Skip to content

Commit

Permalink
Update functional.py (#419)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Feil <63565275+michaelfeil@users.noreply.github.com>
  • Loading branch information
YadlaMani and michaelfeil authored Nov 16, 2024
1 parent 764917a commit d1fa1c1
Showing 1 changed file with 66 additions and 40 deletions.
106 changes: 66 additions & 40 deletions infra/modal/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,21 @@ def _get_array(self):

@build()
async def download_model(self):
print(f"downloading models {self.model_id} ...")
self._get_array()
try:
print(f"downloading models {self.model_id} ...")
self._get_array()
except Exception as e:
print(f"Error downloading model: {e}")

@enter()
async def enter(self):
print("Starting the engine array ...")
self.engine_array = self._get_array()
await self.engine_array.astart()
print("engine array started!")
try:
print("Starting the engine array ...")
self.engine_array = self._get_array()
await self.engine_array.astart()
print("engine array started!")
except Exception as e:
print(f"Error starting the engine array: {e}")


@app.cls(gpu="any", allow_concurrent_inputs=500)
Expand All @@ -49,27 +55,43 @@ def __init__(self, model_id: tuple[str]) -> None:

@method()
async def embed(self, sentences: list[str], model: str | int = 0):
engine = self.engine_array[model]
embeddings, usage = await engine.embed(sentences=sentences)
return embeddings
try:
engine = self.engine_array[model]
embeddings, usage = await engine.embed(sentences=sentences)
return embeddings
except Exception as e:
print(f"Error embedding sentences: {e}")
return None

@method()
async def image_embed(self, urls: list[str], model: str | int = 0):
engine = self.engine_array[model]
embeddings, usage = await engine.image_embed(images=urls)
return embeddings
try:
engine = self.engine_array[model]
embeddings, usage = await engine.image_embed(images=urls)
return embeddings
except Exception as e:
print(f"Error embedding images: {e}")
return None

@method()
async def rerank(self, query: str, docs: list[str], model: str | int = 0):
engine = self.engine_array[model]
rankings, usage = await engine.rerank(query=query, docs=docs)
return rankings
try:
engine = self.engine_array[model]
rankings, usage = await engine.rerank(query=query, docs=docs)
return rankings
except Exception as e:
print(f"Error reranking documents: {e}")
return None

@method()
async def classify(self, sentences: list[str], model: str | int = 0):
engine = self.engine_array[model]
classes, usage = await engine.classify(sentences=sentences)
return classes
try:
engine = self.engine_array[model]
classes, usage = await engine.classify(sentences=sentences)
return classes
except Exception as e:
print(f"Error classifying sentences: {e}")
return None


@app.local_entrypoint()
Expand All @@ -81,28 +103,32 @@ def main():
"philschmid/tiny-bert-sst2-distilled",
)
deployment = InfinityModal(model_id=model_id)
embeddings_1 = deployment.embed.remote(sentences=["hello world"], model=model_id[1])
embeddings_2 = deployment.image_embed.remote(
urls=["http://images.cocodataset.org/val2017/000000039769.jpg"],
model=model_id[0],
)

try:
embeddings_1 = deployment.embed.remote(sentences=["hello world"], model=model_id[1])
embeddings_2 = deployment.image_embed.remote(
urls=["http://images.cocodataset.org/val2017/000000039769.jpg"],
model=model_id[0],
)

rerankings_1 = deployment.rerank.remote(
query="Where is Paris?",
docs=["Paris is the capital of France.", "Berlin is a city in Europe."],
model=model_id[2],
)
rerankings_1 = deployment.rerank.remote(
query="Where is Paris?",
docs=["Paris is the capital of France.", "Berlin is a city in Europe."],
model=model_id[2],
)

classifications_1 = deployment.classify.remote(
sentences=["I feel great today!"], model=model_id[3]
)
classifications_1 = deployment.classify.remote(
sentences=["I feel great today!"], model=model_id[3]
)

print(
"Success, all tasks submitted! Embeddings:",
embeddings_1[0].shape,
embeddings_2[0].shape,
"Rerankings:",
rerankings_1,
"Classifications:",
classifications_1,
)
print(
"Success, all tasks submitted! Embeddings:",
embeddings_1[0].shape if embeddings_1 else "N/A",
embeddings_2[0].shape if embeddings_2 else "N/A",
"Rerankings:",
rerankings_1 if rerankings_1 else "N/A",
"Classifications:",
classifications_1 if classifications_1 else "N/A",
)
except Exception as e:
print(f"Error in main entrypoint: {e}")

0 comments on commit d1fa1c1

Please sign in to comment.