blob: b64ec6b8110f29469645a9f6b7533d085d4c2aed [file] [log] [blame]
22301110f2e3c092025-06-05 01:24:43 +08001import os
2import time
3import jieba
4import fasttext
5from flask import Flask, request, jsonify
6import mysql.connector
7
8# ✅ 初始化数据库连接
9import mysql.connector
10from sklearn.metrics.pairwise import cosine_similarity
11import numpy as np
12
13def compute_similarity(vec1, vec2):
14 """计算两个词向量之间的余弦相似度"""
15 # 使用 cosine_similarity 计算相似度
16 return cosine_similarity([vec1], [vec2])[0][0]
17
18db = mysql.connector.connect(
19 host="49.233.215.144",
20 port=3306,
21 user="sy",
22 password="sy_password",
23 database="pt_station"
24)
25
26cursor = db.cursor()
27
28
29app = Flask(__name__)
30
31# ✅ 加载 fastText 模型
32fasttext_model_path = './models/cc.zh.300.bin'
33if not os.path.exists(fasttext_model_path):
34 raise FileNotFoundError("fastText 模型文件不存在,请检查路径。")
35
36print("加载 fastText 模型中...")
37ft_model = fasttext.load_model(fasttext_model_path)
38print("模型加载完成 ✅")
39fasttext_vocab = set(ft_model.words)
40
41# ✅ 全局标签缓存(用于避免频繁查询)
42existing_tags = set()
43
44def refresh_existing_tags():
45 """刷新数据库中已存在的标签集合"""
46 global existing_tags
47 cursor.execute("SELECT DISTINCT tag FROM user_tag_scores")
48 existing_tags = set(tag[0] for tag in cursor.fetchall())
49 print(f"已加载标签数: {len(existing_tags)}")
50
51# ✅ 启动时初始化标签缓存
52refresh_existing_tags()
53
54# ✅ 扩展函数:仅保留数据库已有标签
55def expand_tags_from_input(input_tags, topn=5, similarity_threshold=0.7):
56 """
57 扩展输入标签列表,查找与之语义相似的标签,返回相似度大于阈值或最相似的前 n 个标签。
58
59 :param input_tags: 输入标签的列表,例如 ['电影', '动漫', '游戏', '1080p']
60 :param topn: 返回的最相似标签的数量,默认为 5
61 :param similarity_threshold: 相似度阈值,默认为 0.7
62 :return: 返回与输入标签相关的扩展标签列表
63 """
64 # 用于存储所有扩展标签及其相似度
65 tag_scores = {}
66
67 for tag in input_tags:
68 # 获取当前标签的词向量
69 tag_vector = ft_model.get_word_vector(tag)
70
71 # 遍历标签库中的所有标签并计算相似度
72 for db_tag in existing_tags:
73 db_tag_vector = ft_model.get_word_vector(db_tag)
74 similarity = compute_similarity(tag_vector, db_tag_vector)
75
76 if similarity >= similarity_threshold:
77 if db_tag not in tag_scores:
78 tag_scores[db_tag] = similarity
79 else:
80 tag_scores[db_tag] = max(tag_scores[db_tag], similarity)
81
82 # 根据相似度排序并返回前 n 个标签
83 sorted_tags = sorted(tag_scores.items(), key=lambda x: x[1], reverse=True)
84 top_tags = [tag for tag, _ in sorted_tags[:topn]]
85
86 return top_tags
87
88
89# ✅ 接口路由
90@app.route("/expand_tags", methods=["POST"])
91def expand_tags():
92 start_time = time.time()
93
94 # 从请求中获取数据
95 data = request.get_json()
96 input_tags = data.get("tags", [])
97 user_id = data.get("user_id")
98 topn = data.get("topn", 10) # 默认为 5
99 similarity_threshold = 0.4 # 默认阈值为 0.7
100
101 if not input_tags or not user_id:
102 return jsonify({"error": "缺少参数 tags 或 user_id"}), 400
103
104 # 获取与输入标签最相关的标签
105 expanded_tags = expand_tags_from_input(input_tags, topn=topn, similarity_threshold=similarity_threshold)
106
107 # 打印日志
108 print(f"[用户 {user_id}] 输入标签: {input_tags}")
109 print(f"[用户 {user_id}] 匹配扩展标签: {expanded_tags}")
110
111 # 数据写入打分逻辑
112 token_set = set(input_tags) # 用于确定哪些标签是用户输入的标签
113 for tag in expanded_tags:
114 score = 2.0 if tag in token_set else 1.0
115 try:
116 cursor.execute("""
117 INSERT INTO user_tag_scores (user_id, tag, score)
118 VALUES (%s, %s, %s)
119 ON DUPLICATE KEY UPDATE score = score + VALUES(score)
120 """, (user_id, tag, score))
121 except Exception as e:
122 print(f"插入失败 [{tag}]:", e)
123 db.commit()
124 # ⏳ 插入后立即查询该用户的所有标签和评分
125 try:
126 cursor.execute("""
127 SELECT tag, score FROM user_tag_scores WHERE user_id = %s
128 """, (user_id,))
129 user_scores = [{"tag": tag, "score": float(score)} for tag, score in cursor.fetchall()]
130 print(user_scores)
131 except Exception as e:
132 print(f"查询用户评分失败: {e}")
133
134 duration = round(time.time() - start_time, 3)
135 return jsonify({
136 "expanded_tags": expanded_tags,
137 })
138
139
140# ✅ 触发标签缓存刷新(可选:手动/接口/定时任务调用)
141@app.route("/refresh_tags", methods=["POST"])
142def refresh_tags():
143 refresh_existing_tags()
144 return jsonify({"status": "标签缓存已刷新", "count": len(existing_tags)})
145
146# ✅ 启动服务
147if __name__ == "__main__":
148 from waitress import serve
149 serve(app, host="0.0.0.0", port=5000, threads=16)