Raver | d789517 | 2025-06-18 17:54:38 +0800 | [diff] [blame] | 1 | import torch |
| 2 | import torch.nn as nn |
| 3 | from app.utils.parse_args import args |
| 4 | from scipy.sparse import csr_matrix |
| 5 | import scipy.sparse as sp |
| 6 | import numpy as np |
| 7 | import torch.nn.functional as F |
| 8 | |
| 9 | |
| 10 | class BaseModel(nn.Module): |
| 11 | def __init__(self, dataloader): |
| 12 | super(BaseModel, self).__init__() |
| 13 | self.num_users = dataloader.num_users |
| 14 | self.num_items = dataloader.num_items |
| 15 | self.emb_size = args.emb_size |
| 16 | |
| 17 | def forward(self): |
| 18 | pass |
| 19 | |
| 20 | def cal_loss(self, batch_data): |
| 21 | pass |
| 22 | |
| 23 | def _check_inf(self, loss, pos_score, neg_score, edge_weight): |
| 24 | # find inf idx |
| 25 | inf_idx = torch.isinf(loss) | torch.isnan(loss) |
| 26 | if inf_idx.any(): |
| 27 | print("find inf in loss") |
| 28 | if type(edge_weight) != int: |
| 29 | print(edge_weight[inf_idx]) |
| 30 | print(f"pos_score: {pos_score[inf_idx]}") |
| 31 | print(f"neg_score: {neg_score[inf_idx]}") |
| 32 | raise ValueError("find inf in loss") |
| 33 | |
| 34 | def _make_binorm_adj(self, mat): |
| 35 | a = csr_matrix((self.num_users, self.num_users)) |
| 36 | b = csr_matrix((self.num_items, self.num_items)) |
| 37 | mat = sp.vstack( |
| 38 | [sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])]) |
| 39 | mat = (mat != 0) * 1.0 |
| 40 | # mat = (mat + sp.eye(mat.shape[0])) * 1.0# MARK |
| 41 | degree = np.array(mat.sum(axis=-1)) |
| 42 | d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1]) |
| 43 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0 |
| 44 | d_inv_sqrt_mat = sp.diags(d_inv_sqrt) |
| 45 | mat = mat.dot(d_inv_sqrt_mat).transpose().dot( |
| 46 | d_inv_sqrt_mat).tocoo() |
| 47 | |
| 48 | # make torch tensor |
| 49 | idxs = torch.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64)) |
| 50 | vals = torch.from_numpy(mat.data.astype(np.float32)) |
| 51 | shape = torch.Size(mat.shape) |
| 52 | return torch.sparse.FloatTensor(idxs, vals, shape).to(args.device) |
| 53 | |
| 54 | def _make_binorm_adj_self_loop(self, mat): |
| 55 | a = csr_matrix((self.num_users, self.num_users)) |
| 56 | b = csr_matrix((self.num_items, self.num_items)) |
| 57 | mat = sp.vstack( |
| 58 | [sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])]) |
| 59 | mat = (mat != 0) * 1.0 |
| 60 | mat = (mat + sp.eye(mat.shape[0])) * 1.0 # self loop |
| 61 | degree = np.array(mat.sum(axis=-1)) |
| 62 | d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1]) |
| 63 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0 |
| 64 | d_inv_sqrt_mat = sp.diags(d_inv_sqrt) |
| 65 | mat = mat.dot(d_inv_sqrt_mat).transpose().dot( |
| 66 | d_inv_sqrt_mat).tocoo() |
| 67 | |
| 68 | # make torch tensor |
| 69 | idxs = torch.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64)) |
| 70 | vals = torch.from_numpy(mat.data.astype(np.float32)) |
| 71 | shape = torch.Size(mat.shape) |
| 72 | return torch.sparse.FloatTensor(idxs, vals, shape).to(args.device) |
| 73 | |
| 74 | |
| 75 | def _sp_matrix_to_sp_tensor(self, sp_matrix): |
| 76 | coo = sp_matrix.tocoo() |
| 77 | indices = torch.LongTensor([coo.row, coo.col]) |
| 78 | values = torch.FloatTensor(coo.data) |
| 79 | return torch.sparse.FloatTensor(indices, values, coo.shape).coalesce().to(args.device) |
| 80 | |
| 81 | def _bpr_loss(self, user_emb, pos_item_emb, neg_item_emb): |
| 82 | pos_score = (user_emb * pos_item_emb).sum(dim=1) |
| 83 | neg_score = (user_emb * neg_item_emb).sum(dim=1) |
| 84 | loss = -torch.log(1e-10 + torch.sigmoid((pos_score - neg_score))) |
| 85 | self._check_inf(loss, pos_score, neg_score, 0) |
| 86 | return loss.mean() |
| 87 | |
| 88 | def _nce_loss(self, pos_score, neg_score, edge_weight=1): |
| 89 | numerator = torch.exp(pos_score) |
| 90 | denominator = torch.exp(pos_score) + torch.exp(neg_score).sum(dim=1) |
| 91 | loss = -torch.log(numerator/denominator) * edge_weight |
| 92 | self._check_inf(loss, pos_score, neg_score, edge_weight) |
| 93 | return loss.mean() |
| 94 | |
| 95 | def _infonce_loss(self, pos_1, pos_2, negs, tau): |
| 96 | pos_1 = self.cl_mlp(pos_1) |
| 97 | pos_2 = self.cl_mlp(pos_2) |
| 98 | negs = self.cl_mlp(negs) |
| 99 | pos_1 = F.normalize(pos_1, dim=-1) |
| 100 | pos_2 = F.normalize(pos_2, dim=-1) |
| 101 | negs = F.normalize(negs, dim=-1) |
| 102 | pos_score = torch.mul(pos_1, pos_2).sum(dim=1) |
| 103 | # B, 1, E * B, E, N -> B, N |
| 104 | neg_score = torch.bmm(pos_1.unsqueeze(1), negs.transpose(1, 2)).squeeze(1) |
| 105 | # infonce loss |
| 106 | numerator = torch.exp(pos_score / tau) |
| 107 | denominator = torch.exp(pos_score / tau) + torch.exp(neg_score / tau).sum(dim=1) |
| 108 | loss = -torch.log(numerator/denominator) |
| 109 | self._check_inf(loss, pos_score, neg_score, 0) |
| 110 | return loss.mean() |
| 111 | |