Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InceptionResnetV2 feature for Query-by-Example #243

Merged
merged 32 commits into from
Jan 5, 2022
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
54a9b4f
Implemented visual-text co-embedding extraction of full video segment…
Spiess Dec 8, 2021
a67249e
bumping version number
silvanheller Dec 8, 2021
8e714c9
Added ThreadLocalObjectCache to reuse objects in a thread-safe manner
lucaro Dec 8, 2021
8f4cfcf
Implemented early return to conform to codebase guidelines.
Spiess Dec 9, 2021
ee7819d
Refactored visual-text co-embedding to load buffered image only once …
Spiess Dec 9, 2021
0a9b5fc
Merge pull request #238 from vitrivr/feature/video-coembedding
Spiess Dec 9, 2021
5d21c8e
Merge branch 'dev' of github.com:vitrivr/cineast into dev
sauterl Dec 13, 2021
3713d0a
Removed unused and empty JavaDoc.
Spiess Dec 16, 2021
9dd6759
Removed unused commented code.
Spiess Dec 16, 2021
1b0e445
Fixed JavaDoc errors.
Spiess Dec 16, 2021
c713422
Raised source and target language compatibility to Java 9.
Spiess Dec 16, 2021
f681f54
Updated log4j *again* to 2.16
lucaro Dec 17, 2021
455e8a5
Raised source and target language compatibility to Java 11 (LTS).
Spiess Dec 17, 2021
69747dc
Merge pull request #240 from vitrivr/log4j-2.16
silvanheller Dec 19, 2021
350e1d6
Reformatted all code to adhere to the project code-style.
Spiess Dec 22, 2021
8427d5c
Refactored InceptionResnetV2 feature encoding into its own separate f…
Spiess Jan 3, 2022
7805260
bumping version number
silvanheller Dec 8, 2021
d10d914
Added ThreadLocalObjectCache to reuse objects in a thread-safe manner
lucaro Dec 8, 2021
201edae
Implemented visual-text co-embedding extraction of full video segment…
Spiess Dec 8, 2021
1ecf39a
Implemented early return to conform to codebase guidelines.
Spiess Dec 9, 2021
f774eb9
Refactored visual-text co-embedding to load buffered image only once …
Spiess Dec 9, 2021
e442ec5
Removed unused and empty JavaDoc.
Spiess Dec 16, 2021
8a2ff57
Removed unused commented code.
Spiess Dec 16, 2021
0d9e982
Fixed JavaDoc errors.
Spiess Dec 16, 2021
6a2c5cd
Raised source and target language compatibility to Java 9.
Spiess Dec 16, 2021
4b2e801
Raised source and target language compatibility to Java 11 (LTS).
Spiess Dec 17, 2021
ef8d793
Reformatted all code to adhere to the project code-style.
Spiess Dec 22, 2021
16dc283
Merge branch 'dev' into feature/inceptionresnetv2-similarity
Spiess Jan 4, 2022
8d4d3a1
Merge branch 'master' into feature/inceptionresnetv2-similarity
silvanheller Jan 4, 2022
e9599fb
Refactored InceptionResnetV2 image preprocessing.
Spiess Jan 4, 2022
2577d35
Updated InceptionResnetV2 feature to aggregate video frame features i…
Spiess Jan 4, 2022
1db1dc2
Replaced check for collection size > 0 with isEmpty in InceptionResne…
Spiess Jan 5, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ allprojects {
group = 'org.vitrivr'

/* Our current version, on dev branch this should always be release+1-SNAPSHOT */
version = '3.6.1'
version = '3.6.2'

apply plugin: 'java-library'
apply plugin: 'maven-publish'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
package org.vitrivr.cineast.core.features;

import java.awt.geom.AffineTransform;
import java.awt.image.AffineTransformOp;
import java.awt.image.BufferedImage;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.types.TFloat32;
import org.vitrivr.cineast.core.config.QueryConfig;
import org.vitrivr.cineast.core.config.ReadableQueryConfig;
import org.vitrivr.cineast.core.config.ReadableQueryConfig.Distance;
import org.vitrivr.cineast.core.data.FloatVectorImpl;
import org.vitrivr.cineast.core.data.frames.VideoFrame;
import org.vitrivr.cineast.core.data.score.ScoreElement;
import org.vitrivr.cineast.core.data.segments.SegmentContainer;
import org.vitrivr.cineast.core.features.abstracts.AbstractFeatureModule;

public class InceptionResnetV2 extends AbstractFeatureModule {

public static final int ENCODING_SIZE = 1536;
private static final String TABLE_NAME = "features_inceptionresnetv2";
private static final Distance DISTANCE = ReadableQueryConfig.Distance.euclidean;

private static final Logger LOGGER = LogManager.getLogger();

/**
* Required dimensions of visual embedding model.
*/
public static final int IMAGE_WIDTH = 299;
public static final int IMAGE_HEIGHT = 299;

/**
* Resource paths.
*/
private static final String MODEL_PATH = "resources/VisualTextCoEmbedding/inception_resnet_v2_weights_tf_dim_ordering_tf_kernels_notop";

/**
* Model input and output names.
*/
public static final String INPUT = "input_1";
public static final String OUTPUT = "global_average_pooling2d";

/**
* InceptionResNetV2 pretrained on ImageNet: https://storage.googleapis.com/tensorflow/keras-applications/inception_resnet_v2/inception_resnet_v2_weights_tf_dim_ordering_tf_kernels_notop.h5
*/
private static SavedModelBundle model;

public InceptionResnetV2() {
super(TABLE_NAME, ENCODING_SIZE, ENCODING_SIZE);
}

@Override
public void processSegment(SegmentContainer shot) {
// Return if already processed
if (phandler.idExists(shot.getId())) {
return;
}

// Case: segment contains image
if (shot.getMostRepresentativeFrame() != VideoFrame.EMPTY_VIDEO_FRAME) {
BufferedImage image = shot.getMostRepresentativeFrame().getImage().getBufferedImage();

lucaro marked this conversation as resolved.
Show resolved Hide resolved
if (image != null) {
float[] encodingArray = encodeImage(image);
this.persist(shot.getId(), new FloatVectorImpl(encodingArray));
}

// Insert return here if additional cases are added!
}
}

@Override
public List<ScoreElement> getSimilar(SegmentContainer sc, ReadableQueryConfig qc) {
if (sc.getMostRepresentativeFrame() == VideoFrame.EMPTY_VIDEO_FRAME) {
LOGGER.error("Could not get similar because no image was provided.");
return new ArrayList<>();
}

BufferedImage image = sc.getMostRepresentativeFrame().getImage().getBufferedImage();

if (image == null) {
LOGGER.error("Could not get similar because image could not be converted to BufferedImage.");
return new ArrayList<>();
}

// Ensure the correct distance function is used
QueryConfig queryConfig = QueryConfig.clone(qc);
queryConfig.setDistance(DISTANCE);

float[] encodingArray = encodeImage(image);

return getSimilar(encodingArray, queryConfig);
}

@Override
public List<ScoreElement> getSimilar(String segmentId, ReadableQueryConfig qc) {
// Ensure the correct distance function is used
QueryConfig queryConfig = QueryConfig.clone(qc);
queryConfig.setDistance(DISTANCE);

return super.getSimilar(segmentId, queryConfig);
}

public static SavedModelBundle getModel() {
initializeModel();

return model;
}

/**
* Encodes the given image using InceptionResnetV2.
*
* @return Intermediary encoding, not yet embedded.
*/
public static float[] encodeImage(BufferedImage image) {
initializeModel();

if (image.getWidth() != IMAGE_WIDTH || image.getHeight() != IMAGE_HEIGHT) {
image = rescale(image, IMAGE_WIDTH, IMAGE_HEIGHT);
}
int[] colors = image.getRGB(0, 0, IMAGE_WIDTH, IMAGE_HEIGHT, null, 0, IMAGE_WIDTH);
int[] rgb = colorsToRGB(colors);
float[] processedColors = preprocessInput(rgb);

try (TFloat32 imageTensor = TFloat32.tensorOf(Shape.of(1, IMAGE_WIDTH, IMAGE_HEIGHT, 3), DataBuffers.of(processedColors))) {
HashMap<String, Tensor> inputMap = new HashMap<>();
inputMap.put(INPUT, imageTensor);

Map<String, Tensor> resultMap = model.call(inputMap);

try (TFloat32 encoding = (TFloat32) resultMap.get(OUTPUT)) {

float[] embeddingArray = new float[ENCODING_SIZE];
FloatDataBuffer floatBuffer = DataBuffers.of(embeddingArray);
encoding.read(floatBuffer);

return embeddingArray;
}
}
}

// TODO: Move image util related functions to a dedicated util class

/**
* Rescales a buffered image using bilinear interpolation.
*/
public static BufferedImage rescale(BufferedImage image, int width, int height) {
BufferedImage scaledImage = new BufferedImage(width, height, image.getType());
Spiess marked this conversation as resolved.
Show resolved Hide resolved

AffineTransform affineTransform = AffineTransform.getScaleInstance((double) width / image.getWidth(), (double) height / image.getHeight());
// The OpenCV resize with which the training data was scaled defaults to bilinear interpolation
AffineTransformOp transformOp = new AffineTransformOp(affineTransform, AffineTransformOp.TYPE_BILINEAR);
scaledImage = transformOp.filter(image, scaledImage);

return scaledImage;
}

/**
* Preprocesses input in a way equivalent to that performed in the Python TensorFlow library.
* <p>
* Maps all values from [0,255] to [-1, 1].
*/
public static float[] preprocessInput(int[] colors) {
// x /= 127.5
// x -= 1.
float[] processedColors = new float[colors.length];
for (int i = 0; i < colors.length; i++) {
processedColors[i] = (colors[i] / 127.5f) - 1;
}

return processedColors;
}

/**
* Converts an integer colors array storing ARGB values in each integer into an integer array where each integer stores R, G or B value.
*/
public static int[] colorsToRGB(int[] colors) {
int[] rgb = new int[colors.length * 3];

for (int i = 0; i < colors.length; i++) {
// Start index for rgb array
int j = i * 3;
rgb[j] = (colors[i] >> 16) & 0xFF; // r
rgb[j + 1] = (colors[i] >> 8) & 0xFF; // g
rgb[j + 2] = colors[i] & 0xFF; // b
}

return rgb;
}

private static void initializeModel() {
if (model == null) {
model = SavedModelBundle.load(MODEL_PATH);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package org.vitrivr.cineast.core.features;

import java.awt.geom.AffineTransform;
import java.awt.image.AffineTransformOp;
import static org.vitrivr.cineast.core.features.InceptionResnetV2.ENCODING_SIZE;
import static org.vitrivr.cineast.core.features.InceptionResnetV2.IMAGE_HEIGHT;
import static org.vitrivr.cineast.core.features.InceptionResnetV2.IMAGE_WIDTH;
import static org.vitrivr.cineast.core.features.InceptionResnetV2.colorsToRGB;
import static org.vitrivr.cineast.core.features.InceptionResnetV2.rescale;

import java.awt.image.BufferedImage;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -31,23 +35,15 @@
public class VisualTextCoEmbedding extends AbstractFeatureModule {

private static final int EMBEDDING_SIZE = 256;
private static final int ENCODING_SIZE = 1536;
private static final String TABLE_NAME = "features_visualtextcoembedding";
private static final Distance DISTANCE = ReadableQueryConfig.Distance.euclidean;

/**
* Required dimensions of visual embedding model.
*/
private static final int IMAGE_WIDTH = 299;
private static final int IMAGE_HEIGHT = 299;

/**
* Resource paths.
*/
private static final String RESOURCE_PATH = "resources/VisualTextCoEmbedding/";
private static final String TEXT_EMBEDDING_MODEL = "universal-sentence-encoder_4";
private static final String TEXT_CO_EMBEDDING_MODEL = "text-co-embedding";
private static final String VISUAL_EMBEDDING_MODEL = "inception_resnet_v2_weights_tf_dim_ordering_tf_kernels_notop";
private static final String VISUAL_CO_EMBEDDING_MODEL = "visual-co-embedding";

/**
Expand All @@ -57,8 +53,6 @@ public class VisualTextCoEmbedding extends AbstractFeatureModule {
private static final String TEXT_EMBEDDING_OUTPUT = "outputs";
private static final String TEXT_CO_EMBEDDING_INPUT = "textual_features";
private static final String TEXT_CO_EMBEDDING_OUTPUT = "l2_norm";
private static final String VISUAL_EMBEDDING_INPUT = "input_1";
private static final String VISUAL_EMBEDDING_OUTPUT = "global_average_pooling2d";
private static final String VISUAL_CO_EMBEDDING_INPUT = "visual_features";
private static final String VISUAL_CO_EMBEDDING_OUTPUT = "l2_norm";

Expand Down Expand Up @@ -153,7 +147,7 @@ private void initializeTextEmbedding() {

private void initializeVisualEmbedding() {
if (visualEmbedding == null) {
visualEmbedding = SavedModelBundle.load(RESOURCE_PATH + VISUAL_EMBEDDING_MODEL);
visualEmbedding = InceptionResnetV2.getModel();
}
if (visualCoEmbedding == null) {
visualCoEmbedding = SavedModelBundle.load(RESOURCE_PATH + VISUAL_CO_EMBEDDING_MODEL);
Expand Down Expand Up @@ -197,15 +191,15 @@ private float[] embedImage(BufferedImage image) {
}
int[] colors = image.getRGB(0, 0, IMAGE_WIDTH, IMAGE_HEIGHT, null, 0, IMAGE_WIDTH);
int[] rgb = colorsToRGB(colors);
float[] processedColors = preprocessInput(rgb);
float[] processedColors = InceptionResnetV2.preprocessInput(rgb);

try (TFloat32 imageTensor = TFloat32.tensorOf(Shape.of(1, IMAGE_WIDTH, IMAGE_HEIGHT, 3), DataBuffers.of(processedColors))) {
HashMap<String, Tensor> inputMap = new HashMap<>();
inputMap.put(VISUAL_EMBEDDING_INPUT, imageTensor);
inputMap.put(InceptionResnetV2.INPUT, imageTensor);

Map<String, Tensor> resultMap = visualEmbedding.call(inputMap);

try (TFloat32 intermediaryEmbedding = (TFloat32) resultMap.get(VISUAL_EMBEDDING_OUTPUT)) {
try (TFloat32 intermediaryEmbedding = (TFloat32) resultMap.get(InceptionResnetV2.OUTPUT)) {

inputMap.clear();
inputMap.put(VISUAL_CO_EMBEDDING_INPUT, intermediaryEmbedding);
Expand All @@ -227,7 +221,7 @@ private float[] embedImage(BufferedImage image) {
private float[] embedVideo(List<MultiImage> frames) {
initializeVisualEmbedding();

List<float[]> encodings = frames.stream().map(image -> encodeImage(image.getBufferedImage())).collect(Collectors.toList());
List<float[]> encodings = frames.stream().map(image -> InceptionResnetV2.encodeImage(image.getBufferedImage())).collect(Collectors.toList());

// Sum
float[] meanEncoding = encodings.stream().reduce(new float[ENCODING_SIZE], (encoding0, encoding1) -> {
Expand Down Expand Up @@ -262,83 +256,4 @@ private float[] embedVideo(List<MultiImage> frames) {
}
}
}

/**
* Encodes the given image using the encoding network.
* <p>
* Visual embedding must already be initialized.
*
* @return Intermediary encoding, not yet embedded.
*/
private float[] encodeImage(BufferedImage image) {
if (image.getWidth() != IMAGE_WIDTH || image.getHeight() != IMAGE_HEIGHT) {
image = rescale(image, IMAGE_WIDTH, IMAGE_HEIGHT);
}
int[] colors = image.getRGB(0, 0, IMAGE_WIDTH, IMAGE_HEIGHT, null, 0, IMAGE_WIDTH);
int[] rgb = colorsToRGB(colors);
float[] processedColors = preprocessInput(rgb);

try (TFloat32 imageTensor = TFloat32.tensorOf(Shape.of(1, IMAGE_WIDTH, IMAGE_HEIGHT, 3), DataBuffers.of(processedColors))) {
HashMap<String, Tensor> inputMap = new HashMap<>();
inputMap.put(VISUAL_EMBEDDING_INPUT, imageTensor);

Map<String, Tensor> resultMap = visualEmbedding.call(inputMap);

try (TFloat32 encoding = (TFloat32) resultMap.get(VISUAL_EMBEDDING_OUTPUT)) {

float[] embeddingArray = new float[ENCODING_SIZE];
FloatDataBuffer floatBuffer = DataBuffers.of(embeddingArray);
encoding.read(floatBuffer);

return embeddingArray;
}
}
}

/**
* Preprocesses input in a way equivalent to that performed in the Python TensorFlow library.
* <p>
* Maps all values from [0,255] to [-1, 1].
*/
private static float[] preprocessInput(int[] colors) {
// x /= 127.5
// x -= 1.
float[] processedColors = new float[colors.length];
for (int i = 0; i < colors.length; i++) {
processedColors[i] = (colors[i] / 127.5f) - 1;
}

return processedColors;
}

/**
* Converts an integer colors array storing ARGB values in each integer into an integer array where each integer stores R, G or B value.
*/
private static int[] colorsToRGB(int[] colors) {
int[] rgb = new int[colors.length * 3];

for (int i = 0; i < colors.length; i++) {
// Start index for rgb array
int j = i * 3;
rgb[j] = (colors[i] >> 16) & 0xFF; // r
rgb[j + 1] = (colors[i] >> 8) & 0xFF; // g
rgb[j + 2] = colors[i] & 0xFF; // b
}

return rgb;
}

/**
* Rescales a buffered image using bilinear interpolation.
*/
private static BufferedImage rescale(BufferedImage image, int width, int height) {
BufferedImage scaledImage = new BufferedImage(width, height, image.getType());

AffineTransform affineTransform = AffineTransform.getScaleInstance((double) width / image.getWidth(), (double) height / image.getHeight());
// The OpenCV resize with which the training data was scaled defaults to bilinear interpolation
AffineTransformOp transformOp = new AffineTransformOp(affineTransform, AffineTransformOp.TYPE_BILINEAR);
scaledImage = transformOp.filter(image, scaledImage);

return scaledImage;
}
}