Skip to content

Commit

Permalink
Add embedding ops aten (pytorch#1129)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1129

Adds embedding ops for aten

Reviewed By: digantdesai, Jack-Khuu

Differential Revision: D64477035
  • Loading branch information
metascroy authored and facebook-github-bot committed Oct 21, 2024
1 parent 85ec209 commit b90c25b
Show file tree
Hide file tree
Showing 11 changed files with 664 additions and 10 deletions.
2 changes: 2 additions & 0 deletions torchao/experimental/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
# Defines target torchao_kernels_aarch64
add_subdirectory(kernels/cpu/aarch64)
add_subdirectory(ops/linear_8bit_act_xbit_weight)
add_subdirectory(ops/embedding_xbit)

add_library(torchao_ops_aten SHARED)
target_link_libraries(
torchao_ops_aten PRIVATE
torchao_ops_linear_8bit_act_xbit_weight_aten
torchao_ops_embedding_xbit_aten
)
install(
TARGETS torchao_ops_aten
Expand Down
19 changes: 19 additions & 0 deletions torchao/experimental/ops/embedding_xbit/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

cmake_minimum_required(VERSION 3.19)

include(${CMAKE_CURRENT_SOURCE_DIR}/../../Utils.cmake)

find_package(Torch REQUIRED)
add_library(torchao_ops_embedding_xbit_aten OBJECT
op_embedding_xbit_aten.cpp
)
target_link_torchao_parallel_backend(torchao_ops_embedding_xbit_aten "aten_openmp")
target_link_libraries(torchao_ops_embedding_xbit_aten PRIVATE torchao_kernels_aarch64)
target_include_directories(torchao_ops_embedding_xbit_aten PRIVATE "${TORCH_INCLUDE_DIRS}")
target_link_libraries(torchao_ops_embedding_xbit_aten PRIVATE "${TORCH_LIBRARIES}")
target_compile_definitions(torchao_ops_embedding_xbit_aten PRIVATE USE_ATEN=1)
268 changes: 268 additions & 0 deletions torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#if defined(__aarch64__) || defined(__ARM_NEON)
#include <torchao/experimental/kernels/cpu/aarch64/embedding/embedding.h>
#endif // defined(__aarch64__) || defined(__ARM_NEON)

#include <torchao/experimental/ops/embedding_xbit/packed_weights_header.h>
#include <torchao/experimental/ops/library.h>
#include <torchao/experimental/ops/packed_weights_header.h>
#include <torchao/experimental/ops/parallel.h>

template <int weight_nbit>
void check_embedding_inputs(
const Tensor& packed_weight_qvals,
int num_embeddings,
int embedding_dim,
const Tensor& weight_scales,
const Tensor& weight_zeros,
const Tensor& indices,
int& group_size) {
TORCHAO_CHECK(
packed_weight_qvals.dim() == 1, "packed_weight_qvals must be 1D");
#ifdef USE_ATEN
TORCHAO_CHECK(
packed_weight_qvals.dtype() == torch::kInt8,
"packed_weight_qvals must be byte");
#endif // USE_ATEN
TORCHAO_CHECK(
(embedding_dim * weight_nbit) % 8 == 0,
"embedding_dim * weight_nbit must be a multiple of 8");
int packed_embedding_dim = (embedding_dim * weight_nbit) / 8;
TORCHAO_CHECK(
packed_weight_qvals.size(0) ==
(torchao::ops::PackedWeightsHeader::size() +
(num_embeddings * packed_embedding_dim)),
"packed_weight_qvals is not the correct size");

// Check header
auto header = torchao::ops::PackedWeightsHeader::read(
packed_weight_qvals.const_data_ptr());
TORCHAO_CHECK(
header ==
torchao::ops::embedding_xbit::get_packed_weights_header_universal(
weight_nbit,
/*min_value_chunk_size=*/32,
/*max_value_chunk_size=*/128),
"packed_weights are not compatible with the kernel");

#ifdef USE_ATEN
TORCHAO_CHECK(
weight_scales.dtype() == torch::kFloat32,
"weight_scales must be float32");
#endif // USE_ATEN
TORCHAO_CHECK(weight_scales.dim() == 2, "weight_scales must be 2D");
TORCHAO_CHECK(
weight_scales.size(0) == num_embeddings,
"weight_scales must be same shape as packed_weight_qvals in dim0 (num_embeddings)");
int num_groups = weight_scales.size(1);
TORCHAO_CHECK(
num_groups >= 1, "weight_scales must be at least 1 in dim1 (num_groups)");
TORCHAO_CHECK(
embedding_dim % num_groups == 0,
"embedding_dim must be a multiple of num_groups");
group_size = embedding_dim / num_groups;
TORCHAO_CHECK(group_size % 32 == 0, "group_size must be a multiple of 32");

#ifdef USE_ATEN
TORCHAO_CHECK(
weight_zeros.dtype() == torch::kInt8, "weight_zeros must be int8");
#endif // USE_ATEN
TORCHAO_CHECK(weight_zeros.dim() == 2, "weight_zeros must be 2D");
TORCHAO_CHECK(
weight_zeros.size(0) == weight_scales.size(0) &&
weight_zeros.size(1) == weight_scales.size(1),
"zeros must be same shape as scales");

TORCHAO_CHECK(indices.dim() == 1, "indices must be 1D");
TORCHAO_CHECK(
(indices.dtype() == Tensor_dtype_kInt32) ||
(indices.dtype() == Tensor_dtype_kInt64),
"indices must be int32 or int64");
}

#if defined(USE_ATEN) || defined(USE_EXECUTORCH)
template <int weight_nbit>
Tensor embedding_out_cpu(
const Tensor& packed_weight_qvals,
// TODO(T200095131): convert to
// int64_t when supported by AOTI
// Currently they are tensors with size
// equal to (0, the int they wrap)
const Tensor& num_embeddings_tensor,
const Tensor& embedding_dim_tensor,
const Tensor& weight_scales,
const Tensor& weight_zeros,
const Tensor& indices,
Tensor& out) {
int num_embeddings = num_embeddings_tensor.size(1);
int embedding_dim = embedding_dim_tensor.size(1);
int group_size;
check_embedding_inputs<weight_nbit>(
packed_weight_qvals,
num_embeddings,
embedding_dim,
weight_scales,
weight_zeros,
indices,
group_size);

int num_out = indices.size(0);
const int8_t* weight_zeros_ptr = weight_zeros.const_data_ptr<int8_t>();

#ifdef USE_ATEN
TORCHAO_CHECK(out.dtype() == torch::kFloat32, "out must be float32");
out.resize_({num_out, embedding_dim});
#endif // USE_ATEN

#ifdef USE_EXECUTORCH
TORCHAO_CHECK(out.dim() == 2, "out must be 2D");
TORCHAO_CHECK(out.size(0) == num_out, "out shape is incorrect");
TORCHAO_CHECK(out.size(1) == embedding_dim, "out shape is incorrect");
#endif // USE_EXECUTORCH

const int32_t* index32_ptr = nullptr;
const int64_t* index64_ptr = nullptr;
if (indices.dtype() == Tensor_dtype_kInt32) {
index32_ptr = indices.const_data_ptr<int32_t>();
} else {
TORCHAO_CHECK(
indices.dtype() == Tensor_dtype_kInt64,
"indices must be int32 or int64");
index64_ptr = indices.const_data_ptr<int64_t>();
}
torchao::parallel_1d(0, num_out, [&](int64_t idx) {
int index = -1;
if (index32_ptr != nullptr) {
index = index32_ptr[idx];
} else {
index = index64_ptr[idx];
}
TORCHAO_CHECK(index >= 0 && index < num_embeddings, "index out of bounds");
#if defined(__aarch64__) || defined(__ARM_NEON)
torchao::kernels::cpu::aarch64::embedding::embedding<weight_nbit>(
out.mutable_data_ptr<float>() + idx * embedding_dim,
embedding_dim,
group_size,
packed_weight_qvals.const_data_ptr<int8_t>() +
torchao::ops::PackedWeightsHeader::size(),
weight_scales.const_data_ptr<float>(),
weight_zeros_ptr,
index);
#else
TORCHAO_CHECK(false, "Unsupported platform");
#endif // defined(__aarch64__) || defined(__ARM_NEON)
});

return out;
}
#endif // defined(USE_ATEN) || defined(USE_EXECUTORCH)

#ifdef USE_ATEN
template <int weight_nbit>
Tensor embedding_cpu(
const Tensor& packed_weight_qvals,
// TODO(T200095131): convert to
// int64_t when supported by AOTI
// Currently they are tensors with size
// equal to (0, the int they wrap)
const Tensor& num_embeddings_tensor,
const Tensor& embedding_dim_tensor,
const Tensor& weight_scales,
const Tensor& weight_zeros,
const Tensor& indices) {
Tensor output_tensor = torch::empty({}, torch::kFloat32);
embedding_out_cpu<weight_nbit>(
packed_weight_qvals,
num_embeddings_tensor,
embedding_dim_tensor,
weight_scales,
weight_zeros,
indices,
output_tensor);
return output_tensor;
}
#endif // USE_ATEN

#ifdef USE_ATEN
template <int weight_nbit>
Tensor embedding_meta(
const Tensor& packed_weight_qvals,
// TODO(T200095131): convert to
// int64_t when supported by AOTI
// Currently they are tensors with size
// equal to (0, the int they wrap)
const Tensor& num_embeddings_tensor,
const Tensor& embedding_dim_tensor,
const Tensor& weight_scales,
const Tensor& weight_zeros,
const Tensor& indices) {
int embedding_dim = embedding_dim_tensor.size(1);
int num_out = indices.size(0);
return torch::empty({num_out, embedding_dim}).to("meta");
}
#endif // USE_ATEN

#ifdef USE_ATEN
template <int weight_nbit>
Tensor pack_embedding_cpu(const Tensor& weight_qvals) {
TORCHAO_CHECK(weight_qvals.dim() == 2, "weight_qvals must be 2D");
int num_embeddings = weight_qvals.size(0);
int embedding_dim = weight_qvals.size(1);
TORCHAO_CHECK(
embedding_dim % 8 == 0, "embedding_dim must be a multiple of 8 to pack");
int packed_embedding_dim = embedding_dim * weight_nbit / 8;
TORCHAO_CHECK(
weight_qvals.dtype() == torch::kInt8, "weight_qvals must be int8");

auto out = torch::empty(
torchao::ops::PackedWeightsHeader::size() +
(num_embeddings * packed_embedding_dim))
.to(torch::kInt8);

auto header =
torchao::ops::embedding_xbit::get_packed_weights_header_universal(
weight_nbit,
/*min_value_chunk_size=*/32,
/*max_value_chunk_size=*/128);
header.write(out.mutable_data_ptr());

torchao::parallel_1d(0, num_embeddings, [&](int64_t idx) {
#if defined(__aarch64__) || defined(__ARM_NEON)
torchao::kernels::cpu::aarch64::embedding::pack_embedding_weight_qvals<
weight_nbit>(
out.mutable_data_ptr<int8_t>() +
torchao::ops::PackedWeightsHeader::size(),
embedding_dim,
weight_qvals.const_data_ptr<int8_t>(),
idx);
#else
TORCHAO_CHECK(false, "Unsupported platform");
#endif // defined(__aarch64__) || defined(__ARM_NEON)
});

return out;
}
#endif // USE_ATEN

#ifdef USE_ATEN
template <int weight_nbit>
Tensor pack_embedding_meta(const Tensor& weight_qvals) {
TORCHAO_CHECK(weight_qvals.dim() == 2, "weight_qvals must be 2D");
int num_embeddings = weight_qvals.size(0);
int embedding_dim = weight_qvals.size(1);
TORCHAO_CHECK(
embedding_dim % 8 == 0, "embedding_dim must be a multiple of 8 to pack");
int packed_embedding_dim = embedding_dim * weight_nbit / 8;
return torch::empty(
torchao::ops::PackedWeightsHeader::size() +
(num_embeddings * packed_embedding_dim))
.to("meta");
}
#endif // USE_ATEN
56 changes: 56 additions & 0 deletions torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#include <torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h>

#define DEFINE_OP(weight_nbit) \
m.def("_pack_embedding_" #weight_nbit "bit(Tensor weight_qvals) -> Tensor"); \
m.def( \
"_embedding_" #weight_nbit \
"bit(Tensor packed_weight_qvals, Tensor num_embeddings_tensor, Tensor embedding_dim_tensor, Tensor weight_scales, Tensor weight_zeros, Tensor indices) -> Tensor"); \
m.def( \
"_embedding_" #weight_nbit \
"bit.out(Tensor packed_weight_qvals, Tensor num_embeddings_tensor, Tensor embedding_dim_tensor, Tensor weight_scales, Tensor weight_zeros, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)");

#define DEFINE_CPU_IMPL(weight_nbit) \
m.impl( \
"_pack_embedding_" #weight_nbit "bit", \
&pack_embedding_cpu<weight_nbit>); \
m.impl("_embedding_" #weight_nbit "bit", &embedding_cpu<weight_nbit>); \
m.impl("_embedding_" #weight_nbit "bit.out", &embedding_out_cpu<weight_nbit>);

#define DEFINE_META_IMPL(weight_nbit) \
m.impl( \
"_pack_embedding_" #weight_nbit "bit", \
&pack_embedding_meta<weight_nbit>); \
m.impl("_embedding_" #weight_nbit "bit", &embedding_meta<weight_nbit>);

TORCH_LIBRARY_FRAGMENT(torchao, m) {
DEFINE_OP(1);
DEFINE_OP(2);
DEFINE_OP(3);
DEFINE_OP(4);
DEFINE_OP(5);
DEFINE_OP(6);
}

TORCH_LIBRARY_IMPL(torchao, CPU, m) {
DEFINE_CPU_IMPL(1);
DEFINE_CPU_IMPL(2);
DEFINE_CPU_IMPL(3);
DEFINE_CPU_IMPL(4);
DEFINE_CPU_IMPL(5);
DEFINE_CPU_IMPL(6);
}

TORCH_LIBRARY_IMPL(torchao, Meta, m) {
DEFINE_META_IMPL(1);
DEFINE_META_IMPL(2);
DEFINE_META_IMPL(3);
DEFINE_META_IMPL(4);
DEFINE_META_IMPL(5);
DEFINE_META_IMPL(6);
}
36 changes: 36 additions & 0 deletions torchao/experimental/ops/embedding_xbit/packed_weights_header.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once
#include <torchao/experimental/ops/library.h>
#include <torchao/experimental/ops/packed_weights_header.h>

namespace torchao::ops::embedding_xbit {

inline torchao::ops::PackedWeightsHeader get_packed_weights_header_universal(
int weight_nbit,
int min_value_chunk_size,
int max_value_chunk_size,
int version = 1) {
return torchao::ops::PackedWeightsHeader(
torchao::ops::PackedWeightsFormat::embedding_xbit_universal,
{version,
weight_nbit,
min_value_chunk_size,
max_value_chunk_size,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0});
}

} // namespace torchao::ops::embedding_xbit
Loading

0 comments on commit b90c25b

Please sign in to comment.