推荐系统

Change-Id: I49b9205568f1ccf88b32b08511aff8b0bea8d1bd
diff --git a/rhj/backend/app/utils/parse_args.py b/rhj/backend/app/utils/parse_args.py
new file mode 100644
index 0000000..82b3bb4
--- /dev/null
+++ b/rhj/backend/app/utils/parse_args.py
@@ -0,0 +1,77 @@
+import argparse
+import sys
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='GraphPro')
+    parser.add_argument('--phase', type=str, default='pretrain')
+    parser.add_argument('--plugin', action='store_true', default=False)
+    parser.add_argument('--save_path', type=str, default="saved" ,help='where to save model and logs')
+    parser.add_argument('--data_path', type=str, default="dataset/yelp",help='where to load data')
+    parser.add_argument('--exp_name', type=str, default='1')
+    parser.add_argument('--desc', type=str, default='')
+    parser.add_argument('--ab', type=str, default='full')
+    parser.add_argument('--log', type=int, default=1)
+
+    parser.add_argument('--device', type=str, default="cuda")
+    parser.add_argument('--model', type=str, default='GraphPro')
+    parser.add_argument('--pre_model', type=str, default='GraphPro')
+    parser.add_argument('--f_model', type=str, default='GraphPro')
+    parser.add_argument('--pre_model_path', type=str, default='pretrained_model.pt')
+
+    parser.add_argument('--hour_interval_pre', type=float, default=1)
+    parser.add_argument('--hour_interval_f', type=int, default=1)
+    parser.add_argument('--emb_dropout', type=float, default=0)
+
+    parser.add_argument('--updt_inter', type=int, default=1)
+    parser.add_argument('--samp_decay', type=float, default=0.05)
+    
+    parser.add_argument('--edge_dropout', type=float, default=0.5)
+    parser.add_argument('--emb_size', type=int, default=64)
+    parser.add_argument('--batch_size', type=int, default=2048)
+    parser.add_argument('--eval_batch_size', type=int, default=512)
+    parser.add_argument('--seed', type=int, default=2023)
+    parser.add_argument('--num_epochs', type=int, default=300)
+    parser.add_argument('--neighbor_sample_num', type=int, default=5)
+    parser.add_argument('--lr', type=float, default=0.001)
+    parser.add_argument('--weight_decay', type=float, default=1e-4)
+    parser.add_argument('--metrics', type=str, default='recall;ndcg')
+    parser.add_argument('--metrics_k', type=str, default='20')
+    parser.add_argument('--early_stop_patience', type=int, default=10)
+    parser.add_argument('--neg_num', type=int, default=1)
+    parser.add_argument('--num_layers', type=int, default=3)
+    parser.add_argument('--n_layers', type=int, default=3)
+    parser.add_argument('--ssl_reg', type=float, default=1e-4)
+    parser.add_argument('--ssl_alpha', type=float, default=1)
+    parser.add_argument('--ssl_temp', type=float, default=0.2)
+    parser.add_argument('--epoch', type=int, default=200)
+    parser.add_argument('--decay', type=float, default=1e-3)
+    parser.add_argument('--model_reg', type=float, default=1e-4)
+    parser.add_argument('--topk', type=int, default=[1, 5, 10, 20], nargs='+')
+    parser.add_argument('--aug_type', type=str, default='ED')
+    parser.add_argument('--metric_topk', type=int, default=10)
+    parser.add_argument('--n_neighbors', type=int, default=32)
+    parser.add_argument('--n_samp', type=int, default=7)
+    parser.add_argument('--temp', type=float, default=1)
+    parser.add_argument('--temp_f', type=float, default=1)
+    
+    return parser
+
+# 创建默认args,支持在没有命令行参数时使用
+try:
+    # 如果是在Flask应用中运行,使用默认参数
+    if len(sys.argv) == 1 or any(x in sys.argv[0] for x in ['flask', 'app.py', 'gunicorn']):
+        parser = parse_args()
+        args = parser.parse_args([])  # 使用空参数列表
+    else:
+        parser = parse_args()
+        args = parser.parse_args()
+except SystemExit:
+    # 如果parse_args失败,使用默认参数
+    parser = parse_args()
+    args = parser.parse_args([])
+
+if hasattr(args, 'pre_model') and hasattr(args, 'f_model'):
+    if args.pre_model == args.f_model:
+        args.model = args.pre_model
+    elif args.pre_model != 'LightGCN':
+        args.model = args.pre_model