Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into fb-dia-1511
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-bernstein committed Oct 24, 2024
2 parents 312c916 + d6d39e5 commit 72000b0
Show file tree
Hide file tree
Showing 9 changed files with 301 additions and 43 deletions.
10 changes: 7 additions & 3 deletions .github/workflows/build_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,15 @@ jobs:
version=$(sed "s/^v//g" <<< ${PROVIDED_VERSION})
sed -i "s/^version[ ]*=.*/version = \"${version}\"/g" ${{ env.PYTHON_VERSION_FILE }}
- name: Set up poetry
uses: snok/install-poetry@v1
- name: "Install poetry"
run: pipx install poetry

- name: "Set up Python"
id: setup_python
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: true
cache: 'poetry'

- name: Install dependencies
run: |
Expand Down
24 changes: 16 additions & 8 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ jobs:
steps:
- uses: actions/checkout@v4

- name: Set up poetry
uses: snok/install-poetry@v1
- name: "Install poetry"
run: pipx install poetry

- name: "Set up Python"
id: setup_python
uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: true
python-version: '3.11'
cache: 'poetry'

- name: Install Python dependencies
run: |
Expand All @@ -45,11 +49,15 @@ jobs:
steps:
- uses: actions/checkout@v4

- name: Set up poetry
uses: snok/install-poetry@v1
- name: "Install poetry"
run: pipx install poetry

- name: "Set up Python"
id: setup_python
uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: true
python-version: '3.11'
cache: 'poetry'

- name: Install Python dependencies
run: |
Expand Down
14 changes: 10 additions & 4 deletions .github/workflows/follow-merge-upstream-repo-sync.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
github.event.client_payload.event_action == 'opened' ||
github.event.client_payload.event_action == 'synchronize' ||
github.event.client_payload.event_action == 'merged'
runs-on: ubuntu-22.04
runs-on: ubuntu-latest
steps:
- uses: hmarr/debug-action@v3.0.0

Expand Down Expand Up @@ -87,9 +87,15 @@ jobs:
our_files: "pyproject.toml poetry.lock web/package.json web/yarn.lock"
working_directory: "${{ env.UPSTREAM_REPO_WORKDIR }}"

- name: "Poetry: Set up"
if: steps.details.outputs.poetry
uses: Gr1N/setup-poetry@v9
- name: "Install poetry"
run: pipx install poetry

- name: "Set up Python"
id: setup_python
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'poetry'

- name: Commit submodule
shell: bash
Expand Down
33 changes: 14 additions & 19 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,15 @@ jobs:
steps:
- uses: actions/checkout@v4

- name: Set up poetry
uses: snok/install-poetry@v1
- name: "Install poetry"
run: pipx install poetry

- name: "Set up Python"
id: setup_python
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: true
cache: 'poetry'

- name: Install Python dependencies
run: poetry install
Expand All @@ -60,23 +64,15 @@ jobs:
steps:
- uses: actions/checkout@v4

- name: Set up python
uses: actions/setup-python@v5
with:
python-version: "${{ matrix.python-version }}"
- name: "Install poetry"
run: pipx install poetry

- name: Set up poetry
uses: snok/install-poetry@v1
with:
virtualenvs-create: true
virtualenvs-in-project: true

- name: Load cached venv
id: cached-pip-wheels
uses: actions/cache@v4
- name: "Set up Python ${{ matrix.python-version }}"
id: setup_python
uses: actions/setup-python@v5
with:
path: ~/.cache
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
python-version: '${{ matrix.python-version }}'
cache: 'poetry'

