Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup CAReduce C-implementation with loop reordering #971

Merged
merged 3 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 101 additions & 101 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from copy import copy
from textwrap import dedent

import numpy as np
from numpy.core.numeric import normalize_axis_tuple
Expand Down Expand Up @@ -1466,116 +1467,114 @@ def infer_shape(self, fgraph, node, shapes):
return ((),)
return ([ishape[i] for i in range(node.inputs[0].type.ndim) if i not in axis],)

def _c_all(self, node, name, inames, onames, sub):
input = node.inputs[0]
output = node.outputs[0]
def _c_all(self, node, name, input_names, output_names, sub):
[inp] = node.inputs
[out] = node.outputs
ndim = inp.type.ndim

iname = inames[0]
oname = onames[0]
[inp_name] = input_names
[out_name] = output_names

idtype = input.type.dtype_specs()[1]
odtype = output.type.dtype_specs()[1]
inp_dtype = inp.type.dtype_specs()[1]
out_dtype = out.type.dtype_specs()[1]

acc_dtype = getattr(self, "acc_dtype", None)

if acc_dtype is not None:
if acc_dtype == "float16":
raise MethodNotDefined("no c_code for float16")
acc_type = TensorType(shape=node.outputs[0].type.shape, dtype=acc_dtype)
adtype = acc_type.dtype_specs()[1]
acc_dtype = acc_type.dtype_specs()[1]
else:
adtype = odtype
acc_dtype = out_dtype

axis = self.axis
if axis is None:
axis = list(range(input.type.ndim))
axis = list(range(inp.type.ndim))

if len(axis) == 0:
# This is just an Elemwise cast operation
# The acc_dtype is never a downcast compared to the input dtype
# So we just need a cast to the output dtype.
var = pytensor.tensor.basic.cast(input, node.outputs[0].dtype)
if var is input:
var = Elemwise(scalar_identity)(input)
var = pytensor.tensor.basic.cast(inp, node.outputs[0].dtype)
if var is inp:
var = Elemwise(scalar_identity)(inp)
assert var.dtype == node.outputs[0].dtype
return var.owner.op._c_all(var.owner, name, inames, onames, sub)

order1 = [i for i in range(input.type.ndim) if i not in axis]
order = order1 + list(axis)
return var.owner.op._c_all(var.owner, name, input_names, output_names, sub)

nnested = len(order1)
inp_dims = list(range(ndim))
non_reduced_dims = [i for i in inp_dims if i not in axis]
counter = iter(range(ndim))
acc_dims = ["x" if i in axis else next(counter) for i in range(ndim)]

sub = dict(sub)
for i, (input, iname) in enumerate(zip(node.inputs, inames)):
sub[f"lv{i}"] = iname
sub = sub.copy()
sub["lv0"] = inp_name
sub["lv1"] = out_name
sub["olv"] = out_name

decl = ""
if adtype != odtype:
if acc_dtype != out_dtype:
# Create an accumulator variable different from the output
aname = "acc"
decl = acc_type.c_declare(aname, sub)
decl += acc_type.c_init(aname, sub)
acc_name = "acc"
setup = acc_type.c_declare(acc_name, sub) + acc_type.c_init(acc_name, sub)
else:
# the output is the accumulator variable
aname = oname

decl += cgen.make_declare([order], [idtype], sub)
checks = cgen.make_checks([order], [idtype], sub)

alloc = ""
i += 1
sub[f"lv{i}"] = oname
sub["olv"] = oname

