优化推荐系统和冷启动
Change-Id: I93d3091f249f2396a25702e01eb8dd5a9e95e8bc
diff --git a/.gitignore b/.gitignore
index 3c68961..c319c04 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,4 +3,7 @@
torrents/
appeals/
migrations/
-front/node_modules
\ No newline at end of file
+front/node_modules
+recommend/user_seed_graph.txt
+recommend/model/__pycache__/
+recommend/utils/__pycache__/
\ No newline at end of file
diff --git a/front/src/HomePage.js b/front/src/HomePage.js
index 026f122..088b284 100644
--- a/front/src/HomePage.js
+++ b/front/src/HomePage.js
@@ -1,4 +1,4 @@
-import React from "react";
+import React, { useEffect, useState } from "react";
import HomeIcon from "@mui/icons-material/Home";
import MovieIcon from "@mui/icons-material/Movie";
import TvIcon from "@mui/icons-material/Tv";
@@ -28,40 +28,40 @@
{ label: "求种", icon: <HelpIcon className="emerald-nav-icon" />, path: "/begseed", type: "help" },
];
-// 示例种子数据
-const exampleSeeds = [
- {
- id: 1,
- tags: "电影,科幻",
- title: "三体 1080P 蓝光",
- popularity: 123,
- user: { username: "Alice" },
- },
- {
- id: 2,
- tags: "动漫,热血",
- title: "灌篮高手 国语配音",
- popularity: 88,
- user: { username: "Bob" },
- },
- {
- id: 3,
- tags: "音乐,流行",
- title: "周杰伦-稻香",
- popularity: 56,
- user: { username: "Jay" },
- },
- {
- id: 4,
- tags: "剧集,悬疑",
- title: "隐秘的角落",
- popularity: 77,
- user: { username: "小明" },
- },
-];
-
export default function HomePage() {
const navigate = useNavigate();
+ const [recommendSeeds, setRecommendSeeds] = useState([]);
+ const [loading, setLoading] = useState(true);
+
+ useEffect(() => {
+ // 获取当前登录用户ID
+ const match = document.cookie.match('(^|;)\\s*userId=([^;]+)');
+ const userId = match ? match[2] : null;
+
+ if (!userId) {
+ setRecommendSeeds([]);
+ setLoading(false);
+ return;
+ }
+
+ setLoading(true);
+ fetch("http://10.126.59.25:5000/recommend", {
+ method: "POST",
+ headers: {
+ "Content-Type": "application/json"
+ },
+ body: JSON.stringify({ user_id: userId })
+ })
+ .then(res => res.json())
+ .then(data => {
+ setRecommendSeeds(Array.isArray(data.recommend) ? data.recommend : []);
+ setLoading(false);
+ })
+ .catch(() => {
+ setRecommendSeeds([]);
+ setLoading(false);
+ });
+ }, []);
return (
<div className="emerald-home-container">
@@ -134,24 +134,38 @@
<thead>
<tr>
<th>分类标签</th>
- <th>资源标题</th>
- <th>热门指数</th>
+ <th>标题</th>
<th>发布者</th>
+ <th>大小</th>
+ <th>热度</th>
+ <th>折扣倍率</th>
</tr>
</thead>
<tbody>
- {exampleSeeds.map((seed) => (
- <tr key={seed.id}>
- <td>{seed.tags}</td>
- <td>
- <a href={`/torrent/${seed.id}`}>
- {seed.title}
- </a>
- </td>
- <td>{seed.popularity}</td>
- <td>{seed.user.username}</td>
+ {loading ? (
+ <tr>
+ <td colSpan={6} style={{ textAlign: "center", color: "#888" }}>正在加载推荐种子...</td>
</tr>
- ))}
+ ) : recommendSeeds.length === 0 ? (
+ <tr>
+ <td colSpan={6} style={{ textAlign: "center", color: "#888" }}>暂无推荐数据</td>
+ </tr>
+ ) : (
+ recommendSeeds.map((seed) => (
+ <tr key={seed.seed_id}>
+ <td>{seed.tags}</td>
+ <td>
+ <a href={`/torrent/${seed.seed_id}`}>
+ {seed.title}
+ </a>
+ </td>
+ <td>{seed.username}</td>
+ <td>{seed.size}</td>
+ <td>{seed.popularity}</td>
+ <td>{seed.discount == null ? 1 : seed.discount}</td>
+ </tr>
+ ))
+ )}
</tbody>
</table>
</div>
diff --git a/front/src/config.js b/front/src/config.js
index 24fe7d1..1309822 100644
--- a/front/src/config.js
+++ b/front/src/config.js
@@ -1 +1 @@
-export const API_BASE_URL = "http://10.126.59.25:8083";
+export const API_BASE_URL = "http://10.126.59.25:9999";
diff --git a/recommend/inference.py b/recommend/inference.py
index 9601230..e2b7b00 100644
--- a/recommend/inference.py
+++ b/recommend/inference.py
@@ -1,32 +1,110 @@
import sys
sys.path.append('./')
-import time
import torch
-import numpy as np
-from os import path
+import pymysql
+from flask_cors import CORS
+from flask import Flask, request, jsonify
from model.LightGCN import LightGCN
from utils.parse_args import args
from utils.data_loader import EdgeListData
-from utils.data_generator import build_user_seed_graph
+from utils.graph_build import build_user_seed_graph
+
+app = Flask(__name__)
+CORS(app)
args.device = 'cuda:7'
args.data_path = './user_seed_graph.txt'
args.pre_model_path = './model/LightGCN_pretrained.pt'
-def run_inference(user_id=1):
- # 1. 实时生成user-seed交互图
- print("正在生成用户-种子交互文件...")
- build_user_seed_graph()
+# 数据库连接配置
+DB_CONFIG = {
+ 'host': '10.126.59.25',
+ 'port': 3306,
+ 'user': 'root',
+ 'password': '123456',
+ 'database': 'pt_database_test',
+ 'charset': 'utf8mb4'
+}
- # 2. 加载数据集
- print("正在加载数据集...")
- t_data_start = time.time()
+TOPK = 2 # 默认推荐数量
+
+def user_cold_start(topk=TOPK):
+ """
+ 冷启动:直接返回热度最高的topk个种子详细信息
+ """
+ conn = pymysql.connect(**DB_CONFIG)
+ cursor = conn.cursor()
+
+ # 查询热度最高的topk个种子
+ cursor.execute(
+ f"SELECT seed_id, owner_user_id, tags, title, size, popularity FROM Seed ORDER BY popularity DESC LIMIT %s",
+ (topk,)
+ )
+ seed_rows = cursor.fetchall()
+ seed_ids = [row[0] for row in seed_rows]
+ seed_map = {row[0]: row for row in seed_rows}
+
+ # 查询用户信息
+ owner_ids = list(set(row[1] for row in seed_rows))
+ if owner_ids:
+ format_strings_user = ','.join(['%s'] * len(owner_ids))
+ cursor.execute(
+ f"SELECT user_id, username FROM User WHERE user_id IN ({format_strings_user})",
+ tuple(owner_ids)
+ )
+ user_rows = cursor.fetchall()
+ user_map = {row[0]: row[1] for row in user_rows}
+ else:
+ user_map = {}
+
+ # 查询促销信息
+ if seed_ids:
+ format_strings = ','.join(['%s'] * len(seed_ids))
+ cursor.execute(
+ f"SELECT seed_id, discount FROM SeedPromotion WHERE seed_id IN ({format_strings})",
+ tuple(seed_ids)
+ )
+ promo_rows = cursor.fetchall()
+ promo_map = {row[0]: row[1] for row in promo_rows}
+ else:
+ promo_map = {}
+
+ cursor.close()
+ conn.close()
+
+ seed_list = []
+ for seed_id in seed_ids:
+ row = seed_map.get(seed_id)
+ if not row:
+ continue
+ owner_user_id = row[1]
+ seed_list.append({
+ 'seed_id': seed_id,
+ 'tags': row[2],
+ 'title': row[3],
+ 'size': row[4],
+ 'username': user_map.get(owner_user_id, ""),
+ 'popularity': row[5],
+ 'discount': promo_map.get(seed_id, 1)
+ })
+ return seed_list
+
+def run_inference(user_id, topk=TOPK):
+ """
+ 输入: user_id, topk
+ 输出: 推荐的topk个种子ID列表(原始种子ID字符串)
+ """
+ user2idx, seed2idx = build_user_seed_graph(return_mapping=True)
+ idx2seed = {v: k for k, v in seed2idx.items()}
+
+ if user_id not in user2idx:
+ # 冷启动
+ return user_cold_start(topk)
+
+ user_idx = user2idx[user_id]
+
dataset = EdgeListData(args.data_path, args.data_path)
- t_data_end = time.time()
-
- # 3. 加载LightGCN模型
- print("正在加载模型参数...")
pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
@@ -35,23 +113,90 @@
model.load_state_dict(pretrained_dict, strict=False)
model.eval()
- # 4. 推理
- print(f"正在为用户 {user_id} 推理推荐结果...")
- t_infer_start = time.time()
with torch.no_grad():
user_emb, item_emb = model.generate()
- user_vec = user_emb[user_id].unsqueeze(0)
+ user_vec = user_emb[user_idx].unsqueeze(0)
scores = model.rating(user_vec, item_emb).squeeze(0)
- pred_item = torch.argmax(scores).item()
- t_infer_end = time.time()
+ topk_indices = torch.topk(scores, topk).indices.cpu().numpy()
+ topk_seed_ids = [idx2seed[idx] for idx in topk_indices]
- print(f"用户{user_id}下一个最可能点击的物品ID为: {pred_item}")
- print(f"加载数据集耗时: {t_data_end - t_data_start:.4f} 秒")
- print(f"推理耗时: {t_infer_end - t_infer_start:.4f} 秒")
+ return topk_seed_ids
+
+def seed_info(topk_seed_ids):
+ """
+ 输入: topk_seed_ids(种子ID字符串列表)
+ 输出: 推荐种子的详细信息列表,每个元素为dict
+ """
+ if not topk_seed_ids:
+ return []
+
+ conn = pymysql.connect(**DB_CONFIG)
+ cursor = conn.cursor()
+
+ # 查询种子基本信息
+ format_strings = ','.join(['%s'] * len(topk_seed_ids))
+ cursor.execute(
+ f"SELECT seed_id, owner_user_id, tags, title, size, popularity FROM Seed WHERE seed_id IN ({format_strings})",
+ tuple(topk_seed_ids)
+ )
+ seed_rows = cursor.fetchall()
+ seed_map = {row[0]: row for row in seed_rows}
+
+ # 查询用户信息
+ owner_ids = list(set(row[1] for row in seed_rows))
+ if owner_ids:
+ format_strings_user = ','.join(['%s'] * len(owner_ids))
+ cursor.execute(
+ f"SELECT user_id, username FROM User WHERE user_id IN ({format_strings_user})",
+ tuple(owner_ids)
+ )
+ user_rows = cursor.fetchall()
+ user_map = {row[0]: row[1] for row in user_rows}
+ else:
+ user_map = {}
+
+ # 查询促销信息
+ cursor.execute(
+ f"SELECT seed_id, discount FROM SeedPromotion WHERE seed_id IN ({format_strings})",
+ tuple(topk_seed_ids)
+ )
+ promo_rows = cursor.fetchall()
+ promo_map = {row[0]: row[1] for row in promo_rows}
+
+ cursor.close()
+ conn.close()
+
+ seed_list = []
+ for seed_id in topk_seed_ids:
+ row = seed_map.get(seed_id)
+ if not row:
+ continue
+ owner_user_id = row[1]
+ seed_list.append({
+ 'seed_id': seed_id,
+ 'tags': row[2],
+ 'title': row[3],
+ 'size': row[4],
+ 'username': user_map.get(owner_user_id, ""),
+ 'popularity': row[5],
+ 'discount': promo_map.get(seed_id, 1)
+ })
+ return seed_list
+
+@app.route('/recommend', methods=['POST'])
+def recommend():
+ data = request.get_json()
+ user_id = data.get('user_id')
+ try:
+ result = run_inference(user_id)
+ # 如果是冷启动直接返回详细信息,否则查详情
+ if isinstance(result, list) and result and isinstance(result[0], dict):
+ seed_list = result
+ else:
+ seed_list = seed_info(result)
+ return jsonify({'recommend': seed_list})
+ except Exception as e:
+ return jsonify({'error': str(e)}), 400
if __name__ == "__main__":
- t_start = time.time()
- user_id = 1
- run_inference(user_id)
- t_end = time.time()
- print(f"脚本总耗时: {t_end - t_start:.4f} 秒")
\ No newline at end of file
+ app.run(host='0.0.0.0', port=5000)
diff --git a/recommend/test_inference.py b/recommend/test_inference.py
new file mode 100644
index 0000000..b85dacd
--- /dev/null
+++ b/recommend/test_inference.py
@@ -0,0 +1,14 @@
+import requests
+
+API_URL = "http://10.126.59.25:5000/recommend"
+
+user_ids = ['user1', '550e8400-e29b-41d4-a716-446655440000']
+
+for user_id in user_ids:
+ try:
+ response = requests.post(API_URL, json={'user_id': user_id})
+ response.raise_for_status()
+ result = response.json()
+ print(f"Recommendations for user_id '{user_id}': {result}")
+ except requests.exceptions.RequestException as e:
+ print(f"Error occurred while fetching recommendations for user_id '{user_id}': {e}")
diff --git a/recommend/user_seed_graph.txt b/recommend/user_seed_graph.txt
index ef29920..5b5ad3d 100644
--- a/recommend/user_seed_graph.txt
+++ b/recommend/user_seed_graph.txt
@@ -1,3 +1,3 @@
-0 0 0 1 2 1746061954 1736237924 1736240066 1736309966
+0 0 0 1 1746061954 1736237924 1748971659
1 1 1 1746315010 1746583706
-2 2 0 2 1746738305 1746865166 1749284366
+2 2 1746738305
diff --git a/recommend/utils/data_generator.py b/recommend/utils/graph_build.py
similarity index 92%
rename from recommend/utils/data_generator.py
rename to recommend/utils/graph_build.py
index d3ef9cf..becf5df 100644
--- a/recommend/utils/data_generator.py
+++ b/recommend/utils/graph_build.py
@@ -9,7 +9,7 @@
SqlPassword = "123456"
-def fetch_data():
+def fetch_user_seed_data():
conn = pymysql.connect(
host=SqlURL,
port=SqlPort,
@@ -74,8 +74,10 @@
f.write(f"{uid}\t{items}\t{times}\n")
-def build_user_seed_graph():
- download_rows, favorite_rows = fetch_data()
+def build_user_seed_graph(return_mapping=False):
+ download_rows, favorite_rows = fetch_user_seed_data()
records, user_set, seed_set = process_records(download_rows, favorite_rows)
user2idx, seed2idx = build_id_maps(user_set, seed_set)
group_and_write(records, user2idx, seed2idx)
+ if return_mapping:
+ return user2idx, seed2idx
\ No newline at end of file