| import os |
| import time |
| import jieba |
| import fasttext |
| from flask import Flask, request, jsonify |
| import mysql.connector |
| |
| # ✅ 初始化数据库连接 |
| import mysql.connector |
| from sklearn.metrics.pairwise import cosine_similarity |
| import numpy as np |
| |
| def compute_similarity(vec1, vec2): |
| """计算两个词向量之间的余弦相似度""" |
| # 使用 cosine_similarity 计算相似度 |
| return cosine_similarity([vec1], [vec2])[0][0] |
| |
| db = mysql.connector.connect( |
| host="49.233.215.144", |
| port=3306, |
| user="sy", |
| password="sy_password", |
| database="pt_station" |
| ) |
| |
| cursor = db.cursor() |
| |
| |
| app = Flask(__name__) |
| |
| # ✅ 加载 fastText 模型 |
| fasttext_model_path = './models/cc.zh.300.bin' |
| if not os.path.exists(fasttext_model_path): |
| raise FileNotFoundError("fastText 模型文件不存在,请检查路径。") |
| |
| print("加载 fastText 模型中...") |
| ft_model = fasttext.load_model(fasttext_model_path) |
| print("模型加载完成 ✅") |
| fasttext_vocab = set(ft_model.words) |
| |
| # ✅ 全局标签缓存(用于避免频繁查询) |
| existing_tags = set() |
| |
| def refresh_existing_tags(): |
| """刷新数据库中已存在的标签集合""" |
| global existing_tags |
| cursor.execute("SELECT DISTINCT tag FROM user_tag_scores") |
| existing_tags = set(tag[0] for tag in cursor.fetchall()) |
| print(f"已加载标签数: {len(existing_tags)}") |
| |
| # ✅ 启动时初始化标签缓存 |
| refresh_existing_tags() |
| |
| # ✅ 扩展函数:仅保留数据库已有标签 |
| def expand_tags_from_input(input_tags, topn=5, similarity_threshold=0.7): |
| """ |
| 扩展输入标签列表,查找与之语义相似的标签,返回相似度大于阈值或最相似的前 n 个标签。 |
| |
| :param input_tags: 输入标签的列表,例如 ['电影', '动漫', '游戏', '1080p'] |
| :param topn: 返回的最相似标签的数量,默认为 5 |
| :param similarity_threshold: 相似度阈值,默认为 0.7 |
| :return: 返回与输入标签相关的扩展标签列表 |
| """ |
| # 用于存储所有扩展标签及其相似度 |
| tag_scores = {} |
| |
| for tag in input_tags: |
| # 获取当前标签的词向量 |
| tag_vector = ft_model.get_word_vector(tag) |
| |
| # 遍历标签库中的所有标签并计算相似度 |
| for db_tag in existing_tags: |
| db_tag_vector = ft_model.get_word_vector(db_tag) |
| similarity = compute_similarity(tag_vector, db_tag_vector) |
| |
| if similarity >= similarity_threshold: |
| if db_tag not in tag_scores: |
| tag_scores[db_tag] = similarity |
| else: |
| tag_scores[db_tag] = max(tag_scores[db_tag], similarity) |
| |
| # 根据相似度排序并返回前 n 个标签 |
| sorted_tags = sorted(tag_scores.items(), key=lambda x: x[1], reverse=True) |
| top_tags = [tag for tag, _ in sorted_tags[:topn]] |
| |
| return top_tags |
| |
| |
| # ✅ 接口路由 |
| @app.route("/expand_tags", methods=["POST"]) |
| def expand_tags(): |
| start_time = time.time() |
| |
| # 从请求中获取数据 |
| data = request.get_json() |
| input_tags = data.get("tags", []) |
| user_id = data.get("user_id") |
| topn = data.get("topn", 10) # 默认为 5 |
| similarity_threshold = 0.4 # 默认阈值为 0.7 |
| |
| if not input_tags or not user_id: |
| return jsonify({"error": "缺少参数 tags 或 user_id"}), 400 |
| |
| # 获取与输入标签最相关的标签 |
| expanded_tags = expand_tags_from_input(input_tags, topn=topn, similarity_threshold=similarity_threshold) |
| |
| # 打印日志 |
| print(f"[用户 {user_id}] 输入标签: {input_tags}") |
| print(f"[用户 {user_id}] 匹配扩展标签: {expanded_tags}") |
| |
| # 数据写入打分逻辑 |
| token_set = set(input_tags) # 用于确定哪些标签是用户输入的标签 |
| for tag in expanded_tags: |
| score = 2.0 if tag in token_set else 1.0 |
| try: |
| cursor.execute(""" |
| INSERT INTO user_tag_scores (user_id, tag, score) |
| VALUES (%s, %s, %s) |
| ON DUPLICATE KEY UPDATE score = score + VALUES(score) |
| """, (user_id, tag, score)) |
| except Exception as e: |
| print(f"插入失败 [{tag}]:", e) |
| db.commit() |
| # ⏳ 插入后立即查询该用户的所有标签和评分 |
| try: |
| cursor.execute(""" |
| SELECT tag, score FROM user_tag_scores WHERE user_id = %s |
| """, (user_id,)) |
| user_scores = [{"tag": tag, "score": float(score)} for tag, score in cursor.fetchall()] |
| print(user_scores) |
| except Exception as e: |
| print(f"查询用户评分失败: {e}") |
| |
| duration = round(time.time() - start_time, 3) |
| return jsonify({ |
| "expanded_tags": expanded_tags, |
| }) |
| |
| |
| # ✅ 触发标签缓存刷新(可选:手动/接口/定时任务调用) |
| @app.route("/refresh_tags", methods=["POST"]) |
| def refresh_tags(): |
| refresh_existing_tags() |
| return jsonify({"status": "标签缓存已刷新", "count": len(existing_tags)}) |
| |
| # ✅ 启动服务 |
| if __name__ == "__main__": |
| from waitress import serve |
| serve(app, host="0.0.0.0", port=5000, threads=16) |