Skip to content

Commit

Permalink
m9 - Socket Groups for HttpClient
Browse files Browse the repository at this point in the history
  • Loading branch information
ollls committed Mar 29, 2021
1 parent bbceaa7 commit a8c8fac
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 31 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ ThisBuild / publishMavenStyle := true
.settings(
organization := "io.github.ollls",
name := "zio-tls-http",
version := "1.1.0-m8",
version := "1.1.0-m9",
scalaVersion := "2.13.1",
maxErrors := 3,
retrieveManaged := true,
Expand Down
21 changes: 14 additions & 7 deletions src/main/scala/clients/HttpConnection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.KeyManagerFactory
import nio.channels.AsynchronousTlsByteChannel
import nio.channels.AsynchronousSocketChannel
import nio.channels.AsynchronousChannelGroup
import java.security.KeyStore
import zhttp.Headers
import zhttp.ContentType
Expand All @@ -31,6 +32,7 @@ import java.security.cert.X509Certificate
import java.io.FileInputStream
import java.io.File


sealed case class HttpConnectionError(msg: String) extends Exception(msg)
sealed case class HttpResponseHeaderError(msg: String) extends Exception(msg)

Expand Down Expand Up @@ -149,6 +151,7 @@ object HttpConnection {
private def connectSSL(
host: String,
port: Int,
group : AsynchronousChannelGroup,
blindTrust: Boolean = false,
trustKeystore: String = null,
password: String = ""
Expand All @@ -158,7 +161,7 @@ object HttpConnection {
ssl_ctx <- if (trustKeystore == null && blindTrust == false)
effectBlocking(SSLContext.getDefault()).refineToOrDie[Exception]
else buildSSLContextM(TLS_PROTOCOL_TAG, trustKeystore, password)
ch <- AsynchronousSocketChannel()
ch <- if ( group == null ) AsynchronousSocketChannel() else AsynchronousSocketChannel( group )
_ <- ch.connect(address).mapError(e => HttpConnectionError(e.toString))
tls_ch <- AsynchronousTlsByteChannel(ch, ssl_ctx)
} yield (tls_ch)
Expand All @@ -168,27 +171,31 @@ object HttpConnection {

private def connectPlain(
host: String,
port: Int
port: Int,
group : AsynchronousChannelGroup,
): ZIO[zio.ZEnv, Exception, Channel] = {
val T = for {
address <- SocketAddress.inetSocketAddress(host, port)
ch <- AsynchronousSocketChannel()
ch <- if ( group == null ) AsynchronousSocketChannel() else AsynchronousSocketChannel( group )
_ <- ch.connect(address).mapError(e => HttpConnectionError(e.toString))
} yield (ch)

T.map(c => new TcpChannel(c))
T.map(c => new TcpChannel(c)).refineToOrDie[Exception]
}


def connect(
url: String,
socketGroup : AsynchronousChannelGroup = null,
tlsBlindTrust: Boolean = false,
trustKeystore: String = null,
password: String = ""
): ZIO[ZEnv, HttpConnectionError, HttpConnection] =
connectWithFilter(url, req => ZIO.effectTotal(req), tlsBlindTrust, trustKeystore, password)
connectWithFilter(url, socketGroup, req => ZIO.effectTotal(req), tlsBlindTrust, trustKeystore, password)

def connectWithFilter(
url: String,
socketGroup : AsynchronousChannelGroup,
filter: ClientRequest => ZIO[ZEnv with MyLogging, Throwable, ClientRequest],
tlsBlindTrust: Boolean = false,
trustKeystore: String = null,
Expand All @@ -197,11 +204,11 @@ object HttpConnection {
val u = new URI(url)
val port = if (u.getPort == -1) 443 else u.getPort
(if (u.getScheme().equalsIgnoreCase("https")) {
val ss = connectSSL(u.getHost(), port, tlsBlindTrust, trustKeystore, password)
val ss = connectSSL(u.getHost(), port, socketGroup, tlsBlindTrust, trustKeystore, password)
.map(new HttpConnection(u, _, FilterProc(filter)))
ss
} else if (u.getScheme().equalsIgnoreCase("http")) {
val ss = connectPlain(u.getHost(), port).map(new HttpConnection(u, _, FilterProc(filter)))
val ss = connectPlain(u.getHost(), port, socketGroup).map(new HttpConnection(u, _, FilterProc(filter)))
ss
} else
throw new Exception("HttpConnection: Unsupported scheme - " + u.getScheme()))
Expand Down
3 changes: 3 additions & 0 deletions src/main/scala/nio/AsynchronousChannelGroup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ object AsynchronousChannelGroup {
new AsynchronousChannelGroup(JAsynchronousChannelGroup.withThreadPool(executor))
)
.refineToOrDie[Exception]

//used in Zlayer construction, before unsafeRun ZIO cycle
def make( executor: JExecutorService ) = new AsynchronousChannelGroup(JAsynchronousChannelGroup.withThreadPool(executor))
}

class AsynchronousChannelGroup(private[channels] val channelGroup: JAsynchronousChannelGroup) {
Expand Down
8 changes: 4 additions & 4 deletions src/main/scala/nio/channels/AsynchronousByteTlsChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,12 @@ class AsynchronousTlsByteChannel(private val channel: AsynchronousSocketChannel,

final def keepAlive(ms: Long) = { READ_TIMEOUT_MS = ms; this }

def getSession: IO[Exception, SSLSession] =
final def getSession: IO[Exception, SSLSession] =
IO.effect(sslEngine.engine.getSession()).refineToOrDie[Exception]

def remoteAddress: ZIO[ZEnv, Exception, Option[SocketAddress]] = channel.remoteAddress
final def remoteAddress: ZIO[ZEnv, Exception, Option[SocketAddress]] = channel.remoteAddress

def readBuffer(out_b: java.nio.ByteBuffer): ZIO[ZEnv, Exception, Unit] = {
final def readBuffer(out_b: java.nio.ByteBuffer): ZIO[ZEnv, Exception, Unit] = {

val out = Buffer.byte(out_b)

Expand Down Expand Up @@ -224,7 +224,7 @@ class AsynchronousTlsByteChannel(private val channel: AsynchronousSocketChannel,
}

//////////////////////////////////////////////////////////////////////////
def read: ZIO[ZEnv, Exception, Chunk[Byte]] = {
final def read: ZIO[ZEnv, Exception, Chunk[Byte]] = {
val OUT_BUF_SZ = APP_PACKET_SZ * 2
val result = for {

Expand Down
13 changes: 3 additions & 10 deletions src/main/scala/nio/channels/AsynchronousChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -273,17 +273,10 @@ object AsynchronousSocketChannel {
.map(new AsynchronousSocketChannel(_))
}

def apply(channelGroup: AsynchronousChannelGroup): Managed[Exception, AsynchronousSocketChannel] = {
val open = IO
.effect(
JAsynchronousSocketChannel.open(channelGroup.channelGroup)
)
.refineOrDie {
case e: Exception => e
}
def apply( channelGroup: AsynchronousChannelGroup ): ZIO[ZEnv, Exception, AsynchronousSocketChannel] = {
IO.effect(JAsynchronousSocketChannel.open( channelGroup.channelGroup ))
.refineToOrDie[Exception]
.map(new AsynchronousSocketChannel(_))

Managed.make(open)(_.close.orDie)
}

def apply(asyncSocketChannel: JAsynchronousSocketChannel): AsynchronousSocketChannel =
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/server/TLSServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class TLSServer[MyEnv <: Has[MyLogging.Service]](
_ <- ZIO.effectTotal(terminate)
//kick it one last time
c <- clients.HttpConnection
.connect(s"https://localhost:$SERVER_PORT", tlsBlindTrust = false, s"$KEYSTORE_PATH", s"$KEYSTORE_PASSWORD")
.connect(s"https://localhost:$SERVER_PORT", null, tlsBlindTrust = false, s"$KEYSTORE_PATH", s"$KEYSTORE_PASSWORD")
response <- c.send(clients.ClientRequest(zhttp.Method.GET, "/"))

svc <- MyLogging.logService
Expand Down
12 changes: 4 additions & 8 deletions src/main/scala/server/TcpServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,13 @@ class TcpServer[MyEnv <: Has[MyLogging.Service]](port: Int, keepAlive: Int = 200
"console",
"Listens TCP: " + BINDING_SERVER_IP + ":" + SERVER_PORT + ", keep alive: " + KEEP_ALIVE + " ms"
)

executor <- ZIO.runtime.map((runtime: zio.Runtime[Any]) => runtime.platform.executor.asECES)

//executor1 <- ZIO.effect ( java.util.concurrent.Executors.newFixedThreadPool(4) )
address <- SocketAddress.inetSocketAddress(BINDING_SERVER_IP, SERVER_PORT)

group <- AsynchronousChannelGroup(executor)

group <- AsynchronousChannelGroup(executor)
_ <- group.openAsynchronousServerSocketChannel().use { srv =>
{
for {

_ <- srv.bind(address)

loop = srv.accept2
Expand All @@ -54,9 +50,9 @@ class TcpServer[MyEnv <: Has[MyLogging.Service]](port: Int, keepAlive: Int = 200
MyLogging.debug("console", "Connected: " + c.get.toInetSocketAddress.address.canonicalHostName)
}) *>
ZManaged
.make(ZIO.effect(new TcpChannel(channel.keepAlive(KEEP_ALIVE))))( Channel.close( _ ).orDie )
.make(ZIO.effect(new TcpChannel(channel.keepAlive(KEEP_ALIVE))))(Channel.close(_).orDie)
.use { c =>
processor(c).catchAll( e => MyLogging.error("console", e.toString ) *> IO.succeed(0))
processor(c).catchAll(e => MyLogging.error("console", e.toString) *> IO.succeed(0))
}
.fork
)
Expand Down

0 comments on commit a8c8fac

Please sign in to comment.