blob: 325def58ee51be3a82d987e00f21dfb0eb4022d9 [file] [log] [blame]
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试基于redbook数据库的推荐系统
"""
import sys
import os
import time
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from app.services.recommendation_service import RecommendationService
from app.utils.graph_build import build_user_post_graph
import pymysql
def test_database_connection():
"""测试数据库连接"""
print("=== 测试数据库连接 ===")
try:
db_config = {
'host': '10.126.59.25',
'port': 3306,
'user': 'root',
'password': '123456',
'database': 'redbook',
'charset': 'utf8mb4'
}
conn = pymysql.connect(**db_config)
cursor = conn.cursor()
# 检查用户数量
cursor.execute("SELECT COUNT(*) FROM users")
user_count = cursor.fetchone()[0]
print(f"用户总数: {user_count}")
# 检查帖子数量
cursor.execute("SELECT COUNT(*) FROM posts WHERE status = 'published'")
post_count = cursor.fetchone()[0]
print(f"已发布帖子数: {post_count}")
# 检查行为数据
cursor.execute("SELECT type, COUNT(*) FROM behaviors GROUP BY type")
behavior_stats = cursor.fetchall()
print("行为统计:")
for behavior_type, count in behavior_stats:
print(f" {behavior_type}: {count}")
cursor.close()
conn.close()
print("数据库连接测试成功!")
return True
except Exception as e:
print(f"数据库连接失败: {e}")
return False
def test_graph_building():
"""测试图构建"""
print("\n=== 测试图构建 ===")
try:
user2idx, post2idx = build_user_post_graph(return_mapping=True)
print(f"用户数量: {len(user2idx)}")
print(f"帖子数量: {len(post2idx)}")
# 显示前几个用户和帖子的映射
print("前5个用户映射:")
for i, (user_id, idx) in enumerate(list(user2idx.items())[:5]):
print(f" 用户{user_id} -> 索引{idx}")
print("前5个帖子映射:")
for i, (post_id, idx) in enumerate(list(post2idx.items())[:5]):
print(f" 帖子{post_id} -> 索引{idx}")
print("图构建测试成功!")
return True
except Exception as e:
print(f"图构建失败: {e}")
return False
def test_cold_start_recommendation():
"""测试冷启动推荐"""
print("\n=== 测试冷启动推荐 ===")
try:
service = RecommendationService()
# 使用一个不存在的用户ID进行冷启动测试
fake_user_id = 999999
# 计时开始
start_time = time.time()
recommendations = service.get_recommendations(fake_user_id, topk=10)
end_time = time.time()
# 计算推荐耗时
recommendation_time = end_time - start_time
print(f"冷启动推荐耗时: {recommendation_time:.4f} 秒")
print(f"冷启动推荐结果(用户{fake_user_id}):")
for i, rec in enumerate(recommendations):
print(f" {i+1}. 帖子ID: {rec['id']}, 标题: {rec['title'][:50]}...")
print(f" 作者: {rec['username']}, 热度: {rec['heat']}")
print(f" 点赞: {rec.get('like_count', 0)}, 评论: {rec.get('comment_count', 0)}")
print("冷启动推荐测试成功!")
return True
except Exception as e:
print(f"冷启动推荐失败: {e}")
return False
def test_user_recommendation():
"""测试用户推荐"""
print("\n=== 测试用户推荐 ===")
try:
service = RecommendationService()
# 获取一个真实用户ID
db_config = service.db_config
conn = pymysql.connect(**db_config)
cursor = conn.cursor()
cursor.execute("SELECT DISTINCT user_id FROM behaviors LIMIT 1")
result = cursor.fetchone()
if result:
user_id = result[0]
print(f"测试用户ID: {user_id}")
# 查看用户的历史行为
cursor.execute("""
SELECT b.type, COUNT(*) as count
FROM behaviors b
WHERE b.user_id = %s
GROUP BY b.type
""", (user_id,))
user_behaviors = cursor.fetchall()
print("用户历史行为:")
for behavior_type, count in user_behaviors:
print(f" {behavior_type}: {count}")
cursor.close()
conn.close()
# 尝试获取推荐 - 添加计时
print("开始生成推荐...")
start_time = time.time()
recommendations = service.get_recommendations(user_id, topk=10)
end_time = time.time()
# 计算推荐耗时
recommendation_time = end_time - start_time
print(f"用户推荐耗时: {recommendation_time:.4f} 秒")
print(f"用户推荐结果(用户{user_id}):")
for i, rec in enumerate(recommendations):
print(f" {i+1}. 帖子ID: {rec['id']}, 标题: {rec['title'][:50]}...")
print(f" 作者: {rec['username']}, 热度: {rec['heat']}")
print(f" 点赞: {rec.get('like_count', 0)}, 评论: {rec.get('comment_count', 0)}")
if 'recommendation_score' in rec:
print(f" 推荐分数: {rec['recommendation_score']:.4f}")
else:
print(f" 热度分数: {rec['heat']}")
print("用户推荐测试成功!")
return True
else:
print("没有找到有行为记录的用户")
cursor.close()
conn.close()
return False
except Exception as e:
print(f"用户推荐失败: {e}")
return False
def test_recommendation_performance():
"""测试推荐性能 - 多次调用统计"""
print("\n=== 测试推荐性能 ===")
try:
service = RecommendationService()
# 获取几个真实用户ID进行测试
db_config = service.db_config
conn = pymysql.connect(**db_config)
cursor = conn.cursor()
cursor.execute("SELECT DISTINCT user_id FROM behaviors LIMIT 5")
user_ids = [row[0] for row in cursor.fetchall()]
cursor.close()
conn.close()
if not user_ids:
print("没有找到有行为记录的用户")
return False
print(f"测试用户数量: {len(user_ids)}")
# 进行多次推荐测试
times = []
test_rounds = 3 # 每个用户测试3轮
for round_num in range(test_rounds):
print(f"\n第 {round_num + 1} 轮测试:")
round_times = []
for i, user_id in enumerate(user_ids):
start_time = time.time()
recommendations = service.get_recommendations(user_id, topk=10)
end_time = time.time()
recommendation_time = end_time - start_time
round_times.append(recommendation_time)
times.append(recommendation_time)
print(f" 用户 {user_id}: {recommendation_time:.4f}s, 推荐数量: {len(recommendations)}")
# 计算本轮统计
avg_time = sum(round_times) / len(round_times)
min_time = min(round_times)
max_time = max(round_times)
print(f" 本轮平均耗时: {avg_time:.4f}s, 最快: {min_time:.4f}s, 最慢: {max_time:.4f}s")
# 计算总体统计
print(f"\n=== 性能统计总结 ===")
print(f"总测试次数: {len(times)}")
print(f"平均推荐耗时: {sum(times) / len(times):.4f} 秒")
print(f"最快推荐耗时: {min(times):.4f} 秒")
print(f"最慢推荐耗时: {max(times):.4f} 秒")
print(f"推荐耗时标准差: {(sum([(t - sum(times)/len(times))**2 for t in times]) / len(times))**0.5:.4f} 秒")
# 性能等级评估
avg_time = sum(times) / len(times)
if avg_time < 0.1:
performance_level = "优秀"
elif avg_time < 0.5:
performance_level = "良好"
elif avg_time < 1.0:
performance_level = "一般"
else:
performance_level = "需要优化"
print(f"性能评级: {performance_level}")
print("推荐性能测试成功!")
return True
except Exception as e:
print(f"推荐性能测试失败: {e}")
return False
def main():
"""主测试函数"""
print("开始测试基于redbook数据库的推荐系统")
print("=" * 50)
tests = [
test_database_connection,
test_graph_building,
test_cold_start_recommendation,
test_user_recommendation,
test_recommendation_performance
]
passed = 0
total = len(tests)
for test in tests:
try:
if test():
passed += 1
except Exception as e:
print(f"测试异常: {e}")
print("\n" + "=" * 50)
print(f"测试完成: {passed}/{total} 通过")
if passed == total:
print("所有测试通过!")
else:
print("部分测试失败,请检查配置和代码")
if __name__ == "__main__":
main()