推荐系统

Change-Id: I49b9205568f1ccf88b32b08511aff8b0bea8d1bd
diff --git a/rhj/backend/app/services/recommendation_service.py b/rhj/backend/app/services/recommendation_service.py
new file mode 100644
index 0000000..2f4de13
--- /dev/null
+++ b/rhj/backend/app/services/recommendation_service.py
@@ -0,0 +1,719 @@
+import torch
+import pymysql
+import numpy as np
+import random
+from app.models.recommend.LightGCN import LightGCN
+from app.models.recall import MultiRecallManager
+from app.services.lightgcn_scorer import LightGCNScorer
+from app.utils.parse_args import args
+from app.utils.data_loader import EdgeListData
+from app.utils.graph_build import build_user_post_graph
+from config import Config
+
+class RecommendationService:
+    def __init__(self):
+        # 数据库连接配置 - 修改为redbook数据库
+        self.db_config = {
+            'host': '10.126.59.25',
+            'port': 3306,
+            'user': 'root',
+            'password': '123456',
+            'database': 'redbook',  # 使用redbook数据库
+            'charset': 'utf8mb4'
+        }
+        
+        # 模型配置
+        args.device = 'cuda:7' if torch.cuda.is_available() else 'cpu'
+        args.data_path = './app/user_post_graph.txt'  # 修改为帖子图文件
+        args.pre_model_path = './app/models/recommend/LightGCN_pretrained.pt'
+        
+        self.topk = 2  # 默认推荐数量
+        
+        # 初始化多路召回管理器
+        self.multi_recall = None
+        self.multi_recall_enabled = True  # 控制是否启用多路召回
+        
+        # 初始化LightGCN评分器
+        self.lightgcn_scorer = None
+        self.use_lightgcn_rerank = True  # 控制是否使用LightGCN对多路召回结果重新打分
+        
+        # 多路召回配置
+        self.recall_config = {
+            'swing': {
+                'enabled': True,
+                'num_items': 20,  # 增加召回数量
+                'alpha': 0.5
+            },
+            'hot': {
+                'enabled': True,
+                'num_items': 15  # 增加热度召回数量
+            },
+            'ad': {
+                'enabled': True,
+                'num_items': 5   # 增加广告召回数量
+            },
+            'usercf': {
+                'enabled': True,
+                'num_items': 15,
+                'min_common_items': 1,  # 降低阈值,从3改为1
+                'num_similar_users': 20  # 减少相似用户数量以提高效率
+            }
+        }
+    
+    def calculate_tag_similarity(self, tags1, tags2):
+        """
+        计算两个帖子标签的相似度
+        输入: tags1, tags2 - 标签字符串,以逗号分隔
+        输出: 相似度分数(0-1之间)
+        """
+        if not tags1 or not tags2:
+            return 0.0
+        
+        # 将标签字符串转换为集合
+        set1 = set(tag.strip() for tag in tags1.split(',') if tag.strip())
+        set2 = set(tag.strip() for tag in tags2.split(',') if tag.strip())
+        
+        if not set1 or not set2:
+            return 0.0
+        
+        # 计算标签重叠比例(Jaccard相似度)
+        intersection = len(set1.intersection(set2))
+        union = len(set1.union(set2))
+        
+        return intersection / union if union > 0 else 0.0
+    
+    def mmr_rerank_with_ads(self, post_ids, scores, theta=0.5, target_size=None):
+        """
+        使用MMR算法重新排序推荐结果,并在过程中加入广告约束
+        输入:
+        - post_ids: 帖子ID列表
+        - scores: 对应的推荐分数列表
+        - theta: 平衡相关性和多样性的参数(0.5表示各占一半)
+        - target_size: 目标结果数量,默认与输入相同
+        输出: 重排后的(post_ids, scores),每5条帖子包含1条广告
+        """
+        if target_size is None:
+            target_size = len(post_ids)
+        
+        if len(post_ids) <= 1:
+            return post_ids, scores
+        
+        # 获取帖子标签信息和广告标识
+        conn = pymysql.connect(**self.db_config)
+        cursor = conn.cursor()
+        
+        try:
+            # 查询所有候选帖子的标签和广告标识
+            format_strings = ','.join(['%s'] * len(post_ids))
+            cursor.execute(
+                f"""SELECT p.id, p.is_advertisement, 
+                          COALESCE(GROUP_CONCAT(t.name), '') as tags
+                    FROM posts p
+                    LEFT JOIN post_tags pt ON p.id = pt.post_id
+                    LEFT JOIN tags t ON pt.tag_id = t.id 
+                    WHERE p.id IN ({format_strings}) AND p.status = 'published'
+                    GROUP BY p.id, p.is_advertisement""",
+                tuple(post_ids)
+            )
+            post_info_rows = cursor.fetchall()
+            post_tags = {}
+            post_is_ad = {}
+            
+            for row in post_info_rows:
+                post_id, is_ad, tags = row
+                post_tags[post_id] = tags or ""
+                post_is_ad[post_id] = bool(is_ad)
+            
+            # 对于没有查询到的帖子,设置默认值
+            for post_id in post_ids:
+                if post_id not in post_tags:
+                    post_tags[post_id] = ""
+                    post_is_ad[post_id] = False
+            
+            # 获取额外的广告帖子作为候选
+            cursor.execute("""
+                SELECT id, heat FROM posts 
+                WHERE is_advertisement = 1 AND status = 'published' 
+                AND id NOT IN ({})
+                ORDER BY heat DESC
+                LIMIT 50
+            """.format(format_strings), tuple(post_ids))
+            extra_ad_rows = cursor.fetchall()
+            
+        finally:
+            cursor.close()
+            conn.close()
+        
+        # 分离普通帖子和广告帖子
+        normal_candidates = []
+        ad_candidates = []
+        
+        for post_id, score in zip(post_ids, scores):
+            if post_is_ad[post_id]:
+                ad_candidates.append((post_id, score))
+            else:
+                normal_candidates.append((post_id, score))
+        
+        # 添加额外的广告候选
+        for ad_id, heat in extra_ad_rows:
+            # 为广告帖子设置标签和广告标识
+            post_tags[ad_id] = ""  # 广告帖子暂时设置为空标签
+            post_is_ad[ad_id] = True
+            ad_score = float(heat) / 1000.0  # 将热度转换为分数
+            ad_candidates.append((ad_id, ad_score))
+        
+        # 排序候选列表
+        normal_candidates.sort(key=lambda x: x[1], reverse=True)
+        ad_candidates.sort(key=lambda x: x[1], reverse=True)
+        
+        # MMR算法实现,加入广告约束
+        selected = []
+        normal_idx = 0
+        ad_idx = 0
+        
+        while len(selected) < target_size:
+            current_position = len(selected)
+            
+            # 检查是否需要插入广告(每5个位置插入1个广告)
+            if (current_position + 1) % 5 == 0 and ad_idx < len(ad_candidates):
+                # 插入广告
+                selected.append(ad_candidates[ad_idx])
+                ad_idx += 1
+            else:
+                # 使用MMR选择普通帖子
+                if normal_idx >= len(normal_candidates):
+                    break
+                
+                best_score = -float('inf')
+                best_local_idx = normal_idx
+                
+                # 在剩余的普通候选中选择最佳的
+                for i in range(normal_idx, min(normal_idx + 10, len(normal_candidates))):
+                    post_id, relevance_score = normal_candidates[i]
+                    
+                    # 计算与已选帖子的最大相似度
+                    max_similarity = 0.0
+                    current_tags = post_tags[post_id]
+                    
+                    for selected_post_id, _ in selected:
+                        selected_tags = post_tags[selected_post_id]
+                        similarity = self.calculate_tag_similarity(current_tags, selected_tags)
+                        max_similarity = max(max_similarity, similarity)
+                    
+                    # 计算MMR分数
+                    mmr_score = theta * relevance_score - (1 - theta) * max_similarity
+                    
+                    if mmr_score > best_score:
+                        best_score = mmr_score
+                        best_local_idx = i
+                
+                # 选择最佳候选
+                selected.append(normal_candidates[best_local_idx])
+                # 将选中的元素移到已处理区域
+                normal_candidates[normal_idx], normal_candidates[best_local_idx] = \
+                    normal_candidates[best_local_idx], normal_candidates[normal_idx]
+                normal_idx += 1
+        
+        # 提取重排后的结果
+        reranked_post_ids = [post_id for post_id, _ in selected]
+        reranked_scores = [score for _, score in selected]
+        
+        return reranked_post_ids, reranked_scores
+    
+    def insert_advertisements(self, post_ids, scores):
+        """
+        在推荐结果中插入广告,每5条帖子插入1条广告
+        输入: post_ids, scores - 原始推荐结果
+        输出: 插入广告后的(post_ids, scores)
+        """
+        # 获取可用的广告帖子
+        conn = pymysql.connect(**self.db_config)
+        cursor = conn.cursor()
+        
+        try:
+            cursor.execute("""
+                SELECT id, heat FROM posts 
+                WHERE is_advertisement = 1 AND status = 'published' 
+                ORDER BY heat DESC
+                LIMIT 50
+            """)
+            ad_rows = cursor.fetchall()
+            
+            if not ad_rows:
+                # 没有广告,直接返回原结果
+                return post_ids, scores
+            
+            # 可用的广告帖子(排除已在推荐结果中的)
+            available_ads = [(ad_id, heat) for ad_id, heat in ad_rows if ad_id not in post_ids]
+            
+            if not available_ads:
+                # 没有可用的新广告,直接返回原结果
+                return post_ids, scores
+            
+        finally:
+            cursor.close()
+            conn.close()
+        
+        # 插入广告的逻辑
+        result_posts = []
+        result_scores = []
+        ad_index = 0
+        
+        for i, (post_id, score) in enumerate(zip(post_ids, scores)):
+            result_posts.append(post_id)
+            result_scores.append(score)
+            
+            # 每5条帖子后插入一条广告
+            if (i + 1) % 5 == 0 and ad_index < len(available_ads):
+                ad_id, ad_heat = available_ads[ad_index]
+                result_posts.append(ad_id)
+                result_scores.append(float(ad_heat) / 1000.0)  # 将热度转换为分数范围
+                ad_index += 1
+        
+        return result_posts, result_scores
+
+    def user_cold_start(self, topk=None):
+        """
+        冷启动:直接返回热度最高的topk个帖子详细信息
+        """
+        if topk is None:
+            topk = self.topk
+            
+        conn = pymysql.connect(**self.db_config)
+        cursor = conn.cursor()
+
+        try:
+            # 查询热度最高的topk个帖子
+            cursor.execute("""
+                SELECT p.id, p.user_id, p.title, p.content, p.type, p.heat, p.created_at
+                FROM posts p 
+                WHERE p.status = 'published' 
+                ORDER BY p.heat DESC 
+                LIMIT %s
+            """, (topk,))
+            post_rows = cursor.fetchall()
+            post_ids = [row[0] for row in post_rows]
+            post_map = {row[0]: row for row in post_rows}
+
+            # 查询用户信息
+            owner_ids = list(set(row[1] for row in post_rows))
+            if owner_ids:
+                format_strings_user = ','.join(['%s'] * len(owner_ids))
+                cursor.execute(
+                    f"SELECT id, username FROM users WHERE id IN ({format_strings_user})",
+                    tuple(owner_ids)
+                )
+                user_rows = cursor.fetchall()
+                user_map = {row[0]: row[1] for row in user_rows}
+            else:
+                user_map = {}
+
+            # 查询帖子标签
+            if post_ids:
+                format_strings = ','.join(['%s'] * len(post_ids))
+                cursor.execute(
+                    f"""SELECT pt.post_id, GROUP_CONCAT(t.name) as tags
+                        FROM post_tags pt 
+                        JOIN tags t ON pt.tag_id = t.id 
+                        WHERE pt.post_id IN ({format_strings})
+                        GROUP BY pt.post_id""",
+                    tuple(post_ids)
+                )
+                tag_rows = cursor.fetchall()
+                tag_map = {row[0]: row[1] for row in tag_rows}
+            else:
+                tag_map = {}
+
+            post_list = []
+            for post_id in post_ids:
+                row = post_map.get(post_id)
+                if not row:
+                    continue
+                owner_user_id = row[1]
+                post_list.append({
+                    'post_id': post_id,
+                    'title': row[2],
+                    'content': row[3][:200] + '...' if len(row[3]) > 200 else row[3],  # 截取前200字符
+                    'type': row[4],
+                    'username': user_map.get(owner_user_id, ""),
+                    'heat': row[5],
+                    'tags': tag_map.get(post_id, ""),
+                    'created_at': str(row[6]) if row[6] else ""
+                })
+            return post_list
+        finally:
+            cursor.close()
+            conn.close()
+
+    def run_inference(self, user_id, topk=None, use_multi_recall=None):
+        """
+        推荐推理主函数
+        
+        Args:
+            user_id: 用户ID
+            topk: 推荐数量
+            use_multi_recall: 是否使用多路召回,None表示使用默认设置
+        """
+        if topk is None:
+            topk = self.topk
+        
+        # 决定使用哪种召回方式
+        if use_multi_recall is None:
+            use_multi_recall = self.multi_recall_enabled
+        
+        return self._run_multi_recall_inference(user_id, topk)
+    
+    def _run_multi_recall_inference(self, user_id, topk):
+        """使用多路召回进行推荐,并可选择使用LightGCN重新打分"""
+        try:
+            # 初始化多路召回(如果尚未初始化)
+            self.init_multi_recall()
+            
+            # 执行多路召回,召回更多候选物品
+            total_candidates = min(topk * 10, 500)  # 召回候选数是最终推荐数的10倍
+            candidate_post_ids, candidate_scores, recall_breakdown = self.multi_recall_inference(
+                user_id, total_candidates
+            )
+            
+            if not candidate_post_ids:
+                # 如果多路召回没有结果,回退到冷启动
+                print(f"用户 {user_id} 多路召回无结果,使用冷启动")
+                return self.user_cold_start(topk)
+            
+            print(f"用户 {user_id} 多路召回候选数量: {len(candidate_post_ids)}")
+            print(f"召回来源分布: {self._get_recall_source_stats(recall_breakdown)}")
+            
+            # 如果启用LightGCN重新打分,使用LightGCN对候选结果进行评分
+            if self.use_lightgcn_rerank:
+                print("使用LightGCN对多路召回结果进行重新打分...")
+                lightgcn_scores = self._get_lightgcn_scores(user_id, candidate_post_ids)
+                
+                # 直接使用LightGCN分数,不进行融合
+                final_scores = lightgcn_scores
+                
+                print(f"LightGCN打分完成,分数范围: [{min(lightgcn_scores):.4f}, {max(lightgcn_scores):.4f}]")
+                print(f"使用LightGCN分数进行重排")
+            else:
+                # 使用原始多路召回分数
+                final_scores = candidate_scores
+            
+            # 使用MMR算法重排,包含广告约束
+            final_post_ids, final_scores = self.mmr_rerank_with_ads(
+                candidate_post_ids, final_scores, theta=0.5, target_size=topk
+            )
+            
+            return final_post_ids, final_scores
+            
+        except Exception as e:
+            print(f"多路召回失败: {str(e)},回退到LightGCN")
+            return self._run_lightgcn_inference(user_id, topk)
+    
+    def _run_lightgcn_inference(self, user_id, topk):
+        """使用原始LightGCN进行推荐"""
+        user2idx, post2idx = build_user_post_graph(return_mapping=True)
+        idx2post = {v: k for k, v in post2idx.items()}
+
+        if user_id not in user2idx:
+            # 冷启动
+            return self.user_cold_start(topk)
+
+        user_idx = user2idx[user_id]
+
+        dataset = EdgeListData(args.data_path, args.data_path)
+        pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
+        pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
+        pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
+
+        model = LightGCN(dataset, phase='vanilla').to(args.device)
+        model.load_state_dict(pretrained_dict, strict=False)
+        model.eval()
+
+        with torch.no_grad():
+            user_emb, item_emb = model.generate()
+            user_vec = user_emb[user_idx].unsqueeze(0)
+            scores = model.rating(user_vec, item_emb).squeeze(0)
+            
+            # 获取所有物品的分数(而不是只取top候选)
+            all_scores = scores.cpu().numpy()
+            all_post_ids = [idx2post[idx] for idx in range(len(all_scores))]
+            
+            # 过滤掉分数为负的物品,只保留正分数的候选
+            positive_candidates = [(post_id, score) for post_id, score in zip(all_post_ids, all_scores) if score > 0]
+            
+            if not positive_candidates:
+                # 如果没有正分数的候选,取分数最高的一些
+                sorted_candidates = sorted(zip(all_post_ids, all_scores), key=lambda x: x[1], reverse=True)
+                positive_candidates = sorted_candidates[:min(100, len(sorted_candidates))]
+            
+            candidate_post_ids = [post_id for post_id, _ in positive_candidates]
+            candidate_scores = [score for _, score in positive_candidates]
+
+        print(f"用户 {user_id} 的LightGCN候选物品数量: {len(candidate_post_ids)}")
+
+        # 使用MMR算法重排,包含广告约束,theta=0.5平衡相关性和多样性
+        final_post_ids, final_scores = self.mmr_rerank_with_ads(
+            candidate_post_ids, candidate_scores, theta=0.5, target_size=topk
+        )
+
+        return final_post_ids, final_scores
+    
+    def _get_recall_source_stats(self, recall_breakdown):
+        """获取召回来源统计"""
+        stats = {}
+        for source, items in recall_breakdown.items():
+            stats[source] = len(items)
+        return stats
+
+    def get_post_info(self, topk_post_ids, topk_scores=None):
+        """
+        输入: topk_post_ids(帖子ID列表),topk_scores(对应的打分列表,可选)
+        输出: 推荐帖子的详细信息列表,每个元素为dict
+        """
+        if not topk_post_ids:
+            return []
+        
+        print(f"获取帖子详细信息,帖子ID列表: {topk_post_ids}")
+        if topk_scores is not None:
+            print(f"对应的推荐打分: {topk_scores}")
+
+        conn = pymysql.connect(**self.db_config)
+        cursor = conn.cursor()
+
+        try:
+            # 查询帖子基本信息
+            format_strings = ','.join(['%s'] * len(topk_post_ids))
+            cursor.execute(
+                f"""SELECT p.id, p.user_id, p.title, p.content, p.type, p.heat, p.created_at, p.is_advertisement
+                    FROM posts p 
+                    WHERE p.id IN ({format_strings}) AND p.status = 'published'""",
+                tuple(topk_post_ids)
+            )
+            post_rows = cursor.fetchall()
+            post_map = {row[0]: row for row in post_rows}
+
+            # 查询用户信息
+            owner_ids = list(set(row[1] for row in post_rows))
+            if owner_ids:
+                format_strings_user = ','.join(['%s'] * len(owner_ids))
+                cursor.execute(
+                    f"SELECT id, username FROM users WHERE id IN ({format_strings_user})",
+                    tuple(owner_ids)
+                )
+                user_rows = cursor.fetchall()
+                user_map = {row[0]: row[1] for row in user_rows}
+            else:
+                user_map = {}
+
+            # 查询帖子标签
+            cursor.execute(
+                f"""SELECT pt.post_id, GROUP_CONCAT(t.name) as tags
+                    FROM post_tags pt 
+                    JOIN tags t ON pt.tag_id = t.id 
+                    WHERE pt.post_id IN ({format_strings})
+                    GROUP BY pt.post_id""",
+                tuple(topk_post_ids)
+            )
+            tag_rows = cursor.fetchall()
+            tag_map = {row[0]: row[1] for row in tag_rows}
+
+            # 查询行为统计(点赞数、评论数等)
+            cursor.execute(
+                f"""SELECT post_id, type, COUNT(*) as count
+                    FROM behaviors 
+                    WHERE post_id IN ({format_strings}) 
+                    GROUP BY post_id, type""",
+                tuple(topk_post_ids)
+            )
+            behavior_rows = cursor.fetchall()
+            behavior_stats = {}
+            for row in behavior_rows:
+                post_id, behavior_type, count = row
+                if post_id not in behavior_stats:
+                    behavior_stats[post_id] = {}
+                behavior_stats[post_id][behavior_type] = count
+
+            post_list = []
+            for i, post_id in enumerate(topk_post_ids):
+                row = post_map.get(post_id)
+                if not row:
+                    print(f"帖子ID {post_id} 不存在或未发布,跳过")
+                    continue
+                owner_user_id = row[1]
+                stats = behavior_stats.get(post_id, {})
+                post_info = {
+                    'post_id': post_id,
+                    'title': row[2],
+                    'content': row[3][:200] + '...' if len(row[3]) > 200 else row[3],
+                    'type': row[4],
+                    'username': user_map.get(owner_user_id, ""),
+                    'heat': row[5],
+                    'tags': tag_map.get(post_id, ""),
+                    'created_at': str(row[6]) if row[6] else "",
+                    'is_advertisement': bool(row[7]),  # 添加广告标识
+                    'like_count': stats.get('like', 0),
+                    'comment_count': stats.get('comment', 0),
+                    'favorite_count': stats.get('favorite', 0),
+                    'view_count': stats.get('view', 0),
+                    'share_count': stats.get('share', 0)
+                }
+                
+                # 如果有推荐打分,添加到结果中
+                if topk_scores is not None and i < len(topk_scores):
+                    post_info['recommendation_score'] = float(topk_scores[i])
+                
+                post_list.append(post_info)
+            return post_list
+        finally:
+            cursor.close()
+            conn.close()
+
+    def get_recommendations(self, user_id, topk=None):
+        """
+        获取推荐结果的主要接口
+        """
+        try:
+            result = self.run_inference(user_id, topk)
+            # 如果是冷启动直接返回详细信息,否则查详情
+            if isinstance(result, list) and result and isinstance(result[0], dict):
+                return result
+            else:
+                # result 现在是 (topk_post_ids, topk_scores) 的元组
+                if isinstance(result, tuple) and len(result) == 2:
+                    topk_post_ids, topk_scores = result
+                    return self.get_post_info(topk_post_ids, topk_scores)
+                else:
+                    # 兼容旧的返回格式
+                    return self.get_post_info(result)
+        except Exception as e:
+            raise Exception(f"推荐系统错误: {str(e)}")
+
+    def get_all_item_scores(self, user_id):
+        """
+        获取用户对所有物品的打分
+        输入: user_id
+        输出: (post_ids, scores) - 所有帖子ID和对应的打分
+        """
+        user2idx, post2idx = build_user_post_graph(return_mapping=True)
+        idx2post = {v: k for k, v in post2idx.items()}
+
+        if user_id not in user2idx:
+            # 用户不存在,返回空结果
+            return [], []
+
+        user_idx = user2idx[user_id]
+
+        dataset = EdgeListData(args.data_path, args.data_path)
+        pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
+        pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
+        pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
+
+        model = LightGCN(dataset, phase='vanilla').to(args.device)
+        model.load_state_dict(pretrained_dict, strict=False)
+        model.eval()
+
+        with torch.no_grad():
+            user_emb, item_emb = model.generate()
+            user_vec = user_emb[user_idx].unsqueeze(0)
+            scores = model.rating(user_vec, item_emb).squeeze(0)
+            
+            # 获取所有物品的ID和分数
+            all_scores = scores.cpu().numpy()
+            all_post_ids = [idx2post[idx] for idx in range(len(all_scores))]
+            
+            return all_post_ids, all_scores
+
+    def init_multi_recall(self):
+        """初始化多路召回管理器"""
+        if self.multi_recall is None:
+            print("初始化多路召回管理器...")
+            self.multi_recall = MultiRecallManager(self.db_config, self.recall_config)
+            print("多路召回管理器初始化完成")
+    
+    def init_lightgcn_scorer(self):
+        """初始化LightGCN评分器"""
+        if self.lightgcn_scorer is None:
+            print("初始化LightGCN评分器...")
+            self.lightgcn_scorer = LightGCNScorer()
+            print("LightGCN评分器初始化完成")
+    
+    def _get_lightgcn_scores(self, user_id, candidate_post_ids):
+        """
+        获取候选物品的LightGCN分数
+        
+        Args:
+            user_id: 用户ID
+            candidate_post_ids: 候选物品ID列表
+            
+        Returns:
+            List[float]: LightGCN分数列表
+        """
+        self.init_lightgcn_scorer()
+        return self.lightgcn_scorer.score_batch_candidates(user_id, candidate_post_ids)
+    
+    def _fuse_scores(self, multi_recall_scores, lightgcn_scores, alpha=0.6):
+        """
+        融合多路召回分数和LightGCN分数
+        
+        Args:
+            multi_recall_scores: 多路召回分数列表
+            lightgcn_scores: LightGCN分数列表
+            alpha: LightGCN分数的权重(0-1之间)
+            
+        Returns:
+            List[float]: 融合后的分数列表
+        """
+        if len(multi_recall_scores) != len(lightgcn_scores):
+            raise ValueError("分数列表长度不匹配")
+        
+        # 对分数进行归一化
+        def normalize_scores(scores):
+            scores = np.array(scores)
+            min_score = np.min(scores)
+            max_score = np.max(scores)
+            if max_score == min_score:
+                return np.ones_like(scores) * 0.5
+            return (scores - min_score) / (max_score - min_score)
+        
+        norm_multi_scores = normalize_scores(multi_recall_scores)
+        norm_lightgcn_scores = normalize_scores(lightgcn_scores)
+        
+        # 加权融合
+        fused_scores = alpha * norm_lightgcn_scores + (1 - alpha) * norm_multi_scores
+        
+        return fused_scores.tolist()
+    
+    def train_multi_recall(self):
+        """训练多路召回模型"""
+        self.init_multi_recall()
+        self.multi_recall.train_all()
+    
+    def update_recall_config(self, new_config):
+        """更新多路召回配置"""
+        self.recall_config.update(new_config)
+        if self.multi_recall:
+            self.multi_recall.update_config(new_config)
+    
+    def multi_recall_inference(self, user_id, total_items=200):
+        """
+        使用多路召回进行推荐
+        
+        Args:
+            user_id: 用户ID
+            total_items: 总召回物品数量
+            
+        Returns:
+            Tuple of (item_ids, scores, recall_breakdown)
+        """
+        self.init_multi_recall()
+        
+        # 执行多路召回
+        item_ids, scores, recall_results = self.multi_recall.recall(user_id, total_items)
+        
+        return item_ids, scores, recall_results
+    
+    def get_multi_recall_stats(self, user_id):
+        """获取多路召回统计信息"""
+        if self.multi_recall is None:
+            return {"error": "多路召回未初始化"}
+        
+        return self.multi_recall.get_recall_stats(user_id)