推荐系统
Change-Id: I49b9205568f1ccf88b32b08511aff8b0bea8d1bd
diff --git a/rhj/backend/app/models/recall/swing_recall.py b/rhj/backend/app/models/recall/swing_recall.py
new file mode 100644
index 0000000..bf7fdd6
--- /dev/null
+++ b/rhj/backend/app/models/recall/swing_recall.py
@@ -0,0 +1,126 @@
+import numpy as np
+import pymysql
+from collections import defaultdict
+import math
+from typing import List, Tuple, Dict
+
+class SwingRecall:
+ """
+ Swing召回算法实现
+ 基于物品相似度的协同过滤算法,能够有效处理热门物品的问题
+ """
+
+ def __init__(self, db_config: dict, alpha: float = 0.5):
+ """
+ 初始化Swing召回模型
+
+ Args:
+ db_config: 数据库配置
+ alpha: 控制热门物品惩罚的参数,值越大惩罚越强
+ """
+ self.db_config = db_config
+ self.alpha = alpha
+ self.item_similarity = {}
+ self.user_items = defaultdict(set)
+ self.item_users = defaultdict(set)
+
+ def _get_interaction_data(self):
+ """获取用户-物品交互数据"""
+ conn = pymysql.connect(**self.db_config)
+ try:
+ cursor = conn.cursor()
+ # 获取用户行为数据(点赞、收藏、评论等)
+ cursor.execute("""
+ SELECT DISTINCT user_id, post_id
+ FROM behaviors
+ WHERE type IN ('like', 'favorite', 'comment')
+ """)
+ interactions = cursor.fetchall()
+
+ for user_id, post_id in interactions:
+ self.user_items[user_id].add(post_id)
+ self.item_users[post_id].add(user_id)
+
+ finally:
+ cursor.close()
+ conn.close()
+
+ def _calculate_swing_similarity(self):
+ """计算Swing相似度矩阵"""
+ print("开始计算Swing相似度...")
+
+ # 获取所有物品对
+ items = list(self.item_users.keys())
+
+ for i, item_i in enumerate(items):
+ if i % 100 == 0:
+ print(f"处理进度: {i}/{len(items)}")
+
+ self.item_similarity[item_i] = {}
+
+ for item_j in items[i+1:]:
+ # 获取同时交互过两个物品的用户
+ common_users = self.item_users[item_i] & self.item_users[item_j]
+
+ if len(common_users) < 2: # 需要至少2个共同用户
+ similarity = 0.0
+ else:
+ # 计算Swing相似度
+ similarity = 0.0
+ for u in common_users:
+ for v in common_users:
+ if u != v:
+ # Swing算法的核心公式
+ swing_weight = 1.0 / (self.alpha + len(self.user_items[u] & self.user_items[v]))
+ similarity += swing_weight
+
+ # 归一化
+ similarity = similarity / (len(common_users) * (len(common_users) - 1))
+
+ self.item_similarity[item_i][item_j] = similarity
+ # 对称性
+ if item_j not in self.item_similarity:
+ self.item_similarity[item_j] = {}
+ self.item_similarity[item_j][item_i] = similarity
+
+ print("Swing相似度计算完成")
+
+ def train(self):
+ """训练Swing模型"""
+ self._get_interaction_data()
+ self._calculate_swing_similarity()
+
+ def recall(self, user_id: int, num_items: int = 50) -> List[Tuple[int, float]]:
+ """
+ 为用户召回相似物品
+
+ Args:
+ user_id: 用户ID
+ num_items: 召回物品数量
+
+ Returns:
+ List of (item_id, score) tuples
+ """
+ # 如果尚未训练,先进行训练
+ if not hasattr(self, 'item_similarity') or not self.item_similarity:
+ self.train()
+
+ if user_id not in self.user_items:
+ return []
+
+ # 获取用户历史交互的物品
+ user_interacted_items = self.user_items[user_id]
+
+ # 计算候选物品的分数
+ candidate_scores = defaultdict(float)
+
+ for item_i in user_interacted_items:
+ if item_i in self.item_similarity:
+ for item_j, similarity in self.item_similarity[item_i].items():
+ # 排除用户已经交互过的物品
+ if item_j not in user_interacted_items:
+ candidate_scores[item_j] += similarity
+
+ # 按分数排序并返回top-N
+ sorted_candidates = sorted(candidate_scores.items(), key=lambda x: x[1], reverse=True)
+ return sorted_candidates[:num_items]