diff --git a/api/src/main/java/ai/djl/Device.java b/api/src/main/java/ai/djl/Device.java index aff51d99c3e..f2adaf88085 100644 --- a/api/src/main/java/ai/djl/Device.java +++ b/api/src/main/java/ai/djl/Device.java @@ -81,6 +81,15 @@ public int getDeviceId() { return deviceId; } + /** + * Returns if the {@code Device} is GPU. + * + * @return if the {@code Device} is GPU. + */ + public boolean isGpu() { + return Type.GPU.equals(deviceType); + } + /** {@inheritDoc} */ @Override public String toString() { diff --git a/api/src/main/java/ai/djl/util/cuda/CudaUtils.java b/api/src/main/java/ai/djl/util/cuda/CudaUtils.java index 8e27720ead5..1736742a3e9 100644 --- a/api/src/main/java/ai/djl/util/cuda/CudaUtils.java +++ b/api/src/main/java/ai/djl/util/cuda/CudaUtils.java @@ -128,7 +128,7 @@ public static String getComputeCapability(int device) { * @throws IllegalArgumentException if {@link Device} is not GPU device or does not exist */ public static MemoryUsage getGpuMemory(Device device) { - if (!Device.Type.GPU.equals(device.getDeviceType())) { + if (!device.isGpu()) { throw new IllegalArgumentException("Only GPU device is allowed."); } diff --git a/api/src/test/java/ai/djl/util/PlatformTest.java b/api/src/test/java/ai/djl/util/PlatformTest.java index e6164138c5e..3026280b9d3 100644 --- a/api/src/test/java/ai/djl/util/PlatformTest.java +++ b/api/src/test/java/ai/djl/util/PlatformTest.java @@ -40,7 +40,6 @@ public void testPlatform() throws IOException { Assert.assertEquals(system.getClassifier(), "linux-x86_64"); Assert.assertEquals(system.getOsPrefix(), "linux"); Assert.assertEquals(system.getOsArch(), "x86_64"); - Assert.assertNull(system.getCudaArch()); url = createPropertyFile("version=1.8.0\nplaceholder=true"); Platform platform = Platform.fromUrl(url); diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/jna/JnaUtils.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/jna/JnaUtils.java index 8b69297a7a6..4702dee4efb 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/jna/JnaUtils.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/jna/JnaUtils.java @@ -196,7 +196,7 @@ public static int getGpuCount() { } public static long[] getGpuMemory(Device device) { - if (!Device.Type.GPU.equals(device.getDeviceType())) { + if (!device.isGpu()) { throw new IllegalArgumentException("Only GPU device is allowed."); } diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java index a6bd7237f0d..cb788a6617b 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtModel.java @@ -71,7 +71,7 @@ public void load(Path modelPath, String prefix, Map options) try { Device device = manager.getDevice(); OrtSession session; - if (Device.Type.GPU.equals(device.getDeviceType())) { + if (device.isGpu()) { OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions(); sessionOptions.addCUDA(manager.getDevice().getDeviceId()); session = env.createSession(modelFile.toString(), sessionOptions); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index a317642f7a7..430892115be 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -215,7 +215,7 @@ public void set(Buffer data) { if (data.isDirect() && data instanceof ByteBuffer) { // If NDArray is on the GPU, it is native code responsibility to control the data life // cycle - if (!Device.Type.GPU.equals(getDevice().getDeviceType())) { + if (!getDevice().isGpu()) { dataRef = (ByteBuffer) data; } JniUtils.set(this, (ByteBuffer) data); @@ -227,7 +227,7 @@ public void set(Buffer data) { BaseNDManager.copyBuffer(data, buf); // If NDArray is on the GPU, it is native code responsibility to control the data life cycle - if (!Device.Type.GPU.equals(getDevice().getDeviceType())) { + if (!getDevice().isGpu()) { dataRef = buf; } JniUtils.set(this, buf); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index db613bae3c8..bad70044bc9 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -141,7 +141,7 @@ public static PtNDArray createNdFromByteBuffer( new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()}, false); - if (layout == 1 || layout == 2 || Device.Type.GPU.equals(device.getDeviceType())) { + if (layout == 1 || layout == 2 || device.isGpu()) { // MKLDNN & COO & GPU device will explicitly make a copy in native code // so we don't want to hold a reference on Java side return new PtNDArray(manager, handle); diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index d315ff5afb0..3d5c8f09c9d 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -191,6 +191,10 @@ public ByteBuffer toByteBuffer() { /** {@inheritDoc} */ @Override public void set(Buffer data) { + if (getDevice().isGpu()) { + // TODO: Implement set for GPU + throw new UnsupportedOperationException("GPU Tensor cannot be modified after creation"); + } int size = Math.toIntExact(getShape().size()); BaseNDManager.validateBufferSize(data, getDataType(), size); if (data instanceof ByteBuffer) { diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java index cc923e63955..c5c26535b8d 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java @@ -69,7 +69,8 @@ public NDArray create(Shape shape, DataType dataType) { // initialize with scalar 0 return create(0f).toType(dataType, false); } - TFE_TensorHandle handle = JavacppUtils.createEmptyTFETensor(shape, dataType); + TFE_TensorHandle handle = + JavacppUtils.createEmptyTFETensor(shape, dataType, getEagerSession(), device); return new TfNDArray(this, handle); } @@ -84,12 +85,15 @@ public TfNDArray create(Buffer data, Shape shape, DataType dataType) { BaseNDManager.validateBufferSize(data, dataType, size); if (data.isDirect() && data instanceof ByteBuffer) { TFE_TensorHandle handle = - JavacppUtils.createTFETensorFromByteBuffer((ByteBuffer) data, shape, dataType); + JavacppUtils.createTFETensorFromByteBuffer( + (ByteBuffer) data, shape, dataType, getEagerSession(), device); return new TfNDArray(this, handle); } ByteBuffer buf = allocateDirect(size * dataType.getNumOfBytes()); copyBuffer(data, buf); - TFE_TensorHandle handle = JavacppUtils.createTFETensorFromByteBuffer(buf, shape, dataType); + TFE_TensorHandle handle = + JavacppUtils.createTFETensorFromByteBuffer( + buf, shape, dataType, getEagerSession(), device); return new TfNDArray(this, handle); } diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/javacpp/JavacppUtils.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/javacpp/JavacppUtils.java index cba1d1c272a..e57b802eff5 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/javacpp/JavacppUtils.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/javacpp/JavacppUtils.java @@ -248,7 +248,7 @@ public static Shape getShape(TFE_TensorHandle handle) { } } - public static TF_Tensor createEmptyTFTensor(Shape shape, DataType dataType) { + private static TF_Tensor createEmptyTFTensor(Shape shape, DataType dataType) { int dType = TfDataType.toTf(dataType); long[] dims = shape.getShape(); long numBytes = dataType.getNumOfBytes() * shape.size(); @@ -260,12 +260,16 @@ public static TF_Tensor createEmptyTFTensor(Shape shape, DataType dataType) { } @SuppressWarnings({"unchecked", "try"}) - public static TFE_TensorHandle createEmptyTFETensor(Shape shape, DataType dataType) { + public static TFE_TensorHandle createEmptyTFETensor( + Shape shape, DataType dataType, TFE_Context eagerSessionHandle, Device device) { try (PointerScope ignored = new PointerScope()) { TF_Tensor tensor = createEmptyTFTensor(shape, dataType); TF_Status status = TF_Status.newStatus(); TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor(tensor, status); status.throwExceptionIfNotOK(); + if (device.isGpu()) { + return toDevice(handle, eagerSessionHandle, device); + } return handle.retainReference(); } } @@ -303,7 +307,11 @@ public static Pair createStringTensor( @SuppressWarnings({"unchecked", "try"}) public static TFE_TensorHandle createTFETensorFromByteBuffer( - ByteBuffer buf, Shape shape, DataType dataType) { + ByteBuffer buf, + Shape shape, + DataType dataType, + TFE_Context eagerSessionHandle, + Device device) { int dType = TfDataType.toTf(dataType); long[] dims = shape.getShape(); long numBytes; @@ -320,6 +328,9 @@ public static TFE_TensorHandle createTFETensorFromByteBuffer( TF_Status status = TF_Status.newStatus(); TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor(tensor, status); status.throwExceptionIfNotOK(); + if (device.isGpu()) { + return toDevice(handle, eagerSessionHandle, device); + } return handle.retainReference(); } } diff --git a/engines/tensorflow/tensorflow-engine/src/test/java/ai/djl/tensorflow/engine/TfNDManagerTest.java b/engines/tensorflow/tensorflow-engine/src/test/java/ai/djl/tensorflow/engine/TfNDManagerTest.java index 4ea91004f4f..b947cf56d21 100644 --- a/engines/tensorflow/tensorflow-engine/src/test/java/ai/djl/tensorflow/engine/TfNDManagerTest.java +++ b/engines/tensorflow/tensorflow-engine/src/test/java/ai/djl/tensorflow/engine/TfNDManagerTest.java @@ -42,9 +42,14 @@ public void testNDArray() { array.toStringArray()[1].getBytes(StandardCharsets.UTF_8), buf2.array()); array = manager.zeros(new Shape(2)); + final NDArray b = array; float[] expected = {2, 3}; - array.set(expected); - Assert.assertEquals(array.toFloatArray(), expected); + if (array.getDevice().isGpu()) { + Assert.assertThrows(UnsupportedOperationException.class, () -> b.set(expected)); + } else { + array.set(expected); + Assert.assertEquals(array.toFloatArray(), expected); + } Assert.assertThrows( IllegalArgumentException.class, diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngine.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngine.java index a739ff56640..343abac7e8c 100644 --- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngine.java +++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngine.java @@ -130,7 +130,7 @@ public TrtNDManager newBaseManager() { public TrtNDManager newBaseManager(Device device) { // Only support GPU for now device = device == null ? defaultDevice() : device; - if (!Device.Type.GPU.equals(device.getDeviceType())) { + if (!device.isGpu()) { throw new IllegalArgumentException("TensorRT only support GPU"); } return TrtNDManager.getSystemManager().newSubManager(device); diff --git a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java index 832caafbd93..d52b450741d 100644 --- a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java +++ b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java @@ -12,7 +12,6 @@ */ package ai.djl.tensorrt.engine; -import ai.djl.Device; import ai.djl.engine.Engine; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; @@ -31,7 +30,7 @@ public void testNDArray() { } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } - if (!Device.Type.GPU.equals(engine.defaultDevice().getDeviceType())) { + if (!engine.defaultDevice().isGpu()) { throw new SkipException("TensorRT only support GPU."); } try (NDManager manager = TrtNDManager.getSystemManager().newSubManager()) { diff --git a/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java b/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java index 1d795644956..36bb702032d 100644 --- a/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java +++ b/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java @@ -12,7 +12,6 @@ */ package ai.djl.tensorrt.integration; -import ai.djl.Device; import ai.djl.ModelException; import ai.djl.engine.Engine; import ai.djl.inference.Predictor; @@ -50,7 +49,7 @@ public void testTrtOnnx() throws ModelException, IOException, TranslateException } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } - if (!Device.Type.GPU.equals(engine.defaultDevice().getDeviceType())) { + if (!engine.defaultDevice().isGpu()) { throw new SkipException("TensorRT only support GPU."); } Criteria criteria = @@ -76,7 +75,7 @@ public void testTrtUff() throws ModelException, IOException, TranslateException } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } - if (!Device.Type.GPU.equals(engine.defaultDevice().getDeviceType())) { + if (!engine.defaultDevice().isGpu()) { throw new SkipException("TensorRT only support GPU."); } List synset = @@ -113,7 +112,7 @@ public void testSerializedEngine() throws ModelException, IOException, Translate } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } - if (!Device.Type.GPU.equals(engine.defaultDevice().getDeviceType())) { + if (!engine.defaultDevice().isGpu()) { throw new SkipException("TensorRT only support GPU."); } Criteria criteria = diff --git a/extensions/benchmark/src/main/java/ai/djl/benchmark/Arguments.java b/extensions/benchmark/src/main/java/ai/djl/benchmark/Arguments.java index b519d1d9b7d..46a8f0b538d 100644 --- a/extensions/benchmark/src/main/java/ai/djl/benchmark/Arguments.java +++ b/extensions/benchmark/src/main/java/ai/djl/benchmark/Arguments.java @@ -98,8 +98,7 @@ public class Arguments { threads = Integer.parseInt(cmd.getOptionValue("threads")); Engine eng = Engine.getEngine(engine); Device[] devices = eng.getDevices(maxGpus); - String deviceType = devices[0].getDeviceType(); - if (Device.Type.GPU.equals(deviceType)) { + if (devices[0].isGpu()) { // one thread per GPU if (threads <= 0) { threads = devices.length;