Skip to content

Commit

Permalink
examples/main: basic multimodal support ported from llava-cli
Browse files Browse the repository at this point in the history
<image> keyword gets replaced with image embed within prompt.
  • Loading branch information
Nekotekina committed Feb 27, 2024
1 parent d7747aa commit 5dc5a79
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 20 deletions.
6 changes: 4 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -668,9 +668,11 @@ clean:
# Helper function that replaces .c, .cpp, and .cu file endings with .o:
GET_OBJ_FILE = $(patsubst %.c,%.o,$(patsubst %.cpp,%.o,$(patsubst %.cu,%.o,$(1))))

main: examples/main/main.cpp ggml.o llama.o $(COMMON_DEPS) console.o grammar-parser.o $(OBJS)
main: examples/main/main.cpp examples/llava/clip.h examples/llava/clip.cpp examples/llava/llava.h examples/llava/llava.cpp ggml.o llama.o $(COMMON_DEPS) console.o grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
$(CXX) $(CXXFLAGS) -c examples/llava/clip.cpp -o $(call GET_OBJ_FILE, examples/llava/clip.cpp) -Wno-cast-qual
$(CXX) $(CXXFLAGS) -c examples/llava/llava.cpp -o $(call GET_OBJ_FILE, examples/llava/llava.cpp)
$(CXX) $(CXXFLAGS) $(filter-out %.h $< examples/llava/clip.cpp examples/llava/llava.cpp,$^) $(call GET_OBJ_FILE, examples/llava/clip.cpp) $(call GET_OBJ_FILE, examples/llava/llava.cpp) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
@echo
@echo '==== Run ./main -h for help. ===='
@echo
Expand Down
2 changes: 1 addition & 1 deletion build.zig
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ pub fn build(b: *std.build.Builder) !void {
const clip = make.obj("clip", "examples/llava/clip.cpp");
const llava = make.obj("llava", "examples/llava/llava.cpp");

_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, buildinfo, sampling, console, grammar_parser });
_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, buildinfo, sampling, console, grammar_parser, clip, llava });
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, buildinfo });
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, buildinfo });
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, buildinfo });
Expand Down
3 changes: 3 additions & 0 deletions examples/llava/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
#include "ggml-metal.h"
#endif

#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__)
#pragma GCC diagnostic ignored "-Wcast-qual"
#endif
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"

Expand Down
2 changes: 1 addition & 1 deletion examples/main/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
set(TARGET main)
add_executable(${TARGET} main.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PRIVATE common llama llava ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
82 changes: 66 additions & 16 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "common.h"

#include "../llava/clip.h"
#include "../llava/llava.h"
#include "console.h"
#include "llama.h"

Expand Down Expand Up @@ -191,6 +193,8 @@ int main(int argc, char ** argv) {
llama_model * model;
llama_context * ctx;
llama_context * ctx_guidance = NULL;
clip_ctx * ctx_clip = nullptr;
llava_image_embed * image_embed = nullptr;
g_model = &model;
g_ctx = &ctx;

Expand All @@ -207,6 +211,27 @@ int main(int argc, char ** argv) {
return 1;
}

if (!params.image.empty() && params.mmproj.empty()) {
LOG_TEE("%s: error: image specified without mmproj\n", __func__);
return 1;
}

if (!params.mmproj.empty()) {
ctx_clip = clip_model_load(params.mmproj.c_str(), /*verbosity=*/ 1);
if (!ctx_clip) {
LOG_TEE("%s: error: failed to load mmproj (CLIP)\n", __func__);
return 1;
}

if (!params.image.empty()) {
image_embed = llava_image_embed_make_with_filename(ctx_clip, params.n_threads, params.image.c_str());
if (!image_embed) {
LOG_TEE("%s: error: failed to load image\n", __func__);
return 1;
}
}
}

const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
LOG("n_ctx: %d\n", n_ctx);
Expand Down Expand Up @@ -249,13 +274,22 @@ int main(int argc, char ** argv) {
LOG("add_bos: %d\n", add_bos);

std::vector<llama_token> embd_inp;
int embd_img_pos = -1;

if (params.interactive_first || params.instruct || params.chatml || !params.prompt.empty() || session_tokens.empty()) {
LOG("tokenize the prompt\n");
if (params.chatml) {
params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>";
}
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
const auto epos = params.prompt.find("<image>");
if (epos + 1 && image_embed) {
embd_inp = ::llama_tokenize(ctx, params.prompt.substr(0, epos), add_bos, true);
embd_img_pos = embd_inp.size();
auto end = ::llama_tokenize(ctx, params.prompt.substr(epos + 7), false, true);
embd_inp.insert(embd_inp.end(), end.begin(), end.end());
} else {
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
}
} else {
LOG("use session tokens\n");
embd_inp = session_tokens;
Expand Down Expand Up @@ -450,9 +484,13 @@ int main(int argc, char ** argv) {
}
}
}

LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str());
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
// Fix n_keep for keeping embedded image
if (params.n_keep > add_bos && embd_img_pos >= 0)
params.n_keep += image_embed->n_image_pos;

// group-attention state
// number of grouped KV tokens so far (used only if params.grp_attn_n > 1)
Expand Down Expand Up @@ -650,26 +688,36 @@ int main(int argc, char ** argv) {
}
}

for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
int n_eval = (int) embd.size() - i;
if (n_eval > params.n_batch) {
n_eval = params.n_batch;
}
auto decode_tokens = [&](int start, int count) -> void {
if (count == -1)
count = embd.size() - start;
for (int i = start; i < count; i += params.n_batch) {
int n_eval = count - i;
if (n_eval > params.n_batch) {
n_eval = params.n_batch;
}

LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());

if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}
llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0));

n_past += n_eval;
n_past += n_eval;

LOG("n_past = %d\n", n_past);
// Display total tokens alongside total time
if (params.n_print > 0 && n_past % params.n_print == 0) {
LOG_TEE("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
LOG("n_past = %d\n", n_past);
// Display total tokens alongside total time
if (params.n_print > 0 && n_past % params.n_print == 0) {
LOG_TEE("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
}
}
};

if (embd_img_pos >= 0) {
decode_tokens(0, embd_img_pos);
llava_eval_image_embed(ctx, image_embed, params.n_batch, &n_past);
decode_tokens(embd_img_pos, -1);
embd_img_pos = -1;
} else {
decode_tokens(0, embd.size());
}

if (!embd.empty() && !path_session.empty()) {
Expand Down Expand Up @@ -923,6 +971,8 @@ int main(int argc, char ** argv) {
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);

if (ctx_guidance) { llama_free(ctx_guidance); }
if (image_embed) llava_image_embed_free(image_embed);
if (ctx_clip) clip_free(ctx_clip);
llama_free(ctx);
llama_free_model(model);

Expand Down

0 comments on commit 5dc5a79

Please sign in to comment.