Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle disabled ruff rules P->U #591

Merged
merged 3 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions hass_nabucasa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
from __future__ import annotations

import asyncio
from collections.abc import Awaitable, Callable, Mapping
from datetime import datetime, timedelta
import json
import logging
import os
from pathlib import Path
import shutil
from typing import Any, Generic, Literal, TypeVar
from collections.abc import Awaitable, Callable, Mapping

import aiohttp
from aiohttp import ClientSession
from atomicwrites import atomic_write
from jose import jwt

Expand All @@ -21,9 +20,9 @@
from .cloudhooks import Cloudhooks
from .const import (
CONFIG_DIR,
MODE_DEV,
DEFAULT_SERVERS,
DEFAULT_VALUES,
MODE_DEV,
STATE_CONNECTED,
)
from .google_report_state import GoogleReportState
Expand Down Expand Up @@ -127,7 +126,7 @@ def is_connected(self) -> bool:
return self.iot.state == STATE_CONNECTED

@property
def websession(self) -> aiohttp.ClientSession:
def websession(self) -> ClientSession:
"""Return websession for connections."""
return self.client.websession

Expand All @@ -141,7 +140,7 @@ def expiration_date(self) -> datetime:
"""Return the subscription expiration as a UTC datetime object."""
if (parsed_date := parse_date(self.claims["custom:sub-exp"])) is None:
raise ValueError(
f"Invalid expiration date ({self.claims['custom:sub-exp']})"
f"Invalid expiration date ({self.claims['custom:sub-exp']})",
)
return datetime.combine(parsed_date, datetime.min.time()).replace(tzinfo=UTC)

Expand All @@ -161,7 +160,10 @@ def user_info_path(self) -> Path:
return self.path(f"{self.mode}_auth.json")

