Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 69 additions & 11 deletions temporalio/ext/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use temporal_client::{
};

use magnus::{
DataTypeFunctions, Error, RString, Ruby, TypedData, Value, class, function, method, prelude::*,
scan_args,
DataTypeFunctions, Error, RHash, RString, Ruby, TypedData, Value, class, function, method,
prelude::*, scan_args,
};
use tonic::{Status, metadata::MetadataKey};
use url::Url;
Expand Down Expand Up @@ -84,9 +84,11 @@ macro_rules! rpc_call {
impl Client {
pub fn async_new(runtime: &Runtime, options: Struct, queue: Value) -> Result<(), Error> {
runtime.handle.fork_check("create client")?;
let ruby = Ruby::get().expect("Ruby not available");
// Build options
let mut opts_build = ClientOptionsBuilder::default();
let tls = options.child(id!("tls"))?;
let headers = partition_grpc_headers(&ruby, options.member(id!("rpc_metadata"))?)?;
opts_build
.target_url(
Url::parse(
Expand All @@ -101,7 +103,8 @@ impl Client {
)
.client_name(options.member::<String>(id!("client_name"))?)
.client_version(options.member::<String>(id!("client_version"))?)
.headers(Some(options.member(id!("rpc_metadata"))?))
.headers(Some(headers.headers))
.binary_headers(Some(headers.binary_headers))
.api_key(options.member(id!("api_key"))?)
.identity(options.member::<String>(id!("identity"))?);
if let Some(tls) = tls {
Expand Down Expand Up @@ -193,6 +196,7 @@ impl Client {

pub fn async_invoke_rpc(&self, args: &[Value]) -> Result<(), Error> {
self.runtime_handle.fork_check("use client")?;
let ruby = Ruby::get().expect("Ruby not available");
let args = scan_args::scan_args::<(), (), (), (), _, ()>(args)?;
let (service, rpc, request, retry, metadata, timeout, cancel_token, queue) =
scan_args::get_kwargs::<
Expand All @@ -202,7 +206,7 @@ impl Client {
String,
RString,
bool,
Option<HashMap<String, String>>,
Option<RHash>,
Option<f64>,
Option<&CancellationToken>,
Value,
Expand All @@ -224,11 +228,16 @@ impl Client {
&[],
)?
.required;
let headers = if let Some(metadata) = metadata {
Some(partition_grpc_headers(&ruby, metadata)?)
} else {
None
};
let call = RpcCall {
rpc,
request: unsafe { request.as_slice() },
retry,
metadata,
headers,
timeout,
cancel_token: cancel_token.map(|c| c.token.clone()),
_not_send_sync: PhantomData,
Expand All @@ -237,18 +246,59 @@ impl Client {
self.invoke_rpc(service, callback, call)
}

pub fn update_metadata(&self, headers: HashMap<String, String>) -> Result<(), Error> {
pub fn update_metadata(&self, headers: RHash) -> Result<(), Error> {
let ruby = Ruby::get().expect("Ruby not available");
let headers = partition_grpc_headers(&ruby, headers)?;
self.core
.get_client()
.set_headers(headers.headers)
.map_err(|err| error!("Invalid headers: {}", err))?;
self.core
.get_client()
.set_headers(headers)
.map_err(|err| error!("Invalid headers: {}", err))
.set_binary_headers(headers.binary_headers)
.map_err(|err| error!("Invalid headers: {}", err))?;
Ok(())
}

pub fn update_api_key(&self, api_key: Option<String>) {
self.core.get_client().set_api_key(api_key);
}
}

pub(crate) struct GrpcHeaders {
headers: HashMap<String, String>,
binary_headers: HashMap<String, Vec<u8>>,
}

fn partition_grpc_headers(ruby: &Ruby, hash: RHash) -> Result<GrpcHeaders, Error> {
let mut headers = HashMap::new();
let mut binary_headers = HashMap::new();
hash.foreach(|key: String, value: RString| {
if key.ends_with("-bin") {
if value.enc_get() != ruby.ascii8bit_encindex() {
return Err(Error::new(
ruby.exception_arg_error(),
format!("Value for metadata key {key} must be ASCII-8BIT"),
));
}
binary_headers.insert(key, unsafe { value.as_slice().to_vec() });
} else {
let value = value.to_string().map_err(|err| {
Error::new(
ruby.exception_arg_error(),
format!("Value for metadata key {key} invalid: {err}"),
)
})?;
headers.insert(key, value);
}
Ok(magnus::r_hash::ForEach::Continue)
})?;
Ok(GrpcHeaders {
headers,
binary_headers,
})
}

#[derive(DataTypeFunctions, TypedData)]
#[magnus(
class = "Temporalio::Internal::Bridge::Client::RPCFailure",
Expand Down Expand Up @@ -280,7 +330,7 @@ pub(crate) struct RpcCall<'a> {
pub rpc: String,
pub request: &'a [u8],
pub retry: bool,
pub metadata: Option<HashMap<String, String>>,
pub headers: Option<GrpcHeaders>,
pub timeout: Option<f64>,
pub cancel_token: Option<tokio_util::sync::CancellationToken>,

Expand All @@ -294,15 +344,23 @@ impl RpcCall<'_> {
pub fn into_request<P: prost::Message + Default>(self) -> Result<tonic::Request<P>, Error> {
let proto = P::decode(self.request).map_err(|err| error!("Invalid proto: {}", err))?;
let mut req = tonic::Request::new(proto);
if let Some(metadata) = self.metadata {
for (k, v) in metadata {
if let Some(headers) = self.headers {
for (k, v) in headers.headers {
req.metadata_mut().insert(
MetadataKey::from_str(k.as_str())
.map_err(|err| error!("Invalid metadata key: {}", err))?,
v.parse()
.map_err(|err| error!("Invalid metadata value: {}", err))?,
);
}
for (k, v) in headers.binary_headers {
req.metadata_mut().insert_bin(
MetadataKey::from_str(k.as_str())
.map_err(|err| error!("Invalid metadata key: {}", err))?,
v.try_into()
.map_err(|err| error!("Invalid metadata value: {}", err))?,
);
}
}
if let Some(timeout) = self.timeout {
req.set_timeout(Duration::from_secs_f64(timeout));
Expand Down
33 changes: 33 additions & 0 deletions temporalio/test/client_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,37 @@ def test_fork
assert status.success?
assert_equal 'started workflow', reader.read.strip
end

def test_binary_metadata
orig_metadata = env.client.connection.rpc_metadata

# Connect a new client with some bad metadata
err = assert_raises(ArgumentError) do
Temporalio::Client.connect(
env.client.connection.target_host,
env.client.namespace,
rpc_metadata: { 'connect-bin' => 'not-allowed' }
)
end
assert_equal 'Value for metadata key connect-bin must be ASCII-8BIT', err.message

# Update a client with some bad metadata
err = assert_raises(ArgumentError) do
env.client.connection.rpc_metadata = { 'update-bin' => 'not-allowed' }
end
assert_equal 'Value for metadata key update-bin must be ASCII-8BIT', err.message

# Make an RPC call with some bad metadata
err = assert_raises(ArgumentError) do
env.client.start_workflow(
:MyWorkflow,
id: "wf-#{SecureRandom.uuid}",
task_queue: "tq-#{SecureRandom.uuid}",
rpc_options: Temporalio::Client::RPCOptions.new(metadata: { 'rpc-bin' => 'not-allowed' })
)
end
assert_equal 'Value for metadata key rpc-bin must be ASCII-8BIT', err.message
ensure
env.client.connection.rpc_metadata = orig_metadata
end
end
Loading