blob: 9601230581b0edbf9d178d35512a49f373f7bb13 [file] [log] [blame]
whtb1e79592025-06-07 16:03:09 +08001import sys
2sys.path.append('./')
3
wht15642182025-06-08 00:16:52 +08004import time
whtb1e79592025-06-07 16:03:09 +08005import torch
6import numpy as np
wht15642182025-06-08 00:16:52 +08007from os import path
8from model.LightGCN import LightGCN
9from utils.parse_args import args
10from utils.data_loader import EdgeListData
11from utils.data_generator import build_user_seed_graph
whtb1e79592025-06-07 16:03:09 +080012
whtb1e79592025-06-07 16:03:09 +080013args.device = 'cuda:7'
wht47695372025-06-07 17:23:42 +080014args.data_path = './user_seed_graph.txt'
whtb1e79592025-06-07 16:03:09 +080015args.pre_model_path = './model/LightGCN_pretrained.pt'
16
wht15642182025-06-08 00:16:52 +080017def run_inference(user_id=1):
18 # 1. 实时生成user-seed交互图
19 print("正在生成用户-种子交互文件...")
20 build_user_seed_graph()
wht47695372025-06-07 17:23:42 +080021
wht15642182025-06-08 00:16:52 +080022 # 2. 加载数据集
23 print("正在加载数据集...")
24 t_data_start = time.time()
25 dataset = EdgeListData(args.data_path, args.data_path)
26 t_data_end = time.time()
whtb1e79592025-06-07 16:03:09 +080027
wht15642182025-06-08 00:16:52 +080028 # 3. 加载LightGCN模型
29 print("正在加载模型参数...")
30 pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
31 pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
32 pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
whtb1e79592025-06-07 16:03:09 +080033
wht15642182025-06-08 00:16:52 +080034 model = LightGCN(dataset, phase='vanilla').to(args.device)
35 model.load_state_dict(pretrained_dict, strict=False)
36 model.eval()
whtb1e79592025-06-07 16:03:09 +080037
wht15642182025-06-08 00:16:52 +080038 # 4. 推理
39 print(f"正在为用户 {user_id} 推理推荐结果...")
40 t_infer_start = time.time()
41 with torch.no_grad():
42 user_emb, item_emb = model.generate()
43 user_vec = user_emb[user_id].unsqueeze(0)
44 scores = model.rating(user_vec, item_emb).squeeze(0)
45 pred_item = torch.argmax(scores).item()
46 t_infer_end = time.time()
whtb1e79592025-06-07 16:03:09 +080047
wht15642182025-06-08 00:16:52 +080048 print(f"用户{user_id}下一个最可能点击的物品ID为: {pred_item}")
49 print(f"加载数据集耗时: {t_data_end - t_data_start:.4f} 秒")
50 print(f"推理耗时: {t_infer_end - t_infer_start:.4f} 秒")
whtb1e79592025-06-07 16:03:09 +080051
wht15642182025-06-08 00:16:52 +080052if __name__ == "__main__":
53 t_start = time.time()
54 user_id = 1
55 run_inference(user_id)
56 t_end = time.time()
57 print(f"脚本总耗时: {t_end - t_start:.4f} 秒")