diff --git a/CHANGELOG.md b/CHANGELOG.md index a4f4b411..4420f32e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Added +- Add new `_bulk_create` parameter to `make` for using Django manager `bulk_create` with `_quantity` [PR #134](https://github.com/model-bakers/model_bakery/pull/134) + ### Changed - Type hinting fixed for Recipe "_model" parameter diff --git a/model_bakery/baker.py b/model_bakery/baker.py index 54f8e6e2..02eaabcc 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -54,6 +54,7 @@ def make( _refresh_after_create: bool = False, _create_files: bool = False, _using: str = "", + _bulk_create: bool = False, **attrs: Any ): """Create a persisted instance from a given model its associated models. @@ -68,7 +69,14 @@ def make( if _valid_quantity(_quantity): raise InvalidQuantityException - if _quantity: + if _quantity and _bulk_create: + return baker.model._base_manager.bulk_create( + [ + baker.prepare(_save_kwargs=_save_kwargs, **attrs) + for _ in range(_quantity) + ] + ) + elif _quantity: return [ baker.make( _save_kwargs=_save_kwargs, @@ -77,6 +85,7 @@ def make( ) for _ in range(_quantity) ] + return baker.make( _save_kwargs=_save_kwargs, _refresh_after_create=_refresh_after_create, **attrs ) diff --git a/tests/generic/models.py b/tests/generic/models.py index 24e4a8ca..791102a9 100755 --- a/tests/generic/models.py +++ b/tests/generic/models.py @@ -409,3 +409,9 @@ class Meta(object): class SubclassOfAbstract(AbstractModel): height = models.IntegerField() + + +class NonStandardManager(models.Model): + name = models.CharField(max_length=30) + + manager = models.Manager() diff --git a/tests/test_baker.py b/tests/test_baker.py index 6595945f..3fa29b5c 100644 --- a/tests/test_baker.py +++ b/tests/test_baker.py @@ -5,6 +5,7 @@ import pytest from django.conf import settings +from django.db import connection from django.db.models import Manager from django.db.models.signals import m2m_changed from django.test import TestCase @@ -29,6 +30,45 @@ def test_import_seq_from_baker(): pytest.fail("{} raised".format(ImportError.__name__)) +class QueryCount: + """ + Keep track of db calls. + + Example: + ======== + + qc = QueryCount() + + with qc.start_count(): + MyModel.objects.get(pk=1) + MyModel.objects.create() + + qc.count # 2 + + """ + + def __init__(self): + self.count = 0 + + def __call__(self, execute, sql, params, many, context): + """ + `django.db.connection.execute_wrapper` callback + + https://docs.djangoproject.com/en/3.1/topics/db/instrumentation/ + """ + self.count += 1 + execute(sql, params, many, context) + + def start_count(self): + """ + Reset query count to 0 and return context manager for wrapping db + queries. + """ + self.count = 0 + + return connection.execute_wrapper(self) + + class TestsModelFinder: def test_unicode_regression(self): obj = baker.prepare("generic.Person") @@ -114,11 +154,46 @@ def test_multiple_inheritance_creation(self): @pytest.mark.django_db class TestsBakerRepeatedCreatesSimpleModel: def test_make_should_create_objects_respecting_quantity_parameter(self): - baker.make(models.Person, _quantity=5) - assert models.Person.objects.count() == 5 + queries = QueryCount() - people = baker.make(models.Person, _quantity=5, name="George Washington") - assert all(p.name == "George Washington" for p in people) + with queries.start_count(): + baker.make(models.Person, _quantity=5) + assert queries.count == 5 + assert models.Person.objects.count() == 5 + + with queries.start_count(): + people = baker.make(models.Person, _quantity=5, name="George Washington") + assert all(p.name == "George Washington" for p in people) + assert queries.count == 5 + + def test_make_quantity_respecting_bulk_create_parameter(self): + queries = QueryCount() + + with queries.start_count(): + baker.make(models.Person, _quantity=5, _bulk_create=True) + assert queries.count == 1 + assert models.Person.objects.count() == 5 + + with queries.start_count(): + people = baker.make( + models.Person, name="George Washington", _quantity=5, _bulk_create=True + ) + assert all(p.name == "George Washington" for p in people) + assert queries.count == 1 + + with queries.start_count(): + baker.make(models.NonStandardManager, _quantity=3, _bulk_create=True) + assert queries.count == 1 + assert getattr(models.NonStandardManager, "objects", None) is None + assert ( + models.NonStandardManager._base_manager + == models.NonStandardManager.manager + ) + assert ( + models.NonStandardManager._default_manager + == models.NonStandardManager.manager + ) + assert models.NonStandardManager.manager.count() == 3 def test_make_raises_correct_exception_if_invalid_quantity(self): with pytest.raises(InvalidQuantityException):