Skip to content

Commit

Permalink
[pytorch] Fix PtNDArray interoperability with other engine (#1270)
Browse files Browse the repository at this point in the history
Change-Id: If0d602c897937b25a9d1b1808568a9bb4fbc582e
  • Loading branch information
frankfliu authored Oct 5, 2021
1 parent 9aaaa66 commit 8a431ba
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public ByteBuffer allocateDirect(int capacity) {
/** {@inheritDoc} */
@Override
public DlrNDArray from(NDArray array) {
if (array instanceof DlrNDArray) {
if (array == null || array instanceof DlrNDArray) {
return (DlrNDArray) array;
}
return (DlrNDArray) create(array.toByteBuffer(), array.getShape(), array.getDataType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public ByteBuffer allocateDirect(int capacity) {
/** {@inheritDoc} */
@Override
public XgbNDArray from(NDArray array) {
if (array instanceof XgbNDArray) {
if (array == null || array instanceof XgbNDArray) {
return (XgbNDArray) array;
}
return (XgbNDArray) create(array.toByteBuffer(), array.getShape(), array.getDataType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public ByteBuffer allocateDirect(int capacity) {
/** {@inheritDoc} */
@Override
public MxNDArray from(NDArray array) {
if (array instanceof MxNDArray) {
if (array == null || array instanceof MxNDArray) {
return (MxNDArray) array;
}
MxNDArray ret = create(array.getShape(), array.getDataType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public ByteBuffer allocateDirect(int capacity) {
/** {@inheritDoc} */
@Override
public OrtNDArray from(NDArray array) {
if (array instanceof OrtNDArray) {
if (array == null || array instanceof OrtNDArray) {
return (OrtNDArray) array;
}
return create(array.toByteBuffer(), array.getShape(), array.getDataType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public ByteBuffer allocateDirect(int capacity) {
/** {@inheritDoc} */
@Override
public PpNDArray from(NDArray array) {
if (array instanceof PpNDArray) {
if (array == null || array instanceof PpNDArray) {
return (PpNDArray) array;
}
return create(array.toByteBuffer(), array.getShape(), array.getDataType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public class PtNDArrayEx implements NDArrayEx {

private NDArrayIndexer indexer;
private PtNDArray array;
private PtNDManager manager;

/**
* Constructs an {@code PtNDArrayEx} given a {@link NDArray}.
Expand All @@ -38,13 +39,14 @@ public class PtNDArrayEx implements NDArrayEx {
*/
PtNDArrayEx(PtNDArray parent) {
this.array = parent;
indexer = new PtNDArrayIndexer(array.getManager());
this.manager = array.getManager();
indexer = new PtNDArrayIndexer(manager);
}

/** {@inheritDoc} */
@Override
public PtNDArray rdiv(Number n) {
return rdiv(array.getManager().create(n));
return rdiv(manager.create(n));
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -276,10 +278,10 @@ public void adamUpdate(
boolean lazyUpdate) {
// TODO: Lazy update not used
JniUtils.adamUpdate(
(PtNDArray) inputs.get(0),
(PtNDArray) inputs.get(1),
(PtNDArray) inputs.get(2),
(PtNDArray) inputs.get(3),
manager.from(inputs.get(0)),
manager.from(inputs.get(1)),
manager.from(inputs.get(2)),
manager.from(inputs.get(3)),
learningRate,
weightDecay,
rescaleGrad,
Expand All @@ -288,7 +290,7 @@ public void adamUpdate(
beta2,
epsilon);
// call zero-grad
JniUtils.zeroGrad((PtNDArray) weights.singletonOrThrow());
JniUtils.zeroGrad(manager.from(weights.singletonOrThrow()));
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -333,16 +335,16 @@ public void sgdUpdate(
boolean lazyUpdate) {
// TODO: Lazy update not used
JniUtils.sgdUpdate(
(PtNDArray) inputs.get(0),
(PtNDArray) inputs.get(1),
(momentum == 0f) ? null : (PtNDArray) inputs.get(2),
manager.from(inputs.get(0)),
manager.from(inputs.get(1)),
(momentum == 0f) ? null : manager.from(inputs.get(2)),
learningRate,
weightDecay,
rescaleGrad,
clipGrad,
momentum);
// call zero-grad
JniUtils.zeroGrad((PtNDArray) weights.singletonOrThrow());
JniUtils.zeroGrad(manager.from(weights.singletonOrThrow()));
}

/** {@inheritDoc} */
Expand All @@ -357,9 +359,9 @@ public NDList convolution(
int groups) {
return new NDList(
JniUtils.convolution(
(PtNDArray) input,
(PtNDArray) weight,
(PtNDArray) bias,
manager.from(input),
manager.from(weight),
manager.from(bias),
stride,
padding,
dilation,
Expand All @@ -383,7 +385,8 @@ public NDList deconvolution(
/** {@inheritDoc} */
@Override
public NDList linear(NDArray input, NDArray weight, NDArray bias) {
return new NDList(JniUtils.linear((PtNDArray) input, (PtNDArray) weight, (PtNDArray) bias));
return new NDList(
JniUtils.linear(manager.from(input), manager.from(weight), manager.from(bias)));
}

/** {@inheritDoc} */
Expand All @@ -394,8 +397,8 @@ public NDList embedding(NDArray input, NDArray weight, SparseFormat sparseFormat
}
return new NDList(
JniUtils.embedding(
(PtNDArray) input,
(PtNDArray) weight,
manager.from(input),
manager.from(weight),
sparseFormat.equals(SparseFormat.COO)));
}

Expand All @@ -408,7 +411,7 @@ public NDList prelu(NDArray input, NDArray alpha) {
/** {@inheritDoc} */
@Override
public NDList dropout(NDArray input, float rate, boolean training) {
return new NDList(JniUtils.dropout((PtNDArray) input, rate, training));
return new NDList(JniUtils.dropout(manager.from(input), rate, training));
}

/** {@inheritDoc} */
Expand All @@ -417,10 +420,10 @@ public NDList layerNorm(
NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) {
return new NDList(
JniUtils.layerNorm(
(PtNDArray) input,
manager.from(input),
normalizedShape,
(PtNDArray) gamma,
(PtNDArray) beta,
manager.from(gamma),
manager.from(beta),
eps));
}
/** {@inheritDoc} */
Expand All @@ -440,11 +443,11 @@ public NDList batchNorm(
if (axis == -1) {
return new NDList(
JniUtils.batchNorm(
(PtNDArray) input,
(PtNDArray) runningMean,
(PtNDArray) runningVar,
(PtNDArray) gamma,
(PtNDArray) beta,
manager.from(input),
manager.from(runningMean),
manager.from(runningVar),
manager.from(gamma),
manager.from(beta),
training,
// momentum is defined differently in PyTorch
1f - momentum,
Expand All @@ -457,11 +460,11 @@ public NDList batchNorm(
result = result.swapAxes(1, axis);
result =
JniUtils.batchNorm(
(PtNDArray) result,
(PtNDArray) runningMean,
(PtNDArray) runningVar,
(PtNDArray) gamma,
(PtNDArray) beta,
manager.from(result),
manager.from(runningMean),
manager.from(runningVar),
manager.from(gamma),
manager.from(beta),
training,
// momentum is defined differently in PyTorch
1f - momentum,
Expand All @@ -487,8 +490,8 @@ public NDList rnn(
boolean bidirectional,
boolean batchFirst) {
return JniUtils.rnn(
(PtNDArray) input,
(PtNDArray) state,
manager.from(input),
manager.from(state),
params,
hasBiases,
numLayers,
Expand All @@ -512,8 +515,8 @@ public NDList gru(
boolean bidirectional,
boolean batchFirst) {
return JniUtils.gru(
(PtNDArray) input,
(PtNDArray) state,
manager.from(input),
manager.from(state),
params,
hasBiases,
numLayers,
Expand All @@ -536,7 +539,7 @@ public NDList lstm(
boolean bidirectional,
boolean batchFirst) {
return JniUtils.lstm(
(PtNDArray) input,
manager.from(input),
states,
params,
hasBiases,
Expand All @@ -551,7 +554,7 @@ public NDList lstm(
@Override
public PtNDArray resize(int width, int height, int interpolation) {
// create subManager to help close intermediate NDArray
try (NDManager subManager = array.getManager().newSubManager()) {
try (NDManager subManager = manager.newSubManager()) {
array.attach(subManager);
NDArray result = array;
if (result.isEmpty()) {
Expand All @@ -567,7 +570,7 @@ public PtNDArray resize(int width, int height, int interpolation) {
result = result.transpose(0, 3, 1, 2);
result =
JniUtils.interpolate(
(PtNDArray) result,
manager.from(result),
new long[] {height, width},
getInterpolationMode(interpolation),
false)
Expand Down Expand Up @@ -626,15 +629,18 @@ public PtNDArray where(NDArray condition, NDArray other) {
throw new UnsupportedOperationException(
"condition and self shape mismatch, broadcast is not supported");
}
return JniUtils.where((PtNDArray) condition, array, (PtNDArray) other);
return JniUtils.where(manager.from(condition), array, manager.from(other));
}

/** {@inheritDoc} */
@Override
public PtNDArray stack(NDList arrays, int axis) {
PtNDArray[] srcArray = new PtNDArray[arrays.size() + 1];
srcArray[0] = array;
System.arraycopy(arrays.toArray(new NDArray[0]), 0, srcArray, 1, arrays.size());
int i = 1;
for (NDArray arr : arrays) {
srcArray[i++] = manager.from(arr);
}
return JniUtils.stack(srcArray, axis);
}

Expand All @@ -645,7 +651,10 @@ public PtNDArray concat(NDList list, int axis) {

PtNDArray[] srcArray = new PtNDArray[list.size() + 1];
srcArray[0] = array;
System.arraycopy(list.toArray(new NDArray[0]), 0, srcArray, 1, list.size());
int i = 1;
for (NDArray arr : list) {
srcArray[i++] = manager.from(arr);
}
return JniUtils.cat(srcArray, axis);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public ByteBuffer allocateDirect(int capacity) {
/** {@inheritDoc} */
@Override
public PtNDArray from(NDArray array) {
if (array instanceof PtNDArray) {
if (array == null || array instanceof PtNDArray) {
return (PtNDArray) array;
}
return create(array.toByteBuffer(), array.getShape(), array.getDataType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public ByteBuffer allocateDirect(int capacity) {
/** {@inheritDoc} */
@Override
public TfNDArray from(NDArray array) {
if (array instanceof TfNDArray) {
if (array == null || array instanceof TfNDArray) {
return (TfNDArray) array;
}
return create(array.toByteBuffer(), array.getShape(), array.getDataType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public ByteBuffer allocateDirect(int capacity) {
/** {@inheritDoc} */
@Override
public TrtNDArray from(NDArray array) {
if (array instanceof TrtNDArray) {
if (array == null || array instanceof TrtNDArray) {
return (TrtNDArray) array;
}
return create(array.toByteBuffer(), array.getShape(), array.getDataType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public ByteBuffer allocateDirect(int capacity) {
/** {@inheritDoc} */
@Override
public TfLiteNDArray from(NDArray array) {
if (array instanceof TfLiteNDArray) {
if (array == null || array instanceof TfLiteNDArray) {
return (TfLiteNDArray) array;
}
return create(array.toByteBuffer(), array.getShape(), array.getDataType());
Expand Down

0 comments on commit 8a431ba

Please sign in to comment.