Skip to content

Commit

Permalink
[µTVM] Add TVMPlatformGenerateRandom, a non-cryptographic random numb…
Browse files Browse the repository at this point in the history
…er generator. (apache#7266)

* [uTVM] Add TVMPlatformGenerateRandom, and use with Session nonce.

 * This change is preparation to support autotuning in microTVM. It
   also cleans up a loose end in the microTVM RPC server
   implementation.
 * Randomness is needed in two places of the CRT:
    1. to initialize the Session nonce, which provides a more robust
       way to detect reboots and ensure that messages are not confused
       across them.
    2. to fill input tensors when timing AutoTVM operators (once
       AutoTVM support lands in the next PR).

 * This change adds TVMPlatformGenerateRandom, a platform function for
   generating non-cryptographic random data, to service those needs.
  • Loading branch information
areusch authored and electriclilies committed Feb 18, 2021
1 parent 0ac44fc commit 433e4b6
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 17 deletions.
19 changes: 19 additions & 0 deletions include/tvm/runtime/crt/platform.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,25 @@ tvm_crt_error_t TVMPlatformTimerStart();
*/
tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds);

/*! \brief Fill a buffer with random data.
*
* Cryptographically-secure random data is NOT required. This function is intended for use
* cases such as filling autotuning input tensors and choosing the nonce used for microTVM RPC.
*
* This function does not need to be implemented for inference tasks. It is used only by
* AutoTVM and the RPC server. When not implemented, an internal weak-linked stub is provided.
*
* Please take care that across successive resets, this function returns different sequences of
* values. If e.g. the random number generator is seeded with the same value, it may make it
* difficult for a host to detect device resets during autotuning or host-driven inference.
*
* \param buffer Pointer to the 0th byte to write with random data. `num_bytes` of random data
* should be written here.
* \param num_bytes Number of bytes to write.
* \return kTvmErrorNoError if successful; a descriptive error code otherwise.
*/
tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes);

#ifdef __cplusplus
} // extern "C"
#endif
Expand Down
10 changes: 6 additions & 4 deletions include/tvm/runtime/crt/rpc_common/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ class Session {
/*! \brief An invalid nonce value that typically indicates an unknown nonce. */
static constexpr const uint8_t kInvalidNonce = 0;

Session(uint8_t initial_session_nonce, Framer* framer, FrameBuffer* receive_buffer,
MessageReceivedFunc message_received_func, void* message_received_func_context)
: local_nonce_{initial_session_nonce},
Session(Framer* framer, FrameBuffer* receive_buffer, MessageReceivedFunc message_received_func,
void* message_received_func_context)
: local_nonce_{kInvalidNonce},
session_id_{0},
state_{State::kReset},
receiver_{this},
Expand All @@ -99,9 +99,11 @@ class Session {

/*!
* \brief Send a session terminate message, usually done at startup to interrupt a hanging remote.
* \param initial_session_nonce Initial nonce that should be used on the first session start
* message. Callers should ensure this is different across device resets.
* \return kTvmErrorNoError on success, or an error code otherwise.
*/
tvm_crt_error_t Initialize();
tvm_crt_error_t Initialize(uint8_t initial_session_nonce);

/*!
* \brief Terminate any previously-established session.
Expand Down
5 changes: 5 additions & 0 deletions src/runtime/crt/common/crt_runtime_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -509,3 +509,8 @@ release_and_return : {
}
return err;
}

// Default implementation, overridden by the platform runtime.
__attribute__((weak)) tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) {
return kTvmErrorFunctionCallNotImplemented;
}
15 changes: 15 additions & 0 deletions src/runtime/crt/host/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* \brief main entry point for host subprocess-based CRT
*/
#include <inttypes.h>
#include <time.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/crt/logging.h>
#include <tvm/runtime/crt/memory.h>
Expand Down Expand Up @@ -93,6 +94,20 @@ tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) {
g_utvm_timer_running = 0;
return kTvmErrorNoError;
}

static_assert(RAND_MAX >= (1 << 8), "RAND_MAX is smaller than acceptable");
unsigned int random_seed = 0;
tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) {
if (random_seed == 0) {
random_seed = (unsigned int)time(NULL);
}
for (size_t i = 0; i < num_bytes; ++i) {
int random = rand_r(&random_seed);
buffer[i] = (uint8_t)random;
}

return kTvmErrorNoError;
}
}

uint8_t memory[512 * 1024];
Expand Down
5 changes: 4 additions & 1 deletion src/runtime/crt/utvm_rpc_common/session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ tvm_crt_error_t Session::StartSession() {
return to_return;
}

tvm_crt_error_t Session::Initialize() { return TerminateSession(); }
tvm_crt_error_t Session::Initialize(uint8_t initial_session_nonce) {
local_nonce_ = initial_session_nonce;
return TerminateSession();
}

