blob: d75e6d85cf0790fccf76f94dbbd18853ce090612 [file] [log] [blame]
import pymysql
from typing import List, Tuple, Dict, Set
from collections import defaultdict
import math
import numpy as np
class UserCFRecall:
"""
UserCF (User-based Collaborative Filtering) 召回算法实现
基于用户相似度的协同过滤算法
"""
def __init__(self, db_config: dict, min_common_items: int = 3):
"""
初始化UserCF召回模型
Args:
db_config: 数据库配置
min_common_items: 计算用户相似度时的最小共同物品数
"""
self.db_config = db_config
self.min_common_items = min_common_items
self.user_items = defaultdict(set)
self.item_users = defaultdict(set)
self.user_similarity = {}
def _get_user_item_interactions(self):
"""获取用户-物品交互数据"""
conn = pymysql.connect(**self.db_config)
try:
cursor = conn.cursor()
# 获取用户行为数据,考虑不同行为的权重
cursor.execute("""
SELECT user_id, post_id, type, COUNT(*) as count
FROM behaviors
WHERE type IN ('like', 'favorite', 'comment', 'view')
GROUP BY user_id, post_id, type
""")
interactions = cursor.fetchall()
# 构建用户-物品交互矩阵(考虑行为权重)
user_item_scores = defaultdict(lambda: defaultdict(float))
# 定义不同行为的权重
behavior_weights = {
'like': 1.0,
'favorite': 2.0,
'comment': 3.0,
'view': 0.1
}
for user_id, post_id, behavior_type, count in interactions:
weight = behavior_weights.get(behavior_type, 1.0)
score = weight * count
user_item_scores[user_id][post_id] += score
# 转换为集合形式(用于相似度计算)
for user_id, items in user_item_scores.items():
# 只保留分数大于阈值的物品
threshold = 1.0 # 可调整阈值
for item_id, score in items.items():
if score >= threshold:
self.user_items[user_id].add(item_id)
self.item_users[item_id].add(user_id)
finally:
cursor.close()
conn.close()
def _calculate_user_similarity(self):
"""计算用户相似度矩阵"""
print("开始计算用户相似度...")
users = list(self.user_items.keys())
total_pairs = len(users) * (len(users) - 1) // 2
processed = 0
for i, user_i in enumerate(users):
self.user_similarity[user_i] = {}
for user_j in users[i+1:]:
processed += 1
if processed % 10000 == 0:
print(f"处理进度: {processed}/{total_pairs}")
# 获取两个用户共同交互的物品
common_items = self.user_items[user_i] & self.user_items[user_j]
if len(common_items) < self.min_common_items:
similarity = 0.0
else:
# 计算余弦相似度
numerator = len(common_items)
denominator = math.sqrt(len(self.user_items[user_i]) * len(self.user_items[user_j]))
similarity = numerator / denominator if denominator > 0 else 0.0
self.user_similarity[user_i][user_j] = similarity
# 对称性
if user_j not in self.user_similarity:
self.user_similarity[user_j] = {}
self.user_similarity[user_j][user_i] = similarity
print("用户相似度计算完成")
def train(self):
"""训练UserCF模型"""
self._get_user_item_interactions()
self._calculate_user_similarity()
def recall(self, user_id: int, num_items: int = 50, num_similar_users: int = 50) -> List[Tuple[int, float]]:
"""
为用户召回相似用户喜欢的物品
Args:
user_id: 目标用户ID
num_items: 召回物品数量
num_similar_users: 考虑的相似用户数量
Returns:
List of (item_id, score) tuples
"""
# 如果尚未训练,先进行训练
if not hasattr(self, 'user_similarity') or not self.user_similarity:
self.train()
if user_id not in self.user_similarity or user_id not in self.user_items:
return []
# 获取最相似的用户
similar_users = sorted(
self.user_similarity[user_id].items(),
key=lambda x: x[1],
reverse=True
)[:num_similar_users]
# 获取目标用户已交互的物品
user_interacted_items = self.user_items[user_id]
# 计算候选物品的分数
candidate_scores = defaultdict(float)
for similar_user_id, similarity in similar_users:
if similarity <= 0:
continue
# 获取相似用户交互的物品
similar_user_items = self.user_items[similar_user_id]
for item_id in similar_user_items:
# 排除目标用户已经交互过的物品
if item_id not in user_interacted_items:
candidate_scores[item_id] += similarity
# 按分数排序并返回top-N
sorted_candidates = sorted(candidate_scores.items(), key=lambda x: x[1], reverse=True)
return sorted_candidates[:num_items]
def get_user_neighbors(self, user_id: int, num_neighbors: int = 10) -> List[Tuple[int, float]]:
"""
获取用户的相似邻居
Args:
user_id: 用户ID
num_neighbors: 邻居数量
Returns:
List of (neighbor_user_id, similarity) tuples
"""
if user_id not in self.user_similarity:
return []
neighbors = sorted(
self.user_similarity[user_id].items(),
key=lambda x: x[1],
reverse=True
)[:num_neighbors]
return neighbors
def get_user_profile(self, user_id: int) -> Dict:
"""
获取用户画像信息
Args:
user_id: 用户ID
Returns:
用户画像字典
"""
if user_id not in self.user_items:
return {}
conn = pymysql.connect(**self.db_config)
try:
cursor = conn.cursor()
# 获取用户交互的物品类别统计
user_item_list = list(self.user_items[user_id])
if not user_item_list:
return {}
format_strings = ','.join(['%s'] * len(user_item_list))
cursor.execute(f"""
SELECT t.name, COUNT(*) as count
FROM post_tags pt
JOIN tags t ON pt.tag_id = t.id
WHERE pt.post_id IN ({format_strings})
GROUP BY t.name
ORDER BY count DESC
""", tuple(user_item_list))
tag_preferences = cursor.fetchall()
# 获取用户行为统计
cursor.execute("""
SELECT type, COUNT(*) as count
FROM behaviors
WHERE user_id = %s
GROUP BY type
""", (user_id,))
behavior_stats = cursor.fetchall()
return {
'user_id': user_id,
'total_interactions': len(self.user_items[user_id]),
'tag_preferences': dict(tag_preferences),
'behavior_stats': dict(behavior_stats)
}
finally:
cursor.close()
conn.close()