diff --git a/python/py_namespace.bzl b/python/py_namespace.bzl new file mode 100644 index 00000000000..13a7a0f2d59 --- /dev/null +++ b/python/py_namespace.bzl @@ -0,0 +1,141 @@ +"""Repository rule py_workspace(), augmenting @rules_python, for relocating +py_library() and py_package() files underneath a given Python namespace. + +The stock @rules_python py_library() -> py_package() -> py_wheel() BUILD file +workflow packages files at Python package paths set to the source paths of the +files relative to the workspace root. This has a several problems. Firstly, it +implies that files must be located underneath a source directory with the same +name as the desired Python namespace package. ( py_wheel.strip_path_prefixes +can remove path components, but cannot add them.) This is not always feasible +or desirable. + +Secondly, this path naming is incompatible with the PYTHONPATH set by +@rules_python when executing Python programs in the source tree via +py_binary(). PYTHONPATH is set such that imports should begin with the +WORKSPACE name, followed by the path from the workspace root. py_wheel(), +however, packages files such that imports use only the path from the workspace +root. + +For example, the source file: + example/hello.py +is imported by a py_binary() running in the source tree as: + `import workspace_name.example.hello` +but must be imported from within the package created by py_wheel() as: + `import example.hello` + +The end result is that code cannot be written to work both in the source tree +and installed in a Python environment via a package. + +py_namespace() fixes these problems by providing the means to package files +within a Python package namespace without adding a corresponding directory in +the source tree. The BUILD workflow changes to py_libary() -> py_package() -> +**py_namespace()** -> py_wheel(). For example: + +``` + # in example/BUILD + + py_library( + name = "library", + srcs = ["hello.py"], + deps = ..., + ) + + py_package( + name = "package", + deps = [":library"], + ) + + py_namespace( + name = "namespace", + deps = [":package"], + namespace = "foo", + ) + + py_wheel( + .... + deps = [":namespace"], + ) +``` + +In this case, the source file: + example/hello.py +which is imported by a py_binary() running in the source tree as: + `import workspace_name.example.hello` +is imported from the package created by py_wheel() as: + `import foo.example.hello` + +If the namespace and the WORKSPACE name match, the import paths used when +running in the source tree will match the import paths used when installed in +the Python environment. + +Furthermore, the Python package can be given an __init__.py file via the +attribute `init`. The given file is relocated directly under the namespace as +__init__.py, regardless of its path in the source tree. This __init__.py can be +used for, among other things, providing a user-friendly public API: providing +aliases for modules otherwise deeply nested in subpackages due to their +location in the source tree. +""" + +def _relocate_init(ctx): + # Copy the init file directly underneath the namespace directory. + outfile = ctx.actions.declare_file(ctx.attr.namespace + "/__init__.py") + ctx.actions.run_shell( + inputs = [ctx.file.init], + outputs = [outfile], + arguments = [ctx.file.init.path, outfile.path], + command = "cp $1 $2", + ) + return outfile + +def _relocate_deps(ctx): + # Copy all transitive deps underneath the namespace directory. E.g., + # example/hello.py + # becomes: + # namespace/example/hello.py + outfiles = [] + inputs = depset(transitive = [dep[DefaultInfo].files for dep in ctx.attr.deps]) + + for infile in sorted(inputs.to_list()): + outfile = ctx.actions.declare_file(ctx.attr.namespace + "/" + infile.short_path) + ctx.actions.run_shell( + inputs = [infile], + outputs = [outfile], + arguments = [infile.path, outfile.path], + command = "cp $1 $2", + ) + outfiles.append(outfile) + + return outfiles + +def _py_namespace(ctx): + # Copy all input files underneath the namesapce directory and return a + # Provider with the new file locations. + outfiles = [] + + if ctx.file.init: + outfiles.append(_relocate_init(ctx)) + + outfiles.extend(_relocate_deps(ctx)) + + return [ + DefaultInfo(files = depset(outfiles)), + ] + +py_namespace = rule( + implementation = _py_namespace, + attrs = { + "init": attr.label( + doc = "optional file for __init__.py", + allow_single_file = [".py"], + mandatory = False, + ), + "namespace": attr.string( + doc = "name for Python namespace", + mandatory = True, + ), + "deps": attr.label_list( + doc = "list of py_library() and py_package()s to include", + mandatory = True, + ), + }, +) diff --git a/python/tflite_micro/BUILD b/python/tflite_micro/BUILD index 24efc9f7d97..68321fd454a 100644 --- a/python/tflite_micro/BUILD +++ b/python/tflite_micro/BUILD @@ -1,4 +1,6 @@ load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") +load("//python:py_namespace.bzl", "py_namespace") +load("@rules_python//python:packaging.bzl", "py_package", "py_wheel") load("@tflm_pip_deps//:requirements.bzl", "requirement") load( "//tensorflow/lite/micro:build_def.bzl", @@ -95,3 +97,83 @@ py_test( "//tensorflow/lite/micro/testing:generate_test_models_lib", ], ) + +py_library( + name = "postinstall_check", + srcs = [ + "postinstall_check.py", + ], + data = [ + "sine_float.tflite", + ], +) + +# Collect the `deps` and their transitive dependences together into a set of +# files to package. The files retain their full path relative to the workspace +# root, which determines the subpackage path at which they're located within +# the Python distribution package. +py_package( + name = "files_to_package", + + # Only Python subpackage paths matching the following prefixes are included + # in the files to package. This avoids packaging, e.g., numpy, which is a + # transitive dependency of the tflm runtime target. This list may require + # modification when adding, directly or indirectly, `deps` from other paths + # in the tflm tree. + packages = [ + "python.tflite_micro", + "tensorflow.lite.python", + "tensorflow.lite.tools.flatbuffer_utils", + ], + deps = [ + ":postinstall_check", + ":runtime", + ], +) + +# Relocate `deps` underneath the given Python package namespace, otherwise +# maintaining their full paths relative to the workspace root. +# +# For example: +# ${workspace_root}/example/hello.py +# becomes: +# namespace.example.hello +# +# Place `init` at the root of the namespace, regardless of `init`'s path in the +# source tree. +py_namespace( + name = "namespace", + init = "__init__.py", + namespace = "tflite_micro", + deps = [ + ":files_to_package", + ], +) + +py_wheel( + name = "whl", + distribution = "tflite_micro", + requires = [ + "flatbuffers", + "numpy", + "tensorflow", + ], + strip_path_prefixes = [package_name()], + version = "0.1.0", + deps = [ + ":namespace", + ], +) + +sh_test( + name = "whl_test", + srcs = [ + "whl_test.sh", + ], + args = [ + "$(location :whl)", + ], + data = [ + ":whl", + ], +) diff --git a/python/tflite_micro/__init__.py b/python/tflite_micro/__init__.py new file mode 100644 index 00000000000..940289ce5d0 --- /dev/null +++ b/python/tflite_micro/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Define a public API for the package by providing aliases for modules which +# are otherwise deeply nested in subpackages determined by their location in +# the tflm source tree. Directly using modules and subpackages not explicitly +# made part of the public API in code outside of the tflm source tree is +# unsupported. + +from tflite_micro.python.tflite_micro import runtime + +# Ordered after `runtime` to avoid a circular dependency +from tflite_micro.python.tflite_micro import postinstall_check diff --git a/python/tflite_micro/postinstall_check.py b/python/tflite_micro/postinstall_check.py new file mode 100644 index 00000000000..89e98551696 --- /dev/null +++ b/python/tflite_micro/postinstall_check.py @@ -0,0 +1,53 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# A simple test to check whether the tflite_micro package works after it is +# installed. + +# To test from the perspective of a package user, use import paths to locations +# in the Python installation environment rather than to locations in the tflm +# source tree. +from tflite_micro import runtime + +import numpy as np +import pkg_resources +import sys + + +def passed(): + # Create an interpreter with a sine model + model = pkg_resources.resource_filename(__name__, "sine_float.tflite") + interpreter = runtime.Interpreter.from_file(model) + OUTPUT_INDEX = 0 + INPUT_INDEX = 0 + input_shape = interpreter.get_input_details(INPUT_INDEX).get("shape") + + # The interpreter infers sin(x) + def infer(x): + tensor = np.array(x, np.float32).reshape(input_shape) + interpreter.set_input(tensor, INPUT_INDEX) + interpreter.invoke() + return interpreter.get_output(OUTPUT_INDEX).squeeze() + + # Check a few inferred values against a numerical computation + PI = 3.14 + inputs = (0.0, PI / 2, PI, 3 * PI / 2, 2 * PI) + outputs = [infer(x) for x in inputs] + goldens = np.sin(inputs) + + return np.allclose(outputs, goldens, atol=0.05) + + +if __name__ == "__main__": + sys.exit(0 if passed() else 1) diff --git a/python/tflite_micro/sine_float.tflite b/python/tflite_micro/sine_float.tflite new file mode 100644 index 00000000000..f741b3a7b6b Binary files /dev/null and b/python/tflite_micro/sine_float.tflite differ diff --git a/python/tflite_micro/whl_test.sh b/python/tflite_micro/whl_test.sh new file mode 100755 index 00000000000..8b6d97fd3da --- /dev/null +++ b/python/tflite_micro/whl_test.sh @@ -0,0 +1,43 @@ +#!/usr/bin/sh + +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Install the given tflm-micro.whl in a fresh virtual environment and run its +# embedded, post-installation checks. + +set -e + +WHL="${1}" + +# Create venv for this test. +python3 -m venv pyenv +. pyenv/bin/activate + +# Disable pip's cache for two reasons: 1) the default location in +# $XDG_CACHE_HOME causes errors when pip is run from a bazel sandbox, and 2) it +# makes no sense to relocate the cache within the sandbox since files generated +# in the sandbox are deleted after the run. +export PIP_NO_CACHE_DIR=true + +# Test package installation. +pip install "${WHL}" +pip show --files tflite-micro + +# Run the package's post-installation checks. +python3 << HEREDOC +import sys +from tflite_micro import postinstall_check +sys.exit(0 if postinstall_check.passed() else 1) +HEREDOC