blob: 3e86a476840958bfd2940b6b7fb4d805c5d7311d [file] [log] [blame]
whtb1e79592025-06-07 16:03:09 +08001import argparse
2
3def parse_args():
4 parser = argparse.ArgumentParser(description='GraphPro')
5 parser.add_argument('--phase', type=str, default='pretrain')
6 parser.add_argument('--plugin', action='store_true', default=False)
7 parser.add_argument('--save_path', type=str, default="saved" ,help='where to save model and logs')
8 parser.add_argument('--data_path', type=str, default="dataset/yelp",help='where to load data')
9 parser.add_argument('--exp_name', type=str, default='1')
10 parser.add_argument('--desc', type=str, default='')
11 parser.add_argument('--ab', type=str, default='full')
12 parser.add_argument('--log', type=int, default=1)
13
14 parser.add_argument('--device', type=str, default="cuda")
15 parser.add_argument('--model', type=str, default='GraphPro')
16 parser.add_argument('--pre_model', type=str, default='GraphPro')
17 parser.add_argument('--f_model', type=str, default='GraphPro')
18 parser.add_argument('--pre_model_path', type=str, default='pretrained_model.pt')
19
20 parser.add_argument('--hour_interval_pre', type=float, default=1)
21 parser.add_argument('--hour_interval_f', type=int, default=1)
22 parser.add_argument('--emb_dropout', type=float, default=0)
23
24 parser.add_argument('--updt_inter', type=int, default=1)
25 parser.add_argument('--samp_decay', type=float, default=0.05)
26
27 parser.add_argument('--edge_dropout', type=float, default=0.5)
28 parser.add_argument('--emb_size', type=int, default=64)
29 parser.add_argument('--batch_size', type=int, default=2048)
30 parser.add_argument('--eval_batch_size', type=int, default=512)
31 parser.add_argument('--seed', type=int, default=2023)
32 parser.add_argument('--num_epochs', type=int, default=300)
33 parser.add_argument('--neighbor_sample_num', type=int, default=5)
34 parser.add_argument('--lr', type=float, default=0.001)
35 parser.add_argument('--weight_decay', type=float, default=1e-4)
36 parser.add_argument('--metrics', type=str, default='recall;ndcg')
37 parser.add_argument('--metrics_k', type=str, default='20')
38 parser.add_argument('--early_stop_patience', type=int, default=10)
39 parser.add_argument('--neg_num', type=int, default=1)
40
41 parser.add_argument('--num_layers', type=int, default=3)
42
43
44 return parser
45
46parser = parse_args()
47args = parser.parse_known_args()[0]
48if args.pre_model == args.f_model:
49 args.model = args.pre_model
50elif args.pre_model != 'LightGCN':
51 args.model = args.pre_model
52
53args = parser.parse_args()
54if args.pre_model == args.f_model:
55 args.model = args.pre_model
56elif args.pre_model != 'LightGCN':
57 args.model = args.pre_model