diff --git a/library/include/hiptensor/hiptensor_types.hpp b/library/include/hiptensor/hiptensor_types.hpp index a4e1b749..90f24e26 100644 --- a/library/include/hiptensor/hiptensor_types.hpp +++ b/library/include/hiptensor/hiptensor_types.hpp @@ -159,7 +159,7 @@ typedef enum //! Log selection messages HIPTENSOR_LOG_LEVEL_HEURISTICS_TRACE = 8, //! Log a trace of API calls - HIPTENSOR_LOG_LEVEL_API_TRACE = 16 + HIPTENSOR_LOG_LEVEL_API_TRACE = 16, } hiptensorLogLevel_t; diff --git a/library/src/contraction/contraction_selection.cpp b/library/src/contraction/contraction_selection.cpp index c7983c69..9e2f6570 100644 --- a/library/src/contraction/contraction_selection.cpp +++ b/library/src/contraction/contraction_selection.cpp @@ -33,6 +33,7 @@ #endif #include "contraction_selection.hpp" +#include "logger.hpp" #include "performance.hpp" #include "util.hpp" @@ -149,6 +150,24 @@ namespace hiptensor static_cast(bytes) / static_cast(1.E6) / time // BW }; + using hiptensor::Logger; + auto& logger = Logger::instance(); + + // Log brute force timings for actor critic training + if(logger->getLogMask() & HIPTENSOR_LOG_LEVEL_HEURISTICS_TRACE) + { + // Log Kernel performances access + char msg[256]; + snprintf(msg, + sizeof(msg), + "KernelId: %lu, KernelName: %s, AvgTime: %0.3f ms", + solution->uid(), + solution->kernelName().c_str(), + time); + + logger->logHeuristics("BRUTE_FORCE_KERNEL_PERF", msg); + } + if(metrics > bestMetrics) { bestSolution = solution; @@ -189,27 +208,60 @@ namespace hiptensor hipDataType typeA, std::vector const& a_ms_ks_lengths, std::vector const& a_ms_ks_strides, + std::vector const& a_ms_ks_modes, hipDataType typeB, std::vector const& b_ns_ks_lengths, std::vector const& b_ns_ks_strides, + std::vector const& b_ns_ks_modes, hipDataType typeD, std::vector const& d_ms_ns_lengths, std::vector const& d_ms_ns_strides, + std::vector const& d_ms_ns_modes, hipDataType typeE, std::vector const& e_ms_ns_lengths, std::vector const& e_ms_ns_strides, + std::vector const& e_ms_ns_modes, const uint64_t workspaceSize) { - int d1 = a_ms_ks_lengths[0]; - int d2 = a_ms_ks_lengths[1]; - int d3 = b_ns_ks_lengths[0]; - int d4 = b_ns_ks_lengths[1]; - int d5 = a_ms_ks_lengths[2]; - int d6 = a_ms_ks_lengths[3]; + int d1 = a_ms_ks_strides[1]; + int d2 = a_ms_ks_strides[3]; + int d3 = a_ms_ks_strides[5]; + int d4 = a_ms_ks_strides[7]; + int d5 = a_ms_ks_strides[9]; + int d6 = a_ms_ks_strides[11]; size_t unique_id = 0; - unique_id = 11124293857315312720ull; + // m1n1k1 + if(d1 == 1) + { + unique_id = 2317674114976786230ull; + } + // m2n2k2 + else if(d2 == 1) + { + unique_id = 2317674114976786230ull; + } + // m3n3k3 + else if(d3 == 1) + { + unique_id = 2317674114976786230ull; + } + // m4n4k4 + else if(d4 == 1) + { + unique_id = 12241437837959333440ull; + } + // m5n5k5 + else if(d5 == 1) + { + unique_id = 12241437837959333440ull; + } + // m6n6k6 + else if(d6 == 1) + { + unique_id = 11152060091307708334ull; + } if(auto candidate = candidates.find(unique_id); candidate != candidates.end()) { @@ -237,27 +289,69 @@ namespace hiptensor hipDataType typeA, std::vector const& a_ms_ks_lengths, std::vector const& a_ms_ks_strides, + std::vector const& a_ms_ks_modes, hipDataType typeB, std::vector const& b_ns_ks_lengths, std::vector const& b_ns_ks_strides, + std::vector const& b_ns_ks_modes, hipDataType typeD, std::vector const& d_ms_ns_lengths, std::vector const& d_ms_ns_strides, + std::vector const& d_ms_ns_modes, hipDataType typeE, std::vector const& e_ms_ns_lengths, std::vector const& e_ms_ns_strides, + std::vector const& e_ms_ns_modes, const uint64_t workspaceSize) { - int d1 = a_ms_ks_lengths[0]; - int d2 = a_ms_ks_lengths[1]; - int d3 = b_ns_ks_lengths[0]; - int d4 = b_ns_ks_lengths[1]; - int d5 = a_ms_ks_lengths[2]; - int d6 = a_ms_ks_lengths[3]; + int d1 = a_ms_ks_strides[1]; + int d2 = a_ms_ks_strides[3]; + int d3 = a_ms_ks_strides[5]; + int d4 = a_ms_ks_strides[7]; + int d5 = a_ms_ks_strides[9]; + int d6 = a_ms_ks_strides[11]; size_t unique_id = 0; - unique_id = 1953020431947874122ull; + bool dim1 = std::count(a_ms_ks_lengths.cbegin(), a_ms_ks_lengths.cend(), 1) + || std::count(b_ns_ks_lengths.cbegin(), b_ns_ks_lengths.cend(), 1); + + // rank2 dim1 case + if(d2 == 1 && dim1) + { + unique_id = 58303249112943560ull; + } + // m1n1k1 + else if(d1 == 1) + // if (d1 == 1 || (d2 == 1 && (a_ms_ks_lengths[3] == 1 || b_ns_ks_lengths[3] == 1))) + { + unique_id = 58303249112943560ull; + } + // m2n2k2 + else if(d2 == 1) + { + unique_id = 2303552229010777601ull; + } + // m3n3k3 + else if(d3 == 1) + { + unique_id = 58303249112943560ull; + } + // m4n4k4 + else if(d4 == 1) + { + unique_id = 58303249112943560ull; + } + // m5n5k5 + else if(d5 == 1) + { + unique_id = 58303249112943560ull; + } + // m6n6k6 + else if(d6 == 1) + { + unique_id = 2303552229010777601ull; + } if(auto candidate = candidates.find(unique_id); candidate != candidates.end()) { @@ -285,27 +379,60 @@ namespace hiptensor hipDataType typeA, std::vector const& a_ms_ks_lengths, std::vector const& a_ms_ks_strides, + std::vector const& a_ms_ks_modes, hipDataType typeB, std::vector const& b_ns_ks_lengths, std::vector const& b_ns_ks_strides, + std::vector const& b_ns_ks_modes, hipDataType typeD, std::vector const& d_ms_ns_lengths, std::vector const& d_ms_ns_strides, + std::vector const& d_ms_ns_modes, hipDataType typeE, std::vector const& e_ms_ns_lengths, std::vector const& e_ms_ns_strides, + std::vector const& e_ms_ns_modes, const uint64_t workspaceSize) { - int d1 = a_ms_ks_lengths[0]; - int d2 = a_ms_ks_lengths[1]; - int d3 = b_ns_ks_lengths[0]; - int d4 = b_ns_ks_lengths[1]; - int d5 = a_ms_ks_lengths[2]; - int d6 = a_ms_ks_lengths[3]; + int d1 = a_ms_ks_strides[1]; + int d2 = a_ms_ks_strides[3]; + int d3 = a_ms_ks_strides[5]; + int d4 = a_ms_ks_strides[7]; + int d5 = a_ms_ks_strides[9]; + int d6 = a_ms_ks_strides[11]; size_t unique_id = 0; - unique_id = 14895098881714635802ull; + // m1n1k1 + if(d1 == 1) + { + unique_id = 9967477699864925937ull; + } + // m2n2k2 + else if(d2 == 1) + { + unique_id = 14071475272156866885ull; + } + // m3n3k3 + else if(d3 == 1) + { + unique_id = 14071475272156866885ull; + } + // m4n4k4 + else if(d4 == 1) + { + unique_id = 15452087623356707112ull; + } + // m5n5k5 + else if(d5 == 1) + { + unique_id = 15452087623356707112ull; + } + // m6n6k6 + else if(d6 == 1) + { + unique_id = 8307633941691601884ull; + } if(auto candidate = candidates.find(unique_id); candidate != candidates.end()) { @@ -333,27 +460,68 @@ namespace hiptensor hipDataType typeA, std::vector const& a_ms_ks_lengths, std::vector const& a_ms_ks_strides, + std::vector const& a_ms_ks_modes, hipDataType typeB, std::vector const& b_ns_ks_lengths, std::vector const& b_ns_ks_strides, + std::vector const& b_ns_ks_modes, hipDataType typeD, std::vector const& d_ms_ns_lengths, std::vector const& d_ms_ns_strides, + std::vector const& d_ms_ns_modes, hipDataType typeE, std::vector const& e_ms_ns_lengths, std::vector const& e_ms_ns_strides, + std::vector const& e_ms_ns_modes, const uint64_t workspaceSize) { - int d1 = a_ms_ks_lengths[0]; - int d2 = a_ms_ks_lengths[1]; - int d3 = b_ns_ks_lengths[0]; - int d4 = b_ns_ks_lengths[1]; - int d5 = a_ms_ks_lengths[2]; - int d6 = a_ms_ks_lengths[3]; + int d1 = a_ms_ks_strides[1]; + int d2 = a_ms_ks_strides[3]; + int d3 = a_ms_ks_strides[5]; + int d4 = a_ms_ks_strides[7]; + int d5 = a_ms_ks_strides[9]; + int d6 = a_ms_ks_strides[11]; size_t unique_id = 0; - unique_id = 8517235228581081946ull; + bool dim1 = std::count(a_ms_ks_lengths.cbegin(), a_ms_ks_lengths.cend(), 1) + || std::count(b_ns_ks_lengths.cbegin(), b_ns_ks_lengths.cend(), 1); + + // rank2 dim1 case + if(d2 == 1 && dim1) + { + unique_id = 16299024124514902126ull; + } + // m1n1k1 + else if(d1 == 1) + { + unique_id = 378062791888302715ull; + } + // m2n2k2 + else if(d2 == 1) + { + unique_id = 76527422265261696ull; + } + // m3n3k3 + else if(d3 == 1) + { + unique_id = 378062791888302715ull; + } + // m4n4k4 + else if(d4 == 1) + { + unique_id = 378062791888302715ull; + } + // m5n5k5 + else if(d5 == 1) + { + unique_id = 378062791888302715ull; + } + // m6n6k6 + else if(d6 == 1) + { + unique_id = 378062791888302715ull; + } if(auto candidate = candidates.find(unique_id); candidate != candidates.end()) { @@ -376,27 +544,60 @@ namespace hiptensor hipDataType typeA, std::vector const& a_ms_ks_lengths, std::vector const& a_ms_ks_strides, + std::vector const& a_ms_ks_modes, hipDataType typeB, std::vector const& b_ns_ks_lengths, std::vector const& b_ns_ks_strides, + std::vector const& b_ns_ks_modes, hipDataType typeD, std::vector const& d_ms_ns_lengths, std::vector const& d_ms_ns_strides, + std::vector const& d_ms_ns_modes, hipDataType typeE, std::vector const& e_ms_ns_lengths, std::vector const& e_ms_ns_strides, + std::vector const& e_ms_ns_modes, const uint64_t workspaceSize) { - int d1 = a_ms_ks_lengths[0]; - int d2 = a_ms_ks_lengths[1]; - int d3 = b_ns_ks_lengths[0]; - int d4 = b_ns_ks_lengths[1]; - int d5 = a_ms_ks_lengths[2]; - int d6 = a_ms_ks_lengths[3]; + int d1 = a_ms_ks_strides[1]; + int d2 = a_ms_ks_strides[3]; + int d3 = a_ms_ks_strides[5]; + int d4 = a_ms_ks_strides[7]; + int d5 = a_ms_ks_strides[9]; + int d6 = a_ms_ks_strides[11]; size_t unique_id = 0; - unique_id = 17313709378682913599ull; + // m1n1k1 + if(d1 == 1) + { + unique_id = 17141562253969597117ull; + } + // m2n2k2 + else if(d2 == 1) + { + unique_id = 17141562253969597117ull; + } + // m3n3k3 + else if(d3 == 1) + { + unique_id = 17141562253969597117ull; + } + // m4n4k4 + else if(d4 == 1) + { + unique_id = 17141562253969597117ull; + } + // m5n5k5 + else if(d5 == 1) + { + unique_id = 17141562253969597117ull; + } + // m6n6k6 + else if(d6 == 1) + { + unique_id = 6384780398804323250ull; + } if(auto candidate = candidates.find(unique_id); candidate != candidates.end()) { @@ -419,27 +620,68 @@ namespace hiptensor hipDataType typeA, std::vector const& a_ms_ks_lengths, std::vector const& a_ms_ks_strides, + std::vector const& a_ms_ks_modes, hipDataType typeB, std::vector const& b_ns_ks_lengths, std::vector const& b_ns_ks_strides, + std::vector const& b_ns_ks_modes, hipDataType typeD, std::vector const& d_ms_ns_lengths, std::vector const& d_ms_ns_strides, + std::vector const& d_ms_ns_modes, hipDataType typeE, std::vector const& e_ms_ns_lengths, std::vector const& e_ms_ns_strides, + std::vector const& e_ms_ns_modes, const uint64_t workspaceSize) { - int d1 = a_ms_ks_lengths[0]; - int d2 = a_ms_ks_lengths[1]; - int d3 = b_ns_ks_lengths[0]; - int d4 = b_ns_ks_lengths[1]; - int d5 = a_ms_ks_lengths[2]; - int d6 = a_ms_ks_lengths[3]; + int d1 = a_ms_ks_strides[1]; + int d2 = a_ms_ks_strides[3]; + int d3 = a_ms_ks_strides[5]; + int d4 = a_ms_ks_strides[7]; + int d5 = a_ms_ks_strides[9]; + int d6 = a_ms_ks_strides[11]; size_t unique_id = 0; - unique_id = 14397647188602189900ull; + bool dim1 = std::count(a_ms_ks_lengths.cbegin(), a_ms_ks_lengths.cend(), 1) + || std::count(b_ns_ks_lengths.cbegin(), b_ns_ks_lengths.cend(), 1); + + // rank2 dim1 case + if(d2 == 1 && dim1) + { + unique_id = 8251132190088736039ull; + } + // m1n1k1 + else if(d1 == 1) + { + unique_id = 2897979232477761524ull; + } + // m2n2k2 + else if(d2 == 1) + { + unique_id = 2897979232477761524ull; + } + // m3n3k3 + else if(d3 == 1) + { + unique_id = 2897979232477761524ull; + } + // m4n4k4 + else if(d4 == 1) + { + unique_id = 2897979232477761524ull; + } + // m5n5k5 + else if(d5 == 1) + { + unique_id = 2897979232477761524ull; + } + // m6n6k6 + else if(d6 == 1) + { + unique_id = 2897979232477761524ull; + } if(auto candidate = candidates.find(unique_id); candidate != candidates.end()) { @@ -462,27 +704,60 @@ namespace hiptensor hipDataType typeA, std::vector const& a_ms_ks_lengths, std::vector const& a_ms_ks_strides, + std::vector const& a_ms_ks_modes, hipDataType typeB, std::vector const& b_ns_ks_lengths, std::vector const& b_ns_ks_strides, + std::vector const& b_ns_ks_modes, hipDataType typeD, std::vector const& d_ms_ns_lengths, std::vector const& d_ms_ns_strides, + std::vector const& d_ms_ns_modes, hipDataType typeE, std::vector const& e_ms_ns_lengths, std::vector const& e_ms_ns_strides, + std::vector const& e_ms_ns_modes, const uint64_t workspaceSize) { - int d1 = a_ms_ks_lengths[0]; - int d2 = a_ms_ks_lengths[1]; - int d3 = b_ns_ks_lengths[0]; - int d4 = b_ns_ks_lengths[1]; - int d5 = a_ms_ks_lengths[2]; - int d6 = a_ms_ks_lengths[3]; + int d1 = a_ms_ks_strides[1]; + int d2 = a_ms_ks_strides[3]; + int d3 = a_ms_ks_strides[5]; + int d4 = a_ms_ks_strides[7]; + int d5 = a_ms_ks_strides[9]; + int d6 = a_ms_ks_strides[11]; size_t unique_id = 0; - unique_id = 8339198051871565944ull; + // m1n1k1 + if(d1 == 1) + { + unique_id = 4373449368168185126ull; + } + // m2n2k2 + else if(d2 == 1) + { + unique_id = 4373449368168185126ull; + } + // m3n3k3 + else if(d3 == 1) + { + unique_id = 2008216990064456310ull; + } + // m4n4k4 + else if(d4 == 1) + { + unique_id = 4373449368168185126ull; + } + // m5n5k5 + else if(d5 == 1) + { + unique_id = 13613206280884761703ull; + } + // m6n6k6 + else if(d6 == 1) + { + unique_id = 15116758930810193332ull; + } if(auto candidate = candidates.find(unique_id); candidate != candidates.end()) { @@ -510,27 +785,68 @@ namespace hiptensor hipDataType typeA, std::vector const& a_ms_ks_lengths, std::vector const& a_ms_ks_strides, + std::vector const& a_ms_ks_modes, hipDataType typeB, std::vector const& b_ns_ks_lengths, std::vector const& b_ns_ks_strides, + std::vector const& b_ns_ks_modes, hipDataType typeD, std::vector const& d_ms_ns_lengths, std::vector const& d_ms_ns_strides, + std::vector const& d_ms_ns_modes, hipDataType typeE, std::vector const& e_ms_ns_lengths, std::vector const& e_ms_ns_strides, + std::vector const& e_ms_ns_modes, const uint64_t workspaceSize) { - int d1 = a_ms_ks_lengths[0]; - int d2 = a_ms_ks_lengths[1]; - int d3 = b_ns_ks_lengths[0]; - int d4 = b_ns_ks_lengths[1]; - int d5 = a_ms_ks_lengths[2]; - int d6 = a_ms_ks_lengths[3]; + int d1 = a_ms_ks_strides[1]; + int d2 = a_ms_ks_strides[3]; + int d3 = a_ms_ks_strides[5]; + int d4 = a_ms_ks_strides[7]; + int d5 = a_ms_ks_strides[9]; + int d6 = a_ms_ks_strides[11]; size_t unique_id = 0; - unique_id = 2724417728984064737ull; + bool dim1 = std::count(a_ms_ks_lengths.cbegin(), a_ms_ks_lengths.cend(), 1) + || std::count(b_ns_ks_lengths.cbegin(), b_ns_ks_lengths.cend(), 1); + + // rank2 dim1 case + if(d2 == 1 && dim1) + { + unique_id = 8067958629699904967ull; + } + // m1n1k1 + else if(d1 == 1) + { + unique_id = 8116863550692548667ull; + } + // m2n2k2 + else if(d2 == 1) + { + unique_id = 8116863550692548667ull; + } + // m3n3k3 + else if(d3 == 1) + { + unique_id = 8116863550692548667ull; + } + // m4n4k4 + else if(d4 == 1) + { + unique_id = 8116863550692548667ull; + } + // m5n5k5 + else if(d5 == 1) + { + unique_id = 8116863550692548667ull; + } + // m6n6k6 + else if(d6 == 1) + { + unique_id = 8116863550692548667ull; + } if(auto candidate = candidates.find(unique_id); candidate != candidates.end()) { @@ -553,27 +869,60 @@ namespace hiptensor hipDataType typeA, std::vector const& a_ms_ks_lengths, std::vector const& a_ms_ks_strides, + std::vector const& a_ms_ks_modes, hipDataType typeB, std::vector const& b_ns_ks_lengths, std::vector const& b_ns_ks_strides, + std::vector const& b_ns_ks_modes, hipDataType typeD, std::vector const& d_ms_ns_lengths, std::vector const& d_ms_ns_strides, + std::vector const& d_ms_ns_modes, hipDataType typeE, std::vector const& e_ms_ns_lengths, std::vector const& e_ms_ns_strides, + std::vector const& e_ms_ns_modes, const uint64_t workspaceSize) { - int d1 = a_ms_ks_lengths[0]; - int d2 = a_ms_ks_lengths[1]; - int d3 = b_ns_ks_lengths[0]; - int d4 = b_ns_ks_lengths[1]; - int d5 = a_ms_ks_lengths[2]; - int d6 = a_ms_ks_lengths[3]; + int d1 = a_ms_ks_strides[1]; + int d2 = a_ms_ks_strides[3]; + int d3 = a_ms_ks_strides[5]; + int d4 = a_ms_ks_strides[7]; + int d5 = a_ms_ks_strides[9]; + int d6 = a_ms_ks_strides[11]; size_t unique_id = 0; - unique_id = 5943247903036531691ull; + // m1n1k1 + if(d1 == 1) + { + unique_id = 5794367356792942822ull; + } + // m2n2k2 + else if(d2 == 1) + { + unique_id = 17939389824758640014ull; + } + // m3n3k3 + else if(d3 == 1) + { + unique_id = 10640128726648594287ull; + } + // m4n4k4 + else if(d4 == 1) + { + unique_id = 5794367356792942822ull; + } + // m5n5k5 + else if(d5 == 1) + { + unique_id = 5794367356792942822ull; + } + // m6n6k6 + else if(d6 == 1) + { + unique_id = 13933081369664111675ull; + } if(auto candidate = candidates.find(unique_id); candidate != candidates.end()) { @@ -596,27 +945,68 @@ namespace hiptensor hipDataType typeA, std::vector const& a_ms_ks_lengths, std::vector const& a_ms_ks_strides, + std::vector const& a_ms_ks_modes, hipDataType typeB, std::vector const& b_ns_ks_lengths, std::vector const& b_ns_ks_strides, + std::vector const& b_ns_ks_modes, hipDataType typeD, std::vector const& d_ms_ns_lengths, std::vector const& d_ms_ns_strides, + std::vector const& d_ms_ns_modes, hipDataType typeE, std::vector const& e_ms_ns_lengths, std::vector const& e_ms_ns_strides, + std::vector const& e_ms_ns_modes, const uint64_t workspaceSize) { - int d1 = a_ms_ks_lengths[0]; - int d2 = a_ms_ks_lengths[1]; - int d3 = b_ns_ks_lengths[0]; - int d4 = b_ns_ks_lengths[1]; - int d5 = a_ms_ks_lengths[2]; - int d6 = a_ms_ks_lengths[3]; + int d1 = a_ms_ks_strides[1]; + int d2 = a_ms_ks_strides[3]; + int d3 = a_ms_ks_strides[5]; + int d4 = a_ms_ks_strides[7]; + int d5 = a_ms_ks_strides[9]; + int d6 = a_ms_ks_strides[11]; size_t unique_id = 0; - unique_id = 17972447156160297755ull; + bool dim1 = std::count(a_ms_ks_lengths.cbegin(), a_ms_ks_lengths.cend(), 1) + || std::count(b_ns_ks_lengths.cbegin(), b_ns_ks_lengths.cend(), 1); + + // rank2 dim1 case + if(d2 == 1 && dim1) + { + unique_id = 14915761978535949477ull; + } + // m1n1k1 + else if(d1 == 1) + { + unique_id = 14915761978535949477ull; + } + // m2n2k2 + else if(d2 == 1) + { + unique_id = 14915761978535949477ull; + } + // m3n3k3 + else if(d3 == 1) + { + unique_id = 14915761978535949477ull; + } + // m4n4k4 + else if(d4 == 1) + { + unique_id = 14915761978535949477ull; + } + // m5n5k5 + else if(d5 == 1) + { + unique_id = 14915761978535949477ull; + } + // m6n6k6 + else if(d6 == 1) + { + unique_id = 14915761978535949477ull; + } if(auto candidate = candidates.find(unique_id); candidate != candidates.end()) { @@ -639,28 +1029,61 @@ namespace hiptensor hipDataType typeA, std::vector const& a_ms_ks_lengths, std::vector const& a_ms_ks_strides, + std::vector const& a_ms_ks_modes, hipDataType typeB, std::vector const& b_ns_ks_lengths, std::vector const& b_ns_ks_strides, + std::vector const& b_ns_ks_modes, hipDataType typeD, std::vector const& d_ms_ns_lengths, std::vector const& d_ms_ns_strides, + std::vector const& d_ms_ns_modes, hipDataType typeE, std::vector const& e_ms_ns_lengths, std::vector const& e_ms_ns_strides, + std::vector const& e_ms_ns_modes, const uint64_t workspaceSize) { - int d1 = a_ms_ks_lengths[0]; - int d2 = a_ms_ks_lengths[1]; - int d3 = b_ns_ks_lengths[0]; - int d4 = b_ns_ks_lengths[1]; - int d5 = a_ms_ks_lengths[2]; - int d6 = a_ms_ks_lengths[3]; + int d1 = a_ms_ks_strides[1]; + int d2 = a_ms_ks_strides[3]; + int d3 = a_ms_ks_strides[5]; + int d4 = a_ms_ks_strides[7]; + int d5 = a_ms_ks_strides[9]; + int d6 = a_ms_ks_strides[11]; size_t unique_id = 0; - unique_id = 3893144338697524749ull; + // m1n1k1 + if(d1 == 1) + { + unique_id = 18207091374964962208ull; + } + // m2n2k2 + else if(d2 == 1) + { + unique_id = 16948282955506101335ull; + } + // m3n3k3 + else if(d3 == 1) + { + unique_id = 16870758234615651290ull; + } + // m4n4k4 + else if(d4 == 1) + { + unique_id = 15355329505248522280ull; + } + // m5n5k5 + else if(d5 == 1) + { + unique_id = 14642257549075851915ull; + } + // m6n6k6 + else if(d6 == 1) + { + unique_id = 14642257549075851915ull; + } if(auto candidate = candidates.find(unique_id); candidate != candidates.end()) { @@ -683,26 +1106,68 @@ namespace hiptensor hipDataType typeA, std::vector const& a_ms_ks_lengths, std::vector const& a_ms_ks_strides, + std::vector const& a_ms_ks_modes, hipDataType typeB, std::vector const& b_ns_ks_lengths, std::vector const& b_ns_ks_strides, + std::vector const& b_ns_ks_modes, hipDataType typeD, std::vector const& d_ms_ns_lengths, std::vector const& d_ms_ns_strides, + std::vector const& d_ms_ns_modes, hipDataType typeE, std::vector const& e_ms_ns_lengths, std::vector const& e_ms_ns_strides, + std::vector const& e_ms_ns_modes, const uint64_t workspaceSize) { - int d1 = a_ms_ks_lengths[0]; - int d2 = a_ms_ks_lengths[1]; - int d3 = b_ns_ks_lengths[0]; - int d4 = b_ns_ks_lengths[1]; - int d5 = a_ms_ks_lengths[2]; - int d6 = a_ms_ks_lengths[3]; + int d1 = a_ms_ks_strides[1]; + int d2 = a_ms_ks_strides[3]; + int d3 = a_ms_ks_strides[5]; + int d4 = a_ms_ks_strides[7]; + int d5 = a_ms_ks_strides[9]; + int d6 = a_ms_ks_strides[11]; size_t unique_id = 0; - unique_id = 15165261158317928321ull; + + bool dim1 = std::count(a_ms_ks_lengths.cbegin(), a_ms_ks_lengths.cend(), 1) + || std::count(b_ns_ks_lengths.cbegin(), b_ns_ks_lengths.cend(), 1); + + // rank2 dim1 case + if(d2 == 1 && dim1) + { + unique_id = 11269655469469274301ull; + } + // m1n1k1 + else if(d1 == 1) + { + unique_id = 2143493311543532856ull; + } + // m2n2k2 + else if(d2 == 1) + { + unique_id = 2143493311543532856ull; + } + // m3n3k3 + else if(d3 == 1) + { + unique_id = 2143493311543532856ull; + } + // m4n4k4 + else if(d4 == 1) + { + unique_id = 2143493311543532856ull; + } + // m5n5k5 + else if(d5 == 1) + { + unique_id = 2143493311543532856ull; + } + // m6n6k6 + else if(d6 == 1) + { + unique_id = 2143493311543532856ull; + } if(auto candidate = candidates.find(unique_id); candidate != candidates.end()) { @@ -725,28 +1190,61 @@ namespace hiptensor hipDataType typeA, std::vector const& a_ms_ks_lengths, std::vector const& a_ms_ks_strides, + std::vector const& a_ms_ks_modes, hipDataType typeB, std::vector const& b_ns_ks_lengths, std::vector const& b_ns_ks_strides, + std::vector const& b_ns_ks_modes, hipDataType typeD, std::vector const& d_ms_ns_lengths, std::vector const& d_ms_ns_strides, + std::vector const& d_ms_ns_modes, hipDataType typeE, std::vector const& e_ms_ns_lengths, std::vector const& e_ms_ns_strides, + std::vector const& e_ms_ns_modes, const uint64_t workspaceSize) { - int d1 = a_ms_ks_lengths[0]; - int d2 = a_ms_ks_lengths[1]; - int d3 = b_ns_ks_lengths[0]; - int d4 = b_ns_ks_lengths[1]; - int d5 = a_ms_ks_lengths[2]; - int d6 = a_ms_ks_lengths[3]; + int d1 = a_ms_ks_strides[1]; + int d2 = a_ms_ks_strides[3]; + int d3 = a_ms_ks_strides[5]; + int d4 = a_ms_ks_strides[7]; + int d5 = a_ms_ks_strides[9]; + int d6 = a_ms_ks_strides[11]; size_t unique_id = 0; - unique_id = 14511729289005214097ull; + // m1n1k1 + if(d1 == 1) + { + unique_id = 3879892272436099392ull; + } + // m2n2k2 + else if(d2 == 1) + { + unique_id = 8021137963958390646ull; + } + // m3n3k3 + else if(d3 == 1) + { + unique_id = 3248584345341330494ull; + } + // m4n4k4 + else if(d4 == 1) + { + unique_id = 3879892272436099392ull; + } + // m5n5k5 + else if(d5 == 1) + { + unique_id = 3879892272436099392ull; + } + // m6n6k6 + else if(d6 == 1) + { + unique_id = 7950787545240972863ull; + } if(auto candidate = candidates.find(unique_id); candidate != candidates.end()) { @@ -769,27 +1267,68 @@ namespace hiptensor hipDataType typeA, std::vector const& a_ms_ks_lengths, std::vector const& a_ms_ks_strides, + std::vector const& a_ms_ks_modes, hipDataType typeB, std::vector const& b_ns_ks_lengths, std::vector const& b_ns_ks_strides, + std::vector const& b_ns_ks_modes, hipDataType typeD, std::vector const& d_ms_ns_lengths, std::vector const& d_ms_ns_strides, + std::vector const& d_ms_ns_modes, hipDataType typeE, std::vector const& e_ms_ns_lengths, std::vector const& e_ms_ns_strides, + std::vector const& e_ms_ns_modes, const uint64_t workspaceSize) { - int d1 = a_ms_ks_lengths[0]; - int d2 = a_ms_ks_lengths[1]; - int d3 = b_ns_ks_lengths[0]; - int d4 = b_ns_ks_lengths[1]; - int d5 = a_ms_ks_lengths[2]; - int d6 = a_ms_ks_lengths[3]; + int d1 = a_ms_ks_strides[1]; + int d2 = a_ms_ks_strides[3]; + int d3 = a_ms_ks_strides[5]; + int d4 = a_ms_ks_strides[7]; + int d5 = a_ms_ks_strides[9]; + int d6 = a_ms_ks_strides[11]; size_t unique_id = 0; - unique_id = 3636246152928348445ull; + bool dim1 = std::count(a_ms_ks_lengths.cbegin(), a_ms_ks_lengths.cend(), 1) + || std::count(b_ns_ks_lengths.cbegin(), b_ns_ks_lengths.cend(), 1); + + // rank2 dim1 case + if(d2 == 1 && dim1) + { + unique_id = 2054609181761357786ull; + } + // m1n1k1 + else if(d1 == 1) + { + unique_id = 14145390177844245465ull; + } + // m2n2k2 + else if(d2 == 1) + { + unique_id = 14145390177844245465ull; + } + // m3n3k3 + else if(d3 == 1) + { + unique_id = 14145390177844245465ull; + } + // m4n4k4 + else if(d4 == 1) + { + unique_id = 14145390177844245465ull; + } + // m5n5k5 + else if(d5 == 1) + { + unique_id = 14145390177844245465ull; + } + // m6n6k6 + else if(d6 == 1) + { + unique_id = 14145390177844245465ull; + } if(auto candidate = candidates.find(unique_id); candidate != candidates.end()) { @@ -817,28 +1356,61 @@ namespace hiptensor hipDataType typeA, std::vector const& a_ms_ks_lengths, std::vector const& a_ms_ks_strides, + std::vector const& a_ms_ks_modes, hipDataType typeB, std::vector const& b_ns_ks_lengths, std::vector const& b_ns_ks_strides, + std::vector const& b_ns_ks_modes, hipDataType typeD, std::vector const& d_ms_ns_lengths, std::vector const& d_ms_ns_strides, + std::vector const& d_ms_ns_modes, hipDataType typeE, std::vector const& e_ms_ns_lengths, std::vector const& e_ms_ns_strides, + std::vector const& e_ms_ns_modes, const uint64_t workspaceSize) { - int d1 = a_ms_ks_lengths[0]; - int d2 = a_ms_ks_lengths[1]; - int d3 = b_ns_ks_lengths[0]; - int d4 = b_ns_ks_lengths[1]; - int d5 = a_ms_ks_lengths[2]; - int d6 = a_ms_ks_lengths[3]; + int d1 = a_ms_ks_strides[1]; + int d2 = a_ms_ks_strides[3]; + int d3 = a_ms_ks_strides[5]; + int d4 = a_ms_ks_strides[7]; + int d5 = a_ms_ks_strides[9]; + int d6 = a_ms_ks_strides[11]; size_t unique_id = 0; - unique_id = 5711776907278244209ull; + // m1n1k1 + if(d1 == 1) + { + unique_id = 1688099565795560288ull; + } + // m2n2k2 + else if(d2 == 1) + { + unique_id = 4348837698146370003ull; + } + // m3n3k3 + else if(d3 == 1) + { + unique_id = 1688099565795560288ull; + } + // m4n4k4 + else if(d4 == 1) + { + unique_id = 1688099565795560288ull; + } + // m5n5k5 + else if(d5 == 1) + { + unique_id = 1688099565795560288ull; + } + // m6n6k6 + else if(d6 == 1) + { + unique_id = 4363356859752806590ull; + } if(auto candidate = candidates.find(unique_id); candidate != candidates.end()) { @@ -866,27 +1438,68 @@ namespace hiptensor hipDataType typeA, std::vector const& a_ms_ks_lengths, std::vector const& a_ms_ks_strides, + std::vector const& a_ms_ks_modes, hipDataType typeB, std::vector const& b_ns_ks_lengths, std::vector const& b_ns_ks_strides, + std::vector const& b_ns_ks_modes, hipDataType typeD, std::vector const& d_ms_ns_lengths, std::vector const& d_ms_ns_strides, + std::vector const& d_ms_ns_modes, hipDataType typeE, std::vector const& e_ms_ns_lengths, std::vector const& e_ms_ns_strides, + std::vector const& e_ms_ns_modes, const uint64_t workspaceSize) { - int d1 = a_ms_ks_lengths[0]; - int d2 = a_ms_ks_lengths[1]; - int d3 = b_ns_ks_lengths[0]; - int d4 = b_ns_ks_lengths[1]; - int d5 = a_ms_ks_lengths[2]; - int d6 = a_ms_ks_lengths[3]; + int d1 = a_ms_ks_strides[1]; + int d2 = a_ms_ks_strides[3]; + int d3 = a_ms_ks_strides[5]; + int d4 = a_ms_ks_strides[7]; + int d5 = a_ms_ks_strides[9]; + int d6 = a_ms_ks_strides[11]; size_t unique_id = 0; - unique_id = 355777364055884033ull; + bool dim1 = std::count(a_ms_ks_lengths.cbegin(), a_ms_ks_lengths.cend(), 1) + || std::count(b_ns_ks_lengths.cbegin(), b_ns_ks_lengths.cend(), 1); + + // rank2 dim1 case + if(d2 == 1 && dim1) + { + unique_id = 15330878641001915472ull; + } + // m1n1k1 + else if(d1 == 1) + { + unique_id = 11537900932066889768ull; + } + // m2n2k2 + else if(d2 == 1) + { + unique_id = 8338926107119209426ull; + } + // m3n3k3 + else if(d3 == 1) + { + unique_id = 11537900932066889768ull; + } + // m4n4k4 + else if(d4 == 1) + { + unique_id = 11537900932066889768ull; + } + // m5n5k5 + else if(d5 == 1) + { + unique_id = 11537900932066889768ull; + } + // m6n6k6 + else if(d6 == 1) + { + unique_id = 11537900932066889768ull; + } if(auto candidate = candidates.find(unique_id); candidate != candidates.end()) { @@ -914,28 +1527,61 @@ namespace hiptensor hipDataType typeA, std::vector const& a_ms_ks_lengths, std::vector const& a_ms_ks_strides, + std::vector const& a_ms_ks_modes, hipDataType typeB, std::vector const& b_ns_ks_lengths, std::vector const& b_ns_ks_strides, + std::vector const& b_ns_ks_modes, hipDataType typeD, std::vector const& d_ms_ns_lengths, std::vector const& d_ms_ns_strides, + std::vector const& d_ms_ns_modes, hipDataType typeE, std::vector const& e_ms_ns_lengths, std::vector const& e_ms_ns_strides, + std::vector const& e_ms_ns_modes, const uint64_t workspaceSize) { - int d1 = a_ms_ks_lengths[0]; - int d2 = a_ms_ks_lengths[1]; - int d3 = b_ns_ks_lengths[0]; - int d4 = b_ns_ks_lengths[1]; - int d5 = a_ms_ks_lengths[2]; - int d6 = a_ms_ks_lengths[3]; + int d1 = a_ms_ks_strides[1]; + int d2 = a_ms_ks_strides[3]; + int d3 = a_ms_ks_strides[5]; + int d4 = a_ms_ks_strides[7]; + int d5 = a_ms_ks_strides[9]; + int d6 = a_ms_ks_strides[11]; size_t unique_id = 0; - unique_id = 3085227716611397774ull; + // m1n1k1 + if(d1 == 1) + { + unique_id = 10254320286859648634ull; + } + // m2n2k2 + else if(d2 == 1) + { + unique_id = 15705829219230515535ull; + } + // m3n3k3 + else if(d3 == 1) + { + unique_id = 12959721676360111684ull; + } + // m4n4k4 + else if(d4 == 1) + { + unique_id = 10254320286859648634ull; + } + // m5n5k5 + else if(d5 == 1) + { + unique_id = 10254320286859648634ull; + } + // m6n6k6 + else if(d6 == 1) + { + unique_id = 10254320286859648634ull; + } if(auto candidate = candidates.find(unique_id); candidate != candidates.end()) { @@ -963,27 +1609,68 @@ namespace hiptensor hipDataType typeA, std::vector const& a_ms_ks_lengths, std::vector const& a_ms_ks_strides, + std::vector const& a_ms_ks_modes, hipDataType typeB, std::vector const& b_ns_ks_lengths, std::vector const& b_ns_ks_strides, + std::vector const& b_ns_ks_modes, hipDataType typeD, std::vector const& d_ms_ns_lengths, std::vector const& d_ms_ns_strides, + std::vector const& d_ms_ns_modes, hipDataType typeE, std::vector const& e_ms_ns_lengths, std::vector const& e_ms_ns_strides, + std::vector const& e_ms_ns_modes, const uint64_t workspaceSize) { - int d1 = a_ms_ks_lengths[0]; - int d2 = a_ms_ks_lengths[1]; - int d3 = b_ns_ks_lengths[0]; - int d4 = b_ns_ks_lengths[1]; - int d5 = a_ms_ks_lengths[2]; - int d6 = a_ms_ks_lengths[3]; + int d1 = a_ms_ks_strides[1]; + int d2 = a_ms_ks_strides[3]; + int d3 = a_ms_ks_strides[5]; + int d4 = a_ms_ks_strides[7]; + int d5 = a_ms_ks_strides[9]; + int d6 = a_ms_ks_strides[11]; size_t unique_id = 0; - unique_id = 2196983681630807584ull; + bool dim1 = std::count(a_ms_ks_lengths.cbegin(), a_ms_ks_lengths.cend(), 1) + || std::count(b_ns_ks_lengths.cbegin(), b_ns_ks_lengths.cend(), 1); + + // rank2 dim1 case + if(d2 == 1 && dim1) + { + unique_id = 14051358583041094215ull; + } + // m1n1k1 + else if(d1 == 1) + { + unique_id = 8503926755447648324ull; + } + // m2n2k2 + else if(d2 == 1) + { + unique_id = 8503926755447648324ull; + } + // m3n3k3 + else if(d3 == 1) + { + unique_id = 8503926755447648324ull; + } + // m4n4k4 + else if(d4 == 1) + { + unique_id = 8503926755447648324ull; + } + // m5n5k5 + else if(d5 == 1) + { + unique_id = 8503926755447648324ull; + } + // m6n6k6 + else if(d6 == 1) + { + unique_id = 8503926755447648324ull; + } if(auto candidate = candidates.find(unique_id); candidate != candidates.end()) { @@ -1003,15 +1690,19 @@ namespace hiptensor hipDataType typeA, std::vector const& a_ms_ks_lengths, std::vector const& a_ms_ks_strides, + std::vector const& a_ms_ks_modes, hipDataType typeB, std::vector const& b_ns_ks_lengths, std::vector const& b_ns_ks_strides, + std::vector const& b_ns_ks_modes, hipDataType typeD, std::vector const& d_ms_ns_lengths, std::vector const& d_ms_ns_strides, + std::vector const& d_ms_ns_modes, hipDataType typeE, std::vector const& e_ms_ns_lengths, std::vector const& e_ms_ns_strides, + std::vector const& e_ms_ns_modes, hiptensorComputeType_t computeType, const uint64_t workspaceSize) { @@ -1028,15 +1719,19 @@ namespace hiptensor typeA, a_ms_ks_lengths, a_ms_ks_strides, + a_ms_ks_modes, typeB, b_ns_ks_lengths, b_ns_ks_strides, + b_ns_ks_modes, typeD, d_ms_ns_lengths, d_ms_ns_strides, + d_ms_ns_modes, typeE, e_ms_ns_lengths, e_ms_ns_strides, + e_ms_ns_modes, workspaceSize); } else if(typeA == HIP_R_16F && typeB == HIP_R_16F && typeD == HIP_R_16F && typeE == HIP_R_16F @@ -1052,15 +1747,19 @@ namespace hiptensor typeA, a_ms_ks_lengths, a_ms_ks_strides, + a_ms_ks_modes, typeB, b_ns_ks_lengths, b_ns_ks_strides, + b_ns_ks_modes, typeD, d_ms_ns_lengths, d_ms_ns_strides, + d_ms_ns_modes, typeE, e_ms_ns_lengths, e_ms_ns_strides, + e_ms_ns_modes, workspaceSize); } else if(typeA == HIP_R_16BF && typeB == HIP_R_16BF && typeD == NONE_TYPE @@ -1076,15 +1775,19 @@ namespace hiptensor typeA, a_ms_ks_lengths, a_ms_ks_strides, + a_ms_ks_modes, typeB, b_ns_ks_lengths, b_ns_ks_strides, + b_ns_ks_modes, typeD, d_ms_ns_lengths, d_ms_ns_strides, + d_ms_ns_modes, typeE, e_ms_ns_lengths, e_ms_ns_strides, + e_ms_ns_modes, workspaceSize); } else if(typeA == HIP_R_16BF && typeB == HIP_R_16BF && typeD == HIP_R_16BF @@ -1100,15 +1803,19 @@ namespace hiptensor typeA, a_ms_ks_lengths, a_ms_ks_strides, + a_ms_ks_modes, typeB, b_ns_ks_lengths, b_ns_ks_strides, + b_ns_ks_modes, typeD, d_ms_ns_lengths, d_ms_ns_strides, + d_ms_ns_modes, typeE, e_ms_ns_lengths, e_ms_ns_strides, + e_ms_ns_modes, workspaceSize); } else if(typeA == HIP_R_32F && typeB == HIP_R_32F && typeD == NONE_TYPE && typeE == HIP_R_32F @@ -1124,15 +1831,19 @@ namespace hiptensor typeA, a_ms_ks_lengths, a_ms_ks_strides, + a_ms_ks_modes, typeB, b_ns_ks_lengths, b_ns_ks_strides, + b_ns_ks_modes, typeD, d_ms_ns_lengths, d_ms_ns_strides, + d_ms_ns_modes, typeE, e_ms_ns_lengths, e_ms_ns_strides, + e_ms_ns_modes, workspaceSize); } else if(typeA == HIP_R_32F && typeB == HIP_R_32F && typeD == HIP_R_32F && typeE == HIP_R_32F @@ -1148,15 +1859,19 @@ namespace hiptensor typeA, a_ms_ks_lengths, a_ms_ks_strides, + a_ms_ks_modes, typeB, b_ns_ks_lengths, b_ns_ks_strides, + b_ns_ks_modes, typeD, d_ms_ns_lengths, d_ms_ns_strides, + d_ms_ns_modes, typeE, e_ms_ns_lengths, e_ms_ns_strides, + e_ms_ns_modes, workspaceSize); } else if(typeA == HIP_R_32F && typeB == HIP_R_32F && typeD == NONE_TYPE && typeE == HIP_R_32F @@ -1172,15 +1887,19 @@ namespace hiptensor typeA, a_ms_ks_lengths, a_ms_ks_strides, + a_ms_ks_modes, typeB, b_ns_ks_lengths, b_ns_ks_strides, + b_ns_ks_modes, typeD, d_ms_ns_lengths, d_ms_ns_strides, + d_ms_ns_modes, typeE, e_ms_ns_lengths, e_ms_ns_strides, + e_ms_ns_modes, workspaceSize); } else if(typeA == HIP_R_32F && typeB == HIP_R_32F && typeD == HIP_R_32F && typeE == HIP_R_32F @@ -1196,15 +1915,19 @@ namespace hiptensor typeA, a_ms_ks_lengths, a_ms_ks_strides, + a_ms_ks_modes, typeB, b_ns_ks_lengths, b_ns_ks_strides, + b_ns_ks_modes, typeD, d_ms_ns_lengths, d_ms_ns_strides, + d_ms_ns_modes, typeE, e_ms_ns_lengths, e_ms_ns_strides, + e_ms_ns_modes, workspaceSize); } else if(typeA == HIP_R_32F && typeB == HIP_R_32F && typeD == NONE_TYPE && typeE == HIP_R_32F @@ -1220,15 +1943,19 @@ namespace hiptensor typeA, a_ms_ks_lengths, a_ms_ks_strides, + a_ms_ks_modes, typeB, b_ns_ks_lengths, b_ns_ks_strides, + b_ns_ks_modes, typeD, d_ms_ns_lengths, d_ms_ns_strides, + d_ms_ns_modes, typeE, e_ms_ns_lengths, e_ms_ns_strides, + e_ms_ns_modes, workspaceSize); } else if(typeA == HIP_R_32F && typeB == HIP_R_32F && typeD == HIP_R_32F && typeE == HIP_R_32F @@ -1244,15 +1971,19 @@ namespace hiptensor typeA, a_ms_ks_lengths, a_ms_ks_strides, + a_ms_ks_modes, typeB, b_ns_ks_lengths, b_ns_ks_strides, + b_ns_ks_modes, typeD, d_ms_ns_lengths, d_ms_ns_strides, + d_ms_ns_modes, typeE, e_ms_ns_lengths, e_ms_ns_strides, + e_ms_ns_modes, workspaceSize); } else if(typeA == HIP_R_64F && typeB == HIP_R_64F && typeD == NONE_TYPE && typeE == HIP_R_64F @@ -1268,15 +1999,19 @@ namespace hiptensor typeA, a_ms_ks_lengths, a_ms_ks_strides, + a_ms_ks_modes, typeB, b_ns_ks_lengths, b_ns_ks_strides, + b_ns_ks_modes, typeD, d_ms_ns_lengths, d_ms_ns_strides, + d_ms_ns_modes, typeE, e_ms_ns_lengths, e_ms_ns_strides, + e_ms_ns_modes, workspaceSize); } else if(typeA == HIP_R_64F && typeB == HIP_R_64F && typeD == HIP_R_64F && typeE == HIP_R_64F @@ -1292,15 +2027,19 @@ namespace hiptensor typeA, a_ms_ks_lengths, a_ms_ks_strides, + a_ms_ks_modes, typeB, b_ns_ks_lengths, b_ns_ks_strides, + b_ns_ks_modes, typeD, d_ms_ns_lengths, d_ms_ns_strides, + d_ms_ns_modes, typeE, e_ms_ns_lengths, e_ms_ns_strides, + e_ms_ns_modes, workspaceSize); } else if(typeA == HIP_R_64F && typeB == HIP_R_64F && typeD == NONE_TYPE && typeE == HIP_R_64F @@ -1316,15 +2055,19 @@ namespace hiptensor typeA, a_ms_ks_lengths, a_ms_ks_strides, + a_ms_ks_modes, typeB, b_ns_ks_lengths, b_ns_ks_strides, + b_ns_ks_modes, typeD, d_ms_ns_lengths, d_ms_ns_strides, + d_ms_ns_modes, typeE, e_ms_ns_lengths, e_ms_ns_strides, + e_ms_ns_modes, workspaceSize); } else if(typeA == HIP_R_64F && typeB == HIP_R_64F && typeD == HIP_R_64F && typeE == HIP_R_64F @@ -1340,15 +2083,19 @@ namespace hiptensor typeA, a_ms_ks_lengths, a_ms_ks_strides, + a_ms_ks_modes, typeB, b_ns_ks_lengths, b_ns_ks_strides, + b_ns_ks_modes, typeD, d_ms_ns_lengths, d_ms_ns_strides, + d_ms_ns_modes, typeE, e_ms_ns_lengths, e_ms_ns_strides, + e_ms_ns_modes, workspaceSize); } else if(typeA == HIP_C_32F && typeB == HIP_C_32F && typeD == NONE_TYPE && typeE == HIP_C_32F @@ -1364,15 +2111,19 @@ namespace hiptensor typeA, a_ms_ks_lengths, a_ms_ks_strides, + a_ms_ks_modes, typeB, b_ns_ks_lengths, b_ns_ks_strides, + b_ns_ks_modes, typeD, d_ms_ns_lengths, d_ms_ns_strides, + d_ms_ns_modes, typeE, e_ms_ns_lengths, e_ms_ns_strides, + e_ms_ns_modes, workspaceSize); } else if(typeA == HIP_C_32F && typeB == HIP_C_32F && typeD == HIP_C_32F && typeE == HIP_C_32F @@ -1388,15 +2139,19 @@ namespace hiptensor typeA, a_ms_ks_lengths, a_ms_ks_strides, + a_ms_ks_modes, typeB, b_ns_ks_lengths, b_ns_ks_strides, + b_ns_ks_modes, typeD, d_ms_ns_lengths, d_ms_ns_strides, + d_ms_ns_modes, typeE, e_ms_ns_lengths, e_ms_ns_strides, + e_ms_ns_modes, workspaceSize); } else if(typeA == HIP_C_64F && typeB == HIP_C_64F && typeD == NONE_TYPE && typeE == HIP_C_64F @@ -1412,15 +2167,19 @@ namespace hiptensor typeA, a_ms_ks_lengths, a_ms_ks_strides, + a_ms_ks_modes, typeB, b_ns_ks_lengths, b_ns_ks_strides, + b_ns_ks_modes, typeD, d_ms_ns_lengths, d_ms_ns_strides, + d_ms_ns_modes, typeE, e_ms_ns_lengths, e_ms_ns_strides, + e_ms_ns_modes, workspaceSize); } else if(typeA == HIP_C_64F && typeB == HIP_C_64F && typeD == HIP_C_64F && typeE == HIP_C_64F @@ -1436,15 +2195,19 @@ namespace hiptensor typeA, a_ms_ks_lengths, a_ms_ks_strides, + a_ms_ks_modes, typeB, b_ns_ks_lengths, b_ns_ks_strides, + b_ns_ks_modes, typeD, d_ms_ns_lengths, d_ms_ns_strides, + d_ms_ns_modes, typeE, e_ms_ns_lengths, e_ms_ns_strides, + e_ms_ns_modes, workspaceSize); } return HIPTENSOR_STATUS_EXECUTION_FAILED; diff --git a/library/src/contraction/contraction_selection.hpp b/library/src/contraction/contraction_selection.hpp index eebbbd53..6c0be5f6 100644 --- a/library/src/contraction/contraction_selection.hpp +++ b/library/src/contraction/contraction_selection.hpp @@ -70,15 +70,19 @@ namespace hiptensor hipDataType typeA, std::vector const& a_ms_ks_lengths, std::vector const& a_ms_ks_strides, + std::vector const& a_ms_ks_modes, hipDataType typeB, std::vector const& b_ns_ks_lengths, std::vector const& b_ns_ks_strides, + std::vector const& b_ns_ks_modes, hipDataType typeD, std::vector const& d_ms_ns_lengths, std::vector const& d_ms_ns_strides, + std::vector const& d_ms_ns_modes, hipDataType typeE, std::vector const& e_ms_ns_lengths, std::vector const& e_ms_ns_strides, + std::vector const& e_ms_ns_modes, const uint64_t workspaceSize); }; @@ -88,15 +92,19 @@ namespace hiptensor hipDataType typeA, std::vector const& a_ms_ks_lengths, std::vector const& a_ms_ks_strides, + std::vector const& a_ms_ks_modes, hipDataType typeB, std::vector const& b_ns_ks_lengths, std::vector const& b_ns_ks_strides, + std::vector const& b_ns_ks_modes, hipDataType typeD, std::vector const& d_ms_ns_lengths, std::vector const& d_ms_ns_strides, + std::vector const& d_ms_ns_modes, hipDataType typeE, std::vector const& e_ms_ns_lengths, std::vector const& e_ms_ns_strides, + std::vector const& e_ms_ns_modes, hiptensorComputeType_t computeType, const uint64_t workspaceSize); diff --git a/library/src/contraction/contraction_solution.cpp b/library/src/contraction/contraction_solution.cpp index c3ce6b41..583843ee 100644 --- a/library/src/contraction/contraction_solution.cpp +++ b/library/src/contraction/contraction_solution.cpp @@ -246,16 +246,15 @@ namespace hiptensor e_ms_ns_strides, e_ms_ns_modes, workspacePtr)) - { return {HIPTENSOR_STATUS_INTERNAL_ERROR, -1.0f}; } if(this->workspaceSize() > workspaceSize) { + resetInvokerArgs(); return {HIPTENSOR_STATUS_INSUFFICIENT_WORKSPACE, -1.0f}; } - auto time = mInvokerPtr->Run(mInvokerArgPtr.get(), streamConfig); resetInvokerArgs(); diff --git a/library/src/contraction/contraction_solution_impl.hpp b/library/src/contraction/contraction_solution_impl.hpp index 806eaf80..5f330735 100644 --- a/library/src/contraction/contraction_solution_impl.hpp +++ b/library/src/contraction/contraction_solution_impl.hpp @@ -194,7 +194,12 @@ namespace hiptensor // Arg test Base::mValid = deviceOp->IsSupportedArgument(Base::mInvokerArgPtr.get()); - return mValid; + if(!Base::mValid) + { + resetArgs(); + } + + return Base::mValid; } }; @@ -324,6 +329,11 @@ namespace hiptensor // Arg test Base::mValid = deviceOp->IsSupportedArgument(Base::mInvokerArgPtr.get()); + if(!Base::mValid) + { + resetArgs(); + } + return Base::mValid; } }; diff --git a/library/src/contraction/hiptensor_contraction.cpp b/library/src/contraction/hiptensor_contraction.cpp index 65ba3b8a..21c28757 100644 --- a/library/src/contraction/hiptensor_contraction.cpp +++ b/library/src/contraction/hiptensor_contraction.cpp @@ -551,15 +551,19 @@ hiptensorStatus_t hiptensorInitContractionPlan(const hiptensorHandle_t* ADataType, desc->mTensorDesc[0].mLengths, desc->mTensorDesc[0].mStrides, + desc->mTensorMode[0], BDataType, desc->mTensorDesc[1].mLengths, desc->mTensorDesc[1].mStrides, + desc->mTensorMode[1], DDataType, desc->mTensorDesc[2].mLengths, desc->mTensorDesc[2].mStrides, + desc->mTensorMode[2], EDataType, desc->mTensorDesc[3].mLengths, desc->mTensorDesc[3].mStrides, + desc->mTensorMode[2], desc->mComputeType, workspaceSize); } diff --git a/library/src/include/logger.hpp b/library/src/include/logger.hpp index 3d2cbc36..16f431ff 100644 --- a/library/src/include/logger.hpp +++ b/library/src/include/logger.hpp @@ -57,7 +57,7 @@ namespace hiptensor LOG_LEVEL_PERF_TRACE = 2, LOG_LEVEL_PERF_HINT = 4, LOG_LEVEL_HEURISTICS_TRACE = 8, - LOG_LEVEL_API_TRACE = 16 + LOG_LEVEL_API_TRACE = 16, }; // For static initialization diff --git a/test/01_contraction/configs/bilinear_test_params_rank1.yaml b/test/01_contraction/configs/bilinear_test_params_rank1.yaml index 5d2c99f1..2a003fe6 100644 --- a/test/01_contraction/configs/bilinear_test_params_rank1.yaml +++ b/test/01_contraction/configs/bilinear_test_params_rank1.yaml @@ -11,6 +11,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/bilinear_test_params_rank2.yaml b/test/01_contraction/configs/bilinear_test_params_rank2.yaml index 403fdd1c..a14a5e77 100644 --- a/test/01_contraction/configs/bilinear_test_params_rank2.yaml +++ b/test/01_contraction/configs/bilinear_test_params_rank2.yaml @@ -11,6 +11,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/bilinear_test_params_rank3.yaml b/test/01_contraction/configs/bilinear_test_params_rank3.yaml index e022918d..d90aaf7f 100644 --- a/test/01_contraction/configs/bilinear_test_params_rank3.yaml +++ b/test/01_contraction/configs/bilinear_test_params_rank3.yaml @@ -11,6 +11,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/bilinear_test_params_rank4.yaml b/test/01_contraction/configs/bilinear_test_params_rank4.yaml index bd1b3aa4..77a346da 100644 --- a/test/01_contraction/configs/bilinear_test_params_rank4.yaml +++ b/test/01_contraction/configs/bilinear_test_params_rank4.yaml @@ -11,6 +11,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/bilinear_test_params_rank5.yaml b/test/01_contraction/configs/bilinear_test_params_rank5.yaml index e6ccb4bc..c0e49263 100644 --- a/test/01_contraction/configs/bilinear_test_params_rank5.yaml +++ b/test/01_contraction/configs/bilinear_test_params_rank5.yaml @@ -11,6 +11,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/bilinear_test_params_rank6.yaml b/test/01_contraction/configs/bilinear_test_params_rank6.yaml index 0714627a..ab3d110f 100644 --- a/test/01_contraction/configs/bilinear_test_params_rank6.yaml +++ b/test/01_contraction/configs/bilinear_test_params_rank6.yaml @@ -11,6 +11,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/complex_bilinear_test_params_rank1.yaml b/test/01_contraction/configs/complex_bilinear_test_params_rank1.yaml index 6558f4a5..62a26b85 100644 --- a/test/01_contraction/configs/complex_bilinear_test_params_rank1.yaml +++ b/test/01_contraction/configs/complex_bilinear_test_params_rank1.yaml @@ -6,6 +6,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/complex_bilinear_test_params_rank2.yaml b/test/01_contraction/configs/complex_bilinear_test_params_rank2.yaml index 7569bb22..3fe41c26 100644 --- a/test/01_contraction/configs/complex_bilinear_test_params_rank2.yaml +++ b/test/01_contraction/configs/complex_bilinear_test_params_rank2.yaml @@ -6,6 +6,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/complex_bilinear_test_params_rank3.yaml b/test/01_contraction/configs/complex_bilinear_test_params_rank3.yaml index 8e679b33..d80f24dc 100644 --- a/test/01_contraction/configs/complex_bilinear_test_params_rank3.yaml +++ b/test/01_contraction/configs/complex_bilinear_test_params_rank3.yaml @@ -6,6 +6,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/complex_bilinear_test_params_rank4.yaml b/test/01_contraction/configs/complex_bilinear_test_params_rank4.yaml index 9d03bcb2..134b7de7 100644 --- a/test/01_contraction/configs/complex_bilinear_test_params_rank4.yaml +++ b/test/01_contraction/configs/complex_bilinear_test_params_rank4.yaml @@ -6,6 +6,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/complex_bilinear_test_params_rank5.yaml b/test/01_contraction/configs/complex_bilinear_test_params_rank5.yaml index 46b23acf..bbe5a912 100644 --- a/test/01_contraction/configs/complex_bilinear_test_params_rank5.yaml +++ b/test/01_contraction/configs/complex_bilinear_test_params_rank5.yaml @@ -6,6 +6,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/complex_bilinear_test_params_rank6.yaml b/test/01_contraction/configs/complex_bilinear_test_params_rank6.yaml index 6de2bbf8..11236f86 100644 --- a/test/01_contraction/configs/complex_bilinear_test_params_rank6.yaml +++ b/test/01_contraction/configs/complex_bilinear_test_params_rank6.yaml @@ -6,6 +6,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/complex_scale_test_params_rank1.yaml b/test/01_contraction/configs/complex_scale_test_params_rank1.yaml index d43001fe..65563868 100644 --- a/test/01_contraction/configs/complex_scale_test_params_rank1.yaml +++ b/test/01_contraction/configs/complex_scale_test_params_rank1.yaml @@ -6,6 +6,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/complex_scale_test_params_rank2.yaml b/test/01_contraction/configs/complex_scale_test_params_rank2.yaml index df352955..9031f322 100644 --- a/test/01_contraction/configs/complex_scale_test_params_rank2.yaml +++ b/test/01_contraction/configs/complex_scale_test_params_rank2.yaml @@ -6,6 +6,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/complex_scale_test_params_rank3.yaml b/test/01_contraction/configs/complex_scale_test_params_rank3.yaml index 1d343e89..bef116ec 100644 --- a/test/01_contraction/configs/complex_scale_test_params_rank3.yaml +++ b/test/01_contraction/configs/complex_scale_test_params_rank3.yaml @@ -6,6 +6,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/complex_scale_test_params_rank4.yaml b/test/01_contraction/configs/complex_scale_test_params_rank4.yaml index 56f1b7e7..1c334f4c 100644 --- a/test/01_contraction/configs/complex_scale_test_params_rank4.yaml +++ b/test/01_contraction/configs/complex_scale_test_params_rank4.yaml @@ -6,6 +6,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/complex_scale_test_params_rank5.yaml b/test/01_contraction/configs/complex_scale_test_params_rank5.yaml index 8cc52fdf..7873ea1c 100644 --- a/test/01_contraction/configs/complex_scale_test_params_rank5.yaml +++ b/test/01_contraction/configs/complex_scale_test_params_rank5.yaml @@ -6,6 +6,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/complex_scale_test_params_rank6.yaml b/test/01_contraction/configs/complex_scale_test_params_rank6.yaml index 2edd8b82..fe25f401 100644 --- a/test/01_contraction/configs/complex_scale_test_params_rank6.yaml +++ b/test/01_contraction/configs/complex_scale_test_params_rank6.yaml @@ -6,6 +6,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/scale_test_params_rank1.yaml b/test/01_contraction/configs/scale_test_params_rank1.yaml index f33d66e4..5e91faa2 100644 --- a/test/01_contraction/configs/scale_test_params_rank1.yaml +++ b/test/01_contraction/configs/scale_test_params_rank1.yaml @@ -11,6 +11,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/scale_test_params_rank2.yaml b/test/01_contraction/configs/scale_test_params_rank2.yaml index 2a3bd827..747b3faa 100644 --- a/test/01_contraction/configs/scale_test_params_rank2.yaml +++ b/test/01_contraction/configs/scale_test_params_rank2.yaml @@ -11,6 +11,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/scale_test_params_rank3.yaml b/test/01_contraction/configs/scale_test_params_rank3.yaml index 00cfecce..e33db922 100644 --- a/test/01_contraction/configs/scale_test_params_rank3.yaml +++ b/test/01_contraction/configs/scale_test_params_rank3.yaml @@ -11,6 +11,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/scale_test_params_rank4.yaml b/test/01_contraction/configs/scale_test_params_rank4.yaml index 2953bec4..cca5ae2e 100644 --- a/test/01_contraction/configs/scale_test_params_rank4.yaml +++ b/test/01_contraction/configs/scale_test_params_rank4.yaml @@ -11,6 +11,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/scale_test_params_rank5.yaml b/test/01_contraction/configs/scale_test_params_rank5.yaml index f5345ff3..a87efa04 100644 --- a/test/01_contraction/configs/scale_test_params_rank5.yaml +++ b/test/01_contraction/configs/scale_test_params_rank5.yaml @@ -11,6 +11,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/configs/scale_test_params_rank6.yaml b/test/01_contraction/configs/scale_test_params_rank6.yaml index 31254753..3d6b17a1 100644 --- a/test/01_contraction/configs/scale_test_params_rank6.yaml +++ b/test/01_contraction/configs/scale_test_params_rank6.yaml @@ -11,6 +11,7 @@ Tensor Data Types: Algorithm Types: - HIPTENSOR_ALGO_DEFAULT - HIPTENSOR_ALGO_DEFAULT_PATIENT + - HIPTENSOR_ALGO_ACTOR_CRITIC Operators: - HIPTENSOR_OP_IDENTITY Worksize Prefs: diff --git a/test/01_contraction/contraction_test.cpp b/test/01_contraction/contraction_test.cpp index 08eac199..631e6cd4 100644 --- a/test/01_contraction/contraction_test.cpp +++ b/test/01_contraction/contraction_test.cpp @@ -193,13 +193,13 @@ namespace hiptensor } size_t elementsA = std::accumulate(a_ms_ks_lengths.begin(), - a_ms_ks_lengths.end(), - size_t{1}, - std::multiplies()); + a_ms_ks_lengths.end(), + size_t{1}, + std::multiplies()); size_t elementsB = std::accumulate(b_ns_ks_lengths.begin(), - b_ns_ks_lengths.end(), - size_t{1}, - std::multiplies()); + b_ns_ks_lengths.end(), + size_t{1}, + std::multiplies()); size_t elementsCD = std::accumulate(cd_ms_ns_lengths.begin(), cd_ms_ns_lengths.end(), size_t{1}, @@ -262,11 +262,13 @@ namespace hiptensor if(ADataType == HIP_R_16F && BDataType == HIP_R_16F && DDataType == HIP_R_16F) { // Initialize matrix data on device - fillLaunchKernel<_Float16>((_Float16*)resource->deviceA().get(), elementsA, seed - 1); + fillLaunchKernel<_Float16>( + (_Float16*)resource->deviceA().get(), elementsA, seed - 1); fillLaunchKernel<_Float16>((_Float16*)resource->deviceB().get(), elementsB, seed); if(CDataType == HIP_R_16F) { - fillLaunchKernel<_Float16>((_Float16*)resource->deviceC().get(), elementsCD, seed + 1); + fillLaunchKernel<_Float16>( + (_Float16*)resource->deviceC().get(), elementsCD, seed + 1); } fillValLaunchKernel<_Float16>((_Float16*)resource->deviceD().get(), elementsCD, @@ -275,12 +277,14 @@ namespace hiptensor else if(ADataType == HIP_R_16BF && BDataType == HIP_R_16BF && DDataType == HIP_R_16BF) { // Initialize matrix data on device - fillLaunchKernel((hip_bfloat16*)resource->deviceA().get(), elementsA, seed - 1); - fillLaunchKernel((hip_bfloat16*)resource->deviceB().get(), elementsB, seed); + fillLaunchKernel( + (hip_bfloat16*)resource->deviceA().get(), elementsA, seed - 1); + fillLaunchKernel( + (hip_bfloat16*)resource->deviceB().get(), elementsB, seed); if(CDataType == HIP_R_16BF) { - fillLaunchKernel((hip_bfloat16*)resource->deviceC().get(), - elementsCD, seed + 1); + fillLaunchKernel( + (hip_bfloat16*)resource->deviceC().get(), elementsCD, seed + 1); } fillValLaunchKernel( (hip_bfloat16*)resource->deviceD().get(), @@ -294,7 +298,8 @@ namespace hiptensor fillLaunchKernel((float*)resource->deviceB().get(), elementsB, seed); if(CDataType == HIP_R_32F) { - fillLaunchKernel((float*)resource->deviceC().get(), elementsCD, seed + 1); + fillLaunchKernel( + (float*)resource->deviceC().get(), elementsCD, seed + 1); } fillValLaunchKernel((float*)resource->deviceD().get(), elementsCD, @@ -307,7 +312,8 @@ namespace hiptensor fillLaunchKernel((double*)resource->deviceB().get(), elementsB, seed); if(CDataType == HIP_R_64F) { - fillLaunchKernel((double*)resource->deviceC().get(), elementsCD, seed + 1); + fillLaunchKernel( + (double*)resource->deviceC().get(), elementsCD, seed + 1); } fillValLaunchKernel((double*)resource->deviceD().get(), elementsCD, @@ -316,14 +322,14 @@ namespace hiptensor else if(ADataType == HIP_C_32F && BDataType == HIP_C_32F && DDataType == HIP_C_32F) { // Initialize matrix data on device - fillLaunchKernel((hipFloatComplex*)resource->deviceA().get(), - elementsA, seed - 1); - fillLaunchKernel((hipFloatComplex*)resource->deviceB().get(), - elementsB, seed); + fillLaunchKernel( + (hipFloatComplex*)resource->deviceA().get(), elementsA, seed - 1); + fillLaunchKernel( + (hipFloatComplex*)resource->deviceB().get(), elementsB, seed); if(CDataType == HIP_C_32F) { - fillLaunchKernel((hipFloatComplex*)resource->deviceC().get(), - elementsCD, seed + 1); + fillLaunchKernel( + (hipFloatComplex*)resource->deviceC().get(), elementsCD, seed + 1); } fillValLaunchKernel( (hipFloatComplex*)resource->deviceD().get(), @@ -333,14 +339,14 @@ namespace hiptensor else if(ADataType == HIP_C_64F && BDataType == HIP_C_64F && DDataType == HIP_C_64F) { // Initialize matrix data on device - fillLaunchKernel((hipDoubleComplex*)resource->deviceA().get(), - elementsA, seed - 1); - fillLaunchKernel((hipDoubleComplex*)resource->deviceB().get(), - elementsB, seed); + fillLaunchKernel( + (hipDoubleComplex*)resource->deviceA().get(), elementsA, seed - 1); + fillLaunchKernel( + (hipDoubleComplex*)resource->deviceB().get(), elementsB, seed); if(CDataType == HIP_C_64F) { - fillLaunchKernel((hipDoubleComplex*)resource->deviceC().get(), - elementsCD, seed + 1); + fillLaunchKernel( + (hipDoubleComplex*)resource->deviceC().get(), elementsCD, seed + 1); } fillValLaunchKernel( (hipDoubleComplex*)resource->deviceD().get(), @@ -404,16 +410,24 @@ namespace hiptensor } } - void ContractionTest::reportResults(std::ostream& stream, - hipDataType DDataType, - bool omitSkipped, - bool omitFailed, - bool omitPassed) const + void ContractionTest::reportResults(std::ostream& stream, + hipDataType DDataType, + hiptensorComputeType_t computeType, + bool omitSkipped, + bool omitFailed, + bool omitPassed) const { // Conditionally print outputs if((mRunFlag || !omitSkipped) && (mValidationResult || !omitFailed) && (!mValidationResult || !omitPassed)) { + if(mPrintTypes) + { + ContractionTest::sAPILogBuff + << "TypeA/B/C/D: " << hipTypeToString(DDataType) + << ", ComputeType: " << computeTypeToString(computeType) << std::endl; + } + stream << ContractionTest::sAPILogBuff.str(); if(mPrintElements) @@ -423,13 +437,13 @@ namespace hiptensor int size = hipDataTypeSize(DDataType); size_t elementsA = std::accumulate(a_ms_ks.mLengths.begin(), - a_ms_ks.mLengths.end(), - size_t{1}, - std::multiplies()); + a_ms_ks.mLengths.end(), + size_t{1}, + std::multiplies()); size_t elementsB = std::accumulate(b_ns_ks.mLengths.begin(), - b_ns_ks.mLengths.end(), - size_t{1}, - std::multiplies()); + b_ns_ks.mLengths.end(), + size_t{1}, + std::multiplies()); size_t elementsCD = std::accumulate(d_ms_ns.mLengths.begin(), d_ms_ns.mLengths.end(), size_t{1}, @@ -438,6 +452,8 @@ namespace hiptensor auto D = resource->allocHost(elementsCD * size); resource->copyData(D, resource->deviceD(), elementsCD * size); + auto& references = resource->hostD(); + if(DDataType == HIP_R_16F) { stream << "Tensor A elements:\n"; @@ -458,6 +474,11 @@ namespace hiptensor stream << "Tensor D elements:\n"; hiptensorPrintArrayElements<_Float16>(stream, (_Float16*)D.get(), elementsCD); stream << std::endl; + + stream << "Tensor reference elements:\n"; + hiptensorPrintArrayElements<_Float16>( + stream, (_Float16*)references.get(), elementsCD); + stream << std::endl; } else if(DDataType == HIP_R_16BF) { @@ -480,6 +501,11 @@ namespace hiptensor hiptensorPrintArrayElements( stream, (hip_bfloat16*)D.get(), elementsCD); stream << std::endl; + + stream << "Tensor reference elements:\n"; + hiptensorPrintArrayElements( + stream, (hip_bfloat16*)references.get(), elementsCD); + stream << std::endl; } else if(DDataType == HIP_R_32F) { @@ -501,6 +527,11 @@ namespace hiptensor stream << "Tensor D elements:\n"; hiptensorPrintArrayElements(stream, (float*)D.get(), elementsCD); stream << std::endl; + + stream << "Tensor reference elements:\n"; + hiptensorPrintArrayElements( + stream, (float*)references.get(), elementsCD); + stream << std::endl; } else if(DDataType == HIP_R_64F) { @@ -522,6 +553,11 @@ namespace hiptensor stream << "Tensor D elements:\n"; hiptensorPrintArrayElements(stream, (double*)D.get(), elementsCD); stream << std::endl; + + stream << "Tensor reference elements:\n"; + hiptensorPrintArrayElements( + stream, (double*)references.get(), elementsCD); + stream << std::endl; } else if(DDataType == HIP_C_32F) { @@ -544,6 +580,11 @@ namespace hiptensor hiptensorPrintArrayElements( stream, (hipFloatComplex*)D.get(), elementsCD); stream << std::endl; + + stream << "Tensor reference elements:\n"; + hiptensorPrintArrayElements( + stream, (hipFloatComplex*)references.get(), elementsCD); + stream << std::endl; } else if(DDataType == HIP_C_64F) { @@ -566,6 +607,11 @@ namespace hiptensor hiptensorPrintArrayElements( stream, (hipDoubleComplex*)D.get(), elementsCD); stream << std::endl; + + stream << "Tensor reference elements:\n"; + hiptensorPrintArrayElements( + stream, (hipDoubleComplex*)references.get(), elementsCD); + stream << std::endl; } } } @@ -660,22 +706,22 @@ namespace hiptensor // Compute tolerance based on compute type auto dimension = a_ms_ks.mLengths.size() / 2; - auto nelems_k = std::accumulate(a_ms_ks.mLengths.begin() + dimension, + auto nelems_k = std::accumulate(a_ms_ks.mLengths.begin() + dimension, a_ms_ks.mLengths.end(), size_t{1}, std::multiplies()); - auto eps = getEpsilon(computeType == HIPTENSOR_COMPUTE_64F ? HIPTENSOR_COMPUTE_64F - : HIPTENSOR_COMPUTE_32F); + auto eps = getEpsilon(computeType == HIPTENSOR_COMPUTE_64F ? HIPTENSOR_COMPUTE_64F + : HIPTENSOR_COMPUTE_32F); double tolerance = 2 * nelems_k * eps; // use the same default tolerance value as CK - if (computeType == HIPTENSOR_COMPUTE_16BF || DDataType == HIP_R_16BF) + if(computeType == HIPTENSOR_COMPUTE_16BF || DDataType == HIP_R_16BF) { const double epsilon = std::pow(2, -7); tolerance += epsilon * 2; } - else if (computeType == HIPTENSOR_COMPUTE_16F || DDataType == HIP_R_16F) + else if(computeType == HIPTENSOR_COMPUTE_16F || DDataType == HIP_R_16F) { const double epsilon = std::pow(2, -10); tolerance += epsilon * 2; @@ -728,6 +774,7 @@ namespace hiptensor { reportResults(std::cout, DDataType, + computeType, loggingOptions->omitSkipped(), loggingOptions->omitFailed(), loggingOptions->omitPassed()); @@ -737,6 +784,7 @@ namespace hiptensor { reportResults(loggingOptions->ostream().fstream(), DDataType, + computeType, loggingOptions->omitSkipped(), loggingOptions->omitFailed(), loggingOptions->omitPassed()); diff --git a/test/01_contraction/contraction_test.hpp b/test/01_contraction/contraction_test.hpp index 2b50d220..f0fb93c8 100644 --- a/test/01_contraction/contraction_test.hpp +++ b/test/01_contraction/contraction_test.hpp @@ -95,11 +95,12 @@ namespace hiptensor void Warmup() {} void RunKernel(); - void reportResults(std::ostream& stream, - hipDataType DDataType, - bool omitSkipped, - bool omitFailed, - bool omitPassed) const; + void reportResults(std::ostream& stream, + hipDataType DDataType, + hiptensorComputeType_t computeType, + bool omitSkipped, + bool omitFailed, + bool omitPassed) const; protected: // Workspace items @@ -117,6 +118,7 @@ namespace hiptensor bool mRunFlag = true; bool mValidationResult = false; bool mPrintElements = false; + bool mPrintTypes = false; double mMaxRelativeError; // Output buffer