"修改了后端,现在用户下载会调用接口并给 tag 加分"
Change-Id: I6cde94104fcf3f4c19f30435a9d41ed679795d68
diff --git a/recommend/recommend.py b/recommend/recommend.py
index b216d52..e8c042d 100644
--- a/recommend/recommend.py
+++ b/recommend/recommend.py
@@ -14,13 +14,6 @@
# === ✅ SQLAlchemy 数据库连接 ===
engine = create_engine("mysql+pymysql://sy:sy_password@49.233.215.144:3306/pt_station")
-# === ✅ 加载 fastText 模型 ===
-fasttext_model_path = 'models\\cc.zh.300.bin'
-if not os.path.exists(fasttext_model_path):
- raise FileNotFoundError("fastText 模型文件不存在,请检查路径。")
-print("加载 fastText 模型中...")
-ft_model = fasttext.load_model(fasttext_model_path)
-print("模型加载完成 ✅")
# === ✅ 用户标签行为矩阵构建 ===
def get_user_tag_matrix():
@@ -252,8 +245,156 @@
return recommended_list + remaining_list
+import os
+import time
+import jieba
+import fasttext
+from flask import Flask, request, jsonify
+import mysql.connector
+
+# ✅ 初始化数据库连接
+import mysql.connector
+from sklearn.metrics.pairwise import cosine_similarity
+import numpy as np
+
+def compute_similarity(vec1, vec2):
+ """计算两个词向量之间的余弦相似度"""
+ # 使用 cosine_similarity 计算相似度
+ return cosine_similarity([vec1], [vec2])[0][0]
+
+db = mysql.connector.connect(
+ host="49.233.215.144",
+ port=3306,
+ user="sy",
+ password="sy_password",
+ database="pt_station"
+)
+
+cursor = db.cursor()
+
+
+app = Flask(__name__)
+
+# ✅ 加载 fastText 模型
+fasttext_model_path = './models/cc.zh.300.bin'
+if not os.path.exists(fasttext_model_path):
+ raise FileNotFoundError("fastText 模型文件不存在,请检查路径。")
+
+print("加载 fastText 模型中...")
+ft_model = fasttext.load_model(fasttext_model_path)
+print("模型加载完成 ✅")
+fasttext_vocab = set(ft_model.words)
+
+# ✅ 全局标签缓存(用于避免频繁查询)
+existing_tags = set()
+
+def refresh_existing_tags():
+ """刷新数据库中已存在的标签集合"""
+ global existing_tags
+ cursor.execute("SELECT DISTINCT tag FROM user_tag_scores")
+ existing_tags = set(tag[0] for tag in cursor.fetchall())
+ print(f"已加载标签数: {len(existing_tags)}")
+
+# ✅ 启动时初始化标签缓存
+refresh_existing_tags()
+
+# ✅ 扩展函数:仅保留数据库已有标签
+def expand_tags_from_input(input_tags, topn=5, similarity_threshold=0.7):
+ """
+ 扩展输入标签列表,查找与之语义相似的标签,返回相似度大于阈值或最相似的前 n 个标签。
+
+ :param input_tags: 输入标签的列表,例如 ['电影', '动漫', '游戏', '1080p']
+ :param topn: 返回的最相似标签的数量,默认为 5
+ :param similarity_threshold: 相似度阈值,默认为 0.7
+ :return: 返回与输入标签相关的扩展标签列表
+ """
+ # 用于存储所有扩展标签及其相似度
+ tag_scores = {}
+
+ for tag in input_tags:
+ # 获取当前标签的词向量
+ tag_vector = ft_model.get_word_vector(tag)
+
+ # 遍历标签库中的所有标签并计算相似度
+ for db_tag in existing_tags:
+ db_tag_vector = ft_model.get_word_vector(db_tag)
+ similarity = compute_similarity(tag_vector, db_tag_vector)
+
+ if similarity >= similarity_threshold:
+ if db_tag not in tag_scores:
+ tag_scores[db_tag] = similarity
+ else:
+ tag_scores[db_tag] = max(tag_scores[db_tag], similarity)
+
+ # 根据相似度排序并返回前 n 个标签
+ sorted_tags = sorted(tag_scores.items(), key=lambda x: x[1], reverse=True)
+ top_tags = [tag for tag, _ in sorted_tags[:topn]]
+
+ return top_tags
+
+
+# ✅ 接口路由
+@app.route("/expand_tags", methods=["POST"])
+def expand_tags():
+ start_time = time.time()
+
+ # 从请求中获取数据
+ data = request.get_json()
+ input_tags = data.get("tags", [])
+ user_id = data.get("user_id")
+ rate = data.get("rate")
+
+ topn = data.get("topn", 10) # 默认为 5
+ similarity_threshold = 0.4 # 默认阈值为 0.7
+
+ if not input_tags or not user_id:
+ return jsonify({"error": "缺少参数 tags 或 user_id"}), 400
+
+ # 获取与输入标签最相关的标签
+ expanded_tags = expand_tags_from_input(input_tags, topn=topn, similarity_threshold=similarity_threshold)
+
+ # 打印日志
+ print(f"[用户 {user_id}] 输入标签: {input_tags}")
+ print(f"[用户 {user_id}] 匹配扩展标签: {expanded_tags}")
+
+ # 数据写入打分逻辑
+ token_set = set(input_tags) # 用于确定哪些标签是用户输入的标签
+ for tag in expanded_tags:
+ score = 2.0 * rate if tag in token_set else 2.0 * rate
+ try:
+ cursor.execute("""
+ INSERT INTO user_tag_scores (user_id, tag, score)
+ VALUES (%s, %s, %s)
+ ON DUPLICATE KEY UPDATE score = score + VALUES(score)
+ """, (user_id, tag, score))
+ except Exception as e:
+ print(f"插入失败 [{tag}]:", e)
+ db.commit()
+ # ⏳ 插入后立即查询该用户的所有标签和评分
+ try:
+ cursor.execute("""
+ SELECT tag, score FROM user_tag_scores WHERE user_id = %s
+ """, (user_id,))
+ user_scores = [{"tag": tag, "score": float(score)} for tag, score in cursor.fetchall()]
+ print(user_scores)
+ except Exception as e:
+ print(f"查询用户评分失败: {e}")
+
+ duration = round(time.time() - start_time, 3)
+ return jsonify({
+ "expanded_tags": expanded_tags,
+ })
+
+
+# ✅ 触发标签缓存刷新(可选:手动/接口/定时任务调用)
+@app.route("/refresh_tags", methods=["POST"])
+def refresh_tags():
+ refresh_existing_tags()
+ return jsonify({"status": "标签缓存已刷新", "count": len(existing_tags)})
+
+
# === ✅ 启动服务 ===
if __name__ == '__main__':
- train_and_save_itemcf()
+ #train_and_save_itemcf()
from waitress import serve
serve(app, host="0.0.0.0", port=5000, threads=16)
diff --git a/ruoyi-admin/src/main/java/com/ruoyi/torrent/controller/BtTorrentController.java b/ruoyi-admin/src/main/java/com/ruoyi/torrent/controller/BtTorrentController.java
index 559547b..035f258 100644
--- a/ruoyi-admin/src/main/java/com/ruoyi/torrent/controller/BtTorrentController.java
+++ b/ruoyi-admin/src/main/java/com/ruoyi/torrent/controller/BtTorrentController.java
@@ -14,9 +14,11 @@
import com.ruoyi.common.utils.SecurityUtils;
import com.ruoyi.torrent.domain.BtTorrentAnnounce;
import com.ruoyi.torrent.domain.BtTorrentFile;
+import com.ruoyi.torrent.domain.BtTorrentTags;
import com.ruoyi.torrent.service.IBtTorrentAnnounceService;
import com.ruoyi.torrent.service.IBtTorrentFileService;
import com.ruoyi.torrent.service.IBtTorrentService;
+import com.ruoyi.torrent.service.IBtTorrentTagsService;
import com.ruoyi.torrent.util.TorrentFileUtil;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.security.access.prepost.PreAuthorize;
@@ -56,6 +58,8 @@
{
@Autowired
private IBtTorrentService btTorrentService;
+ @Autowired
+ private IBtTorrentTagsService btTorrentTagsService;
@Autowired
private IBtTorrentFileService btTorrentFileService;
@@ -203,26 +207,43 @@
@GetMapping("/download/{id}")
public void downloadTorrent(@PathVariable("id") Long id, HttpServletResponse response) {
try {
+ String userId = String.valueOf(getUserId()); // 获取当前用户的 user_id
+
// 1. 调用 getTorrentInfo(id) 获取文件路径
- BtTorrent btTorrent= btTorrentService.selectBtTorrentByTorrentId(id);
- String filePath=btTorrent.getFilePath();
- String fileName=btTorrent.getName();
+ BtTorrent btTorrent = btTorrentService.selectBtTorrentByTorrentId(id);
+ String filePath = btTorrent.getFilePath();
+ String fileName = btTorrent.getName();
+ long torrentId = btTorrent.getTorrentId();
+
+ BtTorrentTags query = new BtTorrentTags();
+ query.setTorrentId(torrentId);
+ List<BtTorrentTags> tags = btTorrentTagsService.selectBtTorrentTagsList(query);
+
+ // ✅ 调用 Flask 接口进行标签扩展
+ ObjectMapper mapper = new ObjectMapper();
+ Map<String, Object> bodyMap = new HashMap<>();
+ bodyMap.put("user_id", userId);
+ bodyMap.put("tags", tags);
+ bodyMap.put("rate", 2);
+ bodyMap.put("topn", 10); // 可选
+ String jsonRequest = mapper.writeValueAsString(bodyMap);
+
+ // 调用 Flask 接口
+ String flaskUrl = "http://127.0.0.1:5000/expand_tags";
+ String flaskResponse = sendJsonPost(flaskUrl, jsonRequest);
+ System.out.println("标签扩展返回: " + flaskResponse);
// 2. 使用工具方法下载文件(不删除源文件)
TorrentFileUtil.downloadFile(response, filePath, fileName, false);
} catch (FileNotFoundException e) {
-
response.setStatus(HttpServletResponse.SC_NOT_FOUND);
System.out.println(e.getMessage());
} catch (Exception e) {
-
response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
System.out.println(e.getMessage());
-
}
-
}
/**