From 22ceda493ec99a844d1e01004ada06bbb4a5490e Mon Sep 17 00:00:00 2001 From: Chen Xi Date: Sat, 25 Nov 2023 09:37:28 +0800 Subject: [PATCH] [LLM Runtime] Add jblas split weight interface and support jblas models (#639) * [LLM Runtime] Add jblas split weight interface and support jblas models Signed-off-by: Clark Chin --- .../jblas/jblas/jit_blas_weight_compression.h | 2 +- .../graph/core/layers/jblas_common.cpp | 121 ++++++++++++++++++ .../llm/runtime/graph/core/ne_jblas.h | 3 + .../llm/runtime/graph/models/gptj/gptj.cpp | 37 +++--- .../llm/runtime/graph/models/llama/llama.cpp | 10 +- .../graph/models/model_utils/model_files.h | 77 ++++++++--- 6 files changed, 211 insertions(+), 39 deletions(-) diff --git a/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_weight_compression.h b/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_weight_compression.h index 082e2b62b8e..d3163303ba1 100644 --- a/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_weight_compression.h +++ b/intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas_weight_compression.h @@ -181,7 +181,7 @@ class WeightS8ScaleFp32 { auto Tscales = utils::amalloc(ssize); auto Tzps = utils::amalloc(ptr->mIsAsym ? ssize : 0); quantizeWeight(N, K, B, ldb, ptr->mBlockSize, tmpq, Tscales, Tzps); - packQWeight(N, K, tmpq, ldb, Tscales, Tzps, stor); + packQWeight(N, K, tmpq, N, Tscales, Tzps, stor); utils::afree(tmpq); utils::afree(Tscales); utils::afree(Tzps); diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/layers/jblas_common.cpp b/intel_extension_for_transformers/llm/runtime/graph/core/layers/jblas_common.cpp index 9baa3e7c6e8..9f7e4e199a2 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/core/layers/jblas_common.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/core/layers/jblas_common.cpp @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "jblas_common.hpp" +#include "jblas/jit_blas_weight_compression.h" using namespace jblas; +using namespace ne_jblas; void jblas_init() { GetCPUDevice(); @@ -34,3 +36,122 @@ int jblas_set_threads(int _nth) { jblas::utils::parallel::CpuDevice::getInstance()->setThreads(_nth); return jblas::utils::parallel::CpuDevice::getInstance()->getThreads(); } + +template