Skip to content

Commit

Permalink
Add tvm-sys
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed May 6, 2020
1 parent 7eb2451 commit dd2bcd0
Show file tree
Hide file tree
Showing 10 changed files with 1,324 additions and 0 deletions.
34 changes: 34 additions & 0 deletions rust/tvm-sys/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

[package]
name = "tvm-sys"
version = "0.1.0"
authors = ["TVM Contributors"]
license = "Apache-2.0"
edition = "2018"

[features]
bindings = []

[dependencies]
thiserror = "^1.0"
anyhow = "^1.0"
ndarray = "0.12"

[build-dependencies]
bindgen = "0.51"
117 changes: 117 additions & 0 deletions rust/tvm-sys/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

extern crate bindgen;

use std::path::PathBuf;

// extern crate cmake;

use std::env;
// use std::path::Path;
// use std::process::Command;
// use cmake::Config;

// fn main() {
// if !Path::new("tvm/.git").exists() {
// let _ = Command::new("git")
// .args(&["submodule", "update", "--recursive", "--init"])
// .status();
// }

// let dst = Config::new("tvm")
// .very_verbose(true)
// .build();

// // let dst = dst.join("build");

// let out_dir = env::var("OUT_DIR").unwrap();

// println!("{}", out_dir);
// // let _ = Command::new("mv")
// // .args(&[format!("{}/build/libtvm.dylib", dst.display()), out_dir])
// // .status();

// println!("cargo:rustc-link-search=native={}/lib", dst.display());
// // TODO(@jroesch): hack for dylib behavior
// for lib in &[/* "tvm", */ "tvm_runtime", /* "tvm_topi" */] {
// // let src = format!("{}/lib/lib{}.dylib", out_dir, lib);
// // let dst = format!("{}/../../../deps", out_dir);
// // let _ = Command::new("mv")
// // .args(&[src, dst])
// // .status();
// println!("cargo:rustc-link-lib=dylib={}", lib);
// }
// // "-Wl,-rpath,/scratch/library/"
// println!("cargo:rustc-env=TVM_HOME={}/build", dst.display());
// // panic!("");
// // cc::Build::new()
// // .cpp(true)
// // .flag("-std=c++11")
// // .flag("-Wno-ignored-qualifiers")
// // .flag("-Wno-unused-parameter")
// // .include("/Users/jroesch/Git/tvm/include")
// // .include("/Users/jroesch/Git/tvm/3rdparty/dmlc-core/include")
// // .include("/Users/jroesch/Git/tvm/3rdparty/dlpack/include")
// // .include("/Users/jroesch/Git/tvm/3rdparty/HalideIR/src")
// // .file("tvm_wrapper.cc")
// // .compile("tvm_ffi");
// // println!("cargo:rustc-link-lib=dylib=tvm");
// // println!("cargo:rustc-link-search=/Users/jroesch/Git/tvm/build");
// }

fn main() {
let tvm_home = option_env!("TVM_HOME").map(str::to_string).unwrap_or({
let crate_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.canonicalize()
.unwrap();
crate_dir
.parent()
.unwrap()
.parent()
.unwrap()
.to_str()
.unwrap()
.to_string()
});

if cfg!(feature = "bindings") {
println!("cargo:rerun-if-env-changed=TVM_HOME");
// println!("cargo:rustc-link-lib=dylib=tvm_runtime");
// TODO: move to core
// println!("cargo:rustc-link-lib=dylib=tvm_runtime");
println!("cargo:rustc-link-lib=dylib=tvm");
println!("cargo:rustc-link-search={}/build", tvm_home);
}

// @see rust-bindgen#550 for `blacklist_type`
bindgen::Builder::default()
.header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home))
.header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home))
.clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home))
.clang_arg(format!("-I{}/include/", tvm_home))
.blacklist_type("max_align_t")
.layout_tests(false)
.derive_partialeq(true)
.derive_eq(true)
.generate()
.expect("unable to generate bindings")
.write_to_file(PathBuf::from("src/c_runtime_api.rs"))
.expect("can not write the bindings!");
}
62 changes: 62 additions & 0 deletions rust/tvm-sys/src/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

use std::{
mem,
os::raw::{c_int, c_void},
};

use crate::ffi::{
DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt,
DLDeviceType_kDLCPU, DLTensor,
};

/// `From` conversions to `DLTensor` for `ndarray::Array`.
/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
macro_rules! impl_dltensor_from_ndarray {
($type:ty, $typecode:expr) => {
impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
DLTensor {
data: arr.as_mut_ptr() as *mut c_void,
ctx: DLContext {
device_type: DLDeviceType_kDLCPU,
device_id: 0,
},
ndim: arr.ndim() as c_int,
dtype: DLDataType {
code: $typecode as u8,
bits: 8 * mem::size_of::<$type>() as u8,
lanes: 1,
},
shape: arr.shape().as_ptr() as *const i64 as *mut i64,
strides: arr.strides().as_ptr() as *const isize as *mut i64,
byte_offset: 0,
}
}
}
};
}

impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
64 changes: 64 additions & 0 deletions rust/tvm-sys/src/byte_array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use std::os::raw::c_char;

use crate::ffi::TVMByteArray;

/// A struct holding TVM byte-array.
///
/// ## Example
///
/// ```
/// let v = b"hello";
/// let barr = tvm_sys::ByteArray::from(&v);
/// assert_eq!(barr.len(), v.len());
/// assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]);
/// ```
pub type ByteArray = TVMByteArray;

impl ByteArray {
/// Gets the underlying byte-array
pub fn data(&self) -> &'static [u8] {
unsafe { std::slice::from_raw_parts(self.data as *const u8, self.size) }
}

/// Gets the length of the underlying byte-array
pub fn len(&self) -> usize {
self.size
}

/// Converts the underlying byte-array to `Vec<u8>`
pub fn to_vec(&self) -> Vec<u8> {
self.data().to_vec()
}

pub fn is_empty(&self) -> bool {
self.len() == 0
}
}

// Needs AsRef for Vec
impl<T: AsRef<[u8]>> From<T> for ByteArray {
fn from(arg: T) -> Self {
let arg = arg.as_ref();
ByteArray {
data: arg.as_ptr() as *const c_char,
size: arg.len(),
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn convert() {
let v = vec![1u8, 2, 3];
let barr = ByteArray::from(&v);
assert_eq!(barr.len(), v.len());
assert_eq!(barr.to_vec(), vec![1u8, 2, 3]);
let v = b"hello";
let barr = ByteArray::from(&v);
assert_eq!(barr.len(), v.len());
assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]);
}
}
Loading

0 comments on commit dd2bcd0

Please sign in to comment.