Skip to content

Commit

Permalink
Cache the module and compute_pipeline objects after first run for pot…
Browse files Browse the repository at this point in the history
…ential reuse (#1005)

* Cache the module and compute_pipeline objects after first run for potential reuse

* Remove use of 'this' and replace with 'gg'

* Make clippy happy
  • Loading branch information
dfellis authored Dec 11, 2024
1 parent 1af94aa commit d538837
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 25 deletions.
28 changes: 18 additions & 10 deletions alan_std.js
Original file line number Diff line number Diff line change
Expand Up @@ -687,21 +687,29 @@ export class GPGPU {
this.entrypoint = entrypoint ?? "main";
this.buffers = buffers;
this.workgroupSizes = workgroupSizes;
this.module = undefined;
this.computePipeline = undefined;
}
}

export async function gpuRun(gg) {
let g = await gpu();
let module = g.device.createShaderModule({
code: gg.source,
});
let computePipeline = g.device.createComputePipeline({
layout: "auto",
compute: {
entryPoint: gg.entrypoint,
module,
},
});
if (!gg.module) {
gg.module = g.device.createShaderModule({
code: gg.source,
});
}
let module = gg.module;
if (!gg.computePipeline) {
gg.computePipeline = g.device.createComputePipeline({
layout: "auto",
compute: {
entryPoint: gg.entryPoint,
module,
},
});
}
let computePipeline = gg.computePipeline;
let encoder = g.device.createCommandEncoder();
let cpass = encoder.beginComputePass();
cpass.setPipeline(computePipeline);
Expand Down
40 changes: 25 additions & 15 deletions alan_std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,8 @@ pub struct GPGPU {
pub entrypoint: String,
pub buffers: Vec<Vec<GBuffer>>,
pub workgroup_sizes: [i64; 3],
pub module: Option<wgpu::ShaderModule>,
pub compute_pipeline: Option<wgpu::ComputePipeline>,
}

impl GPGPU {
Expand All @@ -995,26 +997,34 @@ impl GPGPU {
entrypoint: "main".to_string(),
buffers,
workgroup_sizes,
module: None,
compute_pipeline: None,
}
}
}

pub fn gpu_run(gg: &GPGPU) {
pub fn gpu_run(gg: &mut GPGPU) {
let g = gpu();
let module = g.device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(&gg.source)),
});
let compute_pipeline = g
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
if gg.module.is_none() {
gg.module = Some(g.device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
layout: None,
module: &module,
entry_point: Some(&gg.entrypoint),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None, // TODO: Might be worthwhile
});
source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(&gg.source)),
}));
}
let module = gg.module.as_ref().unwrap();
if gg.compute_pipeline.is_none() {
gg.compute_pipeline = Some(g.device.create_compute_pipeline(
&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module,
entry_point: Some(&gg.entrypoint),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
},
));
}
let compute_pipeline = gg.compute_pipeline.as_ref().unwrap();
let mut bind_groups = Vec::new();
let mut encoder = g
.device
Expand All @@ -1024,7 +1034,7 @@ pub fn gpu_run(gg: &GPGPU) {
label: None,
timestamp_writes: None,
});
cpass.set_pipeline(&compute_pipeline);
cpass.set_pipeline(compute_pipeline);
for i in 0..gg.buffers.len() {
let bind_group_layout = compute_pipeline.get_bind_group_layout(i.try_into().unwrap());
let bind_group_buffers = &gg.buffers[i];
Expand Down

0 comments on commit d538837

Please sign in to comment.