blob: 819442ab97c6a1069b2a01c6068dc9d7f0a14ff7 [file] [log] [blame]
import torch
import torch.nn as nn
from utils.parse_args import args
from scipy.sparse import csr_matrix
import scipy.sparse as sp
import numpy as np
import torch.nn.functional as F
class BaseModel(nn.Module):
def __init__(self, dataloader):
super(BaseModel, self).__init__()
self.num_users = dataloader.num_users
self.num_items = dataloader.num_items
self.emb_size = args.emb_size
def forward(self):
pass
def cal_loss(self, batch_data):
pass
def _check_inf(self, loss, pos_score, neg_score, edge_weight):
# find inf idx
inf_idx = torch.isinf(loss) | torch.isnan(loss)
if inf_idx.any():
print("find inf in loss")
if type(edge_weight) != int:
print(edge_weight[inf_idx])
print(f"pos_score: {pos_score[inf_idx]}")
print(f"neg_score: {neg_score[inf_idx]}")
raise ValueError("find inf in loss")
def _make_binorm_adj(self, mat):
a = csr_matrix((self.num_users, self.num_users))
b = csr_matrix((self.num_items, self.num_items))
mat = sp.vstack(
[sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])
mat = (mat != 0) * 1.0
# mat = (mat + sp.eye(mat.shape[0])) * 1.0# MARK
degree = np.array(mat.sum(axis=-1))
d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1])
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
mat = mat.dot(d_inv_sqrt_mat).transpose().dot(
d_inv_sqrt_mat).tocoo()
# make torch tensor
idxs = torch.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
vals = torch.from_numpy(mat.data.astype(np.float32))
shape = torch.Size(mat.shape)
return torch.sparse.FloatTensor(idxs, vals, shape).to(args.device)
def _make_binorm_adj_self_loop(self, mat):
a = csr_matrix((self.num_users, self.num_users))
b = csr_matrix((self.num_items, self.num_items))
mat = sp.vstack(
[sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])
mat = (mat != 0) * 1.0
mat = (mat + sp.eye(mat.shape[0])) * 1.0 # self loop
degree = np.array(mat.sum(axis=-1))
d_inv_sqrt = np.reshape(np.power(degree, -0.5), [-1])
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
d_inv_sqrt_mat = sp.diags(d_inv_sqrt)
mat = mat.dot(d_inv_sqrt_mat).transpose().dot(
d_inv_sqrt_mat).tocoo()
# make torch tensor
idxs = torch.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
vals = torch.from_numpy(mat.data.astype(np.float32))
shape = torch.Size(mat.shape)
return torch.sparse.FloatTensor(idxs, vals, shape).to(args.device)
def _sp_matrix_to_sp_tensor(self, sp_matrix):
coo = sp_matrix.tocoo()
indices = torch.LongTensor([coo.row, coo.col])
values = torch.FloatTensor(coo.data)
return torch.sparse.FloatTensor(indices, values, coo.shape).coalesce().to(args.device)
def _bpr_loss(self, user_emb, pos_item_emb, neg_item_emb):
pos_score = (user_emb * pos_item_emb).sum(dim=1)
neg_score = (user_emb * neg_item_emb).sum(dim=1)
loss = -torch.log(1e-10 + torch.sigmoid((pos_score - neg_score)))
self._check_inf(loss, pos_score, neg_score, 0)
return loss.mean()
def _nce_loss(self, pos_score, neg_score, edge_weight=1):
numerator = torch.exp(pos_score)
denominator = torch.exp(pos_score) + torch.exp(neg_score).sum(dim=1)
loss = -torch.log(numerator/denominator) * edge_weight
self._check_inf(loss, pos_score, neg_score, edge_weight)
return loss.mean()
def _infonce_loss(self, pos_1, pos_2, negs, tau):
pos_1 = self.cl_mlp(pos_1)
pos_2 = self.cl_mlp(pos_2)
negs = self.cl_mlp(negs)
pos_1 = F.normalize(pos_1, dim=-1)
pos_2 = F.normalize(pos_2, dim=-1)
negs = F.normalize(negs, dim=-1)
pos_score = torch.mul(pos_1, pos_2).sum(dim=1)
# B, 1, E * B, E, N -> B, N
neg_score = torch.bmm(pos_1.unsqueeze(1), negs.transpose(1, 2)).squeeze(1)
# infonce loss
numerator = torch.exp(pos_score / tau)
denominator = torch.exp(pos_score / tau) + torch.exp(neg_score / tau).sum(dim=1)
loss = -torch.log(numerator/denominator)
self._check_inf(loss, pos_score, neg_score, 0)
return loss.mean()