blob: 6c59aa6cd697a467bbceb106a1bf9e450feaa9f3 [file] [log] [blame]
Raverd7895172025-06-18 17:54:38 +08001import torch
2import torch.nn as nn
3from app.utils.parse_args import args
4from scipy.sparse import csr_matrix
5import scipy.sparse as sp
6import numpy as np
7import torch.nn.functional as F
8
9
10class BaseModel(nn.Module):
11 def __init__(self, dataloader):
12 super(BaseModel, self).__init__()
13 self.num_users = dataloader.num_users
14 self.num_items = dataloader.num_items
15 self.emb_size = args.emb_size
16
17 def forward(self):
18 pass
19
20 def cal_loss(self, batch_data):
21 pass
22
23 def _check_inf(self, loss, pos_score, neg_score, edge_weight):
24 # find inf idx
25 inf_idx = torch.isinf(loss) | torch.isnan(loss)
26 if inf_idx.any():
27 print("find inf in loss")
28 if type(edge_weight) != int:
29 print(edge_weight[inf_idx])
30 print(f"pos_score: {pos_score[inf_idx]}")
31 print(f"neg_score: {neg_score[inf_idx]}")
32 raise ValueError("find inf in loss")
33
34 def _make_binorm_adj(self, mat):
35 a = csr_matrix((self.num_users, self.num_users))
36 b = csr_matrix((self.num_items, self.num_items))
37 mat = sp.vstack(
38 [sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])
39 mat = (mat != 0) * 1.0
40 # mat = (mat + sp.eye(mat.shape[0])) * 1.0# MARK
41 degree = np.array(mat.sum(axis=-1))
42 d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1])
43 d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
44 d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
45 mat = mat.dot(d_inv_sqrt_mat).transpose().dot(
46 d_inv_sqrt_mat).tocoo()
47
48 # make torch tensor
49 idxs = torch.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
50 vals = torch.from_numpy(mat.data.astype(np.float32))
51 shape = torch.Size(mat.shape)
52 return torch.sparse.FloatTensor(idxs, vals, shape).to(args.device)
53
54 def _make_binorm_adj_self_loop(self, mat):
55 a = csr_matrix((self.num_users, self.num_users))
56 b = csr_matrix((self.num_items, self.num_items))
57 mat = sp.vstack(
58 [sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])
59 mat = (mat != 0) * 1.0
60 mat = (mat + sp.eye(mat.shape[0])) * 1.0 # self loop
61 degree = np.array(mat.sum(axis=-1))
62 d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1])
63 d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
64 d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
65 mat = mat.dot(d_inv_sqrt_mat).transpose().dot(
66 d_inv_sqrt_mat).tocoo()
67
68 # make torch tensor
69 idxs = torch.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
70 vals = torch.from_numpy(mat.data.astype(np.float32))
71 shape = torch.Size(mat.shape)
72 return torch.sparse.FloatTensor(idxs, vals, shape).to(args.device)
73
74
75 def _sp_matrix_to_sp_tensor(self, sp_matrix):
76 coo = sp_matrix.tocoo()
77 indices = torch.LongTensor([coo.row, coo.col])
78 values = torch.FloatTensor(coo.data)
79 return torch.sparse.FloatTensor(indices, values, coo.shape).coalesce().to(args.device)
80
81 def _bpr_loss(self, user_emb, pos_item_emb, neg_item_emb):
82 pos_score = (user_emb * pos_item_emb).sum(dim=1)
83 neg_score = (user_emb * neg_item_emb).sum(dim=1)
84 loss = -torch.log(1e-10 + torch.sigmoid((pos_score - neg_score)))
85 self._check_inf(loss, pos_score, neg_score, 0)
86 return loss.mean()
87
88 def _nce_loss(self, pos_score, neg_score, edge_weight=1):
89 numerator = torch.exp(pos_score)
90 denominator = torch.exp(pos_score) + torch.exp(neg_score).sum(dim=1)
91 loss = -torch.log(numerator/denominator) * edge_weight
92 self._check_inf(loss, pos_score, neg_score, edge_weight)
93 return loss.mean()
94
95 def _infonce_loss(self, pos_1, pos_2, negs, tau):
96 pos_1 = self.cl_mlp(pos_1)
97 pos_2 = self.cl_mlp(pos_2)
98 negs = self.cl_mlp(negs)
99 pos_1 = F.normalize(pos_1, dim=-1)
100 pos_2 = F.normalize(pos_2, dim=-1)
101 negs = F.normalize(negs, dim=-1)
102 pos_score = torch.mul(pos_1, pos_2).sum(dim=1)
103 # B, 1, E * B, E, N -> B, N
104 neg_score = torch.bmm(pos_1.unsqueeze(1), negs.transpose(1, 2)).squeeze(1)
105 # infonce loss
106 numerator = torch.exp(pos_score / tau)
107 denominator = torch.exp(pos_score / tau) + torch.exp(neg_score / tau).sum(dim=1)
108 loss = -torch.log(numerator/denominator)
109 self._check_inf(loss, pos_score, neg_score, 0)
110 return loss.mean()
111