Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce float8 training #19488

Merged
merged 12 commits into from
Apr 15, 2024
Merged

Introduce float8 training #19488

merged 12 commits into from
Apr 15, 2024

Conversation

james77777778
Copy link
Contributor

This PR introduces float8 training, especially for TF and JAX.

Notes

A New Attribute overwrite_with_gradient Has Been Added to Variable

This attribute will be utilized in Optimizer if set to True. The optimizer will directly overwrite the variable by the computed gradient in BaseOptimizer._overwrite_variables_directly_with_gradients of BaseOptimizer.apply.

With this new behavior, we can correctly update float8 parameters such as scale and amax_history.

A New Quantization Mode float8 Has Been Added

To DTypePolicy:

# This should be only invoked by quantizable layers.
original_policy = dtype_policy().name
policy = QuantizedDTypePolicy(f"float8_from_{original_policy}")

To Dense and EinsumDense:

# Dense
model = models.Sequential([layers.Dense(units=16)])
model.quantize("float8")
model.compile(optimizer="sgd", loss="mse")
model.fit(x, y, epochs=2)

# EinsumDense
model = models.Sequential([layers.EinsumDense("ab,bcd->acd", (8, 32))])
model.quantize("float8")
model.compile(optimizer="sgd", loss="mse")
model.fit(x, y, epochs=2)

The compiled graph in TF should contains the expected op __cublas$lt$matmul$f8:
fp8

generate_hlo_dot.py
import tensorflow as tf
from keras import backend
from keras import config
from keras import layers
from keras import ops

keras.config.set_dtype_policy("mixed_bfloat16")

layer = layers.EinsumDense("abcde,aebf->adbcf", output_shape=[16, 16, 16, 16])
layer.build([32, 16, 16, 16, 16])
x = ops.ones([32, 16, 16, 16, 16])
layer.quantize("float8")

def fn(x):
    return layer(x)

compiled_fn = tf.function(fn, jit_compile=True)
y = compiled_fn(x)
ir = compiled_fn.experimental_get_compiler_ir(x)(stage="optimized_hlo_dot")
with open("ir_tf.dot", "w") as f:
    f.write(ir)

Results

  • RTX 4070 12GB, CUDA12.3, CUDNN8.9
  • GPT2 Base: batch_size=64, frozen embeddings
  • GPT2 Medium: batch_size=24, frozen embeddings
Backend Model Recipe Training Speed Peak Mem. Usage Val. Acc.
TF GPT2 Base mixed_bfloat16 274~275 ms/step 6.234GB 0.4997
TF GPT2 Base float8 269~272 ms/step 5.909GB 0.4971
TF GPT2 Medium mixed_bfloat16 252~254 ms/step 7.338GB 0.5535
TF GPT2 Medium float8 268~271 ms/step 7.002GB 0.5513
JAX GPT2 Base mixed_bfloat16 207~208 ms/step X 0.4991
JAX GPT2 Base float8 200~202 ms/step X 0.4977
JAX GPT2 Medium mixed_bfloat16 236~238 ms/step X 0.5534
JAX GPT2 Medium float8 232~234 ms/step X 0.5521

Standalone script:

fp8.py
"""
Dataset:
https://huggingface.co/datasets/databricks/databricks-dolly-15k/blob/main/databricks-dolly-15k.jsonl
"""

import argparse
import json

import kagglehub
import keras_nlp

import keras


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        default="gpt2_base_en",
        choices=["gpt2_base_en", "gpt2_medium_en", "gemma_2b_en"],
        help="Which model to demonstrate",
    )
    parser.add_argument(
        "--fp8",
        action="store_true",
        help="Whether to use float8 technique",
    )
    parser.add_argument("-b", "--batch-size", default=32, type=int)
    args = parser.parse_args()
    return args


def get_optimizer_and_loss():
    optimizer = keras.optimizers.AdamW(
        learning_rate=5e-5,
        weight_decay=0.01,
        epsilon=1e-6,
        global_clipnorm=1.0,  # Gradient clipping.
    )
    # Exclude layernorm and bias terms from weight decay.
    optimizer.exclude_from_weight_decay(var_names=["bias"])
    optimizer.exclude_from_weight_decay(var_names=["gamma"])
    optimizer.exclude_from_weight_decay(var_names=["beta"])

    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    return optimizer, loss


