From aa3ccbd85cd56f4fc21c7aad9975f92abeb69b50 Mon Sep 17 00:00:00 2001 From: Charlotte McElwain Date: Tue, 10 Sep 2024 10:37:31 -0700 Subject: [PATCH] Move instanced to custom draw cmd. --- bevy_nannou_draw/src/draw/instanced.rs | 310 +++++++++++++++++- bevy_nannou_draw/src/render.rs | 31 +- .../shaders/particle_mouse_material.wgsl | 7 +- examples/compute/particle_mouse.rs | 2 +- 4 files changed, 327 insertions(+), 23 deletions(-) diff --git a/bevy_nannou_draw/src/draw/instanced.rs b/bevy_nannou_draw/src/draw/instanced.rs index c546f3adb..ab381e2bb 100644 --- a/bevy_nannou_draw/src/draw/instanced.rs +++ b/bevy_nannou_draw/src/draw/instanced.rs @@ -3,26 +3,30 @@ use crate::draw::drawing::Drawing; use crate::draw::primitive::Primitive; use crate::draw::{Draw, DrawCommand}; +use crate::render::{PreparedShaderModel, ShaderModel}; use bevy::core_pipeline::core_3d::Opaque3dBinKey; -use bevy::pbr::{MaterialPipeline, MaterialPipelineKey, PreparedMaterial, SetMaterialBindGroup}; +use bevy::pbr::{MaterialPipeline, MaterialPipelineKey, PreparedMaterial}; +use bevy::render::extract_component::ExtractComponentPlugin; +use bevy::render::extract_instances::ExtractedInstances; use bevy::render::mesh::allocator::MeshAllocator; use bevy::render::mesh::RenderMeshBufferInfo; use bevy::render::render_asset::prepare_assets; use bevy::render::render_phase::{BinnedRenderPhaseType, ViewBinnedRenderPhases}; +use bevy::render::storage::{GpuShaderStorageBuffer, ShaderStorageBuffer}; +use bevy::render::view; +use bevy::render::view::VisibilitySystems; use bevy::{ core_pipeline::core_3d::Opaque3d, ecs::system::{lifetimeless::*, SystemParamItem}, - pbr::{ - MeshPipeline, MeshPipelineKey, RenderMeshInstances, SetMeshBindGroup, SetMeshViewBindGroup, - }, + pbr::{MeshPipeline, MeshPipelineKey, RenderMeshInstances, SetMeshViewBindGroup}, prelude::*, render::{ extract_component::ExtractComponent, mesh::{MeshVertexBufferLayoutRef, RenderMesh}, render_asset::RenderAssets, render_phase::{ - AddRenderCommand, DrawFunctions, PhaseItem, PhaseItemExtraIndex, RenderCommand, - RenderCommandResult, SetItemPipeline, TrackedRenderPass, + AddRenderCommand, DrawFunctions, PhaseItem, RenderCommand, RenderCommandResult, + SetItemPipeline, TrackedRenderPass, }, render_resource::*, renderer::RenderDevice, @@ -97,3 +101,297 @@ where .push(Some(DrawCommand::Instanced(primitive, range))); } } + +#[derive(Component, ExtractComponent, Clone)] +pub struct InstancedMesh; + +#[derive(Component, ExtractComponent, Clone)] +pub struct InstanceRange(pub Range); + +pub struct InstancedMaterialPlugin(PhantomData); + +impl Default for InstancedMaterialPlugin +where + M: Default, +{ + fn default() -> Self { + InstancedMaterialPlugin(PhantomData) + } +} + +impl Plugin for InstancedMaterialPlugin +where + SM: ShaderModel, + SM::Data: PartialEq + Eq + Hash + Clone, +{ + fn build(&self, app: &mut App) { + app.add_plugins(( + ExtractComponentPlugin::::default(), + ExtractComponentPlugin::::default(), + )) + .add_systems( + PostUpdate, + view::check_visibility::> + .in_set(VisibilitySystems::CheckVisibility), + ); + + app.sub_app_mut(RenderApp) + .add_render_command::>() + .init_resource::>>() + .add_systems( + Render, + (queue_indirect:: + .after(prepare_assets::>) + .in_set(RenderSet::QueueMeshes),), + ); + } + + fn finish(&self, app: &mut App) { + app.sub_app_mut(RenderApp) + .init_resource::>(); + } +} + +#[allow(clippy::too_many_arguments)] +fn queue_indirect( + draw_functions: Res>, + custom_pipeline: Res>, + mut pipelines: ResMut>>, + pipeline_cache: Res, + meshes: Res>, + ( + render_mesh_instances, + instanced_meshes, + mut phases, + mut views, + shader_models, + extracted_instances, + ): ( + Res, + Query>, + ResMut>, + Query<(Entity, &ExtractedView, &Msaa)>, + Res>>, + Res>>, + ), +) where + SM: ShaderModel, + SM::Data: PartialEq + Eq + Hash + Clone, +{ + let drawn_function = draw_functions.read().id::>(); + + for (view_entity, view, msaa) in &mut views { + let msaa_key = MeshPipelineKey::from_msaa_samples(msaa.samples()); + let Some(phase) = phases.get_mut(&view_entity) else { + continue; + }; + + let view_key = msaa_key | MeshPipelineKey::from_hdr(view.hdr); + for (entity) in &instanced_meshes { + let Some(shader_model) = extracted_instances.get(&entity) else { + continue; + }; + let shader_model = shader_models.get(*shader_model).unwrap(); + let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(entity) else { + continue; + }; + let Some(mesh) = meshes.get(mesh_instance.mesh_asset_id) else { + continue; + }; + let mesh_key = + view_key | MeshPipelineKey::from_primitive_topology(mesh.primitive_topology()); + let key = MaterialPipelineKey { + mesh_key, + bind_group_data: shader_model.key.clone(), + }; + let pipeline = pipelines + .specialize(&pipeline_cache, &custom_pipeline, key, &mesh.layout) + .unwrap(); + phase.add( + Opaque3dBinKey { + draw_function: drawn_function, + pipeline, + asset_id: AssetId::::invalid().untyped(), + material_bind_group_id: None, + lightmap_image: None, + }, + entity, + BinnedRenderPhaseType::NonMesh, + ); + } + } +} + +#[derive(Resource)] +struct InstancedPipeline { + mesh_pipeline: MeshPipeline, + shader_model_layout: BindGroupLayout, + vertex_shader: Option>, + fragment_shader: Option>, + marker: PhantomData, +} + +impl FromWorld for InstancedPipeline { + fn from_world(world: &mut World) -> Self { + let asset_server = world.resource::(); + let render_device = world.resource::(); + + InstancedPipeline { + mesh_pipeline: world.resource::().clone(), + shader_model_layout: SM::bind_group_layout(render_device), + vertex_shader: match ::vertex_shader() { + ShaderRef::Default => None, + ShaderRef::Handle(handle) => Some(handle), + ShaderRef::Path(path) => Some(asset_server.load(path)), + }, + fragment_shader: match ::fragment_shader() { + ShaderRef::Default => None, + ShaderRef::Handle(handle) => Some(handle), + ShaderRef::Path(path) => Some(asset_server.load(path)), + }, + marker: PhantomData, + } + } +} + +impl SpecializedMeshPipeline for InstancedPipeline +where + SM::Data: PartialEq + Eq + Hash + Clone, +{ + type Key = MaterialPipelineKey; + + fn specialize( + &self, + key: Self::Key, + layout: &MeshVertexBufferLayoutRef, + ) -> Result { + let mut descriptor = self.mesh_pipeline.specialize(key.mesh_key, layout)?; + if let Some(vertex_shader) = &self.vertex_shader { + descriptor.vertex.shader = vertex_shader.clone(); + } + + if let Some(fragment_shader) = &self.fragment_shader { + descriptor.fragment.as_mut().unwrap().shader = fragment_shader.clone(); + } + + descriptor + .layout + .insert(2, self.shader_model_layout.clone()); + + let pipeline = MaterialPipeline { + mesh_pipeline: self.mesh_pipeline.clone(), + material_layout: self.shader_model_layout.clone(), + vertex_shader: self.vertex_shader.clone(), + fragment_shader: self.fragment_shader.clone(), + marker: Default::default(), + }; + SM::specialize(&pipeline, &mut descriptor, layout, key)?; + Ok(descriptor) + } +} + +type DrawInstancedMaterial = ( + SetItemPipeline, + SetMeshViewBindGroup<0>, + SetShaderModelBindGroup, + DrawMeshInstanced, +); + +struct SetShaderModelBindGroup(PhantomData); +impl RenderCommand

