blob: 697b569a1b584563d1cb4eb966d95f7f5da7a88e [file] [log] [blame]
whtb1e79592025-06-07 16:03:09 +08001import sys
2sys.path.append('./')
3
4from os import path
5from utils.parse_args import args
6from utils.dataloader import EdgeListData
7from model.LightGCN import LightGCN
8import torch
9import numpy as np
10import time
11
12# 计时:脚本开始
13t_start = time.time()
14
15# 配置参数
16args.data_path = './'
17args.device = 'cuda:7'
18args.pre_model_path = './model/LightGCN_pretrained.pt'
19
20# 1. 加载数据集
21t_data_start = time.time()
22pretrain_data = path.join(args.data_path, "uig.txt")
23pretrain_val_data = path.join(args.data_path, "uig.txt")
24dataset = EdgeListData(pretrain_data, pretrain_val_data)
25t_data_end = time.time()
26
27
28# 2. 加载LightGCN模型
29pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
30pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
31pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
32
33model = LightGCN(dataset, phase='vanilla').to(args.device)
34model.load_state_dict(pretrained_dict, strict=False)
35model.eval()
36
37# 3. 输入用户ID
38user_id = 1
39
40# 4. 推理:获取embedding并打分
41t_infer_start = time.time()
42with torch.no_grad():
43 user_emb, item_emb = model.generate()
44 user_vec = user_emb[user_id].unsqueeze(0)
45 scores = model.rating(user_vec, item_emb).squeeze(0)
46 pred_item = torch.argmax(scores).item()
47t_infer_end = time.time()
48
49t_end = time.time()
50
51print(f"用户{user_id}下一个最可能点击的物品ID为: {pred_item}")
52print(f"加载数据集耗时: {t_data_end - t_data_start:.4f} 秒")
53print(f"推理耗时: {t_infer_end - t_infer_start:.4f} 秒")
54print(f"脚本总耗时: {t_end - t_start:.4f} 秒")