Skip to content

Commit

Permalink
llama : assert input batch with pooling enabled
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Mar 4, 2024
1 parent c23c554 commit 1af2d06
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8097,13 +8097,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
const int64_t n_tokens = batch.n_tokens;

GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
float * data = (float *) lctx.inp_mean->data;

float * data = (float *) lctx.inp_mean->data;
memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));

std::vector<uint64_t> sum(n_tokens, 0);
for (int i = 0; i < n_tokens; ++i) {
const llama_seq_id seq_id = batch.seq_id[i][0];

GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");

sum[seq_id] += 1;
}

Expand All @@ -8127,10 +8130,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));

uint32_t * data = (uint32_t *) lctx.inp_cls->data;
memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));

for (int i = 0; i < n_tokens; ++i) {
const llama_seq_id seq_id = batch.seq_id[i][0];
const llama_pos pos = batch.pos[i];

GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");

if (pos == 0) {
data[seq_id] = i;
}
Expand Down

0 comments on commit 1af2d06

Please sign in to comment.