Skip to content

Commit

Permalink
Support llamacpp (#21)
Browse files Browse the repository at this point in the history
* refactor: use role in the chat

* feat: add support for llama.cpp

* update changelog

* update Readme

* support api key from env variable
  • Loading branch information
pythops authored Jan 31, 2024
1 parent cead45f commit fe47cfe
Show file tree
Hide file tree
Showing 11 changed files with 264 additions and 39 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [0.11] TBA

### Added

- Support for [llama.cpp](https://github.com/ggerganov/llama.cpp)

## [0.10] 27/01/2024

### Added
Expand Down
25 changes: 23 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 15 additions & 14 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,25 @@ homepage = "https://github.com/pythops/tenere"
repository = "https://github.com/pythops/tenere"

[dependencies]
ansi-to-tui = "3"
arboard = "3"
async-trait = "0.1"
bat = "0.24"
clap = { version = "4", features = ["derive", "cargo"] }
crossterm = { version = "0.27", features = ["event-stream"] }
ratatui = { version = "0.25", features = ["all-widgets"] }
tui-textarea = { version = "0.4" }
unicode-width = "0.1"
dirs = "5"
futures = "0.3"
reqwest = { version = "0.11", default-features = false, features = [
"json",
"rustls-tls",
] }
serde_json = "1"
ansi-to-tui = "3"

clap = { version = "4", features = ["derive", "cargo"] }
toml = { version = "0.8" }
serde = { version = "1", features = ["derive"] }
dirs = "5"
ratatui = { version = "0.25", features = ["all-widgets"] }
regex = "1"
bat = "0.24"
arboard = "3"
async-trait = "0.1"
futures = "0.3"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
strum = "0.26"
strum_macros = "0.26"
tokio = { version = "1", features = ["full"] }
toml = { version = "0.8" }
tui-textarea = { version = "0.4" }
unicode-width = "0.1"
49 changes: 43 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

- [x] ChatGPT
- [ ] ollama (todo)
- [ ] llama.cpp (todo)
- [x] llama.cpp (in the `master` branch)

<br>

Expand Down Expand Up @@ -78,11 +78,11 @@ Tenere can be configured using a TOML configuration file. The file should be loc
Here are the available general settings:

- `archive_file_name`: the file name where the chat will be saved. By default it is set to `tenere.archive`
- `model`: the llm model name. Currently only `chatgpt` is supported.
- `model`: the llm model name. Possible values are: `chatgpt` and `llamacpp`.

```toml
archive_file_name = "tenere.archive"
model = "chatgpt"
llm = "chatgpt"
```

### Key bindings
Expand All @@ -104,15 +104,17 @@ save_chat = 's'
## Chatgpt

To use Tenere's chat functionality, you'll need to provide an API key for OpenAI. There are two ways to do this:
To use `chatgpt` as the backemd, you'll need to provide an API key for OpenAI. There are two ways to do this:

1. Set an environment variable with your API key:
Set an environment variable with your API key:

```shell
export OPENAI_API_KEY="YOUTR KEY HERE"
```

2. Include your API key in the configuration file:
Or

Include your API key in the configuration file:

```toml
[chatgpt]
Expand All @@ -123,6 +125,33 @@ url = "https://api.openai.com/v1/chat/completions"

The default model is set to `gpt-3.5-turbo`. Check out the [OpenAI documentation](https://platform.openai.com/docs/models/gpt-3-5) for more info.

## llama.cpp

To use `llama.cpp` as the backemd, you'll need to provide the url that points to the server :

```toml
[llamacpp]
url = "http://localhost:8080/v1/chat/completions"
```

If you configure the server with an api key, then you need to provide it as well:

Setting an environment variable :

```shell
export LLAMACPP_API_KEY="YOUTR KEY HERE"
```

Or

Include your API key in the configuration file:

```toml
[llamacpp]
url = "http://localhost:8080/v1/chat/completions"
api_key = "Your API Key here"
```

<br>

## ⌨️ Key bindings
Expand Down Expand Up @@ -248,3 +277,11 @@ There are 3 modes like vim: `Normal`, `Visual` and `Insert`.
## ⚖️ License

AGPLv3

```
```

```
```
8 changes: 4 additions & 4 deletions src/chatgpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use regex::Regex;
use tokio::sync::mpsc::UnboundedSender;

use crate::config::ChatGPTConfig;
use crate::llm::{LLMAnswer, LLM};
use crate::llm::{LLMAnswer, LLMRole, LLM};
use reqwest::header::HeaderMap;
use serde_json::{json, Value};
use std;
Expand Down Expand Up @@ -56,10 +56,10 @@ impl LLM for ChatGPT {
self.messages = Vec::new();
}

fn append_chat_msg(&mut self, chat: String) {
fn append_chat_msg(&mut self, msg: String, role: LLMRole) {
let mut conv: HashMap<String, String> = HashMap::new();
conv.insert("role".to_string(), "user".to_string());
conv.insert("content".to_string(), chat);
conv.insert("role".to_string(), role.to_string());
conv.insert("content".to_string(), msg);
self.messages.push(conv);
}

Expand Down
19 changes: 18 additions & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ pub struct Config {
pub key_bindings: KeyBindings,

#[serde(default = "default_llm_backend")]
pub model: LLMBackend,
pub llm: LLMBackend,

#[serde(default)]
pub chatgpt: ChatGPTConfig,

pub llamacpp: Option<LLamacppConfig>,
}

pub fn default_archive_file_name() -> String {
Expand All @@ -27,6 +29,7 @@ pub fn default_llm_backend() -> LLMBackend {
LLMBackend::ChatGPT
}

// ChatGPT
#[derive(Deserialize, Debug, Clone)]
pub struct ChatGPTConfig {
pub openai_api_key: Option<String>,
Expand Down Expand Up @@ -58,6 +61,14 @@ impl ChatGPTConfig {
}
}

// LLamacpp

#[derive(Deserialize, Debug, Clone)]
pub struct LLamacppConfig {
pub url: String,
pub api_key: Option<String>,
}

#[derive(Deserialize, Debug)]
pub struct KeyBindings {
#[serde(default = "KeyBindings::default_show_help")]
Expand Down Expand Up @@ -119,6 +130,12 @@ impl Config {

let config = std::fs::read_to_string(conf_path).unwrap_or_default();
let app_config: Config = toml::from_str(&config).unwrap();

if app_config.llm == LLMBackend::LLamacpp && app_config.llamacpp.is_none() {
eprintln!("Config for LLamacpp is not provided");
std::process::exit(1)
}

app_config
}
}
6 changes: 3 additions & 3 deletions src/handler.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::llm::LLMAnswer;
use crate::llm::{LLMAnswer, LLMRole};
use crate::{chat::Chat, prompt::Mode};

use crate::{
Expand All @@ -20,7 +20,7 @@ use tokio::sync::mpsc::UnboundedSender;
pub async fn handle_key_events(
key_event: KeyEvent,
app: &mut App<'_>,
llm: Arc<Mutex<impl LLM + 'static>>,
llm: Arc<Mutex<Box<dyn LLM + 'static>>>,
sender: UnboundedSender<Event>,
) -> AppResult<()> {
match key_event.code {
Expand Down Expand Up @@ -256,7 +256,7 @@ pub async fn handle_key_events(
let llm = llm.clone();
{
let mut llm = llm.lock().await;
llm.append_chat_msg(user_input.into());
llm.append_chat_msg(user_input.into(), LLMRole::USER);
}

app.spinner.active = true;
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@ pub mod help;
pub mod history;

pub mod chat;

pub mod llamacpp;
Loading

0 comments on commit fe47cfe

Please sign in to comment.