blob: b6b447e9b5671b44c111c7967d34e216c2b552e2 [file] [log] [blame]
whtb1e79592025-06-07 16:03:09 +08001import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import numpy as np
5import scipy.sparse as sp
6import math
7import networkx as nx
8import random
9from copy import deepcopy
10from utils.parse_args import args
11from model.base_model import BaseModel
12from model.operators import EdgelistDrop
13from model.operators import scatter_add, scatter_sum
14
15
16init = nn.init.xavier_uniform_
17
18class LightGCN(BaseModel):
19 def __init__(self, dataset, pretrained_model=None, phase='pretrain'):
20 super().__init__(dataset)
21 self.adj = self._make_binorm_adj(dataset.graph)
22 self.edges = self.adj._indices().t()
23 self.edge_norm = self.adj._values()
24
25 self.phase = phase
26
27 self.emb_gate = lambda x: x
28
29 if self.phase == 'pretrain' or self.phase == 'vanilla' or self.phase == 'for_tune':
30 self.user_embedding = nn.Parameter(init(torch.empty(self.num_users, self.emb_size)))
31 self.item_embedding = nn.Parameter(init(torch.empty(self.num_items, self.emb_size)))
32
33
34 elif self.phase == 'finetune':
35 pre_user_emb, pre_item_emb = pretrained_model.generate()
36 self.user_embedding = nn.Parameter(pre_user_emb).requires_grad_(True)
37 self.item_embedding = nn.Parameter(pre_item_emb).requires_grad_(True)
38
39 elif self.phase == 'continue_tune':
40 # re-initialize for loading state dict
41 self.user_embedding = nn.Parameter(init(torch.empty(self.num_users, self.emb_size)))
42 self.item_embedding = nn.Parameter(init(torch.empty(self.num_items, self.emb_size)))
43
44 self.edge_dropout = EdgelistDrop()
45
46 def _agg(self, all_emb, edges, edge_norm):
47 src_emb = all_emb[edges[:, 0]]
48
49 # bi-norm
50 src_emb = src_emb * edge_norm.unsqueeze(1)
51
52 # conv
53 dst_emb = scatter_sum(src_emb, edges[:, 1], dim=0, dim_size=self.num_users+self.num_items)
54 return dst_emb
55
56 def _edge_binorm(self, edges):
57 user_degs = scatter_add(torch.ones_like(edges[:, 0]), edges[:, 0], dim=0, dim_size=self.num_users)
58 user_degs = user_degs[edges[:, 0]]
59 item_degs = scatter_add(torch.ones_like(edges[:, 1]), edges[:, 1], dim=0, dim_size=self.num_items)
60 item_degs = item_degs[edges[:, 1]]
61 norm = torch.pow(user_degs, -0.5) * torch.pow(item_degs, -0.5)
62 return norm
63
64 def forward(self, edges, edge_norm, return_layers=False):
65 all_emb = torch.cat([self.user_embedding, self.item_embedding], dim=0)
66 all_emb = self.emb_gate(all_emb)
67 res_emb = [all_emb]
68 for l in range(args.num_layers):
69 all_emb = self._agg(res_emb[-1], edges, edge_norm)
70 res_emb.append(all_emb)
71 if not return_layers:
72 res_emb = sum(res_emb)
73 user_res_emb, item_res_emb = res_emb.split([self.num_users, self.num_items], dim=0)
74 else:
75 user_res_emb, item_res_emb = [], []
76 for emb in res_emb:
77 u_emb, i_emb = emb.split([self.num_users, self.num_items], dim=0)
78 user_res_emb.append(u_emb)
79 item_res_emb.append(i_emb)
80 return user_res_emb, item_res_emb
81
82 def cal_loss(self, batch_data):
83 edges, dropout_mask = self.edge_dropout(self.edges, 1-args.edge_dropout, return_mask=True)
84 edge_norm = self.edge_norm[dropout_mask]
85
86 # forward
87 users, pos_items, neg_items = batch_data
88 user_emb, item_emb = self.forward(edges, edge_norm)
89 batch_user_emb = user_emb[users]
90 pos_item_emb = item_emb[pos_items]
91 neg_item_emb = item_emb[neg_items]
92 rec_loss = self._bpr_loss(batch_user_emb, pos_item_emb, neg_item_emb)
93 reg_loss = args.weight_decay * self._reg_loss(users, pos_items, neg_items)
94
95 loss = rec_loss + reg_loss
96 loss_dict = {
97 "rec_loss": rec_loss.item(),
98 "reg_loss": reg_loss.item(),
99 }
100 return loss, loss_dict
101
102 @torch.no_grad()
103 def generate(self, return_layers=False):
104 return self.forward(self.edges, self.edge_norm, return_layers=return_layers)
105
106 @torch.no_grad()
107 def generate_lgn(self, return_layers=False):
108 return self.forward(self.edges, self.edge_norm, return_layers=return_layers)
109
110 @torch.no_grad()
111 def rating(self, user_emb, item_emb):
112 return torch.matmul(user_emb, item_emb.t())
113
114 def _reg_loss(self, users, pos_items, neg_items):
115 u_emb = self.user_embedding[users]
116 pos_i_emb = self.item_embedding[pos_items]
117 neg_i_emb = self.item_embedding[neg_items]
118 reg_loss = (1/2)*(u_emb.norm(2).pow(2) +
119 pos_i_emb.norm(2).pow(2) +
120 neg_i_emb.norm(2).pow(2))/float(len(users))
121 return reg_loss