Skip to content

Commit

Permalink
Refactor for simplicity, fix experiment creation
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 committed Jan 15, 2025
1 parent d07f42e commit 644aca5
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 109 deletions.
2 changes: 1 addition & 1 deletion js/src/run_trees.ts
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ export class RunTree implements BaseRun {
};
}

static getSharedClient(): Client {
private static getSharedClient(): Client {
if (!RunTree.sharedClient) {
RunTree.sharedClient = new Client();
}
Expand Down
2 changes: 2 additions & 0 deletions js/src/utils/jestlike/globals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import { isTracingEnabled } from "../../env.js";
import { EvaluationResult } from "../../evaluation/evaluator.js";
import { RunTree } from "../../run_trees.js";

export const DEFAULT_TEST_CLIENT = new Client();

export type TestWrapperAsyncLocalStorageData = {
enableTestTracking?: boolean;
dataset?: Dataset;
Expand Down
212 changes: 105 additions & 107 deletions js/src/utils/jestlike/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import * as path from "node:path";
import * as fs from "node:fs/promises";

import { getCurrentRunTree, traceable } from "../../traceable.js";
import { RunTree } from "../../run_trees.js";
import { KVMap, TracerSession } from "../../schemas.js";
import { randomName } from "../../evaluation/_random_name.js";
import { Client, CreateProjectParams } from "../../client.js";
Expand All @@ -25,6 +24,7 @@ import {
_logTestFeedback,
syncExamplePromises,
trackingEnabled,
DEFAULT_TEST_CLIENT,
} from "./globals.js";
import { wrapExpect } from "./vendor/chain.js";
import { EvaluationResult } from "../../evaluation/evaluator.js";
Expand Down Expand Up @@ -269,7 +269,7 @@ export function generateWrapperFromJestlikeMethods(
enableTestTracking?: boolean;
} & Partial<Omit<CreateProjectParams, "referenceDatasetId">>
) {
const client = experimentConfig?.client ?? RunTree.getSharedClient();
const client = experimentConfig?.client ?? DEFAULT_TEST_CLIENT;
return method(datasetName, () => {
const suiteUuid = v4();
const context = {
Expand Down Expand Up @@ -365,8 +365,72 @@ export function generateWrapperFromJestlikeMethods(
const testInput: I = inputs;
const testOutput: O = expected;
const testFeedback: EvaluationResult[] = [];
let testReturnValue: unknown;
const onFeedbackLogged = (feedback: EvaluationResult) =>
testFeedback.push(feedback);
let loggedOutput: Record<string, unknown> | undefined;
const setLoggedOutput = (value: Record<string, unknown>) => {
if (loggedOutput !== undefined) {
console.warn(
`[WARN]: New "logOutput()" call will override output set by previous "logOutput()" call.`
);
}
loggedOutput = value;
};
let exampleId: string;
const runTestFn = async () => {
const testContext =
testWrapperAsyncLocalStorageInstance.getStore();
if (testContext === undefined) {
throw new Error(
"Could not identify test context. Please contact us for help."
);
}
try {
const res = await testFn({
inputs: testInput,
expected: testOutput,
});
_logTestFeedback({
exampleId,
feedback: { key: "pass", score: true },
context: testContext,
runTree: trackingEnabled(testContext)
? getCurrentRunTree()
: undefined,
client: testContext.client,
});
if (res != null) {
if (loggedOutput !== undefined) {
console.warn(
`[WARN]: Returned value from test function will override output set by previous "logOutput()" call.`
);
}
loggedOutput =
typeof res === "object"
? (res as Record<string, unknown>)
: { result: res };
}
return loggedOutput;
} catch (e: any) {
_logTestFeedback({
exampleId,
feedback: { key: "pass", score: false },
context: testContext,
runTree: trackingEnabled(testContext)
? getCurrentRunTree()
: undefined,
client: testContext.client,
});
const rawError = e;
const strippedErrorMessage = e.message.replace(
STRIP_ANSI_REGEX,
""
);
const langsmithFriendlyError = new Error(strippedErrorMessage);
(langsmithFriendlyError as any).rawJestError = rawError;
throw langsmithFriendlyError;
}
};
try {
if (trackingEnabled(context)) {
const missingFields = [];
Expand All @@ -388,14 +452,16 @@ export function generateWrapperFromJestlikeMethods(
)} while syncing to LangSmith. Please contact us for help.`
);
}
const testClient = client;
const exampleId = getExampleId(dataset.name, inputs, expected);
exampleId = getExampleId(dataset.name, inputs, expected);

// Create or update the example in the background
// TODO: Create or update the example in the background
// Currently run end time has to be after example modified time
// for examples to render properly, so we must modify the example
// first before running the test.
if (syncExamplePromises.get(exampleId) === undefined) {
syncExamplePromises.set(
exampleId,
syncExample({
await syncExample({
client,
exampleId,
datasetId: dataset.id,
Expand All @@ -407,140 +473,72 @@ export function generateWrapperFromJestlikeMethods(
);
}

// .enterWith is OK here
testWrapperAsyncLocalStorageInstance.enterWith({
...context,
currentExample: {
inputs,
outputs: expected,
id: exampleId,
},
client: testClient,
});

const traceableOptions = {
reference_example_id: exampleId,
project_name: project!.name,
metadata: {
...config?.metadata,
},
client: testClient,
client,
tracingEnabled: true,
name,
};

// Pass inputs into traceable so tracing works correctly but
// provide both to the user-defined test function
const tracedFunction = traceable(
async (_: I) => {
const testContext =
testWrapperAsyncLocalStorageInstance.getStore();
if (testContext === undefined) {
throw new Error(
"Could not identify test context. Please contact us for help."
);
}
try {
const res =
await testWrapperAsyncLocalStorageInstance.run(
{
...testContext,
setLoggedOutput: (value) => {
if (loggedOutput !== undefined) {
console.warn(
`[WARN]: New "logOutput()" call will override output set by previous "logOutput()" call.`
);
}
loggedOutput = value;
},
onFeedbackLogged: (feedback) =>
testFeedback.push(feedback),
},
async () => {
return testFn({
inputs: testInput,
expected: testOutput,
});
}
);
_logTestFeedback({
exampleId,
feedback: { key: "pass", score: true },
context: testContext,
runTree: getCurrentRunTree(),
client: testClient,
});
return res;
} catch (e: any) {
_logTestFeedback({
exampleId,
feedback: { key: "pass", score: false },
context: testContext,
runTree: getCurrentRunTree(),
client: testClient,
});
const rawError = e;
const strippedErrorMessage = e.message.replace(
STRIP_ANSI_REGEX,
""
);
const langsmithFriendlyError = new Error(
strippedErrorMessage
);
(langsmithFriendlyError as any).rawJestError = rawError;
throw langsmithFriendlyError;
}
async () => {
return testWrapperAsyncLocalStorageInstance.run(
{
...context,
currentExample: {
inputs,
outputs: expected,
id: exampleId,
},
setLoggedOutput,
onFeedbackLogged,
},
runTestFn
);
},
{ ...traceableOptions, ...config }
{
...traceableOptions,
...config,
}
);
try {
testReturnValue = await (tracedFunction as any)(testInput);
await tracedFunction(testInput);
} catch (e: any) {
// Extract raw Jest error from LangSmith formatted one and throw
if (e.rawJestError !== undefined) {
throw e.rawJestError;
}
throw e;
}
} else {
testReturnValue =
try {
await testWrapperAsyncLocalStorageInstance.run(
{
...context,
currentExample: {
inputs: testInput,
outputs: testOutput,
},
setLoggedOutput: (value) => {
if (loggedOutput !== undefined) {
console.warn(
`[WARN]: New "logOutput()" call will override output set by previous "logOutput()" call.`
);
}
loggedOutput = value;
},
onFeedbackLogged: (feedback) =>
testFeedback.push(feedback),
setLoggedOutput,
onFeedbackLogged,
},
async () => {
return testFn({
inputs: testInput,
expected: testOutput,
});
}
);
}
} finally {
if (testReturnValue != null) {
if (loggedOutput !== undefined) {
console.warn(
`[WARN]: Returned value from test function will override output set by previous "logOutput()" call.`
runTestFn
);
} catch (e: any) {
// Extract raw Jest error from LangSmith formatted one and throw
if (e.rawJestError !== undefined) {
throw e.rawJestError;
}
throw e;
}
loggedOutput =
typeof testReturnValue === "object"
? (testReturnValue as Record<string, unknown>)
: { result: testReturnValue };
}
} finally {
await fs.mkdir(path.dirname(resultsPath), { recursive: true });
await fs.writeFile(
resultsPath,
Expand Down
9 changes: 8 additions & 1 deletion js/src/utils/jestlike/reporter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@ import { STRIP_ANSI_REGEX } from "./index.js";

const FEEDBACK_COLLAPSE_THRESHOLD = 64;

const RESERVED_KEYS = ["Name", "Result", "Inputs", "Expected", "Actual"];
const RESERVED_KEYS = [
"Name",
"Result",
"Inputs",
"Expected",
"Actual",
"pass",
];

function formatTestName(name: string, duration: number) {
if (duration != null) {
Expand Down

0 comments on commit 644aca5

Please sign in to comment.