Skip to content

Commit

Permalink
able to return loss breakdown
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 21, 2024
1 parent d201355 commit e878ad7
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 13 deletions.
35 changes: 25 additions & 10 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
Sequential,
)

from typing import Literal, Tuple
from typing import Literal, Tuple, NamedTuple

from alphafold3_pytorch.typing import (
Float,
Expand Down Expand Up @@ -2326,13 +2326,14 @@ def forward(

# main class

LossBreakdown = namedtuple('LossBreakdown', [
'distogram',
'pae',
'pde',
'plddt',
'resolved'
])
class LossBreakdown(NamedTuple):
diffusion: Float['']
distogram: Float['']
pae: Float['']
pde: Float['']
plddt: Float['']
resolved: Float['']
confidence: Float['']

class Alphafold3(Module):
""" Algorithm 1 """
Expand Down Expand Up @@ -2561,7 +2562,8 @@ def forward(
pde_labels: Int['b n n'] | None = None,
plddt_labels: Int['b n'] | None = None,
resolved_labels: Int['b n'] | None = None,
) -> Float['b m 3'] | Float['']:
return_loss_breakdown = False
) -> Float['b m 3'] | Float[''] | Tuple[Float[''], LossBreakdown]:

w = self.atoms_per_window

Expand Down Expand Up @@ -2754,4 +2756,17 @@ def forward(
confidence_loss * self.loss_confidence_weight
)

return loss
if not return_loss_breakdown:
return loss

loss_breakdown = LossBreakdown(
pae = pae_loss,
pde = pde_loss,
plddt = plddt_loss,
resolved = resolved_loss,
distogram = distogram_loss,
diffusion = diffusion_loss,
confidence = confidence_loss
)

return loss, loss_breakdown
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.0.11"
version = "0.0.12"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down
5 changes: 3 additions & 2 deletions tests/test_af3.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def test_alphafold3():
),
)

loss = alphafold3(
loss, breakdown = alphafold3(
num_recycling_steps = 2,
atom_inputs = atom_inputs,
atom_mask = atom_mask,
Expand All @@ -395,7 +395,8 @@ def test_alphafold3():
pae_labels = pae_labels,
pde_labels = pde_labels,
plddt_labels = plddt_labels,
resolved_labels = resolved_labels
resolved_labels = resolved_labels,
return_loss_breakdown = True
)

loss.backward()
Expand Down

0 comments on commit e878ad7

Please sign in to comment.