diff --git a/io/jvm-native/src/main/scala/fs2/io/net/SocketPlatform.scala b/io/jvm-native/src/main/scala/fs2/io/net/SocketPlatform.scala index 257a435fea..f10eb46aa5 100644 --- a/io/jvm-native/src/main/scala/fs2/io/net/SocketPlatform.scala +++ b/io/jvm-native/src/main/scala/fs2/io/net/SocketPlatform.scala @@ -24,7 +24,7 @@ package io package net import com.comcast.ip4s.{IpAddress, SocketAddress} -import cats.effect.Async +import cats.effect.{Async, Resource} import cats.effect.std.Mutex import cats.syntax.all._ @@ -33,55 +33,69 @@ import java.nio.channels.{AsynchronousSocketChannel, CompletionHandler} import java.nio.{Buffer, ByteBuffer} private[net] trait SocketCompanionPlatform { + + /** Creates a [[Socket]] instance for given `AsynchronousSocketChannel` + * with 16 KiB max. read chunk size and exclusive access guards for both reads abd writes. + */ private[net] def forAsync[F[_]: Async]( ch: AsynchronousSocketChannel ): F[Socket[F]] = - (Mutex[F], Mutex[F]).mapN { (readMutex, writeMutex) => - new AsyncSocket[F](ch, readMutex, writeMutex) + forAsync(ch, maxReadChunkSize = 16384, withExclusiveReads = true, withExclusiveWrites = true) + + /** Creates a [[Socket]] instance for given `AsynchronousSocketChannel`. + * + * @param ch async socket channel for actual reads and writes + * @param maxReadChunkSize maximum chunk size for [[Socket#reads]] method + * @param withExclusiveReads set to `true` if reads should be guarded by mutex + * @param withExclusiveWrites set to `true` if writes should be guarded by mutex + */ + private[net] def forAsync[F[_]]( + ch: AsynchronousSocketChannel, + maxReadChunkSize: Int, + withExclusiveReads: Boolean = false, + withExclusiveWrites: Boolean = false + )(implicit F: Async[F]): F[Socket[F]] = { + def maybeMutex(maybe: Boolean) = F.defer(if (maybe) Mutex[F].map(Some(_)) else F.pure(None)) + (maybeMutex(withExclusiveReads), maybeMutex(withExclusiveWrites)).mapN { + (readMutex, writeMutex) => new AsyncSocket[F](ch, readMutex, writeMutex, maxReadChunkSize) } + } private[net] abstract class BufferedReads[F[_]]( - readMutex: Mutex[F] + readMutex: Option[Mutex[F]], + writeMutex: Option[Mutex[F]], + maxReadChunkSize: Int )(implicit F: Async[F]) extends Socket[F] { - private[this] final val defaultReadSize = 8192 - private[this] var readBuffer: ByteBuffer = ByteBuffer.allocate(defaultReadSize) + private def lock(mutex: Option[Mutex[F]]): Resource[F, Unit] = + mutex match { + case Some(mutex) => mutex.lock + case None => Resource.unit + } private def withReadBuffer[A](size: Int)(f: ByteBuffer => F[A]): F[A] = - readMutex.lock.surround { - F.delay { - if (readBuffer.capacity() < size) - readBuffer = ByteBuffer.allocate(size) - else - (readBuffer: Buffer).limit(size) - f(readBuffer) - }.flatten + lock(readMutex).surround { + F.delay(ByteBuffer.allocate(size)).flatMap(f) } /** Performs a single channel read operation in to the supplied buffer. */ protected def readChunk(buffer: ByteBuffer): F[Int] - /** Copies the contents of the supplied buffer to a `Chunk[Byte]` and clears the buffer contents. */ - private def releaseBuffer(buffer: ByteBuffer): F[Chunk[Byte]] = - F.delay { - val read = buffer.position() - val result = - if (read == 0) Chunk.empty - else { - val dest = new Array[Byte](read) - (buffer: Buffer).flip() - buffer.get(dest) - Chunk.array(dest) - } - (buffer: Buffer).clear() - result - } + /** Performs a channel write operation(-s) from the supplied buffer. + * + * Write could be performed multiple times till all buffer remaining contents are written. + */ + protected def writeChunk(buffer: ByteBuffer): F[Unit] def read(max: Int): F[Option[Chunk[Byte]]] = withReadBuffer(max) { buffer => - readChunk(buffer).flatMap { read => - if (read < 0) F.pure(None) - else releaseBuffer(buffer).map(Some(_)) + readChunk(buffer).map { read => + if (read < 0) None + else if (buffer.position() == 0) Some(Chunk.empty) + else { + (buffer: Buffer).flip() + Some(Chunk.byteBuffer(buffer.asReadOnlyBuffer())) + } } } @@ -89,26 +103,54 @@ private[net] trait SocketCompanionPlatform { withReadBuffer(max) { buffer => def go: F[Chunk[Byte]] = readChunk(buffer).flatMap { readBytes => - if (readBytes < 0 || buffer.position() >= max) - releaseBuffer(buffer) - else go + if (readBytes < 0 || buffer.position() >= max) { + (buffer: Buffer).flip() + F.pure(Chunk.byteBuffer(buffer.asReadOnlyBuffer())) + } else go } go } def reads: Stream[F, Byte] = - Stream.repeatEval(read(defaultReadSize)).unNoneTerminate.unchunks + Stream.resource(lock(readMutex)).flatMap { _ => + Stream.unfoldChunkEval(ByteBuffer.allocate(maxReadChunkSize)) { case buffer => + readChunk(buffer).flatMap { read => + if (read < 0) none[(Chunk[Byte], ByteBuffer)].pure + else if (buffer.position() == 0) { + (Chunk.empty[Byte] -> buffer).some.pure + } else if (buffer.remaining() == 0) { + val bytes = buffer.asReadOnlyBuffer() + val fresh = ByteBuffer.allocate(maxReadChunkSize) + (Chunk.byteBuffer(bytes) -> fresh).some.pure + } else { + val bytes = buffer.duplicate().asReadOnlyBuffer() + val slice = buffer.slice() + (bytes: Buffer).flip() + (Chunk.byteBuffer(bytes) -> slice).some.pure + } + } + } + } + + def write(bytes: Chunk[Byte]): F[Unit] = + lock(writeMutex).surround { + F.delay(bytes.toByteBuffer).flatMap(writeChunk) + } - def writes: Pipe[F, Byte, Nothing] = - _.chunks.foreach(write) + def writes: Pipe[F, Byte, Nothing] = { in => + Stream.resource(lock(writeMutex)).flatMap { _ => + in.chunks.foreach(chunk => writeChunk(chunk.toByteBuffer)) + } + } } private final class AsyncSocket[F[_]]( ch: AsynchronousSocketChannel, - readMutex: Mutex[F], - writeMutex: Mutex[F] + readMutex: Option[Mutex[F]], + writeMutex: Option[Mutex[F]], + maxReadChunkSize: Int )(implicit F: Async[F]) - extends BufferedReads[F](readMutex) { + extends BufferedReads[F](readMutex, writeMutex, maxReadChunkSize) { protected def readChunk(buffer: ByteBuffer): F[Int] = F.async[Int] { cb => @@ -120,24 +162,18 @@ private[net] trait SocketCompanionPlatform { F.delay(Some(endOfInput.voidError)) } - def write(bytes: Chunk[Byte]): F[Unit] = { - def go(buff: ByteBuffer): F[Unit] = - F.async[Int] { cb => - ch.write( - buff, - null, - new IntCompletionHandler(cb) - ) - F.delay(Some(endOfOutput.voidError)) - }.flatMap { written => - if (written >= 0 && buff.remaining() > 0) - go(buff) - else F.unit - } - writeMutex.lock.surround { - F.delay(bytes.toByteBuffer).flatMap(go) + protected def writeChunk(buffer: ByteBuffer): F[Unit] = + F.async[Int] { cb => + ch.write( + buffer, + null, + new IntCompletionHandler(cb) + ) + F.delay(Some(endOfOutput.voidError)) + }.flatMap { written => + if (written < 0 || buffer.remaining() == 0) F.unit + else writeChunk(buffer) } - } def localAddress: F[SocketAddress[IpAddress]] = F.delay( diff --git a/io/jvm/src/main/scala/fs2/io/net/unixsocket/UnixSocketsPlatform.scala b/io/jvm/src/main/scala/fs2/io/net/unixsocket/UnixSocketsPlatform.scala index fa4bbef9fc..085bda00c8 100644 --- a/io/jvm/src/main/scala/fs2/io/net/unixsocket/UnixSocketsPlatform.scala +++ b/io/jvm/src/main/scala/fs2/io/net/unixsocket/UnixSocketsPlatform.scala @@ -28,7 +28,7 @@ import cats.effect.std.Mutex import cats.effect.syntax.all._ import cats.syntax.all._ import com.comcast.ip4s.{IpAddress, SocketAddress} -import fs2.{Chunk, Stream} +import fs2.Stream import fs2.io.file.{Files, Path} import fs2.io.net.Socket import java.nio.ByteBuffer @@ -89,29 +89,36 @@ private[unixsocket] trait UnixSocketsCompanionPlatform { private def makeSocket[F[_]: Async]( ch: SocketChannel ): F[Socket[F]] = - (Mutex[F], Mutex[F]).mapN { (readMutex, writeMutex) => - new AsyncSocket[F](ch, readMutex, writeMutex) + makeSocket(ch, maxReadChunkSize = 16384, withExclusiveReads = true, withExclusiveWrites = true) + + private def makeSocket[F[_]]( + ch: SocketChannel, + maxReadChunkSize: Int, + withExclusiveReads: Boolean, + withExclusiveWrites: Boolean + )(implicit F: Async[F]): F[Socket[F]] = { + def maybeMutex(maybe: Boolean) = F.defer(if (maybe) Mutex[F].map(Some(_)) else F.pure(None)) + (maybeMutex(withExclusiveReads), maybeMutex(withExclusiveWrites)).mapN { + (readMutex, writeMutex) => new AsyncSocket[F](ch, readMutex, writeMutex, maxReadChunkSize) } + } private final class AsyncSocket[F[_]]( ch: SocketChannel, - readMutex: Mutex[F], - writeMutex: Mutex[F] + readMutex: Option[Mutex[F]], + writeMutex: Option[Mutex[F]], + maxReadChunkSize: Int )(implicit F: Async[F]) - extends Socket.BufferedReads[F](readMutex) { + extends Socket.BufferedReads[F](readMutex, writeMutex, maxReadChunkSize) { - def readChunk(buff: ByteBuffer): F[Int] = - F.blocking(ch.read(buff)).cancelable(close) + protected def readChunk(buffer: ByteBuffer): F[Int] = + F.blocking(ch.read(buffer)).cancelable(close) - def write(bytes: Chunk[Byte]): F[Unit] = { - def go(buff: ByteBuffer): F[Unit] = - F.blocking(ch.write(buff)).cancelable(close) *> - F.delay(buff.remaining <= 0).ifM(F.unit, go(buff)) - - writeMutex.lock.surround { - F.delay(bytes.toByteBuffer).flatMap(go) + protected def writeChunk(buffer: ByteBuffer): F[Unit] = + F.blocking(ch.write(buffer)).cancelable(close).flatMap { _ => + if (buffer.remaining() == 0) F.unit + else writeChunk(buffer) } - } def localAddress: F[SocketAddress[IpAddress]] = raiseIpAddressError def remoteAddress: F[SocketAddress[IpAddress]] = raiseIpAddressError