"""
LightGCN评分服务
用于对多路召回的结果进行LightGCN打分
"""

import torch
import numpy as np
from typing import List, Tuple, Dict, Any
from app.models.recommend.LightGCN import LightGCN
from app.utils.parse_args import args
from app.utils.data_loader import EdgeListData
from app.utils.graph_build import build_user_post_graph


class LightGCNScorer:
    """
    LightGCN评分器
    专门用于对多路召回结果进行精准打分
    """
    
    def __init__(self):
        """初始化LightGCN评分器"""
        # 设备配置
        args.device = 'cuda:7' if torch.cuda.is_available() else 'cpu'
        args.data_path = './app/user_post_graph.txt'
        args.pre_model_path = './app/models/recommend/LightGCN_pretrained.pt'
        
        # 模型相关变量
        self.model = None
        self.user2idx = None
        self.post2idx = None
        self.idx2post = None
        self.dataset = None
        self.user_embeddings = None
        self.item_embeddings = None
        
        # 是否已初始化
        self._initialized = False
    
    def _initialize_model(self):
        """初始化LightGCN模型"""
        if self._initialized:
            return
            
        print("初始化LightGCN评分模型...")
        
        # 构建用户-物品映射
        self.user2idx, self.post2idx = build_user_post_graph(return_mapping=True)
        self.idx2post = {v: k for k, v in self.post2idx.items()}
        
        # 加载数据集
        self.dataset = EdgeListData(args.data_path, args.data_path)
        
        # 加载预训练模型
        pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
        pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:self.dataset.num_users]
        pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:self.dataset.num_items]
        
        # 初始化模型
        self.model = LightGCN(self.dataset, phase='vanilla').to(args.device)
        self.model.load_state_dict(pretrained_dict, strict=False)
        self.model.eval()
        
        # 预先计算所有用户和物品的嵌入表示
        with torch.no_grad():
            self.user_embeddings, self.item_embeddings = self.model.generate()
        
        self._initialized = True
        print("LightGCN评分模型初始化完成")
    
    def score_candidates(self, user_id: int, candidate_post_ids: List[int]) -> List[float]:
        """
        对候选物品进行LightGCN打分
        
        Args:
            user_id: 用户ID
            candidate_post_ids: 候选物品ID列表
            
        Returns:
            List[float]: 每个候选物品的LightGCN分数
        """
        self._initialize_model()
        
        # 检查用户是否存在
        if user_id not in self.user2idx:
            print(f"用户 {user_id} 不在训练数据中，返回零分数")
            return [0.0] * len(candidate_post_ids)
        
        user_idx = self.user2idx[user_id]
        scores = []
        
        print(len(candidate_post_ids), "候选物品数量")
        
        with torch.no_grad():
            user_emb = self.user_embeddings[user_idx].unsqueeze(0)  # [1, emb_size]
            
            for post_id in candidate_post_ids:
                if post_id not in self.post2idx:
                    # 物品不在训练数据中，给予默认分数
                    scores.append(0.0)
                    continue
                
                post_idx = self.post2idx[post_id]
                item_emb = self.item_embeddings[post_idx].unsqueeze(0)  # [1, emb_size]
                
                # 计算评分：用户嵌入和物品嵌入的内积
                score = torch.matmul(user_emb, item_emb.t()).item()
                scores.append(float(score))
        
        return scores
    
    def score_batch_candidates(self, user_id: int, candidate_post_ids: List[int]) -> List[float]:
        """
        批量对候选物品进行LightGCN打分（更高效）
        
        Args:
            user_id: 用户ID
            candidate_post_ids: 候选物品ID列表
            
        Returns:
            List[float]: 每个候选物品的LightGCN分数
        """
        self._initialize_model()
        
        # 检查用户是否存在
        if user_id not in self.user2idx:
            print(f"用户 {user_id} 不在训练数据中，返回零分数")
            return [0.0] * len(candidate_post_ids)
        
        print(len(candidate_post_ids), "候选物品数量")
        
        user_idx = self.user2idx[user_id]
        
        # 过滤出存在于训练数据中的物品
        valid_items = []
        valid_indices = []
        for i, post_id in enumerate(candidate_post_ids):
            if post_id in self.post2idx:
                valid_items.append(self.post2idx[post_id])
                valid_indices.append(i)
        
        scores = [0.0] * len(candidate_post_ids)
        
        if not valid_items:
            return scores
        
        with torch.no_grad():
            user_emb = self.user_embeddings[user_idx].unsqueeze(0)  # [1, emb_size]
            
            # 批量获取物品嵌入
            valid_item_indices = torch.tensor(valid_items, device=args.device)
            valid_item_embs = self.item_embeddings[valid_item_indices]  # [num_valid_items, emb_size]
            
            # 批量计算评分
            batch_scores = torch.matmul(user_emb, valid_item_embs.t()).squeeze(0)  # [num_valid_items]
            
            # 将分数填回原位置
            for i, score in enumerate(batch_scores.cpu().numpy()):
                original_idx = valid_indices[i]
                scores[original_idx] = float(score)
        
        return scores
    
    def get_user_profile(self, user_id: int) -> Dict[str, Any]:
        """
        获取用户在LightGCN中的表示和统计信息
        
        Args:
            user_id: 用户ID
            
        Returns:
            Dict: 用户画像信息
        """
        self._initialize_model()
        
        if user_id not in self.user2idx:
            return {
                'user_id': user_id,
                'exists_in_model': False,
                'message': '用户不在训练数据中'
            }
        
        user_idx = self.user2idx[user_id]
        
        with torch.no_grad():
            user_emb = self.user_embeddings[user_idx]
            
            # 计算用户嵌入的统计信息
            emb_norm = torch.norm(user_emb).item()
            emb_mean = torch.mean(user_emb).item()
            emb_std = torch.std(user_emb).item()
            
            # 找到与用户最相似的物品（基于余弦相似度）
            user_emb_normalized = user_emb / torch.norm(user_emb)
            item_embs_normalized = self.item_embeddings / torch.norm(self.item_embeddings, dim=1, keepdim=True)
            
            similarities = torch.matmul(user_emb_normalized.unsqueeze(0), item_embs_normalized.t()).squeeze(0)
            top_k_indices = torch.topk(similarities, k=10).indices.cpu().numpy()
            
            top_similar_items = []
            for idx in top_k_indices:
                if idx < len(self.idx2post):
                    post_id = self.idx2post[idx]
                    similarity = similarities[idx].item()
                    top_similar_items.append({
                        'post_id': post_id,
                        'similarity': float(similarity)
                    })
        
        return {
            'user_id': user_id,
            'user_idx': user_idx,
            'exists_in_model': True,
            'embedding_stats': {
                'norm': float(emb_norm),
                'mean': float(emb_mean),
                'std': float(emb_std),
                'dimension': user_emb.shape[0]
            },
            'top_similar_items': top_similar_items
        }
    
    def compare_scoring_methods(self, user_id: int, candidate_post_ids: List[int]) -> Dict[str, List[float]]:
        """
        比较不同的评分方法
        
        Args:
            user_id: 用户ID
            candidate_post_ids: 候选物品ID列表
            
        Returns:
            Dict: 包含不同评分方法结果的字典
        """
        self._initialize_model()
        
        if user_id not in self.user2idx:
            zero_scores = [0.0] * len(candidate_post_ids)
            return {
                'lightgcn_inner_product': zero_scores,
                'lightgcn_cosine_similarity': zero_scores,
                'message': f'用户 {user_id} 不在训练数据中'
            }
        
        user_idx = self.user2idx[user_id]
        
        inner_product_scores = []
        cosine_similarity_scores = []
        
        with torch.no_grad():
            user_emb = self.user_embeddings[user_idx]
            user_emb_normalized = user_emb / torch.norm(user_emb)
            
            for post_id in candidate_post_ids:
                if post_id not in self.post2idx:
                    inner_product_scores.append(0.0)
                    cosine_similarity_scores.append(0.0)
                    continue
                
                post_idx = self.post2idx[post_id]
                item_emb = self.item_embeddings[post_idx]
                item_emb_normalized = item_emb / torch.norm(item_emb)
                
                # 内积评分
                inner_score = torch.dot(user_emb, item_emb).item()
                inner_product_scores.append(float(inner_score))
                
                # 余弦相似度评分
                cosine_score = torch.dot(user_emb_normalized, item_emb_normalized).item()
                cosine_similarity_scores.append(float(cosine_score))
        
        return {
            'lightgcn_inner_product': inner_product_scores,
            'lightgcn_cosine_similarity': cosine_similarity_scores
        }
    
    def get_model_info(self) -> Dict[str, Any]:
        """
        获取LightGCN模型的基本信息
        
        Returns:
            Dict: 模型信息
        """
        self._initialize_model()
        
        return {
            'model_type': 'LightGCN',
            'device': str(args.device),
            'num_users': self.dataset.num_users,
            'num_items': self.dataset.num_items,
            'embedding_size': self.user_embeddings.shape[1],
            'num_layers': args.num_layers,
            'pretrained_model_path': args.pre_model_path,
            'data_path': args.data_path,
            'initialized': self._initialized
        }
