Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

👷 Move to Ruff and add pre-commit #661

Merged
merged 6 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
default_language_version:
python: python3.10
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-added-large-files
- id: check-toml
- id: check-yaml
args:
- --unsafe
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/asottile/pyupgrade
rev: v3.7.0
hooks:
- id: pyupgrade
args:
- --py3-plus
- --keep-runtime-typing
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.275
hooks:
- id: ruff
args:
- --fix
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate
49 changes: 29 additions & 20 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@ SQLAlchemy = ">=1.4.17,<=1.4.41"
pydantic = "^1.8.2"
sqlalchemy2-stubs = {version = "*", allow-prereleases = true}

[tool.poetry.dev-dependencies]
[tool.poetry.group.dev.dependencies]
pytest = "^7.0.1"
mypy = "0.971"
flake8 = "^5.0.4"
black = "^22.10.0"
mkdocs = "^1.2.1"
mkdocs-material = "^8.1.4"
Expand All @@ -48,8 +47,7 @@ mdx-include = "^1.4.1"
coverage = {extras = ["toml"], version = "^6.2"}
fastapi = "^0.68.1"
requests = "^2.26.0"
autoflake = "^1.4"
isort = "^5.9.3"
ruff = "^0.1.1"

[build-system]
requires = ["poetry-core"]
Expand All @@ -75,27 +73,19 @@ exclude_lines = [
"if TYPE_CHECKING:",
]

[tool.isort]
profile = "black"
known_third_party = ["sqlmodel"]
skip_glob = [
"sqlmodel/__init__.py",
]


[tool.mypy]
# --strict
disallow_any_generics = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
check_untyped_defs = true
disallow_untyped_decorators = true
disallow_incomplete_defs = true
check_untyped_defs = true
disallow_untyped_decorators = true
no_implicit_optional = true
warn_redundant_casts = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_return_any = true
warn_return_any = true
implicit_reexport = false
strict_equality = true
# --strict end
Expand All @@ -104,4 +94,23 @@ strict_equality = true
module = "sqlmodel.sql.expression"
warn_unused_ignores = false

# invalidate CI cache: 1
[tool.ruff]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"C", # flake8-comprehensions
"B", # flake8-bugbear
]
ignore = [
"E501", # line too long, handled by black
"B008", # do not perform function calls in argument defaults
"C901", # too complex
]

[tool.ruff.per-file-ignores]
# "__init__.py" = ["F401"]

[tool.ruff.isort]
known-third-party = ["sqlmodel", "sqlalchemy", "pydantic", "fastapi"]
5 changes: 2 additions & 3 deletions scripts/format.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/bin/sh -e
set -x

autoflake --remove-all-unused-imports --recursive --remove-unused-variables --in-place sqlmodel docs_src tests --exclude=__init__.py
black sqlmodel tests docs_src
isort sqlmodel tests docs_src
ruff sqlmodel tests docs_src scripts --fix
black sqlmodel tests docs_src scripts
3 changes: 1 addition & 2 deletions scripts/lint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@ set -e
set -x

