Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix type issues #9

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion batch/batch/cloud/gcp/driver/create_instance.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any
import os
import logging
import base64
Expand Down Expand Up @@ -61,7 +62,7 @@ def create_instance_config(
)
assert unreserved_disk_storage_gb >= 0

vm_config = {
vm_config: Any = {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an escape hatch. I accidentally fell into rabbit hole of cleaning up the instance config situation so that it was both easier to read and easier to type, but it was a rather large change so I abandoned it.

'name': machine_name,
'machineType': f'projects/{project}/zones/{zone}/machineTypes/{machine_type}',
'labels': {'role': 'batch2-agent', 'namespace': DEFAULT_NAMESPACE},
Expand Down
7 changes: 4 additions & 3 deletions batch/batch/cloud/gcp/driver/driver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Set
from typing import Optional, Set, Dict, Any, List

from hailtop import aiotools
from hailtop.aiocloud import aiogoogle
Expand Down Expand Up @@ -72,8 +72,8 @@ def __init__(self, db: Database, machine_name_prefix: str, task_manager: aiotool
self.namespace = namespace

self.zone_success_rate = ZoneSuccessRate()
self.region_info = None
self.zones = []
self.region_info: Optional[Dict[str, Dict[str, Any]]] = None
self.zones: List[str] = []

async def shutdown(self):
try:
Expand All @@ -91,6 +91,7 @@ def get_zone(self, cores: int, worker_local_ssd_data_disk: bool, worker_pd_ssd_d
global_live_total_cores_mcpu = self.inst_coll_manager.global_live_total_cores_mcpu
if global_live_total_cores_mcpu // 1000 < 1_000:
return self.resource_manager.default_location
assert self.region_info is not None # FIXME: this reveals a race condition: if update_region_quotas does not run before we get_zone this will fail
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should really be fixed by using an async factory method to ensure reg_info is never None

return get_zone(self.region_info, self.zone_success_rate, cores, worker_local_ssd_data_disk, worker_pd_ssd_data_disk_size_gb)

async def process_activity_logs(self):
Expand Down
7 changes: 4 additions & 3 deletions batch/batch/cloud/gcp/driver/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
log = logging.getLogger('resource_manager')


def parse_gcp_timestamp(timestamp: Optional[str]) -> Optional[float]:
def parse_gcp_timestamp(timestamp: Optional[str]) -> Optional[int]:
if timestamp is None:
return None
return dateutil.parser.isoparse(timestamp).timestamp() * 1000
return int(dateutil.parser.isoparse(timestamp).timestamp() * 1000 + 0.5)


class GCPResourceManager(CloudResourceManager):
class GCPResourceManager(CloudResourceManager[GCPInstanceConfig]):
def __init__(self, driver: 'GCPDriver', compute_client: aiogoogle.GoogleComputeClient, default_location: str):
self.driver = driver
self.compute_client = compute_client
Expand Down Expand Up @@ -93,6 +93,7 @@ def prepare_vm(self,

zone = location
if zone is None:
assert cores is not None, (cores, zone)
zone = self.driver.get_zone(cores, worker_local_ssd_data_disk, worker_pd_ssd_data_disk_size_gb)
if zone is None:
return None
Expand Down
8 changes: 6 additions & 2 deletions batch/batch/cloud/gcp/driver/zones.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,12 @@ def compute_zone_weights(region_info: Dict[str, Dict[str, Any]], worker_cores: i
return weights


def get_zone(region_info: Dict[str, Dict[str, Any]], zone_success_rate: ZoneSuccessRate, worker_cores: int,
worker_local_ssd_data_disk: bool, worker_pd_ssd_data_disk_size_gb: int) -> Optional[str]:
def get_zone(region_info: Dict[str, Dict[str, Any]],
zone_success_rate: ZoneSuccessRate,
worker_cores: int,
worker_local_ssd_data_disk: bool,
worker_pd_ssd_data_disk_size_gb: int
) -> Optional[str]:
zone_weights = compute_zone_weights(region_info, worker_cores, worker_local_ssd_data_disk, worker_pd_ssd_data_disk_size_gb)

if not zone_weights:
Expand Down
17 changes: 14 additions & 3 deletions batch/batch/cloud/gcp/instance_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Dict, Any
from typing import List, Optional, Dict, Any, cast
from typing_extensions import Literal
from mypy_extensions import TypedDict
import re
Expand All @@ -16,11 +16,15 @@

def parse_machine_type_str(name: str) -> Dict[str, str]:
match = MACHINE_TYPE_REGEX.fullmatch(name)
if match is None:
raise ValueError(f'invalid machine type string: {name}')
return match.groupdict()


def parse_disk_type(name: str) -> Dict[str, str]:
match = DISK_TYPE_REGEX.fullmatch(name)
if match is None:
raise ValueError(f'invalid disk type string: {name}')
return match.groupdict()


Expand All @@ -41,9 +45,16 @@ def parse_disk_type(name: str) -> Dict[str, str]:
# vm_config: Dict[str, Any]


disk_type_strs = {'pd-ssd', 'pd-standard', 'local-ssd'}
DiskType = Literal['pd-ssd', 'pd-standard', 'local-ssd']


def assert_valid_disk_type(disk_type: str) -> DiskType:
if disk_type in disk_type_strs:
return cast(DiskType, disk_type)
raise ValueError(f'invalid disk type: {disk_type}')


class Disk(TypedDict):
boot: bool
project: Optional[str]
Expand All @@ -64,7 +75,7 @@ def from_vm_config(vm_config: Dict[str, Any], job_private: bool = False) -> 'GCP
for disk_config in vm_config['disks']:
params = disk_config['initializeParams']
disk_info = parse_disk_type(params['diskType'])
disk_type = disk_info['disk_type']
disk_type = assert_valid_disk_type(disk_info['disk_type'])

if disk_type == 'local-ssd':
disk_size = 375
Expand Down Expand Up @@ -199,7 +210,7 @@ def resources(self, cpu_in_mcpu: int, memory_in_bytes: int, storage_in_gib: int)
# storage is in units of MiB
resources.append({'name': 'disk/pd-ssd/1', 'quantity': storage_in_gib * 1024})

quantities = defaultdict(lambda: 0)
quantities: Dict[str, int] = defaultdict(lambda: 0)
for disk in self.disks:
name = f'disk/{disk["type"]}/1'
# the factors of 1024 cancel between GiB -> MiB and fraction_1024 -> fraction
Expand Down
2 changes: 2 additions & 0 deletions batch/batch/cloud/gcp/resource_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

def gcp_machine_type_to_dict(machine_type: str) -> Optional[Dict[str, Any]]:
match = MACHINE_TYPE_REGEX.fullmatch(machine_type)
if match is None:
return match
return match.groupdict()


Expand Down
21 changes: 12 additions & 9 deletions batch/batch/cloud/worker_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from typing import Dict, TYPE_CHECKING
from typing import Dict

from .gcp.worker.disk import GCPDisk
from .gcp.worker.credentials import GCPUserCredentials
from .gcp.worker.disk import GCPDisk
from .gcp.instance_config import GCPInstanceConfig

if TYPE_CHECKING:
from ..worker.disk import CloudDisk # pylint: disable=cyclic-import
from ..worker.credentials import CloudUserCredentials # pylint: disable=cyclic-import
from ..instance_config import InstanceConfig # pylint: disable=cyclic-import
from ..worker.credentials import CloudUserCredentials
from ..worker.disk import CloudDisk
from ..instance_config import InstanceConfig


def get_cloud_disk(instance_name: str, disk_name: str, size_in_gb: int, mount_path: str,
instance_config: 'InstanceConfig') -> 'CloudDisk':
def get_cloud_disk(instance_name: str,
disk_name: str,
size_in_gb: int,
mount_path: str,
instance_config: InstanceConfig
) -> CloudDisk:
cloud = instance_config.cloud
assert cloud == 'gcp'
assert isinstance(instance_config, GCPInstanceConfig)
Expand All @@ -26,6 +29,6 @@ def get_cloud_disk(instance_name: str, disk_name: str, size_in_gb: int, mount_pa
return disk


def get_user_credentials(cloud: str, credentials: Dict[str, bytes]) -> 'CloudUserCredentials':
def get_user_credentials(cloud: str, credentials: Dict[str, bytes]) -> CloudUserCredentials:
assert cloud == 'gcp'
return GCPUserCredentials(credentials)
5 changes: 2 additions & 3 deletions batch/batch/driver/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import humanize
import base64
import json
from typing import Optional

from hailtop.utils import time_msecs, time_msecs_str, retry_transient_errors
from hailtop import httpx
Expand Down Expand Up @@ -101,10 +100,10 @@ def __init__(
free_cores_mcpu,
time_created,
failed_request_count,
last_updated,
last_updated: int,
ip_address,
version,
instance_config: Optional[InstanceConfig],
instance_config: InstanceConfig,
):
self.db: Database = app['db']
self.client_session: httpx.ClientSession = app['client_session']
Expand Down
19 changes: 10 additions & 9 deletions batch/batch/driver/resource_manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import Optional, Any, TypeVar, Generic
import abc
import logging
from typing import TYPE_CHECKING, Optional, Any

if TYPE_CHECKING:
from ..instance_config import InstanceConfig
from .instance import Instance # pylint: disable=cyclic-import
from .instance import Instance


log = logging.getLogger('compute_manager')
Expand All @@ -26,7 +24,10 @@ def __init__(self, state: str, full_spec: Any, last_state_change_timestamp_msecs
self.last_state_change_timestamp_msecs = last_state_change_timestamp_msecs


class CloudResourceManager(abc.ABC):
T = TypeVar('T')


class CloudResourceManager(Generic[T]):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't quite right. The structure we have is a bit complicated and prepare_vm and create_vm really shouldn't be two separate functions.

default_location: str

@abc.abstractmethod
Expand All @@ -44,17 +45,17 @@ def prepare_vm(self,
worker_type: Optional[str] = None,
cores: Optional[int] = None,
location: Optional[str] = None,
) -> Optional['InstanceConfig']:
) -> Optional[T]:
pass

@abc.abstractmethod
async def create_vm(self, instance_config: 'InstanceConfig'):
async def create_vm(self, instance_config: T):
pass

@abc.abstractmethod
async def delete_vm(self, instance: 'Instance'):
async def delete_vm(self, instance: Instance):
pass

@abc.abstractmethod
async def get_vm_state(self, instance: 'Instance') -> VMState:
async def get_vm_state(self, instance: Instance) -> VMState:
pass
10 changes: 8 additions & 2 deletions batch/batch/instance_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@ class InstanceConfig(abc.ABC):
data_disk_size_gb: int
job_private: bool
worker_type: str
machine_type: str
location: str
config: Dict[str, Any]

@property
def machine_type(self) -> str:
raise NotImplementedError

@property
def location(self) -> str:
raise NotImplementedError

@abc.abstractmethod
def resources(self, cpu_in_mcpu, memory_in_bytes, storage_in_gib):
pass
Expand Down
5 changes: 4 additions & 1 deletion batch/batch/worker/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ class CloudUserCredentials(abc.ABC):
cloud_env_name: str
hail_env_name: str
username: str
password: str

@property
def password(self) -> str:
raise NotImplementedError

@property
def mount_path(self):
Expand Down
2 changes: 1 addition & 1 deletion batch/batch/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,7 +1076,7 @@ def credentials_host_file_path(self):
@staticmethod
def create(batch_id,
user,
credentials: Optional[Dict[str, str]],
credentials: CloudUserCredentials,
job_spec: dict,
format_version: BatchFormatVersion,
task_manager: aiotools.BackgroundTaskManager,
Expand Down