推荐系统

Change-Id: I49b9205568f1ccf88b32b08511aff8b0bea8d1bd
diff --git a/rhj/backend/app/models/recall/multi_recall_manager.py b/rhj/backend/app/models/recall/multi_recall_manager.py
new file mode 100644
index 0000000..03cb3f8
--- /dev/null
+++ b/rhj/backend/app/models/recall/multi_recall_manager.py
@@ -0,0 +1,253 @@
+from typing import List, Tuple, Dict, Any
+import numpy as np
+from collections import defaultdict
+
+from .swing_recall import SwingRecall
+from .hot_recall import HotRecall
+from .ad_recall import AdRecall
+from .usercf_recall import UserCFRecall
+
+class MultiRecallManager:
+    """
+    多路召回管理器
+    整合Swing、热度召回、广告召回和UserCF等多种召回策略
+    """
+    
+    def __init__(self, db_config: dict, recall_config: dict = None):
+        """
+        初始化多路召回管理器
+        
+        Args:
+            db_config: 数据库配置
+            recall_config: 召回配置,包含各个召回器的参数和召回数量
+        """
+        self.db_config = db_config
+        
+        # 默认配置
+        default_config = {
+            'swing': {
+                'enabled': True,
+                'num_items': 15,
+                'alpha': 0.5
+            },
+            'hot': {
+                'enabled': True,
+                'num_items': 10
+            },
+            'ad': {
+                'enabled': True,
+                'num_items': 2
+            },
+            'usercf': {
+                'enabled': True,
+                'num_items': 10,
+                'min_common_items': 3,
+                'num_similar_users': 50
+            }
+        }
+        
+        # 合并用户配置
+        self.config = default_config
+        if recall_config:
+            for key, value in recall_config.items():
+                if key in self.config:
+                    self.config[key].update(value)
+                else:
+                    self.config[key] = value
+        
+        # 初始化各个召回器
+        self.recalls = {}
+        self._init_recalls()
+        
+    def _init_recalls(self):
+        """初始化各个召回器"""
+        if self.config['swing']['enabled']:
+            self.recalls['swing'] = SwingRecall(
+                self.db_config, 
+                alpha=self.config['swing']['alpha']
+            )
+            
+        if self.config['hot']['enabled']:
+            self.recalls['hot'] = HotRecall(self.db_config)
+            
+        if self.config['ad']['enabled']:
+            self.recalls['ad'] = AdRecall(self.db_config)
+            
+        if self.config['usercf']['enabled']:
+            self.recalls['usercf'] = UserCFRecall(
+                self.db_config,
+                min_common_items=self.config['usercf']['min_common_items']
+            )
+    
+    def train_all(self):
+        """训练所有召回器"""
+        print("开始训练多路召回模型...")
+        
+        for name, recall_model in self.recalls.items():
+            print(f"训练 {name} 召回器...")
+            recall_model.train()
+            
+        print("所有召回器训练完成!")
+    
+    def recall(self, user_id: int, total_items: int = 200) -> Tuple[List[int], List[float], Dict[str, List[Tuple[int, float]]]]:
+        """
+        执行多路召回
+        
+        Args:
+            user_id: 用户ID
+            total_items: 总召回物品数量
+            
+        Returns:
+            Tuple containing:
+            - List of item IDs
+            - List of scores
+            - Dict of recall results by source
+        """
+        recall_results = {}
+        all_candidates = defaultdict(list)  # item_id -> [(source, score), ...]
+        
+        # 执行各路召回
+        for source, recall_model in self.recalls.items():
+            if not self.config[source]['enabled']:
+                continue
+                
+            num_items = self.config[source]['num_items']
+            
+            # 特殊处理UserCF的参数
+            if source == 'usercf':
+                items_scores = recall_model.recall(
+                    user_id, 
+                    num_items=num_items,
+                    num_similar_users=self.config[source]['num_similar_users']
+                )
+            else:
+                items_scores = recall_model.recall(user_id, num_items=num_items)
+            
+            recall_results[source] = items_scores
+            
+            # 收集候选物品
+            for item_id, score in items_scores:
+                all_candidates[item_id].append((source, score))
+        
+        # 融合多路召回结果
+        final_candidates = self._merge_recall_results(all_candidates, total_items)
+        
+        # 分离item_ids和scores
+        item_ids = [item_id for item_id, _ in final_candidates]
+        scores = [score for _, score in final_candidates]
+        
+        return item_ids, scores, recall_results
+    
+    def _merge_recall_results(self, all_candidates: Dict[int, List[Tuple[str, float]]], 
+                            total_items: int) -> List[Tuple[int, float]]:
+        """
+        融合多路召回结果
+        
+        Args:
+            all_candidates: 所有候选物品及其来源和分数
+            total_items: 最终返回的物品数量
+            
+        Returns:
+            List of (item_id, final_score) tuples
+        """
+        # 定义各召回源的权重
+        source_weights = {
+            'swing': 0.3,
+            'hot': 0.2,
+            'ad': 0.1,
+            'usercf': 0.4
+        }
+        
+        final_scores = []
+        
+        for item_id, source_scores in all_candidates.items():
+            # 计算加权平均分数
+            weighted_score = 0.0
+            total_weight = 0.0
+            
+            for source, score in source_scores:
+                weight = source_weights.get(source, 0.1)
+                weighted_score += weight * score
+                total_weight += weight
+            
+            # 归一化
+            if total_weight > 0:
+                final_score = weighted_score / total_weight
+            else:
+                final_score = 0.0
+            
+            # 多样性奖励:如果物品来自多个召回源,给予额外分数
+            diversity_bonus = len(source_scores) * 0.1
+            final_score += diversity_bonus
+            
+            final_scores.append((item_id, final_score))
+        
+        # 按最终分数排序
+        final_scores.sort(key=lambda x: x[1], reverse=True)
+        
+        return final_scores[:total_items]
+    
+    def get_recall_stats(self, user_id: int) -> Dict[str, Any]:
+        """
+        获取召回统计信息
+        
+        Args:
+            user_id: 用户ID
+            
+        Returns:
+            召回统计字典
+        """
+        stats = {
+            'user_id': user_id,
+            'enabled_recalls': list(self.recalls.keys()),
+            'config': self.config
+        }
+        
+        # 获取各召回器的统计信息
+        if 'usercf' in self.recalls:
+            try:
+                user_profile = self.recalls['usercf'].get_user_profile(user_id)
+                stats['user_profile'] = user_profile
+                
+                neighbors = self.recalls['usercf'].get_user_neighbors(user_id, 5)
+                stats['similar_users'] = neighbors
+            except:
+                pass
+        
+        return stats
+    
+    def update_config(self, new_config: dict):
+        """
+        更新召回配置
+        
+        Args:
+            new_config: 新的配置字典
+        """
+        for key, value in new_config.items():
+            if key in self.config:
+                self.config[key].update(value)
+            else:
+                self.config[key] = value
+        
+        # 重新初始化召回器
+        self._init_recalls()
+    
+    def get_recall_breakdown(self, user_id: int) -> Dict[str, int]:
+        """
+        获取各召回源的物品数量分布
+        
+        Args:
+            user_id: 用户ID
+            
+        Returns:
+            各召回源的物品数量字典
+        """
+        breakdown = {}
+        
+        for source in self.recalls.keys():
+            if self.config[source]['enabled']:
+                breakdown[source] = self.config[source]['num_items']
+            else:
+                breakdown[source] = 0
+        
+        return breakdown