Raver | d789517 | 2025-06-18 17:54:38 +0800 | [diff] [blame^] | 1 | import pymysql |
| 2 | from typing import List, Tuple, Dict, Set |
| 3 | from collections import defaultdict |
| 4 | import math |
| 5 | import numpy as np |
| 6 | |
| 7 | class UserCFRecall: |
| 8 | """ |
| 9 | UserCF (User-based Collaborative Filtering) 召回算法实现 |
| 10 | 基于用户相似度的协同过滤算法 |
| 11 | """ |
| 12 | |
| 13 | def __init__(self, db_config: dict, min_common_items: int = 3): |
| 14 | """ |
| 15 | 初始化UserCF召回模型 |
| 16 | |
| 17 | Args: |
| 18 | db_config: 数据库配置 |
| 19 | min_common_items: 计算用户相似度时的最小共同物品数 |
| 20 | """ |
| 21 | self.db_config = db_config |
| 22 | self.min_common_items = min_common_items |
| 23 | self.user_items = defaultdict(set) |
| 24 | self.item_users = defaultdict(set) |
| 25 | self.user_similarity = {} |
| 26 | |
| 27 | def _get_user_item_interactions(self): |
| 28 | """获取用户-物品交互数据""" |
| 29 | conn = pymysql.connect(**self.db_config) |
| 30 | try: |
| 31 | cursor = conn.cursor() |
| 32 | |
| 33 | # 获取用户行为数据,考虑不同行为的权重 |
| 34 | cursor.execute(""" |
| 35 | SELECT user_id, post_id, type, COUNT(*) as count |
| 36 | FROM behaviors |
| 37 | WHERE type IN ('like', 'favorite', 'comment', 'view') |
| 38 | GROUP BY user_id, post_id, type |
| 39 | """) |
| 40 | |
| 41 | interactions = cursor.fetchall() |
| 42 | |
| 43 | # 构建用户-物品交互矩阵(考虑行为权重) |
| 44 | user_item_scores = defaultdict(lambda: defaultdict(float)) |
| 45 | |
| 46 | # 定义不同行为的权重 |
| 47 | behavior_weights = { |
| 48 | 'like': 1.0, |
| 49 | 'favorite': 2.0, |
| 50 | 'comment': 3.0, |
| 51 | 'view': 0.1 |
| 52 | } |
| 53 | |
| 54 | for user_id, post_id, behavior_type, count in interactions: |
| 55 | weight = behavior_weights.get(behavior_type, 1.0) |
| 56 | score = weight * count |
| 57 | user_item_scores[user_id][post_id] += score |
| 58 | |
| 59 | # 转换为集合形式(用于相似度计算) |
| 60 | for user_id, items in user_item_scores.items(): |
| 61 | # 只保留分数大于阈值的物品 |
| 62 | threshold = 1.0 # 可调整阈值 |
| 63 | for item_id, score in items.items(): |
| 64 | if score >= threshold: |
| 65 | self.user_items[user_id].add(item_id) |
| 66 | self.item_users[item_id].add(user_id) |
| 67 | |
| 68 | finally: |
| 69 | cursor.close() |
| 70 | conn.close() |
| 71 | |
| 72 | def _calculate_user_similarity(self): |
| 73 | """计算用户相似度矩阵""" |
| 74 | print("开始计算用户相似度...") |
| 75 | |
| 76 | users = list(self.user_items.keys()) |
| 77 | total_pairs = len(users) * (len(users) - 1) // 2 |
| 78 | processed = 0 |
| 79 | |
| 80 | for i, user_i in enumerate(users): |
| 81 | self.user_similarity[user_i] = {} |
| 82 | |
| 83 | for user_j in users[i+1:]: |
| 84 | processed += 1 |
| 85 | if processed % 10000 == 0: |
| 86 | print(f"处理进度: {processed}/{total_pairs}") |
| 87 | |
| 88 | # 获取两个用户共同交互的物品 |
| 89 | common_items = self.user_items[user_i] & self.user_items[user_j] |
| 90 | |
| 91 | if len(common_items) < self.min_common_items: |
| 92 | similarity = 0.0 |
| 93 | else: |
| 94 | # 计算余弦相似度 |
| 95 | numerator = len(common_items) |
| 96 | denominator = math.sqrt(len(self.user_items[user_i]) * len(self.user_items[user_j])) |
| 97 | similarity = numerator / denominator if denominator > 0 else 0.0 |
| 98 | |
| 99 | self.user_similarity[user_i][user_j] = similarity |
| 100 | # 对称性 |
| 101 | if user_j not in self.user_similarity: |
| 102 | self.user_similarity[user_j] = {} |
| 103 | self.user_similarity[user_j][user_i] = similarity |
| 104 | |
| 105 | print("用户相似度计算完成") |
| 106 | |
| 107 | def train(self): |
| 108 | """训练UserCF模型""" |
| 109 | self._get_user_item_interactions() |
| 110 | self._calculate_user_similarity() |
| 111 | |
| 112 | def recall(self, user_id: int, num_items: int = 50, num_similar_users: int = 50) -> List[Tuple[int, float]]: |
| 113 | """ |
| 114 | 为用户召回相似用户喜欢的物品 |
| 115 | |
| 116 | Args: |
| 117 | user_id: 目标用户ID |
| 118 | num_items: 召回物品数量 |
| 119 | num_similar_users: 考虑的相似用户数量 |
| 120 | |
| 121 | Returns: |
| 122 | List of (item_id, score) tuples |
| 123 | """ |
| 124 | # 如果尚未训练,先进行训练 |
| 125 | if not hasattr(self, 'user_similarity') or not self.user_similarity: |
| 126 | self.train() |
| 127 | |
| 128 | if user_id not in self.user_similarity or user_id not in self.user_items: |
| 129 | return [] |
| 130 | |
| 131 | # 获取最相似的用户 |
| 132 | similar_users = sorted( |
| 133 | self.user_similarity[user_id].items(), |
| 134 | key=lambda x: x[1], |
| 135 | reverse=True |
| 136 | )[:num_similar_users] |
| 137 | |
| 138 | # 获取目标用户已交互的物品 |
| 139 | user_interacted_items = self.user_items[user_id] |
| 140 | |
| 141 | # 计算候选物品的分数 |
| 142 | candidate_scores = defaultdict(float) |
| 143 | |
| 144 | for similar_user_id, similarity in similar_users: |
| 145 | if similarity <= 0: |
| 146 | continue |
| 147 | |
| 148 | # 获取相似用户交互的物品 |
| 149 | similar_user_items = self.user_items[similar_user_id] |
| 150 | |
| 151 | for item_id in similar_user_items: |
| 152 | # 排除目标用户已经交互过的物品 |
| 153 | if item_id not in user_interacted_items: |
| 154 | candidate_scores[item_id] += similarity |
| 155 | |
| 156 | # 按分数排序并返回top-N |
| 157 | sorted_candidates = sorted(candidate_scores.items(), key=lambda x: x[1], reverse=True) |
| 158 | return sorted_candidates[:num_items] |
| 159 | |
| 160 | def get_user_neighbors(self, user_id: int, num_neighbors: int = 10) -> List[Tuple[int, float]]: |
| 161 | """ |
| 162 | 获取用户的相似邻居 |
| 163 | |
| 164 | Args: |
| 165 | user_id: 用户ID |
| 166 | num_neighbors: 邻居数量 |
| 167 | |
| 168 | Returns: |
| 169 | List of (neighbor_user_id, similarity) tuples |
| 170 | """ |
| 171 | if user_id not in self.user_similarity: |
| 172 | return [] |
| 173 | |
| 174 | neighbors = sorted( |
| 175 | self.user_similarity[user_id].items(), |
| 176 | key=lambda x: x[1], |
| 177 | reverse=True |
| 178 | )[:num_neighbors] |
| 179 | |
| 180 | return neighbors |
| 181 | |
| 182 | def get_user_profile(self, user_id: int) -> Dict: |
| 183 | """ |
| 184 | 获取用户画像信息 |
| 185 | |
| 186 | Args: |
| 187 | user_id: 用户ID |
| 188 | |
| 189 | Returns: |
| 190 | 用户画像字典 |
| 191 | """ |
| 192 | if user_id not in self.user_items: |
| 193 | return {} |
| 194 | |
| 195 | conn = pymysql.connect(**self.db_config) |
| 196 | try: |
| 197 | cursor = conn.cursor() |
| 198 | |
| 199 | # 获取用户交互的物品类别统计 |
| 200 | user_item_list = list(self.user_items[user_id]) |
| 201 | if not user_item_list: |
| 202 | return {} |
| 203 | |
| 204 | format_strings = ','.join(['%s'] * len(user_item_list)) |
| 205 | cursor.execute(f""" |
| 206 | SELECT t.name, COUNT(*) as count |
| 207 | FROM post_tags pt |
| 208 | JOIN tags t ON pt.tag_id = t.id |
| 209 | WHERE pt.post_id IN ({format_strings}) |
| 210 | GROUP BY t.name |
| 211 | ORDER BY count DESC |
| 212 | """, tuple(user_item_list)) |
| 213 | |
| 214 | tag_preferences = cursor.fetchall() |
| 215 | |
| 216 | # 获取用户行为统计 |
| 217 | cursor.execute(""" |
| 218 | SELECT type, COUNT(*) as count |
| 219 | FROM behaviors |
| 220 | WHERE user_id = %s |
| 221 | GROUP BY type |
| 222 | """, (user_id,)) |
| 223 | |
| 224 | behavior_stats = cursor.fetchall() |
| 225 | |
| 226 | return { |
| 227 | 'user_id': user_id, |
| 228 | 'total_interactions': len(self.user_items[user_id]), |
| 229 | 'tag_preferences': dict(tag_preferences), |
| 230 | 'behavior_stats': dict(behavior_stats) |
| 231 | } |
| 232 | |
| 233 | finally: |
| 234 | cursor.close() |
| 235 | conn.close() |