Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues with Faker / locale #828

Merged
merged 9 commits into from
Dec 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ build/
dist/
htmlcov/
MANIFEST
tags
11 changes: 11 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ SETUP_PY=setup.py
COVERAGE = python $(shell which coverage)
FLAKE8 = flake8
ISORT = isort
CTAGS = ctags


all: default
Expand Down Expand Up @@ -87,6 +88,16 @@ coverage:
.PHONY: test testall example-test lint coverage


# Development
# ===========

# DOC: Generate a "tags" file
TAGS:
$(CTAGS) --recurse $(PACKAGE) $(TESTS_DIR)

.PHONY: TAGS


# Documentation
# =============

Expand Down
31 changes: 4 additions & 27 deletions factory/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,6 @@
)


PostGenerationContext = collections.namedtuple(
'PostGenerationContext',
['value_provided', 'value', 'extra'],
)


class DeclarationSet:
"""A set of declarations, including the recursive parameters.

Expand Down Expand Up @@ -274,21 +268,10 @@ def build(self, parent_step=None, force_sequence=None):
postgen_results = {}
for declaration_name in post.sorted():
declaration = post[declaration_name]
unrolled_context = declaration.declaration.unroll_context(
instance=instance,
step=step,
context=declaration.context,
)

postgen_context = PostGenerationContext(
value_provided='' in unrolled_context,
value=unrolled_context.get(''),
extra={k: v for k, v in unrolled_context.items() if k != ''},
)
postgen_results[declaration_name] = declaration.declaration.call(
postgen_results[declaration_name] = declaration.declaration.evaluate_post(
instance=instance,
step=step,
context=postgen_context,
overrides=declaration.context,
)
self.factory_meta.use_postgeneration_results(
instance=instance,
Expand Down Expand Up @@ -358,16 +341,10 @@ def __getattr__(self, name):
if enums.get_builder_phase(value) == enums.BuilderPhase.ATTRIBUTE_RESOLUTION:
self.__pending.append(name)
try:
context = value.unroll_context(
instance=self,
step=self.__step,
context=declaration.context,
)

value = value.evaluate(
value = value.evaluate_pre(
instance=self,
step=self.__step,
extra=context,
overrides=declaration.context,
)
finally:
last = self.__pending.pop()
Expand Down
139 changes: 56 additions & 83 deletions factory/declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import itertools
import logging
import typing as T

from . import enums, errors, utils

Expand All @@ -23,16 +24,28 @@ class BaseDeclaration(utils.OrderedBase):
#: Set to False on declarations that perform their own unrolling.
UNROLL_CONTEXT_BEFORE_EVALUATION = True

def __init__(self, **defaults):
super().__init__()
self._defaults = defaults or {}

def unroll_context(self, instance, step, context):
full_context = dict()
full_context.update(self._defaults)
full_context.update(context)

if not self.UNROLL_CONTEXT_BEFORE_EVALUATION:
return context
if not any(enums.get_builder_phase(v) for v in context.values()):
return full_context
if not any(enums.get_builder_phase(v) for v in full_context.values()):
# Optimization for simple contexts - don't do anything.
return context
return full_context

import factory.base
subfactory = factory.base.DictFactory
return step.recurse(subfactory, context, force_sequence=step.sequence)
return step.recurse(subfactory, full_context, force_sequence=step.sequence)

def evaluate_pre(self, instance, step, overrides):
context = self.unroll_context(instance, step, overrides)
return self.evaluate(instance, step, context)

def evaluate(self, instance, step, extra):
"""Evaluate this declaration.
Expand Down Expand Up @@ -61,8 +74,8 @@ class LazyFunction(BaseDeclaration):
returning the computed value.
"""

def __init__(self, function, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, function):
super().__init__()
self.function = function

def evaluate(self, instance, step, extra):
Expand All @@ -78,8 +91,8 @@ class LazyAttribute(BaseDeclaration):
returning the computed value.
"""

