diff --git a/threedi_schema/migrations/versions/0228_upgrade_db_1D.py b/threedi_schema/migrations/versions/0228_upgrade_db_1D.py index 61a332a..04e71ff 100644 --- a/threedi_schema/migrations/versions/0228_upgrade_db_1D.py +++ b/threedi_schema/migrations/versions/0228_upgrade_db_1D.py @@ -12,11 +12,11 @@ import sqlalchemy as sa from alembic import op -from sqlalchemy import Column, Float, func, Integer, select, String +from sqlalchemy import Column, Float, func, Integer, select, String, Text from sqlalchemy.orm import declarative_base, Session -from threedi_schema.domain import constants, models -from threedi_schema.domain.custom_types import IntegerEnum +from threedi_schema.domain import constants +from threedi_schema.domain.custom_types import Geometry, IntegerEnum from threedi_schema.migrations.utils import drop_conflicting, drop_geo_table Base = declarative_base() @@ -84,10 +84,42 @@ "pump": ["connection_node_end_id", "zoom_category", "classification"] } +ADD_COLUMNS = [ + ("channel", Column("tags", Text)), + ("cross_section_location", Column("tags", Text)), + ("culvert", Column("tags", Text)), + ("culvert", Column("material_id", Integer)), + ("orifice", Column("tags", Text)), + ("orifice", Column("material_id", Integer)), + ("pipe", Column("tags", Text)), + ("pump", Column("tags", Text)), + ("weir", Column("tags", Text)), + ("weir", Column("material_id", Integer)), + ("windshielding_1d", Column("tags", Text)), +] RETYPE_COLUMNS = {} +def add_columns_to_tables(table_columns: List[Tuple[str, Column]]): + # no checks for existence are done, this will fail if any column already exists + for dst_table, col in table_columns: + if isinstance(col.type, Geometry): + add_geometry_column(dst_table, col) + else: + with op.batch_alter_table(dst_table) as batch_op: + batch_op.add_column(col) + + +def add_geometry_column(table: str, geocol: Column): + # Adding geometry columns via alembic doesn't work + # https://postgis.net/docs/AddGeometryColumn.html + geotype = geocol.type + query = ( + f"SELECT AddGeometryColumn('{table}', '{geocol.name}', {geotype.srid}, '{geotype.geometry_type}', 'XY', 1);") + op.execute(sa.text(query)) + + class Schema228UpgradeException(Exception): pass @@ -104,42 +136,69 @@ def remove_tables(tables: List[str]): drop_geo_table(op, table) +def get_geom_type(table_name, geo_col_name): + connection = op.get_bind() + columns = connection.execute(sa.text(f"PRAGMA table_info('{table_name}')")).fetchall() + for col in columns: + if col[1] == geo_col_name: + return col[2] + def modify_table(old_table_name, new_table_name): - # Create a new table named `new_table_name` using the declared models + # Create a new table named `new_table_name` by copying the + # data from `old_table_name`. # Use the columns from `old_table_name`, with the following exceptions: + # * columns in `REMOVE_COLUMNS[new_table_name]` are skipped # * columns in `RENAME_COLUMNS[new_table_name]` are renamed + # * columns in `RETYPE_COLUMNS[new_table_name]` change type # * `the_geom` is renamed to `geom` and NOT NULL is enforced - model = find_model(new_table_name) - # create new table - create_sqlite_table_from_model(model) - # get column names from model and match them to available data in sqlite connection = op.get_bind() - rename_cols = {**RENAME_COLUMNS.get(new_table_name, {}), "the_geom": "geom"} - rename_cols_rev = {v: k for k, v in rename_cols.items()} - col_map = [(col.name, rename_cols_rev.get(col.name, col.name)) for col in get_cols_for_model(model)] - available_cols = [col[1] for col in connection.execute(sa.text(f"PRAGMA table_info('{old_table_name}')")).fetchall()] - new_col_names, old_col_names = zip(*[(new_col, old_col) for new_col, old_col in col_map if old_col in available_cols]) + columns = connection.execute(sa.text(f"PRAGMA table_info('{old_table_name}')")).fetchall() + # get all column names and types + col_names = [col[1] for col in columns] + col_types = [col[2] for col in columns] + # get type of the geometry column + geom_type = get_geom_type(old_table_name, 'the_geom') + # create list of new columns and types for creating the new table + # create list of old columns to copy to new table + skip_cols = ['id', 'the_geom'] + if new_table_name in REMOVE_COLUMNS: + skip_cols += REMOVE_COLUMNS[new_table_name] + old_col_names = [] + new_col_names = [] + new_col_types = [] + for cname, ctype in zip(col_names, col_types): + if cname in skip_cols: + continue + old_col_names.append(cname) + if new_table_name in RENAME_COLUMNS and cname in RENAME_COLUMNS[new_table_name]: + new_col_names.append(RENAME_COLUMNS[new_table_name][cname]) + else: + new_col_names.append(cname) + if new_table_name in RETYPE_COLUMNS and cname in RETYPE_COLUMNS[new_table_name]: + new_col_types.append(RETYPE_COLUMNS[new_table_name][cname]) + else: + new_col_types.append(ctype) + # add to the end manually + old_col_names.append('the_geom') + new_col_names.append('geom') + new_col_types.append(f'{geom_type} NOT NULL') + # Create new table (temp), insert data, drop original and rename temp to table_name + new_col_str = ','.join(['id INTEGER PRIMARY KEY NOT NULL'] + [f'{cname} {ctype}' for cname, ctype in + zip(new_col_names, new_col_types)]) + op.execute(sa.text(f"CREATE TABLE {new_table_name} ({new_col_str});")) # Copy data - # This may copy wrong type data because some types change!! - op.execute(sa.text(f"INSERT INTO {new_table_name} ({','.join(new_col_names)}) " - f"SELECT {','.join(old_col_names)} FROM {old_table_name}")) - + op.execute(sa.text(f"INSERT INTO {new_table_name} (id, {','.join(new_col_names)}) " + f"SELECT id, {','.join(old_col_names)} FROM {old_table_name}")) -def find_model(table_name): - for model in models.DECLARED_MODELS: - if model.__tablename__ == table_name: - return model - # This can only go wrong if the migration or model is incorrect - raise def fix_geometry_columns(): - update_models = [models.Channel, models.ConnectionNode, models.CrossSectionLocation, - models.Culvert, models.Orifice, models.Pipe, models.Pump, - models.PumpMap, models.Weir, models.Windshielding] - for model in update_models: - op.execute(sa.text(f"SELECT RecoverGeometryColumn('{model.__tablename__}', " - f"'geom', {4326}, '{model.geom.type.geometry_type}', 'XY')")) - op.execute(sa.text(f"SELECT CreateSpatialIndex('{model.__tablename__}', 'geom')")) + tables = ['channel', 'connection_node', 'cross_section_location', 'culvert', + 'orifice', 'pipe', 'pump', 'pump_map', 'weir', 'windshielding_1d'] + for table in tables: + geom_type = get_geom_type(table, geo_col_name='geom') + op.execute(sa.text(f"SELECT RecoverGeometryColumn('{table}', " + f"'geom', {4326}, '{geom_type}', 'XY')")) + op.execute(sa.text(f"SELECT CreateSpatialIndex('{table}', 'geom')")) class Temp(Base): @@ -305,29 +364,16 @@ def set_geom_for_v2_pumpstation(): op.execute(sa.text(q)) -def get_cols_for_model(model, skip_cols=None): - from sqlalchemy.orm.attributes import InstrumentedAttribute - if skip_cols is None: - skip_cols = [] - return [getattr(model, item) for item in model.__dict__ - if item not in skip_cols - and isinstance(getattr(model, item), InstrumentedAttribute)] - - -def create_sqlite_table_from_model(model): - cols = get_cols_for_model(model, skip_cols = ["id", "geom"]) - op.execute(sa.text(f""" - CREATE TABLE {model.__tablename__} ( - id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, - {','.join(f"{col.name} {col.type}" for col in cols)}, - geom {model.geom.type.geometry_type} NOT NULL - );""")) - - def create_pump_map(): # Create table - create_sqlite_table_from_model(models.PumpMap) - + query = """ + CREATE TABLE pump_map ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + pump_id INTEGER,connection_node_id_end INTEGER,tags TEXT,code VARCHAR(100),display_name VARCHAR(255), + geom LINESTRING NOT NULL + ); + """ + op.execute(sa.text(query)) # Create geometry op.execute(sa.text(f"SELECT AddGeometryColumn('v2_pumpstation', 'map_geom', 4326, 'LINESTRING', 'XY', 0);")) op.execute(sa.text(""" @@ -358,7 +404,15 @@ def create_pump_map(): def create_connection_node(): - create_sqlite_table_from_model(models.ConnectionNode) + # Create table + query = """ + CREATE TABLE connection_node ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + code VARCHAR(100),tags TEXT,display_name TEXT,storage_area FLOAT,initial_water_level FLOAT,visualisation INTEGER,manhole_surface_level FLOAT,bottom_level FLOAT,exchange_level FLOAT,exchange_type INTEGER,exchange_thickness FLOAT,hydraulic_conductivity_in FLOAT,hydraulic_conductivity_out FLOAT, + geom POINT NOT NULL + ); + """ + op.execute(sa.text(query)) # copy from v2_connection_nodes old_col_names = ["id", "initial_waterlevel", "storage_area", "the_geom", "code"] rename_map = {"initial_waterlevel": "initial_water_level", "the_geom": "geom"} @@ -389,6 +443,15 @@ def create_connection_node(): """)) +# define Material class needed to populate table in create_material +class Material(Base): + __tablename__ = "material" + id = Column(Integer, primary_key=True) + description = Column(Text) + friction_type = Column(IntegerEnum(constants.FrictionType)) + friction_coefficient = Column(Float) + + def create_material(): op.execute(sa.text(""" CREATE TABLE material ( @@ -397,12 +460,13 @@ def create_material(): friction_type INTEGER, friction_coefficient REAL); """)) + connection = op.get_bind() + nof_settings = connection.execute(sa.text("SELECT COUNT(*) FROM model_settings")).scalar() session = Session(bind=op.get_bind()) - nof_settings = session.execute(select(func.count()).select_from(models.ModelSettings)).scalar() if nof_settings > 0: with open(data_dir.joinpath('0228_materials.csv')) as file: reader = csv.DictReader(file) - session.bulk_save_objects([models.Material(**row) for row in reader]) + session.bulk_save_objects([Material(**row) for row in reader]) session.commit() @@ -468,6 +532,7 @@ def upgrade(): set_geom_for_v2_pumpstation() for old_table_name, new_table_name in RENAME_TABLES: modify_table(old_table_name, new_table_name) + add_columns_to_tables(ADD_COLUMNS) # Create new tables create_pump_map() create_material() diff --git a/threedi_schema/migrations/versions/0229_clean_up.py b/threedi_schema/migrations/versions/0229_clean_up.py index dabdba1..7346fc6 100644 --- a/threedi_schema/migrations/versions/0229_clean_up.py +++ b/threedi_schema/migrations/versions/0229_clean_up.py @@ -11,8 +11,6 @@ import sqlalchemy as sa from alembic import op -from threedi_schema import models - # revision identifiers, used by Alembic. revision = "0229" down_revision = "0228" @@ -20,40 +18,31 @@ depends_on = None -def find_model(table_name): - for model in models.DECLARED_MODELS: - if model.__tablename__ == table_name: - return model - # This can only go wrong if the migration or model is incorrect - raise - - -def create_sqlite_table_from_model(model, table_name): - cols = get_cols_for_model(model, skip_cols=["id"]) - op.execute(sa.text(f""" - CREATE TABLE {table_name} ( - id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, - {','.join(f"{col.name} {col.type}" for col in cols)} - );""")) - - -def get_cols_for_model(model, skip_cols=None): - from sqlalchemy.orm.attributes import InstrumentedAttribute - if skip_cols is None: - skip_cols = [] - return [getattr(model, item) for item in model.__dict__ - if item not in skip_cols - and isinstance(getattr(model, item), InstrumentedAttribute)] - +def get_geom_type(table_name, geo_col_name): + connection = op.get_bind() + columns = connection.execute(sa.text(f"PRAGMA table_info('{table_name}')")).fetchall() + for col in columns: + if col[1] == geo_col_name: + return col[2] -def sync_orm_types_to_sqlite(table_name): +def change_types_in_settings_table(): temp_table_name = f'_temp_229_{uuid.uuid4().hex}' - model = find_model(table_name) - create_sqlite_table_from_model(model, temp_table_name) - col_names = [col.name for col in get_cols_for_model(model)] - # This may copy wrong type data because some types change!! - op.execute(sa.text(f"INSERT INTO {temp_table_name} ({','.join(col_names)}) " - f"SELECT {','.join(col_names)} FROM {table_name}")) + table_name = 'model_settings' + change_types = {'use_d2_rain': 'bool', 'friction_averaging': 'bool'} + connection = op.get_bind() + columns = connection.execute(sa.text(f"PRAGMA table_info('{table_name}')")).fetchall() + # get all column names and types + skip_cols = ['id', 'the_geom'] + col_names = [col[1] for col in columns if col[1] not in skip_cols] + old_col_types = [col[2] for col in columns if col[1] not in skip_cols] + col_types = [change_types.get(col_name, col_type) for col_name, col_type in zip(col_names, old_col_types)] + # Create new table, insert data, drop original and rename temp to table_name + col_str = ','.join(['id INTEGER PRIMARY KEY NOT NULL'] + [f'{cname} {ctype}' for cname, ctype in + zip(col_names, col_types)]) + op.execute(sa.text(f"CREATE TABLE {temp_table_name} ({col_str});")) + # Copy data + op.execute(sa.text(f"INSERT INTO {temp_table_name} (id, {','.join(col_names)}) " + f"SELECT id, {','.join(col_names)} FROM {table_name}")) op.execute(sa.text(f"DROP TABLE {table_name}")) op.execute(sa.text(f"ALTER TABLE {temp_table_name} RENAME TO {table_name};")) @@ -98,33 +87,24 @@ def clean_by_type(type: str): def update_use_settings(): # Ensure that use_* settings are only True when there is actual data for them use_settings = [ - (models.ModelSettings.use_groundwater_storage, models.GroundWater), - (models.ModelSettings.use_groundwater_flow, models.GroundWater), - (models.ModelSettings.use_interflow, models.Interflow), - (models.ModelSettings.use_simple_infiltration, models.SimpleInfiltration), - (models.ModelSettings.use_vegetation_drag_2d, models.VegetationDrag), - (models.ModelSettings.use_interception, models.Interception) + ('use_groundwater_storage', 'groundwater'), + ('use_groundwater_flow', 'groundwater'), + ('use_interflow', 'interflow'), + ('use_simple_infiltration', 'simple_infiltration'), + ('use_vegetation_drag_2d', 'vegetation_drag_2d'), + ('use_interception', 'interception') ] connection = op.get_bind() # Get the connection for raw SQL execution for setting, table in use_settings: - use_row = connection.execute( - sa.select(getattr(models.ModelSettings, setting.name)) - ).scalar() + use_row = connection.execute(sa.text(f"SELECT {setting} FROM model_settings")).scalar() if not use_row: continue - row = connection.execute(sa.select(table)).first() + row = connection.execute(sa.text(f"SELECT * FROM {table}")).first() use_row = (row is not None) if use_row: - use_row = not all( - getattr(row, column.name) in (None, "") - for column in table.__table__.columns - if column.name != "id" - ) + use_row = not all(item in (None, "") for item in row[1:]) if not use_row: - connection.execute( - sa.update(models.ModelSettings) - .values({setting.name: False}) - ) + connection.execute(sa.text(f"UPDATE model_settings SET {setting} = 0")) def upgrade(): @@ -133,8 +113,7 @@ def upgrade(): clean_by_type('triggers') clean_by_type('views') update_use_settings() - # Apply changing use_2d_rain and friction_averaging type to bool - sync_orm_types_to_sqlite('model_settings') + change_types_in_settings_table() def downgrade(): diff --git a/threedi_schema/tests/test_migration.py b/threedi_schema/tests/test_migration.py index b2f1f07..8d58175 100644 --- a/threedi_schema/tests/test_migration.py +++ b/threedi_schema/tests/test_migration.py @@ -64,7 +64,11 @@ def get_columns_from_sqlite(cursor, table_name): for c in cursor.fetchall(): if 'geom' in c[1]: continue - type_str = c[2].lower() if c[2] != 'bool' else 'boolean' + type_str = c[2].lower() + if type_str == 'bool': + type_str = 'boolean' + if type_str == 'int': + type_str = 'integer' col_map[c[1]] = (type_str, not c[3]) return col_map