| #!/usr/bin/env python3 |
| # -*- coding: utf-8 -*- |
| """ |
| 测试基于redbook数据库的推荐系统 |
| """ |
| |
| import sys |
| import os |
| import time |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
| |
| from app.services.recommendation_service import RecommendationService |
| from app.utils.graph_build import build_user_post_graph |
| import pymysql |
| |
| def test_database_connection(): |
| """测试数据库连接""" |
| print("=== 测试数据库连接 ===") |
| try: |
| db_config = { |
| 'host': '10.126.59.25', |
| 'port': 3306, |
| 'user': 'root', |
| 'password': '123456', |
| 'database': 'redbook', |
| 'charset': 'utf8mb4' |
| } |
| conn = pymysql.connect(**db_config) |
| cursor = conn.cursor() |
| |
| # 检查用户数量 |
| cursor.execute("SELECT COUNT(*) FROM users") |
| user_count = cursor.fetchone()[0] |
| print(f"用户总数: {user_count}") |
| |
| # 检查帖子数量 |
| cursor.execute("SELECT COUNT(*) FROM posts WHERE status = 'published'") |
| post_count = cursor.fetchone()[0] |
| print(f"已发布帖子数: {post_count}") |
| |
| # 检查行为数据 |
| cursor.execute("SELECT type, COUNT(*) FROM behaviors GROUP BY type") |
| behavior_stats = cursor.fetchall() |
| print("行为统计:") |
| for behavior_type, count in behavior_stats: |
| print(f" {behavior_type}: {count}") |
| |
| cursor.close() |
| conn.close() |
| print("数据库连接测试成功!") |
| return True |
| except Exception as e: |
| print(f"数据库连接失败: {e}") |
| return False |
| |
| def test_graph_building(): |
| """测试图构建""" |
| print("\n=== 测试图构建 ===") |
| try: |
| user2idx, post2idx = build_user_post_graph(return_mapping=True) |
| print(f"用户数量: {len(user2idx)}") |
| print(f"帖子数量: {len(post2idx)}") |
| |
| # 显示前几个用户和帖子的映射 |
| print("前5个用户映射:") |
| for i, (user_id, idx) in enumerate(list(user2idx.items())[:5]): |
| print(f" 用户{user_id} -> 索引{idx}") |
| |
| print("前5个帖子映射:") |
| for i, (post_id, idx) in enumerate(list(post2idx.items())[:5]): |
| print(f" 帖子{post_id} -> 索引{idx}") |
| |
| print("图构建测试成功!") |
| return True |
| except Exception as e: |
| print(f"图构建失败: {e}") |
| return False |
| |
| def test_cold_start_recommendation(): |
| """测试冷启动推荐""" |
| print("\n=== 测试冷启动推荐 ===") |
| try: |
| service = RecommendationService() |
| |
| # 使用一个不存在的用户ID进行冷启动测试 |
| fake_user_id = 999999 |
| |
| # 计时开始 |
| start_time = time.time() |
| recommendations = service.get_recommendations(fake_user_id, topk=10) |
| end_time = time.time() |
| |
| # 计算推荐耗时 |
| recommendation_time = end_time - start_time |
| print(f"冷启动推荐耗时: {recommendation_time:.4f} 秒") |
| |
| print(f"冷启动推荐结果(用户{fake_user_id}):") |
| for i, rec in enumerate(recommendations): |
| print(f" {i+1}. 帖子ID: {rec['id']}, 标题: {rec['title'][:50]}...") |
| print(f" 作者: {rec['username']}, 热度: {rec['heat']}") |
| print(f" 点赞: {rec.get('like_count', 0)}, 评论: {rec.get('comment_count', 0)}") |
| |
| print("冷启动推荐测试成功!") |
| return True |
| except Exception as e: |
| print(f"冷启动推荐失败: {e}") |
| return False |
| |
| def test_user_recommendation(): |
| """测试用户推荐""" |
| print("\n=== 测试用户推荐 ===") |
| try: |
| service = RecommendationService() |
| |
| # 获取一个真实用户ID |
| db_config = service.db_config |
| conn = pymysql.connect(**db_config) |
| cursor = conn.cursor() |
| cursor.execute("SELECT DISTINCT user_id FROM behaviors LIMIT 1") |
| result = cursor.fetchone() |
| |
| if result: |
| user_id = result[0] |
| print(f"测试用户ID: {user_id}") |
| |
| # 查看用户的历史行为 |
| cursor.execute(""" |
| SELECT b.type, COUNT(*) as count |
| FROM behaviors b |
| WHERE b.user_id = %s |
| GROUP BY b.type |
| """, (user_id,)) |
| user_behaviors = cursor.fetchall() |
| print("用户历史行为:") |
| for behavior_type, count in user_behaviors: |
| print(f" {behavior_type}: {count}") |
| |
| cursor.close() |
| conn.close() |
| |
| # 尝试获取推荐 - 添加计时 |
| print("开始生成推荐...") |
| start_time = time.time() |
| recommendations = service.get_recommendations(user_id, topk=10) |
| end_time = time.time() |
| |
| # 计算推荐耗时 |
| recommendation_time = end_time - start_time |
| print(f"用户推荐耗时: {recommendation_time:.4f} 秒") |
| |
| print(f"用户推荐结果(用户{user_id}):") |
| for i, rec in enumerate(recommendations): |
| print(f" {i+1}. 帖子ID: {rec['id']}, 标题: {rec['title'][:50]}...") |
| print(f" 作者: {rec['username']}, 热度: {rec['heat']}") |
| print(f" 点赞: {rec.get('like_count', 0)}, 评论: {rec.get('comment_count', 0)}") |
| if 'recommendation_score' in rec: |
| print(f" 推荐分数: {rec['recommendation_score']:.4f}") |
| else: |
| print(f" 热度分数: {rec['heat']}") |
| |
| print("用户推荐测试成功!") |
| return True |
| else: |
| print("没有找到有行为记录的用户") |
| cursor.close() |
| conn.close() |
| return False |
| |
| except Exception as e: |
| print(f"用户推荐失败: {e}") |
| return False |
| |
| def test_recommendation_performance(): |
| """测试推荐性能 - 多次调用统计""" |
| print("\n=== 测试推荐性能 ===") |
| try: |
| service = RecommendationService() |
| |
| # 获取几个真实用户ID进行测试 |
| db_config = service.db_config |
| conn = pymysql.connect(**db_config) |
| cursor = conn.cursor() |
| cursor.execute("SELECT DISTINCT user_id FROM behaviors LIMIT 5") |
| user_ids = [row[0] for row in cursor.fetchall()] |
| cursor.close() |
| conn.close() |
| |
| if not user_ids: |
| print("没有找到有行为记录的用户") |
| return False |
| |
| print(f"测试用户数量: {len(user_ids)}") |
| |
| # 进行多次推荐测试 |
| times = [] |
| test_rounds = 3 # 每个用户测试3轮 |
| |
| for round_num in range(test_rounds): |
| print(f"\n第 {round_num + 1} 轮测试:") |
| round_times = [] |
| |
| for i, user_id in enumerate(user_ids): |
| start_time = time.time() |
| recommendations = service.get_recommendations(user_id, topk=10) |
| end_time = time.time() |
| |
| recommendation_time = end_time - start_time |
| round_times.append(recommendation_time) |
| times.append(recommendation_time) |
| |
| print(f" 用户 {user_id}: {recommendation_time:.4f}s, 推荐数量: {len(recommendations)}") |
| |
| # 计算本轮统计 |
| avg_time = sum(round_times) / len(round_times) |
| min_time = min(round_times) |
| max_time = max(round_times) |
| print(f" 本轮平均耗时: {avg_time:.4f}s, 最快: {min_time:.4f}s, 最慢: {max_time:.4f}s") |
| |
| # 计算总体统计 |
| print(f"\n=== 性能统计总结 ===") |
| print(f"总测试次数: {len(times)}") |
| print(f"平均推荐耗时: {sum(times) / len(times):.4f} 秒") |
| print(f"最快推荐耗时: {min(times):.4f} 秒") |
| print(f"最慢推荐耗时: {max(times):.4f} 秒") |
| print(f"推荐耗时标准差: {(sum([(t - sum(times)/len(times))**2 for t in times]) / len(times))**0.5:.4f} 秒") |
| |
| # 性能等级评估 |
| avg_time = sum(times) / len(times) |
| if avg_time < 0.1: |
| performance_level = "优秀" |
| elif avg_time < 0.5: |
| performance_level = "良好" |
| elif avg_time < 1.0: |
| performance_level = "一般" |
| else: |
| performance_level = "需要优化" |
| |
| print(f"性能评级: {performance_level}") |
| |
| print("推荐性能测试成功!") |
| return True |
| |
| except Exception as e: |
| print(f"推荐性能测试失败: {e}") |
| return False |
| |
| def main(): |
| """主测试函数""" |
| print("开始测试基于redbook数据库的推荐系统") |
| print("=" * 50) |
| |
| tests = [ |
| test_database_connection, |
| test_graph_building, |
| test_cold_start_recommendation, |
| test_user_recommendation, |
| test_recommendation_performance |
| ] |
| |
| passed = 0 |
| total = len(tests) |
| |
| for test in tests: |
| try: |
| if test(): |
| passed += 1 |
| except Exception as e: |
| print(f"测试异常: {e}") |
| |
| print("\n" + "=" * 50) |
| print(f"测试完成: {passed}/{total} 通过") |
| |
| if passed == total: |
| print("所有测试通过!") |
| else: |
| print("部分测试失败,请检查配置和代码") |
| |
| if __name__ == "__main__": |
| main() |