Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
- Change model parameter type for repetition_penalty to double
- Return UNKNOWN for null sentiment string
  • Loading branch information
pberlandier committed Sep 16, 2024
1 parent d196f47 commit 317b24a
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 14 deletions.
3 changes: 2 additions & 1 deletion customer-review-rules/bom/watsonxai-model.bom
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ public class WatsonxAIRunner
public readonly string output;
public string projectId;
public string promptTemplate;
public int repetitionPenalty;
public double repetitionPenalty;
public readonly boolean valid;
public WatsonxAIRunner();
public void addVariable(string arg1, string arg2);
public void clearVariables();
public string normalize(string arg);
public void runInference(string arg1, string arg2, boolean arg3);
public void runInference();
}
Expand Down
3 changes: 3 additions & 0 deletions customer-review-xom/src/com/acme/Sentiment.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ public enum Sentiment {
POSITIVE, NEGATIVE, UNKNOWN;

public static Sentiment map(String value) {
if ( value == null ) {
return UNKNOWN;
}
switch (value) {
case "negative": return NEGATIVE;
case "positive": return POSITIVE;
Expand Down
30 changes: 17 additions & 13 deletions watsonx-helper/src/main/java/com/ibm/odm/WatsonxAIRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import com.mashape.unirest.http.Unirest;

/**
* This class allows the preparation and invocation of a **watsonx.ai** inference.
* It holds the definition of a template payload for REST API end-point call
* along with the model parameters to run the inference, and also the prompt template with
* the target variable values.
* This class allows the preparation and invocation of a **watsonx.ai**
* inference. It holds the definition of a template payload for REST API
* end-point call along with the model parameters to run the inference, and also
* the prompt template with the target variable values.
*
* @author PIERREBerlandier
*
Expand All @@ -31,20 +31,21 @@ public class WatsonxAIRunner {
private boolean exception = false;
private String message;
//
// Model parameters with a default value.
// Model parameters: you can provide a default value for these parameters or set
// their value as part of the ODM ruleflow using the WatsonxAIRunner setters.
//
private String projectId = "<project-id>";
private String modelId = "google/flan-ul2";
private String projectId = "<your-watsonx-project-id>";
private String modelId = "<your-favorite-llm>";
private String decodingMethod = "greedy";
private int maxNewTokens = 5;
private int minNewTokens = 0;
private int repetitionPenalty = 1;
private double repetitionPenalty = 1.0;
//
// API end-point payload template.
//
private static String inputTemplate = "{\r\n" + " \"input\": \"%s\",\r\n" + " \"parameters\": {\r\n"
+ " \"decoding_method\": \"%s\",\r\n" + " \"max_new_tokens\": %d,\r\n" + " \"min_new_tokens\": %d,\r\n"
+ " \"stop_sequences\": [],\r\n" + " \"repetition_penalty\": %d\r\n" + " },\r\n"
+ " \"stop_sequences\": [],\r\n" + " \"repetition_penalty\": %f\r\n" + " },\r\n"
+ " \"model_id\": \"%s\",\r\n" + " \"project_id\": \"%s\"\r\n" + "}";

public String getDecodingMethod() {
Expand All @@ -71,11 +72,11 @@ public void setMinNewTokens(int minNewTokens) {
this.minNewTokens = minNewTokens;
}

public int getRepetitionPenalty() {
public double getRepetitionPenalty() {
return repetitionPenalty;
}

public void setRepetitionPenalty(int repetitionPenalty) {
public void setRepetitionPenalty(double repetitionPenalty) {
this.repetitionPenalty = repetitionPenalty;
}

Expand Down Expand Up @@ -149,7 +150,8 @@ public String normalize(String response) {

/**
* Runs the web service call with Unirest.
* @param url URL for the deployed watsonx.ai instance
*
* @param url URL for the deployed watsonx.ai instance
* @param apiKey
* @param normalize If true, applies some post-processing on the output string.
*/
Expand All @@ -170,13 +172,15 @@ public void runInference(String url, String apiKey, boolean normalize) {
output = normalize(output);
}
} catch (Exception e) {
output = null;
exception = true;
message = e.getMessage();
}
}

/**
* Gets a bearer token for the web service execution.
*
* @param apiKey
* @return
* @throws Exception
Expand Down

0 comments on commit 317b24a

Please sign in to comment.