Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contraction f16, bf16, f32_f16, f32_bf16, f64_f32 #158

Merged
merged 12 commits into from
Dec 11, 2023
1 change: 1 addition & 0 deletions library/include/hiptensor/internal/hiptensor_utility.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <iostream>

#include "../hiptensor_types.hpp"
#include "types_ext.hpp"

#ifndef CHECK_HIP_ERROR
#define CHECK_HIP_ERROR(expression) \
Expand Down
48 changes: 25 additions & 23 deletions library/src/contraction/contraction_cpu_reference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,33 @@
#include "contraction_cpu_reference_impl.hpp"
#include "contraction_cpu_reference_instances.hpp"

hiptensorStatus_t hiptensorContractionReference(void const* alpha,
void const* A,
void const* B,
void const* beta,
void const* C,
void* D,
std::vector<size_t> const& a_ms_ks_lengths,
std::vector<size_t> const& a_ms_ks_strides,
std::vector<size_t> const& b_ns_ks_lengths,
std::vector<size_t> const& b_ns_ks_strides,
std::vector<size_t> const& c_ms_ns_lengths,
std::vector<size_t> const& c_ms_ns_strides,
std::vector<size_t> const& d_ms_ns_lengths,
std::vector<size_t> const& d_ms_ns_strides,
hipDataType typeA,
hipDataType typeB,
hipDataType typeC,
hipDataType typeD,
void* workspace)
hiptensorStatus_t hiptensorContractionReference(const hiptensorContractionPlan_t* plan,
void const* alpha,
void const* A,
void const* B,
void const* beta,
void const* C,
void* D,
std::vector<size_t> const& a_ms_ks_lengths,
std::vector<size_t> const& a_ms_ks_strides,
std::vector<size_t> const& b_ns_ks_lengths,
std::vector<size_t> const& b_ns_ks_strides,
std::vector<size_t> const& c_ms_ns_lengths,
std::vector<size_t> const& c_ms_ns_strides,
std::vector<size_t> const& d_ms_ns_lengths,
std::vector<size_t> const& d_ms_ns_strides,
hipDataType typeA,
hipDataType typeB,
hipDataType typeC,
hipDataType typeD,
void* workspace)
{
auto& instances = hiptensor::ContractionCpuReferenceInstances::instance();
auto& instances = hiptensor::ContractionCpuReferenceInstances::instance();
auto computeType = plan->mContractionDesc.mComputeType;
auto candidates
= (C == nullptr)
? instances->allSolutions().query(typeA, typeB, hiptensor::NONE_TYPE, typeD)
: instances->allSolutions().query(typeA, typeB, typeC, typeD);
= (C == nullptr) ? instances->allSolutions().query(
typeA, typeB, hiptensor::NONE_TYPE, typeD, computeType)
: instances->allSolutions().query(typeA, typeB, typeC, typeD, computeType);

auto toCKVec
= [](auto& inputVec) { return std::vector<ck::index_t>(inputVec.begin(), inputVec.end()); };
Expand Down
39 changes: 20 additions & 19 deletions library/src/contraction/contraction_cpu_reference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,25 @@

#include <hiptensor/hiptensor.hpp>

hiptensorStatus_t hiptensorContractionReference(void const* alpha,
void const* A,
void const* B,
void const* beta,
void const* C,
void* D,
std::vector<size_t> const& a_ms_ks_lengths,
std::vector<size_t> const& a_ms_ks_strides,
std::vector<size_t> const& b_ks_ns_lengths,
std::vector<size_t> const& b_ks_ns_strides,
std::vector<size_t> const& c_ms_ns_lengths,
std::vector<size_t> const& c_ms_ns_strides,
std::vector<size_t> const& d_ms_ns_lengths,
std::vector<size_t> const& d_ms_ns_strides,
hipDataType typeA,
hipDataType typeB,
hipDataType typeC,
hipDataType typeD,
void* workspace);
hiptensorStatus_t hiptensorContractionReference(const hiptensorContractionPlan_t* plan,
void const* alpha,
void const* A,
void const* B,
void const* beta,
void const* C,
void* D,
std::vector<size_t> const& a_ms_ks_lengths,
std::vector<size_t> const& a_ms_ks_strides,
std::vector<size_t> const& b_ks_ns_lengths,
std::vector<size_t> const& b_ks_ns_strides,
std::vector<size_t> const& c_ms_ns_lengths,
std::vector<size_t> const& c_ms_ns_strides,
std::vector<size_t> const& d_ms_ns_lengths,
std::vector<size_t> const& d_ms_ns_strides,
hipDataType typeA,
hipDataType typeB,
hipDataType typeC,
hipDataType typeD,
void* workspace);

#endif // HIPTENSOR_CONTRACTION_CPU_REFERENCE_HPP
60 changes: 39 additions & 21 deletions library/src/contraction/contraction_cpu_reference_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,25 @@
namespace hiptensor
{
// hardcoded for NumDimM == NumDimN == NumDimK == 2
//
// ck::bhalf_t is ushort, cannot perform bhalf_t * bhalf_t
// CK does not use ck::bhalf_t as AccDataType. But we still
// add this guard here
template <
ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename DsDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2 && DsDataType::Size() <= 1,
typename ComputeDataType = ADataType,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2 && DsDataType::Size() <= 1
&& !std::is_same_v<AccDataType, ck::bhalf_t>,
bool>
= false>
struct ReferenceContraction_M2_N2_K2
Expand All @@ -70,7 +76,8 @@ namespace hiptensor
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
CDEElementwiseOperation,
ComputeDataType>
{
using BaseArgument = ck::tensor_operation::device::BaseArgument;
using BaseInvoker = ck::tensor_operation::device::BaseInvoker;
Expand Down Expand Up @@ -150,7 +157,7 @@ namespace hiptensor
};

auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
auto accum = static_cast<AccDataType>(0);
AccDataType accum = 0;

auto K0 = arg.mA_ms_ks_lengths[2];
auto K1 = arg.mA_ms_ks_lengths[3];
Expand All @@ -164,16 +171,19 @@ namespace hiptensor
auto indexB
= offset(std::vector<size_t>{n0, n1, k0, k1}, arg.mB_ns_ks_strides);

ADataType valA;
BDataType valB;
AccDataType valA;
AccDataType valB;

// Element-wise ops
arg.mOpA(valA, ((ADataType*)arg.mA)[indexA]);
arg.mOpB(valB, ((BDataType*)arg.mB)[indexB]);
arg.mOpA(
valA,
ck::type_convert<ComputeDataType>(((ADataType*)arg.mA)[indexA]));
arg.mOpB(
valB,
ck::type_convert<ComputeDataType>(((BDataType*)arg.mB)[indexB]));

// Mult / accum
accum
+= static_cast<AccDataType>(valA) * static_cast<AccDataType>(valB);
accum += valA * valB;
}
}

