blob: a508966a60e64c51b2d48de32b1634edc27867fc [file] [log] [blame]
Raverd7895172025-06-18 17:54:38 +08001import torch
2from typing import Optional, Tuple
3from torch import nn
4
5def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
6 if dim < 0:
7 dim = other.dim() + dim
8 if src.dim() == 1:
9 for _ in range(0, dim):
10 src = src.unsqueeze(0)
11 for _ in range(src.dim(), other.dim()):
12 src = src.unsqueeze(-1)
13 src = src.expand(other.size())
14 return src
15
16def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
17 out: Optional[torch.Tensor] = None,
18 dim_size: Optional[int] = None) -> torch.Tensor:
19 index = broadcast(index, src, dim)
20 if out is None:
21 size = list(src.size())
22 if dim_size is not None:
23 size[dim] = dim_size
24 elif index.numel() == 0:
25 size[dim] = 0
26 else:
27 size[dim] = int(index.max()) + 1
28 out = torch.zeros(size, dtype=src.dtype, device=src.device)
29 return out.scatter_add_(dim, index, src)
30 else:
31 return out.scatter_add_(dim, index, src)
32
33def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
34 out: Optional[torch.Tensor] = None,
35 dim_size: Optional[int] = None) -> torch.Tensor:
36 return scatter_sum(src, index, dim, out, dim_size)
37
38
39class EdgelistDrop(nn.Module):
40 def __init__(self):
41 super(EdgelistDrop, self).__init__()
42
43 def forward(self, edgeList, keep_rate, return_mask=False):
44 if keep_rate == 1.0:
45 return edgeList, torch.ones(edgeList.size(0)).type(torch.bool)
46 edgeNum = edgeList.size(0)
47 mask = (torch.rand(edgeNum) + keep_rate).floor().type(torch.bool)
48 newEdgeList = edgeList[mask, :]
49 if return_mask:
50 return newEdgeList, mask
51 else:
52 return newEdgeList