优化推荐系统和冷启动

Change-Id: I93d3091f249f2396a25702e01eb8dd5a9e95e8bc
diff --git a/.gitignore b/.gitignore
index 3c68961..c319c04 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,4 +3,7 @@
 torrents/
 appeals/
 migrations/
-front/node_modules
\ No newline at end of file
+front/node_modules
+recommend/user_seed_graph.txt
+recommend/model/__pycache__/
+recommend/utils/__pycache__/
\ No newline at end of file
diff --git a/front/src/HomePage.js b/front/src/HomePage.js
index 026f122..088b284 100644
--- a/front/src/HomePage.js
+++ b/front/src/HomePage.js
@@ -1,4 +1,4 @@
-import React from "react";
+import React, { useEffect, useState } from "react";
 import HomeIcon from "@mui/icons-material/Home";
 import MovieIcon from "@mui/icons-material/Movie";
 import TvIcon from "@mui/icons-material/Tv";
@@ -28,40 +28,40 @@
     { label: "求种", icon: <HelpIcon className="emerald-nav-icon" />, path: "/begseed", type: "help" },
 ];
 
-// 示例种子数据
-const exampleSeeds = [
-    {
-        id: 1,
-        tags: "电影,科幻",
-        title: "三体 1080P 蓝光",
-        popularity: 123,
-        user: { username: "Alice" },
-    },
-    {
-        id: 2,
-        tags: "动漫,热血",
-        title: "灌篮高手 国语配音",
-        popularity: 88,
-        user: { username: "Bob" },
-    },
-    {
-        id: 3,
-        tags: "音乐,流行",
-        title: "周杰伦-稻香",
-        popularity: 56,
-        user: { username: "Jay" },
-    },
-    {
-        id: 4,
-        tags: "剧集,悬疑",
-        title: "隐秘的角落",
-        popularity: 77,
-        user: { username: "小明" },
-    },
-];
-
 export default function HomePage() {
     const navigate = useNavigate();
+    const [recommendSeeds, setRecommendSeeds] = useState([]);
+    const [loading, setLoading] = useState(true);
+
+    useEffect(() => {
+        // 获取当前登录用户ID
+        const match = document.cookie.match('(^|;)\\s*userId=([^;]+)');
+        const userId = match ? match[2] : null;
+
+        if (!userId) {
+            setRecommendSeeds([]);
+            setLoading(false);
+            return;
+        }
+
+        setLoading(true);
+        fetch("http://10.126.59.25:5000/recommend", {
+            method: "POST",
+            headers: {
+                "Content-Type": "application/json"
+            },
+            body: JSON.stringify({ user_id: userId })
+        })
+            .then(res => res.json())
+            .then(data => {
+                setRecommendSeeds(Array.isArray(data.recommend) ? data.recommend : []);
+                setLoading(false);
+            })
+            .catch(() => {
+                setRecommendSeeds([]);
+                setLoading(false);
+            });
+    }, []);
 
     return (
         <div className="emerald-home-container">
@@ -134,24 +134,38 @@
                         <thead>
                             <tr>
                                 <th>分类标签</th>
-                                <th>资源标题</th>
-                                <th>热门指数</th>
+                                <th>标题</th>
                                 <th>发布者</th>
+                                <th>大小</th>
+                                <th>热度</th>
+                                <th>折扣倍率</th>
                             </tr>
                         </thead>
                         <tbody>
-                            {exampleSeeds.map((seed) => (
-                                <tr key={seed.id}>
-                                    <td>{seed.tags}</td>
-                                    <td>
-                                        <a href={`/torrent/${seed.id}`}>
-                                            {seed.title}
-                                        </a>
-                                    </td>
-                                    <td>{seed.popularity}</td>
-                                    <td>{seed.user.username}</td>
+                            {loading ? (
+                                <tr>
+                                    <td colSpan={6} style={{ textAlign: "center", color: "#888" }}>正在加载推荐种子...</td>
                                 </tr>
-                            ))}
+                            ) : recommendSeeds.length === 0 ? (
+                                <tr>
+                                    <td colSpan={6} style={{ textAlign: "center", color: "#888" }}>暂无推荐数据</td>
+                                </tr>
+                            ) : (
+                                recommendSeeds.map((seed) => (
+                                    <tr key={seed.seed_id}>
+                                        <td>{seed.tags}</td>
+                                        <td>
+                                            <a href={`/torrent/${seed.seed_id}`}>
+                                                {seed.title}
+                                            </a>
+                                        </td>
+                                        <td>{seed.username}</td>
+                                        <td>{seed.size}</td>
+                                        <td>{seed.popularity}</td>
+                                        <td>{seed.discount == null ? 1 : seed.discount}</td>
+                                    </tr>
+                                ))
+                            )}
                         </tbody>
                     </table>
                 </div>
diff --git a/front/src/config.js b/front/src/config.js
index 24fe7d1..1309822 100644
--- a/front/src/config.js
+++ b/front/src/config.js
@@ -1 +1 @@
-export const API_BASE_URL = "http://10.126.59.25:8083";
+export const API_BASE_URL = "http://10.126.59.25:9999";
diff --git a/recommend/inference.py b/recommend/inference.py
index 9601230..e2b7b00 100644
--- a/recommend/inference.py
+++ b/recommend/inference.py
@@ -1,32 +1,110 @@
 import sys
 sys.path.append('./')
 
-import time
 import torch
-import numpy as np
-from os import path
+import pymysql
+from flask_cors import CORS
+from flask import Flask, request, jsonify
 from model.LightGCN import LightGCN
 from utils.parse_args import args
 from utils.data_loader import EdgeListData
-from utils.data_generator import build_user_seed_graph
+from utils.graph_build import build_user_seed_graph
+
+app = Flask(__name__)
+CORS(app)
 
 args.device = 'cuda:7'
 args.data_path = './user_seed_graph.txt'
 args.pre_model_path = './model/LightGCN_pretrained.pt'
 
-def run_inference(user_id=1):
-    # 1. 实时生成user-seed交互图
-    print("正在生成用户-种子交互文件...")
-    build_user_seed_graph()
+# 数据库连接配置
+DB_CONFIG = {
+    'host': '10.126.59.25',
+    'port': 3306,
+    'user': 'root',
+    'password': '123456',
+    'database': 'pt_database_test',
+    'charset': 'utf8mb4'
+}
 
-    # 2. 加载数据集
-    print("正在加载数据集...")
-    t_data_start = time.time()
+TOPK = 2  # 默认推荐数量
+
+def user_cold_start(topk=TOPK):
+    """
+    冷启动:直接返回热度最高的topk个种子详细信息
+    """
+    conn = pymysql.connect(**DB_CONFIG)
+    cursor = conn.cursor()
+
+    # 查询热度最高的topk个种子
+    cursor.execute(
+        f"SELECT seed_id, owner_user_id, tags, title, size, popularity FROM Seed ORDER BY popularity DESC LIMIT %s",
+        (topk,)
+    )
+    seed_rows = cursor.fetchall()
+    seed_ids = [row[0] for row in seed_rows]
+    seed_map = {row[0]: row for row in seed_rows}
+
+    # 查询用户信息
+    owner_ids = list(set(row[1] for row in seed_rows))
+    if owner_ids:
+        format_strings_user = ','.join(['%s'] * len(owner_ids))
+        cursor.execute(
+            f"SELECT user_id, username FROM User WHERE user_id IN ({format_strings_user})",
+            tuple(owner_ids)
+        )
+        user_rows = cursor.fetchall()
+        user_map = {row[0]: row[1] for row in user_rows}
+    else:
+        user_map = {}
+
+    # 查询促销信息
+    if seed_ids:
+        format_strings = ','.join(['%s'] * len(seed_ids))
+        cursor.execute(
+            f"SELECT seed_id, discount FROM SeedPromotion WHERE seed_id IN ({format_strings})",
+            tuple(seed_ids)
+        )
+        promo_rows = cursor.fetchall()
+        promo_map = {row[0]: row[1] for row in promo_rows}
+    else:
+        promo_map = {}
+
+    cursor.close()
+    conn.close()
+
+    seed_list = []
+    for seed_id in seed_ids:
+        row = seed_map.get(seed_id)
+        if not row:
+            continue
+        owner_user_id = row[1]
+        seed_list.append({
+            'seed_id': seed_id,
+            'tags': row[2],
+            'title': row[3],
+            'size': row[4],
+            'username': user_map.get(owner_user_id, ""),
+            'popularity': row[5],
+            'discount': promo_map.get(seed_id, 1)
+        })
+    return seed_list
+
+def run_inference(user_id, topk=TOPK):
+    """
+    输入: user_id, topk
+    输出: 推荐的topk个种子ID列表(原始种子ID字符串)
+    """
+    user2idx, seed2idx = build_user_seed_graph(return_mapping=True)
+    idx2seed = {v: k for k, v in seed2idx.items()}
+
+    if user_id not in user2idx:
+        # 冷启动
+        return user_cold_start(topk)
+
+    user_idx = user2idx[user_id]
+
     dataset = EdgeListData(args.data_path, args.data_path)
-    t_data_end = time.time()
-
-    # 3. 加载LightGCN模型
-    print("正在加载模型参数...")
     pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
     pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
     pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
@@ -35,23 +113,90 @@
     model.load_state_dict(pretrained_dict, strict=False)
     model.eval()
 
-    # 4. 推理
-    print(f"正在为用户 {user_id} 推理推荐结果...")
-    t_infer_start = time.time()
     with torch.no_grad():
         user_emb, item_emb = model.generate()
-        user_vec = user_emb[user_id].unsqueeze(0)
+        user_vec = user_emb[user_idx].unsqueeze(0)
         scores = model.rating(user_vec, item_emb).squeeze(0)
-        pred_item = torch.argmax(scores).item()
-    t_infer_end = time.time()
+        topk_indices = torch.topk(scores, topk).indices.cpu().numpy()
+        topk_seed_ids = [idx2seed[idx] for idx in topk_indices]
 
-    print(f"用户{user_id}下一个最可能点击的物品ID为: {pred_item}")
-    print(f"加载数据集耗时: {t_data_end - t_data_start:.4f} 秒")
-    print(f"推理耗时: {t_infer_end - t_infer_start:.4f} 秒")
+    return topk_seed_ids
+
+def seed_info(topk_seed_ids):
+    """
+    输入: topk_seed_ids(种子ID字符串列表)
+    输出: 推荐种子的详细信息列表,每个元素为dict
+    """
+    if not topk_seed_ids:
+        return []
+
+    conn = pymysql.connect(**DB_CONFIG)
+    cursor = conn.cursor()
+
+    # 查询种子基本信息
+    format_strings = ','.join(['%s'] * len(topk_seed_ids))
+    cursor.execute(
+        f"SELECT seed_id, owner_user_id, tags, title, size, popularity FROM Seed WHERE seed_id IN ({format_strings})",
+        tuple(topk_seed_ids)
+    )
+    seed_rows = cursor.fetchall()
+    seed_map = {row[0]: row for row in seed_rows}
+
+    # 查询用户信息
+    owner_ids = list(set(row[1] for row in seed_rows))
+    if owner_ids:
+        format_strings_user = ','.join(['%s'] * len(owner_ids))
+        cursor.execute(
+            f"SELECT user_id, username FROM User WHERE user_id IN ({format_strings_user})",
+            tuple(owner_ids)
+        )
+        user_rows = cursor.fetchall()
+        user_map = {row[0]: row[1] for row in user_rows}
+    else:
+        user_map = {}
+
+    # 查询促销信息
+    cursor.execute(
+        f"SELECT seed_id, discount FROM SeedPromotion WHERE seed_id IN ({format_strings})",
+        tuple(topk_seed_ids)
+    )
+    promo_rows = cursor.fetchall()
+    promo_map = {row[0]: row[1] for row in promo_rows}
+
+    cursor.close()
+    conn.close()
+
+    seed_list = []
+    for seed_id in topk_seed_ids:
+        row = seed_map.get(seed_id)
+        if not row:
+            continue
+        owner_user_id = row[1]
+        seed_list.append({
+            'seed_id': seed_id,
+            'tags': row[2],
+            'title': row[3],
+            'size': row[4],
+            'username': user_map.get(owner_user_id, ""),
+            'popularity': row[5],
+            'discount': promo_map.get(seed_id, 1)
+        })
+    return seed_list
+
+@app.route('/recommend', methods=['POST'])
+def recommend():
+    data = request.get_json()
+    user_id = data.get('user_id')
+    try:
+        result = run_inference(user_id)
+        # 如果是冷启动直接返回详细信息,否则查详情
+        if isinstance(result, list) and result and isinstance(result[0], dict):
+            seed_list = result
+        else:
+            seed_list = seed_info(result)
+        return jsonify({'recommend': seed_list})
+    except Exception as e:
+        return jsonify({'error': str(e)}), 400
 
 if __name__ == "__main__":
-    t_start = time.time()
-    user_id = 1
-    run_inference(user_id)
-    t_end = time.time()
-    print(f"脚本总耗时: {t_end - t_start:.4f} 秒")
\ No newline at end of file
+    app.run(host='0.0.0.0', port=5000)
diff --git a/recommend/test_inference.py b/recommend/test_inference.py
new file mode 100644
index 0000000..b85dacd
--- /dev/null
+++ b/recommend/test_inference.py
@@ -0,0 +1,14 @@
+import requests
+
+API_URL = "http://10.126.59.25:5000/recommend"
+
+user_ids = ['user1', '550e8400-e29b-41d4-a716-446655440000']
+
+for user_id in user_ids:
+    try:
+        response = requests.post(API_URL, json={'user_id': user_id})
+        response.raise_for_status()
+        result = response.json()
+        print(f"Recommendations for user_id '{user_id}': {result}")
+    except requests.exceptions.RequestException as e:
+        print(f"Error occurred while fetching recommendations for user_id '{user_id}': {e}")
diff --git a/recommend/user_seed_graph.txt b/recommend/user_seed_graph.txt
index ef29920..5b5ad3d 100644
--- a/recommend/user_seed_graph.txt
+++ b/recommend/user_seed_graph.txt
@@ -1,3 +1,3 @@
-0	0 0 1 2	1746061954 1736237924 1736240066 1736309966
+0	0 0 1	1746061954 1736237924 1748971659
 1	1 1	1746315010 1746583706
-2	2 0 2	1746738305 1746865166 1749284366
+2	2	1746738305
diff --git a/recommend/utils/data_generator.py b/recommend/utils/graph_build.py
similarity index 92%
rename from recommend/utils/data_generator.py
rename to recommend/utils/graph_build.py
index d3ef9cf..becf5df 100644
--- a/recommend/utils/data_generator.py
+++ b/recommend/utils/graph_build.py
@@ -9,7 +9,7 @@
 SqlPassword = "123456"
 
 
-def fetch_data():
+def fetch_user_seed_data():
     conn = pymysql.connect(
         host=SqlURL,
         port=SqlPort,
@@ -74,8 +74,10 @@
             f.write(f"{uid}\t{items}\t{times}\n")
 
 
-def build_user_seed_graph():
-    download_rows, favorite_rows = fetch_data()
+def build_user_seed_graph(return_mapping=False):
+    download_rows, favorite_rows = fetch_user_seed_data()
     records, user_set, seed_set = process_records(download_rows, favorite_rows)
     user2idx, seed2idx = build_id_maps(user_set, seed_set)
     group_and_write(records, user2idx, seed2idx)
+    if return_mapping:
+        return user2idx, seed2idx
\ No newline at end of file