mypy sqlmodel
flake8 sqlmodel tests docs_src
ruff sqlmodel tests docs_src scripts
black sqlmodel tests docs_src --check
isort sqlmodel tests docs_src scripts --check-only
62 changes: 30 additions & 32 deletions sqlmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from sqlalchemy.engine import engine_from_config as engine_from_config
from sqlalchemy.inspection import inspect as inspect
from sqlalchemy.schema import BLANK_SCHEMA as BLANK_SCHEMA
from sqlalchemy.schema import DDL as DDL
from sqlalchemy.schema import CheckConstraint as CheckConstraint
from sqlalchemy.schema import Column as Column
from sqlalchemy.schema import ColumnDefault as ColumnDefault
from sqlalchemy.schema import Computed as Computed
from sqlalchemy.schema import Constraint as Constraint
from sqlalchemy.schema import DDL as DDL
from sqlalchemy.schema import DefaultClause as DefaultClause
from sqlalchemy.schema import FetchedValue as FetchedValue
from sqlalchemy.schema import ForeignKey as ForeignKey
Expand All @@ -23,6 +23,14 @@
from sqlalchemy.schema import Table as Table
from sqlalchemy.schema import ThreadLocalMetaData as ThreadLocalMetaData
from sqlalchemy.schema import UniqueConstraint as UniqueConstraint
from sqlalchemy.sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT
from sqlalchemy.sql import (
LABEL_STYLE_DISAMBIGUATE_ONLY as LABEL_STYLE_DISAMBIGUATE_ONLY,
)
from sqlalchemy.sql import LABEL_STYLE_NONE as LABEL_STYLE_NONE
from sqlalchemy.sql import (
LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL,
)
from sqlalchemy.sql import alias as alias
from sqlalchemy.sql import all_ as all_
from sqlalchemy.sql import and_ as and_
Expand All @@ -48,14 +56,6 @@
from sqlalchemy.sql import intersect as intersect
from sqlalchemy.sql import intersect_all as intersect_all
from sqlalchemy.sql import join as join
from sqlalchemy.sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT
from sqlalchemy.sql import (
LABEL_STYLE_DISAMBIGUATE_ONLY as LABEL_STYLE_DISAMBIGUATE_ONLY,
)
from sqlalchemy.sql import LABEL_STYLE_NONE as LABEL_STYLE_NONE
from sqlalchemy.sql import (
LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL,
)
from sqlalchemy.sql import lambda_stmt as lambda_stmt
from sqlalchemy.sql import lateral as lateral
from sqlalchemy.sql import literal as literal
Expand Down Expand Up @@ -85,55 +85,53 @@
from sqlalchemy.sql import within_group as within_group
from sqlalchemy.types import ARRAY as ARRAY
from sqlalchemy.types import BIGINT as BIGINT
from sqlalchemy.types import BigInteger as BigInteger
from sqlalchemy.types import BINARY as BINARY
from sqlalchemy.types import BLOB as BLOB
from sqlalchemy.types import BOOLEAN as BOOLEAN
from sqlalchemy.types import Boolean as Boolean
from sqlalchemy.types import CHAR as CHAR
from sqlalchemy.types import CLOB as CLOB
from sqlalchemy.types import DATE as DATE
from sqlalchemy.types import Date as Date
from sqlalchemy.types import DATETIME as DATETIME
from sqlalchemy.types import DateTime as DateTime
from sqlalchemy.types import DECIMAL as DECIMAL
from sqlalchemy.types import Enum as Enum
from sqlalchemy.types import FLOAT as FLOAT
from sqlalchemy.types import Float as Float
from sqlalchemy.types import INT as INT
from sqlalchemy.types import INTEGER as INTEGER
from sqlalchemy.types import Integer as Integer
from sqlalchemy.types import Interval as Interval
from sqlalchemy.types import JSON as JSON
from sqlalchemy.types import LargeBinary as LargeBinary
from sqlalchemy.types import NCHAR as NCHAR
from sqlalchemy.types import NUMERIC as NUMERIC
from sqlalchemy.types import Numeric as Numeric
from sqlalchemy.types import NVARCHAR as NVARCHAR
from sqlalchemy.types import PickleType as PickleType
from sqlalchemy.types import REAL as REAL
from sqlalchemy.types import SMALLINT as SMALLINT
from sqlalchemy.types import TEXT as TEXT
from sqlalchemy.types import TIME as TIME
from sqlalchemy.types import TIMESTAMP as TIMESTAMP
from sqlalchemy.types import VARBINARY as VARBINARY
from sqlalchemy.types import VARCHAR as VARCHAR
from sqlalchemy.types import BigInteger as BigInteger
from sqlalchemy.types import Boolean as Boolean
from sqlalchemy.types import Date as Date
from sqlalchemy.types import DateTime as DateTime
from sqlalchemy.types import Enum as Enum
from sqlalchemy.types import Float as Float
from sqlalchemy.types import Integer as Integer
from sqlalchemy.types import Interval as Interval
from sqlalchemy.types import LargeBinary as LargeBinary
from sqlalchemy.types import Numeric as Numeric
from sqlalchemy.types import PickleType as PickleType
from sqlalchemy.types import SmallInteger as SmallInteger
from sqlalchemy.types import String as String
from sqlalchemy.types import TEXT as TEXT
from sqlalchemy.types import Text as Text
from sqlalchemy.types import TIME as TIME
from sqlalchemy.types import Time as Time
from sqlalchemy.types import TIMESTAMP as TIMESTAMP
from sqlalchemy.types import TypeDecorator as TypeDecorator
from sqlalchemy.types import Unicode as Unicode
from sqlalchemy.types import UnicodeText as UnicodeText
from sqlalchemy.types import VARBINARY as VARBINARY
from sqlalchemy.types import VARCHAR as VARCHAR

