| import torch |
| import pymysql |
| import numpy as np |
| import random |
| from app.models.recommend.LightGCN import LightGCN |
| from app.models.recall import MultiRecallManager |
| from app.services.lightgcn_scorer import LightGCNScorer |
| from app.utils.parse_args import args |
| from app.utils.data_loader import EdgeListData |
| from app.utils.graph_build import build_user_post_graph |
| from config import Config |
| |
| class RecommendationService: |
| def __init__(self): |
| # 数据库连接配置 - 修改为redbook数据库 |
| self.db_config = { |
| 'host': '10.126.59.25', |
| 'port': 3306, |
| 'user': 'root', |
| 'password': '123456', |
| 'database': 'redbook', # 使用redbook数据库 |
| 'charset': 'utf8mb4' |
| } |
| |
| # 模型配置 |
| args.device = 'cuda:7' if torch.cuda.is_available() else 'cpu' |
| args.data_path = './app/user_post_graph.txt' # 修改为帖子图文件 |
| args.pre_model_path = './app/models/recommend/LightGCN_pretrained.pt' |
| |
| self.topk = 2 # 默认推荐数量 |
| |
| # 初始化多路召回管理器 |
| self.multi_recall = None |
| self.multi_recall_enabled = True # 控制是否启用多路召回 |
| |
| # 初始化LightGCN评分器 |
| self.lightgcn_scorer = None |
| self.use_lightgcn_rerank = True # 控制是否使用LightGCN对多路召回结果重新打分 |
| |
| # 多路召回配置 |
| self.recall_config = { |
| 'swing': { |
| 'enabled': True, |
| 'num_items': 20, # 增加召回数量 |
| 'alpha': 0.5 |
| }, |
| 'hot': { |
| 'enabled': True, |
| 'num_items': 15 # 增加热度召回数量 |
| }, |
| 'ad': { |
| 'enabled': True, |
| 'num_items': 5 # 增加广告召回数量 |
| }, |
| 'usercf': { |
| 'enabled': True, |
| 'num_items': 15, |
| 'min_common_items': 1, # 降低阈值,从3改为1 |
| 'num_similar_users': 20 # 减少相似用户数量以提高效率 |
| } |
| } |
| |
| def calculate_tag_similarity(self, tags1, tags2): |
| """ |
| 计算两个帖子标签的相似度 |
| 输入: tags1, tags2 - 标签字符串,以逗号分隔 |
| 输出: 相似度分数(0-1之间) |
| """ |
| if not tags1 or not tags2: |
| return 0.0 |
| |
| # 将标签字符串转换为集合 |
| set1 = set(tag.strip() for tag in tags1.split(',') if tag.strip()) |
| set2 = set(tag.strip() for tag in tags2.split(',') if tag.strip()) |
| |
| if not set1 or not set2: |
| return 0.0 |
| |
| # 计算标签重叠比例(Jaccard相似度) |
| intersection = len(set1.intersection(set2)) |
| union = len(set1.union(set2)) |
| |
| return intersection / union if union > 0 else 0.0 |
| |
| def mmr_rerank_with_ads(self, post_ids, scores, theta=0.5, target_size=None): |
| """ |
| 使用MMR算法重新排序推荐结果,并在过程中加入广告约束 |
| 输入: |
| - post_ids: 帖子ID列表 |
| - scores: 对应的推荐分数列表 |
| - theta: 平衡相关性和多样性的参数(0.5表示各占一半) |
| - target_size: 目标结果数量,默认与输入相同 |
| 输出: 重排后的(post_ids, scores),每5条帖子包含1条广告 |
| """ |
| if target_size is None: |
| target_size = len(post_ids) |
| |
| if len(post_ids) <= 1: |
| return post_ids, scores |
| |
| # 获取帖子标签信息和广告标识 |
| conn = pymysql.connect(**self.db_config) |
| cursor = conn.cursor() |
| |
| try: |
| # 查询所有候选帖子的标签和广告标识 |
| format_strings = ','.join(['%s'] * len(post_ids)) |
| cursor.execute( |
| f"""SELECT p.id, p.is_advertisement, |
| COALESCE(GROUP_CONCAT(t.name), '') as tags |
| FROM posts p |
| LEFT JOIN post_tags pt ON p.id = pt.post_id |
| LEFT JOIN tags t ON pt.tag_id = t.id |
| WHERE p.id IN ({format_strings}) AND p.status = 'published' |
| GROUP BY p.id, p.is_advertisement""", |
| tuple(post_ids) |
| ) |
| post_info_rows = cursor.fetchall() |
| post_tags = {} |
| post_is_ad = {} |
| |
| for row in post_info_rows: |
| post_id, is_ad, tags = row |
| post_tags[post_id] = tags or "" |
| post_is_ad[post_id] = bool(is_ad) |
| |
| # 对于没有查询到的帖子,设置默认值 |
| for post_id in post_ids: |
| if post_id not in post_tags: |
| post_tags[post_id] = "" |
| post_is_ad[post_id] = False |
| |
| # 获取额外的广告帖子作为候选 |
| cursor.execute(""" |
| SELECT id, heat FROM posts |
| WHERE is_advertisement = 1 AND status = 'published' |
| AND id NOT IN ({}) |
| ORDER BY heat DESC |
| LIMIT 50 |
| """.format(format_strings), tuple(post_ids)) |
| extra_ad_rows = cursor.fetchall() |
| |
| finally: |
| cursor.close() |
| conn.close() |
| |
| # 分离普通帖子和广告帖子 |
| normal_candidates = [] |
| ad_candidates = [] |
| |
| for post_id, score in zip(post_ids, scores): |
| if post_is_ad[post_id]: |
| ad_candidates.append((post_id, score)) |
| else: |
| normal_candidates.append((post_id, score)) |
| |
| # 添加额外的广告候选 |
| for ad_id, heat in extra_ad_rows: |
| # 为广告帖子设置标签和广告标识 |
| post_tags[ad_id] = "" # 广告帖子暂时设置为空标签 |
| post_is_ad[ad_id] = True |
| ad_score = float(heat) / 1000.0 # 将热度转换为分数 |
| ad_candidates.append((ad_id, ad_score)) |
| |
| # 排序候选列表 |
| normal_candidates.sort(key=lambda x: x[1], reverse=True) |
| ad_candidates.sort(key=lambda x: x[1], reverse=True) |
| |
| # MMR算法实现,加入广告约束 |
| selected = [] |
| normal_idx = 0 |
| ad_idx = 0 |
| |
| while len(selected) < target_size: |
| current_position = len(selected) |
| |
| # 检查是否需要插入广告(每5个位置插入1个广告) |
| if (current_position + 1) % 5 == 0 and ad_idx < len(ad_candidates): |
| # 插入广告 |
| selected.append(ad_candidates[ad_idx]) |
| ad_idx += 1 |
| else: |
| # 使用MMR选择普通帖子 |
| if normal_idx >= len(normal_candidates): |
| break |
| |
| best_score = -float('inf') |
| best_local_idx = normal_idx |
| |
| # 在剩余的普通候选中选择最佳的 |
| for i in range(normal_idx, min(normal_idx + 10, len(normal_candidates))): |
| post_id, relevance_score = normal_candidates[i] |
| |
| # 计算与已选帖子的最大相似度 |
| max_similarity = 0.0 |
| current_tags = post_tags[post_id] |
| |
| for selected_post_id, _ in selected: |
| selected_tags = post_tags[selected_post_id] |
| similarity = self.calculate_tag_similarity(current_tags, selected_tags) |
| max_similarity = max(max_similarity, similarity) |
| |
| # 计算MMR分数 |
| mmr_score = theta * relevance_score - (1 - theta) * max_similarity |
| |
| if mmr_score > best_score: |
| best_score = mmr_score |
| best_local_idx = i |
| |
| # 选择最佳候选 |
| selected.append(normal_candidates[best_local_idx]) |
| # 将选中的元素移到已处理区域 |
| normal_candidates[normal_idx], normal_candidates[best_local_idx] = \ |
| normal_candidates[best_local_idx], normal_candidates[normal_idx] |
| normal_idx += 1 |
| |
| # 提取重排后的结果 |
| reranked_post_ids = [post_id for post_id, _ in selected] |
| reranked_scores = [score for _, score in selected] |
| |
| return reranked_post_ids, reranked_scores |
| |
| def insert_advertisements(self, post_ids, scores): |
| """ |
| 在推荐结果中插入广告,每5条帖子插入1条广告 |
| 输入: post_ids, scores - 原始推荐结果 |
| 输出: 插入广告后的(post_ids, scores) |
| """ |
| # 获取可用的广告帖子 |
| conn = pymysql.connect(**self.db_config) |
| cursor = conn.cursor() |
| |
| try: |
| cursor.execute(""" |
| SELECT id, heat FROM posts |
| WHERE is_advertisement = 1 AND status = 'published' |
| ORDER BY heat DESC |
| LIMIT 50 |
| """) |
| ad_rows = cursor.fetchall() |
| |
| if not ad_rows: |
| # 没有广告,直接返回原结果 |
| return post_ids, scores |
| |
| # 可用的广告帖子(排除已在推荐结果中的) |
| available_ads = [(ad_id, heat) for ad_id, heat in ad_rows if ad_id not in post_ids] |
| |
| if not available_ads: |
| # 没有可用的新广告,直接返回原结果 |
| return post_ids, scores |
| |
| finally: |
| cursor.close() |
| conn.close() |
| |
| # 插入广告的逻辑 |
| result_posts = [] |
| result_scores = [] |
| ad_index = 0 |
| |
| for i, (post_id, score) in enumerate(zip(post_ids, scores)): |
| result_posts.append(post_id) |
| result_scores.append(score) |
| |
| # 每5条帖子后插入一条广告 |
| if (i + 1) % 5 == 0 and ad_index < len(available_ads): |
| ad_id, ad_heat = available_ads[ad_index] |
| result_posts.append(ad_id) |
| result_scores.append(float(ad_heat) / 1000.0) # 将热度转换为分数范围 |
| ad_index += 1 |
| |
| return result_posts, result_scores |
| |
| def user_cold_start(self, topk=None): |
| """ |
| 冷启动:直接返回热度最高的topk个帖子详细信息 |
| """ |
| if topk is None: |
| topk = self.topk |
| |
| conn = pymysql.connect(**self.db_config) |
| cursor = conn.cursor() |
| |
| try: |
| # 查询热度最高的topk个帖子 |
| cursor.execute(""" |
| SELECT p.id, p.user_id, p.title, p.content, p.type, p.heat, p.created_at |
| FROM posts p |
| WHERE p.status = 'published' |
| ORDER BY p.heat DESC |
| LIMIT %s |
| """, (topk,)) |
| post_rows = cursor.fetchall() |
| post_ids = [row[0] for row in post_rows] |
| post_map = {row[0]: row for row in post_rows} |
| |
| # 查询用户信息 |
| owner_ids = list(set(row[1] for row in post_rows)) |
| if owner_ids: |
| format_strings_user = ','.join(['%s'] * len(owner_ids)) |
| cursor.execute( |
| f"SELECT id, username FROM users WHERE id IN ({format_strings_user})", |
| tuple(owner_ids) |
| ) |
| user_rows = cursor.fetchall() |
| user_map = {row[0]: row[1] for row in user_rows} |
| else: |
| user_map = {} |
| |
| # 查询帖子标签 |
| if post_ids: |
| format_strings = ','.join(['%s'] * len(post_ids)) |
| cursor.execute( |
| f"""SELECT pt.post_id, GROUP_CONCAT(t.name) as tags |
| FROM post_tags pt |
| JOIN tags t ON pt.tag_id = t.id |
| WHERE pt.post_id IN ({format_strings}) |
| GROUP BY pt.post_id""", |
| tuple(post_ids) |
| ) |
| tag_rows = cursor.fetchall() |
| tag_map = {row[0]: row[1] for row in tag_rows} |
| else: |
| tag_map = {} |
| |
| post_list = [] |
| for post_id in post_ids: |
| row = post_map.get(post_id) |
| if not row: |
| continue |
| owner_user_id = row[1] |
| post_list.append({ |
| 'post_id': post_id, |
| 'title': row[2], |
| 'content': row[3][:200] + '...' if len(row[3]) > 200 else row[3], # 截取前200字符 |
| 'type': row[4], |
| 'username': user_map.get(owner_user_id, ""), |
| 'heat': row[5], |
| 'tags': tag_map.get(post_id, ""), |
| 'created_at': str(row[6]) if row[6] else "" |
| }) |
| return post_list |
| finally: |
| cursor.close() |
| conn.close() |
| |
| def run_inference(self, user_id, topk=None, use_multi_recall=None): |
| """ |
| 推荐推理主函数 |
| |
| Args: |
| user_id: 用户ID |
| topk: 推荐数量 |
| use_multi_recall: 是否使用多路召回,None表示使用默认设置 |
| """ |
| if topk is None: |
| topk = self.topk |
| |
| # 决定使用哪种召回方式 |
| if use_multi_recall is None: |
| use_multi_recall = self.multi_recall_enabled |
| |
| return self._run_multi_recall_inference(user_id, topk) |
| |
| def _run_multi_recall_inference(self, user_id, topk): |
| """使用多路召回进行推荐,并可选择使用LightGCN重新打分""" |
| try: |
| # 初始化多路召回(如果尚未初始化) |
| self.init_multi_recall() |
| |
| # 执行多路召回,召回更多候选物品 |
| total_candidates = min(topk * 10, 500) # 召回候选数是最终推荐数的10倍 |
| candidate_post_ids, candidate_scores, recall_breakdown = self.multi_recall_inference( |
| user_id, total_candidates |
| ) |
| |
| if not candidate_post_ids: |
| # 如果多路召回没有结果,回退到冷启动 |
| print(f"用户 {user_id} 多路召回无结果,使用冷启动") |
| return self.user_cold_start(topk) |
| |
| print(f"用户 {user_id} 多路召回候选数量: {len(candidate_post_ids)}") |
| print(f"召回来源分布: {self._get_recall_source_stats(recall_breakdown)}") |
| |
| # 如果启用LightGCN重新打分,使用LightGCN对候选结果进行评分 |
| if self.use_lightgcn_rerank: |
| print("使用LightGCN对多路召回结果进行重新打分...") |
| lightgcn_scores = self._get_lightgcn_scores(user_id, candidate_post_ids) |
| |
| # 直接使用LightGCN分数,不进行融合 |
| final_scores = lightgcn_scores |
| |
| print(f"LightGCN打分完成,分数范围: [{min(lightgcn_scores):.4f}, {max(lightgcn_scores):.4f}]") |
| print(f"使用LightGCN分数进行重排") |
| else: |
| # 使用原始多路召回分数 |
| final_scores = candidate_scores |
| |
| # 使用MMR算法重排,包含广告约束 |
| final_post_ids, final_scores = self.mmr_rerank_with_ads( |
| candidate_post_ids, final_scores, theta=0.5, target_size=topk |
| ) |
| |
| return final_post_ids, final_scores |
| |
| except Exception as e: |
| print(f"多路召回失败: {str(e)},回退到LightGCN") |
| return self._run_lightgcn_inference(user_id, topk) |
| |
| def _run_lightgcn_inference(self, user_id, topk): |
| """使用原始LightGCN进行推荐""" |
| user2idx, post2idx = build_user_post_graph(return_mapping=True) |
| idx2post = {v: k for k, v in post2idx.items()} |
| |
| if user_id not in user2idx: |
| # 冷启动 |
| return self.user_cold_start(topk) |
| |
| user_idx = user2idx[user_id] |
| |
| dataset = EdgeListData(args.data_path, args.data_path) |
| pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True) |
| pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users] |
| pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items] |
| |
| model = LightGCN(dataset, phase='vanilla').to(args.device) |
| model.load_state_dict(pretrained_dict, strict=False) |
| model.eval() |
| |
| with torch.no_grad(): |
| user_emb, item_emb = model.generate() |
| user_vec = user_emb[user_idx].unsqueeze(0) |
| scores = model.rating(user_vec, item_emb).squeeze(0) |
| |
| # 获取所有物品的分数(而不是只取top候选) |
| all_scores = scores.cpu().numpy() |
| all_post_ids = [idx2post[idx] for idx in range(len(all_scores))] |
| |
| # 过滤掉分数为负的物品,只保留正分数的候选 |
| positive_candidates = [(post_id, score) for post_id, score in zip(all_post_ids, all_scores) if score > 0] |
| |
| if not positive_candidates: |
| # 如果没有正分数的候选,取分数最高的一些 |
| sorted_candidates = sorted(zip(all_post_ids, all_scores), key=lambda x: x[1], reverse=True) |
| positive_candidates = sorted_candidates[:min(100, len(sorted_candidates))] |
| |
| candidate_post_ids = [post_id for post_id, _ in positive_candidates] |
| candidate_scores = [score for _, score in positive_candidates] |
| |
| print(f"用户 {user_id} 的LightGCN候选物品数量: {len(candidate_post_ids)}") |
| |
| # 使用MMR算法重排,包含广告约束,theta=0.5平衡相关性和多样性 |
| final_post_ids, final_scores = self.mmr_rerank_with_ads( |
| candidate_post_ids, candidate_scores, theta=0.5, target_size=topk |
| ) |
| |
| return final_post_ids, final_scores |
| |
| def _get_recall_source_stats(self, recall_breakdown): |
| """获取召回来源统计""" |
| stats = {} |
| for source, items in recall_breakdown.items(): |
| stats[source] = len(items) |
| return stats |
| |
| def get_post_info(self, topk_post_ids, topk_scores=None): |
| """ |
| 输入: topk_post_ids(帖子ID列表),topk_scores(对应的打分列表,可选) |
| 输出: 推荐帖子的详细信息列表,每个元素为dict |
| """ |
| if not topk_post_ids: |
| return [] |
| |
| print(f"获取帖子详细信息,帖子ID列表: {topk_post_ids}") |
| if topk_scores is not None: |
| print(f"对应的推荐打分: {topk_scores}") |
| |
| conn = pymysql.connect(**self.db_config) |
| cursor = conn.cursor() |
| |
| try: |
| # 查询帖子基本信息 |
| format_strings = ','.join(['%s'] * len(topk_post_ids)) |
| cursor.execute( |
| f"""SELECT p.id, p.user_id, p.title, p.content, p.type, p.heat, p.created_at, p.updated_at, p.media_urls, p.status, p.is_advertisement |
| FROM posts p |
| WHERE p.id IN ({format_strings}) AND p.status = 'published'""", |
| tuple(topk_post_ids) |
| ) |
| post_rows = cursor.fetchall() |
| post_map = {row[0]: row for row in post_rows} |
| |
| # 查询用户信息 |
| owner_ids = list(set(row[1] for row in post_rows)) |
| if owner_ids: |
| format_strings_user = ','.join(['%s'] * len(owner_ids)) |
| cursor.execute( |
| f"SELECT id, username FROM users WHERE id IN ({format_strings_user})", |
| tuple(owner_ids) |
| ) |
| user_rows = cursor.fetchall() |
| user_map = {row[0]: row[1] for row in user_rows} |
| else: |
| user_map = {} |
| |
| # 查询帖子标签 |
| cursor.execute( |
| f"""SELECT pt.post_id, GROUP_CONCAT(t.name) as tags |
| FROM post_tags pt |
| JOIN tags t ON pt.tag_id = t.id |
| WHERE pt.post_id IN ({format_strings}) |
| GROUP BY pt.post_id""", |
| tuple(topk_post_ids) |
| ) |
| tag_rows = cursor.fetchall() |
| tag_map = {row[0]: row[1] for row in tag_rows} |
| |
| # 查询行为统计(点赞数、评论数等) |
| cursor.execute( |
| f"""SELECT post_id, type, COUNT(*) as count |
| FROM behaviors |
| WHERE post_id IN ({format_strings}) |
| GROUP BY post_id, type""", |
| tuple(topk_post_ids) |
| ) |
| behavior_rows = cursor.fetchall() |
| behavior_stats = {} |
| for row in behavior_rows: |
| post_id, behavior_type, count = row |
| if post_id not in behavior_stats: |
| behavior_stats[post_id] = {} |
| behavior_stats[post_id][behavior_type] = count |
| |
| post_list = [] |
| for i, post_id in enumerate(topk_post_ids): |
| row = post_map.get(post_id) |
| if not row: |
| print(f"帖子ID {post_id} 不存在或未发布,跳过") |
| continue |
| owner_user_id = row[1] |
| stats = behavior_stats.get(post_id, {}) |
| post_info = { |
| 'id': post_id, |
| 'user_id': owner_user_id, |
| 'title': row[2], |
| 'content': row[3], # 不再截断,保持完整内容 |
| 'media_urls': row[8], |
| 'status': row[9], |
| 'heat': row[5], |
| 'created_at': row[6].isoformat() if row[6] else "", |
| 'updated_at': row[7].isoformat() if row[7] else "", |
| # 额外字段,可选保留 |
| 'type': row[4], |
| 'username': user_map.get(owner_user_id, ""), |
| 'tags': tag_map.get(post_id, ""), |
| 'is_advertisement': bool(row[10]), |
| 'like_count': stats.get('like', 0), |
| 'comment_count': stats.get('comment', 0), |
| 'favorite_count': stats.get('favorite', 0), |
| 'view_count': stats.get('view', 0), |
| 'share_count': stats.get('share', 0) |
| } |
| |
| post_list.append(post_info) |
| return post_list |
| finally: |
| cursor.close() |
| conn.close() |
| |
| def get_recommendations(self, user_id, topk=None): |
| """ |
| 获取推荐结果的主要接口 |
| """ |
| try: |
| result = self.run_inference(user_id, topk) |
| # 如果是冷启动直接返回详细信息,否则查详情 |
| if isinstance(result, list) and result and isinstance(result[0], dict): |
| return result |
| else: |
| # result 现在是 (topk_post_ids, topk_scores) 的元组 |
| if isinstance(result, tuple) and len(result) == 2: |
| topk_post_ids, topk_scores = result |
| return self.get_post_info(topk_post_ids, topk_scores) |
| else: |
| # 兼容旧的返回格式 |
| return self.get_post_info(result) |
| except Exception as e: |
| raise Exception(f"推荐系统错误: {str(e)}") |
| |
| def get_all_item_scores(self, user_id): |
| """ |
| 获取用户对所有物品的打分 |
| 输入: user_id |
| 输出: (post_ids, scores) - 所有帖子ID和对应的打分 |
| """ |
| user2idx, post2idx = build_user_post_graph(return_mapping=True) |
| idx2post = {v: k for k, v in post2idx.items()} |
| |
| if user_id not in user2idx: |
| # 用户不存在,返回空结果 |
| return [], [] |
| |
| user_idx = user2idx[user_id] |
| |
| dataset = EdgeListData(args.data_path, args.data_path) |
| pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True) |
| pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users] |
| pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items] |
| |
| model = LightGCN(dataset, phase='vanilla').to(args.device) |
| model.load_state_dict(pretrained_dict, strict=False) |
| model.eval() |
| |
| with torch.no_grad(): |
| user_emb, item_emb = model.generate() |
| user_vec = user_emb[user_idx].unsqueeze(0) |
| scores = model.rating(user_vec, item_emb).squeeze(0) |
| |
| # 获取所有物品的ID和分数 |
| all_scores = scores.cpu().numpy() |
| all_post_ids = [idx2post[idx] for idx in range(len(all_scores))] |
| |
| return all_post_ids, all_scores |
| |
| def init_multi_recall(self): |
| """初始化多路召回管理器""" |
| if self.multi_recall is None: |
| print("初始化多路召回管理器...") |
| self.multi_recall = MultiRecallManager(self.db_config, self.recall_config) |
| print("多路召回管理器初始化完成") |
| |
| def init_lightgcn_scorer(self): |
| """初始化LightGCN评分器""" |
| if self.lightgcn_scorer is None: |
| print("初始化LightGCN评分器...") |
| self.lightgcn_scorer = LightGCNScorer() |
| print("LightGCN评分器初始化完成") |
| |
| def _get_lightgcn_scores(self, user_id, candidate_post_ids): |
| """ |
| 获取候选物品的LightGCN分数 |
| |
| Args: |
| user_id: 用户ID |
| candidate_post_ids: 候选物品ID列表 |
| |
| Returns: |
| List[float]: LightGCN分数列表 |
| """ |
| self.init_lightgcn_scorer() |
| return self.lightgcn_scorer.score_batch_candidates(user_id, candidate_post_ids) |
| |
| def _fuse_scores(self, multi_recall_scores, lightgcn_scores, alpha=0.6): |
| """ |
| 融合多路召回分数和LightGCN分数 |
| |
| Args: |
| multi_recall_scores: 多路召回分数列表 |
| lightgcn_scores: LightGCN分数列表 |
| alpha: LightGCN分数的权重(0-1之间) |
| |
| Returns: |
| List[float]: 融合后的分数列表 |
| """ |
| if len(multi_recall_scores) != len(lightgcn_scores): |
| raise ValueError("分数列表长度不匹配") |
| |
| # 对分数进行归一化 |
| def normalize_scores(scores): |
| scores = np.array(scores) |
| min_score = np.min(scores) |
| max_score = np.max(scores) |
| if max_score == min_score: |
| return np.ones_like(scores) * 0.5 |
| return (scores - min_score) / (max_score - min_score) |
| |
| norm_multi_scores = normalize_scores(multi_recall_scores) |
| norm_lightgcn_scores = normalize_scores(lightgcn_scores) |
| |
| # 加权融合 |
| fused_scores = alpha * norm_lightgcn_scores + (1 - alpha) * norm_multi_scores |
| |
| return fused_scores.tolist() |
| |
| def train_multi_recall(self): |
| """训练多路召回模型""" |
| self.init_multi_recall() |
| self.multi_recall.train_all() |
| |
| def update_recall_config(self, new_config): |
| """更新多路召回配置""" |
| self.recall_config.update(new_config) |
| if self.multi_recall: |
| self.multi_recall.update_config(new_config) |
| |
| def multi_recall_inference(self, user_id, total_items=200): |
| """ |
| 使用多路召回进行推荐 |
| |
| Args: |
| user_id: 用户ID |
| total_items: 总召回物品数量 |
| |
| Returns: |
| Tuple of (item_ids, scores, recall_breakdown) |
| """ |
| self.init_multi_recall() |
| |
| # 执行多路召回 |
| item_ids, scores, recall_results = self.multi_recall.recall(user_id, total_items) |
| |
| return item_ids, scores, recall_results |
| |
| def get_multi_recall_stats(self, user_id): |
| """获取多路召回统计信息""" |
| if self.multi_recall is None: |
| return {"error": "多路召回未初始化"} |
| |
| return self.multi_recall.get_recall_stats(user_id) |