Skip to content

Commit

Permalink
fix: TypeError fix for can_cast (#3255)
Browse files Browse the repository at this point in the history
* fix type error for can_cast

* cleanup
  • Loading branch information
ianna authored Sep 26, 2024
1 parent eaa43ff commit 704fb45
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/awkward/_broadcasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def broadcast_any_option():
mask = backend.index_nplike.logical_or(mask, m, maybe_out=mask)

nextmask = Index8(mask.view(np.int8))
index = backend.index_nplike.full(mask.shape[0], -1, dtype=np.int64)
index = backend.index_nplike.full(mask.shape[0], np.int64(-1), dtype=np.int64)
index[~mask] = backend.index_nplike.arange(
backend.index_nplike.shape_item_as_index(mask.shape[0])
- backend.index_nplike.count_nonzero(mask),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
// def f(grid, block, args):
// (tmpptr, fromstarts, fromstops, length, toequal, invocation_index, err_code) = args
// if length > 1:
// scan_in_array = cupy.full((length - 1) * (length - 2), 0, dtype=cupy.int64)
// scan_in_array = cupy.full((length - 1) * (length - 2), cupy.array(0), dtype=cupy.int64)
// else:
// scan_in_array = cupy.full(0, 0, dtype=cupy.int64)
// scan_in_array = cupy.full(0, cupy.array(0), dtype=cupy.int64)
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_NumpyArray_subrange_equal_bool", bool_, fromstarts.dtype, fromstops.dtype, bool_]))(grid, block, (tmpptr, fromstarts, fromstops, length, toequal, scan_in_array, invocation_index, err_code))
// toequal[0] = cupy.any(scan_in_array == True)
// out["awkward_NumpyArray_subrange_equal_bool", {dtype_specializations}] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
// grid_size = math.floor((lenparents + block[0] - 1) / block[0])
// else:
// grid_size = 1
// temp = cupy.full(lenparents, identity, dtype=toptr.dtype)
// temp = cupy.full(lenparents, cupy.array([identity]), dtype=toptr.dtype)
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_max_a", cupy.dtype(toptr.dtype).type, cupy.dtype(fromptr.dtype).type, parents.dtype]))((grid_size,), block, (toptr, fromptr, parents, lenparents, outlength, toptr.dtype.type(identity), temp, invocation_index, err_code))
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_max_b", cupy.dtype(toptr.dtype).type, cupy.dtype(fromptr.dtype).type, parents.dtype]))((grid_size,), block, (toptr, fromptr, parents, lenparents, outlength, toptr.dtype.type(identity), temp, invocation_index, err_code))
// out["awkward_reduce_max_a", {dtype_specializations}] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
// grid_size = math.floor((lenparents + block[0] - 1) / block[0])
// else:
// grid_size = 1
// temp = cupy.full(lenparents, identity, dtype=toptr.dtype)
// temp = cupy.full(lenparents, cupy.array([identity]), dtype=toptr.dtype)
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_min_a", cupy.dtype(toptr.dtype).type, cupy.dtype(fromptr.dtype).type, parents.dtype]))((grid_size,), block, (toptr, fromptr, parents, lenparents, outlength, toptr.dtype.type(identity), temp, invocation_index, err_code))
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_min_b", cupy.dtype(toptr.dtype).type, cupy.dtype(fromptr.dtype).type, parents.dtype]))((grid_size,), block, (toptr, fromptr, parents, lenparents, outlength, toptr.dtype.type(identity), temp, invocation_index, err_code))
// out["awkward_reduce_min_a", {dtype_specializations}] = None
Expand Down
6 changes: 4 additions & 2 deletions src/awkward/_nplikes/array_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def full(
*,
dtype: DTypeLike | None = None,
) -> ArrayLikeT:
return self._module.full(shape, fill_value, dtype=dtype)
return self._module.full(shape, self._module.array(fill_value), dtype=dtype)

def zeros_like(
self, x: ArrayLikeT | PlaceholderArray, *, dtype: DTypeLike | None = None
Expand Down Expand Up @@ -146,7 +146,9 @@ def full_like(
if isinstance(x, PlaceholderArray):
return self.full(x.shape, fill_value, dtype=dtype or x.dtype)
else:
return self._module.full_like(x, fill_value, dtype=dtype)
return self._module.full_like(
x, self._module.array(fill_value), dtype=dtype
)

def arange(
self,
Expand Down

0 comments on commit 704fb45

Please sign in to comment.