Skip to content

Commit

Permalink
fix type issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel King committed Oct 26, 2021
1 parent 0f45dda commit 8e995d6
Show file tree
Hide file tree
Showing 12 changed files with 69 additions and 37 deletions.
3 changes: 2 additions & 1 deletion batch/batch/cloud/gcp/driver/create_instance.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any
import os
import logging
import base64
Expand Down Expand Up @@ -61,7 +62,7 @@ def create_instance_config(
)
assert unreserved_disk_storage_gb >= 0

vm_config = {
vm_config: Any = {
'name': machine_name,
'machineType': f'projects/{project}/zones/{zone}/machineTypes/{machine_type}',
'labels': {'role': 'batch2-agent', 'namespace': DEFAULT_NAMESPACE},
Expand Down
7 changes: 4 additions & 3 deletions batch/batch/cloud/gcp/driver/driver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Set
from typing import Optional, Set, Dict, Any, List

from hailtop import aiotools
from hailtop.aiocloud import aiogoogle
Expand Down Expand Up @@ -72,8 +72,8 @@ def __init__(self, db: Database, machine_name_prefix: str, task_manager: aiotool
self.namespace = namespace

self.zone_success_rate = ZoneSuccessRate()
self.region_info = None
self.zones = []
self.region_info: Optional[Dict[str, Dict[str, Any]]] = None
self.zones: List[str] = []

async def shutdown(self):
try:
Expand All @@ -91,6 +91,7 @@ def get_zone(self, cores: int, worker_local_ssd_data_disk: bool, worker_pd_ssd_d
global_live_total_cores_mcpu = self.inst_coll_manager.global_live_total_cores_mcpu
if global_live_total_cores_mcpu // 1000 < 1_000:
return self.resource_manager.default_location
assert self.region_info is not None # FIXME: this reveals a race condition: if update_region_quotas does not run before we get_zone this will fail
return get_zone(self.region_info, self.zone_success_rate, cores, worker_local_ssd_data_disk, worker_pd_ssd_data_disk_size_gb)

async def process_activity_logs(self):
Expand Down
7 changes: 4 additions & 3 deletions batch/batch/cloud/gcp/driver/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
log = logging.getLogger('resource_manager')


def parse_gcp_timestamp(timestamp: Optional[str]) -> Optional[float]:
def parse_gcp_timestamp(timestamp: Optional[str]) -> Optional[int]:
if timestamp is None:
return None
return dateutil.parser.isoparse(timestamp).timestamp() * 1000
return int(dateutil.parser.isoparse(timestamp).timestamp() * 1000 + 0.5)


class GCPResourceManager(CloudResourceManager):
class GCPResourceManager(CloudResourceManager[GCPInstanceConfig]):
def __init__(self, driver: 'GCPDriver', compute_client: aiogoogle.GoogleComputeClient, default_location: str):
self.driver = driver
self.compute_client = compute_client
Expand Down Expand Up @@ -93,6 +93,7 @@ def prepare_vm(self,

zone = location
if zone is None:
assert cores is not None, (cores, zone)
zone = self.driver.get_zone(cores, worker_local_ssd_data_disk, worker_pd_ssd_data_disk_size_gb)
if zone is None:
return None
Expand Down
8 changes: 6 additions & 2 deletions batch/batch/cloud/gcp/driver/zones.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,12 @@ def compute_zone_weights(region_info: Dict[str, Dict[str, Any]], worker_cores: i
return weights


def get_zone(region_info: Dict[str, Dict[str, Any]], zone_success_rate: ZoneSuccessRate, worker_cores: int,
worker_local_ssd_data_disk: bool, worker_pd_ssd_data_disk_size_gb: int) -> Optional[str]:
def get_zone(region_info: Dict[str, Dict[str, Any]],
zone_success_rate: ZoneSuccessRate,
worker_cores: int,
worker_local_ssd_data_disk: bool,
worker_pd_ssd_data_disk_size_gb: int
) -> Optional[str]:
zone_weights = compute_zone_weights(region_info, worker_cores, worker_local_ssd_data_disk, worker_pd_ssd_data_disk_size_gb)

if not zone_weights:
Expand Down
17 changes: 14 additions & 3 deletions batch/batch/cloud/gcp/instance_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Dict, Any
from typing import List, Optional, Dict, Any, cast
from typing_extensions import Literal
from mypy_extensions import TypedDict
import re
Expand All @@ -16,11 +16,15 @@

def parse_machine_type_str(name: str) -> Dict[str, str]:
match = MACHINE_TYPE_REGEX.fullmatch(name)
if match is None:
raise ValueError(f'invalid machine type string: {name}')
return match.groupdict()


def parse_disk_type(name: str) -> Dict[str, str]:
match = DISK_TYPE_REGEX.fullmatch(name)
if match is None:
raise ValueError(f'invalid disk type string: {name}')
return match.groupdict()


Expand All @@ -41,9 +45,16 @@ def parse_disk_type(name: str) -> Dict[str, str]:
# vm_config: Dict[str, Any]


disk_type_strs = {'pd-ssd', 'pd-standard', 'local-ssd'}
DiskType = Literal['pd-ssd', 'pd-standard', 'local-ssd']


def assert_valid_disk_type(disk_type: str) -> DiskType:
if disk_type in disk_type_strs:
return cast(DiskType, disk_type)
raise ValueError(f'invalid disk type: {disk_type}')


class Disk(TypedDict):
boot: bool
project: Optional[str]
Expand All @@ -64,7 +75,7 @@ def from_vm_config(vm_config: Dict[str, Any], job_private: bool = False) -> 'GCP
for disk_config in vm_config['disks']:
params = disk_config['initializeParams']
disk_info = parse_disk_type(params['diskType'])
disk_type = disk_info['disk_type']
disk_type = assert_valid_disk_type(disk_info['disk_type'])

if disk_type == 'local-ssd':
disk_size = 375
Expand Down Expand Up @@ -199,7 +210,7 @@ def resources(self, cpu_in_mcpu: int, memory_in_bytes: int, storage_in_gib: int)
# storage is in units of MiB
resources.append({'name': 'disk/pd-ssd/1', 'quantity': storage_in_gib * 1024})

quantities = defaultdict(lambda: 0)
quantities: Dict[str, int] = defaultdict(lambda: 0)
for disk in self.disks:
name = f'disk/{disk["type"]}/1'
# the factors of 1024 cancel between GiB -> MiB and fraction_1024 -> fraction
Expand Down
2 changes: 2 additions & 0 deletions batch/batch/cloud/gcp/resource_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

def gcp_machine_type_to_dict(machine_type: str) -> Optional[Dict[str, Any]]:
match = MACHINE_TYPE_REGEX.fullmatch(machine_type)
if match is None:
return match
return match.groupdict()


Expand Down
21 changes: 12 additions & 9 deletions batch/batch/cloud/worker_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from typing import Dict, TYPE_CHECKING
from typing import Dict

from .gcp.worker.disk import GCPDisk
from .gcp.worker.credentials import GCPUserCredentials
from .gcp.worker.disk import GCPDisk
from .gcp.instance_config import GCPInstanceConfig

if TYPE_CHECKING:
from ..worker.disk import CloudDisk # pylint: disable=cyclic-import
from ..worker.credentials import CloudUserCredentials # pylint: disable=cyclic-import
from ..instance_config import InstanceConfig # pylint: disable=cyclic-import
from ..worker.credentials import CloudUserCredentials
from ..worker.disk import CloudDisk
from ..instance_config import InstanceConfig


def get_cloud_disk(instance_name: str, disk_name: str, size_in_gb: int, mount_path: str,
instance_config: 'InstanceConfig') -> 'CloudDisk':
def get_cloud_disk(instance_name: str,
disk_name: str,
size_in_gb: int,
mount_path: str,
instance_config: InstanceConfig
) -> CloudDisk:
cloud = instance_config.cloud
assert cloud == 'gcp'
assert isinstance(instance_config, GCPInstanceConfig)
Expand All @@ -26,6 +29,6 @@ def get_cloud_disk(instance_name: str, disk_name: str, size_in_gb: int, mount_pa
return disk


def get_user_credentials(cloud: str, credentials: Dict[str, bytes]) -> 'CloudUserCredentials':
def get_user_credentials(cloud: str, credentials: Dict[str, bytes]) -> CloudUserCredentials:
assert cloud == 'gcp'
return GCPUserCredentials(credentials)
5 changes: 2 additions & 3 deletions batch/batch/driver/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import humanize
import base64
import json
from typing import Optional

from hailtop.utils import time_msecs, time_msecs_str, retry_transient_errors
from hailtop import httpx
Expand Down Expand Up @@ -101,10 +100,10 @@ def __init__(
free_cores_mcpu,
time_created,
failed_request_count,
last_updated,
last_updated: int,
ip_address,
version,
instance_config: Optional[InstanceConfig],
instance_config: InstanceConfig,
):
self.db: Database = app['db']
self.client_session: httpx.ClientSession = app['client_session']
Expand Down
19 changes: 10 additions & 9 deletions batch/batch/driver/resource_manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import Optional, Any, TypeVar, Generic
import abc
import logging
from typing import TYPE_CHECKING, Optional, Any

if TYPE_CHECKING:
from ..instance_config import InstanceConfig
from .instance import Instance # pylint: disable=cyclic-import
from .instance import Instance


log = logging.getLogger('compute_manager')
Expand All @@ -26,7 +24,10 @@ def __init__(self, state: str, full_spec: Any, last_state_change_timestamp_msecs
self.last_state_change_timestamp_msecs = last_state_change_timestamp_msecs


class CloudResourceManager(abc.ABC):
T = TypeVar('T')


class CloudResourceManager(Generic[T]):
default_location: str

@abc.abstractmethod
Expand All @@ -44,17 +45,17 @@ def prepare_vm(self,
worker_type: Optional[str] = None,
cores: Optional[int] = None,
location: Optional[str] = None,
) -> Optional['InstanceConfig']:
) -> Optional[T]:
pass

@abc.abstractmethod
async def create_vm(self, instance_config: 'InstanceConfig'):
async def create_vm(self, instance_config: T):
pass

@abc.abstractmethod
async def delete_vm(self, instance: 'Instance'):
async def delete_vm(self, instance: Instance):
pass

@abc.abstractmethod
async def get_vm_state(self, instance: 'Instance') -> VMState:
async def get_vm_state(self, instance: Instance) -> VMState:
pass
10 changes: 8 additions & 2 deletions batch/batch/instance_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@ class InstanceConfig(abc.ABC):
data_disk_size_gb: int
job_private: bool
worker_type: str
machine_type: str
location: str
config: Dict[str, Any]

@property
def machine_type(self) -> str:
raise NotImplementedError

@property
def location(self) -> str:
raise NotImplementedError

@abc.abstractmethod
def resources(self, cpu_in_mcpu, memory_in_bytes, storage_in_gib):
pass
Expand Down
5 changes: 4 additions & 1 deletion batch/batch/worker/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ class CloudUserCredentials(abc.ABC):
cloud_env_name: str
hail_env_name: str
username: str
password: str

@property
def password(self) -> str:
raise NotImplementedError

@property
def mount_path(self):
Expand Down
2 changes: 1 addition & 1 deletion batch/batch/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,7 +1076,7 @@ def credentials_host_file_path(self):
@staticmethod
def create(batch_id,
user,
credentials: Optional[Dict[str, str]],
credentials: CloudUserCredentials,
job_spec: dict,
format_version: BatchFormatVersion,
task_manager: aiotools.BackgroundTaskManager,
Expand Down

0 comments on commit 8e995d6

Please sign in to comment.