Skip to content

Commit

Permalink
Fix GAT processors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 427537373
  • Loading branch information
adria-p authored and copybara-github committed Feb 9, 2022
1 parent 8287ae6 commit d2731dd
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions clrs/_src/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ def __call__(
skip = hk.Linear(self.out_size)

bias_mat = (adj - 1.0) * 1e9
bias_mat = jnp.tile(bias_mat, (1, 1, 1, self.nb_heads)) # [B, N, N, H]
bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2)) # [B, H, N, N]
bias_mat = jnp.tile(bias_mat[..., None],
(1, 1, 1, self.nb_heads)) # [B, N, N, H]
bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2)) # [B, H, N, N]

a_1 = hk.Linear(self.nb_heads)
a_2 = hk.Linear(self.nb_heads)
Expand Down Expand Up @@ -169,8 +170,9 @@ def __call__(
skip = hk.Linear(self.out_size)

bias_mat = (adj - 1.0) * 1e9
bias_mat = jnp.tile(bias_mat, (1, 1, 1, self.nb_heads)) # [B, N, N, H]
bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2)) # [B, H, N, N]
bias_mat = jnp.tile(bias_mat[..., None],
(1, 1, 1, self.nb_heads)) # [B, N, N, H]
bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2)) # [B, H, N, N]

w_1 = hk.Linear(self.mid_size)
w_2 = hk.Linear(self.mid_size)
Expand Down

0 comments on commit d2731dd

Please sign in to comment.