Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Land PostgresClient that is backed by a ConnectionPool as SPI #430

Merged
merged 7 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ extension PoolStateMachine {
}

@inlinable
mutating func parkConnection(at index: Int) -> Max2Sequence<ConnectionTimer> {
mutating func parkConnection(at index: Int, hasBecomeIdle newIdle: Bool) -> Max2Sequence<ConnectionTimer> {
let scheduleIdleTimeoutTimer: Bool
switch index {
case 0..<self.minimumConcurrentConnections:
Expand All @@ -318,7 +318,7 @@ extension PoolStateMachine {

case self.minimumConcurrentConnections..<self.maximumConcurrentConnectionSoftLimit:
// if a connection is a demand connection, we want a timeout timer
scheduleIdleTimeoutTimer = true
scheduleIdleTimeoutTimer = newIdle

case self.maximumConcurrentConnectionSoftLimit..<self.maximumConcurrentConnectionHardLimit:
preconditionFailure("Overflow connections should never be parked.")
Expand Down Expand Up @@ -626,8 +626,11 @@ extension PoolStateMachine {

case self.minimumConcurrentConnections..<self.maximumConcurrentConnectionSoftLimit:
// the connection to be removed is a demand connection
self.connections.swapAt(indexToDelete, lastConnectedIndex)
self.removeO1(lastConnectedIndex)

switch lastConnectedIndex {
case self.minimumConcurrentConnections..<self.maximumConcurrentConnectionSoftLimit:
case self.maximumConcurrentConnectionSoftLimit..<self.maximumConcurrentConnectionHardLimit:
// an overflow connection was moved to a demand connection. It has to be currently leased
precondition(self.connections[indexToDelete].isLeased)
return nil
Expand Down
4 changes: 2 additions & 2 deletions Sources/ConnectionPoolModule/PoolStateMachine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,8 @@ struct PoolStateMachine<
case .leased:
return .none()

case .idle:
let timers = self.connections.parkConnection(at: index).map(self.mapTimers)
case .idle(_, let newIdle):
let timers = self.connections.parkConnection(at: index, hasBecomeIdle: newIdle).map(self.mapTimers)

return .init(
request: .none,
Expand Down
4 changes: 2 additions & 2 deletions Sources/PostgresNIO/Connection/PostgresConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public final class PostgresConnection: @unchecked Sendable {
return !self.channel.isActive
}

let id: ID
public let id: ID

private var _logger: Logger

Expand Down Expand Up @@ -391,7 +391,7 @@ extension PostgresConnection {
self.channel.triggerUserOutboundEvent(PSQLOutgoingEvent.gracefulShutdown, promise: promise)
return try await promise.futureResult.get()
} onCancel: {
_ = self.close()
self.close()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ extension ConnectionStateMachine {
}

return false
case .clientClosedConnection:
case .clientClosedConnection, .poolClosed:
preconditionFailure("A pure client error was thrown directly in PostgresConnection, this shouldn't happen")
case .serverClosedConnection:
return true
Expand Down
16 changes: 13 additions & 3 deletions Sources/PostgresNIO/New/PSQLError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public struct PSQLError: Error {

case listenFailed
case unlistenFailed
case poolClosed
}

internal var base: Base
Expand All @@ -33,22 +34,25 @@ public struct PSQLError: Error {
self.base = base
}

public static let sslUnsupported = Self.init(.sslUnsupported)
public static let sslUnsupported = Self(.sslUnsupported)
public static let failedToAddSSLHandler = Self(.failedToAddSSLHandler)
public static let receivedUnencryptedDataAfterSSLRequest = Self(.receivedUnencryptedDataAfterSSLRequest)
public static let server = Self(.server)
public static let messageDecodingFailure = Self(.messageDecodingFailure)
public static let unexpectedBackendMessage = Self(.unexpectedBackendMessage)
public static let unsupportedAuthMechanism = Self(.unsupportedAuthMechanism)
public static let authMechanismRequiresPassword = Self(.authMechanismRequiresPassword)
public static let saslError = Self.init(.saslError)
public static let saslError = Self(.saslError)
public static let invalidCommandTag = Self(.invalidCommandTag)
public static let queryCancelled = Self(.queryCancelled)
public static let tooManyParameters = Self(.tooManyParameters)
public static let clientClosedConnection = Self(.clientClosedConnection)
public static let serverClosedConnection = Self(.serverClosedConnection)
public static let connectionError = Self(.connectionError)
public static let uncleanShutdown = Self.init(.uncleanShutdown)

public static let uncleanShutdown = Self(.uncleanShutdown)
public static let poolClosed = Self(.poolClosed)

public static let listenFailed = Self.init(.listenFailed)
public static let unlistenFailed = Self.init(.unlistenFailed)

Expand Down Expand Up @@ -92,6 +96,8 @@ public struct PSQLError: Error {
return "connectionError"
case .uncleanShutdown:
return "uncleanShutdown"
case .poolClosed:
return "poolClosed"
case .listenFailed:
return "listenFailed"
case .unlistenFailed:
Expand Down Expand Up @@ -457,6 +463,10 @@ public struct PSQLError: Error {
case sspi
case sasl(mechanisms: [String])
}

static var poolClosed: PSQLError {
Self.init(code: .poolClosed)
}
}

extension PSQLError: CustomStringConvertible {
Expand Down
206 changes: 206 additions & 0 deletions Sources/PostgresNIO/Pool/ConnectionFactory.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import Logging
import NIOConcurrencyHelpers
import NIOCore
import NIOSSL

@available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *)
final class ConnectionFactory: Sendable {

struct ConfigCache: Sendable {
var config: PostgresClient.Configuration
}

let configBox: NIOLockedValueBox<ConfigCache>

struct SSLContextCache: Sendable {
enum State {
case none
case producing(TLSConfiguration, [CheckedContinuation<NIOSSLContext, any Error>])
case cached(TLSConfiguration, NIOSSLContext)
case failed(TLSConfiguration, any Error)
}

var state: State = .none
}

let sslContextBox = NIOLockedValueBox(SSLContextCache())

let eventLoopGroup: any EventLoopGroup

let logger: Logger

init(config: PostgresClient.Configuration, eventLoopGroup: any EventLoopGroup, logger: Logger) {
self.eventLoopGroup = eventLoopGroup
self.configBox = NIOLockedValueBox(ConfigCache(config: config))
self.logger = logger
}

func makeConnection(_ connectionID: PostgresConnection.ID, pool: PostgresClient.Pool) async throws -> PostgresConnection {
let config = try await self.makeConnectionConfig()

var connectionLogger = self.logger
connectionLogger[postgresMetadataKey: .connectionID] = "\(connectionID)"

return try await PostgresConnection.connect(
on: self.eventLoopGroup.any(),
configuration: config,
id: connectionID,
logger: connectionLogger
).get()
}

func makeConnectionConfig() async throws -> PostgresConnection.Configuration {
let config = self.configBox.withLockedValue { $0.config }

let tls: PostgresConnection.Configuration.TLS
switch config.tls.base {
case .prefer(let tlsConfiguration):
let sslContext = try await self.getSSLContext(for: tlsConfiguration)
tls = .prefer(sslContext)

case .require(let tlsConfiguration):
let sslContext = try await self.getSSLContext(for: tlsConfiguration)
tls = .require(sslContext)
case .disable:
tls = .disable
}

var connectionConfig: PostgresConnection.Configuration
switch config.endpointInfo {
case .bindUnixDomainSocket(let path):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Come to think of it, how about the "existing Channel" thing too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we want to use a single Channel for multiple connections. We will need to come up with a custom factory. That is a valuable goal. But not today.

connectionConfig = PostgresConnection.Configuration(
unixSocketPath: path,
username: config.username,
password: config.password,
database: config.database
)

case .connectTCP(let host, let port):
connectionConfig = PostgresConnection.Configuration(
host: host,
port: port,
username: config.username,
password: config.password,
database: config.database,
tls: tls
)
}

connectionConfig.options.connectTimeout = TimeAmount(config.options.connectTimeout)
connectionConfig.options.tlsServerName = config.options.tlsServerName
connectionConfig.options.requireBackendKeyData = config.options.requireBackendKeyData

return connectionConfig
}

private func getSSLContext(for tlsConfiguration: TLSConfiguration) async throws -> NIOSSLContext {
enum Action {
case produce
case succeed(NIOSSLContext)
case fail(any Error)
case wait
}

return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<NIOSSLContext, any Error>) in
let action = self.sslContextBox.withLockedValue { cache -> Action in
switch cache.state {
case .none:
cache.state = .producing(tlsConfiguration, [continuation])
return .produce

case .cached(let cachedTLSConfiguration, let context):
if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) {
return .succeed(context)
} else {
cache.state = .producing(tlsConfiguration, [continuation])
return .produce
}

case .failed(let cachedTLSConfiguration, let error):
if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) {
return .fail(error)
} else {
cache.state = .producing(tlsConfiguration, [continuation])
return .produce
}

case .producing(let cachedTLSConfiguration, var continuations):
continuations.append(continuation)
if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) {
cache.state = .producing(cachedTLSConfiguration, continuations)
return .wait
} else {
cache.state = .producing(tlsConfiguration, continuations)
return .produce
}
}
}

switch action {
case .wait:
break

case .produce:
// TBD: we might want to consider moving this off the concurrent executor
self.reportProduceSSLContextResult(
Result(catching: {try NIOSSLContext(configuration: tlsConfiguration)}),
for: tlsConfiguration
)

case .succeed(let context):
continuation.resume(returning: context)

case .fail(let error):
continuation.resume(throwing: error)
}
}
}

private func reportProduceSSLContextResult(_ result: Result<NIOSSLContext, any Error>, for tlsConfiguration: TLSConfiguration) {
enum Action {
case fail(any Error, [CheckedContinuation<NIOSSLContext, any Error>])
case succeed(NIOSSLContext, [CheckedContinuation<NIOSSLContext, any Error>])
case none
}

let action = self.sslContextBox.withLockedValue { cache -> Action in
switch cache.state {
case .none:
preconditionFailure("Invalid state: \(cache.state)")

case .cached, .failed:
return .none

case .producing(let cachedTLSConfiguration, let continuations):
if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) {
switch result {
case .success(let context):
cache.state = .cached(cachedTLSConfiguration, context)
return .succeed(context, continuations)

case .failure(let failure):
cache.state = .failed(cachedTLSConfiguration, failure)
return .fail(failure, continuations)
}
} else {
return .none
}
}
}

switch action {
case .none:
break

case .succeed(let context, let continuations):
for continuation in continuations {
continuation.resume(returning: context)
}

case .fail(let error, let continuations):
for continuation in continuations {
continuation.resume(throwing: error)
}
}
}
}
Loading