Skip to content

Commit bb20d57

Browse files
authored
Add: expose files as a simulated function call for JIT actions (#8643)
* Add: expose files as a simulated function call for JIT actions * Review fdbk * Small pimp of the debug conversation page
1 parent 6f1ef4a commit bb20d57

File tree

8 files changed

+189
-143
lines changed

8 files changed

+189
-143
lines changed

front/admin/cli.ts

+5-12
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ import parseArgs from "minimist";
88

99
import { getConversation } from "@app/lib/api/assistant/conversation";
1010
import { renderConversationForModel } from "@app/lib/api/assistant/generation";
11-
import { getTextContentFromMessage } from "@app/lib/api/assistant/utils";
11+
import {
12+
getTextContentFromMessage,
13+
getTextRepresentationFromMessages,
14+
} from "@app/lib/api/assistant/utils";
1215
import config from "@app/lib/api/config";
1316
import { getDataSources } from "@app/lib/api/data_sources";
1417
import { garbageCollectGoogleDriveDocument } from "@app/lib/api/poke/plugins/data_sources/garbage_collect_google_drive_document";
@@ -347,17 +350,7 @@ const conversation = async (command: string, args: parseArgs.ParsedArgs) => {
347350
const messages = renderedConvo.modelConversation.messages;
348351

349352
const tokenCountRes = await tokenCountForTexts(
350-
[
351-
...messages.map((m) => {
352-
let text = `${m.role} ${"name" in m ? m.name : ""} ${getTextContentFromMessage(m)}`;
353-
if ("function_calls" in m) {
354-
text += m.function_calls
355-
.map((f) => `${f.name} ${f.arguments}`)
356-
.join(" ");
357-
}
358-
return text;
359-
}),
360-
],
353+
getTextRepresentationFromMessages(messages),
361354
model
362355
);
363356
if (tokenCountRes.isErr()) {

front/lib/api/assistant/generation.ts

+5-13
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ import {
3232
isJITActionsEnabled,
3333
renderConversationForModelJIT,
3434
} from "@app/lib/api/assistant/jit_actions";
35-
import { getTextContentFromMessage } from "@app/lib/api/assistant/utils";
35+
import {
36+
getTextContentFromMessage,
37+
getTextRepresentationFromMessages,
38+
} from "@app/lib/api/assistant/utils";
3639
import { getVisualizationPrompt } from "@app/lib/api/assistant/visualization";
3740
import type { Authenticator } from "@app/lib/auth";
3841
import { renderContentFragmentForModel } from "@app/lib/resources/content_fragment_resource";
@@ -386,18 +389,7 @@ async function renderConversationForModelMultiActions({
386389

387390
// Compute in parallel the token count for each message and the prompt.
388391
const res = await tokenCountForTexts(
389-
[
390-
prompt,
391-
...messages.map((m) => {
392-
let text = `${m.role} ${"name" in m ? m.name : ""} ${getTextContentFromMessage(m)}`;
393-
if ("function_calls" in m) {
394-
text += m.function_calls
395-
.map((f) => `${f.name} ${f.arguments}`)
396-
.join(" ");
397-
}
398-
return text;
399-
}),
400-
],
392+
[prompt, ...getTextRepresentationFromMessages(messages)],
401393
model
402394
);
403395

front/lib/api/assistant/jit_actions.ts

+85-13
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,22 @@ import type {
1212
import {
1313
assertNever,
1414
Err,
15+
getTablesQueryResultsFileAttachment,
1516
isAgentMessageType,
1617
isContentFragmentMessageTypeModel,
1718
isContentFragmentType,
1819
isDevelopment,
20+
isTablesQueryActionType,
1921
isTextContent,
2022
isUserMessageType,
2123
Ok,
2224
removeNulls,
2325
} from "@dust-tt/types";
2426

25-
import { getTextContentFromMessage } from "@app/lib/api/assistant/utils";
27+
import {
28+
getTextContentFromMessage,
29+
getTextRepresentationFromMessages,
30+
} from "@app/lib/api/assistant/utils";
2631
import type { Authenticator } from "@app/lib/auth";
2732
import { getFeatureFlags } from "@app/lib/auth";
2833
import { renderContentFragmentForModel } from "@app/lib/resources/content_fragment_resource";
@@ -217,20 +222,47 @@ export async function renderConversationForModelJIT({
217222
}
218223
}
219224

225+
// If we have messages...
226+
if (messages.length > 0) {
227+
const { filesAsXML, hasFiles } = listConversationFiles({
228+
conversation,
229+
});
230+
231+
// ... and files, we simulate a function call to list the files at the end of the conversation.
232+
if (hasFiles) {
233+
const randomCallId = "tool_" + Math.random().toString(36).substring(7);
234+
const functionName = "list_conversation_files";
235+
236+
const simulatedAgentMessages = [
237+
// 1. We add a message from the agent, asking to use the files listing function
238+
{
239+
role: "assistant",
240+
function_calls: [
241+
{
242+
id: randomCallId,
243+
name: functionName,
244+
arguments: "{}",
245+
},
246+
],
247+
} as AssistantFunctionCallMessageTypeModel,
248+
249+
// 2. We add a message with the resulting files listing
250+
{
251+
function_call_id: randomCallId,
252+
role: "function",
253+
name: functionName,
254+
content: filesAsXML,
255+
} as FunctionMessageTypeModel,
256+
];
257+
258+
// Append the simulated messages to the end of the conversation.
259+
messages.push(...simulatedAgentMessages);
260+
}
261+
}
262+
220263
// Compute in parallel the token count for each message and the prompt.
221264
const res = await tokenCountForTexts(
222-
[
223-
prompt,
224-
...messages.map((m) => {
225-
let text = `${m.role} ${"name" in m ? m.name : ""} ${getTextContentFromMessage(m)}`;
226-
if ("function_calls" in m) {
227-
text += m.function_calls
228-
.map((f) => `${f.name} ${f.arguments}`)
229-
.join(" ");
230-
}
231-
return text;
232-
}),
233-
],
265+
[prompt, ...getTextRepresentationFromMessages(messages)],
234266
model
235267
);
236268

@@ -370,3 +402,43 @@ export async function renderConversationForModelJIT({
370402
tokensUsed,
371403
});
372404
}
405+
406+
function listConversationFiles({
407+
conversation,
408+
}: {
409+
conversation: ConversationType;
410+
}) {
411+
const fileAttachments: string[] = [];
412+
for (const m of conversation.content.flat(1)) {
413+
if (isContentFragmentType(m)) {
414+
if (!m.fileId) {
415+
continue;
416+
}
417+
fileAttachments.push(
418+
`<file id="${m.fileId}" name="${m.title}" type="${m.contentType}" />`
419+
);
420+
} else if (isAgentMessageType(m)) {
421+
for (const a of m.actions) {
422+
if (isTablesQueryActionType(a)) {
423+
const attachment = getTablesQueryResultsFileAttachment({
424+
resultsFileId: a.resultsFileId,
425+
resultsFileSnippet: a.resultsFileSnippet,
426+
output: a.output,
427+
includeSnippet: false,
428+
});
429+
if (attachment) {
430+
fileAttachments.push(attachment);
431+
}
432+
}
433+
}
434+
}
435+
}
436+
let filesAsXML = "<files>\n";
437+
438+
if (fileAttachments.length > 0) {
439+
filesAsXML += fileAttachments.join("\n");
440+
}
441+
filesAsXML += "\n</files>";
442+
443+
return { filesAsXML, hasFiles: fileAttachments.length > 0 };
444+
}

front/lib/api/assistant/utils.ts

+17
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,20 @@ export function getTextContentFromMessage(
2424
})
2525
.join("\n");
2626
}
27+
28+
// This function is used to get the text representation of the messages to calculate the token amount
29+
export function getTextRepresentationFromMessages(
30+
messages: ModelMessageTypeMultiActions[]
31+
): string[] {
32+
return [
33+
...messages.map((m) => {
34+
let text = `${m.role} ${"name" in m ? m.name : ""} ${getTextContentFromMessage(m)}`;
35+
if ("function_calls" in m) {
36+
text += m.function_calls
37+
.map((f) => `${f.name} ${f.arguments}`)
38+
.join(" ");
39+
}
40+
return text;
41+
}),
42+
];
43+
}

front/lib/api/assistant/visualization.ts

+52-92
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@ import {
77
removeNulls,
88
} from "@dust-tt/types";
99
import _ from "lodash";
10-
import * as readline from "readline"; // Add this line
11-
import type { Readable } from "stream";
1210

11+
import { isJITActionsEnabled } from "@app/lib/api/assistant/jit_actions";
1312
import type { Authenticator } from "@app/lib/auth";
1413
import { FileResource } from "@app/lib/resources/file_resource";
1514

@@ -20,104 +19,65 @@ export async function getVisualizationPrompt({
2019
auth: Authenticator;
2120
conversation: ConversationType;
2221
}) {
23-
const readFirstFiveLines = (inputStream: Readable): Promise<string[]> => {
24-
return new Promise((resolve, reject) => {
25-
const rl: readline.Interface = readline.createInterface({
26-
input: inputStream,
27-
crlfDelay: Infinity,
28-
});
22+
const isJITEnabled = await isJITActionsEnabled(auth);
2923

30-
let lineCount: number = 0;
31-
const lines: string[] = [];
32-
33-
rl.on("line", (line: string) => {
34-
lines.push(line);
35-
lineCount++;
36-
if (lineCount === 5) {
37-
rl.close();
38-
}
39-
});
40-
41-
rl.on("close", () => {
42-
resolve(lines);
43-
});
44-
45-
rl.on("error", (err: Error) => {
46-
reject(err);
47-
});
48-
});
49-
};
50-
51-
const contentFragmentMessages: Array<ContentFragmentType> = [];
52-
for (const m of conversation.content.flat(1)) {
53-
if (isContentFragmentType(m)) {
54-
contentFragmentMessages.push(m);
55-
}
56-
}
57-
const contentFragmentFileBySid = _.keyBy(
58-
await FileResource.fetchByIds(
59-
auth,
60-
removeNulls(contentFragmentMessages.map((m) => m.fileId))
61-
),
62-
"sId"
63-
);
64-
65-
const contentFragmentTextByMessageId: Record<string, string[]> = {};
66-
for (const m of contentFragmentMessages) {
67-
if (!m.fileId || !m.contentType.startsWith("text/")) {
68-
continue;
69-
}
70-
71-
const file = contentFragmentFileBySid[m.fileId];
72-
if (!file) {
73-
continue;
74-
}
75-
const readStream = file.getReadStream({
76-
auth,
77-
version: "original",
78-
});
79-
contentFragmentTextByMessageId[m.sId] =
80-
await readFirstFiveLines(readStream);
81-
}
82-
83-
let prompt = visualizationSystemPrompt.trim() + "\n\n";
84-
85-
const fileAttachments: string[] = [];
86-
for (const m of conversation.content.flat(1)) {
87-
if (isContentFragmentType(m)) {
88-
if (!m.fileId || !contentFragmentFileBySid[m.fileId]) {
89-
continue;
24+
// When JIT is enabled, we return the visualization prompt directly without listing the files as the files will be made available to the model via another mechanism (simulated function call).
25+
if (isJITEnabled) {
26+
return visualizationSystemPrompt.trim();
27+
} else {
28+
const contentFragmentMessages: Array<ContentFragmentType> = [];
29+
for (const m of conversation.content.flat(1)) {
30+
if (isContentFragmentType(m)) {
31+
contentFragmentMessages.push(m);
9032
}
91-
fileAttachments.push(
92-
`<file id="${m.fileId}" name="${m.title}" type="${m.contentType}" />`
93-
);
94-
} else if (isAgentMessageType(m)) {
95-
for (const a of m.actions) {
96-
if (isTablesQueryActionType(a)) {
97-
const attachment = getTablesQueryResultsFileAttachment({
98-
resultsFileId: a.resultsFileId,
99-
resultsFileSnippet: a.resultsFileSnippet,
100-
output: a.output,
101-
includeSnippet: false,
102-
});
103-
if (attachment) {
104-
fileAttachments.push(attachment);
33+
}
34+
const contentFragmentFileBySid = _.keyBy(
35+
await FileResource.fetchByIds(
36+
auth,
37+
removeNulls(contentFragmentMessages.map((m) => m.fileId))
38+
),
39+
"sId"
40+
);
41+
42+
let prompt = visualizationSystemPrompt.trim() + "\n\n";
43+
44+
const fileAttachments: string[] = [];
45+
for (const m of conversation.content.flat(1)) {
46+
if (isContentFragmentType(m)) {
47+
if (!m.fileId || !contentFragmentFileBySid[m.fileId]) {
48+
continue;
49+
}
50+
fileAttachments.push(
51+
`<file id="${m.fileId}" name="${m.title}" type="${m.contentType}" />`
52+
);
53+
} else if (isAgentMessageType(m)) {
54+
for (const a of m.actions) {
55+
if (isTablesQueryActionType(a)) {
56+
const attachment = getTablesQueryResultsFileAttachment({
57+
resultsFileId: a.resultsFileId,
58+
resultsFileSnippet: a.resultsFileSnippet,
59+
output: a.output,
60+
includeSnippet: false,
61+
});
62+
if (attachment) {
63+
fileAttachments.push(attachment);
64+
}
10565
}
10666
}
10767
}
10868
}
109-
}
11069

111-
if (fileAttachments.length > 0) {
112-
prompt +=
113-
"Files accessible to the :::visualization directive environment:\n";
114-
prompt += fileAttachments.join("\n");
115-
} else {
116-
prompt +=
117-
"No files are currently accessible to the :::visualization directive environment in this conversation.";
118-
}
70+
if (fileAttachments.length > 0) {
71+
prompt +=
72+
"Files accessible to the :::visualization directive environment:\n";
73+
prompt += fileAttachments.join("\n");
74+
} else {
75+
prompt +=
76+
"No files are currently accessible to the :::visualization directive environment in this conversation.";
77+
}
11978

120-
return prompt;
79+
return prompt;
80+
}
12181
}
12282

12383
export const visualizationSystemPrompt = `\

0 commit comments

Comments
 (0)