async def update_token(
self, id_token: str, access_token: str, refresh_token: str | None = None
self,
id_token: str,
access_token: str,
refresh_token: str | None = None,
) -> asyncio.Task | None:
"""Update the id and access token."""
self.id_token = id_token
Expand All @@ -185,7 +187,8 @@ async def update_token(
return None

def register_on_initialized(
self, on_initialized_cb: Callable[[], Awaitable[None]]
self,
on_initialized_cb: Callable[[], Awaitable[None]],
) -> None:
"""Register an async on_initialized callback.

Expand Down Expand Up @@ -249,7 +252,7 @@ def _remove_data(self) -> None:
base_path = self.path()

# Recursively remove .cloud
if os.path.isdir(base_path):
if base_path.is_dir():
shutil.rmtree(base_path)

# Guard against .cloud not being a directory
Expand All @@ -271,7 +274,7 @@ def _write_user_info(self) -> None:
"refresh_token": self.refresh_token,
},
indent=4,
)
),
)
self.user_info_path.chmod(0o600)

Expand All @@ -287,11 +290,11 @@ def load_config() -> None | dict[str, Any]:

if not self.user_info_path.exists():
return None

try:
content: dict[str, Any] = json.loads(
self.user_info_path.read_text(encoding="utf-8")
self.user_info_path.read_text(encoding="utf-8"),
)
return content
except (ValueError, OSError) as err:
path = self.user_info_path.relative_to(self.client.base_path)
self.client.user_message(
Expand All @@ -300,10 +303,14 @@ def load_config() -> None | dict[str, Any]:
f"Unable to load authentication from {path}. [Please login again](/config/cloud)",
)
_LOGGER.warning(
"Error loading cloud authentication info from %s: %s", path, err
"Error loading cloud authentication info from %s: %s",
path,
err,
)
return None

return content

info = await self.run_executor(load_config)
if info is None:
# No previous token data
Expand Down
6 changes: 3 additions & 3 deletions hass_nabucasa/account_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def async_get_authorize_url(self) -> str:
_LOGGER.debug("Opening connection for %s", self.service)

self._client = await self.cloud.client.websession.ws_connect(
f"https://{self.cloud.account_link_server}/v1"
f"https://{self.cloud.account_link_server}/v1",
)
await self._client.send_json({"service": self.service})

Expand Down Expand Up @@ -101,7 +101,7 @@ async def _get_response(self) -> dict[str, Any]:

if "error" in response:
if response["error"] == ERR_TIMEOUT:
raise TimeoutError()
raise TimeoutError

raise AccountLinkException(response["error"])

Expand Down Expand Up @@ -131,7 +131,7 @@ async def async_fetch_available_services(
"""Fetch available services."""

resp = await cloud.client.websession.get(
f"https://{cloud.account_link_server}/services"
f"https://{cloud.account_link_server}/services",
)
resp.raise_for_status()
content: list[dict[str, Any]] = await resp.json()
Expand Down
61 changes: 35 additions & 26 deletions hass_nabucasa/acme.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@
from typing import TYPE_CHECKING
import urllib

import OpenSSL
from acme import challenges, client, crypto_util, errors, messages
import async_timeout
from atomicwrites import atomic_write
import attr
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
from cryptography.x509.extensions import SubjectAlternativeName
from cryptography.x509.oid import NameOID
import josepy as jose
import OpenSSL

from . import cloud_api

Expand Down Expand Up @@ -134,7 +134,7 @@ def common_name(self) -> str | None:
if not self._x509:
return None
return str(
self._x509.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value
self._x509.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value,
)

@property
Expand All @@ -144,7 +144,7 @@ def alternative_names(self) -> list[str] | None:
return None

alternative_names = self._x509.extensions.get_extension_for_class(
SubjectAlternativeName
SubjectAlternativeName,
).value
return [str(entry.value) for entry in alternative_names]

Expand Down Expand Up @@ -182,7 +182,8 @@ def _load_account_key(self) -> None:
else:
_LOGGER.debug("Create new RSA keyfile: %s", self.path_account_key)
key = rsa.generate_private_key(
public_exponent=65537, key_size=ACCOUNT_KEY_SIZE
public_exponent=65537,
key_size=ACCOUNT_KEY_SIZE,
)

# Store it to file
Expand All @@ -203,7 +204,7 @@ def _create_client(self) -> None:
if self.path_registration_info.exists():
_LOGGER.info("Load exists ACME registration")
regr = messages.RegistrationResource.json_loads(
self.path_registration_info.read_text(encoding="utf-8")
self.path_registration_info.read_text(encoding="utf-8"),
)

acme_url = urllib.parse.urlparse(self._acme_server)
Expand Down Expand Up @@ -255,15 +256,17 @@ def _create_client(self) -> None:
)
regr = self._acme_client.new_account(
messages.NewRegistration.from_data(
email=self._email, terms_of_service_agreed=True
)
email=self._email,
terms_of_service_agreed=True,
),
)
except errors.Error as err:
raise AcmeClientError(f"Can't register to ACME server: {err}") from err

# Store registration info
self.path_registration_info.write_text(
regr.json_dumps_pretty(), encoding="utf-8"
regr.json_dumps_pretty(),
encoding="utf-8",
)
self.path_registration_info.chmod(0o600)

Expand All @@ -280,10 +283,10 @@ def _create_order(self, csr_pem: bytes) -> messages.OrderResource:
and err.detail == "JWS verification error"
):
raise AcmeJWSVerificationError(
f"JWS verification failed: {err}"
f"JWS verification failed: {err}",
) from None
raise AcmeChallengeError(
f"Can't order a new ACME challenge: {err}"
f"Can't order a new ACME challenge: {err}",
) from None

def _start_challenge(self, order: messages.OrderResource) -> list[ChallengeHandler]:
Expand All @@ -307,14 +310,14 @@ def _start_challenge(self, order: messages.OrderResource) -> list[ChallengeHandl
for dns_challenge in dns_challenges:
try:
response, validation = dns_challenge.response_and_validation(
self._account_jwk
self._account_jwk,
)
except errors.Error as err:
raise AcmeChallengeError(
f"Can't validate the new ACME challenge: {err}"
f"Can't validate the new ACME challenge: {err}",
) from None
handlers.append(
ChallengeHandler(dns_challenge, order, response, validation)
ChallengeHandler(dns_challenge, order, response, validation),
)

return handlers
Expand All @@ -338,14 +341,16 @@ def _finish_challenge(self, order: messages.OrderResource) -> None:
try:
order = self._acme_client.poll_authorizations(order, deadline)
order = self._acme_client.finalize_order(
order, deadline, fetch_alternative_chains=True
order,
deadline,
fetch_alternative_chains=True,
)
except errors.Error as err:
raise AcmeChallengeError(f"Wait of ACME challenge fails: {err}") from err
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception while finalizing order")
raise AcmeChallengeError(
"Unexpected exception while finalizing order"
"Unexpected exception while finalizing order",
) from None

# Cleanup the old stuff
Expand Down Expand Up @@ -385,8 +390,9 @@ def _revoke_certificate(self) -> None:

fullchain = jose.ComparableX509(
OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_PEM, self.path_fullchain.read_bytes()
)
OpenSSL.crypto.FILETYPE_PEM,
self.path_fullchain.read_bytes(),
),
)

_LOGGER.info("Revoke certificate")
Expand All @@ -397,10 +403,10 @@ def _revoke_certificate(self) -> None:
pass
except errors.Error as err:
# Ignore errors where certificate did not exist
if "No such certificate" in str(err):
if "No such certificate" in str(err): # noqa: SIM114
pass
# Ignore errors where certificate has expired
elif "Certificate is expired" in str(err):
elif "Certificate is expired" in str(err): # noqa: SIM114
pass
# Ignore errors where unrecognized issuer (happens dev/prod switch)
elif "Certificate from unrecognized issuer" in str(err):
Expand All @@ -415,7 +421,7 @@ def _deactivate_account(self) -> None:

_LOGGER.info("Load exists ACME registration")
regr = messages.RegistrationResource.json_loads(
self.path_registration_info.read_text(encoding="utf-8")
self.path_registration_info.read_text(encoding="utf-8"),
)

try:
Expand Down Expand Up @@ -446,7 +452,8 @@ async def issue_certificate(self) -> None:
csr = await self.cloud.run_executor(self._generate_csr)
order = await self.cloud.run_executor(self._create_order, csr)
dns_challenges: list[ChallengeHandler] = await self.cloud.run_executor(
self._start_challenge, order
self._start_challenge,
order,
)

try:
Expand All @@ -455,18 +462,19 @@ async def issue_certificate(self) -> None:
try:
async with async_timeout.timeout(30):
resp = await cloud_api.async_remote_challenge_txt(
self.cloud, challenge.validation
self.cloud,
challenge.validation,
)
assert resp.status in (200, 201)
except (TimeoutError, AssertionError):
raise AcmeNabuCasaError(
"Can't set challenge token to NabuCasa DNS!"
"Can't set challenge token to NabuCasa DNS!",
) from None

# Answer challenge
try:
_LOGGER.info(
"Waiting 60 seconds for publishing DNS to ACME provider"
"Waiting 60 seconds for publishing DNS to ACME provider",
)
await asyncio.sleep(60)
await self.cloud.run_executor(self._answer_challenge, challenge)
Expand All @@ -483,7 +491,8 @@ async def issue_certificate(self) -> None:
async with async_timeout.timeout(30):
# We only need to cleanup for the last entry
await cloud_api.async_remote_challenge_cleanup(
self.cloud, dns_challenges[-1].validation
self.cloud,
dns_challenges[-1].validation,
)
except TimeoutError:
_LOGGER.error("Failed to clean up challenge from NabuCasa DNS!")
Expand Down
Loading