Skip to content

Commit

Permalink
wasm: allow execution of multiple instances of the same plugin. (#13753)
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Sikora <piotrsikora@google.com>
Co-authored-by: mathetake <takeshi@tetrate.io>
  • Loading branch information
PiotrSikora and mathetake authored Nov 12, 2020
1 parent 36c4191 commit 3c8e56a
Show file tree
Hide file tree
Showing 25 changed files with 222 additions and 176 deletions.
4 changes: 2 additions & 2 deletions bazel/repository_locations.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -870,8 +870,8 @@ REPOSITORY_LOCATIONS_SPEC = dict(
project_name = "WebAssembly for Proxies (C++ host implementation)",
project_desc = "WebAssembly for Proxies (C++ host implementation)",
project_url = "https://github.com/proxy-wasm/proxy-wasm-cpp-host",
version = "4741d2f1cd5eb250f66d0518238c333353259d56",
sha256 = "30fc4becfcc5a95ac875fc5a0658a91aa7ddedd763b52d7810c13ed35d9d81aa",
version = "eceb02d5b7772ec1cd78a4d35356e57d2e6d59bb",
sha256 = "ae9d9b87d21d95647ebda197d130b37bddc5c6ee3e6630909a231fd55fcc9069",
strip_prefix = "proxy-wasm-cpp-host-{version}",
urls = ["https://github.com/proxy-wasm/proxy-wasm-cpp-host/archive/{version}.tar.gz"],
use_category = ["dataplane_ext"],
Expand Down
26 changes: 7 additions & 19 deletions source/extensions/access_loggers/wasm/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ WasmAccessLogFactory::createAccessLogInstance(const Protobuf::Message& proto_con
const auto& config = MessageUtil::downcastAndValidate<
const envoy::extensions::access_loggers::wasm::v3::WasmAccessLog&>(
proto_config, context.messageValidationVisitor());
auto access_log =
std::make_shared<WasmAccessLog>(config.config().root_id(), nullptr, std::move(filter));

// Create a base WASM to verify that the code loads before setting/cloning the for the
// individual threads.
Expand All @@ -35,25 +33,15 @@ WasmAccessLogFactory::createAccessLogInstance(const Protobuf::Message& proto_con
envoy::config::core::v3::TrafficDirection::UNSPECIFIED, context.localInfo(),
nullptr /* listener_metadata */);

auto callback = [access_log, &context, plugin](Common::Wasm::WasmHandleSharedPtr base_wasm) {
auto tls_slot = context.threadLocal().allocateSlot();
auto access_log = std::make_shared<WasmAccessLog>(plugin, nullptr, std::move(filter));

auto callback = [access_log, &context, plugin](Common::Wasm::WasmHandleSharedPtr base_wasm) {
// NB: the Slot set() call doesn't complete inline, so all arguments must outlive this call.
tls_slot->set(
[base_wasm,
plugin](Event::Dispatcher& dispatcher) -> std::shared_ptr<ThreadLocal::ThreadLocalObject> {
if (!base_wasm) {
// There is no way to prevent the connection at this point. The user could choose to use
// an HTTP Wasm plugin and only handle onLog() which would correctly close the
// connection in onRequestHeaders().
if (!plugin->fail_open_) {
ENVOY_LOG(critical, "Plugin configured to fail closed failed to load");
}
return nullptr;
}
return std::static_pointer_cast<ThreadLocal::ThreadLocalObject>(
Common::Wasm::getOrCreateThreadLocalWasm(base_wasm, plugin, dispatcher));
});
auto tls_slot =
ThreadLocal::TypedSlot<Common::Wasm::PluginHandle>::makeUnique(context.threadLocal());
tls_slot->set([base_wasm, plugin](Event::Dispatcher& dispatcher) {
return Common::Wasm::getOrCreateThreadLocalPlugin(base_wasm, plugin, dispatcher);
});
access_log->setTlsSlot(std::move(tls_slot));
};

Expand Down
21 changes: 12 additions & 9 deletions source/extensions/access_loggers/wasm/wasm_access_log_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ namespace Extensions {
namespace AccessLoggers {
namespace Wasm {

using Envoy::Extensions::Common::Wasm::WasmHandle;
using Envoy::Extensions::Common::Wasm::PluginHandle;
using Envoy::Extensions::Common::Wasm::PluginSharedPtr;

class WasmAccessLog : public AccessLog::Instance {
public:
WasmAccessLog(absl::string_view root_id, ThreadLocal::SlotPtr tls_slot,
WasmAccessLog(const PluginSharedPtr& plugin, ThreadLocal::TypedSlotPtr<PluginHandle>&& tls_slot,
AccessLog::FilterPtr filter)
: root_id_(root_id), tls_slot_(std::move(tls_slot)), filter_(std::move(filter)) {}
: plugin_(plugin), tls_slot_(std::move(tls_slot)), filter_(std::move(filter)) {}

void log(const Http::RequestHeaderMap* request_headers,
const Http::ResponseHeaderMap* response_headers,
const Http::ResponseTrailerMap* response_trailers,
Expand All @@ -30,20 +32,21 @@ class WasmAccessLog : public AccessLog::Instance {
}
}

if (tls_slot_->get()) {
tls_slot_->getTyped<WasmHandle>().wasm()->log(root_id_, request_headers, response_headers,
response_trailers, stream_info);
auto handle = tls_slot_->get();
if (handle.has_value()) {
handle->wasm()->log(plugin_, request_headers, response_headers, response_trailers,
stream_info);
}
}

void setTlsSlot(ThreadLocal::SlotPtr tls_slot) {
void setTlsSlot(ThreadLocal::TypedSlotPtr<PluginHandle>&& tls_slot) {
ASSERT(tls_slot_ == nullptr);
tls_slot_ = std::move(tls_slot);
}

private:
std::string root_id_;
ThreadLocal::SlotPtr tls_slot_;
PluginSharedPtr plugin_;
ThreadLocal::TypedSlotPtr<PluginHandle> tls_slot_;
AccessLog::FilterPtr filter_;
};

Expand Down
12 changes: 6 additions & 6 deletions source/extensions/bootstrap/wasm/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,18 @@ void WasmFactory::createWasm(const envoy::extensions::wasm::v3::WasmService& con
}
if (singleton) {
// Return a Wasm VM which will be stored as a singleton by the Server.
cb(std::make_unique<WasmService>(
Common::Wasm::getOrCreateThreadLocalWasm(base_wasm, plugin, context.dispatcher())));
cb(std::make_unique<WasmService>(plugin, Common::Wasm::getOrCreateThreadLocalPlugin(
base_wasm, plugin, context.dispatcher())));
return;
}
// Per-thread WASM VM.
// NB: the Slot set() call doesn't complete inline, so all arguments must outlive this call.
auto tls_slot = context.threadLocal().allocateSlot();
auto tls_slot =
ThreadLocal::TypedSlot<Common::Wasm::PluginHandle>::makeUnique(context.threadLocal());
tls_slot->set([base_wasm, plugin](Event::Dispatcher& dispatcher) {
return std::static_pointer_cast<ThreadLocal::ThreadLocalObject>(
Common::Wasm::getOrCreateThreadLocalWasm(base_wasm, plugin, dispatcher));
return Common::Wasm::getOrCreateThreadLocalPlugin(base_wasm, plugin, dispatcher);
});
cb(std::make_unique<WasmService>(std::move(tls_slot)));
cb(std::make_unique<WasmService>(plugin, std::move(tls_slot)));
};

if (!Common::Wasm::createWasm(
Expand Down
15 changes: 11 additions & 4 deletions source/extensions/bootstrap/wasm/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,21 @@ namespace Extensions {
namespace Bootstrap {
namespace Wasm {

using Envoy::Extensions::Common::Wasm::PluginHandle;
using Envoy::Extensions::Common::Wasm::PluginHandleSharedPtr;
using Envoy::Extensions::Common::Wasm::PluginSharedPtr;

class WasmService {
public:
WasmService(Common::Wasm::WasmHandleSharedPtr singleton) : singleton_(std::move(singleton)) {}
WasmService(ThreadLocal::SlotPtr tls_slot) : tls_slot_(std::move(tls_slot)) {}
WasmService(PluginSharedPtr plugin, PluginHandleSharedPtr singleton)
: plugin_(plugin), singleton_(std::move(singleton)) {}
WasmService(PluginSharedPtr plugin, ThreadLocal::TypedSlotPtr<PluginHandle>&& tls_slot)
: plugin_(plugin), tls_slot_(std::move(tls_slot)) {}

private:
Common::Wasm::WasmHandleSharedPtr singleton_;
ThreadLocal::SlotPtr tls_slot_;
PluginSharedPtr plugin_;
PluginHandleSharedPtr singleton_;
ThreadLocal::TypedSlotPtr<PluginHandle> tls_slot_;
};

using WasmServicePtr = std::unique_ptr<WasmService>;
Expand Down
12 changes: 6 additions & 6 deletions source/extensions/common/wasm/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -810,8 +810,8 @@ BufferInterface* Context::getBuffer(WasmBufferType type) {
case WasmBufferType::VmConfiguration:
return buffer_.set(wasm()->vm_configuration());
case WasmBufferType::PluginConfiguration:
if (plugin_) {
return buffer_.set(plugin_->plugin_configuration_);
if (temp_plugin_) {
return buffer_.set(temp_plugin_->plugin_configuration_);
}
return nullptr;
case WasmBufferType::HttpRequestBody:
Expand Down Expand Up @@ -1182,18 +1182,18 @@ bool Context::validateConfiguration(absl::string_view configuration,
if (!wasm()->validate_configuration_) {
return true;
}
plugin_ = plugin_base;
temp_plugin_ = plugin_base;
auto result =
wasm()
->validate_configuration_(this, id_, static_cast<uint32_t>(configuration.size()))
.u64_ != 0;
plugin_.reset();
temp_plugin_.reset();
return result;
}

absl::string_view Context::getConfiguration() {
if (plugin_) {
return plugin_->plugin_configuration_;
if (temp_plugin_) {
return temp_plugin_->plugin_configuration_;
} else {
return wasm()->vm_configuration();
}
Expand Down
2 changes: 2 additions & 0 deletions source/extensions/common/wasm/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ using proxy_wasm::ContextBase;
using proxy_wasm::Pairs;
using proxy_wasm::PairsWithStringValues;
using proxy_wasm::PluginBase;
using proxy_wasm::PluginHandleBase;
using proxy_wasm::SharedQueueDequeueToken;
using proxy_wasm::SharedQueueEnqueueToken;
using proxy_wasm::WasmBase;
Expand All @@ -45,6 +46,7 @@ using GrpcService = envoy::config::core::v3::GrpcService;

class Wasm;

using PluginHandleBaseSharedPtr = std::shared_ptr<PluginHandleBase>;
using WasmHandleBaseSharedPtr = std::shared_ptr<WasmHandleBase>;

// Opaque context object.
Expand Down
36 changes: 26 additions & 10 deletions source/extensions/common/wasm/wasm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,16 +243,16 @@ ContextBase* Wasm::createRootContext(const std::shared_ptr<PluginBase>& plugin)

ContextBase* Wasm::createVmContext() { return new Context(this); }

void Wasm::log(absl::string_view root_id, const Http::RequestHeaderMap* request_headers,
void Wasm::log(const PluginSharedPtr& plugin, const Http::RequestHeaderMap* request_headers,
const Http::ResponseHeaderMap* response_headers,
const Http::ResponseTrailerMap* response_trailers,
const StreamInfo::StreamInfo& stream_info) {
auto context = getRootContext(root_id);
auto context = getRootContext(plugin, true);
context->log(request_headers, response_headers, response_trailers, stream_info);
}

void Wasm::onStatsUpdate(absl::string_view root_id, Envoy::Stats::MetricSnapshot& snapshot) {
auto context = getRootContext(root_id);
void Wasm::onStatsUpdate(const PluginSharedPtr& plugin, Envoy::Stats::MetricSnapshot& snapshot) {
auto context = getRootContext(plugin, true);
context->onStatsUpdate(snapshot);
}

Expand Down Expand Up @@ -281,6 +281,14 @@ getCloneFactory(WasmExtension* wasm_extension, Event::Dispatcher& dispatcher,
};
}

static proxy_wasm::PluginHandleFactory getPluginFactory(WasmExtension* wasm_extension) {
auto wasm_plugin_factory = wasm_extension->pluginFactory();
return [wasm_plugin_factory](WasmHandleBaseSharedPtr base_wasm,
absl::string_view plugin_key) -> std::shared_ptr<PluginHandleBase> {
return wasm_plugin_factory(std::static_pointer_cast<WasmHandle>(base_wasm), plugin_key);
};
}

WasmEvent toWasmEvent(const std::shared_ptr<WasmHandleBase>& wasm) {
if (!wasm) {
return WasmEvent::UnableToCreateVM;
Expand Down Expand Up @@ -474,13 +482,21 @@ bool createWasm(const VmConfig& vm_config, const PluginSharedPtr& plugin,
create_root_context_for_testing);
}

WasmHandleSharedPtr getOrCreateThreadLocalWasm(const WasmHandleSharedPtr& base_wasm,
const PluginSharedPtr& plugin,
Event::Dispatcher& dispatcher,
CreateContextFn create_root_context_for_testing) {
return std::static_pointer_cast<WasmHandle>(proxy_wasm::getOrCreateThreadLocalWasm(
PluginHandleSharedPtr
getOrCreateThreadLocalPlugin(const WasmHandleSharedPtr& base_wasm, const PluginSharedPtr& plugin,
Event::Dispatcher& dispatcher,
CreateContextFn create_root_context_for_testing) {
if (!base_wasm) {
if (!plugin->fail_open_) {
ENVOY_LOG_TO_LOGGER(Envoy::Logger::Registry::getLog(Envoy::Logger::Id::wasm), critical,
"Plugin configured to fail closed failed to load");
}
return nullptr;
}
return std::static_pointer_cast<PluginHandle>(proxy_wasm::getOrCreateThreadLocalPlugin(
std::static_pointer_cast<WasmHandle>(base_wasm), plugin,
getCloneFactory(getWasmExtension(), dispatcher, create_root_context_for_testing)));
getCloneFactory(getWasmExtension(), dispatcher, create_root_context_for_testing),
getPluginFactory(getWasmExtension())));
}

} // namespace Wasm
Expand Down
34 changes: 26 additions & 8 deletions source/extensions/common/wasm/wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ class Wasm : public WasmBase, Logger::Loggable<Logger::Id::wasm> {

Upstream::ClusterManager& clusterManager() const { return cluster_manager_; }
Event::Dispatcher& dispatcher() { return dispatcher_; }
Context* getRootContext(absl::string_view root_id) {
return static_cast<Context*>(WasmBase::getRootContext(root_id));
Context* getRootContext(const std::shared_ptr<PluginBase>& plugin, bool allow_closed) {
return static_cast<Context*>(WasmBase::getRootContext(plugin, allow_closed));
}
void setTimerPeriod(uint32_t root_context_id, std::chrono::milliseconds period) override;
virtual void tickHandler(uint32_t root_context_id);
Expand All @@ -72,12 +72,13 @@ class Wasm : public WasmBase, Logger::Loggable<Logger::Id::wasm> {
void getFunctions() override;

// AccessLog::Instance
void log(absl::string_view root_id, const Http::RequestHeaderMap* request_headers,
void log(const PluginSharedPtr& plugin, const Http::RequestHeaderMap* request_headers,
const Http::ResponseHeaderMap* response_headers,
const Http::ResponseTrailerMap* response_trailers,
const StreamInfo::StreamInfo& stream_info);

void onStatsUpdate(absl::string_view root_id, Envoy::Stats::MetricSnapshot& snapshot);
void onStatsUpdate(const PluginSharedPtr& plugin, Envoy::Stats::MetricSnapshot& snapshot);

virtual std::string buildVersion() { return BUILD_VERSION_NUMBER; }

void initializeLifecycle(Server::ServerLifecycleNotifier& lifecycle_notifier);
Expand Down Expand Up @@ -136,6 +137,23 @@ class WasmHandle : public WasmHandleBase, public ThreadLocal::ThreadLocalObject
WasmSharedPtr wasm_;
};

using WasmHandleSharedPtr = std::shared_ptr<WasmHandle>;

class PluginHandle : public PluginHandleBase, public ThreadLocal::ThreadLocalObject {
public:
explicit PluginHandle(const WasmHandleSharedPtr& wasm_handle, absl::string_view plugin_key)
: PluginHandleBase(std::static_pointer_cast<WasmHandleBase>(wasm_handle), plugin_key),
wasm_handle_(wasm_handle) {}

WasmSharedPtr& wasm() { return wasm_handle_->wasm(); }
WasmHandleSharedPtr& wasmHandleForTest() { return wasm_handle_; }

private:
WasmHandleSharedPtr wasm_handle_;
};

using PluginHandleSharedPtr = std::shared_ptr<PluginHandle>;

using CreateWasmCallback = std::function<void(WasmHandleSharedPtr)>;

// Returns false if createWasm failed synchronously. This is necessary because xDS *MUST* report
Expand All @@ -150,10 +168,10 @@ bool createWasm(const VmConfig& vm_config, const PluginSharedPtr& plugin,
CreateWasmCallback&& callback,
CreateContextFn create_root_context_for_testing = nullptr);

WasmHandleSharedPtr
getOrCreateThreadLocalWasm(const WasmHandleSharedPtr& base_wasm, const PluginSharedPtr& plugin,
Event::Dispatcher& dispatcher,
CreateContextFn create_root_context_for_testing = nullptr);
PluginHandleSharedPtr
getOrCreateThreadLocalPlugin(const WasmHandleSharedPtr& base_wasm, const PluginSharedPtr& plugin,
Event::Dispatcher& dispatcher,
CreateContextFn create_root_context_for_testing = nullptr);

void clearCodeCacheForTesting();
std::string anyToBytes(const ProtobufWkt::Any& any);
Expand Down
8 changes: 8 additions & 0 deletions source/extensions/common/wasm/wasm_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ EnvoyWasm::createEnvoyWasmVmIntegration(const Stats::ScopeSharedPtr& scope,
return std::make_unique<EnvoyWasmVmIntegration>(scope, runtime, short_runtime);
}

PluginHandleExtensionFactory EnvoyWasm::pluginFactory() {
return [](const WasmHandleSharedPtr& base_wasm,
absl::string_view plugin_key) -> PluginHandleBaseSharedPtr {
return std::static_pointer_cast<PluginHandleBase>(
std::make_shared<PluginHandle>(base_wasm, plugin_key));
};
}

WasmHandleExtensionFactory EnvoyWasm::wasmFactory() {
return [](const VmConfig vm_config, const Stats::ScopeSharedPtr& scope,
Upstream::ClusterManager& cluster_manager, Event::Dispatcher& dispatcher,
Expand Down
4 changes: 4 additions & 0 deletions source/extensions/common/wasm/wasm_extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class EnvoyWasmVmIntegration;
using WasmHandleSharedPtr = std::shared_ptr<WasmHandle>;
using CreateContextFn =
std::function<ContextBase*(Wasm* wasm, const std::shared_ptr<Plugin>& plugin)>;
using PluginHandleExtensionFactory = std::function<PluginHandleBaseSharedPtr(
const WasmHandleSharedPtr& base_wasm, absl::string_view plugin_key)>;
using WasmHandleExtensionFactory = std::function<WasmHandleBaseSharedPtr(
const VmConfig& vm_config, const Stats::ScopeSharedPtr& scope,
Upstream::ClusterManager& cluster_manager, Event::Dispatcher& dispatcher,
Expand All @@ -54,6 +56,7 @@ class WasmExtension : Logger::Loggable<Logger::Id::wasm> {
virtual std::unique_ptr<EnvoyWasmVmIntegration>
createEnvoyWasmVmIntegration(const Stats::ScopeSharedPtr& scope, absl::string_view runtime,
absl::string_view short_runtime) = 0;
virtual PluginHandleExtensionFactory pluginFactory() = 0;
virtual WasmHandleExtensionFactory wasmFactory() = 0;
virtual WasmHandleExtensionCloneFactory wasmCloneFactory() = 0;
enum class WasmEvent : int {
Expand Down Expand Up @@ -100,6 +103,7 @@ class EnvoyWasm : public WasmExtension {
std::unique_ptr<EnvoyWasmVmIntegration>
createEnvoyWasmVmIntegration(const Stats::ScopeSharedPtr& scope, absl::string_view runtime,
absl::string_view short_runtime) override;
PluginHandleExtensionFactory pluginFactory() override;
WasmHandleExtensionFactory wasmFactory() override;
WasmHandleExtensionCloneFactory wasmCloneFactory() override;
void onEvent(WasmEvent event, const PluginSharedPtr& plugin) override;
Expand Down
Loading

0 comments on commit 3c8e56a

Please sign in to comment.