Skip to content

Commit

Permalink
MNT Unit tests tokenizer (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier authored Jul 30, 2023
1 parent 117d490 commit deec6f3
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 63 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ build/

bark
encodec
main
tests/test-tokenizer

*.o
*.plist
10 changes: 5 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ option(BARK_CUDA_DMMV_F16 "bark: use 16 bit floats for dmmv CU
option(BARK_CLBLAST "bark: use CLBlast" OFF)
option(BARK_METAL "bark: use Metal" OFF)

# option(BARK_BUILD_TESTS "bark: build tests" ${BARK_STANDALONE})
option(BARK_BUILD_TESTS "bark: build tests" ${BARK_STANDALONE})

#
# Build info header
Expand Down Expand Up @@ -518,7 +518,7 @@ install(
# programs, examples and tests
#

# if (BARK_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
# include(CTest)
# add_subdirectory(tests)
# endif ()
if (BARK_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
include(CTest)
add_subdirectory(tests)
endif ()
17 changes: 14 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
BUILD_TARGETS = bark

# Binaries only useful for tests
# TEST_TARGETS = tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0
TEST_TARGETS =
TEST_TARGETS = tests/test-tokenizer

default: $(BUILD_TARGETS)

Expand Down Expand Up @@ -302,5 +301,17 @@ bark.o: bark.cpp bark.h
clean:
rm -vf *.o *.so *.dll encodec bark

bark: bark.cpp encodec.o ggml.o $(OBJS)
bark: bark.cpp encodec.o ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

main: examples/main.cpp ggml.o bark.o encodec.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

#
# Test
#

tests: $(TEST_TARGETS)

tests/test-tokenizer: tests/test-tokenizer.cpp ggml.o bark.o encodec.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS)
55 changes: 1 addition & 54 deletions bark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ bool gpt_model_load(const std::string& fname, gpt_model& model, bark_vocab& voca
return true;
}

bool bark_model_load(std::string & dirname, bark_model & model) {
bool bark_model_load(const std::string & dirname, bark_model & model) {
printf("%s: loading model from '%s'\n", __func__, dirname.c_str());

// text
Expand Down Expand Up @@ -1460,56 +1460,3 @@ bool bark_generate_audio(

return true;
}

int main() {
const int64_t t_main_start_us = ggml_time_us();

int64_t t_load_us = 0;
int64_t t_eval_us = 0;

bark_model model;
std::string fname = "./ggml_weights";

// load the model
{
const int64_t t_start_us = ggml_time_us();

if(!bark_model_load(fname, model)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, fname.c_str());
return 1;
}

t_load_us = ggml_time_us() - t_start_us;
}

printf("\n");

// forward pass
const std::string prompt = "This is an audio";
{
const int64_t t_eval_us_start = ggml_time_us();

// call to generate audio
bark_generate_audio(model, model.vocab, prompt.data(), 4);

t_eval_us = ggml_time_us() - t_eval_us_start;
}

// report timing
{
const int64_t t_main_end_us = ggml_time_us();

printf("\n\n");
printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
printf("%s: eval time = %8.2f ms\n", __func__, t_eval_us/1000.0f);
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
}

// TODO: write wrapper
ggml_free(model.coarse_model.ctx);
ggml_free(model.fine_model.ctx);
ggml_free(model.text_model.ctx);
ggml_free(model.codec_model.ctx);

return 0;
}
19 changes: 18 additions & 1 deletion bark.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,21 @@ struct bark_model {
bark_vocab vocab;

int32_t memsize = 0;
};
};

bool gpt_model_load(const std::string& fname, gpt_model& model, bark_vocab& vocab, bool has_vocab);

bool bark_model_load(const std::string & dirname, bark_model & model);

void bert_tokenize(
const bark_vocab& vocab,
const char * text,
int32_t * tokens,
int32_t * n_tokens,
int32_t n_max_tokens);

bool bark_generate_audio(
bark_model model,
const bark_vocab& vocab,
const char * text,
const int n_threads);
55 changes: 55 additions & 0 deletions examples/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#include "ggml.h"
#include "bark.h"

int main() {
const int64_t t_main_start_us = ggml_time_us();

int64_t t_load_us = 0;
int64_t t_eval_us = 0;

bark_model model;
std::string fname = "./ggml_weights";

// load the model
{
const int64_t t_start_us = ggml_time_us();

if(!bark_model_load(fname, model)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, fname.c_str());
return 1;
}

t_load_us = ggml_time_us() - t_start_us;
}

printf("\n");

// forward pass
const std::string prompt = "This is an audio";
{
const int64_t t_eval_us_start = ggml_time_us();

// call to generate audio
bark_generate_audio(model, model.vocab, prompt.data(), 4);

t_eval_us = ggml_time_us() - t_eval_us_start;
}

// report timing
{
const int64_t t_main_end_us = ggml_time_us();

printf("\n\n");
printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
printf("%s: eval time = %8.2f ms\n", __func__, t_eval_us/1000.0f);
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
}

// TODO: write wrapper
ggml_free(model.coarse_model.ctx);
ggml_free(model.fine_model.ctx);
ggml_free(model.text_model.ctx);
ggml_free(model.codec_model.ctx);

return 0;
}
9 changes: 9 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
function(bark_add_test source)
get_filename_component(TEST_TARGET ${source} NAME_WE)
add_executable(${TEST_TARGET} ${source})
install(TARGETS ${TEST_TARGET} RUNTIME)
target_link_libraries(${TEST_TARGET} PRIVATE bark)
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN})
endfunction()

