feat: 加入了download_model.py进行模型的下载和训练,修复了recommend.py的bug
Change-Id: I72fb3fdb160ff73934396e2127ae6432e8a727c8
diff --git a/recommend/demo.py b/recommend/demo.py
new file mode 100644
index 0000000..b64ec6b
--- /dev/null
+++ b/recommend/demo.py
@@ -0,0 +1,149 @@
+import os
+import time
+import jieba
+import fasttext
+from flask import Flask, request, jsonify
+import mysql.connector
+
+# ✅ 初始化数据库连接
+import mysql.connector
+from sklearn.metrics.pairwise import cosine_similarity
+import numpy as np
+
+def compute_similarity(vec1, vec2):
+ """计算两个词向量之间的余弦相似度"""
+ # 使用 cosine_similarity 计算相似度
+ return cosine_similarity([vec1], [vec2])[0][0]
+
+db = mysql.connector.connect(
+ host="49.233.215.144",
+ port=3306,
+ user="sy",
+ password="sy_password",
+ database="pt_station"
+)
+
+cursor = db.cursor()
+
+
+app = Flask(__name__)
+
+# ✅ 加载 fastText 模型
+fasttext_model_path = './models/cc.zh.300.bin'
+if not os.path.exists(fasttext_model_path):
+ raise FileNotFoundError("fastText 模型文件不存在,请检查路径。")
+
+print("加载 fastText 模型中...")
+ft_model = fasttext.load_model(fasttext_model_path)
+print("模型加载完成 ✅")
+fasttext_vocab = set(ft_model.words)
+
+# ✅ 全局标签缓存(用于避免频繁查询)
+existing_tags = set()
+
+def refresh_existing_tags():
+ """刷新数据库中已存在的标签集合"""
+ global existing_tags
+ cursor.execute("SELECT DISTINCT tag FROM user_tag_scores")
+ existing_tags = set(tag[0] for tag in cursor.fetchall())
+ print(f"已加载标签数: {len(existing_tags)}")
+
+# ✅ 启动时初始化标签缓存
+refresh_existing_tags()
+
+# ✅ 扩展函数:仅保留数据库已有标签
+def expand_tags_from_input(input_tags, topn=5, similarity_threshold=0.7):
+ """
+ 扩展输入标签列表,查找与之语义相似的标签,返回相似度大于阈值或最相似的前 n 个标签。
+
+ :param input_tags: 输入标签的列表,例如 ['电影', '动漫', '游戏', '1080p']
+ :param topn: 返回的最相似标签的数量,默认为 5
+ :param similarity_threshold: 相似度阈值,默认为 0.7
+ :return: 返回与输入标签相关的扩展标签列表
+ """
+ # 用于存储所有扩展标签及其相似度
+ tag_scores = {}
+
+ for tag in input_tags:
+ # 获取当前标签的词向量
+ tag_vector = ft_model.get_word_vector(tag)
+
+ # 遍历标签库中的所有标签并计算相似度
+ for db_tag in existing_tags:
+ db_tag_vector = ft_model.get_word_vector(db_tag)
+ similarity = compute_similarity(tag_vector, db_tag_vector)
+
+ if similarity >= similarity_threshold:
+ if db_tag not in tag_scores:
+ tag_scores[db_tag] = similarity
+ else:
+ tag_scores[db_tag] = max(tag_scores[db_tag], similarity)
+
+ # 根据相似度排序并返回前 n 个标签
+ sorted_tags = sorted(tag_scores.items(), key=lambda x: x[1], reverse=True)
+ top_tags = [tag for tag, _ in sorted_tags[:topn]]
+
+ return top_tags
+
+
+# ✅ 接口路由
+@app.route("/expand_tags", methods=["POST"])
+def expand_tags():
+ start_time = time.time()
+
+ # 从请求中获取数据
+ data = request.get_json()
+ input_tags = data.get("tags", [])
+ user_id = data.get("user_id")
+ topn = data.get("topn", 10) # 默认为 5
+ similarity_threshold = 0.4 # 默认阈值为 0.7
+
+ if not input_tags or not user_id:
+ return jsonify({"error": "缺少参数 tags 或 user_id"}), 400
+
+ # 获取与输入标签最相关的标签
+ expanded_tags = expand_tags_from_input(input_tags, topn=topn, similarity_threshold=similarity_threshold)
+
+ # 打印日志
+ print(f"[用户 {user_id}] 输入标签: {input_tags}")
+ print(f"[用户 {user_id}] 匹配扩展标签: {expanded_tags}")
+
+ # 数据写入打分逻辑
+ token_set = set(input_tags) # 用于确定哪些标签是用户输入的标签
+ for tag in expanded_tags:
+ score = 2.0 if tag in token_set else 1.0
+ try:
+ cursor.execute("""
+ INSERT INTO user_tag_scores (user_id, tag, score)
+ VALUES (%s, %s, %s)
+ ON DUPLICATE KEY UPDATE score = score + VALUES(score)
+ """, (user_id, tag, score))
+ except Exception as e:
+ print(f"插入失败 [{tag}]:", e)
+ db.commit()
+ # ⏳ 插入后立即查询该用户的所有标签和评分
+ try:
+ cursor.execute("""
+ SELECT tag, score FROM user_tag_scores WHERE user_id = %s
+ """, (user_id,))
+ user_scores = [{"tag": tag, "score": float(score)} for tag, score in cursor.fetchall()]
+ print(user_scores)
+ except Exception as e:
+ print(f"查询用户评分失败: {e}")
+
+ duration = round(time.time() - start_time, 3)
+ return jsonify({
+ "expanded_tags": expanded_tags,
+ })
+
+
+# ✅ 触发标签缓存刷新(可选:手动/接口/定时任务调用)
+@app.route("/refresh_tags", methods=["POST"])
+def refresh_tags():
+ refresh_existing_tags()
+ return jsonify({"status": "标签缓存已刷新", "count": len(existing_tags)})
+
+# ✅ 启动服务
+if __name__ == "__main__":
+ from waitress import serve
+ serve(app, host="0.0.0.0", port=5000, threads=16)
diff --git a/recommend/download_model.py b/recommend/download_model.py
new file mode 100644
index 0000000..1bb0c50
--- /dev/null
+++ b/recommend/download_model.py
@@ -0,0 +1,32 @@
+import os
+import urllib.request
+from recommend import train_and_save_itemcf
+MODEL_URL = "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.zh.300.bin.gz"
+MODEL_DIR = "./models"
+MODEL_PATH = os.path.join(MODEL_DIR, "cc.zh.300.bin")
+COMPRESSED_PATH = MODEL_PATH + ".gz"
+
+def download_model():
+ if not os.path.exists(MODEL_DIR):
+ os.makedirs(MODEL_DIR)
+
+ if os.path.exists(MODEL_PATH):
+ print("✅ 模型已存在,跳过下载。")
+ return
+
+ print("⏬ 下载 fastText 中文模型...")
+ urllib.request.urlretrieve(MODEL_URL, COMPRESSED_PATH)
+
+ print("📦 解压模型文件...")
+ import gzip
+ import shutil
+ with gzip.open(COMPRESSED_PATH, 'rb') as f_in:
+ with open(MODEL_PATH, 'wb') as f_out:
+ shutil.copyfileobj(f_in, f_out)
+
+ os.remove(COMPRESSED_PATH)
+ print("✅ 模型下载并解压完成!")
+
+if __name__ == "__main__":
+ train_and_save_itemcf()
+ download_model()
diff --git a/recommend/recommend.py b/recommend/recommend.py
new file mode 100644
index 0000000..25032a0
--- /dev/null
+++ b/recommend/recommend.py
@@ -0,0 +1,250 @@
+import os
+import time
+import jieba
+import fasttext
+import pandas as pd
+from flask import Flask, request, jsonify
+from sqlalchemy import create_engine
+from scipy.sparse import coo_matrix
+from sklearn.metrics.pairwise import cosine_similarity
+import pickle
+
+app = Flask(__name__)
+
+# === ✅ SQLAlchemy 数据库连接 ===
+engine = create_engine("mysql+pymysql://sy:sy_password@49.233.215.144:3306/pt_station")
+
+# === ✅ 加载 fastText 模型 ===
+fasttext_model_path = 'E:\\course\\pt\\recommend\\models\\cc.zh.300.bin'
+if not os.path.exists(fasttext_model_path):
+ raise FileNotFoundError("fastText 模型文件不存在,请检查路径。")
+print("加载 fastText 模型中...")
+ft_model = fasttext.load_model(fasttext_model_path)
+print("模型加载完成 ✅")
+
+# === ✅ 用户标签行为矩阵构建 ===
+def get_user_tag_matrix():
+ df = pd.read_sql("SELECT user_id, tag, score FROM user_tag_scores", engine)
+ print(df)
+ df['user_id'] = df['user_id'].astype(str)
+ user_map = {u: i for i, u in enumerate(df['user_id'].unique())}
+ tag_map = {t: i for i, t in enumerate(df['tag'].unique())}
+ df['user_index'] = df['user_id'].map(user_map)
+ df['tag_index'] = df['tag'].map(tag_map)
+ matrix = df.pivot_table(index='user_id', columns='tag', values='score', fill_value=0)
+ sparse_matrix = coo_matrix((df['score'], (df['tag_index'], df['user_index'])))
+ return df, matrix, sparse_matrix, user_map, tag_map
+
+# === ✅ 基于 fastText 的语义相似推荐方法 ===
+def semantic_recommend(user_id, topn=5):
+ print(f"正在为用户 {user_id} 生成推荐...")
+
+ # 读取数据库中的用户标签数据
+ df = pd.read_sql("SELECT user_id, tag, score FROM user_tag_scores", engine)
+ print(f"总记录数: {len(df)}")
+ print(f"数据示例:\n{df.head()}")
+ print(df.dtypes)
+ user_id = str(user_id) # 确保匹配
+
+ # 获取该用户的所有标签(按分数从高到低排序)
+ user_tags = df[df['user_id'] == user_id].sort_values(by="score", ascending=False)['tag'].tolist()
+ print(f"用户 {user_id} 的标签(按分数排序): {user_tags}")
+
+ if not user_tags:
+ print(f"用户 {user_id} 没有标签记录,返回空推荐结果。")
+ return []
+
+ # 截取前 3 个标签作为“兴趣标签”
+ user_tags = user_tags[:3]
+ print(f"用户 {user_id} 的 Top 3 标签: {user_tags}")
+
+ # 构造所有标签的词向量
+ all_tags = df['tag'].unique()
+ print(f"所有唯一标签数量: {len(all_tags)}")
+
+ tag_vectors = {}
+ for tag in all_tags:
+ vec = ft_model.get_word_vector(tag)
+ tag_vectors[tag] = vec
+
+ # 计算未出现过标签的相似度得分
+ scores = {}
+ for tag in all_tags:
+ if tag in user_tags:
+ continue
+ vec = tag_vectors[tag]
+ sim_total = 0.0
+ for t in user_tags:
+ sim = cosine_similarity([vec], [ft_model.get_word_vector(t)])[0][0]
+ print(f"标签 [{tag}] 与用户标签 [{t}] 的相似度: {sim:.4f}")
+ sim_total += sim
+ avg_score = sim_total / len(user_tags)
+ scores[tag] = avg_score
+ print(f"标签 [{tag}] 的平均相似度得分: {avg_score:.4f}")
+
+ # 排序并返回 topN 标签
+ sorted_tags = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:topn]
+ print(f"\n最终推荐标签(前 {topn}):")
+ for tag, score in sorted_tags:
+ print(f"{tag}: {score:.4f}")
+
+ return [tag for tag, _ in sorted_tags]
+
+# === ✅ ItemCF 推荐方法 ===
+import os
+import pickle
+
+def itemcf_recommend(user_id, matrix, sim_path="./models/itemcf_sim.pkl", topn=5):
+ user_id = str(user_id) # 确保 user_id 类型一致
+ print(matrix.index.dtype)
+ print(type(user_id)) # 应该是 str
+
+ if user_id not in matrix.index:
+ print(f"⚠️ 用户 {user_id} 不在评分矩阵中。")
+ return []
+
+ if not os.path.exists(sim_path):
+ print(f"⚠️ 用户 {user_id} 不在评分矩阵中。")
+ train_and_save_itemcf()
+
+ with open(sim_path, "rb") as f:
+ sim_df = pickle.load(f)
+
+ user_row = matrix.loc[user_id]
+ user_tags = user_row[user_row > 0]
+
+ if user_tags.empty:
+ print(f"⚠️ 用户 {user_id} 没有任何标签评分记录。")
+ return []
+
+ print(f"用户 {user_id} 的标签评分:\n{user_tags}")
+
+ scores = {}
+ for tag, val in user_tags.items():
+ if tag not in sim_df:
+ print(f"标签 {tag} 在相似度矩阵中不存在,跳过。")
+ continue
+ sims = sim_df[tag].drop(index=user_tags.index, errors="ignore")
+ for sim_tag, sim_score in sims.items():
+ scores[sim_tag] = scores.get(sim_tag, 0) + sim_score * val
+
+ if not scores:
+ print(f"⚠️ 用户 {user_id} 无法生成推荐,可能是标签相似度不足。")
+ return []
+
+ sorted_tags = sorted(scores.items(), key=lambda x: x[1], reverse=True)
+ print(f"推荐得分(前{topn}):\n", sorted_tags[:topn])
+
+ return [tag for tag, _ in sorted_tags[:topn]]
+
+
+# === ✅ ItemCF 相似度训练 ===
+def train_and_save_itemcf(path="./models/itemcf_sim.pkl"):
+ _, matrix, _, _, _ = get_user_tag_matrix()
+ tag_sim = cosine_similarity(matrix.T)
+ sim_df = pd.DataFrame(tag_sim, index=matrix.columns, columns=matrix.columns)
+ with open(path, "wb") as f:
+ pickle.dump(sim_df, f)
+ print("ItemCF 相似度矩阵已保存 ✅")
+
+# === ✅ Flask 推荐接口 ===
+import random
+
+@app.route("/recommend_torrents", methods=["POST"])
+def recommend_torrents():
+ data = request.get_json()
+ user_id = data.get("user_id")
+
+ if not user_id:
+ return jsonify({"error": "缺少 user_id"}), 400
+
+ df, matrix, _, _, _ = get_user_tag_matrix()
+
+ # 获取推荐标签
+ itemcf_result = itemcf_recommend(user_id, matrix)
+ semantic_result = semantic_recommend(user_id)
+
+
+ print(f"ItemCF 推荐标签: {itemcf_result}")
+ print(f"Semantic 推荐标签: {semantic_result}")
+
+ all_tags = df['tag'].unique().tolist()
+
+ # 存储标签及其推荐得分
+ combined = []
+ used_tags = set()
+
+ def add_unique_tags(tags, method_name):
+ for tag in tags:
+ if tag not in used_tags:
+ random_priority = random.uniform(0, 1)
+ if method_name == 'ItemCF':
+ combined.append((tag, 'ItemCF', random_priority))
+ elif method_name == 'Semantic':
+ combined.append((tag, 'Semantic', random_priority))
+ used_tags.add(tag)
+
+ # 添加 ItemCF 和 Semantic 推荐
+ add_unique_tags(itemcf_result, 'ItemCF')
+ add_unique_tags(semantic_result, 'Semantic')
+
+ # 添加随机标签
+ random.shuffle(all_tags)
+ add_unique_tags(all_tags, 'Random')
+
+ # 排序:按推荐得分排序,加入的随机值也会影响排序
+ combined.sort(key=lambda x: x[2], reverse=True)
+
+ # 根据标签获取种子 ID
+ final_tags = [tag for tag, _, _ in combined]
+ print(f"最终推荐标签: {final_tags}")
+ torrent_ids = get_torrent_ids_by_tags(final_tags)
+
+ return jsonify({"torrent_ids": torrent_ids})
+
+
+
+from sqlalchemy.sql import text
+
+import random
+from sqlalchemy import text
+
+def get_torrent_ids_by_tags(tags, limit_per_tag=10):
+ if not tags:
+ tags = []
+
+ recommended_ids = set()
+ with engine.connect() as conn:
+ for tag in tags:
+ query = text("""
+ SELECT torrent_id
+ FROM bt_torrent_tags
+ WHERE tag = :tag
+ LIMIT :limit
+ """)
+ result = conn.execute(query, {"tag": tag, "limit": limit_per_tag})
+ for row in result:
+ recommended_ids.add(row[0])
+
+ # 获取数据库中所有 torrent_id
+ all_query = text("SELECT DISTINCT torrent_id FROM bt_torrent_tags")
+ all_result = conn.execute(all_query)
+ all_ids = set(row[0] for row in all_result)
+
+ # 剩下的(非推荐)种子 ID
+ remaining_ids = all_ids - recommended_ids
+
+ # 随机打乱推荐和剩下的 ID
+ recommended_list = list(recommended_ids)
+ remaining_list = list(remaining_ids)
+ random.shuffle(recommended_list)
+ random.shuffle(remaining_list)
+
+ return recommended_list + remaining_list
+
+
+# === ✅ 启动服务 ===
+if __name__ == '__main__':
+ train_and_save_itemcf()
+ from waitress import serve
+ serve(app, host="0.0.0.0", port=5000, threads=16)
diff --git a/recommend/requirements.txt b/recommend/requirements.txt
new file mode 100644
index 0000000..2efe47a
--- /dev/null
+++ b/recommend/requirements.txt
@@ -0,0 +1,18 @@
+fasttext_wheel==0.9.2
+Flask==3.1.1
+huggingface_hub==0.31.2
+jieba==0.42.1
+mysql_connector_repackaged==0.3.1
+numpy==2.2.6
+pandas==2.2.3
+scikit_learn==1.6.1
+scikit_surprise==1.1.4
+scipy==1.15.3
+sentence_transformers==4.1.0
+SQLAlchemy==2.0.41
+surprise==0.1
+tensorflow==2.19.0
+tensorflow_hub==0.16.1
+torch==2.7.0
+transformers==4.51.3
+waitress==3.0.2