Skip to content

Commit de4031c

Browse files
authored
DM: more precision running on DSP + e2e outputs (commaai#23900)
* update cereal * run but not use * log distraction type * regression scaling * clean up naming * add calib buf * add to header * fake model * no calib model * adjust threshs * 018a305f * fix bn * tweak1 * tweak2 * 0ff2/666 * tweak3 * t4 * t5 * fix out of bound * skip when replaying old segments * update ref * fix onnxmodel * get calib * update model replay refs * up ref
1 parent 7deba69 commit de4031c

19 files changed

+145
-62
lines changed

cereal

models/dmonitoring_model.current

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
4e19be90-bd5b-485d-b79a-2462f7f1b49e
2-
08f7ec37b78228cd1cb750b6ddb9c6ba1769e911
1+
0ff292a8-3a9c-47e7-9134-2c3b0f69a1fe
2+
d7c2883ee58d7b757e588bdf13ec45891caf397c

models/dmonitoring_model.onnx

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:895ee32e2a1c77496e015270db475eef65034b25331f2859bac0ccf702f64298
3-
size 3294407
2+
oid sha256:36cdea3c4b03f91cb243e4fa51f52e350e4906afcc5014345190e0c1d2bfaf25
3+
size 3792414

models/dmonitoring_model_q.dlc

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:6e4ac870984d11cd8e86cda4a63e3321fde837bacf4a055a27b7c8ba34facfe2
3-
size 916079
2+
oid sha256:41901abff0e16e6c404627234cca68001fbd21f6961709a729a84a8d5e3cc56d
3+
size 1146001

selfdrive/modeld/dmonitoringmodeld.cc

+11-1
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,25 @@ ExitHandler do_exit;
1313

1414
void run_model(DMonitoringModelState &model, VisionIpcClient &vipc_client) {
1515
PubMaster pm({"driverState"});
16+
SubMaster sm({"liveCalibration"});
17+
float calib[CALIB_LEN] = {0};
1618
double last = 0;
1719

1820
while (!do_exit) {
1921
VisionIpcBufExtra extra = {};
2022
VisionBuf *buf = vipc_client.recv(&extra);
2123
if (buf == nullptr) continue;
2224

25+
sm.update(0);
26+
if (sm.updated("liveCalibration")) {
27+
auto calib_msg = sm["liveCalibration"].getLiveCalibration().getRpyCalib();
28+
for (int i = 0; i < CALIB_LEN; i++) {
29+
calib[i] = calib_msg[i];
30+
}
31+
}
32+
2333
double t1 = millis_since_boot();
24-
DMonitoringResult res = dmonitoring_eval_frame(&model, buf->addr, buf->width, buf->height);
34+
DMonitoringResult res = dmonitoring_eval_frame(&model, buf->addr, buf->width, buf->height, calib);
2535
double t2 = millis_since_boot();
2636

2737
// send dm packet

selfdrive/modeld/models/commonmodel.cc

-4
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,3 @@ void softmax(const float* input, float* output, size_t len) {
7070
float sigmoid(float input) {
7171
return 1 / (1 + expf(-input));
7272
}
73-
74-
float softplus(float input) {
75-
return log1p(expf(input));
76-
}

selfdrive/modeld/models/commonmodel.h

-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
const bool send_raw_pred = getenv("SEND_RAW_PRED") != NULL;
2020

2121
void softmax(const float* input, float* output, size_t len);
22-
float softplus(float input);
2322
float sigmoid(float input);
2423

2524
class ModelFrame {

selfdrive/modeld/models/dmonitoring.cc

+29-16
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ void dmonitoring_init(DMonitoringModelState* s) {
4343
#else
4444
s->m = new SNPEModel("../../models/dmonitoring_model_q.dlc", &s->output[0], OUTPUT_SIZE, USE_DSP_RUNTIME);
4545
#endif
46+
47+
s->m->addCalib(s->calib, CALIB_LEN);
4648
}
4749

4850
static inline auto get_yuv_buf(std::vector<uint8_t> &buf, const int width, int height) {
@@ -65,7 +67,7 @@ void crop_yuv(uint8_t *raw, int width, int height, uint8_t *y, uint8_t *u, uint8
6567
}
6668
}
6769

68-
DMonitoringResult dmonitoring_eval_frame(DMonitoringModelState* s, void* stream_buf, int width, int height) {
70+
DMonitoringResult dmonitoring_eval_frame(DMonitoringModelState* s, void* stream_buf, int width, int height, float *calib) {
6971
Rect crop_rect;
7072
if (width == TICI_CAM_WIDTH) {
7173
const int cropped_height = tici_dm_crop::width / 1.33;
@@ -167,29 +169,38 @@ DMonitoringResult dmonitoring_eval_frame(DMonitoringModelState* s, void* stream_
167169

168170
double t1 = millis_since_boot();
169171
s->m->addImage(net_input_buf, yuv_buf_len);
172+
for (int i = 0; i < CALIB_LEN; i++) {
173+
s->calib[i] = calib[i];
174+
}
170175
s->m->execute();
171176
double t2 = millis_since_boot();
172177

173178
DMonitoringResult ret = {0};
174179
for (int i = 0; i < 3; ++i) {
175-
ret.face_orientation[i] = s->output[i];
176-
ret.face_orientation_meta[i] = softplus(s->output[6 + i]);
180+
ret.face_orientation[i] = s->output[i] * REG_SCALE;
181+
ret.face_orientation_meta[i] = exp(s->output[6 + i]);
182+
}
183+
for (int i = 0; i < 2; ++i) {
184+
ret.face_position[i] = s->output[3 + i] * REG_SCALE;
185+
ret.face_position_meta[i] = exp(s->output[9 + i]);
186+
}
187+
for (int i = 0; i < 4; ++i) {
188+
ret.ready_prob[i] = sigmoid(s->output[39 + i]);
177189
}
178190
for (int i = 0; i < 2; ++i) {
179-
ret.face_position[i] = s->output[3 + i];
180-
ret.face_position_meta[i] = softplus(s->output[9 + i]);
191+
ret.not_ready_prob[i] = sigmoid(s->output[43 + i]);
181192
}
182-
ret.face_prob = s->output[12];
183-
ret.left_eye_prob = s->output[21];
184-
ret.right_eye_prob = s->output[30];
185-
ret.left_blink_prob = s->output[31];
186-
ret.right_blink_prob = s->output[32];
187-
ret.sg_prob = s->output[33];
188-
ret.poor_vision = s->output[34];
189-
ret.partial_face = s->output[35];
190-
ret.distracted_pose = s->output[36];
191-
ret.distracted_eyes = s->output[37];
192-
ret.occluded_prob = s->output[38];
193+
ret.face_prob = sigmoid(s->output[12]);
194+
ret.left_eye_prob = sigmoid(s->output[21]);
195+
ret.right_eye_prob = sigmoid(s->output[30]);
196+
ret.left_blink_prob = sigmoid(s->output[31]);
197+
ret.right_blink_prob = sigmoid(s->output[32]);
198+
ret.sg_prob = sigmoid(s->output[33]);
199+
ret.poor_vision = sigmoid(s->output[34]);
200+
ret.partial_face = sigmoid(s->output[35]);
201+
ret.distracted_pose = sigmoid(s->output[36]);
202+
ret.distracted_eyes = sigmoid(s->output[37]);
203+
ret.occluded_prob = sigmoid(s->output[38]);
193204
ret.dsp_execution_time = (t2 - t1) / 1000.;
194205
return ret;
195206
}
@@ -217,6 +228,8 @@ void dmonitoring_publish(PubMaster &pm, uint32_t frame_id, const DMonitoringResu
217228
framed.setDistractedPose(res.distracted_pose);
218229
framed.setDistractedEyes(res.distracted_eyes);
219230
framed.setOccludedProb(res.occluded_prob);
231+
framed.setReadyProb(res.ready_prob);
232+
framed.setNotReadyProb(res.not_ready_prob);
220233
if (send_raw_pred) {
221234
framed.setRawPredictions(raw_pred.asBytes());
222235
}

selfdrive/modeld/models/dmonitoring.h

+8-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
#include "selfdrive/modeld/models/commonmodel.h"
88
#include "selfdrive/modeld/runners/run.h"
99

10-
#define OUTPUT_SIZE 39
10+
#define CALIB_LEN 3
11+
12+
#define OUTPUT_SIZE 45
13+
#define REG_SCALE 0.25f
1114

1215
typedef struct DMonitoringResult {
1316
float face_orientation[3];
@@ -25,6 +28,8 @@ typedef struct DMonitoringResult {
2528
float distracted_pose;
2629
float distracted_eyes;
2730
float occluded_prob;
31+
float ready_prob[4];
32+
float not_ready_prob[2];
2833
float dsp_execution_time;
2934
} DMonitoringResult;
3035

@@ -36,11 +41,12 @@ typedef struct DMonitoringModelState {
3641
std::vector<uint8_t> cropped_buf;
3742
std::vector<uint8_t> premirror_cropped_buf;
3843
std::vector<float> net_input_buf;
44+
float calib[CALIB_LEN];
3945
float tensor[UINT8_MAX + 1];
4046
} DMonitoringModelState;
4147

4248
void dmonitoring_init(DMonitoringModelState* s);
43-
DMonitoringResult dmonitoring_eval_frame(DMonitoringModelState* s, void* stream_buf, int width, int height);
49+
DMonitoringResult dmonitoring_eval_frame(DMonitoringModelState* s, void* stream_buf, int width, int height, float *calib);
4450
void dmonitoring_publish(PubMaster &pm, uint32_t frame_id, const DMonitoringResult &res, float execution_time, kj::ArrayPtr<const float> raw_pred);
4551
void dmonitoring_free(DMonitoringModelState* s);
4652

selfdrive/modeld/runners/onnxmodel.cc

+8
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ void ONNXModel::addTrafficConvention(float *state, int state_size) {
100100
traffic_convention_size = state_size;
101101
}
102102

103+
void ONNXModel::addCalib(float *state, int state_size) {
104+
calib_input_buf = state;
105+
calib_size = state_size;
106+
}
107+
103108
void ONNXModel::addImage(float *image_buf, int buf_size) {
104109
image_input_buf = image_buf;
105110
image_buf_size = buf_size;
@@ -124,6 +129,9 @@ void ONNXModel::execute() {
124129
if (traffic_convention_input_buf != NULL) {
125130
pwrite(traffic_convention_input_buf, traffic_convention_size);
126131
}
132+
if (calib_input_buf != NULL) {
133+
pwrite(calib_input_buf, calib_size);
134+
}
127135
if (rnn_input_buf != NULL) {
128136
pwrite(rnn_input_buf, rnn_state_size);
129137
}

selfdrive/modeld/runners/onnxmodel.h

+3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class ONNXModel : public RunModel {
1111
void addRecurrent(float *state, int state_size);
1212
void addDesire(float *state, int state_size);
1313
void addTrafficConvention(float *state, int state_size);
14+
void addCalib(float *state, int state_size);
1415
void addImage(float *image_buf, int buf_size);
1516
void addExtra(float *image_buf, int buf_size);
1617
void execute();
@@ -26,6 +27,8 @@ class ONNXModel : public RunModel {
2627
int desire_state_size;
2728
float *traffic_convention_input_buf = NULL;
2829
int traffic_convention_size;
30+
float *calib_input_buf = NULL;
31+
int calib_size;
2932
float *image_input_buf = NULL;
3033
int image_buf_size;
3134
float *extra_input_buf = NULL;

selfdrive/modeld/runners/runmodel.h

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ class RunModel {
55
virtual void addRecurrent(float *state, int state_size) {}
66
virtual void addDesire(float *state, int state_size) {}
77
virtual void addTrafficConvention(float *state, int state_size) {}
8+
virtual void addCalib(float *state, int state_size) {}
89
virtual void addImage(float *image_buf, int buf_size) {}
910
virtual void addExtra(float *image_buf, int buf_size) {}
1011
virtual void execute() {}

selfdrive/modeld/runners/snpemodel.cc

+5
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@ void SNPEModel::addDesire(float *state, int state_size) {
141141
desireBuffer = this->addExtra(state, state_size, 1);
142142
}
143143

144+
void SNPEModel::addCalib(float *state, int state_size) {
145+
calib = state;
146+
calibBuffer = this->addExtra(state, state_size, 1);
147+
}
148+
144149
void SNPEModel::addImage(float *image_buf, int buf_size) {
145150
input = image_buf;
146151
input_size = buf_size;

selfdrive/modeld/runners/snpemodel.h

+3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class SNPEModel : public RunModel {
2525
SNPEModel(const char *path, float *loutput, size_t loutput_size, int runtime, bool luse_extra = false);
2626
void addRecurrent(float *state, int state_size);
2727
void addTrafficConvention(float *state, int state_size);
28+
void addCalib(float *state, int state_size);
2829
void addDesire(float *state, int state_size);
2930
void addImage(float *image_buf, int buf_size);
3031
void addExtra(float *image_buf, int buf_size);
@@ -71,4 +72,6 @@ class SNPEModel : public RunModel {
7172
std::unique_ptr<zdl::DlSystem::IUserBuffer> trafficConventionBuffer;
7273
float *desire;
7374
std::unique_ptr<zdl::DlSystem::IUserBuffer> desireBuffer;
75+
float *calib;
76+
std::unique_ptr<zdl::DlSystem::IUserBuffer> calibBuffer;
7477
};

selfdrive/monitoring/dmonitoringd.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -51,22 +51,23 @@ def dmonitoringd_thread(sm=None, pm=None):
5151

5252
# Get data from dmonitoringmodeld
5353
events = Events()
54-
driver_status.get_pose(sm['driverState'], sm['liveCalibration'].rpyCalib, sm['carState'].vEgo, sm['controlsState'].enabled)
54+
driver_status.update_states(sm['driverState'], sm['liveCalibration'].rpyCalib, sm['carState'].vEgo, sm['controlsState'].enabled)
5555

5656
# Block engaging after max number of distrations
5757
if driver_status.terminal_alert_cnt >= driver_status.settings._MAX_TERMINAL_ALERTS or \
5858
driver_status.terminal_time >= driver_status.settings._MAX_TERMINAL_DURATION:
5959
events.add(car.CarEvent.EventName.tooDistracted)
6060

6161
# Update events from driver state
62-
driver_status.update(events, driver_engaged, sm['controlsState'].enabled, sm['carState'].standstill)
62+
driver_status.update_events(events, driver_engaged, sm['controlsState'].enabled, sm['carState'].standstill)
6363

6464
# build driverMonitoringState packet
6565
dat = messaging.new_message('driverMonitoringState')
6666
dat.driverMonitoringState = {
6767
"events": events.to_msg(),
6868
"faceDetected": driver_status.face_detected,
6969
"isDistracted": driver_status.driver_distracted,
70+
"distractedType": sum(driver_status.distracted_types),
7071
"awarenessStatus": driver_status.awareness,
7172
"posePitchOffset": driver_status.pose.pitch_offseter.filtered_stat.mean(),
7273
"posePitchValidCount": driver_status.pose.pitch_offseter.filtered_stat.n,

0 commit comments

Comments
 (0)