diff --git a/.codecov.yml b/.codecov.yml index b338711073..e45b6334f3 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -2,3 +2,8 @@ comment: behavior: once require_changes: true after_n_builds: 2 + +coverage: + status: + project: false + patch: true diff --git a/.dep-versions b/.dep-versions index bc4a192741..222f447e36 100644 --- a/.dep-versions +++ b/.dep-versions @@ -8,8 +8,8 @@ enzyme=v0.0.149 # For a custom PL version, update the package version here and at # 'doc/requirements.txt -pennylane=0.40.0-dev16 +pennylane=0.40.0-dev20 # For a custom LQ/LK version, update the package version here and at # 'doc/requirements.txt' -lightning=0.40.0-dev11 +lightning=0.40.0-dev41 diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile deleted file mode 100644 index 5df2438d9b..0000000000 --- a/.devcontainer/Dockerfile +++ /dev/null @@ -1,30 +0,0 @@ -FROM --platform=linux/amd64 python:3-slim - -# Add non-root user to the image. - -ARG USERNAME="catalyst" -ARG USER_UID=1000 -ARG USER_GID=$USER_UID - -RUN groupadd --gid $USER_GID $USERNAME -RUN useradd --uid $USER_UID --gid $USER_GID -m $USERNAME -RUN apt-get update && apt-get install -y --no-install-recommends sudo \ - && apt-get clean && rm -rf /var/lib/apt/lists/* -RUN echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME -RUN chmod 0440 /etc/sudoers.d/$USERNAME - -## Image specific instructions: Install Catalyst straight from PyPI. ## - -# Install git. -RUN apt-get update && apt-get install -y --no-install-recommends git \ - && apt-get clean && rm -rf /var/lib/apt/lists/* - -USER $USERNAME - -# Install Python kernel for use with Jupyter Notebooks. -RUN pip install --no-cache-dir ipykernel - -RUN pip install --no-cache-dir pennylane-catalyst - -ENV SHELL /bin/bash -ENV PATH="/home/${USERNAME}/.local/bin:${PATH}" diff --git a/.devcontainer/dev/Dockerfile b/.devcontainer/dev/Dockerfile deleted file mode 100644 index ffdddc2ebf..0000000000 --- a/.devcontainer/dev/Dockerfile +++ /dev/null @@ -1,37 +0,0 @@ -FROM --platform=linux/amd64 python:3-slim - -# Add non-root user to the image. - -ARG USERNAME="catalyst" -ARG USER_UID=1000 -ARG USER_GID=$USER_UID - -RUN groupadd --gid $USER_GID $USERNAME -RUN useradd --uid $USER_UID --gid $USER_GID -m $USERNAME -RUN apt-get update && apt-get install -y --no-install-recommends sudo \ - && apt-get clean && rm -rf /var/lib/apt/lists/* -RUN echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME -RUN chmod 0440 /etc/sudoers.d/$USERNAME - -## Image specific instructions: Ensure developer requirements are met. ## -## Catalyst will be installed from source from within the running container. ## - -# Install the required C++ build toolchain and git. -RUN apt-get update && apt-get install -y --no-install-recommends git curl \ - build-essential ninja-build clang lld ccache libomp-dev \ - && apt-get clean && rm -rf /var/lib/apt/lists/* - -USER $USERNAME - -# Install a recent version of CMake not yet available via apt. -RUN pip install --no-cache-dir "cmake>=3.20" -# Install Python kernel for use with Jupyter Notebooks. -RUN pip install --no-cache-dir ipykernel - -# Install the Rust toolchain for use with LLVM. -ENV PATH="/home/${USERNAME}/.cargo/bin:${PATH}" -RUN curl --proto "=https" --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y -RUN rustup component add llvm-tools-preview - -ENV SHELL /bin/bash -ENV PATH="/home/${USERNAME}/.local/bin:${PATH}" diff --git a/.devcontainer/dev/devcontainer.json b/.devcontainer/dev/devcontainer.json deleted file mode 100644 index 47522d8b47..0000000000 --- a/.devcontainer/dev/devcontainer.json +++ /dev/null @@ -1,43 +0,0 @@ -{ - "name": "CatalystDev", - "build": { - "dockerfile": "Dockerfile" - }, - "postCreateCommand": "/bin/bash ./.devcontainer/dev/post-install.sh", - "customizations": { - "vscode": { - "extensions": [ - "ms-python.python", - "ms-python.vscode-pylance", - "ms-python.pylint", - "ms-python.isort", - "ms-toolsai.jupyter", - "ms-vscode.cpptools", - "twxs.cmake", - "llvm-vs-code-extensions.vscode-mlir", - "revng.llvm-ir", - "colejcummins.llvm-syntax-highlighting" - ], - "settings": { - "editor.formatOnSave": true, - "files.trimTrailingWhitespace": true, - "files.insertFinalNewline": true, - "files.trimFinalNewlines": true, - "python.formatting.provider": "black", - "python.linting.pylintEnabled": true, - "python.linting.enabled": true, - "C_Cpp.default.cppStandard": "c++20", - "C_Cpp.default.includePath": [ - "${containerWorkspaceFolder}/mlir/include", - "${containerWorkspaceFolder}/mlir/lib/**", - "${containerWorkspaceFolder}/mlir/build/include", - "${containerWorkspaceFolder}/mlir/llvm-project/mlir/include", - "${containerWorkspaceFolder}/mlir/llvm-project/build/tools/mlir/include", - "${containerWorkspaceFolder}/mlir/llvm-project/llvm/include", - "${containerWorkspaceFolder}/mlir/llvm-project/build/include" - ], - "mlir.server_path": "${containerWorkspaceFolder}/mlir/build/bin/quantum-lsp-server" - } - } - } -} diff --git a/.devcontainer/dev/post-install.sh b/.devcontainer/dev/post-install.sh deleted file mode 100644 index d425ed0b62..0000000000 --- a/.devcontainer/dev/post-install.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -git submodule update --init --depth=1 -pip install -r requirements.txt -make all diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json deleted file mode 100644 index fc57249d14..0000000000 --- a/.devcontainer/devcontainer.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "name": "CatalystUser", - "build": { - "dockerfile": "Dockerfile" - }, - "customizations": { - "vscode": { - "extensions": [ - "ms-python.python", - "ms-python.vscode-pylance", - "ms-toolsai.jupyter" - ] - } - } -} diff --git a/.github/workflows/build-wheel-linux-arm64.yaml b/.github/workflows/build-wheel-linux-arm64.yaml index 304b23ee5d..fdf81109d7 100644 --- a/.github/workflows/build-wheel-linux-arm64.yaml +++ b/.github/workflows/build-wheel-linux-arm64.yaml @@ -204,7 +204,7 @@ jobs: needs: [constants, build-dependencies] strategy: fail-fast: false - max-parallel: 2 + max-parallel: 1 matrix: python_version: [{major_minor: "3.10", patch: "14", package: "python3.10", alternative: "310"}, {major_minor: "3.11", patch: "9", package: "python3.11", alternative: "311"}, @@ -299,7 +299,7 @@ jobs: needs: [constants, catalyst-linux-wheels-arm64] strategy: fail-fast: false - max-parallel: 2 + max-parallel: 1 matrix: python_version: [{major_minor: "3.10", patch: "14", package: "python3.10"}, {major_minor: "3.11", patch: "9", package: "python3.11"}, diff --git a/.github/workflows/build-wheel-linux-x86_64.yaml b/.github/workflows/build-wheel-linux-x86_64.yaml index b731f62f40..c703aa5970 100644 --- a/.github/workflows/build-wheel-linux-x86_64.yaml +++ b/.github/workflows/build-wheel-linux-x86_64.yaml @@ -10,6 +10,9 @@ on: - ready_for_review push: branches: [ main ] + schedule: + # Thursdays we test the standalone plugin + - cron: '35 4 * * 4' workflow_dispatch: workflow_call: @@ -184,7 +187,7 @@ jobs: -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DPython3_EXECUTABLE=$(which python${{ matrix.python_version }}) \ -DPython3_NumPy_INCLUDE_DIRS=$(python${{ matrix.python_version }} -c "import numpy as np; print(np.get_include())") \ - -DCMAKE_CXX_VISIBILITY_PRESET=protected \ + -DCMAKE_CXX_VISIBILITY_PRESET=default \ -DLLVM_ENABLE_LLD=ON # TODO: when updating LLVM, test to see if mlir/unittests/Bytecode/BytecodeTest.cpp:55 is passing @@ -217,7 +220,7 @@ jobs: -DPython3_EXECUTABLE=$(which python${{ matrix.python_version }}) \ -DLLVM_ENABLE_ZLIB=FORCE_ON \ -DLLVM_ENABLE_ZSTD=OFF \ - -DCMAKE_CXX_VISIBILITY_PRESET=protected \ + -DCMAKE_CXX_VISIBILITY_PRESET=default \ -DLLVM_ENABLE_LLD=ON LIT_FILTER_OUT="chlo_legalize_to_mhlo" cmake --build mhlo-build --target check-mlir-hlo @@ -238,7 +241,7 @@ jobs: -DCMAKE_BUILD_TYPE=Release \ -DLLVM_DIR=$GITHUB_WORKSPACE/llvm-build/lib/cmake/llvm \ -DENZYME_STATIC_LIB=ON \ - -DCMAKE_CXX_VISIBILITY_PRESET=protected \ + -DCMAKE_CXX_VISIBILITY_PRESET=default \ -DCMAKE_CXX_FLAGS="-fuse-ld=lld" cmake --build enzyme-build --target EnzymeStatic-19 @@ -339,10 +342,7 @@ jobs: cmake -S runtime -B runtime-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=$GITHUB_WORKSPACE/runtime-build/lib \ - -DPYTHON_EXECUTABLE=$(which python${{ matrix.python_version }}) \ - -DPython_ROOT_DIR=$(python${{ matrix.python_version }} -c "import sys; print(sys.prefix)") \ - -DPYTHON_VERSION_TO_FIND=${{ matrix.python_version }} \ - -Dpybind11_DIR=$(python${{ matrix.python_version }} -c "import pybind11; print(pybind11.get_cmake_dir())") \ + -DPython_EXECUTABLE=$(which python${{ matrix.python_version }}) \ -DENABLE_OPENQASM=ON cmake --build runtime-build --target rt_capi rtd_openqasm rtd_null_qubit @@ -357,6 +357,16 @@ jobs: PYTHON=$(which python${{ matrix.python_version }}) \ make oqc + # Build OQD-Runtime + - name: Build OQD-Runtime + run: | + C_COMPILER=$(which gcc) \ + CXX_COMPILER=$(which g++) \ + OQD_BUILD_DIR=$GITHUB_WORKSPACE/oqd-build \ + RT_BUILD_DIR=$GITHUB_WORKSPACE/runtime-build \ + PYTHON=$(which python${{ matrix.python_version }}) \ + make oqd + # Build Quantum and Gradient Dialects - name: Build MLIR Dialects run: | @@ -378,6 +388,15 @@ jobs: cmake --build quantum-build --target check-dialects catalyst-cli + - name: Build Plugin wheel + # Run only on Thursday at the given time + if: github.event.schedule == '35 4 * * 4' + run: | + CCACHE_DIR="$(pwd)/.ccache" \ + MLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ + LLVM_BUILD_DIR="$GITHUB_WORKSPACE/llvm-build" \ + make plugin-wheel + - name: Build wheel run: | PYTHON=python${{ matrix.python_version }} \ @@ -386,6 +405,7 @@ jobs: DIALECTS_BUILD_DIR=$GITHUB_WORKSPACE/quantum-build \ RT_BUILD_DIR=$GITHUB_WORKSPACE/runtime-build \ OQC_BUILD_DIR=$GITHUB_WORKSPACE/oqc-build \ + OQD_BUILD_DIR=$GITHUB_WORKSPACE/oqd-build \ ENZYME_BUILD_DIR=$GITHUB_WORKSPACE/enzyme-build \ make wheel @@ -401,6 +421,15 @@ jobs: path: wheel/ retention-days: 14 + - name: Upload Standalone Plugin Wheel Artifact + # Run only on Thursday at the given time + if: github.event.schedule == '35 4 * * 4' + uses: actions/upload-artifact@v4 + with: + name: standalone-plugin-manylinux_2_28_x86_64-wheel-py-${{ matrix.python_version }}.zip + path: standalone_plugin_wheel/dist + retention-days: 14 + test-wheels: needs: [constants, catalyst-linux-wheels-x86-64, determine_runner] strategy: @@ -422,6 +451,14 @@ jobs: name: catalyst-manylinux_2_28_x86_64-wheel-py-${{ matrix.python_version }}.zip path: dist + - name: Download Standalone Plugin Wheel Artifact + # Run only on Thursday at the given time + if: github.event.schedule == '35 4 * * 4' + uses: actions/download-artifact@v4 + with: + name: standalone-plugin-manylinux_2_28_x86_64-wheel-py-${{ matrix.python_version }}.zip + path: standalone_plugin_wheel/wheel + - name: Set up Python ${{ matrix.python_version }} uses: actions/setup-python@v5 with: @@ -444,6 +481,12 @@ jobs: run: | python${{ matrix.python_version }} -m pip install dist/*.whl --extra-index-url https://test.pypi.org/simple + - name: Install Standalone Plugin + # Run only on Thursday at the given time (TODO: set comparison to == before merging) + if: github.event.schedule == '35 4 * * 4' + run: | + python${{ matrix.python_version }} -m pip install standalone_plugin_wheel/wheel/*.whl --no-deps + - name: Run Python Pytest Tests run: | python${{ matrix.python_version }} -m pytest frontend/test/pytest -n auto @@ -451,3 +494,9 @@ jobs: python${{ matrix.python_version }} -m pytest frontend/test/async_tests # python${{ matrix.python_version }} -m pytest frontend/test/pytest --runbraket=LOCAL -n auto python${{ matrix.python_version }} -m pytest frontend/test/test_oqc/oqc -n auto + + - name: Run Standalone Plugin Tests + # Run only on Thursday at the given time + if: github.event.schedule == '35 4 * * 4' + run: | + python${{ matrix.python_version }} -m pytest standalone_plugin_wheel/standalone_plugin/test -n auto diff --git a/.github/workflows/build-wheel-macos-arm64.yaml b/.github/workflows/build-wheel-macos-arm64.yaml index f8c127d53e..153b3b816c 100644 --- a/.github/workflows/build-wheel-macos-arm64.yaml +++ b/.github/workflows/build-wheel-macos-arm64.yaml @@ -10,6 +10,9 @@ on: - ready_for_review push: branches: [ main ] + schedule: + # Thursdays we test the standalone plugin + - cron: '35 4 * * 4' workflow_dispatch: workflow_call: @@ -48,6 +51,12 @@ jobs: - name: Checkout Catalyst repo uses: actions/checkout@v4 + # Python 3.10 was dropped from the GH images on macOS arm + - name: Install Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: '3.10' + # Cache external project sources # Hopefully these can be shared with the main check-catalyst action since we don't run this # build in a container. @@ -150,7 +159,7 @@ jobs: -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DPython3_EXECUTABLE=$(which python${{ matrix.python_version }}) \ -DPython3_NumPy_INCLUDE_DIRS=$(python${{ matrix.python_version }} -c "import numpy as np; print(np.get_include())") \ - -DCMAKE_CXX_VISIBILITY_PRESET=hidden + -DCMAKE_CXX_VISIBILITY_PRESET=default # TODO: when updating LLVM, test to see if mlir/unittests/Bytecode/BytecodeTest.cpp:55 is passing # and remove filter @@ -182,7 +191,7 @@ jobs: -DLLVM_ENABLE_LLD=OFF \ -DLLVM_ENABLE_ZLIB=FORCE_ON \ -DLLVM_ENABLE_ZSTD=OFF \ - -DCMAKE_CXX_VISIBILITY_PRESET=hidden + -DCMAKE_CXX_VISIBILITY_PRESET=default cmake --build mhlo-build --target check-mlir-hlo @@ -201,7 +210,7 @@ jobs: -DCMAKE_BUILD_TYPE=Release \ -DLLVM_DIR=$GITHUB_WORKSPACE/llvm-build/lib/cmake/llvm \ -DENZYME_STATIC_LIB=ON \ - -DCMAKE_CXX_VISIBILITY_PRESET=hidden + -DCMAKE_CXX_VISIBILITY_PRESET=default cmake --build enzyme-build --target EnzymeStatic-19 @@ -227,6 +236,13 @@ jobs: - name: Checkout Catalyst repo uses: actions/checkout@v4 + # Python 3.10 was dropped from the GH images on macOS arm + - name: Install Python 3.10 + uses: actions/setup-python@v5 + if: ${{ matrix.python_version }} == '3.10' + with: + python-version: '3.10' + - name: Install Dependencies (System) run: | brew install libomp @@ -296,13 +312,12 @@ jobs: # Build Catalyst-Runtime - name: Build Catalyst-Runtime run: | + # On GH images, gfortran is only available as a specific version. + export FC=gfortran-14 cmake -S runtime -B runtime-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=$GITHUB_WORKSPACE/runtime-build/lib \ - -DPYTHON_EXECUTABLE=$(which python${{ matrix.python_version }}) \ - -DPython_ROOT_DIR=$(python${{ matrix.python_version }} -c "import sys; print(sys.prefix)") \ - -DPYTHON_VERSION_TO_FIND=${{ matrix.python_version }} \ - -Dpybind11_DIR=$(python${{ matrix.python_version }} -c "import pybind11; print(pybind11.get_cmake_dir())") \ + -DPython_EXECUTABLE=$(which python${{ matrix.python_version }}) \ -DENABLE_OPENQASM=ON cmake --build runtime-build --target rt_capi rtd_openqasm rtd_null_qubit @@ -321,6 +336,14 @@ jobs: PYTHON=$(which python${{ matrix.python_version }}) \ make oqc + # Build OQD-Runtime + - name: Build OQD-Runtime + run: | + OQD_BUILD_DIR=$GITHUB_WORKSPACE/oqd-build \ + RT_BUILD_DIR=$GITHUB_WORKSPACE/runtime-build \ + PYTHON=$(which python${{ matrix.python_version }}) \ + make oqd + # Build Quantum and Gradient Dialects - name: Build MLIR Dialects run: | @@ -342,6 +365,14 @@ jobs: cmake --build quantum-build --target check-dialects catalyst-cli + - name: Build Plugin wheel + # Run only on Thursday at the given time + if: github.event.schedule == '35 4 * * 4' + run: | + MLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ + LLVM_BUILD_DIR="$GITHUB_WORKSPACE/llvm-build" \ + make plugin-wheel + - name: Build wheel run: | PYTHON=python${{ matrix.python_version }} \ @@ -350,6 +381,7 @@ jobs: DIALECTS_BUILD_DIR=$GITHUB_WORKSPACE/quantum-build \ RT_BUILD_DIR=$GITHUB_WORKSPACE/runtime-build \ OQC_BUILD_DIR=$GITHUB_WORKSPACE/oqc-build \ + OQD_BUILD_DIR=$GITHUB_WORKSPACE/oqd-build \ ENZYME_BUILD_DIR=$GITHUB_WORKSPACE/enzyme-build \ make wheel @@ -365,6 +397,15 @@ jobs: path: wheel/ retention-days: 14 + - name: Upload Standalone Plugin Wheel Artifact + # Run only on Thursday at the given time + if: github.event.schedule == '35 4 * * 4' + uses: actions/upload-artifact@v4 + with: + name: standalone-plugin-macos_arm64-wheel-py-${{ matrix.python_version }}.zip + path: standalone_plugin_wheel/dist + retention-days: 14 + test-wheels: needs: [constants, catalyst-macos-wheels-arm64] strategy: @@ -380,12 +421,27 @@ jobs: - name: Checkout Catalyst repo uses: actions/checkout@v4 + # Python 3.10 was dropped from the GH images on macOS arm + - name: Install Python 3.10 + uses: actions/setup-python@v5 + if: ${{ matrix.python_version }} == '3.10' + with: + python-version: '3.10' + - name: Download Wheel Artifact uses: actions/download-artifact@v4 with: name: catalyst-macos_arm64-wheel-py-${{ matrix.python_version }}.zip path: dist + - name: Download Standalone Plugin Wheel Artifact + # Run only on Thursday at the given time + if: github.event.schedule == '35 4 * * 4' + uses: actions/download-artifact@v4 + with: + name: standalone-plugin-macos_arm64-wheel-py-${{ matrix.python_version }}.zip + path: standalone_plugin_wheel/wheel + - name: Setup Python version # There are multiple Python versions installed on the GitHub image, 3.10 - 3.12 is already # available under /Library/Frameworks/Python.framework/Versions/, but homebrew also provides @@ -410,6 +466,12 @@ jobs: run: | python${{ matrix.python_version }} -m pip install dist/*.whl --extra-index-url https://test.pypi.org/simple + - name: Install Standalone Plugin + # Run only on Thursday at the given time + if: github.event.schedule == '35 4 * * 4' + run: | + python${{ matrix.python_version }} -m pip install standalone_plugin_wheel/wheel/*.whl --no-deps + - name: Run Python Pytest Tests run: | python${{ matrix.python_version }} -m pytest frontend/test/pytest -n auto @@ -417,3 +479,9 @@ jobs: python${{ matrix.python_version }} -m pytest frontend/test/async_tests # python${{ matrix.python_version }} -m pytest frontend/test/pytest --runbraket=LOCAL -n auto python${{ matrix.python_version }} -m pytest frontend/test/test_oqc/oqc -n auto + + - name: Run Standalone Plugin Tests + # Run only on Thursday at the given time + if: github.event.schedule == '35 4 * * 4' + run: | + python${{ matrix.python_version }} -m pytest standalone_plugin_wheel/standalone_plugin/test -n auto diff --git a/.github/workflows/build-wheel-macos-x86_64.yaml b/.github/workflows/build-wheel-macos-x86_64.yaml index eef0b3c1a0..b51bf58b6e 100644 --- a/.github/workflows/build-wheel-macos-x86_64.yaml +++ b/.github/workflows/build-wheel-macos-x86_64.yaml @@ -10,6 +10,9 @@ on: - ready_for_review push: branches: [ main ] + schedule: + # Thursdays we test the standalone plugin + - cron: '35 4 * * 4' workflow_dispatch: workflow_call: @@ -146,7 +149,7 @@ jobs: -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DPython3_EXECUTABLE=$(which python${{ matrix.python_version }}) \ -DPython3_NumPy_INCLUDE_DIRS=$(python${{ matrix.python_version }} -c "import numpy as np; print(np.get_include())") \ - -DCMAKE_CXX_VISIBILITY_PRESET=hidden + -DCMAKE_CXX_VISIBILITY_PRESET=default # TODO: when updating LLVM, test to see if mlir/unittests/Bytecode/BytecodeTest.cpp:55 is passing # and remove filter @@ -178,7 +181,7 @@ jobs: -DLLVM_ENABLE_LLD=OFF \ -DLLVM_ENABLE_ZLIB=FORCE_ON \ -DLLVM_ENABLE_ZSTD=OFF \ - -DCMAKE_CXX_VISIBILITY_PRESET=hidden + -DCMAKE_CXX_VISIBILITY_PRESET=default cmake --build mhlo-build --target check-mlir-hlo @@ -197,7 +200,7 @@ jobs: -DCMAKE_BUILD_TYPE=Release \ -DLLVM_DIR=$GITHUB_WORKSPACE/llvm-build/lib/cmake/llvm \ -DENZYME_STATIC_LIB=ON \ - -DCMAKE_CXX_VISIBILITY_PRESET=hidden + -DCMAKE_CXX_VISIBILITY_PRESET=default cmake --build enzyme-build --target EnzymeStatic-19 @@ -290,10 +293,7 @@ jobs: cmake -S runtime -B runtime-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=$GITHUB_WORKSPACE/runtime-build/lib \ - -DPYTHON_EXECUTABLE=$(which python${{ matrix.python_version }}) \ - -DPython_ROOT_DIR=$(python${{ matrix.python_version }} -c "import sys; print(sys.prefix)") \ - -DPYTHON_VERSION_TO_FIND=${{ matrix.python_version }} \ - -Dpybind11_DIR=$(python${{ matrix.python_version }} -c "import pybind11; print(pybind11.get_cmake_dir())") \ + -DPython_EXECUTABLE=$(which python${{ matrix.python_version }}) \ -DENABLE_OPENQASM=ON cmake --build runtime-build --target rt_capi rtd_openqasm rtd_null_qubit @@ -306,6 +306,14 @@ jobs: PYTHON=$(which python${{ matrix.python_version }}) \ make oqc + # Build OQD-Runtime + - name: Build OQD-Runtime + run: | + OQD_BUILD_DIR="$(pwd)/oqd-build" \ + RT_BUILD_DIR="$(pwd)/runtime-build" \ + PYTHON=$(which python${{ matrix.python_version }}) \ + make oqd + - name: Test Catalyst-Runtime run: | python${{ matrix.python_version }} -m pip install 'amazon-braket-pennylane-plugin>1.27.1' @@ -332,6 +340,15 @@ jobs: cmake --build quantum-build --target check-dialects catalyst-cli + - name: Build Plugin wheel + # Run only on Thursday at the given time + if: github.event.schedule == '35 4 * * 4' + run: | + CCACHE_DIR="$(pwd)/.ccache" \ + LLVM_BUILD_DIR="$GITHUB_WORKSPACE/llvm-build" \ + MLIR_DIR="$(pwd)/llvm-build/lib/cmake/mlir" \ + make plugin-wheel + - name: Build wheel run: | PYTHON=python${{ matrix.python_version }} \ @@ -340,6 +357,7 @@ jobs: DIALECTS_BUILD_DIR=$GITHUB_WORKSPACE/quantum-build \ RT_BUILD_DIR=$GITHUB_WORKSPACE/runtime-build \ OQC_BUILD_DIR=$GITHUB_WORKSPACE/oqc-build \ + OQD_BUILD_DIR=$GITHUB_WORKSPACE/oqd-build \ ENZYME_BUILD_DIR=$GITHUB_WORKSPACE/enzyme-build \ make wheel @@ -355,6 +373,15 @@ jobs: path: wheel/ retention-days: 14 + - name: Upload Standalone Plugin Wheel Artifact + # Run only on Thursday at the given time + if: github.event.schedule == '35 4 * * 4' + uses: actions/upload-artifact@v4 + with: + name: standalone-plugin-macos_x86_64-wheel-py-${{ matrix.python_version }}.zip + path: standalone_plugin_wheel/dist + retention-days: 14 + test-wheels: needs: [constants, catalyst-macos-wheels-x86-64] strategy: @@ -376,6 +403,14 @@ jobs: name: catalyst-macos_x86_64-wheel-py-${{ matrix.python_version }}.zip path: dist + - name: Download Standalone Plugin Wheel Artifact + # Run only on Thursday at the given time + if: github.event.schedule == '35 4 * * 4' + uses: actions/download-artifact@v4 + with: + name: standalone-plugin-macos_x86_64-wheel-py-${{ matrix.python_version }}.zip + path: standalone_plugin_wheel/wheel + - name: Set up Python ${{ matrix.python_version }} uses: actions/setup-python@v5 with: @@ -398,6 +433,12 @@ jobs: run: | python${{ matrix.python_version }} -m pip install dist/*.whl --extra-index-url https://test.pypi.org/simple + - name: Install Standalone Plugin + # Run only on Thursday at the given time + if: github.event.schedule == '35 4 * * 4' + run: | + python${{ matrix.python_version }} -m pip install standalone_plugin_wheel/wheel/*.whl --no-deps + - name: Run Python Pytest Tests run: | python${{ matrix.python_version }} -m pytest frontend/test/pytest -n auto @@ -406,3 +447,9 @@ jobs: python${{ matrix.python_version }} -m pytest frontend/test/async_tests # python${{ matrix.python_version }} -m pytest frontend/test/pytest --runbraket=LOCAL -n auto python${{ matrix.python_version }} -m pytest frontend/test/test_oqc/oqc -n auto + + - name: Run Standalone Plugin Tests + # Run only on Thursday at the given time + if: github.event.schedule == '35 4 * * 4' + run: | + python${{ matrix.python_version }} -m pytest standalone_plugin_wheel/standalone_plugin/test -n auto diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index 1e979b2260..396eea19e6 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -47,7 +47,7 @@ jobs: run: | sudo apt-get update sudo apt-get -y -q install ninja-build make cmake clang libomp-dev - python3 -m pip install nanobind + python3 -m pip install nanobind pybind11 - name: Build Catalyst-Runtime run: | @@ -65,6 +65,8 @@ jobs: OQC_BUILD_DIR="$(pwd)/oqc-build" \ make oqc + OQD_BUILD_DIR="$(pwd)/oqd-build" \ + make oqd - name: Upload Catalyst-Runtime Artifact uses: actions/upload-artifact@v4 @@ -81,7 +83,16 @@ jobs: name: oqc-build-${{ matrix.compiler }} path: | oqc-build/*.so - oqc-build/*.toml + oqc-build/backend/*.toml + retention-days: 1 + + - name: Upload OQD-Runtime Artifact + uses: actions/upload-artifact@v4 + with: + name: oqd-build-${{ matrix.compiler }} + path: | + oqd-build/*.so + oqd-build/backend/*.toml retention-days: 1 llvm: @@ -373,6 +384,17 @@ jobs: DIALECTS_BUILD_DIR="$(pwd)/quantum-build" \ make dialects + - name: Build Standalone Plugin + run: | + CCACHE_DIR="$(pwd)/.ccache" \ + C_COMPILER=$(which ${{ needs.constants.outputs[format('c_compiler.{0}', matrix.compiler)] }}) \ + CXX_COMPILER=$(which ${{ needs.constants.outputs[format('cxx_compiler.{0}', matrix.compiler)] }}) \ + LLVM_BUILD_DIR="$(pwd)/llvm-build" \ + MLIR_DIR="$(pwd)/llvm-build/lib/cmake/mlir" \ + make plugin + mkdir -p $(pwd)/quantum-build/lib + mv mlir/standalone/build/lib/StandalonePlugin.so "$(pwd)/quantum-build/lib" + - name: Upload Quantum Build Artifact uses: actions/upload-artifact@v4 with: @@ -380,6 +402,7 @@ jobs: path: | quantum-build/bin quantum-build/python_packages/* + quantum-build/lib/StandalonePlugin.so retention-days: 1 - name: Cache CCache on main branch @@ -435,12 +458,26 @@ jobs: name: runtime-build-${{ matrix.compiler }} path: runtime-build/lib + - name: Download OQC-Runtime Artifact + uses: actions/download-artifact@v4 + with: + name: oqc-build-${{ matrix.compiler }} + path: oqc-build + + - name: Download OQD-Runtime Artifact + uses: actions/download-artifact@v4 + with: + name: oqd-build-${{ matrix.compiler }} + path: oqd-build + - name: Add Frontend Dependencies to PATH run: | echo "$(pwd)/llvm-build/bin" >> $GITHUB_PATH echo "PYTHONPATH=$PYTHONPATH:$(pwd)/quantum-build/python_packages/quantum" >> $GITHUB_ENV echo "RUNTIME_LIB_DIR=$(pwd)/runtime-build/lib" >> $GITHUB_ENV echo "MLIR_LIB_DIR=$(pwd)/llvm-build/lib" >> $GITHUB_ENV + echo "OQC_LIB_DIR=$(pwd)/oqc-build" >> $GITHUB_ENV + echo "OQD_LIB_DIR=$(pwd)/oqd-build" >> $GITHUB_ENV echo "CATALYST_BIN_DIR=$(pwd)/quantum-build/bin" >> $GITHUB_ENV chmod +x $(pwd)/quantum-build/bin/catalyst-cli # artifact upload does not preserve permissions @@ -509,17 +546,24 @@ jobs: name: runtime-build-${{ matrix.compiler }} path: runtime-build/lib + - name: Download OQD-Runtime Artifact + uses: actions/download-artifact@v4 + with: + name: oqd-build-${{ matrix.compiler }} + path: oqd-build + - name: Add Frontend Dependencies to PATH run: | echo "PYTHONPATH=$PYTHONPATH:$(pwd)/quantum-build/python_packages/quantum" >> $GITHUB_ENV echo "RUNTIME_LIB_DIR=$(pwd)/runtime-build/lib" >> $GITHUB_ENV echo "MLIR_LIB_DIR=$(pwd)/llvm-build/lib" >> $GITHUB_ENV + echo "OQD_LIB_DIR=$(pwd)/oqd-build" >> $GITHUB_ENV echo "CATALYST_BIN_DIR=$(pwd)/quantum-build/bin" >> $GITHUB_ENV chmod +x $(pwd)/quantum-build/bin/catalyst-cli # artifact upload does not preserve permissions - name: Run Python Pytest Tests (backend=lightning.kokkos) run: | - make pytest TEST_BACKEND="lightning.kokkos" + make pytest TEST_BACKEND="lightning.kokkos" SKIP_OQD="true" # frontend-tests-openqasm-device: # name: Frontend Tests (backend="openqasm3") @@ -589,23 +633,14 @@ jobs: compiler: ${{ fromJson(needs.constants.outputs.compilers) }} steps: - # - name: Collect Workflow Telemetry - # uses: catchpoint/workflow-telemetry-action@v2 - - name: Checkout the repo uses: actions/checkout@v4 - name: Install dependencies run: | sudo apt-get update - sudo apt-get -y -q install cmake ninja-build libomp-dev lcov libasan6 - python3 -m pip install nanobind - - - name: Install additional dependencies (OpenQasm device) - if: ${{ matrix.backend == 'openqasm' }} - run: | - pip install numpy amazon-braket-sdk - echo "AWS_DEFAULT_REGION=us-east-1" >> $GITHUB_ENV + sudo apt-get -y -q install cmake ninja-build + python3 -m pip install nanobind pybind11 - name: Download Catalyst-Runtime Artifact uses: actions/download-artifact@v4 @@ -614,7 +649,6 @@ jobs: path: runtime-build/lib - name: Build Runtime test suite for OQC device - if: ${{ matrix.backend == 'oqc' }} run: | C_COMPILER=$(which ${{ needs.constants.outputs[format('c_compiler.{0}', matrix.compiler)] }}) \ @@ -622,36 +656,13 @@ jobs: RT_BUILD_DIR="$(pwd)/runtime-build" \ make test-oqc - - name: Build Runtime test suite for Lightning simulator - if: ${{ matrix.backend == 'lightning' }} + - name: Build Runtime test suite for OQD device + if: ${{ matrix.backend == 'oqd' }} run: | C_COMPILER=$(which ${{ needs.constants.outputs[format('c_compiler.{0}', matrix.compiler)] }}) \ CXX_COMPILER=$(which ${{ needs.constants.outputs[format('cxx_compiler.{0}', matrix.compiler)] }}) \ - COMPILER_LAUNCHER="" \ - ENABLE_OPENQASM=OFF \ - ENABLE_ASAN=ON \ - make test-runtime - - - name: Build Runtime test suite with both Lightning and Lightning-Kokkos simulators - if: ${{ matrix.backend == 'lightning-kokkos' }} - run: | - # ASAN fails w/ leaks, odr-violation, and double-free from Kokkos_Profiling.cpp - C_COMPILER=$(which ${{ needs.constants.outputs[format('c_compiler.{0}', matrix.compiler)] }}) \ - CXX_COMPILER=$(which ${{ needs.constants.outputs[format('cxx_compiler.{0}', matrix.compiler)] }}) \ - COMPILER_LAUNCHER="" \ - ENABLE_OPENQASM=OFF \ - ENABLE_ASAN=OFF \ - make test-runtime - - - name: Build Runtime test suite for OpenQasm device - if: ${{ matrix.backend == 'openqasm' }} - run: | - # Asan prevents dlopen from working? - C_COMPILER=$(which ${{ needs.constants.outputs[format('c_compiler.{0}', matrix.compiler)] }}) \ - CXX_COMPILER=$(which ${{ needs.constants.outputs[format('cxx_compiler.{0}', matrix.compiler)] }}) \ - COMPILER_LAUNCHER="" \ - ENABLE_ASAN=OFF \ - make test-runtime + RT_BUILD_DIR="$(pwd)/runtime-build" \ + make test-oqd runtime-code-cov: name: Runtime Code Coverage (Linux) @@ -665,16 +676,21 @@ jobs: - name: Install dependencies run: | sudo apt-get update - sudo apt-get -y -q install cmake ninja-build libomp-dev lcov - python3 -m pip install nanobind + sudo apt-get -y -q install cmake ninja-build lcov + python3 -m pip install nanobind pybind11 + + - name: Install additional dependencies for OpenQasm device + run: | + pip install numpy pybind11 amazon-braket-sdk + echo "AWS_DEFAULT_REGION=us-east-1" >> $GITHUB_ENV - - name: Build Runtime test suite for Lightning simulator + - name: Run the main runtime test suite for coverage run: | C_COMPILER=$(which gcc) \ CXX_COMPILER=$(which g++) \ COMPILER_LAUNCHER="" \ make coverage-runtime - mv runtime/build/coverage.info coverage-${{ github.job }}.info + mv runtime/build_cov/coverage.info coverage-${{ github.job }}.info - name: Upload to Codecov uses: codecov/codecov-action@v4 diff --git a/.github/workflows/check-for-wheel-build.yml b/.github/workflows/check-for-wheel-build.yml index 1a9b802deb..579e332b08 100644 --- a/.github/workflows/check-for-wheel-build.yml +++ b/.github/workflows/check-for-wheel-build.yml @@ -19,9 +19,9 @@ jobs: - name: Pull Request Build needs Wheel Builds to pass id: needs_wheel_builds if: steps.is_pr.outputs.is_pr == 'true' - run: echo "needs_wheel_builds=${{ contains(github.event.pull_request.labels.*.name, 'requires-wheel-builds') }}" >>$GITHUB_OUTPUT + run: echo "needs_wheel_builds=${{ contains(github.event.pull_request.labels.*.name, 'reviewer:require-wheels') }}" >> $GITHUB_OUTPUT - # If the trigger for this workflow (on pull_request) is not a labelling event, then only build the wheels if the + # If the trigger for this workflow (on pull_request) is a labelling event, then only build the wheels if the # label being added is `author:build-wheels`. If the pull_request event is not a labelling event (eg: new commit is pushed) # then build wheels as long as the `author:build-wheels` label is present - name: Build Wheels for Pull Request @@ -33,7 +33,7 @@ jobs: github.event.label.name == 'author:build-wheels' }}" >> $GITHUB_OUTPUT - # If a pr has the `requires-wheel-builds` label, that means that the Workflows which build wheels need to successfully run against it + # If a pr has the `reviewer:require-wheels` label, that means that the Workflows which build wheels need to successfully run against it # However, the PR does not have the `author:build-wheels` label, meaning it is not ready for the wheel workflows to run against it yet. # In that condition, this step will fail, causing the entire workflow to fail. # And since this job is required to pass on all jobs, it will cause the merging of the pull request to be blocked. diff --git a/.github/workflows/check-pl-compat.yaml b/.github/workflows/check-pl-compat.yaml index 1c5a994d06..50476f0620 100644 --- a/.github/workflows/check-pl-compat.yaml +++ b/.github/workflows/check-pl-compat.yaml @@ -119,7 +119,7 @@ jobs: path: lightning_build fetch-depth: 0 - - name: Download PennyLane-Lightning (latest) + - name: Download PennyLane-Lightning (release-candidate) if: ${{ inputs.lightning == 'release-candidate' }} uses: actions/checkout@v4 with: diff --git a/.github/workflows/constants.yaml b/.github/workflows/constants.yaml index 6c285ae668..e42f89d859 100644 --- a/.github/workflows/constants.yaml +++ b/.github/workflows/constants.yaml @@ -83,7 +83,7 @@ jobs: - name: Runtime Backend Devices id: rt_backends - run: echo 'rt_backends=["lightning", "lightning-kokkos", "openqasm", "oqc"]' >> $GITHUB_OUTPUT + run: echo 'rt_backends=["lightning", "lightning-kokkos", "openqasm", "oqc", "oqd"]' >> $GITHUB_OUTPUT - name: Compilers (All) id: compilers diff --git a/.github/workflows/notify-failed-jobs.yaml b/.github/workflows/notify-failed-jobs.yaml index fb4569801e..694236ec0b 100644 --- a/.github/workflows/notify-failed-jobs.yaml +++ b/.github/workflows/notify-failed-jobs.yaml @@ -10,6 +10,7 @@ on: - Build Catalyst Wheel on Linux (x86_64) - Build Catalyst Wheel on macOS (arm64) - Build Catalyst Wheel on macOS (x86_64) + - Build nightly Catalyst releases for TestPyPI jobs: on-failure: diff --git a/.github/workflows/scripts/linux_arm64/rh8/build_catalyst.sh b/.github/workflows/scripts/linux_arm64/rh8/build_catalyst.sh index 714cd62d67..488cf37b4b 100644 --- a/.github/workflows/scripts/linux_arm64/rh8/build_catalyst.sh +++ b/.github/workflows/scripts/linux_arm64/rh8/build_catalyst.sh @@ -41,12 +41,7 @@ export PATH=/catalyst/llvm-build/bin:/opt/_internal/cpython-${PYTHON_VERSION}.${ cmake -S runtime -B runtime-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=runtime-build/lib \ - -DPYTHON_EXECUTABLE=/usr/bin/python3 \ - -DPython_ROOT_DIR=$(/usr/bin/python3 -c "import sys; print(sys.prefix)") \ - -DPYTHON_VERSION_TO_FIND=${PYTHON_VERSION} \ - -DPYTHON_INCLUDE_DIR=/opt/_internal/cpython-${PYTHON_VERSION}.${PYTHON_SUBVERSION}/include/python${PYTHON_VERSION} \ - -DPYTHON_LIBRARY=/opt/_internal/cpython-${PYTHON_VERSION}.${PYTHON_SUBVERSION}/lib \ - -Dpybind11_DIR=/opt/_internal/cpython-${PYTHON_VERSION}.${PYTHON_SUBVERSION}/lib/python${PYTHON_VERSION}/site-packages/pybind11/share/cmake/pybind11 \ + -DPython_EXECUTABLE=${PYTHON} \ -DENABLE_OPENQASM=ON cmake --build runtime-build --target rt_capi rtd_openqasm rtd_null_qubit @@ -55,6 +50,10 @@ export OQC_BUILD_DIR="/catalyst/oqc-build" export RT_BUILD_DIR="/catalyst/runtime-build" make oqc +# Build OQD +export OQD_BUILD_DIR="/catalyst/oqd-build" +make oqd + # Build Catalyst dialects cmake -S mlir -B quantum-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ @@ -85,7 +84,9 @@ export MHLO_BUILD_DIR=/catalyst/mhlo-build export DIALECTS_BUILD_DIR=/catalyst/quantum-build export RT_BUILD_DIR=/catalyst/runtime-build export OQC_BUILD_DIR=/catalyst/oqc-build +export OQD_BUILD_DIR=/catalyst/oqd-build export ENZYME_BUILD_DIR=/catalyst/enzyme-build +export PYTHON=/usr/bin/python3 make wheel # Exclude libopenblas as we rely on the openblas/lapack library shipped by scipy diff --git a/.github/workflows/scripts/linux_arm64/rh8/build_enzyme.sh b/.github/workflows/scripts/linux_arm64/rh8/build_enzyme.sh index db18d11141..6f5aadbe9d 100644 --- a/.github/workflows/scripts/linux_arm64/rh8/build_enzyme.sh +++ b/.github/workflows/scripts/linux_arm64/rh8/build_enzyme.sh @@ -38,7 +38,7 @@ cmake -S /catalyst/mlir/Enzyme/enzyme -B /catalyst/enzyme-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_DIR=/catalyst/llvm-build/lib/cmake/llvm \ -DENZYME_STATIC_LIB=ON \ - -DCMAKE_CXX_VISIBILITY_PRESET=protected \ + -DCMAKE_CXX_VISIBILITY_PRESET=default \ -DCMAKE_CXX_FLAGS="-fuse-ld=lld" cmake --build /catalyst/enzyme-build --target EnzymeStatic-19 diff --git a/.github/workflows/scripts/linux_arm64/rh8/build_llvm.sh b/.github/workflows/scripts/linux_arm64/rh8/build_llvm.sh index 9f17391f8a..67082dc4c4 100644 --- a/.github/workflows/scripts/linux_arm64/rh8/build_llvm.sh +++ b/.github/workflows/scripts/linux_arm64/rh8/build_llvm.sh @@ -47,6 +47,6 @@ cmake -S /catalyst/mlir/llvm-project/llvm -B /catalyst/llvm-build -G Ninja \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DPython3_EXECUTABLE=/usr/bin/python3 \ -DPython3_NumPy_INCLUDE_DIRS=/opt/_internal/cpython-${PYTHON_VERSION}.${PYTHON_SUBVERSION}/lib/python${PYTHON_VERSION}/site-packages/numpy/core/include \ - -DCMAKE_CXX_VISIBILITY_PRESET=protected + -DCMAKE_CXX_VISIBILITY_PRESET=default LIT_FILTER_OUT="Bytecode|tosa-to-tensor" cmake --build /catalyst/llvm-build --target check-mlir llvm-symbolizer diff --git a/.github/workflows/scripts/linux_arm64/rh8/build_mhlo.sh b/.github/workflows/scripts/linux_arm64/rh8/build_mhlo.sh index b6b4561f32..4b7dd07152 100644 --- a/.github/workflows/scripts/linux_arm64/rh8/build_mhlo.sh +++ b/.github/workflows/scripts/linux_arm64/rh8/build_mhlo.sh @@ -49,6 +49,6 @@ cmake -S /catalyst/mlir/mlir-hlo -B /catalyst/mhlo-build -G Ninja \ -DLLVM_ENABLE_LLD=ON \ -DLLVM_ENABLE_ZLIB=FORCE_ON \ -DLLVM_ENABLE_ZSTD=OFF \ - -DCMAKE_CXX_VISIBILITY_PRESET=protected + -DCMAKE_CXX_VISIBILITY_PRESET=default LIT_FILTER_OUT="chlo_legalize_to_mhlo" cmake --build /catalyst/mhlo-build --target check-mlir-hlo diff --git a/.gitignore b/.gitignore index c52bbecb05..c5b0d31f27 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ build __pycache__ pennylane_catalyst.egg-info *.so +*.dylib .dylibs # Testing files and directories @@ -30,3 +31,6 @@ venv # Cache files .cache + +# Standalone-plugin example +mlir/standalone diff --git a/MANIFEST.in b/MANIFEST.in index 6452376cfd..49183a875b 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,3 +3,4 @@ recursive-include frontend/catalyst/lib * recursive-include frontend/catalyst/enzyme * recursive-include frontend/mlir_quantum * recursive-include frontend/catalyst/third_party/cuda/ *.toml +recursive-include frontend/catalyst/third_party/oqd/ *.toml diff --git a/Makefile b/Makefile index 67fa13bc58..1fc7b8908d 100644 --- a/Makefile +++ b/Makefile @@ -12,13 +12,14 @@ MHLO_BUILD_DIR ?= $(MK_DIR)/mlir/mlir-hlo/bazel-build DIALECTS_BUILD_DIR ?= $(MK_DIR)/mlir/build RT_BUILD_DIR ?= $(MK_DIR)/runtime/build OQC_BUILD_DIR ?= $(MK_DIR)/frontend/catalyst/third_party/oqc/src/build +OQD_BUILD_DIR ?= $(MK_DIR)/frontend/catalyst/third_party/oqd/src/build ENZYME_BUILD_DIR ?= $(MK_DIR)/mlir/Enzyme/build COVERAGE_REPORT ?= term-missing ENABLE_OPENQASM?=ON TEST_BACKEND ?= "lightning.qubit" TEST_BRAKET ?= NONE ENABLE_ASAN ?= OFF -TOML_SPECS ?= $(shell find ./runtime ./frontend -name '*.toml') +TOML_SPECS ?= $(shell find ./runtime ./frontend -name '*.toml' -not -name 'pyproject.toml') PLATFORM := $(shell uname -s) ifeq ($(PLATFORM),Linux) @@ -58,6 +59,12 @@ endif # Export variables so that they can be set here without needing to also set them in sub-make files. export ENABLE_ASAN ASAN_COMMAND +# Flag for verbose pip install output +PIP_VERBOSE_FLAG := +ifeq ($(VERBOSE),1) +PIP_VERBOSE_FLAG := --verbose +endif + .PHONY: help help: @echo "Please use \`make ' where is one of" @@ -66,6 +73,7 @@ help: @echo " mlir to build MLIR and custom Catalyst dialects" @echo " runtime to build Catalyst Runtime" @echo " oqc to build Catalyst-OQC Runtime" + @echo " oqd to build Catalyst-OQD Runtime" @echo " test to run the Catalyst test suites" @echo " docs to build the documentation for Catalyst" @echo " clean to uninstall Catalyst and delete all temporary and cache files" @@ -73,6 +81,7 @@ help: @echo " clean-mlir to clean build files of MLIR and custom Catalyst dialects" @echo " clean-runtime to clean build files of Catalyst Runtime" @echo " clean-oqc to clean build files of OQC Runtime" + @echo " clean-oqd to clean build files of OQD Runtime" @echo " clean-all to uninstall Catalyst and delete all temporary, cache, and build files" @echo " clean-docs to delete all built documentation" @echo " coverage to generate a coverage report" @@ -81,8 +90,8 @@ help: .PHONY: all catalyst -all: runtime oqc mlir frontend -catalyst: runtime dialects frontend +all: runtime oqc oqd mlir frontend +catalyst: runtime dialects plugin oqd frontend .PHONY: frontend frontend: @@ -90,10 +99,10 @@ frontend: # Uninstall pennylane before updating Catalyst, since pip will not replace two development # versions of a package with the same version tag (e.g. 0.38-dev0). $(PYTHON) -m pip uninstall -y pennylane - $(PYTHON) -m pip install -e . --extra-index-url https://test.pypi.org/simple + $(PYTHON) -m pip install -e . --extra-index-url https://test.pypi.org/simple $(PIP_VERBOSE_FLAG) rm -r frontend/PennyLane_Catalyst.egg-info -.PHONY: mlir llvm mhlo enzyme dialects runtime oqc +.PHONY: mlir llvm mhlo enzyme dialects runtime oqc oqd mlir: $(MAKE) -C mlir all @@ -110,12 +119,15 @@ dialects: $(MAKE) -C mlir dialects runtime: - $(MAKE) -C runtime all + $(MAKE) -C runtime runtime oqc: $(MAKE) -C frontend/catalyst/third_party/oqc/src oqc -.PHONY: test test-runtime test-frontend lit pytest test-demos test-oqc test-toml-spec +oqd: + $(MAKE) -C frontend/catalyst/third_party/oqd/src oqd + +.PHONY: test test-runtime test-frontend lit pytest test-demos test-oqc test-oqd test-toml-spec test: test-runtime test-frontend test-demos test-toml-spec: @@ -132,6 +144,9 @@ test-frontend: lit pytest test-oqc: $(MAKE) -C frontend/catalyst/third_party/oqc/src test +test-oqd: + $(MAKE) -C frontend/catalyst/third_party/oqd/src test + lit: ifeq ($(ENABLE_ASAN),ON) ifneq ($(findstring clang,$(C_COMPILER)),clang) @@ -173,10 +188,14 @@ wheel: cp $(RT_BUILD_DIR)/lib/librtd* $(MK_DIR)/frontend/catalyst/lib cp $(RT_BUILD_DIR)/lib/catalyst_callback_registry.so $(MK_DIR)/frontend/catalyst/lib cp $(RT_BUILD_DIR)/lib/openqasm_python_module.so $(MK_DIR)/frontend/catalyst/lib + cp $(RT_BUILD_DIR)/lib/liblapacke.* $(MK_DIR)/frontend/catalyst/lib || true # optional cp $(RT_BUILD_DIR)/lib/librt_capi.* $(MK_DIR)/frontend/catalyst/lib cp $(RT_BUILD_DIR)/lib/backend/*.toml $(MK_DIR)/frontend/catalyst/lib/backend cp $(OQC_BUILD_DIR)/librtd_oqc* $(MK_DIR)/frontend/catalyst/lib + cp $(OQC_BUILD_DIR)/oqc_python_module.so $(MK_DIR)/frontend/catalyst/lib cp $(OQC_BUILD_DIR)/backend/*.toml $(MK_DIR)/frontend/catalyst/lib/backend + cp $(OQD_BUILD_DIR)/librtd_oqd* $(MK_DIR)/frontend/catalyst/lib + cp $(OQD_BUILD_DIR)/backend/*.toml $(MK_DIR)/frontend/catalyst/lib/backend cp $(COPY_FLAGS) $(LLVM_BUILD_DIR)/lib/libmlir_float16_utils.* $(MK_DIR)/frontend/catalyst/lib cp $(COPY_FLAGS) $(LLVM_BUILD_DIR)/lib/libmlir_c_runner_utils.* $(MK_DIR)/frontend/catalyst/lib cp $(COPY_FLAGS) $(LLVM_BUILD_DIR)/lib/libmlir_async_runtime.* $(MK_DIR)/frontend/catalyst/lib @@ -194,32 +213,42 @@ wheel: $(PYTHON) -m pip wheel --no-deps . -w dist rm -r $(MK_DIR)/build + rm -r frontend/PennyLane_Catalyst.egg-info + +plugin-wheel: plugin + mkdir -p $(MK_DIR)/standalone_plugin_wheel/standalone_plugin/lib + cp $(COPY_FLAGS) $(DIALECTS_BUILD_DIR)/lib/StandalonePlugin.* $(MK_DIR)/standalone_plugin_wheel/standalone_plugin/lib + + $(PYTHON) -m pip wheel --no-deps $(MK_DIR)/standalone_plugin_wheel -w $(MK_DIR)/standalone_plugin_wheel/dist + + rm -r $(MK_DIR)/standalone_plugin_wheel/standalone_plugin/lib + rm -r $(MK_DIR)/standalone_plugin_wheel/standalone_plugin.egg-info + rm -r $(MK_DIR)/standalone_plugin_wheel/build .PHONY: clean clean-all clean: @echo "uninstall catalyst and delete all temporary and cache files" $(PYTHON) -m pip uninstall -y pennylane-catalyst - rm -rf $(MK_DIR)/frontend/mlir_quantum $(MK_DIR)/frontend/catalyst/lib + find frontend/catalyst -name "*.so" -exec rm -v {} + + git restore frontend/catalyst/_configuration.py + rm -rf $(MK_DIR)/frontend/catalyst/_revision.py + rm -rf $(MK_DIR)/frontend/mlir_quantum $(MK_DIR)/frontend/catalyst/lib $(MK_DIR)/frontend/catalyst/bin rm -rf dist __pycache__ rm -rf .coverage coverage_html_report + rm -rf .benchmarks -clean-all: clean-frontend clean-mlir clean-runtime clean-oqc - @echo "uninstall catalyst and delete all temporary, cache, and build files" - $(PYTHON) -m pip uninstall -y pennylane-catalyst - rm -rf dist __pycache__ - rm -rf .coverage coverage_html_report/ - -.PHONY: clean-frontend -clean-frontend: - find frontend/catalyst -name "*.so" -exec rm -v {} + +clean-all: clean clean-mlir clean-runtime clean-oqc clean-oqd -.PHONY: clean-mlir clean-dialects clean-llvm clean-mhlo clean-enzyme +.PHONY: clean-mlir clean-dialects clean-plugin clean-llvm clean-mhlo clean-enzyme clean-mlir: $(MAKE) -C mlir clean clean-dialects: $(MAKE) -C mlir clean-dialects +clean-plugin: + $(MAKE) -C mlir clean-plugin + clean-llvm: $(MAKE) -C mlir clean-llvm @@ -229,13 +258,16 @@ clean-mhlo: clean-enzyme: $(MAKE) -C mlir clean-enzyme -.PHONY: clean-runtime clean-oqc +.PHONY: clean-runtime clean-oqc clean-oqd clean-runtime: $(MAKE) -C runtime clean clean-oqc: $(MAKE) -C frontend/catalyst/third_party/oqc/src clean +clean-oqd: + $(MAKE) -C frontend/catalyst/third_party/oqd/src clean + .PHONY: coverage coverage-frontend coverage-runtime coverage: coverage-frontend coverage-runtime @@ -250,6 +282,10 @@ endif coverage-runtime: $(MAKE) -C runtime coverage +.PHONY: plugin +plugin: + $(MAKE) -C mlir plugin + .PHONY: format format: ifeq ($(shell test $(BLACKVERSIONMAJOR) -lt 22; echo $$?), 0) diff --git a/README.md b/README.md index d33080ce98..78aac581df 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,6 @@ [![PyPI](https://img.shields.io/pypi/v/PennyLane-Catalyst.svg?style=flat-square)](https://pypi.org/project/PennyLane-Catalyst) [![Forum](https://img.shields.io/discourse/https/discuss.pennylane.ai/posts.svg?logo=discourse&style=flat-square)](https://discuss.pennylane.ai) [![License](https://img.shields.io/pypi/l/PennyLane.svg?logo=apache&style=flat-square)](https://www.apache.org/licenses/LICENSE-2.0) -[![Dev Container](https://img.shields.io/static/v1?label=Dev%20Container&message=Launch&color=blue&logo=visualstudiocode&style=flat-square)](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/PennyLaneAI/catalyst)

@@ -73,7 +72,7 @@ In addition, we also provide a Python frontend for [PennyLane](https://pennylane ## Installation -Catalyst is officially supported on Linux (aarch64/arm64, x86_64) and macOS (aarch64/arm64, x86_64) platforms, +Catalyst is officially supported on Linux (x86_64, aarch64) and macOS (arm64, x86_64) platforms, and pre-built binaries are being distributed via the Python Package Index (PyPI) for Python versions 3.10 and higher. To install it, simply run the following ``pip`` command: @@ -81,15 +80,8 @@ higher. To install it, simply run the following ``pip`` command: pip install pennylane-catalyst ``` -Pre-built packages for Windows are not yet available, and comptability with Windows -is untested and cannot be guaranteed. If you are using one of these platforms, please -try out our Docker and Dev Container images described in the [documentation](https://docs.pennylane.ai/projects/catalyst/en/latest/dev/installation.html#dev-containers) -or click this button: - -[![Dev Container](https://img.shields.io/static/v1?label=Dev%20Container&message=Launch&color=blue&logo=visualstudiocode&style=flat-square)](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/PennyLaneAI/catalyst). - If you wish to contribute to Catalyst or develop against our runtime or compiler, instructions for -[building from source](https://docs.pennylane.ai/projects/catalyst/en/latest/dev/installation.html#building-from-source) +[building from source](https://docs.pennylane.ai/projects/catalyst/en/latest/dev/installation.html#minimal-building-from-source-guide) are also available. ## Trying Catalyst with PennyLane @@ -159,9 +151,25 @@ If you are doing research using Catalyst, please cite our paper: author = {David Ittah and Ali Asadi and Erick Ochoa Lopez and Sergei Mironov and Samuel Banning and Romain Moyard and Mai Jacob Peng and Josh Izaac}, title = {Catalyst: a Python JIT compiler for auto-differentiable hybrid quantum programs}, journal = {Journal of Open Source Software} -} +} ``` ## License Catalyst is **free** and **open source**, released under the Apache License, Version 2.0. + +## Acknowledgements + +Catalyst makes use of the following libraries and tools, which are under their own respective +licenses: + +- [JAX](https://github.com/jax-ml/jax) +- [TensorFlow](https://github.com/tensorflow/tensorflow) +- [OpenXLA](https://github.com/openxla/xla) +- [LLVM/MLIR](https://github.com/llvm/llvm-project) +- [EnzymeAD](https://github.com/EnzymeAD/Enzyme) +- [pybind11](https://github.com/pybind/pybind11) +- [nanobind](https://github.com/wjakob/nanobind) +- [accelerate-lapacke](https://github.com/lepus2589/accelerate-lapacke) +- [LAPACK](https://github.com/Reference-LAPACK/lapack) +- [TOML++](https://github.com/marzer/tomlplusplus) diff --git a/bin/format.py b/bin/format.py index 31257cbfec..2f59138b76 100644 --- a/bin/format.py +++ b/bin/format.py @@ -33,7 +33,7 @@ def parse_version(version_string): - version_rgx = "version (\d+)" + version_rgx = r"version (\d+)" m = re.search(version_rgx, version_string) return int(m.group(1)) diff --git a/bin/toml-check.py b/bin/toml-check.py index 52009b4fb3..3fe2231ea9 100755 --- a/bin/toml-check.py +++ b/bin/toml-check.py @@ -22,54 +22,8 @@ import sys from argparse import ArgumentParser -from textwrap import dedent - -try: - from lark import Lark, LarkError, UnexpectedInput -except ImportError as e: - raise RuntimeError( - "toml-check.py requires `lark` library. Consider using `pip install lark`" - ) from e - -parser = Lark( - dedent( - """ - start: schema_body \ - gates_native_section \ - gates_decomp_section \ - gates_matrix_section \ - gates_observables_section \ - measurement_processes_section \ - compilation_section \ - options_section? - schema_body: schema_decl - gates_native_section: "[operators.gates.native]" gate_decls - gates_decomp_section: "[operators.gates.decomp]" gate_decls - gates_matrix_section: "[operators.gates.matrix]" gate_decls - gates_observables_section: "[operators.observables]" gate_decls - measurement_processes_section: "[measurement_processes]" gate_decls - compilation_section: "[compilation]" flag_decl* - options_section: "[options]" option_decl* - schema_decl: "schema" "=" "2" - gate_decls: (gate_decl)* - gate_decl: name "=" "{" (gate_trait ("," gate_trait)*)? "}" - gate_trait: gate_condition | gate_properties - gate_condition: "condition" "=" "[" ( "\\"finiteshots\\"" | "\\"analytic\\"" ) "]" - gate_properties: "properties" "=" "[" gate_property ("," gate_property)* "]" - gate_property: "\\"controllable\\"" | "\\"invertible\\"" | "\\"differentiable\\"" - flag_decl: ( "qjit_compatible" | "runtime_code_generation" | \ - "mid_circuit_measurement" | "dynamic_qubit_management" ) "=" boolean - option_decl: name "=" (name | "\\"" name "\\"") - name: /[a-zA-Z0-9_]+/ - boolean: "true" | "false" - COMMENT: "#" /./* - %import common.WS - %ignore WS - %ignore COMMENT - """ - ) -) +from pennylane.devices.toml_check import LarkError, UnexpectedInput, parser if __name__ == "__main__": ap = ArgumentParser(prog="toml-check.py") diff --git a/doc/catalyst-cli/catalyst-cli.rst b/doc/catalyst-cli/catalyst-cli.rst index b4e248a5c9..b04c3aaf90 100644 --- a/doc/catalyst-cli/catalyst-cli.rst +++ b/doc/catalyst-cli/catalyst-cli.rst @@ -2,7 +2,8 @@ Catalyst Command Line Interface =============================== Catalyst includes a standalone command-line-interface compiler tool ``catalyst-cli`` that -quantum-compiles MLIR input files into an object file, independent of the Catalyst Python frontend. +compiles quantum programs written in our MLIR dialects into an object file, +independent of the Catalyst Python frontend. This compiler tool combines three stages of compilation: @@ -174,14 +175,11 @@ two quantum-optimization passes: this case the two RX gates the become adjacent after the two Hadamard gates have been removed by the ``remove-chained-self-inverse`` pass. -To define the pass pipeline, we must supply the name of the function to which each pass applies -using the ``func-name`` argument. The ``func-name`` argument is specific to the two passes we are -applying and is not a general requirement. To apply these two passes to our ``my_circuit`` function, -we can do so as follows: +To apply these two passes to our ``my_circuit`` function, we can do so as follows: .. code-block:: - pipe(remove-chained-self-inverse{func-name=my_circuit};merge-rotations{func-name=my_circuit}) + pipe(remove-chained-self-inverse;merge-rotations) Finally, we'll use the option ``--mlir-print-ir-after-all`` to print the resulting MLIR after each pass that is applied, and the ``-o`` option to set the name of the output IR file: @@ -190,7 +188,7 @@ pass that is applied, and the ``-o`` option to set the name of the output IR fil catalyst-cli my_circuit.mlir \ --tool=opt \ - --catalyst-pipeline="pipe(remove-chained-self-inverse{func-name=my_circuit};merge-rotations{func-name=my_circuit})" \ + --catalyst-pipeline="pipe(remove-chained-self-inverse;merge-rotations)" \ --mlir-print-ir-after-all \ -o my_circuit-llvm.mlir @@ -238,3 +236,22 @@ optimized MLIR. For a list of transformation passes currently available in Catalyst, see the :ref:`catalyst-s-transformation-library` documentation. The available passes are also listed in the ``catalyst-cli --help`` message. + +MLIR Plugins +------------ + +``mlir-opt``-like tools are able to take plugins as inputs. +These plugins are shared objects that include dialects and passes written by third parties. +This means that you can write dialects and passes that can be used with ``catalyst-cli`` and ``quantum-opt``. + +As an example, the `LLVM repository includes a very simple plugin `_. +To build it, simply run ``make plugin`` and the standalone plugin +will be built in the root directory of the Catalyst project. + +With this, you can now run your own passes by using the following flags: + +``catalyst-cli --load-dialect-plugin=$YOUR_PLUGIN --load-pass-plugin=$YOUR_PLUGIN $YOUR_PASS_NAME file.mlir`` + +Concretely for the example plugin, you can use the following command: + +``catalyst-cli --tool=opt --load-pass-plugin=standalone/build/lib/StandalonePlugin.so --load-dialect-plugin=standalone/build/lib/StandalonePlugin.so --pass-pipeline='builtin.module(standalone-switch-bar-foo)' a.mlir`` diff --git a/doc/dev/custom_devices.rst b/doc/dev/custom_devices.rst index ffadafb5e4..7e814d97a8 100644 --- a/doc/dev/custom_devices.rst +++ b/doc/dev/custom_devices.rst @@ -153,7 +153,8 @@ Integration with Python devices There are two things that are needed in order to integrate with PennyLane devices: * Adding a ``get_c_interface`` method to your ``qml.devices.Device`` class. -* Adding a ``config`` class variable pointing to your configuration file. This file should be a `toml file `_ with fields that describe what gates and features are supported by your device. +* Adding a ``config_filepath`` class variable pointing to your configuration file. This file should be a `toml file `_ with fields that describe what gates and features are supported by your device. +* Optionally, adding a ``device_kwargs`` dictionary for runtime parameters to pass from the PennyLane device to the ``QuantumDevice`` upon initialization. If you already have a custom PennyLane device defined in Python and have added a shared object that corresponds to your implementation of the ``QuantumDevice`` class, then all you need to do is to add a ``get_c_interface`` method to your PennyLane device. The ``get_c_interface`` method should be a static method that takes no parameters and returns the complete path to your shared library with the ``QuantumDevice`` implementation. @@ -170,7 +171,7 @@ The Pennylane device API allows you to build a QJIT compatible device in a simpl class CustomDevice(qml.devices.Device): """Custom Device""" - config = pathlib.Path("absolute/path/to/configuration/file.toml") + config_filepath = pathlib.Path("absolute/path/to/configuration/file.toml") @staticmethod def get_c_interface(): @@ -196,137 +197,140 @@ headers and fields are generally required, unless stated otherwise. .. code-block:: toml - # Which version of the specification format is being used. - schema = 2 - - # The union of all gate types listed in this section must match what - # the device considers "supported" through PennyLane's device API. - # The gate definition has the following format: - # - # GATE = { properties = [ PROPS ], condition = [ COND ] } - # - # Where: - # - # PROPS: zero or more comma-separated quoted strings: - # "controllable", "invertible", "differentiable" - # COND: quoted string, on of: - # "analytic", "finiteshots" - # - [operators.gates.native] - - QubitUnitary = { properties = [ "controllable", "invertible"] } - PauliX = { properties = [ "controllable", "invertible"] } - PauliY = { properties = [ "controllable", "invertible"] } - PauliZ = { properties = [ "controllable", "invertible"] } - MultiRZ = { properties = [ "controllable", "invertible" ] } - Hadamard = { properties = [ "controllable", "invertible"] } - S = { properties = [ "controllable", "invertible" ] } - T = { properties = [ "controllable", "invertible" ] } - CNOT = { properties = [ "invertible" ] } - SWAP = { properties = [ "controllable", "invertible" ] } - CSWAP = { properties = [ "invertible" ] } - Toffoli = { properties = [ "controllable", "invertible" ] } - CY = { properties = [ "invertible" ] } - CZ = { properties = [ "invertible" ] } - PhaseShift = { properties = [ "controllable", "invertible" ] } - ControlledPhaseShift = { properties = [ "invertible" ] } - RX = { properties = [ "controllable", "invertible" ] } - RY = { properties = [ "controllable", "invertible" ] } - RZ = { properties = [ "controllable", "invertible" ] } - Rot = { properties = [ "controllable", "invertible" ] } - CRX = { properties = [ "invertible" ] } - CRY = { properties = [ "invertible" ] } - CRZ = { properties = [ "invertible" ] } - CRot = { properties = [ "invertible" ] } - Identity = { properties = [ "controllable", "invertible" ] } - IsingXX = { properties = [ "controllable", "invertible" ] } - IsingYY = { properties = [ "controllable", "invertible" ] } - IsingZZ = { properties = [ "controllable", "invertible" ] } - IsingXY = { properties = [ "controllable", "invertible" ] } - - # Operators that should be decomposed according to the algorithm used - # by PennyLane's device API. - # Optional, since gates not listed in this list will typically be decomposed by - # default, but can be useful to express a deviation from this device's regular - # strategy in PennyLane. - [operators.gates.decomp] - - SX = {} - ISWAP = {} - PSWAP = {} - SISWAP = {} - SQISW = {} - CPhase = {} - BasisState = {} - StatePrep = {} - ControlledQubitUnitary = {} - MultiControlledX = {} - SingleExcitation = {} - SingleExcitationPlus = {} - SingleExcitationMinus = {} - DoubleExcitation = {} - DoubleExcitationPlus = {} - DoubleExcitationMinus = {} - QubitCarry = {} - QubitSum = {} - OrbitalRotation = {} - QFT = {} - ECR = {} - - # Gates which should be translated to QubitUnitary - [operators.gates.matrix] - - DiagonalQubitUnitary = {} - - # Observables supported by the device - [operators.observables] - - PauliX = {} - PauliY = {} - PauliZ = {} - Hadamard = {} - Hermitian = {} - Identity = {} - Projector = {} - SparseHamiltonian = {} - Hamiltonian = {} - Sum = {} - SProd = {} - Prod = {} - Exp = {} - - [measurement_processes] - - Expval = {} - Var = {} - Probs = {} - Sample = {} - Counts = { condition = [ "finiteshots" ] } - - [compilation] - - # If the device is compatible with qjit - qjit_compatible = true - # If the device requires run time generation of the quantum circuit. - runtime_code_generation = false - # If the device supports mid circuit measurements natively - mid_circuit_measurement = true - # This field is currently unchecked but it is reserved for the purpose of - # determining if the device supports dynamic qubit allocation/deallocation. - dynamic_qubit_management = false - - [options] - # Options is an optional field. - # These options represent runtime parameters that can be passed to the device - # upon the device initialization. - # The option key will be the key in a dictionary. - # The string corresponds to a field queried in the `qml.Device` instance. - option_key = "option_field" - # In the above example, a dictionary will be constructed at run time. - # The dictionary will contain the string key "option_key" and its value - # will be the value in `qml.Device` `option_field`. - # The value can be any Python type, but will be converted to a string. - # During the initialization of your `class QuantumDevice`, the dictionary - # will be sent to the constructor of your implementation of `class QuantumDevice`. - # The dictionary will be a JSON string like the following: - # { 'option_key': option_field } + schema = 3 + + # The set of all gate types supported at the runtime execution interface of the + # device, i.e., what is supported by the `execute` method. The gate definitions + # should have the following format: + # + # GATE = { properties = [ PROPS ], conditions = [ CONDS ] } + # + # where PROPS and CONS are zero or more comma separated quoted strings. + # + # PROPS: additional support provided for each gate. + # - "controllable": if a controlled version of this gate is supported. + # - "invertible": if the adjoint of this operation is supported. + # - "differentiable": if device gradient is supported for this gate. + # CONDS: constraints on the support for each gate. + # - "analytic" or "finiteshots": if this operation is only supported in + # either analytic execution or with shots, respectively. + # + [operators.gates] + + PauliX = { properties = ["controllable", "invertible"] } + PauliY = { properties = ["controllable", "invertible"] } + PauliZ = { properties = ["controllable", "invertible"] } + RY = { properties = ["controllable", "invertible", "differentiable"] } + RZ = { properties = ["controllable", "invertible", "differentiable"] } + CRY = { properties = ["invertible", "differentiable"] } + CRZ = { properties = ["invertible", "differentiable"] } + CNOT = { properties = ["invertible"] } + + # Observables supported by the device for measurements. The observables defined + # in this section should have the following format: + # + # OBSERVABLE = { conditions = [ CONDS ] } + # + # where CONDS is zero or more comma separated quoted strings, same as above. + # + # CONDS: constraints on the support for each observable. + # - "analytic" or "finiteshots": if this observable is only supported in + # either analytic execution or with shots, respectively. + # - "terms-commute": if a composite operator is only supported under the + # condition that its terms commute. + # + [operators.observables] + + PauliX = { } + PauliY = { } + PauliZ = { } + Hamiltonian = { conditions = [ "terms-commute" ] } + Sum = { conditions = [ "terms-commute" ] } + SProd = { } + Prod = { } + + # Types of measurement processes supported on the device. The measurements in + # this section should have the following format: + # + # MEASUREMENT_PROCESS = { conditions = [ CONDS ] } + # + # where CONDS is zero or more comma separated quoted strings, same as above. + # + # CONDS: constraints on the support for each measurement process. + # - "analytic" or "finiteshots": if this measurement is only supported + # in either analytic execution or with shots, respectively. + # + [measurement_processes] + + ExpectationMP = { } + SampleMP = { } + CountsMP = { conditions = ["finiteshots"] } + StateMP = { conditions = ["analytic"] } + + # Additional support that the device may provide that informs the compilation + # process. All accepted fields and their default values are listed below. + [compilation] + + # Whether the device is compatible with qjit. + qjit_compatible = false + + # Whether the device requires run time generation of the quantum circuit. + runtime_code_generation = false + + # Whether the device supports allocating and releasing qubits during execution. + dynamic_qubit_management = false + + # Whether simultaneous measurements on overlapping wires is supported. + overlapping_observables = true + + # Whether simultaneous measurements of non-commuting observables is supported. + # If false, a circuit with multiple non-commuting measurements will have to be + # split into multiple executions for each subset of commuting measurements. + non_commuting_observables = false + + # Whether the device supports initial state preparation. + initial_state_prep = false + + # The methods of handling mid-circuit measurements that the device supports, + # e.g., "one-shot", "tree-traversal", "device", etc. An empty list indicates + # that the device does not support mid-circuit measurements. + supported_mcm_methods = [ ] + +This TOML file is used by both Catalyst frontend and PennyLane. Regular circuit execution is +performed by your implementation of ``Device.execute``, whereas for a QJIT-compiled workflow, +execution is performed by the ``QuantumDevice``. The TOML file should declare the capabilities +of the two execution interfaces. If one of the interfaces have additional support that the other +does not have, include them in a separate section: + +.. code-block:: toml + + # Gates supported by the Python implementation of Device.execute but not by the QuantumDevice. + [pennylane.operators.gates] + + MultiControlledX = { } + + # Observables supported by the QuantumDevice but not by your implementation of Device.execute. + [qjit.operators.observables] + + Sum = { } + +Additionally, any runtime parameters to be passed to the ``QuantumDevice`` upon initialization +should be specified in a dictionary class property ``device_kwargs`` that links keyword arguments +of the ``QuantumDevice`` constructor to variables. For example: + +.. code-block:: python + + class CustomDevice(qml.devices.Device): + """Custom Device""" + + config_filepath = pathlib.Path("absolute/path/to/configuration/file.toml") + + def __init__(self, wires, do_something=False, special_param=""): + ... + self.device_kwargs = { + 'cpp_do_something' = do_something, + 'cpp_special_param' = special_param + } + +In the above example, a dictionary will be constructed at runtime and passed to the constructor of +the ``QuantumDevice`` implementation. diff --git a/doc/dev/devices.rst b/doc/dev/devices.rst index 8fd3dead5e..e030285439 100644 --- a/doc/dev/devices.rst +++ b/doc/dev/devices.rst @@ -96,3 +96,15 @@ Supported backend devices include: See the `Catalyst configuration file `__ for natively supported instructions. + * - ``oqd.default`` + + - Experimental support for execution on `Open Quantum Design (OQD) `__ + trapped-ion hardware. To use OQD with Catalyst, use the ``backend`` argument to specify the + OQD backend to use when initializing the device: + + .. code-block:: python + + dev = qml.device("oqd", backend="default", shots=1024, wires=2) + + See the `Catalyst configuration file `__ + for natively supported instructions. diff --git a/doc/dev/installation.rst b/doc/dev/installation.rst index b38945599c..b3a72a0db8 100644 --- a/doc/dev/installation.rst +++ b/doc/dev/installation.rst @@ -2,8 +2,8 @@ Installation ============ -Catalyst is officially supported on Linux (x86_64, aarch64) and macOS (arm64, x86_64) -platforms, and pre-built binaries are being distributed via the Python Package Index (PyPI) for +Catalyst is officially supported on Linux (x86_64, aarch64) and macOS (arm64, x86_64) +platforms, and pre-built binaries are being distributed via the Python Package Index (PyPI) for Python versions 3.10 and higher. To install it, simply run the following ``pip`` command: .. code-block:: console @@ -19,62 +19,22 @@ Python versions 3.10 and higher. To install it, simply run the following ``pip`` The easiest method of installation is to run ``xcode-select --install`` from the Terminal app. -Pre-built packages for Windows are not yet available, and compatibility with other platforms is -untested and cannot be guaranteed. If you are using one of these platforms, please -try out our Docker and Dev Container images described in the `next section <#dev-containers>`_. +Pre-built packages for Windows are not yet available, and compatibility is untested and cannot +be guaranteed. If you would like to use Catalyst on Windows, we recommend trying the +`WSL `_. If you wish to contribute to Catalyst or develop against our runtime or compiler, instructions for -building from source are also included `further down <#minimal-building-from-source-guide>`_. - -Dev Containers --------------- - - -Try out Catalyst in self-contained, ready-to-go environments called -`Dev Containers `__: - -.. image:: https://img.shields.io/static/v1?label=Dev%20Container&message=Launch&color=blue&logo=visualstudiocode&style=flat-square - :alt: Try Catalyst in Dev Container - :target: https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/PennyLaneAI/catalyst - :align: center - -| You will need an existing installation of `Docker `_, - `VS Code `_, and the VS Code - `Dev Containers `__ - extension. - -If desired, the Docker images can also be used in a standalone fashion: - -| `Docker: User Installation `_ -| `Docker: Developer Installation `_ - -The user image provides an officially supported environment and automatically installs the latest -release of Catalyst. The developer image only provides the right environment to build Catalyst from -source, and requires launching the post-install script at ``.devcontainer/dev/post-install.sh`` -from within the root of the running container. - -.. note:: - - Due to `a bug `_ in the Dev - Containers extension, clicking on the "Launch" badge will not prompt for a choice between the User - and Dev containers. Instead, the User container is automatically chosen. - - As a workaround, you can clone the `Catalyst repository `_ - first, open it as a VS Code Workspace, and then reopen the Workspace in a Dev Container via the - ``Reopen in Container`` command. - +building from source are detailed `below <#minimal-building-from-source-guide>`_. Minimal Building From Source Guide ---------------------------------- -Most developers might want to build Catalyst from source instead of using a pre-shipped package. In this section we present a minimal building-from-source installation guide. - -The next section provides a more detailed guide, which we **strongly** recommend the user to read through. Importantly, each component of Catalyst, namely the Python frontend, the MLIR compiler, and the runtime library, can be built and tested indenpendently, which this minimal installation guide does not go over. - - -The essential steps are: +This is an abbreviated set of instructions that can be copy-pasted into the terminal of most +common systems. For information on pre-requisites, how to build individual components, or if +you are encoutering issues, please consult the detailed guide +`in the next section <#detailed-building-from-source-guide>`_. .. tabs:: @@ -90,9 +50,9 @@ The essential steps are: .. code-block:: console # Install common requirements - sudo apt install clang lld ccache libomp-dev ninja-build make cmake + sudo apt install clang lld ccache libomp-dev ninja-build make cmake - # Clone the Catalyst repository + # Clone the Catalyst repository git clone --recurse-submodules --shallow-submodules https://github.com/PennyLaneAI/catalyst.git # Install specific requirements for Catalyst @@ -114,17 +74,17 @@ The essential steps are: pip install cmake ninja # If not present yet, install Homebrew (https://brew.sh/) - brew install libomp ccache + brew install libomp ccache gfortran # Add ccache drop-in compiler replacements to the PATH export PATH=/usr/local/opt/ccache/libexec:$PATH - # Clone the Catalyst repository + # Clone the Catalyst repository git clone --recurse-submodules --shallow-submodules https://github.com/PennyLaneAI/catalyst.git # Install specific requirements for Catalyst cd catalyst - pip install -r requirements.txt + pip install -r requirements.txt # Build Catalyst make all @@ -132,17 +92,11 @@ The essential steps are: # Test that everything is built properly make test -These steps should give you the full functionality of Catalyst. - Detailed Building From Source Guide ----------------------------------- -.. note:: - This section is a detailed building-from-source guide. Some commands in this section has already been included in the minimal guide. - - To build Catalyst from source, developers should follow the instructions provided below for building all three modules: the Python frontend, the MLIR compiler, and the runtime library. @@ -156,7 +110,9 @@ installed and available on the path (depending on the platform): - The `clang `_ compiler, `LLD `_ linker (Linux only), `CCache `_ compiler cache (optional, recommended), and - `OpenMP `_. + `OpenMP `_. Additionaly, the + `GFortran `_ compiler is + required on ARM macOS systems. - The `Ninja `_, `Make `_, and `CMake `_ (v3.20 or greater) build tools. @@ -208,7 +164,7 @@ They can be installed via: xcode-select --install pip install cmake ninja - brew install libomp ccache + brew install libomp ccache gfortran export PATH=/usr/local/opt/ccache/libexec:$PATH @@ -395,7 +351,7 @@ To build and test documentation for Catalyst, you will need to install Additionally, `doxygen `_ is required to build C++ documentation, and `pandoc `_ to render Jupyter Notebooks. -They can be installed via +They can be installed via .. tabs:: @@ -430,16 +386,16 @@ Known Issues .. group-tab:: Linux Debian/Ubuntu - If you get this error: + If you get this error: .. code-block:: console - + cannot find -lstdc++: No such file or directory - you might need to install a recent version of ``libstdc``. E.g.: + you might need to install a recent version of ``libstdc``. E.g.: .. code-block:: console - + sudo apt install libstdc++-12-dev (See user's report `here `_) @@ -451,13 +407,13 @@ Known Issues Under Ubuntu 24.04, if you get this error: .. code-block:: console - + fatal error: 'Python.h' file not found - + you might need to install the Python Dev package: .. code-block:: console - + sudo apt install python3-dev (See user's report `here `_) @@ -477,9 +433,9 @@ Known Issues Install a Frontend-Only Development Environment from TestPyPI Wheels -------------------------------------------------------------------- -It is possible to work on the source code repository and test the changes without -having to compile Catalyst. This is ideal for situations where the changes do not target the -runtime or the MLIR infrastructure, and only concern the frontend. It basically +It is possible to work on the source code repository and test the changes without +having to compile Catalyst. This is ideal for situations where the changes do not target the +runtime or the MLIR infrastructure, and only concern the frontend. It basically makes use of the shared libraries already shipped with the TestPyPI Catalyst wheels. Essential Steps @@ -531,7 +487,7 @@ How Does it Work? The provided script first creates and activates a Python virtual environment, so the system Python configurations do not get affected, nor other virtual environments. -In a second step, it obtains the latest Catalyst wheel from the TestPyPI server and creates hard +In a second step, it obtains the latest Catalyst wheel from the TestPyPI server and creates hard links from the wheel code to the frontend code of the repository, in order to allow working directly with the frontend code of the repository and at the same time test the changes while using the installed Catalyst wheel libraries, hence avoiding compilation. @@ -539,7 +495,7 @@ using the installed Catalyst wheel libraries, hence avoiding compilation. Further Steps ^^^^^^^^^^^^^ -If everything goes well, ``git status`` should not report any changed files. +If everything goes well, ``git status`` should not report any changed files. Before making changes to the frontend, make sure you create a new branch: @@ -550,7 +506,7 @@ Before making changes to the frontend, make sure you create a new branch: Once in the new branch, make the wanted changes. Use the IDE of your preference. You can test the changes by executing your sample code under the same virtual environment you used -with the scripts. As files in the repository are hard-linked to the Wheel code, you are actually +with the scripts. As files in the repository are hard-linked to the Wheel code, you are actually changing the code stored at the Python ``site-packages`` folder as well, and you will be automatically using the shared libraries provided by the Python wheels. Again, there is no need to compile Catalyst from source. @@ -559,7 +515,7 @@ You can commit your changes as usual. Once ready, push the new branch to the rem repository: .. code-block:: console - + git push -u origin new-branch-name Now you can go to GitHub and issue a Pull Request based on the new branch. diff --git a/doc/dev/plugins.rst b/doc/dev/plugins.rst new file mode 100644 index 0000000000..3a368c171f --- /dev/null +++ b/doc/dev/plugins.rst @@ -0,0 +1,353 @@ +MLIR Plugins +============ + +This page outlines documentation on how to start developping an MLIR plugin that can work with Catalyst. +An MLIR plugin is a shared object that implements a compilation pass compatible with the MLIR framework. +Catalyst is built on top of MLIR, this means that MLIR plugins work with Catalyst. +This can enable anyone to build quantum compilation passes and new dialects as well. + +Building the Standalone Plugin +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Catalyst comes with ``Makefile`` rules to build the `standalone-plugin from MLIR upstream's source code `_. +Simply type + +``make plugin`` + +and in the ``catalyst/mlir/standalone/build/lib`` folder, you will find the ``StandalonePlugin.so`` plugin. +The ``StandalonePlugin.so`` file is a simple plugin that has its own dialect (called Standalone dialect) and a single transformation that transforms symbol names from ``bar`` to ``foo``. +It is intended to show how one would build an MLIR plugin, rather than showing all the features to build a usable MLIR plugin. + +You can use the ``StandalonePlugin.so`` plugin + +* with either ``quantum-opt`` or ``catalyst-cli``, +* load it from Python and transform a quantum program. + +For example, if you are interested in using it from the command line interface, you can use the following flags to load the standalone plugin: + +* ``--load-pass-plugin=/path/to/StandalonePlugin.so`` +* ``--load-dialect-plugin=/path/to/StandalonePlugin.so`` + +This allows all normal flags to work. +For example using ``quantum-opt --help`` while loading your pass plugin will enable you to see the documentation available for the standalone pass. + +.. code-block:: + + --standalone-switch-bar-foo - Switches the name of a FuncOp named `bar` to `foo` and folds. + +Taking into account the description of the pass ``standalone-switch-bar-foo``, let's write the most minimal program that would be transformed by this transformation. + +.. code-block:: mlir + + module @module { + func.func private @bar() -> (tensor) { + %c = stablehlo.constant dense<0> : tensor + return %c : tensor + } + } + +And you can schedule this pass as any other pass + +.. code-block:: + + quantum-opt --load-pass-plugin=/path/to/StandalonePlugin.so --standalone-switch-bar-to-foo example.mlir' + +And you have your transformed program + +.. code-block:: mlir + + module @module { + func.func private @foo() -> tensor { + %c = stablehlo.constant dense<0> : tensor + return %c : tensor + } + } + +Notice that the name of the function ``bar`` has been changed to ``foo``. + +Pass Plugins vs Dialect Plugins +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +You may now be asking, "how come we used the option ``--load-pass-plugin`` but we didn't use the option ``--load-dialect-plugin``?" +The ``--load-pass-plugin`` option is used to load passes, while the ``--load-dialect-plugin`` is used to load dialects. +As mentioned earlier, the ``StandalonePlugin.so`` file also contains a dialect. +It is a simple dialect intended only for testing purposes, and it only contains a single operation. It is the ``standalone.foo`` operation. +(Please do not confuse this operation with symbols named ``foo``). + +We can write a program that contains operations in the standalone dialect: + +.. code-block:: mlir + + module @module { + func.func private @bar() -> (i32) { + %0 = arith.constant 0 : i32 + %1 = standalone.foo %0 : i32 + return %1 : i32 + } + } + +But if we try to run it, using the same command as shown earlier + +.. code-block:: + + quantum-opt --load-pass-plugin=/path/to/StandalonePlugin.so --standalone-switch-bar-to-foo example.mlir' + +the compilation will fail with the following message: + +.. code-block:: + + example.mlir:4:10: error: Dialect `standalone' not found for custom op 'standalone.foo' + %1 = standalone.foo %0 : i32 + ^ + a.mlir:4:10: note: Registered dialects: acc, affine, amdgpu, amx, arith, arm_neon, arm_sme, arm_sve, async, bufferization, builtin, catalyst, cf, chlo, complex, dlti, emitc, func, gpu, gradient, index, irdl, linalg, llvm, math, memref, mesh, mhlo, mitigation, ml_program, mpi, nvgpu, nvvm, omp, pdl, pdl_interp, polynomial, quant, quantum, rocdl, scf, shape, sparse_tensor, spirv, stablehlo, tensor, test, tosa, transform, ub, vector, vhlo, x86vector, xegpu ; for more info on dialect registration see https://mlir.llvm.org/getting_started/Faq/#registered-loaded-dependent-whats-up-with-dialects-management + +To be able to parse this dialect, we need to load the dialect which is stored in the same file + +.. code-block:: + + quantum-opt --load-pass-plugin=/path/to/StandalonePlugin.so --load-dialect-plugin-/path/to/StandalonePlugin.so --standalone-switch-bar-to-foo example.mlir' + +Now, you can parse the program without the error and run the ``standalone-switch-bar-to-foo`` pass. + +Creating your own Pass Plugin +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Catalyst lists LLVM as a git submodule in its repository and the LLVM project already contains an example standalone plugin. +When running ``make standalone-plugin`` Catalyst will copy the directory containing the standalone plugin and patch it to make sure that it works with Catalyst. +However, as mentioned earlier, the standalone plugin is a bare bones example. +You may be wondering, well, how can I make a standalone plugin but that is able to change some aspects of the quantum program? +For that, you will need to change the build script for the standalone plugin. +For now, we found that the following process is the easiest one: + +1. Add the standalone plugin directory as a subdirectory of Catalyst: + +.. code-block:: diff + + diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt + index c0b8dfd6c..1b5c2e528 100644 + --- a/mlir/CMakeLists.txt + +++ b/mlir/CMakeLists.txt + @@ -73,6 +73,7 @@ add_subdirectory(include) + add_subdirectory(lib) + add_subdirectory(tools) + add_subdirectory(test) + +add_subdirectory(standalone) + + if(QUANTUM_ENABLE_BINDINGS_PYTHON) + message(STATUS "Enabling Python API") + +You will also need to make the following change: + +.. code-block:: diff + + diff --git a/mlir/standalone/CMakeLists.txt b/mlir/standalone/CMakeLists.txt + index e999ae34d..fd6ee8f10 100644 + --- a/mlir/standalone/CMakeLists.txt + +++ b/mlir/standalone/CMakeLists.txt + @@ -1,6 +1,3 @@ + -cmake_minimum_required(VERSION 3.20.0) + -project(standalone-dialect LANGUAGES CXX C) + - + set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON) + + set(CMAKE_CXX_STANDARD 17 CACHE STRING "C++ standard to conform to") + +.. code-block:: diff + + diff --git a/mlir/standalone/CMakeLists.txt b/mlir/standalone/CMakeLists.txt + index 280cd80e1..fd6ee8f10 100644 + --- a/mlir/standalone/CMakeLists.txt + +++ b/mlir/standalone/CMakeLists.txt + @@ -32,8 +32,8 @@ if(MLIR_ENABLE_BINDINGS_PYTHON) + mlir_configure_python_dev_packages() + endif() + + -set(STANDALONE_SOURCE_DIR ${PROJECT_SOURCE_DIR}) + -set(STANDALONE_BINARY_DIR ${PROJECT_BINARY_DIR}) + +set(STANDALONE_SOURCE_DIR ${PROJECT_SOURCE_DIR}/standalone) + +set(STANDALONE_BINARY_DIR ${PROJECT_BINARY_DIR}/standalone) + include_directories(${LLVM_INCLUDE_DIRS}) + include_directories(${MLIR_INCLUDE_DIRS}) + include_directories(${STANDALONE_SOURCE_DIR}/include) + +With these changes, you should now be able to use ``make all`` and build the standalone plugin. +Please note that the location of the ``StandalonePlugin.so`` shared object has changed. +It will now be stored in the ``mlir/build/lib/`` folder. + +2. Include the header files in the standalone plugin pass. + +.. code-block:: diff + + diff --git a/mlir/standalone/lib/Standalone/StandalonePasses.cpp b/mlir/standalone/lib/Standalone/StandalonePasses.cpp + index a23d0420f..83e2ce255 100644 + --- a/mlir/standalone/lib/Standalone/StandalonePasses.cpp + +++ b/mlir/standalone/lib/Standalone/StandalonePasses.cpp + @@ -12,6 +12,7 @@ + #include "mlir/Transforms/GreedyPatternRewriteDriver.h" + + #include "Standalone/StandalonePasses.h" + +#include "Quantum/IR/QuantumOps.h" + + namespace mlir::standalone { + #define GEN_PASS_DEF_STANDALONESWITCHBARFOO + +You can type ``make all`` and see the compilation succeed. +Please note that Catalyst has three custom dialects, the Quantum, Catalyst and Gradient dialect. +Depending on which dialect you are interested in, you can include the definition of the operations in that way. + +3. Marking dialects as dependent in the pass TableGen file. + +.. code-block:: diff + + diff --git a/mlir/standalone/include/Standalone/StandalonePasses.td b/mlir/standalone/include/Standalone/StandalonePasses.td + index dc8fb43d2..29510d74d 100644 + --- a/mlir/standalone/include/Standalone/StandalonePasses.td + +++ b/mlir/standalone/include/Standalone/StandalonePasses.td + @@ -26,6 +26,10 @@ def StandaloneSwitchBarFoo: Pass<"standalone-switch-bar-foo", "::mlir::ModuleOp" + ``` + }]; + + + let dependentDialects = [ + + "catalyst::quantum::QuantumDialect" + + ]; + + + } + + #endif // STANDALONE_PASS + +LLVM and MLIR use an embedded DSL to declare passes called `Tablegen `_. +This saves LLVM and MLIR developers time, because Tablegen generates C++ files that are mostly just boilerplate code. +We are not going to go in depth into Tablegen, you just need to know that transformations require to register which passes are used. +In this example, since we are interested in using the quantum dialect, we will add the Quantum Dialect in the list of dependent dialects. + +One also needs to link the MLIRQuantum library and change the plugin tool to catalyst-cli. + +.. code-block:: diff + + diff --git a/mlir/standalone/lib/Standalone/CMakeLists.txt b/mlir/standalone/lib/Standalone/CMakeLists.txt + index 0f1705a25..8874e410d 100644 + --- a/mlir/standalone/lib/Standalone/CMakeLists.txt + +++ b/mlir/standalone/lib/Standalone/CMakeLists.txt + @@ -10,9 +10,11 @@ add_mlir_dialect_library(MLIRStandalone + DEPENDS + MLIRStandaloneOpsIncGen + MLIRStandalonePassesIncGen + + MLIRQuantum + + LINK_LIBS PUBLIC + MLIRIR + MLIRInferTypeOpInterface + MLIRFuncDialect + + MLIRQuantum + ) + +.. code-block:: diff + + diff --git a/mlir/standalone/standalone-plugin/CMakeLists.txt b/mlir/standalone/standalone-plugin/CMakeLists.txt + index 3e3383608..2dbeea9d5 100644 + --- a/mlir/standalone/standalone-plugin/CMakeLists.txt + +++ b/mlir/standalone/standalone-plugin/CMakeLists.txt + @@ -5,7 +5,7 @@ add_llvm_library(StandalonePlugin + DEPENDS + MLIRStandalone + PLUGIN_TOOL + - mlir-opt + + catalyst-cli + + LINK_LIBS + MLIRStandalone + +Please note that if you are using the Catalyst or Gradient dialects, you should also add MLIRCatalyst and MLIRGradient to the list of dependences and libraries to be linked. + +4. Modify the standalone plugin to modify quantum operations. + +Here we will create a very simple pass that will change a the quantum qubit allocation from 1 to 42 (for illustration purposes). +We recommend reading MLIR tutorials on how to write MLIR passes, reading the Catalyst source to understand the Catalyst IR, and submitting issues if you are having troubles building your own plugin. + +The first thing we need to do is change the ``OpRewritePattern`` to match against our ``quantum::AllocOp`` which denotes how many qubits should be allocated for a given quantum program. + +.. code-block:: diff + + diff --git a/mlir/standalone/lib/Standalone/StandalonePasses.cpp b/mlir/standalone/lib/Standalone/StandalonePasses.cpp + index 83e2ce255..504cf2d20 100644 + --- a/mlir/standalone/lib/Standalone/StandalonePasses.cpp + +++ b/mlir/standalone/lib/Standalone/StandalonePasses.cpp + @@ -19,10 +19,10 @@ namespace mlir::standalone { + #include "Standalone/StandalonePasses.h.inc" + + namespace { + -class StandaloneSwitchBarFooRewriter : public OpRewritePattern { + +class StandaloneSwitchBarFooRewriter : public OpRewritePattern { + public: + - using OpRewritePattern::OpRewritePattern; + - LogicalResult matchAndRewrite(func::FuncOp op, + + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(catalyst::quantum::AllocOp op, + PatternRewriter &rewriter) const final { + if (op.getSymName() == "bar") { + rewriter.modifyOpInPlace(op, [&op]() { op.setSymName("foo"); }); + +The next step is changing the contents of the function itself: + +.. code-block:: diff + + diff --git a/mlir/standalone/lib/Standalone/StandalonePasses.cpp b/mlir/standalone/lib/Standalone/StandalonePasses.cpp + index 83e2ce255..e8a7f805e 100644 + --- a/mlir/standalone/lib/Standalone/StandalonePasses.cpp + +++ b/mlir/standalone/lib/Standalone/StandalonePasses.cpp + @@ -19,15 +19,21 @@ namespace mlir::standalone { + #include "Standalone/StandalonePasses.h.inc" + + namespace { + -class StandaloneSwitchBarFooRewriter : public OpRewritePattern { + +class StandaloneSwitchBarFooRewriter : public OpRewritePattern { + public: + - using OpRewritePattern::OpRewritePattern; + - LogicalResult matchAndRewrite(func::FuncOp op, + + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(catalyst::quantum::AllocOp op, + PatternRewriter &rewriter) const final { + - if (op.getSymName() == "bar") { + - rewriter.modifyOpInPlace(op, [&op]() { op.setSymName("foo"); }); + + // get the number of qubits allocated + + if (op.getNqubitsAttr().value_or(0) == 1) { + + Type i64 = rewriter.getI64Type(); + + auto fortytwo = rewriter.getIntegerAttr(i64, 42); + + + + // modify the allocation to change the number of qubits to 42. + + rewriter.modifyOpInPlace(op, [&]() { op.setNqubitsAttrAttr(fortytwo); }); + return success(); + } + + // failure indicates that nothing was modified. + return failure(); + } + }; + +And then we can run ``make all`` again. +The shared object of the standalone plugin should be available in ``mlir/build/lib/StandalonePlugin.so``. +This shared object can be used with ``catalyst-cli`` and ``quantum-opt``. +From here, you can change the name of the pass, change the name of the shared object, and implement more complex transformations. + + +5. Build your own python wheel and ship your plugin. + +Now that you have your ``StandalonePlugin.so``, you can ship it in a python wheel. +To allow users to run your pass, we have provided a class called :class:`~.passes.Pass` and :class:`~.passes.PassPlugin`. +You can extend these classes and allow the user to import your derived classes and run passes as a decorator. +For example: + +.. code-block:: python + + @SwitchBarToFoo + @qml.qnode(qml.device("lightning.qubit", wires=1)) + def qnode(): + return qml.state() + + @qml.qjit + def module(): + return qnode() + +If you inspect the MLIR sources, you'll find that the number of qubits allocated will be 42. +Take a look into the ``standalone_plugin_wheel`` make rule to see how we test shipping a plugin. +For more information, please consult our `dialect guide <../dev/dialects.html>`_, our `compiler passes guide <../dev/transforms.html>`_, and the `MLIR documentation `_. + diff --git a/doc/index.rst b/doc/index.rst index 24240b5006..d5c8ab7f9f 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -84,6 +84,7 @@ Catalyst Compiler Core MLIR Dialects Compiler Passes + Compiler Plugins Quantum Runtime dev/debugging dev/custom_devices diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index b6ed1b7af0..2b90faf3ab 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -2,11 +2,20 @@

New features since last release

+* Catalyst can now load local MLIR plugins from python. Including support for `entry_points`. + [(#1317)](https://github.com/PennyLaneAI/catalyst/pull/1317) + [(#1361)](https://github.com/PennyLaneAI/catalyst/pull/1361) + [(#1370)](https://github.com/PennyLaneAI/catalyst/pull/1370) +

Improvements 🛠

+* Lightning runtime shot-measurement support for Hermitian observables. + [(#451)](https://github.com/PennyLaneAI/catalyst/pull/451) + * Replace pybind11 with nanobind for C++/Python bindings in the frontend and in the runtime. [(#1173)](https://github.com/PennyLaneAI/catalyst/pull/1173) [(#1293)](https://github.com/PennyLaneAI/catalyst/pull/1293) + [(#1391)](https://github.com/PennyLaneAI/catalyst/pull/1391) Nanobind has been developed as a natural successor to the pybind11 library and offers a number of [advantages](https://nanobind.readthedocs.io/en/latest/why.html#major-additions), in particular, @@ -17,8 +26,8 @@ Frontend no longer uses pybind11 to connect to the compiler. Instead, it uses subprocess instead. [(#1285)](https://github.com/PennyLaneAI/catalyst/pull/1285) -* Add a MLIR decomposition for the gate set {"T", "S", "Z", "Hadamard", "RZ", "PhaseShift", "CNOT"} to - the gate set {RX, RY, MS}. It is useful for trapped ion devices. It can be used thanks to +* Add a MLIR decomposition for the gate set {"T", "S", "Z", "Hadamard", "RZ", "PhaseShift", "CNOT"} + to the gate set {RX, RY, MS}. It is useful for trapped ion devices. It can be used thanks to `quantum-opt --ions-decomposition`. [(#1226)](https://github.com/PennyLaneAI/catalyst/pull/1226) @@ -39,16 +48,74 @@ * Improves the readability of conditional passes in pipelines [(#1194)](https://github.com/PennyLaneAI/catalyst/pull/1194) +* Cleans up the output of compiler instrumentation. + [(#1343)](https://github.com/PennyLaneAI/catalyst/pull/1343) + +* Generate stable ABI wheels for Python 3.12 and up. + [(#1357)](https://github.com/PennyLaneAI/catalyst/pull/1357) + [(#1385)](https://github.com/PennyLaneAI/catalyst/pull/1385) + +* A new circuit optimization pass, `--disentangle-CNOT`, is available. + [(#1154)](https://github.com/PennyLaneAI/catalyst/pull/1154) + + The pass disentangles CNOT gates whenever possible, e.g. when the control bit + is known to be in |0>, the pass removes the CNOT. The pass uses a finite state + machine to propagate simple one-qubit states, in order to determine + the input states to the CNOT. + + The algorithm is taken from [Relaxed Peephole Optimization: A Novel Compiler Optimization for Quantum Circuits, by Ji Liu, Luciano Bello, and Huiyang Zhou](https://arxiv.org/abs/2012.07711). + +* A new circuit optimization pass, `--disentangle-SWAP`, is available. + [(#1297)](https://github.com/PennyLaneAI/catalyst/pull/1297) + + The pass disentangles SWAP gates whenever possible by using a finite state + machine to propagate simple one-qubit states, similar to the `--disentangle-CNOT` pass. + + The algorithm is taken from [Relaxed Peephole Optimization: A Novel Compiler Optimization for Quantum Circuits, by Ji Liu, Luciano Bello, and Huiyang Zhou](https://arxiv.org/abs/2012.07711). +

Breaking changes 💔

-* Handling for the legacy operator arithmetic (the `Hamiltonian` and `Tensor` classes in PennyLane) +* The `sample` and `counts` measurement primitives now support dynamic shot values across catalyst, although at the PennyLane side, the device shots still is constrained to a static integer literal. + + To support this, `SampleOp` and `CountsOp` in mlir no longer carry the shots attribute, since integer attributes are tied to literal values and must be static. + + `DeviceInitOp` now takes in an optional SSA argument for shots, and the device init runtime CAPI will take in this SSA shots value as an argument and set it as the device shots. + The sample and counts runtime CAPI functions no longer take in the shots argument and will retrieve shots from the device. + + Correspondingly, the device C++ interface should no longer parse the `DeviceInitOp`'s attributes dictionary for the shots. + For now we still keep the shots as an attribute so device implementors can have time to migrate, but we will remove shots from the attribute dictionary in the next release. + + [(#1170)](https://github.com/PennyLaneAI/catalyst/pull/1170) + [(#1310)](https://github.com/PennyLaneAI/catalyst/pull/1310) + +* The `toml` module has been migrated to PennyLane with an updated schema for declaring device + capabilities. Devices with TOML files using `schema = 2` will not be compatible with the latest + Catalyst. See [Custom Devices](https://docs.pennylane.ai/projects/catalyst/en/stable/dev/custom_devices.html) + for updated instructions on integrating your device with Catalyst and PennyLane + [(#1275)](https://github.com/PennyLaneAI/catalyst/pull/1275) + +* Handling for the legacy operator arithmetic (the `Hamiltonian` and `Tensor` classes in PennyLane) is removed. [(#1308)](https://github.com/PennyLaneAI/catalyst/pull/1308) +

Bug fixes 🐛

+ +* Fix bug introduced in 0.8 that breaks nested invocations of `qml.adjoint` and `qml.ctrl`. + [(#1301)](https://github.com/PennyLaneAI/catalyst/issues/1301) + +* Fix a bug in `debug.compile_executable` which would generate incorrect stride information for + array arguments of the function, in particular when non-64bit datatypes are used. + [(#1338)](https://github.com/PennyLaneAI/catalyst/pull/1338) +

Deprecations 👋

Internal changes ⚙️

+* Catalyst no longer depends on or pins the `scipy` package, instead OpenBLAS is sourced directly + from [`scipy-openblas32`](https://pypi.org/project/scipy-openblas32/) or Accelerate is used. + [(#1322)](https://github.com/PennyLaneAI/catalyst/pull/1322) + [(#1328)](https://github.com/PennyLaneAI/catalyst/pull/1328) + * The `QuantumExtension` module (previously implemented with pybind11) has been removed. This module was not included in the distributed wheels and has been deprecated to align with our adoption of Python's stable ABI, which pybind11 does not support. @@ -59,25 +126,95 @@ [(#1307)](https://github.com/PennyLaneAI/catalyst/pull/1307) [(#1312)](https://github.com/PennyLaneAI/catalyst/pull/1312) +* `catalyst-cli` and `quantum-opt` are compiled with `default` visibility, which allows for MLIR plugins to work. + [(#1287)](https://github.com/PennyLaneAI/catalyst/pull/1287) + +* Sink patching of autograph's allowlist. + [(#1332)](https://github.com/PennyLaneAI/catalyst/pull/1332) + [(#1337)](https://github.com/PennyLaneAI/catalyst/pull/1337) + +* Each qnode now has its own transformation schedule. + Instead of relying on the name of the qnode, each qnode now has a transformation module, + which denotes the transformation schedule, embedded in its MLIR representation. + [(#1323)](https://github.com/PennyLaneAI/catalyst/pull/1323) + +* The `apply_registered_pass_p` primitive is removed. The API for scheduling passes + to run using the transform dialect has been refactored. In particular, + passes are appended to a tuple as they are being registered and they will + be run in order. If there are no local passes, the global `pass_pipeline` is + scheduled. Furthermore, this commit also reworks the caching mechanism for + primitives, which is important as qnodes and functions are primitives and + now that we can apply passes to them, they are distinct based on which + passes have been scheduled to run on them. + [(#1317)](https://github.com/PennyLaneAI/catalyst/pull/1317) + +* Replace Python C-API calls with Stable ABI calls. + [(#1354)](https://github.com/PennyLaneAI/catalyst/pull/1354) + +* A framework for loading and interacting with databases containing hardware information and + calibration data for Open Quantum Design (OQD) trapped-ion quantum devices has been added. + [(#1348)](https://github.com/PennyLaneAI/catalyst/pull/1348) + + A new module, `catalyst.utils.toml_utils`, was also added to assist in loading information from + these databases, which are stored as text files in TOML format. In particular, this module + contains a new function, :func:`~.utils.toml_utils.safe_eval`, to safely evaluate mathematical + expressions: + + ```python + >>> from catalyst.utils.toml_utils import safe_eval + >>> safe_eval("2 * math.pi * 1e9") + 6283185307.179586 + ``` + +* A default backend for OQD trapped-ion quantum devices has been added. + [(#1355)](https://github.com/PennyLaneAI/catalyst/pull/1355) + +* `expval` and `var` operations no longer keep the static shots attribute, as a step towards supporting dynamic shots across catalyst. + [(#1360)](https://github.com/PennyLaneAI/catalyst/pull/1360) + +* A new `ion` dialect has been added for Catalyst programs targeting OQD trapped-ion quantum devices. + [(#1260)](https://github.com/PennyLaneAI/catalyst/pull/1260) + [(#1372)](https://github.com/PennyLaneAI/catalyst/pull/1372) + + The `ion` dialect defines the set of physical properties of the device, such as the ion species + and their atomic energy levels, as well as the operations to manipulate the qubits in the + trapped-ion system, such as laser pulse durations, polarizations, detuning frequencies, etc. + + A new pass, `--quantum-to-ion`, has also been added to convert logical gate-based circuits in the + Catalyst `quantum` dialect to laser pulse operations in the `ion` dialect. This pass accepts + logical quantum gates from the set {RX, RY, Mølmer–Sørensen (MS)}. Doing so enables the insertion + of physical device parameters into the IR, which will be necessary when lowering to OQD's backend + calls. The physical parameters are read in from [TOML](https://toml.io/en/) files during the + `--quantum-to-ion` conversion. The TOML files are assumed to exist by the pass (the paths to the + TOML file locations are taken in as pass options), with the intention that they are generated + immediately before compilation during hardware-calibration runs. + +* IR is now extended to support literal values as opposed to SSA Values for static parameters of + quantum gates by adding a new gate called StaticCustomOp with lowering to regular customOp. + [(#1387)](https://github.com/PennyLaneAI/catalyst/pull/1387) +

Documentation 📝

-* A new tutorial going through how to write a new MLIR pass is available. The tutorial writes an empty pass that prints hello world. The code of the tutorial is at [a separate github branch](https://github.com/PennyLaneAI/catalyst/commit/ba7b3438667963b307c07440acd6d7082f1960f3). +* A new tutorial going through how to write a new MLIR pass is available. The tutorial writes an + empty pass that prints hello world. The code of the tutorial is at + [a separate github branch](https://github.com/PennyLaneAI/catalyst/commit/ba7b3438667963b307c07440acd6d7082f1960f3). [(#872)](https://github.com/PennyLaneAI/catalyst/pull/872) -

Bug fixes 🐛

- -* Fix bug introduced in 0.8 that breaks nested invocations of `qml.adjoint` and `qml.ctrl`. - [(#1301)](https://github.com/PennyLaneAI/catalyst/issues/1301) +* Updated catalyst-cli documentation to reflect the removal of func-name option for trasnformation passes. + [(#1368)](https://github.com/PennyLaneAI/catalyst/pull/1368)

Contributors ✍️

This release contains contributions from (in alphabetical order): +Astral Cai, Joey Carter, David Ittah, Erick Ochoa Lopez, Mehrdad Malekmohammadi, William Maxwell Romain Moyard, +Shuli Shu, +Ritu Thombre, Raul Torres, Paul Haochen Wang. diff --git a/doc/requirements.txt b/doc/requirements.txt index cda7cc4559..1ae7387ba4 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -30,6 +30,6 @@ lxml_html_clean # Pre-install PL development wheels --extra-index-url https://test.pypi.org/simple/ -pennylane-lightning-kokkos==0.40.0-dev11 -pennylane-lightning==0.40.0-dev11 -pennylane==0.40.0-dev16 +pennylane-lightning-kokkos==0.40.0-dev41 +pennylane-lightning==0.40.0-dev41 +pennylane==0.40.0-dev20 diff --git a/frontend/catalyst/_version.py b/frontend/catalyst/_version.py index 8dbc64777e..5b6f37e69b 100644 --- a/frontend/catalyst/_version.py +++ b/frontend/catalyst/_version.py @@ -16,4 +16,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "0.10.0-dev8" +__version__ = "0.10.0-dev35" diff --git a/frontend/catalyst/autograph/transformer.py b/frontend/catalyst/autograph/transformer.py index 06b85d47cd..7e17d11736 100644 --- a/frontend/catalyst/autograph/transformer.py +++ b/frontend/catalyst/autograph/transformer.py @@ -21,16 +21,18 @@ by Catalyst. """ import copy +import functools import inspect from contextlib import ContextDecorator import pennylane as qml -from malt.core import ag_ctx, converter +from malt.core import ag_ctx, config, converter from malt.impl.api import PyToPy import catalyst from catalyst.autograph import ag_primitives, operator_update from catalyst.utils.exceptions import AutoGraphError +from catalyst.utils.patching import Patcher class CatalystTransformer(PyToPy): @@ -132,17 +134,26 @@ def transform_ast(self, node, ctx): return node -def run_autograph(fn): +def run_autograph(fn, *modules): """Decorator that converts the given function into graph form.""" - user_context = converter.ProgramContext(TOPLEVEL_OPTIONS) + allowed_modules = tuple(config.Convert(module) for module in modules) + allowed_modules += ag_primitives.module_allowlist + user_context = converter.ProgramContext(TOPLEVEL_OPTIONS) new_fn, module, source_map = TRANSFORMER.transform(fn, user_context) new_fn.ag_module = module new_fn.ag_source_map = source_map new_fn.ag_unconverted = fn - return new_fn + @functools.wraps(new_fn) + def wrapper(*args, **kwargs): + with Patcher( + (ag_primitives, "module_allowlist", allowed_modules), + ): + return new_fn(*args, **kwargs) + + return wrapper def autograph_source(fn): diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index 863598f5d1..45dc82afb8 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -116,36 +116,29 @@ def get_default_flags(options): # Discover the LAPACK library provided by scipy & add link against it. # Doing this here ensures we will always have the correct library name. + lib_name = "openblas" + package_name = "scipy_openblas32" + path_within_package = "lib" + file_extension = ".so" if platform.system() == "Linux" else ".dylib" # pragma: no branch + + if platform.system() == "Darwin" and platform.machine() == "arm64": # pragma: nocover + # use our own build of LAPACKe to interface with Accelerate + lapack_lib_name = "lapacke.3" + else: + package_spec = importlib.util.find_spec(package_name) + package_directory = path.dirname(package_spec.origin) + lapack_lib_path = path.join(package_directory, path_within_package) - if platform.system() == "Linux": - file_path_within_package = "../scipy.libs/" - file_extension = ".so" - else: # pragma: nocover - msg = "Attempting to use catalyst on an unsupported system" - assert platform.system() == "Darwin", msg - file_path_within_package = ".dylibs/" - file_extension = ".dylib" - - package_name = "scipy" - scipy_package = importlib.util.find_spec(package_name) - package_directory = path.dirname(scipy_package.origin) - scipy_lib_path = path.join(package_directory, file_path_within_package) - - file_prefix = "libopenblas" - search_pattern = path.join(scipy_lib_path, f"{file_prefix}*{file_extension}") - search_result = glob.glob(search_pattern) - if not search_result: - raise CompileError( - f'Unable to find OpenBLAS library at "{search_pattern}". ' - "Please ensure that SciPy is installed and available via pip." - ) - openblas_so_file = search_result[0] - openblas_lib_name = path.basename(openblas_so_file)[3 : -len(file_extension)] + search_pattern = path.join(lapack_lib_path, f"lib*{lib_name}*{file_extension}") + search_result = glob.glob(search_pattern) + if not search_result: # pragma: nocover + raise CompileError( + f'Unable to find OpenBLAS library at "{search_pattern}". ' + "Please ensure that scipy is installed and available via pip." + ) - lib_path_flags += [ - f"-Wl,-rpath,{scipy_lib_path}", - f"-L{scipy_lib_path}", - ] + lib_path_flags += [f"-Wl,-rpath,{lapack_lib_path}", f"-L{lapack_lib_path}"] + lapack_lib_name = path.basename(search_result[0])[3 : -len(file_extension)] system_flags = [] if platform.system() == "Linux": @@ -153,7 +146,8 @@ def get_default_flags(options): # RPATH influences search paths globally while RUNPATH only works for # a single file, but not its dependencies. system_flags += ["-Wl,-no-as-needed", "-Wl,--disable-new-dtags"] - elif platform.system() == "Darwin": # pragma: nocover + else: # pragma: nocover + assert platform.system() == "Darwin", f"Unsupported OS {platform.system()}" system_flags += ["-Wl,-arch_errors_fatal"] # The exception handling mechanism requires linking against @@ -172,7 +166,7 @@ def get_default_flags(options): "-lrt_capi", "-lpthread", "-lmlir_c_runner_utils", # required for memref.copy - f"-l{openblas_lib_name}", # required for custom_calls lib + f"-l{lapack_lib_name}", # required for custom_calls lib "-lcustom_calls", "-lmlir_async_runtime", ] @@ -289,6 +283,14 @@ def get_cli_command(self, tmp_infile_name, output_ir_name, module_name, workspac cmd += ["--module-name", module_name, "--workspace", str(workspace)] if not self.options.lower_to_llvm: cmd += ["--tool", "opt"] + if self.options.pass_plugins: + plugins = self.options.pass_plugins + for plugin in plugins: + cmd += ["--load-pass-plugin", str(plugin)] + if self.options.dialect_plugins: + plugins = self.options.dialect_plugins + for plugin in plugins: + cmd += ["--load-dialect-plugin", str(plugin)] if self.options.keep_intermediate: cmd += ["--keep-intermediate"] # The async tests are not included as part of coverage. diff --git a/frontend/catalyst/device/__init__.py b/frontend/catalyst/device/__init__.py index 1f6230fcb6..4fba62139e 100644 --- a/frontend/catalyst/device/__init__.py +++ b/frontend/catalyst/device/__init__.py @@ -22,7 +22,6 @@ extract_backend_info, get_device_capabilities, get_device_shots, - get_device_toml_config, ) __all__ = ( @@ -31,5 +30,4 @@ "extract_backend_info", "get_device_capabilities", "get_device_shots", - "get_device_toml_config", ) diff --git a/frontend/catalyst/device/decomposition.py b/frontend/catalyst/device/decomposition.py index 7adb261a6d..f9e61844ac 100644 --- a/frontend/catalyst/device/decomposition.py +++ b/frontend/catalyst/device/decomposition.py @@ -23,6 +23,7 @@ import jax import pennylane as qml from pennylane import transform +from pennylane.devices.capabilities import DeviceCapabilities from pennylane.devices.preprocess import decompose from pennylane.measurements import ( CountsMP, @@ -38,18 +39,17 @@ from catalyst.logging import debug_logger from catalyst.tracing.contexts import EvaluationContext from catalyst.utils.exceptions import CompileError -from catalyst.utils.toml import DeviceCapabilities logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) -def check_alternative_support(op, capabilities): +def check_alternative_support(op, capabilities: DeviceCapabilities): """Verify that aliased operations aren't supported via alternative definitions.""" if isinstance(op, qml.ops.Controlled): # "Cast" away the specialized class for gates like Toffoli, ControlledQubitUnitary, etc. - supported = capabilities.native_ops.get(op.base.name) + supported = capabilities.operations.get(op.base.name) if supported and supported.controllable: return [qml.ops.Controlled(op.base, op.control_wires, op.control_values, op.work_wires)] @@ -60,6 +60,7 @@ def catalyst_decomposer(op, capabilities: DeviceCapabilities): """A decomposer for catalyst, to be passed to the decompose transform. Takes an operator and returns the default decomposition, unless the operator should decompose to a QubitUnitary. Raises a CompileError for MidMeasureMP""" + if isinstance(op, MidMeasureMP): raise CompileError("Must use 'measure' from Catalyst instead of PennyLane.") @@ -67,11 +68,13 @@ def catalyst_decomposer(op, capabilities: DeviceCapabilities): if alternative_decomp is not None: return alternative_decomp - if capabilities.native_ops.get("QubitUnitary"): - # If the device supports unitary matrices, apply the relevant conversions and fallbacks. - if op.name in capabilities.to_matrix_ops or ( - op.has_matrix and isinstance(op, qml.ops.Controlled) - ): + if op.name in getattr(capabilities, "to_matrix_ops", {}): + return _decompose_to_matrix(op) + + if op.has_matrix and isinstance(op, qml.ops.Controlled): + + # If the device supports unitary matrices, apply the fallback. + if "QubitUnitary" in capabilities.operations: return _decompose_to_matrix(op) return op.decomposition() @@ -79,17 +82,17 @@ def catalyst_decomposer(op, capabilities: DeviceCapabilities): @transform @debug_logger -def catalyst_decompose(tape: qml.tape.QuantumTape, ctx, capabilities): +def catalyst_decompose(tape: qml.tape.QuantumTape, ctx, capabilities: DeviceCapabilities): """Decompose operations until the stopping condition is met. In a single call of the catalyst_decompose function, the PennyLane operations are decomposed in the same manner as in PennyLane (for each operator on the tape, checking if the operator - passes the stopping_condition, and using its `decompostion` method if not, called recursively + passes the stopping_condition, and using its `decomposition` method if not, called recursively until a supported operation is found or an error is hit, then moving on to the next operator on the tape.) Once all operators on the tape are supported operators, the resulting tape is iterated over, - and for each HybridOp, the catalyst_decompose function is called on each of it's regions. + and for each HybridOp, the catalyst_decompose function is called on each of its regions. This continues to call catalyst_decompose recursively until the tapes on all the HybridOps have been passed to the decompose function. """ @@ -99,7 +102,7 @@ def catalyst_decompose(tape: qml.tape.QuantumTape, ctx, capabilities): # only supports qml.StatePrep and qml.BasisState. A default strategy for handling any PennyLane # operator of type qml.StatePrepBase will be needed before this conditional can be removed. if len(tape) == 0 or type(tape[0]) in (qml.StatePrep, qml.BasisState): - skip_initial_state_prep = capabilities.initial_state_prep_flag + skip_initial_state_prep = capabilities.initial_state_prep else: skip_initial_state_prep = False @@ -133,7 +136,7 @@ def _decompose_to_matrix(op): return [op] -def _decompose_nested_tapes(op, ctx, capabilities): +def _decompose_nested_tapes(op, ctx, capabilities: DeviceCapabilities): new_regions = [] for region in op.regions: if region.quantum_tape is None: @@ -194,22 +197,28 @@ def null_postprocessing(results): return [new_tape], null_postprocessing -def catalyst_acceptance(op: qml.operation.Operation, capabilities: DeviceCapabilities) -> str: - """Check whether or not an Operator is supported.""" - op_support = capabilities.native_ops +def catalyst_acceptance(op: qml.operation.Operator, capabilities: DeviceCapabilities) -> str | None: + """Check whether an Operator is supported and returns the name of the operation or None.""" + + op_support = capabilities.operations if isinstance(op, qml.ops.Adjoint): match = catalyst_acceptance(op.base, capabilities) - if not match or not op_support[match].invertible: - return None + if match and op_support[match].invertible: + return match + + # There are cases where a custom controlled gate, e.g., CH, is supported, but its + # base, i.e., H, is not labeled controllable. In this case, we don't want to use + # this branch to check the support for this operation. elif type(op) is qml.ops.ControlledOp: match = catalyst_acceptance(op.base, capabilities) - if not match or not op_support[match].controllable: - return None - else: - match = op.name if op.name in op_support else None + if match and op_support[match].controllable: + return match - return match + elif op.name in op_support: + return op.name + + return None @transform @@ -237,7 +246,7 @@ def measurements_from_counts(tape, device_wires): new_tape = type(tape)(new_operations, [qml.counts(wires=measured_wires)], shots=tape.shots) def postprocessing_counts(results): - """A processing function to get expecation values from counts.""" + """A processing function to get expectation values from counts.""" states = results[0][0] counts_outcomes = results[0][1] results_processed = [] @@ -295,7 +304,7 @@ def measurements_from_samples(tape, device_wires): new_tape = type(tape)(new_operations, [qml.sample(wires=measured_wires)], shots=tape.shots) def postprocessing_samples(results): - """A processing function to get expecation values from samples.""" + """A processing function to get expectation values from samples.""" samples = results[0] results_processed = [] for m in tape.measurements: diff --git a/frontend/catalyst/device/qjit_device.py b/frontend/catalyst/device/qjit_device.py index 9543e742eb..421bd3f1e9 100644 --- a/frontend/catalyst/device/qjit_device.py +++ b/frontend/catalyst/device/qjit_device.py @@ -26,6 +26,7 @@ from typing import Any, Dict, Optional import pennylane as qml +from pennylane.devices.capabilities import DeviceCapabilities, OperatorProperties from pennylane.transforms import ( diagonalize_measurements, split_non_commuting, @@ -49,15 +50,6 @@ from catalyst.third_party.cuda import SoftwareQQPP from catalyst.utils.exceptions import CompileError from catalyst.utils.runtime_environment import get_lib_path -from catalyst.utils.toml import ( - DeviceCapabilities, - OperationProperties, - ProgramFeatures, - TOMLDocument, - intersect_operations, - load_device_capabilities, - read_toml_file, -) logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -110,17 +102,21 @@ "Sum", ] +RUNTIME_MPS = ["ExpectationMP", "SampleMP", "VarianceMP", "CountsMP", "StateMP", "ProbabilityMP"] + # The runtime interface does not care about specific gate properties, so set them all to True. RUNTIME_OPERATIONS = { - op: OperationProperties(invertible=True, controllable=True, differentiable=True) + op: OperatorProperties(invertible=True, controllable=True, differentiable=True) for op in RUNTIME_OPERATIONS } RUNTIME_OBSERVABLES = { - obs: OperationProperties(invertible=True, controllable=True, differentiable=True) + obs: OperatorProperties(invertible=True, controllable=True, differentiable=True) for obs in RUNTIME_OBSERVABLES } +RUNTIME_MPS = {mp: [] for mp in RUNTIME_MPS} + # TODO: This should be removed after implementing `get_c_interface` # for the following backend devices: SUPPORTED_RT_DEVICES = { @@ -155,7 +151,7 @@ def extract_backend_info( dname = device.name if isinstance(device, qml.devices.LegacyDeviceFacade): - dname = device.target_device.short_name + dname = device.target_device.short_name # pragma: no cover device_name = "" device_lpath = "" @@ -200,56 +196,73 @@ def extract_backend_info( device.target_device._s3_folder # pylint: disable=protected-access ) - for k, v in capabilities.options.items(): - if hasattr(device, v) and not k in device_kwargs: - device_kwargs[k] = getattr(device, v) + for k, v in getattr(device, "device_kwargs", {}).items(): + if k not in device_kwargs: # pragma: no branch + device_kwargs[k] = v return BackendInfo(dname, device_name, device_lpath, device_kwargs) +def intersect_operations( + a: Dict[str, OperatorProperties], b: Dict[str, OperatorProperties] +) -> Dict[str, OperatorProperties]: + """Intersects two sets of operator properties""" + return {k: a[k] & b[k] for k in (a.keys() & b.keys())} + + +def intersect_mps(a: dict[str, list], b: dict[str, list]) -> dict[str, list]: + """Intersects two sets of measurement processes""" + # In the dictionary, each measurement process is associated with a list of conditions. + # Therefore, the intersection is really the union of constraints from both measurement + # processes declarations, thus the | operator. + return {k: list(set(a[k]) | set(b[k])) for k in (a.keys() & b.keys())} + + @debug_logger def get_qjit_device_capabilities(target_capabilities: DeviceCapabilities) -> DeviceCapabilities: """Calculate the set of supported quantum gates for the QJIT device from the gates allowed on the target quantum device.""" + # Supported gates of the target PennyLane's device qjit_capabilities = deepcopy(target_capabilities) - # Gates and observables that Catalyst runtime supports - qir_gates = RUNTIME_OPERATIONS - qir_observables = RUNTIME_OBSERVABLES - - # Intersection of the above - qjit_capabilities.native_ops = intersect_operations(target_capabilities.native_ops, qir_gates) - qjit_capabilities.native_obs = intersect_operations( - target_capabilities.native_obs, qir_observables + # Intersection of gates and observables supported by the device and by Catalyst runtime. + qjit_capabilities.operations = intersect_operations( + target_capabilities.operations, RUNTIME_OPERATIONS + ) + qjit_capabilities.observables = intersect_operations( + target_capabilities.observables, RUNTIME_OBSERVABLES + ) + qjit_capabilities.measurement_processes = intersect_mps( + target_capabilities.measurement_processes, RUNTIME_MPS ) # Control-flow gates to be lowered down to the LLVM control-flow instructions - qjit_capabilities.native_ops.update( + qjit_capabilities.operations.update( { - "Cond": OperationProperties(invertible=True, controllable=True, differentiable=True), - "WhileLoop": OperationProperties( + "Cond": OperatorProperties(invertible=True, controllable=True, differentiable=True), + "WhileLoop": OperatorProperties( invertible=True, controllable=True, differentiable=True ), - "ForLoop": OperationProperties(invertible=True, controllable=True, differentiable=True), + "ForLoop": OperatorProperties(invertible=True, controllable=True, differentiable=True), } ) - # Optionally enable runtime-powered mid-circuit measurments - if target_capabilities.mid_circuit_measurement_flag: # pragma: no branch - qjit_capabilities.native_ops.update( + # Optionally enable runtime-powered mid-circuit measurements + if target_capabilities.supported_mcm_methods: # pragma: no branch + qjit_capabilities.operations.update( { - "MidCircuitMeasure": OperationProperties( + "MidCircuitMeasure": OperatorProperties( invertible=False, controllable=False, differentiable=False ) } ) - # Optionally enable runtime-powered quantum gate adjointing (inversions) - if any(ng.invertible for ng in target_capabilities.native_ops.values()): - qjit_capabilities.native_ops.update( + # Optionally enable runtime-powered adjoint of quantum gates (inversions) + if any(ng.invertible for ng in target_capabilities.operations.values()): # pragma: no branch + qjit_capabilities.operations.update( { - "HybridAdjoint": OperationProperties( + "HybridAdjoint": OperatorProperties( invertible=True, controllable=True, differentiable=True ) } @@ -257,13 +270,13 @@ def get_qjit_device_capabilities(target_capabilities: DeviceCapabilities) -> Dev # TODO: Optionally enable runtime-powered quantum gate controlling once they # are supported natively in MLIR. - # if any(ng.controllable for ng in target_capabilities.native_ops.values()): - # qjit_capabilities.native_ops.update( + # if any(ng.controllable for ng in target_capabilities.operations.values()): + # qjit_capabilities.operations.update( # { - # "HybridCtrl": OperationProperties( + # "HybridCtrl": OperatorProperties( # invertible=True, controllable=True, differentiable=True # ) - # + # } # ) return qjit_capabilities @@ -301,14 +314,24 @@ def __init__(self, original_device): super().__init__(wires=original_device.wires, shots=original_device.shots) # Capability loading - original_device_capabilities = get_device_capabilities(original_device) - backend = QJITDevice.extract_backend_info(original_device, original_device_capabilities) + device_capabilities = get_device_capabilities(original_device) + + # TODO: This is a temporary measure to ensure consistency of behaviour. Remove this + # when customizable multi-pathway decomposition is implemented. (Epic 74474) + if hasattr(original_device, "_to_matrix_ops"): + _to_matrix_ops = getattr(original_device, "_to_matrix_ops") + setattr(device_capabilities, "to_matrix_ops", _to_matrix_ops) + if _to_matrix_ops and not device_capabilities.supports_operation("QubitUnitary"): + raise CompileError( + "The device that specifies to_matrix_ops must support QubitUnitary." + ) + + backend = QJITDevice.extract_backend_info(original_device, device_capabilities) self.backend_name = backend.c_interface_name self.backend_lib = backend.lpath self.backend_kwargs = backend.kwargs - - self.capabilities = get_qjit_device_capabilities(original_device_capabilities) + self.capabilities = get_qjit_device_capabilities(device_capabilities) @debug_logger def preprocess( @@ -380,41 +403,41 @@ def _measurement_transform_program(self): if isinstance(self.original_device, SoftwareQQPP): return measurement_program - supports_sum_observables = "Sum" in self.capabilities.native_obs + supports_sum_observables = "Sum" in self.capabilities.observables - if self.capabilities.non_commuting_observables_flag is False: + if self.capabilities.non_commuting_observables is False: measurement_program.add_transform(split_non_commuting) elif not supports_sum_observables: measurement_program.add_transform(split_to_single_terms) # if no observables are supported, we apply a transform to convert *everything* to the # readout basis, using either sample or counts based on device specification - if not self.capabilities.native_obs: + if not self.capabilities.observables: if not split_non_commuting in measurement_program: # this *should* be redundant, a TOML that doesn't have observables should have # a False non_commuting_observables flag, but we aren't enforcing that measurement_program.add_transform(split_non_commuting) - if "Sample" in self.capabilities.measurement_processes: + if "SampleMP" in self.capabilities.measurement_processes: measurement_program.add_transform(measurements_from_samples, self.wires) - elif "Counts" in self.capabilities.measurement_processes: + elif "CountsMP" in self.capabilities.measurement_processes: measurement_program.add_transform(measurements_from_counts, self.wires) else: raise RuntimeError("The device does not support observables or sample/counts") - elif not self.capabilities.measurement_processes - {"Counts", "Sample"}: - # ToDo: this branch should become unneccessary when selective conversion of + elif not self.capabilities.measurement_processes.keys() - {"CountsMP", "SampleMP"}: + # ToDo: this branch should become unnecessary when selective conversion of # unsupported MPs is finished, see ToDo below - if not split_non_commuting in measurement_program: + if not split_non_commuting in measurement_program: # pragma: no branch measurement_program.add_transform(split_non_commuting) mp_transform = ( measurements_from_samples - if "Sample" in self.capabilities.measurement_processes + if "SampleMP" in self.capabilities.measurement_processes else measurements_from_counts ) measurement_program.add_transform(mp_transform, self.wires) # if only some observables are supported, we try to diagonalize those that aren't - elif not {"PauliX", "PauliY", "PauliZ", "Hadamard"}.issubset(self.capabilities.native_obs): + elif not {"PauliX", "PauliY", "PauliZ", "Hadamard"}.issubset(self.capabilities.observables): if not split_non_commuting in measurement_program: # the device might support non commuting measurements but not all the # Pauli + Hadamard observables, so here it is needed @@ -427,7 +450,7 @@ def _measurement_transform_program(self): } # checking which base observables are unsupported and need to be diagonalized supported_observables = {"PauliX", "PauliY", "PauliZ", "Hadamard"}.intersection( - self.capabilities.native_obs + self.capabilities.observables ) supported_observables = [_obs_dict[obs] for obs in supported_observables] @@ -447,6 +470,7 @@ def execute(self, circuits, execution_config): raise RuntimeError("QJIT devices cannot execute tapes.") +# pragam: no cover def filter_out_modifiers(operations): """Remove Adjoint/Control from operations. @@ -464,16 +488,20 @@ def is_not_modifier(op): return set(filter(is_not_modifier, operations)) -def get_device_toml_config(device) -> TOMLDocument: +def _load_device_capabilities(device) -> DeviceCapabilities: """Get the contents of the device config file.""" - if hasattr(device, "config"): - # The expected case: device specifies its own config. - toml_file = device.config + + # TODO: This code exists purely for testing. Find another way to customize device Find a + # better way for a device to customize its capabilities as seen by Catalyst. + if hasattr(device, "qjit_capabilities"): + return device.qjit_capabilities + + if getattr(device, "config_filepath") is not None: + toml_file = device.config_filepath + else: - # TODO: Remove this section when `qml.devices.Device`s are guaranteed to have their own config file - # field. + # TODO: Remove this section when devices are guaranteed to have their own config file device_lpath = pathlib.Path(get_lib_path("runtime", "RUNTIME_LIB_DIR")) - name = device.short_name if isinstance(device, qml.devices.LegacyDevice) else device.name # The toml files name convention we follow is to replace # the dots with underscores in the device short name. @@ -482,32 +510,31 @@ def get_device_toml_config(device) -> TOMLDocument: toml_file = device_lpath.parent / "lib" / "backend" / toml_file_name try: - config = read_toml_file(toml_file) + capabilities = DeviceCapabilities.from_toml_file(toml_file, "qjit") + except FileNotFoundError as e: raise CompileError( "Attempting to compile program for incompatible device: " f"Config file ({toml_file}) does not exist" ) from e - return config + return capabilities def get_device_capabilities(device) -> DeviceCapabilities: - """Get or load DeviceCapabilities structure from device""" + """Get or load the original DeviceCapabilities from device""" + assert not isinstance(device, QJITDevice) - # TODO: This code exists purely for testing. Find another way to customize device - # support easily without injecting code into the package. - if hasattr(device, "qjit_capabilities"): - return device.qjit_capabilities + shots_present = bool(device.shots) + device_capabilities = _load_device_capabilities(device) - program_features = ProgramFeatures(shots_present=bool(device.shots)) - device_config = get_device_toml_config(device) - return load_device_capabilities(device_config, program_features) + return device_capabilities.filter(finite_shots=shots_present) def check_device_wires(wires): """Validate requirements Catalyst imposes on device wires.""" + if wires is None: raise AttributeError("Catalyst does not support device instances without set wires.") diff --git a/frontend/catalyst/device/verification.py b/frontend/catalyst/device/verification.py index 9feb85c1ca..532410e81c 100644 --- a/frontend/catalyst/device/verification.py +++ b/frontend/catalyst/device/verification.py @@ -20,6 +20,7 @@ from typing import Any, Callable, Sequence, Union from pennylane import transform +from pennylane.devices.capabilities import DeviceCapabilities, OperatorProperties from pennylane.measurements import ( ExpectationMP, MutualInfoMP, @@ -38,7 +39,6 @@ from catalyst.jax_tracer import HybridOp, has_nested_tapes, nested_quantum_regions from catalyst.tracing.contexts import EvaluationContext from catalyst.utils.exceptions import CompileError, DifferentiableCompileError -from catalyst.utils.toml import OperationProperties def _verify_nested( @@ -84,11 +84,12 @@ def _verify_observable(obs: Operation, _obs_checker: Callable) -> bool: if isinstance(obs, CompositeOp): for o in obs.operands: _verify_observable(o, _obs_checker) + elif isinstance(obs, SymbolicOp): _verify_observable(obs.base, _obs_checker) -EMPTY_PROPERTIES = OperationProperties(False, False, False) +EMPTY_PROPERTIES = OperatorProperties() @transform @@ -113,7 +114,8 @@ def verify_operations(tape: QuantumTape, grad_method, qjit_device): DifferentiableCompileError: gradient-related error CompileError: compilation error """ - op_support = qjit_device.capabilities.native_ops + + supported_ops = qjit_device.capabilities.operations def _paramshift_op_checker(op): if not isinstance(op, HybridOp): @@ -131,7 +133,7 @@ def _adj_diff_op_checker(op): op_name = op.base.name else: op_name = op.name - if not op_support.get(op_name, EMPTY_PROPERTIES).differentiable: + if not supported_ops.get(op_name, EMPTY_PROPERTIES).differentiable: raise DifferentiableCompileError( f"{op.name} is non-differentiable on '{qjit_device.original_device.name}' device" ) @@ -154,7 +156,7 @@ def _ctrl_op_checker(op, in_control): elif not in_control: return isinstance(op, HybridCtrl) - if not op_support.get(op.name, EMPTY_PROPERTIES).controllable: + if not supported_ops.get(op.name, EMPTY_PROPERTIES).controllable: raise CompileError( f"{op.name} is not controllable on '{qjit_device.original_device.name}' device" ) @@ -179,7 +181,7 @@ def _inv_op_checker(op, in_inverse): elif not in_inverse: return isinstance(op, HybridAdjoint) - if not op_support.get(op.name, EMPTY_PROPERTIES).invertible: + if not supported_ops.get(op.name, EMPTY_PROPERTIES).invertible: raise CompileError( f"{op.name} is not invertible on '{qjit_device.original_device.name}' device" ) @@ -196,9 +198,9 @@ def _op_checker(op, state): pass # Don't check StatePrep since StatePrep is not in the list of device capabilities. # It is only valid when the TOML file has the initial_state_prep_flag. - elif isinstance(op, StatePrepBase) and qjit_device.capabilities.initial_state_prep_flag: + elif isinstance(op, StatePrepBase) and qjit_device.capabilities.initial_state_prep: pass - elif not op.name in op_support: + elif not op.name in supported_ops: raise CompileError( f"{op.name} is not supported on '{qjit_device.original_device.name}' device" ) @@ -248,7 +250,7 @@ def validate_observables_adjoint_diff(tape: QuantumTape, qjit_device): """Validate that the observables on the tape support adjoint differentiation""" def _obs_checker(obs): - if not qjit_device.capabilities.native_obs.get(obs.name, EMPTY_PROPERTIES).differentiable: + if not qjit_device.capabilities.observables.get(obs.name, EMPTY_PROPERTIES).differentiable: raise DifferentiableCompileError( f"{obs.name} is non-differentiable on " f"'{qjit_device.original_device.name}' device" @@ -263,14 +265,14 @@ def _obs_checker(obs): @transform def validate_measurements( - tape: QuantumTape, capabilities: dict, name: str, shots: Union[int, Shots] + tape: QuantumTape, capabilities: DeviceCapabilities, name: str, shots: Union[int, Shots] ) -> (Sequence[QuantumTape], Callable): """Validates the observables and measurements for a circuit against the capabilites from the TOML file. Args: tape (QuantumTape or QNode or Callable): a quantum circuit. - capabilities (dict): specifies the capabilities of the qjitted device + capabilities (DeviceCapabilities): specifies the capabilities of the qjitted device name: the name of the device to use in error messages shots: the shots on the device to use in error messages @@ -284,7 +286,7 @@ def validate_measurements( """ def _obs_checker(obs): - if not obs.name in capabilities.native_obs: + if not obs.name in capabilities.observables: raise CompileError( f"{m.obs} is not supported as an observable on the '{name}' device with Catalyst" ) @@ -309,10 +311,10 @@ def _obs_checker(obs): f"Sample-based measurements like {m} cannot work with shots=None. " "Please specify a finite number of shots." ) - mp_name = m.return_type.value if m.return_type else type(m).__name__ - if not mp_name.title() in capabilities.measurement_processes: + mp_name = type(m).__name__ + if not mp_name in capabilities.measurement_processes: raise CompileError( - f"{type(m)} is not a supported measurement process on '{name}' with Catalyst" + f"{mp_name} is not a supported measurement process on '{name}' with Catalyst" ) return (tape,), lambda x: x[0] diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index b3bf0a193f..70610267b1 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -29,12 +29,18 @@ ) from pennylane.capture import PlxprInterpreter -from catalyst.device import extract_backend_info, get_device_capabilities +from catalyst.device import ( + extract_backend_info, + get_device_capabilities, + get_device_shots, +) from catalyst.jax_extras import make_jaxpr2, transient_jax_config +from catalyst.jax_extras.tracing import bind_flexible_primitive from catalyst.jax_primitives import ( AbstractQbit, AbstractQreg, compbasis_p, + counts_p, expval_p, gphase_p, namedobs_p, @@ -183,7 +189,7 @@ def __setattr__(self, __name: str, __value) -> None: def setup(self): if self.stateref is None: - qdevice_p.bind(**_get_device_kwargs(self._device)) + qdevice_p.bind(get_device_shots(self._device) or 0, **_get_device_kwargs(self._device)) self.stateref = { "qreg": qalloc_p.bind(len(self._device.wires)), "wire_map": {} @@ -212,7 +218,7 @@ def interpret_operation(self, op, is_adjoint=False): *in_qubits, *op.data, op=op.name, - params_len=len(op.data), + ctrl_value_len=0, qubits_len=len(op.wires), adjoint=is_adjoint, ) diff --git a/frontend/catalyst/jax_extras/tracing.py b/frontend/catalyst/jax_extras/tracing.py index 9f0951bbfa..1a5823f8f8 100644 --- a/frontend/catalyst/jax_extras/tracing.py +++ b/frontend/catalyst/jax_extras/tracing.py @@ -960,3 +960,37 @@ def bind(self, *args, **params): eqn = new_jaxpr_eqn(invars, outvars, self, params, [], source_info) trace.frame.add_eqn(eqn) return out_tracers if self.multiple_results else out_tracers.pop() + + +def bind_flexible_primitive(primitive, flexible_args: dict[str, Any], *dyn_args, **static_args): + """ + Calls the primitive.bind() method with dyn_args being positional arguments to the bind, + and static_args being keyword arguments. + + The flexible_args is a dictionary containing the flexible arguments. + These are the arguments that can either be static or dynamic. This method + will bind a flexible argument as static only if it is a single or a list of only integer, float, + or boolean literals. In the static case, the binded primitive's param name is the flexible arg's + key, and the jaxpr param value is the flexible arg's value. + + If a flexible argument is received as a tracer, it will be binded dynamically with + the flexible arg's value. + + This ensures that in the jaxpr, dynamic args become SSA arguments to the primitive, + and static args become literal-valued parameters of the jaxpr. + """ + + static_literal_pool = (int, float, bool) + + for flex_arg_name, flex_arg_value in flexible_args.items(): + if type(flex_arg_value) in static_literal_pool: + static_args[flex_arg_name] = flex_arg_value + elif isinstance(flex_arg_value, list): + if flex_arg_value and all(type(arg) in static_literal_pool for arg in flex_arg_value): + static_args[flex_arg_name] = flex_arg_value + else: + dyn_args += (*flex_arg_value,) + else: + dyn_args += (flex_arg_value,) + + return primitive.bind(*dyn_args, **static_args) diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index ad901df934..9f58dff29b 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -15,7 +15,6 @@ of quantum operations, measurements, and observables to JAXPR. """ -import copy import sys from dataclasses import dataclass from enum import Enum @@ -26,6 +25,7 @@ import numpy as np import pennylane as qml from jax._src import api_util, core, source_info_util, util +from jax._src.interpreters import partial_eval as pe from jax._src.lax.lax import _nary_lower_hlo, cos_p, sin_p from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo @@ -42,21 +42,14 @@ MulIOp, SubIOp, ) -from jaxlib.mlir.dialects.builtin import ModuleOp -from jaxlib.mlir.dialects.func import CallOp, FunctionType +from jaxlib.mlir.dialects.func import FunctionType from jaxlib.mlir.dialects.scf import ConditionOp, ForOp, IfOp, WhileOp, YieldOp from jaxlib.mlir.dialects.stablehlo import ConstantOp as StableHLOConstantOp from jaxlib.mlir.dialects.stablehlo import ConvertOp as StableHLOConvertOp -from mlir_quantum.dialects._transform_ops_gen import ( - ApplyRegisteredPassOp, - NamedSequenceOp, -) -from mlir_quantum.dialects._transform_ops_gen import YieldOp as TransformYieldOp from mlir_quantum.dialects.catalyst import ( AssertionOp, CallbackCallOp, CallbackOp, - LaunchKernelOp, PrintOp, ) from mlir_quantum.dialects.gradient import ( @@ -93,6 +86,7 @@ SetBasisStateOp, SetStateOp, StateOp, + StaticCustomOp, TensorOp, VarianceOp, ) @@ -107,6 +101,15 @@ infer_output_type_jaxpr, while_loop_expansion_strategy, ) +from catalyst.jax_primitives_utils import ( + cache, + create_call_op, + get_cached, + get_call_jaxpr, + get_symbolref, + lower_callable, + lower_jaxpr, +) from catalyst.utils.calculate_grad_shape import Signature, calculate_grad_shape from catalyst.utils.extra_bindings import FromElementsOp, TensorExtractOp from catalyst.utils.types import convert_shaped_arrays_to_tensors @@ -195,18 +198,6 @@ def _obs_lowering(aval): return (ir.OpaqueType.get("quantum", "obs"),) -# -# Transform Module Type -# -class AbstractTransformMod(AbstractValue): - """Abstract transform module type.""" - - -def _transform_mod_lowering(aval): - assert isinstance(aval, AbstractTransformMod) - return (ir.OpaqueType.get("transform", 'op<"builtin.module">'),) - - # # registration # @@ -219,9 +210,6 @@ def _transform_mod_lowering(aval): core.raise_to_shaped_mappings[AbstractObs] = lambda aval, _: aval mlir.ir_type_handlers[AbstractObs] = _obs_lowering -core.raise_to_shaped_mappings[AbstractTransformMod] = lambda aval, _: aval -mlir.ir_type_handlers[AbstractTransformMod] = _transform_mod_lowering - class Folding(Enum): """ @@ -289,9 +277,6 @@ class Folding(Enum): value_and_grad_p.multiple_results = True assert_p = core.Primitive("assert") assert_p.multiple_results = True -apply_registered_pass_p = core.Primitive("apply_registered_pass") -transform_named_sequence_p = core.Primitive("transform_named_sequence") -transform_named_sequence_p.multiple_results = True set_state_p = jax.core.Primitive("state_prep") set_state_p.multiple_results = True set_basis_state_p = jax.core.Primitive("set_basis_state") @@ -334,8 +319,7 @@ def _python_callback_lowering( fn_ty = FunctionType.get(inputs=params_ty, results=results_ty) fn_ty_attr = ir.TypeAttr.get(fn_ty) cache_key = (callback_id, *params_ty, *results_ty) - if cache_key in jax_ctx.module_context.cached_primitive_lowerings: - callbackOp = jax_ctx.module_context.cached_primitive_lowerings[cache_key] + if callbackOp := get_cached(jax_ctx, cache_key): symbol = callbackOp.sym_name.value symbol_attr = ir.FlatSymbolRefAttr.get(symbol) return CallbackCallOp(results_ty, symbol_attr, args).results @@ -347,7 +331,8 @@ def _python_callback_lowering( # TODO: Name mangling for callbacks name = callback.__name__ callbackOp = CallbackOp(f"callback_{name}_{callback_id}", *attrs) - jax_ctx.module_context.cached_primitive_lowerings[cache_key] = callbackOp + + cache(jax_ctx, cache_key, callbackOp) symbol = callbackOp.sym_name.value symbol_attr = ir.FlatSymbolRefAttr.get(symbol) retval = CallbackCallOp(results_ty, symbol_attr, args).results @@ -360,8 +345,8 @@ def _python_callback_lowering( rev = custom_grad._bwd fwd_jaxpr = custom_grad._fwd_jaxpr rev_jaxpr = custom_grad._bwd_jaxpr - mlir_fwd = get_or_create_funcop(jax_ctx, fwd, fwd_jaxpr) - mlir_rev = get_or_create_funcop(jax_ctx, rev, rev_jaxpr) + mlir_fwd = lower_callable(jax_ctx, fwd, fwd_jaxpr) + mlir_rev = lower_callable(jax_ctx, rev, rev_jaxpr) sym_fwd = mlir_fwd.sym_name.value + ".fwd" argc = len(args) @@ -412,331 +397,15 @@ def _print_lowering(jax_ctx: mlir.LoweringRuleContext, *args, string=None, memre return PrintOp(val=val, const_val=None, print_descriptor=memref).results -# -# transform dialect lowering -# - - -def get_named_sequence_in_module(mod): - for op in mod.body.operations: - if op.operation.name == "transform.named_sequence": - return op.operation - return None - - -# -# transform_named_sequence -# -@transform_named_sequence_p.def_abstract_eval -def _transform_named_sequence_p_abstract_eval(*args): - return () - - -@transform_named_sequence_p.def_impl -def _transform_named_sequence_p_def_impl(*args): # pragma: no cover - raise NotImplementedError() - - -def _transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, *args): - transform_mod_type = ir.OpaqueType.get("transform", 'op<"builtin.module">') - module = jax_ctx.module_context.module - - # We wish to generate the transformer module, and place it in the top-level module - # The transformer module must be marked with the "transform.with_named_sequence" attribute - # The transformer module has a single block, and the block contains the - # "transform.named_sequence @__transform_main" operation - - with ir.InsertionPoint(module.body): - transformer_module = ModuleOp() - with_named_sequence_attr = ir.UnitAttr.get(jax_ctx.module_context.context) - transformer_module.operation.attributes["transform.with_named_sequence"] = ( - with_named_sequence_attr - ) - bb_transformer = transformer_module.body - - functype = ir.FunctionType.get(inputs=[transform_mod_type], results=[]) - functype_attr = ir.TypeAttr.get(functype) - - # Insert the transform.named_sequence op into the transformer module - # Note that InsertionPoint(Block) inserts after the last operation but still inside the block. - with ir.InsertionPoint(bb_transformer): - named_sequence_op = NamedSequenceOp( - sym_name="__transform_main", - function_type=functype_attr, - ) - - # transform.named_sequence op is the "main function" of the transform dialect - # and thus needs an entry block (which also should be its only block) - # The argument of the block is the payload module - bb_named_sequence = ir.Block.create_at_start( - named_sequence_op.body, arg_types=[transform_mod_type] - ) - - # The transform.named_sequence needs a terminator called "transform.yield" - with ir.InsertionPoint(bb_named_sequence): - transform_yield_op = TransformYieldOp(operands_=[]) # pylint: disable=unused-variable - - return named_sequence_op.results - - -# -# apply_registered_pass -# -@apply_registered_pass_p.def_abstract_eval -def _apply_registered_pass_abstract_eval(*args, pass_name, options=None): - return AbstractTransformMod() - - -@apply_registered_pass_p.def_impl -def _apply_registered_pass_def_impl(*args, pass_name, options=None): # pragma: no cover - raise NotImplementedError() - - -def _apply_registered_pass_lowering( - jax_ctx: mlir.LoweringRuleContext, *args, pass_name, options=None -): - transform_mod_type = ir.OpaqueType.get("transform", 'op<"builtin.module">') - module = jax_ctx.module_context.module - named_sequence_op = None - # module is a nested module - # parent_module is the root module - # E.g., - # - # ```mlir/pseudocode - # module @root { - # module @inner { - # func.func @qnode - # } - # module @transform { - # } - # } - # ``` - # - # When this function is executed we are likely - # somewhere around func.func @qnode. - # - # jax_ctx.module_context.module holds a reference to @inner - # - # This means that it's parent is @root. - parent_module = module.parent - for op in reversed(parent_module.regions[0].blocks[0].operations): - # Look for the module @transform that holds the transformation schedule - # TODO: Find a better way to search for the module with the transform schedule. - if op.operation.name == "builtin.module": - named_sequence_op = get_named_sequence_in_module(op) - if named_sequence_op: - break - assert ( - named_sequence_op is not None - ), """ - transform.apply_registered_pass must be placed in a transform.named_sequence, - but none exist in the module. - """ - - # If there already is a apply_registered_pass, - # insert after the last pass in the existing pass sequence. - # Note that ir.InsertionPoint(op) sets the insertion point to immediately BEFORE the op - named_sequence_op_block = named_sequence_op.regions[0].blocks[0] - first_op_in_block = named_sequence_op_block.operations[0].operation - - assert first_op_in_block.name in ( - "transform.apply_registered_pass", - "transform.yield", - ), """ - Unexpected operation in transform.named_sequence! - Only transform.apply_registered_pass and transform.yield are allowed. - """ - - if first_op_in_block.name == "transform.apply_registered_pass": - _ = len(named_sequence_op_block.operations) - yield_op = named_sequence_op_block.operations[_ - 1].operation - current_last_pass = named_sequence_op_block.operations[_ - 2].operation - with ir.InsertionPoint(yield_op): - apply_registered_pass_op = ApplyRegisteredPassOp( - result=transform_mod_type, - target=current_last_pass.result, - pass_name=pass_name, - options=options, - ) - - # otherwise it's the first pass, i.e. only a yield op is in the block - # so insert right before the yield op - else: - ip = named_sequence_op.regions[0].blocks[0] - with ir.InsertionPoint(ip.operations[len(ip.operations) - 1]): - apply_registered_pass_op = ApplyRegisteredPassOp( - result=transform_mod_type, - target=ip.arguments[0], - pass_name=pass_name, - options=options, - ) - - return apply_registered_pass_op.results - - # # module # -def lower_callable_to_funcop(ctx, callable_, call_jaxpr): - """Lower callable to either a FuncOp""" - if isinstance(call_jaxpr, core.Jaxpr): - call_jaxpr = core.ClosedJaxpr(call_jaxpr, ()) - - kwargs = {} - kwargs["ctx"] = ctx.module_context - kwargs["name"] = callable_.__name__ - kwargs["jaxpr"] = call_jaxpr - kwargs["effects"] = [] - kwargs["name_stack"] = ctx.name_stack - func_op = mlir.lower_jaxpr_to_fun(**kwargs) - - if isinstance(callable_, qml.QNode): - func_op.attributes["qnode"] = ir.UnitAttr.get() - # "best", the default option in PennyLane, chooses backprop on the device - # if supported and parameter-shift otherwise. Emulating the same behaviour - # would require generating code to query the device. - # For simplicity, Catalyst instead defaults to parameter-shift. - diff_method = ( - "parameter-shift" if callable_.diff_method == "best" else str(callable_.diff_method) - ) - func_op.attributes["diff_method"] = ir.StringAttr.get(diff_method) - - return func_op - - -def get_or_create_funcop(ctx, callable_, call_jaxpr): - """Get funcOp from cache, or create it from scratch""" - if func_op := ctx.module_context.cached_primitive_lowerings.get(callable_): - return func_op - func_op = lower_callable_to_funcop(ctx, callable_, call_jaxpr) - ctx.module_context.cached_primitive_lowerings[callable_] = func_op - return func_op - - -def get_symbolref(ctx, func_op): - """Get symbolref by deciding whether to constructo a symbolref or flatsymbolref""" - is_call_same_module = ctx.module_context.module.operation == func_op.parent - if is_call_same_module: - return ir.FlatSymbolRefAttr.get(func_op.name.value) - parent = func_op.parent - parent_name = parent.operation.attributes["sym_name"].value - child_name = func_op.name.value - return ir.SymbolRefAttr.get([parent_name, child_name]) - - -def create_call_op(ctx, func_op, *args): - """Create a func::CallOp from JAXPR.""" - output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out)) - flat_output_types = util.flatten(output_types) - mlir_args = mlir.flatten_lowering_ir_args(args) - symbol_ref = get_symbolref(ctx, func_op) - is_call_same_module = ctx.module_context.module.operation == func_op.parent - constructor = CallOp if is_call_same_module else LaunchKernelOp - return constructor(flat_output_types, symbol_ref, mlir_args) - - -def create_module_op(ctx, name): - """Create a module with name name""" - - symbol_table = ctx.module_context.symbol_table - parent = ctx.module_context.module - with ir.InsertionPoint(parent.body): - module = ModuleOp() - symbol_attr = ir._symbolNameAttr(name, ctx.module_context.context) - module.operation.attributes["sym_name"] = symbol_attr - symbol_table.insert(module) - - return module - - -class NestedModule: - """Context manager for the nested module""" - - def __init__(self, ctx, name): - self.ctx = ctx - self.moduleOp = create_module_op(ctx, name) - self.old_module_context = ctx.module_context - - def __enter__(self): - self.ctx.module_context = copy.copy(self.ctx.module_context) - self.ctx.module_context.module = self.moduleOp - self.ctx.module_context.cached_primitive_lowerings = {} - return self.moduleOp - - def __exit__(self, exc_type, exc_val, exc_tb): - self.ctx.module_context = self.old_module_context - - @quantum_kernel_p.def_impl -def _quantum_kernel_def_impl(*args, call_jaxpr, qnode): # pragma: no cover +def _quantum_kernel_def_impl(*args, call_jaxpr, qnode, pipeline=None): # pragma: no cover raise NotImplementedError() -def lower_callable(ctx, callable_, call_jaxpr): - """Lowers _callable to MLIR. - - If callable_ is a qnode, then we will first create a module, then - create a FuncOp corresponding to call_jaxpr. Otherwise, a FuncOp - will be created in the current module. This function might - add more than one FuncOps. This depends on the contents of call_jaxpr. - - Args: - ctx: LoweringRuleContext - callable_: python function - call_jaxpr: jaxpr representing callable_ - Returns: - FuncOp - """ - if not isinstance(callable_, qml.QNode): - return get_or_create_funcop(ctx, callable_, call_jaxpr) - - return get_or_create_qnode_funcop(ctx, callable_, call_jaxpr) - - -def lower_qnode_to_funcop(ctx, callable_, call_jaxpr): - """Lowers callable_ to MLIR. - - Will create ModuleOp and then lower the callable_ to a - FuncOp inside the ModuleOp. The ModuleOp may have more - than one FuncOp. This depends on the contents of call_jaxpr. - - Args: - ctx: LoweringRuleContext - callable_: qml.Qnode - call_jaxpr: jaxpr representing callable_ - Returns: - FuncOp - """ - assert isinstance(callable_, qml.QNode), "This function expects qnodes" - - name = "module_" + callable_.__name__ - # pylint: disable-next=no-member - with NestedModule(ctx, name) as module, ir.InsertionPoint(module.regions[0].blocks[0]) as ip: - ctx.module_context.ip = ip - func_op = get_or_create_funcop(ctx, callable_, call_jaxpr) - func_op.sym_visibility = ir.StringAttr.get("public") - - return func_op - - -def get_or_create_qnode_funcop(ctx, callable_, call_jaxpr): - """A wrapper around lower_qnode_to_funcop that will cache the FuncOp. - - Args: - ctx: LoweringRuleContext - callable_: qml.Qnode - call_jaxpr: jaxpr representing callable_ - Returns: - FuncOp - """ - if func_op := ctx.module_context.cached_primitive_lowerings.get(callable_): - return func_op - func_op = lower_qnode_to_funcop(ctx, callable_, call_jaxpr) - ctx.module_context.cached_primitive_lowerings[callable_] = func_op - return func_op - - -def _quantum_kernel_lowering(ctx, *args, call_jaxpr, qnode): +def _quantum_kernel_lowering(ctx, *args, call_jaxpr, qnode, pipeline=None): """Lower's qnodes to moduleOp Args: @@ -747,9 +416,11 @@ def _quantum_kernel_lowering(ctx, *args, call_jaxpr, qnode): Returns: List[mlir.Value] corresponding """ - assert isinstance(qnode, qml.QNode), "This function expects qnodes" - func_op = get_or_create_qnode_funcop(ctx, qnode, call_jaxpr) + if pipeline is None: + pipeline = tuple() + + func_op = lower_callable(ctx, qnode, call_jaxpr, pipeline) call_op = create_call_op(ctx, func_op, *args) return call_op.results @@ -771,7 +442,7 @@ def _func_lowering(ctx, *args, call_jaxpr, fn): call_jaxpr: the jaxpr representation of the fn fn: the function being compiled """ - func_op = get_or_create_funcop(ctx, fn, call_jaxpr) + func_op = lower_callable(ctx, fn, call_jaxpr) call_op = create_call_op(ctx, func_op, *args) return call_op.results @@ -808,15 +479,6 @@ def _grad_abstract(*args, jaxpr, fn, grad_params): return tuple(transformed_signature.get_results()) -def _get_call_jaxpr(jaxpr): - """Extracts the `call_jaxpr` from a JAXPR if it exists.""" "" - for eqn in jaxpr.eqns: - primitive = eqn.primitive - if primitive in {func_p, quantum_kernel_p}: - return eqn.params["call_jaxpr"] - raise AssertionError("No call_jaxpr found in the JAXPR.") - - def _grad_lowering(ctx, *args, jaxpr, fn, grad_params): """Lowering function to gradient. Args: @@ -839,9 +501,7 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params): new_argnums = [num + offset for num in argnums] argnum_numpy = np.array(new_argnums) diffArgIndices = ir.DenseIntElementsAttr.get(argnum_numpy) - func_call_jaxpr = _get_call_jaxpr(jaxpr) - lower_callable(ctx, fn, func_call_jaxpr) - func_op = ctx.module_context.cached_primitive_lowerings[fn] + func_op = lower_jaxpr(ctx, jaxpr) symbol_ref = get_symbolref(ctx, func_op) output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out)) @@ -911,14 +571,13 @@ def _value_and_grad_lowering(ctx, *args, jaxpr, fn, grad_params): constants.append(constantVals) consts_and_args = constants + args - func_call_jaxpr = _get_call_jaxpr(jaxpr) + func_call_jaxpr = get_call_jaxpr(jaxpr) func_args = consts_and_args[: len(func_call_jaxpr.invars)] val_result_types = flat_output_types[: len(flat_output_types) - len(argnums)] gradient_result_types = flat_output_types[len(flat_output_types) - len(argnums) :] - lower_callable(ctx, fn, func_call_jaxpr) + func_op = lower_jaxpr(ctx, jaxpr) - func_op = ctx.module_context.cached_primitive_lowerings[fn] symbol_ref = get_symbolref(ctx, func_op) return ValueAndGradOp( val_result_types, @@ -963,16 +622,15 @@ def _jvp_lowering(ctx, *args, jaxpr, fn, grad_params): for const in jaxpr.consts ] consts_and_args = constants + args - func_call_jaxpr = _get_call_jaxpr(jaxpr) + func_call_jaxpr = get_call_jaxpr(jaxpr) func_args = consts_and_args[: len(func_call_jaxpr.invars)] tang_args = consts_and_args[len(func_call_jaxpr.invars) :] - lower_callable(ctx, fn, func_call_jaxpr) + func_op = lower_jaxpr(ctx, jaxpr) assert ( len(flat_output_types) % 2 == 0 ), f"The total number of result tensors is expected to be even, not {len(flat_output_types)}" - func_op = ctx.module_context.cached_primitive_lowerings[fn] symbol_ref = get_symbolref(ctx, func_op) return JVPOp( flat_output_types[: len(flat_output_types) // 2], @@ -1015,15 +673,14 @@ def _vjp_lowering(ctx, *args, jaxpr, fn, grad_params): for const in jaxpr.consts ] consts_and_args = constants + args - func_call_jaxpr = _get_call_jaxpr(jaxpr) + func_call_jaxpr = get_call_jaxpr(jaxpr) func_args = consts_and_args[: len(func_call_jaxpr.invars)] cotang_args = consts_and_args[len(func_call_jaxpr.invars) :] func_result_types = flat_output_types[: len(flat_output_types) - len(argnums)] vjp_result_types = flat_output_types[len(flat_output_types) - len(argnums) :] - lower_callable(ctx, fn, func_call_jaxpr) + func_op = lower_jaxpr(ctx, jaxpr) - func_op = ctx.module_context.cached_primitive_lowerings[fn] symbol_ref = get_symbolref(ctx, func_op) return VJPOp( func_result_types, @@ -1073,10 +730,7 @@ def _zne_lowering(ctx, *args, folding, jaxpr, fn): jaxpr: the jaxpr representation of the circuit fn: the function to be mitigated """ - func_call_jaxpr = _get_call_jaxpr(jaxpr) - - lower_callable(ctx, fn, func_call_jaxpr) - func_op = ctx.module_context.cached_primitive_lowerings[fn] + func_op = lower_jaxpr(ctx, jaxpr) symbol_ref = get_symbolref(ctx, func_op) output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out)) flat_output_types = util.flatten(output_types) @@ -1107,20 +761,27 @@ def _zne_lowering(ctx, *args, folding, jaxpr, fn): # qdevice # @qdevice_p.def_impl -def _qdevice_def_impl(ctx, rtd_lib, rtd_name, rtd_kwargs): # pragma: no cover +def _qdevice_def_impl(ctx, shots, rtd_lib, rtd_name, rtd_kwargs): # pragma: no cover raise NotImplementedError() @qdevice_p.def_abstract_eval -def _qdevice_abstract_eval(rtd_lib, rtd_name, rtd_kwargs): +def _qdevice_abstract_eval(shots, rtd_lib, rtd_name, rtd_kwargs): return () -def _qdevice_lowering(jax_ctx: mlir.LoweringRuleContext, rtd_lib, rtd_name, rtd_kwargs): +def _qdevice_lowering( + jax_ctx: mlir.LoweringRuleContext, shots: ir.Value, rtd_lib, rtd_name, rtd_kwargs +): ctx = jax_ctx.module_context.context ctx.allow_unregistered_dialects = True + + shots_value = TensorExtractOp(ir.IntegerType.get_signless(64, ctx), shots, []).result DeviceInitOp( - ir.StringAttr.get(rtd_lib), ir.StringAttr.get(rtd_name), ir.StringAttr.get(rtd_kwargs) + ir.StringAttr.get(rtd_lib), + ir.StringAttr.get(rtd_name), + ir.StringAttr.get(rtd_kwargs), + shots=shots_value, ) return () @@ -1312,13 +973,17 @@ def _gphase_lowering( # @qinst_p.def_abstract_eval def _qinst_abstract_eval( - *qubits_or_params, op=None, qubits_len=0, params_len=0, ctrl_len=0, adjoint=False + *qubits_or_params, + op=None, + qubits_len=0, + ctrl_len=0, + ctrl_value_len=0, + adjoint=False, + static_params=None, ): # The signature here is: (using * to denote zero or more) - # qubits*, params*, ctrl_qubits*, ctrl_values* - qubits = qubits_or_params[:qubits_len] - ctrl_qubits = qubits_or_params[-2 * ctrl_len : -ctrl_len] - all_qubits = qubits + ctrl_qubits + # qubits*, ctrl_qubits*, ctrl_values*, params* + all_qubits = qubits_or_params[: qubits_len + ctrl_len] for idx in range(qubits_len + ctrl_len): qubit = all_qubits[idx] assert isinstance(qubit, AbstractQbit) @@ -1336,17 +1001,19 @@ def _qinst_lowering( *qubits_or_params, op=None, qubits_len=0, - params_len=0, ctrl_len=0, + ctrl_value_len=0, adjoint=False, + static_params=None, ): + assert ctrl_value_len == ctrl_len, "Control values must be the same length as control qubits" ctx = jax_ctx.module_context.context ctx.allow_unregistered_dialects = True qubits = qubits_or_params[:qubits_len] - params = qubits_or_params[qubits_len : qubits_len + params_len] - ctrl_qubits = qubits_or_params[qubits_len + params_len : qubits_len + params_len + ctrl_len] - ctrl_values = qubits_or_params[qubits_len + params_len + ctrl_len :] + ctrl_qubits = qubits_or_params[qubits_len : qubits_len + ctrl_len] + ctrl_values = qubits_or_params[qubits_len + ctrl_len : qubits_len + ctrl_len + ctrl_value_len] + params = qubits_or_params[qubits_len + ctrl_len + ctrl_value_len :] for qubit in qubits: assert ir.OpaqueType.isinstance(qubit.type) @@ -1369,13 +1036,31 @@ def _qinst_lowering( p = TensorExtractOp(ir.IntegerType.get_signless(1), v, []).result ctrl_values_i1.append(p) + params_attr = ( + None + if not static_params + else ir.DenseF64ArrayAttr.get([ir.FloatAttr.get_f64(val) for val in static_params]) + ) + if len(float_params) > 0: + assert ( + params_attr is None + ), "Static parameters are not allowed when having dynamic parameters" + name_attr = ir.StringAttr.get(op) name_str = str(name_attr) name_str = name_str.replace('"', "") if name_str == "MultiRZ": - assert len(float_params) == 1, "MultiRZ takes one float parameter" - float_param = float_params[0] + assert len(float_params) <= 1, "MultiRZ takes at most one dynamic float parameter" + assert ( + not static_params or len(static_params) <= 1 + ), "MultiRZ takes at most one static float parameter" + # TODO: Add support for MultiRZ with static params + float_param = ( + TensorExtractOp(ir.F64Type.get(), mlir.ir_constant(static_params[0]), []) + if len(float_params) == 0 + else float_params[0] + ) return MultiRZOp( out_qubits=[qubit.type for qubit in qubits], out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits], @@ -1385,7 +1070,17 @@ def _qinst_lowering( in_ctrl_values=ctrl_values_i1, adjoint=adjoint, ).results - + if params_attr: + return StaticCustomOp( + out_qubits=[qubit.type for qubit in qubits], + out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits], + static_params=params_attr, + in_qubits=qubits, + gate_name=name_attr, + in_ctrl_qubits=ctrl_qubits, + in_ctrl_values=ctrl_values_i1, + adjoint=adjoint, + ).results return CustomOp( out_qubits=[qubit.type for qubit in qubits], out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits], @@ -1653,33 +1348,78 @@ def _hamiltonian_lowering(jax_ctx: mlir.LoweringRuleContext, coeffs: ir.Value, * # # sample measurement # -@sample_p.def_abstract_eval -def _sample_abstract_eval(obs, shots, shape): - assert isinstance(obs, AbstractObs) +def sample_staging_rule(jaxpr_trace, obs, shots, num_qubits): + """ + The result shape of `sample_p` is (shots, num_qubits). + In jax, the default `def_abstract_eval` method for binding primitives keeps the abstract aval in + the dynamic shape dimension, instead of the SSA value for the shape, i.e. + + c:i64[] = ... + d:AbstractObs = ... + e:f64[ShapedArray(int64[], weak_type=True),1] = sample[num_qubits=1] d c + + To ensure that the result DShapedArray is actually constructed with the tracer value, + we need to provide a custom staging rule for the primitive, where we manually link + the tracer to the output shape. This will now correctly produce + + e:f64[c,1] = sample[num_qubits=1] d c + + This works because when jax processes a primitive during making jaxprs, the default + is to only look at the abstract avals of the primitive. Providing a custom staging rule + circumvents the above default logic. + + See jax._src.interpreters.partial_eval.process_primitive and default_process_primitive, + https://github.com/jax-ml/jax/blob/a54319ec1886ed920d50cacf10e147a743888464/jax/_src/interpreters/partial_eval.py#L1881C7-L1881C24 + """ if obs.primitive is compbasis_p: - assert shape == (shots, obs.num_qubits) + assert num_qubits == obs.num_qubits + + out_shape = core.DShapedArray((shots, num_qubits), jax.numpy.dtype("float64")) + out_tracer = pe.DynamicJaxprTracer(jaxpr_trace, out_shape) + + if isinstance(shots, int): + invars = [jaxpr_trace.getvar(obs)] + params = {"shots": shots, "num_qubits": num_qubits} else: - assert shape == (shots,) + invars = [jaxpr_trace.getvar(obs), jaxpr_trace.getvar(shots)] + params = {"num_qubits": num_qubits} + + eqn = pe.new_jaxpr_eqn( + invars, + [jaxpr_trace.makevar(out_tracer)], + sample_p, + params, + jax.core.no_effects, + ) + jaxpr_trace.frame.add_eqn(eqn) + return out_tracer - return core.ShapedArray(shape, jax.numpy.float64) + +pe.custom_staging_rules[sample_p] = sample_staging_rule @sample_p.def_impl -def _sample_def_impl(ctx, obs, shots, shape): # pragma: no cover +def _sample_def_impl(ctx, obs, shots, num_qubits): # pragma: no cover raise NotImplementedError() -def _sample_lowering(jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shots: int, shape: tuple): +def _sample_lowering( + jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shots: Union[int, ir.Value], num_qubits: int +): + # Note: result shape of sample op is (shots, number_of_qubits) ctx = jax_ctx.module_context.context ctx.allow_unregistered_dialects = True - i64_type = ir.IntegerType.get_signless(64, ctx) - shots_attr = ir.IntegerAttr.get(i64_type, shots) f64_type = ir.F64Type.get() - result_type = ir.RankedTensorType.get(shape, f64_type) + result_shape = ( + (shots, num_qubits) + if isinstance(shots, int) + else (ir.ShapedType.get_dynamic_size(), num_qubits) + ) + result_type = ir.RankedTensorType.get(result_shape, f64_type) - return SampleOp(result_type, obs, shots_attr).results + return SampleOp(result_type, obs).results # @@ -1699,37 +1439,43 @@ def _counts_abstract_eval(obs, shots, shape): else: assert shape == (2,) - return core.ShapedArray(shape, jax.numpy.float64), core.ShapedArray(shape, jax.numpy.int64) + return core.ShapedArray(shape, jax.numpy.dtype("float64")), core.ShapedArray( + shape, jax.numpy.dtype("int64") + ) -def _counts_lowering(jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shots: int, shape: tuple): +def _counts_lowering( + jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shots: Union[int, ir.Value], shape: tuple +): + # Note: result shape of counts op is (tensor, tensor) + # where N = 2**number_of_qubits + # This means even with dynamic shots, result shape is still static. ctx = jax_ctx.module_context.context ctx.allow_unregistered_dialects = True i64_type = ir.IntegerType.get_signless(64, ctx) - shots_attr = ir.IntegerAttr.get(i64_type, shots) f64_type = ir.F64Type.get() eigvals_type = ir.RankedTensorType.get(shape, f64_type) counts_type = ir.RankedTensorType.get(shape, i64_type) - return CountsOp(eigvals_type, counts_type, obs, shots_attr).results + return CountsOp(eigvals_type, counts_type, obs).results # # expval measurement # @expval_p.def_abstract_eval -def _expval_abstract_eval(obs, shots, shape=None): +def _expval_abstract_eval(obs, shape=None): assert isinstance(obs, AbstractObs) return core.ShapedArray((), jax.numpy.float64) @expval_p.def_impl -def _expval_def_impl(ctx, obs, shots, shape=None): # pragma: no cover +def _expval_def_impl(ctx, obs, shape=None): # pragma: no cover raise NotImplementedError() -def _expval_lowering(jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shots: int, shape=None): +def _expval_lowering(jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shape=None): ctx = jax_ctx.module_context.context ctx.allow_unregistered_dialects = True @@ -1737,11 +1483,9 @@ def _expval_lowering(jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shots: in assert ir.OpaqueType(obs.type).dialect_namespace == "quantum" assert ir.OpaqueType(obs.type).data == "obs" - i64_type = ir.IntegerType.get_signless(64, ctx) - shots_attr = ir.IntegerAttr.get(i64_type, shots) if shots is not None else None result_type = ir.F64Type.get() - mres = ExpvalOp(result_type, obs, shots=shots_attr).result + mres = ExpvalOp(result_type, obs).result result_from_elements_op = ir.RankedTensorType.get((), result_type) from_elements_op = FromElementsOp(result_from_elements_op, mres) return from_elements_op.results @@ -1751,17 +1495,17 @@ def _expval_lowering(jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shots: in # var measurement # @var_p.def_abstract_eval -def _var_abstract_eval(obs, shots, shape=None): +def _var_abstract_eval(obs, shape=None): assert isinstance(obs, AbstractObs) return core.ShapedArray((), jax.numpy.float64) @var_p.def_impl -def _var_def_impl(ctx, obs, shots, shape=None): # pragma: no cover +def _var_def_impl(ctx, obs, shape=None): # pragma: no cover raise NotImplementedError() -def _var_lowering(jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shots: int, shape=None): +def _var_lowering(jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shape=None): ctx = jax_ctx.module_context.context ctx.allow_unregistered_dialects = True @@ -1769,11 +1513,9 @@ def _var_lowering(jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shots: int, assert ir.OpaqueType(obs.type).dialect_namespace == "quantum" assert ir.OpaqueType(obs.type).data == "obs" - i64_type = ir.IntegerType.get_signless(64, ctx) - shots_attr = ir.IntegerAttr.get(i64_type, shots) if shots is not None else None result_type = ir.F64Type.get() - mres = VarianceOp(result_type, obs, shots=shots_attr).result + mres = VarianceOp(result_type, obs).result result_from_elements_op = ir.RankedTensorType.get((), result_type) from_elements_op = FromElementsOp(result_from_elements_op, mres) return from_elements_op.results @@ -2385,8 +2127,6 @@ def _cos_lowering2(ctx, x): (assert_p, _assert_lowering), (python_callback_p, _python_callback_lowering), (value_and_grad_p, _value_and_grad_lowering), - (apply_registered_pass_p, _apply_registered_pass_lowering), - (transform_named_sequence_p, _transform_named_sequence_lowering), (set_state_p, _set_state_lowering), (set_basis_state_p, _set_basis_state_lowering), (sin_p, _sin_lowering2), diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py new file mode 100644 index 0000000000..696392ed7f --- /dev/null +++ b/frontend/catalyst/jax_primitives_utils.py @@ -0,0 +1,283 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module contains some helper functions for translating JAX +primitives to MLIR""" + +import copy +import functools + +import pennylane as qml +from jax._src import core, util +from jax._src.lib.mlir import ir +from jax.interpreters import mlir +from jaxlib.mlir.dialects.builtin import ModuleOp +from jaxlib.mlir.dialects.func import CallOp +from mlir_quantum.dialects._transform_ops_gen import ( + ApplyRegisteredPassOp, + NamedSequenceOp, + YieldOp, +) +from mlir_quantum.dialects.catalyst import LaunchKernelOp + + +def get_call_jaxpr(jaxpr): + """Extracts the `call_jaxpr` from a JAXPR if it exists.""" "" + for eqn in jaxpr.eqns: + if eqn.params.get("call_jaxpr"): + return eqn.params["call_jaxpr"] + raise AssertionError("No call_jaxpr found in the JAXPR.") + + +def get_call_equation(jaxpr): + """Extracts the equation which has a call_jaxpr.""" + for eqn in jaxpr.eqns: + if eqn.params.get("call_jaxpr"): + return eqn + raise AssertionError("No call_jaxpr found in the JAXPR.") + + +def lower_jaxpr(ctx, jaxpr): + """Lowers a call primitive jaxpr, may be either func_p or quantum_kernel_p""" + equation = get_call_equation(jaxpr) + call_jaxpr = equation.params["call_jaxpr"] + callable_ = equation.params.get("fn") + if callable_ is None: + callable_ = equation.params.get("qnode") + pipeline = equation.params.get("pipeline") + return lower_callable(ctx, callable_, call_jaxpr, pipeline=pipeline) + + +def lower_callable(ctx, callable_, call_jaxpr, pipeline=None): + """Lowers _callable to MLIR. + + If callable_ is a qnode, then we will first create a module, then + create a FuncOp corresponding to call_jaxpr. Otherwise, a FuncOp + will be created in the current module. This function might + add more than one FuncOps. This depends on the contents of call_jaxpr. + + Args: + ctx: LoweringRuleContext + callable_: python function + call_jaxpr: jaxpr representing callable_ + Returns: + FuncOp + """ + if pipeline is None: + pipeline = tuple() + + if not isinstance(callable_, qml.QNode): + return get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline) + + return get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline) + + +def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline): + """Get funcOp from cache, or create it from scratch""" + key = (callable_, *pipeline) + if func_op := get_cached(ctx, key): + return func_op + func_op = lower_callable_to_funcop(ctx, callable_, call_jaxpr) + cache(ctx, key, func_op) + return func_op + + +def lower_callable_to_funcop(ctx, callable_, call_jaxpr): + """Lower callable to either a FuncOp""" + if isinstance(call_jaxpr, core.Jaxpr): + call_jaxpr = core.ClosedJaxpr(call_jaxpr, ()) + + kwargs = {} + kwargs["ctx"] = ctx.module_context + if not isinstance(callable_, functools.partial): + name = callable_.__name__ + else: + name = callable_.func.__name__ + ".partial" + kwargs["name"] = name + kwargs["jaxpr"] = call_jaxpr + kwargs["effects"] = [] + kwargs["name_stack"] = ctx.name_stack + func_op = mlir.lower_jaxpr_to_fun(**kwargs) + + if isinstance(callable_, qml.QNode): + func_op.attributes["qnode"] = ir.UnitAttr.get() + # "best", the default option in PennyLane, chooses backprop on the device + # if supported and parameter-shift otherwise. Emulating the same behaviour + # would require generating code to query the device. + # For simplicity, Catalyst instead defaults to parameter-shift. + diff_method = ( + "parameter-shift" if callable_.diff_method == "best" else str(callable_.diff_method) + ) + func_op.attributes["diff_method"] = ir.StringAttr.get(diff_method) + + return func_op + + +def get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline): + """A wrapper around lower_qnode_to_funcop that will cache the FuncOp. + + Args: + ctx: LoweringRuleContext + callable_: qml.Qnode + call_jaxpr: jaxpr representing callable_ + Returns: + FuncOp + """ + key = (callable_, *pipeline) + if func_op := get_cached(ctx, key): + return func_op + func_op = lower_qnode_to_funcop(ctx, callable_, call_jaxpr, pipeline) + cache(ctx, key, func_op) + return func_op + + +def lower_qnode_to_funcop(ctx, callable_, call_jaxpr, pipeline): + """Lowers callable_ to MLIR. + + Will create ModuleOp and then lower the callable_ to a + FuncOp inside the ModuleOp. The ModuleOp may have more + than one FuncOp. This depends on the contents of call_jaxpr. + + Args: + ctx: LoweringRuleContext + callable_: qml.Qnode + call_jaxpr: jaxpr representing callable_ + Returns: + FuncOp + """ + assert isinstance(callable_, qml.QNode), "This function expects qnodes" + + name = "module_" + callable_.__name__ + # pylint: disable-next=no-member + with NestedModule(ctx, name) as module, ir.InsertionPoint(module.regions[0].blocks[0]) as ip: + transform_named_sequence_lowering(ctx, pipeline) + ctx.module_context.ip = ip + func_op = get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline) + func_op.sym_visibility = ir.StringAttr.get("public") + + return func_op + + +def get_cached(ctx, key): + """Looks for key in the cache""" + return ctx.module_context.cached_primitive_lowerings.get(key) + + +def cache(ctx, key, val): + """Caches value in cache with key""" + ctx.module_context.cached_primitive_lowerings[key] = val + + +def get_symbolref(ctx, func_op): + """Get symbolref by deciding whether to constructo a symbolref or flatsymbolref""" + is_call_same_module = ctx.module_context.module.operation == func_op.parent + if is_call_same_module: + return ir.FlatSymbolRefAttr.get(func_op.name.value) + parent = func_op.parent + parent_name = parent.operation.attributes["sym_name"].value + child_name = func_op.name.value + return ir.SymbolRefAttr.get([parent_name, child_name]) + + +def create_call_op(ctx, func_op, *args): + """Create a func::CallOp from JAXPR.""" + output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out)) + flat_output_types = util.flatten(output_types) + mlir_args = mlir.flatten_lowering_ir_args(args) + symbol_ref = get_symbolref(ctx, func_op) + is_call_same_module = ctx.module_context.module.operation == func_op.parent + constructor = CallOp if is_call_same_module else LaunchKernelOp + return constructor(flat_output_types, symbol_ref, mlir_args) + + +def create_module_op(ctx, name): + """Create a module with name name""" + + symbol_table = ctx.module_context.symbol_table + parent = ctx.module_context.module + with ir.InsertionPoint(parent.body): + module = ModuleOp() + symbol_attr = ir._symbolNameAttr(name, ctx.module_context.context) + module.operation.attributes["sym_name"] = symbol_attr + symbol_table.insert(module) + + return module + + +class NestedModule: + """Context manager for the nested module""" + + def __init__(self, ctx, name): + self.ctx = ctx + self.moduleOp = create_module_op(ctx, name) + self.old_module_context = ctx.module_context + + def __enter__(self): + self.ctx.module_context = copy.copy(self.ctx.module_context) + self.ctx.module_context.module = self.moduleOp + self.ctx.module_context.cached_primitive_lowerings = {} + return self.moduleOp + + def __exit__(self, exc_type, exc_val, exc_tb): + self.ctx.module_context = self.old_module_context + + +def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipeline): + """Generate a transform module embedded in the current module and schedule + the transformations in pipeline""" + + transform_mod_type = ir.OpaqueType.get("transform", 'op<"builtin.module">') + module = jax_ctx.module_context.module + + # We wish to generate the transformer module, and place it in the top-level module + # The transformer module must be marked with the "transform.with_named_sequence" attribute + # The transformer module has a single block, and the block contains the + # "transform.named_sequence @__transform_main" operation + + with ir.InsertionPoint(module.body): + transformer_module = ModuleOp() + with_named_sequence_attr = ir.UnitAttr.get(jax_ctx.module_context.context) + transformer_module.operation.attributes["transform.with_named_sequence"] = ( + with_named_sequence_attr + ) + bb_transformer = transformer_module.body + + functype = ir.FunctionType.get(inputs=[transform_mod_type], results=[]) + functype_attr = ir.TypeAttr.get(functype) + + # Insert the transform.named_sequence op into the transformer module + # Note that InsertionPoint(Block) inserts after the last operation but still inside the block. + with ir.InsertionPoint(bb_transformer): + named_sequence_op = NamedSequenceOp( + sym_name="__transform_main", + function_type=functype_attr, + ) + + # transform.named_sequence op is the "main function" of the transform dialect + # and thus needs an entry block (which also should be its only block) + # The argument of the block is the payload module + bb_named_sequence = ir.Block.create_at_start( + named_sequence_op.body, arg_types=[transform_mod_type] + ) + + # The transform.named_sequence needs a terminator called "transform.yield" + with ir.InsertionPoint(bb_named_sequence): + target = bb_named_sequence.arguments[0] + for _pass in pipeline: + apply_registered_pass_op = ApplyRegisteredPassOp( + result=transform_mod_type, target=target, pass_name=_pass.name + ) + target = apply_registered_pass_op.result + transform_yield_op = YieldOp(operands_=[]) # pylint: disable=unused-variable + + return named_sequence_op.results diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 651c7b4d04..2eb4dc7396 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -17,17 +17,25 @@ """ import logging +from collections.abc import Sequence from contextlib import contextmanager from dataclasses import dataclass from functools import partial, reduce -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import jax import jax.numpy as jnp import pennylane as qml from pennylane import QubitUnitary, QueuingManager from pennylane.devices import QubitDevice -from pennylane.measurements import DensityMatrixMP, MeasurementProcess, StateMP +from pennylane.measurements import ( + CountsMP, + ExpectationMP, + MeasurementProcess, + ProbabilityMP, + StateMP, + VarianceMP, +) from pennylane.operation import AnyWires, Operation, Operator, Wires from pennylane.ops import Adjoint, Controlled, ControlledOp from pennylane.tape import QuantumTape @@ -62,6 +70,7 @@ tree_unflatten, wrap_init, ) +from catalyst.jax_extras.tracing import bind_flexible_primitive from catalyst.jax_primitives import ( AbstractQreg, compbasis_p, @@ -88,7 +97,6 @@ var_p, ) from catalyst.logging import debug_logger, debug_logger_init -from catalyst.passes import add_mlir_quantum_decomposition from catalyst.tracing.contexts import ( EvaluationContext, EvaluationMode, @@ -540,8 +548,9 @@ def trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs): } with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION): jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs) + plugins = EvaluationContext.get_plugins() - return jaxpr, out_type, out_treedef + return jaxpr, out_type, out_treedef, plugins @debug_logger @@ -682,12 +691,14 @@ def bind_native_operation(qrp, op, controlled_wires, controlled_values, adjoint= else: qubits = qrp.extract(op.wires) controlled_qubits = qrp.extract(controlled_wires) - qubits2 = qinst_p.bind( - *[*qubits, *op.parameters, *controlled_qubits, *controlled_values], + qubits2 = bind_flexible_primitive( + qinst_p, + {"static_params": op.parameters}, + *[*qubits, *controlled_qubits, *controlled_values], op=op.name, qubits_len=len(qubits), - params_len=len(op.parameters), ctrl_len=len(controlled_qubits), + ctrl_value_len=len(controlled_values), adjoint=adjoint, ) qrp.insert(op.wires, qubits2[: len(qubits)]) @@ -856,23 +867,19 @@ def trace_quantum_measurements( if isinstance(o, MeasurementProcess): # Check if the measurement is supported shot-vector where num_of_total_copies > 1 - if device.shots.num_copies > 1 and o.return_type.value != "sample": # qml.sample() + if device.shots.num_copies > 1 and not isinstance(o, qml.measurements.SampleMP): raise NotImplementedError( - f"Measurement {o.return_type.value} is not supported a shot-vector. " + f"Measurement {type(o).__name__} is not supported a shot-vector. " "Use qml.sample() instead." ) - if isinstance(device, qml.devices.LegacyDevice): - m_wires = o.wires if o.wires else range(device.num_wires) - else: - m_wires = o.wires if o.wires else range(len(device.wires)) + m_wires = o.wires if o.wires else range(len(device.wires)) obs_tracers, nqubits = trace_observables(o.obs, qrp, m_wires) using_compbasis = obs_tracers.primitive == compbasis_p - if o.return_type.value == "sample": - results = [] # list of results per copy + if isinstance(o, qml.measurements.SampleMP): if shots is None: # needed for old device API only raise ValueError( @@ -883,7 +890,9 @@ def trace_quantum_measurements( out_classical_tracers.append(o.mv) else: shape = (shots, nqubits) if using_compbasis else (shots,) - result = sample_p.bind(obs_tracers, shots=shots, shape=shape) + result = bind_flexible_primitive( + sample_p, {"shots": shots}, obs_tracers, num_qubits=nqubits + ) if using_compbasis: result = jnp.astype(result, jnp.int64) @@ -901,22 +910,24 @@ def trace_quantum_measurements( out_classical_tracers.append(reshaped_result) - elif o.return_type.value == "expval": - out_classical_tracers.append(expval_p.bind(obs_tracers, shots=shots)) - elif o.return_type.value == "var": - out_classical_tracers.append(var_p.bind(obs_tracers, shots=shots)) - elif o.return_type.value == "probs": + elif type(o) is ExpectationMP: + out_classical_tracers.append(expval_p.bind(obs_tracers)) + elif type(o) is VarianceMP: + out_classical_tracers.append(var_p.bind(obs_tracers)) + elif type(o) is ProbabilityMP: assert using_compbasis shape = (2**nqubits,) out_classical_tracers.append(probs_p.bind(obs_tracers, shape=shape)) - elif o.return_type.value == "counts": + elif type(o) is CountsMP: if shots is None: # needed for old device API only raise ValueError( "qml.sample cannot work with shots=None. " "Please specify a finite number of shots." ) shape = (2**nqubits,) if using_compbasis else (2,) - results = counts_p.bind(obs_tracers, shots=shots, shape=shape) + results = bind_flexible_primitive( + counts_p, {"shots": shots}, obs_tracers, shape=shape + ) if using_compbasis: results = (jnp.asarray(results[0], jnp.int64), results[1]) out_classical_tracers.extend(results) @@ -931,7 +942,7 @@ def trace_quantum_measurements( ) else: out_tree = counts_tree - elif isinstance(o, StateMP) and not isinstance(o, DensityMatrixMP): + elif type(o) is StateMP: assert using_compbasis shape = (2**nqubits,) out_classical_tracers.append(state_p.bind(obs_tracers, shape=shape)) @@ -1136,8 +1147,6 @@ def trace_quantum_function( out_type: JAXPR output type (list of abstract values with explicitness flags). out_tree: PyTree shapen of the result """ - # Add the decomposition passes with the transform dialect - add_mlir_quantum_decomposition(f, device) with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION) as ctx: # (1) - Classical tracing @@ -1195,7 +1204,12 @@ def is_leaf(obj): # We just need to ensure there is a tape cut in between each. # Each tape will be outlined into its own function with mlir pass # -split-multiple-tapes + + # TODO: device shots is now always a concrete integer or None + # When PennyLane allows dynamic shots, update tracing to accept dynamic shots too + device_shots = get_device_shots(device) or 0 qdevice_p.bind( + device_shots, rtd_lib=device.backend_lib, rtd_name=device.backend_name, rtd_kwargs=str(device.backend_kwargs), diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index 86722edd56..2553dab071 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -28,17 +28,15 @@ import pennylane as qml from jax.interpreters import mlir from jax.tree_util import tree_flatten, tree_unflatten -from malt.core import config as ag_config import catalyst -from catalyst.autograph import ag_primitives, run_autograph +from catalyst.autograph import run_autograph from catalyst.compiled_functions import CompilationCache, CompiledFunction from catalyst.compiler import CompileOptions, Compiler from catalyst.debug.instruments import instrument from catalyst.from_plxpr import trace_from_pennylane from catalyst.jax_tracer import lower_jaxpr_to_mlir, trace_to_jaxpr from catalyst.logging import debug_logger, debug_logger_init -from catalyst.passes import PipelineNameUniquer, _inject_transform_named_sequence from catalyst.qfunc import QFunc from catalyst.tracing.contexts import EvaluationContext from catalyst.tracing.type_signatures import ( @@ -89,6 +87,8 @@ def qjit( seed=None, experimental_capture=False, circuit_transform_pipeline=None, + pass_plugins=None, + dialect_plugins=None, ): # pylint: disable=too-many-arguments,unused-argument """A just-in-time decorator for PennyLane and JAX programs using Catalyst. @@ -142,7 +142,9 @@ def qjit( ``lightning.gpu``. The default value is None, which means no seeding is performed, and all processes are random. A seed is expected to be an unsigned 32-bit integer. Currently, the following measurement processes are seeded: :func:`~.measure`, - :func:`qml.sample() `, :func:`qml.counts() `. + :func:`qml.sample() `, :func:`qml.counts() `, + :func:`qml.probs() `, :func:`qml.expval() `, + :func:`qml.var() `. experimental_capture (bool): If set to ``True``, the qjit decorator will use PennyLane's experimental program capture capabilities to capture the decorated function for compilation. @@ -155,6 +157,8 @@ def qjit( dictionaries of valid keyword arguments and values for the specific pass. The order of keys in this dictionary will determine the pass pipeline. If not specified, the default pass pipeline will be applied. + pass_plugins (Optional[List[Path]]): List of paths to pass plugins. + dialect_plugins (Optional[List[Path]]): List of paths to dialect plugins. Returns: QJIT object. @@ -496,17 +500,7 @@ def __init__(self, fn, compile_options): fn, compile_options.static_argnames, compile_options.static_argnums ) - # Patch the conversion rules by adding the included modules before the block list - include_convertlist = tuple( - ag_config.Convert(rule) for rule in self.compile_options.autograph_include - ) - self.patched_module_allowlist = include_convertlist + ag_primitives.module_allowlist - - # Pre-compile with the patched conversion rules - with Patcher( - (ag_primitives, "module_allowlist", self.patched_module_allowlist), - ): - self.user_function = self.pre_compilation() + self.user_function = self.pre_compilation() # Static arguments require values, so we cannot AOT compile. if self.user_sig is not None and not self.compile_options.static_argnums: @@ -546,13 +540,9 @@ def aot_compile(self): # TODO: awkward, refactor or redesign the target feature if self.compile_options.target in ("jaxpr", "mlir", "binary"): - # Capture with the patched conversion rules - with Patcher( - (ag_primitives, "module_allowlist", self.patched_module_allowlist), - ): - self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture( - self.user_sig or () - ) + self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture( + self.user_sig or () + ) if self.compile_options.target in ("mlir", "binary"): self.mlir_module, self.mlir = self.generate_ir() @@ -591,13 +581,7 @@ def jit_compile(self, args, **kwargs): if self.compiled_function and self.compiled_function.shared_object: self.compiled_function.shared_object.close() - # Capture with the patched conversion rules - with Patcher( - (ag_primitives, "module_allowlist", self.patched_module_allowlist), - ): - self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture( - args, **kwargs - ) + self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(args, **kwargs) self.mlir_module, self.mlir = self.generate_ir() self.compiled_function, self.qir = self.compile() @@ -622,12 +606,10 @@ def jit_compile(self, args, **kwargs): @debug_logger def pre_compilation(self): """Perform pre-processing tasks on the Python function, such as AST transformations.""" - processed_fn = self.original_function - if self.compile_options.autograph: - processed_fn = run_autograph(self.original_function) + return run_autograph(self.original_function, *self.compile_options.autograph_include) - return processed_fn + return self.original_function @instrument(size_from=0) @debug_logger @@ -650,33 +632,20 @@ def capture(self, args, **kwargs): dynamic_sig = get_abstract_signature(dynamic_args) full_sig = merge_static_args(dynamic_sig, args, static_argnums) - def fn_with_transform_named_sequence(*args, **kwargs): - """ - This function behaves exactly like the user function being jitted, - taking in the same arguments and producing the same results, except - it injects a transform_named_sequence jax primitive at the beginning - of the jaxpr when being traced. - - Note that we do not overwrite self.original_function and self.user_function; - this fn_with_transform_named_sequence is ONLY used here to produce tracing - results with a transform_named_sequence primitive at the beginning of the - jaxpr. It is never executed or used anywhere, except being traced here. - """ - _inject_transform_named_sequence() - return self.user_function(*args, **kwargs) - if self.compile_options.experimental_capture: return trace_from_pennylane( - fn_with_transform_named_sequence, static_argnums, abstracted_axes, full_sig, kwargs + self.user_function, static_argnums, abstracted_axes, full_sig, kwargs ) def closure(qnode, *args, **kwargs): params = {} params["static_argnums"] = kwargs.pop("static_argnums", static_argnums) params["_out_tree_expected"] = [] + default_pass_pipeline = self.compile_options.circuit_transform_pipeline + pass_pipeline = params.get("pass_pipeline", default_pass_pipeline) + params["pass_pipeline"] = pass_pipeline return QFunc.__call__( qnode, - pass_pipeline=self.compile_options.circuit_transform_pipeline, *args, **dict(params, **kwargs), ) @@ -685,15 +654,16 @@ def closure(qnode, *args, **kwargs): (qml.QNode, "__call__", closure), ): # TODO: improve PyTree handling - jaxpr, out_type, treedef = trace_to_jaxpr( - fn_with_transform_named_sequence, + jaxpr, out_type, treedef, plugins = trace_to_jaxpr( + self.user_function, static_argnums, abstracted_axes, full_sig, kwargs, ) + self.compile_options.pass_plugins.update(plugins) + self.compile_options.dialect_plugins.update(plugins) - PipelineNameUniquer.reset() return jaxpr, out_type, treedef, dynamic_sig @instrument(size_from=0, has_finegrained=True) diff --git a/frontend/catalyst/passes.py b/frontend/catalyst/passes.py index dbe5391009..c163e444f6 100644 --- a/frontend/catalyst/passes.py +++ b/frontend/catalyst/passes.py @@ -34,21 +34,75 @@ import copy import functools -from typing import Optional +from importlib.metadata import entry_points +from pathlib import Path +from typing import TypeAlias import pennylane as qml -from catalyst.jax_primitives import apply_registered_pass_p, transform_named_sequence_p from catalyst.tracing.contexts import EvaluationContext +PipelineDict: TypeAlias = dict[str, dict[str, str]] + + +class Pass: + """Class intended to hold options for passes""" + + def __init__(self, name, *options, **valued_options): + self.options = options + self.valued_options = valued_options + if "." in name: + resolution_functions = entry_points(group="catalyst.passes_resolution") + key, passname = name.split(".") + resolution_function = resolution_functions[key + ".passes"] + module = resolution_function.load() + path, name = module.name2pass(passname) + assert EvaluationContext.is_tracing() + EvaluationContext.add_plugin(path) + + self.name = name + + def __repr__(self): + return ( + self.name + + " ".join(f"--{option}" for option in self.options) + + " ".join(f"--{option}={value}" for option, value in self.valued_options) + ) + + +class PassPlugin(Pass): + """Class intended to hold options for pass plugins""" + + def __init__( + self, path: Path, name: str, *options: list[str], **valued_options: dict[str, str] + ): + assert EvaluationContext.is_tracing() + EvaluationContext.add_plugin(path) + self.path = path + super().__init__(name, *options, **valued_options) + + +def dictionary_to_tuple_of_passes(pass_pipeline: PipelineDict): + """Convert dictionary of passes into tuple of passes""" + + if type(pass_pipeline) != dict: + return pass_pipeline + + passes = tuple() + pass_names = _API_name_to_pass_name() + for API_name, pass_options in pass_pipeline.items(): + name = pass_names.get(API_name, API_name) + passes += (Pass(name, **pass_options),) + return passes + ## API ## # pylint: disable=line-too-long -def pipeline(pass_pipeline: Optional[dict[str, dict[str, str]]] = None): +@functools.singledispatch +def pipeline(pass_pipeline: PipelineDict): """Configures the Catalyst MLIR pass pipeline for quantum circuit transformations for a QNode within a qjit-compiled program. Args: - fn (QNode): The QNode to run the pass pipeline on. pass_pipeline (dict[str, dict[str, str]]): A dictionary that specifies the pass pipeline order, and optionally arguments for each pass in the pipeline. Keys of this dictionary should correspond to names of passes found in the `catalyst.passes `. @@ -131,58 +185,27 @@ def fn(x): will always take precedence over global pass pipelines. """ - def _decorator(fn=None, **kwargs): - if fn is None: - return functools.partial(pipeline, **kwargs) - - if not isinstance(fn, qml.QNode): - raise TypeError(f"A QNode is expected, got the classical function {fn}") - - if pass_pipeline is None: - # TODO: design a default peephole pipeline - return fn - - fn_original_name = fn.__name__ - wrapped_qnode_function = fn.func - fn_clone = copy.copy(fn) - uniquer = str(_rename_to_unique()) - fn_clone.__name__ = fn_original_name + "_transformed" + uniquer - - pass_names = _API_name_to_pass_name() - - def wrapper(*args, **kwrags): - # TODO: we should not match pass targets by function name. - # The quantum scope work will likely put each qnode into a module - # instead of a `func.func ... attributes {qnode}`. - # When that is in place, the qnode's module can have a proper attribute - # (as opposed to discardable) that records its transform schedule, i.e. - # module_with_transform @name_of_module { - # // transform schedule - # } { - # // contents of the module - # } - # This eliminates the need for matching target functions by name. + def _decorator(qnode=None): + if not isinstance(qnode, qml.QNode): + raise TypeError(f"A QNode is expected, got the classical function {qnode}") - if EvaluationContext.is_tracing(): - for API_name, pass_options in pass_pipeline.items(): - opt = "" - for option, option_value in pass_options.items(): - opt += " " + str(option) + "=" + str(option_value) - apply_registered_pass_p.bind( - pass_name=pass_names[API_name], - options=f"func-name={fn_original_name}" + "_transformed" + uniquer + opt, - ) - return wrapped_qnode_function(*args, **kwrags) + clone = copy.copy(qnode) + clone.__name__ += "_transformed" - fn_clone.func = wrapper - fn_clone._peephole_transformed = True # pylint: disable=protected-access + @functools.wraps(clone) + def wrapper(*args, **kwargs): + if EvaluationContext.is_tracing(): + passes = kwargs.pop("pass_pipeline", tuple()) + passes += dictionary_to_tuple_of_passes(pass_pipeline) + kwargs["pass_pipeline"] = passes + return clone(*args, **kwargs) - return fn_clone + return wrapper return _decorator -def cancel_inverses(fn=None): +def cancel_inverses(qnode=None): """ Specify that the ``-removed-chained-self-inverse`` MLIR compiler pass for cancelling two neighbouring self-inverse @@ -292,29 +315,66 @@ def circuit(x: float): %2 = quantum.namedobs %out_qubits[ PauliZ] : !quantum.obs %3 = quantum.expval %2 : f64 """ - if not isinstance(fn, qml.QNode): - raise TypeError(f"A QNode is expected, got the classical function {fn}") + if not isinstance(qnode, qml.QNode): + raise TypeError(f"A QNode is expected, got the classical function {qnode}") + + clone = copy.copy(qnode) + clone.__name__ += "_cancel_inverses" + + @functools.wraps(clone) + def wrapper(*args, **kwargs): + pass_pipeline = kwargs.pop("pass_pipeline", tuple()) + pass_pipeline += (Pass("remove-chained-self-inverse"),) + kwargs["pass_pipeline"] = pass_pipeline + return clone(*args, **kwargs) + + return wrapper + + +def apply_pass(pass_name, *flags, **valued_options): + """Applies a single pass to the qnode""" - funcname = fn.__name__ - wrapped_qnode_function = fn.func - uniquer = str(_rename_to_unique()) + def decorator(qnode): - def wrapper(*args, **kwrags): - if EvaluationContext.is_tracing(): - apply_registered_pass_p.bind( - pass_name="remove-chained-self-inverse", - options=f"func-name={funcname}" + "_cancel_inverses" + uniquer, - ) - return wrapped_qnode_function(*args, **kwrags) + if not isinstance(qnode, qml.QNode): + # Technically, this apply pass is general enough that it can apply to + # classical functions too. However, since we lack the current infrastructure + # to denote a function, let's limit it to qnodes + raise TypeError(f"A QNode is expected, got the classical function {qnode}") - fn_clone = copy.copy(fn) - fn_clone.func = wrapper - fn_clone.__name__ = funcname + "_cancel_inverses" + uniquer + def qnode_call(*args, **kwargs): + pass_pipeline = kwargs.get("pass_pipeline", []) + pass_pipeline.append(Pass(pass_name, *flags, **valued_options)) + kwargs["pass_pipeline"] = pass_pipeline + return qnode(*args, **kwargs) - return fn_clone + return qnode_call + return decorator -def merge_rotations(fn=None): + +def apply_pass_plugin(plugin_name, pass_name, *flags, **valued_options): + """Applies a pass plugin""" + + def decorator(qnode): + if not isinstance(qnode, qml.QNode): + # Technically, this apply pass is general enough that it can apply to + # classical functions too. However, since we lack the current infrastructure + # to denote a function, let's limit it to qnodes + raise TypeError(f"A QNode is expected, got the classical function {qnode}") + + def qnode_call(*args, **kwargs): + pass_pipeline = kwargs.get("pass_pipeline", []) + pass_pipeline.append(PassPlugin(plugin_name, pass_name, *flags, **valued_options)) + kwargs["pass_pipeline"] = pass_pipeline + return qnode(*args, **kwargs) + + return qnode_call + + return decorator + + +def merge_rotations(qnode=None): """ Specify that the ``-merge-rotations`` MLIR compiler pass for merging roations (peephole) will be applied. @@ -375,70 +435,41 @@ def circuit(x: float): >>> circuit(0.54) Array(0.5965506257017892, dtype=float64) """ - if not isinstance(fn, qml.QNode): - raise TypeError(f"A QNode is expected, got the classical function {fn}") - - funcname = fn.__name__ - wrapped_qnode_function = fn.func - uniquer = str(_rename_to_unique()) - - def wrapper(*args, **kwrags): - if EvaluationContext.is_tracing(): - apply_registered_pass_p.bind( - pass_name="merge-rotations", - options=f"func-name={funcname}" + "_merge_rotations" + uniquer, - ) - return wrapped_qnode_function(*args, **kwrags) - - fn_clone = copy.copy(fn) - fn_clone.func = wrapper - fn_clone.__name__ = funcname + "_merge_rotations" + uniquer - - return fn_clone + if not isinstance(qnode, qml.QNode): + raise TypeError(f"A QNode is expected, got the classical function {qnode}") + clone = copy.copy(qnode) + clone.__name__ += "_merge_rotations" -## IMPL and helpers ## -# pylint: disable=missing-function-docstring -class _PipelineNameUniquer: - def __init__(self, i): - self.i = i + @functools.wraps(clone) + def wrapper(*args, **kwargs): + pass_pipeline = kwargs.pop("pass_pipeline", tuple()) + pass_pipeline += (Pass("merge-rotations"),) + kwargs["pass_pipeline"] = pass_pipeline + return clone(*args, **kwargs) - def get(self): - self.i += 1 - return self.i - - def reset(self): - self.i = -1 - - -PipelineNameUniquer = _PipelineNameUniquer(-1) - - -def _rename_to_unique(): - return PipelineNameUniquer.get() + return wrapper def _API_name_to_pass_name(): - return {"cancel_inverses": "remove-chained-self-inverse", "merge_rotations": "merge-rotations"} + return { + "cancel_inverses": "remove-chained-self-inverse", + "merge_rotations": "merge-rotations", + "ions_decomposition": "ions-decomposition", + } -def _inject_transform_named_sequence(): - """ - Inject a transform_named_sequence jax primitive. +def ions_decomposition(qnode=None): # pragma: nocover + """Apply decomposition pass at the MLIR level""" - This must be called when preprocessing the traced function in QJIT.capture(), - since to invoke -apply-transform-sequence, a transform_named_sequence primitive - must be in the jaxpr. - """ + if not isinstance(qnode, qml.QNode): + raise TypeError(f"A QNode is expected, got the classical function {qnode}") - transform_named_sequence_p.bind() + @functools.wraps(qnode) + def wrapper(*args, **kwargs): + pass_pipeline = kwargs.pop("pass_pipeline", tuple()) + pass_pipeline += (Pass("ions-decomposition"),) + kwargs["pass_pipeline"] = pass_pipeline + return qnode(*args, **kwargs) - -def add_mlir_quantum_decomposition(f, device): - """When called it adds the MLIR decomposition pass thanks to the transform dialect.""" - # TODO: make this non related to the name of the device - if device.original_device.name == "oqd.cloud": - apply_registered_pass_p.bind( - pass_name="ions-decomposition", - options=f"func-name={f.__name__}", - ) + return wrapper diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index 46efd76162..4336010600 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -31,7 +31,8 @@ from functools import partial from io import TextIOWrapper from operator import is_not -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from catalyst.utils.exceptions import CompileError @@ -74,6 +75,8 @@ class CompileOptions: A dictionary that specifies the quantum circuit transformation pass pipeline order, and optionally arguments for each pass in the pipeline. Default is None. + pass_plugins (Optional[Set[Path]]): List of paths to pass plugins. + dialect_plugins (Optional[Set[Path]]): List of paths to dialect plugins. """ verbose: Optional[bool] = False @@ -93,6 +96,8 @@ class CompileOptions: seed: Optional[int] = None experimental_capture: Optional[bool] = False circuit_transform_pipeline: Optional[dict[str, dict[str, str]]] = None + pass_plugins: Optional[Set[Path]] = None + dialect_plugins: Optional[Set[Path]] = None def __post_init__(self): # Check that async runs must not be seeded @@ -121,6 +126,10 @@ def __post_init__(self): self.static_argnums = (static_argnums,) elif isinstance(static_argnums, Iterable): self.static_argnums = tuple(static_argnums) + if self.pass_plugins is None: + self.pass_plugins = set() + if self.dialect_plugins is None: + self.dialect_plugins = set() def __deepcopy__(self, memo): """Make a deep copy of all fields of a CompileOptions object except the logfile, which is @@ -159,7 +168,9 @@ def get_enforce_runtime_invariants_stage(_options: CompileOptions) -> List[str]: # Split multiple tapes enforces that invariant. "split-multiple-tapes", # Run the transform sequence defined in the MLIR module - "apply-transform-sequence", + "builtin.module(apply-transform-sequence)", + # Lower the static custom ops to regular custom ops with dynamic parameters. + "static-custom-lowering", # Nested modules are something that will be used in the future # for making device specific transformations. # Since at the moment, nothing in the runtime is using them diff --git a/frontend/catalyst/qfunc.py b/frontend/catalyst/qfunc.py index d0f41fe01b..da59f07ab2 100644 --- a/frontend/catalyst/qfunc.py +++ b/frontend/catalyst/qfunc.py @@ -49,7 +49,7 @@ from catalyst.jax_primitives import quantum_kernel_p from catalyst.jax_tracer import Function, trace_quantum_function from catalyst.logging import debug_logger -from catalyst.passes import pipeline +from catalyst.passes import dictionary_to_tuple_of_passes from catalyst.tracing.contexts import EvaluationContext from catalyst.tracing.type_signatures import filter_static_args from catalyst.utils.exceptions import CompileError @@ -105,11 +105,8 @@ def __call__(self, *args, **kwargs): assert isinstance(self, qml.QNode) # Update the qnode with peephole pipeline - if "pass_pipeline" in kwargs.keys(): - pass_pipeline = kwargs["pass_pipeline"] - if not hasattr(self, "_peephole_transformed"): - self = pipeline(pass_pipeline=pass_pipeline)(self) - kwargs.pop("pass_pipeline") + pass_pipeline = kwargs.pop("pass_pipeline", tuple()) + pass_pipeline = dictionary_to_tuple_of_passes(pass_pipeline) # Mid-circuit measurement configuration/execution dynamic_one_shot_called = getattr(self, "_dynamic_one_shot_called", False) @@ -150,7 +147,9 @@ def _eval_quantum(*args, **kwargs): ) dynamic_args = filter_static_args(args, static_argnums) args_flat = tree_flatten((dynamic_args, kwargs))[0] - res_flat = quantum_kernel_p.bind(flattened_fun, *args_flat, qnode=self) + res_flat = quantum_kernel_p.bind( + flattened_fun, *args_flat, qnode=self, pipeline=pass_pipeline + ) return tree_unflatten(out_tree_promise(), res_flat)[0] diff --git a/frontend/catalyst/third_party/cuda/__init__.py b/frontend/catalyst/third_party/cuda/__init__.py index f2aa65e9f0..b8d6de164e 100644 --- a/frontend/catalyst/third_party/cuda/__init__.py +++ b/frontend/catalyst/third_party/cuda/__init__.py @@ -131,7 +131,7 @@ class BaseCudaInstructionSet(qml.devices.QubitDevice): "PauliX", "PauliZ", ] - config = Path(__file__).parent / "cuda_quantum.toml" + config_filepath = Path(__file__).parent / "cuda_quantum.toml" def __init__(self, shots=None, wires=None): _check_version_compatibility() diff --git a/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py b/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py index dc73891466..adc6545473 100644 --- a/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py +++ b/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py @@ -428,10 +428,17 @@ def change_instruction(ctx, eqn): op = params["op"] cuda_inst_name = from_catalyst_to_cuda[op] qubits_len = params["qubits_len"] + static_params = params.get("static_params") # Now, we can map to the correct op # For now just assume rx - cuda_inst(ctx.kernel, *qubits_or_params, inst=cuda_inst_name, qubits_len=qubits_len) + cuda_inst( + ctx.kernel, + *qubits_or_params, + inst=cuda_inst_name, + qubits_len=qubits_len, + static_params=static_params, + ) # Finally determine how many are qubits. qubits = qubits_or_params[:qubits_len] @@ -576,13 +583,8 @@ def change_expval(ctx, eqn): invals = _map(ctx.read, eqn.invars) obs = invals[0] - # Params: - # * shots: Shots - shots = eqn.params["shots"] - shots = shots if shots is not None else -1 - # To obtain expval, we first obtain an observe object. - observe_results = cudaq_observe(ctx.kernel, obs, shots) + observe_results = cudaq_observe(ctx.kernel, obs) # And then we call expectation on that object. result = cudaq_expectation(observe_results) outvariables = [ctx.new_variable()] @@ -848,7 +850,8 @@ def cudaq_backend_info(device, _capabilities) -> BackendInfo: # We could also pass abstract arguments here in *args # the same way we do so in Catalyst. # But I think that is redundant now given make_jaxpr2 - jaxpr, _, out_treedef = trace_to_jaxpr(func, static_args, abs_axes, args, {}) + jaxpr, _, out_treedef, plugins = trace_to_jaxpr(func, static_args, abs_axes, args, {}) + assert not plugins, "Plugins are not compatible with CUDA integration" # TODO(@erick-xanadu): # What about static_args? diff --git a/frontend/catalyst/third_party/cuda/cuda_quantum.toml b/frontend/catalyst/third_party/cuda/cuda_quantum.toml index 0f7274ac09..9115776ebd 100644 --- a/frontend/catalyst/third_party/cuda/cuda_quantum.toml +++ b/frontend/catalyst/third_party/cuda/cuda_quantum.toml @@ -1,59 +1,70 @@ -schema = 2 - -# The union of all gate types listed in this section must match what -# the device considers "supported" through PennyLane's device API. -[operators.gates.native] - -CNOT = { properties = [ "invertible" ] } -CY = { properties = [ "invertible" ] } -CZ = { properties = [ "invertible" ] } -CRX = { properties = [ "invertible" ] } -CRY = { properties = [ "invertible" ] } -CRZ = { properties = [ "invertible" ] } -PauliX = { properties = [ "invertible" ] } -PauliY = { properties = [ "invertible" ] } -PauliZ = { properties = [ "invertible" ] } -Hadamard = { properties = [ "invertible" ] } -S = { properties = [ "invertible" ] } -T = { properties = [ "invertible" ] } -RX = { properties = [ "invertible" ] } -RY = { properties = [ "invertible" ] } -RZ = { properties = [ "invertible" ] } -SWAP = { properties = [ "invertible" ] } -CSWAP = { properties = [ "invertible" ] } - -# Operators that should be decomposed according to the algorithm used -# by PennyLane's device API. -# Optional, since gates not listed in this list will typically be decomposed by -# default, but can be useful to express a deviation from this device's regular -# strategy in PennyLane. -# Everything else should be decomposed. -[operators.gates.decomp] - -[operators.gates.matrix] +schema = 3 + +# The set of all gate types supported at the runtime execution interface of the +# device, i.e., what is supported by the `execute` method of the Device API. +# The gate definition has the following format: +# +# GATE = { properties = [ PROPS ], conditions = [ CONDS ] } +# +# where PROPS and CONS are zero or more comma separated quoted strings. +# +# PROPS: zero or more comma-separated quoted strings: +# - "controllable": if a controlled version of this gate is supported. +# - "invertible": if the adjoint of this operation is supported. +# - "differentiable": if device gradient is supported for this gate. +# CONDS: zero or more comma-separated quoted strings: +# - "analytic" or "finiteshots": if this operation is only supported in +# either analytic execution or with shots, respectively. +# - "terms-commute": if this composite operator is only supported +# given that its terms commute. Only relevant for Prod, SProd, Sum, +# LinearCombination, and Hamiltonian. +# +[operators.gates] + +CNOT = { properties = ["invertible"] } +CY = { properties = ["invertible"] } +CZ = { properties = ["invertible"] } +CRX = { properties = ["invertible"] } +CRY = { properties = ["invertible"] } +CRZ = { properties = ["invertible"] } +PauliX = { properties = ["invertible"] } +PauliY = { properties = ["invertible"] } +PauliZ = { properties = ["invertible"] } +Hadamard = { properties = ["invertible"] } +S = { properties = ["invertible"] } +T = { properties = ["invertible"] } +RX = { properties = ["invertible"] } +RY = { properties = ["invertible"] } +RZ = { properties = ["invertible"] } +SWAP = { properties = ["invertible"] } +CSWAP = { properties = ["invertible"] } # Observables supported natively by the device [operators.observables] -PauliX = {} -PauliZ = {} -Sum = {} + +PauliX = { } +PauliZ = { } +Sum = { } [measurement_processes] -Expval = {} -State = { condition = [ "analytic" ] } -Sample = { condition = [ "finiteshots" ] } -Counts = { condition = [ "finiteshots" ] } +ExpectationMP = {} +StateMP = { conditions = ["analytic"] } +SampleMP = { conditions = ["finiteshots"] } +CountsMP = { conditions = ["finiteshots"] } +# Additional support that the device may provide. All accepted fields and their +# default values are listed below. Any fields missing from the TOML file will be +# set to their default values. [compilation] -# If the device is compatible with qjit + +# Whether the device is compatible with qjit. qjit_compatible = true -# If the device requires run time generation of the quantum circuit. +# Whether the device requires run time generation of the quantum circuit. runtime_code_generation = false -# Technically limited support -mid_circuit_measurement = true - -# This field is currently unchecked but it is reserved for the purpose of -# determining if the device supports dynamic qubit allocation/deallocation. +# The methods of handling mid-circuit measurements that the device supports, e.g., +# "one-shot", "device", "tree-traversal", etc. An empty list indicates that the device +# does not support mid-circuit measurements. +supported_mcm_methods = [ "device", "one-shot" ] +# Whether the device supports dynamic qubit allocation/deallocation. dynamic_qubit_management = false - diff --git a/frontend/catalyst/third_party/cuda/primitives/__init__.py b/frontend/catalyst/third_party/cuda/primitives/__init__.py index e1679d3bcb..5c00100308 100644 --- a/frontend/catalyst/third_party/cuda/primitives/__init__.py +++ b/frontend/catalyst/third_party/cuda/primitives/__init__.py @@ -301,27 +301,33 @@ def make_primitive_for_gate(): kernel_gate_p = jax.core.Primitive("kernel_inst") kernel_gate_p.multiple_results = True - def gate_func(kernel, *qubits_or_params, inst=None, qubits_len=-1): + def gate_func(kernel, *qubits_or_params, inst=None, qubits_len=-1, static_params=None): """Convenience. Quantum operations in CUDA-quantum return no values. But JAXPR expects return values. We can just say that multiple_results = True and return an empty tuple. """ - kernel_gate_p.bind(kernel, *qubits_or_params, inst=inst, qubits_len=qubits_len) + kernel_gate_p.bind( + kernel, *qubits_or_params, inst=inst, qubits_len=qubits_len, static_params=static_params + ) return tuple() @kernel_gate_p.def_impl - def gate_impl(kernel, *qubits_or_params, inst=None, qubits_len=-1): + def gate_impl(kernel, *qubits_or_params, inst=None, qubits_len=-1, static_params=None): """Concrete implementation.""" assert inst and qubits_len > 0 + if static_params is None: + static_params = [] method = getattr(cudaq.Kernel, inst) targets = qubits_or_params[:qubits_len] params = qubits_or_params[qubits_len:] + if not params: + params = static_params method(kernel, *params, *targets) return tuple() @kernel_gate_p.def_abstract_eval - def gate_abs(_kernel, *_qubits_or_params, inst=None, qubits_len=-1): + def gate_abs(_kernel, *_qubits_or_params, inst=None, qubits_len=-1, static_params=None): """Abstract evaluation.""" return tuple() diff --git a/frontend/catalyst/third_party/oqc/oqc_device.py b/frontend/catalyst/third_party/oqc/oqc_device.py index b61a4c39bd..7377b936f6 100644 --- a/frontend/catalyst/third_party/oqc/oqc_device.py +++ b/frontend/catalyst/third_party/oqc/oqc_device.py @@ -36,7 +36,7 @@ class OQCDevice(Device): """The OQC device allows to access the hardware devices from OQC using Catalyst.""" - config = get_lib_path("oqc_runtime", "OQC_LIB_DIR") + "/backend" + "/oqc.toml" + config_filepath = get_lib_path("oqc_runtime", "OQC_LIB_DIR") + "/backend" + "/oqc.toml" @staticmethod def get_c_interface(): diff --git a/frontend/catalyst/third_party/oqc/src/CMakeLists.txt b/frontend/catalyst/third_party/oqc/src/CMakeLists.txt index d71ef0f6ac..56d55c084d 100644 --- a/frontend/catalyst/third_party/oqc/src/CMakeLists.txt +++ b/frontend/catalyst/third_party/oqc/src/CMakeLists.txt @@ -10,31 +10,39 @@ set(oqc_backend_dir "${OQC_BUILD_DIR}/backend") set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) -# Avoid warning raised by pybind11 on newer cmake versions. PYBIND11_FINDPYTHON=ON caused issues. -if (${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.27") - cmake_policy(SET CMP0148 OLD) +# nanobind suggests including these lines to configure CMake to perform an optimized release build +# by default unless another build type is specified. Without this addition, binding code may run +# slowly and produce large binaries. +# See https://nanobind.readthedocs.io/en/latest/building.html#preliminaries +if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) + set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") endif() -include(FetchContent) - -function(fetch_pybind11) - find_package(pybind11 CONFIG) - if (pybind11_FOUND) - message(STATUS, "FOUND pybind11") - else() - message(STATUS "Could not find existing pybind11-dev package. Building from source.") - set(CMAKE_POLICY_DEFAULT_CMP0127 NEW) # To suppress pybind11 CMP0127 warning +# Locate Python +# The optional component is only used for the C++ test suite (to spin up its own interpreter), +# and requires libpython.so to be available on the system. +find_package(Python REQUIRED + COMPONENTS Interpreter Development.Module + OPTIONAL_COMPONENTS Development.Embed Development.SABIModule +) - FetchContent_Declare(pybind11 - GIT_REPOSITORY https://github.com/pybind/pybind11.git - GIT_TAG v2.10.1 - ) +# Locate nanobind +execute_process( + COMMAND "${Python_EXECUTABLE}" -c "import nanobind; print(nanobind.cmake_dir())" + OUTPUT_VARIABLE nanobind_DIR OUTPUT_STRIP_TRAILING_WHITESPACE +) +find_package(nanobind CONFIG REQUIRED) - FetchContent_MakeAvailable(pybind11) - endif() -endfunction() +# Create the Python `oqc_python_module` module +# Target the stable ABI for Python 3.12+, which reduces the number of binary wheels that must be +# built (`STABLE_ABI` does nothing on older Python versions). +nanobind_add_module(oqc_python_module STABLE_ABI oqc_python_module.cpp) -fetch_pybind11() +# Use a consistant suffix ".so" rather than, e.g. ".abi3.so" (when using the Stable ABI) or +# ".cpython-3xx-darwin.so". Doing so simplifies the process to locate it when calling +# `dlopen(OQC_PY)` in frontend/catalyst/third_party/oqc/src/OQCRunner.hpp. +set_target_properties(oqc_python_module PROPERTIES SUFFIX ".so") message(STATUS "Building the OQC device.") @@ -43,7 +51,7 @@ add_library(rtd_oqc SHARED OQCDevice.cpp) target_include_directories(rtd_oqc PUBLIC . ${runtime_includes} ${backend_includes} - ) +) set(OQC_LIBRARIES rtd_oqc @@ -52,7 +60,6 @@ set(OQC_LIBRARIES set_target_properties(rtd_oqc PROPERTIES BUILD_RPATH "$ORIGIN/../utils") target_link_directories(rtd_oqc PRIVATE ${runtime_lib}) -pybind11_add_module(oqc_python_module SHARED oqc_python_module.cpp) target_include_directories(oqc_python_module PRIVATE ${runtime_includes}) add_dependencies(rtd_oqc oqc_python_module) diff --git a/frontend/catalyst/third_party/oqc/src/Makefile b/frontend/catalyst/third_party/oqc/src/Makefile index 66580e2fcf..c354167562 100644 --- a/frontend/catalyst/third_party/oqc/src/Makefile +++ b/frontend/catalyst/third_party/oqc/src/Makefile @@ -16,8 +16,7 @@ configure: -DCMAKE_C_COMPILER=$(C_COMPILER) \ -DCMAKE_CXX_COMPILER=$(CXX_COMPILER) \ -DRUNTIME_BUILD_DIR=$(RT_BUILD_DIR) \ - -DPYTHON_EXECUTABLE=$(PYTHON) \ - -Dpybind11_DIR=$(shell $(PYTHON) -c "import pybind11; print(pybind11.get_cmake_dir())") + -DPython_EXECUTABLE=$(PYTHON) $(OQC_BUILD_DIR)/librtd_oqc.so: configure cmake --build $(OQC_BUILD_DIR) --target rtd_oqc -j$(NPROC) diff --git a/frontend/catalyst/third_party/oqc/src/oqc.toml b/frontend/catalyst/third_party/oqc/src/oqc.toml index 434bb40913..d1eb1e1165 100644 --- a/frontend/catalyst/third_party/oqc/src/oqc.toml +++ b/frontend/catalyst/third_party/oqc/src/oqc.toml @@ -1,8 +1,25 @@ -schema = 2 +schema = 3 -# The union of all gate types listed in this section must match what -# the device considers "supported" through PennyLane's device API. -[operators.gates.native] +# The set of all gate types supported at the runtime execution interface of the +# device, i.e., what is supported by the `execute` method of the Device API. +# The gate definition has the following format: +# +# GATE = { properties = [ PROPS ], conditions = [ CONDS ] } +# +# where PROPS and CONS are zero or more comma separated quoted strings. +# +# PROPS: zero or more comma-separated quoted strings: +# - "controllable": if a controlled version of this gate is supported. +# - "invertible": if the adjoint of this operation is supported. +# - "differentiable": if device gradient is supported for this gate. +# CONDS: zero or more comma-separated quoted strings: +# - "analytic" or "finiteshots": if this operation is only supported in +# either analytic execution or with shots, respectively. +# - "terms-commute": if this composite operator is only supported +# given that its terms commute. Only relevant for Prod, SProd, Sum, +# LinearCombination, and Hamiltonian. +# +[operators.gates] Identity = { } Hadamard = { } @@ -28,38 +45,26 @@ U1 = { } U2 = { } U3 = { } -# Operators that should be decomposed according to the algorithm used -# by PennyLane's device API. -# Optional, since gates not listed in this list will typically be decomposed by -# default, but can be useful to express a deviation from this device's regular -# strategy in PennyLane. - -[operators.gates.decomp] - -# Gates which should be translated to QubitUnitary -[operators.gates.matrix] - [operators.observables] -# Observables supported by the device [measurement_processes] -Counts = { condition = [ "finiteshots" ] } - +CountsMP = { conditions = ["finiteshots"] } +# Additional support that the device may provide. All accepted fields and their +# default values are listed below. Any fields missing from the TOML file will be +# set to their default values. [compilation] -# If the device is compatible with qjit + +# Whether the device is compatible with qjit. qjit_compatible = true -# If the device requires run time generation of the quantum circuit. +# Whether the device requires run time generation of the quantum circuit. runtime_code_generation = true - -# If the device supports mid circuit measurements natively -mid_circuit_measurement = false - -# This field is currently unchecked but it is reserved for the purpose of -# determining if the device supports dynamic qubit allocation/deallocation. +# The methods of handling mid-circuit measurements that the device supports, e.g., +# "one-shot", "device", "tree-traversal", etc. An empty list indicates that the device +# does not support mid-circuit measurements. +supported_mcm_methods = [ ] +# Whether the device supports dynamic qubit allocation/deallocation. dynamic_qubit_management = false - -# whether the device can support non-commuting measurements together -# in a single execution +# Whether simultaneous measurements of non-commuting observables is supported. non_commuting_observables = false diff --git a/frontend/catalyst/third_party/oqc/src/oqc_python_module.cpp b/frontend/catalyst/third_party/oqc/src/oqc_python_module.cpp index 7dfab5d62c..0b9e4ab521 100644 --- a/frontend/catalyst/third_party/oqc/src/oqc_python_module.cpp +++ b/frontend/catalyst/third_party/oqc/src/oqc_python_module.cpp @@ -12,9 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include #include +#include + +#include +#include +#include #include "Exception.hpp" @@ -46,28 +49,34 @@ except Exception as e: [[gnu::visibility("default")]] void counts(const char *_circuit, const char *_device, size_t shots, size_t num_qubits, const char *_kwargs, void *_vector) { - namespace py = pybind11; - using namespace py::literals; + namespace nb = nanobind; + using namespace nb::literals; - py::gil_scoped_acquire lock; + nb::gil_scoped_acquire lock; - auto locals = py::dict("circuit"_a = _circuit, "device"_a = _device, "kwargs"_a = _kwargs, - "shots"_a = shots, "msg"_a = ""); + nb::dict locals; + locals["circuit"] = _circuit; + locals["device"] = _device; + locals["kwargs"] = _kwargs; + locals["shots"] = shots; + locals["msg"] = ""; - py::exec(program, py::globals(), locals); + // Evaluate in scope of main module + nb::object scope = nb::module_::import_("__main__").attr("__dict__"); + nb::exec(nb::str(program.c_str()), scope, locals); - auto &&msg = locals["msg"].cast(); + auto msg = nb::cast(locals["msg"]); RT_FAIL_IF(!msg.empty(), msg.c_str()); - py::dict results = locals["counts"]; + nb::dict results = locals["counts"]; std::vector *counts_value = reinterpret_cast *>(_vector); for (auto item : results) { auto key = item.first; auto value = item.second; - counts_value->push_back(value.cast()); + counts_value->push_back(nb::cast(value)); } return; } -PYBIND11_MODULE(oqc_python_module, m) { m.doc() = "oqc"; } +NB_MODULE(oqc_python_module, m) { m.doc() = "oqc"; } diff --git a/frontend/catalyst/third_party/oqc/src/tests/CMakeLists.txt b/frontend/catalyst/third_party/oqc/src/tests/CMakeLists.txt index fcbaff86a6..79e485e900 100644 --- a/frontend/catalyst/third_party/oqc/src/tests/CMakeLists.txt +++ b/frontend/catalyst/third_party/oqc/src/tests/CMakeLists.txt @@ -1,9 +1,3 @@ -cmake_minimum_required(VERSION 3.20) - -project(oqc_runtime_tests) - -set(CMAKE_CXX_STANDARD 20) - Include(FetchContent) FetchContent_Declare( @@ -14,8 +8,6 @@ FetchContent_Declare( FetchContent_MakeAvailable(Catch2) -fetch_pybind11() - # Required for catch_discover_tests(). list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/contrib) @@ -36,20 +28,26 @@ endif() target_include_directories(runner_tests_oqc PRIVATE ${OQC_LIBRARIES} - ) +) target_link_directories(runner_tests_oqc PRIVATE ${runtime_lib}) -# To avoid link to libpython, we use pybind11::module interface library. -target_compile_definitions(runner_tests_oqc PUBLIC INITIALIZE_PYTHON) + +# Locate PyBind11 +execute_process( + COMMAND "${Python_EXECUTABLE}" -c "import pybind11; print(pybind11.get_cmake_dir())" + OUTPUT_VARIABLE pybind11_DIR OUTPUT_STRIP_TRAILING_WHITESPACE +) +find_package(pybind11 CONFIG REQUIRED) + target_link_libraries(runner_tests_oqc PRIVATE Catch2::Catch2 pybind11::embed ${OQC_LIBRARIES} - ) +) target_sources(runner_tests_oqc PRIVATE Test_OpenQASM2Builder.cpp Test_OQCDevice.cpp - ) +) catch_discover_tests(runner_tests_oqc) diff --git a/frontend/catalyst/third_party/oqd/__init__.py b/frontend/catalyst/third_party/oqd/__init__.py new file mode 100644 index 0000000000..c928646676 --- /dev/null +++ b/frontend/catalyst/third_party/oqd/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This submodule contains classes for the OQD device and its properties. +""" + +from .oqd_device import OQDDevice +from .oqd_database_managers import OQDDeviceDatabase, OQDQubitDatabase, OQDBeamDatabase + +__all__ = ["OQDDeviceDatabase", "OQDQubitDatabase", "OQDBeamDatabase", "OQDDevice"] diff --git a/frontend/catalyst/third_party/oqd/oqd_database_managers.py b/frontend/catalyst/third_party/oqd/oqd_database_managers.py new file mode 100644 index 0000000000..96bd058aef --- /dev/null +++ b/frontend/catalyst/third_party/oqd/oqd_database_managers.py @@ -0,0 +1,470 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +OQD Device Properties +~~~~~~~~~~~~~~~~~~~~~ + +This module defines the classes that represent the properties of an Open Quantum Design (OQD) +trapped-ion quantum computer device and the methods for loading them from their respective databases +and configuration files. +""" + +from collections.abc import Collection +from dataclasses import dataclass, field +from numbers import Number +from os import PathLike +from typing import Union, Collection + +from catalyst.utils.toml_utils import load_toml, safe_eval + + +SUPPORTED_SCHEMAS = ["v0.1"] + + +@dataclass +class OQDDeviceParameter: + """A class to represent a device parameter for an OQD trapped-ion experiment workflow. + + Attributes: + description: A short description of the device parameter. + stage: The stage in the trapped-ion workflow that the parameter applies to. The ordered + list of stages in the OQD experiment workflow is ['Loading', 'Trapping', + 'Initialization', 'Experiment', 'Detection']. + process: The process within the given stage of the trapped-ion workflow that the parameter + applies to. + equation: The equation that describes how the parameter is computed, if applicable. + value: The parameter value. + unit: The unit associated with the parameter value, if applicable. + """ + + # TODO: Some parameters, mainly laser frequencies, are expressed as the sum of a nominal + # wavelength and a frequency offset, e.g. '493 nm - 7.506 GHz', which is equal to 493.0061 nm. + # We need to be able to support this. + # Furthermore, some parameters may either have one value or have two simultaneous values, e.g. + # 'Doppler cooling laser frequency', which has values '493 nm - 7.506 GHz' and + # '493 nm + 4.259 GHz'. We need to be able to support this as well. + + name: str + description: str = "" + stage: str = "" + process: str = "" + equation: str = "" + value: Union[int, float] = None + unit: str = "" + + @classmethod + def from_dict(cls, name: str, params: dict) -> "OQDDeviceParameter": + """Creates an OQDDeviceParameter object from a dictionary. + + Args: + name: The name of the device parameter. + params: A dictionary containing the device parameter parameters, typically as parsed + from a TOML document. + + Returns: + OQDDeviceParameter: The OQDDeviceParameter object. + """ + return cls( + name=name, + description=params["description"], + stage=params["stage"], + process=params["process"], + equation=params["equation"], + value=params["value"], + unit=params["unit"], + ) + + +@dataclass +class OQDDeviceDatabase: + """A database class to represent the properties of an OQD device. + + The properties of the device include hardware specification, parameters relating to the + experimental apparatus, and generally any other parameters needed by the compiler. Physical + constants, such as the energy levels of the ion(s) being uses, are handled separately. + + Attributes: + parameters: A dictionary of OQD device parameters. + """ + + parameters: dict[str, OQDDeviceParameter] = field(default_factory=dict) + + @classmethod + def from_toml(cls, filepath_or_buffer: Union[str, PathLike]): + """Loads an OQDDeviceProperties object from a TOML file or string. + + Args: + filepath_or_buffer: The path to the TOML file or a TOML document string. + """ + try: + document = load_toml(filepath_or_buffer) + + except Exception as e: + raise ValueError( + "Failed to load TOML document when creating OQDDeviceProperties" + ) from e + + properties = cls._parse_toml_document(document) + return properties + + @classmethod + def _parse_toml_document(cls, document: dict): + """Parses a TOML document and returns an OQDDeviceProperties object.""" + _check_oqd_config_schema(document) + cls._check_required_keys(document) + return cls( + parameters={ + name: OQDDeviceParameter.from_dict(name, level) + for name, level in document["parameters"].items() + } + ) + + @staticmethod + def _check_required_keys(document: dict): + assert ( + "parameters" in document + ), "TOML document for OQD device parameters must contain key 'parameters'" + + +@dataclass +class OQDIonLevelParameters: + """A class to represent ion energy levels for an OQD trapped-ion experiment workflow.""" + + # pylint: disable=too-many-instance-attributes + name: str + principal: float + spin: float + orbital: float + nuclear: float + spin_orbital: float + spin_orbital_nuclear: float + spin_orbital_nuclear_magnetization: float + energy: Union[float, str] + + @classmethod + def from_dict(cls, name: str, params: dict) -> "OQDIonLevelParameters": + """Creates an OQDIonLevelParameters object from a dictionary. + + Args: + name: The name of the level, e.g. 'downstate', 'upstate', 'estate', etc. + params: A dictionary containing the level parameters, including the relevant quantum + numbers and the level energy, typically as parsed from a TOML document. + + Returns: + OQDIonLevelParameters: The OQDIonLevelParameters object. + """ + return cls( + name=name, + principal=params["principal"], + spin=params["spin"], + orbital=params["orbital"], + nuclear=params["nuclear"], + spin_orbital=params["spin_orbital"], + spin_orbital_nuclear=params["spin_orbital_nuclear"], + spin_orbital_nuclear_magnetization=params["spin_orbital_nuclear_magnetization"], + energy=_parse_value_or_expression_as_float(params["energy"]), + ) + + +@dataclass +class OQDIonTransitionParameters: + """A class to represent a specific transition between ion energy levels for an OQD trapped-ion + experiment workflow.""" + + name: str + level1: str + level2: str + einsteinA: float + + @classmethod + def from_dict(cls, name: str, params: dict) -> "OQDIonTransitionParameters": + """Creates an OQDIonTransitionParameters object from a dictionary. + + Args: + name: The name of the transition as '_', e.g. 'downstate_upstate', + 'upstate_downstate', etc. + params: A dictionary containing the transition parameters, typically as parsed from a + TOML document. + + Returns: + OQDIonTransitionParameters: The OQDIonTransitionParameters object. + """ + return cls( + name=name, + level1=params["level1"], + level2=params["level2"], + einsteinA=params["einsteinA"], + ) + + +@dataclass +class OQDIonParameters: + """A class to represent an ion used in an OQD trapped-ion experiment workflow.""" + + mass: float + charge: int + position: list[int] + levels: dict[str, OQDIonLevelParameters] + transitions: dict[str, OQDIonTransitionParameters] + + @classmethod + def from_dict(cls, params: dict) -> "OQDIonParameters": + """Creates an OQDIonParameters object from a dictionary. + + Args: + params: A dictionary containing the ion parameters, typically as parsed from a TOML + document. + + Returns: + OQDIonParameters: The OQDIonParameters object. + """ + return cls( + mass=params["mass"], + charge=params["charge"], + position=params["position"], + levels={ + name: OQDIonLevelParameters.from_dict(name, level) + for name, level in params["levels"].items() + }, + transitions={ + name: OQDIonTransitionParameters.from_dict(name, transition) + for name, transition in params["transitions"].items() + }, + ) + + +@dataclass +class OQDPhononParameters: + """A class to represent a phonon mode for an OQD trapped-ion experiment workflow.""" + + energy: Union[float, str] + eigenvector: list[int] + + @classmethod + def from_dict(cls, params: dict) -> "OQDPhononParameters": + """Creates an OQDPhononParameters object from a dictionary. + + Args: + params: A dictionary containing the phonon mode parameters, typically as parsed from a + TOML document. + + Returns: + OQDPhononParameters: The OQDPhononParameters object. + """ + return cls( + energy=_parse_value_or_expression_as_float(params["energy"]), + eigenvector=params["eigenvector"], + ) + + +@dataclass +class OQDQubitDatabase: + """A database class to represent the qubit parameters for an OQD trapped-ion experiment workflow.""" # pylint: disable=line-too-long + + ion_parameters: dict[str, OQDIonParameters] + phonon_parameters: dict[str, OQDPhononParameters] + + @classmethod + def from_toml( + cls, + filepath_or_buffer: Union[str, PathLike], + ion_species_filter: Union[str, Collection[str]] = None, + phonon_mode_filter: Union[str, Collection[str]] = None, + ) -> "OQDQubitDatabase": + """Loads an OQDQubitDatabase object from a TOML file or string. + + Args: + filepath_or_buffer: The path to the TOML file or a TOML document string. + ion_species_filter (optional): A list of ion species to include in the + OQDQubitDatabase object. If None, all ion species are included. + phonon_mode_filter (optional): A list of phonon modes to include in the + OQDQubitDatabase object. If None, all phonon modes are included. + """ + try: + document = load_toml(filepath_or_buffer) + + except Exception as e: + raise ValueError("Failed to load TOML document when creating OQDQubitDatabase") from e + + _check_oqd_config_schema(document) + cls._check_required_keys(document) + + # Collect the ion properties + apply_ion_species_filter = ion_species_filter is not None + if apply_ion_species_filter: + ion_species_filter = _string_or_collection_of_strings_to_set(ion_species_filter) + + _ion_properties = {} + for ion_species in document["ions"]: + if apply_ion_species_filter and ion_species not in ion_species_filter: + continue + _ion_properties[ion_species] = OQDIonParameters.from_dict(document["ions"][ion_species]) + + # Collect the phonon properties + apply_phonon_mode_filter = phonon_mode_filter is not None + if apply_phonon_mode_filter: + phonon_mode_filter = _string_or_collection_of_strings_to_set(phonon_mode_filter) + + _phonon_properties = {} + for phonon_mode in document["phonons"]: + if apply_phonon_mode_filter and phonon_mode not in phonon_mode_filter: + continue + _phonon_properties[phonon_mode] = OQDPhononParameters.from_dict( + document["phonons"][phonon_mode] + ) + + return cls( + ion_parameters=_ion_properties, + phonon_parameters=_phonon_properties, + ) + + @staticmethod + def _check_required_keys(document: dict): + assert "ions" in document, "TOML document for OQD qubit parameters must contain key 'ions'" + assert ( + "phonons" in document + ), "TOML document for OQD qubit parameters must contain key 'phonons'" + + +@dataclass +class OQDBeamParameters: + """A class to represent the beam parameters for an OQD trapped-ion experiment workflow.""" + + transition: str + rabi: Union[float, str] + detuning: float + phase: float + polarization: float + wavevector: float + + @classmethod + def from_dict(cls, params: dict) -> "OQDBeamParameters": + """Creates an OQDBeamParameters object from a dictionary. + + Args: + params: A dictionary containing the beam parameters, typically as parsed from a TOML + document. + + Returns: + OQDBeamParameters: The OQDBeamParameters object. + """ + return cls( + transition=params["transition"], + rabi=_parse_value_or_expression_as_float(params["rabi"]), + detuning=_parse_value_or_expression_as_float(params["detuning"]), + phase=_parse_value_or_expression_as_float(params["phase"]), + polarization=_parse_value_or_expression_as_float(params["polarization"]), + wavevector=params["wavevector"], + ) + + +@dataclass +class OQDBeamDatabase: + """A database class to represent the beam parameters for an OQD trapped-ion experiment workflow.""" # pylint: disable=line-too-long + + beam_parameters: dict[str, OQDBeamParameters] + + @classmethod + def from_toml(cls, filepath_or_buffer: Union[str, PathLike]): + """Loads an OQDBeamDatabase object from a TOML file or string. + + Args: + filepath_or_buffer: The path to the TOML file or a TOML document string. + """ + try: + document = load_toml(filepath_or_buffer) + + except Exception as e: + raise ValueError("Failed to load TOML document when creating OQDBeamDatabase") from e + + _check_oqd_config_schema(document) + cls._check_required_keys(document) + + return cls( + beam_parameters={ + beam: OQDBeamParameters.from_dict(document["beams"][beam]) + for beam in document["beams"] + } + ) + + @staticmethod + def _check_required_keys(document: dict): + assert "beams" in document, "TOML document for OQD beam parameters must contain key 'beams'" + + +def _parse_value_or_expression_as_float(input_: Union[Number, str]): + """Parses a numeric value, or an expression that can be evaluated to a numeric value, and return + as a float. + + Args: + input_: The numeric value or expression to be evaluated. + + Returns: + float: The original value, or the evaluated expression, as a float. + + Raises: + ValueError: If the expression is invalid. + TypeError: If the input is not a number or string. + """ + if isinstance(input_, Number): + return float(input_) + + elif isinstance(input_, str): + try: + result = float(safe_eval(input_)) + except Exception as e: + raise ValueError(f"Invalid expression: '{input_}'") from e + + return result + + else: + raise TypeError(f"Expected a number or string, but got {type(input_)}") + + +def _check_oqd_config_schema(document: dict): + """Checks that the TOML document has the correct schema.""" + + assert "oqd_config_schema" in document, "TOML document must contain key 'oqd_config_schema'" + + schema = document["oqd_config_schema"] + assert schema in SUPPORTED_SCHEMAS, ( + f"Unsupported OQD TOML config schema '{schema}'; " + f"supported schemas are {SUPPORTED_SCHEMAS}" + ) + + +def _string_or_collection_of_strings_to_set(input_: Union[str, Collection[str]]) -> set[str]: + """Converts a string or a collection of strings to a set of strings. + + Args: + input (Union[str, Collection[str]]): The input string or collection of strings. + + Raises: + TypeError: If the input is not a string or a collection of strings. + + Returns: + set[str]: The set of strings. + """ + if isinstance(input_, str): + return {input_} + + elif isinstance(input_, Collection): + assert all( + isinstance(item, str) for item in input_ + ), "All items in collection must be strings" + return set(input_) + + else: + raise TypeError(f"Expected a string or a collection of strings, but got {type(input_)}") diff --git a/frontend/catalyst/third_party/oqd/oqd_device.py b/frontend/catalyst/third_party/oqd/oqd_device.py new file mode 100644 index 0000000000..7f275b7551 --- /dev/null +++ b/frontend/catalyst/third_party/oqd/oqd_device.py @@ -0,0 +1,72 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +OQD Device +~~~~~~~~~~ + +This module defines the classes that represent an Open Quantum Design (OQD) +trapped-ion quantum computer device. +""" +from typing import Optional + +from pennylane.devices import Device, ExecutionConfig +from pennylane.transforms.core import TransformProgram +from catalyst.compiler import get_lib_path + +BACKENDS = ["default"] + + +class OQDDevice(Device): + """The OQD device allows access to the hardware devices from OQD using Catalyst.""" + + config_filepath = get_lib_path("oqd_runtime", "OQD_LIB_DIR") + "/backend" + "/oqd.toml" + + def __init__(self, wires, backend, shots, **kwargs): + self._backend = backend + _check_backend(backend=backend) + super().__init__(wires=wires, shots=shots, **kwargs) + + @property + def backend(self): + """Backend property of the device.""" + return self._backend + + def preprocess( + self, + execution_config: Optional[ExecutionConfig] = None, + ): + """Device preprocessing function. + + This function defines the device transform program to be applied and an updated device + configuration. + + TODO: This function is boilerplate only + """ + if execution_config is None: + execution_config = ExecutionConfig() + + transform_program = TransformProgram() + + return transform_program, execution_config + + def execute(self, circuits, execution_config): + """Python execution is not supported.""" + raise NotImplementedError("The OQD device only supports Catalyst.") + + +def _check_backend(backend): + """Helper function to check the backend.""" + if backend not in BACKENDS: + raise ValueError(f"The backend {backend} is not supported. Valid devices are {BACKENDS}") diff --git a/frontend/catalyst/third_party/oqd/src/CMakeLists.txt b/frontend/catalyst/third_party/oqd/src/CMakeLists.txt new file mode 100644 index 0000000000..2eee000b30 --- /dev/null +++ b/frontend/catalyst/third_party/oqd/src/CMakeLists.txt @@ -0,0 +1,31 @@ +cmake_minimum_required(VERSION 3.20) + +project(catalyst_oqd) + +set(runtime_includes "${PROJECT_SOURCE_DIR}/../../../../../runtime/include") +set(backend_includes "${PROJECT_SOURCE_DIR}/../../../../../runtime/lib/backend/common") +set(runtime_lib "${RUNTIME_BUILD_DIR}/lib") +set(oqd_backend_dir "${OQD_BUILD_DIR}/backend") + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +message(STATUS "Building the OQD device.") + +add_library(rtd_oqd SHARED OQDDevice.cpp) + +target_include_directories(rtd_oqd PUBLIC . + ${runtime_includes} + ${backend_includes} +) + +set(OQD_LIBRARIES + rtd_oqd +) + +set_target_properties(rtd_oqd PROPERTIES BUILD_RPATH "$ORIGIN/../utils") +target_link_directories(rtd_oqd PRIVATE ${runtime_lib}) + +file(COPY ${PROJECT_SOURCE_DIR}/oqd.toml DESTINATION ./backend) + +add_subdirectory(tests) diff --git a/frontend/catalyst/third_party/oqd/src/Makefile b/frontend/catalyst/third_party/oqd/src/Makefile new file mode 100644 index 0000000000..6a7d8a4a2c --- /dev/null +++ b/frontend/catalyst/third_party/oqd/src/Makefile @@ -0,0 +1,37 @@ +PYTHON?=$(shell which python3) +C_COMPILER?=$(shell which clang) +CXX_COMPILER?=$(shell which clang++) +NPROC?=$(shell python3 -c "import os; print(os.cpu_count())") + +MK_ABSPATH := $(abspath $(lastword $(MAKEFILE_LIST))) +MK_DIR := $(dir $(MK_ABSPATH)) +OQD_BUILD_DIR?=$(MK_DIR)/build +RT_BUILD_DIR?=$(MK_DIR)/../../../../../runtime/build + +.PHONY: configure +configure: + @echo "Configure OQD Runtime" + + cmake -G Ninja -B $(OQD_BUILD_DIR) \ + -DCMAKE_C_COMPILER=$(C_COMPILER) \ + -DCMAKE_CXX_COMPILER=$(CXX_COMPILER) \ + -DRUNTIME_BUILD_DIR=$(RT_BUILD_DIR) + +$(OQD_BUILD_DIR)/librtd_oqd.so: configure + cmake --build $(OQD_BUILD_DIR) --target rtd_oqd -j$(NPROC) + +.PHONY: oqd +oqd: $(OQD_BUILD_DIR)/librtd_oqd.so + +$(OQD_BUILD_DIR)/tests/runner_tests_oqd: configure + cmake --build $(OQD_BUILD_DIR) --target runner_tests_oqd -j$(NPROC) + +.PHONY: test +test: $(OQD_BUILD_DIR)/tests/runner_tests_oqd + @echo "test the Catalyst runtime test suite" + $(OQD_BUILD_DIR)/tests/runner_tests_oqd + +.PHONY: clean +clean: + @echo "clean build files" + rm -rf $(OQD_BUILD_DIR) diff --git a/frontend/catalyst/third_party/oqd/src/OQDDevice.cpp b/frontend/catalyst/third_party/oqd/src/OQDDevice.cpp new file mode 100644 index 0000000000..fdea8c65f1 --- /dev/null +++ b/frontend/catalyst/third_party/oqd/src/OQDDevice.cpp @@ -0,0 +1,130 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "OQDDevice.hpp" + +namespace Catalyst::Runtime::Device { + +auto OQDDevice::AllocateQubits(size_t num_qubits) -> std::vector +{ + RT_FAIL("Unsupported functionality"); +} + +void OQDDevice::ReleaseAllQubits() { RT_FAIL("Unsupported functionality"); } + +void OQDDevice::ReleaseQubit([[maybe_unused]] QubitIdType q) +{ + RT_FAIL("Unsupported functionality"); +} + +auto OQDDevice::GetNumQubits() const -> size_t { RT_FAIL("Unsupported functionality"); } + +void OQDDevice::StartTapeRecording() +{ + RT_FAIL_IF(tape_recording, "Cannot re-activate the cache manager"); + tape_recording = true; + cache_manager.Reset(); +} + +void OQDDevice::StopTapeRecording() +{ + RT_FAIL_IF(!tape_recording, "Cannot stop an already stopped cache manager"); + tape_recording = false; +} + +void OQDDevice::SetDeviceShots([[maybe_unused]] size_t shots) { device_shots = shots; } + +auto OQDDevice::GetDeviceShots() const -> size_t { return device_shots; } + +auto OQDDevice::Zero() const -> Result { return const_cast(&GLOBAL_RESULT_FALSE_CONST); } + +auto OQDDevice::One() const -> Result { return const_cast(&GLOBAL_RESULT_TRUE_CONST); } + +void OQDDevice::NamedOperation(const std::string &name, const std::vector ¶ms, + const std::vector &wires, bool inverse, + const std::vector &controlled_wires, + const std::vector &controlled_values) +{ + RT_FAIL("Unsupported functionality"); +} + +void OQDDevice::PartialCounts(DataView &eigvals, DataView &counts, + const std::vector &wires, size_t shots) +{ + RT_FAIL("Unsupported functionality"); +} + +auto OQDDevice::AllocateQubit() -> QubitIdType { RT_FAIL("Unsupported functionality"); } +void OQDDevice::PrintState() { RT_FAIL("Unsupported functionality"); } + +void OQDDevice::Counts(DataView &eigvals, DataView &counts, size_t shots) +{ + RT_FAIL("Unsupported functionality"); +} + +auto OQDDevice::Measure([[maybe_unused]] QubitIdType wire, std::optional postselect) + -> Result +{ + RT_FAIL("Unsupported functionality"); +} + +ObsIdType OQDDevice::Observable(ObsId, const std::vector> &, + const std::vector &) +{ + RT_FAIL("Unsupported functionality"); +} + +ObsIdType OQDDevice::TensorObservable(const std::vector &) +{ + RT_FAIL("Unsupported functionality"); +}; + +ObsIdType OQDDevice::HamiltonianObservable(const std::vector &, + const std::vector &) +{ + RT_FAIL("Unsupported functionality"); +} + +void OQDDevice::MatrixOperation(const std::vector> &, + const std::vector &, bool, + const std::vector &, const std::vector &) +{ + RT_FAIL("Unsupported functionality"); +} + +double OQDDevice::Expval(ObsIdType) { RT_FAIL("Unsupported functionality"); }; +double OQDDevice::Var(ObsIdType) { RT_FAIL("Unsupported functionality"); }; +void OQDDevice::State(DataView, 1> &) +{ + RT_FAIL("Unsupported functionality"); +}; +void OQDDevice::Probs(DataView &) { RT_FAIL("Unsupported functionality"); }; +void OQDDevice::PartialProbs(DataView &, const std::vector &) +{ + RT_FAIL("Unsupported functionality"); +}; +void OQDDevice::Sample(DataView &, size_t) { RT_FAIL("Unsupported functionality"); }; +void OQDDevice::PartialSample(DataView &, const std::vector &, size_t) +{ + RT_FAIL("Unsupported functionality"); +} + +void OQDDevice::Gradient(std::vector> &, const std::vector &) +{ + RT_FAIL("Unsupported functionality"); +} + +} // namespace Catalyst::Runtime::Device + +GENERATE_DEVICE_FACTORY(oqd, Catalyst::Runtime::Device::OQDDevice); diff --git a/frontend/catalyst/third_party/oqd/src/OQDDevice.hpp b/frontend/catalyst/third_party/oqd/src/OQDDevice.hpp new file mode 100644 index 0000000000..4129ad3cfa --- /dev/null +++ b/frontend/catalyst/third_party/oqd/src/OQDDevice.hpp @@ -0,0 +1,74 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "QuantumDevice.hpp" + +// catalyst/runtime/lib/backend/common/ + +#include "CacheManager.hpp" +#include "QubitManager.hpp" +#include "Utils.hpp" + +#include "OQDRunner.hpp" + +namespace Catalyst::Runtime::Device { +class OQDDevice final : public Catalyst::Runtime::QuantumDevice { + private: + // static constants for RESULT values + static constexpr bool GLOBAL_RESULT_TRUE_CONST{true}; + static constexpr bool GLOBAL_RESULT_FALSE_CONST{false}; + + Catalyst::Runtime::QubitManager qubit_manager{}; + std::unique_ptr runner; + + Catalyst::Runtime::CacheManager> cache_manager{}; + bool tape_recording{false}; + size_t device_shots; + + std::unordered_map device_kwargs; + + inline auto getDeviceWires(const std::vector &wires) -> std::vector + { + std::vector res; + res.reserve(wires.size()); + std::transform(wires.begin(), wires.end(), std::back_inserter(res), + [this](auto w) { return qubit_manager.getDeviceId(w); }); + return res; + } + + public: + explicit OQDDevice(const std::string &kwargs = "{device_type : oqd, backend : default}") + { + device_kwargs = Catalyst::Runtime::parse_kwargs(kwargs); + device_shots = device_kwargs.contains("shots") + ? static_cast(std::stoll(device_kwargs["shots"])) + : 0; + runner = std::make_unique(); + } + ~OQDDevice() = default; + + QUANTUM_DEVICE_DEL_DECLARATIONS(OQDDevice); + + QUANTUM_DEVICE_RT_DECLARATIONS; + QUANTUM_DEVICE_QIS_DECLARATIONS; +}; +} // namespace Catalyst::Runtime::Device diff --git a/frontend/catalyst/third_party/oqd/src/OQDRunner.hpp b/frontend/catalyst/third_party/oqd/src/OQDRunner.hpp new file mode 100644 index 0000000000..092326877b --- /dev/null +++ b/frontend/catalyst/third_party/oqd/src/OQDRunner.hpp @@ -0,0 +1,115 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "Exception.hpp" + +namespace Catalyst::Runtime::Device { + +/** + * The OQD circuit runner interface. + */ +struct OQDRunnerBase { + explicit OQDRunnerBase() = default; + virtual ~OQDRunnerBase() = default; + + [[nodiscard]] virtual auto runCircuit([[maybe_unused]] const std::string &circuit, + [[maybe_unused]] const std::string &device, + [[maybe_unused]] size_t shots, + [[maybe_unused]] const std::string &kwargs = "") const + -> std::string + { + RT_FAIL("Not implemented method"); + return {}; + } + [[nodiscard]] virtual auto + Probs([[maybe_unused]] const std::string &circuit, [[maybe_unused]] const std::string &device, + [[maybe_unused]] size_t shots, [[maybe_unused]] size_t num_qubits, + [[maybe_unused]] const std::string &kwargs = "") const -> std::vector + { + RT_FAIL("Not implemented method"); + return {}; + } + [[nodiscard]] virtual auto + Sample([[maybe_unused]] const std::string &circuit, [[maybe_unused]] const std::string &device, + [[maybe_unused]] size_t shots, [[maybe_unused]] size_t num_qubits, + [[maybe_unused]] const std::string &kwargs = "") const -> std::vector + { + RT_FAIL("Not implemented method"); + return {}; + } + [[nodiscard]] virtual auto + Counts([[maybe_unused]] const std::string &circuit, [[maybe_unused]] const std::string &device, + [[maybe_unused]] size_t shots, [[maybe_unused]] size_t num_qubits, + [[maybe_unused]] const std::string &kwargs = "") const -> std::vector + { + RT_FAIL("Not implemented method"); + return {}; + } + [[nodiscard]] virtual auto + Expval([[maybe_unused]] const std::string &circuit, [[maybe_unused]] const std::string &device, + [[maybe_unused]] size_t shots, [[maybe_unused]] const std::string &kwargs = "") const + -> double + { + RT_FAIL("Not implemented method"); + return {}; + } + [[nodiscard]] virtual auto Var([[maybe_unused]] const std::string &circuit, + [[maybe_unused]] const std::string &device, + [[maybe_unused]] size_t shots, + [[maybe_unused]] const std::string &kwargs = "") const -> double + { + RT_FAIL("Not implemented method"); + return {}; + } + [[nodiscard]] virtual auto + State([[maybe_unused]] const std::string &circuit, [[maybe_unused]] const std::string &device, + [[maybe_unused]] size_t shots, [[maybe_unused]] size_t num_qubits, + [[maybe_unused]] const std::string &kwargs = "") const + -> std::vector> + { + RT_FAIL("Not implemented method"); + return {}; + } + [[nodiscard]] virtual auto Gradient([[maybe_unused]] const std::string &circuit, + [[maybe_unused]] const std::string &device, + [[maybe_unused]] size_t shots, + [[maybe_unused]] size_t num_qubits, + [[maybe_unused]] const std::string &kwargs = "") const + -> std::vector + { + RT_FAIL("Not implemented method"); + return {}; + } +}; + +/** + * The OQD circuit runner to execute a circuit on OQD devices. + */ +struct OQDRunner : public OQDRunnerBase { + [[nodiscard]] auto Counts(const std::string &circuit, const std::string &device, size_t shots, + size_t num_qubits, const std::string &kwargs = "") const + -> std::vector + { + RT_FAIL("Not implemented method"); + return {}; + } +}; + +} // namespace Catalyst::Runtime::Device diff --git a/frontend/catalyst/third_party/oqd/src/oqd.toml b/frontend/catalyst/third_party/oqd/src/oqd.toml new file mode 100644 index 0000000000..3fc91beafb --- /dev/null +++ b/frontend/catalyst/third_party/oqd/src/oqd.toml @@ -0,0 +1,73 @@ +# Note: This file is a template and is only meant for testing. +# The actual parameters will be filled in later. + +schema = 3 + +# The set of all gate types supported at the runtime execution interface of the +# device, i.e., what is supported by the `execute` method of the Device API. +# The gate definition has the following format: +# +# GATE = { properties = [ PROPS ], conditions = [ CONDS ] } +# +# where PROPS and CONS are zero or more comma separated quoted strings. +# +# PROPS: zero or more comma-separated quoted strings: +# - "controllable": if a controlled version of this gate is supported. +# - "invertible": if the adjoint of this operation is supported. +# - "differentiable": if device gradient is supported for this gate. +# CONDS: zero or more comma-separated quoted strings: +# - "analytic" or "finiteshots": if this operation is only supported in +# either analytic execution or with shots, respectively. +# - "terms-commute": if this composite operator is only supported +# given that its terms commute. Only relevant for Prod, SProd, Sum, +# LinearCombination, and Hamiltonian. +# +[operators.gates] + +Identity = { } +Hadamard = { } +PauliX = { } +PauliY = { } +PauliZ = { } +S = { } +T = { } +CNOT = { } +Toffoli = { } +CY = { } +CZ = { } +SWAP = { } +CSWAP = { } +RX = { } +RY = { } +RZ = { } +CRX = { } +CRY = { } +CRZ = { } +PhaseShift = { } +U1 = { } +U2 = { } +U3 = { } + +[operators.observables] + +[measurement_processes] + +CountsMP = { conditions = ["finiteshots"] } + +# Additional support that the device may provide. All accepted fields and their +# default values are listed below. Any fields missing from the TOML file will be +# set to their default values. +[compilation] + +# Whether the device is compatible with qjit. +qjit_compatible = true +# Whether the device requires run time generation of the quantum circuit. +runtime_code_generation = true +# The methods of handling mid-circuit measurements that the device supports, e.g., +# "one-shot", "device", "tree-traversal", etc. An empty list indicates that the device +# does not support mid-circuit measurements. +supported_mcm_methods = [ ] +# Whether the device supports dynamic qubit allocation/deallocation. +dynamic_qubit_management = false +# Whether simultaneous measurements of non-commuting observables is supported. +non_commuting_observables = false diff --git a/frontend/catalyst/third_party/oqd/src/tests/CMakeLists.txt b/frontend/catalyst/third_party/oqd/src/tests/CMakeLists.txt new file mode 100644 index 0000000000..a8eefd2386 --- /dev/null +++ b/frontend/catalyst/third_party/oqd/src/tests/CMakeLists.txt @@ -0,0 +1,35 @@ +Include(FetchContent) + +FetchContent_Declare( + Catch2 + GIT_REPOSITORY https://github.com/catchorg/Catch2.git + GIT_TAG v2.13.9 +) + +FetchContent_MakeAvailable(Catch2) + +# Required for catch_discover_tests(). +list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/contrib) + +# Modify `ctest` to only run the supported subset of tests. +include(CTest) +include(Catch) + +add_executable(runner_tests_oqd runner_main.cpp) + +target_include_directories(runner_tests_oqd PRIVATE + ${OQD_LIBRARIES} +) + +target_link_directories(runner_tests_oqd PRIVATE ${runtime_lib}) + +target_link_libraries(runner_tests_oqd PRIVATE + Catch2::Catch2 + ${OQD_LIBRARIES} +) + +target_sources(runner_tests_oqd PRIVATE + Test_OQDDevice.cpp +) + +catch_discover_tests(runner_tests_oqd) diff --git a/frontend/catalyst/third_party/oqd/src/tests/Test_OQDDevice.cpp b/frontend/catalyst/third_party/oqd/src/tests/Test_OQDDevice.cpp new file mode 100644 index 0000000000..0f6d6dcafd --- /dev/null +++ b/frontend/catalyst/third_party/oqd/src/tests/Test_OQDDevice.cpp @@ -0,0 +1,31 @@ +// Copyright 2023 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "OQDDevice.cpp" + +#include + +using namespace Catalyst::Runtime::Device; + +TEST_CASE("Test the OQDDevice constructor", "[oqd]") +{ + auto device = OQDDevice("{shots : 100}"); + + REQUIRE_THROWS_WITH(device.GetNumQubits(), Catch::Contains("Unsupported functionality")); + REQUIRE_THROWS_WITH(device.PrintState(), Catch::Contains("Unsupported functionality")); + REQUIRE_THROWS_WITH(device.AllocateQubit(), Catch::Contains("Unsupported functionality")); + REQUIRE_THROWS_WITH(device.Measure(0), Catch::Contains("Unsupported functionality")); + REQUIRE_THROWS_WITH(device.Expval(0), Catch::Contains("Unsupported functionality")); + REQUIRE_THROWS_WITH(device.Var(0), Catch::Contains("Unsupported functionality")); +} diff --git a/frontend/catalyst/third_party/oqd/src/tests/runner_main.cpp b/frontend/catalyst/third_party/oqd/src/tests/runner_main.cpp new file mode 100644 index 0000000000..434049dc18 --- /dev/null +++ b/frontend/catalyst/third_party/oqd/src/tests/runner_main.cpp @@ -0,0 +1,16 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define CATCH_CONFIG_MAIN +#include diff --git a/frontend/catalyst/tracing/contexts.py b/frontend/catalyst/tracing/contexts.py index 5d5031b4b5..1c067b1795 100644 --- a/frontend/catalyst/tracing/contexts.py +++ b/frontend/catalyst/tracing/contexts.py @@ -20,7 +20,8 @@ from contextlib import contextmanager from dataclasses import dataclass from enum import Enum -from typing import ContextManager, Dict, List, Optional, Tuple +from pathlib import Path +from typing import ContextManager, Dict, List, Optional, Set, Tuple from jax._src.core import MainTrace as JaxMainTrace from jax._src.core import cur_sublevel, new_base_main @@ -174,6 +175,7 @@ class EvaluationContext: """ _tracing_stack: List[Tuple[EvaluationMode, Optional[JaxTracingContext]]] = [] + _mlir_plugins: Set[Path] = set() @debug_logger_init def __init__(self, mode: EvaluationMode): @@ -184,6 +186,20 @@ def __init__(self, mode: EvaluationMode): self.mode = mode self.ctx = None + @classmethod + def add_plugin(cls, plugin: Path): + """Add an MLIR plugin to the set of MLIR plugins encountered in the + program""" + cls._mlir_plugins.add(plugin) + + @classmethod + def get_plugins(cls): + """Get and reset all plugins encountered during the trace of the + program""" + retval = cls._mlir_plugins + cls._mlir_plugins = set() + return retval + @classmethod @contextmanager def _create_tracing_context(cls, mode) -> ContextManager[JaxTracingContext]: diff --git a/frontend/catalyst/utils/CMakeLists.txt b/frontend/catalyst/utils/CMakeLists.txt index 5dd037086c..51ec6d2350 100644 --- a/frontend/catalyst/utils/CMakeLists.txt +++ b/frontend/catalyst/utils/CMakeLists.txt @@ -1,9 +1,9 @@ -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_STANDARD_REQUIRED ON) +cmake_minimum_required(VERSION 3.26) + +project(catalyst_frontend) -find_package(Python 3 - REQUIRED COMPONENTS Interpreter Development.Module - OPTIONAL_COMPONENTS Development.SABIModule) +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) # nanobind suggests including these lines to configure CMake to perform an optimized release build # by default unless another build type is specified. Without this addition, binding code may run @@ -14,19 +14,16 @@ if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") endif() -# Detect the installed nanobind package and import it into CMake -execute_process( - COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir - OUTPUT_VARIABLE nanobind_ROOT OUTPUT_STRIP_TRAILING_WHITESPACE) - -find_package(nanobind CONFIG REQUIRED) - -# Get the NumPy include directory +# Locate Python & nanobind +find_package(Python REQUIRED + COMPONENTS Interpreter Development.Module NumPy + OPTIONAL_COMPONENTS Development.SABIModule +) execute_process( - COMMAND "${Python_EXECUTABLE}" -c "import numpy; print(numpy.get_include())" - OUTPUT_VARIABLE NUMPY_INCLUDE_DIR - OUTPUT_STRIP_TRAILING_WHITESPACE + COMMAND "${Python_EXECUTABLE}" -c "import nanobind; print(nanobind.cmake_dir())" + OUTPUT_VARIABLE nanobind_DIR OUTPUT_STRIP_TRAILING_WHITESPACE ) +find_package(nanobind CONFIG REQUIRED) # Source file list for `wrapper` module set(WRAPPER_SRC_FILES @@ -39,7 +36,7 @@ set(WRAPPER_SRC_FILES nanobind_add_module(wrapper STABLE_ABI ${WRAPPER_SRC_FILES}) # Add the NumPy include directory to the library's include paths -target_include_directories(wrapper PRIVATE ${NUMPY_INCLUDE_DIR}) +target_include_directories(wrapper PRIVATE ${Python_NumPy_INCLUDE_DIRS}) # Use suffix ".so" rather than ".abi3.so" for library file using Stable ABI # This is necessary for compatibility with setuptools build extensions diff --git a/frontend/catalyst/utils/c_template.py b/frontend/catalyst/utils/c_template.py index 3f2ed628ee..2319e53587 100644 --- a/frontend/catalyst/utils/c_template.py +++ b/frontend/catalyst/utils/c_template.py @@ -270,7 +270,8 @@ def _get_sizes(array): @staticmethod def _get_strides(array): - strides = [str(stride // 8) for stride in array.strides] + # Numpy uses units of bytes for their strides, but memrefs use number of elements. + strides = [str(stride // array.itemsize) for stride in array.strides] strides_str = ",".join(strides) return strides_str diff --git a/frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels_using_lapack.cpp b/frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels_using_lapack.cpp index b3875101be..fe98badc5e 100644 --- a/frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels_using_lapack.cpp +++ b/frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels_using_lapack.cpp @@ -35,126 +35,140 @@ #include "lapack_kernels.hpp" +// With SciPy OpenBLAS, symbols are now prefixed with scipy_ when using scipy>=1.14 or +// scipy-openblas32, whereas this is not the case for the reference implementation used on M1 mac. +#if defined(__APPLE__) && defined(__arm64__) +#define SYM_PREFIX +#else +#define SYM_PREFIX scipy_ +#endif + +// The CONCAT macro and its helper CONCAT_ are required here since macro arguments are not expanded +// around a ## preprocessing token. See http://port70.net/%7Ensz/c/c11/n1570.html#6.10.3.1. +#define CONCAT_(X, Y) X##Y +#define CONCAT(X, Y) CONCAT_(X, Y) +#define GET_SYMBOL(X) CONCAT(SYM_PREFIX, X) + // From a Python binary, JAX obtains its LAPACK/BLAS kernels from Scipy, but // a C++ user should link against LAPACK directly. This is needed when using // JAX-generated HLO from C++. extern "C" { -jax::RealTrsm::FnType cblas_strsm; -jax::RealTrsm::FnType cblas_dtrsm; -jax::ComplexTrsm>::FnType cblas_ctrsm; -jax::ComplexTrsm>::FnType cblas_ztrsm; - -jax::Getrf::FnType LAPACKE_sgetrf; -jax::Getrf::FnType LAPACKE_dgetrf; -jax::Getrf>::FnType LAPACKE_cgetrf; -jax::Getrf>::FnType LAPACKE_zgetrf; - -jax::Geqrf::FnType LAPACKE_sgeqrf; -jax::Geqrf::FnType LAPACKE_dgeqrf; -jax::Geqrf>::FnType LAPACKE_cgeqrf; -jax::Geqrf>::FnType LAPACKE_zgeqrf; - -jax::Orgqr::FnType LAPACKE_sorgqr; -jax::Orgqr::FnType LAPACKE_dorgqr; -jax::Orgqr>::FnType LAPACKE_cungqr; -jax::Orgqr>::FnType LAPACKE_zungqr; - -jax::Potrf::FnType LAPACKE_spotrf; -jax::Potrf::FnType LAPACKE_dpotrf; -jax::Potrf>::FnType LAPACKE_cpotrf; -jax::Potrf>::FnType LAPACKE_zpotrf; - -jax::RealGesdd::FnType LAPACKE_sgesdd; -jax::RealGesdd::FnType LAPACKE_dgesdd; -jax::ComplexGesdd>::FnType LAPACKE_cgesdd; -jax::ComplexGesdd>::FnType LAPACKE_zgesdd; - -jax::RealSyevd::FnType LAPACKE_ssyevd; -jax::RealSyevd::FnType LAPACKE_dsyevd; -jax::ComplexHeevd>::FnType LAPACKE_cheevd; -jax::ComplexHeevd>::FnType LAPACKE_zheevd; - -jax::RealGeev::FnType LAPACKE_sgeev; -jax::RealGeev::FnType LAPACKE_dgeev; -jax::ComplexGeev>::FnType LAPACKE_cgeev; -jax::ComplexGeev>::FnType LAPACKE_zgeev; - -jax::RealGees::FnType LAPACKE_sgees; -jax::RealGees::FnType LAPACKE_dgees; -jax::ComplexGees>::FnType LAPACKE_cgees; -jax::ComplexGees>::FnType LAPACKE_zgees; - -jax::Gehrd::FnType LAPACKE_sgehrd; -jax::Gehrd::FnType LAPACKE_dgehrd; -jax::Gehrd>::FnType LAPACKE_cgehrd; -jax::Gehrd>::FnType LAPACKE_zgehrd; - -jax::Sytrd::FnType LAPACKE_ssytrd; -jax::Sytrd::FnType LAPACKE_dsytrd; -jax::Sytrd>::FnType LAPACKE_chetrd; -jax::Sytrd>::FnType LAPACKE_zhetrd; +jax::RealTrsm::FnType GET_SYMBOL(cblas_strsm); +jax::RealTrsm::FnType GET_SYMBOL(cblas_dtrsm); +jax::ComplexTrsm>::FnType GET_SYMBOL(cblas_ctrsm); +jax::ComplexTrsm>::FnType GET_SYMBOL(cblas_ztrsm); + +jax::Getrf::FnType GET_SYMBOL(LAPACKE_sgetrf); +jax::Getrf::FnType GET_SYMBOL(LAPACKE_dgetrf); +jax::Getrf>::FnType GET_SYMBOL(LAPACKE_cgetrf); +jax::Getrf>::FnType GET_SYMBOL(LAPACKE_zgetrf); + +jax::Geqrf::FnType GET_SYMBOL(LAPACKE_sgeqrf); +jax::Geqrf::FnType GET_SYMBOL(LAPACKE_dgeqrf); +jax::Geqrf>::FnType GET_SYMBOL(LAPACKE_cgeqrf); +jax::Geqrf>::FnType GET_SYMBOL(LAPACKE_zgeqrf); + +jax::Orgqr::FnType GET_SYMBOL(LAPACKE_sorgqr); +jax::Orgqr::FnType GET_SYMBOL(LAPACKE_dorgqr); +jax::Orgqr>::FnType GET_SYMBOL(LAPACKE_cungqr); +jax::Orgqr>::FnType GET_SYMBOL(LAPACKE_zungqr); + +jax::Potrf::FnType GET_SYMBOL(LAPACKE_spotrf); +jax::Potrf::FnType GET_SYMBOL(LAPACKE_dpotrf); +jax::Potrf>::FnType GET_SYMBOL(LAPACKE_cpotrf); +jax::Potrf>::FnType GET_SYMBOL(LAPACKE_zpotrf); + +jax::RealGesdd::FnType GET_SYMBOL(LAPACKE_sgesdd); +jax::RealGesdd::FnType GET_SYMBOL(LAPACKE_dgesdd); +jax::ComplexGesdd>::FnType GET_SYMBOL(LAPACKE_cgesdd); +jax::ComplexGesdd>::FnType GET_SYMBOL(LAPACKE_zgesdd); + +jax::RealSyevd::FnType GET_SYMBOL(LAPACKE_ssyevd); +jax::RealSyevd::FnType GET_SYMBOL(LAPACKE_dsyevd); +jax::ComplexHeevd>::FnType GET_SYMBOL(LAPACKE_cheevd); +jax::ComplexHeevd>::FnType GET_SYMBOL(LAPACKE_zheevd); + +jax::RealGeev::FnType GET_SYMBOL(LAPACKE_sgeev); +jax::RealGeev::FnType GET_SYMBOL(LAPACKE_dgeev); +jax::ComplexGeev>::FnType GET_SYMBOL(LAPACKE_cgeev); +jax::ComplexGeev>::FnType GET_SYMBOL(LAPACKE_zgeev); + +jax::RealGees::FnType GET_SYMBOL(LAPACKE_sgees); +jax::RealGees::FnType GET_SYMBOL(LAPACKE_dgees); +jax::ComplexGees>::FnType GET_SYMBOL(LAPACKE_cgees); +jax::ComplexGees>::FnType GET_SYMBOL(LAPACKE_zgees); + +jax::Gehrd::FnType GET_SYMBOL(LAPACKE_sgehrd); +jax::Gehrd::FnType GET_SYMBOL(LAPACKE_dgehrd); +jax::Gehrd>::FnType GET_SYMBOL(LAPACKE_cgehrd); +jax::Gehrd>::FnType GET_SYMBOL(LAPACKE_zgehrd); + +jax::Sytrd::FnType GET_SYMBOL(LAPACKE_ssytrd); +jax::Sytrd::FnType GET_SYMBOL(LAPACKE_dsytrd); +jax::Sytrd>::FnType GET_SYMBOL(LAPACKE_chetrd); +jax::Sytrd>::FnType GET_SYMBOL(LAPACKE_zhetrd); } // extern "C" namespace jax { static auto init = []() -> int { - RealTrsm::fn = cblas_strsm; - RealTrsm::fn = cblas_dtrsm; - ComplexTrsm>::fn = cblas_ctrsm; - ComplexTrsm>::fn = cblas_ztrsm; - - Getrf::fn = LAPACKE_sgetrf; - Getrf::fn = LAPACKE_dgetrf; - Getrf>::fn = LAPACKE_cgetrf; - Getrf>::fn = LAPACKE_zgetrf; - - Geqrf::fn = LAPACKE_sgeqrf; - Geqrf::fn = LAPACKE_dgeqrf; - Geqrf>::fn = LAPACKE_cgeqrf; - Geqrf>::fn = LAPACKE_zgeqrf; - - Orgqr::fn = LAPACKE_sorgqr; - Orgqr::fn = LAPACKE_dorgqr; - Orgqr>::fn = LAPACKE_cungqr; - Orgqr>::fn = LAPACKE_zungqr; - - Potrf::fn = LAPACKE_spotrf; - Potrf::fn = LAPACKE_dpotrf; - Potrf>::fn = LAPACKE_cpotrf; - Potrf>::fn = LAPACKE_zpotrf; - - RealGesdd::fn = LAPACKE_sgesdd; - RealGesdd::fn = LAPACKE_dgesdd; - ComplexGesdd>::fn = LAPACKE_cgesdd; - ComplexGesdd>::fn = LAPACKE_zgesdd; - - RealSyevd::fn = LAPACKE_ssyevd; - RealSyevd::fn = LAPACKE_dsyevd; - ComplexHeevd>::fn = LAPACKE_cheevd; - ComplexHeevd>::fn = LAPACKE_zheevd; - - RealGeev::fn = LAPACKE_sgeev; - RealGeev::fn = LAPACKE_dgeev; - ComplexGeev>::fn = LAPACKE_cgeev; - ComplexGeev>::fn = LAPACKE_zgeev; - - RealGees::fn = LAPACKE_sgees; - RealGees::fn = LAPACKE_dgees; - ComplexGees>::fn = LAPACKE_cgees; - ComplexGees>::fn = LAPACKE_zgees; - - Gehrd::fn = LAPACKE_sgehrd; - Gehrd::fn = LAPACKE_dgehrd; - Gehrd>::fn = LAPACKE_cgehrd; - Gehrd>::fn = LAPACKE_zgehrd; - - Sytrd::fn = LAPACKE_ssytrd; - Sytrd::fn = LAPACKE_dsytrd; - Sytrd>::fn = LAPACKE_chetrd; - Sytrd>::fn = LAPACKE_zhetrd; + RealTrsm::fn = GET_SYMBOL(cblas_strsm); + RealTrsm::fn = GET_SYMBOL(cblas_dtrsm); + ComplexTrsm>::fn = GET_SYMBOL(cblas_ctrsm); + ComplexTrsm>::fn = GET_SYMBOL(cblas_ztrsm); + + Getrf::fn = GET_SYMBOL(LAPACKE_sgetrf); + Getrf::fn = GET_SYMBOL(LAPACKE_dgetrf); + Getrf>::fn = GET_SYMBOL(LAPACKE_cgetrf); + Getrf>::fn = GET_SYMBOL(LAPACKE_zgetrf); + + Geqrf::fn = GET_SYMBOL(LAPACKE_sgeqrf); + Geqrf::fn = GET_SYMBOL(LAPACKE_dgeqrf); + Geqrf>::fn = GET_SYMBOL(LAPACKE_cgeqrf); + Geqrf>::fn = GET_SYMBOL(LAPACKE_zgeqrf); + + Orgqr::fn = GET_SYMBOL(LAPACKE_sorgqr); + Orgqr::fn = GET_SYMBOL(LAPACKE_dorgqr); + Orgqr>::fn = GET_SYMBOL(LAPACKE_cungqr); + Orgqr>::fn = GET_SYMBOL(LAPACKE_zungqr); + + Potrf::fn = GET_SYMBOL(LAPACKE_spotrf); + Potrf::fn = GET_SYMBOL(LAPACKE_dpotrf); + Potrf>::fn = GET_SYMBOL(LAPACKE_cpotrf); + Potrf>::fn = GET_SYMBOL(LAPACKE_zpotrf); + + RealGesdd::fn = GET_SYMBOL(LAPACKE_sgesdd); + RealGesdd::fn = GET_SYMBOL(LAPACKE_dgesdd); + ComplexGesdd>::fn = GET_SYMBOL(LAPACKE_cgesdd); + ComplexGesdd>::fn = GET_SYMBOL(LAPACKE_zgesdd); + + RealSyevd::fn = GET_SYMBOL(LAPACKE_ssyevd); + RealSyevd::fn = GET_SYMBOL(LAPACKE_dsyevd); + ComplexHeevd>::fn = GET_SYMBOL(LAPACKE_cheevd); + ComplexHeevd>::fn = GET_SYMBOL(LAPACKE_zheevd); + + RealGeev::fn = GET_SYMBOL(LAPACKE_sgeev); + RealGeev::fn = GET_SYMBOL(LAPACKE_dgeev); + ComplexGeev>::fn = GET_SYMBOL(LAPACKE_cgeev); + ComplexGeev>::fn = GET_SYMBOL(LAPACKE_zgeev); + + RealGees::fn = GET_SYMBOL(LAPACKE_sgees); + RealGees::fn = GET_SYMBOL(LAPACKE_dgees); + ComplexGees>::fn = GET_SYMBOL(LAPACKE_cgees); + ComplexGees>::fn = GET_SYMBOL(LAPACKE_zgees); + + Gehrd::fn = GET_SYMBOL(LAPACKE_sgehrd); + Gehrd::fn = GET_SYMBOL(LAPACKE_dgehrd); + Gehrd>::fn = GET_SYMBOL(LAPACKE_cgehrd); + Gehrd>::fn = GET_SYMBOL(LAPACKE_zgehrd); + + Sytrd::fn = GET_SYMBOL(LAPACKE_ssytrd); + Sytrd::fn = GET_SYMBOL(LAPACKE_dsytrd); + Sytrd>::fn = GET_SYMBOL(LAPACKE_chetrd); + Sytrd>::fn = GET_SYMBOL(LAPACKE_zhetrd); return 0; }(); diff --git a/frontend/catalyst/utils/runtime_environment.py b/frontend/catalyst/utils/runtime_environment.py index bc61f13bcf..27df433049 100644 --- a/frontend/catalyst/utils/runtime_environment.py +++ b/frontend/catalyst/utils/runtime_environment.py @@ -28,6 +28,7 @@ "runtime": os.path.join(package_root, "../../../runtime/build/lib"), "enzyme": os.path.join(package_root, "../../../mlir/Enzyme/build/Enzyme"), "oqc_runtime": os.path.join(package_root, "../../catalyst/third_party/oqc/src/build"), + "oqd_runtime": os.path.join(package_root, "../../catalyst/third_party/oqd/src/build"), } DEFAULT_BIN_PATHS = { diff --git a/frontend/catalyst/utils/toml.py b/frontend/catalyst/utils/toml.py deleted file mode 100644 index 14ffe81b6e..0000000000 --- a/frontend/catalyst/utils/toml.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright 2023 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Module for abstracting which toml_load to use. -""" - -import importlib.util -from dataclasses import dataclass, field -from functools import reduce -from itertools import repeat -from typing import Any, Dict, List, Set - -from catalyst.utils.exceptions import CompileError - -# TODO: -# Once Python version 3.11 is the oldest supported Python version, we can remove tomlkit -# and rely exclusively on tomllib. - -# New in version 3.11 -# https://docs.python.org/3/library/tomllib.html -tomllib = importlib.util.find_spec("tomllib") -tomlkit = importlib.util.find_spec("tomlkit") -# We need at least one of these to make sure we can read toml files. -if tomllib is None and tomlkit is None: # pragma: nocover - raise ImportError("Either tomllib or tomlkit need to be installed.") - -# Give preference to tomllib -if tomllib: # pragma: nocover - from tomllib import load as toml_load - - TOMLDocument = Any - TOMLException = Exception -else: # pragma: nocover - from tomlkit import TOMLDocument - from tomlkit import load as toml_load - from tomlkit.exceptions import TOMLKitError as TOMLException - -ALL_SUPPORTED_SCHEMAS = [2] - - -def read_toml_file(toml_file: str) -> TOMLDocument: - """Helper function opening toml file properly and reading it into a document""" - with open(toml_file, "rb") as f: - config = toml_load(f) - return config - - -@dataclass -class OperationProperties: - """Capabilities of a single operation""" - - invertible: bool = False - controllable: bool = False - differentiable: bool = False - - -def _intersect_properties(a: OperationProperties, b: OperationProperties) -> OperationProperties: - """Calculate the intersection of OperationProperties""" - return OperationProperties( - invertible=a.invertible and b.invertible, - controllable=a.controllable and b.controllable, - differentiable=a.differentiable and b.differentiable, - ) - - -@dataclass -class DeviceCapabilities: # pylint: disable=too-many-instance-attributes - """Quantum device capabilities""" - - native_ops: Dict[str, OperationProperties] = field(default_factory=dict) - to_decomp_ops: Dict[str, OperationProperties] = field(default_factory=dict) - to_matrix_ops: Dict[str, OperationProperties] = field(default_factory=dict) - native_obs: Dict[str, OperationProperties] = field(default_factory=dict) - measurement_processes: Set[str] = field(default_factory=set) - qjit_compatible_flag: bool = False - mid_circuit_measurement_flag: bool = False - runtime_code_generation_flag: bool = False - dynamic_qubit_management_flag: bool = False - non_commuting_observables_flag: bool = False - initial_state_prep_flag: bool = False - options: Dict[str, bool] = field(default_factory=dict) - - -def intersect_operations( - a: Dict[str, OperationProperties], b: Dict[str, OperationProperties] -) -> Dict[str, OperationProperties]: - """Intersects two sets of oepration properties""" - return {k: _intersect_properties(a[k], b[k]) for k in (a.keys() & b.keys())} - - -@dataclass -class ProgramFeatures: - """Program features, obtained from the user""" - - shots_present: bool - - -def _get_compilation_flag(config: TOMLDocument, flag_name: str) -> bool: - """Get the flag in the toml document 'compilation' section.""" - return bool(config.get("compilation", {}).get(flag_name, False)) - - -def _get_options(config: TOMLDocument) -> Dict[str, str]: - """Get custom options sections""" - return {str(k): str(v) for k, v in config.get("options", {}).items()} - - -def _parse_toml_section( - config: TOMLDocument, path: List[str], program_features: ProgramFeatures -) -> Dict[str, dict]: - """Parses the section of toml config file specified by `path`. Filters-out gates which don't - match condition. For now the only condition we support is `shots_present`.""" - gates = {} - analytic = "analytic" - finiteshots = "finiteshots" - try: - iterable = reduce(lambda x, y: x[y], path, config) - except TOMLException as _: # pylint: disable=broad-exception-caught - return {} - gen = iterable.items() if hasattr(iterable, "items") else zip(iterable, repeat({})) - for g, values in gen: - unknown_attrs = set(values) - {"condition", "properties"} - if len(unknown_attrs) > 0: - raise CompileError( - f"Configuration for gate '{str(g)}' has unknown attributes: {list(unknown_attrs)}" - ) - properties = values.get("properties", {}) - unknown_props = set(properties) - {"invertible", "controllable", "differentiable"} - if len(unknown_props) > 0: - raise CompileError( - f"Configuration for gate '{str(g)}' has unknown properties: {list(unknown_props)}" - ) - if "condition" in values: - # TODO: do not filter here. Parse the condition and then filter on demand instead. - conditions = values["condition"] - unknown_conditions = set(conditions) - {analytic, finiteshots} - if len(unknown_conditions) > 0: - raise CompileError( - f"Configuration for gate '{str(g)}' has unknown conditions: " - f"{list(unknown_conditions)}" - ) - if all(c in conditions for c in [analytic, finiteshots]): - raise CompileError( - f"Configuration for gate '{g}' can not contain both " - f"`{finiteshots}` and `{analytic}` conditions simultaniosly" - ) - if analytic in conditions and not program_features.shots_present: - gates[g] = values - elif finiteshots in conditions and program_features.shots_present: - gates[g] = values - else: - gates[g] = values - return gates - - -def _get_observables(config: TOMLDocument, program_features: ProgramFeatures) -> Dict[str, dict]: - """Override the set of supported observables.""" - return _parse_toml_section(config, ["operators", "observables"], program_features) - - -def _get_measurement_processes( - config: TOMLDocument, program_features: ProgramFeatures -) -> Dict[str, dict]: - """Get the measurements processes from the `native` section of the config.""" - return _parse_toml_section(config, ["measurement_processes"], program_features) - - -def _get_native_ops(config: TOMLDocument, program_features: ProgramFeatures) -> Dict[str, dict]: - """Get the gates from the `native` section of the config.""" - return _parse_toml_section(config, ["operators", "gates", "native"], program_features) - - -def _get_decomposable_gates( - config: TOMLDocument, program_features: ProgramFeatures -) -> Dict[str, dict]: - """Get gates that will be decomposed according to PL's decomposition rules. - - Args: - config (TOMLDocument): Configuration dictionary - """ - return _parse_toml_section(config, ["operators", "gates", "decomp"], program_features) - - -def _get_matrix_decomposable_gates( - config: TOMLDocument, program_features: ProgramFeatures -) -> Dict[str, dict]: - """Get gates that will be decomposed to QubitUnitary. - - Args: - config (TOMLDocument): Configuration dictionary - """ - return _parse_toml_section(config, ["operators", "gates", "matrix"], program_features) - - -def _get_operation_properties(config_props: dict) -> OperationProperties: - """Load operation properties from config""" - properties = config_props.get("properties", {}) - return OperationProperties( - invertible="invertible" in properties, - controllable="controllable" in properties, - differentiable="differentiable" in properties, - ) - - -def load_device_capabilities( - config: TOMLDocument, program_features: ProgramFeatures -) -> DeviceCapabilities: - """Load device capabilities from device config""" - - schema = int(config["schema"]) - assert schema in ALL_SUPPORTED_SCHEMAS, f"Unsupported config schema {schema}" - - native_gate_props = {} - for g, props in _get_native_ops(config, program_features).items(): - native_gate_props[g] = _get_operation_properties(props) - - matrix_decomp_props = {} - for g, props in _get_matrix_decomposable_gates(config, program_features).items(): - matrix_decomp_props[g] = _get_operation_properties(props) - - decomp_props = {} - for g, props in _get_decomposable_gates(config, program_features).items(): - decomp_props[g] = _get_operation_properties(props) - - observable_props = {} - for g, props in _get_observables(config, program_features).items(): - observable_props[g] = _get_operation_properties(props) - - measurements_props = set() - for g, props in _get_measurement_processes(config, program_features).items(): - measurements_props.add(g) - - return DeviceCapabilities( - native_ops=native_gate_props, - to_decomp_ops=decomp_props, - to_matrix_ops=matrix_decomp_props, - native_obs=observable_props, - measurement_processes=measurements_props, - qjit_compatible_flag=_get_compilation_flag(config, "qjit_compatible"), - mid_circuit_measurement_flag=_get_compilation_flag(config, "mid_circuit_measurement"), - runtime_code_generation_flag=_get_compilation_flag(config, "runtime_code_generation"), - dynamic_qubit_management_flag=_get_compilation_flag(config, "dynamic_qubit_management"), - non_commuting_observables_flag=_get_compilation_flag(config, "non_commuting_observables"), - initial_state_prep_flag=_get_compilation_flag(config, "initial_state_prep"), - options=_get_options(config), - ) diff --git a/frontend/catalyst/utils/toml_utils.py b/frontend/catalyst/utils/toml_utils.py new file mode 100644 index 0000000000..07fc2857d1 --- /dev/null +++ b/frontend/catalyst/utils/toml_utils.py @@ -0,0 +1,183 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +TOML Utilities +~~~~~~~~~~~~~~ + +A collection of utility functions for working with TOML configuration files. +""" + +import ast +import io +import math +import operator +import os +import sys +from os import PathLike +from pathlib import Path +from typing import Union + +if sys.version_info >= (3, 11): + import tomllib as toml # pragma: no cover +else: + import tomli as toml # pragma: no cover + + +# Alias for supported TOML input types +TomlInput = Union[str, io.StringIO, os.PathLike] + + +def load_toml(filepath_or_buffer: TomlInput) -> dict: + """Loads a TOML document from a file or string and returns the parsed dict. + + Args: + filepath_or_buffer: The path to the TOML file or a TOML document string. + + Returns: + dict: The parsed TOML document. + + Raises: + TypeError: If the input type is not supported. + """ + if isinstance(filepath_or_buffer, io.StringIO): + document = _load_toml_from_string(filepath_or_buffer.getvalue()) + + elif isinstance(filepath_or_buffer, (str, Path)) and os.path.isfile(filepath_or_buffer): + document = _load_toml_from_file(filepath_or_buffer) # pragma: no cover + + elif isinstance(filepath_or_buffer, str): + document = _load_toml_from_string(filepath_or_buffer) + + else: + raise TypeError("Input must be a string, io.StringIO, or a path-like object.") + + return document + + +def _load_toml_from_string(contents: str) -> dict: + """Loads a TOML string and returns the parsed dict.""" + return toml.loads(contents) + + +def _load_toml_from_file(filepath: PathLike) -> dict: + """Loads a TOML file and returns the parsed dict.""" + with open(filepath, "rb") as f: # pragma: no cover + return toml.load(f) + + +# Supported safe_eval operators and their corresponding functions +OPERATORS = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.Pow: operator.pow, + ast.UAdd: operator.pos, + ast.USub: operator.neg, +} + + +def safe_eval(expr: str) -> float: + """ + Safely evaluate a mathematical expression. + + Mathematical constants and functions from the math module are supported. + + The usage of `safe_eval` should be preferred wherever possible over Python's builtin `eval()` + function, whose ability to perform arbitrary code execution of a user's input makes it + inherently unsafe. The functionality of `safe_eval` is deliberately limited to basic + mathematical operations to prevent malicious code, intentional or not, from being evaluated. + + Parameters: + expr (str): The arithmetic expression to evaluate. + + Returns: + float: The result of the evaluated expression. + + Raises: + ValueError: If the expression is invalid or contains unsupported elements. + + Examples: + + >>> safe_eval("1 + 1e-1") + 1.1 + >>> safe_eval("1 + (2 * math.sin(math.pi / 2))") + 3.0 + """ + + def _eval(node): + if isinstance(node, ast.BinOp): # Binary operations (e.g., 1 + 2) + return _eval_binop(node) + + elif isinstance(node, ast.UnaryOp): # Unary operations (e.g., -1) + return _eval_unaryop(node) + + elif isinstance(node, ast.Call): # Function calls (e.g., math.sin(0.5)) + return _eval_call(node) + + elif isinstance(node, ast.Attribute): # Accessing attributes (e.g., math.pi) + return _eval_attr(node) + + elif isinstance(node, ast.Constant): # Python 3.8+ literal + return node.value + + else: + raise ValueError(f"Unsupported expression type: {type(node)}") + + def _eval_binop(node): + left = _eval(node.left) + right = _eval(node.right) + op_type = type(node.op) + if op_type in OPERATORS: + return OPERATORS[op_type](left, right) + else: + raise ValueError(f"Unsupported operator: {op_type}") + + def _eval_unaryop(node): + operand = _eval(node.operand) + op_type = type(node.op) + if op_type in OPERATORS: + return OPERATORS[op_type](operand) + else: + raise ValueError(f"Unsupported unary operator: {op_type}") + + def _eval_call(node): + if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): + module = node.func.value.id + func = node.func.attr + if module == "math" and hasattr(math, func): + args = [_eval(arg) for arg in node.args] + return getattr(math, func)(*args) + else: + raise ValueError(f"Unsupported function: {module}.{func}") + else: + raise ValueError("Unsupported function call structure") + + def _eval_attr(node): + if isinstance(node.value, ast.Name): + module = node.value.id + attr = node.attr + if module == "math" and hasattr(math, attr): + return getattr(math, attr) + else: + raise ValueError(f"Unsupported attribute: {module}.{attr}") + else: + raise ValueError("Unsupported attribute structure") + + try: + parsed_expr = ast.parse(expr, mode="eval") + return _eval(parsed_expr.body) + except Exception as e: + raise ValueError(f"Invalid expression: '{expr}'") from e diff --git a/frontend/catalyst/utils/wrapper.cpp b/frontend/catalyst/utils/wrapper.cpp index 6b8e9b0401..8326ef4ec0 100644 --- a/frontend/catalyst/utils/wrapper.cpp +++ b/frontend/catalyst/utils/wrapper.cpp @@ -190,8 +190,8 @@ nb::list move_returns(void *memref_array_ptr, nb::object result_desc, nb::object // Decrement reference counts. // The final ref count of `new_array` should be 2: one for the `returns` list and one for // the `numpy_arrays` dict. - Py_DECREF(pyLong); - Py_DECREF(new_array); + Py_DecRef(pyLong); + Py_DecRef(new_array); } return returns; } diff --git a/frontend/test/custom_device/custom_device.toml b/frontend/test/custom_device/custom_device.toml index 67cf8276d6..750944dfd6 100644 --- a/frontend/test/custom_device/custom_device.toml +++ b/frontend/test/custom_device/custom_device.toml @@ -1,52 +1,67 @@ -schema = 2 +schema = 3 + +# The set of all gate types supported at the runtime execution interface of the +# device, i.e., what is supported by the `execute` method of the Device API. +# The gate definition has the following format: +# +# GATE = { properties = [ PROPS ], conditions = [ CONDS ] } +# +# where PROPS and CONS are zero or more comma separated quoted strings. +# +# PROPS: zero or more comma-separated quoted strings: +# - "controllable": if a controlled version of this gate is supported. +# - "invertible": if the adjoint of this operation is supported. +# - "differentiable": if device gradient is supported for this gate. +# CONDS: zero or more comma-separated quoted strings: +# - "analytic" or "finiteshots": if this operation is only supported in +# either analytic execution or with shots, respectively. +# - "terms-commute": if this composite operator is only supported +# given that its terms commute. Only relevant for Prod, SProd, Sum, +# LinearCombination, and Hamiltonian. +# +[operators.gates] + +QubitUnitary = { properties = [ "controllable", "invertible", "differentiable" ] } +PauliX = { properties = [ "controllable", "invertible", "differentiable" ] } +PauliY = { properties = [ "controllable", "invertible", "differentiable" ] } +PauliZ = { properties = [ "controllable", "invertible", "differentiable" ] } +MultiRZ = { properties = [ "controllable", "invertible", "differentiable" ] } +Hadamard = { properties = [ "controllable", "invertible", "differentiable" ] } +S = { properties = [ "controllable", "invertible", "differentiable" ] } +T = { properties = [ "controllable", "invertible", "differentiable" ] } +CNOT = { properties = [ "invertible", "differentiable" ] } +SWAP = { properties = [ "controllable", "invertible", "differentiable" ] } +CSWAP = { properties = [ "invertible", "differentiable" ] } +Toffoli = { properties = [ "invertible", "differentiable" ] } +CY = { properties = [ "invertible", "differentiable" ] } +CZ = { properties = [ "invertible", "differentiable" ] } +PhaseShift = { properties = [ "controllable", "invertible", "differentiable" ] } +ControlledPhaseShift = { properties = [ "invertible", "differentiable" ] } +RX = { properties = [ "controllable", "invertible", "differentiable" ] } +RY = { properties = [ "controllable", "invertible", "differentiable" ] } +RZ = { properties = [ "controllable", "invertible", "differentiable" ] } +Rot = { properties = [ "controllable", "invertible", "differentiable" ] } +CRX = { properties = [ "invertible", "differentiable" ] } +CRY = { properties = [ "invertible", "differentiable" ] } +CRZ = { properties = [ "invertible", "differentiable" ] } +CRot = { properties = [ "invertible" ] } +Identity = { properties = [ "invertible", "differentiable" ] } +IsingXX = { properties = [ "controllable", "invertible", "differentiable" ] } +IsingYY = { properties = [ "controllable", "invertible", "differentiable" ] } +IsingZZ = { properties = [ "controllable", "invertible", "differentiable" ] } +IsingXY = { properties = [ "controllable", "invertible", "differentiable" ] } +GlobalPhase = { properties = [ "controllable", "invertible", "differentiable" ] } +BlockEncode = { properties = [ "invertible", "differentiable" ] } +SingleExcitation = { properties = [ "controllable", "invertible", "differentiable" ] } +SingleExcitationPlus = { properties = [ "controllable", "invertible", "differentiable" ] } +SingleExcitationMinus = { properties = [ "controllable", "invertible", "differentiable" ] } +DoubleExcitation = { properties = [ "controllable", "invertible", "differentiable" ] } +DoubleExcitationPlus = { properties = [ "controllable", "invertible", "differentiable" ] } +DoubleExcitationMinus = { properties = [ "controllable", "invertible", "differentiable" ] } + +# Operations supported by the execution in Python but not directly supported by the backend +[pennylane.operators.gates] -[operators.gates.native] - -QubitUnitary = { properties = [ "invertible", "controllable", "differentiable" ] } -PauliX = { properties = [ "controllable", "invertible", "differentiable" ] } -PauliY = { properties = [ "controllable", "invertible", "differentiable" ] } -PauliZ = { properties = [ "controllable", "invertible", "differentiable" ] } -MultiRZ = { properties = [ "controllable", "invertible", "differentiable" ] } -Hadamard = { properties = [ "controllable", "invertible", "differentiable" ] } -S = { properties = [ "controllable", "invertible", "differentiable" ] } -T = { properties = [ "controllable", "invertible", "differentiable" ] } -CNOT = { properties = [ "invertible", "differentiable" ] } -SWAP = { properties = [ "controllable", "invertible", "differentiable" ] } -CSWAP = { properties = [ "invertible", "differentiable" ] } -Toffoli = { properties = [ "invertible", "differentiable" ] } -CY = { properties = [ "invertible", "differentiable" ] } -CZ = { properties = [ "invertible", "differentiable" ] } -PhaseShift = { properties = [ "controllable", "invertible", "differentiable" ] } -ControlledPhaseShift = { properties = [ "invertible", "differentiable" ] } -RX = { properties = [ "controllable", "invertible", "differentiable" ] } -RY = { properties = [ "controllable", "invertible", "differentiable" ] } -RZ = { properties = [ "controllable", "invertible", "differentiable" ] } -Rot = { properties = [ "controllable", "invertible", "differentiable" ] } -CRX = { properties = [ "invertible", "differentiable" ] } -CRY = { properties = [ "invertible", "differentiable" ] } -CRZ = { properties = [ "invertible", "differentiable" ] } -CRot = { properties = [ "invertible" ] } -Identity = { properties = [ "invertible", "differentiable" ] } -IsingXX = { properties = [ "controllable", "invertible", "differentiable" ] } -IsingYY = { properties = [ "controllable", "invertible", "differentiable" ] } -IsingZZ = { properties = [ "controllable", "invertible", "differentiable" ] } -IsingXY = { properties = [ "controllable", "invertible", "differentiable" ] } -GlobalPhase = { properties = [ "controllable", "invertible", "differentiable" ] } -BlockEncode = { properties = [ "invertible", "differentiable" ] } -SingleExcitation = { properties = [ "invertible", "controllable", "differentiable"] } -SingleExcitationPlus = { properties = [ "invertible", "controllable", "differentiable"] } -SingleExcitationMinus = { properties = [ "invertible", "controllable", "differentiable"] } -DoubleExcitation = { properties = [ "invertible", "controllable", "differentiable"] } -DoubleExcitationPlus = { properties = [ "invertible", "controllable", "differentiable"] } -DoubleExcitationMinus = { properties = [ "invertible", "controllable", "differentiable"] } - -[operators.gates.decomp] - -# Operators that should be decomposed according to the algorithm used -# by PennyLane's device API. -# Optional, since gates not listed in this list will typically be decomposed by -# default, but can be useful to express a deviation from this device's regular -# strategy in PennyLane. SX = {} ISWAP = {} PSWAP = {} @@ -61,38 +76,32 @@ QubitSum = {} OrbitalRotation = {} QFT = {} ECR = {} - -# Gates which should be translated to QubitUnitary -[operators.gates.matrix] - DiagonalQubitUnitary = {} - # Observables supported by the device [operators.observables] -PauliX = { properties = [ "differentiable" ] } -PauliY = { properties = [ "differentiable" ] } -PauliZ = { properties = [ "differentiable" ] } -Hadamard = { properties = [ "differentiable" ] } -Hermitian = { properties = [ "differentiable" ] } -Identity = { properties = [ "differentiable" ] } -Projector = {} -SparseHamiltonian = { properties = [ "differentiable" ] } -Hamiltonian = { properties = [ "differentiable" ] } -Sum = { properties = [ "differentiable" ] } -SProd = { properties = [ "differentiable" ] } -Prod = { properties = [ "differentiable" ] } -Exp = { properties = [ "differentiable" ] } +PauliX = { properties = ["differentiable"] } +PauliY = { properties = ["differentiable"] } +PauliZ = { properties = ["differentiable"] } +Hadamard = { properties = ["differentiable"] } +Hermitian = { properties = ["differentiable"] } +Identity = { properties = ["differentiable"] } +SparseHamiltonian = { properties = ["differentiable"] } +Sum = { properties = ["differentiable"] } +SProd = { properties = ["differentiable"] } +Prod = { properties = ["differentiable"] } +Exp = { properties = ["differentiable"] } +Projector = {} [measurement_processes] -Expval = {} -Var = {} -Probs = {} -State = { condition = [ "analytic" ] } -Sample = { condition = [ "finiteshots" ] } -Counts = { condition = [ "finiteshots" ] } +ExpectationMP = { } +VarianceMP = { } +ProbabilityMP = { } +StateMP = { conditions = ["analytic"] } +SampleMP = { conditions = ["finiteshots"] } +CountsMP = { conditions = ["finiteshots"] } [compilation] @@ -100,17 +109,11 @@ Counts = { condition = [ "finiteshots" ] } qjit_compatible = true # If the device requires run time generation of the quantum circuit. runtime_code_generation = false -# If the device supports mid circuit measurements natively -mid_circuit_measurement = true -# This field is currently unchecked but it is reserved for the purpose of +# If the device supports mid-circuit measurements natively +supported_mcm_methods = [ "one-shot" ] +# This field is currently unchecked, but it is reserved for the purpose of # determining if the device supports dynamic qubit allocation/deallocation. dynamic_qubit_management = false - # whether the device can support non-commuting measurements together # in a single execution non_commuting_observables = true - -[options] - -option1 = "_option1" -option2 = "_option2" diff --git a/frontend/test/lit/test_autograph.py b/frontend/test/lit/test_autograph.py index a0975b8da2..a69c8f9f73 100644 --- a/frontend/test/lit/test_autograph.py +++ b/frontend/test/lit/test_autograph.py @@ -632,7 +632,7 @@ def h(): @qjit(autograph=True) def disable_autograph_context_manager_jax(): """Checks that Autograph is disabled for a given context.""" - # CHECK: { lambda ; . let transform_named_sequence in (36.4,) } + # CHECK: { lambda ; . let in (36.4,) } x = 0.4 with disable_autograph: x += h() diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index a84d375a04..522a453013 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -1,4 +1,16 @@ # Copyright 2022-2023 Xanadu Quantum Technologies Inc. +import os +import pathlib +import platform +from copy import deepcopy + +import jax +import pennylane as qml +from pennylane.devices.capabilities import OperatorProperties + +from catalyst import measure, qjit +from catalyst.compiler import get_lib_path +from catalyst.device import get_device_capabilities # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,16 +27,9 @@ # RUN: %PYTHON %s | FileCheck %s # pylint: disable=line-too-long -import platform -from copy import deepcopy -import jax -import pennylane as qml - -from catalyst import measure, qjit -from catalyst.compiler import get_lib_path -from catalyst.device import get_device_capabilities -from catalyst.utils.toml import OperationProperties +TEST_PATH = os.path.dirname(__file__) +CONFIG_CUSTOM_DEVICE = pathlib.Path(f"{TEST_PATH}/../custom_device/custom_device.toml") def get_custom_device_without(num_wires, discards=frozenset(), force_matrix=frozenset()): @@ -34,30 +39,18 @@ class CustomDevice(qml.devices.Device): """Custom Gate Set Device""" name = "Custom Device" - pennylane_requires = "0.35.0" - version = "0.0.2" - author = "Tester" - - lightning_device = qml.device("lightning.qubit", wires=0) + config_filepath = CONFIG_CUSTOM_DEVICE - config = None - backend_name = "default" - backend_lib = "default" - backend_kwargs = {} + _to_matrix_ops = {} def __init__(self, shots=None, wires=None): super().__init__(wires=wires, shots=shots) - lightning_capabilities = get_device_capabilities(self.lightning_device) - custom_capabilities = deepcopy(lightning_capabilities) + self.qjit_capabilities = deepcopy(get_device_capabilities(self)) for gate in discards: - custom_capabilities.native_ops.pop(gate, None) - custom_capabilities.to_decomp_ops.pop(gate, None) - custom_capabilities.to_matrix_ops.pop(gate, None) + self.qjit_capabilities.operations.pop(gate, None) for gate in force_matrix: - custom_capabilities.native_ops.pop(gate, None) - custom_capabilities.to_decomp_ops.pop(gate, None) - custom_capabilities.to_matrix_ops[gate] = OperationProperties(False, False, False) - self.qjit_capabilities = custom_capabilities + self.qjit_capabilities.operations.pop(gate, None) + self._to_matrix_ops[gate] = OperatorProperties(False, False, False) def apply(self, operations, **kwargs): """Unused""" @@ -140,10 +133,8 @@ def test_decompose_s(): @qml.qnode(dev) # CHECK-LABEL: public @jit_decompose_s def decompose_s(): - # CHECK-NOT: name="S" - # CHECK: [[pi_div_2:%.+]] = arith.constant 1.57079{{.+}} : f64 # CHECK-NOT: name = "S" - # CHECK: {{%.+}} = quantum.custom "PhaseShift"([[pi_div_2]]) + # CHECK: {{%.+}} = quantum.static_custom "PhaseShift" [1.570796e+00] # CHECK-NOT: name = "S" qml.S(wires=0) return measure(wires=0) diff --git a/frontend/test/lit/test_device_api.py b/frontend/test/lit/test_device_api.py index e929bf82ab..d5ef002adb 100644 --- a/frontend/test/lit/test_device_api.py +++ b/frontend/test/lit/test_device_api.py @@ -14,6 +14,8 @@ # RUN: %PYTHON %s | FileCheck %s +# pylint: disable=line-too-long + """Test for the device API. """ import os @@ -36,7 +38,7 @@ class CustomDevice(Device): """A custom device that does nothing.""" - config = CONFIG_CUSTOM_DEVICE + config_filepath = CONFIG_CUSTOM_DEVICE def __init__(self, wires, shots=1024): super().__init__(wires=wires, shots=shots) @@ -70,7 +72,8 @@ def preprocess(self, execution_config: Optional[ExecutionConfig] = None): def test_circuit(): """Test a circuit compilation to MLIR when using the new device API.""" - # CHECK: quantum.device["[[PATH:.*]]librtd_null_qubit.{{so|dylib}}", "Custom", "{'shots': 2048}"] + # CHECK: [[shots:%.+]] = arith.constant 2048 : i64 + # CHECK: quantum.device shots([[shots]]) ["[[PATH:.*]]librtd_null_qubit.{{so|dylib}}", "Custom", "{'shots': 2048}"] dev = CustomDevice(wires=2, shots=2048) @qjit(target="mlir") @@ -95,7 +98,8 @@ def test_preprocess(): using the new device API. TODO: we need to readd the two check-not once we accept the device preprocessing.""" - # CHECK: quantum.device["[[PATH:.*]]librtd_null_qubit.{{so|dylib}}", "Custom", "{'shots': 2048}"] + # CHECK: [[shots:%.+]] = arith.constant 2048 : i64 + # CHECK: quantum.device shots([[shots]]) ["[[PATH:.*]]librtd_null_qubit.{{so|dylib}}", "Custom", "{'shots': 2048}"] dev = CustomDevice(wires=2, shots=2048) @qjit(target="mlir") diff --git a/frontend/test/lit/test_if_else.py b/frontend/test/lit/test_if_else.py index 0aec735804..bd586c7c62 100644 --- a/frontend/test/lit/test_if_else.py +++ b/frontend/test/lit/test_if_else.py @@ -94,7 +94,7 @@ def circuit_single_gate(n: int): # CHECK: [[b6:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t6]] # CHECK: [[qreg_out1:%.+]] = scf.if [[b6]] # CHECK-DAG: [[q4:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_out]] - # CHECK-DAG: [[q5:%[a-zA-Z0-9_]+]] = quantum.custom "RX"({{%.+}}) [[q4]] + # CHECK-DAG: [[q5:%[a-zA-Z0-9_]+]] = quantum.static_custom "RX" [3.140000e+00] [[q4]] # CHECK-DAG: [[qreg_3:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out]][ {{[%a-zA-Z0-9_]+}}], [[q5]] # CHECK: scf.yield [[qreg_3]] # CHECK: else diff --git a/frontend/test/lit/test_instrumentation_console.py b/frontend/test/lit/test_instrumentation_console.py index 3fb5242f92..aec33e984d 100644 --- a/frontend/test/lit/test_instrumentation_console.py +++ b/frontend/test/lit/test_instrumentation_console.py @@ -64,10 +64,6 @@ def circuit(weights): # COM: Do not check detailed output from the first call to the compiler, e.g. 0_Canonicalize, # COM: as we may want to remove that in the future. # CHECK: [DIAGNOSTICS] > Total generate_ir -# CHECK-NEXT: [DIAGNOSTICS] Running parseMLIRSource -# CHECK-SAME: walltime: {{[0-9\.]+}} ms{{\s*}} cputime: {{[0-9\.]+}} ms{{\s*}} programsize: {{[0-9]+}} lines -# CHECK-NEXT: [DIAGNOSTICS] Running {{[a-zA-Z]+}}Pass -# CHECK-SAME: walltime: {{[0-9\.]+}} ms{{\s*}} cputime: {{[0-9\.]+}} ms{{\s*}} programsize: {{[0-9]+}} lines # COM: Check for "compile" exactly, otherwise we could match things like "compileObjFile". # CHECK: [DIAGNOSTICS] > Total compile{{ }} # CHECK-NEXT: [DIAGNOSTICS] Running device_init @@ -79,6 +75,11 @@ def circuit(weights): # CHECK-NEXT: [DIAGNOSTICS] Running device_release # CHECK-SAME: walltime: {{[0-9\.]+}} ms{{\s*}} cputime: {{[0-9\.]+}} ms # CHECK: [DIAGNOSTICS] > Total run +# COM: As the output below is generated by catalyst-cli, checking the correct order may cause flaky results. +# CHECK: [DIAGNOSTICS] Running parseMLIRSource +# CHECK-SAME: walltime: {{[0-9\.]+}} ms{{\s*}} cputime: {{[0-9\.]+}} ms{{\s*}} programsize: {{[0-9]+}} lines +# CHECK: [DIAGNOSTICS] Running {{[a-zA-Z]+}}Pass +# CHECK-SAME: walltime: {{[0-9\.]+}} ms{{\s*}} cputime: {{[0-9\.]+}} ms{{\s*}} programsize: {{[0-9]+}} lines with instrumentation(circuit.__name__, filename=None, detailed=True): qjit(circuit)(weights) diff --git a/frontend/test/lit/test_measurements.py b/frontend/test/lit/test_measurements.py index 916d721688..21e4e300a0 100644 --- a/frontend/test/lit/test_measurements.py +++ b/frontend/test/lit/test_measurements.py @@ -14,10 +14,13 @@ # RUN: %PYTHON %s | FileCheck %s +import jax import numpy as np import pennylane as qml from catalyst import CompileError, qjit +from catalyst.jax_extras.tracing import bind_flexible_primitive +from catalyst.jax_primitives import compbasis_p, counts_p, sample_p # TODO: NOTE: # The tests sample1 and sample2 below used to pass, before verification steps were added in the @@ -35,7 +38,7 @@ def sample1(x: float, y: float): qml.RX(x, wires=0) qml.RY(y, wires=1) - # COM: CHECK: [[q0:%.+]] = quantum.custom "RZ" + # COM: CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # COM: CHECK: [[obs:%.+]] = quantum.namedobs [[q0]][ PauliZ] @@ -51,7 +54,7 @@ def sample2(x: float, y: float): qml.RX(x, wires=0) # COM: CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # COM: CHECK: [[q0:%.+]] = quantum.custom "RZ" + # COM: CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # COM: CHECK: [[obs1:%.+]] = quantum.namedobs [[q1]][ PauliX] @@ -68,20 +71,64 @@ def sample2(x: float, y: float): # CHECK-LABEL: public @sample3( @qjit(target="mlir") @qml.qnode(qml.device("lightning.qubit", wires=2, shots=1000)) +# CHECK: [[shots:%.+]] = arith.constant 1000 : i64 +# CHECK: quantum.device shots([[shots]]) [{{.+}}] def sample3(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q0:%.+]] = quantum.custom "RZ" + # CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # CHECK: [[obs:%.+]] = quantum.compbasis [[q0]], [[q1]] - # CHECK: quantum.sample [[obs]] {shots = 1000 : i64} : tensor<1000x2xf64> + # CHECK: quantum.sample [[obs]] : tensor<1000x2xf64> return qml.sample() print(sample3.mlir) + +# CHECK-LABEL: public @test_sample_static( +@qjit +@qml.qnode( + qml.device("null.qubit", wires=1) +) # SampleOp is only legal if there is a device in the same scope +def test_sample_static(): + """Test that the sample primitive can be correctly compiled to mlir.""" + obs = compbasis_p.bind() + return bind_flexible_primitive(sample_p, {"shots": 5}, obs, num_qubits=0) + + +# CHECK: [[obs:%.+]] = quantum.compbasis : !quantum.obs +# CHECK: [[sample:%.+]] = quantum.sample [[obs]] : tensor<5x0xf64> +# CHECK: return [[sample]] : tensor<5x0xf64> +print(test_sample_static.mlir) + + +# TODO: convert the device to have a dynamic shots value when core PennyLane device supports it +# CHECK-LABEL: public @test_sample_dynamic( +@qjit +@qml.qnode( + qml.device("null.qubit", wires=1) +) # SampleOp is only legal if there is a device in the same scope +def test_sample_dynamic(shots: int): + """Test that the sample primitive with dynamic shape can be correctly compiled to mlir.""" + obs = compbasis_p.bind() + x = shots + 1 + sample = bind_flexible_primitive(sample_p, {"shots": x}, obs, num_qubits=0) + return sample + jax.numpy.zeros((x, 0)) + + +# CHECK: [[one:%.+]] = stablehlo.constant dense<1> : tensor +# CHECK: [[obs:%.+]] = quantum.compbasis : !quantum.obs +# CHECK: [[plusOne:%.+]] = stablehlo.add %arg0, [[one]] : tensor +# CHECK: [[sample:%.+]] = quantum.sample [[obs]] : tensor +# CHECK: [[zeroVec:%.+]] = stablehlo.dynamic_broadcast_in_dim {{.+}} -> tensor +# CHECK: [[outVecSum:%.+]] = stablehlo.add [[sample]], [[zeroVec]] : tensor +# CHECK: return [[plusOne]], [[outVecSum]] : tensor, tensor +print(test_sample_dynamic.mlir) + + # TODO: NOTE: # The tests below used to pass before the compiler driver (in the case of counts2) and device # preprocessing verification (in the case of counts1). Now that the validation is run, the circuits @@ -98,7 +145,7 @@ def sample3(x: float, y: float): def counts1(x: float, y: float): qml.RX(x, wires=0) qml.RY(y, wires=1) - # COM: CHECK: [[q0:%.+]] = quantum.custom "RZ" + # COM: CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # COM: CHECK: [[obs:%.+]] = quantum.namedobs [[q0]][ PauliZ] @@ -113,7 +160,7 @@ def counts2(x: float, y: float): qml.RX(x, wires=0) # COM: CHECK: [[q1:%.+]] = "quantum.custom"({{%.+}}, {{%.+}}) {gate_name = "RY" qml.RY(y, wires=1) - # COM: CHECK: [[q0:%.+]] = "quantum.custom"({{%.+}}, {{%.+}}) {gate_name = "RZ" + # COM: CHECK: [[q0:%.+]] = "quantum.static_custom"({{%.+}}, {{%.+}}) {gate_name = "RZ" qml.RZ(0.1, wires=0) # COM: CHECK: [[obs1:%.+]] = "quantum.namedobs"([[q1]]) {type = #quantum} @@ -130,28 +177,58 @@ def counts2(x: float, y: float): # CHECK-LABEL: public @counts3( @qjit(target="mlir") @qml.qnode(qml.device("lightning.qubit", wires=2, shots=1000)) +# CHECK: [[shots:%.+]] = arith.constant 1000 : i64 +# CHECK: quantum.device shots([[shots]]) [{{.+}}] def counts3(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q0:%.+]] = quantum.custom "RZ" + # CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # CHECK: [[obs:%.+]] = quantum.compbasis [[q0]], [[q1]] - # CHECK: quantum.counts [[obs]] {shots = 1000 : i64} : tensor<4xf64>, tensor<4xi64> + # CHECK: quantum.counts [[obs]] : tensor<4xf64>, tensor<4xi64> return qml.counts() print(counts3.mlir) +# CHECK-LABEL: public @jit_test_counts_static( +@qjit +def test_counts_static(): + """Test that the counts primitive can be correctly compiled to mlir.""" + obs = compbasis_p.bind() + return bind_flexible_primitive(counts_p, {"shots": 5}, obs, shape=(1,)) + + +# CHECK: [[obs:%.+]] = quantum.compbasis : !quantum.obs +# CHECK: [[eigvals:%.+]], [[counts:%.+]] = quantum.counts [[obs]] : tensor<1xf64>, tensor<1xi64> +# CHECK: return [[eigvals]], [[counts]] : tensor<1xf64>, tensor<1xi64> +print(test_counts_static.mlir) + + +# CHECK-LABEL: public @jit_test_counts_dynamic( +@qjit +def test_counts_dynamic(shots: int): + """Test that the counts primitive with dynamic shape can be correctly compiled to mlir.""" + obs = compbasis_p.bind() + return bind_flexible_primitive(counts_p, {"shots": shots}, obs, shape=(1,)) + + +# CHECK: [[obs:%.+]] = quantum.compbasis : !quantum.obs +# CHECK: [[eigvals:%.+]], [[counts:%.+]] = quantum.counts [[obs]] : tensor<1xf64>, tensor<1xi64> +# CHECK: return [[eigvals]], [[counts]] : tensor<1xf64>, tensor<1xi64> +print(test_counts_dynamic.mlir) + + # CHECK-LABEL: public @expval1( @qjit(target="mlir") @qml.qnode(qml.device("lightning.qubit", wires=2)) def expval1(x: float, y: float): qml.RX(x, wires=0) qml.RY(y, wires=1) - # CHECK: [[q0:%.+]] = quantum.custom "RZ" + # CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # CHECK: [[obs:%.+]] = quantum.namedobs [[q0]][ PauliX] @@ -170,7 +247,7 @@ def expval2(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q2:%.+]] = quantum.custom "RZ" + # CHECK: [[q2:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=2) # CHECK: [[p1:%.+]] = quantum.namedobs [[q0]][ PauliX] @@ -227,7 +304,7 @@ def expval5(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q2:%.+]] = quantum.custom "RZ" + # CHECK: [[q2:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=2) B = np.array( @@ -257,7 +334,7 @@ def expval5(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q2:%.+]] = quantum.custom "RZ" + # CHECK: [[q2:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=2) coeffs = np.array([0.2, -0.543]) @@ -349,7 +426,7 @@ def expval9(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q2:%.+]] = quantum.custom "RZ" + # CHECK: [[q2:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=2) # CHECK: [[p1:%.+]] = quantum.namedobs [[q0]][ PauliX] @@ -371,7 +448,7 @@ def expval10(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q2:%.+]] = quantum.custom "RZ" + # CHECK: [[q2:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=2) B = np.array( @@ -399,7 +476,7 @@ def expval10(x: float, y: float): def var1(x: float, y: float): qml.RX(x, wires=0) qml.RY(y, wires=1) - # CHECK: [[q0:%.+]] = quantum.custom "RZ" + # CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # CHECK: [[obs:%.+]] = quantum.namedobs [[q0]][ PauliX] @@ -442,7 +519,7 @@ def probs1(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q0:%.+]] = quantum.custom "RZ" + # CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # qml.probs() # unsupported by PennyLane @@ -463,7 +540,7 @@ def state1(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q0:%.+]] = quantum.custom "RZ" + # CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # qml.state(wires=[0]) # unsupported by PennyLane diff --git a/frontend/test/lit/test_mlir_decomposition.py b/frontend/test/lit/test_mlir_decomposition.py index e0cd59a96c..2493edc6b6 100644 --- a/frontend/test/lit/test_mlir_decomposition.py +++ b/frontend/test/lit/test_mlir_decomposition.py @@ -33,6 +33,7 @@ from lit_util_printers import print_jaxpr from pennylane.devices import NullQubit +import catalyst from catalyst import qjit from catalyst.debug import get_compilation_stage from catalyst.utils.runtime_environment import get_lib_path @@ -58,7 +59,7 @@ class CustomDevice(NullQubit): name = "oqd.cloud" - config = CONFIG_CUSTOM_DEVICE + config_filepath = CONFIG_CUSTOM_DEVICE @staticmethod def get_c_interface(): @@ -78,6 +79,7 @@ def test_decomposition_lowering(): """ @qjit(keep_intermediate=True) + @catalyst.passes.ions_decomposition @qml.qnode(CustomDevice(2)) def test_decomposition_lowering_workflow(x): qml.RX(x, wires=[0]) @@ -85,11 +87,6 @@ def test_decomposition_lowering_workflow(x): qml.Hadamard(wires=[1]) return qml.expval(qml.PauliY(wires=0)) - # CHECK: transform_named_sequence - # CHECK: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: options=func-name=test_decomposition_lowering_workflow - # CHECK: pass_name=ions-decomposition - # CHECK: ] print_jaxpr(test_decomposition_lowering_workflow, 1.2) # CHECK: quantum.custom "RX" # CHECK-NOT: quantum.custom "Hadamard" diff --git a/frontend/test/lit/test_mlir_plugin.py b/frontend/test/lit/test_mlir_plugin.py new file mode 100644 index 0000000000..39b0c58e04 --- /dev/null +++ b/frontend/test/lit/test_mlir_plugin.py @@ -0,0 +1,98 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# RUN: %PYTHON %s | FileCheck %s + +""" +This test makes sure that we can use plugins from the compiler + +Given the standalone-plugin in the MLIR repository, can we verify that +it works when loading it from python? + +This test uses a lot of machinery that is not exposed to the user. +However, testing the standalone-plugin (as written in the LLVM repository) +is impractical. The standalone-plugin rewrites a symbols with the name +`bar` to symbols with the name `foo`. However, since the standalone +plugin is meant to be more of an example, it does not modify +the uses of symbol `bar` and change them to `foo`. + +What this practically means for this test is that if we were to write +something like the following: + + ```python + import pennylane as qml + from pathlib import Path + + from catalyst.passes import apply_pass + + plugin = Path("./mlir/standalone/build/lib/StandalonePlugin.so") + + @apply_pass("standalone-switch-bar-foo") + @qml.qnode(qml.device("lightning.qubit", wires=1)) + def bar(): + return qml.state() + + @qml.qjit(keep_intermediate=True, verbose=True, pass_plugins=[plugin], dialect_plugins=[plugin]) + def module(): + return bar() + + print(module.mlir) + ``` + +It would succeed in generate correct MLIR during the lowering from JAXPR to MLIR. +However, after the `standalone-switch-bar-foo` pass, the verifier would fail +because it would see callsites to `@bar` but no definitions for `@bar`. + +As such, this test is perhaps a bit more limited. It does not test the +apply_pass interface nor the pass_plugins directly. Instead, it tests +that the `standalone-switch-bar-foo` pass can be executed using lower level +APIs, like the Compiler and options. +""" + +import platform +from pathlib import Path + +from catalyst.compiler import CompileOptions, Compiler +from catalyst.utils.filesystem import WorkspaceManager +from catalyst.utils.runtime_environment import get_bin_path + +mlir_module = """ +module @module { + module @module_qnode { + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) { + %0 = transform.apply_registered_pass "standalone-switch-bar-foo" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + transform.yield + } + } + // CHECK: func.func private @foo() + func.func private @bar() -> (tensor) { + %c = stablehlo.constant dense<0> : tensor + return %c : tensor + } + } +} +""" + +ext = "so" if platform.system() == "Linux" else "dylib" +plugin_path = get_bin_path("cli", "CATALYST_BIN_DIR") + f"/../lib/StandalonePlugin.{ext}" +plugin = Path(plugin_path) +custom_pipeline = [("run_only_plugin", ["builtin.module(apply-transform-sequence)"])] +options = CompileOptions( + pipelines=custom_pipeline, lower_to_llvm=False, pass_plugins=[plugin], dialect_plugins=[plugin] +) +workspace = WorkspaceManager.get_or_create_workspace("test", None) +custom_compiler = Compiler(options) +_, mlir_string = custom_compiler.run_from_ir(mlir_module, "test", workspace) +print(mlir_string) diff --git a/frontend/test/lit/test_peephole_optimizations.py b/frontend/test/lit/test_peephole_optimizations.py index 4c7c3f665d..69d607bb18 100644 --- a/frontend/test/lit/test_peephole_optimizations.py +++ b/frontend/test/lit/test_peephole_optimizations.py @@ -46,35 +46,6 @@ def flush_peephole_opted_mlir_to_iostream(QJIT): shutil.rmtree(QJIT.__name__) -# -# General lowering tests -# - - -def test_transform_named_sequence_injection(): - """ - Test the transform.with_named_sequence jax primitive and mlir operation are - always generated for qjit. - """ - - @qjit - def func(): - return - - # CHECK: transform_named_sequence - print_jaxpr(func) - - # CHECK: module @func { - # CHECK: module attributes { - # CHECK-SAME: transform.with_named_sequence - # CHECK: transform.named_sequence @__transform_main - # CHECK-NEXT: transform.yield - print_mlir(func) - - -test_transform_named_sequence_injection() - - # # pipeline # @@ -98,25 +69,17 @@ def test_pipeline_lowering_workflow(x): qml.Hadamard(wires=[1]) return qml.expval(qml.PauliY(wires=0)) - # CHECK: transform_named_sequence - # CHECK: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: options=func-name=test_pipeline_lowering_workflow_transformed0 - # CHECK: pass_name=remove-chained-self-inverse - # CHECK: ] - # CHECK: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: options=func-name=test_pipeline_lowering_workflow_transformed0 - # CHECK: pass_name=merge-rotations - # CHECK: ] + # CHECK: pipeline=(remove-chained-self-inverse, merge-rotations) print_jaxpr(test_pipeline_lowering_workflow, 1.2) # CHECK: transform.named_sequence @__transform_main - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=test_pipeline_lowering_workflow_transformed0"} - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=test_pipeline_lowering_workflow_transformed0"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} # CHECK-NEXT: transform.yield print_mlir(test_pipeline_lowering_workflow, 1.2) - # CHECK: {{%.+}} = call @test_pipeline_lowering_workflow_transformed0_0( - # CHECK: func.func public @test_pipeline_lowering_workflow_transformed0_0( + # CHECK: {{%.+}} = call @test_pipeline_lowering_workflow_transformed_0( + # CHECK: func.func public @test_pipeline_lowering_workflow_transformed_0( # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit @@ -144,40 +107,37 @@ def f(x): qml.Hadamard(wires=[1]) return qml.expval(qml.PauliY(wires=0)) - f_pipeline = pipeline(pass_pipeline=my_pipeline)(f) + f_pipeline = pipeline(my_pipeline)(f) @qjit(keep_intermediate=True) def test_pipeline_lowering_keep_original_workflow(x): return f(x), f_pipeline(x) - # CHECK: transform_named_sequence - # CHECK: call_jaxpr= - # CHECK-NOT: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: call_jaxpr= - # CHECK: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: options=func-name=f_transformed0 - # CHECK: pass_name=remove-chained-self-inverse - # CHECK: ] - # CHECK: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: options=func-name=f_transformed0 - # CHECK: pass_name=merge-rotations - # CHECK: ] + # COM: this if for f(x) + # CHECK: quantum_kernel + # COM: this if for f_pipeline(x) + # COM: Unfortunately, we don't have a nice repr for qnode + # CHECK: quantum_kernel + # CHECK: pipeline=(remove-chained-self-inverse, merge-rotations) print_jaxpr(test_pipeline_lowering_keep_original_workflow, 1.2) + # COM: This is the one that is unchanged + # CHECK: transform.named_sequence @__transform_main + # COM: This is the one that is changed # CHECK: transform.named_sequence @__transform_main - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=f_transformed0"} - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=f_transformed0"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} # CHECK-NEXT: transform.yield print_mlir(test_pipeline_lowering_keep_original_workflow, 1.2) # CHECK: func.func public @jit_test_pipeline_lowering_keep_original_workflow # CHECK: {{%.+}} = call @f_0( - # CHECK: {{%.+}} = call @f_transformed0_0( + # CHECK: {{%.+}} = call @f_transformed_0( # CHECK: func.func public @f_0( # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit # CHECK: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit # CHECK: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit - # CHECK: func.func public @f_transformed0_0( + # CHECK: func.func public @f_transformed_0( # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit @@ -216,41 +176,29 @@ def h(x): return g(1.2), h(1.2) - # CHECK: transform_named_sequence - # CHECK: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: options=func-name=g_transformed0 - # CHECK: pass_name=remove-chained-self-inverse - # CHECK: ] - # CHECK: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: options=func-name=g_transformed0 - # CHECK: pass_name=merge-rotations - # CHECK: ] - # CHECK: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: options=func-name=h_transformed1 - # CHECK: pass_name=remove-chained-self-inverse - # CHECK: ] - # CHECK: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: options=func-name=h_transformed1 - # CHECK: pass_name=merge-rotations - # CHECK: ] + # CHECK: quantum_kernel + # CHECK: pipeline=(remove-chained-self-inverse, merge-rotations) + # CHECK: quantum_kernel + # CHECK: pipeline=(remove-chained-self-inverse, merge-rotations) print_jaxpr(global_wf) # CHECK: transform.named_sequence @__transform_main - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=g_transformed0"} - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=g_transformed0"} - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=h_transformed1"} - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=h_transformed1"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} + # CHECK: transform.named_sequence @__transform_main + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} # CHECK-NEXT: transform.yield print_mlir(global_wf) # CHECK: func.func public @jit_global_wf() - # CHECK {{%.+}} = call @g_transformed0( - # CHECK {{%.+}} = call @h_transformed1( - # CHECK: func.func public @g_transformed0 + # CHECK {{%.+}} = call @g_0( + # CHECK {{%.+}} = call @h_0( + # CHECK: func.func public @g_0( # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit - # CHECK: func.func public @h_transformed1 + # CHECK: func.func public @h_0( # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit @@ -284,7 +232,7 @@ def g(x): qml.Hadamard(wires=[1]) return qml.expval(qml.PauliY(wires=0)) - @pipeline(pass_pipeline=local_pipeline) + @pipeline(local_pipeline) @qml.qnode(qml.device("lightning.qubit", wires=2)) def h(x): qml.RX(x, wires=[0]) @@ -294,38 +242,29 @@ def h(x): return g(1.2), h(1.2) - # CHECK: transform_named_sequence - # CHECK: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: options=func-name=g_transformed1 - # CHECK: pass_name=remove-chained-self-inverse - # CHECK: ] - # CHECK: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: options=func-name=g_transformed1 - # CHECK: pass_name=merge-rotations - # CHECK: ] - # CHECK: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: options=func-name=h_transformed0 - # CHECK-NOT: pass_name=remove-chained-self-inverse - # CHECK: pass_name=merge-rotations - # CHECK: ] + # CHECK: quantum_kernel + # CHECK: pipeline=(remove-chained-self-inverse, merge-rotations) + # CHECK: quantum_kernel + # CHECK: pipeline=(merge-rotations,) print_jaxpr(global_wf) # CHECK: transform.named_sequence @__transform_main - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=g_transformed1"} - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=g_transformed1"} - # CHECK-NOT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=h_transformed0"} - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=h_transformed0"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} + # CHECK: transform.named_sequence @__transform_main + # CHECK-NOT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} # CHECK-NEXT: transform.yield print_mlir(global_wf) # CHECK: func.func public @jit_global_wf() - # CHECK {{%.+}} = call @g_transformed1( - # CHECK {{%.+}} = call @h_transformed0( - # CHECK: func.func public @g_transformed1 + # CHECK {{%.+}} = call @g( + # CHECK {{%.+}} = call @h_transformed( + # CHECK: func.func public @g_0 # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit - # CHECK: func.func public @h_transformed0 + # CHECK: func.func public @h_transformed_0 # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit # CHECK: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit # CHECK: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit @@ -377,25 +316,19 @@ def h(x: float): _h = h(xx) return _f, _g, _h - # CHECK: transform_named_sequence - # CHECK: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: options=func-name=f_cancel_inverses0 - # CHECK: pass_name=remove-chained-self-inverse - # CHECK: ] - # CHECK: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: options=func-name=g_cancel_inverses1 - # CHECK: pass_name=remove-chained-self-inverse - # CHECK: ] - # CHECK-NOT: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK-NOT: options=func-name=h_cancel_inverses - # CHECK-NOT: pass_name=remove-chained-self-inverse + # CHECK: quantum_kernel + # CHECK: pipeline=(remove-chained-self-inverse,) + # CHECK: quantum_kernel + # CHECK: pipeline=(remove-chained-self-inverse,) print_jaxpr(test_cancel_inverses_tracing_and_lowering_workflow, 1.1) # CHECK: module @test_cancel_inverses_tracing_and_lowering_workflow # CHECK: transform.named_sequence @__transform_main - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=f_cancel_inverses0"} - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=g_cancel_inverses1"} - # CHECK-NOT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=h_cancel_inverses"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} + # CHECK-NEXT: transform.yield + # CHECK: transform.named_sequence @__transform_main + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} + # CHECK-NOT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} # CHECK-NEXT: transform.yield print_mlir(test_cancel_inverses_tracing_and_lowering_workflow, 1.1) @@ -421,16 +354,13 @@ def test_cancel_inverses_tracing_and_lowering_outside_qjit_workflow(xx: float): _f = f(xx) return _f - # CHECK: transform_named_sequence - # CHECK: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: options=func-name=f_cancel_inverses0 - # CHECK: pass_name=remove-chained-self-inverse - # CHECK: ] + # CHECK: quantum_kernel + # CHECK: pipeline=(remove-chained-self-inverse,) print_jaxpr(test_cancel_inverses_tracing_and_lowering_outside_qjit_workflow, 1.1) # CHECK: module @test_cancel_inverses_tracing_and_lowering_outside_qjit_workflow # CHECK: transform.named_sequence @__transform_main - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=f_cancel_inverses0"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} # CHECK-NEXT: transform.yield print_mlir(test_cancel_inverses_tracing_and_lowering_outside_qjit_workflow, 1.1) @@ -515,7 +445,7 @@ def f(x: float): # CHECK-LABEL: public @jit_test_cancel_inverses_keep_original_workflow0 # CHECK: {{%.+}} = call @f_0({{%.+}}) - # CHECK-NOT: {{%.+}} = call @f_cancel_inverses0_0({{%.+}}) + # CHECK-NOT: {{%.+}} = call @f_cancel_inverses_0({{%.+}}) # CHECK-LABEL: public @f_0({{%.+}}) # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit # CHECK-NEXT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit @@ -529,9 +459,9 @@ def test_cancel_inverses_keep_original_workflow0(): flush_peephole_opted_mlir_to_iostream(test_cancel_inverses_keep_original_workflow0) # CHECK-LABEL: public @jit_test_cancel_inverses_keep_original_workflow1 - # CHECK: {{%.+}} = call @f_cancel_inverses0_0({{%.+}}) + # CHECK: {{%.+}} = call @f_cancel_inverses_0({{%.+}}) # CHECK-NOT: {{%.+}} = call @f_0({{%.+}}) - # CHECK-LABEL: public @f_cancel_inverses0_0({{%.+}}) + # CHECK-LABEL: public @f_cancel_inverses_0({{%.+}}) # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit @@ -545,12 +475,12 @@ def test_cancel_inverses_keep_original_workflow1(): # CHECK-LABEL: public @jit_test_cancel_inverses_keep_original_workflow2 # CHECK: {{%.+}} = call @f_0({{%.+}}) - # CHECK: {{%.+}} = call @f_cancel_inverses0_0({{%.+}}) + # CHECK: {{%.+}} = call @f_cancel_inverses_0({{%.+}}) # CHECK-LABEL: public @f_0({{%.+}}) # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit # CHECK-NEXT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit # CHECK-NEXT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit - # CHECK-LABEL: public @f_cancel_inverses0_0({{%.+}}) + # CHECK-LABEL: public @f_cancel_inverses_0({{%.+}}) # CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit # CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit @@ -606,25 +536,20 @@ def h(x: float): _h = h(xx) return _f, _g, _h - # CHECK: transform_named_sequence - # CHECK: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: options=func-name=f_merge_rotations0 - # CHECK: pass_name=merge-rotations - # CHECK: ] - # CHECK: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK: options=func-name=g_merge_rotations1 - # CHECK: pass_name=merge-rotations - # CHECK: ] - # CHECK-NOT: _:AbstractTransformMod() = apply_registered_pass[ - # CHECK-NOT: options=func-name=h_merge_rotations - # CHECK-NOT: pass_name=merge-rotations + # CHECK: quantum_kernel + # CHECK: pipeline=(merge-rotations,) + # CHECK: quantum_kernel + # CHECK: pipeline=(merge-rotations,) + # CHECK: quantum_kernel print_jaxpr(test_merge_rotations_tracing_and_lowering_workflow, 1.1) # CHECK: module @test_merge_rotations_tracing_and_lowering_workflow # CHECK: transform.named_sequence @__transform_main - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=f_merge_rotations0"} - # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=g_merge_rotations1"} - # CHECK-NOT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=h_merge_rotations"} + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} + # CHECK-NEXT: transform.yield + # CHECK: transform.named_sequence @__transform_main + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} + # CHECK-NOT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} # CHECK-NEXT: transform.yield print_mlir(test_merge_rotations_tracing_and_lowering_workflow, 1.1) diff --git a/frontend/test/lit/test_quantum_control.py b/frontend/test/lit/test_quantum_control.py index 5852757fec..730692c855 100644 --- a/frontend/test/lit/test_quantum_control.py +++ b/frontend/test/lit/test_quantum_control.py @@ -15,16 +15,21 @@ # RUN: %PYTHON %s | FileCheck %s """ Test the lowering cases involving quantum control """ +import os +import pathlib import platform from copy import deepcopy import jax.numpy as jnp import pennylane as qml +from pennylane.devices.capabilities import OperatorProperties from catalyst import qjit from catalyst.compiler import get_lib_path from catalyst.device import get_device_capabilities -from catalyst.utils.toml import OperationProperties + +TEST_PATH = os.path.dirname(__file__) +CONFIG_CUSTOM_DEVICE = pathlib.Path(f"{TEST_PATH}/../custom_device/custom_device.toml") def get_custom_qjit_device(num_wires, discards, additions): @@ -34,24 +39,14 @@ class CustomDevice(qml.devices.Device): """Custom Gate Set Device""" name = "lightning.qubit" - pennylane_requires = "0.35.0" - version = "0.0.2" - author = "Tester" - - lightning_device = qml.device("lightning.qubit", wires=0) - - backend_name = "default" - backend_lib = "default" - backend_kwargs = {} + config_filepath = CONFIG_CUSTOM_DEVICE def __init__(self, shots=None, wires=None): super().__init__(wires=wires, shots=shots) - lightning_capabilities = get_device_capabilities(self.lightning_device) - custom_capabilities = deepcopy(lightning_capabilities) + self.qjit_capabilities = get_device_capabilities(self) for gate in discards: - custom_capabilities.native_ops.pop(gate) - custom_capabilities.native_ops.update(additions) - self.qjit_capabilities = custom_capabilities + self.qjit_capabilities.operations.pop(gate, None) + self.qjit_capabilities.operations.update(additions) @staticmethod def get_c_interface(): @@ -97,13 +92,13 @@ def named_controlled(): def test_native_controlled_custom(): """Test native control of a custom operation.""" - dev = get_custom_qjit_device(3, set(), {"Rot": OperationProperties(True, True, False)}) + dev = get_custom_qjit_device(3, set(), {"Rot": OperatorProperties(True, True, False)}) @qjit(target="mlir") @qml.qnode(dev) # CHECK-LABEL: public @jit_native_controlled def native_controlled(): - # CHECK: [[out:%.+]], [[out_ctrl:%.+]]:2 = quantum.custom "Rot" + # CHECK: [[out:%.+]], [[out_ctrl:%.+]]:2 = quantum.static_custom "Rot" # CHECK-SAME: ctrls # CHECK-SAME: ctrlvals(%true, %true) qml.ctrl(qml.Rot(0.3, 0.4, 0.5, wires=[0]), control=[1, 2]) @@ -147,7 +142,7 @@ def native_controlled_unitary(): def test_native_controlled_multirz(): """Test native control of the multirz operation.""" - dev = get_custom_qjit_device(3, set(), {"MultiRZ": OperationProperties(True, True, True)}) + dev = get_custom_qjit_device(3, set(), {"MultiRZ": OperatorProperties(True, True, True)}) @qjit(target="mlir") @qml.qnode(dev) diff --git a/frontend/test/lit/test_split_multiple_tapes.py b/frontend/test/lit/test_split_multiple_tapes.py index 12468bacfb..1e45e4f656 100644 --- a/frontend/test/lit/test_split_multiple_tapes.py +++ b/frontend/test/lit/test_split_multiple_tapes.py @@ -64,7 +64,7 @@ def circuit_twotapes(x): # CHECK: circuit_twotapes # CHECK: call_jaxpr={ lambda ; - # CHECK-NEXT: qdevice[ + # CHECK: qdevice[ # CHECK: ] # CHECK: qdealloc # CHECK-NEXT: qdevice[ @@ -74,10 +74,10 @@ def circuit_twotapes(x): print_jaxpr(circuit_twotapes, [0.1, 0.2]) # CHECK: circuit_twotapes - # CHECK: quantum.device[ + # CHECK: quantum.device # CHECK: quantum.dealloc # CHECK-NEXT: quantum.device_release - # CHECK-NEXT: quantum.device[ + # CHECK-NEXT: quantum.device # CHECK: quantum.dealloc # CHECK-NEXT: quantum.device_release # CHECK-NEXT: {{%.+}} = stablehlo.add {{%.+}}, {{%.+}} : tensor diff --git a/frontend/test/lit/test_static_circuit.py b/frontend/test/lit/test_static_circuit.py new file mode 100644 index 0000000000..40de0c68d3 --- /dev/null +++ b/frontend/test/lit/test_static_circuit.py @@ -0,0 +1,68 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# RUN: %PYTHON %s | FileCheck %s + +""" +Test quantum circuits with static (knonw at compile time) specifications. +""" + +import pennylane as qml + +from catalyst import qjit + + +def test_static_params(): + """Test operations with static params.""" + + @qjit(target="mlir") + @qml.qnode(qml.device("lightning.qubit", wires=4)) + def circuit(): + x = 3.14 + y = 0.6 + qml.Rot(x, y, x + y, wires=0) + + qml.RX(x, wires=0) + qml.RY(y, wires=1) + qml.RZ(x, wires=2) + + qml.IsingXX(x, wires=[0, 1]) + qml.IsingXX(y, wires=[1, 2]) + qml.IsingZZ(x, wires=[0, 1]) + + qml.CRX(x, wires=[0, 1]) + qml.CRY(x, wires=[0, 1]) + qml.CRZ(x, wires=[0, 1]) + + return qml.state() + + print(circuit.mlir) + + +# CHECK-LABEL: public @jit_circuit +# CHECK: %[[REG:.*]] = quantum.alloc( 4) : !quantum.reg +# CHECK: %[[BIT1:.*]] = quantum.extract %[[REG]][ 0] : !quantum.reg -> !quantum.bit +# CHECK: %[[ROT:.*]] = quantum.static_custom "Rot" +# CHECK: %[[RX:.*]] = quantum.static_custom "RX" +# CHECK: %[[BIT1:.*]] = quantum.extract %[[REG]][ 1] +# CHECK: %[[RY1:.*]] = quantum.static_custom "RY" +# CHECK: %[[XX1:.*]] = quantum.static_custom "IsingXX" +# CHECK: %[[BIT2:.*]] = quantum.extract %[[REG]][ 2] +# CHECK: %[[RZ:.*]] = quantum.static_custom "RZ" +# CHECK: %[[XX2:.*]] = quantum.static_custom "IsingXX" +# CHECK: %[[ZZ:.*]] = quantum.static_custom "IsingZZ" +# CHECK: %[[CRX:.*]] = quantum.static_custom "CRX" +# CHECK: %[[CRY:.*]] = quantum.static_custom "CRY" +# CHECK: %[[CRZ:.*]] = quantum.static_custom "CRZ" +test_static_params() diff --git a/frontend/test/pytest/conftest.py b/frontend/test/pytest/conftest.py index 3d757ae4ea..d265f9e239 100644 --- a/frontend/test/pytest/conftest.py +++ b/frontend/test/pytest/conftest.py @@ -17,6 +17,22 @@ import os import pathlib +from tempfile import TemporaryDirectory +from textwrap import dedent + +import pytest TEST_PATH = os.path.dirname(__file__) CONFIG_CUSTOM_DEVICE = pathlib.Path(f"{TEST_PATH}/../custom_device/custom_device.toml") + + +@pytest.fixture(scope="function") +def create_temporary_toml_file(request) -> str: + """Create a temporary TOML file with the given content.""" + content = request.param + with TemporaryDirectory() as temp_dir: + toml_file = os.path.join(temp_dir, "test.toml") + with open(toml_file, "w", encoding="utf-8") as f: + f.write(dedent(content)) + request.node.toml_file = toml_file + yield diff --git a/frontend/test/pytest/device/test_decomposition.py b/frontend/test/pytest/device/test_decomposition.py index 4b0378ae55..5159fe01f6 100644 --- a/frontend/test/pytest/device/test_decomposition.py +++ b/frontend/test/pytest/device/test_decomposition.py @@ -14,18 +14,21 @@ """Unit test module for catalyst/device/decomposition.py""" +import os +import pathlib import platform -from copy import deepcopy import numpy as np import pennylane as qml import pytest +from pennylane.devices.capabilities import DeviceCapabilities, OperatorProperties from catalyst import CompileError, ctrl, qjit from catalyst.compiler import get_lib_path -from catalyst.device import get_device_capabilities from catalyst.device.decomposition import catalyst_decomposer -from catalyst.utils.toml import DeviceCapabilities, OperationProperties + +TEST_PATH = os.path.dirname(__file__) +CONFIG_CUSTOM_DEVICE = pathlib.Path(f"{TEST_PATH}/../../custom_device/custom_device.toml") class TestGateAliases: @@ -71,7 +74,7 @@ def test_control_aliases(self, gate, base): """Test the decomposition of specialized control operations.""" capabilities = DeviceCapabilities( - native_ops={base.__name__: OperationProperties(controllable=True)} + operations={base.__name__: OperatorProperties(controllable=True)} ) decomp = catalyst_decomposer(gate, capabilities) @@ -80,17 +83,51 @@ def test_control_aliases(self, gate, base): assert type(decomp[0].base) is base +class NoUnitaryDevice(qml.devices.Device): + """Custom device used for testing purposes.""" + + config_filepath = CONFIG_CUSTOM_DEVICE + + def __init__(self, shots=None, wires=None): + super().__init__(wires=wires, shots=shots) + self.capabilities.operations.pop("QubitUnitary") + self.qjit_capabilities = self.capabilities + + def apply(self, operations, **kwargs): + """Unused""" + raise RuntimeError("Only C/C++ interface is defined") + + @staticmethod + def get_c_interface(): + """Returns a tuple consisting of the device name, and + the location to the shared object with the C/C++ device implementation. + """ + system_extension = ".dylib" if platform.system() == "Darwin" else ".so" + lib_path = ( + get_lib_path("runtime", "RUNTIME_LIB_DIR") + "/librtd_null_qubit" + system_extension + ) + return "NullQubit", lib_path + + def execute(self, circuits, execution_config): + """Execution.""" + return circuits, execution_config + + class TestControlledDecomposition: """Test behaviour around the decomposition of the `Controlled` class.""" def test_no_matrix(self, backend): """Test that controlling an operation without a matrix method raises an error.""" + dev = qml.device(backend, wires=4) class OpWithNoMatrix(qml.operation.Operation): + """Op without a matrix""" + num_wires = qml.operation.AnyWires def matrix(self): + """matrix undefined""" raise NotImplementedError() @qml.qnode(dev) @@ -105,9 +142,12 @@ def test_no_unitary_support(self): """Test that unknown controlled operations without QubitUnitary support raise an error.""" class UnknownOp(qml.operation.Operation): + """An unknown operation""" + num_wires = qml.operation.AnyWires def matrix(self): + """The matrix""" return np.array( [ [1.0, 0.0, 0.0, 0.0], @@ -118,7 +158,7 @@ def matrix(self): dtype=np.complex128, ) - dev = get_custom_device_without(4, {"QubitUnitary"}) + dev = NoUnitaryDevice(4, wires=4) @qml.qnode(dev) def f(): @@ -129,59 +169,5 @@ def f(): qjit(f, target="jaxpr") -def get_custom_device_without(num_wires, discards=frozenset(), force_matrix=frozenset()): - """Generate a custom device without gates in discards.""" - - class CustomDevice(qml.devices.Device): - """Custom Gate Set Device""" - - name = "Custom Device" - pennylane_requires = "0.35.0" - version = "0.0.2" - author = "Tester" - - lightning_device = qml.device("lightning.qubit", wires=0) - - config = None - backend_name = "default" - backend_lib = "default" - backend_kwargs = {} - - def __init__(self, shots=None, wires=None): - super().__init__(wires=wires, shots=shots) - lightning_capabilities = get_device_capabilities(self.lightning_device) - custom_capabilities = deepcopy(lightning_capabilities) - for gate in discards: - custom_capabilities.native_ops.pop(gate, None) - custom_capabilities.to_decomp_ops.pop(gate, None) - custom_capabilities.to_matrix_ops.pop(gate, None) - for gate in force_matrix: - custom_capabilities.native_ops.pop(gate, None) - custom_capabilities.to_decomp_ops.pop(gate, None) - custom_capabilities.to_matrix_ops[gate] = OperationProperties(False, False, False) - self.qjit_capabilities = custom_capabilities - - def apply(self, operations, **kwargs): - """Unused""" - raise RuntimeError("Only C/C++ interface is defined") - - @staticmethod - def get_c_interface(): - """Returns a tuple consisting of the device name, and - the location to the shared object with the C/C++ device implementation. - """ - system_extension = ".dylib" if platform.system() == "Darwin" else ".so" - lib_path = ( - get_lib_path("runtime", "RUNTIME_LIB_DIR") + "/librtd_null_qubit" + system_extension - ) - return "NullQubit", lib_path - - def execute(self, circuits, execution_config): - """Execution.""" - return circuits, execution_config - - return CustomDevice(wires=num_wires) - - if __name__ == "__main__": pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/test_autograph.py b/frontend/test/pytest/test_autograph.py index 85c2542949..e4cb13e3ba 100644 --- a/frontend/test/pytest/test_autograph.py +++ b/frontend/test/pytest/test_autograph.py @@ -2182,7 +2182,7 @@ def g(x): expected = g(x) return result, expected - result, expected = workflow(np.array([5, 3, 4])) + result, expected = workflow(np.array([5.0, 3.0, 4.0])) assert jnp.allclose(result, jnp.array([2.5, 1.5, 2])) assert jnp.allclose(result, expected) diff --git a/frontend/test/pytest/test_conditionals.py b/frontend/test/pytest/test_conditionals.py index 70ca5b9081..400ba89d51 100644 --- a/frontend/test/pytest/test_conditionals.py +++ b/frontend/test/pytest/test_conditionals.py @@ -33,7 +33,7 @@ def test_basic_cond_to_jaxpr(self): expected = dedent( """ - { lambda ; a:i64[]. let transform_named_sequence + { lambda ; a:i64[]. let b:bool[] = eq a 5 c:i64[] = cond[ branch_jaxprs=[{ lambda ; a:i64[] b_:i64[]. let c:i64[] = integer_pow[y=2] a in (c,) }, diff --git a/frontend/test/pytest/test_config_functions.py b/frontend/test/pytest/test_config_functions.py deleted file mode 100644 index dca8f03f00..0000000000 --- a/frontend/test/pytest/test_config_functions.py +++ /dev/null @@ -1,262 +0,0 @@ -# Copyright 2023 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for device toml config parsing and validation.""" - -from os.path import join -from tempfile import TemporaryDirectory -from textwrap import dedent - -import pytest - -from catalyst.utils.exceptions import CompileError -from catalyst.utils.toml import ( - ALL_SUPPORTED_SCHEMAS, - DeviceCapabilities, - ProgramFeatures, - TOMLDocument, - load_device_capabilities, - read_toml_file, -) - - -def get_test_config(config_text: str) -> TOMLDocument: - """Parse test config into the TOMLDocument structure""" - with TemporaryDirectory() as d: - toml_file = join(d, "test.toml") - with open(toml_file, "w", encoding="utf-8") as f: - f.write(config_text) - config = read_toml_file(toml_file) - return config - - -def get_test_device_capabilities( - program_features: ProgramFeatures, config_text: str -) -> DeviceCapabilities: - """Parse test config into the DeviceCapabilities structure""" - config = get_test_config(config_text) - device_capabilities = load_device_capabilities(config, program_features) - return device_capabilities - - -@pytest.mark.parametrize("schema", ALL_SUPPORTED_SCHEMAS) -def test_get_observables(schema): - """Test observables are properly obtained.""" - device_capabilities = get_test_device_capabilities( - ProgramFeatures(False), - dedent( - f""" - schema = {schema} - [operators.observables] - PauliX = {{ }} - """ - ), - ) - assert len(device_capabilities.native_obs) == 1 - assert "PauliX" in device_capabilities.native_obs - - -@pytest.mark.parametrize("schema", ALL_SUPPORTED_SCHEMAS) -def test_get_native_ops(schema): - """Test native gates are properly obtained from the toml.""" - device_capabilities = get_test_device_capabilities( - ProgramFeatures(False), - dedent( - f""" - schema = {schema} - [operators.gates.native] - PauliX = {{ properties = [ 'controllable' ] }} - PauliY = {{ }} - """ - ), - ) - - assert len(device_capabilities.native_ops) == 2 - assert {"PauliX", "PauliY"}.issubset(device_capabilities.native_ops) - assert device_capabilities.native_ops["PauliX"].controllable - - -@pytest.mark.parametrize("schema", ALL_SUPPORTED_SCHEMAS) -def test_get_native_ops_optional_shots(schema): - """Test native gates are properly obtained from the toml.""" - device_capabilities = get_test_device_capabilities( - ProgramFeatures(True), - dedent( - f""" - schema = {schema} - [operators.gates.native] - PauliX = {{ condition = ['finiteshots'] }} - PauliY = {{ condition = ['analytic'] }} - """ - ), - ) - assert "PauliX" in device_capabilities.native_ops - assert "PauliY" not in device_capabilities.native_ops - - -@pytest.mark.parametrize("schema", ALL_SUPPORTED_SCHEMAS) -def test_get_native_ops_optional_noshots(schema): - """Test native gates are properly obtained from the toml.""" - device_capabilities = get_test_device_capabilities( - ProgramFeatures(False), - dedent( - f""" - schema = {schema} - [operators.gates.native] - PauliX = {{ condition = ['finiteshots'] }} - PauliY = {{ condition = ['analytic'] }} - """ - ), - ) - assert "PauliX" not in device_capabilities.native_ops - assert "PauliY" in device_capabilities.native_ops - - -@pytest.mark.parametrize("schema", ALL_SUPPORTED_SCHEMAS) -def test_get_decomp_gates(schema): - """Test native decomposition gates are properly obtained from the toml.""" - device_capabilities = get_test_device_capabilities( - ProgramFeatures(False), - dedent( - f""" - schema = {schema} - [operators.gates] - decomp = ["PauliX", "PauliY"] - """ - ), - ) - - assert "PauliX" in device_capabilities.to_decomp_ops - assert "PauliY" in device_capabilities.to_decomp_ops - - -@pytest.mark.parametrize("schema", ALL_SUPPORTED_SCHEMAS) -def test_get_matrix_decomposable_gates(schema): - """Test native matrix gates are properly obtained from the toml.""" - device_capabilities = get_test_device_capabilities( - ProgramFeatures(False), - dedent( - f""" - schema = {schema} - [operators.gates.matrix] - PauliZ = {{}} - """ - ), - ) - - assert "PauliZ" in device_capabilities.to_matrix_ops - - -@pytest.mark.parametrize("schema", ALL_SUPPORTED_SCHEMAS) -def test_config_invalid_attr(schema): - """Check the gate condition handling logic""" - with pytest.raises( - CompileError, match="Configuration for gate 'TestGate' has unknown attributes" - ): - get_test_device_capabilities( - ProgramFeatures(False), - dedent( - f""" - schema = {schema} - [operators.gates.native] - TestGate = {{ unknown_attribute = 33 }} - """ - ), - ) - - -@pytest.mark.parametrize("schema", ALL_SUPPORTED_SCHEMAS) -def test_config_invalid_condition_unknown(schema): - """Check the gate condition handling logic""" - with pytest.raises( - CompileError, match="Configuration for gate 'TestGate' has unknown conditions" - ): - get_test_device_capabilities( - ProgramFeatures(True), - dedent( - f""" - schema = {schema} - [operators.gates.native] - TestGate = {{ condition = ["unknown", "analytic"] }} - """ - ), - ) - - -@pytest.mark.parametrize("schema", ALL_SUPPORTED_SCHEMAS) -def test_config_invalid_property_unknown(schema): - """Check the gate condition handling logic""" - with pytest.raises( - CompileError, match="Configuration for gate 'TestGate' has unknown properties" - ): - get_test_device_capabilities( - ProgramFeatures(True), - dedent( - f""" - schema = {schema} - [operators.gates.native] - TestGate = {{ properties = ["unknown", "invertible"] }} - """ - ), - ) - - -@pytest.mark.parametrize("schema", ALL_SUPPORTED_SCHEMAS) -def test_config_invalid_condition_duplicate_true(schema): - """Check the gate condition handling logic for True""" - with pytest.raises(CompileError, match="Configuration for gate 'TestGate'"): - get_test_device_capabilities( - ProgramFeatures(True), - dedent( - f""" - schema = {schema} - [operators.gates.native] - TestGate = {{ condition = ["finiteshots", "analytic"] }} - """ - ), - ) - - -@pytest.mark.parametrize("schema", ALL_SUPPORTED_SCHEMAS) -def test_config_invalid_condition_duplicate_false(schema): - """Check the gate condition handling logic for False""" - with pytest.raises(CompileError, match="Configuration for gate 'TestGate'"): - get_test_device_capabilities( - ProgramFeatures(False), - dedent( - f""" - schema = {schema} - [operators.gates.native] - TestGate = {{ condition = ["finiteshots", "analytic"] }} - """ - ), - ) - - -@pytest.mark.parametrize("schema", [1, 999]) -def test_config_unsupported_schema(schema): - """Test unsupported schema version.""" - program_features = ProgramFeatures(False) - config_text = dedent( - f""" - schema = {schema} - """ - ) - - with pytest.raises(AssertionError, match="Unsupported config schema"): - get_test_device_capabilities(program_features, config_text) - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/frontend/test/pytest/test_custom_devices.py b/frontend/test/pytest/test_custom_devices.py index c88d41c998..5807d3097f 100644 --- a/frontend/test/pytest/test_custom_devices.py +++ b/frontend/test/pytest/test_custom_devices.py @@ -21,133 +21,27 @@ from catalyst import measure, qjit from catalyst.compiler import get_lib_path -from catalyst.device import extract_backend_info, get_device_capabilities +from catalyst.device import QJITDevice, extract_backend_info, get_device_capabilities from catalyst.utils.exceptions import CompileError -# These have to match the ones in the configuration file. -OPERATIONS = [ - "QubitUnitary", - "PauliX", - "PauliY", - "PauliZ", - "MultiRZ", - "Hadamard", - "S", - "T", - "CNOT", - "SWAP", - "CSWAP", - "Toffoli", - "CY", - "CZ", - "PhaseShift", - "ControlledPhaseShift", - "RX", - "RY", - "RZ", - "Rot", - "CRX", - "CRY", - "CRZ", - "CRot", - "Identity", - "IsingXX", - "IsingYY", - "IsingZZ", - "IsingXY", - "SX", - "ISWAP", - "PSWAP", - "SISWAP", - "SQISW", - "BasisState", - "StatePrep", - "ControlledQubitUnitary", - "DiagonalQubitUnitary", - "SingleExcitation", - "SingleExcitationPlus", - "SingleExcitationMinus", - "DoubleExcitation", - "DoubleExcitationPlus", - "DoubleExcitationMinus", - "QubitCarry", - "QubitSum", - "OrbitalRotation", - "QFT", - "ECR", - "Adjoint(S)", - "Adjoint(T)", - "Adjoint(SX)", - "Adjoint(ISWAP)", - "Adjoint(SISWAP)", - "MultiControlledX", - "SISWAP", - "ControlledPhaseShift", - "C(QubitUnitary)", - "C(PauliY)", - "C(RY)", - "C(PauliX)", - "C(RX)", - "C(IsingXX)", - "C(Hadamard)", - "C(SWAP)", - "C(IsingYY)", - "C(S)", - "C(MultiRZ)", - "C(PhaseShift)", - "C(T)", - "C(IsingXY)", - "C(PauliZ)", - "C(Rot)", - "C(IsingZZ)", - "C(RZ)", - "C(SingleExcitationPlus)", - "C(GlobalPhase)", - "C(DoubleExcitationPlus)", - "C(SingleExcitationMinus)", - "C(DoubleExcitation)", - "GlobalPhase", - "C(SingleExcitation)", - "C(DoubleExcitationMinus)", - "BlockEncode", -] -OBSERVABLES = [ - "PauliX", - "PauliY", - "PauliZ", - "Hadamard", - "Hermitian", - "Identity", - "Projector", - "SparseHamiltonian", - "Sum", - "SProd", - "Prod", - "Exp", -] - RUNTIME_LIB_PATH = get_lib_path("runtime", "RUNTIME_LIB_DIR") def test_custom_device_load(): """Test that custom device can run using Catalyst.""" - class CustomDevice(qml.devices.QubitDevice): + class CustomDevice(qml.devices.Device): """Custom device""" - name = "Custom Device" - short_name = "custom.device" - pennylane_requires = "0.33.0" - version = "0.0.1" - author = "Dummy" - - operations = OPERATIONS - observables = OBSERVABLES - config = CONFIG_CUSTOM_DEVICE + name = "custom.device" + config_filepath = CONFIG_CUSTOM_DEVICE - def __init__(self, shots=None, wires=None): + def __init__(self, shots=None, wires=None, options1=42, options2=38): super().__init__(wires=wires, shots=shots) - self._option1 = 42 + self.device_kwargs = { + "option1": options1, + "option2": options2, + } def apply(self, operations, **kwargs): """Unused""" @@ -164,11 +58,15 @@ def get_c_interface(): ) return "NullQubit", lib_path + def execute(self, circuits, execution_config): + """Execution.""" + raise NotImplementedError + device = CustomDevice(wires=1) capabilities = get_device_capabilities(device) backend_info = extract_backend_info(device, capabilities) assert backend_info.kwargs["option1"] == 42 - assert "option2" not in backend_info.kwargs + assert backend_info.kwargs["option2"] == 38 @qjit @qml.qnode(device) @@ -178,24 +76,17 @@ def f(): has been implemented to always return True.""" return measure(0) - assert True == f() + assert f() == True def test_custom_device_bad_directory(): """Test that custom device error.""" - class CustomDevice(qml.devices.QubitDevice): + class CustomDevice(qml.devices.Device): """Custom Device""" - name = "Custom Qubit" - short_name = "custom.device" - pennylane_requires = "0.33.0" - version = "0.0.1" - author = "Dummy" - - operations = OPERATIONS - observables = OBSERVABLES - config = CONFIG_CUSTOM_DEVICE + name = "custom.device" + config_filepath = CONFIG_CUSTOM_DEVICE def __init__(self, shots=None, wires=None): super().__init__(wires=wires, shots=shots) @@ -209,9 +100,12 @@ def get_c_interface(): """Returns a tuple consisting of the device name, and the location to the shared object with the C/C++ device implementation. """ - return "CustomDevice", "this-file-does-not-exist.so" + def execute(self, circuits, execution_config): + """Execution.""" + raise NotImplementedError + with pytest.raises( CompileError, match="Device at this-file-does-not-exist.so cannot be found!" ): @@ -225,18 +119,11 @@ def f(): def test_custom_device_no_c_interface(): """Test that custom device error.""" - class CustomDevice(qml.devices.QubitDevice): + class CustomDevice(qml.devices.Device): """Custom Device""" - name = "Custom Qubit" - short_name = "custom.device" - pennylane_requires = "0.33.0" - version = "0.0.1" - author = "Dummy" - - operations = OPERATIONS - observables = OBSERVABLES - config = CONFIG_CUSTOM_DEVICE + name = "custom.device" + config_filepath = CONFIG_CUSTOM_DEVICE def __init__(self, shots=None, wires=None): super().__init__(wires=wires, shots=shots) @@ -245,6 +132,10 @@ def apply(self, operations, **kwargs): """Unused.""" raise RuntimeError("Custom device") + def execute(self, circuits, execution_config): + """Execution.""" + raise NotImplementedError + with pytest.raises( CompileError, match="The custom.device device does not provide C interface for compilation." ): @@ -253,3 +144,46 @@ def apply(self, operations, **kwargs): @qml.qnode(CustomDevice(wires=1)) def f(): return measure(0) + + +def test_error_raised_no_unitary_support_for_matrix_ops(): + """Tests that an error is raised when a device specifies _to_matrix_ops but does not support + the QubitUnitary operation""" + + class CustomDevice(qml.devices.Device): + """Custom device for testing.""" + + name = "custom.device" + config_filepath = CONFIG_CUSTOM_DEVICE + + _to_matrix_ops = { + "DiagonalQubitUnitary": qml.devices.capabilities.OperatorProperties(), + "BlockEncode": qml.devices.capabilities.OperatorProperties(), + } + + def __init__(self, wires, shots=None, **kwargs): + del self.capabilities.operations["QubitUnitary"] + assert not self.capabilities.supports_operation("QubitUnitary") + self.qjit_capabilities = self.capabilities + super().__init__(wires=wires, shots=shots) + + @staticmethod + def get_c_interface(): + """Returns a tuple consisting of the device name, and + the location to the shared object with the C/C++ device implementation. + """ + system_extension = ".dylib" if platform.system() == "Darwin" else ".so" + lib_path = ( + get_lib_path("runtime", "RUNTIME_LIB_DIR") + "/librtd_null_qubit" + system_extension + ) + return "NullQubit", lib_path + + def execute(self, circuits, execution_config): + """Execution.""" + return (0,) + + with pytest.raises( + CompileError, + match="The device that specifies to_matrix_ops must support QubitUnitary.", + ): + QJITDevice(CustomDevice(wires=2, shots=2048)) diff --git a/frontend/test/pytest/test_device_api.py b/frontend/test/pytest/test_device_api.py index da442d17a6..8136d04492 100644 --- a/frontend/test/pytest/test_device_api.py +++ b/frontend/test/pytest/test_device_api.py @@ -13,17 +13,14 @@ # limitations under the License. """Test for the device API. """ +import platform + import pennylane as qml import pytest from pennylane.devices import NullQubit from catalyst import qjit -from catalyst.device import ( - QJITDevice, - get_device_capabilities, - get_device_toml_config, - qjit_device, -) +from catalyst.device import QJITDevice, get_device_capabilities, qjit_device from catalyst.tracing.contexts import EvaluationContext, EvaluationMode # pylint:disable = protected-access,attribute-defined-outside-init @@ -99,14 +96,8 @@ def test_qjit_device_measurements(shots, mocker): spy = mocker.spy(qjit_device, "get_device_capabilities") dev = qml.device("lightning.qubit", wires=2, shots=shots) - state_measurements = {"State"} - finite_shot_measurements = {"Counts", "Sample"} - - config = get_device_toml_config(dev) - all_measurements = set(config["measurement_processes"]) - - assert state_measurements.issubset(all_measurements) - assert finite_shot_measurements.issubset(all_measurements) + state_measurements = {"StateMP"} + finite_shot_measurements = {"CountsMP", "SampleMP"} dev_capabilities = get_device_capabilities(dev) expected_measurements = dev_capabilities.measurement_processes diff --git a/frontend/test/pytest/test_from_plxpr.py b/frontend/test/pytest/test_from_plxpr.py index 2a9b3c7379..a76f2598c8 100644 --- a/frontend/test/pytest/test_from_plxpr.py +++ b/frontend/test/pytest/test_from_plxpr.py @@ -25,6 +25,7 @@ # needs to be below the importorskip calls # pylint: disable=wrong-import-position from catalyst.from_plxpr import from_plxpr +from catalyst.jax_primitives import get_call_jaxpr def catalyst_execute_jaxpr(jaxpr): @@ -215,8 +216,8 @@ def circuit(U): qjit_obj = qml.qjit(circuit) qjit_obj(x) catalxpr = qjit_obj.jaxpr - call_jaxpr_pl = converted.eqns[0].params["call_jaxpr"] - call_jaxpr_c = catalxpr.eqns[1].params["call_jaxpr"] + call_jaxpr_pl = get_call_jaxpr(converted) + call_jaxpr_c = get_call_jaxpr(catalxpr) compare_call_jaxprs(call_jaxpr_pl, call_jaxpr_c) def test_globalphase(self): @@ -241,8 +242,8 @@ def circuit(phi): qjit_obj(0.5) catalxpr = qjit_obj.jaxpr - call_jaxpr_pl = converted.eqns[0].params["call_jaxpr"] - call_jaxpr_c = catalxpr.eqns[1].params["call_jaxpr"] + call_jaxpr_pl = get_call_jaxpr(converted) + call_jaxpr_c = get_call_jaxpr(catalxpr) compare_call_jaxprs(call_jaxpr_pl, call_jaxpr_c) def test_expval(self): @@ -269,8 +270,8 @@ def circuit(x): qjit_obj = qml.qjit(circuit) qjit_obj(0.5) catalxpr = qjit_obj.jaxpr - call_jaxpr_pl = converted.eqns[0].params["call_jaxpr"] - call_jaxpr_c = catalxpr.eqns[1].params["call_jaxpr"] + call_jaxpr_pl = get_call_jaxpr(converted) + call_jaxpr_c = get_call_jaxpr(catalxpr) compare_call_jaxprs(call_jaxpr_pl, call_jaxpr_c) @@ -301,8 +302,8 @@ def circuit(x): qjit_obj = qml.qjit(circuit) qjit_obj(0.5) catalxpr = qjit_obj.jaxpr - call_jaxpr_pl = converted.eqns[0].params["call_jaxpr"] - call_jaxpr_c = catalxpr.eqns[1].params["call_jaxpr"] + call_jaxpr_pl = get_call_jaxpr(converted) + call_jaxpr_c = get_call_jaxpr(catalxpr) compare_call_jaxprs(call_jaxpr_pl, call_jaxpr_c) @@ -340,8 +341,8 @@ def circuit(phi): qjit_obj = qml.qjit(circuit) qjit_obj(phi) catalxpr = qjit_obj.jaxpr - call_jaxpr_pl = converted.eqns[0].params["call_jaxpr"] - call_jaxpr_c = catalxpr.eqns[1].params["call_jaxpr"] + call_jaxpr_pl = get_call_jaxpr(converted) + call_jaxpr_c = get_call_jaxpr(catalxpr) # confused by the weak_types error here compare_call_jaxprs(call_jaxpr_pl, call_jaxpr_c) @@ -375,8 +376,8 @@ def circuit(x): qjit_obj = qml.qjit(circuit) qjit_obj(x) catalxpr = qjit_obj.jaxpr - call_jaxpr_pl = converted.eqns[0].params["call_jaxpr"] - call_jaxpr_c = catalxpr.eqns[1].params["call_jaxpr"] + call_jaxpr_pl = get_call_jaxpr(converted) + call_jaxpr_c = get_call_jaxpr(catalxpr) compare_call_jaxprs(call_jaxpr_pl, call_jaxpr_c) @@ -404,6 +405,39 @@ def circuit(): expected = np.transpose(np.vstack([np.ones(50), np.zeros(50)])) assert qml.math.allclose(catalyst_res[0], expected) + qjit_obj = qml.qjit(circuit) + qjit_obj() + catalxpr = qjit_obj.jaxpr + call_jaxpr_pl = get_call_jaxpr(converted) + call_jaxpr_c = get_call_jaxpr(catalxpr) + + compare_call_jaxprs(call_jaxpr_pl, call_jaxpr_c) + + @pytest.mark.xfail(reason="CountsMP returns a dictionary, which is not compatible with capture") + def test_counts(self): + """Test comparison and execution of a jaxpr returning counts.""" + + dev = qml.device("lightning.qubit", wires=2, shots=50) + + @qml.qnode(dev) + def circuit(): + qml.X(0) + return qml.counts() + + qml.capture.enable() + plxpr = jax.make_jaxpr(circuit)() + qml.capture.disable() + + converted = from_plxpr(plxpr)() + + assert converted.eqns[0].primitive == catalyst.jax_primitives.quantum_kernel_p + assert converted.eqns[0].params["qnode"] is circuit + + catalyst_res = catalyst_execute_jaxpr(converted)() + assert len(catalyst_res) == 1 + expected = np.transpose(np.vstack([np.ones(50), np.zeros(50)])) + assert qml.math.allclose(catalyst_res[0], expected) + qjit_obj = qml.qjit(circuit) qjit_obj() catalxpr = qjit_obj.jaxpr @@ -449,8 +483,8 @@ def circuit(x, y, z): qjit_obj = qml.qjit(circuit) qjit_obj(x, y, z) catalxpr = qjit_obj.jaxpr - call_jaxpr_pl = converted.eqns[0].params["call_jaxpr"] - call_jaxpr_c = catalxpr.eqns[1].params["call_jaxpr"] + call_jaxpr_pl = get_call_jaxpr(converted) + call_jaxpr_c = get_call_jaxpr(catalxpr) compare_call_jaxprs(call_jaxpr_pl, call_jaxpr_c) @@ -494,8 +528,8 @@ def workflow(z): qjit_obj = qml.qjit(workflow) qjit_obj(0.5) - call_jaxpr_pl = converted.eqns[1].params["call_jaxpr"] - call_jaxpr_c = qjit_obj.jaxpr.eqns[2].params["call_jaxpr"] + call_jaxpr_pl = get_call_jaxpr(converted) + call_jaxpr_c = get_call_jaxpr(qjit_obj.jaxpr) # qubit extraction and classical equations in a slightly different order # thus cant check specific equations and have to discard comparing counts diff --git a/frontend/test/pytest/test_jax_linalg.py b/frontend/test/pytest/test_jax_linalg.py index cf4118f7ea..246a73bff5 100644 --- a/frontend/test/pytest/test_jax_linalg.py +++ b/frontend/test/pytest/test_jax_linalg.py @@ -1070,3 +1070,7 @@ def f(X): assert jnp.allclose(U_obs, U_exp) assert jnp.allclose(S_obs, S_exp) assert jnp.allclose(Vt_obs, Vt_exp) + + +if __name__ == "__main__": + pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/test_jax_linalg_in_circuit.py b/frontend/test/pytest/test_jax_linalg_in_circuit.py index 16a9d2226c..a099bf822b 100644 --- a/frontend/test/pytest/test_jax_linalg_in_circuit.py +++ b/frontend/test/pytest/test_jax_linalg_in_circuit.py @@ -18,6 +18,7 @@ import numpy as np import pennylane as qml +import pytest from jax import numpy as jnp from jax import scipy as jsp @@ -46,3 +47,7 @@ def circuit_rot(): res = circuit_expm() expected = circuit_rot() # expected = [0,1] assert np.allclose(res, expected) + + +if __name__ == "__main__": + pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/test_jax_primitives.py b/frontend/test/pytest/test_jax_primitives.py index cdeb4b51b8..44d5995bae 100644 --- a/frontend/test/pytest/test_jax_primitives.py +++ b/frontend/test/pytest/test_jax_primitives.py @@ -26,10 +26,10 @@ from jax.interpreters.mlir import ir_constant, make_ir_context from catalyst.jax_primitives import ( - _get_call_jaxpr, _qextract_lowering, _qinsert_lowering, extract_scalar, + get_call_jaxpr, safe_cast_to_f64, ) @@ -136,14 +136,14 @@ def test_scalar_extraction_error(self, test_input): def test_get_call_jaxpr(): - """Test _get_call_jaxpr raises AsserionError if no function primitive exists.""" + """Test get_call_jaxpr raises AsserionError if no function primitive exists.""" def f(x): return x * x jaxpr = make_jaxpr(f)(2.0) with pytest.raises(AssertionError, match="No call_jaxpr found in the JAXPR"): - _ = _get_call_jaxpr(jaxpr) + _ = get_call_jaxpr(jaxpr) if __name__ == "__main__": diff --git a/frontend/test/pytest/test_loops.py b/frontend/test/pytest/test_loops.py index 38e8597017..b5004bed24 100644 --- a/frontend/test/pytest/test_loops.py +++ b/frontend/test/pytest/test_loops.py @@ -33,7 +33,6 @@ def test_while_loop(self): expected = dedent( """ { lambda ; a:f64[]. let - transform_named_sequence b:i64[] c:f64[] = while_loop[ body_jaxpr={ lambda ; d:i64[] e:f64[]. let f:i64[] = add d 1 in (f, e) } body_nconsts=0 @@ -62,7 +61,6 @@ def test_for_loop(self): expected = dedent( """ { lambda ; a:f64[] b:i64[]. let - transform_named_sequence c:i64[] d:f64[] = for_loop[ apply_reverse_transform=False body_jaxpr={ lambda ; e:i64[] f:i64[] g:f64[]. let diff --git a/frontend/test/pytest/test_measurement_primitives.py b/frontend/test/pytest/test_measurement_primitives.py index 41229d2dc2..df0ca05468 100644 --- a/frontend/test/pytest/test_measurement_primitives.py +++ b/frontend/test/pytest/test_measurement_primitives.py @@ -14,44 +14,95 @@ """ This file contains a couple of tests for the capture of measurement primitives into jaxpr. """ + +# pylint: disable=line-too-long + import jax +import pennylane as qml -from catalyst.jax_primitives import ( - compbasis_p, - counts_p, - expval_p, - probs_p, - sample_p, - state_p, - var_p, -) +import catalyst +from catalyst.debug import get_compilation_stage, replace_ir +from catalyst.jax_primitives import compbasis_p, expval_p, probs_p, state_p, var_p -def test_sample(): - """Test that the sample primitive can be captured into jaxpr.""" +def test_dynamic_sample_backend_functionality(): + """Test that a `sample` program with dynamic shots can be executed correctly.""" - def f(): - obs = compbasis_p.bind() - return sample_p.bind(obs, shots=5, shape=(5, 0)) + @catalyst.qjit(keep_intermediate=True) + def workflow_dyn_sample(shots): # pylint: disable=unused-argument + # qml.device still needs concrete shots + device = qml.device("lightning.qubit", wires=1, shots=10) - jaxpr = jax.make_jaxpr(f)().jaxpr - assert jaxpr.eqns[1].primitive == sample_p - assert jaxpr.eqns[1].params == {"shape": (5, 0), "shots": 5} - assert jaxpr.eqns[1].outvars[0].aval.shape == (5, 0) + @qml.qnode(device) + def circuit(): + qml.RX(1.5, 0) + return qml.sample() + return circuit() -def test_counts(): - """Test that the counts primitive can be captured by jaxpr.""" + workflow_dyn_sample(10) + old_ir = get_compilation_stage(workflow_dyn_sample, "mlir") + workflow_dyn_sample.workspace.cleanup() - def f(): - obs = compbasis_p.bind() - return counts_p.bind(obs, shots=5, shape=(1,)) + new_ir = old_ir.replace( + "catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<10x1xi64>", + "catalyst.launch_kernel @module_circuit::@circuit(%arg0) : (tensor) -> tensor", + ) + new_ir = new_ir.replace( + "func.func public @circuit() -> tensor<10x1xi64>", + "func.func public @circuit(%arg0: tensor) -> tensor", + ) + new_ir = new_ir.replace( + "quantum.device shots(%extracted) [", + """%shots = tensor.extract %arg0[] : tensor + quantum.device shots(%shots) [""", + ) + new_ir = new_ir.replace("tensor<10x1x", "tensor", + "catalyst.launch_kernel @module_circuit::@circuit(%arg0) : (tensor) ->", + ) + new_ir = new_ir.replace( + "func.func public @circuit() ->", "func.func public @circuit(%arg0: tensor) ->" + ) + new_ir = new_ir.replace( + "quantum.device shots(%extracted) [", + """%shots = tensor.extract %arg0[] : tensor + quantum.device shots(%shots) [""", + ) + + replace_ir(workflow_dyn_counts, "mlir", new_ir) + res = workflow_dyn_counts(4000) + assert res[1][0] + res[1][1] == 4000 + + workflow_dyn_counts.workspace.cleanup() def test_expval(): @@ -59,11 +110,11 @@ def test_expval(): def f(): obs = compbasis_p.bind() - return expval_p.bind(obs, shots=5, shape=(1,)) + return expval_p.bind(obs, shape=(1,)) jaxpr = jax.make_jaxpr(f)() assert jaxpr.eqns[1].primitive == expval_p - assert jaxpr.eqns[1].params == {"shape": (1,), "shots": 5} + assert jaxpr.eqns[1].params == {"shape": (1,)} assert jaxpr.eqns[1].outvars[0].aval.shape == () @@ -72,11 +123,11 @@ def test_var(): def f(): obs = compbasis_p.bind() - return var_p.bind(obs, shots=5, shape=(1,)) + return var_p.bind(obs, shape=(1,)) jaxpr = jax.make_jaxpr(f)() assert jaxpr.eqns[1].primitive == var_p - assert jaxpr.eqns[1].params == {"shape": (1,), "shots": 5} + assert jaxpr.eqns[1].params == {"shape": (1,)} assert jaxpr.eqns[1].outvars[0].aval.shape == () diff --git a/frontend/test/pytest/test_measurement_transforms.py b/frontend/test/pytest/test_measurement_transforms.py index 517edf130f..997f82a947 100644 --- a/frontend/test/pytest/test_measurement_transforms.py +++ b/frontend/test/pytest/test_measurement_transforms.py @@ -15,7 +15,6 @@ """ # pylint: disable=unused-argument import os -import pathlib # pylint: disable=unused-argument import platform @@ -28,16 +27,16 @@ import pytest from conftest import CONFIG_CUSTOM_DEVICE from pennylane.devices import Device +from pennylane.devices.capabilities import OperatorProperties from pennylane.transforms import split_non_commuting, split_to_single_terms from catalyst.compiler import get_lib_path -from catalyst.device import QJITDevice, get_device_capabilities, get_device_toml_config +from catalyst.device import QJITDevice, get_device_capabilities from catalyst.device.decomposition import ( measurements_from_counts, measurements_from_samples, ) from catalyst.tracing.contexts import EvaluationContext, EvaluationMode -from catalyst.utils.toml import OperationProperties # pylint: disable=attribute-defined-outside-init @@ -45,15 +44,13 @@ class CustomDevice(Device): """A Custom Device following the new API.""" - config = CONFIG_CUSTOM_DEVICE + config_filepath = CONFIG_CUSTOM_DEVICE + + _to_matrix_ops = {"BlockEncode": OperatorProperties(False, False, False)} def __init__(self, wires, shots=1024): - print(pathlib.Path(__file__).parent.parent.parent.parent) super().__init__(wires=wires, shots=shots) - dummy_capabilities = get_device_capabilities(self) - dummy_capabilities.native_ops.pop("BlockEncode") - dummy_capabilities.to_matrix_ops["BlockEncode"] = OperationProperties(False, False, False) - self.capabilities = dummy_capabilities + self.capabilities.operations.pop("BlockEncode") @staticmethod def get_c_interface(): @@ -75,7 +72,7 @@ def execute(self, circuits, execution_config): class CustomDeviceLimitedMPs(Device): """A Custom Device from the device API without wires.""" - config = CONFIG_CUSTOM_DEVICE + config_filepath = CONFIG_CUSTOM_DEVICE def __init__(self, wires, shots=1024, allow_counts=False, allow_samples=False): self.allow_samples = allow_samples @@ -101,21 +98,21 @@ def execute(self, circuits, execution_config): return circuits, execution_config def __enter__(self, *args, **kwargs): - toml_file_path = self.config + toml_file_path = self.config_filepath with open(toml_file_path, mode="r", encoding="UTF-8") as f: toml_contents = f.readlines() updated_toml_contents = [] for line in toml_contents: - if "Expval" in line: + if "ExpectationMP" in line: continue - if "Var" in line: + if "VarianceMP" in line: continue - if "Probs" in line: + if "ProbabilityMP" in line: continue - if "Sample" in line and not self.allow_samples: + if "SampleMP" in line and not self.allow_samples: continue - if "Counts" in line and not self.allow_counts: + if "CountsMP" in line and not self.allow_counts: continue updated_toml_contents.append(line) @@ -124,12 +121,12 @@ def __enter__(self, *args, **kwargs): self.toml_file.writelines(updated_toml_contents) self.toml_file.close() # close for now without deleting - self.config = self.toml_file.name + self.config_filepath = self.toml_file.name return self def __exit__(self, *args, **kwargs): os.unlink(self.toml_file.name) - self.config = None + self.config_filepath = None class TestMeasurementTransforms: @@ -235,8 +232,8 @@ def basic_circuit(theta: float): @pytest.mark.parametrize( "unsupported_measurement, measurement_transform, target_measurement", [ - ("Sample", measurements_from_counts, "counts"), - ("Counts", measurements_from_samples, "sample"), + ("SampleMP", measurements_from_counts, "counts"), + ("CountsMP", measurements_from_samples, "sample"), (None, measurements_from_samples, "sample"), ], ) @@ -250,12 +247,14 @@ def test_measurement_from_readout_integration_if_no_observables_supported( dev = qml.device("lightning.qubit", wires=4, shots=100) - config = get_device_toml_config(dev) - config["operators"]["observables"] = {} + config = get_device_capabilities(dev) + config.observables = {} if unsupported_measurement: - del config["measurement_processes"][unsupported_measurement] + del config.measurement_processes[unsupported_measurement] - with patch("catalyst.device.qjit_device.get_device_toml_config", Mock(return_value=config)): + with patch( + "catalyst.device.qjit_device.get_device_capabilities", Mock(return_value=config) + ): # transform is added to transform program qjit_dev = QJITDevice(dev) @@ -350,10 +349,12 @@ def circuit(): return qml.expval(qml.X(0)), qml.var(qml.Y(1)) # modify config to indicate no observables supported - config = get_device_toml_config(dev) - config["operators"]["observables"] = {} + config = get_device_capabilities(dev) + config.observables = {} - with patch("catalyst.device.qjit_device.get_device_toml_config", Mock(return_value=config)): + with patch( + "catalyst.device.qjit_device.get_device_capabilities", Mock(return_value=config) + ): with pytest.raises( RuntimeError, match="The device does not support observables or sample/counts" ): @@ -619,14 +620,16 @@ def circuit(theta: float): expected_result = circuit(1.2) - config = get_device_toml_config(dev) + config = get_device_capabilities(dev) for obs in unsupported_obs: - del config["operators"]["observables"][obs] + del config.observables[obs] spy = mocker.spy(QJITDevice, "preprocess") # mock TOML file output to indicate some observables are not supported - with patch("catalyst.device.qjit_device.get_device_toml_config", Mock(return_value=config)): + with patch( + "catalyst.device.qjit_device.get_device_capabilities", Mock(return_value=config) + ): jitted_circuit = qml.qjit(circuit) transform_program, _ = spy.spy_return @@ -663,12 +666,14 @@ def circuit(): for obs in unsupported_obs: assert f"{obs}] : !quantum.obs" in mlir - config = get_device_toml_config(dev) + config = get_device_capabilities(dev) for obs in unsupported_obs: - del config["operators"]["observables"][obs] + del config.observables[obs] # mock TOML file output to indicate some observables are not supported - with patch("catalyst.device.qjit_device.get_device_toml_config", Mock(return_value=config)): + with patch( + "catalyst.device.qjit_device.get_device_capabilities", Mock(return_value=config) + ): mlir = qml.qjit(circuit, target="mlir").mlir for obs in unsupported_obs: @@ -684,16 +689,18 @@ def test_split_non_commuting_is_added_for_partial_diagonalization( dev = qml.device("lightning.qubit", wires=4, shots=1000) - config = get_device_toml_config(dev) + config = get_device_capabilities(dev) - del config["operators"]["observables"]["Hadamard"] - config["compilation"]["non_commuting_observables"] = non_commuting_flag + del config.observables["Hadamard"] + config.non_commuting_observables = non_commuting_flag - with patch("catalyst.device.qjit_device.get_device_toml_config", Mock(return_value=config)): + with patch( + "catalyst.device.qjit_device.get_device_capabilities", Mock(return_value=config) + ): qjit_dev = QJITDevice(dev) # dev1 supports non-commuting observables and sum observables - no splitting - assert qjit_dev.capabilities.non_commuting_observables_flag is non_commuting_flag + assert qjit_dev.capabilities.non_commuting_observables is non_commuting_flag # Check the preprocess with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION) as ctx: @@ -711,16 +718,18 @@ def test_split_non_commuting_is_added_for_full_diagonalization( dev = qml.device("lightning.qubit", wires=4, shots=1000) - config = get_device_toml_config(dev) + config = get_device_capabilities(dev) - config["operators"]["observables"] = {} - config["compilation"]["non_commuting_observables"] = non_commuting_flag + config.observables = {} + config.non_commuting_observables = non_commuting_flag - with patch("catalyst.device.qjit_device.get_device_toml_config", Mock(return_value=config)): + with patch( + "catalyst.device.qjit_device.get_device_capabilities", Mock(return_value=config) + ): qjit_dev = QJITDevice(dev) # dev1 supports non-commuting observables and sum observables - no splitting - assert qjit_dev.capabilities.non_commuting_observables_flag is non_commuting_flag + assert qjit_dev.capabilities.non_commuting_observables is non_commuting_flag # Check the preprocess with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION) as ctx: @@ -737,21 +746,21 @@ def test_measurements_are_split(self, mocker): # dev1 supports non-commuting observables and sum observables - no splitting qjit_dev1 = QJITDevice(dev) - assert "Sum" in qjit_dev1.capabilities.native_obs - assert qjit_dev1.capabilities.non_commuting_observables_flag is True + assert "Sum" in qjit_dev1.capabilities.observables + assert qjit_dev1.capabilities.non_commuting_observables is True # dev2 supports non-commuting observables but NOT sums - split_to_single_terms qjit_dev2 = QJITDevice(dev) - del qjit_dev2.capabilities.native_obs["Sum"] + del qjit_dev2.capabilities.observables["Sum"] # dev3 supports does not support non-commuting observables OR sums - split_non_commuting qjit_dev3 = QJITDevice(dev) - del qjit_dev3.capabilities.native_obs["Sum"] - qjit_dev3.capabilities.non_commuting_observables_flag = False + del qjit_dev3.capabilities.observables["Sum"] + qjit_dev3.capabilities.non_commuting_observables = False # dev4 supports sums but NOT non-commuting observables - split_non_commuting qjit_dev4 = QJITDevice(dev) - qjit_dev4.capabilities.non_commuting_observables_flag = False + qjit_dev4.capabilities.non_commuting_observables = False # Check the preprocess with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION) as ctx: @@ -794,12 +803,14 @@ def unjitted_circuit(theta: float): expected_result = unjitted_circuit(1.2) - config = get_device_toml_config(dev) + config = get_device_capabilities(dev) spy = mocker.spy(QJITDevice, "preprocess") # mock TOML file output to indicate non-commuting observables are supported - config["compilation"]["non_commuting_observables"] = True - with patch("catalyst.device.qjit_device.get_device_toml_config", Mock(return_value=config)): + config.non_commuting_observables = True + with patch( + "catalyst.device.qjit_device.get_device_capabilities", Mock(return_value=config) + ): jitted_circuit = qml.qjit(unjitted_circuit) assert len(jitted_circuit(1.2)) == len(expected_result) == 2 assert np.allclose(jitted_circuit(1.2), expected_result) @@ -808,8 +819,10 @@ def unjitted_circuit(theta: float): assert split_non_commuting not in transform_program # mock TOML file output to indicate non-commuting observables are NOT supported - config["compilation"]["non_commuting_observables"] = False - with patch("catalyst.device.qjit_device.get_device_toml_config", Mock(return_value=config)): + config.non_commuting_observables = False + with patch( + "catalyst.device.qjit_device.get_device_capabilities", Mock(return_value=config) + ): jitted_circuit = qml.qjit(unjitted_circuit) assert len(jitted_circuit(1.2)) == len(expected_result) == 2 assert np.allclose(jitted_circuit(1.2), unjitted_circuit(1.2)) @@ -832,14 +845,14 @@ def unjitted_circuit(theta: float): expected_result = unjitted_circuit(1.2) - config = get_device_toml_config(dev) + config = get_device_capabilities(dev) spy = mocker.spy(QJITDevice, "preprocess") # make sure non_commuting_observables_flag is True - otherwise we use # split_non_commuting instead of split_to_single_terms - assert config["compilation"]["non_commuting_observables"] is True + assert config.non_commuting_observables is True # make sure the testing device does in fact support sum observables - assert "Sum" in config["operators"]["observables"] + assert "Sum" in config.observables # test case where transform should not be applied jitted_circuit = qml.qjit(unjitted_circuit) @@ -850,8 +863,10 @@ def unjitted_circuit(theta: float): assert split_to_single_terms not in transform_program # mock TOML file output to indicate non-commuting observables are NOT supported - del config["operators"]["observables"]["Sum"] - with patch("catalyst.device.qjit_device.get_device_toml_config", Mock(return_value=config)): + del config.observables["Sum"] + with patch( + "catalyst.device.qjit_device.get_device_capabilities", Mock(return_value=config) + ): jitted_circuit = qml.qjit(unjitted_circuit) assert len(jitted_circuit(1.2)) == len(expected_result) == 2 assert np.allclose(jitted_circuit(1.2), unjitted_circuit(1.2)) diff --git a/frontend/test/pytest/test_measurements_results.py b/frontend/test/pytest/test_measurements_results.py index 325cf5ecad..e69973d1a8 100644 --- a/frontend/test/pytest/test_measurements_results.py +++ b/frontend/test/pytest/test_measurements_results.py @@ -1,23 +1,29 @@ # Copyright 2022-2023 Xanadu Quantum Technologies Inc. - +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at - +# # http://www.apache.org/licenses/LICENSE-2.0 - +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import platform +from copy import deepcopy + import numpy as np import pennylane as qml import pytest +from conftest import CONFIG_CUSTOM_DEVICE from jax import numpy as jnp -from catalyst import qjit +from catalyst import CompileError, qjit +from catalyst.device import get_device_capabilities +from catalyst.utils.runtime_environment import get_lib_path # pylint: disable=too-many-lines @@ -1000,17 +1006,50 @@ def circuit(x, y): assert np.allclose(expected, result) +class CustomDevice(qml.devices.Device): + """Custom Gate Set Device""" + + name = "Custom Device" + config_filepath = CONFIG_CUSTOM_DEVICE + + _to_matrix_ops = {} + + def __init__(self, shots=None, wires=None): + super().__init__(wires=wires, shots=shots) + self.qjit_capabilities = deepcopy(get_device_capabilities(self)) + self.qjit_capabilities.measurement_processes["DensityMatrixMP"] = [] + + def apply(self, operations, **kwargs): + """Unused""" + raise RuntimeError("Only C/C++ interface is defined") + + @staticmethod + def get_c_interface(): + """Returns a tuple consisting of the device name, and + the location to the shared object with the C/C++ device implementation. + """ + system_extension = ".dylib" if platform.system() == "Darwin" else ".so" + lib_path = ( + get_lib_path("runtime", "RUNTIME_LIB_DIR") + "/librtd_null_qubit" + system_extension + ) + return "NullQubit", lib_path + + def execute(self, circuits, execution_config): + """Execution.""" + return circuits, execution_config + + class TestDensityMatrixMP: """Tests for density_matrix""" - def test_error(self, backend): + def test_error(self): """Test that tracing density matrix produces an error""" - err_msg = "Measurement .* is not implemented" - with pytest.raises(NotImplementedError, match=err_msg): + err_msg = "DensityMatrixMP is not a supported measurement process" + with pytest.raises(CompileError, match=err_msg): @qml.qjit - @qml.qnode(qml.device(backend, wires=1)) + @qml.qnode(CustomDevice(wires=1)) def circuit(): return qml.density_matrix([0]) diff --git a/frontend/test/pytest/test_measurements_shots_results.py b/frontend/test/pytest/test_measurements_shots_results.py index 241242c1fd..31d79c19a3 100644 --- a/frontend/test/pytest/test_measurements_shots_results.py +++ b/frontend/test/pytest/test_measurements_shots_results.py @@ -41,7 +41,7 @@ def circuit(): return qml.expval(qml.Identity(wires=0)), qml.expval(qml.Identity(wires=1)) expected = circuit() - result = qjit(circuit)() + result = qjit(circuit, seed=37)() assert np.allclose(result, expected, atol=tol_stochastic, rtol=tol_stochastic) def test_pauliz(self, backend, tol_stochastic): @@ -61,7 +61,7 @@ def circuit(): return qml.expval(qml.PauliZ(wires=0)), qml.expval(qml.PauliZ(wires=1)) expected = circuit() - result = qjit(circuit)() + result = qjit(circuit, seed=37)() assert np.allclose(result, expected, atol=tol_stochastic, rtol=tol_stochastic) def test_paulix(self, backend, tol_stochastic): @@ -81,7 +81,7 @@ def circuit(): return qml.expval(qml.PauliX(wires=0)), qml.expval(qml.PauliX(wires=1)) expected = circuit() - result = qjit(circuit)() + result = qjit(circuit, seed=37)() assert np.allclose(result, expected, atol=tol_stochastic, rtol=tol_stochastic) def test_pauliy(self, backend, tol_stochastic): @@ -101,7 +101,7 @@ def circuit(): return qml.expval(qml.PauliY(wires=0)), qml.expval(qml.PauliY(wires=1)) expected = circuit() - result = qjit(circuit)() + result = qjit(circuit, seed=37)() assert np.allclose(result, expected, atol=tol_stochastic, rtol=tol_stochastic) def test_hadamard(self, backend, tol_stochastic): @@ -121,14 +121,16 @@ def circuit(): return qml.expval(qml.Hadamard(wires=0)), qml.expval(qml.Hadamard(wires=1)) expected = circuit() - result = qjit(circuit)() + result = qjit(circuit, seed=37)() assert np.allclose(result, expected, atol=tol_stochastic, rtol=tol_stochastic) - def test_hermitian(self, backend): + def test_hermitian(self, backend, tol_stochastic): """Test expval Hermitian observables with shots.""" + n_wires = 3 + n_shots = 10000 + dev = qml.device(backend, wires=n_wires, shots=n_shots) - @qjit - @qml.qnode(qml.device(backend, wires=3, shots=10000)) + @qml.qnode(dev) def circuit(x, y): qml.RX(x, wires=0) qml.RX(y, wires=1) @@ -136,13 +138,12 @@ def circuit(x, y): A = np.array( [[complex(1.0, 0.0), complex(2.0, 0.0)], [complex(2.0, 0.0), complex(1.0, 0.0)]] ) - return qml.expval(qml.Hermitian(A, wires=2) + qml.PauliX(0) + qml.Hermitian(A, wires=1)) + return qml.expval(qml.Hermitian(A, wires=2)) - with pytest.raises( - RuntimeError, - match="Hermitian observables with shot measurement are not supported", - ): - circuit(np.pi / 4, np.pi / 4) + expected = circuit(np.pi / 4, np.pi / 4) + result = qjit(circuit, seed=37)(np.pi / 4, np.pi / 4) + + assert np.allclose(result, expected, atol=tol_stochastic, rtol=tol_stochastic) def test_paulix_pauliy(self, backend, tol_stochastic): """Test that a tensor product involving PauliX and PauliY works correctly""" @@ -164,7 +165,7 @@ def circuit(): return qml.expval(qml.PauliX(wires=0) @ qml.PauliY(wires=2)) expected = circuit() - result = qjit(circuit)() + result = qjit(circuit, seed=37)() assert np.allclose(result, expected, atol=tol_stochastic, rtol=tol_stochastic) def test_pauliz_pauliy_prod(self, backend, tol_stochastic): @@ -183,7 +184,7 @@ def circuit(theta, phi, varphi): return qml.expval(qml.PauliX(2) @ qml.PauliY(1) @ qml.PauliZ(0)) expected = circuit(0.432, 0.123, -0.543) - result = qjit(circuit)(0.432, 0.123, -0.543) + result = qjit(circuit, seed=37)(0.432, 0.123, -0.543) assert np.allclose(result, expected, atol=tol_stochastic, rtol=tol_stochastic) def test_pauliz_hamiltonian(self, backend, tol_stochastic): @@ -204,7 +205,7 @@ def circuit(theta, phi, varphi): ) expected = circuit(0.432, 0.123, -0.543) - result = qjit(circuit)(0.432, 0.123, -0.543) + result = qjit(circuit, seed=37)(0.432, 0.123, -0.543) assert np.allclose(result, expected, atol=tol_stochastic, rtol=tol_stochastic) def test_prod_hamiltonian(self, backend, tol_stochastic): @@ -225,7 +226,7 @@ def circuit(theta, phi, varphi): ) expected = circuit(0.432, 0.123, -0.543) - result = qjit(circuit)(0.432, 0.123, -0.543) + result = qjit(circuit, seed=37)(0.432, 0.123, -0.543) assert np.allclose(result, expected, atol=tol_stochastic, rtol=tol_stochastic) @@ -249,7 +250,7 @@ def circuit(): return qml.var(qml.Identity(wires=0)), qml.var(qml.Identity(wires=1)) expected = circuit() - result = qjit(circuit)() + result = qjit(circuit, seed=37)() assert np.allclose(result, expected, atol=tol_stochastic, rtol=tol_stochastic) def test_pauliz(self, backend, tol_stochastic): @@ -269,7 +270,7 @@ def circuit(): return qml.var(qml.PauliZ(wires=0)), qml.var(qml.PauliZ(wires=1)) expected = circuit() - result = qjit(circuit)() + result = qjit(circuit, seed=37)() assert np.allclose(result, expected, atol=tol_stochastic, rtol=tol_stochastic) def test_paulix(self, backend, tol_stochastic): @@ -289,7 +290,7 @@ def circuit(): return qml.var(qml.PauliX(wires=0)), qml.var(qml.PauliX(wires=1)) expected = circuit() - result = qjit(circuit)() + result = qjit(circuit, seed=37)() assert np.allclose(result, expected, atol=tol_stochastic, rtol=tol_stochastic) def test_pauliy(self, backend, tol_stochastic): @@ -309,7 +310,7 @@ def circuit(): return qml.var(qml.PauliY(wires=0)), qml.var(qml.PauliY(wires=1)) expected = circuit() - result = qjit(circuit)() + result = qjit(circuit, seed=37)() assert np.allclose(result, expected, atol=tol_stochastic, rtol=tol_stochastic) def test_hadamard(self, backend, tol_stochastic): @@ -329,14 +330,16 @@ def circuit(): return qml.var(qml.Hadamard(wires=0)), qml.var(qml.Hadamard(wires=1)) expected = circuit() - result = qjit(circuit)() + result = qjit(circuit, seed=37)() assert np.allclose(result, expected, atol=tol_stochastic, rtol=tol_stochastic) - def test_hermitian_shots(self, backend): + def test_hermitian_shots(self, backend, tol_stochastic): """Test var Hermitian observables with shots.""" + n_wires = 3 + n_shots = 10000 + dev = qml.device(backend, wires=n_wires, shots=n_shots) - @qjit - @qml.qnode(qml.device(backend, wires=3, shots=10000)) + @qml.qnode(dev) def circuit(x, y): qml.RX(x, wires=0) qml.RX(y, wires=1) @@ -344,13 +347,12 @@ def circuit(x, y): A = np.array( [[complex(1.0, 0.0), complex(2.0, 0.0)], [complex(2.0, 0.0), complex(1.0, 0.0)]] ) - return qml.var(qml.Hermitian(A, wires=2) + qml.PauliX(0) + qml.Hermitian(A, wires=1)) + return qml.var(qml.Hermitian(A, wires=2)) - with pytest.raises( - RuntimeError, - match="Hermitian observables with shot measurement are not supported", - ): - circuit(np.pi / 4, np.pi / 4) + expected = circuit(np.pi / 4, np.pi / 4) + result = qjit(circuit, seed=37)(np.pi / 4, np.pi / 4) + + assert np.allclose(result, expected, atol=tol_stochastic, rtol=tol_stochastic) def test_paulix_pauliy(self, backend, tol_stochastic): """Test that a tensor product involving PauliX and PauliY works correctly""" @@ -372,7 +374,7 @@ def circuit(): return qml.var(qml.PauliX(wires=0) @ qml.PauliY(wires=2)) expected = circuit() - result = qjit(circuit)() + result = qjit(circuit, seed=37)() assert np.allclose(result, expected, atol=tol_stochastic, rtol=tol_stochastic) def test_hadamard_pauliy_prod(self, backend, tol_stochastic): @@ -391,7 +393,7 @@ def circuit(theta, phi, varphi): return qml.var(qml.Hadamard(wires=1) @ qml.PauliY(wires=2)) expected = circuit(0.432, 0.123, -0.543) - result = qjit(circuit)(0.432, 0.123, -0.543) + result = qjit(circuit, seed=37)(0.432, 0.123, -0.543) assert np.allclose(result, expected, atol=tol_stochastic, rtol=tol_stochastic) def test_pauliz_pauliy_prod(self, backend, tol_stochastic): @@ -410,7 +412,7 @@ def circuit(theta, phi, varphi): return qml.var(qml.PauliX(2) @ qml.PauliY(1) @ qml.PauliZ(0)) expected = circuit(0.432, 0.123, -0.543) - result = qjit(circuit)(0.432, 0.123, -0.543) + result = qjit(circuit, seed=37)(0.432, 0.123, -0.543) assert np.allclose(result, expected, atol=tol_stochastic, rtol=tol_stochastic) def test_pauliz_hamiltonian(self, backend): @@ -460,7 +462,7 @@ def circuit(theta): return qml.probs() expected = circuit(0.432) - result = qjit(circuit)(0.432) + result = qjit(circuit, seed=37)(0.432) assert np.allclose(result, expected, atol=tol_stochastic, rtol=tol_stochastic) def test_probs_wire(self, backend, tol_stochastic): @@ -477,7 +479,7 @@ def circuit(theta): return qml.probs(wires=[0]) expected = circuit(0.432) - result = qjit(circuit)(0.432) + result = qjit(circuit, seed=37)(0.432) assert np.allclose(result, expected, atol=tol_stochastic, rtol=tol_stochastic) diff --git a/frontend/test/pytest/test_peephole_optimizations.py b/frontend/test/pytest/test_peephole_optimizations.py index 59970cd028..769215f38f 100644 --- a/frontend/test/pytest/test_peephole_optimizations.py +++ b/frontend/test/pytest/test_peephole_optimizations.py @@ -189,7 +189,7 @@ def classical_func(): TypeError, match="A QNode is expected, got the classical function", ): - pipeline()(classical_func) + pipeline({})(classical_func) with pytest.raises( TypeError, diff --git a/frontend/test/pytest/test_preprocess.py b/frontend/test/pytest/test_preprocess.py index 00f13bf0aa..6e59184d1c 100644 --- a/frontend/test/pytest/test_preprocess.py +++ b/frontend/test/pytest/test_preprocess.py @@ -13,18 +13,15 @@ # limitations under the License. """Test for the device preprocessing. """ -import pathlib import platform from dataclasses import replace -from os.path import join -from tempfile import TemporaryDirectory -from textwrap import dedent import numpy as np import pennylane as qml import pytest from conftest import CONFIG_CUSTOM_DEVICE from pennylane.devices import Device, NullQubit +from pennylane.devices.capabilities import DeviceCapabilities, OperatorProperties from pennylane.tape import QuantumScript from catalyst import CompileError, ctrl @@ -38,41 +35,13 @@ ) from catalyst.api_extensions.quantum_operators import HybridAdjoint, adjoint from catalyst.compiler import get_lib_path -from catalyst.device import get_device_capabilities from catalyst.device.decomposition import catalyst_decompose, decompose_ops_to_unitary from catalyst.jax_tracer import HybridOpRegion from catalyst.tracing.contexts import EvaluationContext, EvaluationMode -from catalyst.utils.toml import ( - DeviceCapabilities, - OperationProperties, - ProgramFeatures, - TOMLDocument, - load_device_capabilities, - read_toml_file, -) # pylint: disable=unused-argument -def get_test_config(config_text: str) -> TOMLDocument: - """Parse test config into the TOMLDocument structure""" - with TemporaryDirectory() as d: - toml_file = join(d, "test.toml") - with open(toml_file, "w", encoding="utf-8") as f: - f.write(config_text) - config = read_toml_file(toml_file) - return config - - -def get_test_device_capabilities( - program_features: ProgramFeatures, config_text: str -) -> DeviceCapabilities: - """Parse test config into the DeviceCapabilities structure""" - config = get_test_config(config_text) - device_capabilities = load_device_capabilities(config, program_features) - return device_capabilities - - class OtherHadamard(qml.Hadamard): """A version of the Hadamard operator that won't be recognized by the QJit device, and will need to be decomposed""" @@ -110,15 +79,17 @@ def decomposition(self): class CustomDevice(Device): """A dummy device from the device API.""" - config = CONFIG_CUSTOM_DEVICE + config_filepath = CONFIG_CUSTOM_DEVICE + + _to_matrix_ops = { + "DiagonalQubitUnitary": OperatorProperties(), + "BlockEncode": OperatorProperties(), + } def __init__(self, wires, shots=1024): - print(pathlib.Path(__file__).parent.parent.parent.parent) super().__init__(wires=wires, shots=shots) - dummy_capabilities = get_device_capabilities(self) - dummy_capabilities.native_ops.pop("BlockEncode") - dummy_capabilities.to_matrix_ops["BlockEncode"] = OperationProperties(False, False, False) - self.qjit_capabilities = dummy_capabilities + self.capabilities.operations.pop("BlockEncode") + self.qjit_capabilities = self.capabilities @staticmethod def get_c_interface(): @@ -235,36 +206,30 @@ def f(): (cond_op, Cond, 2), ] -capabilities = get_test_device_capabilities( - ProgramFeatures(False), - dedent( - """ - schema = 2 - [operators.gates.native] - PauliX = { } - PauliZ = { } - RX = { } - RY = { } - RZ = { } - CNOT = { } - HybridAdjoint = { } - ForLoop = { } - WhileLoop = { } - Cond = { } - QubitUnitary = { } - - [operators.gates.matrix] - S = { } - """ - ), -) +TEST_DEVICE_CONFIG_TEXT = """ +schema = 3 +[operators.gates] +PauliX = { } +PauliZ = { } +RX = { } +RY = { } +RZ = { } +CNOT = { } +HybridAdjoint = { } +ForLoop = { } +WhileLoop = { } +Cond = { } +QubitUnitary = { } +""" class TestPreprocessHybridOp: """Test that the operators on the tapes nested inside HybridOps are also decomposed""" + @pytest.mark.usefixtures("create_temporary_toml_file") + @pytest.mark.parametrize("create_temporary_toml_file", [TEST_DEVICE_CONFIG_TEXT], indirect=True) @pytest.mark.parametrize("op, op_class, num_regions", HYBRID_OPS) - def test_hybrid_op_decomposition(self, op, op_class, num_regions): + def test_hybrid_op_decomposition(self, op, op_class, num_regions, request): """Tests that for a tape containing a HybridOp that contains unsupported Operators, the unsupported Operators are decomposed""" @@ -272,6 +237,9 @@ def test_hybrid_op_decomposition(self, op, op_class, num_regions): for region in op.regions: region.trace = None + capabilities = DeviceCapabilities.from_toml_file(request.node.toml_file) + setattr(capabilities, "to_matrix_ops", {"S": OperatorProperties()}) + # create and decompose the tape tape = QuantumScript([op, qml.X(0), qml.Hadamard(3)]) with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION) as ctx: @@ -287,7 +255,7 @@ def test_hybrid_op_decomposition(self, op, op_class, num_regions): # the HybridOp on the original tape is unmodified, i.e. continues to contain ops # not in `expected_ops`. The post-decomposition HybridOp tape does not - expected_ops = capabilities.native_ops + expected_ops = capabilities.operations for i in range(num_regions): if old_op.regions[i].quantum_tape: assert not np.all( @@ -433,7 +401,9 @@ def loop_rx(x): assert np.isclose(res, expected_res) assert final_phi > 2.0 - def test_decomposition_of_nested_HybridOp(self): + @pytest.mark.usefixtures("create_temporary_toml_file") + @pytest.mark.parametrize("create_temporary_toml_file", [TEST_DEVICE_CONFIG_TEXT], indirect=True) + def test_decomposition_of_nested_HybridOp(self, request): """Tests that HybridOps with HybridOps nested inside them are still decomposed correctly""" # make a weird nested op @@ -455,6 +425,9 @@ def test_decomposition_of_nested_HybridOp(self): region.trace = None for_loop_op.regions[0].trace = None + capabilities = DeviceCapabilities.from_toml_file(request.node.toml_file) + setattr(capabilities, "to_matrix_ops", {"S": OperatorProperties()}) + # do the decomposition and get the new tape with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION) as ctx: (new_tape,), _ = catalyst_decompose(tape, ctx, capabilities) @@ -478,7 +451,7 @@ def test_decomposition_of_nested_HybridOp(self): ) # unsupported ops in the subtape decomposed (original tapes contained Hadamard) for subtape in cond_subtapes: - assert np.all([op.name in capabilities.native_ops for op in subtape.operations]) + assert np.all([op.name in capabilities.operations for op in subtape.operations]) assert "Hadamard" not in [op.name for op in subtape.operations] assert "RZ" in [op.name for op in subtape.operations] @@ -489,11 +462,16 @@ def test_decomposition_of_nested_HybridOp(self): assert "Hadamard" not in [op.name for op in adj_subtape.operations] assert "RZ" in [op.name for op in adj_subtape.operations] - def test_controlled_decomposes_to_unitary_listed(self): + @pytest.mark.usefixtures("create_temporary_toml_file") + @pytest.mark.parametrize("create_temporary_toml_file", [TEST_DEVICE_CONFIG_TEXT], indirect=True) + def test_controlled_decomposes_to_unitary_listed(self, request): """Test that a PennyLane toml-listed operation is decomposed to a QubitUnitary""" tape = qml.tape.QuantumScript([qml.PauliX(0), qml.S(0)]) + capabilities = DeviceCapabilities.from_toml_file(request.node.toml_file) + setattr(capabilities, "to_matrix_ops", {"S": OperatorProperties()}) + with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION) as ctx: (new_tape,), _ = catalyst_decompose(tape, ctx, capabilities) @@ -501,11 +479,16 @@ def test_controlled_decomposes_to_unitary_listed(self): assert isinstance(new_tape.operations[0], qml.PauliX) assert isinstance(new_tape.operations[1], qml.QubitUnitary) - def test_controlled_decomposes_to_unitary_controlled(self): + @pytest.mark.usefixtures("create_temporary_toml_file") + @pytest.mark.parametrize("create_temporary_toml_file", [TEST_DEVICE_CONFIG_TEXT], indirect=True) + def test_controlled_decomposes_to_unitary_controlled(self, request): """Test that a PennyLane controlled operation is decomposed to a QubitUnitary""" tape = qml.tape.QuantumScript([qml.ctrl(qml.RX(1.23, 0), 1)]) + capabilities = DeviceCapabilities.from_toml_file(request.node.toml_file) + setattr(capabilities, "to_matrix_ops", {"S": OperatorProperties()}) + with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION) as ctx: (new_tape,), _ = catalyst_decompose(tape, ctx, capabilities) @@ -515,7 +498,9 @@ def test_controlled_decomposes_to_unitary_controlled(self): assert isinstance(new_op, qml.QubitUnitary) assert np.allclose(new_op.matrix(), tape.operations[0].matrix()) - def test_error_for_pennylane_midmeasure_decompose(self): + @pytest.mark.usefixtures("create_temporary_toml_file") + @pytest.mark.parametrize("create_temporary_toml_file", [TEST_DEVICE_CONFIG_TEXT], indirect=True) + def test_error_for_pennylane_midmeasure_decompose(self, request): """Test that an error is raised in decompose if a PennyLane mid-circuit measurement is encountered""" @@ -526,13 +511,18 @@ def test_error_for_pennylane_midmeasure_decompose(self): ops, measurements = qml.queuing.process_queue(q) tape = qml.tape.QuantumScript(ops, measurements) + capabilities = DeviceCapabilities.from_toml_file(request.node.toml_file) + setattr(capabilities, "to_matrix_ops", {"S": OperatorProperties()}) + with pytest.raises( CompileError, match="Must use 'measure' from Catalyst instead of PennyLane." ): with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION) as ctx: _ = catalyst_decompose(tape, ctx, capabilities) - def test_error_for_pennylane_midmeasure_decompose_nested(self): + @pytest.mark.usefixtures("create_temporary_toml_file") + @pytest.mark.parametrize("create_temporary_toml_file", [TEST_DEVICE_CONFIG_TEXT], indirect=True) + def test_error_for_pennylane_midmeasure_decompose_nested(self, request): """Test that an error is raised in decompose if a PennyLane mid-circuit measurement is encountered""" @@ -549,26 +539,36 @@ def test_error_for_pennylane_midmeasure_decompose_nested(self): tape = qml.tape.QuantumScript([adjoint_op, qml.Y(1)], []) + capabilities = DeviceCapabilities.from_toml_file(request.node.toml_file) + setattr(capabilities, "to_matrix_ops", {"S": OperatorProperties()}) + with pytest.raises( CompileError, match="Must use 'measure' from Catalyst instead of PennyLane." ): with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION) as ctx: _ = catalyst_decompose(tape, ctx, capabilities) - def test_unsupported_op_with_no_decomposition_raises_error(self): + @pytest.mark.usefixtures("create_temporary_toml_file") + @pytest.mark.parametrize("create_temporary_toml_file", [TEST_DEVICE_CONFIG_TEXT], indirect=True) + def test_unsupported_op_with_no_decomposition_raises_error(self, request): """Test that an unsupported operator that doesn't provide a decomposition raises a CompileError""" tape = qml.tape.QuantumScript([qml.Y(0)]) + capabilities = DeviceCapabilities.from_toml_file(request.node.toml_file) + setattr(capabilities, "to_matrix_ops", {"S": OperatorProperties()}) + with pytest.raises( CompileError, match="not supported with catalyst on this device and does not provide a decomposition", ): with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION) as ctx: - _ = catalyst_decompose(tape, ctx, replace(capabilities, native_ops={})) + _ = catalyst_decompose(tape, ctx, replace(capabilities, operations={})) - def test_decompose_to_matrix_raises_error(self): + @pytest.mark.usefixtures("create_temporary_toml_file") + @pytest.mark.parametrize("create_temporary_toml_file", [TEST_DEVICE_CONFIG_TEXT], indirect=True) + def test_decompose_to_matrix_raises_error(self, request): """Test that _decompose_to_matrix raises a CompileError if the operator has no matrix""" class NoMatrixMultiControlledX(qml.MultiControlledX): @@ -580,12 +580,15 @@ def matrix(self): tape = qml.tape.QuantumScript([NoMatrixMultiControlledX(wires=[0, 1, 2, 3])]) + capabilities = DeviceCapabilities.from_toml_file(request.node.toml_file) + setattr(capabilities, "to_matrix_ops", {"S": OperatorProperties()}) + with pytest.raises(CompileError, match="could not be decomposed, it might be unsupported"): with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION) as ctx: _ = catalyst_decompose( tape, ctx, - replace(capabilities, native_ops={"QubitUnitary": OperationProperties()}), + replace(capabilities, operations={"QubitUnitary": OperatorProperties()}), ) diff --git a/frontend/test/pytest/test_seeded_qjit.py b/frontend/test/pytest/test_seeded_qjit.py index c516398b1c..d23d21e9ed 100644 --- a/frontend/test/pytest/test_seeded_qjit.py +++ b/frontend/test/pytest/test_seeded_qjit.py @@ -47,8 +47,6 @@ def _(): "seed", [ 42, - 37, - 1337, 2**32 - 1, 0, ], @@ -91,7 +89,7 @@ def cfun0(): return circuit(), circuit(), circuit(), circuit() # Calls to qjits with the same seed should return the same results - for _ in range(5): + for _ in range(3): results0 = workflow() results1 = workflow() results2 = workflow1() @@ -103,8 +101,6 @@ def cfun0(): "seed", [ 42, - 37, - 1337, 2**32 - 1, 0, ], @@ -142,7 +138,150 @@ def circuit(): return circuit(), circuit(), circuit(), circuit() # Calls to qjits with the same seed should return the same samples - for _ in range(5): + for _ in range(3): + results0 = workflow() + results1 = workflow() + results2 = workflow1() + assert np.allclose(results0, results1) + assert np.allclose(results0, results2) + + +@pytest.mark.parametrize( + "seed", + [ + 42, + 2**32 - 1, + 0, + ], +) +@pytest.mark.parametrize("shots", [10]) +def test_seeded_probs(seed, shots, backend): + """Test that different calls to qjits with the same seed produce the same probs results""" + + if backend not in ["lightning.qubit", "lightning.kokkos", "lightning.gpu"]: + pytest.skip( + "Probs seeding is only supported on lightning.qubit, lightning.kokkos and lightning.gpu" + ) + + dev = qml.device(backend, wires=2, shots=shots) + + @qjit(seed=seed) + def workflow(): + @qml.qnode(dev) + def circuit(): + qml.RY(0.7, wires=0) + qml.PauliZ(0) + return qml.probs() + + return circuit(), circuit(), circuit(), circuit() + + @qjit(seed=seed) + def workflow1(): + @qml.qnode(dev) + def circuit(): + qml.RY(0.7, wires=0) + qml.PauliZ(0) + return qml.probs() + + return circuit(), circuit(), circuit(), circuit() + + # Calls to qjits with the same seed should return the same samples + for _ in range(3): + results0 = workflow() + results1 = workflow() + results2 = workflow1() + assert np.allclose(results0, results1) + assert np.allclose(results0, results2) + + +@pytest.mark.parametrize( + "seed", + [ + 42, + 2**32 - 1, + 0, + ], +) +@pytest.mark.parametrize("shots", [10]) +def test_seeded_expval(seed, shots, backend): + """Test that different calls to qjits with the same seed produce the same expval results""" + + if backend not in ["lightning.qubit", "lightning.kokkos", "lightning.gpu"]: + pytest.skip( + "Expval seeding is only supported on lightning.qubit, lightning.kokkos" + "and lightning.gpu" + ) + + dev = qml.device(backend, wires=2, shots=shots) + + @qjit(seed=seed) + def workflow(): + @qml.qnode(dev) + def circuit(): + qml.RY(0.7, wires=0) + return qml.expval(qml.PauliZ(0)) + + return circuit(), circuit(), circuit(), circuit() + + @qjit(seed=seed) + def workflow1(): + @qml.qnode(dev) + def circuit(): + qml.RY(0.7, wires=0) + qml.PauliZ(0) + return qml.expval(qml.PauliZ(0)) + + return circuit(), circuit(), circuit(), circuit() + + # Calls to qjits with the same seed should return the same samples + for _ in range(3): + results0 = workflow() + results1 = workflow() + results2 = workflow1() + assert np.allclose(results0, results1) + assert np.allclose(results0, results2) + + +@pytest.mark.parametrize( + "seed", + [ + 42, + 2**32 - 1, + 0, + ], +) +@pytest.mark.parametrize("shots", [10]) +def test_seeded_var(seed, shots, backend): + """Test that different calls to qjits with the same seed produce the same var results""" + + if backend not in ["lightning.qubit", "lightning.kokkos", "lightning.gpu"]: + pytest.skip( + "Var seeding is only supported on lightning.qubit, lightning.kokkos and lightning.gpu" + ) + + dev = qml.device(backend, wires=2, shots=shots) + + @qjit(seed=seed) + def workflow(): + @qml.qnode(dev) + def circuit(): + qml.RY(0.7, wires=0) + return qml.var(qml.PauliZ(0)) + + return circuit(), circuit(), circuit(), circuit() + + @qjit(seed=seed) + def workflow1(): + @qml.qnode(dev) + def circuit(): + qml.RY(0.7, wires=0) + qml.PauliZ(0) + return qml.var(qml.PauliZ(0)) + + return circuit(), circuit(), circuit(), circuit() + + # Calls to qjits with the same seed should return the same samples + for _ in range(3): results0 = workflow() results1 = workflow() results2 = workflow1() diff --git a/frontend/test/pytest/test_shotvector.py b/frontend/test/pytest/test_shotvector.py index cc8afa3a69..d982b2acde 100644 --- a/frontend/test/pytest/test_shotvector.py +++ b/frontend/test/pytest/test_shotvector.py @@ -94,7 +94,7 @@ def test_shot_vector_with_different_measurement(self): with pytest.raises( NotImplementedError, match=re.escape( - "Measurement expval is not supported a shot-vector. Use qml.sample() instead." + "Measurement ExpectationMP is not supported a shot-vector. Use qml.sample() instead." ), ): @@ -109,7 +109,7 @@ def circuit(): with pytest.raises( NotImplementedError, match=re.escape( - "Measurement var is not supported a shot-vector. Use qml.sample() instead." + "Measurement VarianceMP is not supported a shot-vector. Use qml.sample() instead." ), ): @@ -124,7 +124,7 @@ def circuit(): with pytest.raises( NotImplementedError, match=re.escape( - "Measurement probs is not supported a shot-vector. Use qml.sample() instead." + "Measurement ProbabilityMP is not supported a shot-vector. Use qml.sample() instead." ), ): diff --git a/frontend/test/pytest/test_toml_utils.py b/frontend/test/pytest/test_toml_utils.py new file mode 100644 index 0000000000..657015a124 --- /dev/null +++ b/frontend/test/pytest/test_toml_utils.py @@ -0,0 +1,256 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the toml_utils module. +""" + +import io +import math + +import pytest + +from catalyst.utils.toml_utils import load_toml, safe_eval + + +class TestLoadToml: + """Test suite for the load_toml function. + + Note that we do not test loading TOML documents from files to avoid extraneous TOML test files + and disk read operations. + """ + + def test_load_toml_from_StringIO(self): + """Test load_toml on a valid TOML document loaded as a StringIO.""" + with io.StringIO("[test]\nkey = 1") as f: + assert load_toml(f) == {"test": {"key": 1}} + + def test_load_toml_from_string(self): + """Test load_toml on a valid TOML document string.""" + assert load_toml("[test]\nkey = 1") == {"test": {"key": 1}} + + def test_load_toml_invalid(self): + """Test load_toml on an invalid TOML file input.""" + with pytest.raises( + TypeError, match="Input must be a string, io.StringIO, or a path-like object." + ): + load_toml(1) + + +class TestSafeEval: + """Test suite for the safe_eval function""" + + def test_safe_eval_single_value(self): + """Test safe_eval on expressions containing single numeric values of various types.""" + assert safe_eval("1") == 1 + assert type(safe_eval("1")) == int + + assert math.isclose(safe_eval("1.0"), 1.0) + assert type(safe_eval("1.0")) == float + + assert safe_eval("+1") == 1 + assert type(safe_eval("+1")) == int + + assert math.isclose(safe_eval("+1.0"), 1.0) + assert type(safe_eval("+1.0")) == float + + assert safe_eval("-1") == -1 + assert type(safe_eval("-1")) == int + + assert math.isclose(safe_eval("-1.0"), -1.0) + assert type(safe_eval("-1.0")) == float + + assert math.isclose(safe_eval("1e3"), 1e3) + assert math.isclose(safe_eval("1e-3"), 1e-3) + + def test_safe_eval_addition(self): + """Test safe_eval on expressions containing addition operations.""" + assert safe_eval("1 + 1") == 2 + assert type(safe_eval("1 + 1")) == int + + assert math.isclose(safe_eval("1.0 + 1.0"), 2.0) + assert type(safe_eval("1.0 + 1.0")) == float + + def test_safe_eval_subtraction(self): + """Test safe_eval on expressions containing subtraction operations.""" + assert safe_eval("2 - 1") == 1 + assert type(safe_eval("2 - 1")) == int + + assert math.isclose(safe_eval("2.0 - 1.0"), 1.0) + assert type(safe_eval("2.0 - 1.0")) == float + + assert safe_eval("1 - 2") == -1 + assert math.isclose(safe_eval("1.0 - 2.0"), -1.0) + + def test_safe_eval_multiplication(self): + """Test safe_eval on expressions containing multiplication operations.""" + assert safe_eval("2 * 2") == 4 + assert type(safe_eval("2 * 2")) == int + + assert math.isclose(safe_eval("2.0 * 2.0"), 4.0) + assert type(safe_eval("2.0 * 2.0")) == float + + assert safe_eval("2 * -2") == -4 + assert math.isclose(safe_eval("2.0 * -2.0"), -4.0) + + def test_safe_eval_division(self): + """Test safe_eval on expressions containing division operations.""" + assert math.isclose(safe_eval("4 / 2"), 2.0) + assert type(safe_eval("4 / 2")) == float + + assert math.isclose(safe_eval("4.0 / 2.0"), 2.0) + assert type(safe_eval("4.0 / 2.0")) == float + + assert math.isclose(safe_eval("4.0 / -2.0"), -2.0) + assert math.isclose(safe_eval("-4.0 / 2.0"), -2.0) + + assert math.isclose(safe_eval("1.0 / 2.0"), 0.5) + + def test_safe_eval_exponentiation(self): + """Test safe_eval on expressions containing exponentiation operations.""" + assert math.isclose(safe_eval("2 ** 2"), 4) + assert type(safe_eval("2 ** 2")) == int + + assert math.isclose(safe_eval("2.0 ** 2"), 4.0) + assert type(safe_eval("2.0 ** 2")) == float + + def test_safe_eval_math_constants(self): + """Test safe_eval on expressions containing constants from the math module.""" + assert math.isclose(safe_eval("math.pi"), math.pi) + assert math.isclose(safe_eval("math.e"), math.e) + assert math.isinf(safe_eval("math.inf")) + assert math.isnan(safe_eval("math.nan")) + + def test_safe_eval_math_functions(self): + """Test safe_eval on expressions containing functions from the math module.""" + assert math.isclose(safe_eval("math.sin(0.5)"), math.sin(0.5)) + assert math.isclose(safe_eval("math.cos(0.5)"), math.cos(0.5)) + assert math.isclose(safe_eval("math.tan(0.5)"), math.tan(0.5)) + assert math.isclose(safe_eval("math.asin(0.5)"), math.asin(0.5)) + assert math.isclose(safe_eval("math.acos(0.5)"), math.acos(0.5)) + assert math.isclose(safe_eval("math.atan(0.5)"), math.atan(0.5)) + assert math.isclose(safe_eval("math.sinh(0.5)"), math.sinh(0.5)) + assert math.isclose(safe_eval("math.cosh(0.5)"), math.cosh(0.5)) + assert math.isclose(safe_eval("math.tanh(0.5)"), math.tanh(0.5)) + assert math.isclose(safe_eval("math.asinh(0.5)"), math.asinh(0.5)) + assert math.isclose(safe_eval("math.acosh(1.5)"), math.acosh(1.5)) + assert math.isclose(safe_eval("math.atanh(0.5)"), math.atanh(0.5)) + assert math.isclose(safe_eval("math.log(0.5)"), math.log(0.5)) + assert math.isclose(safe_eval("math.log10(0.5)"), math.log10(0.5)) + assert math.isclose(safe_eval("math.log2(0.5)"), math.log2(0.5)) + + def test_safe_eval_complex_numbers(self): + """Test safe_eval on expressions containing complex numbers.""" + assert safe_eval("1 + 1j") == 1 + 1j + assert safe_eval("(1 + 1j) * (1 - 1j)") == 2 + 0j + + def test_safe_eval_long_expr(self): + """Test safe_eval on reasonably long, non-trivial expressions.""" + assert math.isclose( + safe_eval("(1.602e-19) ** 2 / (4 * math.pi * 8.854e-12 * 1.054e-34 * 2.998e8)"), + 1 / 137.036, + rel_tol=1e-3, + ) + + def test_safe_eval_invalid(self): + """Test that safe_eval raises a ValueError on invalid expressions.""" + with pytest.raises(ValueError, match="Invalid expression"): + safe_eval("1 +") + + with pytest.raises(ValueError, match="Invalid expression"): + safe_eval("1 + (2 + 3") + + with pytest.raises(ValueError, match="Invalid expression"): + safe_eval("* 2") + + with pytest.raises(ValueError, match="Invalid expression"): + safe_eval("1 + x") + + with pytest.raises(ValueError, match="Invalid expression"): + safe_eval("~1") + + with pytest.raises(ValueError, match="Invalid expression"): + safe_eval("1 + math.pii") + + with pytest.raises(ValueError, match="Invalid expression"): + safe_eval("1 + math.sin") + + with pytest.raises(ValueError, match="Invalid expression"): + safe_eval("1 + math.sinn(2)") + + with pytest.raises(ValueError, match="Invalid expression"): + safe_eval("1 + (3).real") + + # The following operations are not supported, but we may want to add support in the future: + @pytest.mark.xfail(reason="safe_eval does not support the modulo operator") + def test_safe_eval_modulo(self): + """Test safe_eval on expressions containing modulo operations.""" + assert safe_eval("2 % 2") == 0 + + @pytest.mark.xfail(reason="safe_eval does not support the floor-division operator") + def test_safe_eval_floor_division(self): + """Test safe_eval on expressions containing floor-division operations.""" + assert safe_eval("4 // 2") == 2 + + @pytest.mark.xfail(reason="safe_eval does not support the left bit-shift operator") + def test_safe_eval_left_bit_shift(self): + """Test safe_eval on expressions containing left bit-shift operations.""" + assert safe_eval("8 << 1") == 16 + + @pytest.mark.xfail(reason="safe_eval does not support the right bit-shift operator") + def test_safe_eval_right_bit_shift(self): + """Test safe_eval on expressions containing right bit-shift operations.""" + assert safe_eval("8 >> 1") == 4 + + @pytest.mark.xfail(reason="safe_eval does not support the bitwise-and operator") + def test_safe_eval_bitwise_and(self): + """Test safe_eval on expressions containing bitwise-and operations.""" + assert safe_eval("8 & 1") == 0 + + @pytest.mark.xfail(reason="safe_eval does not support the bitwise-or operator") + def test_safe_eval_bitwise_or(self): + """Test safe_eval on expressions containing bitwise-or operations.""" + assert safe_eval("8 | 1") == 9 + + @pytest.mark.xfail(reason="safe_eval does not support the bitwise-xor operator") + def test_safe_eval_bitwise_xor(self): + """Test safe_eval on expressions containing bitwise-xor operations.""" + assert safe_eval("8 ^ 9") == 1 + + @pytest.mark.xfail(reason="safe_eval does not support the invert operator") + def test_safe_eval_invert(self): + """Test safe_eval on expressions containing invert operations.""" + assert safe_eval("~1") == -2 + + # The following operations should never be supported by safe_eval: + def test_safe_eval_unsupported_operations(self): + """Test that safe_eval raises a ValueError on valid expressions that contain unsupported + operations. + """ + with pytest.raises(ValueError, match="Invalid expression"): + # `eval("os.listdir('/')")` is valid (assuming the os module is accessible), but the + # equivalent `safe_eval` call is not. + safe_eval("os.listdir('/')") + + with pytest.raises(ValueError, match="Invalid expression"): + safe_eval("sys.version") + + with pytest.raises(ValueError, match="Invalid expression"): + safe_eval("x = 42") + + with pytest.raises(ValueError, match="Invalid expression"): + safe_eval("print(42)") + + +if __name__ == "__main__": + pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/test_verification.py b/frontend/test/pytest/test_verification.py index 31cede054f..6d1a979bae 100644 --- a/frontend/test/pytest/test_verification.py +++ b/frontend/test/pytest/test_verification.py @@ -20,6 +20,7 @@ import pennylane as qml import pytest +from pennylane.devices.capabilities import OperatorProperties from pennylane.measurements import ExpectationMP, VarianceMP from pennylane.ops import Adjoint, Controlled @@ -38,7 +39,6 @@ from catalyst.device import get_device_capabilities from catalyst.device.qjit_device import RUNTIME_OPERATIONS, get_qjit_device_capabilities from catalyst.device.verification import validate_measurements -from catalyst.utils.toml import OperationProperties # pylint: disable = unused-argument, unnecessary-lambda-assignment, unnecessary-lambda @@ -71,15 +71,15 @@ def __init__(self, shots=None, wires=None): lightning_capabilities = get_device_capabilities(lightning_device) custom_capabilities = deepcopy(lightning_capabilities) for gate in native_gates: - custom_capabilities.native_ops[gate] = OperationProperties(True, True, True) + custom_capabilities.operations[gate] = OperatorProperties(True, True, True) for gate in non_differentiable_gates: - custom_capabilities.native_ops[gate].differentiable = False + custom_capabilities.operations[gate].differentiable = False for gate in non_invertible_gates: - custom_capabilities.native_ops[gate].invertible = False + custom_capabilities.operations[gate].invertible = False for gate in non_controllable_gates: - custom_capabilities.native_ops[gate].controllable = False + custom_capabilities.operations[gate].controllable = False for obs in non_differentiable_obs: - custom_capabilities.native_obs[obs].differentiable = False + custom_capabilities.observables[obs].differentiable = False self.qjit_capabilities = custom_capabilities @staticmethod @@ -299,7 +299,7 @@ def f(x: float): return qml.expval(qml.PauliX(0)) runtime_ops_with_qctrl = deepcopy(RUNTIME_OPERATIONS) - runtime_ops_with_qctrl["HybridCtrl"] = OperationProperties( + runtime_ops_with_qctrl["HybridCtrl"] = OperatorProperties( invertible=True, controllable=True, differentiable=True ) @@ -407,7 +407,7 @@ def f(x: float): return qml.expval(qml.PauliX(0)) runtime_ops_with_qctrl = deepcopy(RUNTIME_OPERATIONS) - runtime_ops_with_qctrl["HybridCtrl"] = OperationProperties( + runtime_ops_with_qctrl["HybridCtrl"] = OperatorProperties( invertible=True, controllable=True, differentiable=True ) @@ -450,7 +450,7 @@ def f(x: float): return qml.expval(qml.PauliX(0)) runtime_ops_with_qctrl = deepcopy(RUNTIME_OPERATIONS) - runtime_ops_with_qctrl["HybridCtrl"] = OperationProperties( + runtime_ops_with_qctrl["HybridCtrl"] = OperatorProperties( invertible=True, controllable=True, differentiable=True ) @@ -594,7 +594,7 @@ def test_arithmetic_ops_validation(self, obs, obs_type, backend): # all good validate_measurements(tape, qjit_capabilities, dev.name, dev.shots) - del qjit_capabilities.native_obs[obs_type] + del qjit_capabilities.observables[obs_type] with pytest.raises(CompileError, match="not supported as an observable"): validate_measurements(tape, qjit_capabilities, dev.name, dev.shots) @@ -605,12 +605,8 @@ def test_non_qjit_observables_raise_error(self, backend): dev = qml.device(backend, wires=1) dev_capabilities = get_device_capabilities(dev) - dev_capabilities.native_obs.update( - { - "PauliX2": OperationProperties( - invertible=True, controllable=True, differentiable=True - ) - } + dev_capabilities.observables.update( + {"PauliX2": OperatorProperties(invertible=True, controllable=True, differentiable=True)} ) qjit_capabilities = get_qjit_device_capabilities(dev_capabilities) @@ -664,7 +660,7 @@ def test_validate_measurements_works_on_measurement_processes(self, measurement, tape = qml.tape.QuantumScript([], measurements=[measurement]) qjit_capabilities = get_device_capabilities(dev) - qjit_capabilities.measurement_processes.remove("Expval") + qjit_capabilities.measurement_processes.pop("ExpectationMP") with pytest.raises(CompileError, match=msg): validate_measurements(tape, qjit_capabilities, dev.name, dev.shots) diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index d738f5151e..c0b8dfd6cb 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -69,12 +69,6 @@ include_directories(${LLVM_INCLUDE_DIRS} link_directories(${LLVM_BUILD_LIBRARY_DIR}) add_definitions(${LLVM_DEFINITIONS}) -if(APPLE) - set(CMAKE_CXX_VISIBILITY_PRESET hidden) -else() - set(CMAKE_CXX_VISIBILITY_PRESET protected) -endif() - add_subdirectory(include) add_subdirectory(lib) add_subdirectory(tools) diff --git a/mlir/Makefile b/mlir/Makefile index 31fdd905c3..aebb37294c 100644 --- a/mlir/Makefile +++ b/mlir/Makefile @@ -14,13 +14,14 @@ ENABLE_ASAN?=OFF BUILD_TYPE?=Release TARGET_FILE=$(MK_DIR)/mlir-hlo/mhlo/transforms/CMakeLists.txt PATCH_FILE=$(MK_DIR)/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch +LLVM_EXTERNAL_LIT ?= $(LLVM_BUILD_DIR)/bin/llvm-lit ifeq ($(shell uname), Darwin) DEFAULT_ENABLE_LLD := OFF -SYMBOL_VISIBILITY := hidden +SYMBOL_VISIBILITY := default else DEFAULT_ENABLE_LLD := ON -SYMBOL_VISIBILITY := protected +SYMBOL_VISIBILITY := default endif ENABLE_LLD?=$(DEFAULT_ENABLE_LLD) @@ -49,7 +50,7 @@ help: @echo " format [version=?] to apply C++ formatter; use with 'version={version}' to run clang-format-{version} instead of clang-format" .PHONY: all -all: llvm mhlo enzyme dialects +all: llvm mhlo enzyme dialects plugin .PHONY: llvm llvm: @@ -120,6 +121,23 @@ enzyme: cmake --build $(ENZYME_BUILD_DIR) --target EnzymeStatic-19 +.PHONY: plugin +plugin: + [ -f $(MK_DIR)/standalone ] || cp -r $(MK_DIR)/llvm-project/mlir/examples/standalone . + @if patch -p0 --dry-run -N < $(MK_DIR)/patches/test-plugin-with-catalyst.patch > /dev/null 2>&1; then \ + patch -p0 < $(MK_DIR)/patches/test-plugin-with-catalyst.patch; \ + fi + cmake -B standalone/build -G Ninja \ + -DMLIR_DIR=$(LLVM_BUILD_DIR)/lib/cmake/mlir \ + -DLLVM_EXTERNAL_LIT=$(LLVM_EXTERNAL_LIT) \ + -DCATALYST_TOOLS_DIR=$(DIALECTS_BUILD_DIR)/bin \ + -DPython3_EXECUTABLE=$(PYTHON) \ + -DPython3_NumPy_INCLUDE_DIRS=$$($(PYTHON) -c "import numpy as np; print(np.get_include())") \ + standalone + cmake --build standalone/build --target check-standalone + mkdir -p $(DIALECTS_BUILD_DIR)/lib + cp standalone/build/lib/StandalonePlugin.* $(DIALECTS_BUILD_DIR)/lib + .PHONY: dialects dialects: @@ -153,8 +171,8 @@ test: @echo "test the Catalyst MLIR dialects test suite" cmake --build $(DIALECTS_BUILD_DIR) --target check-dialects -.PHONY: clean clean-dialects clean-enzyme clean-mhlo -clean: clean-dialects clean-llvm clean-mhlo clean-enzyme +.PHONY: clean clean-dialects clean-enzyme clean-mhlo clean-plugin +clean: clean-dialects clean-llvm clean-mhlo clean-enzyme clean-plugin clean-dialects: @echo "clean catalyst dialect build files" @@ -172,6 +190,11 @@ clean-enzyme: @echo "clean enzyme build files" rm -rf $(ENZYME_BUILD_DIR) +clean-plugin: + @echo "clean plugin" + rm -rf standalone/build + rm -rf $(DIALECTS_BUILD_DIR)/lib/StandalonePlugin.* + .PHONY: format format: ifdef check diff --git a/mlir/include/CAPI/Dialects.h b/mlir/include/CAPI/Dialects.h index c3261ac666..5824feab17 100644 --- a/mlir/include/CAPI/Dialects.h +++ b/mlir/include/CAPI/Dialects.h @@ -24,6 +24,7 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Quantum, quantum); MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Gradient, gradient); MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Mitigation, mitigation); MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Catalyst, catalyst); +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Ion, ion); #ifdef __cplusplus } diff --git a/mlir/include/CMakeLists.txt b/mlir/include/CMakeLists.txt index 960dd55d3d..989dd97d9a 100644 --- a/mlir/include/CMakeLists.txt +++ b/mlir/include/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(Catalyst) add_subdirectory(Quantum) add_subdirectory(Gradient) +add_subdirectory(Ion) add_subdirectory(Mitigation) add_subdirectory(Test) diff --git a/mlir/include/Ion/CMakeLists.txt b/mlir/include/Ion/CMakeLists.txt new file mode 100644 index 0000000000..9f57627c32 --- /dev/null +++ b/mlir/include/Ion/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/Ion/IR/CMakeLists.txt b/mlir/include/Ion/IR/CMakeLists.txt new file mode 100644 index 0000000000..0d558ed465 --- /dev/null +++ b/mlir/include/Ion/IR/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_dialect(IonOps ion) +add_mlir_interface(IonInterfaces) +add_mlir_doc(IonDialect IonDialect -gen-dialect-doc Ion/) +add_mlir_doc(IonOps IonOps Ion/ -gen-op-doc) +add_mlir_doc(IonInterfaces IonInterfaces Ion/ -gen-op-interface-docs) + +set(LLVM_TARGET_DEFINITIONS IonOps.td) +mlir_tablegen(IonEnums.h.inc -gen-enum-decls) +mlir_tablegen(IonEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(IonAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=ion) +mlir_tablegen(IonAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=ion) +add_public_tablegen_target(MLIRIonEnumsIncGen) diff --git a/mlir/include/Ion/IR/IonDialect.h b/mlir/include/Ion/IR/IonDialect.h new file mode 100644 index 0000000000..e1be7d55e5 --- /dev/null +++ b/mlir/include/Ion/IR/IonDialect.h @@ -0,0 +1,33 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/Dialect.h" + +#include "Quantum/IR/QuantumDialect.h" + +//===----------------------------------------------------------------------===// +// Ion dialect declarations. +//===----------------------------------------------------------------------===// + +#include "Ion/IR/IonOpsDialect.h.inc" + +//===----------------------------------------------------------------------===// +// Ion type declarations. +//===----------------------------------------------------------------------===// + +#define GET_TYPEDEF_CLASSES +#include "Ion/IR/IonOpsTypes.h.inc" diff --git a/mlir/include/Ion/IR/IonDialect.td b/mlir/include/Ion/IR/IonDialect.td new file mode 100644 index 0000000000..9943c8598b --- /dev/null +++ b/mlir/include/Ion/IR/IonDialect.td @@ -0,0 +1,64 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ION_DIALECT +#define ION_DIALECT + +include "mlir/IR/OpBase.td" +include "mlir/IR/DialectBase.td" +include "mlir/IR/AttrTypeBase.td" + +//===----------------------------------------------------------------------===// +// Ion dialect. +//===----------------------------------------------------------------------===// + +def Ion_Dialect : Dialect { + let summary = "A trapped ions dialect with value semantics."; + let description = [{ + The ion dialect extends core MLIR with the necessary types and operations to form + the IR for trapped ions quantum computers. + }]; + + /// This is the namespace of the dialect in MLIR, which is used as a prefix for types and ops. + let name = "ion"; + + /// This is the C++ namespace that the dialect, and all sub-components, get placed in. + let cppNamespace = "::catalyst::ion"; + + /// Use the default type printing/parsing hooks, otherwise we would explicitly define them. + let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; +} + +//===----------------------------------------------------------------------===// +// Ion dialect types. +//===----------------------------------------------------------------------===// + +class Ion_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + +def IonType : Ion_Type<"Ion", "ion"> { + let summary = "A value-semantic ion."; +} + +//===----------------------------------------------------------------------===// +// Ion dialect base operation. +//===----------------------------------------------------------------------===// + +class Ion_Op traits = []> : + Op; + +#endif // ION_DIALECT diff --git a/mlir/include/Ion/IR/IonInterfaces.h b/mlir/include/Ion/IR/IonInterfaces.h new file mode 100644 index 0000000000..b6c92d0504 --- /dev/null +++ b/mlir/include/Ion/IR/IonInterfaces.h @@ -0,0 +1,23 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "mlir/IR/OpDefinition.h" + +//===----------------------------------------------------------------------===// +// Ion interface declarations. +//===----------------------------------------------------------------------===// + +#include "Ion/IR/IonInterfaces.h.inc" diff --git a/mlir/include/Ion/IR/IonInterfaces.td b/mlir/include/Ion/IR/IonInterfaces.td new file mode 100644 index 0000000000..e9d642104d --- /dev/null +++ b/mlir/include/Ion/IR/IonInterfaces.td @@ -0,0 +1,21 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ION_INTERFACES +#define ION_INTERFACES + +include "mlir/IR/OpBase.td" + + +#endif // ION_INTERFACES diff --git a/mlir/include/Ion/IR/IonOps.h b/mlir/include/Ion/IR/IonOps.h new file mode 100644 index 0000000000..2d66e34c80 --- /dev/null +++ b/mlir/include/Ion/IR/IonOps.h @@ -0,0 +1,45 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/ErrorHandling.h" + +#include "Ion/IR/IonInterfaces.h" + +//===----------------------------------------------------------------------===// +// Ion trait declarations. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Ion ops declarations. +//===----------------------------------------------------------------------===// + +#include "Ion/IR/IonDialect.h" +// #include "Ion/IR/IonEnums.h.inc" +#define GET_ATTRDEF_CLASSES +#include "Ion/IR/IonAttributes.h.inc" +#define GET_OP_CLASSES +#include "Ion/IR/IonOps.h.inc" diff --git a/mlir/include/Ion/IR/IonOps.td b/mlir/include/Ion/IR/IonOps.td new file mode 100644 index 0000000000..cc18087501 --- /dev/null +++ b/mlir/include/Ion/IR/IonOps.td @@ -0,0 +1,214 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ION_OPS +#define ION_OPS + +include "mlir/IR/EnumAttr.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinAttributes.td" + +include "Ion/IR/IonDialect.td" +include "Ion/IR/IonInterfaces.td" +include "Quantum/IR/QuantumDialect.td" + +//===----------------------------------------------------------------------===// +// Ion dialect enums. +//===----------------------------------------------------------------------===// + + +//===----------------------------------------------------------------------===// +// Ion dialect traits. +//===----------------------------------------------------------------------===// + + +//===----------------------------------------------------------------------===// +// Ion dialect attributes. +//===----------------------------------------------------------------------===// + +class Ion_Attr traits = []> + : AttrDef { + let mnemonic = attrMnemonic; +} + +def LevelAttr : Ion_Attr<"Level", "level"> { + let summary = "A class to represent an atomic level."; + + let parameters = (ins + "mlir::IntegerAttr":$principal, + "mlir::FloatAttr":$spin, + "mlir::FloatAttr":$orbital, + "mlir::FloatAttr":$nuclear, + "mlir::FloatAttr":$spin_orbital, + "mlir::FloatAttr":$spin_orbital_nuclear, + "mlir::FloatAttr":$spin_orbital_nuclear_magnetization, + "mlir::FloatAttr":$energy + ); + + + let builders = [ + AttrBuilderWithInferredContext<(ins + "mlir::IntegerAttr":$principal, + "mlir::FloatAttr":$spin, + "mlir::FloatAttr":$orbital, + "mlir::FloatAttr":$nuclear, + "mlir::FloatAttr":$spin_orbital, + "mlir::FloatAttr":$spin_orbital_nuclear, + "mlir::FloatAttr":$spin_orbital_nuclear_magnetization, + "mlir::FloatAttr":$energy), [{ + return $_get(principal.getContext(), principal, spin, orbital, nuclear, spin_orbital, spin_orbital_nuclear, spin_orbital_nuclear_magnetization, energy); + }]> + ]; + + let assemblyFormat = "`<` struct(params) `>`"; +} + +def TransitionAttr : Ion_Attr<"Transition", "transition"> { + let summary = "A class to represent a atomic transition between two levels."; + + let parameters = (ins + "LevelAttr":$level_0, + "LevelAttr":$level_1, + "mlir::FloatAttr":$einstein_a + ); + + let builders = [ + AttrBuilderWithInferredContext<(ins "LevelAttr":$level_0, + "LevelAttr":$level_1, + "mlir::FloatAttr":$einstein_a), [{ + return $_get(einstein_a.getContext(), level_0, level_1, einstein_a); + }]> + ]; + + let assemblyFormat = "`<` struct(params) `>`"; +} + +def BeamAttr : Ion_Attr<"Beam", "beam"> { + let summary = "A class to represent a laser beam."; + + let parameters = (ins + OptionalParameter<"mlir::IntegerAttr">:$transition_index, + "mlir::FloatAttr":$rabi, + "mlir::FloatAttr":$detuning, + "mlir::DenseIntElementsAttr": $polarization, + "mlir::DenseIntElementsAttr": $wavevector + ); + + + let assemblyFormat = "`<` struct(params) `>`"; +} + +def LevelArrayAttr : TypedArrayAttrBase { + let constBuilderCall = ?; +} + + +def TransitionArrayAttr : TypedArrayAttrBase { + let constBuilderCall = ?; +} + +//===----------------------------------------------------------------------===// +// Ion dialect operations. +//===----------------------------------------------------------------------===// + +def IonOp : Ion_Op<"ion"> { + let summary = "A class to represent an ion."; + + let arguments = (ins + Builtin_StringAttr:$name, + Builtin_FloatAttr:$mass, + Builtin_FloatAttr:$charge, + AnyIntElementsAttr: $position, + LevelArrayAttr: $levels, + TransitionArrayAttr: $transitions + ); + + let results = (outs + IonType:$out_ion + ); + + let assemblyFormat = [{ + attr-dict `:` type($out_ion) + }]; +} + + +def PulseOp : Ion_Op<"pulse"> { + let summary = "Represent a pulse (a laser beam and some time)."; + + let arguments = (ins + AnyFloat: $time, + QubitType: $in_qubit, + BeamAttr: $beam, + Builtin_FloatAttr: $phase + ); + + let assemblyFormat = [{ + `(` $time `:` type($time) `)` $in_qubit attr-dict + }]; +} + + +def ParallelProtocolOp : Ion_Op<"parallelprotocol", [SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = "Represent a parallel protocol of pulses."; + + let arguments = (ins + Variadic: $in_qubits + ); + + let results = (outs Variadic:$out_qubits); + let regions = (region SizedRegion<1>:$region); + + let builders = [ + OpBuilder<(ins + CArg<"mlir::ValueRange", "std::nullopt">:$in_qubits, + CArg<"llvm::function_ref", + "nullptr">)> + ]; + + let extraClassDeclaration = [{ + using BodyBuilderFn = + llvm::function_ref; + + }]; + + let assemblyFormat = [{ + `(` $in_qubits `)` attr-dict `:` type($out_qubits) $region + }]; +} + +def YieldOp : Ion_Op<"yield", [Pure, ReturnLike, Terminator, ParentOneOf<["ParallelProtocolOp"]>]> { + let summary = "Return results from parallel protocol regions"; + + let arguments = (ins + Variadic:$results + ); + + let assemblyFormat = [{ + attr-dict ($results^ `:` type($results))? + }]; + + let builders = [ + OpBuilder<(ins), [{ /* nothing to do */ }]> + ]; +} + + +#endif // ION_OPS diff --git a/mlir/include/Ion/Transforms/CMakeLists.txt b/mlir/include/Ion/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..ed8fe77f42 --- /dev/null +++ b/mlir/include/Ion/Transforms/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name ion) +add_public_tablegen_target(MLIRIonPassIncGen) +add_mlir_doc(Passes IonPasses ./ -gen-pass-doc) diff --git a/mlir/include/Ion/Transforms/Passes.h b/mlir/include/Ion/Transforms/Passes.h new file mode 100644 index 0000000000..46b06e0428 --- /dev/null +++ b/mlir/include/Ion/Transforms/Passes.h @@ -0,0 +1,25 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "mlir/Pass/Pass.h" + +namespace catalyst { + +std::unique_ptr createQuantumToIonPass(); + +} // namespace catalyst diff --git a/mlir/include/Ion/Transforms/Passes.td b/mlir/include/Ion/Transforms/Passes.td new file mode 100644 index 0000000000..5af78d6558 --- /dev/null +++ b/mlir/include/Ion/Transforms/Passes.td @@ -0,0 +1,46 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ION_PASSES +#define ION_PASSES + +include "mlir/Pass/PassBase.td" + +def QuantumToIonPass : Pass<"quantum-to-ion"> { + let summary = "Lower a quantum circuit to an ionic pulse program."; + + let options = [ + Option<"DeviceTomlLoc", "device-toml-loc", + "std::string", /*default=*/"\"\"", + "Toml file location for the ion hardware device parameters.">, + Option<"QubitTomlLoc", "qubit-toml-loc", + "std::string", /*default=*/"\"\"", + "Toml file location for the ion hardware qubit parameters.">, + Option<"Gate2PulseDecompTomlLoc", "gate-to-pulse-toml-loc", + "std::string", /*default=*/"\"\"", + "Toml file location for the ion hardware gate-to-pulse decomposition parameters.">, + Option<"LoadIon", "load-ion", + "bool", /*default=*/"true", + "Whether to load the physical parameters for the ion (e.g. mass, charge, spin) into the IR.">, + ]; + + let dependentDialects = [ + "quantum::QuantumDialect", + "ion::IonDialect" + ]; + + let constructor = "catalyst::createQuantumToIonPass()"; +} + +#endif // ION_PASSES diff --git a/mlir/include/Ion/Transforms/Patterns.h b/mlir/include/Ion/Transforms/Patterns.h new file mode 100644 index 0000000000..676988e555 --- /dev/null +++ b/mlir/include/Ion/Transforms/Patterns.h @@ -0,0 +1,29 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "Ion/Transforms/oqd_database_managers.hpp" + +namespace catalyst { +namespace ion { + +void populateQuantumToIonPatterns(mlir::RewritePatternSet &, const OQDDatabaseManager &); + +} // namespace ion +} // namespace catalyst diff --git a/mlir/include/Ion/Transforms/oqd_database_managers.hpp b/mlir/include/Ion/Transforms/oqd_database_managers.hpp new file mode 100644 index 0000000000..5bbaa53fe0 --- /dev/null +++ b/mlir/include/Ion/Transforms/oqd_database_managers.hpp @@ -0,0 +1,206 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include + +#include "oqd_database_types.hpp" +#include "oqd_database_utils.hpp" + +namespace catalyst { +namespace ion { + +class OQDDatabaseManager { + public: + OQDDatabaseManager(const std::string &DeviceTomlLoc, const std::string &QubitTomlLoc, + const std::string &Gate2PulseDecompTomlLoc) + { + sourceTomlDevice = toml::parse_file(DeviceTomlLoc); + sourceTomlQubit = toml::parse_file(QubitTomlLoc); + sourceTomlGateDecomposition = toml::parse_file(Gate2PulseDecompTomlLoc); + + assert(sourceTomlDevice && "Parsing of device toml failed!"); + assert(sourceTomlQubit && "Parsing of qubit toml failed!"); + assert(sourceTomlGateDecomposition && "Parsing of gate decomposition toml failed!"); + + loadBeams1Params(); + loadBeams2Params(); + + loadPhononParams(); + + loadIonParams(); + } + + const std::vector &getBeams1Params() const { return beams1; } + const std::vector &getBeams2Params() const { return beams2; } + + const std::vector &getPhononParams() const { return phonons; } + + const std::map &getIonParams() const { return ions; } + + private: + toml::parse_result sourceTomlDevice; + toml::parse_result sourceTomlQubit; + toml::parse_result sourceTomlGateDecomposition; + + std::vector beams1; + std::vector beams2; + + std::vector phonons; + + std::map ions; + + void loadBeams1Params() { loadBeamsParamsImpl("beams1"); } + void loadBeams2Params() { loadBeamsParamsImpl("beams2"); } + + void loadBeamsParamsImpl(const std::string &mode) + { + // Read in the gate decomposition beam parameters from toml file. + // The toml contains a list of beams, where each beam has the following fields: + // rabi = 4.4 + // detuning = 5.5 + // polarization = [6,7] + // wavevector = [8,9] + + toml::node_view beamsToml = sourceTomlGateDecomposition[mode]; + size_t numBeams = beamsToml.as_array()->size(); + + std::vector *collector; + if (mode == "beams1") { + collector = &beams1; + } + else if (mode == "beams2") { + collector = &beams2; + } + else { + assert(false && "Invalid beam mode. Only single-qubit gates and 2-qubit gates are " + "supported for decomposition onto beams."); + } + + for (size_t i = 0; i < numBeams; i++) { + auto beam = beamsToml[i]; + double rabi = beam["rabi"].as_floating_point()->get(); + double detuning = beam["detuning"].as_floating_point()->get(); + std::vector polarization = + tomlArray2StdVector(*(beam["polarization"].as_array())); + std::vector wavevector = + tomlArray2StdVector(*(beam["wavevector"].as_array())); + + collector->push_back(Beam(rabi, detuning, polarization, wavevector)); + } + } + + void loadPhononParams() + { + toml::node_view phononsToml = sourceTomlGateDecomposition["phonons"]; + size_t numPhononModes = phononsToml.as_array()->size(); + + auto parseSingleDirection = [](auto direction) { + double energy = direction["energy"].as_floating_point()->get(); + std::vector eigenvector = + tomlArray2StdVector(*(direction["eigenvector"].as_array())); + return Phonon(energy, eigenvector); + }; + + for (size_t i = 0; i < numPhononModes; i++) { + auto phononMode = phononsToml[i]; + + Phonon COM_x = parseSingleDirection(phononMode["COM_x"]); + Phonon COM_y = parseSingleDirection(phononMode["COM_y"]); + Phonon COM_z = parseSingleDirection(phononMode["COM_z"]); + + phonons.push_back(PhononMode(COM_x, COM_y, COM_z)); + } + } + + void loadIonParams() + { + toml::node_view ionsToml = sourceTomlQubit["ions"]; + + auto parseSingleLevel = [](auto level) { + int64_t principal = level["principal"].as_integer()->get(); + + std::vector properties{"spin", + "orbital", + "nuclear", + "spin_orbital", + "spin_orbital_nuclear", + "spin_orbital_nuclear_magnetization", + "energy"}; + std::vector propertiesData(properties.size()); + + std::transform(properties.begin(), properties.end(), propertiesData.begin(), + [&level](const std::string &name) { + return level[name].as_floating_point()->get(); + }); + + return Level(principal, propertiesData[0], propertiesData[1], propertiesData[2], + propertiesData[3], propertiesData[4], propertiesData[5], + propertiesData[6]); + }; + + auto parseSingleTransition = [](const auto &transition_entry, + const std::vector &allLevels) { + // FIXME: `allLevels` is hardcoded as {downstate, upstate, estate} + // Not super important, as the ion species is extremely unlikely to change, so + // hardcoding is fine + + double einstein_a = transition_entry["einstein_a"].as_floating_point()->get(); + std::string level1 = transition_entry["level1"].as_string()->get(); + std::string level2 = transition_entry["level2"].as_string()->get(); + + std::map levelEncodings{ + {"downstate", 0}, {"upstate", 1}, {"estate", 2}}; + assert((levelEncodings.count(level1) & levelEncodings.count(level2)) && + "Only \"downstate\", \"upstate\" and \"estate\" are allowed in the atom's " + "transition levels."); + + return Transition(allLevels[levelEncodings[level1]], allLevels[levelEncodings[level2]], + einstein_a); + }; + + for (auto &ion_it : *(ionsToml.as_table())) { + std::string name(ion_it.first.str()); + toml::table *data = ion_it.second.as_table(); + + double mass = data->at_path("mass").as_floating_point()->get(); + double charge = data->at_path("charge").as_floating_point()->get(); + + std::vector position = + tomlArray2StdVector(*(data->at_path("position").as_array())); + + Level downstate = parseSingleLevel(data->at_path("levels")["downstate"]); + Level upstate = parseSingleLevel(data->at_path("levels")["upstate"]); + Level estate = parseSingleLevel(data->at_path("levels")["estate"]); + std::vector levels{downstate, upstate, estate}; + + std::vector transitions; + auto *transitionsTable = data->at_path("transitions").as_table(); + for (auto &transition : *transitionsTable) { + transitions.push_back( + parseSingleTransition(*(transition.second.as_table()), levels)); + } + + Ion ion(name, mass, charge, position, levels, transitions); + ions.insert({name, ion}); + } + } +}; + +} // namespace ion +} // namespace catalyst diff --git a/mlir/include/Ion/Transforms/oqd_database_types.hpp b/mlir/include/Ion/Transforms/oqd_database_types.hpp new file mode 100644 index 0000000000..c5b7af0dbf --- /dev/null +++ b/mlir/include/Ion/Transforms/oqd_database_types.hpp @@ -0,0 +1,108 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace { + +// +// Calibrated parameters +// + +struct Beam { + // This struct contains the calibrated beam parameters. + double rabi, detuning; + std::vector polarization, wavevector; + + Beam(double _rabi, double _detuning, std::vector _polarization, + std::vector _wavevector) + : rabi(_rabi), detuning(_detuning), polarization(_polarization), wavevector(_wavevector) + { + } +}; + +struct Phonon { + // This struct contains the calibrated phonon parameters on one axis. + double energy; + std::vector eigenvector; + + Phonon(double _energy, std::vector _eigenvector) + : energy(_energy), eigenvector(_eigenvector) + { + } +}; + +struct PhononMode { + // This struct contains the calibrated phonon parameters for one ion. + Phonon COM_x; + Phonon COM_y; + Phonon COM_z; + + PhononMode(Phonon x, Phonon y, Phonon z) : COM_x(x), COM_y(y), COM_z(z) {} +}; + +// +// Innate atomic parameters +// + +struct Level { + // This class represents an atomic level. + // It contains the innate properties of the qubit. + int64_t principal; + double spin, orbital, nuclear, spin_orbital, spin_orbital_nuclear, + spin_orbital_nuclear_magnetization, energy; + + Level(int64_t _principal, double _spin, double _orbital, double _nuclear, double _spin_orbital, + double _spin_orbital_nuclear, double _spin_orbital_nuclear_magnetization, double _energy) + : principal(_principal), spin(_spin), orbital(_orbital), nuclear(_nuclear), + spin_orbital(_spin_orbital), spin_orbital_nuclear(_spin_orbital_nuclear), + spin_orbital_nuclear_magnetization(_spin_orbital_nuclear_magnetization), energy(_energy) + { + } +}; + +struct Transition { + // This class represents a transition between two atomic levels. + // It contains the innate properties of the qubit. + Level level_0, level_1; + double einstein_a; + + Transition(Level _level_0, Level _level_1, double _einstein_a) + : level_0(_level_0), level_1(_level_1), einstein_a(_einstein_a) + { + } +}; + +struct Ion { + // This class represents an ion. + // It contains the innate properties of the qubit. + std::string name; + double mass, charge; + std::vector position; + std::vector levels; + std::vector transitions; + + Ion(std::string _name, double _mass, double _charge, std::vector _position, + std::vector _levels, std::vector _transitions) + : name(_name), mass(_mass), charge(_charge), position(_position), levels(_levels), + transitions(_transitions) + { + } +}; + +} // namespace diff --git a/mlir/include/Ion/Transforms/oqd_database_utils.hpp b/mlir/include/Ion/Transforms/oqd_database_utils.hpp new file mode 100644 index 0000000000..4491d4e4c4 --- /dev/null +++ b/mlir/include/Ion/Transforms/oqd_database_utils.hpp @@ -0,0 +1,48 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include + +namespace { + +template std::vector tomlArray2StdVector(const toml::array &arr) +{ + // A toml node can contain toml objects of arbitrary types, even other toml nodes + // i.e. toml nodes are similar to pytrees + // Therefore, toml++ does not provide a simple "toml array to std vector" converter + // + // For a "leaf" array node, whose contents are now simple values, + // such a utility would come in handy. + + std::vector vec; + + if constexpr (std::is_same_v) { + for (const auto &elem : arr) { + vec.push_back(elem.as_integer()->get()); + } + } + else if constexpr (std::is_same_v) { + for (const auto &elem : arr) { + vec.push_back(elem.as_floating_point()->get()); + } + } + + return vec; +} +} // namespace diff --git a/mlir/include/Quantum/IR/QuantumInterfaces.td b/mlir/include/Quantum/IR/QuantumInterfaces.td index 145e47ae46..dc1390c636 100644 --- a/mlir/include/Quantum/IR/QuantumInterfaces.td +++ b/mlir/include/Quantum/IR/QuantumInterfaces.td @@ -164,6 +164,23 @@ def QuantumGate : OpInterface<"QuantumGate", [QuantumOperation]> { }]; } +def StaticGate : OpInterface<"StaticGate", [QuantumGate]> { + let description = [{ + This interface provides a generic way to interact with quantum + instructions with static parameters (known at compile time). These parameters + are specified by a set of constant literals in the form of an array attribute. + }]; + + let cppNamespace = "::catalyst::quantum"; + + let methods = [ + InterfaceMethod< + "Return all operands which are considered gate parameters.", + "mlir::DenseF64ArrayAttr", "getAllParams" + >, + ]; +} + def ParametrizedGate : OpInterface<"ParametrizedGate", [QuantumGate]> { let description = [{ This interface provides a generic way to interact with parametrized diff --git a/mlir/include/Quantum/IR/QuantumOps.td b/mlir/include/Quantum/IR/QuantumOps.td index ba51fd4e31..7c8fa3d4f2 100644 --- a/mlir/include/Quantum/IR/QuantumOps.td +++ b/mlir/include/Quantum/IR/QuantumOps.td @@ -83,14 +83,16 @@ def DeviceInitOp : Quantum_Op<"device"> { let summary = "Initialize a quantum device."; let arguments = (ins + Optional:$shots, StrAttr:$lib, StrAttr:$name, StrAttr:$kwargs ); let assemblyFormat = [{ - `[` $lib `,` $name `,` $kwargs `]` attr-dict + (`shots` `(` $shots^ `)`)? `[` $lib `,` $name `,` $kwargs `]` attr-dict }]; + } def DeviceReleaseOp : Quantum_Op<"device_release"> { @@ -375,6 +377,74 @@ def CustomOp : UnitaryGate_Op<"custom", [DifferentiableGate, NoMemoryEffect, Variadic:$out_ctrl_qubits ); + let builders = [ + OpBuilder< + // Convenience builder for a gate with parameters and controls + // Note that number of out_qubits = number of in_qubits, + // and number of out_ctrl_qubits = number of in_ctrl_qubits + (ins + "llvm::StringRef":$gate, + "mlir::ValueRange":$in_qubits, + "mlir::ValueRange":$in_ctrl_qubits, + "mlir::ValueRange":$in_ctrl_values, + "mlir::ValueRange":$params, + CArg<"bool", "false">:$adjoint + ),[{ + CustomOp::build($_builder, $_state, + /*out_qubits=*/ mlir::TypeRange(in_qubits), + /*out_ctrl_qubits=*/ mlir::TypeRange(in_ctrl_qubits), + /*params=*/ params, + /*in_qubits=*/ in_qubits, + /*gate_name=*/ $_builder.getStringAttr(gate), + /*(optional) adjoint=*/ nullptr, + /*in_ctrl_qubits=*/ in_ctrl_qubits, + /*in_ctrl_values=*/ in_ctrl_values + ); + + if (adjoint){ + $_state.addAttribute("adjoint", $_builder.getUnitAttr()); + } + }]>, + + OpBuilder< + // Convenience builder for a gate with parameters and no controls + (ins + "llvm::StringRef":$gate, + "mlir::ValueRange":$in_qubits, + "mlir::ValueRange":$params, + CArg<"bool", "false">:$adjoint + ),[{ + CustomOp::build($_builder, $_state, + gate, in_qubits, mlir::ValueRange(), mlir::ValueRange(), + params, adjoint); + }]>, + + OpBuilder< + // Convenience builder for a gate with controls and no parameters + (ins + "llvm::StringRef":$gate, + "mlir::ValueRange":$in_qubits, + "mlir::ValueRange":$in_ctrl_qubits, + "mlir::ValueRange":$in_ctrl_values, + CArg<"bool", "false">:$adjoint + ),[{ + CustomOp::build($_builder, $_state, + gate, in_qubits, in_ctrl_qubits, in_ctrl_values, + mlir::ValueRange(), adjoint); + }]>, + + OpBuilder< + // Convenience builder for a gate with no parameters and no controls + (ins + "llvm::StringRef":$gate, + "mlir::ValueRange":$in_qubits, + CArg<"bool", "false">:$adjoint + ),[{ + CustomOp::build($_builder, $_state, + gate, in_qubits, mlir::ValueRange(), adjoint); + }]>, + ]; + let assemblyFormat = [{ $gate_name `(` $params `)` $in_qubits attr-dict ( `ctrls` `(` $in_ctrl_qubits^ `)` )? ( `ctrlvals` `(` $in_ctrl_values^ `)` )? `:` type($out_qubits) (`ctrls` type($out_ctrl_qubits)^ )? }]; @@ -387,6 +457,45 @@ def CustomOp : UnitaryGate_Op<"custom", [DifferentiableGate, NoMemoryEffect, let hasCanonicalizeMethod = 1; } +def StaticCustomOp : UnitaryGate_Op<"static_custom", [NoMemoryEffect, + AttrSizedOperandSegments, + AttrSizedResultSegments]> { + let summary = "A generic quantum gate with static parameters in form of a DenseF64ArrayAttr."; + let description = [{ + This operation represents a quantum gate with parameters defined statically as a + DenseF64ArrayAttr, rather than passed dynamically as operands. This is useful for gates + with parameters known at compile-time. + }]; + + let arguments = (ins + DenseF64ArrayAttr:$static_params, + Variadic:$in_qubits, + StrAttr:$gate_name, + OptionalAttr:$adjoint, + Variadic:$in_ctrl_qubits, + Variadic:$in_ctrl_values + ); + + let results = (outs + Variadic:$out_qubits, + Variadic:$out_ctrl_qubits + ); + + let assemblyFormat = [{ + $gate_name $static_params $in_qubits attr-dict ( `ctrls` `(` $in_ctrl_qubits^ `)` )? + ( `ctrlvals` `(` $in_ctrl_values^ `)` )? `:` type($out_qubits) + (`ctrls` type($out_ctrl_qubits)^ )? + }]; + + let extraClassDeclaration = extraBaseClassDeclaration # [{ + llvm::ArrayRef getAllParams() { + return getStaticParams(); + } + }]; + let hasCanonicalizeMethod = 1; +} + + def GlobalPhaseOp : UnitaryGate_Op<"gphase", [DifferentiableGate, AttrSizedOperandSegments]> { let summary = "Global Phase."; let description = [{ @@ -733,10 +842,11 @@ def SampleOp : Measurement_Op<"sample"> { let summary = "Sample eigenvalues from the given observable for the current state"; let description = [{ The `quantum.sample` operation represents the measurement process of sampling eigenvalues - from an observable on the current quantum state. Given the nature of the operation, an - attribute specifying the shot number, i.e. the number of samples to draw, must be specified. - The only SSA argument is an observable that must be defined by an operation in the local - scope. + from an observable on the current quantum state. + The only SSA argument is an observable that must be defined by an operation in the local scope. + from an observable on the current quantum state. + The number of samples to draw is determined by the device shots argument in the device initialization operation in the local scope. + Note that the return value type depends on the type of observable provided. Computational basis samples are returned as a 2D array of shape (shot number, number of qubits), with all @@ -745,13 +855,14 @@ def SampleOp : Measurement_Op<"sample"> { Example: ```mlir - func.func @foo(%q0: !quantum.bit, %q1: !quantum.bit) + func.func @foo(%q0: !quantum.bit, %q1: !quantum.bit, %shots: i64) { + quantum.device shots(%shots) ["rtd_lightning.so", "lightning.qubit", "{my_attr: my_attr_value}"] %obs1 = quantum.compbasis %q0, %q1 : !quantum.obs - %samples = quantum.samples %obs1 {shots=1000} : tensor<1000xf64> + %samples = quantum.samples %obs1 : tensor %obs2 = quantum.pauli %q0[3], %q1[1] : !quantum.obs - %samples2 = quantum.samples %obs2 {shots=1000} : tensor<1000x2xf64> + %samples2 = quantum.samples %obs2 : tensor func.return } @@ -765,8 +876,7 @@ def SampleOp : Measurement_Op<"sample"> { MemRefRankOf<[F64], [1]>, MemRefRankOf<[F64], [2]> ]> - >:$in_data, - I64Attr:$shots + >:$in_data ); let results = (outs @@ -796,10 +906,10 @@ def CountsOp : Measurement_Op<"counts", [SameVariadicOperandSize, SameVariadicRe let description = [{ The `quantum.counts` operation represents the measurement process of sampling eigenvalues from an observable on the current quantum state and counting the frequency of each - eigenvalue. Given the nature of the operation, an attribute specifying the shot number, - i.e. the number of samples to draw, must be specified. - The only SSA argument is an observable that must be defined by an operation in the local - scope. + eigenvalue. + The only SSA argument is an observable that must be defined by an operation in the local scope. + from an observable on the current quantum state. + The number of samples to draw is determined by the device shots argument in the device initialization operation in the local scope. Note that the "counts dictionary" is returned as two separate arrays of the same length, one array for the eigenvalues, and one for count of each eigenvalue. When operating in the @@ -809,13 +919,14 @@ def CountsOp : Measurement_Op<"counts", [SameVariadicOperandSize, SameVariadicRe Example: ```mlir - func.func @foo(%q0: !quantum.bit, %q1: !quantum.bit) + func.func @foo(%q0: !quantum.bit, %q1: !quantum.bit, %shots: i64) { + quantum.device shots(%shots) ["rtd_lightning.so", "lightning.qubit", "{my_attr: my_attr_value}"] %obs = quantum.compbasis %q0, %q1 : !quantum.obs - %counts = quantum.counts %obs {shots=1000} : tensor<4xf64>, tensor<4xi64> + %counts = quantum.counts %obs : tensor<4xf64>, tensor<4xi64> %obs2 = quantum.pauli %q0[3], %q1[1] : !quantum.obs - %counts2 = quantum.counts %obs2 {shots=1000} : tensor<2xf64>, tensor<2xi64> + %counts2 = quantum.counts %obs2 : tensor<2xf64>, tensor<2xi64> func.return } @@ -825,8 +936,7 @@ def CountsOp : Measurement_Op<"counts", [SameVariadicOperandSize, SameVariadicRe let arguments = (ins ObservableType:$obs, Optional>:$in_eigvals, - Optional>:$in_counts, - I64Attr:$shots + Optional>:$in_counts ); let results = (outs @@ -854,9 +964,9 @@ def ExpvalOp : Measurement_Op<"expval"> { let description = [{ The `quantum.expval` operation represents the measurement process of computing the expectation value of an observable on the current quantum state. While this quantity can - be computed analytically on simulators, an optional attribute specifiying the shot number, - i.e. the number of samples to draw, can be specified for hardware execution or shot noise - simulation. + be computed analytically on simulators, for hardware execution or shot noise + simulation, the shots attached to the device + in the local scope is used. The only SSA argument is an observable that must be defined by an operation in the local scope. @@ -874,8 +984,7 @@ def ExpvalOp : Measurement_Op<"expval"> { }]; let arguments = (ins - ObservableType:$obs, - OptionalAttr:$shots + ObservableType:$obs ); let results = (outs @@ -891,9 +1000,9 @@ def VarianceOp : Measurement_Op<"var"> { let summary = "Compute the variance of the given observable for the current state"; let description = [{ The `quantum.var` operation represents the measurement process of computing the variance of - an observable on the current quantum state. While this quantity can be computed analytically - on simulators, an optional attribute specifiying the shot number, i.e. the number of samples - to draw, can be specified for hardware execution or shot noise simulation. + an observable on the current quantum state. While this quantity can be computed analytically on simulators, for hardware execution or shot noise + simulation, the shots attached to the device + in the local scope is used. The only SSA argument is an observable that must be defined by an operation in the local scope. @@ -911,8 +1020,7 @@ def VarianceOp : Measurement_Op<"var"> { }]; let arguments = (ins - ObservableType:$obs, - OptionalAttr:$shots + ObservableType:$obs ); let results = (outs diff --git a/mlir/include/Quantum/Transforms/Passes.h b/mlir/include/Quantum/Transforms/Passes.h index 65a930bb17..d733a108d7 100644 --- a/mlir/include/Quantum/Transforms/Passes.h +++ b/mlir/include/Quantum/Transforms/Passes.h @@ -29,6 +29,9 @@ std::unique_ptr createRemoveChainedSelfInversePass(); std::unique_ptr createAnnotateFunctionPass(); std::unique_ptr createSplitMultipleTapesPass(); std::unique_ptr createMergeRotationsPass(); +std::unique_ptr createDisentangleCNOTPass(); +std::unique_ptr createDisentangleSWAPPass(); std::unique_ptr createIonsDecompositionPass(); +std::unique_ptr createStaticCustomLoweringPass(); } // namespace catalyst diff --git a/mlir/include/Quantum/Transforms/Passes.td b/mlir/include/Quantum/Transforms/Passes.td index 8d47a24212..168aed6a00 100644 --- a/mlir/include/Quantum/Transforms/Passes.td +++ b/mlir/include/Quantum/Transforms/Passes.td @@ -88,40 +88,54 @@ def SplitMultipleTapesPass : Pass<"split-multiple-tapes"> { let constructor = "catalyst::createSplitMultipleTapesPass()"; } -// ----- Quantum circuit transformation passes begin ----- // -// For example, automatic compiler peephole opts, etc. +def StaticCustomLoweringPass : Pass<"static-custom-lowering"> { + let summary = "Lower static custom ops to regular custom op with dynamic parameters."; -class QuantumCircuitTransformationPassBase { - list