class GPUMemoryCallback(keras.callbacks.Callback):
    def __init__(self, target_batches, **kwargs):
        super().__init__(**kwargs)
        self.target_batches = target_batches
        self.memory_usage = []

    def _compute_memory_usage(self):
        if keras.backend.backend() == "tensorflow":
            import tensorflow as tf

            memory_stats = tf.config.experimental.get_memory_info("GPU:0")
        else:
            memory_stats = {"peak": 0.0}
        # Convert bytes to GB and store in list.
        peak_usage = round(memory_stats["peak"] / (2**30), 3)
        self.memory_usage.append(peak_usage)

    def on_epoch_begin(self, epoch, logs=None):
        self._compute_memory_usage()

    def on_train_batch_begin(self, batch, logs=None):
        if batch in self.target_batches:
            self._compute_memory_usage()

    def on_epoch_end(self, epoch, logs=None):
        self._compute_memory_usage()


if __name__ == "__main__":
    EPOCHS = 2
    keras.config.disable_traceback_filtering()
    keras.mixed_precision.set_global_policy("mixed_bfloat16")

    args = get_args()
    if args.model == "gemma_2b_en":
        kagglehub.login()

    # Setup dataset
    data = []
    with open("databricks-dolly-15k.jsonl") as file:
        for line in file:
            features = json.loads(line)
            # Filter out examples with context, to keep it simple.
            if features["context"]:
                continue
            # Format the entire example as a single string.
            template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
            data.append(template.format(**features))

    # Only use 16000 training examples, to keep it fast.
    data = data[:16000]

    if args.model == "gemma_2b_en":
        preprocessor = keras_nlp.models.GemmaCausalLMPreprocessor.from_preset(
            args.model, sequence_length=128
        )
        model = keras_nlp.models.GemmaCausalLM.from_preset(
            args.model, preprocessor=preprocessor
        )
        model.backbone.token_embedding.trainable = False
    elif "gpt2" in args.model:
        preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
            args.model, sequence_length=128
        )
        model = keras_nlp.models.GPT2CausalLM.from_preset(
            args.model, preprocessor=preprocessor
        )
        model.backbone.token_embedding.trainable = False
        model.backbone.position_embedding.trainable = False
    if args.fp8:
        model.quantize("float8")

    model.summary()
    optimizer, loss = get_optimizer_and_loss()
    model.compile(optimizer=optimizer, loss=loss, weighted_metrics=["accuracy"])
    callbacks = [
        GPUMemoryCallback(target_batches=[5, 10, 25, 50, 100, 150, 200, 300])
    ]
    model.fit(
        data, batch_size=args.batch_size, epochs=EPOCHS, callbacks=callbacks
    )
    if keras.backend.backend() == "tensorflow":
        model_memory_usage = callbacks[0].memory_usage
        print(f"GPU Memory Usage (in GB): {max(model_memory_usage)}")

    if args.fp8:
        from keras import layers

        count = 0
        for layer in model._flatten_layers(False, True):
            list_of_sublayers = list(layer._flatten_layers())
            if len(list_of_sublayers) == 1:  # leaves of the model
                if isinstance(layer, (layers.Dense, layers.EinsumDense)):
                    print(
                        layer.outputs_grad_scale.path,
                        layer.outputs_grad_scale.value,
                    )
                    count += 1
                if count > 10:
                    break

Issues

References for Implementation

cc @kaixih

@github-actions github-actions bot added the Gemma Gemma model specific issues label Apr 11, 2024
@codecov-commenter
Copy link

codecov-commenter commented Apr 11, 2024

Codecov Report

Attention: Patch coverage is 95.32164% with 16 lines in your changes are missing coverage. Please review.

Project coverage is 76.51%. Comparing base (5187fac) to head (197c6e7).
Report is 5 commits behind head on master.

