Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add cuda feature flag #96

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/vllm/src/tests/llama.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::models::llama::LlamaModel;
use crate::{
llm_service::LlmService,
types::{GenerateParameters, GenerateRequest},
};
use crate::models::llama::LlamaModel;
use futures::{stream::FuturesUnordered, StreamExt};
use std::{path::PathBuf, time::Instant};
use tracing::info;
Expand Down
3 changes: 2 additions & 1 deletion models/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ serde_json = { workspace = true }
tokenizers = { workspace = true }

[features]
nccl = ["dep:cudarc", "cudarc/nccl"]
nccl = ["dep:cudarc", "cudarc/nccl"]
cuda = ["dep:cudarc"]
12 changes: 7 additions & 5 deletions models/src/flash_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ impl FlashAttention {
mod tests {
use super::*;
use candle_core::{DType, Device, Tensor};

#[cfg(feature = "cuda")]
#[test]
fn test_new() {
let device = Device::new_cuda(0).unwrap();
Expand All @@ -489,13 +489,14 @@ mod tests {
assert_eq!(flash_attention.kv_cache_dtype, DType::F32);
}

#[cfg(feature = "cpu")]
#[test]
fn test_new_invalid_heads() {
let device = Device::Cpu;
let result = FlashAttention::new(7, 4, 64, 1.0, None, None, DType::F32, device);
assert!(result.is_err());
}

#[cfg(feature = "cpu")]
#[test]
fn test_new_invalid_head_dim() {
let device = Device::Cpu;
Expand All @@ -514,7 +515,7 @@ mod tests {
let shape = FlashAttention::get_kv_cache_shape(10, 32, 4, 64);
assert_eq!(shape, vec![2, 10, 32, 4, 64]);
}

#[cfg(feature = "cpu")]
#[test]
fn test_split_kv_cache() {
let device = Device::Cpu;
Expand All @@ -531,7 +532,7 @@ mod tests {
assert_eq!(key_cache.shape().dims(), &[10, 32, 4, 64]);
assert_eq!(value_cache.shape().dims(), &[10, 32, 4, 64]);
}

#[cfg(feature = "cpu")]
#[test]
fn test_split_kv_cache_invalid_shape() {
let device = Device::Cpu;
Expand All @@ -542,7 +543,7 @@ mod tests {
let result = flash_attention.split_kv_cache(&kv_cache);
assert!(result.is_err());
}

#[cfg(feature = "cuda")]
#[test]
fn test_forward() {
let device = Device::new_cuda(0).unwrap();
Expand Down Expand Up @@ -627,6 +628,7 @@ mod tests {
.any(|&x| x == 1));
}

#[cfg(feature = "cuda")]
#[test]
fn test_forward_with_varlen() {
let device = Device::new_cuda(0).unwrap();
Expand Down
6 changes: 5 additions & 1 deletion models/src/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ mod tests {
use tokenizers::Tokenizer;

const EOS_TOKEN: &str = "</s>";

#[cfg(feature = "cuda")]
#[test]
#[serial]
fn test_llama_model() -> Result<()> {
Expand Down Expand Up @@ -704,6 +704,7 @@ mod tests {
Ok(())
}

#[cfg(feature = "cuda")]
#[test]
#[serial]
fn test_llama_model_long() -> Result<()> {
Expand Down Expand Up @@ -915,6 +916,7 @@ mod tests {
Ok(())
}

#[cfg(feature = "cuda")]
#[test]
#[serial]
fn test_llama_model_random_block_order() -> Result<()> {
Expand Down Expand Up @@ -1151,6 +1153,7 @@ mod tests {
Ok(())
}

#[cfg(feature = "cuda")]
#[test]
#[serial]
fn test_llama_model_llama3_2_1b() -> Result<()> {
Expand Down Expand Up @@ -1396,6 +1399,7 @@ mod tests {
Ok(())
}

#[cfg(feature = "cuda")]
#[test]
#[serial]
fn test_llama_model_batch() -> Result<()> {
Expand Down
1 change: 1 addition & 0 deletions models/src/llama_nccl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ mod tests {

const EOS_TOKEN: &str = "</s>";

#[cfg(feature = "cuda")]
#[test]
#[serial]
fn test_llama_nccl_model_random_block_order() -> Result<()> {
Expand Down
4 changes: 3 additions & 1 deletion models/src/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ mod tests {
const EOS_TOKEN: &str = "ӏ ";
const BLOCK_SIZE: usize = 16;

#[cfg(feature = "cuda")]
#[test]
#[serial]
fn test_mistral_model() -> Result<()> {
Expand Down Expand Up @@ -646,7 +647,7 @@ mod tests {
Ok(())
}

// cargo test test_mistral_model_batch -- --exact
#[cfg(feature = "cuda")]
#[test]
#[serial]
fn test_mistral_model_batch() -> Result<()> {
Expand Down Expand Up @@ -1018,6 +1019,7 @@ mod tests {
Ok(())
}

#[cfg(feature = "cuda")]
#[test]
#[serial]
fn test_mistral_model_long() -> Result<()> {
Expand Down
3 changes: 3 additions & 0 deletions models/src/phi3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ mod tests {
const EOS_TOKEN: &str = "ӏ ";
const BLOCK_SIZE: usize = 16;

#[cfg(feature = "cuda")]
#[test]
#[serial]

Expand Down Expand Up @@ -636,6 +637,7 @@ mod tests {
Ok(())
}

#[cfg(feature = "cuda")]
#[test]
#[serial]
fn test_phi3_model_batch() -> Result<()> {
Expand Down Expand Up @@ -995,6 +997,7 @@ mod tests {
Ok(())
}

#[cfg(feature = "cuda")]
#[test]
#[serial]
fn test_phi3_model_long() -> Result<()> {
Expand Down
13 changes: 13 additions & 0 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ serde_json.workspace = true
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal"] }
tracing.workspace = true
tracing-subscriber.workspace = true
cudarc = { workspace = true, optional = true }

[dev-dependencies]
expect-test.workspace = true
Expand All @@ -23,3 +24,15 @@ async-openai.workspace = true
[features]
vllm = ["atoma-backends/vllm"]
nccl = ["vllm", "atoma-backends/nccl"]
cuda = [
"dep:cudarc",
"cudarc/std",
"cudarc/cublas",
"cudarc/cublaslt",
"cudarc/curand",
"cudarc/driver",
"cudarc/nvrtc",
"cudarc/f16",
"cudarc/cuda-version-from-build-system",
"cudarc/dynamic-linking"
]
19 changes: 19 additions & 0 deletions server/src/api/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@ impl TryFrom<(String, GenerateRequestOutput)> for ChatCompletionResponse {
}
}

#[cfg(feature = "cpu")]
#[cfg(test)]
pub mod json_schema_tests {
// TODO: Move check functions to a test utils module.
Expand All @@ -773,6 +774,7 @@ pub mod json_schema_tests {
expect_file.assert_eq(schema);
}

#[cfg(feature = "cpu")]
#[test]
/// Used in tandem with a schema file, this will check if there are
/// changes to the JSON API schema, and show a diff if so.
Expand All @@ -790,6 +792,7 @@ pub mod json_schema_tests {

// TODO: Add the above test for response_schema

#[cfg(feature = "cpu")]
#[test]
fn request_schema_control() {
let schema_path = concat!(env!("CARGO_MANIFEST_DIR"), "/request_schema.json");
Expand All @@ -806,6 +809,7 @@ pub mod json_schema_tests {
);
}

#[cfg(feature = "cpu")]
#[test]
fn deserialize_request_body_basic() {
let json_request_body = r#"
Expand Down Expand Up @@ -845,6 +849,7 @@ pub mod json_schema_tests {
);
}

#[cfg(feature = "cpu")]
#[test]
fn deserialize_user_message() {
let json_user_message_text = r#"
Expand Down Expand Up @@ -889,6 +894,7 @@ pub mod json_schema_tests {
);
}

#[cfg(feature = "cpu")]
#[test]
fn deserialize_assistant_message() {
let json_assistant_message_text = r#"
Expand Down Expand Up @@ -931,6 +937,7 @@ pub mod json_schema_tests {
);
}

#[cfg(feature = "cpu")]
#[test]
fn deserialize_tool_message() {
let json_tool_message_text = r#"
Expand All @@ -952,6 +959,7 @@ pub mod json_schema_tests {
);
}

#[cfg(feature = "cpu")]
#[test]
fn deserialize_message_content_text() {
let json_message_content_text = r#"
Expand All @@ -963,6 +971,7 @@ pub mod json_schema_tests {
assert!(message_content.is_ok());
}

#[cfg(feature = "cpu")]
#[test]
fn deserialize_message_content_array() {
let json_message_content_array = r#"
Expand All @@ -986,6 +995,7 @@ pub mod json_schema_tests {
assert!(message_content.is_ok());
}

#[cfg(feature = "cpu")]
#[test]
fn test_messages_to_prompt() {
// Create some sample messages
Expand Down Expand Up @@ -1021,6 +1031,7 @@ pub mod json_schema_tests {
assert_eq!(prompt, expected_prompt);
}

#[cfg(feature = "cpu")]
#[test]
fn test_empty_messages() {
let messages: Vec<Message> = vec![];
Expand All @@ -1030,6 +1041,7 @@ pub mod json_schema_tests {
assert_eq!(prompt, expected_prompt);
}

#[cfg(feature = "cpu")]
#[test]
fn test_message_with_no_content() {
let messages = vec![
Expand Down Expand Up @@ -1058,6 +1070,7 @@ pub mod json_schema_tests {
assert_eq!(prompt, expected_prompt);
}

#[cfg(feature = "cpu")]
#[test]
fn test_message_to_prompt_without_sytem() {
let messages = vec![
Expand All @@ -1081,6 +1094,7 @@ pub mod json_schema_tests {
assert_eq!(prompt, expected_prompt);
}

#[cfg(feature = "cpu")]
#[test]
fn test_message_to_prompt_without_user() {
let messages = vec![
Expand Down Expand Up @@ -1108,6 +1122,7 @@ pub mod json_schema_tests {
assert_eq!(prompt, expected_prompt);
}

#[cfg(feature = "cpu")]
#[test]
fn test_deserialize_chat_completion_response() {
let json = json!({
Expand Down Expand Up @@ -1157,6 +1172,7 @@ pub mod json_schema_tests {
assert_eq!(response.usage.total_tokens, 21);
}

#[cfg(feature = "cpu")]
#[test]
fn test_deserialize_choice() {
let json = json!({
Expand All @@ -1175,6 +1191,7 @@ pub mod json_schema_tests {
assert!(matches!(choice.finish_reason, FinishReason::Stopped));
}

#[cfg(feature = "cpu")]
#[test]
fn test_deserialize_finish_reason() {
assert_eq!(
Expand All @@ -1191,6 +1208,7 @@ pub mod json_schema_tests {
);
}

#[cfg(feature = "cpu")]
#[test]
fn test_deserialize_usage() {
let json = json!({
Expand All @@ -1206,6 +1224,7 @@ pub mod json_schema_tests {
assert_eq!(usage.total_tokens, 21);
}

#[cfg(feature = "cpu")]
#[test]
fn test_deserialize_choice_with_logprobs() {
let json = json!({
Expand Down
Loading