Skip to content

Commit

Permalink
Lint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
alukach committed Sep 10, 2024
1 parent ea39600 commit 3518fc4
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 43 deletions.
13 changes: 6 additions & 7 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -2069,13 +2069,13 @@ def __repr__(self) -> str:


class AwsSystemsManagerParameterStoreSettingsSource(EnvSettingsSource):
_ssm_client: "SSMClient" # type: ignore # TODO: type client
_ssm_client: "SSMClient" # type: ignore
_ssm_path: str

def __init__(
self,
settings_cls: type[BaseSettings],
ssm_client: "SSMClient", # type: ignore # TODO: type client
ssm_client: "SSMClient", # type: ignore
ssm_path: str = "/",
case_sensitive: bool | None = None,
env_prefix: str | None = None,
Expand Down Expand Up @@ -2106,14 +2106,13 @@ def _load_env_vars(self) -> Mapping[str, Optional[str]]:
try:
for page in response_iterator:
for parameter in page["Parameters"]:
key = (
Path(parameter["Name"]).relative_to(self._ssm_path).as_posix()
)

name = Path(parameter["Name"])
key = name.relative_to(self._ssm_path).as_posix()

if not self.case_sensitive:
first_key, *rest = key.split(self.env_nested_delimiter)
key = self.env_nested_delimiter.join([first_key.lower(), *rest])

output[key] = parameter["Value"]

except self._ssm_client.exceptions.ClientError as e:
Expand Down
71 changes: 35 additions & 36 deletions tests/test_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,21 +245,21 @@ def test___call__case_sensitive(self, mocker: MockerFixture) -> None:
"""Test __call__."""

class SqlServer(BaseModel):
password: str = Field(..., alias="Password")
password: str = Field(..., alias='Password')

class AwsSettings(BaseSettings):
"""AWS settings."""

SqlServerUser: str
sql_server_user: str = Field(..., alias="SqlServerUser")
sql_server: SqlServer = Field(..., alias="SqlServer")
sql_server_user: str = Field(..., alias='SqlServerUser')
sql_server: SqlServer = Field(..., alias='SqlServer')

mock_parameters = [
{"Name": "/my/path/SqlServerUser", "Value": "SecretValue"},
{"Name": "/my/path/SqlServer/Password", "Value": "SecretValue"},
{'Name': '/my/path/SqlServerUser', 'Value': 'SecretValue'},
{'Name': '/my/path/SqlServer/Password', 'Value': 'SecretValue'},
]
paginator_mock = mocker.Mock()
paginator_mock.paginate.return_value = [{"Parameters": mock_parameters}]
paginator_mock.paginate.return_value = [{'Parameters': mock_parameters}]

client_mock = mocker.Mock()
client_mock.get_paginator.return_value = paginator_mock
Expand All @@ -268,32 +268,31 @@ class AwsSettings(BaseSettings):
obj = AwsSystemsManagerParameterStoreSettingsSource(
settings_cls=AwsSettings,
ssm_client=client_mock,
ssm_path="/my/path",
ssm_path='/my/path',
case_sensitive=True,
)

settings = obj()
print(f"{settings=}")

assert settings["SqlServerUser"] == "SecretValue"
assert settings["SqlServer"]["Password"] == "SecretValue"
assert settings['SqlServerUser'] == 'SecretValue'
assert settings['SqlServer']['Password'] == 'SecretValue'

def test___call__case_insensitive(self, mocker: MockerFixture) -> None:
"""Test __call__."""

class SqlServer(BaseModel):
password: str = Field(..., alias="Password")
password: str = Field(..., alias='Password')

class AwsSettings(BaseSettings):
"""AWS settings."""

SqlServerUser: str
sql_server_user: str = Field(..., alias="SqlServerUser")
sql_server: SqlServer = Field(..., alias="SqlServer")
sql_server_user: str = Field(..., alias='SqlServerUser')
sql_server: SqlServer = Field(..., alias='SqlServer')

mock_parameters = [
{"Name": "/my/path/SqlServerUser", "Value": "SecretValue"},
{"Name": "/my/path/SqlServer/Password", "Value": "SecretValue"},
{'Name': '/my/path/SqlServerUser', 'Value': 'SecretValue'},
{'Name': '/my/path/SqlServer/Password', 'Value': 'SecretValue'},
]
paginator_mock = mocker.Mock()
paginator_mock.paginate.return_value = [{"Parameters": mock_parameters}]
Expand All @@ -310,14 +309,14 @@ class AwsSettings(BaseSettings):
)
settings = obj()

