Skip to content

Commit

Permalink
python[patch]: pass Runnable to evaluate (#1204)
Browse files Browse the repository at this point in the history
Co-authored-by: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com>
  • Loading branch information
baskaryan and hinthornw authored Nov 13, 2024
1 parent 5d5cace commit 62e66ca
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 33 deletions.
14 changes: 10 additions & 4 deletions python/langsmith/evaluation/_arunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
_ExperimentManagerMixin,
_extract_feedback_keys,
_ForwardResults,
_is_langchain_runnable,
_load_examples_map,
_load_experiment,
_load_tqdm,
Expand Down Expand Up @@ -379,8 +380,10 @@ async def _aevaluate(
blocking: bool = True,
experiment: Optional[Union[schemas.TracerSession, str, uuid.UUID]] = None,
) -> AsyncExperimentResults:
is_async_target = asyncio.iscoroutinefunction(target) or (
hasattr(target, "__aiter__") and asyncio.iscoroutine(target.__aiter__())
is_async_target = (
asyncio.iscoroutinefunction(target)
or (hasattr(target, "__aiter__") and asyncio.iscoroutine(target.__aiter__()))
or _is_langchain_runnable(target)
)
client = client or rt.get_cached_client()
runs = None if is_async_target else cast(Iterable[schemas.Run], target)
Expand Down Expand Up @@ -940,7 +943,7 @@ def _get_run(r: run_trees.RunTree) -> None:
def _ensure_async_traceable(
target: ATARGET_T,
) -> rh.SupportsLangsmithExtra[[dict], Awaitable]:
if not asyncio.iscoroutinefunction(target):
if not asyncio.iscoroutinefunction(target) and not _is_langchain_runnable(target):
if callable(target):
raise ValueError(
"Target must be an async function. For sync functions, use evaluate."
Expand All @@ -961,7 +964,10 @@ def _ensure_async_traceable(
)
if rh.is_traceable_function(target):
return target # type: ignore
return rh.traceable(name="AsyncTarget")(target)
else:
if _is_langchain_runnable(target):
target = target.ainvoke # type: ignore[attr-defined]
return rh.traceable(name="AsyncTarget")(target)


def _aresolve_data(
Expand Down
41 changes: 29 additions & 12 deletions python/langsmith/evaluation/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@

if TYPE_CHECKING:
import pandas as pd
from langchain_core.runnables import Runnable

DataFrame = pd.DataFrame
else:
Expand Down Expand Up @@ -96,7 +97,7 @@


def evaluate(
target: TARGET_T,
target: Union[TARGET_T, Runnable],
/,
data: DATA_T,
evaluators: Optional[Sequence[EVALUATOR_T]] = None,
Expand Down Expand Up @@ -878,12 +879,12 @@ def _print_comparative_experiment_start(
)


def _is_callable(target: Union[TARGET_T, Iterable[schemas.Run]]) -> bool:
return callable(target) or (hasattr(target, "invoke") and callable(target.invoke))
def _is_callable(target: Union[TARGET_T, Iterable[schemas.Run], Runnable]) -> bool:
return callable(target) or _is_langchain_runnable(target)


def _evaluate(
target: Union[TARGET_T, Iterable[schemas.Run]],
target: Union[TARGET_T, Iterable[schemas.Run], Runnable],
/,
data: DATA_T,
evaluators: Optional[Sequence[EVALUATOR_T]] = None,
Expand Down Expand Up @@ -1664,12 +1665,13 @@ def _resolve_data(


def _ensure_traceable(
target: TARGET_T | rh.SupportsLangsmithExtra[[dict], dict],
target: TARGET_T | rh.SupportsLangsmithExtra[[dict], dict] | Runnable,
) -> rh.SupportsLangsmithExtra[[dict], dict]:
"""Ensure the target function is traceable."""
if not callable(target):
if not _is_callable(target):
raise ValueError(
"Target must be a callable function. For example:\n\n"
"Target must be a callable function or a langchain/langgraph object. For "
"example:\n\n"
"def predict(inputs: dict) -> dict:\n"
" # do work, like chain.invoke(inputs)\n"
" return {...}\n\n"
Expand All @@ -1679,9 +1681,11 @@ def _ensure_traceable(
")"
)
if rh.is_traceable_function(target):
fn = target
fn: rh.SupportsLangsmithExtra[[dict], dict] = target
else:
fn = rh.traceable(name="Target")(target)
if _is_langchain_runnable(target):
target = target.invoke # type: ignore[union-attr]
fn = rh.traceable(name="Target")(cast(Callable, target))
return fn


Expand Down Expand Up @@ -1709,9 +1713,8 @@ def _resolve_experiment(
return experiment_, runs
# If we have runs, that means the experiment was already started.
if runs is not None:
if runs is not None:
runs_, runs = itertools.tee(runs)
first_run = next(runs_)
runs_, runs = itertools.tee(runs)
first_run = next(runs_)
experiment_ = client.read_project(project_id=first_run.session_id)
if not experiment_.name:
raise ValueError("Experiment name not found for provided runs.")
Expand Down Expand Up @@ -1923,3 +1926,17 @@ def _flatten_experiment_results(
}
for x in results[start:end]
]


@functools.lru_cache(maxsize=1)
def _import_langchain_runnable() -> Optional[type]:
try:
from langchain_core.runnables import Runnable

return Runnable
except ImportError:
return None


def _is_langchain_runnable(o: Any) -> bool:
return bool((Runnable := _import_langchain_runnable()) and isinstance(o, Runnable))
6 changes: 2 additions & 4 deletions python/langsmith/run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,7 @@ def tracing_context(
get_run_tree_context = get_current_run_tree


def is_traceable_function(
func: Callable[P, R],
) -> TypeGuard[SupportsLangsmithExtra[P, R]]:
def is_traceable_function(func: Any) -> TypeGuard[SupportsLangsmithExtra[P, R]]:
"""Check if a function is @traceable decorated."""
return (
_is_traceable_function(func)
Expand Down Expand Up @@ -1445,7 +1443,7 @@ def _handle_container_end(
LOGGER.warning(f"Unable to process trace outputs: {repr(e)}")


def _is_traceable_function(func: Callable) -> bool:
def _is_traceable_function(func: Any) -> bool:
return getattr(func, "__langsmith_traceable__", False)


Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langsmith"
version = "0.1.142"
version = "0.1.143"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
authors = ["LangChain <support@langchain.dev>"]
license = "MIT"
Expand Down
43 changes: 31 additions & 12 deletions python/tests/unit_tests/evaluation/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def request(self, verb: str, endpoint: str, *args, **kwargs):
res = MagicMock()
res.json.return_value = {
"runs": [
r for r in self.runs.values() if "reference_example_id" in r
r
for r in self.runs.values()
if r["trace_id"] == r["id"] and r.get("reference_example_id")
]
}
return res
Expand Down Expand Up @@ -120,7 +122,8 @@ def _wait_until(condition: Callable, timeout: int = 8):

@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher")
@pytest.mark.parametrize("blocking", [False, True])
def test_evaluate_results(blocking: bool) -> None:
@pytest.mark.parametrize("as_runnable", [False, True])
def test_evaluate_results(blocking: bool, as_runnable: bool) -> None:
session = mock.Mock()
ds_name = "my-dataset"
ds_id = "00886375-eb2a-4038-9032-efff60309896"
Expand Down Expand Up @@ -180,6 +183,15 @@ def predict(inputs: dict) -> dict:
ordering_of_stuff.append("predict")
return {"output": inputs["in"] + 1}

if as_runnable:
try:
from langchain_core.runnables import RunnableLambda
except ImportError:
pytest.skip("langchain-core not installed.")
return
else:
predict = RunnableLambda(predict)

def score_value_first(run, example):
ordering_of_stuff.append("evaluate")
return {"score": 0.3}
Expand Down Expand Up @@ -263,26 +275,24 @@ async def my_other_func(inputs: dict, other_val: int):
with pytest.raises(ValueError, match=match):
evaluate(functools.partial(my_other_func, other_val=3), data="foo")

if sys.version_info < (3, 10):
return
try:
from langchain_core.runnables import RunnableLambda
except ImportError:
pytest.skip("langchain-core not installed.")

@RunnableLambda
def foo(inputs: dict):
return "bar"

with pytest.raises(ValueError, match=match):
evaluate(foo.ainvoke, data="foo")
if sys.version_info < (3, 10):
return
with pytest.raises(ValueError, match=match):
evaluate(functools.partial(foo.ainvoke, inputs={"foo": "bar"}), data="foo")
evaluate(
functools.partial(RunnableLambda(my_func).ainvoke, inputs={"foo": "bar"}),
data="foo",
)


@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher")
@pytest.mark.parametrize("blocking", [False, True])
async def test_aevaluate_results(blocking: bool) -> None:
@pytest.mark.parametrize("as_runnable", [False, True])
async def test_aevaluate_results(blocking: bool, as_runnable: bool) -> None:
session = mock.Mock()
ds_name = "my-dataset"
ds_id = "00886375-eb2a-4038-9032-efff60309896"
Expand Down Expand Up @@ -343,6 +353,15 @@ async def predict(inputs: dict) -> dict:
ordering_of_stuff.append("predict")
return {"output": inputs["in"] + 1}

if as_runnable:
try:
from langchain_core.runnables import RunnableLambda
except ImportError:
pytest.skip("langchain-core not installed.")
return
else:
predict = RunnableLambda(predict)

async def score_value_first(run, example):
ordering_of_stuff.append("evaluate")
return {"score": 0.3}
Expand Down

0 comments on commit 62e66ca

Please sign in to comment.