Raver | d789517 | 2025-06-18 17:54:38 +0800 | [diff] [blame^] | 1 | #!/usr/bin/env python3 |
| 2 | # -*- coding: utf-8 -*- |
| 3 | """ |
| 4 | 测试基于redbook数据库的推荐系统 |
| 5 | """ |
| 6 | |
| 7 | import sys |
| 8 | import os |
| 9 | import time |
| 10 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
| 11 | |
| 12 | from app.services.recommendation_service import RecommendationService |
| 13 | from app.utils.graph_build import build_user_post_graph |
| 14 | import pymysql |
| 15 | |
| 16 | def test_database_connection(): |
| 17 | """测试数据库连接""" |
| 18 | print("=== 测试数据库连接 ===") |
| 19 | try: |
| 20 | db_config = { |
| 21 | 'host': '10.126.59.25', |
| 22 | 'port': 3306, |
| 23 | 'user': 'root', |
| 24 | 'password': '123456', |
| 25 | 'database': 'redbook', |
| 26 | 'charset': 'utf8mb4' |
| 27 | } |
| 28 | conn = pymysql.connect(**db_config) |
| 29 | cursor = conn.cursor() |
| 30 | |
| 31 | # 检查用户数量 |
| 32 | cursor.execute("SELECT COUNT(*) FROM users") |
| 33 | user_count = cursor.fetchone()[0] |
| 34 | print(f"用户总数: {user_count}") |
| 35 | |
| 36 | # 检查帖子数量 |
| 37 | cursor.execute("SELECT COUNT(*) FROM posts WHERE status = 'published'") |
| 38 | post_count = cursor.fetchone()[0] |
| 39 | print(f"已发布帖子数: {post_count}") |
| 40 | |
| 41 | # 检查行为数据 |
| 42 | cursor.execute("SELECT type, COUNT(*) FROM behaviors GROUP BY type") |
| 43 | behavior_stats = cursor.fetchall() |
| 44 | print("行为统计:") |
| 45 | for behavior_type, count in behavior_stats: |
| 46 | print(f" {behavior_type}: {count}") |
| 47 | |
| 48 | cursor.close() |
| 49 | conn.close() |
| 50 | print("数据库连接测试成功!") |
| 51 | return True |
| 52 | except Exception as e: |
| 53 | print(f"数据库连接失败: {e}") |
| 54 | return False |
| 55 | |
| 56 | def test_graph_building(): |
| 57 | """测试图构建""" |
| 58 | print("\n=== 测试图构建 ===") |
| 59 | try: |
| 60 | user2idx, post2idx = build_user_post_graph(return_mapping=True) |
| 61 | print(f"用户数量: {len(user2idx)}") |
| 62 | print(f"帖子数量: {len(post2idx)}") |
| 63 | |
| 64 | # 显示前几个用户和帖子的映射 |
| 65 | print("前5个用户映射:") |
| 66 | for i, (user_id, idx) in enumerate(list(user2idx.items())[:5]): |
| 67 | print(f" 用户{user_id} -> 索引{idx}") |
| 68 | |
| 69 | print("前5个帖子映射:") |
| 70 | for i, (post_id, idx) in enumerate(list(post2idx.items())[:5]): |
| 71 | print(f" 帖子{post_id} -> 索引{idx}") |
| 72 | |
| 73 | print("图构建测试成功!") |
| 74 | return True |
| 75 | except Exception as e: |
| 76 | print(f"图构建失败: {e}") |
| 77 | return False |
| 78 | |
| 79 | def test_cold_start_recommendation(): |
| 80 | """测试冷启动推荐""" |
| 81 | print("\n=== 测试冷启动推荐 ===") |
| 82 | try: |
| 83 | service = RecommendationService() |
| 84 | |
| 85 | # 使用一个不存在的用户ID进行冷启动测试 |
| 86 | fake_user_id = 999999 |
| 87 | |
| 88 | # 计时开始 |
| 89 | start_time = time.time() |
| 90 | recommendations = service.get_recommendations(fake_user_id, topk=10) |
| 91 | end_time = time.time() |
| 92 | |
| 93 | # 计算推荐耗时 |
| 94 | recommendation_time = end_time - start_time |
| 95 | print(f"冷启动推荐耗时: {recommendation_time:.4f} 秒") |
| 96 | |
| 97 | print(f"冷启动推荐结果(用户{fake_user_id}):") |
| 98 | for i, rec in enumerate(recommendations): |
| 99 | print(f" {i+1}. 帖子ID: {rec['post_id']}, 标题: {rec['title'][:50]}...") |
| 100 | print(f" 作者: {rec['username']}, 热度: {rec['heat']}") |
| 101 | print(f" 点赞: {rec.get('like_count', 0)}, 评论: {rec.get('comment_count', 0)}") |
| 102 | |
| 103 | print("冷启动推荐测试成功!") |
| 104 | return True |
| 105 | except Exception as e: |
| 106 | print(f"冷启动推荐失败: {e}") |
| 107 | return False |
| 108 | |
| 109 | def test_user_recommendation(): |
| 110 | """测试用户推荐""" |
| 111 | print("\n=== 测试用户推荐 ===") |
| 112 | try: |
| 113 | service = RecommendationService() |
| 114 | |
| 115 | # 获取一个真实用户ID |
| 116 | db_config = service.db_config |
| 117 | conn = pymysql.connect(**db_config) |
| 118 | cursor = conn.cursor() |
| 119 | cursor.execute("SELECT DISTINCT user_id FROM behaviors LIMIT 1") |
| 120 | result = cursor.fetchone() |
| 121 | |
| 122 | if result: |
| 123 | user_id = result[0] |
| 124 | print(f"测试用户ID: {user_id}") |
| 125 | |
| 126 | # 查看用户的历史行为 |
| 127 | cursor.execute(""" |
| 128 | SELECT b.type, COUNT(*) as count |
| 129 | FROM behaviors b |
| 130 | WHERE b.user_id = %s |
| 131 | GROUP BY b.type |
| 132 | """, (user_id,)) |
| 133 | user_behaviors = cursor.fetchall() |
| 134 | print("用户历史行为:") |
| 135 | for behavior_type, count in user_behaviors: |
| 136 | print(f" {behavior_type}: {count}") |
| 137 | |
| 138 | cursor.close() |
| 139 | conn.close() |
| 140 | |
| 141 | # 尝试获取推荐 - 添加计时 |
| 142 | print("开始生成推荐...") |
| 143 | start_time = time.time() |
| 144 | recommendations = service.get_recommendations(user_id, topk=10) |
| 145 | end_time = time.time() |
| 146 | |
| 147 | # 计算推荐耗时 |
| 148 | recommendation_time = end_time - start_time |
| 149 | print(f"用户推荐耗时: {recommendation_time:.4f} 秒") |
| 150 | |
| 151 | print(f"用户推荐结果(用户{user_id}):") |
| 152 | for i, rec in enumerate(recommendations): |
| 153 | print(f" {i+1}. 帖子ID: {rec['post_id']}, 标题: {rec['title'][:50]}...") |
| 154 | print(f" 作者: {rec['username']}, 热度: {rec['heat']}") |
| 155 | print(f" 点赞: {rec.get('like_count', 0)}, 评论: {rec.get('comment_count', 0)}") |
| 156 | if 'recommendation_score' in rec: |
| 157 | print(f" 推荐分数: {rec['recommendation_score']:.4f}") |
| 158 | else: |
| 159 | print(f" 热度分数: {rec['heat']}") |
| 160 | |
| 161 | print("用户推荐测试成功!") |
| 162 | return True |
| 163 | else: |
| 164 | print("没有找到有行为记录的用户") |
| 165 | cursor.close() |
| 166 | conn.close() |
| 167 | return False |
| 168 | |
| 169 | except Exception as e: |
| 170 | print(f"用户推荐失败: {e}") |
| 171 | return False |
| 172 | |
| 173 | def test_recommendation_performance(): |
| 174 | """测试推荐性能 - 多次调用统计""" |
| 175 | print("\n=== 测试推荐性能 ===") |
| 176 | try: |
| 177 | service = RecommendationService() |
| 178 | |
| 179 | # 获取几个真实用户ID进行测试 |
| 180 | db_config = service.db_config |
| 181 | conn = pymysql.connect(**db_config) |
| 182 | cursor = conn.cursor() |
| 183 | cursor.execute("SELECT DISTINCT user_id FROM behaviors LIMIT 5") |
| 184 | user_ids = [row[0] for row in cursor.fetchall()] |
| 185 | cursor.close() |
| 186 | conn.close() |
| 187 | |
| 188 | if not user_ids: |
| 189 | print("没有找到有行为记录的用户") |
| 190 | return False |
| 191 | |
| 192 | print(f"测试用户数量: {len(user_ids)}") |
| 193 | |
| 194 | # 进行多次推荐测试 |
| 195 | times = [] |
| 196 | test_rounds = 3 # 每个用户测试3轮 |
| 197 | |
| 198 | for round_num in range(test_rounds): |
| 199 | print(f"\n第 {round_num + 1} 轮测试:") |
| 200 | round_times = [] |
| 201 | |
| 202 | for i, user_id in enumerate(user_ids): |
| 203 | start_time = time.time() |
| 204 | recommendations = service.get_recommendations(user_id, topk=10) |
| 205 | end_time = time.time() |
| 206 | |
| 207 | recommendation_time = end_time - start_time |
| 208 | round_times.append(recommendation_time) |
| 209 | times.append(recommendation_time) |
| 210 | |
| 211 | print(f" 用户 {user_id}: {recommendation_time:.4f}s, 推荐数量: {len(recommendations)}") |
| 212 | |
| 213 | # 计算本轮统计 |
| 214 | avg_time = sum(round_times) / len(round_times) |
| 215 | min_time = min(round_times) |
| 216 | max_time = max(round_times) |
| 217 | print(f" 本轮平均耗时: {avg_time:.4f}s, 最快: {min_time:.4f}s, 最慢: {max_time:.4f}s") |
| 218 | |
| 219 | # 计算总体统计 |
| 220 | print(f"\n=== 性能统计总结 ===") |
| 221 | print(f"总测试次数: {len(times)}") |
| 222 | print(f"平均推荐耗时: {sum(times) / len(times):.4f} 秒") |
| 223 | print(f"最快推荐耗时: {min(times):.4f} 秒") |
| 224 | print(f"最慢推荐耗时: {max(times):.4f} 秒") |
| 225 | print(f"推荐耗时标准差: {(sum([(t - sum(times)/len(times))**2 for t in times]) / len(times))**0.5:.4f} 秒") |
| 226 | |
| 227 | # 性能等级评估 |
| 228 | avg_time = sum(times) / len(times) |
| 229 | if avg_time < 0.1: |
| 230 | performance_level = "优秀" |
| 231 | elif avg_time < 0.5: |
| 232 | performance_level = "良好" |
| 233 | elif avg_time < 1.0: |
| 234 | performance_level = "一般" |
| 235 | else: |
| 236 | performance_level = "需要优化" |
| 237 | |
| 238 | print(f"性能评级: {performance_level}") |
| 239 | |
| 240 | print("推荐性能测试成功!") |
| 241 | return True |
| 242 | |
| 243 | except Exception as e: |
| 244 | print(f"推荐性能测试失败: {e}") |
| 245 | return False |
| 246 | |
| 247 | def main(): |
| 248 | """主测试函数""" |
| 249 | print("开始测试基于redbook数据库的推荐系统") |
| 250 | print("=" * 50) |
| 251 | |
| 252 | tests = [ |
| 253 | test_database_connection, |
| 254 | test_graph_building, |
| 255 | test_cold_start_recommendation, |
| 256 | test_user_recommendation, |
| 257 | test_recommendation_performance |
| 258 | ] |
| 259 | |
| 260 | passed = 0 |
| 261 | total = len(tests) |
| 262 | |
| 263 | for test in tests: |
| 264 | try: |
| 265 | if test(): |
| 266 | passed += 1 |
| 267 | except Exception as e: |
| 268 | print(f"测试异常: {e}") |
| 269 | |
| 270 | print("\n" + "=" * 50) |
| 271 | print(f"测试完成: {passed}/{total} 通过") |
| 272 | |
| 273 | if passed == total: |
| 274 | print("所有测试通过!") |
| 275 | else: |
| 276 | print("部分测试失败,请检查配置和代码") |
| 277 | |
| 278 | if __name__ == "__main__": |
| 279 | main() |