Raver | d789517 | 2025-06-18 17:54:38 +0800 | [diff] [blame] | 1 | import numpy as np |
| 2 | import pymysql |
| 3 | from collections import defaultdict |
| 4 | import math |
| 5 | from typing import List, Tuple, Dict |
| 6 | |
| 7 | class SwingRecall: |
| 8 | """ |
| 9 | Swing召回算法实现 |
| 10 | 基于物品相似度的协同过滤算法,能够有效处理热门物品的问题 |
| 11 | """ |
| 12 | |
| 13 | def __init__(self, db_config: dict, alpha: float = 0.5): |
| 14 | """ |
| 15 | 初始化Swing召回模型 |
| 16 | |
| 17 | Args: |
| 18 | db_config: 数据库配置 |
| 19 | alpha: 控制热门物品惩罚的参数,值越大惩罚越强 |
| 20 | """ |
| 21 | self.db_config = db_config |
| 22 | self.alpha = alpha |
| 23 | self.item_similarity = {} |
| 24 | self.user_items = defaultdict(set) |
| 25 | self.item_users = defaultdict(set) |
| 26 | |
| 27 | def _get_interaction_data(self): |
| 28 | """获取用户-物品交互数据""" |
| 29 | conn = pymysql.connect(**self.db_config) |
| 30 | try: |
| 31 | cursor = conn.cursor() |
| 32 | # 获取用户行为数据(点赞、收藏、评论等) |
| 33 | cursor.execute(""" |
| 34 | SELECT DISTINCT user_id, post_id |
| 35 | FROM behaviors |
| 36 | WHERE type IN ('like', 'favorite', 'comment') |
| 37 | """) |
| 38 | interactions = cursor.fetchall() |
| 39 | |
| 40 | for user_id, post_id in interactions: |
| 41 | self.user_items[user_id].add(post_id) |
| 42 | self.item_users[post_id].add(user_id) |
| 43 | |
| 44 | finally: |
| 45 | cursor.close() |
| 46 | conn.close() |
| 47 | |
| 48 | def _calculate_swing_similarity(self): |
| 49 | """计算Swing相似度矩阵""" |
| 50 | print("开始计算Swing相似度...") |
| 51 | |
| 52 | # 获取所有物品对 |
| 53 | items = list(self.item_users.keys()) |
| 54 | |
| 55 | for i, item_i in enumerate(items): |
| 56 | if i % 100 == 0: |
| 57 | print(f"处理进度: {i}/{len(items)}") |
| 58 | |
| 59 | self.item_similarity[item_i] = {} |
| 60 | |
| 61 | for item_j in items[i+1:]: |
| 62 | # 获取同时交互过两个物品的用户 |
| 63 | common_users = self.item_users[item_i] & self.item_users[item_j] |
| 64 | |
| 65 | if len(common_users) < 2: # 需要至少2个共同用户 |
| 66 | similarity = 0.0 |
| 67 | else: |
| 68 | # 计算Swing相似度 |
| 69 | similarity = 0.0 |
| 70 | for u in common_users: |
| 71 | for v in common_users: |
| 72 | if u != v: |
| 73 | # Swing算法的核心公式 |
| 74 | swing_weight = 1.0 / (self.alpha + len(self.user_items[u] & self.user_items[v])) |
| 75 | similarity += swing_weight |
| 76 | |
| 77 | # 归一化 |
| 78 | similarity = similarity / (len(common_users) * (len(common_users) - 1)) |
| 79 | |
| 80 | self.item_similarity[item_i][item_j] = similarity |
| 81 | # 对称性 |
| 82 | if item_j not in self.item_similarity: |
| 83 | self.item_similarity[item_j] = {} |
| 84 | self.item_similarity[item_j][item_i] = similarity |
| 85 | |
| 86 | print("Swing相似度计算完成") |
| 87 | |
| 88 | def train(self): |
| 89 | """训练Swing模型""" |
| 90 | self._get_interaction_data() |
| 91 | self._calculate_swing_similarity() |
| 92 | |
| 93 | def recall(self, user_id: int, num_items: int = 50) -> List[Tuple[int, float]]: |
| 94 | """ |
| 95 | 为用户召回相似物品 |
| 96 | |
| 97 | Args: |
| 98 | user_id: 用户ID |
| 99 | num_items: 召回物品数量 |
| 100 | |
| 101 | Returns: |
| 102 | List of (item_id, score) tuples |
| 103 | """ |
| 104 | # 如果尚未训练,先进行训练 |
| 105 | if not hasattr(self, 'item_similarity') or not self.item_similarity: |
| 106 | self.train() |
| 107 | |
| 108 | if user_id not in self.user_items: |
| 109 | return [] |
| 110 | |
| 111 | # 获取用户历史交互的物品 |
| 112 | user_interacted_items = self.user_items[user_id] |
| 113 | |
| 114 | # 计算候选物品的分数 |
| 115 | candidate_scores = defaultdict(float) |
| 116 | |
| 117 | for item_i in user_interacted_items: |
| 118 | if item_i in self.item_similarity: |
| 119 | for item_j, similarity in self.item_similarity[item_i].items(): |
| 120 | # 排除用户已经交互过的物品 |
| 121 | if item_j not in user_interacted_items: |
| 122 | candidate_scores[item_j] += similarity |
| 123 | |
| 124 | # 按分数排序并返回top-N |
| 125 | sorted_candidates = sorted(candidate_scores.items(), key=lambda x: x[1], reverse=True) |
| 126 | return sorted_candidates[:num_items] |