Raver | d789517 | 2025-06-18 17:54:38 +0800 | [diff] [blame] | 1 | import argparse |
| 2 | import sys |
| 3 | |
| 4 | def parse_args(): |
| 5 | parser = argparse.ArgumentParser(description='GraphPro') |
| 6 | parser.add_argument('--phase', type=str, default='pretrain') |
| 7 | parser.add_argument('--plugin', action='store_true', default=False) |
| 8 | parser.add_argument('--save_path', type=str, default="saved" ,help='where to save model and logs') |
| 9 | parser.add_argument('--data_path', type=str, default="dataset/yelp",help='where to load data') |
| 10 | parser.add_argument('--exp_name', type=str, default='1') |
| 11 | parser.add_argument('--desc', type=str, default='') |
| 12 | parser.add_argument('--ab', type=str, default='full') |
| 13 | parser.add_argument('--log', type=int, default=1) |
| 14 | |
| 15 | parser.add_argument('--device', type=str, default="cuda") |
| 16 | parser.add_argument('--model', type=str, default='GraphPro') |
| 17 | parser.add_argument('--pre_model', type=str, default='GraphPro') |
| 18 | parser.add_argument('--f_model', type=str, default='GraphPro') |
| 19 | parser.add_argument('--pre_model_path', type=str, default='pretrained_model.pt') |
| 20 | |
| 21 | parser.add_argument('--hour_interval_pre', type=float, default=1) |
| 22 | parser.add_argument('--hour_interval_f', type=int, default=1) |
| 23 | parser.add_argument('--emb_dropout', type=float, default=0) |
| 24 | |
| 25 | parser.add_argument('--updt_inter', type=int, default=1) |
| 26 | parser.add_argument('--samp_decay', type=float, default=0.05) |
| 27 | |
| 28 | parser.add_argument('--edge_dropout', type=float, default=0.5) |
| 29 | parser.add_argument('--emb_size', type=int, default=64) |
| 30 | parser.add_argument('--batch_size', type=int, default=2048) |
| 31 | parser.add_argument('--eval_batch_size', type=int, default=512) |
| 32 | parser.add_argument('--seed', type=int, default=2023) |
| 33 | parser.add_argument('--num_epochs', type=int, default=300) |
| 34 | parser.add_argument('--neighbor_sample_num', type=int, default=5) |
| 35 | parser.add_argument('--lr', type=float, default=0.001) |
| 36 | parser.add_argument('--weight_decay', type=float, default=1e-4) |
| 37 | parser.add_argument('--metrics', type=str, default='recall;ndcg') |
| 38 | parser.add_argument('--metrics_k', type=str, default='20') |
| 39 | parser.add_argument('--early_stop_patience', type=int, default=10) |
| 40 | parser.add_argument('--neg_num', type=int, default=1) |
| 41 | parser.add_argument('--num_layers', type=int, default=3) |
| 42 | parser.add_argument('--n_layers', type=int, default=3) |
| 43 | parser.add_argument('--ssl_reg', type=float, default=1e-4) |
| 44 | parser.add_argument('--ssl_alpha', type=float, default=1) |
| 45 | parser.add_argument('--ssl_temp', type=float, default=0.2) |
| 46 | parser.add_argument('--epoch', type=int, default=200) |
| 47 | parser.add_argument('--decay', type=float, default=1e-3) |
| 48 | parser.add_argument('--model_reg', type=float, default=1e-4) |
| 49 | parser.add_argument('--topk', type=int, default=[1, 5, 10, 20], nargs='+') |
| 50 | parser.add_argument('--aug_type', type=str, default='ED') |
| 51 | parser.add_argument('--metric_topk', type=int, default=10) |
| 52 | parser.add_argument('--n_neighbors', type=int, default=32) |
| 53 | parser.add_argument('--n_samp', type=int, default=7) |
| 54 | parser.add_argument('--temp', type=float, default=1) |
| 55 | parser.add_argument('--temp_f', type=float, default=1) |
| 56 | |
| 57 | return parser |
| 58 | |
| 59 | # 创建默认args,支持在没有命令行参数时使用 |
| 60 | try: |
| 61 | # 如果是在Flask应用中运行,使用默认参数 |
| 62 | if len(sys.argv) == 1 or any(x in sys.argv[0] for x in ['flask', 'app.py', 'gunicorn']): |
| 63 | parser = parse_args() |
| 64 | args = parser.parse_args([]) # 使用空参数列表 |
| 65 | else: |
| 66 | parser = parse_args() |
| 67 | args = parser.parse_args() |
| 68 | except SystemExit: |
| 69 | # 如果parse_args失败,使用默认参数 |
| 70 | parser = parse_args() |
| 71 | args = parser.parse_args([]) |
| 72 | |
| 73 | if hasattr(args, 'pre_model') and hasattr(args, 'f_model'): |
| 74 | if args.pre_model == args.f_model: |
| 75 | args.model = args.pre_model |
| 76 | elif args.pre_model != 'LightGCN': |
| 77 | args.model = args.pre_model |