llama_add_test(test-tokenizer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../ggml_weights/ggml_weights_text.bin)
78 changes: 78 additions & 0 deletions tests/test-tokenizer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#include "bark.h"

#include <cstdio>
#include <string>
#include <map>
#include <vector>

static const std::map<std::string, std::vector<bark_vocab::id>> & k_tests()
{
static std::map<std::string, std::vector<bark_vocab::id>> _k_tests = {
{ "Hello world!", { 31178, 11356, 106, }, },
{ "Hello world", { 31178, 11356, }, },
{ " Hello world!", { 31178, 11356, 106, }, },
// { "this is an audio generated by bark", { 10531, 10124, 10151, 23685, 48918, 10155, 18121, 10174, }, },
};
return _k_tests;
};

int main(int argc, char **argv) {
if (argc < 2) {
fprintf(stderr, "Usage: %s <model-file>\n", argv[0]);
return 1;
}

const std::string fname = argv[1];

fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());

bark_model model;
int max_ctx_size = 256;

// load text model and vocab
{
if(!gpt_model_load(fname, model.text_model, model.vocab, true)) {
fprintf(stderr, "%s: invalid model file '%s' (bad text)\n", __func__, fname.c_str());
return 1;
}
model.memsize += model.text_model.memsize;
}

for (const auto & test_kv : k_tests()) {
std::vector<bark_vocab::id> res(test_kv.first.size());
int n_tokens;
bert_tokenize(model.vocab, test_kv.first.c_str(), res.data(), &n_tokens, max_ctx_size);
res.resize(n_tokens);

bool correct = res.size() == test_kv.second.size();

for (int i = 0; i < (int) res.size() && correct; ++i) {
if (res[i] != test_kv.second[i]) {
correct = false;
}
}

if (!correct) {
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
fprintf(stderr, "%s : expected tokens: ", __func__);
for (const auto & t : test_kv.second) {
fprintf(stderr, "%6d, ", t);
}
fprintf(stderr, "\n");
fprintf(stderr, "%s : got tokens: ", __func__);
for (const auto & t : res) {
fprintf(stderr, "%6d, ", t);
}
fprintf(stderr, "\n");

return 3;
}
}

ggml_free(model.coarse_model.ctx);
ggml_free(model.fine_model.ctx);
ggml_free(model.text_model.ctx);
ggml_free(model.codec_model.ctx);

return 0;
}

0 comments on commit deec6f3

Please sign in to comment.