diff --git a/oak/proto/oak_api.proto b/oak/proto/oak_api.proto index df16dbfb9ea..0f85433b6b4 100644 --- a/oak/proto/oak_api.proto +++ b/oak/proto/oak_api.proto @@ -42,6 +42,8 @@ enum OakStatus { ERR_TERMINATED = 9; // Channel has no messages available to read. ERR_CHANNEL_EMPTY = 10; + // The node does not have sufficient permissions to perform the requested operation. + ERR_PERMISSION_DENIED = 11; } // Single byte values used to indicate the read status of a channel on the diff --git a/oak/server/rust/oak_runtime/src/runtime/channel.rs b/oak/server/rust/oak_runtime/src/runtime/channel.rs index c407e287986..521fc2a0760 100644 --- a/oak/server/rust/oak_runtime/src/runtime/channel.rs +++ b/oak/server/rust/oak_runtime/src/runtime/channel.rs @@ -62,7 +62,7 @@ pub struct Channel { /// This is set at channel creation time and does not change after that. /// /// See https://github.com/project-oak/oak/blob/master/docs/concepts.md#labels - label: oak_abi::label::Label, + pub label: oak_abi::label::Label, } /// A reference to a [`Channel`]. Each [`Handle`] has an implicit direction such that it is only diff --git a/oak/server/rust/oak_runtime/src/runtime/mod.rs b/oak/server/rust/oak_runtime/src/runtime/mod.rs index 60840095869..51061f16092 100644 --- a/oak/server/rust/oak_runtime/src/runtime/mod.rs +++ b/oak/server/rust/oak_runtime/src/runtime/mod.rs @@ -44,8 +44,6 @@ struct Node { /// This is set at node creation time and does not change after that. /// /// See https://github.com/project-oak/oak/blob/master/docs/concepts.md#labels - // TODO(#630): Remove exception when label tracking is implemented. - #[allow(dead_code)] label: oak_abi::label::Label, /// A [`HashSet`] containing all the handles associated with this Node. @@ -155,7 +153,7 @@ impl Runtime { // will prevent additional nodes from starting to wait again, because `wait_on_channels` // will return immediately with `OakStatus::ErrTerminated`. let instances: Vec<_> = { - let mut nodes = self.nodes.write().unwrap(); + let mut nodes = self.nodes.write().expect("could not acquire lock on nodes"); self.terminating.store(true, SeqCst); nodes @@ -204,10 +202,8 @@ impl Runtime { return; } - let nodes = self.nodes.read().unwrap(); - let node = nodes - .get(&node_id) - .expect("Invalid node_id passed into track_handles_in_node!"); + let nodes = self.nodes.read().expect("could not acquire lock on nodes"); + let node = nodes.get(&node_id).expect("invalid node_id"); let mut tracked_handles = node.handles.lock().unwrap(); for handle in handles { @@ -223,12 +219,13 @@ impl Runtime { return Ok(()); } - let nodes = self.nodes.read().unwrap(); + let nodes = self.nodes.read().expect("could not acquire lock on nodes"); // Lookup the node_id in the runtime's nodes hashmap. - let node = nodes - .get(&node_id) - .expect("Invalid node_id passed into validate_handle_access!"); - let tracked_handles = node.handles.lock().unwrap(); + let node = nodes.get(&node_id).expect("invalid node_id"); + let tracked_handles = node + .handles + .lock() + .expect("could not acquire lock on tracked handles"); // Check the handle exists in the handles associated with a node, otherwise // return ErrBadHandle. @@ -245,21 +242,22 @@ impl Runtime { /// Validate the [`NodeId`] has access to all [`Handle`]'s passed in the iterator, returning /// `Err(OakStatus::ErrBadHandle)` if access is not allowed. - fn validate_handles_access<'a, I>(&self, node_id: NodeId, handles: I) -> Result<(), OakStatus> + fn validate_handles_access(&self, node_id: NodeId, handles: I) -> Result<(), OakStatus> where - I: IntoIterator, + I: IntoIterator, { // Allow RUNTIME_NODE_ID access to all handles. if node_id == RUNTIME_NODE_ID { return Ok(()); } - let nodes = self.nodes.read().unwrap(); - let node = nodes - .get(&node_id) - .expect("Invalid node_id passed into filter_optional_handles!"); + let nodes = self.nodes.read().expect("could not acquire lock on nodes"); + let node = nodes.get(&node_id).expect("invalid node_id"); - let tracked_handles = node.handles.lock().unwrap(); + let tracked_handles = node + .handles + .lock() + .expect("could not acquire lock on node handles"); for handle in handles { // Check handle is accessible by the node. if !tracked_handles.contains(&handle) { @@ -273,8 +271,106 @@ impl Runtime { Ok(()) } + fn validate_can_read_from_channel( + &self, + node_id: NodeId, + channel_handle: Handle, + ) -> Result<(), OakStatus> { + debug!( + "validating whether node {:?} can read from channel {:?}", + node_id, channel_handle + ); + + // Allow RUNTIME_NODE_ID access to all handles. + if node_id == RUNTIME_NODE_ID { + return Ok(()); + } + + let nodes = self.nodes.read().expect("could not acquire lock on nodes"); + let node = nodes.get(&node_id).expect("invalid node_id"); + let node_label = &node.label; + + let channel_label = self.channels.with_channel( + self.channels.get_reader_channel(channel_handle)?, + |channel| Ok(channel.label.clone()), + )?; + + if channel_label.flows_to(node_label) { + debug!( + "node {:?} can read from channel {:?}", + node_id, channel_handle + ); + Ok(()) + } else { + debug!( + "node {:?} cannot read from channel {:?}", + node_id, channel_handle + ); + Err(OakStatus::ErrPermissionDenied) + } + } + + fn validate_can_read_from_channels( + &self, + node_id: NodeId, + channel_handles: I, + ) -> Result<(), OakStatus> + where + I: IntoIterator, + { + let all_chanel_handles_ok = channel_handles.into_iter().all(|channel_handle| { + self.validate_can_read_from_channel(node_id, channel_handle) + .is_ok() + }); + if all_chanel_handles_ok { + Ok(()) + } else { + Err(OakStatus::ErrPermissionDenied) + } + } + + fn validate_can_write_to_channel( + &self, + node_id: NodeId, + channel_handle: Handle, + ) -> Result<(), OakStatus> { + debug!( + "validating whether node {:?} can write to channel {:?}", + node_id, channel_handle + ); + + // Allow RUNTIME_NODE_ID access to all handles. + if node_id == RUNTIME_NODE_ID { + return Ok(()); + } + + let nodes = self.nodes.read().expect("could not acquire lock on nodes"); + let node = nodes.get(&node_id).expect("invalid node_id"); + let node_label = &node.label; + + let channel_label = self.channels.with_channel( + self.channels.get_writer_channel(channel_handle)?, + |channel| Ok(channel.label.clone()), + )?; + + if node_label.flows_to(&channel_label) { + debug!( + "node {:?} can write to channel {:?}", + node_id, channel_handle + ); + Ok(()) + } else { + debug!( + "node {:?} cannot write to channel {:?}", + node_id, channel_handle + ); + Err(OakStatus::ErrPermissionDenied) + } + } + /// Creates a new [`Channel`] and returns a `(writer handle, reader handle)` pair. pub fn new_channel(&self, node_id: NodeId, label: &oak_abi::label::Label) -> (Handle, Handle) { + // TODO(#630): Check whether the calling node can create a node with the specified label. let (writer, reader) = self.channels.new_channel(label); self.track_handles_in_node(node_id, vec![writer, reader]); (writer, reader) @@ -317,7 +413,11 @@ impl Runtime { node_id: NodeId, readers: &[Option], ) -> Result, OakStatus> { - self.validate_handles_access(node_id, readers.iter().filter_map(|x| x.as_ref()))?; + self.validate_handles_access(node_id, readers.iter().filter_map(|x| x.as_ref()).copied())?; + self.validate_can_read_from_channels( + node_id, + readers.iter().filter_map(|x| x.as_ref()).copied(), + )?; let thread = thread::current(); while !self.is_terminating() { @@ -377,6 +477,7 @@ impl Runtime { msg: Message, ) -> Result<(), OakStatus> { self.validate_handle_access(node_id, reference)?; + self.validate_can_write_to_channel(node_id, reference)?; self.channels.with_channel(self.channels.get_writer_channel(reference)?, |channel|{ if channel.is_orphan() { @@ -442,6 +543,7 @@ impl Runtime { reference: Handle, ) -> Result, OakStatus> { self.validate_handle_access(node_id, reference)?; + self.validate_can_read_from_channel(node_id, reference)?; self.channels .with_channel(self.channels.get_reader_channel(reference)?, |channel| { let mut messages = channel.messages.write().unwrap(); @@ -471,6 +573,7 @@ impl Runtime { reference: Handle, ) -> Result { self.validate_handle_access(node_id, reference)?; + self.validate_can_read_from_channel(node_id, reference)?; self.channels .with_channel(self.channels.get_reader_channel(reference)?, |channel| { let messages = channel.messages.read().unwrap(); @@ -499,6 +602,7 @@ impl Runtime { handles_capacity: usize, ) -> Result, OakStatus> { self.validate_handle_access(node_id, reference)?; + self.validate_can_read_from_channel(node_id, reference)?; let result = self.channels .with_channel(self.channels.get_reader_channel(reference)?, |channel| { let mut messages = channel.messages.write().unwrap(); @@ -545,6 +649,7 @@ impl Runtime { reference: Handle, ) -> Result { self.validate_handle_access(node_id, reference)?; + // self.validate_can_read_from_channel(node_id, reference)?; { let readers = self.channels.readers.read().unwrap(); if readers.contains_key(&reference) { diff --git a/sdk/rust/oak/src/io/mod.rs b/sdk/rust/oak/src/io/mod.rs index e7ce80dfbd3..7f875c7265b 100644 --- a/sdk/rust/oak/src/io/mod.rs +++ b/sdk/rust/oak/src/io/mod.rs @@ -68,5 +68,8 @@ pub fn error_from_nonok_status(status: OakStatus) -> io::Error { OakStatus::ErrInternal => io::Error::new(io::ErrorKind::Other, "Internal error"), OakStatus::ErrTerminated => io::Error::new(io::ErrorKind::Other, "Node terminated"), OakStatus::ErrChannelEmpty => io::Error::new(io::ErrorKind::UnexpectedEof, "Channel empty"), + OakStatus::ErrPermissionDenied => { + io::Error::new(io::ErrorKind::PermissionDenied, "Permission denied") + } } }