推荐系统
Change-Id: I49b9205568f1ccf88b32b08511aff8b0bea8d1bd
diff --git a/rhj/backend/app/services/recommendation_service.py b/rhj/backend/app/services/recommendation_service.py
new file mode 100644
index 0000000..2f4de13
--- /dev/null
+++ b/rhj/backend/app/services/recommendation_service.py
@@ -0,0 +1,719 @@
+import torch
+import pymysql
+import numpy as np
+import random
+from app.models.recommend.LightGCN import LightGCN
+from app.models.recall import MultiRecallManager
+from app.services.lightgcn_scorer import LightGCNScorer
+from app.utils.parse_args import args
+from app.utils.data_loader import EdgeListData
+from app.utils.graph_build import build_user_post_graph
+from config import Config
+
+class RecommendationService:
+ def __init__(self):
+ # 数据库连接配置 - 修改为redbook数据库
+ self.db_config = {
+ 'host': '10.126.59.25',
+ 'port': 3306,
+ 'user': 'root',
+ 'password': '123456',
+ 'database': 'redbook', # 使用redbook数据库
+ 'charset': 'utf8mb4'
+ }
+
+ # 模型配置
+ args.device = 'cuda:7' if torch.cuda.is_available() else 'cpu'
+ args.data_path = './app/user_post_graph.txt' # 修改为帖子图文件
+ args.pre_model_path = './app/models/recommend/LightGCN_pretrained.pt'
+
+ self.topk = 2 # 默认推荐数量
+
+ # 初始化多路召回管理器
+ self.multi_recall = None
+ self.multi_recall_enabled = True # 控制是否启用多路召回
+
+ # 初始化LightGCN评分器
+ self.lightgcn_scorer = None
+ self.use_lightgcn_rerank = True # 控制是否使用LightGCN对多路召回结果重新打分
+
+ # 多路召回配置
+ self.recall_config = {
+ 'swing': {
+ 'enabled': True,
+ 'num_items': 20, # 增加召回数量
+ 'alpha': 0.5
+ },
+ 'hot': {
+ 'enabled': True,
+ 'num_items': 15 # 增加热度召回数量
+ },
+ 'ad': {
+ 'enabled': True,
+ 'num_items': 5 # 增加广告召回数量
+ },
+ 'usercf': {
+ 'enabled': True,
+ 'num_items': 15,
+ 'min_common_items': 1, # 降低阈值,从3改为1
+ 'num_similar_users': 20 # 减少相似用户数量以提高效率
+ }
+ }
+
+ def calculate_tag_similarity(self, tags1, tags2):
+ """
+ 计算两个帖子标签的相似度
+ 输入: tags1, tags2 - 标签字符串,以逗号分隔
+ 输出: 相似度分数(0-1之间)
+ """
+ if not tags1 or not tags2:
+ return 0.0
+
+ # 将标签字符串转换为集合
+ set1 = set(tag.strip() for tag in tags1.split(',') if tag.strip())
+ set2 = set(tag.strip() for tag in tags2.split(',') if tag.strip())
+
+ if not set1 or not set2:
+ return 0.0
+
+ # 计算标签重叠比例(Jaccard相似度)
+ intersection = len(set1.intersection(set2))
+ union = len(set1.union(set2))
+
+ return intersection / union if union > 0 else 0.0
+
+ def mmr_rerank_with_ads(self, post_ids, scores, theta=0.5, target_size=None):
+ """
+ 使用MMR算法重新排序推荐结果,并在过程中加入广告约束
+ 输入:
+ - post_ids: 帖子ID列表
+ - scores: 对应的推荐分数列表
+ - theta: 平衡相关性和多样性的参数(0.5表示各占一半)
+ - target_size: 目标结果数量,默认与输入相同
+ 输出: 重排后的(post_ids, scores),每5条帖子包含1条广告
+ """
+ if target_size is None:
+ target_size = len(post_ids)
+
+ if len(post_ids) <= 1:
+ return post_ids, scores
+
+ # 获取帖子标签信息和广告标识
+ conn = pymysql.connect(**self.db_config)
+ cursor = conn.cursor()
+
+ try:
+ # 查询所有候选帖子的标签和广告标识
+ format_strings = ','.join(['%s'] * len(post_ids))
+ cursor.execute(
+ f"""SELECT p.id, p.is_advertisement,
+ COALESCE(GROUP_CONCAT(t.name), '') as tags
+ FROM posts p
+ LEFT JOIN post_tags pt ON p.id = pt.post_id
+ LEFT JOIN tags t ON pt.tag_id = t.id
+ WHERE p.id IN ({format_strings}) AND p.status = 'published'
+ GROUP BY p.id, p.is_advertisement""",
+ tuple(post_ids)
+ )
+ post_info_rows = cursor.fetchall()
+ post_tags = {}
+ post_is_ad = {}
+
+ for row in post_info_rows:
+ post_id, is_ad, tags = row
+ post_tags[post_id] = tags or ""
+ post_is_ad[post_id] = bool(is_ad)
+
+ # 对于没有查询到的帖子,设置默认值
+ for post_id in post_ids:
+ if post_id not in post_tags:
+ post_tags[post_id] = ""
+ post_is_ad[post_id] = False
+
+ # 获取额外的广告帖子作为候选
+ cursor.execute("""
+ SELECT id, heat FROM posts
+ WHERE is_advertisement = 1 AND status = 'published'
+ AND id NOT IN ({})
+ ORDER BY heat DESC
+ LIMIT 50
+ """.format(format_strings), tuple(post_ids))
+ extra_ad_rows = cursor.fetchall()
+
+ finally:
+ cursor.close()
+ conn.close()
+
+ # 分离普通帖子和广告帖子
+ normal_candidates = []
+ ad_candidates = []
+
+ for post_id, score in zip(post_ids, scores):
+ if post_is_ad[post_id]:
+ ad_candidates.append((post_id, score))
+ else:
+ normal_candidates.append((post_id, score))
+
+ # 添加额外的广告候选
+ for ad_id, heat in extra_ad_rows:
+ # 为广告帖子设置标签和广告标识
+ post_tags[ad_id] = "" # 广告帖子暂时设置为空标签
+ post_is_ad[ad_id] = True
+ ad_score = float(heat) / 1000.0 # 将热度转换为分数
+ ad_candidates.append((ad_id, ad_score))
+
+ # 排序候选列表
+ normal_candidates.sort(key=lambda x: x[1], reverse=True)
+ ad_candidates.sort(key=lambda x: x[1], reverse=True)
+
+ # MMR算法实现,加入广告约束
+ selected = []
+ normal_idx = 0
+ ad_idx = 0
+
+ while len(selected) < target_size:
+ current_position = len(selected)
+
+ # 检查是否需要插入广告(每5个位置插入1个广告)
+ if (current_position + 1) % 5 == 0 and ad_idx < len(ad_candidates):
+ # 插入广告
+ selected.append(ad_candidates[ad_idx])
+ ad_idx += 1
+ else:
+ # 使用MMR选择普通帖子
+ if normal_idx >= len(normal_candidates):
+ break
+
+ best_score = -float('inf')
+ best_local_idx = normal_idx
+
+ # 在剩余的普通候选中选择最佳的
+ for i in range(normal_idx, min(normal_idx + 10, len(normal_candidates))):
+ post_id, relevance_score = normal_candidates[i]
+
+ # 计算与已选帖子的最大相似度
+ max_similarity = 0.0
+ current_tags = post_tags[post_id]
+
+ for selected_post_id, _ in selected:
+ selected_tags = post_tags[selected_post_id]
+ similarity = self.calculate_tag_similarity(current_tags, selected_tags)
+ max_similarity = max(max_similarity, similarity)
+
+ # 计算MMR分数
+ mmr_score = theta * relevance_score - (1 - theta) * max_similarity
+
+ if mmr_score > best_score:
+ best_score = mmr_score
+ best_local_idx = i
+
+ # 选择最佳候选
+ selected.append(normal_candidates[best_local_idx])
+ # 将选中的元素移到已处理区域
+ normal_candidates[normal_idx], normal_candidates[best_local_idx] = \
+ normal_candidates[best_local_idx], normal_candidates[normal_idx]
+ normal_idx += 1
+
+ # 提取重排后的结果
+ reranked_post_ids = [post_id for post_id, _ in selected]
+ reranked_scores = [score for _, score in selected]
+
+ return reranked_post_ids, reranked_scores
+
+ def insert_advertisements(self, post_ids, scores):
+ """
+ 在推荐结果中插入广告,每5条帖子插入1条广告
+ 输入: post_ids, scores - 原始推荐结果
+ 输出: 插入广告后的(post_ids, scores)
+ """
+ # 获取可用的广告帖子
+ conn = pymysql.connect(**self.db_config)
+ cursor = conn.cursor()
+
+ try:
+ cursor.execute("""
+ SELECT id, heat FROM posts
+ WHERE is_advertisement = 1 AND status = 'published'
+ ORDER BY heat DESC
+ LIMIT 50
+ """)
+ ad_rows = cursor.fetchall()
+
+ if not ad_rows:
+ # 没有广告,直接返回原结果
+ return post_ids, scores
+
+ # 可用的广告帖子(排除已在推荐结果中的)
+ available_ads = [(ad_id, heat) for ad_id, heat in ad_rows if ad_id not in post_ids]
+
+ if not available_ads:
+ # 没有可用的新广告,直接返回原结果
+ return post_ids, scores
+
+ finally:
+ cursor.close()
+ conn.close()
+
+ # 插入广告的逻辑
+ result_posts = []
+ result_scores = []
+ ad_index = 0
+
+ for i, (post_id, score) in enumerate(zip(post_ids, scores)):
+ result_posts.append(post_id)
+ result_scores.append(score)
+
+ # 每5条帖子后插入一条广告
+ if (i + 1) % 5 == 0 and ad_index < len(available_ads):
+ ad_id, ad_heat = available_ads[ad_index]
+ result_posts.append(ad_id)
+ result_scores.append(float(ad_heat) / 1000.0) # 将热度转换为分数范围
+ ad_index += 1
+
+ return result_posts, result_scores
+
+ def user_cold_start(self, topk=None):
+ """
+ 冷启动:直接返回热度最高的topk个帖子详细信息
+ """
+ if topk is None:
+ topk = self.topk
+
+ conn = pymysql.connect(**self.db_config)
+ cursor = conn.cursor()
+
+ try:
+ # 查询热度最高的topk个帖子
+ cursor.execute("""
+ SELECT p.id, p.user_id, p.title, p.content, p.type, p.heat, p.created_at
+ FROM posts p
+ WHERE p.status = 'published'
+ ORDER BY p.heat DESC
+ LIMIT %s
+ """, (topk,))
+ post_rows = cursor.fetchall()
+ post_ids = [row[0] for row in post_rows]
+ post_map = {row[0]: row for row in post_rows}
+
+ # 查询用户信息
+ owner_ids = list(set(row[1] for row in post_rows))
+ if owner_ids:
+ format_strings_user = ','.join(['%s'] * len(owner_ids))
+ cursor.execute(
+ f"SELECT id, username FROM users WHERE id IN ({format_strings_user})",
+ tuple(owner_ids)
+ )
+ user_rows = cursor.fetchall()
+ user_map = {row[0]: row[1] for row in user_rows}
+ else:
+ user_map = {}
+
+ # 查询帖子标签
+ if post_ids:
+ format_strings = ','.join(['%s'] * len(post_ids))
+ cursor.execute(
+ f"""SELECT pt.post_id, GROUP_CONCAT(t.name) as tags
+ FROM post_tags pt
+ JOIN tags t ON pt.tag_id = t.id
+ WHERE pt.post_id IN ({format_strings})
+ GROUP BY pt.post_id""",
+ tuple(post_ids)
+ )
+ tag_rows = cursor.fetchall()
+ tag_map = {row[0]: row[1] for row in tag_rows}
+ else:
+ tag_map = {}
+
+ post_list = []
+ for post_id in post_ids:
+ row = post_map.get(post_id)
+ if not row:
+ continue
+ owner_user_id = row[1]
+ post_list.append({
+ 'post_id': post_id,
+ 'title': row[2],
+ 'content': row[3][:200] + '...' if len(row[3]) > 200 else row[3], # 截取前200字符
+ 'type': row[4],
+ 'username': user_map.get(owner_user_id, ""),
+ 'heat': row[5],
+ 'tags': tag_map.get(post_id, ""),
+ 'created_at': str(row[6]) if row[6] else ""
+ })
+ return post_list
+ finally:
+ cursor.close()
+ conn.close()
+
+ def run_inference(self, user_id, topk=None, use_multi_recall=None):
+ """
+ 推荐推理主函数
+
+ Args:
+ user_id: 用户ID
+ topk: 推荐数量
+ use_multi_recall: 是否使用多路召回,None表示使用默认设置
+ """
+ if topk is None:
+ topk = self.topk
+
+ # 决定使用哪种召回方式
+ if use_multi_recall is None:
+ use_multi_recall = self.multi_recall_enabled
+
+ return self._run_multi_recall_inference(user_id, topk)
+
+ def _run_multi_recall_inference(self, user_id, topk):
+ """使用多路召回进行推荐,并可选择使用LightGCN重新打分"""
+ try:
+ # 初始化多路召回(如果尚未初始化)
+ self.init_multi_recall()
+
+ # 执行多路召回,召回更多候选物品
+ total_candidates = min(topk * 10, 500) # 召回候选数是最终推荐数的10倍
+ candidate_post_ids, candidate_scores, recall_breakdown = self.multi_recall_inference(
+ user_id, total_candidates
+ )
+
+ if not candidate_post_ids:
+ # 如果多路召回没有结果,回退到冷启动
+ print(f"用户 {user_id} 多路召回无结果,使用冷启动")
+ return self.user_cold_start(topk)
+
+ print(f"用户 {user_id} 多路召回候选数量: {len(candidate_post_ids)}")
+ print(f"召回来源分布: {self._get_recall_source_stats(recall_breakdown)}")
+
+ # 如果启用LightGCN重新打分,使用LightGCN对候选结果进行评分
+ if self.use_lightgcn_rerank:
+ print("使用LightGCN对多路召回结果进行重新打分...")
+ lightgcn_scores = self._get_lightgcn_scores(user_id, candidate_post_ids)
+
+ # 直接使用LightGCN分数,不进行融合
+ final_scores = lightgcn_scores
+
+ print(f"LightGCN打分完成,分数范围: [{min(lightgcn_scores):.4f}, {max(lightgcn_scores):.4f}]")
+ print(f"使用LightGCN分数进行重排")
+ else:
+ # 使用原始多路召回分数
+ final_scores = candidate_scores
+
+ # 使用MMR算法重排,包含广告约束
+ final_post_ids, final_scores = self.mmr_rerank_with_ads(
+ candidate_post_ids, final_scores, theta=0.5, target_size=topk
+ )
+
+ return final_post_ids, final_scores
+
+ except Exception as e:
+ print(f"多路召回失败: {str(e)},回退到LightGCN")
+ return self._run_lightgcn_inference(user_id, topk)
+
+ def _run_lightgcn_inference(self, user_id, topk):
+ """使用原始LightGCN进行推荐"""
+ user2idx, post2idx = build_user_post_graph(return_mapping=True)
+ idx2post = {v: k for k, v in post2idx.items()}
+
+ if user_id not in user2idx:
+ # 冷启动
+ return self.user_cold_start(topk)
+
+ user_idx = user2idx[user_id]
+
+ dataset = EdgeListData(args.data_path, args.data_path)
+ pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
+ pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
+ pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
+
+ model = LightGCN(dataset, phase='vanilla').to(args.device)
+ model.load_state_dict(pretrained_dict, strict=False)
+ model.eval()
+
+ with torch.no_grad():
+ user_emb, item_emb = model.generate()
+ user_vec = user_emb[user_idx].unsqueeze(0)
+ scores = model.rating(user_vec, item_emb).squeeze(0)
+
+ # 获取所有物品的分数(而不是只取top候选)
+ all_scores = scores.cpu().numpy()
+ all_post_ids = [idx2post[idx] for idx in range(len(all_scores))]
+
+ # 过滤掉分数为负的物品,只保留正分数的候选
+ positive_candidates = [(post_id, score) for post_id, score in zip(all_post_ids, all_scores) if score > 0]
+
+ if not positive_candidates:
+ # 如果没有正分数的候选,取分数最高的一些
+ sorted_candidates = sorted(zip(all_post_ids, all_scores), key=lambda x: x[1], reverse=True)
+ positive_candidates = sorted_candidates[:min(100, len(sorted_candidates))]
+
+ candidate_post_ids = [post_id for post_id, _ in positive_candidates]
+ candidate_scores = [score for _, score in positive_candidates]
+
+ print(f"用户 {user_id} 的LightGCN候选物品数量: {len(candidate_post_ids)}")
+
+ # 使用MMR算法重排,包含广告约束,theta=0.5平衡相关性和多样性
+ final_post_ids, final_scores = self.mmr_rerank_with_ads(
+ candidate_post_ids, candidate_scores, theta=0.5, target_size=topk
+ )
+
+ return final_post_ids, final_scores
+
+ def _get_recall_source_stats(self, recall_breakdown):
+ """获取召回来源统计"""
+ stats = {}
+ for source, items in recall_breakdown.items():
+ stats[source] = len(items)
+ return stats
+
+ def get_post_info(self, topk_post_ids, topk_scores=None):
+ """
+ 输入: topk_post_ids(帖子ID列表),topk_scores(对应的打分列表,可选)
+ 输出: 推荐帖子的详细信息列表,每个元素为dict
+ """
+ if not topk_post_ids:
+ return []
+
+ print(f"获取帖子详细信息,帖子ID列表: {topk_post_ids}")
+ if topk_scores is not None:
+ print(f"对应的推荐打分: {topk_scores}")
+
+ conn = pymysql.connect(**self.db_config)
+ cursor = conn.cursor()
+
+ try:
+ # 查询帖子基本信息
+ format_strings = ','.join(['%s'] * len(topk_post_ids))
+ cursor.execute(
+ f"""SELECT p.id, p.user_id, p.title, p.content, p.type, p.heat, p.created_at, p.is_advertisement
+ FROM posts p
+ WHERE p.id IN ({format_strings}) AND p.status = 'published'""",
+ tuple(topk_post_ids)
+ )
+ post_rows = cursor.fetchall()
+ post_map = {row[0]: row for row in post_rows}
+
+ # 查询用户信息
+ owner_ids = list(set(row[1] for row in post_rows))
+ if owner_ids:
+ format_strings_user = ','.join(['%s'] * len(owner_ids))
+ cursor.execute(
+ f"SELECT id, username FROM users WHERE id IN ({format_strings_user})",
+ tuple(owner_ids)
+ )
+ user_rows = cursor.fetchall()
+ user_map = {row[0]: row[1] for row in user_rows}
+ else:
+ user_map = {}
+
+ # 查询帖子标签
+ cursor.execute(
+ f"""SELECT pt.post_id, GROUP_CONCAT(t.name) as tags
+ FROM post_tags pt
+ JOIN tags t ON pt.tag_id = t.id
+ WHERE pt.post_id IN ({format_strings})
+ GROUP BY pt.post_id""",
+ tuple(topk_post_ids)
+ )
+ tag_rows = cursor.fetchall()
+ tag_map = {row[0]: row[1] for row in tag_rows}
+
+ # 查询行为统计(点赞数、评论数等)
+ cursor.execute(
+ f"""SELECT post_id, type, COUNT(*) as count
+ FROM behaviors
+ WHERE post_id IN ({format_strings})
+ GROUP BY post_id, type""",
+ tuple(topk_post_ids)
+ )
+ behavior_rows = cursor.fetchall()
+ behavior_stats = {}
+ for row in behavior_rows:
+ post_id, behavior_type, count = row
+ if post_id not in behavior_stats:
+ behavior_stats[post_id] = {}
+ behavior_stats[post_id][behavior_type] = count
+
+ post_list = []
+ for i, post_id in enumerate(topk_post_ids):
+ row = post_map.get(post_id)
+ if not row:
+ print(f"帖子ID {post_id} 不存在或未发布,跳过")
+ continue
+ owner_user_id = row[1]
+ stats = behavior_stats.get(post_id, {})
+ post_info = {
+ 'post_id': post_id,
+ 'title': row[2],
+ 'content': row[3][:200] + '...' if len(row[3]) > 200 else row[3],
+ 'type': row[4],
+ 'username': user_map.get(owner_user_id, ""),
+ 'heat': row[5],
+ 'tags': tag_map.get(post_id, ""),
+ 'created_at': str(row[6]) if row[6] else "",
+ 'is_advertisement': bool(row[7]), # 添加广告标识
+ 'like_count': stats.get('like', 0),
+ 'comment_count': stats.get('comment', 0),
+ 'favorite_count': stats.get('favorite', 0),
+ 'view_count': stats.get('view', 0),
+ 'share_count': stats.get('share', 0)
+ }
+
+ # 如果有推荐打分,添加到结果中
+ if topk_scores is not None and i < len(topk_scores):
+ post_info['recommendation_score'] = float(topk_scores[i])
+
+ post_list.append(post_info)
+ return post_list
+ finally:
+ cursor.close()
+ conn.close()
+
+ def get_recommendations(self, user_id, topk=None):
+ """
+ 获取推荐结果的主要接口
+ """
+ try:
+ result = self.run_inference(user_id, topk)
+ # 如果是冷启动直接返回详细信息,否则查详情
+ if isinstance(result, list) and result and isinstance(result[0], dict):
+ return result
+ else:
+ # result 现在是 (topk_post_ids, topk_scores) 的元组
+ if isinstance(result, tuple) and len(result) == 2:
+ topk_post_ids, topk_scores = result
+ return self.get_post_info(topk_post_ids, topk_scores)
+ else:
+ # 兼容旧的返回格式
+ return self.get_post_info(result)
+ except Exception as e:
+ raise Exception(f"推荐系统错误: {str(e)}")
+
+ def get_all_item_scores(self, user_id):
+ """
+ 获取用户对所有物品的打分
+ 输入: user_id
+ 输出: (post_ids, scores) - 所有帖子ID和对应的打分
+ """
+ user2idx, post2idx = build_user_post_graph(return_mapping=True)
+ idx2post = {v: k for k, v in post2idx.items()}
+
+ if user_id not in user2idx:
+ # 用户不存在,返回空结果
+ return [], []
+
+ user_idx = user2idx[user_id]
+
+ dataset = EdgeListData(args.data_path, args.data_path)
+ pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
+ pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
+ pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
+
+ model = LightGCN(dataset, phase='vanilla').to(args.device)
+ model.load_state_dict(pretrained_dict, strict=False)
+ model.eval()
+
+ with torch.no_grad():
+ user_emb, item_emb = model.generate()
+ user_vec = user_emb[user_idx].unsqueeze(0)
+ scores = model.rating(user_vec, item_emb).squeeze(0)
+
+ # 获取所有物品的ID和分数
+ all_scores = scores.cpu().numpy()
+ all_post_ids = [idx2post[idx] for idx in range(len(all_scores))]
+
+ return all_post_ids, all_scores
+
+ def init_multi_recall(self):
+ """初始化多路召回管理器"""
+ if self.multi_recall is None:
+ print("初始化多路召回管理器...")
+ self.multi_recall = MultiRecallManager(self.db_config, self.recall_config)
+ print("多路召回管理器初始化完成")
+
+ def init_lightgcn_scorer(self):
+ """初始化LightGCN评分器"""
+ if self.lightgcn_scorer is None:
+ print("初始化LightGCN评分器...")
+ self.lightgcn_scorer = LightGCNScorer()
+ print("LightGCN评分器初始化完成")
+
+ def _get_lightgcn_scores(self, user_id, candidate_post_ids):
+ """
+ 获取候选物品的LightGCN分数
+
+ Args:
+ user_id: 用户ID
+ candidate_post_ids: 候选物品ID列表
+
+ Returns:
+ List[float]: LightGCN分数列表
+ """
+ self.init_lightgcn_scorer()
+ return self.lightgcn_scorer.score_batch_candidates(user_id, candidate_post_ids)
+
+ def _fuse_scores(self, multi_recall_scores, lightgcn_scores, alpha=0.6):
+ """
+ 融合多路召回分数和LightGCN分数
+
+ Args:
+ multi_recall_scores: 多路召回分数列表
+ lightgcn_scores: LightGCN分数列表
+ alpha: LightGCN分数的权重(0-1之间)
+
+ Returns:
+ List[float]: 融合后的分数列表
+ """
+ if len(multi_recall_scores) != len(lightgcn_scores):
+ raise ValueError("分数列表长度不匹配")
+
+ # 对分数进行归一化
+ def normalize_scores(scores):
+ scores = np.array(scores)
+ min_score = np.min(scores)
+ max_score = np.max(scores)
+ if max_score == min_score:
+ return np.ones_like(scores) * 0.5
+ return (scores - min_score) / (max_score - min_score)
+
+ norm_multi_scores = normalize_scores(multi_recall_scores)
+ norm_lightgcn_scores = normalize_scores(lightgcn_scores)
+
+ # 加权融合
+ fused_scores = alpha * norm_lightgcn_scores + (1 - alpha) * norm_multi_scores
+
+ return fused_scores.tolist()
+
+ def train_multi_recall(self):
+ """训练多路召回模型"""
+ self.init_multi_recall()
+ self.multi_recall.train_all()
+
+ def update_recall_config(self, new_config):
+ """更新多路召回配置"""
+ self.recall_config.update(new_config)
+ if self.multi_recall:
+ self.multi_recall.update_config(new_config)
+
+ def multi_recall_inference(self, user_id, total_items=200):
+ """
+ 使用多路召回进行推荐
+
+ Args:
+ user_id: 用户ID
+ total_items: 总召回物品数量
+
+ Returns:
+ Tuple of (item_ids, scores, recall_breakdown)
+ """
+ self.init_multi_recall()
+
+ # 执行多路召回
+ item_ids, scores, recall_results = self.multi_recall.recall(user_id, total_items)
+
+ return item_ids, scores, recall_results
+
+ def get_multi_recall_stats(self, user_id):
+ """获取多路召回统计信息"""
+ if self.multi_recall is None:
+ return {"error": "多路召回未初始化"}
+
+ return self.multi_recall.get_recall_stats(user_id)