Skip to content

Commit

Permalink
added type-hints to attentive_fp.py
Browse files Browse the repository at this point in the history
  • Loading branch information
TanveshT committed Oct 18, 2022
1 parent 8cbbd72 commit 2b3b9f4
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions torch_geometric/nn/models/attentive_fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@


class GATEConv(MessagePassing):
def __init__(self, in_channels: int, out_channels: int, edge_dim: int,
dropout: float = 0.0):
def __init__(
self,
in_channels: int,
out_channels: int,
edge_dim: int,
dropout: float = 0.0,
):
super().__init__(aggr='add', node_dim=0)

self.dropout = dropout
Expand All @@ -29,7 +34,7 @@ def __init__(self, in_channels: int, out_channels: int, edge_dim: int,

self.reset_parameters()

def reset_parameters(self):
def reset_parameters(self) -> None:
glorot(self.att_l)
glorot(self.att_r)
glorot(self.lin1.weight)
Expand Down Expand Up @@ -73,9 +78,16 @@ class AttentiveFP(torch.nn.Module):
dropout (float, optional): Dropout probability. (default: :obj:`0.0`)
"""
def __init__(self, in_channels: int, hidden_channels: int,
out_channels: int, edge_dim: int, num_layers: int,
num_timesteps: int, dropout: float = 0.0):
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
edge_dim: int,
num_layers: int,
num_timesteps: int,
dropout: float = 0.0,
):
super().__init__()

self.num_layers = num_layers
Expand Down Expand Up @@ -103,7 +115,7 @@ def __init__(self, in_channels: int, hidden_channels: int,

self.reset_parameters()

def reset_parameters(self):
def reset_parameters(self) -> None:
self.lin1.reset_parameters()
for conv, gru in zip(self.atom_convs, self.atom_grus):
conv.reset_parameters()
Expand All @@ -112,7 +124,8 @@ def reset_parameters(self):
self.mol_gru.reset_parameters()
self.lin2.reset_parameters()

def forward(self, x, edge_index, edge_attr, batch):
def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor,
batch: Tensor) -> Tensor:
""""""
# Atom Embedding:
x = F.leaky_relu_(self.lin1(x))
Expand Down

0 comments on commit 2b3b9f4

Please sign in to comment.