Skip to content

Commit

Permalink
CAReduce loop reordering C-impl
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Aug 13, 2024
1 parent 1a944b7 commit 756170b
Show file tree
Hide file tree
Showing 3 changed files with 390 additions and 188 deletions.
202 changes: 101 additions & 101 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from copy import copy
from textwrap import dedent

import numpy as np
from numpy.core.numeric import normalize_axis_tuple
Expand Down Expand Up @@ -1466,116 +1467,114 @@ def infer_shape(self, fgraph, node, shapes):
return ((),)
return ([ishape[i] for i in range(node.inputs[0].type.ndim) if i not in axis],)

def _c_all(self, node, name, inames, onames, sub):
input = node.inputs[0]
output = node.outputs[0]
def _c_all(self, node, name, input_names, output_names, sub):
[inp] = node.inputs
[out] = node.outputs
ndim = inp.type.ndim

iname = inames[0]
oname = onames[0]
[inp_name] = input_names
[out_name] = output_names

idtype = input.type.dtype_specs()[1]
odtype = output.type.dtype_specs()[1]
inp_dtype = inp.type.dtype_specs()[1]
out_dtype = out.type.dtype_specs()[1]

acc_dtype = getattr(self, "acc_dtype", None)

if acc_dtype is not None:
if acc_dtype == "float16":
raise MethodNotDefined("no c_code for float16")
acc_type = TensorType(shape=node.outputs[0].type.shape, dtype=acc_dtype)
adtype = acc_type.dtype_specs()[1]
acc_dtype = acc_type.dtype_specs()[1]
else:
adtype = odtype
acc_dtype = out_dtype

axis = self.axis
if axis is None:
axis = list(range(input.type.ndim))
axis = list(range(inp.type.ndim))

if len(axis) == 0:
# This is just an Elemwise cast operation
# The acc_dtype is never a downcast compared to the input dtype
# So we just need a cast to the output dtype.
var = pytensor.tensor.basic.cast(input, node.outputs[0].dtype)
if var is input:
var = Elemwise(scalar_identity)(input)
var = pytensor.tensor.basic.cast(inp, node.outputs[0].dtype)
if var is inp:
var = Elemwise(scalar_identity)(inp)
assert var.dtype == node.outputs[0].dtype
return var.owner.op._c_all(var.owner, name, inames, onames, sub)

order1 = [i for i in range(input.type.ndim) if i not in axis]
order = order1 + list(axis)
return var.owner.op._c_all(var.owner, name, input_names, output_names, sub)

nnested = len(order1)
inp_dims = list(range(ndim))
non_reduced_dims = [i for i in inp_dims if i not in axis]
counter = iter(range(ndim))
acc_dims = ["x" if i in axis else next(counter) for i in range(ndim)]

sub = dict(sub)
for i, (input, iname) in enumerate(zip(node.inputs, inames)):
sub[f"lv{i}"] = iname
sub = sub.copy()
sub["lv0"] = inp_name
sub["lv1"] = out_name
sub["olv"] = out_name

decl = ""
if adtype != odtype:
if acc_dtype != out_dtype:
# Create an accumulator variable different from the output
aname = "acc"
decl = acc_type.c_declare(aname, sub)
decl += acc_type.c_init(aname, sub)
acc_name = "acc"
setup = acc_type.c_declare(acc_name, sub) + acc_type.c_init(acc_name, sub)
else:
# the output is the accumulator variable
aname = oname

decl += cgen.make_declare([order], [idtype], sub)
checks = cgen.make_checks([order], [idtype], sub)

alloc = ""
i += 1
sub[f"lv{i}"] = oname
sub["olv"] = oname

