新增管理员页面和用户申诉、迁移审核页面,推荐系统
Change-Id: Ief5646321feb98fadb17da4b4e91caeaacdbacc5
diff --git a/recommend/inference.py b/recommend/inference.py
new file mode 100644
index 0000000..697b569
--- /dev/null
+++ b/recommend/inference.py
@@ -0,0 +1,54 @@
+import sys
+sys.path.append('./')
+
+from os import path
+from utils.parse_args import args
+from utils.dataloader import EdgeListData
+from model.LightGCN import LightGCN
+import torch
+import numpy as np
+import time
+
+# 计时:脚本开始
+t_start = time.time()
+
+# 配置参数
+args.data_path = './'
+args.device = 'cuda:7'
+args.pre_model_path = './model/LightGCN_pretrained.pt'
+
+# 1. 加载数据集
+t_data_start = time.time()
+pretrain_data = path.join(args.data_path, "uig.txt")
+pretrain_val_data = path.join(args.data_path, "uig.txt")
+dataset = EdgeListData(pretrain_data, pretrain_val_data)
+t_data_end = time.time()
+
+
+# 2. 加载LightGCN模型
+pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
+pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
+pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
+
+model = LightGCN(dataset, phase='vanilla').to(args.device)
+model.load_state_dict(pretrained_dict, strict=False)
+model.eval()
+
+# 3. 输入用户ID
+user_id = 1
+
+# 4. 推理:获取embedding并打分
+t_infer_start = time.time()
+with torch.no_grad():
+ user_emb, item_emb = model.generate()
+ user_vec = user_emb[user_id].unsqueeze(0)
+ scores = model.rating(user_vec, item_emb).squeeze(0)
+ pred_item = torch.argmax(scores).item()
+t_infer_end = time.time()
+
+t_end = time.time()
+
+print(f"用户{user_id}下一个最可能点击的物品ID为: {pred_item}")
+print(f"加载数据集耗时: {t_data_end - t_data_start:.4f} 秒")
+print(f"推理耗时: {t_infer_end - t_infer_start:.4f} 秒")
+print(f"脚本总耗时: {t_end - t_start:.4f} 秒")
\ No newline at end of file