Source code for dgllife.model.model_zoo.gatv2_predictor

# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# GATv2-based model for regression and classification on graphs
#
# pylint: disable= no-member, arguments-differ, invalid-name

import torch.nn as nn

from .mlp_predictor import MLPPredictor
from ..gnn.gatv2 import GATv2
from ..readout.weighted_sum_and_max import WeightedSumAndMax

# pylint: disable=W0221
[docs]class GATv2Predictor(nn.Module): r"""GATv2-based model for regression and classification on graphs GATv2 is introduced in `How Attentive Are Graph Attention Networks? <https://arxiv.org/pdf/2105.14491.pdf>`. This model is based on GATv2 and can be used for regression and classification on graphs. After updating node representations, we perform a weighted sum with learnable weights and max pooling on them and concatenate the output of the two operations, which is then fed into an MLP for final prediction. For classification tasks, the output will be logits, i.e. values before sigmoid or softmax. Parameters ---------- in_feats : int Number of input node features hidden_feats : list of int, optional ``hidden_feats[i]`` gives the output size of an attention head in the i-th GATv2 layer. ``len(hidden_feats)`` equals the number of GATv2 layers. By default, we use ``[32, 32]``. num_heads : list of int, optional ``num_heads[i]`` gives the number of attention heads in the i-th GATv2 layer. ``len(num_heads)`` equals the number of GATv2 layers. By default , we use 4 attention heads per GATv2 layer. feat_drops : list of float, optional ``feat_drops[i]`` gives the dropout applied to the input features in the i-th GATv2 layer. ``len(feat_drops)`` equals the number of GATv2 layers. By default, we use zero for all GATv2 layers. attn_drops : list of float, optional ``attn_drops[i]`` gives the dropout applied to the attention values of edges in the i-th GATv2 layer. ``len(attn_drops)`` equals the number of GATv2 layers. By default, we use zero for all GATv2 layers. alphas : list of float, optional ``alphas[i]`` gives the slope for the negative values in the LeakyReLU function of the i-th GATv2 layer. ``len(alphas)`` equals the number of GATv2 layers. By default, we use 0.2 for all GATv2 layers. residuals : list of bool, optional ``residuals[i]`` decides if residual connection is to be used for the i-th GATv2 layer. ``len(residuals)`` equals the number of GATv2 layers. By default, we use ``False`` for all GATv2 layers. activations : list of callable, optional ``activations[i]`` gives the activation function applied to the result of the i-th GATv2 layer. ``len(activations)`` equals the number of GATv2 layers. By default, we use ELU for all GATv2 layers, except for the last layer. allow_zero_in_degree : bool, optional If there are 0-in-degree nodes in the graph, output for those nodes will be invalid since no messages will be passed to those nodes. This is harmful for some applications, causing silent performance regression . This module will raise a DGLError if it detects 0-in-degree nodes in input graph. By setting True, it will suppress the check and let the users handle it by themselves. Defaults: False. biases : list of bool, optional ``biases[i]`` decides if an additive bias is allowed to be learned by the i-th GATv2 layer. ``len(biases)`` equals the number of GATv2 layers. By default, additive biases are learned for all GATv2 layers. share_weights : list of bool, optional ``share_weights[i]`` decides if the learnable weight matrix for source and destination nodes is the same in the i-th GATv2 layer. ``len(share_weights)`` equals the number of GATv2 Layers. By default, no weight sharing is used in all GATv2 layers. agg_modes : list of str, optional ``agg_modes[i]`` gives the way to aggregate multi-head attention results in the i-th GATv2 layer. ``len(agg_modes)`` equals the number of GATv2 Layers. By default, we flatten all-head results for each GATv2 layer, except for the last layer. n_tasks : int, optional Number of tasks, which is also the output size. Default to 1. predictor_out_feats : int, optional Size for hidden representations in the output MLP predictor. Default to 128. predictor_dropout : float, optional The probability for dropout in the output MLP predictor. Default to 0. """ def __init__( self, in_feats, hidden_feats=None, num_heads=None, feat_drops=None, attn_drops=None, alphas=None, residuals=None, activations=None, allow_zero_in_degree=False, biases=None, share_weights=None, agg_modes=None, n_tasks=1, predictor_out_feats=128, predictor_dropout=0.): super(GATv2Predictor, self).__init__() self.gnn = GATv2(in_feats=in_feats, hidden_feats=hidden_feats, num_heads=num_heads, feat_drops=feat_drops, attn_drops=attn_drops, alphas=alphas, residuals=residuals, activations=activations, allow_zero_in_degree=allow_zero_in_degree, biases=biases, share_weights=share_weights, agg_modes=agg_modes) if agg_modes[-1] == 'flatten': gnn_out_feats = hidden_feats[-1] * num_heads[-1] else: gnn_out_feats = hidden_feats[-1] self.readout = WeightedSumAndMax(gnn_out_feats) self.predict = MLPPredictor(2 * gnn_out_feats, predictor_out_feats, n_tasks, predictor_dropout)
[docs] def forward(self, bg, feats, get_attention=False): """Graph-level regression/soft classification. Parameters ---------- bg : DGLGraph DGLGraph for a batch of graphs. feats : FloatTensor of shape (N, M1) * N is the total number of nodes in the batch of graphs. * M1 is the input node feature size, which equals in_feats in initialization get_attention : bool, optional Whether to return the attention values. Defaults: False Returns ------- preds : FloatTensor of shape (B, n_tasks) * Predictions on graphs * B for the number of graphs in the batch attentions : list of FloatTensor of shape (E, H, 1), optional It is returned when :attr:`get_attention` is True. ``attentions[i]`` gives the attention values in the i-th GATv2 layer. * `E` is the number of edges. * `H` is the number of attention heads. """ if get_attention: node_feats, attentions = self.gnn(bg, feats, get_attention=get_attention) graph_feats = self.readout(bg, node_feats) return self.predict(graph_feats), attentions else: node_feats = self.gnn(bg, feats) graph_feats = self.readout(bg, node_feats) return self.predict(graph_feats)