# Allocate output buffer
alloc += cgen.make_declare(
[list(range(nnested)) + ["x"] * len(axis)], [odtype], dict(sub, lv0=oname)
)
alloc += cgen.make_alloc([order1], odtype, sub)
alloc += cgen.make_checks(
[list(range(nnested)) + ["x"] * len(axis)], [odtype], dict(sub, lv0=oname)
acc_name = out_name
setup = ""

# Define strides of input array
setup += cgen.make_declare(
[inp_dims], [inp_dtype], sub, compute_stride_jump=False
) + cgen.make_checks([inp_dims], [inp_dtype], sub, compute_stride_jump=False)

# Define strides of output array and allocate it
out_sub = sub | {"lv0": out_name}
alloc = (
cgen.make_declare(
[acc_dims], [out_dtype], out_sub, compute_stride_jump=False
)
+ cgen.make_alloc([non_reduced_dims], out_dtype, sub)
+ cgen.make_checks(
[acc_dims], [out_dtype], out_sub, compute_stride_jump=False
)
)

if adtype != odtype:
# Allocate accumulation buffer
sub[f"lv{i}"] = aname
sub["olv"] = aname
if acc_dtype != out_dtype:
# Define strides of accumulation buffer and allocate it
sub["lv1"] = acc_name
sub["olv"] = acc_name

alloc += cgen.make_declare(
[list(range(nnested)) + ["x"] * len(axis)],
[adtype],
dict(sub, lv0=aname),
)
alloc += cgen.make_alloc([order1], adtype, sub)
alloc += cgen.make_checks(
[list(range(nnested)) + ["x"] * len(axis)],
[adtype],
dict(sub, lv0=aname),
acc_sub = sub | {"lv0": acc_name}
alloc += (
cgen.make_declare(
[acc_dims], [acc_dtype], acc_sub, compute_stride_jump=False
)
+ cgen.make_alloc([non_reduced_dims], acc_dtype, sub)
+ cgen.make_checks(
[acc_dims], [acc_dtype], acc_sub, compute_stride_jump=False
)
)

identity = self.scalar_op.identity

if np.isposinf(identity):
if input.type.dtype in ("float32", "float64"):
if inp.type.dtype in ("float32", "float64"):
identity = "__builtin_inf()"
elif input.type.dtype.startswith("uint") or input.type.dtype == "bool":
elif inp.type.dtype.startswith("uint") or inp.type.dtype == "bool":
identity = "1"
else:
identity = "NPY_MAX_" + str(input.type.dtype).upper()
identity = "NPY_MAX_" + str(inp.type.dtype).upper()
elif np.isneginf(identity):
if input.type.dtype in ("float32", "float64"):
if inp.type.dtype in ("float32", "float64"):
identity = "-__builtin_inf()"
elif input.type.dtype.startswith("uint") or input.type.dtype == "bool":
elif inp.type.dtype.startswith("uint") or inp.type.dtype == "bool":
identity = "0"
else:
identity = "NPY_MIN_" + str(input.type.dtype).upper()
identity = "NPY_MIN_" + str(inp.type.dtype).upper()
elif identity is None:
raise TypeError(f"The {self.scalar_op} does not define an identity.")

task0_decl = f"{adtype}& {aname}_i = *{aname}_iter;\n{aname}_i = {identity};"

task1_decl = f"{idtype}& {inames[0]}_i = *{inames[0]}_iter;\n"
initial_value = f"{acc_name}_i = {identity};"

task1_code = self.scalar_op.c_code(
inner_task = self.scalar_op.c_code(
Apply(
self.scalar_op,
[
Expand All @@ -1588,44 +1587,45 @@ def _c_all(self, node, name, inames, onames, sub):
],
),
None,
[f"{aname}_i", f"{inames[0]}_i"],
[f"{aname}_i"],
[f"{acc_name}_i", f"{inp_name}_i"],
[f"{acc_name}_i"],
sub,
)
code1 = f"""
{{
{task1_decl}
{task1_code}
}}
"""

if node.inputs[0].type.ndim:
if len(axis) == 1:
all_code = [("", "")] * nnested + [(task0_decl, code1), ""]
else:
all_code = (
[("", "")] * nnested
+ [(task0_decl, "")]
+ [("", "")] * (len(axis) - 2)
+ [("", code1), ""]
)
if out.type.ndim == 0:
# Simple case where everything is reduced, no need for loop ordering
loop = cgen.make_complete_loop_careduce(
inp_var=inp_name,
acc_var=acc_name,
inp_dtype=inp_dtype,
acc_dtype=acc_dtype,
initial_value=initial_value,
inner_task=inner_task,
fail_code=sub["fail"],
)
else:
all_code = [task0_decl + code1]
loop = cgen.make_loop_careduce(
[order, list(range(nnested)) + ["x"] * len(axis)],
[idtype, adtype],
all_code,
sub,
)
loop = cgen.make_reordered_loop_careduce(
inp_var=inp_name,
acc_var=acc_name,
inp_dtype=inp_dtype,
acc_dtype=acc_dtype,
inp_ndim=ndim,
reduction_axes=axis,
initial_value=initial_value,
inner_task=inner_task,
)

end = ""
if adtype != odtype:
end = f"""
PyArray_CopyInto({oname}, {aname});
"""
end += acc_type.c_cleanup(aname, sub)
if acc_dtype != out_dtype:
cast = dedent(
f"""
PyArray_CopyInto({out_name}, {acc_name});
{acc_type.c_cleanup(acc_name, sub)}
"""
)
else:
cast = ""

return decl, checks, alloc, loop, end
return setup, alloc, loop, cast

def c_code(self, node, name, inames, onames, sub):
code = "\n".join(self._c_all(node, name, inames, onames, sub))
Expand All @@ -1637,7 +1637,7 @@ def c_headers(self, **kwargs):

def c_code_cache_version_apply(self, node):
# the version corresponding to the c code in this Op
version = [9]
version = [10]

# now we insert versions for the ops on which we depend...
scalar_node = Apply(
Expand Down
Loading

0 comments on commit 756170b

Please sign in to comment.