Skip to content

Commit

Permalink
feat: build Deepdetect + pytorch MPS on Apple platforms
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob committed Aug 11, 2023
1 parent 05096fd commit aa8822d
Show file tree
Hide file tree
Showing 24 changed files with 262 additions and 90 deletions.
178 changes: 145 additions & 33 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ option(USE_FAISS "use FAISS as indexer" ON)
option(BUILD_SPDLOG "build SPDLOG instead of using system library" ON)
option(BUILD_PROTOBUF "build PROTOBUF instead of using system library" ON)
option(USE_BOOST_BACKTRACE "use boost backtrace" ON)
option(USE_OPENMP "use OpenMP" ON)

if (USE_CAFFE)
add_definitions(-DUSE_CAFFE)
Expand Down Expand Up @@ -92,8 +93,37 @@ if (NOT EXISTS ${CMAKE_BINARY_DIR}/src)
COMMAND bash -c "mkdir ${CMAKE_BINARY_DIR}/src")
endif()

set(CMAKE_CXX_FLAGS "-g -O2 -Wall -Wextra -fopenmp -fPIC -std=c++14 -Wl,--no-as-needed -ltcmalloc_minimal")
set(CMAKE_CXX_STANDARD 14)
if (USE_OPENMP)
if (APPLE)
set(FOPENMP "-Xclang -fopenmp")
find_library(OPENMP_LIBRARIES REQUIRED
NAMES omp libomp
HINTS /opt/homebrew/Cellar/libomp/*/lib/)

message(STATUS "OpenMP location: ${OPENMP_LIBRARIES}")
find_path(OPENMP_INCLUDE_DIRS REQUIRED
NAMES omp.h
HINTS /opt/homebrew/Cellar/libomp/*/include/)
message(STATUS "OpenMP header ${OPENMP_INCLUDE_DIRS}")
include_directories(SYSTEM ${OPENMP_INCLUDE_DIRS})
else()
set(FOPENMP "-fopenmp")
endif()
else()
set(FOPENMP "")
endif()

if (APPLE)
set(CMAKE_CXX_FLAGS "-g -O2 -Wall -Wextra ${FOPENMP} -fPIC -std=c++17")

message(STATUS "Apple build: CUDA is disabled")
set(USE_CUDNN OFF)
set(USE_CUDA_CV OFF)
else()
set(CMAKE_CXX_FLAGS "-g -O2 -Wall -Wextra ${FOPENMP} -fPIC -std=c++14 -Wl,--no-as-needed -ltcmalloc_minimal")
endif()

set(CMAKE_CXX_STANDARD 17)
add_definitions("-DUSE_OPENCV" "-DUSE_LMDB")

if(WARNING)
Expand Down Expand Up @@ -126,6 +156,25 @@ if (USE_JSON_API)
string(APPEND CMAKE_CXX_FLAGS " -DUSE_JSON_API")
endif()

# VARIANT

if (APPLE)
# Variant is not on brew
set(VARIANT_VERSION v1.2.0)
message(STATUS "Download Variant library ${VARIANT_VERSION}")
ExternalProject_Add(
variant
PREFIX variant
UPDATE_DISCONNECTED 1
URL https://github.com/mapbox/variant/archive/refs/tags/${VARIANT_VERSION}.tar.gz
URL_HASH SHA256=7059f4420d504c4bc96f8a462a0f6d029c5be914ba55cc030a0a773366dd7bc8
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
)
set(VARIANT_INCLUDE_DIRS "${CMAKE_BINARY_DIR}/variant/src/variant/include/")
endif()

# PROTOBUF

# use protobuf same version as pytorch one
Expand All @@ -149,7 +198,13 @@ if (BUILD_PROTOBUF)
set(PROTOBUF_INCLUDE_DIR ${CMAKE_BINARY_DIR}/protobuf/src/protobuf/src)
set(PROTOBUF_LIB_DIR ${CMAKE_BINARY_DIR}/protobuf/src/protobuf-build/)
set(PROTOBUF_PROTOC ${PROTOBUF_LIB_DIR}/protoc)
set(PROTOBUF_LIB_DEPS ${PROTOBUF_LIB_DIR}/libprotobuf.so ${PROTOBUF_LIB_DIR}/libprotoc.so ${PROTOBUF_LIB_DIR}/libprotobuf-lite.so)

