Skip to content

Commit

Permalink
enh : add progress callback (#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier authored May 10, 2024
1 parent cbba964 commit e4af686
Show file tree
Hide file tree
Showing 10 changed files with 137 additions and 76 deletions.
79 changes: 26 additions & 53 deletions bark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,39 +163,6 @@ struct bark_context {
bark_statistics stats;
};

class BarkProgressBar {
public:
BarkProgressBar(std::string func_name, double needed_progress) {
this->func_name = func_name;
this->needed_progress = needed_progress;
}

void update(double new_progress) {
current_progress += new_progress;
amount_of_filler = (int)((current_progress / needed_progress) * (double)pbar_length);
}
void print() {
printf("\r%s: %s", func_name.c_str(), initial_part.c_str());
for (int a = 0; a < amount_of_filler; a++) {
printf("%s", pbar_filler.c_str());
}
printf("%s", pbar_updater.c_str());
for (int b = 0; b < pbar_length - amount_of_filler; b++) {
printf(" ");
}
printf("%s (%d%%)", last_part.c_str(), (int)(100 * (current_progress / needed_progress)));
fflush(stdout);
}

std::string initial_part = "[", last_part = "]";
std::string pbar_filler = "=", pbar_updater = ">";

private:
std::string func_name;
double needed_progress, current_progress = 0;
int amount_of_filler, pbar_length = 50;
};

template <typename T>
static void read_safe(std::ifstream& fin, T& dest) {
fin.read((char*)&dest, sizeof(T));
Expand Down Expand Up @@ -1207,20 +1174,18 @@ static bool bark_load_model_from_file(
return true;
}

struct bark_context* bark_load_model(const char* model_path, bark_verbosity_level verbosity, uint32_t seed) {
struct bark_context* bark_load_model(const char* model_path, struct bark_context_params params, uint32_t seed) {
int64_t t_load_start_us = ggml_time_us();

struct bark_context* bctx = new bark_context();

bctx->text_model = bark_model();
std::string model_path_str(model_path);
if (!bark_load_model_from_file(model_path_str, bctx, verbosity)) {
if (!bark_load_model_from_file(model_path_str, bctx, params.verbosity)) {
fprintf(stderr, "%s: failed to load model weights from '%s'\n", __func__, model_path);
return nullptr;
}

bark_context_params params = bark_context_default_params();
params.verbosity = verbosity;
bctx->rng = std::mt19937(seed);
bctx->params = params;
bctx->stats.t_load_us = ggml_time_us() - t_load_start_us;
Expand Down Expand Up @@ -1724,8 +1689,6 @@ static bool bark_eval_text_encoder(struct bark_context* bctx, int n_threads) {
int32_t semantic_vocab_size = params.semantic_vocab_size;
int32_t semantic_pad_token = params.semantic_pad_token;

BarkProgressBar progress(std::string("Generating semantic tokens"), n_steps_text_encoder);

auto& model = bctx->text_model.semantic_model;
auto& allocr = bctx->allocr;
auto& hparams = model.hparams;
Expand All @@ -1742,6 +1705,13 @@ static bool bark_eval_text_encoder(struct bark_context* bctx, int n_threads) {
int n_past = 0;

for (int i = 0; i < n_steps_text_encoder; i++) {
if (params.progress_callback) {
const int progress_cur = 100*(i+1)/n_steps_text_encoder;

params.progress_callback(
bctx, bark_encoding_step::SEMANTIC, progress_cur, params.progress_callback_user_data);
}

if (!bark_eval_encoder_internal(model, allocr, input, logits, &n_past, true, n_threads)) {
fprintf(stderr, "%s: Could not generate token\n", __func__);
return false;
Expand All @@ -1761,9 +1731,6 @@ static bool bark_eval_text_encoder(struct bark_context* bctx, int n_threads) {

input.push_back(next);
output.push_back(next);

progress.update(1);
progress.print();
}

bctx->semantic_tokens = output;
Expand Down Expand Up @@ -1859,8 +1826,6 @@ static bool bark_eval_coarse_encoder(struct bark_context* bctx, int n_threads) {
assert(n_steps > 0);
assert(n_steps % n_coarse_codebooks == 0);

BarkProgressBar progress(std::string("Generating coarse tokens"), n_steps);

int n_window_steps = ceilf(static_cast<float>(n_steps) / sliding_window_size);

int step_idx = 0;
Expand Down Expand Up @@ -1894,6 +1859,13 @@ static bool bark_eval_coarse_encoder(struct bark_context* bctx, int n_threads) {
continue;
}

if (params.progress_callback) {
const int progress_cur = 100*(step_idx+1)/n_steps;

params.progress_callback(
bctx, bark_encoding_step::COARSE, progress_cur, params.progress_callback_user_data);
}

if (!bark_eval_encoder_internal(model, allocr, input_in, logits, &n_past, false, n_threads)) {
fprintf(stderr, "%s: Could not generate token\n", __func__);
return false;
Expand All @@ -1918,9 +1890,6 @@ static bool bark_eval_coarse_encoder(struct bark_context* bctx, int n_threads) {
out.push_back(next);

step_idx += 1;

progress.update(1);
progress.print();
}
}

Expand Down Expand Up @@ -2079,8 +2048,6 @@ static bool bark_eval_fine_encoder(struct bark_context* bctx, int n_threads) {

bark_codes in_arr = input; // [seq_length, n_codes]

BarkProgressBar progress(std::string("Generating fine tokens"), n_loops * (n_fine_codebooks - n_coarse));

for (int n = 0; n < n_loops; n++) {
int start_idx = std::min(n * 512, (int)in_arr.size() - 1024);
int start_fill_idx = std::min(n * 512, (int)in_arr.size() - 512);
Expand All @@ -2095,6 +2062,13 @@ static bool bark_eval_fine_encoder(struct bark_context* bctx, int n_threads) {
}

for (int nn = n_coarse; nn < n_fine_codebooks; nn++) {
if (params.progress_callback) {
const int progress_cur = 100*(n*(n_fine_codebooks-n_coarse)+(nn-n_coarse+1))/(n_loops*(n_fine_codebooks-n_coarse));

params.progress_callback(
bctx, bark_encoding_step::FINE, progress_cur, params.progress_callback_user_data);
}

if (!bark_eval_fine_encoder_internal(bctx, in_buffer, logits, nn, n_threads)) {
fprintf(stderr, "%s: Could not generate token\n", __func__);
return false;
Expand All @@ -2111,9 +2085,6 @@ static bool bark_eval_fine_encoder(struct bark_context* bctx, int n_threads) {

in_buffer[nn * 1024 + rel_start_fill_idx + i] = next;
}

progress.update(1);
progress.print();
}

// transfer over info into model_in
Expand Down Expand Up @@ -2307,6 +2278,8 @@ struct bark_context_params bark_context_default_params() {
/*.n_coarse_codebooks =*/2,
/*.n_fine_codebooks =*/8,
/*.codebook_size =*/1024,
/*.progress_callback =*/nullptr,
/*.progress_callback_user_data =*/nullptr,
};

return result;
Expand Down Expand Up @@ -2378,7 +2351,7 @@ bool bark_model_weights_quantize(std::ifstream& fin, std::ofstream& fout, ggml_f
return true;
}

bool bark_model_quantize(const char* fname_inp, const char* fname_out, ggml_ftype ftype) {
bool bark_model_quantize(const char* fname_inp, const char* fname_out, enum ggml_ftype ftype) {
printf("%s: loading model from '%s'\n", __func__, fname_inp);

std::string fname_inp_str(fname_inp);
Expand Down
18 changes: 15 additions & 3 deletions bark.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ extern "C" {
HIGH = 2,
};

enum bark_encoding_step {
SEMANTIC = 0,
COARSE = 1,
FINE = 2,
};

struct bark_context;
struct bark_model;

Expand All @@ -39,6 +45,8 @@ extern "C" {
// Define the GPT architecture for the 3 encoders
struct gpt_model;

typedef void (*bark_progress_callback)(struct bark_context * bctx, enum bark_encoding_step step, int progress, void * user_data);

struct bark_statistics {
// Time to load model weights
int64_t t_load_us;
Expand Down Expand Up @@ -116,6 +124,10 @@ extern "C" {
int32_t n_fine_codebooks;
// Dimension of the codes
int32_t codebook_size;

// called on each progress update
bark_progress_callback progress_callback;
void * progress_callback_user_data;
};

/**
Expand All @@ -126,16 +138,16 @@ extern "C" {
struct bark_context_params bark_context_default_params(void);

/**
* Loads a BARK model from the specified file path with the given parameters.
* Loads a Bark model from the specified file path with the given parameters.
*
* @param model_path The directory path of the bark model to load.
* @param verbosity The verbosity level when loading the model.
* @param params The parameters to use for the Bark model.
* @param seed The seed to use for random number generation.
* @return A pointer to the loaded bark model context.
*/
struct bark_context *bark_load_model(
const char *model_path,
enum bark_verbosity_level verbosity,
struct bark_context_params params,
uint32_t seed);

/**
Expand Down
4 changes: 4 additions & 0 deletions examples/bark.swiftui/bark.swiftui.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
E05A4F8C2BEE16590000CD31 /* BarkState.swift in Sources */ = {isa = PBXBuildFile; fileRef = E05A4F8B2BEE16590000CD31 /* BarkState.swift */; };
E05A4F952BEE284C0000CD31 /* AudioPlayer.swift in Sources */ = {isa = PBXBuildFile; fileRef = E05A4F942BEE284C0000CD31 /* AudioPlayer.swift */; };
E05A4F9A2BEE3DF60000CD31 /* bark in Frameworks */ = {isa = PBXBuildFile; productRef = E05A4F992BEE3DF60000CD31 /* bark */; };
E061784F2BEE99AE00497E2F /* ProgressData.swift in Sources */ = {isa = PBXBuildFile; fileRef = E061784E2BEE99AE00497E2F /* ProgressData.swift */; };
/* End PBXBuildFile section */

/* Begin PBXFileReference section */
Expand All @@ -28,6 +29,7 @@
E05A4F942BEE284C0000CD31 /* AudioPlayer.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AudioPlayer.swift; sourceTree = "<group>"; };
E05A4F972BEE3CB90000CD31 /* bark_swift_package */ = {isa = PBXFileReference; lastKnownFileType = wrapper; path = bark_swift_package; sourceTree = "<group>"; };
E05A4F982BEE3CCB0000CD31 /* bark-swiftui-Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist; path = "bark-swiftui-Info.plist"; sourceTree = SOURCE_ROOT; };
E061784E2BEE99AE00497E2F /* ProgressData.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ProgressData.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */

/* Begin PBXFrameworksBuildPhase section */
Expand Down Expand Up @@ -110,6 +112,7 @@
isa = PBXGroup;
children = (
E05A4F892BEE15BA0000CD31 /* LibBark.swift */,
E061784E2BEE99AE00497E2F /* ProgressData.swift */,
);
path = Bindings;
sourceTree = "<group>";
Expand Down Expand Up @@ -191,6 +194,7 @@
E05A4F782BEE15150000CD31 /* bark_swiftuiApp.swift in Sources */,
E05A4F952BEE284C0000CD31 /* AudioPlayer.swift in Sources */,
E05A4F8A2BEE15BA0000CD31 /* LibBark.swift in Sources */,
E061784F2BEE99AE00497E2F /* ProgressData.swift in Sources */,
E05A4F8C2BEE16590000CD31 /* BarkState.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
Expand Down
6 changes: 5 additions & 1 deletion examples/bark.swiftui/bark.swiftui/Bindings/LibBark.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ actor BarkContext {
}

static func createContext(path: String, seed: Int) throws -> BarkContext {
let context = bark_load_model(path, bark_verbosity_level(0), UInt32(seed))
var context_params = bark_context_default_params()
context_params.verbosity = bark_verbosity_level(0)
context_params.progress_callback = cCallbackBridge

let context = bark_load_model(path, context_params, UInt32(seed))
if let context {
return BarkContext(context: context)
} else {
Expand Down
39 changes: 39 additions & 0 deletions examples/bark.swiftui/bark.swiftui/Bindings/ProgressData.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//
// ProgressData.swift
// bark.swiftui
//
// Created by Pierre-Antoine BANNIER on 10/05/2024.
//

import Foundation
import Combine
import bark

class ProgressData: ObservableObject {
static let shared = ProgressData()

@Published var progress: Float = 0.0
@Published var stepTitle: String = "Progress..."

private init() {}
}

func cCallbackBridge(bctx: OpaquePointer?, step: bark_encoding_step, progress: Int32, userData: UnsafeMutableRawPointer?) {
DispatchQueue.main.async {
let progressValue = Float(progress) / 100.0
var stepTitle: String

switch step {
case bark_encoding_step(rawValue: 0):
stepTitle = "Semantic tokens (1/3)"
case bark_encoding_step(rawValue: 1):
stepTitle = "Coarse tokens (2/3)"
default:
stepTitle = "Fine tokens (3/3)"
}

// Update the shared observable object
ProgressData.shared.progress = progressValue
ProgressData.shared.stepTitle = stepTitle
}
}
5 changes: 1 addition & 4 deletions examples/bark.swiftui/bark.swiftui/Models/BarkState.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
import Foundation
import AVFoundation

let kSampleRate: Double = 44100;
let kNumChannels: UInt32 = 1;


@MainActor
class BarkState: NSObject, ObservableObject {
Expand Down Expand Up @@ -38,7 +35,6 @@ class BarkState: NSObject, ObservableObject {
canGenerate = true
} catch {
print(error.localizedDescription)
messageLog += "\(error.localizedDescription)"
}
}

Expand All @@ -49,6 +45,7 @@ class BarkState: NSObject, ObservableObject {
messageLog += "Loaded model \(modelUrl.lastPathComponent)"
} else {
messageLog += "Could not locate model\n"
throw LoadError.couldNotLocateModel
}
}

Expand Down
12 changes: 12 additions & 0 deletions examples/bark.swiftui/bark.swiftui/Views/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import SwiftUI

struct ContentView: View {
@StateObject var barkState = BarkState()
@ObservedObject var progressData = ProgressData.shared
@State private var textInput: String = ""

var body: some View {
Expand Down Expand Up @@ -41,6 +42,17 @@ struct ContentView: View {
.disabled(!barkState.canGenerate)
}

HStack {
Text(verbatim: progressData.stepTitle)

Spacer()

ProgressView(value: progressData.progress)
.frame(width: 150)
}
.frame(maxWidth: .infinity, alignment: .leading)
.padding()

ScrollView {
Text(verbatim: barkState.messageLog)
.frame(maxWidth: .infinity, alignment: .leading)
Expand Down
Loading

0 comments on commit e4af686

Please sign in to comment.