Skip to content

Commit

Permalink
LLM stop should be able to stop everything or specific endpoints.
Browse files Browse the repository at this point in the history
  • Loading branch information
pm3310 committed Jan 20, 2024
1 parent d135fc2 commit 94fd3d3
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 8 deletions.
52 changes: 50 additions & 2 deletions sagify/commands/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,34 @@ def start(


@click.command()
@click.option(
'--all',
is_flag=True,
show_default=True,
default=False,
help='Start infrastructure for all services.'
)
@click.option(
'--chat-completions',
is_flag=True,
show_default=True,
default=False,
help='Start infrastructure for chat completions.'
)
@click.option(
'--image-creations',
is_flag=True,
show_default=True,
default=False,
help='Start infrastructure for image creations.'
)
@click.option(
'--embeddings',
is_flag=True,
show_default=True,
default=False,
help='Start infrastructure for embeddings.'
)
@click.option(
u"--aws-profile",
required=True,
Expand All @@ -216,7 +244,16 @@ def start(
required=False,
help="Optional external id used when using an IAM role"
)
def stop(aws_profile, aws_region, iam_role_arn, external_id):
def stop(
all,
chat_completions,
image_creations,
embeddings,
aws_profile,
aws_region,
iam_role_arn,
external_id
):
"""
Command to stop LLM infrastructure
"""
Expand All @@ -228,7 +265,18 @@ def stop(aws_profile, aws_region, iam_role_arn, external_id):
with open('.sagify_llm_infra.json', 'r') as f:
llm_infra_config = json.load(f)

for _endpoint in ['chat_completions_endpoint', 'image_creations_endpoint', 'embeddings_endpoint']:
endpoints_to_stop = []
if all:
endpoints_to_stop = ['chat_completions_endpoint', 'image_creations_endpoint', 'embeddings_endpoint']
else:
if chat_completions:
endpoints_to_stop.append('chat_completions_endpoint')
if image_creations:
endpoints_to_stop.append('image_creations_endpoint')
if embeddings:
endpoints_to_stop.append('embeddings_endpoint')

for _endpoint in endpoints_to_stop:
if llm_infra_config[_endpoint]:
try:
sagemaker_client.shutdown_endpoint(llm_infra_config[_endpoint])
Expand Down
114 changes: 108 additions & 6 deletions tests/commands/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,11 @@ def test_start_embeddings_only(self):


class TestLlmStop(object):
def test_stop_happy_case(self):
def test_stop_all_happy_case(self):
runner = CliRunner()
with patch(
'sagify.commands.llm.sagemaker.SageMakerClient'
) as mocked_sagemaker_client:
# from unittest.mock import MagicMock

# mocked_sagemaker_client.return_value = MagicMock()
with runner.isolated_filesystem():
with open('.sagify_llm_infra.json', 'w') as f:
json.dump({
Expand All @@ -233,6 +230,7 @@ def test_stop_happy_case(self):
cli=cli,
args=[
'llm', 'stop',
'--all',
'--aws-region', 'us-east-1',
'--aws-profile', 'sagemaker-production',
'--iam-role-arn', 'arn:aws:iam::123456789012:role/MyRole',
Expand All @@ -254,12 +252,116 @@ def test_stop_happy_case(self):

assert result.exit_code == 0

def test_stop_chat_completions_only(self):
runner = CliRunner()
with patch(
'sagify.commands.llm.sagemaker.SageMakerClient'
) as mocked_sagemaker_client:
with runner.isolated_filesystem():
with open('.sagify_llm_infra.json', 'w') as f:
json.dump({
'chat_completions_endpoint': 'endpoint1',
'image_creations_endpoint': 'endpoint2',
'embeddings_endpoint': 'endpoint3'
}, f)

result = runner.invoke(
cli=cli,
args=[
'llm', 'stop',
'--chat-completions',
'--aws-region', 'us-east-1',
'--aws-profile', 'sagemaker-production',
'--iam-role-arn', 'arn:aws:iam::123456789012:role/MyRole',
'--external-id', '123456'
]
)

mocked_sagemaker_client.assert_called_with(
'sagemaker-production', 'us-east-1', 'arn:aws:iam::123456789012:role/MyRole', '123456'
)
assert mocked_sagemaker_client.return_value.shutdown_endpoint.call_count == 1
mocked_sagemaker_client.return_value.shutdown_endpoint.assert_called_with(
'endpoint1'
)

assert result.exit_code == 0

def test_stop_image_creations_only(self):
runner = CliRunner()
with patch(
'sagify.commands.llm.sagemaker.SageMakerClient'
) as mocked_sagemaker_client:
with runner.isolated_filesystem():
with open('.sagify_llm_infra.json', 'w') as f:
json.dump({
'chat_completions_endpoint': 'endpoint1',
'image_creations_endpoint': 'endpoint2',
'embeddings_endpoint': 'endpoint3'
}, f)

result = runner.invoke(
cli=cli,
args=[
'llm', 'stop',
'--image-creations',
'--aws-region', 'us-east-1',
'--aws-profile', 'sagemaker-production',
'--iam-role-arn', 'arn:aws:iam::123456789012:role/MyRole',
'--external-id', '123456'
]
)

mocked_sagemaker_client.assert_called_with(
'sagemaker-production', 'us-east-1', 'arn:aws:iam::123456789012:role/MyRole', '123456'
)
assert mocked_sagemaker_client.return_value.shutdown_endpoint.call_count == 1
mocked_sagemaker_client.return_value.shutdown_endpoint.assert_called_with(
'endpoint2'
)

assert result.exit_code == 0

def test_stop_embeddings_only(self):
runner = CliRunner()
with patch(
'sagify.commands.llm.sagemaker.SageMakerClient'
) as mocked_sagemaker_client:
with runner.isolated_filesystem():
with open('.sagify_llm_infra.json', 'w') as f:
json.dump({
'chat_completions_endpoint': 'endpoint1',
'image_creations_endpoint': 'endpoint2',
'embeddings_endpoint': 'endpoint3'
}, f)

result = runner.invoke(
cli=cli,
args=[
'llm', 'stop',
'--embeddings',
'--aws-region', 'us-east-1',
'--aws-profile', 'sagemaker-production',
'--iam-role-arn', 'arn:aws:iam::123456789012:role/MyRole',
'--external-id', '123456'
]
)

mocked_sagemaker_client.assert_called_with(
'sagemaker-production', 'us-east-1', 'arn:aws:iam::123456789012:role/MyRole', '123456'
)
assert mocked_sagemaker_client.return_value.shutdown_endpoint.call_count == 1
mocked_sagemaker_client.return_value.shutdown_endpoint.assert_called_with(
'endpoint3'
)

assert result.exit_code == 0

def test_stop_missing_config_file(self):
runner = CliRunner()
with patch(
'sagify.commands.llm.sagemaker.SageMakerClient'
) as mocked_sagemaker_client:
# mocked_sagemaker_client.return_value = MagicMock()
with runner.isolated_filesystem():
result = runner.invoke(
cli=cli,
Expand All @@ -280,7 +382,6 @@ def test_stop_endpoint_shutdown_error(self):
with patch(
'sagify.commands.llm.sagemaker.SageMakerClient'
) as mocked_sagemaker_client:
# mocked_sagemaker_client.return_value = MagicMock()
mocked_sagemaker_client.return_value.shutdown_endpoint.side_effect = Exception('Endpoint shutdown error')
with runner.isolated_filesystem():
with open('.sagify_llm_infra.json', 'w') as f:
Expand All @@ -294,6 +395,7 @@ def test_stop_endpoint_shutdown_error(self):
cli=cli,
args=[
'llm', 'stop',
'--all',
'--aws-region', 'us-east-1',
'--aws-profile', 'sagemaker-production',
'--iam-role-arn', 'arn:aws:iam::123456789012:role/MyRole',
Expand Down

0 comments on commit 94fd3d3

Please sign in to comment.