Skip to content

Commit

Permalink
implement remaining predictors
Browse files Browse the repository at this point in the history
  • Loading branch information
c-w-feldmann committed Apr 25, 2024
1 parent 83cd8b5 commit c7a4bff
Showing 1 changed file with 76 additions and 6 deletions.
82 changes: 76 additions & 6 deletions molpipeline/estimators/chemprop/component_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,36 @@
from chemprop.nn.agg import MeanAggregation as _MeanAggregation
from chemprop.nn.agg import SumAggregation as _SumAggregation
from chemprop.nn.ffn import MLP
from chemprop.nn.loss import BCELoss, LossFunction, MSELoss
from chemprop.nn.loss import (
BCELoss,
BinaryDirichletLoss,
CrossEntropyLoss,
EvidentialLoss,
LossFunction,
MSELoss,
MulticlassDirichletLoss,
MVELoss,
SIDLoss,
)
from chemprop.nn.message_passing import BondMessagePassing as _BondMessagePassing
from chemprop.nn.message_passing import MessagePassing
from chemprop.nn.metrics import BinaryAUROCMetric, Metric, MSEMetric
from chemprop.nn.metrics import (
BinaryAUROCMetric,
CrossEntropyMetric,
Metric,
MSEMetric,
SIDMetric,
)
from chemprop.nn.predictors import BinaryClassificationFFN as _BinaryClassificationFFN
from chemprop.nn.predictors import BinaryDirichletFFN as _BinaryDirichletFFN
from chemprop.nn.predictors import EvidentialFFN as _EvidentialFFN
from chemprop.nn.predictors import (
MulticlassClassificationFFN as _MulticlassClassificationFFN,
)
from chemprop.nn.predictors import MulticlassDirichletFFN as _MulticlassDirichletFFN
from chemprop.nn.predictors import MveFFN as _MveFFN
from chemprop.nn.predictors import RegressionFFN as _RegressionFFN
from chemprop.nn.predictors import SpectralFFN as _SpectralFFN
from chemprop.nn.predictors import (
_FFNPredictorBase as _Predictor, # pylint: disable=protected-access
)
Expand Down Expand Up @@ -265,6 +289,28 @@ def set_params(self, **params: Any) -> Self:
return self


class RegressionFFN(PredictorWrapper, _RegressionFFN): # type: ignore
"""A wrapper for the RegressionFFN class."""

n_targets: int = 1
_T_default_criterion = MSELoss
_T_default_metric = MSEMetric


class MveFFN(PredictorWrapper, _MveFFN): # type: ignore
"""A wrapper for the MveFFN class."""

n_targets: int = 2
_T_default_criterion = MVELoss


class EvidentialFFN(PredictorWrapper, _EvidentialFFN): # type: ignore
"""A wrapper for the EvidentialFFN class."""

n_targets: int = 4
_T_default_criterion = EvidentialLoss


class BinaryClassificationFFN(PredictorWrapper, _BinaryClassificationFFN): # type: ignore
"""A wrapper for the BinaryClassificationFFN class."""

Expand All @@ -273,12 +319,36 @@ class BinaryClassificationFFN(PredictorWrapper, _BinaryClassificationFFN): # ty
_T_default_metric = BinaryAUROCMetric


class RegressionFFN(PredictorWrapper, _RegressionFFN): # type: ignore
"""A wrapper for the RegressionFFN class."""
class BinaryDirichletFFN(PredictorWrapper, _BinaryDirichletFFN): # type: ignore
"""A wrapper for the BinaryDirichletFFN class."""

n_targets: int = 2
_T_default_criterion = BinaryDirichletLoss
_T_default_metric = BinaryAUROCMetric


class MulticlassClassificationFFN(PredictorWrapper, _MulticlassClassificationFFN): # type: ignore
"""A wrapper for the MulticlassClassificationFFN class."""

n_targets: int = 1
_T_default_criterion = MSELoss
_T_default_metric = MSEMetric
_T_default_criterion = CrossEntropyLoss
_T_default_metric = CrossEntropyMetric


class MulticlassDirichletFFN(PredictorWrapper, _MulticlassDirichletFFN): # type: ignore
"""A wrapper for the MulticlassDirichletFFN class."""

n_targets: int = 1
_T_default_criterion = MulticlassDirichletLoss
_T_default_metric = CrossEntropyMetric


class SpectralFFN(PredictorWrapper, _SpectralFFN): # type: ignore
"""A wrapper for the SpectralFFN class."""

n_targets: int = 1
_T_default_criterion = SIDLoss
_T_default_metric = SIDMetric


class MPNN(_MPNN, BaseEstimator):
Expand Down

0 comments on commit c7a4bff

Please sign in to comment.