diff --git a/mistralrs-core/src/layers.rs b/mistralrs-core/src/layers.rs index 48a3d77662..bebe4f3daf 100644 --- a/mistralrs-core/src/layers.rs +++ b/mistralrs-core/src/layers.rs @@ -1303,3 +1303,16 @@ impl GetFloatInfo for DType { Ok(finfo) } } + +/// AvgPool1d with no padding. +pub struct AvgPool1d { + pub kernel_size: usize, + pub stride: usize, +} + +impl Module for AvgPool1d { + fn forward(&self, xs: &Tensor) -> Result { + xs.unsqueeze(2)? + .avg_pool2d_with_stride((1, self.kernel_size), (1, self.stride)) + } +} diff --git a/mistralrs-core/src/vision_models/common/mod.rs b/mistralrs-core/src/vision_models/common/mod.rs new file mode 100644 index 0000000000..4349cd3231 --- /dev/null +++ b/mistralrs-core/src/vision_models/common/mod.rs @@ -0,0 +1,3 @@ +pub(crate) mod siglip; +pub(crate) mod whisper; +pub(crate) mod whisper_feature_extractor; diff --git a/mistralrs-core/src/vision_models/siglip.rs b/mistralrs-core/src/vision_models/common/siglip.rs similarity index 100% rename from mistralrs-core/src/vision_models/siglip.rs rename to mistralrs-core/src/vision_models/common/siglip.rs diff --git a/mistralrs-core/src/vision_models/common/whisper.rs b/mistralrs-core/src/vision_models/common/whisper.rs new file mode 100644 index 0000000000..8774915cc8 --- /dev/null +++ b/mistralrs-core/src/vision_models/common/whisper.rs @@ -0,0 +1,204 @@ +use candle_core::{DType, IndexOp, Result, Tensor}; +use candle_nn::{ + Activation, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, Module, VarBuilder, +}; + +use crate::{ + attention::SdpaParams, + layers::{clamp_for_f16, Sdpa}, +}; + +#[derive(Debug, Clone, serde::Deserialize)] +pub struct WhisperEncoderConfig { + pub num_mel_bins: usize, + pub encoder_layers: usize, + pub encoder_attention_heads: usize, + pub encoder_ffn_dim: usize, + pub activation_function: Activation, + pub d_model: usize, + pub max_source_positions: usize, +} + +pub struct WhisperAttention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + num_heads: usize, + head_dim: usize, +} + +impl WhisperAttention { + fn new(cfg: &WhisperEncoderConfig, vb: VarBuilder) -> Result { + Ok(Self { + q_proj: candle_nn::linear(cfg.d_model, cfg.d_model, vb.pp("q_proj"))?, + k_proj: candle_nn::linear_no_bias(cfg.d_model, cfg.d_model, vb.pp("k_proj"))?, + v_proj: candle_nn::linear(cfg.d_model, cfg.d_model, vb.pp("v_proj"))?, + o_proj: candle_nn::linear(cfg.d_model, cfg.d_model, vb.pp("o_proj"))?, + num_heads: cfg.encoder_attention_heads, + head_dim: cfg.d_model / cfg.encoder_attention_heads, + }) + } + + fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result { + let mut q = self.q_proj.forward(&xs)?; + let mut k = self.k_proj.forward(&xs)?; + let mut v = self.v_proj.forward(&xs)?; + + // Should be same, no caching... + let (bs, q_sq, _) = q.dims3()?; + + q = q + .reshape((bs, q_sq, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + k = k + .reshape((bs, q_sq, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + v = v + .reshape((bs, q_sq, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + + let attn_output = Sdpa + .run_attention( + &q.contiguous()?, + &k.contiguous()?, + &v.contiguous()?, + attention_mask, + None, + &SdpaParams { + n_kv_groups: 1, + use_flash_attn: false, + sliding_window: None, + softcap: None, + softmax_scale: 1. / (self.head_dim as f32).sqrt(), + }, + )? + .transpose(1, 2)? + .contiguous()? + .reshape((bs, q_sq, ()))?; + + self.o_proj.forward(&attn_output) + } +} + +pub struct WhisperEncoderLayer { + attn: WhisperAttention, + self_attn_layer_norm: LayerNorm, + act: Activation, + fc1: Linear, + fc2: Linear, + final_layer_norm: LayerNorm, +} + +impl WhisperEncoderLayer { + fn new(cfg: &WhisperEncoderConfig, vb: VarBuilder) -> Result { + Ok(Self { + self_attn_layer_norm: candle_nn::layer_norm( + cfg.d_model, + 1e-6, + vb.pp("self_attn_layer_norm"), + )?, + final_layer_norm: candle_nn::layer_norm(cfg.d_model, 1e-6, vb.pp("final_layer_norm"))?, + fc1: candle_nn::linear(cfg.d_model, cfg.encoder_ffn_dim, vb.pp("fc1"))?, + fc2: candle_nn::linear(cfg.encoder_ffn_dim, cfg.d_model, vb.pp("fc2"))?, + act: cfg.activation_function.clone(), + attn: WhisperAttention::new(cfg, vb.pp("self_attn"))?, + }) + } + + fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result { + let residual = xs.clone(); + let mut xs = self.self_attn_layer_norm.forward(xs)?; + xs = self.attn.forward(&xs, attention_mask)?; + xs = (residual + xs)?; + + let residual = xs.clone(); + xs = self.final_layer_norm.forward(&xs)?; + xs = self.fc1.forward(&xs)?.apply(&self.act)?; + xs = self.fc2.forward(&xs)?; + xs = (residual + xs)?; + + if xs.dtype() == DType::F16 { + xs = clamp_for_f16(&xs)?; + } + + Ok(xs) + } +} + +pub struct WhisperEncoder { + conv1: Conv1d, + conv2: Conv1d, + embed_positions: Embedding, + layer_norm: LayerNorm, + layers: Vec, +} + +impl WhisperEncoder { + pub fn new(cfg: &WhisperEncoderConfig, vb: VarBuilder) -> Result { + let conv1 = candle_nn::conv1d( + cfg.num_mel_bins, + cfg.d_model, + 3, + Conv1dConfig { + padding: 1, + ..Default::default() + }, + vb.pp("conv1"), + )?; + let conv2 = candle_nn::conv1d( + cfg.d_model, + cfg.d_model, + 3, + Conv1dConfig { + stride: 2, + padding: 1, + ..Default::default() + }, + vb.pp("conv2"), + )?; + let embed_positions = candle_nn::embedding( + cfg.max_source_positions, + cfg.d_model, + vb.pp("embed_positions"), + )?; + let layer_norm = candle_nn::layer_norm(cfg.d_model, 1e-6, vb.pp("layer_norm"))?; + + let vb_l = vb.pp("layers"); + let mut layers = Vec::new(); + for i in 0..cfg.encoder_layers { + layers.push(WhisperEncoderLayer::new(cfg, vb_l.pp(i))?); + } + + Ok(Self { + conv1, + conv2, + embed_positions, + layer_norm, + layers, + }) + } + + pub fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result { + let mut xs = self.conv1.forward(xs)?.gelu()?; + xs = self.conv2.forward(&xs)?; + + xs = xs.permute((0, 2, 1))?; + + let mut embed_pos = self.embed_positions.embeddings().clone(); + // No cache because there is no streaming. + embed_pos = embed_pos.i((.., ..xs.dim(1)?, ..))?; + + xs = xs.broadcast_add(&embed_pos)?; + + for layer in &self.layers { + xs = layer.forward(&xs, attention_mask)?; + } + + self.layer_norm.forward(&xs) + } + + pub fn dtype(&self) -> DType { + self.conv1.weight().dtype() + } +} diff --git a/mistralrs-core/src/vision_models/common/whisper_feature_extractor.rs b/mistralrs-core/src/vision_models/common/whisper_feature_extractor.rs new file mode 100644 index 0000000000..c83cd44313 --- /dev/null +++ b/mistralrs-core/src/vision_models/common/whisper_feature_extractor.rs @@ -0,0 +1,21 @@ +pub struct WhisperFeatureExtractorConfig { + pub feature_size: usize, + pub sampling_rate: usize, + pub hop_length: usize, + pub chunk_length: usize, + pub n_fft: usize, + pub padding_value: f64, +} + +impl Default for WhisperFeatureExtractorConfig { + fn default() -> Self { + Self { + feature_size: 80, + sampling_rate: 16000, + hop_length: 160, + chunk_length: 30, + n_fft: 400, + padding_value: 0.0, + } + } +} diff --git a/mistralrs-core/src/vision_models/minicpmo/config.rs b/mistralrs-core/src/vision_models/minicpmo/config.rs index f010cfa7b2..c0970d67f1 100644 --- a/mistralrs-core/src/vision_models/minicpmo/config.rs +++ b/mistralrs-core/src/vision_models/minicpmo/config.rs @@ -1,10 +1,15 @@ -use crate::{models::qwen2, vision_models::siglip}; +use crate::{ + models::qwen2, + vision_models::common::{siglip, whisper}, +}; #[derive(Debug, Clone, serde::Deserialize)] pub struct MiniCpmOConfig { #[serde(flatten)] pub text_config: qwen2::Config, pub vision_config: siglip::SiglipVisionConfig, + pub audio_config: whisper::WhisperEncoderConfig, pub vision_batch_size: usize, pub query_num: usize, + pub audio_pool_step: usize, } diff --git a/mistralrs-core/src/vision_models/minicpmo/inputs_processor.rs b/mistralrs-core/src/vision_models/minicpmo/inputs_processor.rs index b06c7b91fd..150858d1d3 100644 --- a/mistralrs-core/src/vision_models/minicpmo/inputs_processor.rs +++ b/mistralrs-core/src/vision_models/minicpmo/inputs_processor.rs @@ -18,7 +18,9 @@ use crate::{ InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor, }, sequence::Sequence, - vision_models::ModelInputs, + vision_models::{ + common::whisper_feature_extractor::WhisperFeatureExtractorConfig, ModelInputs, + }, }; use crate::vision_models::{ @@ -42,6 +44,8 @@ const DEFAULT_SLICE_END_TOKEN: &str = ""; const DEFAULT_UNK_TOKEN: &str = ""; const DEFAULT_USE_IMAGE_ID: bool = false; const DEFAULT_SLICE_MODE: bool = true; +const AUDIO_START_ID: &str = "<|audio_start|>"; +const AUDIO_END_ID: &str = "<|audio_end|>"; pub struct MiniCpmOImageProcessor { config: PreProcessorConfig, @@ -85,6 +89,20 @@ impl Processor for MiniCpmOProcessor { } } +// TODO: chunk input? +fn get_audio_placeholder(audio_lens: usize) -> String { + let pool_step = 2; + let feature_lens = audio_lens.div_ceil(WhisperFeatureExtractorConfig::default().hop_length); + + let feature_lens = (feature_lens - 1) / 2 + 1; + let output_lens = (feature_lens - pool_step) / pool_step + 1; + + format!( + "{AUDIO_START_ID}{}{AUDIO_END_ID}", + "".repeat(output_lens) + ) +} + impl InputsProcessor for MiniCpmOImageProcessor { fn get_type(&self) -> InputsProcessorType { InputsProcessorType::Vision @@ -178,8 +196,9 @@ impl InputsProcessor for MiniCpmOImageProcessor { .iter() .all(|seq| seq.images().is_some_and(|images| !images.is_empty())); - let (new_input, pixel_values_all, image_bound, tgt_sizes) = if has_images { + let (new_input, pixel_values_all, image_bound, audio_bound, tgt_sizes) = if has_images { const IMAGE_TAG: &str = "(./)"; + const AUDIO_TAG: &str = "()"; const IMAGE_PATTERN: &str = r"\(./\)"; const AUDIO_PATTERN: &str = r"\(\)"; @@ -191,8 +210,10 @@ impl InputsProcessor for MiniCpmOImageProcessor { let mut tgt_sizes_accum = Vec::new(); let mut input_ids_accum = Vec::new(); let mut image_bounds_accum = Vec::new(); + let mut audio_bounds_accum = Vec::new(); for seq in input_seqs.iter_mut() { + // IMAGE let PreprocessedImages { pixel_values: _, pixel_attention_mask: _, @@ -222,6 +243,46 @@ impl InputsProcessor for MiniCpmOImageProcessor { let tgt_sizes = tgt_sizes.unwrap(); let image_sizes_all = image_sizes_all.unwrap(); + // AUDIO + let mut audios = vec![vec![0f32; 16000]]; + let mut audio_parts = vec![0]; + + let mut audio_features = Vec::new(); + let mut audio_feature_lens = Vec::new(); + let mut audio_placeholders = Vec::new(); + + assert_eq!(audios.len(), audio_parts.len()); + + for a in &audios { + audio_placeholders.push(get_audio_placeholder(a.len())); + } + + let mut cur_audio = Vec::new(); + let mut merge_audio = Vec::new(); + for ((aid, audio), _part) in audios.iter().enumerate().zip(audio_parts) { + if aid == 0 || audio_parts[aid] == audio_parts[aid - 1] { + cur_audio.push(audio); + } else { + let collect = cur_audio.into_iter().flat_map(|a| a).collect::>(); + merge_audio.push(audio); + cur_audio = vec![audio]; + } + } + + // If the audio exceeds 30 seconds, split it into chunks every 30 seconds. + let sampling_rate = WhisperFeatureExtractorConfig::default().sampling_rate; + let max_audio_inp_len = 30 * sampling_rate; + let mut final_merge_audio = Vec::new(); + for audio in merge_audio { + if audio.len() <= max_audio_inp_len { + final_merge_audio.push(audio.to_vec()); + } else { + for chunk in audio.chunks(max_audio_inp_len) { + final_merge_audio.push(chunk.to_vec()); + } + } + } + let text = tokenizer .decode(seq.get_toks(), false) .expect("Detokenization failed!"); @@ -256,18 +317,23 @@ impl InputsProcessor for MiniCpmOImageProcessor { } let mut image_id = 0; + let mut audio_id = 0; for chunk in &mut text_chunks { if chunk == IMAGE_TAG { *chunk = self.get_slice_image_placeholder(image_sizes_all[image_id], image_id); image_id += 1; } + if chunk == AUDIO_TAG { + *chunk = audio_placeholders[audio_id].clone(); + audio_id += 1; + } } let final_text = text_chunks.join(""); seq.set_initial_prompt(final_text.clone()); - let (input_ids, image_bounds) = { + let (input_ids, image_bounds, audio_bounds) = { let im_start_id = tokenizer .encode( self.config @@ -358,13 +424,57 @@ impl InputsProcessor for MiniCpmOImageProcessor { let image_bounds = Tensor::cat(&[image_start_idx, image_end_idx], 1).unwrap(); - (input_ids, image_bounds) + let audio_start_id = tokenizer + .encode(AUDIO_START_ID.to_string(), true) + .unwrap() + .get_ids()[0]; + let audio_end_id = tokenizer + .encode(AUDIO_END_ID.to_string(), true) + .unwrap() + .get_ids()[0]; + + let audio_start_idx = input_ids + .iter() + .enumerate() + .filter_map(|(i, &id)| { + if id == audio_start_id { + Some(i as u32 + 1) + } else { + None + } + }) + .collect::>(); + + let audio_end_idx = input_ids + .iter() + .enumerate() + .filter_map(|(i, &id)| { + if id == audio_end_id { + Some(i as u32) + } else { + None + } + }) + .collect::>(); + + assert_eq!(audio_start_idx.len(), audio_end_idx.len()); + let audio_idx_len = audio_start_idx.len(); + + let audio_start_idx = + Tensor::from_vec(audio_start_idx, (audio_idx_len, 1), device).unwrap(); + let audio_end_idx = + Tensor::from_vec(audio_end_idx, (audio_idx_len, 1), device).unwrap(); + + let audio_bounds = Tensor::cat(&[audio_start_idx, audio_end_idx], 1).unwrap(); + + (input_ids, image_bounds, audio_bounds) }; pixel_values_accum.push(pixel_values_list); tgt_sizes_accum.push(tgt_sizes); input_ids_accum.push(input_ids); image_bounds_accum.push(image_bounds); + audio_bounds_accum.push(audio_bounds); } let mut all_ids_new = Vec::new(); @@ -379,10 +489,11 @@ impl InputsProcessor for MiniCpmOImageProcessor { Some(Tensor::stack(&all_ids_new, 0).unwrap()), Some(pixel_values_accum), Some(image_bounds_accum), + Some(audio_bounds_accum), Some(tgt_sizes_accum), ) } else { - (None, None, None, None) + (None, None, None, None, None) }; let input = match new_input { @@ -394,6 +505,9 @@ impl InputsProcessor for MiniCpmOImageProcessor { pixel_values_all, tgt_sizes, image_bound, + audio_bound, + audio_feature_lens_raw: todo!(), + audio_features: todo!(), }; // Dummy pixel values - real ones are in model specific args diff --git a/mistralrs-core/src/vision_models/minicpmo/mod.rs b/mistralrs-core/src/vision_models/minicpmo/mod.rs index 02a3cb4a62..95d64ec5dc 100644 --- a/mistralrs-core/src/vision_models/minicpmo/mod.rs +++ b/mistralrs-core/src/vision_models/minicpmo/mod.rs @@ -1,7 +1,7 @@ use std::{any::Any, collections::HashMap, sync::Arc}; use candle_core::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::VarBuilder; +use candle_nn::{Linear, Module, VarBuilder}; pub use config::MiniCpmOConfig; pub use inputs_processor::MiniCpmOProcessor; use mistralrs_quant::QuantMethod; @@ -10,6 +10,7 @@ use resampler::Resampler; use crate::{ amoe::AnyMoeBaseModelMixin, device_map::DeviceMapper, + layers::{Activation, AvgPool1d, GetFloatInfo}, models::qwen2, paged_attention::{AttentionImplementation, ModelConfigMetadata}, pipeline::{ @@ -21,17 +22,42 @@ use crate::{ use self::siglip::SiglipVisionTransformer; -use super::siglip; +use super::common::{siglip, whisper::WhisperEncoder}; mod config; mod inputs_processor; mod resampler; +pub struct MultiModalProjector { + act: Activation, + linear1: Linear, + linear2: Linear, +} + +impl MultiModalProjector { + fn new(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result { + Ok(Self { + act: Activation::Relu, + linear1: candle_nn::linear(in_dim, out_dim, vb.pp("linear1"))?, + linear2: candle_nn::linear(in_dim, out_dim, vb.pp("linear2"))?, + }) + } + + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.linear1)? + .apply(&self.act)? + .apply(&self.linear2) + } +} + pub struct MiniCpmOModel { cfg: MiniCpmOConfig, llm: qwen2::Model, vpm: SiglipVisionTransformer, resampler: Resampler, + apm: WhisperEncoder, + audio_projection_layer: MultiModalProjector, + audio_avg_pooler: AvgPool1d, } impl MiniCpmOModel { @@ -50,6 +76,7 @@ impl MiniCpmOModel { normal_loading_metadata, attention_mechanism, )?; + // Vision let vpm = SiglipVisionTransformer::new( &cfg.vision_config, vb.pp("vpm").set_device(real_device.clone()), @@ -63,11 +90,25 @@ impl MiniCpmOModel { None, vb.pp("resampler").set_device(real_device.clone()), )?; + // Audio + let apm = WhisperEncoder::new(&cfg.audio_config, vb.pp("apm"))?; + let audio_projection_layer = MultiModalProjector::new( + cfg.audio_config.encoder_ffn_dim / 4, + cfg.text_config.hidden_size, + vb.pp("audio_projection_layer"), + )?; + let audio_avg_pooler = AvgPool1d { + kernel_size: cfg.audio_pool_step, + stride: cfg.audio_pool_step, + }; Ok(Self { cfg: cfg.clone(), llm, vpm, resampler, + apm, + audio_projection_layer, + audio_avg_pooler, }) } @@ -208,6 +249,117 @@ impl MiniCpmOModel { Ok(vllm_embedding) } + fn get_feat_extract_output_lengths(&self, input_lengths: &Tensor) -> Result<(Tensor, Tensor)> { + let input_lengths = input_lengths.to_dtype(DType::I32)?; + let input_lengths_after_cnn = (((input_lengths - 1.)? / 2.)?.floor()? + 1.)?; + let input_lengths_after_pooling = (((&input_lengths_after_cnn + - self.cfg.audio_pool_step as f64)? + / self.cfg.audio_pool_step as f64)? + .floor()? + + 1.)?; + + Ok((input_lengths_after_cnn, input_lengths_after_pooling)) + } + + fn get_audio_embedding( + &self, + audio_features: &Tensor, + audio_feature_lens_raw: Vec, + ) -> Result>> { + let audio_feature_lens = Tensor::cat(&audio_feature_lens_raw, 0)?; + let (bs, _, max_mel_seq_len) = audio_features.dims3()?; + let max_seq_len = (max_mel_seq_len - 1) / 2 + 1; + + // Create a sequence tensor of shape (bs, max_seq_len) + let seq_range = Tensor::arange(0, max_seq_len as u32, audio_features.device())? + .unsqueeze(1)? + .expand((bs, max_seq_len))?; + let lengths_expand = audio_feature_lens.unsqueeze(1)?.expand((bs, max_seq_len))?; + + // Create mask: 1 for padded values + let padding_mask = seq_range.ge(&lengths_expand)?; + let audio_attention_mask = padding_mask.reshape((bs, 1, 1, max_seq_len))?.expand(( + bs, + 1, + max_seq_len, + max_seq_len, + ))?; + let apm_dtype = self.apm.dtype(); + // 1 -> -inf, 0 -> 0 + let audio_attention_mask = + (audio_attention_mask.to_dtype(apm_dtype)? * apm_dtype.finfo()?.min)?; + + let audio_states = self + .apm + .forward(audio_features, Some(&audio_attention_mask))?; + let mut audio_embeds = self.audio_projection_layer.forward(&audio_states)?; + + audio_embeds = audio_embeds.transpose(1, 2)?; + audio_embeds = self.audio_avg_pooler.forward(&audio_embeds)?; + audio_embeds = audio_embeds.transpose(1, 2)?; + + let (_, feature_lens_after_pooling) = + self.get_feat_extract_output_lengths(&audio_feature_lens)?; + + let num_audio_tokens = feature_lens_after_pooling.to_vec1::()?; + + let mut final_audio_embeds = Vec::new(); + let mut idx = 0; + for lens_i in &audio_feature_lens_raw { + let mut target_audio_embeds = Vec::new(); + for _ in 0..lens_i.dim(0)? { + target_audio_embeds.push(audio_embeds.i(( + idx, + ..num_audio_tokens[idx] as usize, + .., + ))?); + idx += 1; + } + final_audio_embeds.push(target_audio_embeds) + } + + Ok(final_audio_embeds) + } + + fn get_omni_embedding( + &self, + input_embeddings: &Tensor, + audio_features: &Tensor, + audio_feature_lens_raw: Vec, + audio_bound: Vec, + ) -> Result { + let audio_embeddings = self.get_audio_embedding(audio_features, audio_feature_lens_raw)?; + + assert_eq!(audio_embeddings.len(), audio_bound.len()); + let audio_bound_vec = audio_bound + .into_iter() + .map(|x| x.to_vec2::()) + .collect::>>()?; + + // TODO: chunk input? + let mut new_embeddings = input_embeddings.clone(); + for ((i, audio_embs), bounds) in audio_embeddings.iter().enumerate().zip(audio_bound_vec) { + assert_eq!(audio_embs.len(), bounds.len()); + for (embs, bound) in audio_embs.iter().zip(bounds) { + let audio_indices_len = bound[1] - bound[0]; + + if embs.dim(0)? != audio_indices_len as usize { + candle_core::bail!( + "Shape mismatch: Trying to assign embeddings of shape {:?} to input indices of length {audio_indices_len}", + embs.dims(), + ); + } + + new_embeddings = new_embeddings.slice_assign( + &[&i, &(bound[0] as usize..bound[1] as usize), &..], + &embs.to_dtype(input_embeddings.dtype())?, + )?; + } + } + + Ok(new_embeddings) + } + #[allow(clippy::too_many_arguments)] pub fn forward( &self, @@ -215,12 +367,15 @@ impl MiniCpmOModel { pixel_values_all: Option>>, tgt_sizes: Option>, image_bound: Option>, + audio_features: Option, + audio_feature_lens_raw: Option>, + audio_bound: Option>, seqlen_offsets: &[usize], context_lens: Vec<(usize, usize)>, metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>, flash_params: &FlashParams, ) -> Result { - let vllm_embedding = self.get_vllm_embedding( + let mut embedding = self.get_vllm_embedding( input_ids, self.llm.device(), pixel_values_all, @@ -228,9 +383,21 @@ impl MiniCpmOModel { image_bound, )?; + if let Some(audio_features) = audio_features { + let audio_feature_lens_raw = + audio_feature_lens_raw.expect("Require audio_feature_lens_raw"); + let audio_bound = audio_bound.expect("Require audio_feature_lens_raw"); + embedding = self.get_omni_embedding( + &embedding, + &audio_features, + audio_feature_lens_raw, + audio_bound, + )?; + } + self.llm.forward_embed( input_ids, - vllm_embedding, + embedding, seqlen_offsets, context_lens, metadata, @@ -244,6 +411,9 @@ pub(crate) struct MiniCpmOSpecificArgs { pub(crate) pixel_values_all: Option>>, pub(crate) tgt_sizes: Option>, pub(crate) image_bound: Option>, + pub(crate) audio_features: Option, + pub(crate) audio_feature_lens_raw: Option>, + pub(crate) audio_bound: Option>, } impl VisionModel for MiniCpmOModel { @@ -280,14 +450,21 @@ impl VisionModel for MiniCpmOModel { pixel_values_all, tgt_sizes, image_bound, + audio_features, + audio_feature_lens_raw, + audio_bound, } = *model_specific_args .downcast() .expect("Cannot downcast into `MiniCpmOSpecificArgs`"); + self.forward( input_ids, pixel_values_all, tgt_sizes, image_bound, + audio_features, + audio_feature_lens_raw, + audio_bound, seqlen_offsets, context_lens, metadata, @@ -299,6 +476,9 @@ impl VisionModel for MiniCpmOModel { pixel_values_all: None, tgt_sizes: None, image_bound: None, + audio_features: None, + audio_feature_lens_raw: None, + audio_bound: None, }) } } diff --git a/mistralrs-core/src/vision_models/mod.rs b/mistralrs-core/src/vision_models/mod.rs index c78bd8a3b9..850ce09dfb 100644 --- a/mistralrs-core/src/vision_models/mod.rs +++ b/mistralrs-core/src/vision_models/mod.rs @@ -17,9 +17,9 @@ pub(crate) use llava::llava15; pub(crate) use llava::llava_inputs_processor; pub(crate) use llava::llava_next; pub(crate) use llava::llava_next_inputs_processor; +pub(crate) mod common; pub(crate) mod idefics3; pub(crate) mod minicpmo; -pub(crate) mod siglip; use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};