Skip to content

Commit

Permalink
Get CUDA and Metal GPU working in whisperfile
Browse files Browse the repository at this point in the history
  • Loading branch information
jart committed Jul 31, 2024
1 parent 94e9629 commit 0849f32
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 25 deletions.
8 changes: 8 additions & 0 deletions llamafile/metal.c
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ static struct Metal {
typeof(ggml_backend_metal_set_n_cb) *backend_set_n_cb;
typeof(ggml_backend_metal_log_set_callback) *log_set_callback;
typeof(ggml_backend_reg_metal_init) *reg_init;
typeof(ggml_backend_metal_supports_family) *supports_family;
} ggml_metal;

static const char *Dlerror(void) {
Expand Down Expand Up @@ -217,6 +218,7 @@ static bool LinkMetal(const char *dso) {
ok &= !!(ggml_metal.backend_set_n_cb = cosmo_dlsym(lib, "ggml_backend_metal_set_n_cb"));
ok &= !!(ggml_metal.log_set_callback = cosmo_dlsym(lib, "ggml_backend_metal_log_set_callback"));
ok &= !!(ggml_metal.reg_init = cosmo_dlsym(lib, "ggml_backend_reg_metal_init"));
ok &= !!(ggml_metal.supports_family = cosmo_dlsym(lib, "ggml_backend_metal_supports_family"));
if (!ok) {
tinylog(Dlerror(), ": not all symbols could be imported\n", NULL);
return false;
Expand Down Expand Up @@ -318,3 +320,9 @@ ggml_backend_t ggml_backend_reg_metal_init(const char *params, void *user_data)
return 0;
return ggml_metal.reg_init(params, user_data);
}

bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
if (!llamafile_has_metal())
return 0;
return ggml_metal.supports_family(backend, family);
}
42 changes: 35 additions & 7 deletions whisper.cpp/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
#include <vector>
#include <cstring>

#include "llamafile/llamafile.h"
#include "llamafile/debug.h"

#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
Expand All @@ -28,11 +31,11 @@ static void replace_all(std::string & s, const std::string & search, const std::
}
}

int32_t get_num_physical_cores();
int cpu_get_num_math();

// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, get_num_physical_cores());
int32_t n_threads = cpu_get_num_math();
int32_t n_processors = 1;
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
Expand Down Expand Up @@ -72,7 +75,6 @@ struct whisper_params {
bool print_progress = false;
bool no_timestamps = false;
bool log_score = false;
bool use_gpu = true;
bool flash_attn = false;

std::string language = "en";
Expand Down Expand Up @@ -122,6 +124,33 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
continue;
}

if (arg == "--log-disable") {
FLAG_log_disable = true;
} else if (arg == "--trap") {
FLAG_trap = true;
FLAG_unsecure = true; // for better backtraces
llamafile_trapping_enabled(+1);
} else if (arg == "--unsecure") {
FLAG_unsecure = true;
} else if (arg == "--nocompile") {
FLAG_nocompile = true;
} else if (arg == "--recompile") {
FLAG_recompile = true;
} else if (arg == "--tinyblas") {
FLAG_tinyblas = true; // undocumented
} else if (arg == "--gpu") {
if (++i >= argc) {
fprintf(stderr, "error: missing --gpu flag value\n");
exit(1);
}
FLAG_gpu = llamafile_gpu_parse(argv[i]);
if (FLAG_gpu == LLAMAFILE_GPU_ERROR) {
fprintf(stderr, "error: invalid --gpu flag value: %s\n", argv[i]);
exit(1);
}
return true;
} else

