Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid allocations in TLSEngine when logging is disabled #2462

Merged
merged 6 commits into from
Jul 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,16 @@ ThisBuild / mimaBinaryIssueFilters ++= Seq(
ProblemFilters.exclude[MissingClassProblem]("fs2.Pull$CloseScope$"),
ProblemFilters.exclude[ReversedAbstractMethodProblem]("fs2.Pull#CloseScope.*"),
ProblemFilters.exclude[Problem]("fs2.io.Watcher#Registration.*"),
ProblemFilters.exclude[Problem]("fs2.io.Watcher#DefaultWatcher.*")
ProblemFilters.exclude[Problem]("fs2.io.Watcher#DefaultWatcher.*"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("fs2.io.net.tls.TLSContext.clientBuilder"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("fs2.io.net.tls.TLSContext.serverBuilder"),
ProblemFilters.exclude[ReversedMissingMethodProblem](
"fs2.io.net.tls.TLSContext.dtlsClientBuilder"
),
ProblemFilters.exclude[ReversedMissingMethodProblem](
"fs2.io.net.tls.TLSContext.dtlsServerBuilder"
),
ProblemFilters.exclude[Problem]("fs2.io.net.tls.TLSEngine*")
)

lazy val root = project
Expand Down
3 changes: 1 addition & 2 deletions core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import scala.concurrent.duration._
import scala.concurrent.TimeoutException
import cats.effect.{IO, SyncIO}
import cats.effect.kernel.Ref
import cats.effect.std.Queue
import cats.effect.std.Semaphore
import cats.syntax.all._
import org.scalacheck.Gen
Expand Down Expand Up @@ -1316,7 +1315,7 @@ class StreamCombinatorsSuite extends Fs2Suite {
val action =
Vector.fill(streamSize)(Deferred[IO, Unit]).sequence.map { seenArr =>
def peek(ind: Int)(f: Option[Unit] => Boolean) =
seenArr.get(ind).fold(true.pure[IO])(_.tryGet.map(f))
seenArr.get(ind.toLong).fold(true.pure[IO])(_.tryGet.map(f))

Stream
.emits(0 until streamSize)
Expand Down
187 changes: 118 additions & 69 deletions io/src/main/scala/fs2/io/net/tls/TLSContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ import javax.net.ssl.{
SSLContext,
SSLEngine,
TrustManagerFactory,
X509ExtendedTrustManager,
X509TrustManager
X509ExtendedTrustManager
}
import cats.Applicative
import cats.effect.kernel.{Async, Resource}
Expand All @@ -47,43 +46,85 @@ import java.util.function.BiFunction
*/
sealed trait TLSContext[F[_]] {

/** Creates a `TLSSocket` in client mode, using the supplied parameters.
* Internal debug logging of the session can be enabled by passing a logger.
*/
/** Creates a `TLSSocket` builder in client mode. */
def client(socket: Socket[F]): Resource[F, TLSSocket[F]] =
clientBuilder(socket).build

/** Creates a `TLSSocket` builder in client mode, allowing optional parameters to be configured. */
def clientBuilder(socket: Socket[F]): TLSContext.SocketBuilder[F, TLSSocket]

@deprecated("Use client(socket) or clientBuilder(socket).with(...).build", "3.0.6")
def client(
socket: Socket[F],
params: TLSParameters = TLSParameters.Default,
logger: Option[String => F[Unit]] = None
): Resource[F, TLSSocket[F]]
): Resource[F, TLSSocket[F]] =
clientBuilder(socket).withParameters(params).withOldLogging(logger).build

/** Creates a `TLSSocket` builder in server mode. */
def server(socket: Socket[F]): Resource[F, TLSSocket[F]] =
serverBuilder(socket).build

/** Creates a `TLSSocket` in server mode, using the supplied parameters.
* Internal debug logging of the session can be enabled by passing a logger.
*/
/** Creates a `TLSSocket` builder in server mode, allowing optional parameters to be configured. */
def serverBuilder(socket: Socket[F]): TLSContext.SocketBuilder[F, TLSSocket]

@deprecated("Use server(socket) or serverBuilder(socket).with(...).build", "3.0.6")
def server(
socket: Socket[F],
params: TLSParameters = TLSParameters.Default,
logger: Option[String => F[Unit]] = None
): Resource[F, TLSSocket[F]]
): Resource[F, TLSSocket[F]] =
serverBuilder(socket).withParameters(params).withOldLogging(logger).build

/** Creates a `DTLSSocket` builder in client mode. */
def dtlsClient(
socket: DatagramSocket[F],
remoteAddress: SocketAddress[IpAddress]
): Resource[F, DTLSSocket[F]] =
dtlsClientBuilder(socket, remoteAddress).build

/** Creates a `DTLSSocket` builder in client mode, allowing optional parameters to be configured. */
def dtlsClientBuilder(
socket: DatagramSocket[F],
remoteAddress: SocketAddress[IpAddress]
): TLSContext.SocketBuilder[F, DTLSSocket]

/** Creates a `DTLSSocket` in client mode, using the supplied parameters.
* Internal debug logging of the session can be enabled by passing a logger.
*/
@deprecated(
"Use dtlsClient(socket, remoteAddress) or dtlsClientBuilder(socket, remoteAddress).with(...).build",
"3.0.6"
)
def dtlsClient(
socket: DatagramSocket[F],
remoteAddress: SocketAddress[IpAddress],
params: TLSParameters = TLSParameters.Default,
logger: Option[String => F[Unit]] = None
): Resource[F, DTLSSocket[F]]
): Resource[F, DTLSSocket[F]] =
dtlsClientBuilder(socket, remoteAddress).withParameters(params).withOldLogging(logger).build

/** Creates a `DTLSSocket` in server mode, using the supplied parameters.
* Internal debug logging of the session can be enabled by passing a logger.
*/
/** Creates a `DTLSSocket` builder in server mode. */
def dtlsServer(
socket: DatagramSocket[F],
remoteAddress: SocketAddress[IpAddress]
): Resource[F, DTLSSocket[F]] =
dtlsServerBuilder(socket, remoteAddress).build

/** Creates a `DTLSSocket` builder in client mode, allowing optional parameters to be configured. */
def dtlsServerBuilder(
socket: DatagramSocket[F],
remoteAddress: SocketAddress[IpAddress]
): TLSContext.SocketBuilder[F, DTLSSocket]

@deprecated(
"Use dtlsServer(socket, remoteAddress) or dtlsClientBuilder(socket, remoteAddress).with(...).build",
"3.0.6"
)
def dtlsServer(
socket: DatagramSocket[F],
remoteAddress: SocketAddress[IpAddress],
params: TLSParameters = TLSParameters.Default,
logger: Option[String => F[Unit]] = None
): Resource[F, DTLSSocket[F]]
): Resource[F, DTLSSocket[F]] =
dtlsServerBuilder(socket, remoteAddress).withParameters(params).withOldLogging(logger).build
}

