Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed save_imatrix to match old behaviour for MoE #7099

Merged
merged 6 commits into from
May 8, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions examples/imatrix/imatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

struct Stats {
std::vector<float> values;
std::vector<int> counts;
int ncall = 0;
};

Expand Down Expand Up @@ -120,13 +121,14 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *

auto & e = m_stats[wname];

++e.ncall;
// NOTE: since we select top-k experts, the number of calls for the expert tensors will be k times larger
// using the following line, we can correct for that if needed by replacing the line above with:
//if (idx == t->src[0]->ne[0] - 1) ++e.ncall;
// We select top-k experts, the number of calls for the expert tensors will be k times larger.
// NOTE: This will trigger the "if (e.ncall > m_last_call)" save conditional on the first active expert.
// The commented out "if (idx == t->src[0]->ne[0] - 1) ++e.ncall;" doesn't work.
if (((int32_t *) t->op_params)[0] == 0) ++e.ncall;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the past only one expert was evaluated per mul_mat_id, and op_params was used to store the expert being evaluated, but that's no longer the case. op_params is not used anymore in mul_mat_id, so this condition doesn't really do anything, op_params will always be zero so it's always true.

Copy link
Contributor Author

@jukofyork jukofyork May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I didn't test it but the old if (idx == 0) did work.

What test can be done to test for the callback being the last for the MoE?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in effect, this change does nothing, e.ncall is increased unconditionally as it was before. I think that increasing ncall unconditionally here is the correct thing to do, since the count is later corrected in save_imatrix with your change.

Copy link
Member

@slaren slaren May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What test can be done to test for the callback being the last for the MoE?

Currently there is only one call to mul_mat_id regardless of the number of experts being used. This was changed in #6505.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in effect, this change does nothing, e.ncall is increased unconditionally as it was before. I think that increasing ncall unconditionally here is the correct thing to do, since the count is later corrected in save_imatrix with your change.

Yeah, it will still work:

(p.second.values[i] / static_cast<float>(p.second.counts[i])) * static_cast<float>(p.second.ncall)

Basically we divide down to get the actual mean based on how many actual values were added to an element values[i] and then multiply back up to get a value that can be used for the weighted combination of other imatrix files and for quantize to get back the original mean when it divides by ncall stored in the file.

So having the weighted combination scaled up by num-top-k won't effect either of these.

But this will still cause it to save the 10 chunks too often:

[1]3.4990,[2]2.7563,[3]2.8327,[4]2.8365,
save_imatrix: stored collected data after 10 chunks in wizard-lm-2:8x22b-f16.imatrix
[5]3.2415,[6]3.1667,[7]2.9011,[8]3.2475,[9]3.2100,
save_imatrix: stored collected data after 20 chunks in wizard-lm-2:8x22b-f16.imatrix
[10]3.5357,[11]3.7258,[12]3.6469,[13]3.9192,[14]4.2641,
save_imatrix: stored collected data after 30 chunks in wizard-lm-2:8x22b-f16.imatrix
[15]4.4561,[16]4.7251,[17]4.8591,[18]5.0424,[19]5.1595,

vs

[1]6.8864,[2]5.5590,[3]4.6385,[4]5.2093,[5]5.6050,[6]4.6732,[7]4.7876,[8]5.3775,[9]5.6677,
save_imatrix: stored collected data after 10 chunks in dbrx:16x12b-instruct-f16.imatrix
[10]5.4960,[11]5.8453,[12]6.4653,[13]6.7705,[14]7.1977,[15]7.3001,[16]7.4528,[17]7.6426,[18]7.2825,[19]7.3690,
save_imatrix: stored collected data after 20 chunks in dbrx:16x12b-instruct-f16.imatrix
[20]7.4835,[21]7.8310,[22]7.9035,[23]7.7323,[24]7.6813,[25]7.4121,[26]7.3496,[27]7.3934,[28]7.8041,[29]7.9666,
save_imatrix: stored collected data after 30 chunks in dbrx:16x12b-instruct-f16.imatrix
[30]8.1926,[31]8.3989,[32]8.6105,[33]8.7318,[34]8.8261,[35]8.8406,[36]8.8695,[37]9.0027,[38]9.0287,[39]8.9052,
save_imatrix: stored collected data after 40 chunks in dbrx:16x12b-instruct-f16.imatrix

and with the debug output on in quantize, will print num-top-k times more for each ncall for the experts:

load_imatrix: loaded data (size = 172032, ncall =    364) for 'blk.38.ffn_down_exps.weight'
load_imatrix: loaded data (size =   6144, ncall =     91) for 'blk.38.ffn_gate_inp.weight'
load_imatrix: loaded data (size =   6144, ncall =     91) for 'blk.38.attn_output.weight'
load_imatrix: loaded data (size = 172032, ncall =    364) for 'blk.37.ffn_down_exps.weight'
load_imatrix: loaded data (size =  98304, ncall =    364) for 'blk.37.ffn_gate_exps.weight'
load_imatrix: loaded data (size =   6144, ncall =     91) for 'blk.37.ffn_gate_inp.weight'
load_imatrix: loaded data (size =   6144, ncall =     91) for 'blk.37.attn_output.weight'

