Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ignore_class param to sparse CE loss #599

Merged
merged 15 commits into from
Aug 27, 2023

Conversation

AakashKumarNain
Copy link
Collaborator

Add ignore_class param for sparse categorical cross entropy loss

@AakashKumarNain
Copy link
Collaborator Author

AakashKumarNain commented Jul 25, 2023

@fchollet ops.where is pretty inconsistent across the backends, and within a single backend itself. The problem is that it accepts two positional arguments which aren't optional. Ideally it should work just with the condition alone. This is the reason why tests for torch is failing

Apart from that, tf.where(...) returns a tensor while every other backend will return a tuple. Do we make it consistent across the backends or should we leave it to the behavior of the backend?

@fchollet
Copy link
Member

Do we make it consistent across the backends or should we leave it to the behavior of the backend?

We should make the behavior consistent across all backends -- that's important. And when applicable the behavior should also be standardized on the numpy behavior.

@AakashKumarNain
Copy link
Collaborator Author

Thank you. I will make the changes accordingly

@AakashKumarNain
Copy link
Collaborator Author

AakashKumarNain commented Jul 27, 2023

@fchollet I have skipped torch backend for now. ops.scatter(...) in torch backend seems broken, and I have no idea how to fix it right away

@AakashKumarNain
Copy link
Collaborator Author

I just realized that even though the tests for jax are passing, I am skeptical about it working when the loss function is complied. Why? Because the masking is dynamic in this case

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. What's the error in torch? And for JAX, surely correctness is being tested so the test passing should be proof that it works?

keras_core/losses/losses.py Outdated Show resolved Hide resolved
@AakashKumarNain
Copy link
Collaborator Author

AakashKumarNain commented Jul 30, 2023

Thanks for review it. Here is the traceback for torch:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)


File ~/keras-core/keras_core/losses/loss.py:44, in Loss.__call__(self, y_true, y_pred, sample_weight)
     37 y_pred = tree.map_structure(
     38     lambda x: ops.convert_to_tensor(x, dtype=dtype), y_pred
     39 )
     40 y_true = tree.map_structure(
     41     lambda x: ops.convert_to_tensor(x, dtype=dtype), y_true
     42 )
---> 44 losses = self.call(y_true, y_pred)
     45 out_mask = getattr(losses, "_keras_mask", None)
     47 if in_mask is not None and out_mask is not None:

Cell In[5], line 20, in LossFunctionWrapper.call(self, y_true, y_pred)
     18 def call(self, y_true, y_pred):
     19     y_true, y_pred = squeeze_to_same_rank(y_true, y_pred)
---> 20     return self.fn(y_true, y_pred, **self._fn_kwargs)

Cell In[13], line 57, in sparse_categorical_crossentropy(y_true, y_pred, from_logits, ignore_class, axis)
     54 if isinstance(valid_idx, (tuple, list)):
     55     valid_idx = ops.stack(valid_idx, -1)
---> 57 res = ops.scatter(indices=valid_idx, values=res, shape=res_shape)
     59 if backend.backend() != "numpy":
     60     res._keras_mask = valid_mask

File ~/keras-core/keras_core/ops/core.py:61, in scatter(indices, values, shape)
     59 if any_symbolic_tensors((indices, values, shape)):
     60     return Scatter().symbolic_call(indices, values, shape)
---> 61 return backend.core.scatter(indices, values, shape)

File ~/keras-core/keras_core/backend/torch/core.py:277, in scatter(indices, values, shape)
    275 value_shape = shape[index_length:]
    276 indices = torch.reshape(indices, [-1, index_length])
--> 277 values = torch.reshape(values, [-1] + list(value_shape))
    279 for i in range(indices.shape[0]):
    280     index = indices[i]

RuntimeError: shape '[-1, 2]' is invalid for input of size 1

And for JAX, surely correctness is being tested so the test passing should be proof that it works?

I don't know how the test is passing because I definitely get error when I try to jit it separately. And the error makes sense as well. This is the version that is jittable in JAX and works as expected

def loss_fn(y_true, y_pred, ignore_class=-1):
    mask = ops.equal(y_true, ignore_class)
    mask_idx = ops.stack(jnp.nonzero(mask, size=len(mask)), -1)
    
    y_true_masked = ops.where(mask, 0.0, y_true)
    y_pred_masked = ops.where(mask, 0.0, y_pred)
    
    loss = ops.sparse_categorical_crossentropy(
        target=y_true_masked,
        output=y_pred_masked,
        from_logits=True,
        axis=-1
    )
    
    mask = ops.reshape(mask, y_pred.shape[:-1])
    loss = ops.where(mask, 0.0, loss)
    return loss

The problem with this code though that it implements the lax compatible nonzero(...) method that takes the size argument. This argument is missing from other backends. One way to tackle this problem is to handle it in a cond depending on the backend, and implement two version of it, one for JAX, and one for others

@kerighan
Copy link

kerighan commented Aug 2, 2023

I don't know if it's the place to add a comment or try to help and I may be mistaken, but I believe the behaviour of ignore_class could be expected without using ops.where, for example using the following implementation:

def sparse_categorical_crossentropy_loss(ignore_class=None):
    def wrapper(y_true, y_pred_logits):
        # Transforming logits to softmax for probability distribution
        y_pred = ops.nn.softmax(y_pred_logits, axis=-1)

        # Getting the number of classes from predictions
        num_classes = ops.shape(y_pred)[-1]

        # One-hot encoding the true labels
        y_true_one_hot = ops.one_hot(
            ops.cast(y_true, "int32"), num_classes)

        # Mask for valid positions (not ignore_class), if specified
        if ignore_class is not None:
            mask = ops.not_equal(y_true, ignore_class)
            mask = ops.cast(mask, "float32")
            y_true_one_hot *= ops.expand_dims(mask, axis=-1)
            # Masking the one-hot labels

        # Adding a small value to avoid computing log(0) which gives NaN
        y_pred = ops.clip(y_pred, 1e-7, 1.)

        # Computing cross-entropy
        loss = -ops.sum(y_true_one_hot * ops.log(y_pred), axis=-1)

        return ops.mean(loss)
    return wrapper

It does not use ops.sparse_categorical_crossentropy though, but the masking bit seems to work accross all backends.

@AakashKumarNain
Copy link
Collaborator Author

AakashKumarNain commented Aug 3, 2023

Thanks for the suggestion. If you look at the changed files in this PR, this is exactly what it does except for using ops.sparse_categoical_crossentropy(...), and the reason behind that is to avoid redundancy as much as possible.

@AakashKumarNain
Copy link
Collaborator Author

AakashKumarNain commented Aug 23, 2023

Update: Now that I have fixed the where op, I have a version that works with all the backends. The only catch is with JAX, as the version won't be jitable. Figuring out a way to do that as well

@AakashKumarNain
Copy link
Collaborator Author

@fchollet good news! The current version is now fully jit compatible.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update! Looking good

keras_core/losses/losses.py Outdated Show resolved Hide resolved
keras_core/losses/losses.py Outdated Show resolved Hide resolved
@AakashKumarNain
Copy link
Collaborator Author

@fchollet thank you for the review and the valuable suggestions. I have made the necessary changes.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@fchollet fchollet merged commit 65ef46a into keras-team:main Aug 27, 2023
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants