Skip to content

Commit

Permalink
Allow new completions endpoint to reuse slots
Browse files Browse the repository at this point in the history
The llamafiler server also supports GPU now.
  • Loading branch information
jart committed Nov 2, 2024
1 parent ee4b51a commit 24a4b87
Show file tree
Hide file tree
Showing 22 changed files with 429 additions and 157 deletions.
6 changes: 3 additions & 3 deletions llamafile/chatbot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ static void on_context(const std::vector<std::string> &args) {
}

static void on_clear(const std::vector<std::string> &args) {
llama_kv_cache_seq_rm(g_ctx, 0, g_system_prompt_tokens, tokens_used() - g_system_prompt_tokens);
llama_kv_cache_seq_rm(g_ctx, 0, g_system_prompt_tokens, -1);
g_history.resize(g_system_prompt_tokens);
g_stack.clear();
fix_stacks();
Expand All @@ -228,7 +228,7 @@ static void on_pop(const std::vector<std::string> &args) {
}
printf(BOLD "%12d" RESET " restored " FAINT "(%s)" RESET "\n", g_stack.back(),
describe_position(g_stack.back()).c_str());
llama_kv_cache_seq_rm(g_ctx, 0, g_stack.back(), tokens_used() - g_stack.back());
llama_kv_cache_seq_rm(g_ctx, 0, g_stack.back(), -1);
g_history.resize(g_stack.back());
g_stack.pop_back();
fix_stacks();
Expand All @@ -244,7 +244,7 @@ static void on_undo(const std::vector<std::string> &args) {
}
printf(FAINT "restoring conversation to: %s" RESET "\n",
describe_position(g_undo.back()).c_str());
llama_kv_cache_seq_rm(g_ctx, 0, g_undo.back(), tokens_used() - g_undo.back());
llama_kv_cache_seq_rm(g_ctx, 0, g_undo.back(), -1);
g_history.resize(g_undo.back());
g_undo.pop_back();
fix_stacks();
Expand Down
12 changes: 11 additions & 1 deletion llamafile/flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,19 @@ void llamafile_get_flags(int argc, char **argv) {
// model flags

if (!strcmp(flag, "-c") || !strcmp(flag, "--ctx-size")) {
char *ep;
if (i == argc)
missing("--ctx-size");
FLAG_ctx_size = atoi(argv[i++]);
FLAG_ctx_size = strtol(argv[i++], &ep, 10);
if (*ep == 'k')
FLAG_ctx_size *= 1024;
continue;
}

if (!strcmp(flag, "-s") || !strcmp(flag, "--slots")) {
if (i == argc)
missing("--slots");
FLAG_slots = atoi(argv[i++]);
continue;
}

Expand Down
2 changes: 1 addition & 1 deletion llamafile/server/cleanup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ cleanup_float_vector(void* arg)
void
cleanup_token_vector(void* arg)
{
delete (std::vector<llama_token>*)arg;
delete (std::vector<int>*)arg;
}

void
Expand Down
12 changes: 8 additions & 4 deletions llamafile/server/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "llamafile/version.h"

#include "log.h"
#include "server.h"
#include "time.h"
#include "tokenbucket.h"
#include "worker.h"
Expand All @@ -54,8 +55,11 @@ on_http_cancel(Client* client)

static ThreadLocal<Client> g_http_cancel(on_http_cancel);

Client::Client()
: cleanups(nullptr), ibuf(FLAG_http_ibuf_size), obuf(FLAG_http_obuf_size)
Client::Client(llama_model* model)
: model_(model)
, cleanups(nullptr)
, ibuf(FLAG_http_ibuf_size)
, obuf(FLAG_http_obuf_size)
{
InitHttpMessage(&msg, 0);
url.params.p = nullptr;
Expand Down Expand Up @@ -213,11 +217,11 @@ Client::transport()
}

if (get_header("X-Priority") == "batch") {
worker->deprioritize();
worker_->deprioritize();
} else if (!effective_ip_trusted) {
if (tokenbucket_acquire(client_ip) > FLAG_token_burst) {
SLOG("deprioritizing");
worker->deprioritize();
worker_->deprioritize();
}
}

Expand Down
8 changes: 6 additions & 2 deletions llamafile/server/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
#define HeaderEqualCase(H, S) \
SlicesEqualCase(S, strlen(S), HeaderData(H), HeaderLength(H))

struct Slot;
struct Worker;
struct llama_model;
struct TokenizeParams;
struct EmbeddingParams;
struct V1ChatCompletionParams;
Expand All @@ -58,7 +60,9 @@ struct Client
bool close_connection = false;
bool should_send_error_if_canceled;
size_t unread = 0;
Worker* worker;
Worker* worker_; // borrowed
Slot* slot_ = nullptr; // owned or null
llama_model* model_; // borrowed
timespec message_started;
HttpMessage msg;
Url url = {};
Expand All @@ -69,7 +73,7 @@ struct Client
Buffer ibuf;
Buffer obuf;

Client();
explicit Client(llama_model*);

void run();
int close();
Expand Down
9 changes: 4 additions & 5 deletions llamafile/server/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include "fastjson.h"
#include "json.h"
#include "log.h"
#include "model.h"
#include "utils.h"

struct EmbeddingParams
Expand Down Expand Up @@ -161,7 +160,7 @@ Client::embedding()
// turn text into tokens
auto toks = new std::vector<llama_token>(params->prompt.size() + 16);
defer_cleanup(cleanup_token_vector, toks);
int count = llama_tokenize(g_model,
int count = llama_tokenize(model_,
params->prompt.data(),
params->prompt.size(),
&(*toks)[0],
Expand All @@ -178,7 +177,7 @@ Client::embedding()
return send_error(400, "completely empty prompt disallowed");

// truncate if exceeds model context size
const int n_ctx_train = llama_n_ctx_train(g_model);
const int n_ctx_train = llama_n_ctx_train(model_);
if (count > n_ctx_train)
count = n_ctx_train;

Expand All @@ -200,15 +199,15 @@ Client::embedding()
cparams.type_k = GGML_TYPE_F16;
cparams.type_v = GGML_TYPE_F16;
cparams.flash_attn = FLAG_flash_attn;
llama_context* ctx = llama_new_context_with_model(g_model, cparams);
llama_context* ctx = llama_new_context_with_model(model_, cparams);
if (!ctx) {
SLOG("llama_new_context_with_model failed");
return send_error(500);
}
defer_cleanup(cleanup_llama_context, ctx);

// initialize batch
const int n_embd = llama_n_embd(g_model);
const int n_embd = llama_n_embd(model_);
llama_batch* batch = new llama_batch;
*batch = llama_batch_init(count, 0, 1);
defer_cleanup(cleanup_llama_batch, batch);
Expand Down
18 changes: 18 additions & 0 deletions llamafile/server/main.1
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,24 @@ Specifies the local [HOST:]PORT on which the HTTP server should listen.
By default this is 0.0.0.0:8080 which means llamafiler will bind to port
8080 on every locally available IPv4 network interface. This option may
currently only be specified once.
.It Fl c Ar TOKENS , Fl Fl ctx-size Ar TOKENS
Specifies context size. This specifies how long a completion can get
before it runs out of space. It defaults to 8k which means 8192 tokens.
Many models support a larger context size, like 128k, but that'll need
much more RAM or VRAM per slot. If this value is larger than the trained
context size of the model, it'll be tuned down to the maximum. If this
value is 0 or negative, the maximum number of tokens will be used.
.It Fl s Ar COUNT , Fl Fl slots Ar COUNT
Specifies how many slots to maintain. This defaults to 1. Slots are used
by chat completions requests. When such a request comes in, the client
needs to take control of a slot. When the completion is finished, the
slot is relinquished back to the server. HTTP clients will wait for a
slot to be relinquished if none are available. Tuning this parameter to
nicely fit available RAM or VRAM can help you manage your server
resources, and control how much completion parallelism can happen.
Please note that
.Fl Fl ctx-size
has a strong influence on how many slots can be created.
.It Fl Fl url-prefix Ar URLPREFIX
Specifies a URL prefix (subdirectory) under which the HTTP server will
make the API accessible, e.g. /lamafiler. Useful when running llamafiler
Expand Down
21 changes: 21 additions & 0 deletions llamafile/server/main.1.asc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,27 @@
will bind to port 8080 on every locally available IPv4 network
interface. This option may currently only be specified once.

-c TOKENS, --ctx-size TOKENS
Specifies context size. This specifies how long a completion
can get before it runs out of space. It defaults to 8k which
means 8192 tokens. Many models support a larger context size,
like 128k, but that'll need much more RAM or VRAM per slot. If
this value is larger than the trained context size of the
model, it'll be tuned down to the maximum. If this value is 0
or negative, the maximum number of tokens will be used.

-s COUNT, --slots COUNT
Specifies how many slots to maintain. This defaults to 1. Slots
are used by chat completions requests. When such a request
comes in, the client needs to take control of a slot. When the
completion is finished, the slot is relinquished back to the
server. HTTP clients will wait for a slot to be relinquished if
none are available. Tuning this parameter to nicely fit avail‐
able RAM or VRAM can help you manage your server resources, and
control how much completion parallelism can happen. Please
note that --ctx-size has a strong influence on how many slots
can be created.

--url-prefix URLPREFIX
Specifies a URL prefix (subdirectory) under which the HTTP
server will make the API accessible, e.g. /lamafiler. Useful
Expand Down
22 changes: 16 additions & 6 deletions llamafile/server/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,14 @@
#include "llamafile/version.h"

#include "log.h"
#include "model.h"
#include "server.h"
#include "signals.h"
#include "slots.h"
#include "time.h"
#include "tokenbucket.h"
#include "utils.h"

Server* g_server;
llama_model* g_model;
std::string g_url_prefix;

int
Expand All @@ -42,7 +41,6 @@ main(int argc, char* argv[])
mallopt(M_GRANULARITY, 2 * 1024 * 1024);
mallopt(M_MMAP_THRESHOLD, 16 * 1024 * 1024);
mallopt(M_TRIM_THRESHOLD, 128 * 1024 * 1024);
FLAG_gpu = LLAMAFILE_GPU_DISABLE;
llamafile_check_cpu();
ShowCrashReports();

Expand Down Expand Up @@ -88,15 +86,26 @@ main(int argc, char* argv[])
.use_mlock = false,
.check_tensors = false,
};
g_model = llama_load_model_from_file(FLAG_model, mparams);
llama_model* model = llama_load_model_from_file(FLAG_model, mparams);
if (!model) {
fprintf(stderr, "%s: failed to load model\n", FLAG_model);
exit(1);
}

// create slots
Slots* slots = new Slots(model);
if (!slots->start(FLAG_slots)) {
SLOG("no slots could be created");
exit(1);
}

// create server
if (FLAG_workers <= 0)
FLAG_workers = __get_cpu_count() + 4;
if (FLAG_workers <= 0)
FLAG_workers = 16;
set_thread_name("server");
g_server = new Server(create_listening_socket(FLAG_listen));
g_server = new Server(create_listening_socket(FLAG_listen), slots, model);
for (int i = 0; i < FLAG_workers; ++i)
npassert(!g_server->spawn());

Expand All @@ -122,7 +131,8 @@ main(int argc, char* argv[])
g_server->shutdown();
g_server->close();
delete g_server;
llama_free_model(g_model);
delete slots;
llama_free_model(model);
tokenbucket_destroy();
time_destroy();
SLOG("exit");
Expand Down
8 changes: 5 additions & 3 deletions llamafile/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// limitations under the License.

#include "server.h"
#include "slots.h"

#include <assert.h>
#include <netinet/in.h>
Expand All @@ -32,7 +33,8 @@
#include "server.h"
#include "worker.h"

Server::Server(int fd) : fd(fd)
Server::Server(int fd, Slots* slots, llama_model* model)
: fd(fd), slots_(slots), model_(model)
{
}

Expand Down Expand Up @@ -102,12 +104,12 @@ Server::spawn()
errno_t err;
Worker* worker;
pthread_attr_t attr;
worker = new Worker(this);
worker = new Worker(this, model_);
pthread_attr_init(&attr);
pthread_attr_setguardsize(&attr, sysconf(_SC_PAGESIZE));
pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_DETACHED);
pthread_attr_setsigaltstacksize_np(&attr, sysconf(_SC_MINSIGSTKSZ) + 16384);
if ((err = pthread_create(&worker->th, &attr, worker_thread, worker)))
if ((err = pthread_create(&worker->th_, &attr, worker_thread, worker)))
delete worker;
pthread_attr_destroy(&attr);
return err;
Expand Down
7 changes: 6 additions & 1 deletion llamafile/server/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
#include <pthread.h>
#include <string>

struct Slots;
struct llama_model;

struct Server
{
Server(int);
Server(int, Slots*, llama_model*);
~Server();

int accept(unsigned*);
Expand All @@ -38,6 +41,8 @@ struct Server
void wait();

int fd;
Slots* slots_;
llama_model* model_;
Dll* idle_workers = nullptr;
Dll* active_workers = nullptr;
pthread_cond_t cond_ = PTHREAD_COND_INITIALIZER;
Expand Down
Loading

0 comments on commit 24a4b87

Please sign in to comment.