Skip to content

Commit

Permalink
Merge pull request #257 from populationgenomics/upstream-106
Browse files Browse the repository at this point in the history
Upstream 107
  • Loading branch information
lgruen authored Dec 16, 2022
2 parents f6126c4 + fe50208 commit e48bcf2
Show file tree
Hide file tree
Showing 305 changed files with 8,741 additions and 2,445 deletions.
4 changes: 0 additions & 4 deletions auth/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
FROM {{ service_base_image.image }}

RUN hail-pip-install \
google-auth-oauthlib==0.4.6 \
google-auth==1.25.0

COPY auth/setup.py auth/MANIFEST.in /auth/
COPY auth/auth /auth/auth/
RUN hail-pip-install /auth && rm -rf /auth
Expand Down
70 changes: 6 additions & 64 deletions batch/batch/cloud/azure/driver/billing_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,10 @@
from collections import namedtuple
from typing import Dict, List

from gear import Database, transaction
from gear import Database
from hailtop.aiocloud import aioazure

from ....driver.billing_manager import (
CloudBillingManager,
ProductVersions,
product_version_to_resource,
refresh_product_versions_from_db,
)
from ....driver.billing_manager import CloudBillingManager, ProductVersions, refresh_product_versions_from_db
from .pricing import AzureVMPrice, fetch_prices

log = logging.getLogger('billing_manager')
Expand Down Expand Up @@ -51,64 +46,11 @@ async def get_spot_billing_price(self, machine_type: str, location: str) -> floa
return self.vm_price_cache[vm_identifier].cost_per_hour

async def refresh_resources_from_retail_prices(self):
log.info('refreshing resources from retail prices')
resource_updates = []
product_version_updates = []
vm_cache_updates = {}
prices = await fetch_prices(self.pricing_client, self.regions)

for price in await fetch_prices(self.pricing_client, self.regions):
product = price.product
latest_product_version = price.version
latest_resource_rate = price.rate

resource_name = product_version_to_resource(product, latest_product_version)

current_product_version = self.product_versions.latest_version(product)
current_resource_rate = self.resource_rates.get(resource_name)

if current_resource_rate is None:
resource_updates.append((resource_name, latest_resource_rate))
elif abs(current_resource_rate - latest_resource_rate) > 1e-20:
log.error(
f'resource {resource_name} does not have the latest rate in the database for '
f'version {current_product_version}: {current_resource_rate} vs {latest_resource_rate}; '
f'did the vm price change without a version change?'
)
continue

if price.is_current_price() and (
current_product_version is None or current_product_version != latest_product_version
):
product_version_updates.append((product, latest_product_version))
await self._refresh_resources_from_retail_prices(prices)

for price in prices:
if isinstance(price, AzureVMPrice):
vm_identifier = AzureVMIdentifier(price.machine_type, price.preemptible, price.region)
vm_cache_updates[vm_identifier] = price

@transaction(self.db)
async def insert_or_update(tx):
if resource_updates:
await tx.execute_many(
'''
INSERT INTO `resources` (resource, rate)
VALUES (%s, %s)
''',
resource_updates,
)

if product_version_updates:
await tx.execute_many(
'''
INSERT INTO `latest_product_versions` (product, version)
VALUES (%s, %s)
ON DUPLICATE KEY UPDATE version = VALUES(version)
''',
product_version_updates,
)

await insert_or_update() # pylint: disable=no-value-for-parameter

if resource_updates or product_version_updates:
await self.refresh_resources()

self.vm_price_cache.update(vm_cache_updates)
self.vm_price_cache[vm_identifier] = price
5 changes: 5 additions & 0 deletions batch/batch/cloud/azure/driver/create_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def create_vm_config(
) -> dict:
_, cores = azure_machine_type_to_worker_type_and_cores(machine_type)

region = instance_config.region_for(location)

if max_price is not None and not preemptible:
raise ValueError(f'max price given for a nonpreemptible machine {max_price}')

