Skip to content

Commit

Permalink
[GPU] KV-cache compression support (openvinotoolkit#27114)
Browse files Browse the repository at this point in the history
### Details:
This PR enables KV-cache compression support
Currently, it supports only combinations of the following
configurations:
* Data types: INT8_SYM / INT8_ASYM
* Modes: per-token (quantization of `num_heads * head_size` in a single
group) / per-token-per-head (quantization of each `head_size` group for
each head per token)

### Tickets:
 - *ticket-id*
  • Loading branch information
sshlyapn authored Oct 28, 2024
1 parent 2486a7f commit 0d113d9
Show file tree
Hide file tree
Showing 69 changed files with 3,447 additions and 343 deletions.
68 changes: 57 additions & 11 deletions src/common/transformations/include/ov_ops/dynamic_quantize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,75 @@ namespace internal {
/// \brief Operator performing Dynamic Quantize
class TRANSFORMATIONS_API DynamicQuantize : public ov::op::Op {
public:
OPENVINO_OP("DynamicQuantize", "gpu_opset");
OPENVINO_OP("DynamicQuantize", "ie_internal_opset");

/**
* @brief Configuration for the type of quantization applied to the data:
* - Symmetric: Quantization where the zero point is fixed at zero, and the range is symmetric around zero.
* - Asymmetric: Quantization where the zero point is not fixed at zero.
*/
enum class QuantizationType { Symmetric, Asymmetric };

/**
* @brief Configuration for how Activations, Scales and Zero Points will be stored in output buffers:
* - Planar: Activations, Scales, and Zero Points are stored in independent buffers.
* - InterleavedScalesZP: Activations are stored in an independent buffer, while Scales and Zero Points (if any) are
* combined in a separate buffer.
*/
enum class OutputStorageType { Planar, InterleavedScalesZP, /* InterleavedActivationsScalesZP */ };

/// \brief Structure that specifies attributes for interpolation
struct Attributes {
QuantizationType quantization_type = QuantizationType::Symmetric;
element::Type quantization_dt = element::undefined;
element::Type scale_dt = element::undefined;
element::Type zp_dt = element::undefined;

std::vector<uint64_t> group_sizes = {};
std::vector<uint64_t> scales_zp_output_order = {};
OutputStorageType output_storage_type = OutputStorageType::Planar;
};

DynamicQuantize() = default;
/// \brief Constructs an DynamicQuantize operation.
///
/// \param data Input tensor with data
/// \param group_sizes Group sizes for dynamic quantization
/// \param dt_scale Data type for scale output
DynamicQuantize(const Output<Node>& data, std::vector<uint64_t> group_sizes, element::Type dt_scale);
/// \param config Dynamic quantization configuration
DynamicQuantize(const Output<Node>& data, const Attributes& attrs);

void validate_and_infer_types() override;

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

const Attributes& get_attrs() const {
return m_attrs;
}

void set_attrs(Attributes attrs) {
m_attrs = std::move(attrs);
}

const std::vector<uint64_t>& get_group_sizes() const {
return m_group_sizes;
};
return m_attrs.group_sizes;
}

QuantizationType get_quantization_type() const {
return m_attrs.quantization_type;
}

OutputStorageType get_output_storage_type() const {
return m_attrs.output_storage_type;
}

const std::vector<uint64_t>& get_scales_zp_output_order() const {
return m_attrs.scales_zp_output_order;
}

static std::vector<ov::PartialShape> shape_infer(const DynamicQuantize* op,
const std::vector<ov::PartialShape>& input_shapes,
const std::vector<uint64_t>& group_sizes);
const std::vector<ov::PartialShape>& input_shapes);

private:
std::vector<uint64_t> m_group_sizes;
element::Type m_dt_scale;
protected:
Attributes m_attrs;
};

} // namespace internal
Expand Down
91 changes: 71 additions & 20 deletions src/common/transformations/src/ov_ops/dynamic_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,62 +7,113 @@
#include "openvino/core/partial_shape.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/variadic_split.hpp"
#include "variadic_split_shape_inference.hpp"
#include "openvino/util/common_util.hpp"