Expand All @@ -182,15 +192,17 @@ namespace hiptensor
if constexpr(std::is_same_v<CDEElementwiseOperation,
ck::tensor_operation::element_wise::Scale>)
{
arg.mOpCDE(((EDataType*)arg.mE)[indexE], accum);
arg.mOpCDE(((EDataType*)arg.mE)[indexE],
ck::type_convert<EDataType>(accum));
}
else // bilinear
{
// NumDTensor will be 1 due to SFINAE of this class
auto indexD
= offset(std::vector<size_t>{m0, m1, n0, n1}, arg.mD_ms_ns_strides[0]);
arg.mOpCDE(
((EDataType*)arg.mE)[indexE], accum, ((EDataType*)(arg.mD[0]))[indexD]);
arg.mOpCDE(((EDataType*)arg.mE)[indexE],
ck::type_convert<EDataType>(accum),
((EDataType*)(arg.mD[0]))[indexD]);
}
};

Expand Down Expand Up @@ -319,23 +331,25 @@ namespace hiptensor
ck::index_t NumDimsK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename DsDataType,
typename EDataType,
typename AccumDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
typename CDEElementwiseOperation,
typename ComputeDataType>
struct MetaTraits<ReferenceContraction_M2_N2_K2<NumDimsM,
NumDimsN,
NumDimsK,
ADataType,
BDataType,
AccDataType,
DsDataType,
EDataType,
AccumDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>>
CDEElementwiseOperation,
ComputeDataType>>
: public MetaTraits<
ck::tensor_operation::device::DeviceContractionMultipleD<NumDimsM,
NumDimsN,
Expand All @@ -346,7 +360,8 @@ namespace hiptensor
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>>
CDEElementwiseOperation,
ComputeDataType>>
{
};

Expand All @@ -355,24 +370,27 @@ namespace hiptensor
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
typename CDEElementwiseOperation,
typename ComputeDataType = ADataType>
auto enumerateReferenceSolutions()
{
using ReferenceOp = ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
DsDataType,
EDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>;
CDEElementwiseOperation,
ComputeDataType>;

auto solution = std::make_unique<ContractionSolutionImpl<ReferenceOp>>(
std::make_unique<ReferenceOp>());
Expand Down
Loading