Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Tests] Fix Autograph warnings in division test (#1329)
**Context:** Currently, there are two `UserWarning`s issued in the test `TestJaxIndexOperatorUpdate::test_single_index_div_update_all_items`: Worflow `f`: ``` frontend/test/pytest/test_autograph.py::TestJaxIndexOperatorUpdate::test_single_index_div_update_all_items .../catalyst/frontend/catalyst/autograph/ag_primitives.py:350: UserWarning: Tracing of an AutoGraph converted for loop failed with an exception: AutoGraphError: The variable 'x' was initialized with the wrong type, or you may be trying to change its type from one iteration to the next. Expected: ShapedArray(float64[3]), Got: ShapedArray(int64[3]) The error ocurred within the body of the following for loop statement: File ".../catalyst/frontend/test/pytest/test_autograph.py", line 2165, in f for i in range(first_dim): If you intended for the conversion to happen, make sure that the (now dynamic) loop variable is not used in tracing-incompatible ways, for instance by indexing a Python list with it. In that case, the list should be wrapped into an array. To understand different types of JAX tracing errors, please refer to the guide at: https://jax.readthedocs.io/en/latest/errors.html If you did not intend for the conversion to happen, you may safely ignore this warning. warnings.warn( ``` Workflow `g`: ``` frontend/test/pytest/test_autograph.py::TestJaxIndexOperatorUpdate::test_single_index_div_update_all_items .../catalyst/frontend/catalyst/autograph/ag_primitives.py:350: UserWarning: Tracing of an AutoGraph converted for loop failed with an exception: AutoGraphError: The variable 'result' was initialized with the wrong type, or you may be trying to change its type from one iteration to the next. Expected: ShapedArray(float64[3]), Got: ShapedArray(int64[3]) The error ocurred within the body of the following for loop statement: File ".../catalyst/frontend/test/pytest/test_autograph.py", line 2176, in g for i in range(first_dim): If you intended for the conversion to happen, make sure that the (now dynamic) loop variable is not used in tracing-incompatible ways, for instance by indexing a Python list with it. In that case, the list should be wrapped into an array. To understand different types of JAX tracing errors, please refer to the guide at: https://jax.readthedocs.io/en/latest/errors.html If you did not intend for the conversion to happen, you may safely ignore this warning. warnings.warn( ``` Both warnings are symptoms of the same issue: the input is an array of integers, `np.array(5, 3, 4)`, but the test is performing division (directly on the input in worflow `f` and on a copy of the input in worflow `g`), which requires changing the array's type to floats, hence the warning. **Description of the Change:** Changes the input to an array of floats: `np.array([5.0, 3.0, 4.0])`. This fixes the warnings since there is now no array type conversion during the division operation. **Benefits:** Fixes warnings in test suite; cleaner test output. **Possible Drawbacks:** None.
- Loading branch information