def __init__(self, function, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, function):
super().__init__()
self.function = function

def evaluate(self, instance, step, extra):
Expand Down Expand Up @@ -133,8 +146,8 @@ class SelfAttribute(BaseDeclaration):
exist.
"""

def __init__(self, attribute_name, default=_UNSPECIFIED, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, attribute_name, default=_UNSPECIFIED):
super().__init__()
depth = len(attribute_name) - len(attribute_name.lstrip('.'))
attribute_name = attribute_name[depth:]

Expand Down Expand Up @@ -241,8 +254,8 @@ class ContainerAttribute(BaseDeclaration):
strict (bool): Whether evaluating should fail when the containers are
not passed in (i.e used outside a SubFactory).
"""
def __init__(self, function, strict=True, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, function, strict=True):
super().__init__()
self.function = function
self.strict = strict

Expand Down Expand Up @@ -272,28 +285,8 @@ class ParameteredAttribute(BaseDeclaration):
Attributes:
defaults (dict): Default values for the parameters.
May be overridden by call-time parameters.

Class attributes:
CONTAINERS_FIELD (str): name of the field, if any, where container
information (e.g for SubFactory) should be stored. If empty,
containers data isn't merged into generate() parameters.
"""

CONTAINERS_FIELD = '__containers'

# Whether to add the current object to the stack of containers
EXTEND_CONTAINERS = False

def __init__(self, **kwargs):
super().__init__()
self.defaults = kwargs

def _prepare_containers(self, obj, containers=()):
if self.EXTEND_CONTAINERS:
return (obj,) + tuple(containers)

return containers

def evaluate(self, instance, step, extra):
"""Evaluate the current definition and fill its attributes.

Expand All @@ -308,11 +301,7 @@ def evaluate(self, instance, step, extra):
extra (dict): additional, call-time added kwargs
for the step.
"""
defaults = dict(self.defaults)
if extra:
defaults.update(extra)

return self.generate(step, defaults)
return self.generate(step, extra)

def generate(self, step, params):
"""Actually generate the related attribute.
Expand All @@ -331,39 +320,6 @@ def generate(self, step, params):
raise NotImplementedError()


class ParameteredDeclaration(BaseDeclaration):
"""A declaration with parameters.

The parameters can be any factory-enabled declaration, and will be resolved
before the call to the user-defined code in `self.generate()`.

Attributes:
defaults (dict): Default values for the parameters; can be overridden
by call-time parameters. Accepts BaseDeclaration subclasses.
"""

def __init__(self, **defaults):
self.defaults = defaults
super().__init__()

def unroll_context(self, instance, step, context):
merged_context = {}
merged_context.update(self.defaults)
merged_context.update(context)
return super().unroll_context(instance, step, merged_context)

def evaluate(self, instance, step, extra):
return self.generate(extra)

def generate(self, params):
"""Generate a value for this declaration.

Args:
params (dict): the parameters, after a factory evaluation.
"""
raise NotImplementedError()


class _FactoryWrapper:
"""Handle a 'factory' arg.

Expand Down Expand Up @@ -398,7 +354,7 @@ def __repr__(self):
return f'<_FactoryImport: {self.factory.__class__}>'


class SubFactory(ParameteredAttribute):
class SubFactory(BaseDeclaration):
"""Base class for attributes based upon a sub-factory.

Attributes:
Expand All @@ -407,7 +363,6 @@ class SubFactory(ParameteredAttribute):
factory (base.Factory): the wrapped factory
"""

EXTEND_CONTAINERS = True
# Whether to align the attribute's sequence counter to the holding
# factory's sequence counter
FORCE_SEQUENCE = False
Expand All @@ -421,7 +376,7 @@ def get_factory(self):
"""Retrieve the wrapped factory.Factory subclass."""
return self.factory_wrapper.get()

def generate(self, step, params):
def evaluate(self, instance, step, extra):
"""Evaluate the current definition and fill its attributes.

