Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sync API and require pre-allocated output buffers #174

Merged
merged 4 commits into from
Jun 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions explainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,17 @@ const A = builder.input('A', operandType);
const B = builder.input('B', operandType);
const C = builder.add(builder.mul(A, constant), B);
// 2. Compile it into an executable.
const graph = await builder.build({'C': C});
const graph = builder.build({'C': C});
// 3. Bind inputs to the graph and execute for the result.
const bufferA = new Float32Array(4).fill(1.0);
const bufferB = new Float32Array(4).fill(0.8);
const inputs = {'A': {data: bufferA}, 'B': {data: bufferB}};
const outputs = await graph.compute(inputs);
const bufferC = new Float32Array(4);
const inputs = {'A': bufferA, 'B': bufferB};
const outputs = {'C': bufferC};
graph.compute(inputs, outputs);
// The computed result of [[1, 1], [1, 1]] is in the buffer associated with
// the output operand.
console.log('Output shape: ' + outputs.C.dimensions);
console.log('Output value: ' + outputs.C.data);
console.log('Output value: ' + bufferC);
```

Check it out in [WebNN Code Editor](https://webmachinelearning.github.io/webnn-samples/code/?example=mul_add.js).
Expand Down Expand Up @@ -102,7 +103,7 @@ export class NSNet2 {
this.hiddenSize = 400;
}

async load(baseUrl, batchSize, frames) {
async build(baseUrl, batchSize, frames) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to remove async keyword if build method is sync.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed, because buildConstantByNpy in this function uses fetch to download .npy files from network which is async.

const context = navigator.ml.createContext();
const builder = new MLGraphBuilder(context);
// Create constants by loading pre-trained data from .npy files.
Expand Down Expand Up @@ -138,20 +139,21 @@ export class NSNet2 {
const relu163 = builder.relu(builder.add(builder.matmul(transpose159, weight215), biasFcOut0));
const relu167 = builder.relu(builder.add(builder.matmul(relu163, weight216), biasFcOut2));
const output = builder.sigmoid(builder.add(builder.matmul(relu167, weight217), biasFcOut4));
this.builder = builder;
this.graph = builder.build({'output': output, 'gru94': gru94, 'gru157': gru157});
}

async build() {
this.graph = await this.builder.build({output, gru94, gru157});
}

async compute(inputBuffer, initialState92Buffer, initialState155Buffer) {
compute(inputBuffer, initialState92Buffer, initialState155Buffer, outputBuffer, gru94Buffer, gru157Buffer) {
const inputs = {
input: {data: inputBuffer},
initialState92: {data: initialState92Buffer},
initialState155: {data: initialState155Buffer},
'input': inputBuffer,
'initialState92': initialState92Buffer,
'initialState155': initialState155Buffer,
};
const outputs = {
'output': outputBuffer,
'gru94': gru94Buffer,
'gru157': gru157Buffer
};
return await this.graph.compute(inputs);
return this.graph.compute(inputs, outputs);
}
}
```
Expand Down
Loading