From 2bd8d51cc3ddf9d6633039e161b3da082e4b5114 Mon Sep 17 00:00:00 2001 From: Florian Spiess Date: Wed, 5 Jan 2022 14:20:26 +0100 Subject: [PATCH] InceptionResnetV2 feature for Query-by-Example (#245) Refactored InceptionResnetV2 feature encoding into its own separate feature from inside VisualTextCoEmbedding. This allows it to be used as a feature for query-by-example that is less abstracted towards semantic content than the visual-text co-embedding. Authored by @Spiess Former-commit-id: 2ea3d63d3d64d8da1c45dcdde0170be79ceadf18 --- build.gradle | 2 +- .../core/features/InceptionResnetV2.java | 255 ++++++++++++++++++ .../core/features/VisualTextCoEmbedding.java | 142 ++-------- 3 files changed, 271 insertions(+), 128 deletions(-) create mode 100644 cineast-core/src/main/java/org/vitrivr/cineast/core/features/InceptionResnetV2.java diff --git a/build.gradle b/build.gradle index ea5be48b9..07b8a50b9 100644 --- a/build.gradle +++ b/build.gradle @@ -11,7 +11,7 @@ allprojects { group = 'org.vitrivr' /* Our current version, on dev branch this should always be release+1-SNAPSHOT */ - version = '3.6.1' + version = '3.6.2' apply plugin: 'java-library' apply plugin: 'maven-publish' diff --git a/cineast-core/src/main/java/org/vitrivr/cineast/core/features/InceptionResnetV2.java b/cineast-core/src/main/java/org/vitrivr/cineast/core/features/InceptionResnetV2.java new file mode 100644 index 000000000..4294b3475 --- /dev/null +++ b/cineast-core/src/main/java/org/vitrivr/cineast/core/features/InceptionResnetV2.java @@ -0,0 +1,255 @@ +package org.vitrivr.cineast.core.features; + +import java.awt.image.BufferedImage; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import net.coobird.thumbnailator.Thumbnails; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Tensor; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.types.TFloat32; +import org.vitrivr.cineast.core.config.QueryConfig; +import org.vitrivr.cineast.core.config.ReadableQueryConfig; +import org.vitrivr.cineast.core.config.ReadableQueryConfig.Distance; +import org.vitrivr.cineast.core.data.FloatVectorImpl; +import org.vitrivr.cineast.core.data.frames.VideoFrame; +import org.vitrivr.cineast.core.data.raw.images.MultiImage; +import org.vitrivr.cineast.core.data.score.ScoreElement; +import org.vitrivr.cineast.core.data.segments.SegmentContainer; +import org.vitrivr.cineast.core.features.abstracts.AbstractFeatureModule; + +public class InceptionResnetV2 extends AbstractFeatureModule { + + public static final int ENCODING_SIZE = 1536; + private static final String TABLE_NAME = "features_inceptionresnetv2"; + private static final Distance DISTANCE = Distance.manhattan; + + private static final Logger LOGGER = LogManager.getLogger(); + + /** + * Required dimensions of visual embedding model. + */ + public static final int IMAGE_WIDTH = 299; + public static final int IMAGE_HEIGHT = 299; + + /** + * Resource paths. + */ + private static final String MODEL_PATH = "resources/VisualTextCoEmbedding/inception_resnet_v2_weights_tf_dim_ordering_tf_kernels_notop"; + + /** + * Model input and output names. + */ + public static final String INPUT = "input_1"; + public static final String OUTPUT = "global_average_pooling2d"; + + /** + * InceptionResNetV2 pretrained on ImageNet: https://storage.googleapis.com/tensorflow/keras-applications/inception_resnet_v2/inception_resnet_v2_weights_tf_dim_ordering_tf_kernels_notop.h5 + */ + private static SavedModelBundle model; + + public InceptionResnetV2() { + super(TABLE_NAME, ENCODING_SIZE, ENCODING_SIZE); + } + + @Override + public void processSegment(SegmentContainer sc) { + // Return if already processed + if (phandler.idExists(sc.getId())) { + return; + } + + // Case: segment contains video frames + if (!sc.getVideoFrames().isEmpty() && sc.getVideoFrames().get(0) != VideoFrame.EMPTY_VIDEO_FRAME) { + List frames = sc.getVideoFrames().stream() + .map(VideoFrame::getImage) + .collect(Collectors.toList()); + + float[] encodingArray = encodeVideo(frames); + this.persist(sc.getId(), new FloatVectorImpl(encodingArray)); + + return; + } + + // Case: segment contains image + if (sc.getMostRepresentativeFrame() != VideoFrame.EMPTY_VIDEO_FRAME) { + BufferedImage image = sc.getMostRepresentativeFrame().getImage().getBufferedImage(); + + if (image != null) { + float[] encodingArray = encodeImage(image); + this.persist(sc.getId(), new FloatVectorImpl(encodingArray)); + } + + // Insert return here if additional cases are added! + } + } + + @Override + public List getSimilar(SegmentContainer sc, ReadableQueryConfig qc) { + float[] encodingArray = null; + + if (!sc.getVideoFrames().isEmpty() && sc.getVideoFrames().get(0) != VideoFrame.EMPTY_VIDEO_FRAME) { + // Case: segment contains video frames + List frames = sc.getVideoFrames().stream() + .map(VideoFrame::getImage) + .collect(Collectors.toList()); + + encodingArray = encodeVideo(frames); + } else if (sc.getMostRepresentativeFrame() != VideoFrame.EMPTY_VIDEO_FRAME) { + // Case: segment contains image + BufferedImage image = sc.getMostRepresentativeFrame().getImage().getBufferedImage(); + + if (image != null) { + encodingArray = encodeImage(image); + } else { + LOGGER.error("Could not get similar because image could not be converted to BufferedImage."); + } + } + + if (encodingArray == null) { + LOGGER.error("Could not get similar because no acceptable modality was provided."); + return new ArrayList<>(); + } + + // Ensure the correct distance function is used + QueryConfig queryConfig = QueryConfig.clone(qc); + queryConfig.setDistance(DISTANCE); + + return getSimilar(encodingArray, queryConfig); + } + + @Override + public List getSimilar(String segmentId, ReadableQueryConfig qc) { + // Ensure the correct distance function is used + QueryConfig queryConfig = QueryConfig.clone(qc); + queryConfig.setDistance(DISTANCE); + + return super.getSimilar(segmentId, queryConfig); + } + + public static SavedModelBundle getModel() { + initializeModel(); + + return model; + } + + /** + * Encodes the given image using InceptionResnetV2. + * + * @return Intermediary encoding, not yet embedded. + */ + public static float[] encodeImage(BufferedImage image) { + initializeModel(); + + float[] processedColors = preprocessImage(image); + + try (TFloat32 imageTensor = TFloat32.tensorOf(Shape.of(1, IMAGE_WIDTH, IMAGE_HEIGHT, 3), DataBuffers.of(processedColors))) { + HashMap inputMap = new HashMap<>(); + inputMap.put(INPUT, imageTensor); + + Map resultMap = model.call(inputMap); + + try (TFloat32 encoding = (TFloat32) resultMap.get(OUTPUT)) { + + float[] embeddingArray = new float[ENCODING_SIZE]; + FloatDataBuffer floatBuffer = DataBuffers.of(embeddingArray); + encoding.read(floatBuffer); + + return embeddingArray; + } + } + } + + /** + * Encodes each frame of the given video using InceptionResnetV2 and returns the mean encoding as float array. + * + * @param frames List of frames in the video or shot to be encoded. + * @return Mean of frame encodings as float array. + */ + public static float[] encodeVideo(List frames) { + List encodings = frames.stream().map(image -> encodeImage(image.getBufferedImage())).collect(Collectors.toList()); + + // Sum + float[] meanEncoding = encodings.stream().reduce(new float[ENCODING_SIZE], (encoding0, encoding1) -> { + float[] tempSum = new float[ENCODING_SIZE]; + + for (int i = 0; i < ENCODING_SIZE; i++) { + tempSum[i] = encoding0[i] + encoding1[i]; + } + + return tempSum; + }); + + // Calculate mean + for (int i = 0; i < ENCODING_SIZE; i++) { + meanEncoding[i] /= encodings.size(); + } + + return meanEncoding; + } + + /** + * Preprocesses the image, so it can be used as input to the InceptionResnetV2. Involves rescaling, remapping and converting the image to a float array. + * + * @return Float array representation of the input image. + */ + public static float[] preprocessImage(BufferedImage image) { + if (image.getWidth() != IMAGE_WIDTH || image.getHeight() != IMAGE_HEIGHT) { + try { + image = Thumbnails.of(image).forceSize(IMAGE_WIDTH, IMAGE_HEIGHT).asBufferedImage(); + } catch (IOException e) { + LOGGER.error("Could not resize image", e); + } + } + int[] colors = image.getRGB(0, 0, IMAGE_WIDTH, IMAGE_HEIGHT, null, 0, IMAGE_WIDTH); + int[] rgb = colorsToRGB(colors); + return preprocessInput(rgb); + } + + /** + * Preprocesses input in a way equivalent to that performed in the Python TensorFlow library. + *

+ * Maps all values from [0,255] to [-1, 1]. + */ + private static float[] preprocessInput(int[] colors) { + // x /= 127.5 + // x -= 1. + float[] processedColors = new float[colors.length]; + for (int i = 0; i < colors.length; i++) { + processedColors[i] = (colors[i] / 127.5f) - 1; + } + + return processedColors; + } + + /** + * Converts an integer colors array storing ARGB values in each integer into an integer array where each integer stores R, G or B value. + */ + private static int[] colorsToRGB(int[] colors) { + int[] rgb = new int[colors.length * 3]; + + for (int i = 0; i < colors.length; i++) { + // Start index for rgb array + int j = i * 3; + rgb[j] = (colors[i] >> 16) & 0xFF; // r + rgb[j + 1] = (colors[i] >> 8) & 0xFF; // g + rgb[j + 2] = colors[i] & 0xFF; // b + } + + return rgb; + } + + private static void initializeModel() { + if (model == null) { + model = SavedModelBundle.load(MODEL_PATH); + } + } +} diff --git a/cineast-core/src/main/java/org/vitrivr/cineast/core/features/VisualTextCoEmbedding.java b/cineast-core/src/main/java/org/vitrivr/cineast/core/features/VisualTextCoEmbedding.java index 0486366c0..9f6941773 100644 --- a/cineast-core/src/main/java/org/vitrivr/cineast/core/features/VisualTextCoEmbedding.java +++ b/cineast-core/src/main/java/org/vitrivr/cineast/core/features/VisualTextCoEmbedding.java @@ -1,7 +1,5 @@ package org.vitrivr.cineast.core.features; -import java.awt.geom.AffineTransform; -import java.awt.image.AffineTransformOp; import java.awt.image.BufferedImage; import java.util.HashMap; import java.util.List; @@ -31,23 +29,15 @@ public class VisualTextCoEmbedding extends AbstractFeatureModule { private static final int EMBEDDING_SIZE = 256; - private static final int ENCODING_SIZE = 1536; private static final String TABLE_NAME = "features_visualtextcoembedding"; private static final Distance DISTANCE = ReadableQueryConfig.Distance.euclidean; - /** - * Required dimensions of visual embedding model. - */ - private static final int IMAGE_WIDTH = 299; - private static final int IMAGE_HEIGHT = 299; - /** * Resource paths. */ private static final String RESOURCE_PATH = "resources/VisualTextCoEmbedding/"; private static final String TEXT_EMBEDDING_MODEL = "universal-sentence-encoder_4"; private static final String TEXT_CO_EMBEDDING_MODEL = "text-co-embedding"; - private static final String VISUAL_EMBEDDING_MODEL = "inception_resnet_v2_weights_tf_dim_ordering_tf_kernels_notop"; private static final String VISUAL_CO_EMBEDDING_MODEL = "visual-co-embedding"; /** @@ -57,8 +47,6 @@ public class VisualTextCoEmbedding extends AbstractFeatureModule { private static final String TEXT_EMBEDDING_OUTPUT = "outputs"; private static final String TEXT_CO_EMBEDDING_INPUT = "textual_features"; private static final String TEXT_CO_EMBEDDING_OUTPUT = "l2_norm"; - private static final String VISUAL_EMBEDDING_INPUT = "input_1"; - private static final String VISUAL_EMBEDDING_OUTPUT = "global_average_pooling2d"; private static final String VISUAL_CO_EMBEDDING_INPUT = "visual_features"; private static final String VISUAL_CO_EMBEDDING_OUTPUT = "l2_norm"; @@ -89,31 +77,31 @@ public VisualTextCoEmbedding() { } @Override - public void processSegment(SegmentContainer shot) { + public void processSegment(SegmentContainer sc) { // Return if already processed - if (phandler.idExists(shot.getId())) { + if (phandler.idExists(sc.getId())) { return; } // Case: segment contains video frames - if (!(shot.getVideoFrames().size() > 0 && shot.getVideoFrames().get(0) == VideoFrame.EMPTY_VIDEO_FRAME)) { - List frames = shot.getVideoFrames().stream() + if (!sc.getVideoFrames().isEmpty() && sc.getVideoFrames().get(0) != VideoFrame.EMPTY_VIDEO_FRAME) { + List frames = sc.getVideoFrames().stream() .map(VideoFrame::getImage) .collect(Collectors.toList()); float[] embeddingArray = embedVideo(frames); - this.persist(shot.getId(), new FloatVectorImpl(embeddingArray)); + this.persist(sc.getId(), new FloatVectorImpl(embeddingArray)); return; } // Case: segment contains image - if (shot.getMostRepresentativeFrame() != VideoFrame.EMPTY_VIDEO_FRAME) { - BufferedImage image = shot.getMostRepresentativeFrame().getImage().getBufferedImage(); + if (sc.getMostRepresentativeFrame() != VideoFrame.EMPTY_VIDEO_FRAME) { + BufferedImage image = sc.getMostRepresentativeFrame().getImage().getBufferedImage(); if (image != null) { float[] embeddingArray = embedImage(image); - this.persist(shot.getId(), new FloatVectorImpl(embeddingArray)); + this.persist(sc.getId(), new FloatVectorImpl(embeddingArray)); } // Insert return here if additional cases are added! @@ -153,7 +141,7 @@ private void initializeTextEmbedding() { private void initializeVisualEmbedding() { if (visualEmbedding == null) { - visualEmbedding = SavedModelBundle.load(RESOURCE_PATH + VISUAL_EMBEDDING_MODEL); + visualEmbedding = InceptionResnetV2.getModel(); } if (visualCoEmbedding == null) { visualCoEmbedding = SavedModelBundle.load(RESOURCE_PATH + VISUAL_CO_EMBEDDING_MODEL); @@ -192,20 +180,15 @@ private float[] embedText(String text) { private float[] embedImage(BufferedImage image) { initializeVisualEmbedding(); - if (image.getWidth() != IMAGE_WIDTH || image.getHeight() != IMAGE_HEIGHT) { - image = rescale(image, IMAGE_WIDTH, IMAGE_HEIGHT); - } - int[] colors = image.getRGB(0, 0, IMAGE_WIDTH, IMAGE_HEIGHT, null, 0, IMAGE_WIDTH); - int[] rgb = colorsToRGB(colors); - float[] processedColors = preprocessInput(rgb); + float[] processedColors = InceptionResnetV2.preprocessImage(image); - try (TFloat32 imageTensor = TFloat32.tensorOf(Shape.of(1, IMAGE_WIDTH, IMAGE_HEIGHT, 3), DataBuffers.of(processedColors))) { + try (TFloat32 imageTensor = TFloat32.tensorOf(Shape.of(1, InceptionResnetV2.IMAGE_WIDTH, InceptionResnetV2.IMAGE_HEIGHT, 3), DataBuffers.of(processedColors))) { HashMap inputMap = new HashMap<>(); - inputMap.put(VISUAL_EMBEDDING_INPUT, imageTensor); + inputMap.put(InceptionResnetV2.INPUT, imageTensor); Map resultMap = visualEmbedding.call(inputMap); - try (TFloat32 intermediaryEmbedding = (TFloat32) resultMap.get(VISUAL_EMBEDDING_OUTPUT)) { + try (TFloat32 intermediaryEmbedding = (TFloat32) resultMap.get(InceptionResnetV2.OUTPUT)) { inputMap.clear(); inputMap.put(VISUAL_CO_EMBEDDING_INPUT, intermediaryEmbedding); @@ -227,25 +210,9 @@ private float[] embedImage(BufferedImage image) { private float[] embedVideo(List frames) { initializeVisualEmbedding(); - List encodings = frames.stream().map(image -> encodeImage(image.getBufferedImage())).collect(Collectors.toList()); - - // Sum - float[] meanEncoding = encodings.stream().reduce(new float[ENCODING_SIZE], (encoding0, encoding1) -> { - float[] tempSum = new float[ENCODING_SIZE]; - - for (int i = 0; i < ENCODING_SIZE; i++) { - tempSum[i] = encoding0[i] + encoding1[i]; - } - - return tempSum; - }); - - // Calculate mean - for (int i = 0; i < ENCODING_SIZE; i++) { - meanEncoding[i] /= encodings.size(); - } + float[] meanEncoding = InceptionResnetV2.encodeVideo(frames); - try (TFloat32 encoding = TFloat32.tensorOf(Shape.of(1, ENCODING_SIZE), DataBuffers.of(meanEncoding))) { + try (TFloat32 encoding = TFloat32.tensorOf(Shape.of(1, InceptionResnetV2.ENCODING_SIZE), DataBuffers.of(meanEncoding))) { HashMap inputMap = new HashMap<>(); inputMap.put(VISUAL_CO_EMBEDDING_INPUT, encoding); @@ -262,83 +229,4 @@ private float[] embedVideo(List frames) { } } } - - /** - * Encodes the given image using the encoding network. - *

- * Visual embedding must already be initialized. - * - * @return Intermediary encoding, not yet embedded. - */ - private float[] encodeImage(BufferedImage image) { - if (image.getWidth() != IMAGE_WIDTH || image.getHeight() != IMAGE_HEIGHT) { - image = rescale(image, IMAGE_WIDTH, IMAGE_HEIGHT); - } - int[] colors = image.getRGB(0, 0, IMAGE_WIDTH, IMAGE_HEIGHT, null, 0, IMAGE_WIDTH); - int[] rgb = colorsToRGB(colors); - float[] processedColors = preprocessInput(rgb); - - try (TFloat32 imageTensor = TFloat32.tensorOf(Shape.of(1, IMAGE_WIDTH, IMAGE_HEIGHT, 3), DataBuffers.of(processedColors))) { - HashMap inputMap = new HashMap<>(); - inputMap.put(VISUAL_EMBEDDING_INPUT, imageTensor); - - Map resultMap = visualEmbedding.call(inputMap); - - try (TFloat32 encoding = (TFloat32) resultMap.get(VISUAL_EMBEDDING_OUTPUT)) { - - float[] embeddingArray = new float[ENCODING_SIZE]; - FloatDataBuffer floatBuffer = DataBuffers.of(embeddingArray); - encoding.read(floatBuffer); - - return embeddingArray; - } - } - } - - /** - * Preprocesses input in a way equivalent to that performed in the Python TensorFlow library. - *

- * Maps all values from [0,255] to [-1, 1]. - */ - private static float[] preprocessInput(int[] colors) { - // x /= 127.5 - // x -= 1. - float[] processedColors = new float[colors.length]; - for (int i = 0; i < colors.length; i++) { - processedColors[i] = (colors[i] / 127.5f) - 1; - } - - return processedColors; - } - - /** - * Converts an integer colors array storing ARGB values in each integer into an integer array where each integer stores R, G or B value. - */ - private static int[] colorsToRGB(int[] colors) { - int[] rgb = new int[colors.length * 3]; - - for (int i = 0; i < colors.length; i++) { - // Start index for rgb array - int j = i * 3; - rgb[j] = (colors[i] >> 16) & 0xFF; // r - rgb[j + 1] = (colors[i] >> 8) & 0xFF; // g - rgb[j + 2] = colors[i] & 0xFF; // b - } - - return rgb; - } - - /** - * Rescales a buffered image using bilinear interpolation. - */ - private static BufferedImage rescale(BufferedImage image, int width, int height) { - BufferedImage scaledImage = new BufferedImage(width, height, image.getType()); - - AffineTransform affineTransform = AffineTransform.getScaleInstance((double) width / image.getWidth(), (double) height / image.getHeight()); - // The OpenCV resize with which the training data was scaled defaults to bilinear interpolation - AffineTransformOp transformOp = new AffineTransformOp(affineTransform, AffineTransformOp.TYPE_BILINEAR); - scaledImage = transformOp.filter(image, scaledImage); - - return scaledImage; - } }