blob: 03cb3f8a7ad224b0381f55d2f8df01c720328af1 [file] [log] [blame]
Raverd7895172025-06-18 17:54:38 +08001from typing import List, Tuple, Dict, Any
2import numpy as np
3from collections import defaultdict
4
5from .swing_recall import SwingRecall
6from .hot_recall import HotRecall
7from .ad_recall import AdRecall
8from .usercf_recall import UserCFRecall
9
10class MultiRecallManager:
11 """
12 多路召回管理器
13 整合Swing、热度召回、广告召回和UserCF等多种召回策略
14 """
15
16 def __init__(self, db_config: dict, recall_config: dict = None):
17 """
18 初始化多路召回管理器
19
20 Args:
21 db_config: 数据库配置
22 recall_config: 召回配置,包含各个召回器的参数和召回数量
23 """
24 self.db_config = db_config
25
26 # 默认配置
27 default_config = {
28 'swing': {
29 'enabled': True,
30 'num_items': 15,
31 'alpha': 0.5
32 },
33 'hot': {
34 'enabled': True,
35 'num_items': 10
36 },
37 'ad': {
38 'enabled': True,
39 'num_items': 2
40 },
41 'usercf': {
42 'enabled': True,
43 'num_items': 10,
44 'min_common_items': 3,
45 'num_similar_users': 50
46 }
47 }
48
49 # 合并用户配置
50 self.config = default_config
51 if recall_config:
52 for key, value in recall_config.items():
53 if key in self.config:
54 self.config[key].update(value)
55 else:
56 self.config[key] = value
57
58 # 初始化各个召回器
59 self.recalls = {}
60 self._init_recalls()
61
62 def _init_recalls(self):
63 """初始化各个召回器"""
64 if self.config['swing']['enabled']:
65 self.recalls['swing'] = SwingRecall(
66 self.db_config,
67 alpha=self.config['swing']['alpha']
68 )
69
70 if self.config['hot']['enabled']:
71 self.recalls['hot'] = HotRecall(self.db_config)
72
73 if self.config['ad']['enabled']:
74 self.recalls['ad'] = AdRecall(self.db_config)
75
76 if self.config['usercf']['enabled']:
77 self.recalls['usercf'] = UserCFRecall(
78 self.db_config,
79 min_common_items=self.config['usercf']['min_common_items']
80 )
81
82 def train_all(self):
83 """训练所有召回器"""
84 print("开始训练多路召回模型...")
85
86 for name, recall_model in self.recalls.items():
87 print(f"训练 {name} 召回器...")
88 recall_model.train()
89
90 print("所有召回器训练完成!")
91
92 def recall(self, user_id: int, total_items: int = 200) -> Tuple[List[int], List[float], Dict[str, List[Tuple[int, float]]]]:
93 """
94 执行多路召回
95
96 Args:
97 user_id: 用户ID
98 total_items: 总召回物品数量
99
100 Returns:
101 Tuple containing:
102 - List of item IDs
103 - List of scores
104 - Dict of recall results by source
105 """
106 recall_results = {}
107 all_candidates = defaultdict(list) # item_id -> [(source, score), ...]
108
109 # 执行各路召回
110 for source, recall_model in self.recalls.items():
111 if not self.config[source]['enabled']:
112 continue
113
114 num_items = self.config[source]['num_items']
115
116 # 特殊处理UserCF的参数
117 if source == 'usercf':
118 items_scores = recall_model.recall(
119 user_id,
120 num_items=num_items,
121 num_similar_users=self.config[source]['num_similar_users']
122 )
123 else:
124 items_scores = recall_model.recall(user_id, num_items=num_items)
125
126 recall_results[source] = items_scores
127
128 # 收集候选物品
129 for item_id, score in items_scores:
130 all_candidates[item_id].append((source, score))
131
132 # 融合多路召回结果
133 final_candidates = self._merge_recall_results(all_candidates, total_items)
134
135 # 分离item_ids和scores
136 item_ids = [item_id for item_id, _ in final_candidates]
137 scores = [score for _, score in final_candidates]
138
139 return item_ids, scores, recall_results
140
141 def _merge_recall_results(self, all_candidates: Dict[int, List[Tuple[str, float]]],
142 total_items: int) -> List[Tuple[int, float]]:
143 """
144 融合多路召回结果
145
146 Args:
147 all_candidates: 所有候选物品及其来源和分数
148 total_items: 最终返回的物品数量
149
150 Returns:
151 List of (item_id, final_score) tuples
152 """
153 # 定义各召回源的权重
154 source_weights = {
155 'swing': 0.3,
156 'hot': 0.2,
157 'ad': 0.1,
158 'usercf': 0.4
159 }
160
161 final_scores = []
162
163 for item_id, source_scores in all_candidates.items():
164 # 计算加权平均分数
165 weighted_score = 0.0
166 total_weight = 0.0
167
168 for source, score in source_scores:
169 weight = source_weights.get(source, 0.1)
170 weighted_score += weight * score
171 total_weight += weight
172
173 # 归一化
174 if total_weight > 0:
175 final_score = weighted_score / total_weight
176 else:
177 final_score = 0.0
178
179 # 多样性奖励:如果物品来自多个召回源,给予额外分数
180 diversity_bonus = len(source_scores) * 0.1
181 final_score += diversity_bonus
182
183 final_scores.append((item_id, final_score))
184
185 # 按最终分数排序
186 final_scores.sort(key=lambda x: x[1], reverse=True)
187
188 return final_scores[:total_items]
189
190 def get_recall_stats(self, user_id: int) -> Dict[str, Any]:
191 """
192 获取召回统计信息
193
194 Args:
195 user_id: 用户ID
196
197 Returns:
198 召回统计字典
199 """
200 stats = {
201 'user_id': user_id,
202 'enabled_recalls': list(self.recalls.keys()),
203 'config': self.config
204 }
205
206 # 获取各召回器的统计信息
207 if 'usercf' in self.recalls:
208 try:
209 user_profile = self.recalls['usercf'].get_user_profile(user_id)
210 stats['user_profile'] = user_profile
211
212 neighbors = self.recalls['usercf'].get_user_neighbors(user_id, 5)
213 stats['similar_users'] = neighbors
214 except:
215 pass
216
217 return stats
218
219 def update_config(self, new_config: dict):
220 """
221 更新召回配置
222
223 Args:
224 new_config: 新的配置字典
225 """
226 for key, value in new_config.items():
227 if key in self.config:
228 self.config[key].update(value)
229 else:
230 self.config[key] = value
231
232 # 重新初始化召回器
233 self._init_recalls()
234
235 def get_recall_breakdown(self, user_id: int) -> Dict[str, int]:
236 """
237 获取各召回源的物品数量分布
238
239 Args:
240 user_id: 用户ID
241
242 Returns:
243 各召回源的物品数量字典
244 """
245 breakdown = {}
246
247 for source in self.recalls.keys():
248 if self.config[source]['enabled']:
249 breakdown[source] = self.config[source]['num_items']
250 else:
251 breakdown[source] = 0
252
253 return breakdown