Skip to content

Commit

Permalink
moved model_downloader into simpler.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed Dec 5, 2023
1 parent ed6c0ed commit 6580956
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 61 deletions.
12 changes: 0 additions & 12 deletions src/model_loader/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,4 @@ cc_library(
torch
)

cc_library(
NAME
model_downloader
HDRS
model_downloader.h
SRCS
model_downloader.cpp
DEPS
torch
Python::Python
)

add_subdirectory(safetensors)
25 changes: 0 additions & 25 deletions src/model_loader/model_downloader.cpp

This file was deleted.

8 changes: 0 additions & 8 deletions src/model_loader/model_downloader.h

This file was deleted.

2 changes: 1 addition & 1 deletion src/server/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ cc_binary(
simple.cpp
DEPS
:engine
:model_downloader
torch
absl::strings
gflags::gflags
glog::glog
Python::Python
)

cc_binary(
Expand Down
80 changes: 65 additions & 15 deletions src/server/simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@
#include <c10/core/Device.h>
#include <c10/core/ScalarType.h>
#include <gflags/gflags.h>
#include <pybind11/embed.h>

#include <filesystem>
#include <iostream>
#include <string>

#include "common/logging.h"
#include "engine/engine.h"
#include "model_loader/model_downloader.h"
#include "models/args.h"
#include "models/input_parameters.h"
#include "request/sampling_parameter.h"
#include "request/sequence.h"
#include "request/stopping_criteria.h"
Expand All @@ -20,7 +18,14 @@ DEFINE_string(model_name_or_path,
"TheBloke/Llama-2-7B-GPTQ",
"hf model name or path to the model file.");

DEFINE_string(device, "cuda:0", "Device to run the model on.");
DEFINE_string(model_allow_patterns,
"*.json,*.safetensors,*.model",
"Allow patterns for model files.");

DEFINE_string(device,
"auto",
"Device to run the model on, e.g. cpu, cuda:0, cuda:0,cuda:1, or "
"auto to use all available gpus.");

DEFINE_int32(max_seq_len, 256, "Maximum sequence length.");

Expand All @@ -34,34 +39,79 @@ DEFINE_double(repetition_penalty, 1.0, "Repetition penalty for sampling.");
DEFINE_double(frequency_penalty, 0.0, "Frequency penalty for sampling.");
DEFINE_double(presence_penalty, 0.0, "Presence penalty for sampling.");

int main(int argc, char* argv[]) {
// initialize glog and gflags
google::InitGoogleLogging(argv[0]);
gflags::ParseCommandLineFlags(&argc, &argv, true);

torch::InferenceMode guard;
std::string download_model(const std::string& model_name) {
namespace py = pybind11;
py::scoped_interpreter guard{}; // Start the interpreter

py::dict globals = py::globals();
globals["repo_id"] = model_name;
globals["allow_patterns"] = FLAGS_model_allow_patterns;
py::exec(R"(
from huggingface_hub import snapshot_download
model_path = snapshot_download(repo_id, allow_patterns=allow_patterns.split(','))
)", globals, globals);
return globals["model_path"].cast<std::string>();
}

// split device into chunks
const std::vector<std::string> device_strs =
absl::StrSplit(FLAGS_device, ',');
std::vector<torch::Device> parse_devices(const std::string& device_str) {
std::vector<torch::Device> devices;
devices.reserve(device_strs.size());
if (device_str == "auto") {
// use all available gpus if any
const auto num_gpus = torch::cuda::device_count();
if (num_gpus == 0) {
GLOG(INFO) << "no gpus found, using cpu.";
return {torch::kCPU};
}
devices.reserve(num_gpus);
for (int i = 0; i < num_gpus; ++i) {
devices.emplace_back(torch::kCUDA, i);
}
return devices;
}

// parse device string
const std::vector<std::string> device_strs = absl::StrSplit(device_str, ',');
std::set<torch::DeviceType> device_types;
devices.reserve(device_strs.size());
for (const auto& device_str : device_strs) {
devices.emplace_back(device_str);
device_types.insert(devices.back().type());
}
GCHECK(!devices.empty()) << "No devices specified.";
GCHECK(device_types.size() == 1)
<< "All devices must be of the same type. Got: " << FLAGS_device;
return devices;
}

std::string to_string(const std::vector<torch::Device>& devices) {
std::stringstream ss;
for (size_t i = 0; i < devices.size(); ++i) {
const auto& device = devices[i];
if (i == 0) {
ss << device;
} else {
ss << "," << device;
}
}
return ss.str();
}

int main(int argc, char* argv[]) {
// initialize glog and gflags
google::InitGoogleLogging(argv[0]);
gflags::ParseCommandLineFlags(&argc, &argv, true);

// check if model path exists
std::string model_path = FLAGS_model_name_or_path;
if (!std::filesystem::exists(model_path)) {
// not a model path, try to download the model from huggingface hub
model_path = llm::hf::download_model(FLAGS_model_name_or_path);
model_path = download_model(FLAGS_model_name_or_path);
}

// parse devices
const auto devices = parse_devices(FLAGS_device);
GLOG(INFO) << "Using devices: " << to_string(devices);

llm::Engine engine(devices);
GCHECK(engine.init(model_path));
auto tokenizer = engine.tokenizer();
Expand Down

0 comments on commit 6580956

Please sign in to comment.