Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove usage of ORM from migrations 228 and 229 #158

176 changes: 122 additions & 54 deletions threedi_schema/migrations/versions/0228_upgrade_db_1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

import sqlalchemy as sa
from alembic import op
from sqlalchemy import Column, Float, func, Integer, select, String
from sqlalchemy import Column, Float, func, Integer, select, String, Text
from sqlalchemy.orm import declarative_base, Session

from threedi_schema.domain import constants, models
from threedi_schema.domain.custom_types import IntegerEnum
from threedi_schema.domain import constants
from threedi_schema.domain.custom_types import Geometry, IntegerEnum
from threedi_schema.migrations.utils import drop_conflicting, drop_geo_table

Base = declarative_base()
Expand Down Expand Up @@ -84,10 +84,45 @@
"pump": ["connection_node_end_id", "zoom_category", "classification"]
}

ADD_COLUMNS = [
("channel", Column("tags", Text)),
("cross_section_location", Column("tags", Text)),
("culvert", Column("tags", Text)),
("culvert", Column("material_id", Integer)),
("orifice", Column("tags", Text)),
("orifice", Column("material_id", Integer)),
# ("orifice", Column("geom", Geometry("LINESTRING"), nullable=False)),
# ("pipe", Column("geom", Geometry("LINESTRING"), nullable=False)),
("pipe", Column("tags", Text)),
("pump", Column("tags", Text)),
("weir", Column("tags", Text)),
("weir", Column("material_id", Integer)),
# ("weir", Column("geom", Geometry("LINESTRING"), nullable=False)),
margrietpalm marked this conversation as resolved.
Show resolved Hide resolved
("windshielding_1d", Column("tags", Text)),
]

RETYPE_COLUMNS = {}


def add_columns_to_tables(table_columns: List[Tuple[str, Column]]):
# no checks for existence are done, this will fail if any column already exists
for dst_table, col in table_columns:
if isinstance(col.type, Geometry):
add_geometry_column(dst_table, col)
else:
with op.batch_alter_table(dst_table) as batch_op:
batch_op.add_column(col)


def add_geometry_column(table: str, geocol: Column):
# Adding geometry columns via alembic doesn't work
# https://postgis.net/docs/AddGeometryColumn.html
geotype = geocol.type
query = (
f"SELECT AddGeometryColumn('{table}', '{geocol.name}', {geotype.srid}, '{geotype.geometry_type}', 'XY', 1);")
op.execute(sa.text(query))


class Schema228UpgradeException(Exception):
pass

Expand All @@ -104,42 +139,69 @@ def remove_tables(tables: List[str]):
drop_geo_table(op, table)


def get_geom_type(table_name, geo_col_name):
connection = op.get_bind()
columns = connection.execute(sa.text(f"PRAGMA table_info('{table_name}')")).fetchall()
for col in columns:
if col[1] == geo_col_name:
return col[2]

