diff --git a/build.sbt b/build.sbt index 411b8fedf0..e5c843cc27 100644 --- a/build.sbt +++ b/build.sbt @@ -2,7 +2,7 @@ import com.typesafe.tools.mima.core._ Global / onChangedBuildSource := ReloadOnSourceChanges -ThisBuild / tlBaseVersion := "3.6" +ThisBuild / tlBaseVersion := "3.7" ThisBuild / organization := "co.fs2" ThisBuild / organizationName := "Functional Streams for Scala" @@ -178,7 +178,14 @@ ThisBuild / mimaBinaryIssueFilters ++= Seq( ), ProblemFilters.exclude[DirectMissingMethodProblem]("fs2.io.package.readBytesFromInputStream"), ProblemFilters.exclude[DirectMissingMethodProblem]("fs2.io.package.readInputStreamGeneric"), - ProblemFilters.exclude[DirectMissingMethodProblem]("fs2.io.package.") + ProblemFilters.exclude[DirectMissingMethodProblem]("fs2.io.package."), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("fs2.io.net.Socket.forAsync"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "fs2.io.net.SocketCompanionPlatform#AsyncSocket.this" + ), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "fs2.io.net.unixsocket.UnixSocketsCompanionPlatform#AsyncSocket.this" + ) ) lazy val root = tlCrossRootProject @@ -213,9 +220,9 @@ lazy val core = crossProject(JVMPlatform, JSPlatform, NativePlatform) libraryDependencies ++= Seq( "org.scodec" %%% "scodec-bits" % "1.1.35", "org.typelevel" %%% "cats-core" % "2.9.0", - "org.typelevel" %%% "cats-effect" % "3.4.7", - "org.typelevel" %%% "cats-effect-laws" % "3.4.7" % Test, - "org.typelevel" %%% "cats-effect-testkit" % "3.4.7" % Test, + "org.typelevel" %%% "cats-effect" % "3.5.0-RC2", + "org.typelevel" %%% "cats-effect-laws" % "3.5.0-RC2" % Test, + "org.typelevel" %%% "cats-effect-testkit" % "3.5.0-RC2" % Test, "org.typelevel" %%% "cats-laws" % "2.9.0" % Test, "org.typelevel" %%% "discipline-munit" % "2.0.0-M3" % Test, "org.typelevel" %%% "munit-cats-effect" % "2.0.0-M3" % Test, diff --git a/core/shared/src/main/scala/fs2/concurrent/Channel.scala b/core/shared/src/main/scala/fs2/concurrent/Channel.scala index 86bda25773..70b9cf7ddd 100644 --- a/core/shared/src/main/scala/fs2/concurrent/Channel.scala +++ b/core/shared/src/main/scala/fs2/concurrent/Channel.scala @@ -153,8 +153,8 @@ object Channel { def send(a: A) = F.deferred[Unit].flatMap { producer => - F.uncancelable { poll => - state.modify { + state.flatModifyFull { case (poll, state) => + state match { case s @ State(_, _, _, _, closed @ true) => (s, Channel.closed[Unit].pure[F]) @@ -169,12 +169,12 @@ object Channel { State(values, size, None, (a, producer) :: producers, false), notifyStream(waiting).as(rightUnit) <* waitOnBound(producer, poll) ) - }.flatten + } } } def trySend(a: A) = - state.modify { + state.flatModify { case s @ State(_, _, _, _, closed @ true) => (s, Channel.closed[Boolean].pure[F]) @@ -186,22 +186,19 @@ object Channel { ) else (s, rightFalse.pure[F]) - }.flatten + } def close = - state - .modify { - case s @ State(_, _, _, _, closed @ true) => - (s, Channel.closed[Unit].pure[F]) + state.flatModify { + case s @ State(_, _, _, _, closed @ true) => + (s, Channel.closed[Unit].pure[F]) - case State(values, size, waiting, producers, closed @ false) => - ( - State(values, size, None, producers, true), - notifyStream(waiting).as(rightUnit) <* signalClosure - ) - } - .flatten - .uncancelable + case State(values, size, waiting, producers, closed @ false) => + ( + State(values, size, None, producers, true), + notifyStream(waiting).as(rightUnit) <* signalClosure + ) + } def isClosed = closedGate.tryGet.map(_.isDefined) diff --git a/core/shared/src/main/scala/fs2/concurrent/Signal.scala b/core/shared/src/main/scala/fs2/concurrent/Signal.scala index c85940c2b4..5c4319d743 100644 --- a/core/shared/src/main/scala/fs2/concurrent/Signal.scala +++ b/core/shared/src/main/scala/fs2/concurrent/Signal.scala @@ -26,6 +26,7 @@ import cats.data.OptionT import cats.kernel.Eq import cats.effect.kernel.{Concurrent, Deferred, Ref, Resource} import cats.effect.std.MapRef +import cats.effect.syntax.all._ import cats.syntax.all._ import cats.{Applicative, Functor, Invariant, Monad} @@ -270,14 +271,16 @@ object SignallingRef { private[this] def getAndDiscreteUpdatesImpl = { def go(id: Long, lastSeen: Long): Stream[F, A] = { def getNext: F[(A, Long)] = - F.deferred[(A, Long)].flatMap { wait => - state.modify { case state @ State(value, lastUpdate, listeners) => - if (lastUpdate != lastSeen) - state -> (value -> lastUpdate).pure[F] - else - state.copy(listeners = listeners + (id -> wait)) -> wait.get - }.flatten - } + F.deferred[(A, Long)] + .flatMap { wait => + state.modify { case state @ State(value, lastUpdate, listeners) => + if (lastUpdate != lastSeen) + state -> (value -> lastUpdate).pure[F] + else + state.copy(listeners = listeners + (id -> wait)) -> wait.get + } + } + .flatten // cancelable Stream.eval(getNext).flatMap { case (a, lastUpdate) => Stream.emit(a) ++ go(id, lastSeen = lastUpdate) @@ -297,10 +300,10 @@ object SignallingRef { def update(f: A => A): F[Unit] = modify(a => (f(a), ())) def modify[B](f: A => (A, B)): F[B] = - state.modify(updateAndNotify(_, f)).flatten + state.flatModify(updateAndNotify(_, f)) def tryModify[B](f: A => (A, B)): F[Option[B]] = - state.tryModify(updateAndNotify(_, f)).flatMap(_.sequence) + state.tryModify(updateAndNotify(_, f)).flatMap(_.sequence).uncancelable def tryUpdate(f: A => A): F[Boolean] = tryModify(a => (f(a), ())).map(_.isDefined) @@ -529,23 +532,25 @@ object SignallingMapRef { private[this] def getAndDiscreteUpdatesImpl = { def go(id: Long, lastSeen: Long): Stream[F, Option[V]] = { def getNext: F[(Option[V], Long)] = - F.deferred[(Option[V], Long)].flatMap { wait => - state.modify { state => - val keyState = state.keys.get(k) - val value = keyState.flatMap(_.value) - val lastUpdate = keyState.fold(-1L)(_.lastUpdate) - val listeners = keyState.fold(LongMap.empty[Listener])(_.listeners) - - if (lastUpdate != lastSeen) - state -> (value -> lastUpdate).pure[F] - else { - val newKeys = - state.keys - .updated(k, KeyState(value, lastUpdate, listeners.updated(id, wait))) - state.copy(keys = newKeys) -> wait.get + F.deferred[(Option[V], Long)] + .flatMap { wait => + state.modify { state => + val keyState = state.keys.get(k) + val value = keyState.flatMap(_.value) + val lastUpdate = keyState.fold(-1L)(_.lastUpdate) + val listeners = keyState.fold(LongMap.empty[Listener])(_.listeners) + + if (lastUpdate != lastSeen) + state -> (value -> lastUpdate).pure[F] + else { + val newKeys = + state.keys + .updated(k, KeyState(value, lastUpdate, listeners.updated(id, wait))) + state.copy(keys = newKeys) -> wait.get + } } - }.flatten - } + } + .flatten // cancelable Stream.eval(getNext).flatMap { case (v, lastUpdate) => Stream.emit(v) ++ go(id, lastSeen = lastUpdate) @@ -580,10 +585,10 @@ object SignallingMapRef { def update(f: Option[V] => Option[V]): F[Unit] = modify(v => (f(v), ())) def modify[U](f: Option[V] => (Option[V], U)): F[U] = - state.modify(updateAndNotify(_, k, f)).flatten + state.flatModify(updateAndNotify(_, k, f)) def tryModify[U](f: Option[V] => (Option[V], U)): F[Option[U]] = - state.tryModify(updateAndNotify(_, k, f)).flatMap(_.sequence) + state.tryModify(updateAndNotify(_, k, f)).flatMap(_.sequence).uncancelable def tryUpdate(f: Option[V] => Option[V]): F[Boolean] = tryModify(a => (f(a), ())).map(_.isDefined) diff --git a/core/shared/src/main/scala/fs2/internal/ScopedResource.scala b/core/shared/src/main/scala/fs2/internal/ScopedResource.scala index 15f42c1787..3d8e3d5f21 100644 --- a/core/shared/src/main/scala/fs2/internal/ScopedResource.scala +++ b/core/shared/src/main/scala/fs2/internal/ScopedResource.scala @@ -142,7 +142,7 @@ private[internal] object ScopedResource { .flatMap(finalizer => finalizer.map(_(ec)).getOrElse(pru)) def acquired(finalizer: Resource.ExitCase => F[Unit]): F[Either[Throwable, Boolean]] = - state.modify { s => + state.flatModify { s => if (s.isFinished) // state is closed and there are no leases, finalizer has to be invoked right away s -> finalizer(Resource.ExitCase.Succeeded).as(false).attempt @@ -154,7 +154,7 @@ private[internal] object ScopedResource { Boolean ]).pure[F] } - }.flatten + } def lease: F[Option[Lease[F]]] = state.modify { s => @@ -173,14 +173,14 @@ private[internal] object ScopedResource { } .flatMap { now => if (now.isFinished) - state.modify { s => + state.flatModify { s => // Scope is closed and this is last lease, assure finalizer is removed from the state and run // previous finalizer shall be always present at this point, this shall invoke it s.copy(finalizer = None) -> (s.finalizer match { case Some(ff) => ff(Resource.ExitCase.Succeeded) case None => pru }) - }.flatten + } else pru } diff --git a/core/shared/src/test/scala/fs2/StreamInterruptSuite.scala b/core/shared/src/test/scala/fs2/StreamInterruptSuite.scala index 34adc1dbbf..fde67705b6 100644 --- a/core/shared/src/test/scala/fs2/StreamInterruptSuite.scala +++ b/core/shared/src/test/scala/fs2/StreamInterruptSuite.scala @@ -56,96 +56,93 @@ class StreamInterruptSuite extends Fs2Suite { } } - // These IO streams cannot be interrupted on JS b/c they never yield execution - if (isJVM) { - test("3 - constant stream") { - val interruptSoon = Stream.sleep_[IO](20.millis).compile.drain.attempt - Stream - .constant(true) - .interruptWhen(interruptSoon) - .compile - .drain - .replicateA(interruptRepeatCount) - } + test("3 - constant stream") { + val interruptSoon = Stream.sleep_[IO](20.millis).compile.drain.attempt + Stream + .constant(true) + .interruptWhen(interruptSoon) + .compile + .drain + .replicateA(interruptRepeatCount) + } - test("4 - interruption of constant stream with a flatMap") { - val interrupt = - Stream.sleep_[IO](20.millis).compile.drain.attempt - Stream - .constant(true) - .interruptWhen(interrupt) - .flatMap(_ => Stream.emit(1)) - .compile - .drain - .replicateA(interruptRepeatCount) - } + test("4 - interruption of constant stream with a flatMap") { + val interrupt = + Stream.sleep_[IO](20.millis).compile.drain.attempt + Stream + .constant(true) + .interruptWhen(interrupt) + .flatMap(_ => Stream.emit(1)) + .compile + .drain + .replicateA(interruptRepeatCount) + } - test("5 - interruption of an infinitely recursive stream") { - val interrupt = - Stream.sleep_[IO](20.millis).compile.drain.attempt + test("5 - interruption of an infinitely recursive stream") { + val interrupt = + Stream.sleep_[IO](20.millis).compile.drain.attempt - def loop(i: Int): Stream[IO, Int] = - Stream.emit(i).flatMap(i => Stream.emit(i) ++ loop(i + 1)) + def loop(i: Int): Stream[IO, Int] = + Stream.emit(i).flatMap(i => Stream.emit(i) ++ loop(i + 1)) - loop(0) - .interruptWhen(interrupt) - .compile - .drain - .replicateA(interruptRepeatCount) - } + loop(0) + .interruptWhen(interrupt) + .compile + .drain + .replicateA(interruptRepeatCount) + } - test("6 - interruption of an infinitely recursive stream that never emits") { - val interrupt = - Stream.sleep_[IO](20.millis).compile.drain.attempt + test("6 - interruption of an infinitely recursive stream that never emits") { + val interrupt = + Stream.sleep_[IO](20.millis).compile.drain.attempt - def loop: Stream[IO, Nothing] = - Stream.eval(IO.unit) >> loop + def loop: Stream[IO, Nothing] = + Stream.eval(IO.unit) >> loop - loop - .interruptWhen(interrupt) - .compile - .drain - .replicateA(interruptRepeatCount) - } + loop + .interruptWhen(interrupt) + .compile + .drain + .replicateA(interruptRepeatCount) + } - test("7 - interruption of an infinitely recursive stream that never emits and has no eval") { - val interrupt = Stream.sleep_[IO](20.millis).compile.drain.attempt - def loop: Stream[IO, Int] = Stream.emit(()) >> loop - loop - .interruptWhen(interrupt) - .compile - .drain - .replicateA(interruptRepeatCount) - } + test("7 - interruption of an infinitely recursive stream that never emits and has no eval") { + val interrupt = Stream.sleep_[IO](20.millis).compile.drain.attempt + def loop: Stream[IO, Int] = Stream.emit(()) >> loop + loop + .interruptWhen(interrupt) + .compile + .drain + .replicateA(interruptRepeatCount) + } - test("8 - interruption of a stream that repeatedly evaluates") { - val interrupt = - Stream.sleep_[IO](20.millis).compile.drain.attempt - Stream - .repeatEval(IO.unit) - .interruptWhen(interrupt) - .compile - .drain - .replicateA(interruptRepeatCount) - } + test("8 - interruption of a stream that repeatedly evaluates") { + val interrupt = + Stream.sleep_[IO](20.millis).compile.drain.attempt + Stream + .repeatEval(IO.unit) + .interruptWhen(interrupt) + .compile + .drain + .replicateA(interruptRepeatCount) + } - test("9 - interruption of the constant drained stream") { - val interrupt = - Stream.sleep_[IO](1.millis).compile.drain.attempt - Stream - .constant(true) - .dropWhile(!_) - .interruptWhen(interrupt) - .compile - .drain - .replicateA(interruptRepeatCount) - } + test("9 - interruption of the constant drained stream") { + val interrupt = + Stream.sleep_[IO](1.millis).compile.drain.attempt + Stream + .constant(true) + .dropWhile(!_) + .interruptWhen(interrupt) + .compile + .drain + .replicateA(interruptRepeatCount) + } - test("10 - terminates when interruption stream is infinitely false") { - forAllF { (s: Stream[Pure, Int]) => - val allFalse = Stream.constant(false) - s.covary[IO].interruptWhen(allFalse).assertEmitsSameAs(s) - } + test("10 - terminates when interruption stream is infinitely false") { + forAllF { (s: Stream[Pure, Int]) => + val allFalse = Stream.constant(false) + s.covary[IO].interruptWhen(allFalse).assertEmitsSameAs(s) } } diff --git a/io/js/src/main/scala/fs2/io/ioplatform.scala b/io/js/src/main/scala/fs2/io/ioplatform.scala index b2eb620aef..68e92f00fd 100644 --- a/io/js/src/main/scala/fs2/io/ioplatform.scala +++ b/io/js/src/main/scala/fs2/io/ioplatform.scala @@ -161,8 +161,9 @@ private[fs2] trait ioplatform { val end = if (endAfterUse) Stream.exec { - F.async_[Unit] { cb => - writable.end(e => cb(e.toLeft(()).leftMap(js.JavaScriptException))) + F.async[Unit] { cb => + F.delay(writable.end(e => cb(e.toLeft(()).leftMap(js.JavaScriptException)))) + .as(Some(F.unit)) } } else Stream.empty diff --git a/io/jvm-native/src/main/scala/fs2/io/net/SocketGroupPlatform.scala b/io/jvm-native/src/main/scala/fs2/io/net/SocketGroupPlatform.scala index c05ddb8017..dd24aa5042 100644 --- a/io/jvm-native/src/main/scala/fs2/io/net/SocketGroupPlatform.scala +++ b/io/jvm-native/src/main/scala/fs2/io/net/SocketGroupPlatform.scala @@ -57,21 +57,25 @@ private[net] trait SocketGroupCompanionPlatform { self: SocketGroup.type => def connect(ch: AsynchronousSocketChannel): F[AsynchronousSocketChannel] = to.resolve[F].flatMap { ip => - Async[F].async_[AsynchronousSocketChannel] { cb => - ch.connect( - ip.toInetSocketAddress, - null, - new CompletionHandler[Void, Void] { - def completed(result: Void, attachment: Void): Unit = - cb(Right(ch)) - def failed(rsn: Throwable, attachment: Void): Unit = - cb(Left(rsn)) + Async[F].async[AsynchronousSocketChannel] { cb => + Async[F] + .delay { + ch.connect( + ip.toInetSocketAddress, + null, + new CompletionHandler[Void, Void] { + def completed(result: Void, attachment: Void): Unit = + cb(Right(ch)) + def failed(rsn: Throwable, attachment: Void): Unit = + cb(Left(rsn)) + } + ) } - ) + .as(Some(Async[F].delay(ch.close()))) } } - setup.flatMap(ch => Resource.eval(connect(ch))).flatMap(Socket.forAsync(_)) + setup.evalMap(ch => connect(ch) *> Socket.forAsync(ch)) } def serverResource( @@ -104,28 +108,34 @@ private[net] trait SocketGroupCompanionPlatform { self: SocketGroup.type => sch: AsynchronousServerSocketChannel ): Stream[F, Socket[F]] = { def go: Stream[F, Socket[F]] = { - def acceptChannel: F[AsynchronousSocketChannel] = - Async[F].async_[AsynchronousSocketChannel] { cb => - sch.accept( - null, - new CompletionHandler[AsynchronousSocketChannel, Void] { - def completed(ch: AsynchronousSocketChannel, attachment: Void): Unit = - cb(Right(ch)) - def failed(rsn: Throwable, attachment: Void): Unit = - cb(Left(rsn)) - } - ) + def acceptChannel = Resource.makeFull[F, AsynchronousSocketChannel] { poll => + poll { + Async[F].async[AsynchronousSocketChannel] { cb => + Async[F] + .delay { + sch.accept( + null, + new CompletionHandler[AsynchronousSocketChannel, Void] { + def completed(ch: AsynchronousSocketChannel, attachment: Void): Unit = + cb(Right(ch)) + def failed(rsn: Throwable, attachment: Void): Unit = + cb(Left(rsn)) + } + ) + } + .as(Some(Async[F].delay(sch.close()))) + } } + }(ch => Async[F].delay(if (ch.isOpen) ch.close else ())) def setOpts(ch: AsynchronousSocketChannel) = Async[F].delay { options.foreach(o => ch.setOption(o.key, o.value)) } - Stream.eval(acceptChannel.attempt).flatMap { - case Left(_) => Stream.empty[F] - case Right(accepted) => - Stream.resource(Socket.forAsync(accepted).evalTap(_ => setOpts(accepted))) + Stream.resource(acceptChannel.attempt).flatMap { + case Left(_) => Stream.empty[F] + case Right(accepted) => Stream.eval(setOpts(accepted) *> Socket.forAsync(accepted)) } ++ go } diff --git a/io/jvm-native/src/main/scala/fs2/io/net/SocketPlatform.scala b/io/jvm-native/src/main/scala/fs2/io/net/SocketPlatform.scala index 7274af643b..ad9d50d683 100644 --- a/io/jvm-native/src/main/scala/fs2/io/net/SocketPlatform.scala +++ b/io/jvm-native/src/main/scala/fs2/io/net/SocketPlatform.scala @@ -24,8 +24,8 @@ package io package net import com.comcast.ip4s.{IpAddress, SocketAddress} -import cats.effect.{Async, Resource} -import cats.effect.std.Semaphore +import cats.effect.Async +import cats.effect.std.Mutex import cats.syntax.all._ import java.net.InetSocketAddress @@ -35,22 +35,20 @@ import java.nio.{Buffer, ByteBuffer} private[net] trait SocketCompanionPlatform { private[net] def forAsync[F[_]: Async]( ch: AsynchronousSocketChannel - ): Resource[F, Socket[F]] = - Resource.make { - (Semaphore[F](1), Semaphore[F](1)).mapN { (readSemaphore, writeSemaphore) => - new AsyncSocket[F](ch, readSemaphore, writeSemaphore) - } - }(_ => Async[F].delay(if (ch.isOpen) ch.close else ())) + ): F[Socket[F]] = + (Mutex[F], Mutex[F]).mapN { (readMutex, writeMutex) => + new AsyncSocket[F](ch, readMutex, writeMutex) + } private[net] abstract class BufferedReads[F[_]]( - readSemaphore: Semaphore[F] + readMutex: Mutex[F] )(implicit F: Async[F]) extends Socket[F] { private[this] final val defaultReadSize = 8192 private[this] var readBuffer: ByteBuffer = ByteBuffer.allocateDirect(defaultReadSize) private def withReadBuffer[A](size: Int)(f: ByteBuffer => F[A]): F[A] = - readSemaphore.permit.use { _ => + readMutex.lock.surround { F.delay { if (readBuffer.capacity() < size) readBuffer = ByteBuffer.allocateDirect(size) @@ -107,10 +105,10 @@ private[net] trait SocketCompanionPlatform { private final class AsyncSocket[F[_]]( ch: AsynchronousSocketChannel, - readSemaphore: Semaphore[F], - writeSemaphore: Semaphore[F] + readMutex: Mutex[F], + writeMutex: Mutex[F] )(implicit F: Async[F]) - extends BufferedReads[F](readSemaphore) { + extends BufferedReads[F](readMutex) { protected def readChunk(buffer: ByteBuffer): F[Int] = F.async[Int] { cb => @@ -142,8 +140,8 @@ private[net] trait SocketCompanionPlatform { go(buff) else F.unit } - writeSemaphore.permit.use { _ => - go(bytes.toByteBuffer) + writeMutex.lock.surround { + F.delay(bytes.toByteBuffer).flatMap(go) } } diff --git a/io/jvm/src/main/scala/fs2/io/net/tls/TLSEngine.scala b/io/jvm/src/main/scala/fs2/io/net/tls/TLSEngine.scala index 2c61e0a020..9be231c4ba 100644 --- a/io/jvm/src/main/scala/fs2/io/net/tls/TLSEngine.scala +++ b/io/jvm/src/main/scala/fs2/io/net/tls/TLSEngine.scala @@ -28,7 +28,7 @@ import javax.net.ssl.{SSLEngine, SSLEngineResult} import cats.Applicative import cats.effect.kernel.{Async, Sync} -import cats.effect.std.Semaphore +import cats.effect.std.Mutex import cats.syntax.all._ /** Provides the ability to establish and communicate over a TLS session. @@ -65,9 +65,9 @@ private[tls] object TLSEngine { engine.getSession.getPacketBufferSize, engine.getSession.getApplicationBufferSize ) - readSemaphore <- Semaphore[F](1) - writeSemaphore <- Semaphore[F](1) - handshakeSemaphore <- Semaphore[F](1) + readMutex <- Mutex[F] + writeMutex <- Mutex[F] + handshakeMutex <- Mutex[F] sslEngineTaskRunner = SSLEngineTaskRunner[F](engine) } yield new TLSEngine[F] { private val doLog: (() => String) => F[Unit] = @@ -85,7 +85,7 @@ private[tls] object TLSEngine { def stopUnwrap = Sync[F].delay(engine.closeInbound()).attempt.void def write(data: Chunk[Byte]): F[Unit] = - writeSemaphore.permit.use(_ => write0(data)) + writeMutex.lock.surround(write0(data)) private def write0(data: Chunk[Byte]): F[Unit] = wrapBuffer.input(data) >> wrap @@ -104,8 +104,8 @@ private[tls] object TLSEngine { wrapBuffer.inputRemains .flatMap(x => wrap.whenA(x > 0 && result.bytesConsumed > 0)) case _ => - handshakeSemaphore.permit - .use(_ => stepHandshake(result, true)) >> wrap + handshakeMutex.lock + .surround(stepHandshake(result, true)) >> wrap } } case SSLEngineResult.Status.BUFFER_UNDERFLOW => @@ -124,7 +124,7 @@ private[tls] object TLSEngine { } def read(maxBytes: Int): F[Option[Chunk[Byte]]] = - readSemaphore.permit.use(_ => read0(maxBytes)) + readMutex.lock.surround(read0(maxBytes)) private def initialHandshakeDone: F[Boolean] = Sync[F].delay(engine.getSession.getCipherSuite != "SSL_NULL_WITH_NULL_NULL") @@ -168,8 +168,8 @@ private[tls] object TLSEngine { case SSLEngineResult.HandshakeStatus.FINISHED => unwrap(maxBytes) case _ => - handshakeSemaphore.permit - .use(_ => stepHandshake(result, false)) >> unwrap( + handshakeMutex.lock + .surround(stepHandshake(result, false)) >> unwrap( maxBytes ) } diff --git a/io/jvm/src/main/scala/fs2/io/net/tls/TLSSocketPlatform.scala b/io/jvm/src/main/scala/fs2/io/net/tls/TLSSocketPlatform.scala index ddd2abbb15..d541f927c4 100644 --- a/io/jvm/src/main/scala/fs2/io/net/tls/TLSSocketPlatform.scala +++ b/io/jvm/src/main/scala/fs2/io/net/tls/TLSSocketPlatform.scala @@ -25,7 +25,7 @@ package net package tls import cats.Applicative -import cats.effect.std.Semaphore +import cats.effect.std.Mutex import cats.effect.kernel._ import cats.syntax.all._ @@ -53,7 +53,7 @@ private[tls] trait TLSSocketCompanionPlatform { self: TLSSocket.type => engine: TLSEngine[F] ): F[TLSSocket[F]] = for { - readSem <- Semaphore(1) + readMutex <- Mutex[F] } yield new UnsealedTLSSocket[F] { def write(bytes: Chunk[Byte]): F[Unit] = engine.write(bytes) @@ -62,7 +62,7 @@ private[tls] trait TLSSocketCompanionPlatform { self: TLSSocket.type => engine.read(maxBytes) def readN(numBytes: Int): F[Chunk[Byte]] = - readSem.permit.use { _ => + readMutex.lock.surround { def go(acc: Chunk[Byte]): F[Chunk[Byte]] = { val toRead = numBytes - acc.size if (toRead <= 0) Applicative[F].pure(acc) @@ -76,7 +76,7 @@ private[tls] trait TLSSocketCompanionPlatform { self: TLSSocket.type => } def read(maxBytes: Int): F[Option[Chunk[Byte]]] = - readSem.permit.use(_ => read0(maxBytes)) + readMutex.lock.surround(read0(maxBytes)) def reads: Stream[F, Byte] = Stream.repeatEval(read(8192)).unNoneTerminate.unchunks diff --git a/io/jvm/src/main/scala/fs2/io/net/unixsocket/UnixSocketsPlatform.scala b/io/jvm/src/main/scala/fs2/io/net/unixsocket/UnixSocketsPlatform.scala index 276b67a243..868d9ecb00 100644 --- a/io/jvm/src/main/scala/fs2/io/net/unixsocket/UnixSocketsPlatform.scala +++ b/io/jvm/src/main/scala/fs2/io/net/unixsocket/UnixSocketsPlatform.scala @@ -22,7 +22,7 @@ package fs2.io.net.unixsocket import cats.effect.kernel.{Async, Resource} -import cats.effect.std.Semaphore +import cats.effect.std.Mutex import cats.syntax.all._ import com.comcast.ip4s.{IpAddress, SocketAddress} import fs2.{Chunk, Stream} @@ -89,17 +89,17 @@ private[unixsocket] trait UnixSocketsCompanionPlatform { ch: SocketChannel ): Resource[F, Socket[F]] = Resource.make { - (Semaphore[F](1), Semaphore[F](1)).mapN { (readSemaphore, writeSemaphore) => - new AsyncSocket[F](ch, readSemaphore, writeSemaphore) + (Mutex[F], Mutex[F]).mapN { (readMutex, writeMutex) => + new AsyncSocket[F](ch, readMutex, writeMutex) } }(_ => Async[F].delay(if (ch.isOpen) ch.close else ())) private final class AsyncSocket[F[_]]( ch: SocketChannel, - readSemaphore: Semaphore[F], - writeSemaphore: Semaphore[F] + readMutex: Mutex[F], + writeMutex: Mutex[F] )(implicit F: Async[F]) - extends Socket.BufferedReads[F](readSemaphore) { + extends Socket.BufferedReads[F](readMutex) { def readChunk(buff: ByteBuffer): F[Int] = F.blocking(ch.read(buff)) @@ -110,8 +110,8 @@ private[unixsocket] trait UnixSocketsCompanionPlatform { if (buff.remaining <= 0) F.unit else go(buff) } - writeSemaphore.permit.use { _ => - go(bytes.toByteBuffer) + writeMutex.lock.surround { + F.delay(bytes.toByteBuffer).flatMap(go) } } diff --git a/io/native/src/main/scala/fs2/io/net/tls/TLSSocketPlatform.scala b/io/native/src/main/scala/fs2/io/net/tls/TLSSocketPlatform.scala index 716e7b90c5..2aa0be334d 100644 --- a/io/native/src/main/scala/fs2/io/net/tls/TLSSocketPlatform.scala +++ b/io/native/src/main/scala/fs2/io/net/tls/TLSSocketPlatform.scala @@ -26,7 +26,7 @@ package tls import cats.effect.kernel.Async import cats.effect.kernel.Resource -import cats.effect.std.Semaphore +import cats.effect.std.Mutex import cats.syntax.all._ import com.comcast.ip4s.IpAddress import com.comcast.ip4s.SocketAddress @@ -49,17 +49,17 @@ private[tls] trait TLSSocketCompanionPlatform { self: TLSSocket.type => connection: S2nConnection[F] )(implicit F: Async[F]): F[TLSSocket[F]] = for { - readSem <- Semaphore(1) - writeSem <- Semaphore(1) + readMutex <- Mutex[F] + writeMutex <- Mutex[F] } yield new UnsealedTLSSocket[F] { def write(bytes: Chunk[Byte]): F[Unit] = - writeSem.permit.surround(connection.write(bytes)) + writeMutex.lock.surround(connection.write(bytes)) private def read0(maxBytes: Int): F[Option[Chunk[Byte]]] = connection.read(maxBytes) def readN(numBytes: Int): F[Chunk[Byte]] = - readSem.permit.use { _ => + readMutex.lock.surround { def go(acc: Chunk[Byte]): F[Chunk[Byte]] = { val toRead = numBytes - acc.size if (toRead <= 0) F.pure(acc) @@ -73,7 +73,7 @@ private[tls] trait TLSSocketCompanionPlatform { self: TLSSocket.type => } def read(maxBytes: Int): F[Option[Chunk[Byte]]] = - readSem.permit.surround(read0(maxBytes)) + readMutex.lock.surround(read0(maxBytes)) def reads: Stream[F, Byte] = Stream.repeatEval(read(8192)).unNoneTerminate.unchunks