import numpy as np
import pymysql
from collections import defaultdict
import math
from typing import List, Tuple, Dict

class SwingRecall:
    """
    Swing召回算法实现
    基于物品相似度的协同过滤算法，能够有效处理热门物品的问题
    """
    
    def __init__(self, db_config: dict, alpha: float = 0.5):
        """
        初始化Swing召回模型
        
        Args:
            db_config: 数据库配置
            alpha: 控制热门物品惩罚的参数，值越大惩罚越强
        """
        self.db_config = db_config
        self.alpha = alpha
        self.item_similarity = {}
        self.user_items = defaultdict(set)
        self.item_users = defaultdict(set)
        
    def _get_interaction_data(self):
        """获取用户-物品交互数据"""
        conn = pymysql.connect(**self.db_config)
        try:
            cursor = conn.cursor()
            # 获取用户行为数据（点赞、收藏、评论等）
            cursor.execute("""
                SELECT DISTINCT user_id, post_id
                FROM behaviors
                WHERE type IN ('like', 'favorite', 'comment')
            """)
            interactions = cursor.fetchall()
            
            for user_id, post_id in interactions:
                self.user_items[user_id].add(post_id)
                self.item_users[post_id].add(user_id)
                
        finally:
            cursor.close()
            conn.close()
    
    def _calculate_swing_similarity(self):
        """计算Swing相似度矩阵"""
        print("开始计算Swing相似度...")
        
        # 获取所有物品对
        items = list(self.item_users.keys())
        
        for i, item_i in enumerate(items):
            if i % 100 == 0:
                print(f"处理进度: {i}/{len(items)}")
                
            self.item_similarity[item_i] = {}
            
            for item_j in items[i+1:]:
                # 获取同时交互过两个物品的用户
                common_users = self.item_users[item_i] & self.item_users[item_j]
                
                if len(common_users) < 2:  # 需要至少2个共同用户
                    similarity = 0.0
                else:
                    # 计算Swing相似度
                    similarity = 0.0
                    for u in common_users:
                        for v in common_users:
                            if u != v:
                                # Swing算法的核心公式
                                swing_weight = 1.0 / (self.alpha + len(self.user_items[u] & self.user_items[v]))
                                similarity += swing_weight
                    
                    # 归一化
                    similarity = similarity / (len(common_users) * (len(common_users) - 1))
                
                self.item_similarity[item_i][item_j] = similarity
                # 对称性
                if item_j not in self.item_similarity:
                    self.item_similarity[item_j] = {}
                self.item_similarity[item_j][item_i] = similarity
        
        print("Swing相似度计算完成")
    
    def train(self):
        """训练Swing模型"""
        self._get_interaction_data()
        self._calculate_swing_similarity()
    
    def recall(self, user_id: int, num_items: int = 50) -> List[Tuple[int, float]]:
        """
        为用户召回相似物品
        
        Args:
            user_id: 用户ID
            num_items: 召回物品数量
            
        Returns:
            List of (item_id, score) tuples
        """
        # 如果尚未训练，先进行训练
        if not hasattr(self, 'item_similarity') or not self.item_similarity:
            self.train()
        
        if user_id not in self.user_items:
            return []
        
        # 获取用户历史交互的物品
        user_interacted_items = self.user_items[user_id]
        
        # 计算候选物品的分数
        candidate_scores = defaultdict(float)
        
        for item_i in user_interacted_items:
            if item_i in self.item_similarity:
                for item_j, similarity in self.item_similarity[item_i].items():
                    # 排除用户已经交互过的物品
                    if item_j not in user_interacted_items:
                        candidate_scores[item_j] += similarity
        
        # 按分数排序并返回top-N
        sorted_candidates = sorted(candidate_scores.items(), key=lambda x: x[1], reverse=True)
        return sorted_candidates[:num_items]
