From 726bb92c38ee7c1f36e1f13e01ffd97ddca30efd Mon Sep 17 00:00:00 2001 From: Dave Bakker Date: Sun, 8 Oct 2023 22:07:41 +0200 Subject: [PATCH] Update UDP tests --- .../src/bin/udp_sample_application.rs | 124 ++++++----------- .../wasi-sockets-tests/src/lib.rs | 125 +++++++++++++++++- 2 files changed, 168 insertions(+), 81 deletions(-) diff --git a/crates/test-programs/wasi-sockets-tests/src/bin/udp_sample_application.rs b/crates/test-programs/wasi-sockets-tests/src/bin/udp_sample_application.rs index 9322683197f6..7dc64fa586a6 100644 --- a/crates/test-programs/wasi-sockets-tests/src/bin/udp_sample_application.rs +++ b/crates/test-programs/wasi-sockets-tests/src/bin/udp_sample_application.rs @@ -1,108 +1,72 @@ -use wasi::io::poll; use wasi::sockets::network::{ - IpAddressFamily, IpSocketAddress, Ipv4SocketAddress, Ipv6SocketAddress, + IpAddressFamily, IpSocketAddress, Ipv4SocketAddress, Ipv6SocketAddress, Network, }; -use wasi::sockets::{instance_network, udp, udp_create_socket}; +use wasi::sockets::udp::{Datagram, UdpSocket}; use wasi_sockets_tests::*; fn test_sample_application(family: IpAddressFamily, bind_address: IpSocketAddress) { - let first_message = b"Hello, world!"; - let second_message = b"Greetings, planet!"; + let first_message = &[]; + let second_message = b"Hello, world!"; + let third_message = b"Greetings, planet!"; - let net = instance_network::instance_network(); + let net = Network::default(); - let sock = udp_create_socket::create_udp_socket(family).unwrap(); + let server = UdpSocket::new(family).unwrap(); - let sub = sock.subscribe(); + server.blocking_bind(&net, bind_address).unwrap(); + let addr = server.local_address().unwrap(); - sock.start_bind(&net, bind_address).unwrap(); + let client_addr = { + let client = UdpSocket::new(family).unwrap(); + client.blocking_connect(&net, addr).unwrap(); - poll::poll_one(&sub); - drop(sub); - - sock.finish_bind().unwrap(); - - let sub = sock.subscribe(); - - let addr = sock.local_address().unwrap(); - - let client = udp_create_socket::create_udp_socket(family).unwrap(); - let client_sub = client.subscribe(); - - client.start_connect(&net, addr).unwrap(); - poll::poll_one(&client_sub); - client.finish_connect().unwrap(); - - let _client_addr = client.local_address().unwrap(); - - let n = client - .send(&[ - udp::Datagram { - data: vec![], + let datagrams = [ + Datagram { + data: first_message.to_vec(), remote_address: addr, }, - udp::Datagram { - data: first_message.to_vec(), + Datagram { + data: second_message.to_vec(), remote_address: addr, }, - ]) - .unwrap(); - assert_eq!(n, 2); + ]; + client.blocking_send(&datagrams).unwrap(); - drop(client_sub); - drop(client); - - poll::poll_one(&sub); - let datagrams = sock.receive(2).unwrap(); - let mut datagrams = datagrams.into_iter(); - let (first, second) = match (datagrams.next(), datagrams.next(), datagrams.next()) { - (Some(first), Some(second), None) => (first, second), - (Some(_first), None, None) => panic!("only one datagram received"), - (None, None, None) => panic!("no datagrams received"), - _ => panic!("invalid datagram sequence received"), + client.local_address().unwrap() }; - assert!(first.data.is_empty()); - - // TODO: Verify the `remote_address` - //assert_eq!(first.remote_address, client_addr); + { + // Check that we've received our sent messages. + // Not guaranteed to work but should work in practice. + let datagrams = server.blocking_receive(2..100).unwrap(); + assert_eq!(datagrams.len(), 2); - // Check that we sent and recieved our message! - assert_eq!(second.data, first_message); // Not guaranteed to work but should work in practice. + assert_eq!(datagrams[0].data, first_message); + assert_eq!(datagrams[0].remote_address, client_addr); - // TODO: Verify the `remote_address` - //assert_eq!(second.remote_address, client_addr); + assert_eq!(datagrams[1].data, second_message); + assert_eq!(datagrams[1].remote_address, client_addr); + } // Another client - let client = udp_create_socket::create_udp_socket(family).unwrap(); - let client_sub = client.subscribe(); + { + let client = UdpSocket::new(family).unwrap(); + client.blocking_connect(&net, addr).unwrap(); - client.start_connect(&net, addr).unwrap(); - poll::poll_one(&client_sub); - client.finish_connect().unwrap(); - - let n = client - .send(&[udp::Datagram { - data: second_message.to_vec(), + let datagrams = [Datagram { + data: third_message.to_vec(), remote_address: addr, - }]) - .unwrap(); - assert_eq!(n, 1); - - drop(client_sub); - drop(client); + }]; + client.blocking_send(&datagrams).unwrap(); + } - poll::poll_one(&sub); - let datagrams = sock.receive(2).unwrap(); - let mut datagrams = datagrams.into_iter(); - let first = match (datagrams.next(), datagrams.next()) { - (Some(first), None) => first, - (None, None) => panic!("no datagrams received"), - _ => panic!("invalid datagram sequence received"), - }; + { + // Check that we sent and received our message! + let datagrams = server.blocking_receive(1..100).unwrap(); + assert_eq!(datagrams.len(), 1); - // Check that we sent and recieved our message! - assert_eq!(first.data, second_message); // Not guaranteed to work but should work in practice. + assert_eq!(datagrams[0].data, third_message); // Not guaranteed to work but should work in practice. + } } fn main() { diff --git a/crates/test-programs/wasi-sockets-tests/src/lib.rs b/crates/test-programs/wasi-sockets-tests/src/lib.rs index 794fd1d4cdb3..258099e5a05c 100644 --- a/crates/test-programs/wasi-sockets-tests/src/lib.rs +++ b/crates/test-programs/wasi-sockets-tests/src/lib.rs @@ -1,5 +1,7 @@ wit_bindgen::generate!("test-command-with-sockets" in "../../wasi/wit"); +use std::ops::Range; +use wasi::clocks::monotonic_clock; use wasi::io::poll::{self, Pollable}; use wasi::io::streams::{InputStream, OutputStream, StreamError}; use wasi::sockets::instance_network; @@ -8,12 +10,25 @@ use wasi::sockets::network::{ Network, }; use wasi::sockets::tcp::TcpSocket; -use wasi::sockets::{network, tcp_create_socket, udp, udp_create_socket}; +use wasi::sockets::udp::{Datagram, UdpSocket}; +use wasi::sockets::{tcp_create_socket, udp_create_socket}; + +const TIMEOUT_NS: u64 = 1_000_000_000; impl Pollable { pub fn wait(&self) { poll::poll_one(self); } + + pub fn wait_until(&self, timeout: &Pollable) -> Result<(), ErrorCode> { + let ready = poll::poll_list(&[self, timeout]); + assert!(ready.len() > 0); + match ready[0] { + 0 => Ok(()), + 1 => Err(ErrorCode::Timeout), + _ => unreachable!(), + } + } } impl OutputStream { @@ -108,6 +123,89 @@ impl TcpSocket { } } +impl UdpSocket { + pub fn new(address_family: IpAddressFamily) -> Result { + udp_create_socket::create_udp_socket(address_family) + } + + pub fn blocking_bind( + &self, + network: &Network, + local_address: IpSocketAddress, + ) -> Result<(), ErrorCode> { + let sub = self.subscribe(); + + self.start_bind(&network, local_address)?; + + loop { + match self.finish_bind() { + Err(ErrorCode::WouldBlock) => sub.wait(), + result => return result, + } + } + } + + pub fn blocking_connect( + &self, + network: &Network, + remote_address: IpSocketAddress, + ) -> Result<(), ErrorCode> { + let sub = self.subscribe(); + + self.start_connect(&network, remote_address)?; + + loop { + match self.finish_connect() { + Err(ErrorCode::WouldBlock) => sub.wait(), + result => return result, + } + } + } + + pub fn blocking_send(&self, mut datagrams: &[Datagram]) -> Result<(), ErrorCode> { + let timeout = monotonic_clock::subscribe(TIMEOUT_NS, false); + let pollable = self.subscribe(); + + while !datagrams.is_empty() { + match self.send(datagrams) { + Ok(packets_sent) => { + datagrams = &datagrams[(packets_sent as usize)..]; + } + Err(ErrorCode::WouldBlock) => pollable.wait_until(&timeout)?, + Err(err) => return Err(err), + } + } + + Ok(()) + } + + pub fn blocking_receive(&self, count: Range) -> Result, ErrorCode> { + let timeout = monotonic_clock::subscribe(TIMEOUT_NS, false); + let pollable = self.subscribe(); + let mut datagrams = vec![]; + + loop { + match self.receive(count.end - datagrams.len() as u64) { + Ok(mut chunk) => { + datagrams.append(&mut chunk); + + if datagrams.len() >= count.start as usize { + return Ok(datagrams); + } + } + Err(ErrorCode::WouldBlock) => { + if datagrams.len() >= count.start as usize { + return Ok(datagrams); + } else { + pollable.wait_until(&timeout)?; + } + } + Err(err) => return Err(err), + } + } + } +} + impl IpAddress { pub const IPV4_BROADCAST: IpAddress = IpAddress::Ipv4((255, 255, 255, 255)); @@ -189,3 +287,28 @@ impl IpSocketAddress { } } } + +impl PartialEq for Ipv4SocketAddress { + fn eq(&self, other: &Self) -> bool { + self.port == other.port && self.address == other.address + } +} + +impl PartialEq for Ipv6SocketAddress { + fn eq(&self, other: &Self) -> bool { + self.port == other.port + && self.flow_info == other.flow_info + && self.address == other.address + && self.scope_id == other.scope_id + } +} + +impl PartialEq for IpSocketAddress { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Ipv4(l0), Self::Ipv4(r0)) => l0 == r0, + (Self::Ipv6(l0), Self::Ipv6(r0)) => l0 == r0, + _ => false, + } + } +}