推荐系统
Change-Id: I49b9205568f1ccf88b32b08511aff8b0bea8d1bd
diff --git a/rhj/backend/app/models/recall/multi_recall_manager.py b/rhj/backend/app/models/recall/multi_recall_manager.py
new file mode 100644
index 0000000..03cb3f8
--- /dev/null
+++ b/rhj/backend/app/models/recall/multi_recall_manager.py
@@ -0,0 +1,253 @@
+from typing import List, Tuple, Dict, Any
+import numpy as np
+from collections import defaultdict
+
+from .swing_recall import SwingRecall
+from .hot_recall import HotRecall
+from .ad_recall import AdRecall
+from .usercf_recall import UserCFRecall
+
+class MultiRecallManager:
+ """
+ 多路召回管理器
+ 整合Swing、热度召回、广告召回和UserCF等多种召回策略
+ """
+
+ def __init__(self, db_config: dict, recall_config: dict = None):
+ """
+ 初始化多路召回管理器
+
+ Args:
+ db_config: 数据库配置
+ recall_config: 召回配置,包含各个召回器的参数和召回数量
+ """
+ self.db_config = db_config
+
+ # 默认配置
+ default_config = {
+ 'swing': {
+ 'enabled': True,
+ 'num_items': 15,
+ 'alpha': 0.5
+ },
+ 'hot': {
+ 'enabled': True,
+ 'num_items': 10
+ },
+ 'ad': {
+ 'enabled': True,
+ 'num_items': 2
+ },
+ 'usercf': {
+ 'enabled': True,
+ 'num_items': 10,
+ 'min_common_items': 3,
+ 'num_similar_users': 50
+ }
+ }
+
+ # 合并用户配置
+ self.config = default_config
+ if recall_config:
+ for key, value in recall_config.items():
+ if key in self.config:
+ self.config[key].update(value)
+ else:
+ self.config[key] = value
+
+ # 初始化各个召回器
+ self.recalls = {}
+ self._init_recalls()
+
+ def _init_recalls(self):
+ """初始化各个召回器"""
+ if self.config['swing']['enabled']:
+ self.recalls['swing'] = SwingRecall(
+ self.db_config,
+ alpha=self.config['swing']['alpha']
+ )
+
+ if self.config['hot']['enabled']:
+ self.recalls['hot'] = HotRecall(self.db_config)
+
+ if self.config['ad']['enabled']:
+ self.recalls['ad'] = AdRecall(self.db_config)
+
+ if self.config['usercf']['enabled']:
+ self.recalls['usercf'] = UserCFRecall(
+ self.db_config,
+ min_common_items=self.config['usercf']['min_common_items']
+ )
+
+ def train_all(self):
+ """训练所有召回器"""
+ print("开始训练多路召回模型...")
+
+ for name, recall_model in self.recalls.items():
+ print(f"训练 {name} 召回器...")
+ recall_model.train()
+
+ print("所有召回器训练完成!")
+
+ def recall(self, user_id: int, total_items: int = 200) -> Tuple[List[int], List[float], Dict[str, List[Tuple[int, float]]]]:
+ """
+ 执行多路召回
+
+ Args:
+ user_id: 用户ID
+ total_items: 总召回物品数量
+
+ Returns:
+ Tuple containing:
+ - List of item IDs
+ - List of scores
+ - Dict of recall results by source
+ """
+ recall_results = {}
+ all_candidates = defaultdict(list) # item_id -> [(source, score), ...]
+
+ # 执行各路召回
+ for source, recall_model in self.recalls.items():
+ if not self.config[source]['enabled']:
+ continue
+
+ num_items = self.config[source]['num_items']
+
+ # 特殊处理UserCF的参数
+ if source == 'usercf':
+ items_scores = recall_model.recall(
+ user_id,
+ num_items=num_items,
+ num_similar_users=self.config[source]['num_similar_users']
+ )
+ else:
+ items_scores = recall_model.recall(user_id, num_items=num_items)
+
+ recall_results[source] = items_scores
+
+ # 收集候选物品
+ for item_id, score in items_scores:
+ all_candidates[item_id].append((source, score))
+
+ # 融合多路召回结果
+ final_candidates = self._merge_recall_results(all_candidates, total_items)
+
+ # 分离item_ids和scores
+ item_ids = [item_id for item_id, _ in final_candidates]
+ scores = [score for _, score in final_candidates]
+
+ return item_ids, scores, recall_results
+
+ def _merge_recall_results(self, all_candidates: Dict[int, List[Tuple[str, float]]],
+ total_items: int) -> List[Tuple[int, float]]:
+ """
+ 融合多路召回结果
+
+ Args:
+ all_candidates: 所有候选物品及其来源和分数
+ total_items: 最终返回的物品数量
+
+ Returns:
+ List of (item_id, final_score) tuples
+ """
+ # 定义各召回源的权重
+ source_weights = {
+ 'swing': 0.3,
+ 'hot': 0.2,
+ 'ad': 0.1,
+ 'usercf': 0.4
+ }
+
+ final_scores = []
+
+ for item_id, source_scores in all_candidates.items():
+ # 计算加权平均分数
+ weighted_score = 0.0
+ total_weight = 0.0
+
+ for source, score in source_scores:
+ weight = source_weights.get(source, 0.1)
+ weighted_score += weight * score
+ total_weight += weight
+
+ # 归一化
+ if total_weight > 0:
+ final_score = weighted_score / total_weight
+ else:
+ final_score = 0.0
+
+ # 多样性奖励:如果物品来自多个召回源,给予额外分数
+ diversity_bonus = len(source_scores) * 0.1
+ final_score += diversity_bonus
+
+ final_scores.append((item_id, final_score))
+
+ # 按最终分数排序
+ final_scores.sort(key=lambda x: x[1], reverse=True)
+
+ return final_scores[:total_items]
+
+ def get_recall_stats(self, user_id: int) -> Dict[str, Any]:
+ """
+ 获取召回统计信息
+
+ Args:
+ user_id: 用户ID
+
+ Returns:
+ 召回统计字典
+ """
+ stats = {
+ 'user_id': user_id,
+ 'enabled_recalls': list(self.recalls.keys()),
+ 'config': self.config
+ }
+
+ # 获取各召回器的统计信息
+ if 'usercf' in self.recalls:
+ try:
+ user_profile = self.recalls['usercf'].get_user_profile(user_id)
+ stats['user_profile'] = user_profile
+
+ neighbors = self.recalls['usercf'].get_user_neighbors(user_id, 5)
+ stats['similar_users'] = neighbors
+ except:
+ pass
+
+ return stats
+
+ def update_config(self, new_config: dict):
+ """
+ 更新召回配置
+
+ Args:
+ new_config: 新的配置字典
+ """
+ for key, value in new_config.items():
+ if key in self.config:
+ self.config[key].update(value)
+ else:
+ self.config[key] = value
+
+ # 重新初始化召回器
+ self._init_recalls()
+
+ def get_recall_breakdown(self, user_id: int) -> Dict[str, int]:
+ """
+ 获取各召回源的物品数量分布
+
+ Args:
+ user_id: 用户ID
+
+ Returns:
+ 各召回源的物品数量字典
+ """
+ breakdown = {}
+
+ for source in self.recalls.keys():
+ if self.config[source]['enabled']:
+ breakdown[source] = self.config[source]['num_items']
+ else:
+ breakdown[source] = 0
+
+ return breakdown