Skip to content

Commit

Permalink
[GPU]: Added stub for SearchSorted
Browse files Browse the repository at this point in the history
  • Loading branch information
pkowalc1 committed Oct 31, 2024
1 parent 9f6826a commit 7835bb3
Show file tree
Hide file tree
Showing 16 changed files with 482 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ REGISTER_FACTORY(v13, BitwiseXor);
REGISTER_FACTORY(v15, ROIAlignRotated);
REGISTER_FACTORY(v15, BitwiseRightShift);
REGISTER_FACTORY(v15, BitwiseLeftShift);
REGISTER_FACTORY(v15, SearchSorted);

// --------------------------- Supported internal ops --------------------------- //
REGISTER_FACTORY(internal, NonMaxSuppressionIEInternal);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once
#include <algorithm>
#include <vector>

#include "openvino/op/util/attr_types.hpp"
#include "primitive.hpp"

namespace cldnn {

struct search_sorted : public primitive_base<search_sorted> {
CLDNN_DECLARE_PRIMITIVE(search_sorted)

search_sorted() : primitive_base("", {}) {}

size_t hash() const override {
size_t seed = primitive::hash();
return seed;
}

bool operator==(const primitive& rhs) const override {
return compare_common_params(rhs);
}

void save(BinaryOutputBuffer& ob) const override {
primitive_base<search_sorted>::save(ob);
}

void load(BinaryInputBuffer& ib) override {
primitive_base<search_sorted>::load(ib);
}
};
} // namespace cldnn
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ REGISTER_OCL(unique_count);
REGISTER_OCL(unique_gather);
REGISTER_OCL(scaled_dot_product_attention);
REGISTER_OCL(rope);
REGISTER_OCL(search_sorted);

#undef REGISTER_OCL

Expand Down
65 changes: 65 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/search_sorted.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "primitive_base.hpp"

#include "search_sorted_inst.h"
#include "search_sorted/search_sorted_kernel_selector.h"
#include "search_sorted/search_sorted_kernel_base.h"

namespace cldnn {
namespace ocl {

struct search_sorted_impl : typed_primitive_impl_ocl<search_sorted> {
using parent = typed_primitive_impl_ocl<search_sorted>;
using parent::parent;
using kernel_selector_t = kernel_selector::search_sorted_kernel_selector;
using kernel_params_t = kernel_selector::search_sorted_params;

DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::ocl::search_sorted_impl)

std::unique_ptr<primitive_impl> clone() const override {
return make_unique<search_sorted_impl>(*this);
}

static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param) {
const auto& primitive = impl_param.typed_desc<search_sorted>();
auto params = get_default_params<kernel_selector::search_sorted_params>(impl_param);

// params.search_sorted_axis = primitive->search_sorted_axis;
// params.on_value = primitive->on_value;
// params.off_value = primitive->off_value;

auto output_sizes = impl_param.get_output_layout().get_dims();

//params.search_sorted_limit = output_sizes[params.search_sorted_axis];
return params;
}
};

namespace detail {

attach_search_sorted_impl::attach_search_sorted_impl() {
implementation_map<search_sorted>::add(impl_types::ocl, typed_primitive_impl_ocl<search_sorted>::create<search_sorted_impl>, {
std::make_tuple(data_types::i8, format::bfyx),
std::make_tuple(data_types::u8, format::bfyx),
std::make_tuple(data_types::i32, format::bfyx),
std::make_tuple(data_types::i64, format::bfyx),
std::make_tuple(data_types::f32, format::bfyx),
std::make_tuple(data_types::f16, format::bfyx),
std::make_tuple(data_types::i8, format::bfzyx),
std::make_tuple(data_types::u8, format::bfzyx),
std::make_tuple(data_types::i32, format::bfzyx),
std::make_tuple(data_types::i64, format::bfzyx),
std::make_tuple(data_types::f32, format::bfzyx),
std::make_tuple(data_types::f16, format::bfzyx),
});
}

} // namespace detail
} // namespace ocl
} // namespace cldnn

BIND_BINARY_BUFFER_WITH_TYPE(cldnn::ocl::search_sorted_impl)
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::search_sorted)
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,4 @@ REGISTER_DEFAULT_IMPLS(unique_count, OCL_S, OCL_D);
REGISTER_DEFAULT_IMPLS(unique_gather, OCL_S, OCL_D);
REGISTER_DEFAULT_IMPLS(scaled_dot_product_attention, OCL_S, OCL_D);
REGISTER_DEFAULT_IMPLS(rope, OCL_S, OCL_D);
REGISTER_DEFAULT_IMPLS(search_sorted, OCL_S, OCL_D);
53 changes: 53 additions & 0 deletions src/plugins/intel_gpu/src/graph/include/search_sorted_inst.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once

