-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathmqrepl.py
354 lines (324 loc) · 14.1 KB
/
mqrepl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
# MQTT Repl for MicroPython
# Copyright © 2020 by Thorsten von Eicken.
#
# Requires mqtt_async for asyncio-based MQTT.
import io
import os
import sys
import time
import struct
import gc
from uasyncio import Loop as loop
import uhashlib as hashlib
import ubinascii as binascii
import logging
from micropython import const
log = logging.getLogger(__name__)
# log.setLevel(logging.DEBUG)
TOPIC = "esp32/test/mqb/" # typ. overridden in MQRepl's constructor (exported to other modules)
PKTLEN = 1400 # data bytes that reasonably fit into a TCP packet
BUFLEN = PKTLEN * 2 # good number of data bytes to stream files
ERR_SINGLEMSG = "only single message supported"
if sys.platform == "esp32":
from esp32 import Partition
BLOCKLEN = const(4096) # data bytes in a flash block
# OTA manages a MicroPython firmware update over-the-air.
# It assumes that there are two "app" partitions in the partition table and updates the one
# that is not currently running. When the update is complete, it sets the new partition as
# the next one to boot. It does not reset/restart, use machine.reset() explicitly.
class OTA:
def __init__(self):
self.part = Partition(Partition.RUNNING).get_next_update()
self.sha = hashlib.sha256()
self.seq = 0
self.block = 0
self.buf = bytearray(BLOCKLEN)
self.buflen = 0
# handle processes one message with a chunk of data in msg. The sequence number seq needs
# to increment sequentially and the last call needs to have last==True as well as the
# sha set to the hashlib.sha256(entire_data).hexdigest().
def handle(self, sha, msg, seq, last):
if self.seq is None:
raise ValueError("missing first message")
elif self.seq < seq:
# "duplicate message"
log.warning("Duplicate OTA message seq=%d", seq)
return None
elif self.seq > seq:
raise ValueError("message missing")
else:
self.seq += 1
self.sha.update(msg)
# avoid allocating memory: use buf as-is
msglen = len(msg)
if self.buflen + msglen >= BLOCKLEN:
# got a full block, assemble it and write to flash
cpylen = BLOCKLEN - self.buflen
self.buf[self.buflen : BLOCKLEN] = msg[:cpylen]
self.part.writeblocks(self.block, self.buf)
self.block += 1
msglen -= cpylen
if msglen > 0:
self.buf[:msglen] = msg[cpylen:]
self.buflen = msglen
else:
self.buf[self.buflen : self.buflen + msglen] = msg
self.buflen += msglen
if last and self.buflen > 0:
for i in range(BLOCKLEN - self.buflen):
self.buf[self.buflen + i] = 0xFF # erased flash is ff
self.part.writeblocks(self.block, self.buf)
self.block += 1
assert len(self.buf) == BLOCKLEN
if last:
return self.finish(sha)
elif (seq & 7) == 0:
# log.info("Sending ACK {}".format(seq))
return "SEQ {}".format(seq).encode()
def finish(self, check_sha):
del self.buf
self.seq = None
calc_sha = binascii.hexlify(self.sha.digest())
check_sha = check_sha.encode()
if calc_sha != check_sha:
raise ValueError("SHA mismatch calc:{} check={}".format(calc_sha, check_sha))
self.part.set_boot()
return "OK"
# LogWriter is a helper class that sends text line-wise to a logger (logging module).
# class LogWriter(io.IOBase):
# def __init__(self, logger, level):
# self._logger = logger
# self._level = level
# self._wbuf = b""
#
# def write(self, buf):
# if buf == b"":
# return
# lines = (self._wbuf + buf).split(b"\n")
# self._wbuf = lines[-1]
# for l in lines[:-1]:
# self._logger(self._level, l)
# MQRepl implements REPL-like functionality over MQTT. It receives command messages, performs
# the commands, and sends a response back.
# The commands topics have the general form .../cmd/<cmd>/<id>[/<filename>] where <cmd> is the name
# of the command, <id> is random ID to tie request topics and response topics for one command
# invocations together, <filename> is a filesystem path where appropriate. Responses have the
# general form .../reply/<kind>/<id> where <kind> is out and err and the <id> matches the
# request.
# The payloads contain file data, command text, or response text. Each payload is prefixed with
# a 2-byte header which contains a sequence number (to detect duplicates) and a last-message
# flag.
# All multi-message sequences must be sent using QoS=1 to ensure in-order delivery.
# A non-obvious trick is that at start-up MQRepl ignores all command messages that are
# MQTT duplicates because they may be unacked because they caused a crash and chances are
# they'll do that again (oops).
class MQRepl:
def __init__(self, mqclient, topic):
import __main__
global TOPIC
self._ota = None # OTA in progress
self._put_fd = None # file descr for PUT in progress
self._put_seq = None # next expected PUT message seq number
self._ndup = False # set true when 1st non-dup msg is received
self._globals = __main__.GLOBALS()
TOPIC = topic
self.mqclient = mqclient
async def _ttypub(self, buf):
if self.mqclient:
await self.mqclient.publish(TOPIC + "ttyout", buf, qos=1, sync=False)
async def start(self, mqtt):
mqtt.on_msg(self._msg_cb)
topic = TOPIC + "cmd/#"
await mqtt.client.subscribe(topic, qos=1) # TODO: should this be made async?
log.info("Subscribed to %s", topic)
# def stop(self):
# # TODO: this should unsubscribe and remove the on_msg handler, but is this ever used?
# pass
# Handlers for commands
# # do_eval receives an expression in cmd, runs it through the interpreter and returns
# # the result using repr()
# def _do_eval_xx(self, fname, cmd, seq, last):
# if seq != 0 or not last:
# raise ValueError(ERR_SINGLEMSG)
# cmd = str(cmd, "utf-8")
# log.debug("eval %s", cmd)
# op = compile(cmd, "<eval>", "eval")
# result = eval(op, globals(), None)
# return repr(result)
# # do_exec receives a command line in cmd, runs it through the interpreter and returns
# # the resulting output
# def _do_exec_xx(self, fname, cmd, seq, last):
# if seq != 0 or not last:
# raise ValueError(ERR_SINGLEMSG)
# cmd = str(cmd, "utf-8")
# log.debug("exec %s", cmd)
# outbuf = io.BytesIO(BUFLEN) # FIXME: need to stream output back
# old_term = os.dupterm(outbuf)
# try:
# op = compile(cmd, "<exec>", "exec")
# eval(op, globals(), None)
# time.sleep_ms(5) # necessary to capture all output?
# return outbuf.getvalue()
# finally:
# os.dupterm(old_term)
def _do_eval(self, fname, cmd, seq, last): # read_msg, msg_len):
if seq != 0 or not last:
raise ValueError(ERR_SINGLEMSG)
cmd = str(cmd, "utf-8")
log.debug("eval %s", cmd)
# try to eval
try:
op = compile(cmd, "<eval>", "eval")
result = eval(op, self._globals, None)
return repr(result)
except SyntaxError:
pass
# try to exec
outbuf = io.BytesIO(BUFLEN) # FIXME: make this variable-sized with a max
old_term = os.dupterm(outbuf)
try:
op = compile(cmd, "<exec>", "exec")
exec(op, self._globals, None)
time.sleep_ms(11) # necessary to capture all output?
return outbuf.getvalue()
finally:
os.dupterm(old_term)
# do_get opens the file fname and retuns it as a stream so it can be sent back
def _do_get(self, fname, msg, seq, last):
if seq != 0 or not last:
raise ValueError(ERR_SINGLEMSG)
log.debug("opening {}".format(fname))
return open(fname, "rb")
# do_put opens the file fname for writing and appends the message content to it.
# FIXME: properly guard against concurrent PUTs
def _do_put(self, fname, msg, seq, last):
if seq == 0:
if self._put_fd is not None:
self._put_fd.close()
self._put_fd = open(fname, "wb")
self._put_seq = 1 # next seq expected
elif self._put_seq is None:
raise ValueError("missing first message")
elif seq < self._put_seq:
# "duplicate message"
return None
elif seq > self._put_seq:
raise ValueError("message missing: {} vs. {}".format(seq, self._put_seq))
else:
self._put_seq += 1
self._put_fd.write(msg)
if last:
self._put_fd.close()
self._put_fd = None
self._put_seq = None
return "OK"
# do_ota uploads a new firmware over-the-air and activates it for the next boot
# the fname passed in must be the sha256 of the firmware
def _do_ota(self, fname, msg, seq, last):
if sys.platform != "esp32":
raise ValueError("N/A")
if seq == 0:
self._ota = OTA()
if self._ota is not None:
ret = self._ota.handle(fname, msg, seq, last)
if last:
self._ota = None
# log.info("OTA ret=%s", ret)
gc.collect() # needed!
return ret
# Helpers
# _send_stream repeatedly calls read() on the stream until EOF and publishes the data it gets
# to the specified topic. Each packet has the std 2-byte header.
async def _send_stream(self, topic, stream):
buf = bytearray(BUFLEN + 2)
buf[2:] = stream.read(BUFLEN)
seq = 0
last = 0
while True:
last = len(buf) == 2
struct.pack_into("!H", buf, 0, last << 15 | seq)
log.debug("pub {} -> {}".format(len(buf), topic))
await self.mqclient.publish(topic, buf, qos=1, sync=last)
if last:
stream.close()
return None
buf[2:] = stream.read(BUFLEN)
seq += 1
# Callback handlers
# _msg_cb handles the arrival of an MQTT message.
# The first two bytes of each message contain a binary (big endian) sequence number with the
# top bit set for the last message in the sequence.
def _msg_cb(self, topic, msg, retained, qos, dup):
topic = str(topic, "utf-8")
# log.info("MQTT: %s", topic)
lt = len(TOPIC)
if topic.startswith(TOPIC) and topic[lt : lt + 4] == "cmd/" and len(msg) >= 2:
if dup and not self._ndup:
return # skip inital dup msgs
else:
self._ndup = True
# expect topic: TOPIC/cmd/<cmd>/<id>[/<filename>]
topic = topic[lt + 4 :].split("/", 2)
if len(topic) < 2:
return
cmd, ident, *name = topic # *name allows for it to be missing
name = name[0] if len(name) else None
rtopic = TOPIC + "reply/out/" + ident
errtopic = TOPIC + "reply/err/" + ident
# check cmd
fn = "_do_" + cmd
if not hasattr(self, fn):
loop.create_task(
self.mqclient.publish(errtopic, "Command '" + cmd + "' not supported", qos=1)
)
return
# parse message header (first two bytes)
seq = ((msg[0] & 0x7F) << 8) | msg[1]
last = (msg[0] & 0x80) != 0
msg = memoryview(msg)[2:]
# dispatch to command function
# logging: if something is being streamed to us and we try to send a log message back
# for each inbound message we end up loosing log messages because we can't get them out
# as fast as new ones arrive. This always happens during OTA. Hence we stop logging
# every message...
if seq < 4 or last or seq & 0xF == 0:
log.info(
"Dispatch %s, msglen=%d seq=%d last=%s id=%s dup=%s",
cmd,
len(msg),
seq,
last,
ident,
dup,
)
try:
t0 = time.ticks_ms()
resp = getattr(self, fn)(name, msg, seq, last)
log.debug("took %dms", time.ticks_diff(time.ticks_ms(), t0))
# send response back, which may require reading a stream
if resp is None:
pass
elif callable(getattr(resp, "read", None)):
loop.create_task(self._send_stream(rtopic, resp))
else:
log.debug("pub {} -> {}".format(len(resp), rtopic))
loop.create_task(self.mqclient.publish(rtopic, b"\xff\xff" + resp, qos=1))
except ValueError as e:
buf = "MQRepl protocol error {}: {}".format(cmd, e.args[0])
loop.create_task(self.mqclient.publish(errtopic, buf, qos=1))
except Exception as e:
log.warning("Exception in %s: %s", cmd, e)
# sys.print_exception(e)
# if this is a memory error just return, logging more runs out of memory again...
if isinstance(e, MemoryError):
return
# lw = LogWriter(log.log, logging.WARNING)
# sys.print_exception(e, lw)
errbuf = io.BytesIO(PKTLEN)
sys.print_exception(e, errbuf)
errbuf = errbuf.getvalue()
loop.create_task(self.mqclient.publish(errtopic, errbuf, qos=1))
# micropython.mem_info()
def start(mqtt, config):
mqr = MQRepl(mqtt.client, config["prefix"])
mqtt.on_init(mqr.start(mqtt))