namespace ov {
namespace op {
namespace internal {

DynamicQuantize::DynamicQuantize(const Output<Node>& data, std::vector<uint64_t> group_sizes, element::Type dt_scale)
: Op({data}),
m_group_sizes(std::move(group_sizes)),
m_dt_scale(dt_scale) {
OPENVINO_ASSERT(data.get_partial_shape().rank() == m_group_sizes.size(),
"FC input rank should be same as the rank of group_size ",
DynamicQuantize::DynamicQuantize(const Output<Node>& data, const Attributes& attrs) : Op({data}), m_attrs(attrs) {
if (m_attrs.scales_zp_output_order.empty()) {
m_attrs.scales_zp_output_order.resize(data.get_partial_shape().size());
std::iota(m_attrs.scales_zp_output_order.begin(), m_attrs.scales_zp_output_order.end(), 0);
}

OPENVINO_ASSERT(data.get_partial_shape().rank() == m_attrs.group_sizes.size(),
"DQ input rank should be same as the rank of group_size ",
data.get_tensor_ptr()->get_partial_shape().rank(),
" / ",
m_group_sizes.size());
set_output_size(2);
m_attrs.group_sizes.size());

OPENVINO_ASSERT(data.get_partial_shape().size() == m_attrs.scales_zp_output_order.size(),
"DQ input rank should be same as the rank of scales and zero points output order)");

size_t outputs_number = 2;
if (m_attrs.quantization_type == QuantizationType::Asymmetric &&
m_attrs.output_storage_type == OutputStorageType::Planar)
outputs_number = 3;

OPENVINO_ASSERT(
(m_attrs.output_storage_type == OutputStorageType::Planar) ||
(m_attrs.quantization_type == QuantizationType::Asymmetric && m_attrs.scale_dt == m_attrs.zp_dt),
"Scales and Zero Points should have the same data type to be stored in the single buffer");

set_output_size(outputs_number);
validate_and_infer_types();
}

void DynamicQuantize::validate_and_infer_types() {
std::vector<ov::PartialShape> input_shapes = {get_input_partial_shape(0)};

auto out_shapes = shape_infer(this, input_shapes, m_group_sizes);
set_output_type(0, element::i8, out_shapes[0]);
set_output_type(1, m_dt_scale, out_shapes[1]);
auto out_shapes = shape_infer(this, input_shapes);
set_output_type(0, m_attrs.quantization_dt, out_shapes[0]);
set_output_type(1, m_attrs.scale_dt, out_shapes[1]);

if (m_attrs.quantization_type == QuantizationType::Asymmetric &&
m_attrs.output_storage_type == OutputStorageType::Planar)
set_output_type(2, m_attrs.zp_dt, out_shapes[2]);
}

std::shared_ptr<Node> DynamicQuantize::clone_with_new_inputs(const ov::OutputVector& new_args) const {
check_new_args_count(this, new_args);
return std::make_shared<DynamicQuantize>(new_args.at(0), m_group_sizes, m_dt_scale);
return std::make_shared<DynamicQuantize>(new_args.at(0), m_attrs);
}

std::vector<ov::PartialShape> DynamicQuantize::shape_infer(const DynamicQuantize* op,
const std::vector<ov::PartialShape>& input_shapes,
const std::vector<uint64_t>& group_sizes) {
const std::vector<ov::PartialShape>& input_shapes) {
std::vector<ov::PartialShape> out_shapes;
out_shapes.push_back(input_shapes[0]);

auto scale_shape = input_shapes[0];
const auto& group_sizes = op->m_attrs.group_sizes;
OPENVINO_ASSERT(scale_shape.size() == group_sizes.size(),
"Scale_shape and group_size are supposed to have same rank: ",
scale_shape.size(),
" / ",
group_sizes.size());
for (size_t i = 0; i < scale_shape.size(); i++) {
if (scale_shape[i].is_dynamic())
if (scale_shape[i].is_dynamic() || scale_shape[i] == 0)
continue;

if (group_sizes[i] == UINT64_MAX)
if (group_sizes[i] == UINT64_MAX) {
scale_shape[i] = 1;
else {
scale_shape[i] /= group_sizes[i]; // if group_size is larger than shape, scale_shape will be 1
scale_shape[i] = std::max(static_cast<int>(scale_shape[i].get_length()), 1);
} else {
scale_shape[i] = ov::util::ceil_div(scale_shape[i].get_length(), static_cast<int64_t>(group_sizes[i]));
}
}
out_shapes.push_back(scale_shape);

// Add zero points shape, same as the scales
if (op->m_attrs.quantization_type == QuantizationType::Asymmetric &&
op->m_attrs.output_storage_type == OutputStorageType::Planar)
out_shapes.push_back(scale_shape);

auto transpose_shape = [](const ov::PartialShape& shape, const std::vector<uint64_t>& scales_zp_output_order) {
auto transposed_shape = shape;
for (size_t i = 0; i < scales_zp_output_order.size(); i++) {
OPENVINO_ASSERT(scales_zp_output_order[i] < transposed_shape.size());
transposed_shape[i] = shape[scales_zp_output_order[i]];
}

return transposed_shape;
};

// Transpose scales and zero points shapes
const auto& scales_zp_output_order = op->m_attrs.scales_zp_output_order;
for (size_t i = 1; i < out_shapes.size(); i++) {
out_shapes[i] = transpose_shape(out_shapes[i], scales_zp_output_order);
}

if (op->m_attrs.quantization_type == QuantizationType::Asymmetric &&
op->m_attrs.output_storage_type != OutputStorageType::Planar) {
// Currently scales and zero points are supposed to be combined over the last dimension only
const auto combine_axis = scales_zp_output_order.empty() ? out_shapes[1].size() - 1
: scales_zp_output_order[out_shapes[1].size() - 1];
OPENVINO_ASSERT(group_sizes[combine_axis] != 1);

out_shapes[1][combine_axis] *= 2; // [scale, zero_point] pairs
}

return out_shapes;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ struct kernel_impl_params final {
optional_layout weights_zero_points_layout = optional_layout();
optional_layout activations_zero_points_layout = optional_layout();
optional_layout compensation_layout = optional_layout();
optional_layout state_layout = optional_layout();
std::vector<layout> state_layouts;

std::map<size_t, memory::ptr> memory_deps = {};
size_t primary_input_idx = 0;
Expand Down
11 changes: 11 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/indirect_sdpa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ class IndirectSDPA : public ov::intel_gpu::op::SDPA {
const std::vector<int64_t>& order_out,
const ov::element::Type output_type = ov::element::undefined);

IndirectSDPA(const OutputVector& data_inputs,
const ov::Output<Node>& beam_table,
const bool is_causal,
const int64_t indirect_axis,
const std::vector<int64_t>& order_q,
const std::vector<int64_t>& order_k,
const std::vector<int64_t>& order_v,
const std::vector<int64_t>& order_out,
const QuantizationAttribute& quantization_attribute,
const ov::element::Type output_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor &visitor) override;
void validate_and_infer_types() override;

Expand Down
17 changes: 13 additions & 4 deletions src/plugins/intel_gpu/include/intel_gpu/op/kv_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "openvino/op/op.hpp"
#include "openvino/op/util/variable.hpp"
#include "openvino/op/util/variable_extension.hpp"
#include "ov_ops/dynamic_quantize.hpp"

namespace ov {
namespace intel_gpu {
Expand All @@ -22,16 +23,16 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension {

KVCache(const Output<Node>& past,
const Output<Node>& new_token_data,
const Output<Node>& beam_idx,
const std::shared_ptr<ov::op::util::Variable>& past_values,
int64_t concat_axis,
int64_t gather_axis,
const ov::element::Type output_type = ov::element::undefined);

KVCache(const Output<Node>& past,
const Output<Node>& new_token_data,
const Output<Node>& beam_idx,
const std::shared_ptr<ov::op::util::Variable>& past_values,
int64_t concat_axis,
int64_t gather_axis,
const ov::element::Type output_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor& visitor) override;
Expand All @@ -53,14 +54,22 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension {

bool get_indirect() const { return m_indirect; }

private:
protected:
KVCache(const OutputVector& inputs,
const std::shared_ptr<ov::op::util::Variable>& past_values,
bool indirect,
int64_t concat_axis,
int64_t gather_axis,
const ov::element::Type output_type = ov::element::undefined);

int64_t m_concat_axis = 0;
int64_t m_gather_axis = 0;
bool m_indirect = false;

ov::element::Type m_output_type;
};

std::vector<ov::PartialShape> shape_infer(const KVCache* op, std::vector<ov::PartialShape> input_shapes);
std::vector<ov::PartialShape> shape_infer(const KVCache* op, const std::vector<ov::PartialShape>& input_shapes);

} // namespace op
} // namespace intel_gpu
Expand Down
56 changes: 56 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/kv_cache_compressed.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "intel_gpu/op/kv_cache.hpp"
#include "ov_ops/dynamic_quantize.hpp"

namespace ov {
namespace intel_gpu {
namespace op {

/// \brief Operator that implements Key-Values cache subgraph for large language models.
/// This operation updates data of the corresponding Variable
class KVCacheCompressed : public ov::intel_gpu::op::KVCache {
public:
OPENVINO_OP("KVCacheCompressed", "gpu_opset");

using QuantizationAttrs = ov::op::internal::DynamicQuantize::Attributes;

KVCacheCompressed() = default;

KVCacheCompressed(const OutputVector& inputs,
const std::shared_ptr<ov::op::util::Variable>& past_values,
int64_t concat_axis,
int64_t gather_axis,
const QuantizationAttrs& quantization_attrs,
const ov::element::Type output_type = ov::element::undefined);

void validate_and_infer_types() override;

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

bool get_kv_compressed() const { return m_compressed; }
bool get_combine_scales_and_zp() const {
return m_quantization_attrs.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric &&
m_quantization_attrs.output_storage_type != ov::op::internal::DynamicQuantize::OutputStorageType::Planar;
}

QuantizationAttrs get_quantization_attrs() const { return m_quantization_attrs; }
void set_quantization_attrs(QuantizationAttrs attrs) { m_quantization_attrs = std::move(attrs); }

std::vector<uint64_t> get_scales_zp_output_order() const { return m_quantization_attrs.scales_zp_output_order; }

private:
bool m_compressed;
QuantizationAttrs m_quantization_attrs = {};
};

std::vector<ov::PartialShape> shape_infer(const KVCacheCompressed* op,
const std::vector<ov::PartialShape>& input_shapes);

} // namespace op
} // namespace intel_gpu
} // namespace ov
7 changes: 7 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/read_value.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,20 @@ class ReadValue : public ov::op::Op, public ov::op::util::VariableExtension {
bool visit_attributes(ov::AttributeVisitor& visitor) override;

void validate_and_infer_types() override;
void validate_and_infer_types(size_t output_idx, const ov::op::util::VariableInfo& variable_info);

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

std::string get_variable_id() const override {
OPENVINO_ASSERT(m_variable, "Variable is not initialized. Variable_id is unavailable");
return m_variable->get_info().variable_id;
}

protected:
ReadValue(const std::vector<Output<Node>>& variable_initializers, const std::shared_ptr<ov::op::util::Variable>& variable)
: Op(variable_initializers) {
m_variable = variable;
}
};

} // namespace op
Expand Down
Loading

0 comments on commit 0d113d9

Please sign in to comment.