Skip to content

Commit

Permalink
fix some bugs in pytorch based examples (#2009)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk authored Sep 13, 2022
1 parent 93ec732 commit 5073966
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,13 @@ public BigGANTranslator(float truncation) {
/** {@inheritDoc} */
@Override
public Image[] processOutput(TranslatorContext ctx, NDList list) {
NDArray output = list.get(0).addi(1).muli(128).clip(0, 255);

NDArray output = list.get(0).duplicate().addi(1).muli(128).clip(0, 255);
int sampleSize = (int) output.getShape().get(0);
Image[] images = new Image[sampleSize];

for (int i = 0; i < sampleSize; ++i) {
images[i] = ImageFactory.getInstance().fromNDArray(output.get(i));
}

return images;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public NDList processInput(TranslatorContext ctx, Image input) {
/** {@inheritDoc} */
@Override
public Image processOutput(TranslatorContext ctx, NDList list) {
NDArray output = list.get(0).addi(1).muli(128);
NDArray output = list.get(0).duplicate().addi(1).muli(128);
return ImageFactory.getInstance().fromNDArray(output.squeeze());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ public static void main(String[] args) throws IOException, ModelException, Trans
Image input = ImageFactory.getInstance().fromFile(Paths.get(imagePath));
Image output = transfer(input, artist);

logger.info("Using PyTorch Engine. " + artist + " painting generated.");
logger.info(
"Using PyTorch Engine. {} painting generated. Image saved in build/output/cyclegan",
artist);
save(output, artist.toString(), "build/output/cyclegan/");
}

Expand Down

0 comments on commit 5073966

Please sign in to comment.