Skip to content

Commit

Permalink
fix: task type sagemaker (#338)
Browse files Browse the repository at this point in the history
* fix: task type sagemaker

* style: black, isort

* chore: constrain pyright

* fix: setup

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

---------

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
  • Loading branch information
zanussbaum and ellipsis-dev[bot] authored Oct 23, 2024
1 parent d67643c commit 609009f
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 23 deletions.
2 changes: 1 addition & 1 deletion examples/sagemaker/run-nomic-embed-text.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@
}
],
"source": [
"response = embed_text(texts, endpoint_name, region_name=region_name, batch_size=32, dimensionality=128)\n",
"response = embed_text(texts, endpoint_name, region_name=region_name, batch_size=32, dimensionality=128, task_type=\"search_document\")\n",
"embeddings = response[\"embeddings\"]\n",
"np.array(embeddings).shape"
]
Expand Down
29 changes: 8 additions & 21 deletions nomic/aws/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,6 @@ def parse_sagemaker_response(response):
return resp["embeddings"]


def preprocess_texts(texts: List[str], task_type: str = "search_document"):
"""
Preprocess a list of texts for embedding using a sagemaker model.
Args:
texts: List of texts to be embedded.
task_type: The task type to use when embedding. One of `search_query`, `search_document`, `classification`, `clustering`
Returns:
List of texts formatted for sagemaker embedding.
"""
assert task_type in [
"search_query",
"search_document",
"classification",
"clustering",
], f"Invalid task type: {task_type}"
return [f"{task_type}: {text}" for text in texts]


def batch_transform_text(
s3_input_path: str,
s3_output_path: str,
Expand Down Expand Up @@ -157,7 +137,13 @@ def embed_text(
logger.warning("No texts to embed.")
return None

texts = preprocess_texts(texts, task_type)
assert task_type in [
"search_query",
"search_document",
"classification",
"clustering",
], f"Invalid task type: {task_type}"

assert dimensionality in (
64,
128,
Expand All @@ -175,6 +161,7 @@ def embed_text(
"texts": texts[i : i + batch_size],
"binary": binary,
"dimensionality": dimensionality,
"task_type": task_type,
}
)
response = client.invoke_endpoint(EndpointName=sagemaker_endpoint, Body=batch, ContentType="application/json")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"pylint",
"pytest",
"isort",
"pyright",
"pyright<=1.1.377",
"myst-parser",
"mkdocs-material",
"mkautodoc",
Expand Down

0 comments on commit 609009f

Please sign in to comment.