Skip to content

Commit

Permalink
Add DJL Paddle OCR Example (#568)
Browse files Browse the repository at this point in the history
* add ocr init commit

* add final model fixes

Change-Id: I26c0210886f005c0bf107d6b68e573922e6aafe5

* ocr

Change-Id: Icdcb59000bb980efde03ac7b7ed33667779dbaeb

* clean up

Change-Id: Iba56203868da68ab0a4ef6e19485787234a9c87a
  • Loading branch information
lanking520 authored Feb 2, 2021
1 parent 34ffc84 commit 2ba4fb7
Show file tree
Hide file tree
Showing 6 changed files with 461 additions and 66 deletions.
71 changes: 44 additions & 27 deletions api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ai.djl.modality.cv.output.Joints;
import ai.djl.modality.cv.output.Mask;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
Expand Down Expand Up @@ -92,36 +93,52 @@ public Image fromImage(Object image) {
@Override
public Image fromNDArray(NDArray array) {
Shape shape = array.getShape();
if (shape.dimension() != 3) {
throw new IllegalArgumentException("Shape should only have three dimension follow CHW");
}
if (array.getDataType() != DataType.UINT8 && array.getDataType() != DataType.INT8) {
if (shape.dimension() == 4) {
throw new UnsupportedOperationException("Batch is not supported");
} else if (shape.get(0) == 1 || shape.get(2) == 1) {
throw new UnsupportedOperationException("Grayscale image is not supported");
} else if (array.getDataType() != DataType.UINT8 && array.getDataType() != DataType.INT8) {
throw new IllegalArgumentException("Datatype should be INT8 or UINT8");
}
if (shape.get(0) == 1) {
throw new UnsupportedOperationException("Grayscale image is not supported");
} else if (shape.get(0) != 3) {
throw new IllegalArgumentException(
"First dimension should be number of channel with value 1 or 3");
if (NDImageUtils.isCHW(shape)) {
int height = (int) shape.get(1);
int width = (int) shape.get(2);
int imageArea = width * height;
BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);
int[] raw = array.toUint8Array();
IntStream.range(0, imageArea)
.parallel()
.forEach(
ele -> {
int x = ele % width;
int y = ele / width;
int red = raw[ele] & 0xFF;
int green = raw[ele + imageArea] & 0xFF;
int blue = raw[ele + imageArea * 2] & 0xFF;
int rgb = (red << 16) | (green << 8) | blue;
image.setRGB(x, y, rgb);
});
return new BufferedImageWrapper(image);
} else {
int height = (int) shape.get(0);
int width = (int) shape.get(1);
int imageArea = width * height;
BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);
int[] raw = array.toUint8Array();
IntStream.range(0, imageArea)
.parallel()
.forEach(
ele -> {
int x = ele % width;
int y = ele / width;
int red = raw[ele * 3] & 0xFF;
int green = raw[ele * 3 + 1] & 0xFF;
int blue = raw[ele * 3 + 2] & 0xFF;
int rgb = (red << 16) | (green << 8) | blue;
image.setRGB(x, y, rgb);
});
return new BufferedImageWrapper(image);
}
int height = (int) shape.get(1);
int width = (int) shape.get(2);
int imageArea = width * height;
BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);
int[] raw = array.toUint8Array();
IntStream.range(0, imageArea)
.parallel()
.forEach(
ele -> {
int x = ele % width;
int y = ele / width;
int red = raw[ele] & 0xFF;
int green = raw[ele + imageArea] & 0xFF;
int blue = raw[ele + imageArea * 2] & 0xFF;
int rgb = (red << 16) | (green << 8) | blue;
image.setRGB(x, y, rgb);
});
return new BufferedImageWrapper(image);
}

protected void save(BufferedImage image, OutputStream os, String type) throws IOException {
Expand Down
8 changes: 7 additions & 1 deletion api/src/main/java/ai/djl/modality/cv/util/NDImageUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,13 @@ public static NDArray randomColorJitter(
return image.getNDArrayInternal().randomColorJitter(brightness, contrast, saturation, hue);
}

private static boolean isCHW(Shape shape) {
/**
* Check if the shape of the image follows CHW/NCHW.
*
* @param shape the shape of the image
* @return true for (N)CHW, false for (N)HWC
*/
public static boolean isCHW(Shape shape) {
if (shape.dimension() < 3) {
throw new IllegalArgumentException(
"Not a valid image shape, require at least three dimensions");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,28 +75,23 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I

private String[] findModelFile(Path dir) {
String[] paths = new String[2];
Path modelFile = dir.resolve("model");
if (Files.isRegularFile(modelFile)) {
paths[0] = modelFile.toString();
Path paramFile = dir.resolve("params");
if (Files.isRegularFile(paramFile)) {
paths[1] = paramFile.toString();
} else {
paths[0] = dir.toString();
String[][] patterns = {
{"model", "params"},
{"__model__", "__params__"},
{"inference.pdmodel", "inference.pdiparams"}
};
for (String[] pattern : patterns) {
Path modelFile = dir.resolve(pattern[0]);
if (Files.isRegularFile(modelFile)) {
paths[0] = modelFile.toString();
Path paramFile = dir.resolve(pattern[1]);
if (Files.isRegularFile(paramFile)) {
paths[1] = paramFile.toString();
} else {
paths[0] = dir.toString();
}
return paths;
}
return paths;
}

modelFile = dir.resolve("__model__");
if (Files.isRegularFile(modelFile)) {
paths[0] = modelFile.toString();
Path paramFile = dir.resolve("__params__");
if (Files.isRegularFile(paramFile)) {
paths[1] = paramFile.toString();
} else {
paths[0] = dir.toString();
}
return paths;
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ public class PpNDArray extends NativeResource<Long> implements NDArrayAdapter {
private PpNDManager manager;
private Shape shape;
private DataType dataType;
// TODO: we cannot close the inference NDArray, should remove after JNI
private boolean notClose;

/**
* Constructs an PpNDArray from a native handle (internal. Use {@link NDManager} instead).
Expand All @@ -44,18 +42,6 @@ public PpNDArray(PpNDManager manager, long handle) {
manager.attach(getUid(), this);
}

/**
* Constructs an PpNDArray from a native handle (internal. Use {@link NDManager} instead).
*
* @param manager the manager to attach the new array to
* @param handle the pointer to the native MxNDArray memory
* @param notClose not close the NDArray (inference use only)
*/
public PpNDArray(PpNDManager manager, long handle, boolean notClose) {
this(manager, handle);
this.notClose = notClose;
}

/**
* Constructs an PaddlePaddle NDArray from a {@link PpNDManager} (internal. Use {@link
* NDManager} instead).
Expand Down Expand Up @@ -157,9 +143,6 @@ public String toString() {
/** {@inheritDoc} */
@Override
public void close() {
if (notClose) {
return;
}
Long pointer = handle.getAndSet(null);
if (pointer != null) {
JniUtils.deleteNd(pointer);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.paddlepaddle.zoo;

import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.Point;
import ai.djl.modality.cv.output.Rectangle;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.stream.Collectors;

public class BoundFinder {

private final int[] deltaX = {0, 1, -1, 0};
private final int[] deltaY = {1, 0, 0, -1};
private List<List<Point>> pointsCollection;
private int width;
private int height;

public BoundFinder(boolean[][] grid) {
pointsCollection = new ArrayList<>();
width = grid.length;
height = grid[0].length;
boolean[][] visited = new boolean[width][height];
// get all points connections
for (int i = 0; i < width; i++) {
for (int j = 0; j < height; j++) {
if (grid[i][j] && !visited[i][j]) {
pointsCollection.add(bfs(grid, i, j, visited));
}
}
}
}

public List<List<Point>> getPoints() {
return pointsCollection;
}

public List<BoundingBox> getBoxes() {
return pointsCollection
.stream()
.parallel()
.map(
points -> {
double[] minMax = {Integer.MAX_VALUE, Integer.MAX_VALUE, -1, -1};
points.forEach(
p -> {
minMax[0] = Math.min(minMax[0], p.getX());
minMax[1] = Math.min(minMax[1], p.getY());
minMax[2] = Math.max(minMax[2], p.getX());
minMax[3] = Math.max(minMax[3], p.getY());
});
return new Rectangle(
minMax[1],
minMax[0],
minMax[3] - minMax[1],
minMax[2] - minMax[0]);
})
.filter(rect -> rect.getWidth() * width > 5.0 && rect.getHeight() * height > 5.0)
.collect(Collectors.toList());
}

private List<Point> bfs(boolean[][] grid, int x, int y, boolean[][] visited) {
Queue<Point> queue = new ArrayDeque<>();
queue.offer(new Point(x, y));
visited[x][y] = true;

List<Point> points = new ArrayList<>();
while (!queue.isEmpty()) {
Point point = queue.poll();
points.add(new Point(point.getX() / width, point.getY() / height));
for (int direction = 0; direction < 4; direction++) {
int newX = (int) point.getX() + deltaX[direction];
int newY = (int) point.getY() + deltaY[direction];
if (!isVaild(grid, newX, newY, visited)) {
continue;
}
queue.offer(new Point(newX, newY));
visited[newX][newY] = true;
}
}
return points;
}

private boolean isVaild(boolean[][] grid, int x, int y, boolean[][] visited) {
if (x < 0 || x >= width || y < 0 || y >= height) {
return false;
}
if (visited[x][y]) {
return false;
}
return grid[x][y];
}
}
Loading

0 comments on commit 2ba4fb7

Please sign in to comment.