diff --git a/frontend/catalyst/qjit_device.py b/frontend/catalyst/qjit_device.py index f025ffe9c4..e52e8d487b 100644 --- a/frontend/catalyst/qjit_device.py +++ b/frontend/catalyst/qjit_device.py @@ -293,6 +293,11 @@ def observables(self) -> Set[str]: """Get the device observables""" return pennylane_operation_set(self.capabilities.native_obs) + @property + def measurement_processes(self) -> Set[str]: + """Get the device measurement processes""" + return self.capabilities.measurement_processes + def preprocess( self, ctx, diff --git a/frontend/catalyst/utils/toml.py b/frontend/catalyst/utils/toml.py index 9c47547c8a..e629e822dd 100644 --- a/frontend/catalyst/utils/toml.py +++ b/frontend/catalyst/utils/toml.py @@ -74,13 +74,14 @@ def intersect_properties(a: OperationProperties, b: OperationProperties) -> Oper @dataclass -class DeviceCapabilities: +class DeviceCapabilities: # pylint: disable=too-many-instance-attributes """Quantum device capabilities""" native_ops: Dict[str, OperationProperties] to_decomp_ops: Dict[str, OperationProperties] to_matrix_ops: Dict[str, OperationProperties] native_obs: Dict[str, OperationProperties] + measurement_processes: Set[str] mid_circuit_measurement_flag: bool runtime_code_generation_flag: bool dynamic_qubit_management_flag: bool @@ -178,6 +179,21 @@ def get_observables(config: TOMLDocument, program_features: ProgramFeatures) -> return parse_toml_section(config, ["operators", "observables"], program_features) +def get_measurement_processes( + config: TOMLDocument, program_features: ProgramFeatures +) -> Dict[str, dict]: + """Get the measurements processes from the `native` section of the config.""" + + schema = int(config["schema"]) + if schema == 1: + shots_string = "finiteshots" if program_features.shots_present else "exactshots" + return parse_toml_section(config, ["measurement_processes", shots_string], program_features) + if schema == 2: + return parse_toml_section(config, ["measurement_processes"], program_features) + + raise CompileError(f"Unsupported config schema {schema}") + + def get_native_ops(config: TOMLDocument, program_features: ProgramFeatures) -> Dict[str, dict]: """Get the gates from the `native` section of the config.""" @@ -334,6 +350,10 @@ def get_device_capabilities( for g, props in get_observables(config, program_features).items(): observable_props[g] = get_operation_properties(props) + measurements_props = set() + for g, props in get_measurement_processes(config, program_features).items(): + measurements_props.add(g) + if schema == 1: patch_schema1_collections( config, @@ -349,6 +369,7 @@ def get_device_capabilities( to_decomp_ops=decomp_props, to_matrix_ops=matrix_decomp_props, native_obs=observable_props, + measurement_processes=measurements_props, mid_circuit_measurement_flag=check_compilation_flag(config, "mid_circuit_measurement"), runtime_code_generation_flag=check_compilation_flag(config, "runtime_code_generation"), dynamic_qubit_management_flag=check_compilation_flag(config, "dynamic_qubit_management"),