blob: d75e6d85cf0790fccf76f94dbbd18853ce090612 [file] [log] [blame]
Raverd7895172025-06-18 17:54:38 +08001import pymysql
2from typing import List, Tuple, Dict, Set
3from collections import defaultdict
4import math
5import numpy as np
6
7class UserCFRecall:
8 """
9 UserCF (User-based Collaborative Filtering) 召回算法实现
10 基于用户相似度的协同过滤算法
11 """
12
13 def __init__(self, db_config: dict, min_common_items: int = 3):
14 """
15 初始化UserCF召回模型
16
17 Args:
18 db_config: 数据库配置
19 min_common_items: 计算用户相似度时的最小共同物品数
20 """
21 self.db_config = db_config
22 self.min_common_items = min_common_items
23 self.user_items = defaultdict(set)
24 self.item_users = defaultdict(set)
25 self.user_similarity = {}
26
27 def _get_user_item_interactions(self):
28 """获取用户-物品交互数据"""
29 conn = pymysql.connect(**self.db_config)
30 try:
31 cursor = conn.cursor()
32
33 # 获取用户行为数据,考虑不同行为的权重
34 cursor.execute("""
35 SELECT user_id, post_id, type, COUNT(*) as count
36 FROM behaviors
37 WHERE type IN ('like', 'favorite', 'comment', 'view')
38 GROUP BY user_id, post_id, type
39 """)
40
41 interactions = cursor.fetchall()
42
43 # 构建用户-物品交互矩阵(考虑行为权重)
44 user_item_scores = defaultdict(lambda: defaultdict(float))
45
46 # 定义不同行为的权重
47 behavior_weights = {
48 'like': 1.0,
49 'favorite': 2.0,
50 'comment': 3.0,
51 'view': 0.1
52 }
53
54 for user_id, post_id, behavior_type, count in interactions:
55 weight = behavior_weights.get(behavior_type, 1.0)
56 score = weight * count
57 user_item_scores[user_id][post_id] += score
58
59 # 转换为集合形式(用于相似度计算)
60 for user_id, items in user_item_scores.items():
61 # 只保留分数大于阈值的物品
62 threshold = 1.0 # 可调整阈值
63 for item_id, score in items.items():
64 if score >= threshold:
65 self.user_items[user_id].add(item_id)
66 self.item_users[item_id].add(user_id)
67
68 finally:
69 cursor.close()
70 conn.close()
71
72 def _calculate_user_similarity(self):
73 """计算用户相似度矩阵"""
74 print("开始计算用户相似度...")
75
76 users = list(self.user_items.keys())
77 total_pairs = len(users) * (len(users) - 1) // 2
78 processed = 0
79
80 for i, user_i in enumerate(users):
81 self.user_similarity[user_i] = {}
82
83 for user_j in users[i+1:]:
84 processed += 1
85 if processed % 10000 == 0:
86 print(f"处理进度: {processed}/{total_pairs}")
87
88 # 获取两个用户共同交互的物品
89 common_items = self.user_items[user_i] & self.user_items[user_j]
90
91 if len(common_items) < self.min_common_items:
92 similarity = 0.0
93 else:
94 # 计算余弦相似度
95 numerator = len(common_items)
96 denominator = math.sqrt(len(self.user_items[user_i]) * len(self.user_items[user_j]))
97 similarity = numerator / denominator if denominator > 0 else 0.0
98
99 self.user_similarity[user_i][user_j] = similarity
100 # 对称性
101 if user_j not in self.user_similarity:
102 self.user_similarity[user_j] = {}
103 self.user_similarity[user_j][user_i] = similarity
104
105 print("用户相似度计算完成")
106
107 def train(self):
108 """训练UserCF模型"""
109 self._get_user_item_interactions()
110 self._calculate_user_similarity()
111
112 def recall(self, user_id: int, num_items: int = 50, num_similar_users: int = 50) -> List[Tuple[int, float]]:
113 """
114 为用户召回相似用户喜欢的物品
115
116 Args:
117 user_id: 目标用户ID
118 num_items: 召回物品数量
119 num_similar_users: 考虑的相似用户数量
120
121 Returns:
122 List of (item_id, score) tuples
123 """
124 # 如果尚未训练,先进行训练
125 if not hasattr(self, 'user_similarity') or not self.user_similarity:
126 self.train()
127
128 if user_id not in self.user_similarity or user_id not in self.user_items:
129 return []
130
131 # 获取最相似的用户
132 similar_users = sorted(
133 self.user_similarity[user_id].items(),
134 key=lambda x: x[1],
135 reverse=True
136 )[:num_similar_users]
137
138 # 获取目标用户已交互的物品
139 user_interacted_items = self.user_items[user_id]
140
141 # 计算候选物品的分数
142 candidate_scores = defaultdict(float)
143
144 for similar_user_id, similarity in similar_users:
145 if similarity <= 0:
146 continue
147
148 # 获取相似用户交互的物品
149 similar_user_items = self.user_items[similar_user_id]
150
151 for item_id in similar_user_items:
152 # 排除目标用户已经交互过的物品
153 if item_id not in user_interacted_items:
154 candidate_scores[item_id] += similarity
155
156 # 按分数排序并返回top-N
157 sorted_candidates = sorted(candidate_scores.items(), key=lambda x: x[1], reverse=True)
158 return sorted_candidates[:num_items]
159
160 def get_user_neighbors(self, user_id: int, num_neighbors: int = 10) -> List[Tuple[int, float]]:
161 """
162 获取用户的相似邻居
163
164 Args:
165 user_id: 用户ID
166 num_neighbors: 邻居数量
167
168 Returns:
169 List of (neighbor_user_id, similarity) tuples
170 """
171 if user_id not in self.user_similarity:
172 return []
173
174 neighbors = sorted(
175 self.user_similarity[user_id].items(),
176 key=lambda x: x[1],
177 reverse=True
178 )[:num_neighbors]
179
180 return neighbors
181
182 def get_user_profile(self, user_id: int) -> Dict:
183 """
184 获取用户画像信息
185
186 Args:
187 user_id: 用户ID
188
189 Returns:
190 用户画像字典
191 """
192 if user_id not in self.user_items:
193 return {}
194
195 conn = pymysql.connect(**self.db_config)
196 try:
197 cursor = conn.cursor()
198
199 # 获取用户交互的物品类别统计
200 user_item_list = list(self.user_items[user_id])
201 if not user_item_list:
202 return {}
203
204 format_strings = ','.join(['%s'] * len(user_item_list))
205 cursor.execute(f"""
206 SELECT t.name, COUNT(*) as count
207 FROM post_tags pt
208 JOIN tags t ON pt.tag_id = t.id
209 WHERE pt.post_id IN ({format_strings})
210 GROUP BY t.name
211 ORDER BY count DESC
212 """, tuple(user_item_list))
213
214 tag_preferences = cursor.fetchall()
215
216 # 获取用户行为统计
217 cursor.execute("""
218 SELECT type, COUNT(*) as count
219 FROM behaviors
220 WHERE user_id = %s
221 GROUP BY type
222 """, (user_id,))
223
224 behavior_stats = cursor.fetchall()
225
226 return {
227 'user_id': user_id,
228 'total_interactions': len(self.user_items[user_id]),
229 'tag_preferences': dict(tag_preferences),
230 'behavior_stats': dict(behavior_stats)
231 }
232
233 finally:
234 cursor.close()
235 conn.close()