#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试基于redbook数据库的推荐系统
"""

import sys
import os
import time
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from app.services.recommendation_service import RecommendationService
from app.utils.graph_build import build_user_post_graph
import pymysql

def test_database_connection():
    """测试数据库连接"""
    print("=== 测试数据库连接 ===")
    try:
        db_config = {
            'host': '10.126.59.25',
            'port': 3306,
            'user': 'root',
            'password': '123456',
            'database': 'redbook',
            'charset': 'utf8mb4'
        }
        conn = pymysql.connect(**db_config)
        cursor = conn.cursor()
        
        # 检查用户数量
        cursor.execute("SELECT COUNT(*) FROM users")
        user_count = cursor.fetchone()[0]
        print(f"用户总数: {user_count}")
        
        # 检查帖子数量
        cursor.execute("SELECT COUNT(*) FROM posts WHERE status = 'published'")
        post_count = cursor.fetchone()[0]
        print(f"已发布帖子数: {post_count}")
        
        # 检查行为数据
        cursor.execute("SELECT type, COUNT(*) FROM behaviors GROUP BY type")
        behavior_stats = cursor.fetchall()
        print("行为统计:")
        for behavior_type, count in behavior_stats:
            print(f"  {behavior_type}: {count}")
            
        cursor.close()
        conn.close()
        print("数据库连接测试成功!")
        return True
    except Exception as e:
        print(f"数据库连接失败: {e}")
        return False

def test_graph_building():
    """测试图构建"""
    print("\n=== 测试图构建 ===")
    try:
        user2idx, post2idx = build_user_post_graph(return_mapping=True)
        print(f"用户数量: {len(user2idx)}")
        print(f"帖子数量: {len(post2idx)}")
        
        # 显示前几个用户和帖子的映射
        print("前5个用户映射:")
        for i, (user_id, idx) in enumerate(list(user2idx.items())[:5]):
            print(f"  用户{user_id} -> 索引{idx}")
            
        print("前5个帖子映射:")
        for i, (post_id, idx) in enumerate(list(post2idx.items())[:5]):
            print(f"  帖子{post_id} -> 索引{idx}")
            
        print("图构建测试成功!")
        return True
    except Exception as e:
        print(f"图构建失败: {e}")
        return False

def test_cold_start_recommendation():
    """测试冷启动推荐"""
    print("\n=== 测试冷启动推荐 ===")
    try:
        service = RecommendationService()
        
        # 使用一个不存在的用户ID进行冷启动测试
        fake_user_id = 999999
        
        # 计时开始
        start_time = time.time()
        recommendations = service.get_recommendations(fake_user_id, topk=10)
        end_time = time.time()
        
        # 计算推荐耗时
        recommendation_time = end_time - start_time
        print(f"冷启动推荐耗时: {recommendation_time:.4f} 秒")
        
        print(f"冷启动推荐结果(用户{fake_user_id}):")
        for i, rec in enumerate(recommendations):
            print(f"  {i+1}. 帖子ID: {rec['post_id']}, 标题: {rec['title'][:50]}...")
            print(f"     作者: {rec['username']}, 热度: {rec['heat']}")
            print(f"     点赞: {rec.get('like_count', 0)}, 评论: {rec.get('comment_count', 0)}")
            
        print("冷启动推荐测试成功!")
        return True
    except Exception as e:
        print(f"冷启动推荐失败: {e}")
        return False

def test_user_recommendation():
    """测试用户推荐"""
    print("\n=== 测试用户推荐 ===")
    try:
        service = RecommendationService()
        
        # 获取一个真实用户ID
        db_config = service.db_config
        conn = pymysql.connect(**db_config)
        cursor = conn.cursor()
        cursor.execute("SELECT DISTINCT user_id FROM behaviors LIMIT 1")
        result = cursor.fetchone()
        
        if result:
            user_id = result[0]
            print(f"测试用户ID: {user_id}")
            
            # 查看用户的历史行为
            cursor.execute("""
                SELECT b.type, COUNT(*) as count
                FROM behaviors b 
                WHERE b.user_id = %s 
                GROUP BY b.type
            """, (user_id,))
            user_behaviors = cursor.fetchall()
            print("用户历史行为:")
            for behavior_type, count in user_behaviors:
                print(f"  {behavior_type}: {count}")
            
            cursor.close()
            conn.close()
            
            # 尝试获取推荐 - 添加计时
            print("开始生成推荐...")
            start_time = time.time()
            recommendations = service.get_recommendations(user_id, topk=10)
            end_time = time.time()
            
            # 计算推荐耗时
            recommendation_time = end_time - start_time
            print(f"用户推荐耗时: {recommendation_time:.4f} 秒")
            
            print(f"用户推荐结果(用户{user_id}):")
            for i, rec in enumerate(recommendations):
                print(f"  {i+1}. 帖子ID: {rec['post_id']}, 标题: {rec['title'][:50]}...")
                print(f"     作者: {rec['username']}, 热度: {rec['heat']}")
                print(f"     点赞: {rec.get('like_count', 0)}, 评论: {rec.get('comment_count', 0)}")
                if 'recommendation_score' in rec:
                    print(f"     推荐分数: {rec['recommendation_score']:.4f}")
                else:
                    print(f"     热度分数: {rec['heat']}")
                
            print("用户推荐测试成功!")
            return True
        else:
            print("没有找到有行为记录的用户")
            cursor.close()
            conn.close()
            return False
            
    except Exception as e:
        print(f"用户推荐失败: {e}")
        return False

def test_recommendation_performance():
    """测试推荐性能 - 多次调用统计"""
    print("\n=== 测试推荐性能 ===")
    try:
        service = RecommendationService()
        
        # 获取几个真实用户ID进行测试
        db_config = service.db_config
        conn = pymysql.connect(**db_config)
        cursor = conn.cursor()
        cursor.execute("SELECT DISTINCT user_id FROM behaviors LIMIT 5")
        user_ids = [row[0] for row in cursor.fetchall()]
        cursor.close()
        conn.close()
        
        if not user_ids:
            print("没有找到有行为记录的用户")
            return False
        
        print(f"测试用户数量: {len(user_ids)}")
        
        # 进行多次推荐测试
        times = []
        test_rounds = 3  # 每个用户测试3轮
        
        for round_num in range(test_rounds):
            print(f"\n第 {round_num + 1} 轮测试:")
            round_times = []
            
            for i, user_id in enumerate(user_ids):
                start_time = time.time()
                recommendations = service.get_recommendations(user_id, topk=10)
                end_time = time.time()
                
                recommendation_time = end_time - start_time
                round_times.append(recommendation_time)
                times.append(recommendation_time)
                
                print(f"  用户 {user_id}: {recommendation_time:.4f}s, 推荐数量: {len(recommendations)}")
            
            # 计算本轮统计
            avg_time = sum(round_times) / len(round_times)
            min_time = min(round_times)
            max_time = max(round_times)
            print(f"  本轮平均耗时: {avg_time:.4f}s, 最快: {min_time:.4f}s, 最慢: {max_time:.4f}s")
        
        # 计算总体统计
        print(f"\n=== 性能统计总结 ===")
        print(f"总测试次数: {len(times)}")
        print(f"平均推荐耗时: {sum(times) / len(times):.4f} 秒")
        print(f"最快推荐耗时: {min(times):.4f} 秒")
        print(f"最慢推荐耗时: {max(times):.4f} 秒")
        print(f"推荐耗时标准差: {(sum([(t - sum(times)/len(times))**2 for t in times]) / len(times))**0.5:.4f} 秒")
        
        # 性能等级评估
        avg_time = sum(times) / len(times)
        if avg_time < 0.1:
            performance_level = "优秀"
        elif avg_time < 0.5:
            performance_level = "良好"
        elif avg_time < 1.0:
            performance_level = "一般"
        else:
            performance_level = "需要优化"
        
        print(f"性能评级: {performance_level}")
        
        print("推荐性能测试成功!")
        return True
        
    except Exception as e:
        print(f"推荐性能测试失败: {e}")
        return False

def main():
    """主测试函数"""
    print("开始测试基于redbook数据库的推荐系统")
    print("=" * 50)
    
    tests = [
        test_database_connection,
        test_graph_building,
        test_cold_start_recommendation,
        test_user_recommendation,
        test_recommendation_performance
    ]
    
    passed = 0
    total = len(tests)
    
    for test in tests:
        try:
            if test():
                passed += 1
        except Exception as e:
            print(f"测试异常: {e}")
    
    print("\n" + "=" * 50)
    print(f"测试完成: {passed}/{total} 通过")
    
    if passed == total:
        print("所有测试通过!")
    else:
        print("部分测试失败，请检查配置和代码")

if __name__ == "__main__":
    main()