object TLSContext {
Expand Down Expand Up @@ -128,35 +169,17 @@ object TLSContext {
ctx: SSLContext
): TLSContext[F] =
new TLSContext[F] {
def client(
socket: Socket[F],
params: TLSParameters,
logger: Option[String => F[Unit]]
): Resource[F, TLSSocket[F]] =
mkSocket(
socket,
true,
params,
logger
)

def server(
socket: Socket[F],
params: TLSParameters,
logger: Option[String => F[Unit]]
): Resource[F, TLSSocket[F]] =
mkSocket(
socket,
false,
params,
logger
)
def clientBuilder(socket: Socket[F]) =
SocketBuilder((p, l) => mkSocket(socket, true, p, l))

def serverBuilder(socket: Socket[F]) =
SocketBuilder((p, l) => mkSocket(socket, false, p, l))

private def mkSocket(
socket: Socket[F],
clientMode: Boolean,
params: TLSParameters,
logger: Option[String => F[Unit]]
logger: TLSLogger[F]
): Resource[F, TLSSocket[F]] =
Resource
.eval(
Expand All @@ -174,40 +197,24 @@ object TLSContext {
)
.flatMap(engine => TLSSocket(socket, engine))

def dtlsClient(
def dtlsClientBuilder(
socket: DatagramSocket[F],
remoteAddress: SocketAddress[IpAddress],
params: TLSParameters,
logger: Option[String => F[Unit]]
): Resource[F, DTLSSocket[F]] =
mkDtlsSocket(
socket,
remoteAddress,
true,
params,
logger
)

def dtlsServer(
remoteAddress: SocketAddress[IpAddress]
) =
SocketBuilder((p, l) => mkDtlsSocket(socket, remoteAddress, true, p, l))

def dtlsServerBuilder(
socket: DatagramSocket[F],
remoteAddress: SocketAddress[IpAddress],
params: TLSParameters,
logger: Option[String => F[Unit]]
): Resource[F, DTLSSocket[F]] =
mkDtlsSocket(
socket,
remoteAddress,
false,
params,
logger
)
remoteAddress: SocketAddress[IpAddress]
) =
SocketBuilder((p, l) => mkDtlsSocket(socket, remoteAddress, false, p, l))

private def mkDtlsSocket(
socket: DatagramSocket[F],
remoteAddress: SocketAddress[IpAddress],
clientMode: Boolean,
params: TLSParameters,
logger: Option[String => F[Unit]]
logger: TLSLogger[F]
): Resource[F, DTLSSocket[F]] =
Resource
.eval(
Expand All @@ -230,7 +237,7 @@ object TLSContext {
binding: TLSEngine.Binding[F],
clientMode: Boolean,
params: TLSParameters,
logger: Option[String => F[Unit]]
logger: TLSLogger[F]
): F[TLSEngine[F]] = {
val sslEngine = Async[F].blocking {
val engine = ctx.createSSLEngine()
Expand Down Expand Up @@ -345,4 +352,46 @@ object TLSContext {
.map(fromSSLContext(_))
}
}

sealed trait SocketBuilder[F[_], S[_[_]]] {
def withParameters(params: TLSParameters): SocketBuilder[F, S]
def withLogging(log: (=> String) => F[Unit]): SocketBuilder[F, S]
def withoutLogging: SocketBuilder[F, S]
def withLogger(logger: TLSLogger[F]): SocketBuilder[F, S]
private[TLSContext] def withOldLogging(log: Option[String => F[Unit]]): SocketBuilder[F, S]
def build: Resource[F, S[F]]
}

object SocketBuilder {
private[TLSContext] type Build[F[_], S[_[_]]] =
(TLSParameters, TLSLogger[F]) => Resource[F, S[F]]

private[TLSContext] def apply[F[_], S[_[_]]](
mkSocket: Build[F, S]
): SocketBuilder[F, S] =
instance(mkSocket, TLSParameters.Default, TLSLogger.Disabled)

private def instance[F[_], S[_[_]]](
mkSocket: Build[F, S],
params: TLSParameters,
logger: TLSLogger[F]
): SocketBuilder[F, S] =
new SocketBuilder[F, S] {
def withParameters(params: TLSParameters): SocketBuilder[F, S] =
instance(mkSocket, params, logger)
def withLogging(log: (=> String) => F[Unit]): SocketBuilder[F, S] =
withLogger(TLSLogger.Enabled(log))
def withoutLogging: SocketBuilder[F, S] =
withLogger(TLSLogger.Disabled)
def withLogger(logger: TLSLogger[F]): SocketBuilder[F, S] =
instance(mkSocket, params, logger)
private[TLSContext] def withOldLogging(
log: Option[String => F[Unit]]
): SocketBuilder[F, S] =
log.map(f => withLogging(m => f(m))).getOrElse(withoutLogging)
def build: Resource[F, S[F]] =
mkSocket(params, logger)
}
}

}
11 changes: 8 additions & 3 deletions io/src/main/scala/fs2/io/net/tls/TLSEngine.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ private[tls] object TLSEngine {
def apply[F[_]: Async](
engine: SSLEngine,
binding: Binding[F],
logger: Option[String => F[Unit]] = None
logger: TLSLogger[F]
): F[TLSEngine[F]] =
for {
wrapBuffer <- InputOutputBuffer[F](
Expand All @@ -70,8 +70,13 @@ private[tls] object TLSEngine {
handshakeSemaphore <- Semaphore[F](1)
sslEngineTaskRunner = SSLEngineTaskRunner[F](engine)
} yield new TLSEngine[F] {
private def log(msg: String): F[Unit] =
logger.map(_(msg)).getOrElse(Applicative[F].unit)
private val doLog: (() => String) => F[Unit] =
logger match {
case e: TLSLogger.Enabled[_] => msg => e.log(msg())
case TLSLogger.Disabled => _ => Applicative[F].unit
}

private def log(msg: => String): F[Unit] = doLog(() => msg)

def beginHandshake = Sync[F].delay(engine.beginHandshake())
def session = Sync[F].delay(engine.getSession())
Expand Down
29 changes: 29 additions & 0 deletions io/src/main/scala/fs2/io/net/tls/TLSLogger.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) 2013 Functional Streams for Scala
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of
* this software and associated documentation files (the "Software"), to deal in
* the Software without restriction, including without limitation the rights to
* use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
* the Software, and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/

package fs2.io.net.tls

sealed trait TLSLogger[+F[_]]

object TLSLogger {
case object Disabled extends TLSLogger[Nothing]
case class Enabled[F[_]](log: (=> String) => F[Unit]) extends TLSLogger[F]
}
10 changes: 8 additions & 2 deletions io/src/test/scala/fs2/io/net/tls/DTLSSocketSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,14 @@ class DTLSSocketSuite extends TLSSuite {
serverAddress <- address(serverSocket)
clientSocket <- Network[IO].openDatagramSocket()
clientAddress <- address(clientSocket)
tlsServerSocket <- tlsContext.dtlsServer(serverSocket, clientAddress, logger = logger)
tlsClientSocket <- tlsContext.dtlsClient(clientSocket, serverAddress, logger = logger)
tlsServerSocket <- tlsContext
.dtlsServerBuilder(serverSocket, clientAddress)
.withLogger(logger)
.build
tlsClientSocket <- tlsContext
.dtlsClientBuilder(clientSocket, serverAddress)
.withLogger(logger)
.build
} yield (tlsServerSocket, tlsClientSocket, serverAddress)

Stream
Expand Down
5 changes: 3 additions & 2 deletions io/src/test/scala/fs2/io/net/tls/TLSDebugExample.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ object TLSDebug {
host.resolve.flatMap { socketAddress =>
Network[F].client(socketAddress).use { rawSocket =>
tlsContext
.client(
rawSocket,
.clientBuilder(rawSocket)
.withParameters(
TLSParameters(serverNames = Some(List(new SNIHostName(host.host.toString))))
)
.build
.use { tlsSocket =>
tlsSocket.write(Chunk.empty) >>
tlsSocket.session.map { session =>
Expand Down
Loading