blob: 82b3bb465c0b132599cdf17ac72d3e97d6ed65b3 [file] [log] [blame]
Raverd7895172025-06-18 17:54:38 +08001import argparse
2import sys
3
4def parse_args():
5 parser = argparse.ArgumentParser(description='GraphPro')
6 parser.add_argument('--phase', type=str, default='pretrain')
7 parser.add_argument('--plugin', action='store_true', default=False)
8 parser.add_argument('--save_path', type=str, default="saved" ,help='where to save model and logs')
9 parser.add_argument('--data_path', type=str, default="dataset/yelp",help='where to load data')
10 parser.add_argument('--exp_name', type=str, default='1')
11 parser.add_argument('--desc', type=str, default='')
12 parser.add_argument('--ab', type=str, default='full')
13 parser.add_argument('--log', type=int, default=1)
14
15 parser.add_argument('--device', type=str, default="cuda")
16 parser.add_argument('--model', type=str, default='GraphPro')
17 parser.add_argument('--pre_model', type=str, default='GraphPro')
18 parser.add_argument('--f_model', type=str, default='GraphPro')
19 parser.add_argument('--pre_model_path', type=str, default='pretrained_model.pt')
20
21 parser.add_argument('--hour_interval_pre', type=float, default=1)
22 parser.add_argument('--hour_interval_f', type=int, default=1)
23 parser.add_argument('--emb_dropout', type=float, default=0)
24
25 parser.add_argument('--updt_inter', type=int, default=1)
26 parser.add_argument('--samp_decay', type=float, default=0.05)
27
28 parser.add_argument('--edge_dropout', type=float, default=0.5)
29 parser.add_argument('--emb_size', type=int, default=64)
30 parser.add_argument('--batch_size', type=int, default=2048)
31 parser.add_argument('--eval_batch_size', type=int, default=512)
32 parser.add_argument('--seed', type=int, default=2023)
33 parser.add_argument('--num_epochs', type=int, default=300)
34 parser.add_argument('--neighbor_sample_num', type=int, default=5)
35 parser.add_argument('--lr', type=float, default=0.001)
36 parser.add_argument('--weight_decay', type=float, default=1e-4)
37 parser.add_argument('--metrics', type=str, default='recall;ndcg')
38 parser.add_argument('--metrics_k', type=str, default='20')
39 parser.add_argument('--early_stop_patience', type=int, default=10)
40 parser.add_argument('--neg_num', type=int, default=1)
41 parser.add_argument('--num_layers', type=int, default=3)
42 parser.add_argument('--n_layers', type=int, default=3)
43 parser.add_argument('--ssl_reg', type=float, default=1e-4)
44 parser.add_argument('--ssl_alpha', type=float, default=1)
45 parser.add_argument('--ssl_temp', type=float, default=0.2)
46 parser.add_argument('--epoch', type=int, default=200)
47 parser.add_argument('--decay', type=float, default=1e-3)
48 parser.add_argument('--model_reg', type=float, default=1e-4)
49 parser.add_argument('--topk', type=int, default=[1, 5, 10, 20], nargs='+')
50 parser.add_argument('--aug_type', type=str, default='ED')
51 parser.add_argument('--metric_topk', type=int, default=10)
52 parser.add_argument('--n_neighbors', type=int, default=32)
53 parser.add_argument('--n_samp', type=int, default=7)
54 parser.add_argument('--temp', type=float, default=1)
55 parser.add_argument('--temp_f', type=float, default=1)
56
57 return parser
58
59# 创建默认args,支持在没有命令行参数时使用
60try:
61 # 如果是在Flask应用中运行,使用默认参数
62 if len(sys.argv) == 1 or any(x in sys.argv[0] for x in ['flask', 'app.py', 'gunicorn']):
63 parser = parse_args()
64 args = parser.parse_args([]) # 使用空参数列表
65 else:
66 parser = parse_args()
67 args = parser.parse_args()
68except SystemExit:
69 # 如果parse_args失败,使用默认参数
70 parser = parse_args()
71 args = parser.parse_args([])
72
73if hasattr(args, 'pre_model') and hasattr(args, 'f_model'):
74 if args.pre_model == args.f_model:
75 args.model = args.pre_model
76 elif args.pre_model != 'LightGCN':
77 args.model = args.pre_model