diff --git a/.env.docker b/.env.docker index cb0394806..278bd3e50 100644 --- a/.env.docker +++ b/.env.docker @@ -53,7 +53,7 @@ ONETL_PG_PASSWORD=ohtae0luxeshi1uraeluMoh9IShah7ai # Oracle ONETL_ORA_HOST=oracle ONETL_ORA_PORT=1521 -ONETL_ORA_SERVICE_NAME=XEPDB1 +ONETL_ORA_SERVICE_NAME=FREEPDB1 ONETL_ORA_USER=onetl ONETL_ORA_PASSWORD=Yoequ2Hoeceit4ch diff --git a/.env.local b/.env.local index 2e05030f3..0327f1cc0 100644 --- a/.env.local +++ b/.env.local @@ -53,7 +53,7 @@ export ONETL_PG_PASSWORD=ohtae0luxeshi1uraeluMoh9IShah7ai # Oracle export ONETL_ORA_HOST=localhost export ONETL_ORA_PORT=1522 -export ONETL_ORA_SERVICE_NAME=XEPDB1 +export ONETL_ORA_SERVICE_NAME=FREEPDB1 export ONETL_ORA_USER=onetl export ONETL_ORA_PASSWORD=Yoequ2Hoeceit4ch diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml index a3d39f471..b7e93d401 100644 --- a/.github/workflows/changelog.yml +++ b/.github/workflows/changelog.yml @@ -7,7 +7,7 @@ on: - master env: - DEFAULT_PYTHON: '3.11' + DEFAULT_PYTHON: '3.12' permissions: contents: read diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 434905856..1b3aee7f5 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -14,6 +14,8 @@ concurrency: cancel-in-progress: true env: + # flake8-commas is failing on Python 3.12 + # as well as bandit https://github.com/PyCQA/bandit/issues/1077 DEFAULT_PYTHON: '3.11' jobs: @@ -35,7 +37,9 @@ jobs: python-version: ${{ env.DEFAULT_PYTHON }} - name: Install Kerberos headers - run: sudo apt-get install --no-install-recommends libkrb5-dev + run: | + sudo apt-get update + sudo apt-get install --no-install-recommends libkrb5-dev - name: Cache pip uses: actions/cache@v3 diff --git a/.github/workflows/data/clickhouse/matrix.yml b/.github/workflows/data/clickhouse/matrix.yml index cf52893ab..6ee63269e 100644 --- a/.github/workflows/data/clickhouse/matrix.yml +++ b/.github/workflows/data/clickhouse/matrix.yml @@ -6,13 +6,13 @@ min: &min max: &max spark-version: 3.5.0 - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest latest: &latest spark-version: latest - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest diff --git a/.github/workflows/data/core/matrix.yml b/.github/workflows/data/core/matrix.yml index 78cb7f316..522e2160a 100644 --- a/.github/workflows/data/core/matrix.yml +++ b/.github/workflows/data/core/matrix.yml @@ -6,13 +6,13 @@ min: &min max: &max spark-version: 3.5.0 - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest latest: &latest spark-version: latest - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest diff --git a/.github/workflows/data/ftp/matrix.yml b/.github/workflows/data/ftp/matrix.yml index 4a5c92575..58061a5ed 100644 --- a/.github/workflows/data/ftp/matrix.yml +++ b/.github/workflows/data/ftp/matrix.yml @@ -3,7 +3,7 @@ min: &min os: ubuntu-latest max: &max - python-version: '3.11' + python-version: '3.12' os: ubuntu-latest matrix: diff --git a/.github/workflows/data/ftps/matrix.yml b/.github/workflows/data/ftps/matrix.yml index 4497a7371..1ff40b12d 100644 --- a/.github/workflows/data/ftps/matrix.yml +++ b/.github/workflows/data/ftps/matrix.yml @@ -3,7 +3,7 @@ min: &min os: ubuntu-latest max: &max - python-version: '3.11' + python-version: '3.12' os: ubuntu-latest matrix: diff --git a/.github/workflows/data/hdfs/matrix.yml b/.github/workflows/data/hdfs/matrix.yml index eba19818e..465c7eb81 100644 --- a/.github/workflows/data/hdfs/matrix.yml +++ b/.github/workflows/data/hdfs/matrix.yml @@ -8,14 +8,14 @@ min: &min max: &max hadoop-version: hadoop3-hdfs spark-version: 3.5.0 - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest latest: &latest hadoop-version: hadoop3-hdfs spark-version: latest - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest diff --git a/.github/workflows/data/hive/matrix.yml b/.github/workflows/data/hive/matrix.yml index 17c1c3a6c..370d22d95 100644 --- a/.github/workflows/data/hive/matrix.yml +++ b/.github/workflows/data/hive/matrix.yml @@ -6,13 +6,13 @@ min: &min max: &max spark-version: 3.5.0 - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest latest: &latest spark-version: latest - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest diff --git a/.github/workflows/data/kafka/matrix.yml b/.github/workflows/data/kafka/matrix.yml index 7797d37d6..cf4668e9a 100644 --- a/.github/workflows/data/kafka/matrix.yml +++ b/.github/workflows/data/kafka/matrix.yml @@ -9,14 +9,14 @@ min: &min max: &max kafka-version: 3.5.1 spark-version: 3.5.0 - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest latest: &latest kafka-version: latest spark-version: latest - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest diff --git a/.github/workflows/data/local-fs/matrix.yml b/.github/workflows/data/local-fs/matrix.yml index 4329d6582..df9bdeffe 100644 --- a/.github/workflows/data/local-fs/matrix.yml +++ b/.github/workflows/data/local-fs/matrix.yml @@ -16,33 +16,25 @@ min_excel: &min_excel java-version: 8 os: ubuntu-latest -max_excel: &max_excel - spark-version: 3.4.1 - python-version: '3.11' - java-version: 20 - os: ubuntu-latest - max: &max spark-version: 3.5.0 - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest latest: &latest spark-version: latest - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest matrix: small: - - <<: *max_excel - <<: *max full: - <<: *min - <<: *min_avro - <<: *min_excel - - <<: *max_excel - <<: *max nightly: - <<: *min diff --git a/.github/workflows/data/local-fs/tracked.txt b/.github/workflows/data/local-fs/tracked.txt index c763aed1e..013c04894 100644 --- a/.github/workflows/data/local-fs/tracked.txt +++ b/.github/workflows/data/local-fs/tracked.txt @@ -1 +1,2 @@ **/*local_fs* +**/*local-fs* diff --git a/.github/workflows/data/mongodb/matrix.yml b/.github/workflows/data/mongodb/matrix.yml index f91e1baaa..63aca7454 100644 --- a/.github/workflows/data/mongodb/matrix.yml +++ b/.github/workflows/data/mongodb/matrix.yml @@ -6,14 +6,14 @@ min: &min os: ubuntu-latest max: &max - spark-version: 3.4.1 - python-version: '3.11' + spark-version: 3.4.2 + python-version: '3.12' java-version: 20 os: ubuntu-latest latest: &latest spark-version: latest - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest diff --git a/.github/workflows/data/mssql/matrix.yml b/.github/workflows/data/mssql/matrix.yml index b9941b583..a2c115bc2 100644 --- a/.github/workflows/data/mssql/matrix.yml +++ b/.github/workflows/data/mssql/matrix.yml @@ -6,13 +6,13 @@ min: &min max: &max spark-version: 3.5.0 - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest latest: &latest spark-version: latest - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest diff --git a/.github/workflows/data/mysql/matrix.yml b/.github/workflows/data/mysql/matrix.yml index 39ba7034f..435ce9f90 100644 --- a/.github/workflows/data/mysql/matrix.yml +++ b/.github/workflows/data/mysql/matrix.yml @@ -6,13 +6,13 @@ min: &min max: &max spark-version: 3.5.0 - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest latest: &latest spark-version: latest - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest diff --git a/.github/workflows/data/oracle/matrix.yml b/.github/workflows/data/oracle/matrix.yml index 20086bf04..dcf725f51 100644 --- a/.github/workflows/data/oracle/matrix.yml +++ b/.github/workflows/data/oracle/matrix.yml @@ -6,32 +6,41 @@ min: &min max: &max spark-version: 3.5.0 - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest latest: &latest spark-version: latest - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest matrix: small: - - oracle-version: 21.3.0-slim-faststart - db-name: XEPDB1 + - oracle-image: gvenzl/oracle-free + oracle-version: 23.3-slim-faststart + db-name: FREEPDB1 <<: *max full: - - oracle-version: 11.2.0.2-slim-faststart + - oracle-image: gvenzl/oracle-xe + oracle-version: 11.2.0.2-slim-faststart db-name: XE <<: *min - - oracle-version: 21.3.0-slim-faststart + - oracle-image: gvenzl/oracle-xe + oracle-version: 21.3.0-slim-faststart db-name: XEPDB1 <<: *max + - oracle-image: gvenzl/oracle-free + oracle-version: 23.3-slim-faststart + db-name: FREEPDB1 + <<: *max nightly: - - oracle-version: 11.2.0.2-slim-faststart + - oracle-image: gvenzl/oracle-xe + oracle-version: 11.2.0.2-slim-faststart db-name: XE <<: *min - - oracle-version: latest-faststart - db-name: XEPDB1 + - oracle-image: gvenzl/oracle-free + oracle-version: slim-faststart + db-name: FREEPDB1 <<: *latest diff --git a/.github/workflows/data/postgres/matrix.yml b/.github/workflows/data/postgres/matrix.yml index c5233c5e8..1f2c6077c 100644 --- a/.github/workflows/data/postgres/matrix.yml +++ b/.github/workflows/data/postgres/matrix.yml @@ -6,13 +6,13 @@ min: &min max: &max spark-version: 3.5.0 - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest latest: &latest spark-version: latest - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest diff --git a/.github/workflows/data/s3/matrix.yml b/.github/workflows/data/s3/matrix.yml index 2b6fbdb32..64357c229 100644 --- a/.github/workflows/data/s3/matrix.yml +++ b/.github/workflows/data/s3/matrix.yml @@ -10,14 +10,14 @@ min: &min max: &max minio-version: 2023.7.18 spark-version: 3.5.0 - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest latest: &latest minio-version: latest spark-version: latest - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest diff --git a/.github/workflows/data/samba/matrix.yml b/.github/workflows/data/samba/matrix.yml index a4a3afe30..fc1573e02 100644 --- a/.github/workflows/data/samba/matrix.yml +++ b/.github/workflows/data/samba/matrix.yml @@ -3,7 +3,7 @@ min: &min os: ubuntu-latest max: &max - python-version: '3.11' + python-version: '3.12' os: ubuntu-latest matrix: diff --git a/.github/workflows/data/sftp/matrix.yml b/.github/workflows/data/sftp/matrix.yml index 12e2ecd79..44852908b 100644 --- a/.github/workflows/data/sftp/matrix.yml +++ b/.github/workflows/data/sftp/matrix.yml @@ -3,7 +3,7 @@ min: &min os: ubuntu-latest max: &max - python-version: '3.11' + python-version: '3.12' os: ubuntu-latest matrix: diff --git a/.github/workflows/data/teradata/matrix.yml b/.github/workflows/data/teradata/matrix.yml index 05da497c8..5391077c9 100644 --- a/.github/workflows/data/teradata/matrix.yml +++ b/.github/workflows/data/teradata/matrix.yml @@ -1,12 +1,12 @@ max: &max spark-version: 3.5.0 - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest latest: &latest spark-version: latest - python-version: '3.11' + python-version: '3.12' java-version: 20 os: ubuntu-latest diff --git a/.github/workflows/data/webdav/matrix.yml b/.github/workflows/data/webdav/matrix.yml index 227e7a330..87b492c31 100644 --- a/.github/workflows/data/webdav/matrix.yml +++ b/.github/workflows/data/webdav/matrix.yml @@ -3,7 +3,7 @@ min: &min os: ubuntu-latest max: &max - python-version: '3.11' + python-version: '3.12' os: ubuntu-latest matrix: diff --git a/.github/workflows/dev-release.yml b/.github/workflows/dev-release.yml index 76f362733..8dd3dc5d1 100644 --- a/.github/workflows/dev-release.yml +++ b/.github/workflows/dev-release.yml @@ -9,7 +9,7 @@ on: workflow_dispatch: env: - DEFAULT_PYTHON: '3.11' + DEFAULT_PYTHON: '3.12' concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} @@ -42,6 +42,10 @@ jobs: - name: Upgrade pip run: python -m pip install --upgrade pip setuptools wheel + - name: Fix logo in Readme + run: | + sed -i "s#image:: docs/#image:: https://raw.githubusercontent.com/MobileTeleSystems/onetl/$GITHUB_SHA/docs/#g" README.rst + - name: Build package run: python setup.py sdist bdist_wheel diff --git a/.github/workflows/get-matrix.yml b/.github/workflows/get-matrix.yml index b9d160b42..466ca483a 100644 --- a/.github/workflows/get-matrix.yml +++ b/.github/workflows/get-matrix.yml @@ -47,7 +47,7 @@ on: value: ${{ jobs.get-matrix.outputs.matrix-webdav }} env: - DEFAULT_PYTHON: '3.11' + DEFAULT_PYTHON: '3.12' jobs: get-matrix: @@ -375,7 +375,7 @@ jobs: files_from_source_file: .github/workflows/data/mysql/tracked.txt files_ignore_from_source_file: .github/workflows/data/mysql/ignored.txt - - name: Print MSSQL MySQL changed + - name: Print MSSQL files changed run: | echo '${{ steps.changed-mysql.outputs.all_changed_files }}' @@ -405,7 +405,7 @@ jobs: files_from_source_file: .github/workflows/data/oracle/tracked.txt files_ignore_from_source_file: .github/workflows/data/oracle/ignored.txt - - name: Print Oracle MySQL changed + - name: Print Oracle files changed run: | echo '${{ steps.changed-oracle.outputs.all_changed_files }}' diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 7608ebe6e..be5958b1b 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -10,7 +10,7 @@ concurrency: cancel-in-progress: true env: - DEFAULT_PYTHON: '3.11' + DEFAULT_PYTHON: '3.12' jobs: get-matrix: @@ -183,6 +183,7 @@ jobs: uses: ./.github/workflows/test-oracle.yml with: + oracle-image: ${{ matrix.oracle-image }} oracle-version: ${{ matrix.oracle-version }} db-name: ${{ matrix.db-name }} spark-version: ${{ matrix.spark-version }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e0d227ea9..84cf5817c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -6,7 +6,7 @@ on: - '[0-9]+.[0-9]+.[0-9]+' env: - DEFAULT_PYTHON: '3.11' + DEFAULT_PYTHON: '3.12' jobs: release: @@ -46,8 +46,11 @@ jobs: run: python -m pip install --upgrade pip setuptools wheel - name: Install dependencies + run: pip install -I -r requirements/core.txt -r requirements/docs.txt + + - name: Fix logo in Readme run: | - pip install -I -r requirements/core.txt -r requirements/docs.txt + sed -i "s#image:: docs/#image:: https://raw.githubusercontent.com/MobileTeleSystems/onetl/$GITHUB_REF_NAME/docs/#g" README.rst - name: Build package run: python setup.py sdist bdist_wheel @@ -57,7 +60,7 @@ jobs: - name: Get changelog run: | - cat docs/changelog/${{ github.ref_name }}.rst > changelog.rst + cat docs/changelog/$GITHUB_REF_NAME.rst > changelog.rst - name: Fix Github links run: | diff --git a/.github/workflows/test-clickhouse.yml b/.github/workflows/test-clickhouse.yml index 42170ed2e..903e50665 100644 --- a/.github/workflows/test-clickhouse.yml +++ b/.github/workflows/test-clickhouse.yml @@ -92,7 +92,7 @@ jobs: ./pytest_runner.sh -m clickhouse - name: Upload coverage results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: clickhouse-${{ inputs.clickhouse-version }}-spark-${{ inputs.spark-version }}-python-${{ inputs.python-version }}-os-${{ inputs.os }} path: reports/* diff --git a/.github/workflows/test-core.yml b/.github/workflows/test-core.yml index 0ed807efa..f3f24f7a8 100644 --- a/.github/workflows/test-core.yml +++ b/.github/workflows/test-core.yml @@ -72,7 +72,7 @@ jobs: ./run_tests.sh onetl/_util - name: Upload coverage results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: core-spark-${{ inputs.spark-version }}-python-${{ inputs.python-version }}-os-${{ inputs.os }} path: reports/* diff --git a/.github/workflows/test-ftp.yml b/.github/workflows/test-ftp.yml index 4ed90ec89..648a5ca36 100644 --- a/.github/workflows/test-ftp.yml +++ b/.github/workflows/test-ftp.yml @@ -77,7 +77,7 @@ jobs: COMPOSE_PROJECT_NAME: ${{ github.run_id }}-ftp${{ inputs.ftp-version }} - name: Upload coverage results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: ftp-${{ inputs.ftp-version }}-python-${{ inputs.python-version }}-os-${{ inputs.os }} path: reports/* diff --git a/.github/workflows/test-ftps.yml b/.github/workflows/test-ftps.yml index 741bd7a0e..4725420a0 100644 --- a/.github/workflows/test-ftps.yml +++ b/.github/workflows/test-ftps.yml @@ -77,7 +77,7 @@ jobs: COMPOSE_PROJECT_NAME: ${{ github.run_id }}-ftps${{ inputs.ftps-version }} - name: Upload coverage results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: ftps-${{ inputs.ftps-version }}-python-${{ inputs.python-version }}-os-${{ inputs.os }} path: reports/* diff --git a/.github/workflows/test-greenplum.yml b/.github/workflows/test-greenplum.yml index 67026aab0..26bac4a68 100644 --- a/.github/workflows/test-greenplum.yml +++ b/.github/workflows/test-greenplum.yml @@ -78,7 +78,8 @@ jobs: - name: Set up Postgres client if: runner.os == 'Linux' run: | - sudo apt-get update && sudo apt-get install --no-install-recommends postgresql-client + sudo apt-get update + sudo apt-get install --no-install-recommends postgresql-client - name: Upgrade pip run: python -m pip install --upgrade pip setuptools wheel @@ -110,7 +111,7 @@ jobs: GREENPLUM_PACKAGES_PASSWORD: ${{ secrets.GREENPLUM_PACKAGES_PASSWORD }} - name: Upload coverage results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: greenplum-${{ inputs.greenplum-version }}-spark-${{ inputs.spark-version }}-python-${{ inputs.python-version }}-os-${{ inputs.os }} path: reports/* diff --git a/.github/workflows/test-hdfs.yml b/.github/workflows/test-hdfs.yml index e48f4dd0a..b0c53c475 100644 --- a/.github/workflows/test-hdfs.yml +++ b/.github/workflows/test-hdfs.yml @@ -45,7 +45,8 @@ jobs: - name: Set up Kerberos libs if: runner.os == 'Linux' run: | - sudo apt-get update && sudo apt-get install --no-install-recommends libkrb5-dev gcc + sudo apt-get update + sudo apt-get install --no-install-recommends libkrb5-dev gcc - name: Cache Ivy uses: actions/cache@v3 @@ -107,7 +108,7 @@ jobs: COMPOSE_PROJECT_NAME: ${{ github.run_id }}-hadoop${{ inputs.hadoop-version }} - name: Upload coverage results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: hdfs-${{ inputs.hadoop-version }}-spark-${{ inputs.spark-version }}-python-${{ inputs.python-version }}-os-${{ inputs.os }} path: reports/* diff --git a/.github/workflows/test-hive.yml b/.github/workflows/test-hive.yml index cce56d675..16e9ef3f4 100644 --- a/.github/workflows/test-hive.yml +++ b/.github/workflows/test-hive.yml @@ -74,7 +74,7 @@ jobs: ./pytest_runner.sh -m hive - name: Upload coverage results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: hive-spark-${{ inputs.spark-version }}-python-${{ inputs.python-version }}-os-${{ inputs.os }} path: reports/* diff --git a/.github/workflows/test-kafka.yml b/.github/workflows/test-kafka.yml index b6641557b..3b8894bcc 100644 --- a/.github/workflows/test-kafka.yml +++ b/.github/workflows/test-kafka.yml @@ -126,7 +126,7 @@ jobs: ./pytest_runner.sh -m kafka - name: Upload coverage results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: kafka-${{ inputs.kafka-version }}-spark-${{ inputs.spark-version }}-python-${{ inputs.python-version }}-os-${{ inputs.os }} path: reports/* diff --git a/.github/workflows/test-local-fs.yml b/.github/workflows/test-local-fs.yml index f23beeb49..a15c04337 100644 --- a/.github/workflows/test-local-fs.yml +++ b/.github/workflows/test-local-fs.yml @@ -74,7 +74,7 @@ jobs: ./pytest_runner.sh -m local_fs - name: Upload coverage results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: local-fs-spark-${{ inputs.spark-version }}-python-${{ inputs.python-version }}-os-${{ inputs.os }} path: reports/* diff --git a/.github/workflows/test-mongodb.yml b/.github/workflows/test-mongodb.yml index 38086671e..2229c9a20 100644 --- a/.github/workflows/test-mongodb.yml +++ b/.github/workflows/test-mongodb.yml @@ -90,7 +90,7 @@ jobs: ./pytest_runner.sh -m mongodb - name: Upload coverage results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: mongodb-${{ inputs.mongodb-version }}-spark-${{ inputs.spark-version }}-python-${{ inputs.python-version }}-os-${{ inputs.os }} path: reports/* diff --git a/.github/workflows/test-mssql.yml b/.github/workflows/test-mssql.yml index d3b2a21a8..59a88978a 100644 --- a/.github/workflows/test-mssql.yml +++ b/.github/workflows/test-mssql.yml @@ -93,7 +93,7 @@ jobs: ./pytest_runner.sh -m mssql - name: Upload coverage results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: mssql-${{ inputs.mssql-version }}-spark-${{ inputs.spark-version }}-python-${{ inputs.python-version }}-os-${{ inputs.os }} path: reports/* diff --git a/.github/workflows/test-mysql.yml b/.github/workflows/test-mysql.yml index 16745e904..a03bd6304 100644 --- a/.github/workflows/test-mysql.yml +++ b/.github/workflows/test-mysql.yml @@ -92,7 +92,7 @@ jobs: ./pytest_runner.sh -m mysql - name: Upload coverage results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: mysql-${{ inputs.mysql-version }}-spark-${{ inputs.spark-version }}-python-${{ inputs.python-version }}-os-${{ inputs.os }} path: reports/* diff --git a/.github/workflows/test-oracle.yml b/.github/workflows/test-oracle.yml index dcd51b42a..0fbcbc73b 100644 --- a/.github/workflows/test-oracle.yml +++ b/.github/workflows/test-oracle.yml @@ -2,13 +2,16 @@ name: Tests for Oracle on: workflow_call: inputs: + oracle-image: + required: true + type: string oracle-version: required: true type: string db-name: required: false type: string - default: XEPDB1 + default: FREEPDB1 spark-version: required: true type: string @@ -32,7 +35,7 @@ jobs: runs-on: ${{ inputs.os }} services: oracle: - image: gvenzl/oracle-xe:${{ inputs.oracle-version }} + image: "${{ inputs.oracle-image }}:${{ inputs.oracle-version }}" env: TZ: UTC ORACLE_PASSWORD: maaxohmiGe9eep5x @@ -109,7 +112,7 @@ jobs: ./pytest_runner.sh -m oracle - name: Upload coverage results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: oracle-${{ inputs.oracle-version }}-spark-${{ inputs.spark-version }}-python-${{ inputs.python-version }}-os-${{ inputs.os }} path: reports/* diff --git a/.github/workflows/test-postgres.yml b/.github/workflows/test-postgres.yml index 30c91dfca..601a57a02 100644 --- a/.github/workflows/test-postgres.yml +++ b/.github/workflows/test-postgres.yml @@ -91,7 +91,7 @@ jobs: ./pytest_runner.sh -m postgres - name: Upload coverage results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: postgres-${{ inputs.postgres-version }}-spark-${{ inputs.spark-version }}-python-${{ inputs.python-version }}-os-${{ inputs.os }} path: reports/* diff --git a/.github/workflows/test-s3.yml b/.github/workflows/test-s3.yml index 99a51269e..365331e37 100644 --- a/.github/workflows/test-s3.yml +++ b/.github/workflows/test-s3.yml @@ -92,7 +92,7 @@ jobs: ./pytest_runner.sh -m s3 - name: Upload coverage results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: s3-${{ inputs.minio-version }}-spark-${{ inputs.spark-version }}-python-${{ inputs.python-version }}-os-${{ inputs.os }} path: reports/* diff --git a/.github/workflows/test-samba.yml b/.github/workflows/test-samba.yml index d823a9ae7..d82005d89 100644 --- a/.github/workflows/test-samba.yml +++ b/.github/workflows/test-samba.yml @@ -75,7 +75,7 @@ jobs: COMPOSE_PROJECT_NAME: ${{ github.run_id }}-samba${{ inputs.server-version }} - name: Upload coverage results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: samba-${{ inputs.server-version }}-python-${{ inputs.python-version }}-os-${{ inputs.os }} path: reports/* diff --git a/.github/workflows/test-sftp.yml b/.github/workflows/test-sftp.yml index bd630710b..5ea61b2e6 100644 --- a/.github/workflows/test-sftp.yml +++ b/.github/workflows/test-sftp.yml @@ -69,7 +69,7 @@ jobs: ./pytest_runner.sh -m sftp - name: Upload coverage results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: sftp-${{ inputs.openssh-version }}-python-${{ inputs.python-version }}-os-${{ inputs.os }} path: reports/* diff --git a/.github/workflows/test-teradata.yml b/.github/workflows/test-teradata.yml index 20ef294b7..b865cfff6 100644 --- a/.github/workflows/test-teradata.yml +++ b/.github/workflows/test-teradata.yml @@ -73,7 +73,7 @@ jobs: ./pytest_runner.sh -m teradata - name: Upload coverage results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: teradata-spark-${{ inputs.spark-version }}-python-${{ inputs.python-version }}-os-${{ inputs.os }} path: reports/* diff --git a/.github/workflows/test-webdav.yml b/.github/workflows/test-webdav.yml index fda365489..bece2465f 100644 --- a/.github/workflows/test-webdav.yml +++ b/.github/workflows/test-webdav.yml @@ -77,7 +77,7 @@ jobs: COMPOSE_PROJECT_NAME: ${{ github.run_id }}-webdav${{ inputs.webdav-version }} - name: Upload coverage results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: webdav-${{ inputs.webdav-version }}-python-${{ inputs.python-version }}-os-${{ inputs.os }} path: reports/* diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1df7f5306..670f981e0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ concurrency: cancel-in-progress: true env: - DEFAULT_PYTHON: '3.11' + DEFAULT_PYTHON: '3.12' jobs: get-matrix: @@ -175,6 +175,7 @@ jobs: uses: ./.github/workflows/test-oracle.yml with: + oracle-image: ${{ matrix.oracle-image }} oracle-version: ${{ matrix.oracle-version }} db-name: ${{ matrix.db-name }} spark-version: ${{ matrix.spark-version }} @@ -361,7 +362,7 @@ jobs: run: pip install -I coverage pytest - name: Download all coverage reports - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: path: reports diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9f0addf08..1f1f1fc6d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,7 +33,7 @@ repos: - id: codespell args: [-w] - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks - rev: v2.10.0 + rev: v2.11.0 hooks: - id: pretty-format-yaml args: [--autofix, --indent, '2', --preserve-quotes] @@ -46,7 +46,7 @@ repos: hooks: - id: docker-compose-check - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.0 hooks: - id: isort - repo: https://github.com/pre-commit/pygrep-hooks @@ -64,7 +64,7 @@ repos: - id: pyupgrade args: [--py37-plus, --keep-runtime-typing] - repo: https://github.com/psf/black - rev: 23.9.1 + rev: 23.11.0 hooks: - id: black language_version: python3 diff --git a/.readthedocs.yml b/.readthedocs.yml index 923741b22..13358b8b3 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -3,7 +3,7 @@ version: 2 build: os: ubuntu-22.04 tools: - python: "3.11" + python: "3.12" python: install: diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 4dc2d824f..8c21a15f6 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -70,7 +70,7 @@ Create virtualenv and install dependencies: -r requirements/tests/mysql.txt \ -r requirements/tests/postgres.txt \ -r requirements/tests/oracle.txt \ - -r requirements/tests/spark-3.4.1.txt + -r requirements/tests/spark-3.5.0.txt Enable pre-commit hooks ~~~~~~~~~~~~~~~~~~~~~~~ @@ -324,35 +324,7 @@ Examples for adding changelog entries to your Pull Requests .. tip:: - See `pyproject.toml <../../pyproject.toml>`_ for all available categories - (``tool.towncrier.type``). - -.. _Towncrier philosophy: - https://towncrier.readthedocs.io/en/stable/#philosophy - - -Examples for adding changelog entries to your Pull Requests -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. code-block:: rst - :caption: docs/changelog/next_release/1234.doc.1.rst - - Added a ``:github:user:`` role to Sphinx config -- by :github:user:`someuser` - -.. code-block:: rst - :caption: docs/changelog/next_release/2345.bugfix.rst - - Fixed behavior of ``WebDAV`` connector -- by :github:user:`someuser` - -.. code-block:: rst - :caption: docs/changelog/next_release/3456.feature.rst - - Added support of ``timeout`` in ``S3`` connector - -- by :github:user:`someuser`, :github:user:`anotheruser` and :github:user:`otheruser` - -.. tip:: - - See `pyproject.toml <../../pyproject.toml>`_ for all available categories + See `pyproject.toml `_ for all available categories (``tool.towncrier.type``). .. _Towncrier philosophy: diff --git a/README.rst b/README.rst index 792eb1dc4..01ff7a7f5 100644 --- a/README.rst +++ b/README.rst @@ -23,7 +23,7 @@ onETL |Logo| -.. |Logo| image:: docs/static/logo_wide.svg +.. |Logo| image:: docs/_static/logo_wide.svg :alt: onETL logo :target: https://github.com/MobileTeleSystems/onetl @@ -44,14 +44,14 @@ Goals Non-goals --------- -* onETL is not a Spark replacement. It just provides additional functionality that Spark does not have, and simplifies UX for end users. +* onETL is not a Spark replacement. It just provides additional functionality that Spark does not have, and improves UX for end users. * onETL is not a framework, as it does not have requirements to project structure, naming, the way of running ETL/ELT processes, configuration, etc. All of that should be implemented in some other tool. * onETL is deliberately developed without any integration with scheduling software like Apache Airflow. All integrations should be implemented as separated tools. * Only batch operations, no streaming. For streaming prefer `Apache Flink `_. Requirements ------------ -* **Python 3.7 - 3.11** +* **Python 3.7 - 3.12** * PySpark 2.3.x - 3.5.x (depends on used connector) * Java 8+ (required by Spark, see below) * Kerberos libs & GCC (required by ``Hive``, ``HDFS`` and ``SparkHDFS`` connectors) @@ -177,9 +177,9 @@ Compatibility matrix +--------------------------------------------------------------+-------------+-------------+-------+ | `3.3.x `_ | 3.7 - 3.10 | 8u201 - 17 | 2.12 | +--------------------------------------------------------------+-------------+-------------+-------+ -| `3.4.x `_ | 3.7 - 3.11 | 8u362 - 20 | 2.12 | +| `3.4.x `_ | 3.7 - 3.12 | 8u362 - 20 | 2.12 | +--------------------------------------------------------------+-------------+-------------+-------+ -| `3.5.x `_ | 3.8 - 3.11 | 8u371 - 20 | 2.12 | +| `3.5.x `_ | 3.8 - 3.12 | 8u371 - 20 | 2.12 | +--------------------------------------------------------------+-------------+-------------+-------+ .. _pyspark-install: @@ -328,7 +328,7 @@ Read data from MSSQL, transform & write to Hive. # >>> INFO:|MSSQL| Connection is available - # Initialize DB reader + # Initialize DBReader reader = DBReader( connection=mssql, source="dbo.demo_table", @@ -365,7 +365,7 @@ Read data from MSSQL, transform & write to Hive. # Initialize Hive connection hive = Hive(cluster="rnd-dwh", spark=spark) - # Initialize DB writer + # Initialize DBWriter db_writer = DBWriter( connection=hive, target="dl_sb.demo_table", @@ -626,7 +626,7 @@ Read files directly from S3 path, convert them to dataframe, transform it and th spark=spark, ) - # Initialize DB writer + # Initialize DBWriter db_writer = DBWriter( connection=postgres, # write to specific table diff --git a/docker-compose.yml b/docker-compose.yml index bdcfe3954..724e3d084 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -9,7 +9,7 @@ services: context: . target: base args: - SPARK_VERSION: 3.4.1 + SPARK_VERSION: 3.5.0 env_file: .env.docker volumes: - ./:/app/ @@ -124,13 +124,18 @@ services: - onetl oracle: - image: ${ORACLE_IMAGE:-gvenzl/oracle-xe:21.3.0-slim-faststart} + image: ${ORACLE_IMAGE:-gvenzl/oracle-free:23.3-slim-faststart} restart: unless-stopped env_file: .env.dependencies ports: - 1522:1521 networks: - onetl + healthcheck: + test: ["CMD", "healthcheck.sh"] + interval: 10s + timeout: 5s + retries: 10 ftp: image: ${FTP_IMAGE:-chonjay21/ftps:latest} diff --git a/docker/Dockerfile b/docker/Dockerfile index 817d4eab2..0faa2591e 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.11-slim as base +FROM python:3.12-slim as base LABEL maintainer="DataOps.ETL" ARG ONETL_USER_HOME=/usr/local/onetl @@ -45,7 +45,7 @@ COPY --chown=onetl:onetl ./run_tests.sh ./pytest_runner.sh ./combine_coverage.sh COPY --chown=onetl:onetl ./docker/wait-for-it.sh /app/docker/wait-for-it.sh RUN chmod +x /app/run_tests.sh /app/pytest_runner.sh /app/combine_coverage.sh /app/docker/wait-for-it.sh -ARG SPARK_VERSION=3.4.1 +ARG SPARK_VERSION=3.5.0 # Spark is heavy, and version change is quite rare COPY --chown=onetl:onetl ./requirements/tests/spark-${SPARK_VERSION}.txt /app/requirements/tests/ RUN pip install -r /app/requirements/tests/spark-${SPARK_VERSION}.txt diff --git a/docs/_static/icon.svg b/docs/_static/icon.svg new file mode 100644 index 000000000..a4d737f81 --- /dev/null +++ b/docs/_static/icon.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/docs/_static/logo.svg b/docs/_static/logo.svg new file mode 100644 index 000000000..76527ebf1 --- /dev/null +++ b/docs/_static/logo.svg @@ -0,0 +1,214 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/_static/logo_wide.svg b/docs/_static/logo_wide.svg new file mode 100644 index 000000000..981bf0148 --- /dev/null +++ b/docs/_static/logo_wide.svg @@ -0,0 +1,329 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/changelog/0.10.0.rst b/docs/changelog/0.10.0.rst new file mode 100644 index 000000000..79cc1d2d5 --- /dev/null +++ b/docs/changelog/0.10.0.rst @@ -0,0 +1,386 @@ +0.10.0 (2023-12-18) +=================== + +Breaking Changes +---------------- + +- Upgrade ``etl-entities`` from v1 to v2 (:github:pull:`172`). + + This implies that ``HWM`` classes are now have different internal structure than they used to. + + Before: + + .. code-block:: python + + from etl_entities.old_hwm import IntHWM as OldIntHWM + from etl_entities.source import Column, Table + from etl_entities.process import Process + + hwm = OldIntHWM( + process=Process(name="myprocess", task="abc", dag="cde", host="myhost"), + source=Table(name="schema.table", instance="postgres://host:5432/db"), + column=Column(name="col1"), + value=123, + ) + + After: + + .. code-block:: python + + from etl_entities.hwm import ColumnIntHWM + + hwm = ColumnIntHWM( + name="some_unique_name", + description="any value you want", + source="schema.table", + expression="col1", + value=123, + ) + + **Breaking change:** If you used HWM classes from ``etl_entities`` module, you should rewrite your code to make it compatible with new version. + + .. dropdown:: More details + + - ``HWM`` classes used by previous onETL versions were moved from ``etl_entities`` to ``etl_entities.old_hwm`` submodule. They are here for compatibility reasons, but are planned to be removed in ``etl-entities`` v3 release. + - New ``HWM`` classes have flat structure instead of nested. + - New ``HWM`` classes have mandatory ``name`` attribute (it was known as ``qualified_name`` before). + - Type aliases used while serializing and deserializing ``HWM`` objects to ``dict`` representation were changed too: ``int`` -> ``column_int``. + - HWM Store implementations now can handle only new ``HWM`` classes, old ones are **NOT** supported. + + To make migration simpler, you can use new method: + + .. code-block:: python + + old_hwm = OldIntHWM(...) + new_hwm = old_hwm.as_new_hwm() + + Which automatically converts all fields from old structure to new one, including ``qualified_name`` -> ``name``. + +- **Breaking changes:** + + * Methods ``BaseHWMStore.get()`` and ``BaseHWMStore.save()`` were renamed to ``get_hwm()`` and ``set_hwm()``. + * They now can be used only with new HWM classes from ``etl_entities.hwm``, **old HWM classes are not supported**. + + If you used them in your code, please update it accordingly. + +- YAMLHWMStore **CANNOT read files created by older onETL versions** (0.9.x or older). + + If you use it, please: + + * Find ``.yml`` file for specific HWM. Path can be found in logs, it is usually in form ``/home/USERNAME/.local/share/onETL/yml_hwm_store/QUALIFIED_NAME.yml`` (on Linux). + * Take latest ``value`` from file content. + * Delete the file. + * Update ``DBReader(where=...)`` value to include filter like ``hwm_column >= old_value`` (it should match column type). + * Run your code. ``DBReader.run()`` will get new HWM value, and save it to ``.yml`` file with new structure. + * Undo changes of ``DBReader(where=...)``. + + But most of users use other HWM store implementations which do not have such issues. + +- Several classes and functions were moved from ``onetl`` to ``etl_entities``: + + .. list-table:: + :header-rows: 1 + :widths: 30 30 + + * - onETL ``0.9.x`` and older + - onETL ``0.10.x`` and newer + + * - + .. code-block:: python + + from onetl.hwm.store import ( + detect_hwm_store, + BaseHWMStore, + HWMStoreClassRegistry, + register_hwm_store_class, + HWMStoreManager, + MemoryHWMStore, + ) + + - + .. code-block:: python + + from etl_entities.hwm_store import ( + detect_hwm_store, + BaseHWMStore, + HWMStoreClassRegistry, + register_hwm_store_class, + HWMStoreManager, + MemoryHWMStore, + ) + + They still can be imported from old module, but this is deprecated and will be removed in v1.0.0 release. + +- Change the way of passing ``HWM`` to ``DBReader`` and ``FileDownloader`` classes: + + .. list-table:: + :header-rows: 1 + :widths: 30 30 + + * - onETL ``0.9.x`` and older + - onETL ``0.10.x`` and newer + + * - + .. code-block:: python + + reader = DBReader( + connection=..., + source=..., + hwm_column="col1", + ) + + - + .. code-block:: python + + reader = DBReader( + connection=..., + source=..., + hwm=DBReader.AutoDetectHWM( + # name is mandatory now! + name="my_unique_hwm_name", + expression="col1", + ), + ) + + * - + .. code-block:: python + + reader = DBReader( + connection=..., + source=..., + hwm_column=( + "col1", + "cast(col1 as date)", + ), + ) + + - + .. code-block:: python + + reader = DBReader( + connection=..., + source=..., + hwm=DBReader.AutoDetectHWM( + # name is mandatory now! + name="my_unique_hwm_name", + expression="cast(col1 as date)", + ), + ) + + * - + .. code-block:: python + + downloader = FileDownloader( + connection=..., + source_path=..., + target_path=..., + hwm_type="file_list", + ) + + - + .. code-block:: python + + downloader = FileDownloader( + connection=..., + source_path=..., + target_path=..., + hwm=FileListHWM( + # name is mandatory now! + name="another_unique_hwm_name", + ), + ) + + New HWM classes have **mandatory** ``name`` attribute which should be passed explicitly, + instead of generating if automatically under the hood. + + Automatic ``name`` generation using the old ``DBReader.hwm_column`` / ``FileDownloader.hwm_type`` + syntax is still supported, but will be removed in v1.0.0 release. (:github:pull:`179`) + +- Performance of read Incremental abd Batch strategies has been drastically improved. (:github:pull:`182`). + + .. dropdown:: Before and after in details + + ``DBReader.run()`` + incremental/batch strategy behavior in versions 0.9.x and older: + + - Get table schema by making query ``SELECT * FROM table WHERE 1=0`` (if ``DBReader.columns`` has ``*``) + - Expand ``*`` to real column names from table, add here ``hwm_column``, remove duplicates (as some RDBMS does not allow that). + - Create dataframe from query like ``SELECT hwm_expression AS hwm_column, ...other table columns... FROM table WHERE hwm_expression > prev_hwm.value``. + - Determine HWM class using dataframe schema: ``df.schema[hwm_column].dataType``. + - Determine x HWM column value using Spark: ``df.select(max(hwm_column)).collect()``. + - Use ``max(hwm_column)`` as next HWM value, and save it to HWM Store. + - Return dataframe to user. + + This was far from ideal: + + - Dataframe content (all rows or just changed ones) was loaded from the source to Spark only to get min/max values of specific column. + + - Step of fetching table schema and then substituting column names in the next query caused some unexpected errors. + + For example, source contains columns with mixed name case, like ``"CamelColumn"`` or ``"spaced column"``. + + Column names were *not* escaped during query generation, leading to queries that cannot be executed by database. + + So users have to *explicitly* pass column names ``DBReader``, wrapping columns with mixed naming with ``"``: + + .. code:: python + + reader = DBReader( + connection=..., + source=..., + columns=[ # passing '*' here leads to wrong SQL query generation + "normal_column", + '"CamelColumn"', + '"spaced column"', + ..., + ], + ) + + - Using ``DBReader`` with ``IncrementalStrategy`` could lead to reading rows already read before. + + Dataframe was created from query with WHERE clause like ``hwm.expression > prev_hwm.value``, + not ``hwm.expression > prev_hwm.value AND hwm.expression <= current_hwm.value``. + + So if new rows appeared in the source **after** HWM value is determined, + they can be read by accessing dataframe content (because Spark dataframes are lazy), + leading to inconsistencies between HWM value and dataframe content. + + This may lead to issues then ``DBReader.run()`` read some data, updated HWM value, and next call of ``DBReader.run()`` + will read rows that were already read in previous run. + + ``DBReader.run()`` + incremental/batch strategy behavior in versions 0.10.x and newer: + + - Detect type of HWM expression: ``SELECT hwm.expression FROM table WHERE 1=0``. + - Determine corresponding Spark type ``df.schema[0]`` and when determine matching HWM class (if ``DReader.AutoDetectHWM`` is used). + - Get min/max values by querying the source: ``SELECT MAX(hwm.expression) FROM table WHERE hwm.expression >= prev_hwm.value``. + - Use ``max(hwm.expression)`` as next HWM value, and save it to HWM Store. + - Create dataframe from query ``SELECT ... table columns ... FROM table WHERE hwm.expression > prev_hwm.value AND hwm.expression <= current_hwm.value``, baking new HWM value into the query. + - Return dataframe to user. + + Improvements: + + - Allow source to calculate min/max instead of loading everything to Spark. This should be **faster** on large amounts of data (**up to x2**), because we do not transfer all the data from the source to Spark. This can be even faster if source have indexes for HWM column. + - Columns list is passed to source as-is, without any resolving on ``DBReader`` side. So you can pass ``DBReader(columns=["*"])`` to read tables with mixed columns naming. + - Restrict dataframe content to always match HWM values, which leads to never reading the same row twice. + + **Breaking change**: HWM column is not being implicitly added to dataframe. It was a part of ``SELECT`` clause, but now it is mentioned only in ``WHERE`` clause. + + So if you had code like this, you have to rewrite it: + + .. list-table:: + :header-rows: 1 + :widths: 20 20 + + * - onETL ``0.9.x`` and older + - onETL ``0.10.x`` and newer + + * - + .. code-block:: python + + reader = DBReader( + connection=..., + source=..., + columns=[ + "col1", + "col2", + ], + hwm_column="hwm_col", + ) + + df = reader.run() + # hwm_column value is in the dataframe + assert df.columns == ["col1", "col2", "hwm_col"] + + - + .. code-block:: python + + reader = DBReader( + connection=..., + source=..., + columns=[ + "col1", + "col2", + # add hwm_column explicitly + "hwm_col", + ], + hwm_column="hwm_col", + ) + + df = reader.run() + # if columns list is not updated, + # this fill fail + assert df.columns == ["col1", "col2", "hwm_col"] + + * - + .. code-block:: python + + reader = DBReader( + connection=..., + source=..., + columns=[ + "col1", + "col2", + ], + hwm_column=( + "hwm_col", + "cast(hwm_col as int)", + ), + ) + + df = reader.run() + # hwm_expression value is in the dataframe + assert df.columns == ["col1", "col2", "hwm_col"] + - + .. code-block:: python + + reader = DBReader( + connection=..., + source=..., + columns=[ + "col1", + "col2", + # add hwm_expression explicitly + "cast(hwm_col as int) as hwm_col", + ], + hwm_column=( + "hwm_col", + "cast(hwm_col as int)", + ), + ) + + df = reader.run() + # if columns list is not updated, + # this fill fail + assert df.columns == ["col1", "col2", "hwm_col"] + + But most users just use ``columns=["*"]`` anyway, they won't see any changes. + +- ``FileDownloader.run()`` now updates HWM in HWM Store not after each file is being successfully downloaded, + but after all files were handled. + + This is because: + + * FileDownloader can be used with ``DownloadOptions(workers=N)``, which could lead to race condition - one thread can save to HWM store one HWM value when another thread will save different value. + * FileDownloader can download hundreds and thousands of files, and issuing a request to HWM Store for each file could potentially DDoS HWM Store. (:github:pull:`189`) + + There is a exception handler which tries to save HWM to HWM store if download process was interrupted. But if it was interrupted by force, like sending ``SIGKILL`` event, + HWM will not be saved to HWM store, so some already downloaded files may be downloaded again next time. + + But unexpected process kill may produce other negative impact, like some file will be downloaded partially, so this is an expected behavior. + + +Features +-------- + +- Add Python 3.12 compatibility. (:github:pull:`167`) +- ``Excel`` file format now can be used with Spark 3.5.0. (:github:pull:`187`) +- ``SnapshotBatchStagy`` and ``IncrementalBatchStrategy`` does no raise exceptions if source does not contain any data. + Instead they stop at first iteration and return empty dataframe. (:github:pull:`188`) +- Cache result of ``connection.check()`` in high-level classes like ``DBReader``, ``FileDownloader`` and so on. This makes logs less verbose. (:github:pull:`190`) + +Bug Fixes +--------- + +- Fix ``@slot`` and ``@hook`` decorators returning methods with missing arguments in signature (Pylance, VS Code). (:github:pull:`183`) +- Kafka connector documentation said that it does support reading topic data incrementally by passing ``group.id`` or ``groupIdPrefix``. + Actually, this is not true, because Spark does not send information to Kafka which messages were consumed. + So currently users can only read the whole topic, no incremental reads are supported. diff --git a/docs/changelog/index.rst b/docs/changelog/index.rst index 276dd3cf6..29163e700 100644 --- a/docs/changelog/index.rst +++ b/docs/changelog/index.rst @@ -4,6 +4,7 @@ DRAFT NEXT_RELEASE + 0.10.0 0.9.5 0.9.4 0.9.3 diff --git a/docs/conf.py b/docs/conf.py index 06a5b08aa..291f04069 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -57,6 +57,8 @@ "sphinxcontrib.autodoc_pydantic", "sphinxcontrib.towncrier", # provides `towncrier-draft-entries` directive "sphinxcontrib.plantuml", + "sphinx.ext.extlinks", + "sphinx_favicon", ] numpydoc_show_class_members = True autodoc_pydantic_model_show_config = False @@ -104,8 +106,11 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ["static"] -html_logo = "./static/logo.svg" +html_static_path = ["_static"] +html_logo = "./_static/logo.svg" +favicons = [ + {"rel": "icon", "href": "icon.svg", "type": "image/svg+xml"}, +] # The master toctree document. master_doc = "index" @@ -119,6 +124,12 @@ # The name of the Pygments (syntax highlighting) style to use. pygments_style = "sphinx" +# Create an alias for etl-entities lib in onetl documentation +extlinks = { + "etl-entities": ("https://etl-entities.readthedocs.io/en/stable/%s", None), +} + + # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False diff --git a/docs/connection/db_connection/kafka/read.rst b/docs/connection/db_connection/kafka/read.rst index 3e6c846b1..d502c453e 100644 --- a/docs/connection/db_connection/kafka/read.rst +++ b/docs/connection/db_connection/kafka/read.rst @@ -7,8 +7,7 @@ For reading data from Kafka, use :obj:`DBReader 1010``. - -The ``id`` column here is called ``High WaterMark`` or ``HWM`` for short. -It is used by different :ref:`strategy` to implement some complex logic -of filtering source data. - - -Supported types ---------------- - -HWM column have to be one of the following types: - -1. Integer (of any length, like ``INTEGER``, ``SHORT``, ``LONG``). - -2. Decimal **without** fractional part (Oracle specific, because ``INTEGER`` type here is ``NUMERIC``) - -3. Date - -4. Datetime (a.k.a ``TIMESTAMP``) - -Other column types (like ``DOUBLE`` or ``VARCHAR``) are **not** supported. -But it is possible to use some expression as HWM, like ``CAST(column as TYPE)`` or ``func(column) as hwm``. - -See strategies and :ref:`db-reader` documentation for examples. - - -Restrictions ------------- - -- HWM column values cannot decrease over time, they can only increase -- HWM column cannot contain ``NULL`` values because they cannot be tracked properly, and thus will be skipped - - -Recommendations ---------------- - -- It is highly recommended for HWM column values to be unique, like a table primary key. - - Otherwise, new rows with the same column value as stored in HWM will be skipped, - because they will not match the WHERE clause. - -- It is recommended to add an index for HWM column. - - Filtering by non-indexed rows requires sequential scan for all rows in a table, with complexity ``O(n)``, - which can take *a lot* of time on large tables, especially if RDBMS is not distributed. - - Filtering by row index has complexity ``O(log n)`` (for B-tree), which is very effective. - - .. note:: - - Filtering performance depends on the index implementation and internal RDBMS optimization engine. - - .. note:: - - Use indexes which support ``<`` and ``>`` operations. - Hash-based indexes, which support only ``=`` and ``IN`` operations, cannot be used diff --git a/docs/hwm/file_hwm.rst b/docs/hwm/file_hwm.rst deleted file mode 100644 index f0bc98228..000000000 --- a/docs/hwm/file_hwm.rst +++ /dev/null @@ -1,53 +0,0 @@ -.. _file-hwm: - -File HWM -======== - -What is HWM? -------------- - -Sometimes it's necessary to read/download only new files from a source folder. - -For example, there is a folder with files: - -.. code:: bash - - $ hdfs dfs -ls /path - - 2MB 2023-09-09 10:13 /path/my/file123 - 4Mb 2023-09-09 10:15 /path/my/file234 - -When new file is being added to this folder: - -.. code:: bash - - $ hdfs dfs -ls /path - - 2MB 2023-09-09 10:13 /path/my/file123 - 4Mb 2023-09-09 10:15 /path/my/file234 - 5Mb 2023-09-09 10:20 /path/my/file345 # new one - -To download only new files, if is required to somehow track them, and then filter using the information -from a previous run. - -This technique is called ``High WaterMark`` or ``HWM`` for short. -It is used by different :ref:`strategy` to implement some complex logic -of filtering source data. - - -Supported types ---------------- - -There are a several ways to track HWM value: - - * Save the entire file list, and then select only files not present in this list - (``file_list``) - * Save max modified time of all files, and then select only files with ``modified_time`` - higher than this value - * If file name contains some incrementing value, e.g. id or datetime, - parse it and save max value of all files, then select only files with higher value - * and so on - -Currently the only HWM type implemented for files is ``file_list``. Other ones can be implemented on-demand - -See strategies and :ref:`file-downloader` documentation for examples. diff --git a/docs/hwm/index.rst b/docs/hwm/index.rst deleted file mode 100644 index 25615ddce..000000000 --- a/docs/hwm/index.rst +++ /dev/null @@ -1,11 +0,0 @@ -.. _hwm: - -HWM -========= - -.. toctree:: - :maxdepth: 3 - :caption: HWM - - column_hwm - file_hwm diff --git a/docs/hwm_store/detect_hwm_store.rst b/docs/hwm_store/detect_hwm_store.rst deleted file mode 100644 index fc07f8bcc..000000000 --- a/docs/hwm_store/detect_hwm_store.rst +++ /dev/null @@ -1,8 +0,0 @@ -.. _detect-hwm-store: - -Detect HWM Store decorator -================================================================= - -.. currentmodule:: onetl.hwm.store.hwm_store_class_registry - -.. autodecorator:: detect_hwm_store diff --git a/docs/hwm_store/index.rst b/docs/hwm_store/index.rst index 30b46c474..884a72c1e 100644 --- a/docs/hwm_store/index.rst +++ b/docs/hwm_store/index.rst @@ -1,19 +1,13 @@ -.. _hwm-store: +.. _hwm: -HWM Store -========= +HWM +=== -.. toctree:: - :maxdepth: 3 - :caption: HWM store - - yaml_hwm_store - memory_hwm_store - register_hwm_store_class - detect_hwm_store +Since ``onetl>=0.10.0`` version, the HWM Store and HWM classes have been moved to a separate library :etl-entities:`etl-entities <>`. -:ref:`hwm` values are persisted in HWM stores. +The only class was left intact is YamlHWMStore, **which is default** in onETL: -It is also possible to register your own HWN Store using :ref:`register-hwm-store-class`. +.. toctree:: + :maxdepth: 2 -You can select store based on config values using :ref:`detect-hwm-store`. + yaml_hwm_store diff --git a/docs/hwm_store/memory_hwm_store.rst b/docs/hwm_store/memory_hwm_store.rst deleted file mode 100644 index 19ab33c82..000000000 --- a/docs/hwm_store/memory_hwm_store.rst +++ /dev/null @@ -1,9 +0,0 @@ -.. _memory-hwm-store: - -In-memory HWM Store (ephemeral) -================================================================= - -.. currentmodule:: onetl.hwm.store.memory_hwm_store - -.. autoclass:: MemoryHWMStore - :members: get, save, clear, __enter__ diff --git a/docs/hwm_store/register_hwm_store_class.rst b/docs/hwm_store/register_hwm_store_class.rst deleted file mode 100644 index c8843aaab..000000000 --- a/docs/hwm_store/register_hwm_store_class.rst +++ /dev/null @@ -1,8 +0,0 @@ -.. _register-hwm-store-class: - -Register own HWM Store decorator -================================================================= - -.. currentmodule:: onetl.hwm.store.hwm_store_class_registry - -.. autodecorator:: register_hwm_store_class diff --git a/docs/hwm_store/yaml_hwm_store.rst b/docs/hwm_store/yaml_hwm_store.rst index 866c91569..b9d268fa7 100644 --- a/docs/hwm_store/yaml_hwm_store.rst +++ b/docs/hwm_store/yaml_hwm_store.rst @@ -1,9 +1,9 @@ .. _yaml-hwm-store: -YAML HWM Store (local, default) +YAML HWM Store ================================================================= .. currentmodule:: onetl.hwm.store.yaml_hwm_store .. autoclass:: YAMLHWMStore - :members: get, save, __enter__ + :members: get_hwm, save_hwm, __enter__ diff --git a/docs/index.rst b/docs/index.rst index 54ced3d06..71d3fc250 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,7 +1,7 @@ .. include:: ../README.rst :end-before: |Logo| -.. image:: static/logo_wide.svg +.. image:: _static/logo_wide.svg :alt: onETL logo .. include:: ../README.rst @@ -48,10 +48,9 @@ .. toctree:: :maxdepth: 2 - :caption: HWM and incremental reads + :caption: Read strategies and HWM :hidden: - hwm/index strategy/index hwm_store/index diff --git a/docs/install/spark.rst b/docs/install/spark.rst index 861527341..6c8610680 100644 --- a/docs/install/spark.rst +++ b/docs/install/spark.rst @@ -279,6 +279,7 @@ Can be used to embed ``.jar`` files to a default Spark classpath. * Download ``package.jar`` file (it's usually something like ``some-package_1.0.0.jar``). Local file name does not matter, but it should be unique. * Move it to ``$SPARK_HOME/jars/`` folder, e.g. ``^/.local/lib/python3.7/site-packages/pyspark/jars/`` or ``/opt/spark/3.2.3/jars/``. * Create Spark session **WITHOUT** passing Package name to ``spark.jars.packages`` + .. code:: python # no need to set spark.jars.packages or any other spark.jars.* option diff --git a/docs/static/logo.svg b/docs/static/logo.svg deleted file mode 100644 index 49a8efa1e..000000000 --- a/docs/static/logo.svg +++ /dev/null @@ -1,216 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/docs/static/logo_wide.svg b/docs/static/logo_wide.svg deleted file mode 100644 index bcf927fee..000000000 --- a/docs/static/logo_wide.svg +++ /dev/null @@ -1,331 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/docs/strategy/incremental_batch_strategy.rst b/docs/strategy/incremental_batch_strategy.rst index 417472ccc..485f10de2 100644 --- a/docs/strategy/incremental_batch_strategy.rst +++ b/docs/strategy/incremental_batch_strategy.rst @@ -1,7 +1,7 @@ .. _incremental-batch-strategy: -Incremental batch strategy -================================================================= +Incremental Batch Strategy +========================== .. currentmodule:: onetl.strategy.incremental_strategy diff --git a/docs/strategy/incremental_strategy.rst b/docs/strategy/incremental_strategy.rst index fdfe3d73a..b3c80be99 100644 --- a/docs/strategy/incremental_strategy.rst +++ b/docs/strategy/incremental_strategy.rst @@ -1,7 +1,7 @@ .. _incremental-strategy: -Incremental strategy -================================================================= +Incremental Strategy +==================== .. currentmodule:: onetl.strategy.incremental_strategy diff --git a/docs/strategy/snapshot_batch_strategy.rst b/docs/strategy/snapshot_batch_strategy.rst index 3d99d110c..0386112f5 100644 --- a/docs/strategy/snapshot_batch_strategy.rst +++ b/docs/strategy/snapshot_batch_strategy.rst @@ -1,7 +1,7 @@ .. _snapshot-batch-strategy: -Snapshot batch strategy -================================================================= +Snapshot Batch Strategy +======================= .. currentmodule:: onetl.strategy.snapshot_strategy diff --git a/docs/strategy/snapshot_strategy.rst b/docs/strategy/snapshot_strategy.rst index 6b86aeed0..6c13097c1 100644 --- a/docs/strategy/snapshot_strategy.rst +++ b/docs/strategy/snapshot_strategy.rst @@ -1,7 +1,7 @@ .. _snapshot-strategy: -Snapshot strategy -================================================================= +Snapshot Strategy +================= .. currentmodule:: onetl.strategy.snapshot_strategy diff --git a/onetl/VERSION b/onetl/VERSION index b0bb87854..78bc1abd1 100644 --- a/onetl/VERSION +++ b/onetl/VERSION @@ -1 +1 @@ -0.9.5 +0.10.0 diff --git a/onetl/_internal.py b/onetl/_internal.py index 02b098b8b..1476a6414 100644 --- a/onetl/_internal.py +++ b/onetl/_internal.py @@ -22,7 +22,7 @@ from datetime import datetime from typing import TYPE_CHECKING, Any -from etl_entities import ProcessStackManager +from etl_entities.process import ProcessStackManager from pydantic import SecretStr if TYPE_CHECKING: @@ -163,7 +163,7 @@ def generate_temp_path(root: PurePath) -> PurePath: .. code:: python - from etl_entities import Process + from etl_entities.process import Process from pathlib import Path @@ -180,41 +180,3 @@ def generate_temp_path(root: PurePath) -> PurePath: current_process = ProcessStackManager.get_current() current_dt = datetime.now().strftime(DATETIME_FORMAT) return root / "onetl" / current_process.host / current_process.full_name / current_dt - - -def get_sql_query( - table: str, - columns: list[str] | None = None, - where: str | None = None, - hint: str | None = None, - compact: bool = False, -) -> str: - """ - Generates a SQL query using input arguments - """ - - if compact: - indent = " " - else: - indent = os.linesep + " " * 7 - - hint = f" /*+ {hint} */" if hint else "" - - columns_str = "*" - if columns: - columns_str = indent + f",{indent}".join(column for column in columns) - - if columns_str.strip() == "*": - columns_str = indent + "*" - - where_str = "" - if where: - where_str = "WHERE" + indent + where - - return os.linesep.join( - [ - f"SELECT{hint}{columns_str}", - f"FROM{indent}{table}", - where_str, - ], - ).strip() diff --git a/onetl/base/__init__.py b/onetl/base/__init__.py index c46eb5e27..733b2d524 100644 --- a/onetl/base/__init__.py +++ b/onetl/base/__init__.py @@ -25,7 +25,7 @@ from onetl.base.base_file_limit import BaseFileLimit from onetl.base.contains_exception import ContainsException from onetl.base.contains_get_df_schema import ContainsGetDFSchemaMethod -from onetl.base.contains_get_min_max_bounds import ContainsGetMinMaxBounds +from onetl.base.contains_get_min_max_values import ContainsGetMinMaxValues from onetl.base.path_protocol import PathProtocol, PathWithStatsProtocol from onetl.base.path_stat_protocol import PathStatProtocol from onetl.base.pure_path_protocol import PurePathProtocol diff --git a/onetl/base/base_db_connection.py b/onetl/base/base_db_connection.py index 56a7c5ae5..88c9b11a8 100644 --- a/onetl/base/base_db_connection.py +++ b/onetl/base/base_db_connection.py @@ -15,16 +15,15 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Callable - -from etl_entities import Table +from typing import TYPE_CHECKING, Any from onetl.base.base_connection import BaseConnection -from onetl.hwm import Statement +from onetl.hwm import Window if TYPE_CHECKING: + from etl_entities.hwm import HWM, ColumnHWM from pyspark.sql import DataFrame - from pyspark.sql.types import StructType + from pyspark.sql.types import StructField, StructType class BaseDBDialect(ABC): @@ -32,9 +31,11 @@ class BaseDBDialect(ABC): Collection of methods used for validating input values before passing them to read_source_as_df/write_df_to_target """ - @classmethod + def __init__(self, connection: BaseDBConnection) -> None: + self.connection = connection + @abstractmethod - def validate_name(cls, connection: BaseDBConnection, value: Table) -> Table: + def validate_name(self, value: str) -> str: """Check if ``source`` or ``target`` value is valid. Raises @@ -45,9 +46,8 @@ def validate_name(cls, connection: BaseDBConnection, value: Table) -> Table: If value is invalid """ - @classmethod @abstractmethod - def validate_columns(cls, connection: BaseDBConnection, columns: list[str] | None) -> list[str] | None: + def validate_columns(self, columns: list[str] | None) -> list[str] | None: """Check if ``columns`` value is valid. Raises @@ -58,26 +58,20 @@ def validate_columns(cls, connection: BaseDBConnection, columns: list[str] | Non If value is invalid """ - @classmethod @abstractmethod - def validate_hwm_column( - cls, - connection: BaseDBConnection, - hwm_column: str | None, - ) -> str | None: - """Check if ``hwm_column`` value is valid. + def validate_hwm(self, hwm: HWM | None) -> HWM | None: + """Check if ``HWM`` class is valid. Raises ------ TypeError - If value type is invalid + If hwm type is invalid ValueError - If value is invalid + If hwm is invalid """ - @classmethod @abstractmethod - def validate_df_schema(cls, connection: BaseDBConnection, df_schema: StructType | None) -> StructType | None: + def validate_df_schema(self, df_schema: StructType | None) -> StructType | None: """Check if ``df_schema`` value is valid. Raises @@ -88,9 +82,8 @@ def validate_df_schema(cls, connection: BaseDBConnection, df_schema: StructType If value is invalid """ - @classmethod @abstractmethod - def validate_where(cls, connection: BaseDBConnection, where: Any) -> Any | None: + def validate_where(self, where: Any) -> Any | None: """Check if ``where`` value is valid. Raises @@ -101,9 +94,8 @@ def validate_where(cls, connection: BaseDBConnection, where: Any) -> Any | None: If value is invalid """ - @classmethod @abstractmethod - def validate_hint(cls, connection: BaseDBConnection, hint: Any) -> Any | None: + def validate_hint(self, hint: Any) -> Any | None: """Check if ``hint`` value is valid. Raises @@ -114,38 +106,10 @@ def validate_hint(cls, connection: BaseDBConnection, hint: Any) -> Any | None: If value is invalid """ - @classmethod - @abstractmethod - def validate_hwm_expression(cls, connection: BaseDBConnection, value: Any) -> str | None: - """Check if ``hwm_expression`` value is valid. - - Raises - ------ - TypeError - If value type is invalid - ValueError - If value is invalid - """ - - @classmethod - @abstractmethod - def _merge_conditions(cls, conditions: list[Any]) -> Any: - """ - Convert multiple WHERE conditions to one - """ - - @classmethod @abstractmethod - def _expression_with_alias(cls, expression: Any, alias: str) -> Any: + def detect_hwm_class(self, field: StructField) -> type[ColumnHWM] | None: """ - Return "expression AS alias" statement - """ - - @classmethod - @abstractmethod - def _get_compare_statement(cls, comparator: Callable, arg1: Any, arg2: Any) -> Any: - """ - Return "arg1 COMPARATOR arg2" statement + Detects hwm column type based on specific data types in connections data stores """ @@ -156,6 +120,10 @@ class BaseDBConnection(BaseConnection): Dialect = BaseDBDialect + @property + def dialect(self): + return self.Dialect(self) + @property @abstractmethod def instance_url(self) -> str: @@ -171,8 +139,8 @@ def read_source_as_df( hint: Any | None = None, where: Any | None = None, df_schema: StructType | None = None, - start_from: Statement | None = None, - end_at: Statement | None = None, + window: Window | None = None, + limit: int | None = None, ) -> DataFrame: """ Reads the source to dataframe. |support_hooks| diff --git a/onetl/base/base_file_df_connection.py b/onetl/base/base_file_df_connection.py index 1f0be4d02..fef5a38cc 100644 --- a/onetl/base/base_file_df_connection.py +++ b/onetl/base/base_file_df_connection.py @@ -83,7 +83,7 @@ def check_if_format_supported( Validate if specific file format is supported. |support_hooks| Raises - ------- + ------ RuntimeError If file format is not supported. """ diff --git a/onetl/base/base_file_format.py b/onetl/base/base_file_format.py index 8bc0a4c4a..c1ac32dd1 100644 --- a/onetl/base/base_file_format.py +++ b/onetl/base/base_file_format.py @@ -32,7 +32,7 @@ def check_if_supported(self, spark: SparkSession) -> None: Check if Spark session does support this file format. |support_hooks| Raises - ------- + ------ RuntimeError If file format is not supported. """ @@ -64,7 +64,7 @@ def check_if_supported(self, spark: SparkSession) -> None: Check if Spark session does support this file format. |support_hooks| Raises - ------- + ------ RuntimeError If file format is not supported. """ diff --git a/onetl/base/contains_get_min_max_bounds.py b/onetl/base/contains_get_min_max_values.py similarity index 83% rename from onetl/base/contains_get_min_max_bounds.py rename to onetl/base/contains_get_min_max_values.py index 73bb68a59..9ac74b91d 100644 --- a/onetl/base/contains_get_min_max_bounds.py +++ b/onetl/base/contains_get_min_max_values.py @@ -18,18 +18,19 @@ from typing_extensions import Protocol, runtime_checkable +from onetl.hwm.window import Window + @runtime_checkable -class ContainsGetMinMaxBounds(Protocol): +class ContainsGetMinMaxValues(Protocol): """ - Protocol for objects containing ``get_min_max_bounds`` method + Protocol for objects containing ``get_min_max_values`` method """ - def get_min_max_bounds( + def get_min_max_values( self, source: str, - column: str, - expression: str | None = None, + window: Window, hint: Any | None = None, where: Any | None = None, ) -> tuple[Any, Any]: diff --git a/onetl/connection/db_connection/clickhouse/connection.py b/onetl/connection/db_connection/clickhouse/connection.py index 2ef521072..94e46c739 100644 --- a/onetl/connection/db_connection/clickhouse/connection.py +++ b/onetl/connection/db_connection/clickhouse/connection.py @@ -16,13 +16,15 @@ import logging import warnings -from typing import ClassVar, Optional +from typing import Any, ClassVar, Optional from onetl._util.classproperty import classproperty from onetl.connection.db_connection.clickhouse.dialect import ClickhouseDialect from onetl.connection.db_connection.jdbc_connection import JDBCConnection +from onetl.connection.db_connection.jdbc_connection.options import JDBCReadOptions from onetl.connection.db_connection.jdbc_mixin import JDBCStatementType from onetl.hooks import slot, support_hooks +from onetl.hwm import Window from onetl.impl import GenericOptions # do not import PySpark here, as we allow user to use `Clickhouse.get_packages()` for creating Spark session @@ -173,6 +175,27 @@ def jdbc_url(self) -> str: return f"jdbc:clickhouse://{self.host}:{self.port}?{parameters}".rstrip("?") + @slot + def get_min_max_values( + self, + source: str, + window: Window, + hint: Any | None = None, + where: Any | None = None, + options: JDBCReadOptions | None = None, + ) -> tuple[Any, Any]: + min_value, max_value = super().get_min_max_values( + source=source, + window=window, + hint=hint, + where=where, + options=options, + ) + # Clickhouse for some reason return min/max=0 if there are no rows + if min_value == max_value == 0: + return None, None + return min_value, max_value + @staticmethod def _build_statement( statement: str, diff --git a/onetl/connection/db_connection/clickhouse/dialect.py b/onetl/connection/db_connection/clickhouse/dialect.py index 56fe44b33..3fa6eb9ad 100644 --- a/onetl/connection/db_connection/clickhouse/dialect.py +++ b/onetl/connection/db_connection/clickhouse/dialect.py @@ -20,20 +20,16 @@ class ClickhouseDialect(JDBCDialect): - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: + def get_partition_column_hash(self, partition_column: str, num_partitions: int) -> str: + return f"modulo(halfMD5({partition_column}), {num_partitions})" + + def get_partition_column_mod(self, partition_column: str, num_partitions: int) -> str: + return f"{partition_column} % {num_partitions}" + + def _serialize_datetime(self, value: datetime) -> str: result = value.strftime("%Y-%m-%d %H:%M:%S") return f"CAST('{result}' AS DateTime)" - @classmethod - def _get_date_value_sql(cls, value: date) -> str: + def _serialize_date(self, value: date) -> str: result = value.strftime("%Y-%m-%d") return f"CAST('{result}' AS Date)" - - @classmethod - def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: - return f"modulo(halfMD5({partition_column}), {num_partitions})" - - @classmethod - def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: - return f"{partition_column} % {num_partitions}" diff --git a/onetl/connection/db_connection/db_connection/dialect.py b/onetl/connection/db_connection/db_connection/dialect.py index 5c9472189..da52e817c 100644 --- a/onetl/connection/db_connection/db_connection/dialect.py +++ b/onetl/connection/db_connection/db_connection/dialect.py @@ -14,117 +14,147 @@ from __future__ import annotations -import operator +import os from datetime import date, datetime -from typing import Any, Callable, ClassVar, Dict +from typing import TYPE_CHECKING, Any from onetl.base import BaseDBDialect -from onetl.hwm import Statement +from onetl.hwm import Edge, Window +from onetl.hwm.store import SparkTypeToHWM + +if TYPE_CHECKING: + from etl_entities.hwm import ColumnHWM + from pyspark.sql.types import StructField class DBDialect(BaseDBDialect): - _compare_statements: ClassVar[Dict[Callable, str]] = { - operator.ge: "{} >= {}", - operator.gt: "{} > {}", - operator.le: "{} <= {}", - operator.lt: "{} < {}", - operator.eq: "{} == {}", - operator.ne: "{} != {}", - } - - @classmethod - def _escape_column(cls, value: str) -> str: - return f'"{value}"' + def detect_hwm_class(self, field: StructField) -> type[ColumnHWM] | None: + return SparkTypeToHWM.get(field.dataType.typeName()) # type: ignore + + def get_sql_query( + self, + table: str, + columns: list[str] | None = None, + where: str | list[str] | None = None, + hint: str | None = None, + limit: int | None = None, + compact: bool = False, + ) -> str: + """ + Generates a SQL query using input arguments + """ - @classmethod - def _expression_with_alias(cls, expression: str, alias: str) -> str: - return f"{expression} AS {alias}" + if compact: + indent = " " + else: + indent = os.linesep + " " * 7 + + hint = f" /*+ {hint} */" if hint else "" + + columns_str = indent + "*" + if columns: + columns_str = indent + f",{indent}".join(column for column in columns) + + where = where or [] + if isinstance(where, str): + where = [where] + + if limit == 0: + # LIMIT 0 not valid in some databases + where = ["1 = 0"] - @classmethod - def _get_compare_statement(cls, comparator: Callable, arg1: Any, arg2: Any) -> Any: - template = cls._compare_statements[comparator] - return template.format(arg1, cls._serialize_datetime_value(arg2)) + where_clauses = [] + if len(where) == 1: + where_clauses.append("WHERE" + indent + where[0]) + else: + for i, item in enumerate(where): + directive = "WHERE" if i == 0 else " AND" + where_clauses.append(directive + indent + f"({item})") - @classmethod - def _merge_conditions(cls, conditions: list[Any]) -> Any: - if len(conditions) == 1: - return conditions[0] + query_parts = [ + f"SELECT{hint}{columns_str}", + f"FROM{indent}{table}", + *where_clauses, + f"LIMIT{indent}{limit}" if limit else "", + ] - return " AND ".join(f"({item})" for item in conditions) + return os.linesep.join(filter(None, query_parts)).strip() - @classmethod - def _condition_assembler( - cls, + def apply_window( + self, condition: Any, - start_from: Statement | None, - end_at: Statement | None, + window: Window | None = None, ) -> Any: - conditions = [condition] - - if start_from: - condition1 = cls._get_compare_statement( - comparator=start_from.operator, - arg1=start_from.expression, - arg2=start_from.value, - ) - conditions.append(condition1) - - if end_at: - condition2 = cls._get_compare_statement( - comparator=end_at.operator, - arg1=end_at.expression, - arg2=end_at.value, - ) - conditions.append(condition2) - - result: list[Any] = list(filter(None, conditions)) - if not result: + conditions = [ + condition, + self._edge_to_where(window.expression, window.start_from, position="start") if window else None, + self._edge_to_where(window.expression, window.stop_at, position="end") if window else None, + ] + return list(filter(None, conditions)) + + def escape_column(self, value: str) -> str: + return f'"{value}"' + + def aliased(self, expression: str, alias: str) -> str: + return f"{expression} AS {alias}" + + def get_max_value(self, value: Any) -> str: + """ + Generate `MAX(value)` clause for given value + """ + result = self._serialize_value(value) + return f"MAX({result})" + + def get_min_value(self, value: Any) -> str: + """ + Generate `MIN(value)` clause for given value + """ + result = self._serialize_value(value) + return f"MIN({result})" + + def _edge_to_where( + self, + expression: str, + edge: Edge, + position: str, + ) -> Any: + if not edge.is_set(): return None - return cls._merge_conditions(result) + operators: dict[tuple[str, bool], str] = { + ("start", True): ">=", + ("start", False): "> ", + ("end", True): "<=", + ("end", False): "< ", + } + + operator = operators[(position, edge.including)] + value = self._serialize_value(edge.value) + return f"{expression} {operator} {value}" - @classmethod - def _serialize_datetime_value(cls, value: Any) -> str | int | dict: + def _serialize_value(self, value: Any) -> str | int | dict: """ Transform the value into an SQL Dialect-supported form. """ if isinstance(value, datetime): - return cls._get_datetime_value_sql(value) + return self._serialize_datetime(value) if isinstance(value, date): - return cls._get_date_value_sql(value) + return self._serialize_date(value) return str(value) - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: + def _serialize_datetime(self, value: datetime) -> str: """ Transform the datetime value into supported by SQL Dialect """ result = value.isoformat() return repr(result) - @classmethod - def _get_date_value_sql(cls, value: date) -> str: + def _serialize_date(self, value: date) -> str: """ Transform the date value into supported by SQL Dialect """ result = value.isoformat() return repr(result) - - @classmethod - def _get_max_value_sql(cls, value: Any) -> str: - """ - Generate `MAX(value)` clause for given value - """ - result = cls._serialize_datetime_value(value) - return f"MAX({result})" - - @classmethod - def _get_min_value_sql(cls, value: Any) -> str: - """ - Generate `MIN(value)` clause for given value - """ - result = cls._serialize_datetime_value(value) - return f"MIN({result})" diff --git a/onetl/connection/db_connection/dialect_mixins/__init__.py b/onetl/connection/db_connection/dialect_mixins/__init__.py index 4d889b276..43f1df40d 100644 --- a/onetl/connection/db_connection/dialect_mixins/__init__.py +++ b/onetl/connection/db_connection/dialect_mixins/__init__.py @@ -1,27 +1,24 @@ -from onetl.connection.db_connection.dialect_mixins.support_columns_list import ( - SupportColumnsList, +from onetl.connection.db_connection.dialect_mixins.not_support_columns import ( + NotSupportColumns, +) +from onetl.connection.db_connection.dialect_mixins.not_support_df_schema import ( + NotSupportDFSchema, ) -from onetl.connection.db_connection.dialect_mixins.support_columns_none import ( - SupportColumnsNone, +from onetl.connection.db_connection.dialect_mixins.not_support_hint import ( + NotSupportHint, ) -from onetl.connection.db_connection.dialect_mixins.support_df_schema_none import ( - SupportDfSchemaNone, +from onetl.connection.db_connection.dialect_mixins.not_support_where import ( + NotSupportWhere, ) -from onetl.connection.db_connection.dialect_mixins.support_df_schema_struct import ( - SupportDfSchemaStruct, +from onetl.connection.db_connection.dialect_mixins.requires_df_schema import ( + RequiresDFSchema, ) -from onetl.connection.db_connection.dialect_mixins.support_hint_none import ( - SupportHintNone, +from onetl.connection.db_connection.dialect_mixins.support_columns_list import ( + SupportColumns, ) from onetl.connection.db_connection.dialect_mixins.support_hint_str import ( SupportHintStr, ) -from onetl.connection.db_connection.dialect_mixins.support_hwm_column_str import ( - SupportHWMColumnStr, -) -from onetl.connection.db_connection.dialect_mixins.support_hwm_expression_none import ( - SupportHWMExpressionNone, -) from onetl.connection.db_connection.dialect_mixins.support_hwm_expression_str import ( SupportHWMExpressionStr, ) @@ -31,9 +28,6 @@ from onetl.connection.db_connection.dialect_mixins.support_name_with_schema_only import ( SupportNameWithSchemaOnly, ) -from onetl.connection.db_connection.dialect_mixins.support_where_none import ( - SupportWhereNone, -) from onetl.connection.db_connection.dialect_mixins.support_where_str import ( SupportWhereStr, ) diff --git a/onetl/connection/db_connection/dialect_mixins/support_columns_none.py b/onetl/connection/db_connection/dialect_mixins/not_support_columns.py similarity index 65% rename from onetl/connection/db_connection/dialect_mixins/support_columns_none.py rename to onetl/connection/db_connection/dialect_mixins/not_support_columns.py index 1432a6c4a..74fca09cf 100644 --- a/onetl/connection/db_connection/dialect_mixins/support_columns_none.py +++ b/onetl/connection/db_connection/dialect_mixins/not_support_columns.py @@ -5,12 +5,12 @@ from onetl.base import BaseDBConnection -class SupportColumnsNone: - @classmethod +class NotSupportColumns: + connection: BaseDBConnection + def validate_columns( - cls, - connection: BaseDBConnection, + self, columns: Any, ) -> None: if columns is not None: - raise ValueError(f"'columns' parameter is not supported by {connection.__class__.__name__}") + raise ValueError(f"'columns' parameter is not supported by {self.connection.__class__.__name__}") diff --git a/onetl/connection/db_connection/dialect_mixins/not_support_df_schema.py b/onetl/connection/db_connection/dialect_mixins/not_support_df_schema.py new file mode 100644 index 000000000..8ae04939b --- /dev/null +++ b/onetl/connection/db_connection/dialect_mixins/not_support_df_schema.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from typing import Any + +from onetl.base import BaseDBConnection + + +class NotSupportDFSchema: + connection: BaseDBConnection + + def validate_df_schema( + self, + df_schema: Any, + ) -> None: + if df_schema: + raise ValueError(f"'df_schema' parameter is not supported by {self.connection.__class__.__name__}") diff --git a/onetl/connection/db_connection/dialect_mixins/support_hint_none.py b/onetl/connection/db_connection/dialect_mixins/not_support_hint.py similarity index 66% rename from onetl/connection/db_connection/dialect_mixins/support_hint_none.py rename to onetl/connection/db_connection/dialect_mixins/not_support_hint.py index a7fcf84d2..bbc8e13a4 100644 --- a/onetl/connection/db_connection/dialect_mixins/support_hint_none.py +++ b/onetl/connection/db_connection/dialect_mixins/not_support_hint.py @@ -5,12 +5,12 @@ from onetl.base import BaseDBConnection -class SupportHintNone: - @classmethod +class NotSupportHint: + connection: BaseDBConnection + def validate_hint( - cls, - connection: BaseDBConnection, + self, hint: Any, ) -> None: if hint is not None: - raise TypeError(f"'hint' parameter is not supported by {connection.__class__.__name__}") + raise TypeError(f"'hint' parameter is not supported by {self.connection.__class__.__name__}") diff --git a/onetl/connection/db_connection/dialect_mixins/support_where_none.py b/onetl/connection/db_connection/dialect_mixins/not_support_where.py similarity index 86% rename from onetl/connection/db_connection/dialect_mixins/support_where_none.py rename to onetl/connection/db_connection/dialect_mixins/not_support_where.py index 6eb9dff93..0bd06f76f 100644 --- a/onetl/connection/db_connection/dialect_mixins/support_where_none.py +++ b/onetl/connection/db_connection/dialect_mixins/not_support_where.py @@ -19,12 +19,12 @@ from onetl.base import BaseDBConnection -class SupportWhereNone: - @classmethod +class NotSupportWhere: + connection: BaseDBConnection + def validate_where( - cls, - connection: BaseDBConnection, + self, where: Any, ) -> None: if where is not None: - raise TypeError(f"'where' parameter is not supported by {connection.__class__.__name__}") + raise TypeError(f"'where' parameter is not supported by {self.connection.__class__.__name__}") diff --git a/onetl/connection/db_connection/dialect_mixins/support_df_schema_struct.py b/onetl/connection/db_connection/dialect_mixins/requires_df_schema.py similarity index 73% rename from onetl/connection/db_connection/dialect_mixins/support_df_schema_struct.py rename to onetl/connection/db_connection/dialect_mixins/requires_df_schema.py index 1c6029b20..a66bff553 100644 --- a/onetl/connection/db_connection/dialect_mixins/support_df_schema_struct.py +++ b/onetl/connection/db_connection/dialect_mixins/requires_df_schema.py @@ -8,13 +8,13 @@ from onetl.base import BaseDBConnection -class SupportDfSchemaStruct: - @classmethod +class RequiresDFSchema: + connection: BaseDBConnection + def validate_df_schema( - cls, - connection: BaseDBConnection, + self, df_schema: StructType | None, ) -> StructType: if df_schema: return df_schema - raise ValueError(f"'df_schema' parameter is mandatory for {connection.__class__.__name__}") + raise ValueError(f"'df_schema' parameter is mandatory for {self.connection.__class__.__name__}") diff --git a/onetl/connection/db_connection/dialect_mixins/support_columns_list.py b/onetl/connection/db_connection/dialect_mixins/support_columns_list.py index f83356f71..0ab4eb24b 100644 --- a/onetl/connection/db_connection/dialect_mixins/support_columns_list.py +++ b/onetl/connection/db_connection/dialect_mixins/support_columns_list.py @@ -2,14 +2,10 @@ from typing import Any -from onetl.base import BaseDBConnection - -class SupportColumnsList: - @classmethod +class SupportColumns: def validate_columns( - cls, - connection: BaseDBConnection, + self, columns: Any, ) -> list[str] | None: if columns is None: diff --git a/onetl/connection/db_connection/dialect_mixins/support_df_schema_none.py b/onetl/connection/db_connection/dialect_mixins/support_df_schema_none.py deleted file mode 100644 index 6657631a0..000000000 --- a/onetl/connection/db_connection/dialect_mixins/support_df_schema_none.py +++ /dev/null @@ -1,19 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from pyspark.sql.types import StructType - -from onetl.base import BaseDBConnection - - -class SupportDfSchemaNone: - @classmethod - def validate_df_schema( - cls, - connection: BaseDBConnection, - df_schema: StructType | None, - ) -> None: - if df_schema: - raise ValueError(f"'df_schema' parameter is not supported by {connection.__class__.__name__}") diff --git a/onetl/connection/db_connection/dialect_mixins/support_hint_str.py b/onetl/connection/db_connection/dialect_mixins/support_hint_str.py index 144ac1623..73f3b2875 100644 --- a/onetl/connection/db_connection/dialect_mixins/support_hint_str.py +++ b/onetl/connection/db_connection/dialect_mixins/support_hint_str.py @@ -6,10 +6,10 @@ class SupportHintStr: - @classmethod + connection: BaseDBConnection + def validate_hint( - cls, - connection: BaseDBConnection, + self, hint: Any, ) -> str | None: if hint is None: @@ -17,7 +17,7 @@ def validate_hint( if not isinstance(hint, str): raise TypeError( - f"{connection.__class__.__name__} requires 'hint' parameter type to be 'str', " + f"{self.connection.__class__.__name__} requires 'hint' parameter type to be 'str', " f"got {hint.__class__.__name__!r}", ) diff --git a/onetl/connection/db_connection/dialect_mixins/support_hwm_column_str.py b/onetl/connection/db_connection/dialect_mixins/support_hwm_column_str.py deleted file mode 100644 index a27e6af76..000000000 --- a/onetl/connection/db_connection/dialect_mixins/support_hwm_column_str.py +++ /dev/null @@ -1,19 +0,0 @@ -from __future__ import annotations - -from onetl.base import BaseDBConnection - - -class SupportHWMColumnStr: - @classmethod - def validate_hwm_column( - cls, - connection: BaseDBConnection, - hwm_column: str | None, - ) -> str | None: - if not isinstance(hwm_column, str): - raise ValueError( - f"{connection.__class__.__name__} requires 'hwm_column' parameter type to be 'str', " - f"got {type(hwm_column)}", - ) - - return hwm_column diff --git a/onetl/connection/db_connection/dialect_mixins/support_hwm_expression_none.py b/onetl/connection/db_connection/dialect_mixins/support_hwm_expression_none.py deleted file mode 100644 index c1f5cc0c3..000000000 --- a/onetl/connection/db_connection/dialect_mixins/support_hwm_expression_none.py +++ /dev/null @@ -1,15 +0,0 @@ -from __future__ import annotations - -from typing import Any - -from onetl.base import BaseDBConnection - - -class SupportHWMExpressionNone: - @classmethod - def validate_hwm_expression(cls, connection: BaseDBConnection, value: Any) -> str | None: - if value is not None: - raise ValueError( - f"'hwm_expression' parameter is not supported by {connection.__class__.__name__}", - ) - return value diff --git a/onetl/connection/db_connection/dialect_mixins/support_hwm_expression_str.py b/onetl/connection/db_connection/dialect_mixins/support_hwm_expression_str.py index 09892b57e..0be0fe615 100644 --- a/onetl/connection/db_connection/dialect_mixins/support_hwm_expression_str.py +++ b/onetl/connection/db_connection/dialect_mixins/support_hwm_expression_str.py @@ -1,20 +1,21 @@ from __future__ import annotations -from typing import Any +from etl_entities.hwm import HWM from onetl.base import BaseDBConnection class SupportHWMExpressionStr: - @classmethod - def validate_hwm_expression(cls, connection: BaseDBConnection, value: Any) -> str | None: - if value is None: - return None + connection: BaseDBConnection - if not isinstance(value, str): + def validate_hwm(self, hwm: HWM | None) -> HWM | None: + if not hwm or hwm.expression is None: + return hwm + + if not isinstance(hwm.expression, str): raise TypeError( - f"{connection.__class__.__name__} requires 'hwm_expression' parameter type to be 'str', " - f"got {value.__class__.__name__!r}", + f"{self.connection.__class__.__name__} requires 'hwm.expression' parameter type to be 'str', " + f"got {hwm.expression.__class__.__name__!r}", ) - return value + return hwm diff --git a/onetl/connection/db_connection/dialect_mixins/support_name_any.py b/onetl/connection/db_connection/dialect_mixins/support_name_any.py index 8ecb34fd6..eb87b8097 100644 --- a/onetl/connection/db_connection/dialect_mixins/support_name_any.py +++ b/onetl/connection/db_connection/dialect_mixins/support_name_any.py @@ -1,11 +1,6 @@ from __future__ import annotations -from etl_entities import Table - -from onetl.base import BaseDBConnection - class SupportNameAny: - @classmethod - def validate_name(cls, connection: BaseDBConnection, value: Table) -> Table: + def validate_name(self, value: str) -> str: return value diff --git a/onetl/connection/db_connection/dialect_mixins/support_name_with_schema_only.py b/onetl/connection/db_connection/dialect_mixins/support_name_with_schema_only.py index eb374ca3a..4b9f4a29d 100644 --- a/onetl/connection/db_connection/dialect_mixins/support_name_with_schema_only.py +++ b/onetl/connection/db_connection/dialect_mixins/support_name_with_schema_only.py @@ -1,14 +1,9 @@ from __future__ import annotations -from etl_entities import Table - -from onetl.base import BaseDBConnection - class SupportNameWithSchemaOnly: - @classmethod - def validate_name(cls, connection: BaseDBConnection, value: Table) -> Table: - if value.name.count(".") != 1: + def validate_name(self, value: str) -> str: + if value.count(".") != 1: raise ValueError( f"Name should be passed in `schema.name` format, got '{value}'", ) diff --git a/onetl/connection/db_connection/dialect_mixins/support_where_str.py b/onetl/connection/db_connection/dialect_mixins/support_where_str.py index dbdc39145..1354b5345 100644 --- a/onetl/connection/db_connection/dialect_mixins/support_where_str.py +++ b/onetl/connection/db_connection/dialect_mixins/support_where_str.py @@ -6,10 +6,10 @@ class SupportWhereStr: - @classmethod + connection: BaseDBConnection + def validate_where( - cls, - connection: BaseDBConnection, + self, where: Any, ) -> str | None: if where is None: @@ -17,7 +17,7 @@ def validate_where( if not isinstance(where, str): raise TypeError( - f"{connection.__class__.__name__} requires 'where' parameter type to be 'str', " + f"{self.connection.__class__.__name__} requires 'where' parameter type to be 'str', " f"got {where.__class__.__name__!r}", ) diff --git a/onetl/connection/db_connection/greenplum/connection.py b/onetl/connection/db_connection/greenplum/connection.py index d1eedff7f..ad9287e7d 100644 --- a/onetl/connection/db_connection/greenplum/connection.py +++ b/onetl/connection/db_connection/greenplum/connection.py @@ -23,7 +23,6 @@ from etl_entities.instance import Host from pydantic import validator -from onetl._internal import get_sql_query from onetl._util.classproperty import classproperty from onetl._util.java import try_import_java_class from onetl._util.scala import get_default_scala_version @@ -43,7 +42,7 @@ from onetl.connection.db_connection.jdbc_mixin.options import JDBCOptions from onetl.exception import MISSING_JVM_CLASS_MSG, TooManyParallelJobsError from onetl.hooks import slot, support_hooks -from onetl.hwm import Statement +from onetl.hwm import Window from onetl.impl import GenericOptions from onetl.log import log_lines, log_with_indent @@ -267,25 +266,29 @@ def read_source_as_df( hint: str | None = None, where: str | None = None, df_schema: StructType | None = None, - start_from: Statement | None = None, - end_at: Statement | None = None, + window: Window | None = None, + limit: int | None = None, options: GreenplumReadOptions | None = None, ) -> DataFrame: read_options = self.ReadOptions.parse(options).dict(by_alias=True, exclude_none=True) log.info("|%s| Executing SQL query (on executor):", self.__class__.__name__) - where = self.Dialect._condition_assembler(condition=where, start_from=start_from, end_at=end_at) - query = get_sql_query(table=source, columns=columns, where=where) - log_lines(log, query) + where = self.dialect.apply_window(where, window) + fake_query_for_log = self.dialect.get_sql_query(table=source, columns=columns, where=where, limit=limit) + log_lines(log, fake_query_for_log) df = self.spark.read.format("greenplum").options(**self._connector_params(source), **read_options).load() self._check_expected_jobs_number(df, action="read") if where: - df = df.filter(where) + for item in where: + df = df.filter(item) if columns: df = df.selectExpr(*columns) + if limit is not None: + df = df.limit(limit) + log.info("|Spark| DataFrame successfully created from SQL statement ") return df @@ -323,7 +326,7 @@ def get_df_schema( ) -> StructType: log.info("|%s| Fetching schema of table %r ...", self.__class__.__name__, source) - query = get_sql_query(source, columns=columns, where="1=0", compact=True) + query = self.dialect.get_sql_query(source, columns=columns, limit=0, compact=True) jdbc_options = self.JDBCOptions.parse(options).copy(update={"fetchsize": 0}) log.debug("|%s| Executing SQL query (on driver):", self.__class__.__name__) @@ -335,32 +338,30 @@ def get_df_schema( return df.schema @slot - def get_min_max_bounds( + def get_min_max_values( self, source: str, - column: str, - expression: str | None = None, - hint: str | None = None, - where: str | None = None, + window: Window, + hint: Any | None = None, + where: Any | None = None, options: JDBCOptions | None = None, ) -> tuple[Any, Any]: - log.info("|%s| Getting min and max values for column %r ...", self.__class__.__name__, column) - + log.info("|%s| Getting min and max values for %r ...", self.__class__.__name__, window.expression) jdbc_options = self.JDBCOptions.parse(options).copy(update={"fetchsize": 1}) - query = get_sql_query( + query = self.dialect.get_sql_query( table=source, columns=[ - self.Dialect._expression_with_alias( - self.Dialect._get_min_value_sql(expression or column), - self.Dialect._escape_column("min"), + self.dialect.aliased( + self.dialect.get_min_value(window.expression), + self.dialect.escape_column("min"), ), - self.Dialect._expression_with_alias( - self.Dialect._get_max_value_sql(expression or column), - self.Dialect._escape_column("max"), + self.dialect.aliased( + self.dialect.get_max_value(window.expression), + self.dialect.escape_column("max"), ), ], - where=where, + where=self.dialect.apply_window(where, window), ) log.info("|%s| Executing SQL query (on driver):", self.__class__.__name__) @@ -372,8 +373,8 @@ def get_min_max_bounds( max_value = row["max"] log.info("|%s| Received values:", self.__class__.__name__) - log_with_indent(log, "MIN(%r) = %r", column, min_value) - log_with_indent(log, "MAX(%r) = %r", column, max_value) + log_with_indent(log, "MIN(%s) = %r", window.expression, min_value) + log_with_indent(log, "MAX(%s) = %r", window.expression, max_value) return min_value, max_value diff --git a/onetl/connection/db_connection/greenplum/dialect.py b/onetl/connection/db_connection/greenplum/dialect.py index a998811aa..723503a77 100644 --- a/onetl/connection/db_connection/greenplum/dialect.py +++ b/onetl/connection/db_connection/greenplum/dialect.py @@ -18,10 +18,9 @@ from onetl.connection.db_connection.db_connection import DBDialect from onetl.connection.db_connection.dialect_mixins import ( - SupportColumnsList, - SupportDfSchemaNone, - SupportHintNone, - SupportHWMColumnStr, + NotSupportDFSchema, + NotSupportHint, + SupportColumns, SupportHWMExpressionStr, SupportNameWithSchemaOnly, SupportWhereStr, @@ -30,20 +29,17 @@ class GreenplumDialect( # noqa: WPS215 SupportNameWithSchemaOnly, - SupportColumnsList, - SupportDfSchemaNone, + SupportColumns, + NotSupportDFSchema, SupportWhereStr, - SupportHintNone, + NotSupportHint, SupportHWMExpressionStr, - SupportHWMColumnStr, DBDialect, ): - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: + def _serialize_datetime(self, value: datetime) -> str: result = value.isoformat() return f"cast('{result}' as timestamp)" - @classmethod - def _get_date_value_sql(cls, value: date) -> str: + def _serialize_date(self, value: date) -> str: result = value.isoformat() return f"cast('{result}' as date)" diff --git a/onetl/connection/db_connection/hive/connection.py b/onetl/connection/db_connection/hive/connection.py index 97cc034f5..ddcc3ea8f 100644 --- a/onetl/connection/db_connection/hive/connection.py +++ b/onetl/connection/db_connection/hive/connection.py @@ -16,12 +16,12 @@ import logging from textwrap import dedent -from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Tuple +from typing import TYPE_CHECKING, Any, ClassVar, Iterable from etl_entities.instance import Cluster from pydantic import validator -from onetl._internal import clear_statement, get_sql_query +from onetl._internal import clear_statement from onetl._util.spark import inject_spark_param from onetl.connection.db_connection.db_connection import DBConnection from onetl.connection.db_connection.hive.dialect import HiveDialect @@ -32,7 +32,7 @@ ) from onetl.connection.db_connection.hive.slots import HiveSlots from onetl.hooks import slot, support_hooks -from onetl.hwm import Statement +from onetl.hwm import Window from onetl.log import log_lines, log_with_indent if TYPE_CHECKING: @@ -364,18 +364,18 @@ def read_source_as_df( hint: str | None = None, where: str | None = None, df_schema: StructType | None = None, - start_from: Statement | None = None, - end_at: Statement | None = None, + window: Window | None = None, + limit: int | None = None, ) -> DataFrame: - where = self.Dialect._condition_assembler(condition=where, start_from=start_from, end_at=end_at) - sql_text = get_sql_query( + query = self.dialect.get_sql_query( table=source, columns=columns, - where=where, + where=self.dialect.apply_window(where, window), hint=hint, + limit=limit, ) - return self.sql(sql_text) + return self.sql(query) @slot def get_df_schema( @@ -383,8 +383,8 @@ def get_df_schema( source: str, columns: list[str] | None = None, ) -> StructType: - log.info("|%s| Fetching schema of table table %r ...", self.__class__.__name__, source) - query = get_sql_query(source, columns=columns, where="1=0", compact=True) + log.info("|%s| Fetching schema of table %r ...", self.__class__.__name__, source) + query = self.dialect.get_sql_query(source, columns=columns, where=0, compact=True) log.debug("|%s| Executing SQL query:", self.__class__.__name__) log_lines(log, query, level=logging.DEBUG) @@ -394,43 +394,42 @@ def get_df_schema( return df.schema @slot - def get_min_max_bounds( + def get_min_max_values( self, source: str, - column: str, - expression: str | None = None, - hint: str | None = None, - where: str | None = None, - ) -> Tuple[Any, Any]: - log.info("|%s| Getting min and max values for column %r ...", self.__class__.__name__, column) + window: Window, + hint: Any | None = None, + where: Any | None = None, + ) -> tuple[Any, Any]: + log.info("|%s| Getting min and max values for expression %r ...", self.__class__.__name__, window.expression) - sql_text = get_sql_query( + query = self.dialect.get_sql_query( table=source, columns=[ - self.Dialect._expression_with_alias( - self.Dialect._get_min_value_sql(expression or column), - self.Dialect._escape_column("min"), + self.dialect.aliased( + self.dialect.get_min_value(window.expression), + self.dialect.escape_column("min"), ), - self.Dialect._expression_with_alias( - self.Dialect._get_max_value_sql(expression or column), - self.Dialect._escape_column("max"), + self.dialect.aliased( + self.dialect.get_max_value(window.expression), + self.dialect.escape_column("max"), ), ], - where=where, + where=self.dialect.apply_window(where, window), hint=hint, ) - log.debug("|%s| Executing SQL query:", self.__class__.__name__) - log_lines(log, sql_text, level=logging.DEBUG) + log.info("|%s| Executing SQL query (on driver):", self.__class__.__name__) + log_lines(log, query) - df = self._execute_sql(sql_text) + df = self._execute_sql(query) row = df.collect()[0] min_value = row["min"] max_value = row["max"] log.info("|%s| Received values:", self.__class__.__name__) - log_with_indent(log, "MIN(%s) = %r", column, min_value) - log_with_indent(log, "MAX(%s) = %r", column, max_value) + log_with_indent(log, "MIN(%s) = %r", window.expression, min_value) + log_with_indent(log, "MAX(%s) = %r", window.expression, max_value) return min_value, max_value diff --git a/onetl/connection/db_connection/hive/dialect.py b/onetl/connection/db_connection/hive/dialect.py index 552e66559..7acd4dbd2 100644 --- a/onetl/connection/db_connection/hive/dialect.py +++ b/onetl/connection/db_connection/hive/dialect.py @@ -16,10 +16,9 @@ from onetl.connection.db_connection.db_connection import DBDialect from onetl.connection.db_connection.dialect_mixins import ( - SupportColumnsList, - SupportDfSchemaNone, + NotSupportDFSchema, + SupportColumns, SupportHintStr, - SupportHWMColumnStr, SupportHWMExpressionStr, SupportNameWithSchemaOnly, SupportWhereStr, @@ -28,14 +27,12 @@ class HiveDialect( # noqa: WPS215 SupportNameWithSchemaOnly, - SupportColumnsList, - SupportDfSchemaNone, + SupportColumns, + NotSupportDFSchema, SupportWhereStr, SupportHintStr, SupportHWMExpressionStr, - SupportHWMColumnStr, DBDialect, ): - @classmethod - def _escape_column(cls, value: str) -> str: + def escape_column(self, value: str) -> str: return f"`{value}`" diff --git a/onetl/connection/db_connection/jdbc_connection/connection.py b/onetl/connection/db_connection/jdbc_connection/connection.py index f5b611910..b689efcc1 100644 --- a/onetl/connection/db_connection/jdbc_connection/connection.py +++ b/onetl/connection/db_connection/jdbc_connection/connection.py @@ -20,7 +20,7 @@ from etl_entities.instance import Host -from onetl._internal import clear_statement, get_sql_query +from onetl._internal import clear_statement from onetl.connection.db_connection.db_connection import DBConnection from onetl.connection.db_connection.jdbc_connection.dialect import JDBCDialect from onetl.connection.db_connection.jdbc_connection.options import ( @@ -33,7 +33,7 @@ from onetl.connection.db_connection.jdbc_mixin import JDBCMixin from onetl.connection.db_connection.jdbc_mixin.options import JDBCOptions from onetl.hooks import slot, support_hooks -from onetl.hwm import Statement +from onetl.hwm import Window from onetl.log import log_lines, log_with_indent if TYPE_CHECKING: @@ -156,8 +156,8 @@ def read_source_as_df( hint: str | None = None, where: str | None = None, df_schema: StructType | None = None, - start_from: Statement | None = None, - end_at: Statement | None = None, + window: Window | None = None, + limit: int | None = None, options: JDBCReadOptions | None = None, ) -> DataFrame: read_options = self._set_lower_upper_bound( @@ -172,12 +172,12 @@ def read_source_as_df( if read_options.partition_column: if read_options.partitioning_mode == JDBCPartitioningMode.MOD: - partition_column = self.Dialect._get_partition_column_mod( + partition_column = self.dialect.get_partition_column_mod( read_options.partition_column, read_options.num_partitions, ) elif read_options.partitioning_mode == JDBCPartitioningMode.HASH: - partition_column = self.Dialect._get_partition_column_hash( + partition_column = self.dialect.get_partition_column_hash( read_options.partition_column, read_options.num_partitions, ) @@ -189,17 +189,18 @@ def read_source_as_df( # have the same name as the field in the table ( 2.4 version ) # https://github.com/apache/spark/pull/21379 alias = "generated_" + secrets.token_hex(5) - alias_escaped = self.Dialect._escape_column(alias) - aliased_column = self.Dialect._expression_with_alias(partition_column, alias_escaped) + alias_escaped = self.dialect.escape_column(alias) + aliased_column = self.dialect.aliased(partition_column, alias_escaped) read_options = read_options.copy(update={"partition_column": alias_escaped}) new_columns.append(aliased_column) - where = self.Dialect._condition_assembler(condition=where, start_from=start_from, end_at=end_at) - query = get_sql_query( + where = self.dialect.apply_window(where, window) + query = self.dialect.get_sql_query( table=source, columns=new_columns, where=where, hint=hint, + limit=limit, ) result = self.sql(query, read_options) @@ -236,7 +237,7 @@ def get_df_schema( ) -> StructType: log.info("|%s| Fetching schema of table %r ...", self.__class__.__name__, source) - query = get_sql_query(source, columns=columns, where="1=0", compact=True) + query = self.dialect.get_sql_query(source, columns=columns, limit=0, compact=True) read_options = self._exclude_partition_options(self.ReadOptions.parse(options), fetchsize=0) log.debug("|%s| Executing SQL query (on driver):", self.__class__.__name__) @@ -280,32 +281,30 @@ def options_to_jdbc_params( return result @slot - def get_min_max_bounds( + def get_min_max_values( self, source: str, - column: str, - expression: str | None = None, - hint: str | None = None, - where: str | None = None, + window: Window, + hint: Any | None = None, + where: Any | None = None, options: JDBCReadOptions | None = None, ) -> tuple[Any, Any]: - log.info("|%s| Getting min and max values for column %r ...", self.__class__.__name__, column) - + log.info("|%s| Getting min and max values for expression %r ...", self.__class__.__name__, window.expression) read_options = self._exclude_partition_options(self.ReadOptions.parse(options), fetchsize=1) - query = get_sql_query( + query = self.dialect.get_sql_query( table=source, columns=[ - self.Dialect._expression_with_alias( - self.Dialect._get_min_value_sql(expression or column), - self.Dialect._escape_column("min"), + self.dialect.aliased( + self.dialect.get_min_value(window.expression), + self.dialect.escape_column("min"), ), - self.Dialect._expression_with_alias( - self.Dialect._get_max_value_sql(expression or column), - self.Dialect._escape_column("max"), + self.dialect.aliased( + self.dialect.get_max_value(window.expression), + self.dialect.escape_column("max"), ), ], - where=where, + where=self.dialect.apply_window(where, window), hint=hint, ) @@ -318,8 +317,8 @@ def get_min_max_bounds( max_value = row["max"] log.info("|%s| Received values:", self.__class__.__name__) - log_with_indent(log, "MIN(%s) = %r", column, min_value) - log_with_indent(log, "MAX(%s) = %r", column, max_value) + log_with_indent(log, "MIN(%s) = %r", window.expression, min_value) + log_with_indent(log, "MAX(%s) = %r", window.expression, max_value) return min_value, max_value @@ -377,9 +376,9 @@ def _set_lower_upper_bound( options.partition_column, ) - min_partition_value, max_partition_value = self.get_min_max_bounds( + min_partition_value, max_partition_value = self.get_min_max_values( source=table, - column=options.partition_column, + window=Window(options.partition_column), where=where, hint=hint, options=options, diff --git a/onetl/connection/db_connection/jdbc_connection/dialect.py b/onetl/connection/db_connection/jdbc_connection/dialect.py index 790a0c300..f051cb92b 100644 --- a/onetl/connection/db_connection/jdbc_connection/dialect.py +++ b/onetl/connection/db_connection/jdbc_connection/dialect.py @@ -18,10 +18,9 @@ from onetl.connection.db_connection.db_connection import DBDialect from onetl.connection.db_connection.dialect_mixins import ( - SupportColumnsList, - SupportDfSchemaNone, + NotSupportDFSchema, + SupportColumns, SupportHintStr, - SupportHWMColumnStr, SupportHWMExpressionStr, SupportNameWithSchemaOnly, SupportWhereStr, @@ -30,20 +29,17 @@ class JDBCDialect( # noqa: WPS215 SupportNameWithSchemaOnly, - SupportColumnsList, - SupportDfSchemaNone, + SupportColumns, + NotSupportDFSchema, SupportWhereStr, SupportHintStr, SupportHWMExpressionStr, - SupportHWMColumnStr, DBDialect, ): - @classmethod @abstractmethod - def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: + def get_partition_column_hash(self, partition_column: str, num_partitions: int) -> str: ... - @classmethod @abstractmethod - def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: + def get_partition_column_mod(self, partition_column: str, num_partitions: int) -> str: ... diff --git a/onetl/connection/db_connection/kafka/connection.py b/onetl/connection/db_connection/kafka/connection.py index 111b61b08..0f32f4e9b 100644 --- a/onetl/connection/db_connection/kafka/connection.py +++ b/onetl/connection/db_connection/kafka/connection.py @@ -46,7 +46,7 @@ from onetl.connection.db_connection.kafka.slots import KafkaSlots from onetl.exception import MISSING_JVM_CLASS_MSG, TargetAlreadyExistsError from onetl.hooks import slot, support_hooks -from onetl.hwm import Statement +from onetl.hwm.window import Window from onetl.log import log_collection, log_with_indent if TYPE_CHECKING: @@ -265,8 +265,8 @@ def read_source_as_df( hint: Any | None = None, where: Any | None = None, df_schema: StructType | None = None, - start_from: Statement | None = None, - end_at: Statement | None = None, + window: Window | None = None, + limit: int | None = None, options: KafkaReadOptions = KafkaReadOptions(), # noqa: B008, WPS404 ) -> DataFrame: log.info("|%s| Reading data from topic %r", self.__class__.__name__, source) diff --git a/onetl/connection/db_connection/kafka/dialect.py b/onetl/connection/db_connection/kafka/dialect.py index e8c35ccaa..b0507467c 100644 --- a/onetl/connection/db_connection/kafka/dialect.py +++ b/onetl/connection/db_connection/kafka/dialect.py @@ -17,57 +17,49 @@ import logging +from etl_entities.hwm import HWM + from onetl._util.spark import get_spark_version -from onetl.base import BaseDBConnection from onetl.connection.db_connection.db_connection.dialect import DBDialect from onetl.connection.db_connection.dialect_mixins import ( - SupportColumnsNone, - SupportDfSchemaNone, - SupportHintNone, - SupportHWMExpressionNone, + NotSupportColumns, + NotSupportDFSchema, + NotSupportHint, + NotSupportWhere, SupportNameAny, - SupportWhereNone, ) log = logging.getLogger(__name__) class KafkaDialect( # noqa: WPS215 - SupportColumnsNone, - SupportDfSchemaNone, - SupportHintNone, - SupportWhereNone, + NotSupportColumns, + NotSupportDFSchema, + NotSupportHint, + NotSupportWhere, SupportNameAny, - SupportHWMExpressionNone, DBDialect, ): - valid_hwm_columns = {"offset", "timestamp"} + SUPPORTED_HWM_COLUMNS = {"offset", "timestamp"} + + def validate_hwm( + self, + hwm: HWM | None, + ) -> HWM | None: + if not hwm: + return None - @classmethod - def validate_hwm_column( - cls, - connection: BaseDBConnection, - hwm_column: str | None, - ) -> str | None: - if not isinstance(hwm_column, str): + if hwm.expression not in self.SUPPORTED_HWM_COLUMNS: raise ValueError( - f"{connection.__class__.__name__} requires 'hwm_column' parameter type to be 'str', " - f"got {type(hwm_column)}", + f"hwm.expression={hwm.expression!r} is not supported by {self.connection.__class__.__name__}. " + f"Valid values are: {self.SUPPORTED_HWM_COLUMNS}", ) - cls.validate_column(connection, hwm_column) - - return hwm_column - - @classmethod - def validate_column(cls, connection: BaseDBConnection, column: str) -> None: - if column not in cls.valid_hwm_columns: - raise ValueError(f"{column} is not a valid hwm column. Valid options are: {cls.valid_hwm_columns}") - - if column == "timestamp": + if hwm.expression == "timestamp": # Spark version less 3.x does not support reading from Kafka with the timestamp parameter - spark_version = get_spark_version(connection.spark) # type: ignore[attr-defined] + spark_version = get_spark_version(self.connection.spark) # type: ignore[attr-defined] if spark_version.major < 3: raise ValueError( f"Spark version must be 3.x for the timestamp column. Current version is: {spark_version}", ) + return hwm diff --git a/onetl/connection/db_connection/mongodb/connection.py b/onetl/connection/db_connection/mongodb/connection.py index 280596d5d..facdf9fee 100644 --- a/onetl/connection/db_connection/mongodb/connection.py +++ b/onetl/connection/db_connection/mongodb/connection.py @@ -14,6 +14,7 @@ from __future__ import annotations +import json import logging import warnings from typing import TYPE_CHECKING, Any @@ -37,7 +38,7 @@ ) from onetl.exception import MISSING_JVM_CLASS_MSG from onetl.hooks import slot, support_hooks -from onetl.hwm import Statement +from onetl.hwm import Window from onetl.impl import GenericOptions from onetl.log import log_dataframe_schema, log_json, log_options, log_with_indent @@ -64,8 +65,8 @@ class MongoDB(DBConnection): * MongoDB server versions: 4.0 or higher * Spark versions: 3.2.x - 3.4.x + * Scala versions: 2.12 - 2.13 * Java versions: 8 - 20 - * Scala versions: 2.11 - 2.13 See `official documentation `_. @@ -81,7 +82,7 @@ class MongoDB(DBConnection): pip install onetl[spark] # latest PySpark version # or - pip install onetl pyspark=3.4.1 # pass specific PySpark version + pip install onetl pyspark=3.4.2 # pass specific PySpark version See :ref:`install-spark` installation instruction for more details. @@ -359,7 +360,7 @@ def pipeline( log.info("|%s| Executing aggregation pipeline:", self.__class__.__name__) read_options = self.PipelineOptions.parse(options).dict(by_alias=True, exclude_none=True) - pipeline = self.Dialect.prepare_pipeline(pipeline) + pipeline = self.dialect.prepare_pipeline(pipeline) log_with_indent(log, "collection = %r", collection) log_json(log, pipeline, name="pipeline") @@ -370,7 +371,7 @@ def pipeline( log_options(log, read_options) read_options["collection"] = collection - read_options["aggregation.pipeline"] = self.Dialect.convert_to_str(pipeline) + read_options["aggregation.pipeline"] = json.dumps(pipeline) read_options["connection.uri"] = self.connection_url spark_reader = self.spark.read.format("mongodb").options(**read_options) @@ -400,46 +401,56 @@ def check(self): return self @slot - def get_min_max_bounds( + def get_min_max_values( self, source: str, - column: str, - expression: str | None = None, # noqa: U100 - hint: dict | None = None, # noqa: U100 - where: dict | None = None, + window: Window, + hint: Any | None = None, + where: Any | None = None, options: MongoDBReadOptions | dict | None = None, ) -> tuple[Any, Any]: - log.info("|%s| Getting min and max values for column %r ...", self.__class__.__name__, column) + log.info("|%s| Getting min and max values for column %r ...", self.__class__.__name__, window.expression) read_options = self.ReadOptions.parse(options).dict(by_alias=True, exclude_none=True) # The '_id' field must be specified in the request. - pipeline = [{"$group": {"_id": 1, "min": {"$min": f"${column}"}, "max": {"$max": f"${column}"}}}] - if where: - pipeline.insert(0, {"$match": where}) - - pipeline = self.Dialect.prepare_pipeline(pipeline) - - read_options["connection.uri"] = self.connection_url - read_options["collection"] = source - read_options["aggregation.pipeline"] = self.Dialect.convert_to_str(pipeline) + pipeline: list[dict[str, Any]] = [ + { + "$group": { + "_id": 1, + "min": {"$min": f"${window.expression}"}, + "max": {"$max": f"${window.expression}"}, + }, + }, + ] + final_where = self.dialect.apply_window(where, window) + if final_where: + pipeline.insert(0, {"$match": final_where}) - if hint: - read_options["hint"] = self.Dialect.convert_to_str(hint) + pipeline = self.dialect.prepare_pipeline(pipeline) log.info("|%s| Executing aggregation pipeline:", self.__class__.__name__) log_with_indent(log, "collection = %r", source) log_json(log, pipeline, "pipeline") log_json(log, hint, "hint") + read_options["connection.uri"] = self.connection_url + read_options["collection"] = source + read_options["aggregation.pipeline"] = json.dumps(pipeline) + if hint: + read_options["hint"] = json.dumps(hint) + df = self.spark.read.format("mongodb").options(**read_options).load() - row = df.collect()[0] - min_value = row["min"] - max_value = row["max"] + data = df.collect() + if data: + min_value = data[0]["min"] + max_value = data[0]["max"] + else: + min_value = max_value = None log.info("|%s| Received values:", self.__class__.__name__) - log_with_indent(log, "MIN(%s) = %r", column, min_value) - log_with_indent(log, "MAX(%s) = %r", column, max_value) + log_with_indent(log, "MIN(%s) = %r", window.expression, min_value) + log_with_indent(log, "MAX(%s) = %r", window.expression, max_value) return min_value, max_value @@ -451,23 +462,19 @@ def read_source_as_df( hint: dict | None = None, where: dict | None = None, df_schema: StructType | None = None, - start_from: Statement | None = None, - end_at: Statement | None = None, + window: Window | None = None, + limit: int | None = None, options: MongoDBReadOptions | dict | None = None, ) -> DataFrame: read_options = self.ReadOptions.parse(options).dict(by_alias=True, exclude_none=True) - final_where = self.Dialect._condition_assembler( - condition=where, - start_from=start_from, - end_at=end_at, - ) - pipeline = self.Dialect.prepare_pipeline({"$match": final_where}) if final_where else None + final_where = self.dialect.apply_window(where, window) + pipeline = self.dialect.prepare_pipeline({"$match": final_where}) if final_where else None if pipeline: - read_options["aggregation.pipeline"] = self.Dialect.convert_to_str(pipeline) + read_options["aggregation.pipeline"] = json.dumps(pipeline) if hint: - read_options["hint"] = self.Dialect.convert_to_str(hint) + read_options["hint"] = json.dumps(hint) read_options["connection.uri"] = self.connection_url read_options["collection"] = source @@ -486,6 +493,9 @@ def read_source_as_df( if columns: df = df.select(*columns) + if limit is not None: + df = df.limit(limit) + log.info("|Spark| DataFrame successfully created from SQL statement ") return df diff --git a/onetl/connection/db_connection/mongodb/dialect.py b/onetl/connection/db_connection/mongodb/dialect.py index d3d388f72..51c6a64d7 100644 --- a/onetl/connection/db_connection/mongodb/dialect.py +++ b/onetl/connection/db_connection/mongodb/dialect.py @@ -14,20 +14,17 @@ from __future__ import annotations -import json -import operator from datetime import datetime -from typing import Any, Callable, ClassVar, Dict, Iterable, Mapping +from typing import Any, Iterable, Mapping -from onetl.base.base_db_connection import BaseDBConnection from onetl.connection.db_connection.db_connection.dialect import DBDialect from onetl.connection.db_connection.dialect_mixins import ( - SupportColumnsNone, - SupportDfSchemaStruct, - SupportHWMColumnStr, - SupportHWMExpressionNone, + NotSupportColumns, + RequiresDFSchema, + SupportHWMExpressionStr, SupportNameAny, ) +from onetl.hwm import Edge, Window _upper_level_operators = frozenset( # noqa: WPS527 [ @@ -75,25 +72,13 @@ class MongoDBDialect( # noqa: WPS215 SupportNameAny, - SupportHWMExpressionNone, - SupportColumnsNone, - SupportDfSchemaStruct, - SupportHWMColumnStr, + NotSupportColumns, + RequiresDFSchema, + SupportHWMExpressionStr, DBDialect, ): - _compare_statements: ClassVar[Dict[Callable, str]] = { - operator.ge: "$gte", - operator.gt: "$gt", - operator.le: "$lte", - operator.lt: "$lt", - operator.eq: "$eq", - operator.ne: "$ne", - } - - @classmethod def validate_where( - cls, - connection: BaseDBConnection, + self, where: Any, ) -> dict | None: if where is None: @@ -101,18 +86,16 @@ def validate_where( if not isinstance(where, dict): raise ValueError( - f"{connection.__class__.__name__} requires 'where' parameter type to be 'dict', " + f"{self.connection.__class__.__name__} requires 'where' parameter type to be 'dict', " f"got {where.__class__.__name__!r}", ) for key in where: - cls._validate_top_level_keys_in_where_parameter(key) + self._validate_top_level_keys_in_where_parameter(key) return where - @classmethod def validate_hint( - cls, - connection: BaseDBConnection, + self, hint: Any, ) -> dict | None: if hint is None: @@ -120,70 +103,74 @@ def validate_hint( if not isinstance(hint, dict): raise ValueError( - f"{connection.__class__.__name__} requires 'hint' parameter type to be 'dict', " + f"{self.connection.__class__.__name__} requires 'hint' parameter type to be 'dict', " f"got {hint.__class__.__name__!r}", ) return hint - @classmethod def prepare_pipeline( - cls, + self, pipeline: Any, ) -> Any: """ Prepares pipeline (list or dict) to MongoDB syntax, but without converting it to string. """ - if isinstance(pipeline, datetime): - return {"$date": pipeline.astimezone().isoformat()} - if isinstance(pipeline, Mapping): - return {cls.prepare_pipeline(key): cls.prepare_pipeline(value) for key, value in pipeline.items()} + return {self.prepare_pipeline(key): self.prepare_pipeline(value) for key, value in pipeline.items()} if isinstance(pipeline, Iterable) and not isinstance(pipeline, str): - return [cls.prepare_pipeline(item) for item in pipeline] + return [self.prepare_pipeline(item) for item in pipeline] - return pipeline + return self._serialize_value(pipeline) - @classmethod - def convert_to_str( - cls, - value: Any, - ) -> str: + def apply_window( + self, + condition: Any, + window: Window | None = None, + ) -> Any: + result = super().apply_window(condition, window) + if not result: + return {} + if len(result) == 1: + return result[0] + return {"$and": result} + + def _serialize_value(self, value: Any) -> str | int | dict: """ - Converts the given dictionary, list or primitive to a string. + Transform the value into an SQL Dialect-supported form. """ - return json.dumps(cls.prepare_pipeline(value)) + if isinstance(value, datetime): + return {"$date": value.astimezone().isoformat()} - @classmethod - def _merge_conditions(cls, conditions: list[Any]) -> Any: - if len(conditions) == 1: - return conditions[0] + return value - return {"$and": conditions} - - @classmethod - def _get_compare_statement(cls, comparator: Callable, arg1: Any, arg2: Any) -> dict: - """ - Returns the comparison statement in MongoDB syntax: + def _edge_to_where( + self, + expression: str, + edge: Edge, + position: str, + ) -> Any: + if not expression or not edge.is_set(): + return None - .. code:: + operators: dict[tuple[str, bool], str] = { + ("start", True): "$gte", + ("start", False): "$gt", + ("end", True): "$lte", + ("end", False): "$lt", + } - { - "field": { - "$gt": "some_value", - } - } - """ + operator = operators[(position, edge.including)] + value = self._serialize_value(edge.value) return { - arg1: { - cls._compare_statements[comparator]: arg2, + expression: { + operator: value, }, } - @classmethod - def _validate_top_level_keys_in_where_parameter(cls, key: str): + def _validate_top_level_keys_in_where_parameter(self, key: str): """ Checks the 'where' parameter for illegal operators, such as ``$match``, ``$merge`` or ``$changeStream``. diff --git a/onetl/connection/db_connection/mssql/connection.py b/onetl/connection/db_connection/mssql/connection.py index b4ab8427b..dac49e3cf 100644 --- a/onetl/connection/db_connection/mssql/connection.py +++ b/onetl/connection/db_connection/mssql/connection.py @@ -94,7 +94,7 @@ class MSSQL(JDBCConnection): For example: ``{"connectRetryCount": 3, "connectRetryInterval": 10}`` See `MSSQL JDBC driver properties documentation - `_ + `_ for more details Examples diff --git a/onetl/connection/db_connection/mssql/dialect.py b/onetl/connection/db_connection/mssql/dialect.py index 95e4ff022..f39568423 100644 --- a/onetl/connection/db_connection/mssql/dialect.py +++ b/onetl/connection/db_connection/mssql/dialect.py @@ -20,21 +20,17 @@ class MSSQLDialect(JDBCDialect): - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: + # https://docs.microsoft.com/ru-ru/sql/t-sql/functions/hashbytes-transact-sql?view=sql-server-ver16 + def get_partition_column_hash(self, partition_column: str, num_partitions: int) -> str: + return f"CONVERT(BIGINT, HASHBYTES ( 'SHA' , {partition_column} )) % {num_partitions}" + + def get_partition_column_mod(self, partition_column: str, num_partitions: int) -> str: + return f"{partition_column} % {num_partitions}" + + def _serialize_datetime(self, value: datetime) -> str: result = value.isoformat() return f"CAST('{result}' AS datetime2)" - @classmethod - def _get_date_value_sql(cls, value: date) -> str: + def _serialize_date(self, value: date) -> str: result = value.isoformat() return f"CAST('{result}' AS date)" - - # https://docs.microsoft.com/ru-ru/sql/t-sql/functions/hashbytes-transact-sql?view=sql-server-ver16 - @classmethod - def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: - return f"CONVERT(BIGINT, HASHBYTES ( 'SHA' , {partition_column} )) % {num_partitions}" - - @classmethod - def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: - return f"{partition_column} % {num_partitions}" diff --git a/onetl/connection/db_connection/mysql/connection.py b/onetl/connection/db_connection/mysql/connection.py index 326ae7a81..3af51a2d1 100644 --- a/onetl/connection/db_connection/mysql/connection.py +++ b/onetl/connection/db_connection/mysql/connection.py @@ -93,7 +93,7 @@ class MySQL(JDBCConnection): For example: ``{"useSSL": "false"}`` See `MySQL JDBC driver properties documentation - `_ + `_ for more details Examples diff --git a/onetl/connection/db_connection/mysql/dialect.py b/onetl/connection/db_connection/mysql/dialect.py index b3cd70a55..59f663aed 100644 --- a/onetl/connection/db_connection/mysql/dialect.py +++ b/onetl/connection/db_connection/mysql/dialect.py @@ -20,24 +20,19 @@ class MySQLDialect(JDBCDialect): - @classmethod - def _escape_column(cls, value: str) -> str: + def get_partition_column_hash(self, partition_column: str, num_partitions: int) -> str: + return f"MOD(CONV(CONV(RIGHT(MD5({partition_column}), 16), 16, 2), 2, 10), {num_partitions})" + + def get_partition_column_mod(self, partition_column: str, num_partitions: int) -> str: + return f"MOD({partition_column}, {num_partitions})" + + def escape_column(self, value: str) -> str: return f"`{value}`" - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: + def _serialize_datetime(self, value: datetime) -> str: result = value.strftime("%Y-%m-%d %H:%M:%S.%f") return f"STR_TO_DATE('{result}', '%Y-%m-%d %H:%i:%s.%f')" - @classmethod - def _get_date_value_sql(cls, value: date) -> str: + def _serialize_date(self, value: date) -> str: result = value.strftime("%Y-%m-%d") return f"STR_TO_DATE('{result}', '%Y-%m-%d')" - - @classmethod - def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: - return f"MOD(CONV(CONV(RIGHT(MD5({partition_column}), 16),16, 2), 2, 10), {num_partitions})" - - @classmethod - def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: - return f"MOD({partition_column}, {num_partitions})" diff --git a/onetl/connection/db_connection/oracle/connection.py b/onetl/connection/db_connection/oracle/connection.py index 5166ee71e..a3ec87921 100644 --- a/onetl/connection/db_connection/oracle/connection.py +++ b/onetl/connection/db_connection/oracle/connection.py @@ -34,6 +34,7 @@ from onetl.connection.db_connection.jdbc_mixin.options import JDBCOptions from onetl.connection.db_connection.oracle.dialect import OracleDialect from onetl.hooks import slot, support_hooks +from onetl.hwm import Window from onetl.impl import GenericOptions from onetl.log import BASE_LOG_INDENT, log_lines @@ -83,7 +84,7 @@ class Oracle(JDBCConnection): .. dropdown:: Version compatibility - * Oracle Server versions: 23c, 21c, 19c, 18c, 12.2 and probably 11.2 (tested, but that's not official). + * Oracle Server versions: 23, 21, 19, 18, 12.2 and probably 11.2 (tested, but that's not official). * Spark versions: 2.3.x - 3.5.x * Java versions: 8 - 20 @@ -144,7 +145,7 @@ class Oracle(JDBCConnection): For example: ``{"defaultBatchValue": 100}`` See `Oracle JDBC driver properties documentation - `_ + `_ for more details Examples @@ -247,19 +248,17 @@ def instance_url(self) -> str: return f"{super().instance_url}/{self.service_name}" @slot - def get_min_max_bounds( + def get_min_max_values( self, source: str, - column: str, - expression: str | None = None, - hint: str | None = None, - where: str | None = None, + window: Window, + hint: Any | None = None, + where: Any | None = None, options: JDBCReadOptions | None = None, ) -> tuple[Any, Any]: - min_value, max_value = super().get_min_max_bounds( + min_value, max_value = super().get_min_max_values( source=source, - column=column, - expression=expression, + window=window, hint=hint, where=where, options=options, diff --git a/onetl/connection/db_connection/oracle/dialect.py b/onetl/connection/db_connection/oracle/dialect.py index fb3fa715d..9484524fd 100644 --- a/onetl/connection/db_connection/oracle/dialect.py +++ b/onetl/connection/db_connection/oracle/dialect.py @@ -20,20 +20,38 @@ class OracleDialect(JDBCDialect): - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: + def get_sql_query( + self, + table: str, + columns: list[str] | None = None, + where: str | list[str] | None = None, + hint: str | None = None, + limit: int | None = None, + compact: bool = False, + ) -> str: + # https://stackoverflow.com/questions/27965130/how-to-select-column-from-table-in-oracle + new_columns = columns or ["*"] + if len(new_columns) > 1: + new_columns = [table + ".*" if column.strip() == "*" else column for column in new_columns] + return super().get_sql_query( + table=table, + columns=new_columns, + where=where, + hint=hint, + limit=limit, + compact=compact, + ) + + def get_partition_column_hash(self, partition_column: str, num_partitions: int) -> str: + return f"ora_hash({partition_column}, {num_partitions})" + + def get_partition_column_mod(self, partition_column: str, num_partitions: int) -> str: + return f"MOD({partition_column}, {num_partitions})" + + def _serialize_datetime(self, value: datetime) -> str: result = value.strftime("%Y-%m-%d %H:%M:%S") return f"TO_DATE('{result}', 'YYYY-MM-DD HH24:MI:SS')" - @classmethod - def _get_date_value_sql(cls, value: date) -> str: + def _serialize_date(self, value: date) -> str: result = value.strftime("%Y-%m-%d") return f"TO_DATE('{result}', 'YYYY-MM-DD')" - - @classmethod - def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: - return f"ora_hash({partition_column}, {num_partitions})" - - @classmethod - def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: - return f"MOD({partition_column}, {num_partitions})" diff --git a/onetl/connection/db_connection/postgres/dialect.py b/onetl/connection/db_connection/postgres/dialect.py index 05a44471e..54124ce3b 100644 --- a/onetl/connection/db_connection/postgres/dialect.py +++ b/onetl/connection/db_connection/postgres/dialect.py @@ -16,26 +16,22 @@ from datetime import date, datetime -from onetl.connection.db_connection.dialect_mixins import SupportHintNone +from onetl.connection.db_connection.dialect_mixins import NotSupportHint from onetl.connection.db_connection.jdbc_connection import JDBCDialect -class PostgresDialect(SupportHintNone, JDBCDialect): - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: +class PostgresDialect(NotSupportHint, JDBCDialect): + # https://stackoverflow.com/a/9812029 + def get_partition_column_hash(self, partition_column: str, num_partitions: int) -> str: + return f"('x'||right(md5('{partition_column}'), 16))::bit(32)::bigint % {num_partitions}" + + def get_partition_column_mod(self, partition_column: str, num_partitions: int) -> str: + return f"{partition_column} % {num_partitions}" + + def _serialize_datetime(self, value: datetime) -> str: result = value.isoformat() return f"'{result}'::timestamp" - @classmethod - def _get_date_value_sql(cls, value: date) -> str: + def _serialize_date(self, value: date) -> str: result = value.isoformat() return f"'{result}'::date" - - # https://stackoverflow.com/a/9812029 - @classmethod - def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: - return f"('x'||right(md5('{partition_column}'), 16))::bit(32)::bigint % {num_partitions}" - - @classmethod - def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: - return f"{partition_column} % {num_partitions}" diff --git a/onetl/connection/db_connection/teradata/dialect.py b/onetl/connection/db_connection/teradata/dialect.py index c449debc6..7845dc360 100644 --- a/onetl/connection/db_connection/teradata/dialect.py +++ b/onetl/connection/db_connection/teradata/dialect.py @@ -20,21 +20,17 @@ class TeradataDialect(JDBCDialect): - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: + # https://docs.teradata.com/r/w4DJnG9u9GdDlXzsTXyItA/lkaegQT4wAakj~K_ZmW1Dg + def get_partition_column_hash(self, partition_column: str, num_partitions: int) -> str: + return f"HASHAMP(HASHBUCKET(HASHROW({partition_column}))) mod {num_partitions}" + + def get_partition_column_mod(self, partition_column: str, num_partitions: int) -> str: + return f"{partition_column} mod {num_partitions}" + + def _serialize_datetime(self, value: datetime) -> str: result = value.isoformat() return f"CAST('{result}' AS TIMESTAMP)" - @classmethod - def _get_date_value_sql(cls, value: date) -> str: + def _serialize_date(self, value: date) -> str: result = value.isoformat() return f"CAST('{result}' AS DATE)" - - # https://docs.teradata.com/r/w4DJnG9u9GdDlXzsTXyItA/lkaegQT4wAakj~K_ZmW1Dg - @classmethod - def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: - return f"HASHAMP(HASHBUCKET(HASHROW({partition_column}))) mod {num_partitions}" - - @classmethod - def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: - return f"{partition_column} mod {num_partitions}" diff --git a/onetl/connection/file_df_connection/spark_s3/connection.py b/onetl/connection/file_df_connection/spark_s3/connection.py index dd881988c..73dc1cb07 100644 --- a/onetl/connection/file_df_connection/spark_s3/connection.py +++ b/onetl/connection/file_df_connection/spark_s3/connection.py @@ -63,8 +63,8 @@ class SparkS3(SparkFileDFConnection): .. dropdown:: Version compatibility * Spark versions: 3.2.x - 3.5.x (only with Hadoop 3.x libraries) + * Scala versions: 2.12 - 2.13 * Java versions: 8 - 20 - * Scala versions: 2.11 - 2.13 .. warning:: diff --git a/onetl/db/db_reader/db_reader.py b/onetl/db/db_reader/db_reader.py index 00484db2b..625599cd6 100644 --- a/onetl/db/db_reader/db_reader.py +++ b/onetl/db/db_reader/db_reader.py @@ -1,35 +1,43 @@ from __future__ import annotations +import textwrap +import warnings from logging import getLogger -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Any, List, Optional, Union import frozendict -from etl_entities import Column, Table -from pydantic import Field, root_validator, validator +from etl_entities.hwm import HWM, ColumnHWM +from etl_entities.old_hwm import IntHWM as OldColumnHWM +from etl_entities.source import Column, Table +from pydantic import Field, PrivateAttr, root_validator, validator -from onetl._internal import uniq_ignore_case from onetl._util.spark import try_import_pyspark from onetl.base import ( BaseDBConnection, ContainsGetDFSchemaMethod, - ContainsGetMinMaxBounds, + ContainsGetMinMaxValues, ) from onetl.hooks import slot, support_hooks +from onetl.hwm import AutoDetectHWM, Edge, Window from onetl.impl import FrozenModel, GenericOptions from onetl.log import ( entity_boundary_log, log_collection, log_dataframe_schema, + log_hwm, log_json, log_options, log_with_indent, ) +from onetl.strategy.batch_hwm_strategy import BatchHWMStrategy +from onetl.strategy.hwm_strategy import HWMStrategy +from onetl.strategy.strategy_manager import StrategyManager log = getLogger(__name__) if TYPE_CHECKING: from pyspark.sql.dataframe import DataFrame - from pyspark.sql.types import StructType + from pyspark.sql.types import StructField, StructType @support_hooks @@ -85,7 +93,12 @@ class DBReader(FrozenModel): .. note:: - Some connections does not have columns. + Some sources does not have columns. + + .. note:: + + It is recommended to pass column names explicitly to avoid selecting too many columns, + and to avoid adding unexpected columns to dataframe if source DDL is changed. where : Any, default: ``None`` Custom ``where`` for SQL query or MongoDB pipeline. @@ -109,21 +122,35 @@ class DBReader(FrozenModel): Some sources does not support data filtering. - hwm_column : str or tuple[str, any], default: ``None`` - Column to be used as :ref:`column-hwm` value. + hwm : type[HWM] | None, default: ``None`` + HWM class to be used as :etl-entities:`HWM ` value. + + .. code:: python + + from onetl.hwm import AutoDetectHWM + + hwm = AutoDetectHWM( + name="some_unique_hwm_name", + expression="hwm_column", + ) - If you want to use some SQL expression as HWM value, you can pass it as tuple - ``("column_name", "expression")``, like: + HWM value will be fetched using ``hwm_column`` SQL query. + + If you want to use some SQL expression as HWM value, you can use it as well: .. code:: python - hwm_column = ("hwm_column", "cast(hwm_column_orig as date)") + from onetl.hwm import AutoDetectHWM - HWM value will be fetched using ``max(cast(hwm_column_orig as date)) as hwm_column`` SQL query. + hwm = AutoDetectHWM( + name="some_unique_hwm_name", + expression="cast(hwm_column_orig as date)", + ) .. note:: - Some sources does not support ``("column_name", "expression")`` syntax. + Some sources does not support passing expressions and can be used only with column/field + names which present in the source. hint : Any, default: ``None`` Hint expression used for querying the data. @@ -302,6 +329,7 @@ class DBReader(FrozenModel): from onetl.db import DBReader from onetl.connection import Postgres from onetl.strategy import IncrementalStrategy + from onetl.hwm import AutoDetectHWM from pyspark.sql import SparkSession maven_packages = Postgres.get_packages() @@ -322,7 +350,10 @@ class DBReader(FrozenModel): reader = DBReader( connection=postgres, source="fiddle.dummy", - hwm_column="d_age", # mandatory for IncrementalStrategy + hwm=DBReader.AutoDetectHWM( # mandatory for IncrementalStrategy + name="some_unique_hwm_name", + expression="d_age", + ), ) # read data from table "fiddle.dummy" @@ -331,126 +362,129 @@ class DBReader(FrozenModel): df = reader.run() """ + AutoDetectHWM = AutoDetectHWM + connection: BaseDBConnection - source: Table = Field(alias="table") - columns: Optional[List[str]] = None - hwm_column: Optional[Column] = None - hwm_expression: Optional[str] = None + source: str = Field(alias="table") + columns: Optional[List[str]] = Field(default=None, min_items=1) where: Optional[Any] = None hint: Optional[Any] = None df_schema: Optional[StructType] = None + hwm_column: Optional[Union[str, tuple]] = None + hwm_expression: Optional[str] = None + hwm: Optional[ColumnHWM] = None options: Optional[GenericOptions] = None - @validator("source", pre=True, always=True) + _connection_checked: bool = PrivateAttr(default=False) + + @validator("source", always=True) def validate_source(cls, source, values): connection: BaseDBConnection = values["connection"] - dialect = connection.Dialect - if isinstance(source, str): - # source="dbschema.table" or source="table", If source="dbschema.some.table" in class Table will raise error. - source = Table(name=source, instance=connection.instance_url) - # Here Table(name='source', db='dbschema', instance='some_instance') - return dialect.validate_name(connection, source) - - @validator("where", pre=True, always=True) + return connection.dialect.validate_name(source) + + @validator("columns", always=True) # noqa: WPS231 + def validate_columns(cls, value: list[str] | None, values: dict) -> list[str] | None: + connection: BaseDBConnection = values["connection"] + return connection.dialect.validate_columns(value) + + @validator("where", always=True) def validate_where(cls, where: Any, values: dict) -> Any: connection: BaseDBConnection = values["connection"] - dialect = connection.Dialect - result = dialect.validate_where(connection, where) + result = connection.dialect.validate_where(where) if isinstance(result, dict): return frozendict.frozendict(result) # type: ignore[attr-defined, operator] return result - @validator("hint", pre=True, always=True) + @validator("hint", always=True) def validate_hint(cls, hint: Any, values: dict) -> Any: connection: BaseDBConnection = values["connection"] - dialect = connection.Dialect - result = dialect.validate_hint(connection, hint) + result = connection.dialect.validate_hint(hint) if isinstance(result, dict): return frozendict.frozendict(result) # type: ignore[attr-defined, operator] return result - @validator("df_schema", pre=True, always=True) + @validator("df_schema", always=True) def validate_df_schema(cls, df_schema: StructType | None, values: dict) -> StructType | None: connection: BaseDBConnection = values["connection"] - dialect = connection.Dialect - return dialect.validate_df_schema(connection, df_schema) + return connection.dialect.validate_df_schema(df_schema) - @root_validator(pre=True) # noqa: WPS231 - def validate_hwm_column(cls, values: dict) -> dict: - hwm_column: str | tuple[str, str] | Column | None = values.get("hwm_column") - df_schema: StructType | None = values.get("df_schema") - hwm_expression: str | None = values.get("hwm_expression") + @root_validator(skip_on_failure=True) + def validate_hwm(cls, values: dict) -> dict: # noqa: WPS231 connection: BaseDBConnection = values["connection"] - - if hwm_column is None or isinstance(hwm_column, Column): - return values - - if not hwm_expression and not isinstance(hwm_column, str): - # ("new_hwm_column", "cast(hwm_column as date)") noqa: E800 - hwm_column, hwm_expression = hwm_column # noqa: WPS434 - - if not hwm_expression: - raise ValueError( - "When the 'hwm_column' field is a tuple, then it must be " - "specified as tuple('column_name', 'expression'). Otherwise, " - "the 'hwm_column' field should be a string.", - ) - - if df_schema is not None and hwm_column not in df_schema.fieldNames(): - raise ValueError( - "'df_schema' struct must contain column specified in 'hwm_column'. " - "Otherwise DBReader cannot determine HWM type for this column", + source: str = values["source"] + hwm_column: str | tuple[str, str] | None = values.get("hwm_column") + hwm_expression: str | None = values.get("hwm_expression") + hwm: ColumnHWM | None = values.get("hwm") + + if hwm_column is not None: + if hwm: + raise ValueError("Please pass either DBReader(hwm=...) or DBReader(hwm_column=...), not both") + + if not hwm_expression and isinstance(hwm_column, tuple): + hwm_column, hwm_expression = hwm_column # noqa: WPS434 + + if not hwm_expression: + error_message = textwrap.dedent( + """ + When the 'hwm_column' field is a tuple, then it must be + specified as tuple('column_name', 'expression'). + + Otherwise, the 'hwm_column' field should be a string. + """, + ) + raise ValueError(error_message) + + # convert old parameters to new one + old_hwm = OldColumnHWM( + source=Table(name=source, instance=connection.instance_url), # type: ignore[arg-type] + column=Column(name=hwm_column), # type: ignore[arg-type] + ) + warnings.warn( + textwrap.dedent( + f""" + Passing "hwm_column" in DBReader class is deprecated since version 0.10.0, + and will be removed in v1.0.0. + + Instead use: + hwm=DBReader.AutoDetectHWM( + name={old_hwm.qualified_name!r}, + expression={hwm_column!r}, + ) + """, + ), + UserWarning, + stacklevel=2, + ) + hwm = AutoDetectHWM( + name=old_hwm.qualified_name, + expression=hwm_expression or hwm_column, ) - dialect = connection.Dialect - dialect.validate_hwm_column(connection, hwm_column) # type: ignore - - values["hwm_column"] = Column(name=hwm_column) # type: ignore - values["hwm_expression"] = hwm_expression - - return values - - @root_validator(pre=True) # noqa: WPS231 - def validate_columns(cls, values: dict) -> dict: - connection: BaseDBConnection = values["connection"] - dialect = connection.Dialect - - columns: list[str] | str | None = values.get("columns") - columns_list: list[str] | None - if isinstance(columns, str): - columns_list = columns.split(",") - else: - columns_list = columns - - columns_list = dialect.validate_columns(connection, columns_list) - if columns_list is None: - return values - - if not columns_list: - raise ValueError("Parameter 'columns' can not be an empty list") - - hwm_column = values.get("hwm_column") - hwm_expression = values.get("hwm_expression") - - result: list[str] = [] - already_visited: set[str] = set() + if hwm and not hwm.expression: + raise ValueError("`hwm.expression` cannot be None") - for item in columns_list: - column = item.strip() + if hwm and not hwm.entity: + hwm = hwm.copy(update={"entity": source}) - if not column: - raise ValueError(f"Column name cannot be empty string, got {item!r}") + if hwm and hwm.entity != source: + error_message = textwrap.dedent( + f""" + Passed `hwm.source` is different from `source`. - if column.casefold() in already_visited: - raise ValueError(f"Duplicated column name {item!r}") + `hwm`: + {hwm!r} - if hwm_expression and hwm_column and hwm_column.name.casefold() == column.casefold(): - raise ValueError(f"{item!r} is an alias for HWM, it cannot be used as 'columns' name") + `source`: + {source!r} - result.append(column) - already_visited.add(column.casefold()) + This is not allowed. + """, + ) + raise ValueError(error_message) - values["columns"] = result + values["hwm"] = connection.dialect.validate_hwm(hwm) + values["hwm_column"] = None + values["hwm_expression"] = None return values @validator("options", pre=True, always=True) @@ -467,41 +501,6 @@ def validate_options(cls, options, values): return None - @validator("hwm_expression", pre=True, always=True) # noqa: WPS238, WPS231 - def validate_hwm_expression(cls, hwm_expression, values): - connection: BaseDBConnection = values["connection"] - dialect = connection.Dialect - return dialect.validate_hwm_expression(connection=connection, value=hwm_expression) - - def get_df_schema(self) -> StructType: - if self.df_schema: - return self.df_schema - - if not self.df_schema and isinstance(self.connection, ContainsGetDFSchemaMethod): - return self.connection.get_df_schema( - source=str(self.source), - columns=self._resolve_all_columns(), - **self._get_read_kwargs(), - ) - - raise ValueError( - "|DBReader| You should specify `df_schema` field to use DBReader with " - f"{self.connection.__class__.__name__} connection", - ) - - def get_min_max_bounds(self, column: str, expression: str | None = None) -> tuple[Any, Any]: - if isinstance(self.connection, ContainsGetMinMaxBounds): - return self.connection.get_min_max_bounds( - source=str(self.source), - column=column, - expression=expression, - hint=self.hint, - where=self.where, - **self._get_read_kwargs(), - ) - - raise ValueError(f"{self.connection.__class__.__name__} connection does not support batch strategies") - @slot def run(self) -> DataFrame: """ @@ -532,41 +531,214 @@ def run(self) -> DataFrame: df = reader.run() """ - # avoid circular imports - from onetl.db.db_reader.strategy_helper import ( - HWMStrategyHelper, - NonHWMStrategyHelper, - StrategyHelper, + entity_boundary_log(log, msg=f"{self.__class__.__name__}.run() starts") + + self._check_strategy() + + if not self._connection_checked: + self._log_parameters() + self.connection.check() + self._connection_checked = True + + window, limit = self._calculate_window_and_limit() + df = self.connection.read_source_as_df( + source=str(self.source), + columns=self.columns, + hint=self.hint, + where=self.where, + df_schema=self.df_schema, + window=window, + limit=limit, + **self._get_read_kwargs(), ) - entity_boundary_log(log, msg="DBReader starts") + entity_boundary_log(log, msg=f"{self.__class__.__name__}.run() ends", char="-") + return df + + def _check_strategy(self): + strategy = StrategyManager.get_current() + class_name = type(self).__name__ + strategy_name = type(strategy).__name__ - self._log_parameters() - self.connection.check() + if self.hwm: + if not isinstance(strategy, HWMStrategy): + raise RuntimeError(f"{class_name}(hwm=...) cannot be used with {strategy_name}") + self._prepare_hwm(strategy, self.hwm) + + elif isinstance(strategy, HWMStrategy): + raise RuntimeError(f"{strategy_name} cannot be used without {class_name}(hwm=...)") + + def _prepare_hwm(self, strategy: HWMStrategy, hwm: ColumnHWM): + if not strategy.hwm: + # first run within the strategy + if isinstance(hwm, AutoDetectHWM): + strategy.hwm = self._autodetect_hwm(hwm) + else: + strategy.hwm = hwm + strategy.fetch_hwm() + return + + if not isinstance(strategy.hwm, ColumnHWM) or strategy.hwm.name != hwm.name: + # exception raised when inside one strategy >1 processes on the same table but with different hwm columns + # are executed, example: test_postgres_strategy_incremental_hwm_set_twice + error_message = textwrap.dedent( + f""" + Detected wrong {type(strategy).__name__} usage. + + Previous run: + {strategy.hwm!r} + Current run: + {hwm!r} + + Probably you've executed code which looks like this: + with {strategy.__class__.__name__}(...): + DBReader(hwm=one_hwm, ...).run() + DBReader(hwm=another_hwm, ...).run() + + Please change it to: + with {strategy.__class__.__name__}(...): + DBReader(hwm=one_hwm, ...).run() + + with {strategy.__class__.__name__}(...): + DBReader(hwm=another_hwm, ...).run() + """, + ) + raise ValueError(error_message) + + strategy.validate_hwm_attributes(hwm, strategy.hwm, origin=self.__class__.__name__) + + def _autodetect_hwm(self, hwm: HWM) -> HWM: + field = self._get_hwm_field(hwm) + field_type = field.dataType + detected_hwm_type = self.connection.dialect.detect_hwm_class(field) + + if detected_hwm_type: + log.info( + "|%s| Detected HWM type: %r", + self.__class__.__name__, + detected_hwm_type.__name__, + ) + return detected_hwm_type.deserialize(hwm.dict()) + + error_message = textwrap.dedent( + f""" + Cannot detect HWM type for field {hwm.expression!r} of type {field_type!r} + + Check that column or expression type is supported by {self.connection.__class__.__name__}. + """, + ) + raise RuntimeError(error_message) - helper: StrategyHelper - if self.hwm_column: - helper = HWMStrategyHelper(reader=self, hwm_column=self.hwm_column, hwm_expression=self.hwm_expression) + def _get_hwm_field(self, hwm: HWM) -> StructField: + log.info( + "|%s| Getting Spark type for HWM expression: %r", + self.__class__.__name__, + hwm.expression, + ) + + result: StructField + if self.df_schema: + schema = {field.name.casefold(): field for field in self.df_schema} + column = hwm.expression.casefold() + if column not in schema: + raise ValueError(f"HWM column {column!r} not found in dataframe schema") + + result = schema[column] + elif isinstance(self.connection, ContainsGetDFSchemaMethod): + df_schema = self.connection.get_df_schema( + source=self.source, + columns=[hwm.expression], + **self._get_read_kwargs(), + ) + result = df_schema[0] else: - helper = NonHWMStrategyHelper(reader=self) + raise ValueError( + "You should specify `df_schema` field to use DBReader with " + f"{self.connection.__class__.__name__} connection", + ) - start_from, end_at = helper.get_boundaries() + log.info("|%s| Got Spark field: %s", self.__class__.__name__, result) + return result - df = self.connection.read_source_as_df( - source=str(self.source), - columns=self._resolve_all_columns(), + def _calculate_window_and_limit(self) -> tuple[Window | None, int | None]: + if not self.hwm: + # SnapshotStrategy - always select all the data from source + return None, None + + strategy: HWMStrategy = StrategyManager.get_current() # type: ignore[assignment] + + start_value = strategy.current.value + stop_value = strategy.stop if isinstance(strategy, BatchHWMStrategy) else None + + if start_value is not None and stop_value is not None: + # we already have start and stop values, nothing to do + window = Window(self.hwm.expression, start_from=strategy.current, stop_at=strategy.next) + strategy.update_hwm(window.stop_at.value) + return window, None + + if not isinstance(self.connection, ContainsGetMinMaxValues): + raise ValueError( + f"{self.connection.__class__.__name__} connection does not support {strategy.__class__.__name__}", + ) + + # strategy does not have start/stop/current value - use min/max values from source to fill them up + min_value, max_value = self.connection.get_min_max_values( + source=self.source, + window=Window( + self.hwm.expression, + # always include both edges, > vs >= are applied only to final dataframe + start_from=Edge(value=start_value), + stop_at=Edge(value=stop_value), + ), hint=self.hint, where=self.where, - df_schema=self.df_schema, - start_from=start_from, - end_at=end_at, **self._get_read_kwargs(), ) - df = helper.save(df) - entity_boundary_log(log, msg="DBReader ends", char="-") + if min_value is None or max_value is None: + log.warning("|%s| No data in source %r", self.__class__.__name__, self.source) + # return limit=0 to always return empty dataframe from the source. + # otherwise dataframe may start returning some data whether HWM is not being set + return None, 0 + + # returned value type may not always be the same type as expected, force cast to HWM type + hwm = strategy.hwm.copy() # type: ignore[union-attr] + + try: + min_value = hwm.set_value(min_value).value + max_value = hwm.set_value(max_value).value + except ValueError as e: + hwm_class_name = type(hwm).__name__ + error_message = textwrap.dedent( + f""" + Expression {hwm.expression!r} returned values: + min: {min_value!r} of type {type(min_value).__name__!r} + max: {max_value!r} of type {type(min_value).__name__!r} + which are not compatible with {hwm_class_name}. + + Please check if selected combination of HWM class and expression is valid. + """, + ) + raise ValueError(error_message) from e - return df + if isinstance(strategy, BatchHWMStrategy): + if strategy.start is None: + strategy.start = min_value + + if strategy.stop is None: + strategy.stop = max_value + + window = Window(self.hwm.expression, start_from=strategy.current, stop_at=strategy.next) + else: + # for IncrementalStrategy fix only max value to avoid difference between real dataframe content and HWM value + window = Window( + self.hwm.expression, + start_from=strategy.current, + stop_at=Edge(value=max_value), + ) + + strategy.update_hwm(window.stop_at.value) + return window, None def _log_parameters(self) -> None: log.info("|%s| -> |Spark| Reading DataFrame from source using parameters:", self.connection.__class__.__name__) @@ -581,68 +753,16 @@ def _log_parameters(self) -> None: if self.where: log_json(log, self.where, "where") - if self.hwm_column: - log_with_indent(log, "hwm_column = '%s'", self.hwm_column) - - if self.hwm_expression: - log_json(log, self.hwm_expression, "hwm_expression") - if self.df_schema: empty_df = self.connection.spark.createDataFrame([], self.df_schema) # type: ignore log_dataframe_schema(log, empty_df) + if self.hwm: + log_hwm(log, self.hwm) + options = self.options.dict(by_alias=True, exclude_none=True) if self.options else None log_options(log, options) - def _resolve_all_columns(self) -> list[str] | None: - """ - Unwraps "*" in columns list to real column names from existing table. - - Also adds 'hwm_column' to the result if it is not present. - """ - - if not isinstance(self.connection, ContainsGetDFSchemaMethod): - # Some databases have no `get_df_schema` method - return self.columns - - columns: list[str] = [] - original_columns = self.columns or ["*"] - - for column in original_columns: - if column == "*": - schema = self.connection.get_df_schema( - source=str(self.source), - columns=["*"], - **self._get_read_kwargs(), - ) - field_names = schema.fieldNames() - columns.extend(field_names) - else: - columns.append(column) - - columns = uniq_ignore_case(columns) - - if not self.hwm_column: - return columns - - hwm_statement = self.hwm_column.name - if self.hwm_expression: - hwm_statement = self.connection.Dialect._expression_with_alias( # noqa: WPS437 - self.hwm_expression, - self.hwm_column.name, - ) - - columns_normalized = [column_name.casefold() for column_name in columns] - hwm_column_name = self.hwm_column.name.casefold() - - if hwm_column_name in columns_normalized: - column_index = columns_normalized.index(hwm_column_name) - columns[column_index] = hwm_statement - else: - columns.append(hwm_statement) - - return columns - def _get_read_kwargs(self) -> dict: if self.options: return {"options": self.options} diff --git a/onetl/db/db_reader/strategy_helper.py b/onetl/db/db_reader/strategy_helper.py deleted file mode 100644 index ad50894a2..000000000 --- a/onetl/db/db_reader/strategy_helper.py +++ /dev/null @@ -1,227 +0,0 @@ -# Copyright 2023 MTS (Mobile Telesystems) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from logging import getLogger -from typing import TYPE_CHECKING, NoReturn, Optional, Tuple - -from etl_entities import HWM, Column, ColumnHWM -from pydantic import Field, root_validator, validator -from typing_extensions import Protocol - -from onetl.db.db_reader.db_reader import DBReader -from onetl.hwm import Statement -from onetl.hwm.store import HWMClassRegistry -from onetl.impl import FrozenModel -from onetl.strategy.batch_hwm_strategy import BatchHWMStrategy -from onetl.strategy.hwm_strategy import HWMStrategy -from onetl.strategy.strategy_manager import StrategyManager - -log = getLogger(__name__) - -if TYPE_CHECKING: - from pyspark.sql.dataframe import DataFrame - - -# ColumnHWM has abstract method serialize_value, so it's not possible to create a class instance -# small hack to bypass this exception -class MockColumnHWM(ColumnHWM): - def serialize_value(self): - """Fake implementation of ColumnHWM.serialize_value""" - - -class StrategyHelper(Protocol): - def save(self, df: DataFrame) -> DataFrame: - """Saves HWM value to HWMStore""" - - def get_boundaries(self) -> tuple[Statement | None, Statement | None]: - """Returns ``(min_boundary, max_boundary)`` for applying HWM to source""" - - -class NonHWMStrategyHelper(FrozenModel): - reader: DBReader - - def get_boundaries(self) -> Tuple[Optional[Statement], Optional[Statement]]: - return None, None - - @root_validator(pre=True) - def validate_current_strategy(cls, values): - reader = values.get("reader") - strategy = StrategyManager.get_current() - - if isinstance(strategy, HWMStrategy): - raise ValueError( - f"{strategy.__class__.__name__} cannot be used " - f"without `hwm_column` passed into {reader.__class__.__name__}", - ) - - return values - - def save(self, df: DataFrame) -> DataFrame: - return df - - -class HWMStrategyHelper(FrozenModel): - reader: DBReader - hwm_column: Column - hwm_expression: Optional[str] = None - strategy: HWMStrategy = Field(default_factory=StrategyManager.get_current) - - class Config: - validate_all = True - - @validator("strategy", always=True, pre=True) - def validate_strategy_is_hwm(cls, strategy, values): - reader = values.get("reader") - - if not isinstance(strategy, HWMStrategy): - raise ValueError( - f"{strategy.__class__.__name__} cannot be used " - f"with `hwm_column` passed into {reader.__class__.__name__}", - ) - - return strategy - - @validator("strategy", always=True) - def validate_strategy_matching_reader(cls, strategy, values): - if strategy.hwm is None: - return strategy - - reader = values.get("reader") - hwm_column = values.get("hwm_column") - - if not isinstance(strategy.hwm, ColumnHWM): - cls.raise_wrong_hwm_type(reader, type(strategy.hwm)) - - if strategy.hwm.source != reader.source or strategy.hwm.column != hwm_column: - raise ValueError( - f"{reader.__class__.__name__} was created " - f"with `hwm_column={reader.hwm_column}` and `source={reader.source}` " - f"but current HWM is created for ", - f"`column={strategy.hwm.column}` and `source={strategy.hwm.source}` ", - ) - - return strategy - - @validator("strategy", always=True) - def init_hwm(cls, strategy, values): - reader = values.get("reader") - hwm_column = values.get("hwm_column") - - if strategy.hwm is None: - # Small hack used only to generate qualified_name - strategy.hwm = MockColumnHWM(source=reader.source, column=hwm_column) - - if not strategy.hwm: - strategy.fetch_hwm() - - hwm_type: type[HWM] | None = type(strategy.hwm) - if hwm_type == MockColumnHWM: - # Remove HWM type set by hack above - hwm_type = None - - detected_hwm_type = cls.detect_hwm_column_type(reader, hwm_column) - - if not hwm_type: - hwm_type = detected_hwm_type - - if hwm_type != detected_hwm_type: - raise TypeError( - f'Type of "{hwm_column}" column is matching ' - f'"{detected_hwm_type.__name__}" which is different from "{hwm_type.__name__}"', - ) - - if hwm_type == MockColumnHWM or not issubclass(hwm_type, ColumnHWM): - cls.raise_wrong_hwm_type(reader, hwm_type) - - strategy.hwm = hwm_type(source=reader.source, column=hwm_column, value=strategy.hwm.value) - return strategy - - @validator("strategy", always=True) - def detect_hwm_column_boundaries(cls, strategy, values): - if not isinstance(strategy, BatchHWMStrategy): - return strategy - - if strategy.has_upper_limit and (strategy.has_lower_limit or strategy.hwm): - # values already set by previous reader runs within the strategy - return strategy - - reader = values.get("reader") - hwm_column = values.get("hwm_column") - hwm_expression = values.get("hwm_expression") - - min_hwm_value, max_hwm_value = reader.get_min_max_bounds(hwm_column.name, hwm_expression) - if min_hwm_value is None or max_hwm_value is None: - raise ValueError( - "Unable to determine max and min values. ", - f"Table '{reader.source}' column '{hwm_column}' cannot be used as `hwm_column`", - ) - - if not strategy.has_lower_limit and not strategy.hwm: - strategy.start = min_hwm_value - - if not strategy.has_upper_limit: - strategy.stop = max_hwm_value - - return strategy - - @staticmethod - def raise_wrong_hwm_type(reader: DBReader, hwm_type: type[HWM]) -> NoReturn: - raise ValueError( - f"{hwm_type.__name__} cannot be used with {reader.__class__.__name__}", - ) - - @staticmethod - def detect_hwm_column_type(reader: DBReader, hwm_column: Column) -> type[HWM]: - schema = {field.name.casefold(): field for field in reader.get_df_schema()} - column = hwm_column.name.casefold() - hwm_column_type = schema[column].dataType.typeName() - return HWMClassRegistry.get(hwm_column_type) - - def save(self, df: DataFrame) -> DataFrame: - from pyspark.sql import functions as F # noqa: N812 - - log.info("|DBReader| Calculating max value for column %r in the dataframe...", self.hwm_column.name) - max_df = df.select(F.max(self.hwm_column.name).alias("max_value")) - row = max_df.collect()[0] - max_hwm_value = row["max_value"] - log.info("|DBReader| Max value is: %r", max_hwm_value) - - self.strategy.update_hwm(max_hwm_value) - return df - - def get_boundaries(self) -> tuple[Statement | None, Statement | None]: - start_from: Statement | None = None - end_at: Statement | None = None - hwm: ColumnHWM | None = self.strategy.hwm # type: ignore - - if hwm is None: - return None, None - - if self.strategy.current_value is not None: - start_from = Statement( - expression=self.hwm_expression or hwm.name, - operator=self.strategy.current_value_comparator, - value=self.strategy.current_value, - ) - - if self.strategy.next_value is not None: - end_at = Statement( - expression=self.hwm_expression or hwm.name, - operator=self.strategy.next_value_comparator, - value=self.strategy.next_value, - ) - - return start_from, end_at diff --git a/onetl/db/db_writer/db_writer.py b/onetl/db/db_writer/db_writer.py index 460bc20d1..5f4264042 100644 --- a/onetl/db/db_writer/db_writer.py +++ b/onetl/db/db_writer/db_writer.py @@ -17,8 +17,7 @@ from logging import getLogger from typing import TYPE_CHECKING, Optional -from etl_entities import Table -from pydantic import Field, validator +from pydantic import Field, PrivateAttr, validator from onetl.base import BaseDBConnection from onetl.hooks import slot, support_hooks @@ -151,18 +150,15 @@ class DBWriter(FrozenModel): """ connection: BaseDBConnection - target: Table = Field(alias="table") + target: str = Field(alias="table") options: Optional[GenericOptions] = None + _connection_checked: bool = PrivateAttr(default=False) + @validator("target", pre=True, always=True) def validate_target(cls, target, values): connection: BaseDBConnection = values["connection"] - dialect = connection.Dialect - if isinstance(target, str): - # target="dbschema.table" or target="table", If target="dbschema.some.table" in class Table will raise error. - target = Table(name=target, instance=connection.instance_url) - # Here Table(name='target', db='dbschema', instance='some_instance') - return dialect.validate_name(connection, target) + return connection.dialect.validate_name(target) @validator("options", pre=True, always=True) def validate_options(cls, options, values): @@ -202,18 +198,21 @@ def run(self, df: DataFrame): if df.isStreaming: raise ValueError(f"DataFrame is streaming. {self.__class__.__name__} supports only batch DataFrames.") - entity_boundary_log(log, msg="DBWriter starts") + entity_boundary_log(log, msg=f"{self.__class__.__name__}.run() starts") + + if not self._connection_checked: + self._log_parameters() + log_dataframe_schema(log, df) + self.connection.check() + self._connection_checked = True - self._log_parameters() - log_dataframe_schema(log, df) - self.connection.check() self.connection.write_df_to_target( df=df, target=str(self.target), **self._get_write_kwargs(), ) - entity_boundary_log(log, msg="DBWriter ends", char="-") + entity_boundary_log(log, msg=f"{self.__class__.__name__}.run() ends", char="-") def _log_parameters(self) -> None: log.info("|Spark| -> |%s| Writing DataFrame to target using parameters:", self.connection.__class__.__name__) diff --git a/onetl/file/file_df_reader/file_df_reader.py b/onetl/file/file_df_reader/file_df_reader.py index 3d6a7715a..8fdef08a4 100644 --- a/onetl/file/file_df_reader/file_df_reader.py +++ b/onetl/file/file_df_reader/file_df_reader.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Iterable, Optional from ordered_set import OrderedSet -from pydantic import validator +from pydantic import PrivateAttr, validator from onetl._util.spark import try_import_pyspark from onetl.base import BaseFileDFConnection, BaseReadableFileFormat, PurePathProtocol @@ -116,6 +116,8 @@ class FileDFReader(FrozenModel): df_schema: Optional[StructType] = None options: FileDFReaderOptions = FileDFReaderOptions() + _connection_checked: bool = PrivateAttr(default=False) + @slot def run(self, files: Iterable[str | os.PathLike] | None = None) -> DataFrame: """ @@ -205,10 +207,13 @@ def run(self, files: Iterable[str | os.PathLike] | None = None) -> DataFrame: ) """ + entity_boundary_log(log, msg=f"{self.__class__.__name__}.run() starts") + if files is None and not self.source_path: raise ValueError("Neither file list nor `source_path` are passed") - self._log_parameters(files) + if not self._connection_checked: + self._log_parameters(files) paths: FileSet[PurePathProtocol] = FileSet() if files is not None: @@ -216,10 +221,14 @@ def run(self, files: Iterable[str | os.PathLike] | None = None) -> DataFrame: elif self.source_path: paths = FileSet([self.source_path]) - self.connection.check() - log_with_indent(log, "") + if not self._connection_checked: + self.connection.check() + log_with_indent(log, "") + self._connection_checked = True - return self._read_files(paths) + df = self._read_files(paths) + entity_boundary_log(log, msg=f"{self.__class__.__name__}.run() ends", char="-") + return df def _read_files(self, paths: FileSet[PurePathProtocol]) -> DataFrame: log.info("|%s| Paths to be read:", self.__class__.__name__) @@ -235,8 +244,6 @@ def _read_files(self, paths: FileSet[PurePathProtocol]) -> DataFrame: ) def _log_parameters(self, files: Iterable[str | os.PathLike] | None = None) -> None: - entity_boundary_log(log, msg=f"{self.__class__.__name__} starts") - log.info("|%s| -> |Spark| Reading files using parameters:", self.connection.__class__.__name__) log_with_indent(log, "source_path = %s", f"'{self.source_path}'" if self.source_path else "None") log_with_indent(log, "format = %r", self.format) diff --git a/onetl/file/file_df_writer/file_df_writer.py b/onetl/file/file_df_writer/file_df_writer.py index c0cc6711b..f2dd52438 100644 --- a/onetl/file/file_df_writer/file_df_writer.py +++ b/onetl/file/file_df_writer/file_df_writer.py @@ -17,7 +17,7 @@ import logging from typing import TYPE_CHECKING -from pydantic import validator +from pydantic import PrivateAttr, validator from onetl.base import BaseFileDFConnection, BaseWritableFileFormat, PurePathProtocol from onetl.file.file_df_writer.options import FileDFWriterOptions @@ -99,6 +99,8 @@ class FileDFWriter(FrozenModel): target_path: PurePathProtocol options: FileDFWriterOptions = FileDFWriterOptions() + _connection_checked: bool = PrivateAttr(default=False) + @slot def run(self, df: DataFrame) -> None: """ @@ -122,13 +124,16 @@ def run(self, df: DataFrame) -> None: writer.run(df) """ + entity_boundary_log(log, f"{self.__class__.__name__}.run() starts") + if df.isStreaming: raise ValueError(f"DataFrame is streaming. {self.__class__.__name__} supports only batch DataFrames.") - entity_boundary_log(log, f"{self.__class__.__name__} starts") + if not self._connection_checked: + self._log_parameters(df) + self.connection.check() + self._connection_checked = True - self._log_parameters(df) - self.connection.check() self.connection.write_df_as_files( df=df, path=self.target_path, @@ -136,7 +141,7 @@ def run(self, df: DataFrame) -> None: options=self.options, ) - entity_boundary_log(log, f"{self.__class__.__name__} ends", char="-") + entity_boundary_log(log, f"{self.__class__.__name__}.run() ends", char="-") def _log_parameters(self, df: DataFrame) -> None: log.info("|Spark| -> |%s| Writing dataframe using parameters:", self.connection.__class__.__name__) diff --git a/onetl/file/file_downloader/file_downloader.py b/onetl/file/file_downloader/file_downloader.py index 4a51cc74b..fa1257303 100644 --- a/onetl/file/file_downloader/file_downloader.py +++ b/onetl/file/file_downloader/file_downloader.py @@ -17,25 +17,27 @@ import logging import os import shutil +import textwrap import warnings from concurrent.futures import ThreadPoolExecutor, as_completed from enum import Enum -from typing import Iterable, List, Optional, Tuple, Type +from typing import Generator, Iterable, List, Optional, Tuple, Type, Union -from etl_entities import HWM, FileHWM, RemoteFolder +from etl_entities.hwm import FileHWM, FileListHWM +from etl_entities.instance import AbsolutePath +from etl_entities.old_hwm import FileListHWM as OldFileListHWM +from etl_entities.source import RemoteFolder from ordered_set import OrderedSet -from pydantic import Field, validator +from pydantic import Field, PrivateAttr, root_validator, validator from onetl._internal import generate_temp_path from onetl.base import BaseFileConnection, BaseFileFilter, BaseFileLimit -from onetl.base.path_protocol import PathProtocol, PathWithStatsProtocol -from onetl.base.pure_path_protocol import PurePathProtocol +from onetl.base.path_protocol import PathProtocol from onetl.file.file_downloader.options import FileDownloaderOptions from onetl.file.file_downloader.result import DownloadResult from onetl.file.file_set import FileSet from onetl.file.filter.file_hwm import FileHWMFilter from onetl.hooks import slot, support_hooks -from onetl.hwm.store import HWMClassRegistry from onetl.impl import ( FailedRemoteFile, FileExistBehavior, @@ -48,6 +50,7 @@ from onetl.log import ( entity_boundary_log, log_collection, + log_hwm, log_lines, log_options, log_with_indent, @@ -127,11 +130,12 @@ class FileDownloader(FrozenModel): options : :obj:`~FileDownloader.Options` | dict | None, default: ``None`` File downloading options. See :obj:`~FileDownloader.Options` - hwm_type : str | type[HWM] | None, default: ``None`` - HWM type to detect changes in incremental run. See :ref:`file-hwm` + hwm : type[HWM] | None, default: ``None`` + + HWM class to detect changes in incremental run. See :etl-entities:`File HWM ` .. warning :: - Used only in :obj:`onetl.strategy.incremental_strategy.IncrementalStrategy`. + Used only in :obj:`IncrementalStrategy `. Examples -------- @@ -192,6 +196,7 @@ class FileDownloader(FrozenModel): from onetl.connection import SFTP from onetl.file import FileDownloader from onetl.strategy import IncrementalStrategy + from etl_entities.hwm import FileListHWM sftp = SFTP(...) @@ -200,7 +205,9 @@ class FileDownloader(FrozenModel): connection=sftp, source_path="/path/to/remote/source", local_path="/path/to/local", - hwm_type="file_list", # mandatory for IncrementalStrategy + hwm=FileListHWM( + name="my_unique_hwm_name", directory="/path/to/remote/source" + ), # mandatory for IncrementalStrategy ) # download files to "/path/to/local", but only new ones @@ -209,6 +216,8 @@ class FileDownloader(FrozenModel): """ + Options = FileDownloaderOptions + connection: BaseFileConnection local_path: LocalPath @@ -218,11 +227,12 @@ class FileDownloader(FrozenModel): filters: List[BaseFileFilter] = Field(default_factory=list, alias="filter") limits: List[BaseFileLimit] = Field(default_factory=list, alias="limit") - hwm_type: Optional[Type[FileHWM]] = None + hwm: Optional[FileHWM] = None + hwm_type: Optional[Union[Type[OldFileListHWM], str]] = None options: FileDownloaderOptions = FileDownloaderOptions() - Options = FileDownloaderOptions + _connection_checked: bool = PrivateAttr(default=False) @slot def run(self, files: Iterable[str | os.PathLike] | None = None) -> DownloadResult: # noqa: WPS231 @@ -240,10 +250,10 @@ def run(self, files: Iterable[str | os.PathLike] | None = None) -> DownloadResul File list to download. If empty, download files from ``source_path`` to ``local_path``, - applying ``filter``, ``limit`` and ``hwm_type`` to each one (if set). + applying ``filter``, ``limit`` and ``hwm`` to each one (if set). - If not, download to ``local_path`` **all** input files, **without** - any filtering, limiting and excluding files covered by :ref:`file-hwm` + If not, download to ``local_path`` **all** input files, **ignoring** + filters, limits and HWM. Returns ------- @@ -252,7 +262,7 @@ def run(self, files: Iterable[str | os.PathLike] | None = None) -> DownloadResul Download result object Raises - ------- + ------ :obj:`onetl.exception.DirectoryNotFoundError` ``source_path`` does not found @@ -336,24 +346,28 @@ def run(self, files: Iterable[str | os.PathLike] | None = None) -> DownloadResul assert not downloaded_files.missing """ + entity_boundary_log(log, f"{self.__class__.__name__}.run() starts") + + if not self._connection_checked: + self._log_parameters(files) + self._check_strategy() if files is None and not self.source_path: raise ValueError("Neither file list nor `source_path` are passed") - self._log_parameters(files) - # Check everything - self._check_local_path() - self.connection.check() - log_with_indent(log, "") + if not self._connection_checked: + self._check_local_path() + self.connection.check() - if self.source_path: - self._check_source_path() + if self.source_path: + self._check_source_path() - if files is None: - log.info("|%s| File list is not passed to `run` method", self.__class__.__name__) + self._connection_checked = True + if files is None: + log.debug("|%s| File list is not passed to `run` method", self.__class__.__name__) files = self.view_files() if not files: @@ -372,15 +386,15 @@ def run(self, files: Iterable[str | os.PathLike] | None = None) -> DownloadResul shutil.rmtree(self.local_path) self.local_path.mkdir() - if self.hwm_type is not None: - result = self._download_files_incremental(to_download) - else: - result = self._download_files(to_download) + if self.hwm: + self._init_hwm(self.hwm) + result = self._download_files(to_download) if current_temp_dir: self._remove_temp_dir(current_temp_dir) self._log_result(result) + entity_boundary_log(log, f"{self.__class__.__name__}.run() ends", char="-") return result @slot @@ -394,7 +408,7 @@ def view_files(self) -> FileSet[RemoteFile]: This method can return different results depending on :ref:`strategy` Raises - ------- + ------ :obj:`onetl.exception.DirectoryNotFoundError` ``source_path`` does not found @@ -429,14 +443,19 @@ def view_files(self) -> FileSet[RemoteFile]: } """ + if not self.source_path: + raise ValueError("Cannot call `.view_files()` without `source_path`") + log.debug("|%s| Getting files list from path '%s'", self.connection.__class__.__name__, self.source_path) - self._check_source_path() + if not self._connection_checked: + self._check_source_path() + result = FileSet() filters = self.filters.copy() - if self.hwm_type: - filters.append(FileHWMFilter(hwm=self._init_hwm())) + if self.hwm: + filters.append(FileHWMFilter(hwm=self._init_hwm(self.hwm))) try: for root, _dirs, files in self.connection.walk(self.source_path, filters=filters, limits=self.limits): @@ -462,20 +481,59 @@ def _validate_source_path(cls, source_path): def _validate_temp_path(cls, temp_path): return LocalPath(temp_path).resolve() if temp_path else None - @validator("hwm_type", pre=True, always=True) - def _validate_hwm_type(cls, hwm_type, values): + @root_validator(skip_on_failure=True) + def _validate_hwm(cls, values): + connection = values["connection"] source_path = values.get("source_path") + hwm_type = values.get("hwm_type") + hwm = values.get("hwm") - if hwm_type: - if not source_path: - raise ValueError("If `hwm_type` is passed, `source_path` must be specified") + if (hwm or hwm_type) and not source_path: + raise ValueError("If `hwm` is passed, `source_path` must be specified") + + if hwm_type and (hwm_type == "file_list" or issubclass(hwm_type, OldFileListHWM)): + remote_file_folder = RemoteFolder(name=source_path, instance=connection.instance_url) + old_hwm = OldFileListHWM(source=remote_file_folder) + warnings.warn( + textwrap.dedent( + f""" + Passing "hwm_type" to FileDownloader class is deprecated since version 0.10.0, + and will be removed in v1.0.0. + + Instead use: + hwm=FileListHWM(name={old_hwm.qualified_name!r}) + """, + ), + UserWarning, + stacklevel=2, + ) + hwm = FileListHWM( + name=old_hwm.qualified_name, + directory=source_path, + ) - if isinstance(hwm_type, str): - hwm_type = HWMClassRegistry.get(hwm_type) + if hwm and not hwm.entity: + hwm = hwm.copy(update={"entity": AbsolutePath(source_path)}) - cls._check_hwm_type(hwm_type) + if hwm and hwm.entity != source_path: + error_message = textwrap.dedent( + f""" + Passed `hwm.directory` is different from `source_path`. - return hwm_type + `hwm`: + {hwm!r} + + `source_path`: + {source_path!r} + + This is not allowed. + """, + ) + raise ValueError(error_message) + + values["hwm"] = hwm + values["hwm_type"] = None + return values @validator("filters", pre=True) def _validate_filters(cls, filters): @@ -519,46 +577,67 @@ def _validate_limits(cls, limits): def _check_strategy(self): strategy = StrategyManager.get_current() + class_name = self.__class__.__name__ + strategy_name = strategy.__class__.__name__ - if self.hwm_type: + if self.hwm: if not isinstance(strategy, HWMStrategy): - raise ValueError("`hwm_type` cannot be used in snapshot strategy.") - elif getattr(strategy, "offset", None): # this check should be somewhere in IncrementalStrategy, - # but the logic is quite messy - raise ValueError("If `hwm_type` is passed you can't specify an `offset`") + raise ValueError(f"{class_name}(hwm=...) cannot be used with {strategy_name}") + + offset = getattr(strategy, "offset", None) + if offset is not None: + raise ValueError(f"{class_name}(hwm=...) cannot be used with {strategy_name}(offset={offset}, ...)") if isinstance(strategy, BatchHWMStrategy): - raise ValueError("`hwm_type` cannot be used in batch strategy.") + raise ValueError(f"{class_name}(hwm=...) cannot be used with {strategy_name}") - def _init_hwm(self) -> FileHWM: + def _init_hwm(self, hwm: FileHWM) -> FileHWM: strategy: HWMStrategy = StrategyManager.get_current() - if strategy.hwm is None: - remote_file_folder = RemoteFolder(name=self.source_path, instance=self.connection.instance_url) - strategy.hwm = self.hwm_type(source=remote_file_folder) - if not strategy.hwm: + strategy.hwm = self.hwm strategy.fetch_hwm() + return strategy.hwm + + if not isinstance(strategy.hwm, FileHWM) or strategy.hwm.name != hwm.name: + # exception raised when inside one strategy >1 processes on the same table but with different hwm columns + # are executed, example: test_postgres_strategy_incremental_hwm_set_twice + error_message = textwrap.dedent( + f""" + Detected wrong {type(strategy).__name__} usage. + + Previous run: + {strategy.hwm!r} + Current run: + {self.hwm!r} + + Probably you've executed code which looks like this: + with {strategy.__class__.__name__}(...): + FileDownloader(hwm=one_hwm, ...).run() + FileDownloader(hwm=another_hwm, ...).run() + + Please change it to: + with {strategy.__class__.__name__}(...): + FileDownloader(hwm=one_hwm, ...).run() + + with {strategy.__class__.__name__}(...): + FileDownloader(hwm=another_hwm, ...).run() + """, + ) + raise ValueError(error_message) - file_hwm = strategy.hwm - - # to avoid issues when HWM store returned HWM with unexpected type - self._check_hwm_type(file_hwm.__class__) - return file_hwm - - def _download_files_incremental(self, to_download: DOWNLOAD_ITEMS_TYPE) -> DownloadResult: - self._init_hwm() - return self._download_files(to_download) + strategy.validate_hwm_attributes(hwm, strategy.hwm, origin=self.__class__.__name__) + return strategy.hwm def _log_parameters(self, files: Iterable[str | os.PathLike] | None = None) -> None: - entity_boundary_log(log, msg="FileDownloader starts") - log.info("|%s| -> |Local FS| Downloading files using parameters:", self.connection.__class__.__name__) log_with_indent(log, "source_path = %s", f"'{self.source_path}'" if self.source_path else "None") log_with_indent(log, "local_path = '%s'", self.local_path) log_with_indent(log, "temp_path = %s", f"'{self.temp_path}'" if self.temp_path else "None") log_collection(log, "filters", self.filters) log_collection(log, "limits", self.limits) + if self.hwm: + log_hwm(log, self.hwm) log_options(log, self.options.dict(by_alias=True)) if self.options.delete_source: @@ -628,7 +707,7 @@ def _check_local_path(self): self.local_path.mkdir(exist_ok=True, parents=True) - def _download_files( + def _download_files( # noqa: WPS231 self, to_download: DOWNLOAD_ITEMS_TYPE, ) -> DownloadResult: @@ -640,17 +719,25 @@ def _download_files( self._create_dirs(to_download) + strategy = StrategyManager.get_current() result = DownloadResult() - for status, file in self._bulk_download(to_download): - if status == FileDownloadStatus.SUCCESSFUL: - result.successful.add(file) - elif status == FileDownloadStatus.FAILED: - result.failed.add(file) - elif status == FileDownloadStatus.SKIPPED: - result.skipped.add(file) - elif status == FileDownloadStatus.MISSING: - result.missing.add(file) - + source_files: list[RemotePath] = [] + try: # noqa: WPS501 + for status, source_file, target_file in self._bulk_download(to_download): + if status == FileDownloadStatus.SUCCESSFUL: + result.successful.add(target_file) + source_files.append(source_file) + elif status == FileDownloadStatus.FAILED: + result.failed.add(source_file) + elif status == FileDownloadStatus.SKIPPED: + result.skipped.add(source_file) + elif status == FileDownloadStatus.MISSING: + result.missing.add(source_file) + finally: + if self.hwm: + # always update HWM in HWM store, even if downloader is interrupted + strategy.update_hwm(source_files) + strategy.save_hwm() return result def _create_dirs( @@ -673,10 +760,9 @@ def _create_dirs( def _bulk_download( self, to_download: DOWNLOAD_ITEMS_TYPE, - ) -> list[tuple[FileDownloadStatus, PurePathProtocol | PathWithStatsProtocol]]: + ) -> Generator[tuple[FileDownloadStatus, RemotePath, LocalPath | None], None, None]: workers = self.options.workers files_count = len(to_download) - result = [] real_workers = workers if files_count < workers: @@ -698,27 +784,20 @@ def _bulk_download( executor.submit(self._download_file, source_file, target_file, tmp_file) for source_file, target_file, tmp_file in to_download ] - for future in as_completed(futures): - result.append(future.result()) + yield from (future.result() for future in as_completed(futures)) else: log.debug("|%s| Using plain old for-loop", self.__class__.__name__) - for source_file, target_file, tmp_file in to_download: - result.append( - self._download_file( - source_file, - target_file, - tmp_file, - ), - ) - - return result + yield from ( + self._download_file(source_file, target_file, tmp_file) + for source_file, target_file, tmp_file in to_download + ) def _download_file( # noqa: WPS231, WPS213 self, source_file: RemotePath, local_file: LocalPath, tmp_file: LocalPath | None, - ) -> tuple[FileDownloadStatus, PurePathProtocol | PathWithStatsProtocol]: + ) -> tuple[FileDownloadStatus, RemotePath, LocalPath | None]: if tmp_file: log.info( "|%s| Downloading file '%s' to '%s' (via tmp '%s')", @@ -732,7 +811,7 @@ def _download_file( # noqa: WPS231, WPS213 if not self.connection.path_exists(source_file): log.warning("|%s| Missing file '%s', skipping", self.__class__.__name__, source_file) - return FileDownloadStatus.MISSING, source_file + return FileDownloadStatus.MISSING, source_file, None try: remote_file = self.connection.resolve_file(source_file) @@ -744,7 +823,7 @@ def _download_file( # noqa: WPS231, WPS213 if self.options.if_exists == FileExistBehavior.IGNORE: log.warning("|Local FS| File %s already exists, skipping", path_repr(local_file)) - return FileDownloadStatus.SKIPPED, remote_file + return FileDownloadStatus.SKIPPED, remote_file, None replace = True @@ -766,16 +845,11 @@ def _download_file( # noqa: WPS231, WPS213 # Direct download self.connection.download_file(remote_file, local_file, replace=replace) - if self.hwm_type: - strategy = StrategyManager.get_current() - strategy.hwm.update(remote_file) - strategy.save_hwm() - # Delete Remote if self.options.delete_source: self.connection.remove_file(remote_file) - return FileDownloadStatus.SUCCESSFUL, local_file + return FileDownloadStatus.SUCCESSFUL, remote_file, local_file except Exception as e: if log.isEnabledFor(logging.DEBUG): @@ -791,11 +865,12 @@ def _download_file( # noqa: WPS231, WPS213 e, exc_info=False, ) - return FileDownloadStatus.FAILED, FailedRemoteFile( + failed_file = FailedRemoteFile( path=remote_file.path, stats=remote_file.stats, exception=e, ) + return FileDownloadStatus.FAILED, failed_file, None def _remove_temp_dir(self, temp_dir: LocalPath) -> None: log.info("|Local FS| Removing temp directory '%s'", temp_dir) @@ -809,11 +884,3 @@ def _log_result(self, result: DownloadResult) -> None: log_with_indent(log, "") log.info("|%s| Download result:", self.__class__.__name__) log_lines(log, str(result)) - entity_boundary_log(log, msg=f"{self.__class__.__name__} ends", char="-") - - @staticmethod - def _check_hwm_type(hwm_type: type[HWM]) -> None: - if not issubclass(hwm_type, FileHWM): - raise ValueError( - f"`hwm_type` class should be a inherited from FileHWM, got {hwm_type.__name__}", - ) diff --git a/onetl/file/file_mover/file_mover.py b/onetl/file/file_mover/file_mover.py index 7d27eb8a8..8dd48be6f 100644 --- a/onetl/file/file_mover/file_mover.py +++ b/onetl/file/file_mover/file_mover.py @@ -21,7 +21,7 @@ from typing import Iterable, List, Optional, Tuple from ordered_set import OrderedSet -from pydantic import Field, validator +from pydantic import Field, PrivateAttr, validator from onetl.base import BaseFileConnection, BaseFileFilter, BaseFileLimit from onetl.base.path_protocol import PathProtocol, PathWithStatsProtocol @@ -152,6 +152,8 @@ class FileMover(FrozenModel): """ + Options = FileMoverOptions + connection: BaseFileConnection target_path: RemotePath @@ -162,7 +164,7 @@ class FileMover(FrozenModel): options: FileMoverOptions = FileMoverOptions() - Options = FileMoverOptions + _connection_checked: bool = PrivateAttr(default=False) @slot def run(self, files: Iterable[str | os.PathLike] | None = None) -> MoveResult: # noqa: WPS231 @@ -188,7 +190,7 @@ def run(self, files: Iterable[str | os.PathLike] | None = None) -> MoveResult: Move result object Raises - ------- + ------ :obj:`onetl.exception.DirectoryNotFoundError` ``source_path`` does not found @@ -272,22 +274,25 @@ def run(self, files: Iterable[str | os.PathLike] | None = None) -> MoveResult: assert not moved_files.missing """ + entity_boundary_log(log, f"{self.__class__.__name__}.run() starts") + if files is None and not self.source_path: raise ValueError("Neither file list nor `source_path` are passed") - self._log_parameters(files) + if not self._connection_checked: + self._log_parameters(files) - # Check everything - self.connection.check() - self._check_target_path() - log_with_indent(log, "") + self.connection.check() + self._check_target_path() + log_with_indent(log, "") - if self.source_path: - self._check_source_path() + if self.source_path: + self._check_source_path() - if files is None: - log.info("|%s| File list is not passed to `run` method", self.__class__.__name__) + self._connection_checked = True + if files is None: + log.debug("|%s| File list is not passed to `run` method", self.__class__.__name__) files = self.view_files() if not files: @@ -303,6 +308,7 @@ def run(self, files: Iterable[str | os.PathLike] | None = None) -> MoveResult: result = self._move_files(to_move) self._log_result(result) + entity_boundary_log(log, f"{self.__class__.__name__}.run() ends", char="-") return result @slot @@ -312,7 +318,7 @@ def view_files(self) -> FileSet[RemoteFile]: after ``filter`` and ``limit`` applied (if any). |support_hooks| Raises - ------- + ------ :obj:`onetl.exception.DirectoryNotFoundError` ``source_path`` does not found @@ -347,9 +353,14 @@ def view_files(self) -> FileSet[RemoteFile]: } """ + if not self.source_path: + raise ValueError("Cannot call `.view_files()` without `source_path`") + log.debug("|%s| Getting files list from path '%s'", self.connection.__class__.__name__, self.source_path) - self._check_source_path() + if not self._connection_checked: + self._check_source_path() + result = FileSet() try: @@ -365,8 +376,6 @@ def view_files(self) -> FileSet[RemoteFile]: return result def _log_parameters(self, files: Iterable[str | os.PathLike] | None = None) -> None: - entity_boundary_log(log, msg="FileMover starts") - connection_class = self.connection.__class__.__name__ log.info("|%s| -> |%s| Moving files using parameters:", connection_class, connection_class) log_with_indent(log, "source_path = %s", f"'{self.source_path}'" if self.source_path else "None") @@ -566,4 +575,3 @@ def _log_result(self, result: MoveResult) -> None: log_with_indent(log, "") log.info("|%s| Move result:", self.__class__.__name__) log_lines(log, str(result)) - entity_boundary_log(log, msg=f"{self.__class__.__name__} ends", char="-") diff --git a/onetl/file/file_uploader/file_uploader.py b/onetl/file/file_uploader/file_uploader.py index e9e8c550a..e2ec9cc48 100644 --- a/onetl/file/file_uploader/file_uploader.py +++ b/onetl/file/file_uploader/file_uploader.py @@ -21,7 +21,7 @@ from typing import Iterable, Optional, Tuple from ordered_set import OrderedSet -from pydantic import validator +from pydantic import PrivateAttr, validator from onetl._internal import generate_temp_path from onetl.base import BaseFileConnection @@ -141,6 +141,8 @@ class FileUploader(FrozenModel): """ + Options = FileUploaderOptions + connection: BaseFileConnection target_path: RemotePath @@ -150,7 +152,7 @@ class FileUploader(FrozenModel): options: FileUploaderOptions = FileUploaderOptions() - Options = FileUploaderOptions + _connection_checked: bool = PrivateAttr(default=False) @slot def run(self, files: Iterable[str | os.PathLike] | None = None) -> UploadResult: @@ -172,7 +174,7 @@ def run(self, files: Iterable[str | os.PathLike] | None = None) -> UploadResult: Upload result object Raises - ------- + ------ :obj:`onetl.exception.DirectoryNotFoundError` ``local_path`` does not found @@ -269,22 +271,23 @@ def run(self, files: Iterable[str | os.PathLike] | None = None) -> UploadResult: assert not uploaded_files.missing """ + entity_boundary_log(log, f"{self.__class__.__name__}.run() starts") + if files is None and not self.local_path: raise ValueError("Neither file list nor `local_path` are passed") - self._log_parameters(files) + if not self._connection_checked: + self._log_parameters(files) - # Check everything - if self.local_path: - self._check_local_path() - - self.connection.check() - log_with_indent(log, "") + if self.local_path: + self._check_local_path() + self.connection.check() + self._connection_checked = True self.connection.create_dir(self.target_path) if files is None: - log.info("|%s| File list is not passed to `run` method", self.__class__.__name__) + log.debug("|%s| File list is not passed to `run` method", self.__class__.__name__) files = self.view_files() if not files: @@ -311,6 +314,7 @@ def run(self, files: Iterable[str | os.PathLike] | None = None) -> UploadResult: self._remove_temp_dir(current_temp_dir) self._log_result(result) + entity_boundary_log(log, f"{self.__class__.__name__}.run() ends", char="-") return result @slot @@ -319,7 +323,7 @@ def view_files(self) -> FileSet[LocalPath]: Get file list in the ``local_path``. |support_hooks| Raises - ------- + ------ :obj:`onetl.exception.DirectoryNotFoundError` ``local_path`` does not found @@ -354,9 +358,14 @@ def view_files(self) -> FileSet[LocalPath]: } """ + if not self.local_path: + raise ValueError("Cannot call `.view_files()` without `local_path`") + log.debug("|Local FS| Getting files list from path '%s'", self.local_path) - self._check_local_path() + if not self._connection_checked: + self._check_local_path() + result = FileSet() try: @@ -383,8 +392,6 @@ def _validate_temp_path(cls, temp_path): return RemotePath(temp_path) if temp_path else None def _log_parameters(self, files: Iterable[str | os.PathLike] | None = None) -> None: - entity_boundary_log(log, msg="FileUploader starts") - log.info("|Local FS| -> |%s| Uploading files using parameters:'", self.connection.__class__.__name__) log_with_indent(log, "local_path = %s", f"'{self.local_path}'" if self.local_path else "None") log_with_indent(log, "target_path = '%s'", self.target_path) @@ -605,4 +612,3 @@ def _log_result(self, result: UploadResult) -> None: log.info("") log.info("|%s| Upload result:", self.__class__.__name__) log_lines(log, str(result)) - entity_boundary_log(log, msg=f"{self.__class__.__name__} ends", char="-") diff --git a/onetl/file/filter/file_hwm.py b/onetl/file/filter/file_hwm.py index c0f5362cb..395443fc3 100644 --- a/onetl/file/filter/file_hwm.py +++ b/onetl/file/filter/file_hwm.py @@ -14,7 +14,7 @@ from __future__ import annotations -from etl_entities import FileHWM +from etl_entities.hwm import FileHWM from onetl.base import BaseFileFilter, PathProtocol from onetl.impl import FrozenModel @@ -30,7 +30,7 @@ class FileHWMFilter(BaseFileFilter, FrozenModel): Parameters ---------- - hwm : :obj:`etl_entities.FileHWM` + hwm : :obj:`etl_entities.hwm.FileHWM` File HWM instance """ @@ -47,7 +47,7 @@ def match(self, path: PathProtocol) -> bool: return not self.hwm.covers(path) def __str__(self): - return self.hwm.qualified_name + return self.hwm.name def __repr__(self): - return f"{self.hwm.__class__.__name__}(qualified_name={self.hwm.qualified_name!r})" + return f"{self.hwm.__class__.__name__}(name={self.hwm.name!r})" diff --git a/onetl/file/format/excel.py b/onetl/file/format/excel.py index ffd11a5da..3f5b2bdcf 100644 --- a/onetl/file/format/excel.py +++ b/onetl/file/format/excel.py @@ -69,7 +69,7 @@ class Excel(ReadWriteFileFormat): .. dropdown:: Version compatibility - * Spark versions: 3.2.x - 3.4.x. + * Spark versions: 3.2.x - 3.5.x. .. warning:: @@ -100,7 +100,7 @@ class Excel(ReadWriteFileFormat): from pyspark.sql import SparkSession # Create Spark session with Excel package loaded - maven_packages = Excel.get_packages(spark_version="3.4.1") + maven_packages = Excel.get_packages(spark_version="3.5.0") spark = ( SparkSession.builder.appName("spark-app-name") .config("spark.jars.packages", ",".join(maven_packages)) @@ -150,7 +150,7 @@ def get_packages( If ``None``, ``spark_version`` is used to determine Scala version. version: str, optional - Package version in format ``major.minor.patch``. Default is ``0.19.0``. + Package version in format ``major.minor.patch``. Default is ``0.20.3``. .. warning:: @@ -168,12 +168,12 @@ def get_packages( from onetl.file.format import Excel - Excel.get_packages(spark_version="3.4.1") - Excel.get_packages(spark_version="3.4.1", scala_version="2.13") + Excel.get_packages(spark_version="3.5.0") + Excel.get_packages(spark_version="3.5.0", scala_version="2.13") Excel.get_packages( - spark_version="3.4.1", + spark_version="3.5.0", scala_version="2.13", - package_version="0.19.0", + package_version="0.20.3", ) """ @@ -187,7 +187,7 @@ def get_packages( raise ValueError(f"Package version should be at least 0.15, got {package_version}") log.warning("Passed custom package version %r, it is not guaranteed to be supported", package_version) else: - version = Version.parse("0.19.0") + version = Version.parse("0.20.3") spark_ver = Version.parse(spark_version) if spark_ver < (3, 2): diff --git a/onetl/hooks/hook.py b/onetl/hooks/hook.py index 6988db3f8..e49039a3c 100644 --- a/onetl/hooks/hook.py +++ b/onetl/hooks/hook.py @@ -8,14 +8,13 @@ from functools import wraps from typing import Callable, Generator, Generic, TypeVar -from typing_extensions import ParamSpec, Protocol, runtime_checkable +from typing_extensions import Protocol, runtime_checkable from onetl.log import NOTICE logger = logging.getLogger(__name__) T = TypeVar("T") -P = ParamSpec("P") class HookPriority(int, Enum): @@ -36,7 +35,7 @@ class HookPriority(int, Enum): @dataclass # noqa: WPS338 -class Hook(Generic[P, T]): # noqa: WPS338 +class Hook(Generic[T]): # noqa: WPS338 """ Hook representation. @@ -70,7 +69,7 @@ def some_func(*args, **kwargs): hook = Hook(callback=some_func, enabled=True, priority=HookPriority.FIRST) """ - callback: Callable[P, T] + callback: Callable[..., T] enabled: bool = True priority: HookPriority = HookPriority.NORMAL @@ -198,7 +197,7 @@ def hook_disabled(): ) self.enabled = True - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T | ContextDecorator: + def __call__(self, *args, **kwargs) -> T | ContextDecorator: """ Calls the original callback with passed args. @@ -361,7 +360,7 @@ def process_result(self, result: T) -> T | None: return None -def hook(inp: Callable[P, T] | None = None, enabled: bool = True, priority: HookPriority = HookPriority.NORMAL): +def hook(inp: Callable[..., T] | None = None, enabled: bool = True, priority: HookPriority = HookPriority.NORMAL): """ Initialize hook from callable/context manager. @@ -423,7 +422,7 @@ def process_result(self, result): ... """ - def inner_wrapper(callback: Callable[P, T]): # noqa: WPS430 + def inner_wrapper(callback: Callable[..., T]): # noqa: WPS430 if isinstance(callback, Hook): raise TypeError("@hook decorator can be applied only once") diff --git a/onetl/hooks/slot.py b/onetl/hooks/slot.py index 3fc849574..b6f80efe9 100644 --- a/onetl/hooks/slot.py +++ b/onetl/hooks/slot.py @@ -8,7 +8,7 @@ from functools import partial, wraps from typing import Any, Callable, ContextManager, TypeVar -from typing_extensions import ParamSpec, Protocol +from typing_extensions import Protocol from onetl.exception import SignatureError from onetl.hooks.hook import CanProcessResult, Hook, HookPriority @@ -17,13 +17,12 @@ from onetl.hooks.method_inheritance_stack import MethodInheritanceStack from onetl.log import NOTICE -logger = logging.getLogger(__name__) +Method = TypeVar("Method", bound=Callable[..., Any]) -P = ParamSpec("P") -T = TypeVar("T") +logger = logging.getLogger(__name__) -def _unwrap_method(method: Callable[P, T]) -> Callable[P, T]: +def _unwrap_method(method: Method) -> Method: """Unwrap @classmethod and @staticmethod to get original function""" return getattr(method, "__func__", method) @@ -83,20 +82,20 @@ def method(self, arg): @MyClass.method.bind @hook - def hook(self, arg): + def callable(self, arg): if arg == "some": do_something() @MyClass.method.bind @hook(priority=HookPriority.FIRST, enabled=True) - def another_hook(self, arg): + def another_callable(self, arg): if arg == "another": raise NotAllowed() obj = MyClass() - obj.method(1) # will call both hook(obj, 1) and another_hook(obj, 1) + obj.method(1) # will call both callable(obj, 1) and another_callable(obj, 1) """ def inner_wrapper(hook): # noqa: WPS430 @@ -624,7 +623,7 @@ def bind(self): ... -def slot(method: Callable[P, T]) -> Callable[P, T]: +def slot(method: Method) -> Method: """ Decorator which enables hooks functionality on a specific class method. diff --git a/onetl/hwm/__init__.py b/onetl/hwm/__init__.py index d6e52d105..333d3958a 100644 --- a/onetl/hwm/__init__.py +++ b/onetl/hwm/__init__.py @@ -1 +1,2 @@ -from onetl.hwm.statement import Statement +from onetl.hwm.auto_hwm import AutoDetectHWM +from onetl.hwm.window import Edge, Window diff --git a/onetl/hwm/auto_hwm.py b/onetl/hwm/auto_hwm.py new file mode 100644 index 000000000..aae9c89fd --- /dev/null +++ b/onetl/hwm/auto_hwm.py @@ -0,0 +1,21 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from etl_entities.hwm import ColumnHWM +from typing_extensions import Literal + + +class AutoDetectHWM(ColumnHWM): + value: Literal[None] = None diff --git a/onetl/hwm/statement.py b/onetl/hwm/statement.py deleted file mode 100644 index 7063f0bf0..000000000 --- a/onetl/hwm/statement.py +++ /dev/null @@ -1,11 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any, Callable - - -@dataclass -class Statement: - expression: Callable | str - operator: Any - value: Any diff --git a/onetl/hwm/store/__init__.py b/onetl/hwm/store/__init__.py index 85f66473c..4f71ba0d5 100644 --- a/onetl/hwm/store/__init__.py +++ b/onetl/hwm/store/__init__.py @@ -11,15 +11,49 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import textwrap +import warnings +from importlib import import_module -from onetl.hwm.store.base_hwm_store import BaseHWMStore -from onetl.hwm.store.hwm_class_registry import HWMClassRegistry, register_hwm_class -from onetl.hwm.store.hwm_store_class_registry import ( - HWMStoreClassRegistry, - default_hwm_store_class, - detect_hwm_store, - register_hwm_store_class, +from onetl.hwm.store.hwm_class_registry import ( + SparkTypeToHWM, + register_spark_type_to_hwm_type_mapping, ) -from onetl.hwm.store.hwm_store_manager import HWMStoreManager -from onetl.hwm.store.memory_hwm_store import MemoryHWMStore -from onetl.hwm.store.yaml_hwm_store import YAMLHWMStore +from onetl.hwm.store.yaml_hwm_store import YAMLHWMStore, default_hwm_store_class + +deprecated_imports = { + "MemoryHWMStore", + "BaseHWMStore", + "HWMStoreClassRegistry", + "HWMStoreManager", + "detect_hwm_store", + "register_hwm_store_class", +} + + +def __getattr__(name: str): + if name in deprecated_imports: + msg = f""" + This import is deprecated since v0.10.0: + + from onetl.hwm.store import {name} + + Please use instead: + + from etl_entities.hwm_store import {name} + """ + + warnings.warn( + textwrap.dedent(msg), + UserWarning, + stacklevel=2, + ) + + if name == "HWMStoreManager": + from etl_entities.hwm_store import HWMStoreStackManager + + return HWMStoreStackManager + + return getattr(import_module("etl_entities.hwm_store"), name) + + raise ImportError(f"cannot import name {name!r} from {__name__!r}") diff --git a/onetl/hwm/store/base_hwm_store.py b/onetl/hwm/store/base_hwm_store.py deleted file mode 100644 index e529de907..000000000 --- a/onetl/hwm/store/base_hwm_store.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2023 MTS (Mobile Telesystems) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import logging -import os -from abc import ABC, abstractmethod -from typing import Any - -from etl_entities import HWM - -from onetl.impl import BaseModel, path_repr -from onetl.log import log_with_indent - -log = logging.getLogger(__name__) - - -class BaseHWMStore(BaseModel, ABC): - def __enter__(self): - """ - HWM store context manager. - - Enter this context to use this HWM store instance as current one (instead default). - - Examples - -------- - - .. code:: python - - with hwm_store: - db_reader.run() - """ - # hack to avoid circular imports - from onetl.hwm.store import HWMStoreManager - - log.debug("|%s| Entered stack at level %d", self.__class__.__name__, HWMStoreManager.get_current_level()) - HWMStoreManager.push(self) - - self._log_parameters() - return self - - def __exit__(self, _exc_type, _exc_value, _traceback): - from onetl.hwm.store import HWMStoreManager - - log.debug("|%s| Exiting stack at level %d", self, HWMStoreManager.get_current_level() - 1) - HWMStoreManager.pop() - return False - - @abstractmethod - def get(self, name: str) -> HWM | None: - """ - Get HWM by qualified name from HWM store. |support_hooks| - - Parameters - ---------- - name : str - HWM qualified name - - Returns - ------- - HWM object, if it exists in HWM store, or None - - Examples - -------- - - .. code:: python - - from etl_entities import IntHWM - - # just to generate qualified name using HWM parts - empty_hwm = IntHWM(column=..., table=..., process=...) - real_hwm = hwm_store.get(empty_hwm.qualified_name) - """ - - @abstractmethod - def save(self, hwm: HWM) -> Any: - """ - Save HWM object to HWM Store. |support_hooks| - - Parameters - ---------- - hwm : :obj:`etl_entities.hwm.HWM` - HWM object - - Returns - ------- - HWM location, like URL of file path. - - Examples - -------- - - .. code:: python - - from etl_entities import IntHWM - - hwm = IntHWM(value=..., column=..., table=..., process=...) - hwm_location = hwm_store.save(hwm) - """ - - def _log_parameters(self) -> None: - log.info("|onETL| Using %s as HWM Store", self.__class__.__name__) - options = self.dict(by_alias=True, exclude_none=True) - - if options: - log.info("|%s| Using options:", self.__class__.__name__) - for option, value in options.items(): - if isinstance(value, os.PathLike): - log_with_indent(log, "%s = %s", option, path_repr(value)) - else: - log_with_indent(log, "%s = %r", option, value) diff --git a/onetl/hwm/store/hwm_class_registry.py b/onetl/hwm/store/hwm_class_registry.py index 09f2ff8b0..d9a2ce674 100644 --- a/onetl/hwm/store/hwm_class_registry.py +++ b/onetl/hwm/store/hwm_class_registry.py @@ -14,13 +14,12 @@ from __future__ import annotations -from typing import Any, Callable, ClassVar, Iterator, Optional +from typing import ClassVar -from etl_entities import HWM, DateHWM, DateTimeHWM, FileListHWM, IntHWM -from pydantic import StrictInt +from etl_entities.hwm import HWM, ColumnDateHWM, ColumnDateTimeHWM, ColumnIntHWM -class HWMClassRegistry: +class SparkTypeToHWM: """Registry class for HWM types Examples @@ -28,42 +27,43 @@ class HWMClassRegistry: .. code:: python - from etl_entities import IntHWM, DateHWM - from onetl.hwm.store import HWMClassRegistry + from etl_entities.hwm import ColumnIntHWM, ColumnDateHWM + from onetl.hwm.store import SparkTypeToHWM - HWMClassRegistry.get("int") == IntHWM - HWMClassRegistry.get("integer") == IntHWM # multiple type names are supported + assert SparkTypeToHWM.get("integer") == ColumnIntHWM + assert SparkTypeToHWM.get("short") == ColumnIntHWM # multiple type names are supported - HWMClassRegistry.get("date") == DateHWM + assert SparkTypeToHWM.get("date") == ColumnDateHWM - HWMClassRegistry.get("unknown") # raise KeyError + assert SparkTypeToHWM.get("unknown") is None """ _mapping: ClassVar[dict[str, type[HWM]]] = { - "byte": IntHWM, - "integer": IntHWM, - "short": IntHWM, - "long": IntHWM, - "date": DateHWM, - "timestamp": DateTimeHWM, - "file_list": FileListHWM, + "byte": ColumnIntHWM, + "integer": ColumnIntHWM, + "short": ColumnIntHWM, + "long": ColumnIntHWM, + "date": ColumnDateHWM, + "timestamp": ColumnDateTimeHWM, + # for Oracle which does not differ between int and float/double - everything is Decimal + "float": ColumnIntHWM, + "double": ColumnIntHWM, + "fractional": ColumnIntHWM, + "decimal": ColumnIntHWM, + "numeric": ColumnIntHWM, } @classmethod - def get(cls, type_name: str) -> type[HWM]: - result = cls._mapping.get(type_name) - if not result: - raise KeyError(f"Unknown HWM type {type_name!r}") - - return result + def get(cls, type_name: str) -> type[HWM] | None: + return cls._mapping.get(type_name) @classmethod def add(cls, type_name: str, klass: type[HWM]) -> None: cls._mapping[type_name] = klass -def register_hwm_class(*type_names: str): +def register_spark_type_to_hwm_type_mapping(*type_names: str): """Decorator for registering some HWM class with a type name or names Examples @@ -72,43 +72,23 @@ def register_hwm_class(*type_names: str): .. code:: python from etl_entities import HWM - from onetl.hwm.store import HWMClassRegistry - from onetl.hwm.store import HWMClassRegistry, register_hwm_class + from onetl.hwm.store import SparkTypeToHWM + from onetl.hwm.store import SparkTypeToHWM, register_spark_type_to_hwm_type_mapping - @register_hwm_class("somename", "anothername") + @register_spark_type_to_hwm_type_mapping("somename", "anothername") class MyHWM(HWM): ... - HWMClassRegistry.get("somename") == MyClass - HWMClassRegistry.get("anothername") == MyClass + assert SparkTypeToHWM.get("somename") == MyClass + assert SparkTypeToHWM.get("anothername") == MyClass """ def wrapper(cls: type[HWM]): for type_name in type_names: - HWMClassRegistry.add(type_name, cls) - + SparkTypeToHWM.add(type_name, cls) return cls return wrapper - - -class Decimal(StrictInt): - @classmethod - def __get_validators__(cls) -> Iterator[Callable]: - yield cls.validate - - @classmethod - def validate(cls, value: Any) -> int: - if round(float(value)) != float(value): - raise ValueError(f"{cls.__name__} cannot have fraction part") - return int(value) - - -@register_hwm_class("float", "double", "fractional", "decimal", "numeric") -class DecimalHWM(IntHWM): - """Same as IntHWM, but allows to pass values like 123.000 (float without fractional part)""" - - value: Optional[Decimal] = None diff --git a/onetl/hwm/store/hwm_store_class_registry.py b/onetl/hwm/store/hwm_store_class_registry.py deleted file mode 100644 index 8567ee96a..000000000 --- a/onetl/hwm/store/hwm_store_class_registry.py +++ /dev/null @@ -1,280 +0,0 @@ -# Copyright 2023 MTS (Mobile Telesystems) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from functools import wraps -from typing import Any, Callable, ClassVar, Collection, Mapping - -from onetl.hwm.store.base_hwm_store import BaseHWMStore - - -class HWMStoreClassRegistry: - """Registry class of different HWM stores. - - Examples - -------- - - .. code:: python - - from onetl.hwm.store import HWMStoreClassRegistry, YAMLHWMStore, MemoryHWMStore - - HWMStoreClassRegistry.get("yml") == YAMLHWMStore - HWMStoreClassRegistry.get("memory") == MemoryHWMStore - - HWMStoreClassRegistry.get() == YAMLHWMStore # default - - HWMStoreClassRegistry.get("unknown") # raise KeyError - - """ - - _default: type[BaseHWMStore | None] = type(None) - _mapping: ClassVar[dict[str, type[BaseHWMStore]]] = {} - - @classmethod - def get(cls, type_name: str | None = None) -> type: - if not type_name: - return cls._default - - result = cls._mapping.get(type_name) - if not result: - raise KeyError(f"Unknown HWM Store type {type_name!r}") - - return result - - @classmethod - def add(cls, type_name: str, klass: type[BaseHWMStore]) -> None: - assert isinstance(type_name, str) # noqa: S101 - assert issubclass(klass, BaseHWMStore) # noqa: S101 - - cls._mapping[type_name] = klass - - @classmethod - def set_default(cls, klass: type[BaseHWMStore]) -> None: - cls._default = klass - - @classmethod - def known_types(cls) -> Collection[str]: - return cls._mapping.keys() - - -def default_hwm_store_class(klass: type[BaseHWMStore]) -> type[BaseHWMStore]: - """Decorator for setting up some Store class as default one - - Examples - -------- - - .. code:: python - - from onetl.hwm.store import ( - HWMStoreClassRegistry, - default_hwm_store_class, - BaseHWMStore, - ) - - - @default_hwm_store_class - class MyClass(BaseHWMStore): - ... - - - HWMStoreClassRegistry.get() == MyClass # default - - """ - - HWMStoreClassRegistry.set_default(klass) - return klass - - -def register_hwm_store_class(*type_names: str): - """Decorator for registering some Store class with a name - - Examples - -------- - - .. code:: python - - from onetl.hwm.store import ( - HWMStoreClassRegistry, - register_hwm_store_class, - BaseHWMStore, - ) - - - @register_hwm_store_class("somename") - class MyClass(BaseHWMStore): - ... - - - HWMStoreClassRegistry.get("somename") == MyClass - - """ - - def wrapper(cls: type[BaseHWMStore]): - for type_name in type_names: - HWMStoreClassRegistry.add(type_name, cls) - - return cls - - return wrapper - - -def parse_config(value: Any, key: str) -> tuple[str, list, Mapping]: - if not isinstance(value, (str, Mapping)): - raise ValueError(f"Wrong value {value!r} for {key!r} config item") - - store_type = "unknown" - args: list[Any] = [] - kwargs: Mapping[str, Any] = {} - - if isinstance(value, str): - return value, args, kwargs - - for item in HWMStoreClassRegistry.known_types(): - if item not in value: - continue - - store_type = item - child = value[item] - - args, kwargs = parse_child_item(child) - - return store_type, args, kwargs - - -def parse_child_item(child: Any) -> tuple[list, Mapping]: - store_args: list[Any] = [] - store_kwargs: Mapping[str, Any] = {} - - if not child: - return store_args, store_kwargs - - if isinstance(child, str): - store_args = [child] - elif isinstance(child, Mapping): - store_kwargs = child - else: - store_args = child - - return store_args, store_kwargs - - -def dict_item_getter(key: str) -> Callable: - def wrapper(conf): # noqa: WPS430 - return resolve_attr(conf, key) - - return wrapper - - -def resolve_attr(conf: Mapping, hwm_key: str) -> str | Mapping: - obj = {} - - try: - if "." not in hwm_key: - obj = conf[hwm_key] - else: - for name in hwm_key.split("."): - obj = conf[name] - conf = obj - except Exception as e: - raise ValueError("The configuration does not contain a required key") from e - - return obj - - -def detect_hwm_store(key: str) -> Callable: - """Detect HWM store by config object - - Parameters - ---------- - key : str - The name of the section in the config that stores information about hwm - - .. warning :: - - **DO NOT** use dot ``.`` in config keys - - Examples - -------- - - Config - - .. code:: yaml - - # if HWM store can be created with no args - hwm_store: yaml - - or - - .. code:: yaml - - # named constructor args - hwm_store: - atlas: - url: http://some.atlas.url - user: username - password: password - - Config could be nested: - - .. code:: yaml - - myetl: - env: - hwm_store: yaml - - ``run.py`` - - .. code:: python - - import hydra - from omegaconf import DictConfig - from onetl.hwm.store import detect_hwm_store - - - # key=... is a path to config item, delimited by dot ``.`` - @hydra.main(config="../conf") - @detect_hwm_store(key="myetl.env.hwm_store") - def main(config: DictConfig): - pass - - """ - - if not isinstance(key, str): - raise ValueError("key name must be a string") - - def pre_wrapper(func: Callable): # noqa: WPS430 - @wraps(func) - def wrapper(config: Mapping, *args, **kwargs): - if not config: - raise ValueError("Config must be specified") - - if not key: - raise ValueError("Key value must be specified") - - get_hwm_spec = dict_item_getter(key) - root = get_hwm_spec(config) - - if not root: - return func(config, *args, **kwargs) - - store_type, store_args, store_kwargs = parse_config(root, key) - store = HWMStoreClassRegistry.get(store_type) - - with store(*store_args, **store_kwargs): - return func(config, *args, **kwargs) - - return wrapper - - return pre_wrapper diff --git a/onetl/hwm/store/hwm_store_manager.py b/onetl/hwm/store/hwm_store_manager.py deleted file mode 100644 index dbb59d0b2..000000000 --- a/onetl/hwm/store/hwm_store_manager.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2023 MTS (Mobile Telesystems) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import ClassVar - -from onetl.hwm.store.base_hwm_store import BaseHWMStore -from onetl.hwm.store.hwm_store_class_registry import HWMStoreClassRegistry - - -class HWMStoreManager: - _stack: ClassVar[list[BaseHWMStore]] = [] - - @classmethod - def push(cls, hwm_store: BaseHWMStore) -> None: - cls._stack.append(hwm_store) - - @classmethod - def pop(cls) -> BaseHWMStore: - return cls._stack.pop() - - @classmethod - def get_current_level(cls) -> int: - return len(cls._stack) - - @classmethod - def get_current(cls) -> BaseHWMStore: - if cls._stack: - return cls._stack[-1] - - default_store_type = HWMStoreClassRegistry.get() - return default_store_type() diff --git a/onetl/hwm/store/memory_hwm_store.py b/onetl/hwm/store/memory_hwm_store.py deleted file mode 100644 index ada859bb8..000000000 --- a/onetl/hwm/store/memory_hwm_store.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2023 MTS (Mobile Telesystems) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import Dict - -from etl_entities import HWM -from pydantic import PrivateAttr - -from onetl.hooks import slot, support_hooks -from onetl.hwm.store.base_hwm_store import BaseHWMStore -from onetl.hwm.store.hwm_store_class_registry import register_hwm_store_class - - -@register_hwm_store_class("memory", "in-memory") -@support_hooks -class MemoryHWMStore(BaseHWMStore): - """In-memory local store for HWM values. |support_hooks| - - .. note:: - - This class should be used in tests only, because all saved HWM values - will be deleted after exiting the context - - Examples - -------- - - .. code:: python - - from onetl.connection import Hive, Postgres - from onetl.db import DBReader - from onetl.strategy import IncrementalStrategy - from onetl.hwm.store import MemoryHWMStore - - from pyspark.sql import SparkSession - - maven_packages = Postgres.get_packages() - spark = ( - SparkSession.builder.appName("spark-app-name") - .config("spark.jars.packages", ",".join(maven_packages)) - .getOrCreate() - ) - - postgres = Postgres( - host="postgres.domain.com", - user="myuser", - password="*****", - database="target_database", - spark=spark, - ) - - hive = Hive(cluster="rnd-dwh", spark=spark) - - reader = DBReader( - connection=postgres, - source="public.mydata", - columns=["id", "data"], - hwm_column="id", - ) - - writer = DBWriter(connection=hive, target="newtable") - - with MemoryHWMStore(): - with IncrementalStrategy(): - df = reader.run() - writer.run(df) - - # will store HWM value in RAM - - # values are lost after exiting the context - """ - - _data: Dict[str, HWM] = PrivateAttr(default_factory=dict) - - @slot - def get(self, name: str) -> HWM | None: - return self._data.get(name, None) - - @slot - def save(self, hwm: HWM) -> None: - self._data[hwm.qualified_name] = hwm - - @slot - def clear(self) -> None: - """ - Clears all stored HWM values. |support_hooks| - """ - self._data.clear() diff --git a/onetl/hwm/store/yaml_hwm_store.py b/onetl/hwm/store/yaml_hwm_store.py index 0cc52aadd..31064d572 100644 --- a/onetl/hwm/store/yaml_hwm_store.py +++ b/onetl/hwm/store/yaml_hwm_store.py @@ -19,26 +19,54 @@ from typing import ClassVar import yaml -from etl_entities import HWM, HWMTypeRegistry +from etl_entities.hwm import HWM, HWMTypeRegistry +from etl_entities.hwm_store import ( + BaseHWMStore, + HWMStoreClassRegistry, + register_hwm_store_class, +) from platformdirs import user_data_dir from pydantic import validator from onetl.hooks import slot, support_hooks -from onetl.hwm.store.base_hwm_store import BaseHWMStore -from onetl.hwm.store.hwm_store_class_registry import ( - default_hwm_store_class, - register_hwm_store_class, -) from onetl.impl import FrozenModel, LocalPath DATA_PATH = LocalPath(user_data_dir("onETL", "ONEtools")) +def default_hwm_store_class(klass: type[BaseHWMStore]) -> type[BaseHWMStore]: + """Decorator for setting up some Store class as default one + + Examples + -------- + + .. code:: python + + from onetl.hwm.store import ( + HWMStoreClassRegistry, + default_hwm_store_class, + BaseHWMStore, + ) + + + @default_hwm_store_class + class MyClass(BaseHWMStore): + ... + + + HWMStoreClassRegistry.get() == MyClass # default + + """ + + HWMStoreClassRegistry.set_default(klass) + return klass + + @default_hwm_store_class -@register_hwm_store_class("yaml", "yml") +@register_hwm_store_class("yaml") @support_hooks class YAMLHWMStore(BaseHWMStore, FrozenModel): - r"""YAML local store for HWM values. Used as default HWM store. |support_hooks| + r"""YAML **local store** for HWM values. Used as default HWM store. |support_hooks| Parameters ---------- @@ -91,7 +119,7 @@ class YAMLHWMStore(BaseHWMStore, FrozenModel): connection=postgres, source="public.mydata", columns=["id", "data"], - hwm_column="id", + hwm=DBReader.AutoDetectHWM(name="some_unique_name", expression="id"), ) writer = DBWriter(connection=hive, target="newtable") @@ -170,20 +198,20 @@ def validate_path(cls, path): return path @slot - def get(self, name: str) -> HWM | None: + def get_hwm(self, name: str) -> HWM | None: # type: ignore data = self._load(name) if not data: return None latest = sorted(data, key=operator.itemgetter("modified_time"))[-1] - return HWMTypeRegistry.parse(latest) + return HWMTypeRegistry.parse(latest) # type: ignore @slot - def save(self, hwm: HWM) -> LocalPath: - data = self._load(hwm.qualified_name) - self._dump(hwm.qualified_name, [hwm.serialize()] + data) - return self.get_file_path(hwm.qualified_name) + def set_hwm(self, hwm: HWM) -> LocalPath: # type: ignore + data = self._load(hwm.name) + self._dump(hwm.name, [hwm.serialize()] + data) + return self.get_file_path(hwm.name) @classmethod def cleanup_file_name(cls, name: str) -> str: diff --git a/onetl/hwm/window.py b/onetl/hwm/window.py new file mode 100644 index 000000000..f4eba5bc7 --- /dev/null +++ b/onetl/hwm/window.py @@ -0,0 +1,35 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class Edge: + value: Any = None + including: bool = True + + def is_set(self) -> bool: + return self.value is not None + + +@dataclass +class Window: + expression: str + start_from: Edge = field(default_factory=Edge) + stop_at: Edge = field(default_factory=Edge) diff --git a/onetl/impl/local_path.py b/onetl/impl/local_path.py index 443b2920d..9708fc200 100644 --- a/onetl/impl/local_path.py +++ b/onetl/impl/local_path.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import sys from pathlib import Path, PurePosixPath, PureWindowsPath @@ -20,8 +21,10 @@ class LocalPath(Path): def __new__(cls, *args, **kwargs): if cls is LocalPath: cls = LocalWindowsPath if os.name == "nt" else LocalPosixPath - self = cls._from_parts(args) - return self # noqa: WPS331 + if sys.version_info < (3, 12): + return cls._from_parts(args) + else: + return object.__new__(cls) # noqa: WPS503 class LocalPosixPath(LocalPath, PurePosixPath): diff --git a/onetl/log.py b/onetl/log.py index e4ccd889f..158e6e1b6 100644 --- a/onetl/log.py +++ b/onetl/log.py @@ -21,9 +21,10 @@ from contextlib import redirect_stdout from enum import Enum from textwrap import dedent -from typing import TYPE_CHECKING, Any, Iterable +from typing import TYPE_CHECKING, Any, Collection, Iterable from deprecated import deprecated +from etl_entities.hwm import HWM if TYPE_CHECKING: from pyspark.sql import DataFrame @@ -459,7 +460,7 @@ def log_options( log_with_indent(logger, "%s = %r", name, None, indent=indent, stacklevel=stacklevel, **kwargs) -def log_dataframe_schema(logger: logging.Logger, df: DataFrame, stacklevel: int = 1): +def log_dataframe_schema(logger: logging.Logger, df: DataFrame, indent: int = 0, stacklevel: int = 1): """Log dataframe schema in the following format: Examples @@ -478,7 +479,7 @@ def log_dataframe_schema(logger: logging.Logger, df: DataFrame, stacklevel: int """ stacklevel += 1 - log_with_indent(logger, "df_schema:", stacklevel=stacklevel) + log_with_indent(logger, "df_schema:", indent=indent, stacklevel=stacklevel) schema_tree = io.StringIO() with redirect_stdout(schema_tree): @@ -487,4 +488,54 @@ def log_dataframe_schema(logger: logging.Logger, df: DataFrame, stacklevel: int df.printSchema() for line in schema_tree.getvalue().splitlines(): - log_with_indent(logger, "%s", line, indent=4, stacklevel=stacklevel) + log_with_indent(logger, "%s", line, indent=indent + 4, stacklevel=stacklevel) + + +def log_hwm(logger: logging.Logger, hwm: HWM, indent: int = 0, stacklevel: int = 1): + """Log HWM in the following format: + + Examples + -------- + + .. code:: python + + hwm = ColumnIntHWM(name="my_unique_name", source="my_source", value=123) + log_hwm(logger, hwm) + + .. code-block:: text + + INFO onetl.module hwm = ColumnIntHWM( + INFO onetl.module name = "my_unique_name", + INFO onetl.module entity = "my_source", + INFO onetl.module expression = None, + INFO onetl.module value = 123, + INFO onetl.module ) + + .. code-block:: text + + INFO onetl.module hwm = FileListHWM( + INFO onetl.module name = "my_unique_name", + INFO onetl.module entity = "my_source", + INFO onetl.module expression = None, + INFO onetl.module value = [ + INFO onetl.module AbsolutePath("/some/file1.csv"), + INFO onetl.module AbsolutePath("/some/file2.csv"), + INFO onetl.module AbsolutePath("/some/file3.csv"), + INFO onetl.module ] + INFO onetl.module ) + """ + stacklevel += 1 + + log_with_indent(logger, "hwm = %s(", type(hwm).__name__, indent=indent, stacklevel=stacklevel) + log_with_indent(logger, "name = %r,", hwm.name, indent=indent + 4, stacklevel=stacklevel) + if hwm.description: + log_with_indent(logger, "description = %r,", hwm.name, indent=indent + 4, stacklevel=stacklevel) + log_with_indent(logger, "entity = %r,", hwm.entity, indent=indent + 4, stacklevel=stacklevel) + log_with_indent(logger, "expression = %r,", hwm.expression, indent=indent + 4, stacklevel=stacklevel) + if hwm.value is not None: + if isinstance(hwm.value, Collection): + log_collection(logger, "value", hwm.value, max_items=10, indent=indent + 4, stacklevel=stacklevel) + else: + log_with_indent(logger, "value = %r,", hwm.value, indent=indent + 4, stacklevel=stacklevel) + + log_with_indent(logger, ")", indent=indent, stacklevel=stacklevel) diff --git a/onetl/strategy/base_strategy.py b/onetl/strategy/base_strategy.py index b287d08f1..e22ebd5ca 100644 --- a/onetl/strategy/base_strategy.py +++ b/onetl/strategy/base_strategy.py @@ -15,8 +15,8 @@ from __future__ import annotations import logging -from typing import Any +from onetl.hwm import Edge from onetl.impl import BaseModel from onetl.log import log_with_indent @@ -51,12 +51,12 @@ def __exit__(self, exc_type, _exc_value, _traceback): return False @property - def current_value(self) -> Any: - pass + def current(self) -> Edge: + return Edge() @property - def next_value(self) -> Any: - pass + def next(self) -> Edge: + return Edge() def enter_hook(self) -> None: pass diff --git a/onetl/strategy/batch_hwm_strategy.py b/onetl/strategy/batch_hwm_strategy.py index 8790891a3..711a4ccff 100644 --- a/onetl/strategy/batch_hwm_strategy.py +++ b/onetl/strategy/batch_hwm_strategy.py @@ -15,12 +15,12 @@ from __future__ import annotations import logging -import operator from textwrap import dedent -from typing import Any, Callable, ClassVar +from typing import Any, ClassVar from pydantic import validator +from onetl.hwm import Edge from onetl.strategy.hwm_strategy import HWMStrategy log = logging.getLogger(__name__) @@ -63,26 +63,10 @@ def __next__(self): else: log.info("|%s| Next iteration", self.__class__.__name__) - return self.current_value + return self.current, self.next @property def is_first_run(self) -> bool: - return self._iteration == 0 - - @property - def is_finished(self) -> bool: - return self.current_value is not None and self.has_upper_limit and self.current_value >= self.stop - - @property - def has_lower_limit(self) -> bool: - return self.start is not None - - @property - def has_upper_limit(self) -> bool: - return self.stop is not None - - @property - def current_value(self) -> Any: if self._iteration < 0: raise RuntimeError( dedent( @@ -98,28 +82,49 @@ def current_value(self) -> Any: ), ) - result = super().current_value + return self._iteration == 0 - if result is None: - result = self.start + @property + def is_finished(self) -> bool: + if self._iteration >= self.MAX_ITERATIONS: + # prevent possible infinite loops in unexpected cases + return True - self.check_argument_is_set("start", result) + if self.current.is_set() and self.stop is not None: + return self.current.value >= self.stop - return result + return False + + def check_has_data(self, value: Any): + if not self.is_first_run and value is None: + log.info( + "|%s| No start or stop values are set, exiting after %s iteration(s)", + self.__class__.__name__, + self._iteration, + ) + raise StopIteration - def check_argument_is_set(self, name: str, value: Any) -> None: - if value is None and not self.is_first_run: - raise ValueError(f"{name!r} argument of {self.__class__.__name__} cannot be empty!") + @property + def current(self) -> Edge: + result = super().current + if not result.is_set(): + result = Edge( + value=self.start, + including=True, + ) + + self.check_has_data(result.value) + return result def check_hwm_increased(self, next_value: Any) -> None: - if self.current_value is None: + if not self.current.is_set(): return - if self.stop is not None and self.current_value == self.stop: - # if rows all have the same hwm_column value, this is not an error, read them all + if self.stop is not None and self.current.value == self.stop: + # if rows all have the same expression value, this is not an error, read them all return - if next_value is not None and self.current_value >= next_value: + if next_value is not None and self.current.value >= next_value: # negative or zero step - exception # DateHWM with step value less than one day - exception raise ValueError( @@ -127,7 +132,7 @@ def check_hwm_increased(self, next_value: Any) -> None: ) if self.stop is not None: - expected_iterations = int((self.stop - self.current_value) / self.step) + expected_iterations = int((self.stop - self.current.value) / self.step) if expected_iterations >= self.MAX_ITERATIONS: raise ValueError( f"step={self.step!r} parameter of {self.__class__.__name__} leads to " @@ -135,33 +140,22 @@ def check_hwm_increased(self, next_value: Any) -> None: ) @property - def next_value(self) -> Any: - if self.current_value is not None: - result = self.current_value + self.step + def next(self) -> Edge: + if self.current.is_set(): + result = Edge(value=self.current.value + self.step) else: - result = self.stop - - self.check_argument_is_set("stop", result) + result = Edge(value=self.stop) - if self.has_upper_limit: - result = min(result, self.stop) + self.check_has_data(result.value) - self.check_hwm_increased(result) + if self.stop is not None: + result.value = min(result.value, self.stop) + self.check_hwm_increased(result.value) return result def update_hwm(self, value: Any) -> None: - # no rows has been read, going to next iteration - if self.hwm is not None: - self.hwm.update(self.next_value) - - super().update_hwm(value) - - @property - def current_value_comparator(self) -> Callable: - if not self.hwm: - # if start == 0 and hwm is not set - # SQL should be `hwm_column >= 0` instead of `hwm_column > 0` - return operator.ge - - return super().current_value_comparator + # batch strategy ticks determined by step size only, + # not by real HWM value read from source + if self.hwm: + self.hwm.update(self.next.value) diff --git a/onetl/strategy/hwm_store/__init__.py b/onetl/strategy/hwm_store/__init__.py index 8687365bb..124fd3ccc 100644 --- a/onetl/strategy/hwm_store/__init__.py +++ b/onetl/strategy/hwm_store/__init__.py @@ -20,23 +20,25 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from onetl.hwm.store import ( - BaseHWMStore, - HWMClassRegistry, - HWMStoreClassRegistry, - HWMStoreManager, + from etl_entities.hwm_store import BaseHWMStore, HWMStoreClassRegistry + from etl_entities.hwm_store import HWMStoreStackManager as HWMStoreManager + from etl_entities.hwm_store import ( MemoryHWMStore, - YAMLHWMStore, - default_hwm_store_class, detect_hwm_store, - register_hwm_class, register_hwm_store_class, ) + from onetl.hwm.store import ( + SparkTypeToHWM, + YAMLHWMStore, + default_hwm_store_class, + register_spark_type_to_hwm_type_mapping, + ) + __all__ = [ # noqa: WPS410 "BaseHWMStore", - "HWMClassRegistry", - "register_hwm_class", + "SparkTypeToHWM", + "register_spark_type_to_hwm_type_mapping", "HWMStoreClassRegistry", "default_hwm_store_class", "detect_hwm_store", diff --git a/onetl/strategy/hwm_strategy.py b/onetl/strategy/hwm_strategy.py index b89c2ba67..c50c74372 100644 --- a/onetl/strategy/hwm_strategy.py +++ b/onetl/strategy/hwm_strategy.py @@ -15,14 +15,16 @@ from __future__ import annotations import logging -import operator import os -from typing import Any, Callable, Collection, Optional +import textwrap +import warnings +from typing import Any, Optional -from etl_entities import HWM +from etl_entities.hwm import HWM +from etl_entities.hwm_store import HWMStoreStackManager -from onetl.hwm.store import HWMStoreManager -from onetl.log import log_collection, log_with_indent +from onetl.hwm import Edge +from onetl.log import log_hwm, log_with_indent from onetl.strategy.base_strategy import BaseStrategy log = logging.getLogger(__name__) @@ -32,22 +34,17 @@ class HWMStrategy(BaseStrategy): hwm: Optional[HWM] = None @property - def current_value(self) -> Any: - if self.hwm: - return self.hwm.value + def current(self) -> Edge: + if self.hwm and self.hwm.value is not None: + return Edge( + value=self.hwm.value, + including=False, + ) - return super().current_value - - @property - def current_value_comparator(self) -> Callable: - return operator.gt - - @property - def next_value_comparator(self) -> Callable: - return operator.le + return super().current def update_hwm(self, value: Any) -> None: - if self.hwm is not None and value is not None: + if self.hwm and value is not None: self.hwm.update(value) def enter_hook(self) -> None: @@ -57,28 +54,75 @@ def enter_hook(self) -> None: def fetch_hwm(self) -> None: class_name = self.__class__.__name__ - - if self.hwm is not None: - hwm_store = HWMStoreManager.get_current() - - log.info("|%s| Loading HWM from %s:", class_name, hwm_store.__class__.__name__) - log_with_indent(log, "qualified_name = %r", self.hwm.qualified_name) - - result = hwm_store.get(self.hwm.qualified_name) - - if result is not None: - self.hwm = result - log.info("|%s| Got HWM:", class_name) - self._log_hwm(self.hwm) - else: - log.warning( - "|%s| HWM does not exist in %r. ALL ROWS/FILES WILL BE READ!", - class_name, - hwm_store.__class__.__name__, - ) - else: + if not self.hwm: # entering strategy context, HWM will be set later by DBReader.run or FileDownloader.run - log.debug("|%s| HWM will not be loaded, skipping", class_name) + log.debug("|%s| HWM will not be fetched, skipping", class_name) + return + + hwm_store = HWMStoreStackManager.get_current() + log.info("|%s| Fetching HWM from %s:", class_name, hwm_store.__class__.__name__) + log_with_indent(log, "name = %r", self.hwm.name) + + result = hwm_store.get_hwm(self.hwm.name) + if result is None: + log.warning( + "|%s| HWM does not exist in %r. ALL ROWS/FILES WILL BE READ!", + class_name, + hwm_store.__class__.__name__, + ) + return + + log.info("|%s| Fetched HWM:", class_name) + log_hwm(log, result) + + self.validate_hwm_type(self.hwm, result) + self.validate_hwm_attributes(self.hwm, result, origin=hwm_store.__class__.__name__) + + self.hwm.set_value(result.value) + if self.hwm != result: + log.info("|%s| Final HWM:", class_name) + log_hwm(log, self.hwm) + + def validate_hwm_type(self, current_hwm: HWM, new_hwm: HWM): + hwm_type = type(current_hwm) + + if not isinstance(new_hwm, hwm_type): + message = textwrap.dedent( + f""" + Cannot cast HWM of type {type(new_hwm).__name__!r} as {hwm_type.__name__!r}. + + Please: + * Check that you set correct HWM name, it should be unique. + * Check that your HWM store contains valid value and type for this HWM name. + """, + ) + raise TypeError(message) + + def validate_hwm_attributes(self, current_hwm: HWM, new_hwm: HWM, origin: str): + attributes = [("entity", True), ("expression", False), ("description", False)] + + for attribute, mandatory in attributes: + if getattr(current_hwm, attribute) != getattr(new_hwm, attribute): + # exception raised when inside one strategy >1 processes on the same table but with different entities + # are executed, example: test_postgres_strategy_incremental_hwm_set_twice + message = textwrap.dedent( + f""" + Detected HWM with different `{attribute}` attribute. + + Current HWM: + {current_hwm!r} + HWM in {origin}: + {new_hwm!r} + + Please: + * Check that you set correct HWM name, it should be unique. + * Check that attributes are consistent in both code and HWM Store. + """, + ) + if mandatory: + raise ValueError(message) + + warnings.warn(message, UserWarning, stacklevel=2) def exit_hook(self, failed: bool = False) -> None: if not failed: @@ -87,31 +131,23 @@ def exit_hook(self, failed: bool = False) -> None: def save_hwm(self) -> None: class_name = self.__class__.__name__ - if self.hwm is not None: - hwm_store = HWMStoreManager.get_current() - - log.info("|%s| Saving HWM to %r:", class_name, hwm_store.__class__.__name__) - self._log_hwm(self.hwm) - log_with_indent(log, "qualified_name = %r", self.hwm.qualified_name) + if not self.hwm: + log.debug("|%s| HWM value is not set, do not save", class_name) + return - location = hwm_store.save(self.hwm) - log.info("|%s| HWM has been saved", class_name) + hwm_store = HWMStoreStackManager.get_current() - if location: - if isinstance(location, os.PathLike): - log_with_indent(log, "location = '%s'", os.fspath(location)) - else: - log_with_indent(log, "location = %r", location) - else: - log.debug("|%s| HWM value is not set, do not save", class_name) + log.info("|%s| Saving HWM to %r:", class_name, hwm_store.__class__.__name__) + log_hwm(log, self.hwm) - def _log_hwm(self, hwm: HWM) -> None: - log_with_indent(log, "type = %s", hwm.__class__.__name__) + location = hwm_store.set_hwm(self.hwm) # type: ignore + log.info("|%s| HWM has been saved", class_name) - if isinstance(hwm.value, Collection): - log_collection(log, "value", hwm.value, max_items=10) - else: - log_with_indent(log, "value = %r", hwm.value) + if location: + if isinstance(location, os.PathLike): + log_with_indent(log, "location = '%s'", os.fspath(location)) + else: + log_with_indent(log, "location = %r", location) @classmethod def _log_exclude_fields(cls) -> set[str]: diff --git a/onetl/strategy/incremental_strategy.py b/onetl/strategy/incremental_strategy.py index ed4d0ee31..04cc87903 100644 --- a/onetl/strategy/incremental_strategy.py +++ b/onetl/strategy/incremental_strategy.py @@ -16,7 +16,7 @@ from typing import Any, Optional -from etl_entities import HWM +from etl_entities.hwm import HWM from onetl.impl import BaseModel from onetl.strategy.batch_hwm_strategy import BatchHWMStrategy @@ -30,7 +30,7 @@ class OffsetMixin(BaseModel): def fetch_hwm(self) -> None: super().fetch_hwm() - if self.hwm and self.offset is not None: + if self.hwm and self.hwm.value is not None and self.offset is not None: self.hwm -= self.offset @@ -41,13 +41,13 @@ class IncrementalStrategy(OffsetMixin, HWMStrategy): by filtering items not covered by the previous :ref:`HWM` value. For :ref:`db-reader`: - First incremental run is just the same as :obj:`onetl.strategy.snapshot_strategy.SnapshotStrategy`: + First incremental run is just the same as :obj:`SnapshotStrategy `: .. code:: sql SELECT id, data FROM mydata; - Then the max value of ``id`` column (e.g. ``1000``) will be saved as ``ColumnHWM`` subclass to :ref:`hwm-store`. + Then the max value of ``id`` column (e.g. ``1000``) will be saved as ``HWM`` to :ref:`HWM Store `. Next incremental run will read only new data from the source: @@ -72,11 +72,11 @@ class IncrementalStrategy(OffsetMixin, HWMStrategy): This allows to resume reading process from the *last successful run*. For :ref:`file-downloader`: - Behavior depends on ``hwm_type`` parameter. + Behavior depends on ``hwm`` type. - ``hwm_type="file_list"``: - First incremental run is just the same as :obj:`onetl.strategy.snapshot_strategy.SnapshotStrategy` - all - files are downloaded: + ``hwm=FileListHWM(...)``: + First incremental run is just the same as :obj:`SnapshotStrategy ` - + all files are downloaded: .. code:: bash @@ -94,7 +94,7 @@ class IncrementalStrategy(OffsetMixin, HWMStrategy): ] ) - Then the downloaded files list is saved as ``FileListHWM`` object into :ref:`hwm-store`: + Then the downloaded files list is saved as ``FileListHWM`` object into :ref:`HWM Store `: .. code:: python @@ -123,7 +123,7 @@ class IncrementalStrategy(OffsetMixin, HWMStrategy): ] ) - New files will be added to the ``FileListHWM`` and saved to :ref:`hwm-store`: + New files will be added to the ``FileListHWM`` and saved to :ref:`HWM Store `: .. code:: python @@ -135,22 +135,10 @@ class IncrementalStrategy(OffsetMixin, HWMStrategy): .. warning:: - FileDownload updates HWM in HWM Store after downloading **each** file, because files are downloading - one after another, not in one batch. - - .. warning:: - - If code inside the context manager raised an exception, like: - - .. code:: python - - with IncrementalStrategy(): - download_result = downloader.run() # something went wrong here - uploader.run(download_result.success) # or here - # or here... - - When FileDownloader **will** update HWM in HWM Store, because: + FileDownload updates HWM in HWM Store at the end of ``.run()`` call, + **NOT** while exiting strategy context. This is because: + * FileDownloader does not raise exceptions if some file cannot be downloaded. * FileDownloader creates files on local filesystem, and file content may differ for different :obj:`modes `. * It can remove files from the source if :obj:`delete_source ` is set to ``True``. @@ -202,7 +190,7 @@ class IncrementalStrategy(OffsetMixin, HWMStrategy): .. warning:: - Cannot be used with :ref:`file-downloader` and ``hwm_type="file_list"`` + Cannot be used with :ref:`file-downloader` and ``hwm=FileListHWM(...)`` .. note:: @@ -220,6 +208,7 @@ class IncrementalStrategy(OffsetMixin, HWMStrategy): from onetl.connection import Postgres from onetl.db import DBReader from onetl.strategy import IncrementalStrategy + from onetl.hwm import AutoDetectHWM from pyspark.sql import SparkSession @@ -242,7 +231,7 @@ class IncrementalStrategy(OffsetMixin, HWMStrategy): connection=postgres, source="public.mydata", columns=["id", "data"], - hwm_column="id", + hwm=DBReader.AutoDetectHWM(name="some_hwm_name", expression="id"), ) writer = DBWriter(connection=hive, target="newtable") @@ -277,9 +266,9 @@ class IncrementalStrategy(OffsetMixin, HWMStrategy): SELECT id, data FROM public.mydata - WHERE id > 900; --- from HWM-offset (EXCLUDING first row) + WHERE id > 900; -- from HWM-offset (EXCLUDING first row) - ``hwm_column`` can be a date or datetime, not only integer: + ``hwm.expression`` can be a date or datetime, not only integer: .. code:: python @@ -289,7 +278,7 @@ class IncrementalStrategy(OffsetMixin, HWMStrategy): connection=postgres, source="public.mydata", columns=["business_dt", "data"], - hwm_column="business_dt", + hwm=DBReader.AutoDetectHWM(name="some_hwm_name", expression="business_dt"), ) with IncrementalStrategy(offset=timedelta(days=1)): @@ -303,15 +292,16 @@ class IncrementalStrategy(OffsetMixin, HWMStrategy): SELECT business_dt, data FROM public.mydata - WHERE business_dt > CAST('2021-01-09' AS DATE); + WHERE business_dt > CAST('2021-01-09' AS DATE); -- from HWM-offset (EXCLUDING first row) - Incremental run with :ref:`file-downloader` and ``hwm_type="file_list"``: + Incremental run with :ref:`file-downloader` and ``hwm=FileListHWM(...)``: .. code:: python from onetl.connection import SFTP from onetl.file import FileDownloader from onetl.strategy import SnapshotStrategy + from etl_entities import FileListHWM sftp = SFTP( host="sftp.domain.com", @@ -323,7 +313,7 @@ class IncrementalStrategy(OffsetMixin, HWMStrategy): connection=sftp, source_path="/remote", local_path="/local", - hwm_type="file_list", + hwm=FileListHWM(name="some_hwm_name"), ) with IncrementalStrategy(): @@ -340,7 +330,7 @@ class IncrementalBatchStrategy(OffsetMixin, BatchHWMStrategy): Cannot be used with :ref:`file-downloader` - Same as :obj:`onetl.strategy.incremental_strategy.IncrementalStrategy`, + Same as :obj:`IncrementalStrategy `, but reads data from the source in sequential batches (1..N) like: .. code:: sql @@ -358,8 +348,8 @@ class IncrementalBatchStrategy(OffsetMixin, BatchHWMStrategy): .. warning:: - Unlike :obj:`onetl.strategy.snapshot_strategy.SnapshotBatchStrategy`, - it **saves** current HWM value after **each batch** into :ref:`hwm-store`. + Unlike :obj:`SnapshotBatchStrategy `, + it **saves** current HWM value after **each batch** into :ref:`HWM Store `. So if code inside the context manager raised an exception, like: @@ -406,7 +396,7 @@ class IncrementalBatchStrategy(OffsetMixin, BatchHWMStrategy): stop : Any, default: ``None`` - If passed, the value will be used for generating WHERE clauses with ``hwm_column`` filter, + If passed, the value will be used for generating WHERE clauses with ``hwm.expression`` filter, as a stop value for the last batch. If not set, the value is determined by a separated query: @@ -419,7 +409,7 @@ class IncrementalBatchStrategy(OffsetMixin, BatchHWMStrategy): .. note:: - ``stop`` should be the same type as ``hwm_column`` value, + ``stop`` should be the same type as ``hwm.expression`` value, e.g. :obj:`datetime.datetime` for ``TIMESTAMP`` column, :obj:`datetime.date` for ``DATE``, and so on offset : Any, default: ``None`` @@ -482,7 +472,8 @@ class IncrementalBatchStrategy(OffsetMixin, BatchHWMStrategy): from onetl.connection import Postgres, Hive from onetl.db import DBReader - from onetl.strategy import IncrementalStrategy + from onetl.strategy import IncrementalBatchStrategy + from onetl.hwm import AutoDetectHWM from pyspark.sql import SparkSession @@ -507,7 +498,7 @@ class IncrementalBatchStrategy(OffsetMixin, BatchHWMStrategy): connection=postgres, source="public.mydata", columns=["id", "data"], - hwm_column="id", + hwm=DBReader.AutoDetectHWM(name="some_hwm_name", expression="id"), ) writer = DBWriter(connection=hive, target="newtable") @@ -601,7 +592,7 @@ class IncrementalBatchStrategy(OffsetMixin, BatchHWMStrategy): ... N: WHERE id > 1900 AND id <= 2000; -- until stop - ``hwm_column`` can be a date or datetime, not only integer: + ``hwm.expression`` can be a date or datetime, not only integer: .. code:: python @@ -611,7 +602,7 @@ class IncrementalBatchStrategy(OffsetMixin, BatchHWMStrategy): connection=postgres, source="public.mydata", columns=["business_dt", "data"], - hwm_column="business_dt", + hwm=DBReader.AutoDetectHWM(name="some_hwm_name", expression="business_dt"), ) with IncrementalBatchStrategy( @@ -647,11 +638,8 @@ class IncrementalBatchStrategy(OffsetMixin, BatchHWMStrategy): """ def __next__(self): - result = super().__next__() - self.save_hwm() - - return result + return super().__next__() @classmethod def _log_exclude_fields(cls) -> set[str]: diff --git a/onetl/strategy/snapshot_strategy.py b/onetl/strategy/snapshot_strategy.py index 951f1d58c..3bf6913cb 100644 --- a/onetl/strategy/snapshot_strategy.py +++ b/onetl/strategy/snapshot_strategy.py @@ -67,6 +67,7 @@ class SnapshotStrategy(BaseStrategy): from onetl.connection import Postgres from onetl.db import DBReader from onetl.strategy import SnapshotStrategy + from onetl.hwm import AutoDetectHWM from pyspark.sql import SparkSession @@ -89,7 +90,7 @@ class SnapshotStrategy(BaseStrategy): connection=postgres, source="public.mydata", columns=["id", "data"], - hwm_column="id", + hwm=DBReader.AutoDetectHWM(name="some_hwm_name", expression="id"), ) writer = DBWriter(connection=hive, target="newtable") @@ -136,7 +137,7 @@ class SnapshotBatchStrategy(BatchHWMStrategy): Cannot be used with :ref:`file-downloader` - Same as :obj:`onetl.strategy.snapshot_strategy.SnapshotStrategy`, + Same as :obj:`SnapshotStrategy `, but reads data from the source in sequential batches (1..N) like: .. code:: sql @@ -155,7 +156,7 @@ class SnapshotBatchStrategy(BatchHWMStrategy): .. note:: This strategy uses HWM column value to filter data for each batch, - but **does not** save it into :ref:`hwm-store`. + but does **NOT** save it into :ref:`HWM Store `. So every run starts from the beginning, not from the previous HWM value. .. note:: @@ -194,7 +195,7 @@ class SnapshotBatchStrategy(BatchHWMStrategy): start : Any, default: ``None`` - If passed, the value will be used for generating WHERE clauses with ``hwm_column`` filter, + If passed, the value will be used for generating WHERE clauses with ``hwm.expression`` filter, as a start value for the first batch. If not set, the value is determined by a separated query: @@ -207,12 +208,12 @@ class SnapshotBatchStrategy(BatchHWMStrategy): .. note:: - ``start`` should be the same type as ``hwm_column`` value, + ``start`` should be the same type as ``hwm.expression`` value, e.g. :obj:`datetime.datetime` for ``TIMESTAMP`` column, :obj:`datetime.date` for ``DATE``, and so on stop : Any, default: ``None`` - If passed, the value will be used for generating WHERE clauses with ``hwm_column`` filter, + If passed, the value will be used for generating WHERE clauses with ``hwm.expression`` filter, as a stop value for the last batch. If not set, the value is determined by a separated query: @@ -225,7 +226,7 @@ class SnapshotBatchStrategy(BatchHWMStrategy): .. note:: - ``stop`` should be the same type as ``hwm_column`` value, + ``stop`` should be the same type as ``hwm.expression`` value, e.g. :obj:`datetime.datetime` for ``TIMESTAMP`` column, :obj:`datetime.date` for ``DATE``, and so on Examples @@ -238,6 +239,7 @@ class SnapshotBatchStrategy(BatchHWMStrategy): from onetl.connection import Postgres, Hive from onetl.db import DBReader from onetl.strategy import SnapshotBatchStrategy + from onetl.hwm import AutoDetectHWM from pyspark.sql import SparkSession @@ -262,7 +264,7 @@ class SnapshotBatchStrategy(BatchHWMStrategy): connection=postgres, source="public.mydata", columns=["id", "data"], - hwm_column="id", + hwm=DBReader.AutoDetectHWM(name="some_hwm_name", expression="id"), ) writer = DBWriter(connection=hive, target="newtable") @@ -377,7 +379,7 @@ class SnapshotBatchStrategy(BatchHWMStrategy): ... N: WHERE id > 1900 AND id <= 2000; -- until stop - ``hwm_column`` can be a date or datetime, not only integer: + ``hwm.expression`` can be a date or datetime, not only integer: .. code:: python @@ -387,7 +389,7 @@ class SnapshotBatchStrategy(BatchHWMStrategy): connection=postgres, source="public.mydata", columns=["business_dt", "data"], - hwm_column="business_dt", + hwm=DBReader.AutoDetectHWM(name="some_hwm_name", expression="business_dt"), ) with SnapshotBatchStrategy( diff --git a/requirements/core.txt b/requirements/core.txt index 96c8959e5..01398e154 100644 --- a/requirements/core.txt +++ b/requirements/core.txt @@ -1,5 +1,5 @@ deprecated -etl-entities>=1.4,<1.5 +etl-entities>=2.1.2,<2.2 evacuator>=1.0,<1.1 frozendict humanize diff --git a/requirements/docs.txt b/requirements/docs.txt index 4ff1db3e9..f6062e62f 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -6,10 +6,10 @@ pygments-csv-lexer sphinx sphinx-copybutton sphinx-design +sphinx-favicon sphinx-plantuml sphinx-tabs sphinx-toolbox sphinx_substitution_extensions sphinxcontrib-towncrier towncrier -urllib3 diff --git a/requirements/tests/mssql.txt b/requirements/tests/mssql.txt index 71352cc8f..c80997dc2 100644 --- a/requirements/tests/mssql.txt +++ b/requirements/tests/mssql.txt @@ -1,2 +1 @@ -# https://github.com/pymssql/pymssql/issues/832 -pymssql<2.2.8 +pymssql diff --git a/requirements/tests/spark-3.3.2.txt b/requirements/tests/spark-3.3.3.txt similarity index 80% rename from requirements/tests/spark-3.3.2.txt rename to requirements/tests/spark-3.3.3.txt index 588c10868..259340bf6 100644 --- a/requirements/tests/spark-3.3.2.txt +++ b/requirements/tests/spark-3.3.3.txt @@ -1,5 +1,5 @@ numpy>=1.16,<1.24 pandas>=1.0,<2 pyarrow>=1.0 -pyspark==3.3.2 +pyspark==3.3.3 sqlalchemy<2.0 diff --git a/requirements/tests/spark-3.4.1.txt b/requirements/tests/spark-3.4.2.txt similarity index 76% rename from requirements/tests/spark-3.4.1.txt rename to requirements/tests/spark-3.4.2.txt index d86df7dfa..c7173637d 100644 --- a/requirements/tests/spark-3.4.1.txt +++ b/requirements/tests/spark-3.4.2.txt @@ -1,5 +1,5 @@ numpy>=1.16 pandas>=1.0 pyarrow>=1.0 -pyspark==3.4.1 +pyspark==3.4.2 sqlalchemy diff --git a/setup.cfg b/setup.cfg index 6a799b4c5..435f97203 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,7 +51,7 @@ max-cognitive-score = 20 # Max amount of cognitive complexity per module max-cognitive-average = 25 max-imports = 25 -max-imported-names = 50 +max-imported-names = 55 # Max of expression usages in a module max-module-expressions = 15 # Max of expression usages in a function @@ -214,6 +214,8 @@ ignore = WPS615, # RST213: Inline emphasis start-string without end-string. RST213, +# RST304: Unknown interpreted text role + RST304, # RST307: Error in "code" directive RST307, # WPS428 Found statement that has no effect @@ -259,12 +261,14 @@ ignore = # WPS604 Found incorrect node inside `class` body: pass WPS604, # WPS100 Found wrong module name: util - WPS100 + WPS100, # WPS436 Found protected module import: onetl._util # https://github.com/wemake-services/wemake-python-styleguide/issues/1441 - WPS436 + WPS436, # WPS201 Found module with too many imports: 26 > 25 - WPS201 + WPS201, +# WPS429 Found multiple assign targets + WPS429 # http://flake8.pycqa.org/en/latest/user/options.html?highlight=per-file-ignores#cmdoption-flake8-per-file-ignores per-file-ignores = @@ -382,6 +386,8 @@ per-file-ignores = WPS609, # WPS325 Found inconsistent `yield` statement WPS325, +# WPS360 Found an unnecessary use of a raw string + WPS360 # S106 Possible hardcoded password [test usage] S106, # WPS118 Found too long name @@ -405,9 +411,11 @@ per-file-ignores = # WPS520 Found compare with falsy constant: == [] WPS520, # B017 `pytest.raises(Exception)` should be considered evil - B017 + B017, # WPS202 Found too many module members: 40 > 35 - WPS202 + WPS202, +# WPS210 Found too many local variables: 21 > 20 + WPS210 [darglint] diff --git a/setup.py b/setup.py index f8b560707..6cff75434 100644 --- a/setup.py +++ b/setup.py @@ -70,6 +70,7 @@ def parse_requirements(file: Path) -> list[str]: "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Java Libraries", "Topic :: Software Development :: Libraries :: Python Modules", diff --git a/tests/.coveragerc b/tests/.coveragerc index 218452499..295122b8a 100644 --- a/tests/.coveragerc +++ b/tests/.coveragerc @@ -1,5 +1,5 @@ [run] -branch = True +branch = true omit = tests/* parallel = true data_file = reports/.coverage diff --git a/tests/fixtures/global_hwm_store.py b/tests/fixtures/global_hwm_store.py index 556d38b36..f10a0089d 100644 --- a/tests/fixtures/global_hwm_store.py +++ b/tests/fixtures/global_hwm_store.py @@ -1,6 +1,5 @@ import pytest - -from onetl.hwm.store import MemoryHWMStore +from etl_entities.hwm_store import MemoryHWMStore @pytest.fixture(scope="function", autouse=True) diff --git a/tests/fixtures/hwm_delta.py b/tests/fixtures/hwm_delta.py index dd981367e..271d3839e 100644 --- a/tests/fixtures/hwm_delta.py +++ b/tests/fixtures/hwm_delta.py @@ -2,49 +2,62 @@ from datetime import date, datetime, timedelta import pytest -from etl_entities import ( - Column, - DateHWM, - DateTimeHWM, - FileListHWM, - IntHWM, - RemoteFolder, - Table, -) +from etl_entities.hwm import ColumnDateHWM, ColumnDateTimeHWM, ColumnIntHWM, FileListHWM @pytest.fixture( params=[ ( - IntHWM( - source=Table(name=secrets.token_hex(5), db=secrets.token_hex(5), instance="proto://domain.com"), - column=Column(name=secrets.token_hex(5)), + ColumnIntHWM( + name=secrets.token_hex(5), + # no source + expression=secrets.token_hex(5), + value=10, + ), + 5, + ), + ( + ColumnIntHWM( + name=secrets.token_hex(5), + source=secrets.token_hex(5), + expression=secrets.token_hex(5), value=10, ), 5, ), ( - DateHWM( - source=Table(name=secrets.token_hex(5), db=secrets.token_hex(5), instance="proto://domain.com"), - column=Column(name=secrets.token_hex(5)), + ColumnDateHWM( + name=secrets.token_hex(5), + source=secrets.token_hex(5), + expression=secrets.token_hex(5), value=date(year=2023, month=8, day=15), ), timedelta(days=31), ), ( - DateTimeHWM( - source=Table(name=secrets.token_hex(5), db=secrets.token_hex(5), instance="proto://domain.com"), - column=Column(name=secrets.token_hex(5)), + ColumnDateTimeHWM( + name=secrets.token_hex(5), + source=secrets.token_hex(5), + expression=secrets.token_hex(5), value=datetime(year=2023, month=8, day=15, hour=11, minute=22, second=33), ), timedelta(seconds=50), ), ( FileListHWM( - source=RemoteFolder(name=f"/absolute/{secrets.token_hex(5)}", instance="ftp://ftp.server:21"), - value=["some/path", "another.file"], + name=secrets.token_hex(5), + # not directory + value=["/some/path", "/another.file"], + ), + "/third.file", + ), + ( + FileListHWM( + name=secrets.token_hex(5), + directory="/absolute/path", + value=["/absolute/path/file1", "/absolute/path/file2"], ), - "third.file", + "/absolute/path/file3", ), ], ) diff --git a/tests/fixtures/spark.py b/tests/fixtures/spark.py index 452c0f978..2dbe213a0 100644 --- a/tests/fixtures/spark.py +++ b/tests/fixtures/spark.py @@ -77,10 +77,9 @@ def maven_packages(): # There is no MongoDB connector for Spark less than 3.2 packages.extend(MongoDB.get_packages(spark_version=pyspark_version)) - if pyspark_version < (3, 5): - # There is no Excel files support for Spark less than 3.2 - # And there is still no package released for 3.5.0 https://github.com/crealytics/spark-excel/issues/787 - packages.extend(Excel.get_packages(spark_version=pyspark_version)) + # There is no Excel files support for Spark less than 3.2 + # And there is still no package released for 3.5.0 https://github.com/crealytics/spark-excel/issues/787 + packages.extend(Excel.get_packages(spark_version=pyspark_version)) return packages diff --git a/tests/libs/dummy/dummy.py b/tests/libs/dummy/dummy.py index b7a4d0a78..dad183864 100644 --- a/tests/libs/dummy/dummy.py +++ b/tests/libs/dummy/dummy.py @@ -1,5 +1,4 @@ -from onetl.hwm.store import register_hwm_store_class -from onetl.hwm.store.base_hwm_store import BaseHWMStore +from etl_entities.hwm_store import BaseHWMStore, register_hwm_store_class @register_hwm_store_class("dummy") diff --git a/tests/tests_integration/test_file_format_integration/test_avro_integration.py b/tests/tests_integration/test_file_format_integration/test_avro_integration.py index cb687776c..eaffd6499 100644 --- a/tests/tests_integration/test_file_format_integration/test_avro_integration.py +++ b/tests/tests_integration/test_file_format_integration/test_avro_integration.py @@ -13,8 +13,7 @@ try: from tests.util.assert_df import assert_equal_df except ImportError: - # pandas and spark can be missing if someone runs tests for file connections only - pass + pytest.skip("Missing pandas", allow_module_level=True) pytestmark = [pytest.mark.local_fs, pytest.mark.file_df_connection, pytest.mark.connection] @@ -72,7 +71,7 @@ def test_avro_reader( assert read_df.count() assert read_df.schema == df.schema - assert_equal_df(read_df, df) + assert_equal_df(read_df, df, order_by="id") @pytest.mark.parametrize( @@ -116,4 +115,4 @@ def test_avro_writer( assert read_df.count() assert read_df.schema == df.schema - assert_equal_df(read_df, df) + assert_equal_df(read_df, df, order_by="id") diff --git a/tests/tests_integration/test_file_format_integration/test_csv_integration.py b/tests/tests_integration/test_file_format_integration/test_csv_integration.py index 289e88273..5dbfd20e1 100644 --- a/tests/tests_integration/test_file_format_integration/test_csv_integration.py +++ b/tests/tests_integration/test_file_format_integration/test_csv_integration.py @@ -16,8 +16,7 @@ from tests.util.assert_df import assert_equal_df from tests.util.spark_df import reset_column_names except ImportError: - # pandas and spark can be missing if someone runs tests for file connections only - pass + pytest.skip("Missing pandas or pyspark", allow_module_level=True) pytestmark = [pytest.mark.local_fs, pytest.mark.file_df_connection, pytest.mark.connection] @@ -53,10 +52,11 @@ def test_csv_reader_with_infer_schema( # csv does not have header, so columns are named like "_c0", "_c1", etc expected_df = reset_column_names(expected_df) + first_column = expected_df.schema[0].name assert read_df.schema != df.schema assert read_df.schema == expected_df.schema - assert_equal_df(read_df, expected_df) + assert_equal_df(read_df, expected_df, order_by=first_column) @pytest.mark.parametrize( @@ -89,7 +89,7 @@ def test_csv_reader_with_options( assert read_df.count() assert read_df.schema == df.schema - assert_equal_df(read_df, df) + assert_equal_df(read_df, df, order_by="id") @pytest.mark.parametrize( @@ -131,4 +131,4 @@ def test_csv_writer_with_options( assert read_df.count() assert read_df.schema == df.schema - assert_equal_df(read_df, df) + assert_equal_df(read_df, df, order_by="id") diff --git a/tests/tests_integration/test_file_format_integration/test_excel_integration.py b/tests/tests_integration/test_file_format_integration/test_excel_integration.py index f9aaad38f..9228abd3d 100644 --- a/tests/tests_integration/test_file_format_integration/test_excel_integration.py +++ b/tests/tests_integration/test_file_format_integration/test_excel_integration.py @@ -16,8 +16,7 @@ from tests.util.assert_df import assert_equal_df from tests.util.spark_df import reset_column_names except ImportError: - # pandas and spark can be missing if someone runs tests for file connections only - pass + pytest.skip("Missing pandas or pyspark", allow_module_level=True) pytestmark = [pytest.mark.local_fs, pytest.mark.file_df_connection, pytest.mark.connection] @@ -33,8 +32,6 @@ def test_excel_reader_with_infer_schema( spark_version = get_spark_version(spark) if spark_version < (3, 2): pytest.skip("Excel files are supported on Spark 3.2+ only") - if spark_version >= (3, 5): - pytest.skip("Excel files are not supported on Spark 3.5+ yet") file_df_connection, source_path, _ = local_fs_file_df_connection_with_path_and_files df = file_df_dataframe @@ -55,10 +52,11 @@ def test_excel_reader_with_infer_schema( # excel does not have header, so columns are named like "_c0", "_c1", etc expected_df = reset_column_names(expected_df) + first_column = expected_df.schema[0].name assert read_df.schema != df.schema assert read_df.schema == expected_df.schema - assert_equal_df(read_df, expected_df) + assert_equal_df(read_df, expected_df, order_by=first_column) @pytest.mark.parametrize("format", ["xlsx", "xls"]) @@ -83,8 +81,6 @@ def test_excel_reader_with_options( spark_version = get_spark_version(spark) if spark_version < (3, 2): pytest.skip("Excel files are supported on Spark 3.2+ only") - if spark_version >= (3, 5): - pytest.skip("Excel files are not supported on Spark 3.5+ yet") local_fs, source_path, _ = local_fs_file_df_connection_with_path_and_files df = file_df_dataframe @@ -100,7 +96,7 @@ def test_excel_reader_with_options( assert read_df.count() assert read_df.schema == df.schema - assert_equal_df(read_df, df) + assert_equal_df(read_df, df, order_by="id") @pytest.mark.parametrize( @@ -121,8 +117,6 @@ def test_excel_writer( spark_version = get_spark_version(spark) if spark_version < (3, 2): pytest.skip("Excel files are supported on Spark 3.2+ only") - if spark_version >= (3, 5): - pytest.skip("Excel files are not supported on Spark 3.5+ yet") file_df_connection, source_path = local_fs_file_df_connection_with_path df = file_df_dataframe @@ -145,4 +139,4 @@ def test_excel_writer( assert read_df.count() assert read_df.schema == df.schema - assert_equal_df(read_df, df) + assert_equal_df(read_df, df, order_by="id") diff --git a/tests/tests_integration/test_file_format_integration/test_json_integration.py b/tests/tests_integration/test_file_format_integration/test_json_integration.py index 9f195233a..f1fbd1380 100644 --- a/tests/tests_integration/test_file_format_integration/test_json_integration.py +++ b/tests/tests_integration/test_file_format_integration/test_json_integration.py @@ -12,8 +12,7 @@ try: from tests.util.assert_df import assert_equal_df except ImportError: - # pandas and spark can be missing if someone runs tests for file connections only - pass + pytest.skip("Missing pandas", allow_module_level=True) pytestmark = [pytest.mark.local_fs, pytest.mark.file_df_connection, pytest.mark.connection] @@ -47,7 +46,7 @@ def test_json_reader( assert read_df.count() assert read_df.schema == df.schema - assert_equal_df(read_df, df) + assert_equal_df(read_df, df, order_by="id") def test_json_writer_is_not_supported( diff --git a/tests/tests_integration/test_file_format_integration/test_jsonline_integration.py b/tests/tests_integration/test_file_format_integration/test_jsonline_integration.py index ce955b261..f4678e17d 100644 --- a/tests/tests_integration/test_file_format_integration/test_jsonline_integration.py +++ b/tests/tests_integration/test_file_format_integration/test_jsonline_integration.py @@ -12,8 +12,7 @@ try: from tests.util.assert_df import assert_equal_df except ImportError: - # pandas and spark can be missing if someone runs tests for file connections only - pass + pytest.skip("Missing pandas", allow_module_level=True) pytestmark = [pytest.mark.local_fs, pytest.mark.file_df_connection, pytest.mark.connection] @@ -47,7 +46,7 @@ def test_jsonline_reader( assert read_df.count() assert read_df.schema == df.schema - assert_equal_df(read_df, df) + assert_equal_df(read_df, df, order_by="id") @pytest.mark.parametrize( @@ -85,4 +84,4 @@ def test_jsonline_writer( assert read_df.count() assert read_df.schema == df.schema - assert_equal_df(read_df, df) + assert_equal_df(read_df, df, order_by="id") diff --git a/tests/tests_integration/test_file_format_integration/test_orc_integration.py b/tests/tests_integration/test_file_format_integration/test_orc_integration.py index 9c11e43fa..a848f0f25 100644 --- a/tests/tests_integration/test_file_format_integration/test_orc_integration.py +++ b/tests/tests_integration/test_file_format_integration/test_orc_integration.py @@ -12,8 +12,7 @@ try: from tests.util.assert_df import assert_equal_df except ImportError: - # pandas and spark can be missing if someone runs tests for file connections only - pass + pytest.skip("Missing pandas", allow_module_level=True) pytestmark = [pytest.mark.local_fs, pytest.mark.file_df_connection, pytest.mark.connection] @@ -47,7 +46,7 @@ def test_orc_reader( assert read_df.count() assert read_df.schema == df.schema - assert_equal_df(read_df, df) + assert_equal_df(read_df, df, order_by="id") @pytest.mark.parametrize( @@ -85,4 +84,4 @@ def test_orc_writer( assert read_df.count() assert read_df.schema == df.schema - assert_equal_df(read_df, df) + assert_equal_df(read_df, df, order_by="id") diff --git a/tests/tests_integration/test_file_format_integration/test_parquet_integration.py b/tests/tests_integration/test_file_format_integration/test_parquet_integration.py index 79065e889..41d492c43 100644 --- a/tests/tests_integration/test_file_format_integration/test_parquet_integration.py +++ b/tests/tests_integration/test_file_format_integration/test_parquet_integration.py @@ -12,8 +12,7 @@ try: from tests.util.assert_df import assert_equal_df except ImportError: - # pandas and spark can be missing if someone runs tests for file connections only - pass + pytest.skip("Missing pandas", allow_module_level=True) pytestmark = [pytest.mark.local_fs, pytest.mark.file_df_connection, pytest.mark.connection] @@ -47,7 +46,7 @@ def test_parquet_reader( assert read_df.count() assert read_df.schema == df.schema - assert_equal_df(read_df, df) + assert_equal_df(read_df, df, order_by="id") @pytest.mark.parametrize( @@ -85,4 +84,4 @@ def test_parquet_writer( assert read_df.count() assert read_df.schema == df.schema - assert_equal_df(read_df, df) + assert_equal_df(read_df, df, order_by="id") diff --git a/tests/tests_integration/test_file_format_integration/test_xml_integration.py b/tests/tests_integration/test_file_format_integration/test_xml_integration.py index d03a6f61d..2be9d33a4 100644 --- a/tests/tests_integration/test_file_format_integration/test_xml_integration.py +++ b/tests/tests_integration/test_file_format_integration/test_xml_integration.py @@ -13,8 +13,7 @@ try: from tests.util.assert_df import assert_equal_df except ImportError: - # pandas and spark can be missing if someone runs tests for file connections only - pass + pytest.skip("Missing pandas", allow_module_level=True) pytestmark = [pytest.mark.local_fs, pytest.mark.file_df_connection, pytest.mark.connection] @@ -60,7 +59,7 @@ def test_xml_reader( read_df = reader.run() assert read_df.count() assert read_df.schema == df.schema - assert_equal_df(read_df, df) + assert_equal_df(read_df, df, order_by="id") def test_xml_reader_with_infer_schema( @@ -90,7 +89,7 @@ def test_xml_reader_with_infer_schema( assert set(read_df.columns) == set( expected_xml_attributes_df.columns, ) # "DataFrames have different column types: StructField('id', IntegerType(), True), StructField('id', LongType(), True), etc." - assert_equal_df(read_df, expected_xml_attributes_df) + assert_equal_df(read_df, expected_xml_attributes_df, order_by="id") @pytest.mark.parametrize( @@ -133,7 +132,7 @@ def test_xml_writer( assert read_df.count() assert read_df.schema == df.schema - assert_equal_df(read_df, df) + assert_equal_df(read_df, df, order_by="id") @pytest.mark.parametrize( @@ -166,4 +165,4 @@ def test_xml_reader_with_attributes( read_df = reader.run() assert read_df.count() assert read_df.schema == expected_xml_attributes_df.schema - assert_equal_df(read_df, expected_xml_attributes_df) + assert_equal_df(read_df, expected_xml_attributes_df, order_by="id") diff --git a/tests/tests_integration/tests_core_integration/test_file_df_reader_integration/test_common_file_df_reader_integration.py b/tests/tests_integration/tests_core_integration/test_file_df_reader_integration/test_common_file_df_reader_integration.py index ad58c0034..31e5f32b6 100644 --- a/tests/tests_integration/tests_core_integration/test_file_df_reader_integration/test_common_file_df_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/test_file_df_reader_integration/test_common_file_df_reader_integration.py @@ -28,8 +28,7 @@ from tests.util.assert_df import assert_equal_df except ImportError: - # pandas and spark can be missing if someone runs tests for file connections only - pass + pytest.skip("Missing pandas or pyspark", allow_module_level=True) def test_file_df_reader_run( diff --git a/tests/tests_integration/tests_core_integration/test_file_df_writer_integration/test_common_file_df_writer_integration.py b/tests/tests_integration/tests_core_integration/test_file_df_writer_integration/test_common_file_df_writer_integration.py index 448710183..91f74bc8f 100644 --- a/tests/tests_integration/tests_core_integration/test_file_df_writer_integration/test_common_file_df_writer_integration.py +++ b/tests/tests_integration/tests_core_integration/test_file_df_writer_integration/test_common_file_df_writer_integration.py @@ -15,8 +15,7 @@ try: from tests.util.assert_df import assert_equal_df except ImportError: - # pandas and spark can be missing if someone runs tests for file connections only - pass + pytest.skip("Missing pandas", allow_module_level=True) @pytest.mark.parametrize( diff --git a/tests/tests_integration/tests_core_integration/test_file_downloader_integration.py b/tests/tests_integration/tests_core_integration/test_file_downloader_integration.py index 0a932dd46..f3f594e1a 100644 --- a/tests/tests_integration/tests_core_integration/test_file_downloader_integration.py +++ b/tests/tests_integration/tests_core_integration/test_file_downloader_integration.py @@ -7,7 +7,7 @@ from pathlib import Path, PurePosixPath import pytest -from etl_entities import FileListHWM +from etl_entities.hwm import FileListHWM from onetl.exception import DirectoryNotFoundError, NotAFileError from onetl.file import FileDownloader @@ -881,10 +881,11 @@ def test_file_downloader_detect_hwm_type_snapshot_batch_strategy( connection=file_connection, local_path=local_path, source_path=remote_path, - hwm_type="file_list", + hwm=FileListHWM(name=secrets.token_hex(5)), ) - with pytest.raises(ValueError, match="`hwm_type` cannot be used in batch strategy"): + error_message = "FileDownloader(hwm=...) cannot be used with SnapshotBatchStrategy" + with pytest.raises(ValueError, match=re.escape(error_message)): with SnapshotBatchStrategy(step=100500): downloader.run() @@ -900,10 +901,11 @@ def test_file_downloader_detect_hwm_type_incremental_batch_strategy( connection=file_connection, local_path=local_path, source_path=remote_path, - hwm_type="file_list", + hwm=FileListHWM(name=secrets.token_hex(5)), ) - with pytest.raises(ValueError, match="`hwm_type` cannot be used in batch strategy"): + error_message = "FileDownloader(hwm=...) cannot be used with IncrementalBatchStrategy" + with pytest.raises(ValueError, match=re.escape(error_message)): with IncrementalBatchStrategy( step=timedelta(days=5), ): @@ -922,10 +924,11 @@ def test_file_downloader_detect_hwm_type_snapshot_strategy( connection=file_connection, local_path=local_path, source_path=remote_path, - hwm_type="file_list", + hwm=FileListHWM(name=secrets.token_hex(5)), ) - with pytest.raises(ValueError, match="`hwm_type` cannot be used in snapshot strategy"): + error_message = "FileDownloader(hwm=...) cannot be used with SnapshotStrategy" + with pytest.raises(ValueError, match=re.escape(error_message)): downloader.run() @@ -941,10 +944,11 @@ def test_file_downloader_file_hwm_strategy_with_wrong_parameters( connection=file_connection, local_path=local_path, source_path=remote_path, - hwm_type="file_list", + hwm=FileListHWM(name=secrets.token_hex(5)), ) - with pytest.raises(ValueError, match="If `hwm_type` is passed you can't specify an `offset`"): + error_message = "FileDownloader(hwm=...) cannot be used with IncrementalStrategy(offset=1, ...)" + with pytest.raises(ValueError, match=re.escape(error_message)): with IncrementalStrategy(offset=1): downloader.run() @@ -952,26 +956,18 @@ def test_file_downloader_file_hwm_strategy_with_wrong_parameters( downloader.run() -@pytest.mark.parametrize( - "hwm_type", - [ - "file_list", - FileListHWM, - ], -) def test_file_downloader_file_hwm_strategy( file_connection_with_path_and_files, tmp_path_factory, - hwm_type, ): - file_connection, remote_path, uploaded_files = file_connection_with_path_and_files + file_connection, remote_path, _ = file_connection_with_path_and_files local_path = tmp_path_factory.mktemp("local_path") downloader = FileDownloader( connection=file_connection, local_path=local_path, - hwm_type=hwm_type, source_path=remote_path, + hwm=FileListHWM(name=secrets.token_hex(5)), ) with IncrementalStrategy(): diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_clickhouse_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_clickhouse_reader_integration.py index 643187e22..692c3ea45 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_clickhouse_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_clickhouse_reader_integration.py @@ -1,7 +1,15 @@ +from string import ascii_letters + import pytest +try: + import pandas +except ImportError: + pytest.skip("Missing pandas", allow_module_level=True) + from onetl.connection import Clickhouse from onetl.db import DBReader +from tests.util.rand import rand_str pytestmark = pytest.mark.clickhouse @@ -26,6 +34,7 @@ def test_clickhouse_reader_snapshot(spark, processing, load_table_data): schema=load_table_data.schema, table=load_table_data.table, df=df, + order_by="id_int", ) @@ -88,6 +97,7 @@ def test_clickhouse_reader_snapshot_without_set_database(spark, processing, load schema=load_table_data.schema, table=load_table_data.table, df=df, + order_by="id_int", ) @@ -127,7 +137,11 @@ def test_clickhouse_reader_snapshot_with_columns(spark, processing, load_table_d assert table_df.columns != table_df_with_columns.columns assert table_df_with_columns.columns == columns # dataframe content is unchanged - processing.assert_equal_df(table_df_with_columns, other_frame=table_df) + processing.assert_equal_df( + table_df_with_columns, + other_frame=table_df, + order_by="id_int", + ) reader3 = DBReader( connection=clickhouse, @@ -141,6 +155,71 @@ def test_clickhouse_reader_snapshot_with_columns(spark, processing, load_table_d assert count_df.collect()[0][0] == table_df.count() +def test_clickhouse_reader_snapshot_with_columns_duplicated(spark, processing, prepare_schema_table): + clickhouse = Clickhouse( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader1 = DBReader( + connection=clickhouse, + source=prepare_schema_table.full_name, + ) + df1 = reader1.run() + + reader2 = DBReader( + connection=clickhouse, + source=prepare_schema_table.full_name, + columns=[ + "*", + "id_int", + ], + ) + + # Clickhouse can detect that column is already a part of * and does not produce duplicates + df2 = reader2.run() + assert df1.columns == df2.columns + + +def test_clickhouse_reader_snapshot_with_columns_mixed_naming(spark, processing, get_schema_table): + clickhouse = Clickhouse( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + # create table with mixed column names, e.g. IdInt + full_name, schema, table = get_schema_table + column_names = [] + table_fields = {} + for original_name in processing.column_names: + column_type = processing.get_column_type(original_name) + new_name = rand_str(alphabet=ascii_letters + " _").strip() + # wrap column names in DDL with quotes to preserve case + table_fields[f'"{new_name}"'] = column_type + column_names.append(new_name) + + processing.create_table(schema=schema, table=table, fields=table_fields) + + # before 0.10 this caused errors because * in column names was replaced with real column names, + # but they were not escaped + reader = DBReader( + connection=clickhouse, + source=full_name, + columns=["*"], + ) + + df = reader.run() + assert df.columns == column_names + + def test_clickhouse_reader_snapshot_with_where(spark, processing, load_table_data): clickhouse = Clickhouse( host=processing.host, @@ -178,6 +257,7 @@ def test_clickhouse_reader_snapshot_with_where(spark, processing, load_table_dat schema=load_table_data.schema, table=load_table_data.table, df=table_df1, + order_by="id_int", ) reader3 = DBReader( @@ -225,3 +305,68 @@ def test_clickhouse_reader_snapshot_with_columns_and_where(spark, processing, lo count_df = reader2.run() assert count_df.collect()[0][0] == table_df.count() + + +def test_clickhouse_reader_snapshot_nothing_to_read(spark, processing, prepare_schema_table): + clickhouse = Clickhouse( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=clickhouse, + source=prepare_schema_table.full_name, + ) + + span_gap = 10 + span_length = 50 + + # there are 2 spans with a gap between + + # 0..50 + first_span_begin = 0 + first_span_end = first_span_begin + span_length + + # 60..110 + second_span_begin = first_span_end + span_gap + second_span_end = second_span_begin + span_length + + first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) + second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + + # no data yet, nothing to read + df = reader.run() + assert not df.count() + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=first_span, + ) + + # .run() is not called, but dataframes are lazy, so it now contains all data from the source + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + + # read data explicitly + df = reader.run() + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + + # insert second span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=second_span, + ) + total_span = pandas.concat([first_span, second_span], ignore_index=True) + + # .run() is not called, but dataframes are lazy, so it now contains all data from the source + processing.assert_equal_df(df=df, other_frame=total_span, order_by="id_int") + + # read data explicitly + df = reader.run() + processing.assert_equal_df(df=df, other_frame=total_span, order_by="id_int") diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_greenplum_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_greenplum_reader_integration.py index 917066c63..0e60a6689 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_greenplum_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_greenplum_reader_integration.py @@ -1,7 +1,15 @@ +from string import ascii_letters + import pytest +try: + import pandas +except ImportError: + pytest.skip("Missing pandas", allow_module_level=True) + from onetl.connection import Greenplum from onetl.db import DBReader +from tests.util.rand import rand_str pytestmark = pytest.mark.greenplum @@ -86,6 +94,72 @@ def test_greenplum_reader_snapshot_with_columns(spark, processing, load_table_da assert count_df.collect()[0][0] == table_df.count() +def test_greenplum_reader_snapshot_with_columns_duplicated(spark, processing, prepare_schema_table): + greenplum = Greenplum( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra=processing.extra, + ) + + reader1 = DBReader( + connection=greenplum, + source=prepare_schema_table.full_name, + ) + df1 = reader1.run() + + reader2 = DBReader( + connection=greenplum, + source=prepare_schema_table.full_name, + columns=[ + "*", + "id_int", + ], + ) + + df2 = reader2.run() + assert df2.columns == df1.columns + ["id_int"] + + +def test_greenplum_reader_snapshot_with_columns_mixed_naming(spark, processing, get_schema_table): + greenplum = Greenplum( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra=processing.extra, + ) + + # create table with mixed column names, e.g. IdInt + full_name, schema, table = get_schema_table + column_names = [] + table_fields = {} + for original_name in processing.column_names: + column_type = processing.get_column_type(original_name) + new_name = rand_str(alphabet=ascii_letters + " _").strip() + # wrap column names in DDL with quotes to preserve case + table_fields[f'"{new_name}"'] = column_type + column_names.append(new_name) + + processing.create_table(schema=schema, table=table, fields=table_fields) + + # before 0.10 this caused errors because * in column names was replaced with real column names, + # but they were not escaped + reader = DBReader( + connection=greenplum, + source=full_name, + columns=["*"], + ) + + df = reader.run() + assert df.columns == column_names + + def test_greenplum_reader_snapshot_with_where(spark, processing, load_table_data): greenplum = Greenplum( host=processing.host, @@ -172,3 +246,69 @@ def test_greenplum_reader_snapshot_with_columns_and_where(spark, processing, loa count_df = reader2.run() assert count_df.collect()[0][0] == table_df.count() + + +def test_greenplum_reader_snapshot_nothing_to_read(spark, processing, prepare_schema_table): + greenplum = Greenplum( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra=processing.extra, + ) + + reader = DBReader( + connection=greenplum, + source=prepare_schema_table.full_name, + ) + + span_gap = 10 + span_length = 50 + + # there are 2 spans with a gap between + + # 0..50 + first_span_begin = 0 + first_span_end = first_span_begin + span_length + + # 60..110 + second_span_begin = first_span_end + span_gap + second_span_end = second_span_begin + span_length + + first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) + second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + + # no data yet, nothing to read + df = reader.run() + assert not df.count() + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=first_span, + ) + + # .run() is not called, but dataframes are lazy, so it now contains all data from the source + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + + # read data explicitly + df = reader.run() + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + + # insert second span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=second_span, + ) + total_span = pandas.concat([first_span, second_span], ignore_index=True) + + # .run() is not called, but dataframes are lazy, so it now contains all data from the source + processing.assert_equal_df(df=df, other_frame=total_span, order_by="id_int") + + # read data explicitly + df = reader.run() + processing.assert_equal_df(df=df, other_frame=total_span, order_by="id_int") diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_hive_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_hive_reader_integration.py index cf31e2786..3a3a0a8a0 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_hive_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_hive_reader_integration.py @@ -1,7 +1,13 @@ import pytest +try: + import pandas +except ImportError: + pytest.skip("Missing pandas", allow_module_level=True) + from onetl.connection import Hive from onetl.db import DBReader +from tests.util.rand import rand_str pytestmark = pytest.mark.hive @@ -19,6 +25,7 @@ def test_hive_reader(spark, processing, load_table_data): schema=load_table_data.schema, table=load_table_data.table, df=df, + order_by="id_int", ) @@ -51,7 +58,11 @@ def test_hive_reader_snapshot_with_columns(spark, processing, load_table_data): assert table_df.columns != table_df_with_columns.columns assert table_df_with_columns.columns == columns # dataframe content is unchanged - processing.assert_equal_df(table_df_with_columns, other_frame=table_df) + processing.assert_equal_df( + table_df_with_columns, + other_frame=table_df, + order_by="id_int", + ) reader3 = DBReader( connection=hive, @@ -65,6 +76,56 @@ def test_hive_reader_snapshot_with_columns(spark, processing, load_table_data): assert count_df.collect()[0][0] == table_df.count() +def test_hive_reader_snapshot_with_columns_duplicated(spark, prepare_schema_table): + hive = Hive(cluster="rnd-dwh", spark=spark) + + reader1 = DBReader( + connection=hive, + source=prepare_schema_table.full_name, + ) + df1 = reader1.run() + + reader2 = DBReader( + connection=hive, + source=prepare_schema_table.full_name, + columns=[ + "*", + "id_int", + ], + ) + + df2 = reader2.run() + assert df2.columns == df1.columns + ["id_int"] + + +def test_hive_reader_snapshot_with_columns_mixed_naming(spark, processing, get_schema_table): + hive = Hive(cluster="rnd-dwh", spark=spark) + + # create table with mixed column names, e.g. IdInt + full_name, schema, table = get_schema_table + column_names = [] + table_fields = {} + for original_name in processing.column_names: + column_type = processing.get_column_type(original_name) + new_name = rand_str() + # wrap column names in DDL with quotes to preserve case + table_fields[f"`{new_name}`"] = column_type + column_names.append(new_name) + + processing.create_table(schema=schema, table=table, fields=table_fields) + + # before 0.10 this caused errors because * in column names was replaced with real column names, + # but they were not escaped + reader = DBReader( + connection=hive, + source=full_name, + columns=["*"], + ) + + df = reader.run() + assert df.columns == column_names + + def test_hive_reader_snapshot_with_where(spark, processing, load_table_data): hive = Hive(cluster="rnd-dwh", spark=spark) @@ -80,6 +141,7 @@ def test_hive_reader_snapshot_with_where(spark, processing, load_table_data): schema=load_table_data.schema, table=load_table_data.table, df=table_df, + order_by="id_int", ) reader2 = DBReader( @@ -133,3 +195,60 @@ def test_hive_reader_non_existing_table(spark, get_schema_table): with pytest.raises(AnalysisException) as excinfo: reader.run() assert "does not exists" in str(excinfo.value) + + +def test_hive_reader_snapshot_nothing_to_read(spark, processing, prepare_schema_table): + hive = Hive(cluster="rnd-dwh", spark=spark) + reader = DBReader( + connection=hive, + source=prepare_schema_table.full_name, + ) + + span_gap = 10 + span_length = 50 + + # there are 2 spans with a gap between + + # 0..50 + first_span_begin = 0 + first_span_end = first_span_begin + span_length + + # 60..110 + second_span_begin = first_span_end + span_gap + second_span_end = second_span_begin + span_length + + first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) + second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + + # no data yet, nothing to read + df = reader.run() + assert not df.count() + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=first_span, + ) + + # .run() is not called, but dataframes are lazy, so it now contains all data from the source + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + + # read data explicitly + df = reader.run() + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + + # insert second span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=second_span, + ) + total_span = pandas.concat([first_span, second_span], ignore_index=True) + + # .run() is not called, but dataframes are lazy, so it now contains all data from the source + processing.assert_equal_df(df=df, other_frame=total_span, order_by="id_int") + + # read data explicitly + df = reader.run() + processing.assert_equal_df(df=df, other_frame=total_span, order_by="id_int") diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_kafka_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_kafka_reader_integration.py index 6f5d3e545..635e257b9 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_kafka_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_kafka_reader_integration.py @@ -3,6 +3,11 @@ import pytest +try: + import pandas +except ImportError: + pytest.skip("Missing pandas", allow_module_level=True) + from onetl._util.spark import get_spark_version from onetl.connection import Kafka from onetl.db import DBReader @@ -189,3 +194,30 @@ def test_kafka_reader_topic_does_not_exist(spark, kafka_processing): with pytest.raises(ValueError, match="Topic 'missing' doesn't exist"): reader.run() + + +@pytest.mark.parametrize("group_id_option", ["group.id", "groupIdPrefix"]) +def test_kafka_reader_with_group_id(group_id_option, spark, kafka_processing, schema): + if get_spark_version(spark).major < 3: + pytest.skip("Spark 3.x or later is required to pas group.id") + + topic, processing, expected_df = kafka_processing + + kafka = Kafka( + spark=spark, + addresses=[f"{processing.host}:{processing.port}"], + cluster="cluster", + extra={group_id_option: "test"}, + ) + + reader = DBReader( + connection=kafka, + source=topic, + ) + df = reader.run() + processing.assert_equal_df(processing.json_deserialize(df, df_schema=schema), other_frame=expected_df) + + # Spark does not report to Kafka which messages were read, so Kafka does not remember latest offsets for groupId + # https://stackoverflow.com/a/64003569 + df = reader.run() + processing.assert_equal_df(processing.json_deserialize(df, df_schema=schema), other_frame=expected_df) diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mongodb_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mongodb_reader_integration.py index bbd4d1f00..5071270d2 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mongodb_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mongodb_reader_integration.py @@ -1,5 +1,10 @@ import pytest +try: + import pandas +except ImportError: + pytest.skip("Missing pandas", allow_module_level=True) + from onetl.connection import MongoDB from onetl.db import DBReader @@ -52,6 +57,7 @@ def test_mongodb_reader_snapshot(spark, processing, load_table_data, df_schema): schema=load_table_data.schema, table=load_table_data.table, df=df, + order_by="_id", ) @@ -98,6 +104,7 @@ def test_mongodb_reader_snapshot_with_where(spark, processing, load_table_data, schema=load_table_data.schema, table=load_table_data.table, df=table_df1, + order_by="_id", ) one_reader = DBReader( @@ -119,3 +126,69 @@ def test_mongodb_reader_snapshot_with_where(spark, processing, load_table_data, empty_df = empty_reader.run() assert not empty_df.count() + + +def test_mongodb_reader_snapshot_nothing_to_read(spark, processing, prepare_schema_table, df_schema): + mongo = MongoDB( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=mongo, + source=prepare_schema_table.table, + df_schema=df_schema, + ) + + span_gap = 10 + span_length = 50 + + # there are 2 spans with a gap between + + # 0..50 + first_span_begin = 0 + first_span_end = first_span_begin + span_length + + # 60..110 + second_span_begin = first_span_end + span_gap + second_span_end = second_span_begin + span_length + + first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) + second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + + # no data yet, nothing to read + df = reader.run() + assert not df.count() + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=first_span, + ) + + # .run() is not called, but dataframes are lazy, so it now contains all data from the source + processing.assert_equal_df(df=df, other_frame=first_span, order_by="_id") + + # read data explicitly + df = reader.run() + processing.assert_equal_df(df=df, other_frame=first_span, order_by="_id") + + # insert second span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=second_span, + ) + total_span = pandas.concat([first_span, second_span], ignore_index=True) + + # .run() is not called, but dataframes are lazy, so it now contains all data from the source + processing.assert_equal_df(df=df, other_frame=total_span, order_by="_id") + + # read data explicitly + df = reader.run() + processing.assert_equal_df(df=df, other_frame=total_span, order_by="_id") diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mssql_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mssql_reader_integration.py index 12ace609d..5409b85e9 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mssql_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mssql_reader_integration.py @@ -1,7 +1,15 @@ +from string import ascii_letters + import pytest +try: + import pandas +except ImportError: + pytest.skip("Missing pandas", allow_module_level=True) + from onetl.connection import MSSQL from onetl.db import DBReader +from tests.util.rand import rand_str pytestmark = pytest.mark.mssql @@ -27,6 +35,7 @@ def test_mssql_reader_snapshot(spark, processing, load_table_data): schema=load_table_data.schema, table=load_table_data.table, df=df, + order_by="id_int", ) @@ -108,7 +117,11 @@ def test_mssql_reader_snapshot_with_columns(spark, processing, load_table_data): assert table_df.columns != table_df_with_columns.columns assert table_df_with_columns.columns == columns # dataframe content is unchanged - processing.assert_equal_df(table_df_with_columns, other_frame=table_df) + processing.assert_equal_df( + table_df_with_columns, + other_frame=table_df, + order_by="id_int", + ) reader3 = DBReader( connection=mssql, @@ -122,6 +135,65 @@ def test_mssql_reader_snapshot_with_columns(spark, processing, load_table_data): assert count_df.collect()[0][0] == table_df.count() +def test_mssql_reader_snapshot_with_columns_duplicated(spark, processing, prepare_schema_table): + mssql = MSSQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra={"trustServerCertificate": "true"}, + ) + + reader = DBReader( + connection=mssql, + source=prepare_schema_table.full_name, + columns=[ + "*", + "id_int", + ], + ) + with pytest.raises(Exception, match="The column 'id_int' was specified multiple times"): + reader.run() + + +def test_mssql_reader_snapshot_with_columns_mixed_naming(spark, processing, get_schema_table): + mssql = MSSQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra={"trustServerCertificate": "true"}, + ) + + # create table with mixed column names, e.g. IdInt + full_name, schema, table = get_schema_table + column_names = [] + table_fields = {} + for original_name in processing.column_names: + column_type = processing.get_column_type(original_name) + new_name = rand_str(alphabet=ascii_letters + " _").strip() + # wrap column names in DDL with quotes to preserve case + table_fields[f'"{new_name}"'] = column_type + column_names.append(new_name) + + processing.create_table(schema=schema, table=table, fields=table_fields) + + # before 0.10 this caused errors because * in column names was replaced with real column names, + # but they were not escaped + reader = DBReader( + connection=mssql, + source=full_name, + columns=["*"], + ) + + df = reader.run() + assert df.columns == column_names + + def test_mssql_reader_snapshot_with_where(spark, processing, load_table_data): mssql = MSSQL( host=processing.host, @@ -160,6 +232,7 @@ def test_mssql_reader_snapshot_with_where(spark, processing, load_table_data): schema=load_table_data.schema, table=load_table_data.table, df=table_df1, + order_by="id_int", ) reader3 = DBReader( @@ -208,3 +281,69 @@ def test_mssql_reader_snapshot_with_columns_and_where(spark, processing, load_ta count_df = reader2.run() assert count_df.collect()[0][0] == table_df.count() + + +def test_mssql_reader_snapshot_nothing_to_read(spark, processing, prepare_schema_table): + mssql = MSSQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra={"trustServerCertificate": "true"}, + ) + + reader = DBReader( + connection=mssql, + source=prepare_schema_table.full_name, + ) + + span_gap = 10 + span_length = 50 + + # there are 2 spans with a gap between + + # 0..50 + first_span_begin = 0 + first_span_end = first_span_begin + span_length + + # 60..110 + second_span_begin = first_span_end + span_gap + second_span_end = second_span_begin + span_length + + first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) + second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + + # no data yet, nothing to read + df = reader.run() + assert not df.count() + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=first_span, + ) + + # .run() is not called, but dataframes are lazy, so it now contains all data from the source + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + + # read data explicitly + df = reader.run() + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + + # insert second span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=second_span, + ) + total_span = pandas.concat([first_span, second_span], ignore_index=True) + + # .run() is not called, but dataframes are lazy, so it now contains all data from the source + processing.assert_equal_df(df=df, other_frame=total_span, order_by="id_int") + + # read data explicitly + df = reader.run() + processing.assert_equal_df(df=df, other_frame=total_span, order_by="id_int") diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mysql_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mysql_reader_integration.py index e0865866a..65b922732 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mysql_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mysql_reader_integration.py @@ -1,7 +1,15 @@ +from string import ascii_letters + import pytest +try: + import pandas +except ImportError: + pytest.skip("Missing pandas", allow_module_level=True) + from onetl.connection import MySQL from onetl.db import DBReader +from tests.util.rand import rand_str pytestmark = pytest.mark.mysql @@ -27,6 +35,7 @@ def test_mysql_reader_snapshot(spark, processing, load_table_data): schema=load_table_data.schema, table=load_table_data.table, df=df, + order_by="id_int", ) @@ -90,6 +99,7 @@ def test_mysql_reader_snapshot_with_not_set_database(spark, processing, load_tab schema=load_table_data.schema, table=load_table_data.table, df=df, + order_by="id_int", ) @@ -129,7 +139,11 @@ def test_mysql_reader_snapshot_with_columns(spark, processing, load_table_data): assert table_df.columns != table_df_with_columns.columns assert table_df_with_columns.columns == columns # dataframe content is unchanged - processing.assert_equal_df(table_df_with_columns, other_frame=table_df) + processing.assert_equal_df( + table_df_with_columns, + other_frame=table_df, + order_by="id_int", + ) reader3 = DBReader( connection=mysql, @@ -143,6 +157,63 @@ def test_mysql_reader_snapshot_with_columns(spark, processing, load_table_data): assert count_df.collect()[0][0] == table_df.count() +def test_mysql_reader_snapshot_with_columns_duplicated(spark, processing, prepare_schema_table): + mysql = MySQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=mysql, + source=prepare_schema_table.full_name, + columns=[ + "*", + "id_int", + ], + ) + with pytest.raises(Exception, match="Duplicate column name"): + reader.run() + + +def test_mysql_reader_snapshot_with_columns_mixed_naming(spark, processing, get_schema_table): + mysql = MySQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + # create table with mixed column names, e.g. IdInt + full_name, schema, table = get_schema_table + column_names = [] + table_fields = {} + for original_name in processing.column_names: + column_type = processing.get_column_type(original_name) + new_name = rand_str(alphabet=ascii_letters + " _").strip() + # wrap column names in DDL with quotes to preserve case + table_fields[f"`{new_name}`"] = column_type + column_names.append(new_name) + + processing.create_table(schema=schema, table=table, fields=table_fields) + + # before 0.10 this caused errors because * in column names was replaced with real column names, + # but they were not escaped + reader = DBReader( + connection=mysql, + source=full_name, + columns=["*"], + ) + + df = reader.run() + assert df.columns == column_names + + def test_mysql_reader_snapshot_with_where(spark, processing, load_table_data): mysql = MySQL( host=processing.host, @@ -179,6 +250,7 @@ def test_mysql_reader_snapshot_with_where(spark, processing, load_table_data): schema=load_table_data.schema, table=load_table_data.table, df=table_df1, + order_by="id_int", ) reader3 = DBReader( @@ -226,3 +298,68 @@ def test_mysql_reader_snapshot_with_columns_and_where(spark, processing, load_ta count_df = reader2.run() assert count_df.collect()[0][0] == table_df.count() + + +def test_mysql_reader_snapshot_nothing_to_read(spark, processing, prepare_schema_table): + mysql = MySQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=mysql, + source=prepare_schema_table.full_name, + ) + + span_gap = 10 + span_length = 50 + + # there are 2 spans with a gap between + + # 0..50 + first_span_begin = 0 + first_span_end = first_span_begin + span_length + + # 60..110 + second_span_begin = first_span_end + span_gap + second_span_end = second_span_begin + span_length + + first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) + second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + + # no data yet, nothing to read + df = reader.run() + assert not df.count() + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=first_span, + ) + + # .run() is not called, but dataframes are lazy, so it now contains all data from the source + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + + # read data explicitly + df = reader.run() + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + + # insert second span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=second_span, + ) + total_span = pandas.concat([first_span, second_span], ignore_index=True) + + # .run() is not called, but dataframes are lazy, so it now contains all data from the source + processing.assert_equal_df(df=df, other_frame=total_span, order_by="id_int") + + # read data explicitly + df = reader.run() + processing.assert_equal_df(df=df, other_frame=total_span, order_by="id_int") diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_oracle_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_oracle_reader_integration.py index b379923ef..126420c61 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_oracle_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_oracle_reader_integration.py @@ -1,7 +1,15 @@ +from string import ascii_letters + import pytest +try: + import pandas +except ImportError: + pytest.skip("Missing pandas", allow_module_level=True) + from onetl.connection import Oracle from onetl.db import DBReader +from tests.util.rand import rand_str pytestmark = pytest.mark.oracle @@ -27,6 +35,7 @@ def test_oracle_reader_snapshot(spark, processing, load_table_data): schema=load_table_data.schema, table=load_table_data.table, df=df, + order_by="id_int", ) @@ -108,7 +117,11 @@ def test_oracle_reader_snapshot_with_columns(spark, processing, load_table_data) assert table_df.columns != table_df_with_columns.columns assert table_df_with_columns.columns == columns # dataframe content is unchanged - processing.assert_equal_df(table_df_with_columns, other_frame=table_df) + processing.assert_equal_df( + table_df_with_columns, + other_frame=table_df, + order_by="id_int", + ) reader3 = DBReader( connection=oracle, @@ -122,6 +135,66 @@ def test_oracle_reader_snapshot_with_columns(spark, processing, load_table_data) assert count_df.collect()[0][0] == table_df.count() +def test_oracle_reader_snapshot_with_columns_duplicated(spark, processing, prepare_schema_table): + oracle = Oracle( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + spark=spark, + sid=processing.sid, + service_name=processing.service_name, + ) + + reader = DBReader( + connection=oracle, + source=prepare_schema_table.full_name, + columns=[ + "*", + "ID_INT", + ], + ) + # https://stackoverflow.com/questions/27965130/how-to-select-column-from-table-in-oracle + with pytest.raises(Exception, match="java.sql.SQLSyntaxErrorException"): + reader.run() + + +def test_oracle_reader_snapshot_with_columns_mixed_naming(spark, processing, get_schema_table): + oracle = Oracle( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + spark=spark, + sid=processing.sid, + service_name=processing.service_name, + ) + + # create table with mixed column names, e.g. IdInt + full_name, schema, table = get_schema_table + column_names = [] + table_fields = {} + for original_name in processing.column_names: + column_type = processing.get_column_type(original_name) + new_name = rand_str(alphabet=ascii_letters + " _").strip() + # wrap column names in DDL with quotes to preserve case + table_fields[f'"{new_name}"'] = column_type + column_names.append(new_name) + + processing.create_table(schema=schema, table=table, fields=table_fields) + + # before 0.10 this caused errors because * in column names was replaced with real column names, + # but they were not escaped + reader = DBReader( + connection=oracle, + source=full_name, + columns=["*"], + ) + + df = reader.run() + assert df.columns == column_names + + def test_oracle_reader_snapshot_with_where(spark, processing, load_table_data): oracle = Oracle( host=processing.host, @@ -159,6 +232,7 @@ def test_oracle_reader_snapshot_with_where(spark, processing, load_table_data): schema=load_table_data.schema, table=load_table_data.table, df=table_df1, + order_by="id_int", ) reader3 = DBReader( @@ -207,3 +281,69 @@ def test_oracle_reader_snapshot_with_columns_and_where(spark, processing, load_t count_df = reader2.run() assert count_df.collect()[0][0] == table_df.count() + + +def test_oracle_reader_snapshot_nothing_to_read(spark, processing, prepare_schema_table): + oracle = Oracle( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + spark=spark, + sid=processing.sid, + service_name=processing.service_name, + ) + + reader = DBReader( + connection=oracle, + source=prepare_schema_table.full_name, + ) + + span_gap = 10 + span_length = 50 + + # there are 2 spans with a gap between + + # 0..50 + first_span_begin = 0 + first_span_end = first_span_begin + span_length + + # 60..110 + second_span_begin = first_span_end + span_gap + second_span_end = second_span_begin + span_length + + first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) + second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + + # no data yet, nothing to read + df = reader.run() + assert not df.count() + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=first_span, + ) + + # .run() is not called, but dataframes are lazy, so it now contains all data from the source + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + + # read data explicitly + df = reader.run() + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + + # insert second span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=second_span, + ) + total_span = pandas.concat([first_span, second_span], ignore_index=True) + + # .run() is not called, but dataframes are lazy, so it now contains all data from the source + processing.assert_equal_df(df=df, other_frame=total_span, order_by="id_int") + + # read data explicitly + df = reader.run() + processing.assert_equal_df(df=df, other_frame=total_span, order_by="id_int") diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_postgres_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_postgres_reader_integration.py index 617eba903..bbb2ee472 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_postgres_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_postgres_reader_integration.py @@ -1,7 +1,15 @@ +from string import ascii_letters + import pytest +try: + import pandas +except ImportError: + pytest.skip("Missing pandas", allow_module_level=True) + from onetl.connection import Postgres from onetl.db import DBReader +from tests.util.rand import rand_str pytestmark = pytest.mark.postgres @@ -26,6 +34,7 @@ def test_postgres_reader_snapshot(spark, processing, load_table_data): schema=load_table_data.schema, table=load_table_data.table, df=table_df, + order_by="id_int", ) @@ -106,7 +115,11 @@ def test_postgres_reader_snapshot_with_columns(spark, processing, load_table_dat assert table_df_with_columns.columns == columns # dataframe content is unchanged - processing.assert_equal_df(table_df_with_columns, other_frame=table_df) + processing.assert_equal_df( + table_df_with_columns, + other_frame=table_df, + order_by="id_int", + ) reader3 = DBReader( connection=postgres, @@ -120,6 +133,65 @@ def test_postgres_reader_snapshot_with_columns(spark, processing, load_table_dat assert count_df.collect()[0][0] == table_df.count() +def test_postgres_reader_snapshot_with_columns_duplicate(spark, processing, prepare_schema_table): + postgres = Postgres( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=postgres, + source=prepare_schema_table.full_name, + columns=[ + "*", + "id_int", + ], + ) + + error_msg = r"(The column `id_int` already exists|Found duplicate column\(s\) in the data schema: `id_int`)" + with pytest.raises(Exception, match=error_msg): + reader.run() + + +def test_postgres_reader_snapshot_with_columns_mixed_naming(spark, processing, get_schema_table): + postgres = Postgres( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + # create table with mixed column names, e.g. IdInt + full_name, schema, table = get_schema_table + column_names = [] + table_fields = {} + for original_name in processing.column_names: + column_type = processing.get_column_type(original_name) + new_name = rand_str(alphabet=ascii_letters + " _").strip() + # wrap column names in DDL with quotes to preserve case + table_fields[f'"{new_name}"'] = column_type + column_names.append(new_name) + + processing.create_table(schema=schema, table=table, fields=table_fields) + + # before 0.10 this caused errors because * in column names was replaced with real column names, + # but they were not escaped + reader = DBReader( + connection=postgres, + source=full_name, + columns=["*"], + ) + + df = reader.run() + assert df.columns == column_names + + def test_postgres_reader_snapshot_with_where(spark, processing, load_table_data): postgres = Postgres( host=processing.host, @@ -156,6 +228,7 @@ def test_postgres_reader_snapshot_with_where(spark, processing, load_table_data) schema=load_table_data.schema, table=load_table_data.table, df=table_df1, + order_by="id_int", ) reader3 = DBReader( @@ -227,6 +300,7 @@ def test_postgres_reader_snapshot_with_pydantic_options(spark, processing, load_ schema=load_table_data.schema, table=load_table_data.table, df=table_df, + order_by="id_int", ) @@ -268,4 +342,70 @@ def test_postgres_reader_different_options(spark, processing, load_table_data, o schema=load_table_data.schema, table=load_table_data.table, df=table_df, + order_by="id_int", ) + + +def test_postgres_reader_snapshot_nothing_to_read(spark, processing, prepare_schema_table): + postgres = Postgres( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + reader = DBReader( + connection=postgres, + source=prepare_schema_table.full_name, + ) + + span_gap = 10 + span_length = 50 + + # there are 2 spans with a gap between + + # 0..50 + first_span_begin = 0 + first_span_end = first_span_begin + span_length + + # 60..110 + second_span_begin = first_span_end + span_gap + second_span_end = second_span_begin + span_length + + first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) + second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + + # no data yet, nothing to read + df = reader.run() + assert not df.count() + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=first_span, + ) + + # .run() is not called, but dataframes are lazy, so it now contains all data from the source + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + + # read data explicitly + df = reader.run() + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + + # insert second span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=second_span, + ) + total_span = pandas.concat([first_span, second_span], ignore_index=True) + + # .run() is not called, but dataframes are lazy, so it now contains all data from the source + processing.assert_equal_df(df=df, other_frame=total_span, order_by="id_int") + + # read data explicitly + df = reader.run() + processing.assert_equal_df(df=df, other_frame=total_span, order_by="id_int") diff --git a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_clickhouse_writer_integration.py b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_clickhouse_writer_integration.py index 60ed7a1d7..459794b35 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_clickhouse_writer_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_clickhouse_writer_integration.py @@ -28,4 +28,5 @@ def test_clickhouse_writer_snapshot(spark, processing, prepare_schema_table): schema=prepare_schema_table.schema, table=prepare_schema_table.table, df=df, + order_by="id_int", ) diff --git a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_hive_writer_integration.py b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_hive_writer_integration.py index 8ca74b06d..07b30cb07 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_hive_writer_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_hive_writer_integration.py @@ -39,6 +39,7 @@ def test_hive_writer_target_does_not_exist(spark, processing, get_schema_table, schema=get_schema_table.schema, table=get_schema_table.table, df=df, + order_by="id_int", ) diff --git a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_kafka_writer_integration.py b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_kafka_writer_integration.py index b35bfabad..cf045b310 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_kafka_writer_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_kafka_writer_integration.py @@ -175,7 +175,10 @@ def test_kafka_writer_key_column(spark, kafka_processing, kafka_spark_df): pd_df = processing.get_expected_df(topic, num_messages=df.count()) assert len(pd_df) == df.count() - processing.assert_equal_df(df, other_frame=pd_df.drop(columns=["partition", "headers", "topic"], axis=1)) + processing.assert_equal_df( + df, + other_frame=pd_df.drop(columns=["partition", "headers", "topic"], axis=1), + ) def test_kafka_writer_topic_column(spark, kafka_processing, caplog, kafka_spark_df): @@ -304,7 +307,10 @@ def test_kafka_writer_mode(spark, kafka_processing, kafka_spark_df): read_df = df.withColumn("key", lit(None)).withColumn("topic", lit(topic)).withColumn("partition", lit(0)) # Check that second dataframe record is appended to first dataframe in same topic - processing.assert_equal_df(pd_df.drop(columns=["headers"], axis=1), other_frame=read_df.union(read_df)) + processing.assert_equal_df( + pd_df.drop(columns=["headers"], axis=1), + other_frame=read_df.union(read_df), + ) def test_kafka_writer_mode_error(spark, kafka_processing, kafka_spark_df): diff --git a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_mongodb_writer_integration.py b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_mongodb_writer_integration.py index d5cd94fed..503d2ee2c 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_mongodb_writer_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_mongodb_writer_integration.py @@ -19,7 +19,6 @@ {"if_exists": "ignore"}, ], ) -@pytest.mark.flaky(reruns=2) def test_mongodb_writer_snapshot(spark, processing, get_schema_table, options, caplog): df = processing.create_spark_df(spark=spark) @@ -47,9 +46,12 @@ def test_mongodb_writer_snapshot(spark, processing, get_schema_table, options, c schema=get_schema_table.schema, table=get_schema_table.table, df=df, + order_by="_id", ) +# old MongoDB versions sometimes is missing some part of data during insert +@pytest.mark.flaky(reruns=2) def test_mongodb_writer_if_exists_append(spark, processing, get_schema_table): df = processing.create_spark_df(spark=spark, min_id=1, max_id=1500) df1 = df[df._id < 1001] @@ -76,9 +78,12 @@ def test_mongodb_writer_if_exists_append(spark, processing, get_schema_table): schema=get_schema_table.schema, table=get_schema_table.table, df=df, + order_by="_id", ) +# old MongoDB versions sometimes is missing some part of data during insert +@pytest.mark.flaky(reruns=2) def test_mongodb_writer_if_exists_replace_entire_collection(spark, processing, get_schema_table): df = processing.create_spark_df(spark=spark, min_id=1, max_id=1500) df1 = df[df._id < 1001] @@ -105,10 +110,13 @@ def test_mongodb_writer_if_exists_replace_entire_collection(spark, processing, g schema=get_schema_table.schema, table=get_schema_table.table, df=df2, + order_by="_id", ) -def test_mongodb_writer_if_exists_error(spark, processing, get_schema_table, caplog): +# old MongoDB versions sometimes is missing some part of data during insert +@pytest.mark.flaky(reruns=2) +def test_mongodb_writer_if_exists_error(spark, processing, get_schema_table): df = processing.create_spark_df(spark=spark, min_id=1, max_id=1500) mongo = MongoDB( @@ -137,9 +145,12 @@ def test_mongodb_writer_if_exists_error(spark, processing, get_schema_table, cap schema=get_schema_table.schema, table=get_schema_table.table, df=df, + order_by="_id", ) +# old MongoDB versions sometimes is missing some part of data during insert +@pytest.mark.flaky(reruns=2) def test_mongodb_writer_if_exists_ignore(spark, processing, get_schema_table, caplog): df = processing.create_spark_df(spark=spark, min_id=1, max_id=1500) df1 = df[df._id < 1001] @@ -174,4 +185,5 @@ def test_mongodb_writer_if_exists_ignore(spark, processing, get_schema_table, ca schema=get_schema_table.schema, table=get_schema_table.table, df=df1, + order_by="_id", ) diff --git a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_mssql_writer_integration.py b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_mssql_writer_integration.py index b2d313c9b..3e6cf35b6 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_mssql_writer_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_mssql_writer_integration.py @@ -29,4 +29,5 @@ def test_mssql_writer_snapshot(spark, processing, prepare_schema_table): schema=prepare_schema_table.schema, table=prepare_schema_table.table, df=df, + order_by="id_int", ) diff --git a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_mysql_writer_integration.py b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_mysql_writer_integration.py index a4e375a3b..86bc7cbb1 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_mysql_writer_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_mysql_writer_integration.py @@ -29,4 +29,5 @@ def test_mysql_writer_snapshot(spark, processing, prepare_schema_table): schema=prepare_schema_table.schema, table=prepare_schema_table.table, df=df, + order_by="id_int", ) diff --git a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_oracle_writer_integration.py b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_oracle_writer_integration.py index f7d6dbf58..f5083bab8 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_oracle_writer_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_oracle_writer_integration.py @@ -30,4 +30,5 @@ def test_oracle_writer_snapshot(spark, processing, prepare_schema_table): schema=prepare_schema_table.schema, table=prepare_schema_table.table, df=df, + order_by="id_int", ) diff --git a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_postgres_writer_integration.py b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_postgres_writer_integration.py index cda43c8a8..ed6519485 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_postgres_writer_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_postgres_writer_integration.py @@ -40,6 +40,7 @@ def test_postgres_writer_snapshot(spark, processing, get_schema_table, options): schema=get_schema_table.schema, table=get_schema_table.table, df=df, + order_by="id_int", ) @@ -67,6 +68,7 @@ def test_postgres_writer_snapshot_with_dict_options(spark, processing, prepare_s schema=prepare_schema_table.schema, table=prepare_schema_table.table, df=df, + order_by="id_int", ) @@ -94,6 +96,7 @@ def test_postgres_writer_snapshot_with_pydantic_options(spark, processing, prepa schema=prepare_schema_table.schema, table=prepare_schema_table.table, df=df, + order_by="id_int", ) @@ -124,6 +127,7 @@ def test_postgres_writer_if_exists_append(spark, processing, prepare_schema_tabl schema=prepare_schema_table.schema, table=prepare_schema_table.table, df=df, + order_by="id_int", ) @@ -217,4 +221,5 @@ def test_postgres_writer_if_exists_replace_entire_table(spark, processing, prepa schema=prepare_schema_table.schema, table=prepare_schema_table.table, df=df2, + order_by="id_int", ) diff --git a/tests/tests_integration/tests_db_connection_integration/test_clickhouse_integration.py b/tests/tests_integration/tests_db_connection_integration/test_clickhouse_integration.py index 410a6a02d..73d8abfb2 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_clickhouse_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_clickhouse_integration.py @@ -6,8 +6,7 @@ try: import pandas except ImportError: - # pandas can be missing if someone runs tests for file connections only - pass + pytest.skip("Missing pandas", allow_module_level=True) from onetl.connection import Clickhouse @@ -135,7 +134,7 @@ def test_clickhouse_connection_execute_ddl(spark, processing, get_schema_table, assert not clickhouse.execute( f""" ALTER TABLE {table_name} ADD INDEX {table}_id_int_idx (id_int) TYPE minmax GRANULARITY 8192{suffix} - """, + """, ) assert not clickhouse.execute(f"ALTER TABLE {table_name} DROP INDEX {table}_id_int_idx{suffix}") diff --git a/tests/tests_integration/tests_db_connection_integration/test_greenplum_integration.py b/tests/tests_integration/tests_db_connection_integration/test_greenplum_integration.py index 91a52bb9d..a424282d3 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_greenplum_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_greenplum_integration.py @@ -5,8 +5,7 @@ try: import pandas except ImportError: - # pandas can be missing if someone runs tests for file connections only - pass + pytest.skip("Missing pandas", allow_module_level=True) from onetl.connection import Greenplum diff --git a/tests/tests_integration/tests_db_connection_integration/test_hive_integration.py b/tests/tests_integration/tests_db_connection_integration/test_hive_integration.py index 8a90578cb..69d7bae6b 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_hive_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_hive_integration.py @@ -7,8 +7,7 @@ try: import pandas except ImportError: - # pandas can be missing if someone runs tests for file connections only - pass + pytest.skip("Missing pandas", allow_module_level=True) from onetl.connection import Hive diff --git a/tests/tests_integration/tests_db_connection_integration/test_mssql_integration.py b/tests/tests_integration/tests_db_connection_integration/test_mssql_integration.py index 96e23183b..9a875671a 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_mssql_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_mssql_integration.py @@ -5,8 +5,7 @@ try: import pandas except ImportError: - # pandas can be missing if someone runs tests for file connections only - pass + pytest.skip("Missing pandas", allow_module_level=True) from onetl.connection import MSSQL diff --git a/tests/tests_integration/tests_db_connection_integration/test_mysql_integration.py b/tests/tests_integration/tests_db_connection_integration/test_mysql_integration.py index dad717b1c..72a6b3b8f 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_mysql_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_mysql_integration.py @@ -5,8 +5,7 @@ try: import pandas except ImportError: - # pandas can be missing if someone runs tests for file connections only - pass + pytest.skip("Missing pandas", allow_module_level=True) from onetl.connection import MySQL diff --git a/tests/tests_integration/tests_db_connection_integration/test_oracle_integration.py b/tests/tests_integration/tests_db_connection_integration/test_oracle_integration.py index 40e813279..6bd96b259 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_oracle_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_oracle_integration.py @@ -1,13 +1,12 @@ -import contextlib import logging +from contextlib import suppress import pytest try: import pandas except ImportError: - # pandas can be missing if someone runs tests for file connections only - pass + pytest.skip("Missing pandas", allow_module_level=True) from onetl.connection import Oracle @@ -86,7 +85,8 @@ def test_oracle_connection_sql(spark, processing, load_table_data, suffix): filtered_df = table_df[table_df.ID_INT < 50] processing.assert_equal_df(df=df, other_frame=filtered_df, order_by="id_int") - with pytest.raises(Exception): + with suppress(Exception): + # new syntax in Oracle 23, but fails on older versions oracle.sql(f"SELECT 1{suffix}") @@ -120,9 +120,9 @@ def test_oracle_connection_fetch(spark, processing, load_table_data, suffix): with pytest.raises(Exception): oracle.fetch(f"SHOW TABLES{suffix}") - # wrong syntax - with pytest.raises(Exception): - oracle.fetch(f"SELECT 1{suffix}") + with suppress(Exception): + # new syntax in Oracle 23, but fails on older versions + oracle.sql(f"SELECT 1{suffix}") @pytest.mark.parametrize("suffix", ["", ";"]) @@ -920,7 +920,7 @@ def package_finalizer(): selected_df = table_df[table_df.ID_INT < 10] processing.assert_equal_df(df=df, other_frame=selected_df, order_by="id_int") - with contextlib.suppress(Exception): + with suppress(Exception): # Oracle 11 does not support selecting from pipelined function without TABLE(...), but 18 does df = oracle.fetch(f"SELECT * FROM {func}_pkg.func_pipelined(10)") processing.assert_equal_df(df=df, other_frame=selected_df, order_by="id_int") diff --git a/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py b/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py index 9f2d2253b..1fb74095f 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py @@ -5,8 +5,7 @@ try: import pandas except ImportError: - # pandas can be missing if someone runs tests for file connections only - pass + pytest.skip("Missing pandas", allow_module_level=True) from onetl.connection import Postgres diff --git a/tests/tests_integration/tests_hwm_store_integration.py b/tests/tests_integration/tests_hwm_store_integration.py index 2245cadbd..8b6938b33 100644 --- a/tests/tests_integration/tests_hwm_store_integration.py +++ b/tests/tests_integration/tests_hwm_store_integration.py @@ -1,10 +1,12 @@ +import secrets import tempfile import pytest +from etl_entities.hwm_store import MemoryHWMStore from onetl.connection import Postgres from onetl.db import DBReader -from onetl.hwm.store import MemoryHWMStore, YAMLHWMStore +from onetl.hwm.store import YAMLHWMStore from onetl.strategy import IncrementalStrategy hwm_store = [ @@ -26,7 +28,13 @@ def test_postgres_hwm_store_integration_with_reader(spark, processing, prepare_s ) hwm_column = "hwm_int" - reader = DBReader(connection=postgres, source=prepare_schema_table.full_name, hwm_column=hwm_column) + hwm_name = secrets.token_hex(5) + + reader = DBReader( + connection=postgres, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), + ) # there is a span span_length = 100 @@ -46,11 +54,10 @@ def test_postgres_hwm_store_integration_with_reader(spark, processing, prepare_s with hwm_store: # incremental run - with IncrementalStrategy() as strategy: + with IncrementalStrategy(): reader.run() - strategy_hwm = strategy.hwm # HWM value was saved into the storage - saved_hwm = hwm_store.get(strategy_hwm.qualified_name) + saved_hwm = hwm_store.get(hwm_name) assert saved_hwm.value == span[hwm_column].max() diff --git a/tests/tests_integration/tests_strategy_integration/test_strategy_incremental_batch.py b/tests/tests_integration/tests_strategy_integration/test_strategy_incremental_batch.py index 25477b3cc..5371e8e27 100644 --- a/tests/tests_integration/tests_strategy_integration/test_strategy_incremental_batch.py +++ b/tests/tests_integration/tests_strategy_integration/test_strategy_incremental_batch.py @@ -3,19 +3,21 @@ from datetime import date, datetime, timedelta import pytest -from etl_entities import DateHWM, DateTimeHWM, IntHWM +from etl_entities.hwm import ColumnDateHWM, ColumnDateTimeHWM, ColumnIntHWM + +from tests.util.rand import rand_str try: import pandas from tests.util.to_pandas import to_pandas except ImportError: - # pandas can be missing if someone runs tests for file connections only - pass + pytest.skip("Missing pandas", allow_module_level=True) + +from etl_entities.hwm_store import HWMStoreStackManager from onetl.connection import Postgres from onetl.db import DBReader -from onetl.hwm.store import HWMStoreManager from onetl.strategy import IncrementalBatchStrategy, IncrementalStrategy pytestmark = pytest.mark.postgres @@ -37,66 +39,15 @@ def test_postgres_strategy_incremental_batch_outside_loop( reader = DBReader( connection=postgres, source=load_table_data.full_name, - hwm_column="hwm_int", + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="hwm_int"), ) - with pytest.raises(RuntimeError): + error_msg = "Invalid IncrementalBatchStrategy usage!" + with pytest.raises(RuntimeError, match=re.escape(error_msg)): with IncrementalBatchStrategy(step=1): reader.run() -def test_postgres_strategy_incremental_batch_unknown_hwm_column( - spark, - processing, - prepare_schema_table, -): - postgres = Postgres( - host=processing.host, - port=processing.port, - user=processing.user, - password=processing.password, - database=processing.database, - spark=spark, - ) - reader = DBReader( - connection=postgres, - source=prepare_schema_table.full_name, - hwm_column="unknown_column", - ) - - with pytest.raises(Exception): - with IncrementalBatchStrategy(step=1) as batches: - for _ in batches: - reader.run() - - -def test_postgres_strategy_incremental_batch_duplicated_hwm_column( - spark, - processing, - prepare_schema_table, - load_table_data, -): - postgres = Postgres( - host=processing.host, - port=processing.port, - user=processing.user, - password=processing.password, - database=processing.database, - spark=spark, - ) - reader = DBReader( - connection=postgres, - source=prepare_schema_table.full_name, - columns=["id_int AS hwm_int"], # previous HWM cast implementation is not supported anymore - hwm_column="hwm_int", - ) - - with pytest.raises(Exception): - with IncrementalBatchStrategy(step=1) as batches: - for _ in batches: - reader.run() - - def test_postgres_strategy_incremental_batch_where(spark, processing, prepare_schema_table): postgres = Postgres( host=processing.host, @@ -111,20 +62,19 @@ def test_postgres_strategy_incremental_batch_where(spark, processing, prepare_sc connection=postgres, source=prepare_schema_table.full_name, where="float_value < 51 OR float_value BETWEEN 101 AND 120", - hwm_column="hwm_int", + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="hwm_int"), ) # there are 2 spans - # 0..100 first_span_begin = 0 first_span_end = 100 first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) - # 101..250 second_span_begin = 101 second_span_end = 200 second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + # insert first span processing.insert_data( schema=prepare_schema_table.schema, table=prepare_schema_table.table, @@ -141,7 +91,8 @@ def test_postgres_strategy_incremental_batch_where(spark, processing, prepare_sc else: first_df = first_df.union(next_df) - processing.assert_equal_df(df=first_df, other_frame=first_span[:51]) + # read only rows 0..50 (according to where) + processing.assert_equal_df(df=first_df, other_frame=first_span[:51], order_by="id_int") # insert second span processing.insert_data( @@ -160,7 +111,8 @@ def test_postgres_strategy_incremental_batch_where(spark, processing, prepare_sc else: second_df = second_df.union(next_df) - processing.assert_equal_df(df=second_df, other_frame=second_span[:19]) + # read only rows 101..119 (according to where) + processing.assert_equal_df(df=second_df, other_frame=second_span[:19], order_by="id_int") def test_postgres_strategy_incremental_batch_hwm_set_twice( @@ -180,78 +132,97 @@ def test_postgres_strategy_incremental_batch_hwm_set_twice( step = 1 table1 = load_table_data.full_name - table2 = f"{secrets.token_hex()}.{secrets.token_hex()}" + table2 = f"{secrets.token_hex(5)}.{secrets.token_hex(5)}" - hwm_column1 = "hwm_int" - hwm_column2 = "hwm_datetime" - - reader1 = DBReader(connection=postgres, table=table1, hwm_column=hwm_column1) - reader2 = DBReader(connection=postgres, table=table2, hwm_column=hwm_column1) - reader3 = DBReader(connection=postgres, table=table1, hwm_column=hwm_column2) + reader1 = DBReader( + connection=postgres, + table=table1, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="hwm_int"), + ) + reader2 = DBReader( + connection=postgres, + table=table1, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="hwm_int"), + ) + reader3 = DBReader( + connection=postgres, + table=table2, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="hwm_int"), + ) with IncrementalBatchStrategy(step=step) as batches: for _ in batches: reader1.run() - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Detected wrong IncrementalBatchStrategy usage.", + ): reader2.run() - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Detected wrong IncrementalBatchStrategy usage.", + ): reader3.run() - break - -# Fail if HWM is Numeric, or Decimal with fractional part, or string @pytest.mark.parametrize( - "hwm_column, step", + "hwm_column, new_type, step", [ - ("float_value", 1.0), - ("text_string", "abc"), + ("hwm_int", "date", 200), + ("hwm_date", "integer", timedelta(days=20)), + ("hwm_datetime", "integer", timedelta(weeks=2)), ], ) -def test_postgres_strategy_incremental_batch_wrong_hwm_type(spark, processing, prepare_schema_table, hwm_column, step): +def test_postgres_strategy_incremental_batch_different_hwm_type_in_store( + spark, + processing, + load_table_data, + hwm_column, + new_type, + step, +): postgres = Postgres( host=processing.host, + port=processing.port, user=processing.user, password=processing.password, database=processing.database, spark=spark, ) - reader = DBReader(connection=postgres, source=prepare_schema_table.full_name, hwm_column=hwm_column) - data = processing.create_pandas_df() - - # insert first span - processing.insert_data( - schema=prepare_schema_table.schema, - table=prepare_schema_table.table, - values=data, + reader = DBReader( + connection=postgres, + source=load_table_data.full_name, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), ) - with pytest.raises((KeyError, ValueError)): - # incremental run + with IncrementalBatchStrategy(step=step) as batches: + for _ in batches: + reader.run() + + # change table schema + new_fields = {column_name: processing.get_column_type(column_name) for column_name in processing.column_names} + new_fields[hwm_column] = new_type + + processing.drop_table(schema=load_table_data.schema, table=load_table_data.table) + processing.create_table(schema=load_table_data.schema, table=load_table_data.table, fields=new_fields) + + with pytest.raises(TypeError, match="Cannot cast HWM of type .* as .*"): with IncrementalBatchStrategy(step=step) as batches: for _ in batches: reader.run() -@pytest.mark.parametrize( - "hwm_column, new_type, step", - [ - ("hwm_int", "date", 200), - ("hwm_date", "integer", timedelta(days=20)), - ("hwm_datetime", "integer", timedelta(weeks=2)), - ], -) -def test_postgres_strategy_incremental_batch_different_hwm_type_in_store( +def test_postgres_strategy_incremental_batch_different_hwm_source_in_store( spark, processing, load_table_data, - hwm_column, - new_type, - step, ): + hwm_store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + postgres = Postgres( host=processing.host, port=processing.port, @@ -261,23 +232,61 @@ def test_postgres_strategy_incremental_batch_different_hwm_type_in_store( spark=spark, ) - reader = DBReader(connection=postgres, source=load_table_data.full_name, hwm_column=hwm_column) + old_hwm = ColumnIntHWM(name=hwm_name, source=load_table_data.full_name, expression="hwm_int", description="abc") + # change HWM entity in HWM store + fake_hwm = old_hwm.copy(update={"entity": rand_str()}) + hwm_store.set_hwm(fake_hwm) - with IncrementalBatchStrategy(step=step) as batches: - for _ in batches: - reader.run() + reader = DBReader( + connection=postgres, + source=load_table_data.full_name, + hwm=old_hwm, + ) + with pytest.raises(ValueError, match="Detected HWM with different `entity` attribute"): + with IncrementalBatchStrategy(step=50) as batches: + for _ in batches: + reader.run() - processing.drop_table(schema=load_table_data.schema, table=load_table_data.table) - new_fields = {column_name: processing.get_column_type(column_name) for column_name in processing.column_names} - new_fields[hwm_column] = new_type - processing.create_table(schema=load_table_data.schema, table=load_table_data.table, fields=new_fields) +@pytest.mark.parametrize("attribute", ["expression", "description"]) +def test_postgres_strategy_incremental_batch_different_hwm_optional_attribute_in_store( + spark, + processing, + load_table_data, + attribute, +): + hwm_store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) - with pytest.raises(ValueError): - with IncrementalBatchStrategy(step=step) as batches: + postgres = Postgres( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + old_hwm = ColumnIntHWM(name=hwm_name, source=load_table_data.full_name, expression="hwm_int", description="abc") + + # change attribute value in HWM store + fake_hwm = old_hwm.copy(update={attribute: rand_str()}) + hwm_store.set_hwm(fake_hwm) + + reader = DBReader( + connection=postgres, + source=load_table_data.full_name, + hwm=old_hwm, + ) + with pytest.warns(UserWarning, match=f"Detected HWM with different `{attribute}` attribute"): + with IncrementalBatchStrategy(step=50) as batches: for _ in batches: reader.run() + # attributes from DBReader have higher priority, except value + new_hwm = hwm_store.get_hwm(name=hwm_name) + assert new_hwm.dict(exclude={"value", "modified_time"}) == old_hwm.dict(exclude={"value", "modified_time"}) + @pytest.mark.parametrize( "hwm_column, step", @@ -293,7 +302,7 @@ def test_postgres_strategy_incremental_batch_different_hwm_type_in_store( ("hwm_datetime", "abc"), ], ) -def test_postgres_strategy_incremental_batch_step_wrong_type( +def test_postgres_strategy_incremental_batch_wrong_step_type( spark, processing, prepare_schema_table, @@ -308,7 +317,11 @@ def test_postgres_strategy_incremental_batch_step_wrong_type( database=processing.database, spark=spark, ) - reader = DBReader(connection=postgres, source=prepare_schema_table.full_name, hwm_column=hwm_column) + reader = DBReader( + connection=postgres, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), + ) # there are 2 spans with a gap between @@ -373,7 +386,11 @@ def test_postgres_strategy_incremental_batch_step_negative( database=processing.database, spark=spark, ) - reader = DBReader(connection=postgres, source=prepare_schema_table.full_name, hwm_column=hwm_column) + reader = DBReader( + connection=postgres, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), + ) # there are 2 spans with a gap between @@ -442,7 +459,7 @@ def test_postgres_strategy_incremental_batch_step_too_small( reader = DBReader( connection=postgres, source=prepare_schema_table.full_name, - hwm_column=hwm_column, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), ) # there are 2 spans with a gap between @@ -490,10 +507,10 @@ def test_postgres_strategy_incremental_batch_step_too_small( @pytest.mark.parametrize( "hwm_type, hwm_column, step, per_iter", [ - (IntHWM, "hwm_int", 20, 30), # step < per_iter - (IntHWM, "hwm_int", 30, 30), # step == per_iter - (DateHWM, "hwm_date", timedelta(days=20), 20), # per_iter value is calculated to cover the step value - (DateTimeHWM, "hwm_datetime", timedelta(weeks=2), 20), # same + (ColumnIntHWM, "hwm_int", 20, 30), # step < per_iter + (ColumnIntHWM, "hwm_int", 30, 30), # step == per_iter + (ColumnDateHWM, "hwm_date", timedelta(days=20), 20), # per_iter value is calculated to cover the step value + (ColumnDateTimeHWM, "hwm_datetime", timedelta(weeks=2), 20), # same ], ) @pytest.mark.parametrize( @@ -519,7 +536,7 @@ def test_postgres_strategy_incremental_batch( span_gap, span_length, ): - store = HWMStoreManager.get_current() + store = HWMStoreStackManager.get_current() postgres = Postgres( host=processing.host, @@ -529,9 +546,13 @@ def test_postgres_strategy_incremental_batch( database=processing.database, spark=spark, ) - reader = DBReader(connection=postgres, source=prepare_schema_table.full_name, hwm_column=hwm_column) - hwm = hwm_type(source=reader.source, column=reader.hwm_column) + hwm_name = secrets.token_hex(5) + reader = DBReader( + connection=postgres, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), + ) # there are 2 spans with a gap between # 0..100 @@ -555,7 +576,7 @@ def test_postgres_strategy_incremental_batch( ) # hwm is not in the store - assert store.get(hwm.qualified_name) is None + assert store.get_hwm(hwm_name) is None # fill up hwm storage with last value, e.g. 100 first_df = None @@ -571,14 +592,14 @@ def test_postgres_strategy_incremental_batch( # same behavior as SnapshotBatchStrategy, no rows skipped if "int" in hwm_column: # only changed data has been read - processing.assert_equal_df(df=first_df, other_frame=first_span) + processing.assert_equal_df(df=first_df, other_frame=first_span, order_by="id_int") else: # date and datetime values have a random part # so instead of checking the whole dataframe a partial comparison should be performed processing.assert_subset_df(df=first_df, other_frame=first_span) # hwm is set - hwm = store.get(hwm.qualified_name) + hwm = store.get_hwm(hwm_name) assert hwm is not None assert isinstance(hwm, hwm_type) assert hwm.value == first_span_max @@ -595,7 +616,7 @@ def test_postgres_strategy_incremental_batch( second_df = None with IncrementalBatchStrategy(step=step) as batches: for _ in batches: - hwm = store.get(hwm.qualified_name) + hwm = store.get_hwm(hwm_name) assert hwm is not None assert isinstance(hwm, hwm_type) assert first_span_max <= hwm.value <= second_span_max @@ -608,19 +629,19 @@ def test_postgres_strategy_incremental_batch( else: second_df = second_df.union(next_df) - hwm = store.get(hwm.qualified_name) + hwm = store.get_hwm(hwm_name) assert hwm is not None assert isinstance(hwm, hwm_type) assert first_span_max <= hwm.value <= second_span_max - hwm = store.get(hwm.qualified_name) + hwm = store.get_hwm(hwm_name) assert hwm is not None assert isinstance(hwm, hwm_type) assert hwm.value == second_span_max if "int" in hwm_column: # only changed data has been read - processing.assert_equal_df(df=second_df, other_frame=second_span) + processing.assert_equal_df(df=second_df, other_frame=second_span, order_by="id_int") else: # date and datetime values have a random part # so instead of checking the whole dataframe a partial comparison should be performed @@ -656,7 +677,11 @@ def test_postgres_strategy_incremental_batch_stop( database=processing.database, spark=spark, ) - reader = DBReader(connection=postgres, source=prepare_schema_table.full_name, hwm_column=hwm_column) + reader = DBReader( + connection=postgres, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), + ) # there is a span 0..100 span_begin = 0 @@ -731,7 +756,7 @@ def test_postgres_strategy_incremental_batch_offset( source=prepare_schema_table.full_name, # the error is raised if hwm_expr is set, and hwm_column in the columns list # but if columns list is not passed, this is not an error - hwm_column=(hwm_column, hwm_column), + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), ) # there are 2 spans with a gap between @@ -780,55 +805,15 @@ def test_postgres_strategy_incremental_batch_offset( total_span = pandas.concat([first_span, second_span], ignore_index=True) if full: - total_df = total_df.sort(total_df.id_int.asc()) # all the data has been read - processing.assert_equal_df(df=total_df, other_frame=total_span) + processing.assert_equal_df(df=total_df, other_frame=total_span, order_by="id_int") else: # date and datetime values have a random part # so instead of checking the whole dataframe a partial comparison should be performed processing.assert_subset_df(df=total_df, other_frame=total_span) -@pytest.mark.parametrize( - "hwm_source, hwm_column, hwm_expr, hwm_type, step, func", - [ - ( - "hwm_int", - "hwm1_int", - "text_string::int", - IntHWM, - 10, - str, - ), - ( - "hwm_date", - "hwm1_date", - "text_string::date", - DateHWM, - timedelta(days=10), - lambda x: x.isoformat(), - ), - ( - "hwm_datetime", - "HWM1_DATETIME", - "text_string::timestamp", - DateTimeHWM, - timedelta(hours=100), - lambda x: x.isoformat(), - ), - ], -) -def test_postgres_strategy_incremental_batch_with_hwm_expr( - spark, - processing, - prepare_schema_table, - hwm_source, - hwm_column, - hwm_expr, - hwm_type, - step, - func, -): +def test_postgres_strategy_incremental_batch_nothing_to_read(spark, processing, prepare_schema_table): postgres = Postgres( host=processing.host, port=processing.port, @@ -838,37 +823,54 @@ def test_postgres_strategy_incremental_batch_with_hwm_expr( spark=spark, ) + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + hwm_column = "hwm_int" + step = 10 + reader = DBReader( connection=postgres, source=prepare_schema_table.full_name, - # the error is raised if hwm_expr is set, and hwm_column in the columns list - # but here hwm_column is not in the columns list, no error - columns=["*"], - hwm_column=(hwm_column, hwm_expr), + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), ) - # there are 2 spans with a gap between span_gap = 10 span_length = 50 - # 0..100 + # there are 2 spans with a gap between + + # 0..50 first_span_begin = 0 first_span_end = first_span_begin + span_length - # 110..210 + # 60..110 second_span_begin = first_span_end + span_gap second_span_end = second_span_begin + span_length first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) - first_span["text_string"] = first_span[hwm_source].apply(func) - first_span_with_hwm = first_span.copy() - first_span_with_hwm[hwm_column.lower()] = first_span[hwm_source] + first_span_max = first_span[hwm_column].max() + second_span_max = second_span[hwm_column].max() + + # no data yet, nothing to read + df = None + counter = 0 + with IncrementalBatchStrategy(step=step) as batches: + for _ in batches: + next_df = reader.run() + counter += 1 + + if df is None: + df = next_df + else: + df = df.union(next_df) - second_span["text_string"] = second_span[hwm_source].apply(func) - second_span_with_hwm = second_span.copy() - second_span_with_hwm[hwm_column.lower()] = second_span[hwm_source] + # exactly 1 batch with empty result + assert counter == 1 + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value is None # insert first span processing.insert_data( @@ -877,19 +879,43 @@ def test_postgres_strategy_incremental_batch_with_hwm_expr( values=first_span, ) - # incremental run - first_df = None + # .run() is not called - dataframe still empty - HWM not updated + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value is None + + df = None with IncrementalBatchStrategy(step=step) as batches: for _ in batches: next_df = reader.run() - if first_df is None: - first_df = next_df + if df is None: + df = next_df else: - first_df = first_df.union(next_df) + df = df.union(next_df) - # all the data has been read - processing.assert_equal_df(df=first_df.orderBy("id_int"), other_frame=first_span_with_hwm) + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # no new data yet, nothing to read + df = None + counter = 0 + with IncrementalBatchStrategy(step=step) as batches: + for _ in batches: + next_df = reader.run() + counter += 1 + + if df is None: + df = next_df + else: + df = df.union(next_df) + + # exactly 1 batch with empty result + assert counter == 1 + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max # insert second span processing.insert_data( @@ -898,20 +924,23 @@ def test_postgres_strategy_incremental_batch_with_hwm_expr( values=second_span, ) - second_df = None + # .run() is not called - dataframe still empty - HWM not updated + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # read data + df = None with IncrementalBatchStrategy(step=step) as batches: for _ in batches: next_df = reader.run() + counter += 1 - if second_df is None: - second_df = next_df + if df is None: + df = next_df else: - second_df = second_df.union(next_df) + df = df.union(next_df) - if issubclass(hwm_type, IntHWM): - # only changed data has been read - processing.assert_equal_df(df=second_df.orderBy("id_int"), other_frame=second_span_with_hwm) - else: - # date and datetime values have a random part - # so instead of checking the whole dataframe a partial comparison should be performed - processing.assert_subset_df(df=second_df, other_frame=second_span_with_hwm) + processing.assert_equal_df(df=df, other_frame=second_span, order_by="id_int") + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == second_span_max diff --git a/tests/tests_integration/tests_strategy_integration/test_strategy_snapshot.py b/tests/tests_integration/tests_strategy_integration/test_strategy_snapshot_batch.py similarity index 65% rename from tests/tests_integration/tests_strategy_integration/test_strategy_snapshot.py rename to tests/tests_integration/tests_strategy_integration/test_strategy_snapshot_batch.py index e5a43aad4..e1984ab05 100644 --- a/tests/tests_integration/tests_strategy_integration/test_strategy_snapshot.py +++ b/tests/tests_integration/tests_strategy_integration/test_strategy_snapshot_batch.py @@ -4,20 +4,19 @@ from datetime import date, datetime, timedelta import pytest -from etl_entities import DateHWM, DateTimeHWM, IntHWM try: import pandas - from tests.util.to_pandas import to_pandas except ImportError: - # pandas can be missing if someone runs tests for file connections only - pass + pytest.skip("Missing pandas", allow_module_level=True) + +from etl_entities.hwm import ColumnIntHWM +from etl_entities.hwm_store import HWMStoreStackManager from onetl.connection import Postgres from onetl.db import DBReader -from onetl.hwm.store import HWMStoreManager -from onetl.strategy import IncrementalStrategy, SnapshotBatchStrategy, SnapshotStrategy +from onetl.strategy import SnapshotBatchStrategy, SnapshotStrategy pytestmark = pytest.mark.postgres @@ -31,48 +30,22 @@ def test_postgres_strategy_snapshot_hwm_column_present(spark, processing, prepar database=processing.database, spark=spark, ) - column = secrets.token_hex() - reader = DBReader(connection=postgres, source=prepare_schema_table.full_name, hwm_column=column) - - with SnapshotStrategy(): - with pytest.raises(ValueError, match="SnapshotStrategy cannot be used with `hwm_column` passed into DBReader"): - reader.run() - - -def test_postgres_strategy_snapshot(spark, processing, prepare_schema_table): - postgres = Postgres( - host=processing.host, - port=processing.port, - user=processing.user, - password=processing.password, - database=processing.database, - spark=spark, - ) - reader = DBReader(connection=postgres, source=prepare_schema_table.full_name) - - # there is a span 0..50 - span_begin = 0 - span_end = 100 - span = processing.create_pandas_df(min_id=span_begin, max_id=span_end) - - # insert span - processing.insert_data( - schema=prepare_schema_table.schema, - table=prepare_schema_table.table, - values=span, + column = secrets.token_hex(5) + reader = DBReader( + connection=postgres, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=column), ) - # snapshot run + error_message = "DBReader(hwm=...) cannot be used with SnapshotStrategy" with SnapshotStrategy(): - total_df = reader.run() - - processing.assert_equal_df(df=total_df, other_frame=span) + with pytest.raises(RuntimeError, match=re.escape(error_message)): + reader.run() @pytest.mark.parametrize( "hwm_column, step", [ - ("hwm_int", 1.5), ("hwm_int", "abc"), ("hwm_int", timedelta(hours=10)), ("hwm_date", 10), @@ -83,7 +56,7 @@ def test_postgres_strategy_snapshot(spark, processing, prepare_schema_table): ("hwm_datetime", "abc"), ], ) -def test_postgres_strategy_snapshot_batch_step_wrong_type( +def test_postgres_strategy_snapshot_batch_wrong_step_type( spark, processing, prepare_schema_table, @@ -99,7 +72,11 @@ def test_postgres_strategy_snapshot_batch_step_wrong_type( database=processing.database, spark=spark, ) - reader = DBReader(connection=postgres, source=prepare_schema_table.full_name, hwm_column=hwm_column) + reader = DBReader( + connection=postgres, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), + ) with pytest.raises((TypeError, ValueError)): with SnapshotBatchStrategy(step=step) as part: @@ -132,12 +109,16 @@ def test_postgres_strategy_snapshot_batch_step_negative( spark=spark, ) - reader = DBReader(connection=postgres, source=prepare_schema_table.full_name, hwm_column=hwm_column) + reader = DBReader( + connection=postgres, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), + ) error_msg = "HWM value is not increasing, please check options passed to SnapshotBatchStrategy" with pytest.raises(ValueError, match=error_msg): - with SnapshotBatchStrategy(step=step) as part: - for _ in part: + with SnapshotBatchStrategy(step=step) as batches: + for _ in batches: reader.run() @@ -169,7 +150,7 @@ def test_postgres_strategy_snapshot_batch_step_too_small( reader = DBReader( connection=postgres, source=prepare_schema_table.full_name, - hwm_column=hwm_column, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), ) error_msg = f"step={step!r} parameter of SnapshotBatchStrategy leads to generating too many iterations" @@ -196,10 +177,11 @@ def test_postgres_strategy_snapshot_batch_outside_loop( reader = DBReader( connection=postgres, source=load_table_data.full_name, - hwm_column="hwm_int", + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="hwm_int"), ) - with pytest.raises(RuntimeError): + error_message = "Invalid SnapshotBatchStrategy usage!" + with pytest.raises(RuntimeError, match=re.escape(error_message)): with SnapshotBatchStrategy(step=1): reader.run() @@ -217,81 +199,43 @@ def test_postgres_strategy_snapshot_batch_hwm_set_twice(spark, processing, load_ step = 20 table1 = load_table_data.full_name - table2 = f"{secrets.token_hex()}.{secrets.token_hex()}" + table2 = f"{secrets.token_hex(5)}.{secrets.token_hex(5)}" - hwm_column1 = "hwm_int" - hwm_column2 = "hwm_datetime" - - reader1 = DBReader(connection=postgres, table=table1, hwm_column=hwm_column1) - reader2 = DBReader(connection=postgres, table=table2, hwm_column=hwm_column1) - reader3 = DBReader(connection=postgres, table=table1, hwm_column=hwm_column2) + reader1 = DBReader( + connection=postgres, + table=table1, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="hwm_int"), + ) + reader2 = DBReader( + connection=postgres, + table=table1, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="hwm_int"), + ) + reader3 = DBReader( + connection=postgres, + table=table2, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="hwm_int"), + ) with SnapshotBatchStrategy(step=step) as batches: for _ in batches: reader1.run() - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Detected wrong SnapshotBatchStrategy usage.", + ): reader2.run() - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Detected wrong SnapshotBatchStrategy usage.", + ): reader3.run() break -def test_postgres_strategy_snapshot_batch_unknown_hwm_column( - spark, - processing, - prepare_schema_table, -): - postgres = Postgres( - host=processing.host, - port=processing.port, - user=processing.user, - password=processing.password, - database=processing.database, - spark=spark, - ) - - reader = DBReader( - connection=postgres, - source=prepare_schema_table.full_name, - hwm_column="unknown_column", # there is no such column in a table - ) - - with pytest.raises(Exception): - with SnapshotBatchStrategy(step=1) as batches: - for _ in batches: - reader.run() - - -def test_postgres_strategy_snapshot_batch_duplicated_hwm_column( - spark, - processing, - prepare_schema_table, -): - postgres = Postgres( - host=processing.host, - port=processing.port, - user=processing.user, - password=processing.password, - database=processing.database, - spark=spark, - ) - - reader = DBReader( - connection=postgres, - source=prepare_schema_table.full_name, - columns=["id_int AS hwm_int"], # previous HWM cast implementation is not supported anymore - hwm_column="hwm_int", - ) - - with pytest.raises(Exception): - with SnapshotBatchStrategy(step=1) as batches: - for _ in batches: - reader.run() - - def test_postgres_strategy_snapshot_batch_where(spark, processing, prepare_schema_table): postgres = Postgres( host=processing.host, @@ -306,7 +250,7 @@ def test_postgres_strategy_snapshot_batch_where(spark, processing, prepare_schem connection=postgres, source=prepare_schema_table.full_name, where="float_value < 50 OR float_value = 50.50", - hwm_column="hwm_int", + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="hwm_int"), ) # there is a span 0..100 @@ -331,16 +275,24 @@ def test_postgres_strategy_snapshot_batch_where(spark, processing, prepare_schem else: df = df.union(next_df) - processing.assert_equal_df(df=df, other_frame=span[:51]) + processing.assert_equal_df(df=df, other_frame=span[:51], order_by="id_int") @pytest.mark.flaky(reruns=5) @pytest.mark.parametrize( - "hwm_type, hwm_column, step, per_iter", + "hwm_column, step, per_iter", [ - (IntHWM, "hwm_int", 10, 11), # yes, 11, ids are 0..10, and the first row is included in snapshot strategy - (DateHWM, "hwm_date", timedelta(days=4), 30), # per_iter value is calculated to cover the step value - (DateTimeHWM, "hwm_datetime", timedelta(hours=100), 30), # same + ( + "hwm_int", + 10, + 11, + ), # yes, 11, ids are 0..10, and the first row is included in snapshot strategy + ( + "hwm_date", + timedelta(days=4), + 30, + ), # per_iter value is calculated to cover the step value + ("hwm_datetime", timedelta(hours=100), 30), # same ], ) @pytest.mark.parametrize( @@ -359,14 +311,14 @@ def test_postgres_strategy_snapshot_batch( spark, processing, prepare_schema_table, - hwm_type, hwm_column, step, per_iter, span_gap, span_length, ): - store = HWMStoreManager.get_current() + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) postgres = Postgres( host=processing.host, @@ -376,12 +328,14 @@ def test_postgres_strategy_snapshot_batch( database=processing.database, spark=spark, ) - reader = DBReader(connection=postgres, source=prepare_schema_table.full_name, hwm_column=hwm_column) - - hwm = hwm_type(source=reader.source, column=reader.hwm_column) + reader = DBReader( + connection=postgres, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), + ) # hwm is not in the store - assert store.get(hwm.qualified_name) is None + assert store.get_hwm(hwm_name) is None # there are 2 spans with a gap between # 0..100 @@ -413,7 +367,7 @@ def test_postgres_strategy_snapshot_batch( with SnapshotBatchStrategy(step=step) as batches: for _ in batches: # no hwm saves on each iteration - assert store.get(hwm.qualified_name) is None + assert store.get_hwm(hwm_name) is None next_df = reader.run() assert next_df.count() <= per_iter @@ -423,35 +377,27 @@ def test_postgres_strategy_snapshot_batch( else: total_df = total_df.union(next_df) - assert store.get(hwm.qualified_name) is None + assert store.get_hwm(hwm_name) is None # no hwm saves after exiting the context - assert store.get(hwm.qualified_name) is None + assert store.get_hwm(hwm_name) is None # all the rows will be read total_span = pandas.concat([first_span, second_span], ignore_index=True) - total_df = total_df.sort(total_df.id_int.asc()) - processing.assert_equal_df(df=total_df, other_frame=total_span) + processing.assert_equal_df(df=total_df, other_frame=total_span, order_by="id_int") -@pytest.mark.parametrize( - "hwm_column, step", - [ - ("hwm_int", 10), # yes, 11, ids are 0..10, and the first row is included in snapshot strategy - ("hwm_date", timedelta(days=4)), - ("hwm_datetime", timedelta(hours=100)), - ], -) def test_postgres_strategy_snapshot_batch_ignores_hwm_value( spark, processing, prepare_schema_table, - hwm_column, - step, ): - span_length = 10 - span_gap = 1 + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + hwm_column = "id_int" + hwm = ColumnIntHWM(name=hwm_name, expression=hwm_column) + step = 10 postgres = Postgres( host=processing.host, @@ -464,38 +410,22 @@ def test_postgres_strategy_snapshot_batch_ignores_hwm_value( reader = DBReader( connection=postgres, source=prepare_schema_table.full_name, - columns=[hwm_column, "*"], - hwm_column=hwm_column, + hwm=hwm, ) - # there are 2 spans with a gap between - # 0..100 - first_span_begin = 0 - first_span_end = span_length - first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) - - # 150..200 - second_span_begin = first_span_end + span_gap - second_span_end = second_span_begin + span_length - second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + span = processing.create_pandas_df() # insert first span processing.insert_data( schema=prepare_schema_table.schema, table=prepare_schema_table.table, - values=first_span, + values=span, ) - # init hwm with 100 value - with IncrementalStrategy(): - reader.run() - - # insert second span - processing.insert_data( - schema=prepare_schema_table.schema, - table=prepare_schema_table.table, - values=second_span, - ) + # set HWM value in HWM store + first_span_max = span[hwm_column].max() + fake_hwm = hwm.copy().set_value(first_span_max) + store.set_hwm(fake_hwm) # snapshot run total_df = None @@ -508,12 +438,11 @@ def test_postgres_strategy_snapshot_batch_ignores_hwm_value( else: total_df = total_df.union(next_df) - # init hwm value will be ignored - # all the rows will be read - total_span = pandas.concat([first_span, second_span], ignore_index=True) + # all the rows are be read, HWM store is completely ignored + processing.assert_equal_df(df=total_df, other_frame=span, order_by="id_int") - total_df = total_df.sort(total_df.id_int.asc()) - processing.assert_equal_df(df=total_df, other_frame=total_span) + # HWM in hwm store is left intact + assert store.get_hwm(hwm_name) == fake_hwm @pytest.mark.parametrize( @@ -528,7 +457,13 @@ def test_postgres_strategy_snapshot_batch_ignores_hwm_value( ) @pytest.mark.parametrize("span_length", [100, 40, 5]) def test_postgres_strategy_snapshot_batch_stop( - spark, processing, prepare_schema_table, hwm_column, step, stop, span_length # noqa: C812 + spark, + processing, + prepare_schema_table, + hwm_column, + step, + stop, + span_length, ): postgres = Postgres( host=processing.host, @@ -538,7 +473,11 @@ def test_postgres_strategy_snapshot_batch_stop( database=processing.database, spark=spark, ) - reader = DBReader(connection=postgres, source=prepare_schema_table.full_name, hwm_column=hwm_column) + reader = DBReader( + connection=postgres, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), + ) # there is a span 0..100 span_begin = 0 @@ -563,18 +502,12 @@ def test_postgres_strategy_snapshot_batch_stop( else: total_df = total_df.union(next_df) - total_pandas_df = processing.fix_pandas_df(to_pandas(total_df)) - - # only a small part of input data has been read - # so instead of checking the whole dataframe a partial comparison should be performed - for column in total_pandas_df.columns: - total_pandas_df[column].isin(span[column]).all() - # check that stop clause working as expected - assert (total_pandas_df[hwm_column] <= stop).all() + total_span = span[span[hwm_column] <= stop] + processing.assert_equal_df(df=total_df, other_frame=total_span, order_by="id_int") -def test_postgres_strategy_snapshot_batch_handle_exception(spark, processing, prepare_schema_table): # noqa: C812 +def test_postgres_strategy_snapshot_batch_handle_exception(spark, processing, prepare_schema_table): hwm_column = "hwm_int" postgres = Postgres( host=processing.host, @@ -584,7 +517,11 @@ def test_postgres_strategy_snapshot_batch_handle_exception(spark, processing, pr database=processing.database, spark=spark, ) - reader = DBReader(connection=postgres, source=prepare_schema_table.full_name, hwm_column=hwm_column) + reader = DBReader( + connection=postgres, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), + ) step = 10 @@ -609,10 +546,6 @@ def test_postgres_strategy_snapshot_batch_handle_exception(spark, processing, pr values=first_span, ) - # init hwm with 100 value - with IncrementalStrategy(): - reader.run() - # insert second span processing.insert_data( schema=prepare_schema_table.schema, @@ -631,7 +564,7 @@ def test_postgres_strategy_snapshot_batch_handle_exception(spark, processing, pr first_df = first_df.union(reader.run()) raise_counter += step - # raise exception somethere in the middle of the read process + # raise exception somewhere in the middle of the read process if raise_counter >= span_gap + (span_length // 2): raise ValueError("some error") @@ -646,49 +579,12 @@ def test_postgres_strategy_snapshot_batch_handle_exception(spark, processing, pr else: total_df = total_df.union(next_df) - # all the rows will be read + # all the rows are be read total_span = pandas.concat([first_span, second_span], ignore_index=True) - total_df = total_df.sort(total_df.id_int.asc()) + processing.assert_equal_df(df=total_df, other_frame=total_span, order_by="id_int") - processing.assert_equal_df(df=total_df, other_frame=total_span) - -@pytest.mark.parametrize( - "hwm_source, hwm_column, hwm_expr, step, func", - [ - ( - "hwm_int", - "hwm1_int", - "text_string::int", - 10, - str, - ), - ( - "hwm_date", - "hwm1_date", - "text_string::date", - timedelta(days=10), - lambda x: x.isoformat(), - ), - ( - "hwm_datetime", - "HWM1_DATETIME", - "text_string::timestamp", - timedelta(hours=100), - lambda x: x.isoformat(), - ), - ], -) -def test_postgres_strategy_snapshot_batch_with_hwm_expr( - spark, - processing, - prepare_schema_table, - hwm_source, - hwm_column, - hwm_expr, - step, - func, -): +def test_postgres_strategy_snapshot_batch_nothing_to_read(spark, processing, prepare_schema_table): postgres = Postgres( host=processing.host, port=processing.port, @@ -698,38 +594,105 @@ def test_postgres_strategy_snapshot_batch_with_hwm_expr( spark=spark, ) + hwm_name = secrets.token_hex(5) + step = 10 + reader = DBReader( connection=postgres, source=prepare_schema_table.full_name, - columns=processing.column_names, - hwm_column=(hwm_column, hwm_expr), + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression="hwm_int"), ) - # there is a span 0..100 - span_begin = 0 - span_end = 100 - span = processing.create_pandas_df(min_id=span_begin, max_id=span_end) + span_gap = 10 + span_length = 50 - span["text_string"] = span[hwm_source].apply(func) - span_with_hwm = span.copy() - span_with_hwm[hwm_column.lower()] = span[hwm_source] + # there are 2 spans with a gap between + + # 0..50 + first_span_begin = 0 + first_span_end = first_span_begin + span_length + + # 60..110 + second_span_begin = first_span_end + span_gap + second_span_end = second_span_begin + span_length + + first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) + second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + + # no data yet, nothing to read + df = None + counter = 0 + with SnapshotBatchStrategy(step=step) as batches: + for _ in batches: + next_df = reader.run() + counter += 1 + + if df is None: + df = next_df + else: + df = df.union(next_df) + + # exactly 1 batch with empty result + assert counter == 1 + assert not df.count() # insert first span processing.insert_data( schema=prepare_schema_table.schema, table=prepare_schema_table.table, - values=span, + values=first_span, ) - total_df = None + # .run() is not called - dataframe still empty (unlike SnapshotStrategy) + assert not df.count() + + # read data + df = None with SnapshotBatchStrategy(step=step) as batches: for _ in batches: next_df = reader.run() - if total_df is None: - total_df = next_df + if df is None: + df = next_df else: - total_df = total_df.union(next_df) + df = df.union(next_df) + + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + + # read data again - same output, HWM is not saved to HWM store + df = None + with SnapshotBatchStrategy(step=step) as batches: + for _ in batches: + next_df = reader.run() - # all the data has been read - processing.assert_equal_df(df=total_df.orderBy("id_int"), other_frame=span_with_hwm) + if df is None: + df = next_df + else: + df = df.union(next_df) + + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + + # insert second span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=second_span, + ) + + # .run() is not called - dataframe contains only old data (unlike SnapshotStrategy) + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + + # read data + df = None + with SnapshotBatchStrategy(step=step) as batches: + for _ in batches: + next_df = reader.run() + counter += 1 + + if df is None: + df = next_df + else: + df = df.union(next_df) + + total_span = pandas.concat([first_span, second_span], ignore_index=True) + processing.assert_equal_df(df=df, other_frame=total_span, order_by="id_int") diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_batch_strategy_integration/test_strategy_incremental_batch_kafka.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_batch_strategy_integration/test_strategy_incremental_batch_kafka.py index 8ffec9e7b..aad831c01 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_batch_strategy_integration/test_strategy_incremental_batch_kafka.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_batch_strategy_integration/test_strategy_incremental_batch_kafka.py @@ -1,3 +1,5 @@ +import secrets + import pytest from onetl.connection import Kafka @@ -27,7 +29,7 @@ def test_strategy_kafka_with_batch_strategy_error(strategy, spark): spark=spark, ), table="topic", - hwm_column="offset", + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="offset"), ) - with pytest.raises(ValueError, match="connection does not support batch strategies"): + with pytest.raises(RuntimeError): reader.run() diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_batch_strategy_integration/test_strategy_incremental_batch_mongodb.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_batch_strategy_integration/test_strategy_incremental_batch_mongodb.py index 71763dc01..2809cdb1b 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_batch_strategy_integration/test_strategy_incremental_batch_mongodb.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_batch_strategy_integration/test_strategy_incremental_batch_mongodb.py @@ -1,11 +1,12 @@ +import secrets from datetime import timedelta import pytest -from etl_entities import DateTimeHWM, IntHWM +from etl_entities.hwm import ColumnDateTimeHWM, ColumnIntHWM +from etl_entities.hwm_store import HWMStoreStackManager from onetl.connection import MongoDB from onetl.db import DBReader -from onetl.hwm.store import HWMStoreManager from onetl.strategy import IncrementalBatchStrategy pytestmark = pytest.mark.mongodb @@ -37,9 +38,9 @@ def df_schema(): @pytest.mark.parametrize( "hwm_type, hwm_column, step, per_iter", [ - (IntHWM, "hwm_int", 20, 30), # step < per_iter - (IntHWM, "hwm_int", 30, 30), # step == per_iter - (DateTimeHWM, "hwm_datetime", timedelta(weeks=2), 20), # same + (ColumnIntHWM, "hwm_int", 20, 30), # step < per_iter + (ColumnIntHWM, "hwm_int", 30, 30), # step == per_iter + (ColumnDateTimeHWM, "hwm_datetime", timedelta(weeks=2), 20), # same ], ) @pytest.mark.parametrize( @@ -66,7 +67,8 @@ def test_mongodb_strategy_incremental_batch( span_gap, span_length, ): - store = HWMStoreManager.get_current() + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) mongodb = MongoDB( host=processing.host, @@ -76,9 +78,12 @@ def test_mongodb_strategy_incremental_batch( database=processing.database, spark=spark, ) - reader = DBReader(connection=mongodb, table=prepare_schema_table.table, hwm_column=hwm_column, df_schema=df_schema) - - hwm = hwm_type(source=reader.source, column=reader.hwm_column) + reader = DBReader( + connection=mongodb, + table=prepare_schema_table.table, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), + df_schema=df_schema, + ) # there are 2 spans with a gap between # 0..100 @@ -102,7 +107,7 @@ def test_mongodb_strategy_incremental_batch( ) # hwm is not in the store - assert store.get(hwm.qualified_name) is None + assert store.get_hwm(hwm_name) is None # fill up hwm storage with last value, e.g. 100 first_df = None @@ -118,14 +123,14 @@ def test_mongodb_strategy_incremental_batch( # same behavior as SnapshotBatchStrategy, no rows skipped if "int" in hwm_column: # only changed data has been read - processing.assert_equal_df(df=first_df, other_frame=first_span) + processing.assert_equal_df(df=first_df, other_frame=first_span, order_by="_id") else: # date and datetime values have a random part # so instead of checking the whole dataframe a partial comparison should be performed processing.assert_subset_df(df=first_df, other_frame=first_span) # hwm is set - hwm = store.get(hwm.qualified_name) + hwm = store.get_hwm(hwm_name) assert hwm is not None assert isinstance(hwm, hwm_type) assert hwm.value == first_span_max @@ -142,7 +147,7 @@ def test_mongodb_strategy_incremental_batch( second_df = None with IncrementalBatchStrategy(step=step) as batches: for _ in batches: - hwm = store.get(hwm.qualified_name) + hwm = store.get_hwm(hwm_name) assert hwm is not None assert isinstance(hwm, hwm_type) assert first_span_max <= hwm.value <= second_span_max @@ -155,19 +160,19 @@ def test_mongodb_strategy_incremental_batch( else: second_df = second_df.union(next_df) - hwm = store.get(hwm.qualified_name) + hwm = store.get_hwm(hwm_name) assert hwm is not None assert isinstance(hwm, hwm_type) assert first_span_max <= hwm.value <= second_span_max - hwm = store.get(hwm.qualified_name) + hwm = store.get_hwm(hwm_name) assert hwm is not None assert isinstance(hwm, hwm_type) assert hwm.value == second_span_max if "int" in hwm_column: # only changed data has been read - processing.assert_equal_df(df=second_df, other_frame=second_span) + processing.assert_equal_df(df=second_df, other_frame=second_span, order_by="_id") else: # date and datetime values have a random part # so instead of checking the whole dataframe a partial comparison should be performed @@ -188,7 +193,7 @@ def test_mongodb_strategy_incremental_batch_where(spark, processing, prepare_sch connection=mongodb, table=prepare_schema_table.table, where={"$or": [{"float_value": {"$lt": 51}}, {"float_value": {"$gt": 101, "$lt": 120}}]}, - hwm_column="hwm_int", + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="hwm_int"), df_schema=df_schema, ) @@ -219,7 +224,7 @@ def test_mongodb_strategy_incremental_batch_where(spark, processing, prepare_sch else: first_df = first_df.union(next_df) - processing.assert_equal_df(df=first_df, other_frame=first_span[:51]) + processing.assert_equal_df(df=first_df, other_frame=first_span[:51], order_by="_id") # insert second span processing.insert_data( @@ -238,4 +243,4 @@ def test_mongodb_strategy_incremental_batch_where(spark, processing, prepare_sch else: second_df = second_df.union(next_df) - processing.assert_equal_df(df=second_df, other_frame=second_span[:19]) + processing.assert_equal_df(df=second_df, other_frame=second_span[:19], order_by="_id") diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_clickhouse.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_clickhouse.py index 87de9fd46..9a628c3e6 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_clickhouse.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_clickhouse.py @@ -1,9 +1,11 @@ +import secrets + import pytest -from etl_entities import DateHWM, DateTimeHWM, IntHWM +from etl_entities.hwm import ColumnDateHWM, ColumnDateTimeHWM, ColumnIntHWM +from etl_entities.hwm_store import HWMStoreStackManager from onetl.connection import Clickhouse from onetl.db import DBReader -from onetl.hwm.store import HWMStoreManager from onetl.strategy import IncrementalStrategy pytestmark = pytest.mark.clickhouse @@ -13,9 +15,9 @@ @pytest.mark.parametrize( "hwm_type, hwm_column", [ - (IntHWM, "hwm_int"), - (DateHWM, "hwm_date"), - (DateTimeHWM, "hwm_datetime"), + (ColumnIntHWM, "hwm_int"), + (ColumnDateHWM, "hwm_date"), + (ColumnDateTimeHWM, "hwm_datetime"), ], ) @pytest.mark.parametrize( @@ -34,7 +36,8 @@ def test_clickhouse_strategy_incremental( span_gap, span_length, ): - store = HWMStoreManager.get_current() + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) clickhouse = Clickhouse( host=processing.host, @@ -43,9 +46,11 @@ def test_clickhouse_strategy_incremental( password=processing.password, spark=spark, ) - reader = DBReader(connection=clickhouse, source=prepare_schema_table.full_name, hwm_column=hwm_column) - - hwm = hwm_type(source=reader.source, column=reader.hwm_column) + reader = DBReader( + connection=clickhouse, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), + ) # there are 2 spans with a gap between @@ -74,13 +79,10 @@ def test_clickhouse_strategy_incremental( with IncrementalStrategy(): first_df = reader.run() - hwm = store.get(hwm.qualified_name) - assert hwm is not None - assert isinstance(hwm, hwm_type) - assert hwm.value == first_span_max + assert store.get_hwm(hwm_name).value == first_span_max # all the data has been read - processing.assert_equal_df(df=first_df, other_frame=first_span) + processing.assert_equal_df(df=first_df, other_frame=first_span, order_by="id_int") # insert second span processing.insert_data( @@ -92,30 +94,133 @@ def test_clickhouse_strategy_incremental( with IncrementalStrategy(): second_df = reader.run() - assert store.get(hwm.qualified_name).value == second_span_max + assert store.get_hwm(hwm_name).value == second_span_max if "int" in hwm_column: # only changed data has been read - processing.assert_equal_df(df=second_df, other_frame=second_span) + processing.assert_equal_df(df=second_df, other_frame=second_span, order_by="id_int") else: # date and datetime values have a random part # so instead of checking the whole dataframe a partial comparison should be performed processing.assert_subset_df(df=second_df, other_frame=second_span) +def test_clickhouse_strategy_incremental_nothing_to_read(spark, processing, prepare_schema_table): + clickhouse = Clickhouse( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + spark=spark, + ) + + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + hwm_column = "hwm_int" + + reader = DBReader( + connection=clickhouse, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), + ) + + span_gap = 10 + span_length = 50 + + # there are 2 spans with a gap between + + # 0..50 + first_span_begin = 0 + first_span_end = first_span_begin + span_length + + # 60..110 + second_span_begin = first_span_end + span_gap + second_span_end = second_span_begin + span_length + + first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) + second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + + first_span_max = first_span[hwm_column].max() + second_span_max = second_span[hwm_column].max() + + # no data yet, nothing to read + with IncrementalStrategy(): + df = reader.run() + + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value is None + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=first_span, + ) + + # .run() is not called - dataframe still empty - HWM not updated + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value is None + + # set hwm value to 50 + with IncrementalStrategy(): + df = reader.run() + + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # no new data yet, nothing to read + with IncrementalStrategy(): + df = reader.run() + + assert not df.count() + # HWM value is unchanged + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # insert second span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=second_span, + ) + + # .run() is not called - dataframe still empty - HWM not updated + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # read data + with IncrementalStrategy(): + df = reader.run() + + processing.assert_equal_df(df=df, other_frame=second_span, order_by="id_int") + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == second_span_max + + @pytest.mark.flaky( # sometimes test fails with vague error on Spark side, e.g. `An error occurred while calling o58.version` only_rerun="py4j.protocol.Py4JError", ) @pytest.mark.parametrize( - "hwm_column", + "hwm_column, exception_type, error_message", [ - # Fail if HWM is Numeric, or Decimal with fractional part, or string - "float_value", - "text_string", + ("float_value", ValueError, "Expression 'float_value' returned values"), + ("text_string", RuntimeError, "Cannot detect HWM type for"), + ("unknown_column", Exception, "Missing columns"), ], ) -def test_clickhouse_strategy_incremental_wrong_type(spark, processing, prepare_schema_table, hwm_column): +def test_clickhouse_strategy_incremental_wrong_hwm( + spark, + processing, + prepare_schema_table, + hwm_column, + exception_type, + error_message, +): clickhouse = Clickhouse( host=processing.host, port=processing.port, @@ -123,7 +228,11 @@ def test_clickhouse_strategy_incremental_wrong_type(spark, processing, prepare_s password=processing.password, spark=spark, ) - reader = DBReader(connection=clickhouse, source=prepare_schema_table.full_name, hwm_column=hwm_column) + reader = DBReader( + connection=clickhouse, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), + ) data = processing.create_pandas_df() @@ -134,34 +243,67 @@ def test_clickhouse_strategy_incremental_wrong_type(spark, processing, prepare_s values=data, ) - with pytest.raises((KeyError, ValueError)): + with pytest.raises(exception_type, match=error_message): # incremental run with IncrementalStrategy(): reader.run() +def test_clickhouse_strategy_incremental_explicit_hwm_type( + spark, + processing, + prepare_schema_table, +): + clickhouse = Clickhouse( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + spark=spark, + ) + + reader = DBReader( + connection=clickhouse, + source=prepare_schema_table.full_name, + # tell DBReader that text_string column contains integer values, and can be used for HWM + hwm=ColumnIntHWM(name=secrets.token_hex(5), expression="text_string"), + ) + + data = processing.create_pandas_df() + data["text_string"] = data["hwm_int"].apply(str) + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=data, + ) + + # incremental run + with pytest.raises(Exception, match="There is no supertype for types String, UInt8 because"): + with IncrementalStrategy(): + reader.run() + + @pytest.mark.parametrize( - "hwm_source, hwm_column, hwm_expr, hwm_type, func", + "hwm_source, hwm_expr, hwm_type, func", [ ( "hwm_int", - "hwm1_int", "CAST(text_string AS Integer)", - IntHWM, + ColumnIntHWM, str, ), ( "hwm_date", - "hwm1_date", "CAST(text_string AS Date)", - DateHWM, + ColumnDateHWM, lambda x: x.isoformat(), ), ( "hwm_datetime", - "HWM1_DATETIME", "CAST(text_string AS DateTime)", - DateTimeHWM, + ColumnDateTimeHWM, lambda x: x.isoformat(), ), ], @@ -171,7 +313,6 @@ def test_clickhouse_strategy_incremental_with_hwm_expr( processing, prepare_schema_table, hwm_source, - hwm_column, hwm_expr, hwm_type, func, @@ -187,7 +328,7 @@ def test_clickhouse_strategy_incremental_with_hwm_expr( reader = DBReader( connection=clickhouse, source=prepare_schema_table.full_name, - hwm_column=(hwm_column, hwm_expr), + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_expr), ) # there are 2 spans with a gap between @@ -206,12 +347,7 @@ def test_clickhouse_strategy_incremental_with_hwm_expr( second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) first_span["text_string"] = first_span[hwm_source].apply(func) - first_span_with_hwm = first_span.copy() - first_span_with_hwm[hwm_column] = first_span[hwm_source] - second_span["text_string"] = second_span[hwm_source].apply(func) - second_span_with_hwm = second_span.copy() - second_span_with_hwm[hwm_column] = second_span[hwm_source] # insert first span processing.insert_data( @@ -225,7 +361,7 @@ def test_clickhouse_strategy_incremental_with_hwm_expr( first_df = reader.run() # all the data has been read - processing.assert_equal_df(df=first_df, other_frame=first_span_with_hwm) + processing.assert_equal_df(df=first_df, other_frame=first_span, order_by="id_int") # insert second span processing.insert_data( @@ -237,10 +373,10 @@ def test_clickhouse_strategy_incremental_with_hwm_expr( with IncrementalStrategy(): second_df = reader.run() - if issubclass(hwm_type, IntHWM): + if issubclass(hwm_type, ColumnIntHWM): # only changed data has been read - processing.assert_equal_df(df=second_df, other_frame=second_span_with_hwm) + processing.assert_equal_df(df=second_df, other_frame=second_span, order_by="id_int") else: # date and datetime values have a random part # so instead of checking the whole dataframe a partial comparison should be performed - processing.assert_subset_df(df=second_df, other_frame=second_span_with_hwm) + processing.assert_subset_df(df=second_df, other_frame=second_span) diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_common.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_common.py index 0d1167864..3497266fa 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_common.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_common.py @@ -3,12 +3,15 @@ from datetime import timedelta import pytest +from etl_entities.hwm import ColumnIntHWM +from etl_entities.hwm_store import HWMStoreStackManager + +from tests.util.rand import rand_str try: import pandas except ImportError: - # pandas can be missing if someone runs tests for file connections only - pass + pytest.skip("Missing pandas", allow_module_level=True) from onetl.connection import Postgres from onetl.db import DBReader @@ -41,23 +44,35 @@ def test_postgres_strategy_incremental_different_hwm_type_in_store( spark=spark, ) - reader = DBReader(connection=postgres, source=load_table_data.full_name, hwm_column=hwm_column) + reader = DBReader( + connection=postgres, + source=load_table_data.full_name, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), + ) with IncrementalStrategy(): reader.run() - processing.drop_table(schema=load_table_data.schema, table=load_table_data.table) - + # change table schema new_fields = {column_name: processing.get_column_type(column_name) for column_name in processing.column_names} new_fields[hwm_column] = new_type + + processing.drop_table(schema=load_table_data.schema, table=load_table_data.table) processing.create_table(schema=load_table_data.schema, table=load_table_data.table, fields=new_fields) - with pytest.raises(ValueError): + with pytest.raises(TypeError, match="Cannot cast HWM of type .* as .*"): with IncrementalStrategy(): reader.run() -def test_postgres_strategy_incremental_hwm_set_twice(spark, processing, load_table_data): +def test_postgres_strategy_incremental_different_hwm_source_in_store( + spark, + processing, + load_table_data, +): + hwm_store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + postgres = Postgres( host=processing.host, port=processing.port, @@ -67,31 +82,31 @@ def test_postgres_strategy_incremental_hwm_set_twice(spark, processing, load_tab spark=spark, ) - table1 = load_table_data.full_name - table2 = f"{secrets.token_hex()}.{secrets.token_hex()}" - - hwm_column1 = "hwm_int" - hwm_column2 = "hwm_datetime" + old_hwm = ColumnIntHWM(name=hwm_name, source=load_table_data.full_name, expression="hwm_int", description="abc") + # change HWM entity in HWM store + fake_hwm = old_hwm.copy(update={"entity": rand_str()}) + hwm_store.set_hwm(fake_hwm) - reader1 = DBReader(connection=postgres, table=table1, hwm_column=hwm_column1) - reader2 = DBReader(connection=postgres, table=table2, hwm_column=hwm_column1) - reader3 = DBReader(connection=postgres, table=table1, hwm_column=hwm_column2) - - with IncrementalStrategy(): - reader1.run() - - with pytest.raises(ValueError): - reader2.run() - - with pytest.raises(ValueError): - reader3.run() + reader = DBReader( + connection=postgres, + source=load_table_data.full_name, + hwm=old_hwm, + ) + with pytest.raises(ValueError, match="Detected HWM with different `entity` attribute"): + with IncrementalStrategy(): + reader.run() -def test_postgres_strategy_incremental_unknown_hwm_column( +@pytest.mark.parametrize("attribute", ["expression", "description"]) +def test_postgres_strategy_incremental_different_hwm_optional_attribute_in_store( spark, processing, - prepare_schema_table, + load_table_data, + attribute, ): + hwm_store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + postgres = Postgres( host=processing.host, port=processing.port, @@ -101,22 +116,27 @@ def test_postgres_strategy_incremental_unknown_hwm_column( spark=spark, ) + old_hwm = ColumnIntHWM(name=hwm_name, source=load_table_data.full_name, expression="hwm_int", description="abc") + + # change attribute value in HWM store + fake_hwm = old_hwm.copy(update={attribute: rand_str()}) + hwm_store.set_hwm(fake_hwm) + reader = DBReader( connection=postgres, - source=prepare_schema_table.full_name, - hwm_column="unknown_column", # there is no such column in a table + source=load_table_data.full_name, + hwm=old_hwm, ) - - with pytest.raises(Exception): + with pytest.warns(UserWarning, match=f"Detected HWM with different `{attribute}` attribute"): with IncrementalStrategy(): reader.run() + # attributes from DBReader have higher priority, except value + new_hwm = hwm_store.get_hwm(name=hwm_name) + assert new_hwm.dict(exclude={"value", "modified_time"}) == old_hwm.dict(exclude={"value", "modified_time"}) -def test_postgres_strategy_incremental_duplicated_hwm_column( - spark, - processing, - prepare_schema_table, -): + +def test_postgres_strategy_incremental_hwm_set_twice(spark, processing, load_table_data): postgres = Postgres( host=processing.host, port=processing.port, @@ -126,16 +146,39 @@ def test_postgres_strategy_incremental_duplicated_hwm_column( spark=spark, ) - reader = DBReader( + table1 = load_table_data.full_name + table2 = f"{secrets.token_hex(5)}.{secrets.token_hex(5)}" + + reader1 = DBReader( connection=postgres, - source=prepare_schema_table.full_name, - columns=["id_int AS hwm_int"], # previous HWM cast implementation is not supported anymore - hwm_column="hwm_int", + table=table1, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="hwm_int"), + ) + reader2 = DBReader( + connection=postgres, + table=table1, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="hwm_int"), + ) + reader3 = DBReader( + connection=postgres, + table=table2, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="hwm_int"), ) - with pytest.raises(Exception): - with IncrementalStrategy(): - reader.run() + with IncrementalStrategy(): + reader1.run() + + with pytest.raises( + ValueError, + match="Detected wrong IncrementalStrategy usage.", + ): + reader2.run() + + with pytest.raises( + ValueError, + match="Detected wrong IncrementalStrategy usage.", + ): + reader3.run() def test_postgres_strategy_incremental_where(spark, processing, prepare_schema_table): @@ -154,7 +197,7 @@ def test_postgres_strategy_incremental_where(spark, processing, prepare_schema_t connection=postgres, source=prepare_schema_table.full_name, where="id_int < 1000 OR id_int = 1000", - hwm_column="hwm_int", + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="hwm_int"), ) # there are 2 spans with a gap between @@ -184,7 +227,7 @@ def test_postgres_strategy_incremental_where(spark, processing, prepare_schema_t first_df = reader.run() # all the data has been read - processing.assert_equal_df(df=first_df, other_frame=first_span) + processing.assert_equal_df(df=first_df, other_frame=first_span, order_by="id_int") # insert second span processing.insert_data( @@ -197,7 +240,7 @@ def test_postgres_strategy_incremental_where(spark, processing, prepare_schema_t second_df = reader.run() # only changed data has been read - processing.assert_equal_df(df=second_df, other_frame=second_span) + processing.assert_equal_df(df=second_df, other_frame=second_span, order_by="id_int") @pytest.mark.parametrize( @@ -229,8 +272,7 @@ def test_postgres_strategy_incremental_offset( reader = DBReader( connection=postgres, source=prepare_schema_table.full_name, - columns=["*", hwm_column], - hwm_column=hwm_column, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), ) # there are 2 spans with a gap between @@ -269,10 +311,10 @@ def test_postgres_strategy_incremental_offset( next_df = reader.run() total_span = pandas.concat([second_span, first_span], ignore_index=True) - processing.assert_equal_df(df=next_df, other_frame=total_span) + processing.assert_equal_df(df=next_df, other_frame=total_span, order_by="id_int") -def test_postgres_strategy_incremental_handle_exception(spark, processing, prepare_schema_table): # noqa: C812 +def test_postgres_strategy_incremental_handle_exception(spark, processing, prepare_schema_table): postgres = Postgres( host=processing.host, port=processing.port, @@ -282,8 +324,14 @@ def test_postgres_strategy_incremental_handle_exception(spark, processing, prepa spark=spark, ) + hwm_name = secrets.token_hex(5) + hwm_column = "hwm_int" - reader = DBReader(connection=postgres, source=prepare_schema_table.full_name, hwm_column=hwm_column) + reader = DBReader( + connection=postgres, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), + ) span_gap = 10 span_length = 50 @@ -327,11 +375,13 @@ def test_postgres_strategy_incremental_handle_exception(spark, processing, prepa # and then process is retried with IncrementalStrategy(): - reader = DBReader(connection=postgres, source=prepare_schema_table.full_name, hwm_column=hwm_column) - + reader = DBReader( + connection=postgres, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), + ) second_df = reader.run() # all the data from the second span has been read # like there was no exception - second_df = second_df.sort(second_df.id_int.asc()) - processing.assert_equal_df(df=second_df, other_frame=second_span) + processing.assert_equal_df(df=second_df, other_frame=second_span, order_by="id_int") diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_file.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_file.py index 1a3962714..2a22d74cc 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_file.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_file.py @@ -1,58 +1,58 @@ import contextlib import secrets -from etl_entities import FileListHWM, RemoteFolder -from etl_entities.instance import RelativePath +import pytest +from etl_entities.hwm import ColumnIntHWM, FileListHWM +from etl_entities.hwm_store import HWMStoreStackManager +from etl_entities.instance import AbsolutePath from onetl.file import FileDownloader -from onetl.hwm.store import YAMLHWMStore from onetl.strategy import IncrementalStrategy +from tests.util.rand import rand_str -def test_file_downloader_increment( +def test_file_downloader_incremental_strategy( file_connection_with_path_and_files, tmp_path_factory, tmp_path, ): + hwm_store = HWMStoreStackManager.get_current() + file_connection, remote_path, uploaded_files = file_connection_with_path_and_files - hwm_store = YAMLHWMStore(path=tmp_path_factory.mktemp("hwmstore")) # noqa: S306 local_path = tmp_path_factory.mktemp("local_path") + hwm_name = secrets.token_hex(5) + downloader = FileDownloader( connection=file_connection, source_path=remote_path, local_path=local_path, - hwm_type="file_list", + hwm=FileListHWM(name=hwm_name), ) # load first batch of the files - with hwm_store: - with IncrementalStrategy(): - available = downloader.view_files() - downloaded = downloader.run() + with IncrementalStrategy(): + available = downloader.view_files() + downloaded = downloader.run() - # without HWM value all the files are shown and uploaded - assert len(available) == len(downloaded.successful) == len(uploaded_files) - assert sorted(available) == sorted(uploaded_files) - - remote_file_folder = RemoteFolder(name=remote_path, instance=file_connection.instance_url) - file_hwm = FileListHWM(source=remote_file_folder) - file_hwm_name = file_hwm.qualified_name + # without HWM value all the files are shown and uploaded + assert len(available) == len(downloaded.successful) == len(uploaded_files) + assert sorted(available) == sorted(uploaded_files) - source_files = {RelativePath(file.relative_to(remote_path)) for file in uploaded_files} - assert source_files == hwm_store.get(file_hwm_name).value + source_files = {AbsolutePath(file) for file in uploaded_files} + assert source_files == hwm_store.get_hwm(hwm_name).value for _ in "first_inc", "second_inc": new_file_name = f"{secrets.token_hex(5)}.txt" tmp_file = tmp_path / new_file_name tmp_file.write_text(f"{secrets.token_hex(10)}") + target_file = remote_path / new_file_name file_connection.upload_file(tmp_file, remote_path / new_file_name) - with hwm_store: - with IncrementalStrategy(): - available = downloader.view_files() - downloaded = downloader.run() + with IncrementalStrategy(): + available = downloader.view_files() + downloaded = downloader.run() # without HWM value all the files are shown and uploaded assert len(available) == len(downloaded.successful) == 1 @@ -62,93 +62,217 @@ def test_file_downloader_increment( assert downloaded.missing_count == 0 assert downloaded.failed_count == 0 - source_files.add(RelativePath(new_file_name)) - assert source_files == hwm_store.get(file_hwm_name).value + source_files.add(AbsolutePath(target_file)) + assert source_files == hwm_store.get_hwm(hwm_name).value -def test_file_downloader_increment_fail( +def test_file_downloader_incremental_strategy_fail( file_connection_with_path_and_files, tmp_path_factory, tmp_path, ): + hwm_store = HWMStoreStackManager.get_current() + file_connection, remote_path, uploaded_files = file_connection_with_path_and_files - hwm_store = YAMLHWMStore(path=tmp_path_factory.mktemp("hwmstore")) local_path = tmp_path_factory.mktemp("local_path") + hwm_name = secrets.token_hex(5) + downloader = FileDownloader( connection=file_connection, source_path=remote_path, local_path=local_path, - hwm_type="file_list", + hwm=FileListHWM(name=hwm_name), ) - with hwm_store: - with IncrementalStrategy(): - available = downloader.view_files() - downloaded = downloader.run() - - # without HWM value all the files are shown and uploaded - assert len(available) == len(downloaded.successful) == len(uploaded_files) - assert sorted(available) == sorted(uploaded_files) + with IncrementalStrategy(): + available = downloader.view_files() + downloaded = downloader.run() - remote_file_folder = RemoteFolder(name=remote_path, instance=file_connection.instance_url) - file_hwm = FileListHWM(source=remote_file_folder) - file_hwm_name = file_hwm.qualified_name + # without HWM value all the files are shown and uploaded + assert len(available) == len(downloaded.successful) == len(uploaded_files) + assert sorted(available) == sorted(uploaded_files) - # HWM is updated in HWMStore - source_files = {RelativePath(file.relative_to(remote_path)) for file in uploaded_files} - assert source_files == hwm_store.get(file_hwm_name).value + # HWM is updated in HWMStore + source_files = {AbsolutePath(file) for file in uploaded_files} + assert source_files == hwm_store.get_hwm(hwm_name).value - for _ in "first_inc", "second_inc": - new_file_name = f"{secrets.token_hex(5)}.txt" - tmp_file = tmp_path / new_file_name - tmp_file.write_text(f"{secrets.token_hex(10)}") + for _ in "first_inc", "second_inc": + new_file_name = f"{secrets.token_hex(5)}.txt" + tmp_file = tmp_path / new_file_name + tmp_file.write_text(f"{secrets.token_hex(10)}") - file_connection.upload_file(tmp_file, remote_path / new_file_name) + target_file = remote_path / new_file_name + file_connection.upload_file(tmp_file, target_file) - # while loading data, a crash occurs before exiting the context manager - with contextlib.suppress(RuntimeError): - available = downloader.view_files() - downloaded = downloader.run() - # simulating a failure after download - raise RuntimeError("some exception") + # while loading data, a crash occurs before exiting the context manager + with contextlib.suppress(RuntimeError): + available = downloader.view_files() + downloaded = downloader.run() + # simulating a failure after download + raise RuntimeError("some exception") - assert len(available) == len(downloaded.successful) == 1 - assert downloaded.successful[0].name == tmp_file.name - assert downloaded.successful[0].read_text() == tmp_file.read_text() - assert downloaded.skipped_count == 0 - assert downloaded.missing_count == 0 - assert downloaded.failed_count == 0 + assert len(available) == len(downloaded.successful) == 1 + assert downloaded.successful[0].name == tmp_file.name + assert downloaded.successful[0].read_text() == tmp_file.read_text() + assert downloaded.skipped_count == 0 + assert downloaded.missing_count == 0 + assert downloaded.failed_count == 0 - # HWM is saved after downloading each file, not after exiting from .run - source_files.add(RelativePath(new_file_name)) - assert source_files == hwm_store.get(file_hwm_name).value + # HWM is saved at the end of `FileDownloader.run()` call`, not after exiting from strategy + source_files.add(AbsolutePath(target_file)) + assert source_files == hwm_store.get_hwm(hwm_name).value -def test_file_downloader_increment_hwm_is_ignored_for_user_input( +def test_file_downloader_incremental_strategy_hwm_is_ignored_for_user_input( file_connection_with_path_and_files, tmp_path_factory, - tmp_path, ): file_connection, remote_path, uploaded_files = file_connection_with_path_and_files - hwm_store = YAMLHWMStore(path=tmp_path_factory.mktemp("hwm_store")) local_path = tmp_path_factory.mktemp("local_path") + file_hwm_name = secrets.token_hex(5) downloader = FileDownloader( connection=file_connection, source_path=remote_path, local_path=local_path, - hwm_type="file_list", + hwm=FileListHWM(name=file_hwm_name), options=FileDownloader.Options(if_exists="replace_file"), ) - with hwm_store: - with IncrementalStrategy(): - # load first batch of the files - downloader.run() + with IncrementalStrategy(): + # load first batch of the files + downloader.run() - # download files from list - download_result = downloader.run(uploaded_files) + # download files from list + download_result = downloader.run(uploaded_files) # all the files are downloaded, HWM is ignored assert len(download_result.successful) == len(uploaded_files) + + +def test_file_downloader_incremental_strategy_different_hwm_type_in_store( + file_connection_with_path_and_files, + tmp_path_factory, +): + hwm_store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + + file_connection, remote_path, _ = file_connection_with_path_and_files + local_path = tmp_path_factory.mktemp("local_path") + + downloader = FileDownloader( + connection=file_connection, + source_path=remote_path, + local_path=local_path, + hwm=FileListHWM(name=hwm_name), + ) + + # HWM Store contains HWM with same name, but different type + hwm_store.set_hwm(ColumnIntHWM(name=hwm_name, expression="hwm_int")) + + with pytest.raises(TypeError, match="Cannot cast HWM of type .* as .*"): + with IncrementalStrategy(): + downloader.run() + + +def test_file_downloader_incremental_strategy_different_hwm_directory_in_store( + file_connection_with_path_and_files, + tmp_path_factory, +): + hwm_store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + + file_connection, remote_path, _ = file_connection_with_path_and_files + local_path = tmp_path_factory.mktemp("local_path") + + downloader = FileDownloader( + connection=file_connection, + source_path=remote_path, + local_path=local_path, + hwm=FileListHWM(name=hwm_name), + ) + + # HWM Store contains HWM with same name, but different directory + hwm_store.set_hwm(FileListHWM(name=hwm_name, directory=local_path)) + with pytest.raises(ValueError, match="Detected HWM with different `entity` attribute"): + with IncrementalStrategy(): + downloader.run() + + +@pytest.mark.parametrize("attribute", ["expression", "description"]) +def test_file_downloader_incremental_strategy_different_hwm_optional_attribute_in_store( + file_connection_with_path_and_files, + tmp_path_factory, + attribute, +): + hwm_store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + + file_connection, remote_path, _ = file_connection_with_path_and_files + local_path = tmp_path_factory.mktemp("local_path") + + old_hwm = FileListHWM(name=hwm_name, directory=AbsolutePath(remote_path), expression="some", description="another") + # HWM Store contains HWM with same name, but different optional attribute + fake_hwm = old_hwm.copy(update={attribute: rand_str()}) + hwm_store.set_hwm(fake_hwm) + + downloader = FileDownloader( + connection=file_connection, + source_path=remote_path, + local_path=local_path, + hwm=old_hwm, + ) + with pytest.warns(UserWarning, match=f"Detected HWM with different `{attribute}` attribute"): + with IncrementalStrategy(): + downloader.run() + + # attributes from FileDownloader have higher priority, except value + new_hwm = hwm_store.get_hwm(name=hwm_name) + assert new_hwm.dict(exclude={"value", "modified_time"}) == old_hwm.dict(exclude={"value", "modified_time"}) + + +def test_file_downloader_incremental_strategy_hwm_set_twice( + file_connection_with_path_and_files, + tmp_path_factory, +): + file_connection, remote_path, _ = file_connection_with_path_and_files + local_path = tmp_path_factory.mktemp("local_path") + + downloader1 = FileDownloader( + connection=file_connection, + source_path=remote_path, + local_path=local_path, + hwm=FileListHWM(name=secrets.token_hex(5)), + ) + + downloader2 = FileDownloader( + connection=file_connection, + source_path=remote_path, + local_path=local_path, + hwm=FileListHWM(name=secrets.token_hex(5)), + ) + + file_connection.create_dir(remote_path / "different") + file_connection.write_text(remote_path / "different/file.txt", "abc") + downloader3 = FileDownloader( + connection=file_connection, + source_path=remote_path / "different", + local_path=local_path, + hwm=FileListHWM(name=secrets.token_hex(5)), + ) + + with IncrementalStrategy(): + downloader1.run() + + with pytest.raises( + ValueError, + match="Detected wrong IncrementalStrategy usage.", + ): + downloader2.run() + + with pytest.raises( + ValueError, + match="Detected wrong IncrementalStrategy usage.", + ): + downloader3.run() diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_greenplum.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_greenplum.py index b4750cd9b..7a2be6b68 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_greenplum.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_greenplum.py @@ -1,9 +1,11 @@ +import secrets + import pytest -from etl_entities import DateHWM, DateTimeHWM, IntHWM +from etl_entities.hwm import ColumnDateHWM, ColumnDateTimeHWM, ColumnIntHWM +from etl_entities.hwm_store import HWMStoreStackManager from onetl.connection import Greenplum from onetl.db import DBReader -from onetl.hwm.store import HWMStoreManager from onetl.strategy import IncrementalStrategy pytestmark = pytest.mark.greenplum @@ -13,9 +15,9 @@ @pytest.mark.parametrize( "hwm_type, hwm_column", [ - (IntHWM, "hwm_int"), - (DateHWM, "hwm_date"), - (DateTimeHWM, "hwm_datetime"), + (ColumnIntHWM, "hwm_int"), + (ColumnDateHWM, "hwm_date"), + (ColumnDateTimeHWM, "hwm_datetime"), ], ) @pytest.mark.parametrize( @@ -34,7 +36,8 @@ def test_greenplum_strategy_incremental( span_gap, span_length, ): - store = HWMStoreManager.get_current() + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) greenplum = Greenplum( host=processing.host, @@ -45,9 +48,11 @@ def test_greenplum_strategy_incremental( spark=spark, extra=processing.extra, ) - reader = DBReader(connection=greenplum, source=prepare_schema_table.full_name, hwm_column=hwm_column) - - hwm = hwm_type(source=reader.source, column=reader.hwm_column) + reader = DBReader( + connection=greenplum, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), + ) # there are 2 spans with a gap between @@ -76,7 +81,7 @@ def test_greenplum_strategy_incremental( with IncrementalStrategy(): first_df = reader.run() - hwm = store.get(hwm.qualified_name) + hwm = store.get_hwm(hwm_name) assert hwm is not None assert isinstance(hwm, hwm_type) assert hwm.value == first_span_max @@ -94,7 +99,7 @@ def test_greenplum_strategy_incremental( with IncrementalStrategy(): second_df = reader.run() - assert store.get(hwm.qualified_name).value == second_span_max + assert store.get_hwm(hwm_name).value == second_span_max if "int" in hwm_column: # only changed data has been read @@ -105,15 +110,121 @@ def test_greenplum_strategy_incremental( processing.assert_subset_df(df=second_df, other_frame=second_span) +def test_greenplum_strategy_incremental_nothing_to_read(spark, processing, prepare_schema_table): + greenplum = Greenplum( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra=processing.extra, + ) + + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + hwm_column = "hwm_int" + + reader = DBReader( + connection=greenplum, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), + ) + + span_gap = 10 + span_length = 50 + + # there are 2 spans with a gap between + + # 0..50 + first_span_begin = 0 + first_span_end = first_span_begin + span_length + + # 60..110 + second_span_begin = first_span_end + span_gap + second_span_end = second_span_begin + span_length + + first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) + second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + + first_span_max = first_span[hwm_column].max() + second_span_max = second_span[hwm_column].max() + + # no data yet, nothing to read + with IncrementalStrategy(): + df = reader.run() + + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value is None + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=first_span, + ) + + # .run() is not called - dataframe still empty - HWM not updated + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value is None + + # set hwm value to 50 + with IncrementalStrategy(): + df = reader.run() + + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # no new data yet, nothing to read + with IncrementalStrategy(): + df = reader.run() + + assert not df.count() + # HWM value is unchanged + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # insert second span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=second_span, + ) + + # .run() is not called - dataframe still empty - HWM not updated + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # read data + with IncrementalStrategy(): + df = reader.run() + + processing.assert_equal_df(df=df, other_frame=second_span, order_by="id_int") + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == second_span_max + + # Fail if HWM is Numeric, or Decimal with fractional part, or string @pytest.mark.parametrize( - "hwm_column", + "hwm_column, exception_type, error_message", [ - "float_value", - "text_string", + ("float_value", ValueError, "Expression 'float_value' returned values"), + ("text_string", RuntimeError, "Cannot detect HWM type for"), + ("unknown_column", Exception, "column .* does not exist"), ], ) -def test_greenplum_strategy_incremental_wrong_type(spark, processing, prepare_schema_table, hwm_column): +def test_greenplum_strategy_incremental_wrong_wm_type( + spark, + processing, + prepare_schema_table, + hwm_column, + exception_type, + error_message, +): greenplum = Greenplum( host=processing.host, port=processing.port, @@ -123,7 +234,11 @@ def test_greenplum_strategy_incremental_wrong_type(spark, processing, prepare_sc spark=spark, extra=processing.extra, ) - reader = DBReader(connection=greenplum, source=prepare_schema_table.full_name, hwm_column=hwm_column) + reader = DBReader( + connection=greenplum, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), + ) data = processing.create_pandas_df() @@ -134,34 +249,78 @@ def test_greenplum_strategy_incremental_wrong_type(spark, processing, prepare_sc values=data, ) - with pytest.raises((KeyError, ValueError)): + with pytest.raises(exception_type, match=error_message): # incremental run with IncrementalStrategy(): reader.run() +def test_greenplum_strategy_incremental_explicit_hwm_type( + spark, + processing, + prepare_schema_table, +): + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + + greenplum = Greenplum( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra=processing.extra, + ) + reader = DBReader( + connection=greenplum, + source=prepare_schema_table.full_name, + # tell DBReader that text_string column contains integer values, and can be used for HWM + hwm=ColumnIntHWM(name=hwm_name, expression="text_string"), + ) + + data = processing.create_pandas_df() + data["text_string"] = data["hwm_int"].apply(str) + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=data, + ) + + # incremental run + with IncrementalStrategy(): + df = reader.run() + + hwm = store.get_hwm(name=hwm_name) + # type is exactly as set by user + assert isinstance(hwm, ColumnIntHWM) + + # due to alphabetic sort min=0 and max=99 + assert hwm.value == 99 + processing.assert_equal_df(df=df, other_frame=data[data.hwm_int < 100], order_by="id_int") + + @pytest.mark.parametrize( - "hwm_source, hwm_column, hwm_expr, hwm_type, func", + "hwm_source, hwm_expr, hwm_type, func", [ ( "hwm_int", - "hwm1_int", "cast(text_string as int)", - IntHWM, + ColumnIntHWM, str, ), ( "hwm_date", - "hwm1_date", "cast(text_string as date)", - DateHWM, + ColumnDateHWM, lambda x: x.isoformat(), ), ( "hwm_datetime", - "HWM1_DATETIME", "cast(text_string as timestamp)", - DateTimeHWM, + ColumnDateTimeHWM, lambda x: x.isoformat(), ), ], @@ -171,7 +330,6 @@ def test_greenplum_strategy_incremental_with_hwm_expr( processing, prepare_schema_table, hwm_source, - hwm_column, hwm_expr, hwm_type, func, @@ -189,7 +347,7 @@ def test_greenplum_strategy_incremental_with_hwm_expr( reader = DBReader( connection=greenplum, source=prepare_schema_table.full_name, - hwm_column=(hwm_column, hwm_expr), + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_expr), ) # there are 2 spans with a gap between @@ -208,12 +366,7 @@ def test_greenplum_strategy_incremental_with_hwm_expr( second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) first_span["text_string"] = first_span[hwm_source].apply(func) - first_span_with_hwm = first_span.copy() - first_span_with_hwm[hwm_column] = first_span[hwm_source] - second_span["text_string"] = second_span[hwm_source].apply(func) - second_span_with_hwm = second_span.copy() - second_span_with_hwm[hwm_column] = second_span[hwm_source] # insert first span processing.insert_data( @@ -227,7 +380,7 @@ def test_greenplum_strategy_incremental_with_hwm_expr( first_df = reader.run() # all the data has been read - processing.assert_equal_df(df=first_df, other_frame=first_span_with_hwm, order_by="id_int") + processing.assert_equal_df(df=first_df, other_frame=first_span, order_by="id_int") # insert second span processing.insert_data( @@ -239,10 +392,10 @@ def test_greenplum_strategy_incremental_with_hwm_expr( with IncrementalStrategy(): second_df = reader.run() - if issubclass(hwm_type, IntHWM): + if issubclass(hwm_type, ColumnIntHWM): # only changed data has been read - processing.assert_equal_df(df=second_df, other_frame=second_span_with_hwm, order_by="id_int") + processing.assert_equal_df(df=second_df, other_frame=second_span, order_by="id_int") else: # date and datetime values have a random part # so instead of checking the whole dataframe a partial comparison should be performed - processing.assert_subset_df(df=second_df, other_frame=second_span_with_hwm) + processing.assert_subset_df(df=second_df, other_frame=second_span) diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_hive.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_hive.py index d53ca3988..c576a5d8c 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_hive.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_hive.py @@ -1,5 +1,8 @@ +import secrets + import pytest -from etl_entities import DateHWM, DateTimeHWM, IntHWM +from etl_entities.hwm import ColumnDateHWM, ColumnDateTimeHWM, ColumnIntHWM +from etl_entities.hwm_store import HWMStoreStackManager from onetl.connection import Hive from onetl.db import DBReader @@ -10,11 +13,11 @@ @pytest.mark.flaky(reruns=5) @pytest.mark.parametrize( - "hwm_column", + "hwm_type, hwm_column", [ - "hwm_int", - "hwm_date", - "hwm_datetime", + (ColumnIntHWM, "hwm_int"), + (ColumnDateHWM, "hwm_date"), + (ColumnDateTimeHWM, "hwm_datetime"), ], ) @pytest.mark.parametrize( @@ -24,9 +27,24 @@ (10, 50), ], ) -def test_hive_strategy_incremental(spark, processing, prepare_schema_table, hwm_column, span_gap, span_length): +def test_hive_strategy_incremental( + spark, + processing, + prepare_schema_table, + hwm_type, + hwm_column, + span_gap, + span_length, +): + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + hive = Hive(cluster="rnd-dwh", spark=spark) - reader = DBReader(connection=hive, source=prepare_schema_table.full_name, hwm_column=hwm_column) + reader = DBReader( + connection=hive, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), + ) # there are 2 spans with a gap between @@ -41,6 +59,9 @@ def test_hive_strategy_incremental(spark, processing, prepare_schema_table, hwm_ first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_spant_end) second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + first_span_max = first_span[hwm_column].max() + second_span_max = second_span[hwm_column].max() + # insert first span processing.insert_data( schema=prepare_schema_table.schema, @@ -52,8 +73,13 @@ def test_hive_strategy_incremental(spark, processing, prepare_schema_table, hwm_ with IncrementalStrategy(): first_df = reader.run() + hwm = store.get_hwm(hwm_name) + assert hwm is not None + assert isinstance(hwm, hwm_type) + assert hwm.value == first_span_max + # all the data has been read - processing.assert_equal_df(df=first_df, other_frame=first_span) + processing.assert_equal_df(df=first_df, other_frame=first_span, order_by="id_int") # insert second span processing.insert_data( @@ -65,26 +91,130 @@ def test_hive_strategy_incremental(spark, processing, prepare_schema_table, hwm_ with IncrementalStrategy(): second_df = reader.run() + assert store.get_hwm(hwm_name).value == second_span_max + if "int" in hwm_column: # only changed data has been read - processing.assert_equal_df(df=second_df, other_frame=second_span) + processing.assert_equal_df(df=second_df, other_frame=second_span, order_by="id_int") else: # date and datetime values have a random part # so instead of checking the whole dataframe a partial comparison should be performed processing.assert_subset_df(df=second_df, other_frame=second_span) +def test_hive_strategy_incremental_nothing_to_read(spark, processing, prepare_schema_table): + hive = Hive(cluster="rnd-dwh", spark=spark) + + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + hwm_column = "hwm_int" + + reader = DBReader( + connection=hive, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), + ) + + span_gap = 10 + span_length = 50 + + # there are 2 spans with a gap between + + # 0..50 + first_span_begin = 0 + first_span_end = first_span_begin + span_length + + # 60..110 + second_span_begin = first_span_end + span_gap + second_span_end = second_span_begin + span_length + + first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) + second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + + first_span_max = first_span[hwm_column].max() + second_span_max = second_span[hwm_column].max() + + # no data yet, nothing to read + with IncrementalStrategy(): + df = reader.run() + + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value is None + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=first_span, + ) + + # .run() is not called - dataframe still empty - HWM not updated + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value is None + + # set hwm value to 50 + with IncrementalStrategy(): + df = reader.run() + + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # no new data yet, nothing to read + with IncrementalStrategy(): + df = reader.run() + + assert not df.count() + # HWM value is unchanged + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # insert second span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=second_span, + ) + + # .run() is not called - dataframe still empty - HWM not updated + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # read data + with IncrementalStrategy(): + df = reader.run() + + processing.assert_equal_df(df=df, other_frame=second_span, order_by="id_int") + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == second_span_max + + # Fail if HWM is Numeric, or Decimal with fractional part, or string @pytest.mark.parametrize( - "hwm_column", + "hwm_column, exception_type, error_message", [ - "float_value", - "text_string", + ("float_value", ValueError, "Expression 'float_value' returned values"), + ("text_string", RuntimeError, "Cannot detect HWM type for"), + ("unknown_column", Exception, r"column .* cannot be resolved|cannot resolve .* given input columns"), ], ) -def test_hive_strategy_incremental_wrong_type(spark, processing, prepare_schema_table, hwm_column): +def test_hive_strategy_incremental_wrong_hwm( + spark, + processing, + prepare_schema_table, + hwm_column, + exception_type, + error_message, +): hive = Hive(cluster="rnd-dwh", spark=spark) - reader = DBReader(connection=hive, source=prepare_schema_table.full_name, hwm_column=hwm_column) + reader = DBReader( + connection=hive, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), + ) data = processing.create_pandas_df() @@ -95,22 +225,61 @@ def test_hive_strategy_incremental_wrong_type(spark, processing, prepare_schema_ values=data, ) - with pytest.raises((KeyError, ValueError)): + with pytest.raises(exception_type, match=error_message): # incremental run with IncrementalStrategy(): reader.run() +def test_hive_strategy_incremental_explicit_hwm_type( + spark, + processing, + prepare_schema_table, +): + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + + hive = Hive(cluster="rnd-dwh", spark=spark) + + reader = DBReader( + connection=hive, + source=prepare_schema_table.full_name, + # tell DBReader that text_string column contains integer values, and can be used for HWM + hwm=ColumnIntHWM(name=hwm_name, expression="text_string"), + ) + + data = processing.create_pandas_df() + data["text_string"] = data["hwm_int"].apply(str) + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=data, + ) + + # incremental run + with IncrementalStrategy(): + df = reader.run() + + hwm = store.get_hwm(name=hwm_name) + # type is exactly as set by user + assert isinstance(hwm, ColumnIntHWM) + + # due to alphabetic sort min=0 and max=99 + assert hwm.value == 99 + processing.assert_equal_df(df=df, other_frame=data[data.hwm_int < 100], order_by="id_int") + + @pytest.mark.parametrize( - "hwm_source, hwm_expr, hwm_column, hwm_type, func", + "hwm_source, hwm_expr, hwm_type, func", [ - ("hwm_int", "CAST(text_string AS INT)", "hwm1_int", IntHWM, str), - ("hwm_date", "CAST(text_string AS DATE)", "hwm1_date", DateHWM, lambda x: x.isoformat()), + ("hwm_int", "CAST(text_string AS INT)", ColumnIntHWM, str), + ("hwm_date", "CAST(text_string AS DATE)", ColumnDateHWM, lambda x: x.isoformat()), ( "hwm_datetime", "CAST(text_string AS TIMESTAMP)", - "HWM1_DATETIME", - DateTimeHWM, + ColumnDateTimeHWM, lambda x: x.isoformat(), ), ], @@ -121,7 +290,6 @@ def test_hive_strategy_incremental_with_hwm_expr( prepare_schema_table, hwm_source, hwm_expr, - hwm_column, hwm_type, func, ): @@ -130,7 +298,7 @@ def test_hive_strategy_incremental_with_hwm_expr( reader = DBReader( connection=hive, source=prepare_schema_table.full_name, - hwm_column=(hwm_column, hwm_expr), + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_expr), ) # there are 2 spans with a gap between @@ -149,12 +317,7 @@ def test_hive_strategy_incremental_with_hwm_expr( second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) first_span["text_string"] = first_span[hwm_source].apply(func) - first_span_with_hwm = first_span.copy() - first_span_with_hwm[hwm_column] = first_span[hwm_source] - second_span["text_string"] = second_span[hwm_source].apply(func) - second_span_with_hwm = second_span.copy() - second_span_with_hwm[hwm_column] = second_span[hwm_source] # insert first span processing.insert_data( @@ -168,7 +331,7 @@ def test_hive_strategy_incremental_with_hwm_expr( first_df = reader.run() # all the data has been read - processing.assert_equal_df(df=first_df, other_frame=first_span_with_hwm) + processing.assert_equal_df(df=first_df, other_frame=first_span, order_by="id_int") # insert second span processing.insert_data( @@ -180,10 +343,10 @@ def test_hive_strategy_incremental_with_hwm_expr( with IncrementalStrategy(): second_df = reader.run() - if issubclass(hwm_type, IntHWM): + if issubclass(hwm_type, ColumnIntHWM): # only changed data has been read - processing.assert_equal_df(df=second_df, other_frame=second_span_with_hwm) + processing.assert_equal_df(df=second_df, other_frame=second_span, order_by="id_int") else: # date and datetime values have a random part # so instead of checking the whole dataframe a partial comparison should be performed - processing.assert_subset_df(df=second_df, other_frame=second_span_with_hwm) + processing.assert_subset_df(df=second_df, other_frame=second_span) diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mongodb.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mongodb.py index 71528bfc8..888300a2d 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mongodb.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mongodb.py @@ -1,9 +1,11 @@ +import secrets + import pytest -from etl_entities import DateTimeHWM, IntHWM +from etl_entities.hwm import ColumnDateTimeHWM, ColumnIntHWM +from etl_entities.hwm_store import HWMStoreStackManager from onetl.connection import MongoDB from onetl.db import DBReader -from onetl.hwm.store import HWMStoreManager from onetl.strategy import IncrementalStrategy pytestmark = pytest.mark.mongodb @@ -35,8 +37,8 @@ def df_schema(): @pytest.mark.parametrize( "hwm_type, hwm_column", [ - (IntHWM, "hwm_int"), - (DateTimeHWM, "hwm_datetime"), + (ColumnIntHWM, "hwm_int"), + (ColumnDateTimeHWM, "hwm_datetime"), ], ) @pytest.mark.parametrize( @@ -56,7 +58,7 @@ def test_mongodb_strategy_incremental( span_gap, span_length, ): - store = HWMStoreManager.get_current() + store = HWMStoreStackManager.get_current() mongodb = MongoDB( host=processing.host, @@ -67,15 +69,15 @@ def test_mongodb_strategy_incremental( spark=spark, ) + hwm_name = secrets.token_hex(5) + reader = DBReader( connection=mongodb, table=prepare_schema_table.table, - hwm_column=hwm_column, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), df_schema=df_schema, ) - hwm = hwm_type(source=reader.source, column=reader.hwm_column) - # there are 2 spans with a gap between # 0..100 @@ -103,13 +105,13 @@ def test_mongodb_strategy_incremental( with IncrementalStrategy(): first_df = reader.run() - hwm = store.get(hwm.qualified_name) + hwm = store.get_hwm(hwm_name) assert hwm is not None assert isinstance(hwm, hwm_type) assert hwm.value == first_span_max # all the data has been read - processing.assert_equal_df(df=first_df, other_frame=first_span) + processing.assert_equal_df(df=first_df, other_frame=first_span, order_by="_id") # insert second span processing.insert_data( @@ -121,26 +123,138 @@ def test_mongodb_strategy_incremental( with IncrementalStrategy(): second_df = reader.run() - assert store.get(hwm.qualified_name).value == second_span_max + assert store.get_hwm(hwm_name).value == second_span_max if "int" in hwm_column: # only changed data has been read - processing.assert_equal_df(df=second_df, other_frame=second_span) + processing.assert_equal_df(df=second_df, other_frame=second_span, order_by="_id") else: # date and datetime values have a random part # so instead of checking the whole dataframe a partial comparison should be performed processing.assert_subset_df(df=second_df, other_frame=second_span) +def test_mongodb_strategy_incremental_nothing_to_read( + spark, + processing, + df_schema, + prepare_schema_table, +): + mongodb = MongoDB( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + hwm_column = "hwm_int" + + reader = DBReader( + connection=mongodb, + source=prepare_schema_table.table, + df_schema=df_schema, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), + ) + + span_gap = 10 + span_length = 50 + + # there are 2 spans with a gap between + + # 0..50 + first_span_begin = 0 + first_span_end = first_span_begin + span_length + + # 60..110 + second_span_begin = first_span_end + span_gap + second_span_end = second_span_begin + span_length + + first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) + second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + + first_span_max = first_span[hwm_column].max() + second_span_max = second_span[hwm_column].max() + + # no data yet, nothing to read + with IncrementalStrategy(): + df = reader.run() + + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value is None + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=first_span, + ) + + # .run() is not called - dataframe still empty - HWM not updated + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value is None + + # set hwm value to 50 + with IncrementalStrategy(): + df = reader.run() + + processing.assert_equal_df(df=df, other_frame=first_span, order_by="_id") + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # no new data yet, nothing to read + with IncrementalStrategy(): + df = reader.run() + + assert not df.count() + # HWM value is unchanged + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # insert second span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=second_span, + ) + + # .run() is not called - dataframe still empty - HWM not updated + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # read data + with IncrementalStrategy(): + df = reader.run() + + processing.assert_equal_df(df=df, other_frame=second_span, order_by="_id") + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == second_span_max + + # Fail if HWM is Numeric, or Decimal with fractional part, or string @pytest.mark.parametrize( - "hwm_column", + "hwm_column, exception_type, error_message", [ - "float_value", - "text_string", + ("float_value", ValueError, "Expression 'float_value' returned values"), + ("text_string", RuntimeError, "Cannot detect HWM type for"), + ("unknown_column", ValueError, "not found in dataframe schema"), ], ) -def test_mongodb_strategy_incremental_wrong_type(spark, processing, prepare_schema_table, df_schema, hwm_column): +def test_mongodb_strategy_incremental_wrong_hwm( + spark, + processing, + prepare_schema_table, + df_schema, + hwm_column, + exception_type, + error_message, +): mongodb = MongoDB( host=processing.host, port=processing.port, @@ -153,7 +267,7 @@ def test_mongodb_strategy_incremental_wrong_type(spark, processing, prepare_sche reader = DBReader( connection=mongodb, table=prepare_schema_table.table, - hwm_column=hwm_column, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), df_schema=df_schema, ) @@ -166,7 +280,58 @@ def test_mongodb_strategy_incremental_wrong_type(spark, processing, prepare_sche values=data, ) - with pytest.raises((KeyError, ValueError)): + with pytest.raises(exception_type, match=error_message): # incremental run with IncrementalStrategy(): reader.run() + + +def test_mongodb_strategy_incremental_explicit_hwm_type( + spark, + processing, + df_schema, + prepare_schema_table, +): + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + + mongodb = MongoDB( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + reader = DBReader( + connection=mongodb, + source=prepare_schema_table.table, + df_schema=df_schema, + # tell DBReader that text_string column contains integer values, and can be used for HWM + hwm=ColumnIntHWM(name=hwm_name, expression="text_string"), + ) + + data = processing.create_pandas_df() + data["text_string"] = data["hwm_int"].apply(str) + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=data, + ) + + # incremental run + with IncrementalStrategy(): + df = reader.run() + + hwm = store.get_hwm(name=hwm_name) + # type is exactly as set by user + assert isinstance(hwm, ColumnIntHWM) + + # MongoDB does not support comparison str < int + assert not df.count() + + # but HWM is updated to max value from the source. yes, that's really weird case. + # garbage in (wrong HWM type for specific expression) - garbage out (wrong dataframe content) + assert hwm.value == 99 diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mssql.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mssql.py index e2cb342ca..90447d3b3 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mssql.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mssql.py @@ -1,9 +1,11 @@ +import secrets + import pytest -from etl_entities import DateHWM, DateTimeHWM, IntHWM +from etl_entities.hwm import ColumnDateHWM, ColumnDateTimeHWM, ColumnIntHWM +from etl_entities.hwm_store import HWMStoreStackManager from onetl.connection import MSSQL from onetl.db import DBReader -from onetl.hwm.store import HWMStoreManager from onetl.strategy import IncrementalStrategy pytestmark = pytest.mark.mssql @@ -13,9 +15,9 @@ @pytest.mark.parametrize( "hwm_type, hwm_column", [ - (IntHWM, "hwm_int"), - (DateHWM, "hwm_date"), - (DateTimeHWM, "hwm_datetime"), + (ColumnIntHWM, "hwm_int"), + (ColumnDateHWM, "hwm_date"), + (ColumnDateTimeHWM, "hwm_datetime"), ], ) @pytest.mark.parametrize( @@ -34,7 +36,8 @@ def test_mssql_strategy_incremental( span_gap, span_length, ): - store = HWMStoreManager.get_current() + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) mssql = MSSQL( host=processing.host, @@ -45,9 +48,11 @@ def test_mssql_strategy_incremental( spark=spark, extra={"trustServerCertificate": "true"}, ) - reader = DBReader(connection=mssql, source=prepare_schema_table.full_name, hwm_column=hwm_column) - - hwm = hwm_type(source=reader.source, column=reader.hwm_column) + reader = DBReader( + connection=mssql, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), + ) # there are 2 spans with a gap between @@ -76,13 +81,13 @@ def test_mssql_strategy_incremental( with IncrementalStrategy(): first_df = reader.run() - hwm = store.get(hwm.qualified_name) + hwm = store.get_hwm(hwm_name) assert hwm is not None assert isinstance(hwm, hwm_type) assert hwm.value == first_span_max # all the data has been read - processing.assert_equal_df(df=first_df, other_frame=first_span) + processing.assert_equal_df(df=first_df, other_frame=first_span, order_by="id_int") # insert second span processing.insert_data( @@ -94,26 +99,132 @@ def test_mssql_strategy_incremental( with IncrementalStrategy(): second_df = reader.run() - assert store.get(hwm.qualified_name).value == second_span_max + assert store.get_hwm(hwm_name).value == second_span_max if "int" in hwm_column: # only changed data has been read - processing.assert_equal_df(df=second_df, other_frame=second_span) + processing.assert_equal_df(df=second_df, other_frame=second_span, order_by="id_int") else: # date and datetime values have a random part # so instead of checking the whole dataframe a partial comparison should be performed processing.assert_subset_df(df=second_df, other_frame=second_span) +def test_mssql_strategy_incremental_nothing_to_read(spark, processing, prepare_schema_table): + mssql = MSSQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra={"trustServerCertificate": "true"}, + ) + + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + hwm_column = "hwm_int" + + reader = DBReader( + connection=mssql, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), + ) + + span_gap = 10 + span_length = 50 + + # there are 2 spans with a gap between + + # 0..50 + first_span_begin = 0 + first_span_end = first_span_begin + span_length + + # 60..110 + second_span_begin = first_span_end + span_gap + second_span_end = second_span_begin + span_length + + first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) + second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + + first_span_max = first_span[hwm_column].max() + second_span_max = second_span[hwm_column].max() + + # no data yet, nothing to read + with IncrementalStrategy(): + df = reader.run() + + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value is None + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=first_span, + ) + + # .run() is not called - dataframe still empty - HWM not updated + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value is None + + # set hwm value to 50 + with IncrementalStrategy(): + df = reader.run() + + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # no new data yet, nothing to read + with IncrementalStrategy(): + df = reader.run() + + assert not df.count() + # HWM value is unchanged + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # insert second span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=second_span, + ) + + # .run() is not called - dataframe still empty - HWM not updated + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # read data + with IncrementalStrategy(): + df = reader.run() + + processing.assert_equal_df(df=df, other_frame=second_span, order_by="id_int") + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == second_span_max + + # Fail if HWM is Numeric, or Decimal with fractional part, or string @pytest.mark.parametrize( - "hwm_column", + "hwm_column, exception_type, error_message", [ - "float_value", - "text_string", + ("float_value", ValueError, "Expression 'float_value' returned values"), + ("text_string", RuntimeError, "Cannot detect HWM type for"), + ("unknown_column", Exception, "Invalid column name"), ], ) -def test_mssql_strategy_incremental_wrong_type(spark, processing, prepare_schema_table, hwm_column): +def test_mssql_strategy_incremental_wrong_hwm( + spark, + processing, + prepare_schema_table, + hwm_column, + exception_type, + error_message, +): mssql = MSSQL( host=processing.host, port=processing.port, @@ -123,7 +234,11 @@ def test_mssql_strategy_incremental_wrong_type(spark, processing, prepare_schema spark=spark, extra={"trustServerCertificate": "true"}, ) - reader = DBReader(connection=mssql, source=prepare_schema_table.full_name, hwm_column=hwm_column) + reader = DBReader( + connection=mssql, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), + ) data = processing.create_pandas_df() @@ -134,34 +249,78 @@ def test_mssql_strategy_incremental_wrong_type(spark, processing, prepare_schema values=data, ) - with pytest.raises((KeyError, ValueError)): + with pytest.raises(exception_type, match=error_message): # incremental run with IncrementalStrategy(): reader.run() +def test_mssql_strategy_incremental_explicit_hwm_type( + spark, + processing, + prepare_schema_table, +): + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + + mssql = MSSQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra={"trustServerCertificate": "true"}, + ) + reader = DBReader( + connection=mssql, + source=prepare_schema_table.full_name, + # tell DBReader that text_string column contains integer values, and can be used for HWM + hwm=ColumnIntHWM(name=hwm_name, expression="text_string"), + ) + + data = processing.create_pandas_df() + data["text_string"] = data["hwm_int"].apply(str) + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=data, + ) + + # incremental run + with IncrementalStrategy(): + df = reader.run() + + hwm = store.get_hwm(name=hwm_name) + # type is exactly as set by user + assert isinstance(hwm, ColumnIntHWM) + + # due to alphabetic sort min=0 and max=99 + assert hwm.value == 99 + processing.assert_equal_df(df=df, other_frame=data[data.hwm_int < 100], order_by="id_int") + + @pytest.mark.parametrize( - "hwm_source, hwm_column, hwm_expr, hwm_type, func", + "hwm_source, hwm_expr, hwm_type, func", [ ( "hwm_int", - "hwm1_int", "CAST(text_string AS int)", - IntHWM, + ColumnIntHWM, str, ), ( "hwm_date", - "hwm1_date", "CAST(text_string AS Date)", - DateHWM, + ColumnDateHWM, lambda x: x.isoformat(), ), ( "hwm_datetime", - "HWM1_DATETIME", "CAST(text_string AS datetime2)", - DateTimeHWM, + ColumnDateTimeHWM, lambda x: x.isoformat(), ), ], @@ -171,7 +330,6 @@ def test_mssql_strategy_incremental_with_hwm_expr( processing, prepare_schema_table, hwm_source, - hwm_column, hwm_expr, hwm_type, func, @@ -189,7 +347,7 @@ def test_mssql_strategy_incremental_with_hwm_expr( reader = DBReader( connection=mssql, source=prepare_schema_table.full_name, - hwm_column=(hwm_column, hwm_expr), + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_expr), ) # there are 2 spans with a gap between @@ -208,12 +366,7 @@ def test_mssql_strategy_incremental_with_hwm_expr( second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) first_span["text_string"] = first_span[hwm_source].apply(func) - first_span_with_hwm = first_span.copy() - first_span_with_hwm[hwm_column] = first_span[hwm_source] - second_span["text_string"] = second_span[hwm_source].apply(func) - second_span_with_hwm = second_span.copy() - second_span_with_hwm[hwm_column] = second_span[hwm_source] # insert first span processing.insert_data( @@ -227,7 +380,7 @@ def test_mssql_strategy_incremental_with_hwm_expr( first_df = reader.run() # all the data has been read - processing.assert_equal_df(df=first_df, other_frame=first_span_with_hwm) + processing.assert_equal_df(df=first_df, other_frame=first_span, order_by="id_int") # insert second span processing.insert_data( @@ -239,10 +392,10 @@ def test_mssql_strategy_incremental_with_hwm_expr( with IncrementalStrategy(): second_df = reader.run() - if issubclass(hwm_type, IntHWM): + if issubclass(hwm_type, ColumnIntHWM): # only changed data has been read - processing.assert_equal_df(df=second_df, other_frame=second_span_with_hwm) + processing.assert_equal_df(df=second_df, other_frame=second_span, order_by="id_int") else: # date and datetime values have a random part # so instead of checking the whole dataframe a partial comparison should be performed - processing.assert_subset_df(df=second_df, other_frame=second_span_with_hwm) + processing.assert_subset_df(df=second_df, other_frame=second_span) diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mysql.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mysql.py index dc6e0a3b3..29bba214e 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mysql.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mysql.py @@ -1,9 +1,11 @@ +import secrets + import pytest -from etl_entities import DateHWM, DateTimeHWM, IntHWM +from etl_entities.hwm import ColumnDateHWM, ColumnDateTimeHWM, ColumnIntHWM +from etl_entities.hwm_store import HWMStoreStackManager from onetl.connection import MySQL from onetl.db import DBReader -from onetl.hwm.store import HWMStoreManager from onetl.strategy import IncrementalStrategy pytestmark = pytest.mark.mysql @@ -13,9 +15,9 @@ @pytest.mark.parametrize( "hwm_type, hwm_column", [ - (IntHWM, "hwm_int"), - (DateHWM, "hwm_date"), - (DateTimeHWM, "hwm_datetime"), + (ColumnIntHWM, "hwm_int"), + (ColumnDateHWM, "hwm_date"), + (ColumnDateTimeHWM, "hwm_datetime"), ], ) @pytest.mark.parametrize( @@ -34,7 +36,8 @@ def test_mysql_strategy_incremental( span_gap, span_length, ): - store = HWMStoreManager.get_current() + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) mysql = MySQL( host=processing.host, @@ -44,9 +47,11 @@ def test_mysql_strategy_incremental( database=processing.database, spark=spark, ) - reader = DBReader(connection=mysql, source=prepare_schema_table.full_name, hwm_column=hwm_column) - - hwm = hwm_type(source=reader.source, column=reader.hwm_column) + reader = DBReader( + connection=mysql, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), + ) # there are 2 spans with a gap between @@ -75,13 +80,13 @@ def test_mysql_strategy_incremental( with IncrementalStrategy(): first_df = reader.run() - hwm = store.get(hwm.qualified_name) + hwm = store.get_hwm(hwm_name) assert hwm is not None assert isinstance(hwm, hwm_type) assert hwm.value == first_span_max # all the data has been read - processing.assert_equal_df(df=first_df, other_frame=first_span) + processing.assert_equal_df(df=first_df, other_frame=first_span, order_by="id_int") # insert second span processing.insert_data( @@ -93,26 +98,131 @@ def test_mysql_strategy_incremental( with IncrementalStrategy(): second_df = reader.run() - assert store.get(hwm.qualified_name).value == second_span_max + assert store.get_hwm(hwm_name).value == second_span_max if "int" in hwm_column: # only changed data has been read - processing.assert_equal_df(df=second_df, other_frame=second_span) + processing.assert_equal_df(df=second_df, other_frame=second_span, order_by="id_int") else: # date and datetime values have a random part # so instead of checking the whole dataframe a partial comparison should be performed processing.assert_subset_df(df=second_df, other_frame=second_span) +def test_mysql_strategy_incremental_nothing_to_read(spark, processing, prepare_schema_table): + mysql = MySQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + hwm_column = "hwm_int" + + reader = DBReader( + connection=mysql, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), + ) + + span_gap = 10 + span_length = 50 + + # there are 2 spans with a gap between + + # 0..50 + first_span_begin = 0 + first_span_end = first_span_begin + span_length + + # 60..110 + second_span_begin = first_span_end + span_gap + second_span_end = second_span_begin + span_length + + first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) + second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + + first_span_max = first_span[hwm_column].max() + second_span_max = second_span[hwm_column].max() + + # no data yet, nothing to read + with IncrementalStrategy(): + df = reader.run() + + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value is None + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=first_span, + ) + + # .run() is not called - dataframe still empty - HWM not updated + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value is None + + # set hwm value to 50 + with IncrementalStrategy(): + df = reader.run() + + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # no new data yet, nothing to read + with IncrementalStrategy(): + df = reader.run() + + assert not df.count() + # HWM value is unchanged + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # insert second span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=second_span, + ) + + # .run() is not called - dataframe still empty - HWM not updated + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # read data + with IncrementalStrategy(): + df = reader.run() + + processing.assert_equal_df(df=df, other_frame=second_span, order_by="id_int") + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == second_span_max + + # Fail if HWM is Numeric, or Decimal with fractional part, or string @pytest.mark.parametrize( - "hwm_column", + "hwm_column, exception_type, error_message", [ - "float_value", - "text_string", + ("float_value", ValueError, "Expression 'float_value' returned values"), + ("text_string", RuntimeError, "Cannot detect HWM type for"), + ("unknown_column", Exception, "Unknown column"), ], ) -def test_mysql_strategy_incremental_wrong_hwm_type(spark, processing, prepare_schema_table, hwm_column): +def test_mysql_strategy_incremental_wrong_hwm( + spark, + processing, + prepare_schema_table, + hwm_column, + exception_type, + error_message, +): mysql = MySQL( host=processing.host, port=processing.port, @@ -121,7 +231,11 @@ def test_mysql_strategy_incremental_wrong_hwm_type(spark, processing, prepare_sc spark=spark, database=processing.database, ) - reader = DBReader(connection=mysql, source=prepare_schema_table.full_name, hwm_column=hwm_column) + reader = DBReader( + connection=mysql, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), + ) data = processing.create_pandas_df() @@ -132,34 +246,77 @@ def test_mysql_strategy_incremental_wrong_hwm_type(spark, processing, prepare_sc values=data, ) - with pytest.raises((KeyError, ValueError)): + with pytest.raises(exception_type, match=error_message): # incremental run with IncrementalStrategy(): reader.run() +def test_mysql_strategy_incremental_explicit_hwm_type( + spark, + processing, + prepare_schema_table, +): + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + + mysql = MySQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + spark=spark, + database=processing.database, + ) + reader = DBReader( + connection=mysql, + source=prepare_schema_table.full_name, + # tell DBReader that text_string column contains integer values, and can be used for HWM + hwm=ColumnIntHWM(name=hwm_name, expression="text_string"), + ) + + data = processing.create_pandas_df() + data["text_string"] = data["hwm_int"].apply(str) + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=data, + ) + + # incremental run + with IncrementalStrategy(): + df = reader.run() + + hwm = store.get_hwm(name=hwm_name) + # type is exactly as set by user + assert isinstance(hwm, ColumnIntHWM) + + # due to alphabetic sort min=0 and max=99 + assert hwm.value == 99 + processing.assert_equal_df(df=df, other_frame=data[data.hwm_int < 100], order_by="id_int") + + @pytest.mark.parametrize( - "hwm_source, hwm_column, hwm_expr, hwm_type, func", + "hwm_source, hwm_expr, hwm_type, func", [ ( "hwm_int", - "hwm1_int", "(text_string+0)", - IntHWM, + ColumnIntHWM, str, ), ( "hwm_date", - "hwm1_date", "STR_TO_DATE(text_string, '%Y-%m-%d')", - DateHWM, + ColumnDateHWM, lambda x: x.isoformat(), ), ( "hwm_datetime", - "HWM1_DATETIME", "STR_TO_DATE(text_string, '%Y-%m-%dT%H:%i:%s.%f')", - DateTimeHWM, + ColumnDateTimeHWM, lambda x: x.isoformat(), ), ], @@ -169,7 +326,6 @@ def test_mysql_strategy_incremental_with_hwm_expr( processing, prepare_schema_table, hwm_source, - hwm_column, hwm_expr, hwm_type, func, @@ -186,7 +342,7 @@ def test_mysql_strategy_incremental_with_hwm_expr( reader = DBReader( connection=mysql, source=prepare_schema_table.full_name, - hwm_column=(hwm_column, hwm_expr), + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_expr), ) # there are 2 spans with a gap between @@ -205,12 +361,7 @@ def test_mysql_strategy_incremental_with_hwm_expr( second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) first_span["text_string"] = first_span[hwm_source].apply(func) - first_span_with_hwm = first_span.copy() - first_span_with_hwm[hwm_column] = first_span[hwm_source] - second_span["text_string"] = second_span[hwm_source].apply(func) - second_span_with_hwm = second_span.copy() - second_span_with_hwm[hwm_column] = second_span[hwm_source] # insert first span processing.insert_data( @@ -224,7 +375,7 @@ def test_mysql_strategy_incremental_with_hwm_expr( first_df = reader.run() # all the data has been read - processing.assert_equal_df(df=first_df, other_frame=first_span_with_hwm) + processing.assert_equal_df(df=first_df, other_frame=first_span, order_by="id_int") # insert second span processing.insert_data( @@ -236,10 +387,10 @@ def test_mysql_strategy_incremental_with_hwm_expr( with IncrementalStrategy(): second_df = reader.run() - if issubclass(hwm_type, IntHWM): + if issubclass(hwm_type, ColumnIntHWM): # only changed data has been read - processing.assert_equal_df(df=second_df, other_frame=second_span_with_hwm) + processing.assert_equal_df(df=second_df, other_frame=second_span, order_by="id_int") else: # date and datetime values have a random part # so instead of checking the whole dataframe a partial comparison should be performed - processing.assert_subset_df(df=second_df, other_frame=second_span_with_hwm) + processing.assert_subset_df(df=second_df, other_frame=second_span) diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_oracle.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_oracle.py index 7d7142d64..3e07546f4 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_oracle.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_oracle.py @@ -1,5 +1,9 @@ +import secrets +from datetime import datetime + import pytest -from etl_entities import DateHWM, DateTimeHWM, IntHWM +from etl_entities.hwm import ColumnDateHWM, ColumnDateTimeHWM, ColumnIntHWM +from etl_entities.hwm_store import HWMStoreStackManager from onetl.connection import Oracle from onetl.db import DBReader @@ -12,11 +16,12 @@ # Do not fail in such the case @pytest.mark.flaky(reruns=5) @pytest.mark.parametrize( - "hwm_column", + "hwm_type, hwm_column", [ - "HWM_INT", - "HWM_DATE", - "HWM_DATETIME", + (ColumnIntHWM, "HWM_INT"), + # there is no Date type in Oracle + (ColumnDateTimeHWM, "HWM_DATE"), + (ColumnDateTimeHWM, "HWM_DATETIME"), ], ) @pytest.mark.parametrize( @@ -30,10 +35,14 @@ def test_oracle_strategy_incremental( spark, processing, prepare_schema_table, + hwm_type, hwm_column, span_gap, span_length, ): + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + oracle = Oracle( host=processing.host, port=processing.port, @@ -43,7 +52,11 @@ def test_oracle_strategy_incremental( service_name=processing.service_name, spark=spark, ) - reader = DBReader(connection=oracle, source=prepare_schema_table.full_name, hwm_column=hwm_column) + reader = DBReader( + connection=oracle, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), + ) # there are 2 spans with a gap between @@ -58,6 +71,18 @@ def test_oracle_strategy_incremental( first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + first_span_max = first_span[hwm_column.lower()].max() + second_span_max = second_span[hwm_column.lower()].max() + + if "datetime" in hwm_column.lower(): + # Oracle store datetime only with second precision + first_span_max = first_span_max.replace(microsecond=0, nanosecond=0) + second_span_max = second_span_max.replace(microsecond=0, nanosecond=0) + elif "date" in hwm_column.lower(): + # Oracle does not support date type, convert to datetime + first_span_max = datetime.fromisoformat(first_span_max.isoformat()) + second_span_max = datetime.fromisoformat(second_span_max.isoformat()) + # insert first span processing.insert_data( schema=prepare_schema_table.schema, @@ -69,8 +94,13 @@ def test_oracle_strategy_incremental( with IncrementalStrategy(): first_df = reader.run() + hwm = store.get_hwm(hwm_name) + assert hwm is not None + assert isinstance(hwm, hwm_type) + assert hwm.value == first_span_max + # all the data has been read - processing.assert_equal_df(df=first_df, other_frame=first_span) + processing.assert_equal_df(df=first_df, other_frame=first_span, order_by="id_int") # insert second span processing.insert_data( @@ -82,24 +112,132 @@ def test_oracle_strategy_incremental( with IncrementalStrategy(): second_df = reader.run() + assert store.get_hwm(hwm_name).value == second_span_max + if "int" in hwm_column: # only changed data has been read - processing.assert_equal_df(df=second_df, other_frame=second_span) + processing.assert_equal_df(df=second_df, other_frame=second_span, order_by="id_int") else: # date and datetime values have a random part # so instead of checking the whole dataframe a partial comparison should be performed processing.assert_subset_df(df=second_df, other_frame=second_span) +def test_oracle_strategy_incremental_nothing_to_read(spark, processing, prepare_schema_table): + oracle = Oracle( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + sid=processing.sid, + service_name=processing.service_name, + spark=spark, + ) + + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + hwm_column = "HWM_INT" + + reader = DBReader( + connection=oracle, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), + ) + + span_gap = 10 + span_length = 50 + + # there are 2 spans with a gap between + + # 0..50 + first_span_begin = 0 + first_span_end = first_span_begin + span_length + + # 60..110 + second_span_begin = first_span_end + span_gap + second_span_end = second_span_begin + span_length + + first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) + second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + + first_span_max = first_span["hwm_int"].max() + second_span_max = second_span["hwm_int"].max() + + # no data yet, nothing to read + with IncrementalStrategy(): + df = reader.run() + + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value is None + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=first_span, + ) + + # .run() is not called - dataframe still empty - HWM not updated + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value is None + + # set hwm value to 50 + with IncrementalStrategy(): + df = reader.run() + + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # no new data yet, nothing to read + with IncrementalStrategy(): + df = reader.run() + + assert not df.count() + # HWM value is unchanged + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # insert second span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=second_span, + ) + + # .run() is not called - dataframe still empty - HWM not updated + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # read data + with IncrementalStrategy(): + df = reader.run() + + processing.assert_equal_df(df=df, other_frame=second_span, order_by="id_int") + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == second_span_max + + # Fail if HWM is Numeric, or Decimal with fractional part, or string @pytest.mark.parametrize( - "hwm_column", + "hwm_column, exception_type, error_message", [ - "FLOAT_VALUE", - "TEXT_STRING", + ("FLOAT_VALUE", ValueError, "Expression 'FLOAT_VALUE' returned values"), + ("TEXT_STRING", RuntimeError, "Cannot detect HWM type for"), + ("UNKNOWN_COLUMN", Exception, "java.sql.SQLSyntaxErrorException"), ], ) -def test_oracle_strategy_incremental_wrong_hwm_type(spark, processing, prepare_schema_table, hwm_column): +def test_oracle_strategy_incremental_wrong_hwm( + spark, + processing, + prepare_schema_table, + hwm_column, + exception_type, + error_message, +): oracle = Oracle( host=processing.host, port=processing.port, @@ -109,7 +247,11 @@ def test_oracle_strategy_incremental_wrong_hwm_type(spark, processing, prepare_s service_name=processing.service_name, spark=spark, ) - reader = DBReader(connection=oracle, source=prepare_schema_table.full_name, hwm_column=hwm_column) + reader = DBReader( + connection=oracle, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), + ) data = processing.create_pandas_df() @@ -120,34 +262,78 @@ def test_oracle_strategy_incremental_wrong_hwm_type(spark, processing, prepare_s values=data, ) - with pytest.raises((KeyError, ValueError)): + with pytest.raises(exception_type, match=error_message): # incremental run with IncrementalStrategy(): reader.run() +def test_oracle_strategy_incremental_explicit_hwm_type( + spark, + processing, + prepare_schema_table, +): + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + + oracle = Oracle( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + sid=processing.sid, + service_name=processing.service_name, + spark=spark, + ) + reader = DBReader( + connection=oracle, + source=prepare_schema_table.full_name, + # tell DBReader that text_string column contains integer values, and can be used for HWM + hwm=ColumnIntHWM(name=hwm_name, expression="TEXT_STRING"), + ) + + data = processing.create_pandas_df() + data["text_string"] = data["hwm_int"].apply(str) + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=data, + ) + + # incremental run + with IncrementalStrategy(): + df = reader.run() + + hwm = store.get_hwm(name=hwm_name) + # type is exactly as set by user + assert isinstance(hwm, ColumnIntHWM) + + # due to alphabetic sort min=0 and max=99 + assert hwm.value == 99 + processing.assert_equal_df(df=df, other_frame=data[data.hwm_int < 100], order_by="id_int") + + @pytest.mark.parametrize( - "hwm_source, hwm_column, hwm_expr, hwm_type, func", + "hwm_source, hwm_expr, hwm_type, func", [ ( "hwm_int", - "HWM1_INT", "TO_NUMBER(TEXT_STRING)", - IntHWM, + ColumnIntHWM, str, ), ( "hwm_date", - "HWM1_DATE", "TO_DATE(TEXT_STRING, 'YYYY-MM-DD')", - DateHWM, + ColumnDateHWM, lambda x: x.isoformat(), ), ( "hwm_datetime", - "hwm1_datetime", "TO_DATE(TEXT_STRING, 'YYYY-MM-DD HH24:MI:SS')", - DateTimeHWM, + ColumnDateTimeHWM, lambda x: x.strftime("%Y-%m-%d %H:%M:%S"), ), ], @@ -157,7 +343,6 @@ def test_oracle_strategy_incremental_with_hwm_expr( processing, prepare_schema_table, hwm_source, - hwm_column, hwm_expr, hwm_type, func, @@ -175,7 +360,7 @@ def test_oracle_strategy_incremental_with_hwm_expr( reader = DBReader( connection=oracle, source=prepare_schema_table.full_name, - hwm_column=(hwm_column, hwm_expr), + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_expr), ) # there are 2 spans with a gap between @@ -194,12 +379,7 @@ def test_oracle_strategy_incremental_with_hwm_expr( second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) first_span["text_string"] = first_span[hwm_source].apply(func) - first_span_with_hwm = first_span.copy() - first_span_with_hwm[hwm_column.upper()] = first_span[hwm_source] - second_span["text_string"] = second_span[hwm_source].apply(func) - second_span_with_hwm = second_span.copy() - second_span_with_hwm[hwm_column.upper()] = second_span[hwm_source] # insert first span processing.insert_data( @@ -213,7 +393,7 @@ def test_oracle_strategy_incremental_with_hwm_expr( first_df = reader.run() # all the data has been read - processing.assert_equal_df(df=first_df, other_frame=first_span_with_hwm) + processing.assert_equal_df(df=first_df, other_frame=first_span, order_by="id_int") # insert second span processing.insert_data( @@ -225,10 +405,10 @@ def test_oracle_strategy_incremental_with_hwm_expr( with IncrementalStrategy(): second_df = reader.run() - if issubclass(hwm_type, IntHWM): + if issubclass(hwm_type, ColumnIntHWM): # only changed data has been read - processing.assert_equal_df(df=second_df, other_frame=second_span_with_hwm) + processing.assert_equal_df(df=second_df, other_frame=second_span, order_by="id_int") else: # date and datetime values have a random part # so instead of checking the whole dataframe a partial comparison should be performed - processing.assert_subset_df(df=second_df, other_frame=second_span_with_hwm) + processing.assert_subset_df(df=second_df, other_frame=second_span) diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_postgres.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_postgres.py index aff7cb0a0..424949751 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_postgres.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_postgres.py @@ -1,9 +1,11 @@ +import secrets + import pytest -from etl_entities import DateHWM, DateTimeHWM, IntHWM +from etl_entities.hwm import ColumnDateHWM, ColumnDateTimeHWM, ColumnIntHWM +from etl_entities.hwm_store import HWMStoreStackManager from onetl.connection import Postgres from onetl.db import DBReader -from onetl.hwm.store import HWMStoreManager from onetl.strategy import IncrementalStrategy pytestmark = pytest.mark.postgres @@ -13,9 +15,9 @@ @pytest.mark.parametrize( "hwm_type, hwm_column", [ - (IntHWM, "hwm_int"), - (DateHWM, "hwm_date"), - (DateTimeHWM, "hwm_datetime"), + (ColumnIntHWM, "hwm_int"), + (ColumnDateHWM, "hwm_date"), + (ColumnDateTimeHWM, "hwm_datetime"), ], ) @pytest.mark.parametrize( @@ -34,7 +36,8 @@ def test_postgres_strategy_incremental( span_gap, span_length, ): - store = HWMStoreManager.get_current() + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) postgres = Postgres( host=processing.host, @@ -44,9 +47,12 @@ def test_postgres_strategy_incremental( database=processing.database, spark=spark, ) - reader = DBReader(connection=postgres, source=prepare_schema_table.full_name, hwm_column=hwm_column) - hwm = hwm_type(source=reader.source, column=reader.hwm_column) + reader = DBReader( + connection=postgres, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), + ) # there are 2 spans with a gap between @@ -72,19 +78,19 @@ def test_postgres_strategy_incremental( ) # hwm is not in the store - assert store.get(hwm.qualified_name) is None + assert store.get_hwm(hwm_name) is None # incremental run with IncrementalStrategy(): first_df = reader.run() - hwm = store.get(hwm.qualified_name) + hwm = store.get_hwm(hwm_name) assert hwm is not None assert isinstance(hwm, hwm_type) assert hwm.value == first_span_max # all the data has been read - processing.assert_equal_df(df=first_df, other_frame=first_span) + processing.assert_equal_df(df=first_df, other_frame=first_span, order_by="id_int") # insert second span processing.insert_data( @@ -96,11 +102,11 @@ def test_postgres_strategy_incremental( with IncrementalStrategy(): second_df = reader.run() - assert store.get(hwm.qualified_name).value == second_span_max + assert store.get_hwm(hwm_name).value == second_span_max if "int" in hwm_column: # only changed data has been read - processing.assert_equal_df(df=second_df, other_frame=second_span) + processing.assert_equal_df(df=second_df, other_frame=second_span, order_by="id_int") else: # date and datetime values have a random part # so instead of checking the whole dataframe a partial comparison should be performed @@ -108,27 +114,24 @@ def test_postgres_strategy_incremental( @pytest.mark.parametrize( - "hwm_source, hwm_column, hwm_expr, hwm_type, func", + "hwm_source, hwm_expr, hwm_type, func", [ ( "hwm_int", - "hwm1_int", "text_string::int", - IntHWM, + ColumnIntHWM, str, ), ( "hwm_date", - "hwm1_date", "text_string::date", - DateHWM, + ColumnDateHWM, lambda x: x.isoformat(), ), ( "hwm_datetime", - "HWM1_DATETIME", "text_string::timestamp", - DateTimeHWM, + ColumnDateTimeHWM, lambda x: x.isoformat(), ), ], @@ -138,7 +141,6 @@ def test_postgres_strategy_incremental_with_hwm_expr( processing, prepare_schema_table, hwm_source, - hwm_column, hwm_expr, hwm_type, func, @@ -155,9 +157,7 @@ def test_postgres_strategy_incremental_with_hwm_expr( reader = DBReader( connection=postgres, source=prepare_schema_table.full_name, - # hwm_column is present in the dataframe even if it was not passed in columns list - columns=[column for column in processing.column_names if column != hwm_column], - hwm_column=(hwm_column, hwm_expr), + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_expr), ) # there are 2 spans with a gap between @@ -176,12 +176,7 @@ def test_postgres_strategy_incremental_with_hwm_expr( second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) first_span["text_string"] = first_span[hwm_source].apply(func) - first_span_with_hwm = first_span.copy() - first_span_with_hwm[hwm_column.lower()] = first_span[hwm_source] - second_span["text_string"] = second_span[hwm_source].apply(func) - second_span_with_hwm = second_span.copy() - second_span_with_hwm[hwm_column.lower()] = second_span[hwm_source] # insert first span processing.insert_data( @@ -195,7 +190,7 @@ def test_postgres_strategy_incremental_with_hwm_expr( first_df = reader.run() # all the data has been read - processing.assert_equal_df(df=first_df, other_frame=first_span_with_hwm) + processing.assert_equal_df(df=first_df, other_frame=first_span, order_by="id_int") # insert second span processing.insert_data( @@ -207,24 +202,129 @@ def test_postgres_strategy_incremental_with_hwm_expr( with IncrementalStrategy(): second_df = reader.run() - if issubclass(hwm_type, IntHWM): + if issubclass(hwm_type, ColumnIntHWM): # only changed data has been read - processing.assert_equal_df(df=second_df, other_frame=second_span_with_hwm) + processing.assert_equal_df(df=second_df, other_frame=second_span, order_by="id_int") else: # date and datetime values have a random part # so instead of checking the whole dataframe a partial comparison should be performed - processing.assert_subset_df(df=second_df, other_frame=second_span_with_hwm) + processing.assert_subset_df(df=second_df, other_frame=second_span) + + +def test_postgres_strategy_incremental_nothing_to_read(spark, processing, prepare_schema_table): + postgres = Postgres( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + + store = HWMStoreStackManager.get_current() + hwm_name = secrets.token_hex(5) + hwm_column = "hwm_int" + + reader = DBReader( + connection=postgres, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression=hwm_column), + ) + + span_gap = 10 + span_length = 50 + + # there are 2 spans with a gap between + + # 0..50 + first_span_begin = 0 + first_span_end = first_span_begin + span_length + + # 60..110 + second_span_begin = first_span_end + span_gap + second_span_end = second_span_begin + span_length + + first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) + second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + + first_span_max = first_span[hwm_column].max() + second_span_max = second_span[hwm_column].max() + + # no data yet, nothing to read + with IncrementalStrategy(): + df = reader.run() + + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value is None + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=first_span, + ) + + # .run() is not called - dataframe still empty - HWM not updated + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value is None + + # set hwm value to 50 + with IncrementalStrategy(): + df = reader.run() + + processing.assert_equal_df(df=df, other_frame=first_span, order_by="id_int") + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # no new data yet, nothing to read + with IncrementalStrategy(): + df = reader.run() + + assert not df.count() + # HWM value is unchanged + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # insert second span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=second_span, + ) + + # .run() is not called - dataframe still empty - HWM not updated + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == first_span_max + + # read data + with IncrementalStrategy(): + df = reader.run() + + processing.assert_equal_df(df=df, other_frame=second_span, order_by="id_int") + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == second_span_max # Fail if HWM is Numeric, or Decimal with fractional part, or string @pytest.mark.parametrize( - "hwm_column", + "hwm_column, exception_type, error_message", [ - "float_value", - "text_string", + ("float_value", ValueError, "Expression 'float_value' returned values"), + ("text_string", RuntimeError, "Cannot detect HWM type for"), + ("unknown_column", Exception, "column .* does not exist"), ], ) -def test_postgres_strategy_incremental_wrong_hwm_type(spark, processing, prepare_schema_table, hwm_column): +def test_postgres_strategy_incremental_wrong_hwm( + spark, + processing, + prepare_schema_table, + hwm_column, + exception_type, + error_message, +): postgres = Postgres( host=processing.host, user=processing.user, @@ -232,7 +332,11 @@ def test_postgres_strategy_incremental_wrong_hwm_type(spark, processing, prepare database=processing.database, spark=spark, ) - reader = DBReader(connection=postgres, source=prepare_schema_table.full_name, hwm_column=hwm_column) + reader = DBReader( + connection=postgres, + source=prepare_schema_table.full_name, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), + ) data = processing.create_pandas_df() @@ -243,7 +347,42 @@ def test_postgres_strategy_incremental_wrong_hwm_type(spark, processing, prepare values=data, ) - with pytest.raises((KeyError, ValueError)): + with pytest.raises(exception_type, match=error_message): # incremental run with IncrementalStrategy(): reader.run() + + +def test_postgres_strategy_incremental_explicit_hwm_type( + spark, + processing, + prepare_schema_table, +): + postgres = Postgres( + host=processing.host, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + ) + reader = DBReader( + connection=postgres, + source=prepare_schema_table.full_name, + # tell DBReader that text_string column contains integer values, and can be used for HWM + hwm=ColumnIntHWM(name=secrets.token_hex(5), expression="text_string"), + ) + + data = processing.create_pandas_df() + data["text_string"] = data["hwm_int"].apply(str) + + # insert first span + processing.insert_data( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + values=data, + ) + + # incremental run + with pytest.raises(Exception, match="operator does not exist: text <= integer"): + with IncrementalStrategy(): + reader.run() diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_common_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_common_reader_unit.py index fa3f6ad10..eb4f68c68 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_common_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_common_reader_unit.py @@ -6,6 +6,8 @@ from onetl.connection import Hive from onetl.db import DBReader +pytestmark = pytest.mark.hive + def test_reader_deprecated_import(): msg = textwrap.dedent( @@ -54,23 +56,7 @@ def test_reader_hive_with_read_options(spark_mock): (), {}, set(), - " \t\n", - [""], - [" \t\n"], - ["", "abc"], - [" \t\n", "abc"], - "", - " \t\n", - ",abc", - "abc,", - "cde,,abc", - "cde, ,abc", - "*,*,cde", - "abc,abc,cde", - "abc,ABC,cde", - ["*", "*", "cde"], - ["abc", "abc", "cde"], - ["abc", "ABC", "cde"], + "any,string", ], ) def test_reader_invalid_columns(spark_mock, columns): @@ -85,9 +71,6 @@ def test_reader_invalid_columns(spark_mock, columns): @pytest.mark.parametrize( "columns, real_columns", [ - ("*", ["*"]), - ("abc, cde", ["abc", "cde"]), - ("*, abc", ["*", "abc"]), (["*"], ["*"]), (["abc", "cde"], ["abc", "cde"]), (["*", "abc"], ["*", "abc"]), @@ -104,129 +87,74 @@ def test_reader_valid_columns(spark_mock, columns, real_columns): @pytest.mark.parametrize( - "hwm_column", + "hwm_column, real_hwm_expression", [ - "wrong/name", - "wrong@name", - "wrong=name", - "wrong#name", - [], - {}, - (), - set(), - frozenset(), - ("name",), - ["name"], - {"name"}, - ("wrong/name", "statement"), - ("wrong@name", "statement"), - ("wrong=name", "statement"), - ("wrong#name", "statement"), - ["wrong/name", "statement"], - ["wrong@name", "statement"], - ["wrong=name", "statement"], - ["wrong#name", "statement"], - ("wrong/name", "statement", "too", "many"), - ("wrong@name", "statement", "too", "many"), - ("wrong=name", "statement", "too", "many"), - ("wrong#name", "statement", "too", "many"), - ["wrong/name", "statement", "too", "many"], - ["wrong@name", "statement", "too", "many"], - ["wrong=name", "statement", "too", "many"], - ["wrong#name", "statement", "too", "many"], - {"wrong/name", "statement", "too", "many"}, - {"wrong@name", "statement", "too", "many"}, - {"wrong=name", "statement", "too", "many"}, - {"wrong#name", "statement", "too", "many"}, - (None, "statement"), - [None, "statement"], - # this is the same as hwm_column="name", - # but if user implicitly passed a tuple - # both of values should be set to avoid unexpected errors - ("name", None), - ["name", None], + ("hwm_column", "hwm_column"), + (("hwm_column", "expression"), "expression"), + (("hwm_column", "hwm_column"), "hwm_column"), ], ) -def test_reader_invalid_hwm_column(spark_mock, hwm_column): - with pytest.raises(ValueError): - DBReader( +def test_reader_deprecated_hwm_column(spark_mock, hwm_column, real_hwm_expression): + error_msg = 'Passing "hwm_column" in DBReader class is deprecated since version 0.10.0' + with pytest.warns(UserWarning, match=error_msg): + reader = DBReader( connection=Hive(cluster="rnd-dwh", spark=spark_mock), table="schema.table", hwm_column=hwm_column, ) + assert isinstance(reader.hwm, reader.AutoDetectHWM) + assert reader.hwm.entity == "schema.table" + assert reader.hwm.expression == real_hwm_expression -@pytest.mark.parametrize( - "hwm_column, real_hwm_column, real_hwm_expression", - [ - ("hwm_column", "hwm_column", None), - (("hwm_column", "expression"), "hwm_column", "expression"), - (("hwm_column", "hwm_column"), "hwm_column", "hwm_column"), - ], -) -def test_reader_valid_hwm_column(spark_mock, hwm_column, real_hwm_column, real_hwm_expression): + +def test_reader_autofill_hwm_source(spark_mock): reader = DBReader( connection=Hive(cluster="rnd-dwh", spark=spark_mock), table="schema.table", - hwm_column=hwm_column, + hwm=DBReader.AutoDetectHWM( + name="some_name", + expression="some_expression", + ), ) - assert reader.hwm_column.name == real_hwm_column - assert reader.hwm_expression == real_hwm_expression + assert reader.hwm.entity == "schema.table" + assert reader.hwm.expression == "some_expression" -@pytest.mark.parametrize( - "columns, hwm_column", - [ - (["a", "b", "c", "d"], "d"), - (["a", "b", "c", "d"], "D"), - (["a", "b", "c", "D"], "d"), - ("a, b, c, d", "d"), - ("a, b, c, d", "D"), - ("a, b, c, D", "d"), - (["*", "d"], "d"), - (["*", "d"], "D"), - (["*", "D"], "d"), - ("*, d", "d"), - ("*, d", "D"), - ("*, D", "d"), - (["*"], "d"), - (["*"], "D"), - (["*"], ("d", "cast")), - (["*"], ("D", "cast")), - ], -) -def test_reader_hwm_column_and_columns_are_not_in_conflict(spark_mock, columns, hwm_column): - DBReader( +def test_reader_hwm_has_same_source(spark_mock): + reader = DBReader( connection=Hive(cluster="rnd-dwh", spark=spark_mock), - table="schema.table", - columns=columns, - hwm_column=hwm_column, + source="schema.table", + hwm=DBReader.AutoDetectHWM( + name="some_name", + source="schema.table", + expression="some_expression", + ), ) + assert reader.hwm.entity == "schema.table" + assert reader.hwm.expression == "some_expression" -@pytest.mark.parametrize( - "columns, hwm_column", - [ - (["a", "b", "c", "d"], ("d", "cast")), - (["a", "b", "c", "d"], ("D", "cast")), - (["a", "b", "c", "D"], ("d", "cast")), - ("a, b, c, d", ("d", "cast")), - ("a, b, c, d", ("D", "cast")), - ("a, b, c, D", ("d", "cast")), - (["*", "d"], ("d", "cast")), - (["*", "d"], ("D", "cast")), - (["*", "D"], ("d", "cast")), - ("*, d", ("d", "cast")), - ("*, d", ("D", "cast")), - ("*, D", ("d", "cast")), - ], -) -def test_reader_hwm_column_and_columns_are_in_conflict(spark_mock, columns, hwm_column): - with pytest.raises(ValueError): + +def test_reader_hwm_has_different_source(spark_mock): + error_msg = "Passed `hwm.source` is different from `source`" + with pytest.raises(ValueError, match=error_msg): DBReader( connection=Hive(cluster="rnd-dwh", spark=spark_mock), table="schema.table", - columns=columns, - hwm_column=hwm_column, + hwm=DBReader.AutoDetectHWM( + name="some_name", + source="another.table", + expression="some_expression", + ), + ) + + +def test_reader_no_hwm_expression(spark_mock): + with pytest.raises(ValueError, match="`hwm.expression` cannot be None"): + DBReader( + connection=Hive(cluster="rnd-dwh", spark=spark_mock), + table="schema.table", + hwm=DBReader.AutoDetectHWM(name="some_name"), ) diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py index 5fd228f5b..3506a415f 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py @@ -1,5 +1,6 @@ +import secrets + import pytest -from etl_entities import Column from onetl.connection import Kafka from onetl.db import DBReader @@ -45,6 +46,7 @@ def test_kafka_reader_unsupported_parameters(spark_mock, df_schema): where={"col1": 1}, table="table", ) + with pytest.raises( ValueError, match="'hint' parameter is not supported by Kafka ", @@ -54,6 +56,7 @@ def test_kafka_reader_unsupported_parameters(spark_mock, df_schema): hint={"col1": 1}, table="table", ) + with pytest.raises( ValueError, match="'df_schema' parameter is not supported by Kafka ", @@ -65,8 +68,7 @@ def test_kafka_reader_unsupported_parameters(spark_mock, df_schema): ) -@pytest.mark.parametrize("hwm_column", ["offset", Column(name="offset")]) -def test_kafka_reader_valid_hwm_column(spark_mock, hwm_column): +def test_kafka_reader_hwm_offset_is_valid(spark_mock): kafka = Kafka( addresses=["localhost:9092"], cluster="my_cluster", @@ -76,11 +78,11 @@ def test_kafka_reader_valid_hwm_column(spark_mock, hwm_column): DBReader( connection=kafka, table="table", - hwm_column=hwm_column, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="offset"), ) -def test_kafka_reader_hwm_column_by_version(spark_mock, mocker): +def test_kafka_reader_hwm_timestamp_depends_on_spark_version(spark_mock, mocker): kafka = Kafka( addresses=["localhost:9092"], cluster="my_cluster", @@ -90,7 +92,7 @@ def test_kafka_reader_hwm_column_by_version(spark_mock, mocker): DBReader( connection=kafka, table="table", - hwm_column="timestamp", + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="timestamp"), ) mocker.patch.object(spark_mock, "version", new="2.4.0") @@ -98,12 +100,11 @@ def test_kafka_reader_hwm_column_by_version(spark_mock, mocker): DBReader( connection=kafka, table="table", - hwm_column="timestamp", + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="timestamp"), ) -@pytest.mark.parametrize("hwm_column", ["unknown", '("some", "thing")']) -def test_kafka_reader_invalid_hwm_column(spark_mock, hwm_column): +def test_kafka_reader_invalid_hwm_column(spark_mock): kafka = Kafka( addresses=["localhost:9092"], cluster="my_cluster", @@ -112,10 +113,10 @@ def test_kafka_reader_invalid_hwm_column(spark_mock, hwm_column): with pytest.raises( ValueError, - match="is not a valid hwm column", + match="hwm.expression='unknown' is not supported by Kafka", ): DBReader( connection=kafka, table="table", - hwm_column=hwm_column, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="unknown"), ) diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_mongodb_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_mongodb_reader_unit.py index 2a164c16d..d2465b957 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_mongodb_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_mongodb_reader_unit.py @@ -134,22 +134,6 @@ def test_mongodb_reader_without_df_schema(spark_mock): ) -def test_mongodb_reader_error_pass_hwm_expression(spark_mock, df_schema): - mongo = MongoDB( - host="host", - user="user", - password="password", - database="database", - spark=spark_mock, - ) - - with pytest.raises( - ValueError, - match="'hwm_expression' parameter is not supported by MongoDB", - ): - DBReader(connection=mongo, table="table", df_schema=df_schema, hwm_column=("hwm_int", "expr")) - - def test_mongodb_reader_error_pass_columns(spark_mock, df_schema): mongo = MongoDB( host="host", @@ -164,19 +148,3 @@ def test_mongodb_reader_error_pass_columns(spark_mock, df_schema): match="'columns' parameter is not supported by MongoDB", ): DBReader(connection=mongo, table="table", columns=["_id", "test"], df_schema=df_schema) - - -def test_mongodb_reader_hwm_column_not_in_df_schema(spark_mock, df_schema): - mongo = MongoDB( - host="host", - user="user", - password="password", - database="database", - spark=spark_mock, - ) - - with pytest.raises( - ValueError, - match="'df_schema' struct must contain column specified in 'hwm_column'.*", - ): - DBReader(connection=mongo, table="table", hwm_column="_id2", df_schema=df_schema) diff --git a/tests/tests_unit/test_file/test_file_downloader_unit.py b/tests/tests_unit/test_file/test_file_downloader_unit.py index ffec5f339..b5e62b0cd 100644 --- a/tests/tests_unit/test_file/test_file_downloader_unit.py +++ b/tests/tests_unit/test_file/test_file_downloader_unit.py @@ -3,7 +3,16 @@ from unittest.mock import Mock import pytest -from etl_entities import HWM, ColumnHWM, DateHWM, DateTimeHWM, IntHWM +from etl_entities.hwm import ( + HWM, + ColumnDateHWM, + ColumnDateTimeHWM, + ColumnHWM, + ColumnIntHWM, + FileListHWM, +) +from etl_entities.instance import AbsolutePath +from etl_entities.old_hwm import FileListHWM as OldFileListHWM from onetl.base import BaseFileConnection from onetl.core import FileFilter, FileLimit @@ -31,55 +40,124 @@ def test_file_downloader_deprecated_import(): assert OldFileDownloader is FileDownloader +@pytest.mark.parametrize( + "hwm_type", + [ + "file_list", + OldFileListHWM, + ], +) +def test_file_downloader_hwm_type_deprecated(hwm_type): + warning_msg = 'Passing "hwm_type" to FileDownloader class is deprecated since version 0.10.0' + connection = Mock(spec=BaseFileConnection) + connection.instance_url = "abc" + + with pytest.warns(UserWarning, match=warning_msg): + downloader = FileDownloader( + connection=connection, + local_path="/local/path", + source_path="/source/path", + hwm_type=hwm_type, + ) + + assert isinstance(downloader.hwm, FileListHWM) + assert downloader.hwm.entity == AbsolutePath("/source/path") + + def test_file_downloader_unknown_hwm_type(): - with pytest.raises(KeyError, match="Unknown HWM type 'abc'"): + # fails on pydantic issubclass(hwm_type, OldFileListHWM) in FileDownloader + with pytest.raises(ValueError): FileDownloader( - connection=Mock(), - local_path="/path", - source_path="/path", + connection=Mock(spec=BaseFileConnection), + local_path="/local/path", + source_path="/source/path", hwm_type="abc", ) @pytest.mark.parametrize( - "hwm_type, hwm_type_name", + "hwm_type", + [ + ColumnIntHWM, + ColumnDateHWM, + ColumnDateTimeHWM, + ColumnHWM, + HWM, + ], +) +def test_file_downloader_wrong_hwm_type(hwm_type): + # pydantic validation fails, as new hwm classes are passed into hwm_type + with pytest.raises(ValueError): + FileDownloader( + connection=Mock(spec=BaseFileConnection), + local_path="/local/path", + source_path="/source/path", + hwm_type=hwm_type, + ) + + +@pytest.mark.parametrize( + "hwm_type", [ - ("byte", "IntHWM"), - ("integer", "IntHWM"), - ("short", "IntHWM"), - ("long", "IntHWM"), - ("date", "DateHWM"), - ("timestamp", "DateTimeHWM"), - (IntHWM, "IntHWM"), - (DateHWM, "DateHWM"), - (DateTimeHWM, "DateTimeHWM"), - (HWM, "HWM"), - (ColumnHWM, "ColumnHWM"), + "file_list", + OldFileListHWM, ], ) -def test_file_downloader_wrong_hwm_type(hwm_type, hwm_type_name): - with pytest.raises(ValueError, match=f"`hwm_type` class should be a inherited from FileHWM, got {hwm_type_name}"): +def test_file_downloader_hwm_type_without_source_path(hwm_type): + warning_msg = "If `hwm` is passed, `source_path` must be specified" + with pytest.raises(ValueError, match=warning_msg): FileDownloader( - connection=Mock(), - local_path="/path", - source_path="/path", + connection=Mock(spec=BaseFileConnection), + local_path="/local/path", hwm_type=hwm_type, ) -def test_file_downloader_hwm_type_without_source_path(): - with pytest.raises(ValueError, match="If `hwm_type` is passed, `source_path` must be specified"): +def test_file_downloader_hwm_without_source_path(): + warning_msg = "If `hwm` is passed, `source_path` must be specified" + with pytest.raises(ValueError, match=warning_msg): + FileDownloader( + connection=Mock(spec=BaseFileConnection), + local_path="/local/path", + hwm=FileListHWM(name="abc"), + ) + + +def test_file_downloader_hwm_autofill_directory(): + downloader = FileDownloader( + connection=Mock(spec=BaseFileConnection), + local_path="/local/path", + source_path="/source/path", + hwm=FileListHWM(name="abc"), + ) + assert downloader.hwm.entity == AbsolutePath("/source/path") + + +def test_file_downloader_hwm_with_same_directory(): + downloader = FileDownloader( + connection=Mock(spec=BaseFileConnection), + local_path="/local/path", + source_path="/source/path", + hwm=FileListHWM(name="abc", directory="/source/path"), + ) + assert downloader.hwm.entity == AbsolutePath("/source/path") + + +def test_file_downloader_hwm_with_different_directory_error(): + error_msg = "Passed `hwm.directory` is different from `source_path`" + with pytest.raises(ValueError, match=error_msg): FileDownloader( - connection=Mock(), - local_path="/path", - hwm_type="file_list", + connection=Mock(spec=BaseFileConnection), + local_path="/local/path", + source_path="/source/path", + hwm=FileListHWM(name="abc", directory="/another/path"), ) def test_file_downloader_filter_default(): downloader = FileDownloader( connection=Mock(spec=BaseFileConnection), - local_path="/path", + local_path="/local/path", ) assert downloader.filters == [] @@ -89,7 +167,7 @@ def test_file_downloader_filter_none(): with pytest.warns(UserWarning, match=re.escape("filter=None is deprecated in v0.8.0, use filters=[] instead")): downloader = FileDownloader( connection=Mock(spec=BaseFileConnection), - local_path="/path", + local_path="/local/path", filter=None, ) @@ -107,7 +185,7 @@ def test_file_downloader_filter_legacy(file_filter): with pytest.warns(UserWarning, match=re.escape("filter=... is deprecated in v0.8.0, use filters=[...] instead")): downloader = FileDownloader( connection=Mock(spec=BaseFileConnection), - local_path="/path", + local_path="/local/path", filter=file_filter, ) @@ -117,7 +195,7 @@ def test_file_downloader_filter_legacy(file_filter): def test_file_downloader_limit_default(): downloader = FileDownloader( connection=Mock(spec=BaseFileConnection), - local_path="/path", + local_path="/local/path", ) assert downloader.limits == [] @@ -127,7 +205,7 @@ def test_file_downloader_limit_none(): with pytest.warns(UserWarning, match=re.escape("limit=None is deprecated in v0.8.0, use limits=[] instead")): downloader = FileDownloader( connection=Mock(spec=BaseFileConnection), - local_path="/path", + local_path="/local/path", limit=None, ) @@ -145,7 +223,7 @@ def test_file_downloader_limit_legacy(file_limit): with pytest.warns(UserWarning, match=re.escape("limit=... is deprecated in v0.8.0, use limits=[...] instead")): downloader = FileDownloader( connection=Mock(spec=BaseFileConnection), - local_path="/path", + local_path="/local/path", limit=file_limit, ) diff --git a/tests/tests_unit/test_file/test_format_unit/test_excel_unit.py b/tests/tests_unit/test_file/test_format_unit/test_excel_unit.py index e94386120..9424e8b1c 100644 --- a/tests/tests_unit/test_file/test_format_unit/test_excel_unit.py +++ b/tests/tests_unit/test_file/test_format_unit/test_excel_unit.py @@ -32,16 +32,16 @@ def test_excel_get_packages_package_version_not_supported(): "spark_version, scala_version, package_version, packages", [ # Detect Scala version by Spark version - ("3.2.4", None, None, ["com.crealytics:spark-excel_2.12:3.2.4_0.19.0"]), - ("3.4.1", None, None, ["com.crealytics:spark-excel_2.12:3.4.1_0.19.0"]), + ("3.2.4", None, None, ["com.crealytics:spark-excel_2.12:3.2.4_0.20.3"]), + ("3.5.0", None, None, ["com.crealytics:spark-excel_2.12:3.5.0_0.20.3"]), # Override Scala version - ("3.2.4", "2.12", None, ["com.crealytics:spark-excel_2.12:3.2.4_0.19.0"]), - ("3.2.4", "2.13", None, ["com.crealytics:spark-excel_2.13:3.2.4_0.19.0"]), - ("3.4.1", "2.12", None, ["com.crealytics:spark-excel_2.12:3.4.1_0.19.0"]), - ("3.4.1", "2.13", None, ["com.crealytics:spark-excel_2.13:3.4.1_0.19.0"]), + ("3.2.4", "2.12", None, ["com.crealytics:spark-excel_2.12:3.2.4_0.20.3"]), + ("3.2.4", "2.13", None, ["com.crealytics:spark-excel_2.13:3.2.4_0.20.3"]), + ("3.5.0", "2.12", None, ["com.crealytics:spark-excel_2.12:3.5.0_0.20.3"]), + ("3.5.0", "2.13", None, ["com.crealytics:spark-excel_2.13:3.5.0_0.20.3"]), # Override package version ("3.2.0", None, "0.16.0", ["com.crealytics:spark-excel_2.12:3.2.0_0.16.0"]), - ("3.4.1", None, "0.18.0", ["com.crealytics:spark-excel_2.12:3.4.1_0.18.0"]), + ("3.5.0", None, "0.18.0", ["com.crealytics:spark-excel_2.12:3.5.0_0.18.0"]), ], ) def test_excel_get_packages(caplog, spark_version, scala_version, package_version, packages): diff --git a/tests/tests_unit/test_internal_unit/test_generate_temp_path.py b/tests/tests_unit/test_internal_unit/test_generate_temp_path.py index 340fb7081..faad170fe 100644 --- a/tests/tests_unit/test_internal_unit/test_generate_temp_path.py +++ b/tests/tests_unit/test_internal_unit/test_generate_temp_path.py @@ -3,7 +3,7 @@ from pathlib import PurePath import pytest -from etl_entities import Process +from etl_entities.process import Process from onetl._internal import generate_temp_path diff --git a/tests/tests_unit/test_internal_unit/test_get_sql_query.py b/tests/tests_unit/test_internal_unit/test_get_sql_query.py deleted file mode 100644 index adcebeafe..000000000 --- a/tests/tests_unit/test_internal_unit/test_get_sql_query.py +++ /dev/null @@ -1,134 +0,0 @@ -import textwrap - -import pytest - -from onetl._internal import get_sql_query - - -@pytest.mark.parametrize( - "columns", - [ - None, - "*", - ["*"], - [], - ], -) -def test_get_sql_query_no_columns(columns): - result = get_sql_query( - table="default.test", - columns=columns, - ) - - expected = textwrap.dedent( - """ - SELECT - * - FROM - default.test - """, - ).strip() - - assert result == expected - - -def test_get_sql_query_columns(): - result = get_sql_query( - table="default.test", - columns=["d_id", "d_name", "d_age"], - ) - expected = textwrap.dedent( - """ - SELECT - d_id, - d_name, - d_age - FROM - default.test - """, - ).strip() - - assert result == expected - - -def test_get_sql_query_where(): - result = get_sql_query( - table="default.test", - where="d_id > 100", - ) - - expected = textwrap.dedent( - """ - SELECT - * - FROM - default.test - WHERE - d_id > 100 - """, - ).strip() - - assert result == expected - - -def test_get_sql_query_hint(): - result = get_sql_query( - table="default.test", - hint="NOWAIT", - ) - - expected = textwrap.dedent( - """ - SELECT /*+ NOWAIT */ - * - FROM - default.test - """, - ).strip() - - assert result == expected - - -def test_get_sql_query_compact_false(): - result = get_sql_query( - table="default.test", - hint="NOWAIT", - columns=["d_id", "d_name", "d_age"], - where="d_id > 100", - compact=False, - ) - - expected = textwrap.dedent( - """ - SELECT /*+ NOWAIT */ - d_id, - d_name, - d_age - FROM - default.test - WHERE - d_id > 100 - """, - ).strip() - - assert result == expected - - -def test_get_sql_query_compact_true(): - result = get_sql_query( - table="default.test", - hint="NOWAIT", - columns=["d_id", "d_name", "d_age"], - where="d_id > 100", - compact=True, - ) - - expected = textwrap.dedent( - """ - SELECT /*+ NOWAIT */ d_id, d_name, d_age - FROM default.test - WHERE d_id > 100 - """, - ).strip() - - assert result == expected diff --git a/tests/tests_unit/test_plugins/test_autoimport_success.py b/tests/tests_unit/test_plugins/test_autoimport_success.py index e00c98f30..240f64eb8 100644 --- a/tests/tests_unit/test_plugins/test_autoimport_success.py +++ b/tests/tests_unit/test_plugins/test_autoimport_success.py @@ -1,7 +1,8 @@ import sys +from etl_entities.hwm_store import HWMStoreClassRegistry + import onetl -from onetl.hwm.store import HWMStoreClassRegistry def test_autoimport_success(request): diff --git a/tests/tests_unit/tests_db_connection_unit/test_dialect_unit.py b/tests/tests_unit/tests_db_connection_unit/test_dialect_unit.py new file mode 100644 index 000000000..8faed9256 --- /dev/null +++ b/tests/tests_unit/tests_db_connection_unit/test_dialect_unit.py @@ -0,0 +1,264 @@ +import textwrap + +import pytest + +from onetl.connection import Oracle, Postgres + +pytestmark = [pytest.mark.postgres] + + +@pytest.mark.parametrize( + "columns", + [ + None, + "*", + ["*"], + [], + ], +) +def test_db_dialect_get_sql_query_no_columns(spark_mock, columns): + conn = Postgres(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) + + result = conn.dialect.get_sql_query( + table="default.test", + columns=columns, + ) + + expected = textwrap.dedent( + """ + SELECT + * + FROM + default.test + """, + ).strip() + + assert result == expected + + +def test_db_dialect_get_sql_query_columns(spark_mock): + conn = Postgres(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) + + result = conn.dialect.get_sql_query( + table="default.test", + columns=["*", "d_id", "d_name", "d_age"], + ) + expected = textwrap.dedent( + """ + SELECT + *, + d_id, + d_name, + d_age + FROM + default.test + """, + ).strip() + + assert result == expected + + +def test_db_dialect_get_sql_query_columns_oracle(spark_mock): + conn = Oracle(host="some_host", user="user", sid="database", password="passwd", spark=spark_mock) + + # same as for other databases + result = conn.dialect.get_sql_query( + table="default.test", + columns=["*"], + ) + expected = textwrap.dedent( + """ + SELECT + * + FROM + default.test + """, + ).strip() + + assert result == expected + + # but this is different + result = conn.dialect.get_sql_query( + table="default.test", + columns=["*", "d_id", "d_name", "d_age"], + ) + expected = textwrap.dedent( + """ + SELECT + default.test.*, + d_id, + d_name, + d_age + FROM + default.test + """, + ).strip() + + assert result == expected + + +def test_db_dialect_get_sql_query_where_string(spark_mock): + conn = Postgres(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) + + result = conn.dialect.get_sql_query( + table="default.test", + where="d_id > 100", + ) + + expected = textwrap.dedent( + """ + SELECT + * + FROM + default.test + WHERE + d_id > 100 + """, + ).strip() + + assert result == expected + + +def test_db_dialect_get_sql_query_where_list(spark_mock): + conn = Postgres(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) + + result = conn.dialect.get_sql_query( + table="default.test", + where=["d_id > 100", "d_id < 200"], + ) + + expected = textwrap.dedent( + """ + SELECT + * + FROM + default.test + WHERE + (d_id > 100) + AND + (d_id < 200) + """, + ).strip() + + assert result == expected + + +def test_db_dialect_get_sql_query_hint(spark_mock): + conn = Postgres(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) + + result = conn.dialect.get_sql_query( + table="default.test", + hint="NOWAIT", + ) + + expected = textwrap.dedent( + """ + SELECT /*+ NOWAIT */ + * + FROM + default.test + """, + ).strip() + + assert result == expected + + +def test_db_dialect_get_sql_query_limit(spark_mock): + conn = Postgres(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) + + result = conn.dialect.get_sql_query( + table="default.test", + limit=5, + ) + + expected = textwrap.dedent( + """ + SELECT + * + FROM + default.test + LIMIT + 5 + """, + ).strip() + + assert result == expected + + +def test_db_dialect_get_sql_query_limit_0(spark_mock): + conn = Postgres(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) + + result = conn.dialect.get_sql_query( + table="default.test", + limit=0, + ) + + expected = textwrap.dedent( + """ + SELECT + * + FROM + default.test + WHERE + 1 = 0 + """, + ).strip() + + assert result == expected + + +def test_db_dialect_get_sql_query_compact_false(spark_mock): + conn = Postgres(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) + + result = conn.dialect.get_sql_query( + table="default.test", + hint="NOWAIT", + columns=["d_id", "d_name", "d_age"], + where=["d_id > 100", "d_id < 200"], + limit=5, + compact=False, + ) + + expected = textwrap.dedent( + """ + SELECT /*+ NOWAIT */ + d_id, + d_name, + d_age + FROM + default.test + WHERE + (d_id > 100) + AND + (d_id < 200) + LIMIT + 5 + """, + ).strip() + + assert result == expected + + +def test_db_dialect_get_sql_query_compact_true(spark_mock): + conn = Postgres(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) + + result = conn.dialect.get_sql_query( + table="default.test", + hint="NOWAIT", + columns=["d_id", "d_name", "d_age"], + where=["d_id > 100", "d_id < 200"], + limit=5, + compact=True, + ) + + expected = textwrap.dedent( + """ + SELECT /*+ NOWAIT */ d_id, d_name, d_age + FROM default.test + WHERE (d_id > 100) + AND (d_id < 200) + LIMIT 5 + """, + ).strip() + + assert result == expected diff --git a/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py b/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py index d53b4d614..48882f185 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py @@ -202,7 +202,16 @@ def test_mongodb_with_extra(spark_mock): assert mongo.connection_url == "mongodb://user:password@host:27017/database?opt1=value1&tls=true" -def test_mongodb_convert_list_to_str(): +def test_mongodb_convert_list_to_str(spark_mock): + mongo = MongoDB( + host="host", + user="user", + password="password", + database="database", + extra={"tls": "true", "opt1": "value1"}, + spark=spark_mock, + ) + where = [ {"$or": [{"col_1": {"$gt": 1, "$eq": True}}, {"col_2": {"$eq": None}}]}, { @@ -213,13 +222,27 @@ def test_mongodb_convert_list_to_str(): }, ] - assert MongoDB.Dialect.convert_to_str(where) == ( - '[{"$or": [{"col_1": {"$gt": 1, "$eq": true}}, {"col_2": {"$eq": null}}]}, ' - '{"$and": [{"col_3": {"$eq": "Hello"}}, {"col_4": {"$eq": {"$date": "2022-12-23T08:22:33.456000+00:00"}}}]}]' - ) + assert mongo.dialect.prepare_pipeline(where) == [ + {"$or": [{"col_1": {"$gt": 1, "$eq": True}}, {"col_2": {"$eq": None}}]}, + { + "$and": [ + {"col_3": {"$eq": "Hello"}}, + {"col_4": {"$eq": {"$date": "2022-12-23T08:22:33.456000+00:00"}}}, + ], + }, + ] + +def test_mongodb_convert_dict_to_str(spark_mock): + mongo = MongoDB( + host="host", + user="user", + password="password", + database="database", + extra={"tls": "true", "opt1": "value1"}, + spark=spark_mock, + ) -def test_mongodb_convert_dict_to_str(): where = { "$and": [ {"$or": [{"col_1": {"$gt": 1, "$eq": True}}, {"col_2": {"$eq": None}}]}, @@ -232,12 +255,17 @@ def test_mongodb_convert_dict_to_str(): ], } - assert MongoDB.Dialect.convert_to_str(where) == ( - '{"$and": ' - '[{"$or": [{"col_1": {"$gt": 1, "$eq": true}}, {"col_2": {"$eq": null}}]}, ' - '{"$and": [{"col_3": {"$eq": "Hello"}}, {"col_4": {"$eq": {"$date": "2022-12-23T08:22:33.456000+00:00"}}}]}]' - "}" - ) + assert mongo.dialect.prepare_pipeline(where) == { + "$and": [ + {"$or": [{"col_1": {"$gt": 1, "$eq": True}}, {"col_2": {"$eq": None}}]}, + { + "$and": [ + {"col_3": {"$eq": "Hello"}}, + {"col_4": {"$eq": {"$date": "2022-12-23T08:22:33.456000+00:00"}}}, + ], + }, + ], + } @pytest.mark.parametrize( diff --git a/tests/tests_unit/tests_hwm_store_unit/test_hwm_store_common_unit.py b/tests/tests_unit/tests_hwm_store_unit/test_hwm_store_common_unit.py deleted file mode 100644 index a8965a06e..000000000 --- a/tests/tests_unit/tests_hwm_store_unit/test_hwm_store_common_unit.py +++ /dev/null @@ -1,191 +0,0 @@ -import tempfile - -import pytest -from omegaconf import OmegaConf - -from onetl.hwm.store import ( - HWMStoreManager, - MemoryHWMStore, - YAMLHWMStore, - detect_hwm_store, -) - -hwm_store = [ - MemoryHWMStore(), - YAMLHWMStore(path=tempfile.mktemp("hwmstore")), # noqa: S306 NOSONAR -] - - -@pytest.mark.parametrize("hwm_store", hwm_store) -def test_hwm_store_get_save(hwm_store, hwm_delta): - hwm, delta = hwm_delta - assert hwm_store.get(hwm.qualified_name) is None - - hwm_store.save(hwm) - assert hwm_store.get(hwm.qualified_name) == hwm - - hwm2 = hwm + delta - hwm_store.save(hwm2) - assert hwm_store.get(hwm.qualified_name) == hwm2 - - -@pytest.mark.parametrize( - "hwm_store_class, input_config, key", - [ - ( - YAMLHWMStore, - {"hwm_store": None}, - "hwm_store", - ), - ( - YAMLHWMStore, - {"env": {"hwm_store": "yml"}}, - "env.hwm_store", - ), - ( - YAMLHWMStore, - {"hwm_store": "yml"}, - "hwm_store", - ), - ( - YAMLHWMStore, - {"some_store": "yml"}, - "some_store", - ), - ( - YAMLHWMStore, - {"hwm_store": "yaml"}, - "hwm_store", - ), - ( - MemoryHWMStore, - {"hwm_store": "memory"}, - "hwm_store", - ), - ( - MemoryHWMStore, - {"some_store": "memory"}, - "some_store", - ), - ( - MemoryHWMStore, - {"hwm_store": "in-memory"}, - "hwm_store", - ), - ( - MemoryHWMStore, - {"hwm_store": {"memory": None}}, - "hwm_store", - ), - ( - MemoryHWMStore, - {"hwm_store": {"memory": []}}, - "hwm_store", - ), - ( - MemoryHWMStore, - {"hwm_store": {"memory": {}}}, - "hwm_store", - ), - ( - YAMLHWMStore, - {"hwm_store": {"yml": {"path": tempfile.mktemp("hwmstore"), "encoding": "utf8"}}}, # noqa: S306 NOSONAR - "hwm_store", - ), - ], -) -@pytest.mark.parametrize("config_constructor", [dict, OmegaConf.create]) -def test_hwm_store_unit_detect(hwm_store_class, input_config, config_constructor, key): - @detect_hwm_store(key) - def main(config): - assert isinstance(HWMStoreManager.get_current(), hwm_store_class) - - conf = config_constructor(input_config) - main(conf) - - -@pytest.mark.parametrize( - "input_config", - [ - {"hwm_store": 1}, - {"hwm_store": "unknown"}, - {"hwm_store": {"unknown": None}}, - ], -) -@pytest.mark.parametrize("config_constructor", [dict, OmegaConf.create]) -def test_hwm_store_unit_detect_failure(input_config, config_constructor): - @detect_hwm_store("hwm_store") - def main(config): # NOSONAR - pass - - conf = config_constructor(input_config) - with pytest.raises((KeyError, ValueError)): - main(conf) - - conf = config_constructor({"nested": input_config}) - with pytest.raises((KeyError, ValueError)): - main(conf) - - conf = config_constructor({"even": {"more": {"nested": input_config}}}) - with pytest.raises((KeyError, ValueError)): - main(conf) - - -@pytest.mark.parametrize( - "input_config", - [ - {"hwm_store": {"memory": 1}}, - {"hwm_store": {"memory": {"unknown": "arg"}}}, - {"hwm_store": {"memory": ["too_many_arg"]}}, - {"hwm_store": {"yml": 1}}, - {"hwm_store": {"yml": tempfile.mktemp("hwmstore")}}, # noqa: S306 NOSONAR - {"hwm_store": {"yml": [tempfile.mktemp("hwmstore")]}}, # noqa: S306 NOSONAR - {"hwm_store": {"yml": [tempfile.mktemp("hwmstore"), "utf8"]}}, # noqa: S306 NOSONAR - {"hwm_store": {"yml": {"unknown": "arg"}}}, - {"not_hwm_store": "yml"}, - ], -) -@pytest.mark.parametrize("config_constructor", [dict, OmegaConf.create]) -def test_hwm_store_unit_wrong_options(input_config, config_constructor): - @detect_hwm_store("hwm_store") - def main(config): # NOSONAR - pass - - conf = config_constructor(input_config) - - with pytest.raises((TypeError, ValueError)): - main(conf) - - conf = config_constructor({"nested": input_config}) - with pytest.raises((TypeError, ValueError)): - main(conf) - - conf = config_constructor({"even": {"more": {"nested": input_config}}}) - with pytest.raises((TypeError, ValueError)): - main(conf) - - -@pytest.mark.parametrize( - "config, key", - [ - ({"some": "yml"}, "unknown"), - ({"some": "yml"}, "some.unknown"), - ({"some": "yml"}, "some.yaml.unknown"), - ({"var": {"hwm_store": "yml"}}, "var.hwm_store."), - ({"var": {"hwm_store": "yml"}}, "var..hwm_store"), - ({"some": "yml"}, 12), - ({}, "var.hwm_store"), - ({"var": {"hwm_store": "yml"}}, ""), - ({}, ""), - ], -) -@pytest.mark.parametrize("config_constructor", [dict, OmegaConf.create]) -def test_hwm_store_wrong_config_and_key_value_error(config_constructor, config, key): - with pytest.raises(ValueError): - - @detect_hwm_store(key) - def main(config): - ... - - conf = config_constructor(config) - main(conf) diff --git a/tests/tests_unit/tests_hwm_store_unit/test_memory_hwm_store_unit.py b/tests/tests_unit/tests_hwm_store_unit/test_memory_hwm_store_unit.py deleted file mode 100644 index 342f3003d..000000000 --- a/tests/tests_unit/tests_hwm_store_unit/test_memory_hwm_store_unit.py +++ /dev/null @@ -1,18 +0,0 @@ -import logging - -import pytest - -from onetl.hwm.store import HWMStoreManager, MemoryHWMStore, YAMLHWMStore - - -def test_hwm_store_memory_context_manager(caplog): - hwm_store = MemoryHWMStore() - - with caplog.at_level(logging.INFO): - with hwm_store as store: - assert HWMStoreManager.get_current() == store - - assert "|onETL| Using MemoryHWMStore as HWM Store" in caplog.text - - assert HWMStoreManager.get_current() != hwm_store - assert isinstance(HWMStoreManager.get_current(), YAMLHWMStore) diff --git a/tests/tests_unit/tests_hwm_store_unit/test_yaml_hwm_store_unit.py b/tests/tests_unit/tests_hwm_store_unit/test_yaml_hwm_store_unit.py index cb965290e..c177b608e 100644 --- a/tests/tests_unit/tests_hwm_store_unit/test_yaml_hwm_store_unit.py +++ b/tests/tests_unit/tests_hwm_store_unit/test_yaml_hwm_store_unit.py @@ -1,11 +1,23 @@ import logging import secrets import shutil +import textwrap from pathlib import Path import pytest +from etl_entities.hwm_store import BaseHWMStore as OriginalBaseHWMStore +from etl_entities.hwm_store import ( + HWMStoreClassRegistry as OriginalHWMStoreClassRegistry, +) +from etl_entities.hwm_store import HWMStoreStackManager +from etl_entities.hwm_store import HWMStoreStackManager as OriginalHWMStoreManager +from etl_entities.hwm_store import MemoryHWMStore as OriginalMemoryHWMStore +from etl_entities.hwm_store import detect_hwm_store as original_detect_hwm_store +from etl_entities.hwm_store import ( + register_hwm_store_class as original_register_hwm_store_class, +) -from onetl.hwm.store import HWMStoreManager, YAMLHWMStore +from onetl.hwm.store import YAMLHWMStore def test_hwm_store_yaml_path(request, tmp_path_factory, hwm_delta): @@ -24,7 +36,7 @@ def finalizer(): assert not list(path.glob("**/*")) - store.save(hwm) + store.set_hwm(hwm) empty = True for item in path.glob("**/*"): @@ -65,10 +77,10 @@ def finalizer(): store = YAMLHWMStore(path=path) with pytest.raises(OSError): - store.get(hwm.qualified_name) + store.get_hwm(hwm.name) with pytest.raises(OSError): - store.save(hwm) + store.set_hwm(hwm) def test_hwm_store_yaml_context_manager(caplog): @@ -79,13 +91,13 @@ def test_hwm_store_yaml_context_manager(caplog): with caplog.at_level(logging.INFO): with hwm_store as store: - assert HWMStoreManager.get_current() == store + assert HWMStoreStackManager.get_current() == store - assert "|onETL| Using YAMLHWMStore as HWM Store" in caplog.text - assert "path = '" in caplog.text + assert "Using YAMLHWMStore as HWM Store" in caplog.text + assert "path = " in caplog.text assert "encoding = 'utf-8'" in caplog.text - assert HWMStoreManager.get_current() == hwm_store + assert HWMStoreStackManager.get_current() == hwm_store def test_hwm_store_yaml_context_manager_with_path(caplog, request, tmp_path_factory): @@ -103,14 +115,14 @@ def finalizer(): with caplog.at_level(logging.INFO): with hwm_store as store: - assert HWMStoreManager.get_current() == store + assert HWMStoreStackManager.get_current() == store - assert "|onETL| Using YAMLHWMStore as HWM Store" in caplog.text - assert f"path = '{path}' (kind='directory'" in caplog.text + assert "Using YAMLHWMStore as HWM Store" in caplog.text + assert str(path) in caplog.text assert "encoding = 'utf-8'" in caplog.text - assert HWMStoreManager.get_current() != hwm_store - assert isinstance(HWMStoreManager.get_current(), YAMLHWMStore) + assert HWMStoreStackManager.get_current() != hwm_store + assert isinstance(HWMStoreStackManager.get_current(), YAMLHWMStore) def test_hwm_store_yaml_context_manager_with_encoding(caplog, request, tmp_path_factory): @@ -128,14 +140,14 @@ def finalizer(): with caplog.at_level(logging.INFO): with hwm_store as store: - assert HWMStoreManager.get_current() == store + assert HWMStoreStackManager.get_current() == store - assert "|onETL| Using YAMLHWMStore as HWM Store" in caplog.text - assert f"path = '{path}' (kind='directory'" in caplog.text + assert "Using YAMLHWMStore as HWM Store" in caplog.text + assert str(path) in caplog.text assert "encoding = 'cp-1251'" in caplog.text - assert HWMStoreManager.get_current() != hwm_store - assert isinstance(HWMStoreManager.get_current(), YAMLHWMStore) + assert HWMStoreStackManager.get_current() != hwm_store + assert isinstance(HWMStoreStackManager.get_current(), YAMLHWMStore) @pytest.mark.parametrize( @@ -175,3 +187,44 @@ def finalizer(): ) def test_hwm_store_yaml_cleanup_file_name(qualified_name, file_name): assert YAMLHWMStore.cleanup_file_name(qualified_name) == file_name + + +def test_hwm_store_no_deprecation_warning_yaml_hwm_store(): + with pytest.warns(None) as record: + from onetl.hwm.store import YAMLHWMStore + + YAMLHWMStore() + assert not record + + +@pytest.mark.parametrize( + "import_name, original_import", + [ + ("MemoryHWMStore", OriginalMemoryHWMStore), + ("BaseHWMStore", OriginalBaseHWMStore), + ("HWMStoreClassRegistry", OriginalHWMStoreClassRegistry), + ("HWMStoreManager", OriginalHWMStoreManager), + ("detect_hwm_store", original_detect_hwm_store), + ("register_hwm_store_class", original_register_hwm_store_class), + ], +) +def test_hwm_store_deprecation_warning_matching_cases(import_name, original_import): + msg = textwrap.dedent( + f""" + This import is deprecated since v0.10.0: + + from onetl.hwm.store import {import_name} + + Please use instead: + + from etl_entities.hwm_store import {import_name} + """, + ) + + with pytest.warns(UserWarning) as record: + from onetl.hwm.store import __getattr__ + + assert __getattr__(import_name) is original_import + + assert record + assert msg in str(record[0].message) diff --git a/tests/tests_unit/tests_strategy_unit/test_strategy_unit.py b/tests/tests_unit/tests_strategy_unit/test_strategy_unit.py index 5ca6b09bf..eed85ae7b 100644 --- a/tests/tests_unit/tests_strategy_unit/test_strategy_unit.py +++ b/tests/tests_unit/tests_strategy_unit/test_strategy_unit.py @@ -39,7 +39,7 @@ def test_strategy_batch_step_is_empty(step, strategy): (SnapshotBatchStrategy, {"step": 1}), ], ) -def test_strategy_hwm_column_not_set(check, strategy, kwargs, spark_mock): +def test_strategy_hwm_not_set(check, strategy, kwargs, spark_mock): check.return_value = None with strategy(**kwargs): diff --git a/tests/util/assert_df.py b/tests/util/assert_df.py index 7e12dd490..f7adad032 100644 --- a/tests/util/assert_df.py +++ b/tests/util/assert_df.py @@ -66,4 +66,5 @@ def assert_subset_df( columns = [column.lower() for column in columns] for column in columns: # noqa: WPS528 - assert small_pdf[column].isin(large_pdf[column]).all() # noqa: S101 + difference = ~small_pdf[column].isin(large_pdf[column]) + assert not difference.all(), large_pdf[difference] diff --git a/tests/util/rand.py b/tests/util/rand.py new file mode 100644 index 000000000..0b0a7b56e --- /dev/null +++ b/tests/util/rand.py @@ -0,0 +1,7 @@ +from random import randint +from string import ascii_lowercase + + +def rand_str(alphabet: str = ascii_lowercase, length: int = 10) -> str: + alphabet_length = len(alphabet) + return "".join(alphabet[randint(0, alphabet_length - 1)] for _ in range(length)) # noqa: S311