def modify_table(old_table_name, new_table_name):
# Create a new table named `new_table_name` using the declared models
# Create a new table named `new_table_name` by copying the
# data from `old_table_name`.
# Use the columns from `old_table_name`, with the following exceptions:
# * columns in `REMOVE_COLUMNS[new_table_name]` are skipped
# * columns in `RENAME_COLUMNS[new_table_name]` are renamed
# * columns in `RETYPE_COLUMNS[new_table_name]` change type
# * `the_geom` is renamed to `geom` and NOT NULL is enforced
model = find_model(new_table_name)
# create new table
create_sqlite_table_from_model(model)
# get column names from model and match them to available data in sqlite
connection = op.get_bind()
rename_cols = {**RENAME_COLUMNS.get(new_table_name, {}), "the_geom": "geom"}
rename_cols_rev = {v: k for k, v in rename_cols.items()}
col_map = [(col.name, rename_cols_rev.get(col.name, col.name)) for col in get_cols_for_model(model)]
available_cols = [col[1] for col in connection.execute(sa.text(f"PRAGMA table_info('{old_table_name}')")).fetchall()]
new_col_names, old_col_names = zip(*[(new_col, old_col) for new_col, old_col in col_map if old_col in available_cols])
columns = connection.execute(sa.text(f"PRAGMA table_info('{old_table_name}')")).fetchall()
# get all column names and types
col_names = [col[1] for col in columns]
col_types = [col[2] for col in columns]
# get type of the geometry column
geom_type = get_geom_type(old_table_name, 'the_geom')
# create list of new columns and types for creating the new table
# create list of old columns to copy to new table
skip_cols = ['id', 'the_geom']
if new_table_name in REMOVE_COLUMNS:
skip_cols += REMOVE_COLUMNS[new_table_name]
old_col_names = []
new_col_names = []
new_col_types = []
for cname, ctype in zip(col_names, col_types):
if cname in skip_cols:
continue
old_col_names.append(cname)
if new_table_name in RENAME_COLUMNS and cname in RENAME_COLUMNS[new_table_name]:
new_col_names.append(RENAME_COLUMNS[new_table_name][cname])
else:
new_col_names.append(cname)
if new_table_name in RETYPE_COLUMNS and cname in RETYPE_COLUMNS[new_table_name]:
new_col_types.append(RETYPE_COLUMNS[new_table_name][cname])
else:
new_col_types.append(ctype)
# add to the end manually
old_col_names.append('the_geom')
new_col_names.append('geom')
new_col_types.append(f'{geom_type} NOT NULL')
# Create new table (temp), insert data, drop original and rename temp to table_name
new_col_str = ','.join(['id INTEGER PRIMARY KEY NOT NULL'] + [f'{cname} {ctype}' for cname, ctype in
zip(new_col_names, new_col_types)])
op.execute(sa.text(f"CREATE TABLE {new_table_name} ({new_col_str});"))
# Copy data
# This may copy wrong type data because some types change!!
op.execute(sa.text(f"INSERT INTO {new_table_name} ({','.join(new_col_names)}) "
f"SELECT {','.join(old_col_names)} FROM {old_table_name}"))

op.execute(sa.text(f"INSERT INTO {new_table_name} (id, {','.join(new_col_names)}) "
f"SELECT id, {','.join(old_col_names)} FROM {old_table_name}"))

def find_model(table_name):
for model in models.DECLARED_MODELS:
if model.__tablename__ == table_name:
return model
# This can only go wrong if the migration or model is incorrect
raise

def fix_geometry_columns():
update_models = [models.Channel, models.ConnectionNode, models.CrossSectionLocation,
models.Culvert, models.Orifice, models.Pipe, models.Pump,
models.PumpMap, models.Weir, models.Windshielding]
for model in update_models:
op.execute(sa.text(f"SELECT RecoverGeometryColumn('{model.__tablename__}', "
f"'geom', {4326}, '{model.geom.type.geometry_type}', 'XY')"))
op.execute(sa.text(f"SELECT CreateSpatialIndex('{model.__tablename__}', 'geom')"))
tables = ['channel', 'connection_node', 'cross_section_location', 'culvert',
'orifice', 'pipe', 'pump', 'pump_map', 'weir', 'windshielding_1d']
for table in tables:
geom_type = get_geom_type(table, geo_col_name='geom')
op.execute(sa.text(f"SELECT RecoverGeometryColumn('{table}', "
f"'geom', {4326}, '{geom_type}', 'XY')"))
op.execute(sa.text(f"SELECT CreateSpatialIndex('{table}', 'geom')"))


class Temp(Base):
Expand Down Expand Up @@ -305,29 +367,16 @@ def set_geom_for_v2_pumpstation():
op.execute(sa.text(q))


def get_cols_for_model(model, skip_cols=None):
from sqlalchemy.orm.attributes import InstrumentedAttribute
if skip_cols is None:
skip_cols = []
return [getattr(model, item) for item in model.__dict__
if item not in skip_cols
and isinstance(getattr(model, item), InstrumentedAttribute)]


def create_sqlite_table_from_model(model):
cols = get_cols_for_model(model, skip_cols = ["id", "geom"])
op.execute(sa.text(f"""
CREATE TABLE {model.__tablename__} (
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
{','.join(f"{col.name} {col.type}" for col in cols)},
geom {model.geom.type.geometry_type} NOT NULL
);"""))


def create_pump_map():
# Create table
create_sqlite_table_from_model(models.PumpMap)