+ for SetShaderModelBindGroup +{ + type Param = ( + SRes>>, + SRes>>, + ); + type ViewQuery = (); + type ItemQuery = (); + + #[inline] + fn render<'w>( + item: &P, + _view: (), + _item_query: Option<()>, + (models, instances): SystemParamItem<'w, '_, Self::Param>, + pass: &mut TrackedRenderPass<'w>, + ) -> RenderCommandResult { + let models = models.into_inner(); + let instances = instances.into_inner(); + + let Some(asset_id) = instances.get(&item.entity()) else { + return RenderCommandResult::Skip; + }; + let Some(material) = models.get(*asset_id) else { + return RenderCommandResult::Skip; + }; + pass.set_bind_group(I, &material.bind_group, &[]); + RenderCommandResult::Success + } +} + +struct DrawMeshInstanced; +impl RenderCommand

for DrawMeshInstanced { + type Param = ( + SRes>, + SRes, + SRes, + SRes>, + ); + type ViewQuery = (); + type ItemQuery = Read; + + #[inline] + fn render<'w>( + item: &P, + _view: (), + instance_range: Option<&'w InstanceRange>, + (meshes, render_mesh_instances, mesh_allocator, ssbos): SystemParamItem< + 'w, + '_, + Self::Param, + >, + pass: &mut TrackedRenderPass<'w>, + ) -> RenderCommandResult { + let mesh_allocator = mesh_allocator.into_inner(); + + let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(item.entity()) + else { + return RenderCommandResult::Skip; + }; + let Some(gpu_mesh) = meshes.into_inner().get(mesh_instance.mesh_asset_id) else { + return RenderCommandResult::Skip; + }; + let Some(instance_range) = instance_range else { + return RenderCommandResult::Skip; + }; + let Some(vertex_buffer_slice) = + mesh_allocator.mesh_vertex_slice(&mesh_instance.mesh_asset_id) + else { + return RenderCommandResult::Skip; + }; + + pass.set_vertex_buffer(0, vertex_buffer_slice.buffer.slice(..)); + + match &gpu_mesh.buffer_info { + RenderMeshBufferInfo::Indexed { index_format, .. } => { + let Some(index_buffer_slice) = + mesh_allocator.mesh_index_slice(&mesh_instance.mesh_asset_id) + else { + return RenderCommandResult::Skip; + }; + + pass.set_index_buffer(index_buffer_slice.buffer.slice(..), 0, *index_format); + pass.draw_indexed( + index_buffer_slice.range.clone(), + 0, + instance_range.0.clone(), + ); + } + RenderMeshBufferInfo::NonIndexed => { + pass.draw(vertex_buffer_slice.range.clone(), instance_range.0.clone()); + } + } + RenderCommandResult::Success + } +} diff --git a/bevy_nannou_draw/src/render.rs b/bevy_nannou_draw/src/render.rs index eb55f693c..c1a37c3fc 100644 --- a/bevy_nannou_draw/src/render.rs +++ b/bevy_nannou_draw/src/render.rs @@ -30,6 +30,7 @@ use bevy::window::WindowRef; use lyon::lyon_tessellation::{FillTessellator, StrokeTessellator}; use crate::draw::indirect::{IndirectMaterialPlugin, IndirectMesh}; +use crate::draw::instanced::{InstanceRange, InstancedMaterialPlugin, InstancedMesh}; use crate::draw::mesh::MeshExt; use crate::draw::render::{RenderContext, RenderPrimitive}; use crate::draw::{DrawCommand, DrawContext}; @@ -90,6 +91,7 @@ where app.add_plugins(( RenderAssetPlugin::>::default(), IndirectMaterialPlugin::::default(), + InstancedMaterialPlugin::::default(), )) .add_systems(PostUpdate, update_material::.after(update_draw_mesh)); } @@ -318,21 +320,20 @@ fn update_draw_mesh( prim.render_primitive(ctxt, &mut mesh); let mesh = meshes.add(mesh); let mat_id = last_mat.expect("No material set for instanced draw command"); - // TODO: off by one??? - for _ in range.start..range.end - 1 { - commands.spawn(( - UntypedMaterialId(mat_id), - mesh.clone(), - Transform::default(), - GlobalTransform::default(), - Visibility::default(), - InheritedVisibility::default(), - ViewVisibility::default(), - NannouMesh, - NoFrustumCulling, - window_layers.clone(), - )); - } + commands.spawn(( + InstancedMesh, + InstanceRange(range), + UntypedMaterialId(mat_id), + mesh.clone(), + Transform::default(), + GlobalTransform::default(), + Visibility::default(), + InheritedVisibility::default(), + ViewVisibility::default(), + NannouMesh, + NoFrustumCulling, + window_layers.clone(), + )); } DrawCommand::Indirect(prim, indirect_buffer) => { // Info required during rendering. diff --git a/examples/assets/shaders/particle_mouse_material.wgsl b/examples/assets/shaders/particle_mouse_material.wgsl index d9bd669b6..6ad0a6400 100644 --- a/examples/assets/shaders/particle_mouse_material.wgsl +++ b/examples/assets/shaders/particle_mouse_material.wgsl @@ -22,7 +22,12 @@ fn vertex(vertex: Vertex) -> VertexOutput { let particle = particles[vertex.instance_index]; var out: VertexOutput; out.clip_position = mesh_position_local_to_clip( - get_world_from_local(vertex.instance_index), + mat4x4( + vec4(1.0, 0.0, 0.0, 0.0), + vec4(0.0, 1.0, 0.0, 0.0), + vec4(0.0, 0.0, 1.0, 0.0), + vec4(0.0, 0.0, 0.0, 1.0) + ), vec4(vertex.position, 1.0) ) + vec4(particle.position, 0.0, 0.0); out.color = particle.color; diff --git a/examples/compute/particle_mouse.rs b/examples/compute/particle_mouse.rs index b65535759..986c33a74 100644 --- a/examples/compute/particle_mouse.rs +++ b/examples/compute/particle_mouse.rs @@ -3,7 +3,7 @@ use nannou::prelude::bevy_render::storage::ShaderStorageBuffer; use nannou::prelude::*; use std::sync::Arc; -const NUM_PARTICLES: u32 = 1000; +const NUM_PARTICLES: u32 = 100000; const WORKGROUP_SIZE: u32 = 64; fn main() {