修改个人主页,优化推荐系统
Change-Id: I533e36dc891b1b60fcb1e4a71b522b3ba9d77984
diff --git a/recommend/inference.py b/recommend/inference.py
index 346209c..9601230 100644
--- a/recommend/inference.py
+++ b/recommend/inference.py
@@ -1,53 +1,57 @@
import sys
sys.path.append('./')
-from os import path
-from utils.parse_args import args
-from utils.data_loader import EdgeListData
-from model.LightGCN import LightGCN
+import time
import torch
import numpy as np
-import time
+from os import path
+from model.LightGCN import LightGCN
+from utils.parse_args import args
+from utils.data_loader import EdgeListData
+from utils.data_generator import build_user_seed_graph
-# 计时:脚本开始
-t_start = time.time()
-
-# 配置参数
args.device = 'cuda:7'
args.data_path = './user_seed_graph.txt'
args.pre_model_path = './model/LightGCN_pretrained.pt'
+def run_inference(user_id=1):
+ # 1. 实时生成user-seed交互图
+ print("正在生成用户-种子交互文件...")
+ build_user_seed_graph()
-# 1. 加载数据集
-t_data_start = time.time()
-dataset = EdgeListData(args.data_path, args.data_path)
-t_data_end = time.time()
+ # 2. 加载数据集
+ print("正在加载数据集...")
+ t_data_start = time.time()
+ dataset = EdgeListData(args.data_path, args.data_path)
+ t_data_end = time.time()
+ # 3. 加载LightGCN模型
+ print("正在加载模型参数...")
+ pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
+ pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
+ pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
-# 2. 加载LightGCN模型
-pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
-pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
-pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
+ model = LightGCN(dataset, phase='vanilla').to(args.device)
+ model.load_state_dict(pretrained_dict, strict=False)
+ model.eval()
-model = LightGCN(dataset, phase='vanilla').to(args.device)
-model.load_state_dict(pretrained_dict, strict=False)
-model.eval()
+ # 4. 推理
+ print(f"正在为用户 {user_id} 推理推荐结果...")
+ t_infer_start = time.time()
+ with torch.no_grad():
+ user_emb, item_emb = model.generate()
+ user_vec = user_emb[user_id].unsqueeze(0)
+ scores = model.rating(user_vec, item_emb).squeeze(0)
+ pred_item = torch.argmax(scores).item()
+ t_infer_end = time.time()
-# 3. 输入用户ID
-user_id = 1
+ print(f"用户{user_id}下一个最可能点击的物品ID为: {pred_item}")
+ print(f"加载数据集耗时: {t_data_end - t_data_start:.4f} 秒")
+ print(f"推理耗时: {t_infer_end - t_infer_start:.4f} 秒")
-# 4. 推理:获取embedding并打分
-t_infer_start = time.time()
-with torch.no_grad():
- user_emb, item_emb = model.generate()
- user_vec = user_emb[user_id].unsqueeze(0)
- scores = model.rating(user_vec, item_emb).squeeze(0)
- pred_item = torch.argmax(scores).item()
-t_infer_end = time.time()
-
-t_end = time.time()
-
-print(f"用户{user_id}下一个最可能点击的物品ID为: {pred_item}")
-print(f"加载数据集耗时: {t_data_end - t_data_start:.4f} 秒")
-print(f"推理耗时: {t_infer_end - t_infer_start:.4f} 秒")
-print(f"脚本总耗时: {t_end - t_start:.4f} 秒")
\ No newline at end of file
+if __name__ == "__main__":
+ t_start = time.time()
+ user_id = 1
+ run_inference(user_id)
+ t_end = time.time()
+ print(f"脚本总耗时: {t_end - t_start:.4f} 秒")
\ No newline at end of file