#include <intel_gpu/primitives/search_sorted.hpp>

#include "primitive_inst.h"

namespace cldnn {

template <>
struct typed_program_node<search_sorted> : public typed_program_node_base<search_sorted> {
using parent = typed_program_node_base<search_sorted>;
typed_program_node(const std::shared_ptr<search_sorted> prim, program& prog) : parent(prim, prog) {}

public:
using parent::parent;
program_node& input() const {
return get_dependency(0);
}
std::vector<size_t> get_shape_infer_dependencies() const override {
return {};
}
};

using search_sorted_node = typed_program_node<search_sorted>;

template <>
class typed_primitive_inst<search_sorted> : public typed_primitive_inst_base<search_sorted> {
using parent = typed_primitive_inst_base<search_sorted>;
using parent::parent;

public:
template <typename ShapeType>
static std::vector<layout> calc_output_layouts(search_sorted_node const& node,
kernel_impl_params const& impl_param);
static layout calc_output_layout(search_sorted_node const& node, kernel_impl_params const& impl_param);
static std::string to_string(search_sorted_node const& node);

public:
typed_primitive_inst(network& network, search_sorted_node const& desc);
memory::ptr rois_memory() const {
return dep_memory_ptr(1);
}
memory::ptr batches_memory() const {
return dep_memory_ptr(2);
}
};

using search_sorted_inst = typed_primitive_inst<search_sorted>;

} // namespace cldnn
47 changes: 47 additions & 0 deletions src/plugins/intel_gpu/src/graph/search_sorted.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <search_sorted_inst.h>
#include "primitive_type_base.h"
#include <sstream>
#include <json_object.h>
#include "openvino/core/enum_names.hpp"
#include "search_sorted_shape_inference.hpp"

namespace cldnn {
GPU_DEFINE_PRIMITIVE_TYPE_ID(search_sorted)

search_sorted_inst::typed_primitive_inst(network& network, search_sorted_node const& node)
: parent(network, node) {}

layout search_sorted_inst::calc_output_layout(search_sorted_node const& node, kernel_impl_params const& impl_param) {
auto primitive = impl_param.typed_desc<search_sorted>();
auto input_layout = impl_param.get_input_layout(0);
return layout();
}

template<typename ShapeType>
std::vector<layout> search_sorted_inst::calc_output_layouts(search_sorted_node const& node, kernel_impl_params const& impl_param) {
return std::vector<layout>();
}

std::string search_sorted_inst::to_string(search_sorted_node const& node) {
auto node_info = node.desc_to_json();
json_composite search_sorted_info;
// search_sorted_info.add("input id", node.input().id());
// search_sorted_info.add("rois id", node.get_dependency(1).id());
// search_sorted_info.add("batches id", node.get_dependency(2).id());
// search_sorted_info.add("pooled_h", node.get_primitive()->pooled_h);
// search_sorted_info.add("pooled_w", node.get_primitive()->pooled_w);
// search_sorted_info.add("sampling_ratio", node.get_primitive()->sampling_ratio);
// search_sorted_info.add("spatial_scale", node.get_primitive()->spatial_scale);
// search_sorted_info.add("pooling_mode", ov::as_string(node.get_primitive()->pooling_mode));
// search_sorted_info.add("aligned_mode", ov::as_string(node.get_primitive()->aligned_mode));
node_info->add("search_sorted info", search_sorted_info);
std::stringstream primitive_description;
node_info->dump(primitive_description);
return primitive_description.str();
}

} // namespace cldnn
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "include/batch_headers/fetch_data.cl"

KERNEL(search_sorted_ref)(const __global INPUT0_TYPE* input,
__global OUTPUT_TYPE* output)
{

}
3 changes: 2 additions & 1 deletion src/plugins/intel_gpu/src/kernel_selector/common_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ enum class KernelType {
RMS,
SWIGLU,
ROPE,
DYNAMIC_QUANTIZE
DYNAMIC_QUANTIZE,
SEARCH_SORTED
};

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "search_sorted_kernel_base.h"
#include <vector>
#include "kernel_selector_utils.h"