# Allocate output buffer
alloc += cgen.make_declare(
[list(range(nnested)) + ["x"] * len(axis)], [odtype], dict(sub, lv0=oname)
)
alloc += cgen.make_alloc([order1], odtype, sub)
alloc += cgen.make_checks(
[list(range(nnested)) + ["x"] * len(axis)], [odtype], dict(sub, lv0=oname)
acc_name = out_name
setup = ""

# Define strides of input array
setup += cgen.make_declare(
[inp_dims], [inp_dtype], sub, compute_stride_jump=False
) + cgen.make_checks([inp_dims], [inp_dtype], sub, compute_stride_jump=False)

# Define strides of output array and allocate it
out_sub = sub | {"lv0": out_name}
alloc = (
cgen.make_declare(
[acc_dims], [out_dtype], out_sub, compute_stride_jump=False
)
+ cgen.make_alloc([non_reduced_dims], out_dtype, sub)
+ cgen.make_checks(
[acc_dims], [out_dtype], out_sub, compute_stride_jump=False
)
)

if adtype != odtype:
# Allocate accumulation buffer
sub[f"lv{i}"] = aname
sub["olv"] = aname
if acc_dtype != out_dtype:
# Define strides of accumulation buffer and allocate it
sub["lv1"] = acc_name
sub["olv"] = acc_name

alloc += cgen.make_declare(
[list(range(nnested)) + ["x"] * len(axis)],
[adtype],
dict(sub, lv0=aname),
)
alloc += cgen.make_alloc([order1], adtype, sub)
alloc += cgen.make_checks(
[list(range(nnested)) + ["x"] * len(axis)],
[adtype],
dict(sub, lv0=aname),
acc_sub = sub | {"lv0": acc_name}
alloc += (
cgen.make_declare(
[acc_dims], [acc_dtype], acc_sub, compute_stride_jump=False
)
+ cgen.make_alloc([non_reduced_dims], acc_dtype, sub)
+ cgen.make_checks(
[acc_dims], [acc_dtype], acc_sub, compute_stride_jump=False
)
)

identity = self.scalar_op.identity

if np.isposinf(identity):
if input.type.dtype in ("float32", "float64"):
if inp.type.dtype in ("float32", "float64"):
identity = "__builtin_inf()"
elif input.type.dtype.startswith("uint") or input.type.dtype == "bool":
elif inp.type.dtype.startswith("uint") or inp.type.dtype == "bool":
identity = "1"
else:
identity = "NPY_MAX_" + str(input.type.dtype).upper()
identity = "NPY_MAX_" + str(inp.type.dtype).upper()
elif np.isneginf(identity):
if input.type.dtype in ("float32", "float64"):
if inp.type.dtype in ("float32", "float64"):
identity = "-__builtin_inf()"
elif input.type.dtype.startswith("uint") or input.type.dtype == "bool":
elif inp.type.dtype.startswith("uint") or inp.type.dtype == "bool":
identity = "0"
else:
identity = "NPY_MIN_" + str(input.type.dtype).upper()
identity = "NPY_MIN_" + str(inp.type.dtype).upper()
elif identity is None:
raise TypeError(f"The {self.scalar_op} does not define an identity.")

task0_decl = f"{adtype}& {aname}_i = *{aname}_iter;\n{aname}_i = {identity};"

task1_decl = f"{idtype}& {inames[0]}_i = *{inames[0]}_iter;\n"
initial_value = f"{acc_name}_i = {identity};"

task1_code = self.scalar_op.c_code(
inner_task = self.scalar_op.c_code(
Apply(
self.scalar_op,
[
Expand All @@ -1588,44 +1587,45 @@ def _c_all(self, node, name, inames, onames, sub):
],
),
None,
[f"{aname}_i", f"{inames[0]}_i"],
[f"{aname}_i"],
[f"{acc_name}_i", f"{inp_name}_i"],
[f"{acc_name}_i"],
sub,
)
code1 = f"""
{{
{task1_decl}
{task1_code}
}}
"""

if node.inputs[0].type.ndim:
if len(axis) == 1:
all_code = [("", "")] * nnested + [(task0_decl, code1), ""]
else:
all_code = (
[("", "")] * nnested
+ [(task0_decl, "")]
+ [("", "")] * (len(axis) - 2)
+ [("", code1), ""]
)
if out.type.ndim == 0:
# Simple case where everything is reduced, no need for loop ordering
loop = cgen.make_complete_loop_careduce(
inp_var=inp_name,
acc_var=acc_name,
inp_dtype=inp_dtype,
acc_dtype=acc_dtype,
initial_value=initial_value,
inner_task=inner_task,
fail_code=sub["fail"],
)
else:
all_code = [task0_decl + code1]
loop = cgen.make_loop_careduce(
[order, list(range(nnested)) + ["x"] * len(axis)],
[idtype, adtype],
all_code,
sub,
)
loop = cgen.make_reordered_loop_careduce(
inp_var=inp_name,
acc_var=acc_name,
inp_dtype=inp_dtype,
acc_dtype=acc_dtype,
inp_ndim=ndim,
reduction_axes=axis,
initial_value=initial_value,
inner_task=inner_task,
)

end = ""
if adtype != odtype:
end = f"""
PyArray_CopyInto({oname}, {aname});
"""
end += acc_type.c_cleanup(aname, sub)
if acc_dtype != out_dtype:
cast = dedent(
f"""
PyArray_CopyInto({out_name}, {acc_name});
{acc_type.c_cleanup(acc_name, sub)}
"""
)
else:
cast = ""

return decl, checks, alloc, loop, end
return setup, alloc, loop, cast

def c_code(self, node, name, inames, onames, sub):
code = "\n".join(self._c_all(node, name, inames, onames, sub))
Expand All @@ -1637,7 +1637,7 @@ def c_headers(self, **kwargs):

def c_code_cache_version_apply(self, node):
# the version corresponding to the c code in this Op
version = [9]
version = [10]

# now we insert versions for the ops on which we depend...
scalar_node = Apply(
Expand Down
Loading
Loading