Skip to content

Commit

Permalink
SPI: fixes + fake bulk transfers (commaai#1150)
Browse files Browse the repository at this point in the history
* check spi checkusm

* ugh, fix control handler

* fake bulk xfer

* cleanup

* one more

* unused

* fix linter

* some typing

Co-authored-by: Comma Device <device@comma.ai>
  • Loading branch information
2 people authored and sunnyhaibin committed Nov 12, 2022
1 parent 071b087 commit 016615e
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 24 deletions.
1 change: 1 addition & 0 deletions board/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//#define DEBUG_USB
//#define DEBUG_SPI
//#define DEBUG_FAULTS
//#define DEBUG_COMMS

#define DEEPSLEEP_WAKEUP_DELAY 3U

Expand Down
8 changes: 8 additions & 0 deletions board/main_comms.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ int comms_control_handler(ControlPacket_t *req, uint8_t *resp) {
unsigned int resp_len = 0;
uart_ring *ur = NULL;
timestamp_t t;

#ifdef DEBUG_COMMS
puts("raw control request: "); hexdump(req, sizeof(ControlPacket_t)); puts("\n");
puts("- request "); puth(req->request); puts("\n");
puts("- param1 "); puth(req->param1); puts("\n");
puts("- param2 "); puth(req->param2); puts("\n");
#endif

switch (req->request) {
// **** 0xa0: get rtc time
case 0xa0:
Expand Down
2 changes: 1 addition & 1 deletion board/stm32fx/llspi.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void DMA2_Stream2_IRQ_Handler(void) {
bool reponse_ack = false;
if (check_checksum(spi_buf_rx + SPI_HEADER_SIZE, spi_data_len_mosi + 1)) {
if (spi_endpoint == 0U) {
if (spi_data_len_mosi >= 8U) {
if (spi_data_len_mosi >= sizeof(ControlPacket_t)) {
response_len = comms_control_handler((ControlPacket_t *)(spi_buf_rx + SPI_HEADER_SIZE), spi_buf_tx + 3);
reponse_ack = true;
} else {
Expand Down
8 changes: 7 additions & 1 deletion python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import datetime
import traceback
import warnings
import logging
from functools import wraps
from typing import Optional
from itertools import accumulate
Expand All @@ -21,6 +22,11 @@

__version__ = '0.0.10'

# setup logging
LOGLEVEL = os.environ.get('LOGLEVEL', 'INFO').upper()
logging.basicConfig(level=LOGLEVEL, format='%(message)s')


BASEDIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")

DEBUG = os.getenv("PANDADEBUG") is not None
Expand Down Expand Up @@ -50,7 +56,7 @@ def pack_can_buffer(arr):
snds.append(b'')
idx += 1

#Apply counter to each 64 byte packet
# Apply counter to each 64 byte packet
for idx in range(len(snds)):
tx = b''
counter = 0
Expand Down
75 changes: 53 additions & 22 deletions python/spi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import math
import struct
import spidev
import logging
from functools import reduce
from typing import List

# Constants
SYNC = 0x5A
Expand All @@ -11,64 +14,92 @@

MAX_RETRY_COUNT = 5

USB_MAX_SIZE = 0x40

# This mimics the handle given by libusb1 for easy interoperability
class SpiHandle:
def __init__(self):
self.spi = spidev.SpiDev()
self.spi = spidev.SpiDev() # pylint: disable=c-extension-no-member
self.spi.open(0, 0)

self.spi.max_speed_hz = 30000000

# helpers
def _transfer(self, endpoint, data, max_rx_len=1000):
for _ in range(MAX_RETRY_COUNT):
def _calc_checksum(self, data: List[int]) -> int:
cksum = CHECKSUM_START
for b in data:
cksum ^= b
return cksum

def _transfer(self, endpoint: int, data, max_rx_len: int = 1000) -> bytes:
logging.debug("starting transfer: endpoint=%d, max_rx_len=%d", endpoint, max_rx_len)
logging.debug("==============================================")

for n in range(MAX_RETRY_COUNT):
logging.debug("\ntry #%d", n+1)
try:
logging.debug("- send header")
packet = struct.pack("<BBHH", SYNC, endpoint, len(data), max_rx_len)
packet += bytes([reduce(lambda x, y: x^y, packet) ^ CHECKSUM_START])
self.spi.xfer2(packet)

logging.debug("- waiting for ACK")
# TODO: add timeout?
dat = b"\x00"
while dat[0] not in [HACK, NACK]:
dat = self.spi.xfer2(b"\x12")

if dat[0] == NACK:
raise Exception("Got NACK response for header")

packet = bytes(data)
packet += bytes([reduce(lambda x, y: x^y, packet) ^ CHECKSUM_START])
# send data
logging.debug("- sending data")
packet = bytes([*data, self._calc_checksum(data)])
self.spi.xfer2(packet)

logging.debug("- waiting for ACK")
dat = b"\x00"
while dat[0] not in [DACK, NACK]:
dat = self.spi.xfer2(b"\xab")

if dat[0] == NACK:
raise Exception("Got NACK response for data")

response_len = struct.unpack("<H", bytes(self.spi.xfer2(b"\x00" * 2)))[0]
# get response length, then response
response_len_bytes = bytes(self.spi.xfer2(b"\x00" * 2))
response_len = struct.unpack("<H", response_len_bytes)[0]

logging.debug("- receiving response")
dat = bytes(self.spi.xfer2(b"\x00" * (response_len + 1)))
# TODO: verify CRC
dat = dat[:-1]
if self._calc_checksum([DACK, *response_len_bytes, *dat]) != 0:
raise Exception("SPI got bad checksum")

return dat
return dat[:-1]
except Exception:
pass
logging.exception("SPI transfer failed, %d retries left", n)
raise Exception(f"SPI transaction failed {MAX_RETRY_COUNT} times")

# libusb1 functions
def close(self):
self.spi.close()

def controlWrite(self, request_type, request, value, index, data, timeout=0):
return self._transfer(0, struct.pack("<HHHH", request, value, index, 0))

def controlRead(self, request_type, request, value, index, length, timeout=0):
return self._transfer(0, struct.pack("<HHHH", request, value, index, length))

# TODO: implement these
def bulkWrite(self, endpoint, data, timeout=0):
pass

def bulkRead(self, endpoint, data, timeout=0):
pass
def controlWrite(self, request_type: int, request: int, value: int, index: int, data, timeout: int = 0):
return self._transfer(0, struct.pack("<BHHH", request, value, index, 0))

def controlRead(self, request_type: int, request: int, value: int, index: int, length: int, timeout: int = 0):
return self._transfer(0, struct.pack("<BHHH", request, value, index, length))

# TODO: implement these properly
def bulkWrite(self, endpoint: int, data: List[int], timeout: int = 0) -> int:
for x in range(math.ceil(len(data) / USB_MAX_SIZE)):
self._transfer(endpoint, data[USB_MAX_SIZE*x:USB_MAX_SIZE*(x+1)])
return len(data)

def bulkRead(self, endpoint: int, length: int, timeout: int = 0) -> bytes:
ret: List[int] = []
for _ in range(math.ceil(length / USB_MAX_SIZE)):
d = self._transfer(endpoint, [], max_rx_len=USB_MAX_SIZE)
ret += d
if len(d) < USB_MAX_SIZE:
break
return bytes(ret)

0 comments on commit 016615e

Please sign in to comment.