From 5a3542b6aef643d4e43bbe9a9ee38b651c556da0 Mon Sep 17 00:00:00 2001 From: Kartheek L Date: Fri, 3 May 2024 22:58:08 +0530 Subject: [PATCH] mlx - merge master into mlx (#19657) * Introduce float8 training (#19488) * Add float8 training support * Add tests for fp8 training * Add `quantize_and_dequantize` test * Fix bugs and add float8 correctness tests * Cleanup * Address comments and cleanup * Add docstrings and some minor refactoring * Add `QuantizedFloat8DTypePolicy` * Add dtype policy setter * Fix torch dynamo issue by using `self._dtype_policy` * Improve test coverage * Add LoRA to ConvND layers (#19516) * Add LoRA to `BaseConv` * Add tests * Fix typo * Fix tests * Fix tests * Add path to run keras on dm-tree when optree is not available. * feat(losses): add Tversky loss implementation (#19511) * feat(losses): add Tversky loss implementation * adjusted documentation * Update KLD docs * Models and layers now return owned metrics recursively. (#19522) - added `Layer.metrics` to return all metrics owned by the layer and its sub-layers recursively. - `Layer.metrics_variables` now returns variables from all metrics recursively, not just the layer and its direct sub-layers. - `Model.metrics` now returns all metrics recursively, not just the model level metrics. - `Model.metrics_variables` now returns variables from all metrics recursively, not just the model level metrics. - added test coverage to test metrics and variables 2 levels deep. This is consistent with the Keras 2 behavior and how `Model/Layer.variables` and `Model/Layer.weights` work. * Update IoU ignore_class handling * Fix `RandomBrightness`, Enhance `IndexLookup` Initialization and Expand Test Coverage for `Preprocessing Layers` (#19513) * Add tests for CategoryEncoding class in category_encoding_test.py * fix * Fix IndexLookup class initialization and add test cases * Add test case for IndexLookupLayerTest without vocabulary * Fix IndexLookup class initialization * Add normalization test cases * Add test cases for Hashing class * Fix value range validation error in RandomBrightness class * Refactor IndexLookup class initialization and add test cases * Reffix ndexLookup class initialization and afix est cases * Add test for spectral norm * Add missing test decorator * Fix torch test * Fix code format * Generate API (#19530) * API Generator for Keras * API Generator for Keras * Generates API Gen via api_gen.sh * Remove recursive import of _tf_keras * Generate API Files via api_gen.sh * Update APIs * Added metrics from custom `train_step`/`test_step` are now returned. (#19529) This works the same way as in Keras 2, whereby the metrics are returned directly from the logs if the set of keys doesn't match the model metrics. * Use temp dir and abs path in `api_gen.py` (#19533) * Use temp dir and abs path * Use temp dir and abs path * Update Readme * Update API * Fix gradient accumulation when using `overwrite_with_gradient` during float8 training (#19534) * Fix gradient accumulation with `overwrite_with_gradient` in float8 training * Add comments * Fix annotation * Update code path in ignore path (#19537) * Add operations per run (#19538) * Include input shapes in model visualization. * Add pad_to_aspect_ratio feature in ops.image.resize * Add pad_to_aspect_ratio feature in Resizing layer. * Fix incorrect usage of `quantize` (#19541) * Add logic to prevent double quantization * Add detailed info for double quantization error * Update error msg * Add eigh op. * Add keepdim in argmax/argmin. * Fix small bug in model.save_weights (#19545) * Update public APIs. * eigh should work on JAX GPU * Copy init to keras/__init__.py (#19551) * Revert "Copy init to keras/__init__.py (#19551)" (#19552) This reverts commit da9af61032d5eaf7b5c64bcc7b215bff42dbe19a. * sum-reduce inlined losses * Remove the dependency on `tensorflow.experimental.numpy` and support negative indices for `take` and `take_along_axis` (#19556) * Remove `tfnp` * Update numpy api * Improve test coverage * Improve test coverage * Fix `Tri` and `Eye` and increase test converage * Update `round` test * Fix `jnp.round` * Fix `diag` bug for iou_metrics * Add op.select. * Add new API for select * Make `ops.abs` and `ops.absolute` consistent between backends. (#19563) - The TensorFlow implementation was missing `convert_to_tensor` - The sparse annotation was unnecessarily applied twice - Now `abs` calls `absolute` in all backends Also fixed TensorFlow `ops.select`. * Add pickle support for Keras model (#19555) * Implement unit tests for pickling * Reformat model_test * Reformat model_test * Rename depickle to unpickle * Rename depickle to unpickle * Reformat * remove a comment * Ellipsis Serialization and tests (#19564) * Serialization and tests * Serialization and tests * Serialization and tests * Make TF one_hot input dtype less strict. * Fix einsum `_int8_call` (#19570) * CTC Decoding for JAX and Tensorflow (#19366) * Tensorflow OP for CTC decoding * JAX op for CTC greedy decoding * Update CTC decoding documentation * Fix linting issues * Fix trailing whitespace * Simplify returns in tensorflow CTC wrapper * Fix CTC decoding error messages * Fix line too long * Bug fixes to JAX CTC greedy decoder * Force int typecast in TF CTC decoder * Unit tests for CTC greedy decoding * Add unit test for CTC beam search decoding * Fix mask index set location in JAX CTC decoding * CTC beam search decoding for JAX * Fix unhandled token repetitions in ctc_beam_search_decode * Fix merge_repeated bug in CTC beam search decode * Fix beam storage and repetition bugs in JAX ctc_decode * Remove trailing whitespace * Fix ordering bug for ties in JAX CTC beam search * Cast sequence lengths to integers in JAX ctc_decode * Remove line break in docstring * CTC beam search decoding for JAX * Fix unhandled token repetitions in ctc_beam_search_decode * Fix merge_repeated bug in CTC beam search decode * Fix beam storage and repetition bugs in JAX ctc_decode * Fix ordering bug for ties in JAX CTC beam search * Generate public api directory * Add not implemented errors for NumPy and Torch CTC decoding * Remove unused redefinition of JAX ctc_beam_search_decode * Docstring edits * Expand nan_to_num args. * Add vectorize op. * list insert requires index (#19575) * Add signature and exclude args to knp.vectorize. * Fix the apis of `dtype_polices` (#19580) * Fix api of `dtype_polices` * Update docstring * Increase test coverage * Fix format * Fix keys of `save_own_variables` and `load_own_variables` (#19581) * Fix JAX CTC test. * Fix loss_weights handling in single output case * Fix JAX vectorize. * Move _tf_keras directory to the root of the pip package. * One time fix to _tf_keras API. * Convert return type imdb.load_data to nparray (#19598) Convert return type imdb.load_data to Numpy array. Currently X_train and X-test returned as list. * Fix typo * fix api_gen.py for legacy (#19590) * fix api_gen.py for legacy * merge api and legacy for _tf_keras * Improve int8 for `Embedding` (#19595) * pin torch < 2.3.0 (#19603) * Clean up duplicated `inputs_quantizer` (#19604) * Cleanup duplicated `inputs_quantizer` and add type check for `input_spec` and `supports_masking` * Revert setter * output format changes and errors in github (#19608) * Provide write permission to action for cache management. (#19606) * Pickle support for all saveables (#19592) * Pickle support * Add keras pickleable mixin * Reformat * Implement pickle all over * reformat * Reformat * Keras saveable * Keras saveable * Keras saveable * Keras saveable * Keras saveable * obj_type * Update pickleable * Saveable logic touchups * Add slogdet op. * Update APIs * Remove unused import * Refactor CTC APIs (#19611) * Add `ctc_loss` and `ctc_decode` for numpy backend, improve imports and tests * Support "beam_search" strategy for torch's `ctc_decode` * Improve `ctc_loss` * Cleanup * Refactor `ctc_decode` * Update docstring * Update docstring * Add `CTCDecode` operation and ensure dtype inference of `ctc_decode` * Fix `name` of `losses.CTC` * update the namex version requirements (#19617) * Add `PSNR` API (#19616) * PSNR * Fix * Docstring format * Remove `PYTORCH_ENABLE_MPS_FALLBACK` flag requirement for mps (#19618) * Remove `PYTORCH_ENABLE_MPS_FALLBACK` flag requirement for mps * Formatting * Implement custom layer insertion in clone_model. (#19610) * Implement custom layer insertion in clone_model. * Add recursive arg and tests. * Add nested sequential cloning test * Fix bidir lstm saving issue. * Fix CI * Fix cholesky tracing with jax * made extract_patches dtype agnostic (#19621) * Simplify Bidirectional implementation * Add support for infinite `PyDataset`s. (#19624) `PyDataset` now uses the `num_batches` property instead of `__len__` to support `None`, which is how one indicates the dataset is infinite. Note that infinite datasets are not shuffled. Fixes https://github.com/keras-team/keras/issues/19528 Also added exception reporting when using multithreading / multiprocessing. Previously, the program would just hang with no error reported. * Fix dataset shuffling issue. * Update version string. * Minor fix * Restore version string resolution in pip_build. * Speed up `DataAdapter` tests by testing only the current backend. (#19625) There is no use case for using an iterator for a different backend than the current backend. Also: - limit the number of tests using multiprocessing, the threading tests give us good coverage. - fixed the `test_exception_reported` test, which was not actually exercising the multiprocessing / multithreading cases. - removed unused `init_pool` method. * feat(ops): support np.argpartition (#19588) * feat(ops): support np.argpartition * updated documentation, type-casting, and tf implementation * fixed tf implementation * added torch cast to int32 * updated torch type and API generated files * added torch output type cast * test(trainers): add test_errors implementation for ArrayDataAdapter class (#19626) * Fix torch GPU CI * Fix argmax/argmin keepdims with defined axis in TF * Misc fixes in TF backend ops. * Fix `argpartition` cuda bug in torch (#19634) * fix(ops): specify NonZero output dtype and add test coverage (#19635) * Fix `ops.ctc_decode` (#19633) * Fix greedy ctc decode * Remove print * Fix `tf.nn.ctc_beam_search_decoder` * Change default `mask_index` to `0` * Fix losses test * Update * Ensure the same rule applies for np arrays in autocasting (#19636) * Ensure the same rule applies for np arrays in autocasting * Trigger CI by adding docstring * Update * Update docstring * Fix `istft` and add class `TestMathErrors` in `ops/math_test.py` (#19594) * Fix and test math functions for jax backend * run /workspaces/keras/shell/format.sh * refix * fix * fix _get_complex_tensor_from_tuple * fix * refix * Fix istft function to handle inputs with less than 2 dimensions * fix * Fix ValueError in istft function for inputs with less than 2 dimensions * Return a tuple from `ops.shape` with the Torch backend. (#19640) With Torch, `x.shape` returns a `torch.Size`, which is a subclass of `tuple` but can cause different behaviors. In particular `convert_to_tensor` does not work on `torch.Size`. This fixes https://github.com/keras-team/keras/issues/18900 * support conv3d on cpu for TF (#19641) * Enable cudnn rnns when dropout is set (#19645) * Enable cudnn rnns when dropout is set * Fix * Fix plot_model for input dicts. * Fix deprecation warning in torch * Bump the github-actions group with 2 updates (#19653) Bumps the github-actions group with 2 updates: [actions/upload-artifact](https://github.com/actions/upload-artifact) and [github/codeql-action](https://github.com/github/codeql-action). Updates `actions/upload-artifact` from 4.3.1 to 4.3.3 - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/5d5d22a31266ced268874388b861e4b58bb5c2f3...65462800fd760344b1a7b4382951275a0abb4808) Updates `github/codeql-action` from 3.24.9 to 3.25.3 - [Release notes](https://github.com/github/codeql-action/releases) - [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/github/codeql-action/compare/1b1aada464948af03b950897e5eb522f92603cc2...d39d31e687223d841ef683f52467bd88e9b21c14) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-type: direct:production update-type: version-update:semver-patch dependency-group: github-actions - dependency-name: github/codeql-action dependency-type: direct:production update-type: version-update:semver-minor dependency-group: github-actions ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Bump the python group with 2 updates (#19654) Bumps the python group with 2 updates: torch and torchvision. Updates `torch` from 2.2.1+cu121 to 2.3.0+cu121 Updates `torchvision` from 0.17.1+cu121 to 0.18.0+cu121 --- updated-dependencies: - dependency-name: torch dependency-type: direct:production update-type: version-update:semver-minor dependency-group: python - dependency-name: torchvision dependency-type: direct:production update-type: version-update:semver-minor dependency-group: python ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Revert "Bump the python group with 2 updates (#19654)" (#19655) This reverts commit 09133f459d4158d35ca582433c2577a02696f62e. --------- Signed-off-by: dependabot[bot] Co-authored-by: james77777778 <20734616+james77777778@users.noreply.github.com> Co-authored-by: Francois Chollet Co-authored-by: Luca Pizzini Co-authored-by: hertschuh <1091026+hertschuh@users.noreply.github.com> Co-authored-by: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Co-authored-by: Sachin Prasad Co-authored-by: Uwe Schmidt Co-authored-by: Luke Wood Co-authored-by: Maanas Arora Co-authored-by: AlexanderLavelle <73360008+AlexanderLavelle@users.noreply.github.com> Co-authored-by: Surya <116063290+SuryanarayanaY@users.noreply.github.com> Co-authored-by: Shivam Mishra <124146945+shmishra99@users.noreply.github.com> Co-authored-by: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com> Co-authored-by: IMvision12 <88665786+IMvision12@users.noreply.github.com> Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com> Co-authored-by: Vachan V Y <109357590+VachanVY@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/actions.yml | 9 + .github/workflows/nightly.yml | 10 + .github/workflows/scorecard.yml | 4 +- .github/workflows/stale-issue-pr.yaml | 1 + SECURITY.md | 2 +- api_gen.py | 18 +- keras/api/_tf_keras/keras/__init__.py | 11 +- keras/api/_tf_keras/keras/backend/__init__.py | 123 ++++ .../keras/dtype_policies/__init__.py | 3 + keras/api/_tf_keras/keras/layers/__init__.py | 5 +- keras/api/_tf_keras/keras/losses/__init__.py | 19 +- keras/api/_tf_keras/keras/metrics/__init__.py | 18 +- keras/api/_tf_keras/keras/ops/__init__.py | 6 + keras/api/_tf_keras/keras/ops/nn/__init__.py | 2 + .../api/_tf_keras/keras/ops/numpy/__init__.py | 4 + .../_tf_keras/keras/preprocessing/__init__.py | 5 +- .../keras/preprocessing/image/__init__.py | 13 + .../keras/preprocessing/sequence/__init__.py | 3 + .../keras/preprocessing/text/__init__.py | 11 + keras/api/dtype_policies/__init__.py | 3 + keras/api/ops/__init__.py | 6 + keras/api/ops/nn/__init__.py | 2 + keras/api/ops/numpy/__init__.py | 4 + keras/src/backend/common/backend_utils.py | 223 +++++++ keras/src/backend/jax/linalg.py | 17 +- keras/src/backend/jax/math.py | 6 + keras/src/backend/jax/nn.py | 415 ++++++++++--- keras/src/backend/jax/numpy.py | 40 +- keras/src/backend/numpy/nn.py | 400 ++++++++++++- keras/src/backend/numpy/numpy.py | 20 +- keras/src/backend/tensorflow/linalg.py | 5 +- keras/src/backend/tensorflow/math.py | 5 +- keras/src/backend/tensorflow/nn.py | 178 ++++-- keras/src/backend/tensorflow/numpy.py | 557 ++++++++++++++---- keras/src/backend/tensorflow/random.py | 7 +- keras/src/backend/torch/core.py | 9 +- keras/src/backend/torch/linalg.py | 2 +- keras/src/backend/torch/nn.py | 137 ++++- keras/src/backend/torch/numpy.py | 87 ++- keras/src/callbacks/early_stopping.py | 1 - keras/src/datasets/imdb.py | 4 +- keras/src/dtype_policies/__init__.py | 77 ++- keras/src/dtype_policies/dtype_policy.py | 37 +- keras/src/dtype_policies/dtype_policy_test.py | 55 +- keras/src/layers/convolutional/base_conv.py | 12 +- keras/src/layers/core/dense.py | 42 +- keras/src/layers/core/einsum_dense.py | 48 +- keras/src/layers/core/einsum_dense_test.py | 22 + keras/src/layers/core/embedding.py | 35 +- keras/src/layers/layer.py | 6 +- keras/src/layers/layer_test.py | 37 +- .../spectral_normalization_test.py | 2 +- keras/src/layers/rnn/bidirectional.py | 11 +- keras/src/layers/rnn/gru.py | 13 +- keras/src/layers/rnn/lstm.py | 14 +- keras/src/losses/__init__.py | 4 +- keras/src/losses/loss.py | 6 +- keras/src/losses/loss_test.py | 7 + keras/src/losses/losses.py | 14 +- keras/src/losses/losses_test.py | 2 +- keras/src/metrics/metric.py | 6 +- keras/src/metrics/metric_test.py | 7 + keras/src/models/cloning.py | 169 +++++- keras/src/models/cloning_test.py | 102 +++- keras/src/models/functional.py | 3 + keras/src/models/model.py | 4 +- keras/src/models/model_test.py | 27 +- keras/src/models/sequential.py | 3 + keras/src/models/sequential_test.py | 9 + keras/src/models/variable_mapping.py | 45 +- keras/src/ops/core_test.py | 4 +- keras/src/ops/function.py | 8 +- keras/src/ops/image.py | 2 +- keras/src/ops/linalg_test.py | 9 + keras/src/ops/math_test.py | 88 +++ keras/src/ops/nn.py | 192 +++++- keras/src/ops/nn_test.py | 278 ++++++++- keras/src/ops/numpy.py | 377 +++++++----- keras/src/ops/numpy_test.py | 267 ++++++++- keras/src/optimizers/base_optimizer.py | 6 +- keras/src/optimizers/optimizer_test.py | 25 +- keras/src/saving/keras_saveable.py | 38 ++ keras/src/saving/saving_lib.py | 227 ++++--- keras/src/saving/saving_lib_test.py | 12 + keras/src/saving/serialization_lib.py | 7 + keras/src/saving/serialization_lib_test.py | 4 + keras/src/testing/test_case.py | 14 + keras/src/trainers/compile_utils.py | 12 +- keras/src/trainers/data_adapters/__init__.py | 7 + .../data_adapters/array_data_adapter_test.py | 68 ++- .../generator_data_adapter_test.py | 40 +- .../data_adapters/py_dataset_adapter.py | 119 ++-- .../data_adapters/py_dataset_adapter_test.py | 132 ++++- .../data_adapters/tf_dataset_adapter_test.py | 29 +- .../torch_data_loader_adapter_test.py | 45 +- keras/src/utils/audio_dataset_utils.py | 51 +- keras/src/utils/image_dataset_utils.py | 28 +- keras/src/utils/model_visualization.py | 18 +- keras/src/utils/text_dataset_utils.py | 29 +- keras/src/utils/tracking.py | 4 +- keras/src/version.py | 2 +- pip_build.py | 63 +- requirements-common.txt | 2 +- requirements-jax-cuda.txt | 2 +- requirements-tensorflow-cuda.txt | 2 +- requirements.txt | 3 +- shell/api_gen.sh | 2 +- 107 files changed, 4396 insertions(+), 1034 deletions(-) create mode 100644 keras/api/_tf_keras/keras/preprocessing/text/__init__.py create mode 100644 keras/src/saving/keras_saveable.py diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 342d9414c3e..7a9e7b10cc5 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -126,3 +126,12 @@ jobs: fi - name: Lint run: bash shell/lint.sh + - name: Check for API changes + run: | + bash shell/api_gen.sh + git status + clean=$(git status | grep "nothing to commit") + if [ -z "$clean" ]; then + echo "Please run shell/api_gen.sh to generate API." + exit 1 + fi diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 5edfd2c988b..9c827a62c99 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -92,6 +92,16 @@ jobs: fi - name: Lint run: bash shell/lint.sh + - name: Check for API changes + run: | + bash shell/api_gen.sh + git status + clean=$(git status | grep "nothing to commit") + if [ -z "$clean" ]; then + echo "Please run shell/api_gen.sh to generate API." + exit 1 + fi + nightly: name: Build Wheel file and upload diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index 1286b78b527..2fe3b5cdca8 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -48,7 +48,7 @@ jobs: # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF # format to the repository Actions tab. - name: "Upload artifact" - uses: actions/upload-artifact@5d5d22a31266ced268874388b861e4b58bb5c2f3 # v4.3.1 + uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # v4.3.3 with: name: SARIF file path: results.sarif @@ -56,6 +56,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard. - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@1b1aada464948af03b950897e5eb522f92603cc2 # v3.24.9 + uses: github/codeql-action/upload-sarif@d39d31e687223d841ef683f52467bd88e9b21c14 # v3.25.3 with: sarif_file: results.sarif diff --git a/.github/workflows/stale-issue-pr.yaml b/.github/workflows/stale-issue-pr.yaml index a5c570dd780..309760a0751 100644 --- a/.github/workflows/stale-issue-pr.yaml +++ b/.github/workflows/stale-issue-pr.yaml @@ -10,6 +10,7 @@ jobs: permissions: issues: write pull-requests: write + actions: write steps: - name: Awaiting response issues uses: actions/stale@v9 diff --git a/SECURITY.md b/SECURITY.md index 90853890d8b..e2ccb038246 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -59,7 +59,7 @@ Besides the virtual environment, the hardware (GPUs or TPUs) can also be attacke ## Reporting a Vulnerability -Beware that none of the topics under [Using Keras Securely](#using-Keras-securely) are considered vulnerabilities of Keras. +Beware that none of the topics under [Using Keras Securely](#using-keras-securely) are considered vulnerabilities of Keras. If you have discovered a security vulnerability in this project, please report it privately. **Do not disclose it as a public issue.** This gives us time to work with you diff --git a/api_gen.py b/api_gen.py index 28fac8fa4f1..a38ff20fd23 100644 --- a/api_gen.py +++ b/api_gen.py @@ -7,6 +7,7 @@ """ import os +import re import shutil import namex @@ -78,8 +79,7 @@ def create_legacy_directory(package_dir): for path in os.listdir(os.path.join(src_dir, "legacy")) if os.path.isdir(os.path.join(src_dir, "legacy", path)) ] - - for root, _, fnames in os.walk(os.path.join(package_dir, "_legacy")): + for root, _, fnames in os.walk(os.path.join(api_dir, "_legacy")): for fname in fnames: if fname.endswith(".py"): legacy_fpath = os.path.join(root, fname) @@ -110,6 +110,20 @@ def create_legacy_directory(package_dir): f"keras.api.{legacy_submodule}", f"keras.api._tf_keras.keras.{legacy_submodule}", ) + # Remove duplicate generated comments string. + legacy_contents = re.sub(r"\n", r"\\n", legacy_contents) + legacy_contents = re.sub('""".*"""', "", legacy_contents) + legacy_contents = re.sub(r"\\n", r"\n", legacy_contents) + # If the same module is in legacy and core_api, use legacy + legacy_imports = re.findall( + r"import (\w+)", legacy_contents + ) + for import_name in legacy_imports: + core_api_contents = re.sub( + f"\n.* import {import_name}\n", + r"\n", + core_api_contents, + ) legacy_contents = core_api_contents + "\n" + legacy_contents with open(tf_keras_fpath, "w") as f: f.write(legacy_contents) diff --git a/keras/api/_tf_keras/keras/__init__.py b/keras/api/_tf_keras/keras/__init__.py index 334dc282386..5e0a7229473 100644 --- a/keras/api/_tf_keras/keras/__init__.py +++ b/keras/api/_tf_keras/keras/__init__.py @@ -6,7 +6,6 @@ from keras.api import activations from keras.api import applications -from keras.api import backend from keras.api import callbacks from keras.api import config from keras.api import constraints @@ -15,21 +14,21 @@ from keras.api import dtype_policies from keras.api import export from keras.api import initializers -from keras.api import layers from keras.api import legacy -from keras.api import losses -from keras.api import metrics from keras.api import mixed_precision from keras.api import models from keras.api import ops from keras.api import optimizers -from keras.api import preprocessing from keras.api import quantizers from keras.api import random from keras.api import regularizers -from keras.api import saving from keras.api import tree from keras.api import utils +from keras.api._tf_keras.keras import backend +from keras.api._tf_keras.keras import layers +from keras.api._tf_keras.keras import losses +from keras.api._tf_keras.keras import metrics +from keras.api._tf_keras.keras import preprocessing from keras.src.backend.common.keras_tensor import KerasTensor from keras.src.backend.common.stateless_scope import StatelessScope from keras.src.backend.exports import Variable diff --git a/keras/api/_tf_keras/keras/backend/__init__.py b/keras/api/_tf_keras/keras/backend/__init__.py index 840bde6e4de..94ccc4bf3d8 100644 --- a/keras/api/_tf_keras/keras/backend/__init__.py +++ b/keras/api/_tf_keras/keras/backend/__init__.py @@ -17,4 +17,127 @@ from keras.src.backend.config import set_epsilon from keras.src.backend.config import set_floatx from keras.src.backend.config import set_image_data_format +from keras.src.legacy.backend import abs +from keras.src.legacy.backend import all +from keras.src.legacy.backend import any +from keras.src.legacy.backend import arange +from keras.src.legacy.backend import argmax +from keras.src.legacy.backend import argmin +from keras.src.legacy.backend import batch_dot +from keras.src.legacy.backend import batch_flatten +from keras.src.legacy.backend import batch_get_value +from keras.src.legacy.backend import batch_normalization +from keras.src.legacy.backend import batch_set_value +from keras.src.legacy.backend import bias_add +from keras.src.legacy.backend import binary_crossentropy +from keras.src.legacy.backend import binary_focal_crossentropy +from keras.src.legacy.backend import cast +from keras.src.legacy.backend import cast_to_floatx +from keras.src.legacy.backend import categorical_crossentropy +from keras.src.legacy.backend import categorical_focal_crossentropy +from keras.src.legacy.backend import clip +from keras.src.legacy.backend import concatenate +from keras.src.legacy.backend import constant +from keras.src.legacy.backend import conv1d +from keras.src.legacy.backend import conv2d +from keras.src.legacy.backend import conv2d_transpose +from keras.src.legacy.backend import conv3d +from keras.src.legacy.backend import cos +from keras.src.legacy.backend import count_params +from keras.src.legacy.backend import ctc_batch_cost +from keras.src.legacy.backend import ctc_decode +from keras.src.legacy.backend import ctc_label_dense_to_sparse +from keras.src.legacy.backend import cumprod +from keras.src.legacy.backend import cumsum +from keras.src.legacy.backend import depthwise_conv2d +from keras.src.legacy.backend import dot +from keras.src.legacy.backend import dropout +from keras.src.legacy.backend import dtype +from keras.src.legacy.backend import elu +from keras.src.legacy.backend import equal +from keras.src.legacy.backend import eval +from keras.src.legacy.backend import exp +from keras.src.legacy.backend import expand_dims +from keras.src.legacy.backend import eye +from keras.src.legacy.backend import flatten +from keras.src.legacy.backend import foldl +from keras.src.legacy.backend import foldr +from keras.src.legacy.backend import gather +from keras.src.legacy.backend import get_value +from keras.src.legacy.backend import gradients +from keras.src.legacy.backend import greater +from keras.src.legacy.backend import greater_equal +from keras.src.legacy.backend import hard_sigmoid +from keras.src.legacy.backend import in_top_k +from keras.src.legacy.backend import int_shape +from keras.src.legacy.backend import is_sparse +from keras.src.legacy.backend import l2_normalize +from keras.src.legacy.backend import less +from keras.src.legacy.backend import less_equal +from keras.src.legacy.backend import log +from keras.src.legacy.backend import map_fn +from keras.src.legacy.backend import max +from keras.src.legacy.backend import maximum +from keras.src.legacy.backend import mean +from keras.src.legacy.backend import min +from keras.src.legacy.backend import minimum +from keras.src.legacy.backend import moving_average_update +from keras.src.legacy.backend import name_scope +from keras.src.legacy.backend import ndim +from keras.src.legacy.backend import not_equal +from keras.src.legacy.backend import one_hot +from keras.src.legacy.backend import ones +from keras.src.legacy.backend import ones_like +from keras.src.legacy.backend import permute_dimensions +from keras.src.legacy.backend import pool2d +from keras.src.legacy.backend import pool3d +from keras.src.legacy.backend import pow +from keras.src.legacy.backend import prod +from keras.src.legacy.backend import random_bernoulli +from keras.src.legacy.backend import random_normal +from keras.src.legacy.backend import random_normal_variable +from keras.src.legacy.backend import random_uniform +from keras.src.legacy.backend import random_uniform_variable +from keras.src.legacy.backend import relu +from keras.src.legacy.backend import repeat +from keras.src.legacy.backend import repeat_elements +from keras.src.legacy.backend import reshape +from keras.src.legacy.backend import resize_images +from keras.src.legacy.backend import resize_volumes +from keras.src.legacy.backend import reverse +from keras.src.legacy.backend import rnn +from keras.src.legacy.backend import round +from keras.src.legacy.backend import separable_conv2d +from keras.src.legacy.backend import set_value +from keras.src.legacy.backend import shape +from keras.src.legacy.backend import sigmoid +from keras.src.legacy.backend import sign +from keras.src.legacy.backend import sin +from keras.src.legacy.backend import softmax +from keras.src.legacy.backend import softplus +from keras.src.legacy.backend import softsign +from keras.src.legacy.backend import sparse_categorical_crossentropy +from keras.src.legacy.backend import spatial_2d_padding +from keras.src.legacy.backend import spatial_3d_padding +from keras.src.legacy.backend import sqrt +from keras.src.legacy.backend import square +from keras.src.legacy.backend import squeeze +from keras.src.legacy.backend import stack +from keras.src.legacy.backend import std +from keras.src.legacy.backend import stop_gradient +from keras.src.legacy.backend import sum +from keras.src.legacy.backend import switch +from keras.src.legacy.backend import tanh +from keras.src.legacy.backend import temporal_padding +from keras.src.legacy.backend import tile +from keras.src.legacy.backend import to_dense +from keras.src.legacy.backend import transpose +from keras.src.legacy.backend import truncated_normal +from keras.src.legacy.backend import update +from keras.src.legacy.backend import update_add +from keras.src.legacy.backend import update_sub +from keras.src.legacy.backend import var +from keras.src.legacy.backend import variable +from keras.src.legacy.backend import zeros +from keras.src.legacy.backend import zeros_like from keras.src.utils.naming import get_uid diff --git a/keras/api/_tf_keras/keras/dtype_policies/__init__.py b/keras/api/_tf_keras/keras/dtype_policies/__init__.py index da8364263a2..2abb181f5df 100644 --- a/keras/api/_tf_keras/keras/dtype_policies/__init__.py +++ b/keras/api/_tf_keras/keras/dtype_policies/__init__.py @@ -4,6 +4,9 @@ since your modifications would be overwritten. """ +from keras.src.dtype_policies import deserialize +from keras.src.dtype_policies import get +from keras.src.dtype_policies import serialize from keras.src.dtype_policies.dtype_policy import DTypePolicy from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy diff --git a/keras/api/_tf_keras/keras/layers/__init__.py b/keras/api/_tf_keras/keras/layers/__init__.py index a4e1bf9a6bb..3d10d172b19 100644 --- a/keras/api/_tf_keras/keras/layers/__init__.py +++ b/keras/api/_tf_keras/keras/layers/__init__.py @@ -157,7 +157,6 @@ from keras.src.layers.regularization.activity_regularization import ( ActivityRegularization, ) -from keras.src.layers.regularization.alpha_dropout import AlphaDropout from keras.src.layers.regularization.dropout import Dropout from keras.src.layers.regularization.gaussian_dropout import GaussianDropout from keras.src.layers.regularization.gaussian_noise import GaussianNoise @@ -190,6 +189,10 @@ from keras.src.layers.rnn.simple_rnn import SimpleRNNCell from keras.src.layers.rnn.stacked_rnn_cells import StackedRNNCells from keras.src.layers.rnn.time_distributed import TimeDistributed +from keras.src.legacy.layers import AlphaDropout +from keras.src.legacy.layers import RandomHeight +from keras.src.legacy.layers import RandomWidth +from keras.src.legacy.layers import ThresholdedReLU from keras.src.utils.jax_layer import FlaxLayer from keras.src.utils.jax_layer import JaxLayer from keras.src.utils.torch_utils import TorchModuleWrapper diff --git a/keras/api/_tf_keras/keras/losses/__init__.py b/keras/api/_tf_keras/keras/losses/__init__.py index ecaadddf6b7..832d78f5fda 100644 --- a/keras/api/_tf_keras/keras/losses/__init__.py +++ b/keras/api/_tf_keras/keras/losses/__init__.py @@ -4,6 +4,7 @@ since your modifications would be overwritten. """ +from keras.src.legacy.losses import Reduction from keras.src.losses import deserialize from keras.src.losses import get from keras.src.losses import serialize @@ -38,12 +39,18 @@ from keras.src.losses.losses import dice from keras.src.losses.losses import hinge from keras.src.losses.losses import huber -from keras.src.losses.losses import kl_divergence -from keras.src.losses.losses import log_cosh -from keras.src.losses.losses import mean_absolute_error -from keras.src.losses.losses import mean_absolute_percentage_error -from keras.src.losses.losses import mean_squared_error -from keras.src.losses.losses import mean_squared_logarithmic_error +from keras.src.losses.losses import kl_divergence as KLD +from keras.src.losses.losses import kl_divergence as kld +from keras.src.losses.losses import kl_divergence as kullback_leibler_divergence +from keras.src.losses.losses import log_cosh as logcosh +from keras.src.losses.losses import mean_absolute_error as MAE +from keras.src.losses.losses import mean_absolute_error as mae +from keras.src.losses.losses import mean_absolute_percentage_error as MAPE +from keras.src.losses.losses import mean_absolute_percentage_error as mape +from keras.src.losses.losses import mean_squared_error as MSE +from keras.src.losses.losses import mean_squared_error as mse +from keras.src.losses.losses import mean_squared_logarithmic_error as MSLE +from keras.src.losses.losses import mean_squared_logarithmic_error as msle from keras.src.losses.losses import poisson from keras.src.losses.losses import sparse_categorical_crossentropy from keras.src.losses.losses import squared_hinge diff --git a/keras/api/_tf_keras/keras/metrics/__init__.py b/keras/api/_tf_keras/keras/metrics/__init__.py index dc59b32a46c..9b029f7aecb 100644 --- a/keras/api/_tf_keras/keras/metrics/__init__.py +++ b/keras/api/_tf_keras/keras/metrics/__init__.py @@ -11,12 +11,18 @@ from keras.src.losses.losses import categorical_hinge from keras.src.losses.losses import hinge from keras.src.losses.losses import huber -from keras.src.losses.losses import kl_divergence -from keras.src.losses.losses import log_cosh -from keras.src.losses.losses import mean_absolute_error -from keras.src.losses.losses import mean_absolute_percentage_error -from keras.src.losses.losses import mean_squared_error -from keras.src.losses.losses import mean_squared_logarithmic_error +from keras.src.losses.losses import kl_divergence as KLD +from keras.src.losses.losses import kl_divergence as kld +from keras.src.losses.losses import kl_divergence as kullback_leibler_divergence +from keras.src.losses.losses import log_cosh as logcosh +from keras.src.losses.losses import mean_absolute_error as MAE +from keras.src.losses.losses import mean_absolute_error as mae +from keras.src.losses.losses import mean_absolute_percentage_error as MAPE +from keras.src.losses.losses import mean_absolute_percentage_error as mape +from keras.src.losses.losses import mean_squared_error as MSE +from keras.src.losses.losses import mean_squared_error as mse +from keras.src.losses.losses import mean_squared_logarithmic_error as MSLE +from keras.src.losses.losses import mean_squared_logarithmic_error as msle from keras.src.losses.losses import poisson from keras.src.losses.losses import sparse_categorical_crossentropy from keras.src.losses.losses import squared_hinge diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index f62b9ac8223..be8f00acb55 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -56,6 +56,7 @@ from keras.src.ops.nn import categorical_crossentropy from keras.src.ops.nn import conv from keras.src.ops.nn import conv_transpose +from keras.src.ops.nn import ctc_decode from keras.src.ops.nn import ctc_loss from keras.src.ops.nn import depthwise_conv from keras.src.ops.nn import elu @@ -71,6 +72,7 @@ from keras.src.ops.nn import multi_hot from keras.src.ops.nn import normalize from keras.src.ops.nn import one_hot +from keras.src.ops.nn import psnr from keras.src.ops.nn import relu from keras.src.ops.nn import relu6 from keras.src.ops.nn import selu @@ -100,6 +102,7 @@ from keras.src.ops.numpy import arctanh from keras.src.ops.numpy import argmax from keras.src.ops.numpy import argmin +from keras.src.ops.numpy import argpartition from keras.src.ops.numpy import argsort from keras.src.ops.numpy import array from keras.src.ops.numpy import average @@ -190,10 +193,12 @@ from keras.src.ops.numpy import reshape from keras.src.ops.numpy import roll from keras.src.ops.numpy import round +from keras.src.ops.numpy import select from keras.src.ops.numpy import sign from keras.src.ops.numpy import sin from keras.src.ops.numpy import sinh from keras.src.ops.numpy import size +from keras.src.ops.numpy import slogdet from keras.src.ops.numpy import sort from keras.src.ops.numpy import split from keras.src.ops.numpy import sqrt @@ -218,6 +223,7 @@ from keras.src.ops.numpy import true_divide from keras.src.ops.numpy import var from keras.src.ops.numpy import vdot +from keras.src.ops.numpy import vectorize from keras.src.ops.numpy import vstack from keras.src.ops.numpy import where from keras.src.ops.numpy import zeros diff --git a/keras/api/_tf_keras/keras/ops/nn/__init__.py b/keras/api/_tf_keras/keras/ops/nn/__init__.py index 9452ea18a76..8c7e3d921b3 100644 --- a/keras/api/_tf_keras/keras/ops/nn/__init__.py +++ b/keras/api/_tf_keras/keras/ops/nn/__init__.py @@ -10,6 +10,7 @@ from keras.src.ops.nn import categorical_crossentropy from keras.src.ops.nn import conv from keras.src.ops.nn import conv_transpose +from keras.src.ops.nn import ctc_decode from keras.src.ops.nn import ctc_loss from keras.src.ops.nn import depthwise_conv from keras.src.ops.nn import elu @@ -25,6 +26,7 @@ from keras.src.ops.nn import multi_hot from keras.src.ops.nn import normalize from keras.src.ops.nn import one_hot +from keras.src.ops.nn import psnr from keras.src.ops.nn import relu from keras.src.ops.nn import relu6 from keras.src.ops.nn import selu diff --git a/keras/api/_tf_keras/keras/ops/numpy/__init__.py b/keras/api/_tf_keras/keras/ops/numpy/__init__.py index 1d5434e4028..05c8f93fd73 100644 --- a/keras/api/_tf_keras/keras/ops/numpy/__init__.py +++ b/keras/api/_tf_keras/keras/ops/numpy/__init__.py @@ -22,6 +22,7 @@ from keras.src.ops.numpy import arctanh from keras.src.ops.numpy import argmax from keras.src.ops.numpy import argmin +from keras.src.ops.numpy import argpartition from keras.src.ops.numpy import argsort from keras.src.ops.numpy import array from keras.src.ops.numpy import average @@ -112,10 +113,12 @@ from keras.src.ops.numpy import reshape from keras.src.ops.numpy import roll from keras.src.ops.numpy import round +from keras.src.ops.numpy import select from keras.src.ops.numpy import sign from keras.src.ops.numpy import sin from keras.src.ops.numpy import sinh from keras.src.ops.numpy import size +from keras.src.ops.numpy import slogdet from keras.src.ops.numpy import sort from keras.src.ops.numpy import split from keras.src.ops.numpy import sqrt @@ -140,6 +143,7 @@ from keras.src.ops.numpy import true_divide from keras.src.ops.numpy import var from keras.src.ops.numpy import vdot +from keras.src.ops.numpy import vectorize from keras.src.ops.numpy import vstack from keras.src.ops.numpy import where from keras.src.ops.numpy import zeros diff --git a/keras/api/_tf_keras/keras/preprocessing/__init__.py b/keras/api/_tf_keras/keras/preprocessing/__init__.py index c9ed7fd664c..737515c3696 100644 --- a/keras/api/_tf_keras/keras/preprocessing/__init__.py +++ b/keras/api/_tf_keras/keras/preprocessing/__init__.py @@ -4,8 +4,9 @@ since your modifications would be overwritten. """ -from keras.api.preprocessing import image -from keras.api.preprocessing import sequence +from keras.api._tf_keras.keras.preprocessing import image +from keras.api._tf_keras.keras.preprocessing import sequence +from keras.api._tf_keras.keras.preprocessing import text from keras.src.utils.image_dataset_utils import image_dataset_from_directory from keras.src.utils.text_dataset_utils import text_dataset_from_directory from keras.src.utils.timeseries_dataset_utils import ( diff --git a/keras/api/_tf_keras/keras/preprocessing/image/__init__.py b/keras/api/_tf_keras/keras/preprocessing/image/__init__.py index f68afe8789d..2ca54805acb 100644 --- a/keras/api/_tf_keras/keras/preprocessing/image/__init__.py +++ b/keras/api/_tf_keras/keras/preprocessing/image/__init__.py @@ -4,6 +4,19 @@ since your modifications would be overwritten. """ +from keras.src.legacy.preprocessing.image import DirectoryIterator +from keras.src.legacy.preprocessing.image import ImageDataGenerator +from keras.src.legacy.preprocessing.image import Iterator +from keras.src.legacy.preprocessing.image import NumpyArrayIterator +from keras.src.legacy.preprocessing.image import apply_affine_transform +from keras.src.legacy.preprocessing.image import apply_brightness_shift +from keras.src.legacy.preprocessing.image import apply_channel_shift +from keras.src.legacy.preprocessing.image import random_brightness +from keras.src.legacy.preprocessing.image import random_channel_shift +from keras.src.legacy.preprocessing.image import random_rotation +from keras.src.legacy.preprocessing.image import random_shear +from keras.src.legacy.preprocessing.image import random_shift +from keras.src.legacy.preprocessing.image import random_zoom from keras.src.utils.image_utils import array_to_img from keras.src.utils.image_utils import img_to_array from keras.src.utils.image_utils import load_img diff --git a/keras/api/_tf_keras/keras/preprocessing/sequence/__init__.py b/keras/api/_tf_keras/keras/preprocessing/sequence/__init__.py index 188e01af9c4..1f6388250b6 100644 --- a/keras/api/_tf_keras/keras/preprocessing/sequence/__init__.py +++ b/keras/api/_tf_keras/keras/preprocessing/sequence/__init__.py @@ -4,4 +4,7 @@ since your modifications would be overwritten. """ +from keras.src.legacy.preprocessing.sequence import TimeseriesGenerator +from keras.src.legacy.preprocessing.sequence import make_sampling_table +from keras.src.legacy.preprocessing.sequence import skipgrams from keras.src.utils.sequence_utils import pad_sequences diff --git a/keras/api/_tf_keras/keras/preprocessing/text/__init__.py b/keras/api/_tf_keras/keras/preprocessing/text/__init__.py new file mode 100644 index 00000000000..2e8799f3d5d --- /dev/null +++ b/keras/api/_tf_keras/keras/preprocessing/text/__init__.py @@ -0,0 +1,11 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.legacy.preprocessing.text import Tokenizer +from keras.src.legacy.preprocessing.text import hashing_trick +from keras.src.legacy.preprocessing.text import one_hot +from keras.src.legacy.preprocessing.text import text_to_word_sequence +from keras.src.legacy.preprocessing.text import tokenizer_from_json diff --git a/keras/api/dtype_policies/__init__.py b/keras/api/dtype_policies/__init__.py index da8364263a2..2abb181f5df 100644 --- a/keras/api/dtype_policies/__init__.py +++ b/keras/api/dtype_policies/__init__.py @@ -4,6 +4,9 @@ since your modifications would be overwritten. """ +from keras.src.dtype_policies import deserialize +from keras.src.dtype_policies import get +from keras.src.dtype_policies import serialize from keras.src.dtype_policies.dtype_policy import DTypePolicy from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index f62b9ac8223..be8f00acb55 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -56,6 +56,7 @@ from keras.src.ops.nn import categorical_crossentropy from keras.src.ops.nn import conv from keras.src.ops.nn import conv_transpose +from keras.src.ops.nn import ctc_decode from keras.src.ops.nn import ctc_loss from keras.src.ops.nn import depthwise_conv from keras.src.ops.nn import elu @@ -71,6 +72,7 @@ from keras.src.ops.nn import multi_hot from keras.src.ops.nn import normalize from keras.src.ops.nn import one_hot +from keras.src.ops.nn import psnr from keras.src.ops.nn import relu from keras.src.ops.nn import relu6 from keras.src.ops.nn import selu @@ -100,6 +102,7 @@ from keras.src.ops.numpy import arctanh from keras.src.ops.numpy import argmax from keras.src.ops.numpy import argmin +from keras.src.ops.numpy import argpartition from keras.src.ops.numpy import argsort from keras.src.ops.numpy import array from keras.src.ops.numpy import average @@ -190,10 +193,12 @@ from keras.src.ops.numpy import reshape from keras.src.ops.numpy import roll from keras.src.ops.numpy import round +from keras.src.ops.numpy import select from keras.src.ops.numpy import sign from keras.src.ops.numpy import sin from keras.src.ops.numpy import sinh from keras.src.ops.numpy import size +from keras.src.ops.numpy import slogdet from keras.src.ops.numpy import sort from keras.src.ops.numpy import split from keras.src.ops.numpy import sqrt @@ -218,6 +223,7 @@ from keras.src.ops.numpy import true_divide from keras.src.ops.numpy import var from keras.src.ops.numpy import vdot +from keras.src.ops.numpy import vectorize from keras.src.ops.numpy import vstack from keras.src.ops.numpy import where from keras.src.ops.numpy import zeros diff --git a/keras/api/ops/nn/__init__.py b/keras/api/ops/nn/__init__.py index 9452ea18a76..8c7e3d921b3 100644 --- a/keras/api/ops/nn/__init__.py +++ b/keras/api/ops/nn/__init__.py @@ -10,6 +10,7 @@ from keras.src.ops.nn import categorical_crossentropy from keras.src.ops.nn import conv from keras.src.ops.nn import conv_transpose +from keras.src.ops.nn import ctc_decode from keras.src.ops.nn import ctc_loss from keras.src.ops.nn import depthwise_conv from keras.src.ops.nn import elu @@ -25,6 +26,7 @@ from keras.src.ops.nn import multi_hot from keras.src.ops.nn import normalize from keras.src.ops.nn import one_hot +from keras.src.ops.nn import psnr from keras.src.ops.nn import relu from keras.src.ops.nn import relu6 from keras.src.ops.nn import selu diff --git a/keras/api/ops/numpy/__init__.py b/keras/api/ops/numpy/__init__.py index 1d5434e4028..05c8f93fd73 100644 --- a/keras/api/ops/numpy/__init__.py +++ b/keras/api/ops/numpy/__init__.py @@ -22,6 +22,7 @@ from keras.src.ops.numpy import arctanh from keras.src.ops.numpy import argmax from keras.src.ops.numpy import argmin +from keras.src.ops.numpy import argpartition from keras.src.ops.numpy import argsort from keras.src.ops.numpy import array from keras.src.ops.numpy import average @@ -112,10 +113,12 @@ from keras.src.ops.numpy import reshape from keras.src.ops.numpy import roll from keras.src.ops.numpy import round +from keras.src.ops.numpy import select from keras.src.ops.numpy import sign from keras.src.ops.numpy import sin from keras.src.ops.numpy import sinh from keras.src.ops.numpy import size +from keras.src.ops.numpy import slogdet from keras.src.ops.numpy import sort from keras.src.ops.numpy import split from keras.src.ops.numpy import sqrt @@ -140,6 +143,7 @@ from keras.src.ops.numpy import true_divide from keras.src.ops.numpy import var from keras.src.ops.numpy import vdot +from keras.src.ops.numpy import vectorize from keras.src.ops.numpy import vstack from keras.src.ops.numpy import where from keras.src.ops.numpy import zeros diff --git a/keras/src/backend/common/backend_utils.py b/keras/src/backend/common/backend_utils.py index 1d005be50b7..4be0d75d5f2 100644 --- a/keras/src/backend/common/backend_utils.py +++ b/keras/src/backend/common/backend_utils.py @@ -1,4 +1,6 @@ +import functools import operator +import re import warnings @@ -288,3 +290,224 @@ def to_tuple_or_list(value): if isinstance(value, int): return (value,) return value + + +### Code for ops.vectorize() used for TF and torch backends. + +# See http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html +_DIMENSION_NAME = r"\w+" +_CORE_DIMENSION_LIST = "(?:{0:}(?:,{0:})*)?".format(_DIMENSION_NAME) +_ARGUMENT = rf"\({_CORE_DIMENSION_LIST}\)" +_ARGUMENT_LIST = "{0:}(?:,{0:})*".format(_ARGUMENT) +_SIGNATURE = "^{0:}->{0:}$".format(_ARGUMENT_LIST) + + +def _vectorize_parse_gufunc_signature( + signature, +): + if not re.match(_SIGNATURE, signature): + raise ValueError(f"not a valid gufunc signature: {signature}") + args, retvals = ( + [ + tuple(re.findall(_DIMENSION_NAME, arg)) + for arg in re.findall(_ARGUMENT, arg_list) + ] + for arg_list in signature.split("->") + ) + return args, retvals + + +def _vectorize_update_dim_sizes(dim_sizes, shape, core_dims, is_input=True): + num_core_dims = len(core_dims) + if is_input: + if len(shape) < num_core_dims: + raise ValueError( + f"input with shape {shape} does not " + "have enough dimensions for all core " + f"dimensions {core_dims}" + ) + else: + if len(shape) != num_core_dims: + raise ValueError( + f"output shape {shape} does not " + f"match core dimensions {core_dims}" + ) + + core_shape = shape[-num_core_dims:] if core_dims else () + for dim, size in zip(core_dims, core_shape): + if dim not in dim_sizes: + dim_sizes[dim] = size + elif size != dim_sizes[dim]: + raise ValueError( + f"inconsistent size for core dimension {dim}: " + f"{size} vs {dim_sizes[dim]}" + ) + + +def _vectorize_parse_input_dimensions( + args, + input_core_dims, +): + from keras.src import ops + + if len(args) != len(input_core_dims): + raise TypeError( + "wrong number of positional arguments: " + f"expected {len(input_core_dims)}, got {len(args)}" + ) + shapes = [] + dim_sizes: dict[str, int] = {} + for arg, core_dims in zip(args, input_core_dims): + _vectorize_update_dim_sizes( + dim_sizes, arg.shape, core_dims, is_input=True + ) + ndim = arg.ndim - len(core_dims) + shapes.append(arg.shape[:ndim]) + broadcast_shape = shapes[0] + for s in shapes: + broadcast_shape = ops.broadcast_shapes(broadcast_shape, s) + return broadcast_shape, dim_sizes + + +def _vectorize_check_output_dims( + func, + dim_sizes, + expected_output_core_dims, +): + from keras.src import ops + + def wrapped(*args): + out = func(*args) + if isinstance(out, (list, tuple)): + out_shapes = [ops.shape(x) for x in out] + else: + out_shapes = [out.shape] + + if expected_output_core_dims is None: + output_core_dims = [()] * len(out_shapes) + else: + output_core_dims = expected_output_core_dims + if len(output_core_dims) > 1 and not isinstance(out, tuple): + raise TypeError( + "output must be a tuple when multiple outputs " + f"are expected, got: {out}" + ) + if len(out_shapes) != len(output_core_dims): + raise TypeError( + "wrong number of output arguments: " + f"expected {len(output_core_dims)}, got {len(out_shapes)}" + ) + + sizes = dict(dim_sizes) + for shape, core_dims in zip(out_shapes, output_core_dims): + _vectorize_update_dim_sizes(sizes, shape, core_dims, is_input=False) + + return out + + return wrapped + + +def _vectorize_apply_excluded(func, excluded, args, kwargs): + if not excluded: + return func, args, kwargs + + dynamic_args = [arg for i, arg in enumerate(args) if i not in excluded] + dynamic_kwargs = { + key: val for key, val in kwargs.items() if key not in excluded + } + static_args = [ + (i, args[i]) + for i in sorted(e for e in excluded if isinstance(e, int)) + if i < len(args) + ] + static_kwargs = {key: val for key, val in kwargs.items() if key in excluded} + + def new_func(*args, **kwargs): + args = list(args) + for i, arg in static_args: + args.insert(i, arg) + return func(*args, **kwargs, **static_kwargs) + + return new_func, dynamic_args, dynamic_kwargs + + +def vectorize_impl(pyfunc, vmap_fn, *, excluded=None, signature=None): + """Implementation adapted from JAX and NumPy.""" + + from keras.src import ops + + excluded = None or set() + + @functools.wraps(pyfunc) + def wrapped(*args, **kwargs): + excluded_func, args, kwargs = _vectorize_apply_excluded( + pyfunc, excluded, args, kwargs + ) + + if signature is not None: + input_core_dims, output_core_dims = ( + _vectorize_parse_gufunc_signature(signature) + ) + else: + input_core_dims = [()] * len(args) + output_core_dims = None + + none_args = {i for i, arg in enumerate(args) if arg is None} + if any(none_args): + if any(input_core_dims[i] != () for i in none_args): + raise ValueError( + f"Cannot pass None at locations {none_args} " + f"with signature={signature}" + ) + excluded_func, args, _ = _vectorize_apply_excluded( + excluded_func, none_args, args, {} + ) + input_core_dims = [ + dim + for i, dim in enumerate(input_core_dims) + if i not in none_args + ] + + args = tuple(map(ops.convert_to_tensor, args)) + + broadcast_shape, dim_sizes = _vectorize_parse_input_dimensions( + args, input_core_dims + ) + checked_func = _vectorize_check_output_dims( + excluded_func, dim_sizes, output_core_dims + ) + squeezed_args = [] + rev_filled_shapes = [] + for arg, core_dims in zip(args, input_core_dims): + noncore_shape = arg.shape[: arg.ndim - len(core_dims)] + + pad_ndim = len(broadcast_shape) - len(noncore_shape) + filled_shape = pad_ndim * (1,) + noncore_shape + rev_filled_shapes.append(filled_shape[::-1]) + + squeeze_indices = tuple( + i for i, size in enumerate(noncore_shape) if size == 1 + ) + squeezed_arg = ops.squeeze(arg, axis=squeeze_indices) + squeezed_args.append(squeezed_arg) + + vectorized_func = checked_func + dims_to_expand = [] + for negdim, axis_sizes in enumerate(zip(*rev_filled_shapes)): + in_axes = tuple(None if size == 1 else 0 for size in axis_sizes) + if all(axis is None for axis in in_axes): + dims_to_expand.append(len(broadcast_shape) - 1 - negdim) + else: + vectorized_func = vmap_fn(vectorized_func, in_axes) + result = vectorized_func(*squeezed_args) + + if not dims_to_expand: + return result + elif isinstance(result, tuple): + return tuple( + ops.expand_dims(r, axis=dims_to_expand) for r in result + ) + else: + return ops.expand_dims(result, axis=dims_to_expand) + + return wrapped diff --git a/keras/src/backend/jax/linalg.py b/keras/src/backend/jax/linalg.py index 7984a734e9d..1e1c1cedf9b 100644 --- a/keras/src/backend/jax/linalg.py +++ b/keras/src/backend/jax/linalg.py @@ -11,11 +11,18 @@ def cholesky(a): out = jnp.linalg.cholesky(a) - if jnp.any(jnp.isnan(out)): - raise ValueError( - "Cholesky decomposition failed. " - "The input might not be a valid positive definite matrix." - ) + try: + # In eager mode, raise for nan to + # achieve behavior consistency with numpy + if jnp.any(jnp.isnan(out)): + raise ValueError( + "Cholesky decomposition failed. " + "The input might not be a valid " + "positive definite matrix." + ) + except jax.errors.TracerBoolConversionError: + # Cannot raise for nan in tracing mode + pass return out diff --git a/keras/src/backend/jax/math.py b/keras/src/backend/jax/math.py index 361eeee8917..4119f744e1a 100644 --- a/keras/src/backend/jax/math.py +++ b/keras/src/backend/jax/math.py @@ -204,6 +204,12 @@ def istft( x = _get_complex_tensor_from_tuple(x) dtype = jnp.real(x).dtype + if len(x.shape) < 2: + raise ValueError( + f"Input `x` must have at least 2 dimensions. " + f"Received shape: {x.shape}" + ) + expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1) l_pad = (fft_length - sequence_length) // 2 r_pad = fft_length - sequence_length - l_pad diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index 740c9b17e5a..863ced6b005 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -1,16 +1,16 @@ +import builtins +import math + import jax import jax.experimental.sparse as jax_sparse import jax.numpy as jnp -import numpy as np from jax import lax from jax import nn as jnn -from keras.src.backend import standardize_data_format -from keras.src.backend import standardize_dtype +from keras.src import backend from keras.src.backend.common.backend_utils import ( compute_conv_transpose_padding_args_for_jax, ) -from keras.src.backend.config import epsilon from keras.src.backend.jax.core import cast from keras.src.backend.jax.core import convert_to_tensor @@ -157,7 +157,7 @@ def max_pool( padding="valid", data_format=None, ): - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 pool_size = _convert_to_spatial_operand( pool_size, num_spatial_dims, data_format @@ -176,7 +176,7 @@ def average_pool( padding, data_format=None, ): - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 pool_size = _convert_to_spatial_operand( pool_size, num_spatial_dims, data_format @@ -189,7 +189,7 @@ def average_pool( pooled = _pool(inputs, 0.0, lax.add, pool_size, strides, padding) if padding == "valid": # Avoid the extra reduce_window. - return pooled / np.prod(pool_size) + return pooled / math.prod(pool_size) else: # Count the number of valid entries at each input point, then use that # for computing average. Assumes that any two arrays of same shape will @@ -242,7 +242,7 @@ def conv( data_format=None, dilation_rate=1, ): - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 dimension_numbers = _convert_to_lax_conv_dimension_numbers( num_spatial_dims, @@ -292,7 +292,7 @@ def depthwise_conv( data_format=None, dilation_rate=1, ): - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 dimension_numbers = _convert_to_lax_conv_dimension_numbers( num_spatial_dims, @@ -338,7 +338,7 @@ def separable_conv( data_format=None, dilation_rate=1, ): - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) depthwise_conv_output = depthwise_conv( inputs, depthwise_kernel, @@ -366,7 +366,7 @@ def conv_transpose( data_format=None, dilation_rate=1, ): - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 padding_values = compute_conv_transpose_padding_args_for_jax( input_shape=inputs.shape, @@ -477,7 +477,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): log_prob = jax.nn.log_softmax(output, axis=axis) else: output = output / jnp.sum(output, axis, keepdims=True) - output = jnp.clip(output, epsilon(), 1.0 - epsilon()) + output = jnp.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) log_prob = jnp.log(output) return -jnp.sum(target * log_prob, axis=axis) @@ -504,7 +504,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): log_prob = jax.nn.log_softmax(output, axis=axis) else: output = output / jnp.sum(output, axis, keepdims=True) - output = jnp.clip(output, epsilon(), 1.0 - epsilon()) + output = jnp.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) log_prob = jnp.log(output) target = jnn.one_hot(target, output.shape[axis], axis=axis) return -jnp.sum(target * log_prob, axis=axis) @@ -526,7 +526,7 @@ def binary_crossentropy(target, output, from_logits=False): log_neg_logits = jax.nn.log_sigmoid(-output) return -1.0 * target * log_logits - (1.0 - target) * log_neg_logits - output = jnp.clip(output, epsilon(), 1.0 - epsilon()) + output = jnp.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) bce = target * jnp.log(output) bce += (1.0 - target) * jnp.log(1.0 - output) return -bce @@ -541,7 +541,7 @@ def moments(x, axes, keepdims=False, synchronized=False): # workaround, we simply perform the operations on float32 and convert back # to float16 need_cast = False - ori_dtype = standardize_dtype(x.dtype) + ori_dtype = backend.standardize_dtype(x.dtype) if ori_dtype in ("float16", "bfloat16"): need_cast = True x = cast(x, "float32") @@ -586,76 +586,357 @@ def batch_normalization( return jnp.add(x * inv, res) -def ctc_loss( - target, - output, - target_length, - output_length, - mask_index=0, +def ctc_loss(target, output, target_length, output_length, mask_index=0): + # Ref: https://github.com/google-deepmind/optax + # optax.ctc_loss_with_forward_probs + target = convert_to_tensor(target, dtype="int32") + output = convert_to_tensor(output) + target_length = convert_to_tensor(target_length, "int32") + output_length = convert_to_tensor(output_length, "int32") + batch_size, max_input_length, num_classes = output.shape + batch_size, max_label_length = target.shape + log_epsilon = -1e5 + + # Ensure that the dtype promotion behavior matchs that of `tf.nn.ctc_loss` + dtype = backend.result_type(output.dtype, "float32") + output = cast(output, dtype) + + def _lengths_to_paddings(lengths, max_length): + indices = jnp.arange(max_length).reshape( + (1,) * lengths.ndim + (max_length,) + ) + lengths = jnp.expand_dims(lengths, axis=-1) + elem_valid = indices < lengths + return jnp.logical_not(elem_valid) + + target_paddings = _lengths_to_paddings(target_length, max_label_length) + output_paddings = _lengths_to_paddings(output_length, max_input_length) + target_paddings = target_paddings.astype(output.dtype) + output_paddings = output_paddings.astype(output.dtype) + + logprobs = jnn.log_softmax(output) + label_lengths = max_label_length - jnp.sum(target_paddings, axis=1).astype( + jnp.int32 + ) + + # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1]. + repeat = (target[:, :-1] == target[:, 1:]).astype(jnp.float32) + repeat = jnp.pad(repeat, ((0, 0), (0, 1))) + + logprobs_phi = logprobs[:, :, mask_index : mask_index + 1] # [B, T, 1] + logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] + + _one_hot = jax.nn.one_hot(target, num_classes=num_classes) # [B, N, K] + logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, _one_hot) + logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N] + + # [B, N] + logalpha_phi_init = ( + jnp.ones((batch_size, max_label_length + 1), dtype=output.dtype) + * log_epsilon + ) + logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0) + logalpha_emit_init = ( + jnp.ones((batch_size, max_label_length), dtype=output.dtype) + * log_epsilon + ) + + def update_phi_score(phi, added_score): + # Update `phi[:, 1:]`` with adding `added_score` in log space. + return jnp.concatenate( + [phi[:, :1], jnp.logaddexp(phi[:, 1:], added_score)], axis=-1 + ) + + def loop_body(prev, x): + prev_phi, prev_emit = prev + # emit-to-phi epsilon transition, except if the next label is repetition + prev_phi_orig = prev_phi + prev_phi = update_phi_score(prev_phi, prev_emit + log_epsilon * repeat) + + logprob_emit, logprob_phi, pad = x + + # phi-to-emit transition + next_emit = jnp.logaddexp( + prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit + ) + # self-loop transition + next_phi = prev_phi + logprob_phi + # emit-to-phi blank transition only when the next label is repetition + next_phi = update_phi_score( + next_phi, prev_emit + logprob_phi + log_epsilon * (1.0 - repeat) + ) + + pad = pad.reshape((batch_size, 1)) + next_emit = pad * prev_emit + (1.0 - pad) * next_emit + next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi + + return (next_phi, next_emit), (next_phi, next_emit) + + xs = (logprobs_emit, logprobs_phi, output_paddings.transpose((1, 0))) + _, (logalpha_phi, logalpha_emit) = jax.lax.scan( + loop_body, (logalpha_phi_init, logalpha_emit_init), xs + ) + + # last row needs to be updated with the last epsilon transition + logalpha_phi_last = update_phi_score(logalpha_phi[-1], logalpha_emit[-1]) + logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last) + + # extract per_seq_loss + # [B, N+1] + _one_hot = jax.nn.one_hot(label_lengths, num_classes=max_label_length + 1) + per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, _one_hot) + return per_seq_loss + + +def _ctc_greedy_decode( + inputs, + sequence_lengths, + merge_repeated=True, + mask_index=None, +): + inputs = convert_to_tensor(inputs) + sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32") + batch_size, max_length, num_classes = inputs.shape + + if mask_index is None: + mask_index = num_classes - 1 + + indices = jnp.argmax(inputs, axis=-1) + scores = jnp.max(inputs, axis=-1) + + seqlen_mask = jnp.arange(max_length)[None, :] + seqlen_mask = seqlen_mask >= sequence_lengths[:, None] + + indices = jnp.where(seqlen_mask, mask_index, indices) + scores = jnp.where(seqlen_mask, 0.0, scores) + + if merge_repeated: + repeat_mask = indices[:, 1:] == indices[:, :-1] + repeat_mask = jnp.pad(repeat_mask, ((0, 0), (1, 0))) + indices = jnp.where(repeat_mask, mask_index, indices) + + # We set to -1 for blank labels + invalid_mask = indices == mask_index + indices = jnp.where(invalid_mask, -1, indices) + + # We rearrange the indices by moving `mask_index` to the end of the array + order = jnp.expand_dims(jnp.arange(max_length), axis=0) # [1, N] + order = jnp.tile(order, (batch_size, 1)) # [B, N] + order = jnp.where(invalid_mask, max_length, order) + order = jnp.argsort(order, axis=-1) + indices = jnp.take_along_axis(indices, order, axis=-1) + + scores = -jnp.sum(scores, axis=1)[:, None] + indices = jnp.expand_dims(indices, axis=0) + return indices, scores + + +def _ctc_beam_search_decode( + inputs, + sequence_lengths, + beam_width=100, + top_paths=1, + mask_index=None, ): - batch_size, _, _ = output.shape - batch_size, max_target_length = target.shape + inputs = convert_to_tensor(inputs) + sequence_lengths = convert_to_tensor(sequence_lengths) + + batch_size, max_seq_len, num_classes = inputs.shape + inputs = jnn.log_softmax(inputs) + seqlen_mask = jnp.arange(max_seq_len)[None, :] >= sequence_lengths[:, None] - output = output.transpose((1, 0, 2)) - target = target.transpose((1, 0)).astype("int32") + if mask_index is None: + mask_index = num_classes - 1 - logits = jnn.log_softmax(output) - mgrid_t, mgrid_b = jnp.meshgrid( - jnp.arange(max_target_length), jnp.arange(batch_size) + # This is a workaround for the fact that jnp.argsort does not support + # the order parameter which is used to break ties when scores are equal. + # For compatibility with the tensorflow implementation, we flip the inputs + # and the mask_index, and then flip the classes back to the correct indices + inputs = jnp.flip(inputs, axis=2) + mask_index = num_classes - mask_index - 1 + + _pad = -1 + + init_paths = jnp.full( + (batch_size, 2 * beam_width, max_seq_len), _pad, dtype=jnp.int32 ) - logprobs_emit = logits[mgrid_t, mgrid_b, target[:, :, None]] - logprobs_mask = logits[:, :, mask_index] - logit_paddings = jnp.array( - jnp.arange(max_target_length) < output_length[:, None], - dtype=jnp.float32, + num_init_paths = builtins.min(num_classes, beam_width) + max_classes = jnp.argsort(inputs[:, 0], axis=1)[:, -num_init_paths:] + init_classes = jnp.where(max_classes == mask_index, _pad, max_classes) + init_paths = init_paths.at[:, :num_init_paths, 0].set(init_classes) + + init_scores = ( + jnp.full((batch_size, 2 * beam_width), -jnp.inf, dtype=inputs.dtype) + .at[:, :num_init_paths] + .set(jnp.take_along_axis(inputs[:, 0], max_classes, axis=1)) ) + init_masked = init_paths[:, :, 0] == _pad + + def _extend_paths(paths, scores, masked, x): + paths = jnp.repeat(paths, num_classes, axis=0) + scores = jnp.repeat(scores, num_classes) + masked = jnp.repeat(masked, num_classes) + + path_tail_index = jnp.argmax(paths == _pad, axis=1) + paths_arange = jnp.arange(2 * beam_width * num_classes) + path_tails = paths[paths_arange, path_tail_index - 1] + path_tails = jnp.where(path_tail_index == 0, _pad, path_tails) + + classes = jnp.arange(num_classes).at[mask_index].set(_pad) + classes = jnp.tile(classes, 2 * beam_width) + + prev_masked = masked + masked = classes == _pad + + masked_repeat = ~prev_masked & (path_tails == classes) + classes = jnp.where(masked_repeat, _pad, classes) + paths = paths.at[paths_arange, path_tail_index].set(classes) + + x = jnp.tile(x, 2 * beam_width) + scores = scores + x + + return paths, scores, masked + + def _merge_scores(unique_inverse, scores): + scores_max = jnp.max(scores) + scores_exp = jnp.exp(scores - scores_max) + scores = jnp.zeros_like(scores).at[unique_inverse].add(scores_exp) + scores = jnp.log(scores) + scores_max + return scores + + def _prune_paths(paths, scores, masked): + paths, unique_inverse = jnp.unique( + paths, + return_inverse=True, + size=2 * num_classes * beam_width, + axis=0, + fill_value=_pad, + ) + if len(unique_inverse.shape) >= 2: + unique_inverse = jnp.squeeze(unique_inverse, axis=1) - repeat = jnp.array(target[1:] == target[:-1]) - repeat = jnp.pad(repeat, ((0, 1), (0, 0))).transpose((1, 0)) + emit_scores = jnp.where(masked, -jnp.inf, scores) + mask_scores = jnp.where(masked, scores, -jnp.inf) - _logepsilon = -100000.0 + emit_scores = _merge_scores(unique_inverse, emit_scores) + mask_scores = _merge_scores(unique_inverse, mask_scores) - def _iterate(prev, x): - prev_mask, prev_emit = prev - logprob_mask, logprob_emit, pad = x + total_scores = jnp.logaddexp(emit_scores, mask_scores) + top_indices = jnp.argsort(total_scores)[-beam_width:] - prev_mask_orig = prev_mask - prev_mask = prev_mask.at[:, 1:].set( - jnp.logaddexp(prev_mask[:, 1:], prev_emit + _logepsilon * repeat), + paths = paths[top_indices] + emit_scores = emit_scores[top_indices] + mask_scores = mask_scores[top_indices] + + paths = jnp.tile(paths, (2, 1)) + scores = jnp.concatenate([emit_scores, mask_scores]) + masked = jnp.concatenate( + [jnp.zeros(beam_width, bool), jnp.ones(beam_width, bool)] ) - emit = jnp.logaddexp( - prev_mask[:, :-1] + logprob_emit, prev_emit + logprob_emit + + return paths, scores, masked + + def _decode_step(paths, scores, masked, x): + paths, scores, masked = _extend_paths(paths, scores, masked, x) + paths, scores, masked = _prune_paths(paths, scores, masked) + return paths, scores, masked + + def _step(prev, x): + paths, scores, masked = prev + x, seqlen_mask = x + + paths, scores, masked = lax.cond( + seqlen_mask, + lambda paths, scores, masked, x: (paths, scores, masked), + _decode_step, + paths, + scores, + masked, + x, ) - mask = prev_mask + logprob_mask[:, None] - mask = mask.at[:, 1:].set( - jnp.logaddexp( - mask[:, 1:], - prev_emit + logprob_mask[:, None] + _logepsilon * (1 - repeat), - ) + return (paths, scores, masked), None + + def _decode_batch( + init_paths, init_scores, init_masked, inputs, seqlen_mask + ): + (paths, scores, masked), _ = lax.scan( + _step, + (init_paths, init_scores, init_masked), + (inputs[1:], seqlen_mask[1:]), ) - pad = pad[:, None] - emit = emit * pad + prev_emit * (1 - pad) - mask = mask * pad + prev_mask_orig * (1 - pad) + paths, unique_inverse = jnp.unique( + paths, + return_inverse=True, + size=2 * num_classes * beam_width, + axis=0, + fill_value=_pad, + ) + if len(unique_inverse.shape) >= 2: + unique_inverse = jnp.squeeze(unique_inverse, axis=1) + scores = _merge_scores(unique_inverse, scores) - return (mask, emit), (mask, emit) + top_indices = jnp.argsort(scores)[-top_paths:][::-1] + paths = paths[top_indices] + scores = scores[top_indices] - mask_init = jnp.full((batch_size, max_target_length + 1), _logepsilon) - mask_init = mask_init.at[:, 0].set(0.0) - emit_init = jnp.full((batch_size, max_target_length), _logepsilon) + return paths, scores - _, (alphas_mask, alphas_emit) = lax.scan( - _iterate, - (mask_init, emit_init), - (logprobs_mask, logprobs_emit, logit_paddings.transpose()), + paths, scores = jax.vmap(_decode_batch)( + init_paths, init_scores, init_masked, inputs, seqlen_mask ) - last_alpha_mask = ( - alphas_mask[-1] - .at[:, 1:] - .set(jnp.logaddexp(alphas_mask[-1, :, 1:], alphas_emit[-1])) - ) + # convert classes back to the correct indices + paths = jnp.where(paths == _pad, _pad, num_classes - paths - 1) + paths = jnp.transpose(paths, [1, 0, 2]) + return paths, scores + + +def ctc_decode( + inputs, + sequence_lengths, + strategy="greedy", + beam_width=100, + top_paths=1, + merge_repeated=True, + mask_index=0, +): + inputs = convert_to_tensor(inputs) + dtype = backend.result_type(inputs.dtype, "float32") + inputs = cast(inputs, dtype) + + if strategy == "greedy": + return _ctc_greedy_decode( + inputs, + sequence_lengths, + merge_repeated=merge_repeated, + mask_index=mask_index, + ) + elif strategy == "beam_search": + return _ctc_beam_search_decode( + inputs, + sequence_lengths, + beam_width=beam_width, + top_paths=top_paths, + mask_index=mask_index, + ) + else: + raise ValueError( + f"Invalid strategy {strategy}. Supported values are " + "'greedy' and 'beam_search'." + ) + + +def psnr(x1, x2, max_val): + if x1.shape != x2.shape: + raise ValueError( + f"Input shapes {x1.shape} and {x2.shape} must " + "match for PSNR calculation. " + ) - return -last_alpha_mask[jnp.arange(batch_size), target_length] + max_val = convert_to_tensor(max_val, dtype=x2.dtype) + mse = jnp.mean(jnp.square(x1 - x2)) + psnr = 20 * jnp.log10(max_val) - 10 * jnp.log10(mse) + return psnr diff --git a/keras/src/backend/jax/numpy.py b/keras/src/backend/jax/numpy.py index 36605efc811..60dec64420c 100644 --- a/keras/src/backend/jax/numpy.py +++ b/keras/src/backend/jax/numpy.py @@ -223,10 +223,8 @@ def absolute(x): return jnp.absolute(x) -@sparse.elementwise_unary(linear=False) def abs(x): - x = convert_to_tensor(x) - return jnp.absolute(x) + return absolute(x) def all(x, axis=None, keepdims=False): @@ -770,9 +768,9 @@ def moveaxis(x, source, destination): return jnp.moveaxis(x, source=source, destination=destination) -def nan_to_num(x): +def nan_to_num(x, nan=0.0, posinf=None, neginf=None): x = convert_to_tensor(x) - return jnp.nan_to_num(x) + return jnp.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) def ndim(x): @@ -974,7 +972,18 @@ def tensordot(x1, x2, axes=2): @sparse.elementwise_unary(linear=False) def round(x, decimals=0): x = convert_to_tensor(x) - return jnp.round(x, decimals=decimals) + + # jnp.round doesn't support decimals < 0 for integers + x_dtype = standardize_dtype(x.dtype) + if "int" in x_dtype and decimals < 0: + factor = cast(math.pow(10, decimals), config.floatx()) + x = cast(x, config.floatx()) + x = jnp.multiply(x, factor) + x = jnp.round(x) + x = jnp.divide(x, factor) + return cast(x, x_dtype) + else: + return jnp.round(x, decimals=decimals) def tile(x, repeats): @@ -1014,6 +1023,12 @@ def vstack(xs): return jnp.vstack(xs) +def vectorize(pyfunc, *, excluded=None, signature=None): + if excluded is None: + excluded = set() + return jnp.vectorize(pyfunc, excluded=excluded, signature=signature) + + def where(condition, x1, x2): return jnp.where(condition, x1, x2) @@ -1143,3 +1158,16 @@ def correlate(x1, x2, mode="valid"): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) return jnp.correlate(x1, x2, mode) + + +def select(condlist, choicelist, default=0): + return jnp.select(condlist, choicelist, default=default) + + +def slogdet(x): + x = convert_to_tensor(x) + return tuple(jnp.linalg.slogdet(x)) + + +def argpartition(x, kth, axis=-1): + return jnp.argpartition(x, kth, axis) diff --git a/keras/src/backend/numpy/nn.py b/keras/src/backend/numpy/nn.py index 7dee370ef00..bcac274c245 100644 --- a/keras/src/backend/numpy/nn.py +++ b/keras/src/backend/numpy/nn.py @@ -1,14 +1,11 @@ import jax import numpy as np from jax import lax -from jax import numpy as jnp -from keras.src.backend import standardize_data_format -from keras.src.backend import standardize_dtype +from keras.src import backend from keras.src.backend.common.backend_utils import ( compute_conv_transpose_padding_args_for_jax, ) -from keras.src.backend.config import epsilon from keras.src.backend.numpy.core import cast from keras.src.backend.numpy.core import convert_to_tensor from keras.src.backend.numpy.core import is_tensor @@ -191,7 +188,7 @@ def max_pool( padding="valid", data_format=None, ): - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 pool_size = _convert_to_spatial_operand( pool_size, num_spatial_dims, data_format @@ -200,7 +197,7 @@ def max_pool( strides = _convert_to_spatial_operand( strides, num_spatial_dims, data_format ) - return _pool(inputs, -jnp.inf, lax.max, pool_size, strides, padding) + return _pool(inputs, -np.inf, lax.max, pool_size, strides, padding) def average_pool( @@ -210,7 +207,7 @@ def average_pool( padding, data_format=None, ): - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 pool_size = _convert_to_spatial_operand( pool_size, num_spatial_dims, data_format @@ -233,7 +230,7 @@ def average_pool( (a if b != 1 else 1) for (a, b) in zip(inputs.shape, pool_size) ] window_counts = _pool( - jnp.ones(shape, inputs.dtype), + np.ones(shape, inputs.dtype), 0.0, lax.add, pool_size, @@ -276,7 +273,7 @@ def conv( data_format=None, dilation_rate=1, ): - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 dimension_numbers = _convert_to_lax_conv_dimension_numbers( num_spatial_dims, @@ -328,7 +325,7 @@ def depthwise_conv( data_format=None, dilation_rate=1, ): - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 dimension_numbers = _convert_to_lax_conv_dimension_numbers( num_spatial_dims, @@ -350,7 +347,7 @@ def depthwise_conv( feature_group_count = ( inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1] ) - kernel = jnp.reshape( + kernel = np.reshape( kernel if is_tensor(kernel) else kernel.numpy(), kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]), ) @@ -376,7 +373,7 @@ def separable_conv( data_format=None, dilation_rate=1, ): - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) depthwise_conv_output = depthwise_conv( inputs, depthwise_kernel, @@ -404,7 +401,7 @@ def conv_transpose( data_format=None, dilation_rate=1, ): - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) num_spatial_dims = inputs.ndim - 2 padding_values = compute_conv_transpose_padding_args_for_jax( input_shape=inputs.shape, @@ -508,7 +505,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): log_prob = log_softmax(output, axis=axis) else: output = output / np.sum(output, axis, keepdims=True) - output = np.clip(output, epsilon(), 1.0 - epsilon()) + output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) log_prob = np.log(output) return -np.sum(target * log_prob, axis=axis) @@ -535,7 +532,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): log_prob = log_softmax(output, axis=axis) else: output = output / np.sum(output, axis, keepdims=True) - output = np.clip(output, epsilon(), 1.0 - epsilon()) + output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) log_prob = np.log(output) target = one_hot(target, output.shape[axis], axis=axis) return -np.sum(target * log_prob, axis=axis) @@ -555,7 +552,7 @@ def binary_crossentropy(target, output, from_logits=False): if from_logits: output = sigmoid(output) - output = np.clip(output, epsilon(), 1.0 - epsilon()) + output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) bce = target * np.log(output) bce += (1.0 - target) * np.log(1.0 - output) return -bce @@ -571,7 +568,7 @@ def moments(x, axes, keepdims=False, synchronized=False): # workaround, we simply perform the operations on float32 and convert back # to float16 need_cast = False - ori_dtype = standardize_dtype(x.dtype) + ori_dtype = backend.standardize_dtype(x.dtype) if ori_dtype == "float16": need_cast = True x = cast(x, "float32") @@ -615,3 +612,372 @@ def batch_normalization( res = res + offset return x * inv + res + + +def ctc_loss(target, output, target_length, output_length, mask_index=0): + # Ref: https://github.com/google-deepmind/optax + # optax.ctc_loss_with_forward_probs + target = convert_to_tensor(target, dtype="int32") + output = convert_to_tensor(output) + target_length = convert_to_tensor(target_length, "int32") + output_length = convert_to_tensor(output_length, "int32") + batch_size, max_input_length, num_classes = output.shape + batch_size, max_label_length = target.shape + log_epsilon = -1e5 + + # Ensure that the dtype promotion behavior matchs that of `tf.nn.ctc_loss` + dtype = backend.result_type(output.dtype, "float32") + output = output.astype(dtype) + + def _lengths_to_paddings(lengths, max_length): + indices = np.arange(max_length).reshape( + (1,) * lengths.ndim + (max_length,) + ) + lengths = np.expand_dims(lengths, axis=-1) + elem_valid = indices < lengths + return np.logical_not(elem_valid) + + target_paddings = _lengths_to_paddings(target_length, max_label_length) + output_paddings = _lengths_to_paddings(output_length, max_input_length) + target_paddings = target_paddings.astype(output.dtype) + output_paddings = output_paddings.astype(output.dtype) + + logprobs = log_softmax(output, axis=-1) + label_lengths = max_label_length - np.sum(target_paddings, axis=1).astype( + np.int32 + ) + + # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1]. + repeat = (target[:, :-1] == target[:, 1:]).astype(np.float32) + repeat = np.pad(repeat, ((0, 0), (0, 1))) + + logprobs_phi = logprobs[:, :, mask_index : mask_index + 1] # [B, T, 1] + logprobs_phi = np.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] + + _one_hot = one_hot(target, num_classes=num_classes) # [B, N, K] + logprobs_emit = np.einsum("btk,bnk->btn", logprobs, _one_hot) + logprobs_emit = np.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N] + + # [B, N] + logalpha_phi_init = ( + np.ones((batch_size, max_label_length + 1), dtype=output.dtype) + * log_epsilon + ) + logalpha_phi_init[:, 0] = 0.0 + logalpha_emit_init = ( + np.ones((batch_size, max_label_length), dtype=output.dtype) + * log_epsilon + ) + + def update_phi_score(phi, added_score): + # Update `phi[:, 1:]`` with adding `added_score` in log space. + return np.concatenate( + [phi[:, :1], np.logaddexp(phi[:, 1:], added_score)], axis=-1 + ) + + def loop_body(prev, x): + prev_phi, prev_emit = prev + # emit-to-phi epsilon transition, except if the next label is repetition + prev_phi_orig = prev_phi + prev_phi = update_phi_score(prev_phi, prev_emit + log_epsilon * repeat) + + logprob_emit, logprob_phi, pad = x + + # phi-to-emit transition + next_emit = np.logaddexp( + prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit + ) + # self-loop transition + next_phi = prev_phi + logprob_phi + # emit-to-phi blank transition only when the next label is repetition + next_phi = update_phi_score( + next_phi, prev_emit + logprob_phi + log_epsilon * (1.0 - repeat) + ) + + pad = pad.reshape((batch_size, 1)) + next_emit = pad * prev_emit + (1.0 - pad) * next_emit + next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi + + return (next_phi, next_emit), (next_phi, next_emit) + + def np_scan(f, init, xs): + carry = init + ys = [] + for x in zip(*xs): + carry, y = f(carry, x) + ys.append(y) + result = [] + for i in range(len(ys[0])): + result.append(np.stack([y[i] for y in ys])) + return carry, result + + xs = (logprobs_emit, logprobs_phi, output_paddings.transpose((1, 0))) + _, (logalpha_phi, logalpha_emit) = np_scan( + loop_body, (logalpha_phi_init, logalpha_emit_init), xs + ) + + # last row needs to be updated with the last epsilon transition + logalpha_phi_last = update_phi_score(logalpha_phi[-1], logalpha_emit[-1]) + logalpha_phi[-1] = logalpha_phi_last + + # extract per_seq_loss + # [B, N+1] + _one_hot = one_hot(label_lengths, num_classes=max_label_length + 1) + per_seq_loss = -np.einsum("bn,bn->b", logalpha_phi_last, _one_hot) + return per_seq_loss + + +def _ctc_greedy_decode( + inputs, + sequence_lengths, + merge_repeated=True, + mask_index=None, +): + inputs = convert_to_tensor(inputs) + sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32") + batch_size, max_length, num_classes = inputs.shape + + if mask_index is None: + mask_index = num_classes - 1 + + indices = np.argmax(inputs, axis=-1).astype("int32") + scores = np.max(inputs, axis=-1) + + seqlen_mask = np.arange(max_length)[None, :] + seqlen_mask = seqlen_mask >= sequence_lengths[:, None] + + indices = np.where(seqlen_mask, mask_index, indices) + scores = np.where(seqlen_mask, 0.0, scores) + + if merge_repeated: + repeat_mask = indices[:, 1:] == indices[:, :-1] + repeat_mask = np.pad(repeat_mask, ((0, 0), (1, 0))) + indices = np.where(repeat_mask, mask_index, indices) + + # We set to -1 for blank labels + invalid_mask = indices == mask_index + indices = np.where(invalid_mask, -1, indices) + + # We rearrange the indices by moving `mask_index` to the end of the array + order = np.expand_dims(np.arange(max_length), axis=0) # [1, N] + order = np.tile(order, (batch_size, 1)) # [B, N] + order = np.where(invalid_mask, max_length, order) + order = np.argsort(order, axis=-1) + indices = np.take_along_axis(indices, order, axis=-1) + + scores = -np.sum(scores, axis=1)[:, None] + indices = np.expand_dims(indices, axis=0) + return indices, scores + + +def _ctc_beam_search_decode( + inputs, + sequence_lengths, + beam_width=100, + top_paths=1, + mask_index=None, +): + inputs = convert_to_tensor(inputs) + sequence_lengths = convert_to_tensor(sequence_lengths) + + batch_size, max_seq_len, num_classes = inputs.shape + inputs = log_softmax(inputs, axis=-1) + seqlen_mask = np.arange(max_seq_len)[None, :] >= sequence_lengths[:, None] + + if mask_index is None: + mask_index = num_classes - 1 + + # This is a workaround for the fact that np.argsort does not support + # the order parameter which is used to break ties when scores are equal. + # For compatibility with the tensorflow implementation, we flip the inputs + # and the mask_index, and then flip the classes back to the correct indices + inputs = np.flip(inputs, axis=2) + mask_index = num_classes - mask_index - 1 + + _pad = -1 + + init_paths = np.full( + (batch_size, 2 * beam_width, max_seq_len), _pad, dtype=np.int32 + ) + + num_init_paths = np.min(np.array([num_classes, beam_width])) + max_classes = np.argsort(inputs[:, 0], axis=1)[:, -num_init_paths:] + init_classes = np.where(max_classes == mask_index, _pad, max_classes) + init_paths[:, :num_init_paths, 0] = init_classes + + init_scores = np.full( + (batch_size, 2 * beam_width), -np.inf, dtype=inputs.dtype + ) + init_scores[:, :num_init_paths] = np.take_along_axis( + inputs[:, 0], max_classes, axis=1 + ) + init_masked = init_paths[:, :, 0] == _pad + + def _extend_paths(paths, scores, masked, x): + paths = np.repeat(paths, num_classes, axis=0) + scores = np.repeat(scores, num_classes) + masked = np.repeat(masked, num_classes) + + path_tail_index = np.argmax(paths == _pad, axis=1) + paths_arange = np.arange(2 * beam_width * num_classes) + path_tails = paths[paths_arange, path_tail_index - 1] + path_tails = np.where(path_tail_index == 0, _pad, path_tails) + + classes = np.arange(num_classes) + classes[mask_index] = _pad + classes = np.tile(classes, 2 * beam_width) + + prev_masked = masked + masked = classes == _pad + + masked_repeat = ~prev_masked & (path_tails == classes) + classes = np.where(masked_repeat, _pad, classes) + paths[paths_arange, path_tail_index] = classes + + x = np.tile(x, 2 * beam_width) + scores = scores + x + + return paths, scores, masked + + def _merge_scores(unique_inverse, scores): + scores_max = np.max(scores) + scores_exp = np.exp(scores - scores_max) + scores = np.zeros_like(scores) + for i, u in enumerate(unique_inverse): + scores[u] += scores_exp[i] + scores = np.log(scores) + scores_max + return scores + + def _prune_paths(paths, scores, masked): + paths, unique_inverse = np.unique(paths, return_inverse=True, axis=0) + pad_size = (2 * num_classes * beam_width) - len(paths) + if pad_size > 0: + paths = np.pad(paths, [[0, pad_size], [0, 0]], constant_values=_pad) + paths = paths[: 2 * num_classes * beam_width] + if len(unique_inverse.shape) >= 2: + unique_inverse = np.squeeze(unique_inverse, axis=1) + + emit_scores = np.where(masked, -np.inf, scores) + mask_scores = np.where(masked, scores, -np.inf) + + emit_scores = _merge_scores(unique_inverse, emit_scores) + mask_scores = _merge_scores(unique_inverse, mask_scores) + + total_scores = np.logaddexp(emit_scores, mask_scores) + top_indices = np.argsort(total_scores, kind="stable")[-beam_width:] + + paths = paths[top_indices] + emit_scores = emit_scores[top_indices] + mask_scores = mask_scores[top_indices] + + paths = np.tile(paths, (2, 1)) + scores = np.concatenate([emit_scores, mask_scores]) + masked = np.concatenate( + [np.zeros(beam_width, bool), np.ones(beam_width, bool)] + ) + + return paths, scores, masked + + def _decode_step(paths, scores, masked, x): + paths, scores, masked = _extend_paths(paths, scores, masked, x) + paths, scores, masked = _prune_paths(paths, scores, masked) + return paths, scores, masked + + def _step(prev, x): + paths, scores, masked = prev + x, seqlen_mask = x + if not seqlen_mask: + paths, scores, masked = _decode_step(paths, scores, masked, x) + return (paths, scores, masked), None + + def _decode_batch( + init_paths, init_scores, init_masked, inputs, seqlen_mask + ): + def np_scan_only_carry(f, init, xs): + carry = init + for x in zip(*xs): + carry, y = f(carry, x) + return carry, None + + (paths, scores, masked), _ = np_scan_only_carry( + _step, + (init_paths, init_scores, init_masked), + (inputs[1:], seqlen_mask[1:]), + ) + + paths, unique_inverse = np.unique(paths, return_inverse=True, axis=0) + pad_size = (2 * num_classes * beam_width) - len(paths) + if pad_size > 0: + paths = np.pad(paths, [[0, pad_size], [0, 0]], constant_values=_pad) + paths = paths[: 2 * num_classes * beam_width] + if len(unique_inverse.shape) >= 2: + unique_inverse = np.squeeze(unique_inverse, axis=1) + scores = _merge_scores(unique_inverse, scores) + + top_indices = np.argsort(scores)[-top_paths:][::-1] + paths = paths[top_indices] + scores = scores[top_indices] + + return paths, scores + + results = [ + _decode_batch(p, s, m, i, sm) + for p, s, m, i, sm in zip( + init_paths, init_scores, init_masked, inputs, seqlen_mask + ) + ] + paths = np.stack([r[0] for r in results]) + scores = np.stack([r[1] for r in results]) + + # convert classes back to the correct indices + paths = np.where(paths == _pad, _pad, num_classes - paths - 1) + paths = np.transpose(paths, [1, 0, 2]) + return paths, scores + + +def ctc_decode( + inputs, + sequence_lengths, + strategy="greedy", + beam_width=100, + top_paths=1, + merge_repeated=True, + mask_index=0, +): + inputs = convert_to_tensor(inputs) + dtype = backend.result_type(inputs.dtype, "float32") + inputs = cast(inputs, dtype) + + if strategy == "greedy": + return _ctc_greedy_decode( + inputs, + sequence_lengths, + merge_repeated=merge_repeated, + mask_index=mask_index, + ) + elif strategy == "beam_search": + return _ctc_beam_search_decode( + inputs, + sequence_lengths, + beam_width=beam_width, + top_paths=top_paths, + mask_index=mask_index, + ) + else: + raise ValueError( + f"Invalid strategy {strategy}. Supported values are " + "'greedy' and 'beam_search'." + ) + + +def psnr(x1, x2, max_val): + if x1.shape != x2.shape: + raise ValueError( + f"Input shapes {x1.shape} and {x2.shape} must " + "match for PSNR calculation. " + ) + + max_val = convert_to_tensor(max_val, dtype=x2.dtype) + mse = np.mean(np.square(x1 - x2)) + psnr = 20 * np.log10(max_val) - 10 * np.log10(mse) + return psnr diff --git a/keras/src/backend/numpy/numpy.py b/keras/src/backend/numpy/numpy.py index 0793cee0d03..b820bc91d50 100644 --- a/keras/src/backend/numpy/numpy.py +++ b/keras/src/backend/numpy/numpy.py @@ -688,8 +688,8 @@ def moveaxis(x, source, destination): return np.moveaxis(x, source=source, destination=destination) -def nan_to_num(x): - return np.nan_to_num(x) +def nan_to_num(x, nan=0.0, posinf=None, neginf=None): + return np.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) def ndim(x): @@ -937,6 +937,10 @@ def vstack(xs): return np.vstack(xs) +def vectorize(pyfunc, *, excluded=None, signature=None): + return np.vectorize(pyfunc, excluded=excluded, signature=signature) + + def where(condition, x1, x2): if x1 is not None and x2 is not None: if not isinstance(x1, (int, float)): @@ -1090,3 +1094,15 @@ def correlate(x1, x2, mode="valid"): x1 = convert_to_tensor(x1, dtype) x2 = convert_to_tensor(x2, dtype) return np.correlate(x1, x2, mode) + + +def select(condlist, choicelist, default=0): + return np.select(condlist, choicelist, default=default) + + +def slogdet(x): + return tuple(np.linalg.slogdet(x)) + + +def argpartition(x, kth, axis=-1): + return np.argpartition(x, kth, axis).astype("int32") diff --git a/keras/src/backend/tensorflow/linalg.py b/keras/src/backend/tensorflow/linalg.py index 15459f41133..e76b00ea3fd 100644 --- a/keras/src/backend/tensorflow/linalg.py +++ b/keras/src/backend/tensorflow/linalg.py @@ -1,5 +1,4 @@ import tensorflow as tf -from tensorflow.experimental import numpy as tfnp from keras.src.backend import config from keras.src.backend import standardize_dtype @@ -36,6 +35,8 @@ def lu_factor(a): def norm(x, ord=None, axis=None, keepdims=False): + from keras.src.backend.tensorflow.numpy import moveaxis + x = convert_to_tensor(x) x_shape = x.shape ndim = x_shape.rank @@ -129,7 +130,7 @@ def norm(x, ord=None, axis=None, keepdims=False): keepdims=keepdims, ) elif ord in ("nuc", 2, -2): - x = tfnp.moveaxis(x, axis, (-2, -1)) + x = moveaxis(x, axis, (-2, -1)) if ord == -2: x = tf.math.reduce_min( tf.linalg.svd(x, compute_uv=False), axis=-1 diff --git a/keras/src/backend/tensorflow/math.py b/keras/src/backend/tensorflow/math.py index ffc7f99f0da..f9071b3c2a0 100644 --- a/keras/src/backend/tensorflow/math.py +++ b/keras/src/backend/tensorflow/math.py @@ -1,5 +1,4 @@ import tensorflow as tf -from tensorflow.experimental import numpy as tfnp from keras.src.backend import config from keras.src.backend import standardize_dtype @@ -260,6 +259,8 @@ def solve(a, b): def norm(x, ord=None, axis=None, keepdims=False): + from keras.src.backend.tensorflow.numpy import moveaxis + x = convert_to_tensor(x) x_shape = x.shape ndim = x_shape.rank @@ -328,7 +329,7 @@ def norm(x, ord=None, axis=None, keepdims=False): keepdims=keepdims, ) else: - x = tfnp.moveaxis(x, axis, (-2, -1)) + x = moveaxis(x, axis, (-2, -1)) if ord == -2: x = tf.math.reduce_min( tf.linalg.svd(x, compute_uv=False), axis=-1 diff --git a/keras/src/backend/tensorflow/nn.py b/keras/src/backend/tensorflow/nn.py index 2167087198f..e590d10e2d3 100644 --- a/keras/src/backend/tensorflow/nn.py +++ b/keras/src/backend/tensorflow/nn.py @@ -3,12 +3,10 @@ import tensorflow as tf -from keras.src.backend import standardize_data_format -from keras.src.backend import standardize_dtype +from keras.src import backend from keras.src.backend.common.backend_utils import ( compute_conv_transpose_output_shape, ) -from keras.src.backend.config import epsilon from keras.src.backend.tensorflow.core import cast from keras.src.backend.tensorflow.core import convert_to_tensor @@ -75,30 +73,7 @@ def selu(x): def gelu(x, approximate=True): x = convert_to_tensor(x) - # we need to explicitly implement gelu because bfloat16 will trigger - # DTypePromotionError when using enable_numpy_behavior() - if approximate: - coeff = tf.constant(0.044715, x.dtype) - return ( - tf.constant(0.5, x.dtype) - * x - * ( - tf.constant(1.0, x.dtype) - + tf.math.tanh( - tf.constant(0.7978845608028654, x.dtype) - * (x + coeff * tf.pow(x, 3)) - ) - ) - ) - else: - return ( - tf.constant(0.5, x.dtype) - * x - * ( - tf.constant(1.0, x.dtype) - + tf.math.erf(x / tf.constant(1.4142135623730951, x.dtype)) - ) - ) + return tf.nn.gelu(x, approximate=approximate) def softmax(x, axis=-1): @@ -162,7 +137,7 @@ def max_pool( padding="valid", data_format=None, ): - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) strides = pool_size if strides is None else strides padding = padding.upper() tf_data_format = _convert_data_format("channels_last", len(inputs.shape)) @@ -190,7 +165,7 @@ def average_pool( padding="valid", data_format=None, ): - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) strides = pool_size if strides is None else strides padding = padding.upper() tf_data_format = _convert_data_format("channels_last", len(inputs.shape)) @@ -268,7 +243,7 @@ def _conv(): def _conv_xla(): return _conv() - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) if data_format == "channels_last": channels = inputs.shape[-1] else: @@ -277,6 +252,12 @@ def _conv_xla(): # If kernel's in_channel does not match input's channels, it indicates # convolution is broken down into groups. return _conv_xla() + if data_format == "channels_first" and len(inputs.shape) == 5: + inputs = convert_to_tensor(inputs) + if inputs.device.split(":")[-2] == "CPU": + inputs = tf.transpose(inputs, perm=(0, 2, 3, 4, 1)) + data_format = "channels_last" + return tf.transpose(_conv(), perm=(0, 4, 1, 2, 3)) return _conv() @@ -288,7 +269,7 @@ def depthwise_conv( data_format=None, dilation_rate=1, ): - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) num_spatial_dims = len(inputs.shape) - 2 if num_spatial_dims > 2: raise ValueError( @@ -351,7 +332,7 @@ def separable_conv( data_format=None, dilation_rate=1, ): - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) num_spatial_dims = len(inputs.shape) - 2 if num_spatial_dims > 2: raise ValueError( @@ -414,7 +395,7 @@ def conv_transpose( data_format=None, dilation_rate=1, ): - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) tf_data_format = _convert_data_format(data_format, len(inputs.shape)) kernel_size = kernel.shape[:-2] filters = kernel.shape[-2] @@ -446,7 +427,7 @@ def conv_transpose( def one_hot(x, num_classes, axis=-1, dtype="float32", sparse=False): - x = convert_to_tensor(x) + x = convert_to_tensor(x, dtype="int64") if dtype is None: dtype = "float32" if sparse: @@ -597,7 +578,9 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): output = output / tf.reduce_sum(output, axis, keepdims=True) # Compute cross entropy from probabilities. - output = tf.clip_by_value(output, epsilon(), 1.0 - epsilon()) + output = tf.clip_by_value( + output, backend.epsilon(), 1.0 - backend.epsilon() + ) return -tf.reduce_sum(target * tf.math.log(output), axis) @@ -653,7 +636,9 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): ) if not from_logits: - output = tf.clip_by_value(output, epsilon(), 1 - epsilon()) + output = tf.clip_by_value( + output, backend.epsilon(), 1 - backend.epsilon() + ) output = tf.math.log(output) result = tf.nn.sparse_softmax_cross_entropy_with_logits( @@ -702,7 +687,9 @@ def binary_crossentropy(target, output, from_logits=False): ) # Compute cross entropy from probabilities. - output = tf.clip_by_value(output, epsilon(), 1.0 - epsilon()) + output = tf.clip_by_value( + output, backend.epsilon(), 1.0 - backend.epsilon() + ) bce = target * tf.math.log(output) bce += (1 - target) * tf.math.log(1 - output) return -bce @@ -713,7 +700,7 @@ def moments(x, axes, keepdims=False, synchronized=False): # workaround, we simply perform the operations on float32 and convert back # to float16 need_cast = False - ori_dtype = standardize_dtype(x.dtype) + ori_dtype = backend.standardize_dtype(x.dtype) if ori_dtype in ("float16", "bfloat16"): need_cast = True x = cast(x, "float32") @@ -797,30 +784,18 @@ def ctc_loss( output_length, mask_index=0, ): - """Runs CTC (Connectionist Temporal Classification) loss on each - batch element. - - Arguments: - target: Tensor `(batch_size, max_length)` containing the - target sequences in integer format. - output: Tensor `(batch_size, max_length, num_classes)` - containing the output of the softmax. - target_length: Tensor `(batch_size,)` containing the sequence length - for each target sequence in the batch. - output_length: Tensor `(batch_size,)` containing the sequence length - for each output sequence in the batch. - mask_index: The value in `target` and `output` that represents the - blank label. - - Returns: - A tensor of shape `(batch_size,)` containing the CTC loss for each - sample in the batch. - """ - target = tf.convert_to_tensor(target) + target = convert_to_tensor(target) + output = convert_to_tensor(output) target = tf.cast(target, dtype="int32") - output = tf.convert_to_tensor(output) - output = tf.cast(output, dtype="float32") - return tf.nn.ctc_loss( + + # `tf.nn.ctc_loss` will internally cast to float32 when the input is float16 + # or bfloat16. Additionally, it will raise an error when the input is + # float64. As a result, we perform the casting externally and add support + # for float64. + result_dtype = backend.result_type(output.dtype, "float32") + compute_dtype = "float32" if result_dtype == "float64" else result_dtype + output = tf.cast(output, compute_dtype) + loss = tf.nn.ctc_loss( labels=target, logits=output, label_length=target_length, @@ -828,3 +803,84 @@ def ctc_loss( blank_index=mask_index, logits_time_major=False, ) + return tf.cast(loss, result_dtype) + + +def ctc_decode( + inputs, + sequence_lengths, + strategy="greedy", + beam_width=100, + top_paths=1, + merge_repeated=True, + mask_index=0, +): + inputs = convert_to_tensor(inputs) + input_shape = tf.shape(inputs) + num_samples, num_steps = input_shape[0], input_shape[1] + inputs = tf.transpose(inputs, (1, 0, 2)) + + dtype = backend.result_type(inputs.dtype, "float32") + inputs = tf.cast(inputs, dtype) + + sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32") + if strategy == "greedy": + (decoded, scores) = tf.nn.ctc_greedy_decoder( + inputs=inputs, + sequence_length=sequence_lengths, + merge_repeated=merge_repeated, + blank_index=mask_index, + ) + elif strategy == "beam_search": + # Move `mask_index` column to the last position since this is the + # default for `tf.nn.ctc_beam_search_decoder` + if mask_index is not None: + inputs_before = inputs[..., :mask_index] + inputs_mask = inputs[..., mask_index : mask_index + 1] + inputs_after = inputs[..., mask_index + 1 :] + inputs = tf.concat( + [inputs_before, inputs_after, inputs_mask], axis=-1 + ) + (decoded, scores) = tf.nn.ctc_beam_search_decoder( + inputs=inputs, + sequence_length=sequence_lengths, + beam_width=beam_width, + top_paths=top_paths, + ) + else: + raise ValueError( + f"Invalid strategy {strategy}. Supported values are " + "'greedy' and 'beam_search'." + ) + + # Postprocess sparse tensor + decoded_dense = [] + for st in decoded: + st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps)) + decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1)) + decoded_dense = tf.stack(decoded_dense, axis=0) + decoded_dense = tf.cast(decoded_dense, "int32") + + # We need to recover the labels because we swapped the indices earlier + if strategy == "beam_search" and mask_index is not None: + if mask_index < 0: + mask_index = mask_index + input_shape[-1] + decoded_dense = tf.where( + decoded_dense >= mask_index, decoded_dense + 1, decoded_dense + ) + return decoded_dense, scores + + +def psnr(x1, x2, max_val): + from keras.src.backend.tensorflow.numpy import log10 + + if x1.shape != x2.shape: + raise ValueError( + f"Input shapes {x1.shape} and {x2.shape} must " + "match for PSNR calculation. " + ) + + max_val = convert_to_tensor(max_val, dtype=x2.dtype) + mse = tf.reduce_mean(tf.square(x1 - x2)) + psnr = 20 * log10(max_val) - 10 * log10(mse) + return psnr diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index d033c40536b..51ee5833cad 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -7,7 +7,6 @@ import numpy as np import tensorflow as tf -from tensorflow.experimental import numpy as tfnp from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops from keras.src import tree @@ -16,9 +15,11 @@ from keras.src.backend.common import dtypes from keras.src.backend.common.backend_utils import canonicalize_axis from keras.src.backend.common.backend_utils import to_tuple_or_list +from keras.src.backend.common.backend_utils import vectorize_impl from keras.src.backend.tensorflow import sparse from keras.src.backend.tensorflow.core import cast from keras.src.backend.tensorflow.core import convert_to_tensor +from keras.src.backend.tensorflow.core import shape as shape_op @sparse.elementwise_binary_union(tf.sparse.add) @@ -575,12 +576,18 @@ def mean(x, axis=None, keepdims=False): def max(x, axis=None, keepdims=False, initial=None): + x = convert_to_tensor(x) + # The TensorFlow numpy API implementation doesn't support `initial` so we # handle it manually here. if initial is not None: - return tf.math.maximum( - tfnp.max(x, axis=axis, keepdims=keepdims), initial - ) + if standardize_dtype(x.dtype) == "bool": + x = tf.reduce_any(x, axis=axis, keepdims=keepdims) + x = tf.math.maximum(tf.cast(x, "int32"), tf.cast(initial, "int32")) + return tf.cast(x, "bool") + else: + x = tf.reduce_max(x, axis=axis, keepdims=keepdims) + return tf.math.maximum(x, initial) # TensorFlow returns -inf by default for an empty list, but for consistency # with other backends and the numpy API we want to throw in this case. @@ -592,7 +599,10 @@ def max(x, axis=None, keepdims=False, initial=None): message="Cannot compute the max of an empty tensor.", ) - return tfnp.max(x, axis=axis, keepdims=keepdims) + if standardize_dtype(x.dtype) == "bool": + return tf.reduce_any(x, axis=axis, keepdims=keepdims) + else: + return tf.reduce_max(x, axis=axis, keepdims=keepdims) def ones(shape, dtype=None): @@ -607,6 +617,7 @@ def zeros(shape, dtype=None): @sparse.elementwise_unary def absolute(x): + x = convert_to_tensor(x) # uintx and bool are always non-negative dtype = standardize_dtype(x.dtype) if "uint" in dtype or dtype == "bool": @@ -614,7 +625,6 @@ def absolute(x): return tf.abs(x) -@sparse.elementwise_unary def abs(x): return absolute(x) @@ -650,8 +660,6 @@ def append(x1, x2, axis=None): def arange(start, stop=None, step=1, dtype=None): - # tfnp.arange has trouble with dynamic Tensors in compiled function. - # tf.range does not. if dtype is None: dtypes_to_resolve = [ getattr(start, "dtype", type(start)), @@ -749,7 +757,7 @@ def _keepdims(x, y, axis): if axis is None: shape = [1 for _ in range(len(x.shape))] else: - shape = [tf.shape[i] for i in range(len(x.shape))] + shape = list(shape_op(x)) for axis in tree.flatten(axis): shape[axis] = 1 y = tf.reshape(y, shape) @@ -797,27 +805,34 @@ def array(x, dtype=None): def average(x, axis=None, weights=None): x = convert_to_tensor(x) - axis = to_tuple_or_list(axis) - dtypes_to_resolve = [x.dtype, float] - if weights is not None: + + if weights is None: # Treat all weights as 1 + dtype = dtypes.result_type(x.dtype, float) + x = tf.cast(x, dtype) + avg = tf.reduce_mean(x, axis=axis) + else: weights = convert_to_tensor(weights) - dtypes_to_resolve.append(weights.dtype) - result_dtype = dtypes.result_type(*dtypes_to_resolve) - compute_dtype = result_dtype - # TODO: since tfnp.average incorrectly promote bfloat16 to float64, we - # need to cast to float32 first and then cast back to bfloat16 - if compute_dtype == "bfloat16": - compute_dtype = "float32" - x = tf.cast(x, compute_dtype) - if weights is not None: - weights = tf.cast(weights, compute_dtype) - if axis is None: - x = tfnp.average(x, weights=weights, axis=None) - return tf.cast(x, result_dtype) - for a in axis: - # `tfnp.average` does not handle multiple axes. - x = tfnp.average(x, weights=weights, axis=a) - return tf.cast(x, result_dtype) + dtype = dtypes.result_type(x.dtype, weights.dtype, float) + x = tf.cast(x, dtype) + weights = tf.cast(weights, dtype) + + def _rank_equal_case(): + weights_sum = tf.reduce_sum(weights, axis=axis) + return tf.reduce_sum(x * weights, axis=axis) / weights_sum + + def _rank_not_equal_case(): + weights_sum = tf.reduce_sum(weights) + axes = tf.convert_to_tensor([[axis], [0]]) + return tf.tensordot(x, weights, axes) / weights_sum + + if axis is None: + avg = _rank_equal_case() + else: + if len(x.shape) == len(weights.shape): + avg = _rank_equal_case() + else: + avg = _rank_not_equal_case() + return avg def broadcast_to(x, shape): @@ -876,7 +891,8 @@ def conj(x): @sparse.elementwise_unary def copy(x): - return tfnp.copy(x) + x = convert_to_tensor(x) + return tf.identity(x) @sparse.densifying_unary(1) @@ -911,13 +927,60 @@ def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None): dtype = dtypes.result_type(x1.dtype, x2.dtype) x1 = tf.cast(x1, dtype) x2 = tf.cast(x2, dtype) - return tfnp.cross( - x1, - x2, - axisa=axisa, - axisb=axisb, - axisc=axisc, - axis=axis, + + if axis is not None: + axisa = axis + axisb = axis + axisc = axis + x1 = moveaxis(x1, axisa, -1) + x2 = moveaxis(x2, axisb, -1) + + def maybe_pad_zeros(x, size_of_last_dim): + def pad_zeros(x): + return tf.pad( + x, + tf.concat( + [ + tf.zeros([tf.rank(x) - 1, 2], "int32"), + tf.constant([[0, 1]], "int32"), + ], + axis=0, + ), + ) + + if isinstance(size_of_last_dim, int): + if size_of_last_dim == 2: + return pad_zeros(x) + return x + + return tf.cond( + tf.equal(size_of_last_dim, 2), lambda: pad_zeros(x), lambda: x + ) + + x1_dim = shape_op(x1)[-1] + x2_dim = shape_op(x2)[-1] + + x1 = maybe_pad_zeros(x1, x1_dim) + x2 = maybe_pad_zeros(x2, x2_dim) + + # Broadcast each other + shape = shape_op(x1) + + shape = tf.broadcast_dynamic_shape(shape, shape_op(x2)) + x1 = tf.broadcast_to(x1, shape) + x2 = tf.broadcast_to(x2, shape) + + c = tf.linalg.cross(x1, x2) + + if isinstance(x1_dim, int) and isinstance(x2_dim, int): + if (x1_dim == 2) & (x2_dim == 2): + return c[..., 2] + return moveaxis(c, -1, axisc) + + return tf.cond( + (x1_dim == 2) & (x2_dim == 2), + lambda: c[..., 2], + lambda: moveaxis(c, -1, axisc), ) @@ -944,20 +1007,74 @@ def cumsum(x, axis=None, dtype=None): def diag(x, k=0): - return tfnp.diag(x, k=k) + x = convert_to_tensor(x) + if len(x.shape) == 1: + return tf.cond( + tf.equal(tf.size(x), 0), + lambda: tf.zeros([builtins.abs(k), builtins.abs(k)], dtype=x.dtype), + lambda: tf.linalg.diag(x, k=k), + ) + elif len(x.shape) == 2: + return diagonal(x, offset=k) + else: + raise ValueError(f"`x` must be 1d or 2d. Received: x.shape={x.shape}") def diagonal(x, offset=0, axis1=0, axis2=1): - return tfnp.diagonal( - x, - offset=offset, - axis1=axis1, - axis2=axis2, - ) + x = convert_to_tensor(x) + x_rank = x.ndim + if ( + offset == 0 + and (axis1 == x_rank - 2 or axis1 == -2) + and (axis2 == x_rank - 1 or axis2 == -1) + ): + return tf.linalg.diag_part(x) + + x = moveaxis(x, (axis1, axis2), (-2, -1)) + x_shape = shape_op(x) + + def _zeros(): + return tf.zeros(tf.concat([x_shape[:-1], [0]], 0), dtype=x.dtype) + + if isinstance(x_shape[-1], int) and isinstance(x_shape[-2], int): + if offset <= -1 * x_shape[-2] or offset >= x_shape[-1]: + x = _zeros() + else: + x = tf.cond( + tf.logical_or( + tf.less_equal(offset, -1 * x_shape[-2]), + tf.greater_equal(offset, x_shape[-1]), + ), + lambda: _zeros(), + lambda: x, + ) + return tf.linalg.diag_part(x, k=offset) def diff(a, n=1, axis=-1): - return tfnp.diff(a, n=n, axis=axis) + a = convert_to_tensor(a) + if n == 0: + return a + elif n < 0: + raise ValueError(f"Order `n` must be non-negative. Received n={n}") + elif a.ndim == 0: + raise ValueError( + "`diff` requires input that is at least one dimensional. " + f"Received: a={a}" + ) + axis = canonicalize_axis(axis, a.ndim) + slice1 = [slice(None)] * a.ndim + slice2 = [slice(None)] * a.ndim + slice1[axis] = slice(1, None) + slice2[axis] = slice(None, -1) + slice1_tuple = tuple(slice1) + slice2_tuple = tuple(slice2) + for _ in range(n): + if standardize_dtype(a.dtype) == "bool": + a = tf.not_equal(a[slice1_tuple], a[slice2_tuple]) + else: + a = tf.subtract(a[slice1_tuple], a[slice2_tuple]) + return a def digitize(x, bins): @@ -1119,12 +1236,9 @@ def hstack(xs): if len(dtype_set) > 1: dtype = dtypes.result_type(*dtype_set) xs = tree.map_structure(lambda x: convert_to_tensor(x, dtype), xs) - rank = tf.rank(xs[0]) - return tf.cond( - tf.equal(rank, 1), - lambda: tf.concat(xs, axis=0), - lambda: tf.concat(xs, axis=1), - ) + if len(xs[0].shape) == 1: + return tf.concat(xs, axis=0) + return tf.concat(xs, axis=1) def identity(n, dtype=None): @@ -1152,7 +1266,6 @@ def isclose(x1, x2): @sparse.densifying_unary(True) def isfinite(x): - # `tfnp.isfinite` requires `enable_numpy_behavior`, so we reimplement it. x = convert_to_tensor(x) dtype_as_dtype = tf.as_dtype(x.dtype) if dtype_as_dtype.is_integer or not dtype_as_dtype.is_numeric: @@ -1161,7 +1274,6 @@ def isfinite(x): def isinf(x): - # `tfnp.isinf` requires `enable_numpy_behavior`, so we reimplement it. x = convert_to_tensor(x) dtype_as_dtype = tf.as_dtype(x.dtype) if dtype_as_dtype.is_integer or not dtype_as_dtype.is_numeric: @@ -1170,7 +1282,6 @@ def isinf(x): def isnan(x): - # `tfnp.isnan` requires `enable_numpy_behavior`, so we reimplement it. x = convert_to_tensor(x) dtype_as_dtype = tf.as_dtype(x.dtype) if dtype_as_dtype.is_integer or not dtype_as_dtype.is_numeric: @@ -1206,15 +1317,37 @@ def linspace( float, ] dtype = dtypes.result_type(*dtypes_to_resolve) - return tfnp.linspace( - start, - stop, - num=num, - endpoint=endpoint, - retstep=retstep, - dtype=dtype, - axis=axis, - ) + else: + dtype = standardize_dtype(dtype) + start = convert_to_tensor(start, dtype=dtype) + stop = convert_to_tensor(stop, dtype=dtype) + if num < 0: + raise ValueError( + f"`num` must be a non-negative integer. Received: num={num}" + ) + step = tf.convert_to_tensor(np.nan) + if endpoint: + result = tf.linspace(start, stop, num, axis=axis) + if num > 1: + step = (stop - start) / (num - 1) + else: + # tf.linspace doesn't support endpoint=False, so we manually handle it + if num > 0: + step = (stop - start) / num + if num > 1: + new_stop = tf.cast(stop, step.dtype) - step + start = tf.cast(start, new_stop.dtype) + result = tf.linspace(start, new_stop, num, axis=axis) + else: + result = tf.linspace(start, stop, num, axis=axis) + if dtype is not None: + if "int" in dtype: + result = tf.floor(result) + result = tf.cast(result, dtype) + if retstep: + return (result, step) + else: + return result @sparse.densifying_unary(-np.inf) @@ -1271,9 +1404,6 @@ def logaddexp(x1, x2): dtype = dtypes.result_type(x1.dtype, x2.dtype, float) x1 = tf.cast(x1, dtype) x2 = tf.cast(x2, dtype) - - # Below is the same implementation as tfnp.logaddexp using all native - # ops to prevent incorrect promotion of bfloat16. delta = x1 - x2 return tf.where( tf.math.is_nan(delta), @@ -1300,24 +1430,15 @@ def logical_or(x1, x2): def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0): - if dtype is None: - dtypes_to_resolve = [ - getattr(start, "dtype", type(start)), - getattr(stop, "dtype", type(stop)), - float, - ] - dtype = dtypes.result_type(*dtypes_to_resolve) - start = tf.cast(start, dtype) - stop = tf.cast(stop, dtype) - return tfnp.logspace( - start, - stop, + result = linspace( + start=start, + stop=stop, num=num, endpoint=endpoint, - base=base, dtype=dtype, axis=axis, ) + return tf.pow(tf.cast(base, result.dtype), result) @sparse.elementwise_binary_union(tf.sparse.maximum, densify_mixed=True) @@ -1345,12 +1466,17 @@ def meshgrid(*x, indexing="xy"): def min(x, axis=None, keepdims=False, initial=None): x = convert_to_tensor(x) + # The TensorFlow numpy API implementation doesn't support `initial` so we # handle it manually here. if initial is not None: - return tf.math.minimum( - tfnp.min(x, axis=axis, keepdims=keepdims), initial - ) + if standardize_dtype(x.dtype) == "bool": + x = tf.reduce_all(x, axis=axis, keepdims=keepdims) + x = tf.math.minimum(tf.cast(x, "int32"), tf.cast(initial, "int32")) + return tf.cast(x, "bool") + else: + x = tf.reduce_min(x, axis=axis, keepdims=keepdims) + return tf.math.minimum(x, initial) # TensorFlow returns inf by default for an empty list, but for consistency # with other backends and the numpy API we want to throw in this case. @@ -1362,7 +1488,10 @@ def min(x, axis=None, keepdims=False, initial=None): message="Cannot compute the min of an empty tensor.", ) - return tfnp.min(x, axis=axis, keepdims=keepdims) + if standardize_dtype(x.dtype) == "bool": + return tf.reduce_all(x, axis=axis, keepdims=keepdims) + else: + return tf.reduce_min(x, axis=axis, keepdims=keepdims) @sparse.elementwise_binary_union(tf.sparse.minimum, densify_mixed=True) @@ -1392,10 +1521,27 @@ def mod(x1, x2): def moveaxis(x, source, destination): - return tfnp.moveaxis(x, source=source, destination=destination) + x = convert_to_tensor(x) + + _source = to_tuple_or_list(source) + _destination = to_tuple_or_list(destination) + _source = tuple(canonicalize_axis(i, x.ndim) for i in _source) + _destination = tuple(canonicalize_axis(i, x.ndim) for i in _destination) + if len(_source) != len(_destination): + raise ValueError( + "Inconsistent number of `source` and `destination`. " + f"Received: source={source}, destination={destination}" + ) + # Directly return x if no movement is required + if _source == _destination: + return x + perm = [i for i in range(x.ndim) if i not in _source] + for dest, src in sorted(zip(_destination, _source)): + perm.insert(dest, src) + return tf.transpose(x, perm) -def nan_to_num(x): +def nan_to_num(x, nan=0.0, posinf=None, neginf=None): x = convert_to_tensor(x) dtype = x.dtype @@ -1403,14 +1549,18 @@ def nan_to_num(x): if dtype_as_dtype.is_integer or not dtype_as_dtype.is_numeric: return x - # Replace NaN with 0 - x = tf.where(tf.math.is_nan(x), tf.constant(0, dtype), x) + # Replace NaN with `nan` + x = tf.where(tf.math.is_nan(x), tf.constant(nan, dtype), x) - # Replace positive infinity with dtype.max - x = tf.where(tf.math.is_inf(x) & (x > 0), tf.constant(dtype.max, dtype), x) + # Replace positive infinity with `posinf` or `dtype.max` + if posinf is None: + posinf = dtype.max + x = tf.where(tf.math.is_inf(x) & (x > 0), tf.constant(posinf, dtype), x) - # Replace negative infinity with dtype.min - x = tf.where(tf.math.is_inf(x) & (x < 0), tf.constant(dtype.min, dtype), x) + # Replace negative infinity with `neginf` or `dtype.min` + if neginf is None: + neginf = dtype.min + x = tf.where(tf.math.is_inf(x) & (x < 0), tf.constant(neginf, dtype), x) return x @@ -1615,8 +1765,6 @@ def reciprocal(x): def repeat(x, repeats, axis=None): - # tfnp.repeat has trouble with dynamic Tensors in compiled function. - # tf.repeat does not. x = convert_to_tensor(x) # TODO: tf.repeat doesn't support uint16 if standardize_dtype(x.dtype) == "uint16": @@ -1640,7 +1788,14 @@ def reshape(x, newshape): def roll(x, shift, axis=None): - return tfnp.roll(x, shift, axis=axis) + x = convert_to_tensor(x) + if axis is not None: + return tf.roll(x, shift=shift, axis=axis) + + # If axis is None, the roll happens as a 1-d tensor. + original_shape = tf.shape(x) + x = tf.roll(tf.reshape(x, [-1]), shift, 0) + return tf.reshape(x, original_shape) @sparse.elementwise_unary @@ -1695,14 +1850,12 @@ def split(x, indices_or_sections, axis=0): if not isinstance(indices_or_sections, int): # `tf.split` requires `num_or_size_splits`, so we need to convert # `indices_or_sections` to the appropriate format. - # The following implementation offers better compatibility for the - # tensor argument `indices_or_sections` than original `tfnp.split`. total_size = x.shape[axis] indices_or_sections = convert_to_tensor(indices_or_sections) start_size = indices_or_sections[0:1] end_size = total_size - indices_or_sections[-1:] num_or_size_splits = tf.concat( - [start_size, tfnp.diff(indices_or_sections), end_size], axis=0 + [start_size, diff(indices_or_sections), end_size], axis=0 ) else: num_or_size_splits = indices_or_sections @@ -1726,7 +1879,34 @@ def std(x, axis=None, keepdims=False): def swapaxes(x, axis1, axis2): - return tfnp.swapaxes(x, axis1=axis1, axis2=axis2) + x = convert_to_tensor(x) + + if ( + x.shape.rank is not None + and isinstance(axis1, int) + and isinstance(axis2, int) + ): + # This branch makes sure `perm` is statically known, to avoid a + # not-compile-time-constant XLA error. + axis1 = canonicalize_axis(axis1, x.ndim) + axis2 = canonicalize_axis(axis2, x.ndim) + + # Directly return x if no movement is required + if axis1 == axis2: + return x + + perm = list(range(x.ndim)) + perm[axis1] = axis2 + perm[axis2] = axis1 + else: + x_rank = tf.rank(x) + axis1 = tf.where(axis1 < 0, tf.add(axis1, x_rank), axis1) + axis2 = tf.where(axis2 < 0, tf.add(axis2, x_rank), axis2) + perm = tf.range(x_rank) + perm = tf.tensor_scatter_nd_update( + perm, [[axis1], [axis2]], [axis2, axis1] + ) + return tf.transpose(x, perm) def take(x, indices, axis=None): @@ -1737,9 +1917,7 @@ def take(x, indices, axis=None): f"`x.dtype={x.dtype}` when `indices` is a sparse tensor; " "densifying `indices`." ) - return tfnp.take( - x, convert_to_tensor(indices, sparse=False), axis=axis - ) + return take(x, convert_to_tensor(indices, sparse=False), axis=axis) if axis is None: x = tf.reshape(x, (-1,)) elif axis != 0: @@ -1748,9 +1926,7 @@ def take(x, indices, axis=None): f"`axis={axis}` when `indices` is a sparse tensor; " "densifying `indices`." ) - return tfnp.take( - x, convert_to_tensor(indices, sparse=False), axis=axis - ) + return take(x, convert_to_tensor(indices, sparse=False), axis=axis) output = tf.nn.safe_embedding_lookup_sparse( embedding_weights=tf.convert_to_tensor(x), sparse_ids=tf.sparse.expand_dims(indices, axis=-1), @@ -1758,11 +1934,72 @@ def take(x, indices, axis=None): ) output.set_shape(indices.shape + output.shape[len(indices.shape) :]) return output - return tfnp.take(x, indices, axis=axis) + + x = convert_to_tensor(x) + indices = convert_to_tensor(indices) + if axis is None: + x = tf.reshape(x, [-1]) + axis = 0 + # Correct the indices using "fill" mode which is the same as in jax + indices = tf.where( + indices < 0, + indices + tf.cast(tf.shape(x)[axis], indices.dtype), + indices, + ) + return tf.gather(x, indices, axis=axis) def take_along_axis(x, indices, axis=None): - return tfnp.take_along_axis(x, indices, axis=axis) + x = convert_to_tensor(x) + indices = convert_to_tensor(indices, "int64") + if axis is None: + if indices.ndim != 1: + raise ValueError( + "`indices` must be 1D if axis=None. " + f"Received: indices.shape={indices.shape}" + ) + return take_along_axis(tf.reshape(x, [-1]), indices, 0) + rank = tf.rank(x) + static_axis = axis + axis = axis + rank if axis < 0 else axis + + # Broadcast shapes to match, ensure that the axis of interest is not + # broadcast. + x_shape_original = tf.shape(x, out_type=indices.dtype) + indices_shape_original = tf.shape(indices, out_type=indices.dtype) + x_shape = tf.tensor_scatter_nd_update(x_shape_original, [[axis]], [1]) + indices_shape = tf.tensor_scatter_nd_update( + indices_shape_original, [[axis]], [1] + ) + broadcasted_shape = tf.broadcast_dynamic_shape(x_shape, indices_shape) + x_shape = tf.tensor_scatter_nd_update( + broadcasted_shape, [[axis]], [x_shape_original[axis]] + ) + indices_shape = tf.tensor_scatter_nd_update( + broadcasted_shape, [[axis]], [indices_shape_original[axis]] + ) + x = tf.broadcast_to(x, x_shape) + indices = tf.broadcast_to(indices, indices_shape) + + # Save indices shape so we can restore it later. + possible_result_shape = indices.shape + + # Correct the indices using "fill" mode which is the same as in jax + indices = tf.where(indices < 0, indices + x_shape[static_axis], indices) + + x = swapaxes(x, static_axis, -1) + indices = swapaxes(indices, static_axis, -1) + + x_shape = tf.shape(x) + x = tf.reshape(x, [-1, x_shape[-1]]) + indices_shape = tf.shape(indices) + indices = tf.reshape(indices, [-1, indices_shape[-1]]) + + result = tf.gather(x, indices, batch_dims=1) + result = tf.reshape(result, indices_shape) + result = swapaxes(result, static_axis, -1) + result.set_shape(possible_result_shape) + return result @sparse.elementwise_unary @@ -1800,7 +2037,6 @@ def tensordot(x1, x2, axes=2): @sparse.elementwise_unary def round(x, decimals=0): - # `tfnp.round` requires `enable_numpy_behavior`, so we reimplement it. if decimals == 0: return tf.round(x) x_dtype = x.dtype @@ -1821,7 +2057,6 @@ def round(x, decimals=0): def tile(x, repeats): - # The TFNP implementation is buggy, we roll our own. x = convert_to_tensor(x) repeats = tf.reshape(convert_to_tensor(repeats, dtype="int32"), [-1]) repeats_size = tf.size(repeats) @@ -1844,12 +2079,39 @@ def trace(x, offset=0, axis1=0, axis2=1): dtype = standardize_dtype(x.dtype) if dtype not in ("int64", "uint32", "uint64"): dtype = dtypes.result_type(dtype, "int32") - return tfnp.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype) + x_shape = tf.shape(x) + x = moveaxis(x, (axis1, axis2), (-2, -1)) + # Mask out the diagonal and reduce. + x = tf.where( + eye(x_shape[axis1], x_shape[axis2], k=offset, dtype="bool"), + x, + tf.zeros_like(x), + ) + # The output dtype is set to "int32" if the input dtype is "bool" + if standardize_dtype(x.dtype) == "bool": + x = tf.cast(x, "int32") + return tf.cast(tf.reduce_sum(x, axis=(-2, -1)), dtype) def tri(N, M=None, k=0, dtype=None): - dtype = dtype or config.floatx() - return tfnp.tri(N, M=M, k=k, dtype=dtype) + M = M if M is not None else N + dtype = standardize_dtype(dtype or config.floatx()) + if k < 0: + lower = -k - 1 + if lower > N: + r = tf.zeros([N, M], dtype=dtype) + else: + o = tf.ones([N, M], dtype="bool") + r = tf.cast( + tf.logical_not(tf.linalg.band_part(o, lower, -1)), dtype=dtype + ) + else: + o = tf.ones([N, M], dtype=dtype) + if k > M: + r = o + else: + r = tf.linalg.band_part(o, -1, k) + return r def tril(x, k=0): @@ -1904,6 +2166,25 @@ def vstack(xs): return tf.concat(xs, axis=0) +def _vmap_fn(fn, in_axes=0): + if in_axes != 0: + raise ValueError( + "Not supported with `vectorize()` with the TensorFlow backend." + ) + + @functools.wraps(fn) + def wrapped(x): + return tf.vectorized_map(fn, x) + + return wrapped + + +def vectorize(pyfunc, *, excluded=None, signature=None): + return vectorize_impl( + pyfunc, _vmap_fn, excluded=excluded, signature=signature + ) + + def where(condition, x1, x2): condition = tf.cast(condition, "bool") if x1 is not None and x2 is not None: @@ -2069,7 +2350,29 @@ def sum(x, axis=None, keepdims=False): def eye(N, M=None, k=0, dtype=None): dtype = dtype or config.floatx() - return tfnp.eye(N, M=M, k=k, dtype=dtype) + if not M: + M = N + # Making sure N, M and k are `int` + N, M, k = int(N), int(M), int(k) + if k >= M or -k >= N: + # tf.linalg.diag will raise an error in this case + return zeros([N, M], dtype=dtype) + if k == 0: + return tf.eye(N, M, dtype=dtype) + # We need the precise length, otherwise tf.linalg.diag will raise an error + diag_len = builtins.min(N, M) + if k > 0: + if N >= M: + diag_len -= k + elif N + k > M: + diag_len = M - k + elif k <= 0: + if M >= N: + diag_len += k + elif M - k > N: + diag_len = N + k + diagonal_ = tf.ones([diag_len], dtype=dtype) + return tf.linalg.diag(diagonal=diagonal_, num_rows=N, num_cols=M, k=k) def floor_divide(x1, x2): @@ -2132,3 +2435,33 @@ def correlate(x1, x2, mode="valid"): x2 = tf.reshape(x2, (x2_len, 1, 1)) return tf.squeeze(tf.nn.conv1d(x1, x2, stride=1, padding=mode.upper())) + + +def select(condlist, choicelist, default=0): + return tf.experimental.numpy.select(condlist, choicelist, default=default) + + +def slogdet(x): + x = convert_to_tensor(x) + return tuple(tf.linalg.slogdet(x)) + + +def argpartition(x, kth, axis=-1): + x = convert_to_tensor(x, tf.int32) + + x = swapaxes(x, axis, -1) + bottom_ind = tf.math.top_k(-x, kth + 1).indices + + n = tf.shape(x)[-1] + + mask = tf.reduce_sum(tf.one_hot(bottom_ind, n, dtype=tf.int32), axis=0) + + indices = tf.where(mask) + updates = tf.squeeze(tf.zeros(tf.shape(indices)[0], dtype=tf.int32)) + + final_mask = tf.tensor_scatter_nd_update(x, indices, updates) + + top_ind = tf.math.top_k(final_mask, tf.shape(x)[-1] - kth - 1).indices + + out = tf.concat([bottom_ind, top_ind], axis=x.ndim - 1) + return swapaxes(out, -1, axis) diff --git a/keras/src/backend/tensorflow/random.py b/keras/src/backend/tensorflow/random.py index eeb38a6aa52..0212610085d 100644 --- a/keras/src/backend/tensorflow/random.py +++ b/keras/src/backend/tensorflow/random.py @@ -1,5 +1,4 @@ import tensorflow as tf -from tensorflow.experimental import numpy as tfnp from keras.src.backend.common import standardize_dtype from keras.src.backend.config import floatx @@ -87,12 +86,14 @@ def dropout(inputs, rate, noise_shape=None, seed=None): def shuffle(x, axis=0, seed=None): + from keras.src.backend.tensorflow.numpy import swapaxes + seed = tf_draw_seed(seed) if axis == 0: return tf.random.experimental.stateless_shuffle(x, seed=seed) - x = tfnp.swapaxes(x, axis1=0, axis2=axis) + x = swapaxes(x, axis1=0, axis2=axis) x = tf.random.experimental.stateless_shuffle(x, seed=seed) - x = tfnp.swapaxes(x, axis1=0, axis2=axis) + x = swapaxes(x, axis1=0, axis2=axis) return x diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 257afeeec69..68453255b1f 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -1,5 +1,4 @@ import contextlib -import os import ml_dtypes import numpy as np @@ -19,10 +18,7 @@ # Some operators such as 'aten::_foreach_mul_.Scalar' # are not currently implemented for the MPS device. # check https://github.com/pytorch/pytorch/issues/77764. -if ( - torch.backends.mps.is_available() - and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") == "1" -): +if torch.backends.mps.is_available(): DEFAULT_DEVICE = "mps" elif torch.cuda.is_available(): DEFAULT_DEVICE = "cuda" @@ -238,7 +234,8 @@ def is_tensor(x): def shape(x): - return x.shape + # Convert from `torch.Size` to plain tuple. + return tuple(x.shape) def cast(x, dtype): diff --git a/keras/src/backend/torch/linalg.py b/keras/src/backend/torch/linalg.py index 81041782a1b..9a15f24786e 100644 --- a/keras/src/backend/torch/linalg.py +++ b/keras/src/backend/torch/linalg.py @@ -8,7 +8,7 @@ def cholesky(x): - return torch.cholesky(x) + return torch.linalg.cholesky(x) def det(x): diff --git a/keras/src/backend/torch/nn.py b/keras/src/backend/torch/nn.py index a5cbaab3ea4..97dd04c1b5e 100644 --- a/keras/src/backend/torch/nn.py +++ b/keras/src/backend/torch/nn.py @@ -1,13 +1,11 @@ import torch import torch.nn.functional as tnn +from keras.src import backend from keras.src import tree -from keras.src.backend import standardize_data_format -from keras.src.backend import standardize_dtype from keras.src.backend.common.backend_utils import ( compute_conv_transpose_padding_args_for_torch, ) -from keras.src.backend.config import epsilon from keras.src.backend.torch.core import cast from keras.src.backend.torch.core import convert_to_tensor from keras.src.backend.torch.core import get_device @@ -92,9 +90,12 @@ def gelu(x, approximate=True): def softmax(x, axis=-1): x = convert_to_tensor(x) - dtype = standardize_dtype(x.dtype) + dtype = backend.standardize_dtype(x.dtype) # TODO: tnn.softmax doesn't support float16 using cpu - if get_device() == "cpu" and standardize_dtype(x.dtype) == "float16": + if ( + get_device() == "cpu" + and backend.standardize_dtype(x.dtype) == "float16" + ): x = cast(x, "float32") if axis is None: # Unlike numpy, PyTorch will handle axis=None as axis=-1. @@ -109,9 +110,12 @@ def softmax(x, axis=-1): def log_softmax(x, axis=-1): x = convert_to_tensor(x) - dtype = standardize_dtype(x.dtype) + dtype = backend.standardize_dtype(x.dtype) # TODO: tnn.log_softmax doesn't support float16 using cpu - if get_device() == "cpu" and standardize_dtype(x.dtype) == "float16": + if ( + get_device() == "cpu" + and backend.standardize_dtype(x.dtype) == "float16" + ): x = cast(x, "float32") if axis is None: # Unlike numpy, PyTorch will handle axis=None as axis=-1. @@ -240,7 +244,7 @@ def max_pool( else: strides = standardize_tuple(strides, num_spatial_dims, "strides") - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) if data_format == "channels_last": inputs = _transpose_spatial_inputs(inputs) @@ -301,7 +305,7 @@ def average_pool( else: strides = standardize_tuple(strides, num_spatial_dims, "strides") - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) if data_format == "channels_last": inputs = _transpose_spatial_inputs(inputs) padding_value = 0 @@ -375,7 +379,7 @@ def conv( num_spatial_dims = inputs.ndim - 2 strides = standardize_tuple(strides, num_spatial_dims, "strides") - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) if data_format == "channels_last": inputs = _transpose_spatial_inputs(inputs) # Transpose kernel from keras format to torch format. @@ -494,7 +498,7 @@ def conv_transpose( num_spatial_dims = inputs.ndim - 2 strides = standardize_tuple(strides, num_spatial_dims, "strides") - data_format = standardize_data_format(data_format) + data_format = backend.standardize_data_format(data_format) ( torch_padding, torch_output_padding, @@ -610,7 +614,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): log_prob = tnn.log_softmax(output, dim=axis) else: output = output / torch.sum(output, dim=axis, keepdim=True) - output = torch.clip(output, epsilon(), 1.0 - epsilon()) + output = torch.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) log_prob = torch.log(output) return -torch.sum(target * log_prob, dim=axis) @@ -638,7 +642,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): log_prob = tnn.log_softmax(output, dim=axis) else: output = output / torch.sum(output, dim=axis, keepdim=True) - output = torch.clip(output, epsilon(), 1.0 - epsilon()) + output = torch.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) log_prob = torch.log(output) target = one_hot(target, output.shape[axis], axis=axis) return -torch.sum(target * log_prob, dim=axis) @@ -661,7 +665,7 @@ def binary_crossentropy(target, output, from_logits=False): output, target, reduction="none" ) else: - output = torch.clip(output, epsilon(), 1.0 - epsilon()) + output = torch.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) return tnn.binary_cross_entropy(output, target, reduction="none") @@ -675,7 +679,7 @@ def moments(x, axes, keepdims=False, synchronized=False): # workaround, we simply perform the operations on float32 and convert back # to float16 need_cast = False - ori_dtype = standardize_dtype(x.dtype) + ori_dtype = backend.standardize_dtype(x.dtype) if ori_dtype == "float16": need_cast = True x = cast(x, "float32") @@ -752,10 +756,13 @@ def ctc_loss( target_length = convert_to_tensor(target_length) output_length = convert_to_tensor(output_length) + # Ensure that the dtype promotion behavior matchs that of `tf.nn.ctc_loss` + dtype = backend.result_type(output.dtype, "float32") + output = cast(output, dtype) + output = torch.transpose(output, 1, 0) logits = tnn.log_softmax(output, dim=-1) - - return tnn.ctc_loss( + loss = tnn.ctc_loss( logits, target, output_length, @@ -763,3 +770,99 @@ def ctc_loss( blank=mask_index, reduction="none", ) + return loss + + +def _ctc_greedy_decode( + inputs, + sequence_lengths, + merge_repeated=True, + mask_index=None, +): + inputs = convert_to_tensor(inputs) + sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32") + batch_size, max_length, num_classes = inputs.shape + + if mask_index is None: + mask_index = num_classes - 1 + + indices = torch.argmax(inputs, axis=-1) + indices = cast(indices, "int32") + scores = torch.max(inputs, axis=-1)[0] + + seqlen_mask = torch.arange(max_length, device=indices.device)[None, :] + seqlen_mask = seqlen_mask >= sequence_lengths[:, None] + + indices = torch.where(seqlen_mask, mask_index, indices) + scores = torch.where(seqlen_mask, 0.0, scores) + + if merge_repeated: + repeat = indices[:, 1:] == indices[:, :-1] + repeat = tnn.pad(repeat, (1, 0, 0, 0)) + indices = torch.where(repeat, mask_index, indices) + + # We set to -1 for blank labels + invalid_mask = indices == mask_index + indices = torch.where(invalid_mask, -1, indices) + + # We rearrange the indices by moving `mask_index` to the end of the array + order = torch.unsqueeze( + torch.arange(max_length, device=indices.device), dim=0 + ) # [1, N] + order = torch.tile(order, (batch_size, 1)) # [B, N] + order = torch.where(invalid_mask, max_length, order) + order = torch.argsort(order, dim=-1) + indices = torch.take_along_dim(indices, order, dim=-1) + + scores = -torch.sum(scores, axis=1)[:, None] + indices = torch.unsqueeze(indices, dim=0) + return indices, scores + + +def ctc_decode( + inputs, + sequence_lengths, + strategy="greedy", + beam_width=100, + top_paths=1, + merge_repeated=True, + mask_index=0, +): + inputs = convert_to_tensor(inputs) + dtype = backend.result_type(inputs.dtype, "float32") + inputs = cast(inputs, dtype) + + if strategy == "greedy": + return _ctc_greedy_decode( + inputs, + sequence_lengths, + merge_repeated=merge_repeated, + mask_index=mask_index, + ) + elif strategy == "beam_search": + raise NotImplementedError( + "Torch backend doesn't yet support the beam search strategy for CTC" + "decoding." + ) + else: + raise ValueError( + f"Invalid strategy {strategy}. Supported values are " + "'greedy' and 'beam_search'." + ) + + +def psnr(x1, x2, max_val): + if x1.shape != x2.shape: + raise ValueError( + f"Input shapes {x1.shape} and {x2.shape} must " + "match for PSNR calculation. " + ) + + x1, x2 = ( + convert_to_tensor(x1), + convert_to_tensor(x2), + ) + max_val = convert_to_tensor(max_val, dtype=x1.dtype) + mse = torch.mean((x1 - x2) ** 2) + psnr = 20 * torch.log10(max_val) - 10 * torch.log10(mse) + return psnr diff --git a/keras/src/backend/torch/numpy.py b/keras/src/backend/torch/numpy.py index 257159e4644..8bf95df70c7 100644 --- a/keras/src/backend/torch/numpy.py +++ b/keras/src/backend/torch/numpy.py @@ -8,6 +8,7 @@ from keras.src.backend.common import dtypes from keras.src.backend.common.backend_utils import canonicalize_axis from keras.src.backend.common.backend_utils import to_tuple_or_list +from keras.src.backend.common.backend_utils import vectorize_impl from keras.src.backend.common.variables import standardize_dtype from keras.src.backend.torch.core import cast from keras.src.backend.torch.core import convert_to_tensor @@ -172,8 +173,11 @@ def max(x, axis=None, keepdims=False, initial=None): result = result.values if initial is not None: - initial = convert_to_tensor(initial) - return torch.maximum(result, torch.full(result.shape, initial)) + dtype = to_torch_dtype(result.dtype) + initial = convert_to_tensor(initial, dtype=dtype) + return torch.maximum( + result, torch.full(result.shape, initial, dtype=dtype) + ) return result @@ -198,10 +202,6 @@ def zeros_like(x, dtype=None): def absolute(x): - return abs(x) - - -def abs(x): x = convert_to_tensor(x) # bool are always non-negative if standardize_dtype(x.dtype) == "bool": @@ -209,6 +209,10 @@ def abs(x): return torch.abs(x) +def abs(x): + return absolute(x) + + def all(x, axis=None, keepdims=False): x = convert_to_tensor(x) if axis is None: @@ -744,8 +748,15 @@ def linspace( dtype = dtypes.result_type(*dtypes_to_resolve) dtype = to_torch_dtype(dtype) - if endpoint is False: - stop = stop - ((stop - start) / num) + step = convert_to_tensor(torch.nan) + if endpoint: + if num > 1: + step = (stop - start) / (num - 1) + else: + if num > 0: + step = (stop - start) / num + if num > 1: + stop = stop - ((stop - start) / num) if hasattr(start, "__len__") and hasattr(stop, "__len__"): start = convert_to_tensor(start, dtype=dtype) stop = convert_to_tensor(stop, dtype=dtype) @@ -766,7 +777,7 @@ def linspace( device=get_device(), ) if retstep is True: - return (linspace, num) + return (linspace, step) return linspace @@ -949,7 +960,8 @@ def min(x, axis=None, keepdims=False, initial=None): result = result.values if initial is not None: - initial = convert_to_tensor(initial) + dtype = to_torch_dtype(result.dtype) + initial = convert_to_tensor(initial, dtype=dtype) return torch.minimum(result, initial) return result @@ -983,9 +995,9 @@ def moveaxis(x, source, destination): return torch.moveaxis(x, source=source, destination=destination) -def nan_to_num(x): +def nan_to_num(x, nan=0.0, posinf=None, neginf=None): x = convert_to_tensor(x) - return torch.nan_to_num(x) + return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) def ndim(x): @@ -1265,6 +1277,13 @@ def swapaxes(x, axis1, axis2): def take(x, indices, axis=None): x = convert_to_tensor(x) indices = convert_to_tensor(indices).long() + # Correct the indices using "fill" mode which is the same as in jax + x_dim = x.shape[axis] if axis is not None else x.shape[0] + indices = torch.where( + indices < 0, + indices + x_dim, + indices, + ) if x.ndim == 2 and axis == 0: # This case is equivalent to embedding lookup. return torch.nn.functional.embedding(indices, x) @@ -1285,6 +1304,13 @@ def take(x, indices, axis=None): def take_along_axis(x, indices, axis=None): x = convert_to_tensor(x) indices = convert_to_tensor(indices).long() + # Correct the indices using "fill" mode which is the same as in jax + x_dim = x.shape[axis] if axis is not None else x.shape[0] + indices = torch.where( + indices < 0, + indices + x_dim, + indices, + ) return torch.take_along_dim(x, indices, dim=axis) @@ -1387,6 +1413,12 @@ def vstack(xs): return torch.vstack(xs) +def vectorize(pyfunc, *, excluded=None, signature=None): + return vectorize_impl( + pyfunc, torch.vmap, excluded=excluded, signature=signature + ) + + def where(condition, x1, x2): condition = convert_to_tensor(condition, dtype=bool) if x1 is not None and x2 is not None: @@ -1567,3 +1599,34 @@ def correlate(x1, x2, mode="valid"): result = result[..., start_idx : start_idx + x1_len] return torch.squeeze(result) + + +def select(condlist, choicelist, default=0): + condlist = [convert_to_tensor(c) for c in condlist] + choicelist = [convert_to_tensor(c) for c in choicelist] + out = convert_to_tensor(default) + for c, v in reversed(list(zip(condlist, choicelist))): + out = torch.where(c, v, out) + return out + + +def slogdet(x): + x = convert_to_tensor(x) + return tuple(torch.linalg.slogdet(x)) + + +def argpartition(x, kth, axis=-1): + x = convert_to_tensor(x, "int32") + x = torch.transpose(x, axis, -1) + bottom_ind = torch.topk(-x, kth + 1)[1] + + def set_to_zero(a, i): + a[i] = torch.zeros(1, dtype=a.dtype, device=a.device) + return a + + for _ in range(x.dim() - 1): + set_to_zero = torch.vmap(set_to_zero) + proxy = set_to_zero(torch.ones_like(x, dtype=torch.int32), bottom_ind) + top_ind = torch.topk(proxy, x.shape[-1] - kth - 1)[1] + out = torch.cat([bottom_ind, top_ind], dim=x.dim() - 1) + return cast(torch.transpose(out, -1, axis), "int32") diff --git a/keras/src/callbacks/early_stopping.py b/keras/src/callbacks/early_stopping.py index e7c1fe9c0dc..5571cf606de 100644 --- a/keras/src/callbacks/early_stopping.py +++ b/keras/src/callbacks/early_stopping.py @@ -50,7 +50,6 @@ class EarlyStopping(Callback): improvement is expected and thus training will not be stopped. Defaults to `0`. - Example: >>> callback = keras.callbacks.EarlyStopping(monitor='loss', diff --git a/keras/src/datasets/imdb.py b/keras/src/datasets/imdb.py index a8b5537b111..f38dfaf0a15 100644 --- a/keras/src/datasets/imdb.py +++ b/keras/src/datasets/imdb.py @@ -135,8 +135,8 @@ def load_data( xs = [[w for w in x if skip_top <= w < num_words] for x in xs] idx = len(x_train) - x_train, y_train = xs[:idx], labels[:idx] - x_test, y_test = xs[idx:], labels[idx:] + x_train, y_train = np.array(xs[:idx], dtype="object"), labels[:idx] + x_test, y_test = np.array(xs[idx:], dtype="object"), labels[idx:] return (x_train, y_train), (x_test, y_test) diff --git a/keras/src/dtype_policies/__init__.py b/keras/src/dtype_policies/__init__.py index ec84c266041..03cff8015b9 100644 --- a/keras/src/dtype_policies/__init__.py +++ b/keras/src/dtype_policies/__init__.py @@ -1,23 +1,96 @@ from keras.src import backend +from keras.src.api_export import keras_export from keras.src.dtype_policies import dtype_policy from keras.src.dtype_policies.dtype_policy import QUANTIZATION_MODES +from keras.src.dtype_policies.dtype_policy import DTypePolicy from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy from keras.src.dtype_policies.dtype_policy import QuantizedFloat8DTypePolicy +ALL_OBJECTS = { + DTypePolicy, + FloatDTypePolicy, + QuantizedDTypePolicy, + QuantizedFloat8DTypePolicy, +} +ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS} + +@keras_export("keras.dtype_policies.serialize") +def serialize(dtype_policy): + """Serializes `DTypePolicy` instance. + + Args: + dtype_policy: A Keras `DTypePolicy` instance. + + Returns: + `DTypePolicy` configuration dictionary. + """ + from keras.src.saving import serialization_lib + + return serialization_lib.serialize_keras_object(dtype_policy) + + +@keras_export("keras.dtype_policies.deserialize") +def deserialize(config, custom_objects=None): + """Deserializes a serialized `DTypePolicy` instance. + + Args: + config: `DTypePolicy` configuration. + custom_objects: Optional dictionary mapping names (strings) to custom + objects (classes and functions) to be considered during + deserialization. + + Returns: + A Keras `DTypePolicy` instance. + """ + from keras.src.saving import serialization_lib + + return serialization_lib.deserialize_keras_object( + config, + module_objects=ALL_OBJECTS_DICT, + custom_objects=custom_objects, + ) + + +@keras_export("keras.dtype_policies.get") def get(identifier): + """Retrieves a Keras `DTypePolicy` instance. + + The `identifier` may be the string name of a `DTypePolicy` class. + + >>> policy = dtype_policies.get("mixed_bfloat16") + >>> type(loss) + + + You can also specify `config` of the dtype policy to this function by + passing dict containing `class_name` and `config` as an identifier. Also + note that the `class_name` must map to a `DTypePolicy` class + + >>> identifier = {"class_name": "FloatDTypePolicy", + ... "config": {"name": "float32"}} + >>> policy = dtype_policies.get(identifier) + >>> type(loss) + + + Args: + identifier: A dtype policy identifier. One of `None` or string name of a + `DTypePolicy` or `DTypePolicy` configuration dictionary or a + `DTypePolicy` instance. + + Returns: + A Keras `DTypePolicy` instance. + """ from keras.src.dtype_policies.dtype_policy import ( _get_quantized_dtype_policy_by_str, ) - from keras.src.saving import serialization_lib if identifier is None: return dtype_policy.dtype_policy() if isinstance(identifier, (FloatDTypePolicy, QuantizedDTypePolicy)): return identifier if isinstance(identifier, dict): - return serialization_lib.deserialize_keras_object(identifier) + return deserialize(identifier) if isinstance(identifier, str): if identifier.startswith(QUANTIZATION_MODES): return _get_quantized_dtype_policy_by_str(identifier) diff --git a/keras/src/dtype_policies/dtype_policy.py b/keras/src/dtype_policies/dtype_policy.py index 2618e118e2b..e9bf91aaab9 100644 --- a/keras/src/dtype_policies/dtype_policy.py +++ b/keras/src/dtype_policies/dtype_policy.py @@ -1,5 +1,4 @@ from keras.src import backend -from keras.src import ops from keras.src.api_export import keras_export from keras.src.backend.common import global_state @@ -135,25 +134,27 @@ def name(self): return self._name def convert_input(self, x, autocast, dtype): + """Converts the input dtype based on `autocast` and `dtype`. + + Note that `x` can be a tensor, symbolic tensor or numpy array, and this + method will keep integer inputs untouched and only apply casting to + floats. + """ + dtype = backend.standardize_dtype(dtype) if backend.is_tensor(x): - if ( - autocast - and backend.is_float_dtype(x.dtype) - and x.dtype != dtype - ): + if self._should_cast(x, autocast, dtype): x = backend.cast(x, dtype=dtype) return x elif backend.is_keras_tensor(x): - if ( - autocast - and backend.is_float_dtype(x.dtype) - and x.dtype != dtype - ): + if self._should_cast(x, autocast, dtype): x.dtype = dtype return x elif hasattr(x, "__array__"): - return ops.convert_to_tensor(x, dtype=dtype) + x = backend.convert_to_tensor(x) + if self._should_cast(x, autocast, dtype): + x = backend.cast(x, dtype=dtype) + return x return x def get_config(self): @@ -163,6 +164,13 @@ def get_config(self): def from_config(cls, config): return cls(**config) + def _should_cast(self, x, autocast, dtype): + x_dtype = backend.standardize_dtype(x.dtype) + if autocast and backend.is_float_dtype(x_dtype) and x_dtype != dtype: + return True + else: + return False + @keras_export( ["keras.FloatDTypePolicy", "keras.dtype_policies.FloatDTypePolicy"] @@ -293,6 +301,11 @@ def _get_all_valid_policies(self): ] return valid_policies + def get_config(self): + config = super().get_config() + config.update({"amax_history_length": self.amax_history_length}) + return config + @keras_export( [ diff --git a/keras/src/dtype_policies/dtype_policy_test.py b/keras/src/dtype_policies/dtype_policy_test.py index b040663781a..b66df0779f3 100644 --- a/keras/src/dtype_policies/dtype_policy_test.py +++ b/keras/src/dtype_policies/dtype_policy_test.py @@ -1,5 +1,8 @@ from absl.testing import parameterized +from keras.src.dtype_policies import deserialize +from keras.src.dtype_policies import get +from keras.src.dtype_policies import serialize from keras.src.dtype_policies.dtype_policy import DTypePolicy from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy @@ -64,7 +67,7 @@ def test_get_config_from_config(self): new_policy = DTypePolicy.from_config(config) self.assertEqual(new_policy.name, "mixed_float16") - def test_serialization(self): + def test_python_serialization(self): """Test builtin serialization methods.""" import copy import pickle @@ -91,6 +94,16 @@ def test_serialization(self): repr(copied_policy), '' ) + def test_serialization(self): + policy = DTypePolicy("mixed_float16") + config = serialize(policy) + reloaded_policy = deserialize(config) + self.assertEqual(policy.name, reloaded_policy.name) + + # Test `dtype_policies.get` + reloaded_policy = get(config) + self.assertEqual(policy.name, reloaded_policy.name) + class FloatDTypePolicyTest(test_case.TestCase): def test_initialization_valid_name(self): @@ -154,6 +167,16 @@ def test_get_config_from_config(self): new_policy = FloatDTypePolicy.from_config(config) self.assertEqual(new_policy.name, "mixed_float16") + def test_serialization(self): + policy = FloatDTypePolicy("mixed_float16") + config = serialize(policy) + reloaded_policy = deserialize(config) + self.assertEqual(policy.name, reloaded_policy.name) + + # Test `dtype_policies.get` + reloaded_policy = get(config) + self.assertEqual(policy.name, reloaded_policy.name) + class QuantizedDTypePolicyTest(test_case.TestCase, parameterized.TestCase): @parameterized.named_parameters( @@ -224,7 +247,7 @@ def test_get_config_from_config(self): '', ), ) - def test_serialization(self, name, repr_str): + def test_python_serialization(self, name, repr_str): import copy import pickle @@ -244,6 +267,16 @@ def test_serialization(self, name, repr_str): copied_policy = pickle.load(f) self.assertEqual(repr(copied_policy), repr_str) + def test_serialization(self): + policy = QuantizedDTypePolicy("int8_from_float32") + config = serialize(policy) + reloaded_policy = deserialize(config) + self.assertEqual(policy.name, reloaded_policy.name) + + # Test `dtype_policies.get` + reloaded_policy = get(config) + self.assertEqual(policy.name, reloaded_policy.name) + def test_properties_for_float8(self): policy = QuantizedFloat8DTypePolicy("float8_from_mixed_bfloat16") self.assertEqual(policy.amax_history_length, 1024) @@ -256,7 +289,7 @@ def test_invalid_properties_for_float8(self): with self.assertRaisesRegex(TypeError, "must be an integer."): QuantizedFloat8DTypePolicy("float8_from_float32", 512.0) - def test_serialization_for_float8(self): + def test_python_serialization_for_float8(self): import copy import pickle @@ -288,6 +321,22 @@ def test_serialization_for_float8(self): ) self.assertEqual(copied_policy.amax_history_length, 123) + def test_serialization_for_float8(self): + policy = QuantizedFloat8DTypePolicy("float8_from_mixed_float16") + config = serialize(policy) + reloaded_policy = deserialize(config) + self.assertEqual(policy.name, reloaded_policy.name) + self.assertEqual( + policy.amax_history_length, reloaded_policy.amax_history_length + ) + + # Test `dtype_policies.get` + reloaded_policy = get(config) + self.assertEqual(policy.name, reloaded_policy.name) + self.assertEqual( + policy.amax_history_length, reloaded_policy.amax_history_length + ) + @parameterized.named_parameters( ("int8_from_mixed_bfloat16", "int8_from_mixed_bfloat16"), ("float8_from_mixed_bfloat16", "float8_from_mixed_bfloat16"), diff --git a/keras/src/layers/convolutional/base_conv.py b/keras/src/layers/convolutional/base_conv.py index 96a66e58e40..ffb1e4a8780 100644 --- a/keras/src/layers/convolutional/base_conv.py +++ b/keras/src/layers/convolutional/base_conv.py @@ -307,9 +307,11 @@ def save_own_variables(self, store): # Do nothing if the layer isn't yet built if not self.built: return - store["0"] = self.kernel + target_variables = [self.kernel] if self.use_bias: - store["1"] = self.bias + target_variables.append(self.bias) + for i, variable in enumerate(target_variables): + store[str(i)] = variable def load_own_variables(self, store): if not self.lora_enabled: @@ -317,9 +319,11 @@ def load_own_variables(self, store): # Do nothing if the layer isn't yet built if not self.built: return - self._kernel.assign(store["0"]) + target_variables = [self._kernel] if self.use_bias: - self.bias.assign(store["1"]) + target_variables.append(self.bias) + for i, variable in enumerate(target_variables): + variable.assign(store[str(i)]) if self.lora_enabled: self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape)) self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape)) diff --git a/keras/src/layers/core/dense.py b/keras/src/layers/core/dense.py index e856feb9840..8b78a285154 100644 --- a/keras/src/layers/core/dense.py +++ b/keras/src/layers/core/dense.py @@ -202,24 +202,26 @@ def save_own_variables(self, store): # The keys of the `store` will be saved as determined because the # default ordering will change after quantization kernel_value, kernel_scale = self._get_kernel_with_merged_lora() - store["0"] = kernel_value + target_variables = [kernel_value] if self.use_bias: - store["1"] = self.bias + target_variables.append(self.bias) if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy): mode = self.dtype_policy.quantization_mode if mode == "int8": - store["2"] = kernel_scale + target_variables.append(kernel_scale) elif mode == "float8": - store["2"] = self.inputs_scale - store["3"] = self.inputs_amax_history - store["4"] = self.kernel_scale - store["5"] = self.kernel_amax_history - store["6"] = self.outputs_grad_scale - store["7"] = self.outputs_grad_amax_history + target_variables.append(self.inputs_scale) + target_variables.append(self.inputs_amax_history) + target_variables.append(self.kernel_scale) + target_variables.append(self.kernel_amax_history) + target_variables.append(self.outputs_grad_scale) + target_variables.append(self.outputs_grad_amax_history) else: raise NotImplementedError( self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode) ) + for i, variable in enumerate(target_variables): + store[str(i)] = variable def load_own_variables(self, store): if not self.lora_enabled: @@ -229,24 +231,26 @@ def load_own_variables(self, store): return # The keys of the `store` will be saved as determined because the # default ordering will change after quantization - self._kernel.assign(store["0"]) + target_variables = [self._kernel] if self.use_bias: - self.bias.assign(store["1"]) + target_variables.append(self.bias) if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy): mode = self.dtype_policy.quantization_mode if mode == "int8": - self.kernel_scale.assign(store["2"]) + target_variables.append(self.kernel_scale) elif mode == "float8": - self.inputs_scale.assign(store["2"]) - self.inputs_amax_history.assign(store["3"]) - self.kernel_scale.assign(store["4"]) - self.kernel_amax_history.assign(store["5"]) - self.outputs_grad_scale.assign(store["6"]) - self.outputs_grad_amax_history.assign(store["7"]) + target_variables.append(self.inputs_scale) + target_variables.append(self.inputs_amax_history) + target_variables.append(self.kernel_scale) + target_variables.append(self.kernel_amax_history) + target_variables.append(self.outputs_grad_scale) + target_variables.append(self.outputs_grad_amax_history) else: raise NotImplementedError( self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode) ) + for i, variable in enumerate(target_variables): + variable.assign(store[str(i)]) if self.lora_enabled: self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape)) self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape)) @@ -553,8 +557,6 @@ def quantize(self, mode): self._tracker.unlock() if mode == "int8": - # Configure `self.inputs_quantizer` - self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1) # Quantize `self._kernel` to int8 and compute corresponding scale kernel_value, kernel_scale = quantizers.abs_max_quantize( self._kernel, axis=0 diff --git a/keras/src/layers/core/einsum_dense.py b/keras/src/layers/core/einsum_dense.py index f3b9cb31a1d..a884171dec4 100644 --- a/keras/src/layers/core/einsum_dense.py +++ b/keras/src/layers/core/einsum_dense.py @@ -257,24 +257,26 @@ def save_own_variables(self, store): # The keys of the `store` will be saved as determined because the # default ordering will change after quantization kernel_value, kernel_scale = self._get_kernel_with_merged_lora() - store["0"] = kernel_value + target_variables = [kernel_value] if self.bias is not None: - store["1"] = self.bias + target_variables.append(self.bias) if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy): mode = self.dtype_policy.quantization_mode if mode == "int8": - store["2"] = kernel_scale + target_variables.append(kernel_scale) elif mode == "float8": - store["2"] = self.inputs_scale - store["3"] = self.inputs_amax_history - store["4"] = self.kernel_scale - store["5"] = self.kernel_amax_history - store["6"] = self.outputs_grad_scale - store["7"] = self.outputs_grad_amax_history + target_variables.append(self.inputs_scale) + target_variables.append(self.inputs_amax_history) + target_variables.append(self.kernel_scale) + target_variables.append(self.kernel_amax_history) + target_variables.append(self.outputs_grad_scale) + target_variables.append(self.outputs_grad_amax_history) else: raise NotImplementedError( self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode) ) + for i, variable in enumerate(target_variables): + store[str(i)] = variable def load_own_variables(self, store): if not self.lora_enabled: @@ -284,24 +286,26 @@ def load_own_variables(self, store): return # The keys of the `store` will be saved as determined because the # default ordering will change after quantization - self._kernel.assign(store["0"]) + target_variables = [self._kernel] if self.bias is not None: - self.bias.assign(store["1"]) + target_variables.append(self.bias) if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy): mode = self.dtype_policy.quantization_mode if mode == "int8": - self.kernel_scale.assign(store["2"]) + target_variables.append(self.kernel_scale) elif mode == "float8": - self.inputs_scale.assign(store["2"]) - self.inputs_amax_history.assign(store["3"]) - self.kernel_scale.assign(store["4"]) - self.kernel_amax_history.assign(store["5"]) - self.outputs_grad_scale.assign(store["6"]) - self.outputs_grad_amax_history.assign(store["7"]) + target_variables.append(self.inputs_scale) + target_variables.append(self.inputs_amax_history) + target_variables.append(self.kernel_scale) + target_variables.append(self.kernel_amax_history) + target_variables.append(self.outputs_grad_scale) + target_variables.append(self.outputs_grad_amax_history) else: raise NotImplementedError( self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode) ) + for i, variable in enumerate(target_variables): + variable.assign(store[str(i)]) if self.lora_enabled: self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape)) self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape)) @@ -408,7 +412,9 @@ def _int8_build( self._custom_gradient_equation, self._kernel_reverse_transpose_axes, ) = _analyze_quantization_info(self.equation, self.input_spec.ndim) - self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1) + self.inputs_quantizer = quantizers.AbsMaxQuantizer( + axis=self._input_reduced_axes + ) self._kernel = self.add_weight( name="kernel", shape=kernel_shape, @@ -678,10 +684,6 @@ def quantize(self, mode): self._custom_gradient_equation, self._kernel_reverse_transpose_axes, ) = _analyze_quantization_info(self.equation, self.input_spec.ndim) - # Configure `self.inputs_quantizer` - self.inputs_quantizer = quantizers.AbsMaxQuantizer( - axis=self._input_reduced_axes - ) # Quantize `self._kernel` to int8 and compute corresponding scale kernel_value, kernel_scale = quantizers.abs_max_quantize( self._kernel, axis=self._kernel_reduced_axes diff --git a/keras/src/layers/core/einsum_dense_test.py b/keras/src/layers/core/einsum_dense_test.py index 098c3595b19..eaa102a4df3 100644 --- a/keras/src/layers/core/einsum_dense_test.py +++ b/keras/src/layers/core/einsum_dense_test.py @@ -469,6 +469,28 @@ def test_quantize_int8(self): backend.standardize_dtype(layer.kernel_scale.dtype), "float32" ) + @parameterized.named_parameters( + ("btnh,nhd->btd", "btnh,nhd->btd", (None, 8), (1, 2, 2, 4)), + ("btd,ndh->btnh", "btd,ndh->btnh", (None, 2, 8), (1, 2, 4)), + ("btd,df->btf", "btd,df->btf", (None, 4), (1, 2, 4)), + ) + @pytest.mark.skipif( + backend.backend() == "numpy", + reason=f"{backend.backend()} does not support ops.custom_gradient.", + ) + def test_quantize_int8_with_specific_equations( + self, equation, output_shape, input_shape + ): + layer = layers.EinsumDense(equation=equation, output_shape=output_shape) + layer.build(input_shape) + x = ops.random.uniform(input_shape) + y_float = layer(x) + + layer.quantize("int8") + y_quantized = layer(x) + mse = ops.mean(ops.square(y_float - y_quantized)) + self.assertLess(mse, 1e-3) # A weak correctness test + @parameterized.named_parameters( ("int8", "int8"), ("float8", "float8"), diff --git a/keras/src/layers/core/embedding.py b/keras/src/layers/core/embedding.py index 1807a86b125..bb85c3dd13e 100644 --- a/keras/src/layers/core/embedding.py +++ b/keras/src/layers/core/embedding.py @@ -199,15 +199,17 @@ def save_own_variables(self, store): embeddings_value, embeddings_scale = ( self._get_embeddings_with_merged_lora() ) - store["0"] = embeddings_value + target_variables = [embeddings_value] if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy): mode = self.dtype_policy.quantization_mode if mode == "int8": - store["1"] = embeddings_scale + target_variables.append(embeddings_scale) else: raise NotImplementedError( self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode) ) + for i, variable in enumerate(target_variables): + store[str(i)] = variable def load_own_variables(self, store): if not self.lora_enabled: @@ -217,15 +219,17 @@ def load_own_variables(self, store): return # The keys of the `store` will be saved as determined because the # default ordering will change after quantization - self._embeddings.assign(store["0"]) + target_variables = [self._embeddings] if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy): mode = self.dtype_policy.quantization_mode if mode == "int8": - self.embeddings_scale.assign(store["1"]) + target_variables.append(self.embeddings_scale) else: raise NotImplementedError( self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode) ) + for i, variable in enumerate(target_variables): + variable.assign(store[str(i)]) if self.lora_enabled: self.lora_embeddings_a.assign( ops.zeros(self.lora_embeddings_a.shape) @@ -311,7 +315,6 @@ def _int8_build( embeddings_initializer="zeros", embeddings_scale_initializer="ones", ): - self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1) self._embeddings = self.add_weight( name="embeddings", shape=(self.input_dim, self.output_dim), @@ -319,9 +322,12 @@ def _int8_build( dtype="int8", trainable=False, ) + # We choose to reduce the axis of `output_dim` because, typically, + # `input_dim` is larger than `output_dim`. This reduces quantization + # error. self.embeddings_scale = self.add_weight( name="embeddings_scale", - shape=(self.output_dim,), + shape=(self.input_dim,), initializer=embeddings_scale_initializer, trainable=False, ) @@ -341,11 +347,12 @@ def _int8_call(self, inputs): # not needed if backend.standardize_dtype(inputs.dtype) not in ("int32", "int64"): inputs = ops.cast(inputs, "int32") + embeddings_scale = ops.take(self.embeddings_scale, inputs, axis=0) outputs = ops.take(self._embeddings, inputs, axis=0) # De-scale outputs - outputs = ops.cast(outputs, self.compute_dtype) outputs = ops.divide( - outputs, ops.expand_dims(self.embeddings_scale, axis=0) + ops.cast(outputs, dtype=self.compute_dtype), + ops.expand_dims(embeddings_scale, axis=-1), ) if self.lora_enabled: lora_outputs = ops.take(self.lora_embeddings_a, inputs, axis=0) @@ -375,14 +382,12 @@ def quantize(self, mode): self._tracker.unlock() if mode == "int8": - # Configure `self.inputs_quantizer` - self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1) # Quantize `self._embeddings` to int8 and compute corresponding # scale embeddings_value, embeddings_scale = quantizers.abs_max_quantize( - self._embeddings, axis=0 + self._embeddings, axis=-1 ) - embeddings_scale = ops.squeeze(embeddings_scale, axis=0) + embeddings_scale = ops.squeeze(embeddings_scale, axis=-1) self._untrack_variable(self._embeddings) del self._embeddings # Utilize a lambda expression as an initializer to prevent adding a @@ -408,15 +413,15 @@ def _get_embeddings_with_merged_lora(self): # Dequantize & quantize to merge lora weights into embeddings # Note that this is a lossy compression embeddings_value = ops.divide( - embeddings_value, embeddings_scale + embeddings_value, ops.expand_dims(embeddings_scale, axis=-1) ) embeddings_value = ops.add( embeddings_value, ops.matmul(self.lora_embeddings_a, self.lora_embeddings_b), ) embeddings_value, embeddings_scale = ( - quantizers.abs_max_quantize(embeddings_value, axis=0) + quantizers.abs_max_quantize(embeddings_value, axis=-1) ) - embeddings_scale = ops.squeeze(embeddings_scale, axis=0) + embeddings_scale = ops.squeeze(embeddings_scale, axis=-1) return embeddings_value, embeddings_scale return self.embeddings, None diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 9757bd11f15..a8c764d719c 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -36,6 +36,7 @@ from keras.src.layers import input_spec from keras.src.metrics.metric import Metric from keras.src.ops.operation import Operation +from keras.src.saving.keras_saveable import KerasSaveable from keras.src.utils import python_utils from keras.src.utils import summary_utils from keras.src.utils import traceback_utils @@ -58,7 +59,7 @@ @keras_export(["keras.Layer", "keras.layers.Layer"]) -class Layer(BackendLayer, Operation): +class Layer(BackendLayer, Operation, KerasSaveable): """This is the class from which all layers inherit. A layer is a callable object that takes as input one or more tensors and @@ -424,6 +425,9 @@ def build_from_config(self, config): self.build(**config["shapes_dict"]) self.built = True + def _obj_type(self): + return "Layer" + def add_variable( self, shape, diff --git a/keras/src/layers/layer_test.py b/keras/src/layers/layer_test.py index ad274da84af..e0c71a0cf7f 100644 --- a/keras/src/layers/layer_test.py +++ b/keras/src/layers/layer_test.py @@ -1,3 +1,5 @@ +import pickle + import numpy as np import pytest @@ -11,6 +13,7 @@ class LayerTest(testing.TestCase): + def test_compute_output_spec(self): # Test that implementing compute_output_shape # is enough to make compute_output_spec work. @@ -434,13 +437,13 @@ def test_mixed_precision(self): y = layer(x) self.assertEqual(layer.compute_dtype, "float16") self.assertEqual(layer.variable_dtype, "float16") - self.assertEqual(backend.standardize_dtype(y.dtype), "float16") + self.assertDType(y, "float16") layer = layers.Dense(2, dtype="mixed_float16") y = layer(x) self.assertEqual(layer.compute_dtype, "float16") self.assertEqual(layer.variable_dtype, "float32") - self.assertEqual(backend.standardize_dtype(y.dtype), "float16") + self.assertDType(y, "float16") self.assertEqual(layer.kernel.dtype, "float32") @pytest.mark.skipif( @@ -448,7 +451,7 @@ def test_mixed_precision(self): reason="Some torch ops not implemented for float16 on CPU.", ) def test_autocast(self): - assertEqual = self.assertEqual + assertDType = self.assertDType # A layer with a int dtype (some preprocessing layers do this). class InnerLayerOne(layers.Layer): @@ -464,7 +467,7 @@ def __init__(self): def call(self, x): # Should not autocast. - assertEqual(backend.standardize_dtype(self.v.dtype), "float32") + assertDType(self.v, "float32") return ops.cast(x, "float32") + self.v # A layer that is explicitly full precision. @@ -480,7 +483,7 @@ def __init__(self): def call(self, x): # Should not autocast. - assertEqual(backend.standardize_dtype(self.v.dtype), "float32") + assertDType(self.v, "float32") return x + self.v # A layer that is explicitly mixed precision but with autocast=False @@ -498,7 +501,7 @@ def __init__(self): def call(self, x): # Should not autocast `self.v`. - assertEqual(backend.standardize_dtype(self.v.dtype), "float32") + assertDType(self.v, "float32") return ops.add(x, self.v) # A layer that is explicitly mixed precision with inner layers. @@ -517,7 +520,7 @@ def __init__(self): def call(self, x): # Should autocast. - assertEqual(backend.standardize_dtype(self.v.dtype), "float16") + assertDType(self.v, "float16") return self.inner_three( self.inner_two(self.inner_one(x + self.v)) ) @@ -526,6 +529,21 @@ def call(self, x): y = layer(np.array(0.0)) self.assertEqual(y, 4.0) + def test_autocast_with_np_array(self): + assertDType = self.assertDType + + class CustomLayer(layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def call(self, x): + # Here are the assertions. + assertDType(x[0], "float32") # Cast to compute_dtype + assertDType(x[1], "int32") # Untouched + + x = [np.zeros(1, dtype="float64"), np.zeros(1, dtype="int32")] + CustomLayer()(x) + @pytest.mark.skipif( backend.backend() == "numpy", reason="Numpy backend does not support masking.", @@ -957,3 +975,8 @@ def test_dtype_policy_setter(self): self.assertEqual(layer.dtype_policy.name, "mixed_float16") self.assertEqual(layer.dtype_policy.compute_dtype, "float16") self.assertEqual(layer.dtype_policy.variable_dtype, "float32") + + def test_pickle_layer(self): + layer = layers.Dense(2) + reloaded = pickle.loads(pickle.dumps(layer)) + self.assertEqual(layer.get_config(), reloaded.get_config()) diff --git a/keras/src/layers/normalization/spectral_normalization_test.py b/keras/src/layers/normalization/spectral_normalization_test.py index b3cc47d8d9f..f9a34b4626d 100644 --- a/keras/src/layers/normalization/spectral_normalization_test.py +++ b/keras/src/layers/normalization/spectral_normalization_test.py @@ -25,7 +25,7 @@ def test_basic_spectralnorm(self): self.run_layer_test( layers.SpectralNormalization, init_kwargs={"layer": layers.Embedding(10, 4)}, - input_data=np.random.randint(10, size=(10,)), + input_data=np.random.randint(10, size=(10,)).astype("float32"), expected_output_shape=(10, 4), expected_num_trainable_weights=1, expected_num_non_trainable_weights=1, diff --git a/keras/src/layers/rnn/bidirectional.py b/keras/src/layers/rnn/bidirectional.py index 9d9d29d2460..a89c30f9a4e 100644 --- a/keras/src/layers/rnn/bidirectional.py +++ b/keras/src/layers/rnn/bidirectional.py @@ -3,13 +3,12 @@ from keras.src import ops from keras.src import utils from keras.src.api_export import keras_export -from keras.src.layers.core.wrapper import Wrapper from keras.src.layers.layer import Layer from keras.src.saving import serialization_lib @keras_export("keras.layers.Bidirectional") -class Bidirectional(Wrapper): +class Bidirectional(Layer): """Bidirectional wrapper for RNNs. Args: @@ -105,7 +104,7 @@ def __init__( "Merge mode should be one of " '{"sum", "mul", "ave", "concat", None}' ) - super().__init__(layer, **kwargs) + super().__init__(**kwargs) # Recreate the forward layer from the original layer config, so that it # will not carry over any state from the layer. @@ -272,8 +271,10 @@ def states(self): return None def build(self, sequences_shape, initial_state_shape=None): - self.forward_layer.build(sequences_shape) - self.backward_layer.build(sequences_shape) + if not self.forward_layer.built: + self.forward_layer.build(sequences_shape) + if not self.backward_layer.built: + self.backward_layer.build(sequences_shape) self.built = True def compute_mask(self, _, mask): diff --git a/keras/src/layers/rnn/gru.py b/keras/src/layers/rnn/gru.py index f489bd6638f..8a516d0b440 100644 --- a/keras/src/layers/rnn/gru.py +++ b/keras/src/layers/rnn/gru.py @@ -538,14 +538,23 @@ def inner_loop(self, sequences, initial_state, mask, training=False): if tree.is_nested(mask): mask = mask[0] if self.use_cudnn in ("auto", True): - if not self.dropout and not self.recurrent_dropout: + if not self.recurrent_dropout: try: + if self.dropout: + dp_mask = self.cell.get_dropout_mask(sequences[:, 0, :]) + dp_mask = ops.expand_dims(dp_mask, axis=1) + dp_mask = ops.broadcast_to( + dp_mask, ops.shape(sequences) + ) + dp_sequences = sequences * dp_mask + else: + dp_sequences = sequences # Backends are allowed to specify (optionally) optimized # implementation of the inner GRU loop. In the case of # TF for instance, it will leverage cuDNN when feasible, and # it will raise NotImplementedError otherwise. out = backend.gru( - sequences, + dp_sequences, initial_state, mask, kernel=self.cell.kernel, diff --git a/keras/src/layers/rnn/lstm.py b/keras/src/layers/rnn/lstm.py index 33055fd197e..f4903655bb8 100644 --- a/keras/src/layers/rnn/lstm.py +++ b/keras/src/layers/rnn/lstm.py @@ -518,14 +518,24 @@ def inner_loop(self, sequences, initial_state, mask, training=False): mask = mask[0] if self.use_cudnn in ("auto", True): - if not self.dropout and not self.recurrent_dropout: + if not self.recurrent_dropout: try: + if self.dropout: + dp_mask = self.cell.get_dropout_mask(sequences[:, 0, :]) + dp_mask = ops.expand_dims(dp_mask, axis=1) + dp_mask = ops.broadcast_to( + dp_mask, ops.shape(sequences) + ) + dp_sequences = sequences * dp_mask + else: + dp_sequences = sequences + # Backends are allowed to specify (optionally) optimized # implementation of the inner LSTM loop. In the case of # TF for instance, it will leverage cuDNN when feasible, and # it will raise NotImplementedError otherwise. out = backend.lstm( - sequences, + dp_sequences, initial_state[0], initial_state[1], mask, diff --git a/keras/src/losses/__init__.py b/keras/src/losses/__init__.py index 9652ceb057b..3f4ef8d0f69 100644 --- a/keras/src/losses/__init__.py +++ b/keras/src/losses/__init__.py @@ -135,8 +135,8 @@ def deserialize(name, custom_objects=None): Args: name: Loss configuration. custom_objects: Optional dictionary mapping names (strings) to custom - objects (classes and functions) to be considered during - deserialization. + objects (classes and functions) to be considered during + deserialization. Returns: A Keras `Loss` instance or a loss function. diff --git a/keras/src/losses/loss.py b/keras/src/losses/loss.py index ba4c78ebc5a..a35ecf7d48f 100644 --- a/keras/src/losses/loss.py +++ b/keras/src/losses/loss.py @@ -2,11 +2,12 @@ from keras.src import ops from keras.src import tree from keras.src.api_export import keras_export +from keras.src.saving.keras_saveable import KerasSaveable from keras.src.utils.naming import auto_name @keras_export(["keras.Loss", "keras.losses.Loss"]) -class Loss: +class Loss(KerasSaveable): """Loss base class. To be implemented by subclasses: @@ -69,6 +70,9 @@ def get_config(self): def from_config(cls, config): return cls(**config) + def _obj_type(self): + return "Loss" + def standardize_reduction(reduction): allowed = {"sum_over_batch_size", "sum", None, "none"} diff --git a/keras/src/losses/loss_test.py b/keras/src/losses/loss_test.py index 1d5725ffd3a..e438f7d882b 100644 --- a/keras/src/losses/loss_test.py +++ b/keras/src/losses/loss_test.py @@ -1,3 +1,5 @@ +import pickle + import numpy as np import pytest @@ -226,6 +228,11 @@ def test_mixed_dtypes(self): loss, ) + def test_pickle(self): + loss = losses_module.get("mse") + loss = pickle.loads(pickle.dumps(loss)) + self.assertEqual(loss, losses_module.mean_squared_error) + def test_get_method(self): loss = losses_module.get("mse") self.assertEqual(loss, losses_module.mean_squared_error) diff --git a/keras/src/losses/losses.py b/keras/src/losses/losses.py index f3f997616a0..b91d15e87e7 100644 --- a/keras/src/losses/losses.py +++ b/keras/src/losses/losses.py @@ -1893,7 +1893,7 @@ class CTC(LossFunctionWrapper): def __init__( self, reduction="sum_over_batch_size", - name="sparse_categorical_crossentropy", + name="ctc", ): super().__init__( ctc, @@ -1933,14 +1933,16 @@ def ctc(y_true, y_pred): f"Received: y_pred.shape={ops.shape(y_pred)}" ) - batch_length = ops.cast(ops.shape(y_true)[0], dtype="int32") - input_length = ops.cast(ops.shape(y_pred)[1], dtype="int32") - label_length = ops.cast(ops.shape(y_true)[1], dtype="int32") + mask_index = 0 + batch_length = ops.shape(y_pred)[0] + input_length = ops.shape(y_pred)[1] input_length = input_length * ops.ones((batch_length,), dtype="int32") - label_length = label_length * ops.ones((batch_length,), dtype="int32") + label_length = ops.cast( + ops.sum(y_true != mask_index, axis=-1), dtype="int32" + ) return ops.ctc_loss( - y_true, y_pred, label_length, input_length, mask_index=0 + y_true, y_pred, label_length, input_length, mask_index=mask_index ) diff --git a/keras/src/losses/losses_test.py b/keras/src/losses/losses_test.py index b97a8a253c3..07b74fa3739 100644 --- a/keras/src/losses/losses_test.py +++ b/keras/src/losses/losses_test.py @@ -1387,7 +1387,7 @@ def test_correctness(self): logits = (np.arange(24).reshape((2, 4, 3)).astype("float32") - 12) / 100 y_true = np.array(([[1, 2, 1, 0], [1, 2, 0, 2]])) output = losses.CTC()(y_true, logits) - self.assertAllClose(output, 4.389582) + self.assertAllClose(output, 2.448645) class DiceTest(testing.TestCase): diff --git a/keras/src/metrics/metric.py b/keras/src/metrics/metric.py index 27f39f94d79..91e50dab896 100644 --- a/keras/src/metrics/metric.py +++ b/keras/src/metrics/metric.py @@ -2,12 +2,13 @@ from keras.src import initializers from keras.src import ops from keras.src.api_export import keras_export +from keras.src.saving.keras_saveable import KerasSaveable from keras.src.utils.naming import auto_name from keras.src.utils.tracking import Tracker @keras_export(["keras.Metric", "keras.metrics.Metric"]) -class Metric: +class Metric(KerasSaveable): """Encapsulates metric logic and state. Args: @@ -179,6 +180,9 @@ def stateless_reset_state(self): def dtype(self): return self._dtype + def _obj_type(self): + return "Metric" + def add_variable( self, shape, initializer, dtype=None, aggregation="sum", name=None ): diff --git a/keras/src/metrics/metric_test.py b/keras/src/metrics/metric_test.py index 346e6140c89..90bd1a9b9a0 100644 --- a/keras/src/metrics/metric_test.py +++ b/keras/src/metrics/metric_test.py @@ -1,3 +1,5 @@ +import pickle + import numpy as np from keras.src import backend @@ -165,6 +167,11 @@ def test_serialization(self): custom_objects={"ExampleMetric": ExampleMetric}, ) + def test_pickle(self): + metric = metrics_module.get("mse") + reloaded = pickle.loads(pickle.dumps(metric)) + self.assertIsInstance(reloaded, metrics_module.MeanSquaredError) + def test_get_method(self): metric = metrics_module.get("mse") self.assertIsInstance(metric, metrics_module.MeanSquaredError) diff --git a/keras/src/models/cloning.py b/keras/src/models/cloning.py index 3875b3522a6..7e28e4438c3 100644 --- a/keras/src/models/cloning.py +++ b/keras/src/models/cloning.py @@ -11,7 +11,14 @@ @keras_export("keras.models.clone_model") -def clone_model(model, input_tensors=None, clone_function=None): +def clone_model( + model, + input_tensors=None, + clone_function=None, + call_function=None, + recursive=False, + **kwargs, +): """Clone a Functional or Sequential `Model` instance. Model cloning is similar to calling a model on new inputs, @@ -29,24 +36,44 @@ def clone_model(model, input_tensors=None, clone_function=None): input_tensors: optional list of input tensors or InputLayer objects to build the model upon. If not provided, new `Input` objects will be created. - clone_function: Callable to be used to clone each layer in the target + clone_function: Callable with signature `fn(layer)` + to be used to clone each layer in the target model (except `Input` instances). It takes as argument the layer instance to be cloned, and returns the corresponding layer instance to be used in the model copy. If unspecified, this callable - becomes the following serialization/deserialization function: + defaults to the following serialization/deserialization function: `lambda layer: layer.__class__.from_config(layer.get_config())`. By passing a custom callable, you can customize your copy of the model, e.g. by wrapping certain layers of interest (you might want to replace all `LSTM` instances with equivalent `Bidirectional(LSTM(...))` instances, for example). Defaults to `None`. + call_function: Callable with signature + `fn(layer, *args, **kwargs)` to be used to call each + cloned layer and a set of inputs. It takes the layer instance, + the call arguments and keyword arguments, and returns the + call outputs. If unspecified, this callable defaults to + the regular `__call__()` method: + `def fn(layer, *args, **kwargs): return layer(*args, **kwargs)`. + By passing a custom callable, you can insert new layers before or + after a given layer. Note: this argument can only be used with + Functional models. + recursive: Boolean. Whether to recursively clone any Sequential + or Functional models encountered in the original + Sequential/Functional model. If `False`, + then inner models are cloned by calling `clone_function()`. + If `True`, then inner models are cloned by calling `clone_model()` + with the same `clone_function`, `call_function`, and `recursive` + arguments. Note that in this case, `call_function` + will not be propagated to any Sequential model + (since it is not applicable to Sequential models). Returns: An instance of `Model` reproducing the behavior of the original model, on top of new inputs tensors, using newly instantiated weights. The cloned model may behave differently from the original model if a custom `clone_function` - modifies the layer. + or `call_function` modifies a layer or layer call. Example: @@ -74,6 +101,23 @@ def clone_function(layer): new_model = clone_model(model) ``` + Using a `call_function` to add a `Dropout` layer after each `Dense` layer + (without recreating new layers): + + ```python + def call_function(layer, *args, **kwargs): + out = layer(*args, **kwargs) + if isinstance(layer, keras.layers.Dense): + out = keras.layers.Dropout(0.5)(out) + return out + + new_model = clone_model( + model, + clone_function=lambda x: x, # Reuse the same layers. + call_function=call_function, + ) + ``` + Note that subclassed models cannot be cloned by default, since their internal layer structure is not known. To achieve equivalent functionality @@ -88,11 +132,44 @@ def clone_function(layer): In the case of a subclassed model, you cannot using a custom `clone_function`. """ + cache = kwargs.pop("cache", None) + if kwargs: + raise ValueError( + f"Unexpected keyword argument(s): {tuple(kwargs.keys())}" + ) + if isinstance(model, Sequential): + # Wrap clone_function to handle recursiveness and layer sharing. + clone_function = _wrap_clone_function( + clone_function, + call_function=call_function, + recursive=recursive, + cache=cache, + ) + if call_function is not None: + raise ValueError( + "`call_function` argument is not supported with Sequential " + "models. In a Sequential model, layers aren't called " + "at model-construction time (they're merely listed). " + "Use `call_function` with Functional models only. " + "Received model of " + f"type '{model.__class__.__name__}', with " + f"call_function={clone_function}" + ) return _clone_sequential_model( - model, input_tensors=input_tensors, clone_function=clone_function + model, + clone_function=clone_function, + input_tensors=input_tensors, ) if isinstance(model, Functional): + # Wrap clone_function to handle recursiveness and layer sharing. + clone_function = _wrap_clone_function( + clone_function, + call_function=call_function, + recursive=recursive, + cache=cache, + ) + # If the get_config() method is the same as a regular Functional # model, we're safe to use _clone_functional_model (which relies # on a Functional constructor). In the case where the get_config @@ -104,27 +181,78 @@ def clone_function(layer): ): return _clone_functional_model( model, - input_tensors=input_tensors, clone_function=clone_function, + call_function=call_function, + input_tensors=input_tensors, ) # Case of a custom model class if clone_function or input_tensors: raise ValueError( - "Arguments clone_function and input_tensors " + "Arguments `clone_function` and `input_tensors` " "are only supported for Sequential models " "or Functional models. Received model of " f"type '{model.__class__.__name__}', with " f"clone_function={clone_function} and " f"input_tensors={input_tensors}" ) + if call_function is not None: + raise ValueError( + "Argument `call_function` is only supported " + "for Functional models. Received model of " + f"type '{model.__class__.__name__}', with " + f"call_function={clone_function}" + ) config = serialization_lib.serialize_keras_object(model) return serialization_lib.deserialize_keras_object( config, custom_objects={model.__class__.__name__: model.__class__} ) -def _clone_sequential_model(model, input_tensors=None, clone_function=None): +def _wrap_clone_function( + clone_function, call_function=None, recursive=False, cache=None +): + """Wrapper to handle recursiveness and layer sharing.""" + if clone_function is None: + + def _clone_layer(layer): + return layer.__class__.from_config(layer.get_config()) + + clone_function = _clone_layer + + if cache is None: + cache = {} + + def wrapped_clone_function(layer): + if id(layer) in cache: + return cache[id(layer)] + if recursive: + if isinstance(layer, Sequential): + # Note: Sequential doens't support call_function. + clone = clone_model( + layer, + clone_function=clone_function, + cache=cache, + ) + cache[id(layer)] = clone + return clone + elif isinstance(layer, Functional): + clone = clone_model( + layer, + clone_function=clone_function, + call_function=call_function, + cache=cache, + ) + cache[id(layer)] = clone + return clone + clone = clone_function(layer) + cache[id(layer)] = clone + return clone + + return wrapped_clone_function + + +def _clone_sequential_model(model, clone_function, input_tensors=None): """Clone a `Sequential` model instance. Model cloning is similar to calling a model on new inputs, @@ -144,12 +272,6 @@ def _clone_sequential_model(model, input_tensors=None, clone_function=None): of the original model, on top of new inputs tensors, using newly instantiated weights. """ - if clone_function is None: - - def _clone_layer(layer): - return layer.__class__.from_config(layer.get_config()) - - clone_function = _clone_layer if not isinstance(model, Sequential): raise ValueError( @@ -202,7 +324,9 @@ def _clone_layer(layer): return Sequential(new_layers, name=model.name, trainable=model.trainable) -def _clone_functional_model(model, input_tensors=None, clone_function=None): +def _clone_functional_model( + model, clone_function, input_tensors=None, call_function=None +): """Clone a `Functional` model instance. Model cloning is similar to calling a model on new inputs, @@ -224,17 +348,6 @@ def _clone_functional_model(model, input_tensors=None, clone_function=None): of the original model, on top of new inputs tensors, using newly instantiated weights. """ - if clone_function is None: - seen = {} - - def _clone_layer(layer): - if layer in seen: - return seen[layer] - new_layer = layer.__class__.from_config(layer.get_config()) - seen[layer] = new_layer - return new_layer - - clone_function = _clone_layer if not callable(clone_function): raise ValueError( @@ -276,7 +389,9 @@ def operation_fn(layer): return new_layer output_tensors = model._run_through_graph( - input_tensors, operation_fn=operation_fn + input_tensors, + operation_fn=operation_fn, + call_fn=call_function, ) if functional_like_constructor(model.__class__): diff --git a/keras/src/models/cloning_test.py b/keras/src/models/cloning_test.py index d9a46ac29cf..b77122b28a2 100644 --- a/keras/src/models/cloning_test.py +++ b/keras/src/models/cloning_test.py @@ -4,6 +4,7 @@ from keras.src import layers from keras.src import models +from keras.src import ops from keras.src import testing from keras.src import tree from keras.src.models.cloning import clone_model @@ -21,6 +22,24 @@ def get_mlp_functional_model(shared_layers=False): return model +def get_nested_functional_model(): + inputs = layers.Input(shape=(4,)) + x = layers.Dense(3)(inputs) + mlp = get_mlp_functional_model() + x = mlp(x) + outputs = layers.Dense(2)(x) + model = models.Model(inputs, outputs) + return model + + +def get_nested_sequential_model(): + model = models.Sequential() + model.add(layers.Dense(2)) + model.add(get_sequential_model(explicit_input=False)) + model.add(layers.Dense(2)) + return model + + def get_cnn_functional_model(shared_layers=False): inputs = layers.Input(shape=(7, 3)) x = layers.Conv1D(2, 2, padding="same")(inputs) @@ -57,6 +76,19 @@ def call(self, x): @pytest.mark.requires_trainable_backend class CloneModelTest(testing.TestCase, parameterized.TestCase): + + def assert_models_equal(self, model1, model2, ref_input): + result1 = model1(ref_input) + result2 = model2(ref_input) + for r1, r2 in zip(tree.flatten(result1), tree.flatten(result2)): + self.assertAllClose( + ops.convert_to_numpy(r1), ops.convert_to_numpy(r2) + ) + + def assert_weights_equal(self, model1, model2): + for a, b in zip(model1.weights, model2.weights): + self.assertAllClose(a.numpy(), b.numpy()) + @parameterized.named_parameters( ("mlp_functional", get_mlp_functional_model), ("cnn_functional", get_cnn_functional_model, True), @@ -71,11 +103,10 @@ def test_cloning_correctness(self, model_fn, is_conv=False): ref_input = np.random.random((2, 7, 3) if is_conv else (2, 3)) model = model_fn() new_model = clone_model(model) - ref_output = model(ref_input) + model(ref_input) # Maybe needed to build the model new_model(ref_input) # Maybe needed to build the model new_model.set_weights(model.get_weights()) - output = new_model(ref_input) - self.assertAllClose(ref_output, output) + self.assert_models_equal(model, new_model, ref_input) @parameterized.named_parameters( ("mlp_functional", get_mlp_functional_model), @@ -121,3 +152,68 @@ def test_structured_io_cloning(self): "`input_tensors` must have the same structure as model.input", ): model = clone_model(model0, input_tensors=(x, y)) + + def test_call_fn(self): + model = get_mlp_functional_model(shared_layers=False) + + def call_function(layer, *args, **kwargs): + out = layer(*args, **kwargs) + if isinstance(layer, layers.Dense): + out = layers.Dropout(0.5)(out) + return out + + new_model = clone_model( + model, + clone_function=lambda x: x, # Reuse the same layers. + call_function=call_function, + ) + self.assertLen(model.layers, 3) + self.assertLen(new_model.layers, 5) + self.assertIsInstance(new_model.layers[2], layers.Dropout) + self.assertIsInstance(new_model.layers[4], layers.Dropout) + ref_input = np.random.random((2, 3)) + self.assert_models_equal(model, new_model, ref_input) + + def test_recursive(self): + model = get_nested_functional_model() + + def call_function(layer, *args, **kwargs): + out = layer(*args, **kwargs) + if isinstance(layer, layers.Dense): + out = layers.Dropout(0.5)(out) + return out + + new_model = clone_model( + model, + clone_function=lambda x: x, # Reuse the same layers. + call_function=call_function, + recursive=True, + ) + self.assertLen(model._flatten_layers(), 8) + self.assertLen(new_model._flatten_layers(), 12) + self.assertIsInstance(new_model.layers[3].layers[2], layers.Dropout) + self.assertIsInstance(new_model.layers[3].layers[4], layers.Dropout) + ref_input = np.random.random((2, 4)) + self.assert_models_equal(model, new_model, ref_input) + + # Sequential. + def clone_function(layer): + layer = layer.__class__.from_config(layer.get_config()) + layer.flag = True + return layer + + model = get_nested_sequential_model() + new_model = clone_model( + model, + clone_function=clone_function, + recursive=True, + ) + ref_input = np.random.random((2, 3)) + model(ref_input) # Maybe needed to build the model + new_model(ref_input) # Maybe needed to build the model + new_model.set_weights(model.get_weights()) + self.assert_models_equal(model, new_model, ref_input) + for l1, l2 in zip(model._flatten_layers(), new_model._flatten_layers()): + if isinstance(l2, layers.Dense): + self.assertFalse(hasattr(l1, "flag")) + self.assertTrue(hasattr(l2, "flag")) diff --git a/keras/src/models/functional.py b/keras/src/models/functional.py index 20e2266e00e..8151e5eb089 100644 --- a/keras/src/models/functional.py +++ b/keras/src/models/functional.py @@ -178,6 +178,9 @@ def _lock_state(self): # functional DAG. pass + def _obj_type(self): + return "Functional" + @property def layers(self): layers = [] diff --git a/keras/src/models/model.py b/keras/src/models/model.py index 62871c8b103..af6d60e4933 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -7,7 +7,7 @@ from keras.src import utils from keras.src.api_export import keras_export from keras.src.layers.layer import Layer -from keras.src.models.variable_mapping import map_trackable_variables +from keras.src.models.variable_mapping import map_saveable_variables from keras.src.saving import saving_api from keras.src.trainers import trainer as base_trainer from keras.src.utils import summary_utils @@ -546,7 +546,7 @@ def from_config(cls, config, custom_objects=None): def _get_variable_map(self): store = {} - map_trackable_variables(self, store=store, visited_trackables=set()) + map_saveable_variables(self, store=store, visited_saveables=set()) return store diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index 871fc4bff19..7fa91c5b95d 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -1,3 +1,5 @@ +import pickle + import numpy as np import pytest from absl.testing import parameterized @@ -116,6 +118,29 @@ def call(self, x): ) self.assertIsInstance(new_model, Functional) + @parameterized.named_parameters( + ("single_output_1", _get_model_single_output), + ("single_output_2", _get_model_single_output), + ("single_output_3", _get_model_single_output), + ("single_output_4", _get_model_single_output), + ("single_list_output_1", _get_model_single_output_list), + ("single_list_output_2", _get_model_single_output_list), + ("single_list_output_3", _get_model_single_output_list), + ("single_list_output_4", _get_model_single_output_list), + ) + def test_functional_pickling(self, model_fn): + model = model_fn() + self.assertIsInstance(model, Functional) + model.compile() + x = np.random.rand(8, 3) + + reloaded_pickle = pickle.loads(pickle.dumps(model)) + + pred_reloaded = reloaded_pickle.predict(x) + pred = model.predict(x) + + self.assertAllClose(np.array(pred_reloaded), np.array(pred)) + @parameterized.named_parameters( ("single_output_1", _get_model_single_output, None), ("single_output_2", _get_model_single_output, "list"), @@ -138,7 +163,7 @@ def test_functional_single_output(self, model_fn, loss_type): loss = [loss] elif loss_type == "dict": loss = {"output_a": loss} - elif loss_type == "dict_lsit": + elif loss_type == "dict_list": loss = {"output_a": [loss]} model.compile( optimizer="sgd", diff --git a/keras/src/models/sequential.py b/keras/src/models/sequential.py index ecdaa4058e8..9d7b8149e96 100644 --- a/keras/src/models/sequential.py +++ b/keras/src/models/sequential.py @@ -142,6 +142,9 @@ def _lock_state(self): # Unlike other layers, Sequential is mutable after build. pass + def _obj_type(self): + return "Sequential" + def build(self, input_shape=None): if not isinstance(input_shape, (tuple, list)): # Do not attempt to build if the model does not have a single diff --git a/keras/src/models/sequential_test.py b/keras/src/models/sequential_test.py index 12c0703ab45..6ce5ff5f2c3 100644 --- a/keras/src/models/sequential_test.py +++ b/keras/src/models/sequential_test.py @@ -1,3 +1,5 @@ +import pickle + import numpy as np import pytest @@ -246,6 +248,13 @@ def test_functional_properties(self): self.assertEqual(model.input_shape, (None, 2)) self.assertEqual(model.output_shape, (None, 4)) + def test_pickleable(self): + model = Sequential(name="seq") + model.add(layers.Dense(4)) + + result = pickle.loads(pickle.dumps(model)) + assert len(result.layers) == 1 + def test_bad_layer(self): model = Sequential(name="seq") with self.assertRaisesRegex(ValueError, "Only instances of"): diff --git a/keras/src/models/variable_mapping.py b/keras/src/models/variable_mapping.py index ed9deb7340e..e06ea5b0939 100644 --- a/keras/src/models/variable_mapping.py +++ b/keras/src/models/variable_mapping.py @@ -2,24 +2,25 @@ from keras.src.metrics.metric import Metric from keras.src.optimizers.optimizer import Optimizer from keras.src.saving import saving_lib +from keras.src.saving.keras_saveable import KerasSaveable -def map_trackable_variables(trackable, store, visited_trackables): - # If the trackable has already been saved, skip it. - if id(trackable) in visited_trackables: +def map_saveable_variables(saveable, store, visited_saveables): + # If the saveable has already been seen, skip it. + if id(saveable) in visited_saveables: return - visited_trackables.add(id(trackable)) + visited_saveables.add(id(saveable)) variables = [] - if isinstance(trackable, Layer): + if isinstance(saveable, Layer): variables = ( - trackable._trainable_variables + trackable._non_trainable_variables + saveable._trainable_variables + saveable._non_trainable_variables ) - elif isinstance(trackable, Optimizer): - variables = trackable._variables - elif isinstance(trackable, Metric): - variables = trackable._variables + elif isinstance(saveable, Optimizer): + variables = saveable._variables + elif isinstance(saveable, Metric): + variables = saveable._variables for v in variables: if v.path in store: raise ValueError( @@ -31,30 +32,30 @@ def map_trackable_variables(trackable, store, visited_trackables): ) store[v.path] = v - # Recursively save state of children trackables (layers, optimizers, etc.) - for child_attr, child_obj in saving_lib._walk_trackable(trackable): - if saving_lib._is_keras_trackable(child_obj): - map_trackable_variables( + # Recursively save state of children saveables (layers, optimizers, etc.) + for child_attr, child_obj in saving_lib._walk_saveable(saveable): + if isinstance(child_obj, KerasSaveable): + map_saveable_variables( child_obj, store, - visited_trackables=visited_trackables, + visited_saveables=visited_saveables, ) elif isinstance(child_obj, (list, dict, tuple, set)): map_container_variables( child_obj, store, - visited_trackables=visited_trackables, + visited_saveables=visited_saveables, ) -def map_container_variables(container, store, visited_trackables): +def map_container_variables(container, store, visited_saveables): if isinstance(container, dict): container = list(container.values()) - for trackable in container: - if saving_lib._is_keras_trackable(trackable): - map_trackable_variables( - trackable, + for saveable in container: + if isinstance(saveable, KerasSaveable): + map_saveable_variables( + saveable, store, - visited_trackables=visited_trackables, + visited_saveables=visited_saveables, ) diff --git a/keras/src/ops/core_test.py b/keras/src/ops/core_test.py index 7ea3ef23f15..b609835b077 100644 --- a/keras/src/ops/core_test.py +++ b/keras/src/ops/core_test.py @@ -338,10 +338,12 @@ def test_stop_gradient_return(self): self.assertAllClose(x, y) def test_shape(self): - x = np.ones((2, 3, 7, 1)) + x = ops.ones((2, 3, 7, 1)) + self.assertEqual(core.shape(x).__class__, tuple) self.assertAllEqual(core.shape(x), (2, 3, 7, 1)) x = KerasTensor((None, 3, None, 1)) + self.assertEqual(core.shape(x).__class__, tuple) self.assertAllEqual(core.shape(x), (None, 3, None, 1)) @pytest.mark.skipif( diff --git a/keras/src/ops/function.py b/keras/src/ops/function.py index 8b6930ac12a..04f2b409614 100644 --- a/keras/src/ops/function.py +++ b/keras/src/ops/function.py @@ -121,7 +121,7 @@ def call(self, inputs): self._assert_input_compatibility(inputs) return self._run_through_graph(inputs, operation_fn=lambda op: op) - def _run_through_graph(self, inputs, operation_fn): + def _run_through_graph(self, inputs, operation_fn, call_fn=None): """Execute the graph. At each node we compute outputs via @@ -148,7 +148,11 @@ def _run_through_graph(self, inputs, operation_fn): continue # Node is not computable, try skipping. args, kwargs = node.arguments.fill_in(tensor_dict) - outputs = operation_fn(node.operation)(*args, **kwargs) + op = operation_fn(node.operation) + if call_fn is not None: + outputs = call_fn(op, *args, **kwargs) + else: + outputs = op(*args, **kwargs) # Update tensor_dict. for x, y in zip(node.outputs, tree.flatten(outputs)): diff --git a/keras/src/ops/image.py b/keras/src/ops/image.py index bb817ec4abe..a9971b2aba4 100644 --- a/keras/src/ops/image.py +++ b/keras/src/ops/image.py @@ -548,7 +548,7 @@ def _extract_patches( if not strides: strides = size out_dim = patch_h * patch_w * channels_in - kernel = backend.numpy.eye(out_dim) + kernel = backend.numpy.eye(out_dim, dtype=image.dtype) kernel = backend.numpy.reshape( kernel, (patch_h, patch_w, channels_in, out_dim) ) diff --git a/keras/src/ops/linalg_test.py b/keras/src/ops/linalg_test.py index e1f0decf64b..a2fd0c61aad 100644 --- a/keras/src/ops/linalg_test.py +++ b/keras/src/ops/linalg_test.py @@ -101,6 +101,15 @@ def test_qr(self): self.assertEqual(q.shape, qref_shape) self.assertEqual(r.shape, rref_shape) + def test_qr_invalid_mode(self): + # backend agnostic error message + x = np.array([[1, 2], [3, 4]]) + invalid_mode = "invalid_mode" + with self.assertRaisesRegex( + ValueError, "Expected one of {'reduced', 'complete'}." + ): + linalg.qr(x, mode=invalid_mode) + def test_solve(self): a = KerasTensor([None, 20, 20]) b = KerasTensor([None, 20, 5]) diff --git a/keras/src/ops/math_test.py b/keras/src/ops/math_test.py index 60db9fc70f6..86e3c70a78e 100644 --- a/keras/src/ops/math_test.py +++ b/keras/src/ops/math_test.py @@ -1,5 +1,6 @@ import math +import jax.numpy as jnp import numpy as np import pytest import scipy.signal @@ -1256,3 +1257,90 @@ def test_undefined_fft_length_and_last_dimension(self): expected_shape = real_part.shape[:-1] + (None,) self.assertEqual(output_spec.shape, expected_shape) + + +class TestMathErrors(testing.TestCase): + + @pytest.mark.skipif( + backend.backend() != "jax", reason="Testing Jax errors only" + ) + def test_segment_sum_no_num_segments(self): + data = jnp.array([1, 2, 3, 4]) + segment_ids = jnp.array([0, 0, 1, 1]) + with self.assertRaisesRegex( + ValueError, + "Argument `num_segments` must be set when using the JAX backend.", + ): + kmath.segment_sum(data, segment_ids) + + @pytest.mark.skipif( + backend.backend() != "jax", reason="Testing Jax errors only" + ) + def test_segment_max_no_num_segments(self): + data = jnp.array([1, 2, 3, 4]) + segment_ids = jnp.array([0, 0, 1, 1]) + with self.assertRaisesRegex( + ValueError, + "Argument `num_segments` must be set when using the JAX backend.", + ): + kmath.segment_max(data, segment_ids) + + def test_stft_invalid_input_type(self): + # backend agnostic error message + x = np.array([1, 2, 3, 4]) + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + with self.assertRaisesRegex(TypeError, "`float32` or `float64`"): + kmath.stft(x, sequence_length, sequence_stride, fft_length) + + def test_invalid_fft_length(self): + # backend agnostic error message + x = np.array([1.0, 2.0, 3.0, 4.0]) + sequence_length = 4 + sequence_stride = 1 + fft_length = 2 + with self.assertRaisesRegex(ValueError, "`fft_length` must equal or"): + kmath.stft(x, sequence_length, sequence_stride, fft_length) + + def test_stft_invalid_window(self): + # backend agnostic error message + x = np.array([1.0, 2.0, 3.0, 4.0]) + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + window = "invalid_window" + with self.assertRaisesRegex(ValueError, "If a string is passed to"): + kmath.stft( + x, sequence_length, sequence_stride, fft_length, window=window + ) + + def test_stft_invalid_window_shape(self): + # backend agnostic error message + x = np.array([1.0, 2.0, 3.0, 4.0]) + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + window = np.ones((sequence_length + 1)) + with self.assertRaisesRegex(ValueError, "The shape of `window` must"): + kmath.stft( + x, sequence_length, sequence_stride, fft_length, window=window + ) + + def test_istft_invalid_window_shape_2D_inputs(self): + # backend agnostic error message + x = (np.array([[1.0, 2.0]]), np.array([[3.0, 4.0]])) + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + incorrect_window = np.ones((sequence_length + 1,)) + with self.assertRaisesRegex( + ValueError, "The shape of `window` must be equal to" + ): + kmath.istft( + x, + sequence_length, + sequence_stride, + fft_length, + window=incorrect_window, + ) diff --git a/keras/src/ops/nn.py b/keras/src/ops/nn.py index 4ec642b018a..189f46bee0d 100644 --- a/keras/src/ops/nn.py +++ b/keras/src/ops/nn.py @@ -1810,8 +1810,8 @@ def batch_normalization( ) -class CtcLoss(Operation): - def __init__(self, mask_index): +class CTCLoss(Operation): + def __init__(self, mask_index=0): super().__init__() self.mask_index = mask_index @@ -1838,8 +1838,8 @@ def compute_output_spec(self, target, output, target_length, output_length): self._check_shape_first_dim( "output_length", output_length.shape, "output", output.shape ) - - return KerasTensor((target.shape[0],), dtype=target.dtype) + dtype = backend.result_type(output.dtype, "float32") + return KerasTensor((target.shape[0],), dtype=dtype) @keras_export( @@ -1865,7 +1865,7 @@ def ctc_loss(target, output, target_length, output_length, mask_index=0): """ if any_symbolic_tensors((target, output, target_length, output_length)): - return CtcLoss(mask_index).symbolic_call( + return CTCLoss(mask_index).symbolic_call( target, output, target_length, output_length ) return backend.nn.ctc_loss( @@ -1873,6 +1873,115 @@ def ctc_loss(target, output, target_length, output_length, mask_index=0): ) +class CTCDecode(Operation): + def __init__( + self, + strategy="greedy", + beam_width=100, + top_paths=1, + merge_repeated=True, + mask_index=0, + ): + super().__init__() + self.strategy = strategy + self.beam_width = beam_width + self.top_paths = top_paths + self.merge_repeated = merge_repeated + self.mask_index = mask_index + + def call(self, inputs, sequence_lengths): + return backend.nn.ctc_decode( + inputs, + sequence_lengths, + strategy=self.strategy, + beam_width=self.beam_width, + top_paths=self.top_paths, + merge_repeated=self.merge_repeated, + mask_index=self.mask_index, + ) + + def compute_output_spec(self, inputs, sequence_lengths): + inputs_shape = inputs.shape + if self.strategy == "greedy": + top_paths = 1 + else: + top_paths = self.top_paths + dtype = backend.result_type(inputs.dtype, "float32") + return ( + KerasTensor( + (top_paths, inputs_shape[0], inputs_shape[1]), dtype="int32" + ), + KerasTensor((inputs_shape[0], top_paths), dtype=dtype), + ) + + +@keras_export( + [ + "keras.ops.ctc_decode", + "keras.ops.nn.ctc_decode", + ] +) +def ctc_decode( + inputs, + sequence_lengths, + strategy="greedy", + beam_width=100, + top_paths=1, + merge_repeated=True, + mask_index=0, +): + """Decodes the output of a CTC model. + + Args: + inputs: A tensor of shape `(batch_size, max_length, num_classes)` + containing the logits (the output of the model). + They should *not* be normalized via softmax. + sequence_lengths: A tensor of shape `(batch_size,)` containing the + sequence lengths for the batch. + strategy: A string for the decoding strategy. Supported values are + `"greedy"` and `"beam_search"`. + beam_width: An integer scalar beam width used in beam search. + Defaults to 100. + top_paths: An integer scalar, the number of top paths to return. + Defaults to 1. + merge_repeated: A boolean scalar, whether to merge repeated + labels in the output. Defaults to `True`. + mask_index: An integer scalar, the index of the mask character in + the vocabulary. Defaults to `0`. + + Returns: + A tuple containing: + - The tensor representing the list of decoded sequences. If + `strategy="greedy"`, the shape is `(1, batch_size, max_length)`. If + `strategy="beam_seatch"`, the shape is + `(top_paths, batch_size, max_length)`. Note that: `-1` indicates the + blank label. + - If `strategy="greedy"`, a tensor of shape `(batch_size, 1)` + representing the negative of the sum of the probability logits for + each sequence. If `strategy="beam_seatch"`, a tensor of shape + `(batch_size, top_paths)` representing the log probability for each + sequence. + """ + + if any_symbolic_tensors((inputs, sequence_lengths)): + return CTCDecode( + strategy=strategy, + beam_width=beam_width, + top_paths=top_paths, + merge_repeated=merge_repeated, + mask_index=mask_index, + ).symbolic_call(inputs, sequence_lengths) + return backend.nn.ctc_decode( + inputs=inputs, + sequence_lengths=sequence_lengths, + strategy=strategy, + beam_width=beam_width, + top_paths=top_paths, + merge_repeated=merge_repeated, + mask_index=mask_index, + ) + + class Normalize(Operation): def __init__(self, axis=-1, order=2): super().__init__() @@ -1933,3 +2042,76 @@ def _normalize(x, axis=-1, order=2): norm = backend.linalg.norm(x, ord=order, axis=axis, keepdims=True) denom = backend.numpy.maximum(norm, epsilon) return backend.numpy.divide(x, denom) + + +class PSNR(Operation): + def __init__( + self, + max_val, + ): + super().__init__() + self.max_val = max_val + + def call(self, x1, x2): + return backend.nn.psnr( + x1=x1, + x2=x2, + max_val=self.max_val, + ) + + def compute_output_spec(self, x1, x2): + if len(x1.shape) != len(x2.shape): + raise ValueError("Inputs must have the same rank") + + return KerasTensor(shape=()) + + +@keras_export( + [ + "keras.ops.psnr", + "keras.ops.nn.psnr", + ] +) +def psnr( + x1, + x2, + max_val, +): + """Peak Signal-to-Noise Ratio (PSNR) function. + + This function computes the Peak Signal-to-Noise Ratio between two signals, + `x1` and `x2`. PSNR is a measure of the quality of a reconstructed signal. + The higher the PSNR, the closer the reconstructed signal is to the original + signal. Note that it can become negative when the signal power is + smaller that the noise power. + + Args: + x1: The first input signal. + x2: The second input signal. Must have the same shape as `x1`. + max_val: The maximum possible value in the signals. + + Returns: + float: The PSNR value between `x1` and `x2`. + + Examples: + + >>> x1 = keras.random.normal((2, 4, 4, 3)) + >>> x2 = keras.random.normal((2, 4, 4, 3)) + >>> max_val = 1.0 + >>> keras.ops.nn.psnr(x1, x2, max_val) + -3.1697404 + """ + if any_symbolic_tensors( + ( + x1, + x2, + ) + ): + return PSNR( + max_val, + ).symbolic_call(x1, x2) + return backend.nn.psnr( + x1, + x2, + max_val, + ) diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index 01b84b15763..2ca9350384a 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -633,10 +633,33 @@ def test_batch_normalization(self): scale=KerasTensor([3]), ) + def test_ctc_decode(self): + # Test strategy="greedy" + inputs = KerasTensor([None, 2, 3]) + sequence_lengths = KerasTensor([None]) + decoded, scores = knn.ctc_decode(inputs, sequence_lengths) + self.assertEqual(decoded.shape, (1, None, 2)) + self.assertEqual(scores.shape, (None, 1)) + + # Test strategy="beam_search" + inputs = KerasTensor([None, 2, 3]) + sequence_lengths = KerasTensor([None]) + decoded, scores = knn.ctc_decode( + inputs, sequence_lengths, strategy="beam_search", top_paths=2 + ) + self.assertEqual(decoded.shape, (2, None, 2)) + self.assertEqual(scores.shape, (None, 2)) + def test_normalize(self): x = KerasTensor([None, 2, 3]) self.assertEqual(knn.normalize(x).shape, (None, 2, 3)) + def test_psnr(self): + x1 = KerasTensor([None, 2, 3]) + x2 = KerasTensor([None, 5, 6]) + out = knn.psnr(x1, x2, max_val=224) + self.assertEqual(out.shape, ()) + class NNOpsStaticShapeTest(testing.TestCase): def test_relu(self): @@ -1069,10 +1092,6 @@ def test_batch_normalization(self): (10, 3, 4, 5), ) - @pytest.mark.skipif( - backend.backend() == "numpy", - reason="Numpy does not support CTC loss", - ) def test_ctc_loss(self): x = KerasTensor([10, 3, 4]) y = KerasTensor([10, 3], dtype="int32") @@ -1080,10 +1099,33 @@ def test_ctc_loss(self): y_lengths = KerasTensor([10], dtype="int32") self.assertEqual(knn.ctc_loss(x, y, x_lengths, y_lengths).shape, (10,)) + def test_ctc_decode(self): + # Test strategy="greedy" + inputs = KerasTensor([10, 2, 3]) + sequence_lengths = KerasTensor([10]) + decoded, scores = knn.ctc_decode(inputs, sequence_lengths) + self.assertEqual(decoded.shape, (1, 10, 2)) + self.assertEqual(scores.shape, (10, 1)) + + # Test strategy="beam_search" + inputs = KerasTensor([10, 2, 3]) + sequence_lengths = KerasTensor([10]) + decoded, scores = knn.ctc_decode( + inputs, sequence_lengths, strategy="beam_search", top_paths=2 + ) + self.assertEqual(decoded.shape, (2, 10, 2)) + self.assertEqual(scores.shape, (10, 2)) + def test_normalize(self): x = KerasTensor([1, 2, 3]) self.assertEqual(knn.normalize(x).shape, (1, 2, 3)) + def test_psnr(self): + x1 = KerasTensor([1, 2, 3]) + x2 = KerasTensor([5, 6, 7]) + out = knn.psnr(x1, x2, max_val=224) + self.assertEqual(out.shape, ()) + class NNOpsCorrectnessTest(testing.TestCase, parameterized.TestCase): def test_relu(self): @@ -1403,23 +1445,29 @@ def test_conv_2d_group_2(self, strides, dilation_rate): ) self.assertAllClose(outputs, expected) - @parameterized.product(strides=(1, (1, 1, 1), 2), padding=("valid", "same")) - def test_conv_3d(self, strides, padding): - if backend.config.image_data_format() == "channels_last": + @parameterized.product( + strides=(1, (1, 1, 1), 2), + padding=("valid", "same"), + data_format=("channels_first", "channels_last"), + ) + def test_conv_3d(self, strides, padding, data_format): + if data_format == "channels_last": input_shape = (2, 8, 8, 8, 3) else: input_shape = (2, 3, 8, 8, 8) inputs_3d = np.arange(3072, dtype=float).reshape(input_shape) kernel = np.arange(162, dtype=float).reshape([3, 3, 3, 3, 2]) - outputs = knn.conv(inputs_3d, kernel, strides, padding=padding) + outputs = knn.conv( + inputs_3d, kernel, strides, padding=padding, data_format=data_format + ) expected = np_conv3d( inputs_3d, kernel, bias_weights=np.zeros((2,)), strides=strides, padding=padding, - data_format=backend.config.image_data_format(), + data_format=data_format, dilation_rate=1, groups=1, ) @@ -1884,10 +1932,6 @@ def test_batch_normalization(self): ) self.assertEqual(tuple(output.shape), (2, 3, 3, 5)) - @pytest.mark.skipif( - backend.backend() == "numpy", - reason="Numpy does not support CTC loss", - ) def test_ctc_loss(self): labels = np.array([[1, 2, 1], [1, 2, 2]]) outputs = np.array( @@ -1903,6 +1947,82 @@ def test_ctc_loss(self): result = knn.ctc_loss(labels, outputs, label_length, output_length) self.assertAllClose(result, np.array([3.4411672, 1.91680186])) + def test_ctc_decode(self): + inputs = np.array( + [ + [ + [0.1, 0.4, 0.2, 0.4], + [0.3, -0.3, 0.4, 0.2], + [0.3, 0.2, 0.4, 0.3], + ], + [ + [0.7, 0.4, 0.3, 0.2], + [0.3, 0.3, 0.4, 0.1], + [0.6, -0.1, 0.1, 0.5], + ], + [ + [0.1, 0.4, 0.2, 0.7], + [0.3, 0.3, -0.2, 0.7], + [0.3, 0.2, 0.4, 0.1], + ], + ] + ) + labels = np.array([[1, 2, -1], [2, -1, -1], [3, -1, -1]]) + score_labels = np.array([[-1.2], [-1.7], [-0.7]]) + repeated_labels = np.array([[1, 2, 2], [2, -1, -1], [3, -1, -1]]) + + # Test strategy="greedy" and merge_repeated=True + (decoded,), scores = knn.ctc_decode( + inputs, + sequence_lengths=[3, 3, 1], + strategy="greedy", + mask_index=0, + ) + self.assertAllClose(decoded, labels) + self.assertAllClose(scores, score_labels) + + # Test strategy="greedy" and merge_repeated=False + (decoded,), scores = knn.ctc_decode( + inputs, + sequence_lengths=[3, 3, 1], + strategy="greedy", + merge_repeated=False, + mask_index=0, + ) + self.assertAllClose(decoded, repeated_labels) + self.assertAllClose(scores, score_labels) + + if backend.backend() == "torch": + self.skipTest("torch doesn't support 'beam_search' strategy") + + labels = np.array( + [ + [[1, 2, -1], [2, -1, -1], [3, -1, -1]], + [[2, -1, -1], [3, -1, -1], [1, -1, -1]], + ] + ) + score_labels = np.array( + [ + [-2.426537, -2.435596], + [-2.127681, -2.182338], + [-1.063386, -1.363386], + ] + ) + beam_width = 4 + top_paths = 2 + + # Test strategy="beam_search" + decoded, scores = knn.ctc_decode( + inputs, + sequence_lengths=[3, 3, 1], + strategy="beam_search", + beam_width=beam_width, + top_paths=top_paths, + mask_index=0, + ) + self.assertAllClose(decoded, labels) + self.assertAllClose(scores, score_labels) + def test_normalize(self): x = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) self.assertAllClose( @@ -1934,22 +2054,24 @@ def test_normalize(self): ], ) - -class TestLogitRecovery(testing.TestCase): - def test_logit_recovery_binary_crossentropy(self): - layer = layers.Dense( - 4, activation="sigmoid", use_bias=False, kernel_initializer="ones" + def test_psnr(self): + x1 = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + x2 = np.array([[0.2, 0.2, 0.3], [0.4, 0.6, 0.6]]) + max_val = 1.0 + expected_psnr_1 = 20 * np.log10(max_val) - 10 * np.log10( + np.mean(np.square(x1 - x2)) ) - loss = losses.BinaryCrossentropy() - x = np.array([[1.4, 1.6, 0.8]]) - y = np.array([[0.2, 0.6, 0.1, 0.3]]) - loss_value = loss(y, layer(x)) - self.assertAllClose(loss_value, 2.682124) + psnr_1 = knn.psnr(x1, x2, max_val) + self.assertAlmostEqual(psnr_1, expected_psnr_1) - model = models.Sequential([layer]) - model.compile(loss="binary_crossentropy", optimizer="sgd") - out = model.evaluate(x, y) - self.assertAllClose(out, 2.682124) + x3 = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + x4 = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + max_val = 1.0 + expected_psnr_2 = 20 * np.log10(max_val) - 10 * np.log10( + np.mean(np.square(x3 - x4)) + ) + psnr_2 = knn.psnr(x3, x4, max_val) + self.assertAlmostEqual(psnr_2, expected_psnr_2) class NNOpsDtypeTest(testing.TestCase, parameterized.TestCase): @@ -2296,6 +2418,85 @@ def test_softsign(self, dtype): expected_dtype, ) + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_ctc_loss(self, dtype): + labels = knp.array([[1, 2, 1]], dtype="int32") + outputs = knp.array( + [[[0.4, 0.8, 0.4], [0.2, 0.8, 0.3], [0.9, 0.4, 0.5]]], dtype=dtype + ) + label_length = knp.array([3]) + output_length = knp.array([3]) + expected_dtype = ( + "float32" if dtype in ("float16", "bfloat16") else dtype + ) + + self.assertEqual( + standardize_dtype( + knn.ctc_loss(labels, outputs, label_length, output_length).dtype + ), + expected_dtype, + ) + self.assertEqual( + standardize_dtype( + knn.CTCLoss() + .symbolic_call(labels, outputs, label_length, output_length) + .dtype + ), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_ctc_decode(self, dtype): + inputs = knp.array( + [[[0.4, 0.8, 0.4], [0.2, 0.8, 0.3], [0.9, 0.4, 0.5]]], dtype=dtype + ) + sequence_length = knp.array([3]) + expected_dtype = backend.result_type(dtype, "float32") + + # Test strategy="greedy" + decoded, scores = knn.ctc_decode( + inputs, sequence_length, strategy="greedy" + ) + self.assertEqual(standardize_dtype(decoded.dtype), "int32") + self.assertEqual(standardize_dtype(scores.dtype), expected_dtype) + decoded, scores = knn.CTCDecode(strategy="greedy").symbolic_call( + inputs, sequence_length + ) + self.assertEqual(standardize_dtype(decoded.dtype), "int32") + self.assertEqual(standardize_dtype(scores.dtype), expected_dtype) + + if backend.backend() == "torch": + self.skipTest("torch doesn't support 'beam_search' strategy") + + # Test strategy="beam_search" + decoded, scores = knn.ctc_decode( + inputs, sequence_length, strategy="beam_search" + ) + self.assertEqual(standardize_dtype(decoded.dtype), "int32") + self.assertEqual(standardize_dtype(scores.dtype), expected_dtype) + decoded, scores = knn.CTCDecode(strategy="beam_search").symbolic_call( + inputs, sequence_length + ) + self.assertEqual(standardize_dtype(decoded.dtype), "int32") + self.assertEqual(standardize_dtype(scores.dtype), expected_dtype) + + +class NNOpsBehaviorTest(testing.TestCase, parameterized.TestCase): + def test_logit_recovery_binary_crossentropy(self): + layer = layers.Dense( + 4, activation="sigmoid", use_bias=False, kernel_initializer="ones" + ) + loss = losses.BinaryCrossentropy() + x = np.array([[1.4, 1.6, 0.8]]) + y = np.array([[0.2, 0.6, 0.1, 0.3]]) + loss_value = loss(y, layer(x)) + self.assertAllClose(loss_value, 2.682124) + + model = models.Sequential([layer]) + model.compile(loss="binary_crossentropy", optimizer="sgd") + out = model.evaluate(x, y) + self.assertAllClose(out, 2.682124) + def test_softmax_on_axis_with_size_one_warns(self): x = np.array([[1.0]]) # Applying softmax on the second axis, which has size 1 @@ -2341,10 +2542,31 @@ def test_normalize_order_validation(self): def test_check_shape_first_dim_mismatch(self): name1, shape1 = "labels", (2, 3) name2, shape2 = "logits", (3, 4, 5) - ctc_loss_instance = knn.CtcLoss(mask_index=-1) + ctc_loss_instance = knn.CTCLoss(mask_index=-1) with self.assertRaisesRegex( ValueError, "must have the same first dimension" ): ctc_loss_instance._check_shape_first_dim( name1, shape1, name2, shape2 ) + + def test_invalid_strategy_ctc_decode(self): + inputs = np.array( + [ + [ + [0.1, 0.4, 0.2, 0.4], + [0.3, 0.3, 0.4, 0.2], + [0.3, 0.2, 0.4, 0.3], + ] + ] + ) + beam_width = 4 + top_paths = 2 + with self.assertRaisesRegex(ValueError, "Invalid strategy"): + knn.ctc_decode( + inputs, + sequence_lengths=[3, 3, 1], + strategy="invalid", + beam_width=beam_width, + top_paths=top_paths, + ) diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py index f4c8ffb7f71..c11899720ae 100644 --- a/keras/src/ops/numpy.py +++ b/keras/src/ops/numpy.py @@ -1,147 +1,3 @@ -""" -MANIFEST: - -abs -absolute -add -all -amax -amin -append -arange -arccos -arccosh -arcsin -arcsinh -arctan -arctan2 -arctanh -argmax -argmin -argsort -array -average -bincount -broadcast_to -ceil -clip -concatenate -conj -conjugate -copy -correlate -cos -cosh -count_nonzero -cross -cumprod -cumsum -diag -diagonal -diff -digitize -divide -dot -dtype -einsum -empty -equal -exp -expand_dims -expm1 -eye -flip -floor -full -full_like -greater -greater_equal -hstack -identity -imag -interp -isclose -isfinite -isinf -isnan -less -less_equal -linspace -log -log10 -log1p -log2 -logaddexp -logical_and -logical_not -logical_or -logspace -matmul -max -maximum -mean -median -meshgrid -mgrid -min -minimum -mod -moveaxis -multiply -nan_to_num -ndim -nonzero -not_equal -ones -ones_like -outer -pad -percentile -power -prod -quantile -ravel -real -reciprocal -repeat -reshape -roll -round -sign -sin -sinh -size -sort -split -sqrt -square -squeeze -stack -std -subtract -sum -swapaxes -take -take_along_axis -tan -tanh -tensordot -tile -trace -transpose -tri -tril -triu -true_divide -vdot -vstack -where -zeros -zeros_like - - -""" - import builtins import re @@ -3955,8 +3811,19 @@ def moveaxis(x, source, destination): class NanToNum(Operation): + def __init__(self, nan=0.0, posinf=None, neginf=None): + super().__init__() + self.nan = nan + self.posinf = posinf + self.neginf = neginf + def call(self, x): - return backend.numpy.nan_to_num(x) + return backend.numpy.nan_to_num( + x, nan=self.nan, posinf=self.posinf, neginf=self.neginf + ) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) @keras_export( @@ -3965,16 +3832,23 @@ def call(self, x): "keras.ops.numpy.nan_to_num", ] ) -def nan_to_num(x): +def nan_to_num(x, nan=0.0, posinf=None, neginf=None): """Replace NaN with zero and infinity with large finite numbers. Args: x: Input data. + nan: Optional float or int. Value to replace `NaN` entries with. + posinf: Optional float or int. + Value to replace positive infinity with. + neginf: Optional float or int. + Value to replace negative infinity with. Returns: `x`, with non-finite values replaced. """ - return backend.numpy.nan_to_num(x) + if any_symbolic_tensors((x,)): + return NanToNum(nan=nan, posinf=posinf, neginf=neginf).symbolic_call(x) + return backend.numpy.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) class Ndim(Operation): @@ -4007,7 +3881,7 @@ def call(self, x): return backend.numpy.nonzero(x) def compute_output_spec(self, x): - return KerasTensor([None] * len(x.shape)) + return KerasTensor([None] * len(x.shape), dtype="int32") @keras_export(["keras.ops.nonzero", "keras.ops.numpy.nonzero"]) @@ -5266,14 +5140,18 @@ def trace(x, offset=0, axis1=0, axis2=1): class Tri(Operation): - def call(self, N, M=None, k=0, dtype=None): - return backend.numpy.tri(N, M=M, k=k, dtype=dtype) + def __init__(self, k=0, dtype=None): + super().__init__() + self.k = k + self.dtype = dtype or backend.floatx() + + def call(self, N, M=None): + return backend.numpy.tri(N=N, M=M, k=self.k, dtype=self.dtype) - def compute_output_spec(self, N, M=None, k=0, dtype=None): + def compute_output_spec(self, N, M=None): if M is None: M = N - dtype = dtype or backend.floatx() - return KerasTensor((N, M), dtype=dtype) + return KerasTensor((N, M), dtype=self.dtype) @keras_export(["keras.ops.tri", "keras.ops.numpy.tri"]) @@ -5393,6 +5271,48 @@ def vdot(x1, x2): return backend.numpy.vdot(x1, x2) +@keras_export(["keras.ops.vectorize", "keras.ops.numpy.vectorize"]) +def vectorize(pyfunc, *, excluded=None, signature=None): + """Turn a function into a vectorized function. + + Example: + + ```python + def myfunc(a, b): + return a + b + + vfunc = np.vectorize(myfunc) + y = vfunc([1, 2, 3, 4], 2) # Returns Tensor([3, 4, 5, 6]) + ``` + + Args: + pyfunc: Callable of a single tensor argument. + excluded: Optional set of integers representing + positional arguments for which the function + will not be vectorized. + These will be passed directly to `pyfunc` unmodified. + signature: Optional generalized universal function signature, + e.g., `"(m,n),(n)->(m)"` for vectorized + matrix-vector multiplication. If provided, + `pyfunc` will be called with (and expected to return) + arrays with shapes given by the size of corresponding + core dimensions. By default, `pyfunc` is assumed + to take scalars tensors as input and output. + + Returns: + A new function that applies `pyfunc` to every element + of its input along axis 0 (the batch axis). + """ + if not callable(pyfunc): + raise ValueError( + "Expected argument `pyfunc` to be a callable. " + f"Received: pyfunc={pyfunc}" + ) + return backend.numpy.vectorize( + pyfunc, excluded=excluded, signature=signature + ) + + class Vstack(Operation): def call(self, xs): return backend.numpy.vstack(xs) @@ -6021,14 +5941,18 @@ def ones(shape, dtype=None): class Eye(Operation): - def call(self, N, M=None, k=0, dtype=None): - return backend.numpy.eye(N, M=M, k=k, dtype=dtype) + def __init__(self, k=0, dtype=None): + super().__init__() + self.k = k + self.dtype = dtype or backend.floatx() + + def call(self, N, M=None): + return backend.numpy.eye(N, M=M, k=self.k, dtype=self.dtype) - def compute_output_spec(self, N, M=None, k=0, dtype=None): + def compute_output_spec(self, N, M=None): if M is None: M = N - dtype = dtype or backend.floatx() - return KerasTensor((N, M), dtype=dtype) + return KerasTensor((N, M), dtype=self.dtype) @keras_export(["keras.ops.eye", "keras.ops.numpy.eye"]) @@ -6172,3 +6096,140 @@ def correlate(x1, x2, mode="valid"): if any_symbolic_tensors((x1, x2)): return Correlate(mode=mode).symbolic_call(x1, x2) return backend.numpy.correlate(x1, x2, mode=mode) + + +class Select(Operation): + def __init__(self): + super().__init__() + + def call(self, condlist, choicelist, default=0): + return backend.numpy.select(condlist, choicelist, default) + + def compute_output_spec(self, condlist, choicelist, default=0): + first_element = choicelist[0] + return KerasTensor(first_element.shape, dtype=first_element.dtype) + + +@keras_export(["keras.ops.select", "keras.ops.numpy.select"]) +def select(condlist, choicelist, default=0): + """Return elements from `choicelist`, based on conditions in `condlist`. + + Args: + condlist: List of boolean tensors. + The list of conditions which determine from which array + in choicelist the output elements are taken. + When multiple conditions are satisfied, + the first one encountered in condlist is used. + choicelist: List of tensors. + The list of tensors from which the output elements are taken. + This list has to be of the same length as `condlist`. + defaults: Optional scalar value. + The element inserted in the output + when all conditions evaluate to `False`. + + Returns: + Tensor where the output at position `m` is the `m`-th element + of the tensor in `choicelist` where the `m`-th element of the + corresponding tensor in `condlist` is `True`. + + Example: + + ```python + from keras import ops + + x = ops.arange(6) + condlist = [x<3, x>3] + choicelist = [x, x**2] + ops.select(condlist, choicelist, 42) + # Returns: tensor([0, 1, 2, 42, 16, 25]) + ``` + """ + if not isinstance(condlist, list) or not isinstance(choicelist, list): + raise ValueError( + "condlist and choicelist must be lists. Received: " + f"type(condlist) = {type(condlist)}, " + f"type(choicelist) = {type(choicelist)}" + ) + if not condlist or not choicelist: + raise ValueError( + "condlist and choicelist must not be empty. Received: " + f"condlist = {condlist}, " + f"choicelist = {choicelist}" + ) + if any_symbolic_tensors(condlist + choicelist + [default]): + return Select().symbolic_call(condlist, choicelist, default) + return backend.numpy.select(condlist, choicelist, default) + + +class Slogdet(Operation): + def __init__(self): + super().__init__() + + def call(self, x): + return backend.numpy.slogdet(x) + + def compute_output_spec(self, x): + sign = KerasTensor((), dtype=x.dtype) + logabsdet = KerasTensor((), dtype=x.dtype) + return (sign, logabsdet) + + +@keras_export(["keras.ops.slogdet", "keras.ops.numpy.slogdet"]) +def slogdet(x): + """Compute the sign and natural logarithm of the determinant of a matrix. + + Args: + x: Input matrix. It must 2D and square. + + Returns: + A tuple `(sign, logabsdet)`. `sign` is a number representing + the sign of the determinant. For a real matrix, this is 1, 0, or -1. + For a complex matrix, this is a complex number with absolute value 1 + (i.e., it is on the unit circle), or else 0. + `logabsdet` is the natural log of the absolute value of the determinant. + """ + if any_symbolic_tensors((x,)): + return Slogdet().symbolic_call(x) + return backend.numpy.slogdet(x) + + +class Argpartition(Operation): + def __init__(self, kth, axis=-1): + super().__init__() + if not isinstance(kth, int): + raise ValueError("kth must be an integer. Received:" f"kth = {kth}") + self.kth = kth + self.axis = axis + + def call(self, x): + return backend.numpy.argpartition(x, kth=self.kth, axis=self.axis) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype="int32") + + +@keras_export(["keras.ops.argpartition", "keras.ops.numpy.argpartition"]) +def argpartition(x, kth, axis=-1): + """Performs an indirect partition along the given axis. + + It returns an array + of indices of the same shape as `x` that index data along the given axis + in partitioned order. + + Args: + a: Array to sort. + kth: Element index to partition by. + The k-th element will be in its final sorted position and all + smaller elements will be moved before it and all larger elements + behind it. The order of all elements in the partitions is undefined. + If provided with a sequence of k-th it will partition all of them + into their sorted position at once. + axis: Axis along which to sort. The default is -1 (the last axis). + If `None`, the flattened array is used. + + Returns: + Array of indices that partition `x` along the specified `axis`. + """ + if any_symbolic_tensors((x,)): + return Argpartition(kth, axis).symbolic_call(x) + return backend.numpy.argpartition(x, kth, axis) diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 944a6807839..aee3a63e8ee 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -1475,6 +1475,14 @@ def test_vstack(self): y = KerasTensor((None, None)) self.assertEqual(knp.vstack([x, y]).shape, (None, 3)) + def test_argpartition(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.argpartition(x, 3).shape, (None, 3)) + self.assertEqual(knp.argpartition(x, 1, axis=1).shape, (None, 3)) + + with self.assertRaises(ValueError): + knp.argpartition(x, (1, 3)) + class NumpyOneInputOpsStaticShapeTest(testing.TestCase): def test_mean(self): @@ -1981,6 +1989,14 @@ def test_vstack(self): y = KerasTensor((2, 3)) self.assertEqual(knp.vstack([x, y]).shape, (4, 3)) + def test_argpartition(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.argpartition(x, 3).shape, (2, 3)) + self.assertEqual(knp.argpartition(x, 1, axis=1).shape, (2, 3)) + + with self.assertRaises(ValueError): + knp.argpartition(x, (1, 3)) + class NumpyTwoInputOpsCorretnessTest(testing.TestCase, parameterized.TestCase): def test_add(self): @@ -2170,6 +2186,14 @@ def test_cross(self): self.assertAllClose(knp.Cross()(x1, y3), np.cross(x1, y3)) self.assertAllClose(knp.Cross()(x2, y3), np.cross(x2, y3)) + # Test axis is not None + self.assertAllClose( + knp.cross(x1, y1, axis=-1), np.cross(x1, y1, axis=-1) + ) + self.assertAllClose( + knp.Cross(axis=-1)(x1, y1), np.cross(x1, y1, axis=-1) + ) + def test_einsum(self): x = np.arange(24).reshape([2, 3, 4]).astype("float32") y = np.arange(24).reshape([2, 4, 3]).astype("float32") @@ -2452,6 +2476,10 @@ def test_linspace(self): knp.Linspace(num=5, endpoint=False)(0, 10), np.linspace(0, 10, 5, endpoint=False), ) + self.assertAllClose( + knp.Linspace(num=0, endpoint=False)(0, 10), + np.linspace(0, 10, 0, endpoint=False), + ) start = np.zeros([2, 3, 4]) stop = np.ones([2, 3, 4]) @@ -2659,27 +2687,33 @@ def test_take(self): self.assertAllClose(knp.Take()(x, 0), np.take(x, 0)) self.assertAllClose(knp.Take(axis=1)(x, 0), np.take(x, 0, axis=1)) - # test with multi-dimensional indices + # Test with multi-dimensional indices rng = np.random.default_rng(0) x = rng.standard_normal((2, 3, 4, 5)) indices = rng.integers(0, 4, (6, 7)) self.assertAllClose( - knp.take(x, indices, axis=2), - np.take(x, indices, axis=2), + knp.take(x, indices, axis=2), np.take(x, indices, axis=2) ) - # test with negative axis + # Test with negative axis self.assertAllClose( - knp.take(x, indices, axis=-2), - np.take(x, indices, axis=-2), + knp.take(x, indices, axis=-2), np.take(x, indices, axis=-2) ) - # test with axis=None & x.ndim=2 + + # Test with axis=None & x.ndim=2 x = np.array(([1, 2], [3, 4])) indices = np.array([2, 3]) self.assertAllClose( knp.take(x, indices, axis=None), np.take(x, indices, axis=None) ) + # Test with negative indices + x = rng.standard_normal((2, 3, 4, 5)) + indices = rng.integers(-3, 0, (6, 7)) + self.assertAllClose( + knp.take(x, indices, axis=2), np.take(x, indices, axis=2) + ) + @parameterized.named_parameters( named_product( [ @@ -2744,6 +2778,30 @@ def test_take_along_axis(self): np.take_along_axis(x, indices, axis=2), ) + # Test with axis=None + x = np.arange(12).reshape([1, 1, 3, 4]) + indices = np.array([1, 2, 3], dtype=np.int32) + self.assertAllClose( + knp.take_along_axis(x, indices, axis=None), + np.take_along_axis(x, indices, axis=None), + ) + self.assertAllClose( + knp.TakeAlongAxis(axis=None)(x, indices), + np.take_along_axis(x, indices, axis=None), + ) + + # Test with negative indices + x = np.arange(12).reshape([1, 1, 3, 4]) + indices = np.full([1, 4, 1, 1], -1, dtype=np.int32) + self.assertAllClose( + knp.take_along_axis(x, indices, axis=2), + np.take_along_axis(x, indices, axis=2), + ) + self.assertAllClose( + knp.TakeAlongAxis(axis=2)(x, indices), + np.take_along_axis(x, indices, axis=2), + ) + def test_tensordot(self): x = np.arange(24).reshape([1, 2, 3, 4]).astype("float32") y = np.arange(24).reshape([3, 4, 1, 2]).astype("float32") @@ -3047,9 +3105,13 @@ def test_arctanh(self): self.assertAllClose(knp.Arctanh()(x), np.arctanh(x)) def test_argmax(self): - x = np.array([[1, 2, 3], [3, 2, 1]]) + x = np.array([[1, 2, 3], [3, 2, 1], [4, 5, 6]]) self.assertAllClose(knp.argmax(x), np.argmax(x)) self.assertAllClose(knp.argmax(x, axis=1), np.argmax(x, axis=1)) + self.assertAllClose( + knp.argmax(x, axis=1, keepdims=True), + np.argmax(x, axis=1, keepdims=True), + ) self.assertAllClose( knp.argmax(x, keepdims=True), np.argmax(x, keepdims=True) ) @@ -3448,6 +3510,10 @@ def test_diff(self): self.assertAllClose(knp.diff(x, n=2, axis=0), np.diff(x, n=2, axis=0)) self.assertAllClose(knp.diff(x, n=2, axis=1), np.diff(x, n=2, axis=1)) + # Test n=0 + x = np.array([1, 2, 4, 7, 0]) + self.assertAllClose(knp.diff(x, n=0), np.diff(x, n=0)) + def test_dot(self): x = np.arange(24).reshape([2, 3, 4]).astype("float32") y = np.arange(12).reshape([4, 3]).astype("float32") @@ -3865,6 +3931,20 @@ def test_round(self): self.assertAllClose(knp.round(x), np.round(x)) self.assertAllClose(knp.Round()(x), np.round(x)) + # Test with decimal=1 + self.assertAllClose(knp.round(x, decimals=1), np.round(x, decimals=1)) + self.assertAllClose(knp.Round(decimals=1)(x), np.round(x, decimals=1)) + + # Test with integers + x = np.array([[1, 2, 3], [3, 2, 1]], dtype="int32") + self.assertAllClose(knp.round(x, decimals=1), np.round(x, decimals=1)) + self.assertAllClose(knp.Round(decimals=1)(x), np.round(x, decimals=1)) + + # Test with integers and decimal < 0 + x = np.array([[123, 234, 345], [345, 234, 123]], dtype="int32") + self.assertAllClose(knp.round(x, decimals=-1), np.round(x, decimals=-1)) + self.assertAllClose(knp.Round(decimals=-1)(x), np.round(x, decimals=-1)) + def test_sign(self): x = np.array([[1, -2, 3], [-3, 2, -1]]) self.assertAllClose(knp.sign(x), np.sign(x)) @@ -4028,6 +4108,14 @@ def test_tile(self): self.assertAllClose(knp.tile(x, [2, 3]), np.tile(x, [2, 3])) self.assertAllClose(knp.Tile([2, 3])(x), np.tile(x, [2, 3])) + # If repeats.ndim > x.ndim + self.assertAllClose(knp.tile(x, [2, 3, 4]), np.tile(x, [2, 3, 4])) + self.assertAllClose(knp.Tile([2, 3, 4])(x), np.tile(x, [2, 3, 4])) + + # If repeats.ndim < x.ndim + self.assertAllClose(knp.tile(x, [2]), np.tile(x, [2])) + self.assertAllClose(knp.Tile([2])(x), np.tile(x, [2])) + def test_trace(self): x = np.arange(24).reshape([1, 2, 3, 4]) self.assertAllClose(knp.trace(x), np.trace(x)) @@ -4169,6 +4257,85 @@ def test_correlate_different_size(self): knp.Correlate(mode="full")(x, y), np.correlate(x, y, mode="full") ) + def test_select(self): + x = np.arange(6) + condlist = [x < 3, x > 3] + choicelist = [x, x**2] + y = knp.select(condlist, choicelist, 42) + self.assertAllClose(y, [0, 1, 2, 42, 16, 25]) + + x = backend.KerasTensor((6,)) + condlist = [x < 3, x > 3] + choicelist = [x, x**2] + y = knp.select(condlist, choicelist, 42) + self.assertEqual(y.shape, (6,)) + + def test_slogdet(self): + x = np.ones((4, 4)) * 2.0 + out = knp.slogdet(x) + self.assertAllClose(out[0], 0) + self.assertAllClose(out[0], 0) + + x = backend.KerasTensor((3, 3)) + out = knp.slogdet(x) + self.assertEqual(out[0].shape, ()) + self.assertEqual(out[1].shape, ()) + + def test_nan_to_num(self): + x = knp.array([1.0, np.nan, np.inf, -np.inf]) + self.assertAllClose( + knp.nan_to_num(x), [1.0, 0.0, 3.402823e38, -3.402823e38] + ) + self.assertAllClose( + knp.NanToNum()(x), [1.0, 0.0, 3.402823e38, -3.402823e38] + ) + self.assertAllClose( + knp.nan_to_num(x, nan=2, posinf=3, neginf=4), [1.0, 2.0, 3.0, 4.0] + ) + self.assertAllClose( + knp.NanToNum(nan=2, posinf=3, neginf=4)(x), [1.0, 2.0, 3.0, 4.0] + ) + + x = backend.KerasTensor((3, 4)) + self.assertEqual( + knp.NanToNum(nan=2, posinf=3, neginf=4)(x).shape, (3, 4) + ) + + def test_vectorize(self): + # Basic functionality + def myfunc(a, b): + return a + b + + vfunc = np.vectorize(myfunc) + y = vfunc([1, 2, 3, 4], 2) + self.assertAllClose(y, [3, 4, 5, 6]) + + # Test signature arg + vfunc = knp.vectorize(knp.trace, signature="(d,d)->()") + out = vfunc(np.eye(4)) + self.assertAllClose( + out, np.vectorize(np.trace, signature="(d,d)->()")(np.eye(4)) + ) + + vfunc = knp.vectorize(knp.diag, signature="(d,d)->(d)") + out = vfunc(np.eye(4)) + self.assertAllClose( + out, np.vectorize(np.diag, signature="(d,d)->(d)")(np.eye(4)) + ) + + def test_argpartition(self): + x = np.array([3, 4, 2, 1]) + self.assertAllClose(knp.argpartition(x, 2), np.argpartition(x, 2)) + self.assertAllClose(knp.Argpartition(2)(x), np.argpartition(x, 2)) + + x = np.array([[3, 4, 2], [1, 3, 4]]) + self.assertAllClose(knp.argpartition(x, 1), np.argpartition(x, 1)) + self.assertAllClose(knp.Argpartition(1)(x), np.argpartition(x, 1)) + + x = np.array([[[3, 4], [2, 3]], [[1, 2], [0, 1]]]) + self.assertAllClose(knp.argpartition(x, 1), np.argpartition(x, 1)) + self.assertAllClose(knp.Argpartition(1)(x), np.argpartition(x, 1)) + class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase): def test_ones(self): @@ -4186,7 +4353,22 @@ def test_eye(self): self.assertAllClose(knp.Eye()(3), np.eye(3)) self.assertAllClose(knp.Eye()(3, 4), np.eye(3, 4)) - self.assertAllClose(knp.Eye()(3, 4, 1), np.eye(3, 4, 1)) + self.assertAllClose(knp.Eye(k=1)(3, 4), np.eye(3, 4, k=1)) + + # Test k >= N + self.assertAllClose(knp.Eye(k=3)(3), np.eye(3, k=3)) + + # Test k > 0 and N >= M + self.assertAllClose(knp.Eye(k=1)(3), np.eye(3, k=1)) + + # Test k > 0 and N < M and N + k > M + self.assertAllClose(knp.Eye(k=2)(3, 4), np.eye(3, 4, k=2)) + + # Test k < 0 and M >= N + self.assertAllClose(knp.Eye(k=-1)(3), np.eye(3, k=-1)) + + # Test k < 0 and M < N and M - k > N + self.assertAllClose(knp.Eye(k=-2)(4, 3), np.eye(4, 3, k=-2)) def test_arange(self): self.assertAllClose(knp.arange(3), np.arange(3)) @@ -4228,7 +4410,16 @@ def test_tri(self): self.assertAllClose(knp.Tri()(3), np.tri(3)) self.assertAllClose(knp.Tri()(3, 4), np.tri(3, 4)) - self.assertAllClose(knp.Tri()(3, 4, 1), np.tri(3, 4, 1)) + self.assertAllClose(knp.Tri(k=1)(3, 4), np.tri(3, 4, 1)) + + # Test k < 0 + self.assertAllClose(knp.Tri(k=-1)(3), np.tri(3, k=-1)) + + # Test -k-1 > N + self.assertAllClose(knp.Tri(k=-5)(3), np.tri(3, k=-5)) + + # Test k > M + self.assertAllClose(knp.Tri(k=4)(3), np.tri(3, k=4)) def create_sparse_tensor(x, indices_from=None, start=0, delta=2): @@ -5077,6 +5268,18 @@ def test_max(self, dtype): self.assertEqual(standardize_dtype(knp.max(x).dtype), expected_dtype) self.assertEqual(knp.Max().symbolic_call(x).dtype, expected_dtype) + # Test with initial + initial = 1 + expected_dtype = standardize_dtype( + jnp.max(x_jax, initial=initial).dtype + ) + self.assertEqual( + standardize_dtype(knp.max(x, initial=initial).dtype), expected_dtype + ) + self.assertEqual( + knp.Max(initial=initial).symbolic_call(x).dtype, expected_dtype + ) + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) def test_ones(self, dtype): import jax.numpy as jnp @@ -5235,6 +5438,25 @@ def test_argmin(self, dtype): expected_dtype, ) + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_argpartition(self, dtype): + import jax.numpy as jnp + + if dtype == "bool": + self.skipTest("argpartition doesn't support bool dtype") + + x = knp.array([1, 2, 3], dtype=dtype) + x_jax = jnp.array([1, 2, 3], dtype=dtype) + expected_dtype = standardize_dtype(jnp.argpartition(x_jax, 1).dtype) + + self.assertEqual( + standardize_dtype(knp.argpartition(x, 1).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Argpartition(1).symbolic_call(x).dtype), + expected_dtype, + ) + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) def test_argsort(self, dtype): import jax.numpy as jnp @@ -6048,19 +6270,19 @@ def test_eye(self, dtype): expected_dtype, ) self.assertEqual( - standardize_dtype(knp.Eye().symbolic_call(3, dtype=dtype).dtype), + standardize_dtype(knp.Eye(dtype=dtype).symbolic_call(3).dtype), expected_dtype, ) expected_dtype = standardize_dtype(jnp.eye(3, 4, 1, dtype=dtype).dtype) self.assertEqual( - standardize_dtype(knp.eye(3, 4, 1, dtype=dtype).dtype), + standardize_dtype(knp.eye(3, 4, k=1, dtype=dtype).dtype), expected_dtype, ) self.assertEqual( standardize_dtype( - knp.Eye().symbolic_call(3, 4, 1, dtype=dtype).dtype + knp.Eye(k=1, dtype=dtype).symbolic_call(3, 4).dtype ), expected_dtype, ) @@ -6746,6 +6968,18 @@ def test_min(self, dtype): self.assertEqual(standardize_dtype(knp.min(x).dtype), expected_dtype) self.assertEqual(knp.Min().symbolic_call(x).dtype, expected_dtype) + # Test with initial + initial = 0 + expected_dtype = standardize_dtype( + jnp.min(x_jax, initial=initial).dtype + ) + self.assertEqual( + standardize_dtype(knp.min(x, initial=initial).dtype), expected_dtype + ) + self.assertEqual( + knp.Min(initial=initial).symbolic_call(x).dtype, expected_dtype + ) + @parameterized.named_parameters( named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) ) @@ -6871,7 +7105,10 @@ def test_nonzero(self, dtype): self.assertEqual( standardize_dtype(knp.nonzero(x)[0].dtype), expected_dtype ) - # TODO: verify Nonzero + self.assertEqual( + standardize_dtype(knp.Nonzero().symbolic_call(x)[0].dtype), + expected_dtype, + ) @parameterized.named_parameters( named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) @@ -7521,7 +7758,7 @@ def test_tri(self, dtype): expected_dtype, ) self.assertEqual( - standardize_dtype(knp.Tri().symbolic_call(3, dtype=dtype).dtype), + standardize_dtype(knp.Tri(dtype=dtype).symbolic_call(3).dtype), expected_dtype, ) diff --git a/keras/src/optimizers/base_optimizer.py b/keras/src/optimizers/base_optimizer.py index 94ca0ea1cee..6cb63ae9709 100644 --- a/keras/src/optimizers/base_optimizer.py +++ b/keras/src/optimizers/base_optimizer.py @@ -6,11 +6,12 @@ from keras.src import ops from keras.src.optimizers.schedules import learning_rate_schedule from keras.src.saving import serialization_lib +from keras.src.saving.keras_saveable import KerasSaveable from keras.src.utils import tracking from keras.src.utils.naming import auto_name -class BaseOptimizer: +class BaseOptimizer(KerasSaveable): def __init__( self, learning_rate, @@ -814,6 +815,9 @@ def finalize_variable_values(self, var_list): # optimizer. self._overwrite_model_variables_with_average_value(var_list) + def _obj_type(self): + return "Optimizer" + def get_config(self): """Returns the config of the optimizer. diff --git a/keras/src/optimizers/optimizer_test.py b/keras/src/optimizers/optimizer_test.py index 6ab982d25d2..5706633a62a 100644 --- a/keras/src/optimizers/optimizer_test.py +++ b/keras/src/optimizers/optimizer_test.py @@ -1,7 +1,9 @@ import os +import pickle import numpy as np import pytest +from absl.testing import parameterized from keras.src import backend from keras.src import constraints @@ -11,7 +13,7 @@ from keras.src import testing -class OptimizerTest(testing.TestCase): +class OptimizerTest(testing.TestCase, parameterized.TestCase): def test_iterations_counter(self): v = backend.Variable([[1.0, 2.0], [3.0, 4.0]]) grads = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]]) @@ -318,3 +320,24 @@ def test_setting_lr_to_callable_untracks_lr_var(self): adam.learning_rate, 4 ) self.assertLen(adam.variables, 1) + + @parameterized.parameters( + [ + ("adam",), + ("sgd",), + ("adamw",), + ("adagrad",), + ("rmsprop",), + ("adadelta",), + ("adamax",), + ("lion",), + ("nadam",), + ("ftrl",), + ("adafactor",), + ] + ) + def test_pickleable_optimizers(self, optimizer): + optimizer = optimizers.get(optimizer) + reloaded = pickle.loads(pickle.dumps(optimizer)) + + self.assertEqual(optimizer.get_config(), reloaded.get_config()) diff --git a/keras/src/saving/keras_saveable.py b/keras/src/saving/keras_saveable.py new file mode 100644 index 00000000000..7fc536b470c --- /dev/null +++ b/keras/src/saving/keras_saveable.py @@ -0,0 +1,38 @@ +import io + + +class KerasSaveable: + # Note: renaming this function will cause old pickles to be broken. + # This is probably not a huge deal, as pickle should not be a recommended + # saving format -- it should only be supported for use with distributed + # computing frameworks. + + def _obj_type(self): + raise NotImplementedError( + "KerasSaveable subclases must provide an " + "implementation for `obj_type()`" + ) + + @classmethod + def _unpickle_model(cls, bytesio): + import keras.src.saving.saving_lib as saving_lib + + # pickle is not safe regardless of what you do. + return saving_lib._load_model_from_fileobj( + bytesio, custom_objects=None, compile=True, safe_mode=False + ) + + def __reduce__(self): + """__reduce__ is used to customize the behavior of `pickle.pickle()`. + + The method returns a tuple of two elements: a function, and a list of + arguments to pass to that function. In this case we just leverage the + keras saving library.""" + import keras.src.saving.saving_lib as saving_lib + + buf = io.BytesIO() + saving_lib._save_model_to_fileobj(self, buf, "h5") + return ( + self._unpickle_model, + (buf,), + ) diff --git a/keras/src/saving/saving_lib.py b/keras/src/saving/saving_lib.py index 7de68802e91..c16d2ffafb4 100644 --- a/keras/src/saving/saving_lib.py +++ b/keras/src/saving/saving_lib.py @@ -41,16 +41,16 @@ def save_model(model, filepath, weights_format="h5"): The zip-based archive contains the following structure: - JSON-based configuration file (config.json): Records of model, layer, and - other trackables' configuration. - - H5-based trackable state files, found in respective directories, such as + other saveables' configuration. + - H5-based saveable state files, found in respective directories, such as model/states.npz, model/dense_layer/states.npz, etc. - Metadata file. - The states of Keras trackables (layers, optimizers, loss, and metrics) are + The states of Keras saveables (layers, optimizers, loss, and metrics) are automatically saved as long as they can be discovered through the attributes returned by `dir(Model)`. Typically, the state includes the variables - associated with the trackable, but some specially purposed layers may - contain more such as the vocabularies stored in the hashmaps. The trackables + associated with the saveable, but some specially purposed layers may + contain more such as the vocabularies stored in the hashmaps. The saveables define how their states are saved by exposing `save_state()` and `load_state()` APIs. @@ -129,7 +129,7 @@ def _save_model_to_fileobj(model, fileobj, weights_format): weights_store=weights_store, assets_store=asset_store, inner_path="", - visited_trackables=set(), + visited_saveables=set(), ) weights_store.close() asset_store.close() @@ -188,22 +188,22 @@ def _load_model_from_fileobj(fileobj, custom_objects, compile, safe_mode): else: asset_store = None - failed_trackables = set() + failed_saveables = set() error_msgs = {} _load_state( model, weights_store=weights_store, assets_store=asset_store, inner_path="", - visited_trackables=set(), - failed_trackables=failed_trackables, + visited_saveables=set(), + failed_saveables=failed_saveables, error_msgs=error_msgs, ) weights_store.close() if asset_store: asset_store.close() - if failed_trackables: + if failed_saveables: _raise_loading_failure(error_msgs) return model @@ -223,15 +223,15 @@ def save_weights_only(model, filepath, objects_to_skip=None): ) weights_store = H5IOStore(filepath, mode="w") if objects_to_skip is not None: - visited_trackables = set(id(o) for o in objects_to_skip) + visited_saveables = set(id(o) for o in objects_to_skip) else: - visited_trackables = set() + visited_saveables = set() _save_state( model, weights_store=weights_store, assets_store=None, inner_path="", - visited_trackables=visited_trackables, + visited_saveables=visited_saveables, ) weights_store.close() @@ -254,11 +254,11 @@ def load_weights_only( _VARS_FNAME + ".h5", archive=archive, mode="r" ) - failed_trackables = set() + failed_saveables = set() if objects_to_skip is not None: - visited_trackables = set(id(o) for o in objects_to_skip) + visited_saveables = set(id(o) for o in objects_to_skip) else: - visited_trackables = set() + visited_saveables = set() error_msgs = {} _load_state( model, @@ -266,25 +266,25 @@ def load_weights_only( assets_store=None, inner_path="", skip_mismatch=skip_mismatch, - visited_trackables=visited_trackables, - failed_trackables=failed_trackables, + visited_saveables=visited_saveables, + failed_saveables=failed_saveables, error_msgs=error_msgs, ) weights_store.close() if archive: archive.close() - if failed_trackables: + if failed_saveables: _raise_loading_failure(error_msgs, warn_only=skip_mismatch) def _raise_loading_failure(error_msgs, warn_only=False): first_key = list(error_msgs.keys())[0] - ex_trackable, ex_error = error_msgs[first_key] + ex_saveable, ex_error = error_msgs[first_key] msg = ( f"A total of {len(error_msgs)} objects could not " "be loaded. Example error message for " - f"object {ex_trackable}:\n\n" + f"object {ex_saveable}:\n\n" f"{ex_error}\n\n" "List of objects that could not be loaded:\n" f"{[x[0] for x in error_msgs.values()]}" @@ -318,37 +318,30 @@ def _name_key(name): return name -def _walk_trackable(trackable): - from keras.src.models import Functional - from keras.src.models import Sequential - - if isinstance(trackable, Sequential): - obj_type = "Sequential" - elif isinstance(trackable, Functional): - obj_type = "Functional" - elif isinstance(trackable, Layer): - obj_type = "Layer" - elif isinstance(trackable, Optimizer): - obj_type = "Optimizer" - elif isinstance(trackable, Metric): - obj_type = "Metric" - elif isinstance(trackable, Loss): - obj_type = "Loss" - else: - raise ValueError(f"Invalid obj_type: {obj_type}") +def _walk_saveable(saveable): + from keras.src.saving.keras_saveable import KerasSaveable + + if not isinstance(saveable, KerasSaveable): + raise ValueError( + "Expected object to be an " + "instance of `KerasSaveable`, but " + f"got {saveable} of type {type(saveable)}" + ) + + obj_type = saveable._obj_type() attr_skiplist = get_attr_skiplist(obj_type) # Save all layers directly tracked by Sequential and Functional first. # This helps avoid ordering concerns for subclassed Sequential or Functional # models with extra attributes--the internal Keras state take precedence. if obj_type in ("Sequential", "Functional"): - yield "layers", trackable.layers + yield "layers", saveable.layers - for child_attr in sorted(dir(trackable), key=lambda x: _name_key(x)): + for child_attr in sorted(dir(saveable), key=lambda x: _name_key(x)): if child_attr.startswith("__") or child_attr in attr_skiplist: continue try: - child_obj = getattr(trackable, child_attr) + child_obj = getattr(saveable, child_attr) except Exception: # Avoid raising the exception when visiting the attributes. continue @@ -356,26 +349,28 @@ def _walk_trackable(trackable): def _save_state( - trackable, + saveable, weights_store, assets_store, inner_path, - visited_trackables, + visited_saveables, ): - # If the trackable has already been saved, skip it. - if id(trackable) in visited_trackables: + from keras.src.saving.keras_saveable import KerasSaveable + + # If the saveable has already been saved, skip it. + if id(saveable) in visited_saveables: return - if hasattr(trackable, "save_own_variables") and weights_store: - trackable.save_own_variables(weights_store.make(inner_path)) - if hasattr(trackable, "save_assets") and assets_store: - trackable.save_assets(assets_store.make(inner_path)) + if hasattr(saveable, "save_own_variables") and weights_store: + saveable.save_own_variables(weights_store.make(inner_path)) + if hasattr(saveable, "save_assets") and assets_store: + saveable.save_assets(assets_store.make(inner_path)) - visited_trackables.add(id(trackable)) + visited_saveables.add(id(saveable)) - # Recursively save state of children trackables (layers, optimizers, etc.) - for child_attr, child_obj in _walk_trackable(trackable): - if _is_keras_trackable(child_obj): + # Recursively save state of children saveables (layers, optimizers, etc.) + for child_attr, child_obj in _walk_saveable(saveable): + if isinstance(child_obj, KerasSaveable): _save_state( child_obj, weights_store, @@ -383,7 +378,7 @@ def _save_state( inner_path=file_utils.join(inner_path, child_attr).replace( "\\", "/" ), - visited_trackables=visited_trackables, + visited_saveables=visited_saveables, ) elif isinstance(child_obj, (list, dict, tuple, set)): _save_container_state( @@ -393,55 +388,57 @@ def _save_state( inner_path=file_utils.join(inner_path, child_attr).replace( "\\", "/" ), - visited_trackables=visited_trackables, + visited_saveables=visited_saveables, ) def _load_state( - trackable, + saveable, weights_store, assets_store, inner_path, skip_mismatch=False, - visited_trackables=None, - failed_trackables=None, + visited_saveables=None, + failed_saveables=None, error_msgs=None, ): - if visited_trackables and id(trackable) in visited_trackables: + from keras.src.saving.keras_saveable import KerasSaveable + + if visited_saveables and id(saveable) in visited_saveables: return failure = False - if hasattr(trackable, "load_own_variables") and weights_store: - if skip_mismatch or failed_trackables is not None: + if hasattr(saveable, "load_own_variables") and weights_store: + if skip_mismatch or failed_saveables is not None: try: - trackable.load_own_variables(weights_store.get(inner_path)) + saveable.load_own_variables(weights_store.get(inner_path)) except Exception as e: - failed_trackables.add(id(trackable)) - error_msgs[id(trackable)] = trackable, e + failed_saveables.add(id(saveable)) + error_msgs[id(saveable)] = saveable, e failure = True else: - trackable.load_own_variables(weights_store.get(inner_path)) + saveable.load_own_variables(weights_store.get(inner_path)) - if hasattr(trackable, "load_assets") and assets_store: - if skip_mismatch or failed_trackables is not None: + if hasattr(saveable, "load_assets") and assets_store: + if skip_mismatch or failed_saveables is not None: try: - trackable.load_assets(assets_store.get(inner_path)) + saveable.load_assets(assets_store.get(inner_path)) except Exception as e: - failed_trackables.add(id(trackable)) - error_msgs[id(trackable)] = trackable, e + failed_saveables.add(id(saveable)) + error_msgs[id(saveable)] = saveable, e failure = True else: - trackable.load_assets(assets_store.get(inner_path)) + saveable.load_assets(assets_store.get(inner_path)) - if failed_trackables is not None: - currently_failed = len(failed_trackables) + if failed_saveables is not None: + currently_failed = len(failed_saveables) else: currently_failed = 0 - # Recursively load states for Keras trackables such as layers/optimizers. - for child_attr, child_obj in _walk_trackable(trackable): - if _is_keras_trackable(child_obj): + # Recursively load states for Keras saveables such as layers/optimizers. + for child_attr, child_obj in _walk_saveable(saveable): + if isinstance(child_obj, KerasSaveable): _load_state( child_obj, weights_store, @@ -450,8 +447,8 @@ def _load_state( "\\", "/" ), skip_mismatch=skip_mismatch, - visited_trackables=visited_trackables, - failed_trackables=failed_trackables, + visited_saveables=visited_saveables, + failed_saveables=failed_saveables, error_msgs=error_msgs, ) elif isinstance(child_obj, (list, dict, tuple, set)): @@ -463,48 +460,50 @@ def _load_state( "\\", "/" ), skip_mismatch=skip_mismatch, - visited_trackables=visited_trackables, - failed_trackables=failed_trackables, + visited_saveables=visited_saveables, + failed_saveables=failed_saveables, error_msgs=error_msgs, ) - if failed_trackables is not None: - newly_failed = len(failed_trackables) - currently_failed + if failed_saveables is not None: + newly_failed = len(failed_saveables) - currently_failed else: newly_failed = 0 if not failure: - if visited_trackables is not None and newly_failed <= 0: - visited_trackables.add(id(trackable)) - if id(trackable) in failed_trackables: - failed_trackables.remove(id(trackable)) - error_msgs.pop(id(trackable)) + if visited_saveables is not None and newly_failed <= 0: + visited_saveables.add(id(saveable)) + if id(saveable) in failed_saveables: + failed_saveables.remove(id(saveable)) + error_msgs.pop(id(saveable)) def _save_container_state( - container, weights_store, assets_store, inner_path, visited_trackables + container, weights_store, assets_store, inner_path, visited_saveables ): + from keras.src.saving.keras_saveable import KerasSaveable + used_names = {} if isinstance(container, dict): container = list(container.values()) - for trackable in container: - if _is_keras_trackable(trackable): - # Do NOT address the trackable via `trackable.name`, since + for saveable in container: + if isinstance(saveable, KerasSaveable): + # Do NOT address the saveable via `saveable.name`, since # names are usually autogenerated and thus not reproducible # (i.e. they may vary across two instances of the same model). - name = naming.to_snake_case(trackable.__class__.__name__) + name = naming.to_snake_case(saveable.__class__.__name__) if name in used_names: used_names[name] += 1 name = f"{name}_{used_names[name]}" else: used_names[name] = 0 _save_state( - trackable, + saveable, weights_store, assets_store, inner_path=file_utils.join(inner_path, name).replace("\\", "/"), - visited_trackables=visited_trackables, + visited_saveables=visited_saveables, ) @@ -514,30 +513,32 @@ def _load_container_state( assets_store, inner_path, skip_mismatch, - visited_trackables, - failed_trackables, + visited_saveables, + failed_saveables, error_msgs, ): + from keras.src.saving.keras_saveable import KerasSaveable + used_names = {} if isinstance(container, dict): container = list(container.values()) - for trackable in container: - if _is_keras_trackable(trackable): - name = naming.to_snake_case(trackable.__class__.__name__) + for saveable in container: + if isinstance(saveable, KerasSaveable): + name = naming.to_snake_case(saveable.__class__.__name__) if name in used_names: used_names[name] += 1 name = f"{name}_{used_names[name]}" else: used_names[name] = 0 _load_state( - trackable, + saveable, weights_store, assets_store, inner_path=file_utils.join(inner_path, name).replace("\\", "/"), skip_mismatch=skip_mismatch, - visited_trackables=visited_trackables, - failed_trackables=failed_trackables, + visited_saveables=visited_saveables, + failed_saveables=failed_saveables, error_msgs=error_msgs, ) @@ -793,20 +794,14 @@ def get_attr_skiplist(obj_type): ref_obj = Loss() skiplist += dir(ref_obj) else: - raise ValueError(f"Invalid obj_type: {obj_type}") + raise ValueError( + f"get_attr_skiplist got invalid {obj_type=}. " + "Accepted values for `obj_type` are " + "['Layer', 'Functional', 'Sequential', 'Metric', " + "'Optimizer', 'Loss']" + ) + global_state.set_global_attribute( f"saving_attr_skiplist_{obj_type}", skiplist ) return skiplist - - -def _is_keras_trackable(obj): - return isinstance( - obj, - ( - Layer, - Optimizer, - Metric, - Loss, - ), - ) diff --git a/keras/src/saving/saving_lib_test.py b/keras/src/saving/saving_lib_test.py index 6b2d483d759..23ac2a52e73 100644 --- a/keras/src/saving/saving_lib_test.py +++ b/keras/src/saving/saving_lib_test.py @@ -906,3 +906,15 @@ def func(in_size=4, out_size=2, name=None): out = new_model(x) self.assertAllClose(ref_out[0], out[0]) self.assertAllClose(ref_out[1], out[1]) + + def test_bidirectional_lstm_saving(self): + inputs = keras.Input((3, 2)) + outputs = keras.layers.Bidirectional(keras.layers.LSTM(64))(inputs) + model = keras.Model(inputs, outputs) + temp_filepath = os.path.join(self.get_temp_dir(), "bidir_lstm.keras") + model.save(temp_filepath) + new_model = keras.saving.load_model(temp_filepath) + x = np.random.random((1, 3, 2)) + ref_out = model(x) + out = new_model(x) + self.assertAllClose(ref_out, out) diff --git a/keras/src/saving/serialization_lib.py b/keras/src/saving/serialization_lib.py index 40125572809..3adc832884e 100644 --- a/keras/src/saving/serialization_lib.py +++ b/keras/src/saving/serialization_lib.py @@ -162,6 +162,11 @@ def serialize_keras_object(obj): "step": serialize_keras_object(obj.step), }, } + # Ellipsis is an instance, and ellipsis class is not in global scope. + # checking equality also fails elsewhere in the library, so we have + # to dynamically get the type. + if isinstance(obj, type(Ellipsis)): + return {"class_name": "__ellipsis__", "config": {}} if isinstance(obj, backend.KerasTensor): history = getattr(obj, "_keras_history", None) if history: @@ -613,6 +618,8 @@ class ModifiedMeanSquaredError(keras.losses.MeanSquaredError): return np.array(inner_config["value"], dtype=inner_config["dtype"]) if config["class_name"] == "__bytes__": return inner_config["value"].encode("utf-8") + if config["class_name"] == "__ellipsis__": + return Ellipsis if config["class_name"] == "__slice__": return slice( deserialize_keras_object( diff --git a/keras/src/saving/serialization_lib_test.py b/keras/src/saving/serialization_lib_test.py index 06ed6ac7198..80df36f3eeb 100644 --- a/keras/src/saving/serialization_lib_test.py +++ b/keras/src/saving/serialization_lib_test.py @@ -107,6 +107,10 @@ def tuples_to_lists_str(x): reserialized_str = tuples_to_lists_str(reserialized) self.assertEqual(serialized_str, reserialized_str) + def test_serialize_ellipsis(self): + _, deserialized, _ = self.roundtrip(Ellipsis) + self.assertEqual(..., deserialized) + def test_tensors_and_shapes(self): x = ops.random.normal((2, 2), dtype="float64") obj = {"x": x} diff --git a/keras/src/testing/test_case.py b/keras/src/testing/test_case.py index 0b6fd9d40f3..4b46d1d9b25 100644 --- a/keras/src/testing/test_case.py +++ b/keras/src/testing/test_case.py @@ -99,6 +99,20 @@ def assertSparse(self, x, sparse=True): f"Backend {backend.backend()} does not support sparse tensors", ) + def assertDType(self, x, dtype, msg=None): + if hasattr(x, "dtype"): + x_dtype = backend.standardize_dtype(x.dtype) + else: + # If x is a python number + x_dtype = backend.standardize_dtype(type(x)) + standardized_dtype = backend.standardize_dtype(dtype) + default_msg = ( + "The dtype of x does not match the expected one. " + f"Received: x.dtype={x_dtype} and dtype={dtype}" + ) + msg = msg or default_msg + self.assertEqual(x_dtype, standardized_dtype, msg=msg) + def run_class_serialization_test(self, instance, custom_objects=None): from keras.src.saving import custom_object_scope from keras.src.saving import deserialize_keras_object diff --git a/keras/src/trainers/compile_utils.py b/keras/src/trainers/compile_utils.py index 729c2692133..afb31ed716b 100644 --- a/keras/src/trainers/compile_utils.py +++ b/keras/src/trainers/compile_utils.py @@ -164,7 +164,7 @@ def variables(self): if not self.built: return [] vars = [] - for m in self._flat_metrics + self._flat_weighted_metrics: + for m in self.metrics: if m is not None: vars.extend(m.variables) return vars @@ -413,10 +413,14 @@ def __init__( reduction="sum_over_batch_size", output_names=None, ): - if loss_weights and not isinstance(loss_weights, (list, tuple, dict)): + if loss_weights and not isinstance( + loss_weights, (list, tuple, dict, float) + ): raise ValueError( - "Expected `loss_weights` argument to be a list, tuple, or " - f"dict. Received instead: loss_weights={loss_weights} " + "Expected `loss_weights` argument to be a float " + "(single output case) or a list, tuple, or " + "dict (multiple output case). " + f"Received instead: loss_weights={loss_weights} " f"of type {type(loss_weights)}" ) self._user_loss = loss diff --git a/keras/src/trainers/data_adapters/__init__.py b/keras/src/trainers/data_adapters/__init__.py index 41f2a91f11a..3dc04b75498 100644 --- a/keras/src/trainers/data_adapters/__init__.py +++ b/keras/src/trainers/data_adapters/__init__.py @@ -71,6 +71,13 @@ def get_data_adapter( "sample_weights", "the sample weights", "PyDataset" ) return PyDatasetAdapter(x, class_weight=class_weight, shuffle=shuffle) + # TODO: should we warn or not? + # if x.num_batches is None and shuffle: + # warnings.warn( + # "`shuffle=True` was passed, but will be ignored since the " + # "data `x` was provided as a infinite PyDataset. The " + # "PyDataset is expected to already be shuffled." + # ) elif is_torch_dataloader(x): if y is not None: raise_unsupported_arg("y", "the targets", "torch DataLoader") diff --git a/keras/src/trainers/data_adapters/array_data_adapter_test.py b/keras/src/trainers/data_adapters/array_data_adapter_test.py index 46eb4fcc194..a61a904240e 100644 --- a/keras/src/trainers/data_adapters/array_data_adapter_test.py +++ b/keras/src/trainers/data_adapters/array_data_adapter_test.py @@ -52,11 +52,10 @@ def make_array(self, array_type, shape, dtype): "scipy_sparse", ], array_dtype=["float32", "float64"], - iterator_type=["np", "tf", "jax", "torch"], shuffle=[False, "batch", True], ) ) - def test_basic_flow(self, array_type, array_dtype, iterator_type, shuffle): + def test_basic_flow(self, array_type, array_dtype, shuffle): x = self.make_array(array_type, (34, 4), array_dtype) y = self.make_array(array_type, (34, 2), "int32") xdim1 = 1 if array_type == "pandas_series" else 4 @@ -75,10 +74,10 @@ def test_basic_flow(self, array_type, array_dtype, iterator_type, shuffle): self.assertEqual(adapter.has_partial_batch, True) self.assertEqual(adapter.partial_batch_size, 2) - if iterator_type == "np": + if backend.backend() == "numpy": it = adapter.get_numpy_iterator() expected_class = np.ndarray - elif iterator_type == "tf": + elif backend.backend() == "tensorflow": it = adapter.get_tf_dataset() if array_type == "tf_ragged": expected_class = tf.RaggedTensor @@ -88,13 +87,13 @@ def test_basic_flow(self, array_type, array_dtype, iterator_type, shuffle): expected_class = tf.SparseTensor else: expected_class = tf.Tensor - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() if array_type in ("tf_sparse", "jax_sparse", "scipy_sparse"): expected_class = jax_sparse.JAXSparse else: expected_class = jax.Array - elif iterator_type == "torch": + elif backend.backend() == "torch": it = adapter.get_torch_dataloader() expected_class = torch.Tensor @@ -245,5 +244,58 @@ def test_class_weights(self, target_encoding): self.assertAllClose(bw, [0.1, 0.2, 0.3, 0.4]) def test_errors(self): - # TODO - pass + x = np.random.random((34, 1)) + y = np.random.random((34, 3)) + sw = np.random.random((34,)) + cw = { + 0: 0.1, + 1: 0.2, + 2: 0.3, + 3: 0.4, + } + + with self.assertRaisesRegex( + ValueError, "Expected all elements of `x` to be array-like" + ): + array_data_adapter.ArrayDataAdapter(x="Invalid") + with self.assertRaisesRegex( + ValueError, "Expected all elements of `x` to be array-like" + ): + array_data_adapter.ArrayDataAdapter(x=x, y="Invalid") + with self.assertRaisesRegex( + ValueError, "Expected all elements of `x` to be array-like" + ): + array_data_adapter.ArrayDataAdapter( + x=x, y=y, sample_weight="Invalid" + ) + + with self.assertRaisesRegex( + ValueError, "You cannot `class_weight` and `sample_weight`" + ): + array_data_adapter.ArrayDataAdapter( + x=x, y=y, sample_weight=sw, class_weight=cw + ) + + nested_y = ({"x": x, "y": y},) + with self.assertRaisesRegex( + ValueError, "You should provide one `sample_weight` array per" + ): + array_data_adapter.ArrayDataAdapter( + x=x, y=nested_y, sample_weight=[] + ) + + tensor_sw = self.make_array("tf", (34, 2), "int32") + with self.assertRaisesRegex( + ValueError, "For a model with multiple outputs, when providing" + ): + array_data_adapter.ArrayDataAdapter( + x=x, y=nested_y, sample_weight=tensor_sw + ) + + with self.assertRaisesRegex( + ValueError, + "`class_weight` is only supported for Models with a single", + ): + array_data_adapter.ArrayDataAdapter( + x=x, y=nested_y, class_weight=cw + ) diff --git a/keras/src/trainers/data_adapters/generator_data_adapter_test.py b/keras/src/trainers/data_adapters/generator_data_adapter_test.py index 4d6ebdc5597..76839308c7f 100644 --- a/keras/src/trainers/data_adapters/generator_data_adapter_test.py +++ b/keras/src/trainers/data_adapters/generator_data_adapter_test.py @@ -3,12 +3,14 @@ import jax import jax.experimental.sparse as jax_sparse import numpy as np +import pytest import scipy import tensorflow as tf import torch from absl.testing import parameterized from jax import numpy as jnp +from keras.src import backend from keras.src import testing from keras.src.testing.test_utils import named_product from keras.src.trainers.data_adapters import generator_data_adapter @@ -37,10 +39,9 @@ class GeneratorDataAdapterTest(testing.TestCase, parameterized.TestCase): {"testcase_name": "no_weight", "use_sample_weight": False}, ], generator_type=["np", "tf", "jax", "torch"], - iterator_type=["np", "tf", "jax", "torch"], ) ) - def test_basic_flow(self, use_sample_weight, generator_type, iterator_type): + def test_basic_flow(self, use_sample_weight, generator_type): x = np.random.random((34, 4)).astype("float32") y = np.array([[i, i] for i in range(34)], dtype="float32") sw = np.random.random((34,)).astype("float32") @@ -64,16 +65,16 @@ def test_basic_flow(self, use_sample_weight, generator_type, iterator_type): ) adapter = generator_data_adapter.GeneratorDataAdapter(make_generator()) - if iterator_type == "np": + if backend.backend() == "numpy": it = adapter.get_numpy_iterator() expected_class = np.ndarray - elif iterator_type == "tf": + elif backend.backend() == "tensorflow": it = adapter.get_tf_dataset() expected_class = tf.Tensor - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() expected_class = jax.Array - elif iterator_type == "torch": + elif backend.backend() == "torch": it = adapter.get_torch_dataloader() expected_class = torch.Tensor @@ -101,10 +102,7 @@ def test_basic_flow(self, use_sample_weight, generator_type, iterator_type): sample_order.append(by[i, 0]) self.assertAllClose(sample_order, list(range(34))) - @parameterized.named_parameters( - named_product(iterator_type=["np", "tf", "jax", "torch"]) - ) - def test_with_different_shapes(self, iterator_type): + def test_with_different_shapes(self): def generator(): yield np.ones([16, 4], "float32"), np.ones([16, 2], "float32") yield np.ones([16, 5], "float32"), np.ones([16, 2], "float32") @@ -112,13 +110,13 @@ def generator(): adapter = generator_data_adapter.GeneratorDataAdapter(generator()) - if iterator_type == "np": + if backend.backend() == "numpy": it = adapter.get_numpy_iterator() - elif iterator_type == "tf": + elif backend.backend() == "tensorflow": it = adapter.get_tf_dataset() - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() - elif iterator_type == "torch": + elif backend.backend() == "torch": it = adapter.get_torch_dataloader() for i, batch in enumerate(it): @@ -137,11 +135,13 @@ def generator(): self.assertEqual(by.shape, (2, 2)) @parameterized.named_parameters( - named_product( - generator_type=["tf", "jax", "scipy"], iterator_type=["tf", "jax"] - ) + named_product(generator_type=["tf", "jax", "scipy"]) + ) + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason="Backend does not support sparse tensors", ) - def test_scipy_sparse_tensors(self, generator_type, iterator_type): + def test_scipy_sparse_tensors(self, generator_type): if generator_type == "tf": x = tf.SparseTensor([[0, 0], [1, 2]], [1.0, 2.0], (2, 4)) y = tf.SparseTensor([[0, 0], [1, 1]], [3.0, 4.0], (2, 2)) @@ -158,10 +158,10 @@ def generate(): adapter = generator_data_adapter.GeneratorDataAdapter(generate()) - if iterator_type == "tf": + if backend.backend() == "tensorflow": it = adapter.get_tf_dataset() expected_class = tf.SparseTensor - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() expected_class = jax_sparse.BCOO diff --git a/keras/src/trainers/data_adapters/py_dataset_adapter.py b/keras/src/trainers/data_adapters/py_dataset_adapter.py index 71ab2a67736..19b48570577 100644 --- a/keras/src/trainers/data_adapters/py_dataset_adapter.py +++ b/keras/src/trainers/data_adapters/py_dataset_adapter.py @@ -1,3 +1,4 @@ +import itertools import multiprocessing.dummy import queue import random @@ -153,23 +154,26 @@ def __getitem__(self, index): """ raise NotImplementedError - def __len__(self): - """Number of batch in the PyDataset. + @property + def num_batches(self): + """Number of batches in the PyDataset. Returns: - The number of batches in the PyDataset. + The number of batches in the PyDataset or `None` to indicate that + the dataset is infinite. """ - raise NotImplementedError + # For backwards compatibility, support `__len__`. + if hasattr(self, "__len__"): + return len(self) + raise NotImplementedError( + "You need to implement the `num_batches` property:\n\n" + "@property\ndef num_batches(self):\n return ..." + ) def on_epoch_end(self): """Method called at the end of every epoch.""" pass - def __iter__(self): - """Create a generator that iterate over the PyDataset.""" - for i in range(len(self)): - yield self[i] - class PyDatasetAdapter(DataAdapter): """Adapter for `keras.utils.PyDataset` instances.""" @@ -234,23 +238,33 @@ def generator_fn(): else: def generator_fn(): - order = range(len(self.py_dataset)) - if self.shuffle: + num_batches = self.py_dataset.num_batches + indices = ( + range(num_batches) + if num_batches is not None + else itertools.count() + ) + if self.shuffle and num_batches is not None: # Match the shuffle convention in OrderedEnqueuer. - order = list(order) - random.shuffle(order) + indices = list(indices) + random.shuffle(indices) - for i in order: + for i in indices: yield self.py_dataset[i] return generator_fn def _get_iterator(self): + num_batches = self.py_dataset.num_batches gen_fn = self._make_multiprocessed_generator_fn() for i, batch in enumerate(gen_fn()): batch = self._standardize_batch(batch) yield batch - if i >= len(self.py_dataset) - 1 and self.enqueuer: + if ( + self.enqueuer + and num_batches is not None + and i >= num_batches - 1 + ): self.enqueuer.stop() def get_numpy_iterator(self): @@ -262,11 +276,11 @@ def get_jax_iterator(self): def get_tf_dataset(self): from keras.src.utils.module_utils import tensorflow as tf + num_batches = self.py_dataset.num_batches if self._output_signature is None: - num_samples = min( - data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC, - len(self.py_dataset), - ) + num_samples = data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC + if num_batches is not None: + num_samples = min(num_samples, num_batches) batches = [ self._standardize_batch(self.py_dataset[i]) for i in range(num_samples) @@ -277,7 +291,7 @@ def get_tf_dataset(self): self._get_iterator, output_signature=self._output_signature, ) - if self.shuffle: + if self.shuffle and num_batches is not None: ds = ds.shuffle(8) ds = ds.prefetch(tf.data.AUTOTUNE) return ds @@ -292,7 +306,7 @@ def on_epoch_end(self): @property def num_batches(self): - return len(self.py_dataset) + return self.py_dataset.num_batches @property def batch_size(self): @@ -328,11 +342,6 @@ def get_worker_id_queue(): return _WORKER_ID_QUEUE -def init_pool(seqs): - global _SHARED_SEQUENCES - _SHARED_SEQUENCES = seqs - - def get_index(uid, i): """Get the value from the PyDataset `uid` at index `i`. @@ -520,31 +529,40 @@ def _wait_queue(self): def _run(self): """Submits request to the executor and queue the `Future` objects.""" - indices = list(range(len(self.py_dataset))) - if self.shuffle: - random.shuffle(indices) - self._send_py_dataset() # Share the initial py_dataset - while True: - with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: - for i in indices: + try: + num_batches = self.py_dataset.num_batches + indices = ( + range(num_batches) + if num_batches is not None + else itertools.count() + ) + if self.shuffle and num_batches is not None: + indices = list(indices) + random.shuffle(indices) + self._send_py_dataset() # Share the initial py_dataset + while True: + with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: + for i in indices: + if self.stop_signal.is_set(): + return + + self.queue.put( + executor.apply_async(get_index, (self.uid, i)), + block=True, + ) + + # Done with the current epoch, waiting for the final batches + self._wait_queue() + if self.stop_signal.is_set(): + # We're done return - self.queue.put( - executor.apply_async(get_index, (self.uid, i)), - block=True, - ) - - # Done with the current epoch, waiting for the final batches - self._wait_queue() - - if self.stop_signal.is_set(): - # We're done - return - - # Call the internal on epoch end. - self.py_dataset.on_epoch_end() - self._send_py_dataset() # Update the pool + # Call the internal on epoch end. + self.py_dataset.on_epoch_end() + self._send_py_dataset() # Update the pool + except Exception as e: + self.queue.put(e) # Report exception def get(self): """Creates a generator to extract data from the queue. @@ -558,7 +576,10 @@ def get(self): """ while self.is_running(): try: - inputs = self.queue.get(block=True, timeout=5).get() + value = self.queue.get(block=True, timeout=5) + if isinstance(value, Exception): + raise value # Propagate exception from other thread + inputs = value.get() if self.is_running(): self.queue.task_done() if inputs is not None: diff --git a/keras/src/trainers/data_adapters/py_dataset_adapter_test.py b/keras/src/trainers/data_adapters/py_dataset_adapter_test.py index b1be7002ac5..ac661c2047a 100644 --- a/keras/src/trainers/data_adapters/py_dataset_adapter_test.py +++ b/keras/src/trainers/data_adapters/py_dataset_adapter_test.py @@ -3,10 +3,12 @@ import jax import numpy as np +import pytest import tensorflow as tf import torch from absl.testing import parameterized +from keras.src import backend from keras.src import testing from keras.src.testing.test_utils import named_product from keras.src.trainers.data_adapters import py_dataset_adapter @@ -15,21 +17,34 @@ class ExamplePyDataset(py_dataset_adapter.PyDataset): def __init__( - self, x_set, y_set, sample_weight=None, batch_size=32, delay=0, **kwargs + self, + x_set, + y_set, + sample_weight=None, + batch_size=32, + delay=0, + infinite=False, + **kwargs ): super().__init__(**kwargs) self.x, self.y = x_set, y_set self.batch_size = batch_size self.sample_weight = sample_weight self.delay = delay + self.infinite = infinite - def __len__(self): + @property + def num_batches(self): + if self.infinite: + return None return math.ceil(len(self.x) / self.batch_size) def __getitem__(self, idx): # Create artificial delay to test multiprocessing time.sleep(self.delay) + if self.infinite: + idx = idx % math.ceil(len(self.x) / self.batch_size) # Return x, y for batch idx. low = idx * self.batch_size # Cap upper bound at array length; the last batch may be smaller @@ -48,7 +63,8 @@ def __init__(self, inputs, batch_size=32, **kwargs): self.inputs = inputs self.batch_size = batch_size - def __len__(self): + @property + def num_batches(self): return math.ceil(len(self.inputs["x"]) / self.batch_size) def __getitem__(self, idx): @@ -63,6 +79,21 @@ def __getitem__(self, idx): return batch +class ExceptionPyDataset(py_dataset_adapter.PyDataset): + + @property + def num_batches(self): + return 4 + + def __getitem__(self, index): + if index < 2: + return ( + np.random.random((64, 4)).astype("float32"), + np.random.random((64, 2)).astype("float32"), + ) + raise ValueError("Expected exception") + + class PyDatasetAdapterTest(testing.TestCase, parameterized.TestCase): @parameterized.named_parameters( named_product( @@ -98,7 +129,7 @@ class PyDatasetAdapterTest(testing.TestCase, parameterized.TestCase): "dataset_type": "torch", }, ], - iterator_type=["np", "tf", "jax", "torch"], + infinite=[True, False], shuffle=[True, False], ) ) @@ -106,11 +137,14 @@ def test_basic_flow( self, shuffle, dataset_type, - iterator_type, + infinite, workers=0, use_multiprocessing=False, max_queue_size=0, ): + if use_multiprocessing and (infinite or shuffle): + pytest.skip("Starting processes is slow, only test one variant") + set_random_seed(1337) x = np.random.random((64, 4)).astype("float32") y = np.array([[i, i] for i in range(64)], dtype="float32") @@ -127,21 +161,22 @@ def test_basic_flow( workers=workers, use_multiprocessing=use_multiprocessing, max_queue_size=max_queue_size, + infinite=infinite, ) adapter = py_dataset_adapter.PyDatasetAdapter( py_dataset, shuffle=shuffle ) - if iterator_type == "np": + if backend.backend() == "numpy": it = adapter.get_numpy_iterator() expected_class = np.ndarray - elif iterator_type == "tf": + elif backend.backend() == "tensorflow": it = adapter.get_tf_dataset() expected_class = tf.Tensor - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() expected_class = jax.Array - elif iterator_type == "torch": + elif backend.backend() == "torch": it = adapter.get_torch_dataloader() expected_class = torch.Tensor @@ -157,10 +192,16 @@ def test_basic_flow( self.assertEqual(by.shape, (16, 2)) for i in range(by.shape[0]): sample_order.append(by[i, 0]) - if shuffle: - self.assertNotAllClose(sample_order, list(range(64))) + if infinite and len(sample_order) >= 128: + break + expected_order = list(range(64)) + if infinite: + # When the dataset is infinite, we cycle through the data twice. + expected_order = expected_order + expected_order + if shuffle and not infinite: + self.assertNotAllClose(sample_order, expected_order) else: - self.assertAllClose(sample_order, list(range(64))) + self.assertAllClose(sample_order, expected_order) # TODO: test class_weight # TODO: test sample weights @@ -234,13 +275,11 @@ def test_dict_inputs(self): self.assertEqual(tuple(bx.shape), (4, 4)) self.assertEqual(tuple(by.shape), (4, 2)) - @parameterized.named_parameters( - named_product(iterator_type=["np", "tf", "jax", "torch"]) - ) - def test_with_different_shapes(self, iterator_type): + def test_with_different_shapes(self): class TestPyDataset(py_dataset_adapter.PyDataset): - def __len__(self): + @property + def num_batches(self): return 3 def __getitem__(self, idx): @@ -261,13 +300,13 @@ def __getitem__(self, idx): TestPyDataset(), shuffle=False ) - if iterator_type == "np": + if backend.backend() == "numpy": it = adapter.get_numpy_iterator() - elif iterator_type == "tf": + elif backend.backend() == "tensorflow": it = adapter.get_tf_dataset() - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() - elif iterator_type == "torch": + elif backend.backend() == "torch": it = adapter.get_torch_dataloader() for i, batch in enumerate(it): @@ -284,3 +323,54 @@ def __getitem__(self, idx): else: self.assertEqual(bx.shape, (2, 6)) self.assertEqual(by.shape, (2, 2)) + + @parameterized.named_parameters( + [ + { + "testcase_name": "multiprocessing", + "workers": 2, + "use_multiprocessing": True, + "max_queue_size": 10, + }, + { + "testcase_name": "multithreading", + "workers": 2, + "max_queue_size": 10, + }, + { + "testcase_name": "single", + }, + ] + ) + def test_exception_reported( + self, + workers=0, + use_multiprocessing=False, + max_queue_size=0, + ): + dataset = ExceptionPyDataset( + workers=workers, + use_multiprocessing=use_multiprocessing, + max_queue_size=max_queue_size, + ) + adapter = py_dataset_adapter.PyDatasetAdapter(dataset, shuffle=False) + + expected_exception_class = ValueError + if backend.backend() == "numpy": + it = adapter.get_numpy_iterator() + elif backend.backend() == "tensorflow": + it = adapter.get_tf_dataset() + # tf.data wraps the exception + expected_exception_class = tf.errors.InvalidArgumentError + elif backend.backend() == "jax": + it = adapter.get_jax_iterator() + elif backend.backend() == "torch": + it = adapter.get_torch_dataloader() + + it = iter(it) + next(it) + next(it) + with self.assertRaisesRegex( + expected_exception_class, "Expected exception" + ): + next(it) diff --git a/keras/src/trainers/data_adapters/tf_dataset_adapter_test.py b/keras/src/trainers/data_adapters/tf_dataset_adapter_test.py index ad48c2d3c24..2535e505d61 100644 --- a/keras/src/trainers/data_adapters/tf_dataset_adapter_test.py +++ b/keras/src/trainers/data_adapters/tf_dataset_adapter_test.py @@ -2,20 +2,18 @@ import jax import numpy as np +import pytest import tensorflow as tf import torch from absl.testing import parameterized +from keras.src import backend from keras.src import testing -from keras.src.testing.test_utils import named_product from keras.src.trainers.data_adapters import tf_dataset_adapter class TestTFDatasetAdapter(testing.TestCase, parameterized.TestCase): - @parameterized.named_parameters( - named_product(iterator_type=["np", "tf", "jax", "torch"]) - ) - def test_basic_flow(self, iterator_type): + def test_basic_flow(self): x = tf.random.normal((34, 4)) y = tf.random.normal((34, 2)) base_ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(16) @@ -26,16 +24,16 @@ def test_basic_flow(self, iterator_type): self.assertEqual(adapter.has_partial_batch, None) self.assertEqual(adapter.partial_batch_size, None) - if iterator_type == "np": + if backend.backend() == "numpy": it = adapter.get_numpy_iterator() expected_class = np.ndarray - elif iterator_type == "tf": + elif backend.backend() == "tensorflow": it = adapter.get_tf_dataset() expected_class = tf.Tensor - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() expected_class = jax.Array - elif iterator_type == "torch": + elif backend.backend() == "torch": it = adapter.get_torch_dataloader() expected_class = torch.Tensor @@ -258,10 +256,11 @@ def test_distribute_dataset(self): self.assertEqual(tuple(bx.shape), (2, 4)) self.assertEqual(tuple(by.shape), (2, 2)) - @parameterized.named_parameters( - named_product(iterator_type=["np", "tf", "jax"]) + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS and backend.backend() != "numpy", + reason="Backend does not support sparse tensors", ) - def test_tf_sparse_tensors(self, iterator_type): + def test_tf_sparse_tensors(self): x = tf.SparseTensor( indices=[[0, 0], [1, 2]], values=[1.0, 2.0], dense_shape=(2, 4) ) @@ -271,13 +270,13 @@ def test_tf_sparse_tensors(self, iterator_type): base_ds = tf.data.Dataset.from_tensors((x, y)) adapter = tf_dataset_adapter.TFDatasetAdapter(base_ds) - if iterator_type == "np": + if backend.backend() == "numpy": it = adapter.get_numpy_iterator() expected_class = np.ndarray - elif iterator_type == "tf": + elif backend.backend() == "tensorflow": it = adapter.get_tf_dataset() expected_class = tf.SparseTensor - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() expected_class = jax.experimental.sparse.BCOO diff --git a/keras/src/trainers/data_adapters/torch_data_loader_adapter_test.py b/keras/src/trainers/data_adapters/torch_data_loader_adapter_test.py index e86f570d692..4d02f5592f6 100644 --- a/keras/src/trainers/data_adapters/torch_data_loader_adapter_test.py +++ b/keras/src/trainers/data_adapters/torch_data_loader_adapter_test.py @@ -6,6 +6,7 @@ import torch from absl.testing import parameterized +from keras.src import backend from keras.src import testing from keras.src.testing.test_utils import named_product from keras.src.trainers.data_adapters.torch_data_loader_adapter import ( @@ -14,10 +15,7 @@ class TestTorchDataLoaderAdapter(testing.TestCase, parameterized.TestCase): - @parameterized.named_parameters( - named_product(iterator_type=["np", "tf", "jax", "torch"]) - ) - def test_basic_dataloader(self, iterator_type): + def test_basic_dataloader(self): x = torch.normal(2, 3, size=(34, 4)) y = torch.normal(1, 3, size=(34, 2)) ds = torch.utils.data.TensorDataset(x, y) @@ -29,16 +27,16 @@ def test_basic_dataloader(self, iterator_type): self.assertEqual(adapter.has_partial_batch, True) self.assertEqual(adapter.partial_batch_size, 2) - if iterator_type == "np": + if backend.backend() == "numpy": it = adapter.get_numpy_iterator() expected_class = np.ndarray - elif iterator_type == "tf": + elif backend.backend() == "tensorflow": it = adapter.get_tf_dataset() expected_class = tf.Tensor - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() expected_class = jax.Array - elif iterator_type == "torch": + elif backend.backend() == "torch": it = adapter.get_torch_dataloader() expected_class = torch.Tensor @@ -57,15 +55,9 @@ def test_basic_dataloader(self, iterator_type): self.assertEqual(by.shape, (2, 2)) @parameterized.named_parameters( - named_product( - batch_size=[None, 3], - implements_len=[True, False], - iterator_type=["np", "tf", "jax", "torch"], - ) + named_product(batch_size=[None, 3], implements_len=[True, False]) ) - def test_dataloader_iterable_dataset( - self, batch_size, implements_len, iterator_type - ): + def test_dataloader_iterable_dataset(self, batch_size, implements_len): class TestIterableDataset(torch.utils.data.IterableDataset): def __init__(self): @@ -104,16 +96,16 @@ def __len__(self): self.assertIsNone(adapter.has_partial_batch) self.assertIsNone(adapter.partial_batch_size) - if iterator_type == "np": + if backend.backend() == "numpy": it = adapter.get_numpy_iterator() expected_class = np.ndarray - elif iterator_type == "tf": + elif backend.backend() == "tensorflow": it = adapter.get_tf_dataset() expected_class = tf.Tensor - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() expected_class = jax.Array - elif iterator_type == "torch": + elif backend.backend() == "torch": it = adapter.get_torch_dataloader() expected_class = torch.Tensor @@ -142,10 +134,7 @@ def __len__(self): else: self.assertEqual(batch_count, 10) - @parameterized.named_parameters( - named_product(iterator_type=["np", "tf", "jax", "torch"]) - ) - def test_with_different_shapes(self, iterator_type): + def test_with_different_shapes(self): x = ( [np.ones([4], "float32")] * 16 + [np.ones([5], "float32")] * 16 @@ -161,13 +150,13 @@ def test_with_different_shapes(self, iterator_type): self.assertEqual(adapter.has_partial_batch, True) self.assertEqual(adapter.partial_batch_size, 2) - if iterator_type == "np": + if backend.backend() == "numpy": it = adapter.get_numpy_iterator() - elif iterator_type == "tf": + elif backend.backend() == "tensorflow": it = adapter.get_tf_dataset() - elif iterator_type == "jax": + elif backend.backend() == "jax": it = adapter.get_jax_iterator() - elif iterator_type == "torch": + elif backend.backend() == "torch": it = adapter.get_torch_dataloader() for i, batch in enumerate(it): diff --git a/keras/src/utils/audio_dataset_utils.py b/keras/src/utils/audio_dataset_utils.py index ac1bab223b8..7e320188225 100644 --- a/keras/src/utils/audio_dataset_utils.py +++ b/keras/src/utils/audio_dataset_utils.py @@ -409,27 +409,44 @@ def paths_and_labels_to_dataset( ): """Constructs a fixed-size dataset of audio and labels.""" path_ds = tf.data.Dataset.from_tensor_slices(file_paths) - if shuffle: - path_ds = path_ds.shuffle( - buffer_size=shuffle_buffer_size or 1024, seed=seed + if label_mode: + label_ds = dataset_utils.labels_to_dataset( + labels, label_mode, num_classes ) + ds = tf.data.Dataset.zip((path_ds, label_ds)) + else: + ds = path_ds - audio_ds = path_ds.map( - lambda x: read_and_decode_audio( - x, sampling_rate, output_sequence_length - ), - num_parallel_calls=tf.data.AUTOTUNE, - ) + if shuffle: + ds = ds.shuffle(buffer_size=shuffle_buffer_size or 1024, seed=seed) - if ragged: - audio_ds = audio_ds.map( - lambda x: tf.RaggedTensor.from_tensor(x), + if label_mode: + ds = ds.map( + lambda x, y: ( + read_and_decode_audio(x, sampling_rate, output_sequence_length), + y, + ), num_parallel_calls=tf.data.AUTOTUNE, ) - if label_mode: - label_ds = dataset_utils.labels_to_dataset( - labels, label_mode, num_classes + if ragged: + ds = ds.map( + lambda x, y: (tf.RaggedTensor.from_tensor(x), y), + num_parallel_calls=tf.data.AUTOTUNE, + ) + + else: + ds = ds.map( + lambda x: read_and_decode_audio( + x, sampling_rate, output_sequence_length + ), + num_parallel_calls=tf.data.AUTOTUNE, ) - audio_ds = tf.data.Dataset.zip((audio_ds, label_ds)) - return audio_ds + + if ragged: + ds = ds.map( + lambda x: tf.RaggedTensor.from_tensor(x), + num_parallel_calls=tf.data.AUTOTUNE, + ) + + return ds diff --git a/keras/src/utils/image_dataset_utils.py b/keras/src/utils/image_dataset_utils.py index 30317c96780..380b4337973 100755 --- a/keras/src/utils/image_dataset_utils.py +++ b/keras/src/utils/image_dataset_utils.py @@ -367,12 +367,17 @@ def paths_and_labels_to_dataset( seed=None, ): """Constructs a dataset of images and labels.""" - # TODO(fchollet): consider making num_parallel_calls settable path_ds = tf.data.Dataset.from_tensor_slices(image_paths) - if shuffle: - path_ds = path_ds.shuffle( - buffer_size=shuffle_buffer_size or 1024, seed=seed + if label_mode: + label_ds = dataset_utils.labels_to_dataset( + labels, label_mode, num_classes ) + ds = tf.data.Dataset.zip((path_ds, label_ds)) + else: + ds = path_ds + + if shuffle: + ds = ds.shuffle(buffer_size=shuffle_buffer_size or 1024, seed=seed) args = ( image_size, @@ -382,15 +387,16 @@ def paths_and_labels_to_dataset( crop_to_aspect_ratio, pad_to_aspect_ratio, ) - img_ds = path_ds.map( - lambda x: load_image(x, *args), num_parallel_calls=tf.data.AUTOTUNE - ) if label_mode: - label_ds = dataset_utils.labels_to_dataset( - labels, label_mode, num_classes + ds = ds.map( + lambda x, y: (load_image(x, *args), y), + num_parallel_calls=tf.data.AUTOTUNE, + ) + else: + ds = ds.map( + lambda x: load_image(x, *args), num_parallel_calls=tf.data.AUTOTUNE ) - img_ds = tf.data.Dataset.zip((img_ds, label_ds)) - return img_ds + return ds def load_image( diff --git a/keras/src/utils/model_visualization.py b/keras/src/utils/model_visualization.py index 2784a83e3ca..1fd180339b1 100644 --- a/keras/src/utils/model_visualization.py +++ b/keras/src/utils/model_visualization.py @@ -112,18 +112,32 @@ def make_layer_label(layer, **kwargs): output_shape = tree.map_structure(lambda x: x.shape, layer.output) except (ValueError, AttributeError): pass + + def format_shape(shape): + if shape is not None: + if isinstance(shape, dict): + shape_str = ", ".join( + [f"{k}: {v}" for k, v in shape.items()] + ) + else: + shape_str = f"{shape}" + shape_str = shape_str.replace("}", "").replace("{", "") + else: + shape_str = "?" + return shape_str + if class_name != "InputLayer": cols.append( ( '' - f'Input shape: {input_shape or "?"}' + f"Input shape: {format_shape(input_shape)}" "" ) ) cols.append( ( '' - f'Output shape: {output_shape or "?"}' + f"Output shape: {format_shape(output_shape)}" "" ) ) diff --git a/keras/src/utils/text_dataset_utils.py b/keras/src/utils/text_dataset_utils.py index d8e5ece971c..ab1272bf190 100644 --- a/keras/src/utils/text_dataset_utils.py +++ b/keras/src/utils/text_dataset_utils.py @@ -258,21 +258,28 @@ def paths_and_labels_to_dataset( ): """Constructs a dataset of text strings and labels.""" path_ds = tf.data.Dataset.from_tensor_slices(file_paths) - if shuffle: - path_ds = path_ds.shuffle( - buffer_size=shuffle_buffer_size or 1024, seed=seed - ) - - string_ds = path_ds.map( - lambda x: path_to_string_content(x, max_length), - num_parallel_calls=tf.data.AUTOTUNE, - ) if label_mode: label_ds = dataset_utils.labels_to_dataset( labels, label_mode, num_classes ) - string_ds = tf.data.Dataset.zip((string_ds, label_ds)) - return string_ds + ds = tf.data.Dataset.zip((path_ds, label_ds)) + else: + ds = path_ds + + if shuffle: + ds = ds.shuffle(buffer_size=shuffle_buffer_size or 1024, seed=seed) + + if label_mode: + ds = ds.map( + lambda x, y: (path_to_string_content(x, max_length), y), + num_parallel_calls=tf.data.AUTOTUNE, + ) + else: + ds = ds.map( + lambda x: path_to_string_content(x, max_length), + num_parallel_calls=tf.data.AUTOTUNE, + ) + return ds def path_to_string_content(path, max_length): diff --git a/keras/src/utils/tracking.py b/keras/src/utils/tracking.py index 02678de336a..d24cfc3836a 100644 --- a/keras/src/utils/tracking.py +++ b/keras/src/utils/tracking.py @@ -146,10 +146,10 @@ def append(self, value): self.tracker.track(value) super().append(value) - def insert(self, value): + def insert(self, index, value): if self.tracker: self.tracker.track(value) - super().insert(value) + super().insert(index, value) def extend(self, values): if self.tracker: diff --git a/keras/src/version.py b/keras/src/version.py index c0168d651da..11e49a3b926 100644 --- a/keras/src/version.py +++ b/keras/src/version.py @@ -1,7 +1,7 @@ from keras.src.api_export import keras_export # Unique source of truth for the version number. -__version__ = "3.3.0" +__version__ = "3.3.3" @keras_export("keras.version") diff --git a/pip_build.py b/pip_build.py index 887c7119e26..66e7578eee2 100644 --- a/pip_build.py +++ b/pip_build.py @@ -60,6 +60,62 @@ def export_version_string(version, is_nightly=False, rc_index=None): f.write(init_contents) +def ignore_files(_, filenames): + return [f for f in filenames if f.endswith("_test.py")] + + +def copy_source_to_build_directory(root_path): + # Copy sources (`keras/` directory and setup files) to build + # directory + os.chdir(root_path) + os.mkdir(build_directory) + shutil.copytree( + package, os.path.join(build_directory, package), ignore=ignore_files + ) + for fname in to_copy: + shutil.copy(fname, os.path.join(f"{build_directory}", fname)) + os.chdir(build_directory) + + +def build(root_path, is_nightly=False, rc_index=None): + if os.path.exists(build_directory): + raise ValueError(f"Directory already exists: {build_directory}") + + try: + copy_source_to_build_directory(root_path) + move_tf_keras_directory() + + from keras.src.version import __version__ # noqa: E402 + + export_version_string(__version__, is_nightly, rc_index) + return build_and_save_output(root_path, __version__) + finally: + # Clean up: remove the build directory (no longer needed) + shutil.rmtree(build_directory) + + +def move_tf_keras_directory(): + """Move `keras/api/_tf_keras` to `keras/_tf_keras`, update references.""" + shutil.move(os.path.join(package, "api", "_tf_keras"), "keras") + with open(os.path.join(package, "api", "__init__.py")) as f: + contents = f.read() + contents = contents.replace("from keras.api import _tf_keras", "") + with open(os.path.join(package, "api", "__init__.py"), "w") as f: + f.write(contents) + # Replace `keras.api._tf_keras` with `keras._tf_keras`. + for root, _, fnames in os.walk(os.path.join(package, "_tf_keras")): + for fname in fnames: + if fname.endswith(".py"): + tf_keras_fpath = os.path.join(root, fname) + with open(tf_keras_fpath) as f: + contents = f.read() + contents = contents.replace( + "keras.api._tf_keras", "keras._tf_keras" + ) + with open(tf_keras_fpath, "w") as f: + f.write(contents) + + def build_and_save_output(root_path, __version__): # Build the package os.system("python3 -m build") @@ -85,13 +141,6 @@ def build_and_save_output(root_path, __version__): return whl_path -def build(root_path, is_nightly=False, rc_index=None): - from keras.src.version import __version__ # noqa: E402 - - export_version_string(__version__, is_nightly, rc_index) - return build_and_save_output(root_path, __version__) - - def install_whl(whl_fpath): print(f"Installing wheel file: {whl_fpath}") os.system(f"pip3 install {whl_fpath} --force-reinstall --no-dependencies") diff --git a/requirements-common.txt b/requirements-common.txt index 5d15f7ac615..f645c7ba940 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -1,4 +1,4 @@ -namex +namex>=0.0.8 black>=22 flake8 isort diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index 343026cf6ff..e21c1cb1c5b 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -3,7 +3,7 @@ tensorflow-cpu~=2.16.1 # Pin to TF 2.16 # Torch cpu-only version (needed for testing). --extra-index-url https://download.pytorch.org/whl/cpu -torch>=2.1.0 +torch>=2.1.0, <2.3.0 torchvision>=0.16.0 # Jax with cuda support. diff --git a/requirements-tensorflow-cuda.txt b/requirements-tensorflow-cuda.txt index 69be284f766..f3b946ddcfe 100644 --- a/requirements-tensorflow-cuda.txt +++ b/requirements-tensorflow-cuda.txt @@ -3,7 +3,7 @@ tensorflow[and-cuda]~=2.16.1 # Pin to TF 2.16 # Torch cpu-only version (needed for testing). --extra-index-url https://download.pytorch.org/whl/cpu -torch>=2.1.0 +torch>=2.1.0, <2.3.0 torchvision>=0.16.0 # Jax cpu-only version (needed for testing). diff --git a/requirements.txt b/requirements.txt index d6ca6bf28c5..7c0000eed07 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,8 +2,9 @@ tensorflow-cpu~=2.16.1 # Pin to TF 2.16 # Torch. +# TODO: Pin to < 2.3.0 (GitHub issue #19602) --extra-index-url https://download.pytorch.org/whl/cpu -torch>=2.1.0 +torch>=2.1.0, <2.3.0 torchvision>=0.16.0 # Jax. diff --git a/shell/api_gen.sh b/shell/api_gen.sh index 92cf6c7c247..389874b890a 100755 --- a/shell/api_gen.sh +++ b/shell/api_gen.sh @@ -9,4 +9,4 @@ python3 "${base_dir}"/api_gen.py echo "Formatting api directory..." # Format API Files -bash "${base_dir}"/shell/format.sh > /dev/null 2>&1 +bash "${base_dir}"/shell/format.sh