Skip to content

Commit

Permalink
Add SQLAlchemy 2 support (#46)
Browse files Browse the repository at this point in the history
* Add SQLAlchemy 2 support

* Fix linting

* update 2.0.x matrix scripts
  • Loading branch information
jowilf authored Mar 2, 2023
1 parent b3c6367 commit 1b20b95
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 26 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,19 @@ jobs:
ENGINE: 'sqlite:///test.db?check_same_thread=False'
STORAGE_PROVIDER: 'LOCAL'
LOCAL_PATH: '/tmp/storage'
run: hatch run test:all
run: hatch run test:run
- name: Test Local Storage provider & postgresql
env:
ENGINE: 'postgresql+psycopg2://username:password@localhost:5432/test_db'
STORAGE_PROVIDER: 'LOCAL'
LOCAL_PATH: '/tmp/storage'
run: hatch run test:all
run: hatch run test:run
- name: Test Local Storage provider & mysql
env:
ENGINE: 'mysql+pymysql://username:password@localhost:3306/test_db'
STORAGE_PROVIDER: 'LOCAL'
LOCAL_PATH: '/tmp/storage'
run: hatch run test:all
run: hatch run test:run
- name: Test Minio Storage provider & sqlite memory
env:
ENGINE: 'sqlite:///:memory:?check_same_thread=False'
Expand All @@ -83,7 +83,7 @@ jobs:
MINIO_HOST: 'localhost'
MINIO_PORT: 9000
MINIO_SECURE: false
run: hatch run test:all
run: hatch run test:run
- name: Coverage Report
run: hatch run test:cov
- name: Upload coverage
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ poetry.lock
dist
htmlcov
*.egg-info
.coverage
.coverage*
coverage.xml
site
*.db
Expand Down
32 changes: 23 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ classifiers = [
"Topic :: Database :: Database Engines/Servers",
]
dependencies = [
"SQLAlchemy >=1.4, <1.5",
"SQLAlchemy >=1.4, <2.1",
"apache-libcloud >=3.6, <3.8",
]
dynamic = ["version"]
Expand All @@ -44,16 +44,15 @@ Changelog = "https://jowilf.github.io/sqlalchemy-file/changelog/"
[project.optional-dependencies]
test = [
"pytest >=7.2.0, <7.3.0",
"mypy ==0.991",
"ruff ==0.0.215",
"black ==22.12.0",
"coverage >=7.0.0, <7.1.0",
"mypy ==1.0.1",
"ruff ==0.0.253",
"black ==23.1.0",
"coverage >=7.0.0, <7.3.0",
"fasteners ==0.18",
"PyMySQL[rsa] >=1.0.2, <1.1.0",
"psycopg2-binary >=2.9.5, <3.0.0",
"Pillow >=9.4.0, <9.5.0",
"sqlmodel ==0.0.8",
"python-multipart ==0.0.5",
"python-multipart ==0.0.6",
"fastapi >=0.92, <0.93",
"Flask >=2.2, <2.3",
"Flask-SQLAlchemy >=3.0,<3.1"
Expand All @@ -63,7 +62,7 @@ doc = [
"mkdocstrings[python] >=0.19.0, <0.21.0"
]
dev = [
"pre-commit >=2.20.0, <3.0.0",
"pre-commit >=2.20.0, <4.0.0",
"uvicorn >=0.20.0, <0.21.0",
]

Expand All @@ -88,13 +87,27 @@ lint = [
"ruff sqlalchemy_file tests",
"black . --check"
]
all = "coverage run -m pytest tests"
run = "coverage run -m pytest tests"
cov = [
"coverage combine",
"coverage report --show-missing",
"coverage xml"
]

[[tool.hatch.envs.test.matrix]]
sqla_version = ["1.4.x", "2.0.x"]

[tool.hatch.envs.test.overrides]
matrix.sqla_version.dependencies = [
{ value = "SQLAlchemy >=2.0, <2.1", if = ["2.0.x"] },
{ value = "SQLAlchemy >=1.4, <1.5", if = ["1.4.x"] },
{ value = "sqlmodel ==0.0.8", if = ["1.4.x"] },
]
matrix.sqla_version.scripts = [
{ key = "run", value = 'coverage run -m pytest tests --ignore=tests/test_sqlmodel.py', if = ["2.0.x"] },
{ key = "cov", value = '', if = ["2.0.x"] },
]

[tool.hatch.envs.docs]
features = [
"doc",
Expand Down Expand Up @@ -136,6 +149,7 @@ known-third-party = ["sqlalchemy_file"]

[tool.mypy]
strict = true
warn_unused_ignores = false

[tool.hatch.build.targets.wheel]
[tool.hatch.build.targets.sdist]
Expand Down
22 changes: 10 additions & 12 deletions sqlalchemy_file/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,7 @@ def add_old_files_to_session(cls, session: Session, paths: List[str]) -> None:
session._old_files.update(paths) # type: ignore

@classmethod
def extract_files_from_history(
cls, data: List[Union[MutableList[File], File]]
) -> List[str]:
def extract_files_from_history(cls, data: Union[Tuple[()], List[Any]]) -> List[str]:
paths = []
for item in data:
if isinstance(item, list):
Expand All @@ -202,7 +200,7 @@ def extract_files_from_history(
return paths

@classmethod
def _mapper_configured(cls, mapper: Mapper, class_: Any) -> None:
def _mapper_configured(cls, mapper: Mapper, class_: Any) -> None: # type: ignore[type-arg]
"""Detect and listen all class with FileField Column"""
for mapper_property in mapper.iterate_properties:
if isinstance(mapper_property, ColumnProperty) and isinstance(
Expand Down Expand Up @@ -238,7 +236,7 @@ def _after_soft_rollback(cls, session: Session, _: SessionTransaction) -> None:
cls.clear_session(session)

@classmethod
def _after_delete(cls, mapper: Mapper, _: Connection, obj: Any) -> None:
def _after_delete(cls, mapper: Mapper, _: Connection, obj: Any) -> None: # type: ignore[type-arg]
"""
After delete mark all linked files as old in order to delete
them when after session is committed
Expand All @@ -256,7 +254,7 @@ def _after_delete(cls, mapper: Mapper, _: Connection, obj: Any) -> None:
)

@classmethod
def _after_update(cls, mapper: Mapper, _: Connection, obj: Any) -> None:
def _after_update(cls, mapper: Mapper, _: Connection, obj: Any) -> None: # type: ignore[type-arg]
"""
After update, mark all edited files as old
in order to delete them when after session is committed
Expand All @@ -269,7 +267,7 @@ def _after_update(cls, mapper: Mapper, _: Connection, obj: Any) -> None:
)

@classmethod
def _before_update(cls, mapper: Mapper, _: Connection, obj: Any) -> None:
def _before_update(cls, mapper: Mapper, _: Connection, obj: Any) -> None: # type: ignore[type-arg]
"""
Before updating values, validate and save files. For multiple fields,
mark all removed files as old, as _removed attribute will be
Expand All @@ -292,7 +290,7 @@ def _before_update(cls, mapper: Mapper, _: Connection, obj: Any) -> None:
cls.add_old_files_to_session(session, [f["path"] for f in _removed])

@classmethod
def _before_insert(cls, mapper: Mapper, _: Connection, obj: Any) -> None:
def _before_insert(cls, mapper: Mapper, _: Connection, obj: Any) -> None: # type: ignore[type-arg]
"""Before inserting values, mark all created files as new. They will be
automatically removed when session rollback"""

Expand All @@ -308,7 +306,7 @@ def _before_insert(cls, mapper: Mapper, _: Connection, obj: Any) -> None:

@classmethod
def prepare_file_attr(
cls, mapper: Mapper, obj: Any, key: str
cls, mapper: Mapper, obj: Any, key: str # type: ignore[type-arg]
) -> Tuple[bool, Union[File, List[File]]]:
"""
Prepare file before saved to database, convert bytes and string,
Expand All @@ -320,7 +318,7 @@ def prepare_file_attr(
or when new items is added for multiple field"""
changed = False

column_type = mapper.attrs.get(key).columns[0].type
column_type = mapper.attrs.get(key).columns[0].type # type: ignore[misc,union-attr]
assert isinstance(column_type, FileField)
upload_type = column_type.upload_type

Expand Down Expand Up @@ -351,8 +349,8 @@ def prepare_file_attr(

@classmethod
def setup(cls) -> None:
event.listen(orm.mapper, "mapper_configured", cls._mapper_configured)
event.listen(orm.mapper, "after_configured", cls._after_configured)
event.listen(orm.Mapper, "mapper_configured", cls._mapper_configured)
event.listen(orm.Mapper, "after_configured", cls._after_configured)
event.listen(Session, "after_commit", cls._after_commit)
event.listen(Session, "after_soft_rollback", cls._after_soft_rollback)

Expand Down

0 comments on commit 1b20b95

Please sign in to comment.