blob: b6b447e9b5671b44c111c7967d34e216c2b552e2 [file] [log] [blame]
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import scipy.sparse as sp
import math
import networkx as nx
import random
from copy import deepcopy
from utils.parse_args import args
from model.base_model import BaseModel
from model.operators import EdgelistDrop
from model.operators import scatter_add, scatter_sum
init = nn.init.xavier_uniform_
class LightGCN(BaseModel):
def __init__(self, dataset, pretrained_model=None, phase='pretrain'):
super().__init__(dataset)
self.adj = self._make_binorm_adj(dataset.graph)
self.edges = self.adj._indices().t()
self.edge_norm = self.adj._values()
self.phase = phase
self.emb_gate = lambda x: x
if self.phase == 'pretrain' or self.phase == 'vanilla' or self.phase == 'for_tune':
self.user_embedding = nn.Parameter(init(torch.empty(self.num_users, self.emb_size)))
self.item_embedding = nn.Parameter(init(torch.empty(self.num_items, self.emb_size)))
elif self.phase == 'finetune':
pre_user_emb, pre_item_emb = pretrained_model.generate()
self.user_embedding = nn.Parameter(pre_user_emb).requires_grad_(True)
self.item_embedding = nn.Parameter(pre_item_emb).requires_grad_(True)
elif self.phase == 'continue_tune':
# re-initialize for loading state dict
self.user_embedding = nn.Parameter(init(torch.empty(self.num_users, self.emb_size)))
self.item_embedding = nn.Parameter(init(torch.empty(self.num_items, self.emb_size)))
self.edge_dropout = EdgelistDrop()
def _agg(self, all_emb, edges, edge_norm):
src_emb = all_emb[edges[:, 0]]
# bi-norm
src_emb = src_emb * edge_norm.unsqueeze(1)
# conv
dst_emb = scatter_sum(src_emb, edges[:, 1], dim=0, dim_size=self.num_users+self.num_items)
return dst_emb
def _edge_binorm(self, edges):
user_degs = scatter_add(torch.ones_like(edges[:, 0]), edges[:, 0], dim=0, dim_size=self.num_users)
user_degs = user_degs[edges[:, 0]]
item_degs = scatter_add(torch.ones_like(edges[:, 1]), edges[:, 1], dim=0, dim_size=self.num_items)
item_degs = item_degs[edges[:, 1]]
norm = torch.pow(user_degs, -0.5) * torch.pow(item_degs, -0.5)
return norm
def forward(self, edges, edge_norm, return_layers=False):
all_emb = torch.cat([self.user_embedding, self.item_embedding], dim=0)
all_emb = self.emb_gate(all_emb)
res_emb = [all_emb]
for l in range(args.num_layers):
all_emb = self._agg(res_emb[-1], edges, edge_norm)
res_emb.append(all_emb)
if not return_layers:
res_emb = sum(res_emb)
user_res_emb, item_res_emb = res_emb.split([self.num_users, self.num_items], dim=0)
else:
user_res_emb, item_res_emb = [], []
for emb in res_emb:
u_emb, i_emb = emb.split([self.num_users, self.num_items], dim=0)
user_res_emb.append(u_emb)
item_res_emb.append(i_emb)
return user_res_emb, item_res_emb
def cal_loss(self, batch_data):
edges, dropout_mask = self.edge_dropout(self.edges, 1-args.edge_dropout, return_mask=True)
edge_norm = self.edge_norm[dropout_mask]
# forward
users, pos_items, neg_items = batch_data
user_emb, item_emb = self.forward(edges, edge_norm)
batch_user_emb = user_emb[users]
pos_item_emb = item_emb[pos_items]
neg_item_emb = item_emb[neg_items]
rec_loss = self._bpr_loss(batch_user_emb, pos_item_emb, neg_item_emb)
reg_loss = args.weight_decay * self._reg_loss(users, pos_items, neg_items)
loss = rec_loss + reg_loss
loss_dict = {
"rec_loss": rec_loss.item(),
"reg_loss": reg_loss.item(),
}
return loss, loss_dict
@torch.no_grad()
def generate(self, return_layers=False):
return self.forward(self.edges, self.edge_norm, return_layers=return_layers)
@torch.no_grad()
def generate_lgn(self, return_layers=False):
return self.forward(self.edges, self.edge_norm, return_layers=return_layers)
@torch.no_grad()
def rating(self, user_emb, item_emb):
return torch.matmul(user_emb, item_emb.t())
def _reg_loss(self, users, pos_items, neg_items):
u_emb = self.user_embedding[users]
pos_i_emb = self.item_embedding[pos_items]
neg_i_emb = self.item_embedding[neg_items]
reg_loss = (1/2)*(u_emb.norm(2).pow(2) +
pos_i_emb.norm(2).pow(2) +
neg_i_emb.norm(2).pow(2))/float(len(users))
return reg_loss