Source code for dgllife.model.readout.weighted_sum_and_max

# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Apply weighted sum and max pooling to the node representations and concatenate the results.
# pylint: disable= no-member, arguments-differ, invalid-name

import dgl
import torch
import torch.nn as nn

from dgl.nn.pytorch import WeightAndSum

__all__ = ['WeightedSumAndMax']

# pylint: disable=W0221
[docs]class WeightedSumAndMax(nn.Module): r"""Apply weighted sum and max pooling to the node representations and concatenate the results. Parameters ---------- in_feats : int Input node feature size """ def __init__(self, in_feats): super(WeightedSumAndMax, self).__init__() self.weight_and_sum = WeightAndSum(in_feats)
[docs] def forward(self, bg, feats): """Readout Parameters ---------- bg : DGLGraph DGLGraph for a batch of graphs. feats : FloatTensor of shape (N, M1) * N is the total number of nodes in the batch of graphs * M1 is the input node feature size, which must match in_feats in initialization Returns ------- h_g : FloatTensor of shape (B, 2 * M1) * B is the number of graphs in the batch * M1 is the input node feature size, which must match in_feats in initialization """ h_g_sum = self.weight_and_sum(bg, feats) with bg.local_scope(): bg.ndata['h'] = feats h_g_max = dgl.max_nodes(bg, 'h') h_g = torch.cat([h_g_sum, h_g_max], dim=1) return h_g