blob: e2b7b00fd4025cbc86189a58b8b8907ee3f19018 [file] [log] [blame]
whtb1e79592025-06-07 16:03:09 +08001import sys
2sys.path.append('./')
3
whtb1e79592025-06-07 16:03:09 +08004import torch
wht47038812025-06-09 23:33:09 +08005import pymysql
6from flask_cors import CORS
7from flask import Flask, request, jsonify
wht15642182025-06-08 00:16:52 +08008from model.LightGCN import LightGCN
9from utils.parse_args import args
10from utils.data_loader import EdgeListData
wht47038812025-06-09 23:33:09 +080011from utils.graph_build import build_user_seed_graph
12
13app = Flask(__name__)
14CORS(app)
whtb1e79592025-06-07 16:03:09 +080015
whtb1e79592025-06-07 16:03:09 +080016args.device = 'cuda:7'
wht47695372025-06-07 17:23:42 +080017args.data_path = './user_seed_graph.txt'
whtb1e79592025-06-07 16:03:09 +080018args.pre_model_path = './model/LightGCN_pretrained.pt'
19
wht47038812025-06-09 23:33:09 +080020# 数据库连接配置
21DB_CONFIG = {
22 'host': '10.126.59.25',
23 'port': 3306,
24 'user': 'root',
25 'password': '123456',
26 'database': 'pt_database_test',
27 'charset': 'utf8mb4'
28}
wht47695372025-06-07 17:23:42 +080029
wht47038812025-06-09 23:33:09 +080030TOPK = 2 # 默认推荐数量
31
32def user_cold_start(topk=TOPK):
33 """
34 冷启动:直接返回热度最高的topk个种子详细信息
35 """
36 conn = pymysql.connect(**DB_CONFIG)
37 cursor = conn.cursor()
38
39 # 查询热度最高的topk个种子
40 cursor.execute(
41 f"SELECT seed_id, owner_user_id, tags, title, size, popularity FROM Seed ORDER BY popularity DESC LIMIT %s",
42 (topk,)
43 )
44 seed_rows = cursor.fetchall()
45 seed_ids = [row[0] for row in seed_rows]
46 seed_map = {row[0]: row for row in seed_rows}
47
48 # 查询用户信息
49 owner_ids = list(set(row[1] for row in seed_rows))
50 if owner_ids:
51 format_strings_user = ','.join(['%s'] * len(owner_ids))
52 cursor.execute(
53 f"SELECT user_id, username FROM User WHERE user_id IN ({format_strings_user})",
54 tuple(owner_ids)
55 )
56 user_rows = cursor.fetchall()
57 user_map = {row[0]: row[1] for row in user_rows}
58 else:
59 user_map = {}
60
61 # 查询促销信息
62 if seed_ids:
63 format_strings = ','.join(['%s'] * len(seed_ids))
64 cursor.execute(
65 f"SELECT seed_id, discount FROM SeedPromotion WHERE seed_id IN ({format_strings})",
66 tuple(seed_ids)
67 )
68 promo_rows = cursor.fetchall()
69 promo_map = {row[0]: row[1] for row in promo_rows}
70 else:
71 promo_map = {}
72
73 cursor.close()
74 conn.close()
75
76 seed_list = []
77 for seed_id in seed_ids:
78 row = seed_map.get(seed_id)
79 if not row:
80 continue
81 owner_user_id = row[1]
82 seed_list.append({
83 'seed_id': seed_id,
84 'tags': row[2],
85 'title': row[3],
86 'size': row[4],
87 'username': user_map.get(owner_user_id, ""),
88 'popularity': row[5],
89 'discount': promo_map.get(seed_id, 1)
90 })
91 return seed_list
92
93def run_inference(user_id, topk=TOPK):
94 """
95 输入: user_id, topk
96 输出: 推荐的topk个种子ID列表(原始种子ID字符串)
97 """
98 user2idx, seed2idx = build_user_seed_graph(return_mapping=True)
99 idx2seed = {v: k for k, v in seed2idx.items()}
100
101 if user_id not in user2idx:
102 # 冷启动
103 return user_cold_start(topk)
104
105 user_idx = user2idx[user_id]
106
wht15642182025-06-08 00:16:52 +0800107 dataset = EdgeListData(args.data_path, args.data_path)
wht15642182025-06-08 00:16:52 +0800108 pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
109 pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
110 pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
whtb1e79592025-06-07 16:03:09 +0800111
wht15642182025-06-08 00:16:52 +0800112 model = LightGCN(dataset, phase='vanilla').to(args.device)
113 model.load_state_dict(pretrained_dict, strict=False)
114 model.eval()
whtb1e79592025-06-07 16:03:09 +0800115
wht15642182025-06-08 00:16:52 +0800116 with torch.no_grad():
117 user_emb, item_emb = model.generate()
wht47038812025-06-09 23:33:09 +0800118 user_vec = user_emb[user_idx].unsqueeze(0)
wht15642182025-06-08 00:16:52 +0800119 scores = model.rating(user_vec, item_emb).squeeze(0)
wht47038812025-06-09 23:33:09 +0800120 topk_indices = torch.topk(scores, topk).indices.cpu().numpy()
121 topk_seed_ids = [idx2seed[idx] for idx in topk_indices]
whtb1e79592025-06-07 16:03:09 +0800122
wht47038812025-06-09 23:33:09 +0800123 return topk_seed_ids
124
125def seed_info(topk_seed_ids):
126 """
127 输入: topk_seed_ids(种子ID字符串列表)
128 输出: 推荐种子的详细信息列表,每个元素为dict
129 """
130 if not topk_seed_ids:
131 return []
132
133 conn = pymysql.connect(**DB_CONFIG)
134 cursor = conn.cursor()
135
136 # 查询种子基本信息
137 format_strings = ','.join(['%s'] * len(topk_seed_ids))
138 cursor.execute(
139 f"SELECT seed_id, owner_user_id, tags, title, size, popularity FROM Seed WHERE seed_id IN ({format_strings})",
140 tuple(topk_seed_ids)
141 )
142 seed_rows = cursor.fetchall()
143 seed_map = {row[0]: row for row in seed_rows}
144
145 # 查询用户信息
146 owner_ids = list(set(row[1] for row in seed_rows))
147 if owner_ids:
148 format_strings_user = ','.join(['%s'] * len(owner_ids))
149 cursor.execute(
150 f"SELECT user_id, username FROM User WHERE user_id IN ({format_strings_user})",
151 tuple(owner_ids)
152 )
153 user_rows = cursor.fetchall()
154 user_map = {row[0]: row[1] for row in user_rows}
155 else:
156 user_map = {}
157
158 # 查询促销信息
159 cursor.execute(
160 f"SELECT seed_id, discount FROM SeedPromotion WHERE seed_id IN ({format_strings})",
161 tuple(topk_seed_ids)
162 )
163 promo_rows = cursor.fetchall()
164 promo_map = {row[0]: row[1] for row in promo_rows}
165
166 cursor.close()
167 conn.close()
168
169 seed_list = []
170 for seed_id in topk_seed_ids:
171 row = seed_map.get(seed_id)
172 if not row:
173 continue
174 owner_user_id = row[1]
175 seed_list.append({
176 'seed_id': seed_id,
177 'tags': row[2],
178 'title': row[3],
179 'size': row[4],
180 'username': user_map.get(owner_user_id, ""),
181 'popularity': row[5],
182 'discount': promo_map.get(seed_id, 1)
183 })
184 return seed_list
185
186@app.route('/recommend', methods=['POST'])
187def recommend():
188 data = request.get_json()
189 user_id = data.get('user_id')
190 try:
191 result = run_inference(user_id)
192 # 如果是冷启动直接返回详细信息,否则查详情
193 if isinstance(result, list) and result and isinstance(result[0], dict):
194 seed_list = result
195 else:
196 seed_list = seed_info(result)
197 return jsonify({'recommend': seed_list})
198 except Exception as e:
199 return jsonify({'error': str(e)}), 400
whtb1e79592025-06-07 16:03:09 +0800200
wht15642182025-06-08 00:16:52 +0800201if __name__ == "__main__":
wht47038812025-06-09 23:33:09 +0800202 app.run(host='0.0.0.0', port=5000)