wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame] | 1 | from utils.parse_args import args |
| 2 | from os import path |
| 3 | from tqdm import tqdm |
| 4 | import numpy as np |
| 5 | import scipy.sparse as sp |
| 6 | import torch |
| 7 | import networkx as nx |
| 8 | from copy import deepcopy |
| 9 | from collections import defaultdict |
| 10 | import pandas as pd |
| 11 | |
| 12 | |
| 13 | class EdgeListData: |
| 14 | def __init__(self, train_file, test_file, phase='pretrain', pre_dataset=None, has_time=True): |
| 15 | self.phase = phase |
| 16 | self.has_time = has_time |
| 17 | self.pre_dataset = pre_dataset |
| 18 | |
| 19 | self.hour_interval = args.hour_interval_pre if phase == 'pretrain' else args.hour_interval_f |
| 20 | |
| 21 | self.edgelist = [] |
| 22 | self.edge_time = [] |
| 23 | self.num_users = 0 |
| 24 | self.num_items = 0 |
| 25 | self.num_edges = 0 |
| 26 | |
| 27 | self.train_user_dict = {} |
| 28 | self.test_user_dict = {} |
| 29 | |
| 30 | self._load_data(train_file, test_file, has_time) |
| 31 | |
| 32 | if phase == 'pretrain': |
| 33 | self.user_hist_dict = self.train_user_dict |
| 34 | |
| 35 | users_has_hist = set(list(self.user_hist_dict.keys())) |
| 36 | all_users = set(list(range(self.num_users))) |
| 37 | users_no_hist = all_users - users_has_hist |
| 38 | for u in users_no_hist: |
| 39 | self.user_hist_dict[u] = [] |
| 40 | |
| 41 | def _read_file(self, train_file, test_file, has_time=True): |
| 42 | with open(train_file, 'r') as f: |
| 43 | for line in f: |
| 44 | line = line.strip().split('\t') |
| 45 | if not has_time: |
| 46 | user, items = line[:2] |
| 47 | times = " ".join(["0"] * len(items.split(" "))) |
| 48 | else: |
| 49 | user, items, times = line |
| 50 | |
| 51 | for i in items.split(" "): |
| 52 | self.edgelist.append((int(user), int(i))) |
| 53 | for i in times.split(" "): |
| 54 | self.edge_time.append(int(i)) |
| 55 | self.train_user_dict[int(user)] = [int(i) for i in items.split(" ")] |
| 56 | |
| 57 | self.test_edge_num = 0 |
| 58 | with open(test_file, 'r') as f: |
| 59 | for line in f: |
| 60 | line = line.strip().split('\t') |
| 61 | user, items = line[:2] |
| 62 | self.test_user_dict[int(user)] = [int(i) for i in items.split(" ")] |
| 63 | self.test_edge_num += len(self.test_user_dict[int(user)]) |
| 64 | |
| 65 | def _load_data(self, train_file, test_file, has_time=True): |
| 66 | self._read_file(train_file, test_file, has_time) |
| 67 | |
| 68 | self.edgelist = np.array(self.edgelist, dtype=np.int32) |
| 69 | self.edge_time = 1 + self.timestamp_to_time_step(np.array(self.edge_time, dtype=np.int32)) |
| 70 | self.num_edges = len(self.edgelist) |
| 71 | if self.pre_dataset is not None: |
| 72 | self.num_users = self.pre_dataset.num_users |
| 73 | self.num_items = self.pre_dataset.num_items |
| 74 | else: |
| 75 | self.num_users = max([np.max(self.edgelist[:, 0]) + 1, np.max(list(self.test_user_dict.keys())) + 1]) |
| 76 | self.num_items = max([np.max(self.edgelist[:, 1]) + 1, np.max([np.max(self.test_user_dict[u]) for u in self.test_user_dict.keys()]) + 1]) |
| 77 | |
| 78 | self.graph = sp.coo_matrix((np.ones(self.num_edges), (self.edgelist[:, 0], self.edgelist[:, 1])), shape=(self.num_users, self.num_items)) |
| 79 | |
| 80 | if self.has_time: |
| 81 | self.edge_time_dict = defaultdict(dict) |
| 82 | for i in range(len(self.edgelist)): |
| 83 | self.edge_time_dict[self.edgelist[i][0]][self.edgelist[i][1]+self.num_users] = self.edge_time[i] |
| 84 | self.edge_time_dict[self.edgelist[i][1]+self.num_users][self.edgelist[i][0]] = self.edge_time[i] |
| 85 | |
| 86 | def timestamp_to_time_step(self, timestamp_arr, least_time=None): |
| 87 | interval_hour = self.hour_interval |
| 88 | if least_time is None: |
| 89 | least_time = np.min(timestamp_arr) |
| 90 | timestamp_arr = timestamp_arr - least_time |
| 91 | timestamp_arr = timestamp_arr // (interval_hour * 3600) |
| 92 | return timestamp_arr |