Skip to content

Commit

Permalink
feat: build and test a Python distribution package tflite_micro (#2151
Browse files Browse the repository at this point in the history
)

Add the build configuration and integrated test to generate a Python
distribution package named `tflite_micro` for publishing the tflm interpreter
as a Python module with a native extension.

Use the build tools provided in @rules_python, augmented by a custom rule
`py_namespace` for the reasons documented in `python/py_namespace.bzl`.

Provide an integration test at `//python/tflite_micro:whl_test`. Use a .tflite
model copied from the hello_world example. (Copied to avoid creating a
dependency.)

BUG=part of #1484
  • Loading branch information
rkuester authored Aug 1, 2023
1 parent aa945a0 commit ca74563
Show file tree
Hide file tree
Showing 6 changed files with 343 additions and 0 deletions.
141 changes: 141 additions & 0 deletions python/py_namespace.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""Repository rule py_workspace(), augmenting @rules_python, for relocating
py_library() and py_package() files underneath a given Python namespace.
The stock @rules_python py_library() -> py_package() -> py_wheel() BUILD file
workflow packages files at Python package paths set to the source paths of the
files relative to the workspace root. This has a several problems. Firstly, it
implies that files must be located underneath a source directory with the same
name as the desired Python namespace package. ( py_wheel.strip_path_prefixes
can remove path components, but cannot add them.) This is not always feasible
or desirable.
Secondly, this path naming is incompatible with the PYTHONPATH set by
@rules_python when executing Python programs in the source tree via
py_binary(). PYTHONPATH is set such that imports should begin with the
WORKSPACE name, followed by the path from the workspace root. py_wheel(),
however, packages files such that imports use only the path from the workspace
root.
For example, the source file:
example/hello.py
is imported by a py_binary() running in the source tree as:
`import workspace_name.example.hello`
but must be imported from within the package created by py_wheel() as:
`import example.hello`
The end result is that code cannot be written to work both in the source tree
and installed in a Python environment via a package.
py_namespace() fixes these problems by providing the means to package files
within a Python package namespace without adding a corresponding directory in
the source tree. The BUILD workflow changes to py_libary() -> py_package() ->
**py_namespace()** -> py_wheel(). For example:
```
# in example/BUILD
py_library(
name = "library",
srcs = ["hello.py"],
deps = ...,
)
py_package(
name = "package",
deps = [":library"],
)
py_namespace(
name = "namespace",
deps = [":package"],
namespace = "foo",
)
py_wheel(
....
deps = [":namespace"],
)
```
In this case, the source file:
example/hello.py
which is imported by a py_binary() running in the source tree as:
`import workspace_name.example.hello`
is imported from the package created by py_wheel() as:
`import foo.example.hello`
If the namespace and the WORKSPACE name match, the import paths used when
running in the source tree will match the import paths used when installed in
the Python environment.
Furthermore, the Python package can be given an __init__.py file via the
attribute `init`. The given file is relocated directly under the namespace as
__init__.py, regardless of its path in the source tree. This __init__.py can be
used for, among other things, providing a user-friendly public API: providing
aliases for modules otherwise deeply nested in subpackages due to their
location in the source tree.
"""

def _relocate_init(ctx):
# Copy the init file directly underneath the namespace directory.
outfile = ctx.actions.declare_file(ctx.attr.namespace + "/__init__.py")
ctx.actions.run_shell(
inputs = [ctx.file.init],
outputs = [outfile],
arguments = [ctx.file.init.path, outfile.path],
command = "cp $1 $2",
)
return outfile

def _relocate_deps(ctx):
# Copy all transitive deps underneath the namespace directory. E.g.,
# example/hello.py
# becomes:
# namespace/example/hello.py
outfiles = []
inputs = depset(transitive = [dep[DefaultInfo].files for dep in ctx.attr.deps])

for infile in sorted(inputs.to_list()):
outfile = ctx.actions.declare_file(ctx.attr.namespace + "/" + infile.short_path)
ctx.actions.run_shell(
inputs = [infile],
outputs = [outfile],
arguments = [infile.path, outfile.path],
command = "cp $1 $2",
)
outfiles.append(outfile)

return outfiles

def _py_namespace(ctx):
# Copy all input files underneath the namesapce directory and return a
# Provider with the new file locations.
outfiles = []

if ctx.file.init:
outfiles.append(_relocate_init(ctx))

outfiles.extend(_relocate_deps(ctx))

return [
DefaultInfo(files = depset(outfiles)),
]

py_namespace = rule(
implementation = _py_namespace,
attrs = {
"init": attr.label(
doc = "optional file for __init__.py",
allow_single_file = [".py"],
mandatory = False,
),
"namespace": attr.string(
doc = "name for Python namespace",
mandatory = True,
),
"deps": attr.label_list(
doc = "list of py_library() and py_package()s to include",
mandatory = True,
),
},
)
82 changes: 82 additions & 0 deletions python/tflite_micro/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
load("//python:py_namespace.bzl", "py_namespace")
load("@rules_python//python:packaging.bzl", "py_package", "py_wheel")
load("@tflm_pip_deps//:requirements.bzl", "requirement")
load(
"//tensorflow/lite/micro:build_def.bzl",
Expand Down Expand Up @@ -95,3 +97,83 @@ py_test(
"//tensorflow/lite/micro/testing:generate_test_models_lib",
],
)

py_library(
name = "postinstall_check",
srcs = [
"postinstall_check.py",
],
data = [
"sine_float.tflite",
],
)

# Collect the `deps` and their transitive dependences together into a set of
# files to package. The files retain their full path relative to the workspace
# root, which determines the subpackage path at which they're located within
# the Python distribution package.
py_package(
name = "files_to_package",

# Only Python subpackage paths matching the following prefixes are included
# in the files to package. This avoids packaging, e.g., numpy, which is a
# transitive dependency of the tflm runtime target. This list may require
# modification when adding, directly or indirectly, `deps` from other paths
# in the tflm tree.
packages = [
"python.tflite_micro",
"tensorflow.lite.python",
"tensorflow.lite.tools.flatbuffer_utils",
],
deps = [
":postinstall_check",
":runtime",
],
)

# Relocate `deps` underneath the given Python package namespace, otherwise
# maintaining their full paths relative to the workspace root.
#
# For example:
# ${workspace_root}/example/hello.py
# becomes:
# namespace.example.hello
#
# Place `init` at the root of the namespace, regardless of `init`'s path in the
# source tree.
py_namespace(
name = "namespace",
init = "__init__.py",
namespace = "tflite_micro",
deps = [
":files_to_package",
],
)

py_wheel(
name = "whl",
distribution = "tflite_micro",
requires = [
"flatbuffers",
"numpy",
"tensorflow",
],
strip_path_prefixes = [package_name()],
version = "0.1.0",
deps = [
":namespace",
],
)

sh_test(
name = "whl_test",
srcs = [
"whl_test.sh",
],
args = [
"$(location :whl)",
],
data = [
":whl",
],
)
24 changes: 24 additions & 0 deletions python/tflite_micro/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Define a public API for the package by providing aliases for modules which
# are otherwise deeply nested in subpackages determined by their location in
# the tflm source tree. Directly using modules and subpackages not explicitly
# made part of the public API in code outside of the tflm source tree is
# unsupported.

from tflite_micro.python.tflite_micro import runtime

# Ordered after `runtime` to avoid a circular dependency
from tflite_micro.python.tflite_micro import postinstall_check
53 changes: 53 additions & 0 deletions python/tflite_micro/postinstall_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# A simple test to check whether the tflite_micro package works after it is
# installed.

# To test from the perspective of a package user, use import paths to locations
# in the Python installation environment rather than to locations in the tflm
# source tree.
from tflite_micro import runtime

import numpy as np
import pkg_resources
import sys


def passed():
# Create an interpreter with a sine model
model = pkg_resources.resource_filename(__name__, "sine_float.tflite")
interpreter = runtime.Interpreter.from_file(model)
OUTPUT_INDEX = 0
INPUT_INDEX = 0
input_shape = interpreter.get_input_details(INPUT_INDEX).get("shape")

# The interpreter infers sin(x)
def infer(x):
tensor = np.array(x, np.float32).reshape(input_shape)
interpreter.set_input(tensor, INPUT_INDEX)
interpreter.invoke()
return interpreter.get_output(OUTPUT_INDEX).squeeze()

# Check a few inferred values against a numerical computation
PI = 3.14
inputs = (0.0, PI / 2, PI, 3 * PI / 2, 2 * PI)
outputs = [infer(x) for x in inputs]
goldens = np.sin(inputs)

return np.allclose(outputs, goldens, atol=0.05)


if __name__ == "__main__":
sys.exit(0 if passed() else 1)
Binary file added python/tflite_micro/sine_float.tflite
Binary file not shown.
43 changes: 43 additions & 0 deletions python/tflite_micro/whl_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/usr/bin/sh

# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Install the given tflm-micro.whl in a fresh virtual environment and run its
# embedded, post-installation checks.

set -e

WHL="${1}"

# Create venv for this test.
python3 -m venv pyenv
. pyenv/bin/activate

# Disable pip's cache for two reasons: 1) the default location in
# $XDG_CACHE_HOME causes errors when pip is run from a bazel sandbox, and 2) it
# makes no sense to relocate the cache within the sandbox since files generated
# in the sandbox are deleted after the run.
export PIP_NO_CACHE_DIR=true

# Test package installation.
pip install "${WHL}"
pip show --files tflite-micro

# Run the package's post-installation checks.
python3 << HEREDOC
import sys
from tflite_micro import postinstall_check
sys.exit(0 if postinstall_check.passed() else 1)
HEREDOC

0 comments on commit ca74563

Please sign in to comment.