修改个人主页,优化推荐系统

Change-Id: I533e36dc891b1b60fcb1e4a71b522b3ba9d77984
diff --git a/recommend/inference.py b/recommend/inference.py
index 346209c..9601230 100644
--- a/recommend/inference.py
+++ b/recommend/inference.py
@@ -1,53 +1,57 @@
 import sys
 sys.path.append('./')
 
-from os import path
-from utils.parse_args import args
-from utils.data_loader import EdgeListData
-from model.LightGCN import LightGCN
+import time
 import torch
 import numpy as np
-import time
+from os import path
+from model.LightGCN import LightGCN
+from utils.parse_args import args
+from utils.data_loader import EdgeListData
+from utils.data_generator import build_user_seed_graph
 
-# 计时:脚本开始
-t_start = time.time()
-
-# 配置参数
 args.device = 'cuda:7'
 args.data_path = './user_seed_graph.txt'
 args.pre_model_path = './model/LightGCN_pretrained.pt'
 
+def run_inference(user_id=1):
+    # 1. 实时生成user-seed交互图
+    print("正在生成用户-种子交互文件...")
+    build_user_seed_graph()
 
-# 1. 加载数据集
-t_data_start = time.time()
-dataset = EdgeListData(args.data_path, args.data_path)
-t_data_end = time.time()
+    # 2. 加载数据集
+    print("正在加载数据集...")
+    t_data_start = time.time()
+    dataset = EdgeListData(args.data_path, args.data_path)
+    t_data_end = time.time()
 
+    # 3. 加载LightGCN模型
+    print("正在加载模型参数...")
+    pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
+    pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
+    pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
 
-# 2. 加载LightGCN模型
-pretrained_dict = torch.load(args.pre_model_path, map_location=args.device, weights_only=True)
-pretrained_dict['user_embedding'] = pretrained_dict['user_embedding'][:dataset.num_users]
-pretrained_dict['item_embedding'] = pretrained_dict['item_embedding'][:dataset.num_items]
+    model = LightGCN(dataset, phase='vanilla').to(args.device)
+    model.load_state_dict(pretrained_dict, strict=False)
+    model.eval()
 
-model = LightGCN(dataset, phase='vanilla').to(args.device)
-model.load_state_dict(pretrained_dict, strict=False)
-model.eval()
+    # 4. 推理
+    print(f"正在为用户 {user_id} 推理推荐结果...")
+    t_infer_start = time.time()
+    with torch.no_grad():
+        user_emb, item_emb = model.generate()
+        user_vec = user_emb[user_id].unsqueeze(0)
+        scores = model.rating(user_vec, item_emb).squeeze(0)
+        pred_item = torch.argmax(scores).item()
+    t_infer_end = time.time()
 
-# 3. 输入用户ID
-user_id = 1
+    print(f"用户{user_id}下一个最可能点击的物品ID为: {pred_item}")
+    print(f"加载数据集耗时: {t_data_end - t_data_start:.4f} 秒")
+    print(f"推理耗时: {t_infer_end - t_infer_start:.4f} 秒")
 
-# 4. 推理:获取embedding并打分
-t_infer_start = time.time()
-with torch.no_grad():
-    user_emb, item_emb = model.generate()
-    user_vec = user_emb[user_id].unsqueeze(0)
-    scores = model.rating(user_vec, item_emb).squeeze(0)
-    pred_item = torch.argmax(scores).item()
-t_infer_end = time.time()
-
-t_end = time.time()
-
-print(f"用户{user_id}下一个最可能点击的物品ID为: {pred_item}")
-print(f"加载数据集耗时: {t_data_end - t_data_start:.4f} 秒")
-print(f"推理耗时: {t_infer_end - t_infer_start:.4f} 秒")
-print(f"脚本总耗时: {t_end - t_start:.4f} 秒")
\ No newline at end of file
+if __name__ == "__main__":
+    t_start = time.time()
+    user_id = 1
+    run_inference(user_id)
+    t_end = time.time()
+    print(f"脚本总耗时: {t_end - t_start:.4f} 秒")
\ No newline at end of file