22301110 | f2e3c09 | 2025-06-05 01:24:43 +0800 | [diff] [blame] | 1 | import os |
| 2 | import time |
| 3 | import jieba |
| 4 | import fasttext |
| 5 | import pandas as pd |
| 6 | from flask import Flask, request, jsonify |
| 7 | from sqlalchemy import create_engine |
| 8 | from scipy.sparse import coo_matrix |
| 9 | from sklearn.metrics.pairwise import cosine_similarity |
| 10 | import pickle |
| 11 | |
| 12 | app = Flask(__name__) |
| 13 | |
| 14 | # === ✅ SQLAlchemy 数据库连接 === |
| 15 | engine = create_engine("mysql+pymysql://sy:sy_password@49.233.215.144:3306/pt_station") |
| 16 | |
22301110 | f2e3c09 | 2025-06-05 01:24:43 +0800 | [diff] [blame] | 17 | |
| 18 | # === ✅ 用户标签行为矩阵构建 === |
| 19 | def get_user_tag_matrix(): |
| 20 | df = pd.read_sql("SELECT user_id, tag, score FROM user_tag_scores", engine) |
Atopos0524 | 878db00 | 2025-06-08 22:36:57 +0800 | [diff] [blame] | 21 | #print(df) |
22301110 | f2e3c09 | 2025-06-05 01:24:43 +0800 | [diff] [blame] | 22 | df['user_id'] = df['user_id'].astype(str) |
| 23 | user_map = {u: i for i, u in enumerate(df['user_id'].unique())} |
| 24 | tag_map = {t: i for i, t in enumerate(df['tag'].unique())} |
| 25 | df['user_index'] = df['user_id'].map(user_map) |
| 26 | df['tag_index'] = df['tag'].map(tag_map) |
| 27 | matrix = df.pivot_table(index='user_id', columns='tag', values='score', fill_value=0) |
| 28 | sparse_matrix = coo_matrix((df['score'], (df['tag_index'], df['user_index']))) |
| 29 | return df, matrix, sparse_matrix, user_map, tag_map |
| 30 | |
| 31 | # === ✅ 基于 fastText 的语义相似推荐方法 === |
| 32 | def semantic_recommend(user_id, topn=5): |
| 33 | print(f"正在为用户 {user_id} 生成推荐...") |
| 34 | |
Atopos0524 | 878db00 | 2025-06-08 22:36:57 +0800 | [diff] [blame] | 35 | # 读取数据 |
22301110 | f2e3c09 | 2025-06-05 01:24:43 +0800 | [diff] [blame] | 36 | df = pd.read_sql("SELECT user_id, tag, score FROM user_tag_scores", engine) |
22301110 | f2e3c09 | 2025-06-05 01:24:43 +0800 | [diff] [blame] | 37 | |
Atopos0524 | 878db00 | 2025-06-08 22:36:57 +0800 | [diff] [blame] | 38 | # 统一类型转换 |
| 39 | df['user_id'] = df['user_id'].astype(str) # 确保整个列转为字符串 |
| 40 | user_id = str(user_id) # 要查询的ID也转为字符串 |
| 41 | |
| 42 | # 现在查询应该正常工作了 |
22301110 | f2e3c09 | 2025-06-05 01:24:43 +0800 | [diff] [blame] | 43 | user_tags = df[df['user_id'] == user_id].sort_values(by="score", ascending=False)['tag'].tolist() |
| 44 | print(f"用户 {user_id} 的标签(按分数排序): {user_tags}") |
| 45 | |
| 46 | if not user_tags: |
| 47 | print(f"用户 {user_id} 没有标签记录,返回空推荐结果。") |
| 48 | return [] |
Atopos0524 | 878db00 | 2025-06-08 22:36:57 +0800 | [diff] [blame] | 49 | else: |
| 50 | user_tags = user_tags[:3] |
| 51 | print(f"用户 {user_id} 的 Top 3 标签: {user_tags}") |
| 52 | |
| 53 | if not user_tags: |
| 54 | print(f"用户 {user_id} 没有标签记录,返回空推荐结果。") |
| 55 | return [] |
22301110 | f2e3c09 | 2025-06-05 01:24:43 +0800 | [diff] [blame] | 56 | |
| 57 | # 截取前 3 个标签作为“兴趣标签” |
| 58 | user_tags = user_tags[:3] |
| 59 | print(f"用户 {user_id} 的 Top 3 标签: {user_tags}") |
| 60 | |
| 61 | # 构造所有标签的词向量 |
| 62 | all_tags = df['tag'].unique() |
| 63 | print(f"所有唯一标签数量: {len(all_tags)}") |
| 64 | |
| 65 | tag_vectors = {} |
| 66 | for tag in all_tags: |
| 67 | vec = ft_model.get_word_vector(tag) |
| 68 | tag_vectors[tag] = vec |
| 69 | |
| 70 | # 计算未出现过标签的相似度得分 |
| 71 | scores = {} |
| 72 | for tag in all_tags: |
| 73 | if tag in user_tags: |
| 74 | continue |
| 75 | vec = tag_vectors[tag] |
| 76 | sim_total = 0.0 |
| 77 | for t in user_tags: |
| 78 | sim = cosine_similarity([vec], [ft_model.get_word_vector(t)])[0][0] |
| 79 | print(f"标签 [{tag}] 与用户标签 [{t}] 的相似度: {sim:.4f}") |
| 80 | sim_total += sim |
| 81 | avg_score = sim_total / len(user_tags) |
| 82 | scores[tag] = avg_score |
| 83 | print(f"标签 [{tag}] 的平均相似度得分: {avg_score:.4f}") |
| 84 | |
| 85 | # 排序并返回 topN 标签 |
| 86 | sorted_tags = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:topn] |
| 87 | print(f"\n最终推荐标签(前 {topn}):") |
Atopos0524 | 878db00 | 2025-06-08 22:36:57 +0800 | [diff] [blame] | 88 | #for tag, score in sorted_tags: |
| 89 | # print(f"{tag}: {score:.4f}") |
22301110 | f2e3c09 | 2025-06-05 01:24:43 +0800 | [diff] [blame] | 90 | |
| 91 | return [tag for tag, _ in sorted_tags] |
| 92 | |
| 93 | # === ✅ ItemCF 推荐方法 === |
| 94 | import os |
| 95 | import pickle |
| 96 | |
| 97 | def itemcf_recommend(user_id, matrix, sim_path="./models/itemcf_sim.pkl", topn=5): |
| 98 | user_id = str(user_id) # 确保 user_id 类型一致 |
| 99 | print(matrix.index.dtype) |
| 100 | print(type(user_id)) # 应该是 str |
| 101 | |
| 102 | if user_id not in matrix.index: |
| 103 | print(f"⚠️ 用户 {user_id} 不在评分矩阵中。") |
| 104 | return [] |
| 105 | |
| 106 | if not os.path.exists(sim_path): |
| 107 | print(f"⚠️ 用户 {user_id} 不在评分矩阵中。") |
| 108 | train_and_save_itemcf() |
| 109 | |
| 110 | with open(sim_path, "rb") as f: |
| 111 | sim_df = pickle.load(f) |
| 112 | |
| 113 | user_row = matrix.loc[user_id] |
| 114 | user_tags = user_row[user_row > 0] |
| 115 | |
| 116 | if user_tags.empty: |
| 117 | print(f"⚠️ 用户 {user_id} 没有任何标签评分记录。") |
| 118 | return [] |
| 119 | |
22301110 | f2e3c09 | 2025-06-05 01:24:43 +0800 | [diff] [blame] | 120 | scores = {} |
| 121 | for tag, val in user_tags.items(): |
| 122 | if tag not in sim_df: |
| 123 | print(f"标签 {tag} 在相似度矩阵中不存在,跳过。") |
| 124 | continue |
| 125 | sims = sim_df[tag].drop(index=user_tags.index, errors="ignore") |
| 126 | for sim_tag, sim_score in sims.items(): |
| 127 | scores[sim_tag] = scores.get(sim_tag, 0) + sim_score * val |
| 128 | |
| 129 | if not scores: |
| 130 | print(f"⚠️ 用户 {user_id} 无法生成推荐,可能是标签相似度不足。") |
| 131 | return [] |
| 132 | |
| 133 | sorted_tags = sorted(scores.items(), key=lambda x: x[1], reverse=True) |
| 134 | print(f"推荐得分(前{topn}):\n", sorted_tags[:topn]) |
| 135 | |
| 136 | return [tag for tag, _ in sorted_tags[:topn]] |
| 137 | |
| 138 | |
| 139 | # === ✅ ItemCF 相似度训练 === |
| 140 | def train_and_save_itemcf(path="./models/itemcf_sim.pkl"): |
| 141 | _, matrix, _, _, _ = get_user_tag_matrix() |
| 142 | tag_sim = cosine_similarity(matrix.T) |
| 143 | sim_df = pd.DataFrame(tag_sim, index=matrix.columns, columns=matrix.columns) |
| 144 | with open(path, "wb") as f: |
| 145 | pickle.dump(sim_df, f) |
| 146 | print("ItemCF 相似度矩阵已保存 ✅") |
| 147 | |
| 148 | # === ✅ Flask 推荐接口 === |
| 149 | import random |
| 150 | |
| 151 | @app.route("/recommend_torrents", methods=["POST"]) |
| 152 | def recommend_torrents(): |
| 153 | data = request.get_json() |
| 154 | user_id = data.get("user_id") |
| 155 | |
| 156 | if not user_id: |
| 157 | return jsonify({"error": "缺少 user_id"}), 400 |
| 158 | |
| 159 | df, matrix, _, _, _ = get_user_tag_matrix() |
| 160 | |
| 161 | # 获取推荐标签 |
| 162 | itemcf_result = itemcf_recommend(user_id, matrix) |
| 163 | semantic_result = semantic_recommend(user_id) |
| 164 | |
| 165 | |
| 166 | print(f"ItemCF 推荐标签: {itemcf_result}") |
| 167 | print(f"Semantic 推荐标签: {semantic_result}") |
| 168 | |
| 169 | all_tags = df['tag'].unique().tolist() |
| 170 | |
| 171 | # 存储标签及其推荐得分 |
| 172 | combined = [] |
| 173 | used_tags = set() |
| 174 | |
| 175 | def add_unique_tags(tags, method_name): |
| 176 | for tag in tags: |
| 177 | if tag not in used_tags: |
| 178 | random_priority = random.uniform(0, 1) |
| 179 | if method_name == 'ItemCF': |
| 180 | combined.append((tag, 'ItemCF', random_priority)) |
| 181 | elif method_name == 'Semantic': |
| 182 | combined.append((tag, 'Semantic', random_priority)) |
| 183 | used_tags.add(tag) |
| 184 | |
| 185 | # 添加 ItemCF 和 Semantic 推荐 |
| 186 | add_unique_tags(itemcf_result, 'ItemCF') |
| 187 | add_unique_tags(semantic_result, 'Semantic') |
| 188 | |
| 189 | # 添加随机标签 |
| 190 | random.shuffle(all_tags) |
| 191 | add_unique_tags(all_tags, 'Random') |
| 192 | |
| 193 | # 排序:按推荐得分排序,加入的随机值也会影响排序 |
| 194 | combined.sort(key=lambda x: x[2], reverse=True) |
| 195 | |
| 196 | # 根据标签获取种子 ID |
| 197 | final_tags = [tag for tag, _, _ in combined] |
| 198 | print(f"最终推荐标签: {final_tags}") |
| 199 | torrent_ids = get_torrent_ids_by_tags(final_tags) |
| 200 | |
| 201 | return jsonify({"torrent_ids": torrent_ids}) |
| 202 | |
| 203 | |
| 204 | |
| 205 | from sqlalchemy.sql import text |
| 206 | |
| 207 | import random |
| 208 | from sqlalchemy import text |
| 209 | |
| 210 | def get_torrent_ids_by_tags(tags, limit_per_tag=10): |
| 211 | if not tags: |
| 212 | tags = [] |
Atopos0524 | 878db00 | 2025-06-08 22:36:57 +0800 | [diff] [blame] | 213 | print(f"传递给 get_torrent_ids_by_tags 的标签: {tags}") |
22301110 | f2e3c09 | 2025-06-05 01:24:43 +0800 | [diff] [blame] | 214 | |
| 215 | recommended_ids = set() |
| 216 | with engine.connect() as conn: |
| 217 | for tag in tags: |
| 218 | query = text(""" |
| 219 | SELECT torrent_id |
Atopos0524 | 878db00 | 2025-06-08 22:36:57 +0800 | [diff] [blame] | 220 | FROM bt_torrent_tags |
| 221 | WHERE tag = :tag |
22301110 | f2e3c09 | 2025-06-05 01:24:43 +0800 | [diff] [blame] | 222 | LIMIT :limit |
| 223 | """) |
| 224 | result = conn.execute(query, {"tag": tag, "limit": limit_per_tag}) |
Atopos0524 | 878db00 | 2025-06-08 22:36:57 +0800 | [diff] [blame] | 225 | print(f"标签 '{tag}' 的推荐结果:") |
22301110 | f2e3c09 | 2025-06-05 01:24:43 +0800 | [diff] [blame] | 226 | for row in result: |
Atopos0524 | 878db00 | 2025-06-08 22:36:57 +0800 | [diff] [blame] | 227 | print(row[0]) # 打印每个torrent_id |
22301110 | f2e3c09 | 2025-06-05 01:24:43 +0800 | [diff] [blame] | 228 | recommended_ids.add(row[0]) |
| 229 | |
| 230 | # 获取数据库中所有 torrent_id |
Atopos0524 | 878db00 | 2025-06-08 22:36:57 +0800 | [diff] [blame] | 231 | all_query = text("SELECT DISTINCT torrent_id FROM bt_torrent") |
22301110 | f2e3c09 | 2025-06-05 01:24:43 +0800 | [diff] [blame] | 232 | all_result = conn.execute(all_query) |
| 233 | all_ids = set(row[0] for row in all_result) |
Atopos0524 | 878db00 | 2025-06-08 22:36:57 +0800 | [diff] [blame] | 234 | print("数据库中所有torrent_id:", all_ids) |
22301110 | f2e3c09 | 2025-06-05 01:24:43 +0800 | [diff] [blame] | 235 | |
| 236 | # 剩下的(非推荐)种子 ID |
| 237 | remaining_ids = all_ids - recommended_ids |
Atopos0524 | 878db00 | 2025-06-08 22:36:57 +0800 | [diff] [blame] | 238 | print(remaining_ids) |
22301110 | f2e3c09 | 2025-06-05 01:24:43 +0800 | [diff] [blame] | 239 | # 随机打乱推荐和剩下的 ID |
| 240 | recommended_list = list(recommended_ids) |
| 241 | remaining_list = list(remaining_ids) |
| 242 | random.shuffle(recommended_list) |
| 243 | random.shuffle(remaining_list) |
| 244 | |
| 245 | return recommended_list + remaining_list |
| 246 | |
| 247 | |
22301110 | 9dfbcee | 2025-06-09 22:27:01 +0800 | [diff] [blame] | 248 | import os |
| 249 | import time |
| 250 | import jieba |
| 251 | import fasttext |
| 252 | from flask import Flask, request, jsonify |
| 253 | import mysql.connector |
| 254 | |
| 255 | # ✅ 初始化数据库连接 |
| 256 | import mysql.connector |
| 257 | from sklearn.metrics.pairwise import cosine_similarity |
| 258 | import numpy as np |
| 259 | |
| 260 | def compute_similarity(vec1, vec2): |
| 261 | """计算两个词向量之间的余弦相似度""" |
| 262 | # 使用 cosine_similarity 计算相似度 |
| 263 | return cosine_similarity([vec1], [vec2])[0][0] |
| 264 | |
| 265 | db = mysql.connector.connect( |
| 266 | host="49.233.215.144", |
| 267 | port=3306, |
| 268 | user="sy", |
| 269 | password="sy_password", |
| 270 | database="pt_station" |
| 271 | ) |
| 272 | |
| 273 | cursor = db.cursor() |
| 274 | |
| 275 | |
| 276 | app = Flask(__name__) |
| 277 | |
| 278 | # ✅ 加载 fastText 模型 |
| 279 | fasttext_model_path = './models/cc.zh.300.bin' |
| 280 | if not os.path.exists(fasttext_model_path): |
| 281 | raise FileNotFoundError("fastText 模型文件不存在,请检查路径。") |
| 282 | |
| 283 | print("加载 fastText 模型中...") |
| 284 | ft_model = fasttext.load_model(fasttext_model_path) |
| 285 | print("模型加载完成 ✅") |
| 286 | fasttext_vocab = set(ft_model.words) |
| 287 | |
| 288 | # ✅ 全局标签缓存(用于避免频繁查询) |
| 289 | existing_tags = set() |
| 290 | |
| 291 | def refresh_existing_tags(): |
| 292 | """刷新数据库中已存在的标签集合""" |
| 293 | global existing_tags |
| 294 | cursor.execute("SELECT DISTINCT tag FROM user_tag_scores") |
| 295 | existing_tags = set(tag[0] for tag in cursor.fetchall()) |
| 296 | print(f"已加载标签数: {len(existing_tags)}") |
| 297 | |
| 298 | # ✅ 启动时初始化标签缓存 |
| 299 | refresh_existing_tags() |
| 300 | |
| 301 | # ✅ 扩展函数:仅保留数据库已有标签 |
| 302 | def expand_tags_from_input(input_tags, topn=5, similarity_threshold=0.7): |
| 303 | """ |
| 304 | 扩展输入标签列表,查找与之语义相似的标签,返回相似度大于阈值或最相似的前 n 个标签。 |
| 305 | |
| 306 | :param input_tags: 输入标签的列表,例如 ['电影', '动漫', '游戏', '1080p'] |
| 307 | :param topn: 返回的最相似标签的数量,默认为 5 |
| 308 | :param similarity_threshold: 相似度阈值,默认为 0.7 |
| 309 | :return: 返回与输入标签相关的扩展标签列表 |
| 310 | """ |
| 311 | # 用于存储所有扩展标签及其相似度 |
| 312 | tag_scores = {} |
| 313 | |
| 314 | for tag in input_tags: |
| 315 | # 获取当前标签的词向量 |
| 316 | tag_vector = ft_model.get_word_vector(tag) |
| 317 | |
| 318 | # 遍历标签库中的所有标签并计算相似度 |
| 319 | for db_tag in existing_tags: |
| 320 | db_tag_vector = ft_model.get_word_vector(db_tag) |
| 321 | similarity = compute_similarity(tag_vector, db_tag_vector) |
| 322 | |
| 323 | if similarity >= similarity_threshold: |
| 324 | if db_tag not in tag_scores: |
| 325 | tag_scores[db_tag] = similarity |
| 326 | else: |
| 327 | tag_scores[db_tag] = max(tag_scores[db_tag], similarity) |
| 328 | |
| 329 | # 根据相似度排序并返回前 n 个标签 |
| 330 | sorted_tags = sorted(tag_scores.items(), key=lambda x: x[1], reverse=True) |
| 331 | top_tags = [tag for tag, _ in sorted_tags[:topn]] |
| 332 | |
| 333 | return top_tags |
| 334 | |
| 335 | |
| 336 | # ✅ 接口路由 |
| 337 | @app.route("/expand_tags", methods=["POST"]) |
| 338 | def expand_tags(): |
| 339 | start_time = time.time() |
| 340 | |
| 341 | # 从请求中获取数据 |
| 342 | data = request.get_json() |
| 343 | input_tags = data.get("tags", []) |
| 344 | user_id = data.get("user_id") |
| 345 | rate = data.get("rate") |
| 346 | |
| 347 | topn = data.get("topn", 10) # 默认为 5 |
| 348 | similarity_threshold = 0.4 # 默认阈值为 0.7 |
| 349 | |
| 350 | if not input_tags or not user_id: |
| 351 | return jsonify({"error": "缺少参数 tags 或 user_id"}), 400 |
| 352 | |
| 353 | # 获取与输入标签最相关的标签 |
| 354 | expanded_tags = expand_tags_from_input(input_tags, topn=topn, similarity_threshold=similarity_threshold) |
| 355 | |
| 356 | # 打印日志 |
| 357 | print(f"[用户 {user_id}] 输入标签: {input_tags}") |
| 358 | print(f"[用户 {user_id}] 匹配扩展标签: {expanded_tags}") |
| 359 | |
| 360 | # 数据写入打分逻辑 |
| 361 | token_set = set(input_tags) # 用于确定哪些标签是用户输入的标签 |
| 362 | for tag in expanded_tags: |
| 363 | score = 2.0 * rate if tag in token_set else 2.0 * rate |
| 364 | try: |
| 365 | cursor.execute(""" |
| 366 | INSERT INTO user_tag_scores (user_id, tag, score) |
| 367 | VALUES (%s, %s, %s) |
| 368 | ON DUPLICATE KEY UPDATE score = score + VALUES(score) |
| 369 | """, (user_id, tag, score)) |
| 370 | except Exception as e: |
| 371 | print(f"插入失败 [{tag}]:", e) |
| 372 | db.commit() |
| 373 | # ⏳ 插入后立即查询该用户的所有标签和评分 |
| 374 | try: |
| 375 | cursor.execute(""" |
| 376 | SELECT tag, score FROM user_tag_scores WHERE user_id = %s |
| 377 | """, (user_id,)) |
| 378 | user_scores = [{"tag": tag, "score": float(score)} for tag, score in cursor.fetchall()] |
| 379 | print(user_scores) |
| 380 | except Exception as e: |
| 381 | print(f"查询用户评分失败: {e}") |
| 382 | |
| 383 | duration = round(time.time() - start_time, 3) |
| 384 | return jsonify({ |
| 385 | "expanded_tags": expanded_tags, |
| 386 | }) |
| 387 | |
| 388 | |
| 389 | # ✅ 触发标签缓存刷新(可选:手动/接口/定时任务调用) |
| 390 | @app.route("/refresh_tags", methods=["POST"]) |
| 391 | def refresh_tags(): |
| 392 | refresh_existing_tags() |
| 393 | return jsonify({"status": "标签缓存已刷新", "count": len(existing_tags)}) |
| 394 | |
| 395 | |
22301110 | f2e3c09 | 2025-06-05 01:24:43 +0800 | [diff] [blame] | 396 | # === ✅ 启动服务 === |
| 397 | if __name__ == '__main__': |
22301110 | 9dfbcee | 2025-06-09 22:27:01 +0800 | [diff] [blame] | 398 | #train_and_save_itemcf() |
22301110 | f2e3c09 | 2025-06-05 01:24:43 +0800 | [diff] [blame] | 399 | from waitress import serve |
| 400 | serve(app, host="0.0.0.0", port=5000, threads=16) |