Raver | d789517 | 2025-06-18 17:54:38 +0800 | [diff] [blame] | 1 | import torch |
| 2 | import pymysql |
| 3 | import numpy as np |
| 4 | import random |
| 5 | from app.models.recommend.LightGCN import LightGCN |
| 6 | from app.models.recall import MultiRecallManager |
| 7 | from app.services.lightgcn_scorer import LightGCNScorer |
| 8 | from app.utils.parse_args import args |
| 9 | from app.utils.data_loader import EdgeListData |
| 10 | from app.utils.graph_build import build_user_post_graph |
| 11 | from config import Config |
| 12 | |
| 13 | class RecommendationService: |
| 14 | def __init__(self): |
| 15 | # 数据库连接配置 - 修改为redbook数据库 |
| 16 | self.db_config = { |
| 17 | 'host': '10.126.59.25', |
| 18 | 'port': 3306, |
| 19 | 'user': 'root', |
| 20 | 'password': '123456', |
| 21 | 'database': 'redbook', # 使用redbook数据库 |
| 22 | 'charset': 'utf8mb4' |
| 23 | } |
| 24 | |
| 25 | # 模型配置 |
| 26 | args.device = 'cuda:7' if torch.cuda.is_available() else 'cpu' |
| 27 | args.data_path = './app/user_post_graph.txt' # 修改为帖子图文件 |
| 28 | args.pre_model_path = './app/models/recommend/LightGCN_pretrained.pt' |
| 29 | |
| 30 | self.topk = 2 # 默认推荐数量 |
| 31 | |
| 32 | # 初始化多路召回管理器 |
| 33 | self.multi_recall = None |
| 34 | self.multi_recall_enabled = True # 控制是否启用多路召回 |
| 35 | |
| 36 | # 初始化LightGCN评分器 |
| 37 | self.lightgcn_scorer = None |
| 38 | self.use_lightgcn_rerank = True # 控制是否使用LightGCN对多路召回结果重新打分 |
| 39 | |
| 40 | # 多路召回配置 |
| 41 | self.recall_config = { |
| 42 | 'swing': { |
| 43 | 'enabled': True, |
| 44 | 'num_items': 20, # 增加召回数量 |
| 45 | 'alpha': 0.5 |
| 46 | }, |
| 47 | 'hot': { |
| 48 | 'enabled': True, |
| 49 | 'num_items': 15 # 增加热度召回数量 |
| 50 | }, |
| 51 | 'ad': { |
| 52 | 'enabled': True, |
| 53 | 'num_items': 5 # 增加广告召回数量 |
| 54 | }, |
| 55 | 'usercf': { |
| 56 | 'enabled': True, |
| 57 | 'num_items': 15, |
| 58 | 'min_common_items': 1, # 降低阈值,从3改为1 |
| 59 | 'num_similar_users': 20 # 减少相似用户数量以提高效率 |
| 60 | } |
| 61 | } |
| 62 | |
| 63 | def calculate_tag_similarity(self, tags1, tags2): |
| 64 | """ |
| 65 | 计算两个帖子标签的相似度 |
| 66 | 输入: tags1, tags2 - 标签字符串,以逗号分隔 |
| 67 | 输出: 相似度分数(0-1之间) |
| 68 | """ |
| 69 | if not tags1 or not tags2: |
| 70 | return 0.0 |
| 71 | |
| 72 | # 将标签字符串转换为集合 |
| 73 | set1 = set(tag.strip() for tag in tags1.split(',') if tag.strip()) |
| 74 | set2 = set(tag.strip() for tag in tags2.split(',') if tag.strip()) |
| 75 | |
| 76 | if not set1 or not set2: |
| 77 | return 0.0 |
| 78 | |
| 79 | # 计算标签重叠比例(Jaccard相似度) |
| 80 | intersection = len(set1.intersection(set2)) |
| 81 | union = len(set1.union(set2)) |
| 82 | |
| 83 | return intersection / union if union > 0 else 0.0 |
| 84 | |
| 85 | def mmr_rerank_with_ads(self, post_ids, scores, theta=0.5, target_size=None): |
| 86 | """ |
| 87 | 使用MMR算法重新排序推荐结果,并在过程中加入广告约束 |
| 88 | 输入: |
| 89 | - post_ids: 帖子ID列表 |
| 90 | - scores: 对应的推荐分数列表 |
| 91 | - theta: 平衡相关性和多样性的参数(0.5表示各占一半) |
| 92 | - target_size: 目标结果数量,默认与输入相同 |
| 93 | 输出: 重排后的(post_ids, scores),每5条帖子包含1条广告 |
| 94 | """ |
| 95 | if target_size is None: |
| 96 | target_size = len(post_ids) |
| 97 | |
| 98 | if len(post_ids) <= 1: |
| 99 | return post_ids, scores |
| 100 | |
| 101 | # 获取帖子标签信息和广告标识 |
| 102 | conn = pymysql.connect(**self.db_config) |
| 103 | cursor = conn.cursor() |
| 104 | |
| 105 | try: |
| 106 | # 查询所有候选帖子的标签和广告标识 |
| 107 | format_strings = ','.join(['%s'] * len(post_ids)) |
| 108 | cursor.execute( |
| 109 | f"""SELECT p.id, p.is_advertisement, |
| 110 | COALESCE(GROUP_CONCAT(t.name), '') as tags |
| 111 | FROM posts p |
| 112 | LEFT JOIN post_tags pt ON p.id = pt.post_id |
| 113 | LEFT JOIN tags t ON pt.tag_id = t.id |
| 114 | WHERE p.id IN ({format_strings}) AND p.status = 'published' |
| 115 | GROUP BY p.id, p.is_advertisement""", |
| 116 | tuple(post_ids) |
| 117 | ) |
| 118 | post_info_rows = cursor.fetchall() |
| 119 | post_tags = {} |
| 120 | post_is_ad = {} |
| 121 | |
| 122 | for row in post_info_rows: |
| 123 | post_id, is_ad, tags = row |
| 124 | post_tags[post_id] = tags or "" |
| 125 | post_is_ad[post_id] = bool(is_ad) |
| 126 | |
| 127 | # 对于没有查询到的帖子,设置默认值 |
| 128 | for post_id in post_ids: |
| 129 | if post_id not in post_tags: |
| 130 | post_tags[post_id] = "" |
| 131 | post_is_ad[post_id] = False |
| 132 | |
| 133 | # 获取额外的广告帖子作为候选 |
| 134 | cursor.execute(""" |
| 135 | SELECT id, heat FROM posts |
| 136 | WHERE is_advertisement = 1 AND status = 'published' |
| 137 | AND id NOT IN ({}) |
| 138 | ORDER BY heat DESC |
| 139 | LIMIT 50 |
| 140 | """.format(format_strings), tuple(post_ids)) |
| 141 | extra_ad_rows = cursor.fetchall() |
| 142 | |
| 143 | finally: |
| 144 | cursor.close() |
| 145 | conn.close() |
| 146 | |
| 147 | # 分离普通帖子和广告帖子 |
| 148 | normal_candidates = [] |
| 149 | ad_candidates = [] |
| 150 | |
| 151 | for post_id, score in zip(post_ids, scores): |
| 152 | if post_is_ad[post_id]: |
| 153 | ad_candidates.append((post_id, score)) |
| 154 | else: |
| 155 | normal_candidates.append((post_id, score)) |
| 156 | |
| 157 | # 添加额外的广告候选 |
| 158 | for ad_id, heat in extra_ad_rows: |
| 159 | # 为广告帖子设置标签和广告标识 |
| 160 | post_tags[ad_id] = "" # 广告帖子暂时设置为空标签 |
| 161 | post_is_ad[ad_id] = True |
| 162 | ad_score = float(heat) / 1000.0 # 将热度转换为分数 |
| 163 | ad_candidates.append((ad_id, ad_score)) |
| 164 | |
| 165 | # 排序候选列表 |
| 166 | normal_candidates.sort(key=lambda x: x[1], reverse=True) |
| 167 | ad_candidates.sort(key=lambda x: x[1], reverse=True) |
| 168 | |
| 169 | # MMR算法实现,加入广告约束 |
| 170 | selected = [] |
| 171 | normal_idx = 0 |
| 172 | ad_idx = 0 |
| 173 | |
| 174 | while len(selected) < target_size: |
| 175 | current_position = len(selected) |
| 176 | |
| 177 | # 检查是否需要插入广告(每5个位置插入1个广告) |
| 178 | if (current_position + 1) % 5 == 0 and ad_idx < len(ad_candidates): |
| 179 | # 插入广告 |
| 180 | selected.append(ad_candidates[ad_idx]) |
| 181 | ad_idx += 1 |
| 182 | else: |
| 183 | # 使用MMR选择普通帖子 |
| 184 | if normal_idx >= len(normal_candidates): |
| 185 | break |
| 186 | |
| 187 | best_score = -float('inf') |
| 188 | best_local_idx = normal_idx |
| 189 | |
| 190 | # 在剩余的普通候选中选择最佳的 |
| 191 | for i in range(normal_idx, min(normal_idx + 10, len(normal_candidates))): |
| 192 | post_id, relevance_score = normal_candidates[i] |
| 193 | |
| 194 | # 计算与已选帖子的最大相似度 |
| 195 | max_similarity = 0.0 |
| 196 | current_tags = post_tags[post_id] |
| 197 | |
| 198 | for selected_post_id, _ in selected: |
| 199 | selected_tags = post_tags[selected_post_id] |
| 200 | similarity = self.calculate_tag_similarity(current_tags, selected_tags) |
| 201 | max_similarity = max(max_similarity, similarity) |
| 202 | |
| 203 | # 计算MMR分数 |
| 204 | mmr_score = theta * relevance_score - (1 - theta) * max_similarity |
| 205 | |
| 206 | if mmr_score > best_score: |
| 207 | best_score = mmr_score |
| 208 | best_local_idx = i |
| 209 | |
| 210 | # 选择最佳候选 |
| 211 | selected.append(normal_candidates[best_local_idx]) |
| 212 | # 将选中的元素移到已处理区域 |
| 213 | normal_candidates[normal_idx], normal_candidates[best_local_idx] = \ |
| 214 | normal_candidates[best_local_idx], normal_candidates[normal_idx] |
| 215 | normal_idx += 1 |
| 216 | |
| 217 | # 提取重排后的结果 |
| 218 | reranked_post_ids = [post_id for post_id, _ in selected] |
| 219 | reranked_scores = [score for _, score in selected] |
| 220 | |
| 221 | return reranked_post_ids, reranked_scores |
| 222 | |
| 223 | def insert_advertisements(self, post_ids, scores): |
| 224 | """ |
| 225 | 在推荐结果中插入广告,每5条帖子插入1条广告 |
| 226 | 输入: post_ids, scores - 原始推荐结果 |
| 227 | 输出: 插入广告后的(post_ids, scores) |
| 228 | """ |
| 229 | # 获取可用的广告帖子 |
| 230 | conn = pymysql.connect(**self.db_config) |
| 231 | cursor = conn.cursor() |
| 232 | |
| 233 | try: |
| 234 | cursor.execute(""" |
| 235 | SELECT id, heat FROM posts |
| 236 | WHERE is_advertisement = 1 AND status = 'published' |
| 237 | ORDER BY heat DESC |
| 238 | LIMIT 50 |
| 239 | """) |
| 240 | ad_rows = cursor.fetchall() |
| 241 | |
| 242 | if not ad_rows: |
| 243 | # 没有广告,直接返回原结果 |
| 244 | return post_ids, scores |
| 245 | |
| 246 | # 可用的广告帖子(排除已在推荐结果中的) |
| 247 | available_ads = [(ad_id, heat) for ad_id, heat in ad_rows if ad_id not in post_ids] |
| 248 | |
| 249 | if not available_ads: |
| 250 | # 没有可用的新广告,直接返回原结果 |
| 251 | return post_ids, scores |
| 252 | |
| 253 | finally: |
| 254 | cursor.close() |
| 255 | conn.close() |
| 256 | |
| 257 | # 插入广告的逻辑 |
| 258 | result_posts = [] |
| 259 | result_scores = [] |
| 260 | ad_index = 0 |
| 261 | |
| 262 | for i, (post_id, score) in enumerate(zip(post_ids, scores)): |
| 263 | result_posts.append(post_id) |
| 264 | result_scores.append(score) |
| 265 | |
| 266 | # 每5条帖子后插入一条广告 |
| 267 | if (i + 1) % 5 == 0 and ad_index < len(available_ads): |
| 268 | ad_id, ad_heat = available_ads[ad_index] |
| 269 | result_posts.append(ad_id) |
| 270 | result_scores.append(float(ad_heat) / 1000.0) # 将热度转换为分数范围 |
| 271 | ad_index += 1 |
| 272 | |
| 273 | return result_posts, result_scores |
| 274 | |
| 275 | def user_cold_start(self, topk=None): |
| 276 | """ |
| 277 | 冷启动:直接返回热度最高的topk个帖子详细信息 |
| 278 | """ |
| 279 | if topk is None: |
| 280 | topk = self.topk |
| 281 | |
| 282 | conn = pymysql.connect(**self.db_config) |
| 283 | cursor = conn.cursor() |
| 284 | |
| 285 | try: |
| 286 | # 查询热度最高的topk个帖子 |
| 287 | cursor.execute(""" |
| 288 | SELECT p.id, p.user_id, p.title, p.content, p.type, p.heat, p.created_at |
| 289 | FROM posts p |
| 290 | WHERE p.status = 'published' |
| 291 | ORDER BY p.heat DESC |
| 292 | LIMIT %s |
| 293 | """, (topk,)) |
| 294 | post_rows = cursor.fetchall() |
| 295 | post_ids = [row[0] for row in post_rows] |
| 296 | post_map = {row[0]: row for row in post_rows} |
| 297 | |
| 298 | # 查询用户信息 |
| 299 | owner_ids = list(set(row[1] for row in post_rows)) |
| 300 | if owner_ids: |
| 301 | format_strings_user = ','.join(['%s'] * len(owner_ids)) |
| 302 | cursor.execute( |
| 303 | f"SELECT id, username FROM users WHERE id IN ({format_strings_user})", |
| 304 | tuple(owner_ids) |
| 305 | ) |
| 306 | user_rows = cursor.fetchall() |
| 307 | user_map = {row[0]: row[1] for row in user_rows} |
| 308 | else: |
| 309 | user_map = {} |
| 310 | |
| 311 | # 查询帖子标签 |
| 312 | if post_ids: |
| 313 | format_strings = ','.join(['%s'] * len(post_ids)) |
| 314 | cursor.execute( |
| 315 | f"""SELECT pt.post_id, GROUP_CONCAT(t.name) as tags |
| 316 | FROM post_tags pt |
| 317 | JOIN tags t ON pt.tag_id = t.id |
| 318 | WHERE pt.post_id IN ({format_strings}) |
| 319 | GROUP BY pt.post_id""", |
| 320 | tuple(post_ids) |
| 321 | ) |
| 322 | tag_rows = cursor.fetchall() |
| 323 | tag_map = {row[0]: row[1] for row in tag_rows} |
| 324 | else: |
| 325 | tag_map = {} |
| 326 | |
| 327 | post_list = [] |
| 328 | for post_id in post_ids: |
| 329 | row = post_map.get(post_id) |
| 330 | if not row: |
| 331 | continue |
| 332 | owner_user_id = row[1] |
| 333 | post_list.append({ |
| 334 | 'post_id': post_id, |
| 335 | 'title': row[2], |
| 336 | 'content': row[3][:200] + '...' if len(row[3]) > 200 else row[3], # 截取前200字符 |
| 337 | 'type': row[4], |
| 338 | 'username': user_map.get(owner_user_id, ""), |
| 339 | 'heat': row[5], |
| 340 | 'tags': tag_map.get(post_id, ""), |
| 341 | 'created_at': str(row[6]) if row[6] else "" |
| 342 | }) |
| 343 | return post_list |
| 344 | finally: |
| 345 | cursor.close() |
| 346 | conn.close() |
| 347 | |
| 348 | def run_inference(self, user_id, topk=None, use_multi_recall=None): |
| 349 | """ |
| 350 | 推荐推理主函数 |
| 351 | |
| 352 | Args: |
| 353 | user_id: 用户ID |
| 354 | topk: 推荐数量 |
| 355 | use_multi_recall: 是否使用多路召回,None表示使用默认设置 |
| 356 | """ |
| 357 | if topk is None: |
| 358 | topk = self.topk |
| 359 | |
| 360 | # 决定使用哪种召回方式 |
| 361 | if use_multi_recall is None: |
| 362 | use_multi_recall = self.multi_recall_enabled |
| 363 | |
| 364 | return self._run_multi_recall_inference(user_id, topk) |
| 365 | |
| 366 | def _run_multi_recall_inference(self, user_id, topk): |
| 367 | """使用多路召回进行推荐,并可选择使用LightGCN重新打分""" |
| 368 | try: |
| 369 | # 初始化多路召回(如果尚未初始化) |
| 370 | self.init_multi_recall() |
| 371 | |
| 372 | # 执行多路召回,召回更多候选物品 |
| 373 | total_candidates = min(topk * 10, 500) # 召回候选数是最终推荐数的10倍 |
| 374 | candidate_post_ids, candidate_scores, recall_breakdown = self.multi_recall_inference( |
| 375 | user_id, total_candidates |
| 376 | ) |
| 377 | |
| 378 | if not candidate_post_ids: |
| 379 | # 如果多路召回没有结果,回退到冷启动 |
| 380 | print(f"用户 {user_id} 多路召回无结果,使用冷启动") |
| 381 | return self.user_cold_start(topk) |
| 382 | |
| 383 | print(f"用户 {user_id} 多路召回候选数量: {len(candidate_post_ids)}") |
| 384 | print(f"召回来源分布: {self._get_recall_source_stats(recall_breakdown)}") |
| 385 | |
| 386 | # 如果启用LightGCN重新打分,使用LightGCN对候选结果进行评分 |
| 387 | if self.use_lightgcn_rerank: |
| 388 | print("使用LightGCN对多路召回结果进行重新打分...") |
| 389 | lightgcn_scores = self._get_lightgcn_scores(user_id, candidate_post_ids) |
| 390 | |
| 391 | # 直接使用LightGCN分数,不进行融合 |
| 392 | final_scores = lightgcn_scores |
| 393 | |
| 394 | print(f"LightGCN打分完成,分数范围: [{min(lightgcn_scores):.4f}, {max(lightgcn_scores):.4f}]") |
| 395 | print(f"使用LightGCN分数进行重排") |
| 396 | else: |
| 397 | # 使用原始多路召回分数 |
| 398 | final_scores = candidate_scores |
| 399 | |
| 400 | # 使用MMR算法重排,包含广告约束 |
| 401 | final_post_ids, final_scores = self.mmr_rerank_with_ads( |
| 402 | candidate_post_ids, final_scores, theta=0.5, target_size=topk |
| 403 | ) |
| 404 | |
| 405 | return final_post_ids, final_scores |
| 406 | |
| 407 | except Exception as e: |
| 408 | print(f"多路召回失败: {str(e)},回退到LightGCN") |
| 409 | return self._run_lightgcn_inference(user_id, topk) |
| 410 | |
| 411 | def _run_lightgcn_inference(self, user_id, topk): |
| 412 | """使用原始LightGCN进行推荐""" |
| 413 | user2idx, post2idx = build_user_post_graph(return_mapping=True) |
| 414 | idx2post = {v: k for k, v in post2idx.items()} |
| 415 | |
| 416 | if user_id not in user2idx: |
| 417 | # 冷启动 |
| 418 | return self.user_cold_start(topk) |
| 419 | |
| 420 | user_idx = user2idx[user_id] |
| 421 | |
| 422 | dataset = EdgeListData(args.data_path, args.data_path) |
| 423 | pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True) |
| 424 | pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users] |
| 425 | pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items] |
| 426 | |
| 427 | model = LightGCN(dataset, phase='vanilla').to(args.device) |
| 428 | model.load_state_dict(pretrained_dict, strict=False) |
| 429 | model.eval() |
| 430 | |
| 431 | with torch.no_grad(): |
| 432 | user_emb, item_emb = model.generate() |
| 433 | user_vec = user_emb[user_idx].unsqueeze(0) |
| 434 | scores = model.rating(user_vec, item_emb).squeeze(0) |
| 435 | |
| 436 | # 获取所有物品的分数(而不是只取top候选) |
| 437 | all_scores = scores.cpu().numpy() |
| 438 | all_post_ids = [idx2post[idx] for idx in range(len(all_scores))] |
| 439 | |
| 440 | # 过滤掉分数为负的物品,只保留正分数的候选 |
| 441 | positive_candidates = [(post_id, score) for post_id, score in zip(all_post_ids, all_scores) if score > 0] |
| 442 | |
| 443 | if not positive_candidates: |
| 444 | # 如果没有正分数的候选,取分数最高的一些 |
| 445 | sorted_candidates = sorted(zip(all_post_ids, all_scores), key=lambda x: x[1], reverse=True) |
| 446 | positive_candidates = sorted_candidates[:min(100, len(sorted_candidates))] |
| 447 | |
| 448 | candidate_post_ids = [post_id for post_id, _ in positive_candidates] |
| 449 | candidate_scores = [score for _, score in positive_candidates] |
| 450 | |
| 451 | print(f"用户 {user_id} 的LightGCN候选物品数量: {len(candidate_post_ids)}") |
| 452 | |
| 453 | # 使用MMR算法重排,包含广告约束,theta=0.5平衡相关性和多样性 |
| 454 | final_post_ids, final_scores = self.mmr_rerank_with_ads( |
| 455 | candidate_post_ids, candidate_scores, theta=0.5, target_size=topk |
| 456 | ) |
| 457 | |
| 458 | return final_post_ids, final_scores |
| 459 | |
| 460 | def _get_recall_source_stats(self, recall_breakdown): |
| 461 | """获取召回来源统计""" |
| 462 | stats = {} |
| 463 | for source, items in recall_breakdown.items(): |
| 464 | stats[source] = len(items) |
| 465 | return stats |
| 466 | |
| 467 | def get_post_info(self, topk_post_ids, topk_scores=None): |
| 468 | """ |
| 469 | 输入: topk_post_ids(帖子ID列表),topk_scores(对应的打分列表,可选) |
| 470 | 输出: 推荐帖子的详细信息列表,每个元素为dict |
| 471 | """ |
| 472 | if not topk_post_ids: |
| 473 | return [] |
| 474 | |
| 475 | print(f"获取帖子详细信息,帖子ID列表: {topk_post_ids}") |
| 476 | if topk_scores is not None: |
| 477 | print(f"对应的推荐打分: {topk_scores}") |
| 478 | |
| 479 | conn = pymysql.connect(**self.db_config) |
| 480 | cursor = conn.cursor() |
| 481 | |
| 482 | try: |
| 483 | # 查询帖子基本信息 |
| 484 | format_strings = ','.join(['%s'] * len(topk_post_ids)) |
| 485 | cursor.execute( |
TRM-coding | 3127efa | 2025-06-18 22:54:25 +0800 | [diff] [blame] | 486 | f"""SELECT p.id, p.user_id, p.title, p.content, p.type, p.heat, p.created_at, p.updated_at, p.media_urls, p.status, p.is_advertisement |
Raver | d789517 | 2025-06-18 17:54:38 +0800 | [diff] [blame] | 487 | FROM posts p |
| 488 | WHERE p.id IN ({format_strings}) AND p.status = 'published'""", |
| 489 | tuple(topk_post_ids) |
| 490 | ) |
| 491 | post_rows = cursor.fetchall() |
| 492 | post_map = {row[0]: row for row in post_rows} |
| 493 | |
| 494 | # 查询用户信息 |
| 495 | owner_ids = list(set(row[1] for row in post_rows)) |
| 496 | if owner_ids: |
| 497 | format_strings_user = ','.join(['%s'] * len(owner_ids)) |
| 498 | cursor.execute( |
| 499 | f"SELECT id, username FROM users WHERE id IN ({format_strings_user})", |
| 500 | tuple(owner_ids) |
| 501 | ) |
| 502 | user_rows = cursor.fetchall() |
| 503 | user_map = {row[0]: row[1] for row in user_rows} |
| 504 | else: |
| 505 | user_map = {} |
| 506 | |
| 507 | # 查询帖子标签 |
| 508 | cursor.execute( |
| 509 | f"""SELECT pt.post_id, GROUP_CONCAT(t.name) as tags |
| 510 | FROM post_tags pt |
| 511 | JOIN tags t ON pt.tag_id = t.id |
| 512 | WHERE pt.post_id IN ({format_strings}) |
| 513 | GROUP BY pt.post_id""", |
| 514 | tuple(topk_post_ids) |
| 515 | ) |
| 516 | tag_rows = cursor.fetchall() |
| 517 | tag_map = {row[0]: row[1] for row in tag_rows} |
| 518 | |
| 519 | # 查询行为统计(点赞数、评论数等) |
| 520 | cursor.execute( |
| 521 | f"""SELECT post_id, type, COUNT(*) as count |
| 522 | FROM behaviors |
| 523 | WHERE post_id IN ({format_strings}) |
| 524 | GROUP BY post_id, type""", |
| 525 | tuple(topk_post_ids) |
| 526 | ) |
| 527 | behavior_rows = cursor.fetchall() |
| 528 | behavior_stats = {} |
| 529 | for row in behavior_rows: |
| 530 | post_id, behavior_type, count = row |
| 531 | if post_id not in behavior_stats: |
| 532 | behavior_stats[post_id] = {} |
| 533 | behavior_stats[post_id][behavior_type] = count |
| 534 | |
| 535 | post_list = [] |
| 536 | for i, post_id in enumerate(topk_post_ids): |
| 537 | row = post_map.get(post_id) |
| 538 | if not row: |
| 539 | print(f"帖子ID {post_id} 不存在或未发布,跳过") |
| 540 | continue |
| 541 | owner_user_id = row[1] |
| 542 | stats = behavior_stats.get(post_id, {}) |
| 543 | post_info = { |
TRM-coding | 3127efa | 2025-06-18 22:54:25 +0800 | [diff] [blame] | 544 | 'id': post_id, |
| 545 | 'user_id': owner_user_id, |
Raver | d789517 | 2025-06-18 17:54:38 +0800 | [diff] [blame] | 546 | 'title': row[2], |
TRM-coding | 3127efa | 2025-06-18 22:54:25 +0800 | [diff] [blame] | 547 | 'content': row[3], # 不再截断,保持完整内容 |
| 548 | 'media_urls': row[8], |
| 549 | 'status': row[9], |
| 550 | 'heat': row[5], |
| 551 | 'created_at': row[6].isoformat() if row[6] else "", |
| 552 | 'updated_at': row[7].isoformat() if row[7] else "", |
| 553 | # 额外字段,可选保留 |
Raver | d789517 | 2025-06-18 17:54:38 +0800 | [diff] [blame] | 554 | 'type': row[4], |
| 555 | 'username': user_map.get(owner_user_id, ""), |
Raver | d789517 | 2025-06-18 17:54:38 +0800 | [diff] [blame] | 556 | 'tags': tag_map.get(post_id, ""), |
TRM-coding | 3127efa | 2025-06-18 22:54:25 +0800 | [diff] [blame] | 557 | 'is_advertisement': bool(row[10]), |
Raver | d789517 | 2025-06-18 17:54:38 +0800 | [diff] [blame] | 558 | 'like_count': stats.get('like', 0), |
| 559 | 'comment_count': stats.get('comment', 0), |
| 560 | 'favorite_count': stats.get('favorite', 0), |
| 561 | 'view_count': stats.get('view', 0), |
| 562 | 'share_count': stats.get('share', 0) |
| 563 | } |
| 564 | |
Raver | d789517 | 2025-06-18 17:54:38 +0800 | [diff] [blame] | 565 | post_list.append(post_info) |
| 566 | return post_list |
| 567 | finally: |
| 568 | cursor.close() |
| 569 | conn.close() |
| 570 | |
| 571 | def get_recommendations(self, user_id, topk=None): |
| 572 | """ |
| 573 | 获取推荐结果的主要接口 |
| 574 | """ |
| 575 | try: |
| 576 | result = self.run_inference(user_id, topk) |
| 577 | # 如果是冷启动直接返回详细信息,否则查详情 |
| 578 | if isinstance(result, list) and result and isinstance(result[0], dict): |
| 579 | return result |
| 580 | else: |
| 581 | # result 现在是 (topk_post_ids, topk_scores) 的元组 |
| 582 | if isinstance(result, tuple) and len(result) == 2: |
| 583 | topk_post_ids, topk_scores = result |
| 584 | return self.get_post_info(topk_post_ids, topk_scores) |
| 585 | else: |
| 586 | # 兼容旧的返回格式 |
| 587 | return self.get_post_info(result) |
| 588 | except Exception as e: |
| 589 | raise Exception(f"推荐系统错误: {str(e)}") |
| 590 | |
| 591 | def get_all_item_scores(self, user_id): |
| 592 | """ |
| 593 | 获取用户对所有物品的打分 |
| 594 | 输入: user_id |
| 595 | 输出: (post_ids, scores) - 所有帖子ID和对应的打分 |
| 596 | """ |
| 597 | user2idx, post2idx = build_user_post_graph(return_mapping=True) |
| 598 | idx2post = {v: k for k, v in post2idx.items()} |
| 599 | |
| 600 | if user_id not in user2idx: |
| 601 | # 用户不存在,返回空结果 |
| 602 | return [], [] |
| 603 | |
| 604 | user_idx = user2idx[user_id] |
| 605 | |
| 606 | dataset = EdgeListData(args.data_path, args.data_path) |
| 607 | pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True) |
| 608 | pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users] |
| 609 | pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items] |
| 610 | |
| 611 | model = LightGCN(dataset, phase='vanilla').to(args.device) |
| 612 | model.load_state_dict(pretrained_dict, strict=False) |
| 613 | model.eval() |
| 614 | |
| 615 | with torch.no_grad(): |
| 616 | user_emb, item_emb = model.generate() |
| 617 | user_vec = user_emb[user_idx].unsqueeze(0) |
| 618 | scores = model.rating(user_vec, item_emb).squeeze(0) |
| 619 | |
| 620 | # 获取所有物品的ID和分数 |
| 621 | all_scores = scores.cpu().numpy() |
| 622 | all_post_ids = [idx2post[idx] for idx in range(len(all_scores))] |
| 623 | |
| 624 | return all_post_ids, all_scores |
| 625 | |
| 626 | def init_multi_recall(self): |
| 627 | """初始化多路召回管理器""" |
| 628 | if self.multi_recall is None: |
| 629 | print("初始化多路召回管理器...") |
| 630 | self.multi_recall = MultiRecallManager(self.db_config, self.recall_config) |
| 631 | print("多路召回管理器初始化完成") |
| 632 | |
| 633 | def init_lightgcn_scorer(self): |
| 634 | """初始化LightGCN评分器""" |
| 635 | if self.lightgcn_scorer is None: |
| 636 | print("初始化LightGCN评分器...") |
| 637 | self.lightgcn_scorer = LightGCNScorer() |
| 638 | print("LightGCN评分器初始化完成") |
| 639 | |
| 640 | def _get_lightgcn_scores(self, user_id, candidate_post_ids): |
| 641 | """ |
| 642 | 获取候选物品的LightGCN分数 |
| 643 | |
| 644 | Args: |
| 645 | user_id: 用户ID |
| 646 | candidate_post_ids: 候选物品ID列表 |
| 647 | |
| 648 | Returns: |
| 649 | List[float]: LightGCN分数列表 |
| 650 | """ |
| 651 | self.init_lightgcn_scorer() |
| 652 | return self.lightgcn_scorer.score_batch_candidates(user_id, candidate_post_ids) |
| 653 | |
| 654 | def _fuse_scores(self, multi_recall_scores, lightgcn_scores, alpha=0.6): |
| 655 | """ |
| 656 | 融合多路召回分数和LightGCN分数 |
| 657 | |
| 658 | Args: |
| 659 | multi_recall_scores: 多路召回分数列表 |
| 660 | lightgcn_scores: LightGCN分数列表 |
| 661 | alpha: LightGCN分数的权重(0-1之间) |
| 662 | |
| 663 | Returns: |
| 664 | List[float]: 融合后的分数列表 |
| 665 | """ |
| 666 | if len(multi_recall_scores) != len(lightgcn_scores): |
| 667 | raise ValueError("分数列表长度不匹配") |
| 668 | |
| 669 | # 对分数进行归一化 |
| 670 | def normalize_scores(scores): |
| 671 | scores = np.array(scores) |
| 672 | min_score = np.min(scores) |
| 673 | max_score = np.max(scores) |
| 674 | if max_score == min_score: |
| 675 | return np.ones_like(scores) * 0.5 |
| 676 | return (scores - min_score) / (max_score - min_score) |
| 677 | |
| 678 | norm_multi_scores = normalize_scores(multi_recall_scores) |
| 679 | norm_lightgcn_scores = normalize_scores(lightgcn_scores) |
| 680 | |
| 681 | # 加权融合 |
| 682 | fused_scores = alpha * norm_lightgcn_scores + (1 - alpha) * norm_multi_scores |
| 683 | |
| 684 | return fused_scores.tolist() |
| 685 | |
| 686 | def train_multi_recall(self): |
| 687 | """训练多路召回模型""" |
| 688 | self.init_multi_recall() |
| 689 | self.multi_recall.train_all() |
| 690 | |
| 691 | def update_recall_config(self, new_config): |
| 692 | """更新多路召回配置""" |
| 693 | self.recall_config.update(new_config) |
| 694 | if self.multi_recall: |
| 695 | self.multi_recall.update_config(new_config) |
| 696 | |
| 697 | def multi_recall_inference(self, user_id, total_items=200): |
| 698 | """ |
| 699 | 使用多路召回进行推荐 |
| 700 | |
| 701 | Args: |
| 702 | user_id: 用户ID |
| 703 | total_items: 总召回物品数量 |
| 704 | |
| 705 | Returns: |
| 706 | Tuple of (item_ids, scores, recall_breakdown) |
| 707 | """ |
| 708 | self.init_multi_recall() |
| 709 | |
| 710 | # 执行多路召回 |
| 711 | item_ids, scores, recall_results = self.multi_recall.recall(user_id, total_items) |
| 712 | |
| 713 | return item_ids, scores, recall_results |
| 714 | |
| 715 | def get_multi_recall_stats(self, user_id): |
| 716 | """获取多路召回统计信息""" |
| 717 | if self.multi_recall is None: |
| 718 | return {"error": "多路召回未初始化"} |
| 719 | |
| 720 | return self.multi_recall.get_recall_stats(user_id) |