| import sys |
| sys.path.append('./') |
| |
| from os import path |
| from utils.parse_args import args |
| from utils.data_loader import EdgeListData |
| from model.LightGCN import LightGCN |
| import torch |
| import numpy as np |
| import time |
| |
| # 计时:脚本开始 |
| t_start = time.time() |
| |
| # 配置参数 |
| args.device = 'cuda:7' |
| args.data_path = './user_seed_graph.txt' |
| args.pre_model_path = './model/LightGCN_pretrained.pt' |
| |
| |
| # 1. 加载数据集 |
| t_data_start = time.time() |
| dataset = EdgeListData(args.data_path, args.data_path) |
| t_data_end = time.time() |
| |
| |
| # 2. 加载LightGCN模型 |
| pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True) |
| pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users] |
| pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items] |
| |
| model = LightGCN(dataset, phase='vanilla').to(args.device) |
| model.load_state_dict(pretrained_dict, strict=False) |
| model.eval() |
| |
| # 3. 输入用户ID |
| user_id = 1 |
| |
| # 4. 推理:获取embedding并打分 |
| t_infer_start = time.time() |
| with torch.no_grad(): |
| user_emb, item_emb = model.generate() |
| user_vec = user_emb[user_id].unsqueeze(0) |
| scores = model.rating(user_vec, item_emb).squeeze(0) |
| pred_item = torch.argmax(scores).item() |
| t_infer_end = time.time() |
| |
| t_end = time.time() |
| |
| print(f"用户{user_id}下一个最可能点击的物品ID为: {pred_item}") |
| print(f"加载数据集耗时: {t_data_end - t_data_start:.4f} 秒") |
| print(f"推理耗时: {t_infer_end - t_infer_start:.4f} 秒") |
| print(f"脚本总耗时: {t_end - t_start:.4f} 秒") |