diff --git a/cmake/flexflow-utils.cmake b/cmake/flexflow-utils.cmake index 2030f9c33a..d41573acab 100644 --- a/cmake/flexflow-utils.cmake +++ b/cmake/flexflow-utils.cmake @@ -38,6 +38,9 @@ function(ff_set_cxx_properties target) CXX_STANDARD_REQUIRED YES CXX_EXTENSIONS NO ) + target_compile_options(${target} + PRIVATE $<$:> # add C++ compile flags here + ) endfunction() function(ff_add_library) diff --git a/cmake/nccl.cmake b/cmake/nccl.cmake index ccf9914ac2..12062958cd 100644 --- a/cmake/nccl.cmake +++ b/cmake/nccl.cmake @@ -89,7 +89,7 @@ else() BUILD_BYPRODUCTS ${CMAKE_BINARY_DIR}/deps/nccl/lib/libnccl${LIBEXT} INSTALL_COMMAND "" CONFIGURE_COMMAND "" - BUILD_COMMAND make src.build "${NCCL_BUILD_NVCC_GENCODE} --disable-warnings" "CUDA_HOME=${CUDA_TOOLKIT_ROOT_DIR}" "BUILDDIR=${CMAKE_BINARY_DIR}/deps/nccl" "CXX=${CMAKE_CXX_COMPILER}" CC="${CMAKE_CC_COMPILER}" "CXXFLAGS+=-w" + BUILD_COMMAND make src.build "${NCCL_BUILD_NVCC_GENCODE}" "CUDA_HOME=${CUDA_TOOLKIT_ROOT_DIR}" "BUILDDIR=${CMAKE_BINARY_DIR}/deps/nccl" "CXX=${CMAKE_CXX_COMPILER} -w" CC="${CMAKE_CC_COMPILER}" BUILD_IN_SOURCE 1 ) diff --git a/deps/fmt b/deps/fmt index a33701196a..f5e54359df 160000 --- a/deps/fmt +++ b/deps/fmt @@ -1 +1 @@ -Subproject commit a33701196adfad74917046096bf5a2aa0ab0bb50 +Subproject commit f5e54359df4c26b6230fc61d38aa294581393084 diff --git a/lib/pcg/include/pcg/layer.h b/lib/pcg/include/pcg/layer.h index 1017036e69..6e9415a8fb 100644 --- a/lib/pcg/include/pcg/layer.h +++ b/lib/pcg/include/pcg/layer.h @@ -23,8 +23,10 @@ VISITABLE_STRUCT(::FlexFlow::Layer, attrs, name); MAKE_VISIT_HASHABLE(::FlexFlow::Layer); namespace FlexFlow { -static_assert(is_well_behaved_value_type::value, ""); -static_assert(is_fmtable::value, "Layer must be fmtable"); + +FF_VISIT_FMTABLE(Layer); +CHECK_FMTABLE(Layer); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers.decl b/lib/utils/include/utils/containers.decl.h similarity index 98% rename from lib/utils/include/utils/containers.decl rename to lib/utils/include/utils/containers.decl.h index 960ce06011..8ad65a4488 100644 --- a/lib/utils/include/utils/containers.decl +++ b/lib/utils/include/utils/containers.decl.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_H +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_DECL_H +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_DECL_H #include "utils/bidict.h" #include "utils/invoke.h" diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 7edde2f788..331a423cb0 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_INL #include "bidict.h" -#include "containers.decl" +#include "containers.decl.h" #include "invoke.h" #include "required_core.h" #include "type_traits_core.h" diff --git a/lib/utils/include/utils/exception.decl.h b/lib/utils/include/utils/exception.decl.h new file mode 100644 index 0000000000..d27174f474 --- /dev/null +++ b/lib/utils/include/utils/exception.decl.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_UTILS_INCLUDE_EXCEPTION_DECL_H +#define _FLEXFLOW_UTILS_INCLUDE_EXCEPTION_DECL_H + +#include "utils/fmt.decl.h" +#include + +namespace FlexFlow { + +#ifdef FF_REQUIRE_IMPLEMENTED +#define NOT_IMPLEMENTED() static_assert(false, "Function not yet implemented"); +#else +#define NOT_IMPLEMENTED() throw not_implemented(); +#endif + +class not_implemented : public std::logic_error { +public: + not_implemented(); +}; + +template +std::runtime_error mk_runtime_error(fmt::format_string fmt_str, + T &&...args); +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/exception.h b/lib/utils/include/utils/exception.h index e25a79faab..fd3a0b7ee0 100644 --- a/lib/utils/include/utils/exception.h +++ b/lib/utils/include/utils/exception.h @@ -1,22 +1,12 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_EXCEPTION_H #define _FLEXFLOW_UTILS_INCLUDE_EXCEPTION_H +#include "utils/exception.decl.h" #include "utils/fmt.h" #include namespace FlexFlow { -#ifdef FF_REQUIRE_IMPLEMENTED -#define NOT_IMPLEMENTED() static_assert(false, "Function not yet implemented"); -#else -#define NOT_IMPLEMENTED() throw not_implemented(); -#endif - -class not_implemented : public std::logic_error { -public: - not_implemented() : std::logic_error("Function not yet implemented"){}; -}; - template std::runtime_error mk_runtime_error(fmt::format_string fmt_str, T &&...args) { diff --git a/lib/utils/include/utils/fmt.decl.h b/lib/utils/include/utils/fmt.decl.h new file mode 100644 index 0000000000..367a712b87 --- /dev/null +++ b/lib/utils/include/utils/fmt.decl.h @@ -0,0 +1,41 @@ +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_FMT_DECL_H +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_FMT_DECL_H + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { + +template +using is_fmtable = ::fmt::is_formattable; + +template +struct already_has_ostream_operator; + +template +typename std::enable_if::value, + std::ostream &>::type + operator<<(std::ostream &s, T const &t); + +} // namespace FlexFlow + +namespace fmt { + +template +struct formatter<::std::unordered_set> : formatter<::std::string> { + template + auto format(::std::unordered_set const &m, FormatContext &ctx) + -> decltype(ctx.out()); +}; + +template +struct formatter<::std::vector> : formatter<::std::string> { + template + auto format(::std::vector const &m, FormatContext &ctx) + -> decltype(ctx.out()); +}; + +} // namespace fmt + +#endif diff --git a/lib/utils/include/utils/fmt.h b/lib/utils/include/utils/fmt.h index 23231a6cd8..218f72d8af 100644 --- a/lib/utils/include/utils/fmt.h +++ b/lib/utils/include/utils/fmt.h @@ -1,22 +1,14 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_FMT_H #define _FLEXFLOW_UTILS_INCLUDE_FMT_H -#include "fmt/core.h" -#include "fmt/format.h" -#include "utils/containers.decl" +#include "utils/containers.decl.h" +#include "utils/fmt.decl.h" +#include "utils/test_types.h" #include "utils/type_traits_core.h" -#include namespace FlexFlow { -template -struct is_fmtable : std::false_type {}; - -template -struct is_fmtable()))>> - : std::true_type {}; - -template +template struct already_has_ostream_operator : std::false_type {}; template <> @@ -43,12 +35,18 @@ operator<<(std::ostream &s, T const &t) { } */ +#define CHECK_FMTABLE(...) \ + static_assert(::FlexFlow::is_fmtable<__VA_ARGS__>::value, \ + #__VA_ARGS__ " must be fmtable"); + // This will not template typename std::enable_if::value, std::ostream &>::type operator<<(std::ostream &s, T const &t) { - std::string result = fmt::format("{}", t); + CHECK_FMTABLE(T); + + std::string result = fmt::to_string(t); return s << result; } @@ -63,17 +61,30 @@ typename std::enable_if::value, namespace fmt { template -struct formatter<::std::unordered_set> : formatter<::std::string> { - template - auto format(::std::unordered_set const &m, FormatContext &ctx) - -> decltype(ctx.out()) { - std::string result = - join_strings(m.cbegin(), m.cend(), ", ", [](T const &t) { - return fmt::to_string(t); - }); - return formatter::format(result, ctx); - } -}; +template +auto formatter<::std::unordered_set>::format( + ::std::unordered_set const &m, FormatContext &ctx) + -> decltype(ctx.out()) { + CHECK_FMTABLE(T); + + std::string result = join_strings( + m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); + return formatter::format(result, ctx); +} + +template +template +auto formatter<::std::vector>::format(::std::vector const &m, + FormatContext &ctx) + -> decltype(ctx.out()) { + CHECK_FMTABLE(T); + std::string result = join_strings( + m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); + return formatter::format(result, ctx); +} + +CHECK_FMTABLE(std::vector); +CHECK_FMTABLE(std::unordered_set); } // namespace fmt diff --git a/lib/utils/include/utils/graph/cow_ptr_t.h b/lib/utils/include/utils/graph/cow_ptr_t.h index 97968ad485..9a655ae072 100644 --- a/lib/utils/include/utils/graph/cow_ptr_t.h +++ b/lib/utils/include/utils/graph/cow_ptr_t.h @@ -11,10 +11,8 @@ namespace FlexFlow { template struct cow_ptr_t { - // static_assert(is_clonable::value, - // "cow_ptr_t requires the type to have a clone() method"); // - // TODO: - // https://github.com/flexflow/FlexFlow/issues/909#issue-1833470024 + static_assert(is_clonable::value, + "cow_ptr_t requires the type to have a clone() method"); cow_ptr_t(std::shared_ptr const &ptr) : ptr(ptr) {} cow_ptr_t(std::shared_ptr &&ptr) : ptr(std::move(ptr)) {} diff --git a/lib/utils/include/utils/graph/digraph.h b/lib/utils/include/utils/graph/digraph.h index fc6072d857..c5b37e86cb 100644 --- a/lib/utils/include/utils/graph/digraph.h +++ b/lib/utils/include/utils/graph/digraph.h @@ -37,7 +37,7 @@ struct DiGraphView { unsafe_create_without_ownership(IDiGraphView const &graphView); private: - DiGraphView(std::shared_ptr ptr) : ptr(ptr) {} + DiGraphView(std::shared_ptr ptr); friend struct GraphInternal; diff --git a/lib/utils/include/utils/graph/internal.h b/lib/utils/include/utils/graph/internal.h index 8ecb34bc3d..064e046aca 100644 --- a/lib/utils/include/utils/graph/internal.h +++ b/lib/utils/include/utils/graph/internal.h @@ -3,7 +3,7 @@ #include "utils/graph/digraph.h" #include "utils/graph/digraph_interfaces.h" -#include "utils/graph/labelled/labelled_open.decl" +#include "utils/graph/labelled/labelled_open.decl.h" #include "utils/graph/labelled/labelled_open_interfaces.h" #include "utils/graph/multidigraph.h" #include "utils/graph/multidigraph_interfaces.h" diff --git a/lib/utils/include/utils/graph/labelled/labelled_open.decl b/lib/utils/include/utils/graph/labelled/labelled_open.decl.h similarity index 96% rename from lib/utils/include/utils/graph/labelled/labelled_open.decl rename to lib/utils/include/utils/graph/labelled/labelled_open.decl.h index 11b00a68e0..cf095fe3c7 100644 --- a/lib/utils/include/utils/graph/labelled/labelled_open.decl +++ b/lib/utils/include/utils/graph/labelled/labelled_open.decl.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_DECL -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_DECL +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_DECL_H +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_DECL_H #include "labelled_open_interfaces.h" #include "node_labelled.h" @@ -63,7 +63,11 @@ struct LabelledOpenMultiDiGraph { operator OpenMultiDiGraphView() const; friend void swap(LabelledOpenMultiDiGraph &lhs, - LabelledOpenMultiDiGraph &rhs); + LabelledOpenMultiDiGraph &rhs) { + using std::swap; + + swap(lhs.ptr, rhs.ptr); + } Node add_node(NodeLabel const &l); NodeLabel &at(Node const &n); diff --git a/lib/utils/include/utils/graph/labelled/labelled_open.h b/lib/utils/include/utils/graph/labelled/labelled_open.h index a5c60a37fb..3713a1f8d2 100644 --- a/lib/utils/include/utils/graph/labelled/labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/labelled_open.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_H -#include "labelled_open.decl" +#include "labelled_open.decl.h" #include "labelled_open_interfaces.h" #include "node_labelled.h" #include "utils/graph/internal.h" @@ -69,14 +69,6 @@ LabelledOpenMultiDiGraph::operator OpenMultiDiGraphView() const { return GraphInternal::create_open_multidigraph_view(this->ptr); } -template -void swap(LabelledOpenMultiDiGraph &lhs, - LabelledOpenMultiDiGraph &rhs) { - using std::swap; - - swap(lhs.ptr, rhs.ptr); -} - template Node LabelledOpenMultiDiGraph::add_node( NodeLabel const &l) { diff --git a/lib/utils/include/utils/graph/multidigraph.h b/lib/utils/include/utils/graph/multidigraph.h index 1b56ba7461..77bd3aedea 100644 --- a/lib/utils/include/utils/graph/multidigraph.h +++ b/lib/utils/include/utils/graph/multidigraph.h @@ -30,7 +30,7 @@ struct MultiDiGraphView { unsafe_create_without_ownership(IMultiDiGraphView const &); private: - MultiDiGraphView(std::shared_ptr ptr) : ptr(ptr) {} + MultiDiGraphView(std::shared_ptr ptr); friend struct GraphInternal; diff --git a/lib/utils/include/utils/graph/node.h b/lib/utils/include/utils/graph/node.h index 8f5e9cde09..708eb09491 100644 --- a/lib/utils/include/utils/graph/node.h +++ b/lib/utils/include/utils/graph/node.h @@ -61,7 +61,7 @@ struct GraphView { } private: - GraphView(std::shared_ptr ptr) : ptr(ptr) {} + GraphView(std::shared_ptr ptr); friend struct GraphInternal; diff --git a/lib/utils/include/utils/graph/open_graphs.h b/lib/utils/include/utils/graph/open_graphs.h index 6c8cf88eb2..5b2c86eccf 100644 --- a/lib/utils/include/utils/graph/open_graphs.h +++ b/lib/utils/include/utils/graph/open_graphs.h @@ -34,8 +34,7 @@ struct OpenMultiDiGraphView { } private: - OpenMultiDiGraphView(std::shared_ptr ptr) - : ptr(ptr) {} + OpenMultiDiGraphView(std::shared_ptr ptr); friend struct GraphInternal; diff --git a/lib/utils/include/utils/graph/query_set.h b/lib/utils/include/utils/graph/query_set.h index 6f7314a6c7..c748741a75 100644 --- a/lib/utils/include/utils/graph/query_set.h +++ b/lib/utils/include/utils/graph/query_set.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_QUERY_SET_H #include "utils/bidict.h" -#include "utils/containers.decl" +#include "utils/containers.decl.h" #include "utils/exception.h" #include "utils/optional.h" #include @@ -12,7 +12,7 @@ namespace FlexFlow { template struct query_set { query_set() = delete; - query_set(T const &query) : query({query}) {} + query_set(T const &t) : query(std::unordered_set{t}) {} query_set(std::unordered_set const &query) : query(query) {} diff --git a/lib/utils/include/utils/graph/undirected.h b/lib/utils/include/utils/graph/undirected.h index a1c8dba226..fa021399f8 100644 --- a/lib/utils/include/utils/graph/undirected.h +++ b/lib/utils/include/utils/graph/undirected.h @@ -79,8 +79,7 @@ struct UndirectedGraphView { unsafe_create_without_ownership(IUndirectedGraphView const &); private: - UndirectedGraphView(std::shared_ptr ptr) - : ptr(ptr) {} + UndirectedGraphView(std::shared_ptr ptr); friend struct GraphInternal; diff --git a/lib/utils/include/utils/stack_string.h b/lib/utils/include/utils/stack_string.h index 4286a7ec6b..aab4a18552 100644 --- a/lib/utils/include/utils/stack_string.h +++ b/lib/utils/include/utils/stack_string.h @@ -1,8 +1,10 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_STACK_STRING_H #define _FLEXFLOW_UTILS_INCLUDE_STACK_STRING_H +#include "fmt/core.h" #include "stack_vector.h" #include "utils/fmt.h" +#include "utils/type_traits.h" #include #include @@ -50,6 +52,11 @@ struct stack_basic_string { friend struct std::hash; + friend fmt::basic_string_view + format_as(stack_basic_string const &s) { + return {s.contents.data(), s.length()}; + } + private: stack_vector contents; }; @@ -71,19 +78,6 @@ struct hash<::FlexFlow::stack_basic_string> { } // namespace std -namespace fmt { - -template -struct formatter<::FlexFlow::stack_string> : formatter<::std::string> { - template - auto format(::FlexFlow::stack_string const &m, FormatContext &ctx) - -> decltype(ctx.out()) { - return formatter::format(static_cast(m), ctx); - } -}; - -} // namespace fmt - namespace FlexFlow { static_assert(is_default_constructible>::value, @@ -102,8 +96,9 @@ static_assert(is_neq_comparable>::value, "stack_string must support !="); static_assert(is_lt_comparable>::value, "stack_string must support <"); -static_assert(is_hashable>::value, - "stack_string must be hashable"); +CHECK_WELL_BEHAVED_VALUE_TYPE(stack_string<1>); +CHECK_HASHABLE(stack_string<1>); +CHECK_FMTABLE(stack_string<1>); } // namespace FlexFlow diff --git a/lib/utils/include/utils/stack_vector.h b/lib/utils/include/utils/stack_vector.h index 46ec0d21ef..977244dd62 100644 --- a/lib/utils/include/utils/stack_vector.h +++ b/lib/utils/include/utils/stack_vector.h @@ -5,6 +5,7 @@ #include "hash-utils.h" #include "optional.h" #include "utils/fmt.h" +#include "utils/test_types.h" #include "utils/type_traits.h" #include #include @@ -14,6 +15,23 @@ namespace FlexFlow { template struct stack_vector { +private: + using element_type = + conditional_t::value, T, optional>; + + static T const &get_value(T const &t) { + return t; + } + static T const &get_value(optional const &t) { + return t.value(); + } + static T &get_value(T &t) { + return t; + } + static T &get_value(optional &t) { + return t.value(); + } + public: stack_vector() = default; @@ -26,8 +44,8 @@ struct stack_vector { } } - operator std::vector() { - return {this->begin(), this->end()}; + operator std::vector() const { + return {this->cbegin(), this->cend()}; } void push_back(T const &t) { @@ -45,22 +63,22 @@ struct stack_vector { T const &back() const { assert(this->m_size >= 1); - return this->contents[this->m_size - 1].value(); + return get_value(this->contents[this->m_size - 1]); } T &back() { assert(this->m_size >= 1); - return this->contents[this->m_size - 1].value(); + return get_value(this->contents[this->m_size - 1]); } T const &at(std::size_t idx) const { assert(idx < MAXSIZE); - return this->contents[idx].value(); + return get_value(this->contents[idx]); } T &at(std::size_t idx) { assert(idx < MAXSIZE); - return this->contents[idx].value(); + return get_value(this->contents[idx]); } T const &operator[](std::size_t idx) const { @@ -79,12 +97,12 @@ struct stack_vector { using reference = typename std::conditional::type; using pointer = typename std::conditional::type; - typename std::conditional const *, optional *>:: + typename std::conditional:: type ptr; Iterator(typename std::conditional const *, - optional *>::type ptr) + element_type const *, + element_type *>::type ptr) : ptr(ptr) {} template const &rhs) : ptr(rhs.ptr) {} reference operator*() const { - return ptr->value(); + return get_value(*ptr); } pointer operator->() const { - return &ptr->value(); + return &get_value(*ptr); } Iterator &operator++() { @@ -151,7 +169,7 @@ struct stack_vector { } reference operator[](difference_type const &diff) const { - return this->ptr[diff].value(); + return get_value(this->ptr[diff]); } bool operator<(Iterator const &rhs) const { @@ -181,7 +199,7 @@ struct stack_vector { using const_reference = T const &; iterator begin() { - optional *ptr = this->contents.data(); + element_type *ptr = this->contents.data(); return iterator(ptr); } @@ -190,7 +208,7 @@ struct stack_vector { } const_iterator cbegin() const { - optional const *ptr = this->contents.data(); + element_type const *ptr = this->contents.data(); return const_iterator(ptr); } @@ -267,9 +285,19 @@ struct stack_vector { return (this->m_size == 0); } + T const *data() const { + return this->contents.data(); + } + + friend std::vector format_as(stack_vector const &v) { + CHECK_FMTABLE(std::vector); + + return static_cast>(v); + } + private: std::size_t m_size = 0; - std::array, MAXSIZE> contents; + std::array contents; static_assert( implies, is_equal_comparable>::value, @@ -281,6 +309,8 @@ struct stack_vector { implies, is_lt_comparable>::value, ""); }; +CHECK_FMTABLE(stack_vector); + } // namespace FlexFlow namespace std { @@ -298,25 +328,4 @@ struct hash<::FlexFlow::stack_vector> { } // namespace std -namespace fmt { - -template -struct formatter<::FlexFlow::stack_vector> - : formatter<::std::string> { - template - auto format(::FlexFlow::stack_vector const &m, FormatContext &ctx) - -> decltype(ctx.out()) { - std::string result = - "[" + - join_strings(m.cbegin(), - m.cend(), - ", ", - [](T const &t) { return fmt::to_string(t); }) + - "]"; - return formatter::format(result, ctx); - } -}; - -} // namespace fmt - #endif diff --git a/lib/utils/include/utils/test_types.h b/lib/utils/include/utils/test_types.h index 6002d763a6..2cac547bb6 100644 --- a/lib/utils/include/utils/test_types.h +++ b/lib/utils/include/utils/test_types.h @@ -7,13 +7,22 @@ namespace FlexFlow { namespace test_types { -enum capability { HASHABLE, EQ, CMP, DEFAULT_CONSTRUCTIBLE, COPYABLE }; +enum capability { + HASHABLE, + EQ, + CMP, + DEFAULT_CONSTRUCTIBLE, + COPYABLE, + PLUS, + PLUSEQ, + FMT +}; template struct capability_implies : std::false_type {}; template <> -struct capability_implies : std::true_type {}; +struct capability_implies : std::true_type {}; template struct capability_implies : std::true_type {}; @@ -34,52 +43,76 @@ struct test_type_t { template using supports = conjunction...>; - template ::value, - bool>::type = true> + template ::value, bool>::type = true> test_type_t(); - template ::value, - bool>::type = true> + template ::value, bool>::type = true> test_type_t() = delete; - template < - typename std::enable_if::value, bool>::type = true> + template ::value, bool>::type = true> test_type_t(test_type_t const &); - template < - typename std::enable_if::value, bool>::type = true> + template ::value, bool>::type = true> test_type_t(test_type_t const &) = delete; - typename std::enable_if::value, bool>::type + template + typename std::enable_if::value, bool>::type operator==(test_type_t const &) const; - typename std::enable_if::value, bool>::type + template + typename std::enable_if::value, bool>::type operator!=(test_type_t const &) const; - typename std::enable_if::value, bool>::type + template + typename std::enable_if::value, bool>::type operator<(test_type_t const &) const; - typename std::enable_if::value, bool>::type + template + typename std::enable_if::value, bool>::type operator>(test_type_t const &) const; - typename std::enable_if::value, bool>::type + template + typename std::enable_if::value, bool>::type operator<=(test_type_t const &) const; - typename std::enable_if::value, bool>::type + template + typename std::enable_if::value, bool>::type operator>=(test_type_t const &) const; + + template + typename std::enable_if::value, test_type_t>::type + operator+(test_type_t const &); + + template + typename std::enable_if::value, test_type_t>::type + operator+=(test_type_t const &); }; +template +enable_if_t::value, std::string> + format_as(test_type_t); + using no_eq = test_type_t<>; using eq = test_type_t; using cmp = test_type_t; using hash_cmp = test_type_t; +using plusable = test_type_t; +using fmtable = test_type_t; } // namespace test_types } // namespace FlexFlow namespace std { -template <::FlexFlow::test_types::capability... CAPABILITIES> +template < + ::FlexFlow::test_types:: + capability... CAPABILITIES> //, typename = typename + // std::enable_if<::FlexFlow::test_types::has_capability<::FlexFlow::test_types::HASHABLE>::value, + // bool>::type> struct hash<::FlexFlow::test_types::test_type_t> { typename std::enable_if< ::FlexFlow::test_types::has_capability<::FlexFlow::test_types::HASHABLE, diff --git a/lib/utils/include/utils/tuple.h b/lib/utils/include/utils/tuple.h index 074fbfded4..202e62b5ad 100644 --- a/lib/utils/include/utils/tuple.h +++ b/lib/utils/include/utils/tuple.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_UTILS_TUPLE_H #include "utils/any.h" -#include "utils/exception.h" +#include "utils/exception.decl.h" #include #include #include diff --git a/lib/utils/src/exception.cc b/lib/utils/src/exception.cc new file mode 100644 index 0000000000..7dccdc3074 --- /dev/null +++ b/lib/utils/src/exception.cc @@ -0,0 +1,8 @@ +#include "utils/exception.h" + +namespace FlexFlow { + +not_implemented::not_implemented() + : std::logic_error("Function not yet implemented"){}; + +} diff --git a/lib/utils/src/graph/digraph.cc b/lib/utils/src/graph/digraph.cc index fdbc03bc92..341005bd08 100644 --- a/lib/utils/src/graph/digraph.cc +++ b/lib/utils/src/graph/digraph.cc @@ -102,4 +102,6 @@ DiGraph::operator DiGraphView() const { return GraphInternal::create_digraphview(this->ptr.get()); } +DiGraphView::DiGraphView(std::shared_ptr ptr) : ptr(ptr) {} + } // namespace FlexFlow diff --git a/lib/utils/src/graph/multidigraph.cc b/lib/utils/src/graph/multidigraph.cc index 51c6c41074..f21e35fbe7 100644 --- a/lib/utils/src/graph/multidigraph.cc +++ b/lib/utils/src/graph/multidigraph.cc @@ -102,6 +102,9 @@ std::unordered_set return this->ptr->query_nodes(q); } +MultiDiGraphView::MultiDiGraphView(std::shared_ptr ptr) + : ptr(ptr) {} + std::unordered_set MultiDiGraphView::query_edges(MultiDiEdgeQuery const &q) const { return this->ptr->query_edges(q); diff --git a/lib/utils/src/graph/node.cc b/lib/utils/src/graph/node.cc index e27ac36a7d..0740cde3eb 100644 --- a/lib/utils/src/graph/node.cc +++ b/lib/utils/src/graph/node.cc @@ -30,6 +30,8 @@ std::unordered_set GraphView::query_nodes(NodeQuery const &g) const { return this->ptr->query_nodes(g); } +GraphView::GraphView(std::shared_ptr ptr) : ptr(ptr) {} + // Set the shared_ptr's destructor to a nop so that effectively there is no // ownership GraphView diff --git a/lib/utils/src/graph/open_graphs.cc b/lib/utils/src/graph/open_graphs.cc index c249f309e7..0a41c45fa0 100644 --- a/lib/utils/src/graph/open_graphs.cc +++ b/lib/utils/src/graph/open_graphs.cc @@ -27,6 +27,10 @@ OpenMultiDiGraphView::operator MultiDiGraphView() const { return as_multidigraph(*this); } +OpenMultiDiGraphView::OpenMultiDiGraphView( + std::shared_ptr ptr) + : ptr(ptr) {} + OpenMultiDiGraph::OpenMultiDiGraph(OpenMultiDiGraph const &other) : ptr(other.ptr) {} diff --git a/lib/utils/src/graph/undirected.cc b/lib/utils/src/graph/undirected.cc index 695284207b..9aeeae9b63 100644 --- a/lib/utils/src/graph/undirected.cc +++ b/lib/utils/src/graph/undirected.cc @@ -67,6 +67,10 @@ UndirectedGraph::operator UndirectedGraphView() const { return GraphInternal::create_undirectedgraphview(this->ptr.get()); } +UndirectedGraphView::UndirectedGraphView( + std::shared_ptr ptr) + : ptr(ptr) {} + std::unordered_set UndirectedGraphView::query_edges(UndirectedEdgeQuery const &q) const { return this->ptr->query_edges(q); diff --git a/lib/utils/test/common/include/test/utils/doctest.h b/lib/utils/test/common/include/test/utils/doctest.h index ff7c7cff20..8089621850 100644 --- a/lib/utils/test/common/include/test/utils/doctest.h +++ b/lib/utils/test/common/include/test/utils/doctest.h @@ -1,5 +1,5 @@ #include "doctest/doctest.h" -#include "utils/containers.h" +#include "utils/containers.decl.h" #include #include #include diff --git a/lib/utils/test/src/test_depulicated_priority_queue.cc b/lib/utils/test/src/test_deduplicated_priority_queue.cc similarity index 100% rename from lib/utils/test/src/test_depulicated_priority_queue.cc rename to lib/utils/test/src/test_deduplicated_priority_queue.cc diff --git a/lib/utils/test/src/test_undirected_graph.cc b/lib/utils/test/src/test_undirected_graph.cc index b324b833b6..c6f2003ee4 100644 --- a/lib/utils/test/src/test_undirected_graph.cc +++ b/lib/utils/test/src/test_undirected_graph.cc @@ -1,6 +1,5 @@ #include "test/utils/all.h" #include "test/utils/rapidcheck/visitable.h" -#include "utils/containers.decl" #include "utils/containers.h" #include "utils/graph/hashmap_undirected_graph.h" #include "utils/graph/undirected.h"