vs

load_imatrix: loaded data (size = 172032, ncall =     91) for 'blk.38.ffn_down_exps.weight'
load_imatrix: loaded data (size =   6144, ncall =     91) for 'blk.38.ffn_gate_inp.weight'
load_imatrix: loaded data (size =   6144, ncall =     91) for 'blk.38.attn_output.weight'
load_imatrix: loaded data (size = 172032, ncall =     91) for 'blk.37.ffn_down_exps.weight'
load_imatrix: loaded data (size =  98304, ncall =     91) for 'blk.37.ffn_gate_exps.weight'
load_imatrix: loaded data (size =   6144, ncall =     91) for 'blk.37.ffn_gate_inp.weight'
load_imatrix: loaded data (size =   6144, ncall =     91) for 'blk.37.attn_output.weight'

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I understand correctly, you based these changes on a build before #6505 was merged, and the results that show a higher number of ncall for the moe tensors is with a build without #6505, correct?

Copy link
Member

@slaren slaren May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this PR with ncall increased unconditionally with mixtral and it seems to produce the expected results:

compute_imatrix: computing over 50 chunks with batch_size 512
compute_imatrix: 1.46 seconds per pass - ETA 1.22 minutes
[1]3.3282,[2]5.5064,[3]5.7696,[4]6.0597,[5]6.6383,[6]6.4067,[7]6.0626,[8]6.1729,[9]6.3318,
save_imatrix: stored collected data after 10 chunks in imatrix.dat
[10]5.8754,[11]5.6783,[12]5.8278,[13]5.8804,[14]5.7391,[15]5.9534,[16]5.9483,[17]5.9110,[18]6.0203,[19]5.9764,
save_imatrix: stored collected data after 20 chunks in imatrix.dat
[20]5.9101,[21]5.8586,[22]5.8696,[23]5.9431,[24]5.9631,[25]6.0114,[26]6.0204,[27]5.9588,[28]5.7325,[29]5.7142,
save_imatrix: stored collected data after 30 chunks in imatrix.dat
[30]5.6387,[31]5.5779,[32]5.4650,[33]5.4179,[34]5.3390,[35]5.2645,[36]5.2147,[37]5.1724,[38]5.1585,[39]5.1434,
save_imatrix: stored collected data after 40 chunks in imatrix.dat
[40]5.1864,[41]5.1752,[42]5.1467,[43]5.0827,[44]5.0719,[45]5.0194,[46]5.0461,[47]5.0968,[48]5.1533,[49]5.1977,
save_imatrix: stored collected data after 50 chunks in imatrix.dat
[50]5.1661,
Final estimate: PPL = 5.1661 +/- 0.10175

save_imatrix: stored collected data after 50 chunks in imatrix.dat
load_imatrix: loaded data (size = 114688, ncall =     50) for 'blk.31.ffn_down_exps.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.31.attn_k.weight'
load_imatrix: loaded data (size = 114688, ncall =     50) for 'blk.30.ffn_down_exps.weight'
load_imatrix: loaded data (size =  32768, ncall =     50) for 'blk.30.ffn_up_exps.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.30.ffn_gate_inp.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.30.attn_v.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.30.attn_k.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.30.attn_q.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.29.ffn_gate_inp.weight'
load_imatrix: loaded data (size =  32768, ncall =     50) for 'blk.31.ffn_up_exps.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.29.attn_v.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.29.attn_k.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.29.attn_q.weight'
load_imatrix: loaded data (size = 114688, ncall =     50) for 'blk.28.ffn_down_exps.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.28.attn_v.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.28.attn_q.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.31.attn_q.weight'
load_imatrix: loaded data (size = 114688, ncall =     50) for 'blk.27.ffn_down_exps.weight'
load_imatrix: loaded data (size =  32768, ncall =     50) for 'blk.27.ffn_gate_exps.weight'
load_imatrix: loaded data (size =  32768, ncall =     50) for 'blk.27.ffn_up_exps.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.27.ffn_gate_inp.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.27.attn_v.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.27.attn_k.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.27.attn_output.weight'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I understand correctly, you based these changes on a build before #6505 was merged, and the results that show a higher number of ncall for the moe tensors is with a build without #6505, correct?

Yeah, I used the build right after the dbrx PR was pushed as originally went on this search after having lots of trouble quantizing it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this PR with ncall increased unconditionally with mixtral and it seems to produce the expected results:

compute_imatrix: computing over 50 chunks with batch_size 512
compute_imatrix: 1.46 seconds per pass - ETA 1.22 minutes
[1]3.3282,[2]5.5064,[3]5.7696,[4]6.0597,[5]6.6383,[6]6.4067,[7]6.0626,[8]6.1729,[9]6.3318,
save_imatrix: stored collected data after 10 chunks in imatrix.dat
[10]5.8754,[11]5.6783,[12]5.8278,[13]5.8804,[14]5.7391,[15]5.9534,[16]5.9483,[17]5.9110,[18]6.0203,[19]5.9764,
save_imatrix: stored collected data after 20 chunks in imatrix.dat
[20]5.9101,[21]5.8586,[22]5.8696,[23]5.9431,[24]5.9631,[25]6.0114,[26]6.0204,[27]5.9588,[28]5.7325,[29]5.7142,
save_imatrix: stored collected data after 30 chunks in imatrix.dat
[30]5.6387,[31]5.5779,[32]5.4650,[33]5.4179,[34]5.3390,[35]5.2645,[36]5.2147,[37]5.1724,[38]5.1585,[39]5.1434,
save_imatrix: stored collected data after 40 chunks in imatrix.dat
[40]5.1864,[41]5.1752,[42]5.1467,[43]5.0827,[44]5.0719,[45]5.0194,[46]5.0461,[47]5.0968,[48]5.1533,[49]5.1977,
save_imatrix: stored collected data after 50 chunks in imatrix.dat
[50]5.1661,
Final estimate: PPL = 5.1661 +/- 0.10175

save_imatrix: stored collected data after 50 chunks in imatrix.dat
load_imatrix: loaded data (size = 114688, ncall =     50) for 'blk.31.ffn_down_exps.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.31.attn_k.weight'
load_imatrix: loaded data (size = 114688, ncall =     50) for 'blk.30.ffn_down_exps.weight'
load_imatrix: loaded data (size =  32768, ncall =     50) for 'blk.30.ffn_up_exps.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.30.ffn_gate_inp.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.30.attn_v.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.30.attn_k.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.30.attn_q.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.29.ffn_gate_inp.weight'
load_imatrix: loaded data (size =  32768, ncall =     50) for 'blk.31.ffn_up_exps.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.29.attn_v.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.29.attn_k.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.29.attn_q.weight'
load_imatrix: loaded data (size = 114688, ncall =     50) for 'blk.28.ffn_down_exps.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.28.attn_v.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.28.attn_q.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.31.attn_q.weight'
load_imatrix: loaded data (size = 114688, ncall =     50) for 'blk.27.ffn_down_exps.weight'
load_imatrix: loaded data (size =  32768, ncall =     50) for 'blk.27.ffn_gate_exps.weight'
load_imatrix: loaded data (size =  32768, ncall =     50) for 'blk.27.ffn_up_exps.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.27.ffn_gate_inp.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.27.attn_v.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.27.attn_k.weight'
load_imatrix: loaded data (size =   4096, ncall =     50) for 'blk.27.attn_output.weight'

Yeah, it looks like that can just be left unconditional then.

It's probably worth trying to re-quantize mixtral with and without these fixes too, just in case something else has changed since then.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably worth trying to re-quantize mixtral with and without these fixes too, just in case something else has changed since then.

I will give it a try with a low number of chunks, but I don't have enough VRAM to create a imatrix for mixtral with the full wiki.train.raw in a reasonable amount of time.


if (e.values.empty()) {
e.values.resize(src1->ne[0]*n_as, 0);
e.counts.resize(src1->ne[0]*n_as, 0);
}
else if (e.values.size() != (size_t)src1->ne[0]*n_as) {
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as);
Expand All @@ -153,6 +155,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *

for (int j = 0; j < (int)src1->ne[0]; ++j) {
e.values[e_start + j] += x[j]*x[j];
e.counts[e_start + j]++;
}
}
}
Expand All @@ -170,6 +173,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
auto& e = m_stats[wname];
if (e.values.empty()) {
e.values.resize(src1->ne[0], 0);
e.counts.resize(src1->ne[0], 0);
}
else if (e.values.size() != (size_t)src1->ne[0]) {
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]);
Expand All @@ -183,6 +187,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
const float * x = data + row * src1->ne[0];
for (int j = 0; j < (int)src1->ne[0]; ++j) {
e.values[j] += x[j]*x[j];
e.counts[j]++;
}
}
if (e.ncall > m_last_call) {
Expand Down Expand Up @@ -222,7 +227,13 @@ void IMatrixCollector::save_imatrix(const char * fname, const char * dataset) co
out.write((const char *) &p.second.ncall, sizeof(p.second.ncall));
int nval = p.second.values.size();
out.write((const char *) &nval, sizeof(nval));
if (nval > 0) out.write((const char *) p.second.values.data(), nval * sizeof(float));
if (nval > 0) {
std::vector<float> tmp(nval);
for (int i = 0; i < nval; i++) {
tmp[i] = (p.second.values[i] / static_cast<float>(p.second.counts[i])) * static_cast<float>(p.second.ncall);
}
out.write((const char*)tmp.data(), nval*sizeof(float));
}
}

// Write the number of call the matrix was computed with
Expand Down
Loading