Expand Down Expand Up @@ -188,6 +190,7 @@ def create_vm_config(
BATCH_WORKER_IMAGE=$(jq -r '.batch_worker_image' userdata)
DOCKER_ROOT_IMAGE=$(jq -r '.docker_root_image' userdata)
DOCKER_PREFIX=$(jq -r '.docker_prefix' userdata)
REGION=$(jq -r '.region' userdata)
INTERNAL_GATEWAY_IP=$(jq -r '.internal_ip' userdata)
Expand Down Expand Up @@ -235,6 +238,7 @@ def create_vm_config(
-e SUBSCRIPTION_ID=$SUBSCRIPTION_ID \
-e RESOURCE_GROUP=$RESOURCE_GROUP \
-e LOCATION=$LOCATION \
-e REGION=$REGION \
-e DOCKER_PREFIX=$DOCKER_PREFIX \
-e DOCKER_ROOT_IMAGE=$DOCKER_ROOT_IMAGE \
-e INSTANCE_CONFIG=$INSTANCE_CONFIG \
Expand Down Expand Up @@ -288,6 +292,7 @@ def create_vm_config(
'instance_id': file_store.instance_id,
'max_idle_time_msecs': max_idle_time_msecs,
'instance_config': base64.b64encode(json.dumps(instance_config.to_dict()).encode()).decode(),
'region': region,
}
user_data_str = base64.b64encode(json.dumps(user_data).encode('utf-8')).decode('utf-8')

Expand Down
33 changes: 5 additions & 28 deletions batch/batch/cloud/azure/driver/pricing.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,21 @@
import abc
import asyncio
import logging
from typing import Dict, List, Optional

import dateutil.parser

from hailtop.aiocloud import aioazure
from hailtop.utils import flatten, grouped, time_msecs
from hailtop.utils import flatten, grouped
from hailtop.utils.rates import rate_gib_month_to_mib_msec, rate_instance_hour_to_fraction_msec

from ....driver.pricing import Price
from ..resource_utils import azure_disk_name_to_storage_gib, azure_valid_machine_types
from ..resources import AzureStaticSizedDiskResource, AzureVMResource

log = logging.getLogger('pricing')


class AzurePrice(abc.ABC):
region: str
effective_start_date: int
effective_end_date: Optional[int]
time_updated: int

def is_current_price(self):
now = time_msecs()
return now >= self.effective_start_date and (self.effective_end_date is None or now <= self.effective_end_date)

@property
def version(self) -> str:
return f'{self.effective_start_date}'

@property
def product(self):
raise NotImplementedError

@property
def rate(self):
raise NotImplementedError


class AzureVMPrice(AzurePrice):
class AzureVMPrice(Price):
def __init__(
self,
machine_type: str,
Expand All @@ -64,7 +41,7 @@ def rate(self):
return rate_instance_hour_to_fraction_msec(self.cost_per_hour, 1024)


class AzureDiskPrice(AzurePrice):
class AzureDiskPrice(Price):
def __init__(
self,
disk_name: str,
Expand Down Expand Up @@ -172,7 +149,7 @@ async def managed_disk_prices_by_region(
return prices


async def fetch_prices(pricing_client: aioazure.AzurePricingClient, regions: List[str]) -> List[AzurePrice]:
async def fetch_prices(pricing_client: aioazure.AzurePricingClient, regions: List[str]) -> List[Price]:
# Azure seems to have a limit on how long the OData filter request can be so we split the query into smaller groups
vm_coros = [
vm_prices_by_region(pricing_client, region, machine_types)
Expand Down
26 changes: 21 additions & 5 deletions batch/batch/cloud/gcp/driver/billing_manager.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,38 @@
import logging
from typing import Dict
from typing import Dict, List

from gear import Database
from hailtop.aiocloud import aiogoogle

from ....driver.billing_manager import CloudBillingManager, ProductVersions, refresh_product_versions_from_db
from .pricing import fetch_prices

log = logging.getLogger('resource_manager')
log = logging.getLogger('billing_manager')


class GCPBillingManager(CloudBillingManager):
@staticmethod
async def create(db: Database):
async def create(db: Database, billing_client: aiogoogle.GoogleBillingClient, regions: List[str]):
product_versions_dict = await refresh_product_versions_from_db(db)
bm = GCPBillingManager(db, product_versions_dict)
bm = GCPBillingManager(db, product_versions_dict, billing_client, regions)
await bm.refresh_resources()
await bm.refresh_resources_from_retail_prices()
return bm

def __init__(self, db: Database, product_versions_dict: Dict[str, str]):
def __init__(
self,
db: Database,
product_versions_dict: Dict[str, str],
billing_client: aiogoogle.GoogleBillingClient,
regions: List[str],
):
self.db = db
self.resource_rates: Dict[str, float] = {}
self.product_versions = ProductVersions(product_versions_dict)
self.billing_client = billing_client
self.regions = regions
self.currency_code = 'USD'

async def refresh_resources_from_retail_prices(self):
prices = [price async for price in fetch_prices(self.billing_client, self.regions, self.currency_code)]
await self._refresh_resources_from_retail_prices(prices)
7 changes: 7 additions & 0 deletions batch/batch/cloud/gcp/driver/create_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def create_vm_config(
) -> dict:
_, cores = gcp_machine_type_to_worker_type_and_cores(machine_type)

region = instance_config.region_for(zone)

if local_ssd_data_disk:
worker_data_disk = {
'type': 'SCRATCH',
Expand Down Expand Up @@ -242,6 +244,8 @@ def scheduling() -> dict:
INSTANCE_ID=$(curl -s -H "Metadata-Flavor: Google" "http://metadata.google.internal/computeMetadata/v1/instance/attributes/instance_id")
INSTANCE_CONFIG=$(curl -s -H "Metadata-Flavor: Google" "http://metadata.google.internal/computeMetadata/v1/instance/attributes/instance_config")
MAX_IDLE_TIME_MSECS=$(curl -s -H "Metadata-Flavor: Google" "http://metadata.google.internal/computeMetadata/v1/instance/attributes/max_idle_time_msecs")
REGION=$(curl -s -H "Metadata-Flavor: Google" "http://metadata.google.internal/computeMetadata/v1/instance/attributes/region")
NAME=$(curl -s http://metadata.google.internal/computeMetadata/v1/instance/name -H 'Metadata-Flavor: Google')
ZONE=$(curl -s http://metadata.google.internal/computeMetadata/v1/instance/zone -H 'Metadata-Flavor: Google')
Expand Down Expand Up @@ -287,6 +291,8 @@ def scheduling() -> dict:
-e INSTANCE_ID=$INSTANCE_ID \
-e PROJECT=$PROJECT \
-e ZONE=$ZONE \
-e REGION=$REGION \
-e DOCKER_PREFIX=$DOCKER_PREFIX \
-e DOCKER_ROOT_IMAGE=$DOCKER_ROOT_IMAGE \
-e INSTANCE_CONFIG=$INSTANCE_CONFIG \
-e DOCKER_PREFIX=$DOCKER_PREFIX \
Expand Down Expand Up @@ -348,6 +354,7 @@ def scheduling() -> dict:
{'key': 'batch_logs_storage_uri', 'value': file_store.batch_logs_storage_uri},
{'key': 'instance_id', 'value': file_store.instance_id},
{'key': 'max_idle_time_msecs', 'value': max_idle_time_msecs},
{'key': 'region', 'value': region},
{
'key': 'instance_config',
'value': base64.b64encode(json.dumps(instance_config.to_dict()).encode()).decode(),
Expand Down
6 changes: 4 additions & 2 deletions batch/batch/cloud/gcp/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ async def create(
rate_limit=RateLimit(10, 60),
)

billing_client = aiogoogle.GoogleBillingClient(credentials_file=credentials_file)

zone_monitor = await ZoneMonitor.create(compute_client, regions, zone)
billing_manager = await GCPBillingManager.create(db)
billing_manager = await GCPBillingManager.create(db, billing_client, regions)
inst_coll_manager = InstanceCollectionManager(db, machine_name_prefix, zone_monitor, region, regions)
resource_manager = GCPResourceManager(project, compute_client, billing_manager)

Expand Down Expand Up @@ -111,7 +113,7 @@ async def create(
task_manager.ensure_future(periodically_call(15, driver.process_activity_logs))
task_manager.ensure_future(periodically_call(60, zone_monitor.update_region_quotas))
task_manager.ensure_future(periodically_call(60, driver.delete_orphaned_disks))
task_manager.ensure_future(periodically_call(300, billing_manager.refresh_resources))
task_manager.ensure_future(periodically_call(300, billing_manager.refresh_resources_from_retail_prices))

return driver

Expand Down
Loading

0 comments on commit e48bcf2

Please sign in to comment.