import pymysql
from typing import List, Tuple, Dict, Set
from collections import defaultdict
import math
import numpy as np

class UserCFRecall:
    """
    UserCF (User-based Collaborative Filtering) 召回算法实现
    基于用户相似度的协同过滤算法
    """
    
    def __init__(self, db_config: dict, min_common_items: int = 3):
        """
        初始化UserCF召回模型
        
        Args:
            db_config: 数据库配置
            min_common_items: 计算用户相似度时的最小共同物品数
        """
        self.db_config = db_config
        self.min_common_items = min_common_items
        self.user_items = defaultdict(set)
        self.item_users = defaultdict(set)
        self.user_similarity = {}
        
    def _get_user_item_interactions(self):
        """获取用户-物品交互数据"""
        conn = pymysql.connect(**self.db_config)
        try:
            cursor = conn.cursor()
            
            # 获取用户行为数据，考虑不同行为的权重
            cursor.execute("""
                SELECT user_id, post_id, type, COUNT(*) as count
                FROM behaviors
                WHERE type IN ('like', 'favorite', 'comment', 'view')
                GROUP BY user_id, post_id, type
            """)
            
            interactions = cursor.fetchall()
            
            # 构建用户-物品交互矩阵（考虑行为权重）
            user_item_scores = defaultdict(lambda: defaultdict(float))
            
            # 定义不同行为的权重
            behavior_weights = {
                'like': 1.0,
                'favorite': 2.0,
                'comment': 3.0,
                'view': 0.1
            }
            
            for user_id, post_id, behavior_type, count in interactions:
                weight = behavior_weights.get(behavior_type, 1.0)
                score = weight * count
                user_item_scores[user_id][post_id] += score
            
            # 转换为集合形式（用于相似度计算）
            for user_id, items in user_item_scores.items():
                # 只保留分数大于阈值的物品
                threshold = 1.0  # 可调整阈值
                for item_id, score in items.items():
                    if score >= threshold:
                        self.user_items[user_id].add(item_id)
                        self.item_users[item_id].add(user_id)
                        
        finally:
            cursor.close()
            conn.close()
    
    def _calculate_user_similarity(self):
        """计算用户相似度矩阵"""
        print("开始计算用户相似度...")
        
        users = list(self.user_items.keys())
        total_pairs = len(users) * (len(users) - 1) // 2
        processed = 0
        
        for i, user_i in enumerate(users):
            self.user_similarity[user_i] = {}
            
            for user_j in users[i+1:]:
                processed += 1
                if processed % 10000 == 0:
                    print(f"处理进度: {processed}/{total_pairs}")
                
                # 获取两个用户共同交互的物品
                common_items = self.user_items[user_i] & self.user_items[user_j]
                
                if len(common_items) < self.min_common_items:
                    similarity = 0.0
                else:
                    # 计算余弦相似度
                    numerator = len(common_items)
                    denominator = math.sqrt(len(self.user_items[user_i]) * len(self.user_items[user_j]))
                    similarity = numerator / denominator if denominator > 0 else 0.0
                
                self.user_similarity[user_i][user_j] = similarity
                # 对称性
                if user_j not in self.user_similarity:
                    self.user_similarity[user_j] = {}
                self.user_similarity[user_j][user_i] = similarity
        
        print("用户相似度计算完成")
    
    def train(self):
        """训练UserCF模型"""
        self._get_user_item_interactions()
        self._calculate_user_similarity()
    
    def recall(self, user_id: int, num_items: int = 50, num_similar_users: int = 50) -> List[Tuple[int, float]]:
        """
        为用户召回相似用户喜欢的物品
        
        Args:
            user_id: 目标用户ID
            num_items: 召回物品数量
            num_similar_users: 考虑的相似用户数量
            
        Returns:
            List of (item_id, score) tuples
        """
        # 如果尚未训练，先进行训练
        if not hasattr(self, 'user_similarity') or not self.user_similarity:
            self.train()
        
        if user_id not in self.user_similarity or user_id not in self.user_items:
            return []
        
        # 获取最相似的用户
        similar_users = sorted(
            self.user_similarity[user_id].items(),
            key=lambda x: x[1],
            reverse=True
        )[:num_similar_users]
        
        # 获取目标用户已交互的物品
        user_interacted_items = self.user_items[user_id]
        
        # 计算候选物品的分数
        candidate_scores = defaultdict(float)
        
        for similar_user_id, similarity in similar_users:
            if similarity <= 0:
                continue
                
            # 获取相似用户交互的物品
            similar_user_items = self.user_items[similar_user_id]
            
            for item_id in similar_user_items:
                # 排除目标用户已经交互过的物品
                if item_id not in user_interacted_items:
                    candidate_scores[item_id] += similarity
        
        # 按分数排序并返回top-N
        sorted_candidates = sorted(candidate_scores.items(), key=lambda x: x[1], reverse=True)
        return sorted_candidates[:num_items]
    
    def get_user_neighbors(self, user_id: int, num_neighbors: int = 10) -> List[Tuple[int, float]]:
        """
        获取用户的相似邻居
        
        Args:
            user_id: 用户ID
            num_neighbors: 邻居数量
            
        Returns:
            List of (neighbor_user_id, similarity) tuples
        """
        if user_id not in self.user_similarity:
            return []
        
        neighbors = sorted(
            self.user_similarity[user_id].items(),
            key=lambda x: x[1],
            reverse=True
        )[:num_neighbors]
        
        return neighbors
    
    def get_user_profile(self, user_id: int) -> Dict:
        """
        获取用户画像信息
        
        Args:
            user_id: 用户ID
            
        Returns:
            用户画像字典
        """
        if user_id not in self.user_items:
            return {}
        
        conn = pymysql.connect(**self.db_config)
        try:
            cursor = conn.cursor()
            
            # 获取用户交互的物品类别统计
            user_item_list = list(self.user_items[user_id])
            if not user_item_list:
                return {}
                
            format_strings = ','.join(['%s'] * len(user_item_list))
            cursor.execute(f"""
                SELECT t.name, COUNT(*) as count
                FROM post_tags pt
                JOIN tags t ON pt.tag_id = t.id
                WHERE pt.post_id IN ({format_strings})
                GROUP BY t.name
                ORDER BY count DESC
            """, tuple(user_item_list))
            
            tag_preferences = cursor.fetchall()
            
            # 获取用户行为统计
            cursor.execute("""
                SELECT type, COUNT(*) as count
                FROM behaviors
                WHERE user_id = %s
                GROUP BY type
            """, (user_id,))
            
            behavior_stats = cursor.fetchall()
            
            return {
                'user_id': user_id,
                'total_interactions': len(self.user_items[user_id]),
                'tag_preferences': dict(tag_preferences),
                'behavior_stats': dict(behavior_stats)
            }
            
        finally:
            cursor.close()
            conn.close()
