From 8d205b68e313aaf37afb1436d74b86cffd7c262e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sun, 16 Jul 2023 18:39:08 +0200 Subject: [PATCH 1/5] Adding a parallelism bench. --- gemm/Cargo.toml | 1 + gemm/benches/bench.rs | 116 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 116 insertions(+), 1 deletion(-) diff --git a/gemm/Cargo.toml b/gemm/Cargo.toml index 00eca3f..de19b0a 100644 --- a/gemm/Cargo.toml +++ b/gemm/Cargo.toml @@ -37,6 +37,7 @@ rand = "0.8.5" nalgebra = "0.32.2" assert_approx_eq = "1.1.0" rayon = "1.7" +num_cpus = "1.16.0" [[bench]] name = "bench" diff --git a/gemm/benches/bench.rs b/gemm/benches/bench.rs index 04ceeaf..95d0112 100644 --- a/gemm/benches/bench.rs +++ b/gemm/benches/bench.rs @@ -256,6 +256,112 @@ pub fn criterion_benchmark(c: &mut Criterion) { } } +pub fn criterion_benchmark_parallelism(c: &mut Criterion) { + let mnks = vec![(6, 768 * 3, 768)]; + // let mut push = |m, n, k| { + // mnks.push((m, n, k)); + // }; + // push(64, 64, 64); + // push(8192, 8192, 8192); + // push(4096, 4096, 4096); + // push(1024, 1024, 1024); + // push(896, 128, 128); + // push(512, 256, 256); + // push(448, 448, 128); + // push(256, 256, 256); + // push(256, 32, 256); + // push(52, 52, 256); + // push(48, 48, 256); + // push(63, 1, 10); + // push(63, 2, 10); + // push(63, 3, 10); + // push(63, 4, 10); + + // push(1024, 1, 1024); + // push(1024, 2, 1024); + // push(1024, 3, 1024); + // push(1024, 4, 1024); + // + let n_cpus = num_cpus::get(); + + for (m, n, k) in mnks.iter().copied() { + let a_vec = vec![0.0_f32; m * k]; + let b_vec = vec![0.0_f32; k * n]; + let mut c_vec = vec![0.0_f32; m * n]; + + for (dst_label, dst_cs, dst_rs) in [("n", m, 1), ("t", 1, n)] { + for (lhs_label, lhs_cs, lhs_rs) in [("n", m, 1), ("t", 1, k)] { + for (rhs_label, rhs_cs, rhs_rs) in [("n", k, 1), ("t", 1, n)] { + c.bench_function( + &format!( + "parallelism-{}-f32-{}{}{}-gemm-{}×{}×{}", + n_cpus, dst_label, lhs_label, rhs_label, m, n, k + ), + |b| { + b.iter(|| unsafe { + gemm( + m, + n, + k, + c_vec.as_mut_ptr(), + dst_cs as isize, + dst_rs as isize, + true, + a_vec.as_ptr(), + lhs_cs as isize, + lhs_rs as isize, + b_vec.as_ptr(), + rhs_cs as isize, + rhs_rs as isize, + 0.0_f32, + 0.0_f32, + false, + false, + false, + gemm::Parallelism::Rayon(n_cpus), + ) + }) + }, + ); + c.bench_function( + &format!( + "parallelism-none-f32-{}{}{}-gemm-{}×{}×{}", + dst_label, lhs_label, rhs_label, m, n, k + ), + |b| { + b.iter(|| unsafe { + gemm( + m, + n, + k, + c_vec.as_mut_ptr(), + dst_cs as isize, + dst_rs as isize, + true, + a_vec.as_ptr(), + lhs_cs as isize, + lhs_rs as isize, + b_vec.as_ptr(), + rhs_cs as isize, + rhs_rs as isize, + 0.0_f32, + 0.0_f32, + false, + false, + false, + gemm::Parallelism::None, + ) + }) + }, + ); + } + } + } + + } + +} + criterion_group!( name = benches; config = Criterion::default() @@ -264,4 +370,12 @@ criterion_group!( .sample_size(10); targets = criterion_benchmark ); -criterion_main!(benches); +criterion_group!( + name = benches_parallelism; + config = Criterion::default() + .warm_up_time(Duration::from_secs(1)) + .measurement_time(Duration::from_secs(2)) + .sample_size(10); + targets = criterion_benchmark_parallelism +); +criterion_main!(benches, benches_parallelism); From 0da21288e47f1db87f97ed0b3e273c0103b8656a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 20 Jul 2023 07:42:13 +0000 Subject: [PATCH 2/5] Fixing large multi-threading (-40% improvement for `parallelism` benchmark) --- gemm-common/src/gemm.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/gemm-common/src/gemm.rs b/gemm-common/src/gemm.rs index 779482a..b61151d 100644 --- a/gemm-common/src/gemm.rs +++ b/gemm-common/src/gemm.rs @@ -97,7 +97,7 @@ impl Conj for c64 { } } -pub const DEFAULT_THREADING_THRESHOLD: usize = 48 * 48 * 256; +pub const DEFAULT_THREADING_THRESHOLD: usize = 64 * 64 * 256; pub const DEFAULT_RHS_PACKING_THRESHOLD: usize = 128; pub const DEFAULT_LHS_PACKING_THRESHOLD_SINGLE_THREAD: usize = 8; pub const DEFAULT_LHS_PACKING_THRESHOLD_MULTI_THREAD: usize = 16; @@ -345,17 +345,17 @@ pub unsafe fn gemm_basic_generic< let n_threads = match parallelism { Parallelism::None => 1, - Parallelism::Rayon(n_threads) => { + Parallelism::Rayon(max_threads) => { let threading_threshold = get_threading_threshold(); - if m * n_chunk * k_chunk <= threading_threshold { - 1 + + let max_threads = if n_threads == 0 { + rayon::current_num_threads() } else { - if n_threads == 0 { - rayon::current_num_threads() - } else { - n_threads - } - } + max_threads + }; + let total_work = m * n_chunk * k_chunk; + let n_threads = std::cmp::max(1, std::cmp::min(max_threads, (total_work - threading_threshold + 1) / threading_threshold)); + n_threads } }; From c1a5b31c8b5b673639001b1e7d67be5d5c66e59e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 20 Jul 2023 07:44:10 +0000 Subject: [PATCH 3/5] Fix. --- gemm-common/src/gemm.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gemm-common/src/gemm.rs b/gemm-common/src/gemm.rs index b61151d..8ad8d5b 100644 --- a/gemm-common/src/gemm.rs +++ b/gemm-common/src/gemm.rs @@ -348,7 +348,7 @@ pub unsafe fn gemm_basic_generic< Parallelism::Rayon(max_threads) => { let threading_threshold = get_threading_threshold(); - let max_threads = if n_threads == 0 { + let max_threads = if max_threads == 0 { rayon::current_num_threads() } else { max_threads From 76ea6bd6076671b22ff481a6052246e496d36e76 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 26 Jul 2023 13:28:24 +0200 Subject: [PATCH 4/5] Fix tests. --- gemm-common/src/gemm.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/gemm-common/src/gemm.rs b/gemm-common/src/gemm.rs index 8ad8d5b..3ec6f16 100644 --- a/gemm-common/src/gemm.rs +++ b/gemm-common/src/gemm.rs @@ -97,7 +97,7 @@ impl Conj for c64 { } } -pub const DEFAULT_THREADING_THRESHOLD: usize = 64 * 64 * 256; +pub const DEFAULT_THREADING_THRESHOLD: usize = 48 * 48 * 256; pub const DEFAULT_RHS_PACKING_THRESHOLD: usize = 128; pub const DEFAULT_LHS_PACKING_THRESHOLD_SINGLE_THREAD: usize = 8; pub const DEFAULT_LHS_PACKING_THRESHOLD_MULTI_THREAD: usize = 16; @@ -354,7 +354,10 @@ pub unsafe fn gemm_basic_generic< max_threads }; let total_work = m * n_chunk * k_chunk; - let n_threads = std::cmp::max(1, std::cmp::min(max_threads, (total_work - threading_threshold + 1) / threading_threshold)); + let n_threads = if total_work > threading_threshold{ + std::cmp::max(1, std::cmp::min(max_threads, (total_work - threading_threshold + 1) / threading_threshold)) + }else{1} + ; n_threads } }; From b11ea6f7ebfe7f6c817256815e173754b6898a61 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 26 Jul 2023 13:30:08 +0200 Subject: [PATCH 5/5] Format. --- gemm-common/src/gemm.rs | 15 +++++++++++---- gemm/benches/bench.rs | 2 -- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/gemm-common/src/gemm.rs b/gemm-common/src/gemm.rs index 3ec6f16..66c32be 100644 --- a/gemm-common/src/gemm.rs +++ b/gemm-common/src/gemm.rs @@ -354,10 +354,17 @@ pub unsafe fn gemm_basic_generic< max_threads }; let total_work = m * n_chunk * k_chunk; - let n_threads = if total_work > threading_threshold{ - std::cmp::max(1, std::cmp::min(max_threads, (total_work - threading_threshold + 1) / threading_threshold)) - }else{1} - ; + let n_threads = if total_work > threading_threshold { + std::cmp::max( + 1, + std::cmp::min( + max_threads, + (total_work - threading_threshold + 1) / threading_threshold, + ), + ) + } else { + 1 + }; n_threads } }; diff --git a/gemm/benches/bench.rs b/gemm/benches/bench.rs index 95d0112..415cf26 100644 --- a/gemm/benches/bench.rs +++ b/gemm/benches/bench.rs @@ -357,9 +357,7 @@ pub fn criterion_benchmark_parallelism(c: &mut Criterion) { } } } - } - } criterion_group!(