if (APPLE)
set(PROTOBUF_LIB_DEPS ${PROTOBUF_LIB_DIR}/libprotobuf.dylib ${PROTOBUF_LIB_DIR}/libprotoc.dylib ${PROTOBUF_LIB_DIR}/libprotobuf-lite.dylib)
else()
set(PROTOBUF_LIB_DEPS ${PROTOBUF_LIB_DIR}/libprotobuf.so ${PROTOBUF_LIB_DIR}/libprotoc.so ${PROTOBUF_LIB_DIR}/libprotobuf-lite.so)
endif()

link_directories(${PROTOBUF_LIB_DIR})
include_directories(SYSTEM ${PROTOBUF_INCLUDE_DIR})
else()
Expand All @@ -163,6 +218,17 @@ else()
set(PROTOBUF_INCLUDE_DIR ${Protobuf_INCLUDE_DIRS})
set(PROTOBUF_LIB_DEPS protobuf::libprotobuf)
add_custom_target(protobuf)

# absl for protobuf 22+
if (${Protobuf_VERSION} VERSION_GREATER_EQUAL 4.22.0)
message(STATUS "Add Abseil dependency required by Protobuf V${Protobuf_VERSION}")
if(NOT TARGET absl::strings)
find_package(absl CONFIG)
endif()
set_target_properties(protobuf::libprotobuf PROPERTIES
INTERFACE_LINK_LIBRARIES "absl::absl_check;absl::absl_log;absl::algorithm;absl::base;absl::bind_front;absl::bits;absl::btree;absl::cleanup;absl::cord;absl::core_headers;absl::debugging;absl::die_if_null;absl::dynamic_annotations;absl::flags;absl::flat_hash_map;absl::flat_hash_set;absl::function_ref;absl::hash;absl::layout;absl::log_initialize;absl::log_severity;absl::memory;absl::node_hash_map;absl::node_hash_set;absl::optional;absl::span;absl::status;absl::statusor;absl::strings;absl::synchronization;absl::time;absl::type_traits;absl::utility;absl::variant"
)
endif()
endif()

file(
Expand Down Expand Up @@ -275,7 +341,7 @@ else()
endif()

# CUDA validation
if (NOT CUDA_FOUND)
if (NOT APPLE AND NOT CUDA_FOUND)
if (USE_TENSORRT)
message(FATAL, "USE_TENSORRT=ON needs CUDA installed")
endif()
Expand Down Expand Up @@ -380,7 +446,6 @@ if (CUDA_FOUND)
message(STATUS "CUDA_ARCH=${CUDA_ARCH}")
else()
set(CUDA_LIB_DEPS "")
add_definitions(-DCPU_ONLY)
endif()

if (USE_DLIB)
Expand Down Expand Up @@ -770,8 +835,9 @@ endif()

# Torch
if (USE_TORCH)

set(CMAKE_EXE_LINKER_FLAGS "-Wl,--no-as-needed")
if (!APPLE)
set(CMAKE_EXE_LINKER_FLAGS "-Wl,--no-as-needed")
endif()

set(PYTORCH_PATCHES_PATH ${CMAKE_BINARY_DIR}/patches/pytorch)

Expand All @@ -780,6 +846,9 @@ if (USE_TORCH)
${PYTORCH_PATCHES_PATH}/pytorch_113_new_logger.patch
${PYTORCH_PATCHES_PATH}/pytorch_19_use_new_logger.patch
)
if (APPLE)
list(APPEND PYTORCH_PATCHES ${PYTORCH_PATCHES_PATH}/pytorch_113_apple_includes.patch)
endif()

message(STATUS "Configuring libtorch")
add_definitions(-DUSE_TORCH)
Expand All @@ -788,37 +857,66 @@ if (USE_TORCH)
set(PYTORCH_COMMIT v2.0.1)
set(PYTORCH_COMPLETE ${CMAKE_BINARY_DIR}/CMakeFiles/pytorch-complete)

