Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

native runtime: implement memory bounds check #751

Merged
merged 4 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 66 additions & 5 deletions runtimes/native/src/runtime.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "runtime.h"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "apu.h"
Expand Down Expand Up @@ -40,6 +41,51 @@ static Memory* memory;
static w4_Disk* disk;
static bool firstFrame;

static void panic(const char *msg)
{
/* REVISIT: it's cleaner to raise a wasm trap */
fprintf(stderr, "fatal error in host function: %s\n", msg);
exit(1);
}

static void out_of_bounds_access(void)
{
panic("out of bands memory access");
}

static uint32_t mul_u32_with_overflow_check(uint32_t a, uint32_t b)
{
uint32_t c = a * b;
if (c / a != b) {
panic("integer overflow");
}
return c;
}

static void bounds_check(const void *sp, size_t sz)
{
const void *memory_sp = (const void *)memory;
const void *memory_ep = (const uint8_t *)memory_sp + (1 << 16);
const void *ep = (const uint8_t *)sp + sz;
if (ep < sp || sp < memory_sp || memory_ep < ep) {
out_of_bounds_access();
}
}

static void bounds_check_cstr(const char *p)
{
const void *memory_sp = (const void *)memory;
const void *memory_ep = (const uint8_t *)memory_sp + (1 << 16);
if (p < memory_sp || memory_ep <= p) {
out_of_bounds_access();
}
while (*p++ != 0) {
if (memory_ep <= p) {
out_of_bounds_access();
}
}
}

