Skip to content

Commit

Permalink
Define PJRT plugin interface in C++ (#6360)
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar authored Jan 29, 2024
1 parent 32d24ad commit bd95eb1
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 50 deletions.
5 changes: 2 additions & 3 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import re
import tempfile

import torch
import _XLAC
from ._internal import tpu

logging.basicConfig()
Expand Down Expand Up @@ -135,12 +137,9 @@ def _setup_tpu_vm_library_path() -> bool:
logger.setLevel(logging.INFO)

import atexit
import torch
from ._patched_functions import _apply_patches
from .version import __version__

import _XLAC

_found_libtpu = _setup_tpu_vm_library_path()

# Setup Neuron library for AWS EC2 inf/trn instances.
Expand Down
43 changes: 35 additions & 8 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "torch_xla/csrc/XLANativeFunctions.h"
#include "torch_xla/csrc/aten_autograd_ops.h"
Expand Down Expand Up @@ -78,6 +79,28 @@ struct NoGilSection {
PyThreadState* state = nullptr;
};

class PyPjRtPlugin : public runtime::PjRtPlugin {
public:
using runtime::PjRtPlugin::PjRtPlugin;

std::string library_path() const override {
PYBIND11_OVERRIDE_PURE(std::string, runtime::PjRtPlugin, library_path, );
}

// Templates with commas confuse pybind's macros, so use an alias here
// See https://github.com/pybind/pybind11/issues/2185#issuecomment-634005168
using PjRtCreateOptions = std::unordered_map<std::string, xla::PjRtValueType>;
const PjRtCreateOptions client_create_options() const override {
PYBIND11_OVERRIDE_PURE(PjRtCreateOptions, runtime::PjRtPlugin,
client_create_options, );
}

bool requires_xla_coordinator() const override {
PYBIND11_OVERRIDE_PURE(bool, runtime::PjRtPlugin,
requires_xla_coordinator, );
}
};

c10::optional<torch::lazy::BackendDevice> GetOptionalDevice(
const std::string& device_str) {
if (device_str.empty()) {
Expand Down Expand Up @@ -2319,14 +2342,18 @@ void InitXlaModuleBindings(py::module m) {
return retlist;
});
// -------------Dynamo Integration API End-------------------------
m.def("_register_pjrt_plugin",
[](std::string name, std::string library_path,
std::unordered_map<std::string, xla::PjRtValueType> create_options,
bool init_coordinator) {
runtime::RegisterPjRtPlugin(
name, library_path,
{create_options.begin(), create_options.end()}, init_coordinator);
});
m.def(
"_register_pjrt_plugin",
[](std::string name, std::shared_ptr<const runtime::PjRtPlugin> plugin) {
runtime::RegisterPjRtPlugin(name, plugin);
});
py::class_<runtime::PjRtPlugin, PyPjRtPlugin,
std::shared_ptr<runtime::PjRtPlugin>>(m, "PjRtPlugin")
.def(py::init<>())
.def("library_path", &runtime::PjRtPlugin::library_path)
.def("client_create_options", &runtime::PjRtPlugin::client_create_options)
.def("requires_xla_coordinator",
&runtime::PjRtPlugin::requires_xla_coordinator);
}
} // namespace

Expand Down
40 changes: 18 additions & 22 deletions torch_xla/csrc/runtime/pjrt_registry.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include "torch_xla/csrc/runtime/pjrt_registry.h"