if(USE_TORCH_CPU_ONLY)
set(PYTORCH_USE_CUDA 0)
else()
set(PYTORCH_USE_CUDA 1)
endif()
if (APPLE)
if(USE_TORCH_CPU_ONLY)
set(PYTORCH_USE_MPS 0)
else()
set(PYTORCH_USE_MPS 1)
add_definitions("-DUSE_MPS")
endif()

file(WRITE ${CMAKE_BINARY_DIR}/build_pytorch.sh "set -x\n${CMAKE_COMMAND} -E env PATH=${PROTOBUF_LIB_DIR}:$ENV{PATH} BUILD_CUSTOM_PROTOBUF=0 GLIBCXX_USE_CXX11_ABI=1 BUILD_TEST=0 USE_CUDA=${PYTORCH_USE_CUDA} BUILD_CAFFE2=1 BUILD_CAFFE2_OPS=1 BUILD_CAFFE2_MOBILE=0 USE_DDLOG=1 USE_TENSORRT=0 CAFFE2_LINK_LOCAL_PROTOBUF=0 \"CMAKE_CXX_FLAGS=-isystem ${SPDLOG_INCLUDE_DIR} -isystem ${PROTOBUF_INCLUDE_DIR}\" \"CMAKE_CUDA_FLAGS=-isystem ${SPDLOG_INCLUDE_DIR} -isystem ${PROTOBUF_INCLUDE_DIR}\" TORCH_CUDA_ARCH_LIST=\"${CUDA_ARCH}\" \"CMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc\" CMAKE_PREFIX_PATH=${PROTOBUF_LIB_DIR}/cmake MAX_JOBS=8 python3 ../pytorch/tools/build_libtorch.py")
ExternalProject_Add(
pytorch
PREFIX pytorch
GIT_REPOSITORY https://github.com/pytorch/pytorch.git
GIT_TAG ${PYTORCH_COMMIT}
GIT_CONFIG advice.detachedHead=false
UPDATE_DISCONNECTED 1
PATCH_COMMAND "" test -f ${PYTORCH_COMPLETE} && echo Skipping || git apply ${PYTORCH_PATCHES} && echo Applying ${PYTORCH_PATCHES}
CONFIGURE_COMMAND "" cd ${CMAKE_BINARY_DIR}/pytorch/src/pytorch/third_party/fmt/ && git checkout 7.1.0
BUILD_COMMAND ""
COMMAND test -f ${PYTORCH_COMPLETE} && echo Skipping || sh ${CMAKE_BINARY_DIR}/build_pytorch.sh
INSTALL_COMMAND ""
DEPENDS spdlog protobuf
)
message(STATUS "PROTOBUF_PROTOC_EXECUTABLE=${PROTOBUF_PROTOC} PROTOBUF_INCLUDE_DIRS=${PROTOBUF_INCLUDE_DIR} PROTOBUF_LIBRARIES=${PROTOBUF_LIB_DEPS}")
ExternalProject_Add(
pytorch
PREFIX pytorch
GIT_REPOSITORY https://github.com/pytorch/pytorch.git
GIT_TAG ${PYTORCH_COMMIT}
GIT_CONFIG advice.detachedHead=false
UPDATE_DISCONNECTED 1
PATCH_COMMAND "" test -f ${PYTORCH_COMPLETE} && echo Skipping || git apply ${PYTORCH_PATCHES} && echo Applying ${PYTORCH_PATCHES}
CONFIGURE_COMMAND "" cd ${CMAKE_BINARY_DIR}/pytorch/src/pytorch/third_party/fmt/ && git checkout 7.1.0
BUILD_COMMAND ""
COMMAND test -f ${PYTORCH_COMPLETE} && echo Skipping || ${CMAKE_COMMAND} -E env PATH=${PROTOBUF_LIB_DIR}:${PROTOBUF_INCLUDE_DIR}:$ENV{PATH} BUILD_CUSTOM_PROTOBUF=0 BUILD_TEST=0 USE_MPS=${PYTORCH_USE_MPS} USE_DDLOG=1 USE_TENSORRT=0 "CMAKE_CXX_FLAGS=-isystem ${SPDLOG_INCLUDE_DIR} -isystem ${PROTOBUF_INCLUDE_DIR}" "CMAKE_CUDA_FLAGS=-isystem ${SPDLOG_INCLUDE_DIR} -isystem ${PROTOBUF_INCLUDE_DIR}" CMAKE_PREFIX_PATH="${PROTOBUF_LIB_DIR}/cmake" MAX_JOBS=8 python3 ../pytorch/tools/build_libtorch.py
INSTALL_COMMAND ""
DEPENDS spdlog protobuf
)
set(TORCH_BINARY_LOCATION ${CMAKE_BINARY_DIR}/pytorch/src/pytorch-build/build)
else()
if(USE_TORCH_CPU_ONLY)
set(PYTORCH_USE_CUDA 0)
else()
set(PYTORCH_USE_CUDA 1)
endif()

