wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame] | 1 | import sys |
| 2 | sys.path.append('./') |
| 3 | |
wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame] | 4 | import torch |
wht | 4703881 | 2025-06-09 23:33:09 +0800 | [diff] [blame^] | 5 | import pymysql |
| 6 | from flask_cors import CORS |
| 7 | from flask import Flask, request, jsonify |
wht | 1564218 | 2025-06-08 00:16:52 +0800 | [diff] [blame] | 8 | from model.LightGCN import LightGCN |
| 9 | from utils.parse_args import args |
| 10 | from utils.data_loader import EdgeListData |
wht | 4703881 | 2025-06-09 23:33:09 +0800 | [diff] [blame^] | 11 | from utils.graph_build import build_user_seed_graph |
| 12 | |
| 13 | app = Flask(__name__) |
| 14 | CORS(app) |
wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame] | 15 | |
wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame] | 16 | args.device = 'cuda:7' |
wht | 4769537 | 2025-06-07 17:23:42 +0800 | [diff] [blame] | 17 | args.data_path = './user_seed_graph.txt' |
wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame] | 18 | args.pre_model_path = './model/LightGCN_pretrained.pt' |
| 19 | |
wht | 4703881 | 2025-06-09 23:33:09 +0800 | [diff] [blame^] | 20 | # 数据库连接配置 |
| 21 | DB_CONFIG = { |
| 22 | 'host': '10.126.59.25', |
| 23 | 'port': 3306, |
| 24 | 'user': 'root', |
| 25 | 'password': '123456', |
| 26 | 'database': 'pt_database_test', |
| 27 | 'charset': 'utf8mb4' |
| 28 | } |
wht | 4769537 | 2025-06-07 17:23:42 +0800 | [diff] [blame] | 29 | |
wht | 4703881 | 2025-06-09 23:33:09 +0800 | [diff] [blame^] | 30 | TOPK = 2 # 默认推荐数量 |
| 31 | |
| 32 | def user_cold_start(topk=TOPK): |
| 33 | """ |
| 34 | 冷启动:直接返回热度最高的topk个种子详细信息 |
| 35 | """ |
| 36 | conn = pymysql.connect(**DB_CONFIG) |
| 37 | cursor = conn.cursor() |
| 38 | |
| 39 | # 查询热度最高的topk个种子 |
| 40 | cursor.execute( |
| 41 | f"SELECT seed_id, owner_user_id, tags, title, size, popularity FROM Seed ORDER BY popularity DESC LIMIT %s", |
| 42 | (topk,) |
| 43 | ) |
| 44 | seed_rows = cursor.fetchall() |
| 45 | seed_ids = [row[0] for row in seed_rows] |
| 46 | seed_map = {row[0]: row for row in seed_rows} |
| 47 | |
| 48 | # 查询用户信息 |
| 49 | owner_ids = list(set(row[1] for row in seed_rows)) |
| 50 | if owner_ids: |
| 51 | format_strings_user = ','.join(['%s'] * len(owner_ids)) |
| 52 | cursor.execute( |
| 53 | f"SELECT user_id, username FROM User WHERE user_id IN ({format_strings_user})", |
| 54 | tuple(owner_ids) |
| 55 | ) |
| 56 | user_rows = cursor.fetchall() |
| 57 | user_map = {row[0]: row[1] for row in user_rows} |
| 58 | else: |
| 59 | user_map = {} |
| 60 | |
| 61 | # 查询促销信息 |
| 62 | if seed_ids: |
| 63 | format_strings = ','.join(['%s'] * len(seed_ids)) |
| 64 | cursor.execute( |
| 65 | f"SELECT seed_id, discount FROM SeedPromotion WHERE seed_id IN ({format_strings})", |
| 66 | tuple(seed_ids) |
| 67 | ) |
| 68 | promo_rows = cursor.fetchall() |
| 69 | promo_map = {row[0]: row[1] for row in promo_rows} |
| 70 | else: |
| 71 | promo_map = {} |
| 72 | |
| 73 | cursor.close() |
| 74 | conn.close() |
| 75 | |
| 76 | seed_list = [] |
| 77 | for seed_id in seed_ids: |
| 78 | row = seed_map.get(seed_id) |
| 79 | if not row: |
| 80 | continue |
| 81 | owner_user_id = row[1] |
| 82 | seed_list.append({ |
| 83 | 'seed_id': seed_id, |
| 84 | 'tags': row[2], |
| 85 | 'title': row[3], |
| 86 | 'size': row[4], |
| 87 | 'username': user_map.get(owner_user_id, ""), |
| 88 | 'popularity': row[5], |
| 89 | 'discount': promo_map.get(seed_id, 1) |
| 90 | }) |
| 91 | return seed_list |
| 92 | |
| 93 | def run_inference(user_id, topk=TOPK): |
| 94 | """ |
| 95 | 输入: user_id, topk |
| 96 | 输出: 推荐的topk个种子ID列表(原始种子ID字符串) |
| 97 | """ |
| 98 | user2idx, seed2idx = build_user_seed_graph(return_mapping=True) |
| 99 | idx2seed = {v: k for k, v in seed2idx.items()} |
| 100 | |
| 101 | if user_id not in user2idx: |
| 102 | # 冷启动 |
| 103 | return user_cold_start(topk) |
| 104 | |
| 105 | user_idx = user2idx[user_id] |
| 106 | |
wht | 1564218 | 2025-06-08 00:16:52 +0800 | [diff] [blame] | 107 | dataset = EdgeListData(args.data_path, args.data_path) |
wht | 1564218 | 2025-06-08 00:16:52 +0800 | [diff] [blame] | 108 | pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True) |
| 109 | pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users] |
| 110 | pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items] |
wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame] | 111 | |
wht | 1564218 | 2025-06-08 00:16:52 +0800 | [diff] [blame] | 112 | model = LightGCN(dataset, phase='vanilla').to(args.device) |
| 113 | model.load_state_dict(pretrained_dict, strict=False) |
| 114 | model.eval() |
wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame] | 115 | |
wht | 1564218 | 2025-06-08 00:16:52 +0800 | [diff] [blame] | 116 | with torch.no_grad(): |
| 117 | user_emb, item_emb = model.generate() |
wht | 4703881 | 2025-06-09 23:33:09 +0800 | [diff] [blame^] | 118 | user_vec = user_emb[user_idx].unsqueeze(0) |
wht | 1564218 | 2025-06-08 00:16:52 +0800 | [diff] [blame] | 119 | scores = model.rating(user_vec, item_emb).squeeze(0) |
wht | 4703881 | 2025-06-09 23:33:09 +0800 | [diff] [blame^] | 120 | topk_indices = torch.topk(scores, topk).indices.cpu().numpy() |
| 121 | topk_seed_ids = [idx2seed[idx] for idx in topk_indices] |
wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame] | 122 | |
wht | 4703881 | 2025-06-09 23:33:09 +0800 | [diff] [blame^] | 123 | return topk_seed_ids |
| 124 | |
| 125 | def seed_info(topk_seed_ids): |
| 126 | """ |
| 127 | 输入: topk_seed_ids(种子ID字符串列表) |
| 128 | 输出: 推荐种子的详细信息列表,每个元素为dict |
| 129 | """ |
| 130 | if not topk_seed_ids: |
| 131 | return [] |
| 132 | |
| 133 | conn = pymysql.connect(**DB_CONFIG) |
| 134 | cursor = conn.cursor() |
| 135 | |
| 136 | # 查询种子基本信息 |
| 137 | format_strings = ','.join(['%s'] * len(topk_seed_ids)) |
| 138 | cursor.execute( |
| 139 | f"SELECT seed_id, owner_user_id, tags, title, size, popularity FROM Seed WHERE seed_id IN ({format_strings})", |
| 140 | tuple(topk_seed_ids) |
| 141 | ) |
| 142 | seed_rows = cursor.fetchall() |
| 143 | seed_map = {row[0]: row for row in seed_rows} |
| 144 | |
| 145 | # 查询用户信息 |
| 146 | owner_ids = list(set(row[1] for row in seed_rows)) |
| 147 | if owner_ids: |
| 148 | format_strings_user = ','.join(['%s'] * len(owner_ids)) |
| 149 | cursor.execute( |
| 150 | f"SELECT user_id, username FROM User WHERE user_id IN ({format_strings_user})", |
| 151 | tuple(owner_ids) |
| 152 | ) |
| 153 | user_rows = cursor.fetchall() |
| 154 | user_map = {row[0]: row[1] for row in user_rows} |
| 155 | else: |
| 156 | user_map = {} |
| 157 | |
| 158 | # 查询促销信息 |
| 159 | cursor.execute( |
| 160 | f"SELECT seed_id, discount FROM SeedPromotion WHERE seed_id IN ({format_strings})", |
| 161 | tuple(topk_seed_ids) |
| 162 | ) |
| 163 | promo_rows = cursor.fetchall() |
| 164 | promo_map = {row[0]: row[1] for row in promo_rows} |
| 165 | |
| 166 | cursor.close() |
| 167 | conn.close() |
| 168 | |
| 169 | seed_list = [] |
| 170 | for seed_id in topk_seed_ids: |
| 171 | row = seed_map.get(seed_id) |
| 172 | if not row: |
| 173 | continue |
| 174 | owner_user_id = row[1] |
| 175 | seed_list.append({ |
| 176 | 'seed_id': seed_id, |
| 177 | 'tags': row[2], |
| 178 | 'title': row[3], |
| 179 | 'size': row[4], |
| 180 | 'username': user_map.get(owner_user_id, ""), |
| 181 | 'popularity': row[5], |
| 182 | 'discount': promo_map.get(seed_id, 1) |
| 183 | }) |
| 184 | return seed_list |
| 185 | |
| 186 | @app.route('/recommend', methods=['POST']) |
| 187 | def recommend(): |
| 188 | data = request.get_json() |
| 189 | user_id = data.get('user_id') |
| 190 | try: |
| 191 | result = run_inference(user_id) |
| 192 | # 如果是冷启动直接返回详细信息,否则查详情 |
| 193 | if isinstance(result, list) and result and isinstance(result[0], dict): |
| 194 | seed_list = result |
| 195 | else: |
| 196 | seed_list = seed_info(result) |
| 197 | return jsonify({'recommend': seed_list}) |
| 198 | except Exception as e: |
| 199 | return jsonify({'error': str(e)}), 400 |
wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame] | 200 | |
wht | 1564218 | 2025-06-08 00:16:52 +0800 | [diff] [blame] | 201 | if __name__ == "__main__": |
wht | 4703881 | 2025-06-09 23:33:09 +0800 | [diff] [blame^] | 202 | app.run(host='0.0.0.0', port=5000) |