blob: a508966a60e64c51b2d48de32b1634edc27867fc [file] [log] [blame]
import torch
from typing import Optional, Tuple
from torch import nn
def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
if dim < 0:
dim = other.dim() + dim
if src.dim() == 1:
for _ in range(0, dim):
src = src.unsqueeze(0)
for _ in range(src.dim(), other.dim()):
src = src.unsqueeze(-1)
src = src.expand(other.size())
return src
def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
index = broadcast(index, src, dim)
if out is None:
size = list(src.size())
if dim_size is not None:
size[dim] = dim_size
elif index.numel() == 0:
size[dim] = 0
else:
size[dim] = int(index.max()) + 1
out = torch.zeros(size, dtype=src.dtype, device=src.device)
return out.scatter_add_(dim, index, src)
else:
return out.scatter_add_(dim, index, src)
def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return scatter_sum(src, index, dim, out, dim_size)
class EdgelistDrop(nn.Module):
def __init__(self):
super(EdgelistDrop, self).__init__()
def forward(self, edgeList, keep_rate, return_mask=False):
if keep_rate == 1.0:
return edgeList, torch.ones(edgeList.size(0)).type(torch.bool)
edgeNum = edgeList.size(0)
mask = (torch.rand(edgeNum) + keep_rate).floor().type(torch.bool)
newEdgeList = edgeList[mask, :]
if return_mask:
return newEdgeList, mask
else:
return newEdgeList