file(WRITE ${CMAKE_BINARY_DIR}/build_pytorch.sh "set -x\n${CMAKE_COMMAND} -E env PATH=${PROTOBUF_LIB_DIR}:$ENV{PATH} BUILD_CUSTOM_PROTOBUF=0 GLIBCXX_USE_CXX11_ABI=1 BUILD_TEST=0 USE_CUDA=${PYTORCH_USE_CUDA} BUILD_CAFFE2=1 BUILD_CAFFE2_OPS=1 BUILD_CAFFE2_MOBILE=0 USE_DDLOG=1 USE_TENSORRT=0 CAFFE2_LINK_LOCAL_PROTOBUF=0 \"CMAKE_CXX_FLAGS=-isystem ${SPDLOG_INCLUDE_DIR} -isystem ${PROTOBUF_INCLUDE_DIR}\" \"CMAKE_CUDA_FLAGS=-isystem ${SPDLOG_INCLUDE_DIR} -isystem ${PROTOBUF_INCLUDE_DIR}\" TORCH_CUDA_ARCH_LIST=\"${CUDA_ARCH}\" \"CMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc\" CMAKE_PREFIX_PATH=${PROTOBUF_LIB_DIR}/cmake MAX_JOBS=8 python3 ../pytorch/tools/build_libtorch.py")
ExternalProject_Add(
pytorch
PREFIX pytorch
GIT_REPOSITORY https://github.com/pytorch/pytorch.git
GIT_TAG ${PYTORCH_COMMIT}
GIT_CONFIG advice.detachedHead=false
UPDATE_DISCONNECTED 1
PATCH_COMMAND "" test -f ${PYTORCH_COMPLETE} && echo Skipping || git apply ${PYTORCH_PATCHES} && echo Applying ${PYTORCH_PATCHES}
CONFIGURE_COMMAND "" cd ${CMAKE_BINARY_DIR}/pytorch/src/pytorch/third_party/fmt/ && git checkout 7.1.0
BUILD_COMMAND ""
COMMAND test -f ${PYTORCH_COMPLETE} && echo Skipping || sh ${CMAKE_BINARY_DIR}/build_pytorch.sh
INSTALL_COMMAND ""
DEPENDS spdlog protobuf
)
endif()
set(TORCH_LOCATION ${CMAKE_BINARY_DIR}/pytorch/src/pytorch/torch)
endif()

set(TORCH_LIB_DEPS ${TORCH_LOCATION}/lib/libtorch.so ${TORCH_LOCATION}/lib/libtorch_cpu.so ${TORCH_LOCATION}/lib/libc10.so -llmdb)

if (USE_TORCH_CPU_ONLY)
list(APPEND TORCH_LIB_DEPS ${TORCH_LOCATION}/lib/libtorch_cpu.so iomp5)
if (APPLE)
set(TORCH_LIB_DEPS ${TORCH_BINARY_LOCATION}/lib/libtorch.dylib ${TORCH_BINARY_LOCATION}/lib/libtorch_cpu.dylib ${TORCH_BINARY_LOCATION}/lib/libc10.dylib ${TORCH_BINARY_LOCATION}/lib/libtorch_global_deps.dylib -llmdb)
else()
list(APPEND TORCH_LIB_DEPS ${TORCH_LOCATION}/lib/libc10_cuda.so ${TORCH_LOCATION}/lib/libtorch_cuda.so)
set(TORCH_LIB_DEPS ${TORCH_LOCATION}/lib/libtorch.so ${TORCH_LOCATION}/lib/libtorch_cpu.so ${TORCH_LOCATION}/lib/libc10.so -llmdb)

