blob: 0547f7b2cf8cec417110b4c1f7b2b585ba396293 [file] [log] [blame]
Raverd7895172025-06-18 17:54:38 +08001import torch
2import pymysql
3import numpy as np
4import random
5from app.models.recommend.LightGCN import LightGCN
6from app.models.recall import MultiRecallManager
7from app.services.lightgcn_scorer import LightGCNScorer
8from app.utils.parse_args import args
9from app.utils.data_loader import EdgeListData
10from app.utils.graph_build import build_user_post_graph
11from config import Config
12
13class RecommendationService:
14 def __init__(self):
15 # 数据库连接配置 - 修改为redbook数据库
16 self.db_config = {
17 'host': '10.126.59.25',
18 'port': 3306,
19 'user': 'root',
20 'password': '123456',
21 'database': 'redbook', # 使用redbook数据库
22 'charset': 'utf8mb4'
23 }
24
25 # 模型配置
26 args.device = 'cuda:7' if torch.cuda.is_available() else 'cpu'
27 args.data_path = './app/user_post_graph.txt' # 修改为帖子图文件
28 args.pre_model_path = './app/models/recommend/LightGCN_pretrained.pt'
29
30 self.topk = 2 # 默认推荐数量
31
32 # 初始化多路召回管理器
33 self.multi_recall = None
34 self.multi_recall_enabled = True # 控制是否启用多路召回
35
36 # 初始化LightGCN评分器
37 self.lightgcn_scorer = None
38 self.use_lightgcn_rerank = True # 控制是否使用LightGCN对多路召回结果重新打分
39
40 # 多路召回配置
41 self.recall_config = {
42 'swing': {
43 'enabled': True,
44 'num_items': 20, # 增加召回数量
45 'alpha': 0.5
46 },
47 'hot': {
48 'enabled': True,
49 'num_items': 15 # 增加热度召回数量
50 },
51 'ad': {
52 'enabled': True,
53 'num_items': 5 # 增加广告召回数量
54 },
55 'usercf': {
56 'enabled': True,
57 'num_items': 15,
58 'min_common_items': 1, # 降低阈值,从3改为1
59 'num_similar_users': 20 # 减少相似用户数量以提高效率
60 }
61 }
62
63 def calculate_tag_similarity(self, tags1, tags2):
64 """
65 计算两个帖子标签的相似度
66 输入: tags1, tags2 - 标签字符串,以逗号分隔
67 输出: 相似度分数(0-1之间)
68 """
69 if not tags1 or not tags2:
70 return 0.0
71
72 # 将标签字符串转换为集合
73 set1 = set(tag.strip() for tag in tags1.split(',') if tag.strip())
74 set2 = set(tag.strip() for tag in tags2.split(',') if tag.strip())
75
76 if not set1 or not set2:
77 return 0.0
78
79 # 计算标签重叠比例(Jaccard相似度)
80 intersection = len(set1.intersection(set2))
81 union = len(set1.union(set2))
82
83 return intersection / union if union > 0 else 0.0
84
85 def mmr_rerank_with_ads(self, post_ids, scores, theta=0.5, target_size=None):
86 """
87 使用MMR算法重新排序推荐结果,并在过程中加入广告约束
88 输入:
89 - post_ids: 帖子ID列表
90 - scores: 对应的推荐分数列表
91 - theta: 平衡相关性和多样性的参数(0.5表示各占一半)
92 - target_size: 目标结果数量,默认与输入相同
93 输出: 重排后的(post_ids, scores),每5条帖子包含1条广告
94 """
95 if target_size is None:
96 target_size = len(post_ids)
97
98 if len(post_ids) <= 1:
99 return post_ids, scores
100
101 # 获取帖子标签信息和广告标识
102 conn = pymysql.connect(**self.db_config)
103 cursor = conn.cursor()
104
105 try:
106 # 查询所有候选帖子的标签和广告标识
107 format_strings = ','.join(['%s'] * len(post_ids))
108 cursor.execute(
109 f"""SELECT p.id, p.is_advertisement,
110 COALESCE(GROUP_CONCAT(t.name), '') as tags
111 FROM posts p
112 LEFT JOIN post_tags pt ON p.id = pt.post_id
113 LEFT JOIN tags t ON pt.tag_id = t.id
114 WHERE p.id IN ({format_strings}) AND p.status = 'published'
115 GROUP BY p.id, p.is_advertisement""",
116 tuple(post_ids)
117 )
118 post_info_rows = cursor.fetchall()
119 post_tags = {}
120 post_is_ad = {}
121
122 for row in post_info_rows:
123 post_id, is_ad, tags = row
124 post_tags[post_id] = tags or ""
125 post_is_ad[post_id] = bool(is_ad)
126
127 # 对于没有查询到的帖子,设置默认值
128 for post_id in post_ids:
129 if post_id not in post_tags:
130 post_tags[post_id] = ""
131 post_is_ad[post_id] = False
132
133 # 获取额外的广告帖子作为候选
134 cursor.execute("""
135 SELECT id, heat FROM posts
136 WHERE is_advertisement = 1 AND status = 'published'
137 AND id NOT IN ({})
138 ORDER BY heat DESC
139 LIMIT 50
140 """.format(format_strings), tuple(post_ids))
141 extra_ad_rows = cursor.fetchall()
142
143 finally:
144 cursor.close()
145 conn.close()
146
147 # 分离普通帖子和广告帖子
148 normal_candidates = []
149 ad_candidates = []
150
151 for post_id, score in zip(post_ids, scores):
152 if post_is_ad[post_id]:
153 ad_candidates.append((post_id, score))
154 else:
155 normal_candidates.append((post_id, score))
156
157 # 添加额外的广告候选
158 for ad_id, heat in extra_ad_rows:
159 # 为广告帖子设置标签和广告标识
160 post_tags[ad_id] = "" # 广告帖子暂时设置为空标签
161 post_is_ad[ad_id] = True
162 ad_score = float(heat) / 1000.0 # 将热度转换为分数
163 ad_candidates.append((ad_id, ad_score))
164
165 # 排序候选列表
166 normal_candidates.sort(key=lambda x: x[1], reverse=True)
167 ad_candidates.sort(key=lambda x: x[1], reverse=True)
168
169 # MMR算法实现,加入广告约束
170 selected = []
171 normal_idx = 0
172 ad_idx = 0
173
174 while len(selected) < target_size:
175 current_position = len(selected)
176
177 # 检查是否需要插入广告(每5个位置插入1个广告)
178 if (current_position + 1) % 5 == 0 and ad_idx < len(ad_candidates):
179 # 插入广告
180 selected.append(ad_candidates[ad_idx])
181 ad_idx += 1
182 else:
183 # 使用MMR选择普通帖子
184 if normal_idx >= len(normal_candidates):
185 break
186
187 best_score = -float('inf')
188 best_local_idx = normal_idx
189
190 # 在剩余的普通候选中选择最佳的
191 for i in range(normal_idx, min(normal_idx + 10, len(normal_candidates))):
192 post_id, relevance_score = normal_candidates[i]
193
194 # 计算与已选帖子的最大相似度
195 max_similarity = 0.0
196 current_tags = post_tags[post_id]
197
198 for selected_post_id, _ in selected:
199 selected_tags = post_tags[selected_post_id]
200 similarity = self.calculate_tag_similarity(current_tags, selected_tags)
201 max_similarity = max(max_similarity, similarity)
202
203 # 计算MMR分数
204 mmr_score = theta * relevance_score - (1 - theta) * max_similarity
205
206 if mmr_score > best_score:
207 best_score = mmr_score
208 best_local_idx = i
209
210 # 选择最佳候选
211 selected.append(normal_candidates[best_local_idx])
212 # 将选中的元素移到已处理区域
213 normal_candidates[normal_idx], normal_candidates[best_local_idx] = \
214 normal_candidates[best_local_idx], normal_candidates[normal_idx]
215 normal_idx += 1
216
217 # 提取重排后的结果
218 reranked_post_ids = [post_id for post_id, _ in selected]
219 reranked_scores = [score for _, score in selected]
220
221 return reranked_post_ids, reranked_scores
222
223 def insert_advertisements(self, post_ids, scores):
224 """
225 在推荐结果中插入广告,每5条帖子插入1条广告
226 输入: post_ids, scores - 原始推荐结果
227 输出: 插入广告后的(post_ids, scores)
228 """
229 # 获取可用的广告帖子
230 conn = pymysql.connect(**self.db_config)
231 cursor = conn.cursor()
232
233 try:
234 cursor.execute("""
235 SELECT id, heat FROM posts
236 WHERE is_advertisement = 1 AND status = 'published'
237 ORDER BY heat DESC
238 LIMIT 50
239 """)
240 ad_rows = cursor.fetchall()
241
242 if not ad_rows:
243 # 没有广告,直接返回原结果
244 return post_ids, scores
245
246 # 可用的广告帖子(排除已在推荐结果中的)
247 available_ads = [(ad_id, heat) for ad_id, heat in ad_rows if ad_id not in post_ids]
248
249 if not available_ads:
250 # 没有可用的新广告,直接返回原结果
251 return post_ids, scores
252
253 finally:
254 cursor.close()
255 conn.close()
256
257 # 插入广告的逻辑
258 result_posts = []
259 result_scores = []
260 ad_index = 0
261
262 for i, (post_id, score) in enumerate(zip(post_ids, scores)):
263 result_posts.append(post_id)
264 result_scores.append(score)
265
266 # 每5条帖子后插入一条广告
267 if (i + 1) % 5 == 0 and ad_index < len(available_ads):
268 ad_id, ad_heat = available_ads[ad_index]
269 result_posts.append(ad_id)
270 result_scores.append(float(ad_heat) / 1000.0) # 将热度转换为分数范围
271 ad_index += 1
272
273 return result_posts, result_scores
274
275 def user_cold_start(self, topk=None):
276 """
277 冷启动:直接返回热度最高的topk个帖子详细信息
278 """
279 if topk is None:
280 topk = self.topk
281
282 conn = pymysql.connect(**self.db_config)
283 cursor = conn.cursor()
284
285 try:
286 # 查询热度最高的topk个帖子
287 cursor.execute("""
288 SELECT p.id, p.user_id, p.title, p.content, p.type, p.heat, p.created_at
289 FROM posts p
290 WHERE p.status = 'published'
291 ORDER BY p.heat DESC
292 LIMIT %s
293 """, (topk,))
294 post_rows = cursor.fetchall()
295 post_ids = [row[0] for row in post_rows]
296 post_map = {row[0]: row for row in post_rows}
297
298 # 查询用户信息
299 owner_ids = list(set(row[1] for row in post_rows))
300 if owner_ids:
301 format_strings_user = ','.join(['%s'] * len(owner_ids))
302 cursor.execute(
303 f"SELECT id, username FROM users WHERE id IN ({format_strings_user})",
304 tuple(owner_ids)
305 )
306 user_rows = cursor.fetchall()
307 user_map = {row[0]: row[1] for row in user_rows}
308 else:
309 user_map = {}
310
311 # 查询帖子标签
312 if post_ids:
313 format_strings = ','.join(['%s'] * len(post_ids))
314 cursor.execute(
315 f"""SELECT pt.post_id, GROUP_CONCAT(t.name) as tags
316 FROM post_tags pt
317 JOIN tags t ON pt.tag_id = t.id
318 WHERE pt.post_id IN ({format_strings})
319 GROUP BY pt.post_id""",
320 tuple(post_ids)
321 )
322 tag_rows = cursor.fetchall()
323 tag_map = {row[0]: row[1] for row in tag_rows}
324 else:
325 tag_map = {}
326
327 post_list = []
328 for post_id in post_ids:
329 row = post_map.get(post_id)
330 if not row:
331 continue
332 owner_user_id = row[1]
333 post_list.append({
334 'post_id': post_id,
335 'title': row[2],
336 'content': row[3][:200] + '...' if len(row[3]) > 200 else row[3], # 截取前200字符
337 'type': row[4],
338 'username': user_map.get(owner_user_id, ""),
339 'heat': row[5],
340 'tags': tag_map.get(post_id, ""),
341 'created_at': str(row[6]) if row[6] else ""
342 })
343 return post_list
344 finally:
345 cursor.close()
346 conn.close()
347
348 def run_inference(self, user_id, topk=None, use_multi_recall=None):
349 """
350 推荐推理主函数
351
352 Args:
353 user_id: 用户ID
354 topk: 推荐数量
355 use_multi_recall: 是否使用多路召回,None表示使用默认设置
356 """
357 if topk is None:
358 topk = self.topk
359
360 # 决定使用哪种召回方式
361 if use_multi_recall is None:
362 use_multi_recall = self.multi_recall_enabled
363
364 return self._run_multi_recall_inference(user_id, topk)
365
366 def _run_multi_recall_inference(self, user_id, topk):
367 """使用多路召回进行推荐,并可选择使用LightGCN重新打分"""
368 try:
369 # 初始化多路召回(如果尚未初始化)
370 self.init_multi_recall()
371
372 # 执行多路召回,召回更多候选物品
373 total_candidates = min(topk * 10, 500) # 召回候选数是最终推荐数的10倍
374 candidate_post_ids, candidate_scores, recall_breakdown = self.multi_recall_inference(
375 user_id, total_candidates
376 )
377
378 if not candidate_post_ids:
379 # 如果多路召回没有结果,回退到冷启动
380 print(f"用户 {user_id} 多路召回无结果,使用冷启动")
381 return self.user_cold_start(topk)
382
383 print(f"用户 {user_id} 多路召回候选数量: {len(candidate_post_ids)}")
384 print(f"召回来源分布: {self._get_recall_source_stats(recall_breakdown)}")
385
386 # 如果启用LightGCN重新打分,使用LightGCN对候选结果进行评分
387 if self.use_lightgcn_rerank:
388 print("使用LightGCN对多路召回结果进行重新打分...")
389 lightgcn_scores = self._get_lightgcn_scores(user_id, candidate_post_ids)
390
391 # 直接使用LightGCN分数,不进行融合
392 final_scores = lightgcn_scores
393
394 print(f"LightGCN打分完成,分数范围: [{min(lightgcn_scores):.4f}, {max(lightgcn_scores):.4f}]")
395 print(f"使用LightGCN分数进行重排")
396 else:
397 # 使用原始多路召回分数
398 final_scores = candidate_scores
399
400 # 使用MMR算法重排,包含广告约束
401 final_post_ids, final_scores = self.mmr_rerank_with_ads(
402 candidate_post_ids, final_scores, theta=0.5, target_size=topk
403 )
404
405 return final_post_ids, final_scores
406
407 except Exception as e:
408 print(f"多路召回失败: {str(e)},回退到LightGCN")
409 return self._run_lightgcn_inference(user_id, topk)
410
411 def _run_lightgcn_inference(self, user_id, topk):
412 """使用原始LightGCN进行推荐"""
413 user2idx, post2idx = build_user_post_graph(return_mapping=True)
414 idx2post = {v: k for k, v in post2idx.items()}
415
416 if user_id not in user2idx:
417 # 冷启动
418 return self.user_cold_start(topk)
419
420 user_idx = user2idx[user_id]
421
422 dataset = EdgeListData(args.data_path, args.data_path)
423 pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
424 pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
425 pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
426
427 model = LightGCN(dataset, phase='vanilla').to(args.device)
428 model.load_state_dict(pretrained_dict, strict=False)
429 model.eval()
430
431 with torch.no_grad():
432 user_emb, item_emb = model.generate()
433 user_vec = user_emb[user_idx].unsqueeze(0)
434 scores = model.rating(user_vec, item_emb).squeeze(0)
435
436 # 获取所有物品的分数(而不是只取top候选)
437 all_scores = scores.cpu().numpy()
438 all_post_ids = [idx2post[idx] for idx in range(len(all_scores))]
439
440 # 过滤掉分数为负的物品,只保留正分数的候选
441 positive_candidates = [(post_id, score) for post_id, score in zip(all_post_ids, all_scores) if score > 0]
442
443 if not positive_candidates:
444 # 如果没有正分数的候选,取分数最高的一些
445 sorted_candidates = sorted(zip(all_post_ids, all_scores), key=lambda x: x[1], reverse=True)
446 positive_candidates = sorted_candidates[:min(100, len(sorted_candidates))]
447
448 candidate_post_ids = [post_id for post_id, _ in positive_candidates]
449 candidate_scores = [score for _, score in positive_candidates]
450
451 print(f"用户 {user_id} 的LightGCN候选物品数量: {len(candidate_post_ids)}")
452
453 # 使用MMR算法重排,包含广告约束,theta=0.5平衡相关性和多样性
454 final_post_ids, final_scores = self.mmr_rerank_with_ads(
455 candidate_post_ids, candidate_scores, theta=0.5, target_size=topk
456 )
457
458 return final_post_ids, final_scores
459
460 def _get_recall_source_stats(self, recall_breakdown):
461 """获取召回来源统计"""
462 stats = {}
463 for source, items in recall_breakdown.items():
464 stats[source] = len(items)
465 return stats
466
467 def get_post_info(self, topk_post_ids, topk_scores=None):
468 """
469 输入: topk_post_ids(帖子ID列表),topk_scores(对应的打分列表,可选)
470 输出: 推荐帖子的详细信息列表,每个元素为dict
471 """
472 if not topk_post_ids:
473 return []
474
475 print(f"获取帖子详细信息,帖子ID列表: {topk_post_ids}")
476 if topk_scores is not None:
477 print(f"对应的推荐打分: {topk_scores}")
478
479 conn = pymysql.connect(**self.db_config)
480 cursor = conn.cursor()
481
482 try:
483 # 查询帖子基本信息
484 format_strings = ','.join(['%s'] * len(topk_post_ids))
485 cursor.execute(
Raverb48556a2025-06-18 22:55:03 +0800486 f"""SELECT p.id, p.user_id, p.title, p.content, p.type, p.heat, p.created_at, p.updated_at, p.media_urls, p.status, p.is_advertisement
Raverd7895172025-06-18 17:54:38 +0800487 FROM posts p
488 WHERE p.id IN ({format_strings}) AND p.status = 'published'""",
489 tuple(topk_post_ids)
490 )
491 post_rows = cursor.fetchall()
492 post_map = {row[0]: row for row in post_rows}
493
494 # 查询用户信息
495 owner_ids = list(set(row[1] for row in post_rows))
496 if owner_ids:
497 format_strings_user = ','.join(['%s'] * len(owner_ids))
498 cursor.execute(
499 f"SELECT id, username FROM users WHERE id IN ({format_strings_user})",
500 tuple(owner_ids)
501 )
502 user_rows = cursor.fetchall()
503 user_map = {row[0]: row[1] for row in user_rows}
504 else:
505 user_map = {}
506
507 # 查询帖子标签
508 cursor.execute(
509 f"""SELECT pt.post_id, GROUP_CONCAT(t.name) as tags
510 FROM post_tags pt
511 JOIN tags t ON pt.tag_id = t.id
512 WHERE pt.post_id IN ({format_strings})
513 GROUP BY pt.post_id""",
514 tuple(topk_post_ids)
515 )
516 tag_rows = cursor.fetchall()
517 tag_map = {row[0]: row[1] for row in tag_rows}
518
519 # 查询行为统计(点赞数、评论数等)
520 cursor.execute(
521 f"""SELECT post_id, type, COUNT(*) as count
522 FROM behaviors
523 WHERE post_id IN ({format_strings})
524 GROUP BY post_id, type""",
525 tuple(topk_post_ids)
526 )
527 behavior_rows = cursor.fetchall()
528 behavior_stats = {}
529 for row in behavior_rows:
530 post_id, behavior_type, count = row
531 if post_id not in behavior_stats:
532 behavior_stats[post_id] = {}
533 behavior_stats[post_id][behavior_type] = count
534
535 post_list = []
536 for i, post_id in enumerate(topk_post_ids):
537 row = post_map.get(post_id)
538 if not row:
539 print(f"帖子ID {post_id} 不存在或未发布,跳过")
540 continue
541 owner_user_id = row[1]
542 stats = behavior_stats.get(post_id, {})
543 post_info = {
Raverb48556a2025-06-18 22:55:03 +0800544 'id': post_id,
545 'user_id': owner_user_id,
Raverd7895172025-06-18 17:54:38 +0800546 'title': row[2],
Raverb48556a2025-06-18 22:55:03 +0800547 'content': row[3], # 不再截断,保持完整内容
548 'media_urls': row[8],
549 'status': row[9],
550 'heat': row[5],
551 'created_at': row[6].isoformat() if row[6] else "",
552 'updated_at': row[7].isoformat() if row[7] else "",
553 # 额外字段,可选保留
Raverd7895172025-06-18 17:54:38 +0800554 'type': row[4],
555 'username': user_map.get(owner_user_id, ""),
Raverd7895172025-06-18 17:54:38 +0800556 'tags': tag_map.get(post_id, ""),
Raverb48556a2025-06-18 22:55:03 +0800557 'is_advertisement': bool(row[10]),
Raverd7895172025-06-18 17:54:38 +0800558 'like_count': stats.get('like', 0),
559 'comment_count': stats.get('comment', 0),
560 'favorite_count': stats.get('favorite', 0),
561 'view_count': stats.get('view', 0),
562 'share_count': stats.get('share', 0)
563 }
564
Raverd7895172025-06-18 17:54:38 +0800565 post_list.append(post_info)
566 return post_list
567 finally:
568 cursor.close()
569 conn.close()
570
571 def get_recommendations(self, user_id, topk=None):
572 """
573 获取推荐结果的主要接口
574 """
575 try:
576 result = self.run_inference(user_id, topk)
577 # 如果是冷启动直接返回详细信息,否则查详情
578 if isinstance(result, list) and result and isinstance(result[0], dict):
579 return result
580 else:
581 # result 现在是 (topk_post_ids, topk_scores) 的元组
582 if isinstance(result, tuple) and len(result) == 2:
583 topk_post_ids, topk_scores = result
584 return self.get_post_info(topk_post_ids, topk_scores)
585 else:
586 # 兼容旧的返回格式
587 return self.get_post_info(result)
588 except Exception as e:
589 raise Exception(f"推荐系统错误: {str(e)}")
590
591 def get_all_item_scores(self, user_id):
592 """
593 获取用户对所有物品的打分
594 输入: user_id
595 输出: (post_ids, scores) - 所有帖子ID和对应的打分
596 """
597 user2idx, post2idx = build_user_post_graph(return_mapping=True)
598 idx2post = {v: k for k, v in post2idx.items()}
599
600 if user_id not in user2idx:
601 # 用户不存在,返回空结果
602 return [], []
603
604 user_idx = user2idx[user_id]
605
606 dataset = EdgeListData(args.data_path, args.data_path)
607 pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
608 pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
609 pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
610
611 model = LightGCN(dataset, phase='vanilla').to(args.device)
612 model.load_state_dict(pretrained_dict, strict=False)
613 model.eval()
614
615 with torch.no_grad():
616 user_emb, item_emb = model.generate()
617 user_vec = user_emb[user_idx].unsqueeze(0)
618 scores = model.rating(user_vec, item_emb).squeeze(0)
619
620 # 获取所有物品的ID和分数
621 all_scores = scores.cpu().numpy()
622 all_post_ids = [idx2post[idx] for idx in range(len(all_scores))]
623
624 return all_post_ids, all_scores
625
626 def init_multi_recall(self):
627 """初始化多路召回管理器"""
628 if self.multi_recall is None:
629 print("初始化多路召回管理器...")
630 self.multi_recall = MultiRecallManager(self.db_config, self.recall_config)
631 print("多路召回管理器初始化完成")
632
633 def init_lightgcn_scorer(self):
634 """初始化LightGCN评分器"""
635 if self.lightgcn_scorer is None:
636 print("初始化LightGCN评分器...")
637 self.lightgcn_scorer = LightGCNScorer()
638 print("LightGCN评分器初始化完成")
639
640 def _get_lightgcn_scores(self, user_id, candidate_post_ids):
641 """
642 获取候选物品的LightGCN分数
643
644 Args:
645 user_id: 用户ID
646 candidate_post_ids: 候选物品ID列表
647
648 Returns:
649 List[float]: LightGCN分数列表
650 """
651 self.init_lightgcn_scorer()
652 return self.lightgcn_scorer.score_batch_candidates(user_id, candidate_post_ids)
653
654 def _fuse_scores(self, multi_recall_scores, lightgcn_scores, alpha=0.6):
655 """
656 融合多路召回分数和LightGCN分数
657
658 Args:
659 multi_recall_scores: 多路召回分数列表
660 lightgcn_scores: LightGCN分数列表
661 alpha: LightGCN分数的权重(0-1之间)
662
663 Returns:
664 List[float]: 融合后的分数列表
665 """
666 if len(multi_recall_scores) != len(lightgcn_scores):
667 raise ValueError("分数列表长度不匹配")
668
669 # 对分数进行归一化
670 def normalize_scores(scores):
671 scores = np.array(scores)
672 min_score = np.min(scores)
673 max_score = np.max(scores)
674 if max_score == min_score:
675 return np.ones_like(scores) * 0.5
676 return (scores - min_score) / (max_score - min_score)
677
678 norm_multi_scores = normalize_scores(multi_recall_scores)
679 norm_lightgcn_scores = normalize_scores(lightgcn_scores)
680
681 # 加权融合
682 fused_scores = alpha * norm_lightgcn_scores + (1 - alpha) * norm_multi_scores
683
684 return fused_scores.tolist()
685
686 def train_multi_recall(self):
687 """训练多路召回模型"""
688 self.init_multi_recall()
689 self.multi_recall.train_all()
690
691 def update_recall_config(self, new_config):
692 """更新多路召回配置"""
693 self.recall_config.update(new_config)
694 if self.multi_recall:
695 self.multi_recall.update_config(new_config)
696
697 def multi_recall_inference(self, user_id, total_items=200):
698 """
699 使用多路召回进行推荐
700
701 Args:
702 user_id: 用户ID
703 total_items: 总召回物品数量
704
705 Returns:
706 Tuple of (item_ids, scores, recall_breakdown)
707 """
708 self.init_multi_recall()
709
710 # 执行多路召回
711 item_ids, scores, recall_results = self.multi_recall.recall(user_id, total_items)
712
713 return item_ids, scores, recall_results
714
715 def get_multi_recall_stats(self, user_id):
716 """获取多路召回统计信息"""
717 if self.multi_recall is None:
718 return {"error": "多路召回未初始化"}
719
720 return self.multi_recall.get_recall_stats(user_id)