-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'github/25-allow-to-trigger-pipelines-ma…
…nually' into 25-allow-to-trigger-pipelines-manually
- Loading branch information
Showing
14 changed files
with
442 additions
and
149 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
"""Wrapper for Chemprop loss functions.""" | ||
|
||
from typing import Any | ||
|
||
import torch | ||
from chemprop.nn.loss import BCELoss as _BCELoss | ||
from chemprop.nn.loss import BinaryDirichletLoss as _BinaryDirichletLoss | ||
from chemprop.nn.loss import CrossEntropyLoss as _CrossEntropyLoss | ||
from chemprop.nn.loss import EvidentialLoss as _EvidentialLoss | ||
from chemprop.nn.loss import LossFunction as _LossFunction | ||
from chemprop.nn.loss import MSELoss as _MSELoss | ||
from chemprop.nn.loss import MulticlassDirichletLoss as _MulticlassDirichletLoss | ||
from chemprop.nn.loss import MVELoss as _MVELoss | ||
from chemprop.nn.loss import SIDLoss as _SIDLoss | ||
from numpy.typing import ArrayLike | ||
|
||
|
||
class LossFunctionParamMixin: | ||
"""Mixin for loss functions to get and set parameters.""" | ||
|
||
_original_task_weights: ArrayLike | ||
|
||
def __init__(self: _LossFunction, task_weights: ArrayLike) -> None: | ||
"""Initialize the loss function. | ||
Parameters | ||
---------- | ||
task_weights : ArrayLike | ||
The weights for each task. | ||
""" | ||
super().__init__(task_weights=task_weights) # type: ignore | ||
self._original_task_weights = task_weights | ||
|
||
# pylint: disable=unused-argument | ||
def get_params(self: _LossFunction, deep: bool = True) -> dict[str, Any]: | ||
"""Get the parameters of the loss function. | ||
Parameters | ||
---------- | ||
deep : bool, optional | ||
Not used, only present to match the sklearn API. | ||
Returns | ||
------- | ||
dict[str, Any] | ||
The parameters of the loss function. | ||
""" | ||
return {"task_weights": self._original_task_weights} | ||
|
||
def set_params(self: _LossFunction, **params: Any) -> _LossFunction: | ||
"""Set the parameters of the loss function. | ||
Parameters | ||
---------- | ||
**params : Any | ||
The parameters to set. | ||
Returns | ||
------- | ||
Self | ||
The loss function with the new parameters. | ||
""" | ||
task_weights = params.pop("task_weights", None) | ||
if task_weights is not None: | ||
self._original_task_weights = task_weights | ||
state_dict = self.state_dict() | ||
state_dict["task_weights"] = torch.as_tensor( | ||
task_weights, dtype=torch.float | ||
).view(1, -1) | ||
self.load_state_dict(state_dict) | ||
return self | ||
|
||
|
||
class BCELoss(LossFunctionParamMixin, _BCELoss): | ||
"""Binary cross-entropy loss function.""" | ||
|
||
|
||
class BinaryDirichletLoss(LossFunctionParamMixin, _BinaryDirichletLoss): | ||
"""Binary Dirichlet loss function.""" | ||
|
||
|
||
class CrossEntropyLoss(LossFunctionParamMixin, _CrossEntropyLoss): | ||
"""Cross-entropy loss function.""" | ||
|
||
|
||
class EvidentialLoss(LossFunctionParamMixin, _EvidentialLoss): | ||
"""Evidential loss function.""" | ||
|
||
|
||
class MSELoss(LossFunctionParamMixin, _MSELoss): | ||
"""Mean squared error loss function.""" | ||
|
||
|
||
class MulticlassDirichletLoss(LossFunctionParamMixin, _MulticlassDirichletLoss): | ||
"""Multiclass Dirichlet loss function.""" | ||
|
||
|
||
class MVELoss(LossFunctionParamMixin, _MVELoss): | ||
"""Mean value entropy loss function.""" | ||
|
||
|
||
class SIDLoss(LossFunctionParamMixin, _SIDLoss): | ||
"""SID loss function.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
"""Functions for serializing and deserializing PyTorch models.""" | ||
|
||
from typing import TypeVar | ||
|
||
try: | ||
import torch | ||
|
||
TORCH_AVAILABLE = True | ||
except ImportError: | ||
TORCH_AVAILABLE = False | ||
from typing import Any, Literal | ||
|
||
_T = TypeVar("_T") | ||
|
||
if TORCH_AVAILABLE: | ||
|
||
def tensor_to_json( | ||
obj: _T, | ||
) -> tuple[dict[str, Any], Literal[True]] | tuple[_T, Literal[False]]: | ||
"""Recursively convert a PyTorch model to a JSON-serializable object. | ||
Parameters | ||
---------- | ||
obj : object | ||
The object to convert. | ||
Returns | ||
------- | ||
object | ||
The JSON-serializable object. | ||
""" | ||
if isinstance(obj, torch.Tensor): | ||
object_dict: dict[str, Any] = { | ||
"__name__": obj.__class__.__name__, | ||
"__module__": obj.__class__.__module__, | ||
"__init__": True, | ||
} | ||
else: | ||
return obj, False | ||
object_dict["data"] = obj.tolist() | ||
return object_dict, True | ||
|
||
else: | ||
|
||
def tensor_to_json( | ||
obj: _T, | ||
) -> tuple[dict[str, Any], Literal[True]] | tuple[_T, Literal[False]]: | ||
"""Recursively convert a PyTorch model to a JSON-serializable object. | ||
Parameters | ||
---------- | ||
obj : object | ||
The object to convert. | ||
Returns | ||
------- | ||
object | ||
The JSON-serializable object. | ||
""" | ||
return obj, False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
chemprop>=2.0.0 | ||
chemprop >= 2.0.0, < 2.0.3 | ||
lightning |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Functions repeatedly used in tests for Chemprop models.""" |
54 changes: 54 additions & 0 deletions
54
test_extras/test_chemprop/chemprop_test_utils/compare_models.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
"""Functions for comparing chemprop models.""" | ||
|
||
from typing import Sequence | ||
from unittest import TestCase | ||
|
||
import torch | ||
from chemprop.nn.loss import LossFunction | ||
from lightning.pytorch.accelerators import Accelerator | ||
from lightning.pytorch.profilers.base import PassThroughProfiler | ||
from sklearn.base import BaseEstimator | ||
from torch import nn | ||
|
||
|
||
def compare_params( | ||
test_case: TestCase, model_a: BaseEstimator, model_b: BaseEstimator | ||
) -> None: | ||
"""Compare the parameters of two models. | ||
Parameters | ||
---------- | ||
test_case : TestCase | ||
The test case for which to raise the assertion. | ||
model_a : BaseEstimator | ||
The first model. | ||
model_b : BaseEstimator | ||
The second model. | ||
""" | ||
model_a_params = model_a.get_params(deep=True) | ||
model_b_params = model_b.get_params(deep=True) | ||
test_case.assertSetEqual(set(model_a_params.keys()), set(model_b_params.keys())) | ||
for param_name, param_a in model_a_params.items(): | ||
param_b = model_b_params[param_name] | ||
test_case.assertEqual(param_a.__class__, param_b.__class__) | ||
if hasattr(param_a, "get_params"): | ||
test_case.assertTrue(hasattr(param_b, "get_params")) | ||
test_case.assertNotEqual(id(param_a), id(param_b)) | ||
elif isinstance(param_a, LossFunction): | ||
test_case.assertEqual( | ||
param_a.state_dict()["task_weights"], | ||
param_b.state_dict()["task_weights"], | ||
) | ||
test_case.assertEqual(type(param_a), type(param_b)) | ||
elif isinstance(param_a, (nn.Identity, Accelerator, PassThroughProfiler)): | ||
test_case.assertEqual(type(param_a), type(param_b)) | ||
elif isinstance(param_a, torch.Tensor): | ||
test_case.assertTrue( | ||
torch.equal(param_a, param_b), f"Test failed for {param_name}" | ||
) | ||
elif param_name == "lightning_trainer__callbacks": | ||
test_case.assertIsInstance(param_b, Sequence) | ||
for i, callback in enumerate(param_a): | ||
test_case.assertIsInstance(callback, type(param_b[i])) | ||
else: | ||
test_case.assertEqual(param_a, param_b, f"Test failed for {param_name}") |
Oops, something went wrong.