diff --git a/Cargo.lock b/Cargo.lock
index dfe3dd55..2d6d4bbc 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -3118,7 +3118,6 @@ version = "0.3.0"
dependencies = [
"hf-hub",
"metrics",
- "rayon",
"text-embeddings-backend",
"thiserror",
"tokenizers",
diff --git a/Dockerfile b/Dockerfile
index a638c006..f557c1d3 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -34,7 +34,7 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO
tee /etc/apt/sources.list.d/oneAPI.list
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
- intel-oneapi-mkl-devel \
+ intel-oneapi-mkl-devel=2024.0.0-49656 \
build-essential \
&& rm -rf /var/lib/apt/lists/*
@@ -74,10 +74,8 @@ COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_thread
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_core.so.2 /usr/local/lib/libmkl_core.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_def.so.2 /usr/local/lib/libmkl_vml_def.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_def.so.2 /usr/local/lib/libmkl_def.so.2
-COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx.so.2 /usr/local/lib/libmkl_vml_avx.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx2.so.2 /usr/local/lib/libmkl_vml_avx2.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx512.so.2 /usr/local/lib/libmkl_vml_avx512.so.2
-COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx.so.2 /usr/local/lib/libmkl_avx.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx2.so.2 /usr/local/lib/libmkl_avx2.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx512.so.2 /usr/local/lib/libmkl_avx512.so.2
COPY --from=builder /usr/src/libfakeintel.so /usr/local/libfakeintel.so
diff --git a/README.md b/README.md
index b6db560c..61b71393 100644
--- a/README.md
+++ b/README.md
@@ -9,9 +9,10 @@
-A blazing fast inference solution for text embeddings models.
+A blazing fast inference solution for text embeddings models.
-Benchmark for [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) on an Nvidia A10 with a sequence length of 512 tokens:
+Benchmark for [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) on an Nvidia A10 with a sequence
+length of 512 tokens:
@@ -27,33 +28,37 @@ Benchmark for [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1
## Table of contents
- [Get Started](#get-started)
- - [Supported Models](#supported-models)
- - [Docker](#docker)
- - [Docker Images](#docker-images)
- - [API Documentation](#api-documentation)
- - [Using a private or gated model](#using-a-private-or-gated-model)
- - [Distributed Tracing](#distributed-tracing)
+ - [Supported Models](#supported-models)
+ - [Docker](#docker)
+ - [Docker Images](#docker-images)
+ - [API Documentation](#api-documentation)
+ - [Using a private or gated model](#using-a-private-or-gated-model)
+ - [Using Sequence Classification models](#using-sequence-classification-models)
+ - [Distributed Tracing](#distributed-tracing)
- [Local Install](#local-install)
- [Docker Build](#docker-build)
-Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings models. TEI enables
-high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5. TEI implements many features
-such as:
+Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence
+classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding,
+Ember, GTE and E5. TEI implements many features such as:
* No model graph compilation step
* Small docker images and fast boot times. Get ready for true serverless!
* Token based dynamic batching
* Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention),
-[Candle](https://github.com/huggingface/candle) and [cuBLASLt](https://docs.nvidia.com/cuda/cublas/#using-the-cublaslt-api)
+ [Candle](https://github.com/huggingface/candle)
+ and [cuBLASLt](https://docs.nvidia.com/cuda/cublas/#using-the-cublaslt-api)
* [Safetensors](https://github.com/huggingface/safetensors) weight loading
* Production ready (distributed tracing with Open Telemetry, Prometheus metrics)
-
## Get Started
### Supported Models
-You can use any JinaBERT model with Alibi or absolute positions or any BERT, CamemBERT, RoBERTa, or XLM-RoBERTa model with absolute positions in `text-embeddings-inference`.
+#### Text Embeddings
+
+You can use any JinaBERT model with Alibi or absolute positions or any BERT, CamemBERT, RoBERTa, or XLM-RoBERTa model
+with absolute positions in `text-embeddings-inference`.
**Support for other model types will be added in the future.**
@@ -73,8 +78,20 @@ Examples of supported models:
| N/A | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) |
| N/A | JinaBERT | [jinaai/jina-embeddings-v2-small-en](https://hf.co/jinaai/jina-embeddings-v2-small-en) |
+You can explore the list of best performing text embeddings
+models [here](https://huggingface.co/spaces/mteb/leaderboard).
+
+#### Sequence Classification and Re-Ranking
+
+`text-embeddings-inference` v0.4.0 added support for CamemBERT, RoBERTa and XLM-RoBERTa Sequence Classification models.
+
+Example of supported sequence classification models:
-You can explore the list of best performing text embeddings models [here](https://huggingface.co/spaces/mteb/leaderboard).
+| Task | Model Type | Model ID | Revision |
+|--------------------|-------------|---------------------------------------------------------------------------------------------|-------------|
+| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) | `refs/pr/4` |
+| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base) | `refs/pr/5` |
+| Sentiment Analysis | RoBERTa | [SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) | |
### Docker
@@ -95,7 +112,8 @@ curl 127.0.0.1:8080/embed \
-H 'Content-Type: application/json'
```
-**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html).
+**Note:** To use GPUs, you need to install
+the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html).
We also recommend using NVIDIA drivers with CUDA version 12.0 or higher.
To see all options to serve your models:
@@ -130,20 +148,18 @@ Options:
--dtype
The dtype to be forced upon the model
-
- If `dtype` is not set, it defaults to float32 on accelerate, and float16 for all other architectures
[env: DTYPE=]
[possible values: float16, float32]
--pooling
- Optionally control the pooling method.
-
- If `pooling` is not set, the pooling configuration will be parsed from the model `1_Pooling/config.json`
- configuration.
-
+ Optionally control the pooling method for embedding models.
+
+ If `pooling` is not set, the pooling configuration will be parsed from the model `1_Pooling/config.json`
+ configuration.
+
If `pooling` is set, it will override the model pooling configuration
-
+
[env: POOLING=]
[possible values: cls, mean]
@@ -241,7 +257,8 @@ You can turn Flash Attention v1 ON by using the `USE_FLASH_ATTENTION=True` envir
### API documentation
You can consult the OpenAPI documentation of the `text-embeddings-inference` REST API using the `/docs` route.
-The Swagger UI is also available at: [https://huggingface.github.io/text-embeddings-inference](https://huggingface.github.io/text-embeddings-inference).
+The Swagger UI is also available
+at: [https://huggingface.github.io/text-embeddings-inference](https://huggingface.github.io/text-embeddings-inference).
### Using a private or gated model
@@ -264,6 +281,48 @@ token=
docker run --gpus all -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data --pull always ghcr.io/huggingface/text-embeddings-inference:0.3.0 --model-id $model
```
+### Using Sequence Classification models
+
+`text-embeddings-inference` v0.4.0 added support for CamemBERT, RoBERTa and XLM-RoBERTa Sequence Classification models.
+See [this blogpost](https://blog.llamaindex.ai/boosting-rag-picking-the-best-embedding-reranker-models-42d079022e83) by
+the LlamaIndex team to understand how you can use Sequence Classification models in your RAG pipeline to improve
+downstream performance.
+
+```shell
+model=BAAI/bge-reranker-large
+revision=refs/pr/4
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+
+docker run --gpus all -p 8080:80 -v $volume:/data --pull always ghcr.io/huggingface/text-embeddings-inference:0.3.0 --model-id $model --revision $revision
+```
+
+And then you can rank the similarity between a pair of inputs with:
+
+```bash
+curl 127.0.0.1:8080/predict \
+ -X POST \
+ -d '{"inputs":["What is Deep Learning?", "Deep learning is..."], "raw_scores": true}' \
+ -H 'Content-Type: application/json'
+```
+
+You can also use classic Sequence Classification models like `SamLowe/roberta-base-go_emotions`:
+
+```shell
+model=SamLowe/roberta-base-go_emotions
+volume=$PWD/data
+
+docker run --gpus all -p 8080:80 -v $volume:/data --pull always ghcr.io/huggingface/text-embeddings-inference:0.3.0 --model-id $model
+```
+
+Once you have deployed the model you can use the `predict` endpoint to get the emotions most associated with an input:
+
+```bash
+curl 127.0.0.1:8080/predict \
+ -X POST \
+ -d '{"inputs":"I like you."}' \
+ -H 'Content-Type: application/json'
+```
+
### Distributed Tracing
`text-embeddings-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
@@ -290,7 +349,7 @@ cargo install --path router -F candle -F mkl
cargo install --path router -F candle -F accelerate
```
-You can now launch Text Embeddings Inference on CPU with:
+You can now launch Text Embeddings Inference on CPU with:
```shell
model=BAAI/bge-large-en-v1.5
@@ -309,7 +368,8 @@ sudo apt-get install libssl-dev gcc -y
GPUs with Cuda compute capabilities < 7.5 are not supported (V100, Titan V, GTX 1000 series, ...).
-Make sure you have Cuda and the nvidia drivers installed. We recommend using NVIDIA drivers with CUDA version 12.0 or higher.
+Make sure you have Cuda and the nvidia drivers installed. We recommend using NVIDIA drivers with CUDA version 12.0 or
+higher.
You also need to add the nvidia binaries to your path:
```shell
diff --git a/backends/candle/src/alibi.rs b/backends/candle/src/alibi.rs
index 1b49941b..1888d05e 100644
--- a/backends/candle/src/alibi.rs
+++ b/backends/candle/src/alibi.rs
@@ -16,7 +16,7 @@
use candle::{DType, Device, Result, Tensor};
fn get_slopes_power_of_2(n: usize) -> Vec {
- let start: f64 = 2_f64.powf(-2_f64.powf(-((n as f64).log2() - 3_f64)));
+ let start: f64 = 2_f64.powf(-(2_f64.powf(-((n as f64).log2() - 3_f64))));
(0..n).map(|i| start * start.powi(i as i32)).collect()
}
diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs
index 9c6d2aad..de8e22ec 100644
--- a/backends/candle/src/lib.rs
+++ b/backends/candle/src/lib.rs
@@ -10,21 +10,23 @@ mod models;
use crate::compute_cap::{incompatible_compute_cap, COMPILE_COMPUTE_CAP, RUNTIME_COMPUTE_CAP};
#[cfg(feature = "cuda")]
use crate::models::FlashBertModel;
-use crate::models::{
- BertModel, EmbeddingModel, JinaBertModel, PositionEmbeddingType, QuantBertModel,
-};
+use crate::models::{BertModel, JinaBertModel, Model, PositionEmbeddingType};
use candle::{DType, Device};
use candle_nn::VarBuilder;
use models::Config;
use std::path::PathBuf;
-use text_embeddings_backend_core::{BackendError, Batch, Embedding, EmbeddingBackend, Pool};
+use text_embeddings_backend_core::{Backend, BackendError, Batch, Embedding, ModelType};
pub struct CandleBackend {
- model: Box,
+ model: Box,
}
impl CandleBackend {
- pub fn new(model_path: PathBuf, dtype: String, pool: Pool) -> Result {
+ pub fn new(
+ model_path: PathBuf,
+ dtype: String,
+ model_type: ModelType,
+ ) -> Result {
// Load config
let config: String = std::fs::read_to_string(model_path.join("config.json"))
.map_err(|err| BackendError::Start(err.to_string()))?;
@@ -49,49 +51,39 @@ impl CandleBackend {
)));
}
- let model: Box = match device {
- Device::Cpu => {
- if &dtype == "float32" || &dtype == "float16" {
- let dtype = if &dtype == "float32" {
- DType::F32
- } else {
- DType::F16
- };
-
- let safetensors_path = model_path.join("model.safetensors");
- let vb = if safetensors_path.exists() {
- unsafe {
- VarBuilder::from_mmaped_safetensors(
- &[model_path.join("model.safetensors")],
- dtype,
- &device,
- )
- }
- } else {
- VarBuilder::from_pth(model_path.join("pytorch_model.bin"), dtype, &device)
- }
- .s()?;
-
- if config.position_embedding_type == PositionEmbeddingType::Alibi {
- tracing::info!("Starting JinaBert model on CPU");
- Box::new(JinaBertModel::load(vb, &config, pool).s()?)
- } else {
- tracing::info!("Starting Bert model on CPU");
- Box::new(BertModel::load(vb, &config, pool).s()?)
- }
- } else if &dtype == "q6k" {
- let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
- model_path.join("ggml-model-q6k.bin"),
- )
- .map_err(|err| BackendError::Start(err.to_string()))?;
- tracing::info!("vb");
+ // Get candle dtype
+ let dtype = if &dtype == "float32" {
+ Ok(DType::F32)
+ } else if &dtype == "float16" {
+ Ok(DType::F16)
+ } else {
+ Err(BackendError::Start(format!(
+ "DType {dtype} is not supported"
+ )))
+ }?;
+
+ let safetensors_path = model_path.join("model.safetensors");
+ let vb = if safetensors_path.exists() {
+ unsafe {
+ VarBuilder::from_mmaped_safetensors(
+ &[model_path.join("model.safetensors")],
+ dtype,
+ &device,
+ )
+ }
+ } else {
+ VarBuilder::from_pth(model_path.join("pytorch_model.bin"), dtype, &device)
+ }
+ .s()?;
- tracing::info!("Starting QuantBert model on CPU");
- Box::new(QuantBertModel::load(vb, &config, pool).s()?)
+ let model: Box = match device {
+ Device::Cpu => {
+ if config.position_embedding_type == PositionEmbeddingType::Alibi {
+ tracing::info!("Starting JinaBert model on CPU");
+ Box::new(JinaBertModel::load(vb, &config, model_type).s()?)
} else {
- return Err(BackendError::Start(format!(
- "dtype {dtype} is not supported"
- )));
+ tracing::info!("Starting Bert model on CPU");
+ Box::new(BertModel::load(vb, &config, model_type).s()?)
}
}
Device::Cuda(_) => {
@@ -101,31 +93,6 @@ impl CandleBackend {
));
#[cfg(feature = "cuda")]
{
- // Get candle dtype
- let dtype = if &dtype == "float32" {
- Ok(DType::F32)
- } else if &dtype == "float16" {
- Ok(DType::F16)
- } else {
- Err(BackendError::Start(format!(
- "DType {dtype} is not supported"
- )))
- }?;
-
- let safetensors_path = model_path.join("model.safetensors");
- let vb = if safetensors_path.exists() {
- unsafe {
- VarBuilder::from_mmaped_safetensors(
- &[model_path.join("model.safetensors")],
- dtype,
- &device,
- )
- }
- } else {
- VarBuilder::from_pth(model_path.join("pytorch_model.bin"), dtype, &device)
- }
- .s()?;
-
if incompatible_compute_cap() {
return Err(BackendError::Start(format!("Runtime compute cap {} is not compatible with compile time compute cap {}", *RUNTIME_COMPUTE_CAP, *COMPILE_COMPUTE_CAP)));
}
@@ -138,13 +105,13 @@ impl CandleBackend {
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
{
tracing::info!("Starting FlashBert model on Cuda");
- Box::new(FlashBertModel::load(vb, &config, pool).s()?)
+ Box::new(FlashBertModel::load(vb, &config, model_type).s()?)
} else if config.position_embedding_type == PositionEmbeddingType::Alibi {
tracing::info!("Starting JinaBert model on Cuda");
- Box::new(JinaBertModel::load(vb, &config, pool).s()?)
+ Box::new(JinaBertModel::load(vb, &config, model_type).s()?)
} else {
tracing::info!("Starting Bert model on Cuda");
- Box::new(BertModel::load(vb, &config, pool).s()?)
+ Box::new(BertModel::load(vb, &config, model_type).s()?)
}
}
}
@@ -154,7 +121,7 @@ impl CandleBackend {
}
}
-impl EmbeddingBackend for CandleBackend {
+impl Backend for CandleBackend {
fn health(&self) -> Result<(), BackendError> {
Ok(())
}
@@ -165,8 +132,10 @@ impl EmbeddingBackend for CandleBackend {
Ok(results)
}
- fn max_batch_size(&self) -> Option {
- None
+ fn predict(&self, batch: Batch) -> Result>, BackendError> {
+ let results = self.model.predict(batch).e()?;
+ let results = results.to_dtype(DType::F32).e()?.to_vec2().e()?;
+ Ok(results)
}
}
diff --git a/backends/candle/src/models.rs b/backends/candle/src/models.rs
index c2bb9c28..f47112a9 100644
--- a/backends/candle/src/models.rs
+++ b/backends/candle/src/models.rs
@@ -5,10 +5,8 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
mod bert;
-mod bert_quant;
pub use bert::{BertModel, Config, PositionEmbeddingType};
-pub use bert_quant::QuantBertModel;
use candle::{Result, Tensor};
pub use jina::JinaBertModel;
use text_embeddings_backend_core::Batch;
@@ -20,6 +18,12 @@ mod jina;
#[cfg(feature = "cuda")]
pub use flash_bert::FlashBertModel;
-pub(crate) trait EmbeddingModel {
- fn embed(&self, batch: Batch) -> Result;
+pub(crate) trait Model {
+ fn embed(&self, _batch: Batch) -> Result {
+ candle::bail!("`embed` is not implemented for this model");
+ }
+
+ fn predict(&self, _batch: Batch) -> Result {
+ candle::bail!("`predict is not implemented for this model");
+ }
}
diff --git a/backends/candle/src/models/bert.rs b/backends/candle/src/models/bert.rs
index a2e318ce..30e5c9e2 100644
--- a/backends/candle/src/models/bert.rs
+++ b/backends/candle/src/models/bert.rs
@@ -1,10 +1,10 @@
use crate::layers::{HiddenAct, LayerNorm, Linear, CUBLASLT};
-use crate::models::EmbeddingModel;
+use crate::models::Model;
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{Embedding, VarBuilder};
use serde::Deserialize;
use std::collections::HashMap;
-use text_embeddings_backend_core::{Batch, Pool};
+use text_embeddings_backend_core::{Batch, ModelType, Pool};
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
#[derive(Debug, Clone, PartialEq, Deserialize)]
@@ -357,10 +357,54 @@ impl BertEncoder {
}
}
+struct BertClassificationHead {
+ intermediate: Linear,
+ output: Linear,
+ span: tracing::Span,
+}
+
+impl BertClassificationHead {
+ pub fn load(vb: VarBuilder, config: &Config) -> Result {
+ let n_classes = match &config.id2label {
+ None => candle::bail!("`id2label` must be set for classifier models"),
+ Some(id2label) => id2label.len(),
+ };
+
+ let intermediate_weight = vb
+ .pp("dense")
+ .get((config.hidden_size, config.hidden_size), "weight")?;
+ let intermediate_bias = vb.pp("dense").get(config.hidden_size, "bias")?;
+ let intermediate = Linear::new(intermediate_weight, Some(intermediate_bias), None);
+
+ let output_weight = vb
+ .pp("out_proj")
+ .get((n_classes, config.hidden_size), "weight")?;
+ let output_bias = vb.pp("out_proj").get(n_classes, "bias")?;
+ let output = Linear::new(output_weight, Some(output_bias), None);
+
+ Ok(Self {
+ intermediate,
+ output,
+ span: tracing::span!(tracing::Level::TRACE, "classifier"),
+ })
+ }
+
+ pub fn forward(&self, hidden_states: &Tensor) -> Result {
+ let _enter = self.span.enter();
+
+ let hidden_states = self.intermediate.forward(hidden_states)?;
+ let hidden_states = hidden_states.tanh()?;
+ let hidden_states = self.output.forward(&hidden_states)?;
+
+ Ok(hidden_states)
+ }
+}
+
pub struct BertModel {
embeddings: BertEmbeddings,
encoder: BertEncoder,
pool: Pool,
+ classifier: Option,
num_attention_heads: usize,
@@ -371,12 +415,26 @@ pub struct BertModel {
}
impl BertModel {
- pub fn load(vb: VarBuilder, config: &Config, pool: Pool) -> Result {
+ pub fn load(vb: VarBuilder, config: &Config, model_type: ModelType) -> Result {
// Check position embedding type
if config.position_embedding_type != PositionEmbeddingType::Absolute {
candle::bail!("Bert only supports absolute position embeddings")
}
+ let (pool, classifier) = match model_type {
+ // Classifier models always use CLS pooling
+ ModelType::Classifier => {
+ if config.model_type == Some("bert".to_string()) {
+ candle::bail!("`classifier` model type is not supported for Bert");
+ }
+ (
+ Pool::Cls,
+ Some(BertClassificationHead::load(vb.pp("classifier"), config)?),
+ )
+ }
+ ModelType::Embedding(pool) => (pool, None),
+ };
+
// Check pool type
if pool != Pool::Mean && pool != Pool::Cls {
candle::bail!("Pool type {pool:?} is not supported");
@@ -396,8 +454,8 @@ impl BertModel {
) {
(embeddings, encoder)
} else if let (Ok(embeddings), Ok(encoder)) = (
- BertEmbeddings::load(vb.pp("bert.embeddings"), config),
- BertEncoder::load(vb.pp("bert.encoder"), config),
+ BertEmbeddings::load(vb.pp("roberta.embeddings"), config),
+ BertEncoder::load(vb.pp("roberta.encoder"), config),
) {
(embeddings, encoder)
} else {
@@ -410,6 +468,7 @@ impl BertModel {
embeddings,
encoder,
pool,
+ classifier,
num_attention_heads: config.num_attention_heads,
device: vb.device().clone(),
dtype: vb.dtype(),
@@ -558,8 +617,18 @@ impl BertModel {
}
}
-impl EmbeddingModel for BertModel {
+impl Model for BertModel {
fn embed(&self, batch: Batch) -> Result {
self.forward(batch)
}
+
+ fn predict(&self, batch: Batch) -> Result {
+ match &self.classifier {
+ None => candle::bail!("`predict` is not implemented for this model"),
+ Some(classifier) => {
+ let hidden_states = self.forward(batch)?;
+ classifier.forward(&hidden_states)
+ }
+ }
+ }
}
diff --git a/backends/candle/src/models/bert_quant.rs b/backends/candle/src/models/bert_quant.rs
deleted file mode 100644
index ef9b2ef6..00000000
--- a/backends/candle/src/models/bert_quant.rs
+++ /dev/null
@@ -1,459 +0,0 @@
-use crate::layers::HiddenAct;
-use crate::models::bert::{Config, PositionEmbeddingType};
-use crate::models::EmbeddingModel;
-use candle::quantized::QMatMul;
-use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
-use candle_nn::ops::softmax;
-use candle_nn::Embedding;
-use candle_transformers::quantized_var_builder::VarBuilder;
-use text_embeddings_backend_core::{Batch, Pool};
-
-#[derive(Debug)]
-struct LayerNorm {
- weight: Tensor,
- bias: Tensor,
- epsilon: f64,
- span: tracing::Span,
-}
-
-impl LayerNorm {
- pub fn load(vb: VarBuilder, config: &Config) -> Result {
- Ok(Self {
- weight: vb
- .get(config.hidden_size, "weight")
- .or_else(|_| vb.get(config.hidden_size, "gamma"))?
- .dequantize(vb.device())?,
- bias: vb
- .get(config.hidden_size, "bias")
- .or_else(|_| vb.get(config.hidden_size, "beta"))?
- .dequantize(vb.device())?,
- epsilon: config.layer_norm_eps,
- span: tracing::span!(tracing::Level::TRACE, "layer-norm"),
- })
- }
-
- fn forward(&self, x: &Tensor) -> Result {
- let _enter = self.span.enter();
-
- let x_dtype = x.dtype();
- let internal_dtype = match x_dtype {
- DType::F16 | DType::BF16 => DType::F32,
- d => d,
- };
- let hidden_size = x.dim(D::Minus1)?;
- let x = x.to_dtype(internal_dtype)?;
- let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
- let x = x.broadcast_sub(&mean_x)?;
- let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
- let x_normed = x.broadcast_div(&(norm_x + self.epsilon)?.sqrt()?)?;
- let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?;
- x.broadcast_add(&self.bias)
- }
-}
-
-#[derive(Debug)]
-pub struct Linear {
- weight: QMatMul,
- bias: Option,
- act: Option,
- span: tracing::Span,
-}
-
-impl Linear {
- pub fn new(weight: QMatMul, bias: Option, act: Option) -> Self {
- let span = tracing::span!(tracing::Level::TRACE, "linear");
-
- Self {
- weight,
- bias,
- act,
- span,
- }
- }
-
- pub fn forward(&self, x: &Tensor) -> Result {
- let _enter = self.span.enter();
-
- let x = x.apply(&self.weight)?;
- let x = match &self.bias {
- None => Ok(x),
- Some(bias) => x.broadcast_add(bias),
- }?;
- if let Some(act) = &self.act {
- match act {
- HiddenAct::Gelu => x.gelu(),
- HiddenAct::Relu => x.relu(),
- }
- } else {
- Ok(x)
- }
- }
-}
-
-#[derive(Debug)]
-struct BertEmbeddings {
- word_embeddings: Embedding,
- token_type_embeddings: Embedding,
- position_embeddings: Embedding,
- layer_norm: LayerNorm,
- span: tracing::Span,
-}
-
-impl BertEmbeddings {
- pub fn load(vb: VarBuilder, config: &Config) -> Result {
- if config.position_embedding_type != PositionEmbeddingType::Absolute {
- candle::bail!("FlashBert only supports absolute position embeddings");
- }
-
- Ok(Self {
- word_embeddings: Embedding::new(
- vb.pp("word_embeddings")
- .get((config.vocab_size, config.hidden_size), "weight")?
- .dequantize(vb.device())?,
- config.hidden_size,
- ),
- token_type_embeddings: Embedding::new(
- vb.pp("token_type_embeddings")
- .get((config.type_vocab_size, config.hidden_size), "weight")?
- .dequantize(vb.device())?,
- config.hidden_size,
- ),
- position_embeddings: Embedding::new(
- vb.pp("position_embeddings")
- .get(
- (config.max_position_embeddings, config.hidden_size),
- "weight",
- )?
- .dequantize(vb.device())?,
- config.hidden_size,
- ),
- layer_norm: LayerNorm::load(vb.pp("LayerNorm"), config)?,
- span: tracing::span!(tracing::Level::TRACE, "embeddings"),
- })
- }
-
- fn forward(
- &self,
- input_ids: &Tensor,
- token_type_ids: &Tensor,
- position_ids: &Tensor,
- ) -> Result {
- let _enter = self.span.enter();
-
- let input_embeddings = self.word_embeddings.forward(input_ids)?;
- let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
- let position_embeddings = self.position_embeddings.forward(position_ids)?;
-
- let embeddings = input_embeddings
- .add(&token_type_embeddings)?
- .add(&position_embeddings)?;
- let embeddings = self.layer_norm.forward(&embeddings)?;
-
- Ok(embeddings)
- }
-}
-
-struct BertAttention {
- q_linear: Linear,
- k_linear: Linear,
- v_linear: Linear,
-
- dense: Linear,
- layer_norm: LayerNorm,
-
- num_attention_heads: usize,
- attention_head_size: usize,
- softmax_scale: f64,
-
- span: tracing::Span,
-}
-
-impl BertAttention {
- pub fn load(vb: VarBuilder, config: &Config) -> Result {
- let attention_head_size = config.hidden_size / config.num_attention_heads;
- let all_head_size = config.num_attention_heads * attention_head_size;
- let hidden_size = config.hidden_size;
-
- let query_weight = vb
- .pp("self.query")
- .get((all_head_size, hidden_size), "weight")?;
- let query_bias = vb
- .pp("self.query")
- .get(all_head_size, "bias")?
- .dequantize(vb.device())?;
- let q_linear = Linear::new(QMatMul::from_arc(query_weight)?, Some(query_bias), None);
-
- let key_weight = vb
- .pp("self.key")
- .get((all_head_size, hidden_size), "weight")?;
- let key_bias = vb
- .pp("self.key")
- .get(all_head_size, "bias")?
- .dequantize(vb.device())?;
- let k_linear = Linear::new(QMatMul::from_arc(key_weight)?, Some(key_bias), None);
-
- let value_weight = vb
- .pp("self.value")
- .get((all_head_size, hidden_size), "weight")?;
- let value_bias = vb
- .pp("self.value")
- .get(all_head_size, "bias")?
- .dequantize(vb.device())?;
- let v_linear = Linear::new(QMatMul::from_arc(value_weight)?, Some(value_bias), None);
-
- let dense_weight = vb
- .pp("output")
- .pp("dense")
- .get((hidden_size, hidden_size), "weight")?;
- let dense_bias = vb
- .pp("output")
- .pp("dense")
- .get(hidden_size, "bias")?
- .dequantize(vb.device())?;
-
- let dense = Linear::new(QMatMul::from_arc(dense_weight)?, Some(dense_bias), None);
-
- let layer_norm = LayerNorm::load(vb.pp("output").pp("LayerNorm"), config)?;
-
- let softmax_scale = 1. / (attention_head_size as f64).sqrt();
-
- Ok(Self {
- q_linear,
- k_linear,
- v_linear,
- dense,
- layer_norm,
- num_attention_heads: config.num_attention_heads,
- attention_head_size,
- softmax_scale,
- span: tracing::span!(tracing::Level::TRACE, "attention"),
- })
- }
-
- fn transpose_for_scores(&self, xs: &Tensor) -> Result {
- let mut new_x_shape = xs.dims().to_vec();
- new_x_shape.pop();
- new_x_shape.push(self.num_attention_heads);
- new_x_shape.push(self.attention_head_size);
- let xs = xs
- .reshape(new_x_shape.as_slice())?
- .unsqueeze(0)?
- .transpose(1, 2)?;
- xs.contiguous()
- }
-
- fn forward(&self, hidden_states: &Tensor) -> Result {
- let _enter = self.span.enter();
-
- let residual = hidden_states.clone();
-
- let query_layer = self.q_linear.forward(hidden_states)?;
- let key_layer = self.k_linear.forward(hidden_states)?;
- let value_layer = self.v_linear.forward(hidden_states)?;
-
- let query_layer = self.transpose_for_scores(&query_layer)?;
- let key_layer = self.transpose_for_scores(&key_layer)?;
- let value_layer = self.transpose_for_scores(&value_layer)?;
-
- let attention_scores = query_layer.matmul(&key_layer.t()?)?;
- let attention_scores = (attention_scores * self.softmax_scale)?;
- let attention_probs = softmax(&attention_scores, D::Minus1)?;
-
- let context_layer = attention_probs.matmul(&value_layer)?;
- let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
- let context_layer = context_layer.flatten_from(D::Minus2)?.squeeze(0)?;
-
- let hidden_states = self.dense.forward(&context_layer)?.add(&residual)?;
- let hidden_states = self.layer_norm.forward(&hidden_states)?;
-
- Ok(hidden_states)
- }
-}
-
-struct BertLayer {
- attention: BertAttention,
- intermediate: Linear,
- output: Linear,
- layer_norm: LayerNorm,
- span: tracing::Span,
-}
-
-impl BertLayer {
- pub fn load(vb: VarBuilder, config: &Config) -> Result {
- let attention = BertAttention::load(vb.pp("attention"), config)?;
-
- let intermediate_weight = vb
- .pp("intermediate")
- .pp("dense")
- .get((config.intermediate_size, config.hidden_size), "weight")?;
- let intermediate_bias = vb
- .pp("intermediate")
- .pp("dense")
- .get(config.intermediate_size, "bias")?
- .dequantize(vb.device())?;
- let intermediate = Linear::new(
- QMatMul::from_arc(intermediate_weight)?,
- Some(intermediate_bias),
- Some(config.hidden_act.clone()),
- );
-
- let output_weight = vb
- .pp("output")
- .pp("dense")
- .get((config.hidden_size, config.intermediate_size), "weight")?;
- let output_bias = vb
- .pp("output")
- .pp("dense")
- .get(config.hidden_size, "bias")?
- .dequantize(vb.device())?;
- let output = Linear::new(QMatMul::from_arc(output_weight)?, Some(output_bias), None);
-
- let layer_norm = LayerNorm::load(vb.pp("output").pp("LayerNorm"), config)?;
-
- Ok(Self {
- attention,
- intermediate,
- output,
- layer_norm,
- span: tracing::span!(tracing::Level::TRACE, "layer"),
- })
- }
-
- pub fn forward(&self, hidden_states: &Tensor) -> Result {
- let _enter = self.span.enter();
-
- let hidden_states = self.attention.forward(hidden_states)?;
- let residual = hidden_states.clone();
-
- let hidden_states = self.intermediate.forward(&hidden_states)?;
- let hidden_states = self.output.forward(&hidden_states)?.add(&residual)?;
- let hidden_states = self.layer_norm.forward(&hidden_states)?;
-
- Ok(hidden_states)
- }
-}
-
-struct BertEncoder {
- layers: Vec,
- span: tracing::Span,
-}
-
-impl BertEncoder {
- pub fn load(vb: VarBuilder, config: &Config) -> Result {
- let layers = (0..config.num_hidden_layers)
- .map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config))
- .collect::>>()?;
- let span = tracing::span!(tracing::Level::TRACE, "encoder");
-
- Ok(BertEncoder { layers, span })
- }
-
- fn forward(&self, hidden_states: &Tensor) -> Result {
- let _enter = self.span.enter();
-
- let mut hidden_states = hidden_states.clone();
-
- // Use a loop rather than a fold as it's easier to modify when adding debug/...
- for layer in self.layers.iter() {
- hidden_states = layer.forward(&hidden_states)?
- }
-
- Ok(hidden_states)
- }
-}
-
-pub struct QuantBertModel {
- embeddings: BertEmbeddings,
- encoder: BertEncoder,
- pool: Pool,
- pub device: Device,
-
- span: tracing::Span,
-}
-
-impl QuantBertModel {
- pub fn load(vb: VarBuilder, config: &Config, pool: Pool) -> Result {
- match vb.device() {
- Device::Cpu => {}
- _ => candle::bail!("Bert requires CPU"),
- }
-
- // Check position embedding type
- if config.position_embedding_type != PositionEmbeddingType::Absolute {
- candle::bail!("FlashBert only supports absolute position embeddings")
- }
-
- // Check pool type
- if pool != Pool::Mean && pool != Pool::Cls {
- candle::bail!("Pool type {pool:?} is not supported");
- }
-
- let (embeddings, encoder) = match (
- BertEmbeddings::load(vb.pp("embeddings"), config),
- BertEncoder::load(vb.pp("encoder"), config),
- ) {
- (Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
- (Err(err), _) | (_, Err(err)) => {
- let model_type = config.model_type.clone().unwrap_or("bert".to_string());
-
- if let (Ok(embeddings), Ok(encoder)) = (
- BertEmbeddings::load(vb.pp(format!("{model_type}.embeddings")), config),
- BertEncoder::load(vb.pp(format!("{model_type}.encoder")), config),
- ) {
- (embeddings, encoder)
- } else if let (Ok(embeddings), Ok(encoder)) = (
- BertEmbeddings::load(vb.pp("bert.embeddings"), config),
- BertEncoder::load(vb.pp("bert.encoder"), config),
- ) {
- (embeddings, encoder)
- } else {
- return Err(err);
- }
- }
- };
-
- Ok(Self {
- embeddings,
- encoder,
- pool,
- device: vb.device().clone(),
- span: tracing::span!(tracing::Level::TRACE, "model"),
- })
- }
-
- pub fn forward(&self, batch: Batch) -> Result {
- let _enter = self.span.enter();
-
- let batch_size = batch.cumulative_seq_lengths.len() - 1;
- if batch_size > 1 {
- candle::bail!("Bert CPU only support batch_size == 1");
- }
- let shape = batch.input_ids.len();
-
- // Create Cuda tensors
- let input_ids = Tensor::from_vec(batch.input_ids, shape, &self.device)?;
- let type_ids = Tensor::from_vec(batch.token_type_ids, shape, &self.device)?;
- let position_ids = Tensor::from_vec(batch.position_ids, shape, &self.device)?;
-
- let embedding_output = self
- .embeddings
- .forward(&input_ids, &type_ids, &position_ids)?;
-
- let outputs = self.encoder.forward(&embedding_output)?;
-
- let results = match self.pool {
- // CLS pooling
- Pool::Cls => outputs.i(0..1)?,
- // Mean pooling
- Pool::Mean => (outputs.sum_keepdim(0)? / (batch.max_length as f64))?,
- };
-
- Ok(results)
- }
-}
-
-impl EmbeddingModel for QuantBertModel {
- fn embed(&self, batch: Batch) -> Result {
- self.forward(batch)
- }
-}
diff --git a/backends/candle/src/models/flash_bert.rs b/backends/candle/src/models/flash_bert.rs
index c945965d..34622d61 100644
--- a/backends/candle/src/models/flash_bert.rs
+++ b/backends/candle/src/models/flash_bert.rs
@@ -1,10 +1,10 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{LayerNorm, Linear};
use crate::models::bert::{Config, PositionEmbeddingType};
-use crate::models::EmbeddingModel;
+use crate::models::Model;
use candle::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
-use text_embeddings_backend_core::{Batch, Pool};
+use text_embeddings_backend_core::{Batch, ModelType, Pool};
#[derive(Debug)]
struct BertEmbeddings {
@@ -270,17 +270,61 @@ impl BertEncoder {
}
}
+struct BertClassificationHead {
+ intermediate: Linear,
+ output: Linear,
+ span: tracing::Span,
+}
+
+impl BertClassificationHead {
+ pub fn load(vb: VarBuilder, config: &Config) -> Result {
+ let n_classes = match &config.id2label {
+ None => candle::bail!("`id2label` must be set for classifier models"),
+ Some(id2label) => id2label.len(),
+ };
+
+ let intermediate_weight = vb
+ .pp("dense")
+ .get((config.hidden_size, config.hidden_size), "weight")?;
+ let intermediate_bias = vb.pp("dense").get(config.hidden_size, "bias")?;
+ let intermediate = Linear::new(intermediate_weight, Some(intermediate_bias), None);
+
+ let output_weight = vb
+ .pp("out_proj")
+ .get((n_classes, config.hidden_size), "weight")?;
+ let output_bias = vb.pp("out_proj").get(n_classes, "bias")?;
+ let output = Linear::new(output_weight, Some(output_bias), None);
+
+ Ok(Self {
+ intermediate,
+ output,
+ span: tracing::span!(tracing::Level::TRACE, "classifier"),
+ })
+ }
+
+ pub fn forward(&self, hidden_states: &Tensor) -> Result {
+ let _enter = self.span.enter();
+
+ let hidden_states = self.intermediate.forward(hidden_states)?;
+ let hidden_states = hidden_states.tanh()?;
+ let hidden_states = self.output.forward(&hidden_states)?;
+
+ Ok(hidden_states)
+ }
+}
+
pub struct FlashBertModel {
embeddings: BertEmbeddings,
encoder: BertEncoder,
pool: Pool,
+ classifier: Option,
pub device: Device,
span: tracing::Span,
}
impl FlashBertModel {
- pub fn load(vb: VarBuilder, config: &Config, pool: Pool) -> Result {
+ pub fn load(vb: VarBuilder, config: &Config, model_type: ModelType) -> Result {
match vb.device() {
Device::Cuda(_) => {}
_ => candle::bail!("FlashBert requires Cuda"),
@@ -295,6 +339,20 @@ impl FlashBertModel {
candle::bail!("FlashBert only supports absolute position embeddings")
}
+ let (pool, classifier) = match model_type {
+ // Classifier models always use CLS pooling
+ ModelType::Classifier => {
+ if config.model_type == Some("bert".to_string()) {
+ candle::bail!("`classifier` model type is not supported for Bert");
+ }
+ (
+ Pool::Cls,
+ Some(BertClassificationHead::load(vb.pp("classifier"), config)?),
+ )
+ }
+ ModelType::Embedding(pool) => (pool, None),
+ };
+
// Check pool type
if pool != Pool::Mean && pool != Pool::Cls {
candle::bail!("Pool type {pool:?} is not supported");
@@ -314,8 +372,8 @@ impl FlashBertModel {
) {
(embeddings, encoder)
} else if let (Ok(embeddings), Ok(encoder)) = (
- BertEmbeddings::load(vb.pp("bert.embeddings"), config),
- BertEncoder::load(vb.pp("bert.encoder"), config),
+ BertEmbeddings::load(vb.pp("roberta.embeddings"), config),
+ BertEncoder::load(vb.pp("roberta.encoder"), config),
) {
(embeddings, encoder)
} else {
@@ -328,6 +386,7 @@ impl FlashBertModel {
embeddings,
encoder,
pool,
+ classifier,
device: vb.device().clone(),
span: tracing::span!(tracing::Level::TRACE, "model"),
})
@@ -387,8 +446,18 @@ impl FlashBertModel {
}
}
-impl EmbeddingModel for FlashBertModel {
+impl Model for FlashBertModel {
fn embed(&self, batch: Batch) -> Result {
self.forward(batch)
}
+
+ fn predict(&self, batch: Batch) -> Result {
+ match &self.classifier {
+ None => candle::bail!("`predict` is not implemented for this model"),
+ Some(classifier) => {
+ let hidden_states = self.forward(batch)?;
+ classifier.forward(&hidden_states)
+ }
+ }
+ }
}
diff --git a/backends/candle/src/models/jina.rs b/backends/candle/src/models/jina.rs
index e89205e2..419c7131 100644
--- a/backends/candle/src/models/jina.rs
+++ b/backends/candle/src/models/jina.rs
@@ -1,10 +1,10 @@
use crate::alibi::build_alibi_tensor;
use crate::layers::{HiddenAct, LayerNorm, Linear, CUBLASLT};
-use crate::models::EmbeddingModel;
+use crate::models::Model;
use crate::models::{Config, PositionEmbeddingType};
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{Embedding, VarBuilder};
-use text_embeddings_backend_core::{Batch, Pool};
+use text_embeddings_backend_core::{Batch, ModelType, Pool};
#[derive(Debug)]
struct BertEmbeddings {
@@ -350,17 +350,25 @@ pub struct JinaBertModel {
}
impl JinaBertModel {
- pub fn load(vb: VarBuilder, config: &Config, pool: Pool) -> Result {
+ pub fn load(vb: VarBuilder, config: &Config, model_type: ModelType) -> Result {
let alibi = match config.position_embedding_type {
PositionEmbeddingType::Alibi => Some(build_alibi_tensor(
config.max_position_embeddings,
config.num_attention_heads,
- &vb.device(),
+ vb.device(),
vb.dtype(),
)?),
PositionEmbeddingType::Absolute => None,
};
+ let pool = match model_type {
+ // Classifier models always use CLS pooling
+ ModelType::Classifier => {
+ candle::bail!("`classifier` model type is not supported for Jina")
+ }
+ ModelType::Embedding(pool) => pool,
+ };
+
// Check pool type
if pool != Pool::Mean && pool != Pool::Cls {
candle::bail!("Pool type {pool:?} is not supported");
@@ -586,7 +594,7 @@ impl JinaBertModel {
}
}
-impl EmbeddingModel for JinaBertModel {
+impl Model for JinaBertModel {
fn embed(&self, batch: Batch) -> Result {
self.forward(batch)
}
diff --git a/backends/core/src/lib.rs b/backends/core/src/lib.rs
index a934e24a..d9afe68f 100644
--- a/backends/core/src/lib.rs
+++ b/backends/core/src/lib.rs
@@ -14,18 +14,25 @@ pub struct Batch {
pub type Embedding = Vec;
-pub trait EmbeddingBackend {
+pub trait Backend {
fn health(&self) -> Result<(), BackendError>;
-
- fn embed(&self, batch: Batch) -> Result, BackendError>;
-
fn max_batch_size(&self) -> Option {
None
}
+
+ fn embed(&self, batch: Batch) -> Result, BackendError>;
+
+ fn predict(&self, batch: Batch) -> Result>, BackendError>;
+}
+
+#[derive(Debug, PartialEq, Clone)]
+pub enum ModelType {
+ Classifier,
+ Embedding(Pool),
}
-#[derive(Debug, PartialEq)]
-#[cfg_attr(feature = "clap", derive(Clone, ValueEnum))]
+#[derive(Debug, PartialEq, Clone)]
+#[cfg_attr(feature = "clap", derive(ValueEnum))]
pub enum Pool {
Cls,
Mean,
diff --git a/backends/python/src/lib.rs b/backends/python/src/lib.rs
index 69b633a6..ec3395dd 100644
--- a/backends/python/src/lib.rs
+++ b/backends/python/src/lib.rs
@@ -2,7 +2,7 @@ mod logging;
mod management;
use backend_grpc_client::Client;
-use text_embeddings_backend_core::{BackendError, Batch, Embedding, EmbeddingBackend, Pool};
+use text_embeddings_backend_core::{Backend, BackendError, Batch, Embedding, ModelType, Pool};
use tokio::runtime::Runtime;
pub struct PythonBackend {
@@ -15,10 +15,19 @@ impl PythonBackend {
pub fn new(
model_path: String,
dtype: String,
- pool: Pool,
+ model_type: ModelType,
uds_path: String,
otlp_endpoint: Option,
) -> Result {
+ let pool = match model_type {
+ ModelType::Classifier => {
+ return Err(BackendError::Start(
+ "`classifier` model type is not supported".to_string(),
+ ))
+ }
+ ModelType::Embedding(pool) => pool,
+ };
+
if pool != Pool::Cls {
return Err(BackendError::Start(format!("{pool:?} is not supported")));
}
@@ -44,7 +53,7 @@ impl PythonBackend {
}
}
-impl EmbeddingBackend for PythonBackend {
+impl Backend for PythonBackend {
fn health(&self) -> Result<(), BackendError> {
if self
.tokio_runtime
@@ -69,4 +78,10 @@ impl EmbeddingBackend for PythonBackend {
.map_err(|err| BackendError::Inference(err.to_string()))?;
Ok(results.into_iter().map(|r| r.values).collect())
}
+
+ fn predict(&self, _batch: Batch) -> Result>, BackendError> {
+ Err(BackendError::Inference(
+ "`predict` is not implemented".to_string(),
+ ))
+ }
}
diff --git a/backends/src/lib.rs b/backends/src/lib.rs
index 49669538..566f9800 100644
--- a/backends/src/lib.rs
+++ b/backends/src/lib.rs
@@ -3,12 +3,12 @@ mod dtype;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
-use text_embeddings_backend_core::EmbeddingBackend;
+use text_embeddings_backend_core::Backend as CoreBackend;
use tokio::sync::oneshot;
use tracing::{instrument, Span};
pub use crate::dtype::DType;
-pub use text_embeddings_backend_core::{BackendError, Batch, Embedding, Pool};
+pub use text_embeddings_backend_core::{BackendError, Batch, Embedding, ModelType, Pool};
#[cfg(feature = "candle")]
use text_embeddings_backend_candle::CandleBackend;
@@ -23,19 +23,26 @@ pub struct Backend {
/// Health status
health: Arc,
pub max_batch_size: Option,
+ pub model_type: ModelType,
}
impl Backend {
pub fn new(
model_path: PathBuf,
dtype: DType,
- pool: Pool,
+ model_type: ModelType,
uds_path: String,
otlp_endpoint: Option,
) -> Result {
let (backend_sender, backend_receiver) = flume::unbounded();
- let backend = init_backend(model_path, dtype, pool, uds_path, otlp_endpoint)?;
+ let backend = init_backend(
+ model_path,
+ dtype,
+ model_type.clone(),
+ uds_path,
+ otlp_endpoint,
+ )?;
let max_batch_size = backend.max_batch_size();
tokio::task::spawn_blocking(move || backend_blocking_task(backend, backend_receiver));
@@ -44,6 +51,7 @@ impl Backend {
backend_sender,
health: Arc::new(AtomicBool::new(false)),
max_batch_size,
+ model_type,
})
}
@@ -62,7 +70,7 @@ impl Backend {
)
} else {
// The backend is un-healthy or only just started. Do a more advanced health check
- // by sending an embedding request.
+ // by calling the model forward on a test batch
let batch = Batch {
input_ids: vec![0],
@@ -71,13 +79,10 @@ impl Backend {
cumulative_seq_lengths: vec![0, 1],
max_length: 1,
};
- let (sender, receiver) = oneshot::channel();
- self.backend_sender
- .send(BackendCommand::Embed(batch, Span::current(), sender))
- .expect("No backend receiver. This is a bug.");
- receiver.await.expect(
- "Backend blocking task dropped the sender without sending a response. This is a bug.",
- ).map(|_| ())
+ match &self.model_type {
+ ModelType::Classifier => self.predict(batch).await.map(|_| ()),
+ ModelType::Embedding(_) => self.embed(batch).await.map(|_| ()),
+ }
};
// Update health
@@ -100,22 +105,38 @@ impl Backend {
self.health.store(result.is_ok(), Ordering::SeqCst);
result
}
+
+ #[instrument(skip_all)]
+ pub async fn predict(&self, batch: Batch) -> Result>, BackendError> {
+ let (sender, receiver) = oneshot::channel();
+
+ self.backend_sender
+ .send(BackendCommand::Predict(batch, Span::current(), sender))
+ .expect("No backend receiver. This is a bug.");
+ let result = receiver.await.expect(
+ "Backend blocking task dropped the sender without send a response. This is a bug.",
+ );
+
+ // Update health
+ self.health.store(result.is_ok(), Ordering::SeqCst);
+ result
+ }
}
#[allow(unused)]
fn init_backend(
model_path: PathBuf,
dtype: DType,
- pool: Pool,
+ model_type: ModelType,
uds_path: String,
otlp_endpoint: Option,
-) -> Result, BackendError> {
+) -> Result, BackendError> {
if cfg!(feature = "candle") {
#[cfg(feature = "candle")]
return Ok(Box::new(CandleBackend::new(
model_path,
dtype.to_string(),
- pool,
+ model_type,
)?));
} else if cfg!(feature = "python") {
#[cfg(feature = "python")]
@@ -127,7 +148,7 @@ fn init_backend(
PythonBackend::new(
model_path.to_str().unwrap().to_string(),
dtype.to_string(),
- pool,
+ model_type,
uds_path,
otlp_endpoint,
)
@@ -141,7 +162,7 @@ fn init_backend(
}
fn backend_blocking_task(
- backend: Box,
+ backend: Box,
command_receiver: flume::Receiver,
) {
while let Ok(cmd) = command_receiver.recv() {
@@ -154,6 +175,10 @@ fn backend_blocking_task(
let _span = span.entered();
let _ = sender.send(backend.embed(batch));
}
+ BackendCommand::Predict(batch, span, sender) => {
+ let _span = span.entered();
+ let _ = sender.send(backend.predict(batch));
+ }
}
}
}
@@ -165,4 +190,9 @@ enum BackendCommand {
Span,
oneshot::Sender, BackendError>>,
),
+ Predict(
+ Batch,
+ Span,
+ oneshot::Sender>, BackendError>>,
+ ),
}
diff --git a/core/Cargo.toml b/core/Cargo.toml
index f5e66d0f..c6e62184 100644
--- a/core/Cargo.toml
+++ b/core/Cargo.toml
@@ -8,7 +8,6 @@ homepage.workspace = true
[dependencies]
hf-hub = { version = "^0.3.0", features = ["tokio"] }
metrics = "^0.21"
-rayon = "^1.8"
text-embeddings-backend = { path = "../backends" }
thiserror = "^1.0"
tokenizers = { version = "^0.14.1", default-features=false, features=["onig", "esaxx_fast"] }
diff --git a/core/src/download.rs b/core/src/download.rs
index 31894d50..9a0e234a 100644
--- a/core/src/download.rs
+++ b/core/src/download.rs
@@ -8,6 +8,9 @@ pub async fn download_artifacts(api: &ApiRepo) -> Result {
tracing::info!("Starting download");
+ api.get("config.json").await?;
+ api.get("tokenizer.json").await?;
+
let model_root = match api.get("model.safetensors").await {
Ok(p) => p,
Err(_) => {
@@ -19,8 +22,6 @@ pub async fn download_artifacts(api: &ApiRepo) -> Result {
.parent()
.unwrap()
.to_path_buf();
- api.get("config.json").await?;
- api.get("tokenizer.json").await?;
tracing::info!("Model artifacts downloaded in {:?}", start.elapsed());
Ok(model_root)
diff --git a/core/src/infer.rs b/core/src/infer.rs
index 649d3293..0d34f368 100644
--- a/core/src/infer.rs
+++ b/core/src/infer.rs
@@ -1,10 +1,9 @@
use crate::queue::{Entry, Metadata, NextBatch, Queue};
-use crate::tokenization::Tokenization;
+use crate::tokenization::{EncodingInput, Tokenization};
use crate::TextEmbeddingsError;
-use rayon::prelude::*;
use std::sync::Arc;
use std::time::{Duration, Instant};
-use text_embeddings_backend::{Backend, Embedding};
+use text_embeddings_backend::{Backend, BackendError, ModelType};
use tokio::sync::{mpsc, oneshot, Notify, OwnedSemaphorePermit, Semaphore};
use tracing::{instrument, Span};
@@ -45,7 +44,7 @@ impl Infer {
));
// Create embed task to communicate with backend
- tokio::spawn(embed_task(backend.clone(), embed_receiver));
+ tokio::spawn(backend_task(backend.clone(), embed_receiver));
// Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
@@ -81,21 +80,30 @@ impl Infer {
.expect("Semaphore has been closed. This is a bug.")
}
- #[instrument(skip(self))]
- pub async fn embed(
+ #[instrument(skip(self, _permit))]
+ pub async fn embed + std::fmt::Debug>(
&self,
- inputs: String,
+ inputs: I,
truncate: bool,
normalize: bool,
- permit: OwnedSemaphorePermit,
+ _permit: OwnedSemaphorePermit,
) -> Result {
+ if self.is_classifier() {
+ metrics::increment_counter!("te_request_failure", "err" => "model_type");
+ let message = "model is not an embedding model".to_string();
+ tracing::error!("{message}");
+ return Err(TextEmbeddingsError::Backend(BackendError::Inference(
+ message,
+ )));
+ }
+
let start_time = Instant::now();
metrics::increment_counter!("te_embed_count");
// Tokenization
let encoding = self
.tokenization
- .encode(inputs, truncate)
+ .encode(inputs.into(), truncate)
.await
.map_err(|err| {
metrics::increment_counter!("te_request_failure", "err" => "tokenization");
@@ -114,14 +122,13 @@ impl Infer {
tokenization: start_time.elapsed(),
queue_time: Instant::now(),
prompt_tokens: encoding.input_ids.len(),
- normalize,
},
encoding,
});
self.notify_batching_task.notify_one();
- let response = response_rx
+ let mut response = response_rx
.await
.expect(
"Infer batching task dropped the sender without sending a response. This is a bug.",
@@ -132,6 +139,23 @@ impl Infer {
err
})?;
+ if normalize {
+ // Normalize embedding
+ let scale = (1.0
+ / response
+ .results
+ .iter()
+ .map(|v| {
+ let v = *v as f64;
+ v * v
+ })
+ .sum::()
+ .sqrt()) as f32;
+ for v in response.results.iter_mut() {
+ *v *= scale;
+ }
+ }
+
// Timings
let total_time = start_time.elapsed();
@@ -151,6 +175,88 @@ impl Infer {
Ok(response)
}
+ #[instrument(skip(self, _permit))]
+ pub async fn predict + std::fmt::Debug>(
+ &self,
+ inputs: I,
+ truncate: bool,
+ _permit: OwnedSemaphorePermit,
+ ) -> Result {
+ if !self.is_classifier() {
+ metrics::increment_counter!("te_request_failure", "err" => "model_type");
+ let message = "model is not a classifier model".to_string();
+ // tracing::error!("{message}");
+ return Err(TextEmbeddingsError::Backend(BackendError::Inference(
+ message,
+ )));
+ }
+
+ let start_time = Instant::now();
+ metrics::increment_counter!("te_predict_count");
+
+ // Tokenization
+ let encoding = self
+ .tokenization
+ .encode(inputs.into(), truncate)
+ .await
+ .map_err(|err| {
+ metrics::increment_counter!("te_request_failure", "err" => "tokenization");
+ tracing::error!("{err}");
+ err
+ })?;
+
+ // MPSC channel to communicate with the background batching task
+ let (response_tx, response_rx) = oneshot::channel();
+
+ // Append the request to the queue
+ self.queue.append(Entry {
+ metadata: Metadata {
+ response_tx,
+ span: Span::current(),
+ tokenization: start_time.elapsed(),
+ queue_time: Instant::now(),
+ prompt_tokens: encoding.input_ids.len(),
+ },
+ encoding,
+ });
+
+ self.notify_batching_task.notify_one();
+
+ let response = response_rx
+ .await
+ .expect(
+ "Infer batching task dropped the sender without sending a response. This is a bug.",
+ )
+ .map_err(|err| {
+ metrics::increment_counter!("te_request_failure", "err" => "inference");
+ tracing::error!("{err}");
+ err
+ })?;
+
+ // Timings
+ let total_time = start_time.elapsed();
+
+ // Metrics
+ metrics::increment_counter!("te_predict_success");
+ metrics::histogram!("te_predict_duration", total_time.as_secs_f64());
+ metrics::histogram!(
+ "te_predict_tokenization_duration",
+ response.tokenization.as_secs_f64()
+ );
+ metrics::histogram!("te_predict_queue_duration", response.queue.as_secs_f64());
+ metrics::histogram!(
+ "te_predict_inference_duration",
+ response.inference.as_secs_f64()
+ );
+
+ Ok(response)
+ }
+
+ #[instrument(skip(self))]
+ pub fn is_classifier(&self) -> bool {
+ matches!(self.backend.model_type, ModelType::Classifier)
+ }
+
#[instrument(skip(self))]
pub async fn health(&self) -> bool {
self.backend.health().await.is_ok()
@@ -177,36 +283,23 @@ async fn batching_task(
}
#[instrument(skip_all)]
-async fn embed_task(
+async fn backend_task(
backend: Backend,
mut embed_receiver: mpsc::UnboundedReceiver<(NextBatch, oneshot::Sender<()>)>,
) {
while let Some((batch, _callback)) = embed_receiver.recv().await {
let inference_start = Instant::now();
- let results = backend.embed(batch.1).await;
+ let results = match &backend.model_type {
+ ModelType::Classifier => backend.predict(batch.1).await,
+ ModelType::Embedding(_) => backend.embed(batch.1).await,
+ };
// Handle sending responses in another thread to avoid starving the backend
tokio::task::spawn_blocking(move || match results {
Ok(embeddings) => {
- batch.0.into_par_iter().zip(embeddings).for_each(|(m, e)| {
- let e = match m.normalize {
- // Normalize embedding
- true => {
- let scale = (1.0
- / e.iter()
- .map(|v| {
- let v = *v as f64;
- v * v
- })
- .sum::()
- .sqrt()) as f32;
- e.into_iter().map(|v| v * scale).collect()
- }
- false => e,
- };
-
+ batch.0.into_iter().zip(embeddings).for_each(|(m, e)| {
let _ = m.response_tx.send(Ok(InferResponse {
- embeddings: e,
+ results: e,
prompt_tokens: m.prompt_tokens,
tokenization: m.tokenization,
queue: inference_start - m.queue_time,
@@ -225,7 +318,7 @@ async fn embed_task(
#[derive(Debug)]
pub struct InferResponse {
- pub embeddings: Embedding,
+ pub results: Vec,
pub prompt_tokens: usize,
pub tokenization: Duration,
pub queue: Duration,
diff --git a/core/src/queue.rs b/core/src/queue.rs
index 38ad513d..5262fe29 100644
--- a/core/src/queue.rs
+++ b/core/src/queue.rs
@@ -29,8 +29,6 @@ pub struct Metadata {
pub queue_time: Instant,
/// Number of tokens in the prompt
pub prompt_tokens: usize,
- /// Normalize the embeddings
- pub normalize: bool,
}
/// Request Queue
diff --git a/core/src/tokenization.rs b/core/src/tokenization.rs
index c70aa55e..a320bd9a 100644
--- a/core/src/tokenization.rs
+++ b/core/src/tokenization.rs
@@ -1,7 +1,7 @@
/// Payload tokenization logic
use crate::TextEmbeddingsError;
use tokenizers::tokenizer::Tokenizer;
-use tokenizers::TruncationDirection;
+use tokenizers::{EncodeInput, TruncationDirection, TruncationParams, TruncationStrategy};
use tokio::sync::{mpsc, oneshot};
use tracing::{instrument, Span};
@@ -61,7 +61,7 @@ impl Tokenization {
#[instrument(skip_all)]
pub async fn encode(
&self,
- inputs: String,
+ inputs: EncodingInput,
truncate: bool,
) -> Result {
// Check if inputs is empty
@@ -81,13 +81,13 @@ impl Tokenization {
// Await on response channel
// Unwrap is safe here
- Ok(response_receiver.await.expect("Tokenization background task dropped the sender without sending a response. This is a bug.")?)
+ response_receiver.await.expect("Tokenization background task dropped the sender without sending a response. This is a bug.")
}
}
/// Start tokenization workers
fn tokenizer_worker(
- tokenizer: Tokenizer,
+ mut tokenizer: Tokenizer,
max_input_length: usize,
position_offset: usize,
mut receiver: mpsc::UnboundedReceiver,
@@ -103,7 +103,7 @@ fn tokenizer_worker(
truncate,
max_input_length,
position_offset,
- &tokenizer,
+ &mut tokenizer,
));
}
})
@@ -112,26 +112,34 @@ fn tokenizer_worker(
/// Get input length and optionally truncate it
fn encode_input(
- inputs: String,
+ inputs: EncodingInput,
truncate: bool,
max_input_length: usize,
position_offset: usize,
- tokenizer: &Tokenizer,
+ tokenizer: &mut Tokenizer,
) -> Result {
- // Get the number of tokens in the input
- let mut encoding = tokenizer.encode(inputs.clone(), true)?;
-
- let mut seq_len = encoding.len();
+ // Default truncation params
+ let truncate_params = truncate.then_some(TruncationParams {
+ direction: TruncationDirection::Right,
+ max_length: max_input_length,
+ strategy: TruncationStrategy::LongestFirst,
+ stride: 0,
+ });
+
+ let inputs: EncodeInput = match inputs {
+ EncodingInput::Single(s) => s.into(),
+ EncodingInput::Dual(s1, s2) => (s1, s2).into(),
+ };
+
+ let encoding = tokenizer
+ .with_truncation(truncate_params)?
+ .encode(inputs, true)?;
+ let seq_len = encoding.len();
if seq_len > max_input_length {
- if truncate {
- encoding.truncate(max_input_length, 0, TruncationDirection::Right);
- seq_len = max_input_length;
- } else {
- return Err(TextEmbeddingsError::Validation(format!(
- "`inputs` must have less than {max_input_length} tokens. Given: {seq_len}"
- )));
- }
+ return Err(TextEmbeddingsError::Validation(format!(
+ "`inputs` must have less than {max_input_length} tokens. Given: {seq_len}"
+ )));
}
metrics::histogram!("te_request_input_length", seq_len as f64);
@@ -151,8 +159,29 @@ pub struct Encoding {
pub position_ids: Vec,
}
+#[derive(Debug)]
+pub enum EncodingInput {
+ Single(String),
+ Dual(String, String),
+}
+
+impl EncodingInput {
+ fn is_empty(&self) -> bool {
+ match self {
+ EncodingInput::Single(s) => s.is_empty(),
+ EncodingInput::Dual(s1, s2) => s1.is_empty() && s2.is_empty(),
+ }
+ }
+}
+
+impl From for EncodingInput {
+ fn from(value: String) -> Self {
+ Self::Single(value)
+ }
+}
+
type TokenizerRequest = (
- String,
+ EncodingInput,
bool,
oneshot::Sender>,
Span,
diff --git a/docs/openapi.json b/docs/openapi.json
index 0daa94da..51b65348 100644
--- a/docs/openapi.json
+++ b/docs/openapi.json
@@ -17,8 +17,8 @@
"tags": [
"Text Embeddings Inference"
],
- "summary": "Get Embeddings",
- "description": "Get Embeddings",
+ "summary": "Get Embeddings. Returns a 424 status code if the model is not an embedding model.",
+ "description": "Get Embeddings. Returns a 424 status code if the model is not an embedding model.",
"operationId": "embed",
"requestBody": {
"content": {
@@ -100,6 +100,94 @@
}
}
},
+ "/embeddings": {
+ "post": {
+ "tags": [
+ "Text Embeddings Inference"
+ ],
+ "summary": "OpenAI compatible route. Returns a 424 status code if the model is not an embedding model.",
+ "description": "OpenAI compatible route. Returns a 424 status code if the model is not an embedding model.",
+ "operationId": "openai_embed",
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/OpenAICompatRequest"
+ }
+ }
+ },
+ "required": true
+ },
+ "responses": {
+ "200": {
+ "description": "Embeddings",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/OpenAICompatResponse"
+ }
+ }
+ }
+ },
+ "413": {
+ "description": "Batch size error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/OpenAICompatErrorResponse"
+ },
+ "example": {
+ "message": "Batch size error",
+ "type": "validation"
+ }
+ }
+ }
+ },
+ "422": {
+ "description": "Tokenization error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/OpenAICompatErrorResponse"
+ },
+ "example": {
+ "message": "Tokenization error",
+ "type": "tokenizer"
+ }
+ }
+ }
+ },
+ "424": {
+ "description": "Embedding Error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/OpenAICompatErrorResponse"
+ },
+ "example": {
+ "message": "Inference failed",
+ "type": "backend"
+ }
+ }
+ }
+ },
+ "429": {
+ "description": "Model is overloaded",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/OpenAICompatErrorResponse"
+ },
+ "example": {
+ "message": "Model is overloaded",
+ "type": "overloaded"
+ }
+ }
+ }
+ }
+ }
+ }
+ },
"/health": {
"get": {
"tags": [
@@ -173,19 +261,19 @@
}
}
},
- "/openai": {
+ "/predict": {
"post": {
"tags": [
"Text Embeddings Inference"
],
- "summary": "OpenAI compatible route",
- "description": "OpenAI compatible route",
- "operationId": "openai_embed",
+ "summary": "Get Predictions. Returns a 424 status code if the model is not a Sequence Classification model",
+ "description": "Get Predictions. Returns a 424 status code if the model is not a Sequence Classification model",
+ "operationId": "predict",
"requestBody": {
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/OpenAICompatRequest"
+ "$ref": "#/components/schemas/PredictRequest"
}
}
},
@@ -193,11 +281,11 @@
},
"responses": {
"200": {
- "description": "Embeddings",
+ "description": "Predictions",
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/OpenAICompatResponse"
+ "$ref": "#/components/schemas/PredictResponse"
}
}
}
@@ -207,11 +295,11 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/OpenAICompatErrorResponse"
+ "$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "message": "Batch size error",
- "type": "validation"
+ "error": "Batch size error",
+ "error_type": "validation"
}
}
}
@@ -221,25 +309,25 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/OpenAICompatErrorResponse"
+ "$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "message": "Tokenization error",
- "type": "tokenizer"
+ "error": "Tokenization error",
+ "error_type": "tokenizer"
}
}
}
},
"424": {
- "description": "Embedding Error",
+ "description": "Prediction Error",
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/OpenAICompatErrorResponse"
+ "$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "message": "Inference failed",
- "type": "backend"
+ "error": "Inference failed",
+ "error_type": "backend"
}
}
}
@@ -249,11 +337,11 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/OpenAICompatErrorResponse"
+ "$ref": "#/components/schemas/ErrorResponse"
},
"example": {
- "message": "Model is overloaded",
- "type": "overloaded"
+ "error": "Model is overloaded",
+ "error_type": "overloaded"
}
}
}
@@ -264,6 +352,34 @@
},
"components": {
"schemas": {
+ "ClassifierModel": {
+ "type": "object",
+ "required": [
+ "id2label",
+ "label2id"
+ ],
+ "properties": {
+ "id2label": {
+ "type": "object",
+ "additionalProperties": {
+ "type": "string"
+ },
+ "example": {
+ "0": "LABEL"
+ }
+ },
+ "label2id": {
+ "type": "object",
+ "additionalProperties": {
+ "type": "integer",
+ "minimum": 0
+ },
+ "example": {
+ "LABEL": "0"
+ }
+ }
+ }
+ },
"EmbedRequest": {
"type": "object",
"required": [
@@ -273,6 +389,11 @@
"inputs": {
"$ref": "#/components/schemas/Input"
},
+ "normalize": {
+ "type": "boolean",
+ "default": "true",
+ "example": "true"
+ },
"truncate": {
"type": "boolean",
"default": "false",
@@ -297,6 +418,18 @@
]
]
},
+ "EmbeddingModel": {
+ "type": "object",
+ "required": [
+ "pooling"
+ ],
+ "properties": {
+ "pooling": {
+ "type": "string",
+ "example": "cls"
+ }
+ }
+ },
"ErrorResponse": {
"type": "object",
"required": [
@@ -327,7 +460,7 @@
"required": [
"model_id",
"model_dtype",
- "model_pooling",
+ "model_type",
"max_concurrent_requests",
"max_input_length",
"max_batch_tokens",
@@ -378,15 +511,14 @@
"description": "Model info",
"example": "thenlper/gte-base"
},
- "model_pooling": {
- "type": "string",
- "example": "cls"
- },
"model_sha": {
"type": "string",
"example": "fca14538aa9956a46526bd1d0d11d69e19b5a101",
"nullable": true
},
+ "model_type": {
+ "$ref": "#/components/schemas/ModelType"
+ },
"sha": {
"type": "string",
"example": "null",
@@ -417,6 +549,32 @@
}
]
},
+ "ModelType": {
+ "oneOf": [
+ {
+ "type": "object",
+ "required": [
+ "classifier"
+ ],
+ "properties": {
+ "classifier": {
+ "$ref": "#/components/schemas/ClassifierModel"
+ }
+ }
+ },
+ {
+ "type": "object",
+ "required": [
+ "embedding"
+ ],
+ "properties": {
+ "embedding": {
+ "$ref": "#/components/schemas/EmbeddingModel"
+ }
+ }
+ }
+ ]
+ },
"OpenAICompatEmbedding": {
"type": "object",
"required": [
@@ -536,6 +694,67 @@
"minimum": 0
}
}
+ },
+ "PredictRequest": {
+ "type": "object",
+ "required": [
+ "inputs"
+ ],
+ "properties": {
+ "inputs": {
+ "$ref": "#/components/schemas/Sequence"
+ },
+ "raw_scores": {
+ "type": "boolean",
+ "default": "false",
+ "example": "false"
+ },
+ "truncate": {
+ "type": "boolean",
+ "default": "false",
+ "example": "false"
+ }
+ }
+ },
+ "PredictResponse": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/Prediction"
+ }
+ },
+ "Prediction": {
+ "type": "object",
+ "required": [
+ "score",
+ "label"
+ ],
+ "properties": {
+ "label": {
+ "type": "string",
+ "example": "admiration"
+ },
+ "score": {
+ "type": "number",
+ "format": "float",
+ "example": "0.5"
+ }
+ }
+ },
+ "Sequence": {
+ "oneOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "array",
+ "items": {
+ "type": "string"
+ },
+ "description": "",
+ "maxItems": 2,
+ "minItems": 2
+ }
+ ]
}
}
},
diff --git a/docs/source/en/cli_arguments.md b/docs/source/en/cli_arguments.md
index 17cc6679..2db8d492 100644
--- a/docs/source/en/cli_arguments.md
+++ b/docs/source/en/cli_arguments.md
@@ -48,20 +48,18 @@ Options:
--dtype
The dtype to be forced upon the model
-
- If `dtype` is not set, it defaults to float32 on accelerate, and float16 for all other architectures
[env: DTYPE=]
- [possible values: float16]
+ [possible values: float16, float32]
--pooling
- Optionally control the pooling method.
-
- If `pooling` is not set, the pooling configuration will be parsed from the model `1_Pooling/config.json`
- configuration.
-
+ Optionally control the pooling method for embedding models.
+
+ If `pooling` is not set, the pooling configuration will be parsed from the model `1_Pooling/config.json`
+ configuration.
+
If `pooling` is set, it will override the model pooling configuration
-
+
[env: POOLING=]
[possible values: cls, mean]
diff --git a/docs/source/en/quick_tour.md b/docs/source/en/quick_tour.md
index dbd382c4..c55de18f 100644
--- a/docs/source/en/quick_tour.md
+++ b/docs/source/en/quick_tour.md
@@ -16,6 +16,8 @@ rendered properly in your Markdown viewer.
# Quick Tour
+## Text Embeddings
+
The easiest way to get started with TEI is to use one of the official Docker containers
(see [Supported models and hardware](supported_models) to choose the right container).
@@ -50,3 +52,47 @@ curl 127.0.0.1:8080/embed \
-d '{"inputs":"What is Deep Learning?"}' \
-H 'Content-Type: application/json'
```
+
+## Sequence Classification
+
+TEI can also be used to deploy Sequence Classification models.
+See [this blogpost](https://blog.llamaindex.ai/boosting-rag-picking-the-best-embedding-reranker-models-42d079022e83) by
+the LlamaIndex team to understand how you can use Sequence Classification models in your RAG pipeline to improve
+downstream performance.
+
+Let's say you want to use `BAAI/bge-reranker-large`:
+
+```shell
+model=BAAI/bge-reranker-large
+revision=refs/pr/4
+volume=$PWD/data
+
+docker run --gpus all -p 8080:80 -v $volume:/data --pull always ghcr.io/huggingface/text-embeddings-inference:0.3.0 --model-id $model --revision $revision
+```
+
+Once you have deployed a model you can use the `predict` endpoint and rank the similarity between a pair of inputs:
+
+```bash
+curl 127.0.0.1:8080/predict \
+ -X POST \
+ -d '{"inputs":["What is Deep Learning?", "Deep learning is..."], "raw_scores": true}' \
+ -H 'Content-Type: application/json'
+```
+
+You can also use classic Sequence Classification models like `SamLowe/roberta-base-go_emotions`:
+
+```shell
+model=SamLowe/roberta-base-go_emotions
+volume=$PWD/data
+
+docker run --gpus all -p 8080:80 -v $volume:/data --pull always ghcr.io/huggingface/text-embeddings-inference:0.3.0 --model-id $model
+```
+
+Once you have deployed the model you can use the `predict` endpoint to get the emotions most associated with an input:
+
+```bash
+curl 127.0.0.1:8080/predict \
+ -X POST \
+ -d '{"inputs":"I like you."}' \
+ -H 'Content-Type: application/json'
+```
diff --git a/docs/source/en/supported_models.md b/docs/source/en/supported_models.md
index 733fff7d..ddb0ebee 100644
--- a/docs/source/en/supported_models.md
+++ b/docs/source/en/supported_models.md
@@ -16,29 +16,46 @@ rendered properly in your Markdown viewer.
# Supported models and hardware
-## Supported models
-
-Text Embeddings Inference currently supports BERT, CamemBERT, and XLM-RoBERTa models with absolute positions.
We are continually expanding our support for other model types and plan to include them in future updates.
+## Supported embeddings models
+
+Text Embeddings Inference currently supports BERT, CamemBERT, XLM-RoBERTa models with absolute positions and JinaBERT
+model with Alibi positions.
+
Below are some examples of the currently supported models:
-| MTEB Rank | Model Type | Model ID |
-|-----------|--------------|--------------------------------------------------------------------------------|
-| 1 | Bert | [BAAI/bge-large-en-v1.5](https://hf.co/BAAI/bge-large-en-v1.5) |
-| 2 | | [BAAI/bge-base-en-v1.5](https://hf.co/BAAI/bge-base-en-v1.5) |
-| 3 | | [llmrails/ember-v1](https://hf.co/llmrails/ember-v1) |
-| 4 | | [thenlper/gte-large](https://hf.co/thenlper/gte-large) |
-| 5 | | [thenlper/gte-base](https://hf.co/thenlper/gte-base) |
-| 6 | | [intfloat/e5-large-v2](https://hf.co/intfloat/e5-large-v2) |
-| 7 | | [BAAI/bge-small-en-v1.5](https://hf.co/BAAI/bge-small-en-v1.5) |
-| 10 | | [intfloat/e5-base-v2](https://hf.co/intfloat/e5-base-v2) |
-| 11 | XLM-RoBERTa | [intfloat/multilingual-e5-large](https://hf.co/intfloat/multilingual-e5-large) |
+
+| MTEB Rank | Model Type | Model ID |
+|-----------|-------------|----------------------------------------------------------------------------------------|
+| 1 | Bert | [BAAI/bge-large-en-v1.5](https://hf.co/BAAI/bge-large-en-v1.5) |
+| 2 | | [BAAI/bge-base-en-v1.5](https://hf.co/BAAI/bge-base-en-v1.5) |
+| 3 | | [llmrails/ember-v1](https://hf.co/llmrails/ember-v1) |
+| 4 | | [thenlper/gte-large](https://hf.co/thenlper/gte-large) |
+| 5 | | [thenlper/gte-base](https://hf.co/thenlper/gte-base) |
+| 6 | | [intfloat/e5-large-v2](https://hf.co/intfloat/e5-large-v2) |
+| 7 | | [BAAI/bge-small-en-v1.5](https://hf.co/BAAI/bge-small-en-v1.5) |
+| 10 | | [intfloat/e5-base-v2](https://hf.co/intfloat/e5-base-v2) |
+| 11 | XLM-RoBERTa | [intfloat/multilingual-e5-large](https://hf.co/intfloat/multilingual-e5-large) |
+| N/A | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) |
+| N/A | JinaBERT | [jinaai/jina-embeddings-v2-small-en](https://hf.co/jinaai/jina-embeddings-v2-small-en) |
To explore the list of best performing text embeddings models, visit the
[Massive Text Embedding Benchmark (MTEB) Leaderboard](https://huggingface.co/spaces/mteb/leaderboard).
+## Supported sequence classification models
+
+Text Embeddings Inference currently supports CamemBERT, and XLM-RoBERTa Sequence Classification models with absolute positions.
+
+Below are some examples of the currently supported models:
+
+| Task | Model Type | Model ID | Revision |
+|--------------------|-------------|---------------------------------------------------------------------------------------------|-------------|
+| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) | `refs/pr/4` |
+| Re-Ranking | XLM-RoBERTa | [BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base) | `refs/pr/5` |
+| Sentiment Analysis | RoBERTa | [SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) | |
+
## Supported hardware
Text Embeddings Inference supports can be used on CPU, Turing (T4, RTX 2000 series, ...), Ampere 80 (A100, A30),
diff --git a/load_tests/load.js b/load_tests/load.js
index d4bef350..7a945cce 100644
--- a/load_tests/load.js
+++ b/load_tests/load.js
@@ -41,7 +41,7 @@ export default function () {
});
const headers = {'Content-Type': 'application/json'};
- const res = http.post(`http://${host}/embed`, payload, {
+ const res = http.post(`http://${host}`, payload, {
headers, timeout: '20m'
});
diff --git a/router/src/lib.rs b/router/src/lib.rs
index 8dc7d859..f9683ebf 100644
--- a/router/src/lib.rs
+++ b/router/src/lib.rs
@@ -2,8 +2,31 @@
pub mod server;
use serde::{Deserialize, Serialize};
+use std::collections::HashMap;
+use text_embeddings_core::tokenization::EncodingInput;
use utoipa::ToSchema;
+#[derive(Clone, Debug, Serialize, ToSchema)]
+pub struct EmbeddingModel {
+ #[schema(example = "cls")]
+ pub pooling: String,
+}
+
+#[derive(Clone, Debug, Serialize, ToSchema)]
+pub struct ClassifierModel {
+ #[schema(example = json!({"0": "LABEL"}))]
+ pub id2label: HashMap,
+ #[schema(example = json!({"LABEL": "0"}))]
+ pub label2id: HashMap,
+}
+
+#[derive(Clone, Debug, Serialize, ToSchema)]
+#[serde(rename_all = "lowercase")]
+pub enum ModelType {
+ Classifier(ClassifierModel),
+ Embedding(EmbeddingModel),
+}
+
#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct Info {
/// Model info
@@ -13,8 +36,7 @@ pub struct Info {
pub model_sha: Option,
#[schema(example = "float16")]
pub model_dtype: String,
- #[schema(example = "cls")]
- pub model_pooling: String,
+ pub model_type: ModelType,
/// Router Parameters
#[schema(example = "128")]
pub max_concurrent_requests: usize,
@@ -37,6 +59,53 @@ pub struct Info {
pub docker_label: Option<&'static str>,
}
+#[derive(Deserialize, ToSchema, Debug)]
+#[serde(untagged)]
+pub(crate) enum Sequence {
+ Single(String),
+ Pair(String, String),
+}
+
+impl Sequence {
+ pub(crate) fn count_chars(&self) -> usize {
+ match self {
+ Sequence::Single(s) => s.chars().count(),
+ Sequence::Pair(s1, s2) => s1.chars().count() + s2.chars().count(),
+ }
+ }
+}
+
+impl From for EncodingInput {
+ fn from(value: Sequence) -> Self {
+ match value {
+ Sequence::Single(s) => Self::Single(s),
+ Sequence::Pair(s1, s2) => Self::Dual(s1, s2),
+ }
+ }
+}
+
+#[derive(Deserialize, ToSchema)]
+pub(crate) struct PredictRequest {
+ pub inputs: Sequence,
+ #[serde(default)]
+ #[schema(default = "false", example = "false")]
+ pub truncate: bool,
+ #[serde(default)]
+ #[schema(default = "false", example = "false")]
+ pub raw_scores: bool,
+}
+
+#[derive(Serialize, ToSchema)]
+pub(crate) struct Prediction {
+ #[schema(example = "0.5")]
+ score: f32,
+ #[schema(example = "admiration")]
+ label: String,
+}
+
+#[derive(Serialize, ToSchema)]
+pub(crate) struct PredictResponse(Vec);
+
#[derive(Deserialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum Input {
diff --git a/router/src/main.rs b/router/src/main.rs
index 9b0d1c78..7ed4a579 100644
--- a/router/src/main.rs
+++ b/router/src/main.rs
@@ -9,16 +9,16 @@ use opentelemetry::sdk::{trace, Resource};
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use serde::Deserialize;
+use std::collections::HashMap;
use std::fs;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path;
use text_embeddings_backend::DType;
-use text_embeddings_backend::{Backend, Pool};
use text_embeddings_core::download::{download_artifacts, download_pool_config};
use text_embeddings_core::infer::Infer;
use text_embeddings_core::queue::Queue;
use text_embeddings_core::tokenization::Tokenization;
-use text_embeddings_router::{server, Info};
+use text_embeddings_router::{server, ClassifierModel, EmbeddingModel, Info, ModelType};
use tokenizers::Tokenizer;
use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt;
@@ -54,14 +54,14 @@ struct Args {
#[clap(long, env, value_enum)]
dtype: Option,
- /// Optionally control the pooling method.
+ /// Optionally control the pooling method for embedding models.
///
/// If `pooling` is not set, the pooling configuration will be parsed from the
/// model `1_Pooling/config.json` configuration.
///
/// If `pooling` is set, it will override the model pooling configuration
#[clap(long, env, value_enum)]
- pooling: Option,
+ pooling: Option,
/// The maximum amount of concurrent requests for this particular deployment.
/// Having a low limit will refuse clients requests instead of having them
@@ -127,10 +127,13 @@ struct Args {
#[derive(Debug, Deserialize)]
pub struct ModelConfig {
+ pub architectures: Vec,
pub model_type: String,
#[serde(alias = "n_positions")]
pub max_position_embeddings: usize,
pub pad_token_id: usize,
+ pub id2label: Option>,
+ pub label2id: Option>,
}
#[derive(Debug, Clone, PartialEq, Deserialize)]
@@ -173,7 +176,8 @@ async fn main() -> Result<()> {
// Optionally download the pooling config.
if args.pooling.is_none() {
- download_pool_config(&api_repo).await.context("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model.")?;
+ // If a pooling config exist, download it
+ let _ = download_pool_config(&api_repo).await;
}
// Download model from the Hub
@@ -188,23 +192,62 @@ async fn main() -> Result<()> {
let config: ModelConfig =
serde_json::from_str(&config).context("Failed to parse `config.json`")?;
- // Set pooling
- let pool = match args.pooling {
- Some(pool) => pool,
- None => {
- // Load pooling config
- let config_path = model_root.join("1_Pooling/config.json");
- let config = fs::read_to_string(config_path).context("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model.")?;
- let config: PoolConfig =
- serde_json::from_str(&config).context("Failed to parse `1_Pooling/config.json`")?;
- if config.pooling_mode_cls_token {
- Pool::Cls
- } else if config.pooling_mode_mean_tokens {
- Pool::Mean
- } else {
- return Err(anyhow!("Pooling config {config:?} is not supported"));
+ // Set model type from config
+ let backend_model_type = {
+ // Check if the model is a classifier
+ let mut classifier = false;
+ for arch in &config.architectures {
+ if arch.ends_with("Classification") {
+ classifier = true;
+ break;
}
}
+
+ if classifier {
+ if args.pooling.is_some() {
+ tracing::warn!(
+ "`--pooling` arg is set but model is a classifier. Ignoring `--pooling` arg."
+ );
+ }
+ text_embeddings_backend::ModelType::Classifier
+ } else {
+ // Set pooling
+ let pool = match args.pooling {
+ Some(pool) => pool,
+ None => {
+ // Load pooling config
+ let config_path = model_root.join("1_Pooling/config.json");
+ let config = fs::read_to_string(config_path).context("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model.")?;
+ let config: PoolConfig = serde_json::from_str(&config)
+ .context("Failed to parse `1_Pooling/config.json`")?;
+ if config.pooling_mode_cls_token {
+ text_embeddings_backend::Pool::Cls
+ } else if config.pooling_mode_mean_tokens {
+ text_embeddings_backend::Pool::Mean
+ } else {
+ return Err(anyhow!("Pooling config {config:?} is not supported"));
+ }
+ }
+ };
+ text_embeddings_backend::ModelType::Embedding(pool)
+ }
+ };
+
+ // Info model type
+ let model_type = match &backend_model_type {
+ text_embeddings_backend::ModelType::Classifier => ModelType::Classifier(ClassifierModel {
+ id2label: config
+ .id2label
+ .context("`config.json` does not contain `id2label`")?,
+ label2id: config
+ .label2id
+ .context("`config.json` does not contain `label2id`")?,
+ }),
+ text_embeddings_backend::ModelType::Embedding(pool) => {
+ ModelType::Embedding(EmbeddingModel {
+ pooling: pool.to_string(),
+ })
+ }
};
// Load tokenizer
@@ -251,10 +294,10 @@ async fn main() -> Result<()> {
// Create backend
tracing::info!("Starting model backend");
- let backend = Backend::new(
+ let backend = text_embeddings_backend::Backend::new(
model_root,
dtype.clone(),
- pool.clone(),
+ backend_model_type,
args.uds_path,
args.otlp_endpoint,
)
@@ -285,7 +328,7 @@ async fn main() -> Result<()> {
model_id: args.model_id,
model_sha: args.revision,
model_dtype: dtype.to_string(),
- model_pooling: pool.to_string(),
+ model_type,
max_concurrent_requests: args.max_concurrent_requests,
max_input_length: config.max_position_embeddings,
max_batch_tokens: args.max_batch_tokens,
diff --git a/router/src/server.rs b/router/src/server.rs
index 3284bb99..d0d5b6cc 100644
--- a/router/src/server.rs
+++ b/router/src/server.rs
@@ -1,7 +1,8 @@
/// HTTP Server logic
use crate::{
- EmbedRequest, EmbedResponse, ErrorResponse, ErrorType, Info, Input, OpenAICompatEmbedding,
- OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage,
+ ClassifierModel, EmbedRequest, EmbedResponse, EmbeddingModel, ErrorResponse, ErrorType, Info,
+ Input, ModelType, OpenAICompatEmbedding, OpenAICompatErrorResponse, OpenAICompatRequest,
+ OpenAICompatResponse, OpenAICompatUsage, PredictRequest, PredictResponse, Prediction, Sequence,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
@@ -54,7 +55,156 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json,
+ info: Extension,
+ Json(req): Json,
+) -> Result<(HeaderMap, Json), (StatusCode, Json)> {
+ let span = tracing::Span::current();
+ let start_time = Instant::now();
+
+ metrics::increment_counter!("te_request_count", "method" => "single");
+
+ let compute_chars = req.inputs.count_chars();
+
+ let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
+ let mut response = infer
+ .predict(req.inputs, req.truncate, permit)
+ .await
+ .map_err(ErrorResponse::from)?;
+
+ let id2label = match &info.model_type {
+ ModelType::Classifier(classifier) => &classifier.id2label,
+ _ => panic!(),
+ };
+
+ let mut predictions: Vec = {
+ if !req.raw_scores {
+ // Softmax
+ if response.results.len() > 1 {
+ let max = *response
+ .results
+ .iter()
+ .max_by(|x, y| x.abs().partial_cmp(&y.abs()).unwrap())
+ .unwrap();
+
+ let mut den = 0.0;
+ for v in response.results.iter_mut() {
+ *v = (*v - max).exp();
+ den += *v;
+ }
+ for v in response.results.iter_mut() {
+ *v /= den;
+ }
+ }
+ // Sigmoid
+ else {
+ response.results[0] = 1.0 / (1.0 + (-response.results[0]).exp());
+ }
+ }
+
+ // Map score to label
+ response
+ .results
+ .into_iter()
+ .enumerate()
+ .map(|(i, s)| Prediction {
+ score: s,
+ label: id2label.get(&i.to_string()).unwrap().clone(),
+ })
+ .collect()
+ };
+ // Reverse sort
+ predictions.sort_by(|x, y| x.score.partial_cmp(&y.score).unwrap());
+ predictions.reverse();
+
+ metrics::increment_counter!("te_request_success", "method" => "predict");
+
+ let compute_tokens = response.prompt_tokens;
+ let tokenization_time = response.tokenization;
+ let queue_time = response.queue;
+ let inference_time = response.inference;
+
+ let total_time = start_time.elapsed();
+
+ // Tracing metadata
+ span.record("total_time", format!("{total_time:?}"));
+ span.record("tokenization_time", format!("{tokenization_time:?}"));
+ span.record("queue_time", format!("{queue_time:?}"));
+ span.record("inference_time", format!("{inference_time:?}"));
+
+ // Headers
+ let mut headers = HeaderMap::new();
+ headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
+ headers.insert(
+ "x-compute-time",
+ total_time.as_millis().to_string().parse().unwrap(),
+ );
+ headers.insert(
+ "x-compute-characters",
+ compute_chars.to_string().parse().unwrap(),
+ );
+ headers.insert(
+ "x-compute-tokens",
+ compute_tokens.to_string().parse().unwrap(),
+ );
+ headers.insert(
+ "x-total-time",
+ total_time.as_millis().to_string().parse().unwrap(),
+ );
+ headers.insert(
+ "x-tokenization-time",
+ tokenization_time.as_millis().to_string().parse().unwrap(),
+ );
+ headers.insert(
+ "x-queue-time",
+ queue_time.as_millis().to_string().parse().unwrap(),
+ );
+ headers.insert(
+ "x-inference-time",
+ inference_time.as_millis().to_string().parse().unwrap(),
+ );
+
+ // Metrics
+ metrics::histogram!("te_request_duration", total_time.as_secs_f64());
+ metrics::histogram!(
+ "te_request_tokenization_duration",
+ tokenization_time.as_secs_f64()
+ );
+ metrics::histogram!("te_request_queue_duration", queue_time.as_secs_f64());
+ metrics::histogram!(
+ "te_request_inference_duration",
+ inference_time.as_secs_f64()
+ );
+
+ tracing::info!("Success");
+
+ Ok((headers, Json(PredictResponse(predictions))))
+}
+
+/// Get Embeddings. Returns a 424 status code if the model is not an embedding model.
#[utoipa::path(
post,
tag = "Text Embeddings Inference",
@@ -105,7 +255,7 @@ async fn embed(
response.tokenization,
response.queue,
response.inference,
- EmbedResponse(vec![response.embeddings]),
+ EmbedResponse(vec![response.results]),
)
}
Input::Batch(inputs) => {
@@ -157,7 +307,7 @@ async fn embed(
total_queue_time += r.queue.as_nanos() as u64;
total_inference_time += r.inference.as_nanos() as u64;
total_compute_tokens += r.prompt_tokens;
- embeddings.push(r.embeddings);
+ embeddings.push(r.results);
}
let batch_size = batch_size as u64;
@@ -220,7 +370,7 @@ async fn embed(
"te_request_tokenization_duration",
tokenization_time.as_secs_f64()
);
- metrics::histogram!("e_request_queue_duration", queue_time.as_secs_f64());
+ metrics::histogram!("te_request_queue_duration", queue_time.as_secs_f64());
metrics::histogram!(
"te_request_inference_duration",
inference_time.as_secs_f64()
@@ -231,11 +381,11 @@ async fn embed(
Ok((headers, Json(response)))
}
-/// OpenAI compatible route
+/// OpenAI compatible route. Returns a 424 status code if the model is not an embedding model.
#[utoipa::path(
post,
tag = "Text Embeddings Inference",
-path = "/openai",
+path = "/embeddings",
request_body = OpenAICompatRequest,
responses(
(status = 200, description = "Embeddings", body = OpenAICompatResponse),
@@ -285,7 +435,7 @@ async fn openai_embed(
response.inference,
vec![OpenAICompatEmbedding {
object: "embedding",
- embedding: response.embeddings,
+ embedding: response.results,
index: 0,
}],
)
@@ -339,7 +489,7 @@ async fn openai_embed(
total_compute_tokens += r.prompt_tokens;
embeddings.push(OpenAICompatEmbedding {
object: "embedding",
- embedding: r.embeddings,
+ embedding: r.results,
index: i,
});
}
@@ -404,7 +554,7 @@ async fn openai_embed(
"te_request_tokenization_duration",
tokenization_time.as_secs_f64()
);
- metrics::histogram!("e_request_queue_duration", queue_time.as_secs_f64());
+ metrics::histogram!("te_request_queue_duration", queue_time.as_secs_f64());
metrics::histogram!(
"te_request_inference_duration",
inference_time.as_secs_f64()
@@ -448,14 +598,22 @@ pub async fn run(
paths(
get_model_info,
health,
+ predict,
embed,
openai_embed,
metrics,
),
components(
schemas(
+ Sequence,
Input,
Info,
+ ModelType,
+ ClassifierModel,
+ EmbeddingModel,
+ PredictRequest,
+ Prediction,
+ PredictResponse,
OpenAICompatRequest,
OpenAICompatEmbedding,
OpenAICompatUsage,
@@ -531,13 +689,11 @@ pub async fn run(
let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
// Base routes
- .route("/", post(embed))
.route("/info", get(get_model_info))
.route("/embed", post(embed))
+ .route("/predict", post(predict))
// OpenAI compat route
- .route("/openai", post(openai_embed))
- // AWS Sagemaker route
- .route("/invocations", post(embed))
+ .route("/embeddings", post(openai_embed))
// Base Health route
.route("/health", get(health))
// Inference API health route
@@ -545,7 +701,23 @@ pub async fn run(
// AWS Sagemaker health route
.route("/ping", get(health))
// Prometheus metrics route
- .route("/metrics", get(metrics))
+ .route("/metrics", get(metrics));
+
+ // Set default routes
+ let app = match infer.is_classifier() {
+ true => {
+ app.route("/", post(predict))
+ // AWS Sagemaker route
+ .route("/invocations", post(predict))
+ }
+ false => {
+ app.route("/", post(embed))
+ // AWS Sagemaker route
+ .route("/invocations", post(embed))
+ }
+ };
+
+ let app = app
.layer(Extension(infer))
.layer(Extension(info))
.layer(Extension(prom_handle.clone()))