#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/profiler.h"
Expand All @@ -18,13 +20,8 @@ namespace runtime {

namespace {

struct PluginEntry {
std::string library_path;
absl::flat_hash_map<std::string, xla::PjRtValueType> create_options;
bool init_coordinator;
};

std::unordered_map<std::string, PluginEntry> pjrt_plugins_;
std::unordered_map<std::string, std::shared_ptr<const PjRtPlugin>>
pjrt_plugins_;

xla::GpuAllocatorConfig GetGpuAllocatorConfig() {
auto allocator_config = xla::GpuAllocatorConfig{};
Expand All @@ -43,21 +40,18 @@ xla::GpuAllocatorConfig GetGpuAllocatorConfig() {
return allocator_config;
}

std::optional<PluginEntry> GetPjRtPlugin(const std::string& device_type) {
std::shared_ptr<const PjRtPlugin> GetPjRtPlugin(
const std::string& device_type) {
auto plugin_path = pjrt_plugins_.find(device_type);
return plugin_path != pjrt_plugins_.end() ? std::optional(plugin_path->second)
: std::nullopt;
return plugin_path != pjrt_plugins_.end() ? plugin_path->second : nullptr;
}

} // namespace

void RegisterPjRtPlugin(
std::string name, std::string library_path,
absl::flat_hash_map<std::string, xla::PjRtValueType> create_options,
bool init_coordinator) {
TF_VLOG(3) << "Registering PjRt plugin " << name << " at " << library_path;
pjrt_plugins_[name] = {std::move(library_path), std::move(create_options),
init_coordinator};
void RegisterPjRtPlugin(std::string name,
std::shared_ptr<const PjRtPlugin> plugin) {
TF_VLOG(3) << "Registering PjRt plugin " << name;
pjrt_plugins_[name] = plugin;
}

std::tuple<std::unique_ptr<xla::PjRtClient>, std::unique_ptr<XlaCoordinator>>
Expand All @@ -66,12 +60,12 @@ InitializePjRt(const std::string& device_type) {
std::unique_ptr<XlaCoordinator> coordinator;

if (sys_util::GetEnvBool(env::kEnvPjrtDynamicPlugins, false)) {
std::optional<PluginEntry> plugin = GetPjRtPlugin(device_type);
std::shared_ptr<const PjRtPlugin> plugin = GetPjRtPlugin(device_type);
if (plugin) {
TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type;

std::shared_ptr<xla::KeyValueStoreInterface> kv_store = nullptr;
if (plugin->init_coordinator) {
if (plugin->requires_xla_coordinator()) {
int local_process_rank = sys_util::GetEnvInt(
env::kEnvPjRtLocalRank, sys_util::GetEnvInt("LOCAL_RANK", 0));
int global_process_rank =
Expand Down Expand Up @@ -100,10 +94,12 @@ InitializePjRt(const std::string& device_type) {
/*key_prefix=*/"pjrt:");
}
const PJRT_Api* c_api = *pjrt::LoadPjrtPlugin(
absl::AsciiStrToLower(device_type), plugin->library_path);
absl::AsciiStrToLower(device_type), plugin->library_path());
XLA_CHECK_OK(pjrt::InitializePjrtPlugin(device_type));
client = xla::GetCApiClient(absl::AsciiStrToUpper(device_type),
plugin->create_options, kv_store)
auto create_options = plugin->client_create_options();
client = xla::GetCApiClient(
absl::AsciiStrToUpper(device_type),
{create_options.begin(), create_options.end()}, kv_store)
.value();
profiler::RegisterProfilerForPlugin(c_api);
}
Expand Down
18 changes: 14 additions & 4 deletions torch_xla/csrc/runtime/pjrt_registry.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
#ifndef XLA_CLIENT_INITIALIZE_PJRT_CLIENT_H_
#define XLA_CLIENT_INITIALIZE_PJRT_CLIENT_H_

#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_common.h"

namespace torch_xla {
namespace runtime {

void RegisterPjRtPlugin(
std::string name, std::string library_path,
absl::flat_hash_map<std::string, xla::PjRtValueType> create_options = {},
bool init_coordinator = true);
class PjRtPlugin {
public:
virtual std::string library_path() const = 0;

virtual const std::unordered_map<std::string, xla::PjRtValueType>
client_create_options() const = 0;

virtual bool requires_xla_coordinator() const = 0;
};

void RegisterPjRtPlugin(std::string name,
std::shared_ptr<const PjRtPlugin> plugin);

std::tuple<std::unique_ptr<xla::PjRtClient>, std::unique_ptr<XlaCoordinator>>
InitializePjRt(const std::string& device_type);
Expand Down
18 changes: 5 additions & 13 deletions torch_xla/experimental/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import torch_xla.utils.utils as xu


class DevicePlugin:
"""Base class for device plugings.
class DevicePlugin(torch_xla._XLAC.PjRtPlugin):
"""Base class for device plugins.
Default implementations assume a single device and local process.
"""
Expand Down Expand Up @@ -62,6 +62,7 @@ def requires_xla_coordinator(self) -> bool:
return False


# TODO(wcromar): figure out if we can share this map with the C++ code.
_plugin_registry = {}


Expand All @@ -84,21 +85,12 @@ def default() -> DevicePlugin:

def register_plugin(name: str, device_plugin: DevicePlugin):
_plugin_registry[name.upper()] = device_plugin
torch_xla._XLAC._register_pjrt_plugin(
name, device_plugin.library_path(), device_plugin.client_create_options(),
device_plugin.requires_xla_coordinator())
torch_xla._XLAC._register_pjrt_plugin(name, device_plugin)


def register_installed_plugins():
pjrt_entry_points = importlib_metadata.entry_points(group='torch_xla.plugins')
for ep in pjrt_entry_points:
device_plugin_class = ep.load()

# HACK: TpuPlugin raises EnvironmentError if libtpu is not installed.
# TODO(wcromar): Decide if catching `EnvironmentError` is a permanent
# behavior or temporary hack.
try:
register_plugin(ep.name.upper(), device_plugin_class())
except EnvironmentError as e:
logging.warning(
"Failed to register plugin {}".format(ep.name), exc_info=e)
register_plugin(ep.name.upper(), device_plugin_class())

0 comments on commit bd95eb1

Please sign in to comment.