blob: 2f4de137c96979f85b3e727866f47f221f263d1c [file] [log] [blame]
import torch
import pymysql
import numpy as np
import random
from app.models.recommend.LightGCN import LightGCN
from app.models.recall import MultiRecallManager
from app.services.lightgcn_scorer import LightGCNScorer
from app.utils.parse_args import args
from app.utils.data_loader import EdgeListData
from app.utils.graph_build import build_user_post_graph
from config import Config
class RecommendationService:
def __init__(self):
# 数据库连接配置 - 修改为redbook数据库
self.db_config = {
'host': '10.126.59.25',
'port': 3306,
'user': 'root',
'password': '123456',
'database': 'redbook', # 使用redbook数据库
'charset': 'utf8mb4'
}
# 模型配置
args.device = 'cuda:7' if torch.cuda.is_available() else 'cpu'
args.data_path = './app/user_post_graph.txt' # 修改为帖子图文件
args.pre_model_path = './app/models/recommend/LightGCN_pretrained.pt'
self.topk = 2 # 默认推荐数量
# 初始化多路召回管理器
self.multi_recall = None
self.multi_recall_enabled = True # 控制是否启用多路召回
# 初始化LightGCN评分器
self.lightgcn_scorer = None
self.use_lightgcn_rerank = True # 控制是否使用LightGCN对多路召回结果重新打分
# 多路召回配置
self.recall_config = {
'swing': {
'enabled': True,
'num_items': 20, # 增加召回数量
'alpha': 0.5
},
'hot': {
'enabled': True,
'num_items': 15 # 增加热度召回数量
},
'ad': {
'enabled': True,
'num_items': 5 # 增加广告召回数量
},
'usercf': {
'enabled': True,
'num_items': 15,
'min_common_items': 1, # 降低阈值,从3改为1
'num_similar_users': 20 # 减少相似用户数量以提高效率
}
}
def calculate_tag_similarity(self, tags1, tags2):
"""
计算两个帖子标签的相似度
输入: tags1, tags2 - 标签字符串,以逗号分隔
输出: 相似度分数(0-1之间)
"""
if not tags1 or not tags2:
return 0.0
# 将标签字符串转换为集合
set1 = set(tag.strip() for tag in tags1.split(',') if tag.strip())
set2 = set(tag.strip() for tag in tags2.split(',') if tag.strip())
if not set1 or not set2:
return 0.0
# 计算标签重叠比例(Jaccard相似度)
intersection = len(set1.intersection(set2))
union = len(set1.union(set2))
return intersection / union if union > 0 else 0.0
def mmr_rerank_with_ads(self, post_ids, scores, theta=0.5, target_size=None):
"""
使用MMR算法重新排序推荐结果,并在过程中加入广告约束
输入:
- post_ids: 帖子ID列表
- scores: 对应的推荐分数列表
- theta: 平衡相关性和多样性的参数(0.5表示各占一半)
- target_size: 目标结果数量,默认与输入相同
输出: 重排后的(post_ids, scores),每5条帖子包含1条广告
"""
if target_size is None:
target_size = len(post_ids)
if len(post_ids) <= 1:
return post_ids, scores
# 获取帖子标签信息和广告标识
conn = pymysql.connect(**self.db_config)
cursor = conn.cursor()
try:
# 查询所有候选帖子的标签和广告标识
format_strings = ','.join(['%s'] * len(post_ids))
cursor.execute(
f"""SELECT p.id, p.is_advertisement,
COALESCE(GROUP_CONCAT(t.name), '') as tags
FROM posts p
LEFT JOIN post_tags pt ON p.id = pt.post_id
LEFT JOIN tags t ON pt.tag_id = t.id
WHERE p.id IN ({format_strings}) AND p.status = 'published'
GROUP BY p.id, p.is_advertisement""",
tuple(post_ids)
)
post_info_rows = cursor.fetchall()
post_tags = {}
post_is_ad = {}
for row in post_info_rows:
post_id, is_ad, tags = row
post_tags[post_id] = tags or ""
post_is_ad[post_id] = bool(is_ad)
# 对于没有查询到的帖子,设置默认值
for post_id in post_ids:
if post_id not in post_tags:
post_tags[post_id] = ""
post_is_ad[post_id] = False
# 获取额外的广告帖子作为候选
cursor.execute("""
SELECT id, heat FROM posts
WHERE is_advertisement = 1 AND status = 'published'
AND id NOT IN ({})
ORDER BY heat DESC
LIMIT 50
""".format(format_strings), tuple(post_ids))
extra_ad_rows = cursor.fetchall()
finally:
cursor.close()
conn.close()
# 分离普通帖子和广告帖子
normal_candidates = []
ad_candidates = []
for post_id, score in zip(post_ids, scores):
if post_is_ad[post_id]:
ad_candidates.append((post_id, score))
else:
normal_candidates.append((post_id, score))
# 添加额外的广告候选
for ad_id, heat in extra_ad_rows:
# 为广告帖子设置标签和广告标识
post_tags[ad_id] = "" # 广告帖子暂时设置为空标签
post_is_ad[ad_id] = True
ad_score = float(heat) / 1000.0 # 将热度转换为分数
ad_candidates.append((ad_id, ad_score))
# 排序候选列表
normal_candidates.sort(key=lambda x: x[1], reverse=True)
ad_candidates.sort(key=lambda x: x[1], reverse=True)
# MMR算法实现,加入广告约束
selected = []
normal_idx = 0
ad_idx = 0
while len(selected) < target_size:
current_position = len(selected)
# 检查是否需要插入广告(每5个位置插入1个广告)
if (current_position + 1) % 5 == 0 and ad_idx < len(ad_candidates):
# 插入广告
selected.append(ad_candidates[ad_idx])
ad_idx += 1
else:
# 使用MMR选择普通帖子
if normal_idx >= len(normal_candidates):
break
best_score = -float('inf')
best_local_idx = normal_idx
# 在剩余的普通候选中选择最佳的
for i in range(normal_idx, min(normal_idx + 10, len(normal_candidates))):
post_id, relevance_score = normal_candidates[i]
# 计算与已选帖子的最大相似度
max_similarity = 0.0
current_tags = post_tags[post_id]
for selected_post_id, _ in selected:
selected_tags = post_tags[selected_post_id]
similarity = self.calculate_tag_similarity(current_tags, selected_tags)
max_similarity = max(max_similarity, similarity)
# 计算MMR分数
mmr_score = theta * relevance_score - (1 - theta) * max_similarity
if mmr_score > best_score:
best_score = mmr_score
best_local_idx = i
# 选择最佳候选
selected.append(normal_candidates[best_local_idx])
# 将选中的元素移到已处理区域
normal_candidates[normal_idx], normal_candidates[best_local_idx] = \
normal_candidates[best_local_idx], normal_candidates[normal_idx]
normal_idx += 1
# 提取重排后的结果
reranked_post_ids = [post_id for post_id, _ in selected]
reranked_scores = [score for _, score in selected]
return reranked_post_ids, reranked_scores
def insert_advertisements(self, post_ids, scores):
"""
在推荐结果中插入广告,每5条帖子插入1条广告
输入: post_ids, scores - 原始推荐结果
输出: 插入广告后的(post_ids, scores)
"""
# 获取可用的广告帖子
conn = pymysql.connect(**self.db_config)
cursor = conn.cursor()
try:
cursor.execute("""
SELECT id, heat FROM posts
WHERE is_advertisement = 1 AND status = 'published'
ORDER BY heat DESC
LIMIT 50
""")
ad_rows = cursor.fetchall()
if not ad_rows:
# 没有广告,直接返回原结果
return post_ids, scores
# 可用的广告帖子(排除已在推荐结果中的)
available_ads = [(ad_id, heat) for ad_id, heat in ad_rows if ad_id not in post_ids]
if not available_ads:
# 没有可用的新广告,直接返回原结果
return post_ids, scores
finally:
cursor.close()
conn.close()
# 插入广告的逻辑
result_posts = []
result_scores = []
ad_index = 0
for i, (post_id, score) in enumerate(zip(post_ids, scores)):
result_posts.append(post_id)
result_scores.append(score)
# 每5条帖子后插入一条广告
if (i + 1) % 5 == 0 and ad_index < len(available_ads):
ad_id, ad_heat = available_ads[ad_index]
result_posts.append(ad_id)
result_scores.append(float(ad_heat) / 1000.0) # 将热度转换为分数范围
ad_index += 1
return result_posts, result_scores
def user_cold_start(self, topk=None):
"""
冷启动:直接返回热度最高的topk个帖子详细信息
"""
if topk is None:
topk = self.topk
conn = pymysql.connect(**self.db_config)
cursor = conn.cursor()
try:
# 查询热度最高的topk个帖子
cursor.execute("""
SELECT p.id, p.user_id, p.title, p.content, p.type, p.heat, p.created_at
FROM posts p
WHERE p.status = 'published'
ORDER BY p.heat DESC
LIMIT %s
""", (topk,))
post_rows = cursor.fetchall()
post_ids = [row[0] for row in post_rows]
post_map = {row[0]: row for row in post_rows}
# 查询用户信息
owner_ids = list(set(row[1] for row in post_rows))
if owner_ids:
format_strings_user = ','.join(['%s'] * len(owner_ids))
cursor.execute(
f"SELECT id, username FROM users WHERE id IN ({format_strings_user})",
tuple(owner_ids)
)
user_rows = cursor.fetchall()
user_map = {row[0]: row[1] for row in user_rows}
else:
user_map = {}
# 查询帖子标签
if post_ids:
format_strings = ','.join(['%s'] * len(post_ids))
cursor.execute(
f"""SELECT pt.post_id, GROUP_CONCAT(t.name) as tags
FROM post_tags pt
JOIN tags t ON pt.tag_id = t.id
WHERE pt.post_id IN ({format_strings})
GROUP BY pt.post_id""",
tuple(post_ids)
)
tag_rows = cursor.fetchall()
tag_map = {row[0]: row[1] for row in tag_rows}
else:
tag_map = {}
post_list = []
for post_id in post_ids:
row = post_map.get(post_id)
if not row:
continue
owner_user_id = row[1]
post_list.append({
'post_id': post_id,
'title': row[2],
'content': row[3][:200] + '...' if len(row[3]) > 200 else row[3], # 截取前200字符
'type': row[4],
'username': user_map.get(owner_user_id, ""),
'heat': row[5],
'tags': tag_map.get(post_id, ""),
'created_at': str(row[6]) if row[6] else ""
})
return post_list
finally:
cursor.close()
conn.close()
def run_inference(self, user_id, topk=None, use_multi_recall=None):
"""
推荐推理主函数
Args:
user_id: 用户ID
topk: 推荐数量
use_multi_recall: 是否使用多路召回,None表示使用默认设置
"""
if topk is None:
topk = self.topk
# 决定使用哪种召回方式
if use_multi_recall is None:
use_multi_recall = self.multi_recall_enabled
return self._run_multi_recall_inference(user_id, topk)
def _run_multi_recall_inference(self, user_id, topk):
"""使用多路召回进行推荐,并可选择使用LightGCN重新打分"""
try:
# 初始化多路召回(如果尚未初始化)
self.init_multi_recall()
# 执行多路召回,召回更多候选物品
total_candidates = min(topk * 10, 500) # 召回候选数是最终推荐数的10倍
candidate_post_ids, candidate_scores, recall_breakdown = self.multi_recall_inference(
user_id, total_candidates
)
if not candidate_post_ids:
# 如果多路召回没有结果,回退到冷启动
print(f"用户 {user_id} 多路召回无结果,使用冷启动")
return self.user_cold_start(topk)
print(f"用户 {user_id} 多路召回候选数量: {len(candidate_post_ids)}")
print(f"召回来源分布: {self._get_recall_source_stats(recall_breakdown)}")
# 如果启用LightGCN重新打分,使用LightGCN对候选结果进行评分
if self.use_lightgcn_rerank:
print("使用LightGCN对多路召回结果进行重新打分...")
lightgcn_scores = self._get_lightgcn_scores(user_id, candidate_post_ids)
# 直接使用LightGCN分数,不进行融合
final_scores = lightgcn_scores
print(f"LightGCN打分完成,分数范围: [{min(lightgcn_scores):.4f}, {max(lightgcn_scores):.4f}]")
print(f"使用LightGCN分数进行重排")
else:
# 使用原始多路召回分数
final_scores = candidate_scores
# 使用MMR算法重排,包含广告约束
final_post_ids, final_scores = self.mmr_rerank_with_ads(
candidate_post_ids, final_scores, theta=0.5, target_size=topk
)
return final_post_ids, final_scores
except Exception as e:
print(f"多路召回失败: {str(e)},回退到LightGCN")
return self._run_lightgcn_inference(user_id, topk)
def _run_lightgcn_inference(self, user_id, topk):
"""使用原始LightGCN进行推荐"""
user2idx, post2idx = build_user_post_graph(return_mapping=True)
idx2post = {v: k for k, v in post2idx.items()}
if user_id not in user2idx:
# 冷启动
return self.user_cold_start(topk)
user_idx = user2idx[user_id]
dataset = EdgeListData(args.data_path, args.data_path)
pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
model = LightGCN(dataset, phase='vanilla').to(args.device)
model.load_state_dict(pretrained_dict, strict=False)
model.eval()
with torch.no_grad():
user_emb, item_emb = model.generate()
user_vec = user_emb[user_idx].unsqueeze(0)
scores = model.rating(user_vec, item_emb).squeeze(0)
# 获取所有物品的分数(而不是只取top候选)
all_scores = scores.cpu().numpy()
all_post_ids = [idx2post[idx] for idx in range(len(all_scores))]
# 过滤掉分数为负的物品,只保留正分数的候选
positive_candidates = [(post_id, score) for post_id, score in zip(all_post_ids, all_scores) if score > 0]
if not positive_candidates:
# 如果没有正分数的候选,取分数最高的一些
sorted_candidates = sorted(zip(all_post_ids, all_scores), key=lambda x: x[1], reverse=True)
positive_candidates = sorted_candidates[:min(100, len(sorted_candidates))]
candidate_post_ids = [post_id for post_id, _ in positive_candidates]
candidate_scores = [score for _, score in positive_candidates]
print(f"用户 {user_id} 的LightGCN候选物品数量: {len(candidate_post_ids)}")
# 使用MMR算法重排,包含广告约束,theta=0.5平衡相关性和多样性
final_post_ids, final_scores = self.mmr_rerank_with_ads(
candidate_post_ids, candidate_scores, theta=0.5, target_size=topk
)
return final_post_ids, final_scores
def _get_recall_source_stats(self, recall_breakdown):
"""获取召回来源统计"""
stats = {}
for source, items in recall_breakdown.items():
stats[source] = len(items)
return stats
def get_post_info(self, topk_post_ids, topk_scores=None):
"""
输入: topk_post_ids(帖子ID列表),topk_scores(对应的打分列表,可选)
输出: 推荐帖子的详细信息列表,每个元素为dict
"""
if not topk_post_ids:
return []
print(f"获取帖子详细信息,帖子ID列表: {topk_post_ids}")
if topk_scores is not None:
print(f"对应的推荐打分: {topk_scores}")
conn = pymysql.connect(**self.db_config)
cursor = conn.cursor()
try:
# 查询帖子基本信息
format_strings = ','.join(['%s'] * len(topk_post_ids))
cursor.execute(
f"""SELECT p.id, p.user_id, p.title, p.content, p.type, p.heat, p.created_at, p.is_advertisement
FROM posts p
WHERE p.id IN ({format_strings}) AND p.status = 'published'""",
tuple(topk_post_ids)
)
post_rows = cursor.fetchall()
post_map = {row[0]: row for row in post_rows}
# 查询用户信息
owner_ids = list(set(row[1] for row in post_rows))
if owner_ids:
format_strings_user = ','.join(['%s'] * len(owner_ids))
cursor.execute(
f"SELECT id, username FROM users WHERE id IN ({format_strings_user})",
tuple(owner_ids)
)
user_rows = cursor.fetchall()
user_map = {row[0]: row[1] for row in user_rows}
else:
user_map = {}
# 查询帖子标签
cursor.execute(
f"""SELECT pt.post_id, GROUP_CONCAT(t.name) as tags
FROM post_tags pt
JOIN tags t ON pt.tag_id = t.id
WHERE pt.post_id IN ({format_strings})
GROUP BY pt.post_id""",
tuple(topk_post_ids)
)
tag_rows = cursor.fetchall()
tag_map = {row[0]: row[1] for row in tag_rows}
# 查询行为统计(点赞数、评论数等)
cursor.execute(
f"""SELECT post_id, type, COUNT(*) as count
FROM behaviors
WHERE post_id IN ({format_strings})
GROUP BY post_id, type""",
tuple(topk_post_ids)
)
behavior_rows = cursor.fetchall()
behavior_stats = {}
for row in behavior_rows:
post_id, behavior_type, count = row
if post_id not in behavior_stats:
behavior_stats[post_id] = {}
behavior_stats[post_id][behavior_type] = count
post_list = []
for i, post_id in enumerate(topk_post_ids):
row = post_map.get(post_id)
if not row:
print(f"帖子ID {post_id} 不存在或未发布,跳过")
continue
owner_user_id = row[1]
stats = behavior_stats.get(post_id, {})
post_info = {
'post_id': post_id,
'title': row[2],
'content': row[3][:200] + '...' if len(row[3]) > 200 else row[3],
'type': row[4],
'username': user_map.get(owner_user_id, ""),
'heat': row[5],
'tags': tag_map.get(post_id, ""),
'created_at': str(row[6]) if row[6] else "",
'is_advertisement': bool(row[7]), # 添加广告标识
'like_count': stats.get('like', 0),
'comment_count': stats.get('comment', 0),
'favorite_count': stats.get('favorite', 0),
'view_count': stats.get('view', 0),
'share_count': stats.get('share', 0)
}
# 如果有推荐打分,添加到结果中
if topk_scores is not None and i < len(topk_scores):
post_info['recommendation_score'] = float(topk_scores[i])
post_list.append(post_info)
return post_list
finally:
cursor.close()
conn.close()
def get_recommendations(self, user_id, topk=None):
"""
获取推荐结果的主要接口
"""
try:
result = self.run_inference(user_id, topk)
# 如果是冷启动直接返回详细信息,否则查详情
if isinstance(result, list) and result and isinstance(result[0], dict):
return result
else:
# result 现在是 (topk_post_ids, topk_scores) 的元组
if isinstance(result, tuple) and len(result) == 2:
topk_post_ids, topk_scores = result
return self.get_post_info(topk_post_ids, topk_scores)
else:
# 兼容旧的返回格式
return self.get_post_info(result)
except Exception as e:
raise Exception(f"推荐系统错误: {str(e)}")
def get_all_item_scores(self, user_id):
"""
获取用户对所有物品的打分
输入: user_id
输出: (post_ids, scores) - 所有帖子ID和对应的打分
"""
user2idx, post2idx = build_user_post_graph(return_mapping=True)
idx2post = {v: k for k, v in post2idx.items()}
if user_id not in user2idx:
# 用户不存在,返回空结果
return [], []
user_idx = user2idx[user_id]
dataset = EdgeListData(args.data_path, args.data_path)
pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
model = LightGCN(dataset, phase='vanilla').to(args.device)
model.load_state_dict(pretrained_dict, strict=False)
model.eval()
with torch.no_grad():
user_emb, item_emb = model.generate()
user_vec = user_emb[user_idx].unsqueeze(0)
scores = model.rating(user_vec, item_emb).squeeze(0)
# 获取所有物品的ID和分数
all_scores = scores.cpu().numpy()
all_post_ids = [idx2post[idx] for idx in range(len(all_scores))]
return all_post_ids, all_scores
def init_multi_recall(self):
"""初始化多路召回管理器"""
if self.multi_recall is None:
print("初始化多路召回管理器...")
self.multi_recall = MultiRecallManager(self.db_config, self.recall_config)
print("多路召回管理器初始化完成")
def init_lightgcn_scorer(self):
"""初始化LightGCN评分器"""
if self.lightgcn_scorer is None:
print("初始化LightGCN评分器...")
self.lightgcn_scorer = LightGCNScorer()
print("LightGCN评分器初始化完成")
def _get_lightgcn_scores(self, user_id, candidate_post_ids):
"""
获取候选物品的LightGCN分数
Args:
user_id: 用户ID
candidate_post_ids: 候选物品ID列表
Returns:
List[float]: LightGCN分数列表
"""
self.init_lightgcn_scorer()
return self.lightgcn_scorer.score_batch_candidates(user_id, candidate_post_ids)
def _fuse_scores(self, multi_recall_scores, lightgcn_scores, alpha=0.6):
"""
融合多路召回分数和LightGCN分数
Args:
multi_recall_scores: 多路召回分数列表
lightgcn_scores: LightGCN分数列表
alpha: LightGCN分数的权重(0-1之间)
Returns:
List[float]: 融合后的分数列表
"""
if len(multi_recall_scores) != len(lightgcn_scores):
raise ValueError("分数列表长度不匹配")
# 对分数进行归一化
def normalize_scores(scores):
scores = np.array(scores)
min_score = np.min(scores)
max_score = np.max(scores)
if max_score == min_score:
return np.ones_like(scores) * 0.5
return (scores - min_score) / (max_score - min_score)
norm_multi_scores = normalize_scores(multi_recall_scores)
norm_lightgcn_scores = normalize_scores(lightgcn_scores)
# 加权融合
fused_scores = alpha * norm_lightgcn_scores + (1 - alpha) * norm_multi_scores
return fused_scores.tolist()
def train_multi_recall(self):
"""训练多路召回模型"""
self.init_multi_recall()
self.multi_recall.train_all()
def update_recall_config(self, new_config):
"""更新多路召回配置"""
self.recall_config.update(new_config)
if self.multi_recall:
self.multi_recall.update_config(new_config)
def multi_recall_inference(self, user_id, total_items=200):
"""
使用多路召回进行推荐
Args:
user_id: 用户ID
total_items: 总召回物品数量
Returns:
Tuple of (item_ids, scores, recall_breakdown)
"""
self.init_multi_recall()
# 执行多路召回
item_ids, scores, recall_results = self.multi_recall.recall(user_id, total_items)
return item_ids, scores, recall_results
def get_multi_recall_stats(self, user_id):
"""获取多路召回统计信息"""
if self.multi_recall is None:
return {"error": "多路召回未初始化"}
return self.multi_recall.get_recall_stats(user_id)