blob: 346209cb6185d2d8422ffedc972291de68320a37 [file] [log] [blame]
whtb1e79592025-06-07 16:03:09 +08001import sys
2sys.path.append('./')
3
4from os import path
5from utils.parse_args import args
wht47695372025-06-07 17:23:42 +08006from utils.data_loader import EdgeListData
whtb1e79592025-06-07 16:03:09 +08007from model.LightGCN import LightGCN
8import torch
9import numpy as np
10import time
11
12# 计时:脚本开始
13t_start = time.time()
14
15# 配置参数
whtb1e79592025-06-07 16:03:09 +080016args.device = 'cuda:7'
wht47695372025-06-07 17:23:42 +080017args.data_path = './user_seed_graph.txt'
whtb1e79592025-06-07 16:03:09 +080018args.pre_model_path = './model/LightGCN_pretrained.pt'
19
wht47695372025-06-07 17:23:42 +080020
whtb1e79592025-06-07 16:03:09 +080021# 1. 加载数据集
22t_data_start = time.time()
wht47695372025-06-07 17:23:42 +080023dataset = EdgeListData(args.data_path, args.data_path)
whtb1e79592025-06-07 16:03:09 +080024t_data_end = time.time()
25
26
27# 2. 加载LightGCN模型
28pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
29pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
30pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
31
32model = LightGCN(dataset, phase='vanilla').to(args.device)
33model.load_state_dict(pretrained_dict, strict=False)
34model.eval()
35
36# 3. 输入用户ID
37user_id = 1
38
39# 4. 推理:获取embedding并打分
40t_infer_start = time.time()
41with torch.no_grad():
42 user_emb, item_emb = model.generate()
43 user_vec = user_emb[user_id].unsqueeze(0)
44 scores = model.rating(user_vec, item_emb).squeeze(0)
45 pred_item = torch.argmax(scores).item()
46t_infer_end = time.time()
47
48t_end = time.time()
49
50print(f"用户{user_id}下一个最可能点击的物品ID为: {pred_item}")
51print(f"加载数据集耗时: {t_data_end - t_data_start:.4f} 秒")
52print(f"推理耗时: {t_infer_end - t_infer_start:.4f} 秒")
53print(f"脚本总耗时: {t_end - t_start:.4f} 秒")