Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor[next]: Typing for bindings #1218

Merged
merged 5 commits into from
Mar 27, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 26 additions & 59 deletions src/gt4py/next/otf/binding/cpp_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

import ctypes
import types
from typing import Final, Sequence, Type

import numpy as np
from typing import Final, Sequence

import gt4py.next.type_system.type_specifications as ts
from gt4py.next.otf import languages
from gt4py.next.otf.binding import interface

Expand All @@ -29,63 +26,33 @@
header_extension="cpp.inc",
)

_TYPE_MAPPING: Final = types.MappingProxyType(
{
bool: "bool",
int: "long",
float: "double",
complex: "std::complex<double>",
np.bool_: "bool",
np.byte: "signed char",
np.ubyte: "unsigned char",
np.short: "short",
np.ushort: "unsigned short",
np.intc: "int",
np.uintc: "unsigned int",
np.int_: "long",
np.uint: "unsigned long",
np.longlong: "long long",
np.ulonglong: "unsigned long long",
np.single: "float",
np.double: "double",
np.longdouble: "long double",
np.csingle: "std::complex<float>",
np.cdouble: "std::complex<double>",
np.clongdouble: "std::complex<long double>",
ctypes.c_bool: "bool",
ctypes.c_char: "char",
ctypes.c_wchar: "wchar_t",
ctypes.c_byte: "char",
ctypes.c_ubyte: "unsigned char",
ctypes.c_short: "short",
ctypes.c_ushort: "unsigned short",
ctypes.c_int: "int",
ctypes.c_uint: "unsigned int",
ctypes.c_long: "long",
ctypes.c_ulong: "unsigned long",
ctypes.c_longlong: "long long",
ctypes.c_ulonglong: "unsigned long long",
ctypes.c_size_t: "std::size_t",
ctypes.c_ssize_t: "std::ptrdiff_t",
ctypes.c_float: "float",
ctypes.c_double: "double",
ctypes.c_longdouble: "long double",
}
)


def render_python_type(python_type: Type) -> str:
return _TYPE_MAPPING[python_type]
def render_scalar_type(scalar_type: ts.ScalarType) -> str:
if scalar_type.kind == ts.ScalarKind.BOOL:
return "bool"
elif scalar_type.kind == ts.ScalarKind.INT32:
return "std::int32_t"
elif scalar_type.kind == ts.ScalarKind.INT64:
return "std::int64_t"
elif scalar_type.kind == ts.ScalarKind.FLOAT32:
return "float"
elif scalar_type.kind == ts.ScalarKind.FLOAT64:
return "double"
elif scalar_type.kind == ts.ScalarKind.STRING:
return "std::string"
elif scalar_type.kind == ts.ScalarKind.DIMENSION:
raise AssertionError(f"Deprecated type '{scalar_type}' is not supported.")
else:
raise AssertionError(f"Scalar kind '{scalar_type}' is not implemented when it should be.")
havogt marked this conversation as resolved.
Show resolved Hide resolved


def _render_function_param(
param: interface.ScalarParameter | interface.BufferParameter | interface.ConnectivityParameter,
index: int,
) -> str:
if isinstance(param, interface.ScalarParameter):
return f"{render_python_type(param.scalar_type.type)} {param.name}"
else:
def _render_function_param(param: interface.Parameter, index: int) -> str:
if isinstance(param.type_, ts.ScalarType):
return f"{render_scalar_type(param.type_)} {param.name}"
elif isinstance(param.type_, ts.FieldType):
return f"BufferT{index}&& {param.name}"
else:
raise ValueError(f"Type '{param.type_}' is not supported in C++ interfaces.")


