diff --git a/src/types.rs b/src/types.rs index a1a87f6..83f83af 100644 --- a/src/types.rs +++ b/src/types.rs @@ -468,7 +468,9 @@ pub mod bigarray { #[cfg(all(feature = "bigarray-ext", not(feature = "no-std")))] pub(crate) mod bigarray_ext { - use ndarray::{ArrayView2, ArrayView3, ArrayViewMut2, ArrayViewMut3, Dimension}; + use ndarray::{ + ArrayView2, ArrayView3, ArrayViewMut2, ArrayViewMut3, Dimension, Shape, ShapeBuilder, + }; use core::{marker::PhantomData, mem, ptr, slice}; @@ -488,13 +490,17 @@ pub(crate) mod bigarray_ext { /// Returns array view pub fn view(&self) -> ArrayView2 { let ba = unsafe { self.0.custom_ptr_val::() }; - unsafe { ArrayView2::from_shape_ptr(self.shape(), (*ba).data as *const T) } + unsafe { + ArrayView2::from_shape_ptr(build_shape(ba, self.shape()), (*ba).data as *const T) + } } /// Returns mutable array view pub fn view_mut(&mut self) -> ArrayViewMut2 { let ba = unsafe { self.0.custom_ptr_val::() }; - unsafe { ArrayViewMut2::from_shape_ptr(self.shape(), (*ba).data as *mut T) } + unsafe { + ArrayViewMut2::from_shape_ptr(build_shape(ba, self.shape()), (*ba).data as *mut T) + } } /// Returns the shape of `self` @@ -568,13 +574,17 @@ pub(crate) mod bigarray_ext { /// Returns array view pub fn view(&self) -> ArrayView3 { let ba = unsafe { self.0.custom_ptr_val::() }; - unsafe { ArrayView3::from_shape_ptr(self.shape(), (*ba).data as *const T) } + unsafe { + ArrayView3::from_shape_ptr(build_shape(ba, self.shape()), (*ba).data as *const T) + } } /// Returns mutable array view pub fn view_mut(&mut self) -> ArrayViewMut3 { let ba = unsafe { self.0.custom_ptr_val::() }; - unsafe { ArrayViewMut3::from_shape_ptr(self.shape(), (*ba).data as *mut T) } + unsafe { + ArrayViewMut3::from_shape_ptr(build_shape(ba, self.shape()), (*ba).data as *mut T) + } } /// Returns the shape of `self` @@ -638,4 +648,15 @@ pub(crate) mod bigarray_ext { array } } + + fn build_shape( + ba: *const bigarray::Bigarray, + shape: S, + ) -> Shape<::Dim> { + if unsafe { (*ba).is_fortran() } { + shape.f() + } else { + shape.into_shape() + } + } } diff --git a/sys/src/bigarray.rs b/sys/src/bigarray.rs index 94a9885..aa00fff 100644 --- a/sys/src/bigarray.rs +++ b/sys/src/bigarray.rs @@ -21,6 +21,13 @@ pub struct Bigarray { pub dim: [Intnat; 0], } +impl Bigarray { + /// Returns true if array is Fortran contiguous + pub fn is_fortran(&self) -> bool { + (self.flags & Layout::FORTRAN_LAYOUT as isize) != 0 + } +} + #[allow(non_camel_case_types)] pub enum Managed { EXTERNAL = 0, /* Data is not allocated by OCaml */ @@ -47,6 +54,12 @@ pub enum Kind { KIND_MASK = 0xFF, /* Mask for kind in flags field */ } +#[allow(non_camel_case_types)] +pub enum Layout { + C_LAYOUT = 0, /* Row major, indices start at 0 */ + FORTRAN_LAYOUT = 0x100, /* Column major, indices start at 1 */ +} + extern "C" { pub fn malloc(size: usize) -> Data; pub fn caml_ba_alloc(flags: i32, num_dims: i32, data: Data, dim: *const i32) -> Value;