Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change TensorShape to typically not allocate heap memory #9542

Merged
merged 31 commits into from
Nov 8, 2021
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
59f3045
TensorShape no longer uses std::vector
RyanUnderhill Oct 26, 2021
c9f916a
Merge branch 'master' of https://github.com/Microsoft/onnxruntime int…
RyanUnderhill Oct 26, 2021
36d9f6a
Clang Format
RyanUnderhill Oct 26, 2021
48ab110
Remove comment
RyanUnderhill Oct 26, 2021
a9f2330
Use std::size instead of _countof
RyanUnderhill Oct 26, 2021
9717adb
Use std::size vs _countof
RyanUnderhill Oct 26, 2021
06095c3
Fix some build breaks
RyanUnderhill Oct 26, 2021
f7e6717
Switch over Training Code
RyanUnderhill Oct 26, 2021
554925b
Add unit test
RyanUnderhill Oct 27, 2021
7dfff97
Fix breaks
RyanUnderhill Oct 27, 2021
b1352a5
Build Breaks
RyanUnderhill Oct 27, 2021
1b258ee
Build Fixes
RyanUnderhill Oct 27, 2021
e7cc62d
Build Fixes
RyanUnderhill Oct 27, 2021
8cace97
Build Fixes
RyanUnderhill Oct 27, 2021
6828641
Build Fixes
RyanUnderhill Oct 27, 2021
fe303ea
Build Fixes
RyanUnderhill Oct 27, 2021
766d5f6
Bring back code for existing buffers, plus code review feedback.
RyanUnderhill Oct 28, 2021
dfdadd6
Code review feedback
RyanUnderhill Oct 29, 2021
453b725
Merge branch 'master' of https://github.com/Microsoft/onnxruntime int…
RyanUnderhill Oct 29, 2021
c411d89
Merge with master
RyanUnderhill Oct 29, 2021
93ea6f9
Revert back gsl::span change to api.h, have to keep allocating vector…
RyanUnderhill Oct 29, 2021
0c7a774
Revert parts of transpose_optimizer.cc
RyanUnderhill Oct 29, 2021
8ceef80
Fix rocm
RyanUnderhill Oct 29, 2021
c16631d
Fix rocm
RyanUnderhill Oct 29, 2021
58aaec8
Merge branch 'master' of https://github.com/Microsoft/onnxruntime int…
RyanUnderhill Nov 4, 2021
622de71
Fix ROCM break due to master merge
RyanUnderhill Nov 5, 2021
19d6b07
Merge branch 'master' of https://github.com/Microsoft/onnxruntime int…
RyanUnderhill Nov 5, 2021
d22e227
Bump minimal build size up
RyanUnderhill Nov 5, 2021
7d4efe6
Code review feedback
RyanUnderhill Nov 5, 2021
3524689
Merge with master
RyanUnderhill Nov 5, 2021
d7e0e64
Increase minimal build size by 1608 bytes
RyanUnderhill Nov 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 37 additions & 42 deletions include/onnxruntime/core/framework/tensor_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <algorithm>
#include <string>
#include <cstring>
#include <gsl/gsl>
#include "onnxruntime_config.h"

