blob: 82b3bb465c0b132599cdf17ac72d3e97d6ed65b3 [file] [log] [blame]
import argparse
import sys
def parse_args():
parser = argparse.ArgumentParser(description='GraphPro')
parser.add_argument('--phase', type=str, default='pretrain')
parser.add_argument('--plugin', action='store_true', default=False)
parser.add_argument('--save_path', type=str, default="saved" ,help='where to save model and logs')
parser.add_argument('--data_path', type=str, default="dataset/yelp",help='where to load data')
parser.add_argument('--exp_name', type=str, default='1')
parser.add_argument('--desc', type=str, default='')
parser.add_argument('--ab', type=str, default='full')
parser.add_argument('--log', type=int, default=1)
parser.add_argument('--device', type=str, default="cuda")
parser.add_argument('--model', type=str, default='GraphPro')
parser.add_argument('--pre_model', type=str, default='GraphPro')
parser.add_argument('--f_model', type=str, default='GraphPro')
parser.add_argument('--pre_model_path', type=str, default='pretrained_model.pt')
parser.add_argument('--hour_interval_pre', type=float, default=1)
parser.add_argument('--hour_interval_f', type=int, default=1)
parser.add_argument('--emb_dropout', type=float, default=0)
parser.add_argument('--updt_inter', type=int, default=1)
parser.add_argument('--samp_decay', type=float, default=0.05)
parser.add_argument('--edge_dropout', type=float, default=0.5)
parser.add_argument('--emb_size', type=int, default=64)
parser.add_argument('--batch_size', type=int, default=2048)
parser.add_argument('--eval_batch_size', type=int, default=512)
parser.add_argument('--seed', type=int, default=2023)
parser.add_argument('--num_epochs', type=int, default=300)
parser.add_argument('--neighbor_sample_num', type=int, default=5)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--weight_decay', type=float, default=1e-4)
parser.add_argument('--metrics', type=str, default='recall;ndcg')
parser.add_argument('--metrics_k', type=str, default='20')
parser.add_argument('--early_stop_patience', type=int, default=10)
parser.add_argument('--neg_num', type=int, default=1)
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--n_layers', type=int, default=3)
parser.add_argument('--ssl_reg', type=float, default=1e-4)
parser.add_argument('--ssl_alpha', type=float, default=1)
parser.add_argument('--ssl_temp', type=float, default=0.2)
parser.add_argument('--epoch', type=int, default=200)
parser.add_argument('--decay', type=float, default=1e-3)
parser.add_argument('--model_reg', type=float, default=1e-4)
parser.add_argument('--topk', type=int, default=[1, 5, 10, 20], nargs='+')
parser.add_argument('--aug_type', type=str, default='ED')
parser.add_argument('--metric_topk', type=int, default=10)
parser.add_argument('--n_neighbors', type=int, default=32)
parser.add_argument('--n_samp', type=int, default=7)
parser.add_argument('--temp', type=float, default=1)
parser.add_argument('--temp_f', type=float, default=1)
return parser
# 创建默认args,支持在没有命令行参数时使用
try:
# 如果是在Flask应用中运行,使用默认参数
if len(sys.argv) == 1 or any(x in sys.argv[0] for x in ['flask', 'app.py', 'gunicorn']):
parser = parse_args()
args = parser.parse_args([]) # 使用空参数列表
else:
parser = parse_args()
args = parser.parse_args()
except SystemExit:
# 如果parse_args失败,使用默认参数
parser = parse_args()
args = parser.parse_args([])
if hasattr(args, 'pre_model') and hasattr(args, 'f_model'):
if args.pre_model == args.f_model:
args.model = args.pre_model
elif args.pre_model != 'LightGCN':
args.model = args.pre_model