Skip to content

Commit

Permalink
Externalize extension type registry so that Python can keep it alive …
Browse files Browse the repository at this point in the history
…until PyExtensionType is unregistered
  • Loading branch information
wesm committed Aug 23, 2019
1 parent 84e7bdd commit 5524868
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 32 deletions.
88 changes: 63 additions & 25 deletions cpp/src/arrow/extension_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,39 +65,77 @@ void ExtensionArray::SetData(const std::shared_ptr<ArrayData>& data) {
storage_ = MakeArray(storage_data);
}

std::unordered_map<std::string, std::shared_ptr<ExtensionType>> g_extension_registry;
std::mutex g_extension_registry_guard;
class ExtensionTypeRegistryImpl : public ExtensionTypeRegistry {
public:
ExtensionTypeRegistryImpl() {}

Status RegisterType(std::shared_ptr<ExtensionType> type) {
std::lock_guard<std::mutex> lock(lock_);
std::string type_name = type->extension_name();
auto it = name_to_type_.find(type_name);
if (it != name_to_type_.end()) {
return Status::KeyError("A type extension with name ", type_name,
" already defined");
}
name_to_type_[type_name] = std::move(type);
return Status::OK();
}

Status RegisterExtensionType(std::shared_ptr<ExtensionType> type) {
std::lock_guard<std::mutex> lock_(g_extension_registry_guard);
std::string type_name = type->extension_name();
auto it = g_extension_registry.find(type_name);
if (it != g_extension_registry.end()) {
return Status::KeyError("A type extension with name ", type_name, " already defined");
Status UnregisterType(const std::string& type_name) {
std::lock_guard<std::mutex> lock(lock_);
auto it = name_to_type_.find(type_name);
if (it == name_to_type_.end()) {
return Status::KeyError("No type extension with name ", type_name, " found");
}
name_to_type_.erase(it);
return Status::OK();
}

std::shared_ptr<ExtensionType> GetType(const std::string& type_name) {
std::lock_guard<std::mutex> lock(lock_);
auto it = name_to_type_.find(type_name);
if (it == name_to_type_.end()) {
return nullptr;
} else {
return it->second;
}
return nullptr;
}
g_extension_registry[type_name] = std::move(type);
return Status::OK();

private:
std::mutex lock_;
std::unordered_map<std::string, std::shared_ptr<ExtensionType>> name_to_type_;
};

static std::shared_ptr<ExtensionTypeRegistry> g_registry;
static std::once_flag registry_initialized;

namespace internal {

static void CreateGlobalRegistry() {
g_registry = std::make_shared<ExtensionTypeRegistryImpl>();
}

} // namespace internal

std::shared_ptr<ExtensionTypeRegistry> ExtensionTypeRegistry::GetGlobalRegistry() {
std::call_once(registry_initialized, internal::CreateGlobalRegistry);
return g_registry;
}

Status RegisterExtensionType(std::shared_ptr<ExtensionType> type) {
auto registry = ExtensionTypeRegistry::GetGlobalRegistry();
return registry->RegisterType(type);
}

Status UnregisterExtensionType(const std::string& type_name) {
std::lock_guard<std::mutex> lock_(g_extension_registry_guard);
auto it = g_extension_registry.find(type_name);
if (it == g_extension_registry.end()) {
return Status::KeyError("No type extension with name ", type_name, " found");
}
g_extension_registry.erase(it);
return Status::OK();
auto registry = ExtensionTypeRegistry::GetGlobalRegistry();
return registry->UnregisterType(type_name);
}

std::shared_ptr<ExtensionType> GetExtensionType(const std::string& type_name) {
std::lock_guard<std::mutex> lock_(g_extension_registry_guard);
auto it = g_extension_registry.find(type_name);
if (it == g_extension_registry.end()) {
return nullptr;
} else {
return it->second;
}
return nullptr;
auto registry = ExtensionTypeRegistry::GetGlobalRegistry();
return registry->GetType(type_name);
}

} // namespace arrow
14 changes: 14 additions & 0 deletions cpp/src/arrow/extension_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,20 @@ class ARROW_EXPORT ExtensionArray : public Array {
std::shared_ptr<Array> storage_;
};

class ARROW_EXPORT ExtensionTypeRegistry {
public:
/// \brief Provide access to the global registry to allow code to control for
/// race conditions in registry teardown when some types need to be
/// unregistered and destroyed first
static std::shared_ptr<ExtensionTypeRegistry> GetGlobalRegistry();

virtual ~ExtensionTypeRegistry() = default;

virtual Status RegisterType(std::shared_ptr<ExtensionType> type) = 0;
virtual Status UnregisterType(const std::string& type_name) = 0;
virtual std::shared_ptr<ExtensionType> GetType(const std::string& type_name) = 0;
};

/// \brief Register an extension type globally. The name returned by the type's
/// extension_name() method should be unique. This method is thread-safe
/// \param[in] type an instance of the extension type
Expand Down
4 changes: 4 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -1412,6 +1412,10 @@ cdef extern from 'arrow/python/inference.h' namespace 'arrow::py':


cdef extern from 'arrow/extension_type.h' namespace 'arrow':
cdef cppclass CExtensionTypeRegistry" arrow::ExtensionTypeRegistry":
@staticmethod
shared_ptr[CExtensionTypeRegistry] GetGlobalRegistry()

cdef cppclass CExtensionType" arrow::ExtensionType"(CDataType):
c_string extension_name()
shared_ptr[CDataType] storage_type()
Expand Down
24 changes: 17 additions & 7 deletions python/pyarrow/types.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1961,6 +1961,21 @@ def is_float_value(object obj):
return IsPyFloat(obj)


cdef class _ExtensionRegistryNanny:
# Keep the registry alive until we have unregistered PyExtensionType
cdef:
shared_ptr[CExtensionTypeRegistry] registry

def __cinit__(self):
self.registry = CExtensionTypeRegistry.GetGlobalRegistry()

def release_registry(self):
self.registry.reset()


_registry_nanny = _ExtensionRegistryNanny()


def _register_py_extension_type():
cdef:
DataType storage_type
Expand All @@ -1979,13 +1994,8 @@ def _unregister_py_extension_type():
# finalized. If the C++ type is destroyed later in the process
# teardown stage, it will invoke CPython APIs such as Py_DECREF
# with a destroyed interpreter.
#
# As reported in ARROW-6301 there are cases where UnregisterPyExtensionType
# might fail
cdef CStatus s = UnregisterPyExtensionType()
if not s.ok():
print("Calling UnregisterPyExtensionType failed, allowing to "
"pass silently: {}".format(s.ToString()))
check_status(UnregisterPyExtensionType())
_registry_nanny.release_registry()


_register_py_extension_type()
Expand Down

0 comments on commit 5524868

Please sign in to comment.