Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Adds porting of network configuration to generated base job templates
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle committed Mar 6, 2024
1 parent c24bf14 commit 0f35768
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 9 deletions.
19 changes: 19 additions & 0 deletions prefect_aws/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@
from prefect.utilities.pydantic import JsonPatch
from pydantic import VERSION as PYDANTIC_VERSION

from prefect_aws.utilities import assemble_document_for_patches

if PYDANTIC_VERSION.startswith("2."):
from pydantic.v1 import Field, root_validator, validator
else:
Expand Down Expand Up @@ -739,6 +741,23 @@ async def generate_work_pool_base_job_template(self) -> dict:
)

if self.task_customizations:
network_config_patches = JsonPatch(
[
patch
for patch in self.task_customizations
if "networkConfiguration" in patch["path"]
]
)
minimal_network_config = assemble_document_for_patches(
network_config_patches
)
if minimal_network_config:
minimal_network_config_with_patches = network_config_patches.apply(
minimal_network_config
)
base_job_template["variables"]["properties"]["network_configuration"][
"default"
] = minimal_network_config_with_patches["networkConfiguration"]
try:
base_job_template["job_configuration"]["task_run_request"] = (
self.task_customizations.apply(
Expand Down
81 changes: 81 additions & 0 deletions prefect_aws/utilities.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Utilities for working with AWS services."""

from typing import Dict, List, Union

from prefect.utilities.collections import visit_collection


Expand Down Expand Up @@ -33,3 +35,82 @@ def make_hashable(item):
collection, visit_fn=make_hashable, return_data=True
)
return hash(hashable_collection)


def ensure_path_exists(doc: Union[Dict, List], path: List[str]):
"""
Ensures the path exists in the document, creating empty dictionaries or lists as
needed.
Args:
doc: The current level of the document or sub-document.
path: The remaining path parts to ensure exist.
"""
if not path:
return
current_path = path.pop(0)
# Check if the next path part exists and is a digit
next_path_is_digit = path and path[0].isdigit()

# Determine if the current path is for an array or an object
if isinstance(doc, list): # Path is for an array index
current_path = int(current_path)
# Ensure the current level of the document is a list and long enough

while len(doc) <= current_path:
doc.append({})
next_level = doc[current_path]
else: # Path is for an object
if current_path not in doc or (
next_path_is_digit and not isinstance(doc.get(current_path), list)
):
doc[current_path] = [] if next_path_is_digit else {}
next_level = doc[current_path]

ensure_path_exists(next_level, path)


def assemble_document_for_patches(patches):
"""
Assembles an initial document that can successfully accept the given JSON Patch
operations.
Args:
patches: A list of JSON Patch operations.
Returns:
An initial document structured to accept the patches.
Example:
```python
patches = [
{"op": "replace", "path": "/name", "value": "Jane"},
{"op": "add", "path": "/contact/address", "value": "123 Main St"},
{"op": "remove", "path": "/age"}
]
initial_document = assemble_document_for_patches(patches)
#output
{
"name": {},
"contact": {},
"age": {}
}
```
"""
document = {}

for patch in patches:
operation = patch["op"]
path = patch["path"].lstrip("/").split("/")

if operation == "add":
# Ensure all but the last element of the path exists
ensure_path_exists(document, path[:-1])
elif operation in ["remove", "replace"]:
# For remove adn replace, the entire path should exist
ensure_path_exists(document, path)

return document
28 changes: 20 additions & 8 deletions tests/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2128,6 +2128,15 @@ def base_job_template_with_defaults(default_base_job_template, aws_credentials):
base_job_template_with_defaults["variables"]["properties"][
"auto_deregister_task_definition"
]["default"] = False
base_job_template_with_defaults["variables"]["properties"]["network_configuration"][
"default"
] = {
"awsvpcConfiguration": {
"subnets": ["subnet-***"],
"assignPublicIp": "DISABLED",
"securityGroups": ["sg-***"],
}
}
return base_job_template_with_defaults


Expand Down Expand Up @@ -2188,10 +2197,20 @@ async def test_generate_work_pool_base_job_template(
cpu=2048,
memory=4096,
task_customizations=[
{
"op": "replace",
"path": "/networkConfiguration/awsvpcConfiguration/assignPublicIp",
"value": "DISABLED",
},
{
"op": "add",
"path": "/networkConfiguration/awsvpcConfiguration/subnets",
"value": ["subnet-***"],
},
{
"op": "add",
"path": "/networkConfiguration/awsvpcConfiguration/securityGroups",
"value": ["sg-d72e9599956a084f5"],
"value": ["sg-***"],
},
],
family="test-family",
Expand Down Expand Up @@ -2229,10 +2248,3 @@ async def test_generate_work_pool_base_job_template(
template = await job.generate_work_pool_base_job_template()

assert template == expected_template

if job_config == "custom":
assert (
"Unable to apply task customizations to the base job template."
"You may need to update the template manually."
in caplog.text
)
59 changes: 58 additions & 1 deletion tests/test_utilities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import pytest

from prefect_aws.utilities import hash_collection
from prefect_aws.utilities import (
assemble_document_for_patches,
ensure_path_exists,
hash_collection,
)


class TestHashCollection:
Expand Down Expand Up @@ -32,3 +36,56 @@ def test_unhashable_structure(self):
assert hash_collection(typically_unhashable_structure) == hash_collection(
typically_unhashable_structure
), "Unhashable structure hashing failed after transformation"


class TestAssembleDocumentForPatches:
def test_initial_document(self):
patches = [
{"op": "replace", "path": "/name", "value": "Jane"},
{"op": "add", "path": "/contact/address", "value": "123 Main St"},
{"op": "remove", "path": "/age"},
]

initial_document = assemble_document_for_patches(patches)

expected_document = {"name": {}, "contact": {}, "age": {}}

assert initial_document == expected_document, "Initial document assembly failed"


class TestEnsurePathExists:
def test_existing_path(self):
doc = {"key1": {"subkey1": "value1"}}
path = ["key1", "subkey1"]
ensure_path_exists(doc, path)
assert doc == {
"key1": {"subkey1": "value1"}
}, "Existing path modification failed"

def test_new_path_object(self):
doc = {}
path = ["key1", "subkey1"]
ensure_path_exists(doc, path)
assert doc == {"key1": {"subkey1": {}}}, "New path creation for object failed"

def test_new_path_array(self):
doc = {}
path = ["key1", "0"]
ensure_path_exists(doc, path)
assert doc == {"key1": [{}]}, "New path creation for array failed"

def test_existing_path_array(self):
doc = {"key1": [{"subkey1": "value1"}]}
path = ["key1", "0", "subkey1"]
ensure_path_exists(doc, path)
assert doc == {
"key1": [{"subkey1": "value1"}]
}, "Existing path modification for array failed"

def test_existing_path_array_index_out_of_range(self):
doc = {"key1": []}
path = ["key1", "0", "subkey1"]
ensure_path_exists(doc, path)
assert doc == {
"key1": [{"subkey1": {}}]
}, "Existing path modification for array index out of range failed"

0 comments on commit 0f35768

Please sign in to comment.