Skip to content

Commit

Permalink
Remove usage of Any from some modules in the impl package.
Browse files Browse the repository at this point in the history
In most cases it's possible to simply replace `Any` with `object` but in some cases a more specific type is appropriate.

This change also fixes any pytype errors caused by adding the correct Python type annotation.

Note: This also meant moving the PolymorphicComputaiton class into it's own file in order to break a circular dependency.
PiperOrigin-RevId: 567727933
  • Loading branch information
michaelreneer authored and tensorflow-copybara committed Sep 22, 2023
1 parent 21d32c2 commit da6e997
Show file tree
Hide file tree
Showing 8 changed files with 288 additions and 189 deletions.
24 changes: 22 additions & 2 deletions tensorflow_federated/python/core/impl/computation/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ py_library(
deps = [
":computation_impl",
":function_utils",
":polymorphic_computation",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/impl/types:computation_types",
Expand Down Expand Up @@ -126,7 +127,6 @@ py_library(
name = "function_utils",
srcs = ["function_utils.py"],
deps = [
":computation_base",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/impl/types:computation_types",
Expand All @@ -140,11 +140,31 @@ py_test(
name = "function_utils_test",
size = "small",
srcs = ["function_utils_test.py"],
deps = [
":function_utils",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/impl/types:computation_types",
],
)

py_library(
name = "polymorphic_computation",
srcs = ["polymorphic_computation.py"],
deps = [
":computation_impl",
":function_utils",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:type_conversions",
],
)

py_test(
name = "polymorphic_computation_test",
srcs = ["polymorphic_computation_test.py"],
deps = [
":computation_impl",
":polymorphic_computation",
"//tensorflow_federated/proto/v0:computation_py_pb2",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/impl/context_stack:context_base",
"//tensorflow_federated/python/core/impl/context_stack:context_stack_base",
"//tensorflow_federated/python/core/impl/types:computation_types",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.impl.computation import computation_impl
from tensorflow_federated.python.core.impl.computation import function_utils
from tensorflow_federated.python.core.impl.computation import polymorphic_computation
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import type_conversions
from tensorflow_federated.python.tensorflow_libs import function
Expand Down Expand Up @@ -528,7 +529,9 @@ def _polymorphic_wrapper(
**kwargs,
)

wrapped_func = function_utils.PolymorphicComputation(_polymorphic_wrapper)
wrapped_func = polymorphic_computation.PolymorphicComputation(
_polymorphic_wrapper
)
else:
# Either we have a concrete parameter type, or this is no-arg function.
parameter_type = _parameter_type(parameters, parameter_types)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.impl.computation import computation_base
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import type_analysis
from tensorflow_federated.python.core.impl.types import type_conversions
Expand Down Expand Up @@ -486,89 +485,3 @@ def _none_arg(arg):
return functools.partial(_unpack_arg, arg_types, kwarg_types)
else:
return functools.partial(_ensure_arg_type, parameter_type)


class PolymorphicComputation:
"""A generic polymorphic function that accepts arguments of diverse types."""

def __init__(
self,
concrete_function_factory: Callable[
[computation_types.Type, Optional[bool]], computation_base.Computation
],
):
"""Crates a polymorphic function with a given function factory.
Args:
concrete_function_factory: A callable that accepts a (non-None) TFF type
as an argument, as well as an optional boolean `unpack` argument which
should be treated as documented in `create_argument_unpacking_fn` above.
The callable must return a `Computation` instance that's been created to
accept a single positional argument of this TFF type (to be reused for
future calls with parameters of a matching type).
"""
self._concrete_function_factory = concrete_function_factory
self._concrete_function_cache = {}

def fn_for_argument_type(
self, arg_type: computation_types.Type, unpack: Optional[bool] = None
) -> computation_base.Computation:
"""Concretizes this function with the provided `arg_type`.
The first time this function is called with a particular type on a
given `PolymorphicComputation` (or this `PolymorphicComputation` is called
with an argument of the given type), the underlying function will be
traced using the provided argument type as input. Later calls will
return the cached computed concrete function.
Args:
arg_type: The argument type to use when concretizing this function.
unpack: Whether to force unpacking the arguments (`True`), never unpack
the arguments (`False`), or infer whether or not to unpack the arguments
(`None`).
Returns:
The `computation_base.Computation` that results from tracing this
`PolymorphicComputation` with `arg_type.
"""
key = repr(arg_type) + str(unpack)
concrete_fn = self._concrete_function_cache.get(key)
if not concrete_fn:
concrete_fn = (self._concrete_function_factory)(arg_type, unpack)
py_typecheck.check_type(
concrete_fn, computation_base.Computation, 'computation'
)
if concrete_fn.type_signature.parameter != arg_type:
raise TypeError(
'Expected a concrete function that takes parameter {}, got one '
'that takes {}.'.format(
arg_type, concrete_fn.type_signature.parameter
)
)
self._concrete_function_cache[key] = concrete_fn
return concrete_fn

def __call__(self, *args, **kwargs):
"""Invokes this polymorphic function with a given set of arguments.
Args:
*args: Positional args.
**kwargs: Keyword args.
Returns:
The result of calling a concrete function, instantiated on demand based
on the argument types (and cached for future calls).
Raises:
TypeError: if the concrete functions created by the factory are of the
wrong computation_types.
"""
# TODO: b/113112885 - We may need to normalize individuals args, such that
# the type is more predictable and uniform (e.g., if someone supplies an
# unordered dictionary), possibly by converting dict-like and tuple-like
# containers into `Struct`s.
packed_arg = pack_args_into_struct(args, kwargs)
arg_type = type_conversions.infer_type(packed_arg)
# We know the argument types have been packed, so force unpacking.
concrete_fn = self.fn_for_argument_type(arg_type, unpack=True)
return concrete_fn(packed_arg)
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,9 @@
from absl.testing import parameterized
import tensorflow as tf

from tensorflow_federated.proto.v0 import computation_pb2 as pb
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.impl.computation import computation_impl
from tensorflow_federated.python.core.impl.computation import function_utils
from tensorflow_federated.python.core.impl.context_stack import context_base
from tensorflow_federated.python.core.impl.context_stack import context_stack_base
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import type_serialization


class FunctionUtilsTest(parameterized.TestCase):
Expand Down Expand Up @@ -263,94 +258,5 @@ def test_wrap_as_zero_or_one_arg_callable(
self.assertEqual(actual_result, expected_result)


class PolymorphicComputationTest(absltest.TestCase):

def test_call_returns_result(self):
class TestContext(context_base.SyncContext):

def ingest(self, val, type_spec):
return val

def invoke(self, comp, arg):
return 'name={},type={},arg={},unpack={}'.format(
comp.name, comp.type_signature.parameter, arg, comp.unpack
)

class TestContextStack(context_stack_base.ContextStack):

def __init__(self):
super().__init__()
self._context = TestContext()

@property
def current(self):
return self._context

def install(self, ctx):
del ctx # Unused
return self._context

context_stack = TestContextStack()

class TestFunction(computation_impl.ConcreteComputation):

def __init__(self, name, unpack, parameter_type):
self._name = name
self._unpack = unpack
type_signature = computation_types.FunctionType(
parameter_type, tf.string
)
test_proto = pb.Computation(
type=type_serialization.serialize_type(type_signature)
)
super().__init__(test_proto, context_stack, type_signature)

@property
def name(self):
return self._name

@property
def unpack(self):
return self._unpack

class TestFunctionFactory:

def __init__(self):
self._count = 0

def __call__(self, parameter_type, unpack):
self._count = self._count + 1
return TestFunction(str(self._count), str(unpack), parameter_type)

fn = function_utils.PolymorphicComputation(TestFunctionFactory())

self.assertEqual(fn(10), 'name=1,type=<int32>,arg=<10>,unpack=True')
self.assertEqual(
fn(20, x=True), 'name=2,type=<int32,x=bool>,arg=<20,x=True>,unpack=True'
)
fn_with_bool_arg = fn.fn_for_argument_type(
computation_types.to_type(tf.bool)
)
self.assertEqual(
fn_with_bool_arg(True), 'name=3,type=bool,arg=True,unpack=None'
)
self.assertEqual(
fn(30, x=40), 'name=4,type=<int32,x=int32>,arg=<30,x=40>,unpack=True'
)
self.assertEqual(fn(50), 'name=1,type=<int32>,arg=<50>,unpack=True')
self.assertEqual(
fn(0, x=False), 'name=2,type=<int32,x=bool>,arg=<0,x=False>,unpack=True'
)
fn_with_bool_arg = fn.fn_for_argument_type(
computation_types.to_type(tf.bool)
)
self.assertEqual(
fn_with_bool_arg(False), 'name=3,type=bool,arg=False,unpack=None'
)
self.assertEqual(
fn(60, x=70), 'name=4,type=<int32,x=int32>,arg=<60,x=70>,unpack=True'
)


if __name__ == '__main__':
absltest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright 2018, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for Python functions, defuns, and other types of callables."""

from collections.abc import Callable
from typing import Optional

from tensorflow_federated.python.core.impl.computation import computation_impl
from tensorflow_federated.python.core.impl.computation import function_utils
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import type_conversions


class PolymorphicComputation:
"""A generic polymorphic function that accepts arguments of diverse types."""

def __init__(
self,
concrete_function_factory: Callable[
[computation_types.Type, Optional[bool]],
computation_impl.ConcreteComputation,
],
):
"""Crates a polymorphic function with a given function factory.
Args:
concrete_function_factory: A callable that accepts a (non-None) TFF type
as an argument, as well as an optional boolean `unpack` argument which
should be treated as documented in `create_argument_unpacking_fn` above.
The callable must return a `Computation` instance that's been created to
accept a single positional argument of this TFF type (to be reused for
future calls with parameters of a matching type).
"""
self._concrete_function_factory = concrete_function_factory
self._concrete_function_cache = {}

def fn_for_argument_type(
self, arg_type: computation_types.Type, unpack: Optional[bool] = None
) -> computation_impl.ConcreteComputation:
"""Concretizes this function with the provided `arg_type`.
The first time this function is called with a particular type on a
given `PolymorphicComputation` (or this `PolymorphicComputation` is called
with an argument of the given type), the underlying function will be
traced using the provided argument type as input. Later calls will
return the cached computed concrete function.
Args:
arg_type: The argument type to use when concretizing this function.
unpack: Whether to force unpacking the arguments (`True`), never unpack
the arguments (`False`), or infer whether or not to unpack the arguments
(`None`).
Returns:
The `tff.framework.ConcreteComputation` that results from tracing this
`PolymorphicComputation` with `arg_type.
"""
key = repr(arg_type) + str(unpack)
concrete_fn = self._concrete_function_cache.get(key)
if not concrete_fn:
concrete_fn = (self._concrete_function_factory)(arg_type, unpack)
if concrete_fn.type_signature.parameter != arg_type:
raise TypeError(
'Expected a concrete function that takes parameter {}, got one '
'that takes {}.'.format(
arg_type, concrete_fn.type_signature.parameter
)
)
self._concrete_function_cache[key] = concrete_fn
return concrete_fn

def __call__(self, *args, **kwargs):
"""Invokes this polymorphic function with a given set of arguments.
Args:
*args: Positional args.
**kwargs: Keyword args.
Returns:
The result of calling a concrete function, instantiated on demand based
on the argument types (and cached for future calls).
Raises:
TypeError: if the concrete functions created by the factory are of the
wrong computation_types.
"""
# TODO: b/113112885 - We may need to normalize individuals args, such that
# the type is more predictable and uniform (e.g., if someone supplies an
# unordered dictionary), possibly by converting dict-like and tuple-like
# containers into `Struct`s.
packed_arg = function_utils.pack_args_into_struct(args, kwargs)
arg_type = type_conversions.infer_type(packed_arg)
# We know the argument types have been packed, so force unpacking.
concrete_fn = self.fn_for_argument_type(arg_type, unpack=True)
return concrete_fn(packed_arg)
Loading

0 comments on commit da6e997

Please sign in to comment.