优化推荐系统和冷启动
Change-Id: I93d3091f249f2396a25702e01eb8dd5a9e95e8bc
diff --git a/recommend/inference.py b/recommend/inference.py
index 9601230..e2b7b00 100644
--- a/recommend/inference.py
+++ b/recommend/inference.py
@@ -1,32 +1,110 @@
import sys
sys.path.append('./')
-import time
import torch
-import numpy as np
-from os import path
+import pymysql
+from flask_cors import CORS
+from flask import Flask, request, jsonify
from model.LightGCN import LightGCN
from utils.parse_args import args
from utils.data_loader import EdgeListData
-from utils.data_generator import build_user_seed_graph
+from utils.graph_build import build_user_seed_graph
+
+app = Flask(__name__)
+CORS(app)
args.device = 'cuda:7'
args.data_path = './user_seed_graph.txt'
args.pre_model_path = './model/LightGCN_pretrained.pt'
-def run_inference(user_id=1):
- # 1. 实时生成user-seed交互图
- print("正在生成用户-种子交互文件...")
- build_user_seed_graph()
+# 数据库连接配置
+DB_CONFIG = {
+ 'host': '10.126.59.25',
+ 'port': 3306,
+ 'user': 'root',
+ 'password': '123456',
+ 'database': 'pt_database_test',
+ 'charset': 'utf8mb4'
+}
- # 2. 加载数据集
- print("正在加载数据集...")
- t_data_start = time.time()
+TOPK = 2 # 默认推荐数量
+
+def user_cold_start(topk=TOPK):
+ """
+ 冷启动:直接返回热度最高的topk个种子详细信息
+ """
+ conn = pymysql.connect(**DB_CONFIG)
+ cursor = conn.cursor()
+
+ # 查询热度最高的topk个种子
+ cursor.execute(
+ f"SELECT seed_id, owner_user_id, tags, title, size, popularity FROM Seed ORDER BY popularity DESC LIMIT %s",
+ (topk,)
+ )
+ seed_rows = cursor.fetchall()
+ seed_ids = [row[0] for row in seed_rows]
+ seed_map = {row[0]: row for row in seed_rows}
+
+ # 查询用户信息
+ owner_ids = list(set(row[1] for row in seed_rows))
+ if owner_ids:
+ format_strings_user = ','.join(['%s'] * len(owner_ids))
+ cursor.execute(
+ f"SELECT user_id, username FROM User WHERE user_id IN ({format_strings_user})",
+ tuple(owner_ids)
+ )
+ user_rows = cursor.fetchall()
+ user_map = {row[0]: row[1] for row in user_rows}
+ else:
+ user_map = {}
+
+ # 查询促销信息
+ if seed_ids:
+ format_strings = ','.join(['%s'] * len(seed_ids))
+ cursor.execute(
+ f"SELECT seed_id, discount FROM SeedPromotion WHERE seed_id IN ({format_strings})",
+ tuple(seed_ids)
+ )
+ promo_rows = cursor.fetchall()
+ promo_map = {row[0]: row[1] for row in promo_rows}
+ else:
+ promo_map = {}
+
+ cursor.close()
+ conn.close()
+
+ seed_list = []
+ for seed_id in seed_ids:
+ row = seed_map.get(seed_id)
+ if not row:
+ continue
+ owner_user_id = row[1]
+ seed_list.append({
+ 'seed_id': seed_id,
+ 'tags': row[2],
+ 'title': row[3],
+ 'size': row[4],
+ 'username': user_map.get(owner_user_id, ""),
+ 'popularity': row[5],
+ 'discount': promo_map.get(seed_id, 1)
+ })
+ return seed_list
+
+def run_inference(user_id, topk=TOPK):
+ """
+ 输入: user_id, topk
+ 输出: 推荐的topk个种子ID列表(原始种子ID字符串)
+ """
+ user2idx, seed2idx = build_user_seed_graph(return_mapping=True)
+ idx2seed = {v: k for k, v in seed2idx.items()}
+
+ if user_id not in user2idx:
+ # 冷启动
+ return user_cold_start(topk)
+
+ user_idx = user2idx[user_id]
+
dataset = EdgeListData(args.data_path, args.data_path)
- t_data_end = time.time()
-
- # 3. 加载LightGCN模型
- print("正在加载模型参数...")
pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
@@ -35,23 +113,90 @@
model.load_state_dict(pretrained_dict, strict=False)
model.eval()
- # 4. 推理
- print(f"正在为用户 {user_id} 推理推荐结果...")
- t_infer_start = time.time()
with torch.no_grad():
user_emb, item_emb = model.generate()
- user_vec = user_emb[user_id].unsqueeze(0)
+ user_vec = user_emb[user_idx].unsqueeze(0)
scores = model.rating(user_vec, item_emb).squeeze(0)
- pred_item = torch.argmax(scores).item()
- t_infer_end = time.time()
+ topk_indices = torch.topk(scores, topk).indices.cpu().numpy()
+ topk_seed_ids = [idx2seed[idx] for idx in topk_indices]
- print(f"用户{user_id}下一个最可能点击的物品ID为: {pred_item}")
- print(f"加载数据集耗时: {t_data_end - t_data_start:.4f} 秒")
- print(f"推理耗时: {t_infer_end - t_infer_start:.4f} 秒")
+ return topk_seed_ids
+
+def seed_info(topk_seed_ids):
+ """
+ 输入: topk_seed_ids(种子ID字符串列表)
+ 输出: 推荐种子的详细信息列表,每个元素为dict
+ """
+ if not topk_seed_ids:
+ return []
+
+ conn = pymysql.connect(**DB_CONFIG)
+ cursor = conn.cursor()
+
+ # 查询种子基本信息
+ format_strings = ','.join(['%s'] * len(topk_seed_ids))
+ cursor.execute(
+ f"SELECT seed_id, owner_user_id, tags, title, size, popularity FROM Seed WHERE seed_id IN ({format_strings})",
+ tuple(topk_seed_ids)
+ )
+ seed_rows = cursor.fetchall()
+ seed_map = {row[0]: row for row in seed_rows}
+
+ # 查询用户信息
+ owner_ids = list(set(row[1] for row in seed_rows))
+ if owner_ids:
+ format_strings_user = ','.join(['%s'] * len(owner_ids))
+ cursor.execute(
+ f"SELECT user_id, username FROM User WHERE user_id IN ({format_strings_user})",
+ tuple(owner_ids)
+ )
+ user_rows = cursor.fetchall()
+ user_map = {row[0]: row[1] for row in user_rows}
+ else:
+ user_map = {}
+
+ # 查询促销信息
+ cursor.execute(
+ f"SELECT seed_id, discount FROM SeedPromotion WHERE seed_id IN ({format_strings})",
+ tuple(topk_seed_ids)
+ )
+ promo_rows = cursor.fetchall()
+ promo_map = {row[0]: row[1] for row in promo_rows}
+
+ cursor.close()
+ conn.close()
+
+ seed_list = []
+ for seed_id in topk_seed_ids:
+ row = seed_map.get(seed_id)
+ if not row:
+ continue
+ owner_user_id = row[1]
+ seed_list.append({
+ 'seed_id': seed_id,
+ 'tags': row[2],
+ 'title': row[3],
+ 'size': row[4],
+ 'username': user_map.get(owner_user_id, ""),
+ 'popularity': row[5],
+ 'discount': promo_map.get(seed_id, 1)
+ })
+ return seed_list
+
+@app.route('/recommend', methods=['POST'])
+def recommend():
+ data = request.get_json()
+ user_id = data.get('user_id')
+ try:
+ result = run_inference(user_id)
+ # 如果是冷启动直接返回详细信息,否则查详情
+ if isinstance(result, list) and result and isinstance(result[0], dict):
+ seed_list = result
+ else:
+ seed_list = seed_info(result)
+ return jsonify({'recommend': seed_list})
+ except Exception as e:
+ return jsonify({'error': str(e)}), 400
if __name__ == "__main__":
- t_start = time.time()
- user_id = 1
- run_inference(user_id)
- t_end = time.time()
- print(f"脚本总耗时: {t_end - t_start:.4f} 秒")
\ No newline at end of file
+ app.run(host='0.0.0.0', port=5000)
diff --git a/recommend/test_inference.py b/recommend/test_inference.py
new file mode 100644
index 0000000..b85dacd
--- /dev/null
+++ b/recommend/test_inference.py
@@ -0,0 +1,14 @@
+import requests
+
+API_URL = "http://10.126.59.25:5000/recommend"
+
+user_ids = ['user1', '550e8400-e29b-41d4-a716-446655440000']
+
+for user_id in user_ids:
+ try:
+ response = requests.post(API_URL, json={'user_id': user_id})
+ response.raise_for_status()
+ result = response.json()
+ print(f"Recommendations for user_id '{user_id}': {result}")
+ except requests.exceptions.RequestException as e:
+ print(f"Error occurred while fetching recommendations for user_id '{user_id}': {e}")
diff --git a/recommend/user_seed_graph.txt b/recommend/user_seed_graph.txt
index ef29920..5b5ad3d 100644
--- a/recommend/user_seed_graph.txt
+++ b/recommend/user_seed_graph.txt
@@ -1,3 +1,3 @@
-0 0 0 1 2 1746061954 1736237924 1736240066 1736309966
+0 0 0 1 1746061954 1736237924 1748971659
1 1 1 1746315010 1746583706
-2 2 0 2 1746738305 1746865166 1749284366
+2 2 1746738305
diff --git a/recommend/utils/data_generator.py b/recommend/utils/graph_build.py
similarity index 92%
rename from recommend/utils/data_generator.py
rename to recommend/utils/graph_build.py
index d3ef9cf..becf5df 100644
--- a/recommend/utils/data_generator.py
+++ b/recommend/utils/graph_build.py
@@ -9,7 +9,7 @@
SqlPassword = "123456"
-def fetch_data():
+def fetch_user_seed_data():
conn = pymysql.connect(
host=SqlURL,
port=SqlPort,
@@ -74,8 +74,10 @@
f.write(f"{uid}\t{items}\t{times}\n")
-def build_user_seed_graph():
- download_rows, favorite_rows = fetch_data()
+def build_user_seed_graph(return_mapping=False):
+ download_rows, favorite_rows = fetch_user_seed_data()
records, user_set, seed_set = process_records(download_rows, favorite_rows)
user2idx, seed2idx = build_id_maps(user_set, seed_set)
group_and_write(records, user2idx, seed2idx)
+ if return_mapping:
+ return user2idx, seed2idx
\ No newline at end of file