Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mbarbetti committed Jun 19, 2024
1 parent 23b1865 commit 95d8375
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 27 deletions.
2 changes: 2 additions & 0 deletions src/pidgan/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
k_vrs = float(".".join([n for n in k_vrs]))

if k_vrs >= 3.0:
from .k3.BaseMetric import BaseMetric
from .k3.Accuracy import Accuracy
from .k3.BinaryCrossentropy import BinaryCrossentropy
from .k3.JSDivergence import JSDivergence
Expand All @@ -13,6 +14,7 @@
from .k3.RootMeanSquaredError import RootMeanSquaredError
from .k3.WassersteinDistance import WassersteinDistance
else:
from .k2.BaseMetric import BaseMetric
from .k2.Accuracy import Accuracy
from .k2.BinaryCrossentropy import BinaryCrossentropy
from .k2.JSDivergence import JSDivergence
Expand Down
4 changes: 2 additions & 2 deletions src/pidgan/optimization/callbacks/HopaasPruner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from tensorflow import keras
import keras as k


class HopaasPruner(keras.callbacks.Callback):
class HopaasPruner(k.callbacks.Callback):
def __init__(
self, trial, loss_name, report_frequency=1, enable_pruning=True
) -> None:
Expand Down
12 changes: 6 additions & 6 deletions src/pidgan/utils/checks/checkMetrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pidgan.metrics import MeanSquaredError as MSE
from pidgan.metrics import RootMeanSquaredError as RMSE
from pidgan.metrics import WassersteinDistance as Wass_dist
from pidgan.metrics.BaseMetric import BaseMetric
from pidgan.metrics import BaseMetric

METRIC_SHORTCUTS = [
"accuracy",
Expand Down Expand Up @@ -46,19 +46,19 @@ def checkMetrics(metrics): # TODO: add Union[list, None]
checked_metrics.append(calo_metric)
else:
raise ValueError(
f"`metrics` elements should be selected in "
f'"metrics" elements should be selected in '
f"{METRIC_SHORTCUTS}, instead '{metric}' passed"
)
elif isinstance(metric, BaseMetric):
checked_metrics.append(metric)
else:
raise TypeError(
f"`metrics` elements should be a pidgan's "
f"`BaseMetric`, instead {type(metric)} passed"
f'"metrics" elements should be a pidgan '
f"BaseMetric, instead {type(metric)} passed"
)
return checked_metrics
else:
raise TypeError(
f"`metrics` should be a list of strings or pidgan's "
f"`BaseMetric`s, instead {type(metrics)} passed"
f'"metrics" should be a list of strings or pidgan '
f"BaseMetrics, instead {type(metrics)} passed"
)
16 changes: 8 additions & 8 deletions src/pidgan/utils/checks/checkOptimizer.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
from tensorflow import keras
import keras as k

OPT_SHORTCUTS = ["sgd", "rmsprop", "adam"]
TF_OPTIMIZERS = [
keras.optimizers.SGD(),
keras.optimizers.RMSprop(),
keras.optimizers.Adam(),
k.optimizers.SGD(),
k.optimizers.RMSprop(),
k.optimizers.Adam(),
]


def checkOptimizer(optimizer) -> keras.optimizers.Optimizer:
def checkOptimizer(optimizer) -> k.optimizers.Optimizer:
if isinstance(optimizer, str):
if optimizer in OPT_SHORTCUTS:
for opt, tf_opt in zip(OPT_SHORTCUTS, TF_OPTIMIZERS):
if optimizer == opt:
return tf_opt
else:
raise ValueError(
f"`optimizer` should be selected in {OPT_SHORTCUTS}, "
f'"optimizer" should be selected in {OPT_SHORTCUTS}, '
f"instead '{optimizer}' passed"
)
elif isinstance(optimizer, keras.optimizers.Optimizer):
elif isinstance(optimizer, k.optimizers.Optimizer):
return optimizer
else:
raise TypeError(
f"`optimizer` should be a TensorFlow `Optimizer`, "
f'"optimizer" should be a Keras Optimizer, '
f"instead {type(optimizer)} passed"
)
17 changes: 10 additions & 7 deletions tests/optimization/callbacks/test_HopaasPruner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os

import hopaas_client as hpc
import keras as k
import numpy as np
import pytest
import yaml
from tensorflow import keras

NUM_TRIALS = 1
CHUNK_SIZE = int(1e3)
Expand All @@ -24,11 +24,14 @@
]
Y = np.tanh(X[:, 0]) + 2 * X[:, 1] * X[:, 2]

model = keras.Sequential()
model.add(keras.layers.InputLayer(input_shape=(3,)))
model = k.Sequential()
try:
model.add(k.layers.InputLayer(shape=(3,)))
except(ValueError):
model.add(k.layers.InputLayer(input_shape=(3,)))
for units in [16, 16, 16]:
model.add(keras.layers.Dense(units, activation="relu"))
model.add(keras.layers.Dense(1))
model.add(k.layers.Dense(units, activation="relu"))
model.add(k.layers.Dense(1))


@pytest.fixture
Expand Down Expand Up @@ -79,8 +82,8 @@ def test_callback_use(enable_pruning):

for _ in range(NUM_TRIALS):
with study.trial() as trial:
adam = keras.optimizers.Adam(learning_rate=trial.learning_rate)
mse = keras.losses.MeanSquaredError()
adam = k.optimizers.Adam(learning_rate=trial.learning_rate)
mse = k.losses.MeanSquaredError()

report = HopaasPruner(
trial=trial,
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/checks/test_checkMetrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from pidgan.metrics.BaseMetric import BaseMetric
from pidgan.metrics import BaseMetric
from pidgan.utils.checks.checkMetrics import METRIC_SHORTCUTS, PIDGAN_METRICS


Expand Down
6 changes: 3 additions & 3 deletions tests/utils/checks/test_checkOptimizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from tensorflow import keras
import keras as k

from pidgan.utils.checks.checkOptimizer import OPT_SHORTCUTS, TF_OPTIMIZERS

Expand All @@ -11,12 +11,12 @@ def test_checker_use_strings(optimizer):
from pidgan.utils.checks import checkOptimizer

res = checkOptimizer(optimizer)
assert isinstance(res, keras.optimizers.Optimizer)
assert isinstance(res, k.optimizers.Optimizer)


@pytest.mark.parametrize("optimizer", TF_OPTIMIZERS)
def test_checker_use_classes(optimizer):
from pidgan.utils.checks import checkOptimizer

res = checkOptimizer(optimizer)
assert isinstance(res, keras.optimizers.Optimizer)
assert isinstance(res, k.optimizers.Optimizer)

0 comments on commit 95d8375

Please sign in to comment.