diff --git a/Cargo.lock b/Cargo.lock index dfe3dd55..2d6d4bbc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3118,7 +3118,6 @@ version = "0.3.0" dependencies = [ "hf-hub", "metrics", - "rayon", "text-embeddings-backend", "thiserror", "tokenizers", diff --git a/Dockerfile b/Dockerfile index a638c006..f557c1d3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -34,7 +34,7 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO tee /etc/apt/sources.list.d/oneAPI.list RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ - intel-oneapi-mkl-devel \ + intel-oneapi-mkl-devel=2024.0.0-49656 \ build-essential \ && rm -rf /var/lib/apt/lists/* @@ -74,10 +74,8 @@ COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_thread COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_core.so.2 /usr/local/lib/libmkl_core.so.2 COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_def.so.2 /usr/local/lib/libmkl_vml_def.so.2 COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_def.so.2 /usr/local/lib/libmkl_def.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx.so.2 /usr/local/lib/libmkl_vml_avx.so.2 COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx2.so.2 /usr/local/lib/libmkl_vml_avx2.so.2 COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx512.so.2 /usr/local/lib/libmkl_vml_avx512.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx.so.2 /usr/local/lib/libmkl_avx.so.2 COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx2.so.2 /usr/local/lib/libmkl_avx2.so.2 COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx512.so.2 /usr/local/lib/libmkl_avx512.so.2 COPY --from=builder /usr/src/libfakeintel.so /usr/local/libfakeintel.so diff --git a/README.md b/README.md index b6db560c..61b71393 100644 --- a/README.md +++ b/README.md @@ -9,9 +9,10 @@ Swagger API documentation -A blazing fast inference solution for text embeddings models. +A blazing fast inference solution for text embeddings models. -Benchmark for [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) on an Nvidia A10 with a sequence length of 512 tokens: +Benchmark for [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) on an Nvidia A10 with a sequence +length of 512 tokens:

@@ -27,33 +28,37 @@ Benchmark for [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1 ## Table of contents - [Get Started](#get-started) - - [Supported Models](#supported-models) - - [Docker](#docker) - - [Docker Images](#docker-images) - - [API Documentation](#api-documentation) - - [Using a private or gated model](#using-a-private-or-gated-model) - - [Distributed Tracing](#distributed-tracing) + - [Supported Models](#supported-models) + - [Docker](#docker) + - [Docker Images](#docker-images) + - [API Documentation](#api-documentation) + - [Using a private or gated model](#using-a-private-or-gated-model) + - [Using Sequence Classification models](#using-sequence-classification-models) + - [Distributed Tracing](#distributed-tracing) - [Local Install](#local-install) - [Docker Build](#docker-build) -Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings models. TEI enables -high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5. TEI implements many features -such as: +Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence +classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, +Ember, GTE and E5. TEI implements many features such as: * No model graph compilation step * Small docker images and fast boot times. Get ready for true serverless! * Token based dynamic batching * Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention), -[Candle](https://github.com/huggingface/candle) and [cuBLASLt](https://docs.nvidia.com/cuda/cublas/#using-the-cublaslt-api) + [Candle](https://github.com/huggingface/candle) + and [cuBLASLt](https://docs.nvidia.com/cuda/cublas/#using-the-cublaslt-api) * [Safetensors](https://github.com/huggingface/safetensors) weight loading * Production ready (distributed tracing with Open Telemetry, Prometheus metrics) - ## Get Started ### Supported Models -You can use any JinaBERT model with Alibi or absolute positions or any BERT, CamemBERT, RoBERTa, or XLM-RoBERTa model with absolute positions in `text-embeddings-inference`. +#### Text Embeddings + +You can use any JinaBERT model with Alibi or absolute positions or any BERT, CamemBERT, RoBERTa, or XLM-RoBERTa model +with absolute positions in `text-embeddings-inference`. **Support for other model types will be added in the future.** @@ -73,8 +78,20 @@ Examples of supported models: | N/A | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) | | N/A | JinaBERT | [jinaai/jina-embeddings-v2-small-en](https://hf.co/jinaai/jina-embeddings-v2-small-en) | +You can explore the list of best performing text embeddings +models [here](https://huggingface.co/spaces/mteb/leaderboard). + +#### Sequence Classification and Re-Ranking + +`text-embeddings-inference` v0.4.0 added support for CamemBERT, RoBERTa and XLM-RoBERTa Sequence Classification models. + +Example of supported sequence classification models: -You can explore the list of best performing text embeddings models [here](https://huggingface.co/spaces/mteb/leaderboard). +| Task | Model Type | Model ID | Revision | +|--------------------|-------------|---------------------------------------------------------------------------------------------|-------------| +| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) | `refs/pr/4` | +| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base) | `refs/pr/5` | +| Sentiment Analysis | RoBERTa | [SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) | | ### Docker @@ -95,7 +112,8 @@ curl 127.0.0.1:8080/embed \ -H 'Content-Type: application/json' ``` -**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). +**Note:** To use GPUs, you need to install +the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.0 or higher. To see all options to serve your models: @@ -130,20 +148,18 @@ Options: --dtype The dtype to be forced upon the model - - If `dtype` is not set, it defaults to float32 on accelerate, and float16 for all other architectures [env: DTYPE=] [possible values: float16, float32] --pooling - Optionally control the pooling method. - - If `pooling` is not set, the pooling configuration will be parsed from the model `1_Pooling/config.json` - configuration. - + Optionally control the pooling method for embedding models. + + If `pooling` is not set, the pooling configuration will be parsed from the model `1_Pooling/config.json` + configuration. + If `pooling` is set, it will override the model pooling configuration - + [env: POOLING=] [possible values: cls, mean] @@ -241,7 +257,8 @@ You can turn Flash Attention v1 ON by using the `USE_FLASH_ATTENTION=True` envir ### API documentation You can consult the OpenAPI documentation of the `text-embeddings-inference` REST API using the `/docs` route. -The Swagger UI is also available at: [https://huggingface.github.io/text-embeddings-inference](https://huggingface.github.io/text-embeddings-inference). +The Swagger UI is also available +at: [https://huggingface.github.io/text-embeddings-inference](https://huggingface.github.io/text-embeddings-inference). ### Using a private or gated model @@ -264,6 +281,48 @@ token= docker run --gpus all -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data --pull always ghcr.io/huggingface/text-embeddings-inference:0.3.0 --model-id $model ``` +### Using Sequence Classification models + +`text-embeddings-inference` v0.4.0 added support for CamemBERT, RoBERTa and XLM-RoBERTa Sequence Classification models. +See [this blogpost](https://blog.llamaindex.ai/boosting-rag-picking-the-best-embedding-reranker-models-42d079022e83) by +the LlamaIndex team to understand how you can use Sequence Classification models in your RAG pipeline to improve +downstream performance. + +```shell +model=BAAI/bge-reranker-large +revision=refs/pr/4 +volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run + +docker run --gpus all -p 8080:80 -v $volume:/data --pull always ghcr.io/huggingface/text-embeddings-inference:0.3.0 --model-id $model --revision $revision +``` + +And then you can rank the similarity between a pair of inputs with: + +```bash +curl 127.0.0.1:8080/predict \ + -X POST \ + -d '{"inputs":["What is Deep Learning?", "Deep learning is..."], "raw_scores": true}' \ + -H 'Content-Type: application/json' +``` + +You can also use classic Sequence Classification models like `SamLowe/roberta-base-go_emotions`: + +```shell +model=SamLowe/roberta-base-go_emotions +volume=$PWD/data + +docker run --gpus all -p 8080:80 -v $volume:/data --pull always ghcr.io/huggingface/text-embeddings-inference:0.3.0 --model-id $model +``` + +Once you have deployed the model you can use the `predict` endpoint to get the emotions most associated with an input: + +```bash +curl 127.0.0.1:8080/predict \ + -X POST \ + -d '{"inputs":"I like you."}' \ + -H 'Content-Type: application/json' +``` + ### Distributed Tracing `text-embeddings-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature @@ -290,7 +349,7 @@ cargo install --path router -F candle -F mkl cargo install --path router -F candle -F accelerate ``` -You can now launch Text Embeddings Inference on CPU with: +You can now launch Text Embeddings Inference on CPU with: ```shell model=BAAI/bge-large-en-v1.5 @@ -309,7 +368,8 @@ sudo apt-get install libssl-dev gcc -y GPUs with Cuda compute capabilities < 7.5 are not supported (V100, Titan V, GTX 1000 series, ...). -Make sure you have Cuda and the nvidia drivers installed. We recommend using NVIDIA drivers with CUDA version 12.0 or higher. +Make sure you have Cuda and the nvidia drivers installed. We recommend using NVIDIA drivers with CUDA version 12.0 or +higher. You also need to add the nvidia binaries to your path: ```shell diff --git a/backends/candle/src/alibi.rs b/backends/candle/src/alibi.rs index 1b49941b..1888d05e 100644 --- a/backends/candle/src/alibi.rs +++ b/backends/candle/src/alibi.rs @@ -16,7 +16,7 @@ use candle::{DType, Device, Result, Tensor}; fn get_slopes_power_of_2(n: usize) -> Vec { - let start: f64 = 2_f64.powf(-2_f64.powf(-((n as f64).log2() - 3_f64))); + let start: f64 = 2_f64.powf(-(2_f64.powf(-((n as f64).log2() - 3_f64)))); (0..n).map(|i| start * start.powi(i as i32)).collect() } diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index 9c6d2aad..de8e22ec 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -10,21 +10,23 @@ mod models; use crate::compute_cap::{incompatible_compute_cap, COMPILE_COMPUTE_CAP, RUNTIME_COMPUTE_CAP}; #[cfg(feature = "cuda")] use crate::models::FlashBertModel; -use crate::models::{ - BertModel, EmbeddingModel, JinaBertModel, PositionEmbeddingType, QuantBertModel, -}; +use crate::models::{BertModel, JinaBertModel, Model, PositionEmbeddingType}; use candle::{DType, Device}; use candle_nn::VarBuilder; use models::Config; use std::path::PathBuf; -use text_embeddings_backend_core::{BackendError, Batch, Embedding, EmbeddingBackend, Pool}; +use text_embeddings_backend_core::{Backend, BackendError, Batch, Embedding, ModelType}; pub struct CandleBackend { - model: Box, + model: Box, } impl CandleBackend { - pub fn new(model_path: PathBuf, dtype: String, pool: Pool) -> Result { + pub fn new( + model_path: PathBuf, + dtype: String, + model_type: ModelType, + ) -> Result { // Load config let config: String = std::fs::read_to_string(model_path.join("config.json")) .map_err(|err| BackendError::Start(err.to_string()))?; @@ -49,49 +51,39 @@ impl CandleBackend { ))); } - let model: Box = match device { - Device::Cpu => { - if &dtype == "float32" || &dtype == "float16" { - let dtype = if &dtype == "float32" { - DType::F32 - } else { - DType::F16 - }; - - let safetensors_path = model_path.join("model.safetensors"); - let vb = if safetensors_path.exists() { - unsafe { - VarBuilder::from_mmaped_safetensors( - &[model_path.join("model.safetensors")], - dtype, - &device, - ) - } - } else { - VarBuilder::from_pth(model_path.join("pytorch_model.bin"), dtype, &device) - } - .s()?; - - if config.position_embedding_type == PositionEmbeddingType::Alibi { - tracing::info!("Starting JinaBert model on CPU"); - Box::new(JinaBertModel::load(vb, &config, pool).s()?) - } else { - tracing::info!("Starting Bert model on CPU"); - Box::new(BertModel::load(vb, &config, pool).s()?) - } - } else if &dtype == "q6k" { - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( - model_path.join("ggml-model-q6k.bin"), - ) - .map_err(|err| BackendError::Start(err.to_string()))?; - tracing::info!("vb"); + // Get candle dtype + let dtype = if &dtype == "float32" { + Ok(DType::F32) + } else if &dtype == "float16" { + Ok(DType::F16) + } else { + Err(BackendError::Start(format!( + "DType {dtype} is not supported" + ))) + }?; + + let safetensors_path = model_path.join("model.safetensors"); + let vb = if safetensors_path.exists() { + unsafe { + VarBuilder::from_mmaped_safetensors( + &[model_path.join("model.safetensors")], + dtype, + &device, + ) + } + } else { + VarBuilder::from_pth(model_path.join("pytorch_model.bin"), dtype, &device) + } + .s()?; - tracing::info!("Starting QuantBert model on CPU"); - Box::new(QuantBertModel::load(vb, &config, pool).s()?) + let model: Box = match device { + Device::Cpu => { + if config.position_embedding_type == PositionEmbeddingType::Alibi { + tracing::info!("Starting JinaBert model on CPU"); + Box::new(JinaBertModel::load(vb, &config, model_type).s()?) } else { - return Err(BackendError::Start(format!( - "dtype {dtype} is not supported" - ))); + tracing::info!("Starting Bert model on CPU"); + Box::new(BertModel::load(vb, &config, model_type).s()?) } } Device::Cuda(_) => { @@ -101,31 +93,6 @@ impl CandleBackend { )); #[cfg(feature = "cuda")] { - // Get candle dtype - let dtype = if &dtype == "float32" { - Ok(DType::F32) - } else if &dtype == "float16" { - Ok(DType::F16) - } else { - Err(BackendError::Start(format!( - "DType {dtype} is not supported" - ))) - }?; - - let safetensors_path = model_path.join("model.safetensors"); - let vb = if safetensors_path.exists() { - unsafe { - VarBuilder::from_mmaped_safetensors( - &[model_path.join("model.safetensors")], - dtype, - &device, - ) - } - } else { - VarBuilder::from_pth(model_path.join("pytorch_model.bin"), dtype, &device) - } - .s()?; - if incompatible_compute_cap() { return Err(BackendError::Start(format!("Runtime compute cap {} is not compatible with compile time compute cap {}", *RUNTIME_COMPUTE_CAP, *COMPILE_COMPUTE_CAP))); } @@ -138,13 +105,13 @@ impl CandleBackend { && &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true" { tracing::info!("Starting FlashBert model on Cuda"); - Box::new(FlashBertModel::load(vb, &config, pool).s()?) + Box::new(FlashBertModel::load(vb, &config, model_type).s()?) } else if config.position_embedding_type == PositionEmbeddingType::Alibi { tracing::info!("Starting JinaBert model on Cuda"); - Box::new(JinaBertModel::load(vb, &config, pool).s()?) + Box::new(JinaBertModel::load(vb, &config, model_type).s()?) } else { tracing::info!("Starting Bert model on Cuda"); - Box::new(BertModel::load(vb, &config, pool).s()?) + Box::new(BertModel::load(vb, &config, model_type).s()?) } } } @@ -154,7 +121,7 @@ impl CandleBackend { } } -impl EmbeddingBackend for CandleBackend { +impl Backend for CandleBackend { fn health(&self) -> Result<(), BackendError> { Ok(()) } @@ -165,8 +132,10 @@ impl EmbeddingBackend for CandleBackend { Ok(results) } - fn max_batch_size(&self) -> Option { - None + fn predict(&self, batch: Batch) -> Result>, BackendError> { + let results = self.model.predict(batch).e()?; + let results = results.to_dtype(DType::F32).e()?.to_vec2().e()?; + Ok(results) } } diff --git a/backends/candle/src/models.rs b/backends/candle/src/models.rs index c2bb9c28..f47112a9 100644 --- a/backends/candle/src/models.rs +++ b/backends/candle/src/models.rs @@ -5,10 +5,8 @@ extern crate intel_mkl_src; extern crate accelerate_src; mod bert; -mod bert_quant; pub use bert::{BertModel, Config, PositionEmbeddingType}; -pub use bert_quant::QuantBertModel; use candle::{Result, Tensor}; pub use jina::JinaBertModel; use text_embeddings_backend_core::Batch; @@ -20,6 +18,12 @@ mod jina; #[cfg(feature = "cuda")] pub use flash_bert::FlashBertModel; -pub(crate) trait EmbeddingModel { - fn embed(&self, batch: Batch) -> Result; +pub(crate) trait Model { + fn embed(&self, _batch: Batch) -> Result { + candle::bail!("`embed` is not implemented for this model"); + } + + fn predict(&self, _batch: Batch) -> Result { + candle::bail!("`predict is not implemented for this model"); + } } diff --git a/backends/candle/src/models/bert.rs b/backends/candle/src/models/bert.rs index a2e318ce..30e5c9e2 100644 --- a/backends/candle/src/models/bert.rs +++ b/backends/candle/src/models/bert.rs @@ -1,10 +1,10 @@ use crate::layers::{HiddenAct, LayerNorm, Linear, CUBLASLT}; -use crate::models::EmbeddingModel; +use crate::models::Model; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{Embedding, VarBuilder}; use serde::Deserialize; use std::collections::HashMap; -use text_embeddings_backend_core::{Batch, Pool}; +use text_embeddings_backend_core::{Batch, ModelType, Pool}; // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1 #[derive(Debug, Clone, PartialEq, Deserialize)] @@ -357,10 +357,54 @@ impl BertEncoder { } } +struct BertClassificationHead { + intermediate: Linear, + output: Linear, + span: tracing::Span, +} + +impl BertClassificationHead { + pub fn load(vb: VarBuilder, config: &Config) -> Result { + let n_classes = match &config.id2label { + None => candle::bail!("`id2label` must be set for classifier models"), + Some(id2label) => id2label.len(), + }; + + let intermediate_weight = vb + .pp("dense") + .get((config.hidden_size, config.hidden_size), "weight")?; + let intermediate_bias = vb.pp("dense").get(config.hidden_size, "bias")?; + let intermediate = Linear::new(intermediate_weight, Some(intermediate_bias), None); + + let output_weight = vb + .pp("out_proj") + .get((n_classes, config.hidden_size), "weight")?; + let output_bias = vb.pp("out_proj").get(n_classes, "bias")?; + let output = Linear::new(output_weight, Some(output_bias), None); + + Ok(Self { + intermediate, + output, + span: tracing::span!(tracing::Level::TRACE, "classifier"), + }) + } + + pub fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + + let hidden_states = self.intermediate.forward(hidden_states)?; + let hidden_states = hidden_states.tanh()?; + let hidden_states = self.output.forward(&hidden_states)?; + + Ok(hidden_states) + } +} + pub struct BertModel { embeddings: BertEmbeddings, encoder: BertEncoder, pool: Pool, + classifier: Option, num_attention_heads: usize, @@ -371,12 +415,26 @@ pub struct BertModel { } impl BertModel { - pub fn load(vb: VarBuilder, config: &Config, pool: Pool) -> Result { + pub fn load(vb: VarBuilder, config: &Config, model_type: ModelType) -> Result { // Check position embedding type if config.position_embedding_type != PositionEmbeddingType::Absolute { candle::bail!("Bert only supports absolute position embeddings") } + let (pool, classifier) = match model_type { + // Classifier models always use CLS pooling + ModelType::Classifier => { + if config.model_type == Some("bert".to_string()) { + candle::bail!("`classifier` model type is not supported for Bert"); + } + ( + Pool::Cls, + Some(BertClassificationHead::load(vb.pp("classifier"), config)?), + ) + } + ModelType::Embedding(pool) => (pool, None), + }; + // Check pool type if pool != Pool::Mean && pool != Pool::Cls { candle::bail!("Pool type {pool:?} is not supported"); @@ -396,8 +454,8 @@ impl BertModel { ) { (embeddings, encoder) } else if let (Ok(embeddings), Ok(encoder)) = ( - BertEmbeddings::load(vb.pp("bert.embeddings"), config), - BertEncoder::load(vb.pp("bert.encoder"), config), + BertEmbeddings::load(vb.pp("roberta.embeddings"), config), + BertEncoder::load(vb.pp("roberta.encoder"), config), ) { (embeddings, encoder) } else { @@ -410,6 +468,7 @@ impl BertModel { embeddings, encoder, pool, + classifier, num_attention_heads: config.num_attention_heads, device: vb.device().clone(), dtype: vb.dtype(), @@ -558,8 +617,18 @@ impl BertModel { } } -impl EmbeddingModel for BertModel { +impl Model for BertModel { fn embed(&self, batch: Batch) -> Result { self.forward(batch) } + + fn predict(&self, batch: Batch) -> Result { + match &self.classifier { + None => candle::bail!("`predict` is not implemented for this model"), + Some(classifier) => { + let hidden_states = self.forward(batch)?; + classifier.forward(&hidden_states) + } + } + } } diff --git a/backends/candle/src/models/bert_quant.rs b/backends/candle/src/models/bert_quant.rs deleted file mode 100644 index ef9b2ef6..00000000 --- a/backends/candle/src/models/bert_quant.rs +++ /dev/null @@ -1,459 +0,0 @@ -use crate::layers::HiddenAct; -use crate::models::bert::{Config, PositionEmbeddingType}; -use crate::models::EmbeddingModel; -use candle::quantized::QMatMul; -use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; -use candle_nn::ops::softmax; -use candle_nn::Embedding; -use candle_transformers::quantized_var_builder::VarBuilder; -use text_embeddings_backend_core::{Batch, Pool}; - -#[derive(Debug)] -struct LayerNorm { - weight: Tensor, - bias: Tensor, - epsilon: f64, - span: tracing::Span, -} - -impl LayerNorm { - pub fn load(vb: VarBuilder, config: &Config) -> Result { - Ok(Self { - weight: vb - .get(config.hidden_size, "weight") - .or_else(|_| vb.get(config.hidden_size, "gamma"))? - .dequantize(vb.device())?, - bias: vb - .get(config.hidden_size, "bias") - .or_else(|_| vb.get(config.hidden_size, "beta"))? - .dequantize(vb.device())?, - epsilon: config.layer_norm_eps, - span: tracing::span!(tracing::Level::TRACE, "layer-norm"), - }) - } - - fn forward(&self, x: &Tensor) -> Result { - let _enter = self.span.enter(); - - let x_dtype = x.dtype(); - let internal_dtype = match x_dtype { - DType::F16 | DType::BF16 => DType::F32, - d => d, - }; - let hidden_size = x.dim(D::Minus1)?; - let x = x.to_dtype(internal_dtype)?; - let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?; - let x = x.broadcast_sub(&mean_x)?; - let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; - let x_normed = x.broadcast_div(&(norm_x + self.epsilon)?.sqrt()?)?; - let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?; - x.broadcast_add(&self.bias) - } -} - -#[derive(Debug)] -pub struct Linear { - weight: QMatMul, - bias: Option, - act: Option, - span: tracing::Span, -} - -impl Linear { - pub fn new(weight: QMatMul, bias: Option, act: Option) -> Self { - let span = tracing::span!(tracing::Level::TRACE, "linear"); - - Self { - weight, - bias, - act, - span, - } - } - - pub fn forward(&self, x: &Tensor) -> Result { - let _enter = self.span.enter(); - - let x = x.apply(&self.weight)?; - let x = match &self.bias { - None => Ok(x), - Some(bias) => x.broadcast_add(bias), - }?; - if let Some(act) = &self.act { - match act { - HiddenAct::Gelu => x.gelu(), - HiddenAct::Relu => x.relu(), - } - } else { - Ok(x) - } - } -} - -#[derive(Debug)] -struct BertEmbeddings { - word_embeddings: Embedding, - token_type_embeddings: Embedding, - position_embeddings: Embedding, - layer_norm: LayerNorm, - span: tracing::Span, -} - -impl BertEmbeddings { - pub fn load(vb: VarBuilder, config: &Config) -> Result { - if config.position_embedding_type != PositionEmbeddingType::Absolute { - candle::bail!("FlashBert only supports absolute position embeddings"); - } - - Ok(Self { - word_embeddings: Embedding::new( - vb.pp("word_embeddings") - .get((config.vocab_size, config.hidden_size), "weight")? - .dequantize(vb.device())?, - config.hidden_size, - ), - token_type_embeddings: Embedding::new( - vb.pp("token_type_embeddings") - .get((config.type_vocab_size, config.hidden_size), "weight")? - .dequantize(vb.device())?, - config.hidden_size, - ), - position_embeddings: Embedding::new( - vb.pp("position_embeddings") - .get( - (config.max_position_embeddings, config.hidden_size), - "weight", - )? - .dequantize(vb.device())?, - config.hidden_size, - ), - layer_norm: LayerNorm::load(vb.pp("LayerNorm"), config)?, - span: tracing::span!(tracing::Level::TRACE, "embeddings"), - }) - } - - fn forward( - &self, - input_ids: &Tensor, - token_type_ids: &Tensor, - position_ids: &Tensor, - ) -> Result { - let _enter = self.span.enter(); - - let input_embeddings = self.word_embeddings.forward(input_ids)?; - let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; - let position_embeddings = self.position_embeddings.forward(position_ids)?; - - let embeddings = input_embeddings - .add(&token_type_embeddings)? - .add(&position_embeddings)?; - let embeddings = self.layer_norm.forward(&embeddings)?; - - Ok(embeddings) - } -} - -struct BertAttention { - q_linear: Linear, - k_linear: Linear, - v_linear: Linear, - - dense: Linear, - layer_norm: LayerNorm, - - num_attention_heads: usize, - attention_head_size: usize, - softmax_scale: f64, - - span: tracing::Span, -} - -impl BertAttention { - pub fn load(vb: VarBuilder, config: &Config) -> Result { - let attention_head_size = config.hidden_size / config.num_attention_heads; - let all_head_size = config.num_attention_heads * attention_head_size; - let hidden_size = config.hidden_size; - - let query_weight = vb - .pp("self.query") - .get((all_head_size, hidden_size), "weight")?; - let query_bias = vb - .pp("self.query") - .get(all_head_size, "bias")? - .dequantize(vb.device())?; - let q_linear = Linear::new(QMatMul::from_arc(query_weight)?, Some(query_bias), None); - - let key_weight = vb - .pp("self.key") - .get((all_head_size, hidden_size), "weight")?; - let key_bias = vb - .pp("self.key") - .get(all_head_size, "bias")? - .dequantize(vb.device())?; - let k_linear = Linear::new(QMatMul::from_arc(key_weight)?, Some(key_bias), None); - - let value_weight = vb - .pp("self.value") - .get((all_head_size, hidden_size), "weight")?; - let value_bias = vb - .pp("self.value") - .get(all_head_size, "bias")? - .dequantize(vb.device())?; - let v_linear = Linear::new(QMatMul::from_arc(value_weight)?, Some(value_bias), None); - - let dense_weight = vb - .pp("output") - .pp("dense") - .get((hidden_size, hidden_size), "weight")?; - let dense_bias = vb - .pp("output") - .pp("dense") - .get(hidden_size, "bias")? - .dequantize(vb.device())?; - - let dense = Linear::new(QMatMul::from_arc(dense_weight)?, Some(dense_bias), None); - - let layer_norm = LayerNorm::load(vb.pp("output").pp("LayerNorm"), config)?; - - let softmax_scale = 1. / (attention_head_size as f64).sqrt(); - - Ok(Self { - q_linear, - k_linear, - v_linear, - dense, - layer_norm, - num_attention_heads: config.num_attention_heads, - attention_head_size, - softmax_scale, - span: tracing::span!(tracing::Level::TRACE, "attention"), - }) - } - - fn transpose_for_scores(&self, xs: &Tensor) -> Result { - let mut new_x_shape = xs.dims().to_vec(); - new_x_shape.pop(); - new_x_shape.push(self.num_attention_heads); - new_x_shape.push(self.attention_head_size); - let xs = xs - .reshape(new_x_shape.as_slice())? - .unsqueeze(0)? - .transpose(1, 2)?; - xs.contiguous() - } - - fn forward(&self, hidden_states: &Tensor) -> Result { - let _enter = self.span.enter(); - - let residual = hidden_states.clone(); - - let query_layer = self.q_linear.forward(hidden_states)?; - let key_layer = self.k_linear.forward(hidden_states)?; - let value_layer = self.v_linear.forward(hidden_states)?; - - let query_layer = self.transpose_for_scores(&query_layer)?; - let key_layer = self.transpose_for_scores(&key_layer)?; - let value_layer = self.transpose_for_scores(&value_layer)?; - - let attention_scores = query_layer.matmul(&key_layer.t()?)?; - let attention_scores = (attention_scores * self.softmax_scale)?; - let attention_probs = softmax(&attention_scores, D::Minus1)?; - - let context_layer = attention_probs.matmul(&value_layer)?; - let context_layer = context_layer.transpose(1, 2)?.contiguous()?; - let context_layer = context_layer.flatten_from(D::Minus2)?.squeeze(0)?; - - let hidden_states = self.dense.forward(&context_layer)?.add(&residual)?; - let hidden_states = self.layer_norm.forward(&hidden_states)?; - - Ok(hidden_states) - } -} - -struct BertLayer { - attention: BertAttention, - intermediate: Linear, - output: Linear, - layer_norm: LayerNorm, - span: tracing::Span, -} - -impl BertLayer { - pub fn load(vb: VarBuilder, config: &Config) -> Result { - let attention = BertAttention::load(vb.pp("attention"), config)?; - - let intermediate_weight = vb - .pp("intermediate") - .pp("dense") - .get((config.intermediate_size, config.hidden_size), "weight")?; - let intermediate_bias = vb - .pp("intermediate") - .pp("dense") - .get(config.intermediate_size, "bias")? - .dequantize(vb.device())?; - let intermediate = Linear::new( - QMatMul::from_arc(intermediate_weight)?, - Some(intermediate_bias), - Some(config.hidden_act.clone()), - ); - - let output_weight = vb - .pp("output") - .pp("dense") - .get((config.hidden_size, config.intermediate_size), "weight")?; - let output_bias = vb - .pp("output") - .pp("dense") - .get(config.hidden_size, "bias")? - .dequantize(vb.device())?; - let output = Linear::new(QMatMul::from_arc(output_weight)?, Some(output_bias), None); - - let layer_norm = LayerNorm::load(vb.pp("output").pp("LayerNorm"), config)?; - - Ok(Self { - attention, - intermediate, - output, - layer_norm, - span: tracing::span!(tracing::Level::TRACE, "layer"), - }) - } - - pub fn forward(&self, hidden_states: &Tensor) -> Result { - let _enter = self.span.enter(); - - let hidden_states = self.attention.forward(hidden_states)?; - let residual = hidden_states.clone(); - - let hidden_states = self.intermediate.forward(&hidden_states)?; - let hidden_states = self.output.forward(&hidden_states)?.add(&residual)?; - let hidden_states = self.layer_norm.forward(&hidden_states)?; - - Ok(hidden_states) - } -} - -struct BertEncoder { - layers: Vec, - span: tracing::Span, -} - -impl BertEncoder { - pub fn load(vb: VarBuilder, config: &Config) -> Result { - let layers = (0..config.num_hidden_layers) - .map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config)) - .collect::>>()?; - let span = tracing::span!(tracing::Level::TRACE, "encoder"); - - Ok(BertEncoder { layers, span }) - } - - fn forward(&self, hidden_states: &Tensor) -> Result { - let _enter = self.span.enter(); - - let mut hidden_states = hidden_states.clone(); - - // Use a loop rather than a fold as it's easier to modify when adding debug/... - for layer in self.layers.iter() { - hidden_states = layer.forward(&hidden_states)? - } - - Ok(hidden_states) - } -} - -pub struct QuantBertModel { - embeddings: BertEmbeddings, - encoder: BertEncoder, - pool: Pool, - pub device: Device, - - span: tracing::Span, -} - -impl QuantBertModel { - pub fn load(vb: VarBuilder, config: &Config, pool: Pool) -> Result { - match vb.device() { - Device::Cpu => {} - _ => candle::bail!("Bert requires CPU"), - } - - // Check position embedding type - if config.position_embedding_type != PositionEmbeddingType::Absolute { - candle::bail!("FlashBert only supports absolute position embeddings") - } - - // Check pool type - if pool != Pool::Mean && pool != Pool::Cls { - candle::bail!("Pool type {pool:?} is not supported"); - } - - let (embeddings, encoder) = match ( - BertEmbeddings::load(vb.pp("embeddings"), config), - BertEncoder::load(vb.pp("encoder"), config), - ) { - (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), - (Err(err), _) | (_, Err(err)) => { - let model_type = config.model_type.clone().unwrap_or("bert".to_string()); - - if let (Ok(embeddings), Ok(encoder)) = ( - BertEmbeddings::load(vb.pp(format!("{model_type}.embeddings")), config), - BertEncoder::load(vb.pp(format!("{model_type}.encoder")), config), - ) { - (embeddings, encoder) - } else if let (Ok(embeddings), Ok(encoder)) = ( - BertEmbeddings::load(vb.pp("bert.embeddings"), config), - BertEncoder::load(vb.pp("bert.encoder"), config), - ) { - (embeddings, encoder) - } else { - return Err(err); - } - } - }; - - Ok(Self { - embeddings, - encoder, - pool, - device: vb.device().clone(), - span: tracing::span!(tracing::Level::TRACE, "model"), - }) - } - - pub fn forward(&self, batch: Batch) -> Result { - let _enter = self.span.enter(); - - let batch_size = batch.cumulative_seq_lengths.len() - 1; - if batch_size > 1 { - candle::bail!("Bert CPU only support batch_size == 1"); - } - let shape = batch.input_ids.len(); - - // Create Cuda tensors - let input_ids = Tensor::from_vec(batch.input_ids, shape, &self.device)?; - let type_ids = Tensor::from_vec(batch.token_type_ids, shape, &self.device)?; - let position_ids = Tensor::from_vec(batch.position_ids, shape, &self.device)?; - - let embedding_output = self - .embeddings - .forward(&input_ids, &type_ids, &position_ids)?; - - let outputs = self.encoder.forward(&embedding_output)?; - - let results = match self.pool { - // CLS pooling - Pool::Cls => outputs.i(0..1)?, - // Mean pooling - Pool::Mean => (outputs.sum_keepdim(0)? / (batch.max_length as f64))?, - }; - - Ok(results) - } -} - -impl EmbeddingModel for QuantBertModel { - fn embed(&self, batch: Batch) -> Result { - self.forward(batch) - } -} diff --git a/backends/candle/src/models/flash_bert.rs b/backends/candle/src/models/flash_bert.rs index c945965d..34622d61 100644 --- a/backends/candle/src/models/flash_bert.rs +++ b/backends/candle/src/models/flash_bert.rs @@ -1,10 +1,10 @@ use crate::flash_attn::flash_attn_varlen; use crate::layers::{LayerNorm, Linear}; use crate::models::bert::{Config, PositionEmbeddingType}; -use crate::models::EmbeddingModel; +use crate::models::Model; use candle::{DType, Device, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; -use text_embeddings_backend_core::{Batch, Pool}; +use text_embeddings_backend_core::{Batch, ModelType, Pool}; #[derive(Debug)] struct BertEmbeddings { @@ -270,17 +270,61 @@ impl BertEncoder { } } +struct BertClassificationHead { + intermediate: Linear, + output: Linear, + span: tracing::Span, +} + +impl BertClassificationHead { + pub fn load(vb: VarBuilder, config: &Config) -> Result { + let n_classes = match &config.id2label { + None => candle::bail!("`id2label` must be set for classifier models"), + Some(id2label) => id2label.len(), + }; + + let intermediate_weight = vb + .pp("dense") + .get((config.hidden_size, config.hidden_size), "weight")?; + let intermediate_bias = vb.pp("dense").get(config.hidden_size, "bias")?; + let intermediate = Linear::new(intermediate_weight, Some(intermediate_bias), None); + + let output_weight = vb + .pp("out_proj") + .get((n_classes, config.hidden_size), "weight")?; + let output_bias = vb.pp("out_proj").get(n_classes, "bias")?; + let output = Linear::new(output_weight, Some(output_bias), None); + + Ok(Self { + intermediate, + output, + span: tracing::span!(tracing::Level::TRACE, "classifier"), + }) + } + + pub fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + + let hidden_states = self.intermediate.forward(hidden_states)?; + let hidden_states = hidden_states.tanh()?; + let hidden_states = self.output.forward(&hidden_states)?; + + Ok(hidden_states) + } +} + pub struct FlashBertModel { embeddings: BertEmbeddings, encoder: BertEncoder, pool: Pool, + classifier: Option, pub device: Device, span: tracing::Span, } impl FlashBertModel { - pub fn load(vb: VarBuilder, config: &Config, pool: Pool) -> Result { + pub fn load(vb: VarBuilder, config: &Config, model_type: ModelType) -> Result { match vb.device() { Device::Cuda(_) => {} _ => candle::bail!("FlashBert requires Cuda"), @@ -295,6 +339,20 @@ impl FlashBertModel { candle::bail!("FlashBert only supports absolute position embeddings") } + let (pool, classifier) = match model_type { + // Classifier models always use CLS pooling + ModelType::Classifier => { + if config.model_type == Some("bert".to_string()) { + candle::bail!("`classifier` model type is not supported for Bert"); + } + ( + Pool::Cls, + Some(BertClassificationHead::load(vb.pp("classifier"), config)?), + ) + } + ModelType::Embedding(pool) => (pool, None), + }; + // Check pool type if pool != Pool::Mean && pool != Pool::Cls { candle::bail!("Pool type {pool:?} is not supported"); @@ -314,8 +372,8 @@ impl FlashBertModel { ) { (embeddings, encoder) } else if let (Ok(embeddings), Ok(encoder)) = ( - BertEmbeddings::load(vb.pp("bert.embeddings"), config), - BertEncoder::load(vb.pp("bert.encoder"), config), + BertEmbeddings::load(vb.pp("roberta.embeddings"), config), + BertEncoder::load(vb.pp("roberta.encoder"), config), ) { (embeddings, encoder) } else { @@ -328,6 +386,7 @@ impl FlashBertModel { embeddings, encoder, pool, + classifier, device: vb.device().clone(), span: tracing::span!(tracing::Level::TRACE, "model"), }) @@ -387,8 +446,18 @@ impl FlashBertModel { } } -impl EmbeddingModel for FlashBertModel { +impl Model for FlashBertModel { fn embed(&self, batch: Batch) -> Result { self.forward(batch) } + + fn predict(&self, batch: Batch) -> Result { + match &self.classifier { + None => candle::bail!("`predict` is not implemented for this model"), + Some(classifier) => { + let hidden_states = self.forward(batch)?; + classifier.forward(&hidden_states) + } + } + } } diff --git a/backends/candle/src/models/jina.rs b/backends/candle/src/models/jina.rs index e89205e2..419c7131 100644 --- a/backends/candle/src/models/jina.rs +++ b/backends/candle/src/models/jina.rs @@ -1,10 +1,10 @@ use crate::alibi::build_alibi_tensor; use crate::layers::{HiddenAct, LayerNorm, Linear, CUBLASLT}; -use crate::models::EmbeddingModel; +use crate::models::Model; use crate::models::{Config, PositionEmbeddingType}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{Embedding, VarBuilder}; -use text_embeddings_backend_core::{Batch, Pool}; +use text_embeddings_backend_core::{Batch, ModelType, Pool}; #[derive(Debug)] struct BertEmbeddings { @@ -350,17 +350,25 @@ pub struct JinaBertModel { } impl JinaBertModel { - pub fn load(vb: VarBuilder, config: &Config, pool: Pool) -> Result { + pub fn load(vb: VarBuilder, config: &Config, model_type: ModelType) -> Result { let alibi = match config.position_embedding_type { PositionEmbeddingType::Alibi => Some(build_alibi_tensor( config.max_position_embeddings, config.num_attention_heads, - &vb.device(), + vb.device(), vb.dtype(), )?), PositionEmbeddingType::Absolute => None, }; + let pool = match model_type { + // Classifier models always use CLS pooling + ModelType::Classifier => { + candle::bail!("`classifier` model type is not supported for Jina") + } + ModelType::Embedding(pool) => pool, + }; + // Check pool type if pool != Pool::Mean && pool != Pool::Cls { candle::bail!("Pool type {pool:?} is not supported"); @@ -586,7 +594,7 @@ impl JinaBertModel { } } -impl EmbeddingModel for JinaBertModel { +impl Model for JinaBertModel { fn embed(&self, batch: Batch) -> Result { self.forward(batch) } diff --git a/backends/core/src/lib.rs b/backends/core/src/lib.rs index a934e24a..d9afe68f 100644 --- a/backends/core/src/lib.rs +++ b/backends/core/src/lib.rs @@ -14,18 +14,25 @@ pub struct Batch { pub type Embedding = Vec; -pub trait EmbeddingBackend { +pub trait Backend { fn health(&self) -> Result<(), BackendError>; - - fn embed(&self, batch: Batch) -> Result, BackendError>; - fn max_batch_size(&self) -> Option { None } + + fn embed(&self, batch: Batch) -> Result, BackendError>; + + fn predict(&self, batch: Batch) -> Result>, BackendError>; +} + +#[derive(Debug, PartialEq, Clone)] +pub enum ModelType { + Classifier, + Embedding(Pool), } -#[derive(Debug, PartialEq)] -#[cfg_attr(feature = "clap", derive(Clone, ValueEnum))] +#[derive(Debug, PartialEq, Clone)] +#[cfg_attr(feature = "clap", derive(ValueEnum))] pub enum Pool { Cls, Mean, diff --git a/backends/python/src/lib.rs b/backends/python/src/lib.rs index 69b633a6..ec3395dd 100644 --- a/backends/python/src/lib.rs +++ b/backends/python/src/lib.rs @@ -2,7 +2,7 @@ mod logging; mod management; use backend_grpc_client::Client; -use text_embeddings_backend_core::{BackendError, Batch, Embedding, EmbeddingBackend, Pool}; +use text_embeddings_backend_core::{Backend, BackendError, Batch, Embedding, ModelType, Pool}; use tokio::runtime::Runtime; pub struct PythonBackend { @@ -15,10 +15,19 @@ impl PythonBackend { pub fn new( model_path: String, dtype: String, - pool: Pool, + model_type: ModelType, uds_path: String, otlp_endpoint: Option, ) -> Result { + let pool = match model_type { + ModelType::Classifier => { + return Err(BackendError::Start( + "`classifier` model type is not supported".to_string(), + )) + } + ModelType::Embedding(pool) => pool, + }; + if pool != Pool::Cls { return Err(BackendError::Start(format!("{pool:?} is not supported"))); } @@ -44,7 +53,7 @@ impl PythonBackend { } } -impl EmbeddingBackend for PythonBackend { +impl Backend for PythonBackend { fn health(&self) -> Result<(), BackendError> { if self .tokio_runtime @@ -69,4 +78,10 @@ impl EmbeddingBackend for PythonBackend { .map_err(|err| BackendError::Inference(err.to_string()))?; Ok(results.into_iter().map(|r| r.values).collect()) } + + fn predict(&self, _batch: Batch) -> Result>, BackendError> { + Err(BackendError::Inference( + "`predict` is not implemented".to_string(), + )) + } } diff --git a/backends/src/lib.rs b/backends/src/lib.rs index 49669538..566f9800 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -3,12 +3,12 @@ mod dtype; use std::path::PathBuf; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use text_embeddings_backend_core::EmbeddingBackend; +use text_embeddings_backend_core::Backend as CoreBackend; use tokio::sync::oneshot; use tracing::{instrument, Span}; pub use crate::dtype::DType; -pub use text_embeddings_backend_core::{BackendError, Batch, Embedding, Pool}; +pub use text_embeddings_backend_core::{BackendError, Batch, Embedding, ModelType, Pool}; #[cfg(feature = "candle")] use text_embeddings_backend_candle::CandleBackend; @@ -23,19 +23,26 @@ pub struct Backend { /// Health status health: Arc, pub max_batch_size: Option, + pub model_type: ModelType, } impl Backend { pub fn new( model_path: PathBuf, dtype: DType, - pool: Pool, + model_type: ModelType, uds_path: String, otlp_endpoint: Option, ) -> Result { let (backend_sender, backend_receiver) = flume::unbounded(); - let backend = init_backend(model_path, dtype, pool, uds_path, otlp_endpoint)?; + let backend = init_backend( + model_path, + dtype, + model_type.clone(), + uds_path, + otlp_endpoint, + )?; let max_batch_size = backend.max_batch_size(); tokio::task::spawn_blocking(move || backend_blocking_task(backend, backend_receiver)); @@ -44,6 +51,7 @@ impl Backend { backend_sender, health: Arc::new(AtomicBool::new(false)), max_batch_size, + model_type, }) } @@ -62,7 +70,7 @@ impl Backend { ) } else { // The backend is un-healthy or only just started. Do a more advanced health check - // by sending an embedding request. + // by calling the model forward on a test batch let batch = Batch { input_ids: vec![0], @@ -71,13 +79,10 @@ impl Backend { cumulative_seq_lengths: vec![0, 1], max_length: 1, }; - let (sender, receiver) = oneshot::channel(); - self.backend_sender - .send(BackendCommand::Embed(batch, Span::current(), sender)) - .expect("No backend receiver. This is a bug."); - receiver.await.expect( - "Backend blocking task dropped the sender without sending a response. This is a bug.", - ).map(|_| ()) + match &self.model_type { + ModelType::Classifier => self.predict(batch).await.map(|_| ()), + ModelType::Embedding(_) => self.embed(batch).await.map(|_| ()), + } }; // Update health @@ -100,22 +105,38 @@ impl Backend { self.health.store(result.is_ok(), Ordering::SeqCst); result } + + #[instrument(skip_all)] + pub async fn predict(&self, batch: Batch) -> Result>, BackendError> { + let (sender, receiver) = oneshot::channel(); + + self.backend_sender + .send(BackendCommand::Predict(batch, Span::current(), sender)) + .expect("No backend receiver. This is a bug."); + let result = receiver.await.expect( + "Backend blocking task dropped the sender without send a response. This is a bug.", + ); + + // Update health + self.health.store(result.is_ok(), Ordering::SeqCst); + result + } } #[allow(unused)] fn init_backend( model_path: PathBuf, dtype: DType, - pool: Pool, + model_type: ModelType, uds_path: String, otlp_endpoint: Option, -) -> Result, BackendError> { +) -> Result, BackendError> { if cfg!(feature = "candle") { #[cfg(feature = "candle")] return Ok(Box::new(CandleBackend::new( model_path, dtype.to_string(), - pool, + model_type, )?)); } else if cfg!(feature = "python") { #[cfg(feature = "python")] @@ -127,7 +148,7 @@ fn init_backend( PythonBackend::new( model_path.to_str().unwrap().to_string(), dtype.to_string(), - pool, + model_type, uds_path, otlp_endpoint, ) @@ -141,7 +162,7 @@ fn init_backend( } fn backend_blocking_task( - backend: Box, + backend: Box, command_receiver: flume::Receiver, ) { while let Ok(cmd) = command_receiver.recv() { @@ -154,6 +175,10 @@ fn backend_blocking_task( let _span = span.entered(); let _ = sender.send(backend.embed(batch)); } + BackendCommand::Predict(batch, span, sender) => { + let _span = span.entered(); + let _ = sender.send(backend.predict(batch)); + } } } } @@ -165,4 +190,9 @@ enum BackendCommand { Span, oneshot::Sender, BackendError>>, ), + Predict( + Batch, + Span, + oneshot::Sender>, BackendError>>, + ), } diff --git a/core/Cargo.toml b/core/Cargo.toml index f5e66d0f..c6e62184 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -8,7 +8,6 @@ homepage.workspace = true [dependencies] hf-hub = { version = "^0.3.0", features = ["tokio"] } metrics = "^0.21" -rayon = "^1.8" text-embeddings-backend = { path = "../backends" } thiserror = "^1.0" tokenizers = { version = "^0.14.1", default-features=false, features=["onig", "esaxx_fast"] } diff --git a/core/src/download.rs b/core/src/download.rs index 31894d50..9a0e234a 100644 --- a/core/src/download.rs +++ b/core/src/download.rs @@ -8,6 +8,9 @@ pub async fn download_artifacts(api: &ApiRepo) -> Result { tracing::info!("Starting download"); + api.get("config.json").await?; + api.get("tokenizer.json").await?; + let model_root = match api.get("model.safetensors").await { Ok(p) => p, Err(_) => { @@ -19,8 +22,6 @@ pub async fn download_artifacts(api: &ApiRepo) -> Result { .parent() .unwrap() .to_path_buf(); - api.get("config.json").await?; - api.get("tokenizer.json").await?; tracing::info!("Model artifacts downloaded in {:?}", start.elapsed()); Ok(model_root) diff --git a/core/src/infer.rs b/core/src/infer.rs index 649d3293..0d34f368 100644 --- a/core/src/infer.rs +++ b/core/src/infer.rs @@ -1,10 +1,9 @@ use crate::queue::{Entry, Metadata, NextBatch, Queue}; -use crate::tokenization::Tokenization; +use crate::tokenization::{EncodingInput, Tokenization}; use crate::TextEmbeddingsError; -use rayon::prelude::*; use std::sync::Arc; use std::time::{Duration, Instant}; -use text_embeddings_backend::{Backend, Embedding}; +use text_embeddings_backend::{Backend, BackendError, ModelType}; use tokio::sync::{mpsc, oneshot, Notify, OwnedSemaphorePermit, Semaphore}; use tracing::{instrument, Span}; @@ -45,7 +44,7 @@ impl Infer { )); // Create embed task to communicate with backend - tokio::spawn(embed_task(backend.clone(), embed_receiver)); + tokio::spawn(backend_task(backend.clone(), embed_receiver)); // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); @@ -81,21 +80,30 @@ impl Infer { .expect("Semaphore has been closed. This is a bug.") } - #[instrument(skip(self))] - pub async fn embed( + #[instrument(skip(self, _permit))] + pub async fn embed + std::fmt::Debug>( &self, - inputs: String, + inputs: I, truncate: bool, normalize: bool, - permit: OwnedSemaphorePermit, + _permit: OwnedSemaphorePermit, ) -> Result { + if self.is_classifier() { + metrics::increment_counter!("te_request_failure", "err" => "model_type"); + let message = "model is not an embedding model".to_string(); + tracing::error!("{message}"); + return Err(TextEmbeddingsError::Backend(BackendError::Inference( + message, + ))); + } + let start_time = Instant::now(); metrics::increment_counter!("te_embed_count"); // Tokenization let encoding = self .tokenization - .encode(inputs, truncate) + .encode(inputs.into(), truncate) .await .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "tokenization"); @@ -114,14 +122,13 @@ impl Infer { tokenization: start_time.elapsed(), queue_time: Instant::now(), prompt_tokens: encoding.input_ids.len(), - normalize, }, encoding, }); self.notify_batching_task.notify_one(); - let response = response_rx + let mut response = response_rx .await .expect( "Infer batching task dropped the sender without sending a response. This is a bug.", @@ -132,6 +139,23 @@ impl Infer { err })?; + if normalize { + // Normalize embedding + let scale = (1.0 + / response + .results + .iter() + .map(|v| { + let v = *v as f64; + v * v + }) + .sum::() + .sqrt()) as f32; + for v in response.results.iter_mut() { + *v *= scale; + } + } + // Timings let total_time = start_time.elapsed(); @@ -151,6 +175,88 @@ impl Infer { Ok(response) } + #[instrument(skip(self, _permit))] + pub async fn predict + std::fmt::Debug>( + &self, + inputs: I, + truncate: bool, + _permit: OwnedSemaphorePermit, + ) -> Result { + if !self.is_classifier() { + metrics::increment_counter!("te_request_failure", "err" => "model_type"); + let message = "model is not a classifier model".to_string(); + // tracing::error!("{message}"); + return Err(TextEmbeddingsError::Backend(BackendError::Inference( + message, + ))); + } + + let start_time = Instant::now(); + metrics::increment_counter!("te_predict_count"); + + // Tokenization + let encoding = self + .tokenization + .encode(inputs.into(), truncate) + .await + .map_err(|err| { + metrics::increment_counter!("te_request_failure", "err" => "tokenization"); + tracing::error!("{err}"); + err + })?; + + // MPSC channel to communicate with the background batching task + let (response_tx, response_rx) = oneshot::channel(); + + // Append the request to the queue + self.queue.append(Entry { + metadata: Metadata { + response_tx, + span: Span::current(), + tokenization: start_time.elapsed(), + queue_time: Instant::now(), + prompt_tokens: encoding.input_ids.len(), + }, + encoding, + }); + + self.notify_batching_task.notify_one(); + + let response = response_rx + .await + .expect( + "Infer batching task dropped the sender without sending a response. This is a bug.", + ) + .map_err(|err| { + metrics::increment_counter!("te_request_failure", "err" => "inference"); + tracing::error!("{err}"); + err + })?; + + // Timings + let total_time = start_time.elapsed(); + + // Metrics + metrics::increment_counter!("te_predict_success"); + metrics::histogram!("te_predict_duration", total_time.as_secs_f64()); + metrics::histogram!( + "te_predict_tokenization_duration", + response.tokenization.as_secs_f64() + ); + metrics::histogram!("te_predict_queue_duration", response.queue.as_secs_f64()); + metrics::histogram!( + "te_predict_inference_duration", + response.inference.as_secs_f64() + ); + + Ok(response) + } + + #[instrument(skip(self))] + pub fn is_classifier(&self) -> bool { + matches!(self.backend.model_type, ModelType::Classifier) + } + #[instrument(skip(self))] pub async fn health(&self) -> bool { self.backend.health().await.is_ok() @@ -177,36 +283,23 @@ async fn batching_task( } #[instrument(skip_all)] -async fn embed_task( +async fn backend_task( backend: Backend, mut embed_receiver: mpsc::UnboundedReceiver<(NextBatch, oneshot::Sender<()>)>, ) { while let Some((batch, _callback)) = embed_receiver.recv().await { let inference_start = Instant::now(); - let results = backend.embed(batch.1).await; + let results = match &backend.model_type { + ModelType::Classifier => backend.predict(batch.1).await, + ModelType::Embedding(_) => backend.embed(batch.1).await, + }; // Handle sending responses in another thread to avoid starving the backend tokio::task::spawn_blocking(move || match results { Ok(embeddings) => { - batch.0.into_par_iter().zip(embeddings).for_each(|(m, e)| { - let e = match m.normalize { - // Normalize embedding - true => { - let scale = (1.0 - / e.iter() - .map(|v| { - let v = *v as f64; - v * v - }) - .sum::() - .sqrt()) as f32; - e.into_iter().map(|v| v * scale).collect() - } - false => e, - }; - + batch.0.into_iter().zip(embeddings).for_each(|(m, e)| { let _ = m.response_tx.send(Ok(InferResponse { - embeddings: e, + results: e, prompt_tokens: m.prompt_tokens, tokenization: m.tokenization, queue: inference_start - m.queue_time, @@ -225,7 +318,7 @@ async fn embed_task( #[derive(Debug)] pub struct InferResponse { - pub embeddings: Embedding, + pub results: Vec, pub prompt_tokens: usize, pub tokenization: Duration, pub queue: Duration, diff --git a/core/src/queue.rs b/core/src/queue.rs index 38ad513d..5262fe29 100644 --- a/core/src/queue.rs +++ b/core/src/queue.rs @@ -29,8 +29,6 @@ pub struct Metadata { pub queue_time: Instant, /// Number of tokens in the prompt pub prompt_tokens: usize, - /// Normalize the embeddings - pub normalize: bool, } /// Request Queue diff --git a/core/src/tokenization.rs b/core/src/tokenization.rs index c70aa55e..a320bd9a 100644 --- a/core/src/tokenization.rs +++ b/core/src/tokenization.rs @@ -1,7 +1,7 @@ /// Payload tokenization logic use crate::TextEmbeddingsError; use tokenizers::tokenizer::Tokenizer; -use tokenizers::TruncationDirection; +use tokenizers::{EncodeInput, TruncationDirection, TruncationParams, TruncationStrategy}; use tokio::sync::{mpsc, oneshot}; use tracing::{instrument, Span}; @@ -61,7 +61,7 @@ impl Tokenization { #[instrument(skip_all)] pub async fn encode( &self, - inputs: String, + inputs: EncodingInput, truncate: bool, ) -> Result { // Check if inputs is empty @@ -81,13 +81,13 @@ impl Tokenization { // Await on response channel // Unwrap is safe here - Ok(response_receiver.await.expect("Tokenization background task dropped the sender without sending a response. This is a bug.")?) + response_receiver.await.expect("Tokenization background task dropped the sender without sending a response. This is a bug.") } } /// Start tokenization workers fn tokenizer_worker( - tokenizer: Tokenizer, + mut tokenizer: Tokenizer, max_input_length: usize, position_offset: usize, mut receiver: mpsc::UnboundedReceiver, @@ -103,7 +103,7 @@ fn tokenizer_worker( truncate, max_input_length, position_offset, - &tokenizer, + &mut tokenizer, )); } }) @@ -112,26 +112,34 @@ fn tokenizer_worker( /// Get input length and optionally truncate it fn encode_input( - inputs: String, + inputs: EncodingInput, truncate: bool, max_input_length: usize, position_offset: usize, - tokenizer: &Tokenizer, + tokenizer: &mut Tokenizer, ) -> Result { - // Get the number of tokens in the input - let mut encoding = tokenizer.encode(inputs.clone(), true)?; - - let mut seq_len = encoding.len(); + // Default truncation params + let truncate_params = truncate.then_some(TruncationParams { + direction: TruncationDirection::Right, + max_length: max_input_length, + strategy: TruncationStrategy::LongestFirst, + stride: 0, + }); + + let inputs: EncodeInput = match inputs { + EncodingInput::Single(s) => s.into(), + EncodingInput::Dual(s1, s2) => (s1, s2).into(), + }; + + let encoding = tokenizer + .with_truncation(truncate_params)? + .encode(inputs, true)?; + let seq_len = encoding.len(); if seq_len > max_input_length { - if truncate { - encoding.truncate(max_input_length, 0, TruncationDirection::Right); - seq_len = max_input_length; - } else { - return Err(TextEmbeddingsError::Validation(format!( - "`inputs` must have less than {max_input_length} tokens. Given: {seq_len}" - ))); - } + return Err(TextEmbeddingsError::Validation(format!( + "`inputs` must have less than {max_input_length} tokens. Given: {seq_len}" + ))); } metrics::histogram!("te_request_input_length", seq_len as f64); @@ -151,8 +159,29 @@ pub struct Encoding { pub position_ids: Vec, } +#[derive(Debug)] +pub enum EncodingInput { + Single(String), + Dual(String, String), +} + +impl EncodingInput { + fn is_empty(&self) -> bool { + match self { + EncodingInput::Single(s) => s.is_empty(), + EncodingInput::Dual(s1, s2) => s1.is_empty() && s2.is_empty(), + } + } +} + +impl From for EncodingInput { + fn from(value: String) -> Self { + Self::Single(value) + } +} + type TokenizerRequest = ( - String, + EncodingInput, bool, oneshot::Sender>, Span, diff --git a/docs/openapi.json b/docs/openapi.json index 0daa94da..51b65348 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -17,8 +17,8 @@ "tags": [ "Text Embeddings Inference" ], - "summary": "Get Embeddings", - "description": "Get Embeddings", + "summary": "Get Embeddings. Returns a 424 status code if the model is not an embedding model.", + "description": "Get Embeddings. Returns a 424 status code if the model is not an embedding model.", "operationId": "embed", "requestBody": { "content": { @@ -100,6 +100,94 @@ } } }, + "/embeddings": { + "post": { + "tags": [ + "Text Embeddings Inference" + ], + "summary": "OpenAI compatible route. Returns a 424 status code if the model is not an embedding model.", + "description": "OpenAI compatible route. Returns a 424 status code if the model is not an embedding model.", + "operationId": "openai_embed", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OpenAICompatRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Embeddings", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OpenAICompatResponse" + } + } + } + }, + "413": { + "description": "Batch size error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OpenAICompatErrorResponse" + }, + "example": { + "message": "Batch size error", + "type": "validation" + } + } + } + }, + "422": { + "description": "Tokenization error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OpenAICompatErrorResponse" + }, + "example": { + "message": "Tokenization error", + "type": "tokenizer" + } + } + } + }, + "424": { + "description": "Embedding Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OpenAICompatErrorResponse" + }, + "example": { + "message": "Inference failed", + "type": "backend" + } + } + } + }, + "429": { + "description": "Model is overloaded", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OpenAICompatErrorResponse" + }, + "example": { + "message": "Model is overloaded", + "type": "overloaded" + } + } + } + } + } + } + }, "/health": { "get": { "tags": [ @@ -173,19 +261,19 @@ } } }, - "/openai": { + "/predict": { "post": { "tags": [ "Text Embeddings Inference" ], - "summary": "OpenAI compatible route", - "description": "OpenAI compatible route", - "operationId": "openai_embed", + "summary": "Get Predictions. Returns a 424 status code if the model is not a Sequence Classification model", + "description": "Get Predictions. Returns a 424 status code if the model is not a Sequence Classification model", + "operationId": "predict", "requestBody": { "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/OpenAICompatRequest" + "$ref": "#/components/schemas/PredictRequest" } } }, @@ -193,11 +281,11 @@ }, "responses": { "200": { - "description": "Embeddings", + "description": "Predictions", "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/OpenAICompatResponse" + "$ref": "#/components/schemas/PredictResponse" } } } @@ -207,11 +295,11 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/OpenAICompatErrorResponse" + "$ref": "#/components/schemas/ErrorResponse" }, "example": { - "message": "Batch size error", - "type": "validation" + "error": "Batch size error", + "error_type": "validation" } } } @@ -221,25 +309,25 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/OpenAICompatErrorResponse" + "$ref": "#/components/schemas/ErrorResponse" }, "example": { - "message": "Tokenization error", - "type": "tokenizer" + "error": "Tokenization error", + "error_type": "tokenizer" } } } }, "424": { - "description": "Embedding Error", + "description": "Prediction Error", "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/OpenAICompatErrorResponse" + "$ref": "#/components/schemas/ErrorResponse" }, "example": { - "message": "Inference failed", - "type": "backend" + "error": "Inference failed", + "error_type": "backend" } } } @@ -249,11 +337,11 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/OpenAICompatErrorResponse" + "$ref": "#/components/schemas/ErrorResponse" }, "example": { - "message": "Model is overloaded", - "type": "overloaded" + "error": "Model is overloaded", + "error_type": "overloaded" } } } @@ -264,6 +352,34 @@ }, "components": { "schemas": { + "ClassifierModel": { + "type": "object", + "required": [ + "id2label", + "label2id" + ], + "properties": { + "id2label": { + "type": "object", + "additionalProperties": { + "type": "string" + }, + "example": { + "0": "LABEL" + } + }, + "label2id": { + "type": "object", + "additionalProperties": { + "type": "integer", + "minimum": 0 + }, + "example": { + "LABEL": "0" + } + } + } + }, "EmbedRequest": { "type": "object", "required": [ @@ -273,6 +389,11 @@ "inputs": { "$ref": "#/components/schemas/Input" }, + "normalize": { + "type": "boolean", + "default": "true", + "example": "true" + }, "truncate": { "type": "boolean", "default": "false", @@ -297,6 +418,18 @@ ] ] }, + "EmbeddingModel": { + "type": "object", + "required": [ + "pooling" + ], + "properties": { + "pooling": { + "type": "string", + "example": "cls" + } + } + }, "ErrorResponse": { "type": "object", "required": [ @@ -327,7 +460,7 @@ "required": [ "model_id", "model_dtype", - "model_pooling", + "model_type", "max_concurrent_requests", "max_input_length", "max_batch_tokens", @@ -378,15 +511,14 @@ "description": "Model info", "example": "thenlper/gte-base" }, - "model_pooling": { - "type": "string", - "example": "cls" - }, "model_sha": { "type": "string", "example": "fca14538aa9956a46526bd1d0d11d69e19b5a101", "nullable": true }, + "model_type": { + "$ref": "#/components/schemas/ModelType" + }, "sha": { "type": "string", "example": "null", @@ -417,6 +549,32 @@ } ] }, + "ModelType": { + "oneOf": [ + { + "type": "object", + "required": [ + "classifier" + ], + "properties": { + "classifier": { + "$ref": "#/components/schemas/ClassifierModel" + } + } + }, + { + "type": "object", + "required": [ + "embedding" + ], + "properties": { + "embedding": { + "$ref": "#/components/schemas/EmbeddingModel" + } + } + } + ] + }, "OpenAICompatEmbedding": { "type": "object", "required": [ @@ -536,6 +694,67 @@ "minimum": 0 } } + }, + "PredictRequest": { + "type": "object", + "required": [ + "inputs" + ], + "properties": { + "inputs": { + "$ref": "#/components/schemas/Sequence" + }, + "raw_scores": { + "type": "boolean", + "default": "false", + "example": "false" + }, + "truncate": { + "type": "boolean", + "default": "false", + "example": "false" + } + } + }, + "PredictResponse": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Prediction" + } + }, + "Prediction": { + "type": "object", + "required": [ + "score", + "label" + ], + "properties": { + "label": { + "type": "string", + "example": "admiration" + }, + "score": { + "type": "number", + "format": "float", + "example": "0.5" + } + } + }, + "Sequence": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "string" + }, + "description": "", + "maxItems": 2, + "minItems": 2 + } + ] } } }, diff --git a/docs/source/en/cli_arguments.md b/docs/source/en/cli_arguments.md index 17cc6679..2db8d492 100644 --- a/docs/source/en/cli_arguments.md +++ b/docs/source/en/cli_arguments.md @@ -48,20 +48,18 @@ Options: --dtype The dtype to be forced upon the model - - If `dtype` is not set, it defaults to float32 on accelerate, and float16 for all other architectures [env: DTYPE=] - [possible values: float16] + [possible values: float16, float32] --pooling - Optionally control the pooling method. - - If `pooling` is not set, the pooling configuration will be parsed from the model `1_Pooling/config.json` - configuration. - + Optionally control the pooling method for embedding models. + + If `pooling` is not set, the pooling configuration will be parsed from the model `1_Pooling/config.json` + configuration. + If `pooling` is set, it will override the model pooling configuration - + [env: POOLING=] [possible values: cls, mean] diff --git a/docs/source/en/quick_tour.md b/docs/source/en/quick_tour.md index dbd382c4..c55de18f 100644 --- a/docs/source/en/quick_tour.md +++ b/docs/source/en/quick_tour.md @@ -16,6 +16,8 @@ rendered properly in your Markdown viewer. # Quick Tour +## Text Embeddings + The easiest way to get started with TEI is to use one of the official Docker containers (see [Supported models and hardware](supported_models) to choose the right container). @@ -50,3 +52,47 @@ curl 127.0.0.1:8080/embed \ -d '{"inputs":"What is Deep Learning?"}' \ -H 'Content-Type: application/json' ``` + +## Sequence Classification + +TEI can also be used to deploy Sequence Classification models. +See [this blogpost](https://blog.llamaindex.ai/boosting-rag-picking-the-best-embedding-reranker-models-42d079022e83) by +the LlamaIndex team to understand how you can use Sequence Classification models in your RAG pipeline to improve +downstream performance. + +Let's say you want to use `BAAI/bge-reranker-large`: + +```shell +model=BAAI/bge-reranker-large +revision=refs/pr/4 +volume=$PWD/data + +docker run --gpus all -p 8080:80 -v $volume:/data --pull always ghcr.io/huggingface/text-embeddings-inference:0.3.0 --model-id $model --revision $revision +``` + +Once you have deployed a model you can use the `predict` endpoint and rank the similarity between a pair of inputs: + +```bash +curl 127.0.0.1:8080/predict \ + -X POST \ + -d '{"inputs":["What is Deep Learning?", "Deep learning is..."], "raw_scores": true}' \ + -H 'Content-Type: application/json' +``` + +You can also use classic Sequence Classification models like `SamLowe/roberta-base-go_emotions`: + +```shell +model=SamLowe/roberta-base-go_emotions +volume=$PWD/data + +docker run --gpus all -p 8080:80 -v $volume:/data --pull always ghcr.io/huggingface/text-embeddings-inference:0.3.0 --model-id $model +``` + +Once you have deployed the model you can use the `predict` endpoint to get the emotions most associated with an input: + +```bash +curl 127.0.0.1:8080/predict \ + -X POST \ + -d '{"inputs":"I like you."}' \ + -H 'Content-Type: application/json' +``` diff --git a/docs/source/en/supported_models.md b/docs/source/en/supported_models.md index 733fff7d..ddb0ebee 100644 --- a/docs/source/en/supported_models.md +++ b/docs/source/en/supported_models.md @@ -16,29 +16,46 @@ rendered properly in your Markdown viewer. # Supported models and hardware -## Supported models - -Text Embeddings Inference currently supports BERT, CamemBERT, and XLM-RoBERTa models with absolute positions. We are continually expanding our support for other model types and plan to include them in future updates. +## Supported embeddings models + +Text Embeddings Inference currently supports BERT, CamemBERT, XLM-RoBERTa models with absolute positions and JinaBERT +model with Alibi positions. + Below are some examples of the currently supported models: -| MTEB Rank | Model Type | Model ID | -|-----------|--------------|--------------------------------------------------------------------------------| -| 1 | Bert | [BAAI/bge-large-en-v1.5](https://hf.co/BAAI/bge-large-en-v1.5) | -| 2 | | [BAAI/bge-base-en-v1.5](https://hf.co/BAAI/bge-base-en-v1.5) | -| 3 | | [llmrails/ember-v1](https://hf.co/llmrails/ember-v1) | -| 4 | | [thenlper/gte-large](https://hf.co/thenlper/gte-large) | -| 5 | | [thenlper/gte-base](https://hf.co/thenlper/gte-base) | -| 6 | | [intfloat/e5-large-v2](https://hf.co/intfloat/e5-large-v2) | -| 7 | | [BAAI/bge-small-en-v1.5](https://hf.co/BAAI/bge-small-en-v1.5) | -| 10 | | [intfloat/e5-base-v2](https://hf.co/intfloat/e5-base-v2) | -| 11 | XLM-RoBERTa | [intfloat/multilingual-e5-large](https://hf.co/intfloat/multilingual-e5-large) | + +| MTEB Rank | Model Type | Model ID | +|-----------|-------------|----------------------------------------------------------------------------------------| +| 1 | Bert | [BAAI/bge-large-en-v1.5](https://hf.co/BAAI/bge-large-en-v1.5) | +| 2 | | [BAAI/bge-base-en-v1.5](https://hf.co/BAAI/bge-base-en-v1.5) | +| 3 | | [llmrails/ember-v1](https://hf.co/llmrails/ember-v1) | +| 4 | | [thenlper/gte-large](https://hf.co/thenlper/gte-large) | +| 5 | | [thenlper/gte-base](https://hf.co/thenlper/gte-base) | +| 6 | | [intfloat/e5-large-v2](https://hf.co/intfloat/e5-large-v2) | +| 7 | | [BAAI/bge-small-en-v1.5](https://hf.co/BAAI/bge-small-en-v1.5) | +| 10 | | [intfloat/e5-base-v2](https://hf.co/intfloat/e5-base-v2) | +| 11 | XLM-RoBERTa | [intfloat/multilingual-e5-large](https://hf.co/intfloat/multilingual-e5-large) | +| N/A | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) | +| N/A | JinaBERT | [jinaai/jina-embeddings-v2-small-en](https://hf.co/jinaai/jina-embeddings-v2-small-en) | To explore the list of best performing text embeddings models, visit the [Massive Text Embedding Benchmark (MTEB) Leaderboard](https://huggingface.co/spaces/mteb/leaderboard). +## Supported sequence classification models + +Text Embeddings Inference currently supports CamemBERT, and XLM-RoBERTa Sequence Classification models with absolute positions. + +Below are some examples of the currently supported models: + +| Task | Model Type | Model ID | Revision | +|--------------------|-------------|---------------------------------------------------------------------------------------------|-------------| +| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) | `refs/pr/4` | +| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base) | `refs/pr/5` | +| Sentiment Analysis | RoBERTa | [SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) | | + ## Supported hardware Text Embeddings Inference supports can be used on CPU, Turing (T4, RTX 2000 series, ...), Ampere 80 (A100, A30), diff --git a/load_tests/load.js b/load_tests/load.js index d4bef350..7a945cce 100644 --- a/load_tests/load.js +++ b/load_tests/load.js @@ -41,7 +41,7 @@ export default function () { }); const headers = {'Content-Type': 'application/json'}; - const res = http.post(`http://${host}/embed`, payload, { + const res = http.post(`http://${host}`, payload, { headers, timeout: '20m' }); diff --git a/router/src/lib.rs b/router/src/lib.rs index 8dc7d859..f9683ebf 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -2,8 +2,31 @@ pub mod server; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use text_embeddings_core::tokenization::EncodingInput; use utoipa::ToSchema; +#[derive(Clone, Debug, Serialize, ToSchema)] +pub struct EmbeddingModel { + #[schema(example = "cls")] + pub pooling: String, +} + +#[derive(Clone, Debug, Serialize, ToSchema)] +pub struct ClassifierModel { + #[schema(example = json!({"0": "LABEL"}))] + pub id2label: HashMap, + #[schema(example = json!({"LABEL": "0"}))] + pub label2id: HashMap, +} + +#[derive(Clone, Debug, Serialize, ToSchema)] +#[serde(rename_all = "lowercase")] +pub enum ModelType { + Classifier(ClassifierModel), + Embedding(EmbeddingModel), +} + #[derive(Clone, Debug, Serialize, ToSchema)] pub struct Info { /// Model info @@ -13,8 +36,7 @@ pub struct Info { pub model_sha: Option, #[schema(example = "float16")] pub model_dtype: String, - #[schema(example = "cls")] - pub model_pooling: String, + pub model_type: ModelType, /// Router Parameters #[schema(example = "128")] pub max_concurrent_requests: usize, @@ -37,6 +59,53 @@ pub struct Info { pub docker_label: Option<&'static str>, } +#[derive(Deserialize, ToSchema, Debug)] +#[serde(untagged)] +pub(crate) enum Sequence { + Single(String), + Pair(String, String), +} + +impl Sequence { + pub(crate) fn count_chars(&self) -> usize { + match self { + Sequence::Single(s) => s.chars().count(), + Sequence::Pair(s1, s2) => s1.chars().count() + s2.chars().count(), + } + } +} + +impl From for EncodingInput { + fn from(value: Sequence) -> Self { + match value { + Sequence::Single(s) => Self::Single(s), + Sequence::Pair(s1, s2) => Self::Dual(s1, s2), + } + } +} + +#[derive(Deserialize, ToSchema)] +pub(crate) struct PredictRequest { + pub inputs: Sequence, + #[serde(default)] + #[schema(default = "false", example = "false")] + pub truncate: bool, + #[serde(default)] + #[schema(default = "false", example = "false")] + pub raw_scores: bool, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct Prediction { + #[schema(example = "0.5")] + score: f32, + #[schema(example = "admiration")] + label: String, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct PredictResponse(Vec); + #[derive(Deserialize, ToSchema)] #[serde(untagged)] pub(crate) enum Input { diff --git a/router/src/main.rs b/router/src/main.rs index 9b0d1c78..7ed4a579 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -9,16 +9,16 @@ use opentelemetry::sdk::{trace, Resource}; use opentelemetry::{global, KeyValue}; use opentelemetry_otlp::WithExportConfig; use serde::Deserialize; +use std::collections::HashMap; use std::fs; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::Path; use text_embeddings_backend::DType; -use text_embeddings_backend::{Backend, Pool}; use text_embeddings_core::download::{download_artifacts, download_pool_config}; use text_embeddings_core::infer::Infer; use text_embeddings_core::queue::Queue; use text_embeddings_core::tokenization::Tokenization; -use text_embeddings_router::{server, Info}; +use text_embeddings_router::{server, ClassifierModel, EmbeddingModel, Info, ModelType}; use tokenizers::Tokenizer; use tower_http::cors::AllowOrigin; use tracing_subscriber::layer::SubscriberExt; @@ -54,14 +54,14 @@ struct Args { #[clap(long, env, value_enum)] dtype: Option, - /// Optionally control the pooling method. + /// Optionally control the pooling method for embedding models. /// /// If `pooling` is not set, the pooling configuration will be parsed from the /// model `1_Pooling/config.json` configuration. /// /// If `pooling` is set, it will override the model pooling configuration #[clap(long, env, value_enum)] - pooling: Option, + pooling: Option, /// The maximum amount of concurrent requests for this particular deployment. /// Having a low limit will refuse clients requests instead of having them @@ -127,10 +127,13 @@ struct Args { #[derive(Debug, Deserialize)] pub struct ModelConfig { + pub architectures: Vec, pub model_type: String, #[serde(alias = "n_positions")] pub max_position_embeddings: usize, pub pad_token_id: usize, + pub id2label: Option>, + pub label2id: Option>, } #[derive(Debug, Clone, PartialEq, Deserialize)] @@ -173,7 +176,8 @@ async fn main() -> Result<()> { // Optionally download the pooling config. if args.pooling.is_none() { - download_pool_config(&api_repo).await.context("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model.")?; + // If a pooling config exist, download it + let _ = download_pool_config(&api_repo).await; } // Download model from the Hub @@ -188,23 +192,62 @@ async fn main() -> Result<()> { let config: ModelConfig = serde_json::from_str(&config).context("Failed to parse `config.json`")?; - // Set pooling - let pool = match args.pooling { - Some(pool) => pool, - None => { - // Load pooling config - let config_path = model_root.join("1_Pooling/config.json"); - let config = fs::read_to_string(config_path).context("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model.")?; - let config: PoolConfig = - serde_json::from_str(&config).context("Failed to parse `1_Pooling/config.json`")?; - if config.pooling_mode_cls_token { - Pool::Cls - } else if config.pooling_mode_mean_tokens { - Pool::Mean - } else { - return Err(anyhow!("Pooling config {config:?} is not supported")); + // Set model type from config + let backend_model_type = { + // Check if the model is a classifier + let mut classifier = false; + for arch in &config.architectures { + if arch.ends_with("Classification") { + classifier = true; + break; } } + + if classifier { + if args.pooling.is_some() { + tracing::warn!( + "`--pooling` arg is set but model is a classifier. Ignoring `--pooling` arg." + ); + } + text_embeddings_backend::ModelType::Classifier + } else { + // Set pooling + let pool = match args.pooling { + Some(pool) => pool, + None => { + // Load pooling config + let config_path = model_root.join("1_Pooling/config.json"); + let config = fs::read_to_string(config_path).context("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model.")?; + let config: PoolConfig = serde_json::from_str(&config) + .context("Failed to parse `1_Pooling/config.json`")?; + if config.pooling_mode_cls_token { + text_embeddings_backend::Pool::Cls + } else if config.pooling_mode_mean_tokens { + text_embeddings_backend::Pool::Mean + } else { + return Err(anyhow!("Pooling config {config:?} is not supported")); + } + } + }; + text_embeddings_backend::ModelType::Embedding(pool) + } + }; + + // Info model type + let model_type = match &backend_model_type { + text_embeddings_backend::ModelType::Classifier => ModelType::Classifier(ClassifierModel { + id2label: config + .id2label + .context("`config.json` does not contain `id2label`")?, + label2id: config + .label2id + .context("`config.json` does not contain `label2id`")?, + }), + text_embeddings_backend::ModelType::Embedding(pool) => { + ModelType::Embedding(EmbeddingModel { + pooling: pool.to_string(), + }) + } }; // Load tokenizer @@ -251,10 +294,10 @@ async fn main() -> Result<()> { // Create backend tracing::info!("Starting model backend"); - let backend = Backend::new( + let backend = text_embeddings_backend::Backend::new( model_root, dtype.clone(), - pool.clone(), + backend_model_type, args.uds_path, args.otlp_endpoint, ) @@ -285,7 +328,7 @@ async fn main() -> Result<()> { model_id: args.model_id, model_sha: args.revision, model_dtype: dtype.to_string(), - model_pooling: pool.to_string(), + model_type, max_concurrent_requests: args.max_concurrent_requests, max_input_length: config.max_position_embeddings, max_batch_tokens: args.max_batch_tokens, diff --git a/router/src/server.rs b/router/src/server.rs index 3284bb99..d0d5b6cc 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,7 +1,8 @@ /// HTTP Server logic use crate::{ - EmbedRequest, EmbedResponse, ErrorResponse, ErrorType, Info, Input, OpenAICompatEmbedding, - OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage, + ClassifierModel, EmbedRequest, EmbedResponse, EmbeddingModel, ErrorResponse, ErrorType, Info, + Input, ModelType, OpenAICompatEmbedding, OpenAICompatErrorResponse, OpenAICompatRequest, + OpenAICompatResponse, OpenAICompatUsage, PredictRequest, PredictResponse, Prediction, Sequence, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -54,7 +55,156 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json, + info: Extension, + Json(req): Json, +) -> Result<(HeaderMap, Json), (StatusCode, Json)> { + let span = tracing::Span::current(); + let start_time = Instant::now(); + + metrics::increment_counter!("te_request_count", "method" => "single"); + + let compute_chars = req.inputs.count_chars(); + + let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; + let mut response = infer + .predict(req.inputs, req.truncate, permit) + .await + .map_err(ErrorResponse::from)?; + + let id2label = match &info.model_type { + ModelType::Classifier(classifier) => &classifier.id2label, + _ => panic!(), + }; + + let mut predictions: Vec = { + if !req.raw_scores { + // Softmax + if response.results.len() > 1 { + let max = *response + .results + .iter() + .max_by(|x, y| x.abs().partial_cmp(&y.abs()).unwrap()) + .unwrap(); + + let mut den = 0.0; + for v in response.results.iter_mut() { + *v = (*v - max).exp(); + den += *v; + } + for v in response.results.iter_mut() { + *v /= den; + } + } + // Sigmoid + else { + response.results[0] = 1.0 / (1.0 + (-response.results[0]).exp()); + } + } + + // Map score to label + response + .results + .into_iter() + .enumerate() + .map(|(i, s)| Prediction { + score: s, + label: id2label.get(&i.to_string()).unwrap().clone(), + }) + .collect() + }; + // Reverse sort + predictions.sort_by(|x, y| x.score.partial_cmp(&y.score).unwrap()); + predictions.reverse(); + + metrics::increment_counter!("te_request_success", "method" => "predict"); + + let compute_tokens = response.prompt_tokens; + let tokenization_time = response.tokenization; + let queue_time = response.queue; + let inference_time = response.inference; + + let total_time = start_time.elapsed(); + + // Tracing metadata + span.record("total_time", format!("{total_time:?}")); + span.record("tokenization_time", format!("{tokenization_time:?}")); + span.record("queue_time", format!("{queue_time:?}")); + span.record("inference_time", format!("{inference_time:?}")); + + // Headers + let mut headers = HeaderMap::new(); + headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); + headers.insert( + "x-compute-time", + total_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-compute-characters", + compute_chars.to_string().parse().unwrap(), + ); + headers.insert( + "x-compute-tokens", + compute_tokens.to_string().parse().unwrap(), + ); + headers.insert( + "x-total-time", + total_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-tokenization-time", + tokenization_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-queue-time", + queue_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-inference-time", + inference_time.as_millis().to_string().parse().unwrap(), + ); + + // Metrics + metrics::histogram!("te_request_duration", total_time.as_secs_f64()); + metrics::histogram!( + "te_request_tokenization_duration", + tokenization_time.as_secs_f64() + ); + metrics::histogram!("te_request_queue_duration", queue_time.as_secs_f64()); + metrics::histogram!( + "te_request_inference_duration", + inference_time.as_secs_f64() + ); + + tracing::info!("Success"); + + Ok((headers, Json(PredictResponse(predictions)))) +} + +/// Get Embeddings. Returns a 424 status code if the model is not an embedding model. #[utoipa::path( post, tag = "Text Embeddings Inference", @@ -105,7 +255,7 @@ async fn embed( response.tokenization, response.queue, response.inference, - EmbedResponse(vec![response.embeddings]), + EmbedResponse(vec![response.results]), ) } Input::Batch(inputs) => { @@ -157,7 +307,7 @@ async fn embed( total_queue_time += r.queue.as_nanos() as u64; total_inference_time += r.inference.as_nanos() as u64; total_compute_tokens += r.prompt_tokens; - embeddings.push(r.embeddings); + embeddings.push(r.results); } let batch_size = batch_size as u64; @@ -220,7 +370,7 @@ async fn embed( "te_request_tokenization_duration", tokenization_time.as_secs_f64() ); - metrics::histogram!("e_request_queue_duration", queue_time.as_secs_f64()); + metrics::histogram!("te_request_queue_duration", queue_time.as_secs_f64()); metrics::histogram!( "te_request_inference_duration", inference_time.as_secs_f64() @@ -231,11 +381,11 @@ async fn embed( Ok((headers, Json(response))) } -/// OpenAI compatible route +/// OpenAI compatible route. Returns a 424 status code if the model is not an embedding model. #[utoipa::path( post, tag = "Text Embeddings Inference", -path = "/openai", +path = "/embeddings", request_body = OpenAICompatRequest, responses( (status = 200, description = "Embeddings", body = OpenAICompatResponse), @@ -285,7 +435,7 @@ async fn openai_embed( response.inference, vec![OpenAICompatEmbedding { object: "embedding", - embedding: response.embeddings, + embedding: response.results, index: 0, }], ) @@ -339,7 +489,7 @@ async fn openai_embed( total_compute_tokens += r.prompt_tokens; embeddings.push(OpenAICompatEmbedding { object: "embedding", - embedding: r.embeddings, + embedding: r.results, index: i, }); } @@ -404,7 +554,7 @@ async fn openai_embed( "te_request_tokenization_duration", tokenization_time.as_secs_f64() ); - metrics::histogram!("e_request_queue_duration", queue_time.as_secs_f64()); + metrics::histogram!("te_request_queue_duration", queue_time.as_secs_f64()); metrics::histogram!( "te_request_inference_duration", inference_time.as_secs_f64() @@ -448,14 +598,22 @@ pub async fn run( paths( get_model_info, health, + predict, embed, openai_embed, metrics, ), components( schemas( + Sequence, Input, Info, + ModelType, + ClassifierModel, + EmbeddingModel, + PredictRequest, + Prediction, + PredictResponse, OpenAICompatRequest, OpenAICompatEmbedding, OpenAICompatUsage, @@ -531,13 +689,11 @@ pub async fn run( let app = Router::new() .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) // Base routes - .route("/", post(embed)) .route("/info", get(get_model_info)) .route("/embed", post(embed)) + .route("/predict", post(predict)) // OpenAI compat route - .route("/openai", post(openai_embed)) - // AWS Sagemaker route - .route("/invocations", post(embed)) + .route("/embeddings", post(openai_embed)) // Base Health route .route("/health", get(health)) // Inference API health route @@ -545,7 +701,23 @@ pub async fn run( // AWS Sagemaker health route .route("/ping", get(health)) // Prometheus metrics route - .route("/metrics", get(metrics)) + .route("/metrics", get(metrics)); + + // Set default routes + let app = match infer.is_classifier() { + true => { + app.route("/", post(predict)) + // AWS Sagemaker route + .route("/invocations", post(predict)) + } + false => { + app.route("/", post(embed)) + // AWS Sagemaker route + .route("/invocations", post(embed)) + } + }; + + let app = app .layer(Extension(infer)) .layer(Extension(info)) .layer(Extension(prom_handle.clone()))