From deec6f39e7230d10d5f1b6eac77e11d1c50a2942 Mon Sep 17 00:00:00 2001 From: PAB Date: Sun, 30 Jul 2023 18:04:07 +0200 Subject: [PATCH] MNT Unit tests tokenizer (#18) --- .gitignore | 2 ++ CMakeLists.txt | 10 +++--- Makefile | 17 +++++++-- bark.cpp | 55 +--------------------------- bark.h | 19 +++++++++- examples/main.cpp | 55 ++++++++++++++++++++++++++++ tests/CMakeLists.txt | 9 +++++ tests/test-tokenizer.cpp | 78 ++++++++++++++++++++++++++++++++++++++++ 8 files changed, 182 insertions(+), 63 deletions(-) create mode 100644 examples/main.cpp create mode 100644 tests/CMakeLists.txt create mode 100644 tests/test-tokenizer.cpp diff --git a/.gitignore b/.gitignore index 247f7a3..27b390a 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,8 @@ build/ bark encodec +main +tests/test-tokenizer *.o *.plist diff --git a/CMakeLists.txt b/CMakeLists.txt index 09ceed0..16c084e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -75,7 +75,7 @@ option(BARK_CUDA_DMMV_F16 "bark: use 16 bit floats for dmmv CU option(BARK_CLBLAST "bark: use CLBlast" OFF) option(BARK_METAL "bark: use Metal" OFF) -# option(BARK_BUILD_TESTS "bark: build tests" ${BARK_STANDALONE}) +option(BARK_BUILD_TESTS "bark: build tests" ${BARK_STANDALONE}) # # Build info header @@ -518,7 +518,7 @@ install( # programs, examples and tests # -# if (BARK_BUILD_TESTS AND NOT CMAKE_JS_VERSION) -# include(CTest) -# add_subdirectory(tests) -# endif () +if (BARK_BUILD_TESTS AND NOT CMAKE_JS_VERSION) + include(CTest) + add_subdirectory(tests) +endif () diff --git a/Makefile b/Makefile index 81454ad..6fcfc16 100644 --- a/Makefile +++ b/Makefile @@ -2,8 +2,7 @@ BUILD_TARGETS = bark # Binaries only useful for tests -# TEST_TARGETS = tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0 -TEST_TARGETS = +TEST_TARGETS = tests/test-tokenizer default: $(BUILD_TARGETS) @@ -302,5 +301,17 @@ bark.o: bark.cpp bark.h clean: rm -vf *.o *.so *.dll encodec bark -bark: bark.cpp encodec.o ggml.o $(OBJS) +bark: bark.cpp encodec.o ggml.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) + +main: examples/main.cpp ggml.o bark.o encodec.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) + +# +# Test +# + +tests: $(TEST_TARGETS) + +tests/test-tokenizer: tests/test-tokenizer.cpp ggml.o bark.o encodec.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS) diff --git a/bark.cpp b/bark.cpp index f678d33..b8fd270 100644 --- a/bark.cpp +++ b/bark.cpp @@ -341,7 +341,7 @@ bool gpt_model_load(const std::string& fname, gpt_model& model, bark_vocab& voca return true; } -bool bark_model_load(std::string & dirname, bark_model & model) { +bool bark_model_load(const std::string & dirname, bark_model & model) { printf("%s: loading model from '%s'\n", __func__, dirname.c_str()); // text @@ -1460,56 +1460,3 @@ bool bark_generate_audio( return true; } - -int main() { - const int64_t t_main_start_us = ggml_time_us(); - - int64_t t_load_us = 0; - int64_t t_eval_us = 0; - - bark_model model; - std::string fname = "./ggml_weights"; - - // load the model - { - const int64_t t_start_us = ggml_time_us(); - - if(!bark_model_load(fname, model)) { - fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, fname.c_str()); - return 1; - } - - t_load_us = ggml_time_us() - t_start_us; - } - - printf("\n"); - - // forward pass - const std::string prompt = "This is an audio"; - { - const int64_t t_eval_us_start = ggml_time_us(); - - // call to generate audio - bark_generate_audio(model, model.vocab, prompt.data(), 4); - - t_eval_us = ggml_time_us() - t_eval_us_start; - } - - // report timing - { - const int64_t t_main_end_us = ggml_time_us(); - - printf("\n\n"); - printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f); - printf("%s: eval time = %8.2f ms\n", __func__, t_eval_us/1000.0f); - printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); - } - - // TODO: write wrapper - ggml_free(model.coarse_model.ctx); - ggml_free(model.fine_model.ctx); - ggml_free(model.text_model.ctx); - ggml_free(model.codec_model.ctx); - - return 0; -} \ No newline at end of file diff --git a/bark.h b/bark.h index c7e0d89..960c6d2 100644 --- a/bark.h +++ b/bark.h @@ -110,4 +110,21 @@ struct bark_model { bark_vocab vocab; int32_t memsize = 0; -}; \ No newline at end of file +}; + +bool gpt_model_load(const std::string& fname, gpt_model& model, bark_vocab& vocab, bool has_vocab); + +bool bark_model_load(const std::string & dirname, bark_model & model); + +void bert_tokenize( + const bark_vocab& vocab, + const char * text, + int32_t * tokens, + int32_t * n_tokens, + int32_t n_max_tokens); + +bool bark_generate_audio( + bark_model model, + const bark_vocab& vocab, + const char * text, + const int n_threads); diff --git a/examples/main.cpp b/examples/main.cpp new file mode 100644 index 0000000..af2b089 --- /dev/null +++ b/examples/main.cpp @@ -0,0 +1,55 @@ +#include "ggml.h" +#include "bark.h" + +int main() { + const int64_t t_main_start_us = ggml_time_us(); + + int64_t t_load_us = 0; + int64_t t_eval_us = 0; + + bark_model model; + std::string fname = "./ggml_weights"; + + // load the model + { + const int64_t t_start_us = ggml_time_us(); + + if(!bark_model_load(fname, model)) { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, fname.c_str()); + return 1; + } + + t_load_us = ggml_time_us() - t_start_us; + } + + printf("\n"); + + // forward pass + const std::string prompt = "This is an audio"; + { + const int64_t t_eval_us_start = ggml_time_us(); + + // call to generate audio + bark_generate_audio(model, model.vocab, prompt.data(), 4); + + t_eval_us = ggml_time_us() - t_eval_us_start; + } + + // report timing + { + const int64_t t_main_end_us = ggml_time_us(); + + printf("\n\n"); + printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f); + printf("%s: eval time = %8.2f ms\n", __func__, t_eval_us/1000.0f); + printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); + } + + // TODO: write wrapper + ggml_free(model.coarse_model.ctx); + ggml_free(model.fine_model.ctx); + ggml_free(model.text_model.ctx); + ggml_free(model.codec_model.ctx); + + return 0; +} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt new file mode 100644 index 0000000..c9f2095 --- /dev/null +++ b/tests/CMakeLists.txt @@ -0,0 +1,9 @@ +function(bark_add_test source) + get_filename_component(TEST_TARGET ${source} NAME_WE) + add_executable(${TEST_TARGET} ${source}) + install(TARGETS ${TEST_TARGET} RUNTIME) + target_link_libraries(${TEST_TARGET} PRIVATE bark) + add_test(NAME ${TEST_TARGET} COMMAND $ ${ARGN}) +endfunction() + +llama_add_test(test-tokenizer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../ggml_weights/ggml_weights_text.bin) diff --git a/tests/test-tokenizer.cpp b/tests/test-tokenizer.cpp new file mode 100644 index 0000000..3fcf04f --- /dev/null +++ b/tests/test-tokenizer.cpp @@ -0,0 +1,78 @@ +#include "bark.h" + +#include +#include +#include +#include + +static const std::map> & k_tests() +{ + static std::map> _k_tests = { + { "Hello world!", { 31178, 11356, 106, }, }, + { "Hello world", { 31178, 11356, }, }, + { " Hello world!", { 31178, 11356, 106, }, }, + // { "this is an audio generated by bark", { 10531, 10124, 10151, 23685, 48918, 10155, 18121, 10174, }, }, + }; + return _k_tests; +}; + +int main(int argc, char **argv) { + if (argc < 2) { + fprintf(stderr, "Usage: %s \n", argv[0]); + return 1; + } + + const std::string fname = argv[1]; + + fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str()); + + bark_model model; + int max_ctx_size = 256; + + // load text model and vocab + { + if(!gpt_model_load(fname, model.text_model, model.vocab, true)) { + fprintf(stderr, "%s: invalid model file '%s' (bad text)\n", __func__, fname.c_str()); + return 1; + } + model.memsize += model.text_model.memsize; + } + + for (const auto & test_kv : k_tests()) { + std::vector res(test_kv.first.size()); + int n_tokens; + bert_tokenize(model.vocab, test_kv.first.c_str(), res.data(), &n_tokens, max_ctx_size); + res.resize(n_tokens); + + bool correct = res.size() == test_kv.second.size(); + + for (int i = 0; i < (int) res.size() && correct; ++i) { + if (res[i] != test_kv.second[i]) { + correct = false; + } + } + + if (!correct) { + fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str()); + fprintf(stderr, "%s : expected tokens: ", __func__); + for (const auto & t : test_kv.second) { + fprintf(stderr, "%6d, ", t); + } + fprintf(stderr, "\n"); + fprintf(stderr, "%s : got tokens: ", __func__); + for (const auto & t : res) { + fprintf(stderr, "%6d, ", t); + } + fprintf(stderr, "\n"); + + return 3; + } + } + + ggml_free(model.coarse_model.ctx); + ggml_free(model.fine_model.ctx); + ggml_free(model.text_model.ctx); + ggml_free(model.codec_model.ctx); + + return 0; +}