-
Notifications
You must be signed in to change notification settings - Fork 978
/
completion.rs
68 lines (57 loc) · 2.19 KB
/
completion.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
use std::sync::Arc;
use async_stream::stream;
use async_trait::async_trait;
use futures::{stream::BoxStream, StreamExt};
use ollama_rs::{
generation::{completion::request::GenerationRequest, options::GenerationOptions},
Ollama,
};
use tabby_common::config::HttpModelConfig;
use tabby_inference::{CompletionOptions, CompletionStream};
use tracing::error;
use crate::model::OllamaModelExt;
pub struct OllamaCompletion {
/// Connection to Ollama API
connection: Ollama,
/// Model name, <model>
model: String,
}
#[async_trait]
impl CompletionStream for OllamaCompletion {
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
// FIXME: options.presence_penalty is not used
let ollama_options = GenerationOptions::default()
.num_ctx(options.max_input_length as u32)
.num_predict(options.max_decoding_tokens)
.seed(options.seed as i32)
.repeat_last_n(0)
.temperature(options.sampling_temperature);
let request = GenerationRequest::new(self.model.to_owned(), prompt.to_owned())
.template("{{ .Prompt }}".to_string())
.options(ollama_options);
// Why this function returns not Result?
match self.connection.generate_stream(request).await {
Ok(stream) => {
let tabby_stream = stream! {
for await response in stream {
let parts = response.unwrap();
for part in parts {
yield part.response
}
}
};
tabby_stream.boxed()
}
Err(err) => {
error!("Failed to generate completion: {}", err);
futures::stream::empty().boxed()
}
}
}
}
pub async fn create(config: &HttpModelConfig) -> Arc<dyn CompletionStream> {
let connection = Ollama::try_new(config.api_endpoint.to_owned())
.expect("Failed to create connection to Ollama, URL invalid");
let model = connection.select_model_or_default(config).await.unwrap();
Arc::new(OllamaCompletion { connection, model })
}