assert settings["SqlServerUser"] == "SecretValue"
assert settings["SqlServer"]["Password"] == "SecretValue"
assert settings['SqlServerUser'] == 'SecretValue'
assert settings['SqlServer']['Password'] == 'SecretValue'

def test_aws_ssm_settings_source(self, mocker: MockerFixture) -> None:
"""Test AwsSystemsManagerParameterStoreSettingsSource."""
mock_parameters = [
{"Name": "/my/path/SqlServerUser", "Value": "SecretValue"},
{"Name": "/my/path/SqlServer/Password", "Value": "SecretValue"},
{'Name': '/my/path/SqlServerUser', 'Value': 'SecretValue'},
{'Name': '/my/path/SqlServer/Password', 'Value': 'SecretValue'},
]
paginator_mock = mocker.Mock()
paginator_mock.paginate.return_value = [{"Parameters": mock_parameters}]
Expand All @@ -327,14 +326,14 @@ def test_aws_ssm_settings_source(self, mocker: MockerFixture) -> None:
client_mock.exceptions.ClientError = Exception

class SqlServer(BaseModel):
password: str = Field(..., alias="Password")
password: str = Field(..., alias='Password')

class AwsSettings(BaseSettings):
"""AWS settings."""

SqlServerUser: str
sql_server_user: str = Field(..., alias="SqlServerUser")
sql_server: SqlServer = Field(..., alias="SqlServer")
sql_server_user: str = Field(..., alias='SqlServerUser')
sql_server: SqlServer = Field(..., alias='SqlServer')

@classmethod
def settings_customise_sources(
Expand All @@ -355,32 +354,32 @@ def settings_customise_sources(

settings = AwsSettings() # type: ignore

assert settings.SqlServerUser == "SecretValue"
assert settings.sql_server_user == "SecretValue"
assert settings.sql_server.password == "SecretValue"
assert settings.SqlServerUser == 'SecretValue'
assert settings.sql_server_user == 'SecretValue'
assert settings.sql_server.password == 'SecretValue'

def test_aws_ssm_settings_source__delimiter(self, mocker: MockerFixture) -> None:
"""Test AwsSystemsManagerParameterStoreSettingsSource."""
mock_parameters = [
{"Name": "/my/path/SqlServerUser", "Value": "SecretValue"},
{"Name": "/my/path/SqlServer__Password", "Value": "SecretValue"},
{'Name': '/my/path/SqlServerUser', 'Value': 'SecretValue'},
{'Name': '/my/path/SqlServer__Password', 'Value': 'SecretValue'},
]
paginator_mock = mocker.Mock()
paginator_mock.paginate.return_value = [{"Parameters": mock_parameters}]
paginator_mock.paginate.return_value = [{'Parameters': mock_parameters}]

client_mock = mocker.Mock()
client_mock.get_paginator.return_value = paginator_mock
client_mock.exceptions.ClientError = Exception

class SqlServer(BaseModel):
password: str = Field(..., alias="Password")
password: str = Field(..., alias='Password')

class AwsSettings(BaseSettings):
"""AWS settings."""

SqlServerUser: str
sql_server_user: str = Field(..., alias="SqlServerUser")
sql_server: SqlServer = Field(..., alias="SqlServer")
sql_server_user: str = Field(..., alias='SqlServerUser')
sql_server: SqlServer = Field(..., alias='SqlServer')

@classmethod
def settings_customise_sources(
Expand All @@ -395,13 +394,13 @@ def settings_customise_sources(
AwsSystemsManagerParameterStoreSettingsSource(
settings_cls=AwsSettings,
ssm_client=client_mock,
ssm_path="/my/path",
env_nested_delimiter="__",
ssm_path='/my/path',
env_nested_delimiter='__',
),
)

settings = AwsSettings() # type: ignore

assert settings.SqlServerUser == "SecretValue"
assert settings.sql_server_user == "SecretValue"
assert settings.sql_server.password == "SecretValue"
assert settings.SqlServerUser == 'SecretValue'
assert settings.sql_server_user == 'SecretValue'
assert settings.sql_server.password == 'SecretValue'

0 comments on commit 3518fc4

Please sign in to comment.