Raver | d789517 | 2025-06-18 17:54:38 +0800 | [diff] [blame^] | 1 | from typing import List, Tuple, Dict, Any |
| 2 | import numpy as np |
| 3 | from collections import defaultdict |
| 4 | |
| 5 | from .swing_recall import SwingRecall |
| 6 | from .hot_recall import HotRecall |
| 7 | from .ad_recall import AdRecall |
| 8 | from .usercf_recall import UserCFRecall |
| 9 | |
| 10 | class MultiRecallManager: |
| 11 | """ |
| 12 | 多路召回管理器 |
| 13 | 整合Swing、热度召回、广告召回和UserCF等多种召回策略 |
| 14 | """ |
| 15 | |
| 16 | def __init__(self, db_config: dict, recall_config: dict = None): |
| 17 | """ |
| 18 | 初始化多路召回管理器 |
| 19 | |
| 20 | Args: |
| 21 | db_config: 数据库配置 |
| 22 | recall_config: 召回配置,包含各个召回器的参数和召回数量 |
| 23 | """ |
| 24 | self.db_config = db_config |
| 25 | |
| 26 | # 默认配置 |
| 27 | default_config = { |
| 28 | 'swing': { |
| 29 | 'enabled': True, |
| 30 | 'num_items': 15, |
| 31 | 'alpha': 0.5 |
| 32 | }, |
| 33 | 'hot': { |
| 34 | 'enabled': True, |
| 35 | 'num_items': 10 |
| 36 | }, |
| 37 | 'ad': { |
| 38 | 'enabled': True, |
| 39 | 'num_items': 2 |
| 40 | }, |
| 41 | 'usercf': { |
| 42 | 'enabled': True, |
| 43 | 'num_items': 10, |
| 44 | 'min_common_items': 3, |
| 45 | 'num_similar_users': 50 |
| 46 | } |
| 47 | } |
| 48 | |
| 49 | # 合并用户配置 |
| 50 | self.config = default_config |
| 51 | if recall_config: |
| 52 | for key, value in recall_config.items(): |
| 53 | if key in self.config: |
| 54 | self.config[key].update(value) |
| 55 | else: |
| 56 | self.config[key] = value |
| 57 | |
| 58 | # 初始化各个召回器 |
| 59 | self.recalls = {} |
| 60 | self._init_recalls() |
| 61 | |
| 62 | def _init_recalls(self): |
| 63 | """初始化各个召回器""" |
| 64 | if self.config['swing']['enabled']: |
| 65 | self.recalls['swing'] = SwingRecall( |
| 66 | self.db_config, |
| 67 | alpha=self.config['swing']['alpha'] |
| 68 | ) |
| 69 | |
| 70 | if self.config['hot']['enabled']: |
| 71 | self.recalls['hot'] = HotRecall(self.db_config) |
| 72 | |
| 73 | if self.config['ad']['enabled']: |
| 74 | self.recalls['ad'] = AdRecall(self.db_config) |
| 75 | |
| 76 | if self.config['usercf']['enabled']: |
| 77 | self.recalls['usercf'] = UserCFRecall( |
| 78 | self.db_config, |
| 79 | min_common_items=self.config['usercf']['min_common_items'] |
| 80 | ) |
| 81 | |
| 82 | def train_all(self): |
| 83 | """训练所有召回器""" |
| 84 | print("开始训练多路召回模型...") |
| 85 | |
| 86 | for name, recall_model in self.recalls.items(): |
| 87 | print(f"训练 {name} 召回器...") |
| 88 | recall_model.train() |
| 89 | |
| 90 | print("所有召回器训练完成!") |
| 91 | |
| 92 | def recall(self, user_id: int, total_items: int = 200) -> Tuple[List[int], List[float], Dict[str, List[Tuple[int, float]]]]: |
| 93 | """ |
| 94 | 执行多路召回 |
| 95 | |
| 96 | Args: |
| 97 | user_id: 用户ID |
| 98 | total_items: 总召回物品数量 |
| 99 | |
| 100 | Returns: |
| 101 | Tuple containing: |
| 102 | - List of item IDs |
| 103 | - List of scores |
| 104 | - Dict of recall results by source |
| 105 | """ |
| 106 | recall_results = {} |
| 107 | all_candidates = defaultdict(list) # item_id -> [(source, score), ...] |
| 108 | |
| 109 | # 执行各路召回 |
| 110 | for source, recall_model in self.recalls.items(): |
| 111 | if not self.config[source]['enabled']: |
| 112 | continue |
| 113 | |
| 114 | num_items = self.config[source]['num_items'] |
| 115 | |
| 116 | # 特殊处理UserCF的参数 |
| 117 | if source == 'usercf': |
| 118 | items_scores = recall_model.recall( |
| 119 | user_id, |
| 120 | num_items=num_items, |
| 121 | num_similar_users=self.config[source]['num_similar_users'] |
| 122 | ) |
| 123 | else: |
| 124 | items_scores = recall_model.recall(user_id, num_items=num_items) |
| 125 | |
| 126 | recall_results[source] = items_scores |
| 127 | |
| 128 | # 收集候选物品 |
| 129 | for item_id, score in items_scores: |
| 130 | all_candidates[item_id].append((source, score)) |
| 131 | |
| 132 | # 融合多路召回结果 |
| 133 | final_candidates = self._merge_recall_results(all_candidates, total_items) |
| 134 | |
| 135 | # 分离item_ids和scores |
| 136 | item_ids = [item_id for item_id, _ in final_candidates] |
| 137 | scores = [score for _, score in final_candidates] |
| 138 | |
| 139 | return item_ids, scores, recall_results |
| 140 | |
| 141 | def _merge_recall_results(self, all_candidates: Dict[int, List[Tuple[str, float]]], |
| 142 | total_items: int) -> List[Tuple[int, float]]: |
| 143 | """ |
| 144 | 融合多路召回结果 |
| 145 | |
| 146 | Args: |
| 147 | all_candidates: 所有候选物品及其来源和分数 |
| 148 | total_items: 最终返回的物品数量 |
| 149 | |
| 150 | Returns: |
| 151 | List of (item_id, final_score) tuples |
| 152 | """ |
| 153 | # 定义各召回源的权重 |
| 154 | source_weights = { |
| 155 | 'swing': 0.3, |
| 156 | 'hot': 0.2, |
| 157 | 'ad': 0.1, |
| 158 | 'usercf': 0.4 |
| 159 | } |
| 160 | |
| 161 | final_scores = [] |
| 162 | |
| 163 | for item_id, source_scores in all_candidates.items(): |
| 164 | # 计算加权平均分数 |
| 165 | weighted_score = 0.0 |
| 166 | total_weight = 0.0 |
| 167 | |
| 168 | for source, score in source_scores: |
| 169 | weight = source_weights.get(source, 0.1) |
| 170 | weighted_score += weight * score |
| 171 | total_weight += weight |
| 172 | |
| 173 | # 归一化 |
| 174 | if total_weight > 0: |
| 175 | final_score = weighted_score / total_weight |
| 176 | else: |
| 177 | final_score = 0.0 |
| 178 | |
| 179 | # 多样性奖励:如果物品来自多个召回源,给予额外分数 |
| 180 | diversity_bonus = len(source_scores) * 0.1 |
| 181 | final_score += diversity_bonus |
| 182 | |
| 183 | final_scores.append((item_id, final_score)) |
| 184 | |
| 185 | # 按最终分数排序 |
| 186 | final_scores.sort(key=lambda x: x[1], reverse=True) |
| 187 | |
| 188 | return final_scores[:total_items] |
| 189 | |
| 190 | def get_recall_stats(self, user_id: int) -> Dict[str, Any]: |
| 191 | """ |
| 192 | 获取召回统计信息 |
| 193 | |
| 194 | Args: |
| 195 | user_id: 用户ID |
| 196 | |
| 197 | Returns: |
| 198 | 召回统计字典 |
| 199 | """ |
| 200 | stats = { |
| 201 | 'user_id': user_id, |
| 202 | 'enabled_recalls': list(self.recalls.keys()), |
| 203 | 'config': self.config |
| 204 | } |
| 205 | |
| 206 | # 获取各召回器的统计信息 |
| 207 | if 'usercf' in self.recalls: |
| 208 | try: |
| 209 | user_profile = self.recalls['usercf'].get_user_profile(user_id) |
| 210 | stats['user_profile'] = user_profile |
| 211 | |
| 212 | neighbors = self.recalls['usercf'].get_user_neighbors(user_id, 5) |
| 213 | stats['similar_users'] = neighbors |
| 214 | except: |
| 215 | pass |
| 216 | |
| 217 | return stats |
| 218 | |
| 219 | def update_config(self, new_config: dict): |
| 220 | """ |
| 221 | 更新召回配置 |
| 222 | |
| 223 | Args: |
| 224 | new_config: 新的配置字典 |
| 225 | """ |
| 226 | for key, value in new_config.items(): |
| 227 | if key in self.config: |
| 228 | self.config[key].update(value) |
| 229 | else: |
| 230 | self.config[key] = value |
| 231 | |
| 232 | # 重新初始化召回器 |
| 233 | self._init_recalls() |
| 234 | |
| 235 | def get_recall_breakdown(self, user_id: int) -> Dict[str, int]: |
| 236 | """ |
| 237 | 获取各召回源的物品数量分布 |
| 238 | |
| 239 | Args: |
| 240 | user_id: 用户ID |
| 241 | |
| 242 | Returns: |
| 243 | 各召回源的物品数量字典 |
| 244 | """ |
| 245 | breakdown = {} |
| 246 | |
| 247 | for source in self.recalls.keys(): |
| 248 | if self.config[source]['enabled']: |
| 249 | breakdown[source] = self.config[source]['num_items'] |
| 250 | else: |
| 251 | breakdown[source] = 0 |
| 252 | |
| 253 | return breakdown |