From 4dab3457cc255b04a56ccf725c03b0b805bd1072 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 21 Jul 2025 12:08:47 +0200 Subject: [PATCH] fix documentation --- onnx_array_api/reference/ops/op_cast_like.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnx_array_api/reference/ops/op_cast_like.py b/onnx_array_api/reference/ops/op_cast_like.py index a520405..08c065b 100644 --- a/onnx_array_api/reference/ops/op_cast_like.py +++ b/onnx_array_api/reference/ops/op_cast_like.py @@ -17,7 +17,8 @@ def _cast_like(x, y, saturate): if bfloat16 is None: - return (cast_to(x, y.dtype, saturate),) + to = np_dtype_to_tensor_dtype(y.dtype) + return (cast_to(x, to, saturate),) if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16": # np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16 to = TensorProto.BFLOAT16