From d538837d2c5ffe879a8349e676f8ac9e3a514cc6 Mon Sep 17 00:00:00 2001 From: David Ellis Date: Wed, 11 Dec 2024 16:19:02 -0600 Subject: [PATCH] Cache the module and compute_pipeline objects after first run for potential reuse (#1005) * Cache the module and compute_pipeline objects after first run for potential reuse * Remove use of 'this' and replace with 'gg' * Make clippy happy --- alan_std.js | 28 ++++++++++++++++++---------- alan_std/src/lib.rs | 40 +++++++++++++++++++++++++--------------- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/alan_std.js b/alan_std.js index 4d629da1e..5e9b10281 100644 --- a/alan_std.js +++ b/alan_std.js @@ -687,21 +687,29 @@ export class GPGPU { this.entrypoint = entrypoint ?? "main"; this.buffers = buffers; this.workgroupSizes = workgroupSizes; + this.module = undefined; + this.computePipeline = undefined; } } export async function gpuRun(gg) { let g = await gpu(); - let module = g.device.createShaderModule({ - code: gg.source, - }); - let computePipeline = g.device.createComputePipeline({ - layout: "auto", - compute: { - entryPoint: gg.entrypoint, - module, - }, - }); + if (!gg.module) { + gg.module = g.device.createShaderModule({ + code: gg.source, + }); + } + let module = gg.module; + if (!gg.computePipeline) { + gg.computePipeline = g.device.createComputePipeline({ + layout: "auto", + compute: { + entryPoint: gg.entryPoint, + module, + }, + }); + } + let computePipeline = gg.computePipeline; let encoder = g.device.createCommandEncoder(); let cpass = encoder.beginComputePass(); cpass.setPipeline(computePipeline); diff --git a/alan_std/src/lib.rs b/alan_std/src/lib.rs index c7e64c645..d8ca011f0 100644 --- a/alan_std/src/lib.rs +++ b/alan_std/src/lib.rs @@ -986,6 +986,8 @@ pub struct GPGPU { pub entrypoint: String, pub buffers: Vec>, pub workgroup_sizes: [i64; 3], + pub module: Option, + pub compute_pipeline: Option, } impl GPGPU { @@ -995,26 +997,34 @@ impl GPGPU { entrypoint: "main".to_string(), buffers, workgroup_sizes, + module: None, + compute_pipeline: None, } } } -pub fn gpu_run(gg: &GPGPU) { +pub fn gpu_run(gg: &mut GPGPU) { let g = gpu(); - let module = g.device.create_shader_module(wgpu::ShaderModuleDescriptor { - label: None, - source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(&gg.source)), - }); - let compute_pipeline = g - .device - .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + if gg.module.is_none() { + gg.module = Some(g.device.create_shader_module(wgpu::ShaderModuleDescriptor { label: None, - layout: None, - module: &module, - entry_point: Some(&gg.entrypoint), - compilation_options: wgpu::PipelineCompilationOptions::default(), - cache: None, // TODO: Might be worthwhile - }); + source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(&gg.source)), + })); + } + let module = gg.module.as_ref().unwrap(); + if gg.compute_pipeline.is_none() { + gg.compute_pipeline = Some(g.device.create_compute_pipeline( + &wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module, + entry_point: Some(&gg.entrypoint), + compilation_options: wgpu::PipelineCompilationOptions::default(), + cache: None, + }, + )); + } + let compute_pipeline = gg.compute_pipeline.as_ref().unwrap(); let mut bind_groups = Vec::new(); let mut encoder = g .device @@ -1024,7 +1034,7 @@ pub fn gpu_run(gg: &GPGPU) { label: None, timestamp_writes: None, }); - cpass.set_pipeline(&compute_pipeline); + cpass.set_pipeline(compute_pipeline); for i in 0..gg.buffers.len() { let bind_group_layout = compute_pipeline.get_bind_group_layout(i.try_into().unwrap()); let bind_group_buffers = &gg.buffers[i];