Skip to content

Commit

Permalink
refactoring(rust): updated unit tests for now-proto-pdu
Browse files Browse the repository at this point in the history
  • Loading branch information
pacmancoder committed Dec 24, 2024
1 parent b17ed2c commit 767d128
Show file tree
Hide file tree
Showing 34 changed files with 792 additions and 374 deletions.
3 changes: 1 addition & 2 deletions crates/now-proto-pdu/src/channel/capset.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use core::time;

use bitflags::bitflags;

use ironrdp_core::{
ensure_fixed_part_size, invalid_field_err, Decode, DecodeResult, Encode, EncodeResult, ReadCursor, WriteCursor,
};
Expand Down Expand Up @@ -227,7 +226,7 @@ impl NowChannelCapsetMsg {
let system_capset = NowSystemCapsetFlags::from_bits_retain(src.read_u16());
let session_capset = NowSessionCapsetFlags::from_bits_retain(src.read_u16());
let exec_capset = NowExecCapsetFlags::from_bits_retain(src.read_u16());
// Read heartbeat interval unconditionaly even if `SET_HEARTBEAT` flags is not set.
// Read heartbeat interval unconditionally even if `SET_HEARTBEAT` flags is not set.
let heartbeat_interval_value = src.read_u32();

let heartbeat_interval = flags
Expand Down
10 changes: 6 additions & 4 deletions crates/now-proto-pdu/src/channel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@ mod capset;
mod heartbeat;
mod terminate;

pub use capset::{
NowChannelCapsetFlags, NowChannelCapsetMsg, NowExecCapsetFlags, NowProtoVersion, NowSessionCapsetFlags,
NowSystemCapsetFlags,
};
pub use heartbeat::NowChannelHeartbeatMsg;
use ironrdp_core::{DecodeResult, Encode, EncodeResult, IntoOwned, ReadCursor, WriteCursor};
pub use terminate::{NowChannelTerminateMsg, OwnedNowChannelTerminateMsg};

use crate::NowHeader;

pub use capset::NowChannelCapsetMsg;
pub use heartbeat::NowChannelHeartbeatMsg;
pub use terminate::{NowChannelTerminateMsg, OwnedNowChannelTerminateMsg};

// Wrapper for the `NOW_CHANNEL_MSG_CLASS_ID` message class.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NowChannelMessage<'a> {
Expand Down
27 changes: 17 additions & 10 deletions crates/now-proto-pdu/src/channel/terminate.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use ironrdp_core::{invalid_field_err, Decode, DecodeResult, Encode, EncodeResult, IntoOwned, ReadCursor, WriteCursor};

use crate::{NowChannelMessage, NowChannelMsgKind, NowHeader, NowMessage, NowMessageClass, NowStatus};
use crate::{NowChannelMessage, NowChannelMsgKind, NowHeader, NowMessage, NowMessageClass, NowStatus, NowStatusError};

/// Channel termination notice, could be sent by either parties at any moment of communication to
/// gracefully close DVC channel.
Expand All @@ -23,29 +23,36 @@ impl IntoOwned for NowChannelTerminateMsg<'_> {
}
}

impl Default for NowChannelTerminateMsg<'_> {
fn default() -> Self {
let status = NowStatus::new_success();

Self { status }
}
}

impl<'a> NowChannelTerminateMsg<'a> {
const NAME: &'static str = "NOW_CHANNEL_TERMINATE_MSG";

pub fn from_status(status: NowStatus<'a>) -> EncodeResult<Self> {
pub fn from_error(error: impl Into<NowStatusError>) -> EncodeResult<Self> {
let status = NowStatus::new_error(error);

let msg = Self { status };

msg.ensure_message_size().expect("validated in constructor");
ensure_now_message_size!(msg.status.size());

Ok(msg)
}

pub fn to_result(&self) -> Result<(), NowStatusError> {
self.status.to_result()
}

pub(super) fn decode_from_body(_header: NowHeader, src: &mut ReadCursor<'a>) -> DecodeResult<Self> {
let status = NowStatus::decode(src)?;

Ok(Self { status })
}

fn ensure_message_size(&self) -> EncodeResult<()> {
let _message_size =
u32::try_from(self.status.size()).map_err(|_| invalid_field_err!("size", "message size overflow"))?;

Ok(())
}
}

impl Encode for NowChannelTerminateMsg<'_> {
Expand Down
13 changes: 3 additions & 10 deletions crates/now-proto-pdu/src/core/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ use crate::VarU32;
///
/// NOW-PROTO: NOW_VARBUF
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct NowVarBuf<'a>(Cow<'a, [u8]>);

impl_pdu_borrowing!(NowVarBuf<'_>, OwnedNowVarBuf);
pub(crate) struct NowVarBuf<'a>(Cow<'a, [u8]>);

impl IntoOwned for NowVarBuf<'_> {
type Owned = OwnedNowVarBuf;
type Owned = NowVarBuf<'static>;

