| import numpy as np |
| import pymysql |
| from collections import defaultdict |
| import math |
| from typing import List, Tuple, Dict |
| |
| class SwingRecall: |
| """ |
| Swing召回算法实现 |
| 基于物品相似度的协同过滤算法,能够有效处理热门物品的问题 |
| """ |
| |
| def __init__(self, db_config: dict, alpha: float = 0.5): |
| """ |
| 初始化Swing召回模型 |
| |
| Args: |
| db_config: 数据库配置 |
| alpha: 控制热门物品惩罚的参数,值越大惩罚越强 |
| """ |
| self.db_config = db_config |
| self.alpha = alpha |
| self.item_similarity = {} |
| self.user_items = defaultdict(set) |
| self.item_users = defaultdict(set) |
| |
| def _get_interaction_data(self): |
| """获取用户-物品交互数据""" |
| conn = pymysql.connect(**self.db_config) |
| try: |
| cursor = conn.cursor() |
| # 获取用户行为数据(点赞、收藏、评论等) |
| cursor.execute(""" |
| SELECT DISTINCT user_id, post_id |
| FROM behaviors |
| WHERE type IN ('like', 'favorite', 'comment') |
| """) |
| interactions = cursor.fetchall() |
| |
| for user_id, post_id in interactions: |
| self.user_items[user_id].add(post_id) |
| self.item_users[post_id].add(user_id) |
| |
| finally: |
| cursor.close() |
| conn.close() |
| |
| def _calculate_swing_similarity(self): |
| """计算Swing相似度矩阵""" |
| print("开始计算Swing相似度...") |
| |
| # 获取所有物品对 |
| items = list(self.item_users.keys()) |
| |
| for i, item_i in enumerate(items): |
| if i % 100 == 0: |
| print(f"处理进度: {i}/{len(items)}") |
| |
| self.item_similarity[item_i] = {} |
| |
| for item_j in items[i+1:]: |
| # 获取同时交互过两个物品的用户 |
| common_users = self.item_users[item_i] & self.item_users[item_j] |
| |
| if len(common_users) < 2: # 需要至少2个共同用户 |
| similarity = 0.0 |
| else: |
| # 计算Swing相似度 |
| similarity = 0.0 |
| for u in common_users: |
| for v in common_users: |
| if u != v: |
| # Swing算法的核心公式 |
| swing_weight = 1.0 / (self.alpha + len(self.user_items[u] & self.user_items[v])) |
| similarity += swing_weight |
| |
| # 归一化 |
| similarity = similarity / (len(common_users) * (len(common_users) - 1)) |
| |
| self.item_similarity[item_i][item_j] = similarity |
| # 对称性 |
| if item_j not in self.item_similarity: |
| self.item_similarity[item_j] = {} |
| self.item_similarity[item_j][item_i] = similarity |
| |
| print("Swing相似度计算完成") |
| |
| def train(self): |
| """训练Swing模型""" |
| self._get_interaction_data() |
| self._calculate_swing_similarity() |
| |
| def recall(self, user_id: int, num_items: int = 50) -> List[Tuple[int, float]]: |
| """ |
| 为用户召回相似物品 |
| |
| Args: |
| user_id: 用户ID |
| num_items: 召回物品数量 |
| |
| Returns: |
| List of (item_id, score) tuples |
| """ |
| # 如果尚未训练,先进行训练 |
| if not hasattr(self, 'item_similarity') or not self.item_similarity: |
| self.train() |
| |
| if user_id not in self.user_items: |
| return [] |
| |
| # 获取用户历史交互的物品 |
| user_interacted_items = self.user_items[user_id] |
| |
| # 计算候选物品的分数 |
| candidate_scores = defaultdict(float) |
| |
| for item_i in user_interacted_items: |
| if item_i in self.item_similarity: |
| for item_j, similarity in self.item_similarity[item_i].items(): |
| # 排除用户已经交互过的物品 |
| if item_j not in user_interacted_items: |
| candidate_scores[item_j] += similarity |
| |
| # 按分数排序并返回top-N |
| sorted_candidates = sorted(candidate_scores.items(), key=lambda x: x[1], reverse=True) |
| return sorted_candidates[:num_items] |