diff --git a/onnx_array_api/reference/ops/op_cast_like.py b/onnx_array_api/reference/ops/op_cast_like.py index a520405..08c065b 100644 --- a/onnx_array_api/reference/ops/op_cast_like.py +++ b/onnx_array_api/reference/ops/op_cast_like.py @@ -17,7 +17,8 @@ def _cast_like(x, y, saturate): if bfloat16 is None: - return (cast_to(x, y.dtype, saturate),) + to = np_dtype_to_tensor_dtype(y.dtype) + return (cast_to(x, to, saturate),) if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16": # np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16 to = TensorProto.BFLOAT16