Skip to content

Commit

Permalink
python methods (apache#38)
Browse files Browse the repository at this point in the history
* ir builder in python

* `ForFrame`s in python

* python methods

* rename and add @staticmethod for current builder

* POC demo in python

* apply code review suggestions

* apply code review suggestions

* apply code review suggestions

* apply code review suggestions
  • Loading branch information
cyx-6 authored and junrushao committed Jul 4, 2022
1 parent c4b5ac7 commit 30224ed
Show file tree
Hide file tree
Showing 21 changed files with 505 additions and 1 deletion.
1 change: 1 addition & 0 deletions python/tvm/script/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@

from . import tir

from .builder import Builder
from .parser import ir_module, from_source
22 changes: 22 additions & 0 deletions python/tvm/script/builder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import
"""Namespace for the TVMScript Builder API."""


from .builder import Builder, def_, def_many
from .frame import Frame, IRModuleFrame
20 changes: 20 additions & 0 deletions python/tvm/script/builder/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.script.builder"""
import tvm._ffi

tvm._ffi._init_api("script.builder", __name__)
58 changes: 58 additions & 0 deletions python/tvm/script/builder/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Script IR Builder"""
from typing import List
from tvm._ffi import register_object as _register_object
from .frame import Frame

from tvm.runtime import Object

from . import _ffi_api

from typing import TypeVar


@_register_object("script.builder.Builder")
class Builder(Object):
def __init__(self) -> None:
self.__init_handle_by_constructor__(_ffi_api.Builder)

def __enter__(self) -> "Builder":
_ffi_api.BuilderEnter(self)
return self

def __exit__(self, ptype, value, trace) -> None:
_ffi_api.BuilderExit(self)

@staticmethod
def current(self) -> "Builder":
return _ffi_api.BuilderCurrent(self)

def get(self) -> Frame:
return _ffi_api.BuilderGet(self)


DefType = TypeVar("DefType", bound=Object)


def def_(name: str, var: DefType) -> DefType:
return _ffi_api.Def(name, var)


def def_many(names: List[str], vars: List[DefType]) -> List[DefType]:
assert len(names) == len(vars)
return [def_(name, var) for name, var in zip(names, vars)]
38 changes: 38 additions & 0 deletions python/tvm/script/builder/frame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Script Frames"""
from tvm._ffi import register_object as _register_object

from tvm.runtime import Object

from . import _ffi_api


@_register_object("script.builder.Frame")
class Frame(Object):
def __enter__(self) -> "Frame":
_ffi_api.FrameEnter(self)
return self

def __exit__(self, ptype, value, trace) -> None:
_ffi_api.FrameExit(self)


@_register_object("script.builder.IRModuleFrame")
class IRModuleFrame(Frame):
def __init__(self) -> None:
self.__init_handle_by_constructor__(_ffi_api.IRModuleFrame)
33 changes: 33 additions & 0 deletions python/tvm/script/builder/tir/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import
"""Namespace for the TVMScript TIR Builder API."""

from .base import TIRFrame
from .for_frame import (
ForFrame,
serial,
parallel,
vectorized,
unroll,
thread_binding,
grid,
)
from .prim_func_frame import prim_func, arg
from .block_frame import block
from .var import Buffer
from . import axis
22 changes: 22 additions & 0 deletions python/tvm/script/builder/tir/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.script.builder"""
import tvm._ffi

from .. import _ffi_api as _base_ffi_api

tvm._ffi._init_api("script.builder.tir", __name__)
37 changes: 37 additions & 0 deletions python/tvm/script/builder/tir/axis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Script TIR Axis"""

from . import _ffi_api
from tvm.ir import Range
from tvm.tir import IterVar


def spatial(dom, binding, dtype="int32") -> IterVar:
if not isinstance(dom, Range):
dom = Range(0, dom)
return _ffi_api.AxisSpatial(dom, binding, dtype)


def reduce(dom, binding, dtype="int32") -> IterVar:
if not isinstance(dom, Range):
dom = Range(0, dom)
return _ffi_api.AxisReduce(dom, binding, dtype)


def remap(kinds, bindings, dtype="int32") -> IterVar:
return _ffi_api.AxisRemap(kinds, bindings, dtype)
26 changes: 26 additions & 0 deletions python/tvm/script/builder/tir/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Script TIR Frame"""
from tvm._ffi import register_object as _register_object

from . import _ffi_api
from ..frame import Frame


@_register_object("script.builder.tir.TIRFrame")
class TIRFrame(Frame):
pass
31 changes: 31 additions & 0 deletions python/tvm/script/builder/tir/block_frame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Script TIR Block Frame"""
from tvm._ffi import register_object as _register_object
from .base import TIRFrame


from . import _ffi_api


@_register_object("script.builder.tir.BlockFrame")
class BlockFrame(TIRFrame):
pass


def block(name) -> BlockFrame:
return _ffi_api.BlockFrame(name)
56 changes: 56 additions & 0 deletions python/tvm/script/builder/tir/for_frame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Script TIR For Frame"""
from tvm._ffi import register_object as _register_object

from tvm.tir import Var

from . import _ffi_api
from ._ffi_api import _base_ffi_api
from .base import TIRFrame
from typing import List


@_register_object("script.builder.tir.ForFrame")
class ForFrame(TIRFrame):
def __enter__(self) -> List[Var]:
_base_ffi_api.FrameEnter(self)
return self.vars


def serial(min_val, extent, attrs) -> ForFrame:
return _ffi_api.Serial(min_val, extent, attrs)


def parallel(min_val, extent, attrs) -> ForFrame:
return _ffi_api.Parallel(min_val, extent, attrs)


def vectorized(min_val, extent, attrs) -> ForFrame:
return _ffi_api.Vectorized(min_val, extent, attrs)


def unroll(min_val, extent, attrs) -> ForFrame:
return _ffi_api.Unroll(min_val, extent, attrs)


def thread_binding(min_val, extent, attrs) -> ForFrame:
return _ffi_api.ThreadBinding(min_val, extent, attrs)


def grid(*extents) -> ForFrame:
return _ffi_api.Grid(extents)
40 changes: 40 additions & 0 deletions python/tvm/script/builder/tir/prim_func_frame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TVM Script TIR Prim Func Frame"""
from tvm._ffi import register_object as _register_object

from tvm.tir.expr import Var
from tvm.tir.buffer import Buffer


from . import _ffi_api
from .base import TIRFrame

from typing import Union


@_register_object("script.builder.tir.PrimFuncFrame")
class PrimFuncFrame(TIRFrame):
pass


def prim_func(name) -> PrimFuncFrame:
return _ffi_api.PrimFuncFrame(name)


def arg(name, arg) -> Union[Var, Buffer]:
return _ffi_api.Arg(name, arg)
Loading

0 comments on commit 30224ed

Please sign in to comment.