diff --git a/.github/workflows/test_destination_clickhouse.yml b/.github/workflows/test_destination_clickhouse.yml index d834df6b28..4aea5a8e90 100644 --- a/.github/workflows/test_destination_clickhouse.yml +++ b/.github/workflows/test_destination_clickhouse.yml @@ -1,4 +1,3 @@ - name: test | clickhouse on: @@ -8,7 +7,7 @@ on: - devel workflow_dispatch: schedule: - - cron: '0 2 * * *' + - cron: '0 2 * * *' concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} @@ -20,7 +19,7 @@ env: DLT_SECRETS_TOML: ${{ secrets.DLT_SECRETS_TOML }} ACTIVE_DESTINATIONS: "[\"clickhouse\"]" - ALL_FILESYSTEM_DRIVERS: "[\"memory\"]" + ALL_FILESYSTEM_DRIVERS: "[\"memory\", \"file\"]" jobs: get_docs_changes: @@ -67,12 +66,51 @@ jobs: - name: create secrets.toml run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml + # OSS ClickHouse + - run: | + docker-compose -f "tests/load/clickhouse/clickhouse-compose.yml" up -d + echo "Waiting for ClickHouse to be healthy..." + timeout 30s bash -c 'until docker-compose -f "tests/load/clickhouse/clickhouse-compose.yml" ps | grep -q "healthy"; do sleep 1; done' + echo "ClickHouse is up and running" + name: Start ClickHouse OSS + + + - run: poetry run pytest tests/load -m "essential" + name: Run essential tests Linux (ClickHouse OSS) + if: ${{ ! (contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule')}} + env: + DESTINATION__CLICKHOUSE__CREDENTIALS__HOST: localhost + DESTINATION__CLICKHOUSE__CREDENTIALS__DATABASE: dlt_data + DESTINATION__CLICKHOUSE__CREDENTIALS__USERNAME: loader + DESTINATION__CLICKHOUSE__CREDENTIALS__PASSWORD: loader + DESTINATION__CLICKHOUSE__CREDENTIALS__PORT: 9000 + DESTINATION__CLICKHOUSE__CREDENTIALS__HTTP_PORT: 8123 + DESTINATION__CLICKHOUSE__CREDENTIALS__SECURE: 0 + + - run: poetry run pytest tests/load + name: Run all tests Linux (ClickHouse OSS) + if: ${{ contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule'}} + env: + DESTINATION__CLICKHOUSE__CREDENTIALS__HOST: localhost + DESTINATION__CLICKHOUSE__CREDENTIALS__DATABASE: dlt_data + DESTINATION__CLICKHOUSE__CREDENTIALS__USERNAME: loader + DESTINATION__CLICKHOUSE__CREDENTIALS__PASSWORD: loader + DESTINATION__CLICKHOUSE__CREDENTIALS__PORT: 9000 + DESTINATION__CLICKHOUSE__CREDENTIALS__HTTP_PORT: 8123 + DESTINATION__CLICKHOUSE__CREDENTIALS__SECURE: 0 + + - name: Stop ClickHouse OSS + if: always() + run: docker-compose -f "tests/load/clickhouse/clickhouse-compose.yml" down -v + + # ClickHouse Cloud - run: | poetry run pytest tests/load -m "essential" - name: Run essential tests Linux + name: Run essential tests Linux (ClickHouse Cloud) if: ${{ ! (contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule')}} - run: | poetry run pytest tests/load - name: Run all tests Linux + name: Run all tests Linux (ClickHouse Cloud) if: ${{ contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule'}} + diff --git a/.github/workflows/test_destination_lancedb.yml b/.github/workflows/test_destination_lancedb.yml new file mode 100644 index 0000000000..02b5ef66eb --- /dev/null +++ b/.github/workflows/test_destination_lancedb.yml @@ -0,0 +1,81 @@ +name: dest | lancedb + +on: + pull_request: + branches: + - master + - devel + workflow_dispatch: + schedule: + - cron: '0 2 * * *' + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +env: + DLT_SECRETS_TOML: ${{ secrets.DLT_SECRETS_TOML }} + + RUNTIME__SENTRY_DSN: https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752 + RUNTIME__LOG_LEVEL: ERROR + RUNTIME__DLTHUB_TELEMETRY_ENDPOINT: ${{ secrets.RUNTIME__DLTHUB_TELEMETRY_ENDPOINT }} + + ACTIVE_DESTINATIONS: "[\"lancedb\"]" + ALL_FILESYSTEM_DRIVERS: "[\"memory\"]" + +jobs: + get_docs_changes: + name: docs changes + uses: ./.github/workflows/get_docs_changes.yml + if: ${{ !github.event.pull_request.head.repo.fork || contains(github.event.pull_request.labels.*.name, 'ci from fork')}} + + run_loader: + name: dest | lancedb tests + needs: get_docs_changes + if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' + defaults: + run: + shell: bash + runs-on: "ubuntu-latest" + + steps: + - name: Check out + uses: actions/checkout@master + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: "3.11.x" + + - name: Install Poetry + uses: snok/install-poetry@v1.3.2 + with: + virtualenvs-create: true + virtualenvs-in-project: true + installer-parallel: true + + - name: Load cached venv + id: cached-poetry-dependencies + uses: actions/cache@v3 + with: + path: .venv + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}-gcp + + - name: create secrets.toml + run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml + + - name: Install dependencies + run: poetry install --no-interaction -E lancedb -E parquet --with sentry-sdk --with pipeline + + - name: Install embedding provider dependencies + run: poetry run pip install openai + + - run: | + poetry run pytest tests/load -m "essential" + name: Run essential tests Linux + if: ${{ ! (contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule')}} + + - run: | + poetry run pytest tests/load + name: Run all tests Linux + if: ${{ contains(github.event.pull_request.labels.*.name, 'ci full') || github.event_name == 'schedule'}} diff --git a/.github/workflows/test_doc_snippets.yml b/.github/workflows/test_doc_snippets.yml index cb6417a4ab..b140935d4c 100644 --- a/.github/workflows/test_doc_snippets.yml +++ b/.github/workflows/test_doc_snippets.yml @@ -21,6 +21,8 @@ env: # Slack hook for chess in production example RUNTIME__SLACK_INCOMING_HOOK: ${{ secrets.RUNTIME__SLACK_INCOMING_HOOK }} + # Path to local qdrant database + DESTINATION__QDRANT__CREDENTIALS__PATH: zendesk.qdb # detect if the workflow is executed in a repo fork IS_FORK: ${{ github.event.pull_request.head.repo.fork }} @@ -32,6 +34,26 @@ jobs: # Do not run on forks, unless allowed, secrets are used here if: ${{ !github.event.pull_request.head.repo.fork || contains(github.event.pull_request.labels.*.name, 'ci from fork')}} + # Service containers to run with `container-job` + services: + # Label used to access the service container + postgres: + # Docker Hub image + image: postgres + # Provide the password for postgres + env: + POSTGRES_DB: dlt_data + POSTGRES_USER: loader + POSTGRES_PASSWORD: loader + ports: + - 5432:5432 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + steps: - name: Check out @@ -61,7 +83,7 @@ jobs: - name: Install dependencies # if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction -E duckdb -E weaviate -E parquet -E qdrant -E bigquery -E postgres --with docs,sentry-sdk --without airflow + run: poetry install --no-interaction -E duckdb -E weaviate -E parquet -E qdrant -E bigquery -E postgres -E lancedb --with docs,sentry-sdk --without airflow - name: create secrets.toml for examples run: pwd && echo "$DLT_SECRETS_TOML" > docs/examples/.dlt/secrets.toml diff --git a/.github/workflows/test_local_destinations.yml b/.github/workflows/test_local_destinations.yml index 263d3f588c..f1bf6016bc 100644 --- a/.github/workflows/test_local_destinations.yml +++ b/.github/workflows/test_local_destinations.yml @@ -21,7 +21,7 @@ env: RUNTIME__SENTRY_DSN: https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752 RUNTIME__LOG_LEVEL: ERROR RUNTIME__DLTHUB_TELEMETRY_ENDPOINT: ${{ secrets.RUNTIME__DLTHUB_TELEMETRY_ENDPOINT }} - ACTIVE_DESTINATIONS: "[\"duckdb\", \"postgres\", \"filesystem\", \"weaviate\"]" + ACTIVE_DESTINATIONS: "[\"duckdb\", \"postgres\", \"filesystem\", \"weaviate\", \"qdrant\"]" ALL_FILESYSTEM_DRIVERS: "[\"memory\", \"file\"]" DESTINATION__WEAVIATE__VECTORIZER: text2vec-contextionary @@ -63,6 +63,11 @@ jobs: --health-timeout 5s --health-retries 5 + qdrant: + image: qdrant/qdrant:v1.8.4 + ports: + - 6333:6333 + steps: - name: Check out uses: actions/checkout@master @@ -90,7 +95,7 @@ jobs: key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}-local-destinations - name: Install dependencies - run: poetry install --no-interaction -E postgres -E duckdb -E parquet -E filesystem -E cli -E weaviate --with sentry-sdk --with pipeline -E deltalake + run: poetry install --no-interaction -E postgres -E duckdb -E parquet -E filesystem -E cli -E weaviate -E qdrant --with sentry-sdk --with pipeline -E deltalake - name: create secrets.toml run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml @@ -100,6 +105,7 @@ jobs: name: Run tests Linux env: DESTINATION__POSTGRES__CREDENTIALS: postgresql://loader:loader@localhost:5432/dlt_data + DESTINATION__QDRANT__CREDENTIALS__location: http://localhost:6333 - name: Stop weaviate if: always() diff --git a/Makefile b/Makefile index fd0920d188..15fb895a9f 100644 --- a/Makefile +++ b/Makefile @@ -67,9 +67,9 @@ lint-and-test-snippets: cd docs/website/docs && poetry run pytest --ignore=node_modules lint-and-test-examples: - poetry run mypy --config-file mypy.ini docs/examples - poetry run flake8 --max-line-length=200 docs/examples cd docs/tools && poetry run python prepare_examples_tests.py + poetry run flake8 --max-line-length=200 docs/examples + poetry run mypy --config-file mypy.ini docs/examples cd docs/examples && poetry run pytest diff --git a/README.md b/README.md index ed1cc751c2..bc0f40b62f 100644 --- a/README.md +++ b/README.md @@ -30,20 +30,12 @@ Be it a Google Colab notebook, AWS Lambda function, an Airflow DAG, your local l dlt supports Python 3.8+. -**pip:** ```sh pip install dlt ``` -**pixi:** -```sh -pixi add dlt -``` +More options: [Install via Conda or Pixi](https://dlthub.com/docs/reference/installation#install-dlt-via-pixi-and-conda) -**conda:** -```sh -conda install -c conda-forge dlt -``` ## Quick Start diff --git a/deploy/dlt/Dockerfile b/deploy/dlt/Dockerfile index f3d4f9d707..3f9f6a2341 100644 --- a/deploy/dlt/Dockerfile +++ b/deploy/dlt/Dockerfile @@ -31,7 +31,7 @@ RUN apk update &&\ # add build labels and envs ARG COMMIT_SHA="" ARG IMAGE_VERSION="" -LABEL commit_sha = ${COMMIT_SHA} +LABEL commit_sha=${COMMIT_SHA} LABEL version=${IMAGE_VERSION} ENV COMMIT_SHA=${COMMIT_SHA} ENV IMAGE_VERSION=${IMAGE_VERSION} diff --git a/deploy/dlt/Dockerfile.airflow b/deploy/dlt/Dockerfile.airflow index 43adf5ea95..620b72da0e 100644 --- a/deploy/dlt/Dockerfile.airflow +++ b/deploy/dlt/Dockerfile.airflow @@ -14,7 +14,7 @@ WORKDIR /tmp/pydlt # add build labels and envs ARG COMMIT_SHA="" ARG IMAGE_VERSION="" -LABEL commit_sha = ${COMMIT_SHA} +LABEL commit_sha=${COMMIT_SHA} LABEL version=${IMAGE_VERSION} ENV COMMIT_SHA=${COMMIT_SHA} ENV IMAGE_VERSION=${IMAGE_VERSION} diff --git a/dlt/cli/_dlt.py b/dlt/cli/_dlt.py index af4f2f66e9..7c6526c0a2 100644 --- a/dlt/cli/_dlt.py +++ b/dlt/cli/_dlt.py @@ -164,7 +164,7 @@ def schema_command_wrapper(file_path: str, format_: str, remove_defaults: bool) schema_str = json.dumps(s.to_dict(remove_defaults=remove_defaults), pretty=True) else: schema_str = s.to_pretty_yaml(remove_defaults=remove_defaults) - print(schema_str) + fmt.echo(schema_str) return 0 diff --git a/dlt/cli/pipeline_command.py b/dlt/cli/pipeline_command.py index d66d884ff2..6aa479a398 100644 --- a/dlt/cli/pipeline_command.py +++ b/dlt/cli/pipeline_command.py @@ -8,7 +8,12 @@ from dlt.common.destination.reference import TDestinationReferenceArg from dlt.common.runners import Venv from dlt.common.runners.stdout import iter_stdout -from dlt.common.schema.utils import group_tables_by_resource, remove_defaults +from dlt.common.schema.utils import ( + group_tables_by_resource, + has_table_seen_data, + is_complete_column, + remove_defaults, +) from dlt.common.storages import FileStorage, PackageStorage from dlt.pipeline.helpers import DropCommand from dlt.pipeline.exceptions import CannotRestorePipelineException @@ -180,6 +185,35 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]: fmt.bold(str(res_state_slots)), ) ) + if verbosity > 0: + for table in tables: + incomplete_columns = len( + [ + col + for col in table["columns"].values() + if not is_complete_column(col) + ] + ) + fmt.echo( + "\t%s table %s column(s) %s %s" + % ( + fmt.bold(table["name"]), + fmt.bold(str(len(table["columns"]))), + ( + fmt.style("received data", fg="green") + if has_table_seen_data(table) + else fmt.style("not yet received data", fg="yellow") + ), + ( + fmt.style( + f"{incomplete_columns} incomplete column(s)", + fg="yellow", + ) + if incomplete_columns > 0 + else "" + ), + ) + ) fmt.echo() fmt.echo("Working dir content:") _display_pending_packages() @@ -272,7 +306,7 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]: fmt.echo(package_info.asstr(verbosity)) if len(package_info.schema_update) > 0: if verbosity == 0: - print("Add -v option to see schema update. Note that it could be large.") + fmt.echo("Add -v option to see schema update. Note that it could be large.") else: tables = remove_defaults({"tables": package_info.schema_update}) # type: ignore fmt.echo(fmt.bold("Schema update:")) @@ -316,7 +350,7 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]: fmt.echo( "About to drop the following data in dataset %s in destination %s:" % ( - fmt.bold(drop.info["dataset_name"]), + fmt.bold(p.dataset_name), fmt.bold(p.destination.destination_name), ) ) @@ -329,6 +363,10 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]: ) ) fmt.echo("%s: %s" % (fmt.style("Table(s) to drop", fg="green"), drop.info["tables"])) + fmt.echo( + "%s: %s" + % (fmt.style("\twith data in destination", fg="green"), drop.info["tables_with_data"]) + ) fmt.echo( "%s: %s" % ( diff --git a/dlt/common/configuration/paths.py b/dlt/common/configuration/paths.py index 89494ba6bd..9d0b47f8b6 100644 --- a/dlt/common/configuration/paths.py +++ b/dlt/common/configuration/paths.py @@ -1,16 +1,16 @@ import os import tempfile -# dlt settings folder -DOT_DLT = ".dlt" +from dlt.common import known_env + -# dlt data dir is by default not set, see get_dlt_data_dir for details -DLT_DATA_DIR: str = None +# dlt settings folder +DOT_DLT = os.environ.get(known_env.DLT_CONFIG_FOLDER, ".dlt") def get_dlt_project_dir() -> str: """The dlt project dir is the current working directory but may be overridden by DLT_PROJECT_DIR env variable.""" - return os.environ.get("DLT_PROJECT_DIR", ".") + return os.environ.get(known_env.DLT_PROJECT_DIR, ".") def get_dlt_settings_dir() -> str: @@ -27,14 +27,14 @@ def make_dlt_settings_path(path: str) -> str: def get_dlt_data_dir() -> str: - """Gets default directory where pipelines' data will be stored - 1. in user home directory: ~/.dlt/ - 2. if current user is root: in /var/dlt/ - 3. if current user does not have a home directory: in /tmp/dlt/ - 4. if DLT_DATA_DIR is set in env then it is used + """Gets default directory where pipelines' data (working directories) will be stored + 1. if DLT_DATA_DIR is set in env then it is used + 2. in user home directory: ~/.dlt/ + 3. if current user is root: in /var/dlt/ + 4. if current user does not have a home directory: in /tmp/dlt/ """ - if "DLT_DATA_DIR" in os.environ: - return os.environ["DLT_DATA_DIR"] + if known_env.DLT_DATA_DIR in os.environ: + return os.environ[known_env.DLT_DATA_DIR] # geteuid not available on Windows if hasattr(os, "geteuid") and os.geteuid() == 0: diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py index 9a4373039b..ee8a1f6029 100644 --- a/dlt/common/configuration/resolve.py +++ b/dlt/common/configuration/resolve.py @@ -126,14 +126,21 @@ def _maybe_parse_native_value( not isinstance(explicit_value, C_Mapping) or isinstance(explicit_value, BaseConfiguration) ): try: + # parse the native value anyway because there are configs with side effects config.parse_native_representation(explicit_value) + default_value = config.__class__() + # parse native value and convert it into dict, extract the diff and use it as exact value + # NOTE: as those are the same dataclasses, the set of keys must be the same + explicit_value = { + k: v + for k, v in config.__class__.from_init_value(explicit_value).items() + if default_value[k] != v + } except ValueError as v_err: # provide generic exception raise InvalidNativeValue(type(config), type(explicit_value), embedded_sections, v_err) except NotImplementedError: pass - # explicit value was consumed - explicit_value = None return explicit_value @@ -336,7 +343,11 @@ def _resolve_config_field( # print(f"{embedded_config} IS RESOLVED with VALUE {value}") # injected context will be resolved if value is not None: - _maybe_parse_native_value(embedded_config, value, embedded_sections + (key,)) + from_native_explicit = _maybe_parse_native_value( + embedded_config, value, embedded_sections + (key,) + ) + if from_native_explicit is not value: + embedded_config.update(from_native_explicit) value = embedded_config else: # only config with sections may look for initial values diff --git a/dlt/common/configuration/specs/base_configuration.py b/dlt/common/configuration/specs/base_configuration.py index 1751b6ae13..2504fdeaef 100644 --- a/dlt/common/configuration/specs/base_configuration.py +++ b/dlt/common/configuration/specs/base_configuration.py @@ -49,8 +49,7 @@ # forward class declaration _F_BaseConfiguration: Any = type(object) _F_ContainerInjectableContext: Any = type(object) -_T = TypeVar("_T", bound="BaseConfiguration") -_C = TypeVar("_C", bound="CredentialsConfiguration") +_B = TypeVar("_B", bound="BaseConfiguration") class NotResolved: @@ -289,6 +288,33 @@ class BaseConfiguration(MutableMapping[str, Any]): """Typing for dataclass fields""" __hint_resolvers__: ClassVar[Dict[str, Callable[["BaseConfiguration"], Type[Any]]]] = {} + @classmethod + def from_init_value(cls: Type[_B], init_value: Any = None) -> _B: + """Initializes credentials from `init_value` + + Init value may be a native representation of the credentials or a dict. In case of native representation (for example a connection string or JSON with service account credentials) + a `parse_native_representation` method will be used to parse it. In case of a dict, the credentials object will be updated with key: values of the dict. + Unexpected values in the dict will be ignored. + + Credentials will be marked as resolved if all required fields are set resolve() method is successful + """ + # create an instance + self = cls() + self._apply_init_value(init_value) + if not self.is_partial(): + # let it fail gracefully + with contextlib.suppress(Exception): + self.resolve() + return self + + def _apply_init_value(self, init_value: Any = None) -> None: + if isinstance(init_value, C_Mapping): + self.update(init_value) + elif init_value is not None: + self.parse_native_representation(init_value) + else: + return + def parse_native_representation(self, native_value: Any) -> None: """Initialize the configuration fields by parsing the `native_value` which should be a native representation of the configuration or credentials, for example database connection string or JSON serialized GCP service credentials file. @@ -348,7 +374,7 @@ def resolve(self) -> None: self.call_method_in_mro("on_resolved") self.__is_resolved__ = True - def copy(self: _T) -> _T: + def copy(self: _B) -> _B: """Returns a deep copy of the configuration instance""" return copy.deepcopy(self) @@ -426,21 +452,6 @@ class CredentialsConfiguration(BaseConfiguration): __section__: ClassVar[str] = "credentials" - @classmethod - def from_init_value(cls: Type[_C], init_value: Any = None) -> _C: - """Initializes credentials from `init_value` - - Init value may be a native representation of the credentials or a dict. In case of native representation (for example a connection string or JSON with service account credentials) - a `parse_native_representation` method will be used to parse it. In case of a dict, the credentials object will be updated with key: values of the dict. - Unexpected values in the dict will be ignored. - - Credentials will be marked as resolved if all required fields are set. - """ - # create an instance - self = cls() - self._apply_init_value(init_value) - return self - def to_native_credentials(self) -> Any: """Returns native credentials object. @@ -448,16 +459,6 @@ def to_native_credentials(self) -> Any: """ return self.to_native_representation() - def _apply_init_value(self, init_value: Any = None) -> None: - if isinstance(init_value, C_Mapping): - self.update(init_value) - elif init_value is not None: - self.parse_native_representation(init_value) - else: - return - if not self.is_partial(): - self.resolve() - def __str__(self) -> str: """Get string representation of credentials to be displayed, with all secret parts removed""" return super().__str__() diff --git a/dlt/common/configuration/specs/connection_string_credentials.py b/dlt/common/configuration/specs/connection_string_credentials.py index 2691c5d886..5b9a4587c7 100644 --- a/dlt/common/configuration/specs/connection_string_credentials.py +++ b/dlt/common/configuration/specs/connection_string_credentials.py @@ -15,7 +15,7 @@ class ConnectionStringCredentials(CredentialsConfiguration): username: str = None host: Optional[str] = None port: Optional[int] = None - query: Optional[Dict[str, str]] = None + query: Optional[Dict[str, Any]] = None __config_gen_annotations__: ClassVar[List[str]] = ["port", "password", "host"] @@ -44,7 +44,22 @@ def on_resolved(self) -> None: def to_native_representation(self) -> str: return self.to_url().render_as_string(hide_password=False) + def get_query(self) -> Dict[str, Any]: + """Gets query preserving parameter types. Mostly used internally to export connection params""" + return {} if self.query is None else self.query + def to_url(self) -> URL: + """Creates SQLAlchemy compatible URL object, computes current query via `get_query` and serializes its values to str""" + # circular dependencies here + from dlt.common.configuration.utils import serialize_value + + def _serialize_value(v_: Any) -> str: + if v_ is None: + return None + return serialize_value(v_) + + # query must be str -> str + query = {k: _serialize_value(v) for k, v in self.get_query().items()} return URL.create( self.drivername, self.username, @@ -52,8 +67,12 @@ def to_url(self) -> URL: self.host, self.port, self.database, - self.query, + query, ) def __str__(self) -> str: - return self.to_url().render_as_string(hide_password=True) + url = self.to_url() + # do not display query. it often contains secret values + url = url._replace(query=None) + # we only have control over netloc/path + return url.render_as_string(hide_password=True) diff --git a/dlt/common/configuration/specs/gcp_credentials.py b/dlt/common/configuration/specs/gcp_credentials.py index a1d82fc577..ca5bd076f1 100644 --- a/dlt/common/configuration/specs/gcp_credentials.py +++ b/dlt/common/configuration/specs/gcp_credentials.py @@ -33,12 +33,6 @@ class GcpCredentials(CredentialsConfiguration): project_id: str = None - location: ( - str - ) = ( # DEPRECATED! and present only for backward compatibility. please set bigquery location in BigQuery configuration - "US" - ) - def parse_native_representation(self, native_value: Any) -> None: if not isinstance(native_value, str): raise InvalidGoogleNativeCredentialsType(self.__class__, native_value) diff --git a/dlt/common/configuration/specs/run_configuration.py b/dlt/common/configuration/specs/run_configuration.py index ed85aae8ba..dcb78683fb 100644 --- a/dlt/common/configuration/specs/run_configuration.py +++ b/dlt/common/configuration/specs/run_configuration.py @@ -17,7 +17,9 @@ class RunConfiguration(BaseConfiguration): dlthub_telemetry: bool = True # enable or disable dlthub telemetry dlthub_telemetry_endpoint: Optional[str] = "https://telemetry.scalevector.ai" dlthub_telemetry_segment_write_key: Optional[str] = None - log_format: str = "{asctime}|[{levelname:<21}]|{process}|{thread}|{name}|{filename}|{funcName}:{lineno}|{message}" + log_format: str = ( + "{asctime}|[{levelname}]|{process}|{thread}|{name}|{filename}|{funcName}:{lineno}|{message}" + ) log_level: str = "WARNING" request_timeout: float = 60 """Timeout for http requests""" diff --git a/dlt/common/configuration/utils.py b/dlt/common/configuration/utils.py index 6402afcfbe..74190a87de 100644 --- a/dlt/common/configuration/utils.py +++ b/dlt/common/configuration/utils.py @@ -100,7 +100,7 @@ def deserialize_value(key: str, value: Any, hint: Type[TAny]) -> TAny: raise ConfigValueCannotBeCoercedException(key, value, hint) from exc -def serialize_value(value: Any) -> Any: +def serialize_value(value: Any) -> str: if value is None: raise ValueError(value) # return literal for tuples @@ -108,13 +108,13 @@ def serialize_value(value: Any) -> Any: return str(value) if isinstance(value, BaseConfiguration): try: - return value.to_native_representation() + return str(value.to_native_representation()) except NotImplementedError: # no native representation: use dict value = dict(value) # coerce type to text which will use json for mapping and sequences value_dt = py_type_to_sc_type(type(value)) - return coerce_value("text", value_dt, value) + return coerce_value("text", value_dt, value) # type: ignore[no-any-return] def auto_cast(value: str) -> Any: diff --git a/dlt/common/data_writers/__init__.py b/dlt/common/data_writers/__init__.py index 97451d8be7..945e74a37b 100644 --- a/dlt/common/data_writers/__init__.py +++ b/dlt/common/data_writers/__init__.py @@ -3,6 +3,7 @@ DataWriterMetrics, TDataItemFormat, FileWriterSpec, + create_import_spec, resolve_best_writer_spec, get_best_writer_spec, is_native_writer, @@ -11,12 +12,13 @@ from dlt.common.data_writers.escape import ( escape_redshift_literal, escape_redshift_identifier, - escape_bigquery_identifier, + escape_hive_identifier, ) __all__ = [ "DataWriter", "FileWriterSpec", + "create_import_spec", "resolve_best_writer_spec", "get_best_writer_spec", "is_native_writer", @@ -26,5 +28,5 @@ "new_file_id", "escape_redshift_literal", "escape_redshift_identifier", - "escape_bigquery_identifier", + "escape_hive_identifier", ] diff --git a/dlt/common/data_writers/buffered.py b/dlt/common/data_writers/buffered.py index bd32c68c49..8077007edb 100644 --- a/dlt/common/data_writers/buffered.py +++ b/dlt/common/data_writers/buffered.py @@ -1,11 +1,13 @@ import gzip import time -from typing import ClassVar, List, IO, Any, Optional, Type, Generic +import contextlib +from typing import ClassVar, Iterator, List, IO, Any, Optional, Type, Generic from dlt.common.typing import TDataItem, TDataItems from dlt.common.data_writers.exceptions import ( BufferedDataWriterClosed, DestinationCapabilitiesRequired, + FileImportNotFound, InvalidFileNameTemplateException, ) from dlt.common.data_writers.writers import TWriter, DataWriter, DataWriterMetrics, FileWriterSpec @@ -138,18 +140,31 @@ def write_empty_file(self, columns: TTableSchemaColumns) -> DataWriterMetrics: self._last_modified = time.time() return self._rotate_file(allow_empty_file=True) - def import_file(self, file_path: str, metrics: DataWriterMetrics) -> DataWriterMetrics: + def import_file( + self, file_path: str, metrics: DataWriterMetrics, with_extension: str = None + ) -> DataWriterMetrics: """Import a file from `file_path` into items storage under a new file name. Does not check the imported file format. Uses counts from `metrics` as a base. Logically closes the imported file The preferred import method is a hard link to avoid copying the data. If current filesystem does not support it, a regular copy is used. + + Alternative extension may be provided via `with_extension` so various file formats may be imported into the same folder. """ # TODO: we should separate file storage from other storages. this creates circular deps from dlt.common.storages import FileStorage - self._rotate_file() - FileStorage.link_hard_with_fallback(file_path, self._file_name) + # import file with alternative extension + spec = self.writer_spec + if with_extension: + spec = self.writer_spec._replace(file_extension=with_extension) + with self.alternative_spec(spec): + self._rotate_file() + try: + FileStorage.link_hard_with_fallback(file_path, self._file_name) + except FileNotFoundError as f_ex: + raise FileImportNotFound(file_path, self._file_name) from f_ex + self._last_modified = time.time() metrics = metrics._replace( file_path=self._file_name, @@ -176,6 +191,16 @@ def close(self, skip_flush: bool = False) -> None: def closed(self) -> bool: return self._closed + @contextlib.contextmanager + def alternative_spec(self, spec: FileWriterSpec) -> Iterator[FileWriterSpec]: + """Temporarily changes the writer spec ie. for the moment file is rotated""" + old_spec = self.writer_spec + try: + self.writer_spec = spec + yield spec + finally: + self.writer_spec = old_spec + def __enter__(self) -> "BufferedDataWriter[TWriter]": return self diff --git a/dlt/common/data_writers/configuration.py b/dlt/common/data_writers/configuration.py new file mode 100644 index 0000000000..a837cb47b0 --- /dev/null +++ b/dlt/common/data_writers/configuration.py @@ -0,0 +1,31 @@ +from typing import ClassVar, Literal, Optional +from dlt.common.configuration import configspec, known_sections +from dlt.common.configuration.specs import BaseConfiguration + +CsvQuoting = Literal["quote_all", "quote_needed"] + + +@configspec +class CsvFormatConfiguration(BaseConfiguration): + delimiter: str = "," + include_header: bool = True + quoting: CsvQuoting = "quote_needed" + + # read options + on_error_continue: bool = False + encoding: str = "utf-8" + + __section__: ClassVar[str] = known_sections.DATA_WRITER + + +@configspec +class ParquetFormatConfiguration(BaseConfiguration): + flavor: Optional[str] = None # could be ie. "spark" + version: Optional[str] = "2.4" + data_page_size: Optional[int] = None + timestamp_timezone: str = "UTC" + row_group_size: Optional[int] = None + coerce_timestamps: Optional[Literal["s", "ms", "us", "ns"]] = None + allow_truncated_timestamps: bool = False + + __section__: ClassVar[str] = known_sections.DATA_WRITER diff --git a/dlt/common/data_writers/escape.py b/dlt/common/data_writers/escape.py index 580b057716..06c8d7a95a 100644 --- a/dlt/common/data_writers/escape.py +++ b/dlt/common/data_writers/escape.py @@ -124,7 +124,7 @@ def escape_redshift_identifier(v: str) -> str: escape_dremio_identifier = escape_postgres_identifier -def escape_bigquery_identifier(v: str) -> str: +def escape_hive_identifier(v: str) -> str: # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical return "`" + v.replace("\\", "\\\\").replace("`", "\\`") + "`" @@ -132,10 +132,10 @@ def escape_bigquery_identifier(v: str) -> str: def escape_snowflake_identifier(v: str) -> str: # Snowcase uppercase all identifiers unless quoted. Match this here so queries on information schema work without issue # See also https://docs.snowflake.com/en/sql-reference/identifiers-syntax#double-quoted-identifiers - return escape_postgres_identifier(v.upper()) + return escape_postgres_identifier(v) -escape_databricks_identifier = escape_bigquery_identifier +escape_databricks_identifier = escape_hive_identifier DATABRICKS_ESCAPE_DICT = {"'": "\\'", "\\": "\\\\", "\n": "\\n", "\r": "\\r"} diff --git a/dlt/common/data_writers/exceptions.py b/dlt/common/data_writers/exceptions.py index 1d5c58f787..3b11ed70fc 100644 --- a/dlt/common/data_writers/exceptions.py +++ b/dlt/common/data_writers/exceptions.py @@ -22,6 +22,16 @@ def __init__(self, file_name: str): super().__init__(f"Writer with recent file name {file_name} is already closed") +class FileImportNotFound(DataWriterException, FileNotFoundError): + def __init__(self, import_file_path: str, local_file_path: str) -> None: + self.import_file_path = import_file_path + self.local_file_path = local_file_path + super().__init__( + f"Attempt to import non existing file {import_file_path} into extract storage file" + f" {local_file_path}" + ) + + class DestinationCapabilitiesRequired(DataWriterException, ValueError): def __init__(self, file_format: TLoaderFileFormat): self.file_format = file_format diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 8936dae605..d324792a83 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -4,7 +4,6 @@ IO, TYPE_CHECKING, Any, - ClassVar, Dict, List, Literal, @@ -17,8 +16,7 @@ ) from dlt.common.json import json -from dlt.common.configuration import configspec, known_sections, with_config -from dlt.common.configuration.specs import BaseConfiguration +from dlt.common.configuration import with_config from dlt.common.data_writers.exceptions import ( SpecLookupFailed, DataWriterNotFound, @@ -26,15 +24,25 @@ FileSpecNotFound, InvalidDataItem, ) -from dlt.common.destination import DestinationCapabilitiesContext, TLoaderFileFormat +from dlt.common.data_writers.configuration import ( + CsvFormatConfiguration, + CsvQuoting, + ParquetFormatConfiguration, +) +from dlt.common.destination import ( + DestinationCapabilitiesContext, + TLoaderFileFormat, + ALL_SUPPORTED_FILE_FORMATS, +) from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.typing import StrAny + if TYPE_CHECKING: from dlt.common.libs.pyarrow import pyarrow as pa -TDataItemFormat = Literal["arrow", "object"] +TDataItemFormat = Literal["arrow", "object", "file"] TWriter = TypeVar("TWriter", bound="DataWriter") @@ -124,6 +132,9 @@ def item_format_from_file_extension(cls, extension: str) -> TDataItemFormat: return "object" elif extension == "parquet": return "arrow" + # those files may be imported by normalizer as is + elif extension in ALL_SUPPORTED_FILE_FORMATS: + return "file" else: raise ValueError(f"Cannot figure out data item format for extension {extension}") @@ -132,6 +143,8 @@ def writer_class_from_spec(spec: FileWriterSpec) -> Type["DataWriter"]: try: return WRITER_SPECS[spec] except KeyError: + if spec.data_item_format == "file": + return ImportFileWriter raise FileSpecNotFound(spec.file_format, spec.data_item_format, spec) @staticmethod @@ -147,6 +160,19 @@ def class_factory( raise FileFormatForItemFormatNotFound(file_format, data_item_format) +class ImportFileWriter(DataWriter): + """May only import files, fails on any open/write operations""" + + def write_header(self, columns_schema: TTableSchemaColumns) -> None: + raise NotImplementedError( + "ImportFileWriter cannot write any files. You have bug in your code." + ) + + @classmethod + def writer_spec(cls) -> FileWriterSpec: + raise NotImplementedError("ImportFileWriter has no single spec") + + class JsonlWriter(DataWriter): def write_data(self, rows: Sequence[Any]) -> None: super().write_data(rows) @@ -260,21 +286,8 @@ def writer_spec(cls) -> FileWriterSpec: ) -@configspec -class ParquetDataWriterConfiguration(BaseConfiguration): - flavor: Optional[str] = None # could be ie. "spark" - version: Optional[str] = "2.4" - data_page_size: Optional[int] = None - timestamp_timezone: str = "UTC" - row_group_size: Optional[int] = None - coerce_timestamps: Optional[Literal["s", "ms", "us", "ns"]] = None - allow_truncated_timestamps: bool = False - - __section__: ClassVar[str] = known_sections.DATA_WRITER - - class ParquetDataWriter(DataWriter): - @with_config(spec=ParquetDataWriterConfiguration) + @with_config(spec=ParquetFormatConfiguration) def __init__( self, f: IO[Any], @@ -381,20 +394,8 @@ def writer_spec(cls) -> FileWriterSpec: ) -CsvQuoting = Literal["quote_all", "quote_needed"] - - -@configspec -class CsvDataWriterConfiguration(BaseConfiguration): - delimiter: str = "," - include_header: bool = True - quoting: CsvQuoting = "quote_needed" - - __section__: ClassVar[str] = known_sections.DATA_WRITER - - class CsvWriter(DataWriter): - @with_config(spec=CsvDataWriterConfiguration) + @with_config(spec=CsvFormatConfiguration) def __init__( self, f: IO[Any], @@ -525,7 +526,7 @@ def writer_spec(cls) -> FileWriterSpec: class ArrowToCsvWriter(DataWriter): - @with_config(spec=CsvDataWriterConfiguration) + @with_config(spec=CsvFormatConfiguration) def __init__( self, f: IO[Any], @@ -783,3 +784,16 @@ def get_best_writer_spec( return DataWriter.class_factory(file_format, item_format, native_writers).writer_spec() except DataWriterNotFound: return DataWriter.class_factory(file_format, item_format, ALL_WRITERS).writer_spec() + + +def create_import_spec( + item_file_format: TLoaderFileFormat, + possible_file_formats: Sequence[TLoaderFileFormat], +) -> FileWriterSpec: + """Creates writer spec that may be used only to import files""" + # can the item file be directly imported? + if item_file_format not in possible_file_formats: + raise SpecLookupFailed("file", possible_file_formats, item_file_format) + + spec = DataWriter.class_factory(item_file_format, "object", ALL_WRITERS).writer_spec() + return spec._replace(data_item_format="file") diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index d8361d7140..e8c341d97c 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -10,7 +10,8 @@ Protocol, get_args, ) - +from dlt.common.normalizers.typing import TNamingConventionReferenceArg +from dlt.common.typing import TLoaderFileFormat from dlt.common.configuration.utils import serialize_value from dlt.common.configuration import configspec from dlt.common.configuration.specs import ContainerInjectableContext @@ -19,19 +20,11 @@ DestinationLoadingViaStagingNotSupported, DestinationLoadingWithoutStagingNotSupported, ) -from dlt.common.utils import identity - +from dlt.common.normalizers.naming import NamingConvention from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.common.wei import EVM_DECIMAL_PRECISION -# known loader file formats -# jsonl - new line separated json documents -# typed-jsonl - internal extract -> normalize format bases on jsonl -# insert_values - insert SQL statements -# sql - any sql statement -TLoaderFileFormat = Literal["jsonl", "typed-jsonl", "insert_values", "parquet", "csv"] TLoaderParallelismStrategy = Literal["parallel", "table-sequential", "sequential"] - ALL_SUPPORTED_FILE_FORMATS: Set[TLoaderFileFormat] = set(get_args(TLoaderFileFormat)) @@ -61,9 +54,15 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): """Recommended file size in bytes when writing extract/load files""" preferred_staging_file_format: Optional[TLoaderFileFormat] = None supported_staging_file_formats: Sequence[TLoaderFileFormat] = None + format_datetime_literal: Callable[..., str] = None escape_identifier: Callable[[str], str] = None + "Escapes table name, column name and other identifiers" escape_literal: Callable[[Any], Any] = None - format_datetime_literal: Callable[..., str] = None + "Escapes string literal" + casefold_identifier: Callable[[str], str] = str + """Casing function applied by destination to represent case insensitive identifiers.""" + has_case_sensitive_identifiers: bool = None + """Tells if destination supports case sensitive identifiers""" decimal_precision: Tuple[int, int] = None wei_precision: Tuple[int, int] = None max_identifier_length: int = None @@ -74,7 +73,8 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): is_max_text_data_type_length_in_bytes: bool = None supports_transactions: bool = None supports_ddl_transactions: bool = None - naming_convention: str = "snake_case" + # use naming convention in the schema + naming_convention: TNamingConventionReferenceArg = None alter_add_multi_column: bool = True supports_truncate_command: bool = True schema_supports_numeric_precision: bool = True @@ -88,6 +88,9 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): max_table_nesting: Optional[int] = None """Allows a destination to overwrite max_table_nesting from source""" + supported_merge_strategies: Sequence["TLoaderMergeStrategy"] = None # type: ignore[name-defined] # noqa: F821 + # TODO: also add `supported_replace_strategies` capability + # do not allow to create default value, destination caps must be always explicitly inserted into container can_create_default: ClassVar[bool] = False @@ -96,9 +99,15 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): loader_parallelism_strategy: Optional[TLoaderParallelismStrategy] = None """The destination can override the parallelism strategy""" + def generates_case_sensitive_identifiers(self) -> bool: + """Tells if capabilities as currently adjusted, will generate case sensitive identifiers""" + # must have case sensitive support and folding function must preserve casing + return self.has_case_sensitive_identifiers and self.casefold_identifier is str + @staticmethod def generic_capabilities( preferred_loader_file_format: TLoaderFileFormat = None, + naming_convention: TNamingConventionReferenceArg = None, loader_file_format_adapter: LoaderFileFormatAdapter = None, supported_table_formats: Sequence["TTableFormat"] = None, # type: ignore[name-defined] # noqa: F821 ) -> "DestinationCapabilitiesContext": @@ -110,9 +119,12 @@ def generic_capabilities( caps.loader_file_format_adapter = loader_file_format_adapter caps.preferred_staging_file_format = None caps.supported_staging_file_formats = [] + caps.naming_convention = naming_convention or caps.naming_convention + caps.escape_identifier = str caps.supported_table_formats = supported_table_formats or [] - caps.escape_identifier = identity caps.escape_literal = serialize_value + caps.casefold_identifier = str + caps.has_case_sensitive_identifiers = True caps.format_datetime_literal = format_datetime_literal caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) caps.wei_precision = (EVM_DECIMAL_PRECISION, 0) diff --git a/dlt/common/destination/exceptions.py b/dlt/common/destination/exceptions.py index cd8f50bcce..49c9b822e3 100644 --- a/dlt/common/destination/exceptions.py +++ b/dlt/common/destination/exceptions.py @@ -124,3 +124,20 @@ def __init__(self, schema_name: str, version_hash: str, stored_version_hash: str " schema in load package, you should first save it into schema storage. You can also" " use schema._bump_version() in test code to remove modified flag." ) + + +class DestinationCapabilitiesException(DestinationException): + pass + + +class DestinationInvalidFileFormat(DestinationTerminalException): + def __init__( + self, destination_type: str, file_format: str, file_name: str, message: str + ) -> None: + self.destination_type = destination_type + self.file_format = file_format + self.message = message + super().__init__( + f"Destination {destination_type} cannot process file {file_name} with format" + f" {file_format}: {message}" + ) diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index fd4c939bae..e8e012f4fd 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -28,30 +28,28 @@ from dlt.common import logger from dlt.common.typing import DataFrame, ArrowTable +from dlt.common.configuration.specs.base_configuration import extract_inner_hint +from dlt.common.destination.utils import verify_schema_capabilities +from dlt.common.normalizers.naming import NamingConvention + from dlt.common.schema import Schema, TTableSchema, TSchemaTables -from dlt.common.schema.typing import MERGE_STRATEGIES -from dlt.common.schema.exceptions import SchemaException from dlt.common.schema.utils import ( + get_file_format, get_write_disposition, get_table_format, - get_columns_names_with_prop, - has_column_with_prop, - get_first_column_name_with_prop, ) from dlt.common.configuration import configspec, resolve_configuration, known_sections, NotResolved from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.destination.exceptions import ( - IdentifierTooLongException, InvalidDestinationReference, UnknownDestinationModule, DestinationSchemaTampered, ) -from dlt.common.schema.utils import is_complete_column from dlt.common.schema.exceptions import UnknownTableException from dlt.common.storages import FileStorage from dlt.common.storages.load_storage import ParsedLoadJobFileName -from dlt.common.storages.load_package import LoadJobInfo +from dlt.common.storages.load_package import LoadJobInfo, TPipelineStateDoc TLoaderReplaceStrategy = Literal["truncate-and-insert", "insert-from-staging", "staging-optimized"] @@ -70,14 +68,69 @@ class StorageSchemaInfo(NamedTuple): inserted_at: datetime.datetime schema: str + @classmethod + def from_normalized_mapping( + cls, normalized_doc: Dict[str, Any], naming_convention: NamingConvention + ) -> "StorageSchemaInfo": + """Instantiate this class from mapping where keys are normalized according to given naming convention + + Args: + normalized_doc: Mapping with normalized keys (e.g. {Version: ..., SchemaName: ...}) + naming_convention: Naming convention that was used to normalize keys + + Returns: + StorageSchemaInfo: Instance of this class + """ + return cls( + version_hash=normalized_doc[naming_convention.normalize_identifier("version_hash")], + schema_name=normalized_doc[naming_convention.normalize_identifier("schema_name")], + version=normalized_doc[naming_convention.normalize_identifier("version")], + engine_version=normalized_doc[naming_convention.normalize_identifier("engine_version")], + inserted_at=normalized_doc[naming_convention.normalize_identifier("inserted_at")], + schema=normalized_doc[naming_convention.normalize_identifier("schema")], + ) + -class StateInfo(NamedTuple): +@dataclasses.dataclass +class StateInfo: version: int engine_version: int pipeline_name: str state: str created_at: datetime.datetime - dlt_load_id: str = None + version_hash: Optional[str] = None + _dlt_load_id: Optional[str] = None + + def as_doc(self) -> TPipelineStateDoc: + doc: TPipelineStateDoc = dataclasses.asdict(self) # type: ignore[assignment] + if self._dlt_load_id is None: + doc.pop("_dlt_load_id") + if self.version_hash is None: + doc.pop("version_hash") + return doc + + @classmethod + def from_normalized_mapping( + cls, normalized_doc: Dict[str, Any], naming_convention: NamingConvention + ) -> "StateInfo": + """Instantiate this class from mapping where keys are normalized according to given naming convention + + Args: + normalized_doc: Mapping with normalized keys (e.g. {Version: ..., PipelineName: ...}) + naming_convention: Naming convention that was used to normalize keys + + Returns: + StateInfo: Instance of this class + """ + return cls( + version=normalized_doc[naming_convention.normalize_identifier("version")], + engine_version=normalized_doc[naming_convention.normalize_identifier("engine_version")], + pipeline_name=normalized_doc[naming_convention.normalize_identifier("pipeline_name")], + state=normalized_doc[naming_convention.normalize_identifier("state")], + created_at=normalized_doc[naming_convention.normalize_identifier("created_at")], + version_hash=normalized_doc.get(naming_convention.normalize_identifier("version_hash")), + _dlt_load_id=normalized_doc.get(naming_convention.normalize_identifier("_dlt_load_id")), + ) @configspec @@ -102,6 +155,25 @@ def __str__(self) -> str: def on_resolved(self) -> None: self.destination_name = self.destination_name or self.destination_type + @classmethod + def credentials_type( + cls, config: "DestinationClientConfiguration" = None + ) -> Type[CredentialsConfiguration]: + """Figure out credentials type, using hint resolvers for dynamic types + + For correct type resolution of filesystem, config should have bucket_url populated + """ + key = "credentials" + type_ = cls.get_resolvable_fields()[key] + if key in cls.__hint_resolvers__ and config is not None: + try: + # Type hint for this field is created dynamically + type_ = cls.__hint_resolvers__[key](config) + except Exception: + # we suppress failed hint resolutions + pass + return extract_inner_hint(type_) + @configspec class DestinationClientDwhConfiguration(DestinationClientConfiguration): @@ -117,6 +189,8 @@ class DestinationClientDwhConfiguration(DestinationClientConfiguration): """name of default schema to be used to name effective dataset to load data to""" replace_strategy: TLoaderReplaceStrategy = "truncate-and-insert" """How to handle replace disposition for this destination, can be classic or staging""" + staging_dataset_name_layout: str = "%s_staging" + """Layout for staging dataset, where %s is replaced with dataset name. placeholder is optional""" def _bind_dataset_name( self: TDestinationDwhClient, dataset_name: str, default_schema_name: str = None @@ -134,21 +208,37 @@ def normalize_dataset_name(self, schema: Schema) -> str: If default schema name is None or equals schema.name, the schema suffix is skipped. """ - if not schema.name: + dataset_name = self._make_dataset_name(schema.name) + return ( + dataset_name + if not dataset_name + else schema.naming.normalize_table_identifier(dataset_name) + ) + + def normalize_staging_dataset_name(self, schema: Schema) -> str: + """Builds staging dataset name out of dataset_name and staging_dataset_name_layout.""" + if "%s" in self.staging_dataset_name_layout: + # if dataset name is empty, staging dataset name is also empty + dataset_name = self._make_dataset_name(schema.name) + if not dataset_name: + return dataset_name + # fill the placeholder + dataset_name = self.staging_dataset_name_layout % dataset_name + else: + # no placeholder, then layout is a full name. so you can have a single staging dataset + dataset_name = self.staging_dataset_name_layout + + return schema.naming.normalize_table_identifier(dataset_name) + + def _make_dataset_name(self, schema_name: str) -> str: + if not schema_name: raise ValueError("schema_name is None or empty") # if default schema is None then suffix is not added - if self.default_schema_name is not None and schema.name != self.default_schema_name: - # also normalize schema name. schema name is Python identifier and here convention may be different - return schema.naming.normalize_table_identifier( - (self.dataset_name or "") + "_" + schema.name - ) + if self.default_schema_name is not None and schema_name != self.default_schema_name: + return (self.dataset_name or "") + "_" + schema_name - return ( - self.dataset_name - if not self.dataset_name - else schema.naming.normalize_table_identifier(self.dataset_name) - ) + return self.dataset_name @configspec @@ -257,11 +347,15 @@ class DoNothingFollowupJob(DoNothingJob, FollowupJob): class JobClientBase(ABC): - capabilities: ClassVar[DestinationCapabilitiesContext] = None - - def __init__(self, schema: Schema, config: DestinationClientConfiguration) -> None: + def __init__( + self, + schema: Schema, + config: DestinationClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: self.schema = schema self.config = config + self.capabilities = capabilities @abstractmethod def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: @@ -319,7 +413,7 @@ def should_truncate_table_before_load(self, table: TTableSchema) -> bool: def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], - table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, + completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[NewLoadJob]: """Creates a list of followup jobs that should be executed after a table chain is completed""" return [] @@ -340,96 +434,13 @@ def __exit__( pass def _verify_schema(self) -> None: - """Verifies and cleans up a schema before loading - - * Checks all table and column name lengths against destination capabilities and raises on too long identifiers - * Removes and warns on (unbound) incomplete columns - """ - - for table in self.schema.data_tables(): - table_name = table["name"] - if len(table_name) > self.capabilities.max_identifier_length: - raise IdentifierTooLongException( - self.config.destination_type, - "table", - table_name, - self.capabilities.max_identifier_length, - ) - if table.get("write_disposition") == "merge": - if "x-merge-strategy" in table and table["x-merge-strategy"] not in MERGE_STRATEGIES: # type: ignore[typeddict-item] - raise SchemaException( - f'"{table["x-merge-strategy"]}" is not a valid merge strategy. ' # type: ignore[typeddict-item] - f"""Allowed values: {', '.join(['"' + s + '"' for s in MERGE_STRATEGIES])}.""" - ) - if ( - table.get("x-merge-strategy") == "delete-insert" - and not has_column_with_prop(table, "primary_key") - and not has_column_with_prop(table, "merge_key") - ): - logger.warning( - f"Table {table_name} has `write_disposition` set to `merge`" - " and `merge_strategy` set to `delete-insert`, but no primary or" - " merge keys defined." - " dlt will fall back to `append` for this table." - ) - if has_column_with_prop(table, "hard_delete"): - if len(get_columns_names_with_prop(table, "hard_delete")) > 1: - raise SchemaException( - f'Found multiple "hard_delete" column hints for table "{table_name}" in' - f' schema "{self.schema.name}" while only one is allowed:' - f' {", ".join(get_columns_names_with_prop(table, "hard_delete"))}.' - ) - if table.get("write_disposition") in ("replace", "append"): - logger.warning( - f"""The "hard_delete" column hint for column "{get_first_column_name_with_prop(table, 'hard_delete')}" """ - f'in table "{table_name}" with write disposition' - f' "{table.get("write_disposition")}"' - f' in schema "{self.schema.name}" will be ignored.' - ' The "hard_delete" column hint is only applied when using' - ' the "merge" write disposition.' - ) - if has_column_with_prop(table, "dedup_sort"): - if len(get_columns_names_with_prop(table, "dedup_sort")) > 1: - raise SchemaException( - f'Found multiple "dedup_sort" column hints for table "{table_name}" in' - f' schema "{self.schema.name}" while only one is allowed:' - f' {", ".join(get_columns_names_with_prop(table, "dedup_sort"))}.' - ) - if table.get("write_disposition") in ("replace", "append"): - logger.warning( - f"""The "dedup_sort" column hint for column "{get_first_column_name_with_prop(table, 'dedup_sort')}" """ - f'in table "{table_name}" with write disposition' - f' "{table.get("write_disposition")}"' - f' in schema "{self.schema.name}" will be ignored.' - ' The "dedup_sort" column hint is only applied when using' - ' the "merge" write disposition.' - ) - if table.get("write_disposition") == "merge" and not has_column_with_prop( - table, "primary_key" - ): - logger.warning( - f"""The "dedup_sort" column hint for column "{get_first_column_name_with_prop(table, 'dedup_sort')}" """ - f'in table "{table_name}" with write disposition' - f' "{table.get("write_disposition")}"' - f' in schema "{self.schema.name}" will be ignored.' - ' The "dedup_sort" column hint is only applied when a' - " primary key has been specified." - ) - for column_name, column in dict(table["columns"]).items(): - if len(column_name) > self.capabilities.max_column_identifier_length: - raise IdentifierTooLongException( - self.config.destination_type, - "column", - f"{table_name}.{column_name}", - self.capabilities.max_column_identifier_length, - ) - if not is_complete_column(column): - logger.warning( - f"A column {column_name} in table {table_name} in schema" - f" {self.schema.name} is incomplete. It was not bound to the data during" - " normalizations stage and its data type is unknown. Did you add this" - " column manually in code ie. as a merge key?" - ) + """Verifies schema before loading""" + if exceptions := verify_schema_capabilities( + self.schema, self.capabilities, self.config.destination_type, warnings=False + ): + for exception in exceptions: + logger.error(str(exception)) + raise exceptions[0] def prepare_load_table( self, table_name: str, prepare_for_staging: bool = False @@ -442,9 +453,11 @@ def prepare_load_table( table["write_disposition"] = get_write_disposition(self.schema.tables, table_name) if "table_format" not in table: table["table_format"] = get_table_format(self.schema.tables, table_name) + if "file_format" not in table: + table["file_format"] = get_file_format(self.schema.tables, table_name) return table except KeyError: - raise UnknownTableException(table_name) + raise UnknownTableException(self.schema.name, table_name) class WithStateSync(ABC): @@ -515,7 +528,10 @@ class Destination(ABC, Generic[TDestinationConfig, TDestinationClient]): with credentials and other config params. """ - config_params: Optional[Dict[str, Any]] = None + config_params: Dict[str, Any] + """Explicit config params, overriding any injected or default values.""" + caps_params: Dict[str, Any] + """Explicit capabilities params, overriding any default values for this destination""" def __init__(self, **kwargs: Any) -> None: # Create initial unresolved destination config @@ -523,9 +539,27 @@ def __init__(self, **kwargs: Any) -> None: # to supersede config from the environment or pipeline args sig = inspect.signature(self.__class__.__init__) params = sig.parameters - self.config_params = { - k: v for k, v in kwargs.items() if k not in params or v != params[k].default - } + + # get available args + spec = self.spec + spec_fields = spec.get_resolvable_fields() + caps_fields = DestinationCapabilitiesContext.get_resolvable_fields() + + # remove default kwargs + kwargs = {k: v for k, v in kwargs.items() if k not in params or v != params[k].default} + + # warn on unknown params + for k in list(kwargs): + if k not in spec_fields and k not in caps_fields: + logger.warning( + f"When initializing destination factory of type {self.destination_type}," + f" argument {k} is not a valid field in {spec.__name__} or destination" + " capabilities" + ) + kwargs.pop(k) + + self.config_params = {k: v for k, v in kwargs.items() if k in spec_fields} + self.caps_params = {k: v for k, v in kwargs.items() if k in caps_fields} @property @abstractmethod @@ -533,9 +567,37 @@ def spec(self) -> Type[TDestinationConfig]: """A spec of destination configuration that also contains destination credentials""" ... + def capabilities( + self, config: Optional[TDestinationConfig] = None, naming: Optional[NamingConvention] = None + ) -> DestinationCapabilitiesContext: + """Destination capabilities ie. supported loader file formats, identifier name lengths, naming conventions, escape function etc. + Explicit caps arguments passed to the factory init and stored in `caps_params` are applied. + + If `config` is provided, it is used to adjust the capabilities, otherwise the explicit config composed just of `config_params` passed + to factory init is applied + If `naming` is provided, the case sensitivity and case folding are adjusted. + """ + caps = self._raw_capabilities() + caps.update(self.caps_params) + # get explicit config if final config not passed + if config is None: + # create mock credentials to avoid credentials being resolved + init_config = self.spec() + init_config.update(self.config_params) + credentials = self.spec.credentials_type(init_config)() + credentials.__is_resolved__ = True + config = self.spec(credentials=credentials) + try: + config = self.configuration(config, accept_partial=True) + except Exception: + # in rare cases partial may fail ie. when invalid native value is present + # in that case we fallback to "empty" config + pass + return self.adjust_capabilities(caps, config, naming) + @abstractmethod - def capabilities(self) -> DestinationCapabilitiesContext: - """Destination capabilities ie. supported loader file formats, identifier name lengths, naming conventions, escape function etc.""" + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + """Returns raw capabilities, before being adjusted with naming convention and config""" ... @property @@ -558,16 +620,61 @@ def client_class(self) -> Type[TDestinationClient]: """A job client class responsible for starting and resuming load jobs""" ... - def configuration(self, initial_config: TDestinationConfig) -> TDestinationConfig: + def configuration( + self, initial_config: TDestinationConfig, accept_partial: bool = False + ) -> TDestinationConfig: """Get a fully resolved destination config from the initial config""" + config = resolve_configuration( - initial_config, + initial_config or self.spec(), sections=(known_sections.DESTINATION, self.destination_name), # Already populated values will supersede resolved env config explicit_value=self.config_params, + accept_partial=accept_partial, ) return config + def client( + self, schema: Schema, initial_config: TDestinationConfig = None + ) -> TDestinationClient: + """Returns a configured instance of the destination's job client""" + config = self.configuration(initial_config) + return self.client_class(schema, config, self.capabilities(config, schema.naming)) + + @classmethod + def adjust_capabilities( + cls, + caps: DestinationCapabilitiesContext, + config: TDestinationConfig, + naming: Optional[NamingConvention], + ) -> DestinationCapabilitiesContext: + """Adjust the capabilities to match the case sensitivity as requested by naming convention.""" + # if naming not provided, skip the adjustment + if not naming or not naming.is_case_sensitive: + # all destinations are configured to be case insensitive so there's nothing to adjust + return caps + if not caps.has_case_sensitive_identifiers: + if caps.casefold_identifier is str: + logger.info( + f"Naming convention {naming.name()} is case sensitive but the destination does" + " not support case sensitive identifiers. Nevertheless identifier casing will" + " be preserved in the destination schema." + ) + else: + logger.warn( + f"Naming convention {naming.name()} is case sensitive but the destination does" + " not support case sensitive identifiers. Destination will case fold all the" + f" identifiers with {caps.casefold_identifier}" + ) + else: + # adjust case folding to store casefold identifiers in the schema + if caps.casefold_identifier is not str: + caps.casefold_identifier = str + logger.info( + f"Enabling case sensitive identifiers for naming convention {naming.name()}" + ) + return caps + @staticmethod def to_name(ref: TDestinationReferenceArg) -> str: if ref is None: @@ -580,7 +687,7 @@ def to_name(ref: TDestinationReferenceArg) -> str: @staticmethod def normalize_type(destination_type: str) -> str: - """Normalizes destination type string into a canonical form. Assumes that type names without dots correspond to build in destinations.""" + """Normalizes destination type string into a canonical form. Assumes that type names without dots correspond to built in destinations.""" if "." not in destination_type: destination_type = "dlt.destinations." + destination_type # the next two lines shorten the dlt internal destination paths to dlt.destinations. @@ -593,7 +700,7 @@ def normalize_type(destination_type: str) -> str: @staticmethod def from_reference( ref: TDestinationReferenceArg, - credentials: Optional[CredentialsConfiguration] = None, + credentials: Optional[Any] = None, destination_name: Optional[str] = None, environment: Optional[str] = None, **kwargs: Any, @@ -643,11 +750,5 @@ def from_reference( raise InvalidDestinationReference(ref) from e return dest - def client( - self, schema: Schema, initial_config: TDestinationConfig = None - ) -> TDestinationClient: - """Returns a configured instance of the destination's job client""" - return self.client_class(schema, self.configuration(initial_config)) - TDestination = Destination[DestinationClientConfiguration, JobClientBase] diff --git a/dlt/common/destination/utils.py b/dlt/common/destination/utils.py new file mode 100644 index 0000000000..2c5e97df14 --- /dev/null +++ b/dlt/common/destination/utils.py @@ -0,0 +1,115 @@ +from typing import List + +from dlt.common import logger +from dlt.common.destination.exceptions import IdentifierTooLongException +from dlt.common.schema import Schema +from dlt.common.schema.exceptions import ( + SchemaIdentifierNormalizationCollision, +) +from dlt.common.schema.utils import is_complete_column +from dlt.common.typing import DictStrStr + +from .capabilities import DestinationCapabilitiesContext + + +def verify_schema_capabilities( + schema: Schema, + capabilities: DestinationCapabilitiesContext, + destination_type: str, + warnings: bool = True, +) -> List[Exception]: + """Verifies schema tables before loading against capabilities. Returns a list of exceptions representing critical problems with the schema. + It will log warnings by default. It is up to the caller to eventually raise exception + + * Checks all table and column name lengths against destination capabilities and raises on too long identifiers + * Checks if schema has collisions due to case sensitivity of the identifiers + """ + + log = logger.warning if warnings else logger.info + # collect all exceptions to show all problems in the schema + exception_log: List[Exception] = [] + # combined casing function + case_identifier = lambda ident: capabilities.casefold_identifier( + (str if capabilities.has_case_sensitive_identifiers else str.casefold)(ident) # type: ignore + ) + table_name_lookup: DictStrStr = {} + # name collision explanation + collision_msg = "Destination is case " + ( + "sensitive" if capabilities.has_case_sensitive_identifiers else "insensitive" + ) + if capabilities.casefold_identifier is not str: + collision_msg += ( + f" but it uses {capabilities.casefold_identifier} to generate case insensitive" + " identifiers. You may try to change the destination capabilities by changing the" + " `casefold_identifier` to `str`" + ) + collision_msg += ( + ". Please clean up your data before loading so the entities have different name. You can" + " also change to case insensitive naming convention. Note that in that case data from both" + " columns will be merged into one." + ) + + # check for any table clashes + for table in schema.data_tables(): + table_name = table["name"] + # detect table name conflict + cased_table_name = case_identifier(table_name) + if cased_table_name in table_name_lookup: + conflict_table_name = table_name_lookup[cased_table_name] + exception_log.append( + SchemaIdentifierNormalizationCollision( + schema.name, + table_name, + "table", + table_name, + conflict_table_name, + schema.naming.name(), + collision_msg, + ) + ) + table_name_lookup[cased_table_name] = table_name + if len(table_name) > capabilities.max_identifier_length: + exception_log.append( + IdentifierTooLongException( + destination_type, + "table", + table_name, + capabilities.max_identifier_length, + ) + ) + + column_name_lookup: DictStrStr = {} + for column_name, column in dict(table["columns"]).items(): + # detect table name conflict + cased_column_name = case_identifier(column_name) + if cased_column_name in column_name_lookup: + conflict_column_name = column_name_lookup[cased_column_name] + exception_log.append( + SchemaIdentifierNormalizationCollision( + schema.name, + table_name, + "column", + column_name, + conflict_column_name, + schema.naming.name(), + collision_msg, + ) + ) + column_name_lookup[cased_column_name] = column_name + if len(column_name) > capabilities.max_column_identifier_length: + exception_log.append( + IdentifierTooLongException( + destination_type, + "column", + f"{table_name}.{column_name}", + capabilities.max_column_identifier_length, + ) + ) + if not is_complete_column(column): + log( + f"A column {column_name} in table {table_name} in schema" + f" {schema.name} is incomplete. It was not bound to the data during" + " normalizations stage and its data type is unknown. Did you add this" + " column manually in code ie. as a merge key?" + ) + return exception_log diff --git a/dlt/common/json/__init__.py b/dlt/common/json/__init__.py index cf68e5d3d4..00d8dcc430 100644 --- a/dlt/common/json/__init__.py +++ b/dlt/common/json/__init__.py @@ -12,6 +12,7 @@ except ImportError: PydanticBaseModel = None # type: ignore[misc] +from dlt.common import known_env from dlt.common.pendulum import pendulum from dlt.common.arithmetics import Decimal from dlt.common.wei import Wei @@ -80,7 +81,7 @@ def custom_encode(obj: Any) -> str: # use PUA range to encode additional types -PUA_START = int(os.environ.get("DLT_JSON_TYPED_PUA_START", "0xf026"), 16) +PUA_START = int(os.environ.get(known_env.DLT_JSON_TYPED_PUA_START, "0xf026"), 16) _DECIMAL = chr(PUA_START) _DATETIME = chr(PUA_START + 1) @@ -191,7 +192,7 @@ def may_have_pua(line: bytes) -> bool: # pick the right impl json: SupportsJson = None -if os.environ.get("DLT_USE_JSON") == "simplejson": +if os.environ.get(known_env.DLT_USE_JSON) == "simplejson": from dlt.common.json import _simplejson as _json_d json = _json_d # type: ignore[assignment] diff --git a/dlt/common/known_env.py b/dlt/common/known_env.py new file mode 100644 index 0000000000..7ac36d252d --- /dev/null +++ b/dlt/common/known_env.py @@ -0,0 +1,25 @@ +"""Defines env variables that `dlt` uses independently of its configuration system""" + +DLT_PROJECT_DIR = "DLT_PROJECT_DIR" +"""The dlt project dir is the current working directory, '.' (current working dir) by default""" + +DLT_DATA_DIR = "DLT_DATA_DIR" +"""Gets default directory where pipelines' data (working directories) will be stored""" + +DLT_CONFIG_FOLDER = "DLT_CONFIG_FOLDER" +"""A folder (path relative to DLT_PROJECT_DIR) where config and secrets are stored""" + +DLT_DEFAULT_NAMING_NAMESPACE = "DLT_DEFAULT_NAMING_NAMESPACE" +"""Python namespace default where naming modules reside, defaults to dlt.common.normalizers.naming""" + +DLT_DEFAULT_NAMING_MODULE = "DLT_DEFAULT_NAMING_MODULE" +"""A module name with the default naming convention, defaults to snake_case""" + +DLT_DLT_ID_LENGTH_BYTES = "DLT_DLT_ID_LENGTH_BYTES" +"""The length of the _dlt_id identifier, before base64 encoding""" + +DLT_USE_JSON = "DLT_USE_JSON" +"""Type of json parser to use, defaults to orjson, may be simplejson""" + +DLT_JSON_TYPED_PUA_START = "DLT_JSON_TYPED_PUA_START" +"""Start of the unicode block within the PUA used to encode types in typed json""" diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index 8a6dc68078..ee249b111c 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -348,13 +348,13 @@ def normalize_py_arrow_item( def get_normalized_arrow_fields_mapping(schema: pyarrow.Schema, naming: NamingConvention) -> StrStr: - """Normalizes schema field names and returns mapping from original to normalized name. Raises on name clashes""" + """Normalizes schema field names and returns mapping from original to normalized name. Raises on name collisions""" norm_f = naming.normalize_identifier name_mapping = {n.name: norm_f(n.name) for n in schema} # verify if names uniquely normalize normalized_names = set(name_mapping.values()) if len(name_mapping) != len(normalized_names): - raise NameNormalizationClash( + raise NameNormalizationCollision( f"Arrow schema fields normalized from {list(name_mapping.keys())} to" f" {list(normalized_names)}" ) @@ -497,7 +497,7 @@ def cast_arrow_schema_types( return schema -class NameNormalizationClash(ValueError): +class NameNormalizationCollision(ValueError): def __init__(self, reason: str) -> None: - msg = f"Arrow column name clash after input data normalization. {reason}" + msg = f"Arrow column name collision after input data normalization. {reason}" super().__init__(msg) diff --git a/dlt/common/libs/pydantic.py b/dlt/common/libs/pydantic.py index 774a1641a7..e3e7d373c7 100644 --- a/dlt/common/libs/pydantic.py +++ b/dlt/common/libs/pydantic.py @@ -332,6 +332,7 @@ def validate_items( list_model, {"columns": "freeze"}, items, + err["msg"], ) from e # raise on freeze if err["type"] == "extra_forbidden": @@ -345,6 +346,7 @@ def validate_items( list_model, {"columns": "freeze"}, err_item, + err["msg"], ) from e elif column_mode == "discard_row": # pop at the right index @@ -366,6 +368,7 @@ def validate_items( list_model, {"data_type": "freeze"}, err_item, + err["msg"], ) from e elif data_mode == "discard_row": items.pop(err_idx - len(deleted)) @@ -403,6 +406,7 @@ def validate_item( model, {"columns": "freeze"}, item, + err["msg"], ) from e elif column_mode == "discard_row": return None @@ -420,6 +424,7 @@ def validate_item( model, {"data_type": "freeze"}, item, + err["msg"], ) from e elif data_mode == "discard_row": return None diff --git a/dlt/common/normalizers/__init__.py b/dlt/common/normalizers/__init__.py index 2ff41d4c12..af6add6a19 100644 --- a/dlt/common/normalizers/__init__.py +++ b/dlt/common/normalizers/__init__.py @@ -1,11 +1,9 @@ -from dlt.common.normalizers.configuration import NormalizersConfiguration from dlt.common.normalizers.typing import TJSONNormalizer, TNormalizersConfig -from dlt.common.normalizers.utils import explicit_normalizers, import_normalizers +from dlt.common.normalizers.naming import NamingConvention + __all__ = [ - "NormalizersConfiguration", + "NamingConvention", "TJSONNormalizer", "TNormalizersConfig", - "explicit_normalizers", - "import_normalizers", ] diff --git a/dlt/common/normalizers/configuration.py b/dlt/common/normalizers/configuration.py index 54b725db1f..6011ba4774 100644 --- a/dlt/common/normalizers/configuration.py +++ b/dlt/common/normalizers/configuration.py @@ -1,9 +1,8 @@ -from typing import ClassVar, Optional, TYPE_CHECKING +from typing import ClassVar, Optional from dlt.common.configuration import configspec from dlt.common.configuration.specs import BaseConfiguration, known_sections -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.normalizers.typing import TJSONNormalizer +from dlt.common.normalizers.typing import TNamingConventionReferenceArg from dlt.common.typing import DictStrAny @@ -12,22 +11,6 @@ class NormalizersConfiguration(BaseConfiguration): # always in section __section__: ClassVar[str] = known_sections.SCHEMA - naming: Optional[str] = None + naming: Optional[TNamingConventionReferenceArg] = None # Union[str, NamingConvention] json_normalizer: Optional[DictStrAny] = None - destination_capabilities: Optional[DestinationCapabilitiesContext] = None # injectable - - def on_resolved(self) -> None: - # get naming from capabilities if not present - if self.naming is None: - if self.destination_capabilities: - self.naming = self.destination_capabilities.naming_convention - # if max_table_nesting is set, we need to set the max_table_nesting in the json_normalizer - if ( - self.destination_capabilities - and self.destination_capabilities.max_table_nesting is not None - ): - self.json_normalizer = self.json_normalizer or {} - self.json_normalizer.setdefault("config", {}) - self.json_normalizer["config"][ - "max_nesting" - ] = self.destination_capabilities.max_table_nesting + allow_identifier_change_on_table_with_data: Optional[bool] = None diff --git a/dlt/common/normalizers/json/relational.py b/dlt/common/normalizers/json/relational.py index bad275ca4f..c98949322f 100644 --- a/dlt/common/normalizers/json/relational.py +++ b/dlt/common/normalizers/json/relational.py @@ -2,18 +2,24 @@ from typing import Dict, List, Mapping, Optional, Sequence, Tuple, cast, TypedDict, Any from dlt.common.json import json from dlt.common.normalizers.exceptions import InvalidJsonNormalizer -from dlt.common.normalizers.typing import TJSONNormalizer +from dlt.common.normalizers.typing import TJSONNormalizer, TRowIdType from dlt.common.normalizers.utils import generate_dlt_id, DLT_ID_LENGTH_BYTES -from dlt.common.typing import DictStrAny, DictStrStr, TDataItem, StrAny +from dlt.common.typing import DictStrAny, TDataItem, StrAny from dlt.common.schema import Schema from dlt.common.schema.typing import ( + TLoaderMergeStrategy, TColumnSchema, TColumnName, TSimpleRegex, DLT_NAME_PREFIX, ) -from dlt.common.schema.utils import column_name_validator, get_validity_column_names +from dlt.common.schema.utils import ( + column_name_validator, + get_validity_column_names, + get_columns_names_with_prop, + get_first_column_name_with_prop, +) from dlt.common.schema.exceptions import ColumnNameConflictException from dlt.common.utils import digest128, update_dict_nested from dlt.common.normalizers.json import ( @@ -23,28 +29,10 @@ ) from dlt.common.validation import validate_dict -EMPTY_KEY_IDENTIFIER = "_empty" # replace empty keys with this - - -class TDataItemRow(TypedDict, total=False): - _dlt_id: str # unique id of current row - - -class TDataItemRowRoot(TDataItemRow, total=False): - _dlt_load_id: (str) # load id to identify records loaded together that ie. need to be processed - # _dlt_meta: TEventDLTMeta # stores metadata, should never be sent to the normalizer - - -class TDataItemRowChild(TDataItemRow, total=False): - _dlt_root_id: str # unique id of top level parent - _dlt_parent_id: str # unique id of parent row - _dlt_list_idx: int # position in the list of rows - value: Any # for lists of simple types - class RelationalNormalizerConfigPropagation(TypedDict, total=False): - root: Optional[Mapping[str, TColumnName]] - tables: Optional[Mapping[str, Mapping[str, TColumnName]]] + root: Optional[Dict[TColumnName, TColumnName]] + tables: Optional[Dict[str, Dict[TColumnName, TColumnName]]] class RelationalNormalizerConfig(TypedDict, total=False): @@ -54,6 +42,23 @@ class RelationalNormalizerConfig(TypedDict, total=False): class DataItemNormalizer(DataItemNormalizerBase[RelationalNormalizerConfig]): + # known normalizer props + C_DLT_ID = "_dlt_id" + """unique id of current row""" + C_DLT_LOAD_ID = "_dlt_load_id" + """load id to identify records loaded together that ie. need to be processed""" + C_DLT_ROOT_ID = "_dlt_root_id" + """unique id of top level parent""" + C_DLT_PARENT_ID = "_dlt_parent_id" + """unique id of parent row""" + C_DLT_LIST_IDX = "_dlt_list_idx" + """position in the list of rows""" + C_VALUE = "value" + """for lists of simple types""" + + # other constants + EMPTY_KEY_IDENTIFIER = "_empty" # replace empty keys with this + normalizer_config: RelationalNormalizerConfig propagation_config: RelationalNormalizerConfigPropagation max_nesting: int @@ -63,12 +68,29 @@ def __init__(self, schema: Schema) -> None: """This item normalizer works with nested dictionaries. It flattens dictionaries and descends into lists. It yields row dictionaries at each nesting level.""" self.schema = schema + self.naming = schema.naming self._reset() def _reset(self) -> None: - self.normalizer_config = ( - self.schema._normalizers_config["json"].get("config") or {} # type: ignore[assignment] + # normalize known normalizer column identifiers + self.c_dlt_id: TColumnName = TColumnName(self.naming.normalize_identifier(self.C_DLT_ID)) + self.c_dlt_load_id: TColumnName = TColumnName( + self.naming.normalize_identifier(self.C_DLT_LOAD_ID) + ) + self.c_dlt_root_id: TColumnName = TColumnName( + self.naming.normalize_identifier(self.C_DLT_ROOT_ID) ) + self.c_dlt_parent_id: TColumnName = TColumnName( + self.naming.normalize_identifier(self.C_DLT_PARENT_ID) + ) + self.c_dlt_list_idx: TColumnName = TColumnName( + self.naming.normalize_identifier(self.C_DLT_LIST_IDX) + ) + self.c_value: TColumnName = TColumnName(self.naming.normalize_identifier(self.C_VALUE)) + + # normalize config + + self.normalizer_config = self.schema._normalizers_config["json"].get("config") or {} # type: ignore[assignment] self.propagation_config = self.normalizer_config.get("propagation", None) self.max_nesting = self.normalizer_config.get("max_nesting", 1000) self._skip_primary_key = {} @@ -103,8 +125,8 @@ def _is_complex_type(self, table_name: str, field_name: str, _r_lvl: int) -> boo return data_type == "complex" def _flatten( - self, table: str, dict_row: TDataItemRow, _r_lvl: int - ) -> Tuple[TDataItemRow, Dict[Tuple[str, ...], Sequence[Any]]]: + self, table: str, dict_row: DictStrAny, _r_lvl: int + ) -> Tuple[DictStrAny, Dict[Tuple[str, ...], Sequence[Any]]]: out_rec_row: DictStrAny = {} out_rec_list: Dict[Tuple[str, ...], Sequence[Any]] = {} schema_naming = self.schema.naming @@ -115,7 +137,7 @@ def norm_row_dicts(dict_row: StrAny, __r_lvl: int, path: Tuple[str, ...] = ()) - norm_k = schema_naming.normalize_identifier(k) else: # for empty keys in the data use _ - norm_k = EMPTY_KEY_IDENTIFIER + norm_k = self.EMPTY_KEY_IDENTIFIER # if norm_k != k: # print(f"{k} -> {norm_k}") child_name = ( @@ -139,10 +161,10 @@ def norm_row_dicts(dict_row: StrAny, __r_lvl: int, path: Tuple[str, ...] = ()) - out_rec_row[child_name] = v norm_row_dicts(dict_row, _r_lvl) - return cast(TDataItemRow, out_rec_row), out_rec_list + return out_rec_row, out_rec_list @staticmethod - def get_row_hash(row: Dict[str, Any]) -> str: + def get_row_hash(row: Dict[str, Any], subset: Optional[List[str]] = None) -> str: """Returns hash of row. Hash includes column names and values and is ordered by column name. @@ -150,6 +172,8 @@ def get_row_hash(row: Dict[str, Any]) -> str: Can be used as deterministic row identifier. """ row_filtered = {k: v for k, v in row.items() if not k.startswith(DLT_NAME_PREFIX)} + if subset is not None: + row_filtered = {k: v for k, v in row.items() if k in subset} row_str = json.dumps(row_filtered, sort_keys=True) return digest128(row_str, DLT_ID_LENGTH_BYTES) @@ -160,7 +184,7 @@ def _get_child_row_hash(parent_row_id: str, child_table: str, list_idx: int) -> return digest128(f"{parent_row_id}_{child_table}_{list_idx}", DLT_ID_LENGTH_BYTES) @staticmethod - def _link_row(row: TDataItemRowChild, parent_row_id: str, list_idx: int) -> TDataItemRowChild: + def _link_row(row: DictStrAny, parent_row_id: str, list_idx: int) -> DictStrAny: assert parent_row_id row["_dlt_parent_id"] = parent_row_id row["_dlt_list_idx"] = list_idx @@ -168,31 +192,52 @@ def _link_row(row: TDataItemRowChild, parent_row_id: str, list_idx: int) -> TDat return row @staticmethod - def _extend_row(extend: DictStrAny, row: TDataItemRow) -> None: - row.update(extend) # type: ignore + def _extend_row(extend: DictStrAny, row: DictStrAny) -> None: + row.update(extend) def _add_row_id( - self, table: str, row: TDataItemRow, parent_row_id: str, pos: int, _r_lvl: int + self, + table: str, + dict_row: DictStrAny, + flattened_row: DictStrAny, + parent_row_id: str, + pos: int, + _r_lvl: int, ) -> str: - # row_id is always random, no matter if primary_key is present or not - row_id = generate_dlt_id() - if _r_lvl > 0: - primary_key = self.schema.filter_row_with_hint(table, "primary_key", row) - if not primary_key: - # child table row deterministic hash - row_id = DataItemNormalizer._get_child_row_hash(parent_row_id, table, pos) - # link to parent table - DataItemNormalizer._link_row(cast(TDataItemRowChild, row), parent_row_id, pos) - row["_dlt_id"] = row_id + primary_key = False + if _r_lvl > 0: # child table + primary_key = bool( + self.schema.filter_row_with_hint(table, "primary_key", flattened_row) + ) + row_id_type = self._get_row_id_type(self.schema, table, primary_key, _r_lvl) + + if row_id_type == "random": + row_id = generate_dlt_id() + else: + if _r_lvl == 0: # root table + if row_id_type in ("key_hash", "row_hash"): + subset = None + if row_id_type == "key_hash": + subset = self._get_primary_key(self.schema, table) + # base hash on `dict_row` instead of `flattened_row` + # so changes in child tables lead to new row id + row_id = self.get_row_hash(dict_row, subset=subset) + elif _r_lvl > 0: # child table + if row_id_type == "row_hash": + row_id = DataItemNormalizer._get_child_row_hash(parent_row_id, table, pos) + # link to parent table + DataItemNormalizer._link_row(flattened_row, parent_row_id, pos) + + flattened_row[self.c_dlt_id] = row_id return row_id - def _get_propagated_values(self, table: str, row: TDataItemRow, _r_lvl: int) -> StrAny: + def _get_propagated_values(self, table: str, row: DictStrAny, _r_lvl: int) -> StrAny: extend: DictStrAny = {} config = self.propagation_config if config: # mapping(k:v): propagate property with name "k" as property with name "v" in child table - mappings: DictStrStr = {} + mappings: Dict[TColumnName, TColumnName] = {} if _r_lvl == 0: mappings.update(config.get("root") or {}) if table in (config.get("tables") or {}): @@ -200,7 +245,7 @@ def _get_propagated_values(self, table: str, row: TDataItemRow, _r_lvl: int) -> # look for keys and create propagation as values for prop_from, prop_as in mappings.items(): if prop_from in row: - extend[prop_as] = row[prop_from] # type: ignore + extend[prop_as] = row[prop_from] return extend @@ -214,7 +259,7 @@ def _normalize_list( parent_row_id: Optional[str] = None, _r_lvl: int = 0, ) -> TNormalizedRowIterator: - v: TDataItemRowChild = None + v: DictStrAny = None table = self.schema.naming.shorten_fragments(*parent_path, *ident_path) for idx, v in enumerate(seq): @@ -238,36 +283,31 @@ def _normalize_list( # list of simple types child_row_hash = DataItemNormalizer._get_child_row_hash(parent_row_id, table, idx) wrap_v = wrap_in_dict(v) - wrap_v["_dlt_id"] = child_row_hash + wrap_v[self.c_dlt_id] = child_row_hash e = DataItemNormalizer._link_row(wrap_v, parent_row_id, idx) DataItemNormalizer._extend_row(extend, e) yield (table, self.schema.naming.shorten_fragments(*parent_path)), e def _normalize_row( self, - dict_row: TDataItemRow, + dict_row: DictStrAny, extend: DictStrAny, ident_path: Tuple[str, ...], parent_path: Tuple[str, ...] = (), parent_row_id: Optional[str] = None, pos: Optional[int] = None, _r_lvl: int = 0, - row_hash: bool = False, ) -> TNormalizedRowIterator: schema = self.schema table = schema.naming.shorten_fragments(*parent_path, *ident_path) - # compute row hash and set as row id - if row_hash: - row_id = self.get_row_hash(dict_row) # type: ignore[arg-type] - dict_row["_dlt_id"] = row_id # flatten current row and extract all lists to recur into flattened_row, lists = self._flatten(table, dict_row, _r_lvl) # always extend row DataItemNormalizer._extend_row(extend, flattened_row) # infer record hash or leave existing primary key if present - row_id = flattened_row.get("_dlt_id", None) + row_id = flattened_row.get(self.c_dlt_id, None) if not row_id: - row_id = self._add_row_id(table, flattened_row, parent_row_id, pos, _r_lvl) + row_id = self._add_row_id(table, dict_row, flattened_row, parent_row_id, pos, _r_lvl) # find fields to propagate to child tables in config extend.update(self._get_propagated_values(table, flattened_row, _r_lvl)) @@ -292,43 +332,55 @@ def _normalize_row( ) def extend_schema(self) -> None: - # validate config + """Extends Schema with normalizer-specific hints and settings. + + This method is called by Schema when instance is created or restored from storage. + """ config = cast( RelationalNormalizerConfig, self.schema._normalizers_config["json"].get("config") or {}, ) DataItemNormalizer._validate_normalizer_config(self.schema, config) - # quick check to see if hints are applied - default_hints = self.schema.settings.get("default_hints") or {} - if "not_null" in default_hints and "^_dlt_id$" in default_hints["not_null"]: - return - # add hints - self.schema.merge_hints( + # add hints, do not compile. + self.schema._merge_hints( { "not_null": [ - TSimpleRegex("_dlt_id"), - TSimpleRegex("_dlt_root_id"), - TSimpleRegex("_dlt_parent_id"), - TSimpleRegex("_dlt_list_idx"), - TSimpleRegex("_dlt_load_id"), + TSimpleRegex(self.c_dlt_id), + TSimpleRegex(self.c_dlt_root_id), + TSimpleRegex(self.c_dlt_parent_id), + TSimpleRegex(self.c_dlt_list_idx), + TSimpleRegex(self.c_dlt_load_id), ], - "foreign_key": [TSimpleRegex("_dlt_parent_id")], - "root_key": [TSimpleRegex("_dlt_root_id")], - "unique": [TSimpleRegex("_dlt_id")], - } + "foreign_key": [TSimpleRegex(self.c_dlt_parent_id)], + "root_key": [TSimpleRegex(self.c_dlt_root_id)], + "unique": [TSimpleRegex(self.c_dlt_id)], + }, + normalize_identifiers=False, # already normalized ) for table_name in self.schema.tables.keys(): self.extend_table(table_name) def extend_table(self, table_name: str) -> None: - # if the table has a merge w_d, add propagation info to normalizer + """If the table has a merge write disposition, add propagation info to normalizer + + Called by Schema when new table is added to schema or table is updated with partial table. + Table name should be normalized. + """ table = self.schema.tables.get(table_name) if not table.get("parent") and table.get("write_disposition") == "merge": DataItemNormalizer.update_normalizer_config( self.schema, - {"propagation": {"tables": {table_name: {"_dlt_id": TColumnName("_dlt_root_id")}}}}, + { + "propagation": { + "tables": { + table_name: { + TColumnName(self.c_dlt_id): TColumnName(self.c_dlt_root_id) + } + } + } + }, ) def normalize_data_item( @@ -338,21 +390,18 @@ def normalize_data_item( if not isinstance(item, dict): item = wrap_in_dict(item) # we will extend event with all the fields necessary to load it as root row - row = cast(TDataItemRowRoot, item) + row = cast(DictStrAny, item) # identify load id if loaded data must be processed after loading incrementally - row["_dlt_load_id"] = load_id - # determine if row hash should be used as dlt id - row_hash = False - if self._is_scd2_table(self.schema, table_name): - row_hash = self._dlt_id_is_row_hash(self.schema, table_name) + row[self.c_dlt_load_id] = load_id + if self._get_merge_strategy(self.schema, table_name) == "scd2": self._validate_validity_column_names( - self._get_validity_column_names(self.schema, table_name), item + self.schema.name, self._get_validity_column_names(self.schema, table_name), item ) + yield from self._normalize_row( - cast(TDataItemRowChild, row), + row, {}, (self.schema.naming.normalize_table_identifier(table_name),), - row_hash=row_hash, ) @classmethod @@ -365,12 +414,12 @@ def ensure_this_normalizer(cls, norm_config: TJSONNormalizer) -> None: @classmethod def update_normalizer_config(cls, schema: Schema, config: RelationalNormalizerConfig) -> None: cls._validate_normalizer_config(schema, config) - norm_config = schema._normalizers_config["json"] - cls.ensure_this_normalizer(norm_config) - if "config" in norm_config: - update_dict_nested(norm_config["config"], config) # type: ignore + existing_config = schema._normalizers_config["json"] + cls.ensure_this_normalizer(existing_config) + if "config" in existing_config: + update_dict_nested(existing_config["config"], config) # type: ignore else: - norm_config["config"] = config + existing_config["config"] = config @classmethod def get_normalizer_config(cls, schema: Schema) -> RelationalNormalizerConfig: @@ -380,6 +429,29 @@ def get_normalizer_config(cls, schema: Schema) -> RelationalNormalizerConfig: @staticmethod def _validate_normalizer_config(schema: Schema, config: RelationalNormalizerConfig) -> None: + """Normalizes all known column identifiers according to the schema and then validates the configuration""" + + def _normalize_prop( + mapping: Mapping[TColumnName, TColumnName] + ) -> Dict[TColumnName, TColumnName]: + return { + TColumnName(schema.naming.normalize_path(from_col)): TColumnName( + schema.naming.normalize_path(to_col) + ) + for from_col, to_col in mapping.items() + } + + # normalize the identifiers first + propagation_config = config.get("propagation") + if propagation_config: + if "root" in propagation_config: + propagation_config["root"] = _normalize_prop(propagation_config["root"]) + if "tables" in propagation_config: + for table_name in propagation_config["tables"]: + propagation_config["tables"][table_name] = _normalize_prop( + propagation_config["tables"][table_name] + ) + validate_dict( RelationalNormalizerConfig, config, @@ -397,11 +469,18 @@ def _get_table_nesting_level(schema: Schema, table_name: str) -> Optional[int]: @staticmethod @lru_cache(maxsize=None) - def _is_scd2_table(schema: Schema, table_name: str) -> bool: - if table_name in schema.data_table_names(): - if schema.get_table(table_name).get("x-merge-strategy") == "scd2": - return True - return False + def _get_merge_strategy(schema: Schema, table_name: str) -> Optional[TLoaderMergeStrategy]: + if table_name in schema.data_table_names(include_incomplete=True): + return schema.get_table(table_name).get("x-merge-strategy") # type: ignore[return-value] + return None + + @staticmethod + @lru_cache(maxsize=None) + def _get_primary_key(schema: Schema, table_name: str) -> List[str]: + if table_name not in schema.tables: + return [] + table = schema.get_table(table_name) + return get_columns_names_with_prop(table, "primary_key", include_incomplete=True) @staticmethod @lru_cache(maxsize=None) @@ -410,21 +489,35 @@ def _get_validity_column_names(schema: Schema, table_name: str) -> List[Optional @staticmethod @lru_cache(maxsize=None) - def _dlt_id_is_row_hash(schema: Schema, table_name: str) -> bool: - return ( - schema.get_table(table_name)["columns"] # type: ignore[return-value] - .get("_dlt_id", {}) - .get("x-row-version", False) - ) + def _get_row_id_type( + schema: Schema, table_name: str, primary_key: bool, _r_lvl: int + ) -> TRowIdType: + if _r_lvl == 0: # root table + merge_strategy = DataItemNormalizer._get_merge_strategy(schema, table_name) + if merge_strategy == "upsert": + return "key_hash" + elif merge_strategy == "scd2": + x_row_version_col = get_first_column_name_with_prop( + schema.get_table(table_name), + "x-row-version", + include_incomplete=True, + ) + if x_row_version_col == DataItemNormalizer.C_DLT_ID: + return "row_hash" + elif _r_lvl > 0: # child table + if not primary_key: + return "row_hash" + return "random" @staticmethod def _validate_validity_column_names( - validity_column_names: List[Optional[str]], item: TDataItem + schema_name: str, validity_column_names: List[Optional[str]], item: TDataItem ) -> None: """Raises exception if configured validity column name appears in data item.""" for validity_column_name in validity_column_names: if validity_column_name in item.keys(): raise ColumnNameConflictException( + schema_name, "Found column in data item with same name as validity column" - f' "{validity_column_name}".' + f' "{validity_column_name}".', ) diff --git a/dlt/common/normalizers/naming/__init__.py b/dlt/common/normalizers/naming/__init__.py index 967fb9643e..2b3ecd74d0 100644 --- a/dlt/common/normalizers/naming/__init__.py +++ b/dlt/common/normalizers/naming/__init__.py @@ -1,3 +1,3 @@ -from .naming import SupportsNamingConvention, NamingConvention +from .naming import NamingConvention -__all__ = ["SupportsNamingConvention", "NamingConvention"] +__all__ = ["NamingConvention"] diff --git a/dlt/common/normalizers/naming/direct.py b/dlt/common/normalizers/naming/direct.py index 0998650852..fc146dbc4c 100644 --- a/dlt/common/normalizers/naming/direct.py +++ b/dlt/common/normalizers/naming/direct.py @@ -1,20 +1,23 @@ -from typing import Any, Sequence +from typing import ClassVar from dlt.common.normalizers.naming.naming import NamingConvention as BaseNamingConvention class NamingConvention(BaseNamingConvention): - PATH_SEPARATOR = "▶" + """Case sensitive naming convention that maps source identifiers to destination identifiers with + only minimal changes. New line characters, double and single quotes are replaced with underscores. - _CLEANUP_TABLE = str.maketrans(".\n\r'\"▶", "______") + Uses ▶ as path separator. + """ + + PATH_SEPARATOR: ClassVar[str] = "▶" + _CLEANUP_TABLE = str.maketrans("\n\r'\"▶", "_____") def normalize_identifier(self, identifier: str) -> str: identifier = super().normalize_identifier(identifier) norm_identifier = identifier.translate(self._CLEANUP_TABLE) return self.shorten_identifier(norm_identifier, identifier, self.max_length) - def make_path(self, *identifiers: Any) -> str: - return self.PATH_SEPARATOR.join(filter(lambda x: x.strip(), identifiers)) - - def break_path(self, path: str) -> Sequence[str]: - return [ident for ident in path.split(self.PATH_SEPARATOR) if ident.strip()] + @property + def is_case_sensitive(self) -> bool: + return True diff --git a/dlt/common/normalizers/naming/duck_case.py b/dlt/common/normalizers/naming/duck_case.py index 063482a799..3801660ba8 100644 --- a/dlt/common/normalizers/naming/duck_case.py +++ b/dlt/common/normalizers/naming/duck_case.py @@ -5,8 +5,15 @@ class NamingConvention(SnakeCaseNamingConvention): + """Case sensitive naming convention preserving all unicode characters except new line(s). Uses __ for path + separation and will replace multiple underscores with a single one. + """ + _CLEANUP_TABLE = str.maketrans('\n\r"', "___") - _RE_LEADING_DIGITS = None # do not remove leading digits + + @property + def is_case_sensitive(self) -> bool: + return True @staticmethod @lru_cache(maxsize=None) @@ -17,5 +24,5 @@ def _normalize_identifier(identifier: str, max_length: int) -> str: # shorten identifier return NamingConvention.shorten_identifier( - NamingConvention._RE_UNDERSCORES.sub("_", normalized_ident), identifier, max_length + NamingConvention.RE_UNDERSCORES.sub("_", normalized_ident), identifier, max_length ) diff --git a/dlt/common/normalizers/naming/exceptions.py b/dlt/common/normalizers/naming/exceptions.py index 572fc7e0d0..0b22ae2dd5 100644 --- a/dlt/common/normalizers/naming/exceptions.py +++ b/dlt/common/normalizers/naming/exceptions.py @@ -5,21 +5,33 @@ class NormalizersException(DltException): pass -class UnknownNamingModule(NormalizersException): +class UnknownNamingModule(ImportError, NormalizersException): def __init__(self, naming_module: str) -> None: self.naming_module = naming_module if "." in naming_module: msg = f"Naming module {naming_module} could not be found and imported" else: - msg = f"Naming module {naming_module} is not one of the standard dlt naming convention" + msg = ( + f"Naming module {naming_module} is not one of the standard dlt naming conventions" + " and could not be locally imported" + ) super().__init__(msg) -class InvalidNamingModule(NormalizersException): - def __init__(self, naming_module: str) -> None: +class NamingTypeNotFound(ImportError, NormalizersException): + def __init__(self, naming_module: str, naming_class: str) -> None: + self.naming_module = naming_module + self.naming_class = naming_class + msg = f"In naming module '{naming_module}' type '{naming_class}' does not exist" + super().__init__(msg) + + +class InvalidNamingType(NormalizersException): + def __init__(self, naming_module: str, naming_class: str) -> None: self.naming_module = naming_module + self.naming_class = naming_class msg = ( - f"Naming module {naming_module} does not implement required SupportsNamingConvention" - " protocol" + f"In naming module '{naming_module}' the class '{naming_class}' is not a" + " NamingConvention" ) super().__init__(msg) diff --git a/dlt/common/normalizers/naming/naming.py b/dlt/common/normalizers/naming/naming.py index fccb147981..5ae5847963 100644 --- a/dlt/common/normalizers/naming/naming.py +++ b/dlt/common/normalizers/naming/naming.py @@ -3,16 +3,28 @@ from functools import lru_cache import math import hashlib -from typing import Any, List, Protocol, Sequence, Type +from typing import Sequence, ClassVar class NamingConvention(ABC): - _TR_TABLE = bytes.maketrans(b"/+", b"ab") - _DEFAULT_COLLISION_PROB = 0.001 + """Initializes naming convention to generate identifier with `max_length` if specified. Base naming convention + is case sensitive by default + """ + + _TR_TABLE: ClassVar[bytes] = bytes.maketrans(b"/+", b"ab") + _DEFAULT_COLLISION_PROB: ClassVar[float] = 0.001 + PATH_SEPARATOR: ClassVar[str] = "__" + """Subsequent nested fields will be separated with the string below, applies both to field and table names""" def __init__(self, max_length: int = None) -> None: self.max_length = max_length + @property + @abstractmethod + def is_case_sensitive(self) -> bool: + """Tells if given naming convention is producing case insensitive or case sensitive identifiers.""" + pass + @abstractmethod def normalize_identifier(self, identifier: str) -> str: """Normalizes and shortens the identifier according to naming convention in this function code""" @@ -27,15 +39,13 @@ def normalize_table_identifier(self, identifier: str) -> str: """Normalizes and shortens identifier that will function as a dataset, table or schema name, defaults to `normalize_identifier`""" return self.normalize_identifier(identifier) - @abstractmethod def make_path(self, *identifiers: str) -> str: """Builds path out of identifiers. Identifiers are neither normalized nor shortened""" - pass + return self.PATH_SEPARATOR.join(filter(lambda x: x.strip(), identifiers)) - @abstractmethod def break_path(self, path: str) -> Sequence[str]: """Breaks path into sequence of identifiers""" - pass + return [ident for ident in path.split(self.PATH_SEPARATOR) if ident.strip()] def normalize_path(self, path: str) -> str: """Breaks path into identifiers, normalizes components, reconstitutes and shortens the path""" @@ -58,6 +68,21 @@ def shorten_fragments(self, *normalized_idents: str) -> str: path_str = self.make_path(*normalized_idents) return self.shorten_identifier(path_str, path_str, self.max_length) + @classmethod + def name(cls) -> str: + """Naming convention name is the name of the module in which NamingConvention is defined""" + if cls.__module__.startswith("dlt.common.normalizers.naming."): + # return last component + return cls.__module__.split(".")[-1] + return cls.__module__ + + def __str__(self) -> str: + name = self.name() + name += "_cs" if self.is_case_sensitive else "_ci" + if self.max_length: + name += f"_{self.max_length}" + return name + @staticmethod @lru_cache(maxsize=None) def shorten_identifier( @@ -100,10 +125,3 @@ def _trim_and_tag(identifier: str, tag: str, max_length: int) -> str: ) assert len(identifier) == max_length return identifier - - -class SupportsNamingConvention(Protocol): - """Expected of modules defining naming convention""" - - NamingConvention: Type[NamingConvention] - """A class with a name NamingConvention deriving from normalizers.naming.NamingConvention""" diff --git a/dlt/common/normalizers/naming/snake_case.py b/dlt/common/normalizers/naming/snake_case.py index b3c65e9b8d..d38841a238 100644 --- a/dlt/common/normalizers/naming/snake_case.py +++ b/dlt/common/normalizers/naming/snake_case.py @@ -1,42 +1,54 @@ import re -from typing import Any, List, Sequence from functools import lru_cache +from typing import ClassVar from dlt.common.normalizers.naming.naming import NamingConvention as BaseNamingConvention +from dlt.common.normalizers.naming.sql_cs_v1 import ( + RE_UNDERSCORES, + RE_LEADING_DIGITS, + RE_NON_ALPHANUMERIC, +) +from dlt.common.typing import REPattern class NamingConvention(BaseNamingConvention): - _RE_UNDERSCORES = re.compile("__+") - _RE_LEADING_DIGITS = re.compile(r"^\d+") - # _RE_ENDING_UNDERSCORES = re.compile(r"_+$") - _RE_NON_ALPHANUMERIC = re.compile(r"[^a-zA-Z\d_]+") + """Case insensitive naming convention, converting source identifiers into lower case snake case with reduced alphabet. + + - Spaces around identifier are trimmed + - Removes all ascii characters except ascii alphanumerics and underscores + - Prepends `_` if name starts with number. + - Multiples of `_` are converted into single `_`. + - Replaces all trailing `_` with `x` + - Replaces `+` and `*` with `x`, `-` with `_`, `@` with `a` and `|` with `l` + + Uses __ as patent-child separator for tables and flattened column names. + """ + + RE_UNDERSCORES: ClassVar[REPattern] = RE_UNDERSCORES + RE_LEADING_DIGITS: ClassVar[REPattern] = RE_LEADING_DIGITS + RE_NON_ALPHANUMERIC: ClassVar[REPattern] = RE_NON_ALPHANUMERIC + _SNAKE_CASE_BREAK_1 = re.compile("([^_])([A-Z][a-z]+)") _SNAKE_CASE_BREAK_2 = re.compile("([a-z0-9])([A-Z])") _REDUCE_ALPHABET = ("+-*@|", "x_xal") _TR_REDUCE_ALPHABET = str.maketrans(_REDUCE_ALPHABET[0], _REDUCE_ALPHABET[1]) - # subsequent nested fields will be separated with the string below, applies both to field and table names - PATH_SEPARATOR = "__" + @property + def is_case_sensitive(self) -> bool: + return False def normalize_identifier(self, identifier: str) -> str: identifier = super().normalize_identifier(identifier) # print(f"{identifier} -> {self.shorten_identifier(identifier, self.max_length)} ({self.max_length})") return self._normalize_identifier(identifier, self.max_length) - def make_path(self, *identifiers: str) -> str: - # only non empty identifiers participate - return self.PATH_SEPARATOR.join(filter(lambda x: x.strip(), identifiers)) - - def break_path(self, path: str) -> Sequence[str]: - return [ident for ident in path.split(self.PATH_SEPARATOR) if ident.strip()] - @staticmethod @lru_cache(maxsize=None) def _normalize_identifier(identifier: str, max_length: int) -> str: """Normalizes the identifier according to naming convention represented by this function""" # all characters that are not letters digits or a few special chars are replaced with underscore normalized_ident = identifier.translate(NamingConvention._TR_REDUCE_ALPHABET) - normalized_ident = NamingConvention._RE_NON_ALPHANUMERIC.sub("_", normalized_ident) + normalized_ident = NamingConvention.RE_NON_ALPHANUMERIC.sub("_", normalized_ident) # shorten identifier return NamingConvention.shorten_identifier( @@ -50,7 +62,7 @@ def _to_snake_case(cls, identifier: str) -> str: identifier = cls._SNAKE_CASE_BREAK_2.sub(r"\1_\2", identifier).lower() # leading digits will be prefixed (if regex is defined) - if cls._RE_LEADING_DIGITS and cls._RE_LEADING_DIGITS.match(identifier): + if cls.RE_LEADING_DIGITS and cls.RE_LEADING_DIGITS.match(identifier): identifier = "_" + identifier # replace trailing _ with x @@ -59,5 +71,5 @@ def _to_snake_case(cls, identifier: str) -> str: stripped_ident += "x" * strip_count # identifier = cls._RE_ENDING_UNDERSCORES.sub("x", identifier) - # replace consecutive underscores with single one to prevent name clashes with PATH_SEPARATOR - return cls._RE_UNDERSCORES.sub("_", stripped_ident) + # replace consecutive underscores with single one to prevent name collisions with PATH_SEPARATOR + return cls.RE_UNDERSCORES.sub("_", stripped_ident) diff --git a/dlt/common/normalizers/naming/sql_ci_v1.py b/dlt/common/normalizers/naming/sql_ci_v1.py new file mode 100644 index 0000000000..4fff52ffd6 --- /dev/null +++ b/dlt/common/normalizers/naming/sql_ci_v1.py @@ -0,0 +1,12 @@ +from dlt.common.normalizers.naming.sql_cs_v1 import NamingConvention as SqlCsNamingConvention + + +class NamingConvention(SqlCsNamingConvention): + """A variant of sql_cs which lower cases all identifiers.""" + + def normalize_identifier(self, identifier: str) -> str: + return super().normalize_identifier(identifier).lower() + + @property + def is_case_sensitive(self) -> bool: + return False diff --git a/dlt/common/normalizers/naming/sql_cs_v1.py b/dlt/common/normalizers/naming/sql_cs_v1.py new file mode 100644 index 0000000000..788089fa7d --- /dev/null +++ b/dlt/common/normalizers/naming/sql_cs_v1.py @@ -0,0 +1,44 @@ +import re +from typing import ClassVar + +from dlt.common.typing import REPattern +from dlt.common.normalizers.naming.naming import NamingConvention as BaseNamingConvention + + +RE_UNDERSCORES = re.compile("__+") +RE_LEADING_DIGITS = re.compile(r"^\d+") +RE_ENDING_UNDERSCORES = re.compile(r"_+$") +RE_NON_ALPHANUMERIC = re.compile(r"[^a-zA-Z\d_]+") + + +class NamingConvention(BaseNamingConvention): + """Generates case sensitive SQL safe identifiers, preserving the source casing. + + - Spaces around identifier are trimmed + - Removes all ascii characters except ascii alphanumerics and underscores + - Prepends `_` if name starts with number. + - Removes all trailing underscores. + - Multiples of `_` are converted into single `_`. + """ + + RE_NON_ALPHANUMERIC: ClassVar[REPattern] = RE_NON_ALPHANUMERIC + RE_UNDERSCORES: ClassVar[REPattern] = RE_UNDERSCORES + RE_ENDING_UNDERSCORES: ClassVar[REPattern] = RE_ENDING_UNDERSCORES + + def normalize_identifier(self, identifier: str) -> str: + identifier = super().normalize_identifier(identifier) + # remove non alpha characters + norm_identifier = self.RE_NON_ALPHANUMERIC.sub("_", identifier) + # remove leading digits + if RE_LEADING_DIGITS.match(norm_identifier): + norm_identifier = "_" + norm_identifier + # remove trailing underscores to not mess with how we break paths + if norm_identifier != "_": + norm_identifier = self.RE_ENDING_UNDERSCORES.sub("", norm_identifier) + # contract multiple __ + norm_identifier = self.RE_UNDERSCORES.sub("_", norm_identifier) + return self.shorten_identifier(norm_identifier, identifier, self.max_length) + + @property + def is_case_sensitive(self) -> bool: + return True diff --git a/dlt/common/normalizers/typing.py b/dlt/common/normalizers/typing.py index 599426259f..9840f3a4d2 100644 --- a/dlt/common/normalizers/typing.py +++ b/dlt/common/normalizers/typing.py @@ -1,14 +1,22 @@ -from typing import List, Optional, TypedDict +from typing import List, Optional, Type, TypedDict, Literal, Union +from types import ModuleType from dlt.common.typing import StrAny +from dlt.common.normalizers.naming import NamingConvention + +TNamingConventionReferenceArg = Union[str, Type[NamingConvention], ModuleType] + + +TRowIdType = Literal["random", "row_hash", "key_hash"] class TJSONNormalizer(TypedDict, total=False): module: str - config: Optional[StrAny] # config is a free form and is consumed by `module` + config: Optional[StrAny] # config is a free form and is validated by `module` class TNormalizersConfig(TypedDict, total=False): names: str + allow_identifier_change_on_table_with_data: Optional[bool] detections: Optional[List[str]] json: TJSONNormalizer diff --git a/dlt/common/normalizers/utils.py b/dlt/common/normalizers/utils.py index 645bad2bea..d852cfb7d9 100644 --- a/dlt/common/normalizers/utils.py +++ b/dlt/common/normalizers/utils.py @@ -1,60 +1,168 @@ +import os from importlib import import_module -from typing import Any, Type, Tuple, cast, List +from types import ModuleType +from typing import Any, Dict, Optional, Type, Tuple, cast, List import dlt +from dlt.common import logger +from dlt.common import known_env from dlt.common.configuration.inject import with_config +from dlt.common.configuration.specs import known_sections from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.normalizers.configuration import NormalizersConfiguration +from dlt.common.normalizers.exceptions import InvalidJsonNormalizer from dlt.common.normalizers.json import SupportsDataItemNormalizer, DataItemNormalizer -from dlt.common.normalizers.naming import NamingConvention, SupportsNamingConvention -from dlt.common.normalizers.naming.exceptions import UnknownNamingModule, InvalidNamingModule -from dlt.common.normalizers.typing import TJSONNormalizer, TNormalizersConfig -from dlt.common.utils import uniq_id_base64, many_uniq_ids_base64 +from dlt.common.normalizers.naming import NamingConvention +from dlt.common.normalizers.naming.exceptions import ( + NamingTypeNotFound, + UnknownNamingModule, + InvalidNamingType, +) +from dlt.common.normalizers.typing import ( + TJSONNormalizer, + TNormalizersConfig, + TNamingConventionReferenceArg, +) +from dlt.common.typing import is_subclass +from dlt.common.utils import get_full_class_name, uniq_id_base64, many_uniq_ids_base64 -DEFAULT_NAMING_MODULE = "dlt.common.normalizers.naming.snake_case" -DLT_ID_LENGTH_BYTES = 10 +DEFAULT_NAMING_NAMESPACE = os.environ.get( + known_env.DLT_DEFAULT_NAMING_NAMESPACE, "dlt.common.normalizers.naming" +) +DEFAULT_NAMING_MODULE = os.environ.get(known_env.DLT_DEFAULT_NAMING_MODULE, "snake_case") +DLT_ID_LENGTH_BYTES = int(os.environ.get(known_env.DLT_DLT_ID_LENGTH_BYTES, 10)) -@with_config(spec=NormalizersConfiguration) +def _section_for_schema(kwargs: Dict[str, Any]) -> Tuple[str, ...]: + """Uses the schema name to generate dynamic section normalizer settings""" + if schema_name := kwargs.get("schema_name"): + return (known_sections.SOURCES, schema_name) + else: + return (known_sections.SOURCES,) + + +@with_config(spec=NormalizersConfiguration, sections=_section_for_schema) # type: ignore[call-overload] def explicit_normalizers( - naming: str = dlt.config.value, json_normalizer: TJSONNormalizer = dlt.config.value + naming: TNamingConventionReferenceArg = dlt.config.value, + json_normalizer: TJSONNormalizer = dlt.config.value, + allow_identifier_change_on_table_with_data: bool = None, + schema_name: Optional[str] = None, ) -> TNormalizersConfig: - """Gets explicitly configured normalizers - via config or destination caps. May return None as naming or normalizer""" - return {"names": naming, "json": json_normalizer} + """Gets explicitly configured normalizers without any defaults or capabilities injection. If `naming` + is a module or a type it will get converted into string form via import. + + If `schema_name` is present, a section ("sources", schema_name, "schema") is used to inject the config + """ + + norm_conf: TNormalizersConfig = {"names": serialize_reference(naming), "json": json_normalizer} + if allow_identifier_change_on_table_with_data is not None: + norm_conf["allow_identifier_change_on_table_with_data"] = ( + allow_identifier_change_on_table_with_data + ) + return norm_conf @with_config def import_normalizers( - normalizers_config: TNormalizersConfig, + explicit_normalizers: TNormalizersConfig, + default_normalizers: TNormalizersConfig = None, destination_capabilities: DestinationCapabilitiesContext = None, ) -> Tuple[TNormalizersConfig, NamingConvention, Type[DataItemNormalizer[Any]]]: """Imports the normalizers specified in `normalizers_config` or taken from defaults. Returns the updated config and imported modules. - `destination_capabilities` are used to get max length of the identifier. + `destination_capabilities` are used to get naming convention, max length of the identifier and max nesting level. """ + if default_normalizers is None: + default_normalizers = {} # add defaults to normalizer_config - normalizers_config["names"] = names = normalizers_config["names"] or "snake_case" - # set default json normalizer module - normalizers_config["json"] = item_normalizer = normalizers_config.get("json") or {} - if "module" not in item_normalizer: - item_normalizer["module"] = "dlt.common.normalizers.json.relational" - - try: - if "." in names: + naming: TNamingConventionReferenceArg = explicit_normalizers.get("names") + if naming is None: + if destination_capabilities: + naming = destination_capabilities.naming_convention + if naming is None: + naming = default_normalizers.get("names") or DEFAULT_NAMING_MODULE + naming_convention = naming_from_reference(naming, destination_capabilities) + explicit_normalizers["names"] = serialize_reference(naming) + + item_normalizer = explicit_normalizers.get("json") or default_normalizers.get("json") or {} + item_normalizer.setdefault("module", "dlt.common.normalizers.json.relational") + # if max_table_nesting is set, we need to set the max_table_nesting in the json_normalizer + if destination_capabilities and destination_capabilities.max_table_nesting is not None: + # TODO: this is a hack, we need a better method to do this + from dlt.common.normalizers.json.relational import DataItemNormalizer + + try: + DataItemNormalizer.ensure_this_normalizer(item_normalizer) + item_normalizer.setdefault("config", {}) + item_normalizer["config"]["max_nesting"] = destination_capabilities.max_table_nesting # type: ignore[index] + except InvalidJsonNormalizer: + # not a right normalizer + logger.warning(f"JSON Normalizer {item_normalizer} does not support max_nesting") + pass + json_module = cast(SupportsDataItemNormalizer, import_module(item_normalizer["module"])) + explicit_normalizers["json"] = item_normalizer + return ( + explicit_normalizers, + naming_convention, + json_module.DataItemNormalizer, + ) + + +def naming_from_reference( + names: TNamingConventionReferenceArg, + destination_capabilities: DestinationCapabilitiesContext = None, +) -> NamingConvention: + """Resolves naming convention from reference in `names` and applies max length from `destination_capabilities` + + Reference may be: (1) shorthand name pointing to `dlt.common.normalizers.naming` namespace + (2) a type name which is a module containing `NamingConvention` attribute (3) a type of class deriving from NamingConvention + """ + + def _import_naming(module: str) -> ModuleType: + if "." in module: # TODO: bump schema engine version and migrate schema. also change the name in TNormalizersConfig from names to naming - if names == "dlt.common.normalizers.names.snake_case": - names = DEFAULT_NAMING_MODULE + if module == "dlt.common.normalizers.names.snake_case": + module = f"{DEFAULT_NAMING_NAMESPACE}.{DEFAULT_NAMING_MODULE}" # this is full module name - naming_module = cast(SupportsNamingConvention, import_module(names)) + naming_module = import_module(module) else: # from known location - naming_module = cast( - SupportsNamingConvention, import_module(f"dlt.common.normalizers.naming.{names}") - ) - except ImportError: - raise UnknownNamingModule(names) - if not hasattr(naming_module, "NamingConvention"): - raise InvalidNamingModule(names) + try: + naming_module = import_module(f"{DEFAULT_NAMING_NAMESPACE}.{module}") + except ImportError: + # also import local module + naming_module = import_module(module) + return naming_module + + def _get_type(naming_module: ModuleType, cls: str) -> Type[NamingConvention]: + class_: Type[NamingConvention] = getattr(naming_module, cls, None) + if class_ is None: + raise NamingTypeNotFound(naming_module.__name__, cls) + if is_subclass(class_, NamingConvention): + return class_ + raise InvalidNamingType(naming_module.__name__, cls) + + if is_subclass(names, NamingConvention): + class_: Type[NamingConvention] = names # type: ignore[assignment] + elif isinstance(names, ModuleType): + class_ = _get_type(names, "NamingConvention") + elif isinstance(names, str): + try: + class_ = _get_type(_import_naming(names), "NamingConvention") + except ImportError: + parts = names.rsplit(".", 1) + # we have no more options to try + if len(parts) <= 1: + raise UnknownNamingModule(names) + try: + class_ = _get_type(_import_naming(parts[0]), parts[1]) + except UnknownNamingModule: + raise + except ImportError: + raise UnknownNamingModule(names) + else: + raise ValueError(names) + # get max identifier length if destination_capabilities: max_length = min( @@ -63,13 +171,18 @@ def import_normalizers( ) else: max_length = None - json_module = cast(SupportsDataItemNormalizer, import_module(item_normalizer["module"])) - return ( - normalizers_config, - naming_module.NamingConvention(max_length), - json_module.DataItemNormalizer, - ) + return class_(max_length) + + +def serialize_reference(naming: Optional[TNamingConventionReferenceArg]) -> Optional[str]: + """Serializes generic `naming` reference to importable string.""" + if naming is None: + return naming + if isinstance(naming, str): + return naming + # import reference and use naming to get valid path to type + return get_full_class_name(naming_from_reference(naming)) def generate_dlt_ids(n_ids: int) -> List[str]: diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index 6cefdd9e6c..1e1416eb53 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -260,9 +260,6 @@ def asstr(self, verbosity: int = 0) -> str: return self._load_packages_asstr(self.load_packages, verbosity) -# reveal_type(ExtractInfo) - - class NormalizeMetrics(StepMetrics): job_metrics: Dict[str, DataWriterMetrics] """Metrics collected per job id during writing of job file""" @@ -605,6 +602,14 @@ def __init__(self, deferred_pipeline: Callable[..., SupportsPipeline] = None) -> self._deferred_pipeline = deferred_pipeline +def current_pipeline() -> SupportsPipeline: + """Gets active pipeline context or None if not found""" + proxy = Container()[PipelineContext] + if not proxy.is_active(): + return None + return proxy.pipeline() + + @configspec class StateInjectableContext(ContainerInjectableContext): state: TPipelineState = None diff --git a/dlt/common/runners/stdout.py b/dlt/common/runners/stdout.py index 6a92838342..bb5251764c 100644 --- a/dlt/common/runners/stdout.py +++ b/dlt/common/runners/stdout.py @@ -21,11 +21,11 @@ def exec_to_stdout(f: AnyFun) -> Iterator[Any]: rv = f() yield rv except Exception as ex: - print(encode_obj(ex), file=sys.stderr, flush=True) + print(encode_obj(ex), file=sys.stderr, flush=True) # noqa raise finally: if rv is not None: - print(encode_obj(rv), flush=True) + print(encode_obj(rv), flush=True) # noqa def iter_std( @@ -126,6 +126,6 @@ def iter_stdout_with_result( if isinstance(exception, Exception): raise exception from cpe else: - print(cpe.stderr, file=sys.stderr) + sys.stderr.write(cpe.stderr) # otherwise reraise cpe raise diff --git a/dlt/common/runtime/collector.py b/dlt/common/runtime/collector.py index e00bca576e..95117b70cc 100644 --- a/dlt/common/runtime/collector.py +++ b/dlt/common/runtime/collector.py @@ -230,7 +230,7 @@ def _log(self, log_level: int, log_message: str) -> None: if isinstance(self.logger, (logging.Logger, logging.LoggerAdapter)): self.logger.log(log_level, log_message) else: - print(log_message, file=self.logger or sys.stdout) + print(log_message, file=self.logger or sys.stdout) # noqa def _start(self, step: str) -> None: self.counters = defaultdict(int) diff --git a/dlt/common/runtime/telemetry.py b/dlt/common/runtime/telemetry.py index 28dde0206c..6b783483cc 100644 --- a/dlt/common/runtime/telemetry.py +++ b/dlt/common/runtime/telemetry.py @@ -14,7 +14,6 @@ disable_anon_tracker, track, ) -from dlt.pipeline.platform import disable_platform_tracker, init_platform_tracker _TELEMETRY_STARTED = False @@ -36,6 +35,10 @@ def start_telemetry(config: RunConfiguration) -> None: init_anon_tracker(config) if config.dlthub_dsn: + # TODO: we need pluggable modules for tracing so import into + # concrete modules is not needed + from dlt.pipeline.platform import init_platform_tracker + init_platform_tracker() _TELEMETRY_STARTED = True @@ -55,6 +58,9 @@ def stop_telemetry() -> None: pass disable_anon_tracker() + + from dlt.pipeline.platform import disable_platform_tracker + disable_platform_tracker() _TELEMETRY_STARTED = False diff --git a/dlt/common/schema/exceptions.py b/dlt/common/schema/exceptions.py index 678f4de15e..2f016577ce 100644 --- a/dlt/common/schema/exceptions.py +++ b/dlt/common/schema/exceptions.py @@ -7,37 +7,45 @@ TSchemaContractEntities, TSchemaEvolutionMode, ) +from dlt.common.normalizers.naming import NamingConvention class SchemaException(DltException): - pass + def __init__(self, schema_name: str, msg: str) -> None: + self.schema_name = schema_name + if schema_name: + msg = f"In schema: {schema_name}: " + msg + super().__init__(msg) class InvalidSchemaName(ValueError, SchemaException): MAXIMUM_SCHEMA_NAME_LENGTH = 64 - def __init__(self, name: str) -> None: - self.name = name + def __init__(self, schema_name: str) -> None: + self.name = schema_name super().__init__( - f"{name} is an invalid schema/source name. The source or schema name must be a valid" - " Python identifier ie. a snake case function name and have maximum" + schema_name, + f"{schema_name} is an invalid schema/source name. The source or schema name must be a" + " valid Python identifier ie. a snake case function name and have maximum" f" {self.MAXIMUM_SCHEMA_NAME_LENGTH} characters. Ideally should contain only small" - " letters, numbers and underscores." + " letters, numbers and underscores.", ) -class InvalidDatasetName(ValueError, SchemaException): - def __init__(self, destination_name: str) -> None: - self.destination_name = destination_name - super().__init__( - f"Destination {destination_name} does not accept empty datasets. Please pass the" - " dataset name to the destination configuration ie. via dlt pipeline." - ) +# TODO: does not look like a SchemaException +# class InvalidDatasetName(ValueError, SchemaException): +# def __init__(self, destination_name: str) -> None: +# self.destination_name = destination_name +# super().__init__( +# f"Destination {destination_name} does not accept empty datasets. Please pass the" +# " dataset name to the destination configuration ie. via dlt pipeline." +# ) class CannotCoerceColumnException(SchemaException): def __init__( self, + schema_name: str, table_name: str, column_name: str, from_type: TDataType, @@ -50,37 +58,43 @@ def __init__( self.to_type = to_type self.coerced_value = coerced_value super().__init__( + schema_name, f"Cannot coerce type in table {table_name} column {column_name} existing type" - f" {from_type} coerced type {to_type} value: {coerced_value}" + f" {from_type} coerced type {to_type} value: {coerced_value}", ) class TablePropertiesConflictException(SchemaException): - def __init__(self, table_name: str, prop_name: str, val1: str, val2: str): + def __init__(self, schema_name: str, table_name: str, prop_name: str, val1: str, val2: str): self.table_name = table_name self.prop_name = prop_name self.val1 = val1 self.val2 = val2 super().__init__( + schema_name, f"Cannot merge partial tables for {table_name} due to property {prop_name}: {val1} !=" - f" {val2}" + f" {val2}", ) class ParentTableNotFoundException(SchemaException): - def __init__(self, table_name: str, parent_table_name: str, explanation: str = "") -> None: + def __init__( + self, schema_name: str, table_name: str, parent_table_name: str, explanation: str = "" + ) -> None: self.table_name = table_name self.parent_table_name = parent_table_name super().__init__( + schema_name, f"Parent table {parent_table_name} for {table_name} was not found in the" - f" schema.{explanation}" + f" schema.{explanation}", ) class CannotCoerceNullException(SchemaException): - def __init__(self, table_name: str, column_name: str) -> None: + def __init__(self, schema_name: str, table_name: str, column_name: str) -> None: super().__init__( - f"Cannot coerce NULL in table {table_name} column {column_name} which is not nullable" + schema_name, + f"Cannot coerce NULL in table {table_name} column {column_name} which is not nullable", ) @@ -88,19 +102,48 @@ class SchemaCorruptedException(SchemaException): pass +class SchemaIdentifierNormalizationCollision(SchemaCorruptedException): + def __init__( + self, + schema_name: str, + table_name: str, + identifier_type: str, + identifier_name: str, + conflict_identifier_name: str, + naming_name: str, + collision_msg: str, + ) -> None: + if identifier_type == "column": + table_info = f"in table {table_name} " + else: + table_info = "" + msg = ( + f"A {identifier_type} name {identifier_name} {table_info}collides with" + f" {conflict_identifier_name} after normalization with {naming_name} naming" + " convention. " + + collision_msg + ) + self.table_name = table_name + self.identifier_type = identifier_type + self.identifier_name = identifier_name + self.conflict_identifier_name = conflict_identifier_name + self.naming_name = naming_name + super().__init__(schema_name, msg) + + class SchemaEngineNoUpgradePathException(SchemaException): def __init__( self, schema_name: str, init_engine: int, from_engine: int, to_engine: int ) -> None: - self.schema_name = schema_name self.init_engine = init_engine self.from_engine = from_engine self.to_engine = to_engine super().__init__( + schema_name, f"No engine upgrade path in schema {schema_name} from {init_engine} to {to_engine}," f" stopped at {from_engine}. You possibly tried to run an older dlt" " version against a destination you have previously loaded data to with a newer dlt" - " version." + " version.", ) @@ -133,8 +176,7 @@ def __init__( + f" . Contract on {schema_entity} with mode {contract_mode} is violated. " + (extended_info or "") ) - super().__init__(msg) - self.schema_name = schema_name + super().__init__(schema_name, msg) self.table_name = table_name self.column_name = column_name @@ -148,10 +190,43 @@ def __init__( self.data_item = data_item -class UnknownTableException(SchemaException): - def __init__(self, table_name: str) -> None: +class UnknownTableException(KeyError, SchemaException): + def __init__(self, schema_name: str, table_name: str) -> None: self.table_name = table_name - super().__init__(f"Trying to access unknown table {table_name}.") + super().__init__(schema_name, f"Trying to access unknown table {table_name}.") + + +class TableIdentifiersFrozen(SchemaException): + def __init__( + self, + schema_name: str, + table_name: str, + to_naming: NamingConvention, + from_naming: NamingConvention, + details: str, + ) -> None: + self.table_name = table_name + self.to_naming = to_naming + self.from_naming = from_naming + msg = ( + f"Attempt to normalize identifiers for a table {table_name} from naming" + f" {from_naming.name()} to {to_naming.name()} changed one or more identifiers. " + ) + msg += ( + " This table already received data and tables were created at the destination. By" + " default changing the identifiers is not allowed. " + ) + msg += ( + " Such changes may result in creation of a new table or a new columns while the old" + " columns with data will still be kept. " + ) + msg += ( + " You may disable this behavior by setting" + " schema.allow_identifier_change_on_table_with_data to True or removing `x-normalizer`" + " hints from particular tables. " + ) + msg += f" Details: {details}" + super().__init__(schema_name, msg) class ColumnNameConflictException(SchemaException): diff --git a/dlt/common/schema/migrations.py b/dlt/common/schema/migrations.py index 9b206d61a6..b64714ba19 100644 --- a/dlt/common/schema/migrations.py +++ b/dlt/common/schema/migrations.py @@ -1,7 +1,7 @@ from typing import Dict, List, cast from dlt.common.data_types import TDataType -from dlt.common.normalizers import explicit_normalizers +from dlt.common.normalizers.utils import explicit_normalizers from dlt.common.typing import DictStrAny from dlt.common.schema.typing import ( LOADS_TABLE_NAME, @@ -14,7 +14,7 @@ from dlt.common.schema.exceptions import SchemaEngineNoUpgradePathException from dlt.common.normalizers.utils import import_normalizers -from dlt.common.schema.utils import new_table, version_table, load_table +from dlt.common.schema.utils import new_table, version_table, loads_table def migrate_schema(schema_dict: DictStrAny, from_engine: int, to_engine: int) -> TStoredSchema: @@ -29,7 +29,8 @@ def migrate_schema(schema_dict: DictStrAny, from_engine: int, to_engine: int) -> # current version of the schema current = cast(TStoredSchema, schema_dict) # add default normalizers and root hash propagation - current["normalizers"], _, _ = import_normalizers(explicit_normalizers()) + normalizers = explicit_normalizers() + current["normalizers"], _, _ = import_normalizers(normalizers, normalizers) current["normalizers"]["json"]["config"] = { "propagation": {"root": {"_dlt_id": "_dlt_root_id"}} } @@ -92,11 +93,11 @@ def migrate_filters(group: str, filters: List[str]) -> None: if from_engine == 4 and to_engine > 4: # replace schema versions table schema_dict["tables"][VERSION_TABLE_NAME] = version_table() - schema_dict["tables"][LOADS_TABLE_NAME] = load_table() + schema_dict["tables"][LOADS_TABLE_NAME] = loads_table() from_engine = 5 if from_engine == 5 and to_engine > 5: # replace loads table - schema_dict["tables"][LOADS_TABLE_NAME] = load_table() + schema_dict["tables"][LOADS_TABLE_NAME] = loads_table() from_engine = 6 if from_engine == 6 and to_engine > 6: # migrate from sealed properties to schema evolution settings diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index 6d5dc48907..39db0e42ae 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -1,5 +1,16 @@ from copy import copy, deepcopy -from typing import ClassVar, Dict, List, Mapping, Optional, Sequence, Tuple, Any, cast, Literal +from typing import ( + Callable, + ClassVar, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, + Any, + cast, +) from dlt.common.schema.migrations import migrate_schema from dlt.common.utils import extend_list_deduplicated @@ -11,8 +22,8 @@ VARIANT_FIELD_FORMAT, TDataItem, ) -from dlt.common.normalizers import TNormalizersConfig, explicit_normalizers, import_normalizers -from dlt.common.normalizers.naming import NamingConvention +from dlt.common.normalizers import TNormalizersConfig, NamingConvention +from dlt.common.normalizers.utils import explicit_normalizers, import_normalizers from dlt.common.normalizers.json import DataItemNormalizer, TNormalizedRowIterator from dlt.common.schema import utils from dlt.common.data_types import py_type_to_sc_type, coerce_value, TDataType @@ -22,7 +33,7 @@ SCHEMA_ENGINE_VERSION, LOADS_TABLE_NAME, VERSION_TABLE_NAME, - STATE_TABLE_NAME, + PIPELINE_STATE_TABLE_NAME, TPartialTableSchema, TSchemaContractEntities, TSchemaEvolutionMode, @@ -45,6 +56,7 @@ InvalidSchemaName, ParentTableNotFoundException, SchemaCorruptedException, + TableIdentifiersFrozen, ) from dlt.common.validation import validate_dict from dlt.common.schema.exceptions import DataValidationError @@ -102,13 +114,18 @@ def __init__(self, name: str, normalizers: TNormalizersConfig = None) -> None: self._reset_schema(name, normalizers) @classmethod - def from_dict(cls, d: DictStrAny, bump_version: bool = True) -> "Schema": + def from_dict( + cls, d: DictStrAny, remove_processing_hints: bool = False, bump_version: bool = True + ) -> "Schema": # upgrade engine if needed stored_schema = migrate_schema(d, d["engine_version"], cls.ENGINE_VERSION) # verify schema utils.validate_stored_schema(stored_schema) # add defaults stored_schema = utils.apply_defaults(stored_schema) + # remove processing hints that could be created by normalize and load steps + if remove_processing_hints: + utils.remove_processing_hints(stored_schema["tables"]) # bump version if modified if bump_version: @@ -141,30 +158,6 @@ def replace_schema_content( self._reset_schema(schema.name, schema._normalizers_config) self._from_stored_schema(stored_schema) - def to_dict(self, remove_defaults: bool = False, bump_version: bool = True) -> TStoredSchema: - stored_schema: TStoredSchema = { - "version": self._stored_version, - "version_hash": self._stored_version_hash, - "engine_version": Schema.ENGINE_VERSION, - "name": self._schema_name, - "tables": self._schema_tables, - "settings": self._settings, - "normalizers": self._normalizers_config, - "previous_hashes": self._stored_previous_hashes, - } - if self._imported_version_hash and not remove_defaults: - stored_schema["imported_version_hash"] = self._imported_version_hash - if self._schema_description: - stored_schema["description"] = self._schema_description - - # bump version if modified - if bump_version: - utils.bump_version_if_modified(stored_schema) - # remove defaults after bumping version - if remove_defaults: - utils.remove_defaults(stored_schema) - return stored_schema - def normalize_data_item( self, item: TDataItem, load_id: str, table_name: str ) -> TNormalizedRowIterator: @@ -317,7 +310,7 @@ def apply_schema_contract( column_mode, data_mode = schema_contract["columns"], schema_contract["data_type"] # allow to add new columns when table is new or if columns are allowed to evolve once - if is_new_table or existing_table.get("x-normalizer", {}).get("evolve-columns-once", False): # type: ignore[attr-defined] + if is_new_table or existing_table.get("x-normalizer", {}).get("evolve-columns-once", False): column_mode = "evolve" # check if we should filter any columns, partial table below contains only new columns @@ -402,14 +395,20 @@ def resolve_contract_settings_for_table( # expand settings, empty settings will expand into default settings return Schema.expand_schema_contract_settings(settings) - def update_table(self, partial_table: TPartialTableSchema) -> TPartialTableSchema: - """Adds or merges `partial_table` into the schema. Identifiers are not normalized""" + def update_table( + self, partial_table: TPartialTableSchema, normalize_identifiers: bool = True + ) -> TPartialTableSchema: + """Adds or merges `partial_table` into the schema. Identifiers are normalized by default""" + if normalize_identifiers: + partial_table = utils.normalize_table_identifiers(partial_table, self.naming) + table_name = partial_table["name"] parent_table_name = partial_table.get("parent") # check if parent table present if parent_table_name is not None: if self._schema_tables.get(parent_table_name) is None: raise ParentTableNotFoundException( + self.name, table_name, parent_table_name, " This may be due to misconfigured excludes filter that fully deletes content" @@ -422,21 +421,20 @@ def update_table(self, partial_table: TPartialTableSchema) -> TPartialTableSchem self._schema_tables[table_name] = partial_table else: # merge tables performing additional checks - partial_table = utils.merge_table(table, partial_table) + partial_table = utils.merge_table(self.name, table, partial_table) self.data_item_normalizer.extend_table(table_name) return partial_table def update_schema(self, schema: "Schema") -> None: """Updates this schema from an incoming schema. Normalizes identifiers after updating normalizers.""" - # update all tables - for table in schema.tables.values(): - self.update_table(table) # pass normalizer config - self._configure_normalizers(schema._normalizers_config) - # update and compile settings self._settings = deepcopy(schema.settings) + self._configure_normalizers(schema._normalizers_config) self._compile_settings() + # update all tables + for table in schema.tables.values(): + self.update_table(table) def drop_tables( self, table_names: Sequence[str], seen_data_only: bool = False @@ -444,7 +442,7 @@ def drop_tables( """Drops tables from the schema and returns the dropped tables""" result = [] for table_name in table_names: - table = self.tables.get(table_name) + table = self.get_table(table_name) if table and (not seen_data_only or utils.has_table_seen_data(table)): result.append(self._schema_tables.pop(table_name)) return result @@ -467,67 +465,70 @@ def filter_row_with_hint(self, table_name: str, hint_type: TColumnHint, row: Str # dicts are ordered and we will return the rows with hints in the same order as they appear in the columns return rv_row - def merge_hints(self, new_hints: Mapping[TColumnHint, Sequence[TSimpleRegex]]) -> None: - # validate regexes - validate_dict( - TSchemaSettings, - {"default_hints": new_hints}, - ".", - validator_f=utils.simple_regex_validator, - ) - # prepare hints to be added - default_hints = self._settings.setdefault("default_hints", {}) - # add `new_hints` to existing hints - for h, l in new_hints.items(): - if h in default_hints: - extend_list_deduplicated(default_hints[h], l) - else: - # set new hint type - default_hints[h] = l # type: ignore + def merge_hints( + self, + new_hints: Mapping[TColumnHint, Sequence[TSimpleRegex]], + normalize_identifiers: bool = True, + ) -> None: + """Merges existing default hints with `new_hints`. Normalizes names in column regexes if possible. Compiles setting at the end + + NOTE: you can manipulate default hints collection directly via `Schema.settings` as long as you call Schema._compile_settings() at the end. + """ + self._merge_hints(new_hints, normalize_identifiers) self._compile_settings() - def normalize_table_identifiers(self, table: TTableSchema) -> TTableSchema: - """Normalizes all table and column names in `table` schema according to current schema naming convention and returns - new normalized TTableSchema instance. + def update_preferred_types( + self, + new_preferred_types: Mapping[TSimpleRegex, TDataType], + normalize_identifiers: bool = True, + ) -> None: + """Updates preferred types dictionary with `new_preferred_types`. Normalizes names in column regexes if possible. Compiles setting at the end - Naming convention like snake_case may produce name clashes with the column names. Clashing column schemas are merged - where the column that is defined later in the dictionary overrides earlier column. + NOTE: you can manipulate preferred hints collection directly via `Schema.settings` as long as you call Schema._compile_settings() at the end. + """ + self._update_preferred_types(new_preferred_types, normalize_identifiers) + self._compile_settings() - Note that resource name is not normalized. + def add_type_detection(self, detection: TTypeDetections) -> None: + """Add type auto detection to the schema.""" + if detection not in self.settings["detections"]: + self.settings["detections"].append(detection) + self._compile_settings() - """ - # normalize all identifiers in table according to name normalizer of the schema - table["name"] = self.naming.normalize_tables_path(table["name"]) - parent = table.get("parent") - if parent: - table["parent"] = self.naming.normalize_tables_path(parent) - columns = table.get("columns") - if columns: - new_columns: TTableSchemaColumns = {} - for c in columns.values(): - new_col_name = c["name"] = self.naming.normalize_path(c["name"]) - # re-index columns as the name changed, if name space was reduced then - # some columns now clash with each other. so make sure that we merge columns that are already there - if new_col_name in new_columns: - new_columns[new_col_name] = utils.merge_column( - new_columns[new_col_name], c, merge_defaults=False - ) - else: - new_columns[new_col_name] = c - table["columns"] = new_columns - return table + def remove_type_detection(self, detection: TTypeDetections) -> None: + """Adds type auto detection to the schema.""" + if detection in self.settings["detections"]: + self.settings["detections"].remove(detection) + self._compile_settings() def get_new_table_columns( self, table_name: str, - exiting_columns: TTableSchemaColumns, + existing_columns: TTableSchemaColumns, + case_sensitive: bool, include_incomplete: bool = False, ) -> List[TColumnSchema]: - """Gets new columns to be added to `exiting_columns` to bring them up to date with `table_name` schema. Optionally includes incomplete columns (without data type)""" + """Gets new columns to be added to `existing_columns` to bring them up to date with `table_name` schema. + Columns names are compared case sensitive by default. `existing_column` names are expected to be normalized. + Typically they come from the destination schema. Columns that are in `existing_columns` and not in `table_name` columns are ignored. + + Optionally includes incomplete columns (without data type)""" + casefold_f: Callable[[str], str] = str.casefold if not case_sensitive else str # type: ignore[assignment] + casefold_existing = { + casefold_f(col_name): col for col_name, col in existing_columns.items() + } + if len(existing_columns) != len(casefold_existing): + raise SchemaCorruptedException( + self.name, + f"A set of existing columns passed to get_new_table_columns table {table_name} has" + " colliding names when case insensitive comparison is used. Original names:" + f" {list(existing_columns.keys())}. Case-folded names:" + f" {list(casefold_existing.keys())}", + ) diff_c: List[TColumnSchema] = [] - s_t = self.get_table_columns(table_name, include_incomplete=include_incomplete) - for c in s_t.values(): - if c["name"] not in exiting_columns: + updated_columns = self.get_table_columns(table_name, include_incomplete=include_incomplete) + for c in updated_columns.values(): + if casefold_f(c["name"]) not in casefold_existing: diff_c.append(c) return diff_c @@ -564,9 +565,16 @@ def data_tables( ) ] - def data_table_names(self) -> List[str]: + def data_table_names( + self, seen_data_only: bool = False, include_incomplete: bool = False + ) -> List[str]: """Returns list of table table names. Excludes dlt table names.""" - return [t["name"] for t in self.data_tables()] + return [ + t["name"] + for t in self.data_tables( + seen_data_only=seen_data_only, include_incomplete=include_incomplete + ) + ] def dlt_tables(self) -> List[TTableSchema]: """Gets dlt tables""" @@ -651,20 +659,70 @@ def tables(self) -> TSchemaTables: def settings(self) -> TSchemaSettings: return self._settings - def to_pretty_json(self, remove_defaults: bool = True) -> str: - d = self.to_dict(remove_defaults=remove_defaults) + def to_dict( + self, + remove_defaults: bool = False, + remove_processing_hints: bool = False, + bump_version: bool = True, + ) -> TStoredSchema: + stored_schema: TStoredSchema = { + "version": self._stored_version, + "version_hash": self._stored_version_hash, + "engine_version": Schema.ENGINE_VERSION, + "name": self._schema_name, + "tables": self._schema_tables, + "settings": self._settings, + "normalizers": self._normalizers_config, + "previous_hashes": self._stored_previous_hashes, + } + if self._imported_version_hash and not remove_defaults: + stored_schema["imported_version_hash"] = self._imported_version_hash + if self._schema_description: + stored_schema["description"] = self._schema_description + + # remove processing hints that could be created by normalize and load steps + if remove_processing_hints: + stored_schema["tables"] = utils.remove_processing_hints( + deepcopy(stored_schema["tables"]) + ) + + # bump version if modified + if bump_version: + utils.bump_version_if_modified(stored_schema) + # remove defaults after bumping version + if remove_defaults: + utils.remove_defaults(stored_schema) + return stored_schema + + def to_pretty_json( + self, remove_defaults: bool = True, remove_processing_hints: bool = False + ) -> str: + d = self.to_dict( + remove_defaults=remove_defaults, remove_processing_hints=remove_processing_hints + ) return utils.to_pretty_json(d) - def to_pretty_yaml(self, remove_defaults: bool = True) -> str: - d = self.to_dict(remove_defaults=remove_defaults) + def to_pretty_yaml( + self, remove_defaults: bool = True, remove_processing_hints: bool = False + ) -> str: + d = self.to_dict( + remove_defaults=remove_defaults, remove_processing_hints=remove_processing_hints + ) return utils.to_pretty_yaml(d) - def clone(self, with_name: str = None, update_normalizers: bool = False) -> "Schema": - """Make a deep copy of the schema, optionally changing the name, and updating normalizers and identifiers in the schema if `update_normalizers` is True - - Note that changing of name will set the schema as new + def clone( + self, + with_name: str = None, + remove_processing_hints: bool = False, + update_normalizers: bool = False, + ) -> "Schema": + """Make a deep copy of the schema, optionally changing the name, removing processing markers and updating normalizers and identifiers in the schema if `update_normalizers` is True + Processing markers are `x-` hints created by normalizer (`x-normalizer`) and loader (`x-loader`) to ie. mark newly inferred tables and tables that seen data. + Note that changing of name will break the previous version chain """ - d = deepcopy(self.to_dict(bump_version=False)) + d = deepcopy( + self.to_dict(bump_version=False, remove_processing_hints=remove_processing_hints) + ) if with_name is not None: d["version"] = d["version_hash"] = None d.pop("imported_version_hash", None) @@ -677,12 +735,23 @@ def clone(self, with_name: str = None, update_normalizers: bool = False) -> "Sch return schema def update_normalizers(self) -> None: - """Looks for new normalizer configuration or for destination capabilities context and updates all identifiers in the schema""" - normalizers = explicit_normalizers() - # set the current values as defaults - normalizers["names"] = normalizers["names"] or self._normalizers_config["names"] - normalizers["json"] = normalizers["json"] or self._normalizers_config["json"] - self._configure_normalizers(normalizers) + """Looks for new normalizer configuration or for destination capabilities context and updates all identifiers in the schema + + Table and column names will be normalized with new naming convention, except tables that have seen data ('x-normalizer`) which will + raise if any identifier is to be changed. + Default hints, preferred data types and normalize configs (ie. column propagation) are normalized as well. Regexes are included as long + as textual parts can be extracted from an expression. + """ + self._configure_normalizers(explicit_normalizers(schema_name=self._schema_name)) + self._compile_settings() + + def will_update_normalizers(self) -> bool: + """Checks if schema has any pending normalizer updates due to configuration or destination capabilities""" + # import desired modules + _, to_naming, _ = import_normalizers( + explicit_normalizers(schema_name=self._schema_name), self._normalizers_config + ) + return type(to_naming) is not type(self.naming) # noqa def set_schema_contract(self, settings: TSchemaContract) -> None: if not settings: @@ -690,18 +759,6 @@ def set_schema_contract(self, settings: TSchemaContract) -> None: else: self._settings["schema_contract"] = settings - def add_type_detection(self, detection: TTypeDetections) -> None: - """Add type auto detection to the schema.""" - if detection not in self.settings["detections"]: - self.settings["detections"].append(detection) - self._compile_settings() - - def remove_type_detection(self, detection: TTypeDetections) -> None: - """Adds type auto detection to the schema.""" - if detection in self.settings["detections"]: - self.settings["detections"].remove(detection) - self._compile_settings() - def _infer_column( self, k: str, v: Any, data_type: TDataType = None, is_variant: bool = False ) -> TColumnSchema: @@ -727,7 +784,7 @@ def _coerce_null_value( if col_name in table_columns: existing_column = table_columns[col_name] if not existing_column.get("nullable", True): - raise CannotCoerceNullException(table_name, col_name) + raise CannotCoerceNullException(self.name, table_name, col_name) def _coerce_non_null_value( self, @@ -759,7 +816,12 @@ def _coerce_non_null_value( if is_variant: # this is final call: we cannot generate any more auto-variants raise CannotCoerceColumnException( - table_name, col_name, py_type, table_columns[col_name]["data_type"], v + self.name, + table_name, + col_name, + py_type, + table_columns[col_name]["data_type"], + v, ) # otherwise we must create variant extension to the table # pass final=True so no more auto-variants can be created recursively @@ -816,6 +878,57 @@ def _infer_hint(self, hint_type: TColumnHint, _: Any, col_name: str) -> bool: else: return False + def _merge_hints( + self, + new_hints: Mapping[TColumnHint, Sequence[TSimpleRegex]], + normalize_identifiers: bool = True, + ) -> None: + """Used by `merge_hints method, does not compile settings at the end""" + # validate regexes + validate_dict( + TSchemaSettings, + {"default_hints": new_hints}, + ".", + validator_f=utils.simple_regex_validator, + ) + if normalize_identifiers: + new_hints = self._normalize_default_hints(new_hints) + # prepare hints to be added + default_hints = self._settings.setdefault("default_hints", {}) + # add `new_hints` to existing hints + for h, l in new_hints.items(): + if h in default_hints: + extend_list_deduplicated(default_hints[h], l, utils.canonical_simple_regex) + else: + # set new hint type + default_hints[h] = l # type: ignore + + def _update_preferred_types( + self, + new_preferred_types: Mapping[TSimpleRegex, TDataType], + normalize_identifiers: bool = True, + ) -> None: + # validate regexes + validate_dict( + TSchemaSettings, + {"preferred_types": new_preferred_types}, + ".", + validator_f=utils.simple_regex_validator, + ) + if normalize_identifiers: + new_preferred_types = self._normalize_preferred_types(new_preferred_types) + preferred_types = self._settings.setdefault("preferred_types", {}) + # we must update using canonical simple regex + canonical_preferred = { + utils.canonical_simple_regex(rx): rx for rx in preferred_types.keys() + } + for new_rx, new_dt in new_preferred_types.items(): + canonical_new_rx = utils.canonical_simple_regex(new_rx) + if canonical_new_rx not in canonical_preferred: + preferred_types[new_rx] = new_dt + else: + preferred_types[canonical_preferred[canonical_new_rx]] = new_dt + def _bump_version(self) -> Tuple[int, str]: """Computes schema hash in order to check if schema content was modified. In such case the schema ``stored_version`` and ``stored_version_hash`` are updated. @@ -839,40 +952,175 @@ def _drop_version(self) -> None: self._stored_version_hash = self._stored_previous_hashes.pop(0) def _add_standard_tables(self) -> None: - self._schema_tables[self.version_table_name] = self.normalize_table_identifiers( - utils.version_table() + self._schema_tables[self.version_table_name] = utils.normalize_table_identifiers( + utils.version_table(), self.naming ) - self._schema_tables[self.loads_table_name] = self.normalize_table_identifiers( - utils.load_table() + self._schema_tables[self.loads_table_name] = utils.normalize_table_identifiers( + utils.loads_table(), self.naming ) def _add_standard_hints(self) -> None: - default_hints = utils.standard_hints() + default_hints = utils.default_hints() if default_hints: - self._settings["default_hints"] = default_hints + self._merge_hints(default_hints, normalize_identifiers=False) type_detections = utils.standard_type_detections() if type_detections: self._settings["detections"] = type_detections - def _configure_normalizers(self, normalizers: TNormalizersConfig) -> None: - # import desired modules - self._normalizers_config, naming_module, item_normalizer_class = import_normalizers( - normalizers + def _normalize_default_hints( + self, default_hints: Mapping[TColumnHint, Sequence[TSimpleRegex]] + ) -> Dict[TColumnHint, List[TSimpleRegex]]: + """Normalizes the column names in default hints. In case of column names that are regexes, normalization is skipped""" + return { + hint: [utils.normalize_simple_regex_column(self.naming, regex) for regex in regexes] + for hint, regexes in default_hints.items() + } + + def _normalize_preferred_types( + self, preferred_types: Mapping[TSimpleRegex, TDataType] + ) -> Dict[TSimpleRegex, TDataType]: + """Normalizes the column names in preferred types mapping. In case of column names that are regexes, normalization is skipped""" + return { + utils.normalize_simple_regex_column(self.naming, regex): data_type + for regex, data_type in preferred_types.items() + } + + def _verify_update_normalizers( + self, + normalizers_config: TNormalizersConfig, + to_naming: NamingConvention, + from_naming: NamingConvention, + ) -> TSchemaTables: + """Verifies if normalizers can be updated before schema is changed""" + allow_ident_change = normalizers_config.get( + "allow_identifier_change_on_table_with_data", False ) - # print(f"{self.name}: {type(self.naming)} {type(naming_module)}") - if self.naming and type(self.naming) is not type(naming_module): - self.naming = naming_module - for table in self._schema_tables.values(): - self.normalize_table_identifiers(table) + + def _verify_identifiers(table: TTableSchema, norm_table: TTableSchema) -> None: + if not allow_ident_change: + # make sure no identifier got changed in table + if norm_table["name"] != table["name"]: + raise TableIdentifiersFrozen( + self.name, + table["name"], + to_naming, + from_naming, + f"Attempt to rename table name to {norm_table['name']}.", + ) + # if len(norm_table["columns"]) != len(table["columns"]): + # print(norm_table["columns"]) + # raise TableIdentifiersFrozen( + # self.name, + # table["name"], + # to_naming, + # from_naming, + # "Number of columns changed after normalization. Some columns must have" + # " merged.", + # ) + col_diff = set(norm_table["columns"].keys()).symmetric_difference( + table["columns"].keys() + ) + if len(col_diff) > 0: + raise TableIdentifiersFrozen( + self.name, + table["name"], + to_naming, + from_naming, + f"Some columns got renamed to {col_diff}.", + ) + + naming_changed = from_naming and type(from_naming) is not type(to_naming) + if naming_changed: + schema_tables = {} + # check dlt tables + schema_seen_data = any( + utils.has_table_seen_data(t) for t in self._schema_tables.values() + ) + # modify dlt tables using original naming + orig_dlt_tables = [ + (self.version_table_name, utils.version_table()), + (self.loads_table_name, utils.loads_table()), + (self.state_table_name, utils.pipeline_state_table(add_dlt_id=True)), + ] + for existing_table_name, original_table in orig_dlt_tables: + table = self._schema_tables.get(existing_table_name) + # state table is optional + if table: + table = copy(table) + # keep all attributes of the schema table, copy only what we need to normalize + table["columns"] = original_table["columns"] + norm_table = utils.normalize_table_identifiers(table, to_naming) + table_seen_data = utils.has_table_seen_data(norm_table) + if schema_seen_data: + _verify_identifiers(table, norm_table) + schema_tables[norm_table["name"]] = norm_table + + schema_seen_data = False + for table in self.data_tables(include_incomplete=True): + # TODO: when lineage is fully implemented we should use source identifiers + # not `table` which was already normalized + norm_table = utils.normalize_table_identifiers(table, to_naming) + table_seen_data = utils.has_table_seen_data(norm_table) + if table_seen_data: + _verify_identifiers(table, norm_table) + schema_tables[norm_table["name"]] = norm_table + schema_seen_data |= table_seen_data + if schema_seen_data and not allow_ident_change: + # if any of the tables has seen data, fail naming convention change + # NOTE: this will be dropped with full identifier lineage. currently we cannot detect + # strict schemas being changed to lax + raise TableIdentifiersFrozen( + self.name, + "-", + to_naming, + from_naming, + "Schema contains tables that received data. As a precaution changing naming" + " conventions is disallowed until full identifier lineage is implemented.", + ) # re-index the table names - self._schema_tables = {t["name"]: t for t in self._schema_tables.values()} + return schema_tables + else: + return self._schema_tables + def _renormalize_schema_identifiers( + self, + normalizers_config: TNormalizersConfig, + to_naming: NamingConvention, + from_naming: NamingConvention, + ) -> None: + """Normalizes all identifiers in the schema in place""" + self._schema_tables = self._verify_update_normalizers( + normalizers_config, to_naming, from_naming + ) + self._normalizers_config = normalizers_config + self.naming = to_naming # name normalization functions - self.naming = naming_module - self._dlt_tables_prefix = self.naming.normalize_table_identifier(DLT_NAME_PREFIX) - self.version_table_name = self.naming.normalize_table_identifier(VERSION_TABLE_NAME) - self.loads_table_name = self.naming.normalize_table_identifier(LOADS_TABLE_NAME) - self.state_table_name = self.naming.normalize_table_identifier(STATE_TABLE_NAME) + self._dlt_tables_prefix = to_naming.normalize_table_identifier(DLT_NAME_PREFIX) + self.version_table_name = to_naming.normalize_table_identifier(VERSION_TABLE_NAME) + self.loads_table_name = to_naming.normalize_table_identifier(LOADS_TABLE_NAME) + self.state_table_name = to_naming.normalize_table_identifier(PIPELINE_STATE_TABLE_NAME) + # do a sanity check - dlt tables must start with dlt prefix + for table_name in [self.version_table_name, self.loads_table_name, self.state_table_name]: + if not table_name.startswith(self._dlt_tables_prefix): + raise SchemaCorruptedException( + self.name, + f"A naming convention {self.naming.name()} mangles _dlt table prefix to" + f" '{self._dlt_tables_prefix}'. A table '{table_name}' does not start with it.", + ) + # normalize default hints + if default_hints := self._settings.get("default_hints"): + self._settings["default_hints"] = self._normalize_default_hints(default_hints) + # normalized preferred types + if preferred_types := self.settings.get("preferred_types"): + self._settings["preferred_types"] = self._normalize_preferred_types(preferred_types) + + def _configure_normalizers(self, explicit_normalizers: TNormalizersConfig) -> None: + """Gets naming and item normalizer from schema yaml, config providers and destination capabilities and applies them to schema.""" + # import desired modules + normalizers_config, to_naming, item_normalizer_class = import_normalizers( + explicit_normalizers, self._normalizers_config + ) + self._renormalize_schema_identifiers(normalizers_config, to_naming, self.naming) # data item normalization function self.data_item_normalizer = item_normalizer_class(self) self.data_item_normalizer.extend_schema() @@ -903,7 +1151,7 @@ def _reset_schema(self, name: str, normalizers: TNormalizersConfig = None) -> No self._add_standard_hints() # configure normalizers, including custom config if present if not normalizers: - normalizers = explicit_normalizers() + normalizers = explicit_normalizers(schema_name=self._schema_name) self._configure_normalizers(normalizers) # add version tables self._add_standard_tables() @@ -913,9 +1161,13 @@ def _reset_schema(self, name: str, normalizers: TNormalizersConfig = None) -> No def _from_stored_schema(self, stored_schema: TStoredSchema) -> None: self._schema_tables = stored_schema.get("tables") or {} if self.version_table_name not in self._schema_tables: - raise SchemaCorruptedException(f"Schema must contain table {self.version_table_name}") + raise SchemaCorruptedException( + stored_schema["name"], f"Schema must contain table {self.version_table_name}" + ) if self.loads_table_name not in self._schema_tables: - raise SchemaCorruptedException(f"Schema must contain table {self.loads_table_name}") + raise SchemaCorruptedException( + stored_schema["name"], f"Schema must contain table {self.loads_table_name}" + ) self._stored_version = stored_schema["version"] self._stored_version_hash = stored_schema["version_hash"] self._imported_version_hash = stored_schema.get("imported_version_hash") diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index fb360b38d3..9a4dd51d4b 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -17,8 +17,7 @@ from dlt.common.data_types import TDataType from dlt.common.normalizers.typing import TNormalizersConfig -from dlt.common.typing import TSortOrder, TAnyDateTime -from dlt.common.pendulum import pendulum +from dlt.common.typing import TSortOrder, TAnyDateTime, TLoaderFileFormat try: from pydantic import BaseModel as _PydanticBaseModel @@ -32,7 +31,7 @@ # dlt tables VERSION_TABLE_NAME = "_dlt_version" LOADS_TABLE_NAME = "_dlt_loads" -STATE_TABLE_NAME = "_dlt_pipeline_state" +PIPELINE_STATE_TABLE_NAME = "_dlt_pipeline_state" DLT_NAME_PREFIX = "_dlt" TColumnProp = Literal[ @@ -47,6 +46,7 @@ "unique", "merge_key", "root_key", + "hard_delete", "dedup_sort", ] """Known properties and hints of the column""" @@ -59,12 +59,15 @@ "foreign_key", "sort", "unique", - "root_key", "merge_key", + "root_key", + "hard_delete", "dedup_sort", ] """Known hints of a column used to declare hint regexes.""" + TTableFormat = Literal["iceberg", "delta"] +TFileFormat = Literal[Literal["preferred"], TLoaderFileFormat] TTypeDetections = Literal[ "timestamp", "iso_timestamp", "iso_date", "large_integer", "hexbytes_to_text", "wei_to_double" ] @@ -72,7 +75,7 @@ TColumnNames = Union[str, Sequence[str]] """A string representing a column name or a list of""" -COLUMN_PROPS: Set[TColumnProp] = set(get_args(TColumnProp)) +# COLUMN_PROPS: Set[TColumnProp] = set(get_args(TColumnProp)) COLUMN_HINTS: Set[TColumnHint] = set( [ "partition", @@ -153,8 +156,20 @@ class NormalizerInfo(TypedDict, total=True): new_table: bool +# Part of Table containing processing hints added by pipeline stages +TTableProcessingHints = TypedDict( + "TTableProcessingHints", + { + "x-normalizer": Optional[Dict[str, Any]], + "x-loader": Optional[Dict[str, Any]], + "x-extractor": Optional[Dict[str, Any]], + }, + total=False, +) + + TWriteDisposition = Literal["skip", "append", "replace", "merge"] -TLoaderMergeStrategy = Literal["delete-insert", "scd2"] +TLoaderMergeStrategy = Literal["delete-insert", "scd2", "upsert"] WRITE_DISPOSITIONS: Set[TWriteDisposition] = set(get_args(TWriteDisposition)) @@ -178,7 +193,8 @@ class TMergeDispositionDict(TWriteDispositionDict, total=False): TWriteDispositionConfig = Union[TWriteDisposition, TWriteDispositionDict, TMergeDispositionDict] -class TTableSchema(TypedDict, total=False): +# TypedDict that defines properties of a table +class TTableSchema(TTableProcessingHints, total=False): """TypedDict that defines properties of a table""" name: Optional[str] @@ -191,6 +207,7 @@ class TTableSchema(TypedDict, total=False): columns: TTableSchemaColumns resource: Optional[str] table_format: Optional[TTableFormat] + file_format: Optional[TFileFormat] class TPartialTableSchema(TTableSchema): diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index 51269cbb38..cd0cc5aa63 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -7,6 +7,7 @@ from dlt.common.pendulum import pendulum from dlt.common.time import ensure_pendulum_datetime +from dlt.common import logger from dlt.common.json import json from dlt.common.data_types import TDataType from dlt.common.exceptions import DictValidationException @@ -21,12 +22,15 @@ LOADS_TABLE_NAME, SIMPLE_REGEX_PREFIX, VERSION_TABLE_NAME, + PIPELINE_STATE_TABLE_NAME, TColumnName, + TFileFormat, TPartialTableSchema, TSchemaTables, TSchemaUpdate, TSimpleRegex, TStoredSchema, + TTableProcessingHints, TTableSchema, TColumnSchemaBase, TColumnSchema, @@ -96,7 +100,8 @@ def apply_defaults(stored_schema: TStoredSchema) -> TStoredSchema: def remove_defaults(stored_schema: TStoredSchema) -> TStoredSchema: """Removes default values from `stored_schema` in place, returns the input for chaining - Default values are removed from table schemas and complete column schemas. Incomplete columns are preserved intact. + * removes column and table names from the value + * removed resource name if same as table name """ clean_tables = deepcopy(stored_schema["tables"]) for table_name, t in clean_tables.items(): @@ -202,6 +207,33 @@ def verify_schema_hash( return hash_ == stored_schema["version_hash"] +def normalize_simple_regex_column(naming: NamingConvention, regex: TSimpleRegex) -> TSimpleRegex: + """Assumes that regex applies to column name and normalizes it.""" + + def _normalize(r_: str) -> str: + is_exact = len(r_) >= 2 and r_[0] == "^" and r_[-1] == "$" + if is_exact: + r_ = r_[1:-1] + # if this a simple string then normalize it + if r_ == re.escape(r_): + r_ = naming.normalize_path(r_) + if is_exact: + r_ = "^" + r_ + "$" + return r_ + + if regex.startswith(SIMPLE_REGEX_PREFIX): + return cast(TSimpleRegex, SIMPLE_REGEX_PREFIX + _normalize(regex[3:])) + else: + return cast(TSimpleRegex, _normalize(regex)) + + +def canonical_simple_regex(regex: str) -> TSimpleRegex: + if regex.startswith(SIMPLE_REGEX_PREFIX): + return cast(TSimpleRegex, regex) + else: + return cast(TSimpleRegex, SIMPLE_REGEX_PREFIX + "^" + regex + "$") + + def simple_regex_validator(path: str, pk: str, pv: Any, t: Any) -> bool: # custom validator on type TSimpleRegex if t is TSimpleRegex: @@ -237,7 +269,7 @@ def simple_regex_validator(path: str, pk: str, pv: Any, t: Any) -> bool: # we know how to validate that type return True else: - # don't know how to validate t + # don't know how to validate this return False @@ -299,7 +331,9 @@ def validate_stored_schema(stored_schema: TStoredSchema) -> None: parent_table_name = table.get("parent") if parent_table_name: if parent_table_name not in stored_schema["tables"]: - raise ParentTableNotFoundException(table_name, parent_table_name) + raise ParentTableNotFoundException( + stored_schema["name"], table_name, parent_table_name + ) def autodetect_sc_type(detection_fs: Sequence[TTypeDetections], t: Type[Any], v: Any) -> TDataType: @@ -370,7 +404,9 @@ def merge_columns( return columns_a -def diff_table(tab_a: TTableSchema, tab_b: TPartialTableSchema) -> TPartialTableSchema: +def diff_table( + schema_name: str, tab_a: TTableSchema, tab_b: TPartialTableSchema +) -> TPartialTableSchema: """Creates a partial table that contains properties found in `tab_b` that are not present or different in `tab_a`. The name is always present in returned partial. It returns new columns (not present in tab_a) and merges columns from tab_b into tab_a (overriding non-default hint values). @@ -384,7 +420,7 @@ def diff_table(tab_a: TTableSchema, tab_b: TPartialTableSchema) -> TPartialTable # check if table properties can be merged if tab_a.get("parent") != tab_b.get("parent"): raise TablePropertiesConflictException( - table_name, "parent", tab_a.get("parent"), tab_b.get("parent") + schema_name, table_name, "parent", tab_a.get("parent"), tab_b.get("parent") ) # get new columns, changes in the column data type or other properties are not allowed @@ -398,6 +434,7 @@ def diff_table(tab_a: TTableSchema, tab_b: TPartialTableSchema) -> TPartialTable if not compare_complete_columns(tab_a_columns[col_b_name], col_b): # attempt to update to incompatible columns raise CannotCoerceColumnException( + schema_name, table_name, col_b_name, col_b["data_type"], @@ -426,7 +463,7 @@ def diff_table(tab_a: TTableSchema, tab_b: TPartialTableSchema) -> TPartialTable # this should not really happen if tab_a.get("parent") is not None and (resource := tab_b.get("resource")): raise TablePropertiesConflictException( - table_name, "resource", resource, tab_a.get("parent") + schema_name, table_name, "resource", resource, tab_a.get("parent") ) return partial_table @@ -444,7 +481,9 @@ def diff_table(tab_a: TTableSchema, tab_b: TPartialTableSchema) -> TPartialTable # return False -def merge_table(table: TTableSchema, partial_table: TPartialTableSchema) -> TPartialTableSchema: +def merge_table( + schema_name: str, table: TTableSchema, partial_table: TPartialTableSchema +) -> TPartialTableSchema: """Merges "partial_table" into "table". `table` is merged in place. Returns the diff partial table. `table` and `partial_table` names must be identical. A table diff is generated and applied to `table`: @@ -456,9 +495,10 @@ def merge_table(table: TTableSchema, partial_table: TPartialTableSchema) -> TPar if table["name"] != partial_table["name"]: raise TablePropertiesConflictException( - table["name"], "name", table["name"], partial_table["name"] + schema_name, table["name"], "name", table["name"], partial_table["name"] ) - diff = diff_table(table, partial_table) + diff = diff_table(schema_name, table, partial_table) + # add new columns when all checks passed updated_columns = merge_columns(table["columns"], diff["columns"]) table.update(diff) table["columns"] = updated_columns @@ -466,9 +506,67 @@ def merge_table(table: TTableSchema, partial_table: TPartialTableSchema) -> TPar return diff +def normalize_table_identifiers(table: TTableSchema, naming: NamingConvention) -> TTableSchema: + """Normalizes all table and column names in `table` schema according to current schema naming convention and returns + new instance with modified table schema. + + Naming convention like snake_case may produce name collisions with the column names. Colliding column schemas are merged + where the column that is defined later in the dictionary overrides earlier column. + + Note that resource name is not normalized. + """ + + table = copy(table) + table["name"] = naming.normalize_tables_path(table["name"]) + parent = table.get("parent") + if parent: + table["parent"] = naming.normalize_tables_path(parent) + columns = table.get("columns") + if columns: + new_columns: TTableSchemaColumns = {} + for c in columns.values(): + c = copy(c) + origin_c_name = c["name"] + new_col_name = c["name"] = naming.normalize_path(c["name"]) + # re-index columns as the name changed, if name space was reduced then + # some columns now collide with each other. so make sure that we merge columns that are already there + if new_col_name in new_columns: + new_columns[new_col_name] = merge_column( + new_columns[new_col_name], c, merge_defaults=False + ) + logger.warning( + f"In schema {naming} column {origin_c_name} got normalized into" + f" {new_col_name} which collides with other column. Both columns got merged" + " into one." + ) + else: + new_columns[new_col_name] = c + table["columns"] = new_columns + return table + + def has_table_seen_data(table: TTableSchema) -> bool: """Checks if normalizer has seen data coming to the table.""" - return "x-normalizer" in table and table["x-normalizer"].get("seen-data", None) is True # type: ignore[typeddict-item] + return "x-normalizer" in table and table["x-normalizer"].get("seen-data", None) is True + + +def remove_processing_hints(tables: TSchemaTables) -> TSchemaTables: + "Removes processing hints like x-normalizer and x-loader from schema tables. Modifies the input tables and returns it for convenience" + for table_name, hints in get_processing_hints(tables).items(): + for hint in hints: + del tables[table_name][hint] # type: ignore[misc] + return tables + + +def get_processing_hints(tables: TSchemaTables) -> Dict[str, List[str]]: + """Finds processing hints in a set of tables and returns table_name: [hints] mapping""" + hints: Dict[str, List[str]] = {} + for table in tables.values(): + for hint in TTableProcessingHints.__annotations__.keys(): + if hint in table: + table_hints = hints.setdefault(table["name"], []) + table_hints.append(hint) + return hints def hint_to_column_prop(h: TColumnHint) -> TColumnProp: @@ -581,6 +679,12 @@ def get_table_format(tables: TSchemaTables, table_name: str) -> TTableFormat: ) +def get_file_format(tables: TSchemaTables, table_name: str) -> TFileFormat: + return cast( + TFileFormat, get_inherited_table_hint(tables, table_name, "file_format", allow_none=True) + ) + + def fill_hints_from_parent_and_clone_table( tables: TSchemaTables, table: TTableSchema ) -> TTableSchema: @@ -592,6 +696,8 @@ def fill_hints_from_parent_and_clone_table( table["write_disposition"] = get_write_disposition(tables, table["name"]) if "table_format" not in table: table["table_format"] = get_table_format(tables, table["name"]) + if "file_format" not in table: + table["file_format"] = get_file_format(tables, table["name"]) return table @@ -650,6 +756,8 @@ def group_tables_by_resource( def version_table() -> TTableSchema: # NOTE: always add new columns at the end of the table so we have identical layout # after an update of existing tables (always at the end) + # set to nullable so we can migrate existing tables + # WARNING: do not reorder the columns table = new_table( VERSION_TABLE_NAME, columns=[ @@ -670,9 +778,11 @@ def version_table() -> TTableSchema: return table -def load_table() -> TTableSchema: +def loads_table() -> TTableSchema: # NOTE: always add new columns at the end of the table so we have identical layout # after an update of existing tables (always at the end) + # set to nullable so we can migrate existing tables + # WARNING: do not reorder the columns table = new_table( LOADS_TABLE_NAME, columns=[ @@ -692,6 +802,33 @@ def load_table() -> TTableSchema: return table +def pipeline_state_table(add_dlt_id: bool = False) -> TTableSchema: + # NOTE: always add new columns at the end of the table so we have identical layout + # after an update of existing tables (always at the end) + # set to nullable so we can migrate existing tables + # WARNING: do not reorder the columns + columns: List[TColumnSchema] = [ + {"name": "version", "data_type": "bigint", "nullable": False}, + {"name": "engine_version", "data_type": "bigint", "nullable": False}, + {"name": "pipeline_name", "data_type": "text", "nullable": False}, + {"name": "state", "data_type": "text", "nullable": False}, + {"name": "created_at", "data_type": "timestamp", "nullable": False}, + {"name": "version_hash", "data_type": "text", "nullable": True}, + {"name": "_dlt_load_id", "data_type": "text", "nullable": False}, + ] + if add_dlt_id: + columns.append({"name": "_dlt_id", "data_type": "text", "nullable": False, "unique": True}) + table = new_table( + PIPELINE_STATE_TABLE_NAME, + write_disposition="append", + columns=columns, + # always use caps preferred file format for processing + file_format="preferred", + ) + table["description"] = "Created by DLT. Tracks pipeline state" + return table + + def new_table( table_name: str, parent_table_name: str = None, @@ -701,6 +838,7 @@ def new_table( resource: str = None, schema_contract: TSchemaContract = None, table_format: TTableFormat = None, + file_format: TFileFormat = None, ) -> TTableSchema: table: TTableSchema = { "name": table_name, @@ -719,6 +857,8 @@ def new_table( table["schema_contract"] = schema_contract if table_format: table["table_format"] = table_format + if file_format: + table["file_format"] = file_format if validate_schema: validate_dict_ignoring_xkeys( spec=TColumnSchema, @@ -754,7 +894,7 @@ def new_column( return column -def standard_hints() -> Dict[TColumnHint, List[TSimpleRegex]]: +def default_hints() -> Dict[TColumnHint, List[TSimpleRegex]]: return None diff --git a/dlt/common/storages/configuration.py b/dlt/common/storages/configuration.py index 09beb0015e..b2bdb3a7b6 100644 --- a/dlt/common/storages/configuration.py +++ b/dlt/common/storages/configuration.py @@ -111,8 +111,19 @@ def resolve_credentials_type(self) -> Type[CredentialsConfiguration]: return self.PROTOCOL_CREDENTIALS.get(self.protocol) or Optional[CredentialsConfiguration] # type: ignore[return-value] def fingerprint(self) -> str: - """Returns a fingerprint of bucket_url""" - return digest128(self.bucket_url) if self.bucket_url else "" + """Returns a fingerprint of bucket schema and netloc. + + Returns: + str: Fingerprint. + """ + if not self.bucket_url: + return "" + + if self.is_local_path(self.bucket_url): + return digest128("") + + uri = urlparse(self.bucket_url) + return digest128(self.bucket_url.replace(uri.path, "")) def __str__(self) -> str: """Return displayable destination location""" diff --git a/dlt/common/storages/data_item_storage.py b/dlt/common/storages/data_item_storage.py index f6072c0260..29a9da8acf 100644 --- a/dlt/common/storages/data_item_storage.py +++ b/dlt/common/storages/data_item_storage.py @@ -60,15 +60,18 @@ def import_items_file( table_name: str, file_path: str, metrics: DataWriterMetrics, + with_extension: str = None, ) -> DataWriterMetrics: """Import a file from `file_path` into items storage under a new file name. Does not check the imported file format. Uses counts from `metrics` as a base. Logically closes the imported file The preferred import method is a hard link to avoid copying the data. If current filesystem does not support it, a regular copy is used. + + Alternative extension may be provided via `with_extension` so various file formats may be imported into the same folder. """ writer = self._get_writer(load_id, schema_name, table_name) - return writer.import_file(file_path, metrics) + return writer.import_file(file_path, metrics, with_extension) def close_writers(self, load_id: str, skip_flush: bool = False) -> None: """Flush, write footers (skip_flush), write metrics and close files in all diff --git a/dlt/common/storages/exceptions.py b/dlt/common/storages/exceptions.py index 26a76bb5c0..028491dd9c 100644 --- a/dlt/common/storages/exceptions.py +++ b/dlt/common/storages/exceptions.py @@ -79,6 +79,23 @@ def __init__(self, load_id: str) -> None: super().__init__(f"Package with load id {load_id} could not be found") +class LoadPackageAlreadyCompleted(LoadStorageException): + def __init__(self, load_id: str) -> None: + self.load_id = load_id + super().__init__( + f"Package with load id {load_id} is already completed, but another complete was" + " requested" + ) + + +class LoadPackageNotCompleted(LoadStorageException): + def __init__(self, load_id: str) -> None: + self.load_id = load_id + super().__init__( + f"Package with load id {load_id} is not yet completed, but method required that" + ) + + class SchemaStorageException(StorageException): pass diff --git a/dlt/common/storages/file_storage.py b/dlt/common/storages/file_storage.py index d768ec720a..7d14b8f7f7 100644 --- a/dlt/common/storages/file_storage.py +++ b/dlt/common/storages/file_storage.py @@ -6,7 +6,7 @@ import tempfile import shutil import pathvalidate -from typing import IO, Any, Optional, List, cast, overload +from typing import IO, Any, Optional, List, cast from dlt.common.typing import AnyFun from dlt.common.utils import encoding_for_mode, uniq_id @@ -18,7 +18,7 @@ class FileStorage: def __init__(self, storage_path: str, file_type: str = "t", makedirs: bool = False) -> None: # make it absolute path - self.storage_path = os.path.realpath(storage_path) # os.path.join(, '') + self.storage_path = os.path.realpath(storage_path) self.file_type = file_type if makedirs: os.makedirs(storage_path, exist_ok=True) @@ -243,7 +243,8 @@ def atomic_import( FileStorage.move_atomic_to_file(external_file_path, dest_file_path) ) - def in_storage(self, path: str) -> bool: + def is_path_in_storage(self, path: str) -> bool: + """Checks if a given path is below storage root, without checking for item existence""" assert path is not None # all paths are relative to root if not os.path.isabs(path): @@ -256,25 +257,30 @@ def in_storage(self, path: str) -> bool: def to_relative_path(self, path: str) -> str: if path == "": return "" - if not self.in_storage(path): + if not self.is_path_in_storage(path): raise ValueError(path) if not os.path.isabs(path): path = os.path.realpath(os.path.join(self.storage_path, path)) # for abs paths find the relative return os.path.relpath(path, start=self.storage_path) - def make_full_path(self, path: str) -> str: + def make_full_path_safe(self, path: str) -> str: + """Verifies that path is under storage root and then returns normalized absolute path""" # try to make a relative path if paths are absolute or overlapping path = self.to_relative_path(path) # then assume that it is a path relative to storage root return os.path.realpath(os.path.join(self.storage_path, path)) + def make_full_path(self, path: str) -> str: + """Joins path with storage root. Intended for path known to be relative to storage root""" + return os.path.join(self.storage_path, path) + def from_wd_to_relative_path(self, wd_relative_path: str) -> str: path = os.path.realpath(wd_relative_path) return self.to_relative_path(path) def from_relative_path_to_wd(self, relative_path: str) -> str: - return os.path.relpath(self.make_full_path(relative_path), start=".") + return os.path.relpath(self.make_full_path_safe(relative_path), start=".") @staticmethod def get_file_name_from_file_path(file_path: str) -> str: diff --git a/dlt/common/storages/fsspec_filesystem.py b/dlt/common/storages/fsspec_filesystem.py index a21f0f2c0c..be9ae2bbb1 100644 --- a/dlt/common/storages/fsspec_filesystem.py +++ b/dlt/common/storages/fsspec_filesystem.py @@ -5,7 +5,6 @@ import pathlib import posixpath from io import BytesIO -from gzip import GzipFile from typing import ( Literal, cast, @@ -320,7 +319,7 @@ def glob_files( rel_path = pathlib.Path(file).relative_to(root_dir).as_posix() file_url = FilesystemConfiguration.make_file_uri(file) else: - rel_path = posixpath.relpath(file, root_dir) + rel_path = posixpath.relpath(file.lstrip("/"), root_dir) file_url = bucket_url_parsed._replace( path=posixpath.join(bucket_url_parsed.path, rel_path) ).geturl() diff --git a/dlt/common/storages/live_schema_storage.py b/dlt/common/storages/live_schema_storage.py index fd4ecc968e..1ecc491174 100644 --- a/dlt/common/storages/live_schema_storage.py +++ b/dlt/common/storages/live_schema_storage.py @@ -32,20 +32,6 @@ def remove_schema(self, name: str) -> None: # also remove the live schema self.live_schemas.pop(name, None) - def save_import_schema_if_not_exists(self, schema: Schema) -> bool: - """Saves import schema, if not exists. If schema was saved, link itself as imported from""" - if self.config.import_schema_path: - try: - self._load_import_schema(schema.name) - except FileNotFoundError: - # save import schema only if it not exist - self._export_schema(schema, self.config.import_schema_path) - # if import schema got saved then add own version hash as import version hash - schema._imported_version_hash = schema.version_hash - return True - - return False - def commit_live_schema(self, name: str) -> str: """Saves live schema in storage if it was modified""" if not self.is_live_schema_committed(name): diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index 4d72458e3e..4d84094427 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -5,7 +5,7 @@ import datetime # noqa: 251 import humanize -from pathlib import Path +from pathlib import PurePath from pendulum.datetime import DateTime from typing import ( ClassVar, @@ -37,7 +37,12 @@ from dlt.common.schema import Schema, TSchemaTables from dlt.common.schema.typing import TStoredSchema, TTableSchemaColumns, TTableSchema from dlt.common.storages import FileStorage -from dlt.common.storages.exceptions import LoadPackageNotFound, CurrentLoadPackageStateNotAvailable +from dlt.common.storages.exceptions import ( + LoadPackageAlreadyCompleted, + LoadPackageNotCompleted, + LoadPackageNotFound, + CurrentLoadPackageStateNotAvailable, +) from dlt.common.typing import DictStrAny, SupportsHumanize from dlt.common.utils import flatten_list_or_items from dlt.common.versioned_state import ( @@ -52,6 +57,7 @@ TJobFileFormat = Literal["sql", "reference", TLoaderFileFormat] """Loader file formats with internal job types""" +JOB_EXCEPTION_EXTENSION = ".exception" class TPipelineStateDoc(TypedDict, total=False): @@ -61,12 +67,19 @@ class TPipelineStateDoc(TypedDict, total=False): engine_version: int pipeline_name: str state: str - version_hash: str created_at: datetime.datetime - dlt_load_id: NotRequired[str] + version_hash: str + _dlt_load_id: NotRequired[str] -class TLoadPackageState(TVersionedState, total=False): +class TLoadPackageDropTablesState(TypedDict): + dropped_tables: NotRequired[List[TTableSchema]] + """List of tables that are to be dropped from the schema and destination (i.e. when `refresh` mode is used)""" + truncated_tables: NotRequired[List[TTableSchema]] + """List of tables that are to be truncated in the destination (i.e. when `refresh='drop_data'` mode is used)""" + + +class TLoadPackageState(TVersionedState, TLoadPackageDropTablesState, total=False): created_at: DateTime """Timestamp when the load package was created""" pipeline_state: NotRequired[TPipelineStateDoc] @@ -76,11 +89,6 @@ class TLoadPackageState(TVersionedState, total=False): destination_state: NotRequired[Dict[str, Any]] """private space for destinations to store state relevant only to the load package""" - dropped_tables: NotRequired[List[TTableSchema]] - """List of tables that are to be dropped from the schema and destination (i.e. when `refresh` mode is used)""" - truncated_tables: NotRequired[List[TTableSchema]] - """List of tables that are to be truncated in the destination (i.e. when `refresh='drop_data'` mode is used)""" - class TLoadPackage(TypedDict, total=False): load_id: str @@ -165,7 +173,7 @@ def with_retry(self) -> "ParsedLoadJobFileName": @staticmethod def parse(file_name: str) -> "ParsedLoadJobFileName": - p = Path(file_name) + p = PurePath(file_name) parts = p.name.split(".") if len(parts) != 4: raise TerminalValueError(parts) @@ -319,13 +327,16 @@ def __init__(self, storage: FileStorage, initial_state: TLoadPackageStatus) -> N # def get_package_path(self, load_id: str) -> str: + """Gets path of the package relative to storage root""" return load_id - def get_job_folder_path(self, load_id: str, folder: TJobState) -> str: - return os.path.join(self.get_package_path(load_id), folder) + def get_job_state_folder_path(self, load_id: str, state: TJobState) -> str: + """Gets path to the jobs in `state` in package `load_id`, relative to the storage root""" + return os.path.join(self.get_package_path(load_id), state) - def get_job_file_path(self, load_id: str, folder: TJobState, file_name: str) -> str: - return os.path.join(self.get_job_folder_path(load_id, folder), file_name) + def get_job_file_path(self, load_id: str, state: TJobState, file_name: str) -> str: + """Get path to job with `file_name` in `state` in package `load_id`, relative to the storage root""" + return os.path.join(self.get_job_state_folder_path(load_id, state), file_name) def list_packages(self) -> Sequence[str]: """Lists all load ids in storage, earliest first @@ -338,29 +349,42 @@ def list_packages(self) -> Sequence[str]: def list_new_jobs(self, load_id: str) -> Sequence[str]: new_jobs = self.storage.list_folder_files( - self.get_job_folder_path(load_id, PackageStorage.NEW_JOBS_FOLDER) + self.get_job_state_folder_path(load_id, PackageStorage.NEW_JOBS_FOLDER) ) return new_jobs def list_started_jobs(self, load_id: str) -> Sequence[str]: return self.storage.list_folder_files( - self.get_job_folder_path(load_id, PackageStorage.STARTED_JOBS_FOLDER) + self.get_job_state_folder_path(load_id, PackageStorage.STARTED_JOBS_FOLDER) ) def list_failed_jobs(self, load_id: str) -> Sequence[str]: - return self.storage.list_folder_files( - self.get_job_folder_path(load_id, PackageStorage.FAILED_JOBS_FOLDER) - ) - - def list_jobs_for_table(self, load_id: str, table_name: str) -> Sequence[LoadJobInfo]: - return self.filter_jobs_for_table(self.list_all_jobs(load_id), table_name) - - def list_all_jobs(self, load_id: str) -> Sequence[LoadJobInfo]: - info = self.get_load_package_info(load_id) - return [job for job in flatten_list_or_items(iter(info.jobs.values()))] # type: ignore + return [ + file + for file in self.storage.list_folder_files( + self.get_job_state_folder_path(load_id, PackageStorage.FAILED_JOBS_FOLDER) + ) + if not file.endswith(JOB_EXCEPTION_EXTENSION) + ] + + def list_job_with_states_for_table( + self, load_id: str, table_name: str + ) -> Sequence[Tuple[TJobState, ParsedLoadJobFileName]]: + return self.filter_jobs_for_table(self.list_all_jobs_with_states(load_id), table_name) + + def list_all_jobs_with_states( + self, load_id: str + ) -> Sequence[Tuple[TJobState, ParsedLoadJobFileName]]: + info = self.get_load_package_jobs(load_id) + state_jobs = [] + for state, jobs in info.items(): + state_jobs.extend([(state, job) for job in jobs]) + return state_jobs def list_failed_jobs_infos(self, load_id: str) -> Sequence[LoadJobInfo]: """List all failed jobs and associated error messages for a load package with `load_id`""" + if not self.is_package_completed(load_id): + raise LoadPackageNotCompleted(load_id) failed_jobs: List[LoadJobInfo] = [] package_path = self.get_package_path(load_id) package_created_at = pendulum.from_timestamp( @@ -371,12 +395,19 @@ def list_failed_jobs_infos(self, load_id: str) -> Sequence[LoadJobInfo]: ) ) for file in self.list_failed_jobs(load_id): - if not file.endswith(".exception"): - failed_jobs.append( - self._read_job_file_info("failed_jobs", file, package_created_at) + failed_jobs.append( + self._read_job_file_info( + load_id, "failed_jobs", ParsedLoadJobFileName.parse(file), package_created_at ) + ) return failed_jobs + def is_package_completed(self, load_id: str) -> bool: + package_path = self.get_package_path(load_id) + return self.storage.has_file( + os.path.join(package_path, PackageStorage.PACKAGE_COMPLETED_FILE_NAME) + ) + # # Move jobs # @@ -385,7 +416,9 @@ def import_job( self, load_id: str, job_file_path: str, job_state: TJobState = "new_jobs" ) -> None: """Adds new job by moving the `job_file_path` into `new_jobs` of package `load_id`""" - self.storage.atomic_import(job_file_path, self.get_job_folder_path(load_id, job_state)) + self.storage.atomic_import( + job_file_path, self.get_job_state_folder_path(load_id, job_state) + ) def start_job(self, load_id: str, file_name: str) -> str: return self._move_job( @@ -397,7 +430,7 @@ def fail_job(self, load_id: str, file_name: str, failed_message: Optional[str]) if failed_message: self.storage.save( self.get_job_file_path( - load_id, PackageStorage.FAILED_JOBS_FOLDER, file_name + ".exception" + load_id, PackageStorage.FAILED_JOBS_FOLDER, file_name + JOB_EXCEPTION_EXTENSION ), failed_message, ) @@ -455,6 +488,8 @@ def create_package(self, load_id: str, initial_state: TLoadPackageState = None) def complete_loading_package(self, load_id: str, load_state: TLoadPackageStatus) -> str: """Completes loading the package by writing marker file with`package_state. Returns path to the completed package""" load_path = self.get_package_path(load_id) + if self.is_package_completed(load_id): + raise LoadPackageAlreadyCompleted(load_id) # save marker file self.storage.save( os.path.join(load_path, PackageStorage.PACKAGE_COMPLETED_FILE_NAME), load_state @@ -468,7 +503,7 @@ def remove_completed_jobs(self, load_id: str) -> None: # delete completed jobs if not has_failed_jobs: self.storage.delete_folder( - self.get_job_folder_path(load_id, PackageStorage.COMPLETED_JOBS_FOLDER), + self.get_job_state_folder_path(load_id, PackageStorage.COMPLETED_JOBS_FOLDER), recursively=True, ) @@ -533,11 +568,32 @@ def get_load_package_state_path(self, load_id: str) -> str: # Get package info # - def get_load_package_info(self, load_id: str) -> LoadPackageInfo: - """Gets information on normalized/completed package with given load_id, all jobs and their statuses.""" + def get_load_package_jobs(self, load_id: str) -> Dict[TJobState, List[ParsedLoadJobFileName]]: + """Gets all jobs in a package and returns them as lists assigned to a particular state.""" package_path = self.get_package_path(load_id) if not self.storage.has_folder(package_path): raise LoadPackageNotFound(load_id) + all_jobs: Dict[TJobState, List[ParsedLoadJobFileName]] = {} + for state in WORKING_FOLDERS: + jobs: List[ParsedLoadJobFileName] = [] + with contextlib.suppress(FileNotFoundError): + # we ignore if load package lacks one of working folders. completed_jobs may be deleted on archiving + for file in self.storage.list_folder_files( + self.get_job_state_folder_path(load_id, state), to_root=False + ): + if not file.endswith(JOB_EXCEPTION_EXTENSION): + jobs.append(ParsedLoadJobFileName.parse(file)) + all_jobs[state] = jobs + return all_jobs + + def get_load_package_info(self, load_id: str) -> LoadPackageInfo: + """Gets information on normalized/completed package with given load_id, all jobs and their statuses. + + Will reach to the file system to get additional stats, mtime, also collects exceptions for failed jobs. + NOTE: do not call this function often. it should be used only to generate metrics + """ + package_path = self.get_package_path(load_id) + package_jobs = self.get_load_package_jobs(load_id) package_created_at: DateTime = None package_state = self.initial_state @@ -560,15 +616,11 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo: schema = Schema.from_dict(self._load_schema(load_id)) # read jobs with all statuses - all_jobs: Dict[TJobState, List[LoadJobInfo]] = {} - for state in WORKING_FOLDERS: - jobs: List[LoadJobInfo] = [] - with contextlib.suppress(FileNotFoundError): - # we ignore if load package lacks one of working folders. completed_jobs may be deleted on archiving - for file in self.storage.list_folder_files(os.path.join(package_path, state)): - if not file.endswith(".exception"): - jobs.append(self._read_job_file_info(state, file, package_created_at)) - all_jobs[state] = jobs + all_job_infos: Dict[TJobState, List[LoadJobInfo]] = {} + for state, jobs in package_jobs.items(): + all_job_infos[state] = [ + self._read_job_file_info(load_id, state, job, package_created_at) for job in jobs + ] return LoadPackageInfo( load_id, @@ -577,15 +629,46 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo: schema, applied_update, package_created_at, - all_jobs, + all_job_infos, ) - def _read_job_file_info(self, state: TJobState, file: str, now: DateTime = None) -> LoadJobInfo: - try: - failed_message = self.storage.load(file + ".exception") - except FileNotFoundError: - failed_message = None - full_path = self.storage.make_full_path(file) + def get_job_failed_message(self, load_id: str, job: ParsedLoadJobFileName) -> str: + """Get exception message of a failed job.""" + rel_path = self.get_job_file_path(load_id, "failed_jobs", job.file_name()) + if not self.storage.has_file(rel_path): + raise FileNotFoundError(rel_path) + failed_message: str = None + with contextlib.suppress(FileNotFoundError): + failed_message = self.storage.load(rel_path + JOB_EXCEPTION_EXTENSION) + return failed_message + + def job_to_job_info( + self, load_id: str, state: TJobState, job: ParsedLoadJobFileName + ) -> LoadJobInfo: + """Creates partial job info by converting job object. size, mtime and failed message will not be populated""" + full_path = os.path.join( + self.storage.storage_path, self.get_job_file_path(load_id, state, job.file_name()) + ) + return LoadJobInfo( + state, + full_path, + 0, + None, + 0, + job, + None, + ) + + def _read_job_file_info( + self, load_id: str, state: TJobState, job: ParsedLoadJobFileName, now: DateTime = None + ) -> LoadJobInfo: + """Creates job info by reading additional props from storage""" + failed_message = None + if state == "failed_jobs": + failed_message = self.get_job_failed_message(load_id, job) + full_path = os.path.join( + self.storage.storage_path, self.get_job_file_path(load_id, state, job.file_name()) + ) st = os.stat(full_path) return LoadJobInfo( state, @@ -593,7 +676,7 @@ def _read_job_file_info(self, state: TJobState, file: str, now: DateTime = None) st.st_size, pendulum.from_timestamp(st.st_mtime), PackageStorage._job_elapsed_time_seconds(full_path, now.timestamp() if now else None), - ParsedLoadJobFileName.parse(file), + job, failed_message, ) @@ -611,10 +694,11 @@ def _move_job( ) -> str: # ensure we move file names, not paths assert file_name == FileStorage.get_file_name_from_file_path(file_name) - load_path = self.get_package_path(load_id) - dest_path = os.path.join(load_path, dest_folder, new_file_name or file_name) - self.storage.atomic_rename(os.path.join(load_path, source_folder, file_name), dest_path) - # print(f"{join(load_path, source_folder, file_name)} -> {dest_path}") + + dest_path = self.get_job_file_path(load_id, dest_folder, new_file_name or file_name) + self.storage.atomic_rename( + self.get_job_file_path(load_id, source_folder, file_name), dest_path + ) return self.storage.make_full_path(dest_path) def _load_schema(self, load_id: str) -> DictStrAny: @@ -659,9 +743,9 @@ def _job_elapsed_time_seconds(file_path: str, now_ts: float = None) -> float: @staticmethod def filter_jobs_for_table( - all_jobs: Iterable[LoadJobInfo], table_name: str - ) -> Sequence[LoadJobInfo]: - return [job for job in all_jobs if job.job_file_info.table_name == table_name] + all_jobs: Iterable[Tuple[TJobState, ParsedLoadJobFileName]], table_name: str + ) -> Sequence[Tuple[TJobState, ParsedLoadJobFileName]]: + return [job for job in all_jobs if job[1].table_name == table_name] @configspec diff --git a/dlt/common/storages/schema_storage.py b/dlt/common/storages/schema_storage.py index 1afed18929..0544de696f 100644 --- a/dlt/common/storages/schema_storage.py +++ b/dlt/common/storages/schema_storage.py @@ -5,7 +5,7 @@ from dlt.common.json import json from dlt.common.configuration import with_config from dlt.common.configuration.accessors import config -from dlt.common.schema.utils import to_pretty_json, to_pretty_yaml +from dlt.common.schema.utils import get_processing_hints, to_pretty_json, to_pretty_yaml from dlt.common.storages.configuration import ( SchemaStorageConfiguration, TSchemaFileFormat, @@ -57,6 +57,14 @@ def load_schema(self, name: str) -> Schema: return Schema.from_dict(storage_schema) def save_schema(self, schema: Schema) -> str: + """Saves schema to the storage and returns the path relative to storage. + + If import schema path is configured and import schema with schema.name exits, it + will be linked to `schema` via `_imported_version_hash`. Such hash is used in `load_schema` to + detect if import schema changed and thus to overwrite the storage schema. + + If export schema path is configured, `schema` will be exported to it. + """ # check if there's schema to import if self.config.import_schema_path: try: @@ -66,11 +74,25 @@ def save_schema(self, schema: Schema) -> str: except FileNotFoundError: # just save the schema pass - path = self._save_schema(schema) - if self.config.export_schema_path: - self._export_schema(schema, self.config.export_schema_path) + path = self._save_and_export_schema(schema) return path + def save_import_schema_if_not_exists(self, schema: Schema) -> bool: + """Saves import schema, if not exists. If schema was saved, link itself as imported from""" + if self.config.import_schema_path: + try: + self._load_import_schema(schema.name) + except FileNotFoundError: + # save import schema only if it not exist + self._export_schema( + schema, self.config.import_schema_path, remove_processing_hints=True + ) + # if import schema got saved then add own version hash as import version hash + schema._imported_version_hash = schema.version_hash + return True + + return False + def remove_schema(self, name: str) -> None: schema_file = self._file_name_in_store(name, "json") self.storage.delete(schema_file) @@ -116,25 +138,32 @@ def _maybe_import_schema(self, name: str, storage_schema: DictStrAny = None) -> f" {rv_schema._imported_version_hash}" ) # if schema was imported, overwrite storage schema - self._save_schema(rv_schema) - if self.config.export_schema_path: - self._export_schema(rv_schema, self.config.export_schema_path) + self._save_and_export_schema(rv_schema, check_processing_hints=True) else: # import schema when imported schema was modified from the last import rv_schema = Schema.from_dict(storage_schema) i_s = Schema.from_dict(imported_schema) if i_s.version_hash != rv_schema._imported_version_hash: + logger.warning( + f"Schema {name} was present in schema storage at" + f" {self.storage.storage_path} but will be overwritten with imported schema" + f" version {i_s.version} and imported hash {i_s.version_hash}" + ) + tables_seen_data = rv_schema.data_tables(seen_data_only=True) + if tables_seen_data: + logger.warning( + f"Schema {name} in schema storage contains tables" + f" ({', '.join(t['name'] for t in tables_seen_data)}) that are present" + " in the destination. If you changed schema of those tables in import" + " schema, consider using one of the refresh options:" + " https://dlthub.com/devel/general-usage/pipeline#refresh-pipeline-data-and-state" + ) + rv_schema.replace_schema_content(i_s, link_to_replaced_schema=True) rv_schema._imported_version_hash = i_s.version_hash - logger.info( - f"Schema {name} was present in {self.storage.storage_path} but is" - f" overwritten with imported schema version {i_s.version} and" - f" imported hash {i_s.version_hash}" - ) + # if schema was imported, overwrite storage schema - self._save_schema(rv_schema) - if self.config.export_schema_path: - self._export_schema(rv_schema, self.config.export_schema_path) + self._save_and_export_schema(rv_schema, check_processing_hints=True) except FileNotFoundError: # no schema to import -> skip silently and return the original if storage_schema is None: @@ -156,8 +185,13 @@ def _load_import_schema(self, name: str) -> DictStrAny: import_storage.load(schema_file), self.config.external_schema_format ) - def _export_schema(self, schema: Schema, export_path: str) -> None: - stored_schema = schema.to_dict(remove_defaults=True) + def _export_schema( + self, schema: Schema, export_path: str, remove_processing_hints: bool = False + ) -> None: + stored_schema = schema.to_dict( + remove_defaults=self.config.external_schema_format_remove_defaults, + remove_processing_hints=remove_processing_hints, + ) if self.config.external_schema_format == "json": exported_schema_s = to_pretty_json(stored_schema) elif self.config.external_schema_format == "yaml": @@ -175,7 +209,7 @@ def _export_schema(self, schema: Schema, export_path: str) -> None: ) def _save_schema(self, schema: Schema) -> str: - # save a schema to schema store + """Saves schema to schema store and bumps the version""" schema_file = self._file_name_in_store(schema.name, "json") stored_schema = schema.to_dict() saved_path = self.storage.save(schema_file, to_pretty_json(stored_schema)) @@ -184,16 +218,45 @@ def _save_schema(self, schema: Schema) -> str: schema._bump_version() return saved_path + def _save_and_export_schema(self, schema: Schema, check_processing_hints: bool = False) -> str: + """Saves schema to schema store and then exports it. If the export path is the same as import + path, processing hints will be removed. + """ + saved_path = self._save_schema(schema) + if self.config.export_schema_path: + self._export_schema( + schema, + self.config.export_schema_path, + self.config.export_schema_path == self.config.import_schema_path, + ) + # if any processing hints are found we should warn the user + if check_processing_hints and (processing_hints := get_processing_hints(schema.tables)): + msg = ( + f"Imported schema {schema.name} contains processing hints for some tables." + " Processing hints are used by normalizer (x-normalizer) to mark tables that got" + " materialized and that prevents destructive changes to the schema. In most cases" + " import schema should not contain processing hints because it is mostly used to" + " initialize tables in a new dataset. " + ) + msg += "Affected tables are: " + ", ".join(processing_hints.keys()) + logger.warning(msg) + return saved_path + @staticmethod def load_schema_file( - path: str, name: str, extensions: Tuple[TSchemaFileFormat, ...] = SchemaFileExtensions + path: str, + name: str, + extensions: Tuple[TSchemaFileFormat, ...] = SchemaFileExtensions, + remove_processing_hints: bool = False, ) -> Schema: storage = FileStorage(path) for extension in extensions: file = SchemaStorage._file_name_in_store(name, extension) if storage.has_file(file): parsed_schema = SchemaStorage._parse_schema_str(storage.load(file), extension) - schema = Schema.from_dict(parsed_schema) + schema = Schema.from_dict( + parsed_schema, remove_processing_hints=remove_processing_hints + ) if schema.name != name: raise UnexpectedSchemaName(name, path, schema.name) return schema diff --git a/dlt/common/typing.py b/dlt/common/typing.py index 5668185391..48f641c994 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -4,7 +4,7 @@ import os from re import Pattern as _REPattern import sys -from types import FunctionType, MethodType, ModuleType +from types import FunctionType from typing import ( ForwardRef, Callable, @@ -39,6 +39,7 @@ Concatenate, get_args, get_origin, + get_original_bases, ) try: @@ -116,6 +117,8 @@ VARIANT_FIELD_FORMAT = "v_%s" TFileOrPath = Union[str, PathLike, IO[Any]] TSortOrder = Literal["asc", "desc"] +TLoaderFileFormat = Literal["jsonl", "typed-jsonl", "insert_values", "parquet", "csv"] +"""known loader file formats""" class ConfigValueSentinel(NamedTuple): @@ -268,6 +271,25 @@ def is_literal_type(hint: Type[Any]) -> bool: return False +def get_literal_args(literal: Type[Any]) -> List[Any]: + """Recursively get arguments from nested Literal types and return an unified list.""" + if not hasattr(literal, "__origin__") or literal.__origin__ is not Literal: + raise ValueError("Provided type is not a Literal") + + unified_args = [] + + def _get_args(literal: Type[Any]) -> None: + for arg in get_args(literal): + if hasattr(arg, "__origin__") and arg.__origin__ is Literal: + _get_args(arg) + else: + unified_args.append(arg) + + _get_args(literal) + + return unified_args + + def is_newtype_type(t: Type[Any]) -> bool: if hasattr(t, "__supertype__"): return True @@ -373,7 +395,7 @@ def is_subclass(subclass: Any, cls: Any) -> bool: def get_generic_type_argument_from_instance( - instance: Any, sample_value: Optional[Any] + instance: Any, sample_value: Optional[Any] = None ) -> Type[Any]: """Infers type argument of a Generic class from an `instance` of that class using optional `sample_value` of the argument type @@ -387,8 +409,14 @@ def get_generic_type_argument_from_instance( Type[Any]: type argument or Any if not known """ orig_param_type = Any - if hasattr(instance, "__orig_class__"): - orig_param_type = get_args(instance.__orig_class__)[0] + if cls_ := getattr(instance, "__orig_class__", None): + # instance of generic class + pass + elif bases_ := get_original_bases(instance.__class__): + # instance of class deriving from generic + cls_ = bases_[0] + if cls_: + orig_param_type = get_args(cls_)[0] if orig_param_type is Any and sample_value is not None: orig_param_type = type(sample_value) return orig_param_type # type: ignore diff --git a/dlt/common/utils.py b/dlt/common/utils.py index cb2ec4c3d9..7109daf497 100644 --- a/dlt/common/utils.py +++ b/dlt/common/utils.py @@ -13,6 +13,7 @@ from typing import ( Any, + Callable, ContextManager, Dict, MutableMapping, @@ -136,47 +137,11 @@ def flatten_list_of_str_or_dicts(seq: Sequence[Union[StrAny, str]]) -> DictStrAn else: key = str(e) if key in o: - raise KeyError(f"Cannot flatten with duplicate key {k}") + raise KeyError(f"Cannot flatten with duplicate key {key}") o[key] = None return o -# def flatten_dicts_of_dicts(dicts: Mapping[str, Any]) -> Sequence[Any]: -# """ -# Transform and object {K: {...}, L: {...}...} -> [{key:K, ....}, {key: L, ...}, ...] -# """ -# o: List[Any] = [] -# for k, v in dicts.items(): -# if isinstance(v, list): -# # if v is a list then add "key" to each list element -# for lv in v: -# lv["key"] = k -# else: -# # add as "key" to dict -# v["key"] = k - -# o.append(v) -# return o - - -# def tuplify_list_of_dicts(dicts: Sequence[DictStrAny]) -> Sequence[DictStrAny]: -# """ -# Transform list of dictionaries with single key into single dictionary of {"key": orig_key, "value": orig_value} -# """ -# for d in dicts: -# if len(d) > 1: -# raise ValueError(f"Tuplify requires one key dicts {d}") -# if len(d) == 1: -# key = next(iter(d)) -# # delete key first to avoid name clashes -# value = d[key] -# del d[key] -# d["key"] = key -# d["value"] = value - -# return dicts - - def flatten_list_or_items(_iter: Union[Iterable[TAny], Iterable[List[TAny]]]) -> Iterator[TAny]: for items in _iter: if isinstance(items, List): @@ -503,11 +468,15 @@ def merge_row_counts(row_counts_1: RowCounts, row_counts_2: RowCounts) -> None: row_counts_1[counter_name] = row_counts_1.get(counter_name, 0) + row_counts_2[counter_name] -def extend_list_deduplicated(original_list: List[Any], extending_list: Iterable[Any]) -> List[Any]: +def extend_list_deduplicated( + original_list: List[Any], + extending_list: Iterable[Any], + normalize_f: Callable[[str], str] = str.__call__, +) -> List[Any]: """extends the first list by the second, but does not add duplicates""" - list_keys = set(original_list) + list_keys = set(normalize_f(s) for s in original_list) for item in extending_list: - if item not in list_keys: + if normalize_f(item) not in list_keys: original_list.append(item) return original_list diff --git a/dlt/common/validation.py b/dlt/common/validation.py index 0a8bced287..8862c10024 100644 --- a/dlt/common/validation.py +++ b/dlt/common/validation.py @@ -7,6 +7,7 @@ from dlt.common.exceptions import DictValidationException from dlt.common.typing import ( StrAny, + get_literal_args, get_type_name, is_callable_type, is_literal_type, @@ -114,7 +115,7 @@ def verify_prop(pk: str, pv: Any, t: Any) -> None: failed_validations, ) elif is_literal_type(t): - a_l = get_args(t) + a_l = get_literal_args(t) if pv not in a_l: raise DictValidationException( f"field '{pk}' with value {pv} is not one of: {a_l}", path, t, pk, pv diff --git a/dlt/destinations/__init__.py b/dlt/destinations/__init__.py index 302de24a6b..0546d16bcd 100644 --- a/dlt/destinations/__init__.py +++ b/dlt/destinations/__init__.py @@ -8,6 +8,7 @@ from dlt.destinations.impl.athena.factory import athena from dlt.destinations.impl.redshift.factory import redshift from dlt.destinations.impl.qdrant.factory import qdrant +from dlt.destinations.impl.lancedb.factory import lancedb from dlt.destinations.impl.motherduck.factory import motherduck from dlt.destinations.impl.weaviate.factory import weaviate from dlt.destinations.impl.destination.factory import destination @@ -28,6 +29,7 @@ "athena", "redshift", "qdrant", + "lancedb", "motherduck", "weaviate", "synapse", diff --git a/dlt/destinations/adapters.py b/dlt/destinations/adapters.py index 1c3e094e19..0cf04b7b59 100644 --- a/dlt/destinations/adapters.py +++ b/dlt/destinations/adapters.py @@ -1,17 +1,20 @@ """This module collects all destination adapters present in `impl` namespace""" -from dlt.destinations.impl.weaviate import weaviate_adapter -from dlt.destinations.impl.qdrant import qdrant_adapter -from dlt.destinations.impl.bigquery import bigquery_adapter -from dlt.destinations.impl.synapse import synapse_adapter -from dlt.destinations.impl.clickhouse import clickhouse_adapter -from dlt.destinations.impl.athena import athena_adapter +from dlt.destinations.impl.weaviate.weaviate_adapter import weaviate_adapter +from dlt.destinations.impl.qdrant.qdrant_adapter import qdrant_adapter +from dlt.destinations.impl.lancedb import lancedb_adapter +from dlt.destinations.impl.bigquery.bigquery_adapter import bigquery_adapter +from dlt.destinations.impl.synapse.synapse_adapter import synapse_adapter +from dlt.destinations.impl.clickhouse.clickhouse_adapter import clickhouse_adapter +from dlt.destinations.impl.athena.athena_adapter import athena_adapter, athena_partition __all__ = [ "weaviate_adapter", "qdrant_adapter", + "lancedb_adapter", "bigquery_adapter", "synapse_adapter", "clickhouse_adapter", "athena_adapter", + "athena_partition", ] diff --git a/dlt/destinations/fs_client.py b/dlt/destinations/fs_client.py index 2c7ae5bbec..14ec3c0717 100644 --- a/dlt/destinations/fs_client.py +++ b/dlt/destinations/fs_client.py @@ -1,3 +1,4 @@ +import gzip from typing import Iterable, cast, Any, List from abc import ABC, abstractmethod from fsspec import AbstractFileSystem @@ -39,10 +40,19 @@ def read_bytes(self, path: str, start: Any = None, end: Any = None, **kwargs: An def read_text( self, path: str, - encoding: Any = None, + encoding: Any = "utf-8", errors: Any = None, newline: Any = None, + compression: str = None, **kwargs: Any ) -> str: - """reads given file into string""" - return cast(str, self.fs_client.read_text(path, encoding, errors, newline, **kwargs)) + """reads given file into string, tries gzip and pure text""" + if compression is None: + try: + return self.read_text(path, encoding, errors, newline, "gzip", **kwargs) + except (gzip.BadGzipFile, OSError): + pass + with self.fs_client.open( + path, mode="rt", compression=compression, encoding=encoding, newline=newline + ) as f: + return cast(str, f.read()) diff --git a/dlt/destinations/impl/athena/__init__.py b/dlt/destinations/impl/athena/__init__.py index 87a11f9f41..e69de29bb2 100644 --- a/dlt/destinations/impl/athena/__init__.py +++ b/dlt/destinations/impl/athena/__init__.py @@ -1,33 +0,0 @@ -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.data_writers.escape import ( - escape_athena_identifier, - format_bigquery_datetime_literal, -) -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - # athena only supports loading from staged files on s3 for now - caps.preferred_loader_file_format = None - caps.supported_loader_file_formats = [] - caps.supported_table_formats = ["iceberg"] - caps.preferred_staging_file_format = "parquet" - caps.supported_staging_file_formats = ["parquet", "jsonl"] - caps.escape_identifier = escape_athena_identifier - caps.format_datetime_literal = format_bigquery_datetime_literal - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) - caps.max_identifier_length = 255 - caps.max_column_identifier_length = 255 - caps.max_query_length = 16 * 1024 * 1024 - caps.is_max_query_length_in_bytes = True - caps.max_text_data_type_length = 262144 - caps.is_max_text_data_type_length_in_bytes = True - caps.supports_ddl_transactions = False - caps.supports_transactions = False - caps.alter_add_multi_column = True - caps.schema_supports_numeric_precision = False - caps.timestamp_precision = 3 - caps.supports_truncate_command = False - return caps diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index 60ea64a4e7..4225d63fe7 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -17,6 +17,7 @@ import re from contextlib import contextmanager +from fsspec import AbstractFileSystem from pendulum.datetime import DateTime, Date from datetime import datetime # noqa: I251 @@ -33,22 +34,20 @@ from dlt.common import logger from dlt.common.exceptions import TerminalValueError -from dlt.common.utils import without_none -from dlt.common.data_types import TDataType -from dlt.common.schema import TColumnSchema, Schema, TSchemaTables, TTableSchema +from dlt.common.storages.fsspec_filesystem import fsspec_from_config +from dlt.common.utils import uniq_id, without_none +from dlt.common.schema import TColumnSchema, Schema, TTableSchema from dlt.common.schema.typing import ( TTableSchema, TColumnType, - TWriteDisposition, TTableFormat, TSortOrder, ) -from dlt.common.schema.utils import table_schema_has_type, get_table_format +from dlt.common.schema.utils import table_schema_has_type from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import LoadJob, DoNothingFollowupJob, DoNothingJob -from dlt.common.destination.reference import TLoadJobState, NewLoadJob, SupportsStagingDestination -from dlt.common.storages import FileStorage -from dlt.common.data_writers.escape import escape_bigquery_identifier +from dlt.common.destination.reference import NewLoadJob, SupportsStagingDestination +from dlt.common.data_writers.escape import escape_hive_identifier from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob from dlt.destinations.typing import DBApi, DBTransaction @@ -58,7 +57,6 @@ DatabaseUndefinedRelation, LoadJobTerminalException, ) -from dlt.destinations.impl.athena import capabilities from dlt.destinations.sql_client import ( SqlClientBase, DBApiCursorImpl, @@ -166,7 +164,7 @@ class AthenaMergeJob(SqlMergeJob): @classmethod def _new_temp_table_name(cls, name_prefix: str, sql_client: SqlClientBase[Any]) -> str: # reproducible name so we know which table to drop - with sql_client.with_staging_dataset(staging=True): + with sql_client.with_staging_dataset(): return sql_client.make_qualified_table_name(name_prefix) @classmethod @@ -221,11 +219,16 @@ def requires_temp_table_for_delete(cls) -> bool: class AthenaSQLClient(SqlClientBase[Connection]): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() dbapi: ClassVar[DBApi] = pyathena - def __init__(self, dataset_name: str, config: AthenaClientConfiguration) -> None: - super().__init__(None, dataset_name) + def __init__( + self, + dataset_name: str, + staging_dataset_name: str, + config: AthenaClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(None, dataset_name, staging_dataset_name, capabilities) self._conn: Connection = None self.config = config self.credentials = config.credentials @@ -254,8 +257,9 @@ def escape_ddl_identifier(self, v: str) -> str: # Athena uses HIVE to create tables but for querying it uses PRESTO (so normal escaping) if not v: return v + v = self.capabilities.casefold_identifier(v) # bigquery uses hive escaping - return escape_bigquery_identifier(v) + return escape_hive_identifier(v) def fully_qualified_ddl_dataset_name(self) -> str: return self.escape_ddl_identifier(self.dataset_name) @@ -271,11 +275,6 @@ def create_dataset(self) -> None: def drop_dataset(self) -> None: self.execute_sql(f"DROP DATABASE {self.fully_qualified_ddl_dataset_name()} CASCADE;") - def fully_qualified_dataset_name(self, escape: bool = True) -> str: - return ( - self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name - ) - def drop_tables(self, *tables: str) -> None: if not tables: return @@ -366,17 +365,14 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB yield DBApiCursorImpl(cursor) # type: ignore - def has_dataset(self) -> bool: - # PRESTO escaping for queries - query = f"""SHOW DATABASES LIKE {self.fully_qualified_dataset_name()};""" - rows = self.execute_sql(query) - return len(rows) > 0 - class AthenaClient(SqlJobClientWithStaging, SupportsStagingDestination): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: AthenaClientConfiguration) -> None: + def __init__( + self, + schema: Schema, + config: AthenaClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: # verify if staging layout is valid for Athena # this will raise if the table prefix is not properly defined # we actually that {table_name} is first, no {schema_name} is allowed @@ -386,7 +382,12 @@ def __init__(self, schema: Schema, config: AthenaClientConfiguration) -> None: table_needs_own_folder=True, ) - sql_client = AthenaSQLClient(config.normalize_dataset_name(schema), config) + sql_client = AthenaSQLClient( + config.normalize_dataset_name(schema), + config.normalize_staging_dataset_name(schema), + config, + capabilities, + ) super().__init__(schema, config, sql_client) self.sql_client: AthenaSQLClient = sql_client # type: ignore self.config: AthenaClientConfiguration = config @@ -432,8 +433,12 @@ def _get_table_update_sql( is_iceberg = self._is_iceberg_table(table) or table.get("write_disposition", None) == "skip" columns = ", ".join([self._get_column_def_sql(c, table_format) for c in new_columns]) + # create unique tag for iceberg table so it is never recreated in the same folder + # athena requires some kind of special cleaning (or that is a bug) so we cannot refresh + # iceberg tables without it + location_tag = uniq_id(6) if is_iceberg else "" # this will fail if the table prefix is not properly defined - table_prefix = self.table_prefix_layout.format(table_name=table_name) + table_prefix = self.table_prefix_layout.format(table_name=table_name + location_tag) location = f"{bucket}/{dataset}/{table_prefix}" # use qualified table names diff --git a/dlt/destinations/impl/athena/factory.py b/dlt/destinations/impl/athena/factory.py index 5b37607cca..4945a15661 100644 --- a/dlt/destinations/impl/athena/factory.py +++ b/dlt/destinations/impl/athena/factory.py @@ -1,9 +1,14 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext -from dlt.destinations.impl.athena.configuration import AthenaClientConfiguration from dlt.common.configuration.specs import AwsCredentials -from dlt.destinations.impl.athena import capabilities +from dlt.common.data_writers.escape import ( + escape_athena_identifier, + format_bigquery_datetime_literal, +) +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE + +from dlt.destinations.impl.athena.configuration import AthenaClientConfiguration if t.TYPE_CHECKING: from dlt.destinations.impl.athena.athena import AthenaClient @@ -12,8 +17,37 @@ class athena(Destination[AthenaClientConfiguration, "AthenaClient"]): spec = AthenaClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + # athena only supports loading from staged files on s3 for now + caps.preferred_loader_file_format = None + caps.supported_loader_file_formats = [] + caps.supported_table_formats = ["iceberg"] + caps.preferred_staging_file_format = "parquet" + caps.supported_staging_file_formats = ["parquet", "jsonl"] + # athena is storing all identifiers in lower case and is case insensitive + # it also uses lower case in all the queries + # https://docs.aws.amazon.com/athena/latest/ug/tables-databases-columns-names.html + caps.escape_identifier = escape_athena_identifier + caps.casefold_identifier = str.lower + caps.has_case_sensitive_identifiers = False + caps.format_datetime_literal = format_bigquery_datetime_literal + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + caps.max_identifier_length = 255 + caps.max_column_identifier_length = 255 + caps.max_query_length = 16 * 1024 * 1024 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 262144 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_ddl_transactions = False + caps.supports_transactions = False + caps.alter_add_multi_column = True + caps.schema_supports_numeric_precision = False + caps.timestamp_precision = 3 + caps.supports_truncate_command = False + caps.supported_merge_strategies = ["delete-insert", "scd2"] + return caps @property def client_class(self) -> t.Type["AthenaClient"]: diff --git a/dlt/destinations/impl/bigquery/__init__.py b/dlt/destinations/impl/bigquery/__init__.py index 39322b43a0..e69de29bb2 100644 --- a/dlt/destinations/impl/bigquery/__init__.py +++ b/dlt/destinations/impl/bigquery/__init__.py @@ -1,31 +0,0 @@ -from dlt.common.data_writers.escape import ( - escape_bigquery_identifier, - format_bigquery_datetime_literal, -) -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "jsonl" - caps.supported_loader_file_formats = ["jsonl", "parquet"] - caps.preferred_staging_file_format = "parquet" - caps.supported_staging_file_formats = ["parquet", "jsonl"] - # BQ limit is 4GB but leave a large headroom since buffered writer does not preemptively check size - caps.recommended_file_size = int(1024 * 1024 * 1024) - caps.escape_identifier = escape_bigquery_identifier - caps.escape_literal = None - caps.format_datetime_literal = format_bigquery_datetime_literal - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (76, 38) - caps.max_identifier_length = 1024 - caps.max_column_identifier_length = 300 - caps.max_query_length = 1024 * 1024 - caps.is_max_query_length_in_bytes = False - caps.max_text_data_type_length = 10 * 1024 * 1024 - caps.is_max_text_data_type_length_in_bytes = True - caps.supports_ddl_transactions = False - caps.supports_clone_table = True - - return caps diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index f26e6f42ee..095974d186 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -1,7 +1,7 @@ import functools import os from pathlib import Path -from typing import Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, cast +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, cast import google.cloud.bigquery as bigquery # noqa: I250 from google.api_core import exceptions as api_core_exceptions @@ -20,23 +20,24 @@ SupportsStagingDestination, ) from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns -from dlt.common.schema.exceptions import UnknownTableException from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat from dlt.common.schema.utils import get_inherited_table_hint from dlt.common.schema.utils import table_schema_has_type from dlt.common.storages.file_storage import FileStorage +from dlt.common.storages.load_package import destination_state from dlt.common.typing import DictStrAny from dlt.destinations.job_impl import DestinationJsonlLoadJob, DestinationParquetLoadJob from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.exceptions import ( + DatabaseTransientException, + DatabaseUndefinedRelation, DestinationSchemaWillNotUpdate, DestinationTerminalException, - DestinationTransientException, LoadJobNotExistsException, LoadJobTerminalException, ) -from dlt.destinations.impl.bigquery import capabilities from dlt.destinations.impl.bigquery.bigquery_adapter import ( + AUTODETECT_SCHEMA_HINT, PARTITION_HINT, CLUSTER_HINT, TABLE_DESCRIPTION_HINT, @@ -50,7 +51,7 @@ from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_jobs import SqlMergeJob from dlt.destinations.type_mapping import TypeMapper -from dlt.pipeline.current import destination_state +from dlt.destinations.utils import parse_db_data_type_str_with_precision class BigQueryTypeMapper(TypeMapper): @@ -58,10 +59,10 @@ class BigQueryTypeMapper(TypeMapper): "complex": "JSON", "text": "STRING", "double": "FLOAT64", - "bool": "BOOLEAN", + "bool": "BOOL", "date": "DATE", "timestamp": "TIMESTAMP", - "bigint": "INTEGER", + "bigint": "INT64", "binary": "BYTES", "wei": "BIGNUMERIC", # non-parametrized should hold wei values "time": "TIME", @@ -74,11 +75,11 @@ class BigQueryTypeMapper(TypeMapper): dbt_to_sct = { "STRING": "text", - "FLOAT": "double", - "BOOLEAN": "bool", + "FLOAT64": "double", + "BOOL": "bool", "DATE": "date", "TIMESTAMP": "timestamp", - "INTEGER": "bigint", + "INT64": "bigint", "BYTES": "binary", "NUMERIC": "decimal", "BIGNUMERIC": "decimal", @@ -97,9 +98,10 @@ def to_db_decimal_type(self, precision: Optional[int], scale: Optional[int]) -> def from_db_type( self, db_type: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: - if db_type == "BIGNUMERIC" and precision is None: + # precision is present in the type name + if db_type == "BIGNUMERIC": return dict(data_type="wei") - return super().from_db_type(db_type, precision, scale) + return super().from_db_type(*parse_db_data_type_str_with_precision(db_type)) class BigQueryLoadJob(LoadJob, FollowupJob): @@ -173,12 +175,17 @@ def gen_key_table_clauses( class BigQueryClient(SqlJobClientWithStaging, SupportsStagingDestination): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: BigQueryClientConfiguration) -> None: + def __init__( + self, + schema: Schema, + config: BigQueryClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: sql_client = BigQuerySqlClient( config.normalize_dataset_name(schema), + config.normalize_staging_dataset_name(schema), config.credentials, + capabilities, config.get_location(), config.http_timeout, config.retry_deadline, @@ -221,7 +228,7 @@ def restore_file_load(self, file_path: str) -> LoadJob: file_path, f"The server reason was: {reason}" ) from gace else: - raise DestinationTransientException(gace) from gace + raise DatabaseTransientException(gace) from gace return job def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: @@ -266,7 +273,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> reason = BigQuerySqlClient._get_reason_from_errors(gace) if reason == "notFound": # google.api_core.exceptions.NotFound: 404 – table not found - raise UnknownTableException(table["name"]) from gace + raise DatabaseUndefinedRelation(gace) from gace elif ( reason == "duplicate" ): # google.api_core.exceptions.Conflict: 409 PUT – already exists @@ -277,13 +284,18 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> file_path, f"The server reason was: {reason}" ) from gace else: - raise DestinationTransientException(gace) from gace + raise DatabaseTransientException(gace) from gace return job def _get_table_update_sql( self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool ) -> List[str]: + # return empty columns which will skip table CREATE or ALTER + # to let BigQuery autodetect table from data + if self._should_autodetect_schema(table_name): + return [] + table: Optional[TTableSchema] = self.prepare_load_table(table_name) sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) canonical_name = self.sql_client.make_qualified_table_name(table_name) @@ -292,15 +304,15 @@ def _get_table_update_sql( c for c in new_columns if c.get("partition") or c.get(PARTITION_HINT, False) ]: if len(partition_list) > 1: - col_names = [self.capabilities.escape_identifier(c["name"]) for c in partition_list] + col_names = [self.sql_client.escape_column_name(c["name"]) for c in partition_list] raise DestinationSchemaWillNotUpdate( canonical_name, col_names, "Partition requested for more than one column" ) elif (c := partition_list[0])["data_type"] == "date": - sql[0] += f"\nPARTITION BY {self.capabilities.escape_identifier(c['name'])}" + sql[0] += f"\nPARTITION BY {self.sql_client.escape_column_name(c['name'])}" elif (c := partition_list[0])["data_type"] == "timestamp": sql[0] = ( - f"{sql[0]}\nPARTITION BY DATE({self.capabilities.escape_identifier(c['name'])})" + f"{sql[0]}\nPARTITION BY DATE({self.sql_client.escape_column_name(c['name'])})" ) # Automatic partitioning of an INT64 type requires us to be prescriptive - we treat the column as a UNIX timestamp. # This is due to the bounds requirement of GENERATE_ARRAY function for partitioning. @@ -309,12 +321,12 @@ def _get_table_update_sql( # See: https://dlthub.com/devel/dlt-ecosystem/destinations/bigquery#supported-column-hints elif (c := partition_list[0])["data_type"] == "bigint": sql[0] += ( - f"\nPARTITION BY RANGE_BUCKET({self.capabilities.escape_identifier(c['name'])}," + f"\nPARTITION BY RANGE_BUCKET({self.sql_client.escape_column_name(c['name'])}," " GENERATE_ARRAY(-172800000, 691200000, 86400))" ) if cluster_list := [ - self.capabilities.escape_identifier(c["name"]) + self.sql_client.escape_column_name(c["name"]) for c in new_columns if c.get("cluster") or c.get(CLUSTER_HINT, False) ]: @@ -365,8 +377,58 @@ def prepare_load_table( ) return table + def get_storage_tables( + self, table_names: Iterable[str] + ) -> Iterable[Tuple[str, TTableSchemaColumns]]: + """Gets table schemas from BigQuery using INFORMATION_SCHEMA or get_table for hidden datasets""" + if not self.sql_client.is_hidden_dataset: + return super().get_storage_tables(table_names) + + # use the api to get storage tables for hidden dataset + schema_tables: List[Tuple[str, TTableSchemaColumns]] = [] + for table_name in table_names: + try: + schema_table: TTableSchemaColumns = {} + table = self.sql_client.native_connection.get_table( + self.sql_client.make_qualified_table_name(table_name, escape=False), + retry=self.sql_client._default_retry, + timeout=self.config.http_timeout, + ) + for c in table.schema: + schema_c: TColumnSchema = { + "name": c.name, + "nullable": c.is_nullable, + **self._from_db_type(c.field_type, c.precision, c.scale), + } + schema_table[c.name] = schema_c + schema_tables.append((table_name, schema_table)) + except gcp_exceptions.NotFound: + # table is not present + schema_tables.append((table_name, {})) + return schema_tables + + def _get_info_schema_columns_query( + self, catalog_name: Optional[str], schema_name: str, folded_table_names: List[str] + ) -> Tuple[str, List[Any]]: + """Bigquery needs to scope the INFORMATION_SCHEMA.COLUMNS with project and dataset name so standard query generator cannot be used.""" + # escape schema and catalog names + catalog_name = self.capabilities.escape_identifier(catalog_name) + schema_name = self.capabilities.escape_identifier(schema_name) + + query = f""" +SELECT {",".join(self._get_storage_table_query_columns())} + FROM {catalog_name}.{schema_name}.INFORMATION_SCHEMA.COLUMNS +""" + if folded_table_names: + # placeholder for each table + table_placeholders = ",".join(["%s"] * len(folded_table_names)) + query += f"WHERE table_name IN ({table_placeholders}) " + query += "ORDER BY table_name, ordinal_position;" + + return query, folded_table_names + def _get_column_def_sql(self, column: TColumnSchema, table_format: TTableFormat = None) -> str: - name = self.capabilities.escape_identifier(column["name"]) + name = self.sql_client.escape_column_name(column["name"]) column_def_sql = ( f"{name} {self.type_mapper.to_db_type(column, table_format)} {self._gen_not_null(column.get('nullable', True))}" ) @@ -376,32 +438,6 @@ def _get_column_def_sql(self, column: TColumnSchema, table_format: TTableFormat column_def_sql += " OPTIONS (rounding_mode='ROUND_HALF_AWAY_FROM_ZERO')" return column_def_sql - def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: - schema_table: TTableSchemaColumns = {} - try: - table = self.sql_client.native_connection.get_table( - self.sql_client.make_qualified_table_name(table_name, escape=False), - retry=self.sql_client._default_retry, - timeout=self.config.http_timeout, - ) - partition_field = table.time_partitioning.field if table.time_partitioning else None - for c in table.schema: - schema_c: TColumnSchema = { - "name": c.name, - "nullable": c.is_nullable, - "unique": False, - "sort": False, - "primary_key": False, - "foreign_key": False, - "cluster": c.name in (table.clustering_fields or []), - "partition": c.name == partition_field, - **self._from_db_type(c.field_type, c.precision, c.scale), - } - schema_table[c.name] = schema_c - return True, schema_table - except gcp_exceptions.NotFound: - return False, schema_table - def _create_load_job(self, table: TTableSchema, file_path: str) -> bigquery.LoadJob: # append to table for merge loads (append to stage) and regular appends. table_name = table["name"] @@ -417,12 +453,6 @@ def _create_load_job(self, table: TTableSchema, file_path: str) -> bigquery.Load source_format = bigquery.SourceFormat.NEWLINE_DELIMITED_JSON decimal_target_types: Optional[List[str]] = None if ext == "parquet": - # if table contains complex types, we cannot load with parquet - if table_schema_has_type(table, "complex"): - raise LoadJobTerminalException( - file_path, - "Bigquery cannot load into JSON data type from parquet. Use jsonl instead.", - ) source_format = bigquery.SourceFormat.PARQUET # parquet needs NUMERIC type auto-detection decimal_target_types = ["NUMERIC", "BIGNUMERIC"] @@ -437,6 +467,19 @@ def _create_load_job(self, table: TTableSchema, file_path: str) -> bigquery.Load ignore_unknown_values=False, max_bad_records=0, ) + if self._should_autodetect_schema(table_name): + # allow BigQuery to infer and evolve the schema, note that dlt is not + # creating such tables at all + job_config.autodetect = True + job_config.schema_update_options = bigquery.SchemaUpdateOption.ALLOW_FIELD_ADDITION + job_config.create_disposition = bigquery.CreateDisposition.CREATE_IF_NEEDED + elif ext == "parquet" and table_schema_has_type(table, "complex"): + # if table contains complex types, we cannot load with parquet + raise LoadJobTerminalException( + file_path, + "Bigquery cannot load into JSON data type from parquet. Enable autodetect_schema in" + " config or via BigQuery adapter or use jsonl format instead.", + ) if bucket_path: return self.sql_client.native_connection.load_table_from_uri( @@ -465,6 +508,11 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_db_type(bq_t, precision, scale) + def _should_autodetect_schema(self, table_name: str) -> bool: + return get_inherited_table_hint( + self.schema._schema_tables, table_name, AUTODETECT_SCHEMA_HINT, allow_none=True + ) or (self.config.autodetect_schema and table_name not in self.schema.dlt_table_names()) + def _streaming_load( sql_client: SqlClientBase[BigQueryClient], items: List[Dict[Any, Any]], table: Dict[str, Any] diff --git a/dlt/destinations/impl/bigquery/bigquery_adapter.py b/dlt/destinations/impl/bigquery/bigquery_adapter.py index 38828249ff..4dee572f57 100644 --- a/dlt/destinations/impl/bigquery/bigquery_adapter.py +++ b/dlt/destinations/impl/bigquery/bigquery_adapter.py @@ -20,6 +20,7 @@ ROUND_HALF_EVEN_HINT: Literal["x-bigquery-round-half-even"] = "x-bigquery-round-half-even" TABLE_EXPIRATION_HINT: Literal["x-bigquery-table-expiration"] = "x-bigquery-table-expiration" TABLE_DESCRIPTION_HINT: Literal["x-bigquery-table-description"] = "x-bigquery-table-description" +AUTODETECT_SCHEMA_HINT: Literal["x-bigquery-autodetect-schema"] = "x-bigquery-autodetect-schema" def bigquery_adapter( @@ -31,6 +32,7 @@ def bigquery_adapter( table_description: Optional[str] = None, table_expiration_datetime: Optional[str] = None, insert_api: Optional[Literal["streaming", "default"]] = None, + autodetect_schema: Optional[bool] = None, ) -> DltResource: """ Prepares data for loading into BigQuery. @@ -62,6 +64,8 @@ def bigquery_adapter( If "streaming" is chosen, the streaming API (https://cloud.google.com/bigquery/docs/streaming-data-into-bigquery) is used. NOTE: due to BigQuery features, streaming insert is only available for `append` write_disposition. + autodetect_schema (bool, optional): If set to True, BigQuery schema autodetection will be used to create data tables. This + allows to create structured types from nested data. Returns: A `DltResource` object that is ready to be loaded into BigQuery. @@ -136,6 +140,9 @@ def bigquery_adapter( ) additional_table_hints[TABLE_DESCRIPTION_HINT] = table_description + if autodetect_schema: + additional_table_hints[AUTODETECT_SCHEMA_HINT] = autodetect_schema + if table_expiration_datetime: if not isinstance(table_expiration_datetime, str): raise ValueError( diff --git a/dlt/destinations/impl/bigquery/configuration.py b/dlt/destinations/impl/bigquery/configuration.py index f69e85ca3d..47cc997a4a 100644 --- a/dlt/destinations/impl/bigquery/configuration.py +++ b/dlt/destinations/impl/bigquery/configuration.py @@ -14,26 +14,26 @@ class BigQueryClientConfiguration(DestinationClientDwhWithStagingConfiguration): destination_type: Final[str] = dataclasses.field(default="bigquery", init=False, repr=False, compare=False) # type: ignore credentials: GcpServiceAccountCredentials = None location: str = "US" - - http_timeout: float = 15.0 # connection timeout for http request to BigQuery api - file_upload_timeout: float = 30 * 60.0 # a timeout for file upload when loading local files - retry_deadline: float = ( - 60.0 # how long to retry the operation in case of error, the backoff 60 s. - ) + has_case_sensitive_identifiers: bool = True + """If True then dlt expects to load data into case sensitive dataset""" + should_set_case_sensitivity_on_new_dataset: bool = False + """If True, dlt will set case sensitivity flag on created datasets that corresponds to naming convention""" + + http_timeout: float = 15.0 + """connection timeout for http request to BigQuery api""" + file_upload_timeout: float = 30 * 60.0 + """a timeout for file upload when loading local files""" + retry_deadline: float = 60.0 + """How long to retry the operation in case of error, the backoff 60 s.""" batch_size: int = 500 + """Number of rows in streaming insert batch""" + autodetect_schema: bool = False + """Allow BigQuery to autodetect schemas and create data tables""" __config_gen_annotations__: ClassVar[List[str]] = ["location"] def get_location(self) -> str: - if self.location != "US": - return self.location - # default was changed in credentials, emit deprecation message - if self.credentials.location != "US": - warnings.warn( - "Setting BigQuery location in the credentials is deprecated. Please set the" - " location directly in bigquery section ie. destinations.bigquery.location='EU'" - ) - return self.credentials.location + return self.location def fingerprint(self) -> str: """Returns a fingerprint of project_id""" diff --git a/dlt/destinations/impl/bigquery/factory.py b/dlt/destinations/impl/bigquery/factory.py index bee55fa164..14f976113f 100644 --- a/dlt/destinations/impl/bigquery/factory.py +++ b/dlt/destinations/impl/bigquery/factory.py @@ -1,10 +1,13 @@ import typing as t -from dlt.destinations.impl.bigquery.configuration import BigQueryClientConfiguration +from dlt.common.normalizers.naming import NamingConvention from dlt.common.configuration.specs import GcpServiceAccountCredentials -from dlt.destinations.impl.bigquery import capabilities +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.data_writers.escape import escape_hive_identifier, format_bigquery_datetime_literal from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.destinations.impl.bigquery.configuration import BigQueryClientConfiguration + if t.TYPE_CHECKING: from dlt.destinations.impl.bigquery.bigquery import BigQueryClient @@ -13,8 +16,35 @@ class bigquery(Destination[BigQueryClientConfiguration, "BigQueryClient"]): spec = BigQueryClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "jsonl" + caps.supported_loader_file_formats = ["jsonl", "parquet"] + caps.preferred_staging_file_format = "parquet" + caps.supported_staging_file_formats = ["parquet", "jsonl"] + # BigQuery is by default case sensitive but that cannot be turned off for a dataset + # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity + caps.escape_identifier = escape_hive_identifier + caps.escape_literal = None + caps.has_case_sensitive_identifiers = True + caps.casefold_identifier = str + # BQ limit is 4GB but leave a large headroom since buffered writer does not preemptively check size + caps.recommended_file_size = int(1024 * 1024 * 1024) + caps.format_datetime_literal = format_bigquery_datetime_literal + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (76, 38) + caps.max_identifier_length = 1024 + caps.max_column_identifier_length = 300 + caps.max_query_length = 1024 * 1024 + caps.is_max_query_length_in_bytes = False + caps.max_text_data_type_length = 10 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_ddl_transactions = False + caps.supports_clone_table = True + caps.schema_supports_numeric_precision = False # no precision information in BigQuery + caps.supported_merge_strategies = ["delete-insert", "scd2"] + + return caps @property def client_class(self) -> t.Type["BigQueryClient"]: @@ -26,14 +56,39 @@ def __init__( self, credentials: t.Optional[GcpServiceAccountCredentials] = None, location: t.Optional[str] = None, + has_case_sensitive_identifiers: bool = None, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, **kwargs: t.Any, ) -> None: + """Configure the MsSql destination to use in a pipeline. + + All arguments provided here supersede other configuration sources such as environment variables and dlt config files. + + Args: + credentials: Credentials to connect to the mssql database. Can be an instance of `GcpServiceAccountCredentials` or + a dict or string with service accounts credentials as used in the Google Cloud + location: A location where the datasets will be created, eg. "EU". The default is "US" + has_case_sensitive_identifiers: Is the dataset case-sensitive, defaults to True + **kwargs: Additional arguments passed to the destination config + """ super().__init__( credentials=credentials, location=location, + has_case_sensitive_identifiers=has_case_sensitive_identifiers, destination_name=destination_name, environment=environment, **kwargs, ) + + @classmethod + def adjust_capabilities( + cls, + caps: DestinationCapabilitiesContext, + config: BigQueryClientConfiguration, + naming: t.Optional[NamingConvention], + ) -> DestinationCapabilitiesContext: + # modify the caps if case sensitive identifiers are requested + if config.should_set_case_sensitivity_on_new_dataset: + caps.has_case_sensitive_identifiers = config.has_case_sensitive_identifiers + return super().adjust_capabilities(caps, config, naming) diff --git a/dlt/destinations/impl/bigquery/sql_client.py b/dlt/destinations/impl/bigquery/sql_client.py index 21086a4db6..dfc4094e7b 100644 --- a/dlt/destinations/impl/bigquery/sql_client.py +++ b/dlt/destinations/impl/bigquery/sql_client.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Any, AnyStr, ClassVar, Iterator, List, Optional, Sequence +from typing import Any, AnyStr, ClassVar, Iterator, List, Optional, Sequence, Generator import google.cloud.bigquery as bigquery # noqa: I250 from google.api_core import exceptions as api_core_exceptions @@ -8,6 +8,7 @@ from google.cloud.bigquery.dbapi import Connection as DbApiConnection, Cursor as BQDbApiCursor from google.cloud.bigquery.dbapi import exceptions as dbapi_exceptions +from dlt.common import logger from dlt.common.configuration.specs import GcpServiceAccountCredentialsWithoutDefaults from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.typing import StrAny @@ -16,7 +17,6 @@ DatabaseTransientException, DatabaseUndefinedRelation, ) -from dlt.destinations.impl.bigquery import capabilities from dlt.destinations.sql_client import ( DBApiCursorImpl, SqlClientBase, @@ -44,29 +44,43 @@ class BigQueryDBApiCursorImpl(DBApiCursorImpl): """Use native BigQuery data frame support if available""" native_cursor: BQDbApiCursor # type: ignore + df_iterator: Generator[Any, None, None] - def df(self, chunk_size: int = None, **kwargs: Any) -> DataFrame: - if chunk_size is not None: - return super().df(chunk_size=chunk_size) + def __init__(self, curr: DBApiCursor) -> None: + super().__init__(curr) + self.df_iterator = None + + def df(self, chunk_size: Optional[int] = None, **kwargs: Any) -> DataFrame: query_job: bigquery.QueryJob = getattr( self.native_cursor, "_query_job", self.native_cursor.query_job ) - + if self.df_iterator: + return next(self.df_iterator, None) try: + if chunk_size is not None: + # create iterator with given page size + self.df_iterator = query_job.result(page_size=chunk_size).to_dataframe_iterable() + return next(self.df_iterator, None) return query_job.to_dataframe(**kwargs) - except ValueError: + except ValueError as ex: # no pyarrow/db-types, fallback to our implementation - return super().df() + logger.warning(f"Native BigQuery pandas reader could not be used: {str(ex)}") + return super().df(chunk_size=chunk_size) + + def close(self) -> None: + if self.df_iterator: + self.df_iterator.close() class BigQuerySqlClient(SqlClientBase[bigquery.Client], DBTransaction): dbapi: ClassVar[DBApi] = bq_dbapi - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__( self, dataset_name: str, + staging_dataset_name: str, credentials: GcpServiceAccountCredentialsWithoutDefaults, + capabilities: DestinationCapabilitiesContext, location: str = "US", http_timeout: float = 15.0, retry_deadline: float = 60.0, @@ -75,7 +89,7 @@ def __init__( self.credentials: GcpServiceAccountCredentialsWithoutDefaults = credentials self.location = location self.http_timeout = http_timeout - super().__init__(credentials.project_id, dataset_name) + super().__init__(credentials.project_id, dataset_name, staging_dataset_name, capabilities) self._default_retry = bigquery.DEFAULT_RETRY.with_deadline(retry_deadline) self._default_query = bigquery.QueryJobConfig( @@ -177,20 +191,24 @@ def has_dataset(self) -> bool: return False def create_dataset(self) -> None: - self._client.create_dataset( - self.fully_qualified_dataset_name(escape=False), - retry=self._default_retry, - timeout=self.http_timeout, - ) - - def drop_dataset(self) -> None: - self._client.delete_dataset( - self.fully_qualified_dataset_name(escape=False), - not_found_ok=True, - delete_contents=True, - retry=self._default_retry, - timeout=self.http_timeout, - ) + dataset = bigquery.Dataset(self.fully_qualified_dataset_name(escape=False)) + dataset.location = self.location + dataset.is_case_insensitive = not self.capabilities.has_case_sensitive_identifiers + try: + self._client.create_dataset( + dataset, + retry=self._default_retry, + timeout=self.http_timeout, + ) + except api_core_exceptions.GoogleAPICallError as gace: + reason = BigQuerySqlClient._get_reason_from_errors(gace) + if reason == "notFound": + # google.api_core.exceptions.NotFound: 404 – table not found + raise DatabaseUndefinedRelation(gace) from gace + elif reason in BQ_TERMINAL_REASONS: + raise DatabaseTerminalException(gace) from gace + else: + raise DatabaseTransientException(gace) from gace def execute_sql( self, sql: AnyStr, *args: Any, **kwargs: Any @@ -221,14 +239,19 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB # will close all cursors conn.close() - def fully_qualified_dataset_name(self, escape: bool = True) -> str: + def catalog_name(self, escape: bool = True) -> Optional[str]: + project_id = self.capabilities.casefold_identifier(self.credentials.project_id) if escape: - project_id = self.capabilities.escape_identifier(self.credentials.project_id) - dataset_name = self.capabilities.escape_identifier(self.dataset_name) - else: - project_id = self.credentials.project_id - dataset_name = self.dataset_name - return f"{project_id}.{dataset_name}" + project_id = self.capabilities.escape_identifier(project_id) + return project_id + + @property + def is_hidden_dataset(self) -> bool: + """Tells if the dataset associated with sql_client is a hidden dataset. + + Hidden datasets are not present in information schema. + """ + return self.dataset_name.startswith("_") @classmethod def _make_database_exception(cls, ex: Exception) -> Exception: diff --git a/dlt/destinations/impl/clickhouse/__init__.py b/dlt/destinations/impl/clickhouse/__init__.py index bead136828..e69de29bb2 100644 --- a/dlt/destinations/impl/clickhouse/__init__.py +++ b/dlt/destinations/impl/clickhouse/__init__.py @@ -1,53 +0,0 @@ -import sys - -from dlt.common.pendulum import pendulum -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.common.data_writers.escape import ( - escape_clickhouse_identifier, - escape_clickhouse_literal, - format_clickhouse_datetime_literal, -) -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.destinations.impl.clickhouse.clickhouse_adapter import clickhouse_adapter - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "jsonl" - caps.supported_loader_file_formats = ["parquet", "jsonl"] - caps.preferred_staging_file_format = "jsonl" - caps.supported_staging_file_formats = ["parquet", "jsonl"] - - caps.format_datetime_literal = format_clickhouse_datetime_literal - caps.escape_identifier = escape_clickhouse_identifier - caps.escape_literal = escape_clickhouse_literal - - # https://stackoverflow.com/questions/68358686/what-is-the-maximum-length-of-a-column-in-clickhouse-can-it-be-modified - caps.max_identifier_length = 255 - caps.max_column_identifier_length = 255 - - # ClickHouse has no max `String` type length. - caps.max_text_data_type_length = sys.maxsize - - caps.schema_supports_numeric_precision = True - # Use 'Decimal128' with these defaults. - # https://clickhouse.com/docs/en/sql-reference/data-types/decimal - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - # Use 'Decimal256' with these defaults. - caps.wei_precision = (76, 0) - caps.timestamp_precision = 6 - - # https://clickhouse.com/docs/en/operations/settings/settings#max_query_size - caps.is_max_query_length_in_bytes = True - caps.max_query_length = 262144 - - # ClickHouse has limited support for transactional semantics, especially for `ReplicatedMergeTree`, - # the default ClickHouse Cloud engine. It does, however, provide atomicity for individual DDL operations like `ALTER TABLE`. - # https://clickhouse-driver.readthedocs.io/en/latest/dbapi.html#clickhouse_driver.dbapi.connection.Connection.commit - # https://clickhouse.com/docs/en/guides/developer/transactional#transactions-commit-and-rollback - caps.supports_transactions = False - caps.supports_ddl_transactions = False - - caps.supports_truncate_command = True - - return caps diff --git a/dlt/destinations/impl/clickhouse/clickhouse.py b/dlt/destinations/impl/clickhouse/clickhouse.py index cf1f1bc857..148fca3f1e 100644 --- a/dlt/destinations/impl/clickhouse/clickhouse.py +++ b/dlt/destinations/impl/clickhouse/clickhouse.py @@ -2,21 +2,18 @@ import re from copy import deepcopy from textwrap import dedent -from typing import ClassVar, Optional, Dict, List, Sequence, cast, Tuple +from typing import Optional, List, Sequence, cast from urllib.parse import urlparse import clickhouse_connect from clickhouse_connect.driver.tools import insert_file -import dlt from dlt import config from dlt.common.configuration.specs import ( CredentialsConfiguration, AzureCredentialsWithoutDefaults, - GcpCredentials, AwsCredentialsWithoutDefaults, ) -from dlt.destinations.exceptions import DestinationTransientException from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( SupportsStagingDestination, @@ -29,27 +26,27 @@ from dlt.common.schema.typing import ( TTableFormat, TTableSchema, - TColumnHint, TColumnType, - TTableSchemaColumns, - TColumnSchemaBase, ) from dlt.common.storages import FileStorage from dlt.destinations.exceptions import LoadJobTerminalException -from dlt.destinations.impl.clickhouse import capabilities -from dlt.destinations.impl.clickhouse.clickhouse_adapter import ( - TTableEngineType, - TABLE_ENGINE_TYPE_HINT, -) from dlt.destinations.impl.clickhouse.configuration import ( ClickHouseClientConfiguration, ) from dlt.destinations.impl.clickhouse.sql_client import ClickHouseSqlClient -from dlt.destinations.impl.clickhouse.utils import ( - convert_storage_to_http_scheme, +from dlt.destinations.impl.clickhouse.typing import ( + HINT_TO_CLICKHOUSE_ATTR, + TABLE_ENGINE_TYPE_TO_CLICKHOUSE_ATTR, +) +from dlt.destinations.impl.clickhouse.typing import ( + TTableEngineType, + TABLE_ENGINE_TYPE_HINT, FILE_FORMAT_TO_TABLE_FUNCTION_MAPPING, SUPPORTED_FILE_FORMATS, ) +from dlt.destinations.impl.clickhouse.utils import ( + convert_storage_to_http_scheme, +) from dlt.destinations.job_client_impl import ( SqlJobClientBase, SqlJobClientWithStaging, @@ -59,18 +56,6 @@ from dlt.destinations.type_mapping import TypeMapper -HINT_TO_CLICKHOUSE_ATTR: Dict[TColumnHint, str] = { - "primary_key": "PRIMARY KEY", - "unique": "", # No unique constraints available in ClickHouse. - "foreign_key": "", # No foreign key constraints support in ClickHouse. -} - -TABLE_ENGINE_TYPE_TO_CLICKHOUSE_ATTR: Dict[TTableEngineType, str] = { - "merge_tree": "MergeTree", - "replicated_merge_tree": "ReplicatedMergeTree", -} - - class ClickHouseTypeMapper(TypeMapper): sct_to_unbound_dbt = { "complex": "String", @@ -114,7 +99,8 @@ def from_db_type( if db_type == "DateTime('UTC')": db_type = "DateTime" if datetime_match := re.match( - r"DateTime64(?:\((?P\d+)(?:,?\s*'(?PUTC)')?\))?", db_type + r"DateTime64(?:\((?P\d+)(?:,?\s*'(?PUTC)')?\))?", + db_type, ): if datetime_match["precision"]: precision = int(datetime_match["precision"]) @@ -132,7 +118,7 @@ def from_db_type( db_type = "Decimal" if db_type == "Decimal" and (precision, scale) == self.capabilities.wei_precision: - return dict(data_type="wei") + return cast(TColumnType, dict(data_type="wei")) return super().from_db_type(db_type, precision, scale) @@ -162,7 +148,7 @@ def __init__( compression = "auto" - # Don't use dbapi driver for local files. + # Don't use the DBAPI driver for local files. if not bucket_path: # Local filesystem. if ext == "jsonl": @@ -183,8 +169,8 @@ def __init__( fmt=clickhouse_format, settings={ "allow_experimental_lightweight_delete": 1, - # "allow_experimental_object_type": 1, "enable_http_compression": 1, + "date_time_input_format": "best_effort", }, compression=compression, ) @@ -202,13 +188,7 @@ def __init__( compression = "none" if config.get("data_writer.disable_compression") else "gz" if bucket_scheme in ("s3", "gs", "gcs"): - if isinstance(staging_credentials, AwsCredentialsWithoutDefaults): - bucket_http_url = convert_storage_to_http_scheme( - bucket_url, endpoint=staging_credentials.endpoint_url - ) - access_key_id = staging_credentials.aws_access_key_id - secret_access_key = staging_credentials.aws_secret_access_key - else: + if not isinstance(staging_credentials, AwsCredentialsWithoutDefaults): raise LoadJobTerminalException( file_path, dedent( @@ -220,6 +200,11 @@ def __init__( ).strip(), ) + bucket_http_url = convert_storage_to_http_scheme( + bucket_url, endpoint=staging_credentials.endpoint_url + ) + access_key_id = staging_credentials.aws_access_key_id + secret_access_key = staging_credentials.aws_secret_access_key auth = "NOSIGN" if access_key_id and secret_access_key: auth = f"'{access_key_id}','{secret_access_key}'" @@ -289,15 +274,18 @@ def requires_temp_table_for_delete(cls) -> bool: class ClickHouseClient(SqlJobClientWithStaging, SupportsStagingDestination): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - def __init__( self, schema: Schema, config: ClickHouseClientConfiguration, + capabilities: DestinationCapabilitiesContext, ) -> None: self.sql_client: ClickHouseSqlClient = ClickHouseSqlClient( - config.normalize_dataset_name(schema), config.credentials + config.normalize_dataset_name(schema), + config.normalize_staging_dataset_name(schema), + config.credentials, + capabilities, + config, ) super().__init__(schema, config, self.sql_client) self.config: ClickHouseClientConfiguration = config @@ -310,10 +298,10 @@ def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> Li def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: # Build column definition. # The primary key and sort order definition is defined outside column specification. - hints_str = " ".join( + hints_ = " ".join( self.active_hints.get(hint) for hint in self.active_hints.keys() - if c.get(hint, False) is True + if c.get(cast(str, hint), False) is True and hint not in ("primary_key", "sort") and hint in self.active_hints ) @@ -327,7 +315,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non ) return ( - f"{self.capabilities.escape_identifier(c['name'])} {type_with_nullability_modifier} {hints_str}" + f"{self.sql_client.escape_column_name(c['name'])} {type_with_nullability_modifier} {hints_}" .strip() ) @@ -342,7 +330,10 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> ) def _get_table_update_sql( - self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool + self, + table_name: str, + new_columns: Sequence[TColumnSchema], + generate_alter: bool, ) -> List[str]: table: TTableSchema = self.prepare_load_table(table_name, self.in_staging_mode) sql = SqlJobClientBase._get_table_update_sql(self, table_name, new_columns, generate_alter) @@ -350,14 +341,20 @@ def _get_table_update_sql( if generate_alter: return sql - # Default to 'ReplicatedMergeTree' if user didn't explicitly set a table engine hint. + # Default to 'MergeTree' if the user didn't explicitly set a table engine hint. + # Clickhouse Cloud will automatically pick `SharedMergeTree` for this option, + # so it will work on both local and cloud instances of CH. table_type = cast( - TTableEngineType, table.get(TABLE_ENGINE_TYPE_HINT, "replicated_merge_tree") + TTableEngineType, + table.get( + cast(str, TABLE_ENGINE_TYPE_HINT), + self.config.table_engine_type, + ), ) sql[0] = f"{sql[0]}\nENGINE = {TABLE_ENGINE_TYPE_TO_CLICKHOUSE_ATTR.get(table_type)}" if primary_key_list := [ - self.capabilities.escape_identifier(c["name"]) + self.sql_client.escape_column_name(c["name"]) for c in new_columns if c.get("primary_key") ]: @@ -367,34 +364,6 @@ def _get_table_update_sql( return sql - def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: - fields = self._get_storage_table_query_columns() - db_params = self.sql_client.make_qualified_table_name(table_name, escape=False).split( - ".", 3 - ) - query = f'SELECT {",".join(fields)} FROM INFORMATION_SCHEMA.COLUMNS WHERE ' - if len(db_params) == 3: - query += "table_catalog = %s AND " - query += "table_schema = %s AND table_name = %s ORDER BY ordinal_position;" - rows = self.sql_client.execute_sql(query, *db_params) - - # If no rows we assume that table does not exist. - schema_table: TTableSchemaColumns = {} - if len(rows) == 0: - return False, schema_table - for c in rows: - numeric_precision = ( - c[3] if self.capabilities.schema_supports_numeric_precision else None - ) - numeric_scale = c[4] if self.capabilities.schema_supports_numeric_precision else None - schema_c: TColumnSchemaBase = { - "name": c[0], - "nullable": bool(c[2]), - **self._from_db_type(c[1], numeric_precision, numeric_scale), - } - schema_table[c[0]] = schema_c # type: ignore - return True, schema_table - @staticmethod def _gen_not_null(v: bool) -> str: # ClickHouse fields are not nullable by default. diff --git a/dlt/destinations/impl/clickhouse/clickhouse_adapter.py b/dlt/destinations/impl/clickhouse/clickhouse_adapter.py index 1bbde8e45d..dc030ef88c 100644 --- a/dlt/destinations/impl/clickhouse/clickhouse_adapter.py +++ b/dlt/destinations/impl/clickhouse/clickhouse_adapter.py @@ -1,12 +1,15 @@ -from typing import Any, Literal, Set, get_args, Dict +from typing import Any, Dict +from dlt.destinations.impl.clickhouse.configuration import TTableEngineType +from dlt.destinations.impl.clickhouse.typing import ( + TABLE_ENGINE_TYPES, + TABLE_ENGINE_TYPE_HINT, +) from dlt.destinations.utils import ensure_resource from dlt.extract import DltResource from dlt.extract.items import TTableHintTemplate -TTableEngineType = Literal["merge_tree", "replicated_merge_tree"] - """ The table engine (type of table) determines: @@ -19,8 +22,6 @@ See https://clickhouse.com/docs/en/engines/table-engines. """ -TABLE_ENGINE_TYPES: Set[TTableEngineType] = set(get_args(TTableEngineType)) -TABLE_ENGINE_TYPE_HINT: Literal["x-table-engine-type"] = "x-table-engine-type" def clickhouse_adapter(data: Any, table_engine_type: TTableEngineType = None) -> DltResource: diff --git a/dlt/destinations/impl/clickhouse/configuration.py b/dlt/destinations/impl/clickhouse/configuration.py index bbff6e0a9c..fbda58abc7 100644 --- a/dlt/destinations/impl/clickhouse/configuration.py +++ b/dlt/destinations/impl/clickhouse/configuration.py @@ -1,16 +1,13 @@ import dataclasses -from typing import ClassVar, List, Any, Final, Literal, cast, Optional +from typing import ClassVar, Dict, List, Any, Final, cast, Optional from dlt.common.configuration import configspec from dlt.common.configuration.specs import ConnectionStringCredentials from dlt.common.destination.reference import ( DestinationClientDwhWithStagingConfiguration, ) -from dlt.common.libs.sql_alchemy import URL from dlt.common.utils import digest128 - - -TSecureConnection = Literal[0, 1] +from dlt.destinations.impl.clickhouse.typing import TSecureConnection, TTableEngineType @configspec(init=False) @@ -34,10 +31,6 @@ class ClickHouseCredentials(ConnectionStringCredentials): """Timeout for establishing connection. Defaults to 10 seconds.""" send_receive_timeout: int = 300 """Timeout for sending and receiving data. Defaults to 300 seconds.""" - dataset_table_separator: str = "___" - """Separator for dataset table names, defaults to '___', i.e. 'database.dataset___table'.""" - dataset_sentinel_table_name: str = "dlt_sentinel_table" - """Special table to mark dataset as existing""" gcp_access_key_id: Optional[str] = None """When loading from a gcp bucket, you need to provide gcp interoperable keys""" gcp_secret_access_key: Optional[str] = None @@ -59,37 +52,44 @@ def parse_native_representation(self, native_value: Any) -> None: self.query.get("send_receive_timeout", self.send_receive_timeout) ) self.secure = cast(TSecureConnection, int(self.query.get("secure", self.secure))) - if not self.is_partial(): - self.resolve() - def to_url(self) -> URL: - url = super().to_url() - url = url.update_query_pairs( - [ - ("connect_timeout", str(self.connect_timeout)), - ("send_receive_timeout", str(self.send_receive_timeout)), - ("secure", str(1) if self.secure else str(0)), - # Toggle experimental settings. These are necessary for certain datatypes and not optional. - ("allow_experimental_lightweight_delete", "1"), - # ("allow_experimental_object_type", "1"), - ("enable_http_compression", "1"), - ] + def get_query(self) -> Dict[str, Any]: + query = dict(super().get_query()) + query.update( + { + "connect_timeout": str(self.connect_timeout), + "send_receive_timeout": str(self.send_receive_timeout), + "secure": 1 if self.secure else 0, + "allow_experimental_lightweight_delete": 1, + "enable_http_compression": 1, + "date_time_input_format": "best_effort", + } ) - return url + return query @configspec class ClickHouseClientConfiguration(DestinationClientDwhWithStagingConfiguration): - destination_type: Final[str] = dataclasses.field(default="clickhouse", init=False, repr=False, compare=False) # type: ignore[misc] + destination_type: Final[str] = dataclasses.field( # type: ignore[misc] + default="clickhouse", init=False, repr=False, compare=False + ) credentials: ClickHouseCredentials = None - # Primary key columns are used to build a sparse primary index which allows for efficient data retrieval, - # but they do not enforce uniqueness constraints. It permits duplicate values even for the primary key - # columns within the same granule. - # See: https://clickhouse.com/docs/en/optimize/sparse-primary-indexes + dataset_table_separator: str = "___" + """Separator for dataset table names, defaults to '___', i.e. 'database.dataset___table'.""" + table_engine_type: Optional[TTableEngineType] = "merge_tree" + """The default table engine to use. Defaults to 'merge_tree'. Other implemented options are 'shared_merge_tree' and 'replicated_merge_tree'.""" + dataset_sentinel_table_name: str = "dlt_sentinel_table" + """Special table to mark dataset as existing""" + + __config_gen_annotations__: ClassVar[List[str]] = [ + "dataset_table_separator", + "dataset_sentinel_table_name", + "table_engine_type", + ] def fingerprint(self) -> str: - """Returns a fingerprint of host part of a connection string.""" + """Returns a fingerprint of the host part of a connection string.""" if self.credentials and self.credentials.host: return digest128(self.credentials.host) return "" diff --git a/dlt/destinations/impl/clickhouse/factory.py b/dlt/destinations/impl/clickhouse/factory.py index e5b8fc0e6a..93da6c866a 100644 --- a/dlt/destinations/impl/clickhouse/factory.py +++ b/dlt/destinations/impl/clickhouse/factory.py @@ -1,7 +1,13 @@ +import sys import typing as t +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.data_writers.escape import ( + escape_clickhouse_identifier, + escape_clickhouse_literal, + format_clickhouse_datetime_literal, +) from dlt.common.destination import Destination, DestinationCapabilitiesContext -from dlt.destinations.impl.clickhouse import capabilities from dlt.destinations.impl.clickhouse.configuration import ( ClickHouseClientConfiguration, ClickHouseCredentials, @@ -16,8 +22,53 @@ class clickhouse(Destination[ClickHouseClientConfiguration, "ClickHouseClient"]): spec = ClickHouseClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "jsonl" + caps.supported_loader_file_formats = ["parquet", "jsonl"] + caps.preferred_staging_file_format = "jsonl" + caps.supported_staging_file_formats = ["parquet", "jsonl"] + + caps.format_datetime_literal = format_clickhouse_datetime_literal + caps.escape_identifier = escape_clickhouse_identifier + caps.escape_literal = escape_clickhouse_literal + # docs are very unclear https://clickhouse.com/docs/en/sql-reference/syntax + # taking into account other sources: identifiers are case sensitive + caps.has_case_sensitive_identifiers = True + # and store as is in the information schema + caps.casefold_identifier = str + + # https://stackoverflow.com/questions/68358686/what-is-the-maximum-length-of-a-column-in-clickhouse-can-it-be-modified + caps.max_identifier_length = 255 + caps.max_column_identifier_length = 255 + + # ClickHouse has no max `String` type length. + caps.max_text_data_type_length = sys.maxsize + + caps.schema_supports_numeric_precision = True + # Use 'Decimal128' with these defaults. + # https://clickhouse.com/docs/en/sql-reference/data-types/decimal + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + # Use 'Decimal256' with these defaults. + caps.wei_precision = (76, 0) + caps.timestamp_precision = 6 + + # https://clickhouse.com/docs/en/operations/settings/settings#max_query_size + caps.is_max_query_length_in_bytes = True + caps.max_query_length = 262144 + + # ClickHouse has limited support for transactional semantics, especially for `ReplicatedMergeTree`, + # the default ClickHouse Cloud engine. It does, however, provide atomicity for individual DDL operations like `ALTER TABLE`. + # https://clickhouse-driver.readthedocs.io/en/latest/dbapi.html#clickhouse_driver.dbapi.connection.Connection.commit + # https://clickhouse.com/docs/en/guides/developer/transactional#transactions-commit-and-rollback + caps.supports_transactions = False + caps.supports_ddl_transactions = False + + caps.supports_truncate_command = True + + caps.supported_merge_strategies = ["delete-insert", "scd2"] + + return caps @property def client_class(self) -> t.Type["ClickHouseClient"]: diff --git a/dlt/destinations/impl/clickhouse/sql_client.py b/dlt/destinations/impl/clickhouse/sql_client.py index 8fb89c90cd..25914e4093 100644 --- a/dlt/destinations/impl/clickhouse/sql_client.py +++ b/dlt/destinations/impl/clickhouse/sql_client.py @@ -1,3 +1,5 @@ +import datetime # noqa: I251 +from clickhouse_driver import dbapi as clickhouse_dbapi # type: ignore[import-untyped] from contextlib import contextmanager from typing import ( Iterator, @@ -7,21 +9,32 @@ Optional, Sequence, ClassVar, + Literal, + Tuple, + cast, ) -import clickhouse_driver # type: ignore[import-untyped] +import clickhouse_driver import clickhouse_driver.errors # type: ignore[import-untyped] from clickhouse_driver.dbapi import OperationalError # type: ignore[import-untyped] from clickhouse_driver.dbapi.extras import DictCursor # type: ignore[import-untyped] +from pendulum import DateTime # noqa: I251 from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.typing import DictStrAny from dlt.destinations.exceptions import ( DatabaseUndefinedRelation, DatabaseTransientException, DatabaseTerminalException, ) -from dlt.destinations.impl.clickhouse import capabilities -from dlt.destinations.impl.clickhouse.configuration import ClickHouseCredentials +from dlt.destinations.impl.clickhouse.configuration import ( + ClickHouseCredentials, + ClickHouseClientConfiguration, +) +from dlt.destinations.impl.clickhouse.typing import ( + TTableEngineType, + TABLE_ENGINE_TYPE_TO_CLICKHOUSE_ATTR, +) from dlt.destinations.sql_client import ( DBApiCursorImpl, SqlClientBase, @@ -32,6 +45,7 @@ from dlt.destinations.utils import _convert_to_old_pyformat +TDeployment = Literal["ClickHouseOSS", "ClickHouseCloud"] TRANSACTIONS_UNSUPPORTED_WARNING_MESSAGE = ( "ClickHouse does not support transactions! Each statement is auto-committed separately." ) @@ -44,19 +58,27 @@ class ClickHouseDBApiCursorImpl(DBApiCursorImpl): class ClickHouseSqlClient( SqlClientBase[clickhouse_driver.dbapi.connection.Connection], DBTransaction ): - dbapi: ClassVar[DBApi] = clickhouse_driver.dbapi - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() + dbapi: ClassVar[DBApi] = clickhouse_dbapi - def __init__(self, dataset_name: str, credentials: ClickHouseCredentials) -> None: - super().__init__(credentials.database, dataset_name) + def __init__( + self, + dataset_name: str, + staging_dataset_name: str, + credentials: ClickHouseCredentials, + capabilities: DestinationCapabilitiesContext, + config: ClickHouseClientConfiguration, + ) -> None: + super().__init__(credentials.database, dataset_name, staging_dataset_name, capabilities) self._conn: clickhouse_driver.dbapi.connection = None self.credentials = credentials self.database_name = credentials.database + self.config = config def has_dataset(self) -> bool: - sentinel_table = self.credentials.dataset_sentinel_table_name + # we do not need to normalize dataset_sentinel_table_name. + sentinel_table = self.config.dataset_sentinel_table_name return sentinel_table in [ - t.split(self.credentials.dataset_table_separator)[1] for t in self._list_tables() + t.split(self.config.dataset_table_separator)[1] for t in self._list_tables() ] def open_connection(self) -> clickhouse_driver.dbapi.connection.Connection: @@ -93,27 +115,49 @@ def execute_sql( return None if curr.description is None else curr.fetchall() def create_dataset(self) -> None: - # We create a sentinel table which defines wether we consider the dataset created + # We create a sentinel table which defines whether we consider the dataset created. sentinel_table_name = self.make_qualified_table_name( - self.credentials.dataset_sentinel_table_name - ) - self.execute_sql( - f"""CREATE TABLE {sentinel_table_name} (_dlt_id String NOT NULL PRIMARY KEY) ENGINE=ReplicatedMergeTree COMMENT 'internal dlt sentinel table'""" + self.config.dataset_sentinel_table_name ) + sentinel_table_type = cast(TTableEngineType, self.config.table_engine_type) + self.execute_sql(f""" + CREATE TABLE {sentinel_table_name} + (_dlt_id String NOT NULL PRIMARY KEY) + ENGINE={TABLE_ENGINE_TYPE_TO_CLICKHOUSE_ATTR.get(sentinel_table_type)} + COMMENT 'internal dlt sentinel table'""") def drop_dataset(self) -> None: + # always try to drop the sentinel table. + sentinel_table_name = self.make_qualified_table_name( + self.config.dataset_sentinel_table_name + ) + # drop a sentinel table + self.execute_sql(f"DROP TABLE {sentinel_table_name} SYNC") + # Since ClickHouse doesn't have schemas, we need to drop all tables in our virtual schema, # or collection of tables, that has the `dataset_name` as a prefix. - to_drop_results = self._list_tables() + to_drop_results = [ + f"{self.catalog_name()}.{self.capabilities.escape_identifier(table)}" + for table in self._list_tables() + ] for table in to_drop_results: # The "DROP TABLE" clause is discarded if we allow clickhouse_driver to handle parameter substitution. # This is because the driver incorrectly substitutes the entire query string, causing the "DROP TABLE" keyword to be omitted. # To resolve this, we are forced to provide the full query string here. - self.execute_sql( - f"""DROP TABLE {self.capabilities.escape_identifier(self.database_name)}.{self.capabilities.escape_identifier(table)} SYNC""" - ) + self.execute_sql(f"DROP TABLE {table} SYNC") + + def drop_tables(self, *tables: str) -> None: + """Drops a set of tables if they exist""" + if not tables: + return + statements = [ + f"DROP TABLE IF EXISTS {self.make_qualified_table_name(table)} SYNC;" + for table in tables + ] + self.execute_many(statements) def _list_tables(self) -> List[str]: + catalog_name, table_name = self.make_qualified_table_name_path("%", escape=False) rows = self.execute_sql( """ SELECT name @@ -121,13 +165,20 @@ def _list_tables(self) -> List[str]: WHERE database = %s AND name LIKE %s """, - ( - self.database_name, - f"{self.dataset_name}{self.credentials.dataset_table_separator}%", - ), + catalog_name, + table_name, ) return [row[0] for row in rows] + @staticmethod + def _sanitise_dbargs(db_args: DictStrAny) -> DictStrAny: + """For ClickHouse OSS, the DBapi driver doesn't parse datetime types. + We remove timezone specifications in this case.""" + for key, value in db_args.items(): + if isinstance(value, (DateTime, datetime.datetime)): + db_args[key] = str(value.replace(microsecond=0, tzinfo=None)) + return db_args + @contextmanager @raise_database_error def execute_query( @@ -135,12 +186,14 @@ def execute_query( ) -> Iterator[ClickHouseDBApiCursorImpl]: assert isinstance(query, str), "Query must be a string." - db_args = kwargs.copy() + db_args: DictStrAny = kwargs.copy() if args: query, db_args = _convert_to_old_pyformat(query, args, OperationalError) db_args.update(kwargs) + db_args = self._sanitise_dbargs(db_args) + with self._conn.cursor() as cursor: for query_line in query.split(";"): if query_line := query_line.strip(): @@ -151,21 +204,33 @@ def execute_query( yield ClickHouseDBApiCursorImpl(cursor) # type: ignore[abstract] - def fully_qualified_dataset_name(self, escape: bool = True) -> str: - database_name = self.database_name - dataset_name = self.dataset_name + def catalog_name(self, escape: bool = True) -> Optional[str]: + database_name = self.capabilities.casefold_identifier(self.database_name) if escape: database_name = self.capabilities.escape_identifier(database_name) - dataset_name = self.capabilities.escape_identifier(dataset_name) - return f"{database_name}.{dataset_name}" + return database_name - def make_qualified_table_name(self, table_name: str, escape: bool = True) -> str: - database_name = self.database_name - table_name = f"{self.dataset_name}{self.credentials.dataset_table_separator}{table_name}" - if escape: - database_name = self.capabilities.escape_identifier(database_name) - table_name = self.capabilities.escape_identifier(table_name) - return f"{database_name}.{table_name}" + def make_qualified_table_name_path( + self, table_name: Optional[str], escape: bool = True + ) -> List[str]: + # get catalog and dataset + path = super().make_qualified_table_name_path(None, escape=escape) + if table_name: + # table name combines dataset name and table name + table_name = self.capabilities.casefold_identifier( + f"{self.dataset_name}{self.config.dataset_table_separator}{table_name}" + ) + if escape: + table_name = self.capabilities.escape_identifier(table_name) + # we have only two path components + path[1] = table_name + return path + + def _get_information_schema_components(self, *tables: str) -> Tuple[str, str, List[str]]: + components = super()._get_information_schema_components(*tables) + # clickhouse has a catalogue and no schema but uses catalogue as a schema to query the information schema 🤷 + # so we must disable catalogue search. also note that table name is prefixed with logical "dataset_name" + return (None, components[0], components[2]) @classmethod def _make_database_exception(cls, ex: Exception) -> Exception: diff --git a/dlt/destinations/impl/clickhouse/typing.py b/dlt/destinations/impl/clickhouse/typing.py new file mode 100644 index 0000000000..658822149c --- /dev/null +++ b/dlt/destinations/impl/clickhouse/typing.py @@ -0,0 +1,32 @@ +from typing import Literal, Dict, get_args, Set + +from dlt.common.schema import TColumnHint + +TSecureConnection = Literal[0, 1] +TTableEngineType = Literal[ + "merge_tree", + "shared_merge_tree", + "replicated_merge_tree", +] + +HINT_TO_CLICKHOUSE_ATTR: Dict[TColumnHint, str] = { + "primary_key": "PRIMARY KEY", + "unique": "", # No unique constraints available in ClickHouse. + "foreign_key": "", # No foreign key constraints support in ClickHouse. +} + +TABLE_ENGINE_TYPE_TO_CLICKHOUSE_ATTR: Dict[TTableEngineType, str] = { + "merge_tree": "MergeTree", + "shared_merge_tree": "SharedMergeTree", + "replicated_merge_tree": "ReplicatedMergeTree", +} + +TDeployment = Literal["ClickHouseOSS", "ClickHouseCloud"] + +SUPPORTED_FILE_FORMATS = Literal["jsonl", "parquet"] +FILE_FORMAT_TO_TABLE_FUNCTION_MAPPING: Dict[SUPPORTED_FILE_FORMATS, str] = { + "jsonl": "JSONEachRow", + "parquet": "Parquet", +} +TABLE_ENGINE_TYPES: Set[TTableEngineType] = set(get_args(TTableEngineType)) +TABLE_ENGINE_TYPE_HINT: Literal["x-table-engine-type"] = "x-table-engine-type" diff --git a/dlt/destinations/impl/clickhouse/utils.py b/dlt/destinations/impl/clickhouse/utils.py index 0e2fa3db00..02e4e93943 100644 --- a/dlt/destinations/impl/clickhouse/utils.py +++ b/dlt/destinations/impl/clickhouse/utils.py @@ -1,16 +1,12 @@ -from typing import Union, Literal, Dict +from typing import Union from urllib.parse import urlparse, ParseResult -SUPPORTED_FILE_FORMATS = Literal["jsonl", "parquet"] -FILE_FORMAT_TO_TABLE_FUNCTION_MAPPING: Dict[SUPPORTED_FILE_FORMATS, str] = { - "jsonl": "JSONEachRow", - "parquet": "Parquet", -} - - def convert_storage_to_http_scheme( - url: Union[str, ParseResult], use_https: bool = False, endpoint: str = None, region: str = None + url: Union[str, ParseResult], + use_https: bool = False, + endpoint: str = None, + region: str = None, ) -> str: try: if isinstance(url, str): diff --git a/dlt/destinations/impl/databricks/__init__.py b/dlt/destinations/impl/databricks/__init__.py index 81884fae4b..e69de29bb2 100644 --- a/dlt/destinations/impl/databricks/__init__.py +++ b/dlt/destinations/impl/databricks/__init__.py @@ -1,30 +0,0 @@ -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.data_writers.escape import escape_databricks_identifier, escape_databricks_literal -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - -from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = None - caps.supported_loader_file_formats = [] - caps.preferred_staging_file_format = "parquet" - caps.supported_staging_file_formats = ["jsonl", "parquet"] - caps.escape_identifier = escape_databricks_identifier - caps.escape_literal = escape_databricks_literal - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) - caps.max_identifier_length = 255 - caps.max_column_identifier_length = 255 - caps.max_query_length = 2 * 1024 * 1024 - caps.is_max_query_length_in_bytes = True - caps.max_text_data_type_length = 16 * 1024 * 1024 - caps.is_max_text_data_type_length_in_bytes = True - caps.supports_ddl_transactions = False - caps.supports_truncate_command = True - # caps.supports_transactions = False - caps.alter_add_multi_column = True - caps.supports_multiple_statements = False - caps.supports_clone_table = True - return caps diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index cd203e7e4d..fbe7fa4c6b 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -1,6 +1,7 @@ from typing import ClassVar, Dict, Optional, Sequence, Tuple, List, Any, Iterable, Type, cast from urllib.parse import urlparse, urlunparse +from dlt import config from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( FollowupJob, @@ -15,27 +16,22 @@ AzureCredentials, AzureCredentialsWithoutDefaults, ) -from dlt.common.data_types import TDataType from dlt.common.exceptions import TerminalValueError from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns from dlt.common.schema.typing import TTableSchema, TColumnType, TSchemaTables, TTableFormat from dlt.common.schema.utils import table_schema_has_type +from dlt.common.storages import FilesystemConfiguration, fsspec_from_config from dlt.destinations.insert_job_client import InsertValuesJobClient from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.exceptions import LoadJobTerminalException - -from dlt.destinations.impl.databricks import capabilities from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration from dlt.destinations.impl.databricks.sql_client import DatabricksSqlClient -from dlt.destinations.sql_jobs import SqlMergeJob, SqlJobParams +from dlt.destinations.sql_jobs import SqlMergeJob from dlt.destinations.job_impl import NewReferenceJob -from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.type_mapping import TypeMapper -from dlt.common.storages import FilesystemConfiguration, fsspec_from_config -from dlt import config class DatabricksTypeMapper(TypeMapper): @@ -258,10 +254,18 @@ def gen_delete_from_sql( class DatabricksClient(InsertValuesJobClient, SupportsStagingDestination): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: DatabricksClientConfiguration) -> None: - sql_client = DatabricksSqlClient(config.normalize_dataset_name(schema), config.credentials) + def __init__( + self, + schema: Schema, + config: DatabricksClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + sql_client = DatabricksSqlClient( + config.normalize_dataset_name(schema), + config.normalize_staging_dataset_name(schema), + config.credentials, + capabilities, + ) super().__init__(schema, config, sql_client) self.config: DatabricksClientConfiguration = config self.sql_client: DatabricksSqlClient = sql_client # type: ignore[assignment] @@ -303,7 +307,7 @@ def _get_table_update_sql( sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) cluster_list = [ - self.capabilities.escape_identifier(c["name"]) for c in new_columns if c.get("cluster") + self.sql_client.escape_column_name(c["name"]) for c in new_columns if c.get("cluster") ] if cluster_list: @@ -317,14 +321,14 @@ def _from_db_type( return self.type_mapper.from_db_type(bq_t, precision, scale) def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: - name = self.capabilities.escape_identifier(c["name"]) + name = self.sql_client.escape_column_name(c["name"]) return ( f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" ) def _get_storage_table_query_columns(self) -> List[str]: fields = super()._get_storage_table_query_columns() - fields[1] = ( # Override because this is the only way to get data type with precision + fields[2] = ( # Override because this is the only way to get data type with precision "full_data_type" ) return fields diff --git a/dlt/destinations/impl/databricks/factory.py b/dlt/destinations/impl/databricks/factory.py index 7c6c95137d..59ba3b1ec1 100644 --- a/dlt/destinations/impl/databricks/factory.py +++ b/dlt/destinations/impl/databricks/factory.py @@ -1,12 +1,13 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.data_writers.escape import escape_databricks_identifier, escape_databricks_literal +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.destinations.impl.databricks.configuration import ( DatabricksCredentials, DatabricksClientConfiguration, ) -from dlt.destinations.impl.databricks import capabilities if t.TYPE_CHECKING: from dlt.destinations.impl.databricks.databricks import DatabricksClient @@ -15,8 +16,34 @@ class databricks(Destination[DatabricksClientConfiguration, "DatabricksClient"]): spec = DatabricksClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = None + caps.supported_loader_file_formats = [] + caps.preferred_staging_file_format = "parquet" + caps.supported_staging_file_formats = ["jsonl", "parquet"] + caps.escape_identifier = escape_databricks_identifier + # databricks identifiers are case insensitive and stored in lower case + # https://docs.databricks.com/en/sql/language-manual/sql-ref-identifiers.html + caps.escape_literal = escape_databricks_literal + caps.casefold_identifier = str.lower + caps.has_case_sensitive_identifiers = False + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + caps.max_identifier_length = 255 + caps.max_column_identifier_length = 255 + caps.max_query_length = 2 * 1024 * 1024 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 16 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_ddl_transactions = False + caps.supports_truncate_command = True + # caps.supports_transactions = False + caps.alter_add_multi_column = True + caps.supports_multiple_statements = False + caps.supports_clone_table = True + caps.supported_merge_strategies = ["delete-insert", "scd2"] + return caps @property def client_class(self) -> t.Type["DatabricksClient"]: diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index 530b03715a..4c06ef1cf3 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -1,5 +1,5 @@ from contextlib import contextmanager, suppress -from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence, List, Union, Dict +from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence, List, Tuple, Union, Dict from databricks import sql as databricks_lib @@ -21,18 +21,38 @@ raise_database_error, raise_open_connection_error, ) -from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction +from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction, DataFrame from dlt.destinations.impl.databricks.configuration import DatabricksCredentials -from dlt.destinations.impl.databricks import capabilities -from dlt.common.time import to_py_date, to_py_datetime + + +class DatabricksCursorImpl(DBApiCursorImpl): + """Use native data frame support if available""" + + native_cursor: DatabricksSqlCursor # type: ignore[assignment] + vector_size: ClassVar[int] = 2048 + + def df(self, chunk_size: int = None, **kwargs: Any) -> DataFrame: + if chunk_size is None: + return self.native_cursor.fetchall_arrow().to_pandas() + else: + df = self.native_cursor.fetchmany_arrow(chunk_size).to_pandas() + if df.shape[0] == 0: + return None + else: + return df class DatabricksSqlClient(SqlClientBase[DatabricksSqlConnection], DBTransaction): dbapi: ClassVar[DBApi] = databricks_lib - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - def __init__(self, dataset_name: str, credentials: DatabricksCredentials) -> None: - super().__init__(credentials.catalog, dataset_name) + def __init__( + self, + dataset_name: str, + staging_dataset_name: str, + credentials: DatabricksCredentials, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(credentials.catalog, dataset_name, staging_dataset_name, capabilities) self._conn: DatabricksSqlConnection = None self.credentials = credentials @@ -68,9 +88,6 @@ def rollback_transaction(self) -> None: def native_connection(self) -> "DatabricksSqlConnection": return self._conn - def drop_dataset(self) -> None: - self.execute_sql("DROP SCHEMA IF EXISTS %s CASCADE;" % self.fully_qualified_dataset_name()) - def drop_tables(self, *tables: str) -> None: # Tables are drop with `IF EXISTS`, but databricks raises when the schema doesn't exist. # Multi statement exec is safe and the error can be ignored since all tables are in the same schema. @@ -112,16 +129,13 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB db_args = args or kwargs or None with self._conn.cursor() as curr: # type: ignore[assignment] curr.execute(query, db_args) - yield DBApiCursorImpl(curr) # type: ignore[abstract] + yield DatabricksCursorImpl(curr) # type: ignore[abstract] - def fully_qualified_dataset_name(self, escape: bool = True) -> str: + def catalog_name(self, escape: bool = True) -> Optional[str]: + catalog = self.capabilities.casefold_identifier(self.credentials.catalog) if escape: - catalog = self.capabilities.escape_identifier(self.credentials.catalog) - dataset_name = self.capabilities.escape_identifier(self.dataset_name) - else: - catalog = self.credentials.catalog - dataset_name = self.dataset_name - return f"{catalog}.{dataset_name}" + catalog = self.capabilities.escape_identifier(catalog) + return catalog @staticmethod def _make_database_exception(ex: Exception) -> Exception: diff --git a/dlt/destinations/impl/destination/__init__.py b/dlt/destinations/impl/destination/__init__.py index 5b076df4c6..e69de29bb2 100644 --- a/dlt/destinations/impl/destination/__init__.py +++ b/dlt/destinations/impl/destination/__init__.py @@ -1,21 +0,0 @@ -from typing import Optional -from dlt.common.destination import DestinationCapabilitiesContext, TLoaderFileFormat -from dlt.common.destination.capabilities import TLoaderParallelismStrategy - - -def capabilities( - preferred_loader_file_format: TLoaderFileFormat = "typed-jsonl", - naming_convention: str = "direct", - max_table_nesting: Optional[int] = 0, - max_parallel_load_jobs: Optional[int] = 0, - loader_parallelism_strategy: Optional[TLoaderParallelismStrategy] = None, -) -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext.generic_capabilities(preferred_loader_file_format) - caps.supported_loader_file_formats = ["typed-jsonl", "parquet"] - caps.supports_ddl_transactions = False - caps.supports_transactions = False - caps.naming_convention = naming_convention - caps.max_table_nesting = max_table_nesting - caps.max_parallel_load_jobs = max_parallel_load_jobs - caps.loader_parallelism_strategy = loader_parallelism_strategy - return caps diff --git a/dlt/destinations/impl/destination/configuration.py b/dlt/destinations/impl/destination/configuration.py index c3b677058c..705f3b0bb5 100644 --- a/dlt/destinations/impl/destination/configuration.py +++ b/dlt/destinations/impl/destination/configuration.py @@ -1,20 +1,23 @@ import dataclasses -from typing import Optional, Final, Callable, Union +from typing import Optional, Final, Callable, Union, Any from typing_extensions import ParamSpec -from dlt.common.configuration import configspec +from dlt.common.configuration import configspec, ConfigurationValueError from dlt.common.destination import TLoaderFileFormat from dlt.common.destination.reference import ( DestinationClientConfiguration, ) from dlt.common.typing import TDataItems from dlt.common.schema import TTableSchema -from dlt.common.destination import Destination TDestinationCallable = Callable[[Union[TDataItems, str], TTableSchema], None] TDestinationCallableParams = ParamSpec("TDestinationCallableParams") +def dummy_custom_destination(*args: Any, **kwargs: Any) -> None: + pass + + @configspec class CustomDestinationClientConfiguration(DestinationClientConfiguration): destination_type: Final[str] = dataclasses.field(default="destination", init=False, repr=False, compare=False) # type: ignore @@ -23,3 +26,15 @@ class CustomDestinationClientConfiguration(DestinationClientConfiguration): batch_size: int = 10 skip_dlt_columns_and_tables: bool = True max_table_nesting: Optional[int] = 0 + + def ensure_callable(self) -> None: + """Makes sure that valid callable was provided""" + # TODO: this surely can be done with `on_resolved` + if ( + self.destination_callable is None + or self.destination_callable is dummy_custom_destination + ): + raise ConfigurationValueError( + f"A valid callable was not provided to {self.__class__.__name__}. Did you decorate" + " a function @dlt.destination correctly?" + ) diff --git a/dlt/destinations/impl/destination/destination.py b/dlt/destinations/impl/destination/destination.py index 69d1d1d98a..976dfa4fb5 100644 --- a/dlt/destinations/impl/destination/destination.py +++ b/dlt/destinations/impl/destination/destination.py @@ -3,9 +3,8 @@ from typing import ClassVar, Optional, Type, Iterable, cast, List from dlt.common.destination.reference import LoadJob -from dlt.destinations.job_impl import EmptyLoadJob from dlt.common.typing import AnyFun -from dlt.pipeline.current import destination_state +from dlt.common.storages.load_package import destination_state from dlt.common.configuration import create_resolved_partial from dlt.common.schema import Schema, TTableSchema, TSchemaTables @@ -16,7 +15,7 @@ JobClientBase, ) -from dlt.destinations.impl.destination import capabilities +from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.impl.destination.configuration import CustomDestinationClientConfiguration from dlt.destinations.job_impl import ( DestinationJsonlLoadJob, @@ -27,10 +26,14 @@ class DestinationClient(JobClientBase): """Sink Client""" - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: CustomDestinationClientConfiguration) -> None: - super().__init__(schema, config) + def __init__( + self, + schema: Schema, + config: CustomDestinationClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + config.ensure_callable() + super().__init__(schema, config, capabilities) self.config: CustomDestinationClientConfiguration = config # create pre-resolved callable to avoid multiple config resolutions during execution of the jobs self.destination_callable = create_resolved_partial( diff --git a/dlt/destinations/impl/destination/factory.py b/dlt/destinations/impl/destination/factory.py index b3127ab99b..e307b651fb 100644 --- a/dlt/destinations/impl/destination/factory.py +++ b/dlt/destinations/impl/destination/factory.py @@ -4,18 +4,20 @@ from types import ModuleType from dlt.common import logger +from dlt.common.destination.capabilities import TLoaderParallelismStrategy +from dlt.common.exceptions import TerminalValueError +from dlt.common.normalizers.naming.naming import NamingConvention from dlt.common.typing import AnyFun from dlt.common.destination import Destination, DestinationCapabilitiesContext, TLoaderFileFormat from dlt.common.configuration import known_sections, with_config, get_fun_spec from dlt.common.configuration.exceptions import ConfigurationValueError from dlt.common.utils import get_callable_name, is_inner_callable -from dlt.destinations.exceptions import DestinationTransientException from dlt.destinations.impl.destination.configuration import ( CustomDestinationClientConfiguration, + dummy_custom_destination, TDestinationCallable, ) -from dlt.destinations.impl.destination import capabilities if t.TYPE_CHECKING: from dlt.destinations.impl.destination.destination import DestinationClient @@ -34,16 +36,16 @@ class DestinationInfo(t.NamedTuple): class destination(Destination[CustomDestinationClientConfiguration, "DestinationClient"]): - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities( - preferred_loader_file_format=self.config_params.get( - "loader_file_format", "typed-jsonl" - ), - naming_convention=self.config_params.get("naming_convention", "direct"), - max_table_nesting=self.config_params.get("max_table_nesting", None), - max_parallel_load_jobs=self.config_params.get("max_parallel_load_jobs", None), - loader_parallelism_strategy=self.config_params.get("loader_parallelism_strategy", None), - ) + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext.generic_capabilities("typed-jsonl") + caps.supported_loader_file_formats = ["typed-jsonl", "parquet"] + caps.supports_ddl_transactions = False + caps.supports_transactions = False + caps.naming_convention = "direct" + caps.max_table_nesting = 0 + caps.max_parallel_load_jobs = 0 + caps.loader_parallelism_strategy = None + return caps @property def spec(self) -> t.Type[CustomDestinationClientConfiguration]: @@ -68,7 +70,7 @@ def __init__( **kwargs: t.Any, ) -> None: if spec and not issubclass(spec, CustomDestinationClientConfiguration): - raise ValueError( + raise TerminalValueError( "A SPEC for a sink destination must use CustomDestinationClientConfiguration as a" " base." ) @@ -76,8 +78,10 @@ def __init__( if callable(destination_callable): pass elif destination_callable: + if "." not in destination_callable: + raise ValueError("str destination reference must be of format 'module.function'") + module_path, attr_name = destination_callable.rsplit(".", 1) try: - module_path, attr_name = destination_callable.rsplit(".", 1) dest_module = import_module(module_path) except ModuleNotFoundError as e: raise ConfigurationValueError( @@ -97,14 +101,7 @@ def __init__( "No destination callable provided, providing dummy callable which will fail on" " load." ) - - def dummy_callable(*args: t.Any, **kwargs: t.Any) -> None: - raise DestinationTransientException( - "You tried to load to a custom destination without a valid callable." - ) - - destination_callable = dummy_callable - + destination_callable = dummy_custom_destination elif not callable(destination_callable): raise ConfigurationValueError("Resolved Sink destination callable is not a callable.") @@ -138,9 +135,21 @@ def dummy_callable(*args: t.Any, **kwargs: t.Any) -> None: super().__init__( destination_name=destination_name, environment=environment, + # NOTE: `loader_file_format` is not a field in the caps so we had to hack the base class to allow this loader_file_format=loader_file_format, batch_size=batch_size, naming_convention=naming_convention, destination_callable=conf_callable, **kwargs, ) + + @classmethod + def adjust_capabilities( + cls, + caps: DestinationCapabilitiesContext, + config: CustomDestinationClientConfiguration, + naming: t.Optional[NamingConvention], + ) -> DestinationCapabilitiesContext: + caps = super().adjust_capabilities(caps, config, naming) + caps.preferred_loader_file_format = config.loader_file_format + return caps diff --git a/dlt/destinations/impl/dremio/__init__.py b/dlt/destinations/impl/dremio/__init__.py index b4bde2fe6d..e69de29bb2 100644 --- a/dlt/destinations/impl/dremio/__init__.py +++ b/dlt/destinations/impl/dremio/__init__.py @@ -1,27 +0,0 @@ -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.common.data_writers.escape import escape_dremio_identifier -from dlt.common.destination import DestinationCapabilitiesContext - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = None - caps.supported_loader_file_formats = [] - caps.preferred_staging_file_format = "parquet" - caps.supported_staging_file_formats = ["jsonl", "parquet"] - caps.escape_identifier = escape_dremio_identifier - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) - caps.max_identifier_length = 255 - caps.max_column_identifier_length = 255 - caps.max_query_length = 2 * 1024 * 1024 - caps.is_max_query_length_in_bytes = True - caps.max_text_data_type_length = 16 * 1024 * 1024 - caps.is_max_text_data_type_length_in_bytes = True - caps.supports_transactions = False - caps.supports_ddl_transactions = False - caps.alter_add_multi_column = True - caps.supports_clone_table = False - caps.supports_multiple_statements = False - caps.timestamp_precision = 3 - return caps diff --git a/dlt/destinations/impl/dremio/dremio.py b/dlt/destinations/impl/dremio/dremio.py index 23bca0ad74..bea18cdea5 100644 --- a/dlt/destinations/impl/dremio/dremio.py +++ b/dlt/destinations/impl/dremio/dremio.py @@ -14,7 +14,6 @@ from dlt.common.storages.file_storage import FileStorage from dlt.common.utils import uniq_id from dlt.destinations.exceptions import LoadJobTerminalException -from dlt.destinations.impl.dremio import capabilities from dlt.destinations.impl.dremio.configuration import DremioClientConfiguration from dlt.destinations.impl.dremio.sql_client import DremioSqlClient from dlt.destinations.job_client_impl import SqlJobClientWithStaging @@ -137,10 +136,18 @@ def exception(self) -> str: class DremioClient(SqlJobClientWithStaging, SupportsStagingDestination): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: DremioClientConfiguration) -> None: - sql_client = DremioSqlClient(config.normalize_dataset_name(schema), config.credentials) + def __init__( + self, + schema: Schema, + config: DremioClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + sql_client = DremioSqlClient( + config.normalize_dataset_name(schema), + config.normalize_staging_dataset_name(schema), + config.credentials, + capabilities, + ) super().__init__(schema, config, sql_client) self.config: DremioClientConfiguration = config self.sql_client: DremioSqlClient = sql_client # type: ignore @@ -172,7 +179,7 @@ def _get_table_update_sql( if not generate_alter: partition_list = [ - self.capabilities.escape_identifier(c["name"]) + self.sql_client.escape_column_name(c["name"]) for c in new_columns if c.get("partition") ] @@ -180,7 +187,7 @@ def _get_table_update_sql( sql[0] += "\nPARTITION BY (" + ",".join(partition_list) + ")" sort_list = [ - self.capabilities.escape_identifier(c["name"]) for c in new_columns if c.get("sort") + self.sql_client.escape_column_name(c["name"]) for c in new_columns if c.get("sort") ] if sort_list: sql[0] += "\nLOCALSORT BY (" + ",".join(sort_list) + ")" @@ -193,45 +200,11 @@ def _from_db_type( return self.type_mapper.from_db_type(bq_t, precision, scale) def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: - name = self.capabilities.escape_identifier(c["name"]) + name = self.sql_client.escape_column_name(c["name"]) return ( f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" ) - def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: - def _null_to_bool(v: str) -> bool: - if v == "NO": - return False - elif v == "YES": - return True - raise ValueError(v) - - fields = self._get_storage_table_query_columns() - table_schema = self.sql_client.fully_qualified_dataset_name(escape=False) - db_params = (table_schema, table_name) - query = f""" -SELECT {",".join(fields)} - FROM INFORMATION_SCHEMA.COLUMNS -WHERE - table_catalog = 'DREMIO' AND table_schema = %s AND table_name = %s ORDER BY ordinal_position; -""" - rows = self.sql_client.execute_sql(query, *db_params) - - # if no rows we assume that table does not exist - schema_table: TTableSchemaColumns = {} - if len(rows) == 0: - return False, schema_table - for c in rows: - numeric_precision = c[3] - numeric_scale = c[4] - schema_c: TColumnSchemaBase = { - "name": c[0], - "nullable": _null_to_bool(c[2]), - **self._from_db_type(c[1], numeric_precision, numeric_scale), - } - schema_table[c[0]] = schema_c # type: ignore - return True, schema_table - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: return [DremioMergeJob.from_table_chain(table_chain, self.sql_client)] diff --git a/dlt/destinations/impl/dremio/factory.py b/dlt/destinations/impl/dremio/factory.py index 61895e4f90..b8c7e1b746 100644 --- a/dlt/destinations/impl/dremio/factory.py +++ b/dlt/destinations/impl/dremio/factory.py @@ -1,11 +1,13 @@ import typing as t +from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.data_writers.escape import escape_dremio_identifier + from dlt.destinations.impl.dremio.configuration import ( DremioCredentials, DremioClientConfiguration, ) -from dlt.destinations.impl.dremio import capabilities -from dlt.common.destination import Destination, DestinationCapabilitiesContext if t.TYPE_CHECKING: from dlt.destinations.impl.dremio.dremio import DremioClient @@ -14,8 +16,32 @@ class dremio(Destination[DremioClientConfiguration, "DremioClient"]): spec = DremioClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = None + caps.supported_loader_file_formats = [] + caps.preferred_staging_file_format = "parquet" + caps.supported_staging_file_formats = ["jsonl", "parquet"] + caps.escape_identifier = escape_dremio_identifier + # all identifiers are case insensitive but are stored as is + # https://docs.dremio.com/current/sonar/data-sources + caps.has_case_sensitive_identifiers = False + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + caps.max_identifier_length = 255 + caps.max_column_identifier_length = 255 + caps.max_query_length = 2 * 1024 * 1024 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 16 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_transactions = False + caps.supports_ddl_transactions = False + caps.alter_add_multi_column = True + caps.supports_clone_table = False + caps.supports_multiple_statements = False + caps.timestamp_precision = 3 + caps.supported_merge_strategies = ["delete-insert", "scd2"] + return caps @property def client_class(self) -> t.Type["DremioClient"]: diff --git a/dlt/destinations/impl/dremio/sql_client.py b/dlt/destinations/impl/dremio/sql_client.py index 255c8acee0..7dee056da7 100644 --- a/dlt/destinations/impl/dremio/sql_client.py +++ b/dlt/destinations/impl/dremio/sql_client.py @@ -1,5 +1,5 @@ from contextlib import contextmanager, suppress -from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence, List +from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence, List, Tuple import pyarrow @@ -10,7 +10,7 @@ DatabaseUndefinedRelation, DatabaseTransientException, ) -from dlt.destinations.impl.dremio import capabilities, pydremio +from dlt.destinations.impl.dremio import pydremio from dlt.destinations.impl.dremio.configuration import DremioCredentials from dlt.destinations.sql_client import ( DBApiCursorImpl, @@ -32,10 +32,16 @@ def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: class DremioSqlClient(SqlClientBase[pydremio.DremioConnection]): dbapi: ClassVar[DBApi] = pydremio - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, dataset_name: str, credentials: DremioCredentials) -> None: - super().__init__(credentials.database, dataset_name) + SENTINEL_TABLE_NAME: ClassVar[str] = "_dlt_sentinel_table" + + def __init__( + self, + dataset_name: str, + staging_dataset_name: str, + credentials: DremioCredentials, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(credentials.database, dataset_name, staging_dataset_name, capabilities) self._conn: Optional[pydremio.DremioConnection] = None self.credentials = credentials @@ -99,18 +105,16 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB raise DatabaseTransientException(ex) yield DremioCursorImpl(curr) # type: ignore - def fully_qualified_dataset_name(self, escape: bool = True) -> str: - database_name = self.credentials.database - dataset_name = self.dataset_name + def catalog_name(self, escape: bool = True) -> Optional[str]: + database_name = self.capabilities.casefold_identifier(self.database_name) if escape: database_name = self.capabilities.escape_identifier(database_name) - dataset_name = self.capabilities.escape_identifier(dataset_name) - return f"{database_name}.{dataset_name}" + return database_name - def make_qualified_table_name(self, table_name: str, escape: bool = True) -> str: - if escape: - table_name = self.capabilities.escape_identifier(table_name) - return f"{self.fully_qualified_dataset_name(escape=escape)}.{table_name}" + def _get_information_schema_components(self, *tables: str) -> Tuple[str, str, List[str]]: + components = super()._get_information_schema_components(*tables) + # catalog is always DREMIO but schema contains "database" prefix 🤷 + return ("DREMIO", self.fully_qualified_dataset_name(escape=False), components[2]) @classmethod def _make_database_exception(cls, ex: Exception) -> Exception: @@ -132,19 +136,26 @@ def is_dbapi_exception(ex: Exception) -> bool: return isinstance(ex, (pyarrow.lib.ArrowInvalid, pydremio.MalformedQueryError)) def create_dataset(self) -> None: - pass + # We create a sentinel table which defines wether we consider the dataset created + sentinel_table_name = self.make_qualified_table_name(self.SENTINEL_TABLE_NAME) + self.execute_sql(f"CREATE TABLE {sentinel_table_name} (_dlt_id BIGINT);") def _get_table_names(self) -> List[str]: query = """ SELECT TABLE_NAME FROM INFORMATION_SCHEMA."TABLES" - WHERE TABLE_CATALOG = 'DREMIO' AND TABLE_SCHEMA = %s + WHERE TABLE_CATALOG = %s AND TABLE_SCHEMA = %s """ - db_params = [self.fully_qualified_dataset_name(escape=False)] - tables = self.execute_sql(query, *db_params) or [] + catalog_name, schema_name, _ = self._get_information_schema_components() + tables = self.execute_sql(query, catalog_name, schema_name) or [] return [table[0] for table in tables] def drop_dataset(self) -> None: + # drop sentinel table + sentinel_table_name = self.make_qualified_table_name(self.SENTINEL_TABLE_NAME) + # must exist or we get undefined relation exception + self.execute_sql(f"DROP TABLE {sentinel_table_name}") + table_names = self._get_table_names() for table_name in table_names: full_table_name = self.make_qualified_table_name(table_name) diff --git a/dlt/destinations/impl/duckdb/__init__.py b/dlt/destinations/impl/duckdb/__init__.py index 5cbc8dea53..e69de29bb2 100644 --- a/dlt/destinations/impl/duckdb/__init__.py +++ b/dlt/destinations/impl/duckdb/__init__.py @@ -1,26 +0,0 @@ -from dlt.common.data_writers.escape import escape_postgres_identifier, escape_duckdb_literal -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "insert_values" - caps.supported_loader_file_formats = ["insert_values", "parquet", "jsonl"] - caps.preferred_staging_file_format = None - caps.supported_staging_file_formats = [] - caps.escape_identifier = escape_postgres_identifier - caps.escape_literal = escape_duckdb_literal - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) - caps.max_identifier_length = 65536 - caps.max_column_identifier_length = 65536 - caps.max_query_length = 32 * 1024 * 1024 - caps.is_max_query_length_in_bytes = True - caps.max_text_data_type_length = 1024 * 1024 * 1024 - caps.is_max_text_data_type_length_in_bytes = True - caps.supports_ddl_transactions = True - caps.alter_add_multi_column = False - caps.supports_truncate_command = False - - return caps diff --git a/dlt/destinations/impl/duckdb/duck.py b/dlt/destinations/impl/duckdb/duck.py index 7016e9bfff..10d4fc13de 100644 --- a/dlt/destinations/impl/duckdb/duck.py +++ b/dlt/destinations/impl/duckdb/duck.py @@ -12,7 +12,6 @@ from dlt.destinations.insert_job_client import InsertValuesJobClient -from dlt.destinations.impl.duckdb import capabilities from dlt.destinations.impl.duckdb.sql_client import DuckDbSqlClient from dlt.destinations.impl.duckdb.configuration import DuckDbClientConfiguration from dlt.destinations.type_mapping import TypeMapper @@ -151,10 +150,18 @@ def exception(self) -> str: class DuckDbClient(InsertValuesJobClient): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: DuckDbClientConfiguration) -> None: - sql_client = DuckDbSqlClient(config.normalize_dataset_name(schema), config.credentials) + def __init__( + self, + schema: Schema, + config: DuckDbClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + sql_client = DuckDbSqlClient( + config.normalize_dataset_name(schema), + config.normalize_staging_dataset_name(schema), + config.credentials, + capabilities, + ) super().__init__(schema, config, sql_client) self.config: DuckDbClientConfiguration = config self.sql_client: DuckDbSqlClient = sql_client # type: ignore @@ -173,7 +180,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non for h in self.active_hints.keys() if c.get(h, False) is True ) - column_name = self.capabilities.escape_identifier(c["name"]) + column_name = self.sql_client.escape_column_name(c["name"]) return ( f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) diff --git a/dlt/destinations/impl/duckdb/factory.py b/dlt/destinations/impl/duckdb/factory.py index 55fcd3b339..2c4df2cb58 100644 --- a/dlt/destinations/impl/duckdb/factory.py +++ b/dlt/destinations/impl/duckdb/factory.py @@ -1,8 +1,10 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.data_writers.escape import escape_postgres_identifier, escape_duckdb_literal +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE + from dlt.destinations.impl.duckdb.configuration import DuckDbCredentials, DuckDbClientConfiguration -from dlt.destinations.impl.duckdb import capabilities if t.TYPE_CHECKING: from duckdb import DuckDBPyConnection @@ -12,8 +14,30 @@ class duckdb(Destination[DuckDbClientConfiguration, "DuckDbClient"]): spec = DuckDbClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "insert_values" + caps.supported_loader_file_formats = ["insert_values", "parquet", "jsonl"] + caps.preferred_staging_file_format = None + caps.supported_staging_file_formats = [] + caps.escape_identifier = escape_postgres_identifier + # all identifiers are case insensitive but are stored as is + caps.escape_literal = escape_duckdb_literal + caps.has_case_sensitive_identifiers = False + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + caps.max_identifier_length = 65536 + caps.max_column_identifier_length = 65536 + caps.max_query_length = 32 * 1024 * 1024 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 1024 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_ddl_transactions = True + caps.alter_add_multi_column = False + caps.supports_truncate_command = False + caps.supported_merge_strategies = ["delete-insert", "scd2"] + + return caps @property def client_class(self) -> t.Type["DuckDbClient"]: diff --git a/dlt/destinations/impl/duckdb/sql_client.py b/dlt/destinations/impl/duckdb/sql_client.py index bb85b5825b..80bbbedc9c 100644 --- a/dlt/destinations/impl/duckdb/sql_client.py +++ b/dlt/destinations/impl/duckdb/sql_client.py @@ -17,12 +17,11 @@ raise_open_connection_error, ) -from dlt.destinations.impl.duckdb import capabilities from dlt.destinations.impl.duckdb.configuration import DuckDbBaseCredentials class DuckDBDBApiCursorImpl(DBApiCursorImpl): - """Use native BigQuery data frame support if available""" + """Use native duckdb data frame support if available""" native_cursor: duckdb.DuckDBPyConnection # type: ignore vector_size: ClassVar[int] = 2048 @@ -43,10 +42,15 @@ def df(self, chunk_size: int = None, **kwargs: Any) -> DataFrame: class DuckDbSqlClient(SqlClientBase[duckdb.DuckDBPyConnection], DBTransaction): dbapi: ClassVar[DBApi] = duckdb - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - def __init__(self, dataset_name: str, credentials: DuckDbBaseCredentials) -> None: - super().__init__(None, dataset_name) + def __init__( + self, + dataset_name: str, + staging_dataset_name: str, + credentials: DuckDbBaseCredentials, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(None, dataset_name, staging_dataset_name, capabilities) self._conn: duckdb.DuckDBPyConnection = None self.credentials = credentials @@ -142,11 +146,6 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB # else: # return None - def fully_qualified_dataset_name(self, escape: bool = True) -> str: - return ( - self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name - ) - @classmethod def _make_database_exception(cls, ex: Exception) -> Exception: if isinstance(ex, (duckdb.CatalogException)): diff --git a/dlt/destinations/impl/dummy/__init__.py b/dlt/destinations/impl/dummy/__init__.py index e09f7d07a9..e69de29bb2 100644 --- a/dlt/destinations/impl/dummy/__init__.py +++ b/dlt/destinations/impl/dummy/__init__.py @@ -1,39 +0,0 @@ -from typing import List -from dlt.common.configuration import with_config, known_sections -from dlt.common.configuration.accessors import config -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.capabilities import TLoaderFileFormat - -from dlt.destinations.impl.dummy.configuration import DummyClientConfiguration - - -@with_config( - spec=DummyClientConfiguration, - sections=( - known_sections.DESTINATION, - "dummy", - ), -) -def _configure(config: DummyClientConfiguration = config.value) -> DummyClientConfiguration: - return config - - -def capabilities() -> DestinationCapabilitiesContext: - config = _configure() - additional_formats: List[TLoaderFileFormat] = ( - ["reference"] if config.create_followup_jobs else [] # type:ignore[list-item] - ) - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = config.loader_file_format - caps.supported_loader_file_formats = additional_formats + [config.loader_file_format] - caps.preferred_staging_file_format = None - caps.supported_staging_file_formats = additional_formats + [config.loader_file_format] - caps.max_identifier_length = 127 - caps.max_column_identifier_length = 127 - caps.max_query_length = 8 * 1024 * 1024 - caps.is_max_query_length_in_bytes = True - caps.max_text_data_type_length = 65536 - caps.is_max_text_data_type_length_in_bytes = True - caps.supports_ddl_transactions = False - - return caps diff --git a/dlt/destinations/impl/dummy/dummy.py b/dlt/destinations/impl/dummy/dummy.py index 3c78493b57..c41b7dca61 100644 --- a/dlt/destinations/impl/dummy/dummy.py +++ b/dlt/destinations/impl/dummy/dummy.py @@ -36,7 +36,6 @@ LoadJobNotExistsException, LoadJobInvalidStateTransitionException, ) -from dlt.destinations.impl.dummy import capabilities from dlt.destinations.impl.dummy.configuration import DummyClientConfiguration from dlt.destinations.job_impl import NewReferenceJob @@ -110,10 +109,13 @@ def create_followup_jobs(self, final_state: TLoadJobState) -> List[NewLoadJob]: class DummyClient(JobClientBase, SupportsStagingDestination, WithStagingDataset): """dummy client storing jobs in memory""" - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: DummyClientConfiguration) -> None: - super().__init__(schema, config) + def __init__( + self, + schema: Schema, + config: DummyClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(schema, config, capabilities) self.in_staging_context = False self.config: DummyClientConfiguration = config @@ -160,7 +162,7 @@ def restore_file_load(self, file_path: str) -> LoadJob: def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], - table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, + completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[NewLoadJob]: """Creates a list of followup jobs that should be executed after a table chain is completed""" return [] diff --git a/dlt/destinations/impl/dummy/factory.py b/dlt/destinations/impl/dummy/factory.py index 1c848cf22d..c2792fc432 100644 --- a/dlt/destinations/impl/dummy/factory.py +++ b/dlt/destinations/impl/dummy/factory.py @@ -2,11 +2,12 @@ from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.destination.capabilities import TLoaderFileFormat +from dlt.common.normalizers.naming.naming import NamingConvention from dlt.destinations.impl.dummy.configuration import ( DummyClientConfiguration, DummyClientCredentials, ) -from dlt.destinations.impl.dummy import capabilities if t.TYPE_CHECKING: from dlt.destinations.impl.dummy.dummy import DummyClient @@ -15,8 +16,20 @@ class dummy(Destination[DummyClientConfiguration, "DummyClient"]): spec = DummyClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_staging_file_format = None + caps.has_case_sensitive_identifiers = True + caps.max_identifier_length = 127 + caps.max_column_identifier_length = 127 + caps.max_query_length = 8 * 1024 * 1024 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 65536 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_ddl_transactions = False + caps.supported_merge_strategies = ["delete-insert", "upsert"] + + return caps @property def client_class(self) -> t.Type["DummyClient"]: @@ -37,3 +50,19 @@ def __init__( environment=environment, **kwargs, ) + + @classmethod + def adjust_capabilities( + cls, + caps: DestinationCapabilitiesContext, + config: DummyClientConfiguration, + naming: t.Optional[NamingConvention], + ) -> DestinationCapabilitiesContext: + caps = super().adjust_capabilities(caps, config, naming) + additional_formats: t.List[TLoaderFileFormat] = ( + ["reference"] if config.create_followup_jobs else [] # type:ignore[list-item] + ) + caps.preferred_loader_file_format = config.loader_file_format + caps.supported_loader_file_formats = additional_formats + [config.loader_file_format] + caps.supported_staging_file_formats = additional_formats + [config.loader_file_format] + return caps diff --git a/dlt/destinations/impl/filesystem/__init__.py b/dlt/destinations/impl/filesystem/__init__.py index 49fabd61d7..e69de29bb2 100644 --- a/dlt/destinations/impl/filesystem/__init__.py +++ b/dlt/destinations/impl/filesystem/__init__.py @@ -1,24 +0,0 @@ -from typing import Sequence, Tuple - -from dlt.common.schema.typing import TTableSchema -from dlt.common.destination import DestinationCapabilitiesContext, TLoaderFileFormat - - -def loader_file_format_adapter( - preferred_loader_file_format: TLoaderFileFormat, - supported_loader_file_formats: Sequence[TLoaderFileFormat], - /, - *, - table_schema: TTableSchema, -) -> Tuple[TLoaderFileFormat, Sequence[TLoaderFileFormat]]: - if table_schema.get("table_format") == "delta": - return ("parquet", ["parquet"]) - return (preferred_loader_file_format, supported_loader_file_formats) - - -def capabilities() -> DestinationCapabilitiesContext: - return DestinationCapabilitiesContext.generic_capabilities( - preferred_loader_file_format="jsonl", - loader_file_format_adapter=loader_file_format_adapter, - supported_table_formats=["delta"], - ) diff --git a/dlt/destinations/impl/filesystem/factory.py b/dlt/destinations/impl/filesystem/factory.py index 029a5bdda5..1e6eec5cce 100644 --- a/dlt/destinations/impl/filesystem/factory.py +++ b/dlt/destinations/impl/filesystem/factory.py @@ -1,19 +1,38 @@ import typing as t -from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration -from dlt.destinations.impl.filesystem import capabilities -from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.destination import Destination, DestinationCapabilitiesContext, TLoaderFileFormat +from dlt.common.destination.reference import DEFAULT_FILE_LAYOUT +from dlt.common.schema.typing import TTableSchema from dlt.common.storages.configuration import FileSystemCredentials +from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration +from dlt.destinations.impl.filesystem.typing import TCurrentDateTime, TExtraPlaceholders + if t.TYPE_CHECKING: from dlt.destinations.impl.filesystem.filesystem import FilesystemClient +def loader_file_format_adapter( + preferred_loader_file_format: TLoaderFileFormat, + supported_loader_file_formats: t.Sequence[TLoaderFileFormat], + /, + *, + table_schema: TTableSchema, +) -> t.Tuple[TLoaderFileFormat, t.Sequence[TLoaderFileFormat]]: + if table_schema.get("table_format") == "delta": + return ("parquet", ["parquet"]) + return (preferred_loader_file_format, supported_loader_file_formats) + + class filesystem(Destination[FilesystemDestinationClientConfiguration, "FilesystemClient"]): spec = FilesystemDestinationClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + return DestinationCapabilitiesContext.generic_capabilities( + preferred_loader_file_format="jsonl", + loader_file_format_adapter=loader_file_format_adapter, + supported_table_formats=["delta"], + ) @property def client_class(self) -> t.Type["FilesystemClient"]: @@ -25,6 +44,9 @@ def __init__( self, bucket_url: str = None, credentials: t.Union[FileSystemCredentials, t.Dict[str, t.Any], t.Any] = None, + layout: str = DEFAULT_FILE_LAYOUT, + extra_placeholders: t.Optional[TExtraPlaceholders] = None, + current_datetime: t.Optional[TCurrentDateTime] = None, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, **kwargs: t.Any, @@ -46,11 +68,20 @@ def __init__( credentials: Credentials to connect to the filesystem. The type of credentials should correspond to the bucket protocol. For example, for AWS S3, the credentials should be an instance of `AwsCredentials`. A dictionary with the credentials parameters can also be provided. + layout (str): A layout of the files holding table data in the destination bucket/filesystem. Uses a set of pre-defined + and user-defined (extra) placeholders. Please refer to https://dlthub.com/docs/dlt-ecosystem/destinations/filesystem#files-layout + extra_placeholders (dict(str, str | callable)): A dictionary of extra placeholder names that can be used in the `layout` parameter. Names + are mapped to string values or to callables evaluated at runtime. + current_datetime (DateTime | callable): current datetime used by date/time related placeholders. If not provided, load package creation timestamp + will be used. **kwargs: Additional arguments passed to the destination config """ super().__init__( bucket_url=bucket_url, credentials=credentials, + layout=layout, + extra_placeholders=extra_placeholders, + current_datetime=current_datetime, destination_name=destination_name, environment=environment, **kwargs, diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index 2ea3bfbcad..67ceb5ce34 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -20,10 +20,16 @@ import dlt from dlt.common import logger, time, json, pendulum +from dlt.common.storages.fsspec_filesystem import glob_files from dlt.common.typing import DictStrAny from dlt.common.schema import Schema, TSchemaTables, TTableSchema from dlt.common.storages import FileStorage, fsspec_from_config -from dlt.common.storages.load_package import LoadJobInfo, ParsedLoadJobFileName +from dlt.common.storages.load_package import ( + LoadJobInfo, + ParsedLoadJobFileName, + TPipelineStateDoc, + load_package as current_load_package, +) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( NewLoadJob, @@ -41,7 +47,6 @@ from dlt.common.destination.exceptions import DestinationUndefinedEntity from dlt.destinations.typing import DataFrame, ArrowTable from dlt.destinations.job_impl import EmptyLoadJob, NewReferenceJob -from dlt.destinations.impl.filesystem import capabilities from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations import path_utils @@ -165,15 +170,19 @@ def create_followup_jobs(self, final_state: TLoadJobState) -> List[NewLoadJob]: class FilesystemClient(FSClientBase, JobClientBase, WithStagingDataset, WithStateSync): """filesystem client storing jobs in memory""" - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() fs_client: AbstractFileSystem # a path (without the scheme) to a location in the bucket where dataset is present bucket_path: str # name of the dataset dataset_name: str - def __init__(self, schema: Schema, config: FilesystemDestinationClientConfiguration) -> None: - super().__init__(schema, config) + def __init__( + self, + schema: Schema, + config: FilesystemDestinationClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(schema, config, capabilities) self.fs_client, fs_path = fsspec_from_config(config) self.is_local_filesystem = config.protocol == "file" self.bucket_path = ( @@ -233,7 +242,7 @@ def drop_tables(self, *tables: str, delete_schema: bool = True) -> None: self._delete_file(filename) def truncate_tables(self, table_names: List[str]) -> None: - """Truncate table with given name""" + """Truncate a set of tables with given `table_names`""" table_dirs = set(self.get_table_dirs(table_names)) table_prefixes = [self.get_table_prefix(t) for t in table_names] for table_dir in table_dirs: @@ -311,18 +320,19 @@ def list_table_files(self, table_name: str) -> List[str]: def list_files_with_prefixes(self, table_dir: str, prefixes: List[str]) -> List[str]: """returns all files in a directory that match given prefixes""" result = [] - for current_dir, _dirs, files in self.fs_client.walk(table_dir, detail=False, refresh=True): - for file in files: - # skip INIT files - if file == INIT_FILE_NAME: - continue - filepath = self.pathlib.join( - path_utils.normalize_path_sep(self.pathlib, current_dir), file - ) - for p in prefixes: - if filepath.startswith(p): - result.append(filepath) - break + # we fallback to our own glob implementation that is tested to return consistent results for + # filesystems we support. we were not able to use `find` or `walk` because they were selecting + # files wrongly (on azure walk on path1/path2/ would also select files from path1/path2_v2/ but returning wrong dirs) + for details in glob_files(self.fs_client, self.make_remote_uri(table_dir), "**"): + file = details["file_name"] + filepath = self.pathlib.join(table_dir, details["relative_path"]) + # skip INIT files + if file == INIT_FILE_NAME: + continue + for p in prefixes: + if filepath.startswith(p): + result.append(filepath) + break return result def is_storage_initialized(self) -> bool: @@ -377,7 +387,7 @@ def _write_to_json_file(self, filepath: str, data: DictStrAny) -> None: dirname = self.pathlib.dirname(filepath) if not self.fs_client.isdir(dirname): return - self.fs_client.write_text(filepath, json.dumps(data), "utf-8") + self.fs_client.write_text(filepath, json.dumps(data), encoding="utf-8") def _to_path_safe_string(self, s: str) -> str: """for base64 strings""" @@ -431,11 +441,9 @@ def _store_current_state(self, load_id: str) -> None: # don't save the state this way when used as staging if self.config.as_staging: return - # get state doc from current pipeline - from dlt.pipeline.current import load_package - - pipeline_state_doc = load_package()["state"].get("pipeline_state") + # get state doc from current pipeline + pipeline_state_doc = current_load_package()["state"].get("pipeline_state") if not pipeline_state_doc: return @@ -459,8 +467,13 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: # Load compressed state from destination if selected_path: - state_json = json.loads(self.fs_client.read_text(selected_path)) - state_json.pop("version_hash") + state_json: TPipelineStateDoc = json.loads( + self.fs_client.read_text(selected_path, encoding="utf-8") + ) + # we had dlt_load_id stored until version 0.5 and since we do not have any version control + # we always migrate + if load_id := state_json.pop("dlt_load_id", None): # type: ignore[typeddict-item] + state_json["_dlt_load_id"] = load_id return StateInfo(**state_json) return None @@ -503,7 +516,9 @@ def _get_stored_schema_by_hash_or_newest( break if selected_path: - return StorageSchemaInfo(**json.loads(self.fs_client.read_text(selected_path))) + return StorageSchemaInfo( + **json.loads(self.fs_client.read_text(selected_path, encoding="utf-8")) + ) return None @@ -540,19 +555,23 @@ def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchema def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], - table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, + completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[NewLoadJob]: def get_table_jobs( table_jobs: Sequence[LoadJobInfo], table_name: str ) -> Sequence[LoadJobInfo]: return [job for job in table_jobs if job.job_file_info.table_name == table_name] - assert table_chain_jobs is not None - jobs = super().create_table_chain_completed_followup_jobs(table_chain, table_chain_jobs) + assert completed_table_chain_jobs is not None + jobs = super().create_table_chain_completed_followup_jobs( + table_chain, completed_table_chain_jobs + ) table_format = table_chain[0].get("table_format") if table_format == "delta": delta_jobs = [ - DeltaLoadFilesystemJob(self, table, get_table_jobs(table_chain_jobs, table["name"])) + DeltaLoadFilesystemJob( + self, table, get_table_jobs(completed_table_chain_jobs, table["name"]) + ) for table in table_chain ] jobs.extend(delta_jobs) diff --git a/dlt/destinations/impl/filesystem/typing.py b/dlt/destinations/impl/filesystem/typing.py index 139602198d..6781fe21ac 100644 --- a/dlt/destinations/impl/filesystem/typing.py +++ b/dlt/destinations/impl/filesystem/typing.py @@ -15,5 +15,7 @@ `schema name`, `table name`, `load_id`, `file_id` and an `extension` """ -TExtraPlaceholders: TypeAlias = Dict[str, Union[str, TLayoutPlaceholderCallback]] +TExtraPlaceholders: TypeAlias = Dict[ + str, Union[Union[str, int, DateTime], TLayoutPlaceholderCallback] +] """Extra placeholders for filesystem layout""" diff --git a/dlt/destinations/impl/lancedb/__init__.py b/dlt/destinations/impl/lancedb/__init__.py new file mode 100644 index 0000000000..bc6974b072 --- /dev/null +++ b/dlt/destinations/impl/lancedb/__init__.py @@ -0,0 +1 @@ +from dlt.destinations.impl.lancedb.lancedb_adapter import lancedb_adapter diff --git a/dlt/destinations/impl/lancedb/configuration.py b/dlt/destinations/impl/lancedb/configuration.py new file mode 100644 index 0000000000..ba3a8b49d9 --- /dev/null +++ b/dlt/destinations/impl/lancedb/configuration.py @@ -0,0 +1,111 @@ +import dataclasses +from typing import Optional, Final, Literal, ClassVar, List + +from dlt.common.configuration import configspec +from dlt.common.configuration.specs.base_configuration import ( + BaseConfiguration, + CredentialsConfiguration, +) +from dlt.common.destination.reference import DestinationClientDwhConfiguration +from dlt.common.typing import TSecretStrValue +from dlt.common.utils import digest128 + + +@configspec +class LanceDBCredentials(CredentialsConfiguration): + uri: Optional[str] = ".lancedb" + """LanceDB database URI. Defaults to local, on-disk instance. + + The available schemas are: + + - `/path/to/database` - local database. + - `db://host:port` - remote database (LanceDB cloud). + """ + api_key: Optional[TSecretStrValue] = None + """API key for the remote connections (LanceDB cloud).""" + embedding_model_provider_api_key: Optional[str] = None + """API key for the embedding model provider.""" + + __config_gen_annotations__: ClassVar[List[str]] = [ + "uri", + "api_key", + "embedding_model_provider_api_key", + ] + + +@configspec +class LanceDBClientOptions(BaseConfiguration): + max_retries: Optional[int] = 3 + """`EmbeddingFunction` class wraps the calls for source and query embedding + generation inside a rate limit handler that retries the requests with exponential + backoff after successive failures. + + You can tune it by setting it to a different number, or disable it by setting it to 0.""" + + __config_gen_annotations__: ClassVar[List[str]] = [ + "max_retries", + ] + + +TEmbeddingProvider = Literal[ + "gemini-text", + "bedrock-text", + "cohere", + "gte-text", + "imagebind", + "instructor", + "open-clip", + "openai", + "sentence-transformers", + "huggingface", + "colbert", +] + + +@configspec +class LanceDBClientConfiguration(DestinationClientDwhConfiguration): + destination_type: Final[str] = dataclasses.field( # type: ignore + default="LanceDB", init=False, repr=False, compare=False + ) + credentials: LanceDBCredentials = None + dataset_separator: str = "___" + """Character for the dataset separator.""" + dataset_name: Final[Optional[str]] = dataclasses.field( # type: ignore + default=None, init=False, repr=False, compare=False + ) + + options: Optional[LanceDBClientOptions] = None + """LanceDB client options.""" + + embedding_model_provider: TEmbeddingProvider = "cohere" + """Embedding provider used for generating embeddings. Default is "cohere". You can find the full list of + providers at https://github.com/lancedb/lancedb/tree/main/python/python/lancedb/embeddings as well as + https://lancedb.github.io/lancedb/embeddings/default_embedding_functions/.""" + embedding_model: str = "embed-english-v3.0" + """The model used by the embedding provider for generating embeddings. + Check with the embedding provider which options are available. + Reference https://lancedb.github.io/lancedb/embeddings/default_embedding_functions/.""" + embedding_model_dimensions: Optional[int] = None + """The dimensions of the embeddings generated. In most cases it will be automatically inferred, by LanceDB, + but it is configurable in rare cases. + + Make sure it corresponds with the associated embedding model's dimensionality.""" + vector_field_name: str = "vector__" + """Name of the special field to store the vector embeddings.""" + id_field_name: str = "id__" + """Name of the special field to manage deduplication.""" + sentinel_table_name: str = "dltSentinelTable" + """Name of the sentinel table that encapsulates datasets. Since LanceDB has no + concept of schemas, this table serves as a proxy to group related dlt tables together.""" + + __config_gen_annotations__: ClassVar[List[str]] = [ + "embedding_model", + "embedding_model_provider", + ] + + def fingerprint(self) -> str: + """Returns a fingerprint of a connection string.""" + + if self.credentials and self.credentials.uri: + return digest128(self.credentials.uri) + return "" diff --git a/dlt/destinations/impl/lancedb/exceptions.py b/dlt/destinations/impl/lancedb/exceptions.py new file mode 100644 index 0000000000..35b86ce76c --- /dev/null +++ b/dlt/destinations/impl/lancedb/exceptions.py @@ -0,0 +1,30 @@ +from functools import wraps +from typing import ( + Any, +) + +from lancedb.exceptions import MissingValueError, MissingColumnError # type: ignore + +from dlt.common.destination.exceptions import ( + DestinationUndefinedEntity, + DestinationTerminalException, +) +from dlt.common.destination.reference import JobClientBase +from dlt.common.typing import TFun + + +def lancedb_error(f: TFun) -> TFun: + @wraps(f) + def _wrap(self: JobClientBase, *args: Any, **kwargs: Any) -> Any: + try: + return f(self, *args, **kwargs) + except ( + FileNotFoundError, + MissingValueError, + MissingColumnError, + ) as status_ex: + raise DestinationUndefinedEntity(status_ex) from status_ex + except Exception as e: + raise DestinationTerminalException(e) from e + + return _wrap # type: ignore[return-value] diff --git a/dlt/destinations/impl/lancedb/factory.py b/dlt/destinations/impl/lancedb/factory.py new file mode 100644 index 0000000000..f2e17168b9 --- /dev/null +++ b/dlt/destinations/impl/lancedb/factory.py @@ -0,0 +1,53 @@ +import typing as t + +from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.destinations.impl.lancedb.configuration import ( + LanceDBCredentials, + LanceDBClientConfiguration, +) + + +if t.TYPE_CHECKING: + from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient + + +class lancedb(Destination[LanceDBClientConfiguration, "LanceDBClient"]): + spec = LanceDBClientConfiguration + + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "jsonl" + caps.supported_loader_file_formats = ["jsonl"] + + caps.max_identifier_length = 200 + caps.max_column_identifier_length = 1024 + caps.max_query_length = 8 * 1024 * 1024 + caps.is_max_query_length_in_bytes = False + caps.max_text_data_type_length = 8 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = False + caps.supports_ddl_transactions = False + + caps.decimal_precision = (38, 18) + caps.timestamp_precision = 6 + + return caps + + @property + def client_class(self) -> t.Type["LanceDBClient"]: + from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient + + return LanceDBClient + + def __init__( + self, + credentials: t.Union[LanceDBCredentials, t.Dict[str, t.Any]] = None, + destination_name: t.Optional[str] = None, + environment: t.Optional[str] = None, + **kwargs: t.Any, + ) -> None: + super().__init__( + credentials=credentials, + destination_name=destination_name, + environment=environment, + **kwargs, + ) diff --git a/dlt/destinations/impl/lancedb/lancedb_adapter.py b/dlt/destinations/impl/lancedb/lancedb_adapter.py new file mode 100644 index 0000000000..bb33632b48 --- /dev/null +++ b/dlt/destinations/impl/lancedb/lancedb_adapter.py @@ -0,0 +1,58 @@ +from typing import Any + +from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns +from dlt.destinations.utils import ensure_resource +from dlt.extract import DltResource + + +VECTORIZE_HINT = "x-lancedb-embed" + + +def lancedb_adapter( + data: Any, + embed: TColumnNames = None, +) -> DltResource: + """Prepares data for the LanceDB destination by specifying which columns should be embedded. + + Args: + data (Any): The data to be transformed. It can be raw data or an instance + of DltResource. If raw data, the function wraps it into a DltResource + object. + embed (TColumnNames, optional): Specify columns to generate embeddings for. + It can be a single column name as a string, or a list of column names. + + Returns: + DltResource: A resource with applied LanceDB-specific hints. + + Raises: + ValueError: If input for `embed` invalid or empty. + + Examples: + >>> data = [{"name": "Marcel", "description": "Moonbase Engineer"}] + >>> lancedb_adapter(data, embed="description") + [DltResource with hints applied] + """ + resource = ensure_resource(data) + + column_hints: TTableSchemaColumns = {} + + if embed: + if isinstance(embed, str): + embed = [embed] + if not isinstance(embed, list): + raise ValueError( + "'embed' must be a list of column names or a single column name as a string." + ) + + for column_name in embed: + column_hints[column_name] = { + "name": column_name, + VECTORIZE_HINT: True, # type: ignore[misc] + } + + if not column_hints: + raise ValueError("A value for 'embed' must be specified.") + else: + resource.apply_hints(columns=column_hints) + + return resource diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py new file mode 100644 index 0000000000..8265e50fbf --- /dev/null +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -0,0 +1,770 @@ +import uuid +from types import TracebackType +from typing import ( + ClassVar, + List, + Any, + cast, + Union, + Tuple, + Iterable, + Type, + Optional, + Dict, + Sequence, + TYPE_CHECKING, +) + +import lancedb # type: ignore +import pyarrow as pa +from lancedb import DBConnection +from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore +from lancedb.query import LanceQueryBuilder # type: ignore +from lancedb.table import Table # type: ignore +from numpy import ndarray +from pyarrow import Array, ChunkedArray, ArrowInvalid + +from dlt.common import json, pendulum, logger +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.destination.exceptions import ( + DestinationUndefinedEntity, + DestinationTransientException, + DestinationTerminalException, +) +from dlt.common.destination.reference import ( + JobClientBase, + WithStateSync, + LoadJob, + StorageSchemaInfo, + StateInfo, + TLoadJobState, +) +from dlt.common.pendulum import timedelta +from dlt.common.schema import Schema, TTableSchema, TSchemaTables +from dlt.common.schema.typing import ( + TColumnType, + TTableFormat, + TTableSchemaColumns, + TWriteDisposition, +) +from dlt.common.schema.utils import get_columns_names_with_prop +from dlt.common.storages import FileStorage +from dlt.common.typing import DictStrAny +from dlt.destinations.impl.lancedb.configuration import ( + LanceDBClientConfiguration, +) +from dlt.destinations.impl.lancedb.exceptions import ( + lancedb_error, +) +from dlt.destinations.impl.lancedb.lancedb_adapter import VECTORIZE_HINT +from dlt.destinations.impl.lancedb.schema import ( + make_arrow_field_schema, + make_arrow_table_schema, + TArrowSchema, + NULL_SCHEMA, + TArrowField, +) +from dlt.destinations.impl.lancedb.utils import ( + list_merge_identifiers, + generate_uuid, + set_non_standard_providers_environment_variables, +) +from dlt.destinations.job_impl import EmptyLoadJob +from dlt.destinations.type_mapping import TypeMapper + +if TYPE_CHECKING: + NDArray = ndarray[Any, Any] +else: + NDArray = ndarray + + +TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} +UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} + + +class LanceDBTypeMapper(TypeMapper): + sct_to_unbound_dbt = { + "text": pa.string(), + "double": pa.float64(), + "bool": pa.bool_(), + "bigint": pa.int64(), + "binary": pa.binary(), + "date": pa.date32(), + "complex": pa.string(), + } + + sct_to_dbt = {} + + dbt_to_sct = { + pa.string(): "text", + pa.float64(): "double", + pa.bool_(): "bool", + pa.int64(): "bigint", + pa.binary(): "binary", + pa.date32(): "date", + } + + def to_db_decimal_type( + self, precision: Optional[int], scale: Optional[int] + ) -> pa.Decimal128Type: + precision, scale = self.decimal_precision(precision, scale) + return pa.decimal128(precision, scale) + + def to_db_datetime_type( + self, precision: Optional[int], table_format: TTableFormat = None + ) -> pa.TimestampType: + unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] + return pa.timestamp(unit, "UTC") + + def to_db_time_type( + self, precision: Optional[int], table_format: TTableFormat = None + ) -> pa.Time64Type: + unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] + return pa.time64(unit) + + def from_db_type( + self, + db_type: pa.DataType, + precision: Optional[int] = None, + scale: Optional[int] = None, + ) -> TColumnType: + if isinstance(db_type, pa.TimestampType): + return dict( + data_type="timestamp", + precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], + scale=scale, + ) + if isinstance(db_type, pa.Time64Type): + return dict( + data_type="time", + precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], + scale=scale, + ) + if isinstance(db_type, pa.Decimal128Type): + precision, scale = db_type.precision, db_type.scale + if (precision, scale) == self.capabilities.wei_precision: + return cast(TColumnType, dict(data_type="wei")) + return dict(data_type="decimal", precision=precision, scale=scale) + return super().from_db_type(db_type, precision, scale) + + +def upload_batch( + records: List[DictStrAny], + /, + *, + db_client: DBConnection, + table_name: str, + write_disposition: TWriteDisposition, + id_field_name: Optional[str] = None, +) -> None: + """Inserts records into a LanceDB table with automatic embedding computation. + + Args: + records: The data to be inserted as payload. + db_client: The LanceDB client connection. + table_name: The name of the table to insert into. + id_field_name: The name of the ID field for update/merge operations. + write_disposition: The write disposition - one of 'skip', 'append', 'replace', 'merge'. + + Raises: + ValueError: If the write disposition is unsupported, or `id_field_name` is not + provided for update/merge operations. + """ + + try: + tbl = db_client.open_table(table_name) + tbl.checkout_latest() + except FileNotFoundError as e: + raise DestinationTransientException( + "Couldn't open lancedb database. Batch WILL BE RETRIED" + ) from e + + try: + if write_disposition in ("append", "skip"): + tbl.add(records) + elif write_disposition == "replace": + tbl.add(records, mode="overwrite") + elif write_disposition == "merge": + if not id_field_name: + raise ValueError("To perform a merge update, 'id_field_name' must be specified.") + tbl.merge_insert( + id_field_name + ).when_matched_update_all().when_not_matched_insert_all().execute(records) + else: + raise DestinationTerminalException( + f"Unsupported write disposition {write_disposition} for LanceDB Destination - batch" + " failed AND WILL **NOT** BE RETRIED." + ) + except ArrowInvalid as e: + raise DestinationTerminalException( + "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." + ) from e + + +class LanceDBClient(JobClientBase, WithStateSync): + """LanceDB destination handler.""" + + model_func: TextEmbeddingFunction + + def __init__( + self, + schema: Schema, + config: LanceDBClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(schema, config, capabilities) + self.config: LanceDBClientConfiguration = config + self.db_client: DBConnection = lancedb.connect( + uri=self.config.credentials.uri, + api_key=self.config.credentials.api_key, + read_consistency_interval=timedelta(0), + ) + self.registry = EmbeddingFunctionRegistry.get_instance() + self.type_mapper = LanceDBTypeMapper(self.capabilities) + self.sentinel_table_name = config.sentinel_table_name + + embedding_model_provider = self.config.embedding_model_provider + + # LanceDB doesn't provide a standardized way to set API keys across providers. + # Some use ENV variables and others allow passing api key as an argument. + # To account for this, we set provider environment variable as well. + set_non_standard_providers_environment_variables( + embedding_model_provider, + self.config.credentials.embedding_model_provider_api_key, + ) + # Use the monkey-patched implementation if openai was chosen. + if embedding_model_provider == "openai": + from dlt.destinations.impl.lancedb.models import PatchedOpenAIEmbeddings + + self.model_func = PatchedOpenAIEmbeddings( + max_retries=self.config.options.max_retries, + api_key=self.config.credentials.api_key, + ) + else: + self.model_func = self.registry.get(embedding_model_provider).create( + name=self.config.embedding_model, + max_retries=self.config.options.max_retries, + api_key=self.config.credentials.api_key, + ) + + self.vector_field_name = self.config.vector_field_name + self.id_field_name = self.config.id_field_name + + @property + def dataset_name(self) -> str: + return self.config.normalize_dataset_name(self.schema) + + @property + def sentinel_table(self) -> str: + return self.make_qualified_table_name(self.sentinel_table_name) + + def make_qualified_table_name(self, table_name: str) -> str: + return ( + f"{self.dataset_name}{self.config.dataset_separator}{table_name}" + if self.dataset_name + else table_name + ) + + def get_table_schema(self, table_name: str) -> TArrowSchema: + schema_table: Table = self.db_client.open_table(table_name) + schema_table.checkout_latest() + schema = schema_table.schema + return cast( + TArrowSchema, + schema, + ) + + @lancedb_error + def create_table(self, table_name: str, schema: TArrowSchema, mode: str = "create") -> Table: + """Create a LanceDB Table from the provided LanceModel or PyArrow schema. + + Args: + schema: The table schema to create. + table_name: The name of the table to create. + mode (): The mode to use when creating the table. Can be either "create" or "overwrite". + By default, if the table already exists, an exception is raised. + If you want to overwrite the table, use mode="overwrite". + """ + return self.db_client.create_table(table_name, schema=schema, mode=mode) + + def delete_table(self, table_name: str) -> None: + """Delete a LanceDB table. + + Args: + table_name: The name of the table to delete. + """ + self.db_client.drop_table(table_name) + + def query_table( + self, + table_name: str, + query: Union[List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None] = None, + ) -> LanceQueryBuilder: + """Query a LanceDB table. + + Args: + table_name: The name of the table to query. + query: The targeted vector to search for. + + Returns: + A LanceDB query builder. + """ + query_table: Table = self.db_client.open_table(table_name) + query_table.checkout_latest() + return query_table.search(query=query) + + @lancedb_error + def _get_table_names(self) -> List[str]: + """Return all tables in the dataset, excluding the sentinel table.""" + if self.dataset_name: + prefix = f"{self.dataset_name}{self.config.dataset_separator}" + table_names = [ + table_name + for table_name in self.db_client.table_names() + if table_name.startswith(prefix) + ] + else: + table_names = self.db_client.table_names() + + return [table_name for table_name in table_names if table_name != self.sentinel_table] + + @lancedb_error + def drop_storage(self) -> None: + """Drop the dataset from the LanceDB instance. + + Deletes all tables in the dataset and all data, as well as sentinel table associated with them. + + If the dataset name was not provided, it deletes all the tables in the current schema. + """ + for table_name in self._get_table_names(): + self.db_client.drop_table(table_name) + + self._delete_sentinel_table() + + @lancedb_error + def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: + if not self.is_storage_initialized(): + self._create_sentinel_table() + elif truncate_tables: + for table_name in truncate_tables: + fq_table_name = self.make_qualified_table_name(table_name) + if not self.table_exists(fq_table_name): + continue + schema = self.get_table_schema(fq_table_name) + self.db_client.drop_table(fq_table_name) + self.create_table( + table_name=fq_table_name, + schema=schema, + ) + + @lancedb_error + def is_storage_initialized(self) -> bool: + return self.table_exists(self.sentinel_table) + + def _create_sentinel_table(self) -> Table: + """Create an empty table to indicate that the storage is initialized.""" + return self.create_table(schema=NULL_SCHEMA, table_name=self.sentinel_table) + + def _delete_sentinel_table(self) -> None: + """Delete the sentinel table.""" + self.db_client.drop_table(self.sentinel_table) + + @lancedb_error + def update_stored_schema( + self, + only_tables: Iterable[str] = None, + expected_update: TSchemaTables = None, + ) -> Optional[TSchemaTables]: + super().update_stored_schema(only_tables, expected_update) + applied_update: TSchemaTables = {} + + try: + schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash) + except DestinationUndefinedEntity: + schema_info = None + + if schema_info is None: + logger.info( + f"Schema with hash {self.schema.stored_version_hash} " + "not found in the storage. upgrading" + ) + self._execute_schema_update(only_tables) + else: + logger.info( + f"Schema with hash {self.schema.stored_version_hash} " + f"inserted at {schema_info.inserted_at} found " + "in storage, no upgrade required" + ) + return applied_update + + def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: + table_schema: TTableSchemaColumns = {} + + try: + fq_table_name = self.make_qualified_table_name(table_name) + + table: Table = self.db_client.open_table(fq_table_name) + table.checkout_latest() + arrow_schema: TArrowSchema = table.schema + except FileNotFoundError: + return False, table_schema + + field: TArrowField + for field in arrow_schema: + name = self.schema.naming.normalize_identifier(field.name) + table_schema[name] = { + "name": name, + **self.type_mapper.from_db_type(field.type), + } + return True, table_schema + + @lancedb_error + def add_table_fields( + self, table_name: str, field_schemas: List[TArrowField] + ) -> Optional[Table]: + """Add multiple fields to the LanceDB table at once. + + Args: + table_name: The name of the table to create the fields on. + field_schemas: The list of fields to create. + """ + table: Table = self.db_client.open_table(table_name) + table.checkout_latest() + arrow_table = table.to_arrow() + + # Check if any of the new fields already exist in the table. + existing_fields = set(arrow_table.schema.names) + new_fields = [field for field in field_schemas if field.name not in existing_fields] + + if not new_fields: + # All fields already present, skip. + return None + + null_arrays = [pa.nulls(len(arrow_table), type=field.type) for field in new_fields] + + for field, null_array in zip(new_fields, null_arrays): + arrow_table = arrow_table.append_column(field, null_array) + + try: + return self.db_client.create_table(table_name, arrow_table, mode="overwrite") + except OSError: + # Error occurred while creating the table, skip. + return None + + def _execute_schema_update(self, only_tables: Iterable[str]) -> None: + for table_name in only_tables or self.schema.tables: + exists, existing_columns = self.get_storage_table(table_name) + new_columns = self.schema.get_new_table_columns( + table_name, + existing_columns, + self.capabilities.generates_case_sensitive_identifiers(), + ) + embedding_fields: List[str] = get_columns_names_with_prop( + self.schema.get_table(table_name), VECTORIZE_HINT + ) + logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") + if len(new_columns) > 0: + if exists: + field_schemas: List[TArrowField] = [ + make_arrow_field_schema(column["name"], column, self.type_mapper) + for column in new_columns + ] + fq_table_name = self.make_qualified_table_name(table_name) + self.add_table_fields(fq_table_name, field_schemas) + else: + if table_name not in self.schema.dlt_table_names(): + embedding_fields = get_columns_names_with_prop( + self.schema.get_table(table_name=table_name), VECTORIZE_HINT + ) + vector_field_name = self.vector_field_name + id_field_name = self.id_field_name + embedding_model_func = self.model_func + embedding_model_dimensions = self.config.embedding_model_dimensions + else: + embedding_fields = None + vector_field_name = None + id_field_name = None + embedding_model_func = None + embedding_model_dimensions = None + + table_schema: TArrowSchema = make_arrow_table_schema( + table_name, + schema=self.schema, + type_mapper=self.type_mapper, + embedding_fields=embedding_fields, + embedding_model_func=embedding_model_func, + embedding_model_dimensions=embedding_model_dimensions, + vector_field_name=vector_field_name, + id_field_name=id_field_name, + ) + fq_table_name = self.make_qualified_table_name(table_name) + self.create_table(fq_table_name, table_schema) + + self.update_schema_in_storage() + + @lancedb_error + def update_schema_in_storage(self) -> None: + records = [ + { + self.schema.naming.normalize_identifier("version"): self.schema.version, + self.schema.naming.normalize_identifier( + "engine_version" + ): self.schema.ENGINE_VERSION, + self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), + self.schema.naming.normalize_identifier("schema_name"): self.schema.name, + self.schema.naming.normalize_identifier( + "version_hash" + ): self.schema.stored_version_hash, + self.schema.naming.normalize_identifier("schema"): json.dumps( + self.schema.to_dict() + ), + } + ] + fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) + write_disposition = self.schema.get_table(self.schema.version_table_name).get( + "write_disposition" + ) + upload_batch( + records, + db_client=self.db_client, + table_name=fq_version_table_name, + write_disposition=write_disposition, + ) + + @lancedb_error + def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: + """Retrieves the latest completed state for a pipeline.""" + fq_state_table_name = self.make_qualified_table_name(self.schema.state_table_name) + fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) + + state_table_: Table = self.db_client.open_table(fq_state_table_name) + state_table_.checkout_latest() + + loads_table_: Table = self.db_client.open_table(fq_loads_table_name) + loads_table_.checkout_latest() + + # normalize property names + p_load_id = self.schema.naming.normalize_identifier("load_id") + p_dlt_load_id = self.schema.naming.normalize_identifier("_dlt_load_id") + p_pipeline_name = self.schema.naming.normalize_identifier("pipeline_name") + p_status = self.schema.naming.normalize_identifier("status") + p_version = self.schema.naming.normalize_identifier("version") + p_engine_version = self.schema.naming.normalize_identifier("engine_version") + p_state = self.schema.naming.normalize_identifier("state") + p_created_at = self.schema.naming.normalize_identifier("created_at") + p_version_hash = self.schema.naming.normalize_identifier("version_hash") + + # Read the tables into memory as Arrow tables, with pushdown predicates, so we pull as less + # data into memory as possible. + state_table = ( + state_table_.search() + .where(f"`{p_pipeline_name}` = '{pipeline_name}'", prefilter=True) + .to_arrow() + ) + loads_table = loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() + + # Join arrow tables in-memory. + joined_table: pa.Table = state_table.join( + loads_table, keys=p_dlt_load_id, right_keys=p_load_id, join_type="inner" + ).sort_by([(p_dlt_load_id, "descending")]) + + if joined_table.num_rows == 0: + return None + + state = joined_table.take([0]).to_pylist()[0] + return StateInfo( + version=state[p_version], + engine_version=state[p_engine_version], + pipeline_name=state[p_pipeline_name], + state=state[p_state], + created_at=pendulum.instance(state[p_created_at]), + version_hash=state[p_version_hash], + _dlt_load_id=state[p_dlt_load_id], + ) + + @lancedb_error + def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: + fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) + + version_table: Table = self.db_client.open_table(fq_version_table_name) + version_table.checkout_latest() + p_version_hash = self.schema.naming.normalize_identifier("version_hash") + p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") + p_schema_name = self.schema.naming.normalize_identifier("schema_name") + p_version = self.schema.naming.normalize_identifier("version") + p_engine_version = self.schema.naming.normalize_identifier("engine_version") + p_schema = self.schema.naming.normalize_identifier("schema") + + try: + schemas = ( + version_table.search().where( + f'`{p_version_hash}` = "{schema_hash}"', prefilter=True + ) + ).to_list() + + # LanceDB's ORDER BY clause doesn't seem to work. + # See https://github.com/dlt-hub/dlt/pull/1375#issuecomment-2171909341 + most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] + return StorageSchemaInfo( + version_hash=most_recent_schema[p_version_hash], + schema_name=most_recent_schema[p_schema_name], + version=most_recent_schema[p_version], + engine_version=most_recent_schema[p_engine_version], + inserted_at=most_recent_schema[p_inserted_at], + schema=most_recent_schema[p_schema], + ) + except IndexError: + return None + + @lancedb_error + def get_stored_schema(self) -> Optional[StorageSchemaInfo]: + """Retrieves newest schema from destination storage.""" + fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) + + version_table: Table = self.db_client.open_table(fq_version_table_name) + version_table.checkout_latest() + p_version_hash = self.schema.naming.normalize_identifier("version_hash") + p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") + p_schema_name = self.schema.naming.normalize_identifier("schema_name") + p_version = self.schema.naming.normalize_identifier("version") + p_engine_version = self.schema.naming.normalize_identifier("engine_version") + p_schema = self.schema.naming.normalize_identifier("schema") + + try: + schemas = ( + version_table.search().where( + f'`{p_schema_name}` = "{self.schema.name}"', prefilter=True + ) + ).to_list() + + # LanceDB's ORDER BY clause doesn't seem to work. + # See https://github.com/dlt-hub/dlt/pull/1375#issuecomment-2171909341 + most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] + return StorageSchemaInfo( + version_hash=most_recent_schema[p_version_hash], + schema_name=most_recent_schema[p_schema_name], + version=most_recent_schema[p_version], + engine_version=most_recent_schema[p_engine_version], + inserted_at=most_recent_schema[p_inserted_at], + schema=most_recent_schema[p_schema], + ) + except IndexError: + return None + + def __exit__( + self, + exc_type: Type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, + ) -> None: + pass + + def __enter__(self) -> "LanceDBClient": + return self + + @lancedb_error + def complete_load(self, load_id: str) -> None: + records = [ + { + self.schema.naming.normalize_identifier("load_id"): load_id, + self.schema.naming.normalize_identifier("schema_name"): self.schema.name, + self.schema.naming.normalize_identifier("status"): 0, + self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), + self.schema.naming.normalize_identifier( + "schema_version_hash" + ): None, # Payload schema must match the target schema. + } + ] + fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) + write_disposition = self.schema.get_table(self.schema.loads_table_name).get( + "write_disposition" + ) + upload_batch( + records, + db_client=self.db_client, + table_name=fq_loads_table_name, + write_disposition=write_disposition, + ) + + def restore_file_load(self, file_path: str) -> LoadJob: + return EmptyLoadJob.from_file_path(file_path, "completed") + + def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + return LoadLanceDBJob( + self.schema, + table, + file_path, + type_mapper=self.type_mapper, + db_client=self.db_client, + client_config=self.config, + model_func=self.model_func, + fq_table_name=self.make_qualified_table_name(table["name"]), + ) + + def table_exists(self, table_name: str) -> bool: + return table_name in self.db_client.table_names() + + +class LoadLanceDBJob(LoadJob): + arrow_schema: TArrowSchema + + def __init__( + self, + schema: Schema, + table_schema: TTableSchema, + local_path: str, + type_mapper: LanceDBTypeMapper, + db_client: DBConnection, + client_config: LanceDBClientConfiguration, + model_func: TextEmbeddingFunction, + fq_table_name: str, + ) -> None: + file_name = FileStorage.get_file_name_from_file_path(local_path) + super().__init__(file_name) + self.schema: Schema = schema + self.table_schema: TTableSchema = table_schema + self.db_client: DBConnection = db_client + self.type_mapper: TypeMapper = type_mapper + self.table_name: str = table_schema["name"] + self.fq_table_name: str = fq_table_name + self.unique_identifiers: Sequence[str] = list_merge_identifiers(table_schema) + self.embedding_fields: List[str] = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) + self.embedding_model_func: TextEmbeddingFunction = model_func + self.embedding_model_dimensions: int = client_config.embedding_model_dimensions + self.id_field_name: str = client_config.id_field_name + self.write_disposition: TWriteDisposition = cast( + TWriteDisposition, self.table_schema.get("write_disposition", "append") + ) + + with FileStorage.open_zipsafe_ro(local_path) as f: + records: List[DictStrAny] = [json.loads(line) for line in f] + + if self.table_schema not in self.schema.dlt_tables(): + for record in records: + # Add reserved ID fields. + uuid_id = ( + generate_uuid(record, self.unique_identifiers, self.fq_table_name) + if self.unique_identifiers + else str(uuid.uuid4()) + ) + record.update({self.id_field_name: uuid_id}) + + # LanceDB expects all fields in the target arrow table to be present in the data payload. + # We add and set these missing fields, that are fields not present in the target schema, to NULL. + missing_fields = set(self.table_schema["columns"]) - set(record) + for field in missing_fields: + record[field] = None + + upload_batch( + records, + db_client=db_client, + table_name=self.fq_table_name, + write_disposition=self.write_disposition, + id_field_name=self.id_field_name, + ) + + def state(self) -> TLoadJobState: + return "completed" + + def exception(self) -> str: + raise NotImplementedError() diff --git a/dlt/destinations/impl/lancedb/models.py b/dlt/destinations/impl/lancedb/models.py new file mode 100644 index 0000000000..d90adb62bd --- /dev/null +++ b/dlt/destinations/impl/lancedb/models.py @@ -0,0 +1,34 @@ +from typing import Union, List + +import numpy as np +from lancedb.embeddings import OpenAIEmbeddings # type: ignore +from lancedb.embeddings.registry import register # type: ignore +from lancedb.embeddings.utils import TEXT # type: ignore + + +@register("openai_patched") +class PatchedOpenAIEmbeddings(OpenAIEmbeddings): + EMPTY_STRING_PLACEHOLDER: str = "___EMPTY___" + + def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]: # type: ignore[type-arg] + """ + Replace empty strings with a placeholder value. + """ + + sanitized_texts = super().sanitize_input(texts) + return [self.EMPTY_STRING_PLACEHOLDER if item == "" else item for item in sanitized_texts] + + def generate_embeddings( + self, + texts: Union[List[str], np.ndarray], # type: ignore[type-arg] + ) -> List[np.array]: # type: ignore[valid-type] + """ + Generate embeddings, treating the placeholder as an empty result. + """ + embeddings: List[np.array] = super().generate_embeddings(texts) # type: ignore[valid-type] + + for i, text in enumerate(texts): + if text == self.EMPTY_STRING_PLACEHOLDER: + embeddings[i] = np.zeros(self.ndims()) + + return embeddings diff --git a/dlt/destinations/impl/lancedb/schema.py b/dlt/destinations/impl/lancedb/schema.py new file mode 100644 index 0000000000..c7cceec274 --- /dev/null +++ b/dlt/destinations/impl/lancedb/schema.py @@ -0,0 +1,84 @@ +"""Utilities for creating arrow schemas from table schemas.""" + +from dlt.common.json import json +from typing import ( + List, + cast, + Optional, +) + +import pyarrow as pa +from lancedb.embeddings import TextEmbeddingFunction # type: ignore +from typing_extensions import TypeAlias + +from dlt.common.schema import Schema, TColumnSchema +from dlt.common.typing import DictStrAny +from dlt.destinations.type_mapping import TypeMapper + + +TArrowSchema: TypeAlias = pa.Schema +TArrowDataType: TypeAlias = pa.DataType +TArrowField: TypeAlias = pa.Field +NULL_SCHEMA: TArrowSchema = pa.schema([]) +"""Empty pyarrow Schema with no fields.""" + + +def arrow_schema_to_dict(schema: TArrowSchema) -> DictStrAny: + return {field.name: field.type for field in schema} + + +def make_arrow_field_schema( + column_name: str, + column: TColumnSchema, + type_mapper: TypeMapper, +) -> TArrowField: + """Creates a PyArrow field from a dlt column schema.""" + dtype = cast(TArrowDataType, type_mapper.to_db_type(column)) + return pa.field(column_name, dtype) + + +def make_arrow_table_schema( + table_name: str, + schema: Schema, + type_mapper: TypeMapper, + id_field_name: Optional[str] = None, + vector_field_name: Optional[str] = None, + embedding_fields: Optional[List[str]] = None, + embedding_model_func: Optional[TextEmbeddingFunction] = None, + embedding_model_dimensions: Optional[int] = None, +) -> TArrowSchema: + """Creates a PyArrow schema from a dlt schema.""" + arrow_schema: List[TArrowField] = [] + + if id_field_name: + arrow_schema.append(pa.field(id_field_name, pa.string())) + + if embedding_fields: + # User's provided dimension config, if provided, takes precedence. + vec_size = embedding_model_dimensions or embedding_model_func.ndims() + arrow_schema.append(pa.field(vector_field_name, pa.list_(pa.float32(), vec_size))) + + for column_name, column in schema.get_table_columns(table_name).items(): + field = make_arrow_field_schema(column_name, column, type_mapper) + arrow_schema.append(field) + + metadata = {} + if embedding_model_func: + # Get the registered alias if it exists, otherwise use the class name. + name = getattr( + embedding_model_func, + "__embedding_function_registry_alias__", + embedding_model_func.__class__.__name__, + ) + embedding_functions = [ + { + "source_column": source_column, + "vector_column": vector_field_name, + "name": name, + "model": embedding_model_func.safe_model_dump(), + } + for source_column in embedding_fields + ] + metadata["embedding_functions"] = json.dumps(embedding_functions).encode("utf-8") + + return pa.schema(arrow_schema, metadata=metadata) diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py new file mode 100644 index 0000000000..aeacd4d34b --- /dev/null +++ b/dlt/destinations/impl/lancedb/utils.py @@ -0,0 +1,55 @@ +import os +import uuid +from typing import Sequence, Union, Dict + +from dlt.common.schema import TTableSchema +from dlt.common.schema.utils import get_columns_names_with_prop +from dlt.common.typing import DictStrAny +from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider + + +PROVIDER_ENVIRONMENT_VARIABLES_MAP: Dict[TEmbeddingProvider, str] = { + "cohere": "COHERE_API_KEY", + "gemini-text": "GOOGLE_API_KEY", + "openai": "OPENAI_API_KEY", + "huggingface": "HUGGINGFACE_API_KEY", +} + + +def generate_uuid(data: DictStrAny, unique_identifiers: Sequence[str], table_name: str) -> str: + """Generates deterministic UUID - used for deduplication. + + Args: + data (Dict[str, Any]): Arbitrary data to generate UUID for. + unique_identifiers (Sequence[str]): A list of unique identifiers. + table_name (str): LanceDB table name. + + Returns: + str: A string representation of the generated UUID. + """ + data_id = "_".join(str(data[key]) for key in unique_identifiers) + return str(uuid.uuid5(uuid.NAMESPACE_DNS, table_name + data_id)) + + +def list_merge_identifiers(table_schema: TTableSchema) -> Sequence[str]: + """Returns a list of merge keys for a table used for either merging or deduplication. + + Args: + table_schema (TTableSchema): a dlt table schema. + + Returns: + Sequence[str]: A list of unique column identifiers. + """ + if table_schema.get("write_disposition") == "merge": + primary_keys = get_columns_names_with_prop(table_schema, "primary_key") + merge_keys = get_columns_names_with_prop(table_schema, "merge_key") + if join_keys := list(set(primary_keys + merge_keys)): + return join_keys + return get_columns_names_with_prop(table_schema, "unique") + + +def set_non_standard_providers_environment_variables( + embedding_model_provider: TEmbeddingProvider, api_key: Union[str, None] +) -> None: + if embedding_model_provider in PROVIDER_ENVIRONMENT_VARIABLES_MAP: + os.environ[PROVIDER_ENVIRONMENT_VARIABLES_MAP[embedding_model_provider]] = api_key or "" diff --git a/dlt/destinations/impl/motherduck/__init__.py b/dlt/destinations/impl/motherduck/__init__.py index 74c0e36ef3..e69de29bb2 100644 --- a/dlt/destinations/impl/motherduck/__init__.py +++ b/dlt/destinations/impl/motherduck/__init__.py @@ -1,24 +0,0 @@ -from dlt.common.data_writers.escape import escape_postgres_identifier, escape_duckdb_literal -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "parquet" - caps.supported_loader_file_formats = ["parquet", "insert_values", "jsonl"] - caps.escape_identifier = escape_postgres_identifier - caps.escape_literal = escape_duckdb_literal - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) - caps.max_identifier_length = 65536 - caps.max_column_identifier_length = 65536 - caps.max_query_length = 512 * 1024 - caps.is_max_query_length_in_bytes = True - caps.max_text_data_type_length = 1024 * 1024 * 1024 - caps.is_max_text_data_type_length_in_bytes = True - caps.supports_ddl_transactions = False - caps.alter_add_multi_column = False - caps.supports_truncate_command = False - - return caps diff --git a/dlt/destinations/impl/motherduck/factory.py b/dlt/destinations/impl/motherduck/factory.py index 5e35f69d75..a9bab96d08 100644 --- a/dlt/destinations/impl/motherduck/factory.py +++ b/dlt/destinations/impl/motherduck/factory.py @@ -1,11 +1,13 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.data_writers.escape import escape_postgres_identifier, escape_duckdb_literal +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE + from dlt.destinations.impl.motherduck.configuration import ( MotherDuckCredentials, MotherDuckClientConfiguration, ) -from dlt.destinations.impl.motherduck import capabilities if t.TYPE_CHECKING: from duckdb import DuckDBPyConnection @@ -15,8 +17,28 @@ class motherduck(Destination[MotherDuckClientConfiguration, "MotherDuckClient"]): spec = MotherDuckClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "parquet" + caps.supported_loader_file_formats = ["parquet", "insert_values", "jsonl"] + caps.escape_identifier = escape_postgres_identifier + # all identifiers are case insensitive but are stored as is + caps.escape_literal = escape_duckdb_literal + caps.has_case_sensitive_identifiers = False + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + caps.max_identifier_length = 65536 + caps.max_column_identifier_length = 65536 + caps.max_query_length = 512 * 1024 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 1024 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_ddl_transactions = False + caps.alter_add_multi_column = False + caps.supports_truncate_command = False + caps.supported_merge_strategies = ["delete-insert", "scd2"] + + return caps @property def client_class(self) -> t.Type["MotherDuckClient"]: diff --git a/dlt/destinations/impl/motherduck/motherduck.py b/dlt/destinations/impl/motherduck/motherduck.py index c695d9715e..5a700294fe 100644 --- a/dlt/destinations/impl/motherduck/motherduck.py +++ b/dlt/destinations/impl/motherduck/motherduck.py @@ -1,20 +1,25 @@ -from typing import ClassVar - from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.schema import Schema from dlt.destinations.impl.duckdb.duck import DuckDbClient -from dlt.destinations.impl.motherduck import capabilities from dlt.destinations.impl.motherduck.sql_client import MotherDuckSqlClient from dlt.destinations.impl.motherduck.configuration import MotherDuckClientConfiguration class MotherDuckClient(DuckDbClient): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: MotherDuckClientConfiguration) -> None: - super().__init__(schema, config) # type: ignore - sql_client = MotherDuckSqlClient(config.normalize_dataset_name(schema), config.credentials) + def __init__( + self, + schema: Schema, + config: MotherDuckClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(schema, config, capabilities) # type: ignore + sql_client = MotherDuckSqlClient( + config.normalize_dataset_name(schema), + config.normalize_staging_dataset_name(schema), + config.credentials, + capabilities, + ) self.config: MotherDuckClientConfiguration = config # type: ignore self.sql_client: MotherDuckSqlClient = sql_client diff --git a/dlt/destinations/impl/motherduck/sql_client.py b/dlt/destinations/impl/motherduck/sql_client.py index 7990f90947..5d680160f5 100644 --- a/dlt/destinations/impl/motherduck/sql_client.py +++ b/dlt/destinations/impl/motherduck/sql_client.py @@ -1,41 +1,23 @@ -import duckdb +from typing import Optional -from contextlib import contextmanager -from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence -from dlt.common.destination import DestinationCapabilitiesContext - -from dlt.destinations.exceptions import ( - DatabaseTerminalException, - DatabaseTransientException, - DatabaseUndefinedRelation, -) -from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction, DataFrame -from dlt.destinations.sql_client import ( - SqlClientBase, - DBApiCursorImpl, - raise_database_error, - raise_open_connection_error, -) - -from dlt.destinations.impl.duckdb.sql_client import DuckDbSqlClient, DuckDBDBApiCursorImpl -from dlt.destinations.impl.motherduck import capabilities +from dlt.common.destination.capabilities import DestinationCapabilitiesContext +from dlt.destinations.impl.duckdb.sql_client import DuckDbSqlClient from dlt.destinations.impl.motherduck.configuration import MotherDuckCredentials class MotherDuckSqlClient(DuckDbSqlClient): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, dataset_name: str, credentials: MotherDuckCredentials) -> None: - super().__init__(dataset_name, credentials) + def __init__( + self, + dataset_name: str, + staging_dataset_name: str, + credentials: MotherDuckCredentials, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(dataset_name, staging_dataset_name, credentials, capabilities) self.database_name = credentials.database - def fully_qualified_dataset_name(self, escape: bool = True) -> str: - database_name = ( - self.capabilities.escape_identifier(self.database_name) - if escape - else self.database_name - ) - dataset_name = ( - self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name - ) - return f"{database_name}.{dataset_name}" + def catalog_name(self, escape: bool = True) -> Optional[str]: + database_name = self.database_name + if escape: + database_name = self.capabilities.escape_identifier(database_name) + return database_name diff --git a/dlt/destinations/impl/mssql/__init__.py b/dlt/destinations/impl/mssql/__init__.py index f7768d9238..e69de29bb2 100644 --- a/dlt/destinations/impl/mssql/__init__.py +++ b/dlt/destinations/impl/mssql/__init__.py @@ -1,29 +0,0 @@ -from dlt.common.data_writers.escape import escape_postgres_identifier, escape_mssql_literal -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.common.wei import EVM_DECIMAL_PRECISION - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "insert_values" - caps.supported_loader_file_formats = ["insert_values"] - caps.preferred_staging_file_format = None - caps.supported_staging_file_formats = [] - caps.escape_identifier = escape_postgres_identifier - caps.escape_literal = escape_mssql_literal - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) - # https://learn.microsoft.com/en-us/sql/sql-server/maximum-capacity-specifications-for-sql-server?view=sql-server-ver16&redirectedfrom=MSDN - caps.max_identifier_length = 128 - caps.max_column_identifier_length = 128 - # A SQL Query can be a varchar(max) but is shown as limited to 65,536 * Network Packet - caps.max_query_length = 65536 * 10 - caps.is_max_query_length_in_bytes = True - caps.max_text_data_type_length = 2**30 - 1 - caps.is_max_text_data_type_length_in_bytes = False - caps.supports_ddl_transactions = True - caps.max_rows_per_insert = 1000 - caps.timestamp_precision = 7 - - return caps diff --git a/dlt/destinations/impl/mssql/configuration.py b/dlt/destinations/impl/mssql/configuration.py index 1d085f40c1..64d87065f3 100644 --- a/dlt/destinations/impl/mssql/configuration.py +++ b/dlt/destinations/impl/mssql/configuration.py @@ -34,8 +34,6 @@ def parse_native_representation(self, native_value: Any) -> None: self.query = {k.lower(): v for k, v in self.query.items()} # Make case-insensitive. self.driver = self.query.get("driver", self.driver) self.connect_timeout = int(self.query.get("connect_timeout", self.connect_timeout)) - if not self.is_partial(): - self.resolve() def on_resolved(self) -> None: if self.driver not in self.SUPPORTED_DRIVERS: @@ -45,10 +43,10 @@ def on_resolved(self) -> None: ) self.database = self.database.lower() - def to_url(self) -> URL: - url = super().to_url() - url.update_query_pairs([("connect_timeout", str(self.connect_timeout))]) - return url + def get_query(self) -> Dict[str, Any]: + query = dict(super().get_query()) + query["connect_timeout"] = self.connect_timeout + return query def on_partial(self) -> None: self.driver = self._get_driver() @@ -95,6 +93,7 @@ class MsSqlClientConfiguration(DestinationClientDwhWithStagingConfiguration): credentials: MsSqlCredentials = None create_indexes: bool = False + has_case_sensitive_identifiers: bool = False def fingerprint(self) -> str: """Returns a fingerprint of host part of a connection string""" diff --git a/dlt/destinations/impl/mssql/factory.py b/dlt/destinations/impl/mssql/factory.py index 2e19d7c2a8..4434fdd1e1 100644 --- a/dlt/destinations/impl/mssql/factory.py +++ b/dlt/destinations/impl/mssql/factory.py @@ -1,30 +1,59 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.normalizers.naming.naming import NamingConvention +from dlt.common.data_writers.escape import escape_postgres_identifier, escape_mssql_literal +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.destinations.impl.mssql.configuration import MsSqlCredentials, MsSqlClientConfiguration -from dlt.destinations.impl.mssql import capabilities if t.TYPE_CHECKING: - from dlt.destinations.impl.mssql.mssql import MsSqlClient + from dlt.destinations.impl.mssql.mssql import MsSqlJobClient -class mssql(Destination[MsSqlClientConfiguration, "MsSqlClient"]): +class mssql(Destination[MsSqlClientConfiguration, "MsSqlJobClient"]): spec = MsSqlClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "insert_values" + caps.supported_loader_file_formats = ["insert_values"] + caps.preferred_staging_file_format = None + caps.supported_staging_file_formats = [] + # mssql is by default case insensitive and stores identifiers as is + # case sensitivity can be changed by database collation so we allow to reconfigure + # capabilities in the mssql factory + caps.escape_identifier = escape_postgres_identifier + caps.escape_literal = escape_mssql_literal + caps.has_case_sensitive_identifiers = False + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + # https://learn.microsoft.com/en-us/sql/sql-server/maximum-capacity-specifications-for-sql-server?view=sql-server-ver16&redirectedfrom=MSDN + caps.max_identifier_length = 128 + caps.max_column_identifier_length = 128 + # A SQL Query can be a varchar(max) but is shown as limited to 65,536 * Network Packet + caps.max_query_length = 65536 * 10 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 2**30 - 1 + caps.is_max_text_data_type_length_in_bytes = False + caps.supports_ddl_transactions = True + caps.max_rows_per_insert = 1000 + caps.timestamp_precision = 7 + caps.supported_merge_strategies = ["delete-insert", "scd2"] + + return caps @property - def client_class(self) -> t.Type["MsSqlClient"]: - from dlt.destinations.impl.mssql.mssql import MsSqlClient + def client_class(self) -> t.Type["MsSqlJobClient"]: + from dlt.destinations.impl.mssql.mssql import MsSqlJobClient - return MsSqlClient + return MsSqlJobClient def __init__( self, credentials: t.Union[MsSqlCredentials, t.Dict[str, t.Any], str] = None, - create_indexes: bool = True, + create_indexes: bool = False, + has_case_sensitive_identifiers: bool = False, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, **kwargs: t.Any, @@ -37,12 +66,27 @@ def __init__( credentials: Credentials to connect to the mssql database. Can be an instance of `MsSqlCredentials` or a connection string in the format `mssql://user:password@host:port/database` create_indexes: Should unique indexes be created + has_case_sensitive_identifiers: Are identifiers used by mssql database case sensitive (following the collation) **kwargs: Additional arguments passed to the destination config """ super().__init__( credentials=credentials, create_indexes=create_indexes, + has_case_sensitive_identifiers=has_case_sensitive_identifiers, destination_name=destination_name, environment=environment, **kwargs, ) + + @classmethod + def adjust_capabilities( + cls, + caps: DestinationCapabilitiesContext, + config: MsSqlClientConfiguration, + naming: t.Optional[NamingConvention], + ) -> DestinationCapabilitiesContext: + # modify the caps if case sensitive identifiers are requested + if config.has_case_sensitive_identifiers: + caps.has_case_sensitive_identifiers = True + caps.casefold_identifier = str + return super().adjust_capabilities(caps, config, naming) diff --git a/dlt/destinations/impl/mssql/mssql.py b/dlt/destinations/impl/mssql/mssql.py index 6f364c8af1..ec4a54d6f7 100644 --- a/dlt/destinations/impl/mssql/mssql.py +++ b/dlt/destinations/impl/mssql/mssql.py @@ -1,19 +1,15 @@ -from typing import ClassVar, Dict, Optional, Sequence, List, Any, Tuple +from typing import Dict, Optional, Sequence, List, Any from dlt.common.exceptions import TerminalValueError -from dlt.common.wei import EVM_DECIMAL_PRECISION from dlt.common.destination.reference import NewLoadJob from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, TColumnHint, Schema from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat -from dlt.common.utils import uniq_id from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob, SqlJobParams from dlt.destinations.insert_job_client import InsertValuesJobClient -from dlt.destinations.impl.mssql import capabilities from dlt.destinations.impl.mssql.sql_client import PyOdbcMsSqlClient from dlt.destinations.impl.mssql.configuration import MsSqlClientConfiguration from dlt.destinations.sql_client import SqlClientBase @@ -99,7 +95,7 @@ def generate_sql( ) -> List[str]: sql: List[str] = [] for table in table_chain: - with sql_client.with_staging_dataset(staging=True): + with sql_client.with_staging_dataset(): staging_table_name = sql_client.make_qualified_table_name(table["name"]) table_name = sql_client.make_qualified_table_name(table["name"]) # drop destination table @@ -145,11 +141,19 @@ def _new_temp_table_name(cls, name_prefix: str, sql_client: SqlClientBase[Any]) return "#" + name -class MsSqlClient(InsertValuesJobClient): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: MsSqlClientConfiguration) -> None: - sql_client = PyOdbcMsSqlClient(config.normalize_dataset_name(schema), config.credentials) +class MsSqlJobClient(InsertValuesJobClient): + def __init__( + self, + schema: Schema, + config: MsSqlClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + sql_client = PyOdbcMsSqlClient( + config.normalize_dataset_name(schema), + config.normalize_staging_dataset_name(schema), + config.credentials, + capabilities, + ) super().__init__(schema, config, sql_client) self.config: MsSqlClientConfiguration = config self.sql_client = sql_client @@ -180,7 +184,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non for h in self.active_hints.keys() if c.get(h, False) is True ) - column_name = self.capabilities.escape_identifier(c["name"]) + column_name = self.sql_client.escape_column_name(c["name"]) return f"{column_name} {db_type} {hints_str} {self._gen_not_null(c.get('nullable', True))}" def _create_replace_followup_jobs( diff --git a/dlt/destinations/impl/mssql/sql_client.py b/dlt/destinations/impl/mssql/sql_client.py index db043bae25..e1b51743f5 100644 --- a/dlt/destinations/impl/mssql/sql_client.py +++ b/dlt/destinations/impl/mssql/sql_client.py @@ -1,4 +1,3 @@ -import platform import struct from datetime import datetime, timedelta, timezone # noqa: I251 @@ -23,7 +22,6 @@ ) from dlt.destinations.impl.mssql.configuration import MsSqlCredentials -from dlt.destinations.impl.mssql import capabilities def handle_datetimeoffset(dto_value: bytes) -> datetime: @@ -43,10 +41,15 @@ def handle_datetimeoffset(dto_value: bytes) -> datetime: class PyOdbcMsSqlClient(SqlClientBase[pyodbc.Connection], DBTransaction): dbapi: ClassVar[DBApi] = pyodbc - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - def __init__(self, dataset_name: str, credentials: MsSqlCredentials) -> None: - super().__init__(credentials.database, dataset_name) + def __init__( + self, + dataset_name: str, + staging_dataset_name: str, + credentials: MsSqlCredentials, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(credentials.database, dataset_name, staging_dataset_name, capabilities) self._conn: pyodbc.Connection = None self.credentials = credentials @@ -104,14 +107,14 @@ def drop_dataset(self) -> None: # Drop all views rows = self.execute_sql( "SELECT table_name FROM information_schema.views WHERE table_schema = %s;", - self.dataset_name, + self.capabilities.casefold_identifier(self.dataset_name), ) view_names = [row[0] for row in rows] self._drop_views(*view_names) # Drop all tables rows = self.execute_sql( "SELECT table_name FROM information_schema.tables WHERE table_schema = %s;", - self.dataset_name, + self.capabilities.casefold_identifier(self.dataset_name), ) table_names = [row[0] for row in rows] self.drop_tables(*table_names) @@ -127,7 +130,7 @@ def _drop_views(self, *tables: str) -> None: self.execute_many(statements) def _drop_schema(self) -> None: - self.execute_sql("DROP SCHEMA IF EXISTS %s;" % self.fully_qualified_dataset_name()) + self.execute_sql("DROP SCHEMA %s;" % self.fully_qualified_dataset_name()) def execute_sql( self, sql: AnyStr, *args: Any, **kwargs: Any @@ -158,11 +161,6 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB except pyodbc.Error as outer: raise outer - def fully_qualified_dataset_name(self, escape: bool = True) -> str: - return ( - self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name - ) - @classmethod def _make_database_exception(cls, ex: Exception) -> Exception: if isinstance(ex, pyodbc.ProgrammingError): diff --git a/dlt/destinations/impl/postgres/__init__.py b/dlt/destinations/impl/postgres/__init__.py index bdb9297210..e69de29bb2 100644 --- a/dlt/destinations/impl/postgres/__init__.py +++ b/dlt/destinations/impl/postgres/__init__.py @@ -1,27 +0,0 @@ -from dlt.common.data_writers.escape import escape_postgres_identifier, escape_postgres_literal -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.common.wei import EVM_DECIMAL_PRECISION - - -def capabilities() -> DestinationCapabilitiesContext: - # https://www.postgresql.org/docs/current/limits.html - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "insert_values" - caps.supported_loader_file_formats = ["insert_values", "csv"] - caps.preferred_staging_file_format = None - caps.supported_staging_file_formats = [] - caps.escape_identifier = escape_postgres_identifier - caps.escape_literal = escape_postgres_literal - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (2 * EVM_DECIMAL_PRECISION, EVM_DECIMAL_PRECISION) - caps.max_identifier_length = 63 - caps.max_column_identifier_length = 63 - caps.max_query_length = 32 * 1024 * 1024 - caps.is_max_query_length_in_bytes = True - caps.max_text_data_type_length = 1024 * 1024 * 1024 - caps.is_max_text_data_type_length_in_bytes = True - caps.supports_ddl_transactions = True - - return caps diff --git a/dlt/destinations/impl/postgres/configuration.py b/dlt/destinations/impl/postgres/configuration.py index 0d12abbac7..13bdc7f6b2 100644 --- a/dlt/destinations/impl/postgres/configuration.py +++ b/dlt/destinations/impl/postgres/configuration.py @@ -1,6 +1,7 @@ import dataclasses -from typing import Final, ClassVar, Any, List, TYPE_CHECKING, Union +from typing import Dict, Final, ClassVar, Any, List, Optional +from dlt.common.data_writers.configuration import CsvFormatConfiguration from dlt.common.libs.sql_alchemy import URL from dlt.common.configuration import configspec from dlt.common.configuration.specs import ConnectionStringCredentials @@ -23,13 +24,11 @@ class PostgresCredentials(ConnectionStringCredentials): def parse_native_representation(self, native_value: Any) -> None: super().parse_native_representation(native_value) self.connect_timeout = int(self.query.get("connect_timeout", self.connect_timeout)) - if not self.is_partial(): - self.resolve() - def to_url(self) -> URL: - url = super().to_url() - url.update_query_pairs([("connect_timeout", str(self.connect_timeout))]) - return url + def get_query(self) -> Dict[str, Any]: + query = dict(super().get_query()) + query["connect_timeout"] = self.connect_timeout + return query @configspec @@ -39,6 +38,9 @@ class PostgresClientConfiguration(DestinationClientDwhWithStagingConfiguration): create_indexes: bool = True + csv_format: Optional[CsvFormatConfiguration] = None + """Optional csv format configuration""" + def fingerprint(self) -> str: """Returns a fingerprint of host part of a connection string""" if self.credentials and self.credentials.host: diff --git a/dlt/destinations/impl/postgres/factory.py b/dlt/destinations/impl/postgres/factory.py index 68d72f890a..e14aa61465 100644 --- a/dlt/destinations/impl/postgres/factory.py +++ b/dlt/destinations/impl/postgres/factory.py @@ -1,12 +1,15 @@ import typing as t +from dlt.common.data_writers.configuration import CsvFormatConfiguration from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.data_writers.escape import escape_postgres_identifier, escape_postgres_literal +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.wei import EVM_DECIMAL_PRECISION from dlt.destinations.impl.postgres.configuration import ( PostgresCredentials, PostgresClientConfiguration, ) -from dlt.destinations.impl.postgres import capabilities if t.TYPE_CHECKING: from dlt.destinations.impl.postgres.postgres import PostgresClient @@ -15,8 +18,32 @@ class postgres(Destination[PostgresClientConfiguration, "PostgresClient"]): spec = PostgresClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + # https://www.postgresql.org/docs/current/limits.html + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "insert_values" + caps.supported_loader_file_formats = ["insert_values", "csv"] + caps.preferred_staging_file_format = None + caps.supported_staging_file_formats = [] + caps.escape_identifier = escape_postgres_identifier + # postgres has case sensitive identifiers but by default + # it folds them to lower case which makes them case insensitive + # https://stackoverflow.com/questions/20878932/are-postgresql-column-names-case-sensitive + caps.casefold_identifier = str.lower + caps.has_case_sensitive_identifiers = True + caps.escape_literal = escape_postgres_literal + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (2 * EVM_DECIMAL_PRECISION, EVM_DECIMAL_PRECISION) + caps.max_identifier_length = 63 + caps.max_column_identifier_length = 63 + caps.max_query_length = 32 * 1024 * 1024 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 1024 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_ddl_transactions = True + caps.supported_merge_strategies = ["delete-insert", "upsert", "scd2"] + + return caps @property def client_class(self) -> t.Type["PostgresClient"]: @@ -28,6 +55,7 @@ def __init__( self, credentials: t.Union[PostgresCredentials, t.Dict[str, t.Any], str] = None, create_indexes: bool = True, + csv_format: t.Optional[CsvFormatConfiguration] = None, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, **kwargs: t.Any, @@ -40,11 +68,13 @@ def __init__( credentials: Credentials to connect to the postgres database. Can be an instance of `PostgresCredentials` or a connection string in the format `postgres://user:password@host:port/database` create_indexes: Should unique indexes be created + csv_format: Formatting options for csv file format **kwargs: Additional arguments passed to the destination config """ super().__init__( credentials=credentials, create_indexes=create_indexes, + csv_format=csv_format, destination_name=destination_name, environment=environment, **kwargs, diff --git a/dlt/destinations/impl/postgres/postgres.py b/dlt/destinations/impl/postgres/postgres.py index 11cee208b1..f47549fc4f 100644 --- a/dlt/destinations/impl/postgres/postgres.py +++ b/dlt/destinations/impl/postgres/postgres.py @@ -1,5 +1,11 @@ -from typing import ClassVar, Dict, Optional, Sequence, List, Any - +from typing import Dict, Optional, Sequence, List, Any + +from dlt.common import logger +from dlt.common.data_writers.configuration import CsvFormatConfiguration +from dlt.common.destination.exceptions import ( + DestinationInvalidFileFormat, + DestinationTerminalException, +) from dlt.common.destination.reference import FollowupJob, LoadJob, NewLoadJob, TLoadJobState from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.exceptions import TerminalValueError @@ -9,7 +15,6 @@ from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams from dlt.destinations.insert_job_client import InsertValuesJobClient -from dlt.destinations.impl.postgres import capabilities from dlt.destinations.impl.postgres.sql_client import Psycopg2SqlClient from dlt.destinations.impl.postgres.configuration import PostgresClientConfiguration from dlt.destinations.sql_client import SqlClientBase @@ -90,7 +95,7 @@ def generate_sql( ) -> List[str]: sql: List[str] = [] for table in table_chain: - with sql_client.with_staging_dataset(staging=True): + with sql_client.with_staging_dataset(): staging_table_name = sql_client.make_qualified_table_name(table["name"]) table_name = sql_client.make_qualified_table_name(table["name"]) # drop destination table @@ -106,21 +111,85 @@ def generate_sql( class PostgresCsvCopyJob(LoadJob, FollowupJob): - def __init__(self, table_name: str, file_path: str, sql_client: Psycopg2SqlClient) -> None: + def __init__(self, table: TTableSchema, file_path: str, client: "PostgresClient") -> None: super().__init__(FileStorage.get_file_name_from_file_path(file_path)) + config = client.config + sql_client = client.sql_client + csv_format = config.csv_format or CsvFormatConfiguration() + table_name = table["name"] + sep = csv_format.delimiter + if csv_format.on_error_continue: + logger.warning( + f"When processing {file_path} on table {table_name} Postgres csv reader does not" + " support on_error_continue" + ) with FileStorage.open_zipsafe_ro(file_path, "rb") as f: - # all headers in first line - headers = f.readline().decode("utf-8").strip() - # quote headers if not quoted - all special keywords like "binary" must be quoted - headers = ",".join(h if h.startswith('"') else f'"{h}"' for h in headers.split(",")) + if csv_format.include_header: + # all headers in first line + headers_row: str = f.readline().decode(csv_format.encoding).strip() + split_headers = headers_row.split(sep) + else: + # read first row to figure out the headers + split_first_row: str = f.readline().decode(csv_format.encoding).strip().split(sep) + split_headers = list(client.schema.get_table_columns(table_name).keys()) + if len(split_first_row) > len(split_headers): + raise DestinationInvalidFileFormat( + "postgres", + "csv", + file_path, + f"First row {split_first_row} has more rows than columns {split_headers} in" + f" table {table_name}", + ) + if len(split_first_row) < len(split_headers): + logger.warning( + f"First row {split_first_row} has less rows than columns {split_headers} in" + f" table {table_name}. We will not load data to superfluous columns." + ) + split_headers = split_headers[: len(split_first_row)] + # stream the first row again + f.seek(0) + + # normalized and quoted headers + split_headers = [ + sql_client.escape_column_name(h.strip('"'), escape=True) for h in split_headers + ] + split_null_headers = [] + split_columns = [] + # detect columns with NULL to use in FORCE NULL + # detect headers that are not in columns + for col in client.schema.get_table_columns(table_name).values(): + norm_col = sql_client.escape_column_name(col["name"], escape=True) + split_columns.append(norm_col) + if norm_col in split_headers and col.get("nullable", True): + split_null_headers.append(norm_col) + split_unknown_headers = set(split_headers).difference(split_columns) + if split_unknown_headers: + raise DestinationInvalidFileFormat( + "postgres", + "csv", + file_path, + f"Following headers {split_unknown_headers} cannot be matched to columns" + f" {split_columns} of table {table_name}.", + ) + + # use comma to join + headers = ",".join(split_headers) + if split_null_headers: + null_headers = f"FORCE_NULL({','.join(split_null_headers)})," + else: + null_headers = "" + qualified_table_name = sql_client.make_qualified_table_name(table_name) copy_sql = ( - "COPY %s (%s) FROM STDIN WITH (FORMAT CSV, DELIMITER ',', NULL '', FORCE_NULL(%s))" + "COPY %s (%s) FROM STDIN WITH (FORMAT CSV, DELIMITER '%s', NULL ''," + " %s ENCODING '%s')" % ( qualified_table_name, headers, - headers, + sep, + null_headers, + csv_format.encoding, ) ) with sql_client.begin_transaction(): @@ -135,10 +204,18 @@ def exception(self) -> str: class PostgresClient(InsertValuesJobClient): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: PostgresClientConfiguration) -> None: - sql_client = Psycopg2SqlClient(config.normalize_dataset_name(schema), config.credentials) + def __init__( + self, + schema: Schema, + config: PostgresClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + sql_client = Psycopg2SqlClient( + config.normalize_dataset_name(schema), + config.normalize_staging_dataset_name(schema), + config.credentials, + capabilities, + ) super().__init__(schema, config, sql_client) self.config: PostgresClientConfiguration = config self.sql_client: Psycopg2SqlClient = sql_client @@ -148,7 +225,7 @@ def __init__(self, schema: Schema, config: PostgresClientConfiguration) -> None: def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: job = super().start_file_load(table, file_path, load_id) if not job and file_path.endswith("csv"): - job = PostgresCsvCopyJob(table["name"], file_path, self.sql_client) + job = PostgresCsvCopyJob(table, file_path, self) return job def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: @@ -157,7 +234,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non for h in self.active_hints.keys() if c.get(h, False) is True ) - column_name = self.capabilities.escape_identifier(c["name"]) + column_name = self.sql_client.escape_column_name(c["name"]) return ( f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) diff --git a/dlt/destinations/impl/postgres/sql_client.py b/dlt/destinations/impl/postgres/sql_client.py index 366ed243ef..d867248196 100644 --- a/dlt/destinations/impl/postgres/sql_client.py +++ b/dlt/destinations/impl/postgres/sql_client.py @@ -26,15 +26,19 @@ ) from dlt.destinations.impl.postgres.configuration import PostgresCredentials -from dlt.destinations.impl.postgres import capabilities class Psycopg2SqlClient(SqlClientBase["psycopg2.connection"], DBTransaction): dbapi: ClassVar[DBApi] = psycopg2 - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - def __init__(self, dataset_name: str, credentials: PostgresCredentials) -> None: - super().__init__(credentials.database, dataset_name) + def __init__( + self, + dataset_name: str, + staging_dataset_name: str, + credentials: PostgresCredentials, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(credentials.database, dataset_name, staging_dataset_name, capabilities) self._conn: psycopg2.connection = None self.credentials = credentials @@ -112,11 +116,6 @@ def execute_fragments( composed = Composed(sql if isinstance(sql, Composable) else SQL(sql) for sql in fragments) return self.execute_sql(composed, *args, **kwargs) - def fully_qualified_dataset_name(self, escape: bool = True) -> str: - return ( - self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name - ) - def _reset_connection(self) -> None: # self._conn.autocommit = True self._conn.reset() diff --git a/dlt/destinations/impl/qdrant/__init__.py b/dlt/destinations/impl/qdrant/__init__.py index 1a2c466b14..e69de29bb2 100644 --- a/dlt/destinations/impl/qdrant/__init__.py +++ b/dlt/destinations/impl/qdrant/__init__.py @@ -1,18 +0,0 @@ -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.destinations.impl.qdrant.qdrant_adapter import qdrant_adapter - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "jsonl" - caps.supported_loader_file_formats = ["jsonl"] - - caps.max_identifier_length = 200 - caps.max_column_identifier_length = 1024 - caps.max_query_length = 8 * 1024 * 1024 - caps.is_max_query_length_in_bytes = False - caps.max_text_data_type_length = 8 * 1024 * 1024 - caps.is_max_text_data_type_length_in_bytes = False - caps.supports_ddl_transactions = False - - return caps diff --git a/dlt/destinations/impl/qdrant/configuration.py b/dlt/destinations/impl/qdrant/configuration.py index fd11cc7dcb..baf5e5dc59 100644 --- a/dlt/destinations/impl/qdrant/configuration.py +++ b/dlt/destinations/impl/qdrant/configuration.py @@ -1,6 +1,6 @@ import dataclasses -from typing import Optional, Final -from typing_extensions import Annotated +from typing import Optional, Final, Any +from typing_extensions import Annotated, TYPE_CHECKING from dlt.common.configuration import configspec, NotResolved from dlt.common.configuration.specs.base_configuration import ( @@ -8,16 +8,65 @@ CredentialsConfiguration, ) from dlt.common.destination.reference import DestinationClientDwhConfiguration +from dlt.destinations.impl.qdrant.exceptions import InvalidInMemoryQdrantCredentials + +if TYPE_CHECKING: + from qdrant_client import QdrantClient @configspec class QdrantCredentials(CredentialsConfiguration): - # If `:memory:` - use in-memory Qdrant instance. + if TYPE_CHECKING: + _external_client: "QdrantClient" + # If `str` - use it as a `url` parameter. # If `None` - use default values for `host` and `port` location: Optional[str] = None # API key for authentication in Qdrant Cloud. Default: `None` api_key: Optional[str] = None + # Persistence path for QdrantLocal. Default: `None` + path: Optional[str] = None + + def is_local(self) -> bool: + return self.path is not None + + def on_resolved(self) -> None: + if self.location == ":memory:": + raise InvalidInMemoryQdrantCredentials() + + def parse_native_representation(self, native_value: Any) -> None: + try: + from qdrant_client import QdrantClient + + if isinstance(native_value, QdrantClient): + self._external_client = native_value + self.resolve() + except ModuleNotFoundError: + pass + + super().parse_native_representation(native_value) + + def _create_client(self, model: str, **options: Any) -> "QdrantClient": + from qdrant_client import QdrantClient + + creds = dict(self) + if creds["path"]: + del creds["location"] + + client = QdrantClient(**creds, **options) + client.set_model(model) + return client + + def get_client(self, model: str, **options: Any) -> "QdrantClient": + client = getattr(self, "_external_client", None) + return client or self._create_client(model, **options) + + def close_client(self, client: "QdrantClient") -> None: + """Close client if not external""" + if getattr(self, "_external_client", None) is client: + # Do not close client created externally + return + client.close() def __str__(self) -> str: return self.location or "localhost" @@ -44,7 +93,7 @@ class QdrantClientOptions(BaseConfiguration): # Default: `None` host: Optional[str] = None # Persistence path for QdrantLocal. Default: `None` - path: Optional[str] = None + # path: Optional[str] = None @configspec @@ -79,6 +128,12 @@ class QdrantClientConfiguration(DestinationClientDwhConfiguration): # Find the list here. https://qdrant.github.io/fastembed/examples/Supported_Models/. model: str = "BAAI/bge-small-en" + def get_client(self) -> "QdrantClient": + return self.credentials.get_client(self.model, **dict(self.options)) + + def close_client(self, client: "QdrantClient") -> None: + self.credentials.close_client(client) + def fingerprint(self) -> str: """Returns a fingerprint of a connection string""" diff --git a/dlt/destinations/impl/qdrant/exceptions.py b/dlt/destinations/impl/qdrant/exceptions.py new file mode 100644 index 0000000000..19f33f64c1 --- /dev/null +++ b/dlt/destinations/impl/qdrant/exceptions.py @@ -0,0 +1,11 @@ +from dlt.common.destination.exceptions import DestinationTerminalException + + +class InvalidInMemoryQdrantCredentials(DestinationTerminalException): + def __init__(self) -> None: + super().__init__( + "To use in-memory instance of qdrant, " + "please instantiate it first and then pass to destination factory\n" + '\nclient = QdrantClient(":memory:")\n' + 'dlt.pipeline(pipeline_name="...", destination=dlt.destinations.qdrant(client)' + ) diff --git a/dlt/destinations/impl/qdrant/factory.py b/dlt/destinations/impl/qdrant/factory.py index df9cd64871..f994948d91 100644 --- a/dlt/destinations/impl/qdrant/factory.py +++ b/dlt/destinations/impl/qdrant/factory.py @@ -1,23 +1,50 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.destination.reference import TDestinationConfig +from dlt.common.normalizers.naming import NamingConvention from dlt.destinations.impl.qdrant.configuration import QdrantCredentials, QdrantClientConfiguration -from dlt.destinations.impl.qdrant import capabilities if t.TYPE_CHECKING: - from dlt.destinations.impl.qdrant.qdrant_client import QdrantClient + from dlt.destinations.impl.qdrant.qdrant_job_client import QdrantClient class qdrant(Destination[QdrantClientConfiguration, "QdrantClient"]): spec = QdrantClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "jsonl" + caps.supported_loader_file_formats = ["jsonl"] + caps.has_case_sensitive_identifiers = True + caps.max_identifier_length = 200 + caps.max_column_identifier_length = 1024 + caps.max_query_length = 8 * 1024 * 1024 + caps.is_max_query_length_in_bytes = False + caps.max_text_data_type_length = 8 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = False + caps.supports_ddl_transactions = False + + return caps + + @classmethod + def adjust_capabilities( + cls, + caps: DestinationCapabilitiesContext, + config: QdrantClientConfiguration, + naming: t.Optional[NamingConvention], + ) -> DestinationCapabilitiesContext: + caps = super(qdrant, cls).adjust_capabilities(caps, config, naming) + if config.credentials.is_local(): + # Local qdrant can not load in parallel + caps.loader_parallelism_strategy = "sequential" + caps.max_parallel_load_jobs = 1 + return caps @property def client_class(self) -> t.Type["QdrantClient"]: - from dlt.destinations.impl.qdrant.qdrant_client import QdrantClient + from dlt.destinations.impl.qdrant.qdrant_job_client import QdrantClient return QdrantClient diff --git a/dlt/destinations/impl/qdrant/qdrant_client.py b/dlt/destinations/impl/qdrant/qdrant_job_client.py similarity index 65% rename from dlt/destinations/impl/qdrant/qdrant_client.py rename to dlt/destinations/impl/qdrant/qdrant_job_client.py index 9898b28c86..28d7388701 100644 --- a/dlt/destinations/impl/qdrant/qdrant_client.py +++ b/dlt/destinations/impl/qdrant/qdrant_job_client.py @@ -1,19 +1,27 @@ from types import TracebackType -from typing import ClassVar, Optional, Sequence, List, Dict, Type, Iterable, Any, IO +from typing import Optional, Sequence, List, Dict, Type, Iterable, Any +import threading from dlt.common import logger from dlt.common.json import json from dlt.common.pendulum import pendulum from dlt.common.schema import Schema, TTableSchema, TSchemaTables -from dlt.common.schema.utils import get_columns_names_with_prop +from dlt.common.schema.utils import ( + get_columns_names_with_prop, + loads_table, + normalize_table_identifiers, + version_table, +) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import TLoadJobState, LoadJob, JobClientBase, WithStateSync +from dlt.common.destination.exceptions import DestinationUndefinedEntity from dlt.common.storages import FileStorage +from dlt.common.time import precise_time from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.job_client_impl import StorageSchemaInfo, StateInfo -from dlt.destinations.impl.qdrant import capabilities +from dlt.destinations.utils import get_pipeline_state_query_columns from dlt.destinations.impl.qdrant.configuration import QdrantClientConfiguration from dlt.destinations.impl.qdrant.qdrant_adapter import VECTORIZE_HINT @@ -40,6 +48,7 @@ def __init__( self.config = client_config with FileStorage.open_zipsafe_ro(local_path) as f: + ids: List[str] docs, payloads, ids = [], [], [] for line in f: @@ -47,23 +56,26 @@ def __init__( point_id = ( self._generate_uuid(data, self.unique_identifiers, self.collection_name) if self.unique_identifiers - else uuid.uuid4() + else str(uuid.uuid4()) ) - embedding_doc = self._get_embedding_doc(data) payloads.append(data) ids.append(point_id) - docs.append(embedding_doc) - - embedding_model = db_client._get_or_init_model(db_client.embedding_model_name) - embeddings = list( - embedding_model.embed( - docs, - batch_size=self.config.embedding_batch_size, - parallel=self.config.embedding_parallelism, + if len(self.embedding_fields) > 0: + docs.append(self._get_embedding_doc(data)) + + if len(self.embedding_fields) > 0: + embedding_model = db_client._get_or_init_model(db_client.embedding_model_name) + embeddings = list( + embedding_model.embed( + docs, + batch_size=self.config.embedding_batch_size, + parallel=self.config.embedding_parallelism, + ) ) - ) - vector_name = db_client.get_vector_field_name() - embeddings = [{vector_name: embedding.tolist()} for embedding in embeddings] + vector_name = db_client.get_vector_field_name() + embeddings = [{vector_name: embedding.tolist()} for embedding in embeddings] + else: + embeddings = [{}] * len(ids) assert len(embeddings) == len(payloads) == len(ids) self._upload_data(vectors=embeddings, ids=ids, payloads=payloads) @@ -126,7 +138,7 @@ def _generate_uuid( collection_name (str): Qdrant collection name. Returns: - str: A string representation of the genrated UUID + str: A string representation of the generated UUID """ data_id = "_".join(str(data[key]) for key in unique_identifiers) return str(uuid.uuid5(uuid.NAMESPACE_DNS, collection_name + data_id)) @@ -141,20 +153,25 @@ def exception(self) -> str: class QdrantClient(JobClientBase, WithStateSync): """Qdrant Destination Handler""" - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - state_properties: ClassVar[List[str]] = [ - "version", - "engine_version", - "pipeline_name", - "state", - "created_at", - "_dlt_load_id", - ] - - def __init__(self, schema: Schema, config: QdrantClientConfiguration) -> None: - super().__init__(schema, config) + def __init__( + self, + schema: Schema, + config: QdrantClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(schema, config, capabilities) + # get definitions of the dlt tables, normalize column names and keep for later use + version_table_ = normalize_table_identifiers(version_table(), schema.naming) + self.version_collection_properties = list(version_table_["columns"].keys()) + loads_table_ = normalize_table_identifiers(loads_table(), schema.naming) + self.loads_collection_properties = list(loads_table_["columns"].keys()) + state_table_ = normalize_table_identifiers( + get_pipeline_state_query_columns(), schema.naming + ) + self.pipeline_state_properties = list(state_table_["columns"].keys()) + self.config: QdrantClientConfiguration = config - self.db_client: QC = QdrantClient._create_db_client(config) + self.db_client: QC = None self.model = config.model @property @@ -165,22 +182,6 @@ def dataset_name(self) -> str: def sentinel_collection(self) -> str: return self.dataset_name or "DltSentinelCollection" - @staticmethod - def _create_db_client(config: QdrantClientConfiguration) -> QC: - """Generates a Qdrant client from the 'qdrant_client' package. - - Args: - config (QdrantClientConfiguration): Credentials and options for the Qdrant client. - - Returns: - QdrantClient: A Qdrant client instance. - """ - credentials = dict(config.credentials) - options = dict(config.options) - client = QC(**credentials, **options) - client.set_model(config.model) - return client - def _make_qualified_collection_name(self, table_name: str) -> str: """Generates a qualified collection name. @@ -216,8 +217,10 @@ def _create_collection(self, full_collection_name: str) -> None: self.db_client.create_collection( collection_name=full_collection_name, vectors_config=vectors_config ) + # TODO: we can use index hints to create indexes on properties or full text + # self.db_client.create_payload_index(full_collection_name, "_dlt_load_id", field_type="float") - def _create_point(self, obj: Dict[str, Any], collection_name: str) -> None: + def _create_point_no_vector(self, obj: Dict[str, Any], collection_name: str) -> None: """Inserts a point into a Qdrant collection without a vector. Args: @@ -308,8 +311,14 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: """Loads compressed state from destination storage By finding a load id that was completed """ + # normalize property names + p_load_id = self.schema.naming.normalize_identifier("load_id") + p_dlt_load_id = self.schema.naming.normalize_identifier("_dlt_load_id") + p_pipeline_name = self.schema.naming.normalize_identifier("pipeline_name") + p_created_at = self.schema.naming.normalize_identifier("created_at") + limit = 10 - offset = None + start_from = None while True: try: scroll_table_name = self._make_qualified_collection_name( @@ -317,22 +326,27 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: ) state_records, offset = self.db_client.scroll( scroll_table_name, - with_payload=self.state_properties, + with_payload=self.pipeline_state_properties, scroll_filter=models.Filter( must=[ models.FieldCondition( - key="pipeline_name", match=models.MatchValue(value=pipeline_name) + key=p_pipeline_name, match=models.MatchValue(value=pipeline_name) ) ] ), + order_by=models.OrderBy( + key=p_created_at, + direction=models.Direction.DESC, + start_from=start_from, + ), limit=limit, - offset=offset, ) if len(state_records) == 0: return None for state_record in state_records: state = state_record.payload - load_id = state["_dlt_load_id"] + start_from = state[p_created_at] + load_id = state[p_dlt_load_id] scroll_table_name = self._make_qualified_collection_name( self.schema.loads_table_name ) @@ -342,58 +356,87 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: count_filter=models.Filter( must=[ models.FieldCondition( - key="load_id", match=models.MatchValue(value=load_id) + key=p_load_id, match=models.MatchValue(value=load_id) ) ] ), ) - if load_records.count > 0: - state["dlt_load_id"] = state.pop("_dlt_load_id") - return StateInfo(**state) - except Exception: - return None + if load_records.count == 0: + continue + return StateInfo.from_normalized_mapping(state, self.schema.naming) + except UnexpectedResponse as e: + if e.status_code == 404: + raise DestinationUndefinedEntity(str(e)) from e + raise + except ValueError as e: # Local qdrant error + if "not found" in str(e): + raise DestinationUndefinedEntity(str(e)) from e + raise def get_stored_schema(self) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage""" try: scroll_table_name = self._make_qualified_collection_name(self.schema.version_table_name) + p_schema_name = self.schema.naming.normalize_identifier("schema_name") + p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") response = self.db_client.scroll( scroll_table_name, with_payload=True, scroll_filter=models.Filter( must=[ models.FieldCondition( - key="schema_name", + key=p_schema_name, match=models.MatchValue(value=self.schema.name), ) ] ), limit=1, + order_by=models.OrderBy( + key=p_inserted_at, + direction=models.Direction.DESC, + ), ) - record = response[0][0].payload - return StorageSchemaInfo(**record) - except Exception: - return None + if not response[0]: + return None + payload = response[0][0].payload + return StorageSchemaInfo.from_normalized_mapping(payload, self.schema.naming) + except UnexpectedResponse as e: + if e.status_code == 404: + raise DestinationUndefinedEntity(str(e)) from e + raise + except ValueError as e: # Local qdrant error + if "not found" in str(e): + raise DestinationUndefinedEntity(str(e)) from e + raise def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: try: scroll_table_name = self._make_qualified_collection_name(self.schema.version_table_name) + p_version_hash = self.schema.naming.normalize_identifier("version_hash") response = self.db_client.scroll( scroll_table_name, with_payload=True, scroll_filter=models.Filter( must=[ models.FieldCondition( - key="version_hash", match=models.MatchValue(value=schema_hash) + key=p_version_hash, match=models.MatchValue(value=schema_hash) ) ] ), limit=1, ) - record = response[0][0].payload - return StorageSchemaInfo(**record) - except Exception: - return None + if not response[0]: + return None + payload = response[0][0].payload + return StorageSchemaInfo.from_normalized_mapping(payload, self.schema.naming) + except UnexpectedResponse as e: + if e.status_code == 404: + return None + raise + except ValueError as e: # Local qdrant error + if "not found" in str(e): + return None + raise def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: return LoadQdrantJob( @@ -408,16 +451,14 @@ def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") def complete_load(self, load_id: str) -> None: - properties = { - "load_id": load_id, - "schema_name": self.schema.name, - "status": 0, - "inserted_at": str(pendulum.now()), - } + values = [load_id, self.schema.name, 0, str(pendulum.now()), self.schema.version_hash] + assert len(values) == len(self.loads_collection_properties) + properties = {k: v for k, v in zip(self.loads_collection_properties, values)} loads_table_name = self._make_qualified_collection_name(self.schema.loads_table_name) - self._create_point(properties, loads_table_name) + self._create_point_no_vector(properties, loads_table_name) def __enter__(self) -> "QdrantClient": + self.db_client = self.config.get_client() return self def __exit__( @@ -426,29 +467,52 @@ def __exit__( exc_val: BaseException, exc_tb: TracebackType, ) -> None: - pass + if self.db_client: + self.config.close_client(self.db_client) + self.db_client = None def _update_schema_in_storage(self, schema: Schema) -> None: schema_str = json.dumps(schema.to_dict()) - properties = { - "version_hash": schema.stored_version_hash, - "schema_name": schema.name, - "version": schema.version, - "engine_version": schema.ENGINE_VERSION, - "inserted_at": str(pendulum.now()), - "schema": schema_str, - } + values = [ + schema.version, + schema.ENGINE_VERSION, + str(pendulum.now().isoformat()), + schema.name, + schema.stored_version_hash, + schema_str, + ] + assert len(values) == len(self.version_collection_properties) + properties = {k: v for k, v in zip(self.version_collection_properties, values)} version_table_name = self._make_qualified_collection_name(self.schema.version_table_name) - self._create_point(properties, version_table_name) + self._create_point_no_vector(properties, version_table_name) def _execute_schema_update(self, only_tables: Iterable[str]) -> None: + is_local = self.config.credentials.is_local() for table_name in only_tables or self.schema.tables: exists = self._collection_exists(table_name) + qualified_collection_name = self._make_qualified_collection_name(table_name) + # NOTE: there are no property schemas in qdrant so we do not need to alter + # existing collections if not exists: self._create_collection( - full_collection_name=self._make_qualified_collection_name(table_name) + full_collection_name=qualified_collection_name, ) + if not is_local: # Indexes don't work in local Qdrant (trigger log warning) + # Create indexes to enable order_by in state and schema tables + if table_name == self.schema.state_table_name: + self.db_client.create_payload_index( + collection_name=qualified_collection_name, + field_name=self.schema.naming.normalize_identifier("created_at"), + field_schema="datetime", + ) + elif table_name == self.schema.version_table_name: + self.db_client.create_payload_index( + collection_name=qualified_collection_name, + field_name=self.schema.naming.normalize_identifier("inserted_at"), + field_schema="datetime", + ) + self._update_schema_in_storage(self.schema) def _collection_exists(self, table_name: str, qualify_table_name: bool = True) -> bool: @@ -460,6 +524,10 @@ def _collection_exists(self, table_name: str, qualify_table_name: bool = True) - ) self.db_client.get_collection(table_name) return True + except ValueError as e: + if "not found" in str(e): + return False + raise e except UnexpectedResponse as e: if e.status_code == 404: return False diff --git a/dlt/destinations/impl/redshift/__init__.py b/dlt/destinations/impl/redshift/__init__.py index 8a8cae84b4..e69de29bb2 100644 --- a/dlt/destinations/impl/redshift/__init__.py +++ b/dlt/destinations/impl/redshift/__init__.py @@ -1,25 +0,0 @@ -from dlt.common.data_writers.escape import escape_redshift_identifier, escape_redshift_literal -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "insert_values" - caps.supported_loader_file_formats = ["insert_values"] - caps.preferred_staging_file_format = "jsonl" - caps.supported_staging_file_formats = ["jsonl", "parquet"] - caps.escape_identifier = escape_redshift_identifier - caps.escape_literal = escape_redshift_literal - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) - caps.max_identifier_length = 127 - caps.max_column_identifier_length = 127 - caps.max_query_length = 16 * 1024 * 1024 - caps.is_max_query_length_in_bytes = True - caps.max_text_data_type_length = 65535 - caps.is_max_text_data_type_length_in_bytes = True - caps.supports_ddl_transactions = True - caps.alter_add_multi_column = False - - return caps diff --git a/dlt/destinations/impl/redshift/configuration.py b/dlt/destinations/impl/redshift/configuration.py index 72d7f70a9f..3b84c8663e 100644 --- a/dlt/destinations/impl/redshift/configuration.py +++ b/dlt/destinations/impl/redshift/configuration.py @@ -23,7 +23,9 @@ class RedshiftCredentials(PostgresCredentials): class RedshiftClientConfiguration(PostgresClientConfiguration): destination_type: Final[str] = dataclasses.field(default="redshift", init=False, repr=False, compare=False) # type: ignore credentials: RedshiftCredentials = None + staging_iam_role: Optional[str] = None + has_case_sensitive_identifiers: bool = False def fingerprint(self) -> str: """Returns a fingerprint of host part of a connection string""" diff --git a/dlt/destinations/impl/redshift/factory.py b/dlt/destinations/impl/redshift/factory.py index d80ef9dcad..ef1ee6b754 100644 --- a/dlt/destinations/impl/redshift/factory.py +++ b/dlt/destinations/impl/redshift/factory.py @@ -1,12 +1,14 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.data_writers.escape import escape_redshift_identifier, escape_redshift_literal +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.normalizers.naming import NamingConvention from dlt.destinations.impl.redshift.configuration import ( RedshiftCredentials, RedshiftClientConfiguration, ) -from dlt.destinations.impl.redshift import capabilities if t.TYPE_CHECKING: from dlt.destinations.impl.redshift.redshift import RedshiftClient @@ -15,8 +17,32 @@ class redshift(Destination[RedshiftClientConfiguration, "RedshiftClient"]): spec = RedshiftClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "insert_values" + caps.supported_loader_file_formats = ["insert_values"] + caps.preferred_staging_file_format = "jsonl" + caps.supported_staging_file_formats = ["jsonl", "parquet"] + # redshift is case insensitive and will lower case identifiers when stored + # you can enable case sensitivity https://docs.aws.amazon.com/redshift/latest/dg/r_enable_case_sensitive_identifier.html + # then redshift behaves like postgres + caps.escape_identifier = escape_redshift_identifier + caps.escape_literal = escape_redshift_literal + caps.casefold_identifier = str.lower + caps.has_case_sensitive_identifiers = False + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + caps.max_identifier_length = 127 + caps.max_column_identifier_length = 127 + caps.max_query_length = 16 * 1024 * 1024 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 65535 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_ddl_transactions = True + caps.alter_add_multi_column = False + caps.supported_merge_strategies = ["delete-insert", "scd2"] + + return caps @property def client_class(self) -> t.Type["RedshiftClient"]: @@ -27,8 +53,8 @@ def client_class(self) -> t.Type["RedshiftClient"]: def __init__( self, credentials: t.Union[RedshiftCredentials, t.Dict[str, t.Any], str] = None, - create_indexes: bool = True, staging_iam_role: t.Optional[str] = None, + has_case_sensitive_identifiers: bool = False, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, **kwargs: t.Any, @@ -40,15 +66,28 @@ def __init__( Args: credentials: Credentials to connect to the redshift database. Can be an instance of `RedshiftCredentials` or a connection string in the format `redshift://user:password@host:port/database` - create_indexes: Should unique indexes be created staging_iam_role: IAM role to use for staging data in S3 + has_case_sensitive_identifiers: Are case sensitive identifiers enabled for a database **kwargs: Additional arguments passed to the destination config """ super().__init__( credentials=credentials, - create_indexes=create_indexes, staging_iam_role=staging_iam_role, + has_case_sensitive_identifiers=has_case_sensitive_identifiers, destination_name=destination_name, environment=environment, **kwargs, ) + + @classmethod + def adjust_capabilities( + cls, + caps: DestinationCapabilitiesContext, + config: RedshiftClientConfiguration, + naming: t.Optional[NamingConvention], + ) -> DestinationCapabilitiesContext: + # modify the caps if case sensitive identifiers are requested + if config.has_case_sensitive_identifiers: + caps.has_case_sensitive_identifiers = True + caps.casefold_identifier = str + return super().adjust_capabilities(caps, config, naming) diff --git a/dlt/destinations/impl/redshift/redshift.py b/dlt/destinations/impl/redshift/redshift.py index 672fceb7b2..8eacc76d11 100644 --- a/dlt/destinations/impl/redshift/redshift.py +++ b/dlt/destinations/impl/redshift/redshift.py @@ -1,11 +1,6 @@ import platform import os -from dlt.common.exceptions import TerminalValueError -from dlt.destinations.impl.postgres.sql_client import Psycopg2SqlClient - -from dlt.common.schema.utils import table_schema_has_type, table_schema_has_type_with_precision - if platform.python_implementation() == "PyPy": import psycopg2cffi as psycopg2 @@ -15,25 +10,27 @@ # from psycopg2.sql import SQL, Composed -from typing import ClassVar, Dict, List, Optional, Sequence, Any +from typing import Dict, List, Optional, Sequence, Any, Tuple + -from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( NewLoadJob, CredentialsConfiguration, SupportsStagingDestination, ) from dlt.common.data_types import TDataType +from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.schema import TColumnSchema, TColumnHint, Schema -from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat +from dlt.common.exceptions import TerminalValueError +from dlt.common.schema.utils import table_schema_has_type, table_schema_has_type_with_precision +from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat, TTableSchemaColumns from dlt.common.configuration.specs import AwsCredentialsWithoutDefaults from dlt.destinations.insert_job_client import InsertValuesJobClient from dlt.destinations.sql_jobs import SqlMergeJob from dlt.destinations.exceptions import DatabaseTerminalException, LoadJobTerminalException from dlt.destinations.job_client_impl import CopyRemoteFileLoadJob, LoadJob - -from dlt.destinations.impl.redshift import capabilities +from dlt.destinations.impl.postgres.sql_client import Psycopg2SqlClient from dlt.destinations.impl.redshift.configuration import RedshiftClientConfiguration from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase @@ -109,8 +106,6 @@ def from_db_type( class RedshiftSqlClient(Psycopg2SqlClient): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - @staticmethod def _maybe_make_terminal_exception_from_data_error( pg_ex: psycopg2.DataError, @@ -151,7 +146,6 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: "CREDENTIALS" f" 'aws_access_key_id={aws_access_key};aws_secret_access_key={aws_secret_key}'" ) - table_name = table["name"] # get format ext = os.path.splitext(bucket_path)[1][1:] @@ -191,10 +185,9 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: raise ValueError(f"Unsupported file type {ext} for Redshift.") with self._sql_client.begin_transaction(): - dataset_name = self._sql_client.dataset_name # TODO: if we ever support csv here remember to add column names to COPY self._sql_client.execute_sql(f""" - COPY {dataset_name}.{table_name} + COPY {self._sql_client.make_qualified_table_name(table['name'])} FROM '{bucket_path}' {file_type} {dateformat} @@ -231,10 +224,18 @@ def gen_key_table_clauses( class RedshiftClient(InsertValuesJobClient, SupportsStagingDestination): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: RedshiftClientConfiguration) -> None: - sql_client = RedshiftSqlClient(config.normalize_dataset_name(schema), config.credentials) + def __init__( + self, + schema: Schema, + config: RedshiftClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + sql_client = RedshiftSqlClient( + config.normalize_dataset_name(schema), + config.normalize_staging_dataset_name(schema), + config.credentials, + capabilities, + ) super().__init__(schema, config, sql_client) self.sql_client = sql_client self.config: RedshiftClientConfiguration = config @@ -249,7 +250,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non for h in HINT_TO_REDSHIFT_ATTR.keys() if c.get(h, False) is True ) - column_name = self.capabilities.escape_identifier(c["name"]) + column_name = self.sql_client.escape_column_name(c["name"]) return ( f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) diff --git a/dlt/destinations/impl/snowflake/__init__.py b/dlt/destinations/impl/snowflake/__init__.py index dde4d5a382..e69de29bb2 100644 --- a/dlt/destinations/impl/snowflake/__init__.py +++ b/dlt/destinations/impl/snowflake/__init__.py @@ -1,25 +0,0 @@ -from dlt.common.data_writers.escape import escape_bigquery_identifier -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.data_writers.escape import escape_snowflake_identifier -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "jsonl" - caps.supported_loader_file_formats = ["jsonl", "parquet"] - caps.preferred_staging_file_format = "jsonl" - caps.supported_staging_file_formats = ["jsonl", "parquet"] - caps.escape_identifier = escape_snowflake_identifier - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) - caps.max_identifier_length = 255 - caps.max_column_identifier_length = 255 - caps.max_query_length = 2 * 1024 * 1024 - caps.is_max_query_length_in_bytes = True - caps.max_text_data_type_length = 16 * 1024 * 1024 - caps.is_max_text_data_type_length_in_bytes = True - caps.supports_ddl_transactions = True - caps.alter_add_multi_column = True - caps.supports_clone_table = True - return caps diff --git a/dlt/destinations/impl/snowflake/configuration.py b/dlt/destinations/impl/snowflake/configuration.py index c8cc805712..d2d822b23e 100644 --- a/dlt/destinations/impl/snowflake/configuration.py +++ b/dlt/destinations/impl/snowflake/configuration.py @@ -1,8 +1,9 @@ import dataclasses import base64 -from typing import Final, Optional, Any, Dict, ClassVar, List, TYPE_CHECKING, Union +from typing import Final, Optional, Any, Dict, ClassVar, List from dlt import version +from dlt.common.data_writers.configuration import CsvFormatConfiguration from dlt.common.libs.sql_alchemy import URL from dlt.common.exceptions import MissingDependencyException from dlt.common.typing import TSecretStrValue @@ -13,8 +14,8 @@ from dlt.common.utils import digest128 -def _read_private_key(private_key: str, password: Optional[str] = None) -> bytes: - """Load an encrypted or unencrypted private key from string.""" +def _decode_private_key(private_key: str, password: Optional[str] = None) -> bytes: + """Decode encrypted or unencrypted private key from string.""" try: from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import rsa @@ -61,67 +62,62 @@ class SnowflakeCredentials(ConnectionStringCredentials): warehouse: Optional[str] = None role: Optional[str] = None authenticator: Optional[str] = None + token: Optional[str] = None private_key: Optional[TSecretStrValue] = None private_key_passphrase: Optional[TSecretStrValue] = None application: Optional[str] = SNOWFLAKE_APPLICATION_ID + query_tag: Optional[str] = None __config_gen_annotations__: ClassVar[List[str]] = ["password", "warehouse", "role"] + __query_params__: ClassVar[List[str]] = [ + "warehouse", + "role", + "authenticator", + "token", + "private_key", + "private_key_passphrase", + ] def parse_native_representation(self, native_value: Any) -> None: super().parse_native_representation(native_value) - self.warehouse = self.query.get("warehouse") - self.role = self.query.get("role") - self.private_key = self.query.get("private_key") # type: ignore - self.private_key_passphrase = self.query.get("private_key_passphrase") # type: ignore - if not self.is_partial() and (self.password or self.private_key): - self.resolve() + for param in self.__query_params__: + if param in self.query: + setattr(self, param, self.query.get(param)) def on_resolved(self) -> None: - if not self.password and not self.private_key: + if not self.password and not self.private_key and not self.authenticator: raise ConfigurationValueError( - "Please specify password or private_key. SnowflakeCredentials supports password and" - " private key authentication and one of those must be specified." + "Please specify password or private_key or authenticator fields." + " SnowflakeCredentials supports password, private key and authenticator based (ie." + " oauth2) authentication and one of those must be specified." ) - def to_url(self) -> URL: - query = dict(self.query or {}) - if self.warehouse and "warehouse" not in query: - query["warehouse"] = self.warehouse - if self.role and "role" not in query: - query["role"] = self.role - - if self.application != "" and "application" not in query: - query["application"] = self.application - - return URL.create( - self.drivername, - self.username, - self.password, - self.host, - self.port, - self.database, - query, - ) + def get_query(self) -> Dict[str, Any]: + query = dict(super().get_query() or {}) + for param in self.__query_params__: + if self.get(param, None) is not None: + query[param] = self[param] + return query def to_connector_params(self) -> Dict[str, Any]: - private_key: Optional[bytes] = None + # gather all params in query + query = self.get_query() if self.private_key: - private_key = _read_private_key(self.private_key, self.private_key_passphrase) + query["private_key"] = _decode_private_key( + self.private_key, self.private_key_passphrase + ) + + # we do not want passphrase to be passed + query.pop("private_key_passphrase", None) - conn_params = dict( - self.query or {}, + conn_params: Dict[str, Any] = dict( + query, user=self.username, password=self.password, account=self.host, database=self.database, - warehouse=self.warehouse, - role=self.role, - private_key=private_key, ) - if self.authenticator: - conn_params["authenticator"] = self.authenticator - if self.application != "" and "application" not in conn_params: conn_params["application"] = self.application @@ -138,6 +134,9 @@ class SnowflakeClientConfiguration(DestinationClientDwhWithStagingConfiguration) keep_staged_files: bool = True """Whether to keep or delete the staged files after COPY INTO succeeds""" + csv_format: Optional[CsvFormatConfiguration] = None + """Optional csv format configuration""" + def fingerprint(self) -> str: """Returns a fingerprint of host part of a connection string""" if self.credentials and self.credentials.host: diff --git a/dlt/destinations/impl/snowflake/factory.py b/dlt/destinations/impl/snowflake/factory.py index c4459232b7..c5fbd8600b 100644 --- a/dlt/destinations/impl/snowflake/factory.py +++ b/dlt/destinations/impl/snowflake/factory.py @@ -1,11 +1,14 @@ import typing as t +from dlt.common.data_writers.configuration import CsvFormatConfiguration +from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.data_writers.escape import escape_snowflake_identifier +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE + from dlt.destinations.impl.snowflake.configuration import ( SnowflakeCredentials, SnowflakeClientConfiguration, ) -from dlt.destinations.impl.snowflake import capabilities -from dlt.common.destination import Destination, DestinationCapabilitiesContext if t.TYPE_CHECKING: from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient @@ -14,8 +17,32 @@ class snowflake(Destination[SnowflakeClientConfiguration, "SnowflakeClient"]): spec = SnowflakeClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "jsonl" + caps.supported_loader_file_formats = ["jsonl", "parquet", "csv"] + caps.preferred_staging_file_format = "jsonl" + caps.supported_staging_file_formats = ["jsonl", "parquet", "csv"] + # snowflake is case sensitive but all unquoted identifiers are upper cased + # so upper case identifiers are considered case insensitive + caps.escape_identifier = escape_snowflake_identifier + # dlt is configured to create case insensitive identifiers + # note that case sensitive naming conventions will change this setting to "str" (case sensitive) + caps.casefold_identifier = str.upper + caps.has_case_sensitive_identifiers = True + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + caps.max_identifier_length = 255 + caps.max_column_identifier_length = 255 + caps.max_query_length = 2 * 1024 * 1024 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 16 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_ddl_transactions = True + caps.alter_add_multi_column = True + caps.supports_clone_table = True + caps.supported_merge_strategies = ["delete-insert", "upsert", "scd2"] + return caps @property def client_class(self) -> t.Type["SnowflakeClient"]: @@ -28,6 +55,7 @@ def __init__( credentials: t.Union[SnowflakeCredentials, t.Dict[str, t.Any], str] = None, stage_name: t.Optional[str] = None, keep_staged_files: bool = True, + csv_format: t.Optional[CsvFormatConfiguration] = None, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, **kwargs: t.Any, @@ -46,6 +74,7 @@ def __init__( credentials=credentials, stage_name=stage_name, keep_staged_files=keep_staged_files, + csv_format=csv_format, destination_name=destination_name, environment=environment, **kwargs, diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index 70377de709..96db81e0c6 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -1,6 +1,7 @@ -from typing import ClassVar, Optional, Sequence, Tuple, List, Any +from typing import Optional, Sequence, List from urllib.parse import urlparse, urlunparse +from dlt.common.data_writers.configuration import CsvFormatConfiguration from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( FollowupJob, @@ -14,23 +15,22 @@ AwsCredentialsWithoutDefaults, AzureCredentialsWithoutDefaults, ) -from dlt.common.data_types import TDataType +from dlt.common.storages.configuration import FilesystemConfiguration from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat +from dlt.common.storages.load_package import ParsedLoadJobFileName +from dlt.common.typing import TLoaderFileFormat from dlt.destinations.job_client_impl import SqlJobClientWithStaging from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.exceptions import LoadJobTerminalException -from dlt.destinations.impl.snowflake import capabilities from dlt.destinations.impl.snowflake.configuration import SnowflakeClientConfiguration from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient -from dlt.destinations.sql_jobs import SqlJobParams from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient from dlt.destinations.job_impl import NewReferenceJob -from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.type_mapping import TypeMapper @@ -86,36 +86,90 @@ def __init__( table_name: str, load_id: str, client: SnowflakeSqlClient, + config: SnowflakeClientConfiguration, stage_name: Optional[str] = None, keep_staged_files: bool = True, staging_credentials: Optional[CredentialsConfiguration] = None, ) -> None: file_name = FileStorage.get_file_name_from_file_path(file_path) super().__init__(file_name) + # resolve reference + is_local_file = not NewReferenceJob.is_reference_job(file_path) + file_url = file_path if is_local_file else NewReferenceJob.resolve_reference(file_path) + # take file name + file_name = FileStorage.get_file_name_from_file_path(file_url) + file_format = file_name.rsplit(".", 1)[-1] qualified_table_name = client.make_qualified_table_name(table_name) + # this means we have a local file + stage_file_path: str = "" + if is_local_file: + if not stage_name: + # Use implicit table stage by default: "SCHEMA_NAME"."%TABLE_NAME" + stage_name = client.make_qualified_table_name("%" + table_name) + stage_file_path = f'@{stage_name}/"{load_id}"/{file_name}' - # extract and prepare some vars - bucket_path = ( - NewReferenceJob.resolve_reference(file_path) - if NewReferenceJob.is_reference_job(file_path) - else "" + copy_sql = self.gen_copy_sql( + file_url, + qualified_table_name, + file_format, # type: ignore[arg-type] + client.capabilities.generates_case_sensitive_identifiers(), + stage_name, + stage_file_path, + staging_credentials, + config.csv_format, ) - file_name = ( - FileStorage.get_file_name_from_file_path(bucket_path) if bucket_path else file_name + + with client.begin_transaction(): + # PUT and COPY in one tx if local file, otherwise only copy + if is_local_file: + client.execute_sql( + f'PUT file://{file_path} @{stage_name}/"{load_id}" OVERWRITE = TRUE,' + " AUTO_COMPRESS = FALSE" + ) + client.execute_sql(copy_sql) + if stage_file_path and not keep_staged_files: + client.execute_sql(f"REMOVE {stage_file_path}") + + def state(self) -> TLoadJobState: + return "completed" + + def exception(self) -> str: + raise NotImplementedError() + + @classmethod + def gen_copy_sql( + cls, + file_url: str, + qualified_table_name: str, + loader_file_format: TLoaderFileFormat, + is_case_sensitive: bool, + stage_name: Optional[str] = None, + local_stage_file_path: Optional[str] = None, + staging_credentials: Optional[CredentialsConfiguration] = None, + csv_format: Optional[CsvFormatConfiguration] = None, + ) -> str: + parsed_file_url = urlparse(file_url) + # check if local filesystem (file scheme or just a local file in native form) + is_local = parsed_file_url.scheme == "file" or FilesystemConfiguration.is_local_path( + file_url ) + # file_name = FileStorage.get_file_name_from_file_path(file_url) + from_clause = "" credentials_clause = "" files_clause = "" - stage_file_path = "" + on_error_clause = "" - if bucket_path: - bucket_url = urlparse(bucket_path) - bucket_scheme = bucket_url.scheme + case_folding = "CASE_SENSITIVE" if is_case_sensitive else "CASE_INSENSITIVE" + column_match_clause = f"MATCH_BY_COLUMN_NAME='{case_folding}'" + + if not is_local: + bucket_scheme = parsed_file_url.scheme # referencing an external s3/azure stage does not require explicit AWS credentials if bucket_scheme in ["s3", "az", "abfs"] and stage_name: from_clause = f"FROM '@{stage_name}'" - files_clause = f"FILES = ('{bucket_url.path.lstrip('/')}')" + files_clause = f"FILES = ('{parsed_file_url.path.lstrip('/')}')" # referencing an staged files via a bucket URL requires explicit AWS credentials elif ( bucket_scheme == "s3" @@ -123,7 +177,7 @@ def __init__( and isinstance(staging_credentials, AwsCredentialsWithoutDefaults) ): credentials_clause = f"""CREDENTIALS=(AWS_KEY_ID='{staging_credentials.aws_access_key_id}' AWS_SECRET_KEY='{staging_credentials.aws_secret_access_key}')""" - from_clause = f"FROM '{bucket_path}'" + from_clause = f"FROM '{file_url}'" elif ( bucket_scheme in ["az", "abfs"] and staging_credentials @@ -133,70 +187,80 @@ def __init__( credentials_clause = f"CREDENTIALS=(AZURE_SAS_TOKEN='?{staging_credentials.azure_storage_sas_token}')" # Converts an az:/// to azure://.blob.core.windows.net// # as required by snowflake - _path = "/" + bucket_url.netloc + bucket_url.path - bucket_path = urlunparse( - bucket_url._replace( + _path = "/" + parsed_file_url.netloc + parsed_file_url.path + file_url = urlunparse( + parsed_file_url._replace( scheme="azure", netloc=f"{staging_credentials.azure_storage_account_name}.blob.core.windows.net", path=_path, ) ) - from_clause = f"FROM '{bucket_path}'" + from_clause = f"FROM '{file_url}'" else: # ensure that gcs bucket path starts with gcs://, this is a requirement of snowflake - bucket_path = bucket_path.replace("gs://", "gcs://") + file_url = file_url.replace("gs://", "gcs://") if not stage_name: # when loading from bucket stage must be given raise LoadJobTerminalException( - file_path, - f"Cannot load from bucket path {bucket_path} without a stage name. See" + file_url, + f"Cannot load from bucket path {file_url} without a stage name. See" " https://dlthub.com/docs/dlt-ecosystem/destinations/snowflake for" " instructions on setting up the `stage_name`", ) from_clause = f"FROM @{stage_name}/" - files_clause = f"FILES = ('{urlparse(bucket_path).path.lstrip('/')}')" + files_clause = f"FILES = ('{urlparse(file_url).path.lstrip('/')}')" else: - # this means we have a local file - if not stage_name: - # Use implicit table stage by default: "SCHEMA_NAME"."%TABLE_NAME" - stage_name = client.make_qualified_table_name("%" + table_name) - stage_file_path = f'@{stage_name}/"{load_id}"/{file_name}' - from_clause = f"FROM {stage_file_path}" + from_clause = f"FROM {local_stage_file_path}" # decide on source format, stage_file_path will either be a local file or a bucket path - source_format = "( TYPE = 'JSON', BINARY_FORMAT = 'BASE64' )" - if file_name.endswith("parquet"): - source_format = "(TYPE = 'PARQUET', BINARY_AS_TEXT = FALSE, USE_LOGICAL_TYPE = TRUE)" - - with client.begin_transaction(): - # PUT and COPY in one tx if local file, otherwise only copy - if not bucket_path: - client.execute_sql( - f'PUT file://{file_path} @{stage_name}/"{load_id}" OVERWRITE = TRUE,' - " AUTO_COMPRESS = FALSE" - ) - client.execute_sql(f"""COPY INTO {qualified_table_name} - {from_clause} - {files_clause} - {credentials_clause} - FILE_FORMAT = {source_format} - MATCH_BY_COLUMN_NAME='CASE_INSENSITIVE' - """) - if stage_file_path and not keep_staged_files: - client.execute_sql(f"REMOVE {stage_file_path}") - - def state(self) -> TLoadJobState: - return "completed" + if loader_file_format == "jsonl": + source_format = "( TYPE = 'JSON', BINARY_FORMAT = 'BASE64' )" + elif loader_file_format == "parquet": + source_format = ( + "(TYPE = 'PARQUET', BINARY_AS_TEXT = FALSE, USE_LOGICAL_TYPE = TRUE)" + # TODO: USE_VECTORIZED_SCANNER inserts null strings into VARIANT JSON + # " USE_VECTORIZED_SCANNER = TRUE)" + ) + elif loader_file_format == "csv": + # empty strings are NULL, no data is NULL, missing columns (ERROR_ON_COLUMN_COUNT_MISMATCH) are NULL + csv_format = csv_format or CsvFormatConfiguration() + source_format = ( + "(TYPE = 'CSV', BINARY_FORMAT = 'UTF-8', PARSE_HEADER =" + f" {csv_format.include_header}, FIELD_OPTIONALLY_ENCLOSED_BY = '\"', NULL_IF =" + " (''), ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE," + f" FIELD_DELIMITER='{csv_format.delimiter}', ENCODING='{csv_format.encoding}')" + ) + # disable column match if headers are not provided + if not csv_format.include_header: + column_match_clause = "" + if csv_format.on_error_continue: + on_error_clause = "ON_ERROR = CONTINUE" + else: + raise ValueError(f"{loader_file_format} not supported for Snowflake COPY command.") - def exception(self) -> str: - raise NotImplementedError() + return f"""COPY INTO {qualified_table_name} + {from_clause} + {files_clause} + {credentials_clause} + FILE_FORMAT = {source_format} + {column_match_clause} + {on_error_clause} + """ class SnowflakeClient(SqlJobClientWithStaging, SupportsStagingDestination): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: SnowflakeClientConfiguration) -> None: - sql_client = SnowflakeSqlClient(config.normalize_dataset_name(schema), config.credentials) + def __init__( + self, + schema: Schema, + config: SnowflakeClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + sql_client = SnowflakeSqlClient( + config.normalize_dataset_name(schema), + config.normalize_staging_dataset_name(schema), + config.credentials, + capabilities, + ) super().__init__(schema, config, sql_client) self.config: SnowflakeClientConfiguration = config self.sql_client: SnowflakeSqlClient = sql_client # type: ignore @@ -211,6 +275,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> table["name"], load_id, self.sql_client, + self.config, stage_name=self.config.stage_name, keep_staged_files=self.config.keep_staged_files, staging_credentials=( @@ -241,7 +306,7 @@ def _get_table_update_sql( sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) cluster_list = [ - self.capabilities.escape_identifier(c["name"]) for c in new_columns if c.get("cluster") + self.sql_client.escape_column_name(c["name"]) for c in new_columns if c.get("cluster") ] if cluster_list: @@ -255,17 +320,7 @@ def _from_db_type( return self.type_mapper.from_db_type(bq_t, precision, scale) def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: - name = self.capabilities.escape_identifier(c["name"]) + name = self.sql_client.escape_column_name(c["name"]) return ( f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" ) - - def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: - table_name = table_name.upper() # All snowflake tables are uppercased in information schema - exists, table = super().get_storage_table(table_name) - if not exists: - return exists, table - # Snowflake converts all unquoted columns to UPPER CASE - # Convert back to lower case to enable comparison with dlt schema - table = {col_name.lower(): dict(col, name=col_name.lower()) for col_name, col in table.items()} # type: ignore - return exists, table diff --git a/dlt/destinations/impl/snowflake/sql_client.py b/dlt/destinations/impl/snowflake/sql_client.py index 4a602ce0e8..e2b7d1026a 100644 --- a/dlt/destinations/impl/snowflake/sql_client.py +++ b/dlt/destinations/impl/snowflake/sql_client.py @@ -1,5 +1,5 @@ from contextlib import contextmanager, suppress -from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence, List +from typing import Any, AnyStr, ClassVar, Dict, Iterator, Optional, Sequence, List import snowflake.connector as snowflake_lib @@ -12,12 +12,12 @@ from dlt.destinations.sql_client import ( DBApiCursorImpl, SqlClientBase, + TJobQueryTags, raise_database_error, raise_open_connection_error, ) from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction, DataFrame from dlt.destinations.impl.snowflake.configuration import SnowflakeCredentials -from dlt.destinations.impl.snowflake import capabilities class SnowflakeCursorImpl(DBApiCursorImpl): @@ -31,10 +31,15 @@ def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: class SnowflakeSqlClient(SqlClientBase[snowflake_lib.SnowflakeConnection], DBTransaction): dbapi: ClassVar[DBApi] = snowflake_lib - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - def __init__(self, dataset_name: str, credentials: SnowflakeCredentials) -> None: - super().__init__(credentials.database, dataset_name) + def __init__( + self, + dataset_name: str, + staging_dataset_name: str, + credentials: SnowflakeCredentials, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(credentials.database, dataset_name, staging_dataset_name, capabilities) self._conn: snowflake_lib.SnowflakeConnection = None self.credentials = credentials @@ -112,16 +117,25 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB self.open_connection() raise outer - def fully_qualified_dataset_name(self, escape: bool = True) -> str: - # Always escape for uppercase - if escape: - return self.capabilities.escape_identifier(self.dataset_name) - return self.dataset_name.upper() - def _reset_connection(self) -> None: self._conn.rollback() self._conn.autocommit(True) + def set_query_tags(self, tags: TJobQueryTags) -> None: + super().set_query_tags(tags) + self._tag_session() + + def _tag_session(self) -> None: + """Wraps query with Snowflake query tag""" + if not self.credentials.query_tag: + return + if self._query_tags: + tag = self.credentials.query_tag.format(**self._query_tags) + tag_query = f"ALTER SESSION SET QUERY_TAG = '{tag}'" + else: + tag_query = "ALTER SESSION UNSET QUERY_TAG" + self.execute_sql(tag_query) + @classmethod def _make_database_exception(cls, ex: Exception) -> Exception: if isinstance(ex, snowflake_lib.errors.ProgrammingError): diff --git a/dlt/destinations/impl/synapse/__init__.py b/dlt/destinations/impl/synapse/__init__.py index f6ad7369c1..e69de29bb2 100644 --- a/dlt/destinations/impl/synapse/__init__.py +++ b/dlt/destinations/impl/synapse/__init__.py @@ -1,54 +0,0 @@ -from dlt.common.data_writers.escape import escape_postgres_identifier, escape_mssql_literal -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.common.wei import EVM_DECIMAL_PRECISION - -from dlt.destinations.impl.synapse.synapse_adapter import synapse_adapter - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - - caps.preferred_loader_file_format = "insert_values" - caps.supported_loader_file_formats = ["insert_values"] - caps.preferred_staging_file_format = "parquet" - caps.supported_staging_file_formats = ["parquet"] - - caps.insert_values_writer_type = "select_union" # https://stackoverflow.com/a/77014299 - - caps.escape_identifier = escape_postgres_identifier - caps.escape_literal = escape_mssql_literal - - # Synapse has a max precision of 38 - # https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-azure-sql-data-warehouse?view=aps-pdw-2016-au7#DataTypes - caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) - caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) - - # https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-azure-sql-data-warehouse?view=aps-pdw-2016-au7#LimitationsRestrictions - caps.max_identifier_length = 128 - caps.max_column_identifier_length = 128 - - # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-service-capacity-limits#queries - caps.max_query_length = 65536 * 4096 - caps.is_max_query_length_in_bytes = True - - # nvarchar(max) can store 2 GB - # https://learn.microsoft.com/en-us/sql/t-sql/data-types/nchar-and-nvarchar-transact-sql?view=sql-server-ver16#nvarchar---n--max-- - caps.max_text_data_type_length = 2 * 1024 * 1024 * 1024 - caps.is_max_text_data_type_length_in_bytes = True - - # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-develop-transactions - caps.supports_transactions = True - caps.supports_ddl_transactions = False - - # Synapse throws "Some part of your SQL statement is nested too deeply. Rewrite the query or break it up into smaller queries." - # if number of records exceeds a certain number. Which exact number that is seems not deterministic: - # in tests, I've seen a query with 12230 records run succesfully on one run, but fail on a subsequent run, while the query remained exactly the same. - # 10.000 records is a "safe" amount that always seems to work. - caps.max_rows_per_insert = 10000 - - # datetimeoffset can store 7 digits for fractional seconds - # https://learn.microsoft.com/en-us/sql/t-sql/data-types/datetimeoffset-transact-sql?view=sql-server-ver16 - caps.timestamp_precision = 7 - - return caps diff --git a/dlt/destinations/impl/synapse/factory.py b/dlt/destinations/impl/synapse/factory.py index 100878ae05..bb117e48d2 100644 --- a/dlt/destinations/impl/synapse/factory.py +++ b/dlt/destinations/impl/synapse/factory.py @@ -1,8 +1,10 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.normalizers.naming import NamingConvention +from dlt.common.data_writers.escape import escape_postgres_identifier, escape_mssql_literal +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.destinations.impl.synapse import capabilities from dlt.destinations.impl.synapse.configuration import ( SynapseCredentials, SynapseClientConfiguration, @@ -21,8 +23,59 @@ class synapse(Destination[SynapseClientConfiguration, "SynapseClient"]): # def spec(self) -> t.Type[SynapseClientConfiguration]: # return SynapseClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + + caps.preferred_loader_file_format = "insert_values" + caps.supported_loader_file_formats = ["insert_values"] + caps.preferred_staging_file_format = "parquet" + caps.supported_staging_file_formats = ["parquet"] + + caps.insert_values_writer_type = "select_union" # https://stackoverflow.com/a/77014299 + + # similarly to mssql case sensitivity depends on database collation + # https://learn.microsoft.com/en-us/sql/relational-databases/collations/collation-and-unicode-support?view=sql-server-ver16#collations-in-azure-sql-database + # note that special option CATALOG_COLLATION is used to change it + caps.escape_identifier = escape_postgres_identifier + caps.escape_literal = escape_mssql_literal + # we allow to reconfigure capabilities in the mssql factory + caps.has_case_sensitive_identifiers = False + + # Synapse has a max precision of 38 + # https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-azure-sql-data-warehouse?view=aps-pdw-2016-au7#DataTypes + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + + # https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-azure-sql-data-warehouse?view=aps-pdw-2016-au7#LimitationsRestrictions + caps.max_identifier_length = 128 + caps.max_column_identifier_length = 128 + + # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-service-capacity-limits#queries + caps.max_query_length = 65536 * 4096 + caps.is_max_query_length_in_bytes = True + + # nvarchar(max) can store 2 GB + # https://learn.microsoft.com/en-us/sql/t-sql/data-types/nchar-and-nvarchar-transact-sql?view=sql-server-ver16#nvarchar---n--max-- + caps.max_text_data_type_length = 2 * 1024 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = True + + # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-develop-transactions + caps.supports_transactions = True + caps.supports_ddl_transactions = False + + # Synapse throws "Some part of your SQL statement is nested too deeply. Rewrite the query or break it up into smaller queries." + # if number of records exceeds a certain number. Which exact number that is seems not deterministic: + # in tests, I've seen a query with 12230 records run succesfully on one run, but fail on a subsequent run, while the query remained exactly the same. + # 10.000 records is a "safe" amount that always seems to work. + caps.max_rows_per_insert = 10000 + + # datetimeoffset can store 7 digits for fractional seconds + # https://learn.microsoft.com/en-us/sql/t-sql/data-types/datetimeoffset-transact-sql?view=sql-server-ver16 + caps.timestamp_precision = 7 + + caps.supported_merge_strategies = ["delete-insert", "scd2"] + + return caps @property def client_class(self) -> t.Type["SynapseClient"]: @@ -36,6 +89,7 @@ def __init__( default_table_index_type: t.Optional[TTableIndexType] = "heap", create_indexes: bool = False, staging_use_msi: bool = False, + has_case_sensitive_identifiers: bool = False, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, **kwargs: t.Any, @@ -50,6 +104,7 @@ def __init__( default_table_index_type: Maps directly to the default_table_index_type attribute of the SynapseClientConfiguration object. create_indexes: Maps directly to the create_indexes attribute of the SynapseClientConfiguration object. staging_use_msi: Maps directly to the staging_use_msi attribute of the SynapseClientConfiguration object. + has_case_sensitive_identifiers: Are identifiers used by synapse database case sensitive (following the catalog collation) **kwargs: Additional arguments passed to the destination config """ super().__init__( @@ -57,7 +112,21 @@ def __init__( default_table_index_type=default_table_index_type, create_indexes=create_indexes, staging_use_msi=staging_use_msi, + has_case_sensitive_identifiers=has_case_sensitive_identifiers, destination_name=destination_name, environment=environment, **kwargs, ) + + @classmethod + def adjust_capabilities( + cls, + caps: DestinationCapabilitiesContext, + config: SynapseClientConfiguration, + naming: t.Optional[NamingConvention], + ) -> DestinationCapabilitiesContext: + # modify the caps if case sensitive identifiers are requested + if config.has_case_sensitive_identifiers: + caps.has_case_sensitive_identifiers = True + caps.casefold_identifier = str + return super().adjust_capabilities(caps, config, naming) diff --git a/dlt/destinations/impl/synapse/sql_client.py b/dlt/destinations/impl/synapse/sql_client.py index 089c58e57c..cd9a929901 100644 --- a/dlt/destinations/impl/synapse/sql_client.py +++ b/dlt/destinations/impl/synapse/sql_client.py @@ -1,28 +1,16 @@ -from typing import ClassVar from contextlib import suppress -from dlt.common.destination import DestinationCapabilitiesContext - from dlt.destinations.impl.mssql.sql_client import PyOdbcMsSqlClient -from dlt.destinations.impl.mssql.configuration import MsSqlCredentials -from dlt.destinations.impl.synapse import capabilities -from dlt.destinations.impl.synapse.configuration import SynapseCredentials - from dlt.destinations.exceptions import DatabaseUndefinedRelation class SynapseSqlClient(PyOdbcMsSqlClient): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - def drop_tables(self, *tables: str) -> None: if not tables: return # Synapse does not support DROP TABLE IF EXISTS. # Workaround: use DROP TABLE and suppress non-existence errors. statements = [f"DROP TABLE {self.make_qualified_table_name(table)};" for table in tables] - with suppress(DatabaseUndefinedRelation): - self.execute_fragments(statements) - - def _drop_schema(self) -> None: - # Synapse does not support DROP SCHEMA IF EXISTS. - self.execute_sql("DROP SCHEMA %s;" % self.fully_qualified_dataset_name()) + for statement in statements: + with suppress(DatabaseUndefinedRelation): + self.execute_sql(statement) diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py index 48171ace4c..408bfc2b53 100644 --- a/dlt/destinations/impl/synapse/synapse.py +++ b/dlt/destinations/impl/synapse/synapse.py @@ -1,5 +1,5 @@ import os -from typing import ClassVar, Sequence, List, Dict, Any, Optional, cast, Union +from typing import Sequence, List, Dict, Any, Optional, cast, Union from copy import deepcopy from textwrap import dedent from urllib.parse import urlparse, urlunparse @@ -29,12 +29,11 @@ from dlt.destinations.impl.mssql.mssql import ( MsSqlTypeMapper, - MsSqlClient, + MsSqlJobClient, VARCHAR_MAX_N, VARBINARY_MAX_N, ) -from dlt.destinations.impl.synapse import capabilities from dlt.destinations.impl.synapse.sql_client import SynapseSqlClient from dlt.destinations.impl.synapse.configuration import SynapseClientConfiguration from dlt.destinations.impl.synapse.synapse_adapter import ( @@ -53,14 +52,20 @@ } -class SynapseClient(MsSqlClient, SupportsStagingDestination): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - - def __init__(self, schema: Schema, config: SynapseClientConfiguration) -> None: - super().__init__(schema, config) +class SynapseClient(MsSqlJobClient, SupportsStagingDestination): + def __init__( + self, + schema: Schema, + config: SynapseClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(schema, config, capabilities) self.config: SynapseClientConfiguration = config self.sql_client = SynapseSqlClient( - config.normalize_dataset_name(schema), config.credentials + config.normalize_dataset_name(schema), + config.normalize_staging_dataset_name(schema), + config.credentials, + capabilities, ) self.active_hints = deepcopy(HINT_TO_SYNAPSE_ATTR) diff --git a/dlt/destinations/impl/weaviate/__init__.py b/dlt/destinations/impl/weaviate/__init__.py index 143e0260d2..e69de29bb2 100644 --- a/dlt/destinations/impl/weaviate/__init__.py +++ b/dlt/destinations/impl/weaviate/__init__.py @@ -1,19 +0,0 @@ -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.destinations.impl.weaviate.weaviate_adapter import weaviate_adapter - - -def capabilities() -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "jsonl" - caps.supported_loader_file_formats = ["jsonl"] - - caps.max_identifier_length = 200 - caps.max_column_identifier_length = 1024 - caps.max_query_length = 8 * 1024 * 1024 - caps.is_max_query_length_in_bytes = False - caps.max_text_data_type_length = 8 * 1024 * 1024 - caps.is_max_text_data_type_length_in_bytes = False - caps.supports_ddl_transactions = False - caps.naming_convention = "dlt.destinations.impl.weaviate.naming" - - return caps diff --git a/dlt/destinations/impl/weaviate/ci_naming.py b/dlt/destinations/impl/weaviate/ci_naming.py index cc8936f42d..6e1b0c129e 100644 --- a/dlt/destinations/impl/weaviate/ci_naming.py +++ b/dlt/destinations/impl/weaviate/ci_naming.py @@ -2,6 +2,12 @@ class NamingConvention(WeaviateNamingConvention): + """Case insensitive naming convention for Weaviate. Lower cases all identifiers""" + + @property + def is_case_sensitive(self) -> bool: + return False + def _lowercase_property(self, identifier: str) -> str: """Lowercase the whole property to become case insensitive""" return identifier.lower() diff --git a/dlt/destinations/impl/weaviate/exceptions.py b/dlt/destinations/impl/weaviate/exceptions.py index ee798e4e76..11e440a811 100644 --- a/dlt/destinations/impl/weaviate/exceptions.py +++ b/dlt/destinations/impl/weaviate/exceptions.py @@ -1,16 +1,16 @@ from dlt.common.destination.exceptions import DestinationException, DestinationTerminalException -class WeaviateBatchError(DestinationException): +class WeaviateGrpcError(DestinationException): pass class PropertyNameConflict(DestinationTerminalException): - def __init__(self) -> None: + def __init__(self, error: str) -> None: super().__init__( "Your data contains items with identical property names when compared case insensitive." " Weaviate cannot handle such data. Please clean up your data before loading or change" " to case insensitive naming convention. See" " https://dlthub.com/docs/dlt-ecosystem/destinations/weaviate#names-normalization for" - " details." + f" details. [{error}]" ) diff --git a/dlt/destinations/impl/weaviate/factory.py b/dlt/destinations/impl/weaviate/factory.py index 0449e6cdd5..3d78c9582a 100644 --- a/dlt/destinations/impl/weaviate/factory.py +++ b/dlt/destinations/impl/weaviate/factory.py @@ -6,7 +6,6 @@ WeaviateCredentials, WeaviateClientConfiguration, ) -from dlt.destinations.impl.weaviate import capabilities if t.TYPE_CHECKING: from dlt.destinations.impl.weaviate.weaviate_client import WeaviateClient @@ -15,8 +14,26 @@ class weaviate(Destination[WeaviateClientConfiguration, "WeaviateClient"]): spec = WeaviateClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "jsonl" + caps.supported_loader_file_formats = ["jsonl"] + # weaviate names are case sensitive following GraphQL naming convention + # https://weaviate.io/developers/weaviate/config-refs/schema + caps.has_case_sensitive_identifiers = False + # weaviate will upper case first letter of class name and lower case first letter of a property + # we assume that naming convention will do that + caps.casefold_identifier = str + caps.max_identifier_length = 200 + caps.max_column_identifier_length = 1024 + caps.max_query_length = 8 * 1024 * 1024 + caps.is_max_query_length_in_bytes = False + caps.max_text_data_type_length = 8 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = False + caps.supports_ddl_transactions = False + caps.naming_convention = "dlt.destinations.impl.weaviate.naming" + + return caps @property def client_class(self) -> t.Type["WeaviateClient"]: diff --git a/dlt/destinations/impl/weaviate/naming.py b/dlt/destinations/impl/weaviate/naming.py index f5c94c872f..81a53dafd3 100644 --- a/dlt/destinations/impl/weaviate/naming.py +++ b/dlt/destinations/impl/weaviate/naming.py @@ -1,14 +1,20 @@ import re +from typing import ClassVar from dlt.common.normalizers.naming import NamingConvention as BaseNamingConvention from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention +from dlt.common.typing import REPattern class NamingConvention(SnakeCaseNamingConvention): """Normalizes identifiers according to Weaviate documentation: https://weaviate.io/developers/weaviate/config-refs/schema#class""" + @property + def is_case_sensitive(self) -> bool: + return True + RESERVED_PROPERTIES = {"id": "__id", "_id": "___id", "_additional": "__additional"} - _RE_UNDERSCORES = re.compile("([^_])__+") + RE_UNDERSCORES: ClassVar[REPattern] = re.compile("([^_])__+") _STARTS_DIGIT = re.compile("^[0-9]") _STARTS_NON_LETTER = re.compile("^[0-9_]") _SPLIT_UNDERSCORE_NON_CAP = re.compile("(_[^A-Z])") @@ -51,11 +57,11 @@ def _lowercase_property(self, identifier: str) -> str: def _base_normalize(self, identifier: str) -> str: # all characters that are not letters digits or a few special chars are replaced with underscore normalized_ident = identifier.translate(self._TR_REDUCE_ALPHABET) - normalized_ident = self._RE_NON_ALPHANUMERIC.sub("_", normalized_ident) + normalized_ident = self.RE_NON_ALPHANUMERIC.sub("_", normalized_ident) # replace trailing _ with x stripped_ident = normalized_ident.rstrip("_") strip_count = len(normalized_ident) - len(stripped_ident) stripped_ident += "x" * strip_count # replace consecutive underscores with single one to prevent name clashes with PATH_SEPARATOR - return self._RE_UNDERSCORES.sub(r"\1_", stripped_ident) + return self.RE_UNDERSCORES.sub(r"\1_", stripped_ident) diff --git a/dlt/destinations/impl/weaviate/weaviate_client.py b/dlt/destinations/impl/weaviate/weaviate_client.py index 2d75ca0809..dfbf83d7e5 100644 --- a/dlt/destinations/impl/weaviate/weaviate_client.py +++ b/dlt/destinations/impl/weaviate/weaviate_client.py @@ -31,20 +31,23 @@ from dlt.common.time import ensure_pendulum_datetime from dlt.common.schema import Schema, TTableSchema, TSchemaTables, TTableSchemaColumns from dlt.common.schema.typing import TColumnSchema, TColumnType -from dlt.common.schema.utils import get_columns_names_with_prop +from dlt.common.schema.utils import ( + get_columns_names_with_prop, + loads_table, + normalize_table_identifiers, + version_table, +) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import TLoadJobState, LoadJob, JobClientBase, WithStateSync -from dlt.common.data_types import TDataType from dlt.common.storages import FileStorage from dlt.destinations.impl.weaviate.weaviate_adapter import VECTORIZE_HINT, TOKENIZATION_HINT - from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.job_client_impl import StorageSchemaInfo, StateInfo -from dlt.destinations.impl.weaviate import capabilities from dlt.destinations.impl.weaviate.configuration import WeaviateClientConfiguration -from dlt.destinations.impl.weaviate.exceptions import PropertyNameConflict, WeaviateBatchError +from dlt.destinations.impl.weaviate.exceptions import PropertyNameConflict, WeaviateGrpcError from dlt.destinations.type_mapping import TypeMapper +from dlt.destinations.utils import get_pipeline_state_query_columns NON_VECTORIZED_CLASS = { @@ -104,7 +107,7 @@ def _wrap(self: JobClientBase, *args: Any, **kwargs: Any) -> Any: if "conflict for property" in str(status_ex) or "none vectorizer module" in str( status_ex ): - raise PropertyNameConflict() + raise PropertyNameConflict(str(status_ex)) raise DestinationTerminalException(status_ex) # looks like there are no more terminal exception raise DestinationTransientException(status_ex) @@ -115,23 +118,25 @@ def _wrap(self: JobClientBase, *args: Any, **kwargs: Any) -> Any: return _wrap # type: ignore -def wrap_batch_error(f: TFun) -> TFun: +def wrap_grpc_error(f: TFun) -> TFun: @wraps(f) def _wrap(*args: Any, **kwargs: Any) -> Any: try: return f(*args, **kwargs) # those look like terminal exceptions - except WeaviateBatchError as batch_ex: + except WeaviateGrpcError as batch_ex: errors = batch_ex.args[0] message = errors["error"][0]["message"] # TODO: actually put the job in failed/retry state and prepare exception message with full info on failing item if "invalid" in message and "property" in message and "on class" in message: raise DestinationTerminalException( - f"Batch failed {errors} AND WILL **NOT** BE RETRIED" + f"Grpc (batch, query) failed {errors} AND WILL **NOT** BE RETRIED" ) if "conflict for property" in message: - raise PropertyNameConflict() - raise DestinationTransientException(f"Batch failed {errors} AND WILL BE RETRIED") + raise PropertyNameConflict(message) + raise DestinationTransientException( + f"Grpc (batch, query) failed {errors} AND WILL BE RETRIED" + ) except Exception: raise DestinationTransientException("Batch failed AND WILL BE RETRIED") @@ -174,14 +179,14 @@ def load_batch(self, f: IO[str]) -> None: Weaviate batch supports retries so we do not need to do that. """ - @wrap_batch_error + @wrap_grpc_error def check_batch_result(results: List[StrAny]) -> None: """This kills batch on first error reported""" if results is not None: for result in results: if "result" in result and "errors" in result["result"]: if "error" in result["result"]["errors"]: - raise WeaviateBatchError(result["result"]["errors"]) + raise WeaviateGrpcError(result["result"]["errors"]) with self.db_client.batch( batch_size=self.client_config.batch_size, @@ -233,20 +238,25 @@ def exception(self) -> str: class WeaviateClient(JobClientBase, WithStateSync): """Weaviate client implementation.""" - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - state_properties: ClassVar[List[str]] = [ - "version", - "engine_version", - "pipeline_name", - "state", - "created_at", - "_dlt_load_id", - ] - - def __init__(self, schema: Schema, config: WeaviateClientConfiguration) -> None: - super().__init__(schema, config) + def __init__( + self, + schema: Schema, + config: WeaviateClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + super().__init__(schema, config, capabilities) + # get definitions of the dlt tables, normalize column names and keep for later use + version_table_ = normalize_table_identifiers(version_table(), schema.naming) + self.version_collection_properties = list(version_table_["columns"].keys()) + loads_table_ = normalize_table_identifiers(loads_table(), schema.naming) + self.loads_collection_properties = list(loads_table_["columns"].keys()) + state_table_ = normalize_table_identifiers( + get_pipeline_state_query_columns(), schema.naming + ) + self.pipeline_state_properties = list(state_table_["columns"].keys()) + self.config: WeaviateClientConfiguration = config - self.db_client = self.create_db_client(config) + self.db_client: weaviate.Client = None self._vectorizer_config = { "vectorizer": config.vectorizer, @@ -451,15 +461,22 @@ def update_stored_schema( return applied_update def _execute_schema_update(self, only_tables: Iterable[str]) -> None: - for table_name in only_tables or self.schema.tables: + for table_name in only_tables or self.schema.tables.keys(): exists, existing_columns = self.get_storage_table(table_name) # TODO: detect columns where vectorization was added or removed and modify it. currently we ignore change of hints - new_columns = self.schema.get_new_table_columns(table_name, existing_columns) + new_columns = self.schema.get_new_table_columns( + table_name, + existing_columns, + case_sensitive=self.capabilities.generates_case_sensitive_identifiers(), + ) logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") if len(new_columns) > 0: if exists: + is_collection_vectorized = self._is_collection_vectorized(table_name) for column in new_columns: - prop = self._make_property_schema(column["name"], column) + prop = self._make_property_schema( + column["name"], column, is_collection_vectorized + ) self.create_class_property(table_name, prop) else: class_schema = self.make_weaviate_class_schema(table_name) @@ -487,6 +504,11 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: """Loads compressed state from destination storage""" + # normalize properties + p_load_id = self.schema.naming.normalize_identifier("load_id") + p_dlt_load_id = self.schema.naming.normalize_identifier("_dlt_load_id") + p_pipeline_name = self.schema.naming.normalize_identifier("pipeline_name") + p_status = self.schema.naming.normalize_identifier("status") # we need to find a stored state that matches a load id that was completed # we retrieve the state in blocks of 10 for this @@ -496,44 +518,45 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: state_records = self.get_records( self.schema.state_table_name, # search by package load id which is guaranteed to increase over time - sort={"path": ["_dlt_load_id"], "order": "desc"}, + sort={"path": [p_dlt_load_id], "order": "desc"}, where={ - "path": ["pipeline_name"], + "path": [p_pipeline_name], "operator": "Equal", "valueString": pipeline_name, }, limit=stepsize, offset=offset, - properties=self.state_properties, + properties=self.pipeline_state_properties, ) offset += stepsize if len(state_records) == 0: return None for state in state_records: - load_id = state["_dlt_load_id"] + load_id = state[p_dlt_load_id] load_records = self.get_records( self.schema.loads_table_name, where={ - "path": ["load_id"], + "path": [p_load_id], "operator": "Equal", "valueString": load_id, }, limit=1, - properties=["load_id", "status"], + properties=[p_load_id, p_status], ) # if there is a load for this state which was successful, return the state if len(load_records): - state["dlt_load_id"] = state.pop("_dlt_load_id") return StateInfo(**state) def get_stored_schema(self) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage""" + p_schema_name = self.schema.naming.normalize_identifier("schema_name") + p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") try: record = self.get_records( self.schema.version_table_name, - sort={"path": ["inserted_at"], "order": "desc"}, + sort={"path": [p_inserted_at], "order": "desc"}, where={ - "path": ["schema_name"], + "path": [p_schema_name], "operator": "Equal", "valueString": self.schema.name, }, @@ -544,11 +567,12 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: return None def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: + p_version_hash = self.schema.naming.normalize_identifier("version_hash") try: record = self.get_records( self.schema.version_table_name, where={ - "path": ["version_hash"], + "path": [p_version_hash], "operator": "Equal", "valueString": schema_hash, }, @@ -585,8 +609,13 @@ def get_records( query = query.with_offset(offset) response = query.do() + # if json rpc is used, weaviate does not raise exceptions + if "errors" in response: + raise WeaviateGrpcError(response["errors"]) full_class_name = self.make_qualified_class_name(table_name) records = response["data"]["Get"][full_class_name] + if records is None: + raise DestinationTransientException(f"Could not obtain records for {full_class_name}") return cast(List[Dict[str, Any]], records) def make_weaviate_class_schema(self, table_name: str) -> Dict[str, Any]: @@ -597,31 +626,39 @@ def make_weaviate_class_schema(self, table_name: str) -> Dict[str, Any]: } # check if any column requires vectorization - if get_columns_names_with_prop(self.schema.get_table(table_name), VECTORIZE_HINT): + if self._is_collection_vectorized(table_name): class_schema.update(self._vectorizer_config) else: class_schema.update(NON_VECTORIZED_CLASS) return class_schema + def _is_collection_vectorized(self, table_name: str) -> bool: + """Tells is any of the columns has vectorize hint set""" + return ( + len(get_columns_names_with_prop(self.schema.get_table(table_name), VECTORIZE_HINT)) > 0 + ) + def _make_properties(self, table_name: str) -> List[Dict[str, Any]]: """Creates a Weaviate properties schema from a table schema. Args: table: The table name for which columns should be converted to properties """ - + is_collection_vectorized = self._is_collection_vectorized(table_name) return [ - self._make_property_schema(column_name, column) + self._make_property_schema(column_name, column, is_collection_vectorized) for column_name, column in self.schema.get_table_columns(table_name).items() ] - def _make_property_schema(self, column_name: str, column: TColumnSchema) -> Dict[str, Any]: + def _make_property_schema( + self, column_name: str, column: TColumnSchema, is_collection_vectorized: bool + ) -> Dict[str, Any]: extra_kv = {} vectorizer_name = self._vectorizer_config["vectorizer"] # x-weaviate-vectorize: (bool) means that this field should be vectorized - if not column.get(VECTORIZE_HINT, False): + if is_collection_vectorized and not column.get(VECTORIZE_HINT, False): # tell weaviate explicitly to not vectorize when column has no vectorize hint extra_kv["moduleConfig"] = { vectorizer_name: { @@ -655,15 +692,20 @@ def restore_file_load(self, file_path: str) -> LoadJob: @wrap_weaviate_error def complete_load(self, load_id: str) -> None: - properties = { - "load_id": load_id, - "schema_name": self.schema.name, - "status": 0, - "inserted_at": pendulum.now().isoformat(), - } + # corresponds to order of the columns in loads_table() + values = [ + load_id, + self.schema.name, + 0, + pendulum.now().isoformat(), + self.schema.version_hash, + ] + assert len(values) == len(self.loads_collection_properties) + properties = {k: v for k, v in zip(self.loads_collection_properties, values)} self.create_object(properties, self.schema.loads_table_name) def __enter__(self) -> "WeaviateClient": + self.db_client = self.create_db_client(self.config) return self def __exit__( @@ -672,18 +714,22 @@ def __exit__( exc_val: BaseException, exc_tb: TracebackType, ) -> None: - pass + if self.db_client: + self.db_client = None def _update_schema_in_storage(self, schema: Schema) -> None: schema_str = json.dumps(schema.to_dict()) - properties = { - "version_hash": schema.stored_version_hash, - "schema_name": schema.name, - "version": schema.version, - "engine_version": schema.ENGINE_VERSION, - "inserted_at": pendulum.now().isoformat(), - "schema": schema_str, - } + # corresponds to order of the columns in version_table() + values = [ + schema.version, + schema.ENGINE_VERSION, + str(pendulum.now().isoformat()), + schema.name, + schema.stored_version_hash, + schema_str, + ] + assert len(values) == len(self.version_collection_properties) + properties = {k: v for k, v in zip(self.version_collection_properties, values)} self.create_object(properties, self.schema.version_table_name) def _from_db_type( diff --git a/dlt/destinations/insert_job_client.py b/dlt/destinations/insert_job_client.py index 74e14f0221..652d13f556 100644 --- a/dlt/destinations/insert_job_client.py +++ b/dlt/destinations/insert_job_client.py @@ -36,6 +36,10 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st # the procedure below will split the inserts into max_query_length // 2 packs with FileStorage.open_zipsafe_ro(file_path, "r", encoding="utf-8") as f: header = f.readline() + # format and casefold header + header = self._sql_client.capabilities.casefold_identifier(header).format( + qualified_table_name + ) writer_type = self._sql_client.capabilities.insert_values_writer_type if writer_type == "default": sep = "," @@ -70,7 +74,7 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st # Chunk by max_rows - 1 for simplicity because one more row may be added for chunk in chunks(values_rows, max_rows - 1): processed += len(chunk) - insert_sql.append(header.format(qualified_table_name)) + insert_sql.append(header) if writer_type == "default": insert_sql.append(values_mark) if processed == len_rows: @@ -82,11 +86,9 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st else: # otherwise write all content in a single INSERT INTO if writer_type == "default": - insert_sql.extend( - [header.format(qualified_table_name), values_mark, content + until_nl] - ) + insert_sql.extend([header, values_mark, content + until_nl]) elif writer_type == "select_union": - insert_sql.extend([header.format(qualified_table_name), content + until_nl]) + insert_sql.extend([header, content + until_nl]) # actually this may be empty if we were able to read a full file into content if not is_eof: diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index ac3636db2b..f67784f77b 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -1,44 +1,41 @@ import os from abc import abstractmethod import base64 -import binascii import contextlib from copy import copy -import datetime # noqa: 251 from types import TracebackType from typing import ( Any, ClassVar, List, - NamedTuple, Optional, Sequence, Tuple, Type, Iterable, Iterator, - ContextManager, - cast, ) import zlib import re -from dlt.common import logger +from dlt.common import pendulum, logger from dlt.common.json import json -from dlt.common.pendulum import pendulum -from dlt.common.data_types import TDataType from dlt.common.schema.typing import ( COLUMN_HINTS, TColumnType, TColumnSchemaBase, TTableSchema, - TWriteDisposition, TTableFormat, ) +from dlt.common.schema.utils import ( + get_inherited_table_hint, + loads_table, + normalize_table_identifiers, + version_table, +) from dlt.common.storages import FileStorage from dlt.common.storages.load_package import LoadJobInfo from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns, TSchemaTables -from dlt.common.schema.typing import LOADS_TABLE_NAME, VERSION_TABLE_NAME from dlt.common.destination.reference import ( StateInfo, StorageSchemaInfo, @@ -59,6 +56,11 @@ from dlt.destinations.sql_jobs import SqlMergeJob, SqlStagingCopyJob from dlt.destinations.typing import TNativeConn from dlt.destinations.sql_client import SqlClientBase +from dlt.destinations.utils import ( + get_pipeline_state_query_columns, + info_schema_null_to_bool, + verify_sql_job_client_schema, +) # this should suffice for now DDL_COMMANDS = ["ALTER", "CREATE", "DROP"] @@ -78,7 +80,7 @@ def __init__(self, file_path: str, sql_client: SqlClientBase[Any]) -> None: sql_client.execute_many(self._split_fragments(sql)) # if we detect ddl transactions, only execute transaction if supported by client elif ( - not self._string_containts_ddl_queries(sql) + not self._string_contains_ddl_queries(sql) or sql_client.capabilities.supports_ddl_transactions ): # with sql_client.begin_transaction(): @@ -95,7 +97,7 @@ def exception(self) -> str: # this part of code should be never reached raise NotImplementedError() - def _string_containts_ddl_queries(self, sql: str) -> bool: + def _string_contains_ddl_queries(self, sql: str) -> bool: for cmd in DDL_COMMANDS: if re.search(cmd, sql, re.IGNORECASE): return True @@ -133,22 +135,8 @@ def state(self) -> TLoadJobState: class SqlJobClientBase(JobClientBase, WithStateSync): - _VERSION_TABLE_SCHEMA_COLUMNS: ClassVar[Tuple[str, ...]] = ( - "version_hash", - "schema_name", - "version", - "engine_version", - "inserted_at", - "schema", - ) - _STATE_TABLE_COLUMNS: ClassVar[Tuple[str, ...]] = ( - "version", - "engine_version", - "pipeline_name", - "state", - "created_at", - "_dlt_load_id", - ) + INFO_TABLES_QUERY_THRESHOLD: ClassVar[int] = 1000 + """Fallback to querying all tables in the information schema if checking more than threshold""" def __init__( self, @@ -156,14 +144,22 @@ def __init__( config: DestinationClientConfiguration, sql_client: SqlClientBase[TNativeConn], ) -> None: + # get definitions of the dlt tables, normalize column names and keep for later use + version_table_ = normalize_table_identifiers(version_table(), schema.naming) self.version_table_schema_columns = ", ".join( - sql_client.escape_column_name(col) for col in self._VERSION_TABLE_SCHEMA_COLUMNS + sql_client.escape_column_name(col) for col in version_table_["columns"] + ) + loads_table_ = normalize_table_identifiers(loads_table(), schema.naming) + self.loads_table_schema_columns = ", ".join( + sql_client.escape_column_name(col) for col in loads_table_["columns"] + ) + state_table_ = normalize_table_identifiers( + get_pipeline_state_query_columns(), schema.naming ) self.state_table_columns = ", ".join( - sql_client.escape_column_name(col) for col in self._STATE_TABLE_COLUMNS + sql_client.escape_column_name(col) for col in state_table_["columns"] ) - - super().__init__(schema, config) + super().__init__(schema, config, sql_client.capabilities) self.sql_client = sql_client assert isinstance(config, DestinationClientDwhConfiguration) self.config: DestinationClientDwhConfiguration = config @@ -250,10 +246,12 @@ def _create_replace_followup_jobs( def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], - table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, + completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[NewLoadJob]: """Creates a list of followup jobs for merge write disposition and staging replace strategies""" - jobs = super().create_table_chain_completed_followup_jobs(table_chain, table_chain_jobs) + jobs = super().create_table_chain_completed_followup_jobs( + table_chain, completed_table_chain_jobs + ) write_disposition = table_chain[0]["write_disposition"] if write_disposition == "append": jobs.extend(self._create_append_followup_jobs(table_chain)) @@ -265,6 +263,7 @@ def create_table_chain_completed_followup_jobs( def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: """Starts SqlLoadJob for files ending with .sql or returns None to let derived classes to handle their specific jobs""" + if SqlLoadJob.is_sql_job(file_path): # execute sql load job return SqlLoadJob(file_path, self.sql_client) @@ -290,8 +289,7 @@ def complete_load(self, load_id: str) -> None: name = self.sql_client.make_qualified_table_name(self.schema.loads_table_name) now_ts = pendulum.now() self.sql_client.execute_sql( - f"INSERT INTO {name}(load_id, schema_name, status, inserted_at, schema_version_hash)" - " VALUES(%s, %s, %s, %s, %s);", + f"INSERT INTO {name}({self.loads_table_schema_columns}) VALUES(%s, %s, %s, %s, %s);", load_id, self.schema.name, 0, @@ -308,54 +306,96 @@ def __exit__( ) -> None: self.sql_client.close_connection() - def _get_storage_table_query_columns(self) -> List[str]: - """Column names used when querying table from information schema. - Override for databases that use different namings. - """ - fields = ["column_name", "data_type", "is_nullable"] - if self.capabilities.schema_supports_numeric_precision: - fields += ["numeric_precision", "numeric_scale"] - return fields + def get_storage_tables( + self, table_names: Iterable[str] + ) -> Iterable[Tuple[str, TTableSchemaColumns]]: + """Uses INFORMATION_SCHEMA to retrieve table and column information for tables in `table_names` iterator. + Table names should be normalized according to naming convention and will be further converted to desired casing + in order to (in most cases) create case-insensitive name suitable for search in information schema. - def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: - def _null_to_bool(v: str) -> bool: - if v == "NO": - return False - elif v == "YES": - return True - raise ValueError(v) + The column names are returned as in information schema. To match those with columns in existing table, you'll need to use + `schema.get_new_table_columns` method and pass the correct casing. Most of the casing function are irreversible so it is not + possible to convert identifiers into INFORMATION SCHEMA back into case sensitive dlt schema. + """ + table_names = list(table_names) + if len(table_names) == 0: + # empty generator + return + # get schema search components + catalog_name, schema_name, folded_table_names = ( + self.sql_client._get_information_schema_components(*table_names) + ) + # create table name conversion lookup table + name_lookup = { + folded_name: name for folded_name, name in zip(folded_table_names, table_names) + } + # this should never happen: we verify schema for name collisions before loading + assert len(name_lookup) == len(table_names), ( + f"One or more of tables in {table_names} after applying" + f" {self.capabilities.casefold_identifier} produced a name collision." + ) + # if we have more tables to lookup than a threshold, we prefer to filter them in code + if ( + len(name_lookup) > self.INFO_TABLES_QUERY_THRESHOLD + or len(",".join(folded_table_names)) > self.capabilities.max_query_length / 2 + ): + logger.info( + "Fallback to query all columns from INFORMATION_SCHEMA due to limited query length" + " or table threshold" + ) + folded_table_names = [] - fields = self._get_storage_table_query_columns() - db_params = self.sql_client.make_qualified_table_name(table_name, escape=False).split( - ".", 3 + query, db_params = self._get_info_schema_columns_query( + catalog_name, schema_name, folded_table_names ) - query = f""" -SELECT {",".join(fields)} - FROM INFORMATION_SCHEMA.COLUMNS -WHERE """ - if len(db_params) == 3: - query += "table_catalog = %s AND " - query += "table_schema = %s AND table_name = %s ORDER BY ordinal_position;" rows = self.sql_client.execute_sql(query, *db_params) - - # if no rows we assume that table does not exist - schema_table: TTableSchemaColumns = {} - if len(rows) == 0: - # TODO: additionally check if table exists - return False, schema_table - # TODO: pull more data to infer indexes, PK and uniques attributes/constraints + prev_table: str = None + storage_columns: TTableSchemaColumns = None for c in rows: + # if we are selecting all tables this is expected + if not folded_table_names and c[0] not in name_lookup: + continue + # make sure that new table is known + assert ( + c[0] in name_lookup + ), f"Table name {c[0]} not in expected tables {name_lookup.keys()}" + table_name = name_lookup[c[0]] + if prev_table != table_name: + # yield what we have + if storage_columns: + yield (prev_table, storage_columns) + # we have new table + storage_columns = {} + prev_table = table_name + # remove from table_names + table_names.remove(prev_table) + # add columns + col_name = c[1] numeric_precision = ( - c[3] if self.capabilities.schema_supports_numeric_precision else None + c[4] if self.capabilities.schema_supports_numeric_precision else None ) - numeric_scale = c[4] if self.capabilities.schema_supports_numeric_precision else None + numeric_scale = c[5] if self.capabilities.schema_supports_numeric_precision else None + schema_c: TColumnSchemaBase = { - "name": c[0], - "nullable": _null_to_bool(c[2]), - **self._from_db_type(c[1], numeric_precision, numeric_scale), + "name": col_name, + "nullable": info_schema_null_to_bool(c[3]), + **self._from_db_type(c[2], numeric_precision, numeric_scale), } - schema_table[c[0]] = schema_c # type: ignore - return True, schema_table + storage_columns[col_name] = schema_c # type: ignore + # yield last table, it must have at least one column or we had no rows + if storage_columns: + yield (prev_table, storage_columns) + # if no columns we assume that table does not exist + for table_name in table_names: + yield (table_name, {}) + + def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: + """Uses get_storage_tables to get single `table_name` schema. + + Returns (True, ...) if table exists and (False, {}) when not + """ + storage_table = list(self.get_storage_tables([table_name]))[0] + return len(storage_table[1]) > 0, storage_table[1] @abstractmethod def _from_db_type( @@ -365,31 +405,91 @@ def _from_db_type( def get_stored_schema(self) -> StorageSchemaInfo: name = self.sql_client.make_qualified_table_name(self.schema.version_table_name) + c_schema_name, c_inserted_at = self._norm_and_escape_columns("schema_name", "inserted_at") query = ( - f"SELECT {self.version_table_schema_columns} FROM {name} WHERE schema_name = %s ORDER" - " BY inserted_at DESC;" + f"SELECT {self.version_table_schema_columns} FROM {name} WHERE {c_schema_name} = %s" + f" ORDER BY {c_inserted_at} DESC;" ) return self._row_to_schema_info(query, self.schema.name) def get_stored_state(self, pipeline_name: str) -> StateInfo: state_table = self.sql_client.make_qualified_table_name(self.schema.state_table_name) loads_table = self.sql_client.make_qualified_table_name(self.schema.loads_table_name) + c_load_id, c_dlt_load_id, c_pipeline_name, c_status = self._norm_and_escape_columns( + "load_id", "_dlt_load_id", "pipeline_name", "status" + ) query = ( f"SELECT {self.state_table_columns} FROM {state_table} AS s JOIN {loads_table} AS l ON" - " l.load_id = s._dlt_load_id WHERE pipeline_name = %s AND l.status = 0 ORDER BY" - " l.load_id DESC" + f" l.{c_load_id} = s.{c_dlt_load_id} WHERE {c_pipeline_name} = %s AND l.{c_status} = 0" + f" ORDER BY {c_load_id} DESC" ) with self.sql_client.execute_query(query, pipeline_name) as cur: row = cur.fetchone() if not row: return None - return StateInfo(row[0], row[1], row[2], row[3], pendulum.instance(row[4])) + # NOTE: we request order of columns in SELECT statement which corresponds to StateInfo + return StateInfo( + version=row[0], + engine_version=row[1], + pipeline_name=row[2], + state=row[3], + created_at=pendulum.instance(row[4]), + _dlt_load_id=row[5], + ) + + def _norm_and_escape_columns(self, *columns: str) -> Iterator[str]: + return map( + self.sql_client.escape_column_name, map(self.schema.naming.normalize_path, columns) + ) def get_stored_schema_by_hash(self, version_hash: str) -> StorageSchemaInfo: - name = self.sql_client.make_qualified_table_name(self.schema.version_table_name) - query = f"SELECT {self.version_table_schema_columns} FROM {name} WHERE version_hash = %s;" + table_name = self.sql_client.make_qualified_table_name(self.schema.version_table_name) + (c_version_hash,) = self._norm_and_escape_columns("version_hash") + query = ( + f"SELECT {self.version_table_schema_columns} FROM {table_name} WHERE" + f" {c_version_hash} = %s;" + ) return self._row_to_schema_info(query, version_hash) + def _get_info_schema_columns_query( + self, catalog_name: Optional[str], schema_name: str, folded_table_names: List[str] + ) -> Tuple[str, List[Any]]: + """Generates SQL to query INFORMATION_SCHEMA.COLUMNS for a set of tables in `folded_table_names`. Input identifiers must be already + in a form that can be passed to a query via db_params. `catalogue_name` and `folded_tableS_name` is optional and when None, the part of query selecting it + is skipped. + + Returns: query and list of db_params tuple + """ + query = f""" +SELECT {",".join(self._get_storage_table_query_columns())} + FROM INFORMATION_SCHEMA.COLUMNS +WHERE """ + + db_params = [] + if catalog_name: + db_params.append(catalog_name) + query += "table_catalog = %s AND " + db_params.append(schema_name) + select_tables_clause = "" + # look for particular tables only when requested, otherwise return the full schema + if folded_table_names: + db_params = db_params + folded_table_names + # placeholder for each table + table_placeholders = ",".join(["%s"] * len(folded_table_names)) + select_tables_clause = f"AND table_name IN ({table_placeholders})" + query += f"table_schema = %s {select_tables_clause} ORDER BY table_name, ordinal_position;" + + return query, db_params + + def _get_storage_table_query_columns(self) -> List[str]: + """Column names used when querying table from information schema. + Override for databases that use different namings. + """ + fields = ["table_name", "column_name", "data_type", "is_nullable"] + if self.capabilities.schema_supports_numeric_precision: + fields += ["numeric_precision", "numeric_scale"] + return fields + def _execute_schema_update_sql(self, only_tables: Iterable[str]) -> TSchemaTables: sql_scripts, schema_update = self._build_schema_update_sql(only_tables) # Stay within max query size when doing DDL. @@ -416,12 +516,16 @@ def _build_schema_update_sql( """ sql_updates = [] schema_update: TSchemaTables = {} - for table_name in only_tables or self.schema.tables: - exists, storage_table = self.get_storage_table(table_name) - new_columns = self._create_table_update(table_name, storage_table) + for table_name, storage_columns in self.get_storage_tables( + only_tables or self.schema.tables.keys() + ): + # this will skip incomplete columns + new_columns = self._create_table_update(table_name, storage_columns) if len(new_columns) > 0: # build and add sql to execute - sql_statements = self._get_table_update_sql(table_name, new_columns, exists) + sql_statements = self._get_table_update_sql( + table_name, new_columns, len(storage_columns) > 0 + ) for sql in sql_statements: if not sql.endswith(";"): sql += ";" @@ -472,7 +576,7 @@ def _get_table_update_sql( for hint in COLUMN_HINTS: if any(c.get(hint, False) is True for c in new_columns): hint_columns = [ - self.capabilities.escape_identifier(c["name"]) + self.sql_client.escape_column_name(c["name"]) for c in new_columns if c.get(hint, False) ] @@ -501,8 +605,12 @@ def _gen_not_null(v: bool) -> str: def _create_table_update( self, table_name: str, storage_columns: TTableSchemaColumns ) -> Sequence[TColumnSchema]: - # compare table with stored schema and produce delta - updates = self.schema.get_new_table_columns(table_name, storage_columns) + """Compares storage columns with schema table and produce delta columns difference""" + updates = self.schema.get_new_table_columns( + table_name, + storage_columns, + case_sensitive=self.capabilities.generates_case_sensitive_identifiers(), + ) logger.info(f"Found {len(updates)} updates for {table_name} in {self.schema.name}") return updates @@ -526,16 +634,17 @@ def _row_to_schema_info(self, query: str, *args: Any) -> StorageSchemaInfo: pass # make utc datetime - inserted_at = pendulum.instance(row[4]) + inserted_at = pendulum.instance(row[2]) - return StorageSchemaInfo(row[0], row[1], row[2], row[3], inserted_at, schema_str) + return StorageSchemaInfo(row[4], row[3], row[0], row[1], inserted_at, schema_str) def _delete_schema_in_storage(self, schema: Schema) -> None: """ Delete all stored versions with the same name as given schema """ name = self.sql_client.make_qualified_table_name(self.schema.version_table_name) - self.sql_client.execute_sql(f"DELETE FROM {name} WHERE schema_name = %s;", schema.name) + (c_schema_name,) = self._norm_and_escape_columns("schema_name") + self.sql_client.execute_sql(f"DELETE FROM {name} WHERE {c_schema_name} = %s;", schema.name) def _update_schema_in_storage(self, schema: Schema) -> None: # get schema string or zip @@ -554,14 +663,42 @@ def _commit_schema_update(self, schema: Schema, schema_str: str) -> None: self.sql_client.execute_sql( f"INSERT INTO {name}({self.version_table_schema_columns}) VALUES (%s, %s, %s, %s, %s," " %s);", - schema.stored_version_hash, - schema.name, schema.version, schema.ENGINE_VERSION, now_ts, + schema.name, + schema.stored_version_hash, schema_str, ) + def _verify_schema(self) -> None: + super()._verify_schema() + if exceptions := verify_sql_job_client_schema(self.schema, warnings=True): + for exception in exceptions: + logger.error(str(exception)) + raise exceptions[0] + + def _set_query_tags_for_job(self, load_id: str, table: TTableSchema) -> None: + """Sets query tags in sql_client for a job in package `load_id`, starting for a particular `table`""" + from dlt.common.pipeline import current_pipeline + + pipeline = current_pipeline() + pipeline_name = pipeline.pipeline_name if pipeline else "" + self.sql_client.set_query_tags( + { + "source": self.schema.name, + "resource": ( + get_inherited_table_hint( + self.schema._schema_tables, table["name"], "resource", allow_none=True + ) + or "" + ), + "table": table["name"], + "load_id": load_id, + "pipeline_name": pipeline_name, + } + ) + class SqlJobClientWithStaging(SqlJobClientBase, WithStagingDataset): in_staging_mode: bool = False @@ -569,7 +706,7 @@ class SqlJobClientWithStaging(SqlJobClientBase, WithStagingDataset): @contextlib.contextmanager def with_staging_dataset(self) -> Iterator["SqlJobClientBase"]: try: - with self.sql_client.with_staging_dataset(True): + with self.sql_client.with_staging_dataset(): self.in_staging_mode = True yield self finally: diff --git a/dlt/destinations/job_impl.py b/dlt/destinations/job_impl.py index a4e4b998af..9a8f7277b7 100644 --- a/dlt/destinations/job_impl.py +++ b/dlt/destinations/job_impl.py @@ -5,6 +5,7 @@ from dlt.common.json import json from dlt.common.destination.reference import NewLoadJob, FollowupJob, TLoadJobState, LoadJob +from dlt.common.storages.load_package import commit_load_package_state from dlt.common.schema import Schema, TTableSchema from dlt.common.storages import FileStorage from dlt.common.typing import TDataItems @@ -14,8 +15,6 @@ TDestinationCallable, ) -from dlt.pipeline.current import commit_load_package_state - class EmptyLoadJobWithoutFollowup(LoadJob): def __init__(self, file_name: str, status: TLoadJobState, exception: str = None) -> None: diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index c3f5d3cdf4..66248e6c3b 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -7,6 +7,7 @@ Any, ClassVar, ContextManager, + Dict, Generic, Iterator, Optional, @@ -16,6 +17,7 @@ AnyStr, List, Generator, + TypedDict, ) from dlt.common.typing import TFun @@ -37,15 +39,45 @@ from dlt.common.destination.reference import SupportsDataAccess +from dlt.destinations.typing import DBApi, TNativeConn, DBApiCursor, DataFrame, DBTransaction + + +class TJobQueryTags(TypedDict): + """Applied to sql client when a job using it starts. Using to tag queries""" + + source: str + resource: str + table: str + load_id: str + pipeline_name: str + + class SqlClientBase(SupportsDataAccess, ABC, Generic[TNativeConn]): dbapi: ClassVar[DBApi] = None - capabilities: ClassVar[DestinationCapabilitiesContext] = None - def __init__(self, database_name: str, dataset_name: str) -> None: + database_name: Optional[str] + """Database or catalog name, optional""" + dataset_name: str + """Normalized dataset name""" + staging_dataset_name: str + """Normalized staging dataset name""" + capabilities: DestinationCapabilitiesContext + """Instance of adjusted destination capabilities""" + + def __init__( + self, + database_name: str, + dataset_name: str, + staging_dataset_name: str, + capabilities: DestinationCapabilitiesContext, + ) -> None: if not dataset_name: raise ValueError(dataset_name) self.dataset_name = dataset_name + self.staging_dataset_name = staging_dataset_name self.database_name = database_name + self.capabilities = capabilities + self._query_tags: TJobQueryTags = None @abstractmethod def open_connection(self) -> TNativeConn: @@ -84,9 +116,12 @@ def has_dataset(self) -> bool: SELECT 1 FROM INFORMATION_SCHEMA.SCHEMATA WHERE """ - db_params = self.fully_qualified_dataset_name(escape=False).split(".", 2) - if len(db_params) == 2: + catalog_name, schema_name, _ = self._get_information_schema_components() + db_params: List[str] = [] + if catalog_name is not None: query += " catalog_name = %s AND " + db_params.append(catalog_name) + db_params.append(schema_name) query += "schema_name = %s" rows = self.execute_sql(query, *db_params) return len(rows) > 0 @@ -102,6 +137,7 @@ def truncate_tables(self, *tables: str) -> None: self.execute_many(statements) def drop_tables(self, *tables: str) -> None: + """Drops a set of tables if they exist""" if not tables: return statements = [ @@ -146,16 +182,50 @@ def execute_many( ret.append(result) return ret - @abstractmethod - def fully_qualified_dataset_name(self, escape: bool = True) -> str: - pass + def catalog_name(self, escape: bool = True) -> Optional[str]: + # default is no catalogue component of the name, which typically means that + # connection is scoped to a current database + return None + + def fully_qualified_dataset_name(self, escape: bool = True, staging: bool = False) -> str: + if staging: + with self.with_staging_dataset(): + path = self.make_qualified_table_name_path(None, escape=escape) + else: + path = self.make_qualified_table_name_path(None, escape=escape) + return ".".join(path) def make_qualified_table_name(self, table_name: str, escape: bool = True) -> str: + return ".".join(self.make_qualified_table_name_path(table_name, escape=escape)) + + def make_qualified_table_name_path( + self, table_name: Optional[str], escape: bool = True + ) -> List[str]: + """Returns a list with path components leading from catalog to table_name. + Used to construct fully qualified names. `table_name` is optional. + """ + path: List[str] = [] + if catalog_name := self.catalog_name(escape=escape): + path.append(catalog_name) + dataset_name = self.capabilities.casefold_identifier(self.dataset_name) if escape: - table_name = self.capabilities.escape_identifier(table_name) - return f"{self.fully_qualified_dataset_name(escape=escape)}.{table_name}" + dataset_name = self.capabilities.escape_identifier(dataset_name) + path.append(dataset_name) + if table_name: + table_name = self.capabilities.casefold_identifier(table_name) + if escape: + table_name = self.capabilities.escape_identifier(table_name) + path.append(table_name) + return path + + def get_qualified_table_names(self, table_name: str, escape: bool = True) -> Tuple[str, str]: + """Returns qualified names for table and corresponding staging table as tuple.""" + with self.with_staging_dataset(): + staging_table_name = self.make_qualified_table_name(table_name, escape) + return self.make_qualified_table_name(table_name, escape), staging_table_name def escape_column_name(self, column_name: str, escape: bool = True) -> str: + column_name = self.capabilities.casefold_identifier(column_name) if escape: return self.capabilities.escape_identifier(column_name) return column_name @@ -173,13 +243,12 @@ def with_alternative_dataset_name( # restore previous dataset name self.dataset_name = current_dataset_name - def with_staging_dataset( - self, staging: bool = False - ) -> ContextManager["SqlClientBase[TNativeConn]"]: - dataset_name = self.dataset_name - if staging: - dataset_name = SqlClientBase.make_staging_dataset_name(dataset_name) - return self.with_alternative_dataset_name(dataset_name) + def with_staging_dataset(self) -> ContextManager["SqlClientBase[TNativeConn]"]: + return self.with_alternative_dataset_name(self.staging_dataset_name) + + def set_query_tags(self, tags: TJobQueryTags) -> None: + """Sets current schema (source), resource, load_id and table name when a job starts""" + self._query_tags = tags def _ensure_native_conn(self) -> None: if not self.native_connection: @@ -196,9 +265,17 @@ def is_dbapi_exception(ex: Exception) -> bool: mro = type.mro(type(ex)) return any(t.__name__ in ("DatabaseError", "DataError") for t in mro) - @staticmethod - def make_staging_dataset_name(dataset_name: str) -> str: - return dataset_name + "_staging" + def _get_information_schema_components(self, *tables: str) -> Tuple[str, str, List[str]]: + """Gets catalog name, schema name and name of the tables in format that can be directly + used to query INFORMATION_SCHEMA. catalog name is optional: in that case None is + returned in the first element of the tuple. + """ + schema_path = self.make_qualified_table_name_path(None, escape=False) + return ( + self.catalog_name(escape=False), + schema_path[-1], + [self.make_qualified_table_name_path(table, escape=False)[-1] for table in tables], + ) # # generate sql statements @@ -252,6 +329,11 @@ def _get_columns(self) -> List[str]: return [c[0] for c in self.native_cursor.description] def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: + """Fetches results as data frame in full or in specified chunks. + + May use native pandas/arrow reader if available. Depending on + the native implementation chunk size may vary. + """ from dlt.common.libs.pandas_sql import _wrap_result columns = self._get_columns() diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index 4f8e29ae0d..e9aedf6aca 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Sequence, Tuple, cast, TypedDict, Optional +from typing import Any, Dict, List, Sequence, Tuple, cast, TypedDict, Optional, Callable, Union import yaml from dlt.common.logger import pretty_format_exception @@ -6,6 +6,7 @@ from dlt.common.schema.typing import ( TTableSchema, TSortOrder, + TColumnProp, ) from dlt.common.schema.utils import ( get_columns_names_with_prop, @@ -16,12 +17,12 @@ DEFAULT_MERGE_STRATEGY, ) from dlt.common.storages.load_storage import ParsedLoadJobFileName +from dlt.common.storages.load_package import load_package as current_load_package from dlt.common.utils import uniq_id from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.destinations.exceptions import MergeDispositionException from dlt.destinations.job_impl import NewLoadJobImpl from dlt.destinations.sql_client import SqlClientBase -from dlt.pipeline.current import load_package as current_load_package class SqlJobParams(TypedDict, total=False): @@ -95,7 +96,7 @@ def _generate_clone_sql( """Drop and clone the table for supported destinations""" sql: List[str] = [] for table in table_chain: - with sql_client.with_staging_dataset(staging=True): + with sql_client.with_staging_dataset(): staging_table_name = sql_client.make_qualified_table_name(table["name"]) table_name = sql_client.make_qualified_table_name(table["name"]) sql.append(f"DROP TABLE IF EXISTS {table_name};") @@ -112,12 +113,12 @@ def _generate_insert_sql( ) -> List[str]: sql: List[str] = [] for table in table_chain: - with sql_client.with_staging_dataset(staging=True): + with sql_client.with_staging_dataset(): staging_table_name = sql_client.make_qualified_table_name(table["name"]) table_name = sql_client.make_qualified_table_name(table["name"]) columns = ", ".join( map( - sql_client.capabilities.escape_identifier, + sql_client.escape_column_name, get_columns_names_with_prop(table, "name"), ) ) @@ -158,6 +159,8 @@ def generate_sql( # type: ignore[return] merge_strategy = table_chain[0].get("x-merge-strategy", DEFAULT_MERGE_STRATEGY) if merge_strategy == "delete-insert": return cls.gen_merge_sql(table_chain, sql_client) + elif merge_strategy == "upsert": + return cls.gen_upsert_sql(table_chain, sql_client) elif merge_strategy == "scd2": return cls.gen_scd2_sql(table_chain, sql_client) @@ -342,6 +345,107 @@ def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: """ return f"CREATE TEMP TABLE {temp_table_name} AS {select_sql};" + @classmethod + def gen_update_table_prefix(cls, table_name: str) -> str: + return f"UPDATE {table_name} SET" + + @classmethod + def requires_temp_table_for_delete(cls) -> bool: + """Whether a temporary table is required to delete records. + + Must be `True` for destinations that don't support correlated subqueries. + """ + return False + + @classmethod + def _escape_list(cls, list_: List[str], escape_id: Callable[[str], str]) -> List[str]: + return list(map(escape_id, list_)) + + @classmethod + def _get_hard_delete_col_and_cond( + cls, + table: TTableSchema, + escape_id: Callable[[str], str], + escape_lit: Callable[[Any], Any], + invert: bool = False, + ) -> Tuple[Optional[str], Optional[str]]: + """Returns tuple of hard delete column name and SQL condition statement. + + Returns tuple of `None` values if no column has `hard_delete` hint. + Condition statement can be used to filter deleted records. + Set `invert=True` to filter non-deleted records instead. + """ + + col = get_first_column_name_with_prop(table, "hard_delete") + if col is None: + return (None, None) + cond = f"{escape_id(col)} IS NOT NULL" + if invert: + cond = f"{escape_id(col)} IS NULL" + if table["columns"][col]["data_type"] == "bool": + if invert: + cond += f" OR {escape_id(col)} = {escape_lit(False)}" + else: + cond = f"{escape_id(col)} = {escape_lit(True)}" + return (col, cond) + + @classmethod + def _get_unique_col( + cls, + table_chain: Sequence[TTableSchema], + sql_client: SqlClientBase[Any], + table: TTableSchema, + ) -> str: + """Returns name of first column in `table` with `unique` property. + + Raises `MergeDispositionException` if no such column exists. + """ + return cls._get_prop_col_or_raise( + table, + "unique", + MergeDispositionException( + sql_client.fully_qualified_dataset_name(), + sql_client.fully_qualified_dataset_name(staging=True), + [t["name"] for t in table_chain], + f"No `unique` column (e.g. `_dlt_id`) in table `{table['name']}`.", + ), + ) + + @classmethod + def _get_root_key_col( + cls, + table_chain: Sequence[TTableSchema], + sql_client: SqlClientBase[Any], + table: TTableSchema, + ) -> str: + """Returns name of first column in `table` with `root_key` property. + + Raises `MergeDispositionException` if no such column exists. + """ + return cls._get_prop_col_or_raise( + table, + "root_key", + MergeDispositionException( + sql_client.fully_qualified_dataset_name(), + sql_client.fully_qualified_dataset_name(staging=True), + [t["name"] for t in table_chain], + f"No `root_key` column (e.g. `_dlt_root_id`) in table `{table['name']}`.", + ), + ) + + @classmethod + def _get_prop_col_or_raise( + cls, table: TTableSchema, prop: Union[TColumnProp, str], exception: Exception + ) -> str: + """Returns name of first column in `table` with `prop` property. + + Raises `exception` if no such column exists. + """ + col = get_first_column_name_with_prop(table, prop) + if col is None: + raise exception + return col + @classmethod def gen_merge_sql( cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any] @@ -361,30 +465,24 @@ def gen_merge_sql( sql: List[str] = [] root_table = table_chain[0] - escape_id = sql_client.capabilities.escape_identifier + escape_column_id = sql_client.escape_column_name escape_lit = sql_client.capabilities.escape_literal - if escape_id is None: - escape_id = DestinationCapabilitiesContext.generic_capabilities().escape_identifier if escape_lit is None: escape_lit = DestinationCapabilitiesContext.generic_capabilities().escape_literal # get top level table full identifiers - root_table_name = sql_client.make_qualified_table_name(root_table["name"]) - with sql_client.with_staging_dataset(staging=True): - staging_root_table_name = sql_client.make_qualified_table_name(root_table["name"]) + root_table_name, staging_root_table_name = sql_client.get_qualified_table_names( + root_table["name"] + ) # get merge and primary keys from top level - primary_keys = list( - map( - escape_id, - get_columns_names_with_prop(root_table, "primary_key"), - ) + primary_keys = cls._escape_list( + get_columns_names_with_prop(root_table, "primary_key"), + escape_column_id, ) - merge_keys = list( - map( - escape_id, - get_columns_names_with_prop(root_table, "merge_key"), - ) + merge_keys = cls._escape_list( + get_columns_names_with_prop(root_table, "merge_key"), + escape_column_id, ) # if we do not have any merge keys to select from, we will fall back to a staged append, i.E. @@ -409,18 +507,9 @@ def gen_merge_sql( root_table_name, staging_root_table_name, key_clauses, for_delete=False ) # use unique hint to create temp table with all identifiers to delete - unique_columns = get_columns_names_with_prop(root_table, "unique") - if not unique_columns: - raise MergeDispositionException( - sql_client.fully_qualified_dataset_name(), - staging_root_table_name, - [t["name"] for t in table_chain], - "There is no unique column (ie _dlt_id) in top table" - f" {root_table['name']} so it is not possible to link child tables to it.", - ) - # get first unique column - unique_column = escape_id(unique_columns[0]) - # create temp table with unique identifier + unique_column = escape_column_id( + cls._get_unique_col(table_chain, sql_client, root_table) + ) create_delete_temp_table_sql, delete_temp_table_name = ( cls.gen_delete_temp_table_sql( root_table["name"], unique_column, key_table_clauses, sql_client @@ -432,39 +521,29 @@ def gen_merge_sql( # but uses temporary views instead for table in table_chain[1:]: table_name = sql_client.make_qualified_table_name(table["name"]) - root_key_columns = get_columns_names_with_prop(table, "root_key") - if not root_key_columns: - raise MergeDispositionException( - sql_client.fully_qualified_dataset_name(), - staging_root_table_name, - [t["name"] for t in table_chain], - "There is no root foreign key (ie _dlt_root_id) in child table" - f" {table['name']} so it is not possible to refer to top level table" - f" {root_table['name']} unique column {unique_column}", - ) - root_key_column = escape_id(root_key_columns[0]) + root_key_column = escape_column_id( + cls._get_root_key_col(table_chain, sql_client, table) + ) sql.append( cls.gen_delete_from_sql( table_name, root_key_column, delete_temp_table_name, unique_column ) ) - # delete from top table now that child tables have been prcessed + # delete from top table now that child tables have been processed sql.append( cls.gen_delete_from_sql( root_table_name, unique_column, delete_temp_table_name, unique_column ) ) - # get name of column with hard_delete hint, if specified - not_deleted_cond: str = None - hard_delete_col = get_first_column_name_with_prop(root_table, "hard_delete") - if hard_delete_col is not None: - # any value indicates a delete for non-boolean columns - not_deleted_cond = f"{escape_id(hard_delete_col)} IS NULL" - if root_table["columns"][hard_delete_col]["data_type"] == "bool": - # only True values indicate a delete for boolean columns - not_deleted_cond += f" OR {escape_id(hard_delete_col)} = {escape_lit(False)}" + # get hard delete information + hard_delete_col, not_deleted_cond = cls._get_hard_delete_col_and_cond( + root_table, + escape_column_id, + escape_lit, + invert=True, + ) # get dedup sort information dedup_sort = get_dedup_sort_tuple(root_table) @@ -472,7 +551,8 @@ def gen_merge_sql( insert_temp_table_name: str = None if len(table_chain) > 1: if len(primary_keys) > 0 or hard_delete_col is not None: - condition_columns = [hard_delete_col] if not_deleted_cond is not None else None + # condition_columns = [hard_delete_col] if not_deleted_cond is not None else None + condition_columns = None if hard_delete_col is None else [hard_delete_col] ( create_insert_temp_table_sql, insert_temp_table_name, @@ -490,9 +570,7 @@ def gen_merge_sql( # insert from staging to dataset for table in table_chain: - table_name = sql_client.make_qualified_table_name(table["name"]) - with sql_client.with_staging_dataset(staging=True): - staging_table_name = sql_client.make_qualified_table_name(table["name"]) + table_name, staging_table_name = sql_client.get_qualified_table_names(table["name"]) insert_cond = not_deleted_cond if hard_delete_col is not None else "1 = 1" if (len(primary_keys) > 0 and len(table_chain) > 1) or ( @@ -503,7 +581,7 @@ def gen_merge_sql( uniq_column = unique_column if table.get("parent") is None else root_key_column insert_cond = f"{uniq_column} IN (SELECT * FROM {insert_temp_table_name})" - columns = list(map(escape_id, get_columns_names_with_prop(table, "name"))) + columns = list(map(escape_column_id, get_columns_names_with_prop(table, "name"))) col_str = ", ".join(columns) select_sql = f"SELECT {col_str} FROM {staging_table_name} WHERE {insert_cond}" if len(primary_keys) > 0 and len(table_chain) == 1: @@ -515,6 +593,93 @@ def gen_merge_sql( sql.append(f"INSERT INTO {table_name}({col_str}) {select_sql};") return sql + @classmethod + def gen_upsert_sql( + cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any] + ) -> List[str]: + sql: List[str] = [] + root_table = table_chain[0] + root_table_name, staging_root_table_name = sql_client.get_qualified_table_names( + root_table["name"] + ) + escape_column_id = sql_client.escape_column_name + escape_lit = sql_client.capabilities.escape_literal + if escape_lit is None: + escape_lit = DestinationCapabilitiesContext.generic_capabilities().escape_literal + + # process table hints + primary_keys = cls._escape_list( + get_columns_names_with_prop(root_table, "primary_key"), + escape_column_id, + ) + hard_delete_col, deleted_cond = cls._get_hard_delete_col_and_cond( + root_table, + escape_column_id, + escape_lit, + ) + + # generate merge statement for root table + on_str = " AND ".join([f"d.{c} = s.{c}" for c in primary_keys]) + root_table_column_names = list(map(escape_column_id, root_table["columns"].keys())) + update_str = ", ".join([c + " = " + "s." + c for c in root_table_column_names]) + col_str = ", ".join(["{alias}" + c for c in root_table_column_names]) + delete_str = ( + "" if hard_delete_col is None else f"WHEN MATCHED AND s.{deleted_cond} THEN DELETE" + ) + + sql.append(f""" + MERGE INTO {root_table_name} d USING {staging_root_table_name} s + ON {on_str} + {delete_str} + WHEN MATCHED + THEN UPDATE SET {update_str} + WHEN NOT MATCHED + THEN INSERT ({col_str.format(alias="")}) VALUES ({col_str.format(alias="s.")}); + """) + + # generate statements for child tables if they exist + child_tables = table_chain[1:] + if child_tables: + root_unique_column = escape_column_id( + cls._get_unique_col(table_chain, sql_client, root_table) + ) + for table in child_tables: + unique_column = escape_column_id( + cls._get_unique_col(table_chain, sql_client, table) + ) + root_key_column = escape_column_id( + cls._get_root_key_col(table_chain, sql_client, table) + ) + table_name, staging_table_name = sql_client.get_qualified_table_names(table["name"]) + + # delete records for elements no longer in the list + sql.append(f""" + DELETE FROM {table_name} + WHERE {root_key_column} IN (SELECT {root_unique_column} FROM {staging_root_table_name}) + AND {unique_column} NOT IN (SELECT {unique_column} FROM {staging_table_name}); + """) + + # insert records for new elements in the list + col_str = ", ".join(["{alias}" + escape_column_id(c) for c in table["columns"]]) + sql.append(f""" + MERGE INTO {table_name} d USING {staging_table_name} s + ON d.{unique_column} = s.{unique_column} + WHEN NOT MATCHED + THEN INSERT ({col_str.format(alias="")}) VALUES ({col_str.format(alias="s.")}); + """) + + # delete hard-deleted records + if hard_delete_col is not None: + sql.append(f""" + DELETE FROM {table_name} + WHERE {root_key_column} IN ( + SELECT {root_unique_column} + FROM {staging_root_table_name} + WHERE {deleted_cond} + ); + """) + return sql + @classmethod def gen_scd2_sql( cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any] @@ -528,15 +693,17 @@ def gen_scd2_sql( """ sql: List[str] = [] root_table = table_chain[0] - root_table_name = sql_client.make_qualified_table_name(root_table["name"]) - with sql_client.with_staging_dataset(staging=True): - staging_root_table_name = sql_client.make_qualified_table_name(root_table["name"]) + root_table_name, staging_root_table_name = sql_client.get_qualified_table_names( + root_table["name"] + ) # get column names caps = sql_client.capabilities - escape_id = caps.escape_identifier - from_, to = list(map(escape_id, get_validity_column_names(root_table))) # validity columns - hash_ = escape_id( + escape_column_id = sql_client.escape_column_name + from_, to = list( + map(escape_column_id, get_validity_column_names(root_table)) + ) # validity columns + hash_ = escape_column_id( get_first_column_name_with_prop(root_table, "x-row-version") ) # row hash column @@ -568,7 +735,7 @@ def gen_scd2_sql( """) # insert new active records in root table - columns = map(escape_id, list(root_table["columns"].keys())) + columns = map(escape_column_id, list(root_table["columns"].keys())) col_str = ", ".join([c for c in columns if c not in (from_, to)]) sql.append(f""" INSERT INTO {root_table_name} ({col_str}, {from_}, {to}) @@ -580,27 +747,15 @@ def gen_scd2_sql( # insert list elements for new active records in child tables child_tables = table_chain[1:] if child_tables: - unique_column: str = None - # use unique hint to create temp table with all identifiers to delete - unique_columns = get_columns_names_with_prop(root_table, "unique") - if not unique_columns: - raise MergeDispositionException( - sql_client.fully_qualified_dataset_name(), - staging_root_table_name, - [t["name"] for t in table_chain], - f"There is no unique column (ie _dlt_id) in top table {root_table['name']} so" - " it is not possible to link child tables to it.", - ) - # get first unique column - unique_column = escape_id(unique_columns[0]) + unique_column = escape_column_id( + cls._get_unique_col(table_chain, sql_client, root_table) + ) # TODO: - based on deterministic child hashes (OK) # - if row hash changes all is right # - if it does not we only capture new records, while we should replace existing with those in stage # - this write disposition is way more similar to regular merge (how root tables are handled is different, other tables handled same) for table in child_tables: - table_name = sql_client.make_qualified_table_name(table["name"]) - with sql_client.with_staging_dataset(staging=True): - staging_table_name = sql_client.make_qualified_table_name(table["name"]) + table_name, staging_table_name = sql_client.get_qualified_table_names(table["name"]) sql.append(f""" INSERT INTO {table_name} SELECT * @@ -609,15 +764,3 @@ def gen_scd2_sql( """) return sql - - @classmethod - def gen_update_table_prefix(cls, table_name: str) -> str: - return f"UPDATE {table_name} SET" - - @classmethod - def requires_temp_table_for_delete(cls) -> bool: - """Whether a temporary table is required to delete records. - - Must be `True` for destinations that don't support correlated subqueries. - """ - return False diff --git a/dlt/destinations/utils.py b/dlt/destinations/utils.py index c02460fe58..9dd8b83509 100644 --- a/dlt/destinations/utils.py +++ b/dlt/destinations/utils.py @@ -1,9 +1,23 @@ import re +from typing import Any, List, Optional, Tuple + +from dlt.common import logger +from dlt.common.schema import Schema +from dlt.common.schema.exceptions import SchemaCorruptedException +from dlt.common.schema.typing import MERGE_STRATEGIES, TTableSchema +from dlt.common.schema.utils import ( + get_columns_names_with_prop, + get_first_column_name_with_prop, + has_column_with_prop, + pipeline_state_table, +) from typing import Any, cast, Tuple, Dict, Type from dlt.destinations.exceptions import DatabaseTransientException from dlt.extract import DltResource, resource as make_resource +RE_DATA_TYPE = re.compile(r"([A-Z]+)\((\d+)(?:,\s?(\d+))?\)") + def ensure_resource(data: Any) -> DltResource: """Wraps `data` in a DltResource if it's not a DltResource already.""" @@ -13,6 +27,134 @@ def ensure_resource(data: Any) -> DltResource: return cast(DltResource, make_resource(data, name=resource_name)) +def info_schema_null_to_bool(v: str) -> bool: + """Converts INFORMATION SCHEMA truth values to Python bool""" + if v in ("NO", "0"): + return False + elif v in ("YES", "1"): + return True + raise ValueError(v) + + +def parse_db_data_type_str_with_precision(db_type: str) -> Tuple[str, Optional[int], Optional[int]]: + """Parses a db data type with optional precision or precision and scale information""" + # Search for matches using the regular expression + match = RE_DATA_TYPE.match(db_type) + + # If the pattern matches, extract the type, precision, and scale + if match: + db_type = match.group(1) + precision = int(match.group(2)) + scale = int(match.group(3)) if match.group(3) else None + return db_type, precision, scale + + # If the pattern does not match, return the original type without precision and scale + return db_type, None, None + + +def get_pipeline_state_query_columns() -> TTableSchema: + """We get definition of pipeline state table without columns we do not need for the query""" + state_table = pipeline_state_table() + # we do not need version_hash to be backward compatible as long as we can + state_table["columns"].pop("version_hash") + return state_table + + +def verify_sql_job_client_schema(schema: Schema, warnings: bool = True) -> List[Exception]: + log = logger.warning if warnings else logger.info + # collect all exceptions to show all problems in the schema + exception_log: List[Exception] = [] + + # verifies schema settings specific to sql job client + for table in schema.data_tables(): + table_name = table["name"] + if table.get("write_disposition") == "merge": + if "x-merge-strategy" in table and table["x-merge-strategy"] not in MERGE_STRATEGIES: # type: ignore[typeddict-item] + exception_log.append( + SchemaCorruptedException( + schema.name, + f'"{table["x-merge-strategy"]}" is not a valid merge strategy. ' # type: ignore[typeddict-item] + f"""Allowed values: {', '.join(['"' + s + '"' for s in MERGE_STRATEGIES])}.""", + ) + ) + if table.get("x-merge-strategy") == "delete-insert": + if not has_column_with_prop(table, "primary_key") and not has_column_with_prop( + table, "merge_key" + ): + log( + f"Table {table_name} has `write_disposition` set to `merge`" + " and `merge_strategy` set to `delete-insert`, but no primary or" + " merge keys defined." + " dlt will fall back to `append` for this table." + ) + elif table.get("x-merge-strategy") == "upsert": + if not has_column_with_prop(table, "primary_key"): + exception_log.append( + SchemaCorruptedException( + schema.name, + f"No primary key defined for table `{table['name']}`." + " `primary_key` needs to be set when using the `upsert`" + " merge strategy.", + ) + ) + if has_column_with_prop(table, "merge_key"): + log( + f"Found `merge_key` for table `{table['name']}` with" + " `upsert` merge strategy. Merge key is not supported" + " for this strategy and will be ignored." + ) + if has_column_with_prop(table, "hard_delete"): + if len(get_columns_names_with_prop(table, "hard_delete")) > 1: + exception_log.append( + SchemaCorruptedException( + schema.name, + f'Found multiple "hard_delete" column hints for table "{table_name}" in' + f' schema "{schema.name}" while only one is allowed:' + f' {", ".join(get_columns_names_with_prop(table, "hard_delete"))}.', + ) + ) + if table.get("write_disposition") in ("replace", "append"): + log( + f"""The "hard_delete" column hint for column "{get_first_column_name_with_prop(table, 'hard_delete')}" """ + f'in table "{table_name}" with write disposition' + f' "{table.get("write_disposition")}"' + f' in schema "{schema.name}" will be ignored.' + ' The "hard_delete" column hint is only applied when using' + ' the "merge" write disposition.' + ) + if has_column_with_prop(table, "dedup_sort"): + if len(get_columns_names_with_prop(table, "dedup_sort")) > 1: + exception_log.append( + SchemaCorruptedException( + schema.name, + f'Found multiple "dedup_sort" column hints for table "{table_name}" in' + f' schema "{schema.name}" while only one is allowed:' + f' {", ".join(get_columns_names_with_prop(table, "dedup_sort"))}.', + ) + ) + if table.get("write_disposition") in ("replace", "append"): + log( + f"""The "dedup_sort" column hint for column "{get_first_column_name_with_prop(table, 'dedup_sort')}" """ + f'in table "{table_name}" with write disposition' + f' "{table.get("write_disposition")}"' + f' in schema "{schema.name}" will be ignored.' + ' The "dedup_sort" column hint is only applied when using' + ' the "merge" write disposition.' + ) + if table.get("write_disposition") == "merge" and not has_column_with_prop( + table, "primary_key" + ): + log( + f"""The "dedup_sort" column hint for column "{get_first_column_name_with_prop(table, 'dedup_sort')}" """ + f'in table "{table_name}" with write disposition' + f' "{table.get("write_disposition")}"' + f' in schema "{schema.name}" will be ignored.' + ' The "dedup_sort" column hint is only applied when a' + " primary key has been specified." + ) + return exception_log + + def _convert_to_old_pyformat( new_style_string: str, args: Tuple[Any, ...], operational_error_cls: Type[Exception] ) -> Tuple[str, Dict[str, Any]]: diff --git a/dlt/extract/__init__.py b/dlt/extract/__init__.py index 03b2e59539..4029241634 100644 --- a/dlt/extract/__init__.py +++ b/dlt/extract/__init__.py @@ -4,13 +4,14 @@ from dlt.extract.decorators import source, resource, transformer, defer from dlt.extract.incremental import Incremental from dlt.extract.wrappers import wrap_additional_type -from dlt.extract.extractors import materialize_schema_item +from dlt.extract.extractors import materialize_schema_item, with_file_import __all__ = [ "DltResource", "DltSource", "with_table_name", "with_hints", + "with_file_import", "make_hints", "source", "resource", diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index 2bb4a3ce87..1eccd86aad 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -35,22 +35,20 @@ from dlt.common.schema.schema import Schema from dlt.common.schema.typing import ( TColumnNames, + TFileFormat, TWriteDisposition, TWriteDispositionConfig, TAnySchemaColumns, TSchemaContract, TTableFormat, ) -from dlt.extract.hints import make_hints -from dlt.extract.utils import ( - simulate_func_call, - wrap_compat_transformer, - wrap_resource_gen, -) from dlt.common.storages.exceptions import SchemaNotFoundError from dlt.common.storages.schema_storage import SchemaStorage from dlt.common.typing import AnyFun, ParamSpec, Concatenate, TDataItem, TDataItems from dlt.common.utils import get_callable_name, get_module_name, is_inner_callable + +from dlt.extract.hints import make_hints +from dlt.extract.utils import simulate_func_call from dlt.extract.exceptions import ( CurrentSourceNotAvailable, DynamicNameNotStandaloneResource, @@ -64,8 +62,6 @@ SourceNotAFunction, CurrentSourceSchemaNotAvailable, ) -from dlt.extract.incremental import IncrementalResourceWrapper - from dlt.extract.items import TTableHintTemplate from dlt.extract.source import DltSource from dlt.extract.resource import DltResource, TUnboundDltResource, TDltResourceImpl @@ -196,11 +192,7 @@ def decorator( # source name is passed directly or taken from decorated function name effective_name = name or get_callable_name(f) - if not schema: - # load the schema from file with name_schema.yaml/json from the same directory, the callable resides OR create new default schema - schema = _maybe_load_schema_for_callable(f, effective_name) or Schema(effective_name) - - if name and name != schema.name: + if schema and name and name != schema.name: raise ExplicitSourceNameInvalid(name, schema.name) # wrap source extraction function in configuration with section @@ -210,16 +202,16 @@ def decorator( source_sections = (known_sections.SOURCES, source_section, effective_name) conf_f = with_config(f, spec=spec, sections=source_sections) - def _eval_rv(_rv: Any) -> TDltSourceImpl: + def _eval_rv(_rv: Any, schema_copy: Schema) -> TDltSourceImpl: """Evaluates return value from the source function or coroutine""" if _rv is None: - raise SourceDataIsNone(schema.name) + raise SourceDataIsNone(schema_copy.name) # if generator, consume it immediately if inspect.isgenerator(_rv): _rv = list(_rv) # convert to source - s = _impl_cls.from_data(schema.clone(update_normalizers=True), source_section, _rv) + s = _impl_cls.from_data(schema_copy, source_section, _rv) # apply hints if max_table_nesting is not None: s.max_table_nesting = max_table_nesting @@ -228,10 +220,20 @@ def _eval_rv(_rv: Any) -> TDltSourceImpl: s.root_key = root_key return s + def _make_schema() -> Schema: + if not schema: + # load the schema from file with name_schema.yaml/json from the same directory, the callable resides OR create new default schema + return _maybe_load_schema_for_callable(f, effective_name) or Schema(effective_name) + else: + # clone the schema passed to decorator, update normalizers, remove processing hints + # NOTE: source may be called several times in many different settings + return schema.clone(update_normalizers=True, remove_processing_hints=True) + @wraps(conf_f) def _wrap(*args: Any, **kwargs: Any) -> TDltSourceImpl: """Wrap a regular function, injection context must be a part of the wrap""" - with Container().injectable_context(SourceSchemaInjectableContext(schema)): + schema_copy = _make_schema() + with Container().injectable_context(SourceSchemaInjectableContext(schema_copy)): # configurations will be accessed in this section in the source proxy = Container()[PipelineContext] pipeline_name = None if not proxy.is_active() else proxy.pipeline().pipeline_name @@ -239,18 +241,19 @@ def _wrap(*args: Any, **kwargs: Any) -> TDltSourceImpl: ConfigSectionContext( pipeline_name=pipeline_name, sections=source_sections, - source_state_key=schema.name, + source_state_key=schema_copy.name, ) ): rv = conf_f(*args, **kwargs) - return _eval_rv(rv) + return _eval_rv(rv, schema_copy) @wraps(conf_f) async def _wrap_coro(*args: Any, **kwargs: Any) -> TDltSourceImpl: """In case of co-routine we must wrap the whole injection context in awaitable, there's no easy way to avoid some code duplication """ - with Container().injectable_context(SourceSchemaInjectableContext(schema)): + schema_copy = _make_schema() + with Container().injectable_context(SourceSchemaInjectableContext(schema_copy)): # configurations will be accessed in this section in the source proxy = Container()[PipelineContext] pipeline_name = None if not proxy.is_active() else proxy.pipeline().pipeline_name @@ -258,11 +261,11 @@ async def _wrap_coro(*args: Any, **kwargs: Any) -> TDltSourceImpl: ConfigSectionContext( pipeline_name=pipeline_name, sections=source_sections, - source_state_key=schema.name, + source_state_key=schema_copy.name, ) ): rv = await conf_f(*args, **kwargs) - return _eval_rv(rv) + return _eval_rv(rv, schema_copy) # get spec for wrapped function SPEC = get_fun_spec(conf_f) @@ -296,6 +299,7 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, parallelized: bool = False, @@ -316,6 +320,7 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, parallelized: bool = False, @@ -336,6 +341,7 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, parallelized: bool = False, @@ -359,6 +365,7 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, parallelized: bool = False, @@ -378,6 +385,7 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, parallelized: bool = False, @@ -413,9 +421,10 @@ def resource( If not present, the name of the decorated function will be used. table_name (TTableHintTemplate[str], optional): An table name, if different from `name`. - max_table_nesting (int, optional): A schema hint that sets the maximum depth of nested table above which the remaining nodes are loaded as structs or JSON. This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. + max_table_nesting (int, optional): A schema hint that sets the maximum depth of nested table above which the remaining nodes are loaded as structs or JSON. + write_disposition (TTableHintTemplate[TWriteDispositionConfig], optional): Controls how to write data to a table. Accepts a shorthand string literal or configuration dictionary. Allowed shorthand string literals: `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". Write behaviour can be further customized through a configuration dictionary. For example, to obtain an SCD2 table provide `write_disposition={"disposition": "merge", "strategy": "scd2"}`. @@ -433,7 +442,12 @@ def resource( This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. schema_contract (TSchemaContract, optional): Schema contract settings that will be applied to all resources of this source (if not overridden in the resource itself) - table_format (Literal["iceberg"], optional): Defines the storage format of the table. Currently only "iceberg" is supported on Athena, other destinations ignore this hint. + + table_format (Literal["iceberg", "delta"], optional): Defines the storage format of the table. Currently only "iceberg" is supported on Athena, and "delta" on the filesystem. + Other destinations ignore this hint. + + file_format (Literal["preferred", ...], optional): Format of the file in which resource data is stored. Useful when importing external files. Use `preferred` to force + a file format that is preferred by the destination used. This setting superseded the `load_file_format` passed to pipeline `run` method. selected (bool, optional): When `True` `dlt pipeline` will extract and load this resource, if `False`, the resource will be ignored. @@ -464,6 +478,7 @@ def make_resource(_name: str, _section: str, _data: Any) -> TDltResourceImpl: merge_key=merge_key, schema_contract=schema_contract, table_format=table_format, + file_format=file_format, ) resource = _impl_cls.from_data( @@ -574,10 +589,14 @@ def transformer( data_from: TUnboundDltResource = DltResource.Empty, name: str = None, table_name: TTableHintTemplate[str] = None, + max_table_nesting: int = None, write_disposition: TTableHintTemplate[TWriteDisposition] = None, columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, + schema_contract: TTableHintTemplate[TSchemaContract] = None, + table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, parallelized: bool = False, @@ -591,10 +610,14 @@ def transformer( data_from: TUnboundDltResource = DltResource.Empty, name: TTableHintTemplate[str] = None, table_name: TTableHintTemplate[str] = None, + max_table_nesting: int = None, write_disposition: TTableHintTemplate[TWriteDisposition] = None, columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, + schema_contract: TTableHintTemplate[TSchemaContract] = None, + table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, parallelized: bool = False, @@ -612,10 +635,14 @@ def transformer( data_from: TUnboundDltResource = DltResource.Empty, name: str = None, table_name: TTableHintTemplate[str] = None, + max_table_nesting: int = None, write_disposition: TTableHintTemplate[TWriteDisposition] = None, columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, + schema_contract: TTableHintTemplate[TSchemaContract] = None, + table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, parallelized: bool = False, @@ -629,10 +656,14 @@ def transformer( data_from: TUnboundDltResource = DltResource.Empty, name: TTableHintTemplate[str] = None, table_name: TTableHintTemplate[str] = None, + max_table_nesting: int = None, write_disposition: TTableHintTemplate[TWriteDisposition] = None, columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, + schema_contract: TTableHintTemplate[TSchemaContract] = None, + table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, parallelized: bool = False, @@ -646,10 +677,14 @@ def transformer( data_from: TUnboundDltResource = DltResource.Empty, name: TTableHintTemplate[str] = None, table_name: TTableHintTemplate[str] = None, + max_table_nesting: int = None, write_disposition: TTableHintTemplate[TWriteDisposition] = None, columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, + schema_contract: TTableHintTemplate[TSchemaContract] = None, + table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, parallelized: bool = False, @@ -692,6 +727,8 @@ def transformer( table_name (TTableHintTemplate[str], optional): An table name, if different from `name`. This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. + max_table_nesting (int, optional): A schema hint that sets the maximum depth of nested table above which the remaining nodes are loaded as structs or JSON. + write_disposition (Literal["skip", "append", "replace", "merge"], optional): Controls how to write data to a table. `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. @@ -704,6 +741,14 @@ def transformer( merge_key (str | Sequence[str]): A column name or a list of column names that define a merge key. Typically used with "merge" write disposition to remove overlapping data ranges ie. to keep a single record for a given day. This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. + schema_contract (TSchemaContract, optional): Schema contract settings that will be applied to all resources of this source (if not overridden in the resource itself) + + table_format (Literal["iceberg", "delta"], optional): Defines the storage format of the table. Currently only "iceberg" is supported on Athena, and "delta" on the filesystem. + Other destinations ignore this hint. + + file_format (Literal["preferred", ...], optional): Format of the file in which resource data is stored. Useful when importing external files. Use `preferred` to force + a file format that is preferred by the destination used. This setting superseded the `load_file_format` passed to pipeline `run` method. + selected (bool, optional): When `True` `dlt pipeline` will extract and load this resource, if `False`, the resource will be ignored. spec (Type[BaseConfiguration], optional): A specification of configuration and secret values required by the source. @@ -722,10 +767,14 @@ def transformer( f, name=name, table_name=table_name, + max_table_nesting=max_table_nesting, write_disposition=write_disposition, columns=columns, primary_key=primary_key, merge_key=merge_key, + schema_contract=schema_contract, + table_format=table_format, + file_format=file_format, selected=selected, spec=spec, standalone=standalone, @@ -741,8 +790,11 @@ def _maybe_load_schema_for_callable(f: AnyFun, name: str) -> Optional[Schema]: try: file = inspect.getsourcefile(f) if file: - return SchemaStorage.load_schema_file(os.path.dirname(file), name) - + schema = SchemaStorage.load_schema_file( + os.path.dirname(file), name, remove_processing_hints=True + ) + schema.update_normalizers() + return schema except SchemaNotFoundError: pass return None diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index f8966c3ced..485a01eb99 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -31,7 +31,7 @@ from dlt.common.storages.load_package import ( ParsedLoadJobFileName, LoadPackageStateInjectableContext, - TPipelineStateDoc, + TLoadPackageState, commit_load_package_state, ) from dlt.common.utils import get_callable_name, get_full_class_name @@ -45,7 +45,6 @@ from dlt.extract.storage import ExtractStorage from dlt.extract.extractors import ObjectExtractor, ArrowExtractor, Extractor from dlt.extract.utils import get_data_item_format -from dlt.pipeline.drop import drop_resources def data_to_sources( @@ -170,6 +169,9 @@ def add_item(item: Any) -> bool: class Extract(WithStepInfo[ExtractMetrics, ExtractInfo]): + original_data: Any + """Original data from which the extracted DltSource was created. Will be used to describe in extract info""" + def __init__( self, schema_storage: SchemaStorage, @@ -181,6 +183,7 @@ def __init__( self.collector = collector self.schema_storage = schema_storage self.extract_storage = ExtractStorage(normalize_storage_config) + # TODO: this should be passed together with DltSource to extract() self.original_data: Any = original_data super().__init__() @@ -367,10 +370,12 @@ def extract( source: DltSource, max_parallel_items: int, workers: int, - load_package_state_update: Optional[Dict[str, Any]] = None, + load_package_state_update: Optional[TLoadPackageState] = None, ) -> str: # generate load package to be able to commit all the sources together later - load_id = self.extract_storage.create_load_package(source.discover_schema()) + load_id = self.extract_storage.create_load_package( + source.schema, reuse_exiting_package=True + ) with Container().injectable_context( SourceSchemaInjectableContext(source.schema) ), Container().injectable_context( @@ -388,7 +393,7 @@ def extract( ) ): if load_package_state_update: - load_package.state.update(load_package_state_update) # type: ignore[typeddict-item] + load_package.state.update(load_package_state_update) # reset resource states, the `extracted` list contains all the explicit resources and all their parents for resource in source.resources.extracted.values(): @@ -405,14 +410,10 @@ def extract( commit_load_package_state() return load_id - def commit_packages(self, pipline_state_doc: TPipelineStateDoc = None) -> None: - """Commits all extracted packages to normalize storage, and adds the pipeline state to the load package""" + def commit_packages(self) -> None: + """Commits all extracted packages to normalize storage""" # commit load packages for load_id, metrics in self._load_id_metrics.items(): - if pipline_state_doc: - package_state = self.extract_storage.new_packages.get_load_package_state(load_id) - package_state["pipeline_state"] = {**pipline_state_doc, "dlt_load_id": load_id} - self.extract_storage.new_packages.save_load_package_state(load_id, package_state) self.extract_storage.commit_new_load_package( load_id, self.schema_storage[metrics[0]["schema_name"]] ) diff --git a/dlt/extract/extractors.py b/dlt/extract/extractors.py index 48f0d6968e..4a1de2517d 100644 --- a/dlt/extract/extractors.py +++ b/dlt/extract/extractors.py @@ -1,14 +1,14 @@ from copy import copy -from typing import Set, Dict, Any, Optional, List +from typing import Set, Dict, Any, Optional, List, Union from dlt.common.configuration import known_sections, resolve_configuration, with_config from dlt.common import logger from dlt.common.configuration.specs import BaseConfiguration, configspec +from dlt.common.data_writers import DataWriterMetrics from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.exceptions import MissingDependencyException - from dlt.common.runtime.collector import Collector, NULL_COLLECTOR -from dlt.common.typing import TDataItems, TDataItem +from dlt.common.typing import TDataItems, TDataItem, TLoaderFileFormat from dlt.common.schema import Schema, utils from dlt.common.schema.typing import ( TSchemaContractDict, @@ -17,9 +17,9 @@ TTableSchemaColumns, TPartialTableSchema, ) -from dlt.extract.hints import HintsMeta +from dlt.extract.hints import HintsMeta, TResourceHints from dlt.extract.resource import DltResource -from dlt.extract.items import TableNameMeta +from dlt.extract.items import DataItemWithMeta, TableNameMeta from dlt.extract.storage import ExtractorItemStorage from dlt.normalize.configuration import ItemsNormalizerConfiguration @@ -47,6 +47,50 @@ def materialize_schema_item() -> MaterializedEmptyList: return MaterializedEmptyList() +class ImportFileMeta(HintsMeta): + __slots__ = ("file_path", "metrics", "file_format") + + def __init__( + self, + file_path: str, + metrics: DataWriterMetrics, + file_format: TLoaderFileFormat = None, + hints: TResourceHints = None, + create_table_variant: bool = None, + ) -> None: + super().__init__(hints, create_table_variant) + self.file_path = file_path + self.metrics = metrics + self.file_format = file_format + + +def with_file_import( + file_path: str, + file_format: TLoaderFileFormat, + items_count: int = 0, + hints: Union[TResourceHints, TDataItem] = None, +) -> DataItemWithMeta: + """Marks file under `file_path` to be associated with current resource and imported into the load package as a file of + type `file_format`. + + You can provide optional `hints` that will be applied to the current resource. Note that you should avoid schema inference at + runtime if possible and if that is not possible - to do that only once per extract process. Use `make_hints` in `mark` module + to create hints. You can also pass Arrow table or Pandas data frame form which schema will be taken (but content discarded). + Create `TResourceHints` with `make_hints`. + + If number of records in `file_path` is known, pass it in `items_count` so `dlt` can generate correct extract metrics. + + Note that `dlt` does not sniff schemas from data and will not guess right file format for you. + """ + metrics = DataWriterMetrics(file_path, items_count, 0, 0, 0) + item: TDataItem = None + # if hints are dict assume that this is dlt schema, if not - that it is arrow table + if not isinstance(hints, dict): + item = hints + hints = None + return DataItemWithMeta(ImportFileMeta(file_path, metrics, file_format, hints, False), item) + + class Extractor: @configspec class ExtractorConfiguration(BaseConfiguration): @@ -78,7 +122,7 @@ def __init__( def write_items(self, resource: DltResource, items: TDataItems, meta: Any) -> None: """Write `items` to `resource` optionally computing table schemas and revalidating/filtering data""" - if isinstance(meta, HintsMeta): + if isinstance(meta, HintsMeta) and meta.hints: # update the resource with new hints, remove all caches so schema is recomputed # and contracts re-applied resource.merge_hints(meta.hints, meta.create_table_variant) @@ -93,7 +137,7 @@ def write_items(self, resource: DltResource, items: TDataItems, meta: Any) -> No self._write_to_static_table(resource, table_name, items, meta) else: # table has name or other hints depending on data items - self._write_to_dynamic_table(resource, items) + self._write_to_dynamic_table(resource, items, meta) def write_empty_items_file(self, table_name: str) -> None: table_name = self.naming.normalize_table_identifier(table_name) @@ -129,7 +173,24 @@ def _write_item( if isinstance(items, MaterializedEmptyList): self.resources_with_empty.add(resource_name) - def _write_to_dynamic_table(self, resource: DltResource, items: TDataItems) -> None: + def _import_item( + self, + table_name: str, + resource_name: str, + meta: ImportFileMeta, + ) -> None: + metrics = self.item_storage.import_items_file( + self.load_id, + self.schema.name, + table_name, + meta.file_path, + meta.metrics, + meta.file_format, + ) + self.collector.update(table_name, inc=metrics.items_count) + self.resources_with_items.add(resource_name) + + def _write_to_dynamic_table(self, resource: DltResource, items: TDataItems, meta: Any) -> None: if not isinstance(items, list): items = [items] @@ -143,7 +204,10 @@ def _write_to_dynamic_table(self, resource: DltResource, items: TDataItems) -> N ) # write to storage with inferred table name if table_name not in self._filtered_tables: - self._write_item(table_name, resource.name, item) + if isinstance(meta, ImportFileMeta): + self._import_item(table_name, resource.name, meta) + else: + self._write_item(table_name, resource.name, item) def _write_to_static_table( self, resource: DltResource, table_name: str, items: TDataItems, meta: Any @@ -151,11 +215,16 @@ def _write_to_static_table( if table_name not in self._table_contracts: items = self._compute_and_update_table(resource, table_name, items, meta) if table_name not in self._filtered_tables: - self._write_item(table_name, resource.name, items) + if isinstance(meta, ImportFileMeta): + self._import_item(table_name, resource.name, meta) + else: + self._write_item(table_name, resource.name, items) def _compute_table(self, resource: DltResource, items: TDataItems, meta: Any) -> TTableSchema: """Computes a schema for a new or dynamic table and normalizes identifiers""" - return self.schema.normalize_table_identifiers(resource.compute_table_schema(items, meta)) + return utils.normalize_table_identifiers( + resource.compute_table_schema(items, meta), self.schema.naming + ) def _compute_and_update_table( self, resource: DltResource, table_name: str, items: TDataItems, meta: Any @@ -173,11 +242,11 @@ def _compute_and_update_table( # this is a new table so allow evolve once if schema_contract["columns"] != "evolve" and self.schema.is_new_table(table_name): - computed_table["x-normalizer"] = {"evolve-columns-once": True} # type: ignore[typeddict-unknown-key] + computed_table["x-normalizer"] = {"evolve-columns-once": True} existing_table = self.schema._schema_tables.get(table_name, None) if existing_table: # TODO: revise this. computed table should overwrite certain hints (ie. primary and merge keys) completely - diff_table = utils.diff_table(existing_table, computed_table) + diff_table = utils.diff_table(self.schema.name, existing_table, computed_table) else: diff_table = computed_table @@ -335,7 +404,7 @@ def _compute_table( computed_table = super()._compute_table(resource, item, Any) # Merge the columns to include primary_key and other hints that may be set on the resource if arrow_table: - utils.merge_table(computed_table, arrow_table) + utils.merge_table(self.schema.name, computed_table, arrow_table) else: arrow_table = copy(computed_table) arrow_table["columns"] = pyarrow.py_arrow_to_table_schema_columns(item.schema) @@ -353,8 +422,7 @@ def _compute_table( } # normalize arrow table before merging - arrow_table = self.schema.normalize_table_identifiers(arrow_table) - + arrow_table = utils.normalize_table_identifiers(arrow_table, self.schema.naming) # issue warnings when overriding computed with arrow override_warn: bool = False for col_name, column in arrow_table["columns"].items(): diff --git a/dlt/extract/hints.py b/dlt/extract/hints.py index 6fd1928970..dce375afb0 100644 --- a/dlt/extract/hints.py +++ b/dlt/extract/hints.py @@ -5,6 +5,7 @@ from dlt.common.schema.typing import ( TColumnNames, TColumnProp, + TFileFormat, TPartialTableSchema, TTableSchema, TTableSchemaColumns, @@ -14,6 +15,7 @@ TTableFormat, TSchemaContract, DEFAULT_VALIDITY_COLUMN_NAMES, + MERGE_STRATEGIES, ) from dlt.common.schema.utils import ( DEFAULT_WRITE_DISPOSITION, @@ -25,6 +27,7 @@ ) from dlt.common.typing import TDataItem from dlt.common.utils import clone_dict_nested +from dlt.common.normalizers.json.relational import DataItemNormalizer from dlt.common.validation import validate_dict_ignoring_xkeys from dlt.extract.exceptions import ( DataItemRequiredForDynamicTableHints, @@ -48,6 +51,7 @@ class TResourceHints(TypedDict, total=False): incremental: Incremental[Any] schema_contract: TTableHintTemplate[TSchemaContract] table_format: TTableHintTemplate[TTableFormat] + file_format: TTableHintTemplate[TFileFormat] validator: ValidateItem original_columns: TTableHintTemplate[TAnySchemaColumns] @@ -72,6 +76,7 @@ def make_hints( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, ) -> TResourceHints: """A convenience function to create resource hints. Accepts both static and dynamic hints based on data. @@ -91,6 +96,7 @@ def make_hints( columns=clean_columns, # type: ignore schema_contract=schema_contract, # type: ignore table_format=table_format, # type: ignore + file_format=file_format, # type: ignore ) if not table_name: new_template.pop("name") @@ -209,6 +215,7 @@ def apply_hints( schema_contract: TTableHintTemplate[TSchemaContract] = None, additional_table_hints: Optional[Dict[str, TTableHintTemplate[Any]]] = None, table_format: TTableHintTemplate[TTableFormat] = None, + file_format: TTableHintTemplate[TFileFormat] = None, create_table_variant: bool = False, ) -> None: """Creates or modifies existing table schema by setting provided hints. Accepts both static and dynamic hints based on data. @@ -256,6 +263,7 @@ def apply_hints( merge_key, schema_contract, table_format, + file_format, ) else: t = self._clone_hints(t) @@ -320,6 +328,11 @@ def apply_hints( t["table_format"] = table_format else: t.pop("table_format", None) + if file_format is not None: + if file_format: + t["file_format"] = file_format + else: + t.pop("file_format", None) # set properties that can't be passed to make_hints if incremental is not None: @@ -331,6 +344,7 @@ def _set_hints( self, hints_template: TResourceHints, create_table_variant: bool = False ) -> None: DltResourceHints.validate_dynamic_hints(hints_template) + DltResourceHints.validate_write_disposition_hint(hints_template.get("write_disposition")) if create_table_variant: table_name: str = hints_template["name"] # type: ignore[assignment] # incremental cannot be specified in variant @@ -375,6 +389,7 @@ def merge_hints( incremental=hints_template.get("incremental"), schema_contract=hints_template.get("schema_contract"), table_format=hints_template.get("table_format"), + file_format=hints_template.get("file_format"), create_table_variant=create_table_variant, ) @@ -424,13 +439,11 @@ def _merge_write_disposition_dict(dict_: Dict[str, Any]) -> None: @staticmethod def _merge_merge_disposition_dict(dict_: Dict[str, Any]) -> None: - """Merges merge disposition dict into x-hints on in place.""" + """Merges merge disposition dict into x-hints in place.""" mddict: TMergeDispositionDict = deepcopy(dict_["write_disposition"]) if mddict is not None: - dict_["x-merge-strategy"] = ( - mddict["strategy"] if "strategy" in mddict else DEFAULT_MERGE_STRATEGY - ) + dict_["x-merge-strategy"] = mddict.get("strategy", DEFAULT_MERGE_STRATEGY) # add columns for `scd2` merge strategy if dict_.get("x-merge-strategy") == "scd2": if mddict.get("validity_column_names") is None: @@ -452,7 +465,7 @@ def _merge_merge_disposition_dict(dict_: Dict[str, Any]) -> None: "x-valid-to": True, "x-active-record-timestamp": mddict.get("active_record_timestamp"), } - hash_ = mddict.get("row_version_column_name", "_dlt_id") + hash_ = mddict.get("row_version_column_name", DataItemNormalizer.C_DLT_ID) dict_["columns"][hash_] = { "name": hash_, "nullable": False, @@ -484,3 +497,13 @@ def validate_dynamic_hints(template: TResourceHints) -> None: raise InconsistentTableTemplate( f"Table name {table_name} must be a function if any other table hint is a function" ) + + @staticmethod + def validate_write_disposition_hint(wd: TTableHintTemplate[TWriteDispositionConfig]) -> None: + if isinstance(wd, dict) and wd["disposition"] == "merge": + wd = cast(TMergeDispositionDict, wd) + if "strategy" in wd and wd["strategy"] not in MERGE_STRATEGIES: + raise ValueError( + f'`{wd["strategy"]}` is not a valid merge strategy. ' + f"""Allowed values: {', '.join(['"' + s + '"' for s in MERGE_STRATEGIES])}.""" + ) diff --git a/dlt/extract/incremental/__init__.py b/dlt/extract/incremental/__init__.py index bcb6b1cc9a..11f989e0b2 100644 --- a/dlt/extract/incremental/__init__.py +++ b/dlt/extract/incremental/__init__.py @@ -269,11 +269,10 @@ def parse_native_representation(self, native_value: Any) -> None: self._primary_key = merged._primary_key self.allow_external_schedulers = merged.allow_external_schedulers self.row_order = merged.row_order + self.__is_resolved__ = self.__is_resolved__ else: # TODO: Maybe check if callable(getattr(native_value, '__lt__', None)) # Passing bare value `incremental=44` gets parsed as initial_value self.initial_value = native_value - if not self.is_partial(): - self.resolve() def get_state(self) -> IncrementalColumnState: """Returns an Incremental state for a particular cursor column""" @@ -357,6 +356,7 @@ def _join_external_scheduler(self) -> None: f"Specified Incremental last value type {param_type} is not supported. Please use" f" DateTime, Date, float, int or str to join external schedulers.({ex})" ) + return if param_type is Any: logger.warning( @@ -491,7 +491,8 @@ def __call__(self, rows: TDataItems, meta: Any = None) -> Optional[TDataItems]: return rows -Incremental.EMPTY = Incremental[Any]("") +Incremental.EMPTY = Incremental[Any]() +Incremental.EMPTY.__is_resolved__ = True class IncrementalResourceWrapper(ItemTransform[TDataItem]): diff --git a/dlt/extract/incremental/transform.py b/dlt/extract/incremental/transform.py index 8b4cae4090..947e21f7b8 100644 --- a/dlt/extract/incremental/transform.py +++ b/dlt/extract/incremental/transform.py @@ -213,7 +213,7 @@ def compute_unique_values(self, item: "TAnyArrowItem", unique_columns: List[str] def compute_unique_values_with_index( self, item: "TAnyArrowItem", unique_columns: List[str] - ) -> List[Tuple[int, str]]: + ) -> List[Tuple[Any, str]]: if not unique_columns: return [] indices = item[self._dlt_index].to_pylist() @@ -318,12 +318,12 @@ def __call__( for i, uq_val in unique_values_index if uq_val in self.start_unique_hashes ] - # find rows with unique ids that were stored from previous run - remove_idx = pa.array(i for i, _ in unique_values_index) - # Filter the table - tbl = tbl.filter( - pa.compute.invert(pa.compute.is_in(tbl[self._dlt_index], remove_idx)) - ) + if len(unique_values_index) > 0: + # find rows with unique ids that were stored from previous run + remove_idx = pa.array(i for i, _ in unique_values_index) + tbl = tbl.filter( + pa.compute.invert(pa.compute.is_in(tbl[self._dlt_index], remove_idx)) + ) if ( self.last_value is None diff --git a/dlt/extract/items.py b/dlt/extract/items.py index fec31e2846..4cf8d2191f 100644 --- a/dlt/extract/items.py +++ b/dlt/extract/items.py @@ -160,6 +160,10 @@ class FilterItem(ItemTransform[bool]): def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]: if isinstance(item, list): + # preserve empty lists + if len(item) == 0: + return item + if self._f_meta: item = [i for i in item if self._f_meta(i, meta)] else: diff --git a/dlt/extract/resource.py b/dlt/extract/resource.py index e35b35c6c8..7421cdaf60 100644 --- a/dlt/extract/resource.py +++ b/dlt/extract/resource.py @@ -1,4 +1,3 @@ -from copy import deepcopy import inspect from functools import partial from typing import ( @@ -14,6 +13,7 @@ ) from typing_extensions import TypeVar, Self +from dlt.common import logger from dlt.common.configuration.inject import get_fun_spec, with_config from dlt.common.configuration.resolve import inject_section from dlt.common.configuration.specs import known_sections @@ -396,6 +396,11 @@ def _gen_wrap(gen: TPipeStep) -> TPipeStep: else: # keep function as function to not evaluate generators before pipe starts self._pipe.replace_gen(partial(_gen_wrap, gen)) + else: + logger.warning( + f"Setting add_limit to a transformer {self.name} has no effect. Set the limit on" + " the top level resource." + ) return self def parallelize(self: TDltResourceImpl) -> TDltResourceImpl: diff --git a/dlt/extract/source.py b/dlt/extract/source.py index 658f884c40..7732c4f056 100644 --- a/dlt/extract/source.py +++ b/dlt/extract/source.py @@ -11,6 +11,7 @@ from dlt.common.normalizers.json.relational import DataItemNormalizer as RelationalNormalizer from dlt.common.schema import Schema from dlt.common.schema.typing import TColumnName, TSchemaContract +from dlt.common.schema.utils import normalize_table_identifiers from dlt.common.typing import StrAny, TDataItem from dlt.common.configuration.container import Container from dlt.common.pipeline import ( @@ -214,7 +215,7 @@ def from_data(cls, schema: Schema, section: str, data: Any) -> Self: def name(self) -> str: return self._schema.name - # TODO: 4 properties below must go somewhere else ie. into RelationalSchema which is Schema + Relational normalizer. + # TODO: max_table_nesting/root_key below must go somewhere else ie. into RelationalSchema which is Schema + Relational normalizer. @property def max_table_nesting(self) -> int: """A schema hint that sets the maximum depth of nested table above which the remaining nodes are loaded as structs or JSON.""" @@ -222,49 +223,67 @@ def max_table_nesting(self) -> int: @max_table_nesting.setter def max_table_nesting(self, value: int) -> None: - RelationalNormalizer.update_normalizer_config(self._schema, {"max_nesting": value}) - - @property - def schema_contract(self) -> TSchemaContract: - return self.schema.settings["schema_contract"] - - @schema_contract.setter - def schema_contract(self, settings: TSchemaContract) -> None: - self.schema.set_schema_contract(settings) - - @property - def exhausted(self) -> bool: - """check all selected pipes wether one of them has started. if so, the source is exhausted.""" - for resource in self._resources.extracted.values(): - item = resource._pipe.gen - if inspect.isgenerator(item): - if inspect.getgeneratorstate(item) != "GEN_CREATED": - return True - return False + if value is None: + # this also check the normalizer type + config = RelationalNormalizer.get_normalizer_config(self._schema) + config.pop("max_nesting", None) + else: + RelationalNormalizer.update_normalizer_config(self._schema, {"max_nesting": value}) @property def root_key(self) -> bool: """Enables merging on all resources by propagating root foreign key to child tables. This option is most useful if you plan to change write disposition of a resource to disable/enable merge""" + # this also check the normalizer type config = RelationalNormalizer.get_normalizer_config(self._schema).get("propagation") + data_normalizer = self._schema.data_item_normalizer + assert isinstance(data_normalizer, RelationalNormalizer) return ( config is not None and "root" in config - and "_dlt_id" in config["root"] - and config["root"]["_dlt_id"] == "_dlt_root_id" + and data_normalizer.c_dlt_id in config["root"] + and config["root"][data_normalizer.c_dlt_id] == data_normalizer.c_dlt_root_id ) @root_key.setter def root_key(self, value: bool) -> None: + # this also check the normalizer type + config = RelationalNormalizer.get_normalizer_config(self._schema) + data_normalizer = self._schema.data_item_normalizer + assert isinstance(data_normalizer, RelationalNormalizer) + if value is True: RelationalNormalizer.update_normalizer_config( - self._schema, {"propagation": {"root": {"_dlt_id": TColumnName("_dlt_root_id")}}} + self._schema, + { + "propagation": { + "root": { + data_normalizer.c_dlt_id: TColumnName(data_normalizer.c_dlt_root_id) + } + } + }, ) else: if self.root_key: - propagation_config = RelationalNormalizer.get_normalizer_config(self._schema)[ - "propagation" - ] - propagation_config["root"].pop("_dlt_id") # type: ignore + propagation_config = config["propagation"] + propagation_config["root"].pop(data_normalizer.c_dlt_id) + + @property + def schema_contract(self) -> TSchemaContract: + return self.schema.settings.get("schema_contract") + + @schema_contract.setter + def schema_contract(self, settings: TSchemaContract) -> None: + self.schema.set_schema_contract(settings) + + @property + def exhausted(self) -> bool: + """check all selected pipes wether one of them has started. if so, the source is exhausted.""" + for resource in self._resources.extracted.values(): + item = resource._pipe.gen + if inspect.isgenerator(item): + if inspect.getgeneratorstate(item) != "GEN_CREATED": + return True + return False @property def resources(self) -> DltResourceDict: @@ -291,8 +310,8 @@ def discover_schema(self, item: TDataItem = None) -> Schema: for r in self.selected_resources.values(): # names must be normalized here with contextlib.suppress(DataItemRequiredForDynamicTableHints): - partial_table = self._schema.normalize_table_identifiers( - r.compute_table_schema(item) + partial_table = normalize_table_identifiers( + r.compute_table_schema(item), self._schema.naming ) schema.update_table(partial_table) return schema diff --git a/dlt/helpers/airflow_helper.py b/dlt/helpers/airflow_helper.py index 7d7302aab6..8494d3bba3 100644 --- a/dlt/helpers/airflow_helper.py +++ b/dlt/helpers/airflow_helper.py @@ -11,6 +11,7 @@ RetryCallState, ) +from dlt.common.known_env import DLT_DATA_DIR, DLT_PROJECT_DIR from dlt.common.exceptions import MissingDependencyException try: @@ -121,7 +122,7 @@ def __init__( dags_folder = conf.get("core", "dags_folder") # set the dlt project folder to dags - os.environ["DLT_PROJECT_DIR"] = dags_folder + os.environ[DLT_PROJECT_DIR] = dags_folder # check if /data mount is available if use_data_folder and os.path.exists("/home/airflow/gcs/data"): @@ -129,7 +130,7 @@ def __init__( else: # create random path data_dir = os.path.join(local_data_folder or gettempdir(), f"dlt_{uniq_id(8)}") - os.environ["DLT_DATA_DIR"] = data_dir + os.environ[DLT_DATA_DIR] = data_dir # delete existing config providers in container, they will get reloaded on next use if ConfigProvidersContext in Container(): @@ -400,7 +401,7 @@ def add_run( """ # make sure that pipeline was created after dag was initialized - if not pipeline.pipelines_dir.startswith(os.environ["DLT_DATA_DIR"]): + if not pipeline.pipelines_dir.startswith(os.environ[DLT_DATA_DIR]): raise ValueError( "Please create your Pipeline instance after AirflowTasks are created. The dlt" " pipelines directory is not set correctly." diff --git a/dlt/helpers/dbt/runner.py b/dlt/helpers/dbt/runner.py index c68931d7db..266581c785 100644 --- a/dlt/helpers/dbt/runner.py +++ b/dlt/helpers/dbt/runner.py @@ -1,3 +1,4 @@ +import sys import os from subprocess import CalledProcessError import giturlparse @@ -154,12 +155,12 @@ def _run_dbt_command( try: i = iter_stdout_with_result(self.venv, "python", "-c", script) while True: - print(next(i).strip()) + sys.stdout.write(next(i).strip()) except StopIteration as si: # return result from generator return si.value # type: ignore except CalledProcessError as cpe: - print(cpe.stderr) + sys.stderr.write(cpe.stderr) raise def run( diff --git a/dlt/load/configuration.py b/dlt/load/configuration.py index 8abc679ea2..836da516e9 100644 --- a/dlt/load/configuration.py +++ b/dlt/load/configuration.py @@ -1,11 +1,10 @@ -from typing import TYPE_CHECKING, Literal, Optional +from typing import Optional from dlt.common.configuration import configspec +from dlt.common.destination.capabilities import TLoaderParallelismStrategy from dlt.common.storages import LoadStorageConfiguration from dlt.common.runners.configuration import PoolRunnerConfiguration, TPoolType -TLoaderParallelismStrategy = Literal["parallel", "table-sequential", "sequential"] - @configspec class LoaderConfiguration(PoolRunnerConfiguration): diff --git a/dlt/load/load.py b/dlt/load/load.py index abbeee5ddf..2290d40a1e 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -1,19 +1,15 @@ import contextlib from functools import reduce -import datetime # noqa: 251 from typing import Dict, List, Optional, Tuple, Set, Iterator, Iterable, Sequence from concurrent.futures import Executor import os -from copy import deepcopy from dlt.common import logger from dlt.common.runtime.signals import sleep from dlt.common.configuration import with_config, known_sections -from dlt.common.configuration.resolve import inject_section from dlt.common.configuration.accessors import config from dlt.common.pipeline import LoadInfo, LoadMetrics, SupportsPipeline, WithStepInfo from dlt.common.schema.utils import get_top_level_table -from dlt.common.schema.typing import TTableSchema from dlt.common.storages.load_storage import LoadPackageInfo, ParsedLoadJobFileName, TJobState from dlt.common.storages.load_package import ( LoadPackageStateInjectableContext, @@ -80,7 +76,6 @@ def __init__( self.initial_client_config = initial_client_config self.initial_staging_client_config = initial_staging_client_config self.destination = destination - self.capabilities = destination.capabilities() self.staging_destination = staging_destination self.pool = NullExecutor() self.load_storage: LoadStorage = self.create_storage(is_storage_owner) @@ -88,7 +83,7 @@ def __init__( super().__init__() def create_storage(self, is_storage_owner: bool) -> LoadStorage: - supported_file_formats = self.capabilities.supported_loader_file_formats + supported_file_formats = self.destination.capabilities().supported_loader_file_formats if self.staging_destination: supported_file_formats = ( self.staging_destination.capabilities().supported_loader_file_formats @@ -150,7 +145,7 @@ def w_spool_job( if job_info.file_format not in self.load_storage.supported_job_file_formats: raise LoadClientUnsupportedFileFormats( job_info.file_format, - self.capabilities.supported_loader_file_formats, + self.destination.capabilities().supported_loader_file_formats, file_path, ) logger.info(f"Will load file {file_path} with table name {job_info.table_name}") @@ -197,7 +192,11 @@ def w_spool_job( def spool_new_jobs(self, load_id: str, schema: Schema) -> Tuple[int, List[LoadJob]]: # use thread based pool as jobs processing is mostly I/O and we do not want to pickle jobs load_files = filter_new_jobs( - self.load_storage.list_new_jobs(load_id), self.capabilities, self.config + self.load_storage.list_new_jobs(load_id), + self.destination.capabilities( + self.destination.configuration(self.initial_client_config) + ), + self.config, ) file_count = len(load_files) if file_count == 0: @@ -259,13 +258,20 @@ def create_followup_jobs( schema.tables, starting_job.job_file_info().table_name ) # if all tables of chain completed, create follow up jobs - all_jobs = self.load_storage.normalized_packages.list_all_jobs(load_id) + all_jobs_states = self.load_storage.normalized_packages.list_all_jobs_with_states( + load_id + ) if table_chain := get_completed_table_chain( - schema, all_jobs, top_job_table, starting_job.job_file_info().job_id() + schema, all_jobs_states, top_job_table, starting_job.job_file_info().job_id() ): table_chain_names = [table["name"] for table in table_chain] + # create job infos that contain full path to job table_chain_jobs = [ - job for job in all_jobs if job.job_file_info.table_name in table_chain_names + self.load_storage.normalized_packages.job_to_job_info(load_id, *job_state) + for job_state in all_jobs_states + if job_state[1].table_name in table_chain_names + # job being completed is still in started_jobs + and job_state[0] in ("completed_jobs", "started_jobs") ] if follow_up_jobs := client.create_table_chain_completed_followup_jobs( table_chain, table_chain_jobs @@ -359,7 +365,7 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) ) ): job_client.complete_load(load_id) - self._maybe_trancate_staging_dataset(schema, job_client) + self._maybe_truncate_staging_dataset(schema, job_client) self.load_storage.complete_load_package(load_id, aborted) # collect package info @@ -374,8 +380,12 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) def load_single_package(self, load_id: str, schema: Schema) -> None: new_jobs = self.get_new_jobs_info(load_id) + # get dropped and truncated tables that were added in the extract step if refresh was requested + # NOTE: if naming convention was updated those names correspond to the old naming convention + # and they must be like that in order to drop existing tables dropped_tables = current_load_package()["state"].get("dropped_tables", []) truncated_tables = current_load_package()["state"].get("truncated_tables", []) + # initialize analytical storage ie. create dataset required by passed schema with self.get_destination_client(schema) as job_client: if (expected_update := self.load_storage.begin_schema_update(load_id)) is not None: @@ -432,10 +442,10 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: self.complete_package(load_id, schema, False) return # update counter we only care about the jobs that are scheduled to be loaded - package_info = self.load_storage.normalized_packages.get_load_package_info(load_id) - total_jobs = reduce(lambda p, c: p + len(c), package_info.jobs.values(), 0) - no_failed_jobs = len(package_info.jobs["failed_jobs"]) - no_completed_jobs = len(package_info.jobs["completed_jobs"]) + no_failed_jobs + package_jobs = self.load_storage.normalized_packages.get_load_package_jobs(load_id) + total_jobs = reduce(lambda p, c: p + len(c), package_jobs.values(), 0) + no_failed_jobs = len(package_jobs["failed_jobs"]) + no_completed_jobs = len(package_jobs["completed_jobs"]) + no_failed_jobs self.collector.update("Jobs", no_completed_jobs, total_jobs) if no_failed_jobs > 0: self.collector.update( @@ -447,26 +457,28 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: remaining_jobs = self.complete_jobs(load_id, jobs, schema) if len(remaining_jobs) == 0: # get package status - package_info = self.load_storage.normalized_packages.get_load_package_info( + package_jobs = self.load_storage.normalized_packages.get_load_package_jobs( load_id ) # possibly raise on failed jobs if self.config.raise_on_failed_jobs: - if package_info.jobs["failed_jobs"]: - failed_job = package_info.jobs["failed_jobs"][0] + if package_jobs["failed_jobs"]: + failed_job = package_jobs["failed_jobs"][0] raise LoadClientJobFailed( load_id, - failed_job.job_file_info.job_id(), - failed_job.failed_message, + failed_job.job_id(), + self.load_storage.normalized_packages.get_job_failed_message( + load_id, failed_job + ), ) # possibly raise on too many retries if self.config.raise_on_max_retries: - for new_job in package_info.jobs["new_jobs"]: - r_c = new_job.job_file_info.retry_count + for new_job in package_jobs["new_jobs"]: + r_c = new_job.retry_count if r_c > 0 and r_c % self.config.raise_on_max_retries == 0: raise LoadClientJobRetry( load_id, - new_job.job_file_info.job_id(), + new_job.job_id(), r_c, self.config.raise_on_max_retries, ) @@ -512,7 +524,7 @@ def run(self, pool: Optional[Executor]) -> TRunMetrics: return TRunMetrics(False, len(self.load_storage.list_normalized_packages())) - def _maybe_trancate_staging_dataset(self, schema: Schema, job_client: JobClientBase) -> None: + def _maybe_truncate_staging_dataset(self, schema: Schema, job_client: JobClientBase) -> None: """ Truncate the staging dataset if one used, and configuration requests truncation. diff --git a/dlt/load/utils.py b/dlt/load/utils.py index 4e5099855b..67a813f5f2 100644 --- a/dlt/load/utils.py +++ b/dlt/load/utils.py @@ -1,8 +1,8 @@ -from typing import List, Set, Iterable, Callable, Optional, Sequence +from typing import List, Set, Iterable, Callable, Optional, Tuple, Sequence from itertools import groupby from dlt.common import logger -from dlt.common.storages.load_package import LoadJobInfo, PackageStorage +from dlt.common.storages.load_package import LoadJobInfo, PackageStorage, TJobState from dlt.common.schema.utils import ( fill_hints_from_parent_and_clone_table, get_child_tables, @@ -22,7 +22,7 @@ def get_completed_table_chain( schema: Schema, - all_jobs: Iterable[LoadJobInfo], + all_jobs: Iterable[Tuple[TJobState, ParsedLoadJobFileName]], top_merged_table: TTableSchema, being_completed_job_id: str = None, ) -> List[TTableSchema]: @@ -54,8 +54,8 @@ def get_completed_table_chain( else: # all jobs must be completed in order for merge to be created if any( - job.state not in ("failed_jobs", "completed_jobs") - and job.job_file_info.job_id() != being_completed_job_id + job[0] not in ("failed_jobs", "completed_jobs") + and job[1].job_id() != being_completed_job_id for job in table_jobs ): return None @@ -113,12 +113,15 @@ def init_client( ) ) + # get tables to drop + drop_table_names = {table["name"] for table in drop_tables} if drop_tables else set() + applied_update = _init_dataset_and_update_schema( job_client, expected_update, tables_with_jobs | dlt_tables, truncate_table_names, - drop_tables=drop_tables, + drop_tables=drop_table_names, ) # update the staging dataset if client supports this @@ -138,6 +141,7 @@ def init_client( staging_tables | {schema.version_table_name}, # keep only schema version staging_tables, # all eligible tables must be also truncated staging_info=True, + drop_tables=drop_table_names, # try to drop all the same tables on staging ) return applied_update @@ -149,7 +153,7 @@ def _init_dataset_and_update_schema( update_tables: Iterable[str], truncate_tables: Iterable[str] = None, staging_info: bool = False, - drop_tables: Optional[List[TTableSchema]] = None, + drop_tables: Iterable[str] = None, ) -> TSchemaTables: staging_text = "for staging dataset" if staging_info else "" logger.info( @@ -158,12 +162,17 @@ def _init_dataset_and_update_schema( ) job_client.initialize_storage() if drop_tables: - drop_table_names = [table["name"] for table in drop_tables] if hasattr(job_client, "drop_tables"): logger.info( - f"Client for {job_client.config.destination_type} will drop tables {staging_text}" + f"Client for {job_client.config.destination_type} will drop tables" + f" {drop_tables} {staging_text}" + ) + job_client.drop_tables(*drop_tables, delete_schema=True) + else: + logger.warning( + f"Client for {job_client.config.destination_type} does not implement drop table." + f" Following tables {drop_tables} will not be dropped {staging_text}" ) - job_client.drop_tables(*drop_table_names, delete_schema=True) logger.info( f"Client for {job_client.config.destination_type} will update schema to package schema" diff --git a/dlt/normalize/items_normalizers.py b/dlt/normalize/items_normalizers.py index 6678f6edee..5f84d57d7a 100644 --- a/dlt/normalize/items_normalizers.py +++ b/dlt/normalize/items_normalizers.py @@ -6,6 +6,7 @@ from dlt.common.data_writers import DataWriterMetrics from dlt.common.data_writers.writers import ArrowToObjectAdapter from dlt.common.json import custom_pua_decode, may_have_pua +from dlt.common.normalizers.json.relational import DataItemNormalizer as RelationalNormalizer from dlt.common.runtime import signals from dlt.common.schema.typing import TSchemaEvolutionMode, TTableSchemaColumns, TSchemaContractDict from dlt.common.schema.utils import has_table_seen_data @@ -149,7 +150,7 @@ def _normalize_chunk( continue # theres a new table or new columns in existing table # update schema and save the change - schema.update_table(partial_table) + schema.update_table(partial_table, normalize_identifiers=False) table_updates = schema_update.setdefault(table_name, []) table_updates.append(partial_table) @@ -200,6 +201,7 @@ def __call__( ) schema_updates.append(partial_update) logger.debug(f"Processed {line_no+1} lines from file {extracted_items_file}") + # empty json files are when replace write disposition is used in order to truncate table(s) if line is None and root_table_name in self.schema.tables: # TODO: we should push the truncate jobs via package state # not as empty jobs. empty jobs should be reserved for @@ -234,8 +236,9 @@ def _write_with_dlt_columns( schema = self.schema load_id = self.load_id schema_update: TSchemaUpdate = {} + data_normalizer = schema.data_item_normalizer - if add_dlt_id: + if add_dlt_id and isinstance(data_normalizer, RelationalNormalizer): table_update = schema.update_table( { "name": root_table_name, @@ -249,7 +252,7 @@ def _write_with_dlt_columns( new_columns.append( ( -1, - pa.field("_dlt_id", pyarrow.pyarrow.string(), nullable=False), + pa.field(data_normalizer.c_dlt_id, pyarrow.pyarrow.string(), nullable=False), lambda batch: pa.array(generate_dlt_ids(batch.num_rows)), ) ) @@ -375,3 +378,32 @@ def __call__(self, extracted_items_file: str, root_table_name: str) -> List[TSch ) return base_schema_update + + +class FileImportNormalizer(ItemsNormalizer): + def __call__(self, extracted_items_file: str, root_table_name: str) -> List[TSchemaUpdate]: + logger.info( + f"Table {root_table_name} {self.item_storage.writer_spec.file_format} file" + f" {extracted_items_file} will be directly imported without normalization" + ) + completed_columns = self.schema.get_table_columns(root_table_name) + if not completed_columns: + logger.warning( + f"Table {root_table_name} has no completed columns for imported file" + f" {extracted_items_file} and will not be created! Pass column hints to the" + " resource or with dlt.mark.with_hints or create the destination table yourself." + ) + with self.normalize_storage.extracted_packages.storage.open_file( + extracted_items_file, "rb" + ) as f: + # TODO: sniff the schema depending on a file type + file_metrics = DataWriterMetrics(extracted_items_file, 0, f.tell(), 0, 0) + parts = ParsedLoadJobFileName.parse(extracted_items_file) + self.item_storage.import_items_file( + self.load_id, + self.schema.name, + parts.table_name, + self.normalize_storage.extracted_packages.storage.make_full_path(extracted_items_file), + file_metrics, + ) + return [] diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 75cb9be707..98154cd5cf 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -1,33 +1,23 @@ import os import itertools -from typing import Callable, List, Dict, NamedTuple, Sequence, Tuple, Set, Optional +from typing import List, Dict, Sequence, Optional, Callable from concurrent.futures import Future, Executor from dlt.common import logger from dlt.common.runtime.signals import sleep from dlt.common.configuration import with_config, known_sections from dlt.common.configuration.accessors import config -from dlt.common.configuration.container import Container -from dlt.common.data_writers import ( - DataWriter, - DataWriterMetrics, - TDataItemFormat, - resolve_best_writer_spec, - get_best_writer_spec, - is_native_writer, -) +from dlt.common.data_writers import DataWriterMetrics from dlt.common.data_writers.writers import EMPTY_DATA_WRITER_METRICS from dlt.common.runners import TRunMetrics, Runnable, NullExecutor from dlt.common.runtime import signals from dlt.common.runtime.collector import Collector, NULL_COLLECTOR -from dlt.common.schema.typing import TStoredSchema, TTableSchema +from dlt.common.schema.typing import TStoredSchema from dlt.common.schema.utils import merge_schema_updates from dlt.common.storages import ( NormalizeStorage, SchemaStorage, LoadStorage, - LoadStorageConfiguration, - NormalizeStorageConfiguration, ParsedLoadJobFileName, ) from dlt.common.schema import TSchemaUpdate, Schema @@ -40,20 +30,10 @@ ) from dlt.common.storages.exceptions import LoadPackageNotFound from dlt.common.storages.load_package import LoadPackageInfo -from dlt.common.utils import chunks from dlt.normalize.configuration import NormalizeConfiguration from dlt.normalize.exceptions import NormalizeJobFailed -from dlt.normalize.items_normalizers import ( - ArrowItemsNormalizer, - JsonLItemsNormalizer, - ItemsNormalizer, -) - - -class TWorkerRV(NamedTuple): - schema_updates: List[TSchemaUpdate] - file_metrics: List[DataWriterMetrics] +from dlt.normalize.worker import w_normalize_files, group_worker_files, TWorkerRV # normalize worker wrapping function signature @@ -99,211 +79,19 @@ def create_storages(self) -> None: config=self.config._load_storage_config, ) - @staticmethod - def w_normalize_files( - config: NormalizeConfiguration, - normalize_storage_config: NormalizeStorageConfiguration, - loader_storage_config: LoadStorageConfiguration, - stored_schema: TStoredSchema, - load_id: str, - extracted_items_files: Sequence[str], - ) -> TWorkerRV: - destination_caps = config.destination_capabilities - schema_updates: List[TSchemaUpdate] = [] - # normalizers are cached per table name - item_normalizers: Dict[str, ItemsNormalizer] = {} - - preferred_file_format = ( - destination_caps.preferred_loader_file_format - or destination_caps.preferred_staging_file_format - ) - # TODO: capabilities.supported_*_formats can be None, it should have defaults - supported_file_formats = destination_caps.supported_loader_file_formats or [] - supported_table_formats = destination_caps.supported_table_formats or [] - - # process all files with data items and write to buffered item storage - with Container().injectable_context(destination_caps): - schema = Schema.from_stored_schema(stored_schema) - normalize_storage = NormalizeStorage(False, normalize_storage_config) - load_storage = LoadStorage(False, supported_file_formats, loader_storage_config) - - def _get_items_normalizer( - item_format: TDataItemFormat, table_schema: Optional[TTableSchema] - ) -> ItemsNormalizer: - table_name = table_schema["name"] - if table_name in item_normalizers: - return item_normalizers[table_name] - - if ( - "table_format" in table_schema - and table_schema["table_format"] not in supported_table_formats - ): - logger.warning( - "Destination does not support the configured `table_format` value " - f"`{table_schema['table_format']}` for table `{table_schema['name']}`. " - "The setting will probably be ignored." - ) - - items_preferred_file_format = preferred_file_format - items_supported_file_formats = supported_file_formats - if destination_caps.loader_file_format_adapter is not None: - items_preferred_file_format, items_supported_file_formats = ( - destination_caps.loader_file_format_adapter( - preferred_file_format, - ( - supported_file_formats.copy() - if isinstance(supported_file_formats, list) - else supported_file_formats - ), - table_schema=table_schema, - ) - ) - - # force file format - best_writer_spec = None - if config.loader_file_format: - if config.loader_file_format in items_supported_file_formats: - # TODO: pass supported_file_formats, when used in pipeline we already checked that - # but if normalize is used standalone `supported_loader_file_formats` may be unresolved - best_writer_spec = get_best_writer_spec( - item_format, config.loader_file_format - ) - else: - logger.warning( - f"The configured value `{config.loader_file_format}` " - "for `loader_file_format` is not supported for table " - f"`{table_schema['name']}` and will be ignored. Dlt " - "will use a supported format instead." - ) - - if best_writer_spec is None: - # find best spec among possible formats taking into account destination preference - best_writer_spec = resolve_best_writer_spec( - item_format, items_supported_file_formats, items_preferred_file_format - ) - # if best_writer_spec.file_format != preferred_file_format: - # logger.warning( - # f"For data items yielded as {item_format} jobs in file format" - # f" {preferred_file_format} cannot be created." - # f" {best_writer_spec.file_format} jobs will be used instead." - # " This may decrease the performance." - # ) - item_storage = load_storage.create_item_storage(best_writer_spec) - if not is_native_writer(item_storage.writer_cls): - logger.warning( - f"For data items yielded as {item_format} and job file format" - f" {best_writer_spec.file_format} native writer could not be found. A" - f" {item_storage.writer_cls.__name__} writer is used that internally" - f" converts {item_format}. This will degrade performance." - ) - cls = ArrowItemsNormalizer if item_format == "arrow" else JsonLItemsNormalizer - logger.info( - f"Created items normalizer {cls.__name__} with writer" - f" {item_storage.writer_cls.__name__} for item format {item_format} and file" - f" format {item_storage.writer_spec.file_format}" - ) - norm = item_normalizers[table_name] = cls( - item_storage, - normalize_storage, - schema, - load_id, - config, - ) - return norm - - def _gather_metrics_and_close( - parsed_fn: ParsedLoadJobFileName, in_exception: bool - ) -> List[DataWriterMetrics]: - writer_metrics: List[DataWriterMetrics] = [] - try: - try: - for normalizer in item_normalizers.values(): - normalizer.item_storage.close_writers(load_id, skip_flush=in_exception) - except Exception: - # if we had exception during flushing the writers, close them without flushing - if not in_exception: - for normalizer in item_normalizers.values(): - normalizer.item_storage.close_writers(load_id, skip_flush=True) - raise - finally: - # always gather metrics - for normalizer in item_normalizers.values(): - norm_metrics = normalizer.item_storage.closed_files(load_id) - writer_metrics.extend(norm_metrics) - for normalizer in item_normalizers.values(): - normalizer.item_storage.remove_closed_files(load_id) - except Exception as exc: - if in_exception: - # swallow exception if we already handle exceptions - return writer_metrics - else: - # enclose the exception during the closing in job failed exception - job_id = parsed_fn.job_id() if parsed_fn else "" - raise NormalizeJobFailed(load_id, job_id, str(exc), writer_metrics) - return writer_metrics - - parsed_file_name: ParsedLoadJobFileName = None - try: - root_tables: Set[str] = set() - for extracted_items_file in extracted_items_files: - parsed_file_name = ParsedLoadJobFileName.parse(extracted_items_file) - # normalize table name in case the normalization changed - # NOTE: this is the best we can do, until a full lineage information is in the schema - root_table_name = schema.naming.normalize_table_identifier( - parsed_file_name.table_name - ) - root_tables.add(root_table_name) - normalizer = _get_items_normalizer( - DataWriter.item_format_from_file_extension(parsed_file_name.file_format), - stored_schema["tables"].get(root_table_name, {"name": root_table_name}), - ) - logger.debug( - f"Processing extracted items in {extracted_items_file} in load_id" - f" {load_id} with table name {root_table_name} and schema {schema.name}" - ) - partial_updates = normalizer(extracted_items_file, root_table_name) - schema_updates.extend(partial_updates) - logger.debug(f"Processed file {extracted_items_file}") - except Exception as exc: - job_id = parsed_file_name.job_id() if parsed_file_name else "" - writer_metrics = _gather_metrics_and_close(parsed_file_name, in_exception=True) - raise NormalizeJobFailed(load_id, job_id, str(exc), writer_metrics) from exc - else: - writer_metrics = _gather_metrics_and_close(parsed_file_name, in_exception=False) - - logger.info(f"Processed all items in {len(extracted_items_files)} files") - return TWorkerRV(schema_updates, writer_metrics) - - def update_table(self, schema: Schema, schema_updates: List[TSchemaUpdate]) -> None: + def update_schema(self, schema: Schema, schema_updates: List[TSchemaUpdate]) -> None: for schema_update in schema_updates: for table_name, table_updates in schema_update.items(): logger.info( f"Updating schema for table {table_name} with {len(table_updates)} deltas" ) for partial_table in table_updates: - # merge columns - schema.update_table(partial_table) - - @staticmethod - def group_worker_files(files: Sequence[str], no_groups: int) -> List[Sequence[str]]: - # sort files so the same tables are in the same worker - files = list(sorted(files)) - - chunk_size = max(len(files) // no_groups, 1) - chunk_files = list(chunks(files, chunk_size)) - # distribute the remainder files to existing groups starting from the end - remainder_l = len(chunk_files) - no_groups - l_idx = 0 - while remainder_l > 0: - for idx, file in enumerate(reversed(chunk_files.pop())): - chunk_files[-l_idx - idx - remainder_l].append(file) # type: ignore - remainder_l -= 1 - l_idx = idx + 1 - return chunk_files + # merge columns where we expect identifiers to be normalized + schema.update_table(partial_table, normalize_identifiers=False) def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TWorkerRV: workers: int = getattr(self.pool, "_max_workers", 1) - chunk_files = self.group_worker_files(files, workers) + chunk_files = group_worker_files(files, workers) schema_dict: TStoredSchema = schema.to_dict() param_chunk = [ ( @@ -319,10 +107,7 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TW # return stats summary = TWorkerRV([], []) # push all tasks to queue - tasks = [ - (self.pool.submit(Normalize.w_normalize_files, *params), params) - for params in param_chunk - ] + tasks = [(self.pool.submit(w_normalize_files, *params), params) for params in param_chunk] while len(tasks) > 0: sleep(0.3) @@ -337,7 +122,7 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TW result: TWorkerRV = pending.result() try: # gather schema from all manifests, validate consistency and combine - self.update_table(schema, result[0]) + self.update_schema(schema, result[0]) summary.schema_updates.extend(result.schema_updates) summary.file_metrics.extend(result.file_metrics) # update metrics @@ -358,7 +143,7 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TW # TODO: it's time for a named tuple params = params[:3] + (schema_dict,) + params[4:] retry_pending: Future[TWorkerRV] = self.pool.submit( - Normalize.w_normalize_files, *params + w_normalize_files, *params ) tasks.append((retry_pending, params)) # remove finished tasks @@ -368,7 +153,7 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TW return summary def map_single(self, schema: Schema, load_id: str, files: Sequence[str]) -> TWorkerRV: - result = Normalize.w_normalize_files( + result = w_normalize_files( self.config, self.normalize_storage.config, self.load_storage.config, @@ -376,7 +161,7 @@ def map_single(self, schema: Schema, load_id: str, files: Sequence[str]) -> TWor load_id, files, ) - self.update_table(schema, result.schema_updates) + self.update_schema(schema, result.schema_updates) self.collector.update("Files", len(result.file_metrics)) self.collector.update( "Items", sum(result.file_metrics, EMPTY_DATA_WRITER_METRICS).items_count @@ -399,7 +184,7 @@ def spool_files( # update normalizer specific info for table_name in table_metrics: table = schema.tables[table_name] - x_normalizer = table.setdefault("x-normalizer", {}) # type: ignore[typeddict-item] + x_normalizer = table.setdefault("x-normalizer", {}) # drop evolve once for all tables that seen data x_normalizer.pop("evolve-columns-once", None) # mark that table have seen data only if there was data diff --git a/dlt/normalize/worker.py b/dlt/normalize/worker.py new file mode 100644 index 0000000000..10d0a00eb1 --- /dev/null +++ b/dlt/normalize/worker.py @@ -0,0 +1,255 @@ +from typing import Callable, List, Dict, NamedTuple, Sequence, Set, Optional, Type + +from dlt.common import logger +from dlt.common.configuration.container import Container +from dlt.common.data_writers import ( + DataWriter, + DataWriterMetrics, + create_import_spec, + resolve_best_writer_spec, + get_best_writer_spec, + is_native_writer, +) +from dlt.common.utils import chunks +from dlt.common.schema.typing import TStoredSchema, TTableSchema +from dlt.common.storages import ( + NormalizeStorage, + LoadStorage, + LoadStorageConfiguration, + NormalizeStorageConfiguration, + ParsedLoadJobFileName, +) +from dlt.common.schema import TSchemaUpdate, Schema + +from dlt.normalize.configuration import NormalizeConfiguration +from dlt.normalize.exceptions import NormalizeJobFailed +from dlt.normalize.items_normalizers import ( + ArrowItemsNormalizer, + FileImportNormalizer, + JsonLItemsNormalizer, + ItemsNormalizer, +) + + +class TWorkerRV(NamedTuple): + schema_updates: List[TSchemaUpdate] + file_metrics: List[DataWriterMetrics] + + +def group_worker_files(files: Sequence[str], no_groups: int) -> List[Sequence[str]]: + # sort files so the same tables are in the same worker + files = list(sorted(files)) + + chunk_size = max(len(files) // no_groups, 1) + chunk_files = list(chunks(files, chunk_size)) + # distribute the remainder files to existing groups starting from the end + remainder_l = len(chunk_files) - no_groups + l_idx = 0 + while remainder_l > 0: + idx = 0 + for idx, file in enumerate(reversed(chunk_files.pop())): + chunk_files[-l_idx - idx - remainder_l].append(file) # type: ignore + remainder_l -= 1 + l_idx = idx + 1 + return chunk_files + + +def w_normalize_files( + config: NormalizeConfiguration, + normalize_storage_config: NormalizeStorageConfiguration, + loader_storage_config: LoadStorageConfiguration, + stored_schema: TStoredSchema, + load_id: str, + extracted_items_files: Sequence[str], +) -> TWorkerRV: + destination_caps = config.destination_capabilities + schema_updates: List[TSchemaUpdate] = [] + # normalizers are cached per table name + item_normalizers: Dict[str, ItemsNormalizer] = {} + + preferred_file_format = ( + destination_caps.preferred_loader_file_format + or destination_caps.preferred_staging_file_format + ) + # TODO: capabilities.supported_*_formats can be None, it should have defaults + supported_file_formats = destination_caps.supported_loader_file_formats or [] + supported_table_formats = destination_caps.supported_table_formats or [] + + # process all files with data items and write to buffered item storage + with Container().injectable_context(destination_caps): + schema = Schema.from_stored_schema(stored_schema) + normalize_storage = NormalizeStorage(False, normalize_storage_config) + load_storage = LoadStorage(False, supported_file_formats, loader_storage_config) + + def _get_items_normalizer( + parsed_file_name: ParsedLoadJobFileName, table_schema: TTableSchema + ) -> ItemsNormalizer: + item_format = DataWriter.item_format_from_file_extension(parsed_file_name.file_format) + + table_name = table_schema["name"] + if table_name in item_normalizers: + return item_normalizers[table_name] + + if ( + "table_format" in table_schema + and table_schema["table_format"] not in supported_table_formats + ): + logger.warning( + "Destination does not support the configured `table_format` value " + f"`{table_schema['table_format']}` for table `{table_schema['name']}`. " + "The setting will probably be ignored." + ) + + items_preferred_file_format = preferred_file_format + items_supported_file_formats = supported_file_formats + if destination_caps.loader_file_format_adapter is not None: + items_preferred_file_format, items_supported_file_formats = ( + destination_caps.loader_file_format_adapter( + preferred_file_format, + ( + supported_file_formats.copy() + if isinstance(supported_file_formats, list) + else supported_file_formats + ), + table_schema=table_schema, + ) + ) + + best_writer_spec = None + if item_format == "file": + # if we want to import file, create a spec that may be used only for importing + best_writer_spec = create_import_spec( + parsed_file_name.file_format, items_supported_file_formats # type: ignore[arg-type] + ) + + config_loader_file_format = config.loader_file_format + if file_format := table_schema.get("file_format"): + # resource has a file format defined so use it + if file_format == "preferred": + # use destination preferred + config_loader_file_format = items_preferred_file_format + else: + # use resource format + config_loader_file_format = file_format + logger.info( + f"A file format for table {table_name} was specified to {file_format} in the" + f" resource so {config_loader_file_format} format being used." + ) + + if config_loader_file_format and best_writer_spec is None: + # force file format + if config_loader_file_format in items_supported_file_formats: + # TODO: pass supported_file_formats, when used in pipeline we already checked that + # but if normalize is used standalone `supported_loader_file_formats` may be unresolved + best_writer_spec = get_best_writer_spec(item_format, config_loader_file_format) + else: + logger.warning( + f"The configured value `{config_loader_file_format}` " + "for `loader_file_format` is not supported for table " + f"`{table_name}` and will be ignored. Dlt " + "will use a supported format instead." + ) + + if best_writer_spec is None: + # find best spec among possible formats taking into account destination preference + best_writer_spec = resolve_best_writer_spec( + item_format, items_supported_file_formats, items_preferred_file_format + ) + # if best_writer_spec.file_format != preferred_file_format: + # logger.warning( + # f"For data items yielded as {item_format} jobs in file format" + # f" {preferred_file_format} cannot be created." + # f" {best_writer_spec.file_format} jobs will be used instead." + # " This may decrease the performance." + # ) + item_storage = load_storage.create_item_storage(best_writer_spec) + if not is_native_writer(item_storage.writer_cls): + logger.warning( + f"For data items in `{table_name}` yielded as {item_format} and job file format" + f" {best_writer_spec.file_format} native writer could not be found. A" + f" {item_storage.writer_cls.__name__} writer is used that internally" + f" converts {item_format}. This will degrade performance." + ) + cls: Type[ItemsNormalizer] + if item_format == "arrow": + cls = ArrowItemsNormalizer + elif item_format == "object": + cls = JsonLItemsNormalizer + else: + cls = FileImportNormalizer + logger.info( + f"Created items normalizer {cls.__name__} with writer" + f" {item_storage.writer_cls.__name__} for item format {item_format} and file" + f" format {item_storage.writer_spec.file_format}" + ) + norm = item_normalizers[table_name] = cls( + item_storage, + normalize_storage, + schema, + load_id, + config, + ) + return norm + + def _gather_metrics_and_close( + parsed_fn: ParsedLoadJobFileName, in_exception: bool + ) -> List[DataWriterMetrics]: + writer_metrics: List[DataWriterMetrics] = [] + try: + try: + for normalizer in item_normalizers.values(): + normalizer.item_storage.close_writers(load_id, skip_flush=in_exception) + except Exception: + # if we had exception during flushing the writers, close them without flushing + if not in_exception: + for normalizer in item_normalizers.values(): + normalizer.item_storage.close_writers(load_id, skip_flush=True) + raise + finally: + # always gather metrics + for normalizer in item_normalizers.values(): + norm_metrics = normalizer.item_storage.closed_files(load_id) + writer_metrics.extend(norm_metrics) + for normalizer in item_normalizers.values(): + normalizer.item_storage.remove_closed_files(load_id) + except Exception as exc: + if in_exception: + # swallow exception if we already handle exceptions + return writer_metrics + else: + # enclose the exception during the closing in job failed exception + job_id = parsed_fn.job_id() if parsed_fn else "" + raise NormalizeJobFailed(load_id, job_id, str(exc), writer_metrics) + return writer_metrics + + parsed_file_name: ParsedLoadJobFileName = None + try: + root_tables: Set[str] = set() + for extracted_items_file in extracted_items_files: + parsed_file_name = ParsedLoadJobFileName.parse(extracted_items_file) + # normalize table name in case the normalization changed + # NOTE: this is the best we can do, until a full lineage information is in the schema + root_table_name = schema.naming.normalize_table_identifier( + parsed_file_name.table_name + ) + root_tables.add(root_table_name) + normalizer = _get_items_normalizer( + parsed_file_name, + stored_schema["tables"].get(root_table_name, {"name": root_table_name}), + ) + logger.debug( + f"Processing extracted items in {extracted_items_file} in load_id" + f" {load_id} with table name {root_table_name} and schema {schema.name}" + ) + partial_updates = normalizer(extracted_items_file, root_table_name) + schema_updates.extend(partial_updates) + logger.debug(f"Processed file {extracted_items_file}") + except Exception as exc: + job_id = parsed_file_name.job_id() if parsed_file_name else "" + writer_metrics = _gather_metrics_and_close(parsed_file_name, in_exception=True) + raise NormalizeJobFailed(load_id, job_id, str(exc), writer_metrics) from exc + else: + writer_metrics = _gather_metrics_and_close(parsed_file_name, in_exception=False) + + logger.info(f"Processed all items in {len(extracted_items_files)} files") + return TWorkerRV(schema_updates, writer_metrics) diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py index 20ba0b07d0..8041ca72e0 100644 --- a/dlt/pipeline/__init__.py +++ b/dlt/pipeline/__init__.py @@ -14,7 +14,7 @@ from dlt.pipeline.configuration import PipelineConfiguration, ensure_correct_pipeline_kwargs from dlt.pipeline.pipeline import Pipeline from dlt.pipeline.progress import _from_name as collector_from_name, TCollectorArg, _NULL_COLLECTOR -from dlt.pipeline.warnings import credentials_argument_deprecated, full_refresh_argument_deprecated +from dlt.pipeline.warnings import full_refresh_argument_deprecated TPipeline = TypeVar("TPipeline", bound=Pipeline, default=Pipeline) @@ -32,7 +32,6 @@ def pipeline( full_refresh: Optional[bool] = None, dev_mode: bool = False, refresh: Optional[TRefreshMode] = None, - credentials: Any = None, progress: TCollectorArg = _NULL_COLLECTOR, _impl_cls: Type[TPipeline] = Pipeline, # type: ignore[assignment] ) -> TPipeline: @@ -78,9 +77,6 @@ def pipeline( * `drop_resources`: Drop tables and resource state for all resources being processed. Source level state is not modified. (Note: schema history is erased) * `drop_data`: Wipe all data and resource state for all resources being processed. Schema is not modified. - credentials (Any, optional): Credentials for the `destination` ie. database connection string or a dictionary with google cloud credentials. - In most cases should be set to None, which lets `dlt` to use `secrets.toml` or environment variables to infer right credentials values. - progress(str, Collector): A progress monitor that shows progress bars, console or log messages with current information on sources, resources, data items etc. processed in `extract`, `normalize` and `load` stage. Pass a string with a collector name or configure your own by choosing from `dlt.progress` module. We support most of the progress libraries: try passing `tqdm`, `enlighten` or `alive_progress` or `log` to write to console/log. @@ -109,7 +105,6 @@ def pipeline( full_refresh: Optional[bool] = None, dev_mode: bool = False, refresh: Optional[TRefreshMode] = None, - credentials: Any = None, progress: TCollectorArg = _NULL_COLLECTOR, _impl_cls: Type[TPipeline] = Pipeline, # type: ignore[assignment] **injection_kwargs: Any, @@ -120,7 +115,6 @@ def pipeline( # is any of the arguments different from defaults has_arguments = bool(orig_args[0]) or any(orig_args[1].values()) - credentials_argument_deprecated("pipeline", credentials, destination) full_refresh_argument_deprecated("pipeline", full_refresh) if not has_arguments: @@ -153,7 +147,6 @@ def pipeline( destination, staging, dataset_name, - credentials, import_schema_path, export_schema_path, full_refresh if full_refresh is not None else dev_mode, @@ -173,31 +166,38 @@ def attach( pipeline_name: str = None, pipelines_dir: str = None, pipeline_salt: TSecretValue = None, - full_refresh: Optional[bool] = None, - dev_mode: bool = False, - credentials: Any = None, + destination: TDestinationReferenceArg = None, + staging: TDestinationReferenceArg = None, progress: TCollectorArg = _NULL_COLLECTOR, **injection_kwargs: Any, ) -> Pipeline: - """Attaches to the working folder of `pipeline_name` in `pipelines_dir` or in default directory. Requires that valid pipeline state exists in working folder.""" + """Attaches to the working folder of `pipeline_name` in `pipelines_dir` or in default directory. Requires that valid pipeline state exists in working folder. + Pre-configured `destination` and `staging` factories may be provided. If not present, default factories are created from pipeline state. + """ ensure_correct_pipeline_kwargs(attach, **injection_kwargs) - full_refresh_argument_deprecated("attach", full_refresh) # if working_dir not provided use temp folder if not pipelines_dir: pipelines_dir = get_dlt_pipelines_dir() progress = collector_from_name(progress) + destination = Destination.from_reference( + destination or injection_kwargs["destination_type"], + destination_name=injection_kwargs["destination_name"], + ) + staging = Destination.from_reference( + staging or injection_kwargs.get("staging_type", None), + destination_name=injection_kwargs.get("staging_name", None), + ) # create new pipeline instance p = Pipeline( pipeline_name, pipelines_dir, pipeline_salt, + destination, + staging, None, None, None, - credentials, - None, - None, - full_refresh if full_refresh is not None else dev_mode, + False, # always False as dev_mode so we do not wipe the working folder progress, True, last_config(**injection_kwargs), @@ -214,7 +214,6 @@ def run( destination: TDestinationReferenceArg = None, staging: TDestinationReferenceArg = None, dataset_name: str = None, - credentials: Any = None, table_name: str = None, write_disposition: TWriteDispositionConfig = None, columns: Sequence[TColumnSchema] = None, @@ -249,9 +248,6 @@ def run( dataset_name (str, optional):A name of the dataset to which the data will be loaded. A dataset is a logical group of tables ie. `schema` in relational databases or folder grouping many files. If not provided, the value passed to `dlt.pipeline` will be used. If not provided at all then defaults to the `pipeline_name` - credentials (Any, optional): Credentials for the `destination` ie. database connection string or a dictionary with google cloud credentials. - In most cases should be set to None, which lets `dlt` to use `secrets.toml` or environment variables to infer right credentials values. - table_name (str, optional): The name of the table to which the data should be loaded within the `dataset`. This argument is required for a `data` that is a list/Iterable or Iterator without `__name__` attribute. The behavior of this argument depends on the type of the `data`: * generator functions: the function name is used as table name, `table_name` overrides this default @@ -272,13 +268,12 @@ def run( Returns: LoadInfo: Information on loaded data including the list of package ids and failed job statuses. Please not that `dlt` will not raise if a single job terminally fails. Such information is provided via LoadInfo. """ - destination = Destination.from_reference(destination, credentials=credentials) + destination = Destination.from_reference(destination) return pipeline().run( data, destination=destination, staging=staging, dataset_name=dataset_name, - credentials=credentials, table_name=table_name, write_disposition=write_disposition, columns=columns, diff --git a/dlt/pipeline/configuration.py b/dlt/pipeline/configuration.py index 235ba3485a..723e0ded83 100644 --- a/dlt/pipeline/configuration.py +++ b/dlt/pipeline/configuration.py @@ -1,11 +1,13 @@ from typing import Any, Optional +import dlt from dlt.common.configuration import configspec from dlt.common.configuration.specs import RunConfiguration, BaseConfiguration from dlt.common.typing import AnyFun, TSecretValue from dlt.common.utils import digest256 from dlt.common.destination import TLoaderFileFormat from dlt.common.pipeline import TRefreshMode +from dlt.common.configuration.exceptions import ConfigurationValueError @configspec @@ -18,6 +20,8 @@ class PipelineConfiguration(BaseConfiguration): staging_name: Optional[str] = None loader_file_format: Optional[TLoaderFileFormat] = None dataset_name: Optional[str] = None + dataset_name_layout: Optional[str] = None + """Layout for dataset_name, where %s is replaced with dataset_name. For example: 'prefix_%s'""" pipeline_salt: Optional[TSecretValue] = None restore_from_destination: bool = True """Enables the `run` method of the `Pipeline` object to restore the pipeline state and schemas from the destination""" @@ -41,6 +45,11 @@ def on_resolved(self) -> None: self.runtime.pipeline_name = self.pipeline_name if not self.pipeline_salt: self.pipeline_salt = TSecretValue(digest256(self.pipeline_name)) + if self.dataset_name_layout and "%s" not in self.dataset_name_layout: + raise ConfigurationValueError( + "The dataset_name_layout must contain a '%s' placeholder for dataset_name. For" + " example: 'prefix_%s'" + ) def ensure_correct_pipeline_kwargs(f: AnyFun, **kwargs: Any) -> None: diff --git a/dlt/pipeline/current.py b/dlt/pipeline/current.py index 25fd398623..2ae74e2532 100644 --- a/dlt/pipeline/current.py +++ b/dlt/pipeline/current.py @@ -1,7 +1,7 @@ """Easy access to active pipelines, state, sources and schemas""" from dlt.common.pipeline import source_state as _state, resource_state, get_current_pipe_name -from dlt.pipeline import pipeline as _pipeline +from dlt.pipeline.pipeline import Pipeline from dlt.extract.decorators import get_source_schema from dlt.common.storages.load_package import ( load_package, @@ -11,10 +11,15 @@ ) from dlt.extract.decorators import get_source_schema, get_source -pipeline = _pipeline -"""Alias for dlt.pipeline""" + +def pipeline() -> Pipeline: + """Currently active pipeline ie. the most recently created or run""" + from dlt import _pipeline + + return _pipeline() + + state = source_state = _state -"""Alias for dlt.state""" source_schema = get_source_schema source = get_source pipe_name = get_current_pipe_name diff --git a/dlt/pipeline/dbt.py b/dlt/pipeline/dbt.py index ee900005fd..0b6ec5f896 100644 --- a/dlt/pipeline/dbt.py +++ b/dlt/pipeline/dbt.py @@ -38,7 +38,7 @@ def get_venv( # keep venv inside pipeline if path is relative if not os.path.isabs(venv_path): pipeline._pipeline_storage.create_folder(venv_path, exists_ok=True) - venv_dir = pipeline._pipeline_storage.make_full_path(venv_path) + venv_dir = pipeline._pipeline_storage.make_full_path_safe(venv_path) else: venv_dir = venv_path # try to restore existing venv diff --git a/dlt/pipeline/drop.py b/dlt/pipeline/drop.py index 486bead2f4..cd982cf676 100644 --- a/dlt/pipeline/drop.py +++ b/dlt/pipeline/drop.py @@ -17,6 +17,7 @@ group_tables_by_resource, compile_simple_regexes, compile_simple_regex, + has_table_seen_data, ) from dlt.common import jsonpath from dlt.common.typing import REPattern @@ -24,11 +25,11 @@ class _DropInfo(TypedDict): tables: List[str] + tables_with_data: List[str] resource_states: List[str] resource_names: List[str] state_paths: List[str] schema_name: str - dataset_name: Optional[str] drop_all: bool resource_pattern: Optional[REPattern] warnings: List[str] @@ -39,7 +40,7 @@ class _DropResult: schema: Schema state: TPipelineState info: _DropInfo - dropped_tables: List[TTableSchema] + modified_tables: List[TTableSchema] def _create_modified_state( @@ -85,12 +86,12 @@ def drop_resources( """Generate a new schema and pipeline state with the requested resources removed. Args: - schema: The schema to modify. - state: The pipeline state to modify. + schema: The schema to modify. Note that schema is changed in place. + state: The pipeline state to modify. Note that state is changed in place. resources: Resource name(s) or regex pattern(s) matching resource names to drop. If empty, no resources will be dropped unless `drop_all` is True. state_paths: JSON path(s) relative to the source state to drop. - drop_all: If True, all resources will be dropped (supeseeds `resources`). + drop_all: If True, all resources will be dropped (supersedes `resources`). state_only: If True, only modify the pipeline state, not schema sources: Only wipe state for sources matching the name(s) or regex pattern(s) in this list If not set all source states will be modified according to `state_paths` and `resources` @@ -112,9 +113,6 @@ def drop_resources( state_paths = jsonpath.compile_paths(state_paths) - schema = schema.clone() - state = deepcopy(state) - resources = set(resources) if drop_all: resource_pattern = compile_simple_regex(TSimpleRegex("re:.*")) # Match everything @@ -128,28 +126,28 @@ def drop_resources( source_pattern = compile_simple_regex(TSimpleRegex("re:.*")) # Match everything if resource_pattern: - data_tables = { - t["name"]: t for t in schema.data_tables(seen_data_only=True) - } # Don't remove _dlt tables + # (1) Don't remove _dlt tables (2) Drop all selected tables from the schema + # (3) Mark tables that seen data to be dropped in destination + data_tables = {t["name"]: t for t in schema.data_tables(include_incomplete=True)} resource_tables = group_tables_by_resource(data_tables, pattern=resource_pattern) resource_names = list(resource_tables.keys()) - # TODO: If drop_tables - if not state_only: - tables_to_drop = list(chain.from_iterable(resource_tables.values())) - tables_to_drop.reverse() - else: - tables_to_drop = [] + tables_to_drop_from_schema = list(chain.from_iterable(resource_tables.values())) + tables_to_drop_from_schema.reverse() + tables_to_drop_from_schema_names = [t["name"] for t in tables_to_drop_from_schema] + tables_to_drop_from_dest = [t for t in tables_to_drop_from_schema if has_table_seen_data(t)] else: - tables_to_drop = [] + tables_to_drop_from_schema_names = [] + tables_to_drop_from_dest = [] + tables_to_drop_from_schema = [] resource_names = [] info: _DropInfo = dict( - tables=[t["name"] for t in tables_to_drop], + tables=tables_to_drop_from_schema_names if not state_only else [], + tables_with_data=[t["name"] for t in tables_to_drop_from_dest] if not state_only else [], resource_states=[], state_paths=[], - resource_names=resource_names, + resource_names=resource_names if not state_only else [], schema_name=schema.name, - dataset_name=None, drop_all=drop_all, resource_pattern=resource_pattern, warnings=[], @@ -158,7 +156,7 @@ def drop_resources( new_state, info = _create_modified_state( state, resource_pattern, source_pattern, state_paths, info ) - info["resource_names"] = resource_names + # info["resource_names"] = resource_names if not state_only else [] if resource_pattern and not resource_tables: info["warnings"].append( @@ -167,5 +165,7 @@ def drop_resources( f" {list(group_tables_by_resource(data_tables).keys())}" ) - dropped_tables = schema.drop_tables([t["name"] for t in tables_to_drop], seen_data_only=True) - return _DropResult(schema, new_state, info, dropped_tables) + if not state_only: + # drop only the selected tables + schema.drop_tables(tables_to_drop_from_schema_names) + return _DropResult(schema, new_state, info, tables_to_drop_from_dest) diff --git a/dlt/pipeline/helpers.py b/dlt/pipeline/helpers.py index 0defbc14eb..ce81b81433 100644 --- a/dlt/pipeline/helpers.py +++ b/dlt/pipeline/helpers.py @@ -12,8 +12,10 @@ from dlt.common.jsonpath import TAnyJsonPath from dlt.common.exceptions import TerminalException +from dlt.common.schema.schema import Schema from dlt.common.schema.typing import TSimpleRegex from dlt.common.pipeline import pipeline_state as current_pipeline_state, TRefreshMode +from dlt.common.storages.load_package import TLoadPackageDropTablesState from dlt.pipeline.exceptions import ( PipelineNeverRan, PipelineStepFailed, @@ -83,24 +85,24 @@ def __init__( if not pipeline.default_schema_name: raise PipelineNeverRan(pipeline.pipeline_name, pipeline.pipelines_dir) + # clone schema to keep it as original in case we need to restore pipeline schema self.schema = pipeline.schemas[schema_name or pipeline.default_schema_name].clone() drop_result = drop_resources( - # self._drop_schema, self._new_state, self.info = drop_resources( - self.schema, - pipeline.state, + # create clones to have separate schemas and state + self.schema.clone(), + deepcopy(pipeline.state), resources, state_paths, drop_all, state_only, ) - + # get modified schema and state self._new_state = drop_result.state - self.info = drop_result.info self._new_schema = drop_result.schema - self._dropped_tables = drop_result.dropped_tables - self.drop_tables = not state_only and bool(self._dropped_tables) - + self.info = drop_result.info + self._modified_tables = drop_result.modified_tables + self.drop_tables = not state_only and bool(self._modified_tables) self.drop_state = bool(drop_all or resources or state_paths) @property @@ -130,7 +132,9 @@ def __call__(self) -> None: self.pipeline._save_and_extract_state_and_schema( new_state, schema=self._new_schema, - load_package_state_update={"dropped_tables": self._dropped_tables}, + load_package_state_update=( + {"dropped_tables": self._modified_tables} if self.drop_tables else None + ), ) self.pipeline.normalize() @@ -159,30 +163,33 @@ def drop( def refresh_source( pipeline: "Pipeline", source: DltSource, refresh: TRefreshMode -) -> Dict[str, Any]: - """Run the pipeline's refresh mode on the given source, updating the source's schema and state. +) -> TLoadPackageDropTablesState: + """Run the pipeline's refresh mode on the given source, updating the provided `schema` and pipeline state. Returns: The new load package state containing tables that need to be dropped/truncated. """ - if pipeline.first_run: - return {} pipeline_state, _ = current_pipeline_state(pipeline._container) _resources_to_drop = list(source.resources.extracted) if refresh != "drop_sources" else [] + only_truncate = refresh == "drop_data" + drop_result = drop_resources( + # do not cline the schema, change in place source.schema, + # do not clone the state, change in place pipeline_state, resources=_resources_to_drop, drop_all=refresh == "drop_sources", state_paths="*" if refresh == "drop_sources" else [], + state_only=only_truncate, sources=source.name, ) - load_package_state = {} - if drop_result.dropped_tables: - key = "dropped_tables" if refresh != "drop_data" else "truncated_tables" - load_package_state[key] = drop_result.dropped_tables - if refresh != "drop_data": # drop_data is only data wipe, keep original schema - source.schema = drop_result.schema - if "sources" in drop_result.state: - pipeline_state["sources"] = drop_result.state["sources"] + load_package_state: TLoadPackageDropTablesState = {} + if drop_result.modified_tables: + if only_truncate: + load_package_state["truncated_tables"] = drop_result.modified_tables + else: + load_package_state["dropped_tables"] = drop_result.modified_tables + # if any tables should be dropped, we force state to extract + force_state_extract(pipeline_state) return load_package_state diff --git a/dlt/pipeline/mark.py b/dlt/pipeline/mark.py index 3956d9bbe2..5f3122e7a5 100644 --- a/dlt/pipeline/mark.py +++ b/dlt/pipeline/mark.py @@ -2,6 +2,7 @@ from dlt.extract import ( with_table_name, with_hints, + with_file_import, make_hints, materialize_schema_item as materialize_table_schema, ) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 9e068df7a9..1139302e70 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -1,7 +1,7 @@ import contextlib import os -import datetime # noqa: 251 from contextlib import contextmanager +from copy import deepcopy, copy from functools import wraps from typing import ( Any, @@ -12,7 +12,6 @@ Optional, Sequence, Tuple, - Type, cast, get_type_hints, ContextManager, @@ -24,21 +23,21 @@ from dlt.common.json import json from dlt.common.pendulum import pendulum from dlt.common.configuration import inject_section, known_sections -from dlt.common.configuration.specs import RunConfiguration, CredentialsConfiguration +from dlt.common.configuration.specs import RunConfiguration from dlt.common.configuration.container import Container from dlt.common.configuration.exceptions import ( ConfigFieldMissingException, ContextDefaultCannotBeCreated, + ConfigurationValueError, ) from dlt.common.configuration.specs.config_section_context import ConfigSectionContext -from dlt.common.configuration.resolve import initialize_credentials from dlt.common.destination.exceptions import ( DestinationIncompatibleLoaderFileFormatException, DestinationNoStagingMode, DestinationUndefinedEntity, + DestinationCapabilitiesException, ) from dlt.common.exceptions import MissingDependencyException -from dlt.common.normalizers import explicit_normalizers, import_normalizers from dlt.common.runtime import signals, initialize_runtime from dlt.common.schema.typing import ( TColumnNames, @@ -46,7 +45,6 @@ TWriteDispositionConfig, TAnySchemaColumns, TSchemaContract, - TTableSchema, ) from dlt.common.schema.utils import normalize_schema_name from dlt.common.storages.exceptions import LoadPackageNotFound @@ -84,12 +82,12 @@ DestinationClientStagingConfiguration, DestinationClientDwhWithStagingConfiguration, ) +from dlt.common.normalizers.naming import NamingConvention from dlt.common.pipeline import ( ExtractInfo, LoadInfo, NormalizeInfo, PipelineContext, - StepInfo, TStepInfo, SupportsPipeline, TPipelineLocalState, @@ -104,7 +102,7 @@ from dlt.common.warnings import deprecated, Dlt04DeprecationWarning from dlt.common.versioned_state import json_encode_state, json_decode_state -from dlt.extract import DltSource, DltResource +from dlt.extract import DltSource from dlt.extract.exceptions import SourceExhausted from dlt.extract.extract import Extract, data_to_sources from dlt.normalize import Normalize @@ -125,7 +123,6 @@ PipelineStepFailed, SqlClientNotAvailable, FSClientNotAvailable, - PipelineNeverRan, ) from dlt.pipeline.trace import ( PipelineTrace, @@ -148,7 +145,6 @@ state_resource, default_pipeline_state, ) -from dlt.pipeline.warnings import credentials_argument_deprecated from dlt.common.storages.load_package import TLoadPackageState from dlt.pipeline.helpers import refresh_source from dlt.dataset import Dataset @@ -163,10 +159,8 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: # backup and restore state should_extract_state = may_extract_state and self.config.restore_from_destination - with self.managed_state(extract_state=should_extract_state) as state: - # add the state to container as a context - with self._container.injectable_context(StateInjectableContext(state=state)): - return f(self, *args, **kwargs) + with self.managed_state(extract_state=should_extract_state): + return f(self, *args, **kwargs) return _wrap # type: ignore @@ -311,7 +305,6 @@ class Pipeline(SupportsPipeline): """The destination reference which is the Destination Class. `destination.destination_name` returns the name string""" dataset_name: str = None """Name of the dataset to which pipeline will be loaded to""" - credentials: Any = None is_active: bool = False """Tells if instance is currently active and available via dlt.pipeline()""" collector: _Collector @@ -327,7 +320,6 @@ def __init__( destination: TDestination, staging: TDestination, dataset_name: str, - credentials: Any, import_schema_path: str, export_schema_path: str, dev_mode: bool, @@ -361,14 +353,13 @@ def __init__( self._init_working_dir(pipeline_name, pipelines_dir) with self.managed_state() as state: + self._configure(import_schema_path, export_schema_path, must_attach_to_local_pipeline) # changing the destination could be dangerous if pipeline has pending load packages - self._set_destinations(destination=destination, staging=staging) + self._set_destinations(destination=destination, staging=staging, initializing=True) # set the pipeline properties from state, destination and staging will not be set self._state_to_props(state) # we overwrite the state with the values from init self._set_dataset_name(dataset_name) - self.credentials = credentials - self._configure(import_schema_path, export_schema_path, must_attach_to_local_pipeline) def drop(self, pipeline_name: str = None) -> "Pipeline": """Deletes local pipeline state, schemas and any working files. @@ -386,7 +377,6 @@ def drop(self, pipeline_name: str = None) -> "Pipeline": self.destination, self.staging, self.dataset_name, - self.credentials, self._schema_storage.config.import_schema_path, self._schema_storage.config.export_schema_path, self.dev_mode, @@ -454,15 +444,14 @@ def extract( resource.dataset = self.dataset resource.dataset.bind_table_name(resource.table_name) - # extract state - state: TPipelineStateDoc = None - if self.config.restore_from_destination: - # this will update state version hash so it will not be extracted again by with_state_sync - state = self._bump_version_and_extract_state( - self._container[StateInjectableContext].state, True, extract_step - ) + # this will update state version hash so it will not be extracted again by with_state_sync + self._bump_version_and_extract_state( + self._container[StateInjectableContext].state, + self.config.restore_from_destination, + extract_step, + ) # commit load packages with state - extract_step.commit_packages(state) + extract_step.commit_packages() return self._get_step_info(extract_step) except Exception as exc: # emit step info @@ -476,6 +465,32 @@ def extract( step_info, ) from exc + def _verify_destination_capabilities( + self, + caps: DestinationCapabilitiesContext, + loader_file_format: TLoaderFileFormat, + ) -> None: + # verify loader file format + if loader_file_format and loader_file_format not in caps.supported_loader_file_formats: + raise DestinationIncompatibleLoaderFileFormatException( + self.destination.destination_name, + (self.staging.destination_name if self.staging else None), + loader_file_format, + set(caps.supported_loader_file_formats), + ) + + # verify merge strategy + for table in self.default_schema.data_tables(include_incomplete=True): + # temp solution to prevent raising exceptions for destinations such as + # `fileystem` and `weaviate`, which do handle the `merge` write + # disposition, but don't implement any of the defined merge strategies + if caps.supported_merge_strategies is not None: + if "x-merge-strategy" in table and table["x-merge-strategy"] not in caps.supported_merge_strategies: # type: ignore[typeddict-item] + raise DestinationCapabilitiesException( + f"`{table.get('x-merge-strategy')}` merge strategy not supported" + f" for `{self.destination.destination_name}` destination." + ) + @with_runtime_trace() @with_schemas_sync @with_config_section((known_sections.NORMALIZE,)) @@ -505,13 +520,8 @@ def normalize( ) # run with destination context with self._maybe_destination_capabilities() as caps: - if loader_file_format and loader_file_format not in caps.supported_loader_file_formats: - raise DestinationIncompatibleLoaderFileFormatException( - self.destination.destination_name, - (self.staging.destination_name if self.staging else None), - loader_file_format, - set(caps.supported_loader_file_formats), - ) + self._verify_destination_capabilities(caps, loader_file_format) + # shares schema storage with the pipeline so we do not need to install normalize_step: Normalize = Normalize( collector=self.collector, @@ -544,15 +554,13 @@ def load( workers: int = 20, raise_on_failed_jobs: bool = False, ) -> LoadInfo: - """Loads the packages prepared by `normalize` method into the `dataset_name` at `destination`, using provided `credentials`""" + """Loads the packages prepared by `normalize` method into the `dataset_name` at `destination`, optionally using provided `credentials`""" # set destination and default dataset if provided (this is the reason we have state sync here) - self._set_destinations(destination=destination, staging=None) + self._set_destinations( + destination=destination, destination_credentials=credentials, staging=None + ) self._set_dataset_name(dataset_name) - credentials_argument_deprecated("pipeline.load", credentials, destination) - - self.credentials = credentials or self.credentials - # check if any schema is present, if not then no data was extracted if not self.default_schema_name: return None @@ -634,7 +642,6 @@ def run( dataset_name (str, optional):A name of the dataset to which the data will be loaded. A dataset is a logical group of tables ie. `schema` in relational databases or folder grouping many files. If not provided, the value passed to `dlt.pipeline` will be used. If not provided at all then defaults to the `pipeline_name` - credentials (Any, optional): Credentials for the `destination` ie. database connection string or a dictionary with google cloud credentials. In most cases should be set to None, which lets `dlt` to use `secrets.toml` or environment variables to infer right credentials values. @@ -672,11 +679,11 @@ def run( signals.raise_if_signalled() self.activate() - self._set_destinations(destination=destination, staging=staging) + self._set_destinations( + destination=destination, destination_credentials=credentials, staging=staging + ) self._set_dataset_name(dataset_name) - credentials_argument_deprecated("pipeline.run", credentials, self.destination) - # sync state with destination if ( self.config.restore_from_destination @@ -743,6 +750,7 @@ def sync_destination( try: try: restored_schemas: Sequence[Schema] = None + remote_state = self._restore_state_from_destination() # if remote state is newer or same @@ -874,6 +882,11 @@ def state(self) -> TPipelineState: """Returns a dictionary with the pipeline state""" return self._get_state() + @property + def naming(self) -> NamingConvention: + """Returns naming convention of the default schema""" + return self._get_schema_or_create().naming + @property def last_trace(self) -> PipelineTrace: """Returns or loads last trace generated by pipeline. The trace is loaded from standard location.""" @@ -932,7 +945,7 @@ def drop_pending_packages(self, with_partial_loads: bool = True) -> None: normalize_storage.extracted_packages.delete_package(load_id) @with_schemas_sync - def sync_schema(self, schema_name: str = None, credentials: Any = None) -> TSchemaTables: + def sync_schema(self, schema_name: str = None) -> TSchemaTables: """Synchronizes the schema `schema_name` with the destination. If no name is provided, the default schema will be synchronized.""" if not schema_name and not self.default_schema_name: raise PipelineConfigMissing( @@ -968,7 +981,7 @@ def get_local_state_val(self, key: str) -> Any: state = self._get_state() return state["_local"][key] # type: ignore - def sql_client(self, schema_name: str = None, credentials: Any = None) -> SqlClientBase[Any]: + def sql_client(self, schema_name: str = None) -> SqlClientBase[Any]: """Returns a sql client configured to query/change the destination and dataset that were used to load the data. Use the client with `with` statement to manage opening and closing connection to the destination: >>> with pipeline.sql_client() as client: @@ -978,7 +991,7 @@ def sql_client(self, schema_name: str = None, credentials: Any = None) -> SqlCli >>> print(cursor.fetchall()) The client is authenticated and defaults all queries to dataset_name used by the pipeline. You can provide alternative - `schema_name` which will be used to normalize dataset name and alternative `credentials`. + `schema_name` which will be used to normalize dataset name. """ # if not self.default_schema_name and not schema_name: # raise PipelineConfigMissing( @@ -988,9 +1001,9 @@ def sql_client(self, schema_name: str = None, credentials: Any = None) -> SqlCli # "Sql Client is not available in a pipeline without a default schema. Extract some data first or restore the pipeline from the destination using 'restore_from_destination' flag. There's also `_inject_schema` method for advanced users." # ) schema = self._get_schema_or_create(schema_name) - return self._sql_job_client(schema, credentials).sql_client + return self._sql_job_client(schema).sql_client - def _fs_client(self, schema_name: str = None, credentials: Any = None) -> FSClientBase: + def _fs_client(self, schema_name: str = None) -> FSClientBase: """Returns a filesystem client configured to point to the right folder / bucket for each table. For example you may read all parquet files as bytes for one table with the following code: >>> files = pipeline._fs_client.list_table_files("customers") @@ -1002,18 +1015,18 @@ def _fs_client(self, schema_name: str = None, credentials: Any = None) -> FSClie NOTE: This currently is considered a private endpoint and will become stable after we have decided on the interface of FSClientBase. """ - client = self.destination_client(schema_name, credentials) + client = self.destination_client(schema_name) if isinstance(client, FSClientBase): return client raise FSClientNotAvailable(self.pipeline_name, self.destination.destination_name) - def destination_client(self, schema_name: str = None, credentials: Any = None) -> JobClientBase: + def destination_client(self, schema_name: str = None) -> JobClientBase: """Get the destination job client for the configured destination Use the client with `with` statement to manage opening and closing connection to the destination: >>> with pipeline.destination_client() as client: >>> client.drop_storage() # removes storage which typically wipes all data in it - The client is authenticated. You can provide alternative `schema_name` which will be used to normalize dataset name and alternative `credentials`. + The client is authenticated. You can provide alternative `schema_name` which will be used to normalize dataset name. If no schema name is provided and no default schema is present in the pipeline, and ad hoc schema will be created and discarded after use. """ schema = self._get_schema_or_create(schema_name) @@ -1027,8 +1040,8 @@ def _get_schema_or_create(self, schema_name: str = None) -> Schema: with self._maybe_destination_capabilities(): return Schema(self.pipeline_name) - def _sql_job_client(self, schema: Schema, credentials: Any = None) -> SqlJobClientBase: - client_config = self._get_destination_client_initial_config(credentials) + def _sql_job_client(self, schema: Schema) -> SqlJobClientBase: + client_config = self._get_destination_client_initial_config() client = self._get_destination_clients(schema, client_config)[0] if isinstance(client, SqlJobClientBase): return client @@ -1121,8 +1134,9 @@ def _extract_source( max_parallel_items: int, workers: int, refresh: Optional[TRefreshMode] = None, - load_package_state_update: Optional[Dict[str, Any]] = None, + load_package_state_update: Optional[TLoadPackageState] = None, ) -> str: + load_package_state_update = copy(load_package_state_update or {}) # discover the existing pipeline schema try: # all live schemas are initially committed and during the extract will accumulate changes in memory @@ -1130,29 +1144,40 @@ def _extract_source( # this will (1) look for import schema if present # (2) load import schema an overwrite pipeline schema if import schema modified # (3) load pipeline schema if no import schema is present - pipeline_schema = self.schemas[source.schema.name] - pipeline_schema = pipeline_schema.clone() # use clone until extraction complete - # apply all changes in the source schema to pipeline schema - # NOTE: we do not apply contracts to changes done programmatically - pipeline_schema.update_schema(source.schema) - # replace schema in the source - source.schema = pipeline_schema + + # keep schema created by the source so we can apply changes from it later + source_schema = source.schema + # use existing pipeline schema as the source schema, clone until extraction complete + source.schema = self.schemas[source.schema.name].clone() + # refresh the pipeline schema ie. to drop certain tables before any normalizes change + if refresh: + # NOTE: we use original pipeline schema to detect dropped/truncated tables so we can drop + # the original names, before eventual new naming convention is applied + load_package_state_update.update(deepcopy(refresh_source(self, source, refresh))) + if refresh == "drop_sources": + # replace the whole source AFTER we got tables to drop + source.schema = source_schema + # NOTE: we do pass any programmatic changes from source schema to pipeline schema except settings below + # TODO: enable when we have full identifier lineage and we are able to merge table identifiers + if type(source.schema.naming) is not type(source_schema.naming): # noqa + source.schema_contract = source_schema.settings.get("schema_contract") + else: + source.schema.update_schema(source_schema) except FileNotFoundError: - pass + if refresh is not None: + logger.info( + f"Refresh flag {refresh} has no effect on source {source.name} because the" + " source is extracted for a first time" + ) - load_package_state_update = dict(load_package_state_update or {}) - if refresh: - load_package_state_update.update(refresh_source(self, source, refresh)) + # update the normalizers to detect any conflicts early + source.schema.update_normalizers() # extract into pipeline schema load_id = extract.extract( source, max_parallel_items, workers, load_package_state_update=load_package_state_update ) - # save import with fully discovered schema - # NOTE: moved to with_schema_sync, remove this if all test pass - # self._schema_storage.save_import_schema_if_not_exists(source.schema) - # update live schema but not update the store yet source.schema = self._schema_storage.set_live_schema(source.schema) @@ -1164,7 +1189,7 @@ def _extract_source( return load_id def _get_destination_client_initial_config( - self, destination: TDestination = None, credentials: Any = None, as_staging: bool = False + self, destination: TDestination = None, as_staging: bool = False ) -> DestinationClientConfiguration: destination = destination or self.destination if not destination: @@ -1175,19 +1200,9 @@ def _get_destination_client_initial_config( "Please provide `destination` argument to `pipeline`, `run` or `load` method" " directly or via .dlt config.toml file or environment variable.", ) - # create initial destination client config client_spec = destination.spec - # initialize explicit credentials - if not as_staging: - # explicit credentials passed to dlt.pipeline should not be applied to staging - credentials = credentials or self.credentials - if credentials is not None and not isinstance(credentials, CredentialsConfiguration): - # use passed credentials as initial value. initial value may resolve credentials - credentials = initialize_credentials( - client_spec.get_resolvable_fields()["credentials"], credentials - ) - # this client support many schemas and datasets + # this client supports many schemas and datasets if issubclass(client_spec, DestinationClientDwhConfiguration): if not self.dataset_name and self.dev_mode: logger.warning( @@ -1200,18 +1215,13 @@ def _get_destination_client_initial_config( ) if issubclass(client_spec, DestinationClientStagingConfiguration): - spec: DestinationClientDwhConfiguration = client_spec( - credentials=credentials, - as_staging=as_staging, - ) + spec: DestinationClientDwhConfiguration = client_spec(as_staging=as_staging) else: - spec = client_spec( - credentials=credentials, - ) + spec = client_spec() spec._bind_dataset_name(self.dataset_name, default_schema_name) return spec - return client_spec(credentials=credentials) + return client_spec() def _get_destination_clients( self, @@ -1260,10 +1270,28 @@ def _get_destination_capabilities(self) -> DestinationCapabilitiesContext: "Please provide `destination` argument to `pipeline`, `run` or `load` method" " directly or via .dlt config.toml file or environment variable.", ) - return self.destination.capabilities() + # check if default schema is present + if ( + self.default_schema_name is not None + and self.default_schema_name in self._schema_storage + ): + naming = self.default_schema.naming + else: + naming = None + return self.destination.capabilities(naming=naming) def _get_staging_capabilities(self) -> Optional[DestinationCapabilitiesContext]: - return self.staging.capabilities() if self.staging is not None else None + if self.staging is None: + return None + # check if default schema is present + if ( + self.default_schema_name is not None + and self.default_schema_name in self._schema_storage + ): + naming = self.default_schema.naming + else: + naming = None + return self.staging.capabilities(naming=naming) def _validate_pipeline_name(self) -> None: try: @@ -1299,9 +1327,12 @@ def _set_destinations( destination_name: Optional[str] = None, staging: Optional[TDestinationReferenceArg] = None, staging_name: Optional[str] = None, + initializing: bool = False, + destination_credentials: Any = None, ) -> None: - # destination_mod = DestinationReference.from_name(destination) - if destination: + destination_changed = destination is not None and destination != self.destination + # set destination if provided but do not swap if factory is the same + if destination_changed: self.destination = Destination.from_reference( destination, destination_name=destination_name ) @@ -1320,7 +1351,8 @@ def _set_destinations( staging = "filesystem" staging_name = "filesystem" - if staging: + staging_changed = staging is not None and staging != self.staging + if staging_changed: staging_module = Destination.from_reference(staging, destination_name=staging_name) if staging_module and not issubclass( staging_module.spec, DestinationClientStagingConfiguration @@ -1328,17 +1360,27 @@ def _set_destinations( raise DestinationNoStagingMode(staging_module.destination_name) self.staging = staging_module - with self._maybe_destination_capabilities(): - # default normalizers must match the destination - self._set_default_normalizers() + if staging_changed or destination_changed: + # make sure that capabilities can be generated + with self._maybe_destination_capabilities(): + # update normalizers in all live schemas, only when destination changed + if destination_changed and not initializing: + for schema in self._schema_storage.live_schemas.values(): + schema.update_normalizers() + # set new context + if not initializing: + self._set_context(is_active=True) + # apply explicit credentials + if self.destination and destination_credentials: + self.destination.config_params["credentials"] = destination_credentials @contextmanager def _maybe_destination_capabilities( self, ) -> Iterator[DestinationCapabilitiesContext]: + caps: DestinationCapabilitiesContext = None + injected_caps: ContextManager[DestinationCapabilitiesContext] = None try: - caps: DestinationCapabilitiesContext = None - injected_caps: ContextManager[DestinationCapabilitiesContext] = None if self.destination: destination_caps = self._get_destination_capabilities() stage_caps = self._get_staging_capabilities() @@ -1358,9 +1400,6 @@ def _maybe_destination_capabilities( if injected_caps: injected_caps.__exit__(None, None, None) - def _set_default_normalizers(self) -> None: - _, self._default_naming, _ = import_normalizers(explicit_normalizers()) - def _set_dataset_name(self, new_dataset_name: str) -> None: if not new_dataset_name and not self.dataset_name: # dataset name is required but not provided - generate the default now @@ -1389,6 +1428,10 @@ def _set_dataset_name(self, new_dataset_name: str) -> None: new_dataset_name += self._pipeline_instance_id self.dataset_name = new_dataset_name + # normalizes the dataset name using the dataset_name_layout + if self.config.dataset_name_layout: + self.dataset_name = self.config.dataset_name_layout % self.dataset_name + def _set_default_schema_name(self, schema: Schema) -> None: assert self.default_schema_name is None self.default_schema_name = schema.name @@ -1508,11 +1551,15 @@ def _get_schemas_from_destination( @contextmanager def managed_state(self, *, extract_state: bool = False) -> Iterator[TPipelineState]: - # load or restore state + """Puts pipeline state in managed mode, where yielded state changes will be persisted or fully roll-backed on exception. + + Makes the state to be available via StateInjectableContext + """ state = self._get_state() - # TODO: we should backup schemas here try: - yield state + # add the state to container as a context + with self._container.injectable_context(StateInjectableContext(state=state)): + yield state except Exception: backup_state = self._get_state() # restore original pipeline props @@ -1580,7 +1627,7 @@ def _save_and_extract_state_and_schema( self, state: TPipelineState, schema: Schema, - load_package_state_update: Optional[Dict[str, Any]] = None, + load_package_state_update: Optional[TLoadPackageState] = None, ) -> None: """Save given state + schema and extract creating a new load package @@ -1605,9 +1652,9 @@ def _bump_version_and_extract_state( state: TPipelineState, extract_state: bool, extract: Extract = None, - load_package_state_update: Optional[Dict[str, Any]] = None, + load_package_state_update: Optional[TLoadPackageState] = None, schema: Optional[Schema] = None, - ) -> TPipelineStateDoc: + ) -> None: """Merges existing state into `state` and extracts state using `storage` if extract_state is True. Storage will be created on demand. In that case the extracted package will be immediately committed. @@ -1615,13 +1662,24 @@ def _bump_version_and_extract_state( _, hash_, _ = bump_pipeline_state_version_if_modified(self._props_to_state(state)) should_extract = hash_ != state["_local"].get("_last_extracted_hash") if should_extract and extract_state: - data, doc = state_resource(state) - extract_ = extract or Extract( - self._schema_storage, self._normalize_storage_config(), original_data=data + extract_ = extract or Extract(self._schema_storage, self._normalize_storage_config()) + # create or get load package upfront to get load_id to create state doc + schema = schema or self.default_schema + # note that we preferably retrieve existing package for `schema` + # same thing happens in extract_.extract so the load_id is preserved + load_id = extract_.extract_storage.create_load_package( + schema, reuse_exiting_package=True ) + data, doc = state_resource(state, load_id) + # keep the original data to be used in the metrics + if extract_.original_data is None: + extract_.original_data = data + # append pipeline state to package state + load_package_state_update = load_package_state_update or {} + load_package_state_update["pipeline_state"] = doc self._extract_source( extract_, - data_to_sources(data, self, schema or self.default_schema)[0], + data_to_sources(data, self, schema)[0], 1, 1, load_package_state_update=load_package_state_update, @@ -1630,9 +1688,7 @@ def _bump_version_and_extract_state( mark_state_extracted(state, hash_) # commit only if we created storage if not extract: - extract_.commit_packages(doc) - return doc - return None + extract_.commit_packages() def _list_schemas_sorted(self) -> List[str]: """Lists schema names sorted to have deterministic state""" diff --git a/dlt/pipeline/state_sync.py b/dlt/pipeline/state_sync.py index 41009f2909..11648328f2 100644 --- a/dlt/pipeline/state_sync.py +++ b/dlt/pipeline/state_sync.py @@ -4,7 +4,8 @@ import dlt from dlt.common.pendulum import pendulum from dlt.common.typing import DictStrAny -from dlt.common.schema.typing import STATE_TABLE_NAME, TTableSchemaColumns +from dlt.common.schema.typing import PIPELINE_STATE_TABLE_NAME +from dlt.common.schema.utils import pipeline_state_table from dlt.common.destination.reference import WithStateSync, Destination, StateInfo from dlt.common.versioned_state import ( generate_state_version_hash, @@ -24,20 +25,6 @@ PIPELINE_STATE_ENGINE_VERSION = 4 LOAD_PACKAGE_STATE_KEY = "pipeline_state" -# state table columns -STATE_TABLE_COLUMNS: TTableSchemaColumns = { - "version": {"name": "version", "data_type": "bigint", "nullable": False}, - "engine_version": {"name": "engine_version", "data_type": "bigint", "nullable": False}, - "pipeline_name": {"name": "pipeline_name", "data_type": "text", "nullable": False}, - "state": {"name": "state", "data_type": "text", "nullable": False}, - "created_at": {"name": "created_at", "data_type": "timestamp", "nullable": False}, - "version_hash": { - "name": "version_hash", - "data_type": "text", - "nullable": True, - }, # set to nullable so we can migrate existing tables -} - def generate_pipeline_state_version_hash(state: TPipelineState) -> str: return generate_state_version_hash(state, exclude_attrs=["_local"]) @@ -98,27 +85,28 @@ def state_doc(state: TPipelineState, load_id: str = None) -> TPipelineStateDoc: state = copy(state) state.pop("_local") state_str = compress_state(state) - doc: TPipelineStateDoc = { - "version": state["_state_version"], - "engine_version": state["_state_engine_version"], - "pipeline_name": state["pipeline_name"], - "state": state_str, - "created_at": pendulum.now(), - "version_hash": state["_version_hash"], - } - if load_id: - doc["dlt_load_id"] = load_id - return doc + info = StateInfo( + version=state["_state_version"], + engine_version=state["_state_engine_version"], + pipeline_name=state["pipeline_name"], + state=state_str, + created_at=pendulum.now(), + version_hash=state["_version_hash"], + _dlt_load_id=load_id, + ) + return info.as_doc() -def state_resource(state: TPipelineState) -> Tuple[DltResource, TPipelineStateDoc]: - doc = state_doc(state) +def state_resource(state: TPipelineState, load_id: str) -> Tuple[DltResource, TPipelineStateDoc]: + doc = state_doc(state, load_id) + state_table = pipeline_state_table() return ( dlt.resource( [doc], - name=STATE_TABLE_NAME, - write_disposition="append", - columns=STATE_TABLE_COLUMNS, + name=PIPELINE_STATE_TABLE_NAME, + write_disposition=state_table["write_disposition"], + file_format=state_table["file_format"], + columns=state_table["columns"], ), doc, ) diff --git a/dlt/pipeline/warnings.py b/dlt/pipeline/warnings.py index 8bee670cb7..ac46a4eef0 100644 --- a/dlt/pipeline/warnings.py +++ b/dlt/pipeline/warnings.py @@ -5,23 +5,6 @@ from dlt.common.destination import Destination, TDestinationReferenceArg -def credentials_argument_deprecated( - caller_name: str, credentials: t.Optional[t.Any], destination: TDestinationReferenceArg = None -) -> None: - if credentials is None: - return - - dest_name = Destination.to_name(destination) if destination else "postgres" - - warnings.warn( - f"The `credentials argument` to {caller_name} is deprecated and will be removed in a future" - " version. Pass the same credentials to the `destination` instance instead, e.g." - f" {caller_name}(destination=dlt.destinations.{dest_name}(credentials=...))", - Dlt04DeprecationWarning, - stacklevel=2, - ) - - def full_refresh_argument_deprecated(caller_name: str, full_refresh: t.Optional[bool]) -> None: """full_refresh argument is replaced with dev_mode""" if full_refresh is None: diff --git a/dlt/sources/helpers/requests/retry.py b/dlt/sources/helpers/requests/retry.py index 3f9d7d559e..7d7d6493ec 100644 --- a/dlt/sources/helpers/requests/retry.py +++ b/dlt/sources/helpers/requests/retry.py @@ -119,8 +119,8 @@ def _make_retry( retry_conds = [retry_if_status(status_codes), retry_if_exception_type(tuple(exceptions))] if condition is not None: if callable(condition): - retry_condition = [condition] - retry_conds.extend([retry_if_predicate(c) for c in retry_condition]) + condition = [condition] + retry_conds.extend([retry_if_predicate(c) for c in condition]) wait_cls = wait_exponential_retry_after if respect_retry_after_header else wait_exponential return Retrying( diff --git a/dlt/sources/helpers/rest_client/auth.py b/dlt/sources/helpers/rest_client/auth.py index 29e6d8c77a..d2ca1c1ca6 100644 --- a/dlt/sources/helpers/rest_client/auth.py +++ b/dlt/sources/helpers/rest_client/auth.py @@ -1,17 +1,18 @@ -from base64 import b64encode -import dataclasses import math +import dataclasses +from abc import abstractmethod +from base64 import b64encode from typing import ( - List, + TYPE_CHECKING, + Any, Dict, Final, + Iterable, + List, Literal, Optional, Union, - Any, cast, - Iterable, - TYPE_CHECKING, ) from typing_extensions import Annotated from requests.auth import AuthBase @@ -24,7 +25,6 @@ from dlt.common.configuration.specs.exceptions import NativeValueError from dlt.common.pendulum import pendulum from dlt.common.typing import TSecretStrValue - from dlt.sources.helpers import requests if TYPE_CHECKING: @@ -144,6 +144,76 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest: return request +@configspec +class OAuth2ClientCredentials(OAuth2AuthBase): + """ + This class implements OAuth2 Client Credentials flow where the autorization service + gives permission without the end user approving. + This is often used for machine-to-machine authorization. + The client sends its client ID and client secret to the authorization service which replies + with a temporary access token. + With the access token, the client can access resource services. + """ + + def __init__( + self, + access_token_url: TSecretStrValue, + client_id: TSecretStrValue, + client_secret: TSecretStrValue, + access_token_request_data: Dict[str, Any] = None, + default_token_expiration: int = 3600, + session: Annotated[BaseSession, NotResolved()] = None, + ) -> None: + super().__init__() + self.access_token_url = access_token_url + self.client_id = client_id + self.client_secret = client_secret + if access_token_request_data is None: + self.access_token_request_data = {} + else: + self.access_token_request_data = access_token_request_data + self.default_token_expiration = default_token_expiration + self.token_expiry: pendulum.DateTime = pendulum.now() + + self.session = session if session is not None else requests.client.session + + def __call__(self, request: PreparedRequest) -> PreparedRequest: + if self.access_token is None or self.is_token_expired(): + self.obtain_token() + request.headers["Authorization"] = f"Bearer {self.access_token}" + return request + + def is_token_expired(self) -> bool: + return pendulum.now() >= self.token_expiry + + def obtain_token(self) -> None: + response = self.session.post(self.access_token_url, **self.build_access_token_request()) + response.raise_for_status() + response_json = response.json() + self.parse_native_representation(self.parse_access_token(response_json)) + expires_in_seconds = self.parse_expiration_in_seconds(response_json) + self.token_expiry = pendulum.now().add(seconds=expires_in_seconds) + + def build_access_token_request(self) -> Dict[str, Any]: + return { + "headers": { + "Content-Type": "application/x-www-form-urlencoded", + }, + "data": { + "client_id": self.client_id, + "client_secret": self.client_secret, + "grant_type": "client_credentials", + **self.access_token_request_data, + }, + } + + def parse_expiration_in_seconds(self, response_json: Any) -> int: + return int(response_json.get("expires_in", self.default_token_expiration)) + + def parse_access_token(self, response_json: Any) -> str: + return str(response_json.get("access_token")) + + @configspec class OAuthJWTAuth(BearerTokenAuth): """This is a form of Bearer auth, actually there's not standard way to declare it in openAPI""" @@ -164,7 +234,7 @@ def __post_init__(self) -> None: self.scopes = self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes) self.token = None self.token_expiry: Optional[pendulum.DateTime] = None - # use default system session is not specified + # use default system session unless specified otherwise if self.session is None: self.session = requests.client.session diff --git a/dlt/sources/helpers/rest_client/client.py b/dlt/sources/helpers/rest_client/client.py index e6135b5c0f..2cc19f6624 100644 --- a/dlt/sources/helpers/rest_client/client.py +++ b/dlt/sources/helpers/rest_client/client.py @@ -95,7 +95,7 @@ def _create_request( self, path: str, method: HTTPMethod, - params: Dict[str, Any], + params: Optional[Dict[str, Any]] = None, json: Optional[Dict[str, Any]] = None, auth: Optional[AuthBase] = None, hooks: Optional[Hooks] = None, diff --git a/dlt/sources/helpers/rest_client/paginators.py b/dlt/sources/helpers/rest_client/paginators.py index 22cdc9b415..701f0c914b 100644 --- a/dlt/sources/helpers/rest_client/paginators.py +++ b/dlt/sources/helpers/rest_client/paginators.py @@ -91,6 +91,7 @@ def __init__( param_name: str, initial_value: int, value_step: int, + base_index: int = 0, maximum_value: Optional[int] = None, total_path: Optional[jsonpath.TJsonPath] = None, error_message_items: str = "items", @@ -101,6 +102,8 @@ def __init__( For example, 'page'. initial_value (int): The initial value of the numeric parameter. value_step (int): The step size to increment the numeric parameter. + base_index (int, optional): The index of the initial element. + Used to define 0-based or 1-based indexing. Defaults to 0. maximum_value (int, optional): The maximum value for the numeric parameter. If provided, pagination will stop once this value is reached or exceeded, even if more data is available. This allows you @@ -119,6 +122,7 @@ def __init__( self.param_name = param_name self.current_value = initial_value self.value_step = value_step + self.base_index = base_index self.maximum_value = maximum_value self.total_path = jsonpath.compile_path(total_path) if total_path else None self.error_message_items = error_message_items @@ -145,7 +149,7 @@ def update_state(self, response: Response) -> None: self.current_value += self.value_step - if (total is not None and self.current_value >= total) or ( + if (total is not None and self.current_value >= total + self.base_index) or ( self.maximum_value is not None and self.current_value >= self.maximum_value ): self._has_next_page = False @@ -219,14 +223,20 @@ def get_items(): def __init__( self, - initial_page: int = 0, + base_page: int = 0, + page: int = None, page_param: str = "page", total_path: jsonpath.TJsonPath = "total", maximum_page: Optional[int] = None, ): """ Args: - initial_page (int): The initial page number. + base_page (int): The index of the initial page from the API perspective. + Determines the page number that the API server uses for the starting + page. Normally, this is 0-based or 1-based (e.g., 1, 2, 3, ...) + indexing for the pages. Defaults to 0. + page (int): The page number for the first request. If not provided, + the initial value will be set to `base_page`. page_param (str): The query parameter name for the page number. Defaults to 'page'. total_path (jsonpath.TJsonPath): The JSONPath expression for @@ -238,9 +248,13 @@ def __init__( """ if total_path is None and maximum_page is None: raise ValueError("Either `total_path` or `maximum_page` must be provided.") + + page = page if page is not None else base_page + super().__init__( param_name=page_param, - initial_value=initial_page, + initial_value=page, + base_index=base_page, total_path=total_path, value_step=1, maximum_value=maximum_page, @@ -420,6 +434,10 @@ def update_request(self, request: Request) -> None: request.url = self._next_reference + # Clear the query parameters from the previous request otherwise they + # will be appended to the next URL in Session.prepare_request + request.params = None + class HeaderLinkPaginator(BaseNextUrlPaginator): """A paginator that uses the 'Link' header in HTTP responses diff --git a/docs/examples/archive/credentials/explicit.py b/docs/examples/archive/credentials/explicit.py index b1bc25fce6..f07c69360a 100644 --- a/docs/examples/archive/credentials/explicit.py +++ b/docs/examples/archive/credentials/explicit.py @@ -1,6 +1,7 @@ import os from typing import Iterator import dlt +from dlt.destinations import postgres @dlt.resource @@ -32,14 +33,14 @@ def simple_data( # you are free to pass credentials from custom location to destination pipeline = dlt.pipeline( - destination="postgres", credentials=dlt.secrets["custom.destination.credentials"] + destination=postgres(credentials=dlt.secrets["custom.destination.credentials"]) ) # see nice credentials object print(pipeline.credentials) # you can also pass credentials partially, only the password comes from the secrets or environment pipeline = dlt.pipeline( - destination="postgres", credentials="postgres://loader@localhost:5432/dlt_data" + destination=postgres(credentials="postgres://loader@localhost:5432/dlt_data") ) # now lets compare it with default location for config and credentials diff --git a/docs/examples/archive/quickstart.py b/docs/examples/archive/quickstart.py index f435fa3fab..6806c177ce 100644 --- a/docs/examples/archive/quickstart.py +++ b/docs/examples/archive/quickstart.py @@ -46,7 +46,6 @@ pipeline_name, destination=destination_name, dataset_name=dataset_name, - credentials=credentials, export_schema_path=export_schema_path, dev_mode=True, ) @@ -69,7 +68,9 @@ }, ] -load_info = pipeline.run(rows, table_name=table_name, write_disposition="replace") +load_info = pipeline.run( + rows, table_name=table_name, write_disposition="replace", credentials=credentials +) # 4. Optional error handling - print, raise or handle. print() diff --git a/docs/examples/conftest.py b/docs/examples/conftest.py index 87ccffe53b..be1a03990b 100644 --- a/docs/examples/conftest.py +++ b/docs/examples/conftest.py @@ -1,3 +1,4 @@ +import sys import os import pytest from unittest.mock import patch @@ -47,7 +48,11 @@ def _initial_providers(): ): # extras work when container updated glob_ctx.add_extras() - yield + try: + sys.path.insert(0, dname) + yield + finally: + sys.path.pop(0) def pytest_configure(config): diff --git a/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py b/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py index 380912a9a7..ce4b2a12d0 100644 --- a/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py +++ b/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py @@ -86,7 +86,7 @@ def bigquery_insert( pipeline_name="csv_to_bigquery_insert", destination=bigquery_insert, dataset_name="mydata", - full_refresh=True, + dev_mode=True, ) load_info = pipeline.run(resource(url=OWID_DISASTERS_URL)) diff --git a/docs/examples/custom_destination_lancedb/.dlt/config.toml b/docs/examples/custom_destination_lancedb/.dlt/config.toml new file mode 100644 index 0000000000..4fd35e1159 --- /dev/null +++ b/docs/examples/custom_destination_lancedb/.dlt/config.toml @@ -0,0 +1,2 @@ +[lancedb] +db_path = "spotify.db" \ No newline at end of file diff --git a/docs/examples/custom_destination_lancedb/.dlt/example.secrets.toml b/docs/examples/custom_destination_lancedb/.dlt/example.secrets.toml new file mode 100644 index 0000000000..9c86df320c --- /dev/null +++ b/docs/examples/custom_destination_lancedb/.dlt/example.secrets.toml @@ -0,0 +1,7 @@ +[spotify] +client_id = "" +client_secret = "" + +# provide the openai api key here +[destination.lancedb.credentials] +embedding_model_provider_api_key = "" \ No newline at end of file diff --git a/docs/examples/custom_destination_lancedb/.gitignore b/docs/examples/custom_destination_lancedb/.gitignore new file mode 100644 index 0000000000..c73564481b --- /dev/null +++ b/docs/examples/custom_destination_lancedb/.gitignore @@ -0,0 +1 @@ +spotify.db \ No newline at end of file diff --git a/docs/examples/custom_destination_lancedb/__init__.py b/docs/examples/custom_destination_lancedb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/examples/custom_destination_lancedb/custom_destination_lancedb.py b/docs/examples/custom_destination_lancedb/custom_destination_lancedb.py new file mode 100644 index 0000000000..ba815d4fcd --- /dev/null +++ b/docs/examples/custom_destination_lancedb/custom_destination_lancedb.py @@ -0,0 +1,157 @@ +""" +--- +title: Custom Destination with LanceDB +description: Learn how use the custom destination to load to LanceDB. +keywords: [destination, credentials, example, lancedb, custom destination, vectorstore, AI, LLM] +--- + +This example showcases a Python script that demonstrates the integration of LanceDB, an open-source vector database, +as a custom destination within the dlt ecosystem. +The script illustrates the implementation of a custom destination as well as the population of the LanceDB vector +store with data from various sources. +This highlights the seamless interoperability between dlt and LanceDB. + +You can get a Spotify client ID and secret from https://developer.spotify.com/. + +We'll learn how to: +- Use the [custom destination](../dlt-ecosystem/destinations/destination.md) +- Delegate the embeddings to LanceDB using OpenAI Embeddings +""" + +__source_name__ = "spotify" + +import datetime # noqa: I251 +import os +from dataclasses import dataclass, fields +from pathlib import Path +from typing import Any + +import lancedb # type: ignore +from lancedb.embeddings import get_registry # type: ignore +from lancedb.pydantic import LanceModel, Vector # type: ignore + +import dlt +from dlt.common.configuration import configspec +from dlt.common.schema import TTableSchema +from dlt.common.typing import TDataItems, TSecretStrValue +from dlt.sources.helpers import requests +from dlt.sources.helpers.rest_client import RESTClient, AuthConfigBase + +# access secrets to get openai key and instantiate embedding function +openai_api_key: str = dlt.secrets.get( + "destination.lancedb.credentials.embedding_model_provider_api_key" +) +func = get_registry().get("openai").create(name="text-embedding-3-small", api_key=openai_api_key) + + +class EpisodeSchema(LanceModel): + id: str # noqa: A003 + name: str + description: str = func.SourceField() + vector: Vector(func.ndims()) = func.VectorField() # type: ignore[valid-type] + release_date: datetime.date + href: str + + +@dataclass(frozen=True) +class Shows: + monday_morning_data_chat: str = "3Km3lBNzJpc1nOTJUtbtMh" + latest_space_podcast: str = "2p7zZVwVF6Yk0Zsb4QmT7t" + superdatascience_podcast: str = "1n8P7ZSgfVLVJ3GegxPat1" + lex_fridman: str = "2MAi0BvDc6GTFvKFPXnkCL" + + +@configspec +class SpotifyAuth(AuthConfigBase): + client_id: str = None + client_secret: TSecretStrValue = None + + def __call__(self, request) -> Any: + if not hasattr(self, "access_token"): + self.access_token = self._get_access_token() + request.headers["Authorization"] = f"Bearer {self.access_token}" + return request + + def _get_access_token(self) -> Any: + auth_url = "https://accounts.spotify.com/api/token" + auth_response = requests.post( + auth_url, + { + "grant_type": "client_credentials", + "client_id": self.client_id, + "client_secret": self.client_secret, + }, + ) + return auth_response.json()["access_token"] + + +@dlt.source +def spotify_shows( + client_id: str = dlt.secrets.value, + client_secret: str = dlt.secrets.value, +): + spotify_base_api_url = "https://api.spotify.com/v1" + client = RESTClient( + base_url=spotify_base_api_url, + auth=SpotifyAuth(client_id=client_id, client_secret=client_secret), # type: ignore[arg-type] + ) + + for show in fields(Shows): + show_name = show.name + show_id = show.default + url = f"/shows/{show_id}/episodes" + yield dlt.resource( + client.paginate(url, params={"limit": 50}), + name=show_name, + write_disposition="merge", + primary_key="id", + parallelized=True, + max_table_nesting=0, + ) + + +@dlt.destination(batch_size=250, name="lancedb") +def lancedb_destination(items: TDataItems, table: TTableSchema) -> None: + db_path = Path(dlt.config.get("lancedb.db_path")) + db = lancedb.connect(db_path) + + # since we are embedding the description field, we need to do some additional cleaning + # for openai. Openai will not accept empty strings or input with more than 8191 tokens + for item in items: + item["description"] = item.get("description") or "No Description" + item["description"] = item["description"][0:8000] + try: + tbl = db.open_table(table["name"]) + except FileNotFoundError: + tbl = db.create_table(table["name"], schema=EpisodeSchema) + tbl.add(items) + + +if __name__ == "__main__": + db_path = Path(dlt.config.get("lancedb.db_path")) + db = lancedb.connect(db_path) + + for show in fields(Shows): + db.drop_table(show.name, ignore_missing=True) + + pipeline = dlt.pipeline( + pipeline_name="spotify", + destination=lancedb_destination, + dataset_name="spotify_podcast_data", + progress="log", + ) + + load_info = pipeline.run(spotify_shows()) + load_info.raise_on_failed_jobs() + print(load_info) + + row_counts = pipeline.last_trace.last_normalize_info + print(row_counts) + + query = "French AI scientist with Lex, talking about AGI and Meta and Llama" + table_to_query = "lex_fridman" + + tbl = db.open_table(table_to_query) + + results = tbl.search(query=query).to_list() + assert results diff --git a/docs/examples/custom_naming/.dlt/config.toml b/docs/examples/custom_naming/.dlt/config.toml new file mode 100644 index 0000000000..ba5c8ab73a --- /dev/null +++ b/docs/examples/custom_naming/.dlt/config.toml @@ -0,0 +1,2 @@ +[sources.sql_ci_no_collision.schema] +naming="sql_ci_no_collision" \ No newline at end of file diff --git a/docs/examples/custom_naming/__init__.py b/docs/examples/custom_naming/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/examples/custom_naming/custom_naming.py b/docs/examples/custom_naming/custom_naming.py new file mode 100644 index 0000000000..e99e582213 --- /dev/null +++ b/docs/examples/custom_naming/custom_naming.py @@ -0,0 +1,90 @@ +""" +--- +title: Create and use own naming convention +description: We demonstrate how to create naming conventions that allow UNICODE letters and never generate collisions +keywords: [example] +--- + +This example shows how to add and use custom naming convention. Naming conventions translate identifiers found in source data into identifiers in +destination, where rules for a valid identifier are constrained. + +Custom naming conventions are classes that derive from `NamingConvention` that you can import from `dlt.common.normalizers.naming`. We recommend the following module layout: +1. Each naming convention resides in a separate Python module (file) +2. The class is always named `NamingConvention` + +There are two naming conventions in this example: +1. A variant of `sql_ci` that generates identifier collisions with a low (user defined) probability by appending a deterministic tag to each name. +2. A variant of `sql_cs` that allows for LATIN (ie. umlaut) characters + +With this example you will learn to: +* Create a naming convention module with a recommended layout +* Use naming convention by explicitly passing it to `duckdb` destination factory +* Use naming convention by configuring it config.toml +* Changing the declared case sensitivity by overriding `is_case_sensitive` property +* Providing custom normalization logic by overriding `normalize_identifier` method + +""" + +import dlt + +if __name__ == "__main__": + # sql_cs_latin2 module + import sql_cs_latin2 # type: ignore[import-not-found] + + # create postgres destination with a custom naming convention. pass sql_cs_latin2 as module + # NOTE: ql_cs_latin2 is case sensitive and postgres accepts UNICODE letters in identifiers + dest_ = dlt.destinations.postgres( + "postgresql://loader:loader@localhost:5432/dlt_data", naming_convention=sql_cs_latin2 + ) + # run a pipeline + pipeline = dlt.pipeline( + pipeline_name="sql_cs_latin2_pipeline", + destination=dest_, + dataset_name="example_data", + dev_mode=True, + ) + # Extract, normalize, and load the data + load_info = pipeline.run([{"StückId": 1}], table_name="Ausrüstung") + print(load_info) + # make sure nothing failed + load_info.raise_on_failed_jobs() + with pipeline.sql_client() as client: + # NOTE: we quote case sensitive identifers + with client.execute_query('SELECT "StückId" FROM "Ausrüstung"') as cur: + print(cur.description) + print(cur.fetchone()) + + # sql_ci_no_collision (configured in config toml) + # NOTE: pipeline with name `sql_ci_no_collision` will create default schema with the same name + # so we are free to use it in config.toml to just affect this pipeline and leave the postgres pipeline as it is + pipeline = dlt.pipeline( + pipeline_name="sql_ci_no_collision", + destination="duckdb", + dataset_name="example_data", + dev_mode=True, + ) + # duckdb is case insensitive so tables and columns below would clash but sql_ci_no_collision prevents that + data_1 = {"ItemID": 1, "itemid": "collides"} + load_info = pipeline.run([data_1], table_name="BigData") + load_info.raise_on_failed_jobs() + + data_2 = {"1Data": 1, "_1data": "collides"} + # use colliding table + load_info = pipeline.run([data_2], table_name="bigdata") + load_info.raise_on_failed_jobs() + + with pipeline.sql_client() as client: + from duckdb import DuckDBPyConnection + + conn: DuckDBPyConnection = client.native_connection + # tags are deterministic so we can just use the naming convention to get table names to select + first_table = pipeline.default_schema.naming.normalize_table_identifier("BigData") + sql = f"DESCRIBE TABLE {first_table}" + print(sql) + print(conn.sql(sql)) + second_table = pipeline.default_schema.naming.normalize_table_identifier("bigdata") + sql = f"DESCRIBE TABLE {second_table}" + print(sql) + print(conn.sql(sql)) + + # print(pipeline.default_schema.to_pretty_yaml()) diff --git a/docs/examples/custom_naming/sql_ci_no_collision.py b/docs/examples/custom_naming/sql_ci_no_collision.py new file mode 100644 index 0000000000..276107ea2b --- /dev/null +++ b/docs/examples/custom_naming/sql_ci_no_collision.py @@ -0,0 +1,34 @@ +from typing import ClassVar + +from dlt.common.normalizers.naming.sql_cs_v1 import NamingConvention as SqlNamingConvention +from dlt.common.schema.typing import DLT_NAME_PREFIX + + +class NamingConvention(SqlNamingConvention): + """Case insensitive naming convention with all identifiers lowercases but with unique short tag added""" + + # we will reuse the code we use for shortening + # 1 in 100 prob of collision for identifiers identical after normalization + _DEFAULT_COLLISION_PROB: ClassVar[float] = 0.01 + + def normalize_identifier(self, identifier: str) -> str: + # compute unique tag on original (not normalized) identifier + # NOTE: you may wrap method below in lru_cache if you often normalize the same names + tag = self._compute_tag(identifier, self._DEFAULT_COLLISION_PROB) + # lower case + norm_identifier = identifier.lower() + # add tag if (not a dlt identifier) and tag was not added before (simple heuristics) + if "_4" in norm_identifier: + _, existing_tag = norm_identifier.rsplit("_4", 1) + has_tag = len(existing_tag) == len(tag) + else: + has_tag = False + if not norm_identifier.startswith(DLT_NAME_PREFIX) and not has_tag: + norm_identifier = norm_identifier + "_4" + tag + # run identifier through standard sql cleaning and shortening + return super().normalize_identifier(norm_identifier) + + @property + def is_case_sensitive(self) -> bool: + # switch the naming convention to case insensitive + return False diff --git a/docs/examples/custom_naming/sql_cs_latin2.py b/docs/examples/custom_naming/sql_cs_latin2.py new file mode 100644 index 0000000000..7cf31cc76a --- /dev/null +++ b/docs/examples/custom_naming/sql_cs_latin2.py @@ -0,0 +1,21 @@ +from typing import ClassVar + +# NOTE: we use regex library that supports unicode +import regex as re + +from dlt.common.normalizers.naming.sql_cs_v1 import NamingConvention as SqlNamingConvention +from dlt.common.typing import REPattern + + +class NamingConvention(SqlNamingConvention): + """Case sensitive naming convention which allows basic unicode characters, including latin 2 characters""" + + RE_NON_ALPHANUMERIC: ClassVar[REPattern] = re.compile(r"[^\p{Latin}\d_]+") # type: ignore + + def normalize_identifier(self, identifier: str) -> str: + # typically you'd change how a single + return super().normalize_identifier(identifier) + + @property + def is_case_sensitive(self) -> bool: + return True diff --git a/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py b/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py index 809a6cfbd6..5fbba98a21 100644 --- a/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py +++ b/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py @@ -25,7 +25,7 @@ import os import dlt -from dlt.destinations.impl.weaviate import weaviate_adapter +from dlt.destinations.adapters import weaviate_adapter from PyPDF2 import PdfReader diff --git a/docs/examples/postgres_to_postgres/postgres_to_postgres.py b/docs/examples/postgres_to_postgres/postgres_to_postgres.py index f5327ee236..c6502f236a 100644 --- a/docs/examples/postgres_to_postgres/postgres_to_postgres.py +++ b/docs/examples/postgres_to_postgres/postgres_to_postgres.py @@ -91,7 +91,7 @@ def pg_resource_chunked( order_date: str, load_type: str = "merge", columns: str = "*", - credentials: ConnectionStringCredentials = dlt.secrets["sources.postgres.credentials"], + credentials: ConnectionStringCredentials = None, ): print( f"dlt.resource write_disposition: `{load_type}` -- ", @@ -162,6 +162,7 @@ def table_desc(table_name, pk, schema_name, order_date, columns="*"): table["order_date"], load_type=load_type, columns=table["columns"], + credentials=dlt.secrets["sources.postgres.credentials"], ) ) @@ -170,7 +171,7 @@ def table_desc(table_name, pk, schema_name, order_date, columns="*"): pipeline_name=pipeline_name, destination="duckdb", dataset_name=target_schema_name, - full_refresh=True, + dev_mode=True, progress="alive_progress", ) else: @@ -178,8 +179,8 @@ def table_desc(table_name, pk, schema_name, order_date, columns="*"): pipeline_name=pipeline_name, destination="postgres", dataset_name=target_schema_name, - full_refresh=False, - ) # full_refresh=False + dev_mode=False, + ) # dev_mode=False # start timer startTime = pendulum.now() diff --git a/docs/examples/qdrant_zendesk/qdrant_zendesk.py b/docs/examples/qdrant_zendesk/qdrant_zendesk.py index 7fb55fe842..5416f2f2d0 100644 --- a/docs/examples/qdrant_zendesk/qdrant_zendesk.py +++ b/docs/examples/qdrant_zendesk/qdrant_zendesk.py @@ -38,8 +38,6 @@ from dlt.destinations.adapters import qdrant_adapter from qdrant_client import QdrantClient -from dlt.common.configuration.inject import with_config - # function from: https://github.com/dlt-hub/verified-sources/tree/master/sources/zendesk @dlt.source(max_table_nesting=2) @@ -181,29 +179,22 @@ def get_pages( # make sure nothing failed load_info.raise_on_failed_jobs() - # running the Qdrant client to connect to your Qdrant database - - @with_config(sections=("destination", "qdrant", "credentials")) - def get_qdrant_client(location=dlt.secrets.value, api_key=dlt.secrets.value): - return QdrantClient( - url=location, - api_key=api_key, - ) - - # running the Qdrant client to connect to your Qdrant database - qdrant_client = get_qdrant_client() + # getting the authenticated Qdrant client to connect to your Qdrant database + with pipeline.destination_client() as destination_client: + from qdrant_client import QdrantClient - # view Qdrant collections you'll find your dataset here: - print(qdrant_client.get_collections()) + qdrant_client: QdrantClient = destination_client.db_client # type: ignore + # view Qdrant collections you'll find your dataset here: + print(qdrant_client.get_collections()) - # query Qdrant with prompt: getting tickets info close to "cancellation" - response = qdrant_client.query( - "zendesk_data_content", # collection/dataset name with the 'content' suffix -> tickets content table - query_text=["cancel", "cancel subscription"], # prompt to search - limit=3, # limit the number of results to the nearest 3 embeddings - ) + # query Qdrant with prompt: getting tickets info close to "cancellation" + response = qdrant_client.query( + "zendesk_data_content", # collection/dataset name with the 'content' suffix -> tickets content table + query_text="cancel subscription", # prompt to search + limit=3, # limit the number of results to the nearest 3 embeddings + ) - assert len(response) <= 3 and len(response) > 0 + assert len(response) <= 3 and len(response) > 0 - # make sure nothing failed - load_info.raise_on_failed_jobs() + # make sure nothing failed + load_info.raise_on_failed_jobs() diff --git a/docs/technical/README.md b/docs/technical/README.md deleted file mode 100644 index 6e2b5048a8..0000000000 --- a/docs/technical/README.md +++ /dev/null @@ -1,10 +0,0 @@ -## Finished documents - -1. [general_usage.md](general_usage.md) -2. [create_pipeline.md](create_pipeline.md) -3. [secrets_and_config.md](secrets_and_config.md) -4. [working_with_schemas.md](working_with_schemas.md) - -## In progress - -5. [customization_and_hacking.md](customization_and_hacking.md) diff --git a/docs/technical/create_pipeline.md b/docs/technical/create_pipeline.md deleted file mode 100644 index f6603d08b8..0000000000 --- a/docs/technical/create_pipeline.md +++ /dev/null @@ -1,441 +0,0 @@ -# Create Pipeline -marks features that are: - -⛔ not implemented, hard to add - -☮️ not implemented, easy to add - - -## Example from `dlt` module docstring -It is possible to create "intuitive" pipeline just by providing a list of objects to `dlt.run` methods No decorators and secret files, configurations are necessary. - -```python -import dlt -from dlt.sources.helpers import requests - -dlt.run( - requests.get("https://api.chess.com/pub/player/magnuscarlsen/games/2022/11").json()["games"], - destination="duckdb", - table_name="magnus_games" -) -``` - -Run your pipeline script -`$ python magnus_games.py` - -See and query your data with autogenerated Streamlit app -`$ dlt pipeline dlt_magnus_games show` - -## Source extractor function the preferred way -General guidelines: -1. the source extractor is a function decorated with `@dlt.source`. that function **yields** or **returns** a list of resources. -2. resources are generator functions that always **yield** data (enforced by exception which I hope is user friendly). Access to external endpoints, databases etc. should happen from that generator function. Generator functions may be decorated with `@dlt.resource` to provide alternative names, write disposition etc. -3. resource generator functions can be OFC parametrized and resources may be created dynamically -4. the resource generator function may yield **anything that is json serializable**. we prefer to yield _dict_ or list of dicts. -> yielding lists is much more efficient in terms of processing! -5. like any other iterator, the @dlt.source and @dlt.resource **can be iterated and thus extracted and loaded only once**, see example below. - -**Remarks:** - -1. the **@dlt.resource** let's you define the table schema hints: `name`, `write_disposition`, `columns` -2. the **@dlt.source** let's you define global schema props: `name` (which is also source name), `schema` which is Schema object if explicit schema is provided `nesting` to set nesting level etc. -3. decorators can also be used as functions ie in case of dlt.resource and `lazy_function` (see examples) - -```python -endpoints = ["songs", "playlist", "albums"] -# return list of resourced -return [dlt.resource(lazy_function(endpoint, name=endpoint) for endpoint in endpoints)] - -``` - -### Extracting data -Source function is not meant to extract the data, but in many cases getting some metadata ie. to generate dynamic resources (like in case of google sheets example) is unavoidable. The source function's body is evaluated **outside** the pipeline `run` (if `dlt.source` is a generator, it is immediately consumed). - -Actual extraction of the data should happen inside the `dlt.resource` which is lazily executed inside the `dlt` pipeline. - -> both a `dlt` source and resource are regular Python iterators and can be passed to any python function that accepts them ie to `list`. `dlt` will evaluate such iterators, also parallel and async ones and provide mock state to it. - -## Multiple resources and resource selection when loading -The source extraction function may contain multiple resources. The resources can be defined as multiple resource functions or created dynamically ie. with parametrized generators. -The user of the pipeline can check what resources are available and select the resources to load. - - -**each resource has a a separate resource function** -```python -from dlt.sources.helpers import requests -import dlt - -@dlt.source -def hubspot(...): - - @dlt.resource(write_disposition="replace") - def users(): - # calls to API happens here - ... - yield users - - @dlt.resource(write_disposition="append") - def transactions(): - ... - yield transactions - - # return a list of resources - return users, transactions - -# load all resources -taktile_data(1).run(destination=bigquery) -# load only decisions -taktile_data(1).with_resources("decisions").run(....) - -# alternative form: -source = taktile_data(1) -# select only decisions to be loaded -source.resources.select("decisions") -# see what is selected -print(source.selected_resources) -# same as this -print(source.resources.selected) -``` - -Except being accessible via `source.resources` dictionary, **every resource is available as an attribute of the source**. For the example above -```python -print(list(source.decisions)) # will iterate decisions resource -source.logs.selected = False # deselect resource -``` - -## Resources may be created dynamically -Here we implement a single parametrized function that **yields** data and we call it repeatedly. Mind that the function body won't be executed immediately, only later when generator is consumed in extract stage. - -```python - -@dlt.source -def spotify(): - - endpoints = ["songs", "playlists", "albums"] - - def get_resource(endpoint): - # here we yield the whole response - yield requests.get(url + "/" + endpoint).json() - - # here we yield resources because this produces cleaner code - for endpoint in endpoints: - # calling get_resource creates generator, the actual code of the function will be executed in extractor - yield dlt.resource(get_resource(endpoint), name=endpoint) - -``` - -## Unbound (parametrized) resources -Imagine the situation in which you have a resource for which you want (or require) user to pass some options ie. the number of records returned. - -> try it, it is ⚡ powerful - -1. In all examples above you do that via the source and returned resources are not parametrized. -OR -2. You can return a **parametrized (unbound)** resources from the source. - -```python - -@dlt.source -def chess(chess_api_url): - - # let people choose player title, the default is grand master - @dlt.resource - def players(title_filter="GM", max_results=10): - yield - - # ❗ return the players without the calling - return players - -s = chess("url") -# let's parametrize the resource to select masters. you simply call `bind` method on the resource to bind it -# if you do not bind it, the default values are used -s.players.bind("M", max_results=1000) -# load the masters -s.run() - -``` - -## A standalone @resource -A general purpose resource (ie. jsonl reader, generic sql query reader etc.) that you want to add to any of your sources or multiple instances of it to your pipelines? -Yeah definitely possible. Just replace `@source` with `@resource` decorator. - -```python -@dlt.resource(name="logs", write_disposition="append") -def taktile_data(initial_log_id, taktile_api_key=dlt.secret.value): - - # yes, this will also work but data will be obtained immediately when taktile_data() is called. - resp = requests.get( - "https://taktile.com/api/v2/logs?from_log_id=%i" % initial_log_id, - headers={"Authorization": taktile_api_key}) - resp.raise_for_status() - for item in resp.json()["result"]: - yield item - -# this will load the resource into default schema. see `general_usage.md) -dlt.run(source=taktile_data(1), destination=bigquery) - -``` -How standalone resource works: -1. It can be used like a source that contains only one resource (ie. single endpoint) -2. The main difference is that when extracted it will join the default schema in the pipeline (or explicitly passed schema) -3. It can be called from a `@source` function and then it becomes a resource of that source and joins the source schema - -## `dlt` state availability - -The state is a python dictionary-like object that is available within the `@dlt.source` and `@dlt.resource` decorated functions and may be read and written to. -The data within the state is loaded into destination together with any other extracted data and made automatically available to the source/resource extractor functions when they are run next time. -When using the state: -* Any JSON-serializable values can be written and the read from the state. -* The state available in the `dlt source` is read only and any changes will be discarded. Still it may be used to initialize the resources. -* The state available in the `dlt resource` is writable and written values will be available only once - -### State sharing and isolation across sources - -1. Each source and resources **in the same Python module** (no matter if they are standalone, inner or created dynamically) share the same state dictionary and is separated from other sources -2. Source accepts `section` argument which creates a separate state for that resource (and separate configuration as well). All sources with the same `section` share the state. -2. All the standalone resources and generators that do not belong to any source share the same state when being extracted (they are extracted withing ad-hoc created source) - -## Stream resources: dispatching data to several tables from single resources -What about resource like rasa tracker or singer tap that send a stream of events that should be routed to different tables? we have an answer (actually two): -1. in many cases the table name is based on the data item content (ie. you dispatch events of given type to different tables by event type). We can pass a function that takes the data item as input and returns table name. -```python -# send item to a table with name item["type"] -@dlt.resource(table_name=lambda i: i['type']) -def repo_events() -> Iterator[TDataItems]: - yield item -``` - -2. You can mark the yielded data with a table name (`dlt.mark.with_table_name`). This gives you full control on the name of the table - -see [here](docs/examples/sources/rasa/rasa.py) and [here](docs/examples/sources/singer_tap.py). - -## Source / resource config sections and arguments injection -You should read [secrets_and_config](secrets_and_config.md) now to understand how configs and credentials are passed to the decorated functions and how the users of them can configure their projects. - -Also look at the following [test](/tests/extract/test_decorators.py) : `test_source_sections` - -## Example sources and resources - -### With inner resource function -Resource functions can be placed inside the source extractor function. That lets them get access to source function input arguments and all the computations within the source function via so called closure. - -```python -from dlt.sources.helpers import requests -import dlt - -# the `dlt.source` tell the library that the decorated function is a source -# it will use function name `taktile_data` to name the source and the generated schema by default -# in general `@source` should **return** a list of resources or list of generators (function that yield data) -# @source may also **yield** resources or generators - if yielding is more convenient -# if @source returns or yields data - this will generate exception with a proper explanation. dlt user can always load the data directly without any decorators like in the previous example! -@dlt.source -def taktile_data(initial_log_id, taktile_api_key=dlt.secret.value): - - # the `dlt.resource` tells the `dlt.source` that the function defines a resource - # will use function name `logs` as resource/table name by default - # the function should **yield** the data items one by one or **yield** a list. - # here the decorator is optional: there are no parameters to `dlt.resource` - @dlt.resource - def logs(): - resp = requests.get( - "https://taktile.com/api/v2/logs?from_log_id=%i" % initial_log_id, - headers={"Authorization": taktile_api_key}) - resp.raise_for_status() - # option 1: yield the whole list - yield resp.json()["result"] - # or -> this is useful if you deal with a stream of data and for that you need an API that supports that, for example you could yield lists containing paginated results - for item in resp.json()["result"]: - yield item - - # as mentioned we return a resource or a list of resources - return logs - # this will also work - # return logs() -``` - -### With outer generator yielding data, and @resource created dynamically -```python - -def taktile_logs_data(initial_log_id, taktile_api_key=dlt.secret.value) - yield data - - -@dlt.source -def taktile_data(initial_log_id, taktile_api_key): - # pass the arguments and convert to resource - return dlt.resource(taktile_logs_data(initial_log_id, taktile_api_key), name="logs", write_disposition="append") -``` - -### A source with resources defined elsewhere -Example of the above -```python -from taktile.resources import logs - -@dlt.source -def taktile_data(initial_log_id, taktile_api_key=dlt.secret.value): - return logs(initial_log_id, taktile_api_key) -``` - -## Advanced Topics - -### Transformers ⚡ -This happens all the time: -1. We have an endpoint that returns a list of users and then we must get each profile with a separate call. -2. The situation above is getting even more complicated when we need that list in two places in our source ie. we want to get the profiles but also a list of transactions per user. - -Ideally we would obtain the list only once and then call and yield from the profiles and transactions endpoint in parallel so the extraction time is minimized. - -Here's example how to do that: [run resources and transformers in parallel threads](/docs/examples/chess/chess.py) and test named `test_evolve_schema` - -More on transformers: -1. you can have unbound (parametrized) transformers as well -2. you can use pipe '|' operator to pipe data from resources to transformers instead of binding them statically with `data_from`. -> see our [singer tap](/docs/examples/singer_tap_jsonl_example.py) example where we pipe a stream of document from `jsonl` into `raw_singer_tap` which is a standalone, unbound ⚡ transformer. -3. If transformer yields just one element you can `return` it instead. This allows you to apply the `retry` and `defer` (parallel execution) decorators directly to it. - -#### Transformer example - -Here we have a list of huge documents and we want to load into several tables. - -```python -@dlt.source -def spotify(): - - # deselect by default, we do not want to load the huge doc - @dlt.resource(selected=False) - def get_huge_doc(): - return requests.get(...) - - # make songs and playlists to be dependent on get_huge_doc - @dlt.transformer(data_from=get_huge_doc) - def songs(huge_doc): - yield huge_doc["songs"] - - @dlt.transformer(data_from=get_huge_doc) - def playlists(huge_doc): - yield huge_doc["playlists"] - - # as you can see the get_huge_doc is not even returned, nevertheless it will be evaluated (only once) - # the huge doc will not be extracted and loaded - return songs, playlists - # we could also use the pipe operator, intead of providing_data from - # return get_huge_doc | songs, get_huge_doc | playlists -``` - -## Data item transformations - -You can attach any number of transformations to your resource that are evaluated on item per item basis. The available transformation types: -* map - transform the data item -* filter - filter the data item -* yield map - a map that returns iterator (so single row may generate many rows) - -You can add and insert transformations on the `DltResource` object (ie. decorated function) -* resource.add_map -* resource.add_filter -* resource.add_yield_map - -> Transformations always deal with single items even if you return lists. - -You can add transformations to a resource (also within a source) **after it is created**. This allows to customize existing pipelines. The transformations may -be distributed with the pipeline or written ad hoc in pipeline script. -```python -# anonymize creates nice deterministic hash for any hashable data type (not implemented yet:) -from dlt.helpers import anonymize - -# example transformation provided by the user -def anonymize_user(user_data): - user_data["user_id"] = anonymize(user_data["user_id"]) - user_data["user_email"] = anonymize(user_data["user_email"]) - return user_data - -@dlt.source -def pipedrive(...): - ... - - @dlt.resource(write_disposition="replace") - def users(): - ... - users = requests.get(...) - ... - yield users - - return users, deals, customers -``` - -in pipeline script: -1. we want to remove user with id == "me" -2. we want to anonymize user data -3. we want to pivot `user_props` into KV table - -```python -from pipedrive import pipedrive, anonymize_user - -source = pipedrive() -# access resource in the source by name and add filter and map transformation -source.users.add_filter(lambda user: user["user_id"] != "me").add_map(anonymize_user) -# now we want to yield user props to separate table. we define our own generator function -def pivot_props(user): - # keep user - yield user - # yield user props to user_props table - yield from [ - dlt.mark.with_table_name({"user_id": user["user_id"], "name": k, "value": v}, "user_props") for k, v in user["props"] - ] - -source.user.add_yield_map(pivot_props) -pipeline.run(source) -``` - -We provide a library of various concrete transformations: - -* ☮️ a recursive versions of the map, filter and flat map which can be applied to any nesting level of the data item (the standard transformations work on recursion level 0). Possible applications - - ☮️ recursive rename of dict keys - - ☮️ converting all values to strings - - etc. - -## Some CS Theory - -### The power of decorators - -With decorators dlt can inspect and modify the code being decorated. -1. it knows what are the sources and resources without running them -2. it knows input arguments so it knows the config values and secret values (see `secrets_and_config`). with those we can generate deployments automatically -3. it can inject config and secret values automatically -4. it wraps the functions into objects that provide additional functionalities -- sources and resources are iterators so you can write -```python -items = list(source()) - -for item in source()["logs"]: - ... -``` -- you can select which resources to load with `source().select(*names)` -- you can add mappings and filters to resources - -### The power of yielding: The preferred way to write resources - -The Python function that yields is not a function but magical object that `dlt` can control: -1. it is not executed when you call it! the call just creates a generator (see below). in the example above `taktile_data(1)` will not execute the code inside, it will just return an object composed of function code and input parameters. dlt has control over the object and can execute the code later. this is called `lazy execution` -2. i can control when and how much of the code is executed. the function that yields typically looks like that - -```python -def lazy_function(endpoint_name): - # INIT - this will be executed only once when dlt wants! - get_configuration() - from_item = dlt.current.state.get("last_item", 0) - l = get_item_list_from_api(api_key, endpoint_name) - - # ITERATOR - this will be executed many times also when dlt wants more data! - for item in l: - yield requests.get(url, api_key, "%s?id=%s" % (endpoint_name, item["id"])).json() - # CLEANUP - # this will be executed only once after the last item was yielded! - dlt.current.state["last_item"] = item["id"] -``` - -3. dlt will execute this generator in extractor. the whole execution is atomic (including writing to state). if anything fails with exception the whole extract function fails. -4. the execution can be parallelized by using a decorator or a simple modifier function ie: -```python -for item in l: - yield deferred(requests.get(url, api_key, "%s?id=%s" % (endpoint_name, item["id"])).json()) -``` \ No newline at end of file diff --git a/docs/technical/general_usage.md b/docs/technical/general_usage.md index 19c93bcf38..336c892c66 100644 --- a/docs/technical/general_usage.md +++ b/docs/technical/general_usage.md @@ -90,7 +90,7 @@ p.extract([label1, label2, label3], name="labels") # will use default schema "s **By default, one dataset can handle multiple schemas**. The pipeline configuration option `use_single_dataset` controls the dataset layout in the destination. By default it is set to True. In that case only one dataset is created at the destination - by default dataset name which is the same as pipeline name. The dataset name can also be explicitly provided into `dlt.pipeline` `dlt.run` and `Pipeline::load` methods. -All the tables from all the schemas are stored in that dataset. The table names are **not prefixed** with schema names!. If there are any name clashes, tables in the destination will be unions of the fields of all the tables with same name in the schemas. +All the tables from all the schemas are stored in that dataset. The table names are **not prefixed** with schema names!. If there are any name collisions, tables in the destination will be unions of the fields of all the tables with same name in the schemas. **Enabling one dataset per schema layout** If you set `use_single_dataset` to False: @@ -181,44 +181,6 @@ The `run`, `extract`, `normalize` and `load` method raise `PipelineStepFailed` w > should we add it? I have a runner in `dlt` that would be easy to modify -## the `Pipeline` object -There are many ways to create or get current pipeline object. -```python - -# create and get default pipeline -p1 = dlt.pipeline() -# create explicitly configured pipeline -p2 = dlt.pipeline(pipeline_name="pipe", destination=bigquery) -# get recently created pipeline -assert dlt.pipeline() is p2 -# load data with recently created pipeline -assert dlt.run(taktile_data()) is p2 -assert taktile_data().run() is p2 - -``` - -The `Pipeline` object provides following functionalities: -1. `run`, `extract`, `normalize` and `load` methods -2. a `pipeline.schema` dictionary-like object to enumerate and get the schemas in pipeline -3. schema get with `pipeline.schemas[name]` is a live object: any modification to it is automatically applied to the pipeline with the next `run`, `load` etc. see [working_with_schemas.md](working_with_schemas.md) -4. it returns `sql_client` and `native_client` to get direct access to the destination (if destination supports SQL - currently all of them do) -5. it has several methods to inspect the pipeline state and I think those should be exposed via `dlt pipeline` CLI - -for example: -- list the extracted files if any -- list the load packages ready to load -- list the failed jobs in package -- show info on destination: what are the datasets, the current load_id, the current schema etc. - - -## Examples -[we have some here](/examples/) - -## command line interface - - -## logging -I need your input for user friendly logging. What should we log? What is important to see? ## pipeline runtime setup diff --git a/docs/technical/working_with_schemas.md b/docs/technical/working_with_schemas.md index d94edb8727..532f0e5a1d 100644 --- a/docs/technical/working_with_schemas.md +++ b/docs/technical/working_with_schemas.md @@ -1,134 +1,7 @@ -## General approach to define schemas -marks features that are: - -⛔ not implemented, hard to add - -☮️ not implemented, easy to add - -## Schema components - -### Schema content hash and version -Each schema file contains content based hash `version_hash` that is used to -1. detect manual changes to schema (ie. user edits content) -2. detect if the destination database schema is synchronized with the file schema - -Each time the schema is saved, the version hash is updated. - -Each schema contains also numeric version which increases automatically whenever schema is updated and saved. This version is mostly for informative purposes and there are cases where the increasing order will be lost. - -> Schema in the database is only updated if its hash is not stored in `_dlt_versions` table. In principle many pipelines may send data to a single dataset. If table name clash then a single table with the union of the columns will be created. If columns clash and they have different types etc. then the load will fail. - -### ❗ Normalizer and naming convention - -The parent table is created from all top level fields, if field are dictionaries they will be flattened. **all the key names will be converted with the configured naming convention**. The current naming convention -1. converts to snake_case, small caps. removes all ascii characters except alphanum and underscore -2. add `_` if name starts with number -3. multiples of `_` are converted into single `_` -4. the parent-child relation is expressed as double `_` in names. - -The nested lists will be converted into child tables. - -The data normalizer and the naming convention are part of the schema configuration. In principle the source can set own naming convention or json unpacking mechanism. Or user can overwrite those in `config.toml` - -> The table and column names are mapped automatically. **you cannot rename the columns or tables by changing the `name` property - you must rename your source documents** - -> if you provide any schema elements that contain identifiers via decorators or arguments (ie. `table_name` or `columns`) all the names used will be converted via the naming convention when adding to the schema. For example if you execute `dlt.run(... table_name="CamelCase")` the data will be loaded into `camel_case` - -> 💡 use simple, short small caps identifiers for everything! - -☠️ not implemented! - -⛔ The schema holds lineage information (from json paths to tables/columns) and (1) automatically adapts to destination limits ie. postgres 64 chars by recomputing all names (2) let's user to change the naming convention ie. to verbatim naming convention of `duckdb` where everything is allowed as identifier. - -⛔ Any naming convention generates name clashes. `dlt` detects and fixes name clashes using lineage information - - -#### JSON normalizer settings -Yes those are part of the normalizer module and can be plugged in. -1. column propagation from parent to child tables -2. nesting level - -```yaml -normalizers: - names: dlt.common.normalizers.names.snake_case - json: - module: dlt.common.normalizers.json.relational - config: - max_nesting: 5 - propagation: - # for all root tables - root: - # propagate root dlt id - _dlt_id: _dlt_root_id - tables: - # for particular tables - blocks: - # propagate timestamp as block_timestamp to child tables - timestamp: block_timestamp - hash: block_hash -``` - -## Data types -"text", "double", "bool", "timestamp", "bigint", "binary", "complex", "decimal", "wei" -⛔ you cannot specify scale and precision for bigint, binary, text and decimal - -☮️ there's no time and date type - -wei is a datatype that tries to best represent native Ethereum 256bit integers and fixed point decimals. it works correcly on postgres and bigquery ## Schema settings The `settings` section of schema let's you define various global rules that impact how tables and columns are inferred from data. -> 💡 it is the best practice to use those instead of providing the exact column schemas via `columns` argument or by pasting them in `yaml`. Any ideas for improvements? tell me. - -### Column hint rules -You can define a global rules that will apply hints to a newly inferred columns. Those rules apply to normalized column names. You can use column names directly or with regular expressions. ❗ when lineages are implemented the regular expressions will apply to lineages not to column names. - -Example from ethereum schema -```yaml -settings: - default_hints: - foreign_key: - - _dlt_parent_id - not_null: - - re:^_dlt_id$ - - _dlt_root_id - - _dlt_parent_id - - _dlt_list_idx - unique: - - _dlt_id - cluster: - - block_hash - partition: - - block_timestamp -``` - -### Preferred data types -You can define rules that will set the data type for newly created columns. Put the rules under `preferred_types` key of `settings`. On the left side there's a rule on a column name, on the right side is the data type. ❗See the column hint rules for naming convention! - -Example: -```yaml -settings: - preferred_types: - timestamp: timestamp - re:^inserted_at$: timestamp - re:^created_at$: timestamp - re:^updated_at$: timestamp - re:^_dlt_list_idx$: bigint -``` - -### data type autodetectors -You can define a set of functions that will be used to infer the data type of the column from a value. The functions are run from top to bottom on the lists. Look in `detections.py` to see what is available. -```yaml -settings: - detections: - - timestamp - - iso_timestamp - - iso_date -``` - -⛔ we may define `all_text` function that will generate string only schemas by telling `dlt` that all types should be coerced to strings. - ### Table exclude and include filters You can define the include and exclude filters on tables but you are much better off transforming and filtering your source data in python. The current implementation is both weird and quite powerful. In essence you can exclude columns and whole tables with regular expressions to which the inputs are normalized lineages of the values. Example @@ -191,54 +64,3 @@ p.run() ``` > The `normalize` stage creates standalone load packages each containing data and schema with particular version. Those packages are of course not impacted by the "live" schema changes. - -## Attaching schemas to sources -The general approach when creating a new pipeline is to setup a few global schema settings and then let the table and column schemas to be generated from the resource hints and data itself. - -> ⛔ I do not have any cool "schema builder" api yet to see the global settings. - -The `dlt.source` decorator accepts a schema instance that you can create yourself and whatever you want. It also support a few typical use cases: - -### Schema created implicitly by decorator -If no schema instance is passed, the decorator creates a schema with the name set to source name and all the settings to default. - -### Automatically load schema file stored with source python module -If no schema instance is passed, and a file with a name `{source name}_schema.yml` exists in the same folder as the module with the decorated function, it will be automatically loaded and used as the schema. - -This should make easier to bundle a fully specified (or non trivially configured) schema with a source. - -### Schema is modified in the source function body -What if you can configure your schema or add some tables only inside your schema function, when ie. you have the source credentials and user settings? You could for example add detailed schemas of all the database tables when someone requests a table data to be loaded. This information is available only at the moment source function is called. - -Similarly to the `state`, source and resource function has current schema available via `dlt.current.source_schema` - -Example: - -```python - -# apply schema to the source -@dlt.source -def createx(nesting_level: int): - - schema = dlt.current.source_schema() - - # get default normalizer config - normalizer_conf = dlt.schema.normalizer_config() - # set hash names convention which produces short names without clashes but very ugly - if short_names_convention: - normalizer_conf["names"] = dlt.common.normalizers.names.hash_names - - # apply normalizer conf - schema = Schema("createx", normalizer_conf) - # set nesting level, yeah it's ugly - schema._normalizers_config["json"].setdefault("config", {})["max_nesting"] = nesting_level - # remove date detector and add type detector that forces all fields to strings - schema._settings["detections"].remove("iso_timestamp") - schema._settings["detections"].insert(0, "all_text") - schema.compile_settings() - - return dlt.resource(...) - -``` - -Also look at the following [test](/tests/extract/test_decorators.py) : `test_source_schema_context` diff --git a/docs/tools/fix_grammar_gpt.py b/docs/tools/fix_grammar_gpt.py index 065b53d470..9979a92b41 100644 --- a/docs/tools/fix_grammar_gpt.py +++ b/docs/tools/fix_grammar_gpt.py @@ -120,7 +120,7 @@ def get_chunk_length(chunk: List[str]) -> int: temperature=0, ) - fixed_chunks.append(response.choices[0].message.content) + fixed_chunks.append(response.choices[0].message.content) # type: ignore with open(file_path, "w", encoding="utf-8") as f: for c in fixed_chunks: diff --git a/docs/tools/prepare_examples_tests.py b/docs/tools/prepare_examples_tests.py index a300b1eb8f..dc0a3c82f9 100644 --- a/docs/tools/prepare_examples_tests.py +++ b/docs/tools/prepare_examples_tests.py @@ -3,6 +3,7 @@ """ import os import argparse +from typing import List import dlt.cli.echo as fmt @@ -10,13 +11,15 @@ # settings SKIP_FOLDERS = ["archive", ".", "_", "local_cache"] -SKIP_EXAMPLES = ["qdrant_zendesk"] +SKIP_EXAMPLES: List[str] = [] # the entry point for the script MAIN_CLAUSE = 'if __name__ == "__main__":' # some stuff to insert for setting up and tearing down fixtures TEST_HEADER = """ +import pytest + from tests.utils import skipifgithubfork """ @@ -52,8 +55,12 @@ os.unlink(test_example_file) continue - with open(example_file, "r", encoding="utf-8") as f: - lines = f.read().split("\n") + try: + with open(example_file, "r", encoding="utf-8") as f: + lines = f.read().split("\n") + except FileNotFoundError: + print(f"Example file {example_file} not found, test prep will be skipped") + continue processed_lines = TEST_HEADER.split("\n") main_clause_found = False @@ -62,7 +69,8 @@ # convert the main clause to a test function if line.startswith(MAIN_CLAUSE): main_clause_found = True - processed_lines.append("@skipifgithubfork") + processed_lines.append("@skipifgithubfork") # skip on forks + processed_lines.append("@pytest.mark.forked") # skip on forks processed_lines.append(f"def test_{example}():") else: processed_lines.append(line) diff --git a/docs/website/blog/2023-09-05-mongo-etl.md b/docs/website/blog/2023-09-05-mongo-etl.md index cd102c8895..8dfd953be4 100644 --- a/docs/website/blog/2023-09-05-mongo-etl.md +++ b/docs/website/blog/2023-09-05-mongo-etl.md @@ -168,7 +168,7 @@ Here's a code explanation of how it works under the hood: pipeline_name='from_json', destination='duckdb', dataset_name='mydata', - full_refresh=True, + dev_mode=True, ) # dlt works with lists of dicts, so wrap data to the list load_info = pipeline.run([data], table_name="json_data") diff --git a/docs/website/blog/2023-10-23-arrow-loading.md b/docs/website/blog/2023-10-23-arrow-loading.md index 2cdf4d90e7..25962c932e 100644 --- a/docs/website/blog/2023-10-23-arrow-loading.md +++ b/docs/website/blog/2023-10-23-arrow-loading.md @@ -50,7 +50,7 @@ chat_messages = dlt.resource( In this demo I just extract and normalize data and skip the loading step. ```py -pipeline = dlt.pipeline(destination="duckdb", full_refresh=True) +pipeline = dlt.pipeline(destination="duckdb", dev_mode=True) # extract first pipeline.extract(chat_messages) info = pipeline.normalize() @@ -98,7 +98,7 @@ chat_messages = dlt.resource( write_disposition="append", )("postgresql://loader:loader@localhost:5432/dlt_data") -pipeline = dlt.pipeline(destination="duckdb", full_refresh=True) +pipeline = dlt.pipeline(destination="duckdb", dev_mode=True) # extract first pipeline.extract(chat_messages) info = pipeline.normalize(workers=3, loader_file_format="parquet") diff --git a/docs/website/blog/2023-12-01-dlt-kestra-demo.md b/docs/website/blog/2023-12-01-dlt-kestra-demo.md index 9f1d7acba2..1b1c79562d 100644 --- a/docs/website/blog/2023-12-01-dlt-kestra-demo.md +++ b/docs/website/blog/2023-12-01-dlt-kestra-demo.md @@ -45,7 +45,7 @@ Wanna jump to the [GitHub repo](https://github.com/dlt-hub/dlt-kestra-demo)? ## HOW IT WORKS -To lay it all out clearly: Everything's automated in **`Kestra`**, with hassle-free data loading thanks to **`dlt`**, and the analytical thinking handled by OpenAI. Here's a diagram to help you understand the general outline of the entire process. +To lay it all out clearly: Everything's automated in **`Kestra`**, with hassle-free data loading thanks to **`dlt`**, and the analytical thinking handled by OpenAI. Here's a diagram to help you understand the general outline of the entire process. ![overview](https://storage.googleapis.com/dlt-blog-images/dlt_kestra_workflow_overview.png) @@ -59,12 +59,12 @@ Once you’ve opened [http://localhost:8080/](http://localhost:8080/) in your br ![Kestra](https://storage.googleapis.com/dlt-blog-images/dlt_kestra_kestra_ui.png) -Now, all you need to do is [create your flows](https://github.com/dlt-hub/dlt-kestra-demo/blob/main/README.md) and execute them. +Now, all you need to do is [create your flows](https://github.com/dlt-hub/dlt-kestra-demo/blob/main/README.md) and execute them. The great thing about **`Kestra`** is its ease of use - it's UI-based, declarative, and language-agnostic. Unless you're using a task like a [Python script](https://kestra.io/plugins/plugin-script-python/tasks/io.kestra.plugin.scripts.python.script), you don't even need to know how to code. -:::tip +:::tip If you're already considering ways to use **`Kestra`** for your projects, consult their [documentation](https://kestra.io/docs) and the [plugin](https://kestra.io/plugins) pages for further insights. ::: @@ -84,7 +84,7 @@ pipeline = dlt.pipeline( pipeline_name="standard_inbox", destination='bigquery', dataset_name="messages_data", - full_refresh=False, + dev_mode=False, ) # Set table name diff --git a/docs/website/blog/2024-02-06-practice-api-sources.md b/docs/website/blog/2024-02-06-practice-api-sources.md index 4e78fc48e4..248d4ae647 100644 --- a/docs/website/blog/2024-02-06-practice-api-sources.md +++ b/docs/website/blog/2024-02-06-practice-api-sources.md @@ -32,6 +32,7 @@ This article outlines 10 APIs, detailing their use cases, any free tier limitati ### Data talks club open source spotlight * [Video](https://www.youtube.com/watch?v=eMbhyOECpcE) * [Notebook](https://github.com/dlt-hub/dlt_demos/blob/main/spotlight_demo.ipynb) +* DTC Learners showcase (review again) ### Docs * [Getting started](https://dlthub.com/docs/getting-started) @@ -100,8 +101,166 @@ This article outlines 10 APIs, detailing their use cases, any free tier limitati - **Free:** Varies by API. - **Auth:** Depends on API. +### 11. News API +- **URL**: [News API](https://newsapi.ai/). +- **Use**: Get datasets containing current and historic news articles. +- **Free**: Access to current news articles. +- **Auth**: API-Key. + +### 12. Exchangerates API +- **URL**: [Exchangerate API](https://exchangeratesapi.io/). +- **Use**: Get realtime, intraday and historic currency rates. +- **Free**: 250 monthly requests. +- **Auth**: API-Key. + +### 13. Spotify API +- **URL**: [Spotify API](https://developer.spotify.com/documentation/web-api). +- **Use**: Get spotify content and metadata about songs. +- **Free**: Rate limit. +- **Auth**: API-Key. + +### 14. Football API +- **URL**: [FootBall API](https://www.api-football.com/). +- **Use**: Get information about Football Leagues & Cups. +- **Free**: 100 requests/day. +- **Auth**: API-Key. + +### 15. Yahoo Finance API +- **URL**: [Yahoo Finance API](https://rapidapi.com/sparior/api/yahoo-finance15/details). +- **Use**: Access a wide range of financial data. +- **Free**: 500 requests/month. +- **Auth**: API-Key. + +### 16. Basketball API + +- URL: [Basketball API](https://www.api-basketball.com/). +- Use: Get information about basketball leagues & cups. +- Free: 100 requests/day. +- Auth: API-Key. + +### 17. NY Times API + +- URL: [NY Times API](https://developer.nytimes.com/apis). +- Use: Get info about articles, books, movies and more. +- Free: 500 requests/day or 5 requests/minute. +- Auth: API-Key. + +### 18. Spoonacular API + +- URL: [Spoonacular API](https://spoonacular.com/food-api). +- Use: Get info about ingredients, recipes, products and menu items. +- Free: 150 requests/day and 1 request/sec. +- Auth: API-Key. + +### 19. Movie database alternative API + +- URL: [Movie database alternative API](https://rapidapi.com/rapidapi/api/movie-database-alternative/pricing). +- Use: Movie data for entertainment industry trend analysis. +- Free: 1000 requests/day and 10 requests/sec. +- Auth: API-Key. + +### 20. RAWG Video games database API + +- URL: [RAWG Video Games Database](https://rawg.io/apidocs). +- Use: Gather video game data, such as release dates, platforms, genres, and reviews. +- Free: Unlimited requests for limited endpoints. +- Auth: API key. + +### 21. Jikan API + +- **URL:** [Jikan API](https://jikan.moe/). +- **Use:** Access data from MyAnimeList for anime and manga projects. +- **Free:** Rate-limited. +- **Auth:** None. + +### 22. Open Library Books API + +- URL: [Open Library Books API](https://openlibrary.org/dev/docs/api/books). +- Use: Access data about millions of books, including titles, authors, and publication dates. +- Free: Unlimited. +- Auth: None. + +### 23. YouTube Data API + +- URL: [YouTube Data API](https://developers.google.com/youtube/v3/docs/search/list). +- Use: Access YouTube video data, channels, playlists, etc. +- Free: Limited quota. +- Auth: Google API key and OAuth 2.0. + +### 24. Reddit API + +- URL: [Reddit API](https://www.reddit.com/dev/api/). +- Use: Access Reddit data for social media analysis or content retrieval. +- Free: Rate-limited. +- Auth: OAuth 2.0. + +### 25. World Bank API + +- URL: [World bank API](https://documents.worldbank.org/en/publication/documents-reports/api). +- Use: Access economic and development data from the World Bank. +- Free: Unlimited. +- Auth: None. + Each API offers unique insights for data engineering, from ingestion to visualization. Check each API's documentation for up-to-date details on limitations and authentication. +## Using the above sources + +You can create a pipeline for the APIs discussed above by using `dlt's` REST API source. Let’s create a PokeAPI pipeline as an example. Follow these steps: + +1. Create a Rest API source: + + ```sh + dlt init rest_api duckdb + ``` + +2. The following directory structure gets generated: + + ```sh + rest_api_pipeline/ + ├── .dlt/ + │ ├── config.toml # configs for your pipeline + │ └── secrets.toml # secrets for your pipeline + ├── rest_api/ # folder with source-specific files + │ └── ... + ├── rest_api_pipeline.py # your main pipeline script + ├── requirements.txt # dependencies for your pipeline + └── .gitignore # ignore files for git (not required) + ``` + +3. Configure the source in `rest_api_pipeline.py`: + + ```py + def load_pokemon() -> None: + pipeline = dlt.pipeline( + pipeline_name="rest_api_pokemon", + destination='duckdb', + dataset_name="rest_api_data", + ) + + pokemon_source = rest_api_source( + { + "client": { + "base_url": "https://pokeapi.co/api/v2/", + }, + "resource_defaults": { + "endpoint": { + "params": { + "limit": 1000, + }, + }, + }, + "resources": [ + "pokemon", + "berry", + "location", + ], + } + ) + + ``` + +For a detailed guide on creating a pipeline using the Rest API source, please read the Rest API source [documentation here](https://dlthub.com/docs/dlt-ecosystem/verified-sources/rest_api). + ## Example projects Here are some examples from dlt users and working students: @@ -115,5 +274,19 @@ Here are some examples from dlt users and working students: - Japanese language demos [Notion calendar](https://stable.co.jp/blog/notion-calendar-dlt) and [exploring csv to bigquery with dlt](https://soonraah.github.io/posts/load-csv-data-into-bq-by-dlt/). - Demos with [Dagster](https://dagster.io/blog/dagster-dlt) and [Prefect](https://www.prefect.io/blog/building-resilient-data-pipelines-in-minutes-with-dlt-prefect). +## DTC learners showcase +Check out the incredible projects from our DTC learners: + +1. [e2e_de_project](https://github.com/scpkobayashi/e2e_de_project/tree/153d485bba3ea8f640d0ccf3ec9593790259a646) by [scpkobayashi](https://github.com/scpkobayashi). +2. [de-zoomcamp-project](https://github.com/theDataFixer/de-zoomcamp-project/tree/1737b6a9d556348c2d7d48a91e2a43bb6e12f594) by [theDataFixer](https://github.com/theDataFixer). +3. [data-engineering-zoomcamp2024-project2](https://github.com/pavlokurochka/data-engineering-zoomcamp2024-project2/tree/f336ed00870a74cb93cbd9783dbff594393654b8) by [pavlokurochka](https://github.com/pavlokurochka). +4. [de-zoomcamp-2024](https://github.com/snehangsude/de-zoomcamp-2024) by [snehangsude](https://github.com/snehangsude). +5. [zoomcamp-data-engineer-2024](https://github.com/eokwukwe/zoomcamp-data-engineer-2024) by [eokwukwe](https://github.com/eokwukwe). +6. [data-engineering-zoomcamp-alex](https://github.com/aaalexlit/data-engineering-zoomcamp-alex) by [aaalexlit](https://github.com/aaalexlit). +7. [Zoomcamp2024](https://github.com/alfredzou/Zoomcamp2024) by [alfredzou](https://github.com/alfredzou). +8. [data-engineering-zoomcamp](https://github.com/el-grudge/data-engineering-zoomcamp) by [el-grudge](https://github.com/el-grudge). + +Explore these projects to see the innovative solutions and hard work the learners have put into their data engineering journeys! + ## Showcase your project -If you want your project to be featured, let us know in the [#sharing-and-contributing channel of our community Slack](https://dlthub.com/community). +If you want your project to be featured, let us know in the [#sharing-and-contributing channel of our community Slack](https://dlthub.com/community). \ No newline at end of file diff --git a/docs/website/blog/2024-03-07-openapi-generation-chargebee.md b/docs/website/blog/2024-03-07-openapi-generation-chargebee.md index 3d77c3ea4c..97fc6e4865 100644 --- a/docs/website/blog/2024-03-07-openapi-generation-chargebee.md +++ b/docs/website/blog/2024-03-07-openapi-generation-chargebee.md @@ -7,7 +7,7 @@ authors: title: Data Engineer & ML Engineer url: https://github.com/dlt-hub/dlt image_url: https://avatars.githubusercontent.com/u/89419010?s=48&v=4 -tags: [data observability, data pipeline observability] +tags: [data observability, data pipeline observability, openapi] --- At dltHub, we have been pioneering the future of data pipeline generation, [making complex processes simple and scalable.](https://dlthub.com/product/#multiply-don't-add-to-our-productivity) We have not only been building dlt for humans, but also LLMs. diff --git a/docs/website/blog/2024-05-14-rest-api-source-client.md b/docs/website/blog/2024-05-14-rest-api-source-client.md index 18c8f1196e..ee20b43b41 100644 --- a/docs/website/blog/2024-05-14-rest-api-source-client.md +++ b/docs/website/blog/2024-05-14-rest-api-source-client.md @@ -7,7 +7,7 @@ authors: title: Open source Data Engineer url: https://github.com/adrianbr image_url: https://avatars.githubusercontent.com/u/5762770?v=4 -tags: [full code etl, yes code etl, etl, python elt] +tags: [rest-api, declarative etl] --- ## What is the REST API Source toolkit? diff --git a/docs/website/blog/2024-05-23-contributed-first-pipeline.md b/docs/website/blog/2024-05-23-contributed-first-pipeline.md index aae6e0f298..c6d9252da3 100644 --- a/docs/website/blog/2024-05-23-contributed-first-pipeline.md +++ b/docs/website/blog/2024-05-23-contributed-first-pipeline.md @@ -1,6 +1,6 @@ --- slug: contributed-first-pipeline -title: "How I contributed my first data pipeline to the open source." +title: "How I Contributed to My First Open Source Data Pipeline" image: https://storage.googleapis.com/dlt-blog-images/blog_my_first_data_pipeline.png authors: name: Aman Gupta @@ -78,13 +78,13 @@ def incremental_resource( With the steps defined above, I was able to load the data from Freshdesk to BigQuery and use the pipeline in production. Here’s a summary of the steps I followed: 1. Created a Freshdesk API token with sufficient privileges. -1. Created an API client to make requests to the Freshdesk API with rate limit and pagination. -1. Made incremental requests to this client based on the “updated_at” field in the response. -1. Ran the pipeline using the Python script. +2. Created an API client to make requests to the Freshdesk API with rate limit and pagination. +3. Made incremental requests to this client based on the “updated_at” field in the response. +4. Ran the pipeline using the Python script. While my journey from civil engineering to data engineering was initially intimidating, it has proved to be a profound learning experience. Writing a pipeline with **`dlt`** mirrors the simplicity of a GET request: you request data, yield it, and it flows from the source to its destination. Now, I help other clients integrate **`dlt`** to streamline their data workflows, which has been an invaluable part of my professional growth. In conclusion, diving into data engineering has expanded my technical skill set and provided a new lens through which I view challenges and solutions. As for me, the lens view mainly was concrete and steel a couple of years back, which has now begun to notice the pipelines of the data world. -Data engineering has proved both challenging, satisfying and a good carrier option for me till now. For those interested in the detailed workings of these pipelines, I encourage exploring dlt's [GitHub repository](https://github.com/dlt-hub/verified-sources) or diving into the [documentation](https://dlthub.com/docs/dlt-ecosystem/verified-sources/freshdesk). \ No newline at end of file +Data engineering has proved both challenging, satisfying, and a good career option for me till now. For those interested in the detailed workings of these pipelines, I encourage exploring dlt's [GitHub repository](https://github.com/dlt-hub/verified-sources) or diving into the [documentation](https://dlthub.com/docs/dlt-ecosystem/verified-sources/freshdesk). \ No newline at end of file diff --git a/docs/website/blog/2024-05-28-openapi-pipeline.md b/docs/website/blog/2024-05-28-openapi-pipeline.md new file mode 100644 index 0000000000..60faa062e0 --- /dev/null +++ b/docs/website/blog/2024-05-28-openapi-pipeline.md @@ -0,0 +1,97 @@ +--- +slug: openapi-pipeline +title: "Instant pipelines with dlt-init-openapi" +image: https://storage.googleapis.com/dlt-blog-images/openapi.png +authors: + name: Adrian Brudaru + title: Open source Data Engineer + url: https://github.com/adrianbr + image_url: https://avatars.githubusercontent.com/u/5762770?v=4 +tags: [openapi] +--- + +# The Future of Data Pipelines starts now. + +Dear dltHub Community, + +We are thrilled to announce the launch of our groundbreaking pipeline generator tool. + +We call it `dlt-init-openapi`. + +Just point it to an OpenAPI spec, select your endpoints, and you're done! + + +### What's OpenAPI again? + +[OpenAPI](https://www.openapis.org/) is the world's most widely used API description standard. You may have heard about swagger docs? those are docs generated from the spec. +In 2021 an information-security company named Assetnote scanned the web and unearthed [200,000 public +OpenAPI files](https://www.assetnote.io/resources/research/contextual-content-discovery-youve-forgotten-about-the-api-endpoints). +Modern API frameworks like [FastAPI](https://pypi.org/project/fastapi/) generate such specifications automatically. + +## How does it work? + +**A pipeline is a series of datapoints or decisions about how to extract and load the data**, expressed as code or config. I say decisions because building a pipeline can be boiled down to inspecting a documentation or response and deciding how to write the code. + +Our tool does its best to pick out the necessary details and detect the rest to generate the complete pipeline for you. + +The information required for taking those decisions comes from: +- The OpenAPI [Spec](https://github.com/dlt-hub/openapi-specs) (endpoints, auth) +- The dlt [REST API Source](https://dlthub.com/docs/dlt-ecosystem/verified-sources/rest_api) which attempts to detect pagination +- The [dlt init OpenAPI generator](https://dlthub.com/docs/dlt-ecosystem/verified-sources/openapi-generator) which attempts to detect incremental logic and dependent requests. + +### How well does it work? + +This is something we are also learning about. We did an internal hackathon where we each built a few pipelines with this generator. In our experiments with APIs for which we had credentials, it worked pretty well. + +However, we cannot undertake a big detour from our work to manually test each possible pipeline, so your feedback will be invaluable. +So please, if you try it, let us know how well it worked - and ideally, add the spec you used to our [repository](https://github.com/dlt-hub/openapi-specs). + +### What to do if it doesn't work? + +Once a pipeline is created, it is a **fully configurable instance of the REST API Source**. +So if anything did not go smoothly, you can make the final tweaks. +You can learn how to adjust the generated pipeline by reading our [REST API Source documentation](https://dlthub.com/docs/dlt-ecosystem/verified-sources/rest_api). + +### Are we using LLMS under the hood? + +No. This is a potential future enhancement, so maybe later. + +The pipelines are generated algorithmically with deterministic outcomes. This way, we have more control over the quality of the decisions. + +If we took an LLM-first approach, the errors would compound and put the burden back on the data person. + +We are however considering using LLM-assists for the things that the algorithmic approach can't detect. Another avenue could be generating the OpenAPI spec from website docs. +So we are eager to get feedback from you on what works and what needs work, enabling us to improve it. + +## Try it out now! + +**Video Walkthrough:** + + + + +**[Colab demo](https://colab.research.google.com/drive/1MRZvguOTZj1MlkEGzjiso8lQ_wr1MJRI?usp=sharing)** - Load data from Stripe API to DuckDB using dlt and OpenAPI + +**[Docs](https://dlthub.com/docs/dlt-ecosystem/verified-sources/openapi-generator)** for `dlt-init-openapi` + +dlt init openapi **[code repo.](https://github.com/dlt-hub/dlt-init-openapi)** + +**[Specs repository you can generate from.](https://github.com/dlt-hub/openapi-specs)** + +Showcase your pipeline in the community sources **[here](https://www.notion.so/dlthub/dltHub-Community-Sources-Snippets-7a7f7ddb39334743b1ba3debbdfb8d7f) + +## Next steps: Feedback, discussion and sharing. + +Solving data engineering headaches in the open source is a team sport. +We got this far with your feedback and help (especially on [REST API source](https://dlthub.com/docs/blog/rest-api-source-client)), and are counting on your continuous usage and engagement +to steer our pushing of what's possible into uncharted, but needed directions. + +So here's our call to action: + +- We're excited to see how you will use our new pipeline generator and we are +eager for your feedback. **[Join our community and let us know how we can improve dlt-init-openapi](https://dlthub.com/community)** +- Got an OpenAPI spec? **[Add it to our specs repository](https://github.com/dlt-hub/openapi-specs)** so others may use it. If the spec doesn't work, please note that in the PR and we will use it for R&D. + +*Thank you for being part of our community and for building the future of ETL together!* + +*- dltHub Team* diff --git a/docs/website/blog/2024-06-12-from-pandas-to-production.md b/docs/website/blog/2024-06-12-from-pandas-to-production.md new file mode 100644 index 0000000000..5dbd494a3e --- /dev/null +++ b/docs/website/blog/2024-06-12-from-pandas-to-production.md @@ -0,0 +1,211 @@ +--- +slug: pandas-to-production +title: "From Pandas to Production: How we built dlt as the right ELT tool for Normies" +image: https://storage.googleapis.com/dlt-blog-images/i-am-normal.png +authors: + name: Adrian Brudaru + title: Open source Data Engineer + url: https://github.com/adrianbr + image_url: https://avatars.githubusercontent.com/u/5762770?v=4 +tags: [pandas, production, etl, etl] +--- + + + +:::info +**TL;DR: dlt is a library for Normies: Problem solvers with antipathy for black boxes, gratuitous complexity and external dependencies.** + +**This post tells the story of how we got here.** + +Try it in colab: +* [Schema evolution](https://colab.research.google.com/drive/1H6HKFi-U1V4p0afVucw_Jzv1oiFbH2bu#scrollTo=e4y4sQ78P_OM) +* [Data Talks Club Open Source Spotlight](https://colab.research.google.com/drive/1D39_koejvi-eTtA_8AI33AHhMGGOklfb) + [Video](https://www.youtube.com/playlist?list=PL3MmuxUbc_hJ5t5nnjzC0F2zan76Dpsz0) +* [Hackernews Api demo](https://colab.research.google.com/drive/1DhaKW0tiSTHDCVmPjM-eoyL47BJ30xmP) +* [LLM-free pipeline generation demo](https://colab.research.google.com/drive/1MRZvguOTZj1MlkEGzjiso8lQ_wr1MJRI) +[4min Video](https://www.youtube.com/watch?v=b99qv9je12Q) + +But if you want to load pandas dfs to production databases, with all the best practices built-in, check out this [documentation](https://dlthub.com/docs/dlt-ecosystem/verified-sources/arrow-pandas) or this colab notebook that shows [easy handling of complex api data](https://colab.research.google.com/drive/1DhaKW0tiSTHDCVmPjM-eoyL47BJ30xmP#scrollTo=1wf1R0yQh7pv). + +Or check out more resources [at the end of the article](#call-to-action) +::: + +## I. The background story: Normal people load data too + +Hey, I’m Adrian, cofounder of dlt. I’ve been working in the data industry since 2012, doing all kinds of end-to-end things. + +In 2017, a hiring team called me a data engineer. As I saw that title brought me a lot of work offers, I kept it and went with it. + +But was I doing data engineering? Yes and no. Since my studies were not technical, I always felt some impostor syndrome calling myself a data engineer. I had started as an analyst, did more and more and became an end to end data professional that does everything from building the tech stack, collecting requirements, getting managers to agree on the metrics used 🙄, creating roadmap and hiring a team. + +Back in 2022 there was an online conference called [Normconf](https://normconf.com/) and I ‘felt seen’. As [I watched Normconf participants](https://www.youtube.com/@normconf), I could relate more to them than to the data engineer label. No, I am not just writing code and pushing best practices - I am actually just trying to get things done without getting bogged down in bad practice gotchas. And it seemed at this conference that many people felt this way. + +![normal](https://storage.googleapis.com/dlt-blog-images/i-am-normal.png) + +### Normies: Problem solvers with antipathy for black boxes, gratuitous complexity and external dependencies + +At Normconf, "normie" participants often embodied the three fundamental psychological needs identified in Self-Determination Theory: autonomy, competence, and relatedness. + +They talked about how they autonomously solved all kinds of problems, related on the pains and gains of their roles, and showed off their competence across the board, in solving problems. + +What they did, was what I also did as a data engineer: We start from a business problem, and work back through what needs to be done to understand and solve it. + +By very definition, Normie is someone not very specialised at one thing or another, and in our field, even data engineers are jacks of all trades. + +What undermines the Normie mission are things that clash with the basic needs, from uncustomisable products, to vendors that add bottlenecks and unreliable dependencies. + +### Encountering friction between data engineers and Python-first analysts + +Before becoming a co-founder of dlt I had 5 interesting years as a startup employee, a half-year nightmare in a corporation with no autonomy or mastery (I got fired for refusing the madness, and it was such a huge relief), followed by 5 fun, rewarding and adventure-filled years of freelancing. Much of my work was “build&hire” which usually meant building a first time data warehouse and hiring a team for it. The setups that I did were bespoke to the businesses that were getting them, including the teams - Meaning, the technical complexity was also tailored to the (lack of) technical culture of the companies I was building for. + +In this time, I saw an acute friction between data engineers and Python-first analysts, mostly around the fact that data engineers easily become a bottleneck and data scientists are forced to pick up the slack. And of course, this causes other issues that might further complicate the life of the data engineer, while still not being a good solution for the data consumers. + +So at this point I started building boilerplate code for data warehouses and learning how to better cater to the entire team. + + +### II. The initial idea: pandas.df.to_sql() with data engineering best practices + +After a few attempts I ended up with the hypothesis that df.to_sql() is the natural abstraction a data person would use - I have a table here, I want a table there, shouldn’t be harder than a function call right? + +Right. + +Except that particular function call is anything but data engineering complete. A single run will do what it promises. A production pipeline will also have many additional requirements. In the early days, we wrote up an ideal list of features that should be auto-handled (spoiler alert: today dlt does all that and more). Read on for the wish list: + +### Our dream: a tool that meets production pipelines requirements + +- Wouldn’t it be nice if we could auto-flatten and unpack nested structures into tables with generated join keys? + + +- Wouldn’t it be nice if data types were properly defined and managed? +- Wouldn’t it be nice if we could load the data incrementally, meaning retain some state to know where to start from? +- Wouldn’t it be nice if this incremental load was bound to a way to do incremental extraction? +- Wouldn’t it be nice if we didn’t run out of memory? +- Wouldn’t it be nice if we got alerted/notified when schemas change? +- Wouldn’t it be nice if schema changes were self healing? +- Wouldn’t it be nice if I could run it all in parallel, or do async calls? +- Wouldn’t it be nice if it ran on different databases too, from dev to prod? +- Wouldn’t it be nice if it offered requests with built in retries for those nasty unreliable apis (Hey Zendesk, why you fail on call 99998/100000?) +- Wouldn’t it be nice if we had some extraction helpers like pagination detection? + +Auto typing and unpacking with generated keys: +![keys](https://storage.googleapis.com/dlt-blog-images/generated_keys.png) + +Performance [docs](https://dlthub.com/docs/reference/performance) + + +### The initial steps + +How did we go about it? At first dlt was created as an engine to iron out its functionality. During this time, it was deployed it in several projects, from startups to enterprises, particularly to accelerate data pipeline building in a robust way. + +A while later, to prepare this engine for the general public, we created the current interface on top of it. We then tested it in a workshop with many “Normies” of which over 50% were pre-employment learners. + +For the workshop we broke down the steps to build an incremental pipeline into 20 steps. In the 6 hour workshop we asked people to react on Slack to each “checkpoint”. We then exported the slack data and loaded it with dlt, exposing the completion rate per checkpoint. Turns out, it was 100%. +Everyone who started, managed to build the pipeline. “This is it!” we thought, and spend the next 6 months preparing our docs and adding some plugins for easy deployment. + +## III. Launching dlt + +We finally launched dlt mid 2023 to the general public. Our initial community was mostly data engineers who had been using dlt without docs, +managing from reading code. As we hoped a lot of “normies” are using dlt, too! + +## dlt = code + docs + Slack support + +A product is a sum of many parts. For us dlt is not only the dlt library and interface, but also our docs and Slack community and the support and discussions there. + +In the early days of dlt we talked to Sebastian Ramirez from FastAPI who told us that he spends 2/3 of his FastAPI time writing documentation. + +In this vein, from the beginning docs were very important to us and we quickly adopted our own [docs standard](https://www.writethedocs.org/videos/eu/2017/the-four-kinds-of-documentation-and-why-you-need-to-understand-what-they-are-daniele-procida/). + +However, when we originally launched dlt, we found that different user types, especially Normies, expect different things from our docs, and because we asked for feedback, they told us. + +So overall, we were not satisfied to stop there. + +### "Can you make your docs more like my favorite tool's docs?" + +To this end we built and embedded our own docs helper in our docs. + +The result? The docs helper has been running for a year and we currently see around **300 questions per day.** Comparing this to other communities that do AI support on Slack, that’s almost 2 orders of magnitude difference in question volume by community size. + +We think this is a good thing, and a result of several factors. + +- Embedded in docs means at the right place at the right time. Available to anyone, whether they use Slack or not. +- Conversations are private and anonymous. This reduces the emotional barrier of asking. We suspect this is great for the many “Normies” / “problem solvers” that work in data. +- The questions are different than in our Slack community: Many questions are around “Setup and configuration”, “Troubleshooting” and “General questions” about dlt architecture. In Slack, we see the questions that our docs or assistant could not answer. +- The bot is conversational and will remember recent context, enabling it to be particularly helpful. This is different from the “question answering service” that many Slack bots offer, which do not keep context once a question was answered. By retaining context, it’s possible to reach a useful outcome even if it doesn’t come in the first reply. + +### dlt = “pip install and go” - the fastest way to create a pipeline and source + +dlt offers a small number of verified sources, but encourages you to build your own. As we +mentioned, creating an ad hoc dlt [pipeline and source](https://dlthub.com/docs/tutorial/load-data-from-an-api) is +[dramatically simpler](https://dlthub.com/docs/build-a-pipeline-tutorial#the-simplest-pipeline-1-liner-to-load-data-with-schema-evolution) compared to other python libraries. +Maintaining a custom dlt source in production takes no time at all because the pipeline won't break unless the source stops existing. + +The sources you build and run that are not shared back into the verified sources are what we call “private sources”. + +By the end of 2023, our community had already built 1,000 private sources, [2,000 by early March](https://dlthub.com/docs/blog/code-vs-buy). We +are now at the end of q2 2024 and we see 5,000 private sources. + +### Embracing LLM-free code generation + +We recently launched additional tooling that helps our users build sources. If you wish to try our python-first +dict-based declarative approach to building sources, check out the relevant post. + +- Rest api connector +- Openapi based pipeline generator that configures the rest api connector. + +Alena introduces the generator and troubleshoots the outcome in 4min: + + +Community videos for rest api source: [playlist](https://www.youtube.com/playlist?list=PLpTgUMBCn15rs2NkB4ise780UxLKImZTh). + +Both tools are LLM-free pipeline generators. I stress LLM free, because in our experience, GPT can +do some things to some extent - so if we ask it to complete 10 tasks to produce a pipeline, each +having 50-90% accuracy, we can expect very low success rates. + +To get around this problem, we built from the OpenAPI standard which contains information that can +be turned into a pipeline algorithmically. OpenAPI is an Api spec that’s also used by FastAPI and +constantly growing in popularity, with around 50% of apis currently supporting it. + +By leveraging the data in the spec, we are able to have a basic pipeline. Our generator also infers +some other pieces of information algorithmically to make the pipeline incremental and add some other useful details. + +### When generation doesn’t work + +Of course, generation doesn’t always work but you can take the generated pipeline and make the final +adjustments to have a standard REST API config-based pipeline that won’t suffer from code smells. + +### The benefit of minimalistic sources + +The real benefit of this declarative source is not at building time - A declarative interface requires +more upfront knowledge. Instead, by having this option, we enable minimalistic pipelines that anyone could +maintain, including non coders or human-assisted LLMs. After all, LLMs are particularly proficient at translating configurations back and forth. + +Want to influence us? we listen, so you’re welcome to discuss with us in our slack channel [**#4-discussions**](https://dlthub.com/community) + +### Towards a paid offering + +dlt is an open core product, meaning it won’t be gated to push you to the paid version at some point. +Instead, much like Kafka and Confluent, we will offer things around dlt to help you leverage it in your context. + +If you are interested to help us research what’s needed, you can apply for our design partnership +program, that aims to help you deploy dlt, while helping us learn about your challenges. + +## Call to action. + +If you like the idea of dlt, there is one thing that would help us: + +**Set aside 30min and try it.** + +See resource below. + +We often hear variations of “oh i postponed dlt so long but it only took a few minutes to get going, wish I hadn’t +installed [other tool] which took 2 weeks to set up properly and now we need to maintain or replace”, so don't be that guy. + + +Here are some notebooks and docs to open your appetite: + + +- An [API pipeline step by step tutorial](https://dlthub.com/docs/tutorial/load-data-from-an-api) to build a production pipeline from an api +- A colab demo of [schema evolution](https://colab.research.google.com/drive/1H6HKFi-U1V4p0afVucw_Jzv1oiFbH2bu#scrollTo=e4y4sQ78P_OM) (2min read) +- Docs: RestClient, the imperative class that powers the REST API source, featuring auto pagination https://dlthub.com/docs/general-usage/http/rest-client +- Docs: [Build a simple pipeline](https://dlthub.com/docs/walkthroughs/create-a-pipeline) +- Docs: [Build a complex pipeline](https://dlthub.com/docs/walkthroughs/create-a-pipeline) +- Docs: [capabilities overview](https://dlthub.com/docs/build-a-pipeline-tutorial) hub page +- Community & Help: [Slack join link.](https://dlthub.com/community) \ No newline at end of file diff --git a/docs/website/blog/2024-06-19-scd2-and-incremental-loading.md b/docs/website/blog/2024-06-19-scd2-and-incremental-loading.md new file mode 100644 index 0000000000..11c858c076 --- /dev/null +++ b/docs/website/blog/2024-06-19-scd2-and-incremental-loading.md @@ -0,0 +1,132 @@ +--- +slug: scd2-and-incremental-loading +title: "Slowly Changing Dimension Type2: Explanation and code" +image: https://storage.googleapis.com/dlt-blog-images/flowchart_for_scd2.png +authors: + name: Aman Gupta + title: Junior Data Engineer + url: https://github.com/dat-a-man + image_url: https://dlt-static.s3.eu-central-1.amazonaws.com/images/aman.png +tags: [scd2, incremental loading, slowly changing dimensions, python data pipelines] +--- + + + +:::info +**Check [this Colab Notebook](https://colab.research.google.com/drive/115cRdw1qvekZbXIQSXYkAZzLAqD9_x_I) for a short and sweet demo.** +::: + +# What is a slowly changing dimension? + +Slowly changing dimensions are a dimensional modelling technique created for historising changes in data. + +This technique only works if the dimensions change slower than we read the data, since we would not be able to track changes happening between reads. +For example, if someone changes their address once in a blue moon, we will capture the changes with daily loads - but if +they change their address 3x in a day, we will only see the last state and only capture 2 of the 4 versions of the address. + +However, they enable you to track things you could not before such as + +- Hard deletes +- Most of the changes and when they occurred +- Different versions of entities valid at different historical times + +## What is Slowly Changing Dimension Type 2 (SCD2)? and why use it? + +The Type 2 subtype of Slowly Changing Dimensions (SCD) manages changes in data over time. +When data changes, a new record is added to the database, but the old record remains unchanged. +Each record includes a timestamp or version number. This allows you to view both the historical +data and the most current data separately. + +Traditional data loading methods often involve updating existing records with new information, which results in the loss of historical data. + +SCD2 not only preserves an audit trail of data changes but also allows for accurate historical analysis and reporting. + +## SCD2 applications + +[Colab demo](https://colab.research.google.com/drive/115cRdw1qvekZbXIQSXYkAZzLAqD9_x_I) + +### Use Case 1: Versioning a record that changes + +In environments where maintaining a complete historical record of data changes is crucial, +such as in financial services or healthcare, SCD Type 2 plays a vital role. For instance, if a +customer's address changes, SCD2 ensures that the old address is preserved in historical +records while the new address is available for current transactions. This ability to view the +evolution of data over time supports auditing, tracking changes, and analyzing trends without losing +the context of past information. It allows organizations to track the lifecycle of a data +entity across different states. + +Here's an example with the customer address change. + +Before: + +| `_dlt_valid_from` | `_dlt_valid_to` | `customer_key` | `c1` | `c2` | +|-----------------------------|-----------------|----------------|-------------|------| +| 2024-04-09 18:27:53.734235 | NULL | 1 | 123 Elm St | TN | + +After update: + +| `_dlt_valid_from` | `_dlt_valid_to` | `customer_key` | `c1` | `c2` | +|-----------------------------|-----------------------------|----------------|-------------|------| +| 2024-04-09 18:27:53.734235 | 2024-05-01 17:00:00.000000 | 1 | 123 Elm St | TN | +| 2024-05-02 08:00:00.000000 | NULL | 1 | 456 Oak Ave | TN | + +In the updated state, the previous address record is closed with an `_dlt_valid_to` timestamp, and a new record is created +with the new address "456 Oak Ave" effective from May 2, 2024. The NULL in the `_dlt_valid_to` field for this +new record signifies that it is the current and active address. + +### Use Case 2: Tracking deletions + +This approach ensures that historical data is preserved for audit and compliance purposes, even though the +record is no longer active in the current dataset. It allows businesses to maintain integrity and a full +historical trail of their data changes. + +State Before Deletion: Customer Record Active + +| `_dlt_valid_from` | `_dlt_valid_to` | `customer_key` | `c1` | `c2` | +|-----------------------------|-----------------|----------------|-------------|------| +| 2024-04-09 18:27:53.734235 | NULL | 1 | 123 Elm St | TN | +This table shows the customer record when it was active, with an address at "123 Elm St". The `_dlt_valid_to` field is NULL, indicating that the record is currently active. + +State after deletion: Customer record marked as deleted + +| `_dlt_valid_from` | `_dlt_valid_to` | `customer_key` | `c1` | `c2` | +|-----------------------------|-----------------------------|----------------|-------------|------| +| 2024-04-09 18:27:53.734235 | 2024-06-01 10:00:00.000000 | 1 | 123 Elm St | TN | + +In this updated table, the record that was previously active is marked as deleted by updating the `_dlt_valid_to` field +to reflect the timestamp when the deletion was recognized, in this case, June 1, 2024, at 10:00 AM. The presence +of a non-NULL `_dlt_valid_to` date indicates that this record is no longer active as of that timestamp. + + +Learn how to customise your column names and validity dates in our [SDC2 docs](https://dlthub.com/docs/general-usage/incremental-loading#scd2-strategy). + + +### Surrogate keys, what are they? Why use? + +Every record in the SCD2 table needs its own id. We call this a surrogate key. We use it to identify the specific +record or version of an entity, and we can use it when joining to our fact tables for performance (as opposed to joining on entity id + validity time). + +### Simple steps to determine data loading strategy and write disposition + +This decision flowchart helps determine the most suitable data loading strategy and write disposition: + +1. Is your data stateful? Stateful data is subject to change, like your age. Stateless data does not change, for example, events that happened in the past are stateless. + + 1. If your data is stateless, such as logs, you can just increment by appending new logs. + 2. If it is stateful, do you need to track changes to it? + 1. If yes, then use SCD2 to track changes. + 2. If no, + 1. Can you extract it incrementally (new changes only)? + 1. If yes, load incrementally via merge. + 2. If no, re-load fully via replace. + +Below is a visual representation of steps discussed above: +![Image](https://storage.googleapis.com/dlt-blog-images/flowchart_for_scd2.png) + +### **Conclusion** + +Use SCD2 where it makes sense but keep in mind the shortcomings related to the read vs update frequency. +Use dlt to do it at loading and keep everything downstream clean and simple. + +Want to discuss? +[Join the dlt slack community!](https://dlthub.com/community) diff --git a/docs/website/blog/2024-06-21-google-forms-to-notion.md b/docs/website/blog/2024-06-21-google-forms-to-notion.md new file mode 100644 index 0000000000..ec1631bc44 --- /dev/null +++ b/docs/website/blog/2024-06-21-google-forms-to-notion.md @@ -0,0 +1,142 @@ +--- +slug: google-forms-to-notion +title: "Syncing Google Forms data with Notion using dlt" +authors: + name: Aman Gupta + title: Junior Data Engineer + url: https://github.com/dat-a-man + image_url: https://dlt-static.s3.eu-central-1.amazonaws.com/images/aman.png +tags: [google forms, cloud functions, google-forms-to-notion] +--- + +## Why do we do it? + +Hello, I'm Aman, and I assist the dlthub team with various data-related tasks. In a recent project, the Operations team needed to gather information through Google Forms and integrate it into a Notion database. Initially, they tried using the Zapier connector as a quick and cost-effective solution, but it didn’t work as expected. Since we’re at dlthub, where everyone is empowered to create pipelines, I stepped in to develop one that would automate this process. + +The solution involved setting up a workflow to automatically sync data from Google Forms to a Notion database. This was achieved using Google Sheets, Google Apps Script, and a `dlt` pipeline, ensuring that every new form submission was seamlessly transferred to the Notion database without the need for manual intervention. + +## Implementation + +So here are a few steps followed: + +**Step 1: Link Google Form to Google Sheet** + +Link the Google Form to a Google Sheet to save responses in the sheet. Follow [Google's documentation](https://support.google.com/docs/answer/2917686?hl=en#zippy=%2Cchoose-where-to-store-responses) for setup. + +**Step 2: Google Apps Script for Data Transfer** + +Create a Google Apps Script to send data from Google Sheets to a Notion database via a webhook. This script triggers every time a form response is saved. + +**Google Apps Script code:** + +```text +function sendWebhookOnEdit(e) { + var sheet = SpreadsheetApp.getActiveSpreadsheet().getActiveSheet(); + var range = sheet.getActiveRange(); + var updatedRow = range.getRow(); + var lastColumn = sheet.getLastColumn(); + var headers = sheet.getRange(1, 1, 1, lastColumn).getValues()[0]; + var updatedFields = {}; + var rowValues = sheet.getRange(updatedRow, 1, 1, lastColumn).getValues()[0]; + + for (var i = 0; i < headers.length; i++) { + updatedFields[headers[i]] = rowValues[i]; + } + + var jsonPayload = JSON.stringify(updatedFields); + Logger.log('JSON Payload: ' + jsonPayload); + + var url = 'https://your-webhook.cloudfunctions.net/to_notion_from_google_forms'; // Replace with your Cloud Function URL + var options = { + 'method': 'post', + 'contentType': 'application/json', + 'payload': jsonPayload + }; + + try { + var response = UrlFetchApp.fetch(url, options); + Logger.log('Response: ' + response.getContentText()); + } catch (error) { + Logger.log('Failed to send webhook: ' + error.toString()); + } +} +``` + +**Step 3: Deploying the ETL Pipeline** + +Deploy a `dlt` pipeline to Google Cloud Functions to handle data transfer from Google Sheets to the Notion database. The pipeline is triggered by the Google Apps Script. + +1. Create a Google Cloud function. +2. Create `main.py` with the Python code below. +3. Ensure `requirements.txt` includes `dlt`. +4. Deploy the pipeline to Google Cloud Functions. +5. Use the function URL in the Google Apps Script. + +:::note +This pipeline uses `@dlt.destination` decorator which is used to set up custom destinations. Using custom destinations is a part of `dlt's` reverse ETL capabilities. To read more about `dlt's` reverse ETL pipelines, please read the [documentation here.](https://dlthub.com/docs/dlt-ecosystem/destinations/destination) +::: + +**Python code for `main.py` (Google cloud functions) :** + +```py +import dlt +from dlt.common import json +from dlt.common.typing import TDataItems +from dlt.common.schema import TTableSchema +from datetime import datetime +from dlt.sources.helpers import requests + +@dlt.destination(name="notion", batch_size=1, naming_convention="direct", skip_dlt_columns_and_tables=True) +def insert_into_notion(items: TDataItems, table: TTableSchema) -> None: + api_key = dlt.secrets.value # Add your notion API key to "secrets.toml" + database_id = "your_notion_database_id" # Replace with your Notion Database ID + url = "https://api.notion.com/v1/pages" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "Notion-Version": "2022-02-22" + } + + for item in items: + if isinstance(item.get('Timestamp'), datetime): + item['Timestamp'] = item['Timestamp'].isoformat() + data = { + "parent": {"database_id": database_id}, + "properties": { + "Timestamp": { + "title": [{ + "text": {"content": item.get('Timestamp')} + }] + }, + # Add other properties here + } + } + response = requests.post(url, headers=headers, data=json.dumps(data)) + print(response.status_code, response.text) + +def your_webhook(request): + data = request.get_json() + Event = [data] + + pipeline = dlt.pipeline( + pipeline_name='platform_to_notion', + destination=insert_into_notion, + dataset_name='webhooks', + full_refresh=True + ) + + pipeline.run(Event, table_name='webhook') + return 'Event received and processed successfully.' +``` + +### Step 4: Automation and Real-Time updates + +With everything set up, the workflow automates data transfer as follows: + +1. Form submission saves data in Google Sheets. +2. Google Apps Script sends a POST request to the Cloud Function. +3. The `dlt` pipeline processes the data and updates the Notion database. + +# Conclusion + +We initially considered using Zapier for this small task, but ultimately, handling it ourselves proved to be quite effective. Since we already use an orchestrator for our other automations, the only expense was the time I spent writing and testing the code. This experience demonstrates that `dlt` is a straightforward and flexible tool, suitable for a variety of scenarios. Essentially, wherever Python can be used, `dlt` can be applied effectively for data loading, provided it meets your specific needs. \ No newline at end of file diff --git a/docs/website/docs/dlt-ecosystem/destinations/athena.md b/docs/website/docs/dlt-ecosystem/destinations/athena.md index 93291bfe9a..2a8b8c6b9d 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/athena.md +++ b/docs/website/docs/dlt-ecosystem/destinations/athena.md @@ -102,8 +102,11 @@ Athena does not support JSON fields, so JSON is stored as a string. > ❗**Athena does not support TIME columns in parquet files**. `dlt` will fail such jobs permanently. Convert `datetime.time` objects to `str` or `datetime.datetime` to load them. -### Naming Convention -We follow our snake_case name convention. Keep the following in mind: +### Table and column identifiers +Athena uses case insensitive identifiers and **will lower case all the identifiers** that are stored in the INFORMATION SCHEMA. Do not use +[case sensitive naming conventions](../../general-usage/naming-convention.md#case-sensitive-and-insensitive-destinations). Letter casing will be removed anyway and you risk to generate identifier collisions, which are detected by `dlt` and will fail the load process. + +Under the hood Athena uses different SQL engines for DDL (catalog) and DML/Queries: * DDL uses HIVE escaping with `````` * Other queries use PRESTO and regular SQL escaping. @@ -141,7 +144,7 @@ For every table created as an iceberg table, the Athena destination will create The `merge` write disposition is supported for Athena when using iceberg tables. > Note that: -> 1. there is a risk of tables ending up in inconsistent state in case a pipeline run fails mid flight, because Athena doesn't support transactions, and `dlt` uses multiple DELETE/UPDATE/INSERT statements to implement `merge`, +> 1. there is a risk of tables ending up in inconsistent state in case a pipeline run fails mid flight, because Athena doesn't support transactions, and `dlt` uses multiple DELETE/UPDATE/INSERT statements to implement `merge`, > 2. `dlt` creates additional helper tables called `insert_` and `delete_
` in the staging schema to work around Athena's lack of temporary tables. ### dbt support @@ -183,7 +186,7 @@ Here is an example of how to use the adapter to partition a table: from datetime import date import dlt -from dlt.destinations.impl.athena.athena_adapter import athena_partition, athena_adapter +from dlt.destinations.adapters import athena_partition, athena_adapter data_items = [ (1, "A", date(2021, 1, 1)), diff --git a/docs/website/docs/dlt-ecosystem/destinations/bigquery.md b/docs/website/docs/dlt-ecosystem/destinations/bigquery.md index 4f99901e37..51d124251a 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/bigquery.md +++ b/docs/website/docs/dlt-ecosystem/destinations/bigquery.md @@ -136,6 +136,58 @@ def streamed_resource(): streamed_resource.apply_hints(additional_table_hints={"x-insert-api": "streaming"}) ``` +### Use BigQuery schema autodetect for nested fields +You can let BigQuery to infer schemas and create destination tables instead of `dlt`. As a consequence, nested fields (ie. `RECORD`), which `dlt` does not support at +this moment (they are stored as JSON), may be created. You select certain resources with [BigQuery Adapter](#bigquery-adapter) or all of them with the following config option: +```toml +[destination.bigquery] +autodetect_schema=true +``` +We recommend to yield [arrow tables](../verified-sources/arrow-pandas.md) from your resources and `parquet` file format to load the data. In that case the schemas generated by `dlt` and BigQuery +will be identical. BigQuery will also preserve the column order from the generated parquet files. You can convert `json` data into arrow tables with [pyarrow or duckdb](../verified-sources/arrow-pandas.md#loading-json-documents). + +```py +import pyarrow.json as paj + +import dlt +from dlt.destinations.adapters import bigquery_adapter + +@dlt.resource(name="cve") +def load_cve(): + with open("cve.json", 'rb') as f: + # autodetect arrow schema and yields arrow table + yield paj.read_json(f) + +pipeline = dlt.pipeline("load_json_struct", destination="bigquery") +pipeline.run( + bigquery_adapter(load_cve(), autodetect_schema=True) +) +``` +Above, we use `pyarrow` library to convert `json` document into `arrow` table and use `biguery_adapter` to enable schema autodetect for **cve** resource. + +Yielding Python dicts/lists and loading them as `jsonl` works as well. In many cases, the resulting nested structure is simpler than those obtained via pyarrow/duckdb and parquet. However there are slight differences in inferred types from `dlt` (BigQuery coerces types more aggressively). BigQuery also does not try to preserve the column order in relation to the order of fields in JSON. + +```py +import dlt +from dlt.destinations.adapters import bigquery_adapter + +@dlt.resource(name="cve", max_table_nesting=1) +def load_cve(): + with open("cve.json", 'rb') as f: + yield json.load(f) + +pipeline = dlt.pipeline("load_json_struct", destination="bigquery") +pipeline.run( + bigquery_adapter(load_cve(), autodetect_schema=True) +) +``` +In the example below we represent `json` data as tables up until nesting level 1. Above this nesting level, we let BigQuery to create nested fields. + +:::caution +If you yield data as Python objects (dicts) and load this data as `parquet`, the nested fields will be converted into strings. This is one of the consequences of +`dlt` not being able to infer nested fields. +::: + ## Supported File Formats You can configure the following file formats to load data to BigQuery: @@ -148,7 +200,11 @@ When staging is enabled: * [jsonl](../file-formats/jsonl.md) is used by default. * [parquet](../file-formats/parquet.md) is supported. -> ❗ **Bigquery cannot load JSON columns from `parquet` files**. `dlt` will fail such jobs permanently. Switch to `jsonl` to load and parse JSON properly. +:::caution +**Bigquery cannot load JSON columns from `parquet` files**. `dlt` will fail such jobs permanently. Instead: +* Switch to `jsonl` to load and parse JSON properly. +* Use schema [autodetect and nested fields](#use-bigquery-schema-autodetect-for-nested-fields) +::: ## Supported Column Hints @@ -167,6 +223,22 @@ BigQuery supports the following [column hints](https://dlthub.com/docs/general-u * `cluster` - creates a cluster column(s). Many columns per table are supported and only when a new table is created. +### Table and column identifiers +BigQuery uses case sensitive identifiers by default and this is what `dlt` assumes. If the dataset you use has case insensitive identifiers (you have such option +when you create it) make sure that you use case insensitive [naming convention](../../general-usage/naming-convention.md#case-sensitive-and-insensitive-destinations) or you tell `dlt` about it so identifier collisions are properly detected. +```toml +[destination.bigquery] +has_case_sensitive_identifiers=false +``` + +You have an option to allow `dlt` to set the case sensitivity for newly created datasets. In that case it will follow the case sensitivity of current +naming convention (ie. the default **snake_case** will create dataset with case insensitive identifiers). +```toml +[destination.bigquery] +should_set_case_sensitivity_on_new_dataset=true +``` +The option above is off by default. + ## Staging Support BigQuery supports GCS as a file staging destination. `dlt` will upload files in the parquet format to GCS and ask BigQuery to copy their data directly into the database. @@ -232,7 +304,7 @@ Here is an example of how to use the `bigquery_adapter` method to apply hints to from datetime import date, timedelta import dlt -from dlt.destinations.impl.bigquery.bigquery_adapter import bigquery_adapter +from dlt.destinations.adapters import bigquery_adapter @dlt.resource( diff --git a/docs/website/docs/dlt-ecosystem/destinations/clickhouse.md b/docs/website/docs/dlt-ecosystem/destinations/clickhouse.md index b1dde5a328..bf8e2bce02 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/clickhouse.md +++ b/docs/website/docs/dlt-ecosystem/destinations/clickhouse.md @@ -37,7 +37,7 @@ or with `pip install "dlt[clickhouse]"`, which installs the `dlt` library and th ### 2. Setup ClickHouse database -To load data into ClickHouse, you need to create a ClickHouse database. While we recommend asking our GPT-4 assistant for details, we have provided a general outline of the process below: +To load data into ClickHouse, you need to create a ClickHouse database. While we recommend asking our GPT-4 assistant for details, we've provided a general outline of the process below: 1. You can use an existing ClickHouse database or create a new one. @@ -59,35 +59,52 @@ To load data into ClickHouse, you need to create a ClickHouse database. While we ```toml [destination.clickhouse.credentials] - database = "dlt" # The database name you created - username = "dlt" # ClickHouse username, default is usually "default" - password = "Dlt*12345789234567" # ClickHouse password if any - host = "localhost" # ClickHouse server host - port = 9000 # ClickHouse HTTP port, default is 9000 - http_port = 8443 # HTTP Port to connect to ClickHouse server's HTTP interface. Defaults to 8443. + database = "dlt" # The database name you created. + username = "dlt" # ClickHouse username, default is usually "default". + password = "Dlt*12345789234567" # ClickHouse password if any. + host = "localhost" # ClickHouse server host. + port = 9000 # ClickHouse native TCP protocol port, default is 9000. + http_port = 8443 # ClickHouse HTTP port, default is 9000. secure = 1 # Set to 1 if using HTTPS, else 0. - dataset_table_separator = "___" # Separator for dataset table names from dataset. ``` - :::info http_port - The `http_port` parameter specifies the port number to use when connecting to the ClickHouse server's HTTP interface. This is different from default port 9000, which is used for the native TCP - protocol. + :::info Network Ports + The `http_port` parameter specifies the port number to use when connecting to the ClickHouse server's HTTP interface. + The default non-secure HTTP port for ClickHouse is `8123`. + This is different from the default port `9000`, which is used for the native TCP protocol. - You must set `http_port` if you are not using external staging (i.e. you don't set the staging parameter in your pipeline). This is because dlt's built-in ClickHouse local storage staging uses the - [clickhouse-connect](https://github.com/ClickHouse/clickhouse-connect) library, which communicates with ClickHouse over HTTP. + You must set `http_port` if you are not using external staging (i.e. you don't set the `staging` parameter in your pipeline). This is because dlt's built-in ClickHouse local storage staging uses the [clickhouse-connect](https://github.com/ClickHouse/clickhouse-connect) library, which communicates with ClickHouse over HTTP. - Make sure your ClickHouse server is configured to accept HTTP connections on the port specified by `http_port`. For example, if you set `http_port = 8443`, then ClickHouse should be listening for - HTTP - requests on port 8443. If you are using external staging, you can omit the `http_port` parameter, since clickhouse-connect will not be used in this case. + Make sure your ClickHouse server is configured to accept HTTP connections on the port specified by `http_port`. For example: + + - If you set `http_port = 8123` (default non-secure HTTP port), then ClickHouse should be listening for HTTP requests on port 8123. + - If you set `http_port = 8443`, then ClickHouse should be listening for secure HTTPS requests on port 8443. + + If you're using external staging, you can omit the `http_port` parameter, since clickhouse-connect will not be used in this case. + + For local development and testing with ClickHouse running locally, it is recommended to use the default non-secure HTTP port `8123` by setting `http_port=8123` or omitting the parameter. + + Please see the [ClickHouse network port documentation](https://clickhouse.com/docs/en/guides/sre/network-ports) for further reference. ::: 2. You can pass a database connection string similar to the one used by the `clickhouse-driver` library. The credentials above will look like this: ```toml - # keep it at the top of your toml file, before any section starts. + # keep it at the top of your toml file before any section starts. destination.clickhouse.credentials="clickhouse://dlt:Dlt*12345789234567@localhost:9000/dlt?secure=1" ``` +### 3. Add configuration options + +You can set the following configuration options in the `.dlt/secrets.toml` file: + +```toml +[destination.clickhouse] +dataset_table_separator = "___" # The default separator for dataset table names from dataset. +table_engine_type = "merge_tree" # The default table engine to use. +dataset_sentinel_table_name = "dlt_sentinel_table" # The default name for sentinel tables. +``` + ## Write disposition All [write dispositions](../../general-usage/incremental-loading#choosing-a-write-disposition) are supported. @@ -104,7 +121,8 @@ Data is loaded into ClickHouse using the most efficient method depending on the `Clickhouse` does not support multiple datasets in one database, dlt relies on datasets to exist for multiple reasons. To make `clickhouse` work with `dlt`, tables generated by `dlt` in your `clickhouse` database will have their name prefixed with the dataset name separated by -the configurable `dataset_table_separator`. Additionally, a special sentinel table that does not contain any data will be created, so dlt knows which virtual datasets already exist in a +the configurable `dataset_table_separator`. +Additionally, a special sentinel table that doesn't contain any data will be created, so dlt knows which virtual datasets already exist in a clickhouse destination. @@ -115,14 +133,15 @@ destination. The `clickhouse` destination has a few specific deviations from the default sql destinations: -1. `Clickhouse` has an experimental `object` datatype, but we have found it to be a bit unpredictable, so the dlt clickhouse destination will load the complex datatype to a `text` column. If you need +1. `Clickhouse` has an experimental `object` datatype, but we've found it to be a bit unpredictable, so the dlt clickhouse destination will load the complex datatype to a `text` column. + If you need this feature, get in touch with our Slack community, and we will consider adding it. 2. `Clickhouse` does not support the `time` datatype. Time will be loaded to a `text` column. 3. `Clickhouse` does not support the `binary` datatype. Binary will be loaded to a `text` column. When loading from `jsonl`, this will be a base64 string, when loading from parquet this will be the `binary` object converted to `text`. -4. `Clickhouse` accepts adding columns to a populated table that are not null. -5. `Clickhouse` can produce rounding errors under certain conditions when using the float / double datatype. Make sure to use decimal if you cannot afford to have rounding errors. Loading the value - 12.7001 to a double column with the loader file format jsonl set will predictbly produce a rounding error for example. +4. `Clickhouse` accepts adding columns to a populated table that aren’t null. +5. `Clickhouse` can produce rounding errors under certain conditions when using the float / double datatype. Make sure to use decimal if you can’t afford to have rounding errors. Loading the value + 12.7001 to a double column with the loader file format jsonl set will predictably produce a rounding error, for example. ## Supported column hints @@ -130,31 +149,46 @@ ClickHouse supports the following [column hints](../../general-usage/schema#tabl - `primary_key` - marks the column as part of the primary key. Multiple columns can have this hint to create a composite primary key. -## Table Engine +## Choosing a Table Engine + +dlt defaults to `MergeTree` table engine. You can specify an alternate table engine in two ways: -By default, tables are created using the `ReplicatedMergeTree` table engine in ClickHouse. You can specify an alternate table engine using the `table_engine_type` with the clickhouse adapter: +### Setting a default table engine in the configuration + +You can set a default table engine for all resources and dlt tables by adding the `table_engine_type` parameter to your ClickHouse credentials in the `.dlt/secrets.toml` file: + +```toml +[destination.clickhouse] +# ... (other configuration options) +table_engine_type = "merge_tree" # The default table engine to use. +``` + +### Setting the table engine for specific resources + +You can also set the table engine for specific resources using the clickhouse_adapter, which will override the default engine set in `.dlt/secrets.toml`, for that resource: ```py from dlt.destinations.adapters import clickhouse_adapter - @dlt.resource() def my_resource(): - ... - + ... clickhouse_adapter(my_resource, table_engine_type="merge_tree") - ``` -Supported values are: +Supported values for `table_engine_type` are: + +- `merge_tree` (default) - creates tables using the `MergeTree` engine, suitable for most use cases. [Learn more about MergeTree](https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree). +- `shared_merge_tree` - creates tables using the `SharedMergeTree` engine, optimized for cloud-native environments with shared storage. This table is **only** available on ClickHouse Cloud, and it the default selection if `merge_tree` is selected. [Learn more about SharedMergeTree](https://clickhouse.com/docs/en/cloud/reference/shared-merge-tree). +- `replicated_merge_tree` - creates tables using the `ReplicatedMergeTree` engine, which supports data replication across multiple nodes for high availability. [Learn more about ReplicatedMergeTree](https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/replication). This defaults to `shared_merge_tree` on ClickHouse Cloud. +- Experimental support for the `Log` engine family with `stripe_log` and `tiny_log`. -- `merge_tree` - creates tables using the `MergeTree` engine -- `replicated_merge_tree` (default) - creates tables using the `ReplicatedMergeTree` engine +For local development and testing with ClickHouse running locally, the `MergeTree` engine is recommended. ## Staging support -ClickHouse supports Amazon S3, Google Cloud Storage and Azure Blob Storage as file staging destinations. +ClickHouse supports Amazon S3, Google Cloud Storage, and Azure Blob Storage as file staging destinations. `dlt` will upload Parquet or JSONL files to the staging location and use ClickHouse table functions to load the data directly from the staged files. @@ -214,7 +248,7 @@ dlt's staging mechanisms for ClickHouse. ### dbt support -Integration with [dbt](../transformations/dbt/dbt.md) is generally supported via dbt-clickhouse, but not tested by us. +Integration with [dbt](../transformations/dbt/dbt.md) is generally supported via dbt-clickhouse but not tested by us. ### Syncing of `dlt` state diff --git a/docs/website/docs/dlt-ecosystem/destinations/dremio.md b/docs/website/docs/dlt-ecosystem/destinations/dremio.md index 546f470938..c087d5dc0a 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/dremio.md +++ b/docs/website/docs/dlt-ecosystem/destinations/dremio.md @@ -86,7 +86,7 @@ Data loading happens by copying a staged parquet files from an object storage bu Dremio does not support `CREATE SCHEMA` DDL statements. -Therefore, "Metastore" data sources, such as Hive or Glue, require that the dataset schema exists prior to running the dlt pipeline. `full_refresh=True` is unsupported for these data sources. +Therefore, "Metastore" data sources, such as Hive or Glue, require that the dataset schema exists prior to running the dlt pipeline. `dev_mode=True` is unsupported for these data sources. "Object Storage" data sources do not have this limitation. diff --git a/docs/website/docs/dlt-ecosystem/destinations/duckdb.md b/docs/website/docs/dlt-ecosystem/destinations/duckdb.md index d6ec36ae49..9ecd1ae6dc 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/duckdb.md +++ b/docs/website/docs/dlt-ecosystem/destinations/duckdb.md @@ -37,7 +37,7 @@ All write dispositions are supported. ### Names normalization `dlt` uses the standard **snake_case** naming convention to keep identical table and column identifiers across all destinations. If you want to use the **duckdb** wide range of characters (i.e., emojis) for table and column names, you can switch to the **duck_case** naming convention, which accepts almost any string as an identifier: -* `\n` `\r` and `" are translated to `_` +* `\n` `\r` and `"` are translated to `_` * multiple `_` are translated to a single `_` Switch the naming convention using `config.toml`: @@ -51,7 +51,7 @@ or via the env variable `SCHEMA__NAMING` or directly in the code: dlt.config["schema.naming"] = "duck_case" ``` :::caution -**duckdb** identifiers are **case insensitive** but display names preserve case. This may create name clashes if, for example, you load JSON with +**duckdb** identifiers are **case insensitive** but display names preserve case. This may create name collisions if, for example, you load JSON with `{"Column": 1, "column": 2}` as it will map data to a single column. ::: @@ -164,8 +164,7 @@ destination.duckdb.credentials=":pipeline:" ```py p = pipeline_one = dlt.pipeline( pipeline_name="my_pipeline", - destination="duckdb", - credentials=":pipeline:", + destination=dlt.destinations.duckdb(":pipeline:"), ) ``` diff --git a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md new file mode 100644 index 0000000000..dbf90da4b9 --- /dev/null +++ b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md @@ -0,0 +1,211 @@ +--- +title: LanceDB +description: LanceDB is an open source vector database that can be used as a destination in dlt. +keywords: [ lancedb, vector database, destination, dlt ] +--- + +# LanceDB + +[LanceDB](https://lancedb.com/) is an open-source, high-performance vector database. It allows you to store data objects and perform similarity searches over them. +This destination helps you load data into LanceDB from [dlt resources](../../general-usage/resource.md). + +## Setup Guide + +### Choosing a Model Provider + +First, you need to decide which embedding model provider to use. You can find all supported providers by visiting the official [LanceDB docs](https://lancedb.github.io/lancedb/embeddings/default_embedding_functions/). + + +### Install dlt with LanceDB + +To use LanceDB as a destination, make sure `dlt` is installed with the `lancedb` extra: + +```sh +pip install "dlt[lancedb]" +``` + +the lancedb extra only installs `dlt` and `lancedb`. You will need to install your model provider's SDK. + +You can find which libraries you need to also referring to the [LanceDB docs](https://lancedb.github.io/lancedb/embeddings/default_embedding_functions/). + +### Configure the destination + +Configure the destination in the dlt secrets file located at `~/.dlt/secrets.toml` by default. Add the following section: + +```toml +[destination.lancedb] +embedding_model_provider = "cohere" +embedding_model = "embed-english-v3.0" +[destination.lancedb.credentials] +uri = ".lancedb" +api_key = "api_key" # API key to connect to LanceDB Cloud. Leave out if you are using LanceDB OSS. +embedding_model_provider_api_key = "embedding_model_provider_api_key" # Not needed for providers that don't need authentication (ollama, sentence-transformers). +``` + +- The `uri` specifies the location of your LanceDB instance. It defaults to a local, on-disk instance if not provided. +- The `api_key` is your api key for LanceDB Cloud connections. If you're using LanceDB OSS, you don't need to supply this key. +- The `embedding_model_provider` specifies the embedding provider used for generating embeddings. The default is `cohere`. +- The `embedding_model` specifies the model used by the embedding provider for generating embeddings. + Check with the embedding provider which options are available. + Reference https://lancedb.github.io/lancedb/embeddings/default_embedding_functions/. +- The `embedding_model_provider_api_key` is the API key for the embedding model provider used to generate embeddings. If you're using a provider that doesn't need authentication, say ollama, you don't need to supply this key. + +:::info Available Model Providers +- "gemini-text" +- "bedrock-text" +- "cohere" +- "gte-text" +- "imagebind" +- "instructor" +- "open-clip" +- "openai" +- "sentence-transformers" +- "huggingface" +- "colbert" +::: + +### Define your data source + +For example: + +```py +import dlt +from dlt.destinations.adapters import lancedb_adapter + + +movies = [ + { + "id": 1, + "title": "Blade Runner", + "year": 1982, + }, + { + "id": 2, + "title": "Ghost in the Shell", + "year": 1995, + }, + { + "id": 3, + "title": "The Matrix", + "year": 1999, + }, +] +``` + +### Create a pipeline: + +```py +pipeline = dlt.pipeline( + pipeline_name="movies", + destination="lancedb", + dataset_name="MoviesDataset", +) +``` + +### Run the pipeline: + +```py +info = pipeline.run( + lancedb_adapter( + movies, + embed="title", + ) +) +``` + +The data is now loaded into LanceDB. + +To use **vector search** after loading, you **must specify which fields LanceDB should generate embeddings for**. Do this by wrapping the data (or dlt resource) with the **`lancedb_adapter`** +function. + +## Using an Adapter to Specify Columns to Vectorise + +Out of the box, LanceDB will act as a normal database. To use LanceDB's embedding facilities, you'll need to specify which fields you'd like to embed in your dlt resource. + +The `lancedb_adapter` is a helper function that configures the resource for the LanceDB destination: + +```py +lancedb_adapter(data, embed) +``` + +It accepts the following arguments: + +- `data`: a dlt resource object, or a Python data structure (e.g. a list of dictionaries). +- `embed`: a name of the field or a list of names to generate embeddings for. + +Returns: [dlt resource](../../general-usage/resource.md) object that you can pass to the `pipeline.run()`. + +Example: + +```py +lancedb_adapter( + resource, + embed=["title", "description"], +) +``` + +Bear in mind that you can't use an adapter on a [dlt source](../../general-usage/source.md), only a [dlt resource](../../general-usage/resource.md). + +## Write disposition + +All [write dispositions](../../general-usage/incremental-loading.md#choosing-a-write-disposition) are supported by the LanceDB destination. + +### Replace + +The [replace](../../general-usage/full-loading.md) disposition replaces the data in the destination with the data from the resource. + +```py +info = pipeline.run( + lancedb_adapter( + movies, + embed="title", + ), + write_disposition="replace", +) +``` + +### Merge + +The [merge](../../general-usage/incremental-loading.md) write disposition merges the data from the resource with the data at the destination based on a unique identifier. + +```py +pipeline.run( + lancedb_adapter( + movies, + embed="title", + ), + write_disposition="merge", + primary_key="id", +) +``` + +### Append + +This is the default disposition. It will append the data to the existing data in the destination. + +## Additional Destination Options + +- `dataset_separator`: The character used to separate the dataset name from table names. Defaults to "___". +- `vector_field_name`: The name of the special field to store vector embeddings. Defaults to "vector__". +- `id_field_name`: The name of the special field used for deduplication and merging. Defaults to "id__". +- `max_retries`: The maximum number of retries for embedding operations. Set to 0 to disable retries. Defaults to 3. + + +## dbt support + +The LanceDB destination doesn't support dbt integration. + +## Syncing of `dlt` state + +The LanceDB destination supports syncing of the `dlt` state. + +## Current Limitations + +Adding new fields to an existing LanceDB table requires loading the entire table data into memory as a PyArrow table. +This is because PyArrow tables are immutable, so adding fields requires creating a new table with the updated schema. + +For huge tables, this may impact performance and memory usage since the full table must be loaded into memory to add the new fields. +Keep these considerations in mind when working with large datasets and monitor memory usage if adding fields to sizable existing tables. + + + diff --git a/docs/website/docs/dlt-ecosystem/destinations/mssql.md b/docs/website/docs/dlt-ecosystem/destinations/mssql.md index 6aac877d7b..0512fd5fca 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/mssql.md +++ b/docs/website/docs/dlt-ecosystem/destinations/mssql.md @@ -114,6 +114,15 @@ Data is loaded via INSERT statements by default. MSSQL has a limit of 1000 rows ## Supported column hints **mssql** will create unique indexes for all columns with `unique` hints. This behavior **may be disabled**. +### Table and column identifiers +SQL Server **with the default collation** uses case insensitive identifiers but will preserve the casing of identifiers that are stored in the INFORMATION SCHEMA. You can use [case sensitive naming conventions](../../general-usage/naming-convention.md#case-sensitive-and-insensitive-destinations) to keep the identifier casing. Note that you risk to generate identifier collisions, which are detected by `dlt` and will fail the load process. + +If you change SQL Server server/database collation to case sensitive, this will also affect the identifiers. Configure your destination as below in order to use case sensitive naming conventions without collisions: +```toml +[destination.mssql] +has_case_sensitive_identifiers=true +``` + ## Syncing of `dlt` state This destination fully supports [dlt state sync](../../general-usage/state#syncing-state-with-destination). diff --git a/docs/website/docs/dlt-ecosystem/destinations/postgres.md b/docs/website/docs/dlt-ecosystem/destinations/postgres.md index ae504728c3..1281298312 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/postgres.md +++ b/docs/website/docs/dlt-ecosystem/destinations/postgres.md @@ -98,6 +98,10 @@ In the example above `arrow_table` will be converted to csv with **pyarrow** and ## Supported column hints `postgres` will create unique indexes for all columns with `unique` hints. This behavior **may be disabled**. +### Table and column identifiers +Postgres supports both case sensitive and case insensitive identifiers. All unquoted and lowercase identifiers resolve case-insensitively in SQL statements. Case insensitive [naming conventions](../../general-usage/naming-convention.md#case-sensitive-and-insensitive-destinations) like the default **snake_case** will generate case insensitive identifiers. Case sensitive (like **sql_cs_v1**) will generate +case sensitive identifiers that must be quoted in SQL statements. + ## Additional destination options The Postgres destination creates UNIQUE indexes by default on columns with the `unique` hint (i.e., `_dlt_id`). To disable this behavior: ```toml @@ -105,6 +109,28 @@ The Postgres destination creates UNIQUE indexes by default on columns with the ` create_indexes=false ``` +### Setting up `csv` format +You can provide [non-default](../file-formats/csv.md#default-settings) csv settings via configuration file or explicitly. +```toml +[destination.postgres.csv_format] +delimiter="|" +include_header=false +``` +or +```py +from dlt.destinations import postgres +from dlt.common.data_writers.configuration import CsvFormatConfiguration + +csv_format = CsvFormatConfiguration(delimiter="|", include_header=False) + +dest_ = postgres(csv_format=csv_format) +``` +Above we set `csv` file without header, with **|** as a separator. + +:::tip +You'll need those setting when [importing external files](../../general-usage/resource.md#import-external-files) +::: + ### dbt support This destination [integrates with dbt](../transformations/dbt/dbt.md) via dbt-postgres. diff --git a/docs/website/docs/dlt-ecosystem/destinations/redshift.md b/docs/website/docs/dlt-ecosystem/destinations/redshift.md index 7e0679ec6b..bb92d651f2 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/redshift.md +++ b/docs/website/docs/dlt-ecosystem/destinations/redshift.md @@ -93,10 +93,27 @@ Amazon Redshift supports the following column hints: - `cluster` - This hint is a Redshift term for table distribution. Applying it to a column makes it the "DISTKEY," affecting query and join performance. Check the following [documentation](https://docs.aws.amazon.com/redshift/latest/dg/c_best-practices-best-dist-key.html) for more info. - `sort` - This hint creates a SORTKEY to order rows on disk physically. It is used to improve query and join speed in Redshift. Please read the [sort key docs](https://docs.aws.amazon.com/redshift/latest/dg/c_best-practices-sort-key.html) to learn more. +### Table and column identifiers +Redshift **by default** uses case insensitive identifiers and **will lower case all the identifiers** that are stored in the INFORMATION SCHEMA. Do not use +[case sensitive naming conventions](../../general-usage/naming-convention.md#case-sensitive-and-insensitive-destinations). Letter casing will be removed anyway and you risk to generate identifier collisions, which are detected by `dlt` and will fail the load process. + +You can [put Redshift in case sensitive mode](https://docs.aws.amazon.com/redshift/latest/dg/r_enable_case_sensitive_identifier.html). Configure your destination as below in order to use case sensitive naming conventions: +```toml +[destination.redshift] +has_case_sensitive_identifiers=true +``` + + ## Staging support Redshift supports s3 as a file staging destination. dlt will upload files in the parquet format to s3 and ask Redshift to copy their data directly into the db. Please refer to the [S3 documentation](./filesystem.md#aws-s3) to learn how to set up your s3 bucket with the bucket_url and credentials. The `dlt` Redshift loader will use the AWS credentials provided for s3 to access the s3 bucket if not specified otherwise (see config options below). Alternatively to parquet files, you can also specify jsonl as the staging file format. For this, set the `loader_file_format` argument of the `run` command of the pipeline to `jsonl`. +## Identifier names and case sensitivity +* Up to 127 characters +* Case insensitive +* Stores identifiers in lower case +* Has case sensitive mode, if enabled you must [enable case sensitivity in destination factory](../../general-usage/destination.md#control-how-dlt-creates-table-column-and-other-identifiers) + ### Authentication IAM Role If you would like to load from s3 without forwarding the AWS staging credentials but authorize with an IAM role connected to Redshift, follow the [Redshift documentation](https://docs.aws.amazon.com/redshift/latest/mgmt/authorizing-redshift-service.html) to create a role with access to s3 linked to your Redshift cluster and change your destination settings to use the IAM role: diff --git a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md index deaaff3562..1a50505620 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md +++ b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md @@ -50,14 +50,14 @@ The instructions below assume that you use the default account setup that you ge --create database with standard settings CREATE DATABASE dlt_data; -- create new user - set your password here -CREATE USER loader WITH PASSWORD='' +CREATE USER loader WITH PASSWORD=''; -- we assign all permission to a role CREATE ROLE DLT_LOADER_ROLE; GRANT ROLE DLT_LOADER_ROLE TO USER loader; -- give database access to new role GRANT USAGE ON DATABASE dlt_data TO DLT_LOADER_ROLE; -- allow `dlt` to create new schemas -GRANT CREATE SCHEMA ON DATABASE dlt_data TO ROLE DLT_LOADER_ROLE +GRANT CREATE SCHEMA ON DATABASE dlt_data TO ROLE DLT_LOADER_ROLE; -- allow access to a warehouse named COMPUTE_WH GRANT USAGE ON WAREHOUSE COMPUTE_WH TO DLT_LOADER_ROLE; -- grant access to all future schemas and tables in the database @@ -71,9 +71,10 @@ You can also decrease the suspend time for your warehouse to 1 minute (**Admin** ### Authentication types Snowflake destination accepts three authentication types: +Snowflake destination accepts three authentication types: - password authentication - [key pair authentication](https://docs.snowflake.com/en/user-guide/key-pair-auth) -- external authentication +- oauth authentication The **password authentication** is not any different from other databases like Postgres or Redshift. `dlt` follows the same syntax as the [SQLAlchemy dialect](https://docs.snowflake.com/en/developer-guide/python-connector/sqlalchemy#required-parameters). @@ -81,6 +82,7 @@ You can also pass credentials as a database connection string. For example: ```toml # keep it at the top of your toml file! before any section starts destination.snowflake.credentials="snowflake://loader:@kgiotue-wn98412/dlt_data?warehouse=COMPUTE_WH&role=DLT_LOADER_ROLE" + ``` In **key pair authentication**, you replace the password with a private key string that should be in Base64-encoded DER format ([DBT also recommends](https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication) base64-encoded private keys for Snowflake connections). The private key may also be encrypted. In that case, you must provide a passphrase alongside the private key. @@ -100,16 +102,32 @@ If you pass a passphrase in the connection string, please URL encode it. destination.snowflake.credentials="snowflake://loader:@kgiotue-wn98412/dlt_data?private_key=&private_key_passphrase=" ``` -In **external authentication**, you can use an OAuth provider like Okta or an external browser to authenticate. You pass your authenticator and refresh token as below: +In **oauth authentication**, you can use an OAuth provider like Snowflake, Okta or an external browser to authenticate. In case of Snowflake oauth, you pass your `authenticator` and refresh `token` as below: ```toml [destination.snowflake.credentials] database = "dlt_data" username = "loader" -authenticator="..." +authenticator="oauth" token="..." ``` or in the connection string as query parameters. -Refer to Snowflake [OAuth](https://docs.snowflake.com/en/user-guide/oauth-intro) for more details. + +In case of external authentication, you need to find documentation for your OAuth provider. Refer to Snowflake [OAuth](https://docs.snowflake.com/en/user-guide/oauth-intro) for more details. + +### Additional connection options +We pass all query parameters to `connect` function of Snowflake Python Connector. For example: +```toml +[destination.snowflake.credentials] +database = "dlt_data" +authenticator="oauth" +[destination.snowflake.credentials.query] +timezone="UTC" +# keep session alive beyond 4 hours +client_session_keep_alive=true +``` +Will set the timezone and session keep alive. Mind that if you use `toml` your configuration is typed. The alternative: +`"snowflake://loader/dlt_data?authenticator=oauth&timezone=UTC&client_session_keep_alive=true"` +will pass `client_session_keep_alive` as string to the connect method (which we didn't verify if it works). ## Write disposition All write dispositions are supported. @@ -124,21 +142,42 @@ The data is loaded using an internal Snowflake stage. We use the `PUT` command a * [insert-values](../file-formats/insert-format.md) is used by default * [parquet](../file-formats/parquet.md) is supported * [jsonl](../file-formats/jsonl.md) is supported +* [csv](../file-formats/csv.md) is supported When staging is enabled: * [jsonl](../file-formats/jsonl.md) is used by default * [parquet](../file-formats/parquet.md) is supported +* [csv](../file-formats/csv.md) is supported + +:::caution +When loading from `parquet`, Snowflake will store `complex` types (JSON) in `VARIANT` as a string. Use the `jsonl` format instead or use `PARSE_JSON` to update the `VARIANT` field after loading. +::: -> ❗ When loading from `parquet`, Snowflake will store `complex` types (JSON) in `VARIANT` as a string. Use the `jsonl` format instead or use `PARSE_JSON` to update the `VARIANT` field after loading. +### Custom csv formats +By default we support csv format [produced by our writers](../file-formats/csv.md#default-settings) which is comma delimited, with header and optionally quoted. + +You can configure your own formatting ie. when [importing](../../general-usage/resource.md#import-external-files) external `csv` files. +```toml +[destination.snowflake.csv_format] +delimiter="|" +include_header=false +on_error_continue=true +``` +Which will read, `|` delimited file, without header and will continue on errors. + +Note that we ignore missing columns `ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE` and we will insert NULL into them. ## Supported column hints Snowflake supports the following [column hints](https://dlthub.com/docs/general-usage/schema#tables-and-columns): * `cluster` - creates a cluster column(s). Many columns per table are supported and only when a new table is created. ### Table and column identifiers -Snowflake makes all unquoted identifiers uppercase and then resolves them case-insensitively in SQL statements. `dlt` (effectively) does not quote identifiers in DDL, preserving default behavior. +Snowflake supports both case sensitive and case insensitive identifiers. All unquoted and uppercase identifiers resolve case-insensitively in SQL statements. Case insensitive [naming conventions](../../general-usage/naming-convention.md#case-sensitive-and-insensitive-destinations) like the default **snake_case** will generate case insensitive identifiers. Case sensitive (like **sql_cs_v1**) will generate +case sensitive identifiers that must be quoted in SQL statements. +:::note Names of tables and columns in [schemas](../../general-usage/schema.md) are kept in lower case like for all other destinations. This is the pattern we observed in other tools, i.e., `dbt`. In the case of `dlt`, it is, however, trivial to define your own uppercase [naming convention](../../general-usage/schema.md#naming-convention) +::: ## Staging support @@ -248,6 +287,50 @@ stage_name="DLT_STAGE" keep_staged_files=true ``` +### Setting up `csv` format +You can provide [non-default](../file-formats/csv.md#default-settings) csv settings via configuration file or explicitly. +```toml +[destination.snowflake.csv_format] +delimiter="|" +include_header=false +on_error_continue=true +``` +or +```py +from dlt.destinations import snowflake +from dlt.common.data_writers.configuration import CsvFormatConfiguration + +csv_format = CsvFormatConfiguration(delimiter="|", include_header=False, on_error_continue=True) + +dest_ = snowflake(csv_format=csv_format) +``` +Above we set `csv` file without header, with **|** as a separator and we request to ignore lines with errors. + +:::tip +You'll need those setting when [importing external files](../../general-usage/resource.md#import-external-files) +::: + +### Query Tagging +`dlt` [tags sessions](https://docs.snowflake.com/en/sql-reference/parameters#query-tag) that execute loading jobs with following job properties: +* **source** - name of the source (identical with the name of `dlt` schema) +* **resource** - name of the resource (if known, else empty string) +* **table** - name of the table loaded by the job +* **load_id** - load id of the job +* **pipeline_name** - name of the active pipeline (or empty string if not found) + +You can define query tag by defining a query tag placeholder in snowflake credentials: +```toml +[destination.snowflake.credentials] +query_tag='{{"source":"{source}", "resource":"{resource}", "table": "{table}", "load_id":"{load_id}", "pipeline_name":"{pipeline_name}"}}' +``` +which contains Python named formatters corresponding to tag names ie. `{source}` will assume the name of the dlt source. + +:::note +1. query tagging is off by default. `query_tag` configuration field is `None` by default and must be set to enable tagging. +2. only sessions associated with a job are tagged. sessions that migrate schemas remain untagged +3. jobs processing table chains (ie. sql merge jobs) will use top level table as **table** +::: + ### dbt support This destination [integrates with dbt](../transformations/dbt/dbt.md) via [dbt-snowflake](https://github.com/dbt-labs/dbt-snowflake). Both password and key pair authentication are supported and shared with dbt runners. diff --git a/docs/website/docs/dlt-ecosystem/destinations/synapse.md b/docs/website/docs/dlt-ecosystem/destinations/synapse.md index 2e936f193e..6cfcb1ef8f 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/synapse.md +++ b/docs/website/docs/dlt-ecosystem/destinations/synapse.md @@ -148,6 +148,8 @@ Data is loaded via `INSERT` statements by default. The [table index type](https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-index) of the created tables can be configured at the resource level with the `synapse_adapter`: ```py +from dlt.destinations.adapters import synapse_adapter + info = pipeline.run( synapse_adapter( data=your_resource, diff --git a/docs/website/docs/dlt-ecosystem/destinations/weaviate.md b/docs/website/docs/dlt-ecosystem/destinations/weaviate.md index 11d1276ceb..c6597fadce 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/weaviate.md +++ b/docs/website/docs/dlt-ecosystem/destinations/weaviate.md @@ -252,7 +252,7 @@ it will be normalized to: so your best course of action is to clean up the data yourself before loading and use the default naming convention. Nevertheless, you can configure the alternative in `config.toml`: ```toml [schema] -naming="dlt.destinations.weaviate.impl.ci_naming" +naming="dlt.destinations.impl.weaviate.ci_naming" ``` ## Additional destination options diff --git a/docs/website/docs/dlt-ecosystem/file-formats/_set_the_format.mdx b/docs/website/docs/dlt-ecosystem/file-formats/_set_the_format.mdx new file mode 100644 index 0000000000..e2cce374a2 --- /dev/null +++ b/docs/website/docs/dlt-ecosystem/file-formats/_set_the_format.mdx @@ -0,0 +1,31 @@ +import CodeBlock from '@theme/CodeBlock'; + +There are several ways of configuring dlt to use {props.file_type} file format for normalization step and to store your data at the destination: + +1. You can set the loader_file_format argument to {props.file_type} in the run command: + +
+info = pipeline.run(some_source(), loader_file_format="{props.file_type}")
+
+ + +2. You can set the `loader_file_format` in `config.toml` or `secrets.toml`: + +
+[normalize]{'\n'}
+loader_file_format="{props.file_type}"
+
+ +3. You can set the `loader_file_format` via ENV variable: + +
+export NORMALIZE__LOADER_FILE_FORMAT="{props.file_type}"
+
+ +4. You can set the file type directly in [the resource decorator](../../general-usage/resource#pick-loader-file-format-for-a-particular-resource). + +
+@dlt.resource(file_format="{props.file_type}"){'\n'}
+def generate_rows(nr):{'\n'}
+    pass
+
diff --git a/docs/website/docs/dlt-ecosystem/file-formats/csv.md b/docs/website/docs/dlt-ecosystem/file-formats/csv.md index 4a57a0e2d6..242a8282d1 100644 --- a/docs/website/docs/dlt-ecosystem/file-formats/csv.md +++ b/docs/website/docs/dlt-ecosystem/file-formats/csv.md @@ -3,6 +3,7 @@ title: csv description: The csv file format keywords: [csv, file formats] --- +import SetTheFormat from './_set_the_format.mdx'; # CSV file format @@ -13,26 +14,35 @@ Internally we use two implementations: - **pyarrow** csv writer - very fast, multithreaded writer for the [arrow tables](../verified-sources/arrow-pandas.md) - **python stdlib writer** - a csv writer included in the Python standard library for Python objects - ## Supported Destinations -Supported by: **Postgres**, **Filesystem** +The `csv` format is supported by the following destinations: **Postgres**, **Filesystem**, **Snowflake** -By setting the `loader_file_format` argument to `csv` in the run command, the pipeline will store your data in the csv format at the destination: +## How to configure -```py -info = pipeline.run(some_source(), loader_file_format="csv") -``` + ## Default Settings `dlt` attempts to make both writers to generate similarly looking files * separators are commas * quotes are **"** and are escaped as **""** -* `NULL` values are empty strings +* `NULL` values both are empty strings and empty tokens as in the example below * UNIX new lines are used * dates are represented as ISO 8601 * quoting style is "when needed" +Example of NULLs: +```sh +text1,text2,text3 +A,B,C +A,,"" +``` + +In the last row both `text2` and `text3` values are NULL. Python `csv` writer +is not able to write unquoted `None` values so we had to settle for `""` + +Note: all destinations capable of writing csvs must support it. + ### Change settings You can change basic **csv** settings, this may be handy when working with **filesystem** destination. Other destinations are tested with standard settings: @@ -59,6 +69,15 @@ NORMALIZE__DATA_WRITER__INCLUDE_HEADER=False NORMALIZE__DATA_WRITER__QUOTING=quote_all ``` +### Destination settings +A few additional settings are available when copying `csv` to destination tables: +* **on_error_continue** - skip lines with errors (only Snowflake) +* **encoding** - encoding of the `csv` file + +:::tip +You'll need those setting when [importing external files](../../general-usage/resource.md#import-external-files) +::: + ## Limitations **arrow writer** diff --git a/docs/website/docs/dlt-ecosystem/file-formats/insert-format.md b/docs/website/docs/dlt-ecosystem/file-formats/insert-format.md index 641be9a106..c6742c2584 100644 --- a/docs/website/docs/dlt-ecosystem/file-formats/insert-format.md +++ b/docs/website/docs/dlt-ecosystem/file-formats/insert-format.md @@ -3,6 +3,7 @@ title: INSERT description: The INSERT file format keywords: [insert values, file formats] --- +import SetTheFormat from './_set_the_format.mdx'; # SQL INSERT File Format @@ -21,10 +22,8 @@ This file format is [compressed](../../reference/performance.md#disabling-and-en This format is used by default by: **DuckDB**, **Postgres**, **Redshift**. -It is also supported by: **filesystem**. +It is also supported by: **Filesystem**. -By setting the `loader_file_format` argument to `insert_values` in the run command, the pipeline will store your data in the INSERT format at the destination: +## How to configure -```py -info = pipeline.run(some_source(), loader_file_format="insert_values") -``` + diff --git a/docs/website/docs/dlt-ecosystem/file-formats/jsonl.md b/docs/website/docs/dlt-ecosystem/file-formats/jsonl.md index 7467c6f639..72168b38f0 100644 --- a/docs/website/docs/dlt-ecosystem/file-formats/jsonl.md +++ b/docs/website/docs/dlt-ecosystem/file-formats/jsonl.md @@ -3,6 +3,7 @@ title: jsonl description: The jsonl file format keywords: [jsonl, file formats] --- +import SetTheFormat from './_set_the_format.mdx'; # jsonl - JSON Delimited @@ -22,11 +23,8 @@ This file format is ## Supported Destinations -This format is used by default by: **BigQuery**, **Snowflake**, **filesystem**. +This format is used by default by: **BigQuery**, **Snowflake**, **Filesystem**. -By setting the `loader_file_format` argument to `jsonl` in the run command, the pipeline will store -your data in the jsonl format at the destination: +## How to configure -```py -info = pipeline.run(some_source(), loader_file_format="jsonl") -``` + diff --git a/docs/website/docs/dlt-ecosystem/file-formats/parquet.md b/docs/website/docs/dlt-ecosystem/file-formats/parquet.md index 414eaf2cb8..5d85b7a557 100644 --- a/docs/website/docs/dlt-ecosystem/file-formats/parquet.md +++ b/docs/website/docs/dlt-ecosystem/file-formats/parquet.md @@ -3,6 +3,7 @@ title: Parquet description: The parquet file format keywords: [parquet, file formats] --- +import SetTheFormat from './_set_the_format.mdx'; # Parquet file format @@ -16,13 +17,11 @@ pip install "dlt[parquet]" ## Supported Destinations -Supported by: **BigQuery**, **DuckDB**, **Snowflake**, **filesystem**, **Athena**, **Databricks**, **Synapse** +Supported by: **BigQuery**, **DuckDB**, **Snowflake**, **Filesystem**, **Athena**, **Databricks**, **Synapse** -By setting the `loader_file_format` argument to `parquet` in the run command, the pipeline will store your data in the parquet format at the destination: +## How to configure -```py -info = pipeline.run(some_source(), loader_file_format="parquet") -``` + ## Destination AutoConfig `dlt` uses [destination capabilities](../../walkthroughs/create-new-destination.md#3-set-the-destination-capabilities) to configure the parquet writer: diff --git a/docs/website/docs/dlt-ecosystem/index.md b/docs/website/docs/dlt-ecosystem/index.md new file mode 100644 index 0000000000..740a3a3a39 --- /dev/null +++ b/docs/website/docs/dlt-ecosystem/index.md @@ -0,0 +1,18 @@ +--- +title: Integrations +description: List of integrations +keywords: ['integrations, sources, destinations'] +--- +import DocCardList from '@theme/DocCardList'; +import Link from '../_book-onboarding-call.md'; + +Speed up the process of creating data pipelines by using dlt's multiple pre-built sources and destinations: + +- Each [dlt verified source](verified-sources) allows you to create [pipelines](../general-usage/pipeline) that extract data from a particular source: a database, a cloud service, or an API. +- [Destinations](destinations) are where you want to load your data. dlt supports a variety of destinations, including databases, data warehouses, and data lakes. + + + +:::tip +Most source-destination pairs work seamlessly together. If the merge [write disposition](../general-usage/incremental-loading#choosing-a-write-disposition) is not supported by a destination (for example, [file sytem destination](destinations/filesystem)), dlt will automatically fall back to the [append](../general-usage/incremental-loading#append) write disposition. +::: \ No newline at end of file diff --git a/docs/website/docs/dlt-ecosystem/staging.md b/docs/website/docs/dlt-ecosystem/staging.md index e3a60dfa51..05e31a574b 100644 --- a/docs/website/docs/dlt-ecosystem/staging.md +++ b/docs/website/docs/dlt-ecosystem/staging.md @@ -7,9 +7,37 @@ keywords: [staging, destination] The goal of staging is to bring the data closer to the database engine so the modification of the destination (final) dataset happens faster and without errors. `dlt`, when asked, creates two staging areas: -1. A **staging dataset** used by the [merge and replace loads](../general-usage/incremental-loading.md#merge-incremental_loading) to deduplicate and merge data with the destination. Such staging dataset has the same name as the dataset passed to `dlt.pipeline` but with `_staging` suffix in the name. As a user you typically never see and directly interact with it. +1. A **staging dataset** used by the [merge and replace loads](../general-usage/incremental-loading.md#merge-incremental_loading) to deduplicate and merge data with the destination. 2. A **staging storage** which is typically a s3/gcp bucket where [loader files](file-formats/) are copied before they are loaded by the destination. +## Staging dataset +`dlt` creates a staging dataset when write disposition of any of the loaded resources requires it. It creates and migrates required tables exactly like for the +main dataset. Data in staging tables is truncated when load step begins and only for tables that will participate in it. +Such staging dataset has the same name as the dataset passed to `dlt.pipeline` but with `_staging` suffix in the name. Alternatively, you can provide your own staging dataset pattern or use a fixed name, identical for all the +configured datasets. +```toml +[destination.postgres] +staging_dataset_name_layout="staging_%s" +``` +Entry above switches the pattern to `staging_` prefix and for example for dataset with name **github_data** `dlt` will create **staging_github_data**. + +To configure static staging dataset name, you can do the following (we use destination factory) +```py +import dlt + +dest_ = dlt.destinations.postgres(staging_dataset_name_layout="_dlt_staging") +``` +All pipelines using `dest_` as destination will use **staging_dataset** to store staging tables. Make sure that your pipelines are not overwriting each other's tables. + +### Cleanup up staging dataset automatically +`dlt` does not truncate tables in staging dataset at the end of the load. Data that is left after contains all the extracted data and may be useful for debugging. +If you prefer to truncate it, put the following line in `config.toml`: + +```toml +[load] +truncate_staging_dataset=true +``` + ## Staging storage `dlt` allows to chain destinations where the first one (`staging`) is responsible for uploading the files from local filesystem to the remote storage. It then generates followup jobs for the second destination that (typically) copy the files from remote storage into destination. diff --git a/docs/website/docs/dlt-ecosystem/transformations/sql.md b/docs/website/docs/dlt-ecosystem/transformations/sql.md index ad37c61bd8..b358e97b4c 100644 --- a/docs/website/docs/dlt-ecosystem/transformations/sql.md +++ b/docs/website/docs/dlt-ecosystem/transformations/sql.md @@ -16,7 +16,7 @@ connection. pipeline = dlt.pipeline(destination="bigquery", dataset_name="crm") try: with pipeline.sql_client() as client: - client.sql_client.execute_sql( + client.execute_sql( "INSERT INTO customers VALUES (%s, %s, %s)", 10, "Fred", diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md b/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md index f9ceb99a90..cb14db7ae7 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md @@ -70,9 +70,12 @@ The output file format is chosen automatically based on the destination's capabi * snowflake * filesystem * athena +* databricks +* dremio +* synapse -## Normalize configuration +## Add `_dlt_load_id` and `_dlt_id` to your tables `dlt` does not add any data lineage columns by default when loading Arrow tables. This is to give the best performance and avoid unnecessary data copying. @@ -120,6 +123,21 @@ pipeline.run(orders) Look at the [Connector X + Arrow Example](../../examples/connector_x_arrow/) to see how to load data from production databases fast. ::: +## Loading `json` documents +If you want to skip default `dlt` JSON normalizer, you can use any available method to convert json documents into tabular data. +* **pandas** has `read_json` and `json_normalize` methods +* **pyarrow** can infer table schema and convert json files into tables with `read_json` +* **duckdb** can do the same with `read_json_auto` + +```py +import duckdb + +conn = duckdb.connect() +table = conn.execute(f"SELECT * FROM read_json_auto('{json_file_path}')").fetch_arrow_table() +``` + +Note that **duckdb** and **pyarrow** methods will generate [nested types](#loading-nested-types) for nested data, which are only partially supported by `dlt`. + ## Supported Arrow data types The Arrow data types are translated to dlt data types as follows: @@ -141,7 +159,7 @@ The Arrow data types are translated to dlt data types as follows: ## Loading nested types All struct types are represented as `complex` and will be loaded as JSON (if destination permits) or a string. Currently we do not support **struct** types, -even if they are present in the destination. +even if they are present in the destination (except **BigQuery** which can be [configured to handle them](../destinations/bigquery.md#use-bigquery-schema-autodetect-for-nested-fields)) If you want to represent nested data as separated tables, you must yield panda frames and arrow tables as records. In the examples above: ```py diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/google_sheets.md b/docs/website/docs/dlt-ecosystem/verified-sources/google_sheets.md index 7b957e98ea..9cd6ad8079 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/google_sheets.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/google_sheets.md @@ -355,7 +355,7 @@ To read more about tables, columns, and datatypes, please refer to [our document `dlt` will **not modify** tables after they are created. So if you changed data types with hints, then you need to **delete the dataset** -or set `full_refresh=True`. +or set `dev_mode=True`. ::: ## Sources and resources diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/hubspot.md b/docs/website/docs/dlt-ecosystem/verified-sources/hubspot.md index 357d50582f..83077270c7 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/hubspot.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/hubspot.md @@ -283,5 +283,25 @@ verified source. 1. This function loads data incrementally and tracks the `occurred_at.last_value` parameter from the previous pipeline run. Refer to our official documentation for more information on [incremental loading](../../general-usage/incremental-loading.md). +### Additional info +If you encounter the following error while processing your request: +:::warning ERROR +Your request to HubSpot is too long to process. Maximum allowed query length is 2000 symbols, ... while your list is +2125 symbols long. +::: + +Please note that by default, HubSpot requests all default properties and all custom properties (which are +user-created properties in HubSpot). Therefore, you need to request specific properties for each entity (contacts, +companies, tickets, etc.). + +Default properties are defined in `settings.py`, and you can change them. + +The custom properties could cause the error as there might be too many of them available in your HubSpot. +To change this, you can pass `include_custom_props=False` when initializing the source: + +```py +info = p.run(hubspot(include_custom_props=False)) +``` +Or, if you wish to include them, you can modify `settings.py`. diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/index.md b/docs/website/docs/dlt-ecosystem/verified-sources/index.md index d9ae2d1f21..7b5d9e2bcb 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/index.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/index.md @@ -6,14 +6,29 @@ keywords: ['verified source'] import DocCardList from '@theme/DocCardList'; import Link from '../../_book-onboarding-call.md'; -Pick one of our verified sources that we wrote or maintain ourselves. All of them are constantly tested on real data and distributed as simple Python code so they can be easily customized or hacked. +Choose from our collection of verified sources, developed and maintained by the dlt team and community. Each source is rigorously tested against a real API and provided as Python code for easy customization. -* Need more info? [Join our Slack community](https://dlthub.com/community) and ask in the tech help channel or . +Planning to use dlt in production and need a source that isn't listed? We're happy to build it for you: . -Do you plan to run dlt in production and source is missing? We are happy to build it. -* Source missing? [Request a new verified source](https://github.com/dlt-hub/verified-sources/issues/new?template=source-request.md) -* Missing endpoint or a feature? [Request or contribute](https://github.com/dlt-hub/verified-sources/issues/new?template=extend-a-source.md) +### Popular sources + +- [SQL databases](sql_database). Supports PostgreSQL, MySQL, MS SQL Server, BigQuery, Redshift, and more. +- [REST API generic source](rest_api). Loads data from REST APIs using declarative configuration. +- [OpenAPI source generator](openapi-generator). Generates a source from an OpenAPI 3.x spec using the REST API source. +- [Cloud and local storage](filesystem). Retrieves data from AWS S3, Google Cloud Storage, Azure Blob Storage, local files, and more. -Otherwise pick a source below: +### Full list of verified sources + +:::tip +If you're looking for a source that isn't listed and it provides a REST API, be sure to check out our [REST API generic source](rest_api) + source. +::: + + +### Get help + +* Source missing? [Request a new verified source.](https://github.com/dlt-hub/verified-sources/issues/new?template=source-request.md) +* Missing endpoint or a feature? [Request or contribute](https://github.com/dlt-hub/verified-sources/issues/new?template=extend-a-source.md) +* [Join our Slack community](https://dlthub.com/community) and ask in the technical-help channel. diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/mongodb.md b/docs/website/docs/dlt-ecosystem/verified-sources/mongodb.md index 6fda0f8fe9..f6d57a5ba2 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/mongodb.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/mongodb.md @@ -317,16 +317,28 @@ verified source. 1. To load a selected collection and rename it in the destination: ```py - # Create the MongoDB source and select the "collection_1" collection - source = mongodb().with_resources("collection_1") + # Create the MongoDB source and select the "collection_1" collection + source = mongodb().with_resources("collection_1") - # Apply the hint to rename the table in the destination - source.resources["collection_1"].apply_hints(table_name="loaded_data_1") + # Apply the hint to rename the table in the destination + source.resources["collection_1"].apply_hints(table_name="loaded_data_1") - # Run the pipeline - info = pipeline.run(source, write_disposition="replace") - print(info) + # Run the pipeline + info = pipeline.run(source, write_disposition="replace") + print(info) ``` +1. To load a selected collection, using Apache Arrow for data conversion: + ```py + # Load collection "movies", using Apache Arrow for converion + movies = mongodb_collection( + collection="movies", + data_item_format="arrow", + ) + + # Run the pipeline + info = pipeline.run(source) + print(info) + ``` diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/pg_replication.md b/docs/website/docs/dlt-ecosystem/verified-sources/pg_replication.md new file mode 100644 index 0000000000..6d69f09cd3 --- /dev/null +++ b/docs/website/docs/dlt-ecosystem/verified-sources/pg_replication.md @@ -0,0 +1,271 @@ +--- +title: Postgres replication +description: dlt verified source for Postgres replication +keywords: [postgres, postgres replication, database replication] +--- +import Header from './_source-info-header.md'; + +# Postgres replication + +
+ +[Postgres](https://www.postgresql.org/) is one of the most popular relational database management systems. This verified source uses Postgres replication functionality to efficiently process tables (a process often referred to as *Change Data Capture* or CDC). It uses [logical decoding](https://www.postgresql.org/docs/current/logicaldecoding.html) and the standard built-in `pgoutput` [output plugin](https://www.postgresql.org/docs/current/logicaldecoding-output-plugin.html). + +Resources that can be loaded using this verified source are: + +| Name | Description | +| -------------------- | ----------------------------------------------- | +| replication_resource | Load published messages from a replication slot | + +## Setup Guide + +### Setup user +To setup a Postgres user follow these steps: + +1. The Postgres user needs to have the `LOGIN` and `REPLICATION` attributes assigned: + + ```sql + CREATE ROLE replication_user WITH LOGIN REPLICATION; + ``` + +2. It also needs `GRANT` privilege on the database: + + ```sql + GRANT CREATE ON DATABASE dlt_data TO replication_user; + ``` + + +### Set up RDS +To setup a Postgres user on RDS follow these steps: + +1. You must enable replication for RDS Postgres instance via [Parameter Group](https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_PostgreSQL.Replication.ReadReplicas.html) + +2. `WITH LOGIN REPLICATION;` does not work on RDS, instead do: + + ```sql + GRANT rds_replication TO replication_user; + ``` + +3. Do not fallback to non-SSL connection by setting connection parameters: + + ```toml + sources.pg_replication.credentials="postgresql://loader:password@host.rds.amazonaws.com:5432/dlt_data?sslmode=require&connect_timeout=300" + ``` +### Initialize the verified source + +To get started with your data pipeline, follow these steps: + +1. Enter the following command: + + ```sh + dlt init pg_replication duckdb + ``` + + It will initialize [the pipeline example](https://github.com/dlt-hub/verified-sources/blob/master/sources/pg_replication_pipeline.py) with a Postgres replication as the [source](https://dlthub.com/docs/general-usage/source) and [DuckDB](https://dlthub.com/docs/dlt-ecosystem/destinations/duckdb) as the [destination](https://dlthub.com/docs/dlt-ecosystem/destinations). + + +2. If you'd like to use a different destination, simply replace `duckdb` with the name of your preferred [destination](https://dlthub.com/docs/dlt-ecosystem/destinations). + +3. This source uses `sql_database` source, you can init it as follows: + + ```sh + dlt init sql_database duckdb + ``` + :::note + It is important to note that It is now only required if a user performs an initial load, specifically when `persist_snapshots` is set to `True`. + ::: + +4. After running these two commands, a new directory will be created with the necessary files and configuration settings to get started. + + For more information, read the guide on [how to add a verified source](https://dlthub.com/docs/walkthroughs/add-a-verified-source). + + :::note + You can omit the `[sql.sources.credentials]` section in `secrets.toml` as it is not required. + ::: + +### Add credentials + +1. In the `.dlt` folder, there's a file called `secrets.toml`. It's where you store sensitive information securely, like access tokens. Keep this file safe. + + Here's what the `secrets.toml` looks like: + + ```toml + [sources.pg_replication.credentials] + drivername = "postgresql" # please set me up! + database = "database" # please set me up! + password = "password" # please set me up! + username = "username" # please set me up! + host = "host" # please set me up! + port = 0 # please set me up! + ``` + +2. Credentials can be set as shown above. Alternatively, you can provide credentials in the `secrets.toml` file as follows: + + ```toml + sources.pg_replication.credentials="postgresql://username@password.host:port/database" + ``` + +3. Finally, follow the instructions in [Destinations](https://dlthub.com/docs/dlt-ecosystem/destinations/) to add credentials for your chosen destination. This will ensure that your data is properly routed. + +For more information, read the [Configuration section.](https://dlthub.com/docs/general-usage/credentials) + +## Run the pipeline + +1. Before running the pipeline, ensure that you have installed all the necessary dependencies by running the command: + ```sh + pip install -r requirements.txt + ``` +2. You're now ready to run the pipeline! To get started, run the following command: + ```sh + python pg_replication_pipeline.py + ``` +3. Once the pipeline has finished running, you can verify that everything loaded correctly by using the following command: + ```sh + dlt pipeline show + ``` + For example, the `pipeline_name` for the above pipeline example is `pg_replication_pipeline`, you may also use any custom name instead. + + + For more information, read the guide on [how to run a pipeline](https://dlthub.com/docs/walkthroughs/run-a-pipeline). + + +## Sources and resources + +`dlt` works on the principle of [sources](https://dlthub.com/docs/general-usage/source) and [resources](https://dlthub.com/docs/general-usage/resource). + +### Resource `replication_resource` + +This resource yields data items for changes in one or more Postgres tables. + +```py +@dlt.resource( + name=lambda args: args["slot_name"] + "_" + args["pub_name"], + standalone=True, +) +def replication_resource( + slot_name: str, + pub_name: str, + credentials: ConnectionStringCredentials = dlt.secrets.value, + include_columns: Optional[Dict[str, Sequence[str]]] = None, + columns: Optional[Dict[str, TTableSchemaColumns]] = None, + target_batch_size: int = 1000, + flush_slot: bool = True, +) -> Iterable[Union[TDataItem, DataItemWithMeta]]: + ... +``` + +`slot_name`: Replication slot name to consume messages. + +`pub_name`: Publication slot name to publish messages. + +`include_columns`: Maps table name(s) to sequence of names of columns to include in the generated data items. Any column not in the sequence is excluded. If not provided, all columns are included + +`columns`: Maps table name(s) to column hints to apply on the replicated table(s) + +`target_batch_size`: Desired number of data items yielded in a batch. Can be used to limit the data items in memory. + +`flush_slot`: Whether processed messages are discarded from the replication slot. The recommended value is "True". + +## Customization + +If you wish to create your own pipelines, you can leverage source and resource methods from this verified source. + +1. Define the source pipeline as: + + ```py + # Defining source pipeline + src_pl = dlt.pipeline( + pipeline_name="source_pipeline", + destination="postgres", + dataset_name="source_dataset", + dev_mode=True, + ) + ``` + + You can configure and use the `get_postgres_pipeline()` function available in the `pg_replication_pipeline.py` file to achieve the same functionality. + + :::note IMPORTANT + When working with large datasets from a Postgres database, it's important to consider the relevance of the source pipeline. For testing purposes, using the source pipeline can be beneficial to try out the data flow. However, in production use cases, there will likely be another process that mutates the Postgres database. In such cases, the user generally only needs to define a destination pipeline. + ::: + + +2. Similarly, define the destination pipeline. + + ```py + dest_pl = dlt.pipeline( + pipeline_name="pg_replication_pipeline", + destination='duckdb', + dataset_name="replicate_single_table", + dev_mode=True, + ) + ``` + +3. Define the slot and publication names as: + + ```py + slot_name = "example_slot" + pub_name = "example_pub" + ``` + +4. To initialize replication, you can use the `init_replication` function. A user can use this function to let `dlt` configure Postgres and make it ready for replication. + + ```py + # requires the Postgres user to have the REPLICATION attribute assigned + init_replication( + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + table_names="my_source_table", + reset=True, + ) + ``` + + :::note + To replicate the entire schema, you can omit the `table_names` argument from the `init_replication` function. + ::: + +5. To snapshot the data to the destination during the initial load, you can use the `persist_snapshots=True` argument as follows: + ```py + snapshot = init_replication( # requires the Postgres user to have the REPLICATION attribute assigned + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + table_names="my_source_table", + persist_snapshots=True, # persist snapshot table(s) and let function return resource(s) for initial load + reset=True, + ) + ``` + +6. To load this snapshot to the destination, run the destination pipeline as: + + ```py + dest_pl.run(snapshot) + ``` + +7. After changes are made to the source, you can replicate the changes to the destination using the `replication_resource`, and run the pipeline as: + + ```py + # Create a resource that generates items for each change in the source table + changes = replication_resource(slot_name, pub_name) + + # Run the pipeline as + dest_pl.run(changes) + ``` + +8. To replicate tables with selected columns you can use the `include_columns` argument as follows: + + ```py + # requires the Postgres user to have the REPLICATION attribute assigned + initial_load = init_replication( + slot_name=slot_name, + pub_name=pub_name, + schema_name=src_pl.dataset_name, + table_names="my_source_table", + include_columns={ + "my_source_table": ("column1", "column2") + }, + reset=True, + ) + ``` + + Similarly, to replicate changes from selected columns, you can use the `table_names` and `include_columns` arguments in the `replication_resource` function. \ No newline at end of file diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md b/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md index 98725627b9..8d43e471c8 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/rest_api.md @@ -9,6 +9,61 @@ import Header from './_source-info-header.md'; This is a generic dlt source you can use to extract data from any REST API. It uses [declarative configuration](#source-configuration) to define the API endpoints, their [relationships](#define-resource-relationships), how to handle [pagination](#pagination), and [authentication](#authentication). +### Quick example + +Here's an example of how to configure the REST API source to load posts and related comments from a hypothetical blog API: + +```py +import dlt +from rest_api import rest_api_source + +source = rest_api_source({ + "client": { + "base_url": "https://api.example.com/", + "auth": { + "token": dlt.secrets["your_api_token"], + }, + "paginator": { + "type": "json_response", + "next_url_path": "paging.next", + }, + }, + "resources": [ + # "posts" will be used as the endpoint path, the resource name, + # and the table name in the destination. The HTTP client will send + # a request to "https://api.example.com/posts". + "posts", + + # The explicit configuration allows you to link resources + # and define parameters. + { + "name": "comments", + "endpoint": { + "path": "posts/{post_id}/comments", + "params": { + "post_id": { + "type": "resolve", + "resource": "posts", + "field": "id", + }, + "sort": "created_at", + }, + }, + }, + ], +}) + +pipeline = dlt.pipeline( + pipeline_name="rest_api_example", + destination="duckdb", + dataset_name="rest_api_data", +) + +load_info = pipeline.run(source) +``` + +Running this pipeline will create two tables in the DuckDB: `posts` and `comments` with the data from the respective API endpoints. The `comments` resource will fetch comments for each post by using the `id` field from the `posts` resource. + ## Setup guide ### Initialize the verified source @@ -187,12 +242,12 @@ config: RESTAPIConfig = { #### `client` -`client` contains the configuration to connect to the API's endpoints. It includes the following fields: +The `client` configuration is used to connect to the API's endpoints. It includes the following fields: - `base_url` (str): The base URL of the API. This string is prepended to all endpoint paths. For example, if the base URL is `https://api.example.com/v1/`, and the endpoint path is `users`, the full URL will be `https://api.example.com/v1/users`. -- `headers` (dict, optional): Additional headers to be sent with each request. -- `auth` (optional): Authentication configuration. It can be a simple token, a `AuthConfigBase` object, or a more complex authentication method. -- `paginator` (optional): Configuration for the default pagination to be used for resources that support pagination. See the [pagination](#pagination) section for more details. +- `headers` (dict, optional): Additional headers that are sent with each request. +- `auth` (optional): Authentication configuration. This can be a simple token, an `AuthConfigBase` object, or a more complex authentication method. +- `paginator` (optional): Configuration for the default pagination used for resources that support pagination. Refer to the [pagination](#pagination) section for more details. #### `resource_defaults` (optional) @@ -291,46 +346,69 @@ The REST API source will try to automatically handle pagination for you. This wo In some special cases, you may need to specify the pagination configuration explicitly. -:::note -Currently pagination is supported only for GET requests. To handle POST requests with pagination, you need to implement a [custom paginator](../../general-usage/http/rest-client.md#custom-paginator). -::: +To specify the pagination configuration, use the `paginator` field in the [client](#client) or [endpoint](#endpoint-configuration) configurations. You may either use a dictionary with a string alias in the `type` field along with the required parameters, or use a [paginator class instance](../../general-usage/http/rest-client.md#paginators). -These are the available paginators: +#### Example + +Suppose the API response for `https://api.example.com/posts` contains a `next` field with the URL to the next page: -| Paginator class | String Alias (`type`) | Description | -| -------------- | ------------ | ----------- | -| [JSONResponsePaginator](../../general-usage/http/rest-client.md#jsonresponsepaginator) | `json_response` | The links to the next page are in the body (JSON) of the response. | -| [HeaderLinkPaginator](../../general-usage/http/rest-client.md#headerlinkpaginator) | `header_link` | The links to the next page are in the response headers. | -| [OffsetPaginator](../../general-usage/http/rest-client.md#offsetpaginator) | `offset` | The pagination is based on an offset parameter. With total items count either in the response body or explicitly provided. | -| [PageNumberPaginator](../../general-usage/http/rest-client.md#pagenumberpaginator) | `page_number` | The pagination is based on a page number parameter. With total pages count either in the response body or explicitly provided. | -| [JSONCursorPaginator](../../general-usage/http/rest-client.md#jsonresponsecursorpaginator) | `cursor` | The pagination is based on a cursor parameter. The value of the cursor is in the response body (JSON). | -| SinglePagePaginator | `single_page` | The response will be interpreted as a single-page response, ignoring possible pagination metadata. | -| `None` | `auto` | Explicitly specify that the source should automatically detect the pagination method. | +```json +{ + "data": [ + {"id": 1, "title": "Post 1"}, + {"id": 2, "title": "Post 2"}, + {"id": 3, "title": "Post 3"} + ], + "pagination": { + "next": "https://api.example.com/posts?page=2" + } +} +``` -To specify the pagination configuration, use the `paginator` field in the [client](#client) or [endpoint](#endpoint-configuration) configurations. You may either use a dictionary with a string alias in the `type` field along with the required parameters, or use the paginator instance directly: +You can configure the pagination for the `posts` resource like this: ```py { - # ... + "path": "posts", "paginator": { - "type": "json_links", - "next_url_path": "paging.next", + "type": "json_response", + "next_url_path": "pagination.next", } } ``` -Or using the paginator instance: +Alternatively, you can use the paginator instance directly: ```py +from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator + +# ... + { - # ... + "path": "posts", "paginator": JSONResponsePaginator( - next_url_path="paging.next" + next_url_path="pagination.next" ), } ``` -This is useful when you're [implementing and using a custom paginator](../../general-usage/http/rest-client.md#custom-paginator). +:::note +Currently pagination is supported only for GET requests. To handle POST requests with pagination, you need to implement a [custom paginator](../../general-usage/http/rest-client.md#custom-paginator). +::: + +These are the available paginators: + +| `type` | Paginator class | Description | +| ------------ | -------------- | ----------- | +| `json_response` | [JSONResponsePaginator](../../general-usage/http/rest-client.md#jsonresponsepaginator) | The link to the next page is in the body (JSON) of the response.
*Parameters:*
  • `next_url_path` (str) - the JSONPath to the next page URL
| +| `header_link` | [HeaderLinkPaginator](../../general-usage/http/rest-client.md#headerlinkpaginator) | The links to the next page are in the response headers.
*Parameters:*
  • `link_header` (str) - the name of the header containing the links. Default is "next".
| +| `offset` | [OffsetPaginator](../../general-usage/http/rest-client.md#offsetpaginator) | The pagination is based on an offset parameter. With total items count either in the response body or explicitly provided.
*Parameters:*
  • `limit` (int) - the maximum number of items to retrieve in each request
  • `offset` (int) - the initial offset for the first request. Defaults to `0`
  • `offset_param` (str) - the name of the query parameter used to specify the offset. Defaults to "offset"
  • `limit_param` (str) - the name of the query parameter used to specify the limit. Defaults to "limit"
  • `total_path` (str) - a JSONPath expression for the total number of items. If not provided, pagination is controlled by `maximum_offset`
  • `maximum_offset` (int) - optional maximum offset value. Limits pagination even without total count
| +| `page_number` | [PageNumberPaginator](../../general-usage/http/rest-client.md#pagenumberpaginator) | The pagination is based on a page number parameter. With total pages count either in the response body or explicitly provided.
*Parameters:*
  • `initial_page` (int) - the starting page number. Defaults to `0`
  • `page_param` (str) - the query parameter name for the page number. Defaults to "page"
  • `total_path` (str) - a JSONPath expression for the total number of pages. If not provided, pagination is controlled by `maximum_page`
  • `maximum_page` (int) - optional maximum page number. Stops pagination once this page is reached
| +| `cursor` | [JSONResponseCursorPaginator](../../general-usage/http/rest-client.md#jsonresponsecursorpaginator) | The pagination is based on a cursor parameter. The value of the cursor is in the response body (JSON).
*Parameters:*
  • `cursor_path` (str) - the JSONPath to the cursor value. Defaults to "cursors.next"
  • `cursor_param` (str) - the query parameter name for the cursor. Defaults to "after"
| +| `single_page` | SinglePagePaginator | The response will be interpreted as a single-page response, ignoring possible pagination metadata. | +| `auto` | `None` | Explicitly specify that the source should automatically detect the pagination method. | + +For more complex pagination methods, you can implement a [custom paginator](../../general-usage/http/rest-client.md#implementing-a-custom-paginator), instantiate it, and use it in the configuration. ### Data selection @@ -387,11 +465,11 @@ Read more about [JSONPath syntax](https://github.com/h2non/jsonpath-ng?tab=readm ### Authentication -Many APIs require authentication to access their endpoints. The REST API source supports various authentication methods, such as token-based, query parameters, basic auth, etc. +For APIs that require authentication to access their endpoints, the REST API source supports various authentication methods, including token-based authentication, query parameters, basic authentication, and custom authentication. The authentication configuration is specified in the `auth` field of the [client](#client) either as a dictionary or as an instance of the [authentication class](../../general-usage/http/rest-client.md#authentication). #### Quick example -One of the most common method is token-based authentication. To authenticate with a token, you can use the `token` field in the `auth` configuration: +One of the most common methods is token-based authentication (also known as Bearer token authentication). To authenticate using this method, you can use the following shortcut: ```py { @@ -416,12 +494,14 @@ Available authentication types: | [BearTokenAuth](../../general-usage/http/rest-client.md#bearer-token-authentication) | `bearer` | Bearer token authentication. | | [HTTPBasicAuth](../../general-usage/http/rest-client.md#http-basic-authentication) | `http_basic` | Basic HTTP authentication. | | [APIKeyAuth](../../general-usage/http/rest-client.md#api-key-authentication) | `api_key` | API key authentication with key defined in the query parameters or in the headers. | +| [OAuth2ClientCredentials](../../general-usage/http/rest-client.md#oauth20-authorization) | N/A | OAuth 2.0 authorization with a temporary access token obtained from the authorization server. | To specify the authentication configuration, use the `auth` field in the [client](#client) configuration: ```py { "client": { + # ... "auth": { "type": "bearer", "token": dlt.secrets["your_api_token"], @@ -444,6 +524,23 @@ config = { } ``` +:::warning +Make sure to store your access tokens and other sensitive information in the `secrets.toml` file and never commit it to the version control system. +::: + +Available authentication types: + +| `type` | Authentication class | Description | +| ----------- | ------------------- | ----------- | +| `bearer` | [BearTokenAuth](../../general-usage/http/rest-client.md#bearer-token-authentication) | Bearer token authentication.
Parameters:
  • `token` (str)
| +| `http_basic` | [HTTPBasicAuth](../../general-usage/http/rest-client.md#http-basic-authentication) | Basic HTTP authentication.
Parameters:
  • `username` (str)
  • `password` (str)
| +| `api_key` | [APIKeyAuth](../../general-usage/http/rest-client.md#api-key-authentication) | API key authentication with key defined in the query parameters or in the headers.
Parameters:
  • `name` (str) - the name of the query parameter or header
  • `api_key` (str) - the API key value
  • `location` (str, optional) - the location of the API key in the request. Can be `query` or `header`. Default is `header`
| + + +For more complex authentication methods, you can implement a [custom authentication class](../../general-usage/http/rest-client.md#implementing-custom-authentication) and use it in the configuration. + + + ### Define resource relationships When you have a resource that depends on another resource, you can define the relationship using the `resolve` configuration. With it you link a path parameter in the child resource to a field in the parent resource's data. @@ -531,52 +628,115 @@ This will include the `id`, `title`, and `created_at` fields from the `issues` r Some APIs provide a way to fetch only new or changed data (most often by using a timestamp field like `updated_at`, `created_at`, or incremental IDs). This is called [incremental loading](../../general-usage/incremental-loading.md) and is very useful as it allows you to reduce the load time and the amount of data transferred. -When the API endpoint supports incremental loading, you can configure the source to load only the new or changed data using these two methods: +When the API endpoint supports incremental loading, you can configure dlt to load only the new or changed data using these two methods: -1. Defining a special parameter in the `params` section of the [endpoint configuration](#endpoint-configuration): +1. Defining a special parameter in the `params` section of the [endpoint configuration](#endpoint-configuration). +2. Specifying the `incremental` field in the endpoint configuration. - ```py - { - "": { - "type": "incremental", - "cursor_path": "", - "initial_value": "", - }, - } - ``` +Let's start with the first method. - For example, in the `issues` resource configuration in the GitHub example, we have: +### Incremental loading in `params` - ```py - { - "since": { +Imagine we have the following endpoint `https://api.example.com/posts` and it: +1. Accepts a `created_since` query parameter to fetch posts created after a certain date. +2. Returns a list of posts with the `created_at` field for each post. + +For example, if we query the endpoint with `https://api.example.com/posts?created_since=2024-01-25`, we get the following response: + +```json +{ + "results": [ + {"id": 1, "title": "Post 1", "created_at": "2024-01-26"}, + {"id": 2, "title": "Post 2", "created_at": "2024-01-27"}, + {"id": 3, "title": "Post 3", "created_at": "2024-01-28"} + ] +} +``` + +To enable the incremental loading for this endpoint, you can use the following endpoint configuration: + +```py +{ + "path": "posts", + "data_selector": "results", # Optional JSONPath to select the list of posts + "params": { + "created_since": { "type": "incremental", - "cursor_path": "updated_at", - "initial_value": "2024-01-25T11:21:28Z", + "cursor_path": "created_at", # The JSONPath to the field we want to track in each post + "initial_value": "2024-01-25", }, - } - ``` + }, +} +``` - This configuration tells the source to create an incremental object that will keep track of the `updated_at` field in the response and use it as a value for the `since` parameter in subsequent requests. +After you run the pipeline, dlt will keep track of the last `created_at` from all the posts fetched and use it as the `created_since` parameter in the next request. +So in our case, the next request will be made to `https://api.example.com/posts?created_since=2024-01-28` to fetch only the new posts created after `2024-01-28`. -2. Specifying the `incremental` field in the [endpoint configuration](#endpoint-configuration): +Let's break down the configuration. - ```py - { - "incremental": { - "start_param": "", - "end_param": "", - "cursor_path": "", - "initial_value": "", - "end_value": "", - } +1. We explicitly set `data_selector` to `"results"` to select the list of posts from the response. This is optional, if not set, dlt will try to auto-detect the data location. +2. We define the `created_since` parameter as an incremental parameter with the following fields: + +```py +{ + "created_since": { + "type": "incremental", + "cursor_path": "created_at", + "initial_value": "2024-01-25", + }, +} +``` + +- `type`: The type of the parameter definition. In this case, it must be set to `incremental`. +- `cursor_path`: The JSONPath to the field within each item in the list. The value of this field will be used in the next request. In the example above our items look like `{"id": 1, "title": "Post 1", "created_at": "2024-01-26"}` so to track the created time we set `cursor_path` to `"created_at"`. Note that the JSONPath starts from the root of the item (dict) and not from the root of the response. +- `initial_value`: The initial value for the cursor. This is the value that will initialize the state of incremental loading. In this case, it's `2024-01-25`. The value type should match the type of the field in the data item. + +### Incremental loading using the `incremental` field + +The alternative method is to use the `incremental` field in the [endpoint configuration](#endpoint-configuration). This method is more flexible and allows you to specify the start and end conditions for the incremental loading. + +Let's take the same example as above and configure it using the `incremental` field: + +```py +{ + "path": "posts", + "data_selector": "results", + "incremental": { + "start_param": "created_since", + "cursor_path": "created_at", + "initial_value": "2024-01-25", + }, +} +``` + +Note that we specify the query parameter name `created_since` in the `start_param` field and not in the `params` section. + +The full available configuration for the `incremental` field is: + +```py +{ + "incremental": { + "start_param": "", + "end_param": "", + "cursor_path": "", + "initial_value": "", + "end_value": "", } - ``` +} +``` + +The fields are: - This configuration is more flexible and allows you to specify the start and end conditions for the incremental loading. +- `start_param` (str): The name of the query parameter to be used as the start condition. If we use the example above, it would be `"created_since"`. +- `end_param` (str): The name of the query parameter to be used as the end condition. This is optional and can be omitted if you only need to track the start condition. This is useful when you need to fetch data within a specific range and the API supports end conditions (like `created_before` query parameter). +- `cursor_path` (str): The JSONPath to the field within each item in the list. This is the field that will be used to track the incremental loading. In the example above, it's `"created_at"`. +- `initial_value` (str): The initial value for the cursor. This is the value that will initialize the state of incremental loading. +- `end_value` (str): The end value for the cursor to stop the incremental loading. This is optional and can be omitted if you only need to track the start condition. If you set this field, `initial_value` needs to be set as well. See the [incremental loading](../../general-usage/incremental-loading.md#incremental-loading-with-a-cursor-field) guide for more details. +If you encounter issues with incremental loading, see the [troubleshooting section](../../general-usage/incremental-loading.md#troubleshooting) in the incremental loading guide. + ## Advanced configuration `rest_api_source()` function creates the [dlt source](../../general-usage/source.md) and lets you configure the following parameters: @@ -618,3 +778,92 @@ In this example, the source will ignore responses with a status code of 404, res - `content` (str, optional): A substring to search for in the response content. - `action` (str): The action to take when the condition is met. Currently supported actions: - `ignore`: Ignore the response. + +## Troubleshooting + +If you encounter issues while running the pipeline, enable [logging](../../running-in-production/running.md#set-the-log-level-and-format) for detailed information about the execution: + +```sh +RUNTIME__LOG_LEVEL=INFO python my_script.py +``` + +This also provides details on the HTTP requests. + +### Configuration issues + +#### Getting validation errors + +When you running the pipeline and getting a `DictValidationException`, it means that the [source configuration](#source-configuration) is incorrect. The error message provides details on the issue including the path to the field and the expected type. + +For example, if you have a source configuration like this: + +```py +config: RESTAPIConfig = { + "client": { + # ... + }, + "resources": [ + { + "name": "issues", + "params": { # <- Wrong: this should be inside + "sort": "updated", # the endpoint field below + }, + "endpoint": { + "path": "issues", + # "params": { # <- Correct configuration + # "sort": "updated", + # }, + }, + }, + # ... + ], +} +``` + +You will get an error like this: + +```sh +dlt.common.exceptions.DictValidationException: In path .: field 'resources[0]' +expects the following types: str, EndpointResource. Provided value {'name': 'issues', 'params': {'sort': 'updated'}, +'endpoint': {'path': 'issues', ... }} with type 'dict' is invalid with the following errors: +For EndpointResource: In path ./resources[0]: following fields are unexpected {'params'} +``` + +It means that in the first resource configuration (`resources[0]`), the `params` field should be inside the `endpoint` field. + +:::tip +Import the `RESTAPIConfig` type from the `rest_api` module to have convenient hints in your editor/IDE and use it to define the configuration object. + +```py +from rest_api import RESTAPIConfig +``` +::: + +#### Getting wrong data or no data + +If incorrect data is received from an endpoint, check the `data_selector` field in the [endpoint configuration](#endpoint-configuration). Ensure the JSONPath is accurate and points to the correct data in the response body. `rest_api` attempts to auto-detect the data location, which may not always succeed. See the [data selection](#data-selection) section for more details. + +#### Getting insufficient data or incorrect pagination + +Check the `paginator` field in the configuration. When not explicitly specified, the source tries to auto-detect the pagination method. If auto-detection fails, or the system is unsure, a warning is logged. For production environments, we recommend to specify an explicit paginator in the configuration. See the [pagination](#pagination) section for more details. Some APIs may have non-standard pagination methods, and you may need to implement a [custom paginator](../../general-usage/http/rest-client.md#implementing-a-custom-paginator). + +#### Incremental loading not working + +See the [troubleshooting guide](../../general-usage/incremental-loading.md#troubleshooting) for incremental loading issues. + +#### Getting HTTP 404 errors + +Some API may return 404 errors for resources that do not exist or have no data. Manage these responses by configuring the `ignore` action in [response actions](#response-actions). + +### Authentication issues + +If experiencing 401 (Unauthorized) errors, this could indicate: + +- Incorrect authorization credentials. Verify credentials in the `secrets.toml`. Refer to [Secret and configs](../../general-usage/credentials/configuration#understanding-the-exceptions) for more information. +- An incorrect authentication type. Consult the API documentation for the proper method. See the [authentication](#authentication) section for details. For some APIs, a [custom authentication method](../../general-usage/http/rest-client.md#custom-authentication) may be required. + +### General guidelines + +The `rest_api` source uses the [RESTClient](../../general-usage/http/rest-client.md) class for HTTP requests. Refer to the RESTClient [troubleshooting guide](../../general-usage/http/rest-client.md#troubleshooting) for debugging tips. + +For further assistance, join our [Slack community](https://dlthub.com/community). We're here to help! \ No newline at end of file diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md b/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md index de3e5f4c35..a5869e99bd 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md @@ -161,7 +161,7 @@ good performance, preserves exact database types and we recommend it for large t ### **sqlalchemy** backend **sqlalchemy** (the default) yields table data as list of Python dictionaries. This data goes through regular extract -and normalize steps and does not require additional dependencies to be installed. It is the most robust (works with any destination, correctly represents data types) but also the slowest. You can use `detect_precision_hints` to pass exact database types to `dlt` schema. +and normalize steps and does not require additional dependencies to be installed. It is the most robust (works with any destination, correctly represents data types) but also the slowest. You can use `reflection_level="full_with_precision"` to pass exact database types to `dlt` schema. ### **pyarrow** backend @@ -177,9 +177,9 @@ pipeline = dlt.pipeline( ) def _double_as_decimal_adapter(table: sa.Table) -> None: - """Return double as double, not decimals, this is mysql thing""" + """Emits decimals instead of floats.""" for column in table.columns.values(): - if isinstance(column.type, sa.Double): # type: ignore + if isinstance(column.type, sa.Float): column.type.asdecimal = False sql_alchemy_source = sql_database( @@ -262,7 +262,7 @@ unsw_table = sql_table( chunk_size=100000, backend="connectorx", # keep source data types - detect_precision_hints=True, + reflection_level="full_with_precision", # just to demonstrate how to setup a separate connection string for connectorx backend_kwargs={"conn": "postgresql://loader:loader@localhost:5432/dlt_data"} ) @@ -271,7 +271,7 @@ pipeline = dlt.pipeline( pipeline_name="unsw_download", destination=filesystem(os.path.abspath("../_storage/unsw")), progress="log", - full_refresh=True, + dev_mode=True, ) info = pipeline.run( @@ -383,6 +383,132 @@ database = sql_database().parallelize() table = sql_table().parallelize() ``` +## Column reflection + +Columns and their data types are reflected with SQLAlchemy. The SQL types are then mapped to `dlt` types. +Most types are supported. + +The `reflection_level` argument controls how much information is reflected: + +- `reflection_level = "minimal"`: Only column names and nullability are detected. Data types are inferred from the data. +- `reflection_level = "full"`: Column names, nullability, and data types are detected. For decimal types we always add precision and scale. **This is the default.** +- `reflection_level = "full_with_precision"`: Column names, nullability, data types, and precision/scale are detected, also for types like text and binary. Integer sizes are set to bigint and to int for all other types. + +If the SQL type is unknown or not supported by `dlt` the column is skipped when using the `pyarrow` backend. +In other backend the type is inferred from data regardless of `reflection_level`, this often works, some types are coerced to strings +and `dataclass` based values from sqlalchemy are inferred as `complex` (JSON in most destinations). + +:::tip +If you use **full** (and above) reflection level you may encounter a situation where the data returned by sql alchemy or pyarrow backend +does not match the reflected data types. Most common symptoms are: +1. The destination complains that it cannot cast one type to another for a certain column. For example `connector-x` returns TIME in nanoseconds +and BigQuery sees it as bigint and fails to load. +2. You get `SchemaCorruptedException` or other coercion error during `normalize` step. +In that case you may try **minimal** reflection level where all data types are inferred from the returned data. From our experience this prevents +most of the coercion problems. +::: + +You can also override the sql type by passing a `type_adapter_callback` function. +This function takes an `sqlalchemy` data type and returns a new type (or `None` to force the column to be inferred from the data). + +This is useful for example when: +- You're loading a data type which is not supported by the destination (e.g. you need JSON type columns to be coerced to string) +- You're using an sqlalchemy dialect which uses custom types that don't inherit from standard sqlalchemy types. +- For certain types you prefer `dlt` to infer data type from the data and you return `None` + +Example, when loading timestamps from Snowflake you can make sure they translate to `timestamp` columns in the result schema: + +```py +import dlt +from snowflake.sqlalchemy import TIMESTAMP_NTZ +import sqlalchemy as sa + +def type_adapter_callback(sql_type): + if isinstance(sql_type, TIMESTAMP_NTZ): # Snowflake does not inherit from sa.DateTime + return sa.DateTime(timezone=True) + return sql_type # Use default detection for other types + +source = sql_database( + "snowflake://user:password@account/database?&warehouse=WH_123", + reflection_level="full", + type_adapter_callback=type_adapter_callback, + backend="pyarrow" +) + +dlt.pipeline("demo").run(source) +``` + +## Extended configuration +You are able to configure most of the arguments to `sql_database` and `sql_table` via toml files and environment variables. This is particularly useful with `sql_table` +because you can maintain a separate configuration for each table (below we show **secrets.toml** and **config.toml**, you are free to combine them into one.): +```toml +[sources.sql_database] +credentials="mssql+pyodbc://loader.database.windows.net/dlt_data?trusted_connection=yes&driver=ODBC+Driver+17+for+SQL+Server" +``` + +```toml +[sources.sql_database.chat_message] +backend="pandas" +chunk_size=1000 + +[sources.sql_database.chat_message.incremental] +cursor_path="updated_at" +``` +Example above will setup **backend** and **chunk_size** for a table with name **chat_message**. It will also enable incremental loading on a column named **updated_at**. +Table resource is instantiated as follows: +```py +table = sql_table(table="chat_message", schema="data") +``` + +Similarly, you can configure `sql_database` source. +```toml +[sources.sql_database] +credentials="mssql+pyodbc://loader.database.windows.net/dlt_data?trusted_connection=yes&driver=ODBC+Driver+17+for+SQL+Server" +schema="data" +backend="pandas" +chunk_size=1000 + +[sources.sql_database.chat_message.incremental] +cursor_path="updated_at" +``` +Note that we are able to configure incremental loading per table, even if it is a part of a dlt source. Source below will extract data using **pandas** backend +with **chunk_size** 1000. **chat_message** table will load data incrementally using **updated_at** column. All other tables will load fully. +```py +database = sql_database() +``` + +You can configure all the arguments this way (except adapter callback function). [Standard dlt rules apply](https://dlthub.com/docs/general-usage/credentials/configuration#configure-dlt-sources-and-resources). You can use environment variables [by translating the names properly](https://dlthub.com/docs/general-usage/credentials/config_providers#toml-vs-environment-variables) ie. +```sh +SOURCES__SQL_DATABASE__CREDENTIALS="mssql+pyodbc://loader.database.windows.net/dlt_data?trusted_connection=yes&driver=ODBC+Driver+17+for+SQL+Server" +SOURCES__SQL_DATABASE__BACKEND=pandas +SOURCES__SQL_DATABASE__CHUNK_SIZE=1000 +SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH=updated_at +``` + +### Configuring incremental loading +`dlt.sources.incremental` class is a [config spec](https://dlthub.com/docs/general-usage/credentials/config_specs) and can be configured like any other spec, here's an example that sets all possible options: +```toml +[sources.sql_database.chat_message.incremental] +cursor_path="updated_at" +initial_value=2024-05-27T07:32:00Z +end_value=2024-05-28T07:32:00Z +row_order="asc" +allow_external_schedulers=false +``` +Please note that we specify date times in **toml** as initial and end value. For env variables only strings are currently supported. + + +### Use SqlAlchemy Engine as credentials +You are able to pass an instance of **SqlAlchemy** `Engine` instance instead of credentials: +```py +from sqlalchemy import create_engine + +engine = create_engine("mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam") +table = sql_table(engine, table="chat_message", schema="data") +``` +Engine is used by `dlt` to open database connections and can work across multiple threads so is compatible with `parallelize` setting of dlt sources and resources. + + ## Troubleshooting ### Connect to mysql with SSL diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/stripe.md b/docs/website/docs/dlt-ecosystem/verified-sources/stripe.md index 5844844cca..8c39a5090e 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/stripe.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/stripe.md @@ -175,24 +175,7 @@ def incremental_stripe_source( After each run, 'initial_start_date' updates to the last loaded date. Subsequent runs then retrieve only new data using append mode, streamlining the process and preventing redundant data downloads. -For more information, read the [General Usage: Incremental loading](../../general-usage/incremental-loading). - -### Resource `metrics_resource` - -This function loads a dictionary with calculated metrics, including MRR and Churn rate, along with the current timestamp. - -```py -@dlt.resource(name="Metrics", write_disposition="append", primary_key="created") -def metrics_resource() -> Iterable[TDataItem]: - ... -``` - -Abrevations MRR and Churn rate are as follows: -- Monthly Recurring Revenue (MRR): - - Measures the predictable monthly revenue from all active subscriptions. It's the sum of the monthly-normalized subscription amounts. -- Churn rate: - - Indicates the rate subscribers leave a service over a specific period. Calculated by dividing the number of recent cancellations by the total subscribers from 30 days ago, adjusted for new subscribers. - +For more information, read the [Incremental loading](../../general-usage/incremental-loading). ## Customization ### Create your own pipeline @@ -236,7 +219,7 @@ verified source. ``` > For subsequent runs, the dlt module sets the previous "end_date" as "initial_start_date", ensuring incremental data retrieval. -1. To load data created after December 31, 2022, adjust the data range for stripe_source to prevent redundant loading. For incremental_stripe_source, the initial_start_date will auto-update to the last loaded date from the previous run. +1. To load data created after December 31, 2022, adjust the data range for stripe_source to prevent redundant loading. For `incremental_stripe_source`, the `initial_start_date` will auto-update to the last loaded date from the previous run. ```py source_single = stripe_source( @@ -249,21 +232,6 @@ verified source. load_info = pipeline.run(data=[source_single, source_incremental]) print(load_info) ``` - > To load data, maintain the pipeline name and destination dataset name. The pipeline name is vital for accessing the last run's [state](https://dlthub.com/docs/general-usage/state), which determines the incremental data load's end date. Altering these names can trigger a [“full_refresh”](https://dlthub.com/docs/general-usage/pipeline#do-experiments-with-full-refresh), disrupting the metadata (state) tracking for [incremental data loading](https://dlthub.com/docs/general-usage/incremental-loading). - -1. To load important metrics and store them in database: - - ```py - # Event is an endpoint with uneditable data, so we can use 'incremental_stripe_source'. - source_event = incremental_stripe_source(endpoints=("Event",)) - # Subscription is an endpoint with editable data, use stripe_source. - source_subs = stripe_source(endpoints=("Subscription",)) - load_info = pipeline.run(data=[source_subs, source_event]) - print(load_info) - resource = metrics_resource() - print(list(resource)) - load_info = pipeline.run(resource) - print(load_info) - ``` + > To load data, maintain the pipeline name and destination dataset name. The pipeline name is vital for accessing the last run's [state](../../general-usage/state), which determines the incremental data load's end date. Altering these names can trigger a [“full_refresh”](../../general-usage/pipeline#do-experiments-with-full-refresh), disrupting the metadata (state) tracking for [incremental data loading](../../general-usage/incremental-loading). diff --git a/docs/website/docs/general-usage/destination.md b/docs/website/docs/general-usage/destination.md index 760daa2fee..b30403d349 100644 --- a/docs/website/docs/general-usage/destination.md +++ b/docs/website/docs/general-usage/destination.md @@ -18,26 +18,27 @@ We recommend that you declare the destination type when creating a pipeline inst Above we want to use **filesystem** built-in destination. You can use shorthand types only for built-ins. -* Use full **destination class type** +* Use full **destination factory type** -Above we use built in **filesystem** destination by providing a class type `filesystem` from module `dlt.destinations`. You can pass [destinations from external modules](#declare-external-destination) as well. +Above we use built in **filesystem** destination by providing a factory type `filesystem` from module `dlt.destinations`. You can pass [destinations from external modules](#declare-external-destination) as well. -* Import **destination class** +* Import **destination factory** -Above we import destination class for **filesystem** and pass it to the pipeline. +Above we import destination factory for **filesystem** and pass it to the pipeline. -All examples above will create the same destination class with default parameters and pull required config and secret values from [configuration](credentials/configuration.md) - they are equivalent. +All examples above will create the same destination factory with default parameters and pull required config and secret values from [configuration](credentials/configuration.md) - they are equivalent. ### Pass explicit parameters and a name to a destination -You can instantiate **destination class** yourself to configure it explicitly. When doing this you work with destinations the same way you work with [sources](source.md) +You can instantiate **destination factory** yourself to configure it explicitly. When doing this you work with destinations the same way you work with [sources](source.md) -Above we import and instantiate the `filesystem` destination class. We pass explicit url of the bucket and name the destination to `production_az_bucket`. +Above we import and instantiate the `filesystem` destination factory. We pass explicit url of the bucket and name the destination to `production_az_bucket`. + +If destination is not named, its shorthand type (the Python factory name) serves as a destination name. Name your destination explicitly if you need several separate configurations of destinations of the same type (i.e. you wish to maintain credentials for development, staging and production storage buckets in the same config file). Destination name is also stored in the [load info](../running-in-production/running.md#inspect-and-save-the-load-info-and-trace) and pipeline traces so use them also when you need more descriptive names (other than, for example, `filesystem`). -If destination is not named, its shorthand type (the Python class name) serves as a destination name. Name your destination explicitly if you need several separate configurations of destinations of the same type (i.e. you wish to maintain credentials for development, staging and production storage buckets in the same config file). Destination name is also stored in the [load info](../running-in-production/running.md#inspect-and-save-the-load-info-and-trace) and pipeline traces so use them also when you need more descriptive names (other than, for example, `filesystem`). ## Configure a destination We recommend to pass the credentials and other required parameters to configuration via TOML files, environment variables or other [config providers](credentials/config_providers.md). This allows you, for example, to easily switch to production destinations after deployment. @@ -59,7 +60,7 @@ For named destinations you use their names in the config section Note that when you use [`dlt init` command](../walkthroughs/add-a-verified-source.md) to create or add a data source, `dlt` creates a sample configuration for selected destination. ### Pass explicit credentials -You can pass credentials explicitly when creating destination class instance. This replaces the `credentials` argument in `dlt.pipeline` and `pipeline.load` methods - which is now deprecated. You can pass the required credentials object, its dictionary representation or the supported native form like below: +You can pass credentials explicitly when creating destination factory instance. This replaces the `credentials` argument in `dlt.pipeline` and `pipeline.load` methods - which is now deprecated. You can pass the required credentials object, its dictionary representation or the supported native form like below: @@ -74,6 +75,23 @@ You can create and pass partial credentials and `dlt` will fill the missing data Please read how to use [various built in credentials types](credentials/config_specs.md). ::: +### Inspect destination capabilities +[Destination capabilities](../walkthroughs/create-new-destination.md#3-set-the-destination-capabilities) tell `dlt` what given destination can and cannot do. For example it tells which file formats it can load, what is maximum query or identifier length. Inspect destination capabilities as follows: +```py +import dlt +pipeline = dlt.pipeline("snowflake_test", destination="snowflake") +print(dict(pipeline.destination.capabilities())) +``` + +### Pass additional parameters and change destination capabilities +Destination factory accepts additional parameters that will be used to pre-configure it and change destination capabilities. +```py +import dlt +duck_ = dlt.destinations.duckdb(naming_convention="duck_case", recommended_file_size=120000) +print(dict(duck_.capabilities())) +``` +Example above is overriding `naming_convention` and `recommended_file_size` in the destination capabilities. + ### Configure multiple destinations in a pipeline To configure multiple destinations within a pipeline, you need to provide the credentials for each destination in the "secrets.toml" file. This example demonstrates how to configure a BigQuery destination named `destination_one`: @@ -86,7 +104,7 @@ private_key = "please set me up!" client_email = "please set me up!" ``` -You can then use this destination in your pipeline as follows: +You can then use this destination in your pipeline as follows: ```py import dlt from dlt.common.destination import Destination @@ -117,6 +135,56 @@ Obviously, dlt will access the destination when you instantiate [sql_client](../ ::: +## Control how `dlt` creates table, column and other identifiers +`dlt` maps identifiers found in the source data into destination identifiers (ie. table and columns names) using [naming conventions](naming-convention.md) which ensure that +character set, identifier length and other properties fit into what given destination can handle. For example our [default naming convention (**snake case**)](naming-convention.md#default-naming-convention-snake_case) converts all names in the source (ie. JSON document fields) into snake case, case insensitive identifiers. + +Each destination declares its preferred naming convention, support for case sensitive identifiers and case folding function that case insensitive identifiers follow. For example: +1. Redshift - by default does not support case sensitive identifiers and converts all of them to lower case. +2. Snowflake - supports case sensitive identifiers and considers upper cased identifiers as case insensitive (which is the default case folding) +3. DuckDb - does not support case sensitive identifiers but does not case fold them so it preserves the original casing in the information schema. +4. Athena - does not support case sensitive identifiers and converts all of them to lower case. +5. BigQuery - all identifiers are case sensitive, there's no case insensitive mode available via case folding (but it can be enabled in dataset level). + +You can change the naming convention used in [many different ways](naming-convention.md#configure-naming-convention), below we set the preferred naming convention on the Snowflake destination to `sql_cs` to switch Snowflake to case sensitive mode: +```py +import dlt +snow_ = dlt.destinations.snowflake(naming_convention="sql_cs_v1") +``` +Setting naming convention will impact all new schemas being created (ie. on first pipeline run) and will re-normalize all existing identifiers. + +:::caution +`dlt` prevents re-normalization of identifiers in tables that were already created at the destination. Use [refresh](pipeline.md#refresh-pipeline-data-and-state) mode to drop the data. You can also disable this behavior via [configuration](naming-convention.md#avoid-identifier-collisions) +::: + +:::note +Destinations that support case sensitive identifiers but use case folding convention to enable case insensitive identifiers are configured in case insensitive mode by default. Examples: Postgres, Snowflake, Oracle. +::: + +:::caution +If you use case sensitive naming convention with case insensitive destination, `dlt` will: +1. Fail the load if it detects identifier collision due to case folding +2. Warn if any case folding is applied by the destination. +::: + +### Enable case sensitive identifiers support +Selected destinations may be configured so they start accepting case sensitive identifiers. For example, it is possible to set case sensitive collation on **mssql** database and then tell `dlt` about it. +```py +from dlt.destinations import mssql +dest_ = mssql(has_case_sensitive_identifiers=True, naming_convention="sql_cs_v1") +``` +Above we can safely use case sensitive naming convention without worrying of name collisions. + +You can configure the case sensitivity, **but configuring destination capabilities is not currently supported**. +```toml +[destination.mssql] +has_case_sensitive_identifiers=true +``` + +:::note +In most cases setting the flag above just indicates to `dlt` that you switched the case sensitive option on a destination. `dlt` will not do that for you. Refer to destination documentation for details. +::: + ## Create new destination You have two ways to implement a new destination: 1. You can use `@dlt.destination` decorator and [implement a sink function](../dlt-ecosystem/destinations/destination.md). This is perfect way to implement reverse ETL destinations that push data back to REST APIs. diff --git a/docs/website/docs/general-usage/http/overview.md b/docs/website/docs/general-usage/http/overview.md index 94dc64eac5..2d193ceb2c 100644 --- a/docs/website/docs/general-usage/http/overview.md +++ b/docs/website/docs/general-usage/http/overview.md @@ -8,6 +8,10 @@ dlt has built-in support for fetching data from APIs: - [RESTClient](./rest-client.md) for interacting with RESTful APIs and paginating the results - [Requests wrapper](./requests.md) for making simple HTTP requests with automatic retries and timeouts +Additionally, dlt provides tools to simplify working with APIs: +- [REST API generic source](../../dlt-ecosystem/verified-sources/rest_api) integrates APIs using a [declarative configuration](../../dlt-ecosystem/verified-sources/rest_api#source-configuration) to minimize custom code. +- [OpenAPI source generator](../../dlt-ecosystem/verified-sources/openapi-generator) automatically creates declarative API configurations from [OpenAPI specifications](https://swagger.io/specification/). + ## Quick example Here's a simple pipeline that reads issues from the [dlt GitHub repository](https://github.com/dlt-hub/dlt/issues). The API endpoint is https://api.github.com/repos/dlt-hub/dlt/issues. The result is "paginated", meaning that the API returns a limited number of issues per page. The `paginate()` method iterates over all pages and yields the results which are then processed by the pipeline. diff --git a/docs/website/docs/general-usage/http/rest-client.md b/docs/website/docs/general-usage/http/rest-client.md index 1093428b0f..d3a06a1d28 100644 --- a/docs/website/docs/general-usage/http/rest-client.md +++ b/docs/website/docs/general-usage/http/rest-client.md @@ -164,7 +164,7 @@ def get_data(): #### HeaderLinkPaginator -This paginator handles pagination based on a link to the next page in the response headers (e.g., the `Link` header, as used by GitHub). +This paginator handles pagination based on a link to the next page in the response headers (e.g., the `Link` header, as used by GitHub API). **Parameters:** @@ -231,7 +231,8 @@ Note, that in this case, the `total_path` parameter is set explicitly to `None` **Parameters:** -- `initial_page`: The starting page number. Defaults to `1`. +- `base_page`: The index of the initial page from the API perspective. Normally, it's 0-based or 1-based (e.g., 1, 2, 3, ...) indexing for the pages. Defaults to 0. +- `page`: The page number for the first request. If not provided, the initial value will be set to `base_page`. - `page_param`: The query parameter name for the page number. Defaults to `"page"`. - `total_path`: A JSONPath expression for the total number of pages. If not provided, pagination is controlled by `maximum_page`. - `maximum_page`: Optional maximum page number. Stops pagination once this page is reached. @@ -305,7 +306,9 @@ client = RESTClient( ### Implementing a custom paginator -When working with APIs that use non-standard pagination schemes, or when you need more control over the pagination process, you can implement a custom paginator by subclassing the `BasePaginator` class and `update_state` and `update_request` methods: +When working with APIs that use non-standard pagination schemes, or when you need more control over the pagination process, you can implement a custom paginator by subclassing the `BasePaginator` class and implementing `init_request`, `update_state` and `update_request` methods: + +- `init_request(request: Request) -> None`: This method is called before making the first API call in the `RESTClient.paginate` method. You can use this method to set up the initial request query parameters, headers, etc. For example, you can set the initial page number or cursor value. - `update_state(response: Response) -> None`: This method updates the paginator's state based on the response of the API call. Typically, you extract pagination details (like the next page reference) from the response and store them in the paginator instance. @@ -325,6 +328,10 @@ class QueryParamPaginator(BasePaginator): self.page_param = page_param self.page = initial_page + def init_request(self, request: Request) -> None: + # This will set the initial page number (e.g. page=1) + self.update_request(request) + def update_state(self, response: Response) -> None: # Assuming the API returns an empty list when no more data is available if not response.json(): @@ -406,8 +413,11 @@ The available authentication methods are defined in the `dlt.sources.helpers.res - [BearerTokenAuth](#bearer-token-authentication) - [APIKeyAuth](#api-key-authentication) - [HttpBasicAuth](#http-basic-authentication) +- [OAuth2ClientCredentials](#oauth20-authorization) For specific use cases, you can [implement custom authentication](#implementing-custom-authentication) by subclassing the `AuthBase` class from the Requests library. +For specific flavors of OAuth 2.0 you can [implement custom OAuth 2.0](#oauth2-authorization) +by subclassing `OAuth2ClientCredentials`. ### Bearer token authentication @@ -477,6 +487,57 @@ client = RESTClient(base_url="https://api.example.com", auth=auth) response = client.get("/protected/resource") ``` +### OAuth 2.0 authorization + +OAuth 2.0 is a common protocol for authorization. We have implemented two-legged authorization employed for server-to-server authorization because the end user (resource owner) does not need to grant approval. +The REST client acts as the OAuth client which obtains a temporary access token from the authorization server. This access token is then sent to the resource server to access protected content. If the access token is expired, the OAuth client automatically refreshes it. + +Unfortunately, most OAuth 2.0 implementations vary and thus you might need to subclass `OAuth2ClientCredentials` and implement `build_access_token_request()` to suite the requirements of the specific authorization server you want to interact with. + +**Parameters:** +- `access_token_url`: The url to obtain the temporary access token. +- `client_id`: Client credential to obtain authorization. Usually issued via a developer portal. +- `client_secret`: Client credential to obtain authorization. Usually issued via a developer portal. +- `access_token_request_data`: A dictionary with data required by the autorization server apart from the `client_id`, `client_secret`, and `"grant_type": "client_credentials"`. Defaults to `None`. +- `default_token_expiration`: The time in seconds after which the temporary access token expires. Defaults to 3600. + +**Example:** + +```py +from base64 import b64encode +from dlt.sources.helpers.rest_client import RESTClient +from dlt.sources.helpers.rest_client.auth import OAuth2ClientCredentials + +class OAuth2ClientCredentialsHTTPBasic(OAuth2ClientCredentials): + """Used e.g. by Zoom Zoom Video Communications, Inc.""" + def build_access_token_request(self) -> Dict[str, Any]: + authentication: str = b64encode( + f"{self.client_id}:{self.client_secret}".encode() + ).decode() + return { + "headers": { + "Authorization": f"Basic {authentication}", + "Content-Type": "application/x-www-form-urlencoded", + }, + "data": self.access_token_request_data, + } + +auth = OAuth2ClientCredentialsHTTPBasic( + access_token_url=dlt.secrets["sources.zoom.access_token_url"], # "https://zoom.us/oauth/token" + client_id=dlt.secrets["sources.zoom.client_id"], + client_secret=dlt.secrets["sources.zoom.client_secret"], + access_token_request_data={ + "grant_type": "account_credentials", + "account_id": dlt.secrets["sources.zoom.account_id"], + }, +) +client = RESTClient(base_url="https://api.zoom.us/v2", auth=auth) + +response = client.get("/users") +``` + + + ### Implementing custom authentication You can implement custom authentication by subclassing the `AuthBase` class and implementing the `__call__` method: diff --git a/docs/website/docs/general-usage/incremental-loading.md b/docs/website/docs/general-usage/incremental-loading.md index 72957402da..b21a5779bc 100644 --- a/docs/website/docs/general-usage/incremental-loading.md +++ b/docs/website/docs/general-usage/incremental-loading.md @@ -29,7 +29,7 @@ using `primary_key`. Use `write_disposition='merge'`.
-![write disposition flowchart](/img/write-dispo-choice.png) +![write disposition flowchart](https://storage.googleapis.com/dlt-blog-images/flowchart_for_scd2.png)
@@ -41,8 +41,9 @@ user's profile Stateless data cannot change - for example, a recorded event, suc Because stateless data does not need to be updated, we can just append it. -For stateful data, comes a second question - Can I extract it incrementally from the source? If not, -then we need to replace the entire data set. If however we can request the data incrementally such +For stateful data, comes a second question - Can I extract it incrementally from the source? If yes, you should use [slowly changing dimensions (Type-2)](#scd2-strategy), which allow you to maintain historical records of data changes over time. + +If not, then we need to replace the entire data set. If however we can request the data incrementally such as "all users added or modified since yesterday" then we can simply apply changes to our existing dataset with the merge write disposition. @@ -505,7 +506,7 @@ def get_events(last_created_at = dlt.sources.incremental("$", last_value_func=by ``` ### Using `last_value_func` for lookback -The example below uses the `last_value_func` to load data from the past month. +The example below uses the `last_value_func` to load data from the past month. ```py def lookback(event): last_value = None @@ -981,3 +982,64 @@ def search_tweets(twitter_bearer_token=dlt.secrets.value, search_terms=None, sta yield page ``` + +## Troubleshooting + +If you see that the incremental loading is not working as expected and the incremental values are not modified between pipeline runs, check the following: + +1. Make sure the `destination`, `pipeline_name` and `dataset_name` are the same between pipeline runs. + +2. Check if `dev_mode` is `False` in the pipeline configuration. Check if `refresh` for associated sources and resources is not enabled. + +3. Check the logs for `Bind incremental on ...` message. This message indicates that the incremental value was bound to the resource and shows the state of the incremental value. + +4. After the pipeline run, check the state of the pipeline. You can do this by running the following command: + +```sh +dlt pipeline -v info +``` + +For example, if your pipeline is defined as follows: + +```py +@dlt.resource +def my_resource( + incremental_object = dlt.sources.incremental("some_key", initial_value=0), +): + ... + +pipeline = dlt.pipeline( + pipeline_name="example_pipeline", + destination="duckdb", +) + +pipeline.run(my_resource) +``` + +You'll see the following output: + +```text +Attaching to pipeline +... + +sources: +{ + "example": { + "resources": { + "my_resource": { + "incremental": { + "some_key": { + "initial_value": 0, + "last_value": 42, + "unique_hashes": [ + "nmbInLyII4wDF5zpBovL" + ] + } + } + } + } + } +} +``` + +Verify that the `last_value` is updated between pipeline runs. \ No newline at end of file diff --git a/docs/website/docs/general-usage/naming-convention.md b/docs/website/docs/general-usage/naming-convention.md new file mode 100644 index 0000000000..bf6e650b9c --- /dev/null +++ b/docs/website/docs/general-usage/naming-convention.md @@ -0,0 +1,147 @@ +--- +title: Naming Convention +description: Control how dlt creates table, column and other identifiers +keywords: [identifiers, snake case, case sensitive, case insensitive, naming] +--- + +# Naming Convention +`dlt` creates table and column identifiers from the data. The data source ie. a stream of JSON documents may have identifiers (i.e. key names in a dictionary) with any Unicode characters, of any length and naming style. On the other hand, destinations require that you follow strict rules when you name tables, columns or collections. +A good example is [Redshift](../dlt-ecosystem/destinations/redshift.md#naming-convention) that accepts case-insensitive alphanumeric identifiers with maximum 127 characters. + +`dlt` groups tables from a single [source](source.md) in a [schema](schema.md). Each schema defines **naming convention** that tells `dlt` how to translate identifiers to the +namespace that the destination understands. Naming conventions are in essence functions that map strings from the source identifier format into destination identifier format. For example our **snake_case** (default) naming convention will translate `DealFlow` source identifier into `deal_flow` destination identifier. + +You can pick which naming convention to use. `dlt` provides a few to [choose from](#available-naming-conventions). You can [easily add your own](#write-your-own-naming-convention) as well. + +:::tip +Standard behavior of `dlt` is to **use the same naming convention for all destinations** so users see always the same table and column names in their databases. +::: + +### Use default naming convention (snake_case) +**snake_case** is case insensitive naming convention, converting source identifiers into lower case snake case identifiers with reduced alphabet. + +- Spaces around identifier are trimmed +- Keeps ascii alphanumerics and underscores, replaces all other characters with underscores (with the exceptions below) +- Replaces `+` and `*` with `x`, `-` with `_`, `@` with `a` and `|` with `l` +- Prepends `_` if name starts with number. +- Multiples of `_` are converted into single `_`. +- Replaces all trailing `_` with `x` + +Uses __ as patent-child separator for tables and flattened column names. + +:::tip +If you do not like **snake_case** your next safe option is **sql_ci** which generates SQL-safe, lower-case, case-insensitive identifiers without any +other transformations. To permanently change the default naming convention on a given machine: +1. set an environment variable `SCHEMA__NAMING` to `sql_ci_v1` OR +2. add the following line to your global `config.toml` (the one in your home dir ie. `~/.dlt/config.toml`) +```toml +[schema] +naming="sql_ci_v1" +``` +::: + +## Source identifiers vs destination identifiers +### Pick the right identifier form when defining resources +`dlt` keeps source (not normalized) identifiers during data [extraction](../reference/explainers/how-dlt-works.md#extract) and translates them during [normalization](../reference/explainers/how-dlt-works.md#normalize). For you it means: +1. If you write a [transformer](resource.md#process-resources-with-dlttransformer) or a [mapping/filtering function](resource.md#filter-transform-and-pivot-data), you will see the original data, without any normalization. Use the source identifiers to access the dicts! +2. If you define a `primary_key` or `cursor` that participate in [cursor field incremental loading](incremental-loading.md#incremental-loading-with-a-cursor-field) use the source identifiers (`dlt` uses them to inspect source data, `Incremental` class is just a filtering function). +3. When defining any other hints ie. `columns` or `merge_key` you can pick source or destination identifiers. `dlt` normalizes all hints together with your data. +4. `Schema` object (ie. obtained from the pipeline or from `dlt` source via `discover_schema`) **always contains destination (normalized) identifiers**. + +### Understand the identifier normalization +Identifiers are translated from source to destination form in **normalize** step. Here's how `dlt` picks the naming convention: + +* The default naming convention is **snake_case**. +* Each destination may define a preferred naming convention in [destination capabilities](destination.md#pass-additional-parameters-and-change-destination-capabilities). Some destinations (ie. Weaviate) need specialized naming convention and will override the default. +* You can [configure a naming convention explicitly](#set-and-adjust-naming-convention-explicitly). Such configuration overrides the destination settings. +* This naming convention is used when new schemas are created. It happens when pipeline is run for the first time. +* Schemas preserve naming convention when saved. Your running pipelines will maintain existing naming conventions if not requested otherwise. +* `dlt` applies final naming convention in `normalize` step. Jobs (files) in load package now have destination identifiers. Pipeline schema is duplicated, locked and saved in the load package and will be used by the destination. + +:::caution +If you change naming convention and `dlt` detects that a change in the destination identifiers for tables/collection/files that already exist and store data, +the normalize process will fail. This prevents an unwanted schema migration. New columns and tables will be created for identifiers that changed. +::: + +### Case sensitive and insensitive destinations +Naming convention declare if the destination identifiers they produce are case sensitive or insensitive. This helps `dlt` to [generate case sensitive / insensitive identifiers for the destinations that support both](destination.md#control-how-dlt-creates-table-column-and-other-identifiers). For example: if you pick case insensitive naming like **snake_case** or **sql_ci_v1**, with Snowflake, `dlt` will generate all upper-case identifiers that Snowflake sees as case insensitive. If you pick case sensitive naming like **sql_cs_v1**, `dlt` will generate quoted case-sensitive identifiers that preserve identifier capitalization. + +Note that many destinations are exclusively case insensitive, of which some preserve casing of identifiers (ie. **duckdb**) and some will case-fold identifiers when creating tables (ie. **Redshift**, **Athena** do lower case on the names). `dlt` is able to detect resulting identifier [collisions](#avoid-identifier-collisions) and stop the load process before data is mangled. + +### Identifier shortening +Identifier shortening happens during normalization. `dlt` takes the maximum length of the identifier from the destination capabilities and will trim the identifiers that are +too long. The default shortening behavior generates short deterministic hashes of the source identifiers and places them in the middle of the destination identifier. This +(with a high probability) avoids shortened identifier collisions. + +### 🚧 [WIP] Name convention changes are lossy +`dlt` does not store the source identifiers in the schema so when naming convention changes (or we increase the maximum identifier length), it is not able to generate a fully correct set of new identifiers. Instead it will re-normalize already normalized identifiers. We are currently working to store full identifier lineage - source identifiers will be stored and mapped to the destination in the schema. + +## Pick your own naming convention + +### Configure naming convention +You can use `config.toml`, environment variables or any other configuration provider to set the naming convention name. Configured naming convention **overrides all other settings** +- changes the naming convention stored in the already created schema +- overrides the destination capabilities preference. +```toml +[schema] +naming="sql_ci_v1" +``` +Configuration above will request **sql_ci_v1** for all pipelines (schemas). An environment variable `SCHEMA__NAMING` set to `sql_ci_v1` has the same effect. + +You have an option to set naming convention per source: +```toml +[sources.zendesk] +config="prop" +[sources.zendesk.schema] +naming="sql_cs_v1" +[sources.zendesk.credentials] +password="pass" +``` +Snippet above demonstrates how to apply certain naming for an example `zendesk` source. + +You can use naming conventions that you created yourself or got from other users. In that case you should pass a full Python import path to the [module that contain the naming convention](#write-your-own-naming-convention): +```toml +[schema] +naming="tests.common.cases.normalizers.sql_upper" +``` +`dlt` will import `tests.common.cases.normalizers.sql_upper` and use `NamingConvention` class found in it as the naming convention. + + +### Available naming conventions +You can pick from a few built-in naming conventions. + +* `snake_case` - the default +* `duck_case` - case sensitive, allows all unicode characters like emoji 💥 +* `direct` - case sensitive, allows all unicode characters, does not contract underscores +* `sql_cs_v1` - case sensitive, generates sql-safe identifiers +* `sql_ci_v1` - case insensitive, generates sql-safe lower case identifiers + +### Set and adjust naming convention explicitly +You can modify destination capabilities to + +## Avoid identifier collisions +`dlt` detects various types of identifier collisions and ignores the others. +1. `dlt` detects collisions if case sensitive naming convention is used on case insensitive destination +2. `dlt` detects collisions if change of naming convention changes the identifiers of tables already created in the destination +3. `dlt` detects collisions when naming convention is applied to column names of arrow tables + +`dlt` will not detect collision when normalizing source data. If you have a dictionary, keys will be merged if they collide after being normalized. +You can create a custom naming convention that does not generate collisions on data, see examples below. + + +## Write your own naming convention +Custom naming conventions are classes that derive from `NamingConvention` that you can import from `dlt.common.normalizers.naming`. We recommend the following module layout: +1. Each naming convention resides in a separate Python module (file) +2. The class is always named `NamingConvention` + +In that case you can use a fully qualified module name in [schema configuration](#configure-naming-convention) or pass module [explicitly](#set-and-adjust-naming-convention-explicitly). + +We include [two examples](../examples/custom_naming) of naming conventions that you may find useful: + +1. A variant of `sql_ci` that generates identifier collisions with a low (user defined) probability by appending a deterministic tag to each name. +2. A variant of `sql_cs` that allows for LATIN (ie. umlaut) characters + +:::note +Note that a fully qualified name of your custom naming convention will be stored in the `Schema` and `dlt` will attempt to import it when schema is loaded from storage. +You should distribute your custom naming conventions with your pipeline code or via a pip package from which it can be imported. +::: diff --git a/docs/website/docs/general-usage/pipeline.md b/docs/website/docs/general-usage/pipeline.md index d1f82f970a..f21d6f0686 100644 --- a/docs/website/docs/general-usage/pipeline.md +++ b/docs/website/docs/general-usage/pipeline.md @@ -98,33 +98,72 @@ You can reset parts or all of your sources by using the `refresh` argument to `d That means when you run the pipeline the sources/resources being processed will have their state reset and their tables either dropped or truncated depending on which refresh mode is used. +`refresh` option works with all relational/sql destinations and file buckets (`filesystem`). it does not work with vector databases (we are working on that) and +with custom destinations. + The `refresh` argument should have one of the following string values to decide the refresh mode: -* `drop_sources` - All sources being processed in `pipeline.run` or `pipeline.extract` are refreshed. - That means all tables listed in their schemas are dropped and state belonging to those sources and all their resources is completely wiped. - The tables are deleted both from pipeline's schema and from the destination database. +### Drop tables and pipeline state for a source with `drop_sources` +All sources being processed in `pipeline.run` or `pipeline.extract` are refreshed. +That means all tables listed in their schemas are dropped and state belonging to those sources and all their resources is completely wiped. +The tables are deleted both from pipeline's schema and from the destination database. - If you only have one source or run with all your sources together, then this is practically like running the pipeline again for the first time +If you only have one source or run with all your sources together, then this is practically like running the pipeline again for the first time - :::caution - This erases schema history for the selected sources and only the latest version is stored - :::: +:::caution +This erases schema history for the selected sources and only the latest version is stored +::: -* `drop_resources` - Limits the refresh to the resources being processed in `pipeline.run` or `pipeline.extract` (.e.g by using `source.with_resources(...)`). - Tables belonging to those resources are dropped and their resource state is wiped (that includes incremental state). - The tables are deleted both from pipeline's schema and from the destination database. +```py +import dlt - Source level state keys are not deleted in this mode (i.e. `dlt.state()[<'my_key>'] = ''`) +pipeline = dlt.pipeline("airtable_demo", destination="duckdb") +pipeline.run(airtable_emojis(), refresh="drop_sources") +``` +In example above we instruct `dlt` to wipe pipeline state belonging to `airtable_emojis` source and drop all the database tables in `duckdb` to +which data was loaded. The `airtable_emojis` source had two resources named "📆 Schedule" and "💰 Budget" loading to tables "_schedule" and "_budget". Here's +what `dlt` does step by step: +1. collects a list of tables to drop by looking for all the tables in the schema that are created in the destination. +2. removes existing pipeline state associated with `airtable_emojis` source +3. resets the schema associated with `airtable_emojis` source +4. executes `extract` and `normalize` steps. those will create fresh pipeline state and a schema +5. before it executes `load` step, the collected tables are dropped from staging and regular dataset +6. schema `airtable_emojis` (associated with the source) is be removed from `_dlt_version` table +7. executes `load` step as usual so tables are re-created and fresh schema and pipeline state are stored. + +### Selectively drop tables and resource state with `drop_resources` +Limits the refresh to the resources being processed in `pipeline.run` or `pipeline.extract` (.e.g by using `source.with_resources(...)`). +Tables belonging to those resources are dropped and their resource state is wiped (that includes incremental state). +The tables are deleted both from pipeline's schema and from the destination database. + +Source level state keys are not deleted in this mode (i.e. `dlt.state()[<'my_key>'] = ''`) + +:::caution +This erases schema history for all affected sources and only the latest schema version is stored. +::: - :::caution - This erases schema history for all affected schemas and only the latest schema version is stored - :::: +```py +import dlt -* `drop_data` - Same as `drop_resources` but instead of dropping tables from schema only the data is deleted from them (i.e. by `TRUNCATE ` in sql destinations). Resource state for selected resources is also wiped. - The schema remains unmodified in this case. +pipeline = dlt.pipeline("airtable_demo", destination="duckdb") +pipeline.run(airtable_emojis().with_resources("📆 Schedule"), refresh="drop_resources") +``` +Above we request that the state associated with "📆 Schedule" resource is reset and the table generated by it ("_schedule") is dropped. Other resources, +tables and state are not affected. Please check `drop_sources` for step by step description of what `dlt` does internally. + +### Selectively truncate tables and reset resource state with `drop_data` +Same as `drop_resources` but instead of dropping tables from schema only the data is deleted from them (i.e. by `TRUNCATE ` in sql destinations). Resource state for selected resources is also wiped. In case of [incremental resources](incremental-loading.md#incremental-loading-with-a-cursor-field) this will +reset the cursor state and fully reload the data from the `initial_value`. + +The schema remains unmodified in this case. +```py +import dlt + +pipeline = dlt.pipeline("airtable_demo", destination="duckdb") +pipeline.run(airtable_emojis().with_resources("📆 Schedule"), refresh="drop_data") +``` +Above the incremental state of the "📆 Schedule" is reset before `extract` step so data is fully reacquired. Just before `load` step starts, + the "_schedule" is truncated and new (full) table data will be inserted/copied. ## Display the loading progress diff --git a/docs/website/docs/general-usage/resource.md b/docs/website/docs/general-usage/resource.md index ac7f7e6b38..14f8d73b58 100644 --- a/docs/website/docs/general-usage/resource.md +++ b/docs/website/docs/general-usage/resource.md @@ -488,6 +488,59 @@ be adjusted after the `batch` is processed in the extract pipeline but before an You can emit columns as Pydantic model and use dynamic hints (ie. lambda for table name) as well. You should avoid redefining `Incremental` this way. ::: +### Import external files +You can import external files ie. `csv`, `parquet` and `jsonl` by yielding items marked with `with_file_import`, optionally passing table schema corresponding +the the imported file. `dlt` will not read, parse and normalize any names (ie. `csv` or `arrow` headers) and will attempt to copy the file into the destination as is. +```py +import os +import dlt + +from filesystem import filesystem + +columns: List[TColumnSchema] = [ + {"name": "id", "data_type": "bigint"}, + {"name": "name", "data_type": "text"}, + {"name": "description", "data_type": "text"}, + {"name": "ordered_at", "data_type": "date"}, + {"name": "price", "data_type": "decimal"}, +] + +import_folder = "/tmp/import" + +@dlt.transformer(columns=columns) +def orders(items: Iterator[FileItemDict]): + for item in items: + # copy file locally + dest_file = os.path.join(import_folder, item["file_name"]) + # download file + item.fsspec.download(item["file_url"], dest_file) + # tell dlt to import the dest_file as `csv` + yield dlt.mark.with_file_import(dest_file, "csv") + + +# use filesystem verified source to glob a bucket +downloader = filesystem( + bucket_url="s3://my_bucket/csv", + file_glob="today/*.csv.gz") | orders + +info = pipeline.run(orders, destination="snowflake") +``` +In the example above, we glob all zipped csv files present on **my_bucket/csv/today** (using `filesystem` verified source) and send file descriptors to `orders` transformer. Transformer downloads and imports the files into extract package. At the end, `dlt` sends them to snowflake (the table will be created because we use `column` hints to define the schema). + +If imported `csv` files are not in `dlt` [default format](../dlt-ecosystem/file-formats/csv.md#default-settings), you may need to pass additional configuration. +```toml +[destination.snowflake.csv_format] +delimiter="|" +include_header=false +on_error_continue=true +``` + +You can sniff the schema from the data ie. using `duckdb` to infer the table schema from `csv` file. `dlt.mark.with_file_import` accepts additional arguments that you can use to pass hints at run time. + +:::note +* If you do not define any columns, the table will not be created in the destination. `dlt` will still attempt to load data into it, so you create a fitting table upfront, the load process will succeed. +* Files are imported using hard links if possible to avoid copying and duplicating storage space needed. +::: ### Duplicate and rename resources There are cases when you your resources are generic (ie. bucket filesystem) and you want to load several instances of it (ie. files from different folders) to separate tables. In example below we use `filesystem` source to load csvs from two different folders into separate tables: @@ -538,12 +591,30 @@ pipeline.run(generate_rows(10)) # load a list of resources pipeline.run([generate_rows(10), generate_rows(20)]) ``` + +### Pick loader file format for a particular resource +You can request a particular loader file format to be used for a resource. +```py +@dlt.resource(file_format="parquet") +def generate_rows(nr): + for i in range(nr): + yield {'id':i, 'example_string':'abc'} +``` +Resource above will be saved and loaded from a `parquet` file (if destination supports it). + +:::note +A special `file_format`: **preferred** will load resource using a format that is preferred by a destination. This settings supersedes the `loader_file_format` passed to `run` method. +::: + ### Do a full refresh -To do a full refresh of an `append` or `merge` resources you temporarily change the write -disposition to replace. You can use `apply_hints` method of a resource or just provide alternative -write disposition when loading: +To do a full refresh of an `append` or `merge` resources you set the `refresh` argument on `run` method to `drop_data`. This will truncate the tables without dropping them. + +```py +p.run(merge_source(), refresh="drop_data") +``` +You can also [fully drop the tables](pipeline.md#refresh-pipeline-data-and-state) in the `merge_source`: ```py -p.run(merge_source(), write_disposition="replace") +p.run(merge_source(), refresh="drop_sources") ``` diff --git a/docs/website/docs/general-usage/schema.md b/docs/website/docs/general-usage/schema.md index 989b023b01..0e3e3bba1f 100644 --- a/docs/website/docs/general-usage/schema.md +++ b/docs/website/docs/general-usage/schema.md @@ -42,8 +42,9 @@ characters, any lengths and naming styles. On the other hand the destinations ac namespaces for their identifiers. Like Redshift that accepts case-insensitive alphanumeric identifiers with maximum 127 characters. -Each schema contains `naming convention` that tells `dlt` how to translate identifiers to the -namespace that the destination understands. +Each schema contains [naming convention](naming-convention.md) that tells `dlt` how to translate identifiers to the +namespace that the destination understands. This convention can be configured, changed in code or enforced via +destination. The default naming convention: @@ -214,7 +215,7 @@ The precision for **bigint** is mapped to available integer types ie. TINYINT, I ## Schema settings The `settings` section of schema file lets you define various global rules that impact how tables -and columns are inferred from data. +and columns are inferred from data. For example you can assign **primary_key** hint to all columns with name `id` or force **timestamp** data type on all columns containing `timestamp` with an use of regex pattern. > 💡 It is the best practice to use those instead of providing the exact column schemas via `columns` > argument or by pasting them in `yaml`. @@ -222,8 +223,9 @@ and columns are inferred from data. ### Data type autodetectors You can define a set of functions that will be used to infer the data type of the column from a -value. The functions are run from top to bottom on the lists. Look in [`detections.py`](https://github.com/dlt-hub/dlt/blob/devel/dlt/common/schema/detections.py) to see what is -available. +value. The functions are run from top to bottom on the lists. Look in `detections.py` to see what is +available. **iso_timestamp** detector that looks for ISO 8601 strings and converts them to **timestamp** +is enabled by default. ```yaml settings: @@ -236,12 +238,24 @@ settings: - wei_to_double ``` +Alternatively you can add and remove detections from code: +```py + source = data_source() + # remove iso time detector + source.schema.remove_type_detection("iso_timestamp") + # convert UNIX timestamp (float, withing a year from NOW) into timestamp + source.schema.add_type_detection("timestamp") +``` +Above we modify a schema that comes with a source to detect UNIX timestamps with **timestamp** detector. + ### Column hint rules You can define a global rules that will apply hints of a newly inferred columns. Those rules apply -to normalized column names. You can use column names directly or with regular expressions. +to normalized column names. You can use column names directly or with regular expressions. `dlt` is matching +the column names **after they got normalized with naming convention**. -Example from ethereum schema: +By default, schema adopts hints rules from json(relational) normalizer to support correct hinting +of columns added by normalizer: ```yaml settings: @@ -249,36 +263,59 @@ settings: foreign_key: - _dlt_parent_id not_null: - - re:^_dlt_id$ + - _dlt_id - _dlt_root_id - _dlt_parent_id - _dlt_list_idx + - _dlt_load_id unique: - _dlt_id - cluster: - - block_hash + root_key: + - _dlt_root_id +``` +Above we require exact column name match for a hint to apply. You can also use regular expression (which we call `SimpleRegex`) as follows: +```yaml +settings: partition: - - block_timestamp + - re:_timestamp$ +``` +Above we add `partition` hint to all columns ending with `_timestamp`. You can do same thing in the code +```py + source = data_source() + # this will update existing hints with the hints passed + source.schema.merge_hints({"partition": ["re:_timestamp$"]}) ``` ### Preferred data types You can define rules that will set the data type for newly created columns. Put the rules under `preferred_types` key of `settings`. On the left side there's a rule on a column name, on the right -side is the data type. - -> ❗See the column hint rules for naming convention! +side is the data type. You can use column names directly or with regular expressions. +`dlt` is matching the column names **after they got normalized with naming convention**. Example: ```yaml settings: preferred_types: - timestamp: timestamp - re:^inserted_at$: timestamp - re:^created_at$: timestamp - re:^updated_at$: timestamp - re:^_dlt_list_idx$: bigint + re:timestamp: timestamp + inserted_at: timestamp + created_at: timestamp + updated_at: timestamp +``` + +Above we prefer `timestamp` data type for all columns containing **timestamp** substring and define a few exact matches ie. **created_at**. +Here's same thing in code +```py + source = data_source() + source.schema.update_preferred_types( + { + "re:timestamp": "timestamp", + "inserted_at": "timestamp", + "created_at": "timestamp", + "updated_at": "timestamp", + } + ) ``` ### Applying data types directly with `@dlt.resource` and `apply_hints` `dlt` offers the flexibility to directly apply data types and hints in your code, bypassing the need for importing and adjusting schemas. This approach is ideal for rapid prototyping and handling data sources with dynamic schema requirements. @@ -364,7 +401,6 @@ def textual(nesting_level: int): schema.remove_type_detection("iso_timestamp") # convert UNIX timestamp (float, withing a year from NOW) into timestamp schema.add_type_detection("timestamp") - schema._compile_settings() return dlt.resource([]) ``` diff --git a/docs/website/docs/general-usage/state.md b/docs/website/docs/general-usage/state.md index 4a9e453ea4..b34d37c8b1 100644 --- a/docs/website/docs/general-usage/state.md +++ b/docs/website/docs/general-usage/state.md @@ -96,7 +96,7 @@ about the pipeline, pipeline run (that the state belongs to) and state blob. if you are not able to implement it with the standard incremental construct. - Store the custom fields dictionaries, dynamic configurations and other source-scoped state. -## When not to use pipeline state +## Do not use pipeline state if it can grow to millions of records Do not use dlt state when it may grow to millions of elements. Do you plan to store modification timestamps of all of your millions of user records? This is probably a bad idea! In that case you @@ -109,6 +109,39 @@ could: [sqlclient](../dlt-ecosystem/transformations/sql.md) and load the data of interest. In that case try at least to process your user records in batches. +### Access data in the destination instead of pipeline state +In the example below, we load recent comments made by given `user_id`. We access `user_comments` table to select +maximum comment id for a given user. +```py +import dlt + +@dlt.resource(name="user_comments") +def comments(user_id: str): + current_pipeline = dlt.current.pipeline() + # find last comment id for given user_id by looking in destination + max_id: int = 0 + # on first pipeline run, user_comments table does not yet exist so do not check at all + # alternatively catch DatabaseUndefinedRelation which is raised when unknown table is selected + if not current_pipeline.first_run: + with current_pipeline.sql_client() as client: + # we may get last user comment or None which we replace with 0 + max_id = ( + client.execute_sql( + "SELECT MAX(_id) FROM user_comments WHERE user_id=?", user_id + )[0][0] + or 0 + ) + # use max_id to filter our results (we simulate API query) + yield from [ + {"_id": i, "value": letter, "user_id": user_id} + for i, letter in zip([1, 2, 3], ["A", "B", "C"]) + if i > max_id + ] +``` +When pipeline is first run, the destination dataset and `user_comments` table do not yet exist. We skip the destination +query by using `first_run` property of the pipeline. We also handle a situation where there are no comments for a user_id +by replacing None with 0 as `max_id`. + ## Inspect the pipeline state You can inspect pipeline state with diff --git a/docs/website/docs/reference/installation.md b/docs/website/docs/reference/installation.md index 3f40c3a545..a23ce82c97 100644 --- a/docs/website/docs/reference/installation.md +++ b/docs/website/docs/reference/installation.md @@ -110,7 +110,7 @@ You can install `dlt` in your virtual environment by running: pip install -U dlt ``` -## Install dlt via pixi and conda +## Install dlt via Pixi and Conda Install dlt using `pixi`: diff --git a/docs/website/docs/reference/performance.md b/docs/website/docs/reference/performance.md index 7d8280d8ee..075d351553 100644 --- a/docs/website/docs/reference/performance.md +++ b/docs/website/docs/reference/performance.md @@ -223,6 +223,12 @@ resources are: `round_robin` and `fifo`. `fifo` is an option for sequential extraction. It will result in every resource being fully extracted until the resource generator is expired, or a configured limit is reached, then the next resource will be evaluated. Resources are extracted in the order that you added them to your source. +:::tip +Switch to `fifo` when debugging sources with many resources and connected transformers, for example [rest_api](../dlt-ecosystem/verified-sources/rest_api.md). +Your data will be requested in deterministic and straightforward order - given data item (ie. user record you got from API) will be processed by all resources +and transformers until completion before starting with new one +::: + You can change this setting in your `config.toml` as follows: diff --git a/docs/website/docs/reference/performance_snippets/toml-snippets.toml b/docs/website/docs/reference/performance_snippets/toml-snippets.toml index 5e700c4e31..e1a640e7cf 100644 --- a/docs/website/docs/reference/performance_snippets/toml-snippets.toml +++ b/docs/website/docs/reference/performance_snippets/toml-snippets.toml @@ -71,7 +71,7 @@ max_parallel_items=10 # @@@DLT_SNIPPET_START normalize_workers_toml - [extract.data_writer] +[extract.data_writer] # force extract file rotation if size exceeds 1MiB file_max_bytes=1000000 diff --git a/docs/website/docs/running-in-production/running.md b/docs/website/docs/running-in-production/running.md index 9c52f58caa..377cf57f2c 100644 --- a/docs/website/docs/running-in-production/running.md +++ b/docs/website/docs/running-in-production/running.md @@ -105,13 +105,15 @@ package. In that case, for a correctly behaving pipeline, only minimum amount of behind. In `config.toml`: ```toml -load.delete_completed_jobs=true +[load] +delete_completed_jobs=true ``` -Also, by default, `dlt` leaves data in staging dataset, used during merge and replace load for deduplication. In order to clear it, put the following line in `config.toml`: +Also, by default, `dlt` leaves data in [staging dataset](../dlt-ecosystem/staging.md#staging-dataset), used during merge and replace load for deduplication. In order to clear it, put the following line in `config.toml`: ```toml -load.truncate_staging_dataset=true +[load] +truncate_staging_dataset=true ``` ## Using slack to send messages @@ -174,7 +176,7 @@ As with any other configuration, you can use environment variables instead of th - `RUNTIME__LOG_LEVEL` to set the log level - `LOG_FORMAT` to set the log format -`dlt` logs to a logger named **dlt**. `dlt` logger uses a regular python logger so you can configure the handlers +`dlt` logs to a logger named **dlt**. `dlt` logger uses a regular python logger so you can configure the handlers as per your requirement. For example, to put logs to the file: diff --git a/docs/website/docs/tutorial/grouping-resources.md b/docs/website/docs/tutorial/grouping-resources.md index 3ba95b7971..2bbfd231f2 100644 --- a/docs/website/docs/tutorial/grouping-resources.md +++ b/docs/website/docs/tutorial/grouping-resources.md @@ -106,7 +106,7 @@ You've noticed that there's a lot of code duplication in the `get_issues` and `g ```py import dlt -from dlt.sources.helpers import requests +from dlt.sources.helpers.rest_client import paginate BASE_GITHUB_URL = "https://api.github.com/repos/dlt-hub/dlt" @@ -231,7 +231,7 @@ The next step is to make our dlt GitHub source reusable so it can load data from ```py import dlt -from dlt.sources.helpers import requests +from dlt.sources.helpers.rest_client import paginate BASE_GITHUB_URL = "https://api.github.com/repos/{repo_name}" diff --git a/docs/website/docs/walkthroughs/add_credentials.md b/docs/website/docs/walkthroughs/add_credentials.md index 586d1c2a93..5b4f241d56 100644 --- a/docs/website/docs/walkthroughs/add_credentials.md +++ b/docs/website/docs/walkthroughs/add_credentials.md @@ -74,3 +74,93 @@ DESTINATION__BIGQUERY__CREDENTIALS__PRIVATE_KEY DESTINATION__BIGQUERY__CREDENTIALS__CLIENT_EMAIL DESTINATION__BIGQUERY__LOCATION ``` + +## Retrieving credentials from Google Cloud Secret Manager +To retrieve secrets from Google Cloud Secret Manager using Python, and convert them into a dictionary format, you'll need to follow these steps. First, ensure that you have the necessary permissions to access the secrets on Google Cloud, and have the `google-cloud-secret-manager` library installed. If not, you can install it using pip: + +```sh +pip install google-cloud-secret-manager +``` + +[Google Cloud Documentation: Secret Manager client libraries.](https://cloud.google.com/secret-manager/docs/reference/libraries) + +Here's how you can retrieve secrets and convert them into a dictionary: + +1. **Set up the Secret Manager client**: Create a client that will interact with the Secret Manager API. +2. **Access the secret**: Use the client to access the secret's latest version. +3. **Convert to a dictionary**: If the secret is stored in a structured format (like JSON), parse it into a Python dictionary. + +Assume we store secrets in JSON format with name "temp-secret": +```json +{"api_token": "ghp_Kskdgf98dugjf98ghd...."} +``` + +Set `.dlt/secrets.toml` as: + +```toml +[google_secrets.credentials] +"project_id" = "" +"private_key" = "-----BEGIN PRIVATE KEY-----\n....\n-----END PRIVATE KEY-----\n" +"client_email" = "....gserviceaccount.com" +``` +or `GOOGLE_SECRETS__CREDENTIALS` to the path of your service account key file. + +Retrieve the secrets stored in the Secret Manager as follows: + +```py +import json as json_lib # Rename the json import to avoid name conflict + +import dlt +from dlt.sources.helpers import requests +from dlt.common.configuration.inject import with_config +from dlt.common.configuration.specs import GcpServiceAccountCredentials +from google.cloud import secretmanager + +@with_config(sections=("google_secrets",)) +def get_secret_dict(secret_id: str, credentials: GcpServiceAccountCredentials = dlt.secrets.value) -> dict: + """ + Retrieve a secret from Google Cloud Secret Manager and convert it to a dictionary. + """ + # Create the Secret Manager client with provided credentials + client = secretmanager.SecretManagerServiceClient(credentials=credentials.to_native_credentials()) + + # Build the resource name of the secret version + name = f"projects/{credentials.project_id}/secrets/{secret_id}/versions/latest" + + # Access the secret version + response = client.access_secret_version(request={"name": name}) + + # Decode the payload to a string and convert it to a dictionary + secret_string = response.payload.data.decode("UTF-8") + secret_dict = json_lib.loads(secret_string) + + return secret_dict + +# Retrieve secret data as a dictionary for use in other functions. +secret_data = get_secret_dict("temp-secret") + +# Set up the request URL and headers +url = "https://api.github.com/orgs/dlt-hub/repos" +headers = { + "Authorization": f"token {secret_data['api_token']}", # Use the API token from the secret data + "Accept": "application/vnd.github+json", # Set the Accept header for GitHub API +} + +# Make a request to the GitHub API to get the list of repositories +response = requests.get(url, headers=headers) + +# Set up the DLT pipeline +pipeline = dlt.pipeline( + pipeline_name="quick_start", destination="duckdb", dataset_name="mydata" +) +# Run the pipeline with the data from the GitHub API response +load_info = pipeline.run(response.json()) +# Print the load information to check the results +print(load_info) +``` + +### Points to Note: + +- **Permissions**: Ensure the service account or user credentials you are using have the necessary permissions to access the Secret Manager and the specific secrets. +- **Secret format**: This example assumes that the secret is stored in a JSON string format. If your secret is in a different format, you will need to adjust the parsing method accordingly. +- **Google Cloud authentication**: Make sure your environment is authenticated with Google Cloud. This can typically be done by setting credentials in `.dlt/secrets.toml` or setting the `GOOGLE_SECRETS__CREDENTIALS` environment variable to the path of your service account key file or the dict of credentials as a string. \ No newline at end of file diff --git a/docs/website/docs/walkthroughs/create-new-destination.md b/docs/website/docs/walkthroughs/create-new-destination.md index 1b72b81e3e..69e7b2fcc1 100644 --- a/docs/website/docs/walkthroughs/create-new-destination.md +++ b/docs/website/docs/walkthroughs/create-new-destination.md @@ -88,6 +88,10 @@ The default `escape_identifier` function identifier escapes `"` and '\' and quot You should avoid providing a custom `escape_literal` function by not enabling `insert-values` for your destination. +### Enable / disable case sensitive identifiers +Specify if destination supports case sensitive identifiers by setting `has_case_sensitive_identifiers` to `True` (or `False` if otherwise). Some case sensitive destinations (ie. **Snowflake** or **Postgres**) support case insensitive identifiers via. case folding ie. **Snowflake** considers all upper case identifiers as case insensitive (set `casefold_identifier` to `str.upper`), **Postgres** does the same with lower case identifiers (`str.lower`). +Some case insensitive destinations (ie. **Athena** or **Redshift**) case-fold (ie. lower case) all identifiers and store them as such. In that case set `casefold_identifier` to `str.lower` as well. + ## 4. Adjust the SQL client **sql client** is a wrapper over `dbapi` and its main role is to provide consistent interface for executing SQL statements, managing transactions and (probably the most important) to help handling errors via classifying exceptions. Here's a few things you should pay attention to: diff --git a/docs/website/sidebars.js b/docs/website/sidebars.js index d3d7def8fc..465212cae6 100644 --- a/docs/website/sidebars.js +++ b/docs/website/sidebars.js @@ -46,11 +46,8 @@ const sidebars = { type: 'category', label: 'Integrations', link: { - type: 'generated-index', - title: 'Integrations', - description: 'dlt fits everywhere where the data flows. check out our curated data sources, destinations and unexpected places where dlt runs', - slug: 'dlt-ecosystem', - keywords: ['getting started'], + type: 'doc', + id: 'dlt-ecosystem/index', }, items: [ { @@ -82,6 +79,7 @@ const sidebars = { 'dlt-ecosystem/verified-sources/mux', 'dlt-ecosystem/verified-sources/notion', 'dlt-ecosystem/verified-sources/personio', + 'dlt-ecosystem/verified-sources/pg_replication', 'dlt-ecosystem/verified-sources/pipedrive', 'dlt-ecosystem/verified-sources/rest_api', 'dlt-ecosystem/verified-sources/openapi-generator', @@ -116,6 +114,7 @@ const sidebars = { 'dlt-ecosystem/destinations/snowflake', 'dlt-ecosystem/destinations/athena', 'dlt-ecosystem/destinations/weaviate', + 'dlt-ecosystem/destinations/lancedb', 'dlt-ecosystem/destinations/qdrant', 'dlt-ecosystem/destinations/dremio', 'dlt-ecosystem/destinations/destination', @@ -157,6 +156,7 @@ const sidebars = { 'general-usage/incremental-loading', 'general-usage/full-loading', 'general-usage/schema', + 'general-usage/naming-convention', 'general-usage/schema-contracts', 'general-usage/schema-evolution', { diff --git a/poetry.lock b/poetry.lock index f6a6f98c1a..a7d754f5a8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "about-time" @@ -2416,6 +2416,20 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] +[[package]] +name = "deprecation" +version = "2.1.0" +description = "A library to handle automated deprecations" +optional = false +python-versions = "*" +files = [ + {file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"}, + {file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"}, +] + +[package.dependencies] +packaging = "*" + [[package]] name = "diff-cover" version = "7.7.0" @@ -2450,6 +2464,17 @@ files = [ [package.extras] graph = ["objgraph (>=1.7.2)"] +[[package]] +name = "distro" +version = "1.9.0" +description = "Distro - an OS platform information API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, + {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, +] + [[package]] name = "dnspython" version = "2.4.2" @@ -2658,21 +2683,27 @@ test = ["pytest (>=6)"] [[package]] name = "fastembed" -version = "0.1.1" +version = "0.2.6" description = "Fast, light, accurate library built for retrieval embedding generation" optional = true -python-versions = ">=3.8.0,<3.12" +python-versions = "<3.13,>=3.8.0" files = [ - {file = "fastembed-0.1.1-py3-none-any.whl", hash = "sha256:131413ae52cd72f4c8cced7a675f8269dbfd1a852abade3c815e265114bcc05a"}, - {file = "fastembed-0.1.1.tar.gz", hash = "sha256:f7e524ee4f74bb8aad16be5b687d1f77f608d40e96e292c87881dc36baf8f4c7"}, + {file = "fastembed-0.2.6-py3-none-any.whl", hash = "sha256:3e18633291722087abebccccd7fcdffafef643cb22d203370d7fad4fa83c10fb"}, + {file = "fastembed-0.2.6.tar.gz", hash = "sha256:adaed5b46e19cc1bbe5f98f2b3ffecfc4d2a48d27512e28ff5bfe92a42649a66"}, ] [package.dependencies] -onnx = ">=1.11,<2.0" -onnxruntime = ">=1.15,<2.0" +huggingface-hub = ">=0.20,<0.21" +loguru = ">=0.7.2,<0.8.0" +numpy = [ + {version = ">=1.21", markers = "python_version < \"3.12\""}, + {version = ">=1.26", markers = "python_version >= \"3.12\""}, +] +onnx = ">=1.15.0,<2.0.0" +onnxruntime = ">=1.17.0,<2.0.0" requests = ">=2.31,<3.0" -tokenizers = ">=0.13,<0.14" -tqdm = ">=4.65,<5.0" +tokenizers = ">=0.15.1,<0.16.0" +tqdm = ">=4.66,<5.0" [[package]] name = "filelock" @@ -2782,6 +2813,21 @@ files = [ [package.dependencies] flake8 = ">=3.8.4" +[[package]] +name = "flake8-print" +version = "5.0.0" +description = "print statement checker plugin for flake8" +optional = false +python-versions = ">=3.7" +files = [ + {file = "flake8-print-5.0.0.tar.gz", hash = "sha256:76915a2a389cc1c0879636c219eb909c38501d3a43cc8dae542081c9ba48bdf9"}, + {file = "flake8_print-5.0.0-py3-none-any.whl", hash = "sha256:84a1a6ea10d7056b804221ac5e62b1cee1aefc897ce16f2e5c42d3046068f5d8"}, +] + +[package.dependencies] +flake8 = ">=3.0" +pycodestyle = "*" + [[package]] name = "flake8-tidy-imports" version = "4.10.0" @@ -3521,6 +3567,164 @@ files = [ {file = "google_re2-1.1-1-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9c6c9f64b9724ec38da8e514f404ac64e9a6a5e8b1d7031c2dadd05c1f4c16fd"}, {file = "google_re2-1.1-1-cp39-cp39-win32.whl", hash = "sha256:d1b751b9ab9f8e2ab2a36d72b909281ce65f328c9115a1685acae1a2d1afd7a4"}, {file = "google_re2-1.1-1-cp39-cp39-win_amd64.whl", hash = "sha256:ac775c75cec7069351d201da4e0fb0cae4c1c5ebecd08fa34e1be89740c1d80b"}, + {file = "google_re2-1.1-2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5eaefe4705b75ca5f78178a50104b689e9282f868e12f119b26b4cffc0c7ee6e"}, + {file = "google_re2-1.1-2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:e35f2c8aabfaaa4ce6420b3cae86c0c29042b1b4f9937254347e9b985694a171"}, + {file = "google_re2-1.1-2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:35fd189cbaaaa39c9a6a8a00164c8d9c709bacd0c231c694936879609beff516"}, + {file = "google_re2-1.1-2-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:60475d222cebd066c80414831c8a42aa2449aab252084102ee05440896586e6a"}, + {file = "google_re2-1.1-2-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:871cb85b9b0e1784c983b5c148156b3c5314cb29ca70432dff0d163c5c08d7e5"}, + {file = "google_re2-1.1-2-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:94f4e66e34bdb8de91ec6cdf20ba4fa9fea1dfdcfb77ff1f59700d01a0243664"}, + {file = "google_re2-1.1-2-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1563577e2b720d267c4cffacc0f6a2b5c8480ea966ebdb1844fbea6602c7496f"}, + {file = "google_re2-1.1-2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:49b7964532a801b96062d78c0222d155873968f823a546a3dbe63d73f25bb56f"}, + {file = "google_re2-1.1-2-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2362fd70eb639a75fd0187d28b4ba7b20b3088833d8ad7ffd8693d0ba159e1c2"}, + {file = "google_re2-1.1-2-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:86b80719636a4e21391e20a9adf18173ee6ae2ec956726fe2ff587417b5e8ba6"}, + {file = "google_re2-1.1-2-cp310-cp310-win32.whl", hash = "sha256:5456fba09df951fe8d1714474ed1ecda102a68ddffab0113e6c117d2e64e6f2b"}, + {file = "google_re2-1.1-2-cp310-cp310-win_amd64.whl", hash = "sha256:2ac6936a3a60d8d9de9563e90227b3aea27068f597274ca192c999a12d8baa8f"}, + {file = "google_re2-1.1-2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d5a87b436028ec9b0f02fe19d4cbc19ef30441085cdfcdf1cce8fbe5c4bd5e9a"}, + {file = "google_re2-1.1-2-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:fc0d4163de9ed2155a77e7a2d59d94c348a6bbab3cff88922fab9e0d3d24faec"}, + {file = "google_re2-1.1-2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:48b12d953bc796736e7831d67b36892fb6419a4cc44cb16521fe291e594bfe23"}, + {file = "google_re2-1.1-2-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:62c780c927cff98c1538439f0ff616f48a9b2e8837c676f53170d8ae5b9e83cb"}, + {file = "google_re2-1.1-2-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:04b2aefd768aa4edeef8b273327806c9cb0b82e90ff52eacf5d11003ac7a0db2"}, + {file = "google_re2-1.1-2-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:9c90175992346519ee7546d9af9a64541c05b6b70346b0ddc54a48aa0d3b6554"}, + {file = "google_re2-1.1-2-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:22ad9ad9d125249d6386a2e80efb9de7af8260b703b6be7fa0ab069c1cf56ced"}, + {file = "google_re2-1.1-2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f70971f6ffe5254e476e71d449089917f50ebf9cf60f9cec80975ab1693777e2"}, + {file = "google_re2-1.1-2-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f267499529e64a4abed24c588f355ebe4700189d434d84a7367725f5a186e48d"}, + {file = "google_re2-1.1-2-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b632eff5e4cd44545a9c0e52f2e1becd55831e25f4dd4e0d7ec8ee6ca50858c1"}, + {file = "google_re2-1.1-2-cp311-cp311-win32.whl", hash = "sha256:a42c733036e8f242ee4e5f0e27153ad4ca44ced9e4ce82f3972938ddee528db0"}, + {file = "google_re2-1.1-2-cp311-cp311-win_amd64.whl", hash = "sha256:64f8eed4ca96905d99b5286b3d14b5ca4f6a025ff3c1351626a7df2f93ad1ddd"}, + {file = "google_re2-1.1-2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5541efcca5b5faf7e0d882334a04fa479bad4e7433f94870f46272eec0672c4a"}, + {file = "google_re2-1.1-2-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:92309af35b6eb2d3b3dc57045cdd83a76370958ab3e0edd2cc4638f6d23f5b32"}, + {file = "google_re2-1.1-2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:197cd9bcaba96d18c5bf84d0c32fca7a26c234ea83b1d3083366f4392cb99f78"}, + {file = "google_re2-1.1-2-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:1b896f171d29b541256cf26e10dccc9103ac1894683914ed88828ca6facf8dca"}, + {file = "google_re2-1.1-2-cp38-cp38-macosx_13_0_arm64.whl", hash = "sha256:e022d3239b945014e916ca7120fee659b246ec26c301f9e0542f1a19b38a8744"}, + {file = "google_re2-1.1-2-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:2c73f8a9440873b68bee1198094377501065e85aaf6fcc0d2512c7589ffa06ca"}, + {file = "google_re2-1.1-2-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:901d86555bd7725506d651afaba7d71cd4abd13260aed6cfd7c641a45f76d4f6"}, + {file = "google_re2-1.1-2-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ce4710ff636701cfb56eb91c19b775d53b03749a23b7d2a5071bbbf4342a9067"}, + {file = "google_re2-1.1-2-cp38-cp38-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:76a20e5ebdf5bc5d430530197e42a2eeb562f729d3a3fb51f39168283d676e66"}, + {file = "google_re2-1.1-2-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:77c9f4d4bb1c8de9d2642d3c4b8b615858ba764df025b3b4f1310266f8def269"}, + {file = "google_re2-1.1-2-cp38-cp38-win32.whl", hash = "sha256:94bd60785bf37ef130a1613738e3c39465a67eae3f3be44bb918540d39b68da3"}, + {file = "google_re2-1.1-2-cp38-cp38-win_amd64.whl", hash = "sha256:59efeb77c0dcdbe37794c61f29c5b1f34bc06e8ec309a111ccdd29d380644d70"}, + {file = "google_re2-1.1-2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:221e38c27e1dd9ccb8e911e9c7aed6439f68ce81e7bb74001076830b0d6e931d"}, + {file = "google_re2-1.1-2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:d9145879e6c2e1b814445300b31f88a675e1f06c57564670d95a1442e8370c27"}, + {file = "google_re2-1.1-2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:c8a12f0740e2a52826bdbf95569a4b0abdf413b4012fa71e94ad25dd4715c6e5"}, + {file = "google_re2-1.1-2-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:9c9998f71466f4db7bda752aa7c348b2881ff688e361108fe500caad1d8b9cb2"}, + {file = "google_re2-1.1-2-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:0c39f69b702005963a3d3bf78743e1733ad73efd7e6e8465d76e3009e4694ceb"}, + {file = "google_re2-1.1-2-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:6d0ce762dee8d6617d0b1788a9653e805e83a23046c441d0ea65f1e27bf84114"}, + {file = "google_re2-1.1-2-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ecf3619d98c9b4a7844ab52552ad32597cdbc9a5bdbc7e3435391c653600d1e2"}, + {file = "google_re2-1.1-2-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9a1426a8cbd1fa004974574708d496005bd379310c4b1c7012be4bc75efde7a8"}, + {file = "google_re2-1.1-2-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a1a30626ba48b4070f3eab272d860ef1952e710b088792c4d68dddb155be6bfc"}, + {file = "google_re2-1.1-2-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1b9c1ffcfbc3095b6ff601ec2d2bf662988f6ea6763bc1c9d52bec55881f8fde"}, + {file = "google_re2-1.1-2-cp39-cp39-win32.whl", hash = "sha256:32ecf995a252c0548404c1065ba4b36f1e524f1f4a86b6367a1a6c3da3801e30"}, + {file = "google_re2-1.1-2-cp39-cp39-win_amd64.whl", hash = "sha256:e7865410f3b112a3609739283ec3f4f6f25aae827ff59c6bfdf806fd394d753e"}, + {file = "google_re2-1.1-3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3b21f83f0a201009c56f06fcc7294a33555ede97130e8a91b3f4cae01aed1d73"}, + {file = "google_re2-1.1-3-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:b38194b91354a38db1f86f25d09cdc6ac85d63aee4c67b43da3048ce637adf45"}, + {file = "google_re2-1.1-3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:e7da3da8d6b5a18d6c3b61b11cc5b66b8564eaedce99d2312b15b6487730fc76"}, + {file = "google_re2-1.1-3-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:aeca656fb10d8638f245331aabab59c9e7e051ca974b366dd79e6a9efb12e401"}, + {file = "google_re2-1.1-3-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:2069d6dc94f5fa14a159bf99cad2f11e9c0f8ec3b7f44a4dde9e59afe5d1c786"}, + {file = "google_re2-1.1-3-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:2319a39305a4931cb5251451f2582713418a19bef2af7adf9e2a7a0edd939b99"}, + {file = "google_re2-1.1-3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:eb98fc131699756c6d86246f670a5e1c1cc1ba85413c425ad344cb30479b246c"}, + {file = "google_re2-1.1-3-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a6e038986d8ffe4e269f8532f03009f229d1f6018d4ac0dabc8aff876338f6e0"}, + {file = "google_re2-1.1-3-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8618343ee658310e0f53bf586fab7409de43ce82bf8d9f7eb119536adc9783fd"}, + {file = "google_re2-1.1-3-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d8140ca861cfe00602319cefe2c7b8737b379eb07fb328b51dc44584f47a2718"}, + {file = "google_re2-1.1-3-cp310-cp310-win32.whl", hash = "sha256:41f439c5c54e8a3a0a1fa2dbd1e809d3f643f862df7b16dd790f36a1238a272e"}, + {file = "google_re2-1.1-3-cp310-cp310-win_amd64.whl", hash = "sha256:fe20e97a33176d96d3e4b5b401de35182b9505823abea51425ec011f53ef5e56"}, + {file = "google_re2-1.1-3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7c39ff52b1765db039f690ee5b7b23919d8535aae94db7996079fbde0098c4d7"}, + {file = "google_re2-1.1-3-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:5420be674fd164041639ba4c825450f3d4bd635572acdde16b3dcd697f8aa3ef"}, + {file = "google_re2-1.1-3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:ff53881cf1ce040f102a42d39db93c3f835f522337ae9c79839a842f26d97733"}, + {file = "google_re2-1.1-3-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:8d04600b0b53523118df2e413a71417c408f20dee640bf07dfab601c96a18a77"}, + {file = "google_re2-1.1-3-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:c4835d4849faa34a7fa1074098d81c420ed6c0707a3772482b02ce14f2a7c007"}, + {file = "google_re2-1.1-3-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:3309a9b81251d35fee15974d0ae0581a9a375266deeafdc3a3ac0d172a742357"}, + {file = "google_re2-1.1-3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e2b51cafee7e0bc72d0a4a454547bd8f257cde412ac9f1a2dc46a203b5e42cf4"}, + {file = "google_re2-1.1-3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:83f5f1cb52f832c2297d271ee8c56cf5e9053448162e5d2223d513f729bad908"}, + {file = "google_re2-1.1-3-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:55865a1ace92be3f7953b2e2b38b901d8074a367aa491daee43260a53a7fc6f0"}, + {file = "google_re2-1.1-3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:cec2167dd142e583e98c783bd0d28b8cf5a9cdbe1f7407ba4163fe3ccb613cb9"}, + {file = "google_re2-1.1-3-cp311-cp311-win32.whl", hash = "sha256:a0bc1fe96849e4eb8b726d0bba493f5b989372243b32fe20729cace02e5a214d"}, + {file = "google_re2-1.1-3-cp311-cp311-win_amd64.whl", hash = "sha256:e6310a156db96fc5957cb007dd2feb18476898654530683897469447df73a7cd"}, + {file = "google_re2-1.1-3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8e63cd10ea006088b320e8c5d308da1f6c87aa95138a71c60dd7ca1c8e91927e"}, + {file = "google_re2-1.1-3-cp312-cp312-macosx_11_0_x86_64.whl", hash = "sha256:12b566830a334178733a85e416b1e0507dbc0ceb322827616fe51ef56c5154f1"}, + {file = "google_re2-1.1-3-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:442e18c9d46b225c1496919c16eafe8f8d9bb4091b00b4d3440da03c55bbf4ed"}, + {file = "google_re2-1.1-3-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:c54c00263a9c39b2dacd93e9636319af51e3cf885c080b9680a9631708326460"}, + {file = "google_re2-1.1-3-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:15a3caeeb327bc22e0c9f95eb76890fec8874cacccd2b01ff5c080ab4819bbec"}, + {file = "google_re2-1.1-3-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:59ec0d2cced77f715d41f6eafd901f6b15c11e28ba25fe0effdc1de554d78e75"}, + {file = "google_re2-1.1-3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:185bf0e3441aed3840590f8e42f916e2920d235eb14df2cbc2049526803d3e71"}, + {file = "google_re2-1.1-3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:586d3f2014eea5be14d8de53374d9b79fa99689160e00efa64b5fe93af326087"}, + {file = "google_re2-1.1-3-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cc2575082de4ffd234d9607f3ae67ca22b15a1a88793240e2045f3b3a36a5795"}, + {file = "google_re2-1.1-3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:59c5ad438eddb3630def394456091284d7bbc5b89351987f94f3792d296d1f96"}, + {file = "google_re2-1.1-3-cp312-cp312-win32.whl", hash = "sha256:5b9878c53f2bf16f75bf71d4ddd57f6611351408d5821040e91c53ebdf82c373"}, + {file = "google_re2-1.1-3-cp312-cp312-win_amd64.whl", hash = "sha256:4fdecfeb213110d0a85bad335a8e7cdb59fea7de81a4fe659233f487171980f9"}, + {file = "google_re2-1.1-3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2dd87bacab32b709c28d0145fe75a956b6a39e28f0726d867375dba5721c76c1"}, + {file = "google_re2-1.1-3-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:55d24c61fe35dddc1bb484593a57c9f60f9e66d7f31f091ef9608ed0b6dde79f"}, + {file = "google_re2-1.1-3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a0cf1180d908622df648c26b0cd09281f92129805ccc56a39227fdbfeab95cb4"}, + {file = "google_re2-1.1-3-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:09586f07f3f88d432265c75976da1c619ab7192cd7ebdf53f4ae0776c19e4b56"}, + {file = "google_re2-1.1-3-cp38-cp38-macosx_13_0_arm64.whl", hash = "sha256:539f1b053402203576e919a06749198da4ae415931ee28948a1898131ae932ce"}, + {file = "google_re2-1.1-3-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:abf0bcb5365b0e27a5a23f3da403dffdbbac2c0e3a3f1535a8b10cc121b5d5fb"}, + {file = "google_re2-1.1-3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:19c83e5bbed7958213eeac3aa71c506525ce54faf03e07d0b96cd0a764890511"}, + {file = "google_re2-1.1-3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3348e77330ff672dc44ec01894fa5d93c409a532b6d688feac55e714e9059920"}, + {file = "google_re2-1.1-3-cp38-cp38-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:06b63edb57c5ce5a13eabfd71155e346b9477dc8906dec7c580d4f70c16a7e0d"}, + {file = "google_re2-1.1-3-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:12fe57ba2914092b83338d61d8def9ebd5a2bd0fd8679eceb5d4c2748105d5c0"}, + {file = "google_re2-1.1-3-cp38-cp38-win32.whl", hash = "sha256:80796e08d24e606e675019fe8de4eb5c94bb765be13c384f2695247d54a6df75"}, + {file = "google_re2-1.1-3-cp38-cp38-win_amd64.whl", hash = "sha256:3c2257dedfe7cc5deb6791e563af9e071a9d414dad89e37ac7ad22f91be171a9"}, + {file = "google_re2-1.1-3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:43a0cd77c87c894f28969ac622f94b2e6d1571261dfdd785026848a25cfdc9b9"}, + {file = "google_re2-1.1-3-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:1038990b77fd66f279bd66a0832b67435ea925e15bb59eafc7b60fdec812b616"}, + {file = "google_re2-1.1-3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:fb5dda6875d18dd45f0f24ebced6d1f7388867c8fb04a235d1deab7ea479ce38"}, + {file = "google_re2-1.1-3-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:bb1d164965c6d57a351b421d2f77c051403766a8b75aaa602324ee2451fff77f"}, + {file = "google_re2-1.1-3-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:a072ebfa495051d07ffecbf6ce21eb84793568d5c3c678c00ed8ff6b8066ab31"}, + {file = "google_re2-1.1-3-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:4eb66c8398c8a510adc97978d944b3b29c91181237218841ea1a91dc39ec0e54"}, + {file = "google_re2-1.1-3-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f7c8b57b1f559553248d1757b7fa5b2e0cc845666738d155dff1987c2618264e"}, + {file = "google_re2-1.1-3-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9162f6aa4f25453c682eb176f21b8e2f40205be9f667e98a54b3e1ff10d6ee75"}, + {file = "google_re2-1.1-3-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a2d65ddf67fd7bf94705626871d463057d3d9a3538d41022f95b9d8f01df36e1"}, + {file = "google_re2-1.1-3-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d140c7b9395b4d1e654127aa1c99bcc603ed01000b7bc7e28c52562f1894ec12"}, + {file = "google_re2-1.1-3-cp39-cp39-win32.whl", hash = "sha256:80c5fc200f64b2d903eeb07b8d6cefc620a872a0240c7caaa9aca05b20f5568f"}, + {file = "google_re2-1.1-3-cp39-cp39-win_amd64.whl", hash = "sha256:9eb6dbcee9b5dc4069bbc0634f2eb039ca524a14bed5868fdf6560aaafcbca06"}, + {file = "google_re2-1.1-4-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:0db114d7e1aa96dbcea452a40136d7d747d60cbb61394965774688ef59cccd4e"}, + {file = "google_re2-1.1-4-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:82133958e003a1344e5b7a791b9a9dd7560b5c8f96936dbe16f294604524a633"}, + {file = "google_re2-1.1-4-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:9e74fd441d1f3d917d3303e319f61b82cdbd96b9a5ba919377a6eef1504a1e2b"}, + {file = "google_re2-1.1-4-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:734a2e7a4541c57253b5ebee24f3f3366ba3658bcad01da25fb623c78723471a"}, + {file = "google_re2-1.1-4-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:d88d5eecbc908abe16132456fae13690d0508f3ac5777f320ef95cb6cab9a961"}, + {file = "google_re2-1.1-4-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:b91db80b171ecec435a07977a227757dd487356701a32f556fa6fca5d0a40522"}, + {file = "google_re2-1.1-4-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6b23129887a64bb9948af14c84705273ed1a40054e99433b4acccab4dcf6a226"}, + {file = "google_re2-1.1-4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5dc1a0cc7cd19261dcaf76763e2499305dbb7e51dc69555167cdb8af98782698"}, + {file = "google_re2-1.1-4-cp310-cp310-win32.whl", hash = "sha256:3b2ab1e2420b5dd9743a2d6bc61b64e5f708563702a75b6db86637837eaeaf2f"}, + {file = "google_re2-1.1-4-cp310-cp310-win_amd64.whl", hash = "sha256:92efca1a7ef83b6df012d432a1cbc71d10ff42200640c0f9a5ff5b343a48e633"}, + {file = "google_re2-1.1-4-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:854818fd4ce79787aca5ba459d6e5abe4ca9be2c684a5b06a7f1757452ca3708"}, + {file = "google_re2-1.1-4-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:4ceef51174b6f653b6659a8fdaa9c38960c5228b44b25be2a3bcd8566827554f"}, + {file = "google_re2-1.1-4-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:ee49087c3db7e6f5238105ab5299c09e9b77516fe8cfb0a37e5f1e813d76ecb8"}, + {file = "google_re2-1.1-4-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:dc2312854bdc01410acc5d935f1906a49cb1f28980341c20a68797ad89d8e178"}, + {file = "google_re2-1.1-4-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:0dc0d2e42296fa84a3cb3e1bd667c6969389cd5cdf0786e6b1f911ae2d75375b"}, + {file = "google_re2-1.1-4-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:6bf04ced98453b035f84320f348f67578024f44d2997498def149054eb860ae8"}, + {file = "google_re2-1.1-4-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1d6b6ef11dc4ab322fa66c2f3561925f2b5372a879c3ed764d20e939e2fd3e5f"}, + {file = "google_re2-1.1-4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0dcde6646fa9a97fd3692b3f6ae7daf7f3277d7500b6c253badeefa11db8956a"}, + {file = "google_re2-1.1-4-cp311-cp311-win32.whl", hash = "sha256:5f4f0229deb057348893574d5b0a96d055abebac6debf29d95b0c0e26524c9f6"}, + {file = "google_re2-1.1-4-cp311-cp311-win_amd64.whl", hash = "sha256:4713ddbe48a18875270b36a462b0eada5e84d6826f8df7edd328d8706b6f9d07"}, + {file = "google_re2-1.1-4-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:40a698300b8faddbb325662973f839489c89b960087060bd389c376828978a04"}, + {file = "google_re2-1.1-4-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:103d2d7ac92ba23911a151fd1fc7035cbf6dc92a7f6aea92270ebceb5cd5acd3"}, + {file = "google_re2-1.1-4-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:51fb7182bccab05e8258a2b6a63dda1a6b4a9e8dfb9b03ec50e50c49c2827dd4"}, + {file = "google_re2-1.1-4-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:65383022abd63d7b620221eba7935132b53244b8b463d8fdce498c93cf58b7b7"}, + {file = "google_re2-1.1-4-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:396281fc68a9337157b3ffcd9392c6b7fcb8aab43e5bdab496262a81d56a4ecc"}, + {file = "google_re2-1.1-4-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:8198adcfcff1c680e052044124621730fc48d08005f90a75487f5651f1ebfce2"}, + {file = "google_re2-1.1-4-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:81f7bff07c448aec4db9ca453d2126ece8710dbd9278b8bb09642045d3402a96"}, + {file = "google_re2-1.1-4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b7dacf730fd7d6ec71b11d6404b0b26e230814bfc8e9bb0d3f13bec9b5531f8d"}, + {file = "google_re2-1.1-4-cp312-cp312-win32.whl", hash = "sha256:8c764f62f4b1d89d1ef264853b6dd9fee14a89e9b86a81bc2157fe3531425eb4"}, + {file = "google_re2-1.1-4-cp312-cp312-win_amd64.whl", hash = "sha256:0be2666df4bc5381a5d693585f9bbfefb0bfd3c07530d7e403f181f5de47254a"}, + {file = "google_re2-1.1-4-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:5cb1b63a0bfd8dd65d39d2f3b2e5ae0a06ce4b2ce5818a1d1fc78a786a252673"}, + {file = "google_re2-1.1-4-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:e41751ce6b67a95230edd0772226dc94c2952a2909674cd69df9804ed0125307"}, + {file = "google_re2-1.1-4-cp38-cp38-macosx_13_0_arm64.whl", hash = "sha256:b998cfa2d50bf4c063e777c999a7e8645ec7e5d7baf43ad71b1e2e10bb0300c3"}, + {file = "google_re2-1.1-4-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:226ca3b0c2e970f3fc82001ac89e845ecc7a4bb7c68583e7a76cda70b61251a7"}, + {file = "google_re2-1.1-4-cp38-cp38-macosx_14_0_arm64.whl", hash = "sha256:9adec1f734ebad7c72e56c85f205a281d8fe9bf6583bc21020157d3f2812ce89"}, + {file = "google_re2-1.1-4-cp38-cp38-macosx_14_0_x86_64.whl", hash = "sha256:9c34f3c64ba566af967d29e11299560e6fdfacd8ca695120a7062b6ed993b179"}, + {file = "google_re2-1.1-4-cp38-cp38-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e1b85385fe293838e0d0b6e19e6c48ba8c6f739ea92ce2e23b718afe7b343363"}, + {file = "google_re2-1.1-4-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4694daa8a8987cfb568847aa872f9990e930c91a68c892ead876411d4b9012c3"}, + {file = "google_re2-1.1-4-cp38-cp38-win32.whl", hash = "sha256:5e671e9be1668187e2995aac378de574fa40df70bb6f04657af4d30a79274ce0"}, + {file = "google_re2-1.1-4-cp38-cp38-win_amd64.whl", hash = "sha256:f66c164d6049a8299f6dfcfa52d1580576b4b9724d6fcdad2f36f8f5da9304b6"}, + {file = "google_re2-1.1-4-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:25cb17ae0993a48c70596f3a3ef5d659638106401cc8193f51c0d7961b3b3eb7"}, + {file = "google_re2-1.1-4-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:5f101f86d14ca94ca4dcf63cceaa73d351f2be2481fcaa29d9e68eeab0dc2a88"}, + {file = "google_re2-1.1-4-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:4e82591e85bf262a6d74cff152867e05fc97867c68ba81d6836ff8b0e7e62365"}, + {file = "google_re2-1.1-4-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:1f61c09b93ffd34b1e2557e5a9565039f935407a5786dbad46f64f1a484166e6"}, + {file = "google_re2-1.1-4-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:12b390ad8c7e74bab068732f774e75e0680dade6469b249a721f3432f90edfc3"}, + {file = "google_re2-1.1-4-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:1284343eb31c2e82ed2d8159f33ba6842238a56782c881b07845a6d85613b055"}, + {file = "google_re2-1.1-4-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6c7b38e0daf2c06e4d3163f4c732ab3ad2521aecfed6605b69e4482c612da303"}, + {file = "google_re2-1.1-4-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f4d4f0823e8b2f6952a145295b1ff25245ce9bb136aff6fe86452e507d4c1dd"}, + {file = "google_re2-1.1-4-cp39-cp39-win32.whl", hash = "sha256:1afae56b2a07bb48cfcfefaa15ed85bae26a68f5dc7f9e128e6e6ea36914e847"}, + {file = "google_re2-1.1-4-cp39-cp39-win_amd64.whl", hash = "sha256:aa7d6d05911ab9c8adbf3c225a7a120ab50fd2784ac48f2f0d140c0b7afc2b55"}, ] [[package]] @@ -3924,6 +4128,38 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] +[[package]] +name = "huggingface-hub" +version = "0.20.3" +description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" +optional = true +python-versions = ">=3.8.0" +files = [ + {file = "huggingface_hub-0.20.3-py3-none-any.whl", hash = "sha256:d988ae4f00d3e307b0c80c6a05ca6dbb7edba8bba3079f74cda7d9c2e562a7b6"}, + {file = "huggingface_hub-0.20.3.tar.gz", hash = "sha256:94e7f8e074475fbc67d6a71957b678e1b4a74ff1b64a644fd6cbb83da962d05d"}, +] + +[package.dependencies] +filelock = "*" +fsspec = ">=2023.5.0" +packaging = ">=20.9" +pyyaml = ">=5.1" +requests = "*" +tqdm = ">=4.42.1" +typing-extensions = ">=3.7.4.3" + +[package.extras] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +cli = ["InquirerPy (==0.3.4)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] +inference = ["aiohttp", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)"] +quality = ["mypy (==1.5.1)", "ruff (>=0.1.3)"] +tensorflow = ["graphviz", "pydot", "tensorflow"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +torch = ["torch"] +typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] + [[package]] name = "humanfriendly" version = "10.0" @@ -4229,6 +4465,78 @@ completion = ["shtab"] docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-ruff"] +[[package]] +name = "lancedb" +version = "0.6.13" +description = "lancedb" +optional = false +python-versions = ">=3.8" +files = [ + {file = "lancedb-0.6.13-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:4667353ca7fa187e94cb0ca4c5f9577d65eb5160f6f3fe9e57902d86312c3869"}, + {file = "lancedb-0.6.13-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:2e22533fe6f6b2d7037dcdbbb4019a62402bbad4ce18395be68f4aa007bf8bc0"}, + {file = "lancedb-0.6.13-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:837eaceafb87e3ae4c261eef45c4f73715f892a36165572c3da621dbdb45afcf"}, + {file = "lancedb-0.6.13-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:61af2d72b2a2f0ea419874c3f32760fe5e51530da3be2d65251a0e6ded74419b"}, + {file = "lancedb-0.6.13-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:31b24e57ee313f4ce6255e45d42e8bee19b90ddcd13a9e07030ac04f76e7dfde"}, + {file = "lancedb-0.6.13-cp38-abi3-win_amd64.whl", hash = "sha256:b851182d8492b1e5b57a441af64c95da65ca30b045d6618dc7d203c6d60d70fa"}, +] + +[package.dependencies] +attrs = ">=21.3.0" +cachetools = "*" +deprecation = "*" +overrides = ">=0.7" +pydantic = ">=1.10" +pylance = "0.10.12" +ratelimiter = ">=1.0,<2.0" +requests = ">=2.31.0" +retry = ">=0.9.2" +semver = "*" +tqdm = ">=4.27.0" + +[package.extras] +azure = ["adlfs (>=2024.2.0)"] +clip = ["open-clip", "pillow", "torch"] +dev = ["pre-commit", "ruff"] +docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] +embeddings = ["awscli (>=1.29.57)", "boto3 (>=1.28.57)", "botocore (>=1.31.57)", "cohere", "google-generativeai", "huggingface-hub", "instructorembedding", "open-clip-torch", "openai (>=1.6.1)", "pillow", "sentence-transformers", "torch"] +tests = ["aiohttp", "boto3", "duckdb", "pandas (>=1.4)", "polars (>=0.19)", "pytest", "pytest-asyncio", "pytest-mock", "pytz", "tantivy"] + +[[package]] +name = "lancedb" +version = "0.9.0" +description = "lancedb" +optional = false +python-versions = ">=3.8" +files = [ + {file = "lancedb-0.9.0-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:b1ca08797c72c93ae512aa1078f1891756da157d910fbae8e194fac3528fc1ac"}, + {file = "lancedb-0.9.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:15129791f03c2c04b95f914ced2c1556b43d73a24710207b9af77b6e4008bdeb"}, + {file = "lancedb-0.9.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f093d89447a2039b820d2540a0b64df3024e4549b6808ebd26b44fbe0345cc6"}, + {file = "lancedb-0.9.0-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:a8c1f6777e217d2277451038866d280fa5fb38bd161795e51703b043c26dd345"}, + {file = "lancedb-0.9.0-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:78dd5800a1148f89d33b7e98d1c8b1c42dee146f03580abc1ca83cb05273ff7f"}, + {file = "lancedb-0.9.0-cp38-abi3-win_amd64.whl", hash = "sha256:ba5bdc727d3bc131f17414f42372acde5817073feeb553793a3d20003caa1658"}, +] + +[package.dependencies] +attrs = ">=21.3.0" +cachetools = "*" +deprecation = "*" +overrides = ">=0.7" +packaging = "*" +pydantic = ">=1.10" +pylance = "0.13.0" +ratelimiter = ">=1.0,<2.0" +requests = ">=2.31.0" +retry = ">=0.9.2" +tqdm = ">=4.27.0" + +[package.extras] +azure = ["adlfs (>=2024.2.0)"] +clip = ["open-clip", "pillow", "torch"] +dev = ["pre-commit", "ruff"] +docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] +embeddings = ["awscli (>=1.29.57)", "boto3 (>=1.28.57)", "botocore (>=1.31.57)", "cohere", "google-generativeai", "huggingface-hub", "instructorembedding", "ollama", "open-clip-torch", "openai (>=1.6.1)", "pillow", "sentence-transformers", "torch"] +tests = ["aiohttp", "boto3", "duckdb", "pandas (>=1.4)", "polars (>=0.19)", "pytest", "pytest-asyncio", "pytest-mock", "pytz", "tantivy"] + [[package]] name = "lazy-object-proxy" version = "1.9.0" @@ -4377,6 +4685,24 @@ sqlalchemy = ["sqlalchemy"] test = ["mock", "pytest", "pytest-cov (<2.6)"] zmq = ["pyzmq"] +[[package]] +name = "loguru" +version = "0.7.2" +description = "Python logging made (stupidly) simple" +optional = true +python-versions = ">=3.5" +files = [ + {file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"}, + {file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"}, +] + +[package.dependencies] +colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} +win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} + +[package.extras] +dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "mypy (==v1.5.1)", "pre-commit (==3.4.0)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"] + [[package]] name = "lxml" version = "4.9.3" @@ -4387,10 +4713,13 @@ files = [ {file = "lxml-4.9.3-cp27-cp27m-macosx_11_0_x86_64.whl", hash = "sha256:b0a545b46b526d418eb91754565ba5b63b1c0b12f9bd2f808c852d9b4b2f9b5c"}, {file = "lxml-4.9.3-cp27-cp27m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:075b731ddd9e7f68ad24c635374211376aa05a281673ede86cbe1d1b3455279d"}, {file = "lxml-4.9.3-cp27-cp27m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:1e224d5755dba2f4a9498e150c43792392ac9b5380aa1b845f98a1618c94eeef"}, + {file = "lxml-4.9.3-cp27-cp27m-win32.whl", hash = "sha256:2c74524e179f2ad6d2a4f7caf70e2d96639c0954c943ad601a9e146c76408ed7"}, + {file = "lxml-4.9.3-cp27-cp27m-win_amd64.whl", hash = "sha256:4f1026bc732b6a7f96369f7bfe1a4f2290fb34dce00d8644bc3036fb351a4ca1"}, {file = "lxml-4.9.3-cp27-cp27mu-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c0781a98ff5e6586926293e59480b64ddd46282953203c76ae15dbbbf302e8bb"}, {file = "lxml-4.9.3-cp27-cp27mu-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:cef2502e7e8a96fe5ad686d60b49e1ab03e438bd9123987994528febd569868e"}, {file = "lxml-4.9.3-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:b86164d2cff4d3aaa1f04a14685cbc072efd0b4f99ca5708b2ad1b9b5988a991"}, {file = "lxml-4.9.3-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:42871176e7896d5d45138f6d28751053c711ed4d48d8e30b498da155af39aebd"}, + {file = "lxml-4.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:ae8b9c6deb1e634ba4f1930eb67ef6e6bf6a44b6eb5ad605642b2d6d5ed9ce3c"}, {file = "lxml-4.9.3-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:411007c0d88188d9f621b11d252cce90c4a2d1a49db6c068e3c16422f306eab8"}, {file = "lxml-4.9.3-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:cd47b4a0d41d2afa3e58e5bf1f62069255aa2fd6ff5ee41604418ca925911d76"}, {file = "lxml-4.9.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0e2cb47860da1f7e9a5256254b74ae331687b9672dfa780eed355c4c9c3dbd23"}, @@ -4399,6 +4728,7 @@ files = [ {file = "lxml-4.9.3-cp310-cp310-win_amd64.whl", hash = "sha256:97047f0d25cd4bcae81f9ec9dc290ca3e15927c192df17331b53bebe0e3ff96d"}, {file = "lxml-4.9.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:1f447ea5429b54f9582d4b955f5f1985f278ce5cf169f72eea8afd9502973dd5"}, {file = "lxml-4.9.3-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:57d6ba0ca2b0c462f339640d22882acc711de224d769edf29962b09f77129cbf"}, + {file = "lxml-4.9.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:9767e79108424fb6c3edf8f81e6730666a50feb01a328f4a016464a5893f835a"}, {file = "lxml-4.9.3-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:71c52db65e4b56b8ddc5bb89fb2e66c558ed9d1a74a45ceb7dcb20c191c3df2f"}, {file = "lxml-4.9.3-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:d73d8ecf8ecf10a3bd007f2192725a34bd62898e8da27eb9d32a58084f93962b"}, {file = "lxml-4.9.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0a3d3487f07c1d7f150894c238299934a2a074ef590b583103a45002035be120"}, @@ -4418,6 +4748,7 @@ files = [ {file = "lxml-4.9.3-cp36-cp36m-macosx_11_0_x86_64.whl", hash = "sha256:64f479d719dc9f4c813ad9bb6b28f8390360660b73b2e4beb4cb0ae7104f1c12"}, {file = "lxml-4.9.3-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:dd708cf4ee4408cf46a48b108fb9427bfa00b9b85812a9262b5c668af2533ea5"}, {file = "lxml-4.9.3-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c31c7462abdf8f2ac0577d9f05279727e698f97ecbb02f17939ea99ae8daa98"}, + {file = "lxml-4.9.3-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:e3cd95e10c2610c360154afdc2f1480aea394f4a4f1ea0a5eacce49640c9b190"}, {file = "lxml-4.9.3-cp36-cp36m-manylinux_2_28_x86_64.whl", hash = "sha256:4930be26af26ac545c3dffb662521d4e6268352866956672231887d18f0eaab2"}, {file = "lxml-4.9.3-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4aec80cde9197340bc353d2768e2a75f5f60bacda2bab72ab1dc499589b3878c"}, {file = "lxml-4.9.3-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:14e019fd83b831b2e61baed40cab76222139926b1fb5ed0e79225bc0cae14584"}, @@ -4427,6 +4758,7 @@ files = [ {file = "lxml-4.9.3-cp36-cp36m-win_amd64.whl", hash = "sha256:bef4e656f7d98aaa3486d2627e7d2df1157d7e88e7efd43a65aa5dd4714916cf"}, {file = "lxml-4.9.3-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:46f409a2d60f634fe550f7133ed30ad5321ae2e6630f13657fb9479506b00601"}, {file = "lxml-4.9.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:4c28a9144688aef80d6ea666c809b4b0e50010a2aca784c97f5e6bf143d9f129"}, + {file = "lxml-4.9.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:141f1d1a9b663c679dc524af3ea1773e618907e96075262726c7612c02b149a4"}, {file = "lxml-4.9.3-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:53ace1c1fd5a74ef662f844a0413446c0629d151055340e9893da958a374f70d"}, {file = "lxml-4.9.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:17a753023436a18e27dd7769e798ce302963c236bc4114ceee5b25c18c52c693"}, {file = "lxml-4.9.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7d298a1bd60c067ea75d9f684f5f3992c9d6766fadbc0bcedd39750bf344c2f4"}, @@ -4436,6 +4768,7 @@ files = [ {file = "lxml-4.9.3-cp37-cp37m-win_amd64.whl", hash = "sha256:120fa9349a24c7043854c53cae8cec227e1f79195a7493e09e0c12e29f918e52"}, {file = "lxml-4.9.3-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:4d2d1edbca80b510443f51afd8496be95529db04a509bc8faee49c7b0fb6d2cc"}, {file = "lxml-4.9.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:8d7e43bd40f65f7d97ad8ef5c9b1778943d02f04febef12def25f7583d19baac"}, + {file = "lxml-4.9.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:71d66ee82e7417828af6ecd7db817913cb0cf9d4e61aa0ac1fde0583d84358db"}, {file = "lxml-4.9.3-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:6fc3c450eaa0b56f815c7b62f2b7fba7266c4779adcf1cece9e6deb1de7305ce"}, {file = "lxml-4.9.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:65299ea57d82fb91c7f019300d24050c4ddeb7c5a190e076b5f48a2b43d19c42"}, {file = "lxml-4.9.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:eadfbbbfb41b44034a4c757fd5d70baccd43296fb894dba0295606a7cf3124aa"}, @@ -4445,6 +4778,7 @@ files = [ {file = "lxml-4.9.3-cp38-cp38-win_amd64.whl", hash = "sha256:92af161ecbdb2883c4593d5ed4815ea71b31fafd7fd05789b23100d081ecac96"}, {file = "lxml-4.9.3-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:9bb6ad405121241e99a86efff22d3ef469024ce22875a7ae045896ad23ba2340"}, {file = "lxml-4.9.3-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:8ed74706b26ad100433da4b9d807eae371efaa266ffc3e9191ea436087a9d6a7"}, + {file = "lxml-4.9.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:fbf521479bcac1e25a663df882c46a641a9bff6b56dc8b0fafaebd2f66fb231b"}, {file = "lxml-4.9.3-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:303bf1edce6ced16bf67a18a1cf8339d0db79577eec5d9a6d4a80f0fb10aa2da"}, {file = "lxml-4.9.3-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:5515edd2a6d1a5a70bfcdee23b42ec33425e405c5b351478ab7dc9347228f96e"}, {file = "lxml-4.9.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:690dafd0b187ed38583a648076865d8c229661ed20e48f2335d68e2cf7dc829d"}, @@ -4455,13 +4789,16 @@ files = [ {file = "lxml-4.9.3-cp39-cp39-win_amd64.whl", hash = "sha256:4dd9a263e845a72eacb60d12401e37c616438ea2e5442885f65082c276dfb2b2"}, {file = "lxml-4.9.3-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:6689a3d7fd13dc687e9102a27e98ef33730ac4fe37795d5036d18b4d527abd35"}, {file = "lxml-4.9.3-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:f6bdac493b949141b733c5345b6ba8f87a226029cbabc7e9e121a413e49441e0"}, + {file = "lxml-4.9.3-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:05186a0f1346ae12553d66df1cfce6f251589fea3ad3da4f3ef4e34b2d58c6a3"}, {file = "lxml-4.9.3-pp37-pypy37_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:c2006f5c8d28dee289f7020f721354362fa304acbaaf9745751ac4006650254b"}, {file = "lxml-4.9.3-pp38-pypy38_pp73-macosx_11_0_x86_64.whl", hash = "sha256:5c245b783db29c4e4fbbbfc9c5a78be496c9fea25517f90606aa1f6b2b3d5f7b"}, {file = "lxml-4.9.3-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:4fb960a632a49f2f089d522f70496640fdf1218f1243889da3822e0a9f5f3ba7"}, + {file = "lxml-4.9.3-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:50670615eaf97227d5dc60de2dc99fb134a7130d310d783314e7724bf163f75d"}, {file = "lxml-4.9.3-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:9719fe17307a9e814580af1f5c6e05ca593b12fb7e44fe62450a5384dbf61b4b"}, {file = "lxml-4.9.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:3331bece23c9ee066e0fb3f96c61322b9e0f54d775fccefff4c38ca488de283a"}, {file = "lxml-4.9.3-pp39-pypy39_pp73-macosx_11_0_x86_64.whl", hash = "sha256:ed667f49b11360951e201453fc3967344d0d0263aa415e1619e85ae7fd17b4e0"}, {file = "lxml-4.9.3-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:8b77946fd508cbf0fccd8e400a7f71d4ac0e1595812e66025bac475a8e811694"}, + {file = "lxml-4.9.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:e4da8ca0c0c0aea88fd46be8e44bd49716772358d648cce45fe387f7b92374a7"}, {file = "lxml-4.9.3-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:fe4bda6bd4340caa6e5cf95e73f8fea5c4bfc55763dd42f1b50a94c1b4a2fbd4"}, {file = "lxml-4.9.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:f3df3db1d336b9356dd3112eae5f5c2b8b377f3bc826848567f10bfddfee77e9"}, {file = "lxml-4.9.3.tar.gz", hash = "sha256:48628bd53a426c9eb9bc066a923acaa0878d1e86129fd5359aee99285f4eed9c"}, @@ -4622,6 +4959,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -5407,35 +5754,36 @@ reference = ["Pillow", "google-re2"] [[package]] name = "onnxruntime" -version = "1.16.1" +version = "1.18.0" description = "ONNX Runtime is a runtime accelerator for Machine Learning models" optional = true python-versions = "*" files = [ - {file = "onnxruntime-1.16.1-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:28b2c7f444b4119950b69370801cd66067f403d19cbaf2a444735d7c269cce4a"}, - {file = "onnxruntime-1.16.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c24e04f33e7899f6aebb03ed51e51d346c1f906b05c5569d58ac9a12d38a2f58"}, - {file = "onnxruntime-1.16.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fa93b166f2d97063dc9f33c5118c5729a4a5dd5617296b6dbef42f9047b3e81"}, - {file = "onnxruntime-1.16.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:042dd9201b3016ee18f8f8bc4609baf11ff34ca1ff489c0a46bcd30919bf883d"}, - {file = "onnxruntime-1.16.1-cp310-cp310-win32.whl", hash = "sha256:c20aa0591f305012f1b21aad607ed96917c86ae7aede4a4dd95824b3d124ceb7"}, - {file = "onnxruntime-1.16.1-cp310-cp310-win_amd64.whl", hash = "sha256:5581873e578917bea76d6434ee7337e28195d03488dcf72d161d08e9398c6249"}, - {file = "onnxruntime-1.16.1-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:ef8c0c8abf5f309aa1caf35941380839dc5f7a2fa53da533be4a3f254993f120"}, - {file = "onnxruntime-1.16.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e680380bea35a137cbc3efd67a17486e96972901192ad3026ee79c8d8fe264f7"}, - {file = "onnxruntime-1.16.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e62cc38ce1a669013d0a596d984762dc9c67c56f60ecfeee0d5ad36da5863f6"}, - {file = "onnxruntime-1.16.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:025c7a4d57bd2e63b8a0f84ad3df53e419e3df1cc72d63184f2aae807b17c13c"}, - {file = "onnxruntime-1.16.1-cp311-cp311-win32.whl", hash = "sha256:9ad074057fa8d028df248b5668514088cb0937b6ac5954073b7fb9b2891ffc8c"}, - {file = "onnxruntime-1.16.1-cp311-cp311-win_amd64.whl", hash = "sha256:d5e43a3478bffc01f817ecf826de7b25a2ca1bca8547d70888594ab80a77ad24"}, - {file = "onnxruntime-1.16.1-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:3aef4d70b0930e29a8943eab248cd1565664458d3a62b2276bd11181f28fd0a3"}, - {file = "onnxruntime-1.16.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:55a7b843a57c8ca0c8ff169428137958146081d5d76f1a6dd444c4ffcd37c3c2"}, - {file = "onnxruntime-1.16.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:62c631af1941bf3b5f7d063d24c04aacce8cff0794e157c497e315e89ac5ad7b"}, - {file = "onnxruntime-1.16.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5671f296c3d5c233f601e97a10ab5a1dd8e65ba35c7b7b0c253332aba9dff330"}, - {file = "onnxruntime-1.16.1-cp38-cp38-win32.whl", hash = "sha256:eb3802305023dd05e16848d4e22b41f8147247894309c0c27122aaa08793b3d2"}, - {file = "onnxruntime-1.16.1-cp38-cp38-win_amd64.whl", hash = "sha256:fecfb07443d09d271b1487f401fbdf1ba0c829af6fd4fe8f6af25f71190e7eb9"}, - {file = "onnxruntime-1.16.1-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:de3e12094234db6545c67adbf801874b4eb91e9f299bda34c62967ef0050960f"}, - {file = "onnxruntime-1.16.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ff723c2a5621b5e7103f3be84d5aae1e03a20621e72219dddceae81f65f240af"}, - {file = "onnxruntime-1.16.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:14a7fb3073aaf6b462e3d7fb433320f7700558a8892e5021780522dc4574292a"}, - {file = "onnxruntime-1.16.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:963159f1f699b0454cd72fcef3276c8a1aab9389a7b301bcd8e320fb9d9e8597"}, - {file = "onnxruntime-1.16.1-cp39-cp39-win32.whl", hash = "sha256:85771adb75190db9364b25ddec353ebf07635b83eb94b64ed014f1f6d57a3857"}, - {file = "onnxruntime-1.16.1-cp39-cp39-win_amd64.whl", hash = "sha256:d32d2b30799c1f950123c60ae8390818381fd5f88bdf3627eeca10071c155dc5"}, + {file = "onnxruntime-1.18.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:5a3b7993a5ecf4a90f35542a4757e29b2d653da3efe06cdd3164b91167bbe10d"}, + {file = "onnxruntime-1.18.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:15b944623b2cdfe7f7945690bfb71c10a4531b51997c8320b84e7b0bb59af902"}, + {file = "onnxruntime-1.18.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2e61ce5005118064b1a0ed73ebe936bc773a102f067db34108ea6c64dd62a179"}, + {file = "onnxruntime-1.18.0-cp310-cp310-win32.whl", hash = "sha256:a4fc8a2a526eb442317d280610936a9f73deece06c7d5a91e51570860802b93f"}, + {file = "onnxruntime-1.18.0-cp310-cp310-win_amd64.whl", hash = "sha256:71ed219b768cab004e5cd83e702590734f968679bf93aa488c1a7ffbe6e220c3"}, + {file = "onnxruntime-1.18.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:3d24bd623872a72a7fe2f51c103e20fcca2acfa35d48f2accd6be1ec8633d960"}, + {file = "onnxruntime-1.18.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f15e41ca9b307a12550bfd2ec93f88905d9fba12bab7e578f05138ad0ae10d7b"}, + {file = "onnxruntime-1.18.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f45ca2887f62a7b847d526965686b2923efa72538c89b7703c7b3fe970afd59"}, + {file = "onnxruntime-1.18.0-cp311-cp311-win32.whl", hash = "sha256:9e24d9ecc8781323d9e2eeda019b4b24babc4d624e7d53f61b1fe1a929b0511a"}, + {file = "onnxruntime-1.18.0-cp311-cp311-win_amd64.whl", hash = "sha256:f8608398976ed18aef450d83777ff6f77d0b64eced1ed07a985e1a7db8ea3771"}, + {file = "onnxruntime-1.18.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:f1d79941f15fc40b1ee67738b2ca26b23e0181bf0070b5fb2984f0988734698f"}, + {file = "onnxruntime-1.18.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:99e8caf3a8565c853a22d323a3eebc2a81e3de7591981f085a4f74f7a60aab2d"}, + {file = "onnxruntime-1.18.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:498d2b8380635f5e6ebc50ec1b45f181588927280f32390fb910301d234f97b8"}, + {file = "onnxruntime-1.18.0-cp312-cp312-win32.whl", hash = "sha256:ba7cc0ce2798a386c082aaa6289ff7e9bedc3dee622eef10e74830cff200a72e"}, + {file = "onnxruntime-1.18.0-cp312-cp312-win_amd64.whl", hash = "sha256:1fa175bd43f610465d5787ae06050c81f7ce09da2bf3e914eb282cb8eab363ef"}, + {file = "onnxruntime-1.18.0-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:0284c579c20ec8b1b472dd190290a040cc68b6caec790edb960f065d15cf164a"}, + {file = "onnxruntime-1.18.0-cp38-cp38-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d47353d036d8c380558a5643ea5f7964d9d259d31c86865bad9162c3e916d1f6"}, + {file = "onnxruntime-1.18.0-cp38-cp38-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:885509d2b9ba4b01f08f7fa28d31ee54b6477953451c7ccf124a84625f07c803"}, + {file = "onnxruntime-1.18.0-cp38-cp38-win32.whl", hash = "sha256:8614733de3695656411d71fc2f39333170df5da6c7efd6072a59962c0bc7055c"}, + {file = "onnxruntime-1.18.0-cp38-cp38-win_amd64.whl", hash = "sha256:47af3f803752fce23ea790fd8d130a47b2b940629f03193f780818622e856e7a"}, + {file = "onnxruntime-1.18.0-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:9153eb2b4d5bbab764d0aea17adadffcfc18d89b957ad191b1c3650b9930c59f"}, + {file = "onnxruntime-1.18.0-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2c7fd86eca727c989bb8d9c5104f3c45f7ee45f445cc75579ebe55d6b99dfd7c"}, + {file = "onnxruntime-1.18.0-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ac67a4de9c1326c4d87bcbfb652c923039b8a2446bb28516219236bec3b494f5"}, + {file = "onnxruntime-1.18.0-cp39-cp39-win32.whl", hash = "sha256:6ffb445816d06497df7a6dd424b20e0b2c39639e01e7fe210e247b82d15a23b9"}, + {file = "onnxruntime-1.18.0-cp39-cp39-win_amd64.whl", hash = "sha256:46de6031cb6745f33f7eca9e51ab73e8c66037fb7a3b6b4560887c5b55ab5d5d"}, ] [package.dependencies] @@ -5446,6 +5794,29 @@ packaging = "*" protobuf = "*" sympy = "*" +[[package]] +name = "openai" +version = "1.35.3" +description = "The official Python library for the openai API" +optional = false +python-versions = ">=3.7.1" +files = [ + {file = "openai-1.35.3-py3-none-any.whl", hash = "sha256:7b26544cef80f125431c073ffab3811d2421fbb9e30d3bd5c2436aba00b042d5"}, + {file = "openai-1.35.3.tar.gz", hash = "sha256:d6177087f150b381d49499be782d764213fdf638d391b29ca692b84dd675a389"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tqdm = ">4" +typing-extensions = ">=4.7,<5" + +[package.extras] +datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] + [[package]] name = "openpyxl" version = "3.1.2" @@ -5592,71 +5963,68 @@ dev = ["black", "mypy", "pytest"] [[package]] name = "orjson" -version = "3.9.5" +version = "3.10.5" description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "orjson-3.9.5-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:ad6845912a71adcc65df7c8a7f2155eba2096cf03ad2c061c93857de70d699ad"}, - {file = "orjson-3.9.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e298e0aacfcc14ef4476c3f409e85475031de24e5b23605a465e9bf4b2156273"}, - {file = "orjson-3.9.5-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:83c9939073281ef7dd7c5ca7f54cceccb840b440cec4b8a326bda507ff88a0a6"}, - {file = "orjson-3.9.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e174cc579904a48ee1ea3acb7045e8a6c5d52c17688dfcb00e0e842ec378cabf"}, - {file = "orjson-3.9.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f8d51702f42c785b115401e1d64a27a2ea767ae7cf1fb8edaa09c7cf1571c660"}, - {file = "orjson-3.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f13d61c0c7414ddee1ef4d0f303e2222f8cced5a2e26d9774751aecd72324c9e"}, - {file = "orjson-3.9.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d748cc48caf5a91c883d306ab648df1b29e16b488c9316852844dd0fd000d1c2"}, - {file = "orjson-3.9.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bd19bc08fa023e4c2cbf8294ad3f2b8922f4de9ba088dbc71e6b268fdf54591c"}, - {file = "orjson-3.9.5-cp310-none-win32.whl", hash = "sha256:5793a21a21bf34e1767e3d61a778a25feea8476dcc0bdf0ae1bc506dc34561ea"}, - {file = "orjson-3.9.5-cp310-none-win_amd64.whl", hash = "sha256:2bcec0b1024d0031ab3eab7a8cb260c8a4e4a5e35993878a2da639d69cdf6a65"}, - {file = "orjson-3.9.5-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:8547b95ca0e2abd17e1471973e6d676f1d8acedd5f8fb4f739e0612651602d66"}, - {file = "orjson-3.9.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:87ce174d6a38d12b3327f76145acbd26f7bc808b2b458f61e94d83cd0ebb4d76"}, - {file = "orjson-3.9.5-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a960bb1bc9a964d16fcc2d4af5a04ce5e4dfddca84e3060c35720d0a062064fe"}, - {file = "orjson-3.9.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1a7aa5573a949760d6161d826d34dc36db6011926f836851fe9ccb55b5a7d8e8"}, - {file = "orjson-3.9.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8b2852afca17d7eea85f8e200d324e38c851c96598ac7b227e4f6c4e59fbd3df"}, - {file = "orjson-3.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa185959c082475288da90f996a82e05e0c437216b96f2a8111caeb1d54ef926"}, - {file = "orjson-3.9.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:89c9332695b838438ea4b9a482bce8ffbfddde4df92750522d928fb00b7b8dce"}, - {file = "orjson-3.9.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2493f1351a8f0611bc26e2d3d407efb873032b4f6b8926fed8cfed39210ca4ba"}, - {file = "orjson-3.9.5-cp311-none-win32.whl", hash = "sha256:ffc544e0e24e9ae69301b9a79df87a971fa5d1c20a6b18dca885699709d01be0"}, - {file = "orjson-3.9.5-cp311-none-win_amd64.whl", hash = "sha256:89670fe2732e3c0c54406f77cad1765c4c582f67b915c74fda742286809a0cdc"}, - {file = "orjson-3.9.5-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:15df211469625fa27eced4aa08dc03e35f99c57d45a33855cc35f218ea4071b8"}, - {file = "orjson-3.9.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9f17c59fe6c02bc5f89ad29edb0253d3059fe8ba64806d789af89a45c35269a"}, - {file = "orjson-3.9.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ca6b96659c7690773d8cebb6115c631f4a259a611788463e9c41e74fa53bf33f"}, - {file = "orjson-3.9.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a26fafe966e9195b149950334bdbe9026eca17fe8ffe2d8fa87fdc30ca925d30"}, - {file = "orjson-3.9.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9006b1eb645ecf460da067e2dd17768ccbb8f39b01815a571bfcfab7e8da5e52"}, - {file = "orjson-3.9.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ebfdbf695734b1785e792a1315e41835ddf2a3e907ca0e1c87a53f23006ce01d"}, - {file = "orjson-3.9.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:4a3943234342ab37d9ed78fb0a8f81cd4b9532f67bf2ac0d3aa45fa3f0a339f3"}, - {file = "orjson-3.9.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e6762755470b5c82f07b96b934af32e4d77395a11768b964aaa5eb092817bc31"}, - {file = "orjson-3.9.5-cp312-none-win_amd64.whl", hash = "sha256:c74df28749c076fd6e2157190df23d43d42b2c83e09d79b51694ee7315374ad5"}, - {file = "orjson-3.9.5-cp37-cp37m-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:88e18a74d916b74f00d0978d84e365c6bf0e7ab846792efa15756b5fb2f7d49d"}, - {file = "orjson-3.9.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d28514b5b6dfaf69097be70d0cf4f1407ec29d0f93e0b4131bf9cc8fd3f3e374"}, - {file = "orjson-3.9.5-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:25b81aca8c7be61e2566246b6a0ca49f8aece70dd3f38c7f5c837f398c4cb142"}, - {file = "orjson-3.9.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:385c1c713b1e47fd92e96cf55fd88650ac6dfa0b997e8aa7ecffd8b5865078b1"}, - {file = "orjson-3.9.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f9850c03a8e42fba1a508466e6a0f99472fd2b4a5f30235ea49b2a1b32c04c11"}, - {file = "orjson-3.9.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4449f84bbb13bcef493d8aa669feadfced0f7c5eea2d0d88b5cc21f812183af8"}, - {file = "orjson-3.9.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:86127bf194f3b873135e44ce5dc9212cb152b7e06798d5667a898a00f0519be4"}, - {file = "orjson-3.9.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:0abcd039f05ae9ab5b0ff11624d0b9e54376253b7d3217a358d09c3edf1d36f7"}, - {file = "orjson-3.9.5-cp37-none-win32.whl", hash = "sha256:10cc8ad5ff7188efcb4bec196009d61ce525a4e09488e6d5db41218c7fe4f001"}, - {file = "orjson-3.9.5-cp37-none-win_amd64.whl", hash = "sha256:ff27e98532cb87379d1a585837d59b187907228268e7b0a87abe122b2be6968e"}, - {file = "orjson-3.9.5-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:5bfa79916ef5fef75ad1f377e54a167f0de334c1fa4ebb8d0224075f3ec3d8c0"}, - {file = "orjson-3.9.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e87dfa6ac0dae764371ab19b35eaaa46dfcb6ef2545dfca03064f21f5d08239f"}, - {file = "orjson-3.9.5-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:50ced24a7b23058b469ecdb96e36607fc611cbaee38b58e62a55c80d1b3ad4e1"}, - {file = "orjson-3.9.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b1b74ea2a3064e1375da87788897935832e806cc784de3e789fd3c4ab8eb3fa5"}, - {file = "orjson-3.9.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7cb961efe013606913d05609f014ad43edfaced82a576e8b520a5574ce3b2b9"}, - {file = "orjson-3.9.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1225d2d5ee76a786bda02f8c5e15017462f8432bb960de13d7c2619dba6f0275"}, - {file = "orjson-3.9.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:f39f4b99199df05c7ecdd006086259ed25886cdbd7b14c8cdb10c7675cfcca7d"}, - {file = "orjson-3.9.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a461dc9fb60cac44f2d3218c36a0c1c01132314839a0e229d7fb1bba69b810d8"}, - {file = "orjson-3.9.5-cp38-none-win32.whl", hash = "sha256:dedf1a6173748202df223aea29de814b5836732a176b33501375c66f6ab7d822"}, - {file = "orjson-3.9.5-cp38-none-win_amd64.whl", hash = "sha256:fa504082f53efcbacb9087cc8676c163237beb6e999d43e72acb4bb6f0db11e6"}, - {file = "orjson-3.9.5-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:6900f0248edc1bec2a2a3095a78a7e3ef4e63f60f8ddc583687eed162eedfd69"}, - {file = "orjson-3.9.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17404333c40047888ac40bd8c4d49752a787e0a946e728a4e5723f111b6e55a5"}, - {file = "orjson-3.9.5-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0eefb7cfdd9c2bc65f19f974a5d1dfecbac711dae91ed635820c6b12da7a3c11"}, - {file = "orjson-3.9.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:68c78b2a3718892dc018adbc62e8bab6ef3c0d811816d21e6973dee0ca30c152"}, - {file = "orjson-3.9.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:591ad7d9e4a9f9b104486ad5d88658c79ba29b66c5557ef9edf8ca877a3f8d11"}, - {file = "orjson-3.9.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6cc2cbf302fbb2d0b2c3c142a663d028873232a434d89ce1b2604ebe5cc93ce8"}, - {file = "orjson-3.9.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b26b5aa5e9ee1bad2795b925b3adb1b1b34122cb977f30d89e0a1b3f24d18450"}, - {file = "orjson-3.9.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ef84724f7d29dcfe3aafb1fc5fc7788dca63e8ae626bb9298022866146091a3e"}, - {file = "orjson-3.9.5-cp39-none-win32.whl", hash = "sha256:664cff27f85939059472afd39acff152fbac9a091b7137092cb651cf5f7747b5"}, - {file = "orjson-3.9.5-cp39-none-win_amd64.whl", hash = "sha256:91dda66755795ac6100e303e206b636568d42ac83c156547634256a2e68de694"}, - {file = "orjson-3.9.5.tar.gz", hash = "sha256:6daf5ee0b3cf530b9978cdbf71024f1c16ed4a67d05f6ec435c6e7fe7a52724c"}, + {file = "orjson-3.10.5-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:545d493c1f560d5ccfc134803ceb8955a14c3fcb47bbb4b2fee0232646d0b932"}, + {file = "orjson-3.10.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4324929c2dd917598212bfd554757feca3e5e0fa60da08be11b4aa8b90013c1"}, + {file = "orjson-3.10.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c13ca5e2ddded0ce6a927ea5a9f27cae77eee4c75547b4297252cb20c4d30e6"}, + {file = "orjson-3.10.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b6c8e30adfa52c025f042a87f450a6b9ea29649d828e0fec4858ed5e6caecf63"}, + {file = "orjson-3.10.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:338fd4f071b242f26e9ca802f443edc588fa4ab60bfa81f38beaedf42eda226c"}, + {file = "orjson-3.10.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6970ed7a3126cfed873c5d21ece1cd5d6f83ca6c9afb71bbae21a0b034588d96"}, + {file = "orjson-3.10.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:235dadefb793ad12f7fa11e98a480db1f7c6469ff9e3da5e73c7809c700d746b"}, + {file = "orjson-3.10.5-cp310-none-win32.whl", hash = "sha256:be79e2393679eda6a590638abda16d167754393f5d0850dcbca2d0c3735cebe2"}, + {file = "orjson-3.10.5-cp310-none-win_amd64.whl", hash = "sha256:c4a65310ccb5c9910c47b078ba78e2787cb3878cdded1702ac3d0da71ddc5228"}, + {file = "orjson-3.10.5-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:cdf7365063e80899ae3a697def1277c17a7df7ccfc979990a403dfe77bb54d40"}, + {file = "orjson-3.10.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b68742c469745d0e6ca5724506858f75e2f1e5b59a4315861f9e2b1df77775a"}, + {file = "orjson-3.10.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7d10cc1b594951522e35a3463da19e899abe6ca95f3c84c69e9e901e0bd93d38"}, + {file = "orjson-3.10.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dcbe82b35d1ac43b0d84072408330fd3295c2896973112d495e7234f7e3da2e1"}, + {file = "orjson-3.10.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10c0eb7e0c75e1e486c7563fe231b40fdd658a035ae125c6ba651ca3b07936f5"}, + {file = "orjson-3.10.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:53ed1c879b10de56f35daf06dbc4a0d9a5db98f6ee853c2dbd3ee9d13e6f302f"}, + {file = "orjson-3.10.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:099e81a5975237fda3100f918839af95f42f981447ba8f47adb7b6a3cdb078fa"}, + {file = "orjson-3.10.5-cp311-none-win32.whl", hash = "sha256:1146bf85ea37ac421594107195db8bc77104f74bc83e8ee21a2e58596bfb2f04"}, + {file = "orjson-3.10.5-cp311-none-win_amd64.whl", hash = "sha256:36a10f43c5f3a55c2f680efe07aa93ef4a342d2960dd2b1b7ea2dd764fe4a37c"}, + {file = "orjson-3.10.5-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:68f85ecae7af14a585a563ac741b0547a3f291de81cd1e20903e79f25170458f"}, + {file = "orjson-3.10.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28afa96f496474ce60d3340fe8d9a263aa93ea01201cd2bad844c45cd21f5268"}, + {file = "orjson-3.10.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9cd684927af3e11b6e754df80b9ffafd9fb6adcaa9d3e8fdd5891be5a5cad51e"}, + {file = "orjson-3.10.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d21b9983da032505f7050795e98b5d9eee0df903258951566ecc358f6696969"}, + {file = "orjson-3.10.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ad1de7fef79736dde8c3554e75361ec351158a906d747bd901a52a5c9c8d24b"}, + {file = "orjson-3.10.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2d97531cdfe9bdd76d492e69800afd97e5930cb0da6a825646667b2c6c6c0211"}, + {file = "orjson-3.10.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d69858c32f09c3e1ce44b617b3ebba1aba030e777000ebdf72b0d8e365d0b2b3"}, + {file = "orjson-3.10.5-cp312-none-win32.whl", hash = "sha256:64c9cc089f127e5875901ac05e5c25aa13cfa5dbbbd9602bda51e5c611d6e3e2"}, + {file = "orjson-3.10.5-cp312-none-win_amd64.whl", hash = "sha256:b2efbd67feff8c1f7728937c0d7f6ca8c25ec81373dc8db4ef394c1d93d13dc5"}, + {file = "orjson-3.10.5-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:03b565c3b93f5d6e001db48b747d31ea3819b89abf041ee10ac6988886d18e01"}, + {file = "orjson-3.10.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:584c902ec19ab7928fd5add1783c909094cc53f31ac7acfada817b0847975f26"}, + {file = "orjson-3.10.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5a35455cc0b0b3a1eaf67224035f5388591ec72b9b6136d66b49a553ce9eb1e6"}, + {file = "orjson-3.10.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1670fe88b116c2745a3a30b0f099b699a02bb3482c2591514baf5433819e4f4d"}, + {file = "orjson-3.10.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:185c394ef45b18b9a7d8e8f333606e2e8194a50c6e3c664215aae8cf42c5385e"}, + {file = "orjson-3.10.5-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:ca0b3a94ac8d3886c9581b9f9de3ce858263865fdaa383fbc31c310b9eac07c9"}, + {file = "orjson-3.10.5-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:dfc91d4720d48e2a709e9c368d5125b4b5899dced34b5400c3837dadc7d6271b"}, + {file = "orjson-3.10.5-cp38-none-win32.whl", hash = "sha256:c05f16701ab2a4ca146d0bca950af254cb7c02f3c01fca8efbbad82d23b3d9d4"}, + {file = "orjson-3.10.5-cp38-none-win_amd64.whl", hash = "sha256:8a11d459338f96a9aa7f232ba95679fc0c7cedbd1b990d736467894210205c09"}, + {file = "orjson-3.10.5-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:85c89131d7b3218db1b24c4abecea92fd6c7f9fab87441cfc342d3acc725d807"}, + {file = "orjson-3.10.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb66215277a230c456f9038d5e2d84778141643207f85336ef8d2a9da26bd7ca"}, + {file = "orjson-3.10.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:51bbcdea96cdefa4a9b4461e690c75ad4e33796530d182bdd5c38980202c134a"}, + {file = "orjson-3.10.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dbead71dbe65f959b7bd8cf91e0e11d5338033eba34c114f69078d59827ee139"}, + {file = "orjson-3.10.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5df58d206e78c40da118a8c14fc189207fffdcb1f21b3b4c9c0c18e839b5a214"}, + {file = "orjson-3.10.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:c4057c3b511bb8aef605616bd3f1f002a697c7e4da6adf095ca5b84c0fd43595"}, + {file = "orjson-3.10.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:b39e006b00c57125ab974362e740c14a0c6a66ff695bff44615dcf4a70ce2b86"}, + {file = "orjson-3.10.5-cp39-none-win32.whl", hash = "sha256:eded5138cc565a9d618e111c6d5c2547bbdd951114eb822f7f6309e04db0fb47"}, + {file = "orjson-3.10.5-cp39-none-win_amd64.whl", hash = "sha256:cc28e90a7cae7fcba2493953cff61da5a52950e78dc2dacfe931a317ee3d8de7"}, + {file = "orjson-3.10.5.tar.gz", hash = "sha256:7a5baef8a4284405d96c90c7c62b755e9ef1ada84c2406c24a9ebec86b89f46d"}, +] + +[[package]] +name = "overrides" +version = "7.7.0" +description = "A decorator to automatically detect mismatch when overriding a method." +optional = false +python-versions = ">=3.6" +files = [ + {file = "overrides-7.7.0-py3-none-any.whl", hash = "sha256:c7ed9d062f78b8e4c1a7b70bd8796b35ead4d9f510227ef9c5dc7626c60d7e49"}, + {file = "overrides-7.7.0.tar.gz", hash = "sha256:55158fa3d93b98cc75299b1e67078ad9003ca27945c76162c1c0766d6f91820a"}, ] [[package]] @@ -6559,6 +6927,58 @@ dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pyte docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] +[[package]] +name = "pylance" +version = "0.10.12" +description = "python wrapper for Lance columnar format" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pylance-0.10.12-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:30cbcca078edeb37e11ae86cf9287d81ce6c0c07ba77239284b369a4b361497b"}, + {file = "pylance-0.10.12-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:e558163ff6035d518706cc66848497219ccc755e2972b8f3b1706a3e1fd800fd"}, + {file = "pylance-0.10.12-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75afb39f71d7f12429f9b4d380eb6cf6aed179ae5a1c5d16cc768373a1521f87"}, + {file = "pylance-0.10.12-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:3de391dfc3a99bdb245fd1e27ef242be769a94853f802ef57f246e9a21358d32"}, + {file = "pylance-0.10.12-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:34a5278b90f4cbcf21261353976127aa2ffbbd7d068810f0a2b0c1aa0334022a"}, + {file = "pylance-0.10.12-cp38-abi3-win_amd64.whl", hash = "sha256:6cef5975d513097fd2c22692296c9a5a138928f38d02cd34ab63a7369abc1463"}, +] + +[package.dependencies] +numpy = ">=1.22" +pyarrow = ">=12,<15.0.1" + +[package.extras] +benchmarks = ["pytest-benchmark"] +dev = ["ruff (==0.2.2)"] +ray = ["ray[data]"] +tests = ["boto3", "datasets", "duckdb", "h5py (<3.11)", "ml-dtypes", "pandas", "pillow", "polars[pandas,pyarrow]", "pytest", "tensorflow", "tqdm"] +torch = ["torch"] + +[[package]] +name = "pylance" +version = "0.13.0" +description = "python wrapper for Lance columnar format" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pylance-0.13.0-cp39-abi3-macosx_10_15_x86_64.whl", hash = "sha256:2f3d6f9eec1f59f45dccb01075ba79868b8d37c8371d6210bcf6418217a0dd8b"}, + {file = "pylance-0.13.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:f4861ab466c94b0f9a4b4e6de6e1dfa02f40e7242d8db87447bc7bb7d89606ac"}, + {file = "pylance-0.13.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3cb92547e145f5bfb0ea7d6f483953913b9bdd44c45bea84fc95a18da9f5853"}, + {file = "pylance-0.13.0-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:d1ddd7700924bc6b6b0774ea63d2aa23f9210a86cd6d6af0cdfa987df776d50d"}, + {file = "pylance-0.13.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:c51d4b6e59cf4dc97c11a35b299f11e80dbdf392e2d8dc498573c26474a3c19e"}, + {file = "pylance-0.13.0-cp39-abi3-win_amd64.whl", hash = "sha256:4018ba016f1445874960a4ba2ad5c80cb380f3116683282ee8beabd38fa8989d"}, +] + +[package.dependencies] +numpy = ">=1.22" +pyarrow = ">=12,<15.0.1" + +[package.extras] +benchmarks = ["pytest-benchmark"] +dev = ["ruff (==0.4.1)"] +ray = ["ray[data]"] +tests = ["boto3", "datasets", "duckdb", "h5py (<3.11)", "ml-dtypes", "pandas", "pillow", "polars[pandas,pyarrow]", "pytest", "tensorflow", "tqdm"] +torch = ["torch"] + [[package]] name = "pymongo" version = "4.6.0" @@ -6596,6 +7016,7 @@ files = [ {file = "pymongo-4.6.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ab6bcc8e424e07c1d4ba6df96f7fb963bcb48f590b9456de9ebd03b88084fe8"}, {file = "pymongo-4.6.0-cp312-cp312-win32.whl", hash = "sha256:47aa128be2e66abd9d1a9b0437c62499d812d291f17b55185cb4aa33a5f710a4"}, {file = "pymongo-4.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:014e7049dd019a6663747ca7dae328943e14f7261f7c1381045dfc26a04fa330"}, + {file = "pymongo-4.6.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e24025625bad66895b1bc3ae1647f48f0a92dd014108fb1be404c77f0b69ca67"}, {file = "pymongo-4.6.0-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:288c21ab9531b037f7efa4e467b33176bc73a0c27223c141b822ab4a0e66ff2a"}, {file = "pymongo-4.6.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:747c84f4e690fbe6999c90ac97246c95d31460d890510e4a3fa61b7d2b87aa34"}, {file = "pymongo-4.6.0-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:055f5c266e2767a88bb585d01137d9c7f778b0195d3dbf4a487ef0638be9b651"}, @@ -7036,6 +7457,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -7043,8 +7465,16 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -7061,6 +7491,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -7068,6 +7499,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -7075,30 +7507,44 @@ files = [ [[package]] name = "qdrant-client" -version = "1.6.4" +version = "1.9.1" description = "Client library for the Qdrant vector search engine" optional = true -python-versions = ">=3.8,<3.13" +python-versions = ">=3.8" files = [ - {file = "qdrant_client-1.6.4-py3-none-any.whl", hash = "sha256:db4696978d6a62d78ff60f70b912383f1e467bda3053f732b01ddb5f93281b10"}, - {file = "qdrant_client-1.6.4.tar.gz", hash = "sha256:bbd65f383b6a55a9ccf4e301250fa925179340dd90cfde9b93ce4230fd68867b"}, + {file = "qdrant_client-1.9.1-py3-none-any.whl", hash = "sha256:b9b7e0e5c1a51410d8bb5106a869a51e12f92ab45a99030f27aba790553bd2c8"}, + {file = "qdrant_client-1.9.1.tar.gz", hash = "sha256:186b9c31d95aefe8f2db84b7746402d7365bd63b305550e530e31bde2002ce79"}, ] [package.dependencies] -fastembed = {version = "0.1.1", optional = true, markers = "python_version < \"3.12\" and extra == \"fastembed\""} +fastembed = {version = "0.2.6", optional = true, markers = "python_version < \"3.13\" and extra == \"fastembed\""} grpcio = ">=1.41.0" grpcio-tools = ">=1.41.0" -httpx = {version = ">=0.14.0", extras = ["http2"]} +httpx = {version = ">=0.20.0", extras = ["http2"]} numpy = [ {version = ">=1.21", markers = "python_version >= \"3.8\" and python_version < \"3.12\""}, {version = ">=1.26", markers = "python_version >= \"3.12\""}, ] portalocker = ">=2.7.0,<3.0.0" pydantic = ">=1.10.8" -urllib3 = ">=1.26.14,<2.0.0" +urllib3 = ">=1.26.14,<3" + +[package.extras] +fastembed = ["fastembed (==0.2.6)"] + +[[package]] +name = "ratelimiter" +version = "1.2.0.post0" +description = "Simple python rate limiting object" +optional = false +python-versions = "*" +files = [ + {file = "ratelimiter-1.2.0.post0-py3-none-any.whl", hash = "sha256:a52be07bc0bb0b3674b4b304550f10c769bbb00fead3072e035904474259809f"}, + {file = "ratelimiter-1.2.0.post0.tar.gz", hash = "sha256:5c395dcabdbbde2e5178ef3f89b568a3066454a6ddc223b76473dac22f89b4f7"}, +] [package.extras] -fastembed = ["fastembed (==0.1.1)"] +test = ["pytest (>=3.0)", "pytest-asyncio"] [[package]] name = "redshift-connector" @@ -7327,6 +7773,21 @@ files = [ [package.dependencies] types-setuptools = ">=57.0.0" +[[package]] +name = "retry" +version = "0.9.2" +description = "Easy to use retry decorator." +optional = false +python-versions = "*" +files = [ + {file = "retry-0.9.2-py2.py3-none-any.whl", hash = "sha256:ccddf89761fa2c726ab29391837d4327f819ea14d244c232a1d24c67a2f98606"}, + {file = "retry-0.9.2.tar.gz", hash = "sha256:f8bfa8b99b69c4506d6f5bd3b0aabf77f98cdb17f3c9fc3f5ca820033336fba4"}, +] + +[package.dependencies] +decorator = ">=3.4.2" +py = ">=1.4.26,<2.0.0" + [[package]] name = "rfc3339-validator" version = "0.1.4" @@ -7962,6 +8423,7 @@ files = [ {file = "SQLAlchemy-1.4.49-cp27-cp27mu-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:03db81b89fe7ef3857b4a00b63dedd632d6183d4ea5a31c5d8a92e000a41fc71"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:95b9df9afd680b7a3b13b38adf6e3a38995da5e162cc7524ef08e3be4e5ed3e1"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a63e43bf3f668c11bb0444ce6e809c1227b8f067ca1068898f3008a273f52b09"}, + {file = "SQLAlchemy-1.4.49-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca46de16650d143a928d10842939dab208e8d8c3a9a8757600cae9b7c579c5cd"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f835c050ebaa4e48b18403bed2c0fda986525896efd76c245bdd4db995e51a4c"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c21b172dfb22e0db303ff6419451f0cac891d2e911bb9fbf8003d717f1bcf91"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-win32.whl", hash = "sha256:5fb1ebdfc8373b5a291485757bd6431de8d7ed42c27439f543c81f6c8febd729"}, @@ -7971,26 +8433,35 @@ files = [ {file = "SQLAlchemy-1.4.49-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5debe7d49b8acf1f3035317e63d9ec8d5e4d904c6e75a2a9246a119f5f2fdf3d"}, {file = "SQLAlchemy-1.4.49-cp311-cp311-win32.whl", hash = "sha256:82b08e82da3756765c2e75f327b9bf6b0f043c9c3925fb95fb51e1567fa4ee87"}, {file = "SQLAlchemy-1.4.49-cp311-cp311-win_amd64.whl", hash = "sha256:171e04eeb5d1c0d96a544caf982621a1711d078dbc5c96f11d6469169bd003f1"}, + {file = "SQLAlchemy-1.4.49-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f23755c384c2969ca2f7667a83f7c5648fcf8b62a3f2bbd883d805454964a800"}, + {file = "SQLAlchemy-1.4.49-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8396e896e08e37032e87e7fbf4a15f431aa878c286dc7f79e616c2feacdb366c"}, + {file = "SQLAlchemy-1.4.49-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66da9627cfcc43bbdebd47bfe0145bb662041472393c03b7802253993b6b7c90"}, + {file = "SQLAlchemy-1.4.49-cp312-cp312-win32.whl", hash = "sha256:9a06e046ffeb8a484279e54bda0a5abfd9675f594a2e38ef3133d7e4d75b6214"}, + {file = "SQLAlchemy-1.4.49-cp312-cp312-win_amd64.whl", hash = "sha256:7cf8b90ad84ad3a45098b1c9f56f2b161601e4670827d6b892ea0e884569bd1d"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:36e58f8c4fe43984384e3fbe6341ac99b6b4e083de2fe838f0fdb91cebe9e9cb"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b31e67ff419013f99ad6f8fc73ee19ea31585e1e9fe773744c0f3ce58c039c30"}, + {file = "SQLAlchemy-1.4.49-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ebc22807a7e161c0d8f3da34018ab7c97ef6223578fcdd99b1d3e7ed1100a5db"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c14b29d9e1529f99efd550cd04dbb6db6ba5d690abb96d52de2bff4ed518bc95"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c40f3470e084d31247aea228aa1c39bbc0904c2b9ccbf5d3cfa2ea2dac06f26d"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-win32.whl", hash = "sha256:706bfa02157b97c136547c406f263e4c6274a7b061b3eb9742915dd774bbc264"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-win_amd64.whl", hash = "sha256:a7f7b5c07ae5c0cfd24c2db86071fb2a3d947da7bd487e359cc91e67ac1c6d2e"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:4afbbf5ef41ac18e02c8dc1f86c04b22b7a2125f2a030e25bbb4aff31abb224b"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:24e300c0c2147484a002b175f4e1361f102e82c345bf263242f0449672a4bccf"}, + {file = "SQLAlchemy-1.4.49-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:393cd06c3b00b57f5421e2133e088df9cabcececcea180327e43b937b5a7caa5"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:201de072b818f8ad55c80d18d1a788729cccf9be6d9dc3b9d8613b053cd4836d"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7653ed6817c710d0c95558232aba799307d14ae084cc9b1f4c389157ec50df5c"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-win32.whl", hash = "sha256:647e0b309cb4512b1f1b78471fdaf72921b6fa6e750b9f891e09c6e2f0e5326f"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-win_amd64.whl", hash = "sha256:ab73ed1a05ff539afc4a7f8cf371764cdf79768ecb7d2ec691e3ff89abbc541e"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:37ce517c011560d68f1ffb28af65d7e06f873f191eb3a73af5671e9c3fada08a"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1878ce508edea4a879015ab5215546c444233881301e97ca16fe251e89f1c55"}, + {file = "SQLAlchemy-1.4.49-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95ab792ca493891d7a45a077e35b418f68435efb3e1706cb8155e20e86a9013c"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:0e8e608983e6f85d0852ca61f97e521b62e67969e6e640fe6c6b575d4db68557"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ccf956da45290df6e809ea12c54c02ace7f8ff4d765d6d3dfb3655ee876ce58d"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-win32.whl", hash = "sha256:f167c8175ab908ce48bd6550679cc6ea20ae169379e73c7720a28f89e53aa532"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-win_amd64.whl", hash = "sha256:45806315aae81a0c202752558f0df52b42d11dd7ba0097bf71e253b4215f34f4"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:b6d0c4b15d65087738a6e22e0ff461b407533ff65a73b818089efc8eb2b3e1de"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a843e34abfd4c797018fd8d00ffffa99fd5184c421f190b6ca99def4087689bd"}, + {file = "SQLAlchemy-1.4.49-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:738d7321212941ab19ba2acf02a68b8ee64987b248ffa2101630e8fccb549e0d"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:1c890421651b45a681181301b3497e4d57c0d01dc001e10438a40e9a9c25ee77"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d26f280b8f0a8f497bc10573849ad6dc62e671d2468826e5c748d04ed9e670d5"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-win32.whl", hash = "sha256:ec2268de67f73b43320383947e74700e95c6770d0c68c4e615e9897e46296294"}, @@ -8220,56 +8691,129 @@ twisted = ["twisted"] [[package]] name = "tokenizers" -version = "0.13.3" -description = "Fast and Customizable Tokenizers" +version = "0.15.2" +description = "" optional = true -python-versions = "*" +python-versions = ">=3.7" files = [ - {file = "tokenizers-0.13.3-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:f3835c5be51de8c0a092058a4d4380cb9244fb34681fd0a295fbf0a52a5fdf33"}, - {file = "tokenizers-0.13.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4ef4c3e821730f2692489e926b184321e887f34fb8a6b80b8096b966ba663d07"}, - {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5fd1a6a25353e9aa762e2aae5a1e63883cad9f4e997c447ec39d071020459bc"}, - {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee0b1b311d65beab83d7a41c56a1e46ab732a9eed4460648e8eb0bd69fc2d059"}, - {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ef4215284df1277dadbcc5e17d4882bda19f770d02348e73523f7e7d8b8d396"}, - {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4d53976079cff8a033f778fb9adca2d9d69d009c02fa2d71a878b5f3963ed30"}, - {file = "tokenizers-0.13.3-cp310-cp310-win32.whl", hash = "sha256:1f0e3b4c2ea2cd13238ce43548959c118069db7579e5d40ec270ad77da5833ce"}, - {file = "tokenizers-0.13.3-cp310-cp310-win_amd64.whl", hash = "sha256:89649c00d0d7211e8186f7a75dfa1db6996f65edce4b84821817eadcc2d3c79e"}, - {file = "tokenizers-0.13.3-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:56b726e0d2bbc9243872b0144515ba684af5b8d8cd112fb83ee1365e26ec74c8"}, - {file = "tokenizers-0.13.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:cc5c022ce692e1f499d745af293ab9ee6f5d92538ed2faf73f9708c89ee59ce6"}, - {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f55c981ac44ba87c93e847c333e58c12abcbb377a0c2f2ef96e1a266e4184ff2"}, - {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f247eae99800ef821a91f47c5280e9e9afaeed9980fc444208d5aa6ba69ff148"}, - {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b3e3215d048e94f40f1c95802e45dcc37c5b05eb46280fc2ccc8cd351bff839"}, - {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ba2b0bf01777c9b9bc94b53764d6684554ce98551fec496f71bc5be3a03e98b"}, - {file = "tokenizers-0.13.3-cp311-cp311-win32.whl", hash = "sha256:cc78d77f597d1c458bf0ea7c2a64b6aa06941c7a99cb135b5969b0278824d808"}, - {file = "tokenizers-0.13.3-cp311-cp311-win_amd64.whl", hash = "sha256:ecf182bf59bd541a8876deccf0360f5ae60496fd50b58510048020751cf1724c"}, - {file = "tokenizers-0.13.3-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:0527dc5436a1f6bf2c0327da3145687d3bcfbeab91fed8458920093de3901b44"}, - {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07cbb2c307627dc99b44b22ef05ff4473aa7c7cc1fec8f0a8b37d8a64b1a16d2"}, - {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4560dbdeaae5b7ee0d4e493027e3de6d53c991b5002d7ff95083c99e11dd5ac0"}, - {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64064bd0322405c9374305ab9b4c07152a1474370327499911937fd4a76d004b"}, - {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8c6e2ab0f2e3d939ca66aa1d596602105fe33b505cd2854a4c1717f704c51de"}, - {file = "tokenizers-0.13.3-cp37-cp37m-win32.whl", hash = "sha256:6cc29d410768f960db8677221e497226e545eaaea01aa3613fa0fdf2cc96cff4"}, - {file = "tokenizers-0.13.3-cp37-cp37m-win_amd64.whl", hash = "sha256:fc2a7fdf864554a0dacf09d32e17c0caa9afe72baf9dd7ddedc61973bae352d8"}, - {file = "tokenizers-0.13.3-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:8791dedba834c1fc55e5f1521be325ea3dafb381964be20684b92fdac95d79b7"}, - {file = "tokenizers-0.13.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:d607a6a13718aeb20507bdf2b96162ead5145bbbfa26788d6b833f98b31b26e1"}, - {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3791338f809cd1bf8e4fee6b540b36822434d0c6c6bc47162448deee3f77d425"}, - {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c2f35f30e39e6aab8716f07790f646bdc6e4a853816cc49a95ef2a9016bf9ce6"}, - {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:310204dfed5aa797128b65d63538a9837cbdd15da2a29a77d67eefa489edda26"}, - {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0f9b92ea052305166559f38498b3b0cae159caea712646648aaa272f7160963"}, - {file = "tokenizers-0.13.3-cp38-cp38-win32.whl", hash = "sha256:9a3fa134896c3c1f0da6e762d15141fbff30d094067c8f1157b9fdca593b5806"}, - {file = "tokenizers-0.13.3-cp38-cp38-win_amd64.whl", hash = "sha256:8e7b0cdeace87fa9e760e6a605e0ae8fc14b7d72e9fc19c578116f7287bb873d"}, - {file = "tokenizers-0.13.3-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:00cee1e0859d55507e693a48fa4aef07060c4bb6bd93d80120e18fea9371c66d"}, - {file = "tokenizers-0.13.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:a23ff602d0797cea1d0506ce69b27523b07e70f6dda982ab8cf82402de839088"}, - {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70ce07445050b537d2696022dafb115307abdffd2a5c106f029490f84501ef97"}, - {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:280ffe95f50eaaf655b3a1dc7ff1d9cf4777029dbbc3e63a74e65a056594abc3"}, - {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97acfcec592f7e9de8cadcdcda50a7134423ac8455c0166b28c9ff04d227b371"}, - {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd7730c98a3010cd4f523465867ff95cd9d6430db46676ce79358f65ae39797b"}, - {file = "tokenizers-0.13.3-cp39-cp39-win32.whl", hash = "sha256:48625a108029cb1ddf42e17a81b5a3230ba6888a70c9dc14e81bc319e812652d"}, - {file = "tokenizers-0.13.3-cp39-cp39-win_amd64.whl", hash = "sha256:bc0a6f1ba036e482db6453571c9e3e60ecd5489980ffd95d11dc9f960483d783"}, - {file = "tokenizers-0.13.3.tar.gz", hash = "sha256:2e546dbb68b623008a5442353137fbb0123d311a6d7ba52f2667c8862a75af2e"}, -] - -[package.extras] -dev = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] -docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] + {file = "tokenizers-0.15.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:52f6130c9cbf70544287575a985bf44ae1bda2da7e8c24e97716080593638012"}, + {file = "tokenizers-0.15.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:054c1cc9c6d68f7ffa4e810b3d5131e0ba511b6e4be34157aa08ee54c2f8d9ee"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a9b9b070fdad06e347563b88c278995735292ded1132f8657084989a4c84a6d5"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea621a7eef4b70e1f7a4e84dd989ae3f0eeb50fc8690254eacc08acb623e82f1"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cf7fd9a5141634fa3aa8d6b7be362e6ae1b4cda60da81388fa533e0b552c98fd"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:44f2a832cd0825295f7179eaf173381dc45230f9227ec4b44378322d900447c9"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8b9ec69247a23747669ec4b0ca10f8e3dfb3545d550258129bd62291aabe8605"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40b6a4c78da863ff26dbd5ad9a8ecc33d8a8d97b535172601cf00aee9d7ce9ce"}, + {file = "tokenizers-0.15.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5ab2a4d21dcf76af60e05af8063138849eb1d6553a0d059f6534357bce8ba364"}, + {file = "tokenizers-0.15.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a47acfac7e511f6bbfcf2d3fb8c26979c780a91e06fb5b9a43831b2c0153d024"}, + {file = "tokenizers-0.15.2-cp310-none-win32.whl", hash = "sha256:064ff87bb6acdbd693666de9a4b692add41308a2c0ec0770d6385737117215f2"}, + {file = "tokenizers-0.15.2-cp310-none-win_amd64.whl", hash = "sha256:3b919afe4df7eb6ac7cafd2bd14fb507d3f408db7a68c43117f579c984a73843"}, + {file = "tokenizers-0.15.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:89cd1cb93e4b12ff39bb2d626ad77e35209de9309a71e4d3d4672667b4b256e7"}, + {file = "tokenizers-0.15.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:cfed5c64e5be23d7ee0f0e98081a25c2a46b0b77ce99a4f0605b1ec43dd481fa"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a907d76dcfda37023ba203ab4ceeb21bc5683436ebefbd895a0841fd52f6f6f2"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:20ea60479de6fc7b8ae756b4b097572372d7e4032e2521c1bbf3d90c90a99ff0"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:48e2b9335be2bc0171df9281385c2ed06a15f5cf121c44094338306ab7b33f2c"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:112a1dd436d2cc06e6ffdc0b06d55ac019a35a63afd26475205cb4b1bf0bfbff"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4620cca5c2817177ee8706f860364cc3a8845bc1e291aaf661fb899e5d1c45b0"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ccd73a82751c523b3fc31ff8194702e4af4db21dc20e55b30ecc2079c5d43cb7"}, + {file = "tokenizers-0.15.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:107089f135b4ae7817affe6264f8c7a5c5b4fd9a90f9439ed495f54fcea56fb4"}, + {file = "tokenizers-0.15.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0ff110ecc57b7aa4a594396525a3451ad70988e517237fe91c540997c4e50e29"}, + {file = "tokenizers-0.15.2-cp311-none-win32.whl", hash = "sha256:6d76f00f5c32da36c61f41c58346a4fa7f0a61be02f4301fd30ad59834977cc3"}, + {file = "tokenizers-0.15.2-cp311-none-win_amd64.whl", hash = "sha256:cc90102ed17271cf0a1262babe5939e0134b3890345d11a19c3145184b706055"}, + {file = "tokenizers-0.15.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f86593c18d2e6248e72fb91c77d413a815153b8ea4e31f7cd443bdf28e467670"}, + {file = "tokenizers-0.15.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0774bccc6608eca23eb9d620196687c8b2360624619623cf4ba9dc9bd53e8b51"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:d0222c5b7c9b26c0b4822a82f6a7011de0a9d3060e1da176f66274b70f846b98"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3835738be1de66624fff2f4f6f6684775da4e9c00bde053be7564cbf3545cc66"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0143e7d9dcd811855c1ce1ab9bf5d96d29bf5e528fd6c7824d0465741e8c10fd"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db35825f6d54215f6b6009a7ff3eedee0848c99a6271c870d2826fbbedf31a38"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3f5e64b0389a2be47091d8cc53c87859783b837ea1a06edd9d8e04004df55a5c"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e0480c452217edd35eca56fafe2029fb4d368b7c0475f8dfa3c5c9c400a7456"}, + {file = "tokenizers-0.15.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a33ab881c8fe70474980577e033d0bc9a27b7ab8272896e500708b212995d834"}, + {file = "tokenizers-0.15.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a308a607ca9de2c64c1b9ba79ec9a403969715a1b8ba5f998a676826f1a7039d"}, + {file = "tokenizers-0.15.2-cp312-none-win32.whl", hash = "sha256:b8fcfa81bcb9447df582c5bc96a031e6df4da2a774b8080d4f02c0c16b42be0b"}, + {file = "tokenizers-0.15.2-cp312-none-win_amd64.whl", hash = "sha256:38d7ab43c6825abfc0b661d95f39c7f8af2449364f01d331f3b51c94dcff7221"}, + {file = "tokenizers-0.15.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:38bfb0204ff3246ca4d5e726e8cc8403bfc931090151e6eede54d0e0cf162ef0"}, + {file = "tokenizers-0.15.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:9c861d35e8286a53e06e9e28d030b5a05bcbf5ac9d7229e561e53c352a85b1fc"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:936bf3842db5b2048eaa53dade907b1160f318e7c90c74bfab86f1e47720bdd6"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:620beacc3373277700d0e27718aa8b25f7b383eb8001fba94ee00aeea1459d89"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2735ecbbf37e52db4ea970e539fd2d450d213517b77745114f92867f3fc246eb"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:473c83c5e2359bb81b0b6fde870b41b2764fcdd36d997485e07e72cc3a62264a"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:968fa1fb3c27398b28a4eca1cbd1e19355c4d3a6007f7398d48826bbe3a0f728"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:865c60ae6eaebdde7da66191ee9b7db52e542ed8ee9d2c653b6d190a9351b980"}, + {file = "tokenizers-0.15.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7c0d8b52664ab2d4a8d6686eb5effc68b78608a9008f086a122a7b2996befbab"}, + {file = "tokenizers-0.15.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:f33dfbdec3784093a9aebb3680d1f91336c56d86cc70ddf88708251da1fe9064"}, + {file = "tokenizers-0.15.2-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:d44ba80988ff9424e33e0a49445072ac7029d8c0e1601ad25a0ca5f41ed0c1d6"}, + {file = "tokenizers-0.15.2-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:dce74266919b892f82b1b86025a613956ea0ea62a4843d4c4237be2c5498ed3a"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0ef06b9707baeb98b316577acb04f4852239d856b93e9ec3a299622f6084e4be"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c73e2e74bbb07910da0d37c326869f34113137b23eadad3fc00856e6b3d9930c"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4eeb12daf02a59e29f578a865f55d87cd103ce62bd8a3a5874f8fdeaa82e336b"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ba9f6895af58487ca4f54e8a664a322f16c26bbb442effd01087eba391a719e"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ccec77aa7150e38eec6878a493bf8c263ff1fa8a62404e16c6203c64c1f16a26"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3f40604f5042ff210ba82743dda2b6aa3e55aa12df4e9f2378ee01a17e2855e"}, + {file = "tokenizers-0.15.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5645938a42d78c4885086767c70923abad047163d809c16da75d6b290cb30bbe"}, + {file = "tokenizers-0.15.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:05a77cbfebe28a61ab5c3891f9939cc24798b63fa236d84e5f29f3a85a200c00"}, + {file = "tokenizers-0.15.2-cp37-none-win32.whl", hash = "sha256:361abdc068e8afe9c5b818769a48624687fb6aaed49636ee39bec4e95e1a215b"}, + {file = "tokenizers-0.15.2-cp37-none-win_amd64.whl", hash = "sha256:7ef789f83eb0f9baeb4d09a86cd639c0a5518528f9992f38b28e819df397eb06"}, + {file = "tokenizers-0.15.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:4fe1f74a902bee74a3b25aff180fbfbf4f8b444ab37c4d496af7afd13a784ed2"}, + {file = "tokenizers-0.15.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4c4b89038a684f40a6b15d6b09f49650ac64d951ad0f2a3ea9169687bbf2a8ba"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:d05a1b06f986d41aed5f2de464c003004b2df8aaf66f2b7628254bcbfb72a438"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:508711a108684111ec8af89d3a9e9e08755247eda27d0ba5e3c50e9da1600f6d"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:daa348f02d15160cb35439098ac96e3a53bacf35885072611cd9e5be7d333daa"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:494fdbe5932d3416de2a85fc2470b797e6f3226c12845cadf054dd906afd0442"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c2d60f5246f4da9373f75ff18d64c69cbf60c3bca597290cea01059c336d2470"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93268e788825f52de4c7bdcb6ebc1fcd4a5442c02e730faa9b6b08f23ead0e24"}, + {file = "tokenizers-0.15.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6fc7083ab404019fc9acafe78662c192673c1e696bd598d16dc005bd663a5cf9"}, + {file = "tokenizers-0.15.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:41e39b41e5531d6b2122a77532dbea60e171ef87a3820b5a3888daa847df4153"}, + {file = "tokenizers-0.15.2-cp38-none-win32.whl", hash = "sha256:06cd0487b1cbfabefb2cc52fbd6b1f8d4c37799bd6c6e1641281adaa6b2504a7"}, + {file = "tokenizers-0.15.2-cp38-none-win_amd64.whl", hash = "sha256:5179c271aa5de9c71712e31cb5a79e436ecd0d7532a408fa42a8dbfa4bc23fd9"}, + {file = "tokenizers-0.15.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:82f8652a74cc107052328b87ea8b34291c0f55b96d8fb261b3880216a9f9e48e"}, + {file = "tokenizers-0.15.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:02458bee6f5f3139f1ebbb6d042b283af712c0981f5bc50edf771d6b762d5e4f"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:c9a09cd26cca2e1c349f91aa665309ddb48d71636370749414fbf67bc83c5343"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:158be8ea8554e5ed69acc1ce3fbb23a06060bd4bbb09029431ad6b9a466a7121"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ddba9a2b0c8c81633eca0bb2e1aa5b3a15362b1277f1ae64176d0f6eba78ab1"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3ef5dd1d39797044642dbe53eb2bc56435308432e9c7907728da74c69ee2adca"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:454c203164e07a860dbeb3b1f4a733be52b0edbb4dd2e5bd75023ffa8b49403a"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cf6b7f1d4dc59af960e6ffdc4faffe6460bbfa8dce27a58bf75755ffdb2526d"}, + {file = "tokenizers-0.15.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2ef09bbc16519f6c25d0c7fc0c6a33a6f62923e263c9d7cca4e58b8c61572afb"}, + {file = "tokenizers-0.15.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c9a2ebdd2ad4ec7a68e7615086e633857c85e2f18025bd05d2a4399e6c5f7169"}, + {file = "tokenizers-0.15.2-cp39-none-win32.whl", hash = "sha256:918fbb0eab96fe08e72a8c2b5461e9cce95585d82a58688e7f01c2bd546c79d0"}, + {file = "tokenizers-0.15.2-cp39-none-win_amd64.whl", hash = "sha256:524e60da0135e106b254bd71f0659be9f89d83f006ea9093ce4d1fab498c6d0d"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6a9b648a58281c4672212fab04e60648fde574877d0139cd4b4f93fe28ca8944"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:7c7d18b733be6bbca8a55084027f7be428c947ddf871c500ee603e375013ffba"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:13ca3611de8d9ddfbc4dc39ef54ab1d2d4aaa114ac8727dfdc6a6ec4be017378"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:237d1bf3361cf2e6463e6c140628e6406766e8b27274f5fcc62c747ae3c6f094"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67a0fe1e49e60c664915e9fb6b0cb19bac082ab1f309188230e4b2920230edb3"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:4e022fe65e99230b8fd89ebdfea138c24421f91c1a4f4781a8f5016fd5cdfb4d"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:d857be2df69763362ac699f8b251a8cd3fac9d21893de129bc788f8baaef2693"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:708bb3e4283177236309e698da5fcd0879ce8fd37457d7c266d16b550bcbbd18"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:64c35e09e9899b72a76e762f9854e8750213f67567787d45f37ce06daf57ca78"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1257f4394be0d3b00de8c9e840ca5601d0a4a8438361ce9c2b05c7d25f6057b"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02272fe48280e0293a04245ca5d919b2c94a48b408b55e858feae9618138aeda"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:dc3ad9ebc76eabe8b1d7c04d38be884b8f9d60c0cdc09b0aa4e3bcf746de0388"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:32e16bdeffa7c4f46bf2152172ca511808b952701d13e7c18833c0b73cb5c23f"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:fb16ba563d59003028b678d2361a27f7e4ae0ab29c7a80690efa20d829c81fdb"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:2277c36d2d6cdb7876c274547921a42425b6810d38354327dd65a8009acf870c"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:1cf75d32e8d250781940d07f7eece253f2fe9ecdb1dc7ba6e3833fa17b82fcbc"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1b3b31884dc8e9b21508bb76da80ebf7308fdb947a17affce815665d5c4d028"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b10122d8d8e30afb43bb1fe21a3619f62c3e2574bff2699cf8af8b0b6c5dc4a3"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d88b96ff0fe8e91f6ef01ba50b0d71db5017fa4e3b1d99681cec89a85faf7bf7"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:37aaec5a52e959892870a7c47cef80c53797c0db9149d458460f4f31e2fb250e"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:e2ea752f2b0fe96eb6e2f3adbbf4d72aaa1272079b0dfa1145507bd6a5d537e6"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:4b19a808d8799fda23504a5cd31d2f58e6f52f140380082b352f877017d6342b"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:64c86e5e068ac8b19204419ed8ca90f9d25db20578f5881e337d203b314f4104"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de19c4dc503c612847edf833c82e9f73cd79926a384af9d801dcf93f110cea4e"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea09acd2fe3324174063d61ad620dec3bcf042b495515f27f638270a7d466e8b"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:cf27fd43472e07b57cf420eee1e814549203d56de00b5af8659cb99885472f1f"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:7ca22bd897537a0080521445d91a58886c8c04084a6a19e6c78c586e0cfa92a5"}, + {file = "tokenizers-0.15.2.tar.gz", hash = "sha256:e6e9c6e019dd5484be5beafc775ae6c925f4c69a3487040ed09b45e13df2cb91"}, +] + +[package.dependencies] +huggingface_hub = ">=0.16.4,<1.0" + +[package.extras] +dev = ["tokenizers[testing]"] +docs = ["setuptools_rust", "sphinx", "sphinx_rtd_theme"] testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] [[package]] @@ -8460,6 +9004,17 @@ files = [ {file = "types_PyYAML-6.0.12.11-py3-none-any.whl", hash = "sha256:a461508f3096d1d5810ec5ab95d7eeecb651f3a15b71959999988942063bf01d"}, ] +[[package]] +name = "types-regex" +version = "2024.5.15.20240519" +description = "Typing stubs for regex" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-regex-2024.5.15.20240519.tar.gz", hash = "sha256:ef3f594a95a95d6b9b5704a1facf3511a73e4731209ddb8868461db4c42dc12b"}, + {file = "types_regex-2024.5.15.20240519-py3-none-any.whl", hash = "sha256:d5895079cc66f91ae8818aeef14e9337c492ceb87ad0ff3df8c1c04d418cb9dd"}, +] + [[package]] name = "types-requests" version = "2.31.0.2" @@ -8783,6 +9338,20 @@ files = [ {file = "win_precise_time-1.4.2-cp39-cp39-win_amd64.whl", hash = "sha256:3f510fa92d9c39ea533c983e1d62c7bc66fdf0a3e3c3bdda48d4ebb634ff7034"}, ] +[[package]] +name = "win32-setctime" +version = "1.1.0" +description = "A small Python utility to set file creation time on Windows" +optional = true +python-versions = ">=3.5" +files = [ + {file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"}, + {file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"}, +] + +[package.extras] +dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] + [[package]] name = "wrapt" version = "1.15.0" @@ -9074,6 +9643,7 @@ duckdb = ["duckdb"] filesystem = ["botocore", "s3fs"] gcp = ["gcsfs", "google-cloud-bigquery", "grpcio"] gs = ["gcsfs"] +lancedb = ["lancedb", "pyarrow"] motherduck = ["duckdb", "pyarrow"] mssql = ["pyodbc"] parquet = ["pyarrow"] @@ -9088,4 +9658,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "47136cc3a6247e709dfe04a810df7309d1a2bc7fe838592dd5f58dc39c2407c8" +content-hash = "1205791c3a090cf55617833ef566f1d55e6fcfa7209079bca92277f217130549" diff --git a/pyproject.toml b/pyproject.toml index 10e3bf47d5..88658a5206 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "dlt" -version = "0.4.13a0" +version = "0.5.2a0" description = "dlt is an open-source python-first scalable data loading library that does not require any backend to run." authors = ["dltHub Inc. "] maintainers = [ "Marcin Rudolf ", "Adrian Brudaru ", "Anton Burnashev ", "David Scharf " ] @@ -44,7 +44,8 @@ astunparse = ">=1.6.3" gitpython = ">=3.1.29" pytz = ">=2022.6" giturlparse = ">=0.10.0" -orjson = {version = ">=3.6.7,<=3.9.10", markers="platform_python_implementation != 'PyPy'"} +# exclude some versions because of segfault bugs in orjson +orjson = {version = ">=3.6.7,<4,!=3.9.11,!=3.9.12,!=3.9.13,!=3.9.14,!=3.10.1", markers="platform_python_implementation != 'PyPy'"} tenacity = ">=8.0.2" jsonpath-ng = ">=1.5.3" fsspec = ">=2022.4.0" @@ -73,11 +74,12 @@ pipdeptree = {version = ">=2.9.0,<2.10", optional = true} pyathena = {version = ">=2.9.6", optional = true} weaviate-client = {version = ">=3.22", optional = true} adlfs = {version = ">=2022.4.0", optional = true} -pyodbc = {version = "^4.0.39", optional = true} -qdrant-client = {version = "^1.6.4", optional = true, extras = ["fastembed"]} -databricks-sql-connector = {version = ">=3", optional = true} +pyodbc = {version = ">=4.0.39", optional = true} +qdrant-client = {version = ">=1.8", optional = true, extras = ["fastembed"]} +databricks-sql-connector = {version = ">=2.9.3", optional = true} clickhouse-driver = { version = ">=0.2.7", optional = true } clickhouse-connect = { version = ">=0.7.7", optional = true } +lancedb = { version = ">=0.8.2", optional = true, markers = "python_version >= '3.9'" } deltalake = { version = ">=0.17.4", optional = true } [tool.poetry.extras] @@ -103,8 +105,10 @@ qdrant = ["qdrant-client"] databricks = ["databricks-sql-connector"] clickhouse = ["clickhouse-driver", "clickhouse-connect", "s3fs", "gcsfs", "adlfs", "pyarrow"] dremio = ["pyarrow"] +lancedb = ["lancedb", "pyarrow"] deltalake = ["deltalake", "pyarrow"] + [tool.poetry.scripts] dlt = "dlt.cli._dlt:_main" @@ -149,6 +153,8 @@ types-pytz = ">=2024.1.0.20240203" ruff = "^0.3.2" pyjwt = "^2.8.0" pytest-mock = "^3.14.0" +types-regex = "^2024.5.15.20240519" +flake8-print = "^5.0.0" [tool.poetry.group.pipeline] optional = true @@ -214,6 +220,8 @@ pandas = ">2" alive-progress = ">=3.0.1" pyarrow = ">=14.0.0" psycopg2-binary = ">=2.9" +lancedb = ">=0.6.13" +openai = ">=1.35" [tool.black] # https://black.readthedocs.io/en/stable/usage_and_configuration/the_basics.html#configuration-via-a-file line-length = 100 diff --git a/tests/cases.py b/tests/cases.py index d145ec1d94..fa346b8b49 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -173,7 +173,7 @@ def table_update_and_row( exclude_types: Sequence[TDataType] = None, exclude_columns: Sequence[str] = None -) -> Tuple[TTableSchemaColumns, StrAny]: +) -> Tuple[TTableSchemaColumns, Dict[str, Any]]: """Get a table schema and a row with all possible data types. Optionally exclude some data types from the schema and row. """ @@ -192,6 +192,7 @@ def table_update_and_row( def assert_all_data_types_row( db_row: Union[List[Any], TDataItems], + expected_row: Dict[str, Any] = None, parse_complex_strings: bool = False, allow_base64_binary: bool = False, timestamp_precision: int = 6, @@ -202,6 +203,7 @@ def assert_all_data_types_row( # content must equal # print(db_row) schema = schema or TABLE_UPDATE_COLUMNS_SCHEMA + expected_row = expected_row or TABLE_ROW_ALL_DATA_TYPES # Include only columns requested in schema if isinstance(db_row, dict): @@ -209,7 +211,7 @@ def assert_all_data_types_row( else: db_mapping = {col_name: db_row[i] for i, col_name in enumerate(schema)} - expected_rows = {key: value for key, value in TABLE_ROW_ALL_DATA_TYPES.items() if key in schema} + expected_rows = {key: value for key, value in expected_row.items() if key in schema} # prepare date to be compared: convert into pendulum instance, adjust microsecond precision if "col4" in expected_rows: parsed_date = pendulum.instance(db_mapping["col4"]) diff --git a/tests/cli/common/test_cli_invoke.py b/tests/cli/common/test_cli_invoke.py index d367a97261..f856162479 100644 --- a/tests/cli/common/test_cli_invoke.py +++ b/tests/cli/common/test_cli_invoke.py @@ -6,6 +6,7 @@ from unittest.mock import patch import dlt +from dlt.common.known_env import DLT_DATA_DIR from dlt.common.configuration.paths import get_dlt_data_dir from dlt.common.runners.venv import Venv from dlt.common.utils import custom_environ, set_working_dir @@ -62,7 +63,7 @@ def test_invoke_pipeline(script_runner: ScriptRunner) -> None: shutil.copytree("tests/cli/cases/deploy_pipeline", TEST_STORAGE_ROOT, dirs_exist_ok=True) with set_working_dir(TEST_STORAGE_ROOT): - with custom_environ({"COMPETED_PROB": "1.0", "DLT_DATA_DIR": get_dlt_data_dir()}): + with custom_environ({"COMPETED_PROB": "1.0", DLT_DATA_DIR: get_dlt_data_dir()}): venv = Venv.restore_current() venv.run_script("dummy_pipeline.py") # we check output test_pipeline_command else @@ -96,7 +97,7 @@ def test_invoke_pipeline(script_runner: ScriptRunner) -> None: def test_invoke_init_chess_and_template(script_runner: ScriptRunner) -> None: with set_working_dir(TEST_STORAGE_ROOT): # store dlt data in test storage (like patch_home_dir) - with custom_environ({"DLT_DATA_DIR": get_dlt_data_dir()}): + with custom_environ({DLT_DATA_DIR: get_dlt_data_dir()}): result = script_runner.run(["dlt", "init", "chess", "dummy"]) assert "Verified source chess was added to your project!" in result.stdout assert result.returncode == 0 @@ -116,7 +117,7 @@ def test_invoke_list_verified_sources(script_runner: ScriptRunner) -> None: def test_invoke_deploy_project(script_runner: ScriptRunner) -> None: with set_working_dir(TEST_STORAGE_ROOT): # store dlt data in test storage (like patch_home_dir) - with custom_environ({"DLT_DATA_DIR": get_dlt_data_dir()}): + with custom_environ({DLT_DATA_DIR: get_dlt_data_dir()}): result = script_runner.run( ["dlt", "deploy", "debug_pipeline.py", "github-action", "--schedule", "@daily"] ) diff --git a/tests/common/cases/destinations/null.py b/tests/common/cases/destinations/null.py index b2054cd7e8..37e87d89cf 100644 --- a/tests/common/cases/destinations/null.py +++ b/tests/common/cases/destinations/null.py @@ -14,7 +14,7 @@ def __init__(self, **kwargs: Any) -> None: spec = DestinationClientConfiguration - def capabilities(self) -> DestinationCapabilitiesContext: + def _raw_capabilities(self) -> DestinationCapabilitiesContext: return DestinationCapabilitiesContext.generic_capabilities() @property diff --git a/tests/common/cases/normalizers/__init__.py b/tests/common/cases/normalizers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/common/cases/normalizers/snake_no_x.py b/tests/common/cases/normalizers/snake_no_x.py new file mode 100644 index 0000000000..af3a53cbce --- /dev/null +++ b/tests/common/cases/normalizers/snake_no_x.py @@ -0,0 +1,10 @@ +from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention + + +class NamingConvention(SnakeCaseNamingConvention): + def normalize_identifier(self, identifier: str) -> str: + identifier = super().normalize_identifier(identifier) + if identifier.endswith("x"): + print(identifier[:-1] + "_") + return identifier[:-1] + "_" + return identifier diff --git a/tests/common/cases/normalizers/sql_upper.py b/tests/common/cases/normalizers/sql_upper.py new file mode 100644 index 0000000000..f2175f06ad --- /dev/null +++ b/tests/common/cases/normalizers/sql_upper.py @@ -0,0 +1,18 @@ +from typing import Any, Sequence + +from dlt.common.normalizers.naming.naming import NamingConvention as BaseNamingConvention + + +class NamingConvention(BaseNamingConvention): + PATH_SEPARATOR = "__" + + _CLEANUP_TABLE = str.maketrans(".\n\r'\"▶", "______") + + @property + def is_case_sensitive(self) -> bool: + return True + + def normalize_identifier(self, identifier: str) -> str: + identifier = super().normalize_identifier(identifier) + norm_identifier = identifier.translate(self._CLEANUP_TABLE).upper() + return self.shorten_identifier(norm_identifier, identifier, self.max_length) diff --git a/tests/common/cases/normalizers/title_case.py b/tests/common/cases/normalizers/title_case.py new file mode 100644 index 0000000000..2b93b476c8 --- /dev/null +++ b/tests/common/cases/normalizers/title_case.py @@ -0,0 +1,15 @@ +from typing import ClassVar +from dlt.common.normalizers.naming.direct import NamingConvention as DirectNamingConvention + + +class NamingConvention(DirectNamingConvention): + """Test case sensitive naming that capitalizes first and last letter and leaves the rest intact""" + + PATH_SEPARATOR: ClassVar[str] = "__" + + def normalize_identifier(self, identifier: str) -> str: + # keep prefix + if identifier == "_dlt": + return "_dlt" + identifier = super().normalize_identifier(identifier) + return identifier[0].upper() + identifier[1:-1] + identifier[-1].upper() diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index e7083956b3..7c3138ea73 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -63,6 +63,7 @@ ) from dlt.common.pipeline import TRefreshMode +from dlt.destinations.impl.postgres.configuration import PostgresCredentials from tests.utils import preserve_environ from tests.common.configuration.utils import ( MockProvider, @@ -450,6 +451,43 @@ def test_invalid_native_config_value() -> None: assert py_ex.value.embedded_sections == () +def test_maybe_use_explicit_value() -> None: + # pass through dict and configs + c = ConnectionStringCredentials() + dict_explicit = {"explicit": "is_dict"} + config_explicit = BaseConfiguration() + assert resolve._maybe_parse_native_value(c, dict_explicit, ()) is dict_explicit + assert resolve._maybe_parse_native_value(c, config_explicit, ()) is config_explicit + + # postgres credentials have a default parameter (connect_timeout), which must be removed for explicit value + pg_c = PostgresCredentials() + explicit_value = resolve._maybe_parse_native_value( + pg_c, "postgres://loader@localhost:5432/dlt_data?a=b&c=d", () + ) + # NOTE: connect_timeout and password are not present + assert explicit_value == { + "drivername": "postgres", + "database": "dlt_data", + "username": "loader", + "host": "localhost", + "query": {"a": "b", "c": "d"}, + } + pg_c = PostgresCredentials() + explicit_value = resolve._maybe_parse_native_value( + pg_c, "postgres://loader@localhost:5432/dlt_data?connect_timeout=33", () + ) + assert explicit_value["connect_timeout"] == 33 + + +def test_optional_params_resolved_if_complete_native_value(environment: Any) -> None: + # this native value fully resolves configuration + environment["CREDENTIALS"] = "postgres://loader:pwd@localhost:5432/dlt_data?a=b&c=d" + # still this config value will be injected + environment["CREDENTIALS__CONNECT_TIMEOUT"] = "300" + c = resolve.resolve_configuration(PostgresCredentials()) + assert c.connect_timeout == 300 + + def test_on_resolved(environment: Any) -> None: with pytest.raises(RuntimeError): # head over hells @@ -583,7 +621,7 @@ class _SecretCredentials(RunConfiguration): "dlthub_telemetry": True, "dlthub_telemetry_endpoint": "https://telemetry-tracker.services4758.workers.dev", "dlthub_telemetry_segment_write_key": None, - "log_format": "{asctime}|[{levelname:<21}]|{process}|{thread}|{name}|{filename}|{funcName}:{lineno}|{message}", + "log_format": "{asctime}|[{levelname}]|{process}|{thread}|{name}|{filename}|{funcName}:{lineno}|{message}", "log_level": "WARNING", "request_timeout": 60, "request_max_attempts": 5, diff --git a/tests/common/configuration/test_credentials.py b/tests/common/configuration/test_credentials.py index 7c184c16e5..1c6319b551 100644 --- a/tests/common/configuration/test_credentials.py +++ b/tests/common/configuration/test_credentials.py @@ -21,7 +21,8 @@ ) from dlt.common.configuration.specs.run_configuration import RunConfiguration -from tests.utils import preserve_environ +from dlt.destinations.impl.snowflake.configuration import SnowflakeCredentials +from tests.utils import TEST_DICT_CONFIG_PROVIDER, preserve_environ from tests.common.utils import json_case_path from tests.common.configuration.utils import environment @@ -64,6 +65,17 @@ """ % OAUTH_USER_INFO +def test_credentials_resolve_from_init_value() -> None: + c = SnowflakeCredentials.from_init_value("snowflake://loader:pass@localhost:5432/dlt_data") + assert c.is_resolved() + # incomplete not resolved + c = SnowflakeCredentials.from_init_value("snowflake://loader:pass@localhost") + assert c.is_resolved() is False + # invalid configuration that raises on resolve() + c = SnowflakeCredentials.from_init_value("snowflake://loader@localhost/dlt_data") + assert c.is_resolved() is False + + def test_connection_string_credentials_native_representation(environment) -> None: with pytest.raises(InvalidConnectionString): ConnectionStringCredentials().parse_native_representation(1) @@ -158,10 +170,10 @@ def test_connection_string_resolved_from_native_representation_env(environment: assert c.host == "aws.12.1" -def test_connection_string_from_init() -> None: +def test_connection_string_initializer() -> None: c = ConnectionStringCredentials("postgres://loader:pass@localhost:5432/dlt_data?a=b&c=d") assert c.drivername == "postgres" - assert c.is_resolved() + assert not c.is_resolved() assert not c.is_partial() c = ConnectionStringCredentials( @@ -182,10 +194,31 @@ def test_connection_string_from_init() -> None: assert c.port == 5432 assert c.database == "dlt_data" assert c.query == {"a": "b", "c": "d"} - assert c.is_resolved() + assert not c.is_resolved() assert not c.is_partial() +def test_query_additional_params() -> None: + c = ConnectionStringCredentials("snowflake://user1:pass1@host1/db1?keep_alive=true") + assert c.query["keep_alive"] == "true" + assert c.to_url().query["keep_alive"] == "true" + + # try a typed param + with TEST_DICT_CONFIG_PROVIDER().values({"credentials": {"query": {"keep_alive": True}}}): + c = ConnectionStringCredentials("snowflake://user1:pass1@host1/db1") + assert c.is_resolved() is False + c = resolve_configuration(c) + assert c.query["keep_alive"] is True + assert c.get_query()["keep_alive"] is True + assert c.to_url().query["keep_alive"] == "True" + + +def test_connection_string_str_repr() -> None: + c = ConnectionStringCredentials("postgres://loader:pass@localhost:5432/dlt_data?a=b&c=d") + # password and query string redacted + assert str(c) == "postgres://loader:***@localhost:5432/dlt_data" + + def test_gcp_service_credentials_native_representation(environment) -> None: with pytest.raises(InvalidGoogleNativeCredentialsType): GcpServiceAccountCredentials().parse_native_representation(1) @@ -203,11 +236,9 @@ def test_gcp_service_credentials_native_representation(environment) -> None: assert gcpc.private_key == "-----BEGIN PRIVATE KEY-----\n\n-----END PRIVATE KEY-----\n" assert gcpc.project_id == "chat-analytics" assert gcpc.client_email == "loader@iam.gserviceaccount.com" - # location is present but deprecated - assert gcpc.location == "US" # get native representation, it will also location _repr = gcpc.to_native_representation() - assert "location" in _repr + assert "project_id" in _repr # parse again gcpc_2 = GcpServiceAccountCredentials() gcpc_2.parse_native_representation(_repr) diff --git a/tests/common/configuration/test_inject.py b/tests/common/configuration/test_inject.py index f0494e9898..13d68b53e9 100644 --- a/tests/common/configuration/test_inject.py +++ b/tests/common/configuration/test_inject.py @@ -570,7 +570,19 @@ def get_cf(aux: str = dlt.config.value, last_config: AuxTest = None): def test_inject_spec_into_argument_with_spec_type() -> None: # if signature contains argument with type of SPEC, it gets injected there - from dlt.destinations.impl.dummy import _configure, DummyClientConfiguration + import dlt + from dlt.common.configuration import known_sections + from dlt.destinations.impl.dummy.configuration import DummyClientConfiguration + + @with_config( + spec=DummyClientConfiguration, + sections=( + known_sections.DESTINATION, + "dummy", + ), + ) + def _configure(config: DummyClientConfiguration = dlt.config.value) -> DummyClientConfiguration: + return config # _configure has argument of type DummyClientConfiguration that it returns # this type holds resolved configuration diff --git a/tests/common/configuration/test_toml_provider.py b/tests/common/configuration/test_toml_provider.py index 43bad21ece..5271c68633 100644 --- a/tests/common/configuration/test_toml_provider.py +++ b/tests/common/configuration/test_toml_provider.py @@ -10,6 +10,7 @@ from dlt.common.configuration.container import Container from dlt.common.configuration.inject import with_config from dlt.common.configuration.exceptions import LookupTrace +from dlt.common.known_env import DLT_DATA_DIR, DLT_PROJECT_DIR from dlt.common.configuration.providers.toml import ( SECRETS_TOML, CONFIG_TOML, @@ -219,8 +220,8 @@ def test_secrets_toml_credentials_from_native_repr( " KEY-----\nMIIEuwIBADANBgkqhkiG9w0BAQEFAASCBKUwggShAgEAAoIBAQCNEN0bL39HmD+S\n...\n-----END" " PRIVATE KEY-----\n" ) - # but project id got overridden from credentials.project_id - assert c.project_id.endswith("-credentials") + # project id taken from the same value, will not be overridden from any other configs + assert c.project_id.endswith("mock-project-id-source.credentials") # also try sql alchemy url (native repr) c2 = resolve.resolve_configuration(ConnectionStringCredentials(), sections=("databricks",)) assert c2.drivername == "databricks+connector" @@ -257,8 +258,8 @@ def test_toml_global_config() -> None: assert config._add_global_config is False # type: ignore[attr-defined] # set dlt data and settings dir - os.environ["DLT_DATA_DIR"] = "./tests/common/cases/configuration/dlt_home" - os.environ["DLT_PROJECT_DIR"] = "./tests/common/cases/configuration/" + os.environ[DLT_DATA_DIR] = "./tests/common/cases/configuration/dlt_home" + os.environ[DLT_PROJECT_DIR] = "./tests/common/cases/configuration/" # create instance with global toml enabled config = ConfigTomlProvider(add_global_config=True) assert config._add_global_config is True diff --git a/tests/common/data_writers/test_data_writers.py b/tests/common/data_writers/test_data_writers.py index 6cc7cb55ab..9b4e61a2f7 100644 --- a/tests/common/data_writers/test_data_writers.py +++ b/tests/common/data_writers/test_data_writers.py @@ -7,11 +7,9 @@ from dlt.common.data_writers.exceptions import DataWriterNotFound, SpecLookupFailed from dlt.common.typing import AnyFun -# from dlt.destinations.postgres import capabilities -from dlt.destinations.impl.redshift import capabilities as redshift_caps from dlt.common.data_writers.escape import ( escape_redshift_identifier, - escape_bigquery_identifier, + escape_hive_identifier, escape_redshift_literal, escape_postgres_literal, escape_duckdb_literal, @@ -29,8 +27,10 @@ DataWriter, DataWriterMetrics, EMPTY_DATA_WRITER_METRICS, + ImportFileWriter, InsertValuesWriter, JsonlWriter, + create_import_spec, get_best_writer_spec, resolve_best_writer_spec, is_native_writer, @@ -51,8 +51,10 @@ class _BytesIOWriter(DataWriter): @pytest.fixture def insert_writer() -> Iterator[DataWriter]: + from dlt.destinations import redshift + with io.StringIO() as f: - yield InsertValuesWriter(f, caps=redshift_caps()) + yield InsertValuesWriter(f, caps=redshift().capabilities()) @pytest.fixture @@ -154,7 +156,7 @@ def test_identifier_escape() -> None: def test_identifier_escape_bigquery() -> None: assert ( - escape_bigquery_identifier(", NULL'); DROP TABLE\"` -\\-") + escape_hive_identifier(", NULL'); DROP TABLE\"` -\\-") == "`, NULL'); DROP TABLE\"\\` -\\\\-`" ) @@ -259,3 +261,14 @@ def test_get_best_writer() -> None: assert WRITER_SPECS[get_best_writer_spec("arrow", "insert_values")] == ArrowToInsertValuesWriter with pytest.raises(DataWriterNotFound): get_best_writer_spec("arrow", "tsv") # type: ignore + + +def test_import_file_writer() -> None: + spec = create_import_spec("jsonl", ["jsonl"]) + assert spec.data_item_format == "file" + assert spec.file_format == "jsonl" + writer = DataWriter.writer_class_from_spec(spec) + assert writer is ImportFileWriter + w_ = writer(None) + with pytest.raises(NotImplementedError): + w_.write_header(None) diff --git a/tests/common/normalizers/custom_normalizers.py b/tests/common/normalizers/custom_normalizers.py index 3ae65c8b53..4a0f456eef 100644 --- a/tests/common/normalizers/custom_normalizers.py +++ b/tests/common/normalizers/custom_normalizers.py @@ -11,6 +11,13 @@ def normalize_identifier(self, identifier: str) -> str: return "column_" + identifier.lower() +class ColumnNamingConvention(SnakeCaseNamingConvention): + def normalize_identifier(self, identifier: str) -> str: + if identifier.startswith("column_"): + return identifier + return "column_" + identifier.lower() + + class DataItemNormalizer(RelationalNormalizer): def extend_schema(self) -> None: json_config = self.schema._normalizers_config["json"]["config"] diff --git a/tests/common/normalizers/test_import_normalizers.py b/tests/common/normalizers/test_import_normalizers.py index df6b973943..fe356de327 100644 --- a/tests/common/normalizers/test_import_normalizers.py +++ b/tests/common/normalizers/test_import_normalizers.py @@ -1,14 +1,23 @@ import os - import pytest from dlt.common.configuration.container import Container from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.normalizers import explicit_normalizers, import_normalizers +from dlt.common.normalizers.typing import TNormalizersConfig +from dlt.common.normalizers.utils import ( + DEFAULT_NAMING_NAMESPACE, + explicit_normalizers, + import_normalizers, + naming_from_reference, + serialize_reference, +) from dlt.common.normalizers.json.relational import DataItemNormalizer as RelationalNormalizer -from dlt.common.normalizers.naming import snake_case -from dlt.common.normalizers.naming import direct -from dlt.common.normalizers.naming.exceptions import InvalidNamingModule, UnknownNamingModule +from dlt.common.normalizers.naming import snake_case, direct +from dlt.common.normalizers.naming.exceptions import ( + InvalidNamingType, + NamingTypeNotFound, + UnknownNamingModule, +) from tests.common.normalizers.custom_normalizers import ( DataItemNormalizer as CustomRelationalNormalizer, @@ -16,7 +25,7 @@ from tests.utils import preserve_environ -def test_default_normalizers() -> None: +def test_explicit_normalizers() -> None: config = explicit_normalizers() assert config["names"] is None assert config["json"] is None @@ -26,6 +35,12 @@ def test_default_normalizers() -> None: assert config["names"] == "direct" assert config["json"] == {"module": "custom"} + # pass modules and types, make sure normalizer config is serialized + config = explicit_normalizers(direct) + assert config["names"] == f"{DEFAULT_NAMING_NAMESPACE}.direct.NamingConvention" + config = explicit_normalizers(direct.NamingConvention) + assert config["names"] == f"{DEFAULT_NAMING_NAMESPACE}.direct.NamingConvention" + # use environ os.environ["SCHEMA__NAMING"] = "direct" os.environ["SCHEMA__JSON_NORMALIZER"] = '{"module": "custom"}' @@ -34,13 +49,75 @@ def test_default_normalizers() -> None: assert config["json"] == {"module": "custom"} -def test_default_normalizers_with_caps() -> None: +def test_explicit_normalizers_caps_ignored() -> None: # gets the naming convention from capabilities destination_caps = DestinationCapabilitiesContext.generic_capabilities() destination_caps.naming_convention = "direct" with Container().injectable_context(destination_caps): config = explicit_normalizers() - assert config["names"] == "direct" + assert config["names"] is None + + +def test_serialize_reference() -> None: + assert serialize_reference(None) is None + assert serialize_reference("module") == "module" + assert ( + serialize_reference(snake_case) == f"{DEFAULT_NAMING_NAMESPACE}.snake_case.NamingConvention" + ) + assert ( + serialize_reference(snake_case.NamingConvention) + == f"{DEFAULT_NAMING_NAMESPACE}.snake_case.NamingConvention" + ) + # test a wrong module and type + with pytest.raises(NamingTypeNotFound): + serialize_reference(pytest) + with pytest.raises(ValueError): + serialize_reference(Container) # type: ignore[arg-type] + + +def test_naming_from_reference() -> None: + assert naming_from_reference("snake_case").name() == "snake_case" + assert naming_from_reference("snake_case.NamingConvention").name() == "snake_case" + + # now not visible + with pytest.raises(UnknownNamingModule): + naming_from_reference("custom_normalizers") + + # temporarily add current file dir to paths and import module that clash with dlt predefined (no path) + import sys + + try: + sys.path.insert(0, os.path.dirname(__file__)) + assert naming_from_reference("custom_normalizers").name() == "custom_normalizers" + assert ( + naming_from_reference("custom_normalizers.NamingConvention").name() + == "custom_normalizers" + ) + assert ( + naming_from_reference("custom_normalizers.ColumnNamingConvention").name() + == "custom_normalizers" + ) + finally: + sys.path.pop(0) + + # non standard location + assert ( + naming_from_reference("dlt.destinations.impl.weaviate.naming").name() + == "dlt.destinations.impl.weaviate.naming" + ) + + # import module + assert naming_from_reference(snake_case).name() == "snake_case" + assert naming_from_reference(snake_case.NamingConvention).name() == "snake_case" + + with pytest.raises(ValueError): + naming_from_reference(snake_case.NamingConvention()) # type: ignore[arg-type] + + # with capabilities + caps = DestinationCapabilitiesContext.generic_capabilities() + caps.max_identifier_length = 120 + naming = naming_from_reference(snake_case.NamingConvention, caps) + assert naming.max_length == 120 def test_import_normalizers() -> None: @@ -64,6 +141,40 @@ def test_import_normalizers() -> None: assert json_normalizer is CustomRelationalNormalizer +def test_import_normalizers_with_defaults() -> None: + explicit = explicit_normalizers() + default_: TNormalizersConfig = { + "names": "dlt.destinations.impl.weaviate.naming", + "json": {"module": "tests.common.normalizers.custom_normalizers"}, + } + config, naming, json_normalizer = import_normalizers(explicit, default_) + + assert config["names"] == "dlt.destinations.impl.weaviate.naming" + assert config["json"] == {"module": "tests.common.normalizers.custom_normalizers"} + assert naming.name() == "dlt.destinations.impl.weaviate.naming" + assert json_normalizer is CustomRelationalNormalizer + + # correctly overrides + explicit["names"] = "sql_cs_v1" + explicit["json"] = {"module": "dlt.common.normalizers.json.relational"} + config, naming, json_normalizer = import_normalizers(explicit, default_) + assert config["names"] == "sql_cs_v1" + assert config["json"] == {"module": "dlt.common.normalizers.json.relational"} + assert naming.name() == "sql_cs_v1" + assert json_normalizer is RelationalNormalizer + + +@pytest.mark.parametrize("sections", ("", "SOURCES__", "SOURCES__TEST_SCHEMA__")) +def test_config_sections(sections: str) -> None: + os.environ[f"{sections}SCHEMA__NAMING"] = "direct" + os.environ[f"{sections}SCHEMA__JSON_NORMALIZER"] = ( + '{"module": "tests.common.normalizers.custom_normalizers"}' + ) + config, _, _ = import_normalizers(explicit_normalizers(schema_name="test_schema")) + assert config["names"] == "direct" + assert config["json"] == {"module": "tests.common.normalizers.custom_normalizers"} + + def test_import_normalizers_with_caps() -> None: # gets the naming convention from capabilities destination_caps = DestinationCapabilitiesContext.generic_capabilities() @@ -74,6 +185,25 @@ def test_import_normalizers_with_caps() -> None: assert isinstance(naming, direct.NamingConvention) assert naming.max_length == 127 + _, naming, _ = import_normalizers(explicit_normalizers(snake_case)) + assert isinstance(naming, snake_case.NamingConvention) + assert naming.max_length == 127 + + # max table nesting generates relational normalizer + default_: TNormalizersConfig = { + "names": "dlt.destinations.impl.weaviate.naming", + "json": {"module": "tests.common.normalizers.custom_normalizers"}, + } + destination_caps.max_table_nesting = 0 + with Container().injectable_context(destination_caps): + config, _, relational = import_normalizers(explicit_normalizers()) + assert config["json"]["config"]["max_nesting"] == 0 + assert relational is RelationalNormalizer + + # wrong normalizer + config, _, relational = import_normalizers(explicit_normalizers(), default_) + assert "config" not in config["json"] + def test_import_invalid_naming_module() -> None: with pytest.raises(UnknownNamingModule) as py_ex: @@ -82,6 +212,7 @@ def test_import_invalid_naming_module() -> None: with pytest.raises(UnknownNamingModule) as py_ex: import_normalizers(explicit_normalizers("dlt.common.tests")) assert py_ex.value.naming_module == "dlt.common.tests" - with pytest.raises(InvalidNamingModule) as py_ex2: - import_normalizers(explicit_normalizers("dlt.pipeline")) + with pytest.raises(InvalidNamingType) as py_ex2: + import_normalizers(explicit_normalizers("dlt.pipeline.helpers")) assert py_ex2.value.naming_module == "dlt.pipeline" + assert py_ex2.value.naming_class == "helpers" diff --git a/tests/common/normalizers/test_json_relational.py b/tests/common/normalizers/test_json_relational.py index 502ce619dd..159e33da4d 100644 --- a/tests/common/normalizers/test_json_relational.py +++ b/tests/common/normalizers/test_json_relational.py @@ -2,16 +2,15 @@ from dlt.common.typing import StrAny, DictStrAny from dlt.common.normalizers.naming import NamingConvention -from dlt.common.schema.typing import TSimpleRegex +from dlt.common.schema.typing import TColumnName, TSimpleRegex from dlt.common.utils import digest128, uniq_id -from dlt.common.schema import Schema, TTableSchema +from dlt.common.schema import Schema from dlt.common.schema.utils import new_table from dlt.common.normalizers.json.relational import ( RelationalNormalizerConfigPropagation, DataItemNormalizer as RelationalNormalizer, DLT_ID_LENGTH_BYTES, - TDataItemRow, ) # _flatten, _get_child_row_hash, _normalize_row, normalize_data_item, @@ -30,7 +29,7 @@ def test_flatten_fix_field_name(norm: RelationalNormalizer) -> None: "f 2": [], "f!3": {"f4": "a", "f-5": "b", "f*6": {"c": 7, "c v": 8, "c x": []}}, } - flattened_row, lists = norm._flatten("mock_table", row, 0) # type: ignore[arg-type] + flattened_row, lists = norm._flatten("mock_table", row, 0) assert "f_1" in flattened_row # assert "f_2" in flattened_row assert "f_3__f4" in flattened_row @@ -63,12 +62,12 @@ def test_preserve_complex_value(norm: RelationalNormalizer) -> None: ) ) row_1 = {"value": 1} - flattened_row, _ = norm._flatten("with_complex", row_1, 0) # type: ignore[arg-type] - assert flattened_row["value"] == 1 # type: ignore[typeddict-item] + flattened_row, _ = norm._flatten("with_complex", row_1, 0) + assert flattened_row["value"] == 1 row_2 = {"value": {"complex": True}} - flattened_row, _ = norm._flatten("with_complex", row_2, 0) # type: ignore[arg-type] - assert flattened_row["value"] == row_2["value"] # type: ignore[typeddict-item] + flattened_row, _ = norm._flatten("with_complex", row_2, 0) + assert flattened_row["value"] == row_2["value"] # complex value is not flattened assert "value__complex" not in flattened_row @@ -79,12 +78,12 @@ def test_preserve_complex_value_with_hint(norm: RelationalNormalizer) -> None: norm.schema._compile_settings() row_1 = {"value": 1} - flattened_row, _ = norm._flatten("any_table", row_1, 0) # type: ignore[arg-type] - assert flattened_row["value"] == 1 # type: ignore[typeddict-item] + flattened_row, _ = norm._flatten("any_table", row_1, 0) + assert flattened_row["value"] == 1 row_2 = {"value": {"complex": True}} - flattened_row, _ = norm._flatten("any_table", row_2, 0) # type: ignore[arg-type] - assert flattened_row["value"] == row_2["value"] # type: ignore[typeddict-item] + flattened_row, _ = norm._flatten("any_table", row_2, 0) + assert flattened_row["value"] == row_2["value"] # complex value is not flattened assert "value__complex" not in flattened_row @@ -94,7 +93,7 @@ def test_child_table_linking(norm: RelationalNormalizer) -> None: # request _dlt_root_id propagation add_dlt_root_id_propagation(norm) - rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + rows = list(norm._normalize_row(row, {}, ("table",))) # should have 7 entries (root + level 1 + 3 * list + 2 * object) assert len(rows) == 7 # root elem will not have a root hash if not explicitly added, "extend" is added only to child @@ -142,7 +141,7 @@ def test_child_table_linking_primary_key(norm: RelationalNormalizer) -> None: norm.schema.merge_hints({"primary_key": [TSimpleRegex("id")]}) norm.schema._compile_settings() - rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + rows = list(norm._normalize_row(row, {}, ("table",))) root = next(t for t in rows if t[0][0] == "table")[1] # record hash is random for primary keys, not based on their content # this is a change introduced in dlt 0.2.0a30 @@ -172,7 +171,7 @@ def test_yields_parents_first(norm: RelationalNormalizer) -> None: "f": [{"id": "level1", "l": ["a", "b", "c"], "v": 120, "o": [{"a": 1}, {"a": 2}]}], "g": [{"id": "level2_g", "l": ["a"]}], } - rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + rows = list(norm._normalize_row(row, {}, ("table",))) tables = list(r[0][0] for r in rows) # child tables are always yielded before parent tables expected_tables = [ @@ -218,7 +217,7 @@ def test_yields_parent_relation(norm: RelationalNormalizer) -> None: } ], } - rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + rows = list(norm._normalize_row(row, {}, ("table",))) # normalizer must return parent table first and move in order of the list elements when yielding child tables # the yielding order if fully defined expected_parents = [ @@ -276,10 +275,10 @@ def test_yields_parent_relation(norm: RelationalNormalizer) -> None: def test_list_position(norm: RelationalNormalizer) -> None: - row: StrAny = { + row: DictStrAny = { "f": [{"l": ["a", "b", "c"], "v": 120, "lo": [{"e": "a"}, {"e": "b"}, {"e": "c"}]}] } - rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + rows = list(norm._normalize_row(row, {}, ("table",))) # root has no pos root = [t for t in rows if t[0][0] == "table"][0][1] assert "_dlt_list_idx" not in root @@ -290,13 +289,13 @@ def test_list_position(norm: RelationalNormalizer) -> None: # f_l must be ordered as it appears in the list for pos, elem in enumerate(["a", "b", "c"]): - row = next(t[1] for t in rows if t[0][0] == "table__f__l" and t[1]["value"] == elem) - assert row["_dlt_list_idx"] == pos + row_1 = next(t[1] for t in rows if t[0][0] == "table__f__l" and t[1]["value"] == elem) + assert row_1["_dlt_list_idx"] == pos # f_lo must be ordered - list of objects for pos, elem in enumerate(["a", "b", "c"]): - row = next(t[1] for t in rows if t[0][0] == "table__f__lo" and t[1]["e"] == elem) - assert row["_dlt_list_idx"] == pos + row_2 = next(t[1] for t in rows if t[0][0] == "table__f__lo" and t[1]["e"] == elem) + assert row_2["_dlt_list_idx"] == pos # def test_list_of_lists(norm: RelationalNormalizer) -> None: @@ -430,7 +429,7 @@ def test_child_row_deterministic_hash(norm: RelationalNormalizer) -> None: "_dlt_id": row_id, "f": [{"l": ["a", "b", "c"], "v": 120, "lo": [{"e": "a"}, {"e": "b"}, {"e": "c"}]}], } - rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + rows = list(norm._normalize_row(row, {}, ("table",))) children = [t for t in rows if t[0][0] != "table"] # all hashes must be different distinct_hashes = set([ch[1]["_dlt_id"] for ch in children]) @@ -449,19 +448,19 @@ def test_child_row_deterministic_hash(norm: RelationalNormalizer) -> None: assert f_lo_p2["_dlt_id"] == digest128(f"{el_f['_dlt_id']}_table__f__lo_2", DLT_ID_LENGTH_BYTES) # same data with same table and row_id - rows_2 = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + rows_2 = list(norm._normalize_row(row, {}, ("table",))) children_2 = [t for t in rows_2 if t[0][0] != "table"] # corresponding hashes must be identical assert all(ch[0][1]["_dlt_id"] == ch[1][1]["_dlt_id"] for ch in zip(children, children_2)) # change parent table and all child hashes must be different - rows_4 = list(norm._normalize_row(row, {}, ("other_table",))) # type: ignore[arg-type] + rows_4 = list(norm._normalize_row(row, {}, ("other_table",))) children_4 = [t for t in rows_4 if t[0][0] != "other_table"] assert all(ch[0][1]["_dlt_id"] != ch[1][1]["_dlt_id"] for ch in zip(children, children_4)) # change parent hash and all child hashes must be different row["_dlt_id"] = uniq_id() - rows_3 = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + rows_3 = list(norm._normalize_row(row, {}, ("table",))) children_3 = [t for t in rows_3 if t[0][0] != "table"] assert all(ch[0][1]["_dlt_id"] != ch[1][1]["_dlt_id"] for ch in zip(children, children_3)) @@ -469,14 +468,16 @@ def test_child_row_deterministic_hash(norm: RelationalNormalizer) -> None: def test_keeps_dlt_id(norm: RelationalNormalizer) -> None: h = uniq_id() row = {"a": "b", "_dlt_id": h} - rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + rows = list(norm._normalize_row(row, {}, ("table",))) root = [t for t in rows if t[0][0] == "table"][0][1] assert root["_dlt_id"] == h def test_propagate_hardcoded_context(norm: RelationalNormalizer) -> None: row = {"level": 1, "list": ["a", "b", "c"], "comp": [{"_timestamp": "a"}]} - rows = list(norm._normalize_row(row, {"_timestamp": 1238.9, "_dist_key": "SENDER_3000"}, ("table",))) # type: ignore[arg-type] + rows = list( + norm._normalize_row(row, {"_timestamp": 1238.9, "_dist_key": "SENDER_3000"}, ("table",)) + ) # context is not added to root element root = next(t for t in rows if t[0][0] == "table")[1] assert "_timestamp" in root @@ -506,7 +507,7 @@ def test_propagates_root_context(norm: RelationalNormalizer) -> None: "dependent_list": [1, 2, 3], "dependent_objects": [{"vx": "ax"}], } - normalized_rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + normalized_rows = list(norm._normalize_row(row, {}, ("table",))) # all non-root rows must have: non_root = [r for r in normalized_rows if r[0][1] is not None] assert all(r[1]["_dlt_root_id"] == "###" for r in non_root) @@ -522,12 +523,12 @@ def test_propagates_table_context( prop_config: RelationalNormalizerConfigPropagation = norm.schema._normalizers_config["json"][ "config" ]["propagation"] - prop_config["root"]["timestamp"] = "_partition_ts" # type: ignore[index] + prop_config["root"][TColumnName("timestamp")] = TColumnName("_partition_ts") # for table "table__lvl1" request to propagate "vx" and "partition_ovr" as "_partition_ts" (should overwrite root) - prop_config["tables"]["table__lvl1"] = { # type: ignore[index] - "vx": "__vx", - "partition_ovr": "_partition_ts", - "__not_found": "__not_found", + prop_config["tables"]["table__lvl1"] = { + TColumnName("vx"): TColumnName("__vx"), + TColumnName("partition_ovr"): TColumnName("_partition_ts"), + TColumnName("__not_found"): TColumnName("__not_found"), } if add_pk: @@ -545,7 +546,7 @@ def test_propagates_table_context( # to reproduce a bug where rows with _dlt_id set were not extended row["lvl1"][0]["_dlt_id"] = "row_id_lvl1" # type: ignore[index] - normalized_rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + normalized_rows = list(norm._normalize_row(row, {}, ("table",))) non_root = [r for r in normalized_rows if r[0][1] is not None] # _dlt_root_id in all non root assert all(r[1]["_dlt_root_id"] == "###" for r in non_root) @@ -574,10 +575,10 @@ def test_propagates_table_context_to_lists(norm: RelationalNormalizer) -> None: prop_config: RelationalNormalizerConfigPropagation = norm.schema._normalizers_config["json"][ "config" ]["propagation"] - prop_config["root"]["timestamp"] = "_partition_ts" # type: ignore[index] + prop_config["root"][TColumnName("timestamp")] = TColumnName("_partition_ts") row = {"_dlt_id": "###", "timestamp": 12918291.1212, "lvl1": [1, 2, 3, [4, 5, 6]]} - normalized_rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] + normalized_rows = list(norm._normalize_row(row, {}, ("table",))) # _partition_ts == timestamp on all child tables non_root = [r for r in normalized_rows if r[0][1] is not None] assert all(r[1]["_partition_ts"] == 12918291.1212 for r in non_root) @@ -590,7 +591,7 @@ def test_removes_normalized_list(norm: RelationalNormalizer) -> None: # after normalizing the list that got normalized into child table must be deleted row = {"comp": [{"_timestamp": "a"}]} # get iterator - normalized_rows_i = norm._normalize_row(row, {}, ("table",)) # type: ignore[arg-type] + normalized_rows_i = norm._normalize_row(row, {}, ("table",)) # yield just one item root_row = next(normalized_rows_i) # root_row = next(r for r in normalized_rows if r[0][1] is None) @@ -614,7 +615,7 @@ def test_preserves_complex_types_list(norm: RelationalNormalizer) -> None: ) ) row = {"value": ["from", {"complex": True}]} - normalized_rows = list(norm._normalize_row(row, {}, ("event_slot",))) # type: ignore[arg-type] + normalized_rows = list(norm._normalize_row(row, {}, ("event_slot",))) # make sure only 1 row is emitted, the list is not normalized assert len(normalized_rows) == 1 # value is kept in root row -> market as complex @@ -623,7 +624,7 @@ def test_preserves_complex_types_list(norm: RelationalNormalizer) -> None: # same should work for a list row = {"value": ["from", ["complex", True]]} # type: ignore[list-item] - normalized_rows = list(norm._normalize_row(row, {}, ("event_slot",))) # type: ignore[arg-type] + normalized_rows = list(norm._normalize_row(row, {}, ("event_slot",))) # make sure only 1 row is emitted, the list is not normalized assert len(normalized_rows) == 1 # value is kept in root row -> market as complex @@ -735,7 +736,7 @@ def test_table_name_meta_normalized() -> None: def test_parse_with_primary_key() -> None: schema = create_schema_with_name("discord") - schema.merge_hints({"primary_key": ["id"]}) # type: ignore[list-item] + schema._merge_hints({"primary_key": ["id"]}) # type: ignore[list-item] schema._compile_settings() add_dlt_root_id_propagation(schema.data_item_normalizer) # type: ignore[arg-type] diff --git a/tests/common/normalizers/test_naming.py b/tests/common/normalizers/test_naming.py index 3bf4762c35..84d36537e6 100644 --- a/tests/common/normalizers/test_naming.py +++ b/tests/common/normalizers/test_naming.py @@ -2,13 +2,29 @@ import string from typing import List, Type -from dlt.common.normalizers.naming import NamingConvention -from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention -from dlt.common.normalizers.naming.direct import NamingConvention as DirectNamingConvention +from dlt.common.normalizers.naming import ( + NamingConvention, + snake_case, + direct, + duck_case, + sql_ci_v1, + sql_cs_v1, +) from dlt.common.typing import DictStrStr from dlt.common.utils import uniq_id +ALL_NAMING_CONVENTIONS = { + snake_case.NamingConvention, + direct.NamingConvention, + duck_case.NamingConvention, + sql_ci_v1.NamingConvention, + sql_cs_v1.NamingConvention, +} + +ALL_UNDERSCORE_PATH_CONVENTIONS = ALL_NAMING_CONVENTIONS - {direct.NamingConvention} + + LONG_PATH = "prospects_external_data__data365_member__member__feed_activities_created_post__items__comments__items__comments__items__author_details__educations" DENSE_PATH = "__".join(string.ascii_lowercase) LONG_IDENT = 10 * string.printable @@ -139,7 +155,7 @@ def test_shorten_identifier() -> None: assert len(norm_ident) == 20 -@pytest.mark.parametrize("convention", (SnakeCaseNamingConvention, DirectNamingConvention)) +@pytest.mark.parametrize("convention", ALL_NAMING_CONVENTIONS) def test_normalize_with_shorten_identifier(convention: Type[NamingConvention]) -> None: naming = convention() # None/empty ident raises @@ -164,7 +180,7 @@ def test_normalize_with_shorten_identifier(convention: Type[NamingConvention]) - assert tag in naming.normalize_identifier(RAW_IDENT) -@pytest.mark.parametrize("convention", (SnakeCaseNamingConvention, DirectNamingConvention)) +@pytest.mark.parametrize("convention", ALL_NAMING_CONVENTIONS) def test_normalize_path_shorting(convention: Type[NamingConvention]) -> None: naming = convention() path = naming.make_path(*LONG_PATH.split("__")) @@ -207,10 +223,11 @@ def test_normalize_path_shorting(convention: Type[NamingConvention]) -> None: assert len(naming.break_path(norm_path)) == 1 -@pytest.mark.parametrize("convention", (SnakeCaseNamingConvention, DirectNamingConvention)) +@pytest.mark.parametrize("convention", ALL_NAMING_CONVENTIONS) def test_normalize_path(convention: Type[NamingConvention]) -> None: naming = convention() raw_path_str = naming.make_path(*RAW_PATH) + assert convention.PATH_SEPARATOR in raw_path_str # count separators norm_path_str = naming.normalize_path(raw_path_str) assert len(naming.break_path(norm_path_str)) == len(RAW_PATH) @@ -248,7 +265,7 @@ def test_normalize_path(convention: Type[NamingConvention]) -> None: assert tag in tagged_raw_path_str -@pytest.mark.parametrize("convention", (SnakeCaseNamingConvention, DirectNamingConvention)) +@pytest.mark.parametrize("convention", ALL_NAMING_CONVENTIONS) def test_shorten_fragments(convention: Type[NamingConvention]) -> None: # max length around the length of the path naming = convention() @@ -266,8 +283,30 @@ def test_shorten_fragments(convention: Type[NamingConvention]) -> None: assert naming.shorten_fragments(*RAW_PATH_WITH_EMPTY_IDENT) == norm_path -# 'event__parse_data__response_selector__default__response__response_templates' -# E 'event__parse_data__response_selector__default__response__responses' +@pytest.mark.parametrize("convention", ALL_UNDERSCORE_PATH_CONVENTIONS) +def test_normalize_break_path(convention: Type[NamingConvention]) -> None: + naming_unlimited = convention() + assert naming_unlimited.break_path("A__B__C") == ["A", "B", "C"] + # what if path has _a and _b which valid normalized idents + assert naming_unlimited.break_path("_a___b__C___D") == ["_a", "_b", "C", "_D"] + # skip empty identifiers from path + assert naming_unlimited.break_path("_a_____b") == ["_a", "_b"] + assert naming_unlimited.break_path("_a____b") == ["_a", "b"] + assert naming_unlimited.break_path("_a__ \t\r__b") == ["_a", "b"] + + +@pytest.mark.parametrize("convention", ALL_UNDERSCORE_PATH_CONVENTIONS) +def test_normalize_make_path(convention: Type[NamingConvention]) -> None: + naming_unlimited = convention() + assert naming_unlimited.make_path("A", "B") == "A__B" + assert naming_unlimited.make_path("_A", "_B") == "_A___B" + assert naming_unlimited.make_path("_A", "", "_B") == "_A___B" + assert naming_unlimited.make_path("_A", "\t\n ", "_B") == "_A___B" + + +def test_naming_convention_name() -> None: + assert snake_case.NamingConvention.name() == "snake_case" + assert direct.NamingConvention.name() == "direct" def assert_short_path(norm_path: str, naming: NamingConvention) -> None: diff --git a/tests/common/normalizers/test_naming_snake_case.py b/tests/common/normalizers/test_naming_snake_case.py index 6d619b5257..ee4f43e7f0 100644 --- a/tests/common/normalizers/test_naming_snake_case.py +++ b/tests/common/normalizers/test_naming_snake_case.py @@ -1,9 +1,7 @@ -from typing import Type import pytest from dlt.common.normalizers.naming import NamingConvention from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention -from dlt.common.normalizers.naming.duck_case import NamingConvention as DuckCaseNamingConvention @pytest.fixture @@ -54,30 +52,9 @@ def test_normalize_path(naming_unlimited: NamingConvention) -> None: def test_normalize_non_alpha_single_underscore() -> None: - assert SnakeCaseNamingConvention._RE_NON_ALPHANUMERIC.sub("_", "-=!*") == "_" - assert SnakeCaseNamingConvention._RE_NON_ALPHANUMERIC.sub("_", "1-=!0*-") == "1_0_" - assert SnakeCaseNamingConvention._RE_NON_ALPHANUMERIC.sub("_", "1-=!_0*-") == "1__0_" - - -@pytest.mark.parametrize("convention", (SnakeCaseNamingConvention, DuckCaseNamingConvention)) -def test_normalize_break_path(convention: Type[NamingConvention]) -> None: - naming_unlimited = convention() - assert naming_unlimited.break_path("A__B__C") == ["A", "B", "C"] - # what if path has _a and _b which valid normalized idents - assert naming_unlimited.break_path("_a___b__C___D") == ["_a", "_b", "C", "_D"] - # skip empty identifiers from path - assert naming_unlimited.break_path("_a_____b") == ["_a", "_b"] - assert naming_unlimited.break_path("_a____b") == ["_a", "b"] - assert naming_unlimited.break_path("_a__ \t\r__b") == ["_a", "b"] - - -@pytest.mark.parametrize("convention", (SnakeCaseNamingConvention, DuckCaseNamingConvention)) -def test_normalize_make_path(convention: Type[NamingConvention]) -> None: - naming_unlimited = convention() - assert naming_unlimited.make_path("A", "B") == "A__B" - assert naming_unlimited.make_path("_A", "_B") == "_A___B" - assert naming_unlimited.make_path("_A", "", "_B") == "_A___B" - assert naming_unlimited.make_path("_A", "\t\n ", "_B") == "_A___B" + assert SnakeCaseNamingConvention.RE_NON_ALPHANUMERIC.sub("_", "-=!*") == "_" + assert SnakeCaseNamingConvention.RE_NON_ALPHANUMERIC.sub("_", "1-=!0*-") == "1_0_" + assert SnakeCaseNamingConvention.RE_NON_ALPHANUMERIC.sub("_", "1-=!_0*-") == "1__0_" def test_normalizes_underscores(naming_unlimited: NamingConvention) -> None: diff --git a/tests/common/normalizers/test_naming_sql.py b/tests/common/normalizers/test_naming_sql.py new file mode 100644 index 0000000000..c290354c6a --- /dev/null +++ b/tests/common/normalizers/test_naming_sql.py @@ -0,0 +1,55 @@ +import pytest +from typing import Type +from dlt.common.normalizers.naming import NamingConvention, sql_ci_v1, sql_cs_v1 + +ALL_NAMING_CONVENTIONS = {sql_ci_v1.NamingConvention, sql_cs_v1.NamingConvention} + + +@pytest.mark.parametrize("convention", ALL_NAMING_CONVENTIONS) +def test_normalize_identifier(convention: Type[NamingConvention]) -> None: + naming = convention() + assert naming.normalize_identifier("event_value") == "event_value" + assert naming.normalize_identifier("event value") == "event_value" + assert naming.normalize_identifier("event-.!:*<>value") == "event_value" + # prefix leading digits + assert naming.normalize_identifier("1event_n'") == "_1event_n" + # remove trailing underscores + assert naming.normalize_identifier("123event_n'") == "_123event_n" + # contract underscores + assert naming.normalize_identifier("___a___b") == "_a_b" + # trim spaces + assert naming.normalize_identifier(" small love potion ") == "small_love_potion" + + # special characters converted to _ + assert naming.normalize_identifier("+-!$*@#=|:") == "_" + # leave single underscore + assert naming.normalize_identifier("_") == "_" + # some other cases + assert naming.normalize_identifier("+1") == "_1" + assert naming.normalize_identifier("-1") == "_1" + + +def test_case_sensitive_normalize() -> None: + naming = sql_cs_v1.NamingConvention() + # all lowercase and converted to snake + assert naming.normalize_identifier("123BaNaNa") == "_123BaNaNa" + # consecutive capital letters + assert naming.normalize_identifier("BANANA") == "BANANA" + assert naming.normalize_identifier("BAN_ANA") == "BAN_ANA" + assert naming.normalize_identifier("BANaNA") == "BANaNA" + # handling spaces + assert naming.normalize_identifier("Small Love Potion") == "Small_Love_Potion" + assert naming.normalize_identifier(" Small Love Potion ") == "Small_Love_Potion" + + +def test_case_insensitive_normalize() -> None: + naming = sql_ci_v1.NamingConvention() + # all lowercase and converted to snake + assert naming.normalize_identifier("123BaNaNa") == "_123banana" + # consecutive capital letters + assert naming.normalize_identifier("BANANA") == "banana" + assert naming.normalize_identifier("BAN_ANA") == "ban_ana" + assert naming.normalize_identifier("BANaNA") == "banana" + # handling spaces + assert naming.normalize_identifier("Small Love Potion") == "small_love_potion" + assert naming.normalize_identifier(" Small Love Potion ") == "small_love_potion" diff --git a/tests/common/schema/conftest.py b/tests/common/schema/conftest.py new file mode 100644 index 0000000000..53d02fc663 --- /dev/null +++ b/tests/common/schema/conftest.py @@ -0,0 +1,25 @@ +import pytest + +from dlt.common.configuration import resolve_configuration +from dlt.common.schema import Schema +from dlt.common.storages import SchemaStorageConfiguration, SchemaStorage + + +from tests.utils import autouse_test_storage, preserve_environ + + +@pytest.fixture +def schema() -> Schema: + return Schema("event") + + +@pytest.fixture +def schema_storage() -> SchemaStorage: + C = resolve_configuration( + SchemaStorageConfiguration(), + explicit_value={ + "import_schema_path": "tests/common/cases/schemas/rasa", + "external_schema_format": "json", + }, + ) + return SchemaStorage(C, makedirs=True) diff --git a/tests/common/schema/test_filtering.py b/tests/common/schema/test_filtering.py index 8cfac9309f..6634a38aa6 100644 --- a/tests/common/schema/test_filtering.py +++ b/tests/common/schema/test_filtering.py @@ -10,11 +10,6 @@ from tests.common.utils import load_json_case -@pytest.fixture -def schema() -> Schema: - return Schema("event") - - def test_row_field_filter(schema: Schema) -> None: _add_excludes(schema) bot_case: DictStrAny = load_json_case("mod_bot_case") diff --git a/tests/common/schema/test_inference.py b/tests/common/schema/test_inference.py index 0a40953f53..1540d8a74a 100644 --- a/tests/common/schema/test_inference.py +++ b/tests/common/schema/test_inference.py @@ -1,3 +1,4 @@ +import os import pytest from copy import deepcopy from typing import Any, List @@ -6,21 +7,17 @@ from dlt.common import Wei, Decimal, pendulum, json from dlt.common.json import custom_pua_decode from dlt.common.schema import Schema, utils -from dlt.common.schema.typing import TSimpleRegex +from dlt.common.schema.typing import TSimpleRegex, TTableSchemaColumns from dlt.common.schema.exceptions import ( CannotCoerceColumnException, CannotCoerceNullException, ParentTableNotFoundException, + SchemaCorruptedException, TablePropertiesConflictException, ) from tests.common.utils import load_json_case -@pytest.fixture -def schema() -> Schema: - return Schema("event") - - def test_get_preferred_type(schema: Schema) -> None: _add_preferred_types(schema) @@ -204,11 +201,10 @@ def test_shorten_variant_column(schema: Schema) -> None: } _, new_table = schema.coerce_row("event_user", None, row_1) # schema assumes that identifiers are already normalized so confidence even if it is longer than 9 chars - schema.update_table(new_table) + schema.update_table(new_table, normalize_identifiers=False) assert "confidence" in schema.tables["event_user"]["columns"] # confidence_123456 # now variant is created and this will be normalized - # TODO: we should move the handling of variants to normalizer new_row_2, new_table = schema.coerce_row("event_user", None, {"confidence": False}) tag = schema.naming._compute_tag( "confidence__v_bool", collision_prob=schema.naming._DEFAULT_COLLISION_PROB @@ -219,6 +215,9 @@ def test_shorten_variant_column(schema: Schema) -> None: def test_coerce_complex_variant(schema: Schema) -> None: + # for this test use case sensitive naming convention + os.environ["SCHEMA__NAMING"] = "direct" + schema.update_normalizers() # create two columns to which complex type cannot be coerced row = {"floatX": 78172.128, "confidenceX": 1.2, "strX": "STR"} new_row, new_table = schema.coerce_row("event_user", None, row) @@ -252,12 +251,12 @@ def test_coerce_complex_variant(schema: Schema) -> None: c_new_columns_v = list(c_new_table_v["columns"].values()) # two new variant columns added assert len(c_new_columns_v) == 2 - assert c_new_columns_v[0]["name"] == "floatX__v_complex" - assert c_new_columns_v[1]["name"] == "confidenceX__v_complex" + assert c_new_columns_v[0]["name"] == "floatX▶v_complex" + assert c_new_columns_v[1]["name"] == "confidenceX▶v_complex" assert c_new_columns_v[0]["variant"] is True assert c_new_columns_v[1]["variant"] is True - assert c_new_row_v["floatX__v_complex"] == v_list - assert c_new_row_v["confidenceX__v_complex"] == v_dict + assert c_new_row_v["floatX▶v_complex"] == v_list + assert c_new_row_v["confidenceX▶v_complex"] == v_dict assert c_new_row_v["strX"] == json.dumps(v_dict) schema.update_table(c_new_table_v) @@ -265,8 +264,8 @@ def test_coerce_complex_variant(schema: Schema) -> None: c_row_v = {"floatX": v_list, "confidenceX": v_dict, "strX": v_dict} c_new_row_v, c_new_table_v = schema.coerce_row("event_user", None, c_row_v) assert c_new_table_v is None - assert c_new_row_v["floatX__v_complex"] == v_list - assert c_new_row_v["confidenceX__v_complex"] == v_dict + assert c_new_row_v["floatX▶v_complex"] == v_list + assert c_new_row_v["confidenceX▶v_complex"] == v_dict assert c_new_row_v["strX"] == json.dumps(v_dict) @@ -539,7 +538,7 @@ def test_infer_on_incomplete_column(schema: Schema) -> None: incomplete_col["primary_key"] = True incomplete_col["x-special"] = "spec" # type: ignore[typeddict-unknown-key] table = utils.new_table("table", columns=[incomplete_col]) - schema.update_table(table) + schema.update_table(table, normalize_identifiers=False) # make sure that column is still incomplete and has no default hints assert schema.get_table("table")["columns"]["I"] == { "name": "I", @@ -586,3 +585,59 @@ def test_update_table_adds_at_end(schema: Schema) -> None: table = schema.tables["eth"] # place new columns at the end assert list(table["columns"].keys()) == ["evm", "_dlt_load_id"] + + +def test_get_new_columns(schema: Schema) -> None: + # allow for casing in names + os.environ["SCHEMA__NAMING"] = "direct" + schema.update_normalizers() + + empty_table = utils.new_table("events") + schema.update_table(empty_table) + assert schema.get_new_table_columns("events", {}, case_sensitive=True) == [] + name_column = utils.new_column("name", "text") + id_column = utils.new_column("ID", "text") + existing_columns: TTableSchemaColumns = { + "id": id_column, + "name": name_column, + } + # no new columns + assert schema.get_new_table_columns("events", existing_columns, case_sensitive=True) == [] + # one new column + address_column = utils.new_column("address", "complex") + schema.update_table(utils.new_table("events", columns=[address_column])) + assert schema.get_new_table_columns("events", existing_columns, case_sensitive=True) == [ + address_column + ] + assert schema.get_new_table_columns("events", existing_columns, case_sensitive=False) == [ + address_column + ] + # name is already present + schema.update_table(utils.new_table("events", columns=[name_column])) + # so it is not detected + assert schema.get_new_table_columns("events", existing_columns, case_sensitive=True) == [ + address_column + ] + assert schema.get_new_table_columns("events", existing_columns, case_sensitive=False) == [ + address_column + ] + # id is added with different casing + ID_column = utils.new_column("ID", "text") + schema.update_table(utils.new_table("events", columns=[ID_column])) + # case sensitive will detect + assert schema.get_new_table_columns("events", existing_columns, case_sensitive=True) == [ + address_column, + ID_column, + ] + # insensitive doesn't + assert schema.get_new_table_columns("events", existing_columns, case_sensitive=False) == [ + address_column + ] + + # existing columns are case sensitive + existing_columns["ID"] = ID_column + assert schema.get_new_table_columns("events", existing_columns, case_sensitive=True) == [ + address_column + ] + with pytest.raises(SchemaCorruptedException): + schema.get_new_table_columns("events", existing_columns, case_sensitive=False) diff --git a/tests/common/schema/test_merges.py b/tests/common/schema/test_merges.py index 8516414abd..893fd1db5f 100644 --- a/tests/common/schema/test_merges.py +++ b/tests/common/schema/test_merges.py @@ -2,10 +2,9 @@ import pytest from copy import copy, deepcopy -from dlt.common.schema import Schema, utils +from dlt.common.schema import utils from dlt.common.schema.exceptions import ( CannotCoerceColumnException, - CannotCoerceNullException, TablePropertiesConflictException, ) from dlt.common.schema.typing import TColumnSchemaBase, TStoredSchema, TTableSchema, TColumnSchema @@ -294,10 +293,10 @@ def test_diff_tables() -> None: empty = utils.new_table("table") del empty["resource"] print(empty) - partial = utils.diff_table(empty, deepcopy(table)) + partial = utils.diff_table("schema", empty, deepcopy(table)) # partial is simply table assert partial == table - partial = utils.diff_table(deepcopy(table), empty) + partial = utils.diff_table("schema", deepcopy(table), empty) # partial is empty assert partial == empty @@ -305,7 +304,7 @@ def test_diff_tables() -> None: changed = deepcopy(table) changed["description"] = "new description" changed["name"] = "new name" - partial = utils.diff_table(deepcopy(table), changed) + partial = utils.diff_table("schema", deepcopy(table), changed) print(partial) assert partial == {"name": "new name", "description": "new description", "columns": {}} @@ -313,7 +312,7 @@ def test_diff_tables() -> None: existing = deepcopy(table) changed["write_disposition"] = "append" changed["schema_contract"] = "freeze" - partial = utils.diff_table(deepcopy(existing), changed) + partial = utils.diff_table("schema", deepcopy(existing), changed) assert partial == { "name": "new name", "description": "new description", @@ -323,14 +322,14 @@ def test_diff_tables() -> None: } existing["write_disposition"] = "append" existing["schema_contract"] = "freeze" - partial = utils.diff_table(deepcopy(existing), changed) + partial = utils.diff_table("schema", deepcopy(existing), changed) assert partial == {"name": "new name", "description": "new description", "columns": {}} # detect changed column existing = deepcopy(table) changed = deepcopy(table) changed["columns"]["test"]["cluster"] = True - partial = utils.diff_table(existing, changed) + partial = utils.diff_table("schema", existing, changed) assert "test" in partial["columns"] assert "test_2" not in partial["columns"] assert existing["columns"]["test"] == table["columns"]["test"] != partial["columns"]["test"] @@ -339,7 +338,7 @@ def test_diff_tables() -> None: existing = deepcopy(table) changed = deepcopy(table) changed["columns"]["test"]["foreign_key"] = False - partial = utils.diff_table(existing, changed) + partial = utils.diff_table("schema", existing, changed) assert "test" in partial["columns"] # even if not present in tab_a at all @@ -347,7 +346,7 @@ def test_diff_tables() -> None: changed = deepcopy(table) changed["columns"]["test"]["foreign_key"] = False del existing["columns"]["test"]["foreign_key"] - partial = utils.diff_table(existing, changed) + partial = utils.diff_table("schema", existing, changed) assert "test" in partial["columns"] @@ -363,7 +362,7 @@ def test_diff_tables_conflicts() -> None: other = utils.new_table("table_2") with pytest.raises(TablePropertiesConflictException) as cf_ex: - utils.diff_table(table, other) + utils.diff_table("schema", table, other) assert cf_ex.value.table_name == "table" assert cf_ex.value.prop_name == "parent" @@ -371,7 +370,7 @@ def test_diff_tables_conflicts() -> None: changed = deepcopy(table) changed["columns"]["test"]["data_type"] = "bigint" with pytest.raises(CannotCoerceColumnException): - utils.diff_table(table, changed) + utils.diff_table("schema", table, changed) def test_merge_tables() -> None: @@ -391,7 +390,7 @@ def test_merge_tables() -> None: changed["new-prop-3"] = False # type: ignore[typeddict-unknown-key] # drop column so partial has it del table["columns"]["test"] - partial = utils.merge_table(table, changed) + partial = utils.merge_table("schema", table, changed) assert "test" in table["columns"] assert table["x-special"] == 129 # type: ignore[typeddict-item] assert table["description"] == "new description" @@ -420,7 +419,7 @@ def test_merge_tables_incomplete_columns() -> None: changed["columns"] = deepcopy({"test": COL_1_HINTS, "test_2": COL_2_HINTS}) # it is completed now changed["columns"]["test_2"]["data_type"] = "bigint" - partial = utils.merge_table(table, changed) + partial = utils.merge_table("schema", table, changed) assert list(partial["columns"].keys()) == ["test_2"] # test_2 goes to the end, it was incomplete in table so it got dropped before update assert list(table["columns"].keys()) == ["test", "test_2"] @@ -435,7 +434,7 @@ def test_merge_tables_incomplete_columns() -> None: changed["columns"] = deepcopy({"test": COL_1_HINTS, "test_2": COL_2_HINTS}) # still incomplete but changed changed["columns"]["test_2"]["nullable"] = False - partial = utils.merge_table(table, changed) + partial = utils.merge_table("schema", table, changed) assert list(partial["columns"].keys()) == ["test_2"] # incomplete -> incomplete stays in place assert list(table["columns"].keys()) == ["test_2", "test"] diff --git a/tests/common/schema/test_normalize_identifiers.py b/tests/common/schema/test_normalize_identifiers.py new file mode 100644 index 0000000000..646a693ea6 --- /dev/null +++ b/tests/common/schema/test_normalize_identifiers.py @@ -0,0 +1,421 @@ +from copy import deepcopy +import os +from typing import Callable +import pytest + +from dlt.common import json +from dlt.common.configuration import resolve_configuration +from dlt.common.configuration.container import Container +from dlt.common.normalizers.naming.naming import NamingConvention +from dlt.common.storages import SchemaStorageConfiguration +from dlt.common.destination.capabilities import DestinationCapabilitiesContext +from dlt.common.normalizers.naming import snake_case, direct +from dlt.common.schema import TColumnSchema, Schema, TStoredSchema, utils +from dlt.common.schema.exceptions import TableIdentifiersFrozen +from dlt.common.schema.typing import SIMPLE_REGEX_PREFIX +from dlt.common.storages import SchemaStorage + +from tests.common.cases.normalizers import sql_upper +from tests.common.utils import load_json_case, load_yml_case + + +@pytest.fixture +def schema_storage_no_import() -> SchemaStorage: + C = resolve_configuration(SchemaStorageConfiguration()) + return SchemaStorage(C, makedirs=True) + + +@pytest.fixture +def cn_schema() -> Schema: + return Schema( + "column_default", + { + "names": "tests.common.normalizers.custom_normalizers", + "json": { + "module": "tests.common.normalizers.custom_normalizers", + "config": {"not_null": ["fake_id"]}, + }, + }, + ) + + +def test_save_store_schema_custom_normalizers( + cn_schema: Schema, schema_storage: SchemaStorage +) -> None: + schema_storage.save_schema(cn_schema) + schema_copy = schema_storage.load_schema(cn_schema.name) + assert_new_schema_values_custom_normalizers(schema_copy) + + +def test_new_schema_custom_normalizers(cn_schema: Schema) -> None: + assert_new_schema_values_custom_normalizers(cn_schema) + + +def test_save_load_incomplete_column( + schema: Schema, schema_storage_no_import: SchemaStorage +) -> None: + # make sure that incomplete column is saved and restored without default hints + incomplete_col = utils.new_column("I", nullable=False) + incomplete_col["primary_key"] = True + incomplete_col["x-special"] = "spec" # type: ignore[typeddict-unknown-key] + table = utils.new_table("table", columns=[incomplete_col]) + schema.update_table(table, normalize_identifiers=False) + schema_storage_no_import.save_schema(schema) + schema_copy = schema_storage_no_import.load_schema("event") + assert schema_copy.get_table("table")["columns"]["I"] == { + "name": "I", + "nullable": False, + "primary_key": True, + "x-special": "spec", + } + + +def test_schema_config_normalizers(schema: Schema, schema_storage_no_import: SchemaStorage) -> None: + # save snake case schema + assert schema._normalizers_config["names"] == "snake_case" + schema_storage_no_import.save_schema(schema) + # config direct naming convention + os.environ["SCHEMA__NAMING"] = "direct" + # new schema has direct naming convention + schema_direct_nc = Schema("direct_naming") + schema_storage_no_import.save_schema(schema_direct_nc) + assert schema_direct_nc._normalizers_config["names"] == "direct" + # still after loading the config is "snake" + schema = schema_storage_no_import.load_schema(schema.name) + assert schema._normalizers_config["names"] == "snake_case" + # provide capabilities context + destination_caps = DestinationCapabilitiesContext.generic_capabilities() + destination_caps.naming_convention = "sql_cs_v1" + destination_caps.max_identifier_length = 127 + with Container().injectable_context(destination_caps): + # caps are ignored if schema is configured + schema_direct_nc = Schema("direct_naming") + assert schema_direct_nc._normalizers_config["names"] == "direct" + # but length is there + assert schema_direct_nc.naming.max_length == 127 + # when loading schema configuration is ignored + schema = schema_storage_no_import.load_schema(schema.name) + assert schema._normalizers_config["names"] == "snake_case" + assert schema.naming.max_length == 127 + # but if we ask to update normalizers config schema is applied + schema.update_normalizers() + assert schema._normalizers_config["names"] == "direct" + + # load schema_direct_nc (direct) + schema_direct_nc = schema_storage_no_import.load_schema(schema_direct_nc.name) + assert schema_direct_nc._normalizers_config["names"] == "direct" + + # drop config + del os.environ["SCHEMA__NAMING"] + schema_direct_nc = schema_storage_no_import.load_schema(schema_direct_nc.name) + assert schema_direct_nc._normalizers_config["names"] == "direct" + + +def test_schema_normalizers_no_config( + schema: Schema, schema_storage_no_import: SchemaStorage +) -> None: + # convert schema to direct and save + os.environ["SCHEMA__NAMING"] = "direct" + schema.update_normalizers() + assert schema._normalizers_config["names"] == "direct" + schema_storage_no_import.save_schema(schema) + # make sure we drop the config correctly + del os.environ["SCHEMA__NAMING"] + schema_test = Schema("test") + assert schema_test.naming.name() == "snake_case" + # use capabilities without default naming convention + destination_caps = DestinationCapabilitiesContext.generic_capabilities() + assert destination_caps.naming_convention is None + destination_caps.max_identifier_length = 66 + with Container().injectable_context(destination_caps): + schema_in_caps = Schema("schema_in_caps") + assert schema_in_caps._normalizers_config["names"] == "snake_case" + assert schema_in_caps.naming.name() == "snake_case" + assert schema_in_caps.naming.max_length == 66 + schema_in_caps.update_normalizers() + assert schema_in_caps.naming.name() == "snake_case" + # old schema preserves convention when loaded + schema = schema_storage_no_import.load_schema(schema.name) + assert schema._normalizers_config["names"] == "direct" + # update normalizer no effect + schema.update_normalizers() + assert schema._normalizers_config["names"] == "direct" + assert schema.naming.max_length == 66 + + # use caps with default naming convention + destination_caps = DestinationCapabilitiesContext.generic_capabilities() + destination_caps.naming_convention = "sql_cs_v1" + destination_caps.max_identifier_length = 127 + with Container().injectable_context(destination_caps): + schema_in_caps = Schema("schema_in_caps") + # new schema gets convention from caps + assert schema_in_caps._normalizers_config["names"] == "sql_cs_v1" + # old schema preserves convention when loaded + schema = schema_storage_no_import.load_schema(schema.name) + assert schema._normalizers_config["names"] == "direct" + # update changes to caps schema + schema.update_normalizers() + assert schema._normalizers_config["names"] == "sql_cs_v1" + assert schema.naming.max_length == 127 + + +@pytest.mark.parametrize("section", ("SOURCES__SCHEMA__NAMING", "SOURCES__THIS__SCHEMA__NAMING")) +def test_config_with_section(section: str) -> None: + os.environ["SOURCES__OTHER__SCHEMA__NAMING"] = "direct" + os.environ[section] = "sql_cs_v1" + this_schema = Schema("this") + that_schema = Schema("that") + assert this_schema.naming.name() == "sql_cs_v1" + expected_that_schema = ( + "snake_case" if section == "SOURCES__THIS__SCHEMA__NAMING" else "sql_cs_v1" + ) + assert that_schema.naming.name() == expected_that_schema + + # test update normalizers + os.environ[section] = "direct" + expected_that_schema = "snake_case" if section == "SOURCES__THIS__SCHEMA__NAMING" else "direct" + this_schema.update_normalizers() + assert this_schema.naming.name() == "direct" + that_schema.update_normalizers() + assert that_schema.naming.name() == expected_that_schema + + +def test_normalize_table_identifiers() -> None: + # load with snake case + schema_dict: TStoredSchema = load_json_case("schemas/github/issues.schema") + schema = Schema.from_dict(schema_dict) # type: ignore[arg-type] + issues_table = schema.tables["issues"] + issues_table_str = json.dumps(issues_table) + # normalize table to upper + issues_table_norm = utils.normalize_table_identifiers( + issues_table, sql_upper.NamingConvention() + ) + # nothing got changes in issues table + assert issues_table_str == json.dumps(issues_table) + # check normalization + assert issues_table_norm["name"] == "ISSUES" + assert "REACTIONS___1" in issues_table_norm["columns"] + # subsequent normalization does not change dict + assert issues_table_norm == utils.normalize_table_identifiers( + issues_table_norm, sql_upper.NamingConvention() + ) + + +def test_normalize_table_identifiers_idempotent() -> None: + schema_dict: TStoredSchema = load_json_case("schemas/github/issues.schema") + schema = Schema.from_dict(schema_dict) # type: ignore[arg-type] + # assert column generated from "reactions/+1" and "-1", it is a valid identifier even with three underscores + assert "reactions___1" in schema.tables["issues"]["columns"] + issues_table = schema.tables["issues"] + # this schema is already normalized so normalization is idempotent + assert schema.tables["issues"] == utils.normalize_table_identifiers(issues_table, schema.naming) + assert schema.tables["issues"] == utils.normalize_table_identifiers( + utils.normalize_table_identifiers(issues_table, schema.naming), schema.naming + ) + + +def test_normalize_table_identifiers_merge_columns() -> None: + # create conflicting columns + table_create = [ + {"name": "case", "data_type": "bigint", "nullable": False, "x-description": "desc"}, + {"name": "Case", "data_type": "double", "nullable": True, "primary_key": True}, + ] + # schema normalizing to snake case will conflict on case and Case + table = utils.new_table("blend", columns=table_create) # type: ignore[arg-type] + table_str = json.dumps(table) + norm_table = utils.normalize_table_identifiers(table, Schema("norm").naming) + # nothing got changed in original table + assert table_str == json.dumps(table) + # only one column + assert len(norm_table["columns"]) == 1 + assert norm_table["columns"]["case"] == { + "nullable": False, # remove default, preserve non default + "primary_key": True, + "name": "case", + "data_type": "double", + "x-description": "desc", + } + + +def test_update_normalizers() -> None: + schema_dict: TStoredSchema = load_json_case("schemas/github/issues.schema") + schema = Schema.from_dict(schema_dict) # type: ignore[arg-type] + # drop seen data + del schema.tables["issues"]["x-normalizer"] + del schema.tables["issues__labels"]["x-normalizer"] + del schema.tables["issues__assignees"]["x-normalizer"] + # save default hints in original form + default_hints = schema._settings["default_hints"] + + os.environ["SCHEMA__NAMING"] = "tests.common.cases.normalizers.sql_upper" + schema.update_normalizers() + assert isinstance(schema.naming, sql_upper.NamingConvention) + # print(schema.to_pretty_yaml()) + assert_schema_identifiers_case(schema, str.upper) + + # resource must be old name + assert schema.tables["ISSUES"]["resource"] == "issues" + + # make sure normalizer config is replaced + assert schema._normalizers_config["names"] == "tests.common.cases.normalizers.sql_upper" + assert "allow_identifier_change_on_table_with_data" not in schema._normalizers_config + + # regexes are uppercased + new_default_hints = schema._settings["default_hints"] + for hint, regexes in default_hints.items(): + # same number of hints + assert len(regexes) == len(new_default_hints[hint]) + # but all upper cased + assert set(n.upper() for n in regexes) == set(new_default_hints[hint]) + + +def test_normalize_default_hints(schema_storage_no_import: SchemaStorage) -> None: + # use destination caps to force naming convention + from dlt.common.destination import DestinationCapabilitiesContext + from dlt.common.configuration.container import Container + + eth_V9 = load_yml_case("schemas/eth/ethereum_schema_v9") + orig_schema = Schema.from_dict(eth_V9) + # save schema + schema_storage_no_import.save_schema(orig_schema) + + with Container().injectable_context( + DestinationCapabilitiesContext.generic_capabilities(naming_convention=sql_upper) + ) as caps: + assert caps.naming_convention is sql_upper + # creating a schema from dict keeps original normalizers + schema = Schema.from_dict(eth_V9) + assert_schema_identifiers_case(schema, str.lower) + assert schema._normalizers_config["names"].endswith("snake_case") + + # loading from storage keeps storage normalizers + storage_schema = schema_storage_no_import.load_schema("ethereum") + assert_schema_identifiers_case(storage_schema, str.lower) + assert storage_schema._normalizers_config["names"].endswith("snake_case") + + # new schema instance is created using caps/config + new_schema = Schema("new") + assert_schema_identifiers_case(new_schema, str.upper) + assert ( + new_schema._normalizers_config["names"] + == "tests.common.cases.normalizers.sql_upper.NamingConvention" + ) + + # attempt to update normalizers blocked by tables with data + with pytest.raises(TableIdentifiersFrozen): + schema.update_normalizers() + # also cloning with update normalizers + with pytest.raises(TableIdentifiersFrozen): + schema.clone(update_normalizers=True) + + # remove processing hints and normalize + norm_cloned = schema.clone(update_normalizers=True, remove_processing_hints=True) + assert_schema_identifiers_case(norm_cloned, str.upper) + assert ( + norm_cloned._normalizers_config["names"] + == "tests.common.cases.normalizers.sql_upper.NamingConvention" + ) + + norm_schema = Schema.from_dict( + deepcopy(eth_V9), remove_processing_hints=True, bump_version=False + ) + norm_schema.update_normalizers() + assert_schema_identifiers_case(norm_schema, str.upper) + assert ( + norm_schema._normalizers_config["names"] + == "tests.common.cases.normalizers.sql_upper.NamingConvention" + ) + + # both ways of obtaining schemas (cloning, cleaning dict) must generate identical schemas + assert norm_cloned.to_pretty_json() == norm_schema.to_pretty_json() + + # save to storage + schema_storage_no_import.save_schema(norm_cloned) + + # load schema out of caps + storage_schema = schema_storage_no_import.load_schema("ethereum") + assert_schema_identifiers_case(storage_schema, str.upper) + # the instance got converted into + assert storage_schema._normalizers_config["names"].endswith("sql_upper.NamingConvention") + assert storage_schema.stored_version_hash == storage_schema.version_hash + # cloned when bumped must have same version hash + norm_cloned._bump_version() + assert storage_schema.stored_version_hash == norm_cloned.stored_version_hash + + +def test_raise_on_change_identifier_table_with_data() -> None: + schema_dict: TStoredSchema = load_json_case("schemas/github/issues.schema") + schema = Schema.from_dict(schema_dict) # type: ignore[arg-type] + # mark issues table to seen data and change naming to sql upper + issues_table = schema.tables["issues"] + issues_table["x-normalizer"] = {"seen-data": True} + os.environ["SCHEMA__NAMING"] = "tests.common.cases.normalizers.sql_upper" + with pytest.raises(TableIdentifiersFrozen) as fr_ex: + schema.update_normalizers() + # _dlt_version is the first table to be normalized, and since there are tables + # that have seen data, we consider _dlt_version also be materialized + assert fr_ex.value.table_name == "_dlt_version" + assert isinstance(fr_ex.value.from_naming, snake_case.NamingConvention) + assert isinstance(fr_ex.value.to_naming, sql_upper.NamingConvention) + # try again, get exception (schema was not partially modified) + with pytest.raises(TableIdentifiersFrozen) as fr_ex: + schema.update_normalizers() + + # use special naming convention that only changes column names ending with x to _ + issues_table["columns"]["columnx"] = {"name": "columnx", "data_type": "bigint"} + assert schema.tables["issues"] is issues_table + os.environ["SCHEMA__NAMING"] = "tests.common.cases.normalizers.snake_no_x" + with pytest.raises(TableIdentifiersFrozen) as fr_ex: + schema.update_normalizers() + assert fr_ex.value.table_name == "issues" + # allow to change tables with data + os.environ["SCHEMA__ALLOW_IDENTIFIER_CHANGE_ON_TABLE_WITH_DATA"] = "True" + schema.update_normalizers() + assert schema._normalizers_config["allow_identifier_change_on_table_with_data"] is True + + +def assert_schema_identifiers_case(schema: Schema, casing: Callable[[str], str]) -> None: + for table_name, table in schema.tables.items(): + assert table_name == casing(table_name) == table["name"] + if "parent" in table: + assert table["parent"] == casing(table["parent"]) + for col_name, column in table["columns"].items(): + assert col_name == casing(col_name) == column["name"] + + # make sure table prefixes are set + assert schema._dlt_tables_prefix == casing("_dlt") + assert schema.loads_table_name == casing("_dlt_loads") + assert schema.version_table_name == casing("_dlt_version") + assert schema.state_table_name == casing("_dlt_pipeline_state") + + def _case_regex(regex: str) -> str: + if regex.startswith(SIMPLE_REGEX_PREFIX): + return SIMPLE_REGEX_PREFIX + casing(regex[3:]) + else: + return casing(regex) + + # regexes are uppercased + new_default_hints = schema._settings["default_hints"] + for hint, regexes in new_default_hints.items(): + # but all upper cased + assert set(_case_regex(n) for n in regexes) == set(new_default_hints[hint]) + + +def assert_new_schema_values_custom_normalizers(schema: Schema) -> None: + # check normalizers config + assert schema._normalizers_config["names"] == "tests.common.normalizers.custom_normalizers" + assert ( + schema._normalizers_config["json"]["module"] + == "tests.common.normalizers.custom_normalizers" + ) + # check if schema was extended by json normalizer + assert ["fake_id"] == schema.settings["default_hints"]["not_null"] + # call normalizers + assert schema.naming.normalize_identifier("a") == "column_a" + assert schema.naming.normalize_path("a__b") == "column_a__column_b" + assert schema.naming.normalize_identifier("1A_b") == "column_1a_b" + # assumes elements are normalized + assert schema.naming.make_path("A", "B", "!C") == "A__B__!C" + assert schema.naming.break_path("A__B__!C") == ["A", "B", "!C"] + row = list(schema.normalize_data_item({"bool": True}, "load_id", "a_table")) + assert row[0] == (("a_table", None), {"bool": True}) diff --git a/tests/common/schema/test_schema.py b/tests/common/schema/test_schema.py index 887b0aa9a0..93be165358 100644 --- a/tests/common/schema/test_schema.py +++ b/tests/common/schema/test_schema.py @@ -1,19 +1,17 @@ -from copy import deepcopy import os -from typing import List, Sequence, cast +from typing import Dict, List, Sequence import pytest +from copy import deepcopy from dlt.common import pendulum -from dlt.common.configuration import resolve_configuration -from dlt.common.configuration.container import Container +from dlt.common.json import json +from dlt.common.data_types.typing import TDataType from dlt.common.schema.migrations import migrate_schema -from dlt.common.storages import SchemaStorageConfiguration -from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.exceptions import DictValidationException -from dlt.common.normalizers.naming import snake_case, direct +from dlt.common.normalizers.naming import snake_case from dlt.common.typing import DictStrAny, StrAny from dlt.common.utils import uniq_id -from dlt.common.schema import TColumnSchema, Schema, TStoredSchema, utils, TColumnHint +from dlt.common.schema import TColumnSchema, Schema, TStoredSchema, utils from dlt.common.schema.exceptions import ( InvalidSchemaName, ParentTableNotFoundException, @@ -28,50 +26,12 @@ ) from dlt.common.storages import SchemaStorage -from tests.utils import autouse_test_storage, preserve_environ from tests.common.utils import load_json_case, load_yml_case, COMMON_TEST_CASES_PATH SCHEMA_NAME = "event" EXPECTED_FILE_NAME = f"{SCHEMA_NAME}.schema.json" -@pytest.fixture -def schema_storage() -> SchemaStorage: - C = resolve_configuration( - SchemaStorageConfiguration(), - explicit_value={ - "import_schema_path": "tests/common/cases/schemas/rasa", - "external_schema_format": "json", - }, - ) - return SchemaStorage(C, makedirs=True) - - -@pytest.fixture -def schema_storage_no_import() -> SchemaStorage: - C = resolve_configuration(SchemaStorageConfiguration()) - return SchemaStorage(C, makedirs=True) - - -@pytest.fixture -def schema() -> Schema: - return Schema("event") - - -@pytest.fixture -def cn_schema() -> Schema: - return Schema( - "column_default", - { - "names": "tests.common.normalizers.custom_normalizers", - "json": { - "module": "tests.common.normalizers.custom_normalizers", - "config": {"not_null": ["fake_id"]}, - }, - }, - ) - - def test_normalize_schema_name(schema: Schema) -> None: assert schema.naming.normalize_table_identifier("BAN_ANA") == "ban_ana" assert schema.naming.normalize_table_identifier("event-.!:value") == "event_value" @@ -102,38 +62,6 @@ def test_new_schema(schema: Schema) -> None: utils.validate_stored_schema(stored_schema) -def test_new_schema_custom_normalizers(cn_schema: Schema) -> None: - assert_is_new_schema(cn_schema) - assert_new_schema_props_custom_normalizers(cn_schema) - - -def test_schema_config_normalizers(schema: Schema, schema_storage_no_import: SchemaStorage) -> None: - # save snake case schema - schema_storage_no_import.save_schema(schema) - # config direct naming convention - os.environ["SCHEMA__NAMING"] = "direct" - # new schema has direct naming convention - schema_direct_nc = Schema("direct_naming") - assert schema_direct_nc._normalizers_config["names"] == "direct" - # still after loading the config is "snake" - schema = schema_storage_no_import.load_schema(schema.name) - assert schema._normalizers_config["names"] == "snake_case" - # provide capabilities context - destination_caps = DestinationCapabilitiesContext.generic_capabilities() - destination_caps.naming_convention = "snake_case" - destination_caps.max_identifier_length = 127 - with Container().injectable_context(destination_caps): - # caps are ignored if schema is configured - schema_direct_nc = Schema("direct_naming") - assert schema_direct_nc._normalizers_config["names"] == "direct" - # but length is there - assert schema_direct_nc.naming.max_length == 127 - # also for loaded schema - schema = schema_storage_no_import.load_schema(schema.name) - assert schema._normalizers_config["names"] == "snake_case" - assert schema.naming.max_length == 127 - - def test_simple_regex_validator() -> None: # can validate only simple regexes assert utils.simple_regex_validator(".", "k", "v", str) is False @@ -394,33 +322,6 @@ def test_save_store_schema(schema: Schema, schema_storage: SchemaStorage) -> Non assert_new_schema_props(schema_copy) -def test_save_store_schema_custom_normalizers( - cn_schema: Schema, schema_storage: SchemaStorage -) -> None: - schema_storage.save_schema(cn_schema) - schema_copy = schema_storage.load_schema(cn_schema.name) - assert_new_schema_props_custom_normalizers(schema_copy) - - -def test_save_load_incomplete_column( - schema: Schema, schema_storage_no_import: SchemaStorage -) -> None: - # make sure that incomplete column is saved and restored without default hints - incomplete_col = utils.new_column("I", nullable=False) - incomplete_col["primary_key"] = True - incomplete_col["x-special"] = "spec" # type: ignore[typeddict-unknown-key] - table = utils.new_table("table", columns=[incomplete_col]) - schema.update_table(table) - schema_storage_no_import.save_schema(schema) - schema_copy = schema_storage_no_import.load_schema("event") - assert schema_copy.get_table("table")["columns"]["I"] == { - "name": "I", - "nullable": False, - "primary_key": True, - "x-special": "spec", - } - - def test_upgrade_engine_v1_schema() -> None: schema_dict: DictStrAny = load_json_case("schemas/ev1/event.schema") # ensure engine v1 @@ -479,7 +380,7 @@ def test_unknown_engine_upgrade() -> None: def test_preserve_column_order(schema: Schema, schema_storage: SchemaStorage) -> None: # python dicts are ordered from v3.6, add 50 column with random names update: List[TColumnSchema] = [ - schema._infer_column(uniq_id(), pendulum.now().timestamp()) for _ in range(50) + schema._infer_column("t" + uniq_id(), pendulum.now().timestamp()) for _ in range(50) ] schema.update_table(utils.new_table("event_test_order", columns=update)) @@ -496,7 +397,7 @@ def verify_items(table, update) -> None: verify_items(table, update) # add more columns update2: List[TColumnSchema] = [ - schema._infer_column(uniq_id(), pendulum.now().timestamp()) for _ in range(50) + schema._infer_column("t" + uniq_id(), pendulum.now().timestamp()) for _ in range(50) ] loaded_schema.update_table(utils.new_table("event_test_order", columns=update2)) table = loaded_schema.get_table_columns("event_test_order") @@ -648,6 +549,79 @@ def test_merge_hints(schema: Schema) -> None: for k in expected_hints: assert set(expected_hints[k]) == set(schema._settings["default_hints"][k]) # type: ignore[index] + # make sure that re:^_dlt_id$ and _dlt_id are equivalent when merging so we can use both forms + alt_form_hints = { + "not_null": ["re:^_dlt_id$"], + "foreign_key": ["_dlt_parent_id"], + } + schema.merge_hints(alt_form_hints) # type: ignore[arg-type] + # we keep the older forms so nothing changed + assert len(expected_hints) == len(schema._settings["default_hints"]) + for k in expected_hints: + assert set(expected_hints[k]) == set(schema._settings["default_hints"][k]) # type: ignore[index] + + # check normalize some regex forms + upper_hints = { + "not_null": [ + "_DLT_ID", + ], + "foreign_key": ["re:^_DLT_PARENT_ID$"], + } + schema.merge_hints(upper_hints) # type: ignore[arg-type] + # all upper form hints can be automatically converted to lower form + assert len(expected_hints) == len(schema._settings["default_hints"]) + for k in expected_hints: + assert set(expected_hints[k]) == set(schema._settings["default_hints"][k]) # type: ignore[index] + + # this form cannot be converted + upper_hints = { + "not_null": [ + "re:TU[b-b]a", + ], + } + schema.merge_hints(upper_hints) # type: ignore[arg-type] + assert "re:TU[b-b]a" in schema.settings["default_hints"]["not_null"] + + +def test_update_preferred_types(schema: Schema) -> None: + # no preferred types in the schema + assert "preferred_types" not in schema.settings + + expected: Dict[TSimpleRegex, TDataType] = { + TSimpleRegex("_dlt_id"): "bigint", + TSimpleRegex("re:^timestamp$"): "timestamp", + } + schema.update_preferred_types(expected) + assert schema.settings["preferred_types"] == expected + # no changes + schema.update_preferred_types(expected) + assert schema.settings["preferred_types"] == expected + + # add and replace, canonical form used to update / replace + updated: Dict[TSimpleRegex, TDataType] = { + TSimpleRegex("_dlt_id"): "decimal", + TSimpleRegex("timestamp"): "date", + TSimpleRegex("re:TU[b-c]a"): "text", + } + schema.update_preferred_types(updated) + assert schema.settings["preferred_types"] == { + "_dlt_id": "decimal", + "re:^timestamp$": "date", + "re:TU[b-c]a": "text", + } + + # will normalize some form of regex + updated = { + TSimpleRegex("_DLT_id"): "text", + TSimpleRegex("re:^TIMESTAMP$"): "timestamp", + } + schema.update_preferred_types(updated) + assert schema.settings["preferred_types"] == { + "_dlt_id": "text", + "re:^timestamp$": "timestamp", + "re:TU[b-c]a": "text", + } + def test_default_table_resource() -> None: """Parent tables without `resource` set default to table name""" @@ -766,9 +740,9 @@ def test_normalize_table_identifiers() -> None: assert "reactions___1" in schema.tables["issues"]["columns"] issues_table = deepcopy(schema.tables["issues"]) # this schema is already normalized so normalization is idempotent - assert schema.tables["issues"] == schema.normalize_table_identifiers(issues_table) - assert schema.tables["issues"] == schema.normalize_table_identifiers( - schema.normalize_table_identifiers(issues_table) + assert schema.tables["issues"] == utils.normalize_table_identifiers(issues_table, schema.naming) + assert schema.tables["issues"] == utils.normalize_table_identifiers( + utils.normalize_table_identifiers(issues_table, schema.naming), schema.naming ) @@ -780,7 +754,10 @@ def test_normalize_table_identifiers_merge_columns() -> None: ] # schema normalizing to snake case will conflict on case and Case table = utils.new_table("blend", columns=table_create) # type: ignore[arg-type] - norm_table = Schema("norm").normalize_table_identifiers(table) + table_str = json.dumps(table) + norm_table = utils.normalize_table_identifiers(table, Schema("norm").naming) + # nothing got changed in original table + assert table_str == json.dumps(table) # only one column assert len(norm_table["columns"]) == 1 assert norm_table["columns"]["case"] == { @@ -859,20 +836,21 @@ def test_group_tables_by_resource(schema: Schema) -> None: schema.update_table(utils.new_table("a_events", columns=[])) schema.update_table(utils.new_table("b_events", columns=[])) schema.update_table(utils.new_table("c_products", columns=[], resource="products")) - schema.update_table(utils.new_table("a_events__1", columns=[], parent_table_name="a_events")) + schema.update_table(utils.new_table("a_events___1", columns=[], parent_table_name="a_events")) schema.update_table( - utils.new_table("a_events__1__2", columns=[], parent_table_name="a_events__1") + utils.new_table("a_events___1___2", columns=[], parent_table_name="a_events___1") ) - schema.update_table(utils.new_table("b_events__1", columns=[], parent_table_name="b_events")) + schema.update_table(utils.new_table("b_events___1", columns=[], parent_table_name="b_events")) + # print(schema.to_pretty_yaml()) # All resources without filter expected_tables = { "a_events": [ schema.tables["a_events"], - schema.tables["a_events__1"], - schema.tables["a_events__1__2"], + schema.tables["a_events___1"], + schema.tables["a_events___1___2"], ], - "b_events": [schema.tables["b_events"], schema.tables["b_events__1"]], + "b_events": [schema.tables["b_events"], schema.tables["b_events___1"]], "products": [schema.tables["c_products"]], "_dlt_version": [schema.tables["_dlt_version"]], "_dlt_loads": [schema.tables["_dlt_loads"]], @@ -887,10 +865,10 @@ def test_group_tables_by_resource(schema: Schema) -> None: assert result == { "a_events": [ schema.tables["a_events"], - schema.tables["a_events__1"], - schema.tables["a_events__1__2"], + schema.tables["a_events___1"], + schema.tables["a_events___1___2"], ], - "b_events": [schema.tables["b_events"], schema.tables["b_events__1"]], + "b_events": [schema.tables["b_events"], schema.tables["b_events___1"]], } # With resources that has many top level tables @@ -919,3 +897,41 @@ def test_group_tables_by_resource(schema: Schema) -> None: {"columns": {}, "name": "mc_products__sub", "parent": "mc_products"}, ] } + + +def test_remove_processing_hints() -> None: + eth_V9 = load_yml_case("schemas/eth/ethereum_schema_v9") + # here tables contain processing hints + schema = Schema.from_dict(eth_V9) + assert "x-normalizer" in schema.tables["blocks"] + + # clone with hints removal, note that clone does not bump version + cloned = schema.clone(remove_processing_hints=True) + assert "x-normalizer" not in cloned.tables["blocks"] + # clone does not touch original schema + assert "x-normalizer" in schema.tables["blocks"] + + # to string + to_yaml = schema.to_pretty_yaml() + assert "x-normalizer" in to_yaml + to_yaml = schema.to_pretty_yaml(remove_processing_hints=True) + assert "x-normalizer" not in to_yaml + to_json = schema.to_pretty_json() + assert "x-normalizer" in to_json + to_json = schema.to_pretty_json(remove_processing_hints=True) + assert "x-normalizer" not in to_json + + # load without hints + no_hints = schema.from_dict(eth_V9, remove_processing_hints=True, bump_version=False) + assert no_hints.stored_version_hash == cloned.stored_version_hash + + # now load without hints but with version bump + cloned._bump_version() + no_hints = schema.from_dict(eth_V9, remove_processing_hints=True) + assert no_hints.stored_version_hash == cloned.stored_version_hash + + +# def test_get_new_table_columns() -> None: +# pytest.fail(reason="must implement!") +# pass +# get_new_table_columns() diff --git a/tests/common/schema/test_versioning.py b/tests/common/schema/test_versioning.py index b67b028161..788da09533 100644 --- a/tests/common/schema/test_versioning.py +++ b/tests/common/schema/test_versioning.py @@ -1,6 +1,5 @@ import pytest import yaml -from copy import deepcopy from dlt.common import json from dlt.common.schema import utils diff --git a/tests/common/storages/test_file_storage.py b/tests/common/storages/test_file_storage.py index eae765398b..7a10e29097 100644 --- a/tests/common/storages/test_file_storage.py +++ b/tests/common/storages/test_file_storage.py @@ -39,38 +39,40 @@ def test_to_relative_path(test_storage: FileStorage) -> None: def test_make_full_path(test_storage: FileStorage) -> None: # fully within storage relative_path = os.path.join("dir", "to", "file") - path = test_storage.make_full_path(relative_path) + path = test_storage.make_full_path_safe(relative_path) assert path.endswith(os.path.join(TEST_STORAGE_ROOT, relative_path)) # overlapped with storage root_path = os.path.join(TEST_STORAGE_ROOT, relative_path) - path = test_storage.make_full_path(root_path) + path = test_storage.make_full_path_safe(root_path) assert path.endswith(root_path) assert path.count(TEST_STORAGE_ROOT) == 2 # absolute path with different root than TEST_STORAGE_ROOT does not lead into storage so calculating full path impossible with pytest.raises(ValueError): - test_storage.make_full_path(os.path.join("/", root_path)) + test_storage.make_full_path_safe(os.path.join("/", root_path)) # relative path out of the root with pytest.raises(ValueError): - test_storage.make_full_path("..") + test_storage.make_full_path_safe("..") # absolute overlapping path - path = test_storage.make_full_path(os.path.abspath(root_path)) + path = test_storage.make_full_path_safe(os.path.abspath(root_path)) assert path.endswith(root_path) - assert test_storage.make_full_path("") == test_storage.storage_path - assert test_storage.make_full_path(".") == test_storage.storage_path + assert test_storage.make_full_path_safe("") == test_storage.storage_path + assert test_storage.make_full_path_safe(".") == test_storage.storage_path def test_in_storage(test_storage: FileStorage) -> None: # always relative to storage root - assert test_storage.in_storage("a/b/c") is True - assert test_storage.in_storage(f"../{TEST_STORAGE_ROOT}/b/c") is True - assert test_storage.in_storage("../a/b/c") is False - assert test_storage.in_storage("../../../a/b/c") is False - assert test_storage.in_storage("/a") is False - assert test_storage.in_storage(".") is True - assert test_storage.in_storage(os.curdir) is True - assert test_storage.in_storage(os.path.realpath(os.curdir)) is False + assert test_storage.is_path_in_storage("a/b/c") is True + assert test_storage.is_path_in_storage(f"../{TEST_STORAGE_ROOT}/b/c") is True + assert test_storage.is_path_in_storage("../a/b/c") is False + assert test_storage.is_path_in_storage("../../../a/b/c") is False + assert test_storage.is_path_in_storage("/a") is False + assert test_storage.is_path_in_storage(".") is True + assert test_storage.is_path_in_storage(os.curdir) is True + assert test_storage.is_path_in_storage(os.path.realpath(os.curdir)) is False assert ( - test_storage.in_storage(os.path.join(os.path.realpath(os.curdir), TEST_STORAGE_ROOT)) + test_storage.is_path_in_storage( + os.path.join(os.path.realpath(os.curdir), TEST_STORAGE_ROOT) + ) is True ) @@ -164,7 +166,7 @@ def test_rmtree_ro(test_storage: FileStorage) -> None: test_storage.create_folder("protected") path = test_storage.save("protected/barbapapa.txt", "barbapapa") os.chmod(path, stat.S_IREAD) - os.chmod(test_storage.make_full_path("protected"), stat.S_IREAD) + os.chmod(test_storage.make_full_path_safe("protected"), stat.S_IREAD) with pytest.raises(PermissionError): test_storage.delete_folder("protected", recursively=True, delete_ro=False) test_storage.delete_folder("protected", recursively=True, delete_ro=True) diff --git a/tests/common/storages/test_load_package.py b/tests/common/storages/test_load_package.py index ecbc5d296d..45bc8d157e 100644 --- a/tests/common/storages/test_load_package.py +++ b/tests/common/storages/test_load_package.py @@ -8,10 +8,8 @@ from dlt.common import sleep from dlt.common.schema import Schema from dlt.common.storages import PackageStorage, LoadStorage, ParsedLoadJobFileName +from dlt.common.storages.exceptions import LoadPackageAlreadyCompleted, LoadPackageNotCompleted from dlt.common.utils import uniq_id - -from tests.common.storages.utils import start_loading_file, assert_package_info, load_storage -from tests.utils import autouse_test_storage from dlt.common.pendulum import pendulum from dlt.common.configuration.container import Container from dlt.common.storages.load_package import ( @@ -23,6 +21,9 @@ clear_destination_state, ) +from tests.common.storages.utils import start_loading_file, assert_package_info, load_storage +from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage + def test_is_partially_loaded(load_storage: LoadStorage) -> None: load_id, file_name = start_loading_file( @@ -243,6 +244,177 @@ def test_build_parse_job_path(load_storage: LoadStorage) -> None: ParsedLoadJobFileName.parse("tab.id.wrong_retry.jsonl") +def test_load_package_listings(load_storage: LoadStorage) -> None: + # 100 csv files + load_id = create_load_package(load_storage.new_packages, 100) + new_jobs = load_storage.new_packages.list_new_jobs(load_id) + assert len(new_jobs) == 100 + assert len(load_storage.new_packages.list_job_with_states_for_table(load_id, "items_1")) == 100 + assert len(load_storage.new_packages.list_job_with_states_for_table(load_id, "items_2")) == 0 + assert len(load_storage.new_packages.list_all_jobs_with_states(load_id)) == 100 + assert len(load_storage.new_packages.list_started_jobs(load_id)) == 0 + assert len(load_storage.new_packages.list_failed_jobs(load_id)) == 0 + assert load_storage.new_packages.is_package_completed(load_id) is False + with pytest.raises(LoadPackageNotCompleted): + load_storage.new_packages.list_failed_jobs_infos(load_id) + # add a few more files + add_new_jobs(load_storage.new_packages, load_id, 7, "items_2") + assert len(load_storage.new_packages.list_job_with_states_for_table(load_id, "items_1")) == 100 + assert len(load_storage.new_packages.list_job_with_states_for_table(load_id, "items_2")) == 7 + j_w_s = load_storage.new_packages.list_all_jobs_with_states(load_id) + assert len(j_w_s) == 107 + assert all(job[0] == "new_jobs" for job in j_w_s) + with pytest.raises(FileNotFoundError): + load_storage.new_packages.get_job_failed_message(load_id, j_w_s[0][1]) + # get package infos + package_jobs = load_storage.new_packages.get_load_package_jobs(load_id) + assert len(package_jobs["new_jobs"]) == 107 + # other folders empty + assert len(package_jobs["started_jobs"]) == 0 + package_info = load_storage.new_packages.get_load_package_info(load_id) + assert len(package_info.jobs["new_jobs"]) == 107 + assert len(package_info.jobs["completed_jobs"]) == 0 + assert package_info.load_id == load_id + # full path + assert package_info.package_path == load_storage.new_packages.storage.make_full_path(load_id) + assert package_info.state == "new" + assert package_info.completed_at is None + + # move some files + new_jobs = sorted(load_storage.new_packages.list_new_jobs(load_id)) + load_storage.new_packages.start_job(load_id, os.path.basename(new_jobs[0])) + load_storage.new_packages.start_job(load_id, os.path.basename(new_jobs[1])) + load_storage.new_packages.start_job(load_id, os.path.basename(new_jobs[-1])) + load_storage.new_packages.start_job(load_id, os.path.basename(new_jobs[-2])) + + assert len(load_storage.new_packages.list_started_jobs(load_id)) == 4 + assert len(load_storage.new_packages.list_new_jobs(load_id)) == 103 + assert len(load_storage.new_packages.list_job_with_states_for_table(load_id, "items_1")) == 100 + assert len(load_storage.new_packages.list_job_with_states_for_table(load_id, "items_2")) == 7 + package_jobs = load_storage.new_packages.get_load_package_jobs(load_id) + assert len(package_jobs["new_jobs"]) == 103 + assert len(package_jobs["started_jobs"]) == 4 + package_info = load_storage.new_packages.get_load_package_info(load_id) + assert len(package_info.jobs["new_jobs"]) == 103 + assert len(package_info.jobs["started_jobs"]) == 4 + + # complete and fail some + load_storage.new_packages.complete_job(load_id, os.path.basename(new_jobs[0])) + load_storage.new_packages.fail_job(load_id, os.path.basename(new_jobs[1]), None) + load_storage.new_packages.fail_job(load_id, os.path.basename(new_jobs[-1]), "error!") + path = load_storage.new_packages.retry_job(load_id, os.path.basename(new_jobs[-2])) + assert ParsedLoadJobFileName.parse(path).retry_count == 1 + assert ( + load_storage.new_packages.get_job_failed_message( + load_id, ParsedLoadJobFileName.parse(new_jobs[1]) + ) + is None + ) + assert ( + load_storage.new_packages.get_job_failed_message( + load_id, ParsedLoadJobFileName.parse(new_jobs[-1]) + ) + == "error!" + ) + # can't move again + with pytest.raises(FileNotFoundError): + load_storage.new_packages.complete_job(load_id, os.path.basename(new_jobs[0])) + assert len(load_storage.new_packages.list_started_jobs(load_id)) == 0 + # retry back in new + assert len(load_storage.new_packages.list_new_jobs(load_id)) == 104 + package_jobs = load_storage.new_packages.get_load_package_jobs(load_id) + assert len(package_jobs["new_jobs"]) == 104 + assert len(package_jobs["started_jobs"]) == 0 + assert len(package_jobs["completed_jobs"]) == 1 + assert len(package_jobs["failed_jobs"]) == 2 + assert len(load_storage.new_packages.list_failed_jobs(load_id)) == 2 + package_info = load_storage.new_packages.get_load_package_info(load_id) + assert len(package_info.jobs["new_jobs"]) == 104 + assert len(package_info.jobs["started_jobs"]) == 0 + assert len(package_info.jobs["completed_jobs"]) == 1 + assert len(package_info.jobs["failed_jobs"]) == 2 + + # complete package + load_storage.new_packages.complete_loading_package(load_id, "aborted") + assert load_storage.new_packages.is_package_completed(load_id) + with pytest.raises(LoadPackageAlreadyCompleted): + load_storage.new_packages.complete_loading_package(load_id, "aborted") + + for job in package_info.jobs["failed_jobs"] + load_storage.new_packages.list_failed_jobs_infos( # type: ignore[operator] + load_id + ): + if job.job_file_info.table_name == "items_1": + assert job.failed_message is None + elif job.job_file_info.table_name == "items_2": + assert job.failed_message == "error!" + else: + raise AssertionError() + assert job.created_at is not None + assert job.elapsed is not None + assert job.file_size > 0 + assert job.state == "failed_jobs" + # must be abs path! + assert os.path.isabs(job.file_path) + + +def test_get_load_package_info_perf(load_storage: LoadStorage) -> None: + import time + + st_t = time.time() + for _ in range(10000): + load_storage.loaded_packages.storage.make_full_path("198291092.121/new/ABD.CX.gx") + # os.path.basename("198291092.121/new/ABD.CX.gx") + print(time.time() - st_t) + + st_t = time.time() + load_id = create_load_package(load_storage.loaded_packages, 10000) + print(time.time() - st_t) + + st_t = time.time() + # move half of the files to failed + for file_name in load_storage.loaded_packages.list_new_jobs(load_id)[:1000]: + load_storage.loaded_packages.start_job(load_id, os.path.basename(file_name)) + load_storage.loaded_packages.fail_job( + load_id, os.path.basename(file_name), f"FAILED {file_name}" + ) + print(time.time() - st_t) + + st_t = time.time() + load_storage.loaded_packages.get_load_package_info(load_id) + print(time.time() - st_t) + + st_t = time.time() + table_stat = {} + for file in load_storage.loaded_packages.list_new_jobs(load_id): + parsed = ParsedLoadJobFileName.parse(file) + table_stat[parsed.table_name] = parsed + print(time.time() - st_t) + + +def create_load_package( + package_storage: PackageStorage, new_jobs: int, table_name="items_1" +) -> str: + schema = Schema("test") + load_id = create_load_id() + package_storage.create_package(load_id) + package_storage.save_schema(load_id, schema) + add_new_jobs(package_storage, load_id, new_jobs, table_name) + return load_id + + +def add_new_jobs( + package_storage: PackageStorage, load_id: str, new_jobs: int, table_name="items_1" +) -> None: + for _ in range(new_jobs): + file_name = PackageStorage.build_job_file_name( + table_name, ParsedLoadJobFileName.new_file_id(), 0, False, "csv" + ) + file_path = os.path.join(TEST_STORAGE_ROOT, file_name) + with open(file_path, "wt", encoding="utf-8") as f: + f.write("a|b|c") + package_storage.import_job(load_id, file_path) + + def test_migrate_to_load_package_state() -> None: """ Here we test that an existing load package without a state will not error diff --git a/tests/common/storages/test_load_storage.py b/tests/common/storages/test_load_storage.py index e8686ac2f9..49deaff23e 100644 --- a/tests/common/storages/test_load_storage.py +++ b/tests/common/storages/test_load_storage.py @@ -33,7 +33,7 @@ def test_complete_successful_package(load_storage: LoadStorage) -> None: # but completed packages are deleted load_storage.maybe_remove_completed_jobs(load_id) assert not load_storage.loaded_packages.storage.has_folder( - load_storage.loaded_packages.get_job_folder_path(load_id, "completed_jobs") + load_storage.loaded_packages.get_job_state_folder_path(load_id, "completed_jobs") ) assert_package_info(load_storage, load_id, "loaded", "completed_jobs", jobs_count=0) # delete completed package @@ -56,7 +56,7 @@ def test_complete_successful_package(load_storage: LoadStorage) -> None: ) # has completed loads assert load_storage.loaded_packages.storage.has_folder( - load_storage.loaded_packages.get_job_folder_path(load_id, "completed_jobs") + load_storage.loaded_packages.get_job_state_folder_path(load_id, "completed_jobs") ) load_storage.delete_loaded_package(load_id) assert not load_storage.storage.has_folder(load_storage.get_loaded_package_path(load_id)) @@ -82,14 +82,14 @@ def test_complete_package_failed_jobs(load_storage: LoadStorage) -> None: assert load_storage.storage.has_folder(load_storage.get_loaded_package_path(load_id)) # has completed loads assert load_storage.loaded_packages.storage.has_folder( - load_storage.loaded_packages.get_job_folder_path(load_id, "completed_jobs") + load_storage.loaded_packages.get_job_state_folder_path(load_id, "completed_jobs") ) assert_package_info(load_storage, load_id, "loaded", "failed_jobs") # get failed jobs info failed_files = sorted(load_storage.loaded_packages.list_failed_jobs(load_id)) - # job + message - assert len(failed_files) == 2 + # only jobs + assert len(failed_files) == 1 assert load_storage.loaded_packages.storage.has_file(failed_files[0]) failed_info = load_storage.list_failed_jobs_in_loaded_package(load_id) assert failed_info[0].file_path == load_storage.loaded_packages.storage.make_full_path( @@ -117,7 +117,7 @@ def test_abort_package(load_storage: LoadStorage) -> None: assert_package_info(load_storage, load_id, "normalized", "failed_jobs") load_storage.complete_load_package(load_id, True) assert load_storage.loaded_packages.storage.has_folder( - load_storage.loaded_packages.get_job_folder_path(load_id, "completed_jobs") + load_storage.loaded_packages.get_job_state_folder_path(load_id, "completed_jobs") ) assert_package_info(load_storage, load_id, "aborted", "failed_jobs") diff --git a/tests/common/storages/test_schema_storage.py b/tests/common/storages/test_schema_storage.py index e97fac8a9e..ffbd2ecf1b 100644 --- a/tests/common/storages/test_schema_storage.py +++ b/tests/common/storages/test_schema_storage.py @@ -1,12 +1,10 @@ import os -import shutil import pytest import yaml from dlt.common import json -from dlt.common.normalizers import explicit_normalizers +from dlt.common.normalizers.utils import explicit_normalizers from dlt.common.schema.schema import Schema -from dlt.common.schema.typing import TStoredSchema from dlt.common.storages.exceptions import ( InStorageSchemaModified, SchemaNotFoundError, @@ -20,9 +18,9 @@ ) from tests.utils import autouse_test_storage, TEST_STORAGE_ROOT +from tests.common.storages.utils import prepare_eth_import_folder from tests.common.utils import ( load_yml_case, - yml_case_path, COMMON_TEST_CASES_PATH, IMPORTED_VERSION_HASH_ETH_V9, ) @@ -234,7 +232,7 @@ def test_getter(storage: SchemaStorage) -> None: def test_getter_with_import(ie_storage: SchemaStorage) -> None: with pytest.raises(KeyError): ie_storage["ethereum"] - prepare_import_folder(ie_storage) + prepare_eth_import_folder(ie_storage) # schema will be imported schema = ie_storage["ethereum"] assert schema.name == "ethereum" @@ -260,17 +258,17 @@ def test_getter_with_import(ie_storage: SchemaStorage) -> None: def test_save_store_schema_over_import(ie_storage: SchemaStorage) -> None: - prepare_import_folder(ie_storage) + prepare_eth_import_folder(ie_storage) # we have ethereum schema to be imported but we create new schema and save it schema = Schema("ethereum") schema_hash = schema.version_hash ie_storage.save_schema(schema) assert schema.version_hash == schema_hash # we linked schema to import schema - assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9 + assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9() # load schema and make sure our new schema is here schema = ie_storage.load_schema("ethereum") - assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9 + assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9() assert schema._stored_version_hash == schema_hash assert schema.version_hash == schema_hash assert schema.previous_hashes == [] @@ -283,11 +281,11 @@ def test_save_store_schema_over_import(ie_storage: SchemaStorage) -> None: def test_save_store_schema_over_import_sync(synced_storage: SchemaStorage) -> None: # as in test_save_store_schema_over_import but we export the new schema immediately to overwrite the imported schema - prepare_import_folder(synced_storage) + prepare_eth_import_folder(synced_storage) schema = Schema("ethereum") schema_hash = schema.version_hash synced_storage.save_schema(schema) - assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9 + assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9() # import schema is overwritten fs = FileStorage(synced_storage.config.import_schema_path) exported_name = synced_storage._file_name_in_store("ethereum", "yaml") @@ -353,6 +351,28 @@ def test_schema_from_file() -> None: ) +def test_save_initial_import_schema(ie_storage: LiveSchemaStorage) -> None: + # no schema in regular storage + with pytest.raises(SchemaNotFoundError): + ie_storage.load_schema("ethereum") + + # save initial import schema where processing hints are removed + eth_V9 = load_yml_case("schemas/eth/ethereum_schema_v9") + schema = Schema.from_dict(eth_V9) + ie_storage.save_import_schema_if_not_exists(schema) + # should be available now + eth = ie_storage.load_schema("ethereum") + assert "x-normalizer" not in eth.tables["blocks"] + + # won't overwrite initial schema + del eth_V9["tables"]["blocks__uncles"] + schema = Schema.from_dict(eth_V9) + ie_storage.save_import_schema_if_not_exists(schema) + # should be available now + eth = ie_storage.load_schema("ethereum") + assert "blocks__uncles" in eth.tables + + def test_live_schema_instances(live_storage: LiveSchemaStorage) -> None: schema = Schema("simple") live_storage.save_schema(schema) @@ -474,22 +494,14 @@ def test_new_live_schema_committed(live_storage: LiveSchemaStorage) -> None: # assert schema.settings["schema_sealed"] is True -def prepare_import_folder(storage: SchemaStorage) -> None: - shutil.copy( - yml_case_path("schemas/eth/ethereum_schema_v8"), - os.path.join(storage.storage.storage_path, "../import/ethereum.schema.yaml"), - ) - - def assert_schema_imported(synced_storage: SchemaStorage, storage: SchemaStorage) -> Schema: - prepare_import_folder(synced_storage) - eth_V9: TStoredSchema = load_yml_case("schemas/eth/ethereum_schema_v9") + prepare_eth_import_folder(synced_storage) schema = synced_storage.load_schema("ethereum") # is linked to imported schema - schema._imported_version_hash = eth_V9["version_hash"] + schema._imported_version_hash = IMPORTED_VERSION_HASH_ETH_V9() # also was saved in storage assert synced_storage.has_schema("ethereum") - # and has link to imported schema s well (load without import) + # and has link to imported schema as well (load without import) schema = storage.load_schema("ethereum") - assert schema._imported_version_hash == eth_V9["version_hash"] + assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9() return schema diff --git a/tests/common/storages/utils.py b/tests/common/storages/utils.py index 3bfc3374a4..1b5a68948b 100644 --- a/tests/common/storages/utils.py +++ b/tests/common/storages/utils.py @@ -21,9 +21,12 @@ ) from dlt.common.storages import DataItemStorage, FileStorage from dlt.common.storages.fsspec_filesystem import FileItem, FileItemDict +from dlt.common.storages.schema_storage import SchemaStorage from dlt.common.typing import StrAny, TDataItems from dlt.common.utils import uniq_id +from tests.common.utils import load_yml_case + TEST_SAMPLE_FILES = "tests/common/storages/samples" MINIMALLY_EXPECTED_RELATIVE_PATHS = { "csv/freshman_kgs.csv", @@ -199,3 +202,12 @@ def assert_package_info( # get dict package_info.asdict() return package_info + + +def prepare_eth_import_folder(storage: SchemaStorage) -> Schema: + eth_V9 = load_yml_case("schemas/eth/ethereum_schema_v9") + # remove processing hints before installing as import schema + # ethereum schema is a "dirty" schema with processing hints + eth = Schema.from_dict(eth_V9, remove_processing_hints=True) + storage._export_schema(eth, storage.config.import_schema_path) + return eth diff --git a/tests/common/test_destination.py b/tests/common/test_destination.py index 24b0928463..e6e2ecad2c 100644 --- a/tests/common/test_destination.py +++ b/tests/common/test_destination.py @@ -1,10 +1,14 @@ +from typing import Dict import pytest from dlt.common.destination.reference import DestinationClientDwhConfiguration, Destination from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.exceptions import InvalidDestinationReference, UnknownDestinationModule from dlt.common.schema import Schema +from dlt.common.typing import is_subclass +from dlt.common.normalizers.naming import sql_ci_v1, sql_cs_v1 +from tests.common.configuration.utils import environment from tests.utils import ACTIVE_DESTINATIONS @@ -32,6 +36,96 @@ def test_custom_destination_module() -> None: ) # a full type name +def test_arguments_propagated_to_config() -> None: + dest = Destination.from_reference( + "dlt.destinations.duckdb", create_indexes=None, unknown_param="A" + ) + # None for create_indexes is not a default and it is passed on, unknown_param is removed because it is unknown + assert dest.config_params == {"create_indexes": None} + assert dest.caps_params == {} + + # test explicit config value being passed + import dlt + + dest = Destination.from_reference( + "dlt.destinations.duckdb", create_indexes=dlt.config.value, unknown_param="A" + ) + assert dest.config_params == {"create_indexes": dlt.config.value} + assert dest.caps_params == {} + + dest = Destination.from_reference( + "dlt.destinations.weaviate", naming_convention="duck_case", create_indexes=True + ) + # create indexes are not known + assert dest.config_params == {} + + # create explicit caps + dest = Destination.from_reference( + "dlt.destinations.dummy", + naming_convention="duck_case", + recommended_file_size=4000000, + loader_file_format="parquet", + ) + from dlt.destinations.impl.dummy.configuration import DummyClientConfiguration + + assert dest.config_params == {"loader_file_format": "parquet"} + # loader_file_format is a legacy param that is duplicated as preferred_loader_file_format + assert dest.caps_params == { + "naming_convention": "duck_case", + "recommended_file_size": 4000000, + } + # instantiate configs + caps = dest.capabilities() + assert caps.naming_convention == "duck_case" + assert caps.preferred_loader_file_format == "parquet" + assert caps.recommended_file_size == 4000000 + init_config = DummyClientConfiguration() + config = dest.configuration(init_config) + assert config.loader_file_format == "parquet" # type: ignore[attr-defined] + + +def test_factory_config_injection(environment: Dict[str, str]) -> None: + environment["DESTINATION__LOADER_FILE_FORMAT"] = "parquet" + from dlt.destinations import dummy + + # caps will resolve from config without client + assert dummy().capabilities().preferred_loader_file_format == "parquet" + + caps = dummy().client(Schema("client")).capabilities + assert caps.preferred_loader_file_format == "parquet" + + environment.clear() + caps = dummy().client(Schema("client")).capabilities + assert caps.preferred_loader_file_format == "jsonl" + + environment["DESTINATION__DUMMY__LOADER_FILE_FORMAT"] = "parquet" + environment["DESTINATION__DUMMY__FAIL_PROB"] = "0.435" + + # config will partially resolve without client + config = dummy().configuration(None, accept_partial=True) + assert config.fail_prob == 0.435 + assert config.loader_file_format == "parquet" + + dummy_ = dummy().client(Schema("client")) + assert dummy_.capabilities.preferred_loader_file_format == "parquet" + assert dummy_.config.fail_prob == 0.435 + + # test named destination + environment.clear() + import os + from dlt.destinations import filesystem + from dlt.destinations.impl.filesystem.configuration import ( + FilesystemDestinationClientConfiguration, + ) + + filesystem_ = filesystem(destination_name="local") + abs_path = os.path.abspath("_storage") + environment["DESTINATION__LOCAL__BUCKET_URL"] = abs_path + init_config = FilesystemDestinationClientConfiguration()._bind_dataset_name(dataset_name="test") + configured_bucket_url = filesystem_.client(Schema("test"), init_config).config.bucket_url + assert configured_bucket_url.endswith("_storage") + + def test_import_module_by_path() -> None: # importing works directly from dlt destinations dest = Destination.from_reference("dlt.destinations.postgres") @@ -54,17 +148,7 @@ def test_import_module_by_path() -> None: def test_import_all_destinations() -> None: # this must pass without the client dependencies being imported for dest_type in ACTIVE_DESTINATIONS: - # generic destination needs a valid callable, otherwise instantiation will fail - additional_args = {} - if dest_type == "destination": - - def dest_callable(items, table) -> None: - pass - - additional_args["destination_callable"] = dest_callable - dest = Destination.from_reference( - dest_type, None, dest_type + "_name", "production", **additional_args - ) + dest = Destination.from_reference(dest_type, None, dest_type + "_name", "production") assert dest.destination_type == "dlt.destinations." + dest_type assert dest.destination_name == dest_type + "_name" assert dest.config_params["environment"] == "production" @@ -73,6 +157,90 @@ def dest_callable(items, table) -> None: assert isinstance(dest.capabilities(), DestinationCapabilitiesContext) +def test_base_adjust_capabilities() -> None: + # return without modifications + caps = DestinationCapabilitiesContext.generic_capabilities() + caps_props = dict(caps) + adj_caps = Destination.adjust_capabilities(caps, None, None) + assert caps is adj_caps + assert dict(adj_caps) == caps_props + + # caps that support case sensitive idents may be put into case sensitive mode + caps = DestinationCapabilitiesContext.generic_capabilities() + assert caps.has_case_sensitive_identifiers is True + assert caps.casefold_identifier is str + # this one is already in case sensitive mode + assert caps.generates_case_sensitive_identifiers() is True + # applying cs naming has no effect + caps = Destination.adjust_capabilities(caps, None, sql_cs_v1.NamingConvention()) + assert caps.generates_case_sensitive_identifiers() is True + # same for ci naming, adjustment is only from case insensitive to sensitive + caps = Destination.adjust_capabilities(caps, None, sql_ci_v1.NamingConvention()) + assert caps.generates_case_sensitive_identifiers() is True + + # switch to case sensitive if supported by changing case folding function + caps = DestinationCapabilitiesContext.generic_capabilities() + caps.casefold_identifier = str.lower + assert caps.generates_case_sensitive_identifiers() is False + caps = Destination.adjust_capabilities(caps, None, sql_cs_v1.NamingConvention()) + assert caps.casefold_identifier is str + assert caps.generates_case_sensitive_identifiers() is True + # ci naming has no effect + caps = DestinationCapabilitiesContext.generic_capabilities() + caps.casefold_identifier = str.upper + caps = Destination.adjust_capabilities(caps, None, sql_ci_v1.NamingConvention()) + assert caps.casefold_identifier is str.upper + assert caps.generates_case_sensitive_identifiers() is False + + # this one does not support case sensitive identifiers and is casefolding + caps = DestinationCapabilitiesContext.generic_capabilities() + caps.has_case_sensitive_identifiers = False + caps.casefold_identifier = str.lower + assert caps.generates_case_sensitive_identifiers() is False + caps = Destination.adjust_capabilities(caps, None, sql_cs_v1.NamingConvention()) + # no effect + assert caps.casefold_identifier is str.lower + assert caps.generates_case_sensitive_identifiers() is False + + +def test_instantiate_all_factories() -> None: + from dlt import destinations + + impls = dir(destinations) + for impl in impls: + var_ = getattr(destinations, impl) + if not is_subclass(var_, Destination): + continue + dest = var_() + + assert dest.destination_name + assert dest.destination_type + # custom destination is named after the callable + if dest.destination_type != "dlt.destinations.destination": + assert dest.destination_type.endswith(dest.destination_name) + else: + assert dest.destination_name == "dummy_custom_destination" + assert dest.spec + assert dest.spec() + # partial configuration may always be created + init_config = dest.spec.credentials_type()() + init_config.__is_resolved__ = True + assert dest.configuration(init_config, accept_partial=True) + assert dest.capabilities() + + mod_dest = var_( + destination_name="fake_name", environment="prod", naming_convention="duck_case" + ) + assert ( + mod_dest.config_params.items() + >= {"destination_name": "fake_name", "environment": "prod"}.items() + ) + assert mod_dest.caps_params == {"naming_convention": "duck_case"} + assert mod_dest.destination_name == "fake_name" + caps = mod_dest.capabilities() + assert caps.naming_convention == "duck_case" + + def test_import_destination_config() -> None: # importing destination by type will work dest = Destination.from_reference(ref="dlt.destinations.duckdb", environment="stage") @@ -97,6 +265,7 @@ def test_import_destination_config() -> None: ref="duckdb", destination_name="my_destination", environment="devel" ) assert dest.destination_type == "dlt.destinations.duckdb" + assert dest.destination_name == "my_destination" assert dest.config_params["environment"] == "devel" config = dest.configuration(dest.spec()._bind_dataset_name(dataset_name="dataset")) # type: ignore assert config.destination_type == "duckdb" @@ -183,6 +352,43 @@ def test_normalize_dataset_name() -> None: ) +def test_normalize_staging_dataset_name() -> None: + # default normalized staging dataset + assert ( + DestinationClientDwhConfiguration() + ._bind_dataset_name(dataset_name="Dataset", default_schema_name="default") + .normalize_staging_dataset_name(Schema("private")) + == "dataset_private_staging" + ) + # different layout + assert ( + DestinationClientDwhConfiguration(staging_dataset_name_layout="%s__STAGING") + ._bind_dataset_name(dataset_name="Dataset", default_schema_name="private") + .normalize_staging_dataset_name(Schema("private")) + == "dataset_staging" + ) + # without placeholder + assert ( + DestinationClientDwhConfiguration(staging_dataset_name_layout="static_staging") + ._bind_dataset_name(dataset_name="Dataset", default_schema_name="default") + .normalize_staging_dataset_name(Schema("private")) + == "static_staging" + ) + # empty dataset -> empty staging + assert ( + DestinationClientDwhConfiguration() + ._bind_dataset_name(dataset_name=None, default_schema_name="private") + .normalize_staging_dataset_name(Schema("private")) + is None + ) + assert ( + DestinationClientDwhConfiguration(staging_dataset_name_layout="static_staging") + ._bind_dataset_name(dataset_name=None, default_schema_name="default") + .normalize_staging_dataset_name(Schema("private")) + == "static_staging" + ) + + def test_normalize_dataset_name_none_default_schema() -> None: # if default schema is None, suffix is not added assert ( diff --git a/tests/common/test_json.py b/tests/common/test_json.py index 79037ebf93..b7d25589a7 100644 --- a/tests/common/test_json.py +++ b/tests/common/test_json.py @@ -6,6 +6,7 @@ from dlt.common import json, Decimal, pendulum from dlt.common.arithmetics import numeric_default_context +from dlt.common import known_env from dlt.common.json import ( _DECIMAL, _WEI, @@ -306,7 +307,7 @@ def test_garbage_pua_string(json_impl: SupportsJson) -> None: def test_change_pua_start() -> None: import inspect - os.environ["DLT_JSON_TYPED_PUA_START"] = "0x0FA179" + os.environ[known_env.DLT_JSON_TYPED_PUA_START] = "0x0FA179" from importlib import reload try: @@ -316,7 +317,7 @@ def test_change_pua_start() -> None: assert MOD_PUA_START == int("0x0FA179", 16) finally: # restore old start - os.environ["DLT_JSON_TYPED_PUA_START"] = hex(PUA_START) + os.environ[known_env.DLT_JSON_TYPED_PUA_START] = hex(PUA_START) from importlib import reload reload(inspect.getmodule(SupportsJson)) diff --git a/tests/common/utils.py b/tests/common/utils.py index a234937e56..32741128b8 100644 --- a/tests/common/utils.py +++ b/tests/common/utils.py @@ -9,14 +9,24 @@ from dlt.common import json from dlt.common.typing import StrAny -from dlt.common.schema import utils +from dlt.common.schema import utils, Schema from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.configuration.providers import environ as environ_provider COMMON_TEST_CASES_PATH = "./tests/common/cases/" -# for import schema tests, change when upgrading the schema version -IMPORTED_VERSION_HASH_ETH_V9 = "PgEHvn5+BHV1jNzNYpx9aDpq6Pq1PSSetufj/h0hKg4=" + + +def IMPORTED_VERSION_HASH_ETH_V9() -> str: + # for import schema tests, change when upgrading the schema version + eth_V9 = load_yml_case("schemas/eth/ethereum_schema_v9") + assert eth_V9["version_hash"] == "PgEHvn5+BHV1jNzNYpx9aDpq6Pq1PSSetufj/h0hKg4=" + # remove processing hints before installing as import schema + # ethereum schema is a "dirty" schema with processing hints + eth = Schema.from_dict(eth_V9, remove_processing_hints=True) + return eth.stored_version_hash + + # test sentry DSN TEST_SENTRY_DSN = ( "https://797678dd0af64b96937435326c7d30c1@o1061158.ingest.sentry.io/4504306172821504" diff --git a/tests/conftest.py b/tests/conftest.py index 020487d878..e819e26ebb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -111,3 +111,34 @@ def _create_pipeline_instance_id(self) -> str: # disable databricks logging for log in ["databricks.sql.client"]: logging.getLogger(log).setLevel("WARNING") + + # disable httpx request logging (too verbose when testing qdrant) + logging.getLogger("httpx").setLevel("WARNING") + + # reset and init airflow db + import warnings + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + + try: + from airflow.utils import db + import contextlib + import io + + for log in [ + "airflow.models.crypto", + "airflow.models.variable", + "airflow", + "alembic", + "alembic.runtime.migration", + ]: + logging.getLogger(log).setLevel("ERROR") + + with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr( + io.StringIO() + ): + db.resetdb() + + except Exception: + pass diff --git a/tests/destinations/test_custom_destination.py b/tests/destinations/test_custom_destination.py index 6834006689..6ebf7f6ef3 100644 --- a/tests/destinations/test_custom_destination.py +++ b/tests/destinations/test_custom_destination.py @@ -8,12 +8,13 @@ from copy import deepcopy from dlt.common.configuration.specs.base_configuration import configspec +from dlt.common.schema.schema import Schema from dlt.common.typing import TDataItems from dlt.common.schema import TTableSchema from dlt.common.data_writers.writers import TLoaderFileFormat from dlt.common.destination.reference import Destination from dlt.common.destination.exceptions import InvalidDestinationReference -from dlt.common.configuration.exceptions import ConfigFieldMissingException +from dlt.common.configuration.exceptions import ConfigFieldMissingException, ConfigurationValueError from dlt.common.configuration.specs import ConnectionStringCredentials from dlt.common.configuration.inject import get_fun_spec from dlt.common.configuration.specs import BaseConfiguration @@ -38,7 +39,7 @@ def _run_through_sink( batch_size: int = 10, ) -> List[Tuple[TDataItems, TTableSchema]]: """ - runs a list of items through the sink destination and returns colleceted calls + runs a list of items through the sink destination and returns collected calls """ calls: List[Tuple[TDataItems, TTableSchema]] = [] @@ -55,7 +56,7 @@ def items_resource() -> TDataItems: nonlocal items yield items - p = dlt.pipeline("sink_test", destination=test_sink, full_refresh=True) + p = dlt.pipeline("sink_test", destination=test_sink, dev_mode=True) p.run([items_resource()]) return calls @@ -126,6 +127,34 @@ def global_sink_func(items: TDataItems, table: TTableSchema) -> None: global_calls.append((items, table)) +def test_capabilities() -> None: + # test default caps + dest = dlt.destination()(global_sink_func)() + caps = dest.capabilities() + assert caps.preferred_loader_file_format == "typed-jsonl" + assert caps.supported_loader_file_formats == ["typed-jsonl", "parquet"] + assert caps.naming_convention == "direct" + assert caps.max_table_nesting == 0 + client_caps = dest.client(Schema("schema")).capabilities + assert dict(caps) == dict(client_caps) + + # test modified caps + dest = dlt.destination( + loader_file_format="parquet", + batch_size=0, + name="my_name", + naming_convention="snake_case", + max_table_nesting=10, + )(global_sink_func)() + caps = dest.capabilities() + assert caps.preferred_loader_file_format == "parquet" + assert caps.supported_loader_file_formats == ["typed-jsonl", "parquet"] + assert caps.naming_convention == "snake_case" + assert caps.max_table_nesting == 10 + client_caps = dest.client(Schema("schema")).capabilities + assert dict(caps) == dict(client_caps) + + def test_instantiation() -> None: # also tests _DESTINATIONS calls: List[Tuple[TDataItems, TTableSchema]] = [] @@ -140,23 +169,23 @@ def local_sink_func(items: TDataItems, table: TTableSchema, my_val=dlt.config.va # test decorator calls = [] - p = dlt.pipeline("sink_test", destination=dlt.destination()(local_sink_func), full_refresh=True) + p = dlt.pipeline("sink_test", destination=dlt.destination()(local_sink_func), dev_mode=True) p.run([1, 2, 3], table_name="items") assert len(calls) == 1 # local func does not create entry in destinations - assert not _DESTINATIONS + assert "local_sink_func" not in _DESTINATIONS # test passing via from_reference calls = [] p = dlt.pipeline( "sink_test", destination=Destination.from_reference("destination", destination_callable=local_sink_func), - full_refresh=True, + dev_mode=True, ) p.run([1, 2, 3], table_name="items") assert len(calls) == 1 # local func does not create entry in destinations - assert not _DESTINATIONS + assert "local_sink_func" not in _DESTINATIONS # test passing string reference global global_calls @@ -167,7 +196,7 @@ def local_sink_func(items: TDataItems, table: TTableSchema, my_val=dlt.config.va "destination", destination_callable="tests.destinations.test_custom_destination.global_sink_func", ), - full_refresh=True, + dev_mode=True, ) p.run([1, 2, 3], table_name="items") assert len(global_calls) == 1 @@ -182,9 +211,9 @@ def local_sink_func(items: TDataItems, table: TTableSchema, my_val=dlt.config.va p = dlt.pipeline( "sink_test", destination=Destination.from_reference("destination", destination_callable=None), - full_refresh=True, + dev_mode=True, ) - with pytest.raises(PipelineStepFailed): + with pytest.raises(ConfigurationValueError): p.run([1, 2, 3], table_name="items") # pass invalid string reference will fail on instantiation @@ -194,7 +223,7 @@ def local_sink_func(items: TDataItems, table: TTableSchema, my_val=dlt.config.va destination=Destination.from_reference( "destination", destination_callable="does.not.exist" ), - full_refresh=True, + dev_mode=True, ) # using decorator without args will also work @@ -206,7 +235,7 @@ def simple_decorator_sink(items, table, my_val=dlt.config.value): assert my_val == "something" calls.append((items, table)) - p = dlt.pipeline("sink_test", destination=simple_decorator_sink, full_refresh=True) # type: ignore + p = dlt.pipeline("sink_test", destination=simple_decorator_sink, dev_mode=True) # type: ignore p.run([1, 2, 3], table_name="items") assert len(calls) == 1 @@ -265,7 +294,7 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: assert str(i) in collected_items # no errors are set, all items should be processed - p = dlt.pipeline("sink_test", destination=test_sink, full_refresh=True) + p = dlt.pipeline("sink_test", destination=test_sink, dev_mode=True) load_id = p.run([items(), items2()]).loads_ids[0] assert_items_in_range(calls["items"], 0, 100) assert_items_in_range(calls["items2"], 0, 100) @@ -278,7 +307,7 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: # provoke errors calls = {} provoke_error = {"items": 25, "items2": 45} - p = dlt.pipeline("sink_test", destination=test_sink, full_refresh=True) + p = dlt.pipeline("sink_test", destination=test_sink, dev_mode=True) with pytest.raises(PipelineStepFailed): p.run([items(), items2()]) @@ -335,7 +364,7 @@ def snake_sink(items, table): assert table["columns"]["snake_case"]["name"] == "snake_case" assert table["columns"]["camel_case"]["name"] == "camel_case" - dlt.pipeline("sink_test", destination=snake_sink, full_refresh=True).run(resource()) + dlt.pipeline("sink_test", destination=snake_sink, dev_mode=True).run(resource()) # check default (which is direct) @dlt.destination() @@ -345,7 +374,7 @@ def direct_sink(items, table): assert table["columns"]["snake_case"]["name"] == "snake_case" assert table["columns"]["camelCase"]["name"] == "camelCase" - dlt.pipeline("sink_test", destination=direct_sink, full_refresh=True).run(resource()) + dlt.pipeline("sink_test", destination=direct_sink, dev_mode=True).run(resource()) def test_file_batch() -> None: @@ -368,7 +397,7 @@ def direct_sink(file_path, table): with pyarrow.parquet.ParquetFile(file_path) as reader: assert reader.metadata.num_rows == (100 if table["name"] == "person" else 50) - dlt.pipeline("sink_test", destination=direct_sink, full_refresh=True).run( + dlt.pipeline("sink_test", destination=direct_sink, dev_mode=True).run( [resource1(), resource2()] ) @@ -384,25 +413,23 @@ def my_sink(file_path, table, my_val=dlt.config.value): # if no value is present, it should raise with pytest.raises(ConfigFieldMissingException): - dlt.pipeline("sink_test", destination=my_sink, full_refresh=True).run( + dlt.pipeline("sink_test", destination=my_sink, dev_mode=True).run( [1, 2, 3], table_name="items" ) # we may give the value via __callable__ function - dlt.pipeline("sink_test", destination=my_sink(my_val="something"), full_refresh=True).run( + dlt.pipeline("sink_test", destination=my_sink(my_val="something"), dev_mode=True).run( [1, 2, 3], table_name="items" ) # right value will pass os.environ["DESTINATION__MY_SINK__MY_VAL"] = "something" - dlt.pipeline("sink_test", destination=my_sink, full_refresh=True).run( - [1, 2, 3], table_name="items" - ) + dlt.pipeline("sink_test", destination=my_sink, dev_mode=True).run([1, 2, 3], table_name="items") # wrong value will raise os.environ["DESTINATION__MY_SINK__MY_VAL"] = "wrong" with pytest.raises(PipelineStepFailed): - dlt.pipeline("sink_test", destination=my_sink, full_refresh=True).run( + dlt.pipeline("sink_test", destination=my_sink, dev_mode=True).run( [1, 2, 3], table_name="items" ) @@ -413,13 +440,13 @@ def other_sink(file_path, table, my_val=dlt.config.value): # if no value is present, it should raise with pytest.raises(ConfigFieldMissingException): - dlt.pipeline("sink_test", destination=other_sink, full_refresh=True).run( + dlt.pipeline("sink_test", destination=other_sink, dev_mode=True).run( [1, 2, 3], table_name="items" ) # right value will pass os.environ["DESTINATION__SOME_NAME__MY_VAL"] = "something" - dlt.pipeline("sink_test", destination=other_sink, full_refresh=True).run( + dlt.pipeline("sink_test", destination=other_sink, dev_mode=True).run( [1, 2, 3], table_name="items" ) @@ -437,7 +464,7 @@ def my_gcp_sink( # missing spec with pytest.raises(ConfigFieldMissingException): - dlt.pipeline("sink_test", destination=my_gcp_sink, full_refresh=True).run( + dlt.pipeline("sink_test", destination=my_gcp_sink, dev_mode=True).run( [1, 2, 3], table_name="items" ) @@ -447,7 +474,7 @@ def my_gcp_sink( os.environ["CREDENTIALS__USERNAME"] = "my_user_name" # now it will run - dlt.pipeline("sink_test", destination=my_gcp_sink, full_refresh=True).run( + dlt.pipeline("sink_test", destination=my_gcp_sink, dev_mode=True).run( [1, 2, 3], table_name="items" ) @@ -471,14 +498,14 @@ def sink_func_with_spec( # call fails because `my_predefined_val` is required part of spec, even if not injected with pytest.raises(ConfigFieldMissingException): - info = dlt.pipeline("sink_test", destination=sink_func_with_spec(), full_refresh=True).run( + info = dlt.pipeline("sink_test", destination=sink_func_with_spec(), dev_mode=True).run( [1, 2, 3], table_name="items" ) info.raise_on_failed_jobs() # call happens now os.environ["MY_PREDEFINED_VAL"] = "VAL" - info = dlt.pipeline("sink_test", destination=sink_func_with_spec(), full_refresh=True).run( + info = dlt.pipeline("sink_test", destination=sink_func_with_spec(), dev_mode=True).run( [1, 2, 3], table_name="items" ) info.raise_on_failed_jobs() @@ -550,7 +577,7 @@ def test_sink(items, table): found_dlt_column_value = True # test with and without removing - p = dlt.pipeline("sink_test", destination=test_sink, full_refresh=True) + p = dlt.pipeline("sink_test", destination=test_sink, dev_mode=True) p.run([{"id": 1, "value": "1"}], table_name="some_table") assert found_dlt_column != remove_stuff @@ -579,7 +606,7 @@ def nesting_sink(items, table): def source(): yield dlt.resource(data, name="data") - p = dlt.pipeline("sink_test_max_nesting", destination=nesting_sink, full_refresh=True) + p = dlt.pipeline("sink_test_max_nesting", destination=nesting_sink, dev_mode=True) p.run(source()) # fall back to source setting diff --git a/tests/extract/data_writers/test_buffered_writer.py b/tests/extract/data_writers/test_buffered_writer.py index b6da132de9..5cad5a35b9 100644 --- a/tests/extract/data_writers/test_buffered_writer.py +++ b/tests/extract/data_writers/test_buffered_writer.py @@ -264,6 +264,27 @@ def test_import_file(writer_type: Type[DataWriter]) -> None: assert metrics.file_size == 231 +@pytest.mark.parametrize("writer_type", ALL_WRITERS) +def test_import_file_with_extension(writer_type: Type[DataWriter]) -> None: + now = time.time() + with get_writer(writer_type) as writer: + # won't destroy the original + metrics = writer.import_file( + "tests/extract/cases/imported.any", + DataWriterMetrics("", 1, 231, 0, 0), + with_extension="any", + ) + assert len(writer.closed_files) == 1 + assert os.path.isfile(metrics.file_path) + # extension is correctly set + assert metrics.file_path.endswith(".any") + assert writer.closed_files[0] == metrics + assert metrics.created <= metrics.last_modified + assert metrics.created >= now + assert metrics.items_count == 1 + assert metrics.file_size == 231 + + @pytest.mark.parametrize( "disable_compression", [True, False], ids=["no_compression", "compression"] ) diff --git a/tests/extract/test_decorators.py b/tests/extract/test_decorators.py index db888c95e4..f9775fd218 100644 --- a/tests/extract/test_decorators.py +++ b/tests/extract/test_decorators.py @@ -42,7 +42,7 @@ ) from dlt.extract.items import TableNameMeta -from tests.common.utils import IMPORTED_VERSION_HASH_ETH_V9 +from tests.common.utils import load_yml_case def test_default_resource() -> None: @@ -107,7 +107,10 @@ def test_load_schema_for_callable() -> None: schema = s.schema assert schema.name == "ethereum" == s.name # the schema in the associated file has this hash - assert schema.stored_version_hash == IMPORTED_VERSION_HASH_ETH_V9 + eth_v9 = load_yml_case("schemas/eth/ethereum_schema_v9") + # source removes processing hints so we do + reference_schema = Schema.from_dict(eth_v9, remove_processing_hints=True) + assert schema.stored_version_hash == reference_schema.stored_version_hash def test_unbound_parametrized_transformer() -> None: @@ -341,6 +344,41 @@ class Columns3(BaseModel): assert t["columns"]["b"]["data_type"] == "double" +def test_not_normalized_identifiers_in_hints() -> None: + @dlt.resource( + primary_key="ID", + merge_key=["Month", "Day"], + columns=[{"name": "Col1", "data_type": "bigint"}], + table_name="🐫Camels", + ) + def CamelResource(): + yield ["🐫"] * 10 + + camels = CamelResource() + # original names are kept + assert camels.name == "CamelResource" + assert camels.table_name == "🐫Camels" + assert camels.columns == {"Col1": {"data_type": "bigint", "name": "Col1"}} + table = camels.compute_table_schema() + columns = table["columns"] + assert "ID" in columns + assert "Month" in columns + assert "Day" in columns + assert "Col1" in columns + assert table["name"] == "🐫Camels" + + # define as part of a source + camel_source = DltSource(Schema("snake_case"), "camel_section", [camels]) + schema = camel_source.discover_schema() + # all normalized + table = schema.get_table("_camels") + columns = table["columns"] + assert "id" in columns + assert "month" in columns + assert "day" in columns + assert "col1" in columns + + def test_resource_name_from_generator() -> None: def some_data(): yield [1, 2, 3] @@ -565,6 +603,21 @@ def created_global(): _assert_source_schema(created_global(), "global") +def test_source_schema_removes_processing_hints() -> None: + eth_V9 = load_yml_case("schemas/eth/ethereum_schema_v9") + assert "x-normalizer" in eth_V9["tables"]["blocks"] + + @dlt.source(schema=Schema.from_dict(eth_V9)) + def created_explicit(): + schema = dlt.current.source_schema() + assert schema.name == "ethereum" + assert "x-normalizer" not in schema.tables["blocks"] + return dlt.resource([1, 2, 3], name="res") + + source = created_explicit() + assert "x-normalizer" not in source.schema.tables["blocks"] + + def test_source_state_context() -> None: @dlt.resource(selected=False) def main(): @@ -849,6 +902,18 @@ def test_standalone_transformer(next_item_mode: str) -> None: ] +def test_transformer_required_args() -> None: + @dlt.transformer + def path_params(id_, workspace_id, load_id, base: bool = False): + yield {"id": id_, "workspace_id": workspace_id, "load_id": load_id} + + data = list([1, 2, 3] | path_params(121, 343)) + assert len(data) == 3 + assert data[0] == {"id": 1, "workspace_id": 121, "load_id": 343} + + # @dlt + + @dlt.transformer(standalone=True, name=lambda args: args["res_name"]) def standalone_tx_with_name(item: TDataItem, res_name: str, init: int = dlt.config.value): return res_name * item * init diff --git a/tests/extract/test_extract.py b/tests/extract/test_extract.py index dc978b997a..dbec417f97 100644 --- a/tests/extract/test_extract.py +++ b/tests/extract/test_extract.py @@ -125,6 +125,7 @@ def with_table_hints(): {"id": 1, "pk2": "B"}, make_hints( write_disposition="merge", + file_format="preferred", columns=[{"name": "id", "precision": 16}, {"name": "text", "data_type": "decimal"}], primary_key="pk2", ), @@ -143,6 +144,7 @@ def with_table_hints(): assert "pk" in table["columns"] assert "text" in table["columns"] assert table["write_disposition"] == "merge" + assert table["file_format"] == "preferred" # make table name dynamic yield dlt.mark.with_hints( diff --git a/tests/extract/test_extract_pipe.py b/tests/extract/test_extract_pipe.py index 9bf580b76a..d285181c55 100644 --- a/tests/extract/test_extract_pipe.py +++ b/tests/extract/test_extract_pipe.py @@ -510,6 +510,19 @@ def test_pipe_copy_on_fork() -> None: assert elems[0].item is not elems[1].item +def test_pipe_pass_empty_list() -> None: + def _gen(): + yield [] + + pipe = Pipe.from_data("data", _gen()) + elems = list(PipeIterator.from_pipe(pipe)) + assert elems[0].item == [] + + pipe = Pipe.from_data("data", [[]]) + elems = list(PipeIterator.from_pipe(pipe)) + assert elems[0].item == [] + + def test_clone_single_pipe() -> None: doc = {"e": 1, "l": 2} parent = Pipe.from_data("data", [doc]) diff --git a/tests/extract/test_incremental.py b/tests/extract/test_incremental.py index bb6fb70983..26158177ff 100644 --- a/tests/extract/test_incremental.py +++ b/tests/extract/test_incremental.py @@ -17,6 +17,7 @@ from dlt.common.configuration.specs.base_configuration import configspec, BaseConfiguration from dlt.common.configuration import ConfigurationValueError from dlt.common.pendulum import pendulum, timedelta +from dlt.common import Decimal from dlt.common.pipeline import NormalizeInfo, StateInjectableContext, resource_state from dlt.common.schema.schema import Schema from dlt.common.utils import uniq_id, digest128, chunks @@ -197,7 +198,8 @@ def some_data(created_at=dlt.sources.incremental("created_at")): yield from source_items2 p = dlt.pipeline( - pipeline_name=uniq_id(), destination="duckdb", credentials=duckdb.connect(":memory:") + pipeline_name=uniq_id(), + destination=dlt.destinations.duckdb(credentials=duckdb.connect(":memory:")), ) p.run(some_data()).raise_on_failed_jobs() p.run(some_data()).raise_on_failed_jobs() @@ -237,7 +239,8 @@ def some_data(created_at=dlt.sources.incremental("created_at")): yield from source_items2 p = dlt.pipeline( - pipeline_name=uniq_id(), destination="duckdb", credentials=duckdb.connect(":memory:") + pipeline_name=uniq_id(), + destination=dlt.destinations.duckdb(credentials=duckdb.connect(":memory:")), ) p.run(some_data()).raise_on_failed_jobs() p.run(some_data()).raise_on_failed_jobs() @@ -443,7 +446,8 @@ def some_data(created_at=dlt.sources.incremental("created_at")): yield from source_items p = dlt.pipeline( - pipeline_name=uniq_id(), destination="duckdb", credentials=duckdb.connect(":memory:") + pipeline_name=uniq_id(), + destination=dlt.destinations.duckdb(credentials=duckdb.connect(":memory:")), ) p.run(some_data()).raise_on_failed_jobs() @@ -786,6 +790,43 @@ def some_data(first: bool, last_timestamp=dlt.sources.incremental("ts")): p.run(some_data(False)) +@pytest.mark.parametrize("item_type", set(ALL_TEST_DATA_ITEM_FORMATS) - {"pandas"}) +@pytest.mark.parametrize( + "id_value", + ("1231231231231271872", b"1231231231231271872", pendulum.now(), 1271.78, Decimal("1231.87")), +) +def test_primary_key_types(item_type: TestDataItemFormat, id_value: Any) -> None: + """Case when deduplication filter is empty for an Arrow table.""" + p = dlt.pipeline(pipeline_name=uniq_id(), destination="duckdb") + now = pendulum.now() + + data = [ + { + "delta": str(i), + "ts": now.add(days=i), + "_id": id_value, + } + for i in range(-10, 10) + ] + source_items = data_to_item_format(item_type, data) + start = now.add(days=-10) + + @dlt.resource + def some_data( + last_timestamp=dlt.sources.incremental("ts", initial_value=start, primary_key="_id"), + ): + yield from source_items + + info = p.run(some_data()) + info.raise_on_failed_jobs() + norm_info = p.last_trace.last_normalize_info + assert norm_info.row_counts["some_data"] == 20 + # load incrementally + info = p.run(some_data()) + norm_info = p.last_trace.last_normalize_info + assert "some_data" not in norm_info.row_counts + + @pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) def test_replace_resets_state(item_type: TestDataItemFormat) -> None: p = dlt.pipeline(pipeline_name=uniq_id(), destination="duckdb") diff --git a/tests/extract/test_sources.py b/tests/extract/test_sources.py index 7b2613776d..a170c6977d 100644 --- a/tests/extract/test_sources.py +++ b/tests/extract/test_sources.py @@ -39,6 +39,39 @@ def switch_to_fifo(): del os.environ["EXTRACT__NEXT_ITEM_MODE"] +def test_basic_source() -> None: + def basic_gen(): + yield 1 + + schema = Schema("test") + s = DltSource.from_data(schema, "section", basic_gen) + assert s.name == "test" + assert s.section == "section" + assert s.max_table_nesting is None + assert s.root_key is False + assert s.schema_contract is None + assert s.exhausted is False + assert s.schema is schema + assert len(s.resources) == 1 + assert s.resources == s.selected_resources + + # set some props + s.max_table_nesting = 10 + assert s.max_table_nesting == 10 + s.root_key = True + assert s.root_key is True + s.schema_contract = "evolve" + assert s.schema_contract == "evolve" + + s.max_table_nesting = None + s.root_key = False + s.schema_contract = None + + assert s.max_table_nesting is None + assert s.root_key is False + assert s.schema_contract is None + + def test_call_data_resource() -> None: with pytest.raises(TypeError): DltResource.from_data([1], name="t")() @@ -1274,6 +1307,8 @@ def empty_gen(): primary_key=["a", "b"], merge_key=["c", "a"], schema_contract="freeze", + table_format="delta", + file_format="jsonl", ) table = empty_r.compute_table_schema() assert table["columns"]["a"] == { @@ -1288,11 +1323,15 @@ def empty_gen(): assert table["parent"] == "parent" assert empty_r.table_name == "table" assert table["schema_contract"] == "freeze" + assert table["table_format"] == "delta" + assert table["file_format"] == "jsonl" # reset empty_r.apply_hints( table_name="", parent_table_name="", + table_format="", + file_format="", primary_key=[], merge_key="", columns={}, diff --git a/tests/helpers/airflow_tests/test_airflow_wrapper.py b/tests/helpers/airflow_tests/test_airflow_wrapper.py index 533d16c998..ac12f70037 100644 --- a/tests/helpers/airflow_tests/test_airflow_wrapper.py +++ b/tests/helpers/airflow_tests/test_airflow_wrapper.py @@ -150,8 +150,7 @@ def test_regular_run() -> None: pipeline_standalone = dlt.pipeline( pipeline_name="pipeline_standalone", dataset_name="mock_data_" + uniq_id(), - destination="duckdb", - credentials=":pipeline:", + destination=dlt.destinations.duckdb(credentials=":pipeline:"), ) pipeline_standalone.run(mock_data_source()) pipeline_standalone_counts = load_table_counts( @@ -170,8 +169,7 @@ def dag_regular(): pipeline_dag_regular = dlt.pipeline( pipeline_name="pipeline_dag_regular", dataset_name="mock_data_" + uniq_id(), - destination="duckdb", - credentials=":pipeline:", + destination=dlt.destinations.duckdb(credentials=":pipeline:"), ) tasks_list = tasks.add_run( pipeline_dag_regular, @@ -214,8 +212,7 @@ def dag_decomposed(): pipeline_dag_decomposed = dlt.pipeline( pipeline_name="pipeline_dag_decomposed", dataset_name="mock_data_" + uniq_id(), - destination="duckdb", - credentials=quackdb_path, + destination=dlt.destinations.duckdb(credentials=quackdb_path), ) tasks_list = tasks.add_run( pipeline_dag_decomposed, @@ -247,8 +244,7 @@ def test_run() -> None: pipeline_standalone = dlt.pipeline( pipeline_name="pipeline_standalone", dataset_name="mock_data_" + uniq_id(), - destination="duckdb", - credentials=":pipeline:", + destination=dlt.destinations.duckdb(credentials=":pipeline:"), ) pipeline_standalone.run(mock_data_source()) pipeline_standalone_counts = load_table_counts( @@ -268,8 +264,7 @@ def dag_regular(): pipeline_dag_regular = dlt.pipeline( pipeline_name="pipeline_dag_regular", dataset_name="mock_data_" + uniq_id(), - destination="duckdb", - credentials=quackdb_path, + destination=dlt.destinations.duckdb(credentials=quackdb_path), ) task = tasks.run(pipeline_dag_regular, mock_data_source()) @@ -292,8 +287,7 @@ def test_parallel_run(): pipeline_standalone = dlt.pipeline( pipeline_name="pipeline_parallel", dataset_name="mock_data_" + uniq_id(), - destination="duckdb", - credentials=":pipeline:", + destination=dlt.destinations.duckdb(credentials=":pipeline:"), ) pipeline_standalone.run(mock_data_source()) pipeline_standalone_counts = load_table_counts( @@ -315,8 +309,7 @@ def dag_parallel(): pipeline_dag_parallel = dlt.pipeline( pipeline_name="pipeline_dag_parallel", dataset_name="mock_data_" + uniq_id(), - destination="duckdb", - credentials=quackdb_path, + destination=dlt.destinations.duckdb(credentials=quackdb_path), ) tasks_list = tasks.add_run( pipeline_dag_parallel, @@ -349,8 +342,7 @@ def test_parallel_incremental(): pipeline_standalone = dlt.pipeline( pipeline_name="pipeline_parallel", dataset_name="mock_data_" + uniq_id(), - destination="duckdb", - credentials=":pipeline:", + destination=dlt.destinations.duckdb(credentials=":pipeline:"), ) pipeline_standalone.run(mock_data_incremental_source()) @@ -369,8 +361,7 @@ def dag_parallel(): pipeline_dag_parallel = dlt.pipeline( pipeline_name="pipeline_dag_parallel", dataset_name="mock_data_" + uniq_id(), - destination="duckdb", - credentials=quackdb_path, + destination=dlt.destinations.duckdb(credentials=quackdb_path), ) tasks.add_run( pipeline_dag_parallel, @@ -401,8 +392,7 @@ def test_parallel_isolated_run(): pipeline_standalone = dlt.pipeline( pipeline_name="pipeline_parallel", dataset_name="mock_data_" + uniq_id(), - destination="duckdb", - credentials=":pipeline:", + destination=dlt.destinations.duckdb(credentials=":pipeline:"), ) pipeline_standalone.run(mock_data_source()) pipeline_standalone_counts = load_table_counts( @@ -424,8 +414,7 @@ def dag_parallel(): pipeline_dag_parallel = dlt.pipeline( pipeline_name="pipeline_dag_parallel", dataset_name="mock_data_" + uniq_id(), - destination="duckdb", - credentials=quackdb_path, + destination=dlt.destinations.duckdb(credentials=quackdb_path), ) tasks_list = tasks.add_run( pipeline_dag_parallel, @@ -466,8 +455,7 @@ def test_parallel_run_single_resource(): pipeline_standalone = dlt.pipeline( pipeline_name="pipeline_parallel", dataset_name="mock_data_" + uniq_id(), - destination="duckdb", - credentials=":pipeline:", + destination=dlt.destinations.duckdb(credentials=":pipeline:"), ) pipeline_standalone.run(mock_data_single_resource()) pipeline_standalone_counts = load_table_counts( @@ -489,8 +477,7 @@ def dag_parallel(): pipeline_dag_parallel = dlt.pipeline( pipeline_name="pipeline_dag_parallel", dataset_name="mock_data_" + uniq_id(), - destination="duckdb", - credentials=quackdb_path, + destination=dlt.destinations.duckdb(credentials=quackdb_path), ) tasks_list = tasks.add_run( pipeline_dag_parallel, @@ -555,8 +542,7 @@ def dag_fail_3(): pipeline_fail_3 = dlt.pipeline( pipeline_name="pipeline_fail_3", dataset_name="mock_data_" + uniq_id(), - destination="duckdb", - credentials=":pipeline:", + destination=dlt.destinations.duckdb(credentials=":pipeline:"), ) tasks.add_run( pipeline_fail_3, _fail_3, trigger_rule="all_done", retries=0, provide_context=True @@ -582,8 +568,7 @@ def dag_fail_4(): pipeline_fail_3 = dlt.pipeline( pipeline_name="pipeline_fail_3", dataset_name="mock_data_" + uniq_id(), - destination="duckdb", - credentials=":pipeline:", + destination=dlt.destinations.duckdb(credentials=":pipeline:"), ) tasks.add_run( pipeline_fail_3, _fail_3, trigger_rule="all_done", retries=0, provide_context=True @@ -611,8 +596,7 @@ def dag_fail_5(): pipeline_fail_3 = dlt.pipeline( pipeline_name="pipeline_fail_3", dataset_name="mock_data_" + uniq_id(), - destination="duckdb", - credentials=":pipeline:", + destination=dlt.destinations.duckdb(credentials=":pipeline:"), ) tasks.add_run( pipeline_fail_3, _fail_3, trigger_rule="all_done", retries=0, provide_context=True @@ -886,8 +870,9 @@ def dag_parallel(): pipe = dlt.pipeline( pipeline_name="test_pipeline", dataset_name="mock_data", - destination="duckdb", - credentials=os.path.join("_storage", "test_pipeline.duckdb"), + destination=dlt.destinations.duckdb( + credentials=os.path.join("_storage", "test_pipeline.duckdb") + ), ) task = tasks.add_run( pipe, @@ -954,8 +939,7 @@ def dag_regular(): call_dag = dlt.pipeline( pipeline_name="callable_dag", dataset_name="mock_data_" + uniq_id(), - destination="duckdb", - credentials=quackdb_path, + destination=dlt.destinations.duckdb(credentials=quackdb_path), ) tasks.run(call_dag, callable_source) @@ -991,8 +975,7 @@ def dag_regular(): call_dag = dlt.pipeline( pipeline_name="callable_dag", dataset_name="mock_data_" + uniq_id(), - destination="duckdb", - credentials=quackdb_path, + destination=dlt.destinations.duckdb(credentials=quackdb_path), ) tasks.run(call_dag, mock_data_source, on_before_run=on_before_run) diff --git a/tests/helpers/airflow_tests/test_join_airflow_scheduler.py b/tests/helpers/airflow_tests/test_join_airflow_scheduler.py index 8c1992c506..d737f254e3 100644 --- a/tests/helpers/airflow_tests/test_join_airflow_scheduler.py +++ b/tests/helpers/airflow_tests/test_join_airflow_scheduler.py @@ -292,8 +292,7 @@ def test_scheduler_pipeline_state() -> None: pipeline = dlt.pipeline( pipeline_name="pipeline_dag_regular", dataset_name="mock_data_" + uniq_id(), - destination="duckdb", - credentials=":pipeline:", + destination=dlt.destinations.duckdb(credentials=":pipeline:"), ) now = pendulum.now() diff --git a/tests/libs/pyarrow/test_pyarrow_normalizer.py b/tests/libs/pyarrow/test_pyarrow_normalizer.py index 63abcbc92a..d975702ad8 100644 --- a/tests/libs/pyarrow/test_pyarrow_normalizer.py +++ b/tests/libs/pyarrow/test_pyarrow_normalizer.py @@ -3,8 +3,8 @@ import pyarrow as pa import pytest -from dlt.common.libs.pyarrow import normalize_py_arrow_item, NameNormalizationClash -from dlt.common.normalizers import explicit_normalizers, import_normalizers +from dlt.common.libs.pyarrow import normalize_py_arrow_item, NameNormalizationCollision +from dlt.common.normalizers.utils import explicit_normalizers, import_normalizers from dlt.common.schema.utils import new_column, TColumnSchema from dlt.common.destination import DestinationCapabilitiesContext @@ -65,7 +65,7 @@ def test_field_normalization_clash() -> None: {"col^New": "hello", "col_new": 1}, ] ) - with pytest.raises(NameNormalizationClash): + with pytest.raises(NameNormalizationCollision): _normalize(table, []) diff --git a/tests/load/athena_iceberg/test_athena_adapter.py b/tests/load/athena_iceberg/test_athena_adapter.py index 3144eb9cc9..19c176a374 100644 --- a/tests/load/athena_iceberg/test_athena_adapter.py +++ b/tests/load/athena_iceberg/test_athena_adapter.py @@ -2,7 +2,7 @@ import dlt from dlt.destinations import filesystem -from dlt.destinations.impl.athena.athena_adapter import athena_adapter, athena_partition +from dlt.destinations.adapters import athena_adapter, athena_partition # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -40,7 +40,7 @@ def not_partitioned_table(): "athena_test", destination="athena", staging=filesystem("s3://not-a-real-bucket"), - full_refresh=True, + dev_mode=True, ) pipeline.extract([partitioned_table, not_partitioned_table]) diff --git a/tests/load/athena_iceberg/test_athena_iceberg.py b/tests/load/athena_iceberg/test_athena_iceberg.py index 4fe01752ee..0ef935a8bc 100644 --- a/tests/load/athena_iceberg/test_athena_iceberg.py +++ b/tests/load/athena_iceberg/test_athena_iceberg.py @@ -1,15 +1,9 @@ import pytest import os -import datetime # noqa: I251 from typing import Iterator, Any import dlt -from dlt.common import pendulum -from dlt.common.utils import uniq_id -from tests.cases import table_update_and_row, assert_all_data_types_row -from tests.pipeline.utils import assert_load_info, load_table_counts - -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration +from tests.pipeline.utils import load_table_counts from dlt.destinations.exceptions import DatabaseTerminalException diff --git a/tests/load/bigquery/test_bigquery_client.py b/tests/load/bigquery/test_bigquery_client.py index b16790b07d..a74ab11860 100644 --- a/tests/load/bigquery/test_bigquery_client.py +++ b/tests/load/bigquery/test_bigquery_client.py @@ -16,13 +16,15 @@ ) from dlt.common.configuration.specs import gcp_credentials from dlt.common.configuration.specs.exceptions import InvalidGoogleNativeCredentialsType +from dlt.common.schema.utils import new_table from dlt.common.storages import FileStorage from dlt.common.utils import digest128, uniq_id, custom_environ from dlt.destinations.impl.bigquery.bigquery import BigQueryClient, BigQueryClientConfiguration from dlt.destinations.exceptions import LoadJobNotExistsException, LoadJobTerminalException -from tests.utils import TEST_STORAGE_ROOT, delete_test_storage, preserve_environ +from dlt.destinations.impl.bigquery.bigquery_adapter import AUTODETECT_SCHEMA_HINT +from tests.utils import TEST_STORAGE_ROOT, delete_test_storage from tests.common.utils import json_case_path as common_json_case_path from tests.common.configuration.utils import environment from tests.load.utils import ( @@ -217,15 +219,15 @@ def test_bigquery_configuration() -> None: assert config.fingerprint() == digest128("chat-analytics-rasa-ci") # credential location is deprecated - os.environ["CREDENTIALS__LOCATION"] = "EU" - config = resolve_configuration( - BigQueryClientConfiguration()._bind_dataset_name(dataset_name="dataset"), - sections=("destination", "bigquery"), - ) - assert config.location == "US" - assert config.credentials.location == "EU" - # but if it is set, we propagate it to the config - assert config.get_location() == "EU" + # os.environ["CREDENTIALS__LOCATION"] = "EU" + # config = resolve_configuration( + # BigQueryClientConfiguration()._bind_dataset_name(dataset_name="dataset"), + # sections=("destination", "bigquery"), + # ) + # assert config.location == "US" + # assert config.credentials.location == "EU" + # # but if it is set, we propagate it to the config + # assert config.get_location() == "EU" os.environ["LOCATION"] = "ATLANTIS" config = resolve_configuration( BigQueryClientConfiguration()._bind_dataset_name(dataset_name="dataset"), @@ -245,6 +247,27 @@ def test_bigquery_configuration() -> None: ) +def test_bigquery_autodetect_configuration(client: BigQueryClient) -> None: + # no schema autodetect + assert client._should_autodetect_schema("event_slot") is False + assert client._should_autodetect_schema("_dlt_loads") is False + # add parent table + child = new_table("event_slot__values", "event_slot") + client.schema.update_table(child) + assert client._should_autodetect_schema("event_slot__values") is False + # enable global config + client.config.autodetect_schema = True + assert client._should_autodetect_schema("event_slot") is True + assert client._should_autodetect_schema("_dlt_loads") is False + assert client._should_autodetect_schema("event_slot__values") is True + # enable hint per table + client.config.autodetect_schema = False + client.schema.get_table("event_slot")[AUTODETECT_SCHEMA_HINT] = True # type: ignore[typeddict-unknown-key] + assert client._should_autodetect_schema("event_slot") is True + assert client._should_autodetect_schema("_dlt_loads") is False + assert client._should_autodetect_schema("event_slot__values") is True + + def test_bigquery_job_errors(client: BigQueryClient, file_storage: FileStorage) -> None: # non existing job with pytest.raises(LoadJobNotExistsException): @@ -290,7 +313,7 @@ def test_bigquery_job_errors(client: BigQueryClient, file_storage: FileStorage) @pytest.mark.parametrize("location", ["US", "EU"]) def test_bigquery_location(location: str, file_storage: FileStorage, client) -> None: with cm_yield_client_with_storage( - "bigquery", default_config_values={"credentials": {"location": location}} + "bigquery", default_config_values={"location": location} ) as client: user_table_name = prepare_table(client) load_json = { diff --git a/tests/load/bigquery/test_bigquery_streaming_insert.py b/tests/load/bigquery/test_bigquery_streaming_insert.py index c80f6ed65a..c950a46f91 100644 --- a/tests/load/bigquery/test_bigquery_streaming_insert.py +++ b/tests/load/bigquery/test_bigquery_streaming_insert.py @@ -1,7 +1,7 @@ import pytest import dlt -from dlt.destinations.impl.bigquery.bigquery_adapter import bigquery_adapter +from dlt.destinations.adapters import bigquery_adapter from tests.pipeline.utils import assert_load_info @@ -12,7 +12,7 @@ def test_resource(): bigquery_adapter(test_resource, insert_api="streaming") - pipe = dlt.pipeline(pipeline_name="insert_test", destination="bigquery", full_refresh=True) + pipe = dlt.pipeline(pipeline_name="insert_test", destination="bigquery", dev_mode=True) pack = pipe.run(test_resource, table_name="test_streaming_items44") assert_load_info(pack) @@ -41,10 +41,12 @@ def test_resource(): pipe = dlt.pipeline(pipeline_name="insert_test", destination="bigquery") info = pipe.run(test_resource) + # pick the failed job + failed_job = info.load_packages[0].jobs["failed_jobs"][0] assert ( """BigQuery streaming insert can only be used with `append`""" """ write_disposition, while the given resource has `merge`.""" - ) in info.asdict()["load_packages"][0]["jobs"][0]["failed_message"] + ) in failed_job.failed_message def test_bigquery_streaming_nested_data(): @@ -54,7 +56,7 @@ def test_resource(): bigquery_adapter(test_resource, insert_api="streaming") - pipe = dlt.pipeline(pipeline_name="insert_test", destination="bigquery", full_refresh=True) + pipe = dlt.pipeline(pipeline_name="insert_test", destination="bigquery", dev_mode=True) pack = pipe.run(test_resource, table_name="test_streaming_items") assert_load_info(pack) diff --git a/tests/load/bigquery/test_bigquery_table_builder.py b/tests/load/bigquery/test_bigquery_table_builder.py index df564192dc..63ac645113 100644 --- a/tests/load/bigquery/test_bigquery_table_builder.py +++ b/tests/load/bigquery/test_bigquery_table_builder.py @@ -1,6 +1,8 @@ import os from copy import deepcopy from typing import Iterator, Dict, Any, List +from dlt.common.destination.exceptions import DestinationSchemaTampered +from dlt.common.schema.exceptions import SchemaIdentifierNormalizationCollision from dlt.destinations.impl.bigquery.bigquery_adapter import ( PARTITION_HINT, CLUSTER_HINT, @@ -18,20 +20,26 @@ GcpServiceAccountCredentials, ) from dlt.common.pendulum import pendulum -from dlt.common.schema import Schema +from dlt.common.schema import Schema, utils from dlt.common.utils import custom_environ from dlt.common.utils import uniq_id + from dlt.destinations.exceptions import DestinationSchemaWillNotUpdate +from dlt.destinations import bigquery from dlt.destinations.impl.bigquery.bigquery import BigQueryClient -from dlt.destinations.impl.bigquery.bigquery_adapter import bigquery_adapter +from dlt.destinations.adapters import bigquery_adapter from dlt.destinations.impl.bigquery.configuration import BigQueryClientConfiguration + from dlt.extract import DltResource -from tests.load.pipeline.utils import ( + +from tests.load.utils import ( destinations_configs, DestinationTestConfiguration, drop_active_pipeline_data, + TABLE_UPDATE, + sequence_generator, + empty_schema, ) -from tests.load.utils import TABLE_UPDATE, sequence_generator, empty_schema # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -54,15 +62,30 @@ def test_configuration() -> None: @pytest.fixture def gcp_client(empty_schema: Schema) -> BigQueryClient: + return create_client(empty_schema) + + +@pytest.fixture +def ci_gcp_client(empty_schema: Schema) -> BigQueryClient: + empty_schema._normalizers_config["names"] = "tests.common.cases.normalizers.title_case" + empty_schema.update_normalizers() + # make the destination case insensitive + return create_client(empty_schema, has_case_sensitive_identifiers=False) + + +def create_client(schema: Schema, has_case_sensitive_identifiers: bool = True) -> BigQueryClient: # return a client without opening connection creds = GcpServiceAccountCredentials() creds.project_id = "test_project_id" # noinspection PydanticTypeChecker - return BigQueryClient( - empty_schema, - BigQueryClientConfiguration(credentials=creds)._bind_dataset_name( - dataset_name=f"test_{uniq_id()}" - ), + return bigquery().client( + schema, + BigQueryClientConfiguration( + credentials=creds, + has_case_sensitive_identifiers=has_case_sensitive_identifiers, + # let modify destination caps + should_set_case_sensitivity_on_new_dataset=True, + )._bind_dataset_name(dataset_name=f"test_{uniq_id()}"), ) @@ -89,9 +112,9 @@ def test_create_table(gcp_client: BigQueryClient) -> None: sqlfluff.parse(sql, dialect="bigquery") assert sql.startswith("CREATE TABLE") assert "event_test_table" in sql - assert "`col1` INTEGER NOT NULL" in sql + assert "`col1` INT64 NOT NULL" in sql assert "`col2` FLOAT64 NOT NULL" in sql - assert "`col3` BOOLEAN NOT NULL" in sql + assert "`col3` BOOL NOT NULL" in sql assert "`col4` TIMESTAMP NOT NULL" in sql assert "`col5` STRING " in sql assert "`col6` NUMERIC(38,9) NOT NULL" in sql @@ -100,7 +123,7 @@ def test_create_table(gcp_client: BigQueryClient) -> None: assert "`col9` JSON NOT NULL" in sql assert "`col10` DATE" in sql assert "`col11` TIME" in sql - assert "`col1_precision` INTEGER NOT NULL" in sql + assert "`col1_precision` INT64 NOT NULL" in sql assert "`col4_precision` TIMESTAMP NOT NULL" in sql assert "`col5_precision` STRING(25) " in sql assert "`col6_precision` NUMERIC(6,2) NOT NULL" in sql @@ -119,9 +142,9 @@ def test_alter_table(gcp_client: BigQueryClient) -> None: assert sql.startswith("ALTER TABLE") assert sql.count("ALTER TABLE") == 1 assert "event_test_table" in sql - assert "ADD COLUMN `col1` INTEGER NOT NULL" in sql + assert "ADD COLUMN `col1` INT64 NOT NULL" in sql assert "ADD COLUMN `col2` FLOAT64 NOT NULL" in sql - assert "ADD COLUMN `col3` BOOLEAN NOT NULL" in sql + assert "ADD COLUMN `col3` BOOL NOT NULL" in sql assert "ADD COLUMN `col4` TIMESTAMP NOT NULL" in sql assert "ADD COLUMN `col5` STRING" in sql assert "ADD COLUMN `col6` NUMERIC(38,9) NOT NULL" in sql @@ -130,7 +153,7 @@ def test_alter_table(gcp_client: BigQueryClient) -> None: assert "ADD COLUMN `col9` JSON NOT NULL" in sql assert "ADD COLUMN `col10` DATE" in sql assert "ADD COLUMN `col11` TIME" in sql - assert "ADD COLUMN `col1_precision` INTEGER NOT NULL" in sql + assert "ADD COLUMN `col1_precision` INT64 NOT NULL" in sql assert "ADD COLUMN `col4_precision` TIMESTAMP NOT NULL" in sql assert "ADD COLUMN `col5_precision` STRING(25)" in sql assert "ADD COLUMN `col6_precision` NUMERIC(6,2) NOT NULL" in sql @@ -144,6 +167,47 @@ def test_alter_table(gcp_client: BigQueryClient) -> None: assert "ADD COLUMN `col2` FLOAT64 NOT NULL" in sql +def test_create_table_case_insensitive(ci_gcp_client: BigQueryClient) -> None: + # in case insensitive mode + assert ci_gcp_client.capabilities.has_case_sensitive_identifiers is False + # case sensitive naming convention + assert ci_gcp_client.sql_client.dataset_name.startswith("Test") + with ci_gcp_client.with_staging_dataset(): + assert ci_gcp_client.sql_client.dataset_name.endswith("staginG") + assert ci_gcp_client.sql_client.staging_dataset_name.endswith("staginG") + + ci_gcp_client.schema.update_table( + utils.new_table("event_test_table", columns=deepcopy(TABLE_UPDATE)) + ) + sql = ci_gcp_client._get_table_update_sql( + "Event_test_tablE", + list(ci_gcp_client.schema.get_table_columns("Event_test_tablE").values()), + False, + )[0] + sqlfluff.parse(sql, dialect="bigquery") + # everything capitalized + + # every line starts with "Col" + for line in sql.split("\n")[1:]: + assert line.startswith("`Col") + + # generate collision + ci_gcp_client.schema.update_table( + utils.new_table("event_TEST_table", columns=deepcopy(TABLE_UPDATE)) + ) + assert "Event_TEST_tablE" in ci_gcp_client.schema.tables + with pytest.raises(SchemaIdentifierNormalizationCollision) as coll_ex: + ci_gcp_client.update_stored_schema([]) + assert coll_ex.value.conflict_identifier_name == "Event_test_tablE" + assert coll_ex.value.table_name == "Event_TEST_tablE" + + # make it case sensitive + ci_gcp_client.capabilities.has_case_sensitive_identifiers = True + # now the check passes, we are stopped because it is not allowed to change schema in the loader + with pytest.raises(DestinationSchemaTampered): + ci_gcp_client.update_stored_schema([]) + + def test_create_table_with_partition_and_cluster(gcp_client: BigQueryClient) -> None: mod_update = deepcopy(TABLE_UPDATE) # timestamp @@ -946,7 +1010,7 @@ def sources() -> List[DltResource]: pipeline = destination_config.setup_pipeline( f"bigquery_{uniq_id()}", - full_refresh=True, + dev_mode=True, ) pipeline.run(sources()) diff --git a/tests/load/cases/loading/csv_header.csv b/tests/load/cases/loading/csv_header.csv new file mode 100644 index 0000000000..14c7514e51 --- /dev/null +++ b/tests/load/cases/loading/csv_header.csv @@ -0,0 +1,3 @@ +id|name|description|ordered_at|price +1|item|value|2024-04-12|128.4 +1|"item"|value with space|2024-04-12|128.4 \ No newline at end of file diff --git a/tests/load/cases/loading/csv_no_header.csv b/tests/load/cases/loading/csv_no_header.csv new file mode 100644 index 0000000000..1e3a63494e --- /dev/null +++ b/tests/load/cases/loading/csv_no_header.csv @@ -0,0 +1,2 @@ +1|item|value|2024-04-12|128.4 +1|"item"|value with space|2024-04-12|128.4 \ No newline at end of file diff --git a/tests/load/cases/loading/csv_no_header.csv.gz b/tests/load/cases/loading/csv_no_header.csv.gz new file mode 100644 index 0000000000..310950f484 Binary files /dev/null and b/tests/load/cases/loading/csv_no_header.csv.gz differ diff --git a/tests/load/cases/loading/cve.json b/tests/load/cases/loading/cve.json new file mode 100644 index 0000000000..58796ed8c5 --- /dev/null +++ b/tests/load/cases/loading/cve.json @@ -0,0 +1,397 @@ +{ + "CVE_data_meta": { + "ASSIGNER": "security@apache.org", + "ID": "CVE-2021-44228", + "STATE": "PUBLIC", + "TITLE": "Apache Log4j2 JNDI features do not protect against attacker controlled LDAP and other JNDI related endpoints" + }, + "affects": { + "vendor": { + "vendor_data": [ + { + "product": { + "product_data": [ + { + "product_name": "Apache Log4j2", + "version": { + "version_data": [ + { + "version_affected": ">=", + "version_name": "log4j-core", + "version_value": "2.0-beta9" + }, + { + "version_affected": "<", + "version_name": "log4j-core", + "version_value": "2.3.1" + }, + { + "version_affected": ">=", + "version_name": "log4j-core", + "version_value": "2.4" + }, + { + "version_affected": "<", + "version_name": "log4j-core" + }, + { + "version_affected": ">=", + "version_name": "log4j-core", + "version_value": "2.13.0" + }, + { + "version_affected": "<", + "version_name": "log4j-core", + "version_value": "2.15.0" + } + ] + } + } + ] + }, + "vendor_name": "Apache Software Foundation" + } + ] + } + }, + "credit": [ + { + "lang": "eng", + "value": "This issue was discovered by Chen Zhaojun of Alibaba Cloud Security Team." + } + ], + "data_format": "MITRE", + "data_type": "CVE", + "data_version": "4.0", + "description": { + "description_data": [ + { + "lang": "eng", + "value": "Apache Log4j2 2.0-beta9 through 2.15.0 (excluding security releases 2.12.2, 2.12.3, and 2.3.1) JNDI features used in configuration, log messages, and parameters do not protect against attacker controlled LDAP and other JNDI related endpoints. An attacker who can control log messages or log message parameters can execute arbitrary code loaded from LDAP servers when message lookup substitution is enabled. From log4j 2.15.0, this behavior has been disabled by default. From version 2.16.0 (along with 2.12.2, 2.12.3, and 2.3.1), this functionality has been completely removed. Note that this vulnerability is specific to log4j-core and does not affect log4net, log4cxx, or other Apache Logging Services projects." + } + ] + }, + "generator": { + "engine": "Vulnogram 0.0.9" + }, + "impact": [ + { + "other": "critical" + } + ], + "problemtype": { + "problemtype_data": [ + { + "description": [ + { + "lang": "eng", + "value": "CWE-502 Deserialization of Untrusted Data" + } + ] + }, + { + "description": [ + { + "lang": "eng", + "value": "CWE-400 Uncontrolled Resource Consumption" + } + ] + }, + { + "description": [ + { + "lang": "eng", + "value": "CWE-20 Improper Input Validation" + } + ] + } + ] + }, + "references": { + "reference_data": [ + { + "refsource": "MISC", + "url": "https://logging.apache.org/log4j/2.x/security.html", + "name": "https://logging.apache.org/log4j/2.x/security.html" + }, + { + "refsource": "MLIST", + "name": "[oss-security] 20211210 CVE-2021-44228: Apache Log4j2 JNDI features do not protect against attacker controlled LDAP and other JNDI related endpoints", + "url": "http://www.openwall.com/lists/oss-security/2021/12/10/1" + }, + { + "refsource": "MLIST", + "name": "[oss-security] 20211210 Re: CVE-2021-44228: Apache Log4j2 JNDI features do not protect against attacker controlled LDAP and other JNDI related endpoints", + "url": "http://www.openwall.com/lists/oss-security/2021/12/10/2" + }, + { + "refsource": "CISCO", + "name": "20211210 Vulnerability in Apache Log4j Library Affecting Cisco Products: December 2021", + "url": "https://tools.cisco.com/security/center/content/CiscoSecurityAdvisory/cisco-sa-apache-log4j-qRuKNEbd" + }, + { + "refsource": "MLIST", + "name": "[oss-security] 20211210 Re: CVE-2021-44228: Apache Log4j2 JNDI features do not protect against attacker controlled LDAP and other JNDI related endpoints", + "url": "http://www.openwall.com/lists/oss-security/2021/12/10/3" + }, + { + "refsource": "CONFIRM", + "name": "https://security.netapp.com/advisory/ntap-20211210-0007/", + "url": "https://security.netapp.com/advisory/ntap-20211210-0007/" + }, + { + "refsource": "CONFIRM", + "name": "https://security.netapp.com/advisory/ntap-20211210-0007/", + "url": "https://security.netapp.com/advisory/ntap-20211210-0007/" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165225/Apache-Log4j2-2.14.1-Remote-Code-Execution.html", + "url": "http://packetstormsecurity.com/files/165225/Apache-Log4j2-2.14.1-Remote-Code-Execution.html" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165225/Apache-Log4j2-2.14.1-Remote-Code-Execution.html", + "url": "http://packetstormsecurity.com/files/165225/Apache-Log4j2-2.14.1-Remote-Code-Execution.html" + }, + { + "refsource": "CONFIRM", + "name": "https://psirt.global.sonicwall.com/vuln-detail/SNWLID-2021-0032", + "url": "https://psirt.global.sonicwall.com/vuln-detail/SNWLID-2021-0032" + }, + { + "refsource": "CONFIRM", + "name": "https://psirt.global.sonicwall.com/vuln-detail/SNWLID-2021-0032", + "url": "https://psirt.global.sonicwall.com/vuln-detail/SNWLID-2021-0032" + }, + { + "refsource": "CONFIRM", + "name": "https://www.oracle.com/security-alerts/alert-cve-2021-44228.html", + "url": "https://www.oracle.com/security-alerts/alert-cve-2021-44228.html" + }, + { + "refsource": "DEBIAN", + "name": "DSA-5020", + "url": "https://www.debian.org/security/2021/dsa-5020" + }, + { + "refsource": "MLIST", + "name": "[debian-lts-announce] 20211212 [SECURITY] [DLA 2842-1] apache-log4j2 security update", + "url": "https://lists.debian.org/debian-lts-announce/2021/12/msg00007.html" + }, + { + "refsource": "FEDORA", + "name": "FEDORA-2021-f0f501d01f", + "url": "https://lists.fedoraproject.org/archives/list/package-announce@lists.fedoraproject.org/message/VU57UJDCFIASIO35GC55JMKSRXJMCDFM/" + }, + { + "refsource": "MS", + "name": "Microsoft\u2019s Response to CVE-2021-44228 Apache Log4j 2", + "url": "https://msrc-blog.microsoft.com/2021/12/11/microsofts-response-to-cve-2021-44228-apache-log4j2/" + }, + { + "refsource": "MLIST", + "name": "[oss-security] 20211213 Re: CVE-2021-4104: Deserialization of untrusted data in JMSAppender in Apache Log4j 1.2", + "url": "http://www.openwall.com/lists/oss-security/2021/12/13/2" + }, + { + "refsource": "MLIST", + "name": "[oss-security] 20211213 CVE-2021-4104: Deserialization of untrusted data in JMSAppender in Apache Log4j 1.2", + "url": "http://www.openwall.com/lists/oss-security/2021/12/13/1" + }, + { + "refsource": "MLIST", + "name": "[oss-security] 20211214 CVE-2021-45046: Apache Log4j2 Thread Context Message Pattern and Context Lookup Pattern vulnerable to a denial of service attack", + "url": "http://www.openwall.com/lists/oss-security/2021/12/14/4" + }, + { + "refsource": "CISCO", + "name": "20211210 A Vulnerability in Apache Log4j Library Affecting Cisco Products: December 2021", + "url": "https://tools.cisco.com/security/center/content/CiscoSecurityAdvisory/cisco-sa-apache-log4j-qRuKNEbd" + }, + { + "refsource": "CERT-VN", + "name": "VU#930724", + "url": "https://www.kb.cert.org/vuls/id/930724" + }, + { + "refsource": "MISC", + "name": "https://twitter.com/kurtseifried/status/1469345530182455296", + "url": "https://twitter.com/kurtseifried/status/1469345530182455296" + }, + { + "refsource": "CONFIRM", + "name": "https://cert-portal.siemens.com/productcert/pdf/ssa-661247.pdf", + "url": "https://cert-portal.siemens.com/productcert/pdf/ssa-661247.pdf" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165260/VMware-Security-Advisory-2021-0028.html", + "url": "http://packetstormsecurity.com/files/165260/VMware-Security-Advisory-2021-0028.html" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165270/Apache-Log4j2-2.14.1-Remote-Code-Execution.html", + "url": "http://packetstormsecurity.com/files/165270/Apache-Log4j2-2.14.1-Remote-Code-Execution.html" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165261/Apache-Log4j2-2.14.1-Information-Disclosure.html", + "url": "http://packetstormsecurity.com/files/165261/Apache-Log4j2-2.14.1-Information-Disclosure.html" + }, + { + "refsource": "CONFIRM", + "name": "https://www.intel.com/content/www/us/en/security-center/advisory/intel-sa-00646.html", + "url": "https://www.intel.com/content/www/us/en/security-center/advisory/intel-sa-00646.html" + }, + { + "refsource": "CISCO", + "name": "20211210 Vulnerabilities in Apache Log4j Library Affecting Cisco Products: December 2021", + "url": "https://tools.cisco.com/security/center/content/CiscoSecurityAdvisory/cisco-sa-apache-log4j-qRuKNEbd" + }, + { + "refsource": "MLIST", + "name": "[oss-security] 20211215 Re: CVE-2021-45046: Apache Log4j2 Thread Context Message Pattern and Context Lookup Pattern vulnerable to a denial of service attack", + "url": "http://www.openwall.com/lists/oss-security/2021/12/15/3" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165282/Log4j-Payload-Generator.html", + "url": "http://packetstormsecurity.com/files/165282/Log4j-Payload-Generator.html" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165281/Log4j2-Log4Shell-Regexes.html", + "url": "http://packetstormsecurity.com/files/165281/Log4j2-Log4Shell-Regexes.html" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165307/Log4j-Remote-Code-Execution-Word-Bypassing.html", + "url": "http://packetstormsecurity.com/files/165307/Log4j-Remote-Code-Execution-Word-Bypassing.html" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165311/log4j-scan-Extensive-Scanner.html", + "url": "http://packetstormsecurity.com/files/165311/log4j-scan-Extensive-Scanner.html" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165306/L4sh-Log4j-Remote-Code-Execution.html", + "url": "http://packetstormsecurity.com/files/165306/L4sh-Log4j-Remote-Code-Execution.html" + }, + { + "refsource": "CONFIRM", + "name": "https://cert-portal.siemens.com/productcert/pdf/ssa-714170.pdf", + "url": "https://cert-portal.siemens.com/productcert/pdf/ssa-714170.pdf" + }, + { + "refsource": "FEDORA", + "name": "FEDORA-2021-66d6c484f3", + "url": "https://lists.fedoraproject.org/archives/list/package-announce@lists.fedoraproject.org/message/M5CSVUNV4HWZZXGOKNSK6L7RPM7BOKIB/" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165371/VMware-Security-Advisory-2021-0028.4.html", + "url": "http://packetstormsecurity.com/files/165371/VMware-Security-Advisory-2021-0028.4.html" + }, + { + "refsource": "CONFIRM", + "name": "https://cert-portal.siemens.com/productcert/pdf/ssa-397453.pdf", + "url": "https://cert-portal.siemens.com/productcert/pdf/ssa-397453.pdf" + }, + { + "refsource": "CONFIRM", + "name": "https://cert-portal.siemens.com/productcert/pdf/ssa-479842.pdf", + "url": "https://cert-portal.siemens.com/productcert/pdf/ssa-479842.pdf" + }, + { + "url": "https://www.oracle.com/security-alerts/cpujan2022.html", + "refsource": "MISC", + "name": "https://www.oracle.com/security-alerts/cpujan2022.html" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165532/Log4Shell-HTTP-Header-Injection.html", + "url": "http://packetstormsecurity.com/files/165532/Log4Shell-HTTP-Header-Injection.html" + }, + { + "refsource": "MISC", + "name": "https://github.com/cisagov/log4j-affected-db/blob/develop/SOFTWARE-LIST.md", + "url": "https://github.com/cisagov/log4j-affected-db/blob/develop/SOFTWARE-LIST.md" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165642/VMware-vCenter-Server-Unauthenticated-Log4Shell-JNDI-Injection-Remote-Code-Execution.html", + "url": "http://packetstormsecurity.com/files/165642/VMware-vCenter-Server-Unauthenticated-Log4Shell-JNDI-Injection-Remote-Code-Execution.html" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/165673/UniFi-Network-Application-Unauthenticated-Log4Shell-Remote-Code-Execution.html", + "url": "http://packetstormsecurity.com/files/165673/UniFi-Network-Application-Unauthenticated-Log4Shell-Remote-Code-Execution.html" + }, + { + "refsource": "FULLDISC", + "name": "20220314 APPLE-SA-2022-03-14-7 Xcode 13.3", + "url": "http://seclists.org/fulldisclosure/2022/Mar/23" + }, + { + "refsource": "MISC", + "name": "https://www.bentley.com/en/common-vulnerability-exposure/be-2022-0001", + "url": "https://www.bentley.com/en/common-vulnerability-exposure/be-2022-0001" + }, + { + "refsource": "MISC", + "name": "https://github.com/cisagov/log4j-affected-db", + "url": "https://github.com/cisagov/log4j-affected-db" + }, + { + "refsource": "CONFIRM", + "name": "https://support.apple.com/kb/HT213189", + "url": "https://support.apple.com/kb/HT213189" + }, + { + "url": "https://www.oracle.com/security-alerts/cpuapr2022.html", + "refsource": "MISC", + "name": "https://www.oracle.com/security-alerts/cpuapr2022.html" + }, + { + "refsource": "MISC", + "name": "https://github.com/nu11secur1ty/CVE-mitre/tree/main/CVE-2021-44228", + "url": "https://github.com/nu11secur1ty/CVE-mitre/tree/main/CVE-2021-44228" + }, + { + "refsource": "MISC", + "name": "https://www.nu11secur1ty.com/2021/12/cve-2021-44228.html", + "url": "https://www.nu11secur1ty.com/2021/12/cve-2021-44228.html" + }, + { + "refsource": "FULLDISC", + "name": "20220721 Open-Xchange Security Advisory 2022-07-21", + "url": "http://seclists.org/fulldisclosure/2022/Jul/11" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/167794/Open-Xchange-App-Suite-7.10.x-Cross-Site-Scripting-Command-Injection.html", + "url": "http://packetstormsecurity.com/files/167794/Open-Xchange-App-Suite-7.10.x-Cross-Site-Scripting-Command-Injection.html" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/167917/MobileIron-Log4Shell-Remote-Command-Execution.html", + "url": "http://packetstormsecurity.com/files/167917/MobileIron-Log4Shell-Remote-Command-Execution.html" + }, + { + "refsource": "FULLDISC", + "name": "20221208 Intel Data Center Manager <= 5.1 Local Privileges Escalation", + "url": "http://seclists.org/fulldisclosure/2022/Dec/2" + }, + { + "refsource": "MISC", + "name": "http://packetstormsecurity.com/files/171626/AD-Manager-Plus-7122-Remote-Code-Execution.html", + "url": "http://packetstormsecurity.com/files/171626/AD-Manager-Plus-7122-Remote-Code-Execution.html" + } + ] + }, + "source": { + "discovery": "UNKNOWN" + } +} \ No newline at end of file diff --git a/tests/load/cases/loading/header.jsonl b/tests/load/cases/loading/header.jsonl new file mode 100644 index 0000000000..c2f9fee551 --- /dev/null +++ b/tests/load/cases/loading/header.jsonl @@ -0,0 +1,2 @@ +{"id": 1, "name": "item", "description": "value", "ordered_at": "2024-04-12", "price": 128.4} +{"id": 1, "name": "item", "description": "value with space", "ordered_at": "2024-04-12", "price": 128.4} \ No newline at end of file diff --git a/tests/load/clickhouse/clickhouse-compose.yml b/tests/load/clickhouse/clickhouse-compose.yml new file mode 100644 index 0000000000..b6415b120a --- /dev/null +++ b/tests/load/clickhouse/clickhouse-compose.yml @@ -0,0 +1,26 @@ +--- +services: + clickhouse: + image: clickhouse/clickhouse-server + ports: + - "9000:9000" + - "8123:8123" + environment: + - CLICKHOUSE_DB=dlt_data + - CLICKHOUSE_USER=loader + - CLICKHOUSE_PASSWORD=loader + - CLICKHOUSE_DEFAULT_ACCESS_MANAGEMENT=1 + volumes: + - clickhouse_data:/var/lib/clickhouse/ + - clickhouse_logs:/var/log/clickhouse-server/ + restart: unless-stopped + healthcheck: + test: [ "CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:8123/ping" ] + interval: 3s + timeout: 5s + retries: 5 + + +volumes: + clickhouse_data: + clickhouse_logs: diff --git a/tests/load/clickhouse/test_clickhouse_adapter.py b/tests/load/clickhouse/test_clickhouse_adapter.py index 36d3ac07f7..e8e2b327c0 100644 --- a/tests/load/clickhouse/test_clickhouse_adapter.py +++ b/tests/load/clickhouse/test_clickhouse_adapter.py @@ -1,61 +1,112 @@ +from typing import Generator, Dict, cast + import dlt +from dlt.common.utils import custom_environ from dlt.destinations.adapters import clickhouse_adapter +from dlt.destinations.impl.clickhouse.sql_client import ClickHouseSqlClient +from dlt.destinations.impl.clickhouse.typing import TDeployment +from tests.load.clickhouse.utils import get_deployment_type from tests.pipeline.utils import assert_load_info def test_clickhouse_adapter() -> None: @dlt.resource - def merge_tree_resource(): + def merge_tree_resource() -> Generator[Dict[str, int], None, None]: yield {"field1": 1, "field2": 2} + # `ReplicatedMergeTree` has been supplanted by `ReplacingMergeTree` on CH Cloud, + # which is automatically selected even if `MergeTree` is selected. + # See https://clickhouse.com/docs/en/cloud/reference/shared-merge-tree. + + # The `Log` Family of engines are only supported in self-managed deployments. + # So can't test in CH Cloud CI. + @dlt.resource - def replicated_merge_tree_resource(): + def replicated_merge_tree_resource() -> Generator[Dict[str, int], None, None]: yield {"field1": 1, "field2": 2} @dlt.resource - def not_annotated_resource(): + def not_annotated_resource() -> Generator[Dict[str, int], None, None]: + """Non annotated resource will default to `SharedMergeTree` for CH cloud + and `MergeTree` for self-managed installation.""" yield {"field1": 1, "field2": 2} clickhouse_adapter(merge_tree_resource, table_engine_type="merge_tree") clickhouse_adapter(replicated_merge_tree_resource, table_engine_type="replicated_merge_tree") - pipe = dlt.pipeline(pipeline_name="adapter_test", destination="clickhouse", full_refresh=True) - pack = pipe.run([merge_tree_resource, replicated_merge_tree_resource, not_annotated_resource]) + pipe = dlt.pipeline(pipeline_name="adapter_test", destination="clickhouse", dev_mode=True) + + with pipe.sql_client() as client: + deployment_type: TDeployment = get_deployment_type(cast(ClickHouseSqlClient, client)) + + if deployment_type == "ClickHouseCloud": + pack = pipe.run( + [ + merge_tree_resource, + replicated_merge_tree_resource, + not_annotated_resource, + ] + ) + else: + # `ReplicatedMergeTree` not supported if only a single node. + pack = pipe.run([merge_tree_resource, not_annotated_resource]) assert_load_info(pack) with pipe.sql_client() as client: - # get map of table names to full table names + # Get a map of table names to full table names. tables = {} for table in client._list_tables(): if "resource" in table: tables[table.split("___")[1]] = table - assert (len(tables.keys())) == 3 + if deployment_type == "ClickHouseCloud": + assert (len(tables.keys())) == 3 + else: + assert (len(tables.keys())) == 2 - # check content + # Check the table content. for full_table_name in tables.values(): with client.execute_query(f"SELECT * FROM {full_table_name};") as cursor: res = cursor.fetchall() assert tuple(res[0])[:2] == (1, 2) - # check table format - # fails now, because we do not have a cluster (I think), it will fall back to SharedMergeTree - for full_table_name in tables.values(): + # Check the table engine. + for table_name, full_table_name in tables.items(): with client.execute_query( - "SELECT database, name, engine, engine_full FROM system.tables WHERE name =" - f" '{full_table_name}';" + "SELECT database, name, engine, engine_full FROM system.tables " + f"WHERE name = '{full_table_name}';" ) as cursor: res = cursor.fetchall() - # this should test that two tables should be replicatedmergetree tables - assert tuple(res[0])[2] == "SharedMergeTree" + if table_name in ( + "merge_tree_resource", + "replicated_merge_tree_resource", + ): + if deployment_type == "ClickHouseCloud": + assert tuple(res[0])[2] in ( + "MergeTree", + "SharedMergeTree", + "ReplicatedMergeTree", + ) + else: + assert tuple(res[0])[2] in ("MergeTree",) + else: + # Non annotated resource needs to default to detected installation + # type, i.e. cloud or self-managed. + # CI runs on CH cloud, so will be `SharedMergeTree`. + if deployment_type == "ClickHouseCloud": + assert tuple(res[0])[2] == "SharedMergeTree" + else: + assert tuple(res[0])[2] == "MergeTree" - # we can check the gen table sql though + # We can check the generated table's SQL, though. with pipe.destination_client() as dest_client: - for table in tables.keys(): + for table in tables: sql = dest_client._get_table_update_sql( # type: ignore[attr-defined] - table, pipe.default_schema.tables[table]["columns"].values(), generate_alter=False + table, + pipe.default_schema.tables[table]["columns"].values(), + generate_alter=False, ) - if table == "merge_tree_resource": - assert "ENGINE = MergeTree" in sql[0] - else: + if table == "replicated_merge_tree_resource": assert "ENGINE = ReplicatedMergeTree" in sql[0] + else: + assert "ENGINE = MergeTree" or "ENGINE = SharedMergeTree" in sql[0] diff --git a/tests/load/clickhouse/test_clickhouse_configuration.py b/tests/load/clickhouse/test_clickhouse_configuration.py index eb02155406..a4e8abc8dd 100644 --- a/tests/load/clickhouse/test_clickhouse_configuration.py +++ b/tests/load/clickhouse/test_clickhouse_configuration.py @@ -1,8 +1,7 @@ -from typing import Any, Iterator +from typing import Iterator import pytest -import dlt from dlt.common.configuration.resolve import resolve_configuration from dlt.common.libs.sql_alchemy import make_url from dlt.common.utils import digest128 @@ -11,11 +10,6 @@ ClickHouseCredentials, ClickHouseClientConfiguration, ) -from dlt.destinations.impl.snowflake.configuration import ( - SnowflakeClientConfiguration, - SnowflakeCredentials, -) -from tests.common.configuration.utils import environment from tests.load.utils import yield_client_with_storage @@ -27,8 +21,8 @@ def client() -> Iterator[ClickHouseClient]: def test_clickhouse_connection_string_with_all_params() -> None: url = ( "clickhouse://user1:pass1@host1:9000/testdb?allow_experimental_lightweight_delete=1&" - "allow_experimental_object_type=1&connect_timeout=230&enable_http_compression=1&secure=0" - "&send_receive_timeout=1000" + "allow_experimental_object_type=1&connect_timeout=230&date_time_input_format=best_effort&" + "enable_http_compression=1&secure=0&send_receive_timeout=1000" ) creds = ClickHouseCredentials() @@ -53,15 +47,15 @@ def test_clickhouse_configuration() -> None: # def empty fingerprint assert ClickHouseClientConfiguration().fingerprint() == "" # based on host - c = resolve_configuration( - SnowflakeCredentials(), + config = resolve_configuration( + ClickHouseCredentials(), explicit_value="clickhouse://user1:pass1@host1:9000/db1", ) - assert SnowflakeClientConfiguration(credentials=c).fingerprint() == digest128("host1") + assert ClickHouseClientConfiguration(credentials=config).fingerprint() == digest128("host1") def test_clickhouse_connection_settings(client: ClickHouseClient) -> None: - """Test experimental settings are set correctly for session.""" + """Test experimental settings are set correctly for the session.""" conn = client.sql_client.open_connection() cursor1 = conn.cursor() cursor2 = conn.cursor() @@ -74,3 +68,4 @@ def test_clickhouse_connection_settings(client: ClickHouseClient) -> None: assert ("allow_experimental_lightweight_delete", "1") in res assert ("enable_http_compression", "1") in res + assert ("date_time_input_format", "best_effort") in res diff --git a/tests/load/clickhouse/test_clickhouse_gcs_s3_compatibility.py b/tests/load/clickhouse/test_clickhouse_gcs_s3_compatibility.py index 481cd420c6..b2edb12d49 100644 --- a/tests/load/clickhouse/test_clickhouse_gcs_s3_compatibility.py +++ b/tests/load/clickhouse/test_clickhouse_gcs_s3_compatibility.py @@ -22,7 +22,7 @@ def dummy_data() -> Generator[Dict[str, int], None, None]: pipeline_name="gcs_s3_compatibility", destination="clickhouse", staging=gcp_bucket, - full_refresh=True, + dev_mode=True, ) pack = pipe.run([dummy_data]) assert_load_info(pack) diff --git a/tests/load/clickhouse/test_clickhouse_table_builder.py b/tests/load/clickhouse/test_clickhouse_table_builder.py index fd3bf50907..433383b631 100644 --- a/tests/load/clickhouse/test_clickhouse_table_builder.py +++ b/tests/load/clickhouse/test_clickhouse_table_builder.py @@ -6,6 +6,7 @@ from dlt.common.schema import Schema from dlt.common.utils import custom_environ, digest128 from dlt.common.utils import uniq_id +from dlt.destinations import clickhouse from dlt.destinations.impl.clickhouse.clickhouse import ClickHouseClient from dlt.destinations.impl.clickhouse.configuration import ( ClickHouseCredentials, @@ -18,7 +19,7 @@ def clickhouse_client(empty_schema: Schema) -> ClickHouseClient: # Return a client without opening connection. creds = ClickHouseCredentials() - return ClickHouseClient( + return clickhouse().client( empty_schema, ClickHouseClientConfiguration(credentials=creds)._bind_dataset_name(f"test_{uniq_id()}"), ) @@ -138,7 +139,9 @@ def test_clickhouse_alter_table(clickhouse_client: ClickHouseClient) -> None: @pytest.mark.usefixtures("empty_schema") -def test_clickhouse_create_table_with_primary_keys(clickhouse_client: ClickHouseClient) -> None: +def test_clickhouse_create_table_with_primary_keys( + clickhouse_client: ClickHouseClient, +) -> None: mod_update = deepcopy(TABLE_UPDATE) mod_update[1]["primary_key"] = True @@ -170,3 +173,28 @@ def test_clickhouse_create_table_with_hints(client: ClickHouseClient) -> None: # No hints. assert "`col3` boolean NOT NULL" in sql assert "`col4` timestamp with time zone NOT NULL" in sql + + +def test_clickhouse_table_engine_configuration() -> None: + with custom_environ( + { + "DESTINATION__CLICKHOUSE__CREDENTIALS__HOST": "localhost", + "DESTINATION__CLICKHOUSE__DATASET_NAME": f"test_{uniq_id()}", + } + ): + config = resolve_configuration( + ClickHouseClientConfiguration(), sections=("destination", "clickhouse") + ) + assert config.table_engine_type == "merge_tree" + + with custom_environ( + { + "DESTINATION__CLICKHOUSE__CREDENTIALS__HOST": "localhost", + "DESTINATION__CLICKHOUSE__TABLE_ENGINE_TYPE": "replicated_merge_tree", + "DESTINATION__CLICKHOUSE__DATASET_NAME": f"test_{uniq_id()}", + } + ): + config = resolve_configuration( + ClickHouseClientConfiguration(), sections=("destination", "clickhouse") + ) + assert config.table_engine_type == "replicated_merge_tree" diff --git a/tests/load/clickhouse/utils.py b/tests/load/clickhouse/utils.py new file mode 100644 index 0000000000..5c34d52148 --- /dev/null +++ b/tests/load/clickhouse/utils.py @@ -0,0 +1,9 @@ +from dlt.destinations.impl.clickhouse.sql_client import ClickHouseSqlClient +from dlt.destinations.impl.clickhouse.typing import TDeployment + + +def get_deployment_type(client: ClickHouseSqlClient) -> TDeployment: + cloud_mode = int(client.execute_sql(""" + SELECT value FROM system.settings WHERE name = 'cloud_mode' + """)[0][0]) + return "ClickHouseCloud" if cloud_mode else "ClickHouseOSS" diff --git a/tests/load/conftest.py b/tests/load/conftest.py index fefaeee077..a110b1198f 100644 --- a/tests/load/conftest.py +++ b/tests/load/conftest.py @@ -2,8 +2,8 @@ import pytest from typing import Iterator -from tests.load.utils import ALL_BUCKETS, DEFAULT_BUCKETS, WITH_GDRIVE_BUCKETS -from tests.utils import preserve_environ +from tests.load.utils import ALL_BUCKETS, DEFAULT_BUCKETS, WITH_GDRIVE_BUCKETS, drop_pipeline +from tests.utils import preserve_environ, patch_home_dir @pytest.fixture(scope="function", params=DEFAULT_BUCKETS) diff --git a/tests/load/databricks/test_databricks_configuration.py b/tests/load/databricks/test_databricks_configuration.py index cc353f5894..f6a06180c9 100644 --- a/tests/load/databricks/test_databricks_configuration.py +++ b/tests/load/databricks/test_databricks_configuration.py @@ -6,7 +6,6 @@ from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration from dlt.common.configuration import resolve_configuration -from tests.utils import preserve_environ # mark all tests as essential, do not remove pytestmark = pytest.mark.essential diff --git a/tests/load/dremio/test_dremio_client.py b/tests/load/dremio/test_dremio_client.py index d0002dc343..efc72c0652 100644 --- a/tests/load/dremio/test_dremio_client.py +++ b/tests/load/dremio/test_dremio_client.py @@ -1,6 +1,8 @@ import pytest from dlt.common.schema import TColumnSchema, Schema + +from dlt.destinations import dremio from dlt.destinations.impl.dremio.configuration import DremioClientConfiguration, DremioCredentials from dlt.destinations.impl.dremio.dremio import DremioClient from tests.load.utils import empty_schema @@ -10,11 +12,11 @@ def dremio_client(empty_schema: Schema) -> DremioClient: creds = DremioCredentials() creds.database = "test_database" - return DremioClient( + # ignore any configured values + creds.resolve() + return dremio(credentials=creds).client( empty_schema, - DremioClientConfiguration(credentials=creds)._bind_dataset_name( - dataset_name="test_dataset" - ), + DremioClientConfiguration()._bind_dataset_name(dataset_name="test_dataset"), ) diff --git a/tests/load/duckdb/test_duckdb_client.py b/tests/load/duckdb/test_duckdb_client.py index 8f6bf195e2..f4088a7608 100644 --- a/tests/load/duckdb/test_duckdb_client.py +++ b/tests/load/duckdb/test_duckdb_client.py @@ -15,9 +15,8 @@ from dlt.destinations.impl.duckdb.exceptions import InvalidInMemoryDuckdbCredentials from dlt.pipeline.exceptions import PipelineStepFailed -from tests.load.pipeline.utils import drop_pipeline from tests.pipeline.utils import assert_table -from tests.utils import patch_home_dir, autouse_test_storage, preserve_environ, TEST_STORAGE_ROOT +from tests.utils import patch_home_dir, autouse_test_storage, TEST_STORAGE_ROOT # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -57,7 +56,7 @@ def test_duckdb_open_conn_default() -> None: delete_quack_db() -def test_duckdb_in_memory_mode_via_factory(preserve_environ): +def test_duckdb_in_memory_mode_via_factory(): delete_quack_db() try: import duckdb @@ -68,7 +67,9 @@ def test_duckdb_in_memory_mode_via_factory(preserve_environ): # Check if passing :memory: to factory fails with pytest.raises(PipelineStepFailed) as exc: - p = dlt.pipeline(pipeline_name="booboo", destination="duckdb", credentials=":memory:") + p = dlt.pipeline( + pipeline_name="booboo", destination=dlt.destinations.duckdb(credentials=":memory:") + ) p.run([1, 2, 3]) assert isinstance(exc.value.exception, InvalidInMemoryDuckdbCredentials) @@ -86,7 +87,7 @@ def test_duckdb_in_memory_mode_via_factory(preserve_environ): with pytest.raises(PipelineStepFailed) as exc: p = dlt.pipeline( pipeline_name="booboo", - destination=Destination.from_reference("duckdb", credentials=":memory:"), # type: ignore[arg-type] + destination=Destination.from_reference("duckdb", credentials=":memory:"), ) p.run([1, 2, 3], table_name="numbers") @@ -204,7 +205,9 @@ def test_duckdb_database_path() -> None: def test_keeps_initial_db_path() -> None: db_path = "_storage/path_test_quack.duckdb" - p = dlt.pipeline(pipeline_name="quack_pipeline", credentials=db_path, destination="duckdb") + p = dlt.pipeline( + pipeline_name="quack_pipeline", destination=dlt.destinations.duckdb(credentials=db_path) + ) print(p.pipelines_dir) with p.sql_client() as conn: # still cwd @@ -252,7 +255,7 @@ def test_duck_database_path_delete() -> None: db_folder = "_storage/db_path" os.makedirs(db_folder) db_path = f"{db_folder}/path_test_quack.duckdb" - p = dlt.pipeline(pipeline_name="deep_quack_pipeline", credentials=db_path, destination="duckdb") + p = dlt.pipeline(pipeline_name="deep_quack_pipeline", destination=duckdb(credentials=db_path)) p.run([1, 2, 3], table_name="table", dataset_name="dataset") # attach the pipeline p = dlt.attach(pipeline_name="deep_quack_pipeline") @@ -273,7 +276,7 @@ def test_case_sensitive_database_name() -> None: cs_quack = os.path.join(TEST_STORAGE_ROOT, "QuAcK") os.makedirs(cs_quack, exist_ok=True) db_path = os.path.join(cs_quack, "path_TEST_quack.duckdb") - p = dlt.pipeline(pipeline_name="NOT_QUAck", credentials=db_path, destination="duckdb") + p = dlt.pipeline(pipeline_name="NOT_QUAck", destination=duckdb(credentials=db_path)) with p.sql_client() as conn: conn.execute_sql("DESCRIBE;") diff --git a/tests/load/duckdb/test_duckdb_table_builder.py b/tests/load/duckdb/test_duckdb_table_builder.py index 545f182ece..85f86ce84d 100644 --- a/tests/load/duckdb/test_duckdb_table_builder.py +++ b/tests/load/duckdb/test_duckdb_table_builder.py @@ -5,6 +5,7 @@ from dlt.common.utils import uniq_id from dlt.common.schema import Schema +from dlt.destinations import duckdb from dlt.destinations.impl.duckdb.duck import DuckDbClient from dlt.destinations.impl.duckdb.configuration import DuckDbClientConfiguration @@ -22,7 +23,7 @@ @pytest.fixture def client(empty_schema: Schema) -> DuckDbClient: # return client without opening connection - return DuckDbClient( + return duckdb().client( empty_schema, DuckDbClientConfiguration()._bind_dataset_name(dataset_name="test_" + uniq_id()), ) @@ -117,7 +118,7 @@ def test_create_table_with_hints(client: DuckDbClient) -> None: assert '"col4" TIMESTAMP WITH TIME ZONE NOT NULL' in sql # same thing with indexes - client = DuckDbClient( + client = duckdb().client( client.schema, DuckDbClientConfiguration(create_indexes=True)._bind_dataset_name( dataset_name="test_" + uniq_id() diff --git a/tests/load/duckdb/test_motherduck_client.py b/tests/load/duckdb/test_motherduck_client.py index 15326c89dc..764e1654c6 100644 --- a/tests/load/duckdb/test_motherduck_client.py +++ b/tests/load/duckdb/test_motherduck_client.py @@ -14,7 +14,7 @@ MotherDuckClientConfiguration, ) -from tests.utils import patch_home_dir, preserve_environ, skip_if_not_active +from tests.utils import patch_home_dir, skip_if_not_active # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -28,7 +28,7 @@ def test_motherduck_configuration() -> None: assert cred.password == "TOKEN" assert cred.database == "dlt_data" assert cred.is_partial() is False - assert cred.is_resolved() is True + assert cred.is_resolved() is False cred = MotherDuckCredentials() cred.parse_native_representation("md:///?token=TOKEN") diff --git a/tests/load/filesystem/test_aws_credentials.py b/tests/load/filesystem/test_aws_credentials.py index 1a41144744..b782e76b7e 100644 --- a/tests/load/filesystem/test_aws_credentials.py +++ b/tests/load/filesystem/test_aws_credentials.py @@ -1,6 +1,7 @@ import pytest from typing import Dict +from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration from dlt.common.utils import digest128 from dlt.common.configuration import resolve_configuration from dlt.common.configuration.specs.aws_credentials import AwsCredentials @@ -8,7 +9,7 @@ from tests.common.configuration.utils import environment from tests.load.utils import ALL_FILESYSTEM_DRIVERS -from tests.utils import preserve_environ, autouse_test_storage +from tests.utils import autouse_test_storage # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -101,6 +102,11 @@ def test_aws_credentials_from_boto3(environment: Dict[str, str]) -> None: assert c.aws_access_key_id == "fake_access_key" +def test_aws_credentials_from_unknown_object() -> None: + with pytest.raises(InvalidBoto3Session): + AwsCredentials().parse_native_representation(CredentialsConfiguration()) + + def test_aws_credentials_for_profile(environment: Dict[str, str]) -> None: import botocore.exceptions @@ -136,6 +142,24 @@ def test_aws_credentials_with_endpoint_url(environment: Dict[str, str]) -> None: } +def test_explicit_filesystem_credentials() -> None: + import dlt + from dlt.destinations import filesystem + + # try filesystem which uses union of credentials that requires bucket_url to resolve + p = dlt.pipeline( + pipeline_name="postgres_pipeline", + destination=filesystem( + bucket_url="s3://test", + destination_name="uniq_s3_bucket", + credentials={"aws_access_key_id": "key_id", "aws_secret_access_key": "key"}, + ), + ) + config = p.destination_client().config + assert isinstance(config.credentials, AwsCredentials) + assert config.credentials.is_resolved() + + def set_aws_credentials_env(environment: Dict[str, str]) -> None: environment["AWS_ACCESS_KEY_ID"] = "fake_access_key" environment["AWS_SECRET_ACCESS_KEY"] = "fake_secret_key" diff --git a/tests/load/filesystem/test_azure_credentials.py b/tests/load/filesystem/test_azure_credentials.py index 4ee2ec46db..2353491737 100644 --- a/tests/load/filesystem/test_azure_credentials.py +++ b/tests/load/filesystem/test_azure_credentials.py @@ -17,7 +17,7 @@ from dlt.common.storages.configuration import FilesystemConfiguration from tests.load.utils import ALL_FILESYSTEM_DRIVERS, AZ_BUCKET from tests.common.configuration.utils import environment -from tests.utils import preserve_environ, autouse_test_storage +from tests.utils import autouse_test_storage from dlt.common.storages.fsspec_filesystem import fsspec_from_config # mark all tests as essential, do not remove diff --git a/tests/load/filesystem/test_filesystem_client.py b/tests/load/filesystem/test_filesystem_client.py index fbfd08271b..f16e75c7e6 100644 --- a/tests/load/filesystem/test_filesystem_client.py +++ b/tests/load/filesystem/test_filesystem_client.py @@ -2,13 +2,22 @@ import os from unittest import mock from pathlib import Path +from urllib.parse import urlparse import pytest +from dlt.common.configuration.specs.azure_credentials import AzureCredentials +from dlt.common.configuration.specs.base_configuration import ( + CredentialsConfiguration, + extract_inner_hint, +) +from dlt.common.schema.schema import Schema +from dlt.common.storages.configuration import FilesystemConfiguration from dlt.common.time import ensure_pendulum_datetime from dlt.common.utils import digest128, uniq_id from dlt.common.storages import FileStorage, ParsedLoadJobFileName +from dlt.destinations import filesystem from dlt.destinations.impl.filesystem.filesystem import ( FilesystemDestinationClientConfiguration, INIT_FILE_NAME, @@ -39,11 +48,43 @@ def logger_autouse() -> None: ] -def test_filesystem_destination_configuration() -> None: - assert FilesystemDestinationClientConfiguration().fingerprint() == "" - assert FilesystemDestinationClientConfiguration( - bucket_url="s3://cool" - ).fingerprint() == digest128("s3://cool") +@pytest.mark.parametrize( + "url, exp", + ( + (None, ""), + ("/path/path2", digest128("")), + ("s3://cool", digest128("s3://cool")), + ("s3://cool.domain/path/path2", digest128("s3://cool.domain")), + ), +) +def test_filesystem_destination_configuration(url, exp) -> None: + assert FilesystemDestinationClientConfiguration(bucket_url=url).fingerprint() == exp + + +def test_filesystem_factory_buckets(with_gdrive_buckets_env: str) -> None: + proto = urlparse(with_gdrive_buckets_env).scheme + credentials_type = extract_inner_hint( + FilesystemConfiguration.PROTOCOL_CREDENTIALS.get(proto, CredentialsConfiguration) + ) + + # test factory figuring out the right credentials + filesystem_ = filesystem(with_gdrive_buckets_env) + client = filesystem_.client( + Schema("test"), + initial_config=FilesystemDestinationClientConfiguration()._bind_dataset_name("test"), + ) + assert client.config.protocol == proto or "file" + assert isinstance(client.config.credentials, credentials_type) + assert issubclass(client.config.credentials_type(client.config), credentials_type) + assert filesystem_.capabilities() + + # factory gets initial credentials + filesystem_ = filesystem(with_gdrive_buckets_env, credentials=credentials_type()) + client = filesystem_.client( + Schema("test"), + initial_config=FilesystemDestinationClientConfiguration()._bind_dataset_name("test"), + ) + assert isinstance(client.config.credentials, credentials_type) @pytest.mark.parametrize("write_disposition", ("replace", "append", "merge")) diff --git a/tests/load/filesystem/test_filesystem_common.py b/tests/load/filesystem/test_filesystem_common.py index 270e1ff70c..3cad7dda2c 100644 --- a/tests/load/filesystem/test_filesystem_common.py +++ b/tests/load/filesystem/test_filesystem_common.py @@ -1,9 +1,10 @@ import os import posixpath -from typing import Union, Dict +from typing import Tuple, Union, Dict from urllib.parse import urlparse +from fsspec import AbstractFileSystem import pytest from tenacity import retry, stop_after_attempt, wait_fixed @@ -20,9 +21,10 @@ from dlt.destinations.impl.filesystem.configuration import ( FilesystemDestinationClientConfiguration, ) -from tests.common.storages.utils import assert_sample_files +from dlt.destinations.impl.filesystem.typing import TExtraPlaceholders +from tests.common.storages.utils import TEST_SAMPLE_FILES, assert_sample_files from tests.load.utils import ALL_FILESYSTEM_DRIVERS, AWS_BUCKET -from tests.utils import preserve_environ, autouse_test_storage +from tests.utils import autouse_test_storage from .utils import self_signed_cert from tests.common.configuration.utils import environment @@ -97,27 +99,26 @@ def check_file_changed(file_url_: str): @pytest.mark.parametrize("load_content", (True, False)) @pytest.mark.parametrize("glob_filter", ("**", "**/*.csv", "*.txt", "met_csv/A803/*.csv")) -def test_filesystem_dict( - with_gdrive_buckets_env: str, load_content: bool, glob_filter: str -) -> None: +def test_glob_files(with_gdrive_buckets_env: str, load_content: bool, glob_filter: str) -> None: bucket_url = os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] - config = get_config() - # enable caches - config.read_only = True - if config.protocol in ["memory", "file"]: - pytest.skip(f"{config.protocol} not supported in this test") - glob_folder = "standard_source/samples" - # may contain query string - bucket_url_parsed = urlparse(bucket_url) - bucket_url = bucket_url_parsed._replace( - path=posixpath.join(bucket_url_parsed.path, glob_folder) - ).geturl() - filesystem, _ = fsspec_from_config(config) + bucket_url, config, filesystem = glob_test_setup(bucket_url, "standard_source/samples") # use glob to get data all_file_items = list(glob_files(filesystem, bucket_url, glob_filter)) + # assert len(all_file_items) == 0 assert_sample_files(all_file_items, filesystem, config, load_content, glob_filter) +def test_glob_overlapping_path_files(with_gdrive_buckets_env: str) -> None: + bucket_url = os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] + # "standard_source/sample" overlaps with a real existing "standard_source/samples". walk operation on azure + # will return all files from "standard_source/samples" and report the wrong "standard_source/sample" path to the user + # here we test we do not have this problem with out glob + bucket_url, _, filesystem = glob_test_setup(bucket_url, "standard_source/sample") + # use glob to get data + all_file_items = list(glob_files(filesystem, bucket_url)) + assert len(all_file_items) == 0 + + @pytest.mark.skipif("s3" not in ALL_FILESYSTEM_DRIVERS, reason="s3 destination not configured") def test_filesystem_instance_from_s3_endpoint(environment: Dict[str, str]) -> None: """Test that fsspec instance is correctly configured when using endpoint URL. @@ -199,7 +200,7 @@ def test_s3_wrong_client_certificate(default_buckets_env: str, self_signed_cert: def test_filesystem_destination_config_reports_unused_placeholders(mocker) -> None: with custom_environ({"DATASET_NAME": "BOBO"}): - extra_placeholders = { + extra_placeholders: TExtraPlaceholders = { "value": 1, "otters": "lab", "dlt": "labs", @@ -211,7 +212,7 @@ def test_filesystem_destination_config_reports_unused_placeholders(mocker) -> No FilesystemDestinationClientConfiguration( bucket_url="file:///tmp/dirbobo", layout="{schema_name}/{table_name}/{otters}-x-{x}/{load_id}.{file_id}.{timestamp}.{ext}", - extra_placeholders=extra_placeholders, # type: ignore + extra_placeholders=extra_placeholders, ) ) logger_spy.assert_called_once_with("Found unused layout placeholders: value, dlt, dlthub") @@ -227,7 +228,7 @@ def test_filesystem_destination_passed_parameters_override_config_values() -> No "DESTINATION__FILESYSTEM__EXTRA_PLACEHOLDERS": json.dumps(config_extra_placeholders), } ): - extra_placeholders = { + extra_placeholders: TExtraPlaceholders = { "new_value": 1, "dlt": "labs", "dlthub": "platform", @@ -263,3 +264,26 @@ def test_filesystem_destination_passed_parameters_override_config_values() -> No bound_config = filesystem_destination.configuration(filesystem_config) assert bound_config.current_datetime == config_now assert bound_config.extra_placeholders == config_extra_placeholders + + +def glob_test_setup( + bucket_url: str, glob_folder: str +) -> Tuple[str, FilesystemConfiguration, AbstractFileSystem]: + config = get_config() + # enable caches + config.read_only = True + if config.protocol in ["file"]: + pytest.skip(f"{config.protocol} not supported in this test") + + # may contain query string + bucket_url_parsed = urlparse(bucket_url) + bucket_url = bucket_url_parsed._replace( + path=posixpath.join(bucket_url_parsed.path, glob_folder) + ).geturl() + filesystem, _ = fsspec_from_config(config) + if config.protocol == "memory": + mem_path = os.path.join("m", "standard_source") + if not filesystem.isdir(mem_path): + filesystem.mkdirs(mem_path) + filesystem.upload(TEST_SAMPLE_FILES, mem_path, recursive=True) + return bucket_url, config, filesystem diff --git a/tests/load/filesystem/test_gcs_credentials.py b/tests/load/filesystem/test_gcs_credentials.py new file mode 100644 index 0000000000..febfa27ea4 --- /dev/null +++ b/tests/load/filesystem/test_gcs_credentials.py @@ -0,0 +1,32 @@ +import pytest + +import dlt +from dlt.destinations import filesystem +from dlt.sources.credentials import GcpOAuthCredentials +from tests.load.utils import ALL_FILESYSTEM_DRIVERS + +# mark all tests as essential, do not remove +pytestmark = pytest.mark.essential + +if "gs" not in ALL_FILESYSTEM_DRIVERS: + pytest.skip("gcs filesystem driver not configured", allow_module_level=True) + + +def test_explicit_filesystem_credentials() -> None: + # resolve gcp oauth + p = dlt.pipeline( + pipeline_name="postgres_pipeline", + destination=filesystem( + "gcs://test", + destination_name="uniq_gcs_bucket", + credentials={ + "project_id": "pxid", + "refresh_token": "123token", + "client_id": "cid", + "client_secret": "s", + }, + ), + ) + config = p.destination_client().config + assert config.credentials.is_resolved() + assert isinstance(config.credentials, GcpOAuthCredentials) diff --git a/tests/load/filesystem/test_object_store_rs_credentials.py b/tests/load/filesystem/test_object_store_rs_credentials.py index 4e43b7c5d8..524cd4425d 100644 --- a/tests/load/filesystem/test_object_store_rs_credentials.py +++ b/tests/load/filesystem/test_object_store_rs_credentials.py @@ -29,9 +29,11 @@ FS_CREDS: Dict[str, Any] = dlt.secrets.get("destination.filesystem.credentials") -assert ( - FS_CREDS is not None -), "`destination.filesystem.credentials` must be configured for these tests." +if FS_CREDS is None: + pytest.skip( + msg="`destination.filesystem.credentials` must be configured for these tests.", + allow_module_level=True, + ) def can_connect(bucket_url: str, object_store_rs_credentials: Dict[str, str]) -> bool: @@ -86,6 +88,7 @@ def test_aws_object_store_rs_credentials() -> None: creds = AwsCredentials( aws_access_key_id=FS_CREDS["aws_access_key_id"], aws_secret_access_key=FS_CREDS["aws_secret_access_key"], + # region_name must be configured in order for data lake to work region_name=FS_CREDS["region_name"], ) assert creds.aws_session_token is None @@ -138,6 +141,7 @@ def test_gcp_object_store_rs_credentials() -> None: creds = GcpServiceAccountCredentialsWithoutDefaults( project_id=FS_CREDS["project_id"], private_key=FS_CREDS["private_key"], + # private_key_id must be configured in order for data lake to work private_key_id=FS_CREDS["private_key_id"], client_email=FS_CREDS["client_email"], ) diff --git a/tests/load/lancedb/__init__.py b/tests/load/lancedb/__init__.py new file mode 100644 index 0000000000..fb4bf0b35d --- /dev/null +++ b/tests/load/lancedb/__init__.py @@ -0,0 +1,3 @@ +from tests.utils import skip_if_not_active + +skip_if_not_active("lancedb") diff --git a/tests/load/lancedb/test_config.py b/tests/load/lancedb/test_config.py new file mode 100644 index 0000000000..c1d658d4fe --- /dev/null +++ b/tests/load/lancedb/test_config.py @@ -0,0 +1,35 @@ +import os +from typing import Iterator + +import pytest + +from dlt.common.configuration import resolve_configuration +from dlt.common.utils import digest128 +from dlt.destinations.impl.lancedb.configuration import ( + LanceDBClientConfiguration, +) +from tests.load.utils import ( + drop_active_pipeline_data, +) + + +# Mark all tests as essential, do not remove. +pytestmark = pytest.mark.essential + + +@pytest.fixture(autouse=True) +def drop_lancedb_data() -> Iterator[None]: + yield + drop_active_pipeline_data() + + +def test_lancedb_configuration() -> None: + os.environ["DESTINATION__LANCEDB__EMBEDDING_MODEL_PROVIDER"] = "colbert" + os.environ["DESTINATION__LANCEDB__EMBEDDING_MODEL"] = "text-embedding-3-small" + + config = resolve_configuration( + LanceDBClientConfiguration()._bind_dataset_name(dataset_name="dataset"), + sections=("destination", "lancedb"), + ) + assert config.embedding_model_provider == "colbert" + assert config.embedding_model == "text-embedding-3-small" diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py new file mode 100644 index 0000000000..e817a2f6c8 --- /dev/null +++ b/tests/load/lancedb/test_pipeline.py @@ -0,0 +1,435 @@ +from typing import Iterator, Generator, Any, List + +import pytest + +import dlt +from dlt.common import json +from dlt.common.typing import DictStrStr, DictStrAny +from dlt.common.utils import uniq_id +from dlt.destinations.impl.lancedb.lancedb_adapter import ( + lancedb_adapter, + VECTORIZE_HINT, +) +from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient +from tests.load.lancedb.utils import assert_table +from tests.load.utils import sequence_generator, drop_active_pipeline_data +from tests.pipeline.utils import assert_load_info + + +# Mark all tests as essential, do not remove. +pytestmark = pytest.mark.essential + + +@pytest.fixture(autouse=True) +def drop_lancedb_data() -> Iterator[None]: + yield + drop_active_pipeline_data() + + +def test_adapter_and_hints() -> None: + generator_instance1 = sequence_generator() + + @dlt.resource(columns=[{"name": "content", "data_type": "text"}]) + def some_data() -> Generator[DictStrStr, Any, None]: + yield from next(generator_instance1) + + assert some_data.columns["content"] == {"name": "content", "data_type": "text"} # type: ignore[index] + + lancedb_adapter( + some_data, + embed=["content"], + ) + + assert some_data.columns["content"] == { # type: ignore + "name": "content", + "data_type": "text", + "x-lancedb-embed": True, + } + + +def test_basic_state_and_schema() -> None: + generator_instance1 = sequence_generator() + + @dlt.resource + def some_data() -> Generator[DictStrStr, Any, None]: + yield from next(generator_instance1) + + lancedb_adapter( + some_data, + embed=["content"], + ) + + pipeline = dlt.pipeline( + pipeline_name="test_pipeline_append", + destination="lancedb", + dataset_name=f"test_pipeline_append_dataset{uniq_id()}", + ) + info = pipeline.run( + some_data(), + ) + assert_load_info(info) + + client: LanceDBClient + with pipeline.destination_client() as client: # type: ignore + # Check if we can get a stored schema and state. + schema = client.get_stored_schema() + print("Print dataset name", client.dataset_name) + assert schema + state = client.get_stored_state("test_pipeline_append") + assert state + + +def test_pipeline_append() -> None: + generator_instance1 = sequence_generator() + generator_instance2 = sequence_generator() + + @dlt.resource + def some_data() -> Generator[DictStrStr, Any, None]: + yield from next(generator_instance1) + + lancedb_adapter( + some_data, + embed=["content"], + ) + + pipeline = dlt.pipeline( + pipeline_name="test_pipeline_append", + destination="lancedb", + dataset_name=f"TestPipelineAppendDataset{uniq_id()}", + ) + info = pipeline.run( + some_data(), + ) + assert_load_info(info) + + data = next(generator_instance2) + assert_table(pipeline, "some_data", items=data) + + info = pipeline.run( + some_data(), + ) + assert_load_info(info) + + data.extend(next(generator_instance2)) + assert_table(pipeline, "some_data", items=data) + + +def test_explicit_append() -> None: + """Append should work even when the primary key is specified.""" + data = [ + {"doc_id": 1, "content": "1"}, + {"doc_id": 2, "content": "2"}, + {"doc_id": 3, "content": "3"}, + ] + + @dlt.resource(primary_key="doc_id") + def some_data() -> Generator[List[DictStrAny], Any, None]: + yield data + + lancedb_adapter( + some_data, + embed=["content"], + ) + + pipeline = dlt.pipeline( + pipeline_name="test_pipeline_append", + destination="lancedb", + dataset_name=f"TestPipelineAppendDataset{uniq_id()}", + ) + info = pipeline.run( + some_data(), + ) + + assert_table(pipeline, "some_data", items=data) + + info = pipeline.run( + some_data(), + write_disposition="append", + ) + assert_load_info(info) + + data.extend(data) + assert_table(pipeline, "some_data", items=data) + + +def test_pipeline_replace() -> None: + generator_instance1 = sequence_generator() + generator_instance2 = sequence_generator() + + @dlt.resource + def some_data() -> Generator[DictStrStr, Any, None]: + yield from next(generator_instance1) + + lancedb_adapter( + some_data, + embed=["content"], + ) + + uid = uniq_id() + + pipeline = dlt.pipeline( + pipeline_name="test_pipeline_replace", + destination="lancedb", + dataset_name="test_pipeline_replace_dataset" + + uid, # lancedb doesn't mandate any name normalization + ) + + info = pipeline.run( + some_data(), + write_disposition="replace", + ) + assert_load_info(info) + assert info.dataset_name == f"test_pipeline_replace_dataset{uid}" + + data = next(generator_instance2) + assert_table(pipeline, "some_data", items=data) + + info = pipeline.run( + some_data(), + write_disposition="replace", + ) + assert_load_info(info) + + data = next(generator_instance2) + assert_table(pipeline, "some_data", items=data) + + +def test_pipeline_merge() -> None: + data = [ + { + "doc_id": 1, + "merge_id": "shawshank-redemption-1994", + "title": "The Shawshank Redemption", + "description": ( + "Two imprisoned men find redemption through acts of decency over the years." + ), + }, + { + "doc_id": 2, + "merge_id": "the-godfather-1972", + "title": "The Godfather", + "description": ( + "A crime dynasty's aging patriarch transfers control to his reluctant son." + ), + }, + { + "doc_id": 3, + "merge_id": "the-dark-knight-2008", + "title": "The Dark Knight", + "description": ( + "The Joker wreaks havoc on Gotham, challenging The Dark Knight's ability to fight" + " injustice." + ), + }, + { + "doc_id": 4, + "merge_id": "pulp-fiction-1994", + "title": "Pulp Fiction", + "description": ( + "The lives of two mob hitmen, a boxer, a gangster and his wife, and a pair of diner" + " bandits intertwine in four tales of violence and redemption." + ), + }, + { + "doc_id": 5, + "merge_id": "schindlers-list-1993", + "title": "Schindler's List", + "description": ( + "In German-occupied Poland during World War II, industrialist Oskar Schindler" + " gradually becomes concerned for his Jewish workforce after witnessing their" + " persecution by the Nazis." + ), + }, + { + "doc_id": 6, + "merge_id": "the-lord-of-the-rings-the-return-of-the-king-2003", + "title": "The Lord of the Rings: The Return of the King", + "description": ( + "Gandalf and Aragorn lead the World of Men against Sauron's army to draw his gaze" + " from Frodo and Sam as they approach Mount Doom with the One Ring." + ), + }, + { + "doc_id": 7, + "merge_id": "the-matrix-1999", + "title": "The Matrix", + "description": ( + "A computer hacker learns from mysterious rebels about the true nature of his" + " reality and his role in the war against its controllers." + ), + }, + ] + + @dlt.resource(primary_key="doc_id") + def movies_data() -> Any: + yield data + + @dlt.resource(primary_key="doc_id", merge_key=["merge_id", "title"]) + def movies_data_explicit_merge_keys() -> Any: + yield data + + lancedb_adapter( + movies_data, + embed=["description"], + ) + + lancedb_adapter( + movies_data_explicit_merge_keys, + embed=["description"], + ) + + pipeline = dlt.pipeline( + pipeline_name="movies", + destination="lancedb", + dataset_name=f"TestPipelineAppendDataset{uniq_id()}", + ) + info = pipeline.run( + movies_data(), + write_disposition="merge", + dataset_name=f"MoviesDataset{uniq_id()}", + ) + assert_load_info(info) + assert_table(pipeline, "movies_data", items=data) + + # Change some data. + data[0]["title"] = "The Shawshank Redemption 2" + + info = pipeline.run( + movies_data(), + write_disposition="merge", + ) + assert_load_info(info) + assert_table(pipeline, "movies_data", items=data) + + info = pipeline.run( + movies_data(), + write_disposition="merge", + ) + assert_load_info(info) + assert_table(pipeline, "movies_data", items=data) + + # Test with explicit merge keys. + info = pipeline.run( + movies_data_explicit_merge_keys(), + write_disposition="merge", + ) + assert_load_info(info) + assert_table(pipeline, "movies_data_explicit_merge_keys", items=data) + + +def test_pipeline_with_schema_evolution() -> None: + data = [ + { + "doc_id": 1, + "content": "1", + }, + { + "doc_id": 2, + "content": "2", + }, + ] + + @dlt.resource() + def some_data() -> Generator[List[DictStrAny], Any, None]: + yield data + + lancedb_adapter(some_data, embed=["content"]) + + pipeline = dlt.pipeline( + pipeline_name="test_pipeline_append", + destination="lancedb", + dataset_name=f"TestSchemaEvolutionDataset{uniq_id()}", + ) + pipeline.run( + some_data(), + ) + + assert_table(pipeline, "some_data", items=data) + + aggregated_data = data.copy() + + data = [ + { + "doc_id": 3, + "content": "3", + "new_column": "new", + }, + { + "doc_id": 4, + "content": "4", + "new_column": "new", + }, + ] + + pipeline.run( + some_data(), + ) + + table_schema = pipeline.default_schema.tables["some_data"] + assert "new_column" in table_schema["columns"] + + aggregated_data.extend(data) + + assert_table(pipeline, "some_data", items=aggregated_data) + + +def test_merge_github_nested() -> None: + pipe = dlt.pipeline(destination="lancedb", dataset_name="github1", dev_mode=True) + assert pipe.dataset_name.startswith("github1_202") + + with open( + "tests/normalize/cases/github.issues.load_page_5_duck.json", + "r", + encoding="utf-8", + ) as f: + data = json.load(f) + + info = pipe.run( + lancedb_adapter(data[:17], embed=["title", "body"]), + table_name="issues", + write_disposition="merge", + primary_key="id", + ) + assert_load_info(info) + # assert if schema contains tables with right names + print(pipe.default_schema.tables.keys()) + assert set(pipe.default_schema.tables.keys()) == { + "_dlt_version", + "_dlt_loads", + "issues", + "_dlt_pipeline_state", + "issues__labels", + "issues__assignees", + } + assert {t["name"] for t in pipe.default_schema.data_tables()} == { + "issues", + "issues__labels", + "issues__assignees", + } + assert {t["name"] for t in pipe.default_schema.dlt_tables()} == { + "_dlt_version", + "_dlt_loads", + "_dlt_pipeline_state", + } + issues = pipe.default_schema.tables["issues"] + assert issues["columns"]["id"]["primary_key"] is True + # Make sure vectorization is enabled for. + assert issues["columns"]["title"][VECTORIZE_HINT] # type: ignore[literal-required] + assert issues["columns"]["body"][VECTORIZE_HINT] # type: ignore[literal-required] + assert VECTORIZE_HINT not in issues["columns"]["url"] + assert_table(pipe, "issues", expected_items_count=17) + + +def test_empty_dataset_allowed() -> None: + # dataset_name is optional so dataset name won't be autogenerated when not explicitly passed. + pipe = dlt.pipeline(destination="lancedb", dev_mode=True) + client: LanceDBClient = pipe.destination_client() # type: ignore[assignment] + + assert pipe.dataset_name is None + info = pipe.run(lancedb_adapter(["context", "created", "not a stop word"], embed=["value"])) + # Dataset in load info is empty. + assert info.dataset_name is None + client = pipe.destination_client() # type: ignore[assignment] + assert client.dataset_name is None + assert client.sentinel_table == "dltSentinelTable" + assert_table(pipe, "content", expected_items_count=3) diff --git a/tests/load/lancedb/utils.py b/tests/load/lancedb/utils.py new file mode 100644 index 0000000000..dc3ea5304b --- /dev/null +++ b/tests/load/lancedb/utils.py @@ -0,0 +1,74 @@ +from typing import Union, List, Any, Dict + +import numpy as np +from lancedb.embeddings import TextEmbeddingFunction # type: ignore + +import dlt +from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient + + +def assert_unordered_dicts_equal( + dict_list1: List[Dict[str, Any]], dict_list2: List[Dict[str, Any]] +) -> None: + """ + Assert that two lists of dictionaries contain the same dictionaries, ignoring None values. + + Args: + dict_list1 (List[Dict[str, Any]]): The first list of dictionaries to compare. + dict_list2 (List[Dict[str, Any]]): The second list of dictionaries to compare. + + Raises: + AssertionError: If the lists have different lengths or contain different dictionaries. + """ + assert len(dict_list1) == len(dict_list2), "Lists have different length" + + dict_set1 = {tuple(sorted((k, v) for k, v in d.items() if v is not None)) for d in dict_list1} + dict_set2 = {tuple(sorted((k, v) for k, v in d.items() if v is not None)) for d in dict_list2} + + assert dict_set1 == dict_set2, "Lists contain different dictionaries" + + +def assert_table( + pipeline: dlt.Pipeline, + table_name: str, + expected_items_count: int = None, + items: List[Any] = None, +) -> None: + client: LanceDBClient = pipeline.destination_client() # type: ignore[assignment] + qualified_table_name = client.make_qualified_table_name(table_name) + + exists = client.table_exists(qualified_table_name) + assert exists + + records = client.db_client.open_table(qualified_table_name).search().limit(50).to_list() + + if expected_items_count is not None: + assert expected_items_count == len(records) + + if items is None: + return + + drop_keys = [ + "_dlt_id", + "_dlt_load_id", + dlt.config.get("destination.lancedb.credentials.id_field_name", str) or "id__", + dlt.config.get("destination.lancedb.credentials.vector_field_name", str) or "vector__", + ] + objects_without_dlt_or_special_keys = [ + {k: v for k, v in record.items() if k not in drop_keys} for record in records + ] + + assert_unordered_dicts_equal(objects_without_dlt_or_special_keys, items) + + +class MockEmbeddingFunc(TextEmbeddingFunction): + def generate_embeddings( + self, + texts: Union[List[str], np.ndarray], # type: ignore[type-arg] + *args, + **kwargs, + ) -> List[np.ndarray]: # type: ignore[type-arg] + return [np.array(None)] + + def ndims(self) -> int: + return 2 diff --git a/tests/load/mssql/test_mssql_credentials.py b/tests/load/mssql/test_mssql_configuration.py similarity index 77% rename from tests/load/mssql/test_mssql_credentials.py rename to tests/load/mssql/test_mssql_configuration.py index 7d49196531..75af101e23 100644 --- a/tests/load/mssql/test_mssql_credentials.py +++ b/tests/load/mssql/test_mssql_configuration.py @@ -1,15 +1,46 @@ +import os import pyodbc import pytest from dlt.common.configuration import resolve_configuration, ConfigFieldMissingException from dlt.common.exceptions import SystemConfigurationException +from dlt.common.schema import Schema -from dlt.destinations.impl.mssql.configuration import MsSqlCredentials +from dlt.destinations import mssql +from dlt.destinations.impl.mssql.configuration import MsSqlCredentials, MsSqlClientConfiguration # mark all tests as essential, do not remove pytestmark = pytest.mark.essential +def test_mssql_factory() -> None: + schema = Schema("schema") + dest = mssql() + client = dest.client(schema, MsSqlClientConfiguration()._bind_dataset_name("dataset")) + assert client.config.create_indexes is False + assert client.config.has_case_sensitive_identifiers is False + assert client.capabilities.has_case_sensitive_identifiers is False + assert client.capabilities.casefold_identifier is str + + # set args explicitly + dest = mssql(has_case_sensitive_identifiers=True, create_indexes=True) + client = dest.client(schema, MsSqlClientConfiguration()._bind_dataset_name("dataset")) + assert client.config.create_indexes is True + assert client.config.has_case_sensitive_identifiers is True + assert client.capabilities.has_case_sensitive_identifiers is True + assert client.capabilities.casefold_identifier is str + + # set args via config + os.environ["DESTINATION__CREATE_INDEXES"] = "True" + os.environ["DESTINATION__HAS_CASE_SENSITIVE_IDENTIFIERS"] = "True" + dest = mssql() + client = dest.client(schema, MsSqlClientConfiguration()._bind_dataset_name("dataset")) + assert client.config.create_indexes is True + assert client.config.has_case_sensitive_identifiers is True + assert client.capabilities.has_case_sensitive_identifiers is True + assert client.capabilities.casefold_identifier is str + + def test_mssql_credentials_defaults() -> None: creds = MsSqlCredentials() assert creds.port == 1433 diff --git a/tests/load/mssql/test_mssql_table_builder.py b/tests/load/mssql/test_mssql_table_builder.py index f7a87c14ee..d6cf3ec3e8 100644 --- a/tests/load/mssql/test_mssql_table_builder.py +++ b/tests/load/mssql/test_mssql_table_builder.py @@ -6,7 +6,8 @@ pytest.importorskip("dlt.destinations.impl.mssql.mssql", reason="MSSQL ODBC driver not installed") -from dlt.destinations.impl.mssql.mssql import MsSqlClient +from dlt.destinations import mssql +from dlt.destinations.impl.mssql.mssql import MsSqlJobClient from dlt.destinations.impl.mssql.configuration import MsSqlClientConfiguration, MsSqlCredentials from tests.load.utils import TABLE_UPDATE, empty_schema @@ -16,9 +17,9 @@ @pytest.fixture -def client(empty_schema: Schema) -> MsSqlClient: +def client(empty_schema: Schema) -> MsSqlJobClient: # return client without opening connection - return MsSqlClient( + return mssql().client( empty_schema, MsSqlClientConfiguration(credentials=MsSqlCredentials())._bind_dataset_name( dataset_name="test_" + uniq_id() @@ -26,7 +27,7 @@ def client(empty_schema: Schema) -> MsSqlClient: ) -def test_create_table(client: MsSqlClient) -> None: +def test_create_table(client: MsSqlJobClient) -> None: # non existing table sql = client._get_table_update_sql("event_test_table", TABLE_UPDATE, False)[0] sqlfluff.parse(sql, dialect="tsql") @@ -50,7 +51,7 @@ def test_create_table(client: MsSqlClient) -> None: assert '"col11_precision" time(3) NOT NULL' in sql -def test_alter_table(client: MsSqlClient) -> None: +def test_alter_table(client: MsSqlJobClient) -> None: # existing table has no columns sql = client._get_table_update_sql("event_test_table", TABLE_UPDATE, True)[0] sqlfluff.parse(sql, dialect="tsql") diff --git a/tests/load/pipeline/conftest.py b/tests/load/pipeline/conftest.py index 34227a8041..a2ba65494b 100644 --- a/tests/load/pipeline/conftest.py +++ b/tests/load/pipeline/conftest.py @@ -1,8 +1,2 @@ -from tests.utils import ( - patch_home_dir, - preserve_environ, - autouse_test_storage, - duckdb_pipeline_location, -) +from tests.utils import autouse_test_storage, duckdb_pipeline_location from tests.pipeline.utils import drop_dataset_from_env -from tests.load.pipeline.utils import drop_pipeline diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index 0bddfaabee..6d78968996 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -9,14 +9,14 @@ import dlt from dlt.common import pendulum -from dlt.common.time import reduce_pendulum_datetime_precision, ensure_pendulum_datetime +from dlt.common.time import reduce_pendulum_datetime_precision from dlt.common.utils import uniq_id + from tests.load.utils import destinations_configs, DestinationTestConfiguration from tests.pipeline.utils import assert_load_info, select_data from tests.utils import ( TestDataItemFormat, arrow_item_from_pandas, - preserve_environ, TPythonTableFormat, ) from tests.cases import arrow_table_all_data_types @@ -192,7 +192,7 @@ def test_parquet_column_names_are_normalized( def some_data(): yield tbl - pipeline = dlt.pipeline("arrow_" + uniq_id(), destination=destination_config.destination) + pipeline = destination_config.setup_pipeline("arrow_" + uniq_id()) pipeline.extract(some_data()) # Find the extracted file diff --git a/tests/load/pipeline/test_athena.py b/tests/load/pipeline/test_athena.py index 272cc701d5..3197a19d14 100644 --- a/tests/load/pipeline/test_athena.py +++ b/tests/load/pipeline/test_athena.py @@ -9,15 +9,15 @@ from tests.pipeline.utils import assert_load_info, load_table_counts from tests.pipeline.utils import load_table_counts from dlt.destinations.exceptions import CantExtractTablePrefix -from dlt.destinations.impl.athena.athena_adapter import athena_partition, athena_adapter -from dlt.destinations.fs_client import FSClientBase +from dlt.destinations.adapters import athena_partition, athena_adapter -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration from tests.load.utils import ( TEST_FILE_LAYOUTS, FILE_LAYOUT_MANY_TABLES_ONE_FOLDER, FILE_LAYOUT_CLASSIC, FILE_LAYOUT_TABLE_NOT_FIRST, + destinations_configs, + DestinationTestConfiguration, ) # mark all tests as essential, do not remove @@ -208,7 +208,7 @@ def my_source() -> Any: @pytest.mark.parametrize("layout", TEST_FILE_LAYOUTS) def test_athena_file_layouts(destination_config: DestinationTestConfiguration, layout) -> None: # test wether strange file layouts still work in all staging configs - pipeline = destination_config.setup_pipeline("athena_file_layout", full_refresh=True) + pipeline = destination_config.setup_pipeline("athena_file_layout", dev_mode=True) os.environ["DESTINATION__FILESYSTEM__LAYOUT"] = layout resources = [ @@ -242,7 +242,7 @@ def test_athena_file_layouts(destination_config: DestinationTestConfiguration, l ) def test_athena_partitioned_iceberg_table(destination_config: DestinationTestConfiguration): """Load an iceberg table with partition hints and verifiy partitions are created correctly.""" - pipeline = destination_config.setup_pipeline("athena_" + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline("athena_" + uniq_id(), dev_mode=True) data_items = [ (1, "A", datetime.date.fromisoformat("2021-01-01")), diff --git a/tests/load/pipeline/test_bigquery.py b/tests/load/pipeline/test_bigquery.py index 68533a5d43..fd0a55e273 100644 --- a/tests/load/pipeline/test_bigquery.py +++ b/tests/load/pipeline/test_bigquery.py @@ -1,10 +1,12 @@ import pytest +import io -from dlt.common import Decimal +import dlt +from dlt.common import Decimal, json +from dlt.common.typing import TLoaderFileFormat from tests.pipeline.utils import assert_load_info -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration -from tests.load.utils import delete_dataset +from tests.load.utils import destinations_configs, DestinationTestConfiguration # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -16,7 +18,7 @@ ids=lambda x: x.name, ) def test_bigquery_numeric_types(destination_config: DestinationTestConfiguration) -> None: - pipeline = destination_config.setup_pipeline("test_bigquery_numeric_types") + pipeline = destination_config.setup_pipeline("test_bigquery_numeric_types", dev_mode=True) columns = [ {"name": "col_big_numeric", "data_type": "decimal", "precision": 47, "scale": 9}, @@ -39,3 +41,107 @@ def test_bigquery_numeric_types(destination_config: DestinationTestConfiguration row = q.fetchone() assert row[0] == data[0]["col_big_numeric"] assert row[1] == data[0]["col_numeric"] + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["bigquery"]), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("file_format", ("parquet", "jsonl")) +def test_bigquery_autodetect_schema( + destination_config: DestinationTestConfiguration, file_format: TLoaderFileFormat +) -> None: + from dlt.destinations.adapters import bigquery_adapter + from dlt.destinations.impl.bigquery.sql_client import BigQuerySqlClient + + @dlt.resource(name="cve", max_table_nesting=0, file_format=file_format) + def load_cve(stage: int): + with open("tests/load/cases/loading/cve.json", "rb") as f: + cve = json.load(f) + if stage == 0: + # remove a whole struct field + del cve["references"] + if stage == 1: + # remove a field from struct + for item in cve["references"]["reference_data"]: + del item["refsource"] + if file_format == "jsonl": + yield cve + else: + import pyarrow.json as paj + + table = paj.read_json(io.BytesIO(json.dumpb(cve))) + yield table + + pipeline = destination_config.setup_pipeline("test_bigquery_autodetect_schema", dev_mode=True) + # run without one nested field + cve = bigquery_adapter(load_cve(0), autodetect_schema=True) + info = pipeline.run(cve) + assert_load_info(info) + client: BigQuerySqlClient + with pipeline.sql_client() as client: # type: ignore[assignment] + table = client.native_connection.get_table( + client.make_qualified_table_name("cve", escape=False) + ) + field = next(field for field in table.schema if field.name == "source") + # not repeatable + assert field.field_type == "RECORD" + assert field.mode == "NULLABLE" + field = next(field for field in table.schema if field.name == "credit") + if file_format == "parquet": + # parquet wraps struct into repeatable list + field = field.fields[0] + assert field.name == "list" + assert field.field_type == "RECORD" + assert field.mode == "REPEATED" + # no references + field = next((field for field in table.schema if field.name == "references"), None) + assert field is None + + # evolve schema - add nested field + cve = bigquery_adapter(load_cve(1), autodetect_schema=True) + info = pipeline.run(cve) + assert_load_info(info) + with pipeline.sql_client() as client: # type: ignore[assignment] + table = client.native_connection.get_table( + client.make_qualified_table_name("cve", escape=False) + ) + field = next(field for field in table.schema if field.name == "references") + field = field.fields[0] + assert field.name == "reference_data" + if file_format == "parquet": + # parquet wraps struct into repeatable list + field = field.fields[0] + assert field.name == "list" + assert field.mode == "REPEATED" + # and enclosed in another type 🤷 + field = field.fields[0] + else: + assert field.mode == "REPEATED" + # make sure url is there + nested_field = next(f for f in field.fields if f.name == "url") + assert nested_field.field_type == "STRING" + # refsource not there + nested_field = next((f for f in field.fields if f.name == "refsource"), None) + assert nested_field is None + + # evolve schema - add field to a nested struct + cve = bigquery_adapter(load_cve(2), autodetect_schema=True) + info = pipeline.run(cve) + assert_load_info(info) + with pipeline.sql_client() as client: # type: ignore[assignment] + table = client.native_connection.get_table( + client.make_qualified_table_name("cve", escape=False) + ) + field = next(field for field in table.schema if field.name == "references") + field = field.fields[0] + if file_format == "parquet": + # parquet wraps struct into repeatable list + field = field.fields[0] + assert field.name == "list" + assert field.mode == "REPEATED" + # and enclosed in another type 🤷 + field = field.fields[0] + # it looks like BigQuery can evolve structs and the field is added + nested_field = next(f for f in field.fields if f.name == "refsource") diff --git a/tests/load/pipeline/test_clickhouse.py b/tests/load/pipeline/test_clickhouse.py index 2ba5cfdcb8..8ad3a7f1a7 100644 --- a/tests/load/pipeline/test_clickhouse.py +++ b/tests/load/pipeline/test_clickhouse.py @@ -5,10 +5,7 @@ import dlt from dlt.common.typing import TDataItem from dlt.common.utils import uniq_id -from tests.load.pipeline.utils import ( - destinations_configs, - DestinationTestConfiguration, -) +from tests.load.utils import destinations_configs, DestinationTestConfiguration from tests.pipeline.utils import load_table_counts @@ -18,7 +15,7 @@ ids=lambda x: x.name, ) def test_clickhouse_destination_append(destination_config: DestinationTestConfiguration) -> None: - pipeline = destination_config.setup_pipeline(f"clickhouse_{uniq_id()}", full_refresh=True) + pipeline = destination_config.setup_pipeline(f"clickhouse_{uniq_id()}", dev_mode=True) try: diff --git a/tests/load/pipeline/test_csv_loading.py b/tests/load/pipeline/test_csv_loading.py new file mode 100644 index 0000000000..6a2be2eb40 --- /dev/null +++ b/tests/load/pipeline/test_csv_loading.py @@ -0,0 +1,172 @@ +import os +from typing import List +import pytest + +import dlt +from dlt.common.data_writers.configuration import CsvFormatConfiguration +from dlt.common.schema.typing import TColumnSchema +from dlt.common.typing import TLoaderFileFormat +from dlt.common.utils import uniq_id + +from tests.cases import arrow_table_all_data_types, prepare_shuffled_tables +from tests.pipeline.utils import ( + assert_data_table_counts, + assert_load_info, + assert_only_table_columns, + load_tables_to_dicts, +) +from tests.load.utils import destinations_configs, DestinationTestConfiguration +from tests.utils import TestDataItemFormat + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["postgres", "snowflake"]), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("item_type", ["object", "table"]) +def test_load_csv( + destination_config: DestinationTestConfiguration, item_type: TestDataItemFormat +) -> None: + os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" + pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), dev_mode=True) + # do not save state so the state job is not created + pipeline.config.restore_from_destination = False + + table, shuffled_table, shuffled_removed_column = prepare_shuffled_tables() + # convert to pylist when loading from objects, this will kick the csv-reader in + if item_type == "object": + table, shuffled_table, shuffled_removed_column = ( + table.to_pylist(), + shuffled_table.to_pylist(), + shuffled_removed_column.to_pylist(), + ) + + load_info = pipeline.run( + [shuffled_removed_column, shuffled_table, table], + table_name="table", + loader_file_format="csv", + ) + assert_load_info(load_info) + job = load_info.load_packages[0].jobs["completed_jobs"][0].file_path + assert job.endswith("csv") + assert_data_table_counts(pipeline, {"table": 5432 * 3}) + load_tables_to_dicts(pipeline, "table") + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["postgres", "snowflake"]), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("file_format", (None, "csv")) +@pytest.mark.parametrize("compression", (True, False)) +def test_custom_csv_no_header( + destination_config: DestinationTestConfiguration, + file_format: TLoaderFileFormat, + compression: bool, +) -> None: + os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = str(not compression) + csv_format = CsvFormatConfiguration(delimiter="|", include_header=False) + # apply to collected config + pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), dev_mode=True) + # this will apply this to config when client instance is created + pipeline.destination.config_params["csv_format"] = csv_format + # verify + assert pipeline.destination_client().config.csv_format == csv_format # type: ignore[attr-defined] + # create a resource that imports file + + columns: List[TColumnSchema] = [ + {"name": "id", "data_type": "bigint"}, + {"name": "name", "data_type": "text"}, + {"name": "description", "data_type": "text"}, + {"name": "ordered_at", "data_type": "date"}, + {"name": "price", "data_type": "decimal"}, + ] + hints = dlt.mark.make_hints(columns=columns) + import_file = "tests/load/cases/loading/csv_no_header.csv" + if compression: + import_file += ".gz" + info = pipeline.run( + [dlt.mark.with_file_import(import_file, "csv", 2, hints=hints)], + table_name="no_header", + loader_file_format=file_format, + ) + info.raise_on_failed_jobs() + print(info) + assert_only_table_columns(pipeline, "no_header", [col["name"] for col in columns]) + rows = load_tables_to_dicts(pipeline, "no_header") + assert len(rows["no_header"]) == 2 + # we should have twp files loaded + jobs = info.load_packages[0].jobs["completed_jobs"] + assert len(jobs) == 2 + job_extensions = [os.path.splitext(job.job_file_info.file_name())[1] for job in jobs] + assert ".csv" in job_extensions + # we allow state to be saved to make sure it is not in csv format (which would broke) + # the loading. state is always saved in destination preferred format + preferred_ext = "." + pipeline.destination.capabilities().preferred_loader_file_format + assert preferred_ext in job_extensions + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["postgres", "snowflake"]), + ids=lambda x: x.name, +) +def test_custom_wrong_header(destination_config: DestinationTestConfiguration) -> None: + csv_format = CsvFormatConfiguration(delimiter="|", include_header=True) + # apply to collected config + pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), dev_mode=True) + # this will apply this to config when client instance is created + pipeline.destination.config_params["csv_format"] = csv_format + # verify + assert pipeline.destination_client().config.csv_format == csv_format # type: ignore[attr-defined] + # create a resource that imports file + + columns: List[TColumnSchema] = [ + {"name": "object_id", "data_type": "bigint", "nullable": False}, + {"name": "name", "data_type": "text"}, + {"name": "description", "data_type": "text"}, + {"name": "ordered_at", "data_type": "date"}, + {"name": "price", "data_type": "decimal"}, + ] + hints = dlt.mark.make_hints(columns=columns) + import_file = "tests/load/cases/loading/csv_header.csv" + # snowflake will pass here because we do not match + info = pipeline.run( + [dlt.mark.with_file_import(import_file, "csv", 2, hints=hints)], + table_name="no_header", + ) + assert info.has_failed_jobs + assert len(info.load_packages[0].jobs["failed_jobs"]) == 1 + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["postgres", "snowflake"]), + ids=lambda x: x.name, +) +def test_empty_csv_from_arrow(destination_config: DestinationTestConfiguration) -> None: + os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" + os.environ["RESTORE_FROM_DESTINATION"] = "False" + pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), dev_mode=True) + table, _, _ = arrow_table_all_data_types("arrow-table", include_json=False) + + load_info = pipeline.run( + table.schema.empty_table(), table_name="arrow_table", loader_file_format="csv" + ) + assert_load_info(load_info) + assert len(load_info.load_packages[0].jobs["completed_jobs"]) == 1 + job = load_info.load_packages[0].jobs["completed_jobs"][0].file_path + assert job.endswith("csv") + assert_data_table_counts(pipeline, {"arrow_table": 0}) + with pipeline.sql_client() as client: + with client.execute_query("SELECT * FROM arrow_table") as cur: + columns = [col.name for col in cur.description] + assert len(cur.fetchall()) == 0 + + # all columns in order, also casefold to the destination casing (we use cursor.description) + casefold = pipeline.destination.capabilities().casefold_identifier + assert columns == list( + map(casefold, pipeline.default_schema.get_table_columns("arrow_table").keys()) + ) diff --git a/tests/load/pipeline/test_dbt_helper.py b/tests/load/pipeline/test_dbt_helper.py index 1dc225594f..86ee1a646e 100644 --- a/tests/load/pipeline/test_dbt_helper.py +++ b/tests/load/pipeline/test_dbt_helper.py @@ -11,8 +11,8 @@ from dlt.helpers.dbt.exceptions import DBTProcessingError, PrerequisitesException from tests.pipeline.utils import select_data +from tests.load.utils import destinations_configs, DestinationTestConfiguration from tests.utils import ACTIVE_SQL_DESTINATIONS -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration # uncomment add motherduck tests # NOTE: the tests are passing but we disable them due to frequent ATTACH DATABASE timeouts diff --git a/tests/load/pipeline/test_dremio.py b/tests/load/pipeline/test_dremio.py index 9a4c96c922..66d1b0be4f 100644 --- a/tests/load/pipeline/test_dremio.py +++ b/tests/load/pipeline/test_dremio.py @@ -12,9 +12,7 @@ ids=lambda x: x.name, ) def test_dremio(destination_config: DestinationTestConfiguration) -> None: - pipeline = destination_config.setup_pipeline( - "dremio-test", dataset_name="bar", full_refresh=True - ) + pipeline = destination_config.setup_pipeline("dremio-test", dataset_name="bar", dev_mode=True) @dlt.resource(name="items", write_disposition="replace") def items() -> Iterator[Any]: diff --git a/tests/load/pipeline/test_drop.py b/tests/load/pipeline/test_drop.py index 313ba63a2c..e1c6ec9d79 100644 --- a/tests/load/pipeline/test_drop.py +++ b/tests/load/pipeline/test_drop.py @@ -17,11 +17,11 @@ ) from dlt.destinations.job_client_impl import SqlJobClientBase -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration +from tests.load.utils import destinations_configs, DestinationTestConfiguration def _attach(pipeline: Pipeline) -> Pipeline: - return dlt.attach(pipeline.pipeline_name, pipeline.pipelines_dir) + return dlt.attach(pipeline.pipeline_name, pipelines_dir=pipeline.pipelines_dir) @dlt.source(section="droppable", name="droppable") @@ -91,13 +91,14 @@ def assert_dropped_resource_tables(pipeline: Pipeline, resources: List[str]) -> client: SqlJobClientBase with pipeline.destination_client(pipeline.default_schema_name) as client: # type: ignore[assignment] # Check all tables supposed to be dropped are not in dataset - for table in dropped_tables: - exists, _ = client.get_storage_table(table) - assert not exists + storage_tables = list(client.get_storage_tables(dropped_tables)) + # no columns in all tables + assert all(len(table[1]) == 0 for table in storage_tables) + # Check tables not from dropped resources still exist - for table in expected_tables: - exists, _ = client.get_storage_table(table) - assert exists + storage_tables = list(client.get_storage_tables(expected_tables)) + # all tables have columns + assert all(len(table[1]) > 0 for table in storage_tables) def assert_dropped_resource_states(pipeline: Pipeline, resources: List[str]) -> None: @@ -178,7 +179,7 @@ def test_drop_command_only_state(destination_config: DestinationTestConfiguratio def test_drop_command_only_tables(destination_config: DestinationTestConfiguration) -> None: """Test drop only tables and makes sure that schema and state are synced""" source = droppable_source() - pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), dev_mode=True) pipeline.run(source) sources_state = pipeline.state["sources"] @@ -334,9 +335,8 @@ def test_drop_all_flag(destination_config: DestinationTestConfiguration) -> None # Verify original _dlt tables were not deleted with attached._sql_job_client(attached.default_schema) as client: - for tbl in dlt_tables: - exists, _ = client.get_storage_table(tbl) - assert exists + storage_tables = list(client.get_storage_tables(dlt_tables)) + assert all(len(table[1]) > 0 for table in storage_tables) @pytest.mark.parametrize( diff --git a/tests/load/pipeline/test_duckdb.py b/tests/load/pipeline/test_duckdb.py index 3f9821cee0..3dcfffe348 100644 --- a/tests/load/pipeline/test_duckdb.py +++ b/tests/load/pipeline/test_duckdb.py @@ -1,16 +1,14 @@ import pytest import os +from dlt.common.schema.exceptions import SchemaIdentifierNormalizationCollision from dlt.common.time import ensure_pendulum_datetime from dlt.destinations.exceptions import DatabaseTerminalException from dlt.pipeline.exceptions import PipelineStepFailed from tests.cases import TABLE_UPDATE_ALL_INT_PRECISIONS, TABLE_UPDATE_ALL_TIMESTAMP_PRECISIONS +from tests.load.utils import destinations_configs, DestinationTestConfiguration from tests.pipeline.utils import airtable_emojis, load_table_counts -from tests.load.pipeline.utils import ( - destinations_configs, - DestinationTestConfiguration, -) @pytest.mark.parametrize( @@ -44,7 +42,7 @@ def test_duck_case_names(destination_config: DestinationTestConfiguration) -> No "🦚Peacock__peacock": 3, "🦚Peacocks🦚": 1, "🦚WidePeacock": 1, - "🦚WidePeacock__peacock": 3, + "🦚WidePeacock__Peacock": 3, } # this will fail - duckdb preserves case but is case insensitive when comparing identifiers @@ -54,7 +52,10 @@ def test_duck_case_names(destination_config: DestinationTestConfiguration) -> No table_name="🦚peacocks🦚", loader_file_format=destination_config.file_format, ) - assert isinstance(pip_ex.value.__context__, DatabaseTerminalException) + assert isinstance(pip_ex.value.__context__, SchemaIdentifierNormalizationCollision) + assert pip_ex.value.__context__.conflict_identifier_name == "🦚Peacocks🦚" + assert pip_ex.value.__context__.identifier_name == "🦚peacocks🦚" + assert pip_ex.value.__context__.identifier_type == "table" # show tables and columns with pipeline.sql_client() as client: diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index efbdc082f1..3f0352cab7 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -14,13 +14,14 @@ from dlt.common.utils import uniq_id from dlt.destinations import filesystem from dlt.destinations.impl.filesystem.filesystem import FilesystemClient +from dlt.destinations.impl.filesystem.typing import TExtraPlaceholders from dlt.pipeline.exceptions import PipelineStepFailed from tests.cases import arrow_table_all_data_types, table_update_and_row, assert_all_data_types_row from tests.common.utils import load_json_case from tests.utils import ALL_TEST_DATA_ITEM_FORMATS, TestDataItemFormat, skip_if_not_active from dlt.destinations.path_utils import create_path -from tests.load.pipeline.utils import ( +from tests.load.utils import ( destinations_configs, DestinationTestConfiguration, ) @@ -34,7 +35,7 @@ @pytest.fixture def local_filesystem_pipeline() -> dlt.Pipeline: os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = "_storage" - return dlt.pipeline(pipeline_name="fs_pipe", destination="filesystem", full_refresh=True) + return dlt.pipeline(pipeline_name="fs_pipe", destination="filesystem", dev_mode=True) def test_pipeline_merge_write_disposition(default_buckets_env: str) -> None: @@ -499,7 +500,7 @@ def count(*args, **kwargs) -> Any: return count - extra_placeholders = { + extra_placeholders: TExtraPlaceholders = { "who": "marcin", "action": "says", "what": "no potato", @@ -600,9 +601,11 @@ def _collect_files(p) -> List[str]: found.append(os.path.join(basedir, file).replace(client.dataset_path, "")) return found - def _collect_table_counts(p) -> Dict[str, int]: + def _collect_table_counts(p, *items: str) -> Dict[str, int]: + expected_items = set(items).intersection({"items", "items2", "items3"}) + print(expected_items) return load_table_counts( - p, "items", "items2", "items3", "_dlt_loads", "_dlt_version", "_dlt_pipeline_state" + p, *expected_items, "_dlt_loads", "_dlt_version", "_dlt_pipeline_state" ) # generate 4 loads from 2 pipelines, store load ids @@ -615,7 +618,7 @@ def _collect_table_counts(p) -> Dict[str, int]: # first two loads p1.run([1, 2, 3], table_name="items").loads_ids[0] load_id_2_1 = p2.run([4, 5, 6], table_name="items").loads_ids[0] - assert _collect_table_counts(p1) == { + assert _collect_table_counts(p1, "items") == { "items": 6, "_dlt_loads": 2, "_dlt_pipeline_state": 2, @@ -642,7 +645,7 @@ def some_data(): p2.run([4, 5, 6], table_name="items").loads_ids[0] # no migration here # 4 loads for 2 pipelines, one schema and state change on p2 changes so 3 versions and 3 states - assert _collect_table_counts(p1) == { + assert _collect_table_counts(p1, "items", "items2") == { "items": 9, "items2": 3, "_dlt_loads": 4, @@ -653,8 +656,8 @@ def some_data(): # test accessors for state s1 = c1.get_stored_state("p1") s2 = c1.get_stored_state("p2") - assert s1.dlt_load_id == load_id_1_2 # second load - assert s2.dlt_load_id == load_id_2_1 # first load + assert s1._dlt_load_id == load_id_1_2 # second load + assert s2._dlt_load_id == load_id_2_1 # first load assert s1_old.version != s1.version assert s2_old.version == s2.version @@ -797,13 +800,15 @@ def table_3(): # check opening of file values = [] - for line in fs_client.read_text(t1_files[0]).split("\n"): + for line in fs_client.read_text(t1_files[0], encoding="utf-8").split("\n"): if line: values.append(json.loads(line)["value"]) assert values == [1, 2, 3, 4, 5] # check binary read - assert fs_client.read_bytes(t1_files[0]) == str.encode(fs_client.read_text(t1_files[0])) + assert fs_client.read_bytes(t1_files[0]) == str.encode( + fs_client.read_text(t1_files[0], encoding="utf-8") + ) # check truncate fs_client.truncate_tables(["table_1"]) diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index a3f5083ae6..e829c8d730 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -11,48 +11,75 @@ from dlt.common.configuration.container import Container from dlt.common.pipeline import StateInjectableContext from dlt.common.schema.utils import has_table_seen_data -from dlt.common.schema.exceptions import SchemaException +from dlt.common.schema.exceptions import SchemaCorruptedException +from dlt.common.schema.typing import TLoaderMergeStrategy from dlt.common.typing import StrAny from dlt.common.utils import digest128 +from dlt.common.destination import TDestination +from dlt.common.destination.exceptions import DestinationCapabilitiesException from dlt.extract import DltResource from dlt.sources.helpers.transform import skip_first, take_first from dlt.pipeline.exceptions import PipelineStepFailed from tests.pipeline.utils import assert_load_info, load_table_counts, select_data -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration +from tests.load.utils import ( + normalize_storage_table_cols, + destinations_configs, + DestinationTestConfiguration, +) # uncomment add motherduck tests # NOTE: the tests are passing but we disable them due to frequent ATTACH DATABASE timeouts # ACTIVE_DESTINATIONS += ["motherduck"] +def skip_if_not_supported( + merge_strategy: TLoaderMergeStrategy, + destination: TDestination, +) -> None: + if merge_strategy not in destination.capabilities().supported_merge_strategies: + pytest.skip( + f"`{merge_strategy}` merge strategy not supported for `{destination.destination_name}`" + " destination." + ) + + @pytest.mark.essential @pytest.mark.parametrize( "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name ) -def test_merge_on_keys_in_schema(destination_config: DestinationTestConfiguration) -> None: +@pytest.mark.parametrize("merge_strategy", ("delete-insert", "upsert")) +def test_merge_on_keys_in_schema( + destination_config: DestinationTestConfiguration, + merge_strategy: TLoaderMergeStrategy, +) -> None: p = destination_config.setup_pipeline("eth_2", dev_mode=True) + skip_if_not_supported(merge_strategy, p.destination) + with open("tests/common/cases/schemas/eth/ethereum_schema_v5.yml", "r", encoding="utf-8") as f: schema = dlt.Schema.from_dict(yaml.safe_load(f)) # make block uncles unseen to trigger filtering loader in loader for child tables if has_table_seen_data(schema.tables["blocks__uncles"]): - del schema.tables["blocks__uncles"]["x-normalizer"] # type: ignore[typeddict-item] + del schema.tables["blocks__uncles"]["x-normalizer"] assert not has_table_seen_data(schema.tables["blocks__uncles"]) - with open( - "tests/normalize/cases/ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2.json", - "r", - encoding="utf-8", - ) as f: - data = json.load(f) + @dlt.resource( + table_name="blocks", + write_disposition={"disposition": "merge", "strategy": merge_strategy}, + ) + def data(slice_: slice = None): + with open( + "tests/normalize/cases/ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2.json", + "r", + encoding="utf-8", + ) as f: + yield json.load(f) if slice_ is None else json.load(f)[slice_] # take only the first block. the first block does not have uncles so this table should not be created and merged info = p.run( - data[:1], - table_name="blocks", - write_disposition="merge", + data(slice(1)), schema=schema, loader_file_format=destination_config.file_format, ) @@ -69,8 +96,6 @@ def test_merge_on_keys_in_schema(destination_config: DestinationTestConfiguratio # if the table would be created before the whole load would fail because new columns have hints info = p.run( data, - table_name="blocks", - write_disposition="merge", schema=schema, loader_file_format=destination_config.file_format, ) @@ -80,8 +105,6 @@ def test_merge_on_keys_in_schema(destination_config: DestinationTestConfiguratio # make sure we have same record after merging full dataset again info = p.run( data, - table_name="blocks", - write_disposition="merge", schema=schema, loader_file_format=destination_config.file_format, ) @@ -96,19 +119,28 @@ def test_merge_on_keys_in_schema(destination_config: DestinationTestConfiguratio @pytest.mark.parametrize( "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name ) -def test_merge_on_ad_hoc_primary_key(destination_config: DestinationTestConfiguration) -> None: +@pytest.mark.parametrize("merge_strategy", ("delete-insert", "upsert")) +def test_merge_on_ad_hoc_primary_key( + destination_config: DestinationTestConfiguration, + merge_strategy: TLoaderMergeStrategy, +) -> None: p = destination_config.setup_pipeline("github_1", dev_mode=True) + skip_if_not_supported(merge_strategy, p.destination) - with open( - "tests/normalize/cases/github.issues.load_page_5_duck.json", "r", encoding="utf-8" - ) as f: - data = json.load(f) - # note: NodeId will be normalized to "node_id" which exists in the schema - info = p.run( - data[:17], + @dlt.resource( table_name="issues", - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": merge_strategy}, primary_key="NodeId", + ) + def data(slice_: slice = None): + with open( + "tests/normalize/cases/github.issues.load_page_5_duck.json", "r", encoding="utf-8" + ) as f: + yield json.load(f) if slice_ is None else json.load(f)[slice_] + + # note: NodeId will be normalized to "node_id" which exists in the schema + info = p.run( + data(slice(0, 17)), loader_file_format=destination_config.file_format, ) assert_load_info(info) @@ -121,10 +153,7 @@ def test_merge_on_ad_hoc_primary_key(destination_config: DestinationTestConfigur assert p.default_schema.tables["issues"]["columns"]["node_id"]["nullable"] is False info = p.run( - data[5:], - table_name="issues", - write_disposition="merge", - primary_key="node_id", + data(slice(5, None)), loader_file_format=destination_config.file_format, ) assert_load_info(info) @@ -307,9 +336,10 @@ def test_merge_keys_non_existing_columns(destination_config: DestinationTestConf github_2_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) assert github_2_counts["issues"] == 100 - 45 + 1 with p._sql_job_client(p.default_schema) as job_c: - _, table_schema = job_c.get_storage_table("issues") - assert "url" in table_schema - assert "m_a1" not in table_schema # unbound columns were not created + _, storage_cols = job_c.get_storage_table("issues") + storage_cols = normalize_storage_table_cols("issues", storage_cols, p.default_schema) + assert "url" in storage_cols + assert "m_a1" not in storage_cols # unbound columns were not created @pytest.mark.parametrize( @@ -319,6 +349,8 @@ def test_merge_keys_non_existing_columns(destination_config: DestinationTestConf ) def test_pipeline_load_parquet(destination_config: DestinationTestConfiguration) -> None: p = destination_config.setup_pipeline("github_3", dev_mode=True) + # do not save state to destination so jobs counting is easier + p.config.restore_from_destination = False github_data = github() # generate some complex types github_data.max_table_nesting = 2 @@ -568,14 +600,23 @@ def duplicates_no_child(): destinations_configs(default_sql_configs=True, supports_merge=True), ids=lambda x: x.name, ) -def test_complex_column_missing(destination_config: DestinationTestConfiguration) -> None: +@pytest.mark.parametrize("merge_strategy", ("delete-insert", "upsert")) +def test_complex_column_missing( + destination_config: DestinationTestConfiguration, + merge_strategy: TLoaderMergeStrategy, +) -> None: table_name = "test_complex_column_missing" - @dlt.resource(name=table_name, write_disposition="merge", primary_key="id") + @dlt.resource( + name=table_name, + write_disposition={"disposition": "merge", "strategy": merge_strategy}, + primary_key="id", + ) def r(data): yield data p = destination_config.setup_pipeline("abstract", dev_mode=True) + skip_if_not_supported(merge_strategy, p.destination) data = [{"id": 1, "simple": "foo", "complex": [1, 2, 3]}] info = p.run(r(data), loader_file_format=destination_config.file_format) @@ -597,14 +638,21 @@ def r(data): ids=lambda x: x.name, ) @pytest.mark.parametrize("key_type", ["primary_key", "merge_key", "no_key"]) -def test_hard_delete_hint(destination_config: DestinationTestConfiguration, key_type: str) -> None: +@pytest.mark.parametrize("merge_strategy", ("delete-insert", "upsert")) +def test_hard_delete_hint( + destination_config: DestinationTestConfiguration, + key_type: str, + merge_strategy: TLoaderMergeStrategy, +) -> None: + if merge_strategy == "upsert" and key_type != "primary_key": + pytest.skip("`upsert` merge strategy requires `primary_key`") # no_key setting will have the effect that hard deletes have no effect, since hard delete records # can not be matched table_name = "test_hard_delete_hint" @dlt.resource( name=table_name, - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": merge_strategy}, columns={"deleted": {"hard_delete": True}}, ) def data_resource(data): @@ -619,6 +667,7 @@ def data_resource(data): pass p = destination_config.setup_pipeline(f"abstract_{key_type}", dev_mode=True) + skip_if_not_supported(merge_strategy, p.destination) # insert two records data = [ @@ -660,6 +709,8 @@ def data_resource(data): {"id": 3, "val": "foo", "deleted": False}, {"id": 3, "val": "bar", "deleted": False}, ] + if merge_strategy == "upsert": + del data[0] # `upsert` requires unique `primary_key` info = p.run(data_resource(data), loader_file_format=destination_config.file_format) assert_load_info(info) counts = load_table_counts(p, table_name)[table_name] @@ -752,12 +803,16 @@ def data_resource(data): destinations_configs(default_sql_configs=True, supports_merge=True), ids=lambda x: x.name, ) -def test_hard_delete_hint_config(destination_config: DestinationTestConfiguration) -> None: +@pytest.mark.parametrize("merge_strategy", ("delete-insert", "upsert")) +def test_hard_delete_hint_config( + destination_config: DestinationTestConfiguration, + merge_strategy: TLoaderMergeStrategy, +) -> None: table_name = "test_hard_delete_hint_non_bool" @dlt.resource( name=table_name, - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": merge_strategy}, primary_key="id", columns={ "deleted_timestamp": {"data_type": "timestamp", "nullable": True, "hard_delete": True} @@ -767,6 +822,7 @@ def data_resource(data): yield data p = destination_config.setup_pipeline("abstract", dev_mode=True) + skip_if_not_supported(merge_strategy, p.destination) # insert two records data = [ @@ -977,15 +1033,60 @@ def r(): @pytest.mark.parametrize( "destination_config", - destinations_configs(default_sql_configs=True, subset=["duckdb"]), + destinations_configs(default_sql_configs=True, subset=["postgres"]), ids=lambda x: x.name, ) -def test_invalid_merge_strategy(destination_config: DestinationTestConfiguration) -> None: - @dlt.resource(write_disposition={"disposition": "merge", "strategy": "foo"}) # type: ignore[call-overload] +def test_merge_strategy_config(destination_config: DestinationTestConfiguration, capsys) -> None: + # merge strategy invalid + with pytest.raises(ValueError): + + @dlt.resource(write_disposition={"disposition": "merge", "strategy": "foo"}) # type: ignore[call-overload] + def invalid_resource(): + yield {"foo": "bar"} + + p = dlt.pipeline( + pipeline_name="dummy_pipeline", + destination="dummy", + full_refresh=True, + ) + + # merge strategy not supported by destination + @dlt.resource(write_disposition={"disposition": "merge", "strategy": "scd2"}) def r(): yield {"foo": "bar"} - p = destination_config.setup_pipeline("abstract", full_refresh=True) + assert "scd2" not in p.destination.capabilities().supported_merge_strategies + with pytest.raises(DestinationCapabilitiesException): + p.run(r()) + + # `upsert` merge strategy without `primary_key` should error + # this check only happens for SQL destinations + p = destination_config.setup_pipeline("sql_pipeline", dev_mode=True) + # assert "upsert" in p.destination.capabilities().supported_merge_strategies + p.drop() + r.apply_hints( + write_disposition={"disposition": "merge", "strategy": "upsert"}, + ) + assert "primary_key" not in r._hints with pytest.raises(PipelineStepFailed) as pip_ex: p.run(r()) - assert isinstance(pip_ex.value.__context__, SchemaException) + assert isinstance(pip_ex.value.__context__, SchemaCorruptedException) + + # TODO: figure out how to test logs on GitHub CI + # section below is commented out because it fails on GitHub CI + # https://github.com/dlt-hub/dlt/pull/1466#discussion_r1658991754 + + # `upsert` merge strategy with `merge_key` should log warning + # p.drop() + # r.apply_hints( + # write_disposition={"disposition": "merge", "strategy": "upsert"}, + # primary_key="foo", + # merge_key="foo", + # ) + # assert "primary_key" in r._hints + # assert "merge_key" in r._hints + # p.run(r()) + # assert ( + # "Merge key is not supported for this strategy and will be ignored." + # in capsys.readouterr().err + # ) diff --git a/tests/load/test_parallelism.py b/tests/load/pipeline/test_parallelism.py similarity index 98% rename from tests/load/test_parallelism.py rename to tests/load/pipeline/test_parallelism.py index a1a09a4d6b..656357fb00 100644 --- a/tests/load/test_parallelism.py +++ b/tests/load/pipeline/test_parallelism.py @@ -55,7 +55,7 @@ def t() -> TDataItems: yield {"num": i} # we load n items for 3 tables in one run - p = dlt.pipeline("sink_test", destination=test_sink, full_refresh=True) + p = dlt.pipeline("sink_test", destination=test_sink, dev_mode=True) p.run( [ dlt.resource(table_name="t1")(t), diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index ad44cd6f5c..ffee515b90 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -1,11 +1,11 @@ from copy import deepcopy import gzip import os -from typing import Any, Callable, Iterator, Tuple, List, cast +from typing import Any, Iterator, List, cast, Tuple, Callable import pytest +from unittest import mock import dlt - from dlt.common import json, sleep from dlt.common.pipeline import SupportsPipeline from dlt.common.destination import Destination @@ -14,19 +14,22 @@ from dlt.common.schema.exceptions import CannotCoerceColumnException from dlt.common.schema.schema import Schema from dlt.common.schema.typing import VERSION_TABLE_NAME +from dlt.common.schema.utils import new_table from dlt.common.typing import TDataItem from dlt.common.utils import uniq_id from dlt.destinations.exceptions import DatabaseUndefinedRelation +from dlt.destinations import filesystem, redshift +from dlt.destinations.job_client_impl import SqlJobClientBase from dlt.extract.exceptions import ResourceNameMissing -from dlt.extract import DltSource +from dlt.extract.source import DltSource from dlt.pipeline.exceptions import ( CannotRestorePipelineException, PipelineConfigMissing, PipelineStepFailed, ) -from tests.utils import TEST_STORAGE_ROOT, data_to_item_format, preserve_environ +from tests.utils import TEST_STORAGE_ROOT, data_to_item_format from tests.pipeline.utils import ( assert_data_table_counts, assert_load_info, @@ -40,12 +43,11 @@ TABLE_UPDATE_COLUMNS_SCHEMA, assert_all_data_types_row, delete_dataset, -) -from tests.load.pipeline.utils import ( drop_active_pipeline_data, - REPLACE_STRATEGIES, + destinations_configs, + DestinationTestConfiguration, ) -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration +from tests.load.pipeline.utils import REPLACE_STRATEGIES # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -137,10 +139,27 @@ def data_fun() -> Iterator[Any]: destinations_configs(default_sql_configs=True, all_buckets_filesystem_configs=True), ids=lambda x: x.name, ) -def test_default_schema_name(destination_config: DestinationTestConfiguration) -> None: +@pytest.mark.parametrize("use_single_dataset", [True, False]) +@pytest.mark.parametrize( + "naming_convention", + [ + "duck_case", + "snake_case", + "sql_cs_v1", + ], +) +def test_default_schema_name( + destination_config: DestinationTestConfiguration, + use_single_dataset: bool, + naming_convention: str, +) -> None: + os.environ["SCHEMA__NAMING"] = naming_convention destination_config.setup() dataset_name = "dataset_" + uniq_id() - data = ["a", "b", "c"] + data = [ + {"id": idx, "CamelInfo": uniq_id(), "GEN_ERIC": alpha} + for idx, alpha in [(0, "A"), (0, "B"), (0, "C")] + ] p = dlt.pipeline( "test_default_schema_name", @@ -149,16 +168,25 @@ def test_default_schema_name(destination_config: DestinationTestConfiguration) - staging=destination_config.staging, dataset_name=dataset_name, ) + p.config.use_single_dataset = use_single_dataset p.extract(data, table_name="test", schema=Schema("default")) p.normalize() info = p.load() + print(info) # try to restore pipeline r_p = dlt.attach("test_default_schema_name", TEST_STORAGE_ROOT) schema = r_p.default_schema assert schema.name == "default" - assert_table(p, "test", data, info=info) + # check if dlt ables have exactly the required schemas + # TODO: uncomment to check dlt tables schemas + # assert ( + # r_p.default_schema.tables[PIPELINE_STATE_TABLE_NAME]["columns"] + # == pipeline_state_table()["columns"] + # ) + + # assert_table(p, "test", data, info=info) @pytest.mark.parametrize( @@ -495,10 +523,16 @@ def test_dataset_name_change(destination_config: DestinationTestConfiguration) - def test_pipeline_explicit_destination_credentials( destination_config: DestinationTestConfiguration, ) -> None: + from dlt.destinations import postgres + from dlt.destinations.impl.postgres.configuration import PostgresCredentials + # explicit credentials resolved p = dlt.pipeline( - destination=Destination.from_reference("postgres", destination_name="mydest"), - credentials="postgresql://loader:loader@localhost:7777/dlt_data", + destination=Destination.from_reference( + "postgres", + destination_name="mydest", + credentials="postgresql://loader:loader@localhost:7777/dlt_data", + ), ) c = p._get_destination_clients(Schema("s"), p._get_destination_client_initial_config())[0] assert c.config.credentials.port == 7777 # type: ignore[attr-defined] @@ -507,8 +541,11 @@ def test_pipeline_explicit_destination_credentials( # explicit credentials resolved ignoring the config providers os.environ["DESTINATION__MYDEST__CREDENTIALS__HOST"] = "HOST" p = dlt.pipeline( - destination=Destination.from_reference("postgres", destination_name="mydest"), - credentials="postgresql://loader:loader@localhost:5432/dlt_data", + destination=Destination.from_reference( + "postgres", + destination_name="mydest", + credentials="postgresql://loader:loader@localhost:5432/dlt_data", + ), ) c = p._get_destination_clients(Schema("s"), p._get_destination_client_initial_config())[0] assert c.config.credentials.host == "localhost" # type: ignore[attr-defined] @@ -517,20 +554,35 @@ def test_pipeline_explicit_destination_credentials( os.environ["DESTINATION__MYDEST__CREDENTIALS__USERNAME"] = "UN" os.environ["DESTINATION__MYDEST__CREDENTIALS__PASSWORD"] = "PW" p = dlt.pipeline( - destination=Destination.from_reference("postgres", destination_name="mydest"), - credentials="postgresql://localhost:5432/dlt_data", + destination=Destination.from_reference( + "postgres", + destination_name="mydest", + credentials="postgresql://localhost:5432/dlt_data", + ), ) c = p._get_destination_clients(Schema("s"), p._get_destination_client_initial_config())[0] assert c.config.credentials.username == "UN" # type: ignore[attr-defined] - # host is also overridden - assert c.config.credentials.host == "HOST" # type: ignore[attr-defined] + # host is taken form explicit credentials + assert c.config.credentials.host == "localhost" # type: ignore[attr-defined] # instance of credentials will be simply passed - # c = RedshiftCredentials("postgresql://loader:loader@localhost/dlt_data") - # assert c.is_resolved() - # p = dlt.pipeline(destination="postgres", credentials=c) - # inner_c = p._get_destination_clients(Schema("s"), p._get_destination_client_initial_config())[0] - # assert inner_c is c + cred = PostgresCredentials("postgresql://user:pass@localhost/dlt_data") + p = dlt.pipeline(destination=postgres(credentials=cred)) + inner_c = p.destination_client() + assert inner_c.config.credentials is cred + + # with staging + p = dlt.pipeline( + pipeline_name="postgres_pipeline", + staging=filesystem("_storage"), + destination=redshift(credentials="redshift://loader:password@localhost:5432/dlt_data"), + ) + config = p.destination_client().config + assert config.credentials.is_resolved() + assert ( + config.credentials.to_native_representation() + == "redshift://loader:password@localhost:5432/dlt_data?connect_timeout=15" + ) # do not remove - it allows us to filter tests by destination @@ -677,9 +729,8 @@ def gen2(): # restore from destination, check state p = dlt.pipeline( pipeline_name="source_1_pipeline", - destination="duckdb", + destination=dlt.destinations.duckdb(credentials="duckdb:///_storage/test_quack.duckdb"), dataset_name="shared_dataset", - credentials="duckdb:///_storage/test_quack.duckdb", ) p.sync_destination() # we have our separate state @@ -694,9 +745,8 @@ def gen2(): p = dlt.pipeline( pipeline_name="source_2_pipeline", - destination="duckdb", + destination=dlt.destinations.duckdb(credentials="duckdb:///_storage/test_quack.duckdb"), dataset_name="shared_dataset", - credentials="duckdb:///_storage/test_quack.duckdb", ) p.sync_destination() # we have our separate state @@ -773,7 +823,6 @@ def test_snowflake_delete_file_after_copy(destination_config: DestinationTestCon assert_query_data(pipeline, f"SELECT value FROM {tbl_name}", ["a", None, None]) -# do not remove - it allows us to filter tests by destination @pytest.mark.parametrize( "destination_config", destinations_configs(default_sql_configs=True, all_staging_configs=True, file_format="parquet"), @@ -947,8 +996,7 @@ def table_3(make_data=False): load_table_counts(pipeline, "table_3") assert "x-normalizer" not in pipeline.default_schema.tables["table_3"] assert ( - pipeline.default_schema.tables["_dlt_pipeline_state"]["x-normalizer"]["seen-data"] # type: ignore[typeddict-item] - is True + pipeline.default_schema.tables["_dlt_pipeline_state"]["x-normalizer"]["seen-data"] is True ) # load with one empty job, table 3 not created @@ -990,18 +1038,9 @@ def table_3(make_data=False): # print(v5) # check if seen data is market correctly - assert ( - pipeline.default_schema.tables["table_3"]["x-normalizer"]["seen-data"] # type: ignore[typeddict-item] - is True - ) - assert ( - pipeline.default_schema.tables["table_2"]["x-normalizer"]["seen-data"] # type: ignore[typeddict-item] - is True - ) - assert ( - pipeline.default_schema.tables["table_1"]["x-normalizer"]["seen-data"] # type: ignore[typeddict-item] - is True - ) + assert pipeline.default_schema.tables["table_3"]["x-normalizer"]["seen-data"] is True + assert pipeline.default_schema.tables["table_2"]["x-normalizer"]["seen-data"] is True + assert pipeline.default_schema.tables["table_1"]["x-normalizer"]["seen-data"] is True job_client, _ = pipeline._get_destination_clients(schema) @@ -1013,12 +1052,38 @@ def table_3(make_data=False): if job_client.should_load_data_to_staging_dataset( job_client.schema.tables[table_name] ): - with client.with_staging_dataset(staging=True): + with client.with_staging_dataset(): tab_name = client.make_qualified_table_name(table_name) with client.execute_query(f"SELECT * FROM {tab_name}") as cur: assert len(cur.fetchall()) == 0 +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) +def test_query_all_info_tables_fallback(destination_config: DestinationTestConfiguration) -> None: + pipeline = destination_config.setup_pipeline( + "parquet_test_" + uniq_id(), dataset_name="parquet_test_" + uniq_id() + ) + with mock.patch.object(SqlJobClientBase, "INFO_TABLES_QUERY_THRESHOLD", 0): + info = pipeline.run([1, 2, 3], table_name="digits_1") + assert_load_info(info) + # create empty table + client: SqlJobClientBase + # we must add it to schema + pipeline.default_schema._schema_tables["existing_table"] = new_table("existing_table") + with pipeline.destination_client() as client: # type: ignore[assignment] + sql = client._get_table_update_sql( + "existing_table", [{"name": "_id", "data_type": "bigint"}], False + ) + client.sql_client.execute_many(sql) + # remove it from schema + del pipeline.default_schema._schema_tables["existing_table"] + # store another table + info = pipeline.run([1, 2, 3], table_name="digits_2") + assert_data_table_counts(pipeline, {"digits_1": 3, "digits_2": 3}) + + # @pytest.mark.skip(reason="Finalize the test: compare some_data values to values from database") # @pytest.mark.parametrize( # "destination_config", diff --git a/tests/load/pipeline/test_postgres.py b/tests/load/pipeline/test_postgres.py index a64ee300cd..a4001b7faa 100644 --- a/tests/load/pipeline/test_postgres.py +++ b/tests/load/pipeline/test_postgres.py @@ -6,45 +6,11 @@ from dlt.common.utils import uniq_id -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration -from tests.cases import arrow_table_all_data_types, prepare_shuffled_tables -from tests.pipeline.utils import assert_data_table_counts, assert_load_info, load_tables_to_dicts +from tests.load.utils import destinations_configs, DestinationTestConfiguration +from tests.pipeline.utils import assert_load_info, load_tables_to_dicts from tests.utils import TestDataItemFormat -@pytest.mark.parametrize( - "destination_config", - destinations_configs(default_sql_configs=True, subset=["postgres"]), - ids=lambda x: x.name, -) -@pytest.mark.parametrize("item_type", ["object", "table"]) -def test_postgres_load_csv( - destination_config: DestinationTestConfiguration, item_type: TestDataItemFormat -) -> None: - os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" - pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), full_refresh=True) - table, shuffled_table, shuffled_removed_column = prepare_shuffled_tables() - - # convert to pylist when loading from objects, this will kick the csv-reader in - if item_type == "object": - table, shuffled_table, shuffled_removed_column = ( - table.to_pylist(), - shuffled_table.to_pylist(), - shuffled_removed_column.to_pylist(), - ) - - load_info = pipeline.run( - [shuffled_removed_column, shuffled_table, table], - table_name="table", - loader_file_format="csv", - ) - assert_load_info(load_info) - job = load_info.load_packages[0].jobs["completed_jobs"][0].file_path - assert job.endswith("csv") - assert_data_table_counts(pipeline, {"table": 5432 * 3}) - load_tables_to_dicts(pipeline, "table") - - @pytest.mark.parametrize( "destination_config", destinations_configs(default_sql_configs=True, subset=["postgres"]), @@ -64,7 +30,7 @@ def test_postgres_encoded_binary( blob_table = blob_table.to_pylist() print(blob_table) - pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), dev_mode=True) load_info = pipeline.run(blob_table, table_name="table", loader_file_format="csv") assert_load_info(load_info) job = load_info.load_packages[0].jobs["completed_jobs"][0].file_path @@ -76,31 +42,3 @@ def test_postgres_encoded_binary( # print(bytes(data["table"][0]["hash"])) # data in postgres equals unencoded blob assert data["table"][0]["hash"].tobytes() == blob - - -@pytest.mark.parametrize( - "destination_config", - destinations_configs(default_sql_configs=True, subset=["postgres"]), - ids=lambda x: x.name, -) -def test_postgres_empty_csv_from_arrow(destination_config: DestinationTestConfiguration) -> None: - os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" - os.environ["RESTORE_FROM_DESTINATION"] = "False" - pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), full_refresh=True) - table, _, _ = arrow_table_all_data_types("arrow-table", include_json=False) - - load_info = pipeline.run( - table.schema.empty_table(), table_name="table", loader_file_format="csv" - ) - assert_load_info(load_info) - assert len(load_info.load_packages[0].jobs["completed_jobs"]) == 1 - job = load_info.load_packages[0].jobs["completed_jobs"][0].file_path - assert job.endswith("csv") - assert_data_table_counts(pipeline, {"table": 0}) - with pipeline.sql_client() as client: - with client.execute_query('SELECT * FROM "table"') as cur: - columns = [col.name for col in cur.description] - assert len(cur.fetchall()) == 0 - - # all columns in order - assert columns == list(pipeline.default_schema.get_table_columns("table").keys()) diff --git a/tests/load/pipeline/test_redshift.py b/tests/load/pipeline/test_redshift.py index 29293693f5..bfdc15459c 100644 --- a/tests/load/pipeline/test_redshift.py +++ b/tests/load/pipeline/test_redshift.py @@ -4,7 +4,7 @@ import dlt from dlt.common.utils import uniq_id -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration +from tests.load.utils import destinations_configs, DestinationTestConfiguration from tests.cases import table_update_and_row, assert_all_data_types_row from tests.pipeline.utils import assert_load_info diff --git a/tests/load/pipeline/test_refresh_modes.py b/tests/load/pipeline/test_refresh_modes.py index 02ed560068..f4bf3b0311 100644 --- a/tests/load/pipeline/test_refresh_modes.py +++ b/tests/load/pipeline/test_refresh_modes.py @@ -2,21 +2,30 @@ import pytest import dlt +from dlt.common.destination.exceptions import DestinationUndefinedEntity from dlt.common.pipeline import resource_state -from dlt.destinations.sql_client import DBApiCursor -from dlt.pipeline.state_sync import load_pipeline_state_from_destination +from dlt.common.utils import uniq_id from dlt.common.typing import DictStrAny from dlt.common.pipeline import pipeline_state as current_pipeline_state -from tests.utils import clean_test_storage, preserve_environ +from dlt.destinations.sql_client import DBApiCursor +from dlt.extract.source import DltSource +from dlt.pipeline.state_sync import load_pipeline_state_from_destination + +from tests.utils import clean_test_storage from tests.pipeline.utils import ( + _is_filesystem, assert_load_info, + load_table_counts, load_tables_to_dicts, assert_only_table_columns, table_exists, ) from tests.load.utils import destinations_configs, DestinationTestConfiguration +# mark all tests as essential, do not remove +pytestmark = pytest.mark.essential + def assert_source_state_is_wiped(state: DictStrAny) -> None: # Keys contains only "resources" or is empty @@ -66,7 +75,7 @@ def some_data_2(): yield {"id": 7} yield {"id": 8} - @dlt.resource + @dlt.resource(primary_key="id", write_disposition="merge") def some_data_3(): if first_run: dlt.state()["source_key_3"] = "source_value_3" @@ -103,7 +112,6 @@ def test_refresh_drop_sources(destination_config: DestinationTestConfiguration): # First run pipeline so destination so tables are created info = pipeline.run(refresh_source(first_run=True, drop_sources=True)) assert_load_info(info) - # Second run of pipeline with only selected resources info = pipeline.run( refresh_source(first_run=False, drop_sources=True).with_resources( @@ -114,8 +122,6 @@ def test_refresh_drop_sources(destination_config: DestinationTestConfiguration): assert set(t["name"] for t in pipeline.default_schema.data_tables(include_incomplete=True)) == { "some_data_1", "some_data_2", - # Table has never seen data and is not dropped - "some_data_4", } # No "name" column should exist as table was dropped and re-created without it @@ -163,7 +169,7 @@ def test_existing_schema_hash(destination_config: DestinationTestConfiguration): new_table_names = set( t["name"] for t in pipeline.default_schema.data_tables(include_incomplete=True) ) - assert new_table_names == {"some_data_1", "some_data_2", "some_data_4"} + assert new_table_names == {"some_data_1", "some_data_2"} # Run again with all tables to ensure they are re-created # The new schema in this case should match the schema of the first run exactly @@ -430,10 +436,76 @@ def test_refresh_argument_to_extract(destination_config: DestinationTestConfigur tables = set(t["name"] for t in pipeline.default_schema.data_tables(include_incomplete=True)) # All other data tables removed - assert tables == {"some_data_3", "some_data_4"} + assert tables == {"some_data_3"} # Run again without refresh to confirm refresh option doesn't persist on pipeline pipeline.extract(refresh_source(first_run=False).with_resources("some_data_2")) tables = set(t["name"] for t in pipeline.default_schema.data_tables(include_incomplete=True)) - assert tables == {"some_data_2", "some_data_3", "some_data_4"} + assert tables == {"some_data_2", "some_data_3"} + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + default_sql_configs=True, default_staging_configs=True, all_buckets_filesystem_configs=True + ), + ids=lambda x: x.name, +) +def test_refresh_staging_dataset(destination_config: DestinationTestConfiguration): + data = [ + {"id": 1, "pop": 1}, + {"id": 2, "pop": 3}, + {"id": 2, "pop": 4}, # duplicate + ] + + pipeline = destination_config.setup_pipeline("test_refresh_staging_dataset" + uniq_id()) + + source = DltSource( + dlt.Schema("data_x"), + "data_section", + [ + dlt.resource(data, name="data_1", primary_key="id", write_disposition="merge"), + dlt.resource(data, name="data_2", primary_key="id", write_disposition="append"), + ], + ) + # create two tables so two tables need to be dropped + info = pipeline.run(source) + assert_load_info(info) + + # make data so inserting on mangled tables is not possible + data_i = [ + {"id": "A", "pop": 0.1}, + {"id": "B", "pop": 0.3}, + {"id": "A", "pop": 0.4}, + ] + source_i = DltSource( + dlt.Schema("data_x"), + "data_section", + [ + dlt.resource(data_i, name="data_1", primary_key="id", write_disposition="merge"), + dlt.resource(data_i, name="data_2", primary_key="id", write_disposition="append"), + ], + ) + info = pipeline.run(source_i, refresh="drop_resources") + assert_load_info(info) + + # now replace the whole source and load different tables + source_i = DltSource( + dlt.Schema("data_x"), + "data_section", + [ + dlt.resource(data_i, name="data_1_v2", primary_key="id", write_disposition="merge"), + dlt.resource(data_i, name="data_2_v2", primary_key="id", write_disposition="append"), + ], + ) + info = pipeline.run(source_i, refresh="drop_sources") + assert_load_info(info) + + # tables got dropped + if _is_filesystem(pipeline): + assert load_table_counts(pipeline, "data_1", "data_2") == {} + else: + with pytest.raises(DestinationUndefinedEntity): + load_table_counts(pipeline, "data_1", "data_2") + load_table_counts(pipeline, "data_1_v2", "data_1_v2") diff --git a/tests/load/pipeline/test_replace_disposition.py b/tests/load/pipeline/test_replace_disposition.py index 464b5aea1f..12bc69abe0 100644 --- a/tests/load/pipeline/test_replace_disposition.py +++ b/tests/load/pipeline/test_replace_disposition.py @@ -4,12 +4,12 @@ from dlt.common.utils import uniq_id from tests.pipeline.utils import assert_load_info, load_table_counts, load_tables_to_dicts -from tests.load.pipeline.utils import ( +from tests.load.utils import ( drop_active_pipeline_data, destinations_configs, DestinationTestConfiguration, - REPLACE_STRATEGIES, ) +from tests.load.pipeline.utils import REPLACE_STRATEGIES @pytest.mark.essential diff --git a/tests/load/pipeline/test_restore_state.py b/tests/load/pipeline/test_restore_state.py index b287619e8c..c3968e2e74 100644 --- a/tests/load/pipeline/test_restore_state.py +++ b/tests/load/pipeline/test_restore_state.py @@ -6,15 +6,17 @@ import dlt from dlt.common import pendulum -from dlt.common.schema.schema import Schema +from dlt.common.destination.capabilities import DestinationCapabilitiesContext +from dlt.common.schema.schema import Schema, utils +from dlt.common.schema.utils import normalize_table_identifiers from dlt.common.utils import uniq_id from dlt.common.destination.exceptions import DestinationUndefinedEntity +from dlt.common.destination.reference import WithStateSync from dlt.load import Load from dlt.pipeline.exceptions import SqlClientNotAvailable from dlt.pipeline.pipeline import Pipeline from dlt.pipeline.state_sync import ( - STATE_TABLE_COLUMNS, load_pipeline_state_from_destination, state_resource, ) @@ -24,12 +26,12 @@ from tests.cases import JSON_TYPED_DICT, JSON_TYPED_DICT_DECODED from tests.common.utils import IMPORTED_VERSION_HASH_ETH_V9, yml_case_path as common_yml_case_path from tests.common.configuration.utils import environment -from tests.load.pipeline.utils import drop_active_pipeline_data from tests.pipeline.utils import assert_query_data from tests.load.utils import ( destinations_configs, DestinationTestConfiguration, get_normalized_dataset_name, + drop_active_pipeline_data, ) @@ -64,9 +66,10 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] with pytest.raises(DestinationUndefinedEntity): load_pipeline_state_from_destination(p.pipeline_name, job_client) - # sync the schema - p.sync_schema() - # check if schema exists + # sync the schema + p.sync_schema() + # check if schema exists + with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] stored_schema = job_client.get_stored_schema() assert stored_schema is not None # dataset exists, still no table @@ -77,90 +80,102 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - initial_state["_local"]["_last_extracted_at"] = pendulum.now() initial_state["_local"]["_last_extracted_hash"] = initial_state["_version_hash"] # add _dlt_id and _dlt_load_id - resource, _ = state_resource(initial_state) + resource, _ = state_resource(initial_state, "not_used_load_id") resource.apply_hints( columns={ "_dlt_id": {"name": "_dlt_id", "data_type": "text", "nullable": False}, "_dlt_load_id": {"name": "_dlt_load_id", "data_type": "text", "nullable": False}, - **STATE_TABLE_COLUMNS, + **utils.pipeline_state_table()["columns"], } ) - schema.update_table(schema.normalize_table_identifiers(resource.compute_table_schema())) + schema.update_table( + normalize_table_identifiers(resource.compute_table_schema(), schema.naming) + ) # do not bump version here or in sync_schema, dlt won't recognize that schema changed and it won't update it in storage # so dlt in normalize stage infers _state_version table again but with different column order and the column order in schema is different # then in database. parquet is created in schema order and in Redshift it must exactly match the order. # schema.bump_version() - p.sync_schema() + p.sync_schema() + with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] stored_schema = job_client.get_stored_schema() assert stored_schema is not None # table is there but no state assert load_pipeline_state_from_destination(p.pipeline_name, job_client) is None - # extract state - with p.managed_state(extract_state=True): - pass - # just run the existing extract - p.normalize(loader_file_format=destination_config.file_format) - p.load() + + # extract state + with p.managed_state(extract_state=True): + pass + # just run the existing extract + p.normalize(loader_file_format=destination_config.file_format) + p.load() + + with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] stored_state = load_pipeline_state_from_destination(p.pipeline_name, job_client) - local_state = p._get_state() - local_state.pop("_local") - assert stored_state == local_state - # extract state again - with p.managed_state(extract_state=True) as managed_state: - # this will be saved - managed_state["sources"] = {"source": dict(JSON_TYPED_DICT_DECODED)} - p.normalize(loader_file_format=destination_config.file_format) - p.load() + local_state = p._get_state() + local_state.pop("_local") + assert stored_state == local_state + # extract state again + with p.managed_state(extract_state=True) as managed_state: + # this will be saved + managed_state["sources"] = {"source": dict(JSON_TYPED_DICT_DECODED)} + p.normalize(loader_file_format=destination_config.file_format) + p.load() + + with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] stored_state = load_pipeline_state_from_destination(p.pipeline_name, job_client) - assert stored_state["sources"] == {"source": JSON_TYPED_DICT_DECODED} - local_state = p._get_state() - local_state.pop("_local") - assert stored_state == local_state - # use the state context manager again but do not change state - with p.managed_state(extract_state=True): - pass - # version not changed - new_local_state = p._get_state() - new_local_state.pop("_local") - assert local_state == new_local_state - p.normalize(loader_file_format=destination_config.file_format) - info = p.load() - assert len(info.loads_ids) == 0 + assert stored_state["sources"] == {"source": JSON_TYPED_DICT_DECODED} + local_state = p._get_state() + local_state.pop("_local") + assert stored_state == local_state + # use the state context manager again but do not change state + with p.managed_state(extract_state=True): + pass + # version not changed + new_local_state = p._get_state() + new_local_state.pop("_local") + assert local_state == new_local_state + p.normalize(loader_file_format=destination_config.file_format) + info = p.load() + assert len(info.loads_ids) == 0 + + with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] new_stored_state = load_pipeline_state_from_destination(p.pipeline_name, job_client) - # new state should not be stored - assert new_stored_state == stored_state - - # change the state in context manager but there's no extract - with p.managed_state(extract_state=False) as managed_state: - managed_state["sources"] = {"source": "test2"} # type: ignore[dict-item] - new_local_state = p._get_state() - new_local_state_local = new_local_state.pop("_local") - assert local_state != new_local_state - # version increased - assert local_state["_state_version"] + 1 == new_local_state["_state_version"] - # last extracted hash does not match current version hash - assert new_local_state_local["_last_extracted_hash"] != new_local_state["_version_hash"] - - # use the state context manager again but do not change state - # because _last_extracted_hash is not present (or different), the version will not change but state will be extracted anyway - with p.managed_state(extract_state=True): - pass - new_local_state_2 = p._get_state() - new_local_state_2_local = new_local_state_2.pop("_local") - assert new_local_state == new_local_state_2 - # there's extraction timestamp - assert "_last_extracted_at" in new_local_state_2_local - # and extract hash is == hash - assert new_local_state_2_local["_last_extracted_hash"] == new_local_state_2["_version_hash"] - # but the version didn't change - assert new_local_state["_state_version"] == new_local_state_2["_state_version"] - p.normalize(loader_file_format=destination_config.file_format) - info = p.load() - assert len(info.loads_ids) == 1 + # new state should not be stored + assert new_stored_state == stored_state + + # change the state in context manager but there's no extract + with p.managed_state(extract_state=False) as managed_state: + managed_state["sources"] = {"source": "test2"} # type: ignore[dict-item] + new_local_state = p._get_state() + new_local_state_local = new_local_state.pop("_local") + assert local_state != new_local_state + # version increased + assert local_state["_state_version"] + 1 == new_local_state["_state_version"] + # last extracted hash does not match current version hash + assert new_local_state_local["_last_extracted_hash"] != new_local_state["_version_hash"] + + # use the state context manager again but do not change state + # because _last_extracted_hash is not present (or different), the version will not change but state will be extracted anyway + with p.managed_state(extract_state=True): + pass + new_local_state_2 = p._get_state() + new_local_state_2_local = new_local_state_2.pop("_local") + assert new_local_state == new_local_state_2 + # there's extraction timestamp + assert "_last_extracted_at" in new_local_state_2_local + # and extract hash is == hash + assert new_local_state_2_local["_last_extracted_hash"] == new_local_state_2["_version_hash"] + # but the version didn't change + assert new_local_state["_state_version"] == new_local_state_2["_state_version"] + p.normalize(loader_file_format=destination_config.file_format) + info = p.load() + assert len(info.loads_ids) == 1 + + with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] new_stored_state_2 = load_pipeline_state_from_destination(p.pipeline_name, job_client) - # the stored state changed to next version - assert new_stored_state != new_stored_state_2 - assert new_stored_state["_state_version"] + 1 == new_stored_state_2["_state_version"] + # the stored state changed to next version + assert new_stored_state != new_stored_state_2 + assert new_stored_state["_state_version"] + 1 == new_stored_state_2["_state_version"] @pytest.mark.parametrize( @@ -183,6 +198,7 @@ def test_silently_skip_on_invalid_credentials( destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) +@pytest.mark.essential @pytest.mark.parametrize( "destination_config", destinations_configs( @@ -191,13 +207,25 @@ def test_silently_skip_on_invalid_credentials( ids=lambda x: x.name, ) @pytest.mark.parametrize("use_single_dataset", [True, False]) +@pytest.mark.parametrize( + "naming_convention", + [ + "tests.common.cases.normalizers.title_case", + "snake_case", + ], +) def test_get_schemas_from_destination( - destination_config: DestinationTestConfiguration, use_single_dataset: bool + destination_config: DestinationTestConfiguration, + use_single_dataset: bool, + naming_convention: str, ) -> None: + set_naming_env(destination_config.destination, naming_convention) + pipeline_name = "pipe_" + uniq_id() dataset_name = "state_test_" + uniq_id() p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) + assert_naming_to_caps(destination_config.destination, p.destination.capabilities()) p.config.use_single_dataset = use_single_dataset def _make_dn_name(schema_name: str) -> str: @@ -208,9 +236,10 @@ def _make_dn_name(schema_name: str) -> str: default_schema = Schema("state") p._inject_schema(default_schema) + + # just sync schema without name - will use default schema + p.sync_schema() with p.destination_client() as job_client: - # just sync schema without name - will use default schema - p.sync_schema() assert get_normalized_dataset_name( job_client ) == default_schema.naming.normalize_table_identifier(dataset_name) @@ -226,9 +255,9 @@ def _make_dn_name(schema_name: str) -> str: ) == schema_two.naming.normalize_table_identifier(_make_dn_name("two")) schema_three = Schema("three") p._inject_schema(schema_three) + # sync schema with a name + p.sync_schema(schema_three.name) with p._get_destination_clients(schema_three)[0] as job_client: - # sync schema with a name - p.sync_schema(schema_three.name) assert get_normalized_dataset_name( job_client ) == schema_three.naming.normalize_table_identifier(_make_dn_name("three")) @@ -268,18 +297,34 @@ def _make_dn_name(schema_name: str) -> str: assert len(restored_schemas) == 3 +@pytest.mark.essential @pytest.mark.parametrize( "destination_config", destinations_configs( - default_sql_configs=True, default_vector_configs=True, all_buckets_filesystem_configs=True + default_sql_configs=True, + all_staging_configs=True, + default_vector_configs=True, + all_buckets_filesystem_configs=True, ), ids=lambda x: x.name, ) -def test_restore_state_pipeline(destination_config: DestinationTestConfiguration) -> None: +@pytest.mark.parametrize( + "naming_convention", + [ + "tests.common.cases.normalizers.title_case", + "snake_case", + ], +) +def test_restore_state_pipeline( + destination_config: DestinationTestConfiguration, naming_convention: str +) -> None: + set_naming_env(destination_config.destination, naming_convention) + # enable restoring from destination os.environ["RESTORE_FROM_DESTINATION"] = "True" pipeline_name = "pipe_" + uniq_id() dataset_name = "state_test_" + uniq_id() p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) + assert_naming_to_caps(destination_config.destination, p.destination.capabilities()) def some_data_gen(param: str) -> Any: dlt.current.source_state()[param] = param @@ -366,7 +411,7 @@ def some_data(): p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) # now attach locally os.environ["RESTORE_FROM_DESTINATION"] = "True" - p = dlt.attach(pipeline_name=pipeline_name) + p = destination_config.attach_pipeline(pipeline_name=pipeline_name) assert p.dataset_name == dataset_name assert p.default_schema_name is None # restore @@ -409,8 +454,19 @@ def test_ignore_state_unfinished_load(destination_config: DestinationTestConfigu @dlt.resource def some_data(param: str) -> Any: dlt.current.source_state()[param] = param - yield param + yield {"col1": param, param: 1} + + job_client: WithStateSync + # Load some complete load packages with state to the destination + p.run(some_data("state1"), loader_file_format=destination_config.file_format) + p.run(some_data("state2"), loader_file_format=destination_config.file_format) + p.run(some_data("state3"), loader_file_format=destination_config.file_format) + + with p._get_destination_clients(p.default_schema)[0] as job_client: # type: ignore[assignment] + state = load_pipeline_state_from_destination(pipeline_name, job_client) + assert state and state["_state_version"] == 3 + # Simulate a load package that stores state but is not completed (no entry in loads table) def complete_package_mock(self, load_id: str, schema: Schema, aborted: bool = False): # complete in local storage but skip call to the database self.load_storage.complete_load_package(load_id, aborted) @@ -419,11 +475,18 @@ def complete_package_mock(self, load_id: str, schema: Schema, aborted: bool = Fa p.run(some_data("fix_1"), loader_file_format=destination_config.file_format) # assert complete_package.called - job_client: SqlJobClientBase with p._get_destination_clients(p.default_schema)[0] as job_client: # type: ignore[assignment] # state without completed load id is not visible state = load_pipeline_state_from_destination(pipeline_name, job_client) - assert state is None + # Restored state version has not changed + assert state and state["_state_version"] == 3 + + newest_schema_hash = p.default_schema.version_hash + p._wipe_working_folder() + p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) + p.sync_destination() + + assert p.default_schema.version_hash == newest_schema_hash @pytest.mark.parametrize( @@ -451,6 +514,9 @@ def test_restore_schemas_while_import_schemas_exist( # make sure schema got imported schema = p.schemas["ethereum"] assert "blocks" in schema.tables + # allow to modify tables even if naming convention is changed. some of the tables in ethereum schema + # have processing hints that lock the table schema. so when weaviate changes naming convention we have an exception + os.environ["SCHEMA__ALLOW_IDENTIFIER_CHANGE_ON_TABLE_WITH_DATA"] = "true" # extract some additional data to upgrade schema in the pipeline p.run( @@ -467,7 +533,7 @@ def test_restore_schemas_while_import_schemas_exist( assert normalized_labels in schema.tables # re-attach the pipeline - p = dlt.attach(pipeline_name=pipeline_name) + p = destination_config.attach_pipeline(pipeline_name=pipeline_name) p.run( ["C", "D", "E"], table_name="annotations", loader_file_format=destination_config.file_format ) @@ -496,7 +562,7 @@ def test_restore_schemas_while_import_schemas_exist( assert normalized_annotations in schema.tables # check if attached to import schema - assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9 + assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9() # extract some data with restored pipeline p.run( ["C", "D", "E"], table_name="blacklist", loader_file_format=destination_config.file_format @@ -604,7 +670,9 @@ def some_data(param: str) -> Any: prod_state = production_p.state assert p.state["_state_version"] == prod_state["_state_version"] - 1 # re-attach production and sync - ra_production_p = dlt.attach(pipeline_name=pipeline_name, pipelines_dir=TEST_STORAGE_ROOT) + ra_production_p = destination_config.attach_pipeline( + pipeline_name=pipeline_name, pipelines_dir=TEST_STORAGE_ROOT + ) ra_production_p.sync_destination() # state didn't change because production is ahead of local with its version # nevertheless this is potentially dangerous situation 🤷 @@ -613,10 +681,18 @@ def some_data(param: str) -> Any: # get all the states, notice version 4 twice (one from production, the other from local) try: with p.sql_client() as client: + # use sql_client to escape identifiers properly state_table = client.make_qualified_table_name(p.default_schema.state_table_name) - + c_version = client.escape_column_name( + p.default_schema.naming.normalize_identifier("version") + ) + c_created_at = client.escape_column_name( + p.default_schema.naming.normalize_identifier("created_at") + ) assert_query_data( - p, f"SELECT version FROM {state_table} ORDER BY created_at DESC", [5, 4, 4, 3, 2] + p, + f"SELECT {c_version} FROM {state_table} ORDER BY {c_created_at} DESC", + [5, 4, 4, 3, 2], ) except SqlClientNotAvailable: pytest.skip(f"destination {destination_config.destination} does not support sql client") @@ -669,7 +745,7 @@ def some_data(param: str) -> Any: assert p.dataset_name == dataset_name print("---> no state sync last attach") - p = dlt.attach(pipeline_name=pipeline_name) + p = destination_config.attach_pipeline(pipeline_name=pipeline_name) # this will prevent from creating of _dlt_pipeline_state p.config.restore_from_destination = False data4 = some_data("state4") @@ -686,7 +762,7 @@ def some_data(param: str) -> Any: assert p.state["_local"]["first_run"] is False # attach again to make the `run` method check the destination print("---> last attach") - p = dlt.attach(pipeline_name=pipeline_name) + p = destination_config.attach_pipeline(pipeline_name=pipeline_name) p.config.restore_from_destination = True data5 = some_data("state4") data5.apply_hints(table_name="state1_data5") @@ -696,8 +772,31 @@ def some_data(param: str) -> Any: def prepare_import_folder(p: Pipeline) -> None: - os.makedirs(p._schema_storage.config.import_schema_path, exist_ok=True) - shutil.copy( - common_yml_case_path("schemas/eth/ethereum_schema_v5"), - os.path.join(p._schema_storage.config.import_schema_path, "ethereum.schema.yaml"), - ) + from tests.common.storages.utils import prepare_eth_import_folder + + prepare_eth_import_folder(p._schema_storage) + + +def set_naming_env(destination: str, naming_convention: str) -> None: + # snake case is for default convention so do not set it + if naming_convention != "snake_case": + # path convention to test weaviate ci_naming + if destination == "weaviate": + if naming_convention.endswith("sql_upper"): + pytest.skip(f"{naming_convention} not supported on weaviate") + else: + naming_convention = "dlt.destinations.impl.weaviate.ci_naming" + os.environ["SCHEMA__NAMING"] = naming_convention + + +def assert_naming_to_caps(destination: str, caps: DestinationCapabilitiesContext) -> None: + naming = Schema("test").naming + if ( + not caps.has_case_sensitive_identifiers + and caps.casefold_identifier is not str + and naming.is_case_sensitive + ): + pytest.skip( + f"Skipping for case insensitive destination {destination} with case folding because" + f" naming {naming.name()} is case sensitive" + ) diff --git a/tests/load/pipeline/test_scd2.py b/tests/load/pipeline/test_scd2.py index e8baa33ff3..b33c5a2590 100644 --- a/tests/load/pipeline/test_scd2.py +++ b/tests/load/pipeline/test_scd2.py @@ -17,12 +17,11 @@ from dlt.pipeline.exceptions import PipelineStepFailed from tests.cases import arrow_table_all_data_types -from tests.pipeline.utils import assert_load_info, load_table_counts -from tests.load.pipeline.utils import ( +from tests.load.utils import ( destinations_configs, DestinationTestConfiguration, ) -from tests.pipeline.utils import load_tables_to_dicts +from tests.pipeline.utils import load_tables_to_dicts, assert_load_info, load_table_counts from tests.utils import TPythonTableFormat @@ -104,7 +103,7 @@ def test_core_functionality( validity_column_names: List[str], active_record_timestamp: Optional[pendulum.DateTime], ) -> None: - p = destination_config.setup_pipeline("abstract", full_refresh=True) + p = destination_config.setup_pipeline("abstract", dev_mode=True) @dlt.resource( table_name="dim_test", @@ -243,7 +242,7 @@ def r(data): ) @pytest.mark.parametrize("simple", [True, False]) def test_child_table(destination_config: DestinationTestConfiguration, simple: bool) -> None: - p = destination_config.setup_pipeline("abstract", full_refresh=True) + p = destination_config.setup_pipeline("abstract", dev_mode=True) @dlt.resource( table_name="dim_test", write_disposition={"disposition": "merge", "strategy": "scd2"} @@ -386,7 +385,7 @@ def r(data): ids=lambda x: x.name, ) def test_grandchild_table(destination_config: DestinationTestConfiguration) -> None: - p = destination_config.setup_pipeline("abstract", full_refresh=True) + p = destination_config.setup_pipeline("abstract", dev_mode=True) @dlt.resource( table_name="dim_test", write_disposition={"disposition": "merge", "strategy": "scd2"} @@ -479,7 +478,7 @@ def r(data): ids=lambda x: x.name, ) def test_validity_column_name_conflict(destination_config: DestinationTestConfiguration) -> None: - p = destination_config.setup_pipeline("abstract", full_refresh=True) + p = destination_config.setup_pipeline("abstract", dev_mode=True) @dlt.resource( table_name="dim_test", @@ -525,7 +524,7 @@ def test_active_record_timestamp( destination_config: DestinationTestConfiguration, active_record_timestamp: Optional[TAnyDateTime], ) -> None: - p = destination_config.setup_pipeline("abstract", full_refresh=True) + p = destination_config.setup_pipeline("abstract", dev_mode=True) @dlt.resource( table_name="dim_test", @@ -572,7 +571,7 @@ def _make_scd2_r(table_: Any) -> DltResource: }, ).add_map(add_row_hash_to_table("row_hash")) - p = destination_config.setup_pipeline("abstract", full_refresh=True) + p = destination_config.setup_pipeline("abstract", dev_mode=True) info = p.run(_make_scd2_r(table), loader_file_format=destination_config.file_format) assert_load_info(info) # make sure we have scd2 columns in schema @@ -608,7 +607,7 @@ def _make_scd2_r(table_: Any) -> DltResource: ids=lambda x: x.name, ) def test_user_provided_row_hash(destination_config: DestinationTestConfiguration) -> None: - p = destination_config.setup_pipeline("abstract", full_refresh=True) + p = destination_config.setup_pipeline("abstract", dev_mode=True) @dlt.resource( table_name="dim_test", diff --git a/tests/load/pipeline/test_snowflake_pipeline.py b/tests/load/pipeline/test_snowflake_pipeline.py new file mode 100644 index 0000000000..cfb30e737e --- /dev/null +++ b/tests/load/pipeline/test_snowflake_pipeline.py @@ -0,0 +1,59 @@ +import os +import pytest + +import dlt + +from dlt.common.utils import uniq_id +from dlt.destinations.exceptions import DatabaseUndefinedRelation + +from tests.load.snowflake.test_snowflake_client import QUERY_TAG +from tests.pipeline.utils import assert_load_info +from tests.load.utils import destinations_configs, DestinationTestConfiguration + +# mark all tests as essential, do not remove +pytestmark = pytest.mark.essential + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["snowflake"]), + ids=lambda x: x.name, +) +def test_snowflake_case_sensitive_identifiers( + destination_config: DestinationTestConfiguration, +) -> None: + # enable query tagging + os.environ["DESTINATION__SNOWFLAKE__CREDENTIALS__QUERY_TAG"] = QUERY_TAG + snow_ = dlt.destinations.snowflake(naming_convention="sql_cs_v1") + + dataset_name = "CaseSensitive_Dataset_" + uniq_id() + pipeline = destination_config.setup_pipeline( + "test_snowflake_case_sensitive_identifiers", dataset_name=dataset_name, destination=snow_ + ) + caps = pipeline.destination.capabilities() + assert caps.naming_convention == "sql_cs_v1" + + destination_client = pipeline.destination_client() + # assert snowflake caps to be in case sensitive mode + assert destination_client.capabilities.casefold_identifier is str + + # load some case sensitive data + info = pipeline.run([{"Id": 1, "Capital": 0.0}], table_name="Expenses") + assert_load_info(info) + with pipeline.sql_client() as client: + assert client.has_dataset() + # use the same case sensitive dataset + with client.with_alternative_dataset_name(dataset_name): + assert client.has_dataset() + # make it case insensitive (upper) + with client.with_alternative_dataset_name(dataset_name.upper()): + assert not client.has_dataset() + # keep case sensitive but make lowercase + with client.with_alternative_dataset_name(dataset_name.lower()): + assert not client.has_dataset() + + # must use quoted identifiers + rows = client.execute_sql('SELECT "Id", "Capital" FROM "Expenses"') + print(rows) + with pytest.raises(DatabaseUndefinedRelation): + client.execute_sql('SELECT "Id", "Capital" FROM Expenses') diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py index e0e2154b57..7f1427f20f 100644 --- a/tests/load/pipeline/test_stage_loading.py +++ b/tests/load/pipeline/test_stage_loading.py @@ -8,15 +8,13 @@ from dlt.common.schema.typing import TDataType from tests.load.pipeline.test_merge_disposition import github -from tests.pipeline.utils import load_table_counts -from tests.pipeline.utils import assert_load_info +from tests.pipeline.utils import load_table_counts, assert_load_info from tests.load.utils import ( - TABLE_ROW_ALL_DATA_TYPES, - TABLE_UPDATE_COLUMNS_SCHEMA, + destinations_configs, + DestinationTestConfiguration, assert_all_data_types_row, ) from tests.cases import table_update_and_row -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration @dlt.resource( @@ -65,12 +63,17 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: ) == 4 ) + # pipeline state is loaded with preferred format, so allows (possibly) for two job formats + caps = pipeline.destination.capabilities() + # NOTE: preferred_staging_file_format goes first because here we test staged loading and + # default caps will be modified so preferred_staging_file_format is used as main + preferred_format = caps.preferred_staging_file_format or caps.preferred_loader_file_format assert ( len( [ x for x in package_info.jobs["completed_jobs"] - if x.job_file_info.file_format == destination_config.file_format + if x.job_file_info.file_format in (destination_config.file_format, preferred_format) ] ) == 4 diff --git a/tests/load/pipeline/test_write_disposition_changes.py b/tests/load/pipeline/test_write_disposition_changes.py index 16c589352e..ba2f6bf172 100644 --- a/tests/load/pipeline/test_write_disposition_changes.py +++ b/tests/load/pipeline/test_write_disposition_changes.py @@ -1,7 +1,7 @@ import pytest import dlt from typing import Any -from tests.load.pipeline.utils import ( +from tests.load.utils import ( destinations_configs, DestinationTestConfiguration, ) @@ -124,9 +124,13 @@ def source(): ) # schemaless destinations allow adding of root key without the pipeline failing - # for now this is only the case for dremio + # they do not mind adding NOT NULL columns to tables with existing data (id NOT NULL is supported at all) # doing this will result in somewhat useless behavior - destination_allows_adding_root_key = destination_config.destination in ["dremio", "clickhouse"] + destination_allows_adding_root_key = destination_config.destination in [ + "dremio", + "clickhouse", + "athena", + ] if destination_allows_adding_root_key and not with_root_key: pipeline.run( diff --git a/tests/load/pipeline/utils.py b/tests/load/pipeline/utils.py index d762029ddd..679c2d6da9 100644 --- a/tests/load/pipeline/utils.py +++ b/tests/load/pipeline/utils.py @@ -1,67 +1 @@ -from typing import Any, Iterator, List, Sequence, TYPE_CHECKING, Callable -import pytest - -import dlt -from dlt.common.destination.reference import WithStagingDataset - -from dlt.common.configuration.container import Container -from dlt.common.pipeline import LoadInfo, PipelineContext - -from tests.load.utils import DestinationTestConfiguration, destinations_configs -from dlt.destinations.exceptions import CantExtractTablePrefix - -if TYPE_CHECKING: - from dlt.destinations.impl.filesystem.filesystem import FilesystemClient - REPLACE_STRATEGIES = ["truncate-and-insert", "insert-from-staging", "staging-optimized"] - - -@pytest.fixture(autouse=True) -def drop_pipeline(request) -> Iterator[None]: - yield - if "no_load" in request.keywords: - return - try: - drop_active_pipeline_data() - except CantExtractTablePrefix: - # for some tests we test that this exception is raised, - # so we suppress it here - pass - - -def drop_active_pipeline_data() -> None: - """Drops all the datasets for currently active pipeline, wipes the working folder and then deactivated it.""" - if Container()[PipelineContext].is_active(): - # take existing pipeline - p = dlt.pipeline() - - def _drop_dataset(schema_name: str) -> None: - with p.destination_client(schema_name) as client: - try: - client.drop_storage() - print("dropped") - except Exception as exc: - print(exc) - if isinstance(client, WithStagingDataset): - with client.with_staging_dataset(): - try: - client.drop_storage() - print("staging dropped") - except Exception as exc: - print(exc) - - # drop_func = _drop_dataset_fs if _is_filesystem(p) else _drop_dataset_sql - # take all schemas and if destination was set - if p.destination: - if p.config.use_single_dataset: - # drop just the dataset for default schema - if p.default_schema_name: - _drop_dataset(p.default_schema_name) - else: - # for each schema, drop the dataset - for schema_name in p.schema_names: - _drop_dataset(schema_name) - - # p._wipe_working_folder() - # deactivate context - Container()[PipelineContext].deactivate() diff --git a/tests/load/postgres/test_postgres_client.py b/tests/load/postgres/test_postgres_client.py index a0fbd85b5b..d8cd996dcf 100644 --- a/tests/load/postgres/test_postgres_client.py +++ b/tests/load/postgres/test_postgres_client.py @@ -11,7 +11,7 @@ from dlt.destinations.impl.postgres.postgres import PostgresClient from dlt.destinations.impl.postgres.sql_client import psycopg2 -from tests.utils import TEST_STORAGE_ROOT, delete_test_storage, skipifpypy, preserve_environ +from tests.utils import TEST_STORAGE_ROOT, delete_test_storage, skipifpypy from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage from tests.common.configuration.utils import environment diff --git a/tests/load/postgres/test_postgres_table_builder.py b/tests/load/postgres/test_postgres_table_builder.py index 7566b8afce..86bd67db9a 100644 --- a/tests/load/postgres/test_postgres_table_builder.py +++ b/tests/load/postgres/test_postgres_table_builder.py @@ -4,8 +4,9 @@ from dlt.common.exceptions import TerminalValueError from dlt.common.utils import uniq_id -from dlt.common.schema import Schema +from dlt.common.schema import Schema, utils +from dlt.destinations import postgres from dlt.destinations.impl.postgres.postgres import PostgresClient from dlt.destinations.impl.postgres.configuration import ( PostgresClientConfiguration, @@ -25,16 +26,34 @@ @pytest.fixture def client(empty_schema: Schema) -> PostgresClient: + return create_client(empty_schema) + + +@pytest.fixture +def cs_client(empty_schema: Schema) -> PostgresClient: + # change normalizer to case sensitive + empty_schema._normalizers_config["names"] = "tests.common.cases.normalizers.title_case" + empty_schema.update_normalizers() + return create_client(empty_schema) + + +def create_client(empty_schema: Schema) -> PostgresClient: # return client without opening connection - return PostgresClient( - empty_schema, - PostgresClientConfiguration(credentials=PostgresCredentials())._bind_dataset_name( - dataset_name="test_" + uniq_id() - ), + config = PostgresClientConfiguration(credentials=PostgresCredentials())._bind_dataset_name( + dataset_name="test_" + uniq_id() ) + return postgres().client(empty_schema, config) def test_create_table(client: PostgresClient) -> None: + # make sure we are in case insensitive mode + assert client.capabilities.generates_case_sensitive_identifiers() is False + # check if dataset name is properly folded + assert client.sql_client.dataset_name == client.config.dataset_name # identical to config + assert ( + client.sql_client.staging_dataset_name + == client.config.staging_dataset_name_layout % client.config.dataset_name + ) # non existing table sql = client._get_table_update_sql("event_test_table", TABLE_UPDATE, False)[0] sqlfluff.parse(sql, dialect="postgres") @@ -102,7 +121,7 @@ def test_alter_table(client: PostgresClient) -> None: assert '"col11_precision" time (3) without time zone NOT NULL' in sql -def test_create_table_with_hints(client: PostgresClient) -> None: +def test_create_table_with_hints(client: PostgresClient, empty_schema: Schema) -> None: mod_update = deepcopy(TABLE_UPDATE) # timestamp mod_update[0]["primary_key"] = True @@ -119,8 +138,8 @@ def test_create_table_with_hints(client: PostgresClient) -> None: assert '"col4" timestamp with time zone NOT NULL' in sql # same thing without indexes - client = PostgresClient( - client.schema, + client = postgres().client( + empty_schema, PostgresClientConfiguration( create_indexes=False, credentials=PostgresCredentials(), @@ -129,3 +148,28 @@ def test_create_table_with_hints(client: PostgresClient) -> None: sql = client._get_table_update_sql("event_test_table", mod_update, False)[0] sqlfluff.parse(sql, dialect="postgres") assert '"col2" double precision NOT NULL' in sql + + +def test_create_table_case_sensitive(cs_client: PostgresClient) -> None: + # did we switch to case sensitive + assert cs_client.capabilities.generates_case_sensitive_identifiers() is True + # check dataset names + assert cs_client.sql_client.dataset_name.startswith("Test") + with cs_client.with_staging_dataset(): + assert cs_client.sql_client.dataset_name.endswith("staginG") + assert cs_client.sql_client.staging_dataset_name.endswith("staginG") + # check tables + cs_client.schema.update_table( + utils.new_table("event_test_table", columns=deepcopy(TABLE_UPDATE)) + ) + sql = cs_client._get_table_update_sql( + "Event_test_tablE", + list(cs_client.schema.get_table_columns("Event_test_tablE").values()), + False, + )[0] + sqlfluff.parse(sql, dialect="postgres") + # everything capitalized + assert cs_client.sql_client.fully_qualified_dataset_name(escape=False)[0] == "T" # Test + # every line starts with "Col" + for line in sql.split("\n")[1:]: + assert line.startswith('"Col') diff --git a/tests/load/qdrant/test_pipeline.py b/tests/load/qdrant/test_pipeline.py index d50b50282a..73f53221ed 100644 --- a/tests/load/qdrant/test_pipeline.py +++ b/tests/load/qdrant/test_pipeline.py @@ -1,15 +1,20 @@ import pytest from typing import Iterator +from tempfile import TemporaryDirectory +import os import dlt from dlt.common import json from dlt.common.utils import uniq_id +from dlt.common.typing import DictStrStr +from dlt.destinations.adapters import qdrant_adapter from dlt.destinations.impl.qdrant.qdrant_adapter import qdrant_adapter, VECTORIZE_HINT -from dlt.destinations.impl.qdrant.qdrant_client import QdrantClient +from dlt.destinations.impl.qdrant.qdrant_job_client import QdrantClient from tests.pipeline.utils import assert_load_info from tests.load.qdrant.utils import drop_active_pipeline_data, assert_collection from tests.load.utils import sequence_generator +from tests.utils import preserve_environ # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -68,6 +73,8 @@ def some_data(): assert schema state = client.get_stored_state("test_pipeline_append") assert state + state = client.get_stored_state("unknown_pipeline") + assert state is None def test_pipeline_append() -> None: @@ -316,8 +323,8 @@ def test_merge_github_nested() -> None: primary_key="id", ) assert_load_info(info) + # assert if schema contains tables with right names - print(p.default_schema.tables.keys()) assert set(p.default_schema.tables.keys()) == { "_dlt_version", "_dlt_loads", @@ -358,3 +365,20 @@ def test_empty_dataset_allowed() -> None: assert client.dataset_name is None assert client.sentinel_collection == "DltSentinelCollection" assert_collection(p, "content", expected_items_count=3) + + +def test_qdrant_local_parallelism_disabled(preserve_environ) -> None: + os.environ["DATA_WRITER__FILE_MAX_ITEMS"] = "20" + + with TemporaryDirectory() as tmpdir: + p = dlt.pipeline(destination=dlt.destinations.qdrant(path=tmpdir)) + + # Data writer limit ensures that we create multiple load files to the same table + @dlt.resource + def q_data(): + for i in range(222): + yield {"doc_id": i, "content": f"content {i}"} + + info = p.run(q_data) + + assert_load_info(info) diff --git a/tests/load/qdrant/test_restore_state.py b/tests/load/qdrant/test_restore_state.py new file mode 100644 index 0000000000..31bc725d24 --- /dev/null +++ b/tests/load/qdrant/test_restore_state.py @@ -0,0 +1,70 @@ +from typing import TYPE_CHECKING +import pytest +from qdrant_client import models + +import dlt +from tests.load.utils import destinations_configs, DestinationTestConfiguration + +from dlt.common.destination.reference import JobClientBase, WithStateSync +from dlt.destinations.impl.qdrant.qdrant_job_client import QdrantClient + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_vector_configs=True, subset=["qdrant"]), + ids=lambda x: x.name, +) +def test_uncommitted_state(destination_config: DestinationTestConfiguration): + """Load uncommitted state into qdrant, meaning that data is written to the state + table but load is not completed (nothing is added to loads table) + + Ensure that state restoration does not include such state + """ + # Type hint of JobClientBase with WithStateSync mixin + + pipeline = destination_config.setup_pipeline("uncommitted_state", dev_mode=True) + + state_val = 0 + + @dlt.resource + def dummy_table(): + dlt.current.resource_state("dummy_table")["val"] = state_val + yield [1, 2, 3] + + # Create > 10 load packages to be above pagination size when restoring state + for _ in range(12): + state_val += 1 + pipeline.extract(dummy_table) + + pipeline.normalize() + info = pipeline.load(raise_on_failed_jobs=True) + + client: QdrantClient + with pipeline.destination_client() as client: # type: ignore[assignment] + state = client.get_stored_state(pipeline.pipeline_name) + + assert state and state.version == state_val + + # Delete last 10 _dlt_loads entries so pagination is triggered when restoring state + with pipeline.destination_client() as client: # type: ignore[assignment] + table_name = client._make_qualified_collection_name( + pipeline.default_schema.loads_table_name + ) + p_load_id = pipeline.default_schema.naming.normalize_identifier("load_id") + + client.db_client.delete( + table_name, + points_selector=models.Filter( + must=[ + models.FieldCondition( + key=p_load_id, match=models.MatchAny(any=info.loads_ids[2:]) + ) + ] + ), + ) + + with pipeline.destination_client() as client: # type: ignore[assignment] + state = client.get_stored_state(pipeline.pipeline_name) + + # Latest committed state is restored + assert state and state.version == 2 diff --git a/tests/load/qdrant/utils.py b/tests/load/qdrant/utils.py index 74d5db9715..e96e06be87 100644 --- a/tests/load/qdrant/utils.py +++ b/tests/load/qdrant/utils.py @@ -5,7 +5,7 @@ from dlt.common.pipeline import PipelineContext from dlt.common.configuration.container import Container -from dlt.destinations.impl.qdrant.qdrant_client import QdrantClient +from dlt.destinations.impl.qdrant.qdrant_job_client import QdrantClient def assert_unordered_list_equal(list1: List[Any], list2: List[Any]) -> None: @@ -20,16 +20,16 @@ def assert_collection( expected_items_count: int = None, items: List[Any] = None, ) -> None: - client: QdrantClient = pipeline.destination_client() # type: ignore[assignment] + client: QdrantClient + with pipeline.destination_client() as client: # type: ignore[assignment] + # Check if collection exists + exists = client._collection_exists(collection_name) + assert exists - # Check if collection exists - exists = client._collection_exists(collection_name) - assert exists - - qualified_collection_name = client._make_qualified_collection_name(collection_name) - point_records, offset = client.db_client.scroll( - qualified_collection_name, with_payload=True, limit=50 - ) + qualified_collection_name = client._make_qualified_collection_name(collection_name) + point_records, offset = client.db_client.scroll( + qualified_collection_name, with_payload=True, limit=50 + ) if expected_items_count is not None: assert expected_items_count == len(point_records) @@ -55,10 +55,11 @@ def has_collections(client): if Container()[PipelineContext].is_active(): # take existing pipeline p = dlt.pipeline() - client: QdrantClient = p.destination_client() # type: ignore[assignment] + client: QdrantClient - if has_collections(client): - client.drop_storage() + with p.destination_client() as client: # type: ignore[assignment] + if has_collections(client): + client.drop_storage() p._wipe_working_folder() # deactivate context diff --git a/tests/load/redshift/test_redshift_client.py b/tests/load/redshift/test_redshift_client.py index 03bb57c3b4..bb923df673 100644 --- a/tests/load/redshift/test_redshift_client.py +++ b/tests/load/redshift/test_redshift_client.py @@ -6,13 +6,18 @@ from dlt.common import json, pendulum from dlt.common.configuration.resolve import resolve_configuration +from dlt.common.schema.schema import Schema from dlt.common.schema.typing import VERSION_TABLE_NAME from dlt.common.storages import FileStorage from dlt.common.storages.schema_storage import SchemaStorage from dlt.common.utils import uniq_id from dlt.destinations.exceptions import DatabaseTerminalException -from dlt.destinations.impl.redshift.configuration import RedshiftCredentials +from dlt.destinations import redshift +from dlt.destinations.impl.redshift.configuration import ( + RedshiftCredentials, + RedshiftClientConfiguration, +) from dlt.destinations.impl.redshift.redshift import RedshiftClient, psycopg2 from tests.common.utils import COMMON_TEST_CASES_PATH @@ -42,6 +47,34 @@ def test_postgres_and_redshift_credentials_defaults() -> None: assert red_cred.port == 5439 +def test_redshift_factory() -> None: + schema = Schema("schema") + dest = redshift() + client = dest.client(schema, RedshiftClientConfiguration()._bind_dataset_name("dataset")) + assert client.config.staging_iam_role is None + assert client.config.has_case_sensitive_identifiers is False + assert client.capabilities.has_case_sensitive_identifiers is False + assert client.capabilities.casefold_identifier is str.lower + + # set args explicitly + dest = redshift(has_case_sensitive_identifiers=True, staging_iam_role="LOADER") + client = dest.client(schema, RedshiftClientConfiguration()._bind_dataset_name("dataset")) + assert client.config.staging_iam_role == "LOADER" + assert client.config.has_case_sensitive_identifiers is True + assert client.capabilities.has_case_sensitive_identifiers is True + assert client.capabilities.casefold_identifier is str + + # set args via config + os.environ["DESTINATION__STAGING_IAM_ROLE"] = "LOADER" + os.environ["DESTINATION__HAS_CASE_SENSITIVE_IDENTIFIERS"] = "True" + dest = redshift() + client = dest.client(schema, RedshiftClientConfiguration()._bind_dataset_name("dataset")) + assert client.config.staging_iam_role == "LOADER" + assert client.config.has_case_sensitive_identifiers is True + assert client.capabilities.has_case_sensitive_identifiers is True + assert client.capabilities.casefold_identifier is str + + @skipifpypy def test_text_too_long(client: RedshiftClient, file_storage: FileStorage) -> None: caps = client.capabilities diff --git a/tests/load/redshift/test_redshift_table_builder.py b/tests/load/redshift/test_redshift_table_builder.py index 2427bc7cfe..37ca20232d 100644 --- a/tests/load/redshift/test_redshift_table_builder.py +++ b/tests/load/redshift/test_redshift_table_builder.py @@ -3,9 +3,10 @@ from copy import deepcopy from dlt.common.utils import uniq_id, custom_environ, digest128 -from dlt.common.schema import Schema +from dlt.common.schema import Schema, utils from dlt.common.configuration import resolve_configuration +from dlt.destinations import redshift from dlt.destinations.impl.redshift.redshift import RedshiftClient from dlt.destinations.impl.redshift.configuration import ( RedshiftClientConfiguration, @@ -20,12 +21,25 @@ @pytest.fixture def client(empty_schema: Schema) -> RedshiftClient: + return create_client(empty_schema) + + +@pytest.fixture +def cs_client(empty_schema: Schema) -> RedshiftClient: + empty_schema._normalizers_config["names"] = "tests.common.cases.normalizers.title_case" + empty_schema.update_normalizers() + # make the destination case sensitive + return create_client(empty_schema, has_case_sensitive_identifiers=True) + + +def create_client(schema: Schema, has_case_sensitive_identifiers: bool = False) -> RedshiftClient: # return client without opening connection - return RedshiftClient( - empty_schema, - RedshiftClientConfiguration(credentials=RedshiftCredentials())._bind_dataset_name( - dataset_name="test_" + uniq_id() - ), + return redshift().client( + schema, + RedshiftClientConfiguration( + credentials=RedshiftCredentials(), + has_case_sensitive_identifiers=has_case_sensitive_identifiers, + )._bind_dataset_name(dataset_name="test_" + uniq_id()), ) @@ -54,6 +68,7 @@ def test_redshift_configuration() -> None: def test_create_table(client: RedshiftClient) -> None: + assert client.capabilities.generates_case_sensitive_identifiers() is False # non existing table sql = ";".join(client._get_table_update_sql("event_test_table", TABLE_UPDATE, False)) sqlfluff.parse(sql, dialect="redshift") @@ -104,6 +119,29 @@ def test_alter_table(client: RedshiftClient) -> None: assert '"col11_precision" time without time zone NOT NULL' in sql +def test_create_table_case_sensitive(cs_client: RedshiftClient) -> None: + # did we switch to case sensitive + assert cs_client.capabilities.generates_case_sensitive_identifiers() is True + # check dataset names + assert cs_client.sql_client.dataset_name.startswith("Test") + + # check tables + cs_client.schema.update_table( + utils.new_table("event_test_table", columns=deepcopy(TABLE_UPDATE)) + ) + sql = cs_client._get_table_update_sql( + "Event_test_tablE", + list(cs_client.schema.get_table_columns("Event_test_tablE").values()), + False, + )[0] + sqlfluff.parse(sql, dialect="redshift") + # everything capitalized + assert cs_client.sql_client.fully_qualified_dataset_name(escape=False)[0] == "T" # Test + # every line starts with "Col" + for line in sql.split("\n")[1:]: + assert line.startswith('"Col') + + def test_create_table_with_hints(client: RedshiftClient) -> None: mod_update = deepcopy(TABLE_UPDATE) # timestamp diff --git a/tests/load/snowflake/test_snowflake_client.py b/tests/load/snowflake/test_snowflake_client.py new file mode 100644 index 0000000000..15207d450d --- /dev/null +++ b/tests/load/snowflake/test_snowflake_client.py @@ -0,0 +1,62 @@ +import os +from typing import Iterator +from pytest_mock import MockerFixture +import pytest + +from dlt.destinations.impl.snowflake.configuration import SnowflakeCredentials +from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient +from dlt.destinations.job_client_impl import SqlJobClientBase + +from dlt.destinations.sql_client import TJobQueryTags + +from tests.load.utils import yield_client_with_storage + +# mark all tests as essential, do not remove +pytestmark = pytest.mark.essential + +QUERY_TAG = ( + '{{"source":"{source}", "resource":"{resource}", "table": "{table}", "load_id":"{load_id}",' + ' "pipeline_name":"{pipeline_name}"}}' +) +QUERY_TAGS_DICT: TJobQueryTags = { + "source": "test_source", + "resource": "test_resource", + "table": "test_table", + "load_id": "1109291083091", + "pipeline_name": "test_pipeline", +} + + +@pytest.fixture(scope="function") +def client() -> Iterator[SqlJobClientBase]: + os.environ["CREDENTIALS__QUERY_TAG"] = QUERY_TAG + yield from yield_client_with_storage("snowflake") + + +def test_query_tag(client: SnowflakeClient, mocker: MockerFixture): + assert client.config.credentials.query_tag == QUERY_TAG + # make sure we generate proper query + execute_sql_spy = mocker.spy(client.sql_client, "execute_sql") + # reset the query if tags are not set + client.sql_client.set_query_tags(None) + execute_sql_spy.assert_called_once_with(sql="ALTER SESSION UNSET QUERY_TAG") + execute_sql_spy.reset_mock() + client.sql_client.set_query_tags({}) # type: ignore[typeddict-item] + execute_sql_spy.assert_called_once_with(sql="ALTER SESSION UNSET QUERY_TAG") + execute_sql_spy.reset_mock() + # set query tags + client.sql_client.set_query_tags(QUERY_TAGS_DICT) + execute_sql_spy.assert_called_once_with( + sql=( + 'ALTER SESSION SET QUERY_TAG = \'{"source":"test_source", "resource":"test_resource",' + ' "table": "test_table", "load_id":"1109291083091", "pipeline_name":"test_pipeline"}\'' + ) + ) + # remove query tag from config + client.sql_client.credentials.query_tag = None + execute_sql_spy.reset_mock() + client.sql_client.set_query_tags(QUERY_TAGS_DICT) + execute_sql_spy.assert_not_called + execute_sql_spy.reset_mock() + client.sql_client.set_query_tags(None) + execute_sql_spy.assert_not_called diff --git a/tests/load/snowflake/test_snowflake_configuration.py b/tests/load/snowflake/test_snowflake_configuration.py index 610aab7c20..10d93d104c 100644 --- a/tests/load/snowflake/test_snowflake_configuration.py +++ b/tests/load/snowflake/test_snowflake_configuration.py @@ -1,10 +1,14 @@ import os import pytest from pathlib import Path -from dlt.common.libs.sql_alchemy import make_url +from urllib3.util import parse_url + +from dlt.common.configuration.utils import add_config_to_env +from tests.utils import TEST_DICT_CONFIG_PROVIDER pytest.importorskip("snowflake") +from dlt.common.libs.sql_alchemy import make_url from dlt.common.configuration.resolve import resolve_configuration from dlt.common.configuration.exceptions import ConfigurationValueError from dlt.common.utils import digest128 @@ -20,12 +24,20 @@ # mark all tests as essential, do not remove pytestmark = pytest.mark.essential +# PEM key +PKEY_PEM_STR = Path("./tests/common/cases/secrets/encrypted-private-key").read_text("utf8") +# base64 encoded DER key +PKEY_DER_STR = Path("./tests/common/cases/secrets/encrypted-private-key-base64").read_text("utf8") + +PKEY_PASSPHRASE = "12345" + def test_connection_string_with_all_params() -> None: - url = "snowflake://user1:pass1@host1/db1?application=dltHub_dlt&warehouse=warehouse1&role=role1&private_key=cGs%3D&private_key_passphrase=paphr" + url = "snowflake://user1:pass1@host1/db1?warehouse=warehouse1&role=role1&private_key=cGs%3D&private_key_passphrase=paphr&authenticator=oauth&token=TOK" creds = SnowflakeCredentials() creds.parse_native_representation(url) + assert not creds.is_resolved() assert creds.database == "db1" assert creds.username == "user1" @@ -35,6 +47,8 @@ def test_connection_string_with_all_params() -> None: assert creds.role == "role1" assert creds.private_key == "cGs=" assert creds.private_key_passphrase == "paphr" + assert creds.authenticator == "oauth" + assert creds.token == "TOK" expected = make_url(url) to_url_value = str(creds.to_url()) @@ -43,23 +57,103 @@ def test_connection_string_with_all_params() -> None: assert make_url(creds.to_native_representation()) == expected assert to_url_value == str(expected) + +def test_custom_application(): + creds = SnowflakeCredentials() creds.application = "custom" - url = "snowflake://user1:pass1@host1/db1?application=custom&warehouse=warehouse1&role=role1&private_key=cGs%3D&private_key_passphrase=paphr" + url = "snowflake://user1:pass1@host1/db1?authenticator=oauth&warehouse=warehouse1&role=role1&private_key=cGs%3D&private_key_passphrase=paphr&token=TOK" creds.parse_native_representation(url) + assert not creds.is_resolved() expected = make_url(url) to_url_value = str(creds.to_url()) assert make_url(creds.to_native_representation()) == expected assert to_url_value == str(expected) - assert "application=custom" in str(expected) + assert "application=custom" not in str(expected) -def test_to_connector_params() -> None: - # PEM key - pkey_str = Path("./tests/common/cases/secrets/encrypted-private-key").read_text("utf8") +def test_set_all_from_env(environment) -> None: + url = "snowflake://user1:pass1@host1/db1?authenticator=oauth&warehouse=warehouse1&role=role1&private_key=cGs%3D&private_key_passphrase=paphr&token=TOK" + c = SnowflakeCredentials(url) + add_config_to_env(c) + # resolve from environments + creds = resolve_configuration(SnowflakeCredentials()) + assert creds.is_resolved() + assert creds.database == "db1" + assert creds.username == "user1" + assert creds.password == "pass1" + assert creds.host == "host1" + assert creds.warehouse == "warehouse1" + assert creds.role == "role1" + assert creds.private_key == "cGs=" + assert creds.private_key_passphrase == "paphr" + assert creds.authenticator == "oauth" + assert creds.token == "TOK" + +def test_only_authenticator() -> None: + url = "snowflake://user1@host1/db1" + # password, pk or authenticator must be specified + with pytest.raises(ConfigurationValueError): + resolve_configuration(SnowflakeCredentials(url)) + c = resolve_configuration(SnowflakeCredentials("snowflake://user1@host1/db1?authenticator=uri")) + assert c.authenticator == "uri" + assert c.token is None + # token not present + assert c.to_connector_params() == { + "authenticator": "uri", + "user": "user1", + "password": None, + "account": "host1", + "database": "db1", + "application": "dltHub_dlt", + } + c = resolve_configuration( + SnowflakeCredentials("snowflake://user1@host1/db1?authenticator=oauth&token=TOK") + ) + assert c.to_connector_params() == { + "authenticator": "oauth", + "token": "TOK", + "user": "user1", + "password": None, + "account": "host1", + "database": "db1", + "application": "dltHub_dlt", + } + + +# def test_no_query(environment) -> None: +# c = SnowflakeCredentials("snowflake://user1:pass1@host1/db1") +# assert str(c.to_url()) == "snowflake://user1:pass1@host1/db1" +# print(c.to_url()) + + +def test_query_additional_params() -> None: + c = SnowflakeCredentials("snowflake://user1:pass1@host1/db1?keep_alive=true") + assert c.to_connector_params()["keep_alive"] == "true" + + # try a typed param + with TEST_DICT_CONFIG_PROVIDER().values({"credentials": {"query": {"keep_alive": True}}}): + c = SnowflakeCredentials("snowflake://user1:pass1@host1/db1") + print(c.__is_resolved__) + assert c.is_resolved() is False + c = resolve_configuration(c) + assert c.to_connector_params()["keep_alive"] is True + # serialize to str + assert c.to_url().query["keep_alive"] == "True" + + +def test_overwrite_query_value_from_explicit() -> None: + # value specified in the query is preserved over the value set in config + c = SnowflakeCredentials("snowflake://user1@host1/db1?authenticator=uri") + c.authenticator = "oauth" + assert c.to_url().query["authenticator"] == "oauth" + assert c.to_connector_params()["authenticator"] == "oauth" + + +def test_to_connector_params_private_key() -> None: creds = SnowflakeCredentials() - creds.private_key = pkey_str # type: ignore[assignment] - creds.private_key_passphrase = "12345" # type: ignore[assignment] + creds.private_key = PKEY_PEM_STR # type: ignore[assignment] + creds.private_key_passphrase = PKEY_PASSPHRASE # type: ignore[assignment] creds.username = "user1" creds.database = "db1" creds.host = "host1" @@ -82,12 +176,9 @@ def test_to_connector_params() -> None: application=SNOWFLAKE_APPLICATION_ID, ) - # base64 encoded DER key - pkey_str = Path("./tests/common/cases/secrets/encrypted-private-key-base64").read_text("utf8") - creds = SnowflakeCredentials() - creds.private_key = pkey_str # type: ignore[assignment] - creds.private_key_passphrase = "12345" # type: ignore[assignment] + creds.private_key = PKEY_DER_STR # type: ignore[assignment] + creds.private_key_passphrase = PKEY_PASSPHRASE # type: ignore[assignment] creds.username = "user1" creds.database = "db1" creds.host = "host1" @@ -127,7 +218,8 @@ def test_snowflake_credentials_native_value(environment) -> None: ) assert c.is_resolved() assert c.password == "pass" - assert "application=dlt" in str(c.to_url()) + assert c.application == "dlt" + assert "application=dlt" not in str(c.to_url()) # # but if password is specified - it is final c = resolve_configuration( SnowflakeCredentials(), @@ -138,14 +230,16 @@ def test_snowflake_credentials_native_value(environment) -> None: # set PK via env del os.environ["CREDENTIALS__PASSWORD"] - os.environ["CREDENTIALS__PRIVATE_KEY"] = "pk" + os.environ["CREDENTIALS__PRIVATE_KEY"] = PKEY_DER_STR + os.environ["CREDENTIALS__PRIVATE_KEY_PASSPHRASE"] = PKEY_PASSPHRASE c = resolve_configuration( SnowflakeCredentials(), explicit_value="snowflake://user1@host1/db1?warehouse=warehouse1&role=role1", ) assert c.is_resolved() - assert c.private_key == "pk" - assert "application=dlt" in str(c.to_url()) + assert c.private_key == PKEY_DER_STR + assert c.private_key_passphrase == PKEY_PASSPHRASE + assert c.password is None # check with application = "" it should not be in connection string os.environ["CREDENTIALS__APPLICATION"] = "" @@ -154,7 +248,18 @@ def test_snowflake_credentials_native_value(environment) -> None: explicit_value="snowflake://user1@host1/db1?warehouse=warehouse1&role=role1", ) assert c.is_resolved() + assert c.application == "" assert "application=" not in str(c.to_url()) + conn_params = c.to_connector_params() + assert isinstance(conn_params.pop("private_key"), bytes) + assert conn_params == { + "warehouse": "warehouse1", + "role": "role1", + "user": "user1", + "password": None, + "account": "host1", + "database": "db1", + } def test_snowflake_configuration() -> None: diff --git a/tests/load/snowflake/test_snowflake_table_builder.py b/tests/load/snowflake/test_snowflake_table_builder.py index bdbe888fb5..1fc0034f43 100644 --- a/tests/load/snowflake/test_snowflake_table_builder.py +++ b/tests/load/snowflake/test_snowflake_table_builder.py @@ -4,13 +4,13 @@ import sqlfluff from dlt.common.utils import uniq_id -from dlt.common.schema import Schema +from dlt.common.schema import Schema, utils +from dlt.destinations import snowflake from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient from dlt.destinations.impl.snowflake.configuration import ( SnowflakeClientConfiguration, SnowflakeCredentials, ) -from dlt.destinations.exceptions import DestinationSchemaWillNotUpdate from tests.load.utils import TABLE_UPDATE, empty_schema @@ -18,12 +18,24 @@ pytestmark = pytest.mark.essential +@pytest.fixture +def cs_client(empty_schema: Schema) -> SnowflakeClient: + # change normalizer to case sensitive + empty_schema._normalizers_config["names"] = "tests.common.cases.normalizers.title_case" + empty_schema.update_normalizers() + return create_client(empty_schema) + + @pytest.fixture def snowflake_client(empty_schema: Schema) -> SnowflakeClient: + return create_client(empty_schema) + + +def create_client(schema: Schema) -> SnowflakeClient: # return client without opening connection creds = SnowflakeCredentials() - return SnowflakeClient( - empty_schema, + return snowflake().client( + schema, SnowflakeClientConfiguration(credentials=creds)._bind_dataset_name( dataset_name="test_" + uniq_id() ), @@ -31,6 +43,22 @@ def snowflake_client(empty_schema: Schema) -> SnowflakeClient: def test_create_table(snowflake_client: SnowflakeClient) -> None: + # make sure we are in case insensitive mode + assert snowflake_client.capabilities.generates_case_sensitive_identifiers() is False + # check if dataset name is properly folded + assert ( + snowflake_client.sql_client.fully_qualified_dataset_name(escape=False) + == snowflake_client.config.dataset_name.upper() + ) + with snowflake_client.sql_client.with_staging_dataset(): + assert ( + snowflake_client.sql_client.fully_qualified_dataset_name(escape=False) + == ( + snowflake_client.config.staging_dataset_name_layout + % snowflake_client.config.dataset_name + ).upper() + ) + statements = snowflake_client._get_table_update_sql("event_test_table", TABLE_UPDATE, False) assert len(statements) == 1 sql = statements[0] @@ -81,6 +109,31 @@ def test_alter_table(snowflake_client: SnowflakeClient) -> None: assert '"COL2" FLOAT NOT NULL' in sql +def test_create_table_case_sensitive(cs_client: SnowflakeClient) -> None: + # did we switch to case sensitive + assert cs_client.capabilities.generates_case_sensitive_identifiers() is True + # check dataset names + assert cs_client.sql_client.dataset_name.startswith("Test") + with cs_client.with_staging_dataset(): + assert cs_client.sql_client.dataset_name.endswith("staginG") + assert cs_client.sql_client.staging_dataset_name.endswith("staginG") + # check tables + cs_client.schema.update_table( + utils.new_table("event_test_table", columns=deepcopy(TABLE_UPDATE)) + ) + sql = cs_client._get_table_update_sql( + "Event_test_tablE", + list(cs_client.schema.get_table_columns("Event_test_tablE").values()), + False, + )[0] + sqlfluff.parse(sql, dialect="snowflake") + # everything capitalized + assert cs_client.sql_client.fully_qualified_dataset_name(escape=False)[0] == "T" # Test + # every line starts with "Col" + for line in sql.split("\n")[1:]: + assert line.startswith('"Col') + + def test_create_table_with_partition_and_cluster(snowflake_client: SnowflakeClient) -> None: mod_update = deepcopy(TABLE_UPDATE) # timestamp diff --git a/tests/load/synapse/test_synapse_configuration.py b/tests/load/synapse/test_synapse_configuration.py index f366d87d09..8aaea03b0f 100644 --- a/tests/load/synapse/test_synapse_configuration.py +++ b/tests/load/synapse/test_synapse_configuration.py @@ -1,8 +1,11 @@ +import os import pytest from dlt.common.configuration import resolve_configuration from dlt.common.exceptions import SystemConfigurationException +from dlt.common.schema import Schema +from dlt.destinations import synapse from dlt.destinations.impl.synapse.configuration import ( SynapseClientConfiguration, SynapseCredentials, @@ -14,7 +17,42 @@ def test_synapse_configuration() -> None: # By default, unique indexes should not be created. - assert SynapseClientConfiguration().create_indexes is False + c = SynapseClientConfiguration() + assert c.create_indexes is False + assert c.has_case_sensitive_identifiers is False + assert c.staging_use_msi is False + + +def test_synapse_factory() -> None: + schema = Schema("schema") + dest = synapse() + client = dest.client(schema, SynapseClientConfiguration()._bind_dataset_name("dataset")) + assert client.config.create_indexes is False + assert client.config.staging_use_msi is False + assert client.config.has_case_sensitive_identifiers is False + assert client.capabilities.has_case_sensitive_identifiers is False + assert client.capabilities.casefold_identifier is str + + # set args explicitly + dest = synapse(has_case_sensitive_identifiers=True, create_indexes=True, staging_use_msi=True) + client = dest.client(schema, SynapseClientConfiguration()._bind_dataset_name("dataset")) + assert client.config.create_indexes is True + assert client.config.staging_use_msi is True + assert client.config.has_case_sensitive_identifiers is True + assert client.capabilities.has_case_sensitive_identifiers is True + assert client.capabilities.casefold_identifier is str + + # set args via config + os.environ["DESTINATION__CREATE_INDEXES"] = "True" + os.environ["DESTINATION__STAGING_USE_MSI"] = "True" + os.environ["DESTINATION__HAS_CASE_SENSITIVE_IDENTIFIERS"] = "True" + dest = synapse() + client = dest.client(schema, SynapseClientConfiguration()._bind_dataset_name("dataset")) + assert client.config.create_indexes is True + assert client.config.staging_use_msi is True + assert client.config.has_case_sensitive_identifiers is True + assert client.capabilities.has_case_sensitive_identifiers is True + assert client.capabilities.casefold_identifier is str def test_parse_native_representation() -> None: diff --git a/tests/load/synapse/test_synapse_table_builder.py b/tests/load/synapse/test_synapse_table_builder.py index 9ee2ebe202..1a92a20f1e 100644 --- a/tests/load/synapse/test_synapse_table_builder.py +++ b/tests/load/synapse/test_synapse_table_builder.py @@ -7,17 +7,18 @@ from dlt.common.utils import uniq_id from dlt.common.schema import Schema, TColumnHint -from dlt.destinations.impl.synapse.synapse import SynapseClient +from dlt.destinations import synapse +from dlt.destinations.impl.synapse.synapse import ( + SynapseClient, + HINT_TO_SYNAPSE_ATTR, + TABLE_INDEX_TYPE_TO_SYNAPSE_ATTR, +) from dlt.destinations.impl.synapse.configuration import ( SynapseClientConfiguration, SynapseCredentials, ) from tests.load.utils import TABLE_UPDATE, empty_schema -from dlt.destinations.impl.synapse.synapse import ( - HINT_TO_SYNAPSE_ATTR, - TABLE_INDEX_TYPE_TO_SYNAPSE_ATTR, -) # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -26,7 +27,7 @@ @pytest.fixture def client(empty_schema: Schema) -> SynapseClient: # return client without opening connection - client = SynapseClient( + client = synapse().client( empty_schema, SynapseClientConfiguration(credentials=SynapseCredentials())._bind_dataset_name( dataset_name="test_" + uniq_id() @@ -39,7 +40,7 @@ def client(empty_schema: Schema) -> SynapseClient: @pytest.fixture def client_with_indexes_enabled(empty_schema: Schema) -> SynapseClient: # return client without opening connection - client = SynapseClient( + client = synapse().client( empty_schema, SynapseClientConfiguration( credentials=SynapseCredentials(), create_indexes=True diff --git a/tests/load/synapse/test_synapse_table_indexing.py b/tests/load/synapse/test_synapse_table_indexing.py index a9d426ad4a..d877b769cc 100644 --- a/tests/load/synapse/test_synapse_table_indexing.py +++ b/tests/load/synapse/test_synapse_table_indexing.py @@ -1,20 +1,14 @@ import os import pytest from typing import Iterator, List, Any, Union -from textwrap import dedent import dlt from dlt.common.schema import TColumnSchema -from dlt.destinations.sql_client import SqlClientBase - -from dlt.destinations.impl.synapse import synapse_adapter +from dlt.destinations.adapters import synapse_adapter from dlt.destinations.impl.synapse.synapse_adapter import TTableIndexType from tests.load.utils import TABLE_UPDATE, TABLE_ROW_ALL_DATA_TYPES -from tests.load.pipeline.utils import ( - drop_pipeline, -) # this import ensures all test data gets removed from tests.load.synapse.utils import get_storage_table_index_type # mark all tests as essential, do not remove diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index 30de51f069..be917672f1 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -4,11 +4,11 @@ from unittest import mock import pytest from unittest.mock import patch -from typing import List +from typing import List, Tuple from dlt.common.exceptions import TerminalException, TerminalValueError from dlt.common.storages import FileStorage, PackageStorage, ParsedLoadJobFileName -from dlt.common.storages.load_package import LoadJobInfo +from dlt.common.storages.load_package import LoadJobInfo, TJobState from dlt.common.storages.load_storage import JobFileFormatUnsupported from dlt.common.destination.reference import LoadJob, TDestination from dlt.common.schema.utils import ( @@ -31,7 +31,6 @@ clean_test_storage, init_test_logging, TEST_DICT_CONFIG_PROVIDER, - preserve_environ, ) from tests.load.utils import prepare_load_package from tests.utils import skip_if_not_active, TEST_STORAGE_ROOT @@ -97,15 +96,11 @@ def test_unsupported_write_disposition() -> None: with ThreadPoolExecutor() as pool: load.run(pool) # job with unsupported write disp. is failed - exception_file = [ - f - for f in load.load_storage.normalized_packages.list_failed_jobs(load_id) - if f.endswith(".exception") - ][0] - assert ( - "LoadClientUnsupportedWriteDisposition" - in load.load_storage.normalized_packages.storage.load(exception_file) + failed_job = load.load_storage.normalized_packages.list_failed_jobs(load_id)[0] + failed_message = load.load_storage.normalized_packages.get_job_failed_message( + load_id, ParsedLoadJobFileName.parse(failed_job) ) + assert "LoadClientUnsupportedWriteDisposition" in failed_message def test_get_new_jobs_info() -> None: @@ -125,7 +120,7 @@ def test_get_completed_table_chain_single_job_per_table() -> None: schema.tables[table_name] = fill_hints_from_parent_and_clone_table(schema.tables, table) top_job_table = get_top_level_table(schema.tables, "event_user") - all_jobs = load.load_storage.normalized_packages.list_all_jobs(load_id) + all_jobs = load.load_storage.normalized_packages.list_all_jobs_with_states(load_id) assert get_completed_table_chain(schema, all_jobs, top_job_table) is None # fake being completed assert ( @@ -144,12 +139,12 @@ def test_get_completed_table_chain_single_job_per_table() -> None: load.load_storage.normalized_packages.start_job( load_id, "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl" ) - all_jobs = load.load_storage.normalized_packages.list_all_jobs(load_id) + all_jobs = load.load_storage.normalized_packages.list_all_jobs_with_states(load_id) assert get_completed_table_chain(schema, all_jobs, loop_top_job_table) is None load.load_storage.normalized_packages.complete_job( load_id, "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl" ) - all_jobs = load.load_storage.normalized_packages.list_all_jobs(load_id) + all_jobs = load.load_storage.normalized_packages.list_all_jobs_with_states(load_id) assert get_completed_table_chain(schema, all_jobs, loop_top_job_table) == [ schema.get_table("event_loop_interrupted") ] @@ -485,9 +480,7 @@ def test_extend_table_chain() -> None: # no jobs for bot assert _extend_tables_with_table_chain(schema, ["event_bot"], ["event_user"]) == set() # skip unseen tables - del schema.tables["event_user__parse_data__entities"][ # type:ignore[typeddict-item] - "x-normalizer" - ] + del schema.tables["event_user__parse_data__entities"]["x-normalizer"] entities_chain = { name for name in schema.data_table_names() @@ -533,25 +526,15 @@ def test_get_completed_table_chain_cases() -> None: # child completed, parent not event_user = schema.get_table("event_user") event_user_entities = schema.get_table("event_user__parse_data__entities") - event_user_job = LoadJobInfo( + event_user_job: Tuple[TJobState, ParsedLoadJobFileName] = ( "started_jobs", - "path", - 0, - None, - 0, ParsedLoadJobFileName("event_user", "event_user_id", 0, "jsonl"), - None, ) - event_user_entities_job = LoadJobInfo( + event_user_entities_job: Tuple[TJobState, ParsedLoadJobFileName] = ( "completed_jobs", - "path", - 0, - None, - 0, ParsedLoadJobFileName( "event_user__parse_data__entities", "event_user__parse_data__entities_id", 0, "jsonl" ), - None, ) chain = get_completed_table_chain(schema, [event_user_job, event_user_entities_job], event_user) assert chain is None @@ -561,24 +544,21 @@ def test_get_completed_table_chain_cases() -> None: schema, [event_user_job, event_user_entities_job], event_user, - event_user_job.job_file_info.job_id(), + event_user_job[1].job_id(), ) # full chain assert chain == [event_user, event_user_entities] # parent failed, child completed chain = get_completed_table_chain( - schema, [event_user_job._replace(state="failed_jobs"), event_user_entities_job], event_user + schema, [("failed_jobs", event_user_job[1]), event_user_entities_job], event_user ) assert chain == [event_user, event_user_entities] # both failed chain = get_completed_table_chain( schema, - [ - event_user_job._replace(state="failed_jobs"), - event_user_entities_job._replace(state="failed_jobs"), - ], + [("failed_jobs", event_user_job[1]), ("failed_jobs", event_user_entities_job[1])], event_user, ) assert chain == [event_user, event_user_entities] @@ -589,16 +569,16 @@ def test_get_completed_table_chain_cases() -> None: event_user["write_disposition"] = w_d # type:ignore[typeddict-item] chain = get_completed_table_chain( - schema, [event_user_job], event_user, event_user_job.job_file_info.job_id() + schema, [event_user_job], event_user, event_user_job[1].job_id() ) assert chain == user_chain # but if child is present and incomplete... chain = get_completed_table_chain( schema, - [event_user_job, event_user_entities_job._replace(state="new_jobs")], + [event_user_job, ("new_jobs", event_user_entities_job[1])], event_user, - event_user_job.job_file_info.job_id(), + event_user_job[1].job_id(), ) # noting is returned assert chain is None @@ -607,9 +587,9 @@ def test_get_completed_table_chain_cases() -> None: deep_child = schema.tables[ "event_user__parse_data__response_selector__default__response__response_templates" ] - del deep_child["x-normalizer"] # type:ignore[typeddict-item] + del deep_child["x-normalizer"] chain = get_completed_table_chain( - schema, [event_user_job], event_user, event_user_job.job_file_info.job_id() + schema, [event_user_job], event_user, event_user_job[1].job_id() ) user_chain.remove(deep_child) assert chain == user_chain @@ -784,7 +764,7 @@ def assert_complete_job(load: Load, should_delete_completed: bool = False) -> No assert not load.load_storage.storage.has_folder( load.load_storage.get_normalized_package_path(load_id) ) - completed_path = load.load_storage.loaded_packages.get_job_folder_path( + completed_path = load.load_storage.loaded_packages.get_job_state_folder_path( load_id, "completed_jobs" ) if should_delete_completed: diff --git a/tests/load/test_insert_job_client.py b/tests/load/test_insert_job_client.py index 1c035f7f68..38155a8b09 100644 --- a/tests/load/test_insert_job_client.py +++ b/tests/load/test_insert_job_client.py @@ -11,10 +11,14 @@ from dlt.destinations.insert_job_client import InsertValuesJobClient from tests.utils import TEST_STORAGE_ROOT, skipifpypy -from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage -from tests.load.pipeline.utils import destinations_configs +from tests.load.utils import ( + expect_load_file, + prepare_table, + yield_client_with_storage, + destinations_configs, +) -DEFAULT_SUBSET = ["duckdb", "redshift", "postgres", "mssql", "synapse"] +DEFAULT_SUBSET = ["duckdb", "redshift", "postgres", "mssql", "synapse", "motherduck"] @pytest.fixture @@ -176,7 +180,6 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage ids=lambda x: x.name, ) def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) -> None: - mocked_caps = client.sql_client.__class__.capabilities writer_type = client.capabilities.insert_values_writer_type insert_sql = prepare_insert_statement(10, writer_type) @@ -185,10 +188,10 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - elif writer_type == "select_union": pre, post, sep = ("SELECT ", "", " UNION ALL\n") + # caps are instance and are attr of sql client instance so it is safe to mock them + client.sql_client.capabilities.max_query_length = 2 # this guarantees that we execute inserts line by line - with patch.object(mocked_caps, "max_query_length", 2), patch.object( - client.sql_client, "execute_fragments" - ) as mocked_fragments: + with patch.object(client.sql_client, "execute_fragments") as mocked_fragments: user_table_name = prepare_table(client) expect_load_file(client, file_storage, insert_sql, user_table_name) # print(mocked_fragments.mock_calls) @@ -211,9 +214,8 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - # set query length so it reads data until separator ("," or " UNION ALL") (followed by \n) query_length = (idx - start_idx - 1) * 2 - with patch.object(mocked_caps, "max_query_length", query_length), patch.object( - client.sql_client, "execute_fragments" - ) as mocked_fragments: + client.sql_client.capabilities.max_query_length = query_length + with patch.object(client.sql_client, "execute_fragments") as mocked_fragments: user_table_name = prepare_table(client) expect_load_file(client, file_storage, insert_sql, user_table_name) # split in 2 on ',' @@ -221,9 +223,8 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - # so it reads until "\n" query_length = (idx - start_idx) * 2 - with patch.object(mocked_caps, "max_query_length", query_length), patch.object( - client.sql_client, "execute_fragments" - ) as mocked_fragments: + client.sql_client.capabilities.max_query_length = query_length + with patch.object(client.sql_client, "execute_fragments") as mocked_fragments: user_table_name = prepare_table(client) expect_load_file(client, file_storage, insert_sql, user_table_name) # split in 2 on separator ("," or " UNION ALL") @@ -235,9 +236,8 @@ def test_query_split(client: InsertValuesJobClient, file_storage: FileStorage) - elif writer_type == "select_union": offset = 1 query_length = (len(insert_sql) - start_idx - offset) * 2 - with patch.object(mocked_caps, "max_query_length", query_length), patch.object( - client.sql_client, "execute_fragments" - ) as mocked_fragments: + client.sql_client.capabilities.max_query_length = query_length + with patch.object(client.sql_client, "execute_fragments") as mocked_fragments: user_table_name = prepare_table(client) expect_load_file(client, file_storage, insert_sql, user_table_name) # split in 2 on ',' @@ -251,22 +251,21 @@ def assert_load_with_max_query( max_query_length: int, ) -> None: # load and check for real - mocked_caps = client.sql_client.__class__.capabilities - with patch.object(mocked_caps, "max_query_length", max_query_length): - user_table_name = prepare_table(client) - insert_sql = prepare_insert_statement( - insert_lines, client.capabilities.insert_values_writer_type - ) - expect_load_file(client, file_storage, insert_sql, user_table_name) - canonical_name = client.sql_client.make_qualified_table_name(user_table_name) - rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {canonical_name}")[0][0] - assert rows_count == insert_lines - # get all uniq ids in order - rows = client.sql_client.execute_sql( - f"SELECT _dlt_id FROM {canonical_name} ORDER BY timestamp ASC;" - ) - v_ids = list(map(lambda i: i[0], rows)) - assert list(map(str, range(0, insert_lines))) == v_ids + client.sql_client.capabilities.max_query_length = max_query_length + user_table_name = prepare_table(client) + insert_sql = prepare_insert_statement( + insert_lines, client.capabilities.insert_values_writer_type + ) + expect_load_file(client, file_storage, insert_sql, user_table_name) + canonical_name = client.sql_client.make_qualified_table_name(user_table_name) + rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {canonical_name}")[0][0] + assert rows_count == insert_lines + # get all uniq ids in order + rows = client.sql_client.execute_sql( + f"SELECT _dlt_id FROM {canonical_name} ORDER BY timestamp ASC;" + ) + v_ids = list(map(lambda i: i[0], rows)) + assert list(map(str, range(0, insert_lines))) == v_ids client.sql_client.execute_sql(f"DELETE FROM {canonical_name}") diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 7e360a6664..614eb17da1 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -5,9 +5,10 @@ from unittest.mock import patch import pytest import datetime # noqa: I251 -from typing import Iterator, Tuple, List, Dict, Any, Mapping, MutableMapping +from typing import Iterator, Tuple, List, Dict, Any from dlt.common import json, pendulum +from dlt.common.normalizers.naming import NamingConvention from dlt.common.schema import Schema from dlt.common.schema.typing import ( LOADS_TABLE_NAME, @@ -15,7 +16,7 @@ TWriteDisposition, TTableSchema, ) -from dlt.common.schema.utils import new_table, new_column +from dlt.common.schema.utils import new_table, new_column, pipeline_state_table from dlt.common.storages import FileStorage from dlt.common.schema import TTableSchemaColumns from dlt.common.utils import uniq_id @@ -26,7 +27,7 @@ ) from dlt.destinations.job_client_impl import SqlJobClientBase -from dlt.common.destination.reference import WithStagingDataset +from dlt.common.destination.reference import StateInfo, WithStagingDataset from tests.cases import table_update_and_row, assert_all_data_types_row from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage @@ -41,8 +42,18 @@ cm_yield_client_with_storage, write_dataset, prepare_table, + normalize_storage_table_cols, + destinations_configs, + DestinationTestConfiguration, +) + +# mark all tests as essential, do not remove +pytestmark = pytest.mark.essential +TEST_NAMING_CONVENTIONS = ( + "snake_case", + "tests.common.cases.normalizers.sql_upper", + "tests.common.cases.normalizers.title_case", ) -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration @pytest.fixture @@ -51,10 +62,20 @@ def file_storage() -> FileStorage: @pytest.fixture(scope="function") -def client(request) -> Iterator[SqlJobClientBase]: +def client(request, naming) -> Iterator[SqlJobClientBase]: yield from yield_client_with_storage(request.param.destination) +@pytest.fixture(scope="function") +def naming(request) -> str: + # NOTE: this fixture is forced by `client` fixture which requires it goes first + # so sometimes there's no request available + if hasattr(request, "param"): + os.environ["SCHEMA__NAMING"] = request.param + return request.param + return None + + @pytest.mark.order(1) @pytest.mark.parametrize( "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name @@ -64,18 +85,25 @@ def test_initialize_storage(client: SqlJobClientBase) -> None: @pytest.mark.order(2) +@pytest.mark.parametrize("naming", TEST_NAMING_CONVENTIONS, indirect=True) @pytest.mark.parametrize( "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name ) -def test_get_schema_on_empty_storage(client: SqlJobClientBase) -> None: +def test_get_schema_on_empty_storage(naming: str, client: SqlJobClientBase) -> None: # test getting schema on empty dataset without any tables - exists, _ = client.get_storage_table(VERSION_TABLE_NAME) - assert exists is False + version_table_name = client.schema.version_table_name + table_name, table_columns = list(client.get_storage_tables([version_table_name]))[0] + assert table_name == version_table_name + assert len(table_columns) == 0 schema_info = client.get_stored_schema() assert schema_info is None schema_info = client.get_stored_schema_by_hash("8a0298298823928939") assert schema_info is None + # now try to get several non existing tables + storage_tables = list(client.get_storage_tables(["no_table_1", "no_table_2"])) + assert [("no_table_1", {}), ("no_table_2", {})] == storage_tables + @pytest.mark.order(3) @pytest.mark.parametrize( @@ -90,17 +118,17 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: # check is event slot has variant assert schema_update["event_slot"]["columns"]["value"]["variant"] is True # now we have dlt tables - exists, _ = client.get_storage_table(VERSION_TABLE_NAME) - assert exists is True - exists, _ = client.get_storage_table(LOADS_TABLE_NAME) - assert exists is True + storage_tables = list(client.get_storage_tables([VERSION_TABLE_NAME, LOADS_TABLE_NAME])) + assert set([table[0] for table in storage_tables]) == {VERSION_TABLE_NAME, LOADS_TABLE_NAME} + assert [len(table[1]) > 0 for table in storage_tables] == [True, True] # verify if schemas stored this_schema = client.get_stored_schema_by_hash(schema.version_hash) newest_schema = client.get_stored_schema() # should point to the same schema assert this_schema == newest_schema # check fields - assert this_schema.version == 1 == schema.version + # NOTE: schema version == 2 because we updated default hints after loading the schema + assert this_schema.version == 2 == schema.version assert this_schema.version_hash == schema.stored_version_hash assert this_schema.engine_version == schema.ENGINE_VERSION assert this_schema.schema_name == schema.name @@ -120,7 +148,7 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: this_schema = client.get_stored_schema_by_hash(schema.version_hash) newest_schema = client.get_stored_schema() assert this_schema == newest_schema - assert this_schema.version == schema.version == 2 + assert this_schema.version == schema.version == 3 assert this_schema.version_hash == schema.stored_version_hash # simulate parallel write: initial schema is modified differently and written alongside the first one @@ -128,14 +156,14 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: first_schema = Schema.from_dict(json.loads(first_version_schema)) first_schema.tables["event_bot"]["write_disposition"] = "replace" first_schema._bump_version() - assert first_schema.version == this_schema.version == 2 + assert first_schema.version == this_schema.version == 3 # wait to make load_newest_schema deterministic sleep(1) client._update_schema_in_storage(first_schema) this_schema = client.get_stored_schema_by_hash(first_schema.version_hash) newest_schema = client.get_stored_schema() assert this_schema == newest_schema # error - assert this_schema.version == first_schema.version == 2 + assert this_schema.version == first_schema.version == 3 assert this_schema.version_hash == first_schema.stored_version_hash # get schema with non existing hash @@ -157,15 +185,17 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: assert this_schema == newest_schema -@pytest.mark.essential +@pytest.mark.parametrize("naming", TEST_NAMING_CONVENTIONS, indirect=True) @pytest.mark.parametrize( "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name ) -def test_complete_load(client: SqlJobClientBase) -> None: +def test_complete_load(naming: str, client: SqlJobClientBase) -> None: + loads_table_name = client.schema.loads_table_name + version_table_name = client.schema.version_table_name client.update_stored_schema() load_id = "182879721.182912" client.complete_load(load_id) - load_table = client.sql_client.make_qualified_table_name(LOADS_TABLE_NAME) + load_table = client.sql_client.make_qualified_table_name(loads_table_name) load_rows = list(client.sql_client.execute_sql(f"SELECT * FROM {load_table}")) assert len(load_rows) == 1 assert load_rows[0][0] == load_id @@ -176,10 +206,13 @@ def test_complete_load(client: SqlJobClientBase) -> None: assert type(load_rows[0][3]) is datetime.datetime assert load_rows[0][4] == client.schema.version_hash # make sure that hash in loads exists in schema versions table - versions_table = client.sql_client.make_qualified_table_name(VERSION_TABLE_NAME) + versions_table = client.sql_client.make_qualified_table_name(version_table_name) + version_hash_column = client.sql_client.escape_column_name( + client.schema.naming.normalize_identifier("version_hash") + ) version_rows = list( client.sql_client.execute_sql( - f"SELECT * FROM {versions_table} WHERE version_hash = %s", load_rows[0][4] + f"SELECT * FROM {versions_table} WHERE {version_hash_column} = %s", load_rows[0][4] ) ) assert len(version_rows) == 1 @@ -190,11 +223,11 @@ def test_complete_load(client: SqlJobClientBase) -> None: @pytest.mark.parametrize( "client", - destinations_configs(default_sql_configs=True, subset=["redshift", "postgres", "duckdb"]), + destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name, ) -def test_schema_update_create_table_redshift(client: SqlJobClientBase) -> None: +def test_schema_update_create_table(client: SqlJobClientBase) -> None: # infer typical rasa event schema schema = client.schema table_name = "event_test_table" + uniq_id() @@ -215,8 +248,8 @@ def test_schema_update_create_table_redshift(client: SqlJobClientBase) -> None: assert table_update["timestamp"]["sort"] is True assert table_update["sender_id"]["cluster"] is True assert table_update["_dlt_id"]["unique"] is True - exists, _ = client.get_storage_table(table_name) - assert exists is True + _, storage_columns = list(client.get_storage_tables([table_name]))[0] + assert len(storage_columns) > 0 @pytest.mark.parametrize( @@ -225,7 +258,15 @@ def test_schema_update_create_table_redshift(client: SqlJobClientBase) -> None: indirect=True, ids=lambda x: x.name, ) -def test_schema_update_create_table_bigquery(client: SqlJobClientBase) -> None: +@pytest.mark.parametrize("dataset_name", (None, "_hidden_ds")) +def test_schema_update_create_table_bigquery(client: SqlJobClientBase, dataset_name: str) -> None: + # patch dataset name + if dataset_name: + # drop existing dataset + client.drop_storage() + client.sql_client.dataset_name = dataset_name + "_" + uniq_id() + client.initialize_storage() + # infer typical rasa event schema schema = client.schema # this will be partition @@ -241,14 +282,11 @@ def test_schema_update_create_table_bigquery(client: SqlJobClientBase) -> None: table_update = schema_update["event_test_table"]["columns"] assert table_update["timestamp"]["partition"] is True assert table_update["_dlt_id"]["nullable"] is False - exists, storage_table = client.get_storage_table("event_test_table") - assert exists is True - assert storage_table["timestamp"]["partition"] is True - assert storage_table["sender_id"]["cluster"] is True - exists, storage_table = client.get_storage_table("_dlt_version") - assert exists is True - assert storage_table["version"]["partition"] is False - assert storage_table["version"]["cluster"] is False + _, storage_columns = client.get_storage_table("event_test_table") + # check if all columns present + assert storage_columns.keys() == client.schema.tables["event_test_table"]["columns"].keys() + _, storage_columns = client.get_storage_table("_dlt_version") + assert storage_columns.keys() == client.schema.tables["_dlt_version"]["columns"].keys() @pytest.mark.parametrize( @@ -285,10 +323,11 @@ def test_schema_update_alter_table(client: SqlJobClientBase) -> None: assert len(schema_update[table_name]["columns"]) == 2 assert schema_update[table_name]["columns"]["col3"]["data_type"] == "double" assert schema_update[table_name]["columns"]["col4"]["data_type"] == "timestamp" - _, storage_table = client.get_storage_table(table_name) + _, storage_table_cols = client.get_storage_table(table_name) # 4 columns - assert len(storage_table) == 4 - assert storage_table["col4"]["data_type"] == "timestamp" + assert len(storage_table_cols) == 4 + storage_table_cols = normalize_storage_table_cols(table_name, storage_table_cols, schema) + assert storage_table_cols["col4"]["data_type"] == "timestamp" @pytest.mark.parametrize( @@ -323,10 +362,11 @@ def test_drop_tables(client: SqlJobClientBase) -> None: # Drop tables from the first schema client.schema = schema tables_to_drop = ["event_slot", "event_user"] - for tbl in tables_to_drop: - del schema.tables[tbl] + schema.drop_tables(tables_to_drop) schema._bump_version() - client.drop_tables(*tables_to_drop) + + # add one fake table to make sure one table can be ignored + client.drop_tables(tables_to_drop[0], "not_exists", *tables_to_drop[1:]) client._update_schema_in_storage(schema) # Schema was deleted, load it in again if isinstance(client, WithStagingDataset): with contextlib.suppress(DatabaseUndefinedRelation): @@ -341,9 +381,7 @@ def test_drop_tables(client: SqlJobClientBase) -> None: client.drop_tables(*tables_to_drop, delete_schema=False) # Verify requested tables are dropped - for tbl in tables_to_drop: - exists, _ = client.get_storage_table(tbl) - assert not exists + assert all(len(table[1]) == 0 for table in client.get_storage_tables(tables_to_drop)) # Verify _dlt_version schema is updated and old versions deleted table_name = client.sql_client.make_qualified_table_name(VERSION_TABLE_NAME) @@ -376,14 +414,13 @@ def test_get_storage_table_with_all_types(client: SqlJobClientBase) -> None: for name, column in table_update.items(): assert column.items() >= TABLE_UPDATE_COLUMNS_SCHEMA[name].items() # now get the actual schema from the db - exists, storage_table = client.get_storage_table(table_name) - assert exists is True + _, storage_table = list(client.get_storage_tables([table_name]))[0] + assert len(storage_table) > 0 # column order must match TABLE_UPDATE storage_columns = list(storage_table.values()) for c, expected_c in zip(TABLE_UPDATE, storage_columns): - # print(c["name"]) - # print(c["data_type"]) - assert c["name"] == expected_c["name"] + # storage columns are returned with column names as in information schema + assert client.capabilities.casefold_identifier(c["name"]) == expected_c["name"] # athena does not know wei data type and has no JSON type, time is not supported with parquet tables if client.config.destination_type == "athena" and c["data_type"] in ( "wei", @@ -429,8 +466,7 @@ def _assert_columns_order(sql_: str) -> None: if hasattr(client.sql_client, "escape_ddl_identifier"): col_name = client.sql_client.escape_ddl_identifier(c["name"]) else: - col_name = client.capabilities.escape_identifier(c["name"]) - print(col_name) + col_name = client.sql_client.escape_column_name(c["name"]) # find column names idx = sql_.find(col_name, idx) assert idx > 0, f"column {col_name} not found in script" @@ -441,10 +477,11 @@ def _assert_columns_order(sql_: str) -> None: _assert_columns_order(sql) +@pytest.mark.parametrize("naming", TEST_NAMING_CONVENTIONS, indirect=True) @pytest.mark.parametrize( "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name ) -def test_data_writer_load(client: SqlJobClientBase, file_storage: FileStorage) -> None: +def test_data_writer_load(naming: str, client: SqlJobClientBase, file_storage: FileStorage) -> None: if not client.capabilities.preferred_loader_file_format: pytest.skip("preferred loader file format not set, destination will only work with staging") rows, table_name = prepare_schema(client, "simple_row") @@ -462,8 +499,10 @@ def test_data_writer_load(client: SqlJobClientBase, file_storage: FileStorage) - write_dataset(client, f, [rows[1]], client.schema.get_table(table_name)["columns"]) query = f.getvalue().decode() expect_load_file(client, file_storage, query, table_name) + f_int_name = client.schema.naming.normalize_identifier("f_int") + f_int_name_quoted = client.sql_client.escape_column_name(f_int_name) db_row = client.sql_client.execute_sql( - f"SELECT * FROM {canonical_name} WHERE f_int = {rows[1]['f_int']}" + f"SELECT * FROM {canonical_name} WHERE {f_int_name_quoted} = {rows[1][f_int_name]}" )[0] assert db_row[3] is None assert db_row[5] is None @@ -509,53 +548,68 @@ def test_data_writer_string_escape_edge( assert row_value == expected +@pytest.mark.parametrize("naming", TEST_NAMING_CONVENTIONS, indirect=True) @pytest.mark.parametrize("write_disposition", ["append", "replace"]) @pytest.mark.parametrize( "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name ) def test_load_with_all_types( - client: SqlJobClientBase, write_disposition: TWriteDisposition, file_storage: FileStorage + naming: str, + client: SqlJobClientBase, + write_disposition: TWriteDisposition, + file_storage: FileStorage, ) -> None: if not client.capabilities.preferred_loader_file_format: pytest.skip("preferred loader file format not set, destination will only work with staging") table_name = "event_test_table" + uniq_id() - column_schemas, data_types = table_update_and_row( + column_schemas, data_row = table_update_and_row( exclude_types=( ["time"] if client.config.destination_type in ["databricks", "clickhouse"] else None ), ) # we should have identical content with all disposition types - client.schema.update_table( + partial = client.schema.update_table( new_table( table_name, write_disposition=write_disposition, columns=list(column_schemas.values()) ) ) + # get normalized schema + table_name = partial["name"] + column_schemas = partial["columns"] + normalize_rows([data_row], client.schema.naming) client.schema._bump_version() client.update_stored_schema() - if client.should_load_data_to_staging_dataset(client.schema.tables[table_name]): # type: ignore[attr-defined] + should_load_to_staging = client.should_load_data_to_staging_dataset(client.schema.tables[table_name]) # type: ignore[attr-defined] + if should_load_to_staging: with client.with_staging_dataset(): # type: ignore[attr-defined] # create staging for merge dataset client.initialize_storage() client.update_stored_schema() - with client.sql_client.with_staging_dataset( - client.should_load_data_to_staging_dataset(client.schema.tables[table_name]) # type: ignore[attr-defined] + with client.sql_client.with_alternative_dataset_name( + client.sql_client.staging_dataset_name + if should_load_to_staging + else client.sql_client.dataset_name ): canonical_name = client.sql_client.make_qualified_table_name(table_name) # write row with io.BytesIO() as f: - write_dataset(client, f, [data_types], column_schemas) + write_dataset(client, f, [data_row], column_schemas) query = f.getvalue().decode() expect_load_file(client, file_storage, query, table_name) db_row = list(client.sql_client.execute_sql(f"SELECT * FROM {canonical_name}")[0]) - # content must equal - assert_all_data_types_row( - db_row, - schema=column_schemas, - allow_base64_binary=client.config.destination_type in ["clickhouse"], - ) + assert len(db_row) == len(data_row) + # assert_all_data_types_row has many hardcoded columns so for now skip that part + if naming == "snake_case": + # content must equal + assert_all_data_types_row( + db_row, + data_row, + schema=column_schemas, + allow_base64_binary=client.config.destination_type in ["clickhouse"], + ) @pytest.mark.parametrize( @@ -638,7 +692,7 @@ def test_write_dispositions( # merge on client level, without loader, loads to staging dataset. so this table is empty assert len(db_rows) == 0 # check staging - with client.sql_client.with_staging_dataset(staging=True): + with client.sql_client.with_staging_dataset(): db_rows = list( client.sql_client.execute_sql( f"SELECT * FROM {client.sql_client.make_qualified_table_name(t)} ORDER" @@ -716,6 +770,53 @@ def test_default_schema_name_init_storage(destination_config: DestinationTestCon assert client.sql_client.has_dataset() +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) +@pytest.mark.parametrize( + "naming_convention", + [ + "tests.common.cases.normalizers.title_case", + "snake_case", + ], +) +def test_get_stored_state( + destination_config: DestinationTestConfiguration, + naming_convention: str, + file_storage: FileStorage, +) -> None: + os.environ["SCHEMA__NAMING"] = naming_convention + + with cm_yield_client_with_storage( + destination_config.destination, default_config_values={"default_schema_name": None} + ) as client: + # event schema with event table + if not client.capabilities.preferred_loader_file_format: + pytest.skip( + "preferred loader file format not set, destination will only work with staging" + ) + # load pipeline state + state_table = pipeline_state_table() + partial = client.schema.update_table(state_table) + print(partial) + client.schema._bump_version() + client.update_stored_schema() + + state_info = StateInfo(1, 4, "pipeline", "compressed", pendulum.now(), None, "_load_id") + doc = state_info.as_doc() + norm_doc = {client.schema.naming.normalize_identifier(k): v for k, v in doc.items()} + with io.BytesIO() as f: + # use normalized columns + write_dataset(client, f, [norm_doc], partial["columns"]) + query = f.getvalue().decode() + expect_load_file(client, file_storage, query, partial["name"]) + client.complete_load("_load_id") + + # get state + stored_state = client.get_stored_state("pipeline") + assert doc == stored_state.as_doc() + + @pytest.mark.parametrize( "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name ) @@ -810,10 +911,19 @@ def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None: def prepare_schema(client: SqlJobClientBase, case: str) -> Tuple[List[Dict[str, Any]], str]: client.update_stored_schema() rows = load_json_case(case) + # normalize rows + normalize_rows(rows, client.schema.naming) # use first row to infer table table: TTableSchemaColumns = {k: client.schema._infer_column(k, v) for k, v in rows[0].items()} table_name = f"event_{case}_{uniq_id()}" - client.schema.update_table(new_table(table_name, columns=list(table.values()))) + partial = client.schema.update_table(new_table(table_name, columns=list(table.values()))) client.schema._bump_version() client.update_stored_schema() - return rows, table_name + # return normalized name + return rows, partial["name"] + + +def normalize_rows(rows: List[Dict[str, Any]], naming: NamingConvention) -> None: + for row in rows: + for k in list(row.keys()): + row[naming.normalize_identifier(k)] = row.pop(k) diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index 26d7884179..e167f0ceda 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -1,3 +1,4 @@ +import os import pytest import datetime # noqa: I251 from typing import Iterator, Any @@ -22,8 +23,20 @@ from dlt.common.time import ensure_pendulum_datetime from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage -from tests.load.utils import yield_client_with_storage, prepare_table, AWS_BUCKET -from tests.load.pipeline.utils import destinations_configs +from tests.load.utils import ( + yield_client_with_storage, + prepare_table, + AWS_BUCKET, + destinations_configs, +) + +# mark all tests as essential, do not remove +pytestmark = pytest.mark.essential +TEST_NAMING_CONVENTIONS = ( + "snake_case", + "tests.common.cases.normalizers.sql_upper", + "tests.common.cases.normalizers.title_case", +) @pytest.fixture @@ -32,10 +45,20 @@ def file_storage() -> FileStorage: @pytest.fixture(scope="function") -def client(request) -> Iterator[SqlJobClientBase]: +def client(request, naming) -> Iterator[SqlJobClientBase]: yield from yield_client_with_storage(request.param.destination) +@pytest.fixture(scope="function") +def naming(request) -> str: + # NOTE: this fixture is forced by `client` fixture which requires it goes first + # so sometimes there's no request available + if hasattr(request, "param"): + os.environ["SCHEMA__NAMING"] = request.param + return request.param + return None + + @pytest.mark.parametrize( "client", destinations_configs( @@ -105,6 +128,30 @@ def test_malformed_query_parameters(client: SqlJobClientBase) -> None: assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) +@pytest.mark.parametrize("naming", TEST_NAMING_CONVENTIONS, indirect=True) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) +def test_has_dataset(naming: str, client: SqlJobClientBase) -> None: + with client.sql_client.with_alternative_dataset_name("not_existing"): + assert not client.sql_client.has_dataset() + client.update_stored_schema() + assert client.sql_client.has_dataset() + + +@pytest.mark.parametrize("naming", TEST_NAMING_CONVENTIONS, indirect=True) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) +def test_create_drop_dataset(naming: str, client: SqlJobClientBase) -> None: + # client.sql_client.create_dataset() + with pytest.raises(DatabaseException): + client.sql_client.create_dataset() + client.sql_client.drop_dataset() + with pytest.raises(DatabaseUndefinedRelation): + client.sql_client.drop_dataset() + + @pytest.mark.parametrize( "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name ) @@ -141,7 +188,6 @@ def test_malformed_execute_parameters(client: SqlJobClientBase) -> None: assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) -@pytest.mark.essential @pytest.mark.parametrize( "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name ) @@ -189,7 +235,6 @@ def test_execute_sql(client: SqlJobClientBase) -> None: assert len(rows) == 0 -@pytest.mark.essential @pytest.mark.parametrize( "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name ) @@ -212,7 +257,6 @@ def test_execute_ddl(client: SqlJobClientBase) -> None: assert rows[0][0] == Decimal("1.0") -@pytest.mark.essential @pytest.mark.parametrize( "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name ) @@ -255,7 +299,6 @@ def test_execute_query(client: SqlJobClientBase) -> None: assert len(rows) == 0 -@pytest.mark.essential @pytest.mark.parametrize( "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name ) @@ -307,7 +350,6 @@ def test_execute_df(client: SqlJobClientBase) -> None: assert df_3 is None -@pytest.mark.essential @pytest.mark.parametrize( "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name ) @@ -490,7 +532,10 @@ def test_transaction_isolation(client: SqlJobClientBase) -> None: def test_thread(thread_id: Decimal) -> None: # make a copy of the sql_client thread_client = client.sql_client.__class__( - client.sql_client.dataset_name, client.sql_client.credentials + client.sql_client.dataset_name, + client.sql_client.staging_dataset_name, + client.sql_client.credentials, + client.capabilities, ) with thread_client: with thread_client.begin_transaction(): diff --git a/tests/load/utils.py b/tests/load/utils.py index 8048d9fe51..95083b7d31 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -13,6 +13,7 @@ from dlt.common.configuration import resolve_configuration from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_section_context import ConfigSectionContext +from dlt.common.configuration.specs import CredentialsConfiguration from dlt.common.destination.reference import ( DestinationClientDwhConfiguration, JobClientBase, @@ -24,13 +25,15 @@ from dlt.common.destination import TLoaderFileFormat, Destination from dlt.common.destination.reference import DEFAULT_FILE_LAYOUT from dlt.common.data_writers import DataWriter +from dlt.common.pipeline import PipelineContext from dlt.common.schema import TTableSchemaColumns, Schema from dlt.common.storages import SchemaStorage, FileStorage, SchemaStorageConfiguration -from dlt.common.schema.utils import new_table +from dlt.common.schema.utils import new_table, normalize_table_identifiers from dlt.common.storages import ParsedLoadJobFileName, LoadStorage, PackageStorage from dlt.common.typing import StrAny from dlt.common.utils import uniq_id +from dlt.destinations.exceptions import CantExtractTablePrefix from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.job_client_impl import SqlJobClientBase @@ -126,6 +129,8 @@ class DestinationTestConfiguration: force_iceberg: bool = False supports_dbt: bool = True disable_compression: bool = False + dev_mode: bool = False + credentials: Optional[Union[CredentialsConfiguration, Dict[str, Any]]] = None @property def name(self) -> str: @@ -140,33 +145,55 @@ def name(self) -> str: name += f"-{self.extra_info}" return name + @property + def factory_kwargs(self) -> Dict[str, Any]: + return { + k: getattr(self, k) + for k in [ + "bucket_url", + "stage_name", + "staging_iam_role", + "staging_use_msi", + "force_iceberg", + ] + if getattr(self, k, None) is not None + } + def setup(self) -> None: """Sets up environment variables for this destination configuration""" - os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = self.bucket_url or "" - os.environ["DESTINATION__STAGE_NAME"] = self.stage_name or "" - os.environ["DESTINATION__STAGING_IAM_ROLE"] = self.staging_iam_role or "" - os.environ["DESTINATION__STAGING_USE_MSI"] = str(self.staging_use_msi) or "" - os.environ["DESTINATION__FORCE_ICEBERG"] = str(self.force_iceberg) or "" + for k, v in self.factory_kwargs.items(): + os.environ[f"DESTINATION__{k.upper()}"] = str(v) - """For the filesystem destinations we disable compression to make analyzing the result easier""" + # For the filesystem destinations we disable compression to make analyzing the result easier if self.destination == "filesystem" or self.disable_compression: os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" + if self.credentials is not None: + for key, value in dict(self.credentials).items(): + os.environ[f"DESTINATION__CREDENTIALS__{key.upper()}"] = str(value) + def setup_pipeline( self, pipeline_name: str, dataset_name: str = None, dev_mode: bool = False, **kwargs ) -> dlt.Pipeline: """Convenience method to setup pipeline with this configuration""" + self.dev_mode = dev_mode self.setup() pipeline = dlt.pipeline( pipeline_name=pipeline_name, - destination=self.destination, - staging=self.staging, + destination=kwargs.pop("destination", self.destination), + staging=kwargs.pop("staging", self.staging), dataset_name=dataset_name or pipeline_name, dev_mode=dev_mode, **kwargs, ) return pipeline + def attach_pipeline(self, pipeline_name: str, **kwargs) -> dlt.Pipeline: + """Attach to existing pipeline keeping the dev_mode""" + # remember dev_mode from setup_pipeline + pipeline = dlt.attach(pipeline_name, **kwargs) + return pipeline + def destinations_configs( default_sql_configs: bool = False, @@ -255,8 +282,16 @@ def destinations_configs( assert set(SQL_DESTINATIONS) == {d.destination for d in destination_configs} if default_vector_configs: - # for now only weaviate - destination_configs += [DestinationTestConfiguration(destination="weaviate")] + destination_configs += [ + DestinationTestConfiguration(destination="weaviate"), + DestinationTestConfiguration(destination="lancedb"), + DestinationTestConfiguration( + destination="qdrant", + credentials=dict(path=str(Path(FILE_BUCKET) / "qdrant_data")), + extra_info="local-file", + ), + DestinationTestConfiguration(destination="qdrant", extra_info="server"), + ] if default_staging_configs or all_staging_configs: destination_configs += [ @@ -489,6 +524,60 @@ def destinations_configs( return destination_configs +@pytest.fixture(autouse=True) +def drop_pipeline(request, preserve_environ) -> Iterator[None]: + # NOTE: keep `preserve_environ` to make sure fixtures are executed in order`` + yield + if "no_load" in request.keywords: + return + try: + drop_active_pipeline_data() + except CantExtractTablePrefix: + # for some tests we test that this exception is raised, + # so we suppress it here + pass + + +def drop_active_pipeline_data() -> None: + """Drops all the datasets for currently active pipeline, wipes the working folder and then deactivated it.""" + if Container()[PipelineContext].is_active(): + try: + # take existing pipeline + p = dlt.pipeline() + + def _drop_dataset(schema_name: str) -> None: + with p.destination_client(schema_name) as client: + try: + client.drop_storage() + print("dropped") + except Exception as exc: + print(exc) + if isinstance(client, WithStagingDataset): + with client.with_staging_dataset(): + try: + client.drop_storage() + print("staging dropped") + except Exception as exc: + print(exc) + + # drop_func = _drop_dataset_fs if _is_filesystem(p) else _drop_dataset_sql + # take all schemas and if destination was set + if p.destination: + if p.config.use_single_dataset: + # drop just the dataset for default schema + if p.default_schema_name: + _drop_dataset(p.default_schema_name) + else: + # for each schema, drop the dataset + for schema_name in p.schema_names: + _drop_dataset(schema_name) + + # p._wipe_working_folder() + finally: + # always deactivate context, working directory will be wiped when the next test starts + Container()[PipelineContext].deactivate() + + @pytest.fixture def empty_schema() -> Schema: schema = Schema("event") @@ -580,6 +669,9 @@ def yield_client( ) schema_storage = SchemaStorage(storage_config) schema = schema_storage.load_schema(schema_name) + schema.update_normalizers() + # NOTE: schema version is bumped because new default hints are added + schema._bump_version() # create client and dataset client: SqlJobClientBase = None @@ -626,7 +718,8 @@ def yield_client_with_storage( ) as client: client.initialize_storage() yield client - client.sql_client.drop_dataset() + if client.is_storage_initialized(): + client.sql_client.drop_dataset() if isinstance(client, WithStagingDataset): with client.with_staging_dataset(): if client.is_storage_initialized(): @@ -680,7 +773,7 @@ def prepare_load_package( shutil.copy( path, load_storage.new_packages.storage.make_full_path( - load_storage.new_packages.get_job_folder_path(load_id, "new_jobs") + load_storage.new_packages.get_job_state_folder_path(load_id, "new_jobs") ), ) schema_path = Path("./tests/load/cases/loading/schema.json") @@ -708,3 +801,15 @@ def sequence_generator() -> Generator[List[Dict[str, str]], None, None]: while True: yield [{"content": str(count + i)} for i in range(3)] count += 3 + + +def normalize_storage_table_cols( + table_name: str, cols: TTableSchemaColumns, schema: Schema +) -> TTableSchemaColumns: + """Normalize storage table columns back into schema naming""" + # go back to schema naming convention. this is a hack - will work here to + # reverse snowflake UPPER case folding + storage_table = normalize_table_identifiers( + new_table(table_name, columns=cols.values()), schema.naming # type: ignore[arg-type] + ) + return storage_table["columns"] diff --git a/tests/load/weaviate/test_pipeline.py b/tests/load/weaviate/test_pipeline.py index ee42ab59d8..fc46d00d05 100644 --- a/tests/load/weaviate/test_pipeline.py +++ b/tests/load/weaviate/test_pipeline.py @@ -4,9 +4,13 @@ import dlt from dlt.common import json +from dlt.common.schema.exceptions import ( + SchemaCorruptedException, + SchemaIdentifierNormalizationCollision, +) from dlt.common.utils import uniq_id -from dlt.destinations.impl.weaviate import weaviate_adapter +from dlt.destinations.adapters import weaviate_adapter from dlt.destinations.impl.weaviate.exceptions import PropertyNameConflict from dlt.destinations.impl.weaviate.weaviate_adapter import VECTORIZE_HINT, TOKENIZATION_HINT from dlt.destinations.impl.weaviate.weaviate_client import WeaviateClient @@ -244,7 +248,8 @@ def movies_data(): assert_class(pipeline, "MoviesData", items=data) -def test_pipeline_with_schema_evolution(): +@pytest.mark.parametrize("vectorized", (True, False), ids=("vectorized", "not-vectorized")) +def test_pipeline_with_schema_evolution(vectorized: bool): data = [ { "doc_id": 1, @@ -260,7 +265,8 @@ def test_pipeline_with_schema_evolution(): def some_data(): yield data - weaviate_adapter(some_data, vectorize=["content"]) + if vectorized: + weaviate_adapter(some_data, vectorize=["content"]) pipeline = dlt.pipeline( pipeline_name="test_pipeline_append", @@ -280,17 +286,22 @@ def some_data(): "doc_id": 3, "content": "3", "new_column": "new", + "new_vec_column": "lorem lorem", }, { "doc_id": 4, "content": "4", "new_column": "new", + "new_vec_column": "lorem lorem", }, ] - pipeline.run( - some_data(), - ) + some_data_2 = some_data() + + if vectorized: + weaviate_adapter(some_data_2, vectorize=["new_vec_column"]) + + pipeline.run(some_data_2) table_schema = pipeline.default_schema.tables["SomeData"] assert "new_column" in table_schema["columns"] @@ -298,6 +309,8 @@ def some_data(): aggregated_data.extend(data) aggregated_data[0]["new_column"] = None aggregated_data[1]["new_column"] = None + aggregated_data[0]["new_vec_column"] = None + aggregated_data[1]["new_vec_column"] = None assert_class(pipeline, "SomeData", items=aggregated_data) @@ -391,7 +404,7 @@ def test_vectorize_property_without_data() -> None: primary_key="vAlue", columns={"vAlue": {"data_type": "text"}}, ) - assert isinstance(pipe_ex.value.__context__, PropertyNameConflict) + assert isinstance(pipe_ex.value.__context__, SchemaIdentifierNormalizationCollision) # set the naming convention to case insensitive os.environ["SCHEMA__NAMING"] = "dlt.destinations.impl.weaviate.ci_naming" diff --git a/tests/load/weaviate/test_weaviate_client.py b/tests/load/weaviate/test_weaviate_client.py index 8c3344f152..dc2110d2f6 100644 --- a/tests/load/weaviate/test_weaviate_client.py +++ b/tests/load/weaviate/test_weaviate_client.py @@ -5,6 +5,7 @@ from dlt.common.schema import Schema from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_section_context import ConfigSectionContext +from dlt.common.schema.exceptions import SchemaIdentifierNormalizationCollision from dlt.common.utils import uniq_id from dlt.common.schema.typing import TWriteDisposition, TColumnSchema, TTableSchemaColumns @@ -13,7 +14,7 @@ from dlt.destinations.impl.weaviate.weaviate_client import WeaviateClient from dlt.common.storages.file_storage import FileStorage -from dlt.common.schema.utils import new_table +from dlt.common.schema.utils import new_table, normalize_table_identifiers from tests.load.utils import ( TABLE_ROW_ALL_DATA_TYPES, TABLE_UPDATE, @@ -58,11 +59,11 @@ def make_client(naming_convention: str) -> Iterator[WeaviateClient]: "test_schema", {"names": f"dlt.destinations.impl.weaviate.{naming_convention}", "json": None}, ) - _client = get_client_instance(schema) - try: - yield _client - finally: - _client.drop_storage() + with get_client_instance(schema) as _client: + try: + yield _client + finally: + _client.drop_storage() @pytest.fixture @@ -114,11 +115,18 @@ def test_case_sensitive_properties_create(client: WeaviateClient) -> None: {"name": "coL1", "data_type": "double", "nullable": False}, ] client.schema.update_table( - client.schema.normalize_table_identifiers(new_table(class_name, columns=table_create)) + normalize_table_identifiers( + new_table(class_name, columns=table_create), client.schema.naming + ) ) client.schema._bump_version() - with pytest.raises(PropertyNameConflict): + with pytest.raises(SchemaIdentifierNormalizationCollision) as clash_ex: client.update_stored_schema() + assert clash_ex.value.identifier_type == "column" + assert clash_ex.value.identifier_name == "coL1" + assert clash_ex.value.conflict_identifier_name == "col1" + assert clash_ex.value.table_name == "ColClass" + assert clash_ex.value.naming_name == "dlt.destinations.impl.weaviate.naming" def test_case_insensitive_properties_create(ci_client: WeaviateClient) -> None: @@ -129,7 +137,9 @@ def test_case_insensitive_properties_create(ci_client: WeaviateClient) -> None: {"name": "coL1", "data_type": "double", "nullable": False}, ] ci_client.schema.update_table( - ci_client.schema.normalize_table_identifiers(new_table(class_name, columns=table_create)) + normalize_table_identifiers( + new_table(class_name, columns=table_create), ci_client.schema.naming + ) ) ci_client.schema._bump_version() ci_client.update_stored_schema() @@ -146,16 +156,20 @@ def test_case_sensitive_properties_add(client: WeaviateClient) -> None: {"name": "coL1", "data_type": "double", "nullable": False}, ] client.schema.update_table( - client.schema.normalize_table_identifiers(new_table(class_name, columns=table_create)) + normalize_table_identifiers( + new_table(class_name, columns=table_create), client.schema.naming + ) ) client.schema._bump_version() client.update_stored_schema() client.schema.update_table( - client.schema.normalize_table_identifiers(new_table(class_name, columns=table_update)) + normalize_table_identifiers( + new_table(class_name, columns=table_update), client.schema.naming + ) ) client.schema._bump_version() - with pytest.raises(PropertyNameConflict): + with pytest.raises(SchemaIdentifierNormalizationCollision): client.update_stored_schema() # _, table_columns = client.get_storage_table("ColClass") @@ -171,12 +185,13 @@ def test_load_case_sensitive_data(client: WeaviateClient, file_storage: FileStor client.schema.update_table(new_table(class_name, columns=[table_create["col1"]])) client.schema._bump_version() client.update_stored_schema() - # prepare a data item where is name clash due to Weaviate being CI + # prepare a data item where is name clash due to Weaviate being CS data_clash = {"col1": 72187328, "coL1": 726171} # write row with io.BytesIO() as f: write_dataset(client, f, [data_clash], table_create) query = f.getvalue().decode() + class_name = client.schema.naming.normalize_table_identifier(class_name) with pytest.raises(PropertyNameConflict): expect_load_file(client, file_storage, query, class_name) @@ -202,6 +217,7 @@ def test_load_case_sensitive_data_ci(ci_client: WeaviateClient, file_storage: Fi with io.BytesIO() as f: write_dataset(ci_client, f, [data_clash], table_create) query = f.getvalue().decode() + class_name = ci_client.schema.naming.normalize_table_identifier(class_name) expect_load_file(ci_client, file_storage, query, class_name) response = ci_client.query_class(class_name, ["col1"]).do() objects = response["data"]["Get"][ci_client.make_qualified_class_name(class_name)] diff --git a/tests/load/weaviate/utils.py b/tests/load/weaviate/utils.py index 1b2a74fcb8..b391c2fa38 100644 --- a/tests/load/weaviate/utils.py +++ b/tests/load/weaviate/utils.py @@ -22,53 +22,57 @@ def assert_class( expected_items_count: int = None, items: List[Any] = None, ) -> None: - client: WeaviateClient = pipeline.destination_client() # type: ignore[assignment] - vectorizer_name: str = client._vectorizer_config["vectorizer"] # type: ignore[assignment] - - # Check if class exists - schema = client.get_class_schema(class_name) - assert schema is not None - - columns = pipeline.default_schema.get_table_columns(class_name) - - properties = {prop["name"]: prop for prop in schema["properties"]} - assert set(properties.keys()) == set(columns.keys()) - - # make sure expected columns are vectorized - for column_name, column in columns.items(): - prop = properties[column_name] - assert prop["moduleConfig"][vectorizer_name]["skip"] == ( - not column.get(VECTORIZE_HINT, False) - ) - # tokenization - if TOKENIZATION_HINT in column: - assert prop["tokenization"] == column[TOKENIZATION_HINT] # type: ignore[literal-required] - - # if there's a single vectorize hint, class must have vectorizer enabled - if get_columns_names_with_prop(pipeline.default_schema.get_table(class_name), VECTORIZE_HINT): - assert schema["vectorizer"] == vectorizer_name - else: - assert schema["vectorizer"] == "none" - - # response = db_client.query.get(class_name, list(properties.keys())).do() - response = client.query_class(class_name, list(properties.keys())).do() - objects = response["data"]["Get"][client.make_qualified_class_name(class_name)] - - if expected_items_count is not None: - assert expected_items_count == len(objects) - - if items is None: - return - - # TODO: Remove this once we have a better way comparing the data - drop_keys = ["_dlt_id", "_dlt_load_id"] - objects_without_dlt_keys = [ - {k: v for k, v in obj.items() if k not in drop_keys} for obj in objects - ] - - # pytest compares content wise but ignores order of elements of dict - # assert sorted(objects_without_dlt_keys, key=lambda d: d['doc_id']) == sorted(data, key=lambda d: d['doc_id']) - assert_unordered_list_equal(objects_without_dlt_keys, items) + client: WeaviateClient + with pipeline.destination_client() as client: # type: ignore[assignment] + vectorizer_name: str = client._vectorizer_config["vectorizer"] # type: ignore[assignment] + + # Check if class exists + schema = client.get_class_schema(class_name) + assert schema is not None + + columns = pipeline.default_schema.get_table_columns(class_name) + + properties = {prop["name"]: prop for prop in schema["properties"]} + assert set(properties.keys()) == set(columns.keys()) + + # make sure expected columns are vectorized + for column_name, column in columns.items(): + prop = properties[column_name] + if client._is_collection_vectorized(class_name): + assert prop["moduleConfig"][vectorizer_name]["skip"] == ( + not column.get(VECTORIZE_HINT, False) + ) + # tokenization + if TOKENIZATION_HINT in column: + assert prop["tokenization"] == column[TOKENIZATION_HINT] # type: ignore[literal-required] + + # if there's a single vectorize hint, class must have vectorizer enabled + if get_columns_names_with_prop( + pipeline.default_schema.get_table(class_name), VECTORIZE_HINT + ): + assert schema["vectorizer"] == vectorizer_name + else: + assert schema["vectorizer"] == "none" + + # response = db_client.query.get(class_name, list(properties.keys())).do() + response = client.query_class(class_name, list(properties.keys())).do() + objects = response["data"]["Get"][client.make_qualified_class_name(class_name)] + + if expected_items_count is not None: + assert expected_items_count == len(objects) + + if items is None: + return + + # TODO: Remove this once we have a better way comparing the data + drop_keys = ["_dlt_id", "_dlt_load_id"] + objects_without_dlt_keys = [ + {k: v for k, v in obj.items() if k not in drop_keys} for obj in objects + ] + + # pytest compares content wise but ignores order of elements of dict + # assert sorted(objects_without_dlt_keys, key=lambda d: d['doc_id']) == sorted(data, key=lambda d: d['doc_id']) + assert_unordered_list_equal(objects_without_dlt_keys, items) def delete_classes(p, class_list): @@ -87,10 +91,9 @@ def schema_has_classes(client): if Container()[PipelineContext].is_active(): # take existing pipeline p = dlt.pipeline() - client = p.destination_client() - - if schema_has_classes(client): - client.drop_storage() + with p.destination_client() as client: + if schema_has_classes(client): + client.drop_storage() p._wipe_working_folder() # deactivate context diff --git a/tests/normalize/test_max_nesting.py b/tests/normalize/test_max_nesting.py index 4015836232..5def1617dc 100644 --- a/tests/normalize/test_max_nesting.py +++ b/tests/normalize/test_max_nesting.py @@ -62,7 +62,7 @@ def bot_events(): pipeline = dlt.pipeline( pipeline_name=pipeline_name, destination=dummy(timeout=0.1), - full_refresh=True, + dev_mode=True, ) pipeline.run(bot_events) @@ -169,7 +169,7 @@ def some_data(): pipeline = dlt.pipeline( pipeline_name=pipeline_name, destination=dummy(timeout=0.1), - full_refresh=True, + dev_mode=True, ) pipeline.run(some_data(), write_disposition="append") diff --git a/tests/normalize/test_normalize.py b/tests/normalize/test_normalize.py index 3891c667c3..7463184be7 100644 --- a/tests/normalize/test_normalize.py +++ b/tests/normalize/test_normalize.py @@ -16,6 +16,7 @@ from dlt.extract.extract import ExtractStorage from dlt.normalize import Normalize +from dlt.normalize.worker import group_worker_files from dlt.normalize.exceptions import NormalizeJobFailed from tests.cases import JSON_TYPED_DICT, JSON_TYPED_DICT_TYPES @@ -510,28 +511,28 @@ def test_collect_metrics_on_exception(raw_normalize: Normalize) -> None: def test_group_worker_files() -> None: files = ["f%03d" % idx for idx in range(0, 100)] - assert Normalize.group_worker_files([], 4) == [] - assert Normalize.group_worker_files(["f001"], 1) == [["f001"]] - assert Normalize.group_worker_files(["f001"], 100) == [["f001"]] - assert Normalize.group_worker_files(files[:4], 4) == [["f000"], ["f001"], ["f002"], ["f003"]] - assert Normalize.group_worker_files(files[:5], 4) == [ + assert group_worker_files([], 4) == [] + assert group_worker_files(["f001"], 1) == [["f001"]] + assert group_worker_files(["f001"], 100) == [["f001"]] + assert group_worker_files(files[:4], 4) == [["f000"], ["f001"], ["f002"], ["f003"]] + assert group_worker_files(files[:5], 4) == [ ["f000"], ["f001"], ["f002"], ["f003", "f004"], ] - assert Normalize.group_worker_files(files[:8], 4) == [ + assert group_worker_files(files[:8], 4) == [ ["f000", "f001"], ["f002", "f003"], ["f004", "f005"], ["f006", "f007"], ] - assert Normalize.group_worker_files(files[:8], 3) == [ + assert group_worker_files(files[:8], 3) == [ ["f000", "f001"], ["f002", "f003", "f006"], ["f004", "f005", "f007"], ] - assert Normalize.group_worker_files(files[:5], 3) == [ + assert group_worker_files(files[:5], 3) == [ ["f000"], ["f001", "f003"], ["f002", "f004"], @@ -539,7 +540,7 @@ def test_group_worker_files() -> None: # check if sorted files = ["tab1.1", "chd.3", "tab1.2", "chd.4", "tab1.3"] - assert Normalize.group_worker_files(files, 3) == [ + assert group_worker_files(files, 3) == [ ["chd.3"], ["chd.4", "tab1.2"], ["tab1.1", "tab1.3"], @@ -730,19 +731,22 @@ def test_removal_of_normalizer_schema_section_and_add_seen_data(raw_normalize: N extracted_schema.tables["event__random_table"] = new_table("event__random_table") # add x-normalizer info (and other block to control) - extracted_schema.tables["event"]["x-normalizer"] = {"evolve-columns-once": True} # type: ignore + extracted_schema.tables["event"]["x-normalizer"] = {"evolve-columns-once": True} extracted_schema.tables["event"]["x-other-info"] = "blah" # type: ignore - extracted_schema.tables["event__parse_data__intent_ranking"]["x-normalizer"] = {"seen-data": True, "random-entry": 1234} # type: ignore - extracted_schema.tables["event__random_table"]["x-normalizer"] = {"evolve-columns-once": True} # type: ignore + extracted_schema.tables["event__parse_data__intent_ranking"]["x-normalizer"] = { + "seen-data": True, + "random-entry": 1234, + } + extracted_schema.tables["event__random_table"]["x-normalizer"] = {"evolve-columns-once": True} normalize_pending(raw_normalize, extracted_schema) schema = raw_normalize.schema_storage.load_schema("event") # seen data gets added, schema settings get removed - assert schema.tables["event"]["x-normalizer"] == {"seen-data": True} # type: ignore - assert schema.tables["event__parse_data__intent_ranking"]["x-normalizer"] == { # type: ignore + assert schema.tables["event"]["x-normalizer"] == {"seen-data": True} + assert schema.tables["event__parse_data__intent_ranking"]["x-normalizer"] == { "seen-data": True, "random-entry": 1234, } # no data seen here, so seen-data is not set and evolve settings stays until first data is seen - assert schema.tables["event__random_table"]["x-normalizer"] == {"evolve-columns-once": True} # type: ignore + assert schema.tables["event__random_table"]["x-normalizer"] == {"evolve-columns-once": True} assert "x-other-info" in schema.tables["event"] diff --git a/tests/normalize/utils.py b/tests/normalize/utils.py index 0ce099d4b6..dffb3f1bb6 100644 --- a/tests/normalize/utils.py +++ b/tests/normalize/utils.py @@ -1,15 +1,10 @@ -from typing import Mapping, cast +from dlt.destinations import duckdb, redshift, postgres, bigquery, filesystem -from dlt.destinations.impl.duckdb import capabilities as duck_insert_caps -from dlt.destinations.impl.redshift import capabilities as rd_insert_caps -from dlt.destinations.impl.postgres import capabilities as pg_insert_caps -from dlt.destinations.impl.bigquery import capabilities as jsonl_caps -from dlt.destinations.impl.filesystem import capabilities as filesystem_caps - -DEFAULT_CAPS = pg_insert_caps -INSERT_CAPS = [duck_insert_caps, rd_insert_caps, pg_insert_caps] -JSONL_CAPS = [jsonl_caps, filesystem_caps] +# callables to capabilities +DEFAULT_CAPS = postgres().capabilities +INSERT_CAPS = [duckdb().capabilities, redshift().capabilities, DEFAULT_CAPS] +JSONL_CAPS = [bigquery().capabilities, filesystem().capabilities] ALL_CAPABILITIES = INSERT_CAPS + JSONL_CAPS diff --git a/tests/pipeline/cases/github_pipeline/github_pipeline.py b/tests/pipeline/cases/github_pipeline/github_pipeline.py index aa0f6d0e0e..f4cdc2bcf2 100644 --- a/tests/pipeline/cases/github_pipeline/github_pipeline.py +++ b/tests/pipeline/cases/github_pipeline/github_pipeline.py @@ -33,11 +33,21 @@ def load_issues( if __name__ == "__main__": - p = dlt.pipeline("dlt_github_pipeline", destination="duckdb", dataset_name="github_3") + # pick the destination name + if len(sys.argv) < 1: + raise RuntimeError(f"Please provide destination name in args ({sys.argv})") + dest_ = sys.argv[1] + if dest_ == "filesystem": + import os + from dlt.destinations import filesystem + + dest_ = filesystem(os.path.abspath(os.path.join("_storage", "data"))) # type: ignore + + p = dlt.pipeline("dlt_github_pipeline", destination=dest_, dataset_name="github_3") github_source = github() - if len(sys.argv) > 1: + if len(sys.argv) > 2: # load only N issues - limit = int(sys.argv[1]) + limit = int(sys.argv[2]) github_source.add_limit(limit) info = p.run(github_source) print(info) diff --git a/tests/pipeline/test_arrow_sources.py b/tests/pipeline/test_arrow_sources.py index 0c03a8209d..4cdccb1e34 100644 --- a/tests/pipeline/test_arrow_sources.py +++ b/tests/pipeline/test_arrow_sources.py @@ -9,7 +9,11 @@ import dlt from dlt.common import json, Decimal from dlt.common.utils import uniq_id -from dlt.common.libs.pyarrow import NameNormalizationClash, remove_columns, normalize_py_arrow_item +from dlt.common.libs.pyarrow import ( + NameNormalizationCollision, + remove_columns, + normalize_py_arrow_item, +) from dlt.pipeline.exceptions import PipelineStepFailed @@ -17,8 +21,8 @@ arrow_table_all_data_types, prepare_shuffled_tables, ) +from tests.pipeline.utils import assert_only_table_columns, load_tables_to_dicts from tests.utils import ( - preserve_environ, TPythonTableFormat, arrow_item_from_pandas, arrow_item_from_table, @@ -223,7 +227,7 @@ def data_frames(): with pytest.raises(PipelineStepFailed) as py_ex: pipeline.extract(data_frames()) - assert isinstance(py_ex.value.__context__, NameNormalizationClash) + assert isinstance(py_ex.value.__context__, NameNormalizationCollision) @pytest.mark.parametrize("item_type", ["arrow-table", "arrow-batch"]) @@ -507,6 +511,48 @@ def test_empty_arrow(item_type: TPythonTableFormat) -> None: assert norm_info.row_counts["items"] == 0 +def test_import_file_with_arrow_schema() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_jsonl_import", + destination="duckdb", + dev_mode=True, + ) + + # Define the schema based on the CSV input + schema = pa.schema( + [ + ("id", pa.int64()), + ("name", pa.string()), + ("description", pa.string()), + ("ordered_at", pa.date32()), + ("price", pa.float64()), + ] + ) + + # Create empty arrays for each field + empty_arrays = [ + pa.array([], type=pa.int64()), + pa.array([], type=pa.string()), + pa.array([], type=pa.string()), + pa.array([], type=pa.date32()), + pa.array([], type=pa.float64()), + ] + + # Create an empty table with the defined schema + empty_table = pa.Table.from_arrays(empty_arrays, schema=schema) + + # columns should be created from empty table + import_file = "tests/load/cases/loading/header.jsonl" + info = pipeline.run( + [dlt.mark.with_file_import(import_file, "jsonl", 2, hints=empty_table)], + table_name="no_header", + ) + info.raise_on_failed_jobs() + assert_only_table_columns(pipeline, "no_header", schema.names) + rows = load_tables_to_dicts(pipeline, "no_header") + assert len(rows["no_header"]) == 2 + + @pytest.mark.parametrize("item_type", ["pandas", "arrow-table", "arrow-batch"]) def test_extract_adds_dlt_load_id(item_type: TPythonTableFormat) -> None: os.environ["NORMALIZE__PARQUET_NORMALIZER__ADD_DLT_LOAD_ID"] = "True" diff --git a/tests/pipeline/test_dlt_versions.py b/tests/pipeline/test_dlt_versions.py index ccf926cc62..319055184a 100644 --- a/tests/pipeline/test_dlt_versions.py +++ b/tests/pipeline/test_dlt_versions.py @@ -1,11 +1,14 @@ import sys +from subprocess import CalledProcessError import pytest import tempfile import shutil +from unittest.mock import patch from importlib.metadata import version as pkg_version import dlt from dlt.common import json, pendulum +from dlt.common.known_env import DLT_DATA_DIR from dlt.common.json import custom_pua_decode from dlt.common.runners import Venv from dlt.common.storages.exceptions import StorageMigrationError @@ -14,17 +17,59 @@ from dlt.common.storages import FileStorage from dlt.common.schema.typing import ( LOADS_TABLE_NAME, - STATE_TABLE_NAME, + PIPELINE_STATE_TABLE_NAME, VERSION_TABLE_NAME, TStoredSchema, ) from dlt.common.configuration.resolve import resolve_configuration +from dlt.destinations import duckdb, filesystem from dlt.destinations.impl.duckdb.configuration import DuckDbClientConfiguration from dlt.destinations.impl.duckdb.sql_client import DuckDbSqlClient +from tests.pipeline.utils import airtable_emojis, load_table_counts from tests.utils import TEST_STORAGE_ROOT, test_storage -if sys.version_info > (3, 11): + +def test_simulate_default_naming_convention_change() -> None: + # checks that (future) change in the naming convention won't affect existing pipelines + pipeline = dlt.pipeline("simulated_snake_case", destination="duckdb") + assert pipeline.naming.name() == "snake_case" + info = pipeline.run( + airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock") + ) + info.raise_on_failed_jobs() + # normalized names + assert pipeline.last_trace.last_normalize_info.row_counts["_schedule"] == 3 + assert "_schedule" in pipeline.default_schema.tables + + # mock the mod + # from dlt.common.normalizers import utils + + with patch("dlt.common.normalizers.utils.DEFAULT_NAMING_MODULE", "duck_case"): + duck_pipeline = dlt.pipeline("simulated_duck_case", destination="duckdb") + assert duck_pipeline.naming.name() == "duck_case" + print(airtable_emojis().schema.naming.name()) + + # run new and old pipelines + info = duck_pipeline.run( + airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock") + ) + info.raise_on_failed_jobs() + print(duck_pipeline.last_trace.last_normalize_info.row_counts) + assert duck_pipeline.last_trace.last_normalize_info.row_counts["📆 Schedule"] == 3 + assert "📆 Schedule" in duck_pipeline.default_schema.tables + + # old pipeline should keep its naming convention + info = pipeline.run( + airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock") + ) + info.raise_on_failed_jobs() + # normalized names + assert pipeline.last_trace.last_normalize_info.row_counts["_schedule"] == 3 + assert pipeline.naming.name() == "snake_case" + + +if sys.version_info >= (3, 12): pytest.skip("Does not run on Python 3.12 and later", allow_module_level=True) @@ -38,7 +83,7 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: # execute in test storage with set_working_dir(TEST_STORAGE_ROOT): # store dlt data in test storage (like patch_home_dir) - with custom_environ({"DLT_DATA_DIR": get_dlt_data_dir()}): + with custom_environ({DLT_DATA_DIR: get_dlt_data_dir()}): # save database outside of pipeline dir with custom_environ( {"DESTINATION__DUCKDB__CREDENTIALS": "duckdb:///test_github_3.duckdb"} @@ -50,7 +95,9 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: # load 20 issues print( venv.run_script( - "../tests/pipeline/cases/github_pipeline/github_pipeline.py", "20" + "../tests/pipeline/cases/github_pipeline/github_pipeline.py", + "duckdb", + "20", ) ) # load schema and check _dlt_loads definition @@ -66,20 +113,26 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: ) # check the dlt state table assert { - "version_hash" not in github_schema["tables"][STATE_TABLE_NAME]["columns"] + "version_hash" + not in github_schema["tables"][PIPELINE_STATE_TABLE_NAME]["columns"] } # check loads table without attaching to pipeline duckdb_cfg = resolve_configuration( DuckDbClientConfiguration()._bind_dataset_name(dataset_name=GITHUB_DATASET), sections=("destination", "duckdb"), ) - with DuckDbSqlClient(GITHUB_DATASET, duckdb_cfg.credentials) as client: + with DuckDbSqlClient( + GITHUB_DATASET, + "%s_staging", + duckdb_cfg.credentials, + duckdb().capabilities(), + ) as client: rows = client.execute_sql(f"SELECT * FROM {LOADS_TABLE_NAME}") # make sure we have just 4 columns assert len(rows[0]) == 4 rows = client.execute_sql("SELECT * FROM issues") assert len(rows) == 20 - rows = client.execute_sql(f"SELECT * FROM {STATE_TABLE_NAME}") + rows = client.execute_sql(f"SELECT * FROM {PIPELINE_STATE_TABLE_NAME}") # only 5 columns + 2 dlt columns assert len(rows[0]) == 5 + 2 # inspect old state @@ -99,7 +152,16 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: # execute in current version venv = Venv.restore_current() # load all issues - print(venv.run_script("../tests/pipeline/cases/github_pipeline/github_pipeline.py")) + try: + print( + venv.run_script( + "../tests/pipeline/cases/github_pipeline/github_pipeline.py", "duckdb" + ) + ) + except CalledProcessError as cpe: + print(f"script stdout: {cpe.stdout}") + print(f"script stderr: {cpe.stderr}") + raise # hash hash in schema github_schema = json.loads( test_storage.load( @@ -108,13 +170,16 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: ) assert github_schema["engine_version"] == 9 assert "schema_version_hash" in github_schema["tables"][LOADS_TABLE_NAME]["columns"] + # print(github_schema["tables"][PIPELINE_STATE_TABLE_NAME]) # load state state_dict = json.loads( test_storage.load(f".dlt/pipelines/{GITHUB_PIPELINE_NAME}/state.json") ) assert "_version_hash" in state_dict - with DuckDbSqlClient(GITHUB_DATASET, duckdb_cfg.credentials) as client: + with DuckDbSqlClient( + GITHUB_DATASET, "%s_staging", duckdb_cfg.credentials, duckdb().capabilities() + ) as client: rows = client.execute_sql( f"SELECT * FROM {LOADS_TABLE_NAME} ORDER BY inserted_at" ) @@ -131,7 +196,9 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: # two schema versions rows = client.execute_sql(f"SELECT * FROM {VERSION_TABLE_NAME}") assert len(rows) == 2 - rows = client.execute_sql(f"SELECT * FROM {STATE_TABLE_NAME} ORDER BY version") + rows = client.execute_sql( + f"SELECT * FROM {PIPELINE_STATE_TABLE_NAME} ORDER BY version" + ) # we have hash columns assert len(rows[0]) == 6 + 2 assert len(rows) == 2 @@ -141,23 +208,82 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: assert rows[1][7] == state_dict["_version_hash"] # attach to existing pipeline - pipeline = dlt.attach(GITHUB_PIPELINE_NAME, credentials=duckdb_cfg.credentials) - created_at_value = pipeline.state["sources"]["github"]["resources"]["load_issues"][ - "incremental" - ]["created_at"]["last_value"] - assert isinstance(created_at_value, pendulum.DateTime) - assert created_at_value == pendulum.parse("2023-02-17T09:52:12Z") - pipeline = pipeline.drop() - # print(pipeline.working_dir) - assert pipeline.dataset_name == GITHUB_DATASET - assert pipeline.default_schema_name is None - # sync from destination - pipeline.sync_destination() - # print(pipeline.working_dir) - # we have updated schema - assert pipeline.default_schema.ENGINE_VERSION == 9 - # make sure that schema hash retrieved from the destination is exactly the same as the schema hash that was in storage before the schema was wiped - assert pipeline.default_schema.stored_version_hash == github_schema["version_hash"] + pipeline = dlt.attach( + GITHUB_PIPELINE_NAME, destination=duckdb(credentials=duckdb_cfg.credentials) + ) + assert_github_pipeline_end_state(pipeline, github_schema, 2) + + +def test_filesystem_pipeline_with_dlt_update(test_storage: FileStorage) -> None: + shutil.copytree("tests/pipeline/cases/github_pipeline", TEST_STORAGE_ROOT, dirs_exist_ok=True) + + # execute in test storage + with set_working_dir(TEST_STORAGE_ROOT): + # store dlt data in test storage (like patch_home_dir) + with custom_environ({DLT_DATA_DIR: get_dlt_data_dir()}): + # create virtual env with (0.4.9) where filesystem started to store state + with Venv.create(tempfile.mkdtemp(), ["dlt==0.4.9"]) as venv: + try: + print(venv.run_script("github_pipeline.py", "filesystem", "20")) + except CalledProcessError as cpe: + print(f"script stdout: {cpe.stdout}") + print(f"script stderr: {cpe.stderr}") + raise + # load all issues + venv = Venv.restore_current() + try: + print(venv.run_script("github_pipeline.py", "filesystem")) + except CalledProcessError as cpe: + print(f"script stdout: {cpe.stdout}") + print(f"script stderr: {cpe.stderr}") + raise + # hash hash in schema + github_schema = json.loads( + test_storage.load( + f".dlt/pipelines/{GITHUB_PIPELINE_NAME}/schemas/github.schema.json" + ) + ) + # attach to existing pipeline + pipeline = dlt.attach(GITHUB_PIPELINE_NAME, destination=filesystem("_storage/data")) + # assert end state + assert_github_pipeline_end_state(pipeline, github_schema, 2) + # load new state + fs_client = pipeline._fs_client() + state_files = sorted(fs_client.list_table_files("_dlt_pipeline_state")) + # first file is in old format + state_1 = json.loads(fs_client.read_text(state_files[0], encoding="utf-8")) + assert "dlt_load_id" in state_1 + # seconds is new + state_2 = json.loads(fs_client.read_text(state_files[1], encoding="utf-8")) + assert "_dlt_load_id" in state_2 + + +def assert_github_pipeline_end_state( + pipeline: dlt.Pipeline, orig_schema: TStoredSchema, schema_updates: int +) -> None: + # get tables counts + table_counts = load_table_counts(pipeline, *pipeline.default_schema.data_table_names()) + assert table_counts == {"issues": 100, "issues__assignees": 31, "issues__labels": 34} + dlt_counts = load_table_counts(pipeline, *pipeline.default_schema.dlt_table_names()) + assert dlt_counts == {"_dlt_version": schema_updates, "_dlt_loads": 2, "_dlt_pipeline_state": 2} + + # check state + created_at_value = pipeline.state["sources"]["github"]["resources"]["load_issues"][ + "incremental" + ]["created_at"]["last_value"] + assert isinstance(created_at_value, pendulum.DateTime) + assert created_at_value == pendulum.parse("2023-02-17T09:52:12Z") + pipeline = pipeline.drop() + # print(pipeline.working_dir) + assert pipeline.dataset_name == GITHUB_DATASET + assert pipeline.default_schema_name is None + # sync from destination + pipeline.sync_destination() + # print(pipeline.working_dir) + # we have updated schema + assert pipeline.default_schema.ENGINE_VERSION == 9 + # make sure that schema hash retrieved from the destination is exactly the same as the schema hash that was in storage before the schema was wiped + assert pipeline.default_schema.stored_version_hash == orig_schema["version_hash"] def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: @@ -166,7 +292,7 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: # execute in test storage with set_working_dir(TEST_STORAGE_ROOT): # store dlt data in test storage (like patch_home_dir) - with custom_environ({"DLT_DATA_DIR": get_dlt_data_dir()}): + with custom_environ({DLT_DATA_DIR: get_dlt_data_dir()}): # save database outside of pipeline dir with custom_environ( {"DESTINATION__DUCKDB__CREDENTIALS": "duckdb:///test_github_3.duckdb"} @@ -182,7 +308,7 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: ) print( venv.run_script( - "../tests/pipeline/cases/github_pipeline/github_normalize.py", + "../tests/pipeline/cases/github_pipeline/github_normalize.py" ) ) # switch to current version and make sure the load package loads and schema migrates @@ -192,7 +318,9 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: DuckDbClientConfiguration()._bind_dataset_name(dataset_name=GITHUB_DATASET), sections=("destination", "duckdb"), ) - with DuckDbSqlClient(GITHUB_DATASET, duckdb_cfg.credentials) as client: + with DuckDbSqlClient( + GITHUB_DATASET, "%s_staging", duckdb_cfg.credentials, duckdb().capabilities() + ) as client: rows = client.execute_sql("SELECT * FROM issues") assert len(rows) == 70 github_schema = json.loads( @@ -201,7 +329,9 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: ) ) # attach to existing pipeline - pipeline = dlt.attach(GITHUB_PIPELINE_NAME, credentials=duckdb_cfg.credentials) + pipeline = dlt.attach( + GITHUB_PIPELINE_NAME, destination=duckdb(credentials=duckdb_cfg.credentials) + ) # get the schema from schema storage before we sync github_schema = json.loads( test_storage.load( @@ -217,7 +347,7 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: assert pipeline.state["_version_hash"] is not None # but in db there's no hash - we loaded an old package with backward compatible schema with pipeline.sql_client() as client: - rows = client.execute_sql(f"SELECT * FROM {STATE_TABLE_NAME}") + rows = client.execute_sql(f"SELECT * FROM {PIPELINE_STATE_TABLE_NAME}") # no hash assert len(rows[0]) == 5 + 2 assert len(rows) == 1 @@ -227,7 +357,7 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: # this will sync schema to destination pipeline.sync_schema() # we have hash now - rows = client.execute_sql(f"SELECT * FROM {STATE_TABLE_NAME}") + rows = client.execute_sql(f"SELECT * FROM {PIPELINE_STATE_TABLE_NAME}") assert len(rows[0]) == 6 + 2 @@ -237,7 +367,7 @@ def test_normalize_package_with_dlt_update(test_storage: FileStorage) -> None: # execute in test storage with set_working_dir(TEST_STORAGE_ROOT): # store dlt data in test storage (like patch_home_dir) - with custom_environ({"DLT_DATA_DIR": get_dlt_data_dir()}): + with custom_environ({DLT_DATA_DIR: get_dlt_data_dir()}): # save database outside of pipeline dir with custom_environ( {"DESTINATION__DUCKDB__CREDENTIALS": "duckdb:///test_github_3.duckdb"} diff --git a/tests/pipeline/test_drop_helpers.py b/tests/pipeline/test_drop_helpers.py new file mode 100644 index 0000000000..9a09d9f866 --- /dev/null +++ b/tests/pipeline/test_drop_helpers.py @@ -0,0 +1,209 @@ +import pytest +from copy import deepcopy + +import dlt +from dlt.common.schema.typing import LOADS_TABLE_NAME, PIPELINE_STATE_TABLE_NAME, VERSION_TABLE_NAME +from dlt.common.versioned_state import decompress_state +from dlt.pipeline.drop import drop_resources +from dlt.pipeline.helpers import DropCommand, refresh_source + +from tests.pipeline.utils import airtable_emojis, assert_load_info + + +@pytest.mark.parametrize("seen_data", [True, False], ids=["seen_data", "no_data"]) +def test_drop_helper_utils(seen_data: bool) -> None: + pipeline = dlt.pipeline("test_drop_helpers_no_table_drop", destination="duckdb") + # extract first which should produce tables that didn't seen data + source = airtable_emojis().with_resources( + "📆 Schedule", "🦚Peacock", "🦚WidePeacock", "💰Budget" + ) + if seen_data: + pipeline.run(source) + else: + pipeline.extract(source) + + # drop nothing + drop_info = drop_resources(pipeline.default_schema.clone(), pipeline.state) + assert drop_info.modified_tables == [] + assert drop_info.info["tables"] == [] + + # drop all resources + drop_info = drop_resources(pipeline.default_schema.clone(), pipeline.state, drop_all=True) + # no tables to drop + tables_to_drop = ( + {"_schedule", "_peacock", "_wide_peacock", "_peacock__peacock", "_wide_peacock__peacock"} + if seen_data + else set() + ) + tables_to_drop_schema = ( + tables_to_drop if seen_data else {"_schedule", "_peacock", "_wide_peacock"} + ) + assert {t["name"] for t in drop_info.modified_tables} == tables_to_drop + # no state mods + assert drop_info.state["sources"]["airtable_emojis"] == {"resources": {}} + assert set(drop_info.info["tables"]) == tables_to_drop_schema + assert set(drop_info.info["tables_with_data"]) == tables_to_drop + # all tables got dropped + assert drop_info.schema.data_tables(include_incomplete=True) == [] + # dlt tables still there + assert set(drop_info.schema.dlt_table_names()) == { + VERSION_TABLE_NAME, + LOADS_TABLE_NAME, + PIPELINE_STATE_TABLE_NAME, + } + # same but with refresh + source_clone = source.clone() + source_clone.schema = pipeline.default_schema.clone() + with pipeline.managed_state() as state: + emoji_state = deepcopy(state["sources"]["airtable_emojis"]) + package_state = refresh_source(pipeline, source_clone, refresh="drop_sources") + # managed state modified + assert state["sources"]["airtable_emojis"] == {"resources": {}} + # restore old state for next tests + state["sources"]["airtable_emojis"] = emoji_state + if seen_data: + assert {t["name"] for t in package_state["dropped_tables"]} == tables_to_drop + else: + assert package_state == {} + assert source_clone.schema.data_tables(include_incomplete=True) == [] + + # drop only selected resources + tables_to_drop = {"_schedule"} if seen_data else set() + # seen_data means full run so we generate child tables in that case + left_in_schema = ( + {"_peacock", "_wide_peacock", "_peacock__peacock", "_wide_peacock__peacock"} + if seen_data + else {"_peacock", "_wide_peacock"} + ) + drop_info = drop_resources( + pipeline.default_schema.clone(), pipeline.state, resources=["📆 Schedule"] + ) + assert set(t["name"] for t in drop_info.modified_tables) == tables_to_drop + # no changes in state + assert drop_info.state == pipeline.state + assert set(drop_info.info["tables"]) == {"_schedule"} + assert set(drop_info.schema.data_table_names(include_incomplete=True)) == left_in_schema + source_clone = source_clone.with_resources("📆 Schedule") + source_clone.schema = pipeline.default_schema.clone() + with pipeline.managed_state() as state: + package_state = refresh_source(pipeline, source_clone, refresh="drop_resources") + # state not modified + assert state["sources"]["airtable_emojis"] == {"resources": {"🦚Peacock": {"🦚🦚🦚": "🦚"}}} + if seen_data: + assert {t["name"] for t in package_state["dropped_tables"]} == tables_to_drop + else: + assert package_state == {} + assert set(source_clone.schema.data_table_names(include_incomplete=True)) == left_in_schema + + # truncate only + tables_to_truncate = ( + {"_peacock", "_wide_peacock", "_peacock__peacock", "_wide_peacock__peacock"} + if seen_data + else set() + ) + all_in_schema = ( + {"_schedule", "_peacock", "_wide_peacock", "_peacock__peacock", "_wide_peacock__peacock"} + if seen_data + else {"_schedule", "_peacock", "_wide_peacock"} + ) + drop_info = drop_resources( + pipeline.default_schema.clone(), + pipeline.state, + resources=["🦚Peacock", "🦚WidePeacock"], + state_only=True, + ) + assert set(t["name"] for t in drop_info.modified_tables) == tables_to_truncate + # state is modified + assert drop_info.state["sources"]["airtable_emojis"] == {"resources": {}} + assert drop_info.info["tables"] == [] + # no tables with data will be dropped + assert drop_info.info["tables_with_data"] == [] + assert set(drop_info.schema.data_table_names(include_incomplete=True)) == all_in_schema + source_clone = source_clone.with_resources("🦚Peacock", "🦚WidePeacock") + source_clone.schema = pipeline.default_schema.clone() + with pipeline.managed_state() as state: + package_state = refresh_source(pipeline, source_clone, refresh="drop_data") + # state modified + assert state["sources"]["airtable_emojis"] == {"resources": {}} + if seen_data: + assert {t["name"] for t in package_state["truncated_tables"]} == tables_to_truncate + else: + assert package_state == {} + assert set(source_clone.schema.data_table_names(include_incomplete=True)) == all_in_schema + + +def test_drop_unknown_resource() -> None: + pipeline = dlt.pipeline("test_drop_unknown_resource", destination="duckdb") + # extract first which should produce tables that didn't seen data + source = airtable_emojis().with_resources( + "📆 Schedule", "🦚Peacock", "🦚WidePeacock", "💰Budget" + ) + info = pipeline.run(source) + assert_load_info(info) + drop = DropCommand(pipeline, resources=["💰Budget"]) + assert drop.is_empty + + source.schema = pipeline.default_schema + package_state = refresh_source( + pipeline, source.with_resources("💰Budget"), refresh="drop_resources" + ) + assert package_state == {} + + info = pipeline.run(source.with_resources("💰Budget"), refresh="drop_resources") + # nothing loaded + assert_load_info(info, 0) + + +def test_modified_state_in_package() -> None: + pipeline = dlt.pipeline("test_modified_state_in_package", destination="duckdb") + # extract first which should produce tables that didn't seen data + source = airtable_emojis().with_resources( + "📆 Schedule", "🦚Peacock", "🦚WidePeacock", "💰Budget" + ) + pipeline.extract(source) + # run again to change peacock state again + info = pipeline.extract(source) + normalize_storage = pipeline._get_normalize_storage() + package_state = normalize_storage.extracted_packages.get_load_package_state(info.loads_ids[0]) + pipeline_state = decompress_state(package_state["pipeline_state"]["state"]) + assert pipeline_state["sources"]["airtable_emojis"] == { + "resources": {"🦚Peacock": {"🦚🦚🦚": "🦚🦚"}} + } + + # remove state + info = pipeline.extract(airtable_emojis().with_resources("🦚Peacock"), refresh="drop_resources") + normalize_storage = pipeline._get_normalize_storage() + package_state = normalize_storage.extracted_packages.get_load_package_state(info.loads_ids[0]) + # nothing to drop + assert "dropped_tables" not in package_state + pipeline_state = decompress_state(package_state["pipeline_state"]["state"]) + # the state was reset to the original + assert pipeline_state["sources"]["airtable_emojis"] == { + "resources": {"🦚Peacock": {"🦚🦚🦚": "🦚"}} + } + + +def test_drop_tables_force_extract_state() -> None: + # if any tables will be dropped, state must be extracted even if it is not changed + pipeline = dlt.pipeline("test_drop_tables_force_extract_state", destination="duckdb") + source = airtable_emojis().with_resources( + "📆 Schedule", "🦚Peacock", "🦚WidePeacock", "💰Budget" + ) + info = pipeline.run(source) + assert_load_info(info) + # dropping schedule should not change the state + info = pipeline.run(airtable_emojis().with_resources("📆 Schedule"), refresh="drop_resources") + assert_load_info(info) + storage = pipeline._get_load_storage() + package_state = storage.get_load_package_state(info.loads_ids[0]) + assert package_state["dropped_tables"][0]["name"] == "_schedule" + assert "pipeline_state" in package_state + + # here we drop and set state to original, so without forcing state extract state would not be present + info = pipeline.run(airtable_emojis().with_resources("🦚Peacock"), refresh="drop_resources") + assert_load_info(info) + storage = pipeline._get_load_storage() + package_state = storage.get_load_package_state(info.loads_ids[0]) + # child table also dropped + assert len(package_state["dropped_tables"]) == 2 + assert "pipeline_state" in package_state diff --git a/tests/pipeline/test_import_export_schema.py b/tests/pipeline/test_import_export_schema.py index 6f40e1d1eb..eb36d36ba3 100644 --- a/tests/pipeline/test_import_export_schema.py +++ b/tests/pipeline/test_import_export_schema.py @@ -117,7 +117,7 @@ def test_import_schema_is_respected() -> None: destination=dummy(completed_prob=1), import_schema_path=IMPORT_SCHEMA_PATH, export_schema_path=EXPORT_SCHEMA_PATH, - full_refresh=True, + dev_mode=True, ) p.extract(EXAMPLE_DATA, table_name="person") # starts with import schema v 1 that is dirty -> 2 @@ -153,7 +153,7 @@ def resource(): destination=dummy(completed_prob=1), import_schema_path=IMPORT_SCHEMA_PATH, export_schema_path=EXPORT_SCHEMA_PATH, - full_refresh=True, + dev_mode=True, ) p.run(source()) diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index f838f31333..7c7dac8e71 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -7,7 +7,7 @@ import random import threading from time import sleep -from typing import Any, Tuple, cast +from typing import Any, List, Tuple, cast from tenacity import retry_if_exception, Retrying, stop_after_attempt import pytest @@ -15,10 +15,8 @@ import dlt from dlt.common import json, pendulum from dlt.common.configuration.container import Container -from dlt.common.configuration.exceptions import ConfigFieldMissingException -from dlt.common.configuration.specs.aws_credentials import AwsCredentials -from dlt.common.configuration.specs.exceptions import NativeValueError -from dlt.common.configuration.specs.gcp_credentials import GcpOAuthCredentials +from dlt.common.configuration.exceptions import ConfigFieldMissingException, InvalidNativeValue +from dlt.common.data_writers.exceptions import FileImportNotFound, SpecLookupFailed from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import WithStateSync from dlt.common.destination.exceptions import ( @@ -32,6 +30,8 @@ from dlt.common.exceptions import PipelineStateNotAvailable from dlt.common.pipeline import LoadInfo, PipelineContext from dlt.common.runtime.collector import LogCollector +from dlt.common.schema.exceptions import TableIdentifiersFrozen +from dlt.common.schema.typing import TColumnSchema from dlt.common.schema.utils import new_column, new_table from dlt.common.typing import DictStrAny from dlt.common.utils import uniq_id @@ -44,9 +44,11 @@ from dlt.extract import DltResource, DltSource from dlt.extract.extractors import MaterializedEmptyList from dlt.load.exceptions import LoadClientJobFailed +from dlt.normalize.exceptions import NormalizeJobFailed from dlt.pipeline.exceptions import InvalidPipelineName, PipelineNotActive, PipelineStepFailed from dlt.pipeline.helpers import retry_load +from dlt.pipeline.pipeline import Pipeline from tests.common.utils import TEST_SENTRY_DSN from tests.common.configuration.utils import environment from tests.utils import TEST_STORAGE_ROOT, skipifnotwindows @@ -55,7 +57,9 @@ assert_data_table_counts, assert_load_info, airtable_emojis, + assert_only_table_columns, load_data_table_counts, + load_tables_to_dicts, many_delayed, ) @@ -83,6 +87,36 @@ def test_default_pipeline() -> None: assert p.default_schema_name in ["dlt_pytest", "dlt"] +def test_default_pipeline_dataset_layout(environment) -> None: + # Set dataset_name_layout to "bobby_%s" + dataset_name_layout = "bobby_%s" + environment["DATASET_NAME_LAYOUT"] = dataset_name_layout + + p = dlt.pipeline() + # this is a name of executing test harness or blank pipeline on windows + possible_names = ["dlt_pytest", "dlt_pipeline"] + possible_dataset_names = [ + dataset_name_layout % "dlt_pytest_dataset", + dataset_name_layout % "dlt_pipeline_dataset", + ] + assert p.pipeline_name in possible_names + assert p.pipelines_dir == os.path.abspath(os.path.join(TEST_STORAGE_ROOT, ".dlt", "pipelines")) + assert p.runtime_config.pipeline_name == p.pipeline_name + # dataset that will be used to load data is the pipeline name + assert p.dataset_name in possible_dataset_names + assert p.destination is None + assert p.default_schema_name is None + + # this is the same pipeline + p2 = dlt.pipeline() + assert p is p2 + + # this will create default schema + p.extract(["a", "b", "c"], table_name="data") + # `_pipeline` is removed from default schema name + assert p.default_schema_name in ["dlt_pytest", "dlt"] + + def test_default_pipeline_dataset() -> None: # dummy does not need a dataset p = dlt.pipeline(destination="dummy") @@ -94,6 +128,40 @@ def test_default_pipeline_dataset() -> None: assert p.dataset_name in possible_dataset_names +def test_default_pipeline_dataset_name(environment) -> None: + environment["DATASET_NAME"] = "dataset" + environment["DATASET_NAME_LAYOUT"] = "prefix_%s" + + p = dlt.pipeline(destination="filesystem") + assert p.dataset_name == "prefix_dataset" + + +def test_default_pipeline_dataset_layout_exception(environment) -> None: + # Set dataset_name_layout without placeholder %s + environment["DATASET_NAME_LAYOUT"] = "bobby_" + + with pytest.raises(ValueError): + dlt.pipeline(destination="filesystem") + + +def test_default_pipeline_dataset_layout_placeholder(environment) -> None: + # Set dataset_name_layout only with placeholder + environment["DATASET_NAME_LAYOUT"] = "%s" + + possible_dataset_names = ["dlt_pytest_dataset", "dlt_pipeline_dataset"] + p = dlt.pipeline(destination="filesystem") + assert p.dataset_name in possible_dataset_names + + +def test_default_pipeline_dataset_layout_empty(environment) -> None: + # Set dataset_name_layout empty + environment["DATASET_NAME_LAYOUT"] = "" + + possible_dataset_names = ["dlt_pytest_dataset", "dlt_pipeline_dataset"] + p = dlt.pipeline(destination="filesystem") + assert p.dataset_name in possible_dataset_names + + def test_run_dev_mode_default_dataset() -> None: p = dlt.pipeline(dev_mode=True, destination="filesystem") assert p.dataset_name.endswith(p._pipeline_instance_id) @@ -112,6 +180,39 @@ def test_run_dev_mode_default_dataset() -> None: assert p.dataset_name and p.dataset_name.endswith(p._pipeline_instance_id) +def test_run_dev_mode_default_dataset_layout(environment) -> None: + # Set dataset_name_layout to "bobby_%s" + dataset_name_layout = "bobby_%s" + environment["DATASET_NAME_LAYOUT"] = dataset_name_layout + + p = dlt.pipeline(dev_mode=True, destination="filesystem") + assert p.dataset_name in [ + dataset_name_layout % f"dlt_pytest_dataset{p._pipeline_instance_id}", + dataset_name_layout % f"dlt_pipeline_dataset{p._pipeline_instance_id}", + ] + # restore this pipeline + r_p = dlt.attach(dev_mode=False) + assert r_p.dataset_name in [ + dataset_name_layout % f"dlt_pytest_dataset{p._pipeline_instance_id}", + dataset_name_layout % f"dlt_pipeline_dataset{p._pipeline_instance_id}", + ] + + # dummy does not need dataset + p = dlt.pipeline(dev_mode=True, destination="dummy") + assert p.dataset_name is None + + # simulate set new dataset + p._set_destinations("filesystem") + assert p.dataset_name is None + p._set_dataset_name(None) + + # full refresh is still observed + assert p.dataset_name in [ + dataset_name_layout % f"dlt_pytest_dataset{p._pipeline_instance_id}", + dataset_name_layout % f"dlt_pipeline_dataset{p._pipeline_instance_id}", + ] + + def test_run_dev_mode_underscored_dataset() -> None: p = dlt.pipeline(dev_mode=True, dataset_name="_main_") assert p.dataset_name.endswith(p._pipeline_instance_id) @@ -171,6 +272,16 @@ def test_invalid_dataset_name() -> None: assert p.dataset_name == "!" +def test_invalid_dataset_layout(environment) -> None: + # Set dataset_name_prefix to "bobby" + dataset_name_layout = "bobby_%s" + environment["DATASET_NAME_LAYOUT"] = dataset_name_layout + + # this is invalid dataset name but it will be normalized within a destination + p = dlt.pipeline(dataset_name="!") + assert p.dataset_name == dataset_name_layout % "!" + + def test_pipeline_context_deferred_activation() -> None: ctx = Container()[PipelineContext] assert ctx.is_active() is False @@ -201,7 +312,8 @@ def test_pipeline_context() -> None: assert ctx.pipeline() is p3 assert p3.is_active is True assert p2.is_active is False - assert Container()[DestinationCapabilitiesContext].naming_convention == "snake_case" + # no default naming convention + assert Container()[DestinationCapabilitiesContext].naming_convention is None # restore previous p2 = dlt.attach("another pipeline") @@ -259,49 +371,16 @@ def test_deterministic_salt(environment) -> None: def test_destination_explicit_credentials(environment: Any) -> None: + from dlt.destinations import motherduck + # test redshift p = dlt.pipeline( pipeline_name="postgres_pipeline", - destination="redshift", - credentials="redshift://loader:loader@localhost:5432/dlt_data", - ) - config = p._get_destination_client_initial_config() - assert config.credentials.is_resolved() - # with staging - p = dlt.pipeline( - pipeline_name="postgres_pipeline", - staging="filesystem", - destination="redshift", - credentials="redshift://loader:loader@localhost:5432/dlt_data", - ) - config = p._get_destination_client_initial_config(p.destination) - assert config.credentials.is_resolved() - config = p._get_destination_client_initial_config(p.staging, as_staging=True) - assert config.credentials is None - p._wipe_working_folder() - # try filesystem which uses union of credentials that requires bucket_url to resolve - p = dlt.pipeline( - pipeline_name="postgres_pipeline", - destination="filesystem", - credentials={"aws_access_key_id": "key_id", "aws_secret_access_key": "key"}, + destination=motherduck(credentials="md://user:password@/dlt_data"), ) - config = p._get_destination_client_initial_config(p.destination) - assert isinstance(config.credentials, AwsCredentials) - assert config.credentials.is_resolved() - # resolve gcp oauth - p = dlt.pipeline( - pipeline_name="postgres_pipeline", - destination="filesystem", - credentials={ - "project_id": "pxid", - "refresh_token": "123token", - "client_id": "cid", - "client_secret": "s", - }, - ) - config = p._get_destination_client_initial_config(p.destination) - assert isinstance(config.credentials, GcpOAuthCredentials) + config = p.destination_client().config assert config.credentials.is_resolved() + assert config.credentials.to_native_representation() == "md://user:password@/dlt_data" def test_destination_staging_config(environment: Any) -> None: @@ -354,14 +433,15 @@ def test_destination_credentials_in_factory(environment: Any) -> None: assert dest_config.credentials.database == "some_db" -@pytest.mark.skip(reason="does not work on CI. probably takes right credentials from somewhere....") def test_destination_explicit_invalid_credentials_filesystem(environment: Any) -> None: # if string cannot be parsed p = dlt.pipeline( - pipeline_name="postgres_pipeline", destination="filesystem", credentials="PR8BLEM" + pipeline_name="postgres_pipeline", + destination=filesystem(bucket_url="s3://test", destination_name="uniq_s3_bucket"), ) - with pytest.raises(NativeValueError): - p._get_destination_client_initial_config(p.destination) + with pytest.raises(PipelineStepFailed) as pip_ex: + p.run([1, 2, 3], table_name="data", credentials="PR8BLEM") + assert isinstance(pip_ex.value.__cause__, InvalidNativeValue) def test_extract_source_twice() -> None: @@ -1539,10 +1619,13 @@ def autodetect(): pipeline = pipeline.drop() source = autodetect() + assert "timestamp" in source.schema.settings["detections"] source.schema.remove_type_detection("timestamp") + assert "timestamp" not in source.schema.settings["detections"] pipeline = dlt.pipeline(destination="duckdb") pipeline.run(source) + assert "timestamp" not in pipeline.default_schema.settings["detections"] assert pipeline.default_schema.get_table("numbers")["columns"]["value"]["data_type"] == "bigint" @@ -1969,7 +2052,7 @@ def source(): assert len(load_info.loads_ids) == 1 -def test_pipeline_load_info_metrics_schema_is_not_chaning() -> None: +def test_pipeline_load_info_metrics_schema_is_not_changing() -> None: """Test if load info schema is idempotent throughout multiple load cycles ## Setup @@ -2025,7 +2108,6 @@ def demand_map(): pipeline_name="quick_start", destination="duckdb", dataset_name="mydata", - # export_schema_path="schemas", ) taxi_load_info = pipeline.run( @@ -2243,7 +2325,7 @@ def test_data(): pipeline = dlt.pipeline( pipeline_name="test_staging_cleared", destination="duckdb", - full_refresh=True, + dev_mode=True, ) info = pipeline.run(test_data, table_name="staging_cleared") @@ -2260,3 +2342,310 @@ def test_data(): with client.execute_query(f"SELECT * FROM {pipeline.dataset_name}.staging_cleared") as cur: assert len(cur.fetchall()) == 3 + + +def test_change_naming_convention_name_collision() -> None: + duck_ = dlt.destinations.duckdb(naming_convention="duck_case", recommended_file_size=120000) + caps = duck_.capabilities() + assert caps.naming_convention == "duck_case" + assert caps.recommended_file_size == 120000 + + # use duck case to load data into duckdb so casing and emoji are preserved + pipeline = dlt.pipeline("test_change_naming_convention_name_collision", destination=duck_) + info = pipeline.run( + airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock") + ) + assert_load_info(info) + # make sure that emojis got in + assert "🦚Peacock" in pipeline.default_schema.tables + assert "🔑id" in pipeline.default_schema.tables["🦚Peacock"]["columns"] + assert load_data_table_counts(pipeline) == { + "📆 Schedule": 3, + "🦚Peacock": 1, + "🦚WidePeacock": 1, + "🦚Peacock__peacock": 3, + "🦚WidePeacock__Peacock": 3, + } + with pipeline.sql_client() as client: + rows = client.execute_sql("SELECT 🔑id FROM 🦚Peacock") + # 🔑id value is 1 + assert rows[0][0] == 1 + + # change naming convention and run pipeline again so we generate name clashes + os.environ["SOURCES__AIRTABLE_EMOJIS__SCHEMA__NAMING"] = "sql_ci_v1" + with pytest.raises(PipelineStepFailed) as pip_ex: + pipeline.run(airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock")) + # see conflicts early + assert pip_ex.value.step == "extract" + assert isinstance(pip_ex.value.__cause__, TableIdentifiersFrozen) + + # all good if we drop tables + info = pipeline.run( + airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock"), + refresh="drop_resources", + ) + assert_load_info(info) + # case insensitive normalization + assert load_data_table_counts(pipeline) == { + "_schedule": 3, + "_peacock": 1, + "_widepeacock": 1, + "_peacock__peacock": 3, + "_widepeacock__peacock": 3, + } + + +def test_change_to_more_lax_naming_convention_name_collision() -> None: + # use snake_case which is strict and then change to duck_case which accepts snake_case names without any changes + # still we want to detect collisions + pipeline = dlt.pipeline( + "test_change_to_more_lax_naming_convention_name_collision", destination="duckdb" + ) + info = pipeline.run( + airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock") + ) + assert_load_info(info) + assert "_peacock" in pipeline.default_schema.tables + + # use duck case to load data into duckdb so casing and emoji are preserved + duck_ = dlt.destinations.duckdb(naming_convention="duck_case") + + # changing destination to one with a separate naming convention raises immediately + with pytest.raises(TableIdentifiersFrozen): + pipeline.run( + airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock"), + destination=duck_, + ) + + # refresh on the source level will work + info = pipeline.run( + airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock"), + destination=duck_, + refresh="drop_sources", + ) + assert_load_info(info) + # make sure that emojis got in + assert "🦚Peacock" in pipeline.default_schema.tables + + +def test_change_naming_convention_column_collision() -> None: + duck_ = dlt.destinations.duckdb(naming_convention="duck_case") + + data = {"Col": "A"} + pipeline = dlt.pipeline("test_change_naming_convention_column_collision", destination=duck_) + info = pipeline.run([data], table_name="data") + assert_load_info(info) + + os.environ["SCHEMA__NAMING"] = "sql_ci_v1" + with pytest.raises(PipelineStepFailed) as pip_ex: + pipeline.run([data], table_name="data") + assert isinstance(pip_ex.value.__cause__, TableIdentifiersFrozen) + + +def test_import_jsonl_file() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_jsonl_import", + destination="duckdb", + dev_mode=True, + ) + columns: List[TColumnSchema] = [ + {"name": "id", "data_type": "bigint", "nullable": False}, + {"name": "name", "data_type": "text"}, + {"name": "description", "data_type": "text"}, + {"name": "ordered_at", "data_type": "date"}, + {"name": "price", "data_type": "decimal"}, + ] + import_file = "tests/load/cases/loading/header.jsonl" + info = pipeline.run( + [dlt.mark.with_file_import(import_file, "jsonl", 2)], + table_name="no_header", + loader_file_format="jsonl", + columns=columns, + ) + info.raise_on_failed_jobs() + print(info) + assert_imported_file(pipeline, "no_header", columns, 2) + + # use hints to infer + hints = dlt.mark.make_hints(columns=columns) + info = pipeline.run( + [dlt.mark.with_file_import(import_file, "jsonl", 2, hints=hints)], + table_name="no_header_2", + ) + info.raise_on_failed_jobs() + assert_imported_file(pipeline, "no_header_2", columns, 2, expects_state=False) + + +def test_import_file_without_sniff_schema() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_jsonl_import", + destination="duckdb", + dev_mode=True, + ) + import_file = "tests/load/cases/loading/header.jsonl" + info = pipeline.run( + [dlt.mark.with_file_import(import_file, "jsonl", 2)], + table_name="no_header", + ) + assert info.has_failed_jobs + print(info) + + +def test_import_non_existing_file() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_jsonl_import", + destination="duckdb", + dev_mode=True, + ) + # this file does not exist + import_file = "tests/load/cases/loading/X_header.jsonl" + with pytest.raises(PipelineStepFailed) as pip_ex: + pipeline.run( + [dlt.mark.with_file_import(import_file, "jsonl", 2)], + table_name="no_header", + ) + inner_ex = pip_ex.value.__cause__ + assert isinstance(inner_ex, FileImportNotFound) + assert inner_ex.import_file_path == import_file + + +def test_import_unsupported_file_format() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_jsonl_import", + destination="duckdb", + dev_mode=True, + ) + # this file does not exist + import_file = "tests/load/cases/loading/csv_no_header.csv" + with pytest.raises(PipelineStepFailed) as pip_ex: + pipeline.run( + [dlt.mark.with_file_import(import_file, "csv", 2)], + table_name="no_header", + ) + inner_ex = pip_ex.value.__cause__ + assert isinstance(inner_ex, NormalizeJobFailed) + assert isinstance(inner_ex.__cause__, SpecLookupFailed) + + +def test_import_unknown_file_format() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_jsonl_import", + destination="duckdb", + dev_mode=True, + ) + # this file does not exist + import_file = "tests/load/cases/loading/csv_no_header.csv" + with pytest.raises(PipelineStepFailed) as pip_ex: + pipeline.run( + [dlt.mark.with_file_import(import_file, "unknown", 2)], # type: ignore[arg-type] + table_name="no_header", + ) + inner_ex = pip_ex.value.__cause__ + assert isinstance(inner_ex, NormalizeJobFailed) + # can't figure format from extension + assert isinstance(inner_ex.__cause__, ValueError) + + +def test_static_staging_dataset() -> None: + # share database and staging dataset + duckdb_ = dlt.destinations.duckdb( + "_storage/test_static_staging_dataset.db", staging_dataset_name_layout="_dlt_staging" + ) + + pipeline_1 = dlt.pipeline("test_static_staging_dataset_1", destination=duckdb_, dev_mode=True) + pipeline_2 = dlt.pipeline("test_static_staging_dataset_2", destination=duckdb_, dev_mode=True) + # staging append (without primary key) + info = pipeline_1.run([1, 2, 3], table_name="digits", write_disposition="merge") + assert_load_info(info) + info = pipeline_2.run(["a", "b", "c", "d"], table_name="letters", write_disposition="merge") + assert_load_info(info) + with pipeline_1.sql_client() as client: + with client.with_alternative_dataset_name("_dlt_staging"): + assert client.has_dataset() + schemas = client.execute_sql("SELECT schema_name FROM _dlt_staging._dlt_version") + assert {s[0] for s in schemas} == { + "test_static_staging_dataset_1", + "test_static_staging_dataset_2", + } + + assert_data_table_counts(pipeline_1, {"digits": 3}) + assert_data_table_counts(pipeline_2, {"letters": 4}) + + +def test_underscore_tables_and_columns() -> None: + pipeline = dlt.pipeline("test_underscore_tables_and_columns", destination="duckdb") + + @dlt.resource + def ids(_id=dlt.sources.incremental("_id", initial_value=2)): + yield from [{"_id": i, "value": l} for i, l in zip([1, 2, 3], ["A", "B", "C"])] + + info = pipeline.run(ids, table_name="_ids") + assert_load_info(info) + print(pipeline.default_schema.to_pretty_yaml()) + assert pipeline.last_trace.last_normalize_info.row_counts["_ids"] == 2 + + +def test_access_pipeline_in_resource() -> None: + pipeline = dlt.pipeline("test_access_pipeline_in_resource", destination="duckdb") + + @dlt.resource(name="user_comments") + def comments(user_id: str): + current_pipeline = dlt.current.pipeline() + # find last comment id for given user_id by looking in destination + max_id: int = 0 + # on first pipeline run, user_comments table does not yet exist so do not check at all + # alternatively catch DatabaseUndefinedRelation which is raised when unknown table is selected + if not current_pipeline.first_run: + with current_pipeline.sql_client() as client: + # we may get last user comment or None which we replace with 0 + max_id = ( + client.execute_sql( + "SELECT MAX(_id) FROM user_comments WHERE user_id=?", user_id + )[0][0] + or 0 + ) + # use max_id to filter our results + yield from [ + {"_id": i, "value": l, "user_id": user_id} + for i, l in zip([1, 2, 3], ["A", "B", "C"]) + if i > max_id + ] + + info = pipeline.run(comments("USER_A")) + assert_load_info(info) + assert pipeline.last_trace.last_normalize_info.row_counts["user_comments"] == 3 + info = pipeline.run(comments("USER_A")) + # no more data for USER_A + assert_load_info(info, 0) + info = pipeline.run(comments("USER_B")) + assert_load_info(info) + assert pipeline.last_trace.last_normalize_info.row_counts["user_comments"] == 3 + + +def assert_imported_file( + pipeline: Pipeline, + table_name: str, + columns: List[TColumnSchema], + expected_rows: int, + expects_state: bool = True, +) -> None: + assert_only_table_columns(pipeline, table_name, [col["name"] for col in columns]) + rows = load_tables_to_dicts(pipeline, table_name) + assert len(rows[table_name]) == expected_rows + # we should have twp files loaded + jobs = pipeline.last_trace.last_load_info.load_packages[0].jobs["completed_jobs"] + job_extensions = [os.path.splitext(job.job_file_info.file_name())[1] for job in jobs] + assert ".jsonl" in job_extensions + if expects_state: + assert ".insert_values" in job_extensions + # check extract trace if jsonl is really there + extract_info = pipeline.last_trace.last_extract_info + jobs = extract_info.load_packages[0].jobs["new_jobs"] + # find jsonl job + jsonl_job = next(job for job in jobs if job.job_file_info.table_name == table_name) + assert jsonl_job.job_file_info.file_format == "jsonl" + # find metrics for table + assert ( + extract_info.metrics[extract_info.loads_ids[0]][0]["table_metrics"][table_name].items_count + == expected_rows + ) diff --git a/tests/pipeline/test_pipeline_extra.py b/tests/pipeline/test_pipeline_extra.py index 7208216c9f..308cdcd91d 100644 --- a/tests/pipeline/test_pipeline_extra.py +++ b/tests/pipeline/test_pipeline_extra.py @@ -40,7 +40,11 @@ class BaseModel: # type: ignore[no-redef] @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs( + default_sql_configs=True, default_vector_configs=True, local_filesystem_configs=True + ), + ids=lambda x: x.name, ) def test_create_pipeline_all_destinations(destination_config: DestinationTestConfiguration) -> None: # create pipelines, extract and normalize. that should be possible without installing any dependencies @@ -51,11 +55,11 @@ def test_create_pipeline_all_destinations(destination_config: DestinationTestCon ) # are capabilities injected caps = p._container[DestinationCapabilitiesContext] - print(caps.naming_convention) - # are right naming conventions created - assert p._default_naming.max_length == min( - caps.max_column_identifier_length, caps.max_identifier_length - ) + if caps.naming_convention: + assert p.naming.name() == caps.naming_convention + else: + assert p.naming.name() == "snake_case" + p.extract([1, "2", 3], table_name="data") # is default schema with right naming convention assert p.default_schema.naming.max_length == min( @@ -469,6 +473,61 @@ def users(): assert set(table.schema.names) == {"id", "name", "_dlt_load_id", "_dlt_id"} +def test_resource_file_format() -> None: + os.environ["RESTORE_FROM_DESTINATION"] = "False" + + def jsonl_data(): + yield [ + { + "id": 1, + "name": "item", + "description": "value", + "ordered_at": "2024-04-12", + "price": 128.4, + }, + { + "id": 1, + "name": "item", + "description": "value with space", + "ordered_at": "2024-04-12", + "price": 128.4, + }, + ] + + # preferred file format will use destination preferred format + jsonl_preferred = dlt.resource(jsonl_data, file_format="preferred", name="jsonl_preferred") + assert jsonl_preferred.compute_table_schema()["file_format"] == "preferred" + + jsonl_r = dlt.resource(jsonl_data, file_format="jsonl", name="jsonl_r") + assert jsonl_r.compute_table_schema()["file_format"] == "jsonl" + + jsonl_pq = dlt.resource(jsonl_data, file_format="parquet", name="jsonl_pq") + assert jsonl_pq.compute_table_schema()["file_format"] == "parquet" + + info = dlt.pipeline("example", destination="duckdb").run([jsonl_preferred, jsonl_r, jsonl_pq]) + info.raise_on_failed_jobs() + # check file types on load jobs + load_jobs = { + job.job_file_info.table_name: job.job_file_info + for job in info.load_packages[0].jobs["completed_jobs"] + } + assert load_jobs["jsonl_r"].file_format == "jsonl" + assert load_jobs["jsonl_pq"].file_format == "parquet" + assert load_jobs["jsonl_preferred"].file_format == "insert_values" + + # test not supported format + csv_r = dlt.resource(jsonl_data, file_format="csv", name="csv_r") + assert csv_r.compute_table_schema()["file_format"] == "csv" + info = dlt.pipeline("example", destination="duckdb").run(csv_r) + info.raise_on_failed_jobs() + # fallback to preferred + load_jobs = { + job.job_file_info.table_name: job.job_file_info + for job in info.load_packages[0].jobs["completed_jobs"] + } + assert load_jobs["csv_r"].file_format == "insert_values" + + def test_pick_matching_file_format(test_storage: FileStorage) -> None: from dlt.destinations import filesystem diff --git a/tests/pipeline/test_pipeline_state.py b/tests/pipeline/test_pipeline_state.py index 8cbc1ca516..11c45d72cc 100644 --- a/tests/pipeline/test_pipeline_state.py +++ b/tests/pipeline/test_pipeline_state.py @@ -1,20 +1,25 @@ import os import shutil +from typing_extensions import get_type_hints import pytest import dlt - +from dlt.common.pendulum import pendulum from dlt.common.exceptions import ( PipelineStateNotAvailable, ResourceNameNotAvailable, ) from dlt.common.schema import Schema +from dlt.common.schema.utils import pipeline_state_table from dlt.common.source import get_current_pipe_name from dlt.common.storages import FileStorage from dlt.common import pipeline as state_module +from dlt.common.storages.load_package import TPipelineStateDoc from dlt.common.utils import uniq_id -from dlt.common.destination.reference import Destination +from dlt.common.destination.reference import Destination, StateInfo +from dlt.common.validation import validate_dict +from dlt.destinations.utils import get_pipeline_state_query_columns from dlt.pipeline.exceptions import PipelineStateEngineNoUpgradePathException, PipelineStepFailed from dlt.pipeline.pipeline import Pipeline from dlt.pipeline.state_sync import ( @@ -41,6 +46,56 @@ def some_data_resource_state(): dlt.current.resource_state()["last_value"] = last_value + 1 +def test_state_repr() -> None: + """Verify that all possible state representations match""" + table = pipeline_state_table() + state_doc_hints = get_type_hints(TPipelineStateDoc) + sync_class_hints = get_type_hints(StateInfo) + info = StateInfo(1, 4, "pipeline", "compressed", pendulum.now(), "hash", "_load_id") + state_doc = info.as_doc() + # just in case hardcode column order + reference_cols = [ + "version", + "engine_version", + "pipeline_name", + "state", + "created_at", + "version_hash", + "_dlt_load_id", + ] + # doc and table must be in the same order with the same name + assert ( + len(table["columns"]) + == len(state_doc_hints) + == len(sync_class_hints) + == len(state_doc) + == len(reference_cols) + ) + for col, hint, class_hint, val, ref_col in zip( + table["columns"].values(), state_doc_hints, sync_class_hints, state_doc, reference_cols + ): + assert col["name"] == hint == class_hint == val == ref_col + + # validate info + validate_dict(TPipelineStateDoc, state_doc, "$") + + info = StateInfo(1, 4, "pipeline", "compressed", pendulum.now()) + state_doc = info.as_doc() + assert "_dlt_load_id" not in state_doc + assert "version_hash" not in state_doc + + # we drop hash in query + compat_table = get_pipeline_state_query_columns() + assert list(compat_table["columns"].keys()) == [ + "version", + "engine_version", + "pipeline_name", + "state", + "created_at", + "_dlt_load_id", + ] + + def test_restore_state_props() -> None: p = dlt.pipeline( pipeline_name="restore_state_props", diff --git a/tests/pipeline/test_schema_contracts.py b/tests/pipeline/test_schema_contracts.py index a46529b861..4b46bb7c3e 100644 --- a/tests/pipeline/test_schema_contracts.py +++ b/tests/pipeline/test_schema_contracts.py @@ -177,8 +177,7 @@ def get_pipeline(): return dlt.pipeline( pipeline_name="contracts_" + uniq_id(), - destination="duckdb", - credentials=duckdb.connect(":memory:"), + destination=dlt.destinations.duckdb(credentials=duckdb.connect(":memory:")), dev_mode=True, ) diff --git a/tests/pipeline/utils.py b/tests/pipeline/utils.py index 7affcc5a81..f2e0058891 100644 --- a/tests/pipeline/utils.py +++ b/tests/pipeline/utils.py @@ -6,6 +6,7 @@ import dlt from dlt.common import json, sleep +from dlt.common.destination.exceptions import DestinationUndefinedEntity from dlt.common.pipeline import LoadInfo from dlt.common.schema.utils import get_table_format from dlt.common.typing import DictStrAny @@ -47,12 +48,14 @@ def budget(): @dlt.resource(name="🦚Peacock", selected=False, primary_key="🔑id") def peacock(): - dlt.current.resource_state()["🦚🦚🦚"] = "🦚" + r_state = dlt.current.resource_state() + r_state.setdefault("🦚🦚🦚", "") + r_state["🦚🦚🦚"] += "🦚" yield [{"peacock": [1, 2, 3], "🔑id": 1}] @dlt.resource(name="🦚WidePeacock", selected=False) def wide_peacock(): - yield [{"peacock": [1, 2, 3]}] + yield [{"Peacock": [1, 2, 3]}] return budget, schedule, peacock, wide_peacock @@ -198,7 +201,7 @@ def _load_tables_to_dicts_sql( for table_name in table_names: table_rows = [] columns = schema.get_table_columns(table_name).keys() - query_columns = ",".join(map(p.sql_client().capabilities.escape_identifier, columns)) + query_columns = ",".join(map(p.sql_client().escape_column_name, columns)) with p.sql_client() as c: query_columns = ",".join(map(c.escape_column_name, columns)) diff --git a/tests/sources/helpers/rest_client/api_router.py b/tests/sources/helpers/rest_client/api_router.py new file mode 100644 index 0000000000..661a4d3468 --- /dev/null +++ b/tests/sources/helpers/rest_client/api_router.py @@ -0,0 +1,61 @@ +import re +from typing import NamedTuple, Callable, Pattern, List, Union, TYPE_CHECKING, Dict, Any + +import requests_mock + +from dlt.common import json + +if TYPE_CHECKING: + RequestCallback = Callable[ + [requests_mock.Request, requests_mock.Context], Union[str, Dict[str, Any], List[Any]] + ] + ResponseSerializer = Callable[[requests_mock.Request, requests_mock.Context], str] +else: + RequestCallback = Callable + ResponseSerializer = Callable + + +class Route(NamedTuple): + method: str + pattern: Pattern[str] + callback: ResponseSerializer + + +class APIRouter: + def __init__(self, base_url: str): + self.routes: List[Route] = [] + self.base_url = base_url + + def _add_route(self, method: str, pattern: str, func: RequestCallback) -> RequestCallback: + compiled_pattern = re.compile(f"{self.base_url}{pattern}") + + def serialize_response(request, context): + result = func(request, context) + + if isinstance(result, dict) or isinstance(result, list): + return json.dumps(result) + + return result + + self.routes.append(Route(method, compiled_pattern, serialize_response)) + return serialize_response + + def get(self, pattern: str) -> Callable[[RequestCallback], RequestCallback]: + def decorator(func: RequestCallback) -> RequestCallback: + return self._add_route("GET", pattern, func) + + return decorator + + def post(self, pattern: str) -> Callable[[RequestCallback], RequestCallback]: + def decorator(func: RequestCallback) -> RequestCallback: + return self._add_route("POST", pattern, func) + + return decorator + + def register_routes(self, mocker: requests_mock.Mocker) -> None: + for route in self.routes: + mocker.register_uri( + route.method, + route.pattern, + text=route.callback, + ) diff --git a/tests/sources/helpers/rest_client/conftest.py b/tests/sources/helpers/rest_client/conftest.py index 7453c63d14..10dd23877d 100644 --- a/tests/sources/helpers/rest_client/conftest.py +++ b/tests/sources/helpers/rest_client/conftest.py @@ -1,156 +1,119 @@ -import re -from typing import NamedTuple, Callable, Pattern, List, Union, TYPE_CHECKING, Dict, List, Any import base64 - -from urllib.parse import urlsplit, urlunsplit +from urllib.parse import parse_qs, urlsplit, urlunsplit, urlencode import pytest import requests_mock -from dlt.common import json - -if TYPE_CHECKING: - RequestCallback = Callable[ - [requests_mock.Request, requests_mock.Context], Union[str, Dict[str, Any], List[Any]] - ] - ResponseSerializer = Callable[[requests_mock.Request, requests_mock.Context], str] -else: - RequestCallback = Callable - ResponseSerializer = Callable - -MOCK_BASE_URL = "https://api.example.com" - - -class Route(NamedTuple): - method: str - pattern: Pattern[str] - callback: ResponseSerializer - +from dlt.sources.helpers.rest_client import RESTClient -class APIRouter: - def __init__(self, base_url: str): - self.routes: List[Route] = [] - self.base_url = base_url +from .api_router import APIRouter +from .paginators import PageNumberPaginator, OffsetPaginator, CursorPaginator - def _add_route(self, method: str, pattern: str, func: RequestCallback) -> RequestCallback: - compiled_pattern = re.compile(f"{self.base_url}{pattern}") - def serialize_response(request, context): - result = func(request, context) +MOCK_BASE_URL = "https://api.example.com" +DEFAULT_PAGE_SIZE = 5 +DEFAULT_TOTAL_PAGES = 5 +DEFAULT_LIMIT = 10 - if isinstance(result, dict) or isinstance(result, list): - return json.dumps(result) - return result +router = APIRouter(MOCK_BASE_URL) - self.routes.append(Route(method, compiled_pattern, serialize_response)) - return serialize_response - def get(self, pattern: str) -> Callable[[RequestCallback], RequestCallback]: - def decorator(func: RequestCallback) -> RequestCallback: - return self._add_route("GET", pattern, func) +def generate_posts(count=DEFAULT_PAGE_SIZE * DEFAULT_TOTAL_PAGES): + return [{"id": i, "title": f"Post {i}"} for i in range(count)] - return decorator - def post(self, pattern: str) -> Callable[[RequestCallback], RequestCallback]: - def decorator(func: RequestCallback) -> RequestCallback: - return self._add_route("POST", pattern, func) +def generate_comments(post_id, count=50): + return [{"id": i, "body": f"Comment {i} for post {post_id}"} for i in range(count)] - return decorator - def register_routes(self, mocker: requests_mock.Mocker) -> None: - for route in self.routes: - mocker.register_uri( - route.method, - route.pattern, - text=route.callback, - ) +def get_page_number(qs, key="page", default=1): + return int(qs.get(key, [default])[0]) -router = APIRouter(MOCK_BASE_URL) +def create_next_page_url(request, paginator, use_absolute_url=True): + scheme, netloc, path, _, _ = urlsplit(request.url) + query = urlencode(paginator.next_page_url_params) + if use_absolute_url: + return urlunsplit([scheme, netloc, path, query, ""]) + else: + return f"{path}?{query}" -def serialize_page( - records, - page_number, - total_pages, - request_url, - records_key="data", - use_absolute_url=True, +def paginate_by_page_number( + request, records, records_key="data", use_absolute_url=True, index_base=1 ): - """Serialize a page of records into a dict with pagination metadata.""" - if records_key is None: - return records + page_number = get_page_number(request.qs, default=index_base) + paginator = PageNumberPaginator(records, page_number, index_base=index_base) response = { - records_key: records, - "page": page_number, - "total_pages": total_pages, + records_key: paginator.page_records, + **paginator.metadata, } - if page_number < total_pages: - next_page = page_number + 1 - - scheme, netloc, path, _, _ = urlsplit(request_url) - if use_absolute_url: - next_page_url = urlunsplit([scheme, netloc, path, f"page={next_page}", ""]) - else: - next_page_url = f"{path}?page={next_page}" - - response["next_page"] = next_page_url + if paginator.next_page_url_params: + response["next_page"] = create_next_page_url(request, paginator, use_absolute_url) return response -def generate_posts(count=100): - return [{"id": i, "title": f"Post {i}"} for i in range(count)] - - -def generate_comments(post_id, count=50): - return [{"id": i, "body": f"Comment {i} for post {post_id}"} for i in range(count)] - - -def get_page_number(qs, key="page", default=1): - return int(qs.get(key, [default])[0]) - - -def paginate_response(request, records, page_size=10, records_key="data", use_absolute_url=True): - page_number = get_page_number(request.qs) - total_records = len(records) - total_pages = (total_records + page_size - 1) // page_size - start_index = (page_number - 1) * 10 - end_index = start_index + 10 - records_slice = records[start_index:end_index] - return serialize_page( - records_slice, - page_number, - total_pages, - request.url, - records_key, - use_absolute_url, - ) - - @pytest.fixture(scope="module") def mock_api_server(): with requests_mock.Mocker() as m: - @router.get(r"/posts_no_key(\?page=\d+)?$") - def posts_no_key(request, context): - return paginate_response(request, generate_posts(), records_key=None) - @router.get(r"/posts(\?page=\d+)?$") def posts(request, context): - return paginate_response(request, generate_posts()) + return paginate_by_page_number(request, generate_posts()) + + @router.get(r"/posts_zero_based(\?page=\d+)?$") + def posts_zero_based(request, context): + return paginate_by_page_number(request, generate_posts(), index_base=0) + + @router.get(r"/posts_header_link(\?page=\d+)?$") + def posts_header_link(request, context): + records = generate_posts() + page_number = get_page_number(request.qs) + paginator = PageNumberPaginator(records, page_number) + + response = paginator.page_records + + if paginator.next_page_url_params: + next_page_url = create_next_page_url(request, paginator) + context.headers["Link"] = f'<{next_page_url}>; rel="next"' + + return response @router.get(r"/posts_relative_next_url(\?page=\d+)?$") def posts_relative_next_url(request, context): - return paginate_response(request, generate_posts(), use_absolute_url=False) + return paginate_by_page_number(request, generate_posts(), use_absolute_url=False) + + @router.get(r"/posts_offset_limit(\?offset=\d+&limit=\d+)?$") + def posts_offset_limit(request, context): + records = generate_posts() + offset = int(request.qs.get("offset", [0])[0]) + limit = int(request.qs.get("limit", [DEFAULT_LIMIT])[0]) + paginator = OffsetPaginator(records, offset, limit) + + return { + "data": paginator.page_records, + **paginator.metadata, + } + + @router.get(r"/posts_cursor(\?cursor=\d+)?$") + def posts_cursor(request, context): + records = generate_posts() + cursor = int(request.qs.get("cursor", [0])[0]) + paginator = CursorPaginator(records, cursor) + + return { + "data": paginator.page_records, + **paginator.metadata, + } @router.get(r"/posts/(\d+)/comments") def post_comments(request, context): post_id = int(request.url.split("/")[-2]) - return paginate_response(request, generate_comments(post_id)) + return paginate_by_page_number(request, generate_comments(post_id)) @router.get(r"/posts/\d+$") def post_detail(request, context): @@ -169,7 +132,29 @@ def post_detail_404(request, context): @router.get(r"/posts_under_a_different_key$") def posts_with_results_key(request, context): - return paginate_response(request, generate_posts(), records_key="many-results") + return paginate_by_page_number(request, generate_posts(), records_key="many-results") + + @router.post(r"/posts/search$") + def search_posts(request, context): + body = request.json() + page_size = body.get("page_size", DEFAULT_PAGE_SIZE) + page_number = body.get("page", 1) + + # Simulate a search with filtering + records = generate_posts() + ids_greater_than = body.get("ids_greater_than", 0) + records = [r for r in records if r["id"] > ids_greater_than] + + total_records = len(records) + total_pages = (total_records + page_size - 1) // page_size + start_index = (page_number - 1) * page_size + end_index = start_index + page_size + records_slice = records[start_index:end_index] + + return { + "data": records_slice, + "next_page": page_number + 1 if page_number < total_pages else None, + } @router.get("/protected/posts/basic-auth") def protected_basic_auth(request, context): @@ -177,7 +162,7 @@ def protected_basic_auth(request, context): creds = "user:password" creds_base64 = base64.b64encode(creds.encode()).decode() if auth == f"Basic {creds_base64}": - return paginate_response(request, generate_posts()) + return paginate_by_page_number(request, generate_posts()) context.status_code = 401 return {"error": "Unauthorized"} @@ -185,7 +170,7 @@ def protected_basic_auth(request, context): def protected_bearer_token(request, context): auth = request.headers.get("Authorization") if auth == "Bearer test-token": - return paginate_response(request, generate_posts()) + return paginate_by_page_number(request, generate_posts()) context.status_code = 401 return {"error": "Unauthorized"} @@ -193,7 +178,7 @@ def protected_bearer_token(request, context): def protected_bearer_token_plain_text_erorr(request, context): auth = request.headers.get("Authorization") if auth == "Bearer test-token": - return paginate_response(request, generate_posts()) + return paginate_by_page_number(request, generate_posts()) context.status_code = 401 return "Unauthorized" @@ -201,13 +186,23 @@ def protected_bearer_token_plain_text_erorr(request, context): def protected_api_key(request, context): api_key = request.headers.get("x-api-key") if api_key == "test-api-key": - return paginate_response(request, generate_posts()) + return paginate_by_page_number(request, generate_posts()) context.status_code = 401 return {"error": "Unauthorized"} @router.post("/oauth/token") def oauth_token(request, context): - return {"access_token": "test-token", "expires_in": 3600} + if oauth_authorize(request): + return {"access_token": "test-token", "expires_in": 3600} + context.status_code = 401 + return {"error": "Unauthorized"} + + @router.post("/oauth/token-expires-now") + def oauth_token_expires_now(request, context): + if oauth_authorize(request): + return {"access_token": "test-token", "expires_in": 0} + context.status_code = 401 + return {"error": "Unauthorized"} @router.post("/auth/refresh") def refresh_token(request, context): @@ -217,12 +212,47 @@ def refresh_token(request, context): context.status_code = 401 return {"error": "Invalid refresh token"} + @router.post("/custom-oauth/token") + def custom_oauth_token(request, context): + qs = parse_qs(request.text) + if ( + qs.get("grant_type")[0] == "account_credentials" + and qs.get("account_id")[0] == "test-account-id" + and request.headers["Authorization"] + == "Basic dGVzdC1hY2NvdW50LWlkOnRlc3QtY2xpZW50LXNlY3JldA==" + ): + return {"access_token": "test-token", "expires_in": 3600} + context.status_code = 401 + return {"error": "Unauthorized"} + router.register_routes(m) yield m -def assert_pagination(pages, expected_start=0, page_size=10, total_pages=10): +@pytest.fixture +def rest_client() -> RESTClient: + return RESTClient( + base_url="https://api.example.com", + headers={"Accept": "application/json"}, + ) + + +def oauth_authorize(request): + qs = parse_qs(request.text) + grant_type = qs.get("grant_type")[0] + if "jwt-bearer" in grant_type: + return True + if "client_credentials" in grant_type: + return ( + qs["client_secret"][0] == "test-client-secret" + and qs["client_id"][0] == "test-client-id" + ) + + +def assert_pagination(pages, page_size=DEFAULT_PAGE_SIZE, total_pages=DEFAULT_TOTAL_PAGES): assert len(pages) == total_pages for i, page in enumerate(pages): - assert page == [{"id": i, "title": f"Post {i}"} for i in range(i * 10, (i + 1) * 10)] + assert page == [ + {"id": i, "title": f"Post {i}"} for i in range(i * page_size, (i + 1) * page_size) + ] diff --git a/tests/sources/helpers/rest_client/paginators.py b/tests/sources/helpers/rest_client/paginators.py new file mode 100644 index 0000000000..fdd8e6f4d8 --- /dev/null +++ b/tests/sources/helpers/rest_client/paginators.py @@ -0,0 +1,125 @@ +class BasePaginator: + def __init__(self, records): + self.records = records + + @property + def page_records(self): + """Return records for the current page.""" + raise NotImplementedError + + @property + def metadata(self): + """Return metadata for the current page. + E.g. total number of records, current page number, etc. + """ + raise NotImplementedError + + @property + def next_page_url_params(self): + """Return URL parameters for the next page. + This is used to generate the URL for the next page in the response. + """ + raise NotImplementedError + + +class PageNumberPaginator(BasePaginator): + def __init__(self, records, page_number, page_size=5, index_base=1): + """Paginate records by page number. + + Args: + records: List of records to paginate. + page_number: Page number to return. + page_size: Maximum number of records to return per page. + index_base: Index of the start page. E.g. zero-based + index or 1-based index. + """ + super().__init__(records) + self.page_number = page_number + self.index_base = index_base + self.page_size = page_size + + @property + def page_records(self): + start_index = (self.page_number - self.index_base) * self.page_size + end_index = start_index + self.page_size + return self.records[start_index:end_index] + + @property + def metadata(self): + return {"page": self.page_number, "total_pages": self.total_pages} + + @property + def next_page_url_params(self): + return {"page": self.next_page_number} if self.next_page_number else {} + + @property + def total_pages(self): + total_records = len(self.records) + return (total_records + self.page_size - 1) // self.page_size + + @property + def next_page_number(self): + return ( + self.page_number + 1 + if self.page_number + 1 < self.total_pages + self.index_base + else None + ) + + +class OffsetPaginator(BasePaginator): + def __init__(self, records, offset, limit=10): + """Paginate records by offset. + + Args: + records: List of records to paginate. + offset: Offset to start slicing from. + limit: Maximum number of records to return. + """ + super().__init__(records) + self.offset = offset + self.limit = limit + + @property + def page_records(self): + return self.records[self.offset : self.offset + self.limit] + + @property + def metadata(self): + return {"total_records": len(self.records), "offset": self.offset, "limit": self.limit} + + @property + def next_page_url_params(self): + if self.offset + self.limit < len(self.records): + return {"offset": self.offset + self.limit, "limit": self.limit} + return {} + + +class CursorPaginator(BasePaginator): + def __init__(self, records, cursor, limit=5): + """Paginate records by cursor. + + Here, cursor is the index of the record to start slicing from. + + Args: + records: List of records to paginate. + cursor: Cursor to start slicing from. + limit: Maximum number of records to return. + """ + super().__init__(records) + self.cursor = cursor + self.limit = limit + + @property + def page_records(self): + return self.records[self.cursor : self.cursor + self.limit] + + @property + def metadata(self): + next_index = self.cursor + self.limit + + if next_index < len(self.records): + next_cursor = next_index + else: + next_cursor = None + + return {"next_cursor": next_cursor} diff --git a/tests/sources/helpers/rest_client/test_client.py b/tests/sources/helpers/rest_client/test_client.py index bd65affe62..ed227cd3cd 100644 --- a/tests/sources/helpers/rest_client/test_client.py +++ b/tests/sources/helpers/rest_client/test_client.py @@ -1,25 +1,30 @@ import os +from base64 import b64encode +from typing import Any, Dict, cast +from unittest.mock import patch + import pytest -from typing import Any, cast -from dlt.common import logger from requests import PreparedRequest, Request, Response from requests.auth import AuthBase +from requests.exceptions import HTTPError + +from dlt.common import logger from dlt.common.typing import TSecretStrValue from dlt.sources.helpers.requests import Client from dlt.sources.helpers.rest_client import RESTClient -from dlt.sources.helpers.rest_client.client import Hooks -from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator - -from dlt.sources.helpers.rest_client.auth import AuthConfigBase from dlt.sources.helpers.rest_client.auth import ( - BearerTokenAuth, APIKeyAuth, + AuthConfigBase, + BearerTokenAuth, HttpBasicAuth, + OAuth2ClientCredentials, OAuthJWTAuth, ) +from dlt.sources.helpers.rest_client.client import Hooks from dlt.sources.helpers.rest_client.exceptions import IgnoreResponseException +from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator, BaseReferencePaginator -from .conftest import assert_pagination +from .conftest import DEFAULT_PAGE_SIZE, DEFAULT_TOTAL_PAGES, assert_pagination def load_private_key(name="private_key.pem"): @@ -31,13 +36,40 @@ def load_private_key(name="private_key.pem"): TEST_PRIVATE_KEY = load_private_key() -@pytest.fixture -def rest_client() -> RESTClient: +def build_rest_client(auth=None) -> RESTClient: return RESTClient( base_url="https://api.example.com", headers={"Accept": "application/json"}, session=Client().session, + auth=auth, + ) + + +@pytest.fixture +def rest_client() -> RESTClient: + return build_rest_client() + + +@pytest.fixture +def rest_client_oauth() -> RESTClient: + auth = OAuth2ClientCredentials( + access_token_url=cast(TSecretStrValue, "https://api.example.com/oauth/token"), + client_id=cast(TSecretStrValue, "test-client-id"), + client_secret=cast(TSecretStrValue, "test-client-secret"), + session=Client().session, + ) + return build_rest_client(auth=auth) + + +@pytest.fixture +def rest_client_immediate_oauth_expiry(auth=None) -> RESTClient: + credentials_expiring_now = OAuth2ClientCredentials( + access_token_url=cast(TSecretStrValue, "https://api.example.com/oauth/token-expires-now"), + client_id=cast(TSecretStrValue, "test-client-id"), + client_secret=cast(TSecretStrValue, "test-client-secret"), + session=Client().session, ) + return build_rest_client(auth=credentials_expiring_now) @pytest.mark.usefixtures("mock_api_server") @@ -163,6 +195,113 @@ def test_api_key_auth_success(self, rest_client: RESTClient): assert response.status_code == 200 assert response.json()["data"][0] == {"id": 0, "title": "Post 0"} + def test_oauth2_client_credentials_flow_auth_success(self, rest_client_oauth: RESTClient): + response = rest_client_oauth.get("/protected/posts/bearer-token") + + assert response.status_code == 200 + assert "test-token" in response.request.headers["Authorization"] + + pages_iter = rest_client_oauth.paginate("/protected/posts/bearer-token") + + assert_pagination(list(pages_iter)) + + def test_oauth2_client_credentials_flow_wrong_client_id(self, rest_client: RESTClient): + auth = OAuth2ClientCredentials( + access_token_url=cast(TSecretStrValue, "https://api.example.com/oauth/token"), + client_id=cast(TSecretStrValue, "invalid-client-id"), + client_secret=cast(TSecretStrValue, "test-client-secret"), + session=Client().session, + ) + + with pytest.raises(HTTPError) as e: + rest_client.get("/protected/posts/bearer-token", auth=auth) + assert e.type == HTTPError + assert e.match("401 Client Error") + + def test_oauth2_client_credentials_flow_wrong_client_secret(self, rest_client: RESTClient): + auth = OAuth2ClientCredentials( + access_token_url=cast(TSecretStrValue, "https://api.example.com/oauth/token"), + client_id=cast(TSecretStrValue, "test-client-id"), + client_secret=cast(TSecretStrValue, "invalid-client-secret"), + session=Client().session, + ) + + with pytest.raises(HTTPError) as e: + rest_client.get( + "/protected/posts/bearer-token", + auth=auth, + ) + assert e.type == HTTPError + assert e.match("401 Client Error") + + def test_oauth_token_expired_refresh(self, rest_client_immediate_oauth_expiry: RESTClient): + rest_client = rest_client_immediate_oauth_expiry + auth = cast(OAuth2ClientCredentials, rest_client.auth) + + with patch.object(auth, "obtain_token", wraps=auth.obtain_token) as mock_obtain_token: + assert auth.access_token is None + response = rest_client.get("/protected/posts/bearer-token") + mock_obtain_token.assert_called_once() + assert response.status_code == 200 + assert auth.access_token is not None + expiry_0 = auth.token_expiry + auth.token_expiry = auth.token_expiry.subtract(seconds=1) + expiry_1 = auth.token_expiry + assert expiry_0 > expiry_1 + assert auth.is_token_expired() + + response = rest_client.get("/protected/posts/bearer-token") + assert mock_obtain_token.call_count == 2 + assert response.status_code == 200 + expiry_2 = auth.token_expiry + assert expiry_2 > expiry_1 + assert response.json()["data"][0] == {"id": 0, "title": "Post 0"} + + def test_oauth_customized_token_request(self, rest_client: RESTClient): + class OAuth2ClientCredentialsHTTPBasic(OAuth2ClientCredentials): + """OAuth 2.0 as required by e.g. Zoom Video Communications, Inc.""" + + def build_access_token_request(self) -> Dict[str, Any]: + authentication: str = b64encode( + f"{self.client_id}:{self.client_secret}".encode() + ).decode() + return { + "headers": { + "Authorization": f"Basic {authentication}", + "Content-Type": "application/x-www-form-urlencoded", + }, + "data": { + "grant_type": "account_credentials", + **self.access_token_request_data, + }, + } + + auth = OAuth2ClientCredentialsHTTPBasic( + access_token_url=cast(TSecretStrValue, "https://api.example.com/custom-oauth/token"), + client_id=cast(TSecretStrValue, "test-account-id"), + client_secret=cast(TSecretStrValue, "test-client-secret"), + access_token_request_data={ + "account_id": cast(TSecretStrValue, "test-account-id"), + }, + session=Client().session, + ) + + assert auth.build_access_token_request() == { + "headers": { + "Authorization": "Basic dGVzdC1hY2NvdW50LWlkOnRlc3QtY2xpZW50LXNlY3JldA==", + "Content-Type": "application/x-www-form-urlencoded", + }, + "data": { + "grant_type": "account_credentials", + "account_id": "test-account-id", + }, + } + + rest_client.auth = auth + pages_iter = rest_client.paginate("/protected/posts/bearer-token") + + assert_pagination(list(pages_iter)) + def test_oauth_jwt_auth_success(self, rest_client: RESTClient): auth = OAuthJWTAuth( client_id="test-client-id", @@ -255,3 +394,39 @@ def _fake_send(*args, **kwargs): result = rest_client.get("/posts/1") assert result.status_code == 200 + + def test_paginate_json_body_without_params(self, rest_client) -> None: + # leave 3 pages of data + posts_skip = (DEFAULT_TOTAL_PAGES - 3) * DEFAULT_PAGE_SIZE + + class JSONBodyPageCursorPaginator(BaseReferencePaginator): + def update_state(self, response): + self._next_reference = response.json().get("next_page") + + def update_request(self, request): + if request.json is None: + request.json = {} + + request.json["page"] = self._next_reference + + page_generator = rest_client.paginate( + path="/posts/search", + method="POST", + json={"ids_greater_than": posts_skip - 1}, + paginator=JSONBodyPageCursorPaginator(), + ) + result = [post for page in list(page_generator) for post in page] + for i in range(3 * DEFAULT_PAGE_SIZE): + assert result[i] == {"id": posts_skip + i, "title": f"Post {posts_skip + i}"} + + def test_post_json_body_without_params(self, rest_client) -> None: + # leave two pages of data + posts_skip = (DEFAULT_TOTAL_PAGES - 2) * DEFAULT_PAGE_SIZE + result = rest_client.post( + path="/posts/search", + json={"ids_greater_than": posts_skip - 1}, + ) + returned_posts = result.json()["data"] + assert len(returned_posts) == DEFAULT_PAGE_SIZE # only one page is returned + for i in range(DEFAULT_PAGE_SIZE): + assert returned_posts[i] == {"id": posts_skip + i, "title": f"Post {posts_skip + i}"} diff --git a/tests/sources/helpers/rest_client/test_mock_api_server.py b/tests/sources/helpers/rest_client/test_mock_api_server.py new file mode 100644 index 0000000000..cfdc920f3e --- /dev/null +++ b/tests/sources/helpers/rest_client/test_mock_api_server.py @@ -0,0 +1,310 @@ +import pytest + + +@pytest.mark.usefixtures("mock_api_server") +class TestMockAPIServer: + @pytest.mark.parametrize( + "test_case", + [ + # Page number is one-based + { + "url": "/posts", + "status_code": 200, + "expected_json": { + "data": [ + {"id": 0, "title": "Post 0"}, + {"id": 1, "title": "Post 1"}, + {"id": 2, "title": "Post 2"}, + {"id": 3, "title": "Post 3"}, + {"id": 4, "title": "Post 4"}, + ], + "page": 1, + "total_pages": 5, + "next_page": "https://api.example.com/posts?page=2", + }, + }, + { + "url": "/posts?page=2", + "status_code": 200, + "expected_json": { + "data": [ + {"id": 5, "title": "Post 5"}, + {"id": 6, "title": "Post 6"}, + {"id": 7, "title": "Post 7"}, + {"id": 8, "title": "Post 8"}, + {"id": 9, "title": "Post 9"}, + ], + "page": 2, + "total_pages": 5, + "next_page": "https://api.example.com/posts?page=3", + }, + }, + { + "url": "/posts?page=3", + "status_code": 200, + "expected_json": { + "data": [ + {"id": 10, "title": "Post 10"}, + {"id": 11, "title": "Post 11"}, + {"id": 12, "title": "Post 12"}, + {"id": 13, "title": "Post 13"}, + {"id": 14, "title": "Post 14"}, + ], + "page": 3, + "total_pages": 5, + "next_page": "https://api.example.com/posts?page=4", + }, + }, + { + "url": "/posts?page=4", + "status_code": 200, + "expected_json": { + "data": [ + {"id": 15, "title": "Post 15"}, + {"id": 16, "title": "Post 16"}, + {"id": 17, "title": "Post 17"}, + {"id": 18, "title": "Post 18"}, + {"id": 19, "title": "Post 19"}, + ], + "page": 4, + "total_pages": 5, + "next_page": "https://api.example.com/posts?page=5", + }, + }, + { + "url": "/posts?page=5", + "status_code": 200, + "expected_json": { + "data": [ + {"id": 20, "title": "Post 20"}, + {"id": 21, "title": "Post 21"}, + {"id": 22, "title": "Post 22"}, + {"id": 23, "title": "Post 23"}, + {"id": 24, "title": "Post 24"}, + ], + "page": 5, + "total_pages": 5, + }, + }, + # Page number is zero-based + { + "url": "/posts_zero_based", + "status_code": 200, + "expected_json": { + "data": [ + {"id": 0, "title": "Post 0"}, + {"id": 1, "title": "Post 1"}, + {"id": 2, "title": "Post 2"}, + {"id": 3, "title": "Post 3"}, + {"id": 4, "title": "Post 4"}, + ], + "page": 0, + "total_pages": 5, + "next_page": "https://api.example.com/posts_zero_based?page=1", + }, + }, + { + "url": "/posts_zero_based?page=1", + "status_code": 200, + "expected_json": { + "data": [ + {"id": 5, "title": "Post 5"}, + {"id": 6, "title": "Post 6"}, + {"id": 7, "title": "Post 7"}, + {"id": 8, "title": "Post 8"}, + {"id": 9, "title": "Post 9"}, + ], + "page": 1, + "total_pages": 5, + "next_page": "https://api.example.com/posts_zero_based?page=2", + }, + }, + { + "url": "/posts_zero_based?page=2", + "status_code": 200, + "expected_json": { + "data": [ + {"id": 10, "title": "Post 10"}, + {"id": 11, "title": "Post 11"}, + {"id": 12, "title": "Post 12"}, + {"id": 13, "title": "Post 13"}, + {"id": 14, "title": "Post 14"}, + ], + "page": 2, + "total_pages": 5, + "next_page": "https://api.example.com/posts_zero_based?page=3", + }, + }, + { + "url": "/posts_zero_based?page=3", + "status_code": 200, + "expected_json": { + "data": [ + {"id": 15, "title": "Post 15"}, + {"id": 16, "title": "Post 16"}, + {"id": 17, "title": "Post 17"}, + {"id": 18, "title": "Post 18"}, + {"id": 19, "title": "Post 19"}, + ], + "page": 3, + "total_pages": 5, + "next_page": "https://api.example.com/posts_zero_based?page=4", + }, + }, + { + "url": "/posts_zero_based?page=4", + "status_code": 200, + "expected_json": { + "data": [ + {"id": 20, "title": "Post 20"}, + {"id": 21, "title": "Post 21"}, + {"id": 22, "title": "Post 22"}, + {"id": 23, "title": "Post 23"}, + {"id": 24, "title": "Post 24"}, + ], + "page": 4, + "total_pages": 5, + }, + }, + # Test offset-limit pagination + { + "url": "/posts_offset_limit", + "status_code": 200, + "expected_json": { + "data": [ + {"id": 0, "title": "Post 0"}, + {"id": 1, "title": "Post 1"}, + {"id": 2, "title": "Post 2"}, + {"id": 3, "title": "Post 3"}, + {"id": 4, "title": "Post 4"}, + {"id": 5, "title": "Post 5"}, + {"id": 6, "title": "Post 6"}, + {"id": 7, "title": "Post 7"}, + {"id": 8, "title": "Post 8"}, + {"id": 9, "title": "Post 9"}, + ], + "total_records": 25, + "offset": 0, + "limit": 10, + }, + }, + { + "url": "/posts_offset_limit?offset=10&limit=10", + "status_code": 200, + "expected_json": { + "data": [ + {"id": 10, "title": "Post 10"}, + {"id": 11, "title": "Post 11"}, + {"id": 12, "title": "Post 12"}, + {"id": 13, "title": "Post 13"}, + {"id": 14, "title": "Post 14"}, + {"id": 15, "title": "Post 15"}, + {"id": 16, "title": "Post 16"}, + {"id": 17, "title": "Post 17"}, + {"id": 18, "title": "Post 18"}, + {"id": 19, "title": "Post 19"}, + ], + "total_records": 25, + "offset": 10, + "limit": 10, + }, + }, + { + "url": "/posts_offset_limit?offset=20&limit=10", + "status_code": 200, + "expected_json": { + "data": [ + {"id": 20, "title": "Post 20"}, + {"id": 21, "title": "Post 21"}, + {"id": 22, "title": "Post 22"}, + {"id": 23, "title": "Post 23"}, + {"id": 24, "title": "Post 24"}, + ], + "total_records": 25, + "offset": 20, + "limit": 10, + }, + }, + # Test cursor pagination + { + "url": "/posts_cursor", + "status_code": 200, + "expected_json": { + "data": [ + {"id": 0, "title": "Post 0"}, + {"id": 1, "title": "Post 1"}, + {"id": 2, "title": "Post 2"}, + {"id": 3, "title": "Post 3"}, + {"id": 4, "title": "Post 4"}, + ], + "next_cursor": 5, + }, + }, + { + "url": "/posts_cursor?cursor=5", + "status_code": 200, + "expected_json": { + "data": [ + {"id": 5, "title": "Post 5"}, + {"id": 6, "title": "Post 6"}, + {"id": 7, "title": "Post 7"}, + {"id": 8, "title": "Post 8"}, + {"id": 9, "title": "Post 9"}, + ], + "next_cursor": 10, + }, + }, + { + "url": "/posts_cursor?cursor=10", + "status_code": 200, + "expected_json": { + "data": [ + {"id": 10, "title": "Post 10"}, + {"id": 11, "title": "Post 11"}, + {"id": 12, "title": "Post 12"}, + {"id": 13, "title": "Post 13"}, + {"id": 14, "title": "Post 14"}, + ], + "next_cursor": 15, + }, + }, + { + "url": "/posts_cursor?cursor=15", + "status_code": 200, + "expected_json": { + "data": [ + {"id": 15, "title": "Post 15"}, + {"id": 16, "title": "Post 16"}, + {"id": 17, "title": "Post 17"}, + {"id": 18, "title": "Post 18"}, + {"id": 19, "title": "Post 19"}, + ], + "next_cursor": 20, + }, + }, + { + "url": "/posts_cursor?cursor=20", + "status_code": 200, + "expected_json": { + "data": [ + {"id": 20, "title": "Post 20"}, + {"id": 21, "title": "Post 21"}, + {"id": 22, "title": "Post 22"}, + {"id": 23, "title": "Post 23"}, + {"id": 24, "title": "Post 24"}, + ], + "next_cursor": None, + }, + }, + ], + ) + def test_paginate_success(self, test_case, rest_client): + response = rest_client.get(test_case["url"]) + assert response.status_code == test_case["status_code"] + assert response.json() == test_case["expected_json"] + + @pytest.mark.skip(reason="Not implemented") + def test_paginate_by_page_number_invalid_page(self, rest_client): + response = rest_client.get("/posts?page=6") + assert response.status_code == 404 + assert response.json() == {"error": "Not Found"} diff --git a/tests/sources/helpers/rest_client/test_paginators.py b/tests/sources/helpers/rest_client/test_paginators.py index 9ca54e814c..a5f9d888a2 100644 --- a/tests/sources/helpers/rest_client/test_paginators.py +++ b/tests/sources/helpers/rest_client/test_paginators.py @@ -3,6 +3,7 @@ import pytest from requests.models import Response, Request +from requests import Session from dlt.sources.helpers.rest_client.paginators import ( SinglePagePaginator, @@ -10,9 +11,13 @@ PageNumberPaginator, HeaderLinkPaginator, JSONResponsePaginator, + JSONResponseCursorPaginator, ) +from .conftest import assert_pagination + +@pytest.mark.usefixtures("mock_api_server") class TestHeaderLinkPaginator: def test_update_state_with_next(self): paginator = HeaderLinkPaginator() @@ -29,7 +34,18 @@ def test_update_state_without_next(self): paginator.update_state(response) assert paginator.has_next_page is False + def test_client_pagination(self, rest_client): + pages_iter = rest_client.paginate( + "/posts_header_link", + paginator=HeaderLinkPaginator(), + ) + + pages = list(pages_iter) + + assert_pagination(pages) + +@pytest.mark.usefixtures("mock_api_server") class TestJSONResponsePaginator: @pytest.mark.parametrize( "test_case", @@ -157,7 +173,44 @@ def test_update_request(self, test_case): paginator.update_request(request) assert request.url == test_case["expected"] + def test_no_duplicate_params_on_update_request(self): + paginator = JSONResponsePaginator() + + request = Request( + method="GET", + url="http://example.com/api/resource", + params={"param1": "value1"}, + ) + + session = Session() + + response = Mock(Response, json=lambda: {"next": "/api/resource?page=2¶m1=value1"}) + paginator.update_state(response) + paginator.update_request(request) + + assert request.url == "http://example.com/api/resource?page=2¶m1=value1" + + # RESTClient._send_request() calls Session.prepare_request() which + # updates the URL with the query parameters from the request object. + prepared_request = session.prepare_request(request) + # The next request should just use the "next" URL without any duplicate parameters. + assert prepared_request.url == "http://example.com/api/resource?page=2¶m1=value1" + + def test_client_pagination(self, rest_client): + pages_iter = rest_client.paginate( + "/posts", + paginator=JSONResponsePaginator( + next_url_path="next_page", + ), + ) + + pages = list(pages_iter) + + assert_pagination(pages) + + +@pytest.mark.usefixtures("mock_api_server") class TestSinglePagePaginator: def test_update_state(self): paginator = SinglePagePaginator() @@ -172,7 +225,18 @@ def test_update_state_with_next(self): paginator.update_state(response) assert paginator.has_next_page is False + def test_client_pagination(self, rest_client): + pages_iter = rest_client.paginate( + "/posts", + paginator=SinglePagePaginator(), + ) + + pages = list(pages_iter) + assert_pagination(pages, total_pages=1) + + +@pytest.mark.usefixtures("mock_api_server") class TestOffsetPaginator: def test_update_state(self): paginator = OffsetPaginator(offset=0, limit=10) @@ -238,40 +302,55 @@ def test_maximum_offset(self): assert paginator.current_value == 100 assert paginator.has_next_page is False + def test_client_pagination(self, rest_client): + pages_iter = rest_client.paginate( + "/posts_offset_limit", + paginator=OffsetPaginator(offset=0, limit=5, total_path="total_records"), + ) + + pages = list(pages_iter) + assert_pagination(pages) + + +@pytest.mark.usefixtures("mock_api_server") class TestPageNumberPaginator: def test_update_state(self): - paginator = PageNumberPaginator(initial_page=1, total_path="total_pages") + paginator = PageNumberPaginator(base_page=1, page=1, total_path="total_pages") response = Mock(Response, json=lambda: {"total_pages": 3}) paginator.update_state(response) assert paginator.current_value == 2 assert paginator.has_next_page is True + paginator.update_state(response) + assert paginator.current_value == 3 + assert paginator.has_next_page is True + # Test for reaching the end paginator.update_state(response) assert paginator.has_next_page is False def test_update_state_with_string_total_pages(self): - paginator = PageNumberPaginator(1) + paginator = PageNumberPaginator(base_page=1, page=1) response = Mock(Response, json=lambda: {"total": "3"}) paginator.update_state(response) assert paginator.current_value == 2 assert paginator.has_next_page is True def test_update_state_with_invalid_total_pages(self): - paginator = PageNumberPaginator(1) + paginator = PageNumberPaginator(base_page=1, page=1) response = Mock(Response, json=lambda: {"total_pages": "invalid"}) with pytest.raises(ValueError): paginator.update_state(response) def test_update_state_without_total_pages(self): - paginator = PageNumberPaginator(1) + paginator = PageNumberPaginator(base_page=1, page=1) response = Mock(Response, json=lambda: {}) with pytest.raises(ValueError): paginator.update_state(response) def test_update_request(self): - paginator = PageNumberPaginator(initial_page=1, page_param="page") + paginator = PageNumberPaginator(base_page=1, page=1, page_param="page") request = Mock(Request) response = Mock(Response, json=lambda: {"total": 3}) paginator.update_state(response) @@ -283,7 +362,7 @@ def test_update_request(self): assert request.params["page"] == 3 def test_maximum_page(self): - paginator = PageNumberPaginator(initial_page=1, maximum_page=3, total_path=None) + paginator = PageNumberPaginator(base_page=1, page=1, maximum_page=3, total_path=None) response = Mock(Response, json=lambda: {"items": []}) paginator.update_state(response) # Page 1 assert paginator.current_value == 2 @@ -292,3 +371,60 @@ def test_maximum_page(self): paginator.update_state(response) # Page 2 assert paginator.current_value == 3 assert paginator.has_next_page is False + + def test_client_pagination_one_based(self, rest_client): + pages_iter = rest_client.paginate( + "/posts", + paginator=PageNumberPaginator(base_page=1, page=1, total_path="total_pages"), + ) + + pages = list(pages_iter) + + assert_pagination(pages) + + def test_client_pagination_one_based_default_page(self, rest_client): + pages_iter = rest_client.paginate( + "/posts", + paginator=PageNumberPaginator(base_page=1, total_path="total_pages"), + ) + + pages = list(pages_iter) + + assert_pagination(pages) + + def test_client_pagination_zero_based(self, rest_client): + pages_iter = rest_client.paginate( + "/posts_zero_based", + paginator=PageNumberPaginator(base_page=0, page=0, total_path="total_pages"), + ) + + pages = list(pages_iter) + + assert_pagination(pages) + + +@pytest.mark.usefixtures("mock_api_server") +class TestJSONResponseCursorPaginator: + def test_update_state(self): + paginator = JSONResponseCursorPaginator(cursor_path="next_cursor") + response = Mock(Response, json=lambda: {"next_cursor": "cursor-2", "results": []}) + paginator.update_state(response) + assert paginator._next_reference == "cursor-2" + assert paginator.has_next_page is True + + def test_update_request(self): + paginator = JSONResponseCursorPaginator(cursor_path="next_cursor") + paginator._next_reference = "cursor-2" + request = Request(method="GET", url="http://example.com/api/resource") + paginator.update_request(request) + assert request.params["cursor"] == "cursor-2" + + def test_client_pagination(self, rest_client): + pages_iter = rest_client.paginate( + "/posts_cursor", + paginator=JSONResponseCursorPaginator(cursor_path="next_cursor"), + ) + + pages = list(pages_iter) + + assert_pagination(pages) diff --git a/tests/utils.py b/tests/utils.py index 580c040706..bf3aafdb77 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -45,13 +45,22 @@ "motherduck", "mssql", "qdrant", + "lancedb", "destination", "synapse", "databricks", "clickhouse", "dremio", } -NON_SQL_DESTINATIONS = {"filesystem", "weaviate", "dummy", "motherduck", "qdrant", "destination"} +NON_SQL_DESTINATIONS = { + "filesystem", + "weaviate", + "dummy", + "motherduck", + "qdrant", + "lancedb", + "destination", +} SQL_DESTINATIONS = IMPLEMENTED_DESTINATIONS - NON_SQL_DESTINATIONS # exclude destination configs (for now used for athena and athena iceberg separation) @@ -173,7 +182,7 @@ def unload_modules() -> Iterator[None]: @pytest.fixture(autouse=True) -def wipe_pipeline() -> Iterator[None]: +def wipe_pipeline(preserve_environ) -> Iterator[None]: """Wipes pipeline local state and deactivates it""" container = Container() if container[PipelineContext].is_active(): diff --git a/tox.ini b/tox.ini index ed6c69c585..059f6a586a 100644 --- a/tox.ini +++ b/tox.ini @@ -7,3 +7,6 @@ banned-modules = datetime = use dlt.common.pendulum open = use dlt.common.open pendulum = use dlt.common.pendulum extend-immutable-calls = dlt.sources.incremental +per-file-ignores = + tests/*: T20 + docs/*: T20 \ No newline at end of file