namespace kernel_selector {
JitConstants SearchSortedKernelBase::GetJitConstants(const search_sorted_params& params) const {
JitConstants jit = MakeBaseParamsJitConstants(params);

// jit.AddConstants({
// MakeJitConstant("search_sorted_AXIS", params.search_sorted_axis),
// MakeJitConstant("search_sorted_LIMIT", params.search_sorted_limit),
// MakeJitConstant("ON_VALUE", params.on_value),
// MakeJitConstant("OFF_VALUE", params.off_value)
// });

return jit;
}

SearchSortedKernelBase::DispatchData SearchSortedKernelBase::SetDefault(const search_sorted_params& params) {
const auto& input = params.inputs[0];
auto in_layout = params.inputs[0].GetLayout();
auto out_layout = params.outputs[0].GetLayout();
std::vector<std::vector<Tensor::DataChannelName>> dims_by_gws;

DispatchData dispatchData;
if (params.outputs[0].GetDims().size() == 5) {
dispatchData.gws = { input.Batch().v, input.Feature().v * input.Z().v, input.Y().v * input.X().v };
dims_by_gws = {{ Tensor::DataChannelName::BATCH },
{ Tensor::DataChannelName::Z, Tensor::DataChannelName::FEATURE },
{ Tensor::DataChannelName::X, Tensor::DataChannelName::Y }};
} else {
dispatchData.gws = { input.Batch().v, input.Feature().v, input.Y().v * input.X().v };
dims_by_gws = {{ Tensor::DataChannelName::BATCH },
{ Tensor::DataChannelName::FEATURE },
{ Tensor::DataChannelName::X, Tensor::DataChannelName::Y }};
}
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo, in_layout, out_layout, dims_by_gws);

return dispatchData;
}

KernelsData SearchSortedKernelBase::GetCommonKernelsData(const Params& params) const {
assert(params.GetType() == KernelType::SEARCH_SORTED);

const auto& prim_params =
static_cast<const search_sorted_params&>(params);

auto dispatchData = SetDefault(prim_params);
KernelData k_data = KernelData::Default<search_sorted_params>(params);

auto cldnn_jit = GetJitConstants(prim_params);
auto entry_point = GetEntryPoint(kernelName, prim_params.layerID, params);
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);

auto& kernel = k_data.kernels[0];
FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point);

return {k_data};
}
} // namespace kernel_selector
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "kernel_base_opencl.h"
#include "kernel_selector_params.h"

namespace kernel_selector {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// search_sorted
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
struct search_sorted_params : public base_params {
search_sorted_params() : base_params(KernelType::ONE_HOT),
one_hot_axis(0), one_hot_limit(0), on_value(1.0), off_value(1.0) {}
uint16_t one_hot_axis;
int32_t one_hot_limit;
float on_value;
float off_value;
};

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// SearchSortedKernelBase
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class SearchSortedKernelBase : public KernelBaseOpenCL {
public:
using KernelBaseOpenCL::KernelBaseOpenCL;

using DispatchData = CommonDispatchData;

protected:
JitConstants GetJitConstants(const search_sorted_params& params) const;
static DispatchData SetDefault(const search_sorted_params& params);
KernelsData GetCommonKernelsData(const Params& params) const;
};
} // namespace kernel_selector
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "search_sorted_kernel_ref.h"

namespace kernel_selector {
ParamsKey SearchSortedKernelRef::GetSupportedKey() const {
ParamsKey k;

k.EnableInputDataType(Datatype::INT8);
k.EnableInputDataType(Datatype::UINT8);
k.EnableInputDataType(Datatype::INT32);
k.EnableInputDataType(Datatype::INT64);
k.EnableInputDataType(Datatype::F32);
k.EnableInputDataType(Datatype::F16);

k.EnableOutputDataType(Datatype::INT8);
k.EnableOutputDataType(Datatype::UINT8);
k.EnableOutputDataType(Datatype::INT32);
k.EnableOutputDataType(Datatype::INT64);
k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::F16);

k.EnableInputLayout(DataLayout::bfyx);
k.EnableInputLayout(DataLayout::bfzyx);

k.EnableOutputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfzyx);

k.EnableTensorOffset();
k.EnableTensorPitches();
k.EnableBatching();
k.EnableDifferentTypes();

return k;
}

KernelsData SearchSortedKernelRef::GetKernelsData(const Params& params) const {
return GetCommonKernelsData(params);
}

KernelsPriority SearchSortedKernelRef::GetKernelsPriority(const Params& /*params*/) const {
return FORCE_PRIORITY_9;
}
} // namespace kernel_selector
Loading

0 comments on commit 7835bb3

Please sign in to comment.