blob: 697b569a1b584563d1cb4eb966d95f7f5da7a88e [file] [log] [blame]
import sys
sys.path.append('./')
from os import path
from utils.parse_args import args
from utils.dataloader import EdgeListData
from model.LightGCN import LightGCN
import torch
import numpy as np
import time
# 计时:脚本开始
t_start = time.time()
# 配置参数
args.data_path = './'
args.device = 'cuda:7'
args.pre_model_path = './model/LightGCN_pretrained.pt'
# 1. 加载数据集
t_data_start = time.time()
pretrain_data = path.join(args.data_path, "uig.txt")
pretrain_val_data = path.join(args.data_path, "uig.txt")
dataset = EdgeListData(pretrain_data, pretrain_val_data)
t_data_end = time.time()
# 2. 加载LightGCN模型
pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
model = LightGCN(dataset, phase='vanilla').to(args.device)
model.load_state_dict(pretrained_dict, strict=False)
model.eval()
# 3. 输入用户ID
user_id = 1
# 4. 推理:获取embedding并打分
t_infer_start = time.time()
with torch.no_grad():
user_emb, item_emb = model.generate()
user_vec = user_emb[user_id].unsqueeze(0)
scores = model.rating(user_vec, item_emb).squeeze(0)
pred_item = torch.argmax(scores).item()
t_infer_end = time.time()
t_end = time.time()
print(f"用户{user_id}下一个最可能点击的物品ID为: {pred_item}")
print(f"加载数据集耗时: {t_data_end - t_data_start:.4f} 秒")
print(f"推理耗时: {t_infer_end - t_infer_start:.4f} 秒")
print(f"脚本总耗时: {t_end - t_start:.4f} 秒")