if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
exit(0);
Expand Down Expand Up @@ -157,7 +186,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
else if (arg == "-ojf" || arg == "--output-json-full"){ params.output_jsn_full = params.output_jsn = true; }
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; }
else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; FLAG_log_disable = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
Expand All @@ -170,7 +199,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-ng" || arg == "--no-gpu") { FLAG_gpu = LLAMAFILE_GPU_DISABLE; }
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
Expand Down Expand Up @@ -236,7 +265,7 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", FLAG_gpu == LLAMAFILE_GPU_DISABLE ? "false" : "true");
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
Expand Down Expand Up @@ -983,7 +1012,6 @@ int main(int argc, char ** argv) {

struct whisper_context_params cparams = whisper_context_default_params();

cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;

if (!params.dtw.empty()) {
Expand Down
20 changes: 17 additions & 3 deletions whisper.cpp/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ struct whisper_params {
bool print_realtime = false;
bool print_progress = false;
bool no_timestamps = false;
bool use_gpu = true;
bool flash_attn = false;

std::string language = "en";
Expand Down Expand Up @@ -184,7 +183,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-ng" || arg == "--no-gpu") { FLAG_gpu = LLAMAFILE_GPU_DISABLE; }
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
// server params
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
Expand All @@ -194,6 +193,22 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if ( arg == "--convert") { sparams.ffmpeg_converter = true; }
else if ( arg == "--recompile") { FLAG_recompile = true; }
else if ( arg == "--nocompile") { FLAG_nocompile = true; }
else if ( arg == "--tinyblas") { FLAG_tinyblas = true; }
else if ( arg == "--unsecure") { FLAG_unsecure = true; }

else if (arg == "--gpu") {
if (++i >= argc) {
fprintf(stderr, "error: missing --gpu flag value\n");
exit(1);
}
FLAG_gpu = llamafile_gpu_parse(argv[i]);
if (FLAG_gpu == LLAMAFILE_GPU_ERROR) {
fprintf(stderr, "error: invalid --gpu flag value: %s\n", argv[i]);
exit(1);
}
return true;
}

else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params, sparams);
Expand Down Expand Up @@ -515,7 +530,6 @@ int main(int argc, char ** argv) {
// whisper init
struct whisper_context_params cparams = whisper_context_default_params();

cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;

if (!params.dtw.empty()) {
Expand Down
37 changes: 23 additions & 14 deletions whisper.cpp/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
#include "whisper.h"

#define GGML_USE_CUDA
#define GGML_USE_METAL

#ifdef GGML_USE_METAL
#include "llama.cpp/ggml-metal.h"
#endif
Expand All @@ -15,6 +18,8 @@
#include "llama.cpp/ggml-alloc.h"
#include "llama.cpp/ggml-backend.h"

#include "llamafile/llamafile.h"

#include "whisper-mel.hpp"

#include <atomic>
Expand Down Expand Up @@ -208,7 +213,8 @@ static bool ggml_graph_compute_helper(
// and X_1 and Y_1 are the remaining views. X_1 and Y_1 end up being small matrices that can be processed with more
// general-purpose kernels
//
static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y, int pad = 32) {
static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y) {
int pad = 32;
// use padding only if dimension 0 is at least 8 times larger than the padding
// else we won't get much benefit from the optimization
const int n_pad_req = 8;
Expand All @@ -231,7 +237,7 @@ static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct g
// TODO: check if other platforms can benefit from this optimization
// TODO: CUDA is currently broken - seems ggml_mul_mat does not handle views correctly
#if defined(GGML_USE_METAL)
#define ggml_mul_mat ggml_mul_mat_pad
#define ggml_mul_mat (llamafile_has_metal() ? ggml_mul_mat_pad : ggml_mul_mat)
#endif

// available whisper models
Expand Down Expand Up @@ -1067,18 +1073,18 @@ static void whisper_kv_cache_seq_cp(
}

static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) {
if (!wctx.params.flash_attn || !wctx.params.use_gpu) {
if (!wctx.params.flash_attn) {
return 1u;
}

#ifdef GGML_USE_METAL
if (wctx.params.use_gpu) {
if (llamafile_has_metal()) {
return 32u;
}
#endif

#ifdef GGML_USE_CUDA
if (wctx.params.use_gpu) {
if (llamafile_has_cuda()) {
return 256u;
}
#endif
Expand Down Expand Up @@ -1221,7 +1227,7 @@ static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & pa
ggml_backend_t result = NULL;

#ifdef GGML_USE_CUDA
if (params.use_gpu) {
if (llamafile_has_cuda()) {
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
result = ggml_backend_cuda_init(params.gpu_device);
if (!result) {
Expand All @@ -1231,7 +1237,7 @@ static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & pa
#endif

#ifdef GGML_USE_METAL
if (params.use_gpu) {
if (!result && llamafile_has_metal()) {
WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
result = ggml_backend_metal_init();
Expand Down Expand Up @@ -1299,14 +1305,14 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
ggml_backend_buffer_type_t result = nullptr;

params.use_gpu || (result = ggml_backend_cpu_buffer_type());

#ifdef GGML_USE_CUDA
result || (result = ggml_backend_cuda_buffer_type(params.gpu_device));
if (!result && llamafile_has_cuda())
result = ggml_backend_cuda_buffer_type(params.gpu_device);
#endif

#ifdef GGML_USE_METAL
result || (result = ggml_backend_metal_buffer_type());
if (!result && llamafile_has_metal())
result = ggml_backend_metal_buffer_type();
#endif

#ifdef GGML_USE_SYCL
Expand All @@ -1317,7 +1323,8 @@ static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_cont
result || (result = ggml_backend_vk_buffer_type(params.gpu_device));
#endif

result || (result = ggml_backend_cpu_buffer_type());
if (!result)
result = ggml_backend_cpu_buffer_type();

return result;
}
Expand Down Expand Up @@ -3585,7 +3592,6 @@ int whisper_ctx_init_openvino_encoder(

struct whisper_context_params whisper_context_default_params() {
struct whisper_context_params result = {
/*.use_gpu =*/ true,
/*.flash_attn =*/ false,
/*.gpu_device =*/ 0,

Expand Down Expand Up @@ -3690,7 +3696,8 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_
params.dtw_token_timestamps = false;
}

WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu);
WHISPER_LOG_INFO("%s: cuda gpu = %d\n", __func__, llamafile_has_cuda());
WHISPER_LOG_INFO("%s: metal gpu = %d\n", __func__, llamafile_has_metal());
WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps);
Expand Down Expand Up @@ -7444,6 +7451,8 @@ static void whisper_log_internal(ggml_log_level level, const char * format, ...)
static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
(void) level;
(void) user_data;
if (FLAG_log_disable)
return;
fputs(text, stderr);
fflush(stderr);
}
1 change: 0 additions & 1 deletion whisper.cpp/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ extern "C" {
} whisper_aheads;

struct whisper_context_params {
bool use_gpu;
bool flash_attn;
int gpu_device; // CUDA device

Expand Down

0 comments on commit 0849f32

Please sign in to comment.