blob: 03cb3f8a7ad224b0381f55d2f8df01c720328af1 [file] [log] [blame]
from typing import List, Tuple, Dict, Any
import numpy as np
from collections import defaultdict
from .swing_recall import SwingRecall
from .hot_recall import HotRecall
from .ad_recall import AdRecall
from .usercf_recall import UserCFRecall
class MultiRecallManager:
"""
多路召回管理器
整合Swing、热度召回、广告召回和UserCF等多种召回策略
"""
def __init__(self, db_config: dict, recall_config: dict = None):
"""
初始化多路召回管理器
Args:
db_config: 数据库配置
recall_config: 召回配置,包含各个召回器的参数和召回数量
"""
self.db_config = db_config
# 默认配置
default_config = {
'swing': {
'enabled': True,
'num_items': 15,
'alpha': 0.5
},
'hot': {
'enabled': True,
'num_items': 10
},
'ad': {
'enabled': True,
'num_items': 2
},
'usercf': {
'enabled': True,
'num_items': 10,
'min_common_items': 3,
'num_similar_users': 50
}
}
# 合并用户配置
self.config = default_config
if recall_config:
for key, value in recall_config.items():
if key in self.config:
self.config[key].update(value)
else:
self.config[key] = value
# 初始化各个召回器
self.recalls = {}
self._init_recalls()
def _init_recalls(self):
"""初始化各个召回器"""
if self.config['swing']['enabled']:
self.recalls['swing'] = SwingRecall(
self.db_config,
alpha=self.config['swing']['alpha']
)
if self.config['hot']['enabled']:
self.recalls['hot'] = HotRecall(self.db_config)
if self.config['ad']['enabled']:
self.recalls['ad'] = AdRecall(self.db_config)
if self.config['usercf']['enabled']:
self.recalls['usercf'] = UserCFRecall(
self.db_config,
min_common_items=self.config['usercf']['min_common_items']
)
def train_all(self):
"""训练所有召回器"""
print("开始训练多路召回模型...")
for name, recall_model in self.recalls.items():
print(f"训练 {name} 召回器...")
recall_model.train()
print("所有召回器训练完成!")
def recall(self, user_id: int, total_items: int = 200) -> Tuple[List[int], List[float], Dict[str, List[Tuple[int, float]]]]:
"""
执行多路召回
Args:
user_id: 用户ID
total_items: 总召回物品数量
Returns:
Tuple containing:
- List of item IDs
- List of scores
- Dict of recall results by source
"""
recall_results = {}
all_candidates = defaultdict(list) # item_id -> [(source, score), ...]
# 执行各路召回
for source, recall_model in self.recalls.items():
if not self.config[source]['enabled']:
continue
num_items = self.config[source]['num_items']
# 特殊处理UserCF的参数
if source == 'usercf':
items_scores = recall_model.recall(
user_id,
num_items=num_items,
num_similar_users=self.config[source]['num_similar_users']
)
else:
items_scores = recall_model.recall(user_id, num_items=num_items)
recall_results[source] = items_scores
# 收集候选物品
for item_id, score in items_scores:
all_candidates[item_id].append((source, score))
# 融合多路召回结果
final_candidates = self._merge_recall_results(all_candidates, total_items)
# 分离item_ids和scores
item_ids = [item_id for item_id, _ in final_candidates]
scores = [score for _, score in final_candidates]
return item_ids, scores, recall_results
def _merge_recall_results(self, all_candidates: Dict[int, List[Tuple[str, float]]],
total_items: int) -> List[Tuple[int, float]]:
"""
融合多路召回结果
Args:
all_candidates: 所有候选物品及其来源和分数
total_items: 最终返回的物品数量
Returns:
List of (item_id, final_score) tuples
"""
# 定义各召回源的权重
source_weights = {
'swing': 0.3,
'hot': 0.2,
'ad': 0.1,
'usercf': 0.4
}
final_scores = []
for item_id, source_scores in all_candidates.items():
# 计算加权平均分数
weighted_score = 0.0
total_weight = 0.0
for source, score in source_scores:
weight = source_weights.get(source, 0.1)
weighted_score += weight * score
total_weight += weight
# 归一化
if total_weight > 0:
final_score = weighted_score / total_weight
else:
final_score = 0.0
# 多样性奖励:如果物品来自多个召回源,给予额外分数
diversity_bonus = len(source_scores) * 0.1
final_score += diversity_bonus
final_scores.append((item_id, final_score))
# 按最终分数排序
final_scores.sort(key=lambda x: x[1], reverse=True)
return final_scores[:total_items]
def get_recall_stats(self, user_id: int) -> Dict[str, Any]:
"""
获取召回统计信息
Args:
user_id: 用户ID
Returns:
召回统计字典
"""
stats = {
'user_id': user_id,
'enabled_recalls': list(self.recalls.keys()),
'config': self.config
}
# 获取各召回器的统计信息
if 'usercf' in self.recalls:
try:
user_profile = self.recalls['usercf'].get_user_profile(user_id)
stats['user_profile'] = user_profile
neighbors = self.recalls['usercf'].get_user_neighbors(user_id, 5)
stats['similar_users'] = neighbors
except:
pass
return stats
def update_config(self, new_config: dict):
"""
更新召回配置
Args:
new_config: 新的配置字典
"""
for key, value in new_config.items():
if key in self.config:
self.config[key].update(value)
else:
self.config[key] = value
# 重新初始化召回器
self._init_recalls()
def get_recall_breakdown(self, user_id: int) -> Dict[str, int]:
"""
获取各召回源的物品数量分布
Args:
user_id: 用户ID
Returns:
各召回源的物品数量字典
"""
breakdown = {}
for source in self.recalls.keys():
if self.config[source]['enabled']:
breakdown[source] = self.config[source]['num_items']
else:
breakdown[source] = 0
return breakdown