tvm_crt_error_t Session::TerminateSession() {
SetSessionId(0, 0);
Expand Down
12 changes: 9 additions & 3 deletions src/runtime/crt/utvm_rpc_server/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,21 @@ class MicroRPCServer {
utvm_rpc_channel_write_t write_func, void* write_func_ctx)
: receive_buffer_{receive_storage, receive_storage_size_bytes},
framer_{&send_stream_},
session_{0xa5, &framer_, &receive_buffer_, &HandleCompleteMessageCb, this},
session_{&framer_, &receive_buffer_, &HandleCompleteMessageCb, this},
io_{&session_, &receive_buffer_},
unframer_{session_.Receiver()},
rpc_server_{&io_},
is_running_{true} {}

void* operator new(size_t count, void* ptr) { return ptr; }

void Initialize() { CHECK_EQ(kTvmErrorNoError, session_.Initialize(), "rpc server init"); }
void Initialize() {
uint8_t initial_session_nonce = Session::kInvalidNonce;
tvm_crt_error_t error =
TVMPlatformGenerateRandom(&initial_session_nonce, sizeof(initial_session_nonce));
CHECK_EQ(kTvmErrorNoError, error, "generating random session id");
CHECK_EQ(kTvmErrorNoError, session_.Initialize(initial_session_nonce), "rpc server init");
}

/*! \brief Process one message from the receive buffer, if possible.
*
Expand Down Expand Up @@ -242,7 +248,7 @@ void TVMLogf(const char* format, ...) {
} else {
tvm::runtime::micro_rpc::SerialWriteStream write_stream;
tvm::runtime::micro_rpc::Framer framer{&write_stream};
tvm::runtime::micro_rpc::Session session{0xa5, &framer, nullptr, nullptr, nullptr};
tvm::runtime::micro_rpc::Session session{&framer, nullptr, nullptr, nullptr};
tvm_crt_error_t err =
session.SendMessage(tvm::runtime::micro_rpc::MessageType::kLog,
reinterpret_cast<uint8_t*>(log_buffer), num_bytes_logged);
Expand Down
30 changes: 27 additions & 3 deletions src/runtime/micro/micro_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class MicroTransportChannel : public RPCChannel {
write_stream_{fsend, session_start_timeout},
framer_{&write_stream_},
receive_buffer_{new uint8_t[TVM_CRT_MAX_PACKET_SIZE_BYTES], TVM_CRT_MAX_PACKET_SIZE_BYTES},
session_{0x5c, &framer_, &receive_buffer_, &HandleMessageReceivedCb, this},
session_{&framer_, &receive_buffer_, &HandleMessageReceivedCb, this},
unframer_{session_.Receiver()},
did_receive_message_{false},
frecv_{frecv},
Expand Down Expand Up @@ -161,13 +161,35 @@ class MicroTransportChannel : public RPCChannel {
}
}

static constexpr const int kNumRandRetries = 10;
static std::atomic<unsigned int> random_seed;

inline uint8_t GenerateRandomNonce() {
// NOTE: this is bad concurrent programming but in practice we don't really expect race
// conditions here, and even if they occur we don't particularly care whether a competing
// process computes a different random seed. This value is just chosen pseudo-randomly to
// form an initial distinct session id. Here we just want to protect against bad loads causing
// confusion.
unsigned int seed = random_seed.load();
if (seed == 0) {
seed = (unsigned int)time(NULL);
}
uint8_t initial_nonce = 0;
for (int i = 0; i < kNumRandRetries && initial_nonce == 0; ++i) {
initial_nonce = rand_r(&seed);
}
random_seed.store(seed);
ICHECK_NE(initial_nonce, 0) << "rand() does not seem to be producing random values";
return initial_nonce;
}

bool StartSessionInternal() {
using ::std::chrono::duration_cast;
using ::std::chrono::microseconds;
using ::std::chrono::steady_clock;

steady_clock::time_point start_time = steady_clock::now();
ICHECK_EQ(kTvmErrorNoError, session_.Initialize());
ICHECK_EQ(kTvmErrorNoError, session_.Initialize(GenerateRandomNonce()));
ICHECK_EQ(kTvmErrorNoError, session_.StartSession());

if (session_start_timeout_ == microseconds::zero() &&
Expand Down Expand Up @@ -198,7 +220,7 @@ class MicroTransportChannel : public RPCChannel {
}
end_time += session_start_retry_timeout_;

ICHECK_EQ(kTvmErrorNoError, session_.Initialize());
ICHECK_EQ(kTvmErrorNoError, session_.Initialize(GenerateRandomNonce()));
ICHECK_EQ(kTvmErrorNoError, session_.StartSession());
}

Expand Down Expand Up @@ -365,6 +387,8 @@ class MicroTransportChannel : public RPCChannel {
std::string pending_chunk_;
};

std::atomic<unsigned int> MicroTransportChannel::random_seed{0};

TVM_REGISTER_GLOBAL("micro._rpc_connect").set_body([](TVMArgs args, TVMRetValue* rv) {
MicroTransportChannel* micro_channel =
new MicroTransportChannel(args[1], args[2], ::std::chrono::microseconds(uint64_t(args[3])),
Expand Down
14 changes: 8 additions & 6 deletions tests/crt/session_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ class TestSession {
TestSession(uint8_t initial_nonce)
: framer{&framer_write_stream},
receive_buffer{receive_buffer_array, sizeof(receive_buffer_array)},
sess{initial_nonce, &framer, &receive_buffer, TestSessionMessageReceivedThunk, this},
unframer{sess.Receiver()} {}
sess{&framer, &receive_buffer, TestSessionMessageReceivedThunk, this},
unframer{sess.Receiver()},
initial_nonce{initial_nonce} {}

void WriteTo(TestSession* other) {
auto framer_buffer = framer_write_stream.BufferContents();
Expand Down Expand Up @@ -84,6 +85,7 @@ class TestSession {
FrameBuffer receive_buffer;
Session sess;
Unframer unframer;
uint8_t initial_nonce;
};

#define EXPECT_FRAMED_PACKET(session, expected) \
Expand Down Expand Up @@ -126,14 +128,14 @@ class SessionTest : public ::testing::Test {

TEST_F(SessionTest, NormalExchange) {
tvm_crt_error_t err;
err = alice_.sess.Initialize();
err = alice_.sess.Initialize(alice_.initial_nonce);
EXPECT_EQ(kTvmErrorNoError, err);
EXPECT_FRAMED_PACKET(alice_,
"\xfe\xff\xfd\x03\0\0\0\0\0\x02"
"fw");
alice_.WriteTo(&bob_);

err = bob_.sess.Initialize();
err = bob_.sess.Initialize(bob_.initial_nonce);
EXPECT_EQ(kTvmErrorNoError, err);
EXPECT_FRAMED_PACKET(bob_,
"\xfe\xff\xfd\x03\0\0\0\0\0\x02"
Expand Down Expand Up @@ -212,14 +214,14 @@ static constexpr const char kBobStartPacket[] = "\xff\xfd\x04\0\0\0f\0\0\x01`\xa

TEST_F(SessionTest, DoubleStart) {
tvm_crt_error_t err;
err = alice_.sess.Initialize();
err = alice_.sess.Initialize(alice_.initial_nonce);
EXPECT_EQ(kTvmErrorNoError, err);
EXPECT_FRAMED_PACKET(alice_,
"\xfe\xff\xfd\x03\0\0\0\0\0\x02"
"fw");
alice_.WriteTo(&bob_);

err = bob_.sess.Initialize();
err = bob_.sess.Initialize(bob_.initial_nonce);
EXPECT_EQ(kTvmErrorNoError, err);
EXPECT_FRAMED_PACKET(bob_,
"\xfe\xff\xfd\x03\0\0\0\0\0\x02"
Expand Down
4 changes: 4 additions & 0 deletions tests/micro/qemu/zephyr-runtime/prj.conf
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,7 @@ CONFIG_FPU=y

# For TVMPlatformAbort().
CONFIG_REBOOT=y

# For TVMPlatformGenerateRandom(). Remember, these values do not need to be truly random.
CONFIG_TEST_RANDOM_GENERATOR=y
CONFIG_TIMER_RANDOM_GENERATOR=y
21 changes: 21 additions & 0 deletions tests/micro/qemu/zephyr-runtime/src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <drivers/uart.h>
#include <kernel.h>
#include <power/reboot.h>
#include <random/rand32.h>
#include <stdio.h>
#include <sys/printk.h>
#include <sys/ring_buffer.h>
Expand Down Expand Up @@ -161,6 +162,26 @@ tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) {
return kTvmErrorNoError;
}

tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) {
uint32_t random; // one unit of random data.

// Fill parts of `buffer` which are as large as `random`.
size_t num_full_blocks = num_bytes / sizeof(random);
for (int i = 0; i < num_full_blocks; ++i) {
random = sys_rand32_get();
memcpy(&buffer[i * sizeof(random)], &random, sizeof(random));
}

// Fill any leftover tail which is smaller than `random`.
size_t num_tail_bytes = num_bytes % sizeof(random);
if (num_tail_bytes > 0) {
random = sys_rand32_get();
memcpy(&buffer[num_bytes - num_tail_bytes], &random, num_tail_bytes);
}

return kTvmErrorNoError;
}

#define RING_BUF_SIZE 512
struct uart_rx_buf_t {
struct ring_buf buf;
Expand Down

0 comments on commit 433e4b6

Please sign in to comment.