推荐系统
Change-Id: I49b9205568f1ccf88b32b08511aff8b0bea8d1bd
diff --git a/rhj/backend/app/utils/__pycache__/bloom_filter.cpython-312.pyc b/rhj/backend/app/utils/__pycache__/bloom_filter.cpython-312.pyc
new file mode 100644
index 0000000..5c90537
--- /dev/null
+++ b/rhj/backend/app/utils/__pycache__/bloom_filter.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/utils/__pycache__/bloom_filter_manager.cpython-312.pyc b/rhj/backend/app/utils/__pycache__/bloom_filter_manager.cpython-312.pyc
new file mode 100644
index 0000000..268f1fb
--- /dev/null
+++ b/rhj/backend/app/utils/__pycache__/bloom_filter_manager.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/utils/__pycache__/data_loader.cpython-312.pyc b/rhj/backend/app/utils/__pycache__/data_loader.cpython-312.pyc
new file mode 100644
index 0000000..10b3571
--- /dev/null
+++ b/rhj/backend/app/utils/__pycache__/data_loader.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/utils/__pycache__/graph_build.cpython-312.pyc b/rhj/backend/app/utils/__pycache__/graph_build.cpython-312.pyc
new file mode 100644
index 0000000..a560e74
--- /dev/null
+++ b/rhj/backend/app/utils/__pycache__/graph_build.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/utils/__pycache__/parse_args.cpython-312.pyc b/rhj/backend/app/utils/__pycache__/parse_args.cpython-312.pyc
new file mode 100644
index 0000000..a88ee3b
--- /dev/null
+++ b/rhj/backend/app/utils/__pycache__/parse_args.cpython-312.pyc
Binary files differ
diff --git a/rhj/backend/app/utils/data_loader.py b/rhj/backend/app/utils/data_loader.py
new file mode 100644
index 0000000..c882a12
--- /dev/null
+++ b/rhj/backend/app/utils/data_loader.py
@@ -0,0 +1,97 @@
+from app.utils.parse_args import args
+from os import path
+from tqdm import tqdm
+import numpy as np
+import scipy.sparse as sp
+import torch
+import networkx as nx
+from copy import deepcopy
+from collections import defaultdict
+import pandas as pd
+
+
+class EdgeListData:
+ def __init__(self, train_file, test_file, phase='pretrain', pre_dataset=None, has_time=True):
+ self.phase = phase
+ self.has_time = has_time
+ self.pre_dataset = pre_dataset
+
+ self.hour_interval = args.hour_interval_pre if phase == 'pretrain' else args.hour_interval_f
+
+ self.edgelist = []
+ self.edge_time = []
+ self.num_users = 0
+ self.num_items = 0
+ self.num_edges = 0
+
+ self.train_user_dict = {}
+ self.test_user_dict = {}
+
+ self._load_data(train_file, test_file, has_time)
+
+ if phase == 'pretrain':
+ self.user_hist_dict = self.train_user_dict
+
+ users_has_hist = set(list(self.user_hist_dict.keys()))
+ all_users = set(list(range(self.num_users)))
+ users_no_hist = all_users - users_has_hist
+ for u in users_no_hist:
+ self.user_hist_dict[u] = []
+
+ def _read_file(self, train_file, test_file, has_time=True):
+ with open(train_file, 'r') as f:
+ for line in f:
+ line = line.strip().split('\t')
+ if not has_time:
+ user, items = line[:2]
+ times = " ".join(["0"] * len(items.split(" ")))
+ weights = " ".join(["1"] * len(items.split(" "))) if len(line) < 4 else line[3]
+ else:
+ if len(line) >= 4: # 包含权重信息
+ user, items, times, weights = line
+ else:
+ user, items, times = line
+ weights = " ".join(["1"] * len(items.split(" ")))
+
+ for i in items.split(" "):
+ self.edgelist.append((int(user), int(i)))
+ for i in times.split(" "):
+ self.edge_time.append(int(i))
+ self.train_user_dict[int(user)] = [int(i) for i in items.split(" ")]
+
+ self.test_edge_num = 0
+ with open(test_file, 'r') as f:
+ for line in f:
+ line = line.strip().split('\t')
+ user, items = line[:2]
+ self.test_user_dict[int(user)] = [int(i) for i in items.split(" ")]
+ self.test_edge_num += len(self.test_user_dict[int(user)])
+
+ def _load_data(self, train_file, test_file, has_time=True):
+ self._read_file(train_file, test_file, has_time)
+
+ self.edgelist = np.array(self.edgelist, dtype=np.int32)
+ self.edge_time = 1 + self.timestamp_to_time_step(np.array(self.edge_time, dtype=np.int32))
+ self.num_edges = len(self.edgelist)
+ if self.pre_dataset is not None:
+ self.num_users = self.pre_dataset.num_users
+ self.num_items = self.pre_dataset.num_items
+ else:
+ self.num_users = max([np.max(self.edgelist[:, 0]) + 1, np.max(list(self.test_user_dict.keys())) + 1])
+ self.num_items = max([np.max(self.edgelist[:, 1]) + 1, np.max([np.max(self.test_user_dict[u]) for u in self.test_user_dict.keys()]) + 1])
+
+ self.graph = sp.coo_matrix((np.ones(self.num_edges), (self.edgelist[:, 0], self.edgelist[:, 1])), shape=(self.num_users, self.num_items))
+
+ if self.has_time:
+ self.edge_time_dict = defaultdict(dict)
+ for i in range(len(self.edgelist)):
+ self.edge_time_dict[self.edgelist[i][0]][self.edgelist[i][1]+self.num_users] = self.edge_time[i]
+ self.edge_time_dict[self.edgelist[i][1]+self.num_users][self.edgelist[i][0]] = self.edge_time[i]
+
+ def timestamp_to_time_step(self, timestamp_arr, least_time=None):
+ interval_hour = self.hour_interval
+ if least_time is None:
+ least_time = np.min(timestamp_arr)
+ timestamp_arr = timestamp_arr - least_time
+ timestamp_arr = timestamp_arr // (interval_hour * 3600)
+ return timestamp_arr
diff --git a/rhj/backend/app/utils/graph_build.py b/rhj/backend/app/utils/graph_build.py
new file mode 100644
index 0000000..a453e4e
--- /dev/null
+++ b/rhj/backend/app/utils/graph_build.py
@@ -0,0 +1,115 @@
+import pymysql
+import datetime
+from collections import defaultdict
+
+SqlURL = "10.126.59.25"
+SqlPort = 3306
+Database = "redbook" # 修改为redbook数据库
+SqlUsername = "root"
+SqlPassword = "123456"
+
+
+def fetch_user_post_data():
+ """
+ 从redbook数据库的behaviors表获取用户-帖子交互数据,只包含已发布的帖子
+ """
+ conn = pymysql.connect(
+ host=SqlURL,
+ port=SqlPort,
+ user=SqlUsername,
+ password=SqlPassword,
+ database=Database,
+ charset="utf8mb4"
+ )
+ cursor = conn.cursor()
+ # 获取用户行为数据,只包含已发布帖子的行为数据
+ cursor.execute("""
+ SELECT b.user_id, b.post_id, b.type, b.value, b.created_at
+ FROM behaviors b
+ INNER JOIN posts p ON b.post_id = p.id
+ WHERE b.type IN ('like', 'favorite', 'comment', 'view', 'share')
+ AND p.status = 'published'
+ ORDER BY b.created_at
+ """)
+ behavior_rows = cursor.fetchall()
+ cursor.close()
+ conn.close()
+ return behavior_rows
+
+
+def process_records(behavior_rows):
+ """
+ 处理用户行为记录,为不同类型的行为分配权重
+ """
+ records = []
+ user_set = set()
+ post_set = set()
+
+ # 为不同行为类型分配权重
+ behavior_weights = {
+ 'view': 1,
+ 'like': 2,
+ 'comment': 3,
+ 'share': 4,
+ 'favorite': 5
+ }
+
+ for row in behavior_rows:
+ user_id, post_id, behavior_type, value, created_at = row
+ user_set.add(user_id)
+ post_set.add(post_id)
+
+ if isinstance(created_at, datetime.datetime):
+ ts = int(created_at.timestamp())
+ else:
+ ts = 0
+
+ # 使用行为权重
+ weight = behavior_weights.get(behavior_type, 1) * (value or 1)
+ records.append((user_id, post_id, ts, weight))
+
+ return records, user_set, post_set
+
+
+def build_id_maps(user_set, post_set):
+ """
+ 构建用户和帖子的ID映射
+ """
+ user2idx = {uid: idx for idx, uid in enumerate(sorted(user_set))}
+ post2idx = {pid: idx for idx, pid in enumerate(sorted(post_set))}
+ return user2idx, post2idx
+
+
+def group_and_write(records, user2idx, post2idx, output_path="./app/user_post_graph.txt"):
+ """
+ 将记录按用户分组并写入文件,支持行为权重
+ """
+ user_items = defaultdict(list)
+ user_times = defaultdict(list)
+ user_weights = defaultdict(list)
+
+ for user_id, post_id, ts, weight in records:
+ uid = user2idx[user_id]
+ pid = post2idx[post_id]
+ user_items[uid].append(pid)
+ user_times[uid].append(ts)
+ user_weights[uid].append(weight)
+
+ with open(output_path, "w", encoding="utf-8") as f:
+ for uid in sorted(user_items.keys()):
+ items = " ".join(str(item) for item in user_items[uid])
+ times = " ".join(str(t) for t in user_times[uid])
+ weights = " ".join(str(w) for w in user_weights[uid])
+ f.write(f"{uid}\t{items}\t{times}\t{weights}\n")
+
+
+def build_user_post_graph(return_mapping=False):
+ """
+ 构建用户-帖子交互图
+ """
+ behavior_rows = fetch_user_post_data()
+ records, user_set, post_set = process_records(behavior_rows)
+ user2idx, post2idx = build_id_maps(user_set, post_set)
+ group_and_write(records, user2idx, post2idx)
+ if return_mapping:
+ return user2idx, post2idx
\ No newline at end of file
diff --git a/rhj/backend/app/utils/parse_args.py b/rhj/backend/app/utils/parse_args.py
new file mode 100644
index 0000000..82b3bb4
--- /dev/null
+++ b/rhj/backend/app/utils/parse_args.py
@@ -0,0 +1,77 @@
+import argparse
+import sys
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='GraphPro')
+ parser.add_argument('--phase', type=str, default='pretrain')
+ parser.add_argument('--plugin', action='store_true', default=False)
+ parser.add_argument('--save_path', type=str, default="saved" ,help='where to save model and logs')
+ parser.add_argument('--data_path', type=str, default="dataset/yelp",help='where to load data')
+ parser.add_argument('--exp_name', type=str, default='1')
+ parser.add_argument('--desc', type=str, default='')
+ parser.add_argument('--ab', type=str, default='full')
+ parser.add_argument('--log', type=int, default=1)
+
+ parser.add_argument('--device', type=str, default="cuda")
+ parser.add_argument('--model', type=str, default='GraphPro')
+ parser.add_argument('--pre_model', type=str, default='GraphPro')
+ parser.add_argument('--f_model', type=str, default='GraphPro')
+ parser.add_argument('--pre_model_path', type=str, default='pretrained_model.pt')
+
+ parser.add_argument('--hour_interval_pre', type=float, default=1)
+ parser.add_argument('--hour_interval_f', type=int, default=1)
+ parser.add_argument('--emb_dropout', type=float, default=0)
+
+ parser.add_argument('--updt_inter', type=int, default=1)
+ parser.add_argument('--samp_decay', type=float, default=0.05)
+
+ parser.add_argument('--edge_dropout', type=float, default=0.5)
+ parser.add_argument('--emb_size', type=int, default=64)
+ parser.add_argument('--batch_size', type=int, default=2048)
+ parser.add_argument('--eval_batch_size', type=int, default=512)
+ parser.add_argument('--seed', type=int, default=2023)
+ parser.add_argument('--num_epochs', type=int, default=300)
+ parser.add_argument('--neighbor_sample_num', type=int, default=5)
+ parser.add_argument('--lr', type=float, default=0.001)
+ parser.add_argument('--weight_decay', type=float, default=1e-4)
+ parser.add_argument('--metrics', type=str, default='recall;ndcg')
+ parser.add_argument('--metrics_k', type=str, default='20')
+ parser.add_argument('--early_stop_patience', type=int, default=10)
+ parser.add_argument('--neg_num', type=int, default=1)
+ parser.add_argument('--num_layers', type=int, default=3)
+ parser.add_argument('--n_layers', type=int, default=3)
+ parser.add_argument('--ssl_reg', type=float, default=1e-4)
+ parser.add_argument('--ssl_alpha', type=float, default=1)
+ parser.add_argument('--ssl_temp', type=float, default=0.2)
+ parser.add_argument('--epoch', type=int, default=200)
+ parser.add_argument('--decay', type=float, default=1e-3)
+ parser.add_argument('--model_reg', type=float, default=1e-4)
+ parser.add_argument('--topk', type=int, default=[1, 5, 10, 20], nargs='+')
+ parser.add_argument('--aug_type', type=str, default='ED')
+ parser.add_argument('--metric_topk', type=int, default=10)
+ parser.add_argument('--n_neighbors', type=int, default=32)
+ parser.add_argument('--n_samp', type=int, default=7)
+ parser.add_argument('--temp', type=float, default=1)
+ parser.add_argument('--temp_f', type=float, default=1)
+
+ return parser
+
+# 创建默认args,支持在没有命令行参数时使用
+try:
+ # 如果是在Flask应用中运行,使用默认参数
+ if len(sys.argv) == 1 or any(x in sys.argv[0] for x in ['flask', 'app.py', 'gunicorn']):
+ parser = parse_args()
+ args = parser.parse_args([]) # 使用空参数列表
+ else:
+ parser = parse_args()
+ args = parser.parse_args()
+except SystemExit:
+ # 如果parse_args失败,使用默认参数
+ parser = parse_args()
+ args = parser.parse_args([])
+
+if hasattr(args, 'pre_model') and hasattr(args, 'f_model'):
+ if args.pre_model == args.f_model:
+ args.model = args.pre_model
+ elif args.pre_model != 'LightGCN':
+ args.model = args.pre_model