| from typing import List, Tuple, Dict, Any |
| import numpy as np |
| from collections import defaultdict |
| |
| from .swing_recall import SwingRecall |
| from .hot_recall import HotRecall |
| from .ad_recall import AdRecall |
| from .usercf_recall import UserCFRecall |
| |
| class MultiRecallManager: |
| """ |
| 多路召回管理器 |
| 整合Swing、热度召回、广告召回和UserCF等多种召回策略 |
| """ |
| |
| def __init__(self, db_config: dict, recall_config: dict = None): |
| """ |
| 初始化多路召回管理器 |
| |
| Args: |
| db_config: 数据库配置 |
| recall_config: 召回配置,包含各个召回器的参数和召回数量 |
| """ |
| self.db_config = db_config |
| |
| # 默认配置 |
| default_config = { |
| 'swing': { |
| 'enabled': True, |
| 'num_items': 15, |
| 'alpha': 0.5 |
| }, |
| 'hot': { |
| 'enabled': True, |
| 'num_items': 10 |
| }, |
| 'ad': { |
| 'enabled': True, |
| 'num_items': 2 |
| }, |
| 'usercf': { |
| 'enabled': True, |
| 'num_items': 10, |
| 'min_common_items': 3, |
| 'num_similar_users': 50 |
| } |
| } |
| |
| # 合并用户配置 |
| self.config = default_config |
| if recall_config: |
| for key, value in recall_config.items(): |
| if key in self.config: |
| self.config[key].update(value) |
| else: |
| self.config[key] = value |
| |
| # 初始化各个召回器 |
| self.recalls = {} |
| self._init_recalls() |
| |
| def _init_recalls(self): |
| """初始化各个召回器""" |
| if self.config['swing']['enabled']: |
| self.recalls['swing'] = SwingRecall( |
| self.db_config, |
| alpha=self.config['swing']['alpha'] |
| ) |
| |
| if self.config['hot']['enabled']: |
| self.recalls['hot'] = HotRecall(self.db_config) |
| |
| if self.config['ad']['enabled']: |
| self.recalls['ad'] = AdRecall(self.db_config) |
| |
| if self.config['usercf']['enabled']: |
| self.recalls['usercf'] = UserCFRecall( |
| self.db_config, |
| min_common_items=self.config['usercf']['min_common_items'] |
| ) |
| |
| def train_all(self): |
| """训练所有召回器""" |
| print("开始训练多路召回模型...") |
| |
| for name, recall_model in self.recalls.items(): |
| print(f"训练 {name} 召回器...") |
| recall_model.train() |
| |
| print("所有召回器训练完成!") |
| |
| def recall(self, user_id: int, total_items: int = 200) -> Tuple[List[int], List[float], Dict[str, List[Tuple[int, float]]]]: |
| """ |
| 执行多路召回 |
| |
| Args: |
| user_id: 用户ID |
| total_items: 总召回物品数量 |
| |
| Returns: |
| Tuple containing: |
| - List of item IDs |
| - List of scores |
| - Dict of recall results by source |
| """ |
| recall_results = {} |
| all_candidates = defaultdict(list) # item_id -> [(source, score), ...] |
| |
| # 执行各路召回 |
| for source, recall_model in self.recalls.items(): |
| if not self.config[source]['enabled']: |
| continue |
| |
| num_items = self.config[source]['num_items'] |
| |
| # 特殊处理UserCF的参数 |
| if source == 'usercf': |
| items_scores = recall_model.recall( |
| user_id, |
| num_items=num_items, |
| num_similar_users=self.config[source]['num_similar_users'] |
| ) |
| else: |
| items_scores = recall_model.recall(user_id, num_items=num_items) |
| |
| recall_results[source] = items_scores |
| |
| # 收集候选物品 |
| for item_id, score in items_scores: |
| all_candidates[item_id].append((source, score)) |
| |
| # 融合多路召回结果 |
| final_candidates = self._merge_recall_results(all_candidates, total_items) |
| |
| # 分离item_ids和scores |
| item_ids = [item_id for item_id, _ in final_candidates] |
| scores = [score for _, score in final_candidates] |
| |
| return item_ids, scores, recall_results |
| |
| def _merge_recall_results(self, all_candidates: Dict[int, List[Tuple[str, float]]], |
| total_items: int) -> List[Tuple[int, float]]: |
| """ |
| 融合多路召回结果 |
| |
| Args: |
| all_candidates: 所有候选物品及其来源和分数 |
| total_items: 最终返回的物品数量 |
| |
| Returns: |
| List of (item_id, final_score) tuples |
| """ |
| # 定义各召回源的权重 |
| source_weights = { |
| 'swing': 0.3, |
| 'hot': 0.2, |
| 'ad': 0.1, |
| 'usercf': 0.4 |
| } |
| |
| final_scores = [] |
| |
| for item_id, source_scores in all_candidates.items(): |
| # 计算加权平均分数 |
| weighted_score = 0.0 |
| total_weight = 0.0 |
| |
| for source, score in source_scores: |
| weight = source_weights.get(source, 0.1) |
| weighted_score += weight * score |
| total_weight += weight |
| |
| # 归一化 |
| if total_weight > 0: |
| final_score = weighted_score / total_weight |
| else: |
| final_score = 0.0 |
| |
| # 多样性奖励:如果物品来自多个召回源,给予额外分数 |
| diversity_bonus = len(source_scores) * 0.1 |
| final_score += diversity_bonus |
| |
| final_scores.append((item_id, final_score)) |
| |
| # 按最终分数排序 |
| final_scores.sort(key=lambda x: x[1], reverse=True) |
| |
| return final_scores[:total_items] |
| |
| def get_recall_stats(self, user_id: int) -> Dict[str, Any]: |
| """ |
| 获取召回统计信息 |
| |
| Args: |
| user_id: 用户ID |
| |
| Returns: |
| 召回统计字典 |
| """ |
| stats = { |
| 'user_id': user_id, |
| 'enabled_recalls': list(self.recalls.keys()), |
| 'config': self.config |
| } |
| |
| # 获取各召回器的统计信息 |
| if 'usercf' in self.recalls: |
| try: |
| user_profile = self.recalls['usercf'].get_user_profile(user_id) |
| stats['user_profile'] = user_profile |
| |
| neighbors = self.recalls['usercf'].get_user_neighbors(user_id, 5) |
| stats['similar_users'] = neighbors |
| except: |
| pass |
| |
| return stats |
| |
| def update_config(self, new_config: dict): |
| """ |
| 更新召回配置 |
| |
| Args: |
| new_config: 新的配置字典 |
| """ |
| for key, value in new_config.items(): |
| if key in self.config: |
| self.config[key].update(value) |
| else: |
| self.config[key] = value |
| |
| # 重新初始化召回器 |
| self._init_recalls() |
| |
| def get_recall_breakdown(self, user_id: int) -> Dict[str, int]: |
| """ |
| 获取各召回源的物品数量分布 |
| |
| Args: |
| user_id: 用户ID |
| |
| Returns: |
| 各召回源的物品数量字典 |
| """ |
| breakdown = {} |
| |
| for source in self.recalls.keys(): |
| if self.config[source]['enabled']: |
| breakdown[source] = self.config[source]['num_items'] |
| else: |
| breakdown[source] = 0 |
| |
| return breakdown |