diff --git a/README.md b/README.md
index aa775d4c..2c169ef4 100644
--- a/README.md
+++ b/README.md
@@ -69,12 +69,12 @@ docker run -it --gpus=all --net=host --shm-size=1g \
This command starts the Docker container with GPU support and various configuration options.
> **Warning**
-> NCCL might fall back to using the host memory if NVLink or PCI is not available. To allow NCCL to use the host memory, we added '--shm-size=1g' to the command. If you have NVLink or PCI available, you can remove this option.
+> NCCL might fall back to using the host memory if NVLink or PCI is not available. To allow NCCL to use the host memory, we added '--shm-size=1g' to the docker run command.
- `HF_MODEL_ID` specifies which Hugging Face model you want to run.
-- `HF_MODEL_REVISION` specifies which Hugging Face model revision you want to run. by default, it is set to `"main"`.
-- `HF_MODEL_ALLOW_PATTERN` specifies which types of files are allowed to be downloaded. by default, it is set to `"*.json,*.safetensors,*.model"`.
-- `DEVICE` specifies the device on which this model should run. by default, it is set to `"auto"`.
+- `HF_MODEL_REVISION` specifies which Hugging Face model revision you want to run. By default, it is set to `"main"`.
+- `HF_MODEL_ALLOW_PATTERN` specifies which types of files are allowed to be downloaded. By default, it is set to `"*.json,*.safetensors,*.model"`.
+- `DEVICE` specifies the device on which this model should run. By default, it is set to `"auto"`.
- `HUGGING_FACE_HUB_TOKEN` specifies the token from [huggingface](https://huggingface.co/settings/tokens) for gated models.
> **Note**
diff --git a/entrypoint.sh b/entrypoint.sh
index cda907eb..fd6d8838 100755
--- a/entrypoint.sh
+++ b/entrypoint.sh
@@ -20,10 +20,6 @@ if [ -n "$HF_MODEL_ID" ]; then
exit 1
fi
ARGS+=" --model_path "$MODEL_PATH" --model_id "$HF_MODEL_ID""
-elif [ -n "$HF_MODEL_PATH" ]; then
- echo "Using model from the specified path "$HF_MODEL_PATH""
-
- ARGS+=" --model_path "$HF_MODEL_PATH""
fi
ARGS+=" --device "$DEVICE""
diff --git a/scalellm.yml b/scalellm.yml
index 95219173..6f311952 100644
--- a/scalellm.yml
+++ b/scalellm.yml
@@ -17,6 +17,7 @@ services:
- HUGGING_FACE_HUB_TOKEN=${HUGGING_FACE_HUB_TOKEN}
volumes:
- $HOME/.cache/huggingface/hub:/models
+ shm_size: 1g
command: --logtostderr
# turn on GPU access
deploy:
diff --git a/scripts/start_scalellm.sh b/scripts/start_scalellm.sh
index 81cd85e1..5eafd74c 100755
--- a/scripts/start_scalellm.sh
+++ b/scripts/start_scalellm.sh
@@ -18,10 +18,6 @@ if [ -n "$HF_MODEL_ID" ]; then
exit 1
fi
ARGS+=" --model_path "$MODEL_PATH" --model_id "$HF_MODEL_ID""
-elif [ -n "$HF_MODEL_PATH" ]; then
- echo "Using model from the specified path "$HF_MODEL_PATH""
-
- ARGS+=" --model_path "$HF_MODEL_PATH""
fi
ARGS+=" --device "$DEVICE""
diff --git a/src/common/process_group.cpp b/src/common/process_group.cpp
index 0a0f4ac8..558a5ea8 100644
--- a/src/common/process_group.cpp
+++ b/src/common/process_group.cpp
@@ -13,23 +13,21 @@ namespace llm {
namespace {
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
-#define NCCLCHECK(cmd) \
- do { \
- ncclResult_t r = cmd; \
- if (r != ncclSuccess) { \
- GLOG(FATAL) << "Failed, NCCL error " << __FILE__ << ":" << __LINE__ \
- << " " << ncclGetErrorString(r); \
- } \
+#define NCCLCHECK(cmd) \
+ do { \
+ ncclResult_t r = cmd; \
+ if (r != ncclSuccess) { \
+ GLOG(FATAL) << "Failed, NCCL error :" << ncclGetErrorString(r); \
+ } \
} while (0)
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
-#define CUDACHECK(cmd) \
- do { \
- cudaError_t err = cmd; \
- if (err != cudaSuccess) { \
- GLOG(FATAL) << "Failed, Cuda error " << __FILE__ << ":" << __LINE__ \
- << " " << cudaGetErrorString(err); \
- } \
+#define CUDACHECK(cmd) \
+ do { \
+ cudaError_t err = cmd; \
+ if (err != cudaSuccess) { \
+ GLOG(FATAL) << "Failed, Cuda error :" << cudaGetErrorString(err); \
+ } \
} while (0)
at::Tensor flatten_for_scatter_gather(std::vector& tensors) {
@@ -81,15 +79,14 @@ std::vector> ProcessGroup::create_process_groups(
GCHECK(device.is_cuda()) << "device should be cuda device";
}
- const int world_size = static_cast(devices.size());
-
- std::vector comms;
- comms.reserve(devices.size());
std::vector device_idxs;
device_idxs.reserve(devices.size());
for (const auto& device : devices) {
device_idxs.push_back(device.index());
}
+
+ std::vector comms(devices.size());
+ const int world_size = static_cast(devices.size());
NCCLCHECK(ncclCommInitAll(comms.data(), world_size, device_idxs.data()));
std::vector> process_groups;
@@ -102,16 +99,6 @@ std::vector> ProcessGroup::create_process_groups(
}
// Constructor.
-ProcessGroupNCCL::ProcessGroupNCCL(int rank,
- int world_size,
- const torch::Device& device,
- const ncclUniqueId& comm_id)
- : ProcessGroup(rank, world_size, device) {
- torch::DeviceGuard device_guard(device);
- NCCLCHECK(ncclCommInitRank(&comm_, world_size, comm_id, rank));
- CUDACHECK(cudaStreamCreate(&stream_));
-}
-
ProcessGroupNCCL::ProcessGroupNCCL(int rank,
int world_size,
const torch::Device& device,
diff --git a/src/common/process_group.h b/src/common/process_group.h
index 89f50e04..8f74a5b8 100644
--- a/src/common/process_group.h
+++ b/src/common/process_group.h
@@ -56,11 +56,6 @@ class ProcessGroup {
class ProcessGroupNCCL : public ProcessGroup {
public:
// Constructor.
- ProcessGroupNCCL(int rank,
- int world_size,
- const torch::Device& device,
- const ncclUniqueId& id);
-
ProcessGroupNCCL(int rank,
int world_size,
const torch::Device& device,
diff --git a/src/engine/engine.cpp b/src/engine/engine.cpp
index 2bba5703..94b4a8e9 100644
--- a/src/engine/engine.cpp
+++ b/src/engine/engine.cpp
@@ -1,4 +1,5 @@
#include "engine.h"
+#include
#include
#include
@@ -21,6 +22,8 @@ DEFINE_double(max_memory_utilization,
0.9,
"maximum memory utilization allowed, default 0.9");
+DECLARE_bool(disable_custom_kernels);
+
namespace llm {
namespace {
torch::ScalarType parse_dtype(const std::string& dtype_str,
@@ -64,6 +67,10 @@ Engine::Engine(const std::vector& devices) : devices_(devices) {
ParallelArgs parallel_args(rank, world_size, pg);
workers_.emplace_back(std::make_unique(parallel_args, devices[i]));
}
+
+ if (FLAGS_disable_custom_kernels) {
+ GLOG(WARNING) << "Custom kernels are disabled, using generic kernels.";
+ }
}
bool Engine::init(const std::string& model_weights_path) {
diff --git a/src/layers/activation.cpp b/src/layers/activation.cpp
index abc277a2..e3fd42aa 100644
--- a/src/layers/activation.cpp
+++ b/src/layers/activation.cpp
@@ -8,6 +8,8 @@
#include "common/logging.h"
+DECLARE_bool(disable_custom_kernels);
+
namespace llm {
namespace detail {
torch::Tensor gelu(torch::Tensor x) {
@@ -84,10 +86,12 @@ ActFunc Activation::get_act_func(const std::string& name,
return gelu;
}
if (boost::iequals(name, "gelu_fast")) {
- return device.is_cuda() ? kernel::gelu_fast : gelu_fast;
+ return device.is_cuda() && !FLAGS_disable_custom_kernels ? kernel::gelu_fast
+ : gelu_fast;
}
if (boost::iequals(name, "gelu_new")) {
- return device.is_cuda() ? kernel::gelu_new : gelu_new;
+ return device.is_cuda() && !FLAGS_disable_custom_kernels ? kernel::gelu_new
+ : gelu_new;
}
if (boost::iequals(name, "gelu_pytorch_tanh")) {
return gelu_pytorch_tanh;
@@ -96,7 +100,8 @@ ActFunc Activation::get_act_func(const std::string& name,
return relu;
}
if (boost::iequals(name, "silu")) {
- return device.is_cuda() ? kernel::silu : silu;
+ return device.is_cuda() && !FLAGS_disable_custom_kernels ? kernel::silu
+ : silu;
}
GLOG(ERROR) << "Unsupported activation function: " << name;
@@ -111,10 +116,14 @@ ActFunc Activation::get_act_with_mul_func(const std::string& name,
return gelu_with_mul;
}
if (boost::iequals(name, "gelu_fast")) {
- return device.is_cuda() ? kernel::gelu_fast_with_mul : gelu_fast_with_mul;
+ return device.is_cuda() && !FLAGS_disable_custom_kernels
+ ? kernel::gelu_fast_with_mul
+ : gelu_fast_with_mul;
}
if (boost::iequals(name, "gelu_new")) {
- return device.is_cuda() ? kernel::gelu_new_with_mul : gelu_new_with_mul;
+ return device.is_cuda() && !FLAGS_disable_custom_kernels
+ ? kernel::gelu_new_with_mul
+ : gelu_new_with_mul;
}
if (boost::iequals(name, "gelu_pytorch_tanh")) {
return gelu_pytorch_tanh_with_mul;
@@ -123,7 +132,9 @@ ActFunc Activation::get_act_with_mul_func(const std::string& name,
return relu_with_mul;
}
if (boost::iequals(name, "silu")) {
- return device.is_cuda() ? kernel::silu_with_mul : silu_with_mul;
+ return device.is_cuda() && !FLAGS_disable_custom_kernels
+ ? kernel::silu_with_mul
+ : silu_with_mul;
}
GLOG(ERROR) << "Unsupported activation function: " << name;
diff --git a/src/layers/attention.cpp b/src/layers/attention.cpp
index 570fa7e6..2ebfb24f 100644
--- a/src/layers/attention.cpp
+++ b/src/layers/attention.cpp
@@ -6,16 +6,7 @@
#include "common/logging.h"
-DEFINE_string(varlen_masked_self_attention,
- "",
- "type of attention to use for varlen_masked_self_attention, "
- "slow, cuda, or empty for auto");
-
-DEFINE_string(
- single_query_masked_self_attention,
- "",
- "type of attention to use for single_query_masked_self_attention, slow, "
- "cuda, or empty for auto");
+DEFINE_bool(disable_custom_kernels, false, "disable all custom kernels");
DEFINE_bool(
force_use_paged_attention_v2,
@@ -188,19 +179,16 @@ void varlen_masked_self_attention(
int32_t max_seq_len, // maximum sequence length
float scale, // scale for softmax
torch::Tensor& output) {
- if (query.is_cuda()) {
+ if (query.is_cuda() && !FLAGS_disable_custom_kernels) {
// use cuda kernel
- if (FLAGS_varlen_masked_self_attention.empty() ||
- FLAGS_varlen_masked_self_attention == "cuda") {
- return varlen_masked_self_attention_cuda(query,
- key,
- value,
- cu_seq_lens,
- alibi_slopes,
- max_seq_len,
- scale,
- output);
- }
+ return varlen_masked_self_attention_cuda(query,
+ key,
+ value,
+ cu_seq_lens,
+ alibi_slopes,
+ max_seq_len,
+ scale,
+ output);
}
return varlen_masked_self_attention_generic(
query, key, value, cu_seq_lens, alibi_slopes, scale, output);
@@ -216,20 +204,17 @@ void single_query_masked_self_attention(
int32_t max_context_len, // maximum context length
float scale, // scale for softmax
torch::Tensor& output) {
- if (query.is_cuda()) {
+ if (query.is_cuda() && !FLAGS_disable_custom_kernels) {
// use cuda kernel
- if (FLAGS_single_query_masked_self_attention.empty() ||
- FLAGS_single_query_masked_self_attention == "cuda") {
- return single_query_masked_self_attention_cuda(kv_cache,
- kv_head_mapping,
- query,
- block_tables,
- context_lens,
- alibi_slopes,
- max_context_len,
- scale,
- output);
- }
+ return single_query_masked_self_attention_cuda(kv_cache,
+ kv_head_mapping,
+ query,
+ block_tables,
+ context_lens,
+ alibi_slopes,
+ max_context_len,
+ scale,
+ output);
}
return single_query_masked_self_attention_generic(kv_cache,
query,
diff --git a/src/layers/attention.h b/src/layers/attention.h
index 71147e81..37cc9312 100644
--- a/src/layers/attention.h
+++ b/src/layers/attention.h
@@ -7,8 +7,7 @@
#include "memory/kv_cache.h"
#include "models/input_parameters.h"
-DECLARE_string(varlen_masked_self_attention);
-DECLARE_string(single_query_masked_self_attention);
+DECLARE_bool(disable_custom_kernels);
namespace llm {
diff --git a/src/layers/normalization.h b/src/layers/normalization.h
index 2990d902..8f634404 100644
--- a/src/layers/normalization.h
+++ b/src/layers/normalization.h
@@ -7,6 +7,7 @@
#include "kernels/layernorm_kernels.h"
#include "model_loader/state_dict.h"
+DECLARE_bool(disable_custom_kernels);
namespace llm {
namespace detail {
inline torch::Tensor rms_norm(torch::Tensor input,
@@ -58,7 +59,7 @@ class LayerNormImpl : public torch::nn::Module {
}
torch::Tensor forward(torch::Tensor input) {
- if (input.is_cuda()) {
+ if (input.is_cuda() && !FLAGS_disable_custom_kernels) {
auto output = torch::empty_like(input);
kernel::layer_norm(output, input, weight_, bias_, eps_);
return output;
@@ -131,7 +132,7 @@ class RMSNormImpl : public torch::nn::Module {
}
torch::Tensor forward(torch::Tensor input) {
- if (input.is_cuda()) {
+ if (input.is_cuda() && !FLAGS_disable_custom_kernels) {
auto output = torch::empty_like(input);
kernel::rms_norm(output, input, weight_, eps_);
return output;
diff --git a/src/layers/pos_embedding.cpp b/src/layers/pos_embedding.cpp
index f9c2c93c..6d296d41 100644
--- a/src/layers/pos_embedding.cpp
+++ b/src/layers/pos_embedding.cpp
@@ -8,6 +8,7 @@
#include "common/logging.h"
#include "kernels/pos_embedding_kernels.h"
+DECLARE_bool(disable_custom_kernels);
namespace llm {
namespace {
@@ -57,7 +58,8 @@ std::shared_ptr create(int64_t rotary_dim,
bool interleaved,
torch::ScalarType dtype,
const torch::Device& device) {
- if (device.is_cuda()) {
+ if (device.is_cuda() && !FLAGS_disable_custom_kernels) {
+ // use custom kernels
return std::make_shared(rotary_dim,
max_position_embeddings,
scaling_factor,
@@ -146,7 +148,7 @@ RotaryEmbeddingGeneric::RotaryEmbeddingGeneric(int64_t rotary_dim,
const auto cos_sin = torch::cat({emd.cos(), emd.sin()}, /*dim=*/-1);
const auto options = torch::dtype(dtype).device(device);
- cos_sin_cache_ = register_buffer("cos_sin_cached", cos_sin);
+ cos_sin_cache_ = register_buffer("cos_sin_cache", cos_sin.to(options));
}
// inplace rotary positional embedding
@@ -181,8 +183,8 @@ RotaryEmbeddingKernel::RotaryEmbeddingKernel(int64_t rotary_dim,
: rotary_dim_(rotary_dim), interleaved_(interleaved) {
const auto freqs = detail::compute_freqs(
max_position_embeddings, rotary_dim, scaling_factor, theta);
- const auto cos_sin = torch::cat({freqs.cos(), freqs.sin()}, /*dim=*/-1);
+ const auto cos_sin = torch::cat({freqs.cos(), freqs.sin()}, /*dim=*/-1);
const auto options = torch::dtype(dtype).device(device);
cos_sin_cache_ = register_buffer("cos_sin_cache", cos_sin.to(options));
}
diff --git a/src/model_loader/model_loader.cpp b/src/model_loader/model_loader.cpp
index 62a45d25..66c1a79b 100644
--- a/src/model_loader/model_loader.cpp
+++ b/src/model_loader/model_loader.cpp
@@ -258,8 +258,8 @@ bool HFModelLoader::load_model_args(const std::string& model_weights_path) {
// always use float16 for quantization
if (!quant_args_.quant_method().empty() && args_.dtype() != "float16") {
- LOG(WARNING) << "Overwriting dtype from " << args_.dtype() << " to float16 "
- << "for quantization";
+ GLOG(WARNING) << "Overwriting dtype from " << args_.dtype()
+ << " to float16 for quantization";
args_.dtype() = "float16";
}
diff --git a/src/server/main.cpp b/src/server/main.cpp
index b06e859f..118b9104 100644
--- a/src/server/main.cpp
+++ b/src/server/main.cpp
@@ -91,7 +91,7 @@ int main(int argc, char** argv) {
// check if model path exists
if (!std::filesystem::exists(FLAGS_model_path)) {
- LOG(FATAL) << "Model path " << FLAGS_model_path << " does not exist.";
+ GLOG(FATAL) << "Model path " << FLAGS_model_path << " does not exist.";
}
if (FLAGS_model_id.empty()) {