Skip to content

Commit

Permalink
[TESTING] pytorch-like nn.Module API to build neural network (apache#54)
Browse files Browse the repository at this point in the history
* nn module

* address comments.

* Add nn.init_params

* Remove nn.Builder and use BlockBuilder instead.

* Rebase.

* Refactor block builder and add tests.

* Address comments.

* Update.
  • Loading branch information
YuchenJin authored and junrushao committed Feb 9, 2023
1 parent 1e5be1a commit ac620da
Show file tree
Hide file tree
Showing 9 changed files with 517 additions and 85 deletions.
7 changes: 3 additions & 4 deletions apps/relax_examples/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
def build_mlp(data, weight):
bb = relax.BlockBuilder()

with bb.function([data, weight], "mlp"):
with bb.function("mlp", [data, weight]):
gv0 = bb.emit_te(tvm.contrib.cblas.matmul, data, weight, transa=False, transb=False)
gv1 = bb.emit_te(topi.nn.relu, gv0)
bb.emit_func_output(gv1)
Expand All @@ -47,9 +47,8 @@ def build_mlp(data, weight):
mod = build_mlp(data, weight)

# build and create vm executor
target = tvm.target.Target("llvm")
target_host = tvm.target.Target("llvm")
ex, lib = relax.vm.build(mod, target, target_host)
target = tvm.target.Target("llvm", host="llvm")
ex, lib = relax.vm.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)

# run the mlp model on relax vm
Expand Down
69 changes: 69 additions & 0 deletions apps/relax_examples/nn_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# Example code on creating, compiling, and running a neural network with pytorch-like API


import tvm
from tvm.relay import Call
from tvm import relax, tir
from tvm.relax.testing import nn
from tvm.script import relax as R
import numpy as np


if __name__ == "__main__":
builder = relax.BlockBuilder()

# a symbolic variable to represent minibatch size
n = tir.Var("n", "int64")
input_size = 784
hidden_sizes = [128, 32]
output_size = 10

# build a three linear-layer neural network for a classification task
with builder.function("main"):
model = nn.Sequential(
nn.Linear(input_size, hidden_sizes[0]),
nn.ReLU(),
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
nn.ReLU(),
nn.Linear(hidden_sizes[1], output_size),
nn.LogSoftmax(),
)
data = nn.Placeholder((n, input_size), name="data")
output = model(data)
params = [data] + model.parameters()
builder.emit_func_output(output, params=params)

# get and print the IRmodule being built
mod = builder.get()
print(R.parser.astext(mod))

# build the IRModule and create relax vm
target = tvm.target.Target("llvm", host="llvm")
ex, lib = relax.vm.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)

# init parameters
params = nn.init_params(mod)

# run the model on relax vm
# the input data has a minibatch size of 3
data = tvm.nd.array(np.random.rand(3, input_size).astype(np.float32))
res = vm["main"](data, *params)
print(res)
186 changes: 143 additions & 43 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,42 +30,37 @@
class FunctionScope(object):
"""Auxiliary scope for function"""

def __init__(self, irbuilder):
self._ib = irbuilder
def __init__(self, block_builder, name, params):
self._bb = block_builder
self._name = name
self._params = params

def __enter__(self):
_ffi_api.BlockBuilderBeginBindingBlock(self._ib)
self._bb._enter_function_scope(self._name, self._params)

def __exit__(self, ptype, value, trace):
block = _ffi_api.BlockBuilderEndBlock(self._ib)
if len(block.bindings) > 0:
self._ib._blocks.append(block)
seqe = rx.SeqExpr(self._ib._blocks, self._ib._func_ret)
func = rx.Function(
self._ib._func_params, seqe, rx.DynTensorType(-1, "float32"), rx.GlobalVar(self._ib._func_name)
)
gvar = rx.GlobalVar(self._ib._func_name)
self._ib._context_mod[gvar] = func
return func
def __exit__(self, exc_type, exc_val, exc_tb):
# __exit__ should properly handle the case where the with block exits with an exception
# when handling error case in exit, always check if there is already an exception been thrown in the with block
self._bb._exit_function_scope(exc_type, exc_val, exc_tb)


