Skip to content

Commit

Permalink
Remove the use of anyhow in netlink-packet-core
Browse files Browse the repository at this point in the history
  • Loading branch information
hch12907 committed Jan 4, 2025
1 parent 63effb9 commit 6c4748a
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 55 deletions.
26 changes: 6 additions & 20 deletions src/buffer.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
// SPDX-License-Identifier: MIT

use byteorder::{ByteOrder, NativeEndian};
use netlink_packet_utils::DecodeError;

use crate::{Field, Rest};
use crate::{CoreError, Field, Rest};

const LENGTH: Field = 0..4;
const MESSAGE_TYPE: Field = 4..6;
Expand Down Expand Up @@ -156,33 +155,20 @@ impl<T: AsRef<[u8]>> NetlinkBuffer<T> {
/// 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
/// assert!(NetlinkBuffer::new_checked(&BYTES[..]).is_err());
/// ```
pub fn new_checked(buffer: T) -> Result<NetlinkBuffer<T>, DecodeError> {
pub fn new_checked(buffer: T) -> Result<NetlinkBuffer<T>, CoreError> {
let packet = Self::new(buffer);
packet.check_buffer_length()?;
Ok(packet)
}

fn check_buffer_length(&self) -> Result<(), DecodeError> {
fn check_buffer_length(&self) -> Result<(), CoreError> {
let len = self.buffer.as_ref().len();
if len < PORT_NUMBER.end {
Err(format!(
"invalid netlink buffer: length is {} but netlink packets are at least {} bytes",
len, PORT_NUMBER.end
)
.into())
Err(CoreError::PacketTooShort { received: len, expected: PORT_NUMBER.end })
} else if len < self.length() as usize {
Err(format!(
"invalid netlink buffer: length field says {} the buffer is {} bytes long",
self.length(),
len
)
.into())
Err(CoreError::NonmatchingLength { expected: self.length(), actual: len })
} else if (self.length() as usize) < PORT_NUMBER.end {
Err(format!(
"invalid netlink buffer: length field says {} but netlink packets are at least {} bytes",
self.length(),
len
).into())
Err(CoreError::InvalidLength { given: self.length(), at_least: len })
} else {
Ok(())
}
Expand Down
17 changes: 7 additions & 10 deletions src/done.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
use std::mem::size_of;

use byteorder::{ByteOrder, NativeEndian};
use netlink_packet_utils::DecodeError;

use crate::{Emitable, Field, Parseable, Rest};
use crate::{CoreError, Emitable, Field, Parseable, Rest};

const CODE: Field = 0..4;
const EXTENDED_ACK: Rest = 4..;
Expand All @@ -27,20 +26,16 @@ impl<T: AsRef<[u8]>> DoneBuffer<T> {
self.buffer
}

pub fn new_checked(buffer: T) -> Result<Self, DecodeError> {
pub fn new_checked(buffer: T) -> Result<Self, CoreError> {
let packet = Self::new(buffer);
packet.check_buffer_length()?;
Ok(packet)
}

fn check_buffer_length(&self) -> Result<(), DecodeError> {
fn check_buffer_length(&self) -> Result<(), CoreError> {
let len = self.buffer.as_ref().len();
if len < DONE_HEADER_LEN {
Err(format!(
"invalid DoneBuffer: length is {len} but DoneBuffer are \
at least {DONE_HEADER_LEN} bytes"
)
.into())
Err(CoreError::InvalidDoneBuffer { received: len })
} else {
Ok(())
}
Expand Down Expand Up @@ -100,7 +95,9 @@ impl Emitable for DoneMessage {
impl<'buffer, T: AsRef<[u8]> + 'buffer> Parseable<DoneBuffer<&'buffer T>>
for DoneMessage
{
fn parse(buf: &DoneBuffer<&'buffer T>) -> Result<DoneMessage, DecodeError> {
type Error = CoreError;

fn parse(buf: &DoneBuffer<&'buffer T>) -> Result<DoneMessage, Self::Error> {
Ok(DoneMessage {
code: buf.code(),
extended_ack: buf.extended_ack().to_vec(),
Expand Down
17 changes: 7 additions & 10 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
use std::{fmt, io, mem::size_of, num::NonZeroI32};

use byteorder::{ByteOrder, NativeEndian};
use netlink_packet_utils::DecodeError;

use crate::{Emitable, Field, Parseable, Rest};
use crate::{CoreError, Emitable, Field, Parseable, Rest};

const CODE: Field = 0..4;
const PAYLOAD: Rest = 4..;
Expand All @@ -27,20 +26,16 @@ impl<T: AsRef<[u8]>> ErrorBuffer<T> {
self.buffer
}

pub fn new_checked(buffer: T) -> Result<Self, DecodeError> {
pub fn new_checked(buffer: T) -> Result<Self, CoreError> {
let packet = Self::new(buffer);
packet.check_buffer_length()?;
Ok(packet)
}

fn check_buffer_length(&self) -> Result<(), DecodeError> {
fn check_buffer_length(&self) -> Result<(), CoreError> {
let len = self.buffer.as_ref().len();
if len < ERROR_HEADER_LEN {
Err(format!(
"invalid ErrorBuffer: length is {len} but ErrorBuffer are \
at least {ERROR_HEADER_LEN} bytes"
)
.into())
Err(CoreError::InvalidErrorBuffer { received: len })
} else {
Ok(())
}
Expand Down Expand Up @@ -118,9 +113,11 @@ impl Emitable for ErrorMessage {
impl<'buffer, T: AsRef<[u8]> + 'buffer> Parseable<ErrorBuffer<&'buffer T>>
for ErrorMessage
{
type Error = CoreError;

fn parse(
buf: &ErrorBuffer<&'buffer T>,
) -> Result<ErrorMessage, DecodeError> {
) -> Result<ErrorMessage, Self::Error> {
// FIXME: The payload of an error is basically a truncated packet, which
// requires custom logic to parse correctly. For now we just
// return it as a Vec<u8> let header: NetlinkHeader = {
Expand Down
7 changes: 4 additions & 3 deletions src/header.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// SPDX-License-Identifier: MIT

use netlink_packet_utils::DecodeError;

use crate::CoreError;
use crate::{buffer::NETLINK_HEADER_LEN, Emitable, NetlinkBuffer, Parseable};

/// A Netlink header representation. A netlink header has the following
Expand Down Expand Up @@ -57,7 +56,9 @@ impl Emitable for NetlinkHeader {
impl<'a, T: AsRef<[u8]> + ?Sized> Parseable<NetlinkBuffer<&'a T>>
for NetlinkHeader
{
fn parse(buf: &NetlinkBuffer<&'a T>) -> Result<NetlinkHeader, DecodeError> {
type Error = CoreError;

fn parse(buf: &NetlinkBuffer<&'a T>) -> Result<NetlinkHeader, Self::Error> {
Ok(NetlinkHeader {
length: buf.length(),
message_type: buf.message_type(),
Expand Down
48 changes: 48 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,51 @@ pub use self::constants::*;

pub(crate) use self::utils::traits::*;
pub(crate) use netlink_packet_utils as utils;
use thiserror::Error;

#[derive(Debug, Error)]
pub enum CoreError {
#[error("invalid netlink buffer: length is {received} but netlink packets are at least {expected} bytes")]
PacketTooShort { received: usize, expected: usize },

#[error("invalid netlink buffer: length field says {expected} but the buffer is {actual} bytes long")]
NonmatchingLength { expected: u32, actual: usize },

#[error("invalid netlink buffer: length field says {given} but netlink packets are at least {at_least} bytes")]
InvalidLength { given: u32, at_least: usize },

#[error(
"invalid ErrorBuffer: length is {received}, expected at least 4 bytes"
)]
InvalidErrorBuffer { received: usize },

#[error(
"invalid DoneBuffer: length is {received}, expected at least 4 bytes"
)]
InvalidDoneBuffer { received: usize },

#[error("invalid Netlink header")]
InvalidHeader {
#[source]
due_to: Box<Self>,
},

#[error("invalid Netlink message of type NLMSG_ERROR")]
InvalidErrorMsg {
#[source]
due_to: Box<Self>,
},

#[error("invalid Netlink message of type NLMSG_DONE")]
InvalidDoneMsg {
#[source]
due_to: Box<Self>,
},

#[error("failed to parse the netlink message, of type {message_type}")]
ParseFailure {
message_type: u16,
#[source]
due_to: Box<dyn std::error::Error>,
},
}
31 changes: 19 additions & 12 deletions src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@

use std::fmt::Debug;

use anyhow::Context;
use netlink_packet_utils::DecodeError;

use crate::{
payload::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN},
DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage,
CoreError, DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage,
NetlinkBuffer, NetlinkDeserializable, NetlinkHeader, NetlinkPayload,
NetlinkSerializable, Parseable,
};
Expand Down Expand Up @@ -39,7 +36,7 @@ where
I: NetlinkDeserializable,
{
/// Parse the given buffer as a netlink message
pub fn deserialize(buffer: &[u8]) -> Result<Self, DecodeError> {
pub fn deserialize(buffer: &[u8]) -> Result<Self, CoreError> {
let netlink_buffer = NetlinkBuffer::new_checked(&buffer)?;
<Self as Parseable<NetlinkBuffer<&&[u8]>>>::parse(&netlink_buffer)
}
Expand Down Expand Up @@ -88,33 +85,43 @@ where
B: AsRef<[u8]> + 'buffer,
I: NetlinkDeserializable,
{
fn parse(buf: &NetlinkBuffer<&'buffer B>) -> Result<Self, DecodeError> {
type Error = CoreError;

fn parse(buf: &NetlinkBuffer<&'buffer B>) -> Result<Self, Self::Error> {
use self::NetlinkPayload::*;

let header =
<NetlinkHeader as Parseable<NetlinkBuffer<&'buffer B>>>::parse(buf)
.context("failed to parse netlink header")?;
.map_err(|e| CoreError::InvalidHeader { due_to: e.into() })?;

let bytes = buf.payload();
let payload = match header.message_type {
NLMSG_ERROR => {
let msg = ErrorBuffer::new_checked(&bytes)
.and_then(|buf| ErrorMessage::parse(&buf))
.context("failed to parse NLMSG_ERROR")?;
.map_err(|e| CoreError::InvalidErrorMsg {
due_to: e.into(),
})?;
Error(msg)
}
NLMSG_NOOP => Noop,
NLMSG_DONE => {
let msg = DoneBuffer::new_checked(&bytes)
.and_then(|buf| DoneMessage::parse(&buf))
.context("failed to parse NLMSG_DONE")?;
.map_err(|e| CoreError::InvalidDoneMsg {
due_to: e.into(),
})?;
Done(msg)
}
NLMSG_OVERRUN => Overrun(bytes.to_vec()),
message_type => {
let inner_msg = I::deserialize(&header, bytes).context(
format!("Failed to parse message with type {message_type}"),
)?;
let inner_msg =
I::deserialize(&header, bytes).map_err(|e| {
CoreError::ParseFailure {
message_type,
due_to: e.into(),
}
})?;
InnerMessage(inner_msg)
}
};
Expand Down

0 comments on commit 6c4748a

Please sign in to comment.