namespace onnxruntime {
Expand All @@ -16,60 +17,45 @@ namespace onnxruntime {
#pragma GCC diagnostic ignored "-Wnull-dereference"
#endif
#endif
class TensorShape : private std::vector<int64_t> {
// TODO - Use a custom STL allocator to avoid heap allocations in the common case.
class TensorShape {
// We use negative numbers for unknown symbolic dimension. Each negative
// number represents a unique symbolic dimension.
// Private inheritance is used to prevent ambiguity of element versus dimension size
public:
TensorShape() = default;

TensorShape(const TensorShape& /*other*/) = default;
TensorShape& operator=(const TensorShape& /*other*/) = default;
TensorShape(const TensorShape& other) : TensorShape(other.GetDims()) {}
TensorShape& operator=(const TensorShape& other);

TensorShape(TensorShape&& /*other*/) = default;
TensorShape& operator=(TensorShape&& /*other*/) = default;
TensorShape(TensorShape&& other) { operator=(std::move(other)); }
TensorShape& operator=(TensorShape&& other);

TensorShape(const std::vector<int64_t>& dims) : std::vector<int64_t>(dims) {}
TensorShape(gsl::span<const int64_t> dims);
TensorShape(const std::vector<int64_t>& dims) : TensorShape(gsl::make_span(dims)) {}
TensorShape(const std::initializer_list<int64_t>& dims) : TensorShape(gsl::make_span(dims)) {}
TensorShape(const int64_t* dimension_sizes, size_t dimension_count) : TensorShape(gsl::span<const int64_t>(dimension_sizes, dimension_count)) {}
TensorShape(const std::vector<int64_t>& dims, size_t start, size_t end) : TensorShape(gsl::span<const int64_t>(&dims[start], end - start)) {}

TensorShape(std::vector<int64_t>&& dims) : std::vector<int64_t>(std::move(dims)) {}

TensorShape(const std::initializer_list<int64_t>& dims) : std::vector<int64_t>(dims) {}

TensorShape(const int64_t* dimension_sizes, size_t dimension_count);

TensorShape(const std::vector<int64_t>& dims, size_t start, size_t end);
// Create a TensorShape that points to an existing buffer internally. As no copy is made, 'data' must remain valid for the life of the TensorShape
static const TensorShape FromExistingBuffer(const std::vector<int64_t>& data) { return TensorShape(External{}, gsl::span<int64_t>(const_cast<int64_t*>(data.data()), data.size())); }
RyanUnderhill marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can use a new line for the function body

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed, hopefully it'll pass the build tests now.. hitting random breaks unrelated to my change.


/**
Return the dimension specified by <idx>.
*/
const int64_t& operator[](size_t idx) const {
return std::vector<int64_t>::operator[](static_cast<int>(idx));
}

int64_t& operator[](size_t idx) {
return std::vector<int64_t>::operator[](static_cast<int>(idx));
}

bool operator==(const TensorShape& other) const noexcept {
auto thisVector = static_cast<const std::vector<int64_t>*>(this);
auto otherVector = static_cast<const std::vector<int64_t>*>(&other);
return *thisVector == *otherVector;
}
int64_t operator[](size_t idx) const { return values_[idx]; }
int64_t& operator[](size_t idx) { return values_[idx]; }
RyanUnderhill marked this conversation as resolved.
Show resolved Hide resolved

bool operator!=(const TensorShape& other) const noexcept {
return !(*this == other);
}
bool operator==(const TensorShape& other) const noexcept { return GetDims() == other.GetDims(); }
bool operator!=(const TensorShape& other) const noexcept { return GetDims() != other.GetDims(); }

size_t NumDimensions() const noexcept {
return size();
return values_.size();
}

/**
Copy dims into an array with given size
*/
void CopyDims(int64_t* dims, size_t num_dims) const {
memcpy(dims, data(), sizeof(value_type) * std::min(num_dims, NumDimensions()));
memcpy(dims, values_.begin(), sizeof(int64_t) * std::min(num_dims, NumDimensions()));
}

/**
Expand All @@ -78,13 +64,14 @@ class TensorShape : private std::vector<int64_t> {
and this function does no checks to ensure that
*/
void CopyDims(int64_t* dims, size_t start_dim, size_t num_dims) const {
memcpy(dims, data() + start_dim, sizeof(value_type) * std::min(num_dims, NumDimensions() - start_dim));
memcpy(dims, values_.begin() + start_dim, sizeof(int64_t) * std::min(num_dims, NumDimensions() - start_dim));
}

/**
Return underlying vector representation.
*/
const std::vector<int64_t>& GetDims() const { return *this; }
gsl::span<const int64_t> GetDims() const { return values_; }
std::vector<int64_t> GetDimsAsVector() const { return std::vector<int64_t>(values_.begin(), values_.end()); }

/**
* Return the total number of elements. Returns 1 for an empty (rank 0) TensorShape.
Expand Down Expand Up @@ -116,7 +103,7 @@ class TensorShape : private std::vector<int64_t> {
/**
Return a new TensorShape of the dimensions from dimstart to end.
*/
TensorShape Slice(size_t dimstart) const;
TensorShape Slice(size_t dimstart) const { return Slice(dimstart, values_.size()); }

/**
output dimensions nicely formatted
Expand All @@ -134,14 +121,22 @@ class TensorShape : private std::vector<int64_t> {
empty shape or 1D shape (1) is regarded as scalar tensor
*/
bool IsScalar() const {
size_t len = size();
return len == 0 || (len == 1 && operator[](0) == 1);
size_t len = values_.size();
return len == 0 || (len == 1 && values_[0] == 1);
}

static const TensorShape& ReinterpretBaseType(const std::vector<int64_t>& dimensions) {
static_assert(sizeof(TensorShape) == sizeof(std::vector<int64_t>), "Size of TensorShape prevents safe casting from vector");
return *static_cast<const TensorShape*>(&dimensions);
}
private:

struct External {};
TensorShape(External, gsl::span<int64_t> buffer) : values_{buffer} {}

void Allocate(size_t size);

gsl::span<int64_t> values_;
RyanUnderhill marked this conversation as resolved.
Show resolved Hide resolved
int64_t small_buffer_[5];
std::unique_ptr<int64_t[]> allocated_buffer_;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that we could use std::pmr::vector with pool allocator sitting on top of the buffer with upstream new_delete. That way we perfectly simulate small value optimization with first bytes being taking from the inline buffer and bigger buffers from heap. This way we would get rid of unique_ptr and all the logic associated with it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Furthermore, we could abstract all the logic of manipulating mutable shapes inside TensorShape based on pmr and enjoy correctness and the absence of heap.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We talked about it outside of github, here's a summary:

std::pmr::vector would be useful in all of the code that currently uses TensorShape::GetDimsAsVector to take advantage of a small size optimized vector. As TensorShape isn't a vector (no runtime adding/removing of elements) it's not needed there (and would take up more space).

Sadly there's no way to remove the unique_ptr member as without it there is no way to know if the memory is the small block, allocated memory, or external memory from TensorShape::FromExistingBuffer.


friend struct ProviderHostImpl; // So that the shared provider interface can access Allocate
};
RyanUnderhill marked this conversation as resolved.
Show resolved Hide resolved
#ifdef __GNUC__
#pragma GCC diagnostic pop
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
past,
extra_add_qk));

const auto& shape = input->Shape().GetDims();
const auto shape = input->Shape().GetDims();
const int batch_size = static_cast<int>(shape[0]);
const int sequence_length = static_cast<int>(shape[1]);
const int input_hidden_size = static_cast<int>(shape[2]);
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class AttentionCPUBase : public AttentionBase {
BufferUniquePtr mask_data_buffer(mask_data, BufferDeleter(allocator));

const int32_t* mask_index_data = mask_index != nullptr ? mask_index->template Data<int32_t>() : nullptr;
const std::vector<int64_t>* mask_index_dims = mask_index != nullptr ? &(mask_index->Shape().GetDims()) : nullptr;
gsl::span<const int64_t> mask_index_dims = mask_index != nullptr ? mask_index->Shape().GetDims() : gsl::span<const int64_t>{};
const T* past_data = past != nullptr ? past->template Data<T>() : nullptr;
T* present_data = present != nullptr ? present->template MutableData<T>() : nullptr;

Expand Down Expand Up @@ -97,7 +97,7 @@ class AttentionCPUBase : public AttentionBase {
const T* Q, // Q data. Its size is BxNxSxH
const T* K, // k data. Its size is BxNxSxH
const int32_t* mask_index, // mask index. nullptr if no mask or its size is B
const std::vector<int64_t>* mask_index_dims, // mask index shape
gsl::span<const int64_t> mask_index_dims, // mask index shape
T* mask_data, // buffer for mask data. It is nullptr if mask_index is nullptr and not unidirectional, otherwise its shape is BxSxS*
bool has_unidirectional, // has unidirectional mask
int batch_size, // batch size of self-attention
Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/contrib_ops/cpu/bert/attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ inline void ComputeAttentionSoftmaxInplace(float* score, int N, int D, ThreadPoo

template <typename T>
void PrepareMask(const int32_t* mask_index,
const std::vector<int64_t>* mask_index_dims,
gsl::span<const int64_t> mask_index_dims,
T* mask_data,
bool is_unidirectional,
int batch_size,
Expand All @@ -74,12 +74,12 @@ void PrepareMask(const int32_t* mask_index,
T* p_mask = mask_data;

// 4D mask in Megatron GPT2 is currently not support in CPU kernel
if (nullptr != mask_index_dims && mask_index_dims->size() == 4) {
if (nullptr != mask_index && mask_index_dims.size() == 4) {
ORT_NOT_IMPLEMENTED("4D mask in attention cpu kernel is not supported");
}

// For 3D mask, convert values 0 to -10000.0, and 1 to 0.0, then apply unidirectional mask if any.
if (nullptr != mask_index_dims && mask_index_dims->size() == 3) {
if (nullptr != mask_index && mask_index_dims.size() == 3) {
for (int i = 0; i < batch_size * sequence_length * all_sequence_length; i++) {
p_mask[i] = (mask_index[i] > 0) ? static_cast<T>(0.0f) : static_cast<T>(-10000.0f);
}
Expand All @@ -98,8 +98,8 @@ void PrepareMask(const int32_t* mask_index,
return;
}

bool is_raw_attention_mask = (nullptr != mask_index_dims && mask_index_dims->size() == 2);
bool has_mask_start_position = (nullptr != mask_index_dims && mask_index_dims->size() == 1 && static_cast<int>(mask_index_dims->at(0)) == 2 * batch_size);
bool is_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() == 2);
bool has_mask_start_position = (nullptr != mask_index && mask_index_dims.size() == 1 && static_cast<int>(mask_index_dims.at(0)) == 2 * batch_size);

for (int b_i = 0; b_i < batch_size; b_i++) {
// TODO: mask_index can be used in softmax to save some calculation.
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/expand_dims.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ExpandDims final : public OpKernel {
if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
const TensorShape& X_shape = X->Shape();

std::vector<int64_t> expanded_shape(X_shape.GetDims());
std::vector<int64_t> expanded_shape(X_shape.GetDimsAsVector());

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allocation and copy into heap buffer, the thing this PR is trying to avoid.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, to fix further things we'll need to have a small block optimized vector that allows modifying, as a few lines later this code exists:

  expanded_shape.insert(expanded_shape.begin() + axis, 1);

We'd need to replace all instances of code like this to use a replacement for std::vector that doesn't allocate for small sizes. This would be easy to incrementally do in a later change (by just searching for all of the GetDimsAsVector()s)

int64_t X_NumDims = X_shape.Size();
ORT_ENFORCE(axis <= X_NumDims && axis >= -X_NumDims,
"Axis must be within range [", -X_NumDims, ", ", X_NumDims, "].", " Axis is ", axis);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/qlinear_global_average_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ Status QLinearGlobalAveragePool::Compute(OpKernelContext* context) const {
int64_t image_size = std::accumulate(x_shape.cbegin() + spatial_dim_start, x_shape.cbegin() + spatial_dim_end,
1LL, std::multiplies<int64_t>());

std::vector<int64_t> output_dims(x_shape);
std::vector<int64_t> output_dims(x_shape.begin(), x_shape.end());
std::transform(x_shape.cbegin() + spatial_dim_start, x_shape.cbegin() + spatial_dim_end,
output_dims.begin() + spatial_dim_start, [](const int64_t&) { return int64_t{1}; });
Tensor& Y = *context->Output(0, output_dims);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/qlinear_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ Status QLinearAveragePool::Compute(OpKernelContext* context) const {
std::vector<int64_t> kernel_shape = pool_attrs_.kernel_shape;

if (channels_last_) {
std::vector<int64_t> x_dims = x_shape.GetDims();
std::vector<int64_t> x_dims = x_shape.GetDimsAsVector();
SwitchDimsNchwNhwc(x_dims, false);
x_shape = TensorShape(x_dims);
}
Expand Down
20 changes: 10 additions & 10 deletions onnxruntime/contrib_ops/cpu/tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ class Tokenizer final : public OpKernel {

private:
Status CharTokenize(OpKernelContext* context, size_t N, size_t C,
const std::vector<int64_t>& input_dims) const;
gsl::span<const int64_t> input_dims) const;

Status SeparatorExpressionTokenizer(OpKernelContext* context, size_t N, size_t C,
const std::vector<int64_t>& input_dims) const;
gsl::span<const int64_t> input_dims) const;

Status TokenExpression(OpKernelContext* ctx,
size_t N, size_t C,
const std::vector<int64_t>& input_dims) const;
gsl::span<const int64_t> input_dims) const;

bool mark_{false};
std::string pad_value_;
Expand Down Expand Up @@ -114,7 +114,7 @@ Tokenizer::Tokenizer(const OpKernelInfo& info) : OpKernel(info) {
}

Status Tokenizer::CharTokenize(OpKernelContext* ctx, size_t N, size_t C,
const std::vector<int64_t>& input_dims) const {
gsl::span<const int64_t> input_dims) const {
// With char tokenzation we get as many tokens as the number of
// utf8 characters in the string. So for every string we calculate its character(utf8) length
// add padding and add start/end test separators if necessary
Expand All @@ -137,7 +137,7 @@ Status Tokenizer::CharTokenize(OpKernelContext* ctx, size_t N, size_t C,
++curr_input;
}

std::vector<int64_t> output_dims(input_dims);
std::vector<int64_t> output_dims(input_dims.begin(), input_dims.end());
// Check if we have no output due to apparently empty strings input.
if (max_tokens == 0) {
output_dims.push_back(0);
Expand Down Expand Up @@ -193,7 +193,7 @@ Status Tokenizer::CharTokenize(OpKernelContext* ctx, size_t N, size_t C,

Status Tokenizer::SeparatorExpressionTokenizer(OpKernelContext* ctx,
size_t N, size_t C,
const std::vector<int64_t>& input_dims) const {
gsl::span<const int64_t> input_dims) const {
using namespace re2;
std::vector<std::vector<StringPiece>> rows;
rows.reserve(N * C);
Expand Down Expand Up @@ -276,7 +276,7 @@ Status Tokenizer::SeparatorExpressionTokenizer(OpKernelContext* ctx,
++curr_input;
}

std::vector<int64_t> output_dims(input_dims);
std::vector<int64_t> output_dims(input_dims.begin(), input_dims.end());
// Check if we have no output due to either empty input
// everything is a separator
if (max_tokens == 0) {
Expand Down Expand Up @@ -334,7 +334,7 @@ Status Tokenizer::SeparatorExpressionTokenizer(OpKernelContext* ctx,

Status Tokenizer::TokenExpression(OpKernelContext* ctx,
size_t N, size_t C,
const std::vector<int64_t>& input_dims) const {
gsl::span<const int64_t> input_dims) const {
using namespace re2;
// Represents a token that will be output after
// first is the index, second is the size;
Expand Down Expand Up @@ -400,7 +400,7 @@ Status Tokenizer::TokenExpression(OpKernelContext* ctx,
}

// Check for empty output
std::vector<int64_t> output_dims(input_dims);
std::vector<int64_t> output_dims(input_dims.begin(), input_dims.end());
// Check if we have no output due to either empty input
// everything is a separator
if (max_tokens == 0) {
Expand Down Expand Up @@ -468,7 +468,7 @@ Status Tokenizer::Compute(OpKernelContext* ctx) const {
}

auto& input_shape = X->Shape();
auto& input_dims = input_shape.GetDims();
auto input_dims = input_shape.GetDims();
size_t N = 0;
size_t C = 0;
if (input_dims.size() == 1) {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
Stream(),
reinterpret_cast<const CudaT*>(gemm_buffer.get()),
nullptr == mask_index ? nullptr : mask_index->template Data<int>(),
nullptr == mask_index ? nullptr : &(mask_index->Shape().GetDims()),
nullptr == mask_index ? gsl::span<const int64_t>() : mask_index->Shape().GetDims(),
output->template MutableData<T>(),
batch_size,
sequence_length,
Expand Down
14 changes: 7 additions & 7 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ bool QkvToContext(
const cudaDeviceProp& prop, cublasHandle_t& cublas, cudaStream_t stream,
const int batch_size, const int sequence_length, const int num_heads, const int head_size, const size_t element_size,
const T* input, T* output, T* workspace,
const int* mask_index, const std::vector<int64_t>* mask_index_dims,
const int* mask_index, gsl::span<const int64_t> mask_index_dims,
bool is_unidirectional, int past_sequence_length, const T* past, const T* extra_add_qk, T* present, bool use_persistent_softmax) {
const int all_sequence_length = past_sequence_length + sequence_length;
const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, all_sequence_length);
Expand Down Expand Up @@ -106,7 +106,7 @@ bool QkvToContext(
}

// Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max sequence length.
bool use_raw_attention_mask = (nullptr != mask_index && nullptr != mask_index_dims && mask_index_dims->size() >= 2);
bool use_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() >= 2);

// compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS*
// Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS*
Expand All @@ -126,8 +126,8 @@ bool QkvToContext(

// apply softmax and store result P to scratch2: BxNxSxS*
if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask
const int mask_dimension = static_cast<int>(mask_index_dims->size());
const int64_t max_sequence_length = mask_dimension == 4 ? mask_index_dims->at(3) : 0;
const int mask_dimension = static_cast<int>(mask_index_dims.size());
const int64_t max_sequence_length = mask_dimension == 4 ? mask_index_dims.at(3) : 0;

T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score if persistent softmax is selected.
if (!ComputeSoftmaxWithRawMask<T>(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, extra_add_qk, scratch1, scratch2,
Expand All @@ -136,9 +136,9 @@ bool QkvToContext(
return false;
}
} else if (nullptr != mask_index) { // 1d mask index
ORT_ENFORCE(nullptr != mask_index_dims && mask_index_dims->size() == 1);
ORT_ENFORCE(mask_index_dims.size() == 1);
// mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions.
const int* mask_start = (mask_index_dims->at(0) > batch_size) ? mask_index + batch_size : nullptr;
const int* mask_start = (mask_index_dims.at(0) > batch_size) ? mask_index + batch_size : nullptr;
if (!ComputeSoftmaxWithMask1D<T>(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, mask_start, extra_add_qk, scratch1, scratch2, is_unidirectional)) {
return false;
}
Expand All @@ -164,7 +164,7 @@ bool LaunchAttentionKernel(
cudaStream_t stream,
const void* input,
const int* mask_index,
const std::vector<int64_t>* mask_index_dims,
gsl::span<const int64_t> mask_index_dims,
void* output,
const int batch_size,
const int sequence_length,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ bool LaunchAttentionKernel(
cudaStream_t stream, // cuda stream
const void* input, // Input tensor
const int* mask_index, // Attention mask raw data or index (end position of each sequence, or end positions and start positions). NULL means no mask.
const std::vector<int64_t>* mask_index_dims, // Mask index shape
gsl::span<const int64_t> mask_index_dims, // Mask index shape
void* output, // Output tensor
int batch_size, // Batch size (B)
int sequence_length, // Sequence length (S)
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/fused_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class FusedConv : public onnxruntime::cuda::Conv<T> {
if (Base::s_.post_slicing_required) {
ORT_RETURN_IF_ERROR(onnxruntime::cuda::SliceOutUnwantedOutputSection(
this->Stream(), Base::s_.y_data, Base::s_.y_dims_with_adjusted_pads, Base::s_.Y->MutableDataRaw(),
Base::s_.y_dims, Base::s_.slice_starts, Base::s_.slice_ends, Base::s_.slice_axes, Base::s_.element_size));
Base::s_.y_dims.GetDims(), Base::s_.slice_starts, Base::s_.slice_ends, Base::s_.slice_axes, Base::s_.element_size));
}
return Status::OK();
}
Expand Down
Loading