Skip to content

Commit

Permalink
Better errors when inputs are omitted. (#28289)
Browse files Browse the repository at this point in the history
It's not always possible to know if a transform consumes inputs,
or can act as a root transform (and in fact some may be able to do
both depending on their configuration), but when a transform
expecting inputs doesn't get them the error can be quite obscure.
This adds best-effort checking and a better error in that case.

We also allow explicitly setting empty imputs to work around this
error (which is where most of the complexity of this change lies).
Importantly, sources (no matter their name) are not required to have
inputs.
  • Loading branch information
robertwb committed Sep 18, 2023
1 parent d6068ad commit 3024ec2
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 75 deletions.
89 changes: 60 additions & 29 deletions sdks/python/apache_beam/yaml/yaml_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ def provided_transforms(self) -> Iterable[str]:
"""Returns a list of transform type names this provider can handle."""
raise NotImplementedError(type(self))

def requires_inputs(self, typ: str, args: Mapping[str, Any]) -> bool:
"""Returns whether this transform requires inputs.
Specifically, if this returns True and inputs are not provided than an error
will be thrown.
This is best-effort, primarily for better and earlier error messages.
"""
return not typ.startswith('Read')

def create_transform(
self,
typ: str,
Expand Down Expand Up @@ -129,9 +139,7 @@ def __init__(self, urns, service):
def provided_transforms(self):
return self._urns.keys()

def create_transform(self, type, args, yaml_create_transform):
if callable(self._service):
self._service = self._service()
def schema_transforms(self):
if self._schema_transforms is None:
try:
self._schema_transforms = {
Expand All @@ -142,8 +150,19 @@ def create_transform(self, type, args, yaml_create_transform):
except Exception:
# It's possible this service doesn't vend schema transforms.
self._schema_transforms = {}
return self._schema_transforms

def requires_inputs(self, typ, args):
if self._urns[type] in self.schema_transforms():
return bool(self.schema_transforms()[self._urns[type]].inputs)
else:
return super().requires_inputs(typ, args)

def create_transform(self, type, args, yaml_create_transform):
if callable(self._service):
self._service = self._service()
urn = self._urns[type]
if urn in self._schema_transforms:
if urn in self.schema_transforms():
return external.SchemaAwareExternalTransform(
urn, self._service, rearrange_based_on_discovery=True, **args)
else:
Expand Down Expand Up @@ -359,8 +378,9 @@ def fn_takes_side_inputs(fn):


class InlineProvider(Provider):
def __init__(self, transform_factories):
def __init__(self, transform_factories, no_input_transforms=()):
self._transform_factories = transform_factories
self._no_input_transforms = set(no_input_transforms)

def available(self):
return True
Expand All @@ -377,6 +397,14 @@ def create_transform(self, type, args, yaml_create_transform):
def to_json(self):
return {'type': "InlineProvider"}

def requires_inputs(self, typ, args):
if typ in self._no_input_transforms:
return False
elif hasattr(self._transform_factories[typ], '_yaml_requires_inputs'):
return self._transform_factories[typ]._yaml_requires_inputs
else:
return super().requires_inputs(typ, args)


class MetaInlineProvider(InlineProvider):
def create_transform(self, type, args, yaml_create_transform):
Expand Down Expand Up @@ -508,30 +536,30 @@ def _parse_window_spec(spec):
# TODO: Triggering, etc.
return beam.WindowInto(window_fn)

return InlineProvider(
dict({
'Create': create,
'PyMap': lambda fn: beam.Map(
python_callable.PythonCallableWithSource(fn)),
'PyMapTuple': lambda fn: beam.MapTuple(
python_callable.PythonCallableWithSource(fn)),
'PyFlatMap': lambda fn: beam.FlatMap(
python_callable.PythonCallableWithSource(fn)),
'PyFlatMapTuple': lambda fn: beam.FlatMapTuple(
python_callable.PythonCallableWithSource(fn)),
'PyFilter': lambda keep: beam.Filter(
python_callable.PythonCallableWithSource(keep)),
'PyTransform': fully_qualified_named_transform,
'PyToRow': lambda fields: beam.Select(
**{
name: python_callable.PythonCallableWithSource(fn)
for (name, fn) in fields.items()
}),
'WithSchema': with_schema,
'Flatten': Flatten,
'WindowInto': WindowInto,
'GroupByKey': beam.GroupByKey,
}))
return InlineProvider({
'Create': create,
'PyMap': lambda fn: beam.Map(
python_callable.PythonCallableWithSource(fn)),
'PyMapTuple': lambda fn: beam.MapTuple(
python_callable.PythonCallableWithSource(fn)),
'PyFlatMap': lambda fn: beam.FlatMap(
python_callable.PythonCallableWithSource(fn)),
'PyFlatMapTuple': lambda fn: beam.FlatMapTuple(
python_callable.PythonCallableWithSource(fn)),
'PyFilter': lambda keep: beam.Filter(
python_callable.PythonCallableWithSource(keep)),
'PyTransform': fully_qualified_named_transform,
'PyToRow': lambda fields: beam.Select(
**{
name: python_callable.PythonCallableWithSource(fn)
for (name, fn) in fields.items()
}),
'WithSchema': with_schema,
'Flatten': Flatten,
'WindowInto': WindowInto,
'GroupByKey': beam.GroupByKey,
},
no_input_transforms=('Create', ))


class PypiExpansionService:
Expand Down Expand Up @@ -639,6 +667,9 @@ def available(self) -> bool:
def provided_transforms(self) -> Iterable[str]:
return self._transforms.keys()

def requires_inputs(self, typ, args):
return self._underlying_provider.requires_inputs(typ, args)

def create_transform(
self,
typ: str,
Expand Down
84 changes: 67 additions & 17 deletions sdks/python/apache_beam/yaml/yaml_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,28 @@ def only_element(xs):
return x


# These allow a user to explicitly pass no input to a transform (i.e. use it
# as a root transform) without an error even if the transform is not known to
# handle it.
def explicitly_empty():
return {'__explicitly_empty__': None}


def is_explicitly_empty(io):
return io == explicitly_empty()


def is_empty(io):
return not io or is_explicitly_empty(io)


def empty_if_explicitly_empty(io):
if is_explicitly_empty(io):
return {}
else:
return io


class SafeLineLoader(SafeLoader):
"""A yaml loader that attaches line information to mappings and strings."""
class TaggedString(str):
Expand Down Expand Up @@ -186,7 +208,7 @@ def followers(self, transform_name):
# TODO(yaml): Also trace through outputs and composites.
for transform in self._transforms:
if transform['type'] != 'composite':
for input in transform.get('input').values():
for input in empty_if_explicitly_empty(transform['input']).values():
transform_id, _ = self.get_transform_id_and_output_name(input)
self._all_followers[transform_id].append(transform['__uuid__'])
return self._all_followers[self.get_transform_id(transform_name)]
Expand Down Expand Up @@ -324,6 +346,12 @@ def create_ptransform(self, spec, input_pcolls):
raise ValueError(
'Config for transform at %s must be a mapping.' %
identify_object(spec))

if (not input_pcolls and not is_explicitly_empty(spec.get('input', {})) and
provider.requires_inputs(spec['type'], config)):
raise ValueError(
f'Missing inputs for transform at {identify_object(spec)}')

try:
# pylint: disable=undefined-loop-variable
ptransform = provider.create_transform(
Expand Down Expand Up @@ -402,7 +430,7 @@ def expand_leaf_transform(spec, scope):
spec = normalize_inputs_outputs(spec)
inputs_dict = {
key: scope.get_pcollection(value)
for (key, value) in spec['input'].items()
for (key, value) in empty_if_explicitly_empty(spec['input']).items()
}
input_type = spec.get('input_type', 'default')
if input_type == 'list':
Expand Down Expand Up @@ -442,10 +470,10 @@ def expand_composite_transform(spec, scope):
spec = normalize_inputs_outputs(normalize_source_sink(spec))

inner_scope = Scope(
scope.root, {
scope.root,
{
key: scope.get_pcollection(value)
for key,
value in spec['input'].items()
for (key, value) in empty_if_explicitly_empty(spec['input']).items()
},
spec['transforms'],
yaml_provider.merge_providers(
Expand All @@ -470,8 +498,7 @@ def expand(inputs):
_LOGGER.info("Expanding %s ", identify_object(spec))
return ({
key: scope.get_pcollection(value)
for key,
value in spec['input'].items()
for (key, value) in empty_if_explicitly_empty(spec['input']).items()
} or scope.root) | scope.unique_name(spec, None) >> CompositePTransform()


Expand All @@ -496,12 +523,25 @@ def is_not_output_of_last_transform(new_transforms, value):
composite_spec = normalize_inputs_outputs(spec)
new_transforms = []
for ix, transform in enumerate(composite_spec['transforms']):
if any(io in transform for io in ('input', 'output', 'input', 'output')):
raise ValueError(
f'Transform {identify_object(transform)} is part of a chain, '
'must have implicit inputs and outputs.')
if any(io in transform for io in ('input', 'output')):
if (ix == 0 and 'input' in transform and 'output' not in transform and
is_explicitly_empty(transform['input'])):
# This is OK as source clause sets an explicitly empty input.
pass
else:
raise ValueError(
f'Transform {identify_object(transform)} is part of a chain, '
'must have implicit inputs and outputs.')
if ix == 0:
transform['input'] = {key: key for key in composite_spec['input'].keys()}
if is_explicitly_empty(transform.get('input', None)):
pass
elif is_explicitly_empty(composite_spec['input']):
transform['input'] = composite_spec['input']
else:
transform['input'] = {
key: key
for key in composite_spec['input'].keys()
}
else:
transform['input'] = new_transforms[-1]['__uuid__']
new_transforms.append(transform)
Expand Down Expand Up @@ -554,6 +594,8 @@ def normalize_source_sink(spec):
spec = dict(spec)
spec['transforms'] = list(spec.get('transforms', []))
if 'source' in spec:
if 'input' not in spec['source']:
spec['source']['input'] = explicitly_empty()
spec['transforms'].insert(0, spec.pop('source'))
if 'sink' in spec:
spec['transforms'].append(spec.pop('sink'))
Expand All @@ -567,6 +609,13 @@ def preprocess_source_sink(spec):
return spec


def tag_explicit_inputs(spec):
if 'input' in spec and not SafeLineLoader.strip_metadata(spec['input']):
return dict(spec, input=explicitly_empty())
else:
return spec


def normalize_inputs_outputs(spec):
spec = dict(spec)

Expand Down Expand Up @@ -611,7 +660,7 @@ def push_windowing_to_roots(spec):
scope = LightweightScope(spec['transforms'])
consumed_outputs_by_transform = collections.defaultdict(set)
for transform in spec['transforms']:
for _, input_ref in transform['input'].items():
for _, input_ref in empty_if_explicitly_empty(transform['input']).items():
try:
transform_id, output = scope.get_transform_id_and_output_name(input_ref)
consumed_outputs_by_transform[transform_id].add(output)
Expand All @@ -620,7 +669,7 @@ def push_windowing_to_roots(spec):
pass

for transform in spec['transforms']:
if not transform['input'] and 'windowing' not in transform:
if is_empty(transform['input']) and 'windowing' not in transform:
transform['windowing'] = spec['windowing']
transform['__consumed_outputs'] = consumed_outputs_by_transform[
transform['__uuid__']]
Expand All @@ -647,7 +696,7 @@ def preprocess_windowing(spec):
spec = push_windowing_to_roots(spec)

windowing = spec.pop('windowing')
if spec['input']:
if not is_empty(spec['input']):
# Apply the windowing to all inputs by wrapping it in a transform that
# first applies windowing and then applies the original transform.
original_inputs = spec['input']
Expand Down Expand Up @@ -778,7 +827,7 @@ def ensure_errors_consumed(spec):
raise ValueError(
f'Missing output in error_handling of {identify_object(t)}')
to_handle[t['__uuid__'], config['error_handling']['output']] = t
for _, input in t['input'].items():
for _, input in empty_if_explicitly_empty(t['input']).items():
if input not in spec['input']:
consumed.add(scope.get_transform_id_and_output_name(input))
for error_pcoll, t in to_handle.items():
Expand Down Expand Up @@ -815,7 +864,7 @@ def preprocess(spec, verbose=False, known_transforms=None):

def apply(phase, spec):
spec = phase(spec)
if spec['type'] in {'composite', 'chain'}:
if spec['type'] in {'composite', 'chain'} and 'transforms' in spec:
spec = dict(
spec, transforms=[apply(phase, t) for t in spec['transforms']])
return spec
Expand All @@ -835,6 +884,7 @@ def ensure_transforms_have_providers(spec):
ensure_transforms_have_providers,
preprocess_source_sink,
preprocess_chain,
tag_explicit_inputs,
normalize_inputs_outputs,
preprocess_flattened_inputs,
ensure_errors_consumed,
Expand Down
31 changes: 2 additions & 29 deletions sdks/python/apache_beam/yaml/yaml_transform_scope_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,40 +88,13 @@ def test_create_ptransform(self):
spec = '''
transforms:
- type: PyMap
input: something
config:
fn: "lambda x: x*x"
'''
scope, spec = self.get_scope_by_spec(p, spec)

result = scope.create_ptransform(spec['transforms'][0], [])
self.assertIsInstance(result, beam.transforms.ParDo)
self.assertEqual(result.label, 'Map(lambda x: x*x)')

result_annotations = {**result.annotations()}
target_annotations = {
'yaml_type': 'PyMap',
'yaml_args': '{"fn": "lambda x: x*x"}',
'yaml_provider': '{"type": "InlineProvider"}'
}

# Check if target_annotations is a subset of result_annotations
self.assertDictEqual(
result_annotations, {
**result_annotations, **target_annotations
})

def test_create_ptransform_with_inputs(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
spec = '''
transforms:
- type: PyMap
config:
fn: "lambda x: x*x"
'''
scope, spec = self.get_scope_by_spec(p, spec)

result = scope.create_ptransform(spec['transforms'][0], [])
result = scope.create_ptransform(spec['transforms'][0], ['something'])
self.assertIsInstance(result, beam.transforms.ParDo)
self.assertEqual(result.label, 'Map(lambda x: x*x)')

Expand Down
Loading

0 comments on commit 3024ec2

Please sign in to comment.