diff --git a/tfjs-backend-webgpu/src/diag_webgpu.ts b/tfjs-backend-webgpu/src/diag_webgpu.ts new file mode 100644 index 00000000000..2896967e129 --- /dev/null +++ b/tfjs-backend-webgpu/src/diag_webgpu.ts @@ -0,0 +1,51 @@ +/** + * @license + * Copyright 2022 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program'; +import {computeDispatch, flatDispatchLayout} from './webgpu_util'; + +export class DiagProgram implements WebGPUProgram { + outputShape: number[]; + shaderKey: string; + dispatchLayout: {x: number[]}; + dispatch: [number, number, number]; + variableNames = ['x']; + workgroupSize: [number, number, number] = [64, 1, 1]; + size = true; + + constructor(size: number) { + this.outputShape = [size, size]; + this.dispatchLayout = flatDispatchLayout(this.outputShape); + this.dispatch = computeDispatch( + this.dispatchLayout, this.outputShape, this.workgroupSize); + + this.shaderKey = 'diag'; + } + + getUserCode(): string { + const userCode = ` + ${main('index')} { + if (index < uniforms.size) { + let coords = getOutputCoords(); + let value = select(0.0, getX(coords[0]), coords[0] == coords[1]); + setOutputAtIndex(index, value); + } + } + `; + return userCode; + } +} diff --git a/tfjs-backend-webgpu/src/kernels/Diag.ts b/tfjs-backend-webgpu/src/kernels/Diag.ts new file mode 100644 index 00000000000..776d42eac05 --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/Diag.ts @@ -0,0 +1,49 @@ +/** + * @license + * Copyright 2022 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Diag, DiagInputs, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {WebGPUBackend} from '../backend_webgpu'; +import {DiagProgram} from '../diag_webgpu'; +import {reshape} from './Reshape'; + +export function diag(args: {inputs: DiagInputs, backend: WebGPUBackend}): + TensorInfo { + const {inputs, backend} = args; + const {x} = inputs; + + const outShape = [...x.shape, ...x.shape]; + const xSize = util.sizeFromShape(x.shape); + + const flat = reshape({inputs: {x}, backend, attrs: {shape: [xSize]}}); + + const program = new DiagProgram(xSize); + const res = backend.runWebGPUProgram(program, [flat], flat.dtype); + + const out = reshape({inputs: {x: res}, backend, attrs: {shape: outShape}}); + + backend.disposeData(flat.dataId); + backend.disposeData(res.dataId); + + return out; +} + +export const diagConfig: KernelConfig = { + kernelName: Diag, + backendName: 'webgpu', + kernelFunc: diag as unknown as KernelFunc +}; diff --git a/tfjs-backend-webgpu/src/register_all_kernels.ts b/tfjs-backend-webgpu/src/register_all_kernels.ts index b2ee12f81f4..dd11fff7992 100644 --- a/tfjs-backend-webgpu/src/register_all_kernels.ts +++ b/tfjs-backend-webgpu/src/register_all_kernels.ts @@ -50,6 +50,7 @@ import {cumsumConfig} from './kernels/Cumsum'; import {denseBincountConfig} from './kernels/DenseBincount'; import {depthToSpaceConfig} from './kernels/DepthToSpace'; import {depthwiseConv2dNativeConfig} from './kernels/DepthwiseConv2dNative'; +import {diagConfig} from './kernels/Diag'; import {dilation2DConfig} from './kernels/Dilation2D'; import {einsumConfig} from './kernels/Einsum'; import {eluConfig} from './kernels/Elu'; @@ -187,6 +188,7 @@ const kernelConfigs: KernelConfig[] = [ denseBincountConfig, depthToSpaceConfig, depthwiseConv2dNativeConfig, + diagConfig, dilation2DConfig, einsumConfig, eluConfig, diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index 6f7d586a933..718ff486b36 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -259,7 +259,6 @@ const TEST_FILTERS: TestFilter[] = [ 'conv1d gradients', // Conv2DBackpropFilter 'conv3d ', 'conv3dTranspose ', - 'diag ', 'maxPool3d ', 'maxPool3dBackprop ', 'maxPoolBackprop ',