推荐系统更新
Change-Id: I0c1cd2201bd3baae442b5fd11f36e73c001a7173
diff --git a/recommend/inference.py b/recommend/inference.py
index 697b569..346209c 100644
--- a/recommend/inference.py
+++ b/recommend/inference.py
@@ -3,7 +3,7 @@
from os import path
from utils.parse_args import args
-from utils.dataloader import EdgeListData
+from utils.data_loader import EdgeListData
from model.LightGCN import LightGCN
import torch
import numpy as np
@@ -13,15 +13,14 @@
t_start = time.time()
# 配置参数
-args.data_path = './'
args.device = 'cuda:7'
+args.data_path = './user_seed_graph.txt'
args.pre_model_path = './model/LightGCN_pretrained.pt'
+
# 1. 加载数据集
t_data_start = time.time()
-pretrain_data = path.join(args.data_path, "uig.txt")
-pretrain_val_data = path.join(args.data_path, "uig.txt")
-dataset = EdgeListData(pretrain_data, pretrain_val_data)
+dataset = EdgeListData(args.data_path, args.data_path)
t_data_end = time.time()