diff --git a/batch/batch/cloud/gcp/driver/create_instance.py b/batch/batch/cloud/gcp/driver/create_instance.py index 1fa2dc64c43..a09f77d0897 100644 --- a/batch/batch/cloud/gcp/driver/create_instance.py +++ b/batch/batch/cloud/gcp/driver/create_instance.py @@ -1,3 +1,4 @@ +from typing import Any import os import logging import base64 @@ -61,7 +62,7 @@ def create_instance_config( ) assert unreserved_disk_storage_gb >= 0 - vm_config = { + vm_config: Any = { 'name': machine_name, 'machineType': f'projects/{project}/zones/{zone}/machineTypes/{machine_type}', 'labels': {'role': 'batch2-agent', 'namespace': DEFAULT_NAMESPACE}, diff --git a/batch/batch/cloud/gcp/driver/driver.py b/batch/batch/cloud/gcp/driver/driver.py index 1149e672bdd..7d8683d4c3c 100644 --- a/batch/batch/cloud/gcp/driver/driver.py +++ b/batch/batch/cloud/gcp/driver/driver.py @@ -1,4 +1,4 @@ -from typing import Optional, Set +from typing import Optional, Set, Dict, Any, List from hailtop import aiotools from hailtop.aiocloud import aiogoogle @@ -72,8 +72,8 @@ def __init__(self, db: Database, machine_name_prefix: str, task_manager: aiotool self.namespace = namespace self.zone_success_rate = ZoneSuccessRate() - self.region_info = None - self.zones = [] + self.region_info: Optional[Dict[str, Dict[str, Any]]] = None + self.zones: List[str] = [] async def shutdown(self): try: @@ -91,6 +91,7 @@ def get_zone(self, cores: int, worker_local_ssd_data_disk: bool, worker_pd_ssd_d global_live_total_cores_mcpu = self.inst_coll_manager.global_live_total_cores_mcpu if global_live_total_cores_mcpu // 1000 < 1_000: return self.resource_manager.default_location + assert self.region_info is not None # FIXME: this reveals a race condition: if update_region_quotas does not run before we get_zone this will fail return get_zone(self.region_info, self.zone_success_rate, cores, worker_local_ssd_data_disk, worker_pd_ssd_data_disk_size_gb) async def process_activity_logs(self): diff --git a/batch/batch/cloud/gcp/driver/resource_manager.py b/batch/batch/cloud/gcp/driver/resource_manager.py index 0d24698657b..f093ccccc7b 100644 --- a/batch/batch/cloud/gcp/driver/resource_manager.py +++ b/batch/batch/cloud/gcp/driver/resource_manager.py @@ -21,13 +21,13 @@ log = logging.getLogger('resource_manager') -def parse_gcp_timestamp(timestamp: Optional[str]) -> Optional[float]: +def parse_gcp_timestamp(timestamp: Optional[str]) -> Optional[int]: if timestamp is None: return None - return dateutil.parser.isoparse(timestamp).timestamp() * 1000 + return int(dateutil.parser.isoparse(timestamp).timestamp() * 1000 + 0.5) -class GCPResourceManager(CloudResourceManager): +class GCPResourceManager(CloudResourceManager[GCPInstanceConfig]): def __init__(self, driver: 'GCPDriver', compute_client: aiogoogle.GoogleComputeClient, default_location: str): self.driver = driver self.compute_client = compute_client @@ -93,6 +93,7 @@ def prepare_vm(self, zone = location if zone is None: + assert cores is not None, (cores, zone) zone = self.driver.get_zone(cores, worker_local_ssd_data_disk, worker_pd_ssd_data_disk_size_gb) if zone is None: return None diff --git a/batch/batch/cloud/gcp/driver/zones.py b/batch/batch/cloud/gcp/driver/zones.py index e2105a36da5..8be4dab7213 100644 --- a/batch/batch/cloud/gcp/driver/zones.py +++ b/batch/batch/cloud/gcp/driver/zones.py @@ -79,8 +79,12 @@ def compute_zone_weights(region_info: Dict[str, Dict[str, Any]], worker_cores: i return weights -def get_zone(region_info: Dict[str, Dict[str, Any]], zone_success_rate: ZoneSuccessRate, worker_cores: int, - worker_local_ssd_data_disk: bool, worker_pd_ssd_data_disk_size_gb: int) -> Optional[str]: +def get_zone(region_info: Dict[str, Dict[str, Any]], + zone_success_rate: ZoneSuccessRate, + worker_cores: int, + worker_local_ssd_data_disk: bool, + worker_pd_ssd_data_disk_size_gb: int + ) -> Optional[str]: zone_weights = compute_zone_weights(region_info, worker_cores, worker_local_ssd_data_disk, worker_pd_ssd_data_disk_size_gb) if not zone_weights: diff --git a/batch/batch/cloud/gcp/instance_config.py b/batch/batch/cloud/gcp/instance_config.py index 95d82c43593..87914df3e41 100644 --- a/batch/batch/cloud/gcp/instance_config.py +++ b/batch/batch/cloud/gcp/instance_config.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Dict, Any +from typing import List, Optional, Dict, Any, cast from typing_extensions import Literal from mypy_extensions import TypedDict import re @@ -16,11 +16,15 @@ def parse_machine_type_str(name: str) -> Dict[str, str]: match = MACHINE_TYPE_REGEX.fullmatch(name) + if match is None: + raise ValueError(f'invalid machine type string: {name}') return match.groupdict() def parse_disk_type(name: str) -> Dict[str, str]: match = DISK_TYPE_REGEX.fullmatch(name) + if match is None: + raise ValueError(f'invalid disk type string: {name}') return match.groupdict() @@ -41,9 +45,16 @@ def parse_disk_type(name: str) -> Dict[str, str]: # vm_config: Dict[str, Any] +disk_type_strs = {'pd-ssd', 'pd-standard', 'local-ssd'} DiskType = Literal['pd-ssd', 'pd-standard', 'local-ssd'] +def assert_valid_disk_type(disk_type: str) -> DiskType: + if disk_type in disk_type_strs: + return cast(DiskType, disk_type) + raise ValueError(f'invalid disk type: {disk_type}') + + class Disk(TypedDict): boot: bool project: Optional[str] @@ -64,7 +75,7 @@ def from_vm_config(vm_config: Dict[str, Any], job_private: bool = False) -> 'GCP for disk_config in vm_config['disks']: params = disk_config['initializeParams'] disk_info = parse_disk_type(params['diskType']) - disk_type = disk_info['disk_type'] + disk_type = assert_valid_disk_type(disk_info['disk_type']) if disk_type == 'local-ssd': disk_size = 375 @@ -199,7 +210,7 @@ def resources(self, cpu_in_mcpu: int, memory_in_bytes: int, storage_in_gib: int) # storage is in units of MiB resources.append({'name': 'disk/pd-ssd/1', 'quantity': storage_in_gib * 1024}) - quantities = defaultdict(lambda: 0) + quantities: Dict[str, int] = defaultdict(lambda: 0) for disk in self.disks: name = f'disk/{disk["type"]}/1' # the factors of 1024 cancel between GiB -> MiB and fraction_1024 -> fraction diff --git a/batch/batch/cloud/gcp/resource_utils.py b/batch/batch/cloud/gcp/resource_utils.py index 9e0286b1248..d8a7f944280 100644 --- a/batch/batch/cloud/gcp/resource_utils.py +++ b/batch/batch/cloud/gcp/resource_utils.py @@ -29,6 +29,8 @@ def gcp_machine_type_to_dict(machine_type: str) -> Optional[Dict[str, Any]]: match = MACHINE_TYPE_REGEX.fullmatch(machine_type) + if match is None: + return match return match.groupdict() diff --git a/batch/batch/cloud/worker_utils.py b/batch/batch/cloud/worker_utils.py index eb08b713042..9d609abdc36 100644 --- a/batch/batch/cloud/worker_utils.py +++ b/batch/batch/cloud/worker_utils.py @@ -1,17 +1,20 @@ -from typing import Dict, TYPE_CHECKING +from typing import Dict -from .gcp.worker.disk import GCPDisk from .gcp.worker.credentials import GCPUserCredentials +from .gcp.worker.disk import GCPDisk from .gcp.instance_config import GCPInstanceConfig -if TYPE_CHECKING: - from ..worker.disk import CloudDisk # pylint: disable=cyclic-import - from ..worker.credentials import CloudUserCredentials # pylint: disable=cyclic-import - from ..instance_config import InstanceConfig # pylint: disable=cyclic-import +from ..worker.credentials import CloudUserCredentials +from ..worker.disk import CloudDisk +from ..instance_config import InstanceConfig -def get_cloud_disk(instance_name: str, disk_name: str, size_in_gb: int, mount_path: str, - instance_config: 'InstanceConfig') -> 'CloudDisk': +def get_cloud_disk(instance_name: str, + disk_name: str, + size_in_gb: int, + mount_path: str, + instance_config: InstanceConfig + ) -> CloudDisk: cloud = instance_config.cloud assert cloud == 'gcp' assert isinstance(instance_config, GCPInstanceConfig) @@ -26,6 +29,6 @@ def get_cloud_disk(instance_name: str, disk_name: str, size_in_gb: int, mount_pa return disk -def get_user_credentials(cloud: str, credentials: Dict[str, bytes]) -> 'CloudUserCredentials': +def get_user_credentials(cloud: str, credentials: Dict[str, bytes]) -> CloudUserCredentials: assert cloud == 'gcp' return GCPUserCredentials(credentials) diff --git a/batch/batch/driver/instance.py b/batch/batch/driver/instance.py index 62b24a89fd8..10cb91b9480 100644 --- a/batch/batch/driver/instance.py +++ b/batch/batch/driver/instance.py @@ -5,7 +5,6 @@ import humanize import base64 import json -from typing import Optional from hailtop.utils import time_msecs, time_msecs_str, retry_transient_errors from hailtop import httpx @@ -101,10 +100,10 @@ def __init__( free_cores_mcpu, time_created, failed_request_count, - last_updated, + last_updated: int, ip_address, version, - instance_config: Optional[InstanceConfig], + instance_config: InstanceConfig, ): self.db: Database = app['db'] self.client_session: httpx.ClientSession = app['client_session'] diff --git a/batch/batch/driver/resource_manager.py b/batch/batch/driver/resource_manager.py index d04545676f3..4244460d6e8 100644 --- a/batch/batch/driver/resource_manager.py +++ b/batch/batch/driver/resource_manager.py @@ -1,10 +1,8 @@ +from typing import Optional, Any, TypeVar, Generic import abc import logging -from typing import TYPE_CHECKING, Optional, Any -if TYPE_CHECKING: - from ..instance_config import InstanceConfig - from .instance import Instance # pylint: disable=cyclic-import +from .instance import Instance log = logging.getLogger('compute_manager') @@ -26,7 +24,10 @@ def __init__(self, state: str, full_spec: Any, last_state_change_timestamp_msecs self.last_state_change_timestamp_msecs = last_state_change_timestamp_msecs -class CloudResourceManager(abc.ABC): +T = TypeVar('T') + + +class CloudResourceManager(Generic[T]): default_location: str @abc.abstractmethod @@ -44,17 +45,17 @@ def prepare_vm(self, worker_type: Optional[str] = None, cores: Optional[int] = None, location: Optional[str] = None, - ) -> Optional['InstanceConfig']: + ) -> Optional[T]: pass @abc.abstractmethod - async def create_vm(self, instance_config: 'InstanceConfig'): + async def create_vm(self, instance_config: T): pass @abc.abstractmethod - async def delete_vm(self, instance: 'Instance'): + async def delete_vm(self, instance: Instance): pass @abc.abstractmethod - async def get_vm_state(self, instance: 'Instance') -> VMState: + async def get_vm_state(self, instance: Instance) -> VMState: pass diff --git a/batch/batch/instance_config.py b/batch/batch/instance_config.py index 0c72af628e0..e5b19e32883 100644 --- a/batch/batch/instance_config.py +++ b/batch/batch/instance_config.py @@ -17,10 +17,16 @@ class InstanceConfig(abc.ABC): data_disk_size_gb: int job_private: bool worker_type: str - machine_type: str - location: str config: Dict[str, Any] + @property + def machine_type(self) -> str: + raise NotImplementedError + + @property + def location(self) -> str: + raise NotImplementedError + @abc.abstractmethod def resources(self, cpu_in_mcpu, memory_in_bytes, storage_in_gib): pass diff --git a/batch/batch/worker/credentials.py b/batch/batch/worker/credentials.py index 780916f7173..eba5cc6b8e1 100644 --- a/batch/batch/worker/credentials.py +++ b/batch/batch/worker/credentials.py @@ -9,7 +9,10 @@ class CloudUserCredentials(abc.ABC): cloud_env_name: str hail_env_name: str username: str - password: str + + @property + def password(self) -> str: + raise NotImplementedError @property def mount_path(self): diff --git a/batch/batch/worker/worker.py b/batch/batch/worker/worker.py index 5a1f7d59912..27fef415c5b 100644 --- a/batch/batch/worker/worker.py +++ b/batch/batch/worker/worker.py @@ -1076,7 +1076,7 @@ def credentials_host_file_path(self): @staticmethod def create(batch_id, user, - credentials: Optional[Dict[str, str]], + credentials: CloudUserCredentials, job_spec: dict, format_version: BatchFormatVersion, task_manager: aiotools.BackgroundTaskManager,