feat: 加入了download_model.py进行模型的下载和训练,修复了recommend.py的bug

Change-Id: I72fb3fdb160ff73934396e2127ae6432e8a727c8
diff --git a/recommend/demo.py b/recommend/demo.py
new file mode 100644
index 0000000..b64ec6b
--- /dev/null
+++ b/recommend/demo.py
@@ -0,0 +1,149 @@
+import os
+import time
+import jieba
+import fasttext
+from flask import Flask, request, jsonify
+import mysql.connector
+
+# ✅ 初始化数据库连接
+import mysql.connector
+from sklearn.metrics.pairwise import cosine_similarity
+import numpy as np
+
+def compute_similarity(vec1, vec2):
+    """计算两个词向量之间的余弦相似度"""
+    # 使用 cosine_similarity 计算相似度
+    return cosine_similarity([vec1], [vec2])[0][0]
+
+db = mysql.connector.connect(
+    host="49.233.215.144",
+    port=3306,
+    user="sy",
+    password="sy_password",
+    database="pt_station"
+)
+
+cursor = db.cursor()
+
+
+app = Flask(__name__)
+
+# ✅ 加载 fastText 模型
+fasttext_model_path = './models/cc.zh.300.bin'
+if not os.path.exists(fasttext_model_path):
+    raise FileNotFoundError("fastText 模型文件不存在,请检查路径。")
+
+print("加载 fastText 模型中...")
+ft_model = fasttext.load_model(fasttext_model_path)
+print("模型加载完成 ✅")
+fasttext_vocab = set(ft_model.words)
+
+# ✅ 全局标签缓存(用于避免频繁查询)
+existing_tags = set()
+
+def refresh_existing_tags():
+    """刷新数据库中已存在的标签集合"""
+    global existing_tags
+    cursor.execute("SELECT DISTINCT tag FROM user_tag_scores")
+    existing_tags = set(tag[0] for tag in cursor.fetchall())
+    print(f"已加载标签数: {len(existing_tags)}")
+
+# ✅ 启动时初始化标签缓存
+refresh_existing_tags()
+
+# ✅ 扩展函数:仅保留数据库已有标签
+def expand_tags_from_input(input_tags, topn=5, similarity_threshold=0.7):
+    """
+    扩展输入标签列表,查找与之语义相似的标签,返回相似度大于阈值或最相似的前 n 个标签。
+    
+    :param input_tags: 输入标签的列表,例如 ['电影', '动漫', '游戏', '1080p']
+    :param topn: 返回的最相似标签的数量,默认为 5
+    :param similarity_threshold: 相似度阈值,默认为 0.7
+    :return: 返回与输入标签相关的扩展标签列表
+    """
+    # 用于存储所有扩展标签及其相似度
+    tag_scores = {}
+
+    for tag in input_tags:
+        # 获取当前标签的词向量
+        tag_vector = ft_model.get_word_vector(tag)
+        
+        # 遍历标签库中的所有标签并计算相似度
+        for db_tag in existing_tags:
+            db_tag_vector = ft_model.get_word_vector(db_tag)
+            similarity = compute_similarity(tag_vector, db_tag_vector)
+            
+            if similarity >= similarity_threshold:
+                if db_tag not in tag_scores:
+                    tag_scores[db_tag] = similarity
+                else:
+                    tag_scores[db_tag] = max(tag_scores[db_tag], similarity)
+    
+    # 根据相似度排序并返回前 n 个标签
+    sorted_tags = sorted(tag_scores.items(), key=lambda x: x[1], reverse=True)
+    top_tags = [tag for tag, _ in sorted_tags[:topn]]
+
+    return top_tags
+
+
+# ✅ 接口路由
+@app.route("/expand_tags", methods=["POST"])
+def expand_tags():
+    start_time = time.time()
+
+    # 从请求中获取数据
+    data = request.get_json()
+    input_tags = data.get("tags", [])
+    user_id = data.get("user_id")
+    topn = data.get("topn", 10)  # 默认为 5
+    similarity_threshold =  0.4  # 默认阈值为 0.7
+
+    if not input_tags or not user_id:
+        return jsonify({"error": "缺少参数 tags 或 user_id"}), 400
+
+    # 获取与输入标签最相关的标签
+    expanded_tags = expand_tags_from_input(input_tags, topn=topn, similarity_threshold=similarity_threshold)
+
+    # 打印日志
+    print(f"[用户 {user_id}] 输入标签: {input_tags}")
+    print(f"[用户 {user_id}] 匹配扩展标签: {expanded_tags}")
+
+    # 数据写入打分逻辑
+    token_set = set(input_tags)  # 用于确定哪些标签是用户输入的标签
+    for tag in expanded_tags:
+        score = 2.0 if tag in token_set else 1.0
+        try:
+            cursor.execute("""
+                INSERT INTO user_tag_scores (user_id, tag, score)
+                VALUES (%s, %s, %s)
+                ON DUPLICATE KEY UPDATE score = score + VALUES(score)
+            """, (user_id, tag, score))
+        except Exception as e:
+            print(f"插入失败 [{tag}]:", e)
+    db.commit()
+    # ⏳ 插入后立即查询该用户的所有标签和评分
+    try:
+        cursor.execute("""
+            SELECT tag, score FROM user_tag_scores WHERE user_id = %s
+        """, (user_id,))
+        user_scores = [{"tag": tag, "score": float(score)} for tag, score in cursor.fetchall()]
+        print(user_scores)
+    except Exception as e:
+        print(f"查询用户评分失败: {e}")
+
+    duration = round(time.time() - start_time, 3)
+    return jsonify({
+        "expanded_tags": expanded_tags,
+    })
+
+
+# ✅ 触发标签缓存刷新(可选:手动/接口/定时任务调用)
+@app.route("/refresh_tags", methods=["POST"])
+def refresh_tags():
+    refresh_existing_tags()
+    return jsonify({"status": "标签缓存已刷新", "count": len(existing_tags)})
+
+# ✅ 启动服务
+if __name__ == "__main__":
+    from waitress import serve
+    serve(app, host="0.0.0.0", port=5000, threads=16)