Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop null for reshape op #192

Merged
merged 1 commit into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion face_recognition/facenet_nhwc.js
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ export class FaceNetNhwc {
}

async buildFullyConnected_(input) {
input = this.builder_.reshape(input, [1, null]);
input = this.builder_.reshape(input, [1, 1792]);
const weights = await buildConstantByNpy(this.builder_,
`${this.weightsUrl_}/Bottleneck_kernel_transpose.npy`);
const bias = await buildConstantByNpy(this.builder_,
Expand Down
2 changes: 1 addition & 1 deletion facial_landmark_detection/face_landmark_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ export class FaceLandmarkNchw {
if (reshapeSize !== undefined) {
gemm = this.builder_.gemm(this.builder_.reshape(
this.builder_.transpose(await input, {permutation: [0, 2, 3, 1]}),
[null, reshapeSize]), await weights, options);
[1, reshapeSize]), await weights, options);
} else {
gemm = this.builder_.gemm(await input, await weights, options);
}
Expand Down
2 changes: 1 addition & 1 deletion facial_landmark_detection/face_landmark_nhwc.js
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ export class FaceLandmarkNhwc {
let fc;
if (reshapeSize !== undefined) {
fc = this.builder_.gemm(this.builder_.reshape(
await input, [null, reshapeSize]), await weights, options);
await input, [1, reshapeSize]), await weights, options);
} else {
fc = this.builder_.gemm(await input, await weights, options);
}
Expand Down
2 changes: 1 addition & 1 deletion image_classification/mobilenet_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ export class MobileNetV2Nchw {

const conv3 = await this.buildConv_(bottleneck15, '95', true);
const pool = this.builder_.averagePool2d(conv3);
const reshape = this.builder_.reshape(pool, [1, null]);
const reshape = this.builder_.reshape(pool, [1, 1280]);
const gemm = await this.buildGemm_(reshape, '104');
return this.builder_.softmax(gemm);
}
Expand Down
2 changes: 1 addition & 1 deletion image_classification/mobilenet_nhwc.js
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ export class MobileNetV2Nhwc {
conv3, {windowDimensions: [7, 7], layout: 'nhwc'});
const conv4 = await this.buildConv_(
averagePool2d, '222', 'Logits_Conv2d_1c_1x1_Conv2D', false, {autoPad, filterLayout});
const reshape = this.builder_.reshape(conv4, [1, null]);
const reshape = this.builder_.reshape(conv4, [1, 1001]);
return this.builder_.softmax(reshape);
}

Expand Down
4 changes: 2 additions & 2 deletions image_classification/resnet50v2_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ export class ResNet50V2Nchw {
const biasName = prefix + '_bias.npy';
const bias = await buildConstantByNpy(this.builder_, biasName);
const options =
{c: this.builder_.reshape(bias, [1, null]), bTranspose: true};
{c: this.builder_.reshape(bias, [1, 1000]), bTranspose: true};
return this.builder_.gemm(input, weight, options);
}

Expand Down Expand Up @@ -148,7 +148,7 @@ export class ResNet50V2Nchw {

const bn3 = await this.buildBatchNorm_(bottleneck16, '2', '');
const pool2 = await this.builder_.averagePool2d(bn3);
const reshape = this.builder_.reshape(pool2, [1, null]);
const reshape = this.builder_.reshape(pool2, [1, 2048]);
const gemm = await this.buildGemm_(reshape, '0');
return this.builder_.softmax(gemm);
}
Expand Down
2 changes: 1 addition & 1 deletion image_classification/resnet50v2_nhwc.js
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ export class ResNet50V2Nhwc {
const mean = this.builder_.averagePool2d(fusedBn, {layout});
const conv2 = await this.buildConv_(
mean, ['', '', 'logits'], {autoPad}, false);
const reshape = this.builder_.reshape(conv2, [1, null]);
const reshape = this.builder_.reshape(conv2, [1, 1001]);
return this.builder_.softmax(reshape);
}

Expand Down
2 changes: 1 addition & 1 deletion image_classification/squeezenet_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ export class SqueezeNetNchw {
const conv25 = await this.buildConv_(fire7, 'conv25');
const pool3 = this.builder_.averagePool2d(
conv25, {windowDimensions: [13, 13], strides: [13, 13]});
const reshape0 = this.builder_.reshape(pool3, [1, null]);
const reshape0 = this.builder_.reshape(pool3, [1, 1000]);
return this.builder_.softmax(reshape0);
}

Expand Down
2 changes: 1 addition & 1 deletion image_classification/squeezenet_nhwc.js
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ export class SqueezeNetNhwc {
const conv10 = await this.buildConv_(fire9, 'conv10');
const averagePool2d = this.builder_.averagePool2d(
conv10, {windowDimensions: [13, 13], layout});
const reshape = this.builder_.reshape(averagePool2d, [1, null]);
const reshape = this.builder_.reshape(averagePool2d, [1, 1001]);
return this.builder_.softmax(reshape);
}

Expand Down
4 changes: 2 additions & 2 deletions lenet/lenet.js
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ export class LeNet {
this.builder_.maxPool2d(add2, {windowDimensions: pool2WindowShape,
strides: pool2Strides});

const reshape1Shape = [1, null];
const reshape1Shape = [1, 800];
const reshape1 = this.builder_.reshape(pool2, reshape1Shape);

// skip the new shape, 2 int64 values
Expand All @@ -100,7 +100,7 @@ export class LeNet {

const relu = this.builder_.relu(add3);

const reshape2Shape = [1, null];
const reshape2Shape = [1, 500];
const reshape2 = this.builder_.reshape(relu, reshape2Shape);

const matmul2Shape = [10, 500];
Expand Down
6 changes: 3 additions & 3 deletions rnnoise/rnnoise.js
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ export class RNNoise {
const vadGruYTransposed = this.builder_.transpose(
vadGruY, {permutation: [2, 0, 1, 3]});
const vadGruTranspose1 = this.builder_.reshape(
vadGruYTransposed, [null, this.frames_, this.vadGruHiddenSize]);
vadGruYTransposed, [1, this.frames_, this.vadGruHiddenSize]);
const concatenate1 = this.builder_.concat(
[inputDenseTanh0, vadGruTranspose1, input], 2);
const noiseGruX = this.builder_.transpose(
Expand Down Expand Up @@ -112,7 +112,7 @@ export class RNNoise {
const noiseGruYTransposed = this.builder_.transpose(
noiseGruY, {permutation: [2, 0, 1, 3]});
const noiseGruTranspose1 = this.builder_.reshape(
noiseGruYTransposed, [null, this.frames_, this.noiseGruHiddenSize]);
noiseGruYTransposed, [1, this.frames_, this.noiseGruHiddenSize]);
const concatenate2 = this.builder_.concat(
[vadGruTranspose1, noiseGruTranspose1, input], 2);
const denoiseGruX = this.builder_.transpose(
Expand Down Expand Up @@ -140,7 +140,7 @@ export class RNNoise {
const denoiseGruYTransposed = this.builder_.transpose(
denoiseGruY, {permutation: [2, 0, 1, 3]});
const denoiseGruTranspose1 = this.builder_.reshape(
denoiseGruYTransposed, [null, this.frames_, this.denoiseGruHiddenSize]);
denoiseGruYTransposed, [1, this.frames_, this.denoiseGruHiddenSize]);
const denoiseOutput0 = this.builder_.matmul(
denoiseGruTranspose1, denoiseOutputKernel0);
const biasedTensorName = this.builder_.add(
Expand Down