blob: d519f172eef2a184d1120739305c9f7d5090df32 [file] [log] [blame]
whtb1e79592025-06-07 16:03:09 +08001from utils.parse_args import args
2from os import path
3from tqdm import tqdm
4import numpy as np
5import scipy.sparse as sp
6import torch
7import networkx as nx
8from copy import deepcopy
9from collections import defaultdict
10import pandas as pd
11
12
13class EdgeListData:
14 def __init__(self, train_file, test_file, phase='pretrain', pre_dataset=None, has_time=True):
15 self.phase = phase
16 self.has_time = has_time
17 self.pre_dataset = pre_dataset
18
19 self.hour_interval = args.hour_interval_pre if phase == 'pretrain' else args.hour_interval_f
20
21 self.edgelist = []
22 self.edge_time = []
23 self.num_users = 0
24 self.num_items = 0
25 self.num_edges = 0
26
27 self.train_user_dict = {}
28 self.test_user_dict = {}
29
30 self._load_data(train_file, test_file, has_time)
31
32 if phase == 'pretrain':
33 self.user_hist_dict = self.train_user_dict
34
35 users_has_hist = set(list(self.user_hist_dict.keys()))
36 all_users = set(list(range(self.num_users)))
37 users_no_hist = all_users - users_has_hist
38 for u in users_no_hist:
39 self.user_hist_dict[u] = []
40
41 def _read_file(self, train_file, test_file, has_time=True):
42 with open(train_file, 'r') as f:
43 for line in f:
44 line = line.strip().split('\t')
45 if not has_time:
46 user, items = line[:2]
47 times = " ".join(["0"] * len(items.split(" ")))
48 else:
49 user, items, times = line
50
51 for i in items.split(" "):
52 self.edgelist.append((int(user), int(i)))
53 for i in times.split(" "):
54 self.edge_time.append(int(i))
55 self.train_user_dict[int(user)] = [int(i) for i in items.split(" ")]
56
57 self.test_edge_num = 0
58 with open(test_file, 'r') as f:
59 for line in f:
60 line = line.strip().split('\t')
61 user, items = line[:2]
62 self.test_user_dict[int(user)] = [int(i) for i in items.split(" ")]
63 self.test_edge_num += len(self.test_user_dict[int(user)])
64
65 def _load_data(self, train_file, test_file, has_time=True):
66 self._read_file(train_file, test_file, has_time)
67
68 self.edgelist = np.array(self.edgelist, dtype=np.int32)
69 self.edge_time = 1 + self.timestamp_to_time_step(np.array(self.edge_time, dtype=np.int32))
70 self.num_edges = len(self.edgelist)
71 if self.pre_dataset is not None:
72 self.num_users = self.pre_dataset.num_users
73 self.num_items = self.pre_dataset.num_items
74 else:
75 self.num_users = max([np.max(self.edgelist[:, 0]) + 1, np.max(list(self.test_user_dict.keys())) + 1])
76 self.num_items = max([np.max(self.edgelist[:, 1]) + 1, np.max([np.max(self.test_user_dict[u]) for u in self.test_user_dict.keys()]) + 1])
77
78 self.graph = sp.coo_matrix((np.ones(self.num_edges), (self.edgelist[:, 0], self.edgelist[:, 1])), shape=(self.num_users, self.num_items))
79
80 if self.has_time:
81 self.edge_time_dict = defaultdict(dict)
82 for i in range(len(self.edgelist)):
83 self.edge_time_dict[self.edgelist[i][0]][self.edgelist[i][1]+self.num_users] = self.edge_time[i]
84 self.edge_time_dict[self.edgelist[i][1]+self.num_users][self.edgelist[i][0]] = self.edge_time[i]
85
86 def timestamp_to_time_step(self, timestamp_arr, least_time=None):
87 interval_hour = self.hour_interval
88 if least_time is None:
89 least_time = np.min(timestamp_arr)
90 timestamp_arr = timestamp_arr - least_time
91 timestamp_arr = timestamp_arr // (interval_hour * 3600)
92 return timestamp_arr