Skip to content

Commit

Permalink
[batch] Dont use Batch identity on workers (#12611)
Browse files Browse the repository at this point in the history
* [batch] Dont use Batch identity on workers

* no longer create one fs per JVM container

* make the worker own its own task manager
  • Loading branch information
daniel-goldstein authored Jan 27, 2023
1 parent b7da591 commit 9f03cf4
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 148 deletions.
3 changes: 0 additions & 3 deletions batch/batch/cloud/azure/worker/disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,3 @@ async def create(self, labels=None):

async def delete(self):
raise NotImplementedError

async def close(self):
raise NotImplementedError
16 changes: 12 additions & 4 deletions batch/batch/cloud/azure/worker/worker_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import aiohttp

from gear.cloud_config import get_azure_config
from hailtop import httpx
from hailtop.aiocloud import aioazure
from hailtop.utils import request_retry_transient_errors, time_msecs

from ....worker.worker_api import CloudWorkerAPI
Expand All @@ -14,6 +16,8 @@


class AzureWorkerAPI(CloudWorkerAPI):
nameserver_ip = '168.63.129.16'

@staticmethod
def from_env():
subscription_id = os.environ['SUBSCRIPTION_ID']
Expand All @@ -26,14 +30,18 @@ def __init__(self, subscription_id: str, resource_group: str, acr_url: str):
self.subscription_id = subscription_id
self.resource_group = resource_group
self.acr_refresh_token = AcrRefreshToken(acr_url, AadAccessToken())

@property
def nameserver_ip(self):
return '168.63.129.16'
self.azure_credentials = aioazure.AzureCredentials.default_credentials()

def create_disk(self, instance_name: str, disk_name: str, size_in_gb: int, mount_path: str) -> AzureDisk:
return AzureDisk(disk_name, instance_name, size_in_gb, mount_path)

def get_cloud_async_fs(self) -> aioazure.AzureAsyncFS:
return aioazure.AzureAsyncFS(credentials=self.azure_credentials)

def get_compute_client(self) -> aioazure.AzureComputeClient:
azure_config = get_azure_config()
return aioazure.AzureComputeClient(azure_config.subscription_id, azure_config.resource_group)

def user_credentials(self, credentials: Dict[str, bytes]) -> AzureUserCredentials:
return AzureUserCredentials(credentials)

Expand Down
18 changes: 11 additions & 7 deletions batch/batch/cloud/gcp/worker/disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,29 @@


class GCPDisk(CloudDisk):
def __init__(self, name: str, zone: str, project: str, instance_name: str, size_in_gb: int, mount_path: str):
def __init__(
self,
name: str,
zone: str,
project: str,
instance_name: str,
size_in_gb: int,
mount_path: str,
compute_client: aiogoogle.GoogleComputeClient, # BORROWED
):
assert size_in_gb >= 10
# disk name must be 63 characters or less
# https://cloud.google.com/compute/docs/reference/rest/v1/disks#resource:-disk
# under the information for the name field
assert len(name) <= 63

self.compute_client = aiogoogle.GoogleComputeClient(
project, credentials=aiogoogle.GoogleCredentials.from_file('/worker-key.json')
)
self.name = name
self.zone = zone
self.project = project
self.instance_name = instance_name
self.size_in_gb = size_in_gb
self.mount_path = mount_path
self.compute_client = compute_client

self._created = False
self._attached = False
Expand All @@ -49,9 +56,6 @@ async def delete(self):
finally:
await self._delete()

async def close(self):
await self.compute_client.close()

async def _unmount(self):
if self._attached:
await retry_all_errors_n_times(
Expand Down
24 changes: 17 additions & 7 deletions batch/batch/cloud/gcp/worker/worker_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import aiohttp

from hailtop import httpx
from hailtop.aiocloud import aiogoogle
from hailtop.utils import request_retry_transient_errors

from ....worker.worker_api import CloudWorkerAPI
Expand All @@ -13,19 +14,21 @@


class GCPWorkerAPI(CloudWorkerAPI):
nameserver_ip = '169.254.169.254'

# async because GoogleSession must be created inside a running event loop
@staticmethod
def from_env():
async def from_env() -> 'GCPWorkerAPI':
project = os.environ['PROJECT']
zone = os.environ['ZONE'].rsplit('/', 1)[1]
return GCPWorkerAPI(project, zone)
session = aiogoogle.GoogleSession()
return GCPWorkerAPI(project, zone, session)

def __init__(self, project: str, zone: str):
def __init__(self, project: str, zone: str, session: aiogoogle.GoogleSession):
self.project = project
self.zone = zone

@property
def nameserver_ip(self):
return '169.254.169.254'
self._google_session = session
self._compute_client = aiogoogle.GoogleComputeClient(project, session=session)

def create_disk(self, instance_name: str, disk_name: str, size_in_gb: int, mount_path: str) -> GCPDisk:
return GCPDisk(
Expand All @@ -35,8 +38,15 @@ def create_disk(self, instance_name: str, disk_name: str, size_in_gb: int, mount
name=disk_name,
size_in_gb=size_in_gb,
mount_path=mount_path,
compute_client=self._compute_client,
)

def get_cloud_async_fs(self) -> aiogoogle.GoogleStorageAsyncFS:
return aiogoogle.GoogleStorageAsyncFS(session=self._google_session)

def get_compute_client(self) -> aiogoogle.GoogleComputeClient:
return self._compute_client

def user_credentials(self, credentials: Dict[str, bytes]) -> GCPUserCredentials:
return GCPUserCredentials(credentials)

Expand Down
5 changes: 3 additions & 2 deletions batch/batch/driver/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ async def get_gsa_key(request, instance): # pylint: disable=unused-argument
return await asyncio.shield(get_gsa_key_1(instance))


# deprecated
@routes.get('/api/v1alpha/instances/credentials')
@activating_instances_only
async def get_credentials(request, instance): # pylint: disable=unused-argument
Expand Down Expand Up @@ -1325,12 +1326,12 @@ async def on_startup(app):
app['cancel_running_state_changed'] = asyncio.Event()
app['async_worker_pool'] = AsyncWorkerPool(100, queue_size=100)

credentials_file = '/gsa-key/key.json'
fs = get_cloud_async_fs(credentials_file=credentials_file)
fs = get_cloud_async_fs()
app['file_store'] = FileStore(fs, BATCH_STORAGE_URI, instance_id)

inst_coll_configs = await InstanceCollectionConfigs.create(db)

credentials_file = '/gsa-key/key.json'
app['driver'] = await get_cloud_driver(
app, db, MACHINE_NAME_PREFIX, DEFAULT_NAMESPACE, inst_coll_configs, credentials_file, task_manager
)
Expand Down
2 changes: 1 addition & 1 deletion batch/batch/front_end/front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -2806,7 +2806,7 @@ async def on_startup(app):
assert max(regions.values()) < 64, str(regions)
app['regions'] = regions

fs = get_cloud_async_fs(credentials_file='/gsa-key/key.json')
fs = get_cloud_async_fs()
app['file_store'] = FileStore(fs, BATCH_STORAGE_URI, instance_id)

app['inst_coll_configs'] = await InstanceCollectionConfigs.create(db)
Expand Down
5 changes: 0 additions & 5 deletions batch/batch/worker/disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ async def __aenter__(self, labels=None):

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.delete()
await self.close()

@abc.abstractmethod
async def create(self, labels=None):
Expand All @@ -22,7 +21,3 @@ async def create(self, labels=None):
@abc.abstractmethod
async def delete(self):
pass

@abc.abstractmethod
async def close(self):
pass
Loading

0 comments on commit 9f03cf4

Please sign in to comment.