diff --git a/src/main/java/org/opensearch/agent/tools/PPLTool.java b/src/main/java/org/opensearch/agent/tools/PPLTool.java index ce17c594..29ff510a 100644 --- a/src/main/java/org/opensearch/agent/tools/PPLTool.java +++ b/src/main/java/org/opensearch/agent/tools/PPLTool.java @@ -87,6 +87,8 @@ public class PPLTool implements Tool { private String contextPrompt; + private Boolean execute; + private PPLModelType pplModelType; private static Gson gson = new Gson(); @@ -120,7 +122,7 @@ public static PPLModelType from(String value) { } - public PPLTool(Client client, String modelId, String contextPrompt, String pplModelType) { + public PPLTool(Client client, String modelId, String contextPrompt, String pplModelType, boolean execute) { this.client = client; this.modelId = modelId; this.pplModelType = PPLModelType.from(pplModelType); @@ -129,6 +131,7 @@ public PPLTool(Client client, String modelId, String contextPrompt, String pplMo } else { this.contextPrompt = contextPrompt; } + this.execute = execute; } @Override @@ -171,6 +174,10 @@ public void run(Map parameters, ActionListener listener) ModelTensor modelTensor = modelTensors.getMlModelTensors().get(0); Map dataAsMap = (Map) modelTensor.getDataAsMap(); String ppl = parseOutput(dataAsMap.get("response"), indexName); + if (!this.execute) { + listener.onResponse((T) ppl); + return; + } JSONObject jsonContent = new JSONObject(ImmutableMap.of("query", ppl)); PPLQueryRequest pplQueryRequest = new PPLQueryRequest(ppl, jsonContent, null, "jdbc"); TransportPPLQueryRequest transportPPLQueryRequest = new TransportPPLQueryRequest(pplQueryRequest); @@ -255,7 +262,8 @@ public PPLTool create(Map map) { client, (String) map.get("model_id"), (String) map.getOrDefault("prompt", ""), - (String) map.getOrDefault("model_type", "") + (String) map.getOrDefault("model_type", ""), + (boolean) map.getOrDefault("execute", true) ); } diff --git a/src/test/java/org/opensearch/agent/tools/PPLToolTests.java b/src/test/java/org/opensearch/agent/tools/PPLToolTests.java index 680586d0..129c2411 100644 --- a/src/test/java/org/opensearch/agent/tools/PPLToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/PPLToolTests.java @@ -138,6 +138,19 @@ public void testTool() { } + @Test + public void testTool_with_WithoutExecution() { + PPLTool tool = PPLTool.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "model_type", "claude", "execute", false)); + assertEquals(PPLTool.TYPE, tool.getName()); + + tool.run(ImmutableMap.of("index", "demo", "question", "demo"), ActionListener.wrap(executePPLResult -> { + assertEquals("source=demo| head 1", executePPLResult); + }, e -> { log.info(e); })); + + } + @Test public void testTool_with_DefaultPrompt() { PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "model_type", "claude"));