wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame^] | 1 | import argparse |
| 2 | |
| 3 | def parse_args(): |
| 4 | parser = argparse.ArgumentParser(description='GraphPro') |
| 5 | parser.add_argument('--phase', type=str, default='pretrain') |
| 6 | parser.add_argument('--plugin', action='store_true', default=False) |
| 7 | parser.add_argument('--save_path', type=str, default="saved" ,help='where to save model and logs') |
| 8 | parser.add_argument('--data_path', type=str, default="dataset/yelp",help='where to load data') |
| 9 | parser.add_argument('--exp_name', type=str, default='1') |
| 10 | parser.add_argument('--desc', type=str, default='') |
| 11 | parser.add_argument('--ab', type=str, default='full') |
| 12 | parser.add_argument('--log', type=int, default=1) |
| 13 | |
| 14 | parser.add_argument('--device', type=str, default="cuda") |
| 15 | parser.add_argument('--model', type=str, default='GraphPro') |
| 16 | parser.add_argument('--pre_model', type=str, default='GraphPro') |
| 17 | parser.add_argument('--f_model', type=str, default='GraphPro') |
| 18 | parser.add_argument('--pre_model_path', type=str, default='pretrained_model.pt') |
| 19 | |
| 20 | parser.add_argument('--hour_interval_pre', type=float, default=1) |
| 21 | parser.add_argument('--hour_interval_f', type=int, default=1) |
| 22 | parser.add_argument('--emb_dropout', type=float, default=0) |
| 23 | |
| 24 | parser.add_argument('--updt_inter', type=int, default=1) |
| 25 | parser.add_argument('--samp_decay', type=float, default=0.05) |
| 26 | |
| 27 | parser.add_argument('--edge_dropout', type=float, default=0.5) |
| 28 | parser.add_argument('--emb_size', type=int, default=64) |
| 29 | parser.add_argument('--batch_size', type=int, default=2048) |
| 30 | parser.add_argument('--eval_batch_size', type=int, default=512) |
| 31 | parser.add_argument('--seed', type=int, default=2023) |
| 32 | parser.add_argument('--num_epochs', type=int, default=300) |
| 33 | parser.add_argument('--neighbor_sample_num', type=int, default=5) |
| 34 | parser.add_argument('--lr', type=float, default=0.001) |
| 35 | parser.add_argument('--weight_decay', type=float, default=1e-4) |
| 36 | parser.add_argument('--metrics', type=str, default='recall;ndcg') |
| 37 | parser.add_argument('--metrics_k', type=str, default='20') |
| 38 | parser.add_argument('--early_stop_patience', type=int, default=10) |
| 39 | parser.add_argument('--neg_num', type=int, default=1) |
| 40 | |
| 41 | parser.add_argument('--num_layers', type=int, default=3) |
| 42 | |
| 43 | |
| 44 | return parser |
| 45 | |
| 46 | parser = parse_args() |
| 47 | args = parser.parse_known_args()[0] |
| 48 | if args.pre_model == args.f_model: |
| 49 | args.model = args.pre_model |
| 50 | elif args.pre_model != 'LightGCN': |
| 51 | args.model = args.pre_model |
| 52 | |
| 53 | args = parser.parse_args() |
| 54 | if args.pre_model == args.f_model: |
| 55 | args.model = args.pre_model |
| 56 | elif args.pre_model != 'LightGCN': |
| 57 | args.model = args.pre_model |