fn into_owned(self) -> Self::Owned {
NowVarBuf(Cow::Owned(self.0.into_owned()))
Expand All @@ -30,7 +28,7 @@ impl<'a> NowVarBuf<'a> {
const NAME: &'static str = "NOW_VARBUF";

/// Create a new `NowVarBuf` instance. Returns an error if the provided value is too large.
pub fn new(value: impl Into<Cow<'a, [u8]>>) -> EncodeResult<Self> {
pub(crate) fn new(value: impl Into<Cow<'a, [u8]>>) -> EncodeResult<Self> {
let value = value.into();

let _: u32 = value
Expand All @@ -42,11 +40,6 @@ impl<'a> NowVarBuf<'a> {

Ok(NowVarBuf(value))
}

/// Get the buffer value.
pub fn value(&self) -> &[u8] {
&self.0
}
}

impl Encode for NowVarBuf<'_> {
Expand Down
16 changes: 11 additions & 5 deletions crates/now-proto-pdu/src/core/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
//! This module contains `NOW-PROTO` core types definitions.
//!
//! Note that these types are not intended to be used directly by the user, and not exported in the
//! public API.
mod buffer;
mod header;
mod number;
mod status;
mod string;

pub use buffer::{NowVarBuf, OwnedNowVarBuf};
pub use header::{NowHeader, NowMessageClass};
pub use number::VarU32;
pub use status::{NowProtoError, NowStatus, NowStatusError, NowStatusErrorKind, OwnedNowStatus};
pub use string::{NowVarStr, OwnedNowVarStr};
pub(crate) use buffer::NowVarBuf;
pub(crate) use header::{NowHeader, NowMessageClass};
pub(crate) use number::VarU32;
pub(crate) use status::NowStatus;
// Only public-exported type is the status error, which should be available to the user for error
// handling.
pub use status::{NowProtoError, NowStatusError, NowStatusErrorKind};
pub(crate) use string::NowVarStr;
53 changes: 39 additions & 14 deletions crates/now-proto-pdu/src/core/status.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use alloc::{borrow::Cow, fmt};
use alloc::borrow::Cow;
use alloc::fmt;
use core::ops::Deref;

use bitflags::bitflags;

use ironrdp_core::{
ensure_fixed_part_size, Decode, DecodeResult, Encode, EncodeResult, IntoOwned, ReadCursor, WriteCursor,
};
Expand Down Expand Up @@ -216,23 +217,49 @@ impl NowStatusErrorKind {
}

/// Wrapper type around NOW_STATUS errors. Provides rust-friendly interface for error handling.
#[derive(Debug)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct NowStatusError {
kind: NowStatusErrorKind,
message: NowVarStr<'static>,
}

impl NowStatusError {
pub fn new_generic(code: u32) -> Self {
Self {
kind: NowStatusErrorKind::Generic(code),
message: Default::default(),
}
}

pub fn new_proto(error: NowProtoError) -> Self {
Self {
kind: NowStatusErrorKind::Now(error),
message: Default::default(),
}
}

pub fn new_winapi(code: u32) -> Self {
Self {
kind: NowStatusErrorKind::WinApi(code),
message: Default::default(),
}
}

pub fn new_unix(code: u32) -> Self {
Self {
kind: NowStatusErrorKind::Unix(code),
message: Default::default(),
}
}

pub fn kind(&self) -> NowStatusErrorKind {
self.kind
}

pub fn message(&self) -> &str {
&self.message
}
}

impl NowStatusError {
/// Attach optional message to NOW_STATUS error.
pub fn with_message(self, message: impl Into<Cow<'static, str>>) -> EncodeResult<Self> {
Ok(Self {
Expand All @@ -248,7 +275,7 @@ impl core::fmt::Display for NowStatusError {

// Write optional message if provided.
if !self.message.is_empty() {
write!(f, " ({})", self.message.value())?;
write!(f, " ({})", self.message.deref())?;
}

Ok(())
Expand Down Expand Up @@ -276,20 +303,18 @@ impl core::error::Error for NowStatusError {}
///
/// NOW-PROTO: NOW_STATUS
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct NowStatus<'a> {
pub(crate) struct NowStatus<'a> {
flags: NowStatusFlags,
kind: RawNowStatusKind,
code: u32,
message: NowVarStr<'a>,
}

impl_pdu_borrowing!(NowStatus<'_>, OwnedNowStatus);

impl IntoOwned for NowStatus<'_> {
type Owned = OwnedNowStatus;
type Owned = NowStatus<'static>;

fn into_owned(self) -> Self::Owned {
OwnedNowStatus {
Self::Owned {
flags: self.flags,
kind: self.kind,
code: self.code,
Expand All @@ -303,7 +328,7 @@ impl NowStatus<'_> {
const FIXED_PART_SIZE: usize = 8;

/// Create a new success status.
pub fn new_success() -> Self {
pub(crate) fn new_success() -> Self {
Self {
flags: NowStatusFlags::empty(),
kind: RawNowStatusKind::GENERIC,
Expand All @@ -313,7 +338,7 @@ impl NowStatus<'_> {
}

/// Create a new status with error.
pub fn new_error(error: impl Into<NowStatusError>) -> Self {
pub(crate) fn new_error(error: impl Into<NowStatusError>) -> Self {
let error: NowStatusError = error.into();

let flags = if error.message.is_empty() {
Expand All @@ -331,7 +356,7 @@ impl NowStatus<'_> {
}

/// Convert status to result with 'static error.
pub fn to_result(&self) -> Result<(), NowStatusError> {
pub(crate) fn to_result(&self) -> Result<(), NowStatusError> {
if !self.flags.contains(NowStatusFlags::ERROR) {
return Ok(());
}
Expand Down
17 changes: 5 additions & 12 deletions crates/now-proto-pdu/src/core/string.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
//! String types
use alloc::borrow::Cow;
use core::{ops::Deref, str};
use core::ops::Deref;
use core::str;

use ironrdp_core::{
cast_length, ensure_size, invalid_field_err, Decode, DecodeResult, Encode, EncodeResult, IntoOwned, ReadCursor,
Expand All @@ -14,25 +15,21 @@ use crate::VarU32;
///
/// NOW-PROTO: NOW_VARSTR
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct NowVarStr<'a>(Cow<'a, str>);

impl_pdu_borrowing!(NowVarStr<'_>, OwnedNowVarStr);
pub(crate) struct NowVarStr<'a>(Cow<'a, str>);

impl IntoOwned for NowVarStr<'_> {
type Owned = OwnedNowVarStr;
type Owned = NowVarStr<'static>;

fn into_owned(self) -> Self::Owned {
NowVarStr(Cow::Owned(self.0.into_owned()))
}
}

impl<'a> NowVarStr<'a> {
pub const MAX_SIZE: usize = VarU32::MAX as usize;

const NAME: &'static str = "NOW_VARSTR";

/// Creates `NowVarStr` from std string. Returns error if string is too big for the protocol.
pub fn new(value: impl Into<Cow<'a, str>>) -> EncodeResult<Self> {
pub(crate) fn new(value: impl Into<Cow<'a, str>>) -> EncodeResult<Self> {
let value = value.into();
// IMPORTANT: we need to check for encoded UTF-8 size, not the string length.

Expand All @@ -46,10 +43,6 @@ impl<'a> NowVarStr<'a> {

Ok(NowVarStr(value))
}

pub fn value(&self) -> &str {
&self.0
}
}

impl Encode for NowVarStr<'_> {
Expand Down
11 changes: 4 additions & 7 deletions crates/now-proto-pdu/src/exec/batch.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use alloc::borrow::Cow;
use bitflags::bitflags;

use bitflags::bitflags;
use ironrdp_core::{
cast_length, ensure_fixed_part_size, invalid_field_err, Decode, DecodeResult, Encode, EncodeResult, IntoOwned,
ReadCursor, WriteCursor,
Expand Down Expand Up @@ -95,18 +95,15 @@ impl<'a> NowExecBatchMsg<'a> {
}

fn ensure_message_size(&self) -> EncodeResult<()> {
let _message_size = Self::FIXED_PART_SIZE
.checked_add(self.command.size())
.and_then(|size| size.checked_add(self.directory.size()))
.ok_or_else(|| invalid_field_err!("size", "message size overflow"))?;
ensure_now_message_size!(Self::FIXED_PART_SIZE, self.command.size(), self.directory.size());

Ok(())
}

pub(super) fn decode_from_body(_header: NowHeader, src: &mut ReadCursor<'a>) -> DecodeResult<Self> {
pub(super) fn decode_from_body(header: NowHeader, src: &mut ReadCursor<'a>) -> DecodeResult<Self> {
ensure_fixed_part_size!(in: src);

let flags = NowExecBatchFlags::from_bits_retain(src.read_u16());
let flags = NowExecBatchFlags::from_bits_retain(header.flags);
let session_id = src.read_u32();
let command = NowVarStr::decode(src)?;
let directory = NowVarStr::decode(src)?;
Expand Down
13 changes: 1 addition & 12 deletions crates/now-proto-pdu/src/exec/cancel_rsp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ impl<'a> NowExecCancelRspMsg<'a> {
status: NowStatus::new_success(),
};

msg.ensure_message_size()
.expect("success message size always fits into payload");

msg
}

Expand All @@ -49,7 +46,7 @@ impl<'a> NowExecCancelRspMsg<'a> {
status: NowStatus::new_error(error),
};

msg.ensure_message_size()?;
ensure_now_message_size!(Self::FIXED_PART_SIZE, msg.status.size());

Ok(msg)
}
Expand All @@ -68,14 +65,6 @@ impl<'a> NowExecCancelRspMsg<'a> {
Self::FIXED_PART_SIZE + self.status.size()
}

fn ensure_message_size(&self) -> EncodeResult<()> {
let _message_size = Self::FIXED_PART_SIZE
.checked_add(self.status.size())
.ok_or_else(|| invalid_field_err!("size", "message size overflow"))?;

Ok(())
}

pub(super) fn decode_from_body(_header: NowHeader, src: &mut ReadCursor<'a>) -> DecodeResult<Self> {
ensure_fixed_part_size!(in: src);
let session_id = src.read_u32();
Expand Down
Loading

0 comments on commit 767d128

Please sign in to comment.