推荐系统

Change-Id: I49b9205568f1ccf88b32b08511aff8b0bea8d1bd
diff --git a/rhj/backend/app/models/recall/swing_recall.py b/rhj/backend/app/models/recall/swing_recall.py
new file mode 100644
index 0000000..bf7fdd6
--- /dev/null
+++ b/rhj/backend/app/models/recall/swing_recall.py
@@ -0,0 +1,126 @@
+import numpy as np
+import pymysql
+from collections import defaultdict
+import math
+from typing import List, Tuple, Dict
+
+class SwingRecall:
+    """
+    Swing召回算法实现
+    基于物品相似度的协同过滤算法,能够有效处理热门物品的问题
+    """
+    
+    def __init__(self, db_config: dict, alpha: float = 0.5):
+        """
+        初始化Swing召回模型
+        
+        Args:
+            db_config: 数据库配置
+            alpha: 控制热门物品惩罚的参数,值越大惩罚越强
+        """
+        self.db_config = db_config
+        self.alpha = alpha
+        self.item_similarity = {}
+        self.user_items = defaultdict(set)
+        self.item_users = defaultdict(set)
+        
+    def _get_interaction_data(self):
+        """获取用户-物品交互数据"""
+        conn = pymysql.connect(**self.db_config)
+        try:
+            cursor = conn.cursor()
+            # 获取用户行为数据(点赞、收藏、评论等)
+            cursor.execute("""
+                SELECT DISTINCT user_id, post_id
+                FROM behaviors
+                WHERE type IN ('like', 'favorite', 'comment')
+            """)
+            interactions = cursor.fetchall()
+            
+            for user_id, post_id in interactions:
+                self.user_items[user_id].add(post_id)
+                self.item_users[post_id].add(user_id)
+                
+        finally:
+            cursor.close()
+            conn.close()
+    
+    def _calculate_swing_similarity(self):
+        """计算Swing相似度矩阵"""
+        print("开始计算Swing相似度...")
+        
+        # 获取所有物品对
+        items = list(self.item_users.keys())
+        
+        for i, item_i in enumerate(items):
+            if i % 100 == 0:
+                print(f"处理进度: {i}/{len(items)}")
+                
+            self.item_similarity[item_i] = {}
+            
+            for item_j in items[i+1:]:
+                # 获取同时交互过两个物品的用户
+                common_users = self.item_users[item_i] & self.item_users[item_j]
+                
+                if len(common_users) < 2:  # 需要至少2个共同用户
+                    similarity = 0.0
+                else:
+                    # 计算Swing相似度
+                    similarity = 0.0
+                    for u in common_users:
+                        for v in common_users:
+                            if u != v:
+                                # Swing算法的核心公式
+                                swing_weight = 1.0 / (self.alpha + len(self.user_items[u] & self.user_items[v]))
+                                similarity += swing_weight
+                    
+                    # 归一化
+                    similarity = similarity / (len(common_users) * (len(common_users) - 1))
+                
+                self.item_similarity[item_i][item_j] = similarity
+                # 对称性
+                if item_j not in self.item_similarity:
+                    self.item_similarity[item_j] = {}
+                self.item_similarity[item_j][item_i] = similarity
+        
+        print("Swing相似度计算完成")
+    
+    def train(self):
+        """训练Swing模型"""
+        self._get_interaction_data()
+        self._calculate_swing_similarity()
+    
+    def recall(self, user_id: int, num_items: int = 50) -> List[Tuple[int, float]]:
+        """
+        为用户召回相似物品
+        
+        Args:
+            user_id: 用户ID
+            num_items: 召回物品数量
+            
+        Returns:
+            List of (item_id, score) tuples
+        """
+        # 如果尚未训练,先进行训练
+        if not hasattr(self, 'item_similarity') or not self.item_similarity:
+            self.train()
+        
+        if user_id not in self.user_items:
+            return []
+        
+        # 获取用户历史交互的物品
+        user_interacted_items = self.user_items[user_id]
+        
+        # 计算候选物品的分数
+        candidate_scores = defaultdict(float)
+        
+        for item_i in user_interacted_items:
+            if item_i in self.item_similarity:
+                for item_j, similarity in self.item_similarity[item_i].items():
+                    # 排除用户已经交互过的物品
+                    if item_j not in user_interacted_items:
+                        candidate_scores[item_j] += similarity
+        
+        # 按分数排序并返回top-N
+        sorted_candidates = sorted(candidate_scores.items(), key=lambda x: x[1], reverse=True)
+        return sorted_candidates[:num_items]