if (USE_TORCH_CPU_ONLY)
list(APPEND TORCH_LIB_DEPS iomp5)
else()
list(APPEND TORCH_LIB_DEPS ${TORCH_LOCATION}/lib/libc10_cuda.so ${TORCH_LOCATION}/lib/libtorch_cuda.so)
endif()
endif()

set(TORCH_INC_DIR ${TORCH_LOCATION}/include/ ${TORCH_LOCATION}/include/torch/csrc/api/include/ ${CMAKE_BINARY_DIR}/pytorch/src/pytorch/torch/include/torch/csrc/api/include ${TORCH_LOCATION}/.. ${CMAKE_BINARY_DIR}/src)
Expand Down Expand Up @@ -1119,6 +1217,15 @@ if (USE_HTTP_SERVER)
endif()

# main library, main & tests
if (APPLE)
include_directories("/opt/homebrew/include/")
include_directories("/opt/homebrew/opt/libarchive/include/")
include_directories("/opt/homebrew/include/eigen3")
include_directories("/opt/homebrew/include/utf8cpp/")

link_directories("/opt/homebrew/lib/")
endif()

include_directories("${PROJECT_SOURCE_DIR}/src")
# add the binary tree to the search path for include files
# so that we will find dd_config.h
Expand All @@ -1143,11 +1250,14 @@ include_directories(
${CMAKE_SOURCE_DIR}/backends/dlib
${CMAKE_SOURCE_DIR}/backends/tsne
)
include_directories(${VARIANT_INCLUDE_DIRS})

add_subdirectory(src)