def render_function_declaration(function: interface.Function, body: str) -> str:
Expand All @@ -98,7 +65,7 @@ def render_function_declaration(function: interface.Function, body: str) -> str:
template_params = [
f"class BufferT{index}"
for index, param in enumerate(function.parameters)
if isinstance(param, (interface.BufferParameter, interface.ConnectivityParameter))
if isinstance(param.type_, ts.FieldType)
]
if template_params:
return f"""
Expand Down
24 changes: 4 additions & 20 deletions src/gt4py/next/otf/binding/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

import dataclasses

import numpy as np

import gt4py.next.type_system.type_specifications as ts
from gt4py.eve import codegen
from gt4py.next.otf import languages

Expand All @@ -27,30 +26,15 @@ def format_source(settings: languages.LanguageSettings, source):


@dataclasses.dataclass(frozen=True)
class ScalarParameter:
name: str
scalar_type: np.dtype


@dataclasses.dataclass(frozen=True)
class BufferParameter:
name: str
dimensions: tuple[str, ...]
scalar_type: np.dtype


@dataclasses.dataclass(frozen=True)
class ConnectivityParameter:
class Parameter:
name: str
origin_axis: str
offset_tag: str
index_type: type[np.int32] | type[np.int64]
type_: ts.TypeSpec


@dataclasses.dataclass(frozen=True)
class Function:
name: str
parameters: tuple[ScalarParameter | BufferParameter | ConnectivityParameter, ...]
parameters: tuple[Parameter, ...]


@dataclasses.dataclass(frozen=True)
Expand Down
70 changes: 25 additions & 45 deletions src/gt4py/next/otf/binding/pybind.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@

from typing import Any, Sequence

import numpy as np

import gt4py.eve as eve
import gt4py.next.type_system.type_specifications as ts
from gt4py.eve.codegen import JinjaTemplate as as_jinja, TemplatedGenerator
from gt4py.next.otf import languages, stages, workflow
from gt4py.next.otf.binding import cpp_interface, interface
Expand All @@ -35,10 +34,10 @@ class DimensionType(Expr):
name: str


class SidConversion(Expr):
buffer_name: str
class BufferSID(Expr):
source_buffer: str
dimensions: Sequence[DimensionType]
scalar_type: np.dtype
scalar_type: ts.ScalarType
dim_config: int


Expand All @@ -53,8 +52,7 @@ class ReturnStmt(eve.Node):

class FunctionParameter(eve.Node):
name: str
ndim: int
dtype: np.dtype
type_: ts.TypeSpec


class WrapperFunction(eve.Node):
Expand Down Expand Up @@ -109,11 +107,13 @@ class BindingCodeGenerator(TemplatedGenerator):
)

def visit_FunctionParameter(self, param: FunctionParameter):
if param.ndim > 0:
if isinstance(param.type_, ts.FieldType):
type_str = "pybind11::buffer"
elif isinstance(param.type_, ts.ScalarType):
type_str = cpp_interface.render_scalar_type(param.type_)
else:
type_str = cpp_interface.render_python_type(param.dtype.type)
return type_str + " " + param.name
raise ValueError(f"Type '{param.type_}' is not supported in pybind11 interfaces.")
return f"{type_str} {param.name}"

ReturnStmt = as_jinja("""return {{expr}};""")

Expand All @@ -132,59 +132,39 @@ def visit_FunctionCall(self, call: FunctionCall):
args = [self.visit(arg) for arg in call.args]
return cpp_interface.render_function_call(call.target, args)

def visit_SidConversion(self, sid: SidConversion):
def visit_BufferSID(self, sid: BufferSID):
return self.generic_visit(
sid, rendered_scalar_type=cpp_interface.render_python_type(sid.scalar_type.type)
sid, rendered_scalar_type=cpp_interface.render_scalar_type(sid.scalar_type)
)

SidConversion = as_jinja(
BufferSID = as_jinja(
"""gridtools::sid::rename_numbered_dimensions<{{", ".join(dimensions)}}>(
gridtools::as_sid<{{rendered_scalar_type}},\
{{dimensions.__len__()}},\
gridtools::integral_constant<int, {{dim_config}}>,\
999'999'999>({{buffer_name}})
999'999'999>({{source_buffer}})
)"""
)

DimensionType = as_jinja("""generated::{{name}}_t""")


def make_parameter(
havogt marked this conversation as resolved.
Show resolved Hide resolved
parameter: interface.ScalarParameter
| interface.BufferParameter
| interface.ConnectivityParameter,
parameter: interface.Parameter,
) -> FunctionParameter:
if isinstance(parameter, interface.ConnectivityParameter):
return FunctionParameter(name=parameter.name, ndim=2, dtype=parameter.index_type)
name = parameter.name
ndim = 0 if isinstance(parameter, interface.ScalarParameter) else len(parameter.dimensions)
scalar_type = parameter.scalar_type
return FunctionParameter(name=name, ndim=ndim, dtype=scalar_type)


def make_argument(
index: int,
param: interface.ScalarParameter | interface.BufferParameter | interface.ConnectivityParameter,
) -> str | SidConversion:
if isinstance(param, interface.ScalarParameter):
return param.name
elif isinstance(param, interface.ConnectivityParameter):
return SidConversion(
buffer_name=param.name,
dimensions=[
DimensionType(name=param.origin_axis),
DimensionType(name=param.offset_tag),
],
scalar_type=param.index_type,
return FunctionParameter(name=parameter.name, type_=parameter.type_)


def make_argument(index: int, param: interface.Parameter) -> str | BufferSID:
if isinstance(param.type_, ts.FieldType):
return BufferSID(
source_buffer=param.name,
dimensions=[DimensionType(name=dim.value) for dim in param.type_.dims],
scalar_type=param.type_.dtype,
dim_config=index,
)
else:
return SidConversion(
buffer_name=param.name,
dimensions=[DimensionType(name=dim) for dim in param.dimensions],
scalar_type=param.scalar_type,
dim_config=index,
)
return param.name


def create_bindings(
Expand Down
14 changes: 2 additions & 12 deletions src/gt4py/next/otf/compilation/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,8 @@ class Strategy(enum.Enum):
_persistent_cache_dir_path = pathlib.Path(tempfile.gettempdir()) / "gt4py_cache"


def _serialize_param(
parameter: interface.ScalarParameter
| interface.BufferParameter
| interface.ConnectivityParameter,
) -> str:
if isinstance(parameter, interface.ScalarParameter):
return f"{parameter.name}: {str(parameter.scalar_type)}"
elif isinstance(parameter, interface.BufferParameter):
return f"{parameter.name}: {str(parameter.scalar_type)}<{', '.join(parameter.dimensions)}>"
elif isinstance(parameter, interface.ConnectivityParameter):
return f"{parameter.name}: {parameter.offset_tag}"
raise ValueError("Invalid parameter type. This is a bug.")
def _serialize_param(parameter: interface.Parameter) -> str:
return f"{parameter.name}: {str(parameter.type_)}"


def _serialize_library_dependency(dependency: interface.LibraryDependency) -> str:
Expand Down
34 changes: 15 additions & 19 deletions src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,16 @@
from gt4py.next.otf import languages, stages, step_types, workflow
from gt4py.next.otf.binding import cpp_interface, interface
from gt4py.next.program_processors.codegens.gtfn import gtfn_backend
from gt4py.next.type_system import type_specifications as ts, type_translation


T = TypeVar("T")

GENERATED_CONNECTIVITY_PARAM_PREFIX = "gt_conn_"


def get_param_description(
name: str, obj: Any
) -> interface.ScalarParameter | interface.BufferParameter:
view: np.ndarray = np.asarray(obj)
if view.ndim > 0:
return interface.BufferParameter(name, tuple(dim.value for dim in obj.axes), view.dtype)
else:
return interface.ScalarParameter(name, view.dtype)
def get_param_description(name: str, obj: Any) -> interface.Parameter:
return interface.Parameter(name, type_translation.from_value(obj))


@dataclasses.dataclass(frozen=True)
Expand All @@ -53,7 +48,7 @@ def _process_regular_arguments(
program: itir.FencilDefinition,
args: tuple[Any, ...],
):
parameters: list[interface.ScalarParameter | interface.BufferParameter] = []
parameters: list[interface.Parameter] = []
arg_exprs: list[str] = []

# TODO(tehrengruber): The backend expects all arguments to a stencil closure to be a SID
Expand All @@ -78,7 +73,7 @@ def _process_regular_arguments(

# argument conversion expression
if (
isinstance(parameter, interface.ScalarParameter)
isinstance(parameter.type_, ts.ScalarType)
and parameter.name in closure_scalar_parameters
):
# convert into sid
Expand All @@ -92,7 +87,7 @@ def _process_connectivity_args(
self,
offset_provider: dict[str, Connectivity | Dimension],
):
parameters: list[interface.ConnectivityParameter] = []
parameters: list[interface.Parameter] = []
arg_exprs: list[str] = []

for name, connectivity in offset_provider.items():
Expand All @@ -104,11 +99,14 @@ def _process_connectivity_args(

# parameter
parameters.append(
interface.ConnectivityParameter(
GENERATED_CONNECTIVITY_PARAM_PREFIX + name.lower(),
connectivity.origin_axis.value,
name,
connectivity.index_type, # type: ignore[arg-type]
interface.Parameter(
name=GENERATED_CONNECTIVITY_PARAM_PREFIX + name.lower(),
type_=ts.FieldType(
dims=[connectivity.origin_axis, Dimension(name)],
dtype=ts.ScalarType(
type_translation.get_scalar_kind(connectivity.index_type)
),
),
)
)

Expand Down Expand Up @@ -150,9 +148,7 @@ def __call__(
)

# combine into a format that is aligned with what the backend expects
parameters: list[
interface.ScalarParameter | interface.BufferParameter | interface.ConnectivityParameter
] = [*regular_parameters, *connectivity_parameters]
parameters: list[interface.Parameter] = [*regular_parameters, *connectivity_parameters]
havogt marked this conversation as resolved.
Show resolved Hide resolved
args_expr: list[str] = ["gridtools::fn::backend::naive{}", *regular_args_expr]

function = interface.Function(program.id, tuple(parameters))
Expand Down
Loading