blob: bf7fdd6fed2bd4170cd8fffad6af40aa60fce31c [file] [log] [blame]
import numpy as np
import pymysql
from collections import defaultdict
import math
from typing import List, Tuple, Dict
class SwingRecall:
"""
Swing召回算法实现
基于物品相似度的协同过滤算法,能够有效处理热门物品的问题
"""
def __init__(self, db_config: dict, alpha: float = 0.5):
"""
初始化Swing召回模型
Args:
db_config: 数据库配置
alpha: 控制热门物品惩罚的参数,值越大惩罚越强
"""
self.db_config = db_config
self.alpha = alpha
self.item_similarity = {}
self.user_items = defaultdict(set)
self.item_users = defaultdict(set)
def _get_interaction_data(self):
"""获取用户-物品交互数据"""
conn = pymysql.connect(**self.db_config)
try:
cursor = conn.cursor()
# 获取用户行为数据(点赞、收藏、评论等)
cursor.execute("""
SELECT DISTINCT user_id, post_id
FROM behaviors
WHERE type IN ('like', 'favorite', 'comment')
""")
interactions = cursor.fetchall()
for user_id, post_id in interactions:
self.user_items[user_id].add(post_id)
self.item_users[post_id].add(user_id)
finally:
cursor.close()
conn.close()
def _calculate_swing_similarity(self):
"""计算Swing相似度矩阵"""
print("开始计算Swing相似度...")
# 获取所有物品对
items = list(self.item_users.keys())
for i, item_i in enumerate(items):
if i % 100 == 0:
print(f"处理进度: {i}/{len(items)}")
self.item_similarity[item_i] = {}
for item_j in items[i+1:]:
# 获取同时交互过两个物品的用户
common_users = self.item_users[item_i] & self.item_users[item_j]
if len(common_users) < 2: # 需要至少2个共同用户
similarity = 0.0
else:
# 计算Swing相似度
similarity = 0.0
for u in common_users:
for v in common_users:
if u != v:
# Swing算法的核心公式
swing_weight = 1.0 / (self.alpha + len(self.user_items[u] & self.user_items[v]))
similarity += swing_weight
# 归一化
similarity = similarity / (len(common_users) * (len(common_users) - 1))
self.item_similarity[item_i][item_j] = similarity
# 对称性
if item_j not in self.item_similarity:
self.item_similarity[item_j] = {}
self.item_similarity[item_j][item_i] = similarity
print("Swing相似度计算完成")
def train(self):
"""训练Swing模型"""
self._get_interaction_data()
self._calculate_swing_similarity()
def recall(self, user_id: int, num_items: int = 50) -> List[Tuple[int, float]]:
"""
为用户召回相似物品
Args:
user_id: 用户ID
num_items: 召回物品数量
Returns:
List of (item_id, score) tuples
"""
# 如果尚未训练,先进行训练
if not hasattr(self, 'item_similarity') or not self.item_similarity:
self.train()
if user_id not in self.user_items:
return []
# 获取用户历史交互的物品
user_interacted_items = self.user_items[user_id]
# 计算候选物品的分数
candidate_scores = defaultdict(float)
for item_i in user_interacted_items:
if item_i in self.item_similarity:
for item_j, similarity in self.item_similarity[item_i].items():
# 排除用户已经交互过的物品
if item_j not in user_interacted_items:
candidate_scores[item_j] += similarity
# 按分数排序并返回top-N
sorted_candidates = sorted(candidate_scores.items(), key=lambda x: x[1], reverse=True)
return sorted_candidates[:num_items]