22301110 | f2e3c09 | 2025-06-05 01:24:43 +0800 | [diff] [blame] | 1 | import os |
| 2 | import time |
| 3 | import jieba |
| 4 | import fasttext |
| 5 | from flask import Flask, request, jsonify |
| 6 | import mysql.connector |
| 7 | |
| 8 | # ✅ 初始化数据库连接 |
| 9 | import mysql.connector |
| 10 | from sklearn.metrics.pairwise import cosine_similarity |
| 11 | import numpy as np |
| 12 | |
| 13 | def compute_similarity(vec1, vec2): |
| 14 | """计算两个词向量之间的余弦相似度""" |
| 15 | # 使用 cosine_similarity 计算相似度 |
| 16 | return cosine_similarity([vec1], [vec2])[0][0] |
| 17 | |
| 18 | db = mysql.connector.connect( |
| 19 | host="49.233.215.144", |
| 20 | port=3306, |
| 21 | user="sy", |
| 22 | password="sy_password", |
| 23 | database="pt_station" |
| 24 | ) |
| 25 | |
| 26 | cursor = db.cursor() |
| 27 | |
| 28 | |
| 29 | app = Flask(__name__) |
| 30 | |
| 31 | # ✅ 加载 fastText 模型 |
| 32 | fasttext_model_path = './models/cc.zh.300.bin' |
| 33 | if not os.path.exists(fasttext_model_path): |
| 34 | raise FileNotFoundError("fastText 模型文件不存在,请检查路径。") |
| 35 | |
| 36 | print("加载 fastText 模型中...") |
| 37 | ft_model = fasttext.load_model(fasttext_model_path) |
| 38 | print("模型加载完成 ✅") |
| 39 | fasttext_vocab = set(ft_model.words) |
| 40 | |
| 41 | # ✅ 全局标签缓存(用于避免频繁查询) |
| 42 | existing_tags = set() |
| 43 | |
| 44 | def refresh_existing_tags(): |
| 45 | """刷新数据库中已存在的标签集合""" |
| 46 | global existing_tags |
| 47 | cursor.execute("SELECT DISTINCT tag FROM user_tag_scores") |
| 48 | existing_tags = set(tag[0] for tag in cursor.fetchall()) |
| 49 | print(f"已加载标签数: {len(existing_tags)}") |
| 50 | |
| 51 | # ✅ 启动时初始化标签缓存 |
| 52 | refresh_existing_tags() |
| 53 | |
| 54 | # ✅ 扩展函数:仅保留数据库已有标签 |
| 55 | def expand_tags_from_input(input_tags, topn=5, similarity_threshold=0.7): |
| 56 | """ |
| 57 | 扩展输入标签列表,查找与之语义相似的标签,返回相似度大于阈值或最相似的前 n 个标签。 |
| 58 | |
| 59 | :param input_tags: 输入标签的列表,例如 ['电影', '动漫', '游戏', '1080p'] |
| 60 | :param topn: 返回的最相似标签的数量,默认为 5 |
| 61 | :param similarity_threshold: 相似度阈值,默认为 0.7 |
| 62 | :return: 返回与输入标签相关的扩展标签列表 |
| 63 | """ |
| 64 | # 用于存储所有扩展标签及其相似度 |
| 65 | tag_scores = {} |
| 66 | |
| 67 | for tag in input_tags: |
| 68 | # 获取当前标签的词向量 |
| 69 | tag_vector = ft_model.get_word_vector(tag) |
| 70 | |
| 71 | # 遍历标签库中的所有标签并计算相似度 |
| 72 | for db_tag in existing_tags: |
| 73 | db_tag_vector = ft_model.get_word_vector(db_tag) |
| 74 | similarity = compute_similarity(tag_vector, db_tag_vector) |
| 75 | |
| 76 | if similarity >= similarity_threshold: |
| 77 | if db_tag not in tag_scores: |
| 78 | tag_scores[db_tag] = similarity |
| 79 | else: |
| 80 | tag_scores[db_tag] = max(tag_scores[db_tag], similarity) |
| 81 | |
| 82 | # 根据相似度排序并返回前 n 个标签 |
| 83 | sorted_tags = sorted(tag_scores.items(), key=lambda x: x[1], reverse=True) |
| 84 | top_tags = [tag for tag, _ in sorted_tags[:topn]] |
| 85 | |
| 86 | return top_tags |
| 87 | |
| 88 | |
| 89 | # ✅ 接口路由 |
| 90 | @app.route("/expand_tags", methods=["POST"]) |
| 91 | def expand_tags(): |
| 92 | start_time = time.time() |
| 93 | |
| 94 | # 从请求中获取数据 |
| 95 | data = request.get_json() |
| 96 | input_tags = data.get("tags", []) |
| 97 | user_id = data.get("user_id") |
| 98 | topn = data.get("topn", 10) # 默认为 5 |
| 99 | similarity_threshold = 0.4 # 默认阈值为 0.7 |
| 100 | |
| 101 | if not input_tags or not user_id: |
| 102 | return jsonify({"error": "缺少参数 tags 或 user_id"}), 400 |
| 103 | |
| 104 | # 获取与输入标签最相关的标签 |
| 105 | expanded_tags = expand_tags_from_input(input_tags, topn=topn, similarity_threshold=similarity_threshold) |
| 106 | |
| 107 | # 打印日志 |
| 108 | print(f"[用户 {user_id}] 输入标签: {input_tags}") |
| 109 | print(f"[用户 {user_id}] 匹配扩展标签: {expanded_tags}") |
| 110 | |
| 111 | # 数据写入打分逻辑 |
| 112 | token_set = set(input_tags) # 用于确定哪些标签是用户输入的标签 |
| 113 | for tag in expanded_tags: |
| 114 | score = 2.0 if tag in token_set else 1.0 |
| 115 | try: |
| 116 | cursor.execute(""" |
| 117 | INSERT INTO user_tag_scores (user_id, tag, score) |
| 118 | VALUES (%s, %s, %s) |
| 119 | ON DUPLICATE KEY UPDATE score = score + VALUES(score) |
| 120 | """, (user_id, tag, score)) |
| 121 | except Exception as e: |
| 122 | print(f"插入失败 [{tag}]:", e) |
| 123 | db.commit() |
| 124 | # ⏳ 插入后立即查询该用户的所有标签和评分 |
| 125 | try: |
| 126 | cursor.execute(""" |
| 127 | SELECT tag, score FROM user_tag_scores WHERE user_id = %s |
| 128 | """, (user_id,)) |
| 129 | user_scores = [{"tag": tag, "score": float(score)} for tag, score in cursor.fetchall()] |
| 130 | print(user_scores) |
| 131 | except Exception as e: |
| 132 | print(f"查询用户评分失败: {e}") |
| 133 | |
| 134 | duration = round(time.time() - start_time, 3) |
| 135 | return jsonify({ |
| 136 | "expanded_tags": expanded_tags, |
| 137 | }) |
| 138 | |
| 139 | |
| 140 | # ✅ 触发标签缓存刷新(可选:手动/接口/定时任务调用) |
| 141 | @app.route("/refresh_tags", methods=["POST"]) |
| 142 | def refresh_tags(): |
| 143 | refresh_existing_tags() |
| 144 | return jsonify({"status": "标签缓存已刷新", "count": len(existing_tags)}) |
| 145 | |
| 146 | # ✅ 启动服务 |
| 147 | if __name__ == "__main__": |
| 148 | from waitress import serve |
| 149 | serve(app, host="0.0.0.0", port=5000, threads=16) |