From ba850f135fde1b41f260d224a3e103261f022cbc Mon Sep 17 00:00:00 2001 From: Victor Petrovykh Date: Mon, 2 Dec 2024 15:14:51 -0500 Subject: [PATCH] Add a basic Django model generator. Due to limitations of Django the schema must only have the `default` module. Also all link properties and multi properties are forbidden since they produce models without a unique `id` (or other primary key) column. Django cannot properly handle updates to that. Added tests for the generated Django models. --- gel/_testbase.py | 105 +++++- gel/orm/__init__.py | 7 + gel/orm/cli.py | 18 +- gel/orm/django/__init__.py | 0 gel/orm/django/gelmodels/__init__.py | 3 + gel/orm/django/gelmodels/apps.py | 31 ++ gel/orm/django/gelmodels/compiler.py | 64 ++++ gel/orm/django/generator.py | 285 ++++++++++++++++ gel/orm/introspection.py | 38 +++ gel/orm/sqla.py | 55 +-- setup.py | 10 +- tests/dbsetup/base.edgeql | 2 +- tests/test_django_basic.py | 494 +++++++++++++++++++++++++++ tests/test_sqla_basic.py | 25 +- tests/test_sqla_features.py | 15 +- 15 files changed, 1075 insertions(+), 77 deletions(-) create mode 100644 gel/orm/django/__init__.py create mode 100644 gel/orm/django/gelmodels/__init__.py create mode 100644 gel/orm/django/gelmodels/apps.py create mode 100644 gel/orm/django/gelmodels/compiler.py create mode 100644 gel/orm/django/generator.py create mode 100644 tests/test_django_basic.py diff --git a/gel/_testbase.py b/gel/_testbase.py index 775f1a0c..d7f2b77b 100644 --- a/gel/_testbase.py +++ b/gel/_testbase.py @@ -25,6 +25,7 @@ import inspect import json import logging +import pathlib import os import re import subprocess @@ -37,7 +38,8 @@ from gel import asyncio_client from gel import blocking_client from gel.orm.introspection import get_schema_json -from gel.orm.sqla import ModelGenerator +from gel.orm.sqla import ModelGenerator as SQLAModGen +from gel.orm.django.generator import ModelGenerator as DjangoModGen log = logging.getLogger(__name__) @@ -630,13 +632,13 @@ def adapt_call(cls, result): return result -class SQLATestCase(SyncQueryTestCase): - SQLAPACKAGE = None +class ORMTestCase(SyncQueryTestCase): + MODEL_PACKAGE = None DEFAULT_MODULE = 'default' @classmethod def setUpClass(cls): - # SQLAlchemy relies on psycopg2 to connect to Postgres and thus we + # ORMs rely on psycopg2 to connect to Postgres and thus we # need it to run tests. Unfortunately not all test environemnts might # have psycopg2 installed, as long as we run this in the test # environments that have this, it is fine since we're not expecting @@ -648,24 +650,34 @@ def setUpClass(cls): class_set_up = os.environ.get('EDGEDB_TEST_CASES_SET_UP') if not class_set_up: - # Now that the DB is setup, generate the SQLAlchemy models from it - spec = get_schema_json(cls.client) # We'll need a temp directory to setup the generated Python # package - cls.tmpsqladir = tempfile.TemporaryDirectory() - gen = ModelGenerator( - outdir=os.path.join(cls.tmpsqladir.name, cls.SQLAPACKAGE), - basemodule=cls.SQLAPACKAGE, - ) - gen.render_models(spec) - sys.path.append(cls.tmpsqladir.name) + cls.tmpormdir = tempfile.TemporaryDirectory() + sys.path.append(cls.tmpormdir.name) + # Now that the DB is setup, generate the ORM models from it + cls.spec = get_schema_json(cls.client) + cls.setupORM() + + @classmethod + def setupORM(cls): + raise NotImplementedError @classmethod def tearDownClass(cls): super().tearDownClass() # cleanup the temp modules - sys.path.remove(cls.tmpsqladir.name) - cls.tmpsqladir.cleanup() + sys.path.remove(cls.tmpormdir.name) + cls.tmpormdir.cleanup() + + +class SQLATestCase(ORMTestCase): + @classmethod + def setupORM(cls): + gen = SQLAModGen( + outdir=os.path.join(cls.tmpormdir.name, cls.MODEL_PACKAGE), + basemodule=cls.MODEL_PACKAGE, + ) + gen.render_models(cls.spec) @classmethod def get_dsn_for_sqla(cls): @@ -678,6 +690,69 @@ def get_dsn_for_sqla(cls): return dsn +APPS_PY = '''\ +from django.apps import AppConfig + + +class TestConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = {name!r} +''' + +SETTINGS_PY = '''\ +from pathlib import Path + +mysettings = dict( + INSTALLED_APPS=[ + '{appname}.apps.TestConfig', + 'gel.orm.django.gelmodels.apps.GelPGModel', + ], + DATABASES={{ + 'default': {{ + 'ENGINE': 'django.db.backends.postgresql', + 'NAME': {database!r}, + 'USER': {user!r}, + 'PASSWORD': {password!r}, + 'HOST': {host!r}, + 'PORT': {port!r}, + }} + }}, +) +''' + + +class DjangoTestCase(ORMTestCase): + @classmethod + def setupORM(cls): + pkgbase = os.path.join(cls.tmpormdir.name, cls.MODEL_PACKAGE) + # Set up the package for testing Django models + os.mkdir(pkgbase) + open(os.path.join(pkgbase, '__init__.py'), 'w').close() + with open(os.path.join(pkgbase, 'apps.py'), 'wt') as f: + print( + APPS_PY.format(name=cls.MODEL_PACKAGE), + file=f, + ) + + with open(os.path.join(pkgbase, 'settings.py'), 'wt') as f: + cargs = cls.get_connect_args(database=cls.get_database_name()) + print( + SETTINGS_PY.format( + appname=cls.MODEL_PACKAGE, + database=cargs["database"], + user=cargs["user"], + password=cargs["password"], + host=cargs["host"], + port=cargs["port"], + ), + file=f, + ) + + models = os.path.join(pkgbase, 'models.py') + gen = DjangoModGen(out=models) + gen.render_models(cls.spec) + + _lock_cnt = 0 diff --git a/gel/orm/__init__.py b/gel/orm/__init__.py index e69de29b..4f44a1f2 100644 --- a/gel/orm/__init__.py +++ b/gel/orm/__init__.py @@ -0,0 +1,7 @@ +import unittest + +# No tests here, but we want to skip the unittest loader from attempting to +# import ORM packages which may not have been installed (like Django that has +# a few custom adjustments to make our models work). +def load_tests(loader, tests, pattern): + return tests diff --git a/gel/orm/cli.py b/gel/orm/cli.py index 5f387879..1ae29e55 100644 --- a/gel/orm/cli.py +++ b/gel/orm/cli.py @@ -23,7 +23,8 @@ from gel.codegen.generator import _get_conn_args from .introspection import get_schema_json -from .sqla import ModelGenerator +from .sqla import ModelGenerator as SQLAModGen +from .django.generator import ModelGenerator as DjangoModGen class ArgumentParser(argparse.ArgumentParser): @@ -65,7 +66,6 @@ def error(self, message): "--mod", help="The fullname of the Python module corresponding to the output " "directory.", - required=True, ) @@ -74,13 +74,23 @@ def main(): # setup client client = gel.create_client(**_get_conn_args(args)) spec = get_schema_json(client) + generate_models(args, spec) + +def generate_models(args, spec): match args.orm: case 'sqlalchemy': - gen = ModelGenerator( + if args.mod is None: + parser.error('sqlalchemy requires to specify --mod') + + gen = SQLAModGen( outdir=args.out, basemodule=args.mod, ) gen.render_models(spec) + case 'django': - print('Not available yet. Coming soon!') + gen = DjangoModGen( + out=args.out, + ) + gen.render_models(spec) diff --git a/gel/orm/django/__init__.py b/gel/orm/django/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/gel/orm/django/gelmodels/__init__.py b/gel/orm/django/gelmodels/__init__.py new file mode 100644 index 00000000..dc4e1c79 --- /dev/null +++ b/gel/orm/django/gelmodels/__init__.py @@ -0,0 +1,3 @@ +import django + +__version__ = "0.0.1" diff --git a/gel/orm/django/gelmodels/apps.py b/gel/orm/django/gelmodels/apps.py new file mode 100644 index 00000000..7ad85b16 --- /dev/null +++ b/gel/orm/django/gelmodels/apps.py @@ -0,0 +1,31 @@ +from django.apps import AppConfig + + +class GelPGModel(AppConfig): + name = "gel.orm.django.gelmodels" + + def ready(self): + from django.db import connections, utils + + gel_compiler_module = "gel.orm.django.gelmodels.compiler" + + # Change the current compiler_module + for c in connections: + connections[c].ops.compiler_module = gel_compiler_module + + # Update the load_backend to use our DatabaseWrapper + orig_load_backend = utils.load_backend + + def custom_load_backend(*args, **kwargs): + backend = orig_load_backend(*args, **kwargs) + + class GelPGBackend: + @staticmethod + def DatabaseWrapper(*args2, **kwargs2): + connection = backend.DatabaseWrapper(*args2, **kwargs2) + connection.ops.compiler_module = gel_compiler_module + return connection + + return GelPGBackend + + utils.load_backend = custom_load_backend \ No newline at end of file diff --git a/gel/orm/django/gelmodels/compiler.py b/gel/orm/django/gelmodels/compiler.py new file mode 100644 index 00000000..5201fe70 --- /dev/null +++ b/gel/orm/django/gelmodels/compiler.py @@ -0,0 +1,64 @@ +from django.db.models.sql.compiler import ( # noqa + SQLAggregateCompiler, + SQLCompiler, + SQLDeleteCompiler, +) +from django.db.models.sql.compiler import ( # noqa + SQLInsertCompiler as BaseSQLInsertCompiler, +) +from django.db.models.sql.compiler import ( # noqa + SQLUpdateCompiler as BaseSQLUpdateCompiler, +) + + +class GelSQLCompilerMixin: + ''' + The reflected models have two special fields: `id` and `obj_type`. Both of + those fields should be read-only as they are populated automatically by + Gel and must not be modified. + ''' + @property + def readonly_gel_fields(self): + try: + # Verify that this is a Gel model reflected via Postgres protocol. + gel_pg_meta = getattr(self.query.model, "GelPGMeta") + except AttributeError: + return set() + else: + return {'id', 'gel_type_id'} + + def as_sql(self): + readonly_gel_fields = self.readonly_gel_fields + if readonly_gel_fields: + self.remove_readonly_gel_fields(readonly_gel_fields) + return super().as_sql() + + +class SQLUpdateCompiler(GelSQLCompilerMixin, BaseSQLUpdateCompiler): + def remove_readonly_gel_fields(self, names): + ''' + Remove the values corresponding to the read-only fields. + ''' + values = self.query.values + # The tuple is (field, model, value) + values[:] = (tup for tup in values if tup[0].name not in names) + + +class SQLInsertCompiler(GelSQLCompilerMixin, BaseSQLInsertCompiler): + def remove_readonly_gel_fields(self, names): + ''' + Remove the read-only fields. + ''' + fields = self.query.fields + + try: + fields[:] = (f for f in fields if f.name not in names) + except AttributeError: + # When deserializing, we might get an attribute error because this + # list shoud be copied first: + # + # "AttributeError: The return type of 'local_concrete_fields' + # should never be mutated. If you want to manipulate this list for + # your own use, make a copy first." + + self.query.fields = [f for f in fields if f.name not in names] diff --git a/gel/orm/django/generator.py b/gel/orm/django/generator.py new file mode 100644 index 00000000..93c45b0d --- /dev/null +++ b/gel/orm/django/generator.py @@ -0,0 +1,285 @@ +import pathlib +import re + +from ..introspection import get_mod_and_name, FilePrinter + + +GEL_SCALAR_MAP = { + 'std::uuid': 'UUIDField', + 'std::bigint': 'DecimalField', + 'std::bool': 'BooleanField', + 'std::bytes': 'BinaryField', + 'std::decimal': 'DecimalField', + 'std::float32': 'FloatField', + 'std::float64': 'FloatField', + 'std::int16': 'SmallIntegerField', + 'std::int32': 'IntegerField', + 'std::int64': 'BigIntegerField', + 'std::json': 'JSONField', + 'std::str': 'TextField', + # Extreme caution is needed for datetime field, the TZ aware and naive + # values are controlled in Django via settings (USE_TZ) and are mutually + # exclusive in the same app under default circumstances. + 'std::datetime': 'DateTimeField', + 'cal::local_date': 'DateField', + 'cal::local_datetime': 'DateTimeField', + 'cal::local_time': 'TimeField', + # all kinds of duration is not supported due to this error: + # iso_8601 intervalstyle currently not supported +} + +BASE_STUB = f'''\ +# +# Automatically generated from Gel schema. +# +# This is based on the auto-generated Django model module, which has been +# updated to fit Gel schema more closely. +# + +from django.db import models + +class GelUUIDField(models.UUIDField): + # This field must be treated as a auto-generated UUID. + db_returning = True + + +class LTForeignKey(models.ForeignKey): + # Linked tables need to return their source/target ForeignKeys. + db_returning = True\ +''' + +GEL_META = f''' +class GelPGMeta: + 'This is a model reflected from Gel using Postgres protocol.' +''' + +FK_RE = re.compile(r'''models\.ForeignKey\((.+?),''') +CLOSEPAR_RE = re.compile(r'\)(?=\s+#|$)') + + +class ModelClass(object): + def __init__(self, name): + self.name = name + self.props = {} + self.links = {} + self.mlinks = {} + self.meta = {'managed': False} + self.backlinks = {} + self.backlink_renames = {} + + @property + def table(self): + return self.meta['db_table'].strip("'") + + def get_backlink_name(self, name): + return self.backlink_renames.get(name, f'backlink_via_{name}') + + +class ModelGenerator(FilePrinter): + def __init__(self, *, out): + super().__init__() + # record the output file path + self.outfile = pathlib.Path(out).resolve() + + def spec_to_modules_dict(self, spec): + modules = { + mod: {} for mod in sorted(spec['modules']) + } + + for rec in spec['link_tables']: + mod = rec['module'] + if 'link_tables' not in modules[mod]: + modules[mod]['link_tables'] = {} + modules[mod]['link_tables'][rec['table']] = rec + + for rec in spec['object_types']: + mod, name = get_mod_and_name(rec['name']) + if 'object_types' not in modules[mod]: + modules[mod]['object_types'] = {} + modules[mod]['object_types'][name] = rec + + return modules['default'] + + def replace_foreignkey(self, fval, origtarget, newtarget, bkname=None): + # Replace the reference with the string quoted + # (because we don't check the order of definition) + # name. + fval = fval.replace(origtarget, repr(newtarget)) + + if bkname: + # Add a backlink reference + fval = CLOSEPAR_RE.sub(f', related_name={bkname!r})', fval) + + return fval + + def build_models(self, maps): + modmap = {} + + for name, rec in maps['object_types'].items(): + mod = ModelClass(name) + mod.meta['db_table'] = repr(name) + if 'backlink_renames' in rec: + mod.backlink_renames = rec['backlink_renames'] + + # copy backlink information + for link in rec['backlinks']: + mod.backlinks[link['name']] = link + + # process properties as fields + for prop in rec['properties']: + pname = prop['name'] + if pname == 'id': + continue + + mod.props[pname] = self.render_prop(prop) + + # process single links as fields + for link in rec['links']: + if link['cardinality'] != 'One': + # Multi links require link tables and are handled + # separately. + continue + + lname = link['name'] + bklink = mod.get_backlink_name(lname) + mod.links[lname] = self.render_link(link, bklink) + + modmap[mod.name] = mod + + for table, rec in maps['link_tables'].items(): + source, fwname = table.split('.') + mod = ModelClass(f'{source}{fwname.title()}') + mod.meta['db_table'] = repr(table) + mod.meta['unique_together'] = "(('source', 'target'),)" + + # Only have source and target + _, target = get_mod_and_name(rec['target']) + mod.links['source'] = ( + f"LTForeignKey({source!r}, models.DO_NOTHING, " + f"db_column='source', primary_key=True)" + ) + mod.links['target'] = ( + f"LTForeignKey({target!r}, models.DO_NOTHING, " + f"db_column='target')" + ) + + # Update the source model with the corresponding + # ManyToManyField. + src = modmap[source] + tgt = modmap[target] + bkname = src.get_backlink_name(fwname) + src.mlinks[fwname] = ( + f'models.ManyToManyField(' + f'{tgt.name!r}, ' + f'through={mod.name!r}, ' + f'through_fields=("source", "target"), ' + f'related_name={bkname!r})' + ) + + modmap[mod.name] = mod + + return modmap + + def render_prop(self, prop): + if prop['required']: + req = '' + else: + req = 'blank=True, null=True' + + target = prop['target']['name'] + try: + ftype = GEL_SCALAR_MAP[target] + except KeyError: + raise RuntimeError( + f'Scalar type {target} is not supported') + + return f'models.{ftype}({req})' + + def render_link(self, link, bklink=None): + if link['required']: + req = '' + else: + req = ', blank=True, null=True' + + _, target = get_mod_and_name(link['target']['name']) + + if bklink: + bklink = f', related_name={bklink!r}' + else: + bklink = '' + + return (f'models.ForeignKey(' + f'{target!r}, models.DO_NOTHING{bklink}{req})') + + def render_models(self, spec): + # Check that there is only "default" module + mods = spec['modules'] + if mods[0] != 'default' or len(mods) > 1: + raise RuntimeError( + f"Django reflection doesn't support multiple modules or " + f"non-default modules." + ) + # Check that we don't have multiprops or link properties as they + # produce models without `id` field and Django doesn't like that. It + # causes Django to mistakenly use `source` as `id` and also attempt to + # UPDATE `target` on link tables. + if len(spec['prop_objects']) > 0: + raise RuntimeError( + f"Django reflection doesn't support multi properties as they " + f"produce models without `id` field." + ) + if len(spec['link_objects']) > 0: + raise RuntimeError( + f"Django reflection doesn't support link properties as they " + f"produce models without `id` field." + ) + + maps = self.spec_to_modules_dict(spec) + modmap = self.build_models(maps) + + with open(self.outfile, 'w+t') as f: + self.out = f + self.write(BASE_STUB) + + for mod in modmap.values(): + self.write() + self.write() + self.render_model_class(mod) + + def render_model_class(self, mod): + self.write(f'class {mod.name}(models.Model):') + self.indent() + + if '.' not in mod.table: + # This is only valid for regular objects, not link tables. + self.write(f"id = GelUUIDField(primary_key=True)") + self.write(f"gel_type_id = models.UUIDField(db_column='__type__')") + + if mod.props: + self.write() + self.write(f'# properties as Fields') + for name, val in mod.props.items(): + self.write(f'{name} = {val}') + + if mod.links: + self.write() + self.write(f'# links as ForeignKeys') + for name, val in mod.links.items(): + self.write(f'{name} = {val}') + + if mod.mlinks: + self.write() + self.write(f'# multi links as ManyToManyFields') + for name, val in mod.mlinks.items(): + self.write(f'{name} = {val}') + + if '.' not in mod.table: + self.write(GEL_META) + + self.write('class Meta:') + self.indent() + for name, val in mod.meta.items(): + self.write(f'{name} = {val}') + self.dedent() + + self.dedent() \ No newline at end of file diff --git a/gel/orm/introspection.py b/gel/orm/introspection.py index f19bffd7..7a6fdfeb 100644 --- a/gel/orm/introspection.py +++ b/gel/orm/introspection.py @@ -1,6 +1,7 @@ import json import re import collections +import textwrap INTRO_QUERY = ''' @@ -68,6 +69,17 @@ def get_sql_name(name): return name +def get_mod_and_name(name): + # Assume the names are already validated to be properly formed + # alphanumeric identifiers that may be prefixed by a module. If the module + # is present assume it is safe to drop it (currently only defualt module + # is allowed). + + # Split on module separator. Potentially if we ever handle more unusual + # names, there may be more processing done. + return name.rsplit('::', 1) + + def check_name(name): # Just remove module separators and check the rest name = name.replace('::', '') @@ -232,3 +244,29 @@ def _process_links(types, modules): 'link_objects': link_objects, 'prop_objects': prop_objects, } + + +class FilePrinter(object): + INDENT = ' ' * 4 + + def __init__(self): + # set the output to be stdout by default, but this is generally + # expected to be overridden + self.out = None + self._indent_level = 0 + + def indent(self): + self._indent_level += 1 + + def dedent(self): + if self._indent_level > 0: + self._indent_level -= 1 + + def reset_indent(self): + self._indent_level -= 0 + + def write(self, text=''): + print( + textwrap.indent(text, prefix=self.INDENT * self._indent_level), + file=self.out, + ) diff --git a/gel/orm/sqla.py b/gel/orm/sqla.py index bf18f264..093096b1 100644 --- a/gel/orm/sqla.py +++ b/gel/orm/sqla.py @@ -1,14 +1,12 @@ import pathlib import re -import textwrap from contextlib import contextmanager -from .introspection import get_sql_name +from .introspection import get_sql_name, get_mod_and_name +from .introspection import FilePrinter -INDENT = ' ' * 4 - GEL_SCALAR_MAP = { 'std::bool': ('bool', 'Boolean'), 'std::str': ('str', 'String'), @@ -52,20 +50,7 @@ class Base(DeclarativeBase): ''' -def get_mod_and_name(name): - # Assume the names are already validated to be properly formed - # alphanumeric identifiers that may be prefixed by a module. If the module - # is present assume it is safe to drop it (currently only defualt module - # is allowed). - - # Split on module separator. Potentially if we ever handle more unusual - # names, there may be more processing done. - return name.rsplit('::', 1) - - -class ModelGenerator(object): - INDENT = ' ' * 4 - +class ModelGenerator(FilePrinter): def __init__(self, *, outdir=None, basemodule=None): # set the output to be stdout by default, but this is generally # expected to be overridden by appropriate files in the `outdir` @@ -75,24 +60,7 @@ def __init__(self, *, outdir=None, basemodule=None): self.outdir = None self.basemodule = basemodule - self.out = None - self._indent_level = 0 - - def indent(self): - self._indent_level += 1 - - def dedent(self): - if self._indent_level > 0: - self._indent_level -= 1 - - def reset_indent(self): - self._indent_level -= 0 - - def write(self, text=''): - print( - textwrap.indent(text, prefix=self.INDENT * self._indent_level), - file=self.out, - ) + super().__init__() def init_dir(self, dirpath): if not dirpath: @@ -161,11 +129,7 @@ def get_py_name(self, mod, name, curmod): mod = mod.replace('::', '.') return f"'{self.basemodule}.{mod}.{name}'" - def render_models(self, spec): - # The modules dict will be populated with the respective types, link - # tables, etc., since they will need to be put in their own files. We - # sort the modules so that nested modules are initialized from root to - # leaf. + def spec_to_modules_dict(self, spec): modules = { mod: {} for mod in sorted(spec['modules']) } @@ -194,6 +158,15 @@ def render_models(self, spec): modules[mod]['object_types'] = {} modules[mod]['object_types'][name] = rec + return modules + + def render_models(self, spec): + # The modules dict will be populated with the respective types, link + # tables, etc., since they will need to be put in their own files. We + # sort the modules so that nested modules are initialized from root to + # leaf. + modules = self.spec_to_modules_dict(spec) + # Initialize the base directory self.init_dir(self.outdir) self.init_sqlabase() diff --git a/setup.py b/setup.py index 447c3a45..b5bb73fd 100644 --- a/setup.py +++ b/setup.py @@ -50,14 +50,14 @@ 'flake8-bugbear~=24.4.26', 'flake8~=7.0.0', 'uvloop>=0.15.1; platform_system != "Windows"', - 'SQLAlchemy>=2.0.0', ] -# This is needed specifically to test ORM reflection because the ORMs tend to -# use this library to access Postgres. It's not always avaialable as a -# pre-built package and we don't necessarily want to try and build it from -# source. +# The ORMs and the SQL libraries they rely on may not be avaialble for older +# Python versions. That's OK for the overall client build, though and we only +# want to test them for the versions where they are avaialable. SQLTEST_DEPENDENCIES = [ + 'SQLAlchemy>=2.0.0', + 'Django>=5.1.3', 'psycopg2-binary>=2.9.10', ] diff --git a/tests/dbsetup/base.edgeql b/tests/dbsetup/base.edgeql index 74959e71..cfbf46e8 100644 --- a/tests/dbsetup/base.edgeql +++ b/tests/dbsetup/base.edgeql @@ -7,7 +7,7 @@ insert User {name := 'Zoe'}; insert UserGroup { name := 'red', - users := (select User filter .name != 'Zoe'), + users := (select User filter .name not in {'Elsa', 'Zoe'}), }; insert UserGroup { name := 'green', diff --git a/tests/test_django_basic.py b/tests/test_django_basic.py new file mode 100644 index 00000000..6ecdc382 --- /dev/null +++ b/tests/test_django_basic.py @@ -0,0 +1,494 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import uuid +import unittest + +try: + import django + from django.db import transaction +except ImportError: + NO_ORM = True +else: + NO_ORM = False + +from gel import _testbase as tb + + +class TestDjangoBasic(tb.DjangoTestCase): + SCHEMA = os.path.join(os.path.dirname(__file__), 'dbsetup', + 'base.esdl') + + SETUP = os.path.join(os.path.dirname(__file__), 'dbsetup', + 'base.edgeql') + + MODEL_PACKAGE = 'djangobase' + + @classmethod + def setUpClass(cls): + if NO_ORM: + raise unittest.SkipTest("django is not installed") + + super().setUpClass() + + from django.conf import settings + from djangobase.settings import mysettings + + settings.configure(**mysettings) + django.setup() + + from djangobase import models + cls.m = models + transaction.set_autocommit(False) + + def setUp(self): + super().setUp() + + if self.client.query_required_single(''' + select sys::get_version().major < 6 + '''): + self.skipTest("Test needs SQL DML queries") + + transaction.savepoint() + + def tearDown(self): + super().tearDown() + transaction.rollback() + + def test_django_read_models_01(self): + vals = {r.name for r in self.m.User.objects.all()} + self.assertEqual( + vals, {'Alice', 'Billie', 'Cameron', 'Dana', 'Elsa', 'Zoe'}) + + vals = {r.name for r in self.m.UserGroup.objects.all()} + self.assertEqual( + vals, {'red', 'green', 'blue'}) + + vals = {r.num for r in self.m.GameSession.objects.all()} + self.assertEqual(vals, {123, 456}) + + vals = {r.body for r in self.m.Post.objects.all()} + self.assertEqual( + vals, {'Hello', "I'm Alice", "I'm Cameron", '*magic stuff*'}) + + # Read from the abstract type + vals = {r.name for r in self.m.Named.objects.all()} + self.assertEqual( + vals, + { + 'Alice', 'Billie', 'Cameron', 'Dana', 'Elsa', 'Zoe', + 'red', 'green', 'blue', + } + ) + + def test_django_read_models_02(self): + # test single link and the one-to-many backlink + # using load-on-demand + + res = self.m.Post.objects.all() + vals = {(p.author.name, p.body) for p in res} + self.assertEqual( + vals, + { + ('Alice', 'Hello'), + ('Alice', "I'm Alice"), + ('Cameron', "I'm Cameron"), + ('Elsa', '*magic stuff*'), + } + ) + + # use backlink + res = self.m.User.objects.order_by('name').all() + vals = [ + (u.name, {p.body for p in u.backlink_via_author.all()}) + for u in res + ] + self.assertEqual( + vals, + [ + ('Alice', {'Hello', "I'm Alice"}), + ('Billie', set()), + ('Cameron', {"I'm Cameron"}), + ('Dana', set()), + ('Elsa', {'*magic stuff*'}), + ('Zoe', set()), + ] + ) + + def test_django_read_models_03(self): + # test single link and the one-to-many backlink + + res = self.m.Post.objects.select_related('author') + vals = {(p.author.name, p.body) for p in res} + self.assertEqual( + vals, + { + ('Alice', 'Hello'), + ('Alice', "I'm Alice"), + ('Cameron', "I'm Cameron"), + ('Elsa', '*magic stuff*'), + } + ) + + # prefetch via backlink + res = self.m.User.objects.prefetch_related('backlink_via_author') \ + .order_by('backlink_via_author__body') + vals = { + (u.name, tuple(p.body for p in u.backlink_via_author.all())) + for u in res + } + self.assertEqual( + vals, + { + ('Alice', ('Hello', "I'm Alice")), + ('Billie', ()), + ('Cameron', ("I'm Cameron",)), + ('Dana', ()), + ('Elsa', ('*magic stuff*',)), + ('Zoe', ()), + } + ) + + def test_django_read_models_04(self): + # test exclusive multi link and its backlink + # using load-on-demand + + res = self.m.GameSession.objects.order_by('num').all() + vals = [(g.num, {u.name for u in g.players.all()}) for g in res] + self.assertEqual( + vals, + [ + (123, {'Alice', 'Billie'}), + (456, {'Dana'}), + ] + ) + + # use backlink + res = self.m.User.objects.all() + vals = { + (u.name, tuple(g.num for g in u.backlink_via_players.all())) + for u in res + } + self.assertEqual( + vals, + { + ('Alice', (123,)), + ('Billie', (123,)), + ('Cameron', ()), + ('Dana', (456,)), + ('Elsa', ()), + ('Zoe', ()), + } + ) + + def test_django_read_models_05(self): + # test exclusive multi link and its backlink + + res = self.m.GameSession.objects.prefetch_related('players') + vals = { + (g.num, tuple(sorted(u.name for u in g.players.all()))) + for g in res + } + self.assertEqual( + vals, + { + (123, ('Alice', 'Billie')), + (456, ('Dana',)), + } + ) + + # prefetch via backlink + res = self.m.User.objects.prefetch_related('backlink_via_players') + vals = { + (u.name, tuple(g.num for g in u.backlink_via_players.all())) + for u in res + } + self.assertEqual( + vals, + { + ('Alice', (123,)), + ('Billie', (123,)), + ('Cameron', ()), + ('Dana', (456,)), + ('Elsa', ()), + ('Zoe', ()), + } + ) + + def test_django_read_models_06(self): + # test multi link and its backlink + # using load-on-demand + + res = self.m.UserGroup.objects.order_by('name').all() + vals = [(g.name, {u.name for u in g.users.all()}) for g in res] + self.assertEqual( + vals, + [ + ('blue', set()), + ('green', {'Alice', 'Billie'}), + ('red', {'Alice', 'Billie', 'Cameron', 'Dana'}), + ] + ) + + # use backlink + res = self.m.User.objects.order_by('name').all() + vals = [ + (u.name, {g.name for g in u.backlink_via_users.all()}) + for u in res + ] + self.assertEqual( + vals, + [ + ('Alice', {'red', 'green'}), + ('Billie', {'red', 'green'}), + ('Cameron', {'red'}), + ('Dana', {'red'}), + ('Elsa', set()), + ('Zoe', set()), + ] + ) + + def test_django_read_models_07(self): + # test exclusive multi link and its backlink + + res = self.m.UserGroup.objects.prefetch_related('users') + vals = { + (g.name, tuple(sorted(u.name for u in g.users.all()))) + for g in res + } + self.assertEqual( + vals, + { + ('blue', ()), + ('green', ('Alice', 'Billie')), + ('red', ('Alice', 'Billie', 'Cameron', 'Dana')), + } + ) + + # prefetch via backlink + res = self.m.User.objects.prefetch_related('backlink_via_users') + vals = { + (u.name, tuple(sorted(g.name for g in u.backlink_via_users.all()))) + for u in res + } + self.assertEqual( + vals, + { + ('Alice', ('green', 'red')), + ('Billie', ('green', 'red')), + ('Cameron', ('red',)), + ('Dana', ('red',)), + ('Elsa', ()), + ('Zoe', ()), + } + ) + + def test_django_create_models_01(self): + vals = self.m.User.objects.filter(name='Yvonne').all() + self.assertEqual(list(vals), []) + + user = self.m.User(name='Yvonne') + user.save() + + self.assertEqual(user.name, 'Yvonne') + self.assertIsInstance(user.id, uuid.UUID) + + def test_django_create_models_02(self): + x = self.m.User(name='Xander') + y = self.m.User(name='Yvonne') + cyan = self.m.UserGroup(name='cyan') + + x.save() + y.save() + cyan.save() + cyan.users.set([x, y]) + + for name in ['Yvonne', 'Xander']: + user = self.m.User.objects.get(name=name) + + self.assertEqual(user.name, name) + self.assertEqual(user.backlink_via_users.all()[0].name, 'cyan') + self.assertIsInstance(user.id, uuid.UUID) + + def test_django_create_models_03(self): + x = self.m.User(name='Xander') + y = self.m.User(name='Yvonne') + cyan = self.m.UserGroup(name='cyan') + + x.save() + y.save() + cyan.save() + + x.backlink_via_users.add(cyan) + y.backlink_via_users.add(cyan) + + group = self.m.UserGroup.objects.get(name='cyan') + self.assertEqual(group.name, 'cyan') + self.assertEqual( + {u.name for u in group.users.all()}, + {'Xander', 'Yvonne'}, + ) + + def test_django_create_models_04(self): + user = self.m.User(name='Yvonne') + user.save() + self.m.Post(body='this is a test', author=user).save() + self.m.Post(body='also a test', author=user).save() + + res = self.m.Post.objects.select_related('author') \ + .filter(author__name='Yvonne') + self.assertEqual( + {p.body for p in res}, + {'this is a test', 'also a test'}, + ) + + def test_django_delete_models_01(self): + user = self.m.User.objects.get(name='Zoe') + self.assertEqual(user.name, 'Zoe') + self.assertIsInstance(user.id, uuid.UUID) + + user.delete() + + vals = self.m.User.objects.filter(name='Zoe').all() + self.assertEqual(list(vals), []) + + def test_django_delete_models_02(self): + post = self.m.Post.objects.select_related('author') \ + .get(author__name='Elsa') + user_id = post.author.id + + post.delete() + + vals = self.m.Post.objects.select_related('author') \ + .filter(author__name='Elsa') + self.assertEqual(list(vals), []) + + user = self.m.User.objects.get(id=user_id) + self.assertEqual(user.name, 'Elsa') + + def test_django_delete_models_03(self): + post = self.m.Post.objects.select_related('author') \ + .get(author__name='Elsa') + user = post.author + + post.delete() + user.delete() + + vals = self.m.Post.objects.select_related('author') \ + .filter(author__name='Elsa') + self.assertEqual(list(vals), []) + + vals = self.m.User.objects.filter(name='Elsa') + self.assertEqual(list(vals), []) + + def test_django_delete_models_04(self): + group = self.m.UserGroup.objects.get(name='green') + names = {u.name for u in group.users.all()} + + group.delete() + + vals = self.m.UserGroup.objects.filter(name='green').all() + self.assertEqual(list(vals), []) + + users = self.m.User.objects.all() + for name in names: + self.assertIn(name, {u.name for u in users}) + + def test_django_delete_models_05(self): + group = self.m.UserGroup.objects.get(name='green') + for u in group.users.all(): + if u.name == 'Billie': + user = u + break + + group.delete() + # make sure the user object is no longer a link target + user.backlink_via_users.clear() + user.backlink_via_players.clear() + user.delete() + + vals = self.m.UserGroup.objects.filter(name='green').all() + self.assertEqual(list(vals), []) + + users = self.m.User.objects.all() + self.assertNotIn('Billie', {u.name for u in users}) + + def test_django_update_models_01(self): + user = self.m.User.objects.get(name='Alice') + self.assertEqual(user.name, 'Alice') + self.assertIsInstance(user.id, uuid.UUID) + + user.name = 'Xander' + user.save() + + vals = self.m.User.objects.filter(name='Alice').all() + self.assertEqual(list(vals), []) + other = self.m.User.objects.get(name='Xander') + self.assertEqual(user, other) + + def test_django_update_models_02(self): + red = self.m.UserGroup.objects.get(name='red') + blue = self.m.UserGroup.objects.get(name='blue') + user = self.m.User(name='Yvonne') + + user.save() + red.users.add(user) + blue.users.add(user) + + self.assertEqual( + {g.name for g in user.backlink_via_users.all()}, + {'red', 'blue'}, + ) + self.assertEqual(user.name, 'Yvonne') + self.assertIsInstance(user.id, uuid.UUID) + + group = [g for g in user.backlink_via_users.all() + if g.name == 'red'][0] + self.assertEqual( + {u.name for u in group.users.all()}, + {'Alice', 'Billie', 'Cameron', 'Dana', 'Yvonne'}, + ) + + def test_django_update_models_03(self): + user0 = self.m.User.objects.get(name='Elsa') + user1 = self.m.User.objects.get(name='Zoe') + # Replace the author or a post + post = user0.backlink_via_author.all()[0] + body = post.body + post.author = user1 + post.save() + + res = self.m.Post.objects.select_related('author') \ + .filter(author__name='Zoe') + self.assertEqual( + {p.body for p in res}, + {body}, + ) + + def test_django_update_models_04(self): + user = self.m.User.objects.get(name='Zoe') + post = self.m.Post.objects.select_related('author') \ + .get(author__name='Elsa') + # Replace the author or a post + post_id = post.id + post.author = user + post.save() + + post = self.m.Post.objects.get(id=post_id) + self.assertEqual(post.author.name, 'Zoe') diff --git a/tests/test_sqla_basic.py b/tests/test_sqla_basic.py index a9b274a5..c5711a8f 100644 --- a/tests/test_sqla_basic.py +++ b/tests/test_sqla_basic.py @@ -18,9 +18,15 @@ import os import uuid +import unittest -from sqlalchemy import create_engine, select -from sqlalchemy.orm import Session +try: + from sqlalchemy import create_engine, select + from sqlalchemy.orm import Session +except ImportError: + NO_ORM = True +else: + NO_ORM = False from gel import _testbase as tb @@ -32,10 +38,13 @@ class TestSQLABasic(tb.SQLATestCase): SETUP = os.path.join(os.path.dirname(__file__), 'dbsetup', 'base.edgeql') - SQLAPACKAGE = 'basemodels' + MODEL_PACKAGE = 'basemodels' @classmethod def setUpClass(cls): + if NO_ORM: + raise unittest.SkipTest("sqlalchemy is not installed") + super().setUpClass() cls.engine = create_engine(cls.get_dsn_for_sqla()) cls.sess = Session(cls.engine, autobegin=False) @@ -269,7 +278,7 @@ def test_sqla_read_models_06(self): [ ('blue', set()), ('green', {'Alice', 'Billie'}), - ('red', {'Alice', 'Billie', 'Cameron', 'Dana', 'Elsa'}), + ('red', {'Alice', 'Billie', 'Cameron', 'Dana'}), ] ) @@ -286,7 +295,7 @@ def test_sqla_read_models_06(self): ('Billie', {'red', 'green'}), ('Cameron', {'red'}), ('Dana', {'red'}), - ('Elsa', {'red'}), + ('Elsa', set()), ('Zoe', set()), ] ) @@ -310,7 +319,7 @@ def test_sqla_read_models_07(self): { ('blue', ()), ('green', ('Alice', 'Billie')), - ('red', ('Alice', 'Billie', 'Cameron', 'Dana', 'Elsa')), + ('red', ('Alice', 'Billie', 'Cameron', 'Dana')), } ) @@ -330,7 +339,7 @@ def test_sqla_read_models_07(self): ('Billie', ('green', 'red')), ('Cameron', ('red',)), ('Dana', ('red',)), - ('Elsa', ('red',)), + ('Elsa', ()), ('Zoe', ()), } ) @@ -528,7 +537,7 @@ def test_sqla_update_models_02(self): group = [g for g in user.backlink_via_users if g.name == 'red'][0] self.assertEqual( {u.name for u in group.users}, - {'Alice', 'Billie', 'Cameron', 'Dana', 'Elsa', 'Yvonne'}, + {'Alice', 'Billie', 'Cameron', 'Dana', 'Yvonne'}, ) def test_sqla_update_models_03(self): diff --git a/tests/test_sqla_features.py b/tests/test_sqla_features.py index 4c1e42e6..7b79b818 100644 --- a/tests/test_sqla_features.py +++ b/tests/test_sqla_features.py @@ -17,9 +17,15 @@ # import os +import unittest -from sqlalchemy import create_engine -from sqlalchemy.orm import Session +try: + from sqlalchemy import create_engine + from sqlalchemy.orm import Session +except ImportError: + NO_ORM = True +else: + NO_ORM = False from gel import _testbase as tb @@ -37,10 +43,13 @@ class TestSQLAFeatures(tb.SQLATestCase): SETUP = os.path.join(os.path.dirname(__file__), 'dbsetup', 'features.edgeql') - SQLAPACKAGE = 'fmodels' + MODEL_PACKAGE = 'fmodels' @classmethod def setUpClass(cls): + if NO_ORM: + raise unittest.SkipTest("sqlalchemy is not installed") + super().setUpClass() cls.engine = create_engine(cls.get_dsn_for_sqla()) cls.sess = Session(cls.engine, autobegin=False)