wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame^] | 1 | import sys |
| 2 | sys.path.append('./') |
| 3 | |
| 4 | from os import path |
| 5 | from utils.parse_args import args |
| 6 | from utils.dataloader import EdgeListData |
| 7 | from model.LightGCN import LightGCN |
| 8 | import torch |
| 9 | import numpy as np |
| 10 | import time |
| 11 | |
| 12 | # 计时:脚本开始 |
| 13 | t_start = time.time() |
| 14 | |
| 15 | # 配置参数 |
| 16 | args.data_path = './' |
| 17 | args.device = 'cuda:7' |
| 18 | args.pre_model_path = './model/LightGCN_pretrained.pt' |
| 19 | |
| 20 | # 1. 加载数据集 |
| 21 | t_data_start = time.time() |
| 22 | pretrain_data = path.join(args.data_path, "uig.txt") |
| 23 | pretrain_val_data = path.join(args.data_path, "uig.txt") |
| 24 | dataset = EdgeListData(pretrain_data, pretrain_val_data) |
| 25 | t_data_end = time.time() |
| 26 | |
| 27 | |
| 28 | # 2. 加载LightGCN模型 |
| 29 | pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True) |
| 30 | pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users] |
| 31 | pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items] |
| 32 | |
| 33 | model = LightGCN(dataset, phase='vanilla').to(args.device) |
| 34 | model.load_state_dict(pretrained_dict, strict=False) |
| 35 | model.eval() |
| 36 | |
| 37 | # 3. 输入用户ID |
| 38 | user_id = 1 |
| 39 | |
| 40 | # 4. 推理:获取embedding并打分 |
| 41 | t_infer_start = time.time() |
| 42 | with torch.no_grad(): |
| 43 | user_emb, item_emb = model.generate() |
| 44 | user_vec = user_emb[user_id].unsqueeze(0) |
| 45 | scores = model.rating(user_vec, item_emb).squeeze(0) |
| 46 | pred_item = torch.argmax(scores).item() |
| 47 | t_infer_end = time.time() |
| 48 | |
| 49 | t_end = time.time() |
| 50 | |
| 51 | print(f"用户{user_id}下一个最可能点击的物品ID为: {pred_item}") |
| 52 | print(f"加载数据集耗时: {t_data_end - t_data_start:.4f} 秒") |
| 53 | print(f"推理耗时: {t_infer_end - t_infer_start:.4f} 秒") |
| 54 | print(f"脚本总耗时: {t_end - t_start:.4f} 秒") |