Skip to content

Commit

Permalink
Compute shaders can read / write index and vertex buffers (#6226)
Browse files Browse the repository at this point in the history
* Compute shaders can read / write index and vertex buffers

* thumbs

* support for read-only storage

---------

Co-authored-by: Martin Valigursky <mvaligursky@snapchat.com>
  • Loading branch information
mvaligursky and Martin Valigursky authored Apr 5, 2024
1 parent e97f3ff commit d66c121
Show file tree
Hide file tree
Showing 14 changed files with 341 additions and 23 deletions.
53 changes: 53 additions & 0 deletions examples/src/examples/compute/vertex-update/config.mjs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/**
* @type {import('../../../../types.mjs').ExampleConfig}
*/
export default {
HIDDEN: true,
WEBGPU_REQUIRED: true,
FILES: {
'compute-shader.wgsl': /* wgsl */ `
struct ub_compute {
count: u32, // number of vertices
positionOffset: u32, // offset of the vertex positions in the vertex buffer
normalOffset: u32, // offset of the vertex normals in the vertex buffer
time: f32 // time
}
// uniform buffer
@group(0) @binding(0) var<uniform> ubCompute : ub_compute;
// vertex buffer
@group(0) @binding(1) var<storage, read_write> vertices: array<f32>;
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) global_invocation_id: vec3u) {
// vertex index - ignore if out of bounds (as they get batched into groups of 64)
let index = global_invocation_id.x;
if (index >= ubCompute.count) { return; }
// read the position from the vertex buffer
let positionOffset = ubCompute.positionOffset + index * 3;
var position = vec3f(vertices[positionOffset], vertices[positionOffset + 1], vertices[positionOffset + 2]);
// read normal
let normalOffset = ubCompute.normalOffset + index * 3;
let normal = vec3f(vertices[normalOffset], vertices[normalOffset + 1], vertices[normalOffset + 2]);
// generate position from the normal by offsetting (0,0,0) by normal * strength
let strength = vec3f(
1.0 + sin(ubCompute.time + 10 * position.y) * 0.1,
1.0 + cos(ubCompute.time + 5 * position.x) * 0.1,
1.0 + sin(ubCompute.time + 2 * position.z) * 0.2
);
position = normal * strength;
// write the position back to the vertex buffer
vertices[positionOffset + 0] = position.x;
vertices[positionOffset + 1] = position.y;
vertices[positionOffset + 2] = position.z;
}
`
}
};
171 changes: 171 additions & 0 deletions examples/src/examples/compute/vertex-update/example.mjs
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import * as pc from 'playcanvas';
import { deviceType, rootPath } from '@examples/utils';
import files from '@examples/files';

const canvas = document.getElementById('application-canvas');
if (!(canvas instanceof HTMLCanvasElement)) {
throw new Error('No canvas found');
}

const assets = {
color: new pc.Asset('color', 'texture', { url: rootPath + '/static/assets/textures/seaside-rocks01-color.jpg' }),
normal: new pc.Asset('normal', 'texture', { url: rootPath + '/static/assets/textures/seaside-rocks01-normal.jpg' }),
gloss: new pc.Asset('gloss', 'texture', { url: rootPath + '/static/assets/textures/seaside-rocks01-gloss.jpg' }),
orbit: new pc.Asset('script', 'script', { url: rootPath + '/static/scripts/camera/orbit-camera.js' }),
helipad: new pc.Asset(
'helipad-env-atlas',
'texture',
{ url: rootPath + '/static/assets/cubemaps/table-mountain-env-atlas.png' },
{ type: pc.TEXTURETYPE_RGBP, mipmaps: false }
)
};

const gfxOptions = {
deviceTypes: [deviceType],
glslangUrl: rootPath + '/static/lib/glslang/glslang.js',
twgslUrl: rootPath + '/static/lib/twgsl/twgsl.js'
};

const device = await pc.createGraphicsDevice(canvas, gfxOptions);
const createOptions = new pc.AppOptions();
createOptions.graphicsDevice = device;
createOptions.mouse = new pc.Mouse(document.body);
createOptions.touch = new pc.TouchDevice(document.body);

createOptions.componentSystems = [
pc.RenderComponentSystem,
pc.CameraComponentSystem,
pc.LightComponentSystem,
pc.ScriptComponentSystem
];

createOptions.resourceHandlers = [
pc.TextureHandler,
pc.ScriptHandler
];

const app = new pc.AppBase(canvas);
app.init(createOptions);

// Set the canvas to fill the window and automatically change resolution to be the same as the canvas size
app.setCanvasFillMode(pc.FILLMODE_FILL_WINDOW);
app.setCanvasResolution(pc.RESOLUTION_AUTO);

// Ensure canvas is resized when window changes size
const resize = () => app.resizeCanvas();
window.addEventListener('resize', resize);
app.on('destroy', () => {
window.removeEventListener('resize', resize);
});

const assetListLoader = new pc.AssetListLoader(Object.values(assets), app.assets);
assetListLoader.load(() => {
app.start();

// setup skydome
app.scene.skyboxMip = 2;
app.scene.exposure = 2;
app.scene.envAtlas = assets.helipad.resource;

// sphere material
const material = new pc.StandardMaterial();
material.diffuseMap = assets.color.resource;
material.normalMap = assets.normal.resource;
material.glossMap = assets.gloss.resource;
material.update();

// sphere mesh and entity
const entity = new pc.Entity('Sphere');
app.root.addChild(entity);

// create hight resolution sphere
const mesh = pc.createSphere(app.graphicsDevice, {
radius: 1,
latitudeBands: 100,
longitudeBands: 100,
storageVertex: true // allow vertex buffer to be accessible by compute shader
});

// Add a render component with the mesh
entity.addComponent('render', {
meshInstances: [new pc.MeshInstance(mesh, material)]
});
app.root.addChild(entity);

// Create an orbit camera
const cameraEntity = new pc.Entity();
cameraEntity.addComponent('camera', {
clearColor: new pc.Color(0.4, 0.45, 0.5)
});
cameraEntity.translate(0, 0, 5);

// add orbit camera script with a mouse and a touch support
cameraEntity.addComponent('script');
cameraEntity.script.create("orbitCamera", {
attributes: {
inertiaFactor: 0.2,
focusEntity: entity
}
});
cameraEntity.script.create("orbitCameraInputMouse");
cameraEntity.script.create("orbitCameraInputTouch");
app.root.addChild(cameraEntity);

// a compute shader that will modify the vertex buffer of the mesh every frame
const shader = device.supportsCompute ? new pc.Shader(device, {
name: 'ComputeShader',
shaderLanguage: pc.SHADERLANGUAGE_WGSL,
cshader: files['compute-shader.wgsl'],

// format of a uniform buffer used by the compute shader
computeUniformBufferFormat: new pc.UniformBufferFormat(device, [
new pc.UniformFormat('count', pc.UNIFORMTYPE_UINT),
new pc.UniformFormat('positionOffset', pc.UNIFORMTYPE_UINT),
new pc.UniformFormat('normalOffset', pc.UNIFORMTYPE_UINT),
new pc.UniformFormat('time', pc.UNIFORMTYPE_FLOAT)
]),

// format of a bind group, providing resources for the compute shader
computeBindGroupFormat: new pc.BindGroupFormat(device, [
// a uniform buffer we provided format for
new pc.BindUniformBufferFormat(pc.UNIFORM_BUFFER_DEFAULT_SLOT_NAME, pc.SHADERSTAGE_COMPUTE),
// the vertex buffer we want to modify
new pc.BindStorageBufferFormat('vb', pc.SHADERSTAGE_COMPUTE)
])
}) : null;

// information about the vertex buffer format - offset of position and normal attributes
// Note: data is stored non-interleaved, positions together, normals together, so no need
// to worry about stride
const format = mesh.vertexBuffer.format;
const positionElement = format.elements.find(e => e.name === pc.SEMANTIC_POSITION);
const normalElement = format.elements.find(e => e.name === pc.SEMANTIC_NORMAL);

// create an instance of the compute shader, and provide it the mesh vertex buffer
const compute = new pc.Compute(device, shader, 'ComputeModifyVB');
compute.setParameter('vb', mesh.vertexBuffer);
compute.setParameter('count', mesh.vertexBuffer.numVertices);
compute.setParameter('positionOffset', positionElement?.offset / 4); // number of floats offset
compute.setParameter('normalOffset', normalElement?.offset / 4); // number of floats offset

let time = 0;
app.on('update', function (dt) {
time += dt;
if (entity) {

// update non-constant parameters each frame
compute.setParameter('time', time);

// set up both dispatches
compute.setupDispatch(mesh.vertexBuffer.numVertices);

// dispatch the compute shader
device.computeDispatch([compute]);

// solid / wireframe
entity.render.renderStyle = Math.floor(time * 0.5) % 2 ? pc.RENDERSTYLE_WIREFRAME : pc.RENDERSTYLE_SOLID;
}
});
});

export { app };
Binary file not shown.
Binary file added examples/thumbnails/compute_vertex-update_small.webp
Binary file not shown.
6 changes: 6 additions & 0 deletions src/platform/graphics/bind-group-format.js
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ class BindUniformBufferFormat extends BindBaseFormat {
* @ignore
*/
class BindStorageBufferFormat extends BindBaseFormat {
constructor(name, visibility, readOnly = false) {
super(name, visibility);

// whether the buffer is read-only
this.readOnly = readOnly;
}
}

/**
Expand Down
4 changes: 2 additions & 2 deletions src/platform/graphics/compute.js
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class Compute {
* Sets a shader parameter on a compute instance.
*
* @param {string} name - The name of the parameter to set.
* @param {number|number[]|Float32Array|import('./texture.js').Texture|import('./storage-buffer.js').StorageBuffer} value
* @param {number|number[]|Float32Array|import('./texture.js').Texture|import('./storage-buffer.js').StorageBuffer|import('./vertex-buffer.js').VertexBuffer|import('./index-buffer.js').IndexBuffer} value
* - The value for the specified parameter.
*/
setParameter(name, value) {
Expand All @@ -91,7 +91,7 @@ class Compute {
* Returns the value of a shader parameter from the compute instance.
*
* @param {string} name - The name of the parameter to get.
* @returns {number|number[]|Float32Array|import('./texture.js').Texture|import('./storage-buffer.js').StorageBuffer|undefined}
* @returns {number|number[]|Float32Array|import('./texture.js').Texture|import('./storage-buffer.js').StorageBuffer|import('./vertex-buffer.js').VertexBuffer|import('./index-buffer.js').IndexBuffer|undefined}
* The value of the specified parameter.
*/
getParameter(name) {
Expand Down
7 changes: 5 additions & 2 deletions src/platform/graphics/index-buffer.js
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class IndexBuffer {
* Defaults to {@link BUFFER_STATIC}.
* @param {ArrayBuffer} [initialData] - Initial data. If left unspecified, the index buffer
* will be initialized to zeros.
* @param {object} [options] - Object for passing optional arguments.
* @param {boolean} [options.storage] - Defines if the index buffer can be used as a storage
* buffer by a compute shader. Defaults to false. Only supported on WebGPU.
* @example
* // Create an index buffer holding 3 16-bit indices. The buffer is marked as
* // static, hinting that the buffer will never be modified.
Expand All @@ -45,7 +48,7 @@ class IndexBuffer {
* pc.BUFFER_STATIC,
* indices);
*/
constructor(graphicsDevice, format, numIndices, usage = BUFFER_STATIC, initialData) {
constructor(graphicsDevice, format, numIndices, usage = BUFFER_STATIC, initialData, options) {
// By default, index buffers are static (better for performance since buffer data can be cached in VRAM)
this.device = graphicsDevice;
this.format = format;
Expand All @@ -54,7 +57,7 @@ class IndexBuffer {

this.id = id++;

this.impl = graphicsDevice.createIndexBufferImpl(this);
this.impl = graphicsDevice.createIndexBufferImpl(this, options);

// Allocate the storage
const bytesPerIndex = typedArrayIndexFormatsByteSize[format];
Expand Down
7 changes: 5 additions & 2 deletions src/platform/graphics/vertex-buffer.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ class VertexBuffer {
* @param {number} numVertices - The number of vertices that this vertex buffer will hold.
* @param {number} [usage] - The usage type of the vertex buffer (see BUFFER_*). Defaults to BUFFER_STATIC.
* @param {ArrayBuffer} [initialData] - Initial data.
* @param {object} [options] - Object for passing optional arguments.
* @param {boolean} [options.storage] - Defines if the vertex buffer can be used as a storage
* buffer by a compute shader. Defaults to false. Only supported on WebGPU.
*/
constructor(graphicsDevice, format, numVertices, usage = BUFFER_STATIC, initialData) {
constructor(graphicsDevice, format, numVertices, usage = BUFFER_STATIC, initialData, options) {
// By default, vertex buffers are static (better for performance since buffer data can be cached in VRAM)
this.device = graphicsDevice;
this.format = format;
Expand All @@ -31,7 +34,7 @@ class VertexBuffer {

this.id = id++;

this.impl = graphicsDevice.createVertexBufferImpl(this, format);
this.impl = graphicsDevice.createVertexBufferImpl(this, format, options);

// Calculate the size. If format contains verticesByteSize (non-interleaved format), use it
this.numBytes = format.verticesByteSize ? format.verticesByteSize : format.size * numVertices;
Expand Down
2 changes: 1 addition & 1 deletion src/platform/graphics/webgpu/webgpu-bind-group-format.js
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class WebgpuBindGroupFormat {
// storage buffers
bindGroupFormat.storageBufferFormats.forEach((bufferFormat) => {

const readOnly = false;
const readOnly = bufferFormat.readOnly;
const visibility = WebgpuUtils.shaderStage(bufferFormat.visibility);
key += `#${bufferFormat.slot}SB:${visibility}-${readOnly ? 'ro' : 'rw'}`;

Expand Down
8 changes: 4 additions & 4 deletions src/platform/graphics/webgpu/webgpu-graphics-device.js
Original file line number Diff line number Diff line change
Expand Up @@ -387,12 +387,12 @@ class WebgpuGraphicsDevice extends GraphicsDevice {
return new WebgpuUniformBuffer(uniformBuffer);
}

createVertexBufferImpl(vertexBuffer, format) {
return new WebgpuVertexBuffer(vertexBuffer, format);
createVertexBufferImpl(vertexBuffer, format, options) {
return new WebgpuVertexBuffer(vertexBuffer, format, options);
}

createIndexBufferImpl(indexBuffer) {
return new WebgpuIndexBuffer(indexBuffer);
createIndexBufferImpl(indexBuffer, options) {
return new WebgpuIndexBuffer(indexBuffer, options);
}

createShaderImpl(shader) {
Expand Down
6 changes: 3 additions & 3 deletions src/platform/graphics/webgpu/webgpu-index-buffer.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Debug } from '../../../core/debug.js';
import { INDEXFORMAT_UINT8, INDEXFORMAT_UINT16, BUFFERUSAGE_INDEX } from '../constants.js';
import { INDEXFORMAT_UINT8, INDEXFORMAT_UINT16, BUFFERUSAGE_INDEX, BUFFERUSAGE_STORAGE } from '../constants.js';
import { WebgpuBuffer } from "./webgpu-buffer.js";

/**
Expand All @@ -10,8 +10,8 @@ import { WebgpuBuffer } from "./webgpu-buffer.js";
class WebgpuIndexBuffer extends WebgpuBuffer {
format = null;

constructor(indexBuffer) {
super(BUFFERUSAGE_INDEX);
constructor(indexBuffer, options) {
super(BUFFERUSAGE_INDEX | (options?.storage ? BUFFERUSAGE_STORAGE : 0));

Debug.assert(indexBuffer.format !== INDEXFORMAT_UINT8, "WebGPU does not support 8-bit index buffer format");
this.format = indexBuffer.format === INDEXFORMAT_UINT16 ? "uint16" : "uint32";
Expand Down
6 changes: 3 additions & 3 deletions src/platform/graphics/webgpu/webgpu-vertex-buffer.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { BUFFERUSAGE_VERTEX } from "../constants.js";
import { BUFFERUSAGE_STORAGE, BUFFERUSAGE_VERTEX } from "../constants.js";
import { WebgpuBuffer } from "./webgpu-buffer.js";

/**
Expand All @@ -7,8 +7,8 @@ import { WebgpuBuffer } from "./webgpu-buffer.js";
* @ignore
*/
class WebgpuVertexBuffer extends WebgpuBuffer {
constructor(vertexBuffer, format) {
super(BUFFERUSAGE_VERTEX);
constructor(vertexBuffer, format, options) {
super(BUFFERUSAGE_VERTEX | (options?.storage ? BUFFERUSAGE_STORAGE : 0));
}

unlock(vertexBuffer) {
Expand Down
Loading

0 comments on commit d66c121

Please sign in to comment.