# Extensions and modifications of SQLAlchemy in SQLModel
# From SQLModel, modifications of SQLAlchemy or equivalents of Pydantic
from .engine.create import create_engine as create_engine
from .main import Field as Field
from .main import Relationship as Relationship
from .main import SQLModel as SQLModel
from .orm.session import Session as Session
from .sql.expression import select as select
from .sql.expression import col as col
from .sql.expression import select as select
from .sql.sqltypes import AutoString as AutoString

# Export SQLModel specifics (equivalent to Pydantic)
from .main import SQLModel as SQLModel
from .main import Field as Field
from .main import Relationship as Relationship
29 changes: 19 additions & 10 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,24 @@

from pydantic import BaseConfig, BaseModel
from pydantic.errors import ConfigError, DictError
from pydantic.fields import SHAPE_SINGLETON
from pydantic.fields import SHAPE_SINGLETON, ModelField, Undefined, UndefinedType
from pydantic.fields import FieldInfo as PydanticFieldInfo
from pydantic.fields import ModelField, Undefined, UndefinedType
from pydantic.main import ModelMetaclass, validate_model
from pydantic.typing import NoArgAnyCallable, resolve_annotations
from pydantic.utils import ROOT_KEY, Representation
from sqlalchemy import Boolean, Column, Date, DateTime
from sqlalchemy import (
Boolean,
Column,
Date,
DateTime,
Float,
ForeignKey,
Integer,
Interval,
Numeric,
inspect,
)
from sqlalchemy import Enum as sa_Enum
from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect
from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship
from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.decl_api import DeclarativeMeta
Expand Down Expand Up @@ -305,9 +314,9 @@ def get_config(name: str) -> Any:
config_registry = cast(registry, config_registry)
# If it was passed by kwargs, ensure it's also set in config
new_cls.__config__.registry = config_table
setattr(new_cls, "_sa_registry", config_registry)
setattr(new_cls, "metadata", config_registry.metadata)
setattr(new_cls, "__abstract__", True)
setattr(new_cls, "_sa_registry", config_registry) # noqa: B010
setattr(new_cls, "metadata", config_registry.metadata) # noqa: B010
setattr(new_cls, "__abstract__", True) # noqa: B010
return new_cls

# Override SQLAlchemy, allow both SQLAlchemy and plain Pydantic models
Expand All @@ -320,7 +329,7 @@ def __init__(
# triggers an error
base_is_table = False
for base in bases:
config = getattr(base, "__config__")
config = getattr(base, "__config__") # noqa: B009
if config and getattr(config, "table", False):
base_is_table = True
break
Expand Down Expand Up @@ -351,7 +360,7 @@ def __init__(
rel_kwargs["back_populates"] = rel_info.back_populates
if rel_info.link_model:
ins = inspect(rel_info.link_model)
local_table = getattr(ins, "local_table")
local_table = getattr(ins, "local_table") # noqa: B009
if local_table is None:
raise RuntimeError(
"Couldn't find the secondary table for "
Expand Down Expand Up @@ -430,7 +439,7 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore
# Override derived nullability if the nullable property is set explicitly
# on the field
if hasattr(field.field_info, "nullable"):
field_nullable = getattr(field.field_info, "nullable")
field_nullable = getattr(field.field_info, "nullable") # noqa: B009
if field_nullable != Undefined:
nullable = field_nullable
args = []
Expand Down