Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add image analysis w/ tensorflow #318

Merged
merged 5 commits into from
Jul 5, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/main/python/tf/detect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
import sys
from util.init import *
from model.object_detection import *
PYAUT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PYAUT_DIR)

from aut.common import WebArchive
from pyspark.sql import DataFrame


if __name__ == "__main__":
# initialization
args = get_args()
sys.path.append(args.spark)
conf, sc, sql_context = init_spark(args.master, args.aut_jar)
zip_model_module(PYAUT_DIR)
sc.addPyFile(os.path.join(PYAUT_DIR, "tf", "model.zip"))
if args.img_model == "ssd":
detector = SSD(sc, sql_context, args)

# preprocessing raw images
arc = WebArchive(sc, sql_context, args.web_archive)
df = DataFrame(arc.loader.extractImages(arc.path), sql_context)
filter_size = tuple(args.filter_size)
print("height >= %d and width >= %d"%filter_size)
preprocessed = df.filter("height >= %d and width >= %d"%filter_size)

# detection
model_broadcast = detector.broadcast()
detect_udf = detector.get_detect_udf(model_broadcast)
res = preprocessed.select("url", detect_udf(col("bytes")).alias("prediction"), "bytes")
res.write.json(args.output_path)
17 changes: 17 additions & 0 deletions src/main/python/tf/extract_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import numpy as np
import argparse
from model.object_detection import SSDExtractor


def get_args():
parser = argparse.ArgumentParser(description='Extracting images from model output.')
parser.add_argument('--res_dir', help='Path of result (model output) directory.')
parser.add_argument('--output_dir', help='Path of extracted image file output directory.')
parser.add_argument('--threshold', type=float, help='Threshold of detection confidence scores.')
return parser.parse_args()


if __name__ == "__main__":
args = get_args()
extractor = SSDExtractor(args.res_dir, args.output_dir)
extractor.extract_and_save(class_ids="all", threshold=args.threshold)
Empty file.
109 changes: 109 additions & 0 deletions src/main/python/tf/model/object_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import pickle
import os
import json
import numpy as np
from .preprocess import *
from pyspark.sql.functions import pandas_udf, PandasUDFType, col
from pyspark.sql.types import ArrayType, FloatType
import tensorflow as tf
import pandas as pd


PKG_DIR = os.path.dirname(os.path.abspath(__file__))


class ImageExtractor:
def __init__(self, res_dir, output_dir):
self.res_dir = res_dir
self.output_dir = output_dir


def _extract_and_save(self, rec, class_ids, threshold):
raise NotImplementedError("Please overwrite this method.")


def extract_and_save(self, class_ids, threshold):
if class_ids == "all":
class_ids = list(self.cate_dict.keys())

for idx in class_ids:
cls = self.cate_dict[idx]
check_dir(self.output_dir + "/%s/"%cls, create=True)

for fname in os.listdir(self.res_dir):
if fname.startswith("part-"):
print("Extracting:", self.res_dir+"/"+fname)
with open(self.res_dir+"/"+fname) as f:
for line in f:
rec = json.loads(line)
self._extract_and_save(rec, class_ids, threshold)


class SSD:
def __init__(self, sc, sql_context, args):
self.sc = sc
self.sql_context = sql_context
self.category = load_cate_dict_from_pbtxt("%s/category/mscoco_label_map.pbtxt"%PKG_DIR)
self.checkpoint = "%s/graph/ssd_mobilenet_v1_fpn_640x640/frozen_inference_graph.pb"%PKG_DIR
self.args = args
with tf.io.gfile.GFile(self.checkpoint, 'rb') as f:
model_params = f.read()
self.model_params = model_params


def broadcast(self):
return self.sc.broadcast(self.model_params)


def get_detect_udf(self, model_broadcast):
def batch_proc(bytes_batch):
with tf.Graph().as_default() as g:
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_broadcast.value)
tf.import_graph_def(graph_def, name='')
image_tensor = g.get_tensor_by_name('image_tensor:0')
detection_scores = g.get_tensor_by_name('detection_scores:0')
detection_classes = g.get_tensor_by_name('detection_classes:0')

with tf.Session().as_default() as sess:
result = []
image_size = (640, 640)
images = np.array([img2np(b, image_size) for b in bytes_batch])
res = sess.run([detection_scores, detection_classes], feed_dict={image_tensor: images})
for i in range(res[0].shape[0]):
result.append([res[0][i], res[1][i]])
return pd.Series(result)
return pandas_udf(ArrayType(ArrayType(FloatType())), PandasUDFType.SCALAR)(batch_proc)


class SSDExtractor(ImageExtractor):
def __init__(self, res_dir, output_dir):
super().__init__(res_dir, output_dir)
self.cate_dict = load_cate_dict_from_pbtxt("%s/category/mscoco_label_map.pbtxt"%PKG_DIR)


def _extract_and_save(self, rec, class_ids, threshold):
pred = rec['prediction']
scores = np.array(pred[0])
classes = np.array(pred[1])
valid_classes = np.unique(classes[scores >= threshold])
if valid_classes.shape[0] > 0:
if class_ids != "all":
inter = list(set(valid_classes).intersection(set(class_ids)))
if len(inter) > 0:
valid_classes = np.array(inter)
else:
valid_classes = None
else:
valid_classes = None

if valid_classes is not None:
for cls_idx in valid_classes:
cls = self.cate_dict[cls_idx]
try:
img = str2img(rec["bytes"])
img.save(self.output_dir+ "/%s/"%cls + url_parse(rec["url"]))
except:
fname = self.output_dir+ "/%s/"%cls + url_parse(rec["url"])
print("Failing to save:", fname)

61 changes: 61 additions & 0 deletions src/main/python/tf/model/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from PIL import Image
import io
import base64
import os
import numpy as np
import re


def str2img(byte_str):
return Image.open(io.BytesIO(base64.b64decode(bytes(byte_str, 'utf-8'))))


def img2np(byte_str, resize=None):
try:
image = str2img(byte_str)
img = image.convert("RGB")
if resize is not None:
img = img.resize(resize, Image.BILINEAR)
img = np.array(img).astype(np.uint8)
img_shape = np.shape(img)

if len(img_shape) == 2:
img = np.stack([img, img, img], axis=-1)
elif img_shape[-1] >= 3:
img = img[:,:,:3]

return img

except:
if resize is not None:
return np.zeros((resize[0], resize[1], 3))
else:
return np.zeros((1, 1, 3))


def url_parse(url):
return url.split("://")[1].replace("/", "%%%%")


def check_dir(path, create=False):
if os.path.exists(path):
return True
else:
if create:
os.makedirs(path, exist_ok=True)
return False


def load_cate_dict_from_pbtxt(path, key="id", value="display_name"):
cate_dict = {}
with open(path) as f:
for line in f:
entry = line.strip().split(":")
if len(entry) > 1:
if entry[0] == key:
cur_key = int(entry[1])
if entry[0] == value:
cur_cate = re.findall(r'"(.*?)"', entry[1])[0]
cate_dict[cur_key] = cur_cate
return cate_dict

Empty file.
46 changes: 46 additions & 0 deletions src/main/python/tf/util/init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import argparse
import os
import zipfile
from pyspark import SparkConf, SparkContext, SQLContext
import re
import os

def init_spark(master, aut_jar):
h324yang marked this conversation as resolved.
Show resolved Hide resolved
conf = SparkConf()
conf.set("spark.jars", aut_jar)
conf_path = os.path.dirname(os.path.abspath(__file__))+"/spark.conf"
conf_dict = read_conf(conf_path)
for item, value in conf_dict.items():
conf.set(item, value)
sc = SparkContext(master, "aut image analysis", conf=conf)
sql_context = SQLContext(sc)
return conf, sc, sql_context


def get_args():
parser = argparse.ArgumentParser(description='PySpark for Web Archive Image Retrieval.')
parser.add_argument('--web_archive', help='Path to warcs.', default='/tuna1/scratch/nruest/geocites/warcs')
parser.add_argument('--aut_jar', help='Path to compiled aut jar.', default='aut/target/aut-0.17.1-SNAPSHOT-fatjar.jar')
parser.add_argument('--spark', help='Path to Apache Spark.', default='spark-2.3.2-bin-hadoop2.7/bin')
parser.add_argument('--master', help='Apache Spark master IP address and port.', default='spark://127.0.1.1:7077')
parser.add_argument('--img_model', help='Model for image processing.', default='ssd')
parser.add_argument('--filter_size', nargs='+', type=int, help='Filter out images smaller than filter_size', default=[640, 640])
parser.add_argument('--output_path', help='Path to image model output.', default='warc_res')
return parser.parse_args()


def zip_model_module(PYAUT_DIR):
zip = zipfile.ZipFile(os.path.join(PYAUT_DIR, "tf", "model.zip"), "w")
zip.write(os.path.join(PYAUT_DIR, "tf", "model", "__init__.py"), os.path.join("model", "__init__.py"))
zip.write(os.path.join(PYAUT_DIR, "tf", "model", "object_detection.py"), os.path.join("model", "object_detection.py"))
zip.write(os.path.join(PYAUT_DIR, "tf", "model", "preprocess.py"), os.path.join("model", "preprocess.py"))


def read_conf(conf_path):
conf_dict = {}
with open(conf_path) as f:
for line in f:
conf = re.findall(r'\S+', line.strip())
conf_dict[conf[0]] = conf[1]
return conf_dict

7 changes: 7 additions & 0 deletions src/main/python/tf/util/spark.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
spark.sql.execution.arrow.enabled true
spark.sql.execution.arrow.maxRecordsPerBatch 320
spark.executor.memory 16G
spark.cores.max 48
spark.executor.cores 6
spark.driver.memory 64G
spark.task.cpus 6