Skip to content

Commit

Permalink
native runtime: implement memory bounds check
Browse files Browse the repository at this point in the history
Fixes:
#750
#709
  • Loading branch information
yamt committed Aug 14, 2024
1 parent 49e6303 commit 0be3938
Showing 1 changed file with 66 additions and 5 deletions.
71 changes: 66 additions & 5 deletions runtimes/native/src/runtime.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "runtime.h"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "apu.h"
Expand Down Expand Up @@ -40,6 +41,51 @@ static Memory* memory;
static w4_Disk* disk;
static bool firstFrame;

static void panic(const char *msg)
{
/* REVISIT: it's cleaner to raise a wasm trap */
fprintf(stderr, "fatal error in host function: %s\n", msg);
exit(1);
}

static void out_of_bounds_access(void)
{
panic("out of bands memory access");
}

static uint32_t mul_u32_with_overflow_check(uint32_t a, uint32_t b)
{
uint32_t c = a * b;
if (c / a != b) {
panic("integer overflow");
}
return c;
}

static void bounds_check(const void *sp, size_t sz)
{
const void *memory_sp = (const void *)memory;
const void *memory_ep = memory_sp + (1 << 16);
const void *ep = sp + sz;
if (ep <= sp || sp < memory_sp || memory_ep < ep) {
out_of_bounds_access();
}
}

static void bounds_check_cstr(const char *p)
{
const void *memory_sp = (const void *)memory;
const void *memory_ep = memory_sp + (1 << 16);
if (p < memory_sp || memory_ep <= p) {
out_of_bounds_access();
}
while (*p++ != 0) {
if (memory_ep <= p) {
out_of_bounds_access();
}
}
}

void w4_runtimeInit (uint8_t* memoryBytes, w4_Disk* diskBytes) {
memory = (Memory*)memoryBytes;
disk = diskBytes;
Expand Down Expand Up @@ -83,6 +129,9 @@ void w4_runtimeBlitSub (const uint8_t* sprite, int x, int y, int width, int heig
bool flipX = (flags & 2);
bool flipY = (flags & 4);
bool rotate = (flags & 8);
uint32_t bpp = (int)bpp2 + 1;
uint32_t nbits = mul_u32_with_overflow_check(mul_u32_with_overflow_check(width, height), bpp);
bounds_check(sprite, nbits / 8);
w4_framebufferBlit(sprite, x, y, width, height, srcX, srcY, stride, bpp2, flipX, flipY, rotate);
}

Expand Down Expand Up @@ -112,16 +161,19 @@ void w4_runtimeRect (int x, int y, int width, int height) {
}

void w4_runtimeText (const uint8_t* str, int x, int y) {
bounds_check_cstr(str);
// printf("text: %s, %d, %d\n", str, x, y);
w4_framebufferText(str, x, y);
}

void w4_runtimeTextUtf8 (const uint8_t* str, int byteLength, int x, int y) {
bounds_check(str, byteLength);
// printf("textUtf8: %p, %d, %d, %d\n", str, byteLength, x, y);
w4_framebufferTextUtf8(str, byteLength, x, y);
}

void w4_runtimeTextUtf16 (const uint16_t* str, int byteLength, int x, int y) {
bounds_check(str, byteLength);
// printf("textUtf16: %p, %d, %d, %d\n", str, byteLength, x, y);
w4_framebufferTextUtf16(str, byteLength, x, y);
}
Expand All @@ -132,6 +184,7 @@ void w4_runtimeTone (int frequency, int duration, int volume, int flags) {
}

int w4_runtimeDiskr (uint8_t* dest, int size) {
bounds_check(dest, size);
if (!disk) {
return 0;
}
Expand All @@ -144,6 +197,7 @@ int w4_runtimeDiskr (uint8_t* dest, int size) {
}

int w4_runtimeDiskw (const uint8_t* src, int size) {
bounds_check(src, size);
if (!disk) {
return 0;
}
Expand All @@ -157,20 +211,24 @@ int w4_runtimeDiskw (const uint8_t* src, int size) {
}

void w4_runtimeTrace (const uint8_t* str) {
bounds_check_cstr(str);
puts(str);
}

void w4_runtimeTraceUtf8 (const uint8_t* str, int byteLength) {
bounds_check(str, byteLength);
printf("%.*s\n", byteLength, str);
}

void w4_runtimeTraceUtf16 (const uint16_t* str, int byteLength) {
bounds_check(str, byteLength);
printf("TODO: traceUtf16: %p, %d\n", str, byteLength);
}

void w4_runtimeTracef (const uint8_t* str, const void* stack) {
const uint8_t* argPtr = stack;
uint32_t strPtr;
bounds_check_cstr(str);
for (; *str != 0; ++str) {
if (*str == '%') {
const uint8_t sym = *(++str);
Expand All @@ -181,27 +239,30 @@ void w4_runtimeTracef (const uint8_t* str, const void* stack) {
putc('%', stdout);
break;
case 'c':
bounds_check(argPtr, 4);
putc(*(char*)argPtr, stdout);
argPtr += 4;
break;
case 'd':
bounds_check(argPtr, 4);
printf("%d", *(int32_t*)argPtr);
argPtr += 4;
break;
case 'x':
bounds_check(argPtr, 4);
printf("%x", *(uint32_t*)argPtr);
argPtr += 4;
break;
case 's':
bounds_check(argPtr, 4);
strPtr = *(uint32_t*)argPtr;
argPtr += 4;
if (strPtr > 0 && strPtr < sizeof(*memory)) {
printf("%.*s", (int)(sizeof(*memory) - strPtr), (char*)memory + strPtr);
} else {
printf("<invalid memory>");
}
const char *strPtr_host = (const char *)memory + strPtr;
bounds_check_cstr(strPtr_host);
printf("%s", strPtr_host);
break;
case 'f':
bounds_check(argPtr, 8);
printf("%lg", *(double*)argPtr);
argPtr += 8;
break;
Expand Down

0 comments on commit 0be3938

Please sign in to comment.