Files Patch % Lines
keras/layers/core/dense.py 94.05% 2 Missing and 4 partials ⚠️
keras/layers/core/einsum_dense.py 94.54% 2 Missing and 4 partials ⚠️
keras/layers/core/embedding.py 92.00% 1 Missing and 1 partial ⚠️
keras/layers/layer.py 80.00% 0 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19488      +/-   ##
==========================================
+ Coverage   76.27%   76.51%   +0.23%     
==========================================
  Files         367      367              
  Lines       41233    41578     +345     
  Branches     8076     8127      +51     
==========================================
+ Hits        31451    31813     +362     
+ Misses       8060     8052       -8     
+ Partials     1722     1713       -9     
Flag Coverage Δ
keras 76.36% <94.15%> (+0.22%) ⬆️
keras-jax 60.54% <93.56%> (+0.24%) ⬆️
keras-numpy 54.36% <57.89%> (+0.08%) ⬆️
keras-tensorflow 61.77% <94.15%> (+0.22%) ⬆️
keras-torch 60.63% <92.39%> (+0.21%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@james77777778 james77777778 marked this pull request as draft April 11, 2024 15:07
@james77777778
Copy link
Contributor Author

This PR needs to wait #19495

There might be potential bugs in LoRA and the quantization in torch backend if torch_params doesn't correctly track/untrack the added/removed variables.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! 👍

keras/backend/common/variables.py Outdated Show resolved Hide resolved
keras/dtype_policies/dtype_policy.py Outdated Show resolved Hide resolved
keras/layers/core/dense.py Outdated Show resolved Hide resolved
@@ -94,17 +97,23 @@ def __init__(
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.lora_rank = lora_rank
self.amax_history_length = amax_history_length
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should not be set on the layer as a constructor arg, but on the DTypePolicy. If you want to configure it, call the layer with layer.quantize(Float8DtypePolicy(amax_history_length=32)) or something like that. We may want to expose dedicate DTypePolicy subclasses to expose such arguments (this is the first one we encounter but there may be more of them in the future).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I struggled to set args to DTypePolicy and I'm uncertain if this is a good approach.

In the current codebase, configuring DTypePolicy for individual layers is challenging (as shown in #19381 ).

  1. DTypePolicy is treated as a global state. Making it difficult to set different amax_history_length for different layers. (We also need to consider scenarios involving subclasses)
  2. The serialization/deserialization may be an issue.

Based on these concerns, I adopted the approach simiar to lora_rank that amax_history_length defaults to None and is overridden as 1024 in self._float8_build if not specified.

What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would strongly prefer to keep dtype configuration and layer constructor args separate. All dtype configuration should go through DTypePolicy.

We also need to consider scenarios involving subclasses

If the dtype policy is a settable attribute, this is not a big issue -- you can instantiate the subclass, and then set the dtype policy on a specific layer owned by the subclass.

For serialization, the dtype policy (and its args) should be part of the layer config, under the dtype entry in the config dict.

lora_rank is simply a shortcut to 1. instantiate layer, 2. call enable_lora(). It's not necessary from a configuration standpoint.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. The idea of keeping the configuration minimal for layers is solid.

As we discussed above, I have introduced a QuantizedFloat8DTypePolicy in this PR. I think it would clearer to use "Quantized" in the naming, as it involves quantizing and dequantizing.

BTW, I'd like to point it out that setting the dtype policy after the instantiation would not trigger the quantization process.
We need to either pass the correct dtype_policy (such as "float8_from_float32") though dtype argment in __init__ or utilize quantize api.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, I'd like to point it out that setting the dtype policy after the instantiation would not trigger the quantization process.

Do you think it would be a good idea to implement a setter that checks whether the policy is a quantized policy, and calls .quantize() if it is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it would be a good idea to implement a setter that checks whether the policy is a quantized policy, and calls .quantize() if it is?

This sounds like a great idea!
I think this change will also improve the UX of serializing and deserializing quantizable layers.

@james77777778 james77777778 marked this pull request as ready for review April 13, 2024 04:35
@james77777778 james77777778 requested a review from fchollet April 13, 2024 04:43
return self._dtype_policy

@dtype_policy.setter
def dtype_policy(self, value):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the setter of dtype_policy and it works well.

@james77777778
Copy link
Contributor Author

james77777778 commented Apr 15, 2024

After adding dtype_policy setter, we can now implement a cleaner solution to support the quantization of subclasses.

import tempfile

import numpy as np

import keras
from keras import dtype_policies


@keras.saving.register_keras_serializable("MyPackage")
class MySubclass(keras.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.layer1 = keras.layers.Dense(4)
        self.layer2 = keras.layers.BatchNormalization(axis=-1)
        self.layer3 = keras.layers.ReLU()

    def call(self, inputs):
        x = self.layer1(inputs)
        x = self.layer2(x)
        return self.layer3(x)

    def get_config(self):
        config = super().get_config()
        quantized_dtype_policies = {}
        for idx, layer in enumerate(
            self._flatten_layers(include_self=False, recursive=False)
        ):
            if isinstance(
                layer.dtype_policy, dtype_policies.QuantizedDTypePolicy
            ):
                quantized_dtype_policies[str(idx)] = layer.dtype_policy
        if quantized_dtype_policies:
            config.update(
                {"quantized_dtype_policies": quantized_dtype_policies}
            )
        return config

    @classmethod
    def from_config(cls, config):
        quantized_dtype_policies = config.pop("quantized_dtype_policies", None)
        obj = super(MySubclass, cls).from_config(config)
        if quantized_dtype_policies:
            for idx, layer in enumerate(
                obj._flatten_layers(include_self=False, recursive=False)
            ):
                if str(idx) in quantized_dtype_policies:
                    layer.dtype_policy = quantized_dtype_policies[str(idx)]
        return obj


model = keras.Sequential([keras.layers.Input([8]), MySubclass()])
model.quantize("int8")

with tempfile.TemporaryDirectory() as tempdir:
    path = f"{tempdir}/model.keras"
    model.save(path)
    reloaded_model = keras.saving.load_model(path)

x = keras.random.uniform([1, 8])
y1 = model(x, training=False)
y2 = reloaded_model(x, training=False)
np.testing.assert_allclose(y1, y2)

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent work -- the changes look good!

@google-ml-butler google-ml-butler bot added the ready to pull Ready to be merged into the codebase label Apr 15, 2024
@fchollet fchollet merged commit 4c67dcf into keras-team:master Apr 15, 2024
9 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase labels Apr 15, 2024
@james77777778 james77777778 deleted the add-fp8-v2 branch April 16, 2024 01:18
fchollet added a commit that referenced this pull request Apr 19, 2024
* Introduce float8 training (#19488)

* Add float8 training support

* Add tests for fp8 training

* Add `quantize_and_dequantize` test

* Fix bugs and add float8 correctness tests

* Cleanup

* Address comments and cleanup

* Add docstrings and some minor refactoring

* Add `QuantizedFloat8DTypePolicy`

* Add dtype policy setter

* Fix torch dynamo issue by using `self._dtype_policy`

* Improve test coverage

* Add LoRA to ConvND layers (#19516)

* Add LoRA to `BaseConv`

* Add tests

* Fix typo

* Fix tests

* Fix tests

* Add path to run keras on dm-tree when optree is not available.

* feat(losses): add Tversky loss implementation (#19511)

* feat(losses): add Tversky loss implementation

* adjusted documentation

* Update KLD docs

* Models and layers now return owned metrics recursively. (#19522)

- added `Layer.metrics` to return all metrics owned by the layer and its sub-layers recursively.
- `Layer.metrics_variables` now returns variables from all metrics recursively, not just the layer and its direct sub-layers.
- `Model.metrics` now returns all metrics recursively, not just the model level metrics.
- `Model.metrics_variables` now returns variables from all metrics recursively, not just the model level metrics.
- added test coverage to test metrics and variables 2 levels deep.

This is consistent with the Keras 2 behavior and how `Model/Layer.variables` and `Model/Layer.weights` work.

* Update IoU ignore_class handling

* Fix `RandomBrightness`, Enhance `IndexLookup` Initialization and Expand Test Coverage for `Preprocessing Layers` (#19513)

* Add tests for CategoryEncoding class in category_encoding_test.py

* fix

* Fix IndexLookup class initialization and add test cases

* Add test case for IndexLookupLayerTest without vocabulary

* Fix IndexLookup class initialization

* Add normalization test cases

* Add test cases for Hashing class

* Fix value range validation error in RandomBrightness class

* Refactor IndexLookup class initialization and add test cases

* Reffix ndexLookup class initialization and afix est cases

* Add test for spectral norm

* Add missing test decorator

* Fix torch test

* Fix code format

* Generate API (#19530)

* API Generator for Keras

* API Generator for Keras

* Generates API Gen via api_gen.sh

* Remove recursive import of _tf_keras

* Generate API Files via api_gen.sh

* Update APIs

* Added metrics from custom `train_step`/`test_step` are now returned. (#19529)

This works the same way as in Keras 2, whereby the metrics are returned directly from the logs if the set of keys doesn't match the model metrics.

* Use temp dir and abs path in `api_gen.py` (#19533)

* Use temp dir and abs path

* Use temp dir and abs path

* Update Readme

* Update API

* Fix gradient accumulation when using `overwrite_with_gradient` during float8 training (#19534)

* Fix gradient accumulation with `overwrite_with_gradient` in float8 training

* Add comments

* Fix annotation

* Update code path in ignore path (#19537)

* Add operations per run (#19538)

* Include input shapes in model visualization.

* Add pad_to_aspect_ratio feature in ops.image.resize

* Add pad_to_aspect_ratio feature in Resizing layer.

* Fix incorrect usage of `quantize` (#19541)

* Add logic to prevent double quantization

* Add detailed info for double quantization error

* Update error msg

* Add eigh op.

* Add keepdim in argmax/argmin.

* Fix small bug in model.save_weights (#19545)

* Update public APIs.

* eigh should work on JAX GPU

* Copy init to keras/__init__.py (#19551)

* Revert "Copy init to keras/__init__.py (#19551)" (#19552)

This reverts commit da9af61.

* fixes for master

---------

Co-authored-by: james77777778 <20734616+james77777778@users.noreply.github.com>
Co-authored-by: Francois Chollet <francois.chollet@gmail.com>
Co-authored-by: Luca Pizzini <lpizzini7@gmail.com>
Co-authored-by: hertschuh <1091026+hertschuh@users.noreply.github.com>
Co-authored-by: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com>
Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com>
Co-authored-by: Sachin Prasad <sachinprasad@google.com>
Co-authored-by: Uwe Schmidt <uschmidt83@users.noreply.github.com>
fchollet added a commit that referenced this pull request May 3, 2024
* Introduce float8 training (#19488)

* Add float8 training support

* Add tests for fp8 training

* Add `quantize_and_dequantize` test

* Fix bugs and add float8 correctness tests

* Cleanup

* Address comments and cleanup

* Add docstrings and some minor refactoring

* Add `QuantizedFloat8DTypePolicy`

* Add dtype policy setter

* Fix torch dynamo issue by using `self._dtype_policy`

* Improve test coverage

* Add LoRA to ConvND layers (#19516)

* Add LoRA to `BaseConv`

* Add tests

* Fix typo

* Fix tests

* Fix tests

* Add path to run keras on dm-tree when optree is not available.

* feat(losses): add Tversky loss implementation (#19511)

* feat(losses): add Tversky loss implementation

* adjusted documentation

* Update KLD docs

* Models and layers now return owned metrics recursively. (#19522)

- added `Layer.metrics` to return all metrics owned by the layer and its sub-layers recursively.
- `Layer.metrics_variables` now returns variables from all metrics recursively, not just the layer and its direct sub-layers.
- `Model.metrics` now returns all metrics recursively, not just the model level metrics.
- `Model.metrics_variables` now returns variables from all metrics recursively, not just the model level metrics.
- added test coverage to test metrics and variables 2 levels deep.

This is consistent with the Keras 2 behavior and how `Model/Layer.variables` and `Model/Layer.weights` work.

* Update IoU ignore_class handling

* Fix `RandomBrightness`, Enhance `IndexLookup` Initialization and Expand Test Coverage for `Preprocessing Layers` (#19513)

* Add tests for CategoryEncoding class in category_encoding_test.py

* fix

* Fix IndexLookup class initialization and add test cases

* Add test case for IndexLookupLayerTest without vocabulary

* Fix IndexLookup class initialization

* Add normalization test cases

* Add test cases for Hashing class

* Fix value range validation error in RandomBrightness class

* Refactor IndexLookup class initialization and add test cases

* Reffix ndexLookup class initialization and afix est cases

* Add test for spectral norm

* Add missing test decorator

* Fix torch test

* Fix code format

* Generate API (#19530)

* API Generator for Keras

* API Generator for Keras

* Generates API Gen via api_gen.sh

* Remove recursive import of _tf_keras

* Generate API Files via api_gen.sh

* Update APIs

* Added metrics from custom `train_step`/`test_step` are now returned. (#19529)

This works the same way as in Keras 2, whereby the metrics are returned directly from the logs if the set of keys doesn't match the model metrics.

* Use temp dir and abs path in `api_gen.py` (#19533)

* Use temp dir and abs path

* Use temp dir and abs path

* Update Readme

* Update API

* Fix gradient accumulation when using `overwrite_with_gradient` during float8 training (#19534)

* Fix gradient accumulation with `overwrite_with_gradient` in float8 training

* Add comments

* Fix annotation

* Update code path in ignore path (#19537)

* Add operations per run (#19538)

* Include input shapes in model visualization.

* Add pad_to_aspect_ratio feature in ops.image.resize

* Add pad_to_aspect_ratio feature in Resizing layer.

* Fix incorrect usage of `quantize` (#19541)

* Add logic to prevent double quantization

* Add detailed info for double quantization error

* Update error msg

* Add eigh op.

* Add keepdim in argmax/argmin.

* Fix small bug in model.save_weights (#19545)

* Update public APIs.

* eigh should work on JAX GPU

* Copy init to keras/__init__.py (#19551)

* Revert "Copy init to keras/__init__.py (#19551)" (#19552)

This reverts commit da9af61.

* sum-reduce inlined losses

* Remove the dependency on `tensorflow.experimental.numpy` and support negative indices for `take` and `take_along_axis` (#19556)

* Remove `tfnp`

* Update numpy api

* Improve test coverage

* Improve test coverage

* Fix `Tri` and `Eye` and increase test converage

* Update `round` test

* Fix `jnp.round`

* Fix `diag` bug for iou_metrics

* Add op.select.

* Add new API for select

* Make `ops.abs` and `ops.absolute` consistent between backends. (#19563)

- The TensorFlow implementation was missing `convert_to_tensor`
- The sparse annotation was unnecessarily applied twice
- Now `abs` calls `absolute` in all backends

Also fixed TensorFlow `ops.select`.

* Add pickle support for Keras model (#19555)

* Implement unit tests for pickling

* Reformat model_test

* Reformat model_test

* Rename depickle to unpickle

* Rename depickle to unpickle

* Reformat

* remove a comment

* Ellipsis Serialization and tests (#19564)

* Serialization and tests

* Serialization and tests

* Serialization and tests

* Make TF one_hot input dtype less strict.

* Fix einsum `_int8_call` (#19570)

* CTC Decoding for JAX and Tensorflow (#19366)

* Tensorflow OP for CTC decoding

* JAX op for CTC greedy decoding

* Update CTC decoding documentation

* Fix linting issues

* Fix trailing whitespace

* Simplify returns in tensorflow CTC wrapper

* Fix CTC decoding error messages

* Fix line too long

* Bug fixes to JAX CTC greedy decoder

* Force int typecast in TF CTC decoder

* Unit tests for CTC greedy decoding

* Add unit test for CTC beam search decoding

* Fix mask index set location in JAX CTC decoding

* CTC beam search decoding for JAX

* Fix unhandled token repetitions in ctc_beam_search_decode

* Fix merge_repeated bug in CTC beam search decode

* Fix beam storage and repetition bugs in JAX ctc_decode

* Remove trailing whitespace

* Fix ordering bug for ties in JAX CTC beam search

* Cast sequence lengths to integers in JAX ctc_decode

* Remove line break in docstring

* CTC beam search decoding for JAX

* Fix unhandled token repetitions in ctc_beam_search_decode

* Fix merge_repeated bug in CTC beam search decode

* Fix beam storage and repetition bugs in JAX ctc_decode

* Fix ordering bug for ties in JAX CTC beam search

* Generate public api directory

* Add not implemented errors for NumPy and Torch CTC decoding

* Remove unused redefinition of JAX ctc_beam_search_decode

* Docstring edits

* Expand nan_to_num args.

* Add vectorize op.

* list insert requires index (#19575)

* Add signature and exclude args to knp.vectorize.

* Fix the apis of `dtype_polices` (#19580)

* Fix api of `dtype_polices`

* Update docstring

* Increase test coverage

* Fix format

* Fix keys of `save_own_variables` and `load_own_variables` (#19581)

* Fix JAX CTC test.

* Fix loss_weights handling in single output case

* Fix JAX vectorize.

* Move _tf_keras directory to the root of the pip package.

* One time fix to _tf_keras API.

* Convert return type imdb.load_data to nparray (#19598)

Convert return type imdb.load_data to Numpy array. Currently X_train and X-test returned as list.

* Fix typo

* fix api_gen.py for legacy (#19590)

* fix api_gen.py for legacy

* merge api and legacy for _tf_keras

* Improve int8 for `Embedding` (#19595)

* pin torch < 2.3.0 (#19603)

* Clean up duplicated `inputs_quantizer` (#19604)

* Cleanup duplicated `inputs_quantizer` and add type check for `input_spec` and `supports_masking`

* Revert setter

* output format changes and errors in github (#19608)

* Provide write permission to action for cache management. (#19606)

* Pickle support for all saveables (#19592)

* Pickle support

* Add keras pickleable mixin

* Reformat

* Implement pickle all over

* reformat

* Reformat

* Keras saveable

* Keras saveable

* Keras saveable

* Keras saveable

* Keras saveable

* obj_type

* Update pickleable

* Saveable logic touchups

* Add slogdet op.

* Update APIs

* Remove unused import

* Refactor CTC APIs (#19611)

* Add `ctc_loss` and `ctc_decode` for numpy backend, improve imports and tests

* Support "beam_search" strategy for torch's `ctc_decode`

* Improve `ctc_loss`

* Cleanup

* Refactor `ctc_decode`

* Update docstring

* Update docstring

* Add `CTCDecode` operation and ensure dtype inference of `ctc_decode`

* Fix `name` of `losses.CTC`

* update the namex version requirements (#19617)

* Add `PSNR` API (#19616)

* PSNR

* Fix

* Docstring format

* Remove `PYTORCH_ENABLE_MPS_FALLBACK` flag requirement for mps (#19618)

* Remove `PYTORCH_ENABLE_MPS_FALLBACK` flag requirement for mps

* Formatting

* Implement custom layer insertion in clone_model. (#19610)

* Implement custom layer insertion in clone_model.

* Add recursive arg and tests.

* Add nested sequential cloning test

* Fix bidir lstm saving issue.

* Fix CI

* Fix cholesky tracing with jax

* made extract_patches dtype agnostic (#19621)

* Simplify Bidirectional implementation

* Add support for infinite `PyDataset`s. (#19624)

`PyDataset` now uses the `num_batches` property instead of `__len__` to support `None`, which is how one indicates the dataset is infinite. Note that infinite datasets are not shuffled.

Fixes #19528

Also added exception reporting when using multithreading / multiprocessing. Previously, the program would just hang with no error reported.

* Fix dataset shuffling issue.

* Update version string.

* Minor fix

* Restore version string resolution in pip_build.

* Speed up `DataAdapter` tests by testing only the current backend. (#19625)

There is no use case for using an iterator for a different backend than the current backend.

Also:
- limit the number of tests using multiprocessing, the threading tests give us good coverage.
- fixed the `test_exception_reported` test, which was not actually exercising the multiprocessing / multithreading cases.
- removed unused `init_pool` method.

* feat(ops): support np.argpartition (#19588)

* feat(ops): support np.argpartition

* updated documentation, type-casting, and tf implementation

* fixed tf implementation

* added torch cast to int32

* updated torch type and API generated files

* added torch output type cast

* test(trainers): add test_errors implementation for ArrayDataAdapter class (#19626)

* Fix torch GPU CI

* Fix argmax/argmin keepdims with defined axis in TF

* Misc fixes in TF backend ops.

* Fix `argpartition` cuda bug in torch (#19634)

* fix(ops): specify NonZero output dtype and add test coverage (#19635)

* Fix `ops.ctc_decode` (#19633)

* Fix greedy ctc decode

* Remove print

* Fix `tf.nn.ctc_beam_search_decoder`

* Change default `mask_index` to `0`

* Fix losses test

* Update

* Ensure the same rule applies for np arrays in autocasting (#19636)

* Ensure the same rule applies for np arrays in autocasting

* Trigger CI by adding docstring

* Update

* Update docstring

* Fix `istft` and add class `TestMathErrors` in `ops/math_test.py` (#19594)

* Fix and test math functions for jax backend

* run /workspaces/keras/shell/format.sh

* refix

* fix

* fix _get_complex_tensor_from_tuple

* fix

* refix

* Fix istft function to handle inputs with less than 2 dimensions

* fix

* Fix ValueError in istft function for inputs with less than 2 dimensions

* Return a tuple from `ops.shape` with the Torch backend. (#19640)

With Torch, `x.shape` returns a `torch.Size`, which is a subclass of `tuple` but can cause different behaviors. In particular `convert_to_tensor` does not work on `torch.Size`.

This fixes #18900

* support conv3d on cpu for TF (#19641)

* Enable cudnn rnns when dropout is set (#19645)

* Enable cudnn rnns when dropout is set

* Fix

* Fix plot_model for input dicts.

* Fix deprecation warning in torch

* Bump the github-actions group with 2 updates (#19653)

Bumps the github-actions group with 2 updates: [actions/upload-artifact](https://github.com/actions/upload-artifact) and [github/codeql-action](https://github.com/github/codeql-action).


Updates `actions/upload-artifact` from 4.3.1 to 4.3.3
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](actions/upload-artifact@5d5d22a...6546280)

Updates `github/codeql-action` from 3.24.9 to 3.25.3
- [Release notes](https://github.com/github/codeql-action/releases)
- [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md)
- [Commits](github/codeql-action@1b1aada...d39d31e)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: github-actions
- dependency-name: github/codeql-action
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Bump the python group with 2 updates (#19654)

Bumps the python group with 2 updates: torch and torchvision.


Updates `torch` from 2.2.1+cu121 to 2.3.0+cu121

Updates `torchvision` from 0.17.1+cu121 to 0.18.0+cu121

---
updated-dependencies:
- dependency-name: torch
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: python
- dependency-name: torchvision
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: python
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Revert "Bump the python group with 2 updates (#19654)" (#19655)

This reverts commit 09133f4.

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: james77777778 <20734616+james77777778@users.noreply.github.com>
Co-authored-by: Francois Chollet <francois.chollet@gmail.com>
Co-authored-by: Luca Pizzini <lpizzini7@gmail.com>
Co-authored-by: hertschuh <1091026+hertschuh@users.noreply.github.com>
Co-authored-by: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com>
Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com>
Co-authored-by: Sachin Prasad <sachinprasad@google.com>
Co-authored-by: Uwe Schmidt <uschmidt83@users.noreply.github.com>
Co-authored-by: Luke Wood <LukeWood@users.noreply.github.com>
Co-authored-by: Maanas Arora <maanasarora23@gmail.com>
Co-authored-by: AlexanderLavelle <73360008+AlexanderLavelle@users.noreply.github.com>
Co-authored-by: Surya <116063290+SuryanarayanaY@users.noreply.github.com>
Co-authored-by: Shivam Mishra <124146945+shmishra99@users.noreply.github.com>
Co-authored-by: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com>
Co-authored-by: IMvision12 <88665786+IMvision12@users.noreply.github.com>
Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com>
Co-authored-by: Vachan V Y <109357590+VachanVY@users.noreply.github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues size:XL
Projects
Status: Merged
Development

Successfully merging this pull request may close these issues.

5 participants