diff --git a/src/pidgan/callbacks/schedulers/__init__.py b/src/pidgan/callbacks/schedulers/__init__.py index e4eb891..36eae32 100644 --- a/src/pidgan/callbacks/schedulers/__init__.py +++ b/src/pidgan/callbacks/schedulers/__init__.py @@ -1,6 +1,6 @@ -import keras +import keras as k -k_vrs = keras.__version__.split(".")[:2] +k_vrs = k.__version__.split(".")[:2] k_vrs = float(".".join([n for n in k_vrs])) if k_vrs >= 3.0: diff --git a/src/pidgan/callbacks/schedulers/k2/LearnRateBaseScheduler.py b/src/pidgan/callbacks/schedulers/k2/LearnRateBaseScheduler.py index 0bbdd26..ee6ff53 100644 --- a/src/pidgan/callbacks/schedulers/k2/LearnRateBaseScheduler.py +++ b/src/pidgan/callbacks/schedulers/k2/LearnRateBaseScheduler.py @@ -1,16 +1,16 @@ +import keras as k import tensorflow as tf -import keras -K = keras.backend +K = k.backend -class LearnRateBaseScheduler(keras.callbacks.Callback): +class LearnRateBaseScheduler(k.callbacks.Callback): def __init__(self, optimizer, verbose=False, key="lr") -> None: super().__init__() self._name = "LearnRateBaseScheduler" # Optimizer - assert isinstance(optimizer, keras.optimizers.Optimizer) + assert isinstance(optimizer, k.optimizers.Optimizer) self._optimizer = optimizer # Verbose @@ -50,7 +50,7 @@ def name(self) -> str: return self._name @property - def optimizer(self) -> keras.optimizers.Optimizer: + def optimizer(self) -> k.optimizers.Optimizer: return self._optimizer @property