From be7a9cea05e50896dfb0117d62778bb9d5b4a2d1 Mon Sep 17 00:00:00 2001 From: 6clc Date: Mon, 9 Oct 2023 19:12:20 +0800 Subject: [PATCH] cinn(py-dsl): parse compute of python dsl (#57731) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 拆分新特性:CINN Python DSL, 主PR和单测见:#56393 此PR只负责 解析python dsl中的compute定义 1. 装饰器@to_cinn_ir封装cinn的function kernel: CinnLowerLevelIrJit支持从Jit运行时中数据类型、target类型、python ast。后续解析compute的信息都会从CinnLowerLevelIrJit这个类中获取。 CinnLowerLevelIrJit也支持静态获取上述信息,通过python的annotation来填充。 2. compute 语义解析 将整个AST分为三种类型: stmts: Function, For, If, With ,对应封装上下文IR的PR: #57515 Assign: 表达式"lhs = rhs"的类型,Assign类型构成了stmts。 python/cinn/compiler/expr_executor.py中的exec_expr方法将rhs解析成cinn ir Expr python/cinn/compiler/expr_executor.py中的exec_assign方法,将lhs=rhs表达的assign语义存储在局部变量表中。 Expr:组成Assign中的rhs。 3. 变量管理 python/cinn/compiler/utils.py中的class VariableTable:用于管理Python DSL中定义的变量,主要是下面两个功能。 每次Enter新的Context,会复制当前的变量表 每次Exit Context,会删除当前Context增加的变量,恢复上一轮Context的变量表。 --- paddle/cinn/pybind/ir/ir_api.cc | 2 + python/cinn/__init__.py | 3 +- python/cinn/compiler/__init__.py | 17 ++ python/cinn/compiler/compiler.py | 38 +++ .../cinn/compiler/compute_code_generator.py | 245 ++++++++++++++++++ python/cinn/compiler/expr_executor.py | 159 ++++++++++++ python/cinn/compiler/utils.py | 76 ++++++ python/cinn/ir/ir.py | 2 +- python/cinn/runtime/__init__.py | 4 + python/cinn/runtime/cinn_jit.py | 115 ++++++++ python/cinn/runtime/utils.py | 35 +++ 11 files changed, 694 insertions(+), 2 deletions(-) create mode 100644 python/cinn/compiler/__init__.py create mode 100644 python/cinn/compiler/compiler.py create mode 100644 python/cinn/compiler/compute_code_generator.py create mode 100644 python/cinn/compiler/expr_executor.py create mode 100644 python/cinn/compiler/utils.py create mode 100644 python/cinn/runtime/cinn_jit.py create mode 100644 python/cinn/runtime/utils.py diff --git a/paddle/cinn/pybind/ir/ir_api.cc b/paddle/cinn/pybind/ir/ir_api.cc index ffbfd3375bf75..2170f360f5062 100644 --- a/paddle/cinn/pybind/ir/ir_api.cc +++ b/paddle/cinn/pybind/ir/ir_api.cc @@ -843,6 +843,8 @@ void BindIrContext(py::module *m) { .def_static("MakeThenContext", []() { return IRContext(new ThenContextNode()); }); + m->def("link_to_parent_context", &pybind::LinkToParentContext); + py::class_ ir_builder(*m, "IRBuilder"); ir_builder.def(py::init<>()) .def("EnterWithContext", &IRBuilder::EnterWithContext) diff --git a/python/cinn/__init__.py b/python/cinn/__init__.py index 9411b774e3836..55ab35e7e5624 100644 --- a/python/cinn/__init__.py +++ b/python/cinn/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .version import full_version as __version__ +from .runtime.cinn_jit import to_cinn_llir import os cinndir = os.path.dirname(os.path.abspath(__file__)) @@ -189,4 +191,3 @@ reduce_mul, reduce_sum, ) -from .version import full_version as __version__ diff --git a/python/cinn/compiler/__init__.py b/python/cinn/compiler/__init__.py new file mode 100644 index 0000000000000..644bf2d949ca4 --- /dev/null +++ b/python/cinn/compiler/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .compiler import compile + +__all__ = ["compile"] diff --git a/python/cinn/compiler/compiler.py b/python/cinn/compiler/compiler.py new file mode 100644 index 0000000000000..330d34962641d --- /dev/null +++ b/python/cinn/compiler/compiler.py @@ -0,0 +1,38 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ..runtime import CinnLowerLevelIrJit +from .compute_code_generator import ComputeCodeGenerator + + +def ast_to_llir(fn, inputs_signature): + function_name = fn.__name__ + # 1. Parse CINN Compute + llir_compute_generator = ComputeCodeGenerator( + fn, function_name, inputs_signature + ) + cinn_llir_func = llir_compute_generator.parse() + return cinn_llir_func + + +def compile(fn, just_convert=False, jit_inputs_signature=[], **kwargs): + if isinstance(fn, CinnLowerLevelIrJit): + llir_func = ast_to_llir(fn, jit_inputs_signature) + else: + raise Exception("Current Only support compile from CinnLowerLevelIrJit") + + if just_convert: + return llir_func + return llir_func diff --git a/python/cinn/compiler/compute_code_generator.py b/python/cinn/compiler/compute_code_generator.py new file mode 100644 index 0000000000000..9a54c504306f3 --- /dev/null +++ b/python/cinn/compiler/compute_code_generator.py @@ -0,0 +1,245 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import contextlib + +from cinn import ir + +from .expr_executor import ExprExecutor, exec_assign +from .utils import VariableTable, is_node_parsed_in_schedule + + +class ComputeCodeGenerator(ast.NodeVisitor): + """ + Convert python ast to CINN Lower Level IR, + containing only the semantics of the compute part + """ + + def __init__(self, fn, function_name, inputs_signature): + self.fn = fn + self.function_name = function_name + self.inputs_signature = inputs_signature + self.cinn_llir_func = None + self.variables_table = VariableTable() + self.extra_scope = {"range": ir.sequential} + + def parse(self): + ast_node = self.fn.parse() + with ir.IRBuilder() as builder, self.variables_table: + for k, v in self.fn.scope.items(): + self.variables_table.add(k, v) + for k, v in self.extra_scope.items(): + self.variables_table.add(k, v) + self.visit(ast_node) + return builder.get() + + def visit_FunctionDef(self, node) -> None: + """ + Parse CINN Low Level IR FunctionDef. + + Args: + node(ast.FunctionDef): The ast FunctionDef Node + """ + with ir.LowerFuncContext(self.function_name) as func_ctx: + arg_names = self.visit(node.args) + + assert len(node.args.defaults) == 0, "Not support default args" + + # 1. Construct args of function + for i, arg_name in enumerate(arg_names): + # Obj of Argument is ir::Buffer + if hasattr(self.inputs_signature[i], "dtype"): + tensor_shape = [ + ir.Expr(dim) for dim in self.inputs_signature[i].shape + ] + llir_value = ir._Buffer_.make( + arg_name, self.inputs_signature[i].dtype + ) + ir.Arg(arg_name, llir_value) + llir_value = ir._Tensor_.make( + arg_name, + self.inputs_signature[i].dtype, + tensor_shape, + tensor_shape, + ) + self.variables_table.add(arg_name, llir_value) + # Obj of Argument is ir::Var + else: + llir_value = ir.Var(arg_name) + ir.Arg(arg_name, llir_value) + llir_value = ir.Expr(llir_value) + self.variables_table.add(arg_name, llir_value) + + # 2. Construct body of function + body = self.visit_compound_statement(node.body) + + def visit_compound_statement(self, stmts): + for stmt in stmts: + self.visit(stmt) + + def visit_arguments(self, node): + """ + Parse CINN Low Level IR Argument. + If it is not jit mode, it will get information from arg.annoatation. + + Args: + node(ast.arguments): The ast argument Node + + Returns: + list[string]: A list of parameter names + """ + arg_names = [arg.arg for arg in node.args] + + if len(self.inputs_signature) != len(arg_names): + self.inputs_signature = [] + for arg in node.args: + arg_annotation = arg.annotation + if isinstance(arg_annotation, ast.Call): + self.inputs_signature.append( + ExprExecutor(self.variables_table.get()).exec( + arg_annotation + ) + ) + elif isinstance(arg_annotation, int): + if ( + -(2**21) <= arg_annotation + and arg_annotation <= 2**31 - 1 + ): + self.inputs_signature.append("i32") + elif ( + 2**63 <= arg_annotation + and arg_annotation <= 2**64 - 1 + ): + self.inputs_signature.append("u64") + else: + self.inputs_signature.append("i64") + elif isinstance(arg_annotation, float): + return self.inputs_signature.append("fp32") + else: + raise TypeError( + f'Unsupported type {type(arg_annotation)} for {arg_annotation}' + ) + + return arg_names + + def visit_For(self, node) -> ir.Expr: + """ + parse CINN Low Level IR For. + + Args: + node(ast.For): The ast For node + """ + for_ctx = ExprExecutor(self.variables_table.get()).exec(node.iter) + with self.variables_table: + with for_ctx as loop_var: + local_var_table = exec_assign( + target=node.target, source=loop_var + ) + for k, v in local_var_table.items(): + loop_var.rename(k) + self.variables_table.add(k, ir.Expr(v)) + self.visit_compound_statement(node.body) + + def visit_Assign(self, node): + """ + parse CINN Low Level IR Store. + + Args: + node(ast.Assign): The ast Assign node + + Returns: + ir.Expr, Points to the Expr of ir::ExprNode + """ + + if isinstance(node.value, ast.Call) and is_node_parsed_in_schedule( + node.value + ): + return "no compute" + + assert ( + len(node.targets) == 1 + ), "Unsupport targets is a \ + list of nodes, like 'a = b = c'" + lhs = node.targets[0] + + # 1 parse RHS + rhs_expr = ExprExecutor(self.variables_table.get()).exec(node.value) + + # 2 parse LHS + # 2.1 Type of arg is Tensor + if isinstance(lhs, ast.Subscript): + expr_tensor = ExprExecutor(self.variables_table.get()).exec( + lhs.value + ) + if isinstance(lhs.slice, ast.Tuple): + expr_indices = [] + for idx in lhs.slice.elts: + expr_indices.append( + ExprExecutor(self.variables_table.get()).exec(idx) + ) + else: + expr_indices = [ + ExprExecutor(self.variables_table.get()).exec(lhs.slice) + ] + if not isinstance(rhs_expr, ir.Expr): + rhs_expr = ir.Expr(rhs_expr) + ir.TensorStore(expr_tensor.Expr(), rhs_expr, expr_indices) + # 2.2 Type of arg is Var + else: + local_var_table = exec_assign(target=lhs, source=rhs_expr) + if isinstance(lhs, ast.Tuple): + for k, v in local_var_table.items(): + v.as_var_ref().rename(k) + self.variables_table.add(k, v) + else: + for k, v in local_var_table.items(): + v[0].as_var_ref().rename(k) + self.variables_table.add(k, v[0]) + + def visit_If(self, node): + with self.variables_table: + with ir.IfContext( + ExprExecutor(self.variables_table.get()).exec(node.test) + ): + with ir.ThenContext(): + with self.variables_table: + self.visit_compound_statement(node.body) + if node.orelse: + with ir.ElseContext(): + with self.variables_table: + self.visit_compound_statement(node.body) + + def visit_With(self, node): + with self.variables_table: + with contextlib.ExitStack() as context_stack: + for item in node.items: + cur_ctx = ExprExecutor(self.variables_table.get()).exec( + item.context_expr + ) + cur_ctx = context_stack.enter_context(cur_ctx) + if item.optional_vars is not None: + local_var_table = exec_assign( + target=item.optional_vars, source=cur_ctx + ) + for k, v in local_var_table.items(): + self.variables_table.add(k, v) + body = self.visit_compound_statement(node.body) + + def visit_Expr(self, node): + if is_node_parsed_in_schedule(node.value): + return + res = ExprExecutor(self.variables_table.get()).exec(node.value) + if isinstance(res, ir.Expr): + ir.link_to_parent_context(res) diff --git a/python/cinn/compiler/expr_executor.py b/python/cinn/compiler/expr_executor.py new file mode 100644 index 0000000000000..cff9a9d62d7c4 --- /dev/null +++ b/python/cinn/compiler/expr_executor.py @@ -0,0 +1,159 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast + +from cinn import ir + +# The Python native AST node that cinn ir supports +AST2CINN = { + ast.Add: ir.Add, + ast.Sub: ir.Sub, + ast.Mult: ir.Mul, + ast.Div: ir.Div, + ast.Mod: ir.Mod, + ast.And: ir.And, + ast.Or: ir.Or, + ast.USub: ir.Minus, + ast.Not: ir.Not, + ast.Eq: ir.EQ, + ast.NotEq: ir.NE, + ast.Lt: ir.LT, + ast.LtE: ir.LE, + ast.Gt: ir.GT, + ast.GtE: ir.GE, +} + + +class ExprExecutor: + def __init__(self, var_table): + self.var_table = var_table + self.tmp_value_count = 1 + + def exec(self, node): + ret = self.visit(node) + if isinstance(ret, ast.Name): + return self.var_table[ret.id] + if isinstance(ret, ast.Constant): + return ret.value + raise Exception(f"Error result type: {type(ret)}") + + def visit(self, node): + if isinstance(node, list): + return [self.visit(item) for item in node] + if isinstance(node, tuple): + return (self.visit(item) for item in node) + assert isinstance(node, ast.AST) + if isinstance(node, ast.Name): + return node + + if isinstance(node, ast.Constant): + return node + + if not isinstance(node, (ast.expr, ast.slice)): + # some nodes don't need to parse, such as ast.Load + return node + if isinstance(node, (ast.Lambda, ast.Starred)): + raise Exception("Current not suporrted: Lambda, Starred") + + cls_fields = {} + for field in node.__class__._fields: + attr = getattr(node, field) + if isinstance(attr, (ast.AST, tuple, list)): + cls_fields[field] = self.visit(attr) + else: + cls_fields[field] = attr + + node_type_name = f'eval_{type(node).__name__}' + if hasattr(self, node_type_name): + exec_func = getattr(self, node_type_name) + value = exec_func(cls_fields) + else: + new_node = node.__class__(**cls_fields) + ast.copy_location(new_node, node) + new_node = ast.Expression(new_node) + value = self.exec_expr(new_node) + return self.save_temp_value(value) + + def exec_expr(self, node): + if isinstance(node, ast.expr): + node = ast.Expression(body=node) + node = ast.fix_missing_locations(node) + exec = compile(node, filename="", mode="eval") + return eval(exec, self.var_table) + + def eval_BinOp(self, fields): + args = [self.exec_expr(fields["left"]), self.exec_expr(fields["right"])] + args = [ + ir.Expr(item) if not isinstance(item, ir.Expr) else item + for item in args + ] + return AST2CINN[type(fields["op"])].make(*args) + + def eval_UnaryOp(self, fields): + args = [self.exec_expr(fields["operand"])] + args = [ + ir.Expr(item) if not isinstance(item, ir.Expr) else item + for item in args + ] + return AST2CINN[type(fields["op"])].make(*args) + + def eval_Compare(self, fields): + assert ( + len(fields["ops"]) == 1 + ), "Only binary comparison symbols are supported. Expressions such as '1 <= a < 10' are not supported." + args = [ + self.exec_expr(fields["left"]), + self.exec_expr(fields["comparators"][0]), + ] + args = [ + ir.Expr(item) if not isinstance(item, ir.Expr) else item + for item in args + ] + return AST2CINN[type(fields["ops"][0])].make(*args) + + def save_temp_value(self, value): + name = f"__cinn_python_script_tmp_value_{self.tmp_value_count}" + self.tmp_value_count += 1 + self.var_table[name] = value + return ast.Name( + id=name, + ctx=ast.Load( + lineno=0, col_offset=0, end_lineno=None, end_col_offset=None + ), + lineno=0, + col_offset=0, + end_lineno=None, + end_col_offset=None, + ) + + +def exec_assign(target, source): + right_value_var_name = "__CINN_RIGHT_VALUE_VAR_NAME__" + local_var_table = {right_value_var_name: source} + mod = ast.fix_missing_locations( + ast.Module( + body=[ + ast.Assign( + targets=[target], + value=ast.Name(id=right_value_var_name, ctx=ast.Load()), + ) + ], + type_ignores=[], + ) + ) + exe = compile(mod, filename="", mode="exec") + exec(exe, {}, local_var_table) + del local_var_table[right_value_var_name] + return local_var_table diff --git a/python/cinn/compiler/utils.py b/python/cinn/compiler/utils.py new file mode 100644 index 0000000000000..6f78446245fb4 --- /dev/null +++ b/python/cinn/compiler/utils.py @@ -0,0 +1,76 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import ast + +try: + from _collections import defaultdict +except ImportError: + pass + + +from cinn.schedule import IRSchedule + + +def is_node_parsed_in_schedule(node: ast.Call): + func_name = "" + if isinstance(node.func, ast.Name): + func_name = node.func.id + elif isinstance(node.func, ast.Attribute): + func_name = node.func.attr + if func_name == "make": + return False + if func_name == "print": + return True + + return getattr(IRSchedule, func_name, None) + + +def node_is_schedule_block_context(node: ast.Call): + if isinstance(node.func, ast.Name): + return node.Name == "ScheduleBlockContext" + if isinstance(node.func, ast.Attribute): + return node.func.attr == "ScheduleBlockContext" + return False + + +class VariableTable: + def __init__(self): + # var name added by current context + self.var_name_list = [] + # var name to var. Dtype is {string:list} + # list records the value assigned to each layer of context + self.name2value = defaultdict(list) + + def __enter__(self): + self.var_name_list.append([]) + return self + + def __exit__(self, ptype, value, trace) -> None: + # clear var assign in current context + if ptype is None and value is None: + var_names = self.var_name_list.pop() + for var_name in var_names: + self.name2value[var_name].pop() + if len(self.name2value[var_name]) == 0: + self.name2value.pop(var_name) + + def add(self, name, value, cover=False): + if cover and name in self.var_name_list[-1]: + self.name2value[name][-1] = value + else: + self.var_name_list[-1].append(name) + self.name2value[name].append(value) + + def get(self): + return {k: v[-1] for k, v in self.name2value.items()} diff --git a/python/cinn/ir/ir.py b/python/cinn/ir/ir.py index 5c683de04e705..7d51a302a3dfb 100644 --- a/python/cinn/ir/ir.py +++ b/python/cinn/ir/ir.py @@ -17,7 +17,7 @@ from .ir_context import ForContext -# Python's rang() function calls the sequential() +# Python's range() function calls the sequential() def sequential(min, extent=None): if extent is None: extent = min diff --git a/python/cinn/runtime/__init__.py b/python/cinn/runtime/__init__.py index a9f32b12d0e22..70753e812e6b6 100644 --- a/python/cinn/runtime/__init__.py +++ b/python/cinn/runtime/__init__.py @@ -66,3 +66,7 @@ seed, set_cinn_cudnn_deterministic, ) + +from .cinn_jit import CinnLowerLevelIrJit + +__all__ = ["CinnLowerLevelIrJit"] diff --git a/python/cinn/runtime/cinn_jit.py b/python/cinn/runtime/cinn_jit.py new file mode 100644 index 0000000000000..7b85808593d62 --- /dev/null +++ b/python/cinn/runtime/cinn_jit.py @@ -0,0 +1,115 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import ast +import functools +import inspect +import textwrap +from typing import Callable, Generic, Optional, TypeVar, Union, cast + +from .utils import inspect_function_scope + +T = TypeVar('T') + + +class CinnLowerLevelIrJit(Generic[T]): + def __init__(self, fn): + self.fn = fn + # function prototype + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] + + self.src = textwrap.dedent(inspect.getsource(fn)) + self.src = self.src[self.src.find("def") :] + self.scope = inspect_function_scope(fn) + + # docs of warpped function + self.__doc__ = fn.__doc__ + self.__name__ = fn.__name__ + self.__globals__ = fn.__globals__ + self.__module__ = fn.__module__ + + # Encapsulates the compile and run processes + self.run = self._make_launcher() + + def _make_launcher(self): + # Gets information about runtime input parameters + jit_input_args = ', '.join(arg_name for arg_name in self.arg_names) + lazy_compile = f""" +import cinn +def {self.fn.__name__}({jit_input_args}, target=cinn.common.DefaultHostTarget()): + from cinn.compiler import compile + jit_inputs = {', '.join([f'{arg}' for arg in self.arg_names])} + jit_inputs_signature = {{ i: self._convert_arg_type(arg) \ + for i, arg in enumerate(jit_inputs)}} + module = compile(self, jit_inputs_signature=jit_inputs_signature, arg_names={ + self.arg_names}, target=target) + module({jit_input_args}) + + return module + """ + scope = { + "self": self, + } + exec(lazy_compile, scope) + return scope[self.fn.__name__] + + def convert_to_llir(self): + from cinn.compiler import compile + + return compile(self, just_convert=True) + + def parse(self): + tree = ast.parse(self.src) + assert isinstance(tree, ast.Module) + return tree + + def __getitem__(self, target): + return cast( + T, functools.partial(cast(Callable, self.run), target=target) + ) + + def _convert_arg_type(self, arg): + # arg is a Tensor + if hasattr(arg, "dtype"): + return arg + # arg is a Var + else: + if isinstance(arg, int): + if -(2**21) <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return "fp32" + else: + raise TypeError(f'Unsupported type {type(arg)} for {arg}') + + def __str__(self): + return str(self.convert_to_llir()) + + +def to_cinn_llir( + fn: Optional[T] = None, +) -> Union[CinnLowerLevelIrJit[T]]: + def decorator(fn: T) -> CinnLowerLevelIrJit[T]: + return CinnLowerLevelIrJit(fn) + + if fn is not None: + return decorator(fn) + else: + return decorator diff --git a/python/cinn/runtime/utils.py b/python/cinn/runtime/utils.py new file mode 100644 index 0000000000000..8df8cccc772d1 --- /dev/null +++ b/python/cinn/runtime/utils.py @@ -0,0 +1,35 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + + +def get_func_global_vars(func): + if inspect.ismethod(func): + func = func.__func__ + + code = func.__code__ + global_vars = {} + if func.__closure__ is not None: + for k, v in zip(code.co_freevars, func.__closure__): + global_vars[k] = v.cell_contents + return global_vars + + +def inspect_function_scope(func): + scope = { + **func.__globals__, + **get_func_global_vars(func), + } + return scope