Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

WIP: Bloom Inference #85

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 202 additions & 2 deletions ggml-sys/ggml/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -2644,7 +2644,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"FLASH_FF",
};

static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand Down Expand Up @@ -2688,7 +2688,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"flash_ff(x)",
};

static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36");

//
// ggml object
Expand Down Expand Up @@ -4709,6 +4709,37 @@ struct ggml_tensor * ggml_rope(
return result;
}


// ggml_alibi
struct ggml_tensor * ggml_alibi(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_head) {
GGML_ASSERT(n_past >= 0);
bool is_node = false;

if (a->grad) {
GGML_ASSERT(false); // TODO: implement backward
is_node = true;
}

// TODO: when implement backward, fix this:
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
struct ggml_tensor * result = ggml_view_tensor(ctx, a);

struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
((int32_t *) b->data)[0] = n_past;
((int32_t *) b->data)[1] = n_head;

result->op = GGML_OP_ALIBI;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = b;

return result;
}

// ggml_conv_1d_1s

struct ggml_tensor * ggml_conv_1d_1s(
Expand Down Expand Up @@ -7192,6 +7223,163 @@ static void ggml_compute_forward_soft_max(
}
}

// ggml_compute_forward_alibi

static void ggml_compute_forward_alibi_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
assert(params->ith == 0);
assert(src1->type == GGML_TYPE_I32);
assert(ggml_nelements(src1) == 3);

if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}

const int n_past = ((int32_t *) src1->data)[0];
const int n_head = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];

const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int ne1 = src0->ne[1]; // seq_len_without_past
const int ne2 = src0->ne[2]; // n_head -> this is k
const int ne3 = src0->ne[3]; // 1 -> bsz

const int n = ggml_nrows(src0);
const int ne2_ne3 = n/ne1; // ne2*ne3

const int nb0 = src0->nb[0];
const int nb1 = src0->nb[1];
const int nb2 = src0->nb[2];
const int nb3 = src0->nb[3];


// printf("\nne0: %d, ne1: %d, ne2: %d, ne3: %d", ne0, ne1, ne2, ne3);
// printf("\nn_past = %d, ne2 = %d", n_past, ne2);

assert(nb0 == sizeof(float));
assert(ne1+n_past == ne0);

// add alibi to src0 (KQ_scaled)
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
const float m0 = pow(2.0, -8.0 / n_heads_log2_floor);
const float m1 = pow(2.0, -4.0 / n_heads_log2_floor);

for (int i = 0; i < ne0; i++) {
for (int j = 0; j < ne1; j++) {
for (int k = 0; k < ne2_ne3; k++) {
float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
float * dst_data = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);

// TODO: k*nb2 or k*nb3

float m_k;
if (k < n_heads_log2_floor) {
m_k = pow(m0, k + 1);
} else {
m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
}
//TODO: optimize
dst_data[0] = (j+1) * m_k + src[0];
}
}
}

}


static void ggml_compute_forward_alibi_f16(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
assert(params->ith == 0);
assert(src1->type == GGML_TYPE_I32);
assert(ggml_nelements(src1) == 3);

if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}

const int n_past = ((int32_t *) src1->data)[0];
const int n_head = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];

const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int ne1 = src0->ne[1]; // seq_len_without_past
const int ne2 = src0->ne[2]; // n_head -> this is k
const int ne3 = src0->ne[3]; // 1 -> bsz

const int n = ggml_nrows(src0);
const int ne2_ne3 = n/ne1; // ne2*ne3

const int nb0 = src0->nb[0];
const int nb1 = src0->nb[1];
const int nb2 = src0->nb[2];
const int nb3 = src0->nb[3];


// printf("\nne0: %d, ne1: %d, ne2: %d, ne3: %d", ne0, ne1, ne2, ne3);
// printf("\nn_past = %d, ne2 = %d", n_past, ne2);

assert(nb0 == sizeof(float));
assert(ne1+n_past == ne0);

// add alibi to src0 (KQ_scaled)
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
const ggml_fp16_t m0 = pow(2.0, -8.0 / n_heads_log2_floor);
const ggml_fp16_t m1 = pow(2.0, -4.0 / n_heads_log2_floor);

for (int i = 0; i < ne0; i++) {
for (int j = 0; j < ne1; j++) {
for (int k = 0; k < ne2_ne3; k++) {
ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);

// TODO: k*nb2 or k*nb3

ggml_fp16_t m_k;
if (k < n_heads_log2_floor) {
m_k = pow(m0, k + 1);
} else {
m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
}
//TODO: optimize
dst_data[0] = (j+1) * m_k + src[0];
}
}
}

}

