Skip to content

Commit

Permalink
[Autoscaler][gcp] parallel terminate nodes (#34455)
Browse files Browse the repository at this point in the history
Why are these changes needed?
ray down takes a lot of time when using GCPNodeProvider as stated in #26239 because GCPNodeProvider uses the serial implementation of terminate_nodes from parent class NodeProvider and also uses a coarse lock in its terminate_node which prevents executing it in a concurrent fashion (not really sure coz I'm new to this).

add threadpoolexecutor in GCPNodeProvider.terminate_nodes for parallelization execution of terminate_node
use fine-grained locks which assign one RLock per node_id
add unit_tests
why not go with the suggestions(batch apis and non-blocking version of terminate_node) mentioned in #26239?
As a novice, I think both solutions would break Liskov Substitute Principle, and also for those who already used terminate_node(s) would need to add await.

Related issue number
#26239

---------

Signed-off-by: Chen-Chen Yeh <ge96noj@mytum.de>
Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
  • Loading branch information
Dan-Yeh and Chen-Chen Yeh authored Apr 21, 2023
1 parent 5c7520f commit 46fc663
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 14 deletions.
44 changes: 30 additions & 14 deletions python/ray/autoscaler/_private/gcp/node_provider.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import concurrent.futures
import logging
import time
from functools import wraps
Expand Down Expand Up @@ -179,23 +180,38 @@ def create_node(self, base_config: dict, tags: dict, count: int) -> Dict[str, di
) # type: List[Tuple[dict, str]]
return {instance_id: result for result, instance_id in results}

def _thread_unsafe_terminate_node(self, node_id: str):
# Assumes the global lock is held for the duration of this operation.
# The lock may be held by a different thread if in `terminate_nodes()` case.
logger.info("NodeProvider: {}: Terminating node".format(node_id))
resource = self._get_resource_depending_on_node_name(node_id)
try:
result = resource.delete_instance(
node_id=node_id,
)
except googleapiclient.errors.HttpError as http_error:
if http_error.resp.status == 404:
logger.warning(
f"Tried to delete the node with id {node_id} "
"but it was already gone."
)
else:
raise http_error from None
return result

@_retry
def terminate_node(self, node_id: str):
with self.lock:
resource = self._get_resource_depending_on_node_name(node_id)
try:
result = resource.delete_instance(
node_id=node_id,
)
except googleapiclient.errors.HttpError as http_error:
if http_error.resp.status == 404:
logger.warning(
f"Tried to delete the node with id {node_id} "
"but it was already gone."
)
else:
raise http_error from None
return result
self._thread_unsafe_terminate_node(node_id)

def terminate_nodes(self, node_ids: List[str]):
if not node_ids:
return None

with self.lock, concurrent.futures.ThreadPoolExecutor() as executor:
result = executor.map(self._thread_unsafe_terminate_node, node_ids)

return list(result)

@_retry
def _get_node(self, node_id: str) -> GCPNode:
Expand Down
27 changes: 27 additions & 0 deletions python/ray/tests/gcp/test_gcp_node_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,33 @@ def __init__(self, provider_config: dict, cluster_name: str):
assert create_node_return_value == expected_return_value


def test_terminate_nodes():
mock_node_config = {"machineType": "n2-standard-8"}
node_type = GCPNodeType.COMPUTE.value
id1, id2 = f"instance-id1-{node_type}", f"instance-id2-{node_type}"
terminate_node_ids = [id1, id2]
mock_resource = MagicMock()
mock_resource.create_instances.return_value = [
({"dict": 1}, id1),
({"dict": 2}, id2),
]
mock_resource.delete_instance.return_value = "test"
expected_terminate_nodes_result_len = 2

def __init__(self, provider_config: dict, cluster_name: str):
self.lock = RLock()
self.cached_nodes: Dict[str, GCPNode] = {}
self.resources: Dict[GCPNodeType, GCPResource] = {}
self.resources[GCPNodeType.COMPUTE] = mock_resource

with patch.object(GCPNodeProvider, "__init__", __init__):
node_provider = GCPNodeProvider({}, "")
node_provider.create_node(mock_node_config, {}, 1)
create_results = node_provider.terminate_nodes(terminate_node_ids)

assert len(create_results) == expected_terminate_nodes_result_len


@pytest.mark.parametrize(
"test_case",
[
Expand Down

0 comments on commit 46fc663

Please sign in to comment.