- name: Install dependencies
run: poetry install --no-interaction --no-root
Expand All @@ -87,7 +83,6 @@ jobs:
- name: Run tests with coverage
if: matrix.os == 'ubuntu-latest'
run: |
source $VENV
poetry run pytest tests/ -vv ${{ matrix.python-version == '3.11' && '--cov=. --cov-report=xml' || '' }}
- name: Upload to Codecov
Expand Down
4 changes: 2 additions & 2 deletions adala/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Agent(BaseModel, ABC):
Attributes:
environment (Environment): The environment with which the agent interacts.
skills (Union[SkillSet, List[Skill]]): The skills possessed by the agent.
skills (SkillSet): The skills possessed by the agent.
memory (LongTermMemory, optional): The agent's long-term memory. Defaults to None.
runtimes (Dict[str, Runtime], optional): The runtimes available to the agent. Defaults to predefined runtimes.
default_runtime (str): The default runtime used by the agent. Defaults to 'openai'.
Expand All @@ -58,7 +58,7 @@ class Agent(BaseModel, ABC):
"""

environment: Optional[SerializeAsAny[Union[Environment, AsyncEnvironment]]] = None
skills: SerializeAsAny[Union[Skill, SkillSet]]
skills: SerializeAsAny[SkillSet]

memory: Memory = Field(default=None)
runtimes: Dict[str, SerializeAsAny[Union[Runtime, AsyncRuntime]]] = Field(
Expand Down
70 changes: 68 additions & 2 deletions adala/runtimes/_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import instructor
from instructor.exceptions import InstructorRetryException, IncompleteOutputException
import traceback
from adala.runtimes.base import CostEstimate
from adala.utils.exceptions import ConstrainedGenerationError
from adala.utils.internal_data import InternalDataFrame
from adala.utils.parse import (
Expand Down Expand Up @@ -122,7 +123,6 @@ def _get_usage_dict(usage: Usage, model: str) -> Dict:


class InstructorClientMixin:

def _from_litellm(self, **kwargs):
return instructor.from_litellm(litellm.completion, **kwargs)

Expand All @@ -139,7 +139,6 @@ def is_custom_openai_endpoint(self) -> bool:


class InstructorAsyncClientMixin(InstructorClientMixin):

def _from_litellm(self, **kwargs):
return instructor.from_litellm(litellm.acompletion, **kwargs)

Expand Down Expand Up @@ -527,6 +526,73 @@ async def record_to_record(
# Extract the single row from the output DataFrame and convert it to a dictionary
return output_df.iloc[0].to_dict()

@staticmethod
def _get_prompt_tokens(string: str, model: str, output_fields: List[str]) -> int:
user_tokens = litellm.token_counter(model=model, text=string)
# FIXME surprisingly difficult to get function call tokens, and doesn't add a ton of value, so hard-coding until something like litellm supports doing this for us.
# currently seems like we'd need to scrape the instructor logs to get the function call info, then use (at best) an openai-specific 3rd party lib to get a token estimate from that.
system_tokens = 56 + (6 * len(output_fields))
return user_tokens + system_tokens

@staticmethod
def _get_completion_tokens(model: str, output_fields: Optional[List[str]]) -> int:
max_tokens = litellm.get_model_info(
model=model, custom_llm_provider="openai"
).get("max_tokens", None)
if not max_tokens:
raise ValueError
# extremely rough heuristic, from testing on some anecdotal examples
n_outputs = len(output_fields) if output_fields else 1
return min(max_tokens, 4 * n_outputs)

@classmethod
def _estimate_cost(
cls, user_prompt: str, model: str, output_fields: Optional[List[str]]
):
prompt_tokens = cls._get_prompt_tokens(user_prompt, model, output_fields)
completion_tokens = cls._get_completion_tokens(model, output_fields)
prompt_cost, completion_cost = litellm.cost_per_token(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
total_cost = prompt_cost + completion_cost

return prompt_cost, completion_cost, total_cost

def get_cost_estimate(
self, prompt: str, substitutions: List[Dict], output_fields: Optional[List[str]]
) -> CostEstimate:
try:
user_prompts = [
prompt.format(**substitution) for substitution in substitutions
]
cumulative_prompt_cost = 0
cumulative_completion_cost = 0
cumulative_total_cost = 0
for user_prompt in user_prompts:
prompt_cost, completion_cost, total_cost = self._estimate_cost(
user_prompt=user_prompt,
model=self.model,
output_fields=output_fields,
)
cumulative_prompt_cost += prompt_cost
cumulative_completion_cost += completion_cost
cumulative_total_cost += total_cost
return CostEstimate(
prompt_cost_usd=cumulative_prompt_cost,
completion_cost_usd=cumulative_completion_cost,
total_cost_usd=cumulative_total_cost,
)

except Exception as e:
logger.error("Failed to estimate cost: %s", e)
return CostEstimate(
is_error=True,
error_type=type(e).__name__,
error_message=str(e),
)


class LiteLLMVisionRuntime(LiteLLMChatRuntime):
"""
Expand Down
49 changes: 44 additions & 5 deletions adala/runtimes/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,51 @@
import logging
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Type

from tqdm import tqdm
from abc import ABC, abstractmethod
from pydantic import BaseModel, model_validator, Field
from typing import List, Dict, Optional, Tuple, Any, Callable, ClassVar, Type
from adala.utils.internal_data import InternalDataFrame, InternalSeries
from adala.utils.internal_data import InternalDataFrame
from adala.utils.registry import BaseModelInRegistry
from pandarallel import pandarallel
from pydantic import BaseModel, Field, model_validator
from tqdm import tqdm

logger = logging.getLogger(__name__)
tqdm.pandas()


class CostEstimate(BaseModel):
prompt_cost_usd: Optional[float] = None
completion_cost_usd: Optional[float] = None
total_cost_usd: Optional[float] = None
is_error: bool = False
error_type: Optional[str] = None
error_message: Optional[str] = None

def __add__(self, other: "CostEstimate") -> "CostEstimate":
# if either has an error, it takes precedence
if self.is_error:
return self
if other.is_error:
return other

def _safe_add(lhs: Optional[float], rhs: Optional[float]) -> Optional[float]:
if lhs is None and rhs is None:
return None
_lhs = lhs or 0.0
_rhs = rhs or 0.0
return _lhs + _rhs

prompt_cost_usd = _safe_add(self.prompt_cost_usd, other.prompt_cost_usd)
completion_cost_usd = _safe_add(
self.completion_cost_usd, other.completion_cost_usd
)
total_cost_usd = _safe_add(self.total_cost_usd, other.total_cost_usd)
return CostEstimate(
prompt_cost_usd=prompt_cost_usd,
completion_cost_usd=completion_cost_usd,
total_cost_usd=total_cost_usd,
)


class Runtime(BaseModelInRegistry):
"""
Base class representing a generic runtime environment.
Expand Down Expand Up @@ -191,6 +225,11 @@ def record_to_batch(
response_model=response_model,
)

def get_cost_estimate(
self, prompt: str, substitutions: List[Dict], output_fields: Optional[List[str]]
) -> CostEstimate:
raise NotImplementedError("This runtime does not support cost estimates")


class AsyncRuntime(Runtime):
"""Async version of runtime that uses asyncio to process batch of records."""
Expand Down
Loading

0 comments on commit 72000b0

Please sign in to comment.