diff --git a/Makefile b/Makefile index 068f6ed028460b..b6c2b0ca000f67 100644 --- a/Makefile +++ b/Makefile @@ -668,9 +668,11 @@ clean: # Helper function that replaces .c, .cpp, and .cu file endings with .o: GET_OBJ_FILE = $(patsubst %.c,%.o,$(patsubst %.cpp,%.o,$(patsubst %.cu,%.o,$(1)))) -main: examples/main/main.cpp ggml.o llama.o $(COMMON_DEPS) console.o grammar-parser.o $(OBJS) +main: examples/main/main.cpp examples/llava/clip.h examples/llava/clip.cpp examples/llava/llava.h examples/llava/llava.cpp ggml.o llama.o $(COMMON_DEPS) console.o grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) - $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + $(CXX) $(CXXFLAGS) -c examples/llava/clip.cpp -o $(call GET_OBJ_FILE, examples/llava/clip.cpp) -Wno-cast-qual + $(CXX) $(CXXFLAGS) -c examples/llava/llava.cpp -o $(call GET_OBJ_FILE, examples/llava/llava.cpp) + $(CXX) $(CXXFLAGS) $(filter-out %.h $< examples/llava/clip.cpp examples/llava/llava.cpp,$^) $(call GET_OBJ_FILE, examples/llava/clip.cpp) $(call GET_OBJ_FILE, examples/llava/llava.cpp) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) @echo @echo '==== Run ./main -h for help. ====' @echo diff --git a/build.zig b/build.zig index c0af454dc9e922..29b315ee781629 100644 --- a/build.zig +++ b/build.zig @@ -125,7 +125,7 @@ pub fn build(b: *std.build.Builder) !void { const clip = make.obj("clip", "examples/llava/clip.cpp"); const llava = make.obj("llava", "examples/llava/llava.cpp"); - _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, buildinfo, sampling, console, grammar_parser }); + _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, buildinfo, sampling, console, grammar_parser, clip, llava }); _ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, buildinfo }); _ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, buildinfo }); _ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, buildinfo }); diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index ef9e4ba7a6b5aa..ee46398e254556 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -15,6 +15,9 @@ #include "ggml-metal.h" #endif +#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) +#pragma GCC diagnostic ignored "-Wcast-qual" +#endif #define STB_IMAGE_IMPLEMENTATION #include "stb_image.h" diff --git a/examples/main/CMakeLists.txt b/examples/main/CMakeLists.txt index d532980b76da83..f078adecebc447 100644 --- a/examples/main/CMakeLists.txt +++ b/examples/main/CMakeLists.txt @@ -1,5 +1,5 @@ set(TARGET main) add_executable(${TARGET} main.cpp) install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE common llama llava ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index c3124e8ff134de..97ad173e69b7d3 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -1,5 +1,7 @@ #include "common.h" +#include "../llava/clip.h" +#include "../llava/llava.h" #include "console.h" #include "llama.h" @@ -191,6 +193,8 @@ int main(int argc, char ** argv) { llama_model * model; llama_context * ctx; llama_context * ctx_guidance = NULL; + clip_ctx * ctx_clip = nullptr; + llava_image_embed * image_embed = nullptr; g_model = &model; g_ctx = &ctx; @@ -207,6 +211,27 @@ int main(int argc, char ** argv) { return 1; } + if (!params.image.empty() && params.mmproj.empty()) { + LOG_TEE("%s: error: image specified without mmproj\n", __func__); + return 1; + } + + if (!params.mmproj.empty()) { + ctx_clip = clip_model_load(params.mmproj.c_str(), /*verbosity=*/ 1); + if (!ctx_clip) { + LOG_TEE("%s: error: failed to load mmproj (CLIP)\n", __func__); + return 1; + } + + if (!params.image.empty()) { + image_embed = llava_image_embed_make_with_filename(ctx_clip, params.n_threads, params.image.c_str()); + if (!image_embed) { + LOG_TEE("%s: error: failed to load image\n", __func__); + return 1; + } + } + } + const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); LOG("n_ctx: %d\n", n_ctx); @@ -249,13 +274,22 @@ int main(int argc, char ** argv) { LOG("add_bos: %d\n", add_bos); std::vector embd_inp; + int embd_img_pos = -1; if (params.interactive_first || params.instruct || params.chatml || !params.prompt.empty() || session_tokens.empty()) { LOG("tokenize the prompt\n"); if (params.chatml) { params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>"; } - embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); + const auto epos = params.prompt.find(""); + if (epos + 1 && image_embed) { + embd_inp = ::llama_tokenize(ctx, params.prompt.substr(0, epos), add_bos, true); + embd_img_pos = embd_inp.size(); + auto end = ::llama_tokenize(ctx, params.prompt.substr(epos + 7), false, true); + embd_inp.insert(embd_inp.end(), end.begin(), end.end()); + } else { + embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); + } } else { LOG("use session tokens\n"); embd_inp = session_tokens; @@ -450,9 +484,13 @@ int main(int argc, char ** argv) { } } } + LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str()); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); + // Fix n_keep for keeping embedded image + if (params.n_keep > add_bos && embd_img_pos >= 0) + params.n_keep += image_embed->n_image_pos; // group-attention state // number of grouped KV tokens so far (used only if params.grp_attn_n > 1) @@ -650,26 +688,36 @@ int main(int argc, char ** argv) { } } - for (int i = 0; i < (int) embd.size(); i += params.n_batch) { - int n_eval = (int) embd.size() - i; - if (n_eval > params.n_batch) { - n_eval = params.n_batch; - } + auto decode_tokens = [&](int start, int count) -> void { + if (count == -1) + count = embd.size() - start; + for (int i = start; i < count; i += params.n_batch) { + int n_eval = count - i; + if (n_eval > params.n_batch) { + n_eval = params.n_batch; + } - LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); + LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); - if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) { - LOG_TEE("%s : failed to eval\n", __func__); - return 1; - } + llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0)); - n_past += n_eval; + n_past += n_eval; - LOG("n_past = %d\n", n_past); - // Display total tokens alongside total time - if (params.n_print > 0 && n_past % params.n_print == 0) { - LOG_TEE("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx); + LOG("n_past = %d\n", n_past); + // Display total tokens alongside total time + if (params.n_print > 0 && n_past % params.n_print == 0) { + LOG_TEE("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx); + } } + }; + + if (embd_img_pos >= 0) { + decode_tokens(0, embd_img_pos); + llava_eval_image_embed(ctx, image_embed, params.n_batch, &n_past); + decode_tokens(embd_img_pos, -1); + embd_img_pos = -1; + } else { + decode_tokens(0, embd.size()); } if (!embd.empty() && !path_session.empty()) { @@ -923,6 +971,8 @@ int main(int argc, char ** argv) { write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); if (ctx_guidance) { llama_free(ctx_guidance); } + if (image_embed) llava_image_embed_free(image_embed); + if (ctx_clip) clip_free(ctx_clip); llama_free(ctx); llama_free_model(model);