query = """
CREATE TABLE pump_map (
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
pump_id INTEGER,connection_node_id_end INTEGER,tags TEXT,code VARCHAR(100),display_name VARCHAR(255),
geom LINESTRING NOT NULL
);
"""
op.execute(sa.text(query))
# Create geometry
op.execute(sa.text(f"SELECT AddGeometryColumn('v2_pumpstation', 'map_geom', 4326, 'LINESTRING', 'XY', 0);"))
op.execute(sa.text("""
Expand Down Expand Up @@ -358,7 +407,15 @@ def create_pump_map():


def create_connection_node():
create_sqlite_table_from_model(models.ConnectionNode)
# Create table
query = """
CREATE TABLE connection_node (
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
code VARCHAR(100),tags TEXT,display_name TEXT,storage_area FLOAT,initial_water_level FLOAT,visualisation INTEGER,manhole_surface_level FLOAT,bottom_level FLOAT,exchange_level FLOAT,exchange_type INTEGER,exchange_thickness FLOAT,hydraulic_conductivity_in FLOAT,hydraulic_conductivity_out FLOAT,
geom POINT NOT NULL
);
"""
op.execute(sa.text(query))
# copy from v2_connection_nodes
old_col_names = ["id", "initial_waterlevel", "storage_area", "the_geom", "code"]
rename_map = {"initial_waterlevel": "initial_water_level", "the_geom": "geom"}
Expand Down Expand Up @@ -389,6 +446,15 @@ def create_connection_node():
"""))


# define Material class needed to populate table in create_material
class Material(Base):
__tablename__ = "material"
id = Column(Integer, primary_key=True)
description = Column(Text)
friction_type = Column(IntegerEnum(constants.FrictionType))
friction_coefficient = Column(Float)


def create_material():
op.execute(sa.text("""
CREATE TABLE material (
Expand All @@ -397,12 +463,13 @@ def create_material():
friction_type INTEGER,
friction_coefficient REAL);
"""))
connection = op.get_bind()
nof_settings = connection.execute(sa.text("SELECT COUNT(*) FROM model_settings")).scalar()
session = Session(bind=op.get_bind())
nof_settings = session.execute(select(func.count()).select_from(models.ModelSettings)).scalar()
if nof_settings > 0:
with open(data_dir.joinpath('0228_materials.csv')) as file:
reader = csv.DictReader(file)
session.bulk_save_objects([models.Material(**row) for row in reader])
session.bulk_save_objects([Material(**row) for row in reader])
session.commit()


Expand Down Expand Up @@ -468,6 +535,7 @@ def upgrade():
set_geom_for_v2_pumpstation()
for old_table_name, new_table_name in RENAME_TABLES:
modify_table(old_table_name, new_table_name)
add_columns_to_tables(ADD_COLUMNS)
# Create new tables
create_pump_map()
create_material()
Expand Down
89 changes: 34 additions & 55 deletions threedi_schema/migrations/versions/0229_clean_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,49 +11,38 @@
import sqlalchemy as sa
from alembic import op

from threedi_schema import models

# revision identifiers, used by Alembic.
revision = "0229"
down_revision = "0228"
branch_labels = None
depends_on = None


def find_model(table_name):
for model in models.DECLARED_MODELS:
if model.__tablename__ == table_name:
return model
# This can only go wrong if the migration or model is incorrect
raise


def create_sqlite_table_from_model(model, table_name):
cols = get_cols_for_model(model, skip_cols=["id"])
op.execute(sa.text(f"""
CREATE TABLE {table_name} (
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
{','.join(f"{col.name} {col.type}" for col in cols)}
);"""))


def get_cols_for_model(model, skip_cols=None):
from sqlalchemy.orm.attributes import InstrumentedAttribute
if skip_cols is None:
skip_cols = []
return [getattr(model, item) for item in model.__dict__
if item not in skip_cols
and isinstance(getattr(model, item), InstrumentedAttribute)]

def get_geom_type(table_name, geo_col_name):
connection = op.get_bind()
columns = connection.execute(sa.text(f"PRAGMA table_info('{table_name}')")).fetchall()
for col in columns:
if col[1] == geo_col_name:
return col[2]

