wht | b1e7959 | 2025-06-07 16:03:09 +0800 | [diff] [blame^] | 1 | import torch |
| 2 | import torch.nn as nn |
| 3 | import torch.nn.functional as F |
| 4 | import numpy as np |
| 5 | import scipy.sparse as sp |
| 6 | import math |
| 7 | import networkx as nx |
| 8 | import random |
| 9 | from copy import deepcopy |
| 10 | from utils.parse_args import args |
| 11 | from model.base_model import BaseModel |
| 12 | from model.operators import EdgelistDrop |
| 13 | from model.operators import scatter_add, scatter_sum |
| 14 | |
| 15 | |
| 16 | init = nn.init.xavier_uniform_ |
| 17 | |
| 18 | class LightGCN(BaseModel): |
| 19 | def __init__(self, dataset, pretrained_model=None, phase='pretrain'): |
| 20 | super().__init__(dataset) |
| 21 | self.adj = self._make_binorm_adj(dataset.graph) |
| 22 | self.edges = self.adj._indices().t() |
| 23 | self.edge_norm = self.adj._values() |
| 24 | |
| 25 | self.phase = phase |
| 26 | |
| 27 | self.emb_gate = lambda x: x |
| 28 | |
| 29 | if self.phase == 'pretrain' or self.phase == 'vanilla' or self.phase == 'for_tune': |
| 30 | self.user_embedding = nn.Parameter(init(torch.empty(self.num_users, self.emb_size))) |
| 31 | self.item_embedding = nn.Parameter(init(torch.empty(self.num_items, self.emb_size))) |
| 32 | |
| 33 | |
| 34 | elif self.phase == 'finetune': |
| 35 | pre_user_emb, pre_item_emb = pretrained_model.generate() |
| 36 | self.user_embedding = nn.Parameter(pre_user_emb).requires_grad_(True) |
| 37 | self.item_embedding = nn.Parameter(pre_item_emb).requires_grad_(True) |
| 38 | |
| 39 | elif self.phase == 'continue_tune': |
| 40 | # re-initialize for loading state dict |
| 41 | self.user_embedding = nn.Parameter(init(torch.empty(self.num_users, self.emb_size))) |
| 42 | self.item_embedding = nn.Parameter(init(torch.empty(self.num_items, self.emb_size))) |
| 43 | |
| 44 | self.edge_dropout = EdgelistDrop() |
| 45 | |
| 46 | def _agg(self, all_emb, edges, edge_norm): |
| 47 | src_emb = all_emb[edges[:, 0]] |
| 48 | |
| 49 | # bi-norm |
| 50 | src_emb = src_emb * edge_norm.unsqueeze(1) |
| 51 | |
| 52 | # conv |
| 53 | dst_emb = scatter_sum(src_emb, edges[:, 1], dim=0, dim_size=self.num_users+self.num_items) |
| 54 | return dst_emb |
| 55 | |
| 56 | def _edge_binorm(self, edges): |
| 57 | user_degs = scatter_add(torch.ones_like(edges[:, 0]), edges[:, 0], dim=0, dim_size=self.num_users) |
| 58 | user_degs = user_degs[edges[:, 0]] |
| 59 | item_degs = scatter_add(torch.ones_like(edges[:, 1]), edges[:, 1], dim=0, dim_size=self.num_items) |
| 60 | item_degs = item_degs[edges[:, 1]] |
| 61 | norm = torch.pow(user_degs, -0.5) * torch.pow(item_degs, -0.5) |
| 62 | return norm |
| 63 | |
| 64 | def forward(self, edges, edge_norm, return_layers=False): |
| 65 | all_emb = torch.cat([self.user_embedding, self.item_embedding], dim=0) |
| 66 | all_emb = self.emb_gate(all_emb) |
| 67 | res_emb = [all_emb] |
| 68 | for l in range(args.num_layers): |
| 69 | all_emb = self._agg(res_emb[-1], edges, edge_norm) |
| 70 | res_emb.append(all_emb) |
| 71 | if not return_layers: |
| 72 | res_emb = sum(res_emb) |
| 73 | user_res_emb, item_res_emb = res_emb.split([self.num_users, self.num_items], dim=0) |
| 74 | else: |
| 75 | user_res_emb, item_res_emb = [], [] |
| 76 | for emb in res_emb: |
| 77 | u_emb, i_emb = emb.split([self.num_users, self.num_items], dim=0) |
| 78 | user_res_emb.append(u_emb) |
| 79 | item_res_emb.append(i_emb) |
| 80 | return user_res_emb, item_res_emb |
| 81 | |
| 82 | def cal_loss(self, batch_data): |
| 83 | edges, dropout_mask = self.edge_dropout(self.edges, 1-args.edge_dropout, return_mask=True) |
| 84 | edge_norm = self.edge_norm[dropout_mask] |
| 85 | |
| 86 | # forward |
| 87 | users, pos_items, neg_items = batch_data |
| 88 | user_emb, item_emb = self.forward(edges, edge_norm) |
| 89 | batch_user_emb = user_emb[users] |
| 90 | pos_item_emb = item_emb[pos_items] |
| 91 | neg_item_emb = item_emb[neg_items] |
| 92 | rec_loss = self._bpr_loss(batch_user_emb, pos_item_emb, neg_item_emb) |
| 93 | reg_loss = args.weight_decay * self._reg_loss(users, pos_items, neg_items) |
| 94 | |
| 95 | loss = rec_loss + reg_loss |
| 96 | loss_dict = { |
| 97 | "rec_loss": rec_loss.item(), |
| 98 | "reg_loss": reg_loss.item(), |
| 99 | } |
| 100 | return loss, loss_dict |
| 101 | |
| 102 | @torch.no_grad() |
| 103 | def generate(self, return_layers=False): |
| 104 | return self.forward(self.edges, self.edge_norm, return_layers=return_layers) |
| 105 | |
| 106 | @torch.no_grad() |
| 107 | def generate_lgn(self, return_layers=False): |
| 108 | return self.forward(self.edges, self.edge_norm, return_layers=return_layers) |
| 109 | |
| 110 | @torch.no_grad() |
| 111 | def rating(self, user_emb, item_emb): |
| 112 | return torch.matmul(user_emb, item_emb.t()) |
| 113 | |
| 114 | def _reg_loss(self, users, pos_items, neg_items): |
| 115 | u_emb = self.user_embedding[users] |
| 116 | pos_i_emb = self.item_embedding[pos_items] |
| 117 | neg_i_emb = self.item_embedding[neg_items] |
| 118 | reg_loss = (1/2)*(u_emb.norm(2).pow(2) + |
| 119 | pos_i_emb.norm(2).pow(2) + |
| 120 | neg_i_emb.norm(2).pow(2))/float(len(users)) |
| 121 | return reg_loss |