class DataflowScope(object):
"""Auxiliary scope for Dataflow block"""

def __init__(self, irbuilder):
self._ib = irbuilder
def __init__(self, block_builder):
self._bb = block_builder

def __enter__(self):
block = _ffi_api.BlockBuilderEndBlock(self._ib)
block = self._bb._end_block()
if len(block.bindings) > 0:
self._ib._blocks.append(block)
_ffi_api.BlockBuilderBeginDataflowBlock(self._ib)
self._bb._blocks.append(block)
self._bb._begin_dataflow_block()

def __exit__(self, ptype, value, trace):
block = _ffi_api.BlockBuilderEndBlock(self._ib)
block = self._bb._end_block()
if len(block.bindings) > 0:
self._ib._blocks.append(block)
_ffi_api.BlockBuilderBeginBindingBlock(self._ib)
self._bb._blocks.append(block)
self._bb._begin_binding_block()


@tvm._ffi.register_object("relax.BlockBuilder")
Expand All @@ -82,19 +77,55 @@ class BlockBuilder(Object):
dtype1 = rx.DynTensorType(rank=1, dtype="float16")
x = rx.Var("x", [m, n], dtype0)
y = rx.Var("y", [n], dtype1)
ib = rx.BlockBuilder()
with ib.function([x, y], "func"):
with ib.dataflow() as df:
lv0 = ib.emit(rx.add(x, y))
lv1 = ib.emit(rx.multiply(lv0, y))
gv0 = ib.emit_output(lv1)
ib.emit_func_output(gv0)
mod = ib.get()
bb = rx.BlockBuilder()
with bb.function([x, y], "func"):
with bb.dataflow() as df:
lv0 = bb.emit(rx.add(x, y))
lv1 = bb.emit(rx.multiply(lv0, y))
gv0 = bb.emit_output(lv1)
bb.emit_func_output(gv0)
mod = bb.get()
BlockBuilder can also be used to contruct neural networks with nn.Module API
.. code-block:: python
from tvm.relax.testing import nn
n = tir.Var("n", "int64")
input_size = 784
hidden_sizes = [128, 32]
output_size = 10
bb = rx.BlockBuilder()
with bb.function("main"):
model = nn.Sequential(
nn.Linear(input_size, hidden_sizes[0]),
nn.ReLU(),
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
nn.ReLU(),
nn.Linear(hidden_sizes[1], output_size),
nn.LogSoftmax(),
)
data = nn.Placeholder((n, input_size), name="data")
output = model(data)
params = [data] + model.parameters()
builder.emit_func_output(output, params=params)
mod = bb.get()
"""

_current = None

@staticmethod
def current():
"""Returns the current BlockBuilder."""
return BlockBuilder._current

def __init__(self):
self._blocks = []
self._context_mod = tvm.IRModule()
# a boolean flag that tracks if emit_func_output has been called
self._is_emit_func_output_called = False;
self.__init_handle_by_constructor__(_ffi_api.BlockBuilderCreate)

def _begin_dataflow_block(self) -> None:
Expand All @@ -105,6 +136,22 @@ def _begin_binding_block(self) -> None:

def _end_block(self) -> BindingBlock:
return _ffi_api.BlockBuilderEndBlock(self)

def _enter_function_scope(self, name, params):
if BlockBuilder.current() is not None:
raise RuntimeError("BlockBuilder does not allow nested functions.")
BlockBuilder._current = self
self._func_name = name
self._func_params = params
self._begin_binding_block()

def _exit_function_scope(self, exc_type, exc_val, exc_tb):
if exc_type is None:
if not self._is_emit_func_output_called:
raise RuntimeError("emit_func_output must be called in a relax function.")

self._is_emit_func_output_called = False
BlockBuilder._current = None

def _convert_te_arg(self,
te_args: Any
Expand Down Expand Up @@ -173,31 +220,36 @@ def _populate_used_vars(expr):


def function(self,
params: Optional[Union[Var, Tuple, List[Var]]] = None,
name: Optional[str] = "") -> FunctionScope:
name: str,
params: Optional[Union[Var, Tuple, List[Var]]] = None) -> FunctionScope:
"""Annotate a Relax function.
Parameters
----------
name : str, optional
The name of the function
params : tvm.relax.Var | Tuple | List[tvm.relax.Var], optional
The parameters of the function.
name : str, optional
The name of the function. If provided, the function is global, otherwise local.
If params is None, it means deferring initialization of function parameters until emit_func_output.
Returns
-------
ret: FunctionScope
A FunctionScope for building a Relax function node.
"""
if not params:
params = []
if not isinstance(params, (list, tuple)):
params = None
elif isinstance(params, rx.Var):
params = [params]
elif isinstance(params, (list, tuple)):
for param in params:
if not isinstance(param, rx.Var):
raise TypeError("each element of function parameters must be of type tvm.relax.Var,\
but got: {}".format(type(param)))

self._func_params = params
self._func_name = name
return FunctionScope(self)
name = self.get_unique_name(name)
return FunctionScope(self, name, params)

def dataflow(self) -> DataflowScope:
"""Annotate a Relax dataflow block.
Expand Down Expand Up @@ -304,12 +356,12 @@ def rx_func(x: Tensor[(n, m), "float32"], y: Tensor[(n, m), "float32"]) -> Tenso

