推荐系统

Change-Id: I49b9205568f1ccf88b32b08511aff8b0bea8d1bd
diff --git a/rhj/backend/app/utils/__pycache__/bloom_filter.cpython-312.pyc b/rhj/backend/app/utils/__pycache__/bloom_filter.cpython-312.pyc
new file mode 100644
index 0000000..5c90537
--- /dev/null
+++ b/rhj/backend/app/utils/__pycache__/bloom_filter.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/utils/__pycache__/bloom_filter_manager.cpython-312.pyc b/rhj/backend/app/utils/__pycache__/bloom_filter_manager.cpython-312.pyc
new file mode 100644
index 0000000..268f1fb
--- /dev/null
+++ b/rhj/backend/app/utils/__pycache__/bloom_filter_manager.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/utils/__pycache__/data_loader.cpython-312.pyc b/rhj/backend/app/utils/__pycache__/data_loader.cpython-312.pyc
new file mode 100644
index 0000000..10b3571
--- /dev/null
+++ b/rhj/backend/app/utils/__pycache__/data_loader.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/utils/__pycache__/graph_build.cpython-312.pyc b/rhj/backend/app/utils/__pycache__/graph_build.cpython-312.pyc
new file mode 100644
index 0000000..a560e74
--- /dev/null
+++ b/rhj/backend/app/utils/__pycache__/graph_build.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/utils/__pycache__/parse_args.cpython-312.pyc b/rhj/backend/app/utils/__pycache__/parse_args.cpython-312.pyc
new file mode 100644
index 0000000..a88ee3b
--- /dev/null
+++ b/rhj/backend/app/utils/__pycache__/parse_args.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/utils/data_loader.py b/rhj/backend/app/utils/data_loader.py
new file mode 100644
index 0000000..c882a12
--- /dev/null
+++ b/rhj/backend/app/utils/data_loader.py
@@ -0,0 +1,97 @@
+from app.utils.parse_args import args
+from os import path
+from tqdm import tqdm
+import numpy as np
+import scipy.sparse as sp
+import torch
+import networkx as nx
+from copy import deepcopy
+from collections import defaultdict
+import pandas as pd
+
+
+class EdgeListData:
+    def __init__(self, train_file, test_file, phase='pretrain', pre_dataset=None, has_time=True):
+        self.phase = phase
+        self.has_time = has_time
+        self.pre_dataset = pre_dataset
+
+        self.hour_interval = args.hour_interval_pre if phase == 'pretrain' else args.hour_interval_f
+
+        self.edgelist = []
+        self.edge_time = []
+        self.num_users = 0
+        self.num_items = 0
+        self.num_edges = 0
+
+        self.train_user_dict = {}
+        self.test_user_dict = {}
+
+        self._load_data(train_file, test_file, has_time)
+
+        if phase == 'pretrain':
+            self.user_hist_dict = self.train_user_dict
+        
+        users_has_hist = set(list(self.user_hist_dict.keys()))
+        all_users = set(list(range(self.num_users)))
+        users_no_hist = all_users - users_has_hist
+        for u in users_no_hist:
+            self.user_hist_dict[u] = []
+
+    def _read_file(self, train_file, test_file, has_time=True):
+        with open(train_file, 'r') as f:
+            for line in f:
+                line = line.strip().split('\t')
+                if not has_time:
+                    user, items = line[:2]
+                    times = " ".join(["0"] * len(items.split(" ")))
+                    weights = " ".join(["1"] * len(items.split(" "))) if len(line) < 4 else line[3]
+                else:
+                    if len(line) >= 4:  # 包含权重信息
+                        user, items, times, weights = line
+                    else:
+                        user, items, times = line
+                        weights = " ".join(["1"] * len(items.split(" ")))
+                    
+                for i in items.split(" "):
+                    self.edgelist.append((int(user), int(i)))
+                for i in times.split(" "):
+                    self.edge_time.append(int(i))
+                self.train_user_dict[int(user)] = [int(i) for i in items.split(" ")]
+
+        self.test_edge_num = 0
+        with open(test_file, 'r') as f:
+            for line in f:
+                line = line.strip().split('\t')
+                user, items = line[:2]
+                self.test_user_dict[int(user)] = [int(i) for i in items.split(" ")]
+                self.test_edge_num += len(self.test_user_dict[int(user)])
+
+    def _load_data(self, train_file, test_file, has_time=True):
+        self._read_file(train_file, test_file, has_time)
+
+        self.edgelist = np.array(self.edgelist, dtype=np.int32)
+        self.edge_time = 1 + self.timestamp_to_time_step(np.array(self.edge_time, dtype=np.int32))
+        self.num_edges = len(self.edgelist)
+        if self.pre_dataset is not None:
+            self.num_users = self.pre_dataset.num_users
+            self.num_items = self.pre_dataset.num_items
+        else:
+            self.num_users = max([np.max(self.edgelist[:, 0]) + 1, np.max(list(self.test_user_dict.keys())) + 1])
+            self.num_items = max([np.max(self.edgelist[:, 1]) + 1, np.max([np.max(self.test_user_dict[u]) for u in self.test_user_dict.keys()]) + 1])
+
+        self.graph = sp.coo_matrix((np.ones(self.num_edges), (self.edgelist[:, 0], self.edgelist[:, 1])), shape=(self.num_users, self.num_items))
+
+        if self.has_time:
+            self.edge_time_dict = defaultdict(dict)
+            for i in range(len(self.edgelist)):
+                self.edge_time_dict[self.edgelist[i][0]][self.edgelist[i][1]+self.num_users] = self.edge_time[i]
+                self.edge_time_dict[self.edgelist[i][1]+self.num_users][self.edgelist[i][0]] = self.edge_time[i]
+
+    def timestamp_to_time_step(self, timestamp_arr, least_time=None):
+        interval_hour = self.hour_interval
+        if least_time is None:
+            least_time = np.min(timestamp_arr)
+        timestamp_arr = timestamp_arr - least_time
+        timestamp_arr = timestamp_arr // (interval_hour * 3600)
+        return timestamp_arr
diff --git a/rhj/backend/app/utils/graph_build.py b/rhj/backend/app/utils/graph_build.py
new file mode 100644
index 0000000..a453e4e
--- /dev/null
+++ b/rhj/backend/app/utils/graph_build.py
@@ -0,0 +1,115 @@
+import pymysql
+import datetime
+from collections import defaultdict
+
+SqlURL = "10.126.59.25"
+SqlPort = 3306
+Database = "redbook"  # 修改为redbook数据库
+SqlUsername = "root"
+SqlPassword = "123456"
+
+
+def fetch_user_post_data():
+    """
+    从redbook数据库的behaviors表获取用户-帖子交互数据,只包含已发布的帖子
+    """
+    conn = pymysql.connect(
+        host=SqlURL,
+        port=SqlPort,
+        user=SqlUsername,
+        password=SqlPassword,
+        database=Database,
+        charset="utf8mb4"
+    )
+    cursor = conn.cursor()
+    # 获取用户行为数据,只包含已发布帖子的行为数据
+    cursor.execute("""
+        SELECT b.user_id, b.post_id, b.type, b.value, b.created_at 
+        FROM behaviors b
+        INNER JOIN posts p ON b.post_id = p.id
+        WHERE b.type IN ('like', 'favorite', 'comment', 'view', 'share')
+        AND p.status = 'published'
+        ORDER BY b.created_at
+    """)
+    behavior_rows = cursor.fetchall()
+    cursor.close()
+    conn.close()
+    return behavior_rows
+
+
+def process_records(behavior_rows):
+    """
+    处理用户行为记录,为不同类型的行为分配权重
+    """
+    records = []
+    user_set = set()
+    post_set = set()
+    
+    # 为不同行为类型分配权重
+    behavior_weights = {
+        'view': 1,
+        'like': 2,
+        'comment': 3,
+        'share': 4,
+        'favorite': 5
+    }
+    
+    for row in behavior_rows:
+        user_id, post_id, behavior_type, value, created_at = row
+        user_set.add(user_id)
+        post_set.add(post_id)
+        
+        if isinstance(created_at, datetime.datetime):
+            ts = int(created_at.timestamp())
+        else:
+            ts = 0
+            
+        # 使用行为权重
+        weight = behavior_weights.get(behavior_type, 1) * (value or 1)
+        records.append((user_id, post_id, ts, weight))
+    
+    return records, user_set, post_set
+
+
+def build_id_maps(user_set, post_set):
+    """
+    构建用户和帖子的ID映射
+    """
+    user2idx = {uid: idx for idx, uid in enumerate(sorted(user_set))}
+    post2idx = {pid: idx for idx, pid in enumerate(sorted(post_set))}
+    return user2idx, post2idx
+
+
+def group_and_write(records, user2idx, post2idx, output_path="./app/user_post_graph.txt"):
+    """
+    将记录按用户分组并写入文件,支持行为权重
+    """
+    user_items = defaultdict(list)
+    user_times = defaultdict(list)
+    user_weights = defaultdict(list)
+    
+    for user_id, post_id, ts, weight in records:
+        uid = user2idx[user_id]
+        pid = post2idx[post_id]
+        user_items[uid].append(pid)
+        user_times[uid].append(ts)
+        user_weights[uid].append(weight)
+    
+    with open(output_path, "w", encoding="utf-8") as f:
+        for uid in sorted(user_items.keys()):
+            items = " ".join(str(item) for item in user_items[uid])
+            times = " ".join(str(t) for t in user_times[uid])
+            weights = " ".join(str(w) for w in user_weights[uid])
+            f.write(f"{uid}\t{items}\t{times}\t{weights}\n")
+
+
+def build_user_post_graph(return_mapping=False):
+    """
+    构建用户-帖子交互图
+    """
+    behavior_rows = fetch_user_post_data()
+    records, user_set, post_set = process_records(behavior_rows)
+    user2idx, post2idx = build_id_maps(user_set, post_set)
+    group_and_write(records, user2idx, post2idx)
+    if return_mapping:
+        return user2idx, post2idx
\ No newline at end of file
diff --git a/rhj/backend/app/utils/parse_args.py b/rhj/backend/app/utils/parse_args.py
new file mode 100644
index 0000000..82b3bb4
--- /dev/null
+++ b/rhj/backend/app/utils/parse_args.py
@@ -0,0 +1,77 @@
+import argparse
+import sys
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='GraphPro')
+    parser.add_argument('--phase', type=str, default='pretrain')
+    parser.add_argument('--plugin', action='store_true', default=False)
+    parser.add_argument('--save_path', type=str, default="saved" ,help='where to save model and logs')
+    parser.add_argument('--data_path', type=str, default="dataset/yelp",help='where to load data')
+    parser.add_argument('--exp_name', type=str, default='1')
+    parser.add_argument('--desc', type=str, default='')
+    parser.add_argument('--ab', type=str, default='full')
+    parser.add_argument('--log', type=int, default=1)
+
+    parser.add_argument('--device', type=str, default="cuda")
+    parser.add_argument('--model', type=str, default='GraphPro')
+    parser.add_argument('--pre_model', type=str, default='GraphPro')
+    parser.add_argument('--f_model', type=str, default='GraphPro')
+    parser.add_argument('--pre_model_path', type=str, default='pretrained_model.pt')
+
+    parser.add_argument('--hour_interval_pre', type=float, default=1)
+    parser.add_argument('--hour_interval_f', type=int, default=1)
+    parser.add_argument('--emb_dropout', type=float, default=0)
+
+    parser.add_argument('--updt_inter', type=int, default=1)
+    parser.add_argument('--samp_decay', type=float, default=0.05)
+    
+    parser.add_argument('--edge_dropout', type=float, default=0.5)
+    parser.add_argument('--emb_size', type=int, default=64)
+    parser.add_argument('--batch_size', type=int, default=2048)
+    parser.add_argument('--eval_batch_size', type=int, default=512)
+    parser.add_argument('--seed', type=int, default=2023)
+    parser.add_argument('--num_epochs', type=int, default=300)
+    parser.add_argument('--neighbor_sample_num', type=int, default=5)
+    parser.add_argument('--lr', type=float, default=0.001)
+    parser.add_argument('--weight_decay', type=float, default=1e-4)
+    parser.add_argument('--metrics', type=str, default='recall;ndcg')
+    parser.add_argument('--metrics_k', type=str, default='20')
+    parser.add_argument('--early_stop_patience', type=int, default=10)
+    parser.add_argument('--neg_num', type=int, default=1)
+    parser.add_argument('--num_layers', type=int, default=3)
+    parser.add_argument('--n_layers', type=int, default=3)
+    parser.add_argument('--ssl_reg', type=float, default=1e-4)
+    parser.add_argument('--ssl_alpha', type=float, default=1)
+    parser.add_argument('--ssl_temp', type=float, default=0.2)
+    parser.add_argument('--epoch', type=int, default=200)
+    parser.add_argument('--decay', type=float, default=1e-3)
+    parser.add_argument('--model_reg', type=float, default=1e-4)
+    parser.add_argument('--topk', type=int, default=[1, 5, 10, 20], nargs='+')
+    parser.add_argument('--aug_type', type=str, default='ED')
+    parser.add_argument('--metric_topk', type=int, default=10)
+    parser.add_argument('--n_neighbors', type=int, default=32)
+    parser.add_argument('--n_samp', type=int, default=7)
+    parser.add_argument('--temp', type=float, default=1)
+    parser.add_argument('--temp_f', type=float, default=1)
+    
+    return parser
+
+# 创建默认args,支持在没有命令行参数时使用
+try:
+    # 如果是在Flask应用中运行,使用默认参数
+    if len(sys.argv) == 1 or any(x in sys.argv[0] for x in ['flask', 'app.py', 'gunicorn']):
+        parser = parse_args()
+        args = parser.parse_args([])  # 使用空参数列表
+    else:
+        parser = parse_args()
+        args = parser.parse_args()
+except SystemExit:
+    # 如果parse_args失败,使用默认参数
+    parser = parse_args()
+    args = parser.parse_args([])
+
+if hasattr(args, 'pre_model') and hasattr(args, 'f_model'):
+    if args.pre_model == args.f_model:
+        args.model = args.pre_model
+    elif args.pre_model != 'LightGCN':
+        args.model = args.pre_model