Skip to content

Commit

Permalink
Try a debug print approach
Browse files Browse the repository at this point in the history
  • Loading branch information
syl20bnr committed Jan 5, 2024
1 parent 1b54109 commit 71086c2
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ jobs:
# run: cargo test -p burn-wgpu --color=always -- --color=always --test-threads 1
# run: cargo test tests::module -p burn-wgpu --color=always -- --color=always --test-threads 1
# run: cargo test tests::matmul -p burn-wgpu --color=always -- --color=always --test-threads 1
run: cargo test can_run_kernel -p burn-wgpu --color=always -- --color=always --test-threads 1
run: cargo test can_run_kernel -p burn-wgpu --color=always -- --color=always --test-threads 1 --nocapture

# - name: Run cargo clippy for stable version
# if: runner.os == 'Linux' && matrix.rust == 'stable' && matrix.test == 'std'
Expand Down
9 changes: 7 additions & 2 deletions burn-compute/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@ where
where
Init: Fn() -> ComputeClient<Server, Channel>,
{
println!("dbg 4_1");
let mut clients = self.clients.lock();

println!("dbg 4_2");
if clients.is_none() {
Self::register_inner(device, init(), &mut clients);
}

println!("dbg 4_3");
match clients.deref_mut() {
Some(clients) => match clients.get(device) {
Some(client) => client.clone(),
Expand Down Expand Up @@ -68,16 +71,18 @@ where
client: ComputeClient<Server, Channel>,
clients: &mut Option<HashMap<Device, ComputeClient<Server, Channel>>>,
) {
println!("dbg 8_1");
if clients.is_none() {
*clients = Some(HashMap::new());
}

println!("dbg 8_2");
if let Some(clients) = clients {
if clients.contains_key(device) {
panic!("Client already created for device {:?}", device);
}

println!("dbg 8_3");
clients.insert(device.clone(), client);
println!("dbg 8_4");
}
}
}
18 changes: 17 additions & 1 deletion burn-wgpu/src/compute/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ static COMPUTE: Compute<WgpuDevice, WgpuServer<MemoryManagement>, Channel> = Com

/// Get the [compute client](ComputeClient) for the given [device](WgpuDevice).
pub fn compute_client<G: GraphicsApi>(device: &WgpuDevice) -> ComputeClient<Server, Channel> {
println!("dbg 3_1");
let device = Arc::new(device);

println!("dbg 3_2");
COMPUTE.client(&device, move || {
println!("dbg 5_1");
pollster::block_on(create_client::<G>(&device))
})
}
Expand All @@ -42,14 +44,17 @@ pub async fn init_async<G: GraphicsApi>(device: &WgpuDevice) {
}

async fn create_client<G: GraphicsApi>(device: &WgpuDevice) -> ComputeClient<Server, Channel> {
println!("dbg 6_1");
let (device_wgpu, queue, info) = select_device::<G>(device).await;

println!("dbg 6_2");
log::info!(
"Created wgpu compute server on device {:?} => {:?}",
device,
info
);

println!("dbg 6_3");
// TODO: Support a way to modify max_tasks without std.
let max_tasks = match std::env::var("BURN_WGPU_MAX_TASKS") {
Ok(value) => value
Expand All @@ -58,31 +63,41 @@ async fn create_client<G: GraphicsApi>(device: &WgpuDevice) -> ComputeClient<Ser
Err(_) => 64, // 64 tasks by default
};

println!("dbg 6_4");
let device = Arc::new(device_wgpu);
println!("dbg 6_5");
let storage = WgpuStorage::new(device.clone());
println!("dbg 6_6");
let memory_management = SimpleMemoryManagement::new(
storage,
DeallocStrategy::new_period_tick(max_tasks * 2),
SliceStrategy::Ratio(0.8),
);
println!("dbg 6_7");
let server = WgpuServer::new(memory_management, device, queue, max_tasks);
println!("dbg 6_8");
let channel = Channel::new(server);

println!("dbg 6_9");
ComputeClient::new(channel, Arc::new(Mutex::new(Tuner::new())))
}

/// Select the wgpu device and queue based on the provided [device](WgpuDevice).
pub async fn select_device<G: GraphicsApi>(
device: &WgpuDevice,
) -> (wgpu::Device, wgpu::Queue, wgpu::AdapterInfo) {
println!("dbg 7_1");
#[cfg(target_family = "wasm")]
let adapter = select_adapter::<G>(device).await;

println!("dbg 7_2");
#[cfg(not(target_family = "wasm"))]
let adapter = select_adapter::<G>(device);

println!("dbg 7_3");
let limits = adapter.limits();

println!("dbg 7_4");
let (device, queue) = adapter
.request_device(
&DeviceDescriptor {
Expand All @@ -102,6 +117,7 @@ pub async fn select_device<G: GraphicsApi>(
})
.unwrap();

dbg!((&device, &queue, adapter.get_info()));
(device, queue, adapter.get_info())
}

Expand Down
12 changes: 6 additions & 6 deletions burn-wgpu/src/compute/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ mod tests {

#[test]
fn can_run_kernel() {
println!("dbg 1_1");
binary!(
operator: |elem: Elem| Operator::Add {
lhs: Variable::Input(0, elem),
Expand All @@ -92,9 +93,9 @@ mod tests {
elem_in: f32,
elem_out: f32
);

println!("dbg 1_2");
let client = compute_client::<AutoGraphicsApi>(&WgpuDevice::default());

println!("dbg 1_3");
let lhs: Vec<f32> = vec![0., 1., 2., 3., 4., 5., 6., 7.];
let rhs: Vec<f32> = vec![10., 11., 12., 6., 7., 3., 1., 0.];
let info: Vec<u32> = vec![1, 1, 8, 1, 8, 1, 8];
Expand All @@ -103,16 +104,15 @@ mod tests {
let rhs = client.create(bytemuck::cast_slice(&rhs));
let out = client.empty(core::mem::size_of::<f32>() * 8);
let info = client.create(bytemuck::cast_slice(&info));

type Kernel =
KernelSettings<Ops<f32, f32>, f32, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>;
let kernel = Box::new(StaticKernel::<Kernel>::new(WorkGroup::new(1, 1, 1)));

println!("dbg 1_4");
client.execute(kernel, &[&lhs, &rhs, &out, &info]);

println!("dbg 1_5");
let data = client.read(&out).read_sync().unwrap();
let output: &[f32] = bytemuck::cast_slice(&data);

println!("dbg 1_6");
assert_eq!(output, [10., 12., 14., 9., 11., 8., 7., 7.]);
}
}
2 changes: 2 additions & 0 deletions burn-wgpu/src/graphics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ impl GraphicsApi for WebGpu {

impl GraphicsApi for AutoGraphicsApi {
fn backend() -> wgpu::Backend {
println!("dbg 2_1");
// Allow overriding AutoGraphicsApi backend with ENV var in std test environments
#[cfg(not(no_std))]
#[cfg(test)]
Expand All @@ -98,6 +99,7 @@ impl GraphicsApi for AutoGraphicsApi {
}
}

println!("dbg 2_2");
// In a no_std environment or if the environment variable is not set
#[cfg(target_os = "macos")]
return wgpu::Backend::Metal;
Expand Down

0 comments on commit 71086c2

Please sign in to comment.