Skip to content

Commit

Permalink
add changes to generated client
Browse files Browse the repository at this point in the history
  • Loading branch information
senecameeks committed Oct 3, 2024
1 parent 1b1c201 commit a3c4c71
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 32 deletions.
51 changes: 35 additions & 16 deletions cirq-google/cirq_google/cloud/quantum_v1alpha1/types/quantum.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,48 @@ class QuantumJob(proto.Message):

class DeviceConfigSelector(proto.Message):
r"""-
This message has `oneof`_ fields (mutually exclusive fields).
For each oneof, at most one member field can be set at the same time.
Setting any member of the oneof automatically clears all other
members.
.. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields
Attributes:
run_name (str):
-
This field is a member of `oneof`_ ``top_level_identifier``.
snapshot_id (str):
-
This field is a member of `oneof`_ ``top_level_identifier``.
config_alias (str):
-
"""

run_name = proto.Field(proto.STRING, number=1)
config_alias = proto.Field(proto.STRING, number=2)
run_name: str = proto.Field(proto.STRING, number=1, oneof='top_level_identifier')
snapshot_id: str = proto.Field(proto.STRING, number=3, oneof='top_level_identifier')
config_alias: str = proto.Field(proto.STRING, number=2)


class DeviceConfigKey(proto.Message):
r"""-
This message has `oneof`_ fields (mutually exclusive fields).
For each oneof, at most one member field can be set at the same time.
Setting any member of the oneof automatically clears all other
members.
.. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields
Attributes:
run (str):
-
This field is a member of `oneof`_ ``top_level_identifier``.
snapshot_id (str):
-
This field is a member of `oneof`_ ``top_level_identifier``.
config_alias (str):
-
"""

run: str = proto.Field(proto.STRING, number=1, oneof='top_level_identifier')
snapshot_id: str = proto.Field(proto.STRING, number=3, oneof='top_level_identifier')
config_alias: str = proto.Field(proto.STRING, number=2)


class SchedulingConfig(proto.Message):
Expand Down Expand Up @@ -611,18 +644,4 @@ class QuantumReservation(proto.Message):
whitelisted_users = proto.RepeatedField(proto.STRING, number=5)


class DeviceConfigKey(proto.Message):
r"""-
Attributes:
run (str):
-
config_alias (str):
-
"""

run = proto.Field(proto.STRING, number=1)
config_alias = proto.Field(proto.STRING, number=2)


__all__ = tuple(sorted(__protobuf__.manifest))
4 changes: 2 additions & 2 deletions cirq-google/cirq_google/engine/engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ async def create_job_async(
raise ValueError('priority must be between 0 and 1000')
if not processor_id:
raise ValueError('Must specify a processor id when creating a job.')
if bool(run_name) ^ bool(device_config_name):
if (bool(run_name) or bool(snapshot_id)) ^ bool(device_config_name):
raise ValueError('Cannot specify only one of `run_name` and `device_config_name`')

# Create job.
Expand Down Expand Up @@ -793,7 +793,7 @@ def run_job_over_stream(
raise ValueError('priority must be between 0 and 1000')
if not processor_id:
raise ValueError('Must specify a processor id when creating a job.')
if bool(run_name) ^ bool(device_config_name):
if (bool(run_name) or bool(snapshot_id)) ^ bool(device_config_name):
raise ValueError('Cannot specify only one of `run_name` and `device_config_name`')

project_name = _project_name(project_id)
Expand Down
65 changes: 53 additions & 12 deletions cirq-google/cirq_google/engine/engine_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,12 +530,7 @@ def test_create_job_with_invalid_processor_and_device_config_arguments_throws(
@pytest.mark.parametrize('processor_id', [('processor0'), ('processor0')])
@pytest.mark.parametrize(
'run_name, snapshot_id, device_config_name',
[
('RUN_NAME', None, 'CONFIG_NAME'),
('', None, ''),
(None, 'SNAPSHOT_ID', 'CONFIG_NAME'),
('', '', ''),
],
[('RUN_NAME', None, 'CONFIG_NAME'), ('', None, ''), ('', '', '')],
)
def test_create_job_with_run_name_and_device_config_name(
client_constructor, processor_id, run_name, snapshot_id, device_config_name
Expand Down Expand Up @@ -577,6 +572,52 @@ def test_create_job_with_run_name_and_device_config_name(
)


@mock.patch.dict(os.environ, clear='CIRQ_TESTING')
@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
@pytest.mark.parametrize('processor_id', [('processor0'), ('processor0')])
@pytest.mark.parametrize(
'run_name, snapshot_id, device_config_name', [(None, 'SNAPSHOT_ID', 'CONFIG_NAME')]
)
def test_create_job_with_snapshot_id_and_device_config_name(
client_constructor, processor_id, run_name, snapshot_id, device_config_name
):
grpc_client = _setup_client_mock(client_constructor)
result = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job0')
grpc_client.create_quantum_job.return_value = result
run_context = any_pb2.Any()
client = EngineClient()

assert client.create_job(
project_id='proj',
program_id='prog',
job_id='job0',
processor_id=processor_id,
run_name=run_name,
snapshot_id=snapshot_id,
device_config_name=device_config_name,
run_context=run_context,
priority=10,
) == ('job0', result)
grpc_client.create_quantum_job.assert_called_with(
quantum.CreateQuantumJobRequest(
parent='projects/proj/programs/prog',
quantum_job=quantum.QuantumJob(
name='projects/proj/programs/prog/jobs/job0',
run_context=run_context,
scheduling_config=quantum.SchedulingConfig(
priority=10,
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor='projects/proj/processors/processor0',
device_config_selector=quantum.DeviceConfigSelector(
snapshot_id=snapshot_id, config_alias=device_config_name
),
),
),
),
)
)


@pytest.mark.parametrize(
'run_job_kwargs, expected_submit_args',
[
Expand Down Expand Up @@ -609,7 +650,7 @@ def test_create_job_with_run_name_and_device_config_name(
priority=10,
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor='projects/proj/processors/processor0',
device_config_selector=quantum.DeviceConfigSelector(),
device_config_selector=quantum.DeviceConfigSelector(run_name=""),
),
),
description='A job',
Expand Down Expand Up @@ -643,7 +684,7 @@ def test_create_job_with_run_name_and_device_config_name(
priority=10,
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor='projects/proj/processors/processor0',
device_config_selector=quantum.DeviceConfigSelector(),
device_config_selector=quantum.DeviceConfigSelector(run_name=""),
),
),
description='A job',
Expand Down Expand Up @@ -674,7 +715,7 @@ def test_create_job_with_run_name_and_device_config_name(
priority=10,
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor='projects/proj/processors/processor0',
device_config_selector=quantum.DeviceConfigSelector(),
device_config_selector=quantum.DeviceConfigSelector(run_name=""),
),
),
description='A job',
Expand Down Expand Up @@ -711,7 +752,7 @@ def test_create_job_with_run_name_and_device_config_name(
priority=10,
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor='projects/proj/processors/processor0',
device_config_selector=quantum.DeviceConfigSelector(),
device_config_selector=quantum.DeviceConfigSelector(run_name=""),
),
),
description='A job',
Expand Down Expand Up @@ -746,7 +787,7 @@ def test_create_job_with_run_name_and_device_config_name(
priority=10,
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor='projects/proj/processors/processor0',
device_config_selector=quantum.DeviceConfigSelector(),
device_config_selector=quantum.DeviceConfigSelector(run_name=""),
),
),
),
Expand Down Expand Up @@ -778,7 +819,7 @@ def test_create_job_with_run_name_and_device_config_name(
scheduling_config=quantum.SchedulingConfig(
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor='projects/proj/processors/processor0',
device_config_selector=quantum.DeviceConfigSelector(),
device_config_selector=quantum.DeviceConfigSelector(run_name=""),
)
),
),
Expand Down
3 changes: 3 additions & 0 deletions cirq-google/cirq_google/engine/engine_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def get_sampler(
device_config_name: An identifier used to select the processor configuration
utilized to run the job. A configuration identifies the set of
available qubits, couplers, and supported gates in the processor.
snapshot_id: A unique identifier for an immutable snapshot reference.
A snapshot contains a collection of device configurations for the
processor.
Returns:
A `cirq.Sampler` instance (specifically a `engine_sampler.ProcessorSampler`
that will send circuits to the Quantum Computing Service
Expand Down
4 changes: 2 additions & 2 deletions cirq-google/cirq_google/engine/processor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def __init__(
):
"""Inits ProcessorSampler.
Either both (`run_name` or `snapshot_id`) and `device_config_name` must be set, or neither of
them must be set. If none of them are set, a default internal device configuration
Either both (`run_name` or `snapshot_id`) and `device_config_name` must be set, or neither
of them must be set. If none of them are set, a default internal device configuration
will be used.
Args:
Expand Down
4 changes: 4 additions & 0 deletions cirq-google/cirq_google/engine/processor_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def test_run_batch(run_name, device_config_name):
params=params1,
repetitions=5,
run_name=run_name,
snapshot_id=None,
device_config_name=device_config_name,
),
mock.call().results_async(),
Expand All @@ -71,6 +72,7 @@ def test_run_batch(run_name, device_config_name):
params=params2,
repetitions=5,
run_name=run_name,
snapshot_id=None,
device_config_name=device_config_name,
),
mock.call().results_async(),
Expand Down Expand Up @@ -100,6 +102,7 @@ def test_run_batch_identical_repetitions(run_name, device_config_name):
params=params1,
repetitions=5,
run_name=run_name,
snapshot_id=None,
device_config_name=device_config_name,
),
mock.call().results_async(),
Expand All @@ -108,6 +111,7 @@ def test_run_batch_identical_repetitions(run_name, device_config_name):
params=params2,
repetitions=5,
run_name=run_name,
snapshot_id=None,
device_config_name=device_config_name,
),
mock.call().results_async(),
Expand Down

0 comments on commit a3c4c71

Please sign in to comment.