void w4_runtimeInit (uint8_t* memoryBytes, w4_Disk* diskBytes) {
memory = (Memory*)memoryBytes;
disk = diskBytes;
Expand Down Expand Up @@ -83,6 +129,9 @@ void w4_runtimeBlitSub (const uint8_t* sprite, int x, int y, int width, int heig
bool flipX = (flags & 2);
bool flipY = (flags & 4);
bool rotate = (flags & 8);
uint32_t bpp = (int)bpp2 + 1;
uint32_t nbits = mul_u32_with_overflow_check(mul_u32_with_overflow_check(width, height), bpp);
bounds_check(sprite, nbits / 8);
Comment on lines +132 to +134
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the bounds check here, but do we necessarily need to do an overflow checked multiply? I think we can live with possible overflows, as long as it doesn't allow out of bounds access somehow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to implement bounds check that way, we need to make this module (runtime.c) know the implementation details of w4_framebufferBlit. or, implement bounds check in w4_framebufferBlit itself.

IMO, it's simpler to perform overflow checks here.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. Passing huge values for width/height here that result in an overflow here would fool the bounds check and result in OOB in w4_framebufferBlit. Nice catch!

w4_framebufferBlit(sprite, x, y, width, height, srcX, srcY, stride, bpp2, flipX, flipY, rotate);
}

Expand Down Expand Up @@ -112,16 +161,19 @@ void w4_runtimeRect (int x, int y, int width, int height) {
}

void w4_runtimeText (const uint8_t* str, int x, int y) {
bounds_check_cstr(str);
// printf("text: %s, %d, %d\n", str, x, y);
w4_framebufferText(str, x, y);
}

void w4_runtimeTextUtf8 (const uint8_t* str, int byteLength, int x, int y) {
bounds_check(str, byteLength);
// printf("textUtf8: %p, %d, %d, %d\n", str, byteLength, x, y);
w4_framebufferTextUtf8(str, byteLength, x, y);
}

void w4_runtimeTextUtf16 (const uint16_t* str, int byteLength, int x, int y) {
bounds_check(str, byteLength);
// printf("textUtf16: %p, %d, %d, %d\n", str, byteLength, x, y);
w4_framebufferTextUtf16(str, byteLength, x, y);
}
Expand All @@ -132,6 +184,7 @@ void w4_runtimeTone (int frequency, int duration, int volume, int flags) {
}

int w4_runtimeDiskr (uint8_t* dest, int size) {
bounds_check(dest, size);
if (!disk) {
return 0;
}
Expand All @@ -144,6 +197,7 @@ int w4_runtimeDiskr (uint8_t* dest, int size) {
}

int w4_runtimeDiskw (const uint8_t* src, int size) {
bounds_check(src, size);
if (!disk) {
return 0;
}
Expand All @@ -157,20 +211,24 @@ int w4_runtimeDiskw (const uint8_t* src, int size) {
}

void w4_runtimeTrace (const uint8_t* str) {
bounds_check_cstr(str);
puts(str);
}

void w4_runtimeTraceUtf8 (const uint8_t* str, int byteLength) {
bounds_check(str, byteLength);
printf("%.*s\n", byteLength, str);
}

void w4_runtimeTraceUtf16 (const uint16_t* str, int byteLength) {
bounds_check(str, byteLength);
printf("TODO: traceUtf16: %p, %d\n", str, byteLength);
}

void w4_runtimeTracef (const uint8_t* str, const void* stack) {
const uint8_t* argPtr = stack;
uint32_t strPtr;
bounds_check_cstr(str);
for (; *str != 0; ++str) {
if (*str == '%') {
const uint8_t sym = *(++str);
Expand All @@ -181,27 +239,30 @@ void w4_runtimeTracef (const uint8_t* str, const void* stack) {
putc('%', stdout);
break;
case 'c':
bounds_check(argPtr, 4);
putc(*(char*)argPtr, stdout);
argPtr += 4;
break;
case 'd':
bounds_check(argPtr, 4);
printf("%d", *(int32_t*)argPtr);
argPtr += 4;
break;
case 'x':
bounds_check(argPtr, 4);
printf("%x", *(uint32_t*)argPtr);
argPtr += 4;
break;
case 's':
bounds_check(argPtr, 4);
strPtr = *(uint32_t*)argPtr;
argPtr += 4;
if (strPtr > 0 && strPtr < sizeof(*memory)) {
printf("%.*s", (int)(sizeof(*memory) - strPtr), (char*)memory + strPtr);
} else {
printf("<invalid memory>");
}
const char *strPtr_host = (const char *)memory + strPtr;
bounds_check_cstr(strPtr_host);
printf("%s", strPtr_host);
break;
case 'f':
bounds_check(argPtr, 8);
printf("%lg", *(double*)argPtr);
argPtr += 8;
break;
Expand Down
13 changes: 13 additions & 0 deletions test/integer_overflow/blit_1bpp_overflow.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
(module
(func $blit (import "env" "blit") (param i32 i32 i32 i32 i32 i32))
(memory (import "env" "memory") 1 1)
(func (export "update")
i32.const 0
i32.const 0
i32.const 0
i32.const 0x10
i32.const 0x10000001
i32.const 0
call $blit
)
)
13 changes: 13 additions & 0 deletions test/integer_overflow/blit_2bpp_overflow.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
(module
(func $blit (import "env" "blit") (param i32 i32 i32 i32 i32 i32))
(memory (import "env" "memory") 1 1)
(func (export "update")
i32.const 0
i32.const 0
i32.const 0
i32.const 0x8
i32.const 0x10000001
i32.const 1
call $blit
)
)
13 changes: 13 additions & 0 deletions test/out_of_bands_memory_access/blit_1bpp.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
(module
(func $blit (import "env" "blit") (param i32 i32 i32 i32 i32 i32))
(memory (import "env" "memory") 1 1)
(func (export "update")
i32.const 65529
i32.const 0
i32.const 0
i32.const 8
i32.const 8
i32.const 0
call $blit
)
)
13 changes: 13 additions & 0 deletions test/out_of_bands_memory_access/blit_2bpp.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
(module
(func $blit (import "env" "blit") (param i32 i32 i32 i32 i32 i32))
(memory (import "env" "memory") 1 1)
(func (export "update")
i32.const 65521
i32.const 0
i32.const 0
i32.const 8
i32.const 8
i32.const 1
call $blit
)
)
10 changes: 10 additions & 0 deletions test/out_of_bands_memory_access/tracef_double.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
(module
(func $tracef (import "env" "tracef") (param i32 i32))
(memory (import "env" "memory") 1 1)
(func (export "update")
i32.const 50000
i32.const 65529
call $tracef
)
(data (i32.const 50000) "%f\00")
)
10 changes: 10 additions & 0 deletions test/out_of_bands_memory_access/tracef_fmt.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
(module
(func $tracef (import "env" "tracef") (param i32 i32))
(memory (import "env" "memory") 1 1)
(func (export "update")
i32.const 65533
i32.const 0
call $tracef
)
(data (i32.const 65533) "xxx")
)
15 changes: 15 additions & 0 deletions test/out_of_bands_memory_access/tracef_str.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
(module
(func $tracef (import "env" "tracef") (param i32 i32))
(memory (import "env" "memory") 1 1)
(func (export "update")
i32.const 40000
i32.const 65533
i32.store

i32.const 50000
i32.const 40000
call $tracef
)
(data (i32.const 50000) "%s\00")
(data (i32.const 65533) "xxx")
)
10 changes: 10 additions & 0 deletions test/out_of_bands_memory_access/tracef_strptr.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
(module
(func $tracef (import "env" "tracef") (param i32 i32))
(memory (import "env" "memory") 1 1)
(func (export "update")
i32.const 50000
i32.const 65533
call $tracef
)
(data (i32.const 50000) "%s\00")
)