Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ability to override default values from config #58

Merged
merged 3 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion chatbot_ui/app/run_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import gradio as gr

api_key = os.getenv("OPENAI_API_KEY")
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT", "https://api.openai.com/v1")

client = OpenAI(api_key=api_key)
client = OpenAI(api_key=api_key, base_url=CHAT_COMPLETION_ENDPOINT)

def predict(message, history):
history_openai_format = []
Expand Down
3 changes: 3 additions & 0 deletions demos/function_calling/bolt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ default_prompt_endpoint: "127.0.0.1"
load_balancing: "round_robin"
timeout_ms: 5000

overrides:
# confidence threshold for prompt target intent matching
prompt_target_intent_matching_threshold: 0.6

# should not be here
embedding_provider:
Expand Down
13 changes: 12 additions & 1 deletion envoyfilter/src/filter_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use open_message_format_embeddings::models::{
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use public_types::common_types::EmbeddingType;
use public_types::configuration::{Configuration, PromptTarget};
use public_types::configuration::{Configuration, Overrides, PromptTarget};
use serde_json::to_string;
use std::collections::HashMap;
use std::rc::Rc;
Expand Down Expand Up @@ -45,6 +45,7 @@ pub struct FilterContext {
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: HashMap<u32, CallContext>,
config: Option<Configuration>,
overrides: Rc<Option<Overrides>>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
}

Expand All @@ -63,6 +64,7 @@ impl FilterContext {
config: None,
metrics: Rc::new(WasmMetrics::new()),
prompt_targets: Rc::new(RwLock::new(HashMap::new())),
overrides: Rc::new(None),
}
}

Expand Down Expand Up @@ -212,6 +214,14 @@ impl RootContext for FilterContext {
if let Some(config_bytes) = self.get_plugin_configuration() {
self.config = serde_yaml::from_slice(&config_bytes).unwrap();

if let Some(overrides_config) = self
.config
.as_mut()
.and_then(|config| config.overrides.as_mut())
{
self.overrides = Rc::new(Some(std::mem::take(overrides_config)));
}

for pt in self.config.clone().unwrap().prompt_targets {
self.prompt_targets
.write()
Expand All @@ -237,6 +247,7 @@ impl RootContext for FilterContext {
context_id,
Rc::clone(&self.metrics),
Rc::clone(&self.prompt_targets),
Rc::clone(&self.overrides),
)))
}

Expand Down
23 changes: 19 additions & 4 deletions envoyfilter/src/stream_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use public_types::common_types::{
BoltFCResponse, BoltFCToolsCall, EmbeddingType, ToolParameter, ToolParameters, ToolsDefinition,
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
};
use public_types::configuration::{PromptTarget, PromptType};
use public_types::configuration::{Overrides, PromptTarget, PromptType};
use std::collections::HashMap;
use std::num::NonZero;
use std::rc::Rc;
Expand All @@ -50,6 +50,7 @@ pub struct StreamContext {
pub context_id: u32,
pub metrics: Rc<WasmMetrics>,
pub prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
pub overrides: Rc<Option<Overrides>>,
callouts: HashMap<u32, CallContext>,
host_header: Option<String>,
ratelimit_selector: Option<Header>,
Expand All @@ -63,6 +64,7 @@ impl StreamContext {
context_id: u32,
metrics: Rc<WasmMetrics>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
overrides: Rc<Option<Overrides>>,
) -> Self {
StreamContext {
context_id,
Expand All @@ -74,6 +76,7 @@ impl StreamContext {
streaming_response: false,
response_tokens: 0,
chat_completions_request: false,
overrides,
}
}
fn save_host_header(&mut self) {
Expand Down Expand Up @@ -263,7 +266,7 @@ impl StreamContext {
+ callout_context.similarity_scores.as_ref().unwrap()[0].1 * 0.3;

debug!(
"similarity score: {}, intent score: {}, description embedding score: {}",
"similarity score: {:.3}, intent score: {:.3}, description embedding score: {:.3}",
prompt_target_similarity_score,
zeroshot_intent_response.predicted_class_score,
callout_context.similarity_scores.as_ref().unwrap()[0].1
Expand All @@ -286,16 +289,28 @@ impl StreamContext {
info!("no assistant message found, probably first interaction");
}

// get prompt target similarity thresold from overrides
let prompt_target_intent_matching_threshold = match self.overrides.as_ref() {
Some(overrides) => match overrides.prompt_target_intent_matching_threshold {
Some(threshold) => threshold,
None => DEFAULT_PROMPT_TARGET_THRESHOLD,
},
None => DEFAULT_PROMPT_TARGET_THRESHOLD,
};

// check to ensure that the prompt target similarity score is above the threshold
if prompt_target_similarity_score < DEFAULT_PROMPT_TARGET_THRESHOLD && !bolt_assistant {
if prompt_target_similarity_score < prompt_target_intent_matching_threshold
&& !bolt_assistant
{
// if bolt fc responded to the user message, then we don't need to check the similarity score
// it may be that bolt fc is handling the conversation for parameter collection
if bolt_assistant {
info!("bolt assistant is handling the conversation");
} else {
info!(
"prompt target below threshold: {}, continue conversation with user",
"prompt target below limit: {:.3}, threshold: {:.3}, continue conversation with user",
prompt_target_similarity_score,
prompt_target_intent_matching_threshold
);
self.resume_http_request();
return;
Expand Down
6 changes: 6 additions & 0 deletions public_types/src/configuration.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Overrides {
pub prompt_target_intent_matching_threshold: Option<f64>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Configuration {
pub default_prompt_endpoint: String,
pub load_balancing: LoadBalancing,
pub timeout_ms: u64,
pub overrides: Option<Overrides>,
pub embedding_provider: EmbeddingProviver,
pub llm_providers: Vec<LlmProvider>,
pub prompt_guards: Option<PromptGuard>,
Expand Down