-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo_mistral_stream.rs
66 lines (52 loc) · 1.7 KB
/
demo_mistral_stream.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
use futures::TryStreamExt;
//use crate::bedrock::model_info::{ModelInfo, ModelName};
use crate::bedrock::models::mistral::MistralClient;
use crate::bedrock::models::mistral::MistralOptions;
use crate::bedrock::models::mistral::MistralRequestBuilder;
pub async fn demo_mistra_with_stream(model_id: &str, prompt: &str) {
let mistral_otions
= MistralOptions::new()
.profile_name("bedrock")
.region("us-west-2");
let client = MistralClient::new(mistral_otions).await;
let request = MistralRequestBuilder::new(prompt.to_owned())
.max_tokens(200)
.temperature(0.5)
.top_p(0.9)
.top_k(100)
.build();
let response_stream = client
.generate_with_stream(
model_id.to_string(),
&request
)
.await;
let response_stream = match response_stream {
Ok(response_stream) => response_stream,
Err(e) => {
println!("Error: {:?}", e);
return;
}
};
// consumme the stream and print the response
response_stream
.try_for_each(|chunk| async move {
let json_display = serde_json::to_string_pretty(&chunk).unwrap();
println!("{:?}", json_display);
Ok(())
})
.await
.unwrap();
}
// Test
#[cfg(test)]
mod tests {
use super::*;
use crate::bedrock::model_info::{ModelInfo, ModelName};
#[tokio::test]
async fn test_demo_chat_mistral_with_stream() {
let model_id = ModelInfo::from_model_name(ModelName::MistralMixtral8X7BInstruct0x);
let prompt = "<s>[INST] What is the capital of France ?[/INST]";
demo_mistra_with_stream(&model_id, prompt).await;
}
}