static void ggml_compute_forward_alibi(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F16:
{
ggml_compute_forward_alibi_f16(params, src0, src1, dst);
} break;
case GGML_TYPE_F32:
{
ggml_compute_forward_alibi_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_COUNT:
{
GGML_ASSERT(false);
} break;
}
}

// ggml_compute_forward_rope

static void ggml_compute_forward_rope_f32(
Expand Down Expand Up @@ -8691,6 +8879,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
} break;
case GGML_OP_ALIBI:
{
ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor);
} break;
case GGML_OP_CONV_1D_1S:
{
ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor);
Expand Down Expand Up @@ -8881,6 +9073,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_ALIBI:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_SILU:
{
GGML_ASSERT(false); // TODO: not implemented
Expand Down Expand Up @@ -9387,6 +9583,10 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
{
node->n_tasks = 1;
} break;
case GGML_OP_ALIBI:
{
node->n_tasks = 1; //TODO
} break;
case GGML_OP_CONV_1D_1S:
case GGML_OP_CONV_1D_2S:
{
Expand Down
11 changes: 11 additions & 0 deletions ggml-sys/ggml/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ enum ggml_op {
GGML_OP_DIAG_MASK_INF,
GGML_OP_SOFT_MAX,
GGML_OP_ROPE,
GGML_OP_ALIBI,
GGML_OP_CONV_1D_1S,
GGML_OP_CONV_1D_2S,

Expand Down Expand Up @@ -599,6 +600,16 @@ struct ggml_tensor * ggml_rope(
int n_dims,
int mode);

// alibi position embedding
// in-place, returns view(a)
struct ggml_tensor * ggml_alibi(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_head);



// padding = 1
// TODO: we don't support extra parameters for now
// that's why we are hard-coding the stride, padding, and dilation
Expand Down
11 changes: 11 additions & 0 deletions ggml-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,17 @@ extern "C" {
filename: *const ::std::os::raw::c_char,
);
}


extern "C" {
pub fn ggml_alibi(
ctx: *mut ggml_context,
a: *mut ggml_tensor,
n_past: ::std::os::raw::c_int,
n_head: ::std::os::raw::c_int,
) -> *mut ggml_tensor;
}

pub const ggml_opt_type_GGML_OPT_ADAM: ggml_opt_type = 0;
pub const ggml_opt_type_GGML_OPT_LBFGS: ggml_opt_type = 1;
pub type ggml_opt_type = ::std::os::raw::c_uint;
Expand Down
42 changes: 42 additions & 0 deletions ggml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,28 @@ impl Context {
self.new_tensor_raw(tensor)
}

/// Creates a 2D view over `a`.
pub fn op_view_2d(
&self,
a: &Tensor,
ne0: usize,
ne1: usize,
nb1: usize,
offset: usize,
) -> Tensor {
let tensor = unsafe {
ggml_sys::ggml_view_2d(
self.ptr.as_ptr(),
a.ptr.as_ptr(),
usize_to_i64(ne0),
usize_to_i64(ne1),
nb1,
offset,
)
};
self.new_tensor_raw(tensor)
}

/// Copies `a` to `b` and returns `b`.
pub fn op_cpy(&self, a: &Tensor, b: &Tensor) -> Tensor {
let tensor =
Expand Down Expand Up @@ -271,6 +293,26 @@ impl Context {
pub fn used_mem(&self) -> usize {
unsafe { ggml_sys::ggml_used_mem(self.ptr.as_ptr()) }
}

/// TODO: something something
pub fn op_alibi(&self, a: &Tensor, n_past: usize, n_head: usize) -> Tensor {
let tensor = unsafe {
ggml_sys::ggml_alibi(
self.ptr.as_ptr(),
a.ptr.as_ptr(),
usize_to_i32(n_past),
usize_to_i32(n_head),
)
};

self.new_tensor_raw(tensor)
}

/// Gaussian Error Linear Units
pub fn op_gelu(&self, a: &Tensor) -> Tensor {
let tensor = unsafe { ggml_sys::ggml_gelu(self.ptr.as_ptr(), a.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
}
}

impl Drop for Context {
Expand Down
9 changes: 6 additions & 3 deletions llama-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use std::path::PathBuf;

use clap::Parser;
use llama_rs::TokenBias;
use llama_rs::common::token::TokenBias;
use once_cell::sync::Lazy;
use std::path::PathBuf;

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
Expand All @@ -29,6 +28,10 @@ pub struct Args {
#[arg(long, short = 'R', default_value_t = false)]
pub repl: bool,

/// Run in bloom mode
#[arg(long, short = 'B', default_value_t = false)]
pub bloom: bool,

/// Sets the number of threads to use
#[arg(long, short = 't')]
pub num_threads: Option<usize>,
Expand Down
Loading