wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame] | 1 | import sys |
| 2 | sys.path.append('./') |
| 3 | |
wht | 1564218 | 2025-06-08 00:16:52 +0800 | [diff] [blame^] | 4 | import time |
wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame] | 5 | import torch |
| 6 | import numpy as np |
wht | 1564218 | 2025-06-08 00:16:52 +0800 | [diff] [blame^] | 7 | from os import path |
| 8 | from model.LightGCN import LightGCN |
| 9 | from utils.parse_args import args |
| 10 | from utils.data_loader import EdgeListData |
| 11 | from utils.data_generator import build_user_seed_graph |
wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame] | 12 | |
wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame] | 13 | args.device = 'cuda:7' |
wht | 4769537 | 2025-06-07 17:23:42 +0800 | [diff] [blame] | 14 | args.data_path = './user_seed_graph.txt' |
wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame] | 15 | args.pre_model_path = './model/LightGCN_pretrained.pt' |
| 16 | |
wht | 1564218 | 2025-06-08 00:16:52 +0800 | [diff] [blame^] | 17 | def run_inference(user_id=1): |
| 18 | # 1. 实时生成user-seed交互图 |
| 19 | print("正在生成用户-种子交互文件...") |
| 20 | build_user_seed_graph() |
wht | 4769537 | 2025-06-07 17:23:42 +0800 | [diff] [blame] | 21 | |
wht | 1564218 | 2025-06-08 00:16:52 +0800 | [diff] [blame^] | 22 | # 2. 加载数据集 |
| 23 | print("正在加载数据集...") |
| 24 | t_data_start = time.time() |
| 25 | dataset = EdgeListData(args.data_path, args.data_path) |
| 26 | t_data_end = time.time() |
wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame] | 27 | |
wht | 1564218 | 2025-06-08 00:16:52 +0800 | [diff] [blame^] | 28 | # 3. 加载LightGCN模型 |
| 29 | print("正在加载模型参数...") |
| 30 | pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True) |
| 31 | pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users] |
| 32 | pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items] |
wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame] | 33 | |
wht | 1564218 | 2025-06-08 00:16:52 +0800 | [diff] [blame^] | 34 | model = LightGCN(dataset, phase='vanilla').to(args.device) |
| 35 | model.load_state_dict(pretrained_dict, strict=False) |
| 36 | model.eval() |
wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame] | 37 | |
wht | 1564218 | 2025-06-08 00:16:52 +0800 | [diff] [blame^] | 38 | # 4. 推理 |
| 39 | print(f"正在为用户 {user_id} 推理推荐结果...") |
| 40 | t_infer_start = time.time() |
| 41 | with torch.no_grad(): |
| 42 | user_emb, item_emb = model.generate() |
| 43 | user_vec = user_emb[user_id].unsqueeze(0) |
| 44 | scores = model.rating(user_vec, item_emb).squeeze(0) |
| 45 | pred_item = torch.argmax(scores).item() |
| 46 | t_infer_end = time.time() |
wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame] | 47 | |
wht | 1564218 | 2025-06-08 00:16:52 +0800 | [diff] [blame^] | 48 | print(f"用户{user_id}下一个最可能点击的物品ID为: {pred_item}") |
| 49 | print(f"加载数据集耗时: {t_data_end - t_data_start:.4f} 秒") |
| 50 | print(f"推理耗时: {t_infer_end - t_infer_start:.4f} 秒") |
wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame] | 51 | |
wht | 1564218 | 2025-06-08 00:16:52 +0800 | [diff] [blame^] | 52 | if __name__ == "__main__": |
| 53 | t_start = time.time() |
| 54 | user_id = 1 |
| 55 | run_inference(user_id) |
| 56 | t_end = time.time() |
| 57 | print(f"脚本总耗时: {t_end - t_start:.4f} 秒") |