inputs = [*te_args, te_out]
tir_func = tvm.te.create_prim_func(inputs)
func_name = _ffi_api.BlockBuilderGetUniqueName(self, func.__name__)
func_name = self.get_unique_name(func.__name__)
tir_func = tir_func.with_attr("global_symbol", func_name)
gvar = GlobalVar(func_name)
self._context_mod[gvar] = tir_func
call = call_dps(inputs[-1].shape, gvar, [x.op.value for x in inputs[:-1]])
return _ffi_api.BlockBuilderEmit(self, call)
return self.emit(call)


def match_shape(self, value: Expr, pattern: List[PrimExpr]) -> Var:
Expand Down Expand Up @@ -347,22 +399,54 @@ def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None:
output = Tuple(output)
return _ffi_api.BlockBuilderEmitOutput(self, output)

def emit_func_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None:
def emit_func_output(self,
output: Union[Expr, Tuple, List[Expr]],
params: Optional[Union[Var, Tuple, List[Var]]] = None) -> None:
"""Emit output for the function.
Parameters
----------
output : Expr | Tuple | List[Expr]
The output of the current block/function.
params : tvm.relax.Var | Tuple | List[tvm.relax.Var], optional
The parameters of the function to be built.
If params is None, it means the params have been initialized in the function with scope.
Returns
-------
ret : tvm.relax.Var
The return variable which gets binded to the output.
"""
if self._is_emit_func_output_called:
raise RuntimeError("emit_func_output must be called exactly once in a relax function.")
self._is_emit_func_output_called = True

if self._func_params is not None and params is not None:
raise RuntimeError("function parameters have been initialized in the function with scope.")

if self._func_params is None and params is None:
raise RuntimeError("Relax function must have parameter.")

if self._func_params is None:
self._func_params = params

if BlockBuilder.current() is not self:
raise RuntimeError("BlockBuilder._current must be self.")

if isinstance(output, (list, tuple)):
output = Tuple(output)
self._func_ret = output

block = self._end_block()
if len(block.bindings) > 0:
self._blocks.append(block)
seqe = rx.SeqExpr(self._blocks, self._func_ret)
func = rx.Function(
self._func_params, seqe, rx.DynTensorType(-1), rx.GlobalVar(self._func_name)
)
gvar = rx.GlobalVar(self._func_name)
self._context_mod[gvar] = func

def normalize(self, expr: Expr) -> Expr:
"""Normalize an Expr to complete its shape and type.
Expand All @@ -388,3 +472,19 @@ def get(self) -> tvm.IRModule:
An IRModule with Relax and TIR functions being built.
"""
return self._context_mod


def get_unique_name(self, name_prefix: str) -> str:
"""Generate a unique name with a specified prefix.
Parameters
----------
name_hint : str
The name prefix.
Returns
-------
ret : str
The generated name.
"""
return _ffi_api.BlockBuilderGetUniqueName(self, name_prefix)
Loading

0 comments on commit ac620da

Please sign in to comment.