Args:
Expand All @@ -433,11 +388,11 @@ def generate(self, step, params):
logger.debug(
"SubFactory: Instantiating %s.%s(%s), create=%r",
subfactory.__module__, subfactory.__name__,
utils.log_pprint(kwargs=params),
utils.log_pprint(kwargs=extra),
step,
)
force_sequence = step.sequence if self.FORCE_SEQUENCE else None
return step.recurse(subfactory, params, force_sequence=force_sequence)
return step.recurse(subfactory, extra, force_sequence=force_sequence)


class Dict(SubFactory):
Expand Down Expand Up @@ -494,36 +449,39 @@ def __init__(self, decider, yes_declaration=SKIP, no_declaration=SKIP):

self.FACTORY_BUILDER_PHASE = used_phases.pop() if used_phases else enums.BuilderPhase.ATTRIBUTE_RESOLUTION

def call(self, instance, step, context):
def evaluate_post(self, instance, step, overrides):
"""Handle post-generation declarations"""
decider_phase = enums.get_builder_phase(self.decider)
if decider_phase == enums.BuilderPhase.ATTRIBUTE_RESOLUTION:
# Note: we work on the *builder stub*, not on the actual instance.
# This gives us access to all Params-level definitions.
choice = self.decider.evaluate(instance=step.stub, step=step, extra=context.extra)
choice = self.decider.evaluate_pre(
instance=step.stub, step=step, overrides=overrides)
else:
assert decider_phase == enums.BuilderPhase.POST_INSTANTIATION
choice = self.decider.call(instance, step, context)
choice = self.decider.evaluate_post(
instance=instance, step=step, overrides={})

target = self.yes if choice else self.no
if enums.get_builder_phase(target) == enums.BuilderPhase.POST_INSTANTIATION:
return target.call(
return target.evaluate_post(
instance=instance,
step=step,
context=context,
overrides=overrides,
)
else:
# Flat value (can't be ATTRIBUTE_RESOLUTION, checked in __init__)
return target

def evaluate(self, instance, step, extra):
def evaluate_pre(self, instance, step, overrides):
choice = self.decider.evaluate(instance=instance, step=step, extra={})
target = self.yes if choice else self.no

if isinstance(target, BaseDeclaration):
return target.evaluate(
return target.evaluate_pre(
instance=instance,
step=step,
extra=extra,
overrides=overrides,
)
else:
# Flat value (can't be POST_INSTANTIATION, checked in __init__)
Expand Down Expand Up @@ -613,11 +571,26 @@ def __repr__(self):
# ===============


class PostGenerationContext(T.NamedTuple):
value_provided: bool
value: T.Any
extra: T.Dict[str, T.Any]


class PostGenerationDeclaration(BaseDeclaration):
"""Declarations to be called once the model object has been generated."""

FACTORY_BUILDER_PHASE = enums.BuilderPhase.POST_INSTANTIATION

def evaluate_post(self, instance, step, overrides):
context = self.unroll_context(instance, step, overrides)
postgen_context = PostGenerationContext(
value_provided=bool('' in context),
value=context.get(''),
extra={k: v for k, v in context.items() if k != ''},
)
return self.call(instance, step, postgen_context)

def call(self, instance, step, context): # pragma: no cover
"""Call this hook; no return value is expected.

Expand Down
6 changes: 3 additions & 3 deletions factory/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _after_postgeneration(cls, instance, create, results=None):
instance.save()


class FileField(declarations.ParameteredDeclaration):
class FileField(declarations.BaseDeclaration):
"""Helper to fill in django.db.models.FileField from a Factory."""

DEFAULT_FILENAME = 'example.dat'
Expand Down Expand Up @@ -219,9 +219,9 @@ def _make_content(self, params):
filename = params.get('filename', default_filename)
return filename, content

def generate(self, params):
def evaluate(self, instance, step, extra):
"""Fill in the field."""
filename, content = self._make_content(params)
filename, content = self._make_content(extra)
return django_files.File(content.file, filename)


Expand Down
Loading