Skip to content

Commit

Permalink
webgpu: support diag operator (#7177)
Browse files Browse the repository at this point in the history
  • Loading branch information
xhcao authored Dec 16, 2022
1 parent 3b95c19 commit d8b08c9
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 1 deletion.
51 changes: 51 additions & 0 deletions tfjs-backend-webgpu/src/diag_webgpu.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/**
* @license
* Copyright 2022 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program';
import {computeDispatch, flatDispatchLayout} from './webgpu_util';

export class DiagProgram implements WebGPUProgram {
outputShape: number[];
shaderKey: string;
dispatchLayout: {x: number[]};
dispatch: [number, number, number];
variableNames = ['x'];
workgroupSize: [number, number, number] = [64, 1, 1];
size = true;

constructor(size: number) {
this.outputShape = [size, size];
this.dispatchLayout = flatDispatchLayout(this.outputShape);
this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape, this.workgroupSize);

this.shaderKey = 'diag';
}

getUserCode(): string {
const userCode = `
${main('index')} {
if (index < uniforms.size) {
let coords = getOutputCoords();
let value = select(0.0, getX(coords[0]), coords[0] == coords[1]);
setOutputAtIndex(index, value);
}
}
`;
return userCode;
}
}
49 changes: 49 additions & 0 deletions tfjs-backend-webgpu/src/kernels/Diag.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/**
* @license
* Copyright 2022 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Diag, DiagInputs, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core';

import {WebGPUBackend} from '../backend_webgpu';
import {DiagProgram} from '../diag_webgpu';
import {reshape} from './Reshape';

export function diag(args: {inputs: DiagInputs, backend: WebGPUBackend}):
TensorInfo {
const {inputs, backend} = args;
const {x} = inputs;

const outShape = [...x.shape, ...x.shape];
const xSize = util.sizeFromShape(x.shape);

const flat = reshape({inputs: {x}, backend, attrs: {shape: [xSize]}});

const program = new DiagProgram(xSize);
const res = backend.runWebGPUProgram(program, [flat], flat.dtype);

const out = reshape({inputs: {x: res}, backend, attrs: {shape: outShape}});

backend.disposeData(flat.dataId);
backend.disposeData(res.dataId);

return out;
}

export const diagConfig: KernelConfig = {
kernelName: Diag,
backendName: 'webgpu',
kernelFunc: diag as unknown as KernelFunc
};
2 changes: 2 additions & 0 deletions tfjs-backend-webgpu/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import {cumsumConfig} from './kernels/Cumsum';
import {denseBincountConfig} from './kernels/DenseBincount';
import {depthToSpaceConfig} from './kernels/DepthToSpace';
import {depthwiseConv2dNativeConfig} from './kernels/DepthwiseConv2dNative';
import {diagConfig} from './kernels/Diag';
import {dilation2DConfig} from './kernels/Dilation2D';
import {einsumConfig} from './kernels/Einsum';
import {eluConfig} from './kernels/Elu';
Expand Down Expand Up @@ -187,6 +188,7 @@ const kernelConfigs: KernelConfig[] = [
denseBincountConfig,
depthToSpaceConfig,
depthwiseConv2dNativeConfig,
diagConfig,
dilation2DConfig,
einsumConfig,
eluConfig,
Expand Down
1 change: 0 additions & 1 deletion tfjs-backend-webgpu/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,6 @@ const TEST_FILTERS: TestFilter[] = [
'conv1d gradients', // Conv2DBackpropFilter
'conv3d ',
'conv3dTranspose ',
'diag ',
'maxPool3d ',
'maxPool3dBackprop ',
'maxPoolBackprop ',
Expand Down

0 comments on commit d8b08c9

Please sign in to comment.