Raver | d789517 | 2025-06-18 17:54:38 +0800 | [diff] [blame^] | 1 | import torch |
| 2 | from typing import Optional, Tuple |
| 3 | from torch import nn |
| 4 | |
| 5 | def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): |
| 6 | if dim < 0: |
| 7 | dim = other.dim() + dim |
| 8 | if src.dim() == 1: |
| 9 | for _ in range(0, dim): |
| 10 | src = src.unsqueeze(0) |
| 11 | for _ in range(src.dim(), other.dim()): |
| 12 | src = src.unsqueeze(-1) |
| 13 | src = src.expand(other.size()) |
| 14 | return src |
| 15 | |
| 16 | def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1, |
| 17 | out: Optional[torch.Tensor] = None, |
| 18 | dim_size: Optional[int] = None) -> torch.Tensor: |
| 19 | index = broadcast(index, src, dim) |
| 20 | if out is None: |
| 21 | size = list(src.size()) |
| 22 | if dim_size is not None: |
| 23 | size[dim] = dim_size |
| 24 | elif index.numel() == 0: |
| 25 | size[dim] = 0 |
| 26 | else: |
| 27 | size[dim] = int(index.max()) + 1 |
| 28 | out = torch.zeros(size, dtype=src.dtype, device=src.device) |
| 29 | return out.scatter_add_(dim, index, src) |
| 30 | else: |
| 31 | return out.scatter_add_(dim, index, src) |
| 32 | |
| 33 | def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1, |
| 34 | out: Optional[torch.Tensor] = None, |
| 35 | dim_size: Optional[int] = None) -> torch.Tensor: |
| 36 | return scatter_sum(src, index, dim, out, dim_size) |
| 37 | |
| 38 | |
| 39 | class EdgelistDrop(nn.Module): |
| 40 | def __init__(self): |
| 41 | super(EdgelistDrop, self).__init__() |
| 42 | |
| 43 | def forward(self, edgeList, keep_rate, return_mask=False): |
| 44 | if keep_rate == 1.0: |
| 45 | return edgeList, torch.ones(edgeList.size(0)).type(torch.bool) |
| 46 | edgeNum = edgeList.size(0) |
| 47 | mask = (torch.rand(edgeNum) + keep_rate).floor().type(torch.bool) |
| 48 | newEdgeList = edgeList[mask, :] |
| 49 | if return_mask: |
| 50 | return newEdgeList, mask |
| 51 | else: |
| 52 | return newEdgeList |