blob: c882a12acefefd4ef56d3561b829275001568fac [file] [log] [blame]
Raverd7895172025-06-18 17:54:38 +08001from app.utils.parse_args import args
2from os import path
3from tqdm import tqdm
4import numpy as np
5import scipy.sparse as sp
6import torch
7import networkx as nx
8from copy import deepcopy
9from collections import defaultdict
10import pandas as pd
11
12
13class EdgeListData:
14 def __init__(self, train_file, test_file, phase='pretrain', pre_dataset=None, has_time=True):
15 self.phase = phase
16 self.has_time = has_time
17 self.pre_dataset = pre_dataset
18
19 self.hour_interval = args.hour_interval_pre if phase == 'pretrain' else args.hour_interval_f
20
21 self.edgelist = []
22 self.edge_time = []
23 self.num_users = 0
24 self.num_items = 0
25 self.num_edges = 0
26
27 self.train_user_dict = {}
28 self.test_user_dict = {}
29
30 self._load_data(train_file, test_file, has_time)
31
32 if phase == 'pretrain':
33 self.user_hist_dict = self.train_user_dict
34
35 users_has_hist = set(list(self.user_hist_dict.keys()))
36 all_users = set(list(range(self.num_users)))
37 users_no_hist = all_users - users_has_hist
38 for u in users_no_hist:
39 self.user_hist_dict[u] = []
40
41 def _read_file(self, train_file, test_file, has_time=True):
42 with open(train_file, 'r') as f:
43 for line in f:
44 line = line.strip().split('\t')
45 if not has_time:
46 user, items = line[:2]
47 times = " ".join(["0"] * len(items.split(" ")))
48 weights = " ".join(["1"] * len(items.split(" "))) if len(line) < 4 else line[3]
49 else:
50 if len(line) >= 4: # 包含权重信息
51 user, items, times, weights = line
52 else:
53 user, items, times = line
54 weights = " ".join(["1"] * len(items.split(" ")))
55
56 for i in items.split(" "):
57 self.edgelist.append((int(user), int(i)))
58 for i in times.split(" "):
59 self.edge_time.append(int(i))
60 self.train_user_dict[int(user)] = [int(i) for i in items.split(" ")]
61
62 self.test_edge_num = 0
63 with open(test_file, 'r') as f:
64 for line in f:
65 line = line.strip().split('\t')
66 user, items = line[:2]
67 self.test_user_dict[int(user)] = [int(i) for i in items.split(" ")]
68 self.test_edge_num += len(self.test_user_dict[int(user)])
69
70 def _load_data(self, train_file, test_file, has_time=True):
71 self._read_file(train_file, test_file, has_time)
72
73 self.edgelist = np.array(self.edgelist, dtype=np.int32)
74 self.edge_time = 1 + self.timestamp_to_time_step(np.array(self.edge_time, dtype=np.int32))
75 self.num_edges = len(self.edgelist)
76 if self.pre_dataset is not None:
77 self.num_users = self.pre_dataset.num_users
78 self.num_items = self.pre_dataset.num_items
79 else:
80 self.num_users = max([np.max(self.edgelist[:, 0]) + 1, np.max(list(self.test_user_dict.keys())) + 1])
81 self.num_items = max([np.max(self.edgelist[:, 1]) + 1, np.max([np.max(self.test_user_dict[u]) for u in self.test_user_dict.keys()]) + 1])
82
83 self.graph = sp.coo_matrix((np.ones(self.num_edges), (self.edgelist[:, 0], self.edgelist[:, 1])), shape=(self.num_users, self.num_items))
84
85 if self.has_time:
86 self.edge_time_dict = defaultdict(dict)
87 for i in range(len(self.edgelist)):
88 self.edge_time_dict[self.edgelist[i][0]][self.edgelist[i][1]+self.num_users] = self.edge_time[i]
89 self.edge_time_dict[self.edgelist[i][1]+self.num_users][self.edgelist[i][0]] = self.edge_time[i]
90
91 def timestamp_to_time_step(self, timestamp_arr, least_time=None):
92 interval_hour = self.hour_interval
93 if least_time is None:
94 least_time = np.min(timestamp_arr)
95 timestamp_arr = timestamp_arr - least_time
96 timestamp_arr = timestamp_arr // (interval_hour * 3600)
97 return timestamp_arr