def sync_orm_types_to_sqlite(table_name):
def change_types_in_settings_table():
temp_table_name = f'_temp_229_{uuid.uuid4().hex}'
model = find_model(table_name)
create_sqlite_table_from_model(model, temp_table_name)
col_names = [col.name for col in get_cols_for_model(model)]
# This may copy wrong type data because some types change!!
op.execute(sa.text(f"INSERT INTO {temp_table_name} ({','.join(col_names)}) "
f"SELECT {','.join(col_names)} FROM {table_name}"))
table_name = 'model_settings'
change_types = {'use_d2_rain': 'bool', 'friction_averaging': 'bool'}
connection = op.get_bind()
columns = connection.execute(sa.text(f"PRAGMA table_info('{table_name}')")).fetchall()
# get all column names and types
skip_cols = ['id', 'the_geom']
col_names = [col[1] for col in columns if col[1] not in skip_cols]
old_col_types = [col[2] for col in columns if col[1] not in skip_cols]
col_types = [change_types.get(col_name, col_type) for col_name, col_type in zip(col_names, old_col_types)]
# Create new table, insert data, drop original and rename temp to table_name
col_str = ','.join(['id INTEGER PRIMARY KEY NOT NULL'] + [f'{cname} {ctype}' for cname, ctype in
zip(col_names, col_types)])
op.execute(sa.text(f"CREATE TABLE {temp_table_name} ({col_str});"))
# Copy data
op.execute(sa.text(f"INSERT INTO {temp_table_name} (id, {','.join(col_names)}) "
f"SELECT id, {','.join(col_names)} FROM {table_name}"))
op.execute(sa.text(f"DROP TABLE {table_name}"))
op.execute(sa.text(f"ALTER TABLE {temp_table_name} RENAME TO {table_name};"))

Expand Down Expand Up @@ -98,33 +87,24 @@ def clean_by_type(type: str):
def update_use_settings():
# Ensure that use_* settings are only True when there is actual data for them
use_settings = [
(models.ModelSettings.use_groundwater_storage, models.GroundWater),
(models.ModelSettings.use_groundwater_flow, models.GroundWater),
(models.ModelSettings.use_interflow, models.Interflow),
(models.ModelSettings.use_simple_infiltration, models.SimpleInfiltration),
(models.ModelSettings.use_vegetation_drag_2d, models.VegetationDrag),
(models.ModelSettings.use_interception, models.Interception)
('use_groundwater_storage', 'groundwater'),
('use_groundwater_flow', 'groundwater'),
('use_interflow', 'interflow'),
('use_simple_infiltration', 'simple_infiltration'),
('use_vegetation_drag_2d', 'vegetation_drag_2d'),
('use_interception', 'interception')
]
connection = op.get_bind() # Get the connection for raw SQL execution
for setting, table in use_settings:
use_row = connection.execute(
sa.select(getattr(models.ModelSettings, setting.name))
).scalar()
use_row = connection.execute(sa.text(f"SELECT {setting} FROM model_settings")).scalar()
if not use_row:
continue
row = connection.execute(sa.select(table)).first()
row = connection.execute(sa.text(f"SELECT * FROM {table}")).first()
use_row = (row is not None)
if use_row:
use_row = not all(
getattr(row, column.name) in (None, "")
for column in table.__table__.columns
if column.name != "id"
)
use_row = not all(item in (None, "") for item in row[1:])
if not use_row:
connection.execute(
sa.update(models.ModelSettings)
.values({setting.name: False})
)
connection.execute(sa.text(f"UPDATE model_settings SET {setting} = 0"))


def upgrade():
Expand All @@ -133,8 +113,7 @@ def upgrade():
clean_by_type('triggers')
clean_by_type('views')
update_use_settings()
# Apply changing use_2d_rain and friction_averaging type to bool
sync_orm_types_to_sqlite('model_settings')
change_types_in_settings_table()


def downgrade():
Expand Down
6 changes: 5 additions & 1 deletion threedi_schema/tests/test_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ def get_columns_from_sqlite(cursor, table_name):
for c in cursor.fetchall():
if 'geom' in c[1]:
continue
type_str = c[2].lower() if c[2] != 'bool' else 'boolean'
type_str = c[2].lower()
if type_str == 'bool':
type_str = 'boolean'
if type_str == 'int':
type_str = 'integer'
col_map[c[1]] = (type_str, not c[3])
return col_map

Expand Down
Loading