Skip to content

Commit

Permalink
follow strict tensorflow alignment requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
wjakob committed Sep 20, 2024
1 parent 4647efc commit 96cca6c
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions tests/test_ndarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,15 +312,18 @@ NB_MODULE(test_ndarray_ext, m) {
});

m.def("ret_tensorflow", []() {
float *f = new float[8] { 1, 2, 3, 4, 5, 6, 7, 8 };
struct alignas(256) Buf {
float f[8];
};
Buf *buf = new Buf({ 1, 2, 3, 4, 5, 6, 7, 8 });
size_t shape[2] = { 2, 4 };

nb::capsule deleter(f, [](void *data) noexcept {
nb::capsule deleter(buf, [](void *data) noexcept {
destruct_count++;
delete[] (float *) data;
delete[] (Buf *) data;
});

return nb::ndarray<nb::tensorflow, float, nb::shape<2, 4>>(f, 2, shape,
return nb::ndarray<nb::tensorflow, float, nb::shape<2, 4>>(buf->f, 2, shape,
deleter);
});

Expand Down

0 comments on commit 96cca6c

Please sign in to comment.