diff --git a/Cargo.toml b/Cargo.toml index 44f7a1067..8b0a28091 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,7 @@ quickcheck = { version = "0.7", default-features = false } tokio-io = "0.1.11" tokio-tcp = "0.1.3" tokio-threadpool = "0.1.10" +futures = "0.1" [features] default = ["miniz-sys"] diff --git a/src/gz/bufread.rs b/src/gz/bufread.rs index ebb795c60..f06205f41 100644 --- a/src/gz/bufread.rs +++ b/src/gz/bufread.rs @@ -3,6 +3,11 @@ use std::io; use std::io::prelude::*; use std::mem; +#[cfg(feature = "tokio")] +use futures::Poll; +#[cfg(feature = "tokio")] +use tokio_io::{AsyncRead, AsyncWrite}; + use super::{GzBuilder, GzHeader}; use super::{FCOMMENT, FEXTRA, FHCRC, FNAME}; use crc::CrcReader; @@ -211,6 +216,19 @@ impl GzEncoder { } } +#[inline] +fn finish(buf: &[u8; 8]) -> (u32, u32) { + let crc = ((buf[0] as u32) << 0) + | ((buf[1] as u32) << 8) + | ((buf[2] as u32) << 16) + | ((buf[3] as u32) << 24); + let amt = ((buf[4] as u32) << 0) + | ((buf[5] as u32) << 8) + | ((buf[6] as u32) << 16) + | ((buf[7] as u32) << 24); + (crc, amt) +} + impl Read for GzEncoder { fn read(&mut self, mut into: &mut [u8]) -> io::Result { let mut amt = 0; @@ -280,69 +298,90 @@ impl Write for GzEncoder { /// ``` #[derive(Debug)] pub struct GzDecoder { - inner: CrcReader>, - header: Option>, - finished: bool, + inner: GzState, + header: Option, + reader: CrcReader>, + multi: bool +} + +#[derive(Debug)] +enum GzState { + Header(Vec), + Body, + Finished(usize, [u8; 8]), + Err(io::Error), + End +} + +struct Buffer<'a, T> { + buf: io::Take>>, + reader: &'a mut T +} + +impl<'a, T> Buffer<'a, T> { + fn new(buf: &'a mut Vec, reader: &'a mut T) -> Buffer<'a, T> { + let len = buf.len(); + Buffer { buf: io::Cursor::new(buf).take(len as _), reader } + } +} + +impl<'a, T: Read> Read for Buffer<'a, T> { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let mut len = self.buf.read(buf)?; + if buf.len() > len { + let len2 = self.reader.read(&mut buf[len..])?; + self.buf.get_mut().get_mut().extend_from_slice(&buf[len..][..len2]); + len += len2; + } + Ok(len) + } } impl GzDecoder { /// Creates a new decoder from the given reader, immediately parsing the /// gzip header. pub fn new(mut r: R) -> GzDecoder { - let header = read_gz_header(&mut r); + let mut buf = Vec::with_capacity(10); // minimum header length + let mut header = None; + + let result = { + let mut reader = Buffer::new(&mut buf, &mut r); + read_gz_header(&mut reader) + }; + + let state = match result { + Ok(hdr) => { + header = Some(hdr); + GzState::Body + }, + Err(ref err) if io::ErrorKind::WouldBlock == err.kind() + => GzState::Header(buf), + Err(err) => GzState::Err(err) + }; - let flate = deflate::bufread::DeflateDecoder::new(r); GzDecoder { - inner: CrcReader::new(flate), - header: Some(header), - finished: false, + inner: state, + reader: CrcReader::new(deflate::bufread::DeflateDecoder::new(r)), + multi: false, + header } } - fn finish(&mut self) -> io::Result<()> { - if self.finished { - return Ok(()); - } - let ref mut buf = [0u8; 8]; - { - let mut len = 0; - - while len < buf.len() { - match self.inner.get_mut().get_mut().read(&mut buf[len..])? { - 0 => return Err(corrupt()), - n => len += n, - } - } - } - - let crc = ((buf[0] as u32) << 0) - | ((buf[1] as u32) << 8) - | ((buf[2] as u32) << 16) - | ((buf[3] as u32) << 24); - let amt = ((buf[4] as u32) << 0) - | ((buf[5] as u32) << 8) - | ((buf[6] as u32) << 16) - | ((buf[7] as u32) << 24); - if crc != self.inner.crc().sum() { - return Err(corrupt()); - } - if amt != self.inner.crc().amount() { - return Err(corrupt()); - } - self.finished = true; - Ok(()) + fn multi(mut self, flag: bool) -> GzDecoder { + self.multi = flag; + self } } impl GzDecoder { /// Returns the header associated with this stream, if it was valid pub fn header(&self) -> Option<&GzHeader> { - self.header.as_ref().and_then(|h| h.as_ref().ok()) + self.header.as_ref() } /// Acquires a reference to the underlying reader. pub fn get_ref(&self) -> &R { - self.inner.get_ref().get_ref() + self.reader.get_ref().get_ref() } /// Acquires a mutable reference to the underlying stream. @@ -350,38 +389,116 @@ impl GzDecoder { /// Note that mutation of the stream may result in surprising results if /// this encoder is continued to be used. pub fn get_mut(&mut self) -> &mut R { - self.inner.get_mut().get_mut() + self.reader.get_mut().get_mut() } /// Consumes this decoder, returning the underlying reader. pub fn into_inner(self) -> R { - self.inner.into_inner().into_inner() + self.reader.into_inner().into_inner() } } impl Read for GzDecoder { fn read(&mut self, into: &mut [u8]) -> io::Result { - match self.header { - None => return Ok(0), // error already returned, - Some(Ok(_)) => {} - Some(Err(_)) => match self.header.take().unwrap() { - Ok(_) => panic!(), - Err(e) => return Err(e), - }, - } - if into.is_empty() { - return Ok(0); - } - match self.inner.read(into)? { - 0 => { - self.finish()?; - Ok(0) - } - n => Ok(n), + let GzDecoder { inner, header, reader, multi } = self; + + loop { + *inner = match mem::replace(inner, GzState::End) { + GzState::Header(mut buf) => { + let result = { + let mut reader = Buffer::new(&mut buf, reader.get_mut().get_mut()); + read_gz_header(&mut reader) + }; + let hdr = result + .map_err(|err| { + if io::ErrorKind::WouldBlock == err.kind() { + *inner = GzState::Header(buf); + } + + err + })?; + *header = Some(hdr); + GzState::Body + }, + GzState::Body => { + if into.is_empty() { + *inner = GzState::Body; + return Ok(0); + } + + let n = reader.read(into) + .map_err(|err| { + if io::ErrorKind::WouldBlock == err.kind() { + *inner = GzState::Body; + } + + err + })?; + + match n { + 0 => GzState::Finished(0, [0; 8]), + n => { + *inner = GzState::Body; + return Ok(n); + } + } + }, + GzState::Finished(pos, mut buf) => if pos < buf.len() { + let n = reader.get_mut().get_mut() + .read(&mut buf[pos..]) + .and_then(|n| if n == 0 { + Err(io::ErrorKind::UnexpectedEof.into()) + } else { + Ok(n) + }) + .map_err(|err| { + if io::ErrorKind::WouldBlock == err.kind() { + *inner = GzState::Finished(pos, buf); + } + + err + })?; + + GzState::Finished(pos + n, buf) + } else { + let (crc, amt) = finish(&buf); + + if crc != reader.crc().sum() || amt != reader.crc().amount() { + return Err(corrupt()); + } else if *multi { + let is_eof = reader.get_mut().get_mut() + .fill_buf() + .map(|buf| buf.is_empty()) + .map_err(|err| { + if io::ErrorKind::WouldBlock == err.kind() { + *inner = GzState::Finished(pos, buf); + } + + err + })?; + + if is_eof { + GzState::End + } else { + reader.reset(); + reader.get_mut().reset_data(); + header.take(); + GzState::Header(Vec::with_capacity(10)) + } + } else { + GzState::End + } + }, + GzState::Err(err) => return Err(err), + GzState::End => return Ok(0) + }; } } } +#[cfg(feature = "tokio")] +impl AsyncRead for GzDecoder {} + impl Write for GzDecoder { fn write(&mut self, buf: &[u8]) -> io::Result { self.get_mut().write(buf) @@ -392,6 +509,13 @@ impl Write for GzDecoder { } } +#[cfg(feature = "tokio")] +impl AsyncWrite for GzDecoder { + fn shutdown(&mut self) -> Poll<(), io::Error> { + self.get_mut().shutdown() + } +} + /// A gzip streaming decoder that decodes all members of a multistream /// /// A gzip member consists of a header, compressed data and a trailer. The [gzip @@ -433,87 +557,26 @@ impl Write for GzDecoder { /// } /// ``` #[derive(Debug)] -pub struct MultiGzDecoder { - inner: CrcReader>, - header: io::Result, - finished: bool, -} +pub struct MultiGzDecoder(GzDecoder); impl MultiGzDecoder { /// Creates a new decoder from the given reader, immediately parsing the /// (first) gzip header. If the gzip stream contains multiple members all will /// be decoded. - pub fn new(mut r: R) -> MultiGzDecoder { - let header = read_gz_header(&mut r); - - let flate = deflate::bufread::DeflateDecoder::new(r); - MultiGzDecoder { - inner: CrcReader::new(flate), - header: header, - finished: false, - } - } - - fn finish_member(&mut self) -> io::Result { - if self.finished { - return Ok(0); - } - let ref mut buf = [0u8; 8]; - { - let mut len = 0; - - while len < buf.len() { - match self.inner.get_mut().get_mut().read(&mut buf[len..])? { - 0 => return Err(corrupt()), - n => len += n, - } - } - } - - let crc = ((buf[0] as u32) << 0) - | ((buf[1] as u32) << 8) - | ((buf[2] as u32) << 16) - | ((buf[3] as u32) << 24); - let amt = ((buf[4] as u32) << 0) - | ((buf[5] as u32) << 8) - | ((buf[6] as u32) << 16) - | ((buf[7] as u32) << 24); - if crc != self.inner.crc().sum() as u32 { - return Err(corrupt()); - } - if amt != self.inner.crc().amount() { - return Err(corrupt()); - } - let remaining = match self.inner.get_mut().get_mut().fill_buf() { - Ok(b) => { - if b.is_empty() { - self.finished = true; - return Ok(0); - } else { - b.len() - } - } - Err(e) => return Err(e), - }; - - let next_header = read_gz_header(self.inner.get_mut().get_mut()); - drop(mem::replace(&mut self.header, next_header)); - self.inner.reset(); - self.inner.get_mut().reset_data(); - - Ok(remaining) + pub fn new(r: R) -> MultiGzDecoder { + MultiGzDecoder(GzDecoder::new(r).multi(true)) } } impl MultiGzDecoder { /// Returns the current header associated with this stream, if it's valid pub fn header(&self) -> Option<&GzHeader> { - self.header.as_ref().ok() + self.0.header() } /// Acquires a reference to the underlying reader. pub fn get_ref(&self) -> &R { - self.inner.get_ref().get_ref() + self.0.get_ref() } /// Acquires a mutable reference to the underlying stream. @@ -521,32 +584,24 @@ impl MultiGzDecoder { /// Note that mutation of the stream may result in surprising results if /// this encoder is continued to be used. pub fn get_mut(&mut self) -> &mut R { - self.inner.get_mut().get_mut() + self.0.get_mut() } /// Consumes this decoder, returning the underlying reader. pub fn into_inner(self) -> R { - self.inner.into_inner().into_inner() + self.0.into_inner() } } impl Read for MultiGzDecoder { fn read(&mut self, into: &mut [u8]) -> io::Result { - if let Err(ref mut e) = self.header { - let another_error = io::ErrorKind::Other.into(); - return Err(mem::replace(e, another_error)); - } - match self.inner.read(into)? { - 0 => match self.finish_member() { - Ok(0) => Ok(0), - Ok(_) => self.read(into), - Err(e) => Err(e), - }, - n => Ok(n), - } + self.0.read(into) } } +#[cfg(feature = "tokio")] +impl AsyncRead for MultiGzDecoder {} + impl Write for MultiGzDecoder { fn write(&mut self, buf: &[u8]) -> io::Result { self.get_mut().write(buf) @@ -556,3 +611,10 @@ impl Write for MultiGzDecoder { self.get_mut().flush() } } + +#[cfg(feature = "tokio")] +impl AsyncWrite for MultiGzDecoder { + fn shutdown(&mut self) -> Poll<(), io::Error> { + self.get_mut().shutdown() + } +} diff --git a/src/gz/mod.rs b/src/gz/mod.rs index ee16b0087..b9043b7d3 100644 --- a/src/gz/mod.rs +++ b/src/gz/mod.rs @@ -18,7 +18,7 @@ pub mod write; /// /// The header can contain metadata about the file that was compressed, if /// present. -#[derive(PartialEq, Clone, Debug)] +#[derive(PartialEq, Clone, Debug, Default)] pub struct GzHeader { extra: Option>, filename: Option>, diff --git a/src/gz/read.rs b/src/gz/read.rs index a73a54699..774f4f381 100644 --- a/src/gz/read.rs +++ b/src/gz/read.rs @@ -1,6 +1,11 @@ use std::io; use std::io::prelude::*; +#[cfg(feature = "tokio")] +use futures::Poll; +#[cfg(feature = "tokio")] +use tokio_io::{AsyncRead, AsyncWrite}; + use super::bufread; use super::{GzBuilder, GzHeader}; use bufreader::BufReader; @@ -170,6 +175,9 @@ impl Read for GzDecoder { } } +#[cfg(feature = "tokio")] +impl AsyncRead for GzDecoder {} + impl Write for GzDecoder { fn write(&mut self, buf: &[u8]) -> io::Result { self.get_mut().write(buf) @@ -180,6 +188,13 @@ impl Write for GzDecoder { } } +#[cfg(feature = "tokio")] +impl AsyncWrite for GzDecoder { + fn shutdown(&mut self) -> Poll<(), io::Error> { + self.get_mut().shutdown() + } +} + /// A gzip streaming decoder that decodes all members of a multistream /// /// A gzip member consists of a header, compressed data and a trailer. The [gzip @@ -267,6 +282,9 @@ impl Read for MultiGzDecoder { } } +#[cfg(feature = "tokio")] +impl AsyncRead for MultiGzDecoder {} + impl Write for MultiGzDecoder { fn write(&mut self, buf: &[u8]) -> io::Result { self.get_mut().write(buf) @@ -276,3 +294,10 @@ impl Write for MultiGzDecoder { self.get_mut().flush() } } + +#[cfg(feature = "tokio")] +impl AsyncWrite for MultiGzDecoder { + fn shutdown(&mut self) -> Poll<(), io::Error> { + self.get_mut().shutdown() + } +} diff --git a/tests/async-reader.rs b/tests/async-reader.rs new file mode 100644 index 000000000..b95e4a19b --- /dev/null +++ b/tests/async-reader.rs @@ -0,0 +1,94 @@ +extern crate flate2; +extern crate tokio_io; +extern crate futures; + +use flate2::read::{GzDecoder, MultiGzDecoder}; +use std::cmp; +use std::fs::File; +use std::io::{self, Read}; +use tokio_io::AsyncRead; +use tokio_io::io::read_to_end; +use futures::prelude::*; +use futures::task; + + +struct BadReader { + reader: T, + x: bool +} + +impl BadReader { + fn new(reader: T) -> BadReader { + BadReader { reader, x: true } + } +} + +impl Read for BadReader { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + if self.x { + self.x = false; + let len = cmp::min(buf.len(), 1); + self.reader.read(&mut buf[..len]) + } else { + self.x = true; + Err(io::ErrorKind::WouldBlock.into()) + } + } +} + +struct AssertAsync(T); + +impl Read for AssertAsync { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } +} + +impl AsyncRead for AssertAsync {} + +struct AlwaysNotify(T); + +impl Future for AlwaysNotify { + type Item = T::Item; + type Error = T::Error; + + fn poll(&mut self) -> Poll { + let ret = self.0.poll(); + if let Ok(Async::NotReady) = &ret { + task::current().notify(); + } + ret + } +} + +#[test] +fn test_gz_asyncread() { + let f = File::open("tests/good-file.gz").unwrap(); + + let fut = read_to_end(AssertAsync(GzDecoder::new(BadReader::new(f))), Vec::new()); + let (_, content) = AlwaysNotify(fut).wait().unwrap(); + + let mut expected = Vec::new(); + File::open("tests/good-file.txt") + .unwrap() + .read_to_end(&mut expected) + .unwrap(); + + assert_eq!(content, expected); +} + +#[test] +fn test_multi_gz_asyncread() { + let f = File::open("tests/multi.gz").unwrap(); + + let fut = read_to_end(AssertAsync(MultiGzDecoder::new(BadReader::new(f))), Vec::new()); + let (_, content) = AlwaysNotify(fut).wait().unwrap(); + + let mut expected = Vec::new(); + File::open("tests/multi.txt") + .unwrap() + .read_to_end(&mut expected) + .unwrap(); + + assert_eq!(content, expected); +}