Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix][relay] Fix alpha attribute with None in ELU #14742

Merged
merged 2 commits into from
May 2, 2023

Conversation

jikechao
Copy link
Contributor

@jikechao jikechao commented Apr 28, 2023

This patch fixes a bug in ELU. when the alpha=None, It will lead to a wrong compilation and give a confusing inference result.
For example, the following script will produce inconsistent inference results between TVM and Keras. However, we hope tvm can reject it immediately in the model loading stage. Because the model is invalid (see Line 210-212 in keras source code file).

This PR is a similar fix to the pr-14707. pr-14707 fixed a similar bug in LeakyReLU.

image

Reproducible script

import tvm
import tvm.relay as relay
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers, models


input_shape = (1, 2, 3, 4)
input_data = np.random.random(input_shape)
x = layers.Input(shape=input_shape[1:], dtype='float32')

layer = keras.layers.ELU(alpha=None)
layer.set_weights(layer.get_weights())

y = layer(x)
model = models.Model(x, y)

res_keras = model.predict(input_data)

shape_dict = {'input_1': input_shape}
mod, params = relay.frontend.from_keras(model, shape_dict)

with tvm.transform.PassContext(opt_level=3):
    model = relay.build_module.create_executor("graph", mod, tvm.cpu(0), 'llvm', params).evaluate()


test_x_tvm = input_data
res_tvm = model(tvm.nd.array(test_x_tvm.astype('float32'))).numpy()

np.testing.assert_allclose(res_keras, res_tvm, atol=1e-3, rtol=1e-3)

@tvm-bot
Copy link
Collaborator

tvm-bot commented Apr 28, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@jikechao
Copy link
Contributor Author

Hi @echuraev @AndrewZhaoLuo, could you help me review it? Thank you!

Copy link
Contributor

@echuraev echuraev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you for your PR.

@jikechao
Copy link
Contributor Author

jikechao commented May 1, 2023

@tvm-bot rerun

1 similar comment
@jikechao
Copy link
Contributor Author

jikechao commented May 1, 2023

@tvm-bot rerun

@echuraev echuraev merged commit d1e1b4c into apache:main May 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants