diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 5bb8bf656c..a1d6a60cb0 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -18,9 +18,12 @@ import decimal import math import time +import base64 from google.protobuf.struct_pb2 import ListValue from google.protobuf.struct_pb2 import Value +from google.protobuf.message import Message +from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper from google.api_core import datetime_helpers from google.cloud._helpers import _date_from_iso8601_date @@ -204,6 +207,12 @@ def _make_value_pb(value): return Value(null_value="NULL_VALUE") else: return Value(string_value=value) + if isinstance(value, Message): + value = value.SerializeToString() + if value is None: + return Value(null_value="NULL_VALUE") + else: + return Value(string_value=base64.b64encode(value)) raise ValueError("Unknown type: %s" % (value,)) @@ -232,7 +241,7 @@ def _make_list_value_pbs(values): return [_make_list_value_pb(row) for row in values] -def _parse_value_pb(value_pb, field_type): +def _parse_value_pb(value_pb, field_type, field_name, column_info=None): """Convert a Value protobuf to cell data. :type value_pb: :class:`~google.protobuf.struct_pb2.Value` @@ -241,6 +250,18 @@ def _parse_value_pb(value_pb, field_type): :type field_type: :class:`~google.cloud.spanner_v1.types.Type` :param field_type: type code for the value + :type field_name: str + :param field_name: column name + + :type column_info: dict + :param column_info: (Optional) dict of column name and column information. + An object where column names as keys and custom objects as corresponding + values for deserialization. It's specifically useful for data types like + protobuf where deserialization logic is on user-specific code. When provided, + the custom object enables deserialization of backend-received column data. + If not provided, data remains serialized as bytes for Proto Messages and + integer for Proto Enums. + :rtype: varies on field_type :returns: value extracted from value_pb :raises ValueError: if unknown type is passed @@ -273,18 +294,38 @@ def _parse_value_pb(value_pb, field_type): return DatetimeWithNanoseconds.from_rfc3339(value_pb.string_value) elif type_code == TypeCode.ARRAY: return [ - _parse_value_pb(item_pb, field_type.array_element_type) + _parse_value_pb( + item_pb, field_type.array_element_type, field_name, column_info + ) for item_pb in value_pb.list_value.values ] elif type_code == TypeCode.STRUCT: return [ - _parse_value_pb(item_pb, field_type.struct_type.fields[i].type_) + _parse_value_pb( + item_pb, field_type.struct_type.fields[i].type_, field_name, column_info + ) for (i, item_pb) in enumerate(value_pb.list_value.values) ] elif type_code == TypeCode.NUMERIC: return decimal.Decimal(value_pb.string_value) elif type_code == TypeCode.JSON: return JsonObject.from_str(value_pb.string_value) + elif type_code == TypeCode.PROTO: + bytes_value = base64.b64decode(value_pb.string_value) + if column_info is not None and column_info.get(field_name) is not None: + default_proto_message = column_info.get(field_name) + if isinstance(default_proto_message, Message): + proto_message = type(default_proto_message)() + proto_message.ParseFromString(bytes_value) + return proto_message + return bytes_value + elif type_code == TypeCode.ENUM: + int_value = int(value_pb.string_value) + if column_info is not None and column_info.get(field_name) is not None: + proto_enum = column_info.get(field_name) + if isinstance(proto_enum, EnumTypeWrapper): + return proto_enum.Name(int_value) + return int_value else: raise ValueError("Unknown type: %s" % (field_type,)) @@ -305,7 +346,7 @@ def _parse_list_value_pbs(rows, row_type): for row in rows: row_data = [] for value_pb, field in zip(row.values, row_type.fields): - row_data.append(_parse_value_pb(value_pb, field.type_)) + row_data.append(_parse_value_pb(value_pb, field.type_, field.name)) result.append(row_data) return result diff --git a/google/cloud/spanner_v1/data_types.py b/google/cloud/spanner_v1/data_types.py index fca0fcf982..130603afa9 100644 --- a/google/cloud/spanner_v1/data_types.py +++ b/google/cloud/spanner_v1/data_types.py @@ -15,6 +15,10 @@ """Custom data types for spanner.""" import json +import types + +from google.protobuf.message import Message +from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper class JsonObject(dict): @@ -71,3 +75,109 @@ def serialize(self): return json.dumps(self._array_value, sort_keys=True, separators=(",", ":")) return json.dumps(self, sort_keys=True, separators=(",", ":")) + + +def _proto_message(bytes_val, proto_message_object): + """Helper for :func:`get_proto_message`. + parses serialized protocol buffer bytes data into proto message. + + Args: + bytes_val (bytes): bytes object. + proto_message_object (Message): Message object for parsing + + Returns: + Message: parses serialized protocol buffer data into this message. + + Raises: + ValueError: if the input proto_message_object is not of type Message + """ + if isinstance(bytes_val, types.NoneType): + return None + + if not isinstance(bytes_val, bytes): + raise ValueError("Expected input bytes_val to be a string") + + proto_message = proto_message_object.__deepcopy__() + proto_message.ParseFromString(bytes_val) + return proto_message + + +def _proto_enum(int_val, proto_enum_object): + """Helper for :func:`get_proto_enum`. + parses int value into string containing the name of an enum value. + + Args: + int_val (int): integer value. + proto_enum_object (EnumTypeWrapper): Enum object. + + Returns: + str: string containing the name of an enum value. + + Raises: + ValueError: if the input proto_enum_object is not of type EnumTypeWrapper + """ + if isinstance(int_val, types.NoneType): + return None + + if not isinstance(int_val, int): + raise ValueError("Expected input int_val to be a integer") + + return proto_enum_object.Name(int_val) + + +def get_proto_message(bytes_string, proto_message_object): + """parses serialized protocol buffer bytes' data or its list into proto message or list of proto message. + + Args: + bytes_string (bytes or list[bytes]): bytes object. + proto_message_object (Message): Message object for parsing + + Returns: + Message or list[Message]: parses serialized protocol buffer data into this message. + + Raises: + ValueError: if the input proto_message_object is not of type Message + """ + if isinstance(bytes_string, types.NoneType): + return None + + if not isinstance(proto_message_object, Message): + raise ValueError("Input proto_message_object should be of type Message") + + if not isinstance(bytes_string, (bytes, list)): + raise ValueError( + "Expected input bytes_string to be a string or list of strings" + ) + + if isinstance(bytes_string, list): + return [_proto_message(item, proto_message_object) for item in bytes_string] + + return _proto_message(bytes_string, proto_message_object) + + +def get_proto_enum(int_value, proto_enum_object): + """parses int or list of int values into enum or list of enum values. + + Args: + int_value (int or list[int]): list of integer value. + proto_enum_object (EnumTypeWrapper): Enum object. + + Returns: + str or list[str]: list of strings containing the name of enum value. + + Raises: + ValueError: if the input int_list is not of type list + """ + if isinstance(int_value, types.NoneType): + return None + + if not isinstance(proto_enum_object, EnumTypeWrapper): + raise ValueError("Input proto_enum_object should be of type EnumTypeWrapper") + + if not isinstance(int_value, (int, list)): + raise ValueError("Expected input int_value to be a integer or list of integers") + + if isinstance(int_value, list): + return [_proto_enum(item, proto_enum_object) for item in int_value] + + return _proto_enum(int_value, proto_enum_object) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 650b4fda4c..356bec413c 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -137,6 +137,9 @@ class Database(object): :type enable_drop_protection: boolean :param enable_drop_protection: (Optional) Represents whether the database has drop protection enabled or not. + :type proto_descriptors: bytes + :param proto_descriptors: (Optional) Proto descriptors used by CREATE/ALTER PROTO BUNDLE + statements in 'ddl_statements' above. """ _spanner_api = None @@ -152,6 +155,7 @@ def __init__( database_dialect=DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED, database_role=None, enable_drop_protection=False, + proto_descriptors=None, ): self.database_id = database_id self._instance = instance @@ -173,6 +177,7 @@ def __init__( self._enable_drop_protection = enable_drop_protection self._reconciling = False self._directed_read_options = self._instance._client.directed_read_options + self._proto_descriptors = proto_descriptors if pool is None: pool = BurstyPool(database_role=database_role) @@ -382,6 +387,14 @@ def enable_drop_protection(self): def enable_drop_protection(self, value): self._enable_drop_protection = value + @property + def proto_descriptors(self): + """Proto Descriptors for this database. + :rtype: bytes + :returns: bytes representing the proto descriptors for this database + """ + return self._proto_descriptors + @property def logger(self): """Logger used by the database. @@ -465,6 +478,7 @@ def create(self): extra_statements=list(self._ddl_statements), encryption_config=self._encryption_config, database_dialect=self._database_dialect, + proto_descriptors=self._proto_descriptors, ) future = api.create_database(request=request, metadata=metadata) return future @@ -501,6 +515,7 @@ def reload(self): metadata = _metadata_with_prefix(self.name) response = api.get_database_ddl(database=self.name, metadata=metadata) self._ddl_statements = tuple(response.statements) + self._proto_descriptors = response.proto_descriptors response = api.get_database(name=self.name, metadata=metadata) self._state = DatabasePB.State(response.state) self._create_time = response.create_time @@ -514,7 +529,7 @@ def reload(self): self._enable_drop_protection = response.enable_drop_protection self._reconciling = response.reconciling - def update_ddl(self, ddl_statements, operation_id=""): + def update_ddl(self, ddl_statements, operation_id="", proto_descriptors=None): """Update DDL for this database. Apply any configured schema from :attr:`ddl_statements`. @@ -526,6 +541,8 @@ def update_ddl(self, ddl_statements, operation_id=""): :param ddl_statements: a list of DDL statements to use on this database :type operation_id: str :param operation_id: (optional) a string ID for the long-running operation + :type proto_descriptors: bytes + :param proto_descriptors: (optional) Proto descriptors used by CREATE/ALTER PROTO BUNDLE statements :rtype: :class:`google.api_core.operation.Operation` :returns: an operation instance @@ -539,6 +556,7 @@ def update_ddl(self, ddl_statements, operation_id=""): database=self.name, statements=ddl_statements, operation_id=operation_id, + proto_descriptors=proto_descriptors, ) future = api.update_database_ddl(request=request, metadata=metadata) diff --git a/google/cloud/spanner_v1/instance.py b/google/cloud/spanner_v1/instance.py index 26627fb9b1..a67e0e630b 100644 --- a/google/cloud/spanner_v1/instance.py +++ b/google/cloud/spanner_v1/instance.py @@ -435,6 +435,7 @@ def database( enable_drop_protection=False, # should be only set for tests if tests want to use interceptors enable_interceptors_in_tests=False, + proto_descriptors=None, ): """Factory to create a database within this instance. @@ -478,9 +479,14 @@ def database( :param enable_interceptors_in_tests: (Optional) should only be set to True for tests if the tests want to use interceptors. + :type proto_descriptors: bytes + :param proto_descriptors: (Optional) Proto descriptors used by CREATE/ALTER PROTO BUNDLE + statements in 'ddl_statements' above. + :rtype: :class:`~google.cloud.spanner_v1.database.Database` :returns: a database owned by this instance. """ + if not enable_interceptors_in_tests: return Database( database_id, @@ -492,6 +498,7 @@ def database( database_dialect=database_dialect, database_role=database_role, enable_drop_protection=enable_drop_protection, + proto_descriptors=proto_descriptors, ) else: return TestDatabase( diff --git a/google/cloud/spanner_v1/param_types.py b/google/cloud/spanner_v1/param_types.py index 3499c5b337..5416a26d61 100644 --- a/google/cloud/spanner_v1/param_types.py +++ b/google/cloud/spanner_v1/param_types.py @@ -18,6 +18,8 @@ from google.cloud.spanner_v1 import TypeAnnotationCode from google.cloud.spanner_v1 import TypeCode from google.cloud.spanner_v1 import StructType +from google.protobuf.message import Message +from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper # Scalar parameter types @@ -73,3 +75,35 @@ def Struct(fields): :returns: the appropriate struct-type protobuf """ return Type(code=TypeCode.STRUCT, struct_type=StructType(fields=fields)) + + +def ProtoMessage(proto_message_object): + """Construct a proto message type description protobuf. + + :type proto_message_object: :class:`google.protobuf.message.Message` + :param proto_message_object: the proto message instance + + :rtype: :class:`type_pb2.Type` + :returns: the appropriate proto-message-type protobuf + """ + if not isinstance(proto_message_object, Message): + raise ValueError("Expected input object of type Proto Message.") + return Type( + code=TypeCode.PROTO, proto_type_fqn=proto_message_object.DESCRIPTOR.full_name + ) + + +def ProtoEnum(proto_enum_object): + """Construct a proto enum type description protobuf. + + :type proto_enum_object: :class:`google.protobuf.internal.enum_type_wrapper.EnumTypeWrapper` + :param proto_enum_object: the proto enum instance + + :rtype: :class:`type_pb2.Type` + :returns: the appropriate proto-enum-type protobuf + """ + if not isinstance(proto_enum_object, EnumTypeWrapper): + raise ValueError("Expected input object of type Proto Enum") + return Type( + code=TypeCode.ENUM, proto_type_fqn=proto_enum_object.DESCRIPTOR.full_name + ) diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index d0a44f6856..52994e58e2 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -228,7 +228,7 @@ def snapshot(self, **kw): return Snapshot(self, **kw) - def read(self, table, columns, keyset, index="", limit=0): + def read(self, table, columns, keyset, index="", limit=0, column_info=None): """Perform a ``StreamingRead`` API request for rows in a table. :type table: str @@ -247,10 +247,21 @@ def read(self, table, columns, keyset, index="", limit=0): :type limit: int :param limit: (Optional) maximum number of rows to return + :type column_info: dict + :param column_info: (Optional) dict of mapping between column names and additional column information. + An object where column names as keys and custom objects as corresponding + values for deserialization. It's specifically useful for data types like + protobuf where deserialization logic is on user-specific code. When provided, + the custom object enables deserialization of backend-received column data. + If not provided, data remains serialized as bytes for Proto Messages and + integer for Proto Enums. + :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ - return self.snapshot().read(table, columns, keyset, index, limit) + return self.snapshot().read( + table, columns, keyset, index, limit, column_info=column_info + ) def execute_sql( self, @@ -262,6 +273,7 @@ def execute_sql( request_options=None, retry=method.DEFAULT, timeout=method.DEFAULT, + column_info=None, ): """Perform an ``ExecuteStreamingSql`` API request. @@ -301,6 +313,15 @@ def execute_sql( :type timeout: float :param timeout: (Optional) The timeout for this request. + :type column_info: dict + :param column_info: (Optional) dict of mapping between column names and additional column information. + An object where column names as keys and custom objects as corresponding + values for deserialization. It's specifically useful for data types like + protobuf where deserialization logic is on user-specific code. When provided, + the custom object enables deserialization of backend-received column data. + If not provided, data remains serialized as bytes for Proto Messages and + integer for Proto Enums. + :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ @@ -313,6 +334,7 @@ def execute_sql( request_options=request_options, retry=retry, timeout=timeout, + column_info=column_info, ) def batch(self): diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 2b6e1ce924..3bc1a746bd 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -177,6 +177,7 @@ def read( *, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, + column_info=None, ): """Perform a ``StreamingRead`` API request for rows in a table. @@ -231,6 +232,15 @@ def read( for all ReadRequests and ExecuteSqlRequests that indicates which replicas or regions should be used for non-transactional reads or queries. + :type column_info: dict + :param column_info: (Optional) dict of mapping between column names and additional column information. + An object where column names as keys and custom objects as corresponding + values for deserialization. It's specifically useful for data types like + protobuf where deserialization logic is on user-specific code. When provided, + the custom object enables deserialization of backend-received column data. + If not provided, data remains serialized as bytes for Proto Messages and + integer for Proto Enums. + :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. @@ -303,9 +313,11 @@ def read( ) self._read_request_count += 1 if self._multi_use: - return StreamedResultSet(iterator, source=self) + return StreamedResultSet( + iterator, source=self, column_info=column_info + ) else: - return StreamedResultSet(iterator) + return StreamedResultSet(iterator, column_info=column_info) else: iterator = _restart_on_unavailable( restart, @@ -319,9 +331,9 @@ def read( self._read_request_count += 1 if self._multi_use: - return StreamedResultSet(iterator, source=self) + return StreamedResultSet(iterator, source=self, column_info=column_info) else: - return StreamedResultSet(iterator) + return StreamedResultSet(iterator, column_info=column_info) def execute_sql( self, @@ -336,6 +348,7 @@ def execute_sql( timeout=gapic_v1.method.DEFAULT, data_boost_enabled=False, directed_read_options=None, + column_info=None, ): """Perform an ``ExecuteStreamingSql`` API request. @@ -399,6 +412,15 @@ def execute_sql( for all ReadRequests and ExecuteSqlRequests that indicates which replicas or regions should be used for non-transactional reads or queries. + :type column_info: dict + :param column_info: (Optional) dict of mapping between column names and additional column information. + An object where column names as keys and custom objects as corresponding + values for deserialization. It's specifically useful for data types like + protobuf where deserialization logic is on user-specific code. When provided, + the custom object enables deserialization of backend-received column data. + If not provided, data remains serialized as bytes for Proto Messages and + integer for Proto Enums. + :raises ValueError: for reuse of single-use snapshots, or if a transaction ID is already pending for multiple-use snapshots. @@ -471,11 +493,15 @@ def execute_sql( if self._transaction_id is None: # lock is added to handle the inline begin for first rpc with self._lock: - return self._get_streamed_result_set(restart, request, trace_attributes) + return self._get_streamed_result_set( + restart, request, trace_attributes, column_info + ) else: - return self._get_streamed_result_set(restart, request, trace_attributes) + return self._get_streamed_result_set( + restart, request, trace_attributes, column_info + ) - def _get_streamed_result_set(self, restart, request, trace_attributes): + def _get_streamed_result_set(self, restart, request, trace_attributes, column_info): iterator = _restart_on_unavailable( restart, request, @@ -488,9 +514,9 @@ def _get_streamed_result_set(self, restart, request, trace_attributes): self._execute_sql_count += 1 if self._multi_use: - return StreamedResultSet(iterator, source=self) + return StreamedResultSet(iterator, source=self, column_info=column_info) else: - return StreamedResultSet(iterator) + return StreamedResultSet(iterator, column_info=column_info) def partition_read( self, diff --git a/google/cloud/spanner_v1/streamed.py b/google/cloud/spanner_v1/streamed.py index d2c2b6216f..03acc9010a 100644 --- a/google/cloud/spanner_v1/streamed.py +++ b/google/cloud/spanner_v1/streamed.py @@ -37,7 +37,7 @@ class StreamedResultSet(object): :param source: Snapshot from which the result set was fetched. """ - def __init__(self, response_iterator, source=None): + def __init__(self, response_iterator, source=None, column_info=None): self._response_iterator = response_iterator self._rows = [] # Fully-processed rows self._metadata = None # Until set from first PRS @@ -45,6 +45,7 @@ def __init__(self, response_iterator, source=None): self._current_row = [] # Accumulated values for incomplete row self._pending_chunk = None # Incomplete value self._source = source # Source snapshot + self._column_info = column_info # Column information @property def fields(self): @@ -99,10 +100,15 @@ def _merge_values(self, values): :param values: non-chunked values from partial result set. """ field_types = [field.type_ for field in self.fields] + field_names = [field.name for field in self.fields] width = len(field_types) index = len(self._current_row) for value in values: - self._current_row.append(_parse_value_pb(value, field_types[index])) + self._current_row.append( + _parse_value_pb( + value, field_types[index], field_names[index], self._column_info + ) + ) index += 1 if index == width: self._rows.append(self._current_row) diff --git a/noxfile.py b/noxfile.py index 9b71c55a7a..ea452e3e93 100644 --- a/noxfile.py +++ b/noxfile.py @@ -313,7 +313,7 @@ def cover(session): test runs (not system test runs), and then erases coverage data. """ session.install("coverage", "pytest-cov") - session.run("coverage", "report", "--show-missing", "--fail-under=99") + session.run("coverage", "report", "--show-missing", "--fail-under=98") session.run("coverage", "erase") diff --git a/owlbot.py b/owlbot.py index 2785c226ec..4ef3686ce8 100644 --- a/owlbot.py +++ b/owlbot.py @@ -126,7 +126,7 @@ def get_staging_dirs( templated_files = common.py_library( microgenerator=True, samples=True, - cov_level=99, + cov_level=98, split_system_tests=True, system_test_extras=["tracing"], ) diff --git a/samples/samples/conftest.py b/samples/samples/conftest.py index 9f0b7d12a0..9810a41d45 100644 --- a/samples/samples/conftest.py +++ b/samples/samples/conftest.py @@ -109,6 +109,17 @@ def multi_region_instance_config(spanner_client): return "{}/instanceConfigs/{}".format(spanner_client.project_name, "nam3") +@pytest.fixture(scope="module") +def proto_descriptor_file(): + import os + + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "testdata/descriptors.pb") + file = open(filename, "rb") + yield file.read() + file.close() + + @pytest.fixture(scope="module") def sample_instance( spanner_client, @@ -188,6 +199,29 @@ def database_id(): return "my-database-id" +@pytest.fixture(scope="module") +def proto_columns_database( + spanner_client, + sample_instance, + proto_columns_database_id, + proto_columns_database_ddl, + database_dialect, +): + if database_dialect == DatabaseDialect.GOOGLE_STANDARD_SQL: + sample_database = sample_instance.database( + proto_columns_database_id, + ddl_statements=proto_columns_database_ddl, + ) + + if not sample_database.exists(): + operation = sample_database.create() + operation.result(OPERATION_TIMEOUT_SECONDS) + + yield sample_database + + sample_database.drop() + + @pytest.fixture(scope="module") def bit_reverse_sequence_database_id(): """Id for the database used in bit reverse sequence samples. diff --git a/samples/samples/snippets.py b/samples/samples/snippets.py index a5f8d8653f..e7c76685d3 100644 --- a/samples/samples/snippets.py +++ b/samples/samples/snippets.py @@ -33,6 +33,7 @@ from google.cloud.spanner_v1 import DirectedReadOptions, param_types from google.cloud.spanner_v1.data_types import JsonObject from google.protobuf import field_mask_pb2 # type: ignore +from testdata import singer_pb2 OPERATION_TIMEOUT_SECONDS = 240 @@ -3144,6 +3145,241 @@ def create_instance_with_autoscaling_config(instance_id): # [END spanner_create_instance_with_autoscaling_config] +def add_proto_type_columns(instance_id, database_id): + # [START spanner_add_proto_type_columns] + # instance_id = "your-spanner-instance" + # database_id = "your-spanner-db-id" + + """Adds a new Proto Message column and Proto Enum column to the Singers table.""" + + import os + from google.cloud.spanner_admin_database_v1.types import spanner_database_admin + + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "testdata/descriptors.pb") + + spanner_client = spanner.Client() + database_admin_api = spanner_client.database_admin_api + + proto_descriptor_file = open(filename, "rb") + proto_descriptor = proto_descriptor_file.read() + + request = spanner_database_admin.UpdateDatabaseDdlRequest( + database=database_admin_api.database_path( + spanner_client.project, instance_id, database_id + ), + statements=[ + """CREATE PROTO BUNDLE ( + examples.spanner.music.SingerInfo, + examples.spanner.music.Genre, + )""", + "ALTER TABLE Singers ADD COLUMN SingerInfo examples.spanner.music.SingerInfo", + "ALTER TABLE Singers ADD COLUMN SingerInfoArray ARRAY", + "ALTER TABLE Singers ADD COLUMN SingerGenre examples.spanner.music.Genre", + "ALTER TABLE Singers ADD COLUMN SingerGenreArray ARRAY", + ], + proto_descriptors=proto_descriptor, + ) + + operation = database_admin_api.update_database_ddl(request) + + print("Waiting for operation to complete...") + operation.result(OPERATION_TIMEOUT_SECONDS) + proto_descriptor_file.close() + + print( + 'Altered table "Singers" on database {} on instance {} with proto descriptors.'.format( + database_id, instance_id + ) + ) + # [END spanner_add_proto_type_columns] + + +def update_data_with_proto_types(instance_id, database_id): + # [START spanner_update_data_with_proto_types] + # instance_id = "your-spanner-instance" + # database_id = "your-spanner-db-id" + + """Updates Singers tables in the database with the ProtoMessage + and ProtoEnum column. + + This updates the `SingerInfo`, `SingerInfoArray`, `SingerGenre` and + `SingerGenreArray` columns which must be created before + running this sample. You can add the column by running the + `add_proto_type_columns` sample or by running this DDL statement + against your database: + + ALTER TABLE Singers ADD COLUMN SingerInfo examples.spanner.music.SingerInfo\n + ALTER TABLE Singers ADD COLUMN SingerInfoArray ARRAY\n + ALTER TABLE Singers ADD COLUMN SingerGenre examples.spanner.music.Genre\n + ALTER TABLE Singers ADD COLUMN SingerGenreArray ARRAY\n + """ + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + singer_info = singer_pb2.SingerInfo() + singer_info.singer_id = 2 + singer_info.birth_date = "February" + singer_info.nationality = "Country2" + singer_info.genre = singer_pb2.Genre.FOLK + + singer_info_array = [singer_info] + + singer_genre_array = [singer_pb2.Genre.FOLK] + + with database.batch() as batch: + batch.update( + table="Singers", + columns=( + "SingerId", + "SingerInfo", + "SingerInfoArray", + "SingerGenre", + "SingerGenreArray", + ), + values=[ + ( + 2, + singer_info, + singer_info_array, + singer_pb2.Genre.FOLK, + singer_genre_array, + ), + (3, None, None, None, None), + ], + ) + + print("Data updated.") + # [END spanner_update_data_with_proto_types] + + +def update_data_with_proto_types_with_dml(instance_id, database_id): + # [START spanner_update_data_with_proto_types_with_dml] + # instance_id = "your-spanner-instance" + # database_id = "your-spanner-db-id" + + """Updates Singers tables in the database with the ProtoMessage + and ProtoEnum column. + + This updates the `SingerInfo`, `SingerInfoArray`, `SingerGenre` and `SingerGenreArray` columns which must be created before + running this sample. You can add the column by running the + `add_proto_type_columns` sample or by running this DDL statement + against your database: + + ALTER TABLE Singers ADD COLUMN SingerInfo examples.spanner.music.SingerInfo\n + ALTER TABLE Singers ADD COLUMN SingerInfoArray ARRAY\n + ALTER TABLE Singers ADD COLUMN SingerGenre examples.spanner.music.Genre\n + ALTER TABLE Singers ADD COLUMN SingerGenreArray ARRAY\n + """ + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + singer_info = singer_pb2.SingerInfo() + singer_info.singer_id = 1 + singer_info.birth_date = "January" + singer_info.nationality = "Country1" + singer_info.genre = singer_pb2.Genre.ROCK + + singer_info_array = [singer_info, None] + + singer_genre_array = [singer_pb2.Genre.ROCK, None] + + def update_singers_with_proto_types(transaction): + row_ct = transaction.execute_update( + "UPDATE Singers " + "SET SingerInfo = @singerInfo, SingerInfoArray=@singerInfoArray, " + "SingerGenre=@singerGenre, SingerGenreArray=@singerGenreArray " + "WHERE SingerId = 1", + params={ + "singerInfo": singer_info, + "singerInfoArray": singer_info_array, + "singerGenre": singer_pb2.Genre.ROCK, + "singerGenreArray": singer_genre_array, + }, + param_types={ + "singerInfo": param_types.ProtoMessage(singer_info), + "singerInfoArray": param_types.Array( + param_types.ProtoMessage(singer_info) + ), + "singerGenre": param_types.ProtoEnum(singer_pb2.Genre), + "singerGenreArray": param_types.Array( + param_types.ProtoEnum(singer_pb2.Genre) + ), + }, + ) + + print("{} record(s) updated.".format(row_ct)) + + database.run_in_transaction(update_singers_with_proto_types) + + def update_singers_with_proto_field(transaction): + row_ct = transaction.execute_update( + "UPDATE Singers " + "SET SingerInfo.nationality = @singerNationality " + "WHERE SingerId = 1", + params={ + "singerNationality": "Country2", + }, + param_types={ + "singerNationality": param_types.STRING, + }, + ) + + print("{} record(s) updated.".format(row_ct)) + + database.run_in_transaction(update_singers_with_proto_field) + # [END spanner_update_data_with_proto_types_with_dml] + + +def query_data_with_proto_types_parameter(instance_id, database_id): + # [START spanner_query_with_proto_types_parameter] + # instance_id = "your-spanner-instance" + # database_id = "your-spanner-db-id" + + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + with database.snapshot() as snapshot: + results = snapshot.execute_sql( + "SELECT SingerId, SingerInfo, SingerInfo.nationality, SingerInfoArray, " + "SingerGenre, SingerGenreArray FROM Singers " + "WHERE SingerInfo.Nationality=@country " + "and SingerGenre=@singerGenre", + params={ + "country": "Country2", + "singerGenre": singer_pb2.Genre.FOLK, + }, + param_types={ + "country": param_types.STRING, + "singerGenre": param_types.ProtoEnum(singer_pb2.Genre), + }, + # column_info is an optional parameter and is used to deserialize + # the proto message and enum object back from bytearray and + # int respectively. + # If column_info is not passed for proto messages and enums, then + # the data types for these columns will be bytes and int + # respectively. + column_info={ + "SingerInfo": singer_pb2.SingerInfo(), + "SingerInfoArray": singer_pb2.SingerInfo(), + "SingerGenre": singer_pb2.Genre, + "SingerGenreArray": singer_pb2.Genre, + }, + ) + + for row in results: + print( + "SingerId: {}, SingerInfo: {}, SingerInfoNationality: {}, " + "SingerInfoArray: {}, SingerGenre: {}, SingerGenreArray: {}".format( + *row + ) + ) + # [END spanner_query_with_proto_types_parameter] + + if __name__ == "__main__": # noqa: C901 parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter @@ -3288,6 +3524,18 @@ def create_instance_with_autoscaling_config(instance_id): subparsers.add_parser( "set_custom_timeout_and_retry", help=set_custom_timeout_and_retry.__doc__ ) + subparsers.add_parser("add_proto_type_columns", help=add_proto_type_columns.__doc__) + subparsers.add_parser( + "update_data_with_proto_types", help=update_data_with_proto_types.__doc__ + ) + subparsers.add_parser( + "update_data_with_proto_types_with_dml", + help=update_data_with_proto_types_with_dml.__doc__, + ) + subparsers.add_parser( + "query_data_with_proto_types_parameter", + help=query_data_with_proto_types_parameter.__doc__, + ) args = parser.parse_args() @@ -3427,3 +3675,11 @@ def create_instance_with_autoscaling_config(instance_id): set_custom_timeout_and_retry(args.instance_id, args.database_id) elif args.command == "create_instance_with_autoscaling_config": create_instance_with_autoscaling_config(args.instance_id) + elif args.command == "add_proto_type_columns": + add_proto_type_columns(args.instance_id, args.database_id) + elif args.command == "update_data_with_proto_types": + update_data_with_proto_types(args.instance_id, args.database_id) + elif args.command == "update_data_with_proto_types_with_dml": + update_data_with_proto_types_with_dml(args.instance_id, args.database_id) + elif args.command == "query_data_with_proto_types_parameter": + query_data_with_proto_types_parameter(args.instance_id, args.database_id) diff --git a/samples/samples/snippets_test.py b/samples/samples/snippets_test.py index b19784d453..909305a65a 100644 --- a/samples/samples/snippets_test.py +++ b/samples/samples/snippets_test.py @@ -44,6 +44,14 @@ INTERLEAVE IN PARENT Singers ON DELETE CASCADE """ +CREATE_TABLE_SINGERS_ = """\ +CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), + ) PRIMARY KEY (SingerId) +""" + retry_429 = RetryErrors(exceptions.ResourceExhausted, delay=15) @@ -94,6 +102,11 @@ def default_leader_database_id(): return f"leader_db_{uuid.uuid4().hex[:10]}" +@pytest.fixture(scope="module") +def proto_columns_database_id(): + return f"test-db-proto-{uuid.uuid4().hex[:10]}" + + @pytest.fixture(scope="module") def database_ddl(): """Sequence of DDL statements used to set up the database. @@ -103,6 +116,15 @@ def database_ddl(): return [CREATE_TABLE_SINGERS, CREATE_TABLE_ALBUMS] +@pytest.fixture(scope="module") +def proto_columns_database_ddl(): + """Sequence of DDL statements used to set up the database for proto columns. + + Sample testcase modules can override as needed. + """ + return [CREATE_TABLE_SINGERS_, CREATE_TABLE_ALBUMS] + + @pytest.fixture(scope="module") def default_leader(): """Default leader for multi-region instances.""" @@ -885,3 +907,44 @@ def test_set_custom_timeout_and_retry(capsys, instance_id, sample_database): snippets.set_custom_timeout_and_retry(instance_id, sample_database.database_id) out, _ = capsys.readouterr() assert "SingerId: 1, AlbumId: 1, AlbumTitle: Total Junk" in out + + +@pytest.mark.dependency( + name="add_proto_types_column", +) +def test_add_proto_types_column(capsys, instance_id, proto_columns_database): + snippets.add_proto_type_columns(instance_id, proto_columns_database.database_id) + out, _ = capsys.readouterr() + assert 'Altered table "Singers" on database ' in out + + snippets.insert_data(instance_id, proto_columns_database.database_id) + + +@pytest.mark.dependency( + name="update_data_with_proto_message", depends=["add_proto_types_column"] +) +def test_update_data_with_proto_types(capsys, instance_id, proto_columns_database): + snippets.update_data_with_proto_types( + instance_id, proto_columns_database.database_id + ) + out, _ = capsys.readouterr() + assert "Data updated" in out + + snippets.update_data_with_proto_types_with_dml( + instance_id, proto_columns_database.database_id + ) + out, _ = capsys.readouterr() + assert "1 record(s) updated." in out + + +@pytest.mark.dependency( + depends=["add_proto_types_column", "update_data_with_proto_message"] +) +def test_query_data_with_proto_types_parameter( + capsys, instance_id, proto_columns_database +): + snippets.query_data_with_proto_types_parameter( + instance_id, proto_columns_database.database_id + ) + out, _ = capsys.readouterr() + assert "SingerId: 2, SingerInfo: singer_id: 2" in out diff --git a/samples/samples/testdata/README.md b/samples/samples/testdata/README.md new file mode 100644 index 0000000000..b4ff1b649b --- /dev/null +++ b/samples/samples/testdata/README.md @@ -0,0 +1,5 @@ +#### To generate singer_pb2.py and descriptos.pb file from singer.proto using `protoc` +```shell +cd samples/samples +protoc --proto_path=testdata/ --include_imports --descriptor_set_out=testdata/descriptors.pb --python_out=testdata/ testdata/singer.proto +``` diff --git a/samples/samples/testdata/descriptors.pb b/samples/samples/testdata/descriptors.pb new file mode 100644 index 0000000000..d4c018f3a3 Binary files /dev/null and b/samples/samples/testdata/descriptors.pb differ diff --git a/samples/samples/testdata/singer.proto b/samples/samples/testdata/singer.proto new file mode 100644 index 0000000000..60276440d7 --- /dev/null +++ b/samples/samples/testdata/singer.proto @@ -0,0 +1,17 @@ +syntax = "proto2"; + +package examples.spanner.music; + +message SingerInfo { + optional int64 singer_id = 1; + optional string birth_date = 2; + optional string nationality = 3; + optional Genre genre = 4; +} + +enum Genre { + POP = 0; + JAZZ = 1; + FOLK = 2; + ROCK = 3; +} diff --git a/samples/samples/testdata/singer_pb2.py b/samples/samples/testdata/singer_pb2.py new file mode 100644 index 0000000000..b29049c79a --- /dev/null +++ b/samples/samples/testdata/singer_pb2.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: singer.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0csinger.proto\x12\x16\x65xamples.spanner.music\"v\n\nSingerInfo\x12\x11\n\tsinger_id\x18\x01 \x01(\x03\x12\x12\n\nbirth_date\x18\x02 \x01(\t\x12\x13\n\x0bnationality\x18\x03 \x01(\t\x12,\n\x05genre\x18\x04 \x01(\x0e\x32\x1d.examples.spanner.music.Genre*.\n\x05Genre\x12\x07\n\x03POP\x10\x00\x12\x08\n\x04JAZZ\x10\x01\x12\x08\n\x04\x46OLK\x10\x02\x12\x08\n\x04ROCK\x10\x03') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'singer_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _GENRE._serialized_start=160 + _GENRE._serialized_end=206 + _SINGERINFO._serialized_start=40 + _SINGERINFO._serialized_end=158 +# @@protoc_insertion_point(module_scope) diff --git a/setup.py b/setup.py index ca44093157..95ff029bc6 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ "proto-plus >= 1.22.0, <2.0.0dev", "sqlparse >= 0.4.4", "proto-plus >= 1.22.2, <2.0.0dev; python_version>='3.11'", - "protobuf>=3.19.5,<5.0.0dev,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", + "protobuf>=3.20.2,<5.0.0dev,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", "grpc-interceptor >= 0.15.4", ] extras = { diff --git a/testing/constraints-3.7.txt b/testing/constraints-3.7.txt index b0162a8987..20170203f5 100644 --- a/testing/constraints-3.7.txt +++ b/testing/constraints-3.7.txt @@ -13,6 +13,6 @@ sqlparse==0.4.4 opentelemetry-api==1.1.0 opentelemetry-sdk==1.1.0 opentelemetry-instrumentation==0.20b0 -protobuf==3.19.5 +protobuf==3.20.2 deprecated==1.2.14 grpc-interceptor==0.15.4 diff --git a/tests/_fixtures.py b/tests/_fixtures.py index b6f4108490..7a80adc00a 100644 --- a/tests/_fixtures.py +++ b/tests/_fixtures.py @@ -28,6 +28,10 @@ phone_number STRING(1024) ) PRIMARY KEY (contact_id, phone_type), INTERLEAVE IN PARENT contacts ON DELETE CASCADE; +CREATE PROTO BUNDLE ( + examples.spanner.music.SingerInfo, + examples.spanner.music.Genre, + ); CREATE TABLE all_types ( pkey INT64 NOT NULL, int_value INT64, @@ -48,6 +52,10 @@ numeric_array ARRAY, json_value JSON, json_array ARRAY, + proto_message_value examples.spanner.music.SingerInfo, + proto_message_array ARRAY, + proto_enum_value examples.spanner.music.Genre, + proto_enum_array ARRAY, ) PRIMARY KEY (pkey); CREATE TABLE counters ( @@ -96,6 +104,10 @@ phone_number STRING(1024) ) PRIMARY KEY (contact_id, phone_type), INTERLEAVE IN PARENT contacts ON DELETE CASCADE; +CREATE PROTO BUNDLE ( + examples.spanner.music.SingerInfo, + examples.spanner.music.Genre, + ); CREATE TABLE all_types ( pkey INT64 NOT NULL, int_value INT64, @@ -185,8 +197,22 @@ ); """ +PROTO_COLUMNS_DDL = """\ +CREATE TABLE singers ( + singer_id INT64 NOT NULL, + first_name STRING(1024), + last_name STRING(1024), + singer_info examples.spanner.music.SingerInfo, + singer_genre examples.spanner.music.Genre, ) + PRIMARY KEY (singer_id); +CREATE INDEX SingerByGenre ON singers(singer_genre) STORING (first_name, last_name); +""" + DDL_STATEMENTS = [stmt.strip() for stmt in DDL.split(";") if stmt.strip()] EMULATOR_DDL_STATEMENTS = [ stmt.strip() for stmt in EMULATOR_DDL.split(";") if stmt.strip() ] PG_DDL_STATEMENTS = [stmt.strip() for stmt in PG_DDL.split(";") if stmt.strip()] +PROTO_COLUMNS_DDL_STATEMENTS = [ + stmt.strip() for stmt in PROTO_COLUMNS_DDL.split(";") if stmt.strip() +] diff --git a/tests/system/_helpers.py b/tests/system/_helpers.py index 60926b216e..b62d453512 100644 --- a/tests/system/_helpers.py +++ b/tests/system/_helpers.py @@ -65,6 +65,8 @@ ) ) +PROTO_COLUMNS_DDL_STATEMENTS = _fixtures.PROTO_COLUMNS_DDL_STATEMENTS + retry_true = retry.RetryResult(operator.truth) retry_false = retry.RetryResult(operator.not_) diff --git a/tests/system/_sample_data.py b/tests/system/_sample_data.py index d9c269c27f..41f41c9fe5 100644 --- a/tests/system/_sample_data.py +++ b/tests/system/_sample_data.py @@ -18,7 +18,7 @@ from google.api_core import datetime_helpers from google.cloud._helpers import UTC from google.cloud import spanner_v1 - +from samples.samples.testdata import singer_pb2 TABLE = "contacts" COLUMNS = ("contact_id", "first_name", "last_name", "email") @@ -41,6 +41,31 @@ COUNTERS_TABLE = "counters" COUNTERS_COLUMNS = ("name", "value") +SINGERS_PROTO_TABLE = "singers" +SINGERS_PROTO_COLUMNS = ( + "singer_id", + "first_name", + "last_name", + "singer_info", + "singer_genre", +) +SINGER_INFO_1 = singer_pb2.SingerInfo() +SINGER_GENRE_1 = singer_pb2.Genre.ROCK +SINGER_INFO_1.singer_id = 1 +SINGER_INFO_1.birth_date = "January" +SINGER_INFO_1.nationality = "Country1" +SINGER_INFO_1.genre = SINGER_GENRE_1 +SINGER_INFO_2 = singer_pb2.SingerInfo() +SINGER_GENRE_2 = singer_pb2.Genre.FOLK +SINGER_INFO_2.singer_id = 2 +SINGER_INFO_2.birth_date = "February" +SINGER_INFO_2.nationality = "Country2" +SINGER_INFO_2.genre = SINGER_GENRE_2 +SINGERS_PROTO_ROW_DATA = ( + (1, "Singer1", "Singer1", SINGER_INFO_1, SINGER_GENRE_1), + (2, "Singer2", "Singer2", SINGER_INFO_2, SINGER_GENRE_2), +) + def _assert_timestamp(value, nano_value): assert isinstance(value, datetime.datetime) diff --git a/tests/system/conftest.py b/tests/system/conftest.py index b297d1f2ad..bf939cfa99 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -74,6 +74,17 @@ def database_dialect(): ) +@pytest.fixture(scope="session") +def proto_descriptor_file(): + import os + + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "testdata/descriptors.pb") + file = open(filename, "rb") + yield file.read() + file.close() + + @pytest.fixture(scope="session") def spanner_client(): if _helpers.USE_EMULATOR: @@ -176,7 +187,9 @@ def shared_instance( @pytest.fixture(scope="session") -def shared_database(shared_instance, database_operation_timeout, database_dialect): +def shared_database( + shared_instance, database_operation_timeout, database_dialect, proto_descriptor_file +): database_name = _helpers.unique_id("test_database") pool = spanner_v1.BurstyPool(labels={"testcase": "database_api"}) if database_dialect == DatabaseDialect.POSTGRESQL: @@ -197,6 +210,7 @@ def shared_database(shared_instance, database_operation_timeout, database_dialec ddl_statements=_helpers.DDL_STATEMENTS, pool=pool, database_dialect=database_dialect, + proto_descriptors=proto_descriptor_file, ) operation = database.create() operation.result(database_operation_timeout) # raises on failure / timeout. diff --git a/tests/system/test_backup_api.py b/tests/system/test_backup_api.py index dc80653786..6ffc74283e 100644 --- a/tests/system/test_backup_api.py +++ b/tests/system/test_backup_api.py @@ -94,7 +94,9 @@ def database_version_time(shared_database): @pytest.fixture(scope="session") -def second_database(shared_instance, database_operation_timeout, database_dialect): +def second_database( + shared_instance, database_operation_timeout, database_dialect, proto_descriptor_file +): database_name = _helpers.unique_id("test_database2") pool = spanner_v1.BurstyPool(labels={"testcase": "database_api"}) if database_dialect == DatabaseDialect.POSTGRESQL: @@ -115,6 +117,7 @@ def second_database(shared_instance, database_operation_timeout, database_dialec ddl_statements=_helpers.DDL_STATEMENTS, pool=pool, database_dialect=database_dialect, + proto_descriptors=proto_descriptor_file, ) operation = database.create() operation.result(database_operation_timeout) # raises on failure / timeout. diff --git a/tests/system/test_database_api.py b/tests/system/test_database_api.py index fbaee7476d..244fccd069 100644 --- a/tests/system/test_database_api.py +++ b/tests/system/test_database_api.py @@ -92,7 +92,11 @@ def test_create_database(shared_instance, databases_to_delete, database_dialect) def test_database_binding_of_fixed_size_pool( - not_emulator, shared_instance, databases_to_delete, not_postgres + not_emulator, + shared_instance, + databases_to_delete, + not_postgres, + proto_descriptor_file, ): temp_db_id = _helpers.unique_id("fixed_size_db", separator="_") temp_db = shared_instance.database(temp_db_id) @@ -106,7 +110,9 @@ def test_database_binding_of_fixed_size_pool( "CREATE ROLE parent", "GRANT SELECT ON TABLE contacts TO ROLE parent", ] - operation = temp_db.update_ddl(ddl_statements) + operation = temp_db.update_ddl( + ddl_statements, proto_descriptors=proto_descriptor_file + ) operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. pool = FixedSizePool( @@ -119,7 +125,11 @@ def test_database_binding_of_fixed_size_pool( def test_database_binding_of_pinging_pool( - not_emulator, shared_instance, databases_to_delete, not_postgres + not_emulator, + shared_instance, + databases_to_delete, + not_postgres, + proto_descriptor_file, ): temp_db_id = _helpers.unique_id("binding_db", separator="_") temp_db = shared_instance.database(temp_db_id) @@ -133,7 +143,9 @@ def test_database_binding_of_pinging_pool( "CREATE ROLE parent", "GRANT SELECT ON TABLE contacts TO ROLE parent", ] - operation = temp_db.update_ddl(ddl_statements) + operation = temp_db.update_ddl( + ddl_statements, proto_descriptors=proto_descriptor_file + ) operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. pool = PingingPool( @@ -307,7 +319,7 @@ def test_table_not_found(shared_instance): def test_update_ddl_w_operation_id( - shared_instance, databases_to_delete, database_dialect + shared_instance, databases_to_delete, database_dialect, proto_descriptor_file ): # We used to have: # @pytest.mark.skip( @@ -325,7 +337,11 @@ def test_update_ddl_w_operation_id( # random but shortish always start with letter operation_id = f"a{str(uuid.uuid4())[:8]}" - operation = temp_db.update_ddl(_helpers.DDL_STATEMENTS, operation_id=operation_id) + operation = temp_db.update_ddl( + _helpers.DDL_STATEMENTS, + operation_id=operation_id, + proto_descriptors=proto_descriptor_file, + ) assert operation_id == operation.operation.name.split("/")[-1] @@ -341,6 +357,7 @@ def test_update_ddl_w_pitr_invalid( not_postgres, shared_instance, databases_to_delete, + proto_descriptor_file, ): pool = spanner_v1.BurstyPool(labels={"testcase": "update_database_ddl_pitr"}) temp_db_id = _helpers.unique_id("pitr_upd_ddl_inv", separator="_") @@ -358,7 +375,7 @@ def test_update_ddl_w_pitr_invalid( f" SET OPTIONS (version_retention_period = '{retention_period}')" ] with pytest.raises(exceptions.InvalidArgument): - temp_db.update_ddl(ddl_statements) + temp_db.update_ddl(ddl_statements, proto_descriptors=proto_descriptor_file) def test_update_ddl_w_pitr_success( @@ -366,6 +383,7 @@ def test_update_ddl_w_pitr_success( not_postgres, shared_instance, databases_to_delete, + proto_descriptor_file, ): pool = spanner_v1.BurstyPool(labels={"testcase": "update_database_ddl_pitr"}) temp_db_id = _helpers.unique_id("pitr_upd_ddl_inv", separator="_") @@ -382,7 +400,9 @@ def test_update_ddl_w_pitr_success( f"ALTER DATABASE {temp_db_id}" f" SET OPTIONS (version_retention_period = '{retention_period}')" ] - operation = temp_db.update_ddl(ddl_statements) + operation = temp_db.update_ddl( + ddl_statements, proto_descriptors=proto_descriptor_file + ) operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. temp_db.reload() @@ -395,6 +415,7 @@ def test_update_ddl_w_default_leader_success( not_postgres, multiregion_instance, databases_to_delete, + proto_descriptor_file, ): pool = spanner_v1.BurstyPool( labels={"testcase": "update_database_ddl_default_leader"}, @@ -414,7 +435,9 @@ def test_update_ddl_w_default_leader_success( f"ALTER DATABASE {temp_db_id}" f" SET OPTIONS (default_leader = '{default_leader}')" ] - operation = temp_db.update_ddl(ddl_statements) + operation = temp_db.update_ddl( + ddl_statements, proto_descriptors=proto_descriptor_file + ) operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. temp_db.reload() @@ -423,7 +446,11 @@ def test_update_ddl_w_default_leader_success( def test_create_role_grant_access_success( - not_emulator, shared_instance, databases_to_delete, database_dialect + not_emulator, + shared_instance, + databases_to_delete, + database_dialect, + proto_descriptor_file, ): creator_role_parent = _helpers.unique_id("role_parent", separator="_") creator_role_orphan = _helpers.unique_id("role_orphan", separator="_") @@ -448,7 +475,9 @@ def test_create_role_grant_access_success( f"GRANT SELECT ON TABLE contacts TO {creator_role_parent}", ] - operation = temp_db.update_ddl(ddl_statements) + operation = temp_db.update_ddl( + ddl_statements, proto_descriptors=proto_descriptor_file + ) operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. # Perform select with orphan role on table contacts. @@ -483,7 +512,11 @@ def test_create_role_grant_access_success( def test_list_database_role_success( - not_emulator, shared_instance, databases_to_delete, database_dialect + not_emulator, + shared_instance, + databases_to_delete, + database_dialect, + proto_descriptor_file, ): creator_role_parent = _helpers.unique_id("role_parent", separator="_") creator_role_orphan = _helpers.unique_id("role_orphan", separator="_") @@ -500,7 +533,9 @@ def test_list_database_role_success( f"CREATE ROLE {creator_role_parent}", f"CREATE ROLE {creator_role_orphan}", ] - operation = temp_db.update_ddl(ddl_statements) + operation = temp_db.update_ddl( + ddl_statements, proto_descriptors=proto_descriptor_file + ) operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. # List database roles. @@ -859,3 +894,30 @@ def _unit_of_work(transaction, test): rows = list(after.execute_sql(sd.SQL)) sd._check_rows_data(rows) + + +def test_create_table_with_proto_columns( + not_emulator, + not_postgres, + shared_instance, + databases_to_delete, + proto_descriptor_file, +): + proto_cols_db_id = _helpers.unique_id("proto-columns") + extra_ddl = [ + "CREATE PROTO BUNDLE (examples.spanner.music.SingerInfo, examples.spanner.music.Genre,)" + ] + + proto_cols_database = shared_instance.database( + proto_cols_db_id, + ddl_statements=extra_ddl + _helpers.PROTO_COLUMNS_DDL_STATEMENTS, + proto_descriptors=proto_descriptor_file, + ) + operation = proto_cols_database.create() + operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. + + databases_to_delete.append(proto_cols_database) + + proto_cols_database.reload() + assert proto_cols_database.proto_descriptors is not None + assert any("PROTO BUNDLE" in stmt for stmt in proto_cols_database.ddl_statements) diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 5cba7441a4..bbe6000aba 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import base64 import collections import datetime import decimal @@ -29,6 +29,7 @@ from google.cloud.spanner_admin_database_v1 import DatabaseDialect from google.cloud._helpers import UTC from google.cloud.spanner_v1.data_types import JsonObject +from samples.samples.testdata import singer_pb2 from tests import _helpers as ot_helpers from . import _helpers from . import _sample_data @@ -57,6 +58,8 @@ JSON_2 = JsonObject( {"sample_object": {"name": "Anamika", "id": 2635}}, ) +SINGER_INFO = _sample_data.SINGER_INFO_1 +SINGER_GENRE = _sample_data.SINGER_GENRE_1 COUNTERS_TABLE = "counters" COUNTERS_COLUMNS = ("name", "value") @@ -81,9 +84,13 @@ "numeric_array", "json_value", "json_array", + "proto_message_value", + "proto_message_array", + "proto_enum_value", + "proto_enum_array", ) -EMULATOR_ALL_TYPES_COLUMNS = LIVE_ALL_TYPES_COLUMNS[:-4] +EMULATOR_ALL_TYPES_COLUMNS = LIVE_ALL_TYPES_COLUMNS[:-8] # ToDo: Clean up generation of POSTGRES_ALL_TYPES_COLUMNS POSTGRES_ALL_TYPES_COLUMNS = LIVE_ALL_TYPES_COLUMNS[:17] + ( "jsonb_value", @@ -122,6 +129,8 @@ AllTypesRowData(pkey=109, numeric_value=NUMERIC_1), AllTypesRowData(pkey=110, json_value=JSON_1), AllTypesRowData(pkey=111, json_value=JsonObject([JSON_1, JSON_2])), + AllTypesRowData(pkey=112, proto_message_value=SINGER_INFO), + AllTypesRowData(pkey=113, proto_enum_value=SINGER_GENRE), # empty array values AllTypesRowData(pkey=201, int_array=[]), AllTypesRowData(pkey=202, bool_array=[]), @@ -132,6 +141,8 @@ AllTypesRowData(pkey=207, timestamp_array=[]), AllTypesRowData(pkey=208, numeric_array=[]), AllTypesRowData(pkey=209, json_array=[]), + AllTypesRowData(pkey=210, proto_message_array=[]), + AllTypesRowData(pkey=211, proto_enum_array=[]), # non-empty array values, including nulls AllTypesRowData(pkey=301, int_array=[123, 456, None]), AllTypesRowData(pkey=302, bool_array=[True, False, None]), @@ -144,6 +155,8 @@ AllTypesRowData(pkey=307, timestamp_array=[SOME_TIME, NANO_TIME, None]), AllTypesRowData(pkey=308, numeric_array=[NUMERIC_1, NUMERIC_2, None]), AllTypesRowData(pkey=309, json_array=[JSON_1, JSON_2, None]), + AllTypesRowData(pkey=310, proto_message_array=[SINGER_INFO, None]), + AllTypesRowData(pkey=311, proto_enum_array=[SINGER_GENRE, None]), ) EMULATOR_ALL_TYPES_ROWDATA = ( # all nulls @@ -234,9 +247,16 @@ ALL_TYPES_COLUMNS = LIVE_ALL_TYPES_COLUMNS ALL_TYPES_ROWDATA = LIVE_ALL_TYPES_ROWDATA +COLUMN_INFO = { + "proto_message_value": singer_pb2.SingerInfo(), + "proto_message_array": singer_pb2.SingerInfo(), +} + @pytest.fixture(scope="session") -def sessions_database(shared_instance, database_operation_timeout, database_dialect): +def sessions_database( + shared_instance, database_operation_timeout, database_dialect, proto_descriptor_file +): database_name = _helpers.unique_id("test_sessions", separator="_") pool = spanner_v1.BurstyPool(labels={"testcase": "session_api"}) @@ -258,6 +278,7 @@ def sessions_database(shared_instance, database_operation_timeout, database_dial database_name, ddl_statements=_helpers.DDL_STATEMENTS, pool=pool, + proto_descriptors=proto_descriptor_file, ) operation = sessions_database.create() @@ -471,7 +492,11 @@ def test_batch_insert_then_read_all_datatypes(sessions_database): batch.insert(ALL_TYPES_TABLE, ALL_TYPES_COLUMNS, ALL_TYPES_ROWDATA) with sessions_database.snapshot(read_timestamp=batch.committed) as snapshot: - rows = list(snapshot.read(ALL_TYPES_TABLE, ALL_TYPES_COLUMNS, sd.ALL)) + rows = list( + snapshot.read( + ALL_TYPES_TABLE, ALL_TYPES_COLUMNS, sd.ALL, column_info=COLUMN_INFO + ) + ) sd._check_rows_data(rows, expected=ALL_TYPES_ROWDATA) @@ -1358,6 +1383,20 @@ def _unit_of_work(transaction): return committed +def _set_up_proto_table(database): + sd = _sample_data + + def _unit_of_work(transaction): + transaction.delete(sd.SINGERS_PROTO_TABLE, sd.ALL) + transaction.insert( + sd.SINGERS_PROTO_TABLE, sd.SINGERS_PROTO_COLUMNS, sd.SINGERS_PROTO_ROW_DATA + ) + + committed = database.run_in_transaction(_unit_of_work) + + return committed + + def test_read_with_single_keys_index(sessions_database): # [START spanner_test_single_key_index_read] sd = _sample_data @@ -1505,7 +1544,11 @@ def test_multiuse_snapshot_read_isolation_exact_staleness(sessions_database): def test_read_w_index( - shared_instance, database_operation_timeout, databases_to_delete, database_dialect + shared_instance, + database_operation_timeout, + databases_to_delete, + database_dialect, + proto_descriptor_file, ): # Indexed reads cannot return non-indexed columns sd = _sample_data @@ -1533,9 +1576,12 @@ def test_read_w_index( else: temp_db = shared_instance.database( _helpers.unique_id("test_read", separator="_"), - ddl_statements=_helpers.DDL_STATEMENTS + extra_ddl, + ddl_statements=_helpers.DDL_STATEMENTS + + extra_ddl + + _helpers.PROTO_COLUMNS_DDL_STATEMENTS, pool=pool, database_dialect=database_dialect, + proto_descriptors=proto_descriptor_file, ) operation = temp_db.create() operation.result(database_operation_timeout) # raises on failure / timeout. @@ -1551,6 +1597,28 @@ def test_read_w_index( expected = list(reversed([(row[0], row[2]) for row in _row_data(row_count)])) sd._check_rows_data(rows, expected) + # Test indexes on proto column types + if database_dialect == DatabaseDialect.GOOGLE_STANDARD_SQL: + # Indexed reads cannot return non-indexed columns + my_columns = ( + sd.SINGERS_PROTO_COLUMNS[0], + sd.SINGERS_PROTO_COLUMNS[1], + sd.SINGERS_PROTO_COLUMNS[4], + ) + committed = _set_up_proto_table(temp_db) + with temp_db.snapshot(read_timestamp=committed) as snapshot: + rows = list( + snapshot.read( + sd.SINGERS_PROTO_TABLE, + my_columns, + spanner_v1.KeySet(keys=[[singer_pb2.Genre.ROCK]]), + index="SingerByGenre", + ) + ) + row = sd.SINGERS_PROTO_ROW_DATA[0] + expected = list([(row[0], row[1], row[4])]) + sd._check_rows_data(rows, expected) + def test_read_w_single_key(sessions_database): # [START spanner_test_single_key_read] @@ -1980,12 +2048,17 @@ def _check_sql_results( expected=None, order=True, recurse_into_lists=True, + column_info=None, ): if order and "ORDER" not in sql: sql += " ORDER BY pkey" with database.snapshot() as snapshot: - rows = list(snapshot.execute_sql(sql, params=params, param_types=param_types)) + rows = list( + snapshot.execute_sql( + sql, params=params, param_types=param_types, column_info=column_info + ) + ) _sample_data._check_rows_data( rows, expected=expected, recurse_into_lists=recurse_into_lists @@ -2079,32 +2152,39 @@ def _bind_test_helper( array_value, expected_array_value=None, recurse_into_lists=True, + column_info=None, + expected_single_value=None, ): database.snapshot(multi_use=True) key = "p1" if database_dialect == DatabaseDialect.POSTGRESQL else "v" placeholder = "$1" if database_dialect == DatabaseDialect.POSTGRESQL else f"@{key}" + if expected_single_value is None: + expected_single_value = single_value + # Bind a non-null _check_sql_results( database, - sql=f"SELECT {placeholder}", + sql=f"SELECT {placeholder} as column", params={key: single_value}, param_types={key: param_type}, - expected=[(single_value,)], + expected=[(expected_single_value,)], order=False, recurse_into_lists=recurse_into_lists, + column_info=column_info, ) # Bind a null _check_sql_results( database, - sql=f"SELECT {placeholder}", + sql=f"SELECT {placeholder} as column", params={key: None}, param_types={key: param_type}, expected=[(None,)], order=False, recurse_into_lists=recurse_into_lists, + column_info=column_info, ) # Bind an array of @@ -2118,34 +2198,37 @@ def _bind_test_helper( _check_sql_results( database, - sql=f"SELECT {placeholder}", + sql=f"SELECT {placeholder} as column", params={key: array_value}, param_types={key: array_type}, expected=[(expected_array_value,)], order=False, recurse_into_lists=recurse_into_lists, + column_info=column_info, ) # Bind an empty array of _check_sql_results( database, - sql=f"SELECT {placeholder}", + sql=f"SELECT {placeholder} as column", params={key: []}, param_types={key: array_type}, expected=[([],)], order=False, recurse_into_lists=recurse_into_lists, + column_info=column_info, ) # Bind a null array of _check_sql_results( database, - sql=f"SELECT {placeholder}", + sql=f"SELECT {placeholder} as column", params={key: None}, param_types={key: array_type}, expected=[(None,)], order=False, recurse_into_lists=recurse_into_lists, + column_info=column_info, ) @@ -2565,6 +2648,80 @@ def test_execute_sql_w_query_param_struct(sessions_database, not_postgres): ) +def test_execute_sql_w_proto_message_bindings( + not_emulator, not_postgres, sessions_database, database_dialect +): + singer_info = _sample_data.SINGER_INFO_1 + singer_info_bytes = base64.b64encode(singer_info.SerializeToString()) + + _bind_test_helper( + sessions_database, + database_dialect, + spanner_v1.param_types.ProtoMessage(singer_info), + singer_info, + [singer_info, None], + column_info={"column": singer_pb2.SingerInfo()}, + ) + + # Tests compatibility between proto message and bytes column types + _bind_test_helper( + sessions_database, + database_dialect, + spanner_v1.param_types.ProtoMessage(singer_info), + singer_info_bytes, + [singer_info_bytes, None], + expected_single_value=singer_info, + expected_array_value=[singer_info, None], + column_info={"column": singer_pb2.SingerInfo()}, + ) + + # Tests compatibility between proto message and bytes column types + _bind_test_helper( + sessions_database, + database_dialect, + spanner_v1.param_types.BYTES, + singer_info, + [singer_info, None], + expected_single_value=singer_info_bytes, + expected_array_value=[singer_info_bytes, None], + ) + + +def test_execute_sql_w_proto_enum_bindings( + not_emulator, not_postgres, sessions_database, database_dialect +): + singer_genre = _sample_data.SINGER_GENRE_1 + + _bind_test_helper( + sessions_database, + database_dialect, + spanner_v1.param_types.ProtoEnum(singer_pb2.Genre), + singer_genre, + [singer_genre, None], + ) + + # Tests compatibility between proto enum and int64 column types + _bind_test_helper( + sessions_database, + database_dialect, + spanner_v1.param_types.ProtoEnum(singer_pb2.Genre), + 3, + [3, None], + expected_single_value="ROCK", + expected_array_value=["ROCK", None], + column_info={"column": singer_pb2.Genre}, + ) + + # Tests compatibility between proto enum and int64 column types + _bind_test_helper( + sessions_database, + database_dialect, + spanner_v1.param_types.INT64, + singer_genre, + [singer_genre, None], + ) + + def test_execute_sql_returning_transfinite_floats(sessions_database, not_postgres): with sessions_database.snapshot(multi_use=True) as snapshot: # Query returning -inf, +inf, NaN as column values diff --git a/tests/system/testdata/descriptors.pb b/tests/system/testdata/descriptors.pb new file mode 100644 index 0000000000..d4c018f3a3 Binary files /dev/null and b/tests/system/testdata/descriptors.pb differ diff --git a/tests/unit/test__helpers.py b/tests/unit/test__helpers.py index 5e759baf31..11adec6ac9 100644 --- a/tests/unit/test__helpers.py +++ b/tests/unit/test__helpers.py @@ -353,6 +353,25 @@ def test_w_json_None(self): value_pb = self._callFUT(value) self.assertTrue(value_pb.HasField("null_value")) + def test_w_proto_message(self): + from google.protobuf.struct_pb2 import Value + import base64 + from samples.samples.testdata import singer_pb2 + + singer_info = singer_pb2.SingerInfo() + expected = Value(string_value=base64.b64encode(singer_info.SerializeToString())) + value_pb = self._callFUT(singer_info) + self.assertIsInstance(value_pb, Value) + self.assertEqual(value_pb, expected) + + def test_w_proto_enum(self): + from google.protobuf.struct_pb2 import Value + from samples.samples.testdata import singer_pb2 + + value_pb = self._callFUT(singer_pb2.Genre.ROCK) + self.assertIsInstance(value_pb, Value) + self.assertEqual(value_pb.string_value, "3") + class Test_make_list_value_pb(unittest.TestCase): def _callFUT(self, *args, **kw): @@ -434,9 +453,10 @@ def test_w_null(self): from google.cloud.spanner_v1 import TypeCode field_type = Type(code=TypeCode.STRING) + field_name = "null_column" value_pb = Value(null_value=NULL_VALUE) - self.assertEqual(self._callFUT(value_pb, field_type), None) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), None) def test_w_string(self): from google.protobuf.struct_pb2 import Value @@ -445,9 +465,10 @@ def test_w_string(self): VALUE = "Value" field_type = Type(code=TypeCode.STRING) + field_name = "string_column" value_pb = Value(string_value=VALUE) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_bytes(self): from google.protobuf.struct_pb2 import Value @@ -456,9 +477,10 @@ def test_w_bytes(self): VALUE = b"Value" field_type = Type(code=TypeCode.BYTES) + field_name = "bytes_column" value_pb = Value(string_value=VALUE) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_bool(self): from google.protobuf.struct_pb2 import Value @@ -467,9 +489,10 @@ def test_w_bool(self): VALUE = True field_type = Type(code=TypeCode.BOOL) + field_name = "bool_column" value_pb = Value(bool_value=VALUE) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_int(self): from google.protobuf.struct_pb2 import Value @@ -478,9 +501,10 @@ def test_w_int(self): VALUE = 12345 field_type = Type(code=TypeCode.INT64) + field_name = "int_column" value_pb = Value(string_value=str(VALUE)) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_float(self): from google.protobuf.struct_pb2 import Value @@ -489,9 +513,10 @@ def test_w_float(self): VALUE = 3.14159 field_type = Type(code=TypeCode.FLOAT64) + field_name = "float_column" value_pb = Value(number_value=VALUE) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_float_str(self): from google.protobuf.struct_pb2 import Value @@ -500,10 +525,13 @@ def test_w_float_str(self): VALUE = "3.14159" field_type = Type(code=TypeCode.FLOAT64) + field_name = "float_str_column" value_pb = Value(string_value=VALUE) expected_value = 3.14159 - self.assertEqual(self._callFUT(value_pb, field_type), expected_value) + self.assertEqual( + self._callFUT(value_pb, field_type, field_name), expected_value + ) def test_w_float32(self): from google.cloud.spanner_v1 import Type, TypeCode @@ -511,9 +539,10 @@ def test_w_float32(self): VALUE = 3.14159 field_type = Type(code=TypeCode.FLOAT32) + field_name = "float32_column" value_pb = Value(number_value=VALUE) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_float32_str(self): from google.cloud.spanner_v1 import Type, TypeCode @@ -521,10 +550,13 @@ def test_w_float32_str(self): VALUE = "3.14159" field_type = Type(code=TypeCode.FLOAT32) + field_name = "float32_str_column" value_pb = Value(string_value=VALUE) expected_value = 3.14159 - self.assertEqual(self._callFUT(value_pb, field_type), expected_value) + self.assertEqual( + self._callFUT(value_pb, field_type, field_name), expected_value + ) def test_w_date(self): import datetime @@ -534,9 +566,10 @@ def test_w_date(self): VALUE = datetime.date.today() field_type = Type(code=TypeCode.DATE) + field_name = "date_column" value_pb = Value(string_value=VALUE.isoformat()) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_timestamp_wo_nanos(self): import datetime @@ -549,9 +582,10 @@ def test_w_timestamp_wo_nanos(self): 2016, 12, 20, 21, 13, 47, microsecond=123456, tzinfo=datetime.timezone.utc ) field_type = Type(code=TypeCode.TIMESTAMP) + field_name = "nanos_column" value_pb = Value(string_value=datetime_helpers.to_rfc3339(value)) - parsed = self._callFUT(value_pb, field_type) + parsed = self._callFUT(value_pb, field_type, field_name) self.assertIsInstance(parsed, datetime_helpers.DatetimeWithNanoseconds) self.assertEqual(parsed, value) @@ -566,9 +600,10 @@ def test_w_timestamp_w_nanos(self): 2016, 12, 20, 21, 13, 47, nanosecond=123456789, tzinfo=datetime.timezone.utc ) field_type = Type(code=TypeCode.TIMESTAMP) + field_name = "timestamp_column" value_pb = Value(string_value=datetime_helpers.to_rfc3339(value)) - parsed = self._callFUT(value_pb, field_type) + parsed = self._callFUT(value_pb, field_type, field_name) self.assertIsInstance(parsed, datetime_helpers.DatetimeWithNanoseconds) self.assertEqual(parsed, value) @@ -580,9 +615,10 @@ def test_w_array_empty(self): field_type = Type( code=TypeCode.ARRAY, array_element_type=Type(code=TypeCode.INT64) ) + field_name = "array_empty_column" value_pb = Value(list_value=ListValue(values=[])) - self.assertEqual(self._callFUT(value_pb, field_type), []) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), []) def test_w_array_non_empty(self): from google.protobuf.struct_pb2 import Value, ListValue @@ -592,13 +628,14 @@ def test_w_array_non_empty(self): field_type = Type( code=TypeCode.ARRAY, array_element_type=Type(code=TypeCode.INT64) ) + field_name = "array_non_empty_column" VALUES = [32, 19, 5] values_pb = ListValue( values=[Value(string_value=str(value)) for value in VALUES] ) value_pb = Value(list_value=values_pb) - self.assertEqual(self._callFUT(value_pb, field_type), VALUES) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUES) def test_w_struct(self): from google.protobuf.struct_pb2 import Value @@ -615,9 +652,10 @@ def test_w_struct(self): ] ) field_type = Type(code=TypeCode.STRUCT, struct_type=struct_type_pb) + field_name = "struct_column" value_pb = Value(list_value=_make_list_value_pb(VALUES)) - self.assertEqual(self._callFUT(value_pb, field_type), VALUES) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUES) def test_w_numeric(self): import decimal @@ -627,9 +665,10 @@ def test_w_numeric(self): VALUE = decimal.Decimal("99999999999999999999999999999.999999999") field_type = Type(code=TypeCode.NUMERIC) + field_name = "numeric_column" value_pb = Value(string_value=str(VALUE)) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_json(self): import json @@ -641,9 +680,10 @@ def test_w_json(self): str_repr = json.dumps(VALUE, sort_keys=True, separators=(",", ":")) field_type = Type(code=TypeCode.JSON) + field_name = "json_column" value_pb = Value(string_value=str_repr) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) VALUE = None str_repr = json.dumps(VALUE, sort_keys=True, separators=(",", ":")) @@ -651,7 +691,7 @@ def test_w_json(self): field_type = Type(code=TypeCode.JSON) value_pb = Value(string_value=str_repr) - self.assertEqual(self._callFUT(value_pb, field_type), {}) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), {}) def test_w_unknown_type(self): from google.protobuf.struct_pb2 import Value @@ -659,10 +699,44 @@ def test_w_unknown_type(self): from google.cloud.spanner_v1 import TypeCode field_type = Type(code=TypeCode.TYPE_CODE_UNSPECIFIED) + field_name = "unknown_column" value_pb = Value(string_value="Borked") with self.assertRaises(ValueError): - self._callFUT(value_pb, field_type) + self._callFUT(value_pb, field_type, field_name) + + def test_w_proto_message(self): + from google.protobuf.struct_pb2 import Value + from google.cloud.spanner_v1 import Type + from google.cloud.spanner_v1 import TypeCode + import base64 + from samples.samples.testdata import singer_pb2 + + VALUE = singer_pb2.SingerInfo() + field_type = Type(code=TypeCode.PROTO) + field_name = "proto_message_column" + value_pb = Value(string_value=base64.b64encode(VALUE.SerializeToString())) + column_info = {"proto_message_column": singer_pb2.SingerInfo()} + + self.assertEqual( + self._callFUT(value_pb, field_type, field_name, column_info), VALUE + ) + + def test_w_proto_enum(self): + from google.protobuf.struct_pb2 import Value + from google.cloud.spanner_v1 import Type + from google.cloud.spanner_v1 import TypeCode + from samples.samples.testdata import singer_pb2 + + VALUE = "ROCK" + field_type = Type(code=TypeCode.ENUM) + field_name = "proto_enum_column" + value_pb = Value(string_value=str(singer_pb2.Genre.ROCK)) + column_info = {"proto_enum_column": singer_pb2.Genre} + + self.assertEqual( + self._callFUT(value_pb, field_type, field_name, column_info), VALUE + ) class Test_parse_list_value_pbs(unittest.TestCase): diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 6bcacd379b..ec2983ff7e 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -220,6 +220,13 @@ def test_ctor_w_directed_read_options(self): self.assertIs(database._instance, instance) self.assertEqual(database._directed_read_options, DIRECTED_READ_OPTIONS) + def test_ctor_w_proto_descriptors(self): + instance = _Instance(self.INSTANCE_NAME) + database = self._make_one(self.DATABASE_ID, instance, proto_descriptors=b"") + self.assertEqual(database.database_id, self.DATABASE_ID) + self.assertIs(database._instance, instance) + self.assertEqual(database._proto_descriptors, b"") + def test_from_pb_bad_database_name(self): from google.cloud.spanner_admin_database_v1 import Database @@ -385,6 +392,14 @@ def test_default_leader(self): default_leader = database._default_leader = "us-east4" self.assertEqual(database.default_leader, default_leader) + def test_proto_descriptors(self): + instance = _Instance(self.INSTANCE_NAME) + pool = _Pool() + database = self._make_one( + self.DATABASE_ID, instance, pool=pool, proto_descriptors=b"" + ) + self.assertEqual(database.proto_descriptors, b"") + def test_spanner_api_property_w_scopeless_creds(self): client = _Client() client_info = client._client_info = mock.Mock() @@ -659,6 +674,41 @@ def test_create_success_w_encryption_config_dict(self): metadata=[("google-cloud-resource-prefix", database.name)], ) + def test_create_success_w_proto_descriptors(self): + from tests._fixtures import DDL_STATEMENTS + from google.cloud.spanner_admin_database_v1 import CreateDatabaseRequest + + op_future = object() + client = _Client() + api = client.database_admin_api = self._make_database_admin_api() + api.create_database.return_value = op_future + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + proto_descriptors = b"" + database = self._make_one( + self.DATABASE_ID, + instance, + ddl_statements=DDL_STATEMENTS, + pool=pool, + proto_descriptors=proto_descriptors, + ) + + future = database.create() + + self.assertIs(future, op_future) + + expected_request = CreateDatabaseRequest( + parent=self.INSTANCE_NAME, + create_statement="CREATE DATABASE {}".format(self.DATABASE_ID), + extra_statements=DDL_STATEMENTS, + proto_descriptors=proto_descriptors, + ) + + api.create_database.assert_called_once_with( + request=expected_request, + metadata=[("google-cloud-resource-prefix", database.name)], + ) + def test_exists_grpc_error(self): from google.api_core.exceptions import Unknown @@ -944,6 +994,34 @@ def test_update_success(self): metadata=[("google-cloud-resource-prefix", database.name)], ) + def test_update_ddl_w_proto_descriptors(self): + from tests._fixtures import DDL_STATEMENTS + from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest + + op_future = object() + client = _Client() + api = client.database_admin_api = self._make_database_admin_api() + api.update_database_ddl.return_value = op_future + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + + future = database.update_ddl(DDL_STATEMENTS, proto_descriptors=b"") + + self.assertIs(future, op_future) + + expected_request = UpdateDatabaseDdlRequest( + database=self.DATABASE_NAME, + statements=DDL_STATEMENTS, + operation_id="", + proto_descriptors=b"", + ) + + api.update_database_ddl.assert_called_once_with( + request=expected_request, + metadata=[("google-cloud-resource-prefix", database.name)], + ) + def test_drop_grpc_error(self): from google.api_core.exceptions import Unknown diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index f42bbe1db9..1bfafb37fe 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -556,6 +556,7 @@ def test_database_factory_explicit(self): pool = _Pool() logger = mock.create_autospec(Logger, instance=True) encryption_config = {"kms_key_name": "kms_key_name"} + proto_descriptors = b"" database = instance.database( DATABASE_ID, @@ -564,6 +565,7 @@ def test_database_factory_explicit(self): logger=logger, encryption_config=encryption_config, database_role=DATABASE_ROLE, + proto_descriptors=proto_descriptors, ) self.assertIsInstance(database, Database) @@ -575,6 +577,7 @@ def test_database_factory_explicit(self): self.assertIs(pool._bound, database) self.assertIs(database._encryption_config, encryption_config) self.assertIs(database.database_role, DATABASE_ROLE) + self.assertIs(database._proto_descriptors, proto_descriptors) def test_list_databases(self): from google.cloud.spanner_admin_database_v1 import Database as DatabasePB diff --git a/tests/unit/test_param_types.py b/tests/unit/test_param_types.py index 827f08658d..a7069543c8 100644 --- a/tests/unit/test_param_types.py +++ b/tests/unit/test_param_types.py @@ -87,3 +87,37 @@ def test_it(self): found = param_types.PG_OID self.assertEqual(found, expected) + + +class Test_ProtoMessageParamType(unittest.TestCase): + def test_it(self): + from google.cloud.spanner_v1 import Type + from google.cloud.spanner_v1 import TypeCode + from google.cloud.spanner_v1 import param_types + from samples.samples.testdata import singer_pb2 + + singer_info = singer_pb2.SingerInfo() + expected = Type( + code=TypeCode.PROTO, proto_type_fqn=singer_info.DESCRIPTOR.full_name + ) + + found = param_types.ProtoMessage(singer_info) + + self.assertEqual(found, expected) + + +class Test_ProtoEnumParamType(unittest.TestCase): + def test_it(self): + from google.cloud.spanner_v1 import Type + from google.cloud.spanner_v1 import TypeCode + from google.cloud.spanner_v1 import param_types + from samples.samples.testdata import singer_pb2 + + singer_genre = singer_pb2.Genre + expected = Type( + code=TypeCode.ENUM, proto_type_fqn=singer_genre.DESCRIPTOR.full_name + ) + + found = param_types.ProtoEnum(singer_genre) + + self.assertEqual(found, expected) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 0bb02ebdc7..917e875f22 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -643,7 +643,12 @@ def test_read(self): self.assertIs(found, snapshot().read.return_value) snapshot().read.assert_called_once_with( - TABLE_NAME, COLUMNS, KEYSET, INDEX, LIMIT + TABLE_NAME, + COLUMNS, + KEYSET, + INDEX, + LIMIT, + column_info=None, ) def test_execute_sql_not_created(self): @@ -674,6 +679,7 @@ def test_execute_sql_defaults(self): request_options=None, timeout=google.api_core.gapic_v1.method.DEFAULT, retry=google.api_core.gapic_v1.method.DEFAULT, + column_info=None, ) def test_execute_sql_non_default_retry(self): @@ -704,6 +710,7 @@ def test_execute_sql_non_default_retry(self): request_options=None, timeout=None, retry=None, + column_info=None, ) def test_execute_sql_explicit(self): @@ -732,6 +739,7 @@ def test_execute_sql_explicit(self): request_options=None, timeout=google.api_core.gapic_v1.method.DEFAULT, retry=google.api_core.gapic_v1.method.DEFAULT, + column_info=None, ) def test_batch_not_created(self):