Skip to content

Commit

Permalink
refactor: adopt optimized variable name tracing mechanism (Azure#27776)
Browse files Browse the repository at this point in the history
* refactor: adopt optimized variable name tracing mechanism

* feat: add support for python 3.11

* refactor: add PersistentLocalsFunctionBuilder

* fix: try applying new refactor to all

* refactor: minor optimization

* refactor: add some more e2e tests
  • Loading branch information
elliotzh committed Dec 2, 2022
1 parent 89deaed commit 4b40de5
Show file tree
Hide file tree
Showing 5 changed files with 363 additions and 22 deletions.
198 changes: 198 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_utils/_func_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from types import FunctionType, MethodType
from typing import List, Union, Callable, Any

from bytecode import Instr, Bytecode


class PersistentLocalsFunction(object):
"""Wrapper class for the 'persistent_locals' decorator.
Refer to the docstring of instances for help about the wrapped
function.
"""

def __init__(self, _func, *, _self: Any = None, skip_locals: List[str] = None):
"""
:param _func: The function to be wrapped.
:param _self: If original func is a method, _self should be provided, which is the instance of the method.
:param skip_locals: A list of local variables to skip when saving the locals.
"""
self.locals = {}
self._self = _self
# make function an instance method
self._func = MethodType(_func, self)
self._skip_locals = skip_locals

def __call__(__self, *args, **kwargs): # pylint: disable=no-self-argument
# Use __self in case self is also passed as a named argument in kwargs
__self.locals.clear()
try:
if __self._self:
return __self._func(__self._self, *args, **kwargs) # pylint: disable=not-callable
return __self._func(*args, **kwargs) # pylint: disable=not-callable
finally:
# always pop skip locals even if exception is raised in user code
if __self._skip_locals is not None:
for skip_local in __self._skip_locals:
__self.locals.pop(skip_local, None)


def _source_template_func(mock_arg):
return mock_arg


def _target_template_func(__self, mock_arg):
try:
return mock_arg
finally:
__self.locals = locals().copy()


class PersistentLocalsFunctionBuilder(object):
def __init__(self):
self._template_separators = self._clear_location(Bytecode.from_code(_source_template_func.__code__))

template = self._clear_location(Bytecode.from_code(_target_template_func.__code__))
self._template_body = self.split_bytecode(template)
# after split, len(self._template_body) will be len(self._separators) + 1
# pop tail to make zip work
self._template_tail = self._template_body.pop()
self._injected_param = template.argnames[0]

def split_bytecode(self, bytecode: Bytecode, *, skip_body_instr=False) -> List[List[Instr]]:
"""Split bytecode into several parts by template separators.
For example, in Python 3.11, the template separators will be:
[
Instr('RESUME', 0), # initial instruction shared by all functions
Instr('LOAD_FAST', 'mock_arg'), # the body execution instruction
Instr('RETURN_VALUE'), # the return instruction shared by all functions
]
Then we will split the target template bytecode into 4 parts.
For passed in bytecode, we should skip the body execution instruction, which is from template,
and split it into 3 parts.
"""
pieces = []
piece = Bytecode()

separator_iter = iter(self._template_separators)

def get_next_separator():
try:
_s = next(separator_iter)
if skip_body_instr and _s == self.get_body_instruction():
_s = next(separator_iter)
return _s
except StopIteration:
return None

cur_separator = get_next_separator()
for instr in self._clear_location(bytecode):
if instr == cur_separator:
# skip the separator
pieces.append(piece)
cur_separator = get_next_separator()
piece = Bytecode()
else:
piece.append(instr)
pieces.append(piece)

if cur_separator is not None:
raise ValueError('Not all template separators are used, please switch to a compatible version of Python.')
return pieces

@classmethod
def get_body_instruction(cls):
"""Get the body execution instruction in template."""
return Instr('LOAD_FAST', 'mock_arg')

@classmethod
def _clear_location(cls, bytecode: Bytecode) -> Bytecode:
"""Clear location information of bytecode instructions and return the cleared bytecode."""
for i, instr in enumerate(bytecode):
if isinstance(instr, Instr):
bytecode[i] = Instr(instr.name, instr.arg)
return bytecode

def _create_base_bytecode(self, func: Union[FunctionType, MethodType]) -> Bytecode:
"""Create the base bytecode for the function to be generated.
Will keep information of the function, such as name, globals, etc., but skip all instructions.
"""
generated_bytecode = Bytecode.from_code(func.__code__)
generated_bytecode.clear()

if self._injected_param in generated_bytecode.argnames:
raise ValueError('Injected param name {} conflicts with function args {}'.format(
self._injected_param,
generated_bytecode.argnames
))
generated_bytecode.argnames.insert(0, self._injected_param)
generated_bytecode.argcount += 1 # pylint: disable=no-member
return generated_bytecode

def _build_func(self, func: Union[FunctionType, MethodType]) -> PersistentLocalsFunction:
generated_bytecode = self._create_base_bytecode(func)

for template_piece, input_piece, separator in zip(
self._template_body,
self.split_bytecode(
Bytecode.from_code(func.__code__),
skip_body_instr=True
),
self._template_separators
):
generated_bytecode.extend(template_piece)
generated_bytecode.extend(input_piece)
if separator != self.get_body_instruction():
generated_bytecode.append(separator)
generated_bytecode.extend(self._template_tail)

generated_code = generated_bytecode.to_code()
generated_func = FunctionType(
generated_code,
func.__globals__,
func.__name__,
func.__defaults__,
func.__closure__
)
return PersistentLocalsFunction(
generated_func,
_self=func.__self__ if isinstance(func, MethodType) else None,
skip_locals=[self._injected_param]
)

def build(self, func: Callable):
"""Build a persistent locals function from the given function.
Detailed impact is described in the docstring of the persistent_locals decorator.
"""
if isinstance(func, (FunctionType, MethodType)):
pass
elif hasattr(func, '__call__'):
func = func.__call__
else:
raise TypeError('func must be a function or a callable object')
return self._build_func(func)


def persistent_locals(func):
"""
Use bytecode injection to add try...finally statement around code to persistent the locals in the function.
It will change the func bytecode like this:
def func(__self, *func_args):
try:
the func code...
finally:
__self.locals = locals().copy()
You can get the locals in func by this code:
persistent_locals_func = persistent_locals(your_func)
# Execute your func
result = persistent_locals_func(*args)
# Get the locals in the func.
func_locals = persistent_locals_func.locals
"""
return PersistentLocalsFunctionBuilder().build(func)
46 changes: 24 additions & 22 deletions sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_pipeline_component_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@

# pylint: disable=protected-access
import copy
import sys
import typing
from collections import OrderedDict
from contextlib import contextmanager
from inspect import Parameter, signature
from typing import Callable, Union

from azure.ai.ml._utils._func_utils import persistent_locals
from azure.ai.ml._utils.utils import (
get_all_enum_values_iter,
is_private_preview_enabled,
Expand Down Expand Up @@ -114,17 +113,6 @@ def tracer(frame, event, arg): # pylint: disable=unused-argument
return tracer


@contextmanager
def replace_sys_profiler(profiler):
"""A context manager which replaces sys profiler to given profiler."""
original_profiler = sys.getprofile()
sys.setprofile(profiler)
try:
yield
finally:
sys.setprofile(original_profiler)


class PipelineComponentBuilder:
# map from python built-in type to component type
# pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -203,17 +191,9 @@ def build(self, *, user_provided_kwargs=None,
non_pipeline_inputs=non_pipeline_inputs
)
kwargs.update(non_pipeline_inputs_dict or {})
# We use this stack to store the dsl pipeline definition hierarchy
_definition_builder_stack.push(self)

# Use a dict to store all variables in self.func
_locals = {}
func_variable_profiler = get_func_variable_tracer(_locals, self.func.__code__)
try:
with replace_sys_profiler(func_variable_profiler):
outputs = self.func(**kwargs)
finally:
_definition_builder_stack.pop()
outputs, _locals = self._get_outputs_and_locals(kwargs)

if outputs is None:
outputs = {}
Expand All @@ -235,6 +215,28 @@ def build(self, *, user_provided_kwargs=None,
pipeline_component._outputs = self._build_pipeline_outputs(outputs)
return pipeline_component

def _get_outputs_and_locals(
self,
_all_kwargs: typing.Dict[str, typing.Any]
) -> typing.Tuple[typing.Dict, typing.Dict]:
"""Get outputs and locals from self.func.
Locals will be used to update node variable names.
:param _all_kwargs: All kwargs to call self.func.
:type _all_kwargs: typing.Dict[str, typing.Any]
:return: A tuple of outputs and locals.
:rtype: typing.Tuple[typing.Dict, typing.Dict]
"""
# We use this stack to store the dsl pipeline definition hierarchy
_definition_builder_stack.push(self)

try:
persistent_func = persistent_locals(self.func)
outputs = persistent_func(**_all_kwargs)
return outputs, persistent_func.locals
finally:
_definition_builder_stack.pop()

def _validate_group_annotation(self, name:str, val:GroupInput):
for k, v in val.values.items():
if isinstance(v, GroupInput):
Expand Down
2 changes: 2 additions & 0 deletions sdk/ml/azure-ai-ml/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@
"azure-common<2.0.0,>=1.1",
"typing-extensions<5.0.0",
"opencensus-ext-azure<2.0.0",
# Used in pipeline_component_builder
'bytecode<0.15.0,>=0.13.0',
],
extras_require={
# user can run `pip install azure-ai-ml[designer]` to install mldesigner alone with this package
Expand Down
49 changes: 49 additions & 0 deletions sdk/ml/azure-ai-ml/tests/dsl/unittests/test_pipeline_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from pathlib import Path
import pytest

from azure.ai.ml import dsl, load_component, Input

tests_root_dir = Path(__file__).parent.parent.parent
components_dir = tests_root_dir / "test_configs/components/"


@pytest.mark.unittest
@pytest.mark.pipeline_test
class TestPersistentLocals:
def test_simple(self):
component_yaml = components_dir / "helloworld_component_optional_input.yml"
component_func = load_component(component_yaml)

@dsl.pipeline
def pipeline_func(required_input: int, optional_input: int = 2):
named_step = component_func(required_input=required_input, optional_input=optional_input)

pipeline_job = pipeline_func(1, 2)
assert 'named_step' in pipeline_job.jobs

def test_raise_exception(self):
@dsl.pipeline
def mock_error_exception():
mock_local_variable = 1
return mock_local_variable / 0

with pytest.raises(ZeroDivisionError):
mock_error_exception()

def test_instance_func(self):
component_yaml = components_dir / "helloworld_component_optional_input.yml"
component_func = load_component(component_yaml)

class MockClass(Input):
def __init__(self, mock_path):
super(MockClass, self).__init__(path=mock_path)

@dsl.pipeline
def pipeline_func(self, required_input: int, optional_input: int = 2):
named_step = component_func(required_input=required_input, optional_input=optional_input)

mock_obj = MockClass("./some/path")
pipeline_job = mock_obj.pipeline_func(1, 2)
assert 'named_step' in pipeline_job.jobs
assert 'self' in pipeline_job.inputs
assert pipeline_job.inputs['self'].path == "./some/path"
Loading

0 comments on commit 4b40de5

Please sign in to comment.