diff --git a/Cargo.toml b/Cargo.toml index 88b2963..53825ba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ license-file = "LICENSE" homepage = "https://github.com/groovybits/rsllm/wiki" repository = "https://github.com/groovybits/rsllm" authors = ["Chris Kennedy"] -version = "0.5.5" +version = "0.5.6" edition = "2021" [lib] diff --git a/scripts/twitch.sh b/scripts/twitch.sh index 69986da..5153769 100755 --- a/scripts/twitch.sh +++ b/scripts/twitch.sh @@ -10,12 +10,13 @@ # === CONFIGURATION === BUILD_TYPE=release -MODEL=gemma -MODEL_ID=2b-it +#MODEL=gemma +MODEL=mistral +MODEL_ID=7b-it MAX_TOKENS=200 ALIGNMENT=right TEMPERATURE=1.0 -CONTINUOUS=1 +CONTINUOUS=0 POLL_INTERVAL=0 PIPELINE_CONCURRENCY=3 TWITCH_LLM_CONCURRENCY=3 @@ -23,14 +24,17 @@ CONTEXT_SIZE=4000 SUBTITLES=1 DAEMON=1 KEEP_HISTORY=1 +QUANTIZED=0 GREETING="Hi I'm Alice, ask me a question!" TWITCH_CHAT_HISTORY=8 +TWITCH_MODEL=mistral # === END OF CONFIGURATION === # # SUBTITLE_CMD= DAEMON_CMD= NO_HISTORY_CMD= +QUANTIZED_CMD= if [ "$SUBTITLES" == 1 ]; then SUBTITLE_CMD="--subtitles" fi @@ -43,6 +47,9 @@ fi if [ "$KEEP_HISTORY" == 0 ]; then NO_HISTORY_CMD="--no-history" fi +if [ "$QUANTIZED" == 1 ]; then + QUANTIZED_CMD="--quantized" +fi SYSTEM_PROMPT="You are not an AI bot, you are Alice who is a real person and you run Alice's AI Wonderland twitch streaming channel, you chat with twitch users in your twitch channel who ask you questions and have discussions. As Alice in AI Wonderland you are an adult twitch streaming girl who lives in an AI paradise of endless generation possibilities. Always talk in first person as the yourself as Alice. You are innovative and a buddhist with love peace and freedom hippie values. Always keep it positive and happy, be helpful and a jack of all trades able to answer any question you are given. If you don't know for sure, just make up something funny. Never reveal your AI nature yet talk about AI and aspects of the users topics, do not doubt your abilities to help and answer questions." PROMPT="create a story based on an anime About Alice an adult twitch streaming girl who lives in AI Wonderland. Have it vary off the title 'Alice in AI Wonderland' with a random plotline you create based on classic anime characters appearing in the wonderland. Alices AI Wonderland is a happy fun show where Alice goes through experiences similar to Alice in Wonderland where she grows small or large depending one what she eats. Add in AI technology twists. Have it fully formatted like a transcript with the character speaking parts mostly speaking in first person, minimal narration. create a whole episode full length with classic anime characters with Alice the main character of AI Wonderland." @@ -56,6 +63,7 @@ DYLD_LIBRARY_PATH=`pwd`:/usr/local/lib:$DYLD_LIBRARY_PATH \ --twitch-client \ --twitch-chat-history $TWITCH_CHAT_HISTORY \ --twitch-llm-concurrency $TWITCH_LLM_CONCURRENCY \ + --twitch-model $TWITCH_MODEL \ --sd-image \ --ndi-audio \ --ndi-images \ @@ -71,4 +79,5 @@ DYLD_LIBRARY_PATH=`pwd`:/usr/local/lib:$DYLD_LIBRARY_PATH \ $DAEMON_CMD \ $CONTINUOUS_CMD \ $NO_HISTORY_CMD \ + $QUANTIZED_CMD \ --max-tokens $MAX_TOKENS $@ diff --git a/src/args.rs b/src/args.rs index 2bf86be..5cf2f68 100644 --- a/src/args.rs +++ b/src/args.rs @@ -4,7 +4,7 @@ use clap::Parser; #[derive(Parser, Debug, Clone)] #[clap( author = "Chris Kennedy", - version = "0.5.5", + version = "0.5.6", about = "Rust AI Stream Analyzer Twitch Bot" )] pub struct Args { @@ -703,4 +703,13 @@ pub struct Args { help = "Twitch Prompt." )] pub twitch_prompt: String, + + /// Twitch model - LLM to use, gemma or mistral for now + #[clap( + long, + env = "TWITCH_MODEL", + default_value = "gemma", + help = "Twitch LLM model." + )] + pub twitch_model: String, } diff --git a/src/twitch_client.rs b/src/twitch_client.rs index b179feb..6a15266 100644 --- a/src/twitch_client.rs +++ b/src/twitch_client.rs @@ -1,5 +1,6 @@ use crate::args::Args; use crate::candle_gemma::gemma; +use crate::candle_mistral::mistral; use anyhow::Result; use std::io::Write; use std::sync::atomic::{AtomicBool, Ordering}; @@ -97,10 +98,70 @@ async fn on_msg( // LLM Thread let (external_sender, mut external_receiver) = tokio::sync::mpsc::channel::(100); let max_tokens = 120; - let temperature = 1.0; - let quantized = true; + let temperature = 0.8; + let quantized = false; let max_messages = args.twitch_chat_history; + let system_start_token = if args.twitch_model == "gemma" { + "" + } else { + "<>" + }; + + let system_end_token = if args.twitch_model == "gemma" { + "" + } else { + "<>" + }; + + let assistant_start_token = if args.twitch_model == "gemma" { + "" + } else { + "" + }; + + let assistant_end_token = if args.twitch_model == "gemma" { + "" + } else { + "" + }; + + let start_token = if args.twitch_model == "gemma" { + "" + } else { + "[INST]" + }; + + let end_token = if args.twitch_model == "gemma" { + "" + } else { + "[/INST]" + }; + + let bos_token = if args.twitch_model == "gemma" { + "" + } else { + "" + }; + + let eos_token = if args.twitch_model == "gemma" { + "" + } else { + "" + }; + + let user_name = if args.twitch_model == "gemma" { + "user" + } else { + "" + }; + + let assistant_name = if args.twitch_model == "gemma" { + "model" + } else { + "" + }; + // Truncate the chat_messages array to 3 messages max messages if chat_messages.len() > max_messages { chat_messages.truncate(max_messages); @@ -115,27 +176,66 @@ async fn on_msg( // Send message to the AI through mpsc channels format to model specs let msg_text = format!( - "model {}{}user twitch chat user {} asked {}model ", + "{}{}{} {}{}{}{}{}{}{}{} twitch chat user {} asked {}{}{}{} ", + bos_token, + system_start_token, + assistant_name, args.twitch_prompt.clone(), + system_end_token, + eos_token, + bos_token, chat_messages_history, + bos_token, + start_token, + user_name, msg.sender().name(), - msg.text().to_string() + msg.text().to_string(), + end_token, + assistant_start_token, + assistant_name, ); // Clone the message text println!("\nTwitch sending msg_text:\n{}\n", msg_text); - let llm_thread = tokio::spawn(async move { - if let Err(e) = gemma( - msg_text, - max_tokens, - temperature, - quantized, - Some("2b-it".to_string()), - external_sender, - ) { - eprintln!("Error running twitch gemma: {}", e); - } - }); + let llm_thread = if args.twitch_model == "gemma" { + tokio::spawn(async move { + if let Err(e) = gemma( + msg_text, + max_tokens, + temperature, + quantized, + Some("2b-it".to_string()), + external_sender, + ) { + eprintln!("Error running twitch gemma: {}", e); + } + }) + } else if args.twitch_model == "mistral" { + tokio::spawn(async move { + if let Err(e) = mistral( + msg_text, + max_tokens, + temperature, + quantized, + Some("auto".to_string()), + external_sender, + ) { + eprintln!("Error running twitch mistral: {}", e); + } + }) + } else { + // print message and error out + eprintln!( + "Error: Invalid model specified for twitch chat {}", + args.twitch_model + ); + tokio::spawn(async move { + external_sender + .send("Error: Invalid model specified for twitch chat".to_string()) + .await + .unwrap(); + }) + }; // thread token collection and wait for it to finish let token_thread = tokio::spawn(async move { @@ -172,10 +272,18 @@ async fn on_msg( // add message to the chat_messages history of strings let full_message = format!( - "user {} asked {}model {}", + "{}{}{} {} asked {}{}{}{} {}{}{}", + bos_token, + start_token, + user_name, msg.sender().name(), msg.text().to_string(), - answer.clone() + end_token, + assistant_start_token, + assistant_name, + answer.clone(), + assistant_end_token, + eos_token ); chat_messages.push(full_message);