blob: bf7fdd6fed2bd4170cd8fffad6af40aa60fce31c [file] [log] [blame]
Raverd7895172025-06-18 17:54:38 +08001import numpy as np
2import pymysql
3from collections import defaultdict
4import math
5from typing import List, Tuple, Dict
6
7class SwingRecall:
8 """
9 Swing召回算法实现
10 基于物品相似度的协同过滤算法,能够有效处理热门物品的问题
11 """
12
13 def __init__(self, db_config: dict, alpha: float = 0.5):
14 """
15 初始化Swing召回模型
16
17 Args:
18 db_config: 数据库配置
19 alpha: 控制热门物品惩罚的参数,值越大惩罚越强
20 """
21 self.db_config = db_config
22 self.alpha = alpha
23 self.item_similarity = {}
24 self.user_items = defaultdict(set)
25 self.item_users = defaultdict(set)
26
27 def _get_interaction_data(self):
28 """获取用户-物品交互数据"""
29 conn = pymysql.connect(**self.db_config)
30 try:
31 cursor = conn.cursor()
32 # 获取用户行为数据(点赞、收藏、评论等)
33 cursor.execute("""
34 SELECT DISTINCT user_id, post_id
35 FROM behaviors
36 WHERE type IN ('like', 'favorite', 'comment')
37 """)
38 interactions = cursor.fetchall()
39
40 for user_id, post_id in interactions:
41 self.user_items[user_id].add(post_id)
42 self.item_users[post_id].add(user_id)
43
44 finally:
45 cursor.close()
46 conn.close()
47
48 def _calculate_swing_similarity(self):
49 """计算Swing相似度矩阵"""
50 print("开始计算Swing相似度...")
51
52 # 获取所有物品对
53 items = list(self.item_users.keys())
54
55 for i, item_i in enumerate(items):
56 if i % 100 == 0:
57 print(f"处理进度: {i}/{len(items)}")
58
59 self.item_similarity[item_i] = {}
60
61 for item_j in items[i+1:]:
62 # 获取同时交互过两个物品的用户
63 common_users = self.item_users[item_i] & self.item_users[item_j]
64
65 if len(common_users) < 2: # 需要至少2个共同用户
66 similarity = 0.0
67 else:
68 # 计算Swing相似度
69 similarity = 0.0
70 for u in common_users:
71 for v in common_users:
72 if u != v:
73 # Swing算法的核心公式
74 swing_weight = 1.0 / (self.alpha + len(self.user_items[u] & self.user_items[v]))
75 similarity += swing_weight
76
77 # 归一化
78 similarity = similarity / (len(common_users) * (len(common_users) - 1))
79
80 self.item_similarity[item_i][item_j] = similarity
81 # 对称性
82 if item_j not in self.item_similarity:
83 self.item_similarity[item_j] = {}
84 self.item_similarity[item_j][item_i] = similarity
85
86 print("Swing相似度计算完成")
87
88 def train(self):
89 """训练Swing模型"""
90 self._get_interaction_data()
91 self._calculate_swing_similarity()
92
93 def recall(self, user_id: int, num_items: int = 50) -> List[Tuple[int, float]]:
94 """
95 为用户召回相似物品
96
97 Args:
98 user_id: 用户ID
99 num_items: 召回物品数量
100
101 Returns:
102 List of (item_id, score) tuples
103 """
104 # 如果尚未训练,先进行训练
105 if not hasattr(self, 'item_similarity') or not self.item_similarity:
106 self.train()
107
108 if user_id not in self.user_items:
109 return []
110
111 # 获取用户历史交互的物品
112 user_interacted_items = self.user_items[user_id]
113
114 # 计算候选物品的分数
115 candidate_scores = defaultdict(float)
116
117 for item_i in user_interacted_items:
118 if item_i in self.item_similarity:
119 for item_j, similarity in self.item_similarity[item_i].items():
120 # 排除用户已经交互过的物品
121 if item_j not in user_interacted_items:
122 candidate_scores[item_j] += similarity
123
124 # 按分数排序并返回top-N
125 sorted_candidates = sorted(candidate_scores.items(), key=lambda x: x[1], reverse=True)
126 return sorted_candidates[:num_items]