diff --git a/.flake8 b/.flake8 index 277e04b7..e46e8e07 100644 --- a/.flake8 +++ b/.flake8 @@ -15,7 +15,7 @@ # limitations under the License. [flake8] -ignore = E203, E231, E266, E501, W503, ANN101, ANN401 +ignore = E203, E231, E266, E501, W503, ANN101, ANN102, ANN401 exclude = # Exclude generated code. **/proto/** diff --git a/README.md b/README.md index 5920f47d..be36ff99 100644 --- a/README.md +++ b/README.md @@ -381,12 +381,11 @@ to your instance's private IP. To change this, such as connecting to AlloyDB over a public IP address, set the `ip_type` keyword argument when initializing a `Connector()` or when calling `connector.connect()`. -Possible values for `ip_type` are `IPTypes.PRIVATE` (default value), and -`IPTypes.PUBLIC`. +Possible values for `ip_type` are `"PRIVATE"` (default value), and `"PUBLIC"`. Example: ```python -from google.cloud.alloydb.connector import Connector, IPTypes +from google.cloud.alloydb.connector import Connector import sqlalchemy @@ -401,7 +400,7 @@ def getconn(): user="my-user", password="my-password", db="my-db-name", - ip_type=IPTypes.PUBLIC, # use public IP + ip_type="PUBLIC", # use public IP ) # create connection pool diff --git a/google/cloud/alloydb/connector/async_connector.py b/google/cloud/alloydb/connector/async_connector.py index 2fb5f506..f8b6d480 100644 --- a/google/cloud/alloydb/connector/async_connector.py +++ b/google/cloud/alloydb/connector/async_connector.py @@ -47,8 +47,8 @@ class AsyncConnector: alloydb_api_endpoint (str): Base URL to use when calling the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com". enable_iam_auth (bool): Enables automatic IAM database authentication. - ip_type (IPTypes): Default IP type for all AlloyDB connections. - Defaults to IPTypes.PRIVATE for private IP connections. + ip_type (str | IPTypes): Default IP type for all AlloyDB connections. + Defaults to IPTypes.PRIVATE ("PRIVATE") for private IP connections. """ def __init__( @@ -57,7 +57,7 @@ def __init__( quota_project: Optional[str] = None, alloydb_api_endpoint: str = "https://alloydb.googleapis.com", enable_iam_auth: bool = False, - ip_type: IPTypes = IPTypes.PRIVATE, + ip_type: str | IPTypes = IPTypes.PRIVATE, user_agent: Optional[str] = None, ) -> None: self._instances: Dict[str, Instance] = {} @@ -65,6 +65,9 @@ def __init__( self._quota_project = quota_project self._alloydb_api_endpoint = alloydb_api_endpoint self._enable_iam_auth = enable_iam_auth + # if ip_type is str, convert to IPTypes enum + if isinstance(ip_type, str): + ip_type = IPTypes(ip_type.upper()) self._ip_type = ip_type self._user_agent = user_agent # initialize credentials @@ -144,7 +147,10 @@ async def connect( kwargs.pop("port", None) # get connection info for AlloyDB instance - ip_type: IPTypes = kwargs.pop("ip_type", self._ip_type) + ip_type: str | IPTypes = kwargs.pop("ip_type", self._ip_type) + # if ip_type is str, convert to IPTypes enum + if isinstance(ip_type, str): + ip_type = IPTypes(ip_type.upper()) ip_address, context = await instance.connection_info(ip_type) # callable to be used for auto IAM authn diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index f7b0822b..f215af27 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -57,8 +57,8 @@ class Connector: alloydb_api_endpoint (str): Base URL to use when calling the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com". enable_iam_auth (bool): Enables automatic IAM database authentication. - ip_type (IPTypes): Default IP type for all AlloyDB connections. - Defaults to IPTypes.PRIVATE for private IP connections. + ip_type (str | IPTypes): Default IP type for all AlloyDB connections. + Defaults to IPTypes.PRIVATE ("PRIVATE") for private IP connections. """ def __init__( @@ -67,7 +67,7 @@ def __init__( quota_project: Optional[str] = None, alloydb_api_endpoint: str = "https://alloydb.googleapis.com", enable_iam_auth: bool = False, - ip_type: IPTypes = IPTypes.PRIVATE, + ip_type: str | IPTypes = IPTypes.PRIVATE, user_agent: Optional[str] = None, ) -> None: # create event loop and start it in background thread @@ -79,6 +79,9 @@ def __init__( self._quota_project = quota_project self._alloydb_api_endpoint = alloydb_api_endpoint self._enable_iam_auth = enable_iam_auth + # if ip_type is str, convert to IPTypes enum + if isinstance(ip_type, str): + ip_type = IPTypes(ip_type.upper()) self._ip_type = ip_type self._user_agent = user_agent # initialize credentials @@ -171,7 +174,10 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> kwargs.pop("port", None) # get connection info for AlloyDB instance - ip_type: IPTypes = kwargs.pop("ip_type", self._ip_type) + ip_type: IPTypes | str = kwargs.pop("ip_type", self._ip_type) + # if ip_type is str, convert to IPTypes enum + if isinstance(ip_type, str): + ip_type = IPTypes(ip_type.upper()) ip_address, context = await instance.connection_info(ip_type) # synchronous drivers are blocking and run using executor diff --git a/google/cloud/alloydb/connector/instance.py b/google/cloud/alloydb/connector/instance.py index a6ea3671..d2d38b79 100644 --- a/google/cloud/alloydb/connector/instance.py +++ b/google/cloud/alloydb/connector/instance.py @@ -49,6 +49,13 @@ class IPTypes(Enum): PUBLIC: str = "PUBLIC" PRIVATE: str = "PRIVATE" + @classmethod + def _missing_(cls, value: object) -> None: + raise ValueError( + f"Incorrect value for ip_type, got '{value}'. Want one of: " + f"{', '.join([repr(m.value) for m in cls])}." + ) + def _parse_instance_uri(instance_uri: str) -> Tuple[str, str, str, str]: # should take form "projects//locations//clusters//instances/" diff --git a/tests/system/test_asyncpg_public_ip.py b/tests/system/test_asyncpg_public_ip.py index f4fe93cd..a297812b 100644 --- a/tests/system/test_asyncpg_public_ip.py +++ b/tests/system/test_asyncpg_public_ip.py @@ -22,7 +22,6 @@ import sqlalchemy.ext.asyncio from google.cloud.alloydb.connector import AsyncConnector -from google.cloud.alloydb.connector import IPTypes async def create_sqlalchemy_engine( @@ -70,7 +69,7 @@ async def getconn() -> asyncpg.Connection: user=user, password=password, db=db, - ip_type=IPTypes.PUBLIC, + ip_type="PUBLIC", ) return conn diff --git a/tests/system/test_pg8000_public_ip.py b/tests/system/test_pg8000_public_ip.py index a9782e11..5fe8752f 100644 --- a/tests/system/test_pg8000_public_ip.py +++ b/tests/system/test_pg8000_public_ip.py @@ -21,7 +21,6 @@ import sqlalchemy from google.cloud.alloydb.connector import Connector -from google.cloud.alloydb.connector import IPTypes def create_sqlalchemy_engine( @@ -70,7 +69,7 @@ def getconn() -> pg8000.dbapi.Connection: user=user, password=password, db=db, - ip_type=IPTypes.PUBLIC, + ip_type="PUBLIC", ) return conn diff --git a/tests/unit/test_async_connector.py b/tests/unit/test_async_connector.py index f1c19578..52ddd77b 100644 --- a/tests/unit/test_async_connector.py +++ b/tests/unit/test_async_connector.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +from typing import Union from mock import patch from mocks import FakeAlloyDBClient @@ -21,6 +22,7 @@ import pytest from google.cloud.alloydb.connector import AsyncConnector +from google.cloud.alloydb.connector import IPTypes ALLOYDB_API_ENDPOINT = "https://alloydb.googleapis.com" @@ -40,6 +42,58 @@ async def test_AsyncConnector_init(credentials: FakeCredentials) -> None: await connector.close() +@pytest.mark.parametrize( + "ip_type, expected", + [ + ( + "private", + IPTypes.PRIVATE, + ), + ( + "PRIVATE", + IPTypes.PRIVATE, + ), + ( + IPTypes.PRIVATE, + IPTypes.PRIVATE, + ), + ( + "public", + IPTypes.PUBLIC, + ), + ( + "PUBLIC", + IPTypes.PUBLIC, + ), + ( + IPTypes.PUBLIC, + IPTypes.PUBLIC, + ), + ], +) +async def test_AsyncConnector_init_ip_type( + ip_type: Union[str, IPTypes], expected: IPTypes, credentials: FakeCredentials +) -> None: + """ + Test to check whether the __init__ method of AsyncConnector + properly sets ip_type. + """ + connector = AsyncConnector(credentials=credentials, ip_type=ip_type) + assert connector._ip_type == expected + connector.close() + + +async def test_AsyncConnector_init_bad_ip_type(credentials: FakeCredentials) -> None: + """Test that AsyncConnector errors due to bad ip_type str.""" + bad_ip_type = "BAD-IP-TYPE" + with pytest.raises(ValueError) as exc_info: + AsyncConnector(ip_type=bad_ip_type, credentials=credentials) + assert ( + exc_info.value.args[0] + == f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'PUBLIC', 'PRIVATE'." + ) + + @pytest.mark.asyncio async def test_AsyncConnector_context_manager( credentials: FakeCredentials, @@ -202,3 +256,25 @@ def test_synchronous_init(credentials: FakeCredentials) -> None: """ connector = AsyncConnector(credentials) assert connector._keys is None + + +async def test_async_connect_bad_ip_type( + credentials: FakeCredentials, fake_client: FakeAlloyDBClient +) -> None: + """Test that AyncConnector.connect errors due to bad ip_type str.""" + async with AsyncConnector(credentials=credentials) as connector: + connector._client = fake_client + bad_ip_type = "BAD-IP-TYPE" + with pytest.raises(ValueError) as exc_info: + await connector.connect( + "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance", + "asyncpg", + user="test-user", + password="test-password", + db="test-db", + ip_type=bad_ip_type, + ) + assert ( + exc_info.value.args[0] + == f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'PUBLIC', 'PRIVATE'." + ) diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index a0761245..99318321 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -14,6 +14,7 @@ import asyncio from threading import Thread +from typing import Union from mock import patch from mocks import FakeAlloyDBClient @@ -21,6 +22,7 @@ import pytest from google.cloud.alloydb.connector import Connector +from google.cloud.alloydb.connector import IPTypes def test_Connector_init(credentials: FakeCredentials) -> None: @@ -36,6 +38,58 @@ def test_Connector_init(credentials: FakeCredentials) -> None: connector.close() +def test_Connector_init_bad_ip_type(credentials: FakeCredentials) -> None: + """Test that Connector errors due to bad ip_type str.""" + bad_ip_type = "BAD-IP-TYPE" + with pytest.raises(ValueError) as exc_info: + Connector(ip_type=bad_ip_type, credentials=credentials) + assert ( + exc_info.value.args[0] + == f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'PUBLIC', 'PRIVATE'." + ) + + +@pytest.mark.parametrize( + "ip_type, expected", + [ + ( + "private", + IPTypes.PRIVATE, + ), + ( + "PRIVATE", + IPTypes.PRIVATE, + ), + ( + IPTypes.PRIVATE, + IPTypes.PRIVATE, + ), + ( + "public", + IPTypes.PUBLIC, + ), + ( + "PUBLIC", + IPTypes.PUBLIC, + ), + ( + IPTypes.PUBLIC, + IPTypes.PUBLIC, + ), + ], +) +def test_Connector_init_ip_type( + ip_type: Union[str, IPTypes], expected: IPTypes, credentials: FakeCredentials +) -> None: + """ + Test to check whether the __init__ method of Connector + properly sets ip_type. + """ + connector = Connector(credentials=credentials, ip_type=ip_type) + assert connector._ip_type == expected + connector.close() + + def test_Connector_context_manager(credentials: FakeCredentials) -> None: """ Test to check whether the __init__ method of Connector @@ -84,6 +138,28 @@ def test_connect(credentials: FakeCredentials, fake_client: FakeAlloyDBClient) - assert connection is True +def test_connect_bad_ip_type( + credentials: FakeCredentials, fake_client: FakeAlloyDBClient +) -> None: + """Test that Connector.connect errors due to bad ip_type str.""" + with Connector(credentials=credentials) as connector: + connector._client = fake_client + bad_ip_type = "BAD-IP-TYPE" + with pytest.raises(ValueError) as exc_info: + connector.connect( + "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance", + "pg8000", + user="test-user", + password="test-password", + db="test-db", + ip_type=bad_ip_type, + ) + assert ( + exc_info.value.args[0] + == f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'PUBLIC', 'PRIVATE'." + ) + + def test_connect_unsupported_driver(credentials: FakeCredentials) -> None: """ Test that connector.connect errors with unsupported database driver.