Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix build failure on GPU #1279

Merged
merged 1 commit into from
Oct 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions api/src/main/java/ai/djl/Device.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ public int getDeviceId() {
return deviceId;
}

/**
* Returns if the {@code Device} is GPU.
*
* @return if the {@code Device} is GPU.
*/
public boolean isGpu() {
return Type.GPU.equals(deviceType);
}

/** {@inheritDoc} */
@Override
public String toString() {
Expand Down
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/util/cuda/CudaUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ public static String getComputeCapability(int device) {
* @throws IllegalArgumentException if {@link Device} is not GPU device or does not exist
*/
public static MemoryUsage getGpuMemory(Device device) {
if (!Device.Type.GPU.equals(device.getDeviceType())) {
if (!device.isGpu()) {
throw new IllegalArgumentException("Only GPU device is allowed.");
}

Expand Down
1 change: 0 additions & 1 deletion api/src/test/java/ai/djl/util/PlatformTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ public void testPlatform() throws IOException {
Assert.assertEquals(system.getClassifier(), "linux-x86_64");
Assert.assertEquals(system.getOsPrefix(), "linux");
Assert.assertEquals(system.getOsArch(), "x86_64");
Assert.assertNull(system.getCudaArch());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we remove this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CudaArch is inherited from system detected value. On GPU machine, this won't be null.


url = createPropertyFile("version=1.8.0\nplaceholder=true");
Platform platform = Platform.fromUrl(url);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ public static int getGpuCount() {
}

public static long[] getGpuMemory(Device device) {
if (!Device.Type.GPU.equals(device.getDeviceType())) {
if (!device.isGpu()) {
throw new IllegalArgumentException("Only GPU device is allowed.");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
try {
Device device = manager.getDevice();
OrtSession session;
if (Device.Type.GPU.equals(device.getDeviceType())) {
if (device.isGpu()) {
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
sessionOptions.addCUDA(manager.getDevice().getDeviceId());
session = env.createSession(modelFile.toString(), sessionOptions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ public void set(Buffer data) {
if (data.isDirect() && data instanceof ByteBuffer) {
// If NDArray is on the GPU, it is native code responsibility to control the data life
// cycle
if (!Device.Type.GPU.equals(getDevice().getDeviceType())) {
if (!getDevice().isGpu()) {
dataRef = (ByteBuffer) data;
}
JniUtils.set(this, (ByteBuffer) data);
Expand All @@ -227,7 +227,7 @@ public void set(Buffer data) {
BaseNDManager.copyBuffer(data, buf);

// If NDArray is on the GPU, it is native code responsibility to control the data life cycle
if (!Device.Type.GPU.equals(getDevice().getDeviceType())) {
if (!getDevice().isGpu()) {
dataRef = buf;
}
JniUtils.set(this, buf);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public static PtNDArray createNdFromByteBuffer(
new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()},
false);

if (layout == 1 || layout == 2 || Device.Type.GPU.equals(device.getDeviceType())) {
if (layout == 1 || layout == 2 || device.isGpu()) {
// MKLDNN & COO & GPU device will explicitly make a copy in native code
// so we don't want to hold a reference on Java side
return new PtNDArray(manager, handle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ public ByteBuffer toByteBuffer() {
/** {@inheritDoc} */
@Override
public void set(Buffer data) {
if (getDevice().isGpu()) {
// TODO: Implement set for GPU
throw new UnsupportedOperationException("GPU Tensor cannot be modified after creation");
}
int size = Math.toIntExact(getShape().size());
BaseNDManager.validateBufferSize(data, getDataType(), size);
if (data instanceof ByteBuffer) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ public NDArray create(Shape shape, DataType dataType) {
// initialize with scalar 0
return create(0f).toType(dataType, false);
}
TFE_TensorHandle handle = JavacppUtils.createEmptyTFETensor(shape, dataType);
TFE_TensorHandle handle =
JavacppUtils.createEmptyTFETensor(shape, dataType, getEagerSession(), device);
return new TfNDArray(this, handle);
}

Expand All @@ -84,12 +85,15 @@ public TfNDArray create(Buffer data, Shape shape, DataType dataType) {
BaseNDManager.validateBufferSize(data, dataType, size);
if (data.isDirect() && data instanceof ByteBuffer) {
TFE_TensorHandle handle =
JavacppUtils.createTFETensorFromByteBuffer((ByteBuffer) data, shape, dataType);
JavacppUtils.createTFETensorFromByteBuffer(
(ByteBuffer) data, shape, dataType, getEagerSession(), device);
return new TfNDArray(this, handle);
}
ByteBuffer buf = allocateDirect(size * dataType.getNumOfBytes());
copyBuffer(data, buf);
TFE_TensorHandle handle = JavacppUtils.createTFETensorFromByteBuffer(buf, shape, dataType);
TFE_TensorHandle handle =
JavacppUtils.createTFETensorFromByteBuffer(
buf, shape, dataType, getEagerSession(), device);
return new TfNDArray(this, handle);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ public static Shape getShape(TFE_TensorHandle handle) {
}
}

public static TF_Tensor createEmptyTFTensor(Shape shape, DataType dataType) {
private static TF_Tensor createEmptyTFTensor(Shape shape, DataType dataType) {
int dType = TfDataType.toTf(dataType);
long[] dims = shape.getShape();
long numBytes = dataType.getNumOfBytes() * shape.size();
Expand All @@ -260,12 +260,16 @@ public static TF_Tensor createEmptyTFTensor(Shape shape, DataType dataType) {
}

@SuppressWarnings({"unchecked", "try"})
public static TFE_TensorHandle createEmptyTFETensor(Shape shape, DataType dataType) {
public static TFE_TensorHandle createEmptyTFETensor(
Shape shape, DataType dataType, TFE_Context eagerSessionHandle, Device device) {
try (PointerScope ignored = new PointerScope()) {
TF_Tensor tensor = createEmptyTFTensor(shape, dataType);
TF_Status status = TF_Status.newStatus();
TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor(tensor, status);
status.throwExceptionIfNotOK();
if (device.isGpu()) {
return toDevice(handle, eagerSessionHandle, device);
}
return handle.retainReference();
}
}
Expand Down Expand Up @@ -303,7 +307,11 @@ public static Pair<TF_Tensor, TFE_TensorHandle> createStringTensor(

@SuppressWarnings({"unchecked", "try"})
public static TFE_TensorHandle createTFETensorFromByteBuffer(
ByteBuffer buf, Shape shape, DataType dataType) {
ByteBuffer buf,
Shape shape,
DataType dataType,
TFE_Context eagerSessionHandle,
Device device) {
int dType = TfDataType.toTf(dataType);
long[] dims = shape.getShape();
long numBytes;
Expand All @@ -320,6 +328,9 @@ public static TFE_TensorHandle createTFETensorFromByteBuffer(
TF_Status status = TF_Status.newStatus();
TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor(tensor, status);
status.throwExceptionIfNotOK();
if (device.isGpu()) {
return toDevice(handle, eagerSessionHandle, device);
}
return handle.retainReference();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,14 @@ public void testNDArray() {
array.toStringArray()[1].getBytes(StandardCharsets.UTF_8), buf2.array());

array = manager.zeros(new Shape(2));
final NDArray b = array;
float[] expected = {2, 3};
array.set(expected);
Assert.assertEquals(array.toFloatArray(), expected);
if (array.getDevice().isGpu()) {
Assert.assertThrows(UnsupportedOperationException.class, () -> b.set(expected));
} else {
array.set(expected);
Assert.assertEquals(array.toFloatArray(), expected);
}

Assert.assertThrows(
IllegalArgumentException.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public TrtNDManager newBaseManager() {
public TrtNDManager newBaseManager(Device device) {
// Only support GPU for now
device = device == null ? defaultDevice() : device;
if (!Device.Type.GPU.equals(device.getDeviceType())) {
if (!device.isGpu()) {
throw new IllegalArgumentException("TensorRT only support GPU");
}
return TrtNDManager.getSystemManager().newSubManager(device);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
*/
package ai.djl.tensorrt.engine;

import ai.djl.Device;
import ai.djl.engine.Engine;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
Expand All @@ -31,7 +30,7 @@ public void testNDArray() {
} catch (Exception ignore) {
throw new SkipException("Your os configuration doesn't support TensorRT.");
}
if (!Device.Type.GPU.equals(engine.defaultDevice().getDeviceType())) {
if (!engine.defaultDevice().isGpu()) {
throw new SkipException("TensorRT only support GPU.");
}
try (NDManager manager = TrtNDManager.getSystemManager().newSubManager()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
*/
package ai.djl.tensorrt.integration;

import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
Expand Down Expand Up @@ -50,7 +49,7 @@ public void testTrtOnnx() throws ModelException, IOException, TranslateException
} catch (Exception ignore) {
throw new SkipException("Your os configuration doesn't support TensorRT.");
}
if (!Device.Type.GPU.equals(engine.defaultDevice().getDeviceType())) {
if (!engine.defaultDevice().isGpu()) {
throw new SkipException("TensorRT only support GPU.");
}
Criteria<float[], float[]> criteria =
Expand All @@ -76,7 +75,7 @@ public void testTrtUff() throws ModelException, IOException, TranslateException
} catch (Exception ignore) {
throw new SkipException("Your os configuration doesn't support TensorRT.");
}
if (!Device.Type.GPU.equals(engine.defaultDevice().getDeviceType())) {
if (!engine.defaultDevice().isGpu()) {
throw new SkipException("TensorRT only support GPU.");
}
List<String> synset =
Expand Down Expand Up @@ -113,7 +112,7 @@ public void testSerializedEngine() throws ModelException, IOException, Translate
} catch (Exception ignore) {
throw new SkipException("Your os configuration doesn't support TensorRT.");
}
if (!Device.Type.GPU.equals(engine.defaultDevice().getDeviceType())) {
if (!engine.defaultDevice().isGpu()) {
throw new SkipException("TensorRT only support GPU.");
}
Criteria<float[], float[]> criteria =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ public class Arguments {
threads = Integer.parseInt(cmd.getOptionValue("threads"));
Engine eng = Engine.getEngine(engine);
Device[] devices = eng.getDevices(maxGpus);
String deviceType = devices[0].getDeviceType();
if (Device.Type.GPU.equals(deviceType)) {
if (devices[0].isGpu()) {
// one thread per GPU
if (threads <= 0) {
threads = devices.length;
Expand Down