diff --git a/torch_geometric/nn/models/attentive_fp.py b/torch_geometric/nn/models/attentive_fp.py index ddb9f8a99349..c91c27d253f0 100644 --- a/torch_geometric/nn/models/attentive_fp.py +++ b/torch_geometric/nn/models/attentive_fp.py @@ -13,8 +13,13 @@ class GATEConv(MessagePassing): - def __init__(self, in_channels: int, out_channels: int, edge_dim: int, - dropout: float = 0.0): + def __init__( + self, + in_channels: int, + out_channels: int, + edge_dim: int, + dropout: float = 0.0, + ): super().__init__(aggr='add', node_dim=0) self.dropout = dropout @@ -29,7 +34,7 @@ def __init__(self, in_channels: int, out_channels: int, edge_dim: int, self.reset_parameters() - def reset_parameters(self): + def reset_parameters(self) -> None: glorot(self.att_l) glorot(self.att_r) glorot(self.lin1.weight) @@ -73,9 +78,16 @@ class AttentiveFP(torch.nn.Module): dropout (float, optional): Dropout probability. (default: :obj:`0.0`) """ - def __init__(self, in_channels: int, hidden_channels: int, - out_channels: int, edge_dim: int, num_layers: int, - num_timesteps: int, dropout: float = 0.0): + def __init__( + self, + in_channels: int, + hidden_channels: int, + out_channels: int, + edge_dim: int, + num_layers: int, + num_timesteps: int, + dropout: float = 0.0, + ): super().__init__() self.num_layers = num_layers @@ -103,7 +115,7 @@ def __init__(self, in_channels: int, hidden_channels: int, self.reset_parameters() - def reset_parameters(self): + def reset_parameters(self) -> None: self.lin1.reset_parameters() for conv, gru in zip(self.atom_convs, self.atom_grus): conv.reset_parameters() @@ -112,7 +124,8 @@ def reset_parameters(self): self.mol_gru.reset_parameters() self.lin2.reset_parameters() - def forward(self, x, edge_index, edge_attr, batch): + def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor, + batch: Tensor) -> Tensor: """""" # Atom Embedding: x = F.leaky_relu_(self.lin1(x))