set(COMMON_INCLUDE_DIRS
${VARIANT_INCLUDE_DIRS}
${Boost_INCLUDE_DIRS}
${OPENMP_INCLUDE_DIRS}
${OATPP_INCLUDE_DIRS}
${HTTP_INCLUDE_DIR}
${PROTOBUF_INCLUDE_DIR}
Expand Down Expand Up @@ -1184,7 +1294,7 @@ set(COMMON_LINK_DIRS
)

set(COMMON_LINK_LIBS
ddetect ${DLIB_LIB_DEPS} ${TENSORRT_LIBS} ${CUDA_LIB_DEPS} glog gflags ${OpenCV_LIBS} curlpp curl ${Boost_LIBRARIES} archive
ddetect ${DLIB_LIB_DEPS} ${TENSORRT_LIBS} ${CUDA_LIB_DEPS} gflags ${OpenCV_LIBS} curlpp curl ${Boost_LIBRARIES} archive ${OPENMP_LIBRARIES}
${PROTOBUF_LIB_DEPS}
${HTTP_LIB_DEPS}
${SPDLOG_LIB_DEPS}
Expand Down Expand Up @@ -1278,6 +1388,7 @@ message(STATUS "USE_COMMAND_LINE: ${USE_COMMAND_LINE}")
message(STATUS "USE_JSON_API: ${USE_JSON_API}")
message(STATUS "USE_HTTP_SERVER: ${USE_HTTP_SERVER}")
message(STATUS "USE_HTTP_SERVER_OATPP: ${USE_HTTP_SERVER_OATPP}")
message(STATUS "USE_CPU_ONLY: ${USE_CPU_ONLY}")
message(STATUS "BUILD_SPDLOG: ${BUILD_SPDLOG}")
message(STATUS "BUILD_PROTOBUF: ${BUILD_PROTOBUF}")
message(STATUS "BUILD_TESTS: ${BUILD_TESTS}")
Expand Down Expand Up @@ -1306,5 +1417,6 @@ message(STATUS "USE_XGBOOST: ${USE_XGBOOST}")
message(STATUS "USE_XGBOOST_CPU_ONLY: ${USE_XGBOOST_CPU_ONLY}")
message(STATUS "USE_TSNE: ${USE_TSNE}")
message(STATUS "USE_BOOST_BACKTRACE: ${USE_BOOST_BACKTRACE}")
message(STATUS "USE_OPENMP: ${USE_OPENMP}")
message(STATUS "USE_CUDA_CV: ${USE_CUDA_CV}")
message(STATUS "OPENCV_VERSION: ${OPENCV_VERSION}")
5 changes: 5 additions & 0 deletions main/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
include_directories(${COMMON_INCLUDE_DIRS})
link_directories(${COMMON_LINK_DIRS})

if (APPLE)
link_directories("/opt/homebrew/lib/")
endif()

if (USE_COMMAND_LINE OR USE_HTTP_SERVER OR USE_HTTP_SERVER_OATPP)
add_executable (dede dede.cc)
add_dependencies(dede protobuf)
Expand Down
14 changes: 14 additions & 0 deletions patches/pytorch/pytorch_113_apple_includes.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 471fc8a8d3d..09026bdb2d7 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1004,6 +1004,9 @@ if(NOT APPLE AND UNIX)
list(APPEND Caffe2_DEPENDENCY_LIBS dl)
endif()

+# Mac OS
+include_directories(BEFORE /opt/homebrew/include/)
+
# Prefix path to Caffe2 headers.
# If a directory containing installed Caffe2 headers was inadvertently
# added to the list of include directories, prefixing
12 changes: 6 additions & 6 deletions patches/pytorch/pytorch_113_new_logger.patch
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
diff --git a/c10/util/logging_is_dd_log.h b/c10/util/logging_is_dd_log.h
new file mode 100644
index 0000000000..8f72861a08
index 00000000000..c1c0c24defa
--- /dev/null
+++ b/c10/util/logging_is_dd_log.h
@@ -0,0 +1,61 @@
Expand All @@ -13,8 +13,8 @@ index 0000000000..8f72861a08
+#include "caffe/llogging.h"
+
+#ifndef VLOG
+#define DEBUG "none"
+static const std::string VLogLevels[4] = {DEBUG, DEBUG, DEBUG, DEBUG};
+#define DEBUG_STR "none"
+static const std::string VLogLevels[4] = {DEBUG_STR, DEBUG_STR, DEBUG_STR, DEBUG_STR};
+#define VLOG(n) LOG(VLogLevels[n])
+#endif
+
Expand Down Expand Up @@ -67,7 +67,7 @@ index 0000000000..8f72861a08
+#endif
diff --git a/caffe/llogging.h b/caffe/llogging.h
new file mode 100644
index 0000000000..720bda66a7
index 00000000000..548b0febecc
--- /dev/null
+++ b/caffe/llogging.h
@@ -0,0 +1,315 @@
Expand Down Expand Up @@ -157,7 +157,7 @@ index 0000000000..720bda66a7
+#ifdef CAFFE_THROW_ON_ERROR
+#include <sstream>
+#define SSTR(x) \
+ dynamic_cast<std::ostringstream &>((std::ostringstream() << std::dec << x)) \
+ dynamic_cast<std::ostringstream &&>((std::ostringstream() << std::dec << x)) \
+ .str()
+class CaffeErrorException : public std::exception
+{
Expand All @@ -168,7 +168,7 @@ index 0000000000..720bda66a7
+ ~CaffeErrorException() throw()
+ {
+ }
+ const char *what() const throw()
+ const char *what() const throw() override
+ {
+ return _s.c_str();
+ }
Expand Down
6 changes: 3 additions & 3 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ endif()

add_library(ddetect ${ddetect_SOURCES})
add_dependencies(ddetect protobuf caffe_pb_h)
if (APPLE)
add_dependencies(ddetect variant)
endif()
if (BUILD_SPDLOG)
add_dependencies(ddetect spdlog)
endif()
Expand Down Expand Up @@ -183,6 +186,3 @@ if (USE_HTTP_SERVER_OATPP)
add_dependencies(ddetect oatpp-swagger)
endif()
endif()
if (BUILD_SPDLOG)
add_dependencies(ddetect spdlog)
endif()
2 changes: 1 addition & 1 deletion src/backends/torch/native/templates/crnn_head.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
namespace dd
{
void CRNNHeadImpl::get_params(const APIData &ad_params,
const std::vector<long int> &input_dims,
const std::vector<int64_t> &input_dims,
int output_size)
{
if (ad_params.has("timesteps"))
Expand Down
Loading

0 comments on commit aa8822d

Please sign in to comment.