Skip to content

Commit

Permalink
tensorflow-lite: use multiple tpu
Browse files Browse the repository at this point in the history
  • Loading branch information
koush committed Mar 16, 2023
1 parent 1c8ff24 commit 38ba31c
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 45 deletions.
4 changes: 2 additions & 2 deletions plugins/tensorflow-lite/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion plugins/tensorflow-lite/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@
"devDependencies": {
"@scrypted/sdk": "file:../../sdk"
},
"version": "0.0.112"
"version": "0.0.113"
}
12 changes: 6 additions & 6 deletions plugins/tensorflow-lite/src/detect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,11 @@ def run_detection_gstsample(self, detection_session: DetectionSession, gst_sampl
async def run_detection_videoframe(self, videoFrame: scrypted_sdk.VideoFrame) -> ObjectsDetected:
pass

def run_detection_avframe(self, detection_session: DetectionSession, avframe, settings: Any, src_size, convert_to_src_size) -> Tuple[ObjectsDetected, Any]:
async def run_detection_avframe(self, detection_session: DetectionSession, avframe, settings: Any, src_size, convert_to_src_size) -> Tuple[ObjectsDetected, Any]:
pil: Image.Image = avframe.to_image()
return self.run_detection_image(detection_session, pil, settings, src_size, convert_to_src_size)
return await self.run_detection_image(detection_session, pil, settings, src_size, convert_to_src_size)

def run_detection_image(self, detection_session: DetectionSession, image: Image.Image, settings: Any, src_size, convert_to_src_size) -> Tuple[ObjectsDetected, Any]:
async def run_detection_image(self, detection_session: DetectionSession, image: Image.Image, settings: Any, src_size, convert_to_src_size) -> Tuple[ObjectsDetected, Any]:
pass

def run_detection_crop(self, detection_session: DetectionSession, sample: Any, settings: Any, src_size, convert_to_src_size, bounding_box: Tuple[float, float, float, float]) -> ObjectsDetected:
Expand Down Expand Up @@ -335,7 +335,7 @@ def convert_to_src_size(point, normalize = False):
finally:
detection_session.running = False
else:
return self.run_detection_jpeg(detection_session, bytes(await scrypted_sdk.mediaManager.convertMediaObjectToBuffer(mediaObject, 'image/jpeg')), settings)
return await self.run_detection_jpeg(detection_session, bytes(await scrypted_sdk.mediaManager.convertMediaObjectToBuffer(mediaObject, 'image/jpeg')), settings)

if not create:
# a detection session may have been created, but not started
Expand Down Expand Up @@ -479,7 +479,7 @@ async def redetect(boundingBox: Tuple[float, float, float, float]):
if not current_data:
raise Exception('no sample')

detection_result = self.run_detection_crop(
detection_result = await self.run_detection_crop(
detection_session, current_data, detection_session.settings, current_src_size, current_convert_to_src_size, boundingBox)

return detection_result['detections']
Expand All @@ -493,7 +493,7 @@ async def user_callback(sample, src_size, convert_to_src_size):
first_frame = False
print("first frame received", detection_session.id)

detection_result, data = run_detection(
detection_result, data = await run_detection(
detection_session, sample, detection_session.settings, src_size, convert_to_src_size)
if detection_result:
detection_result['running'] = True
Expand Down
30 changes: 15 additions & 15 deletions plugins/tensorflow-lite/src/predict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,13 @@ def create_detection_result(self, objs: List[Prediction], size, allowList, conve
# print(detection_result)
return detection_result

def run_detection_jpeg(self, detection_session: PredictSession, image_bytes: bytes, settings: Any) -> ObjectsDetected:
async def run_detection_jpeg(self, detection_session: PredictSession, image_bytes: bytes, settings: Any) -> ObjectsDetected:
stream = io.BytesIO(image_bytes)
image = Image.open(stream)
if image.mode == 'RGBA':
image = image.convert('RGB')

detections, _ = self.run_detection_image(detection_session, image, settings, image.size)
detections, _ = await self.run_detection_image(detection_session, image, settings, image.size)
return detections

def get_detection_input_size(self, src_size):
Expand All @@ -269,7 +269,7 @@ def get_detection_input_size(self, src_size):
def get_input_size(self) -> Tuple[int, int]:
pass

def detect_once(self, input: Image.Image, settings: Any, src_size, cvss) -> ObjectsDetected:
async def detect_once(self, input: Image.Image, settings: Any, src_size, cvss) -> ObjectsDetected:
pass

async def run_detection_videoframe(self, videoFrame: scrypted_sdk.VideoFrame, settings: Any) -> ObjectsDetected:
Expand All @@ -288,7 +288,7 @@ def cvss(point, normalize=False):
})
image = Image.frombuffer('RGB', (w, h), data)
try:
ret = self.detect_once(image, settings, src_size, cvss)
ret = await self.detect_once(image, settings, src_size, cvss)
return ret
finally:
image.close()
Expand Down Expand Up @@ -339,9 +339,9 @@ def cvss1(point, normalize=False):
def cvss2(point, normalize=False):
return point[0] / s + ow, point[1] / s + oh, True

ret1 = self.detect_once(first, settings, src_size, cvss1)
ret1 = await self.detect_once(first, settings, src_size, cvss1)
first.close()
ret2 = self.detect_once(second, settings, src_size, cvss2)
ret2 = await self.detect_once(second, settings, src_size, cvss2)
second.close()

two_intersect = intersect_rect(Rectangle(*first_crop), Rectangle(*second_crop))
Expand Down Expand Up @@ -374,7 +374,7 @@ def is_same_detection_middle(d1: ObjectDetectionResult, d2: ObjectDetectionResul
ret['detections'] = dedupe_detections(ret1['detections'] + ret2['detections'], is_same_detection=is_same_detection_middle)
return ret

def run_detection_image(self, detection_session: PredictSession, image: Image.Image, settings: Any, src_size, convert_to_src_size: Any = None, multipass_crop: Tuple[float, float, float, float] = None):
async def run_detection_image(self, detection_session: PredictSession, image: Image.Image, settings: Any, src_size, convert_to_src_size: Any = None, multipass_crop: Tuple[float, float, float, float] = None):
(w, h) = self.get_input_size() or image.size
(iw, ih) = image.size

Expand Down Expand Up @@ -448,7 +448,7 @@ def cvss(point, normalize=False):
converted = convert_to_src_size(unscaled, normalize) if convert_to_src_size else (unscaled[0], unscaled[1], True)
return converted

ret = self.detect_once(input, settings, src_size, cvss)
ret = await self.detect_once(input, settings, src_size, cvss)
input.close()
detection_session.processed = detection_session.processed + 1
return ret, RawImage(image)
Expand All @@ -461,7 +461,7 @@ def cvss(point, normalize=False):
converted = convert_to_src_size(point, normalize) if convert_to_src_size else (point[0], point[1], True)
return converted

ret = self.detect_once(image, settings, src_size, cvss)
ret = await self.detect_once(image, settings, src_size, cvss)
if detection_session:
detection_session.processed = detection_session.processed + 1
else:
Expand All @@ -483,11 +483,11 @@ def cvss2(point, normalize=False):
converted = convert_to_src_size(unscaled, normalize) if convert_to_src_size else (unscaled[0], unscaled[1], True)
return converted

ret1 = self.detect_once(first, settings, src_size, cvss1)
ret1 = await self.detect_once(first, settings, src_size, cvss1)
first.close()
if detection_session:
detection_session.processed = detection_session.processed + 1
ret2 = self.detect_once(second, settings, src_size, cvss2)
ret2 = await self.detect_once(second, settings, src_size, cvss2)
if detection_session:
detection_session.processed = detection_session.processed + 1
second.close()
Expand Down Expand Up @@ -576,11 +576,11 @@ def track(self, detection_session: PredictSession, ret: ObjectsDetected):
# print('untracked %s: %s' % (d['className'], d['score']))


def run_detection_crop(self, detection_session: DetectionSession, sample: RawImage, settings: Any, src_size, convert_to_src_size, bounding_box: Tuple[float, float, float, float]) -> ObjectsDetected:
(ret, _) = self.run_detection_image(detection_session, sample.image, settings, src_size, convert_to_src_size, bounding_box)
async def run_detection_crop(self, detection_session: DetectionSession, sample: RawImage, settings: Any, src_size, convert_to_src_size, bounding_box: Tuple[float, float, float, float]) -> ObjectsDetected:
(ret, _) = await self.run_detection_image(detection_session, sample.image, settings, src_size, convert_to_src_size, bounding_box)
return ret

def run_detection_gstsample(self, detection_session: PredictSession, gstsample, settings: Any, src_size, convert_to_src_size) -> Tuple[ObjectsDetected, Image.Image]:
async def run_detection_gstsample(self, detection_session: PredictSession, gstsample, settings: Any, src_size, convert_to_src_size) -> Tuple[ObjectsDetected, Image.Image]:
caps = gstsample.get_caps()
# can't trust the width value, compute the stride
height = caps.get_structure(0).get_value('height')
Expand All @@ -604,7 +604,7 @@ def run_detection_gstsample(self, detection_session: PredictSession, gstsample,
gst_buffer.unmap(info)

try:
return self.run_detection_image(detection_session, image, settings, src_size, convert_to_src_size)
return await self.run_detection_image(detection_session, image, settings, src_size, convert_to_src_size)
except:
image.close()
traceback.print_exc()
Expand Down
69 changes: 48 additions & 21 deletions plugins/tensorflow-lite/src/tflite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .common import *
from PIL import Image
from pycoral.adapters import detect
from pycoral.adapters.common import input_size
loaded_py_coral = False
try:
from pycoral.utils.edgetpu import list_edge_tpus
Expand All @@ -19,6 +18,9 @@
from scrypted_sdk.types import Setting
from typing import Any, Tuple
from predict import PredictPlugin
import concurrent.futures
import queue
import asyncio

def parse_label_contents(contents: str):
lines = contents.splitlines()
Expand All @@ -41,6 +43,9 @@ def __init__(self, nativeId: str | None = None):
labels_contents = scrypted_sdk.zip.open(
'fs/coco_labels.txt').read().decode('utf8')
self.labels = parse_label_contents(labels_contents)
self.interpreters = queue.Queue()
self.interpreter_count = 0

try:
edge_tpus = list_edge_tpus()
print('edge tpus', edge_tpus)
Expand All @@ -53,7 +58,21 @@ def __init__(self, nativeId: str | None = None):
'fs/mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite').read()
# face_model = scrypted_sdk.zip.open(
# 'fs/mobilenet_ssd_v2_face_quant_postprocess.tflite').read()
self.interpreter = make_interpreter(model)
for idx, edge_tpu in enumerate(edge_tpus):
try:
interpreter = make_interpreter(model, ":%s" % idx)
interpreter.allocate_tensors()
_, height, width, channels = interpreter.get_input_details()[
0]['shape']
self.input_details = int(width), int(height), int(channels)
self.interpreters.put(interpreter)
self.interpreter_count = self.interpreter_count + 1
print('added tpu %s' % (edge_tpu))
except Exception as e:
print('unable to use Coral Edge TPU', e)

if not self.interpreter_count:
raise Exception('all tpus failed to load')
# self.face_interpreter = make_interpreter(face_model)
except Exception as e:
print('unable to use Coral Edge TPU', e)
Expand All @@ -62,10 +81,16 @@ def __init__(self, nativeId: str | None = None):
'fs/mobilenet_ssd_v2_coco_quant_postprocess.tflite').read()
# face_model = scrypted_sdk.zip.open(
# 'fs/mobilenet_ssd_v2_face_quant_postprocess.tflite').read()
self.interpreter = tflite.Interpreter(model_content=model)
interpreter = tflite.Interpreter(model_content=model)
interpreter.allocate_tensors()
_, height, width, channels = interpreter.get_input_details()[
0]['shape']
self.input_details = int(width), int(height), int(channels)
self.interpreters.put(interpreter)
self.interpreter_count = self.interpreter_count + 1
# self.face_interpreter = make_interpreter(face_model)
self.interpreter.allocate_tensors()
self.mutex = threading.Lock()

self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.interpreter_count, thread_name_prefix="tflite", )

async def getSettings(self) -> list[Setting]:
ret = await super().getSettings()
Expand All @@ -83,30 +108,32 @@ async def getSettings(self) -> list[Setting]:

# width, height, channels
def get_input_details(self) -> Tuple[int, int, int]:
with self.mutex:
_, height, width, channels = self.interpreter.get_input_details()[
0]['shape']
return int(width), int(height), int(channels)
return self.input_details

def get_input_size(self) -> Tuple[int, int]:
w, h = input_size(self.interpreter)
return int(w), int(h)
return self.input_details[0:2]

def detect_once(self, input: Image.Image, settings: Any, src_size, cvss):
try:
with self.mutex:
async def detect_once(self, input: Image.Image, settings: Any, src_size, cvss):
def predict():
interpreter = self.interpreters.get()
try:
common.set_input(
self.interpreter, input)
interpreter, input)
scale = (1, 1)
# _, scale = common.set_resized_input(
# self.interpreter, cropped.size, lambda size: cropped.resize(size, Image.ANTIALIAS))
self.interpreter.invoke()
interpreter.invoke()
objs = detect.get_objects(
self.interpreter, score_threshold=.2, image_scale=scale)
except:
print('tensorflow-lite encountered an error while detecting. requesting plugin restart.')
self.requestRestart()
raise e
interpreter, score_threshold=.2, image_scale=scale)
return objs
except:
print('tensorflow-lite encountered an error while detecting. requesting plugin restart.')
self.requestRestart()
raise e
finally:
self.interpreters.put(interpreter)

objs = await asyncio.get_event_loop().run_in_executor(self.executor, predict)

allowList = settings.get('allowList', None) if settings else None
ret = self.create_detection_result(objs, src_size, allowList, cvss)
Expand Down

0 comments on commit 38ba31c

Please sign in to comment.