Skip to content

Commit de27005

Browse files
Add support for EIM files with shared memory, add support for setting thresholds (#31)
* Add support for shared memory inference * Cleanup * Remove set-thresholds.py * Add set_threshold example * Remove debug statement * image: missing importing sys package * runner: small simplification * set-thresholds: make the example interactive * image: more optimized image conversion Previous versions were ~10 slower than the new one. --------- Co-authored-by: Mateusz Majchrzycki <mateusz@edgeimpulse.com>
1 parent a07463b commit de27005

File tree

12 files changed

+207
-39
lines changed

12 files changed

+207
-39
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ build
88
.vscode
99
*.eim
1010
.act-secrets
11+
*.png

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ This library lets you run machine learning models and collect sensor data on Lin
3232
$ pip3 install -r requirements.txt
3333
```
3434
35-
For the computer vision examples you'll want `opencv-python>=4.5.1.48`
35+
For the computer vision examples you'll want `opencv-python>=4.5.1.48,<5`
3636
Note on macOS on apple silicon, you will need to use a later version,
3737
4.10.0.84 tested and installs cleanly
3838

edge_impulse_linux/image.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
#!/usr/bin/env python
22

33
import numpy as np
4-
import cv2
4+
import sys
5+
try:
6+
import cv2
7+
except ImportError:
8+
print('Missing OpenCV, install via `pip3 install "opencv-python>=4.5.1.48,<5"`')
9+
exit(1)
10+
511
from edge_impulse_linux.runner import ImpulseRunner
612
import math
7-
import psutil
813

914
class ImageImpulseRunner(ImpulseRunner):
1015
def __init__(self, model_path: str):
@@ -44,7 +49,7 @@ def classify(self, data):
4449

4550
# This returns images in RGB format (not BGR)
4651
def get_frames(self, videoDeviceId = 0):
47-
if psutil.OSX or psutil.MACOS:
52+
if sys.platform == "darwin":
4853
print('Make sure to grant the this script access to your webcam.')
4954
print('If your webcam is not responding, try running "tccutil reset Camera" to reset the camera access privileges.')
5055

@@ -57,7 +62,7 @@ def get_frames(self, videoDeviceId = 0):
5762

5863
# This returns images in RGB format (not BGR)
5964
def classifier(self, videoDeviceId = 0):
60-
if psutil.OSX or psutil.MACOS:
65+
if sys.platform == "darwin":
6166
print('Make sure to grant the this script access to your webcam.')
6267
print('If your webcam is not responding, try running "tccutil reset Camera" to reset the camera access privileges.')
6368

@@ -137,9 +142,7 @@ def get_features_from_image_auto_studio_settings(self, img):
137142
raise Exception(
138143
'Runner has not initialized, please call init() first')
139144
if self.resizeMode == 'not-reported':
140-
raise Exception(
141-
'Model file "' + self._model_path + '" does not report the image resize mode\n'
142-
'Please update the model file via edge-impulse-linux-runner --download')
145+
self.resizeMode = 'squash'
143146
return get_features_from_image_with_studio_mode(img, self.resizeMode, self.dim[0], self.dim[1], self.isGrayscale)
144147

145148

@@ -233,17 +236,10 @@ def get_features_from_image_with_studio_mode(img, mode, output_width, output_hei
233236

234237
if is_grayscale:
235238
resized_img = cv2.cvtColor(resized_img, cv2.COLOR_BGR2GRAY)
236-
pixels = np.array(resized_img).flatten().tolist()
237-
238-
for p in pixels:
239-
features.append((p << 16) + (p << 8) + p)
239+
features = (resized_img.astype(np.uint32) * 0x010101).flatten().tolist()
240240
else:
241-
pixels = np.array(resized_img).flatten().tolist()
242-
243-
for ix in range(0, len(pixels), 3):
244-
r = pixels[ix + 0]
245-
g = pixels[ix + 1]
246-
b = pixels[ix + 2]
247-
features.append((r << 16) + (g << 8) + b)
241+
# Use numpy's vectorized operations for RGB feature encoding
242+
pixels = resized_img.astype(np.uint32)
243+
features = ((pixels[..., 0] << 16) | (pixels[..., 1] << 8) | pixels[..., 2]).flatten().tolist()
248244

249245
return features, resized_img

edge_impulse_linux/runner.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
import signal
77
import socket
88
import json
9-
9+
from multiprocessing import shared_memory, resource_tracker
10+
import numpy as np
1011

1112
def now():
1213
return round(time.time() * 1000)
1314

14-
1515
class ImpulseRunner:
1616
def __init__(self, model_path: str):
1717
self._model_path = model_path
@@ -20,6 +20,8 @@ def __init__(self, model_path: str):
2020
self._client = None
2121
self._ix = 0
2222
self._debug = False
23+
self._hello_resp = None
24+
self._shm = None
2325

2426
def init(self, debug=False):
2527
if not os.path.exists(self._model_path):
@@ -50,27 +52,71 @@ def init(self, debug=False):
5052
self._client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
5153
self._client.connect(socket_path)
5254

53-
return self.hello()
55+
hello_resp = self._hello_resp = self.hello()
56+
57+
if ('features_shm' in hello_resp.keys()):
58+
shm_name = hello_resp['features_shm']['name']
59+
# python does not want the leading slash
60+
shm_name = shm_name.lstrip('/')
61+
shm = shared_memory.SharedMemory(name=shm_name)
62+
self._shm = {
63+
'shm': shm,
64+
'type': hello_resp['features_shm']['type'],
65+
'elements': hello_resp['features_shm']['elements'],
66+
'array': np.ndarray((hello_resp['features_shm']['elements'],), dtype=np.float32, buffer=shm.buf)
67+
}
68+
69+
return self._hello_resp
70+
71+
def __del__(self):
72+
self.stop()
5473

5574
def stop(self):
56-
if self._tempdir:
75+
if self._tempdir is not None:
5776
shutil.rmtree(self._tempdir)
77+
self._tempdir = None
5878

59-
if self._client:
79+
if self._client is not None:
6080
self._client.close()
81+
self._client = None
6182

62-
if self._runner:
83+
if self._runner is not None:
6384
os.kill(self._runner.pid, signal.SIGINT)
6485
# todo: in Node we send a SIGHUP after 0.5sec if process has not died, can we do this somehow here too?
86+
self._runner = None
87+
88+
if self._shm is not None:
89+
self._shm['shm'].close()
90+
resource_tracker.unregister(self._shm['shm']._name, "shared_memory")
91+
self._shm = None
6592

6693
def hello(self):
6794
msg = {"hello": 1}
6895
return self.send_msg(msg)
6996

7097
def classify(self, data):
71-
msg = {"classify": data}
98+
if self._shm:
99+
self._shm['array'][:] = data
100+
101+
msg = {
102+
"classify_shm": {
103+
"elements": len(data),
104+
}
105+
}
106+
else:
107+
msg = {"classify": data}
108+
72109
if self._debug:
73110
msg["debug"] = True
111+
112+
send_resp = self.send_msg(msg)
113+
return send_resp
114+
115+
def set_threshold(self, obj):
116+
if not 'id' in obj:
117+
raise Exception('set_threshold requires an object with an "id" field')
118+
119+
msg = { 'set_threshold': obj }
74120
return self.send_msg(msg)
75121

76122
def send_msg(self, msg):

examples/image/classify-full-frame.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22

33
import device_patches # Device specific patches for Jetson Nano (needs to be before importing cv2)
44

5-
import cv2
5+
try:
6+
import cv2
7+
except ImportError:
8+
print('Missing OpenCV, install via `pip3 install "opencv-python>=4.5.1.48,<5"`')
9+
exit(1)
610
import os
711
import sys, getopt
812
import signal

examples/image/classify-image.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22

33
import device_patches # Device specific patches for Jetson Nano (needs to be before importing cv2) # noqa: F401
44

5-
import cv2
5+
try:
6+
import cv2
7+
except ImportError:
8+
print('Missing OpenCV, install via `pip3 install "opencv-python>=4.5.1.48,<5"`')
9+
exit(1)
610
import os
711
import sys
812
import getopt

examples/image/classify-video.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22

33
import device_patches # Device specific patches for Jetson Nano (needs to be before importing cv2)
44

5-
import cv2
5+
try:
6+
import cv2
7+
except ImportError:
8+
print('Missing OpenCV, install via `pip3 install "opencv-python>=4.5.1.48,<5"`')
9+
exit(1)
610
import os
711
import time
812
import sys, getopt
9-
import numpy as np
1013
from edge_impulse_linux.image import ImageImpulseRunner
1114

1215
runner = None

examples/image/classify.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22

33
import device_patches # Device specific patches for Jetson Nano (needs to be before importing cv2)
44

5-
import cv2
5+
try:
6+
import cv2
7+
except ImportError:
8+
print('Missing OpenCV, install via `pip3 install "opencv-python>=4.5.1.48,<5"`')
9+
exit(1)
610
import os
711
import sys, getopt
812
import signal

examples/image/resize_demo.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import numpy as np
2-
import cv2
2+
try:
3+
import cv2
4+
except ImportError:
5+
print('Missing OpenCV, install via `pip3 install "opencv-python>=4.5.1.48,<5"`')
6+
exit(1)
37
from edge_impulse_linux.image import get_features_from_image_with_studio_mode
48

59

examples/image/set-thresholds.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#!/usr/bin/env python
2+
3+
import device_patches # Device specific patches for Jetson Nano (needs to be before importing cv2) # noqa: F401
4+
5+
try:
6+
import cv2
7+
except ImportError:
8+
print('Missing OpenCV, install via `pip3 install "opencv-python>=4.5.1.48,<5"`')
9+
exit(1)
10+
import os
11+
import sys
12+
import getopt
13+
import json
14+
from edge_impulse_linux.image import ImageImpulseRunner
15+
16+
runner = None
17+
18+
def help():
19+
print('python set-thresholds.py <path_to_model.eim> <path_to_image.jpg>')
20+
21+
def main(argv):
22+
try:
23+
opts, args = getopt.getopt(argv, "h", ["--help"])
24+
except getopt.GetoptError:
25+
help()
26+
sys.exit(2)
27+
28+
for opt, arg in opts:
29+
if opt in ('-h', '--help'):
30+
help()
31+
sys.exit()
32+
33+
if len(args) != 2:
34+
help()
35+
sys.exit(2)
36+
37+
model = args[0]
38+
39+
dir_path = os.path.dirname(os.path.realpath(__file__))
40+
modelfile = os.path.join(dir_path, model)
41+
42+
print('MODEL: ' + modelfile)
43+
44+
with ImageImpulseRunner(modelfile) as runner:
45+
try:
46+
model_info = runner.init()
47+
# model_info = runner.init(debug=True) # to get debug print out
48+
49+
print('Loaded runner for "' + model_info['project']['owner'] + ' / ' + model_info['project']['name'] + '"')
50+
if not 'thresholds' in model_info['model_parameters']:
51+
print('This model does not expose any thresholds, build a new Linux deployment (.eim file) to get configurable thresholds')
52+
exit(1)
53+
54+
print('Thresholds:')
55+
for threshold in model_info['model_parameters']['thresholds']:
56+
print(' -', json.dumps(threshold))
57+
58+
# Example output for an object detection model:
59+
# Thresholds:
60+
# - {"id": 3, "min_score": 0.20000000298023224, "type": "object_detection"}
61+
62+
img = cv2.imread(args[1])
63+
if img is None:
64+
print('Failed to load image', args[1])
65+
exit(1)
66+
67+
# imread returns images in BGR format, so we need to convert to RGB
68+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
69+
# this mode uses the same settings used in studio to crop and resize the input
70+
features, cropped = runner.get_features_from_image_auto_studio_settings(img)
71+
72+
print("Which threshold would you like to change? (id)")
73+
while True:
74+
try:
75+
threshold_id = int(input('Enter threshold ID: '))
76+
if threshold_id not in [t['id'] for t in model_info['model_parameters']['thresholds']]:
77+
print('Invalid threshold ID, try again')
78+
continue
79+
break
80+
except ValueError:
81+
print('Invalid input, please enter a number')
82+
83+
print("Enter a new threshold value (between 0.0 and 1.0):")
84+
while True:
85+
try:
86+
new_threshold = float(input('New threshold value: '))
87+
if new_threshold < 0.0 or new_threshold > 1.0:
88+
print('Invalid threshold value, must be between 0.0 and 1.0')
89+
continue
90+
break
91+
except ValueError:
92+
print('Invalid input, please enter a number')
93+
94+
# dynamically override the thresold from 0.2 -> 0.8
95+
runner.set_threshold({
96+
'id': threshold_id,
97+
'min_score': new_threshold,
98+
})
99+
100+
res = runner.classify(features)
101+
print('classify response', json.dumps(res, indent=4))
102+
103+
finally:
104+
if (runner):
105+
runner.stop()
106+
107+
if __name__ == "__main__":
108+
main(sys.argv[1:])

0 commit comments

Comments
 (0)