-
Notifications
You must be signed in to change notification settings - Fork 10.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This new mode works by first loading the model then listening for TCP connections on a port. When a connection is received, arguments will be parsed using a simple protocol: - First the number of arguments will be read followed by a newline character. - Then each argument will be read, separated by the 0 byte. - With this we build an argument vector, similar to what is passed to the program entry point. We pass this to gpt_params_parse. Finally `llama_main` will be executed with the input/output streams connected to the socket. Signed-off-by: Thiago Padilha <thiago@padilha.cc>
- Loading branch information
Showing
9 changed files
with
337 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#!/usr/bin/env bash | ||
|
||
PORT=${PORT:-8080} | ||
PROMPT="${PROMPT:-"Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision. | ||
User: Hello, Bob. | ||
Bob: Hello. How may I help you today? | ||
User: Please tell me the largest city in Europe. | ||
Bob: Sure. The largest city in Europe is Moscow, the capital of Russia. | ||
User:"}" | ||
RPROMPT="${RPROMPT:-"User:"}" | ||
N_PREDICT="${N_PREDICT:-"4096"}" | ||
REPEAT_PENALTY="${REPEAT_PENALTY:-"1.0"}" | ||
|
||
# Open connection to the chat server | ||
exec 3<>/dev/tcp/127.0.0.1/${PORT} | ||
|
||
# Pass the arguments. The protocol is really simple: | ||
# 1. Pass the number of arguments followed by a linefeed | ||
# 2. Pass the arguments, with each being followed by "0" | ||
( | ||
echo -en "10\n" | ||
echo -en "-n\x00" | ||
echo -en "$N_PREDICT\x00" | ||
echo -en "--repeat_penalty\x00" | ||
echo -en "$REPEAT_PENALTY\x00" | ||
echo -en "--color\x00" | ||
echo -en "-i\x00" | ||
echo -en "-r\x00" | ||
echo -en "$RPROMPT\x00" | ||
echo -en "-p\x00" | ||
echo -en "$PROMPT\x00" | ||
) >&3 | ||
|
||
trap exit TERM | ||
|
||
# When we have passed the arguments, start printing socket data to the screen. | ||
# This is done in a background job because we also want to send data when | ||
# running in interactive mode. | ||
cat <&3 && echo "(disconnected, press \"enter\" twice to exit)" & | ||
cat >&3 | ||
wait |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/usr/bin/env bash | ||
|
||
PORT=${PORT:-8080} | ||
MODEL=${MODEL:-models/7B/ggml-model-q4_0.bin} | ||
|
||
./main -l ${PORT} -m $MODEL |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,245 @@ | ||
#include "tcp_server.h" | ||
|
||
#include <iostream> | ||
|
||
#include <stdarg.h> | ||
#include <stdio.h> | ||
#include <stdlib.h> | ||
#include <stdbool.h> | ||
#include <string.h> | ||
#include <errno.h> | ||
|
||
#include <signal.h> | ||
#include <unistd.h> | ||
#include <sys/wait.h> | ||
|
||
#include <sys/types.h> | ||
#include <sys/socket.h> | ||
#include <netdb.h> | ||
|
||
class PosixStream : public std::istream { | ||
public: | ||
PosixStream(int fd) : std::istream(&buf), buf(fd) {} | ||
~PosixStream() { close(buf.get_fd()); } | ||
|
||
private: | ||
class PosixStreamBuf : public std::streambuf { | ||
public: | ||
PosixStreamBuf(int fd) : fd(fd) {} | ||
int get_fd() const { return fd; } | ||
|
||
protected: | ||
virtual int_type underflow() { | ||
if (gptr() < egptr()) { | ||
return traits_type::to_int_type(*gptr()); | ||
} | ||
|
||
ssize_t num_read = ::read(fd, buffer, BUFFER_SIZE); | ||
if (num_read <= 0) { | ||
return traits_type::eof(); | ||
} | ||
|
||
setg(buffer, buffer, buffer + num_read); | ||
return traits_type::to_int_type(*gptr()); | ||
} | ||
|
||
private: | ||
static const int BUFFER_SIZE = 1024; | ||
int fd; | ||
char buffer[BUFFER_SIZE]; | ||
}; | ||
|
||
PosixStreamBuf buf; | ||
}; | ||
|
||
void die(const char *msg, ...) | ||
{ | ||
va_list ap; | ||
|
||
va_start(ap, msg); | ||
vfprintf(stderr, msg, ap); | ||
va_end(ap); | ||
fputc('\n', stderr); | ||
exit(1); | ||
} | ||
|
||
static char *read_argument(uint8_t **param_buf, size_t *param_buf_size, FILE *instream) { | ||
bool done = false; | ||
uint8_t *buf = *param_buf; | ||
size_t bufsize = *param_buf_size; | ||
size_t bufpos = 0; | ||
while (!done) { | ||
if (bufpos == bufsize) { | ||
bufsize += 1024; | ||
buf = (uint8_t *)realloc(buf, bufsize); | ||
if (!buf) { | ||
die("failed to allocate memory"); | ||
} | ||
} | ||
|
||
int c = fgetc(instream); | ||
if (c == EOF) { | ||
die("unexpected EOF client socket"); | ||
} | ||
buf[bufpos++] = (uint8_t)c; | ||
if (c == 0) { | ||
// done reading argument | ||
break; | ||
} | ||
} | ||
*param_buf = buf; | ||
*param_buf_size = bufsize; | ||
return strdup((char *)buf); | ||
} | ||
|
||
static int read_arguments(int argc, char **argv, FILE *instream) { | ||
int i = 1; | ||
size_t param_buf_size = 0; | ||
uint8_t *param_buf = nullptr; | ||
|
||
for (i = 1; i < argc; i++) { | ||
argv[i] = read_argument(¶m_buf, ¶m_buf_size, instream); | ||
} | ||
|
||
free(param_buf); | ||
return i; | ||
} | ||
|
||
static int serve_model( | ||
gpt_params params, | ||
gpt_vocab vocab, | ||
llama_model model, | ||
int64_t t_load_us, | ||
int64_t t_main_start_us, | ||
int sock_fd) | ||
{ | ||
char *response_data; | ||
int argc; | ||
char **argv; | ||
FILE *instream = fdopen(sock_fd, "r"); | ||
FILE *outstream = fdopen(sock_fd, "w"); | ||
setvbuf(instream, NULL, _IONBF, 0); | ||
|
||
// start by reading the parameter count | ||
if (fscanf(instream, "%d\n", &argc) != 1) { | ||
fprintf(outstream, "Error: First line must be character count\n"); | ||
fflush(outstream); | ||
return 1; | ||
} | ||
|
||
argc += 1; // add one extra argument to emulate the program command line | ||
argv = (char **)malloc(argc * sizeof *argv); | ||
argv[0] = nullptr; | ||
if (read_arguments(argc, argv, instream) != argc) { | ||
fprintf(outstream, "Error: Failed to read arguments\n"); | ||
fflush(outstream); | ||
} | ||
|
||
if (gpt_params_parse(argc, argv, params) == false) { | ||
fprintf(outstream, "Error: Failed to parse parameters\n"); | ||
fflush(outstream); | ||
return 1; | ||
} | ||
|
||
for (int i = 1; i < argc; i++) { | ||
free(argv[i]); | ||
} | ||
free(argv); | ||
|
||
PosixStream tcp_is(sock_fd); | ||
|
||
return llama_main(params, vocab, model, t_load_us, t_main_start_us, tcp_is, outstream, outstream); | ||
} | ||
|
||
int listen_tcp( | ||
gpt_params params, | ||
gpt_vocab vocab, | ||
llama_model model, | ||
int64_t t_main_start_us, | ||
int64_t t_load_us) { | ||
int listen_fd; | ||
int status; | ||
pid_t child; | ||
struct addrinfo hints; | ||
struct addrinfo *servinfo, *p; | ||
int yes = 1; | ||
|
||
memset(&hints, 0, sizeof hints); | ||
hints.ai_family = AF_INET; | ||
hints.ai_socktype = SOCK_STREAM; | ||
hints.ai_flags = AI_PASSIVE; | ||
|
||
// This should only ever listen on a loopback address. Access from outside | ||
// should be proxied via nginx or similar software | ||
status = getaddrinfo("127.0.0.1", params.listen_port.c_str(), &hints, &servinfo); | ||
if (status) { | ||
die("getaddrinfo error: %s", gai_strerror(status)); | ||
} | ||
|
||
// bind to the first addrinfo we can from the getaddrinfo results | ||
for (p = servinfo; p != NULL; p = p->ai_next) { | ||
listen_fd = socket(p->ai_family, p->ai_socktype, p->ai_protocol); | ||
if (listen_fd == -1) { | ||
perror("server: socket"); | ||
continue; | ||
} | ||
|
||
if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof yes)) { | ||
die("setsockopt error: %s", params.listen_port.c_str(), strerror(errno)); | ||
} | ||
|
||
if (bind(listen_fd, p->ai_addr, p->ai_addrlen) == 0) { | ||
break; | ||
} | ||
|
||
close(listen_fd); | ||
perror("server: bind"); | ||
} | ||
|
||
freeaddrinfo(servinfo); | ||
|
||
if (p == NULL) { | ||
die("failed to bind: %s", strerror(errno)); | ||
} | ||
|
||
if (listen(listen_fd, 20)) { | ||
die("listen error: %s", strerror(errno)); | ||
} | ||
// Don't track child processes, so ignore SIGCHLD to prevent zombies | ||
signal(SIGCHLD, SIG_IGN); | ||
|
||
for (;;) { | ||
struct sockaddr_in client_addr = {0}; | ||
socklen_t client_addr_len = 0; | ||
|
||
int sock_fd = accept(listen_fd, | ||
(struct sockaddr *)&client_addr, | ||
&client_addr_len); | ||
if (sock_fd < 0) { | ||
fprintf(stderr, "accept error: %s\n", strerror(errno)); | ||
break; | ||
} | ||
|
||
child = fork(); | ||
if (child == 0) { | ||
// close the listen_fd since we won't use it in the child | ||
close(listen_fd); | ||
int ret = serve_model(params, vocab, model, t_main_start_us, t_load_us, sock_fd); | ||
close(sock_fd); | ||
return ret; | ||
} else { | ||
// close the client since we won't use it in the server | ||
close(sock_fd); | ||
sock_fd = 0; | ||
} | ||
} | ||
close(listen_fd); | ||
|
||
// ignore SIGTERM since we'll send it to the group | ||
signal(SIGTERM, SIG_IGN); | ||
// tell children to exit | ||
kill(0, SIGTERM); | ||
// wait for children to terminate | ||
wait(&status); | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#pragma once | ||
|
||
#include "utils.h" | ||
#include "llama.h" | ||
|
||
int listen_tcp( | ||
gpt_params params, | ||
gpt_vocab vocab, | ||
llama_model model, | ||
int64_t t_main_start_us, | ||
int64_t t_load_us); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters