Skip to content

Commit

Permalink
Fix a crash on unknown user inbound events (#46)
Browse files Browse the repository at this point in the history
Also adds a lot of test utility :)
  • Loading branch information
lovetodream authored Jun 29, 2024
1 parent b7c9323 commit 619336f
Show file tree
Hide file tree
Showing 14 changed files with 666 additions and 45 deletions.
67 changes: 61 additions & 6 deletions Sources/OracleNIO/Connection/OracleConnection+Configuration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,43 @@ extension OracleConnection {
public var options: Options = .init()

/// The name or IP address of the machine hosting the database or the database listener.
public var host: String
public var host: String {
get {
switch self.endpointInfo {
case .configureChannel(let channel):
channel.localAddress?.ipAddress ?? ""
case .connectTCP(let host, _):
host
}
}
set {
switch self.endpointInfo {
case .configureChannel:
break // not part of public api
case .connectTCP(_, let port):
self.endpointInfo = .connectTCP(host: newValue, port: port)
}
}
}
/// The port number on which the database listener is listening.
public var port: Int
public var port: Int {
get {
switch self.endpointInfo {
case .configureChannel(let channel):
channel.localAddress?.port ?? 1521
case .connectTCP(_, let port):
port
}
}
set {
switch self.endpointInfo {
case .configureChannel:
break // not part of public api
case .connectTCP(let host, _):
self.endpointInfo = .connectTCP(host: host, port: newValue)
}
}
}

public var tls: TLS
public var serverNameForTLS: String? {
Expand Down Expand Up @@ -253,8 +287,10 @@ extension OracleConnection {
password: String,
tls: TLS = .disable
) {
self.host = host
self.port = port
self.endpointInfo = .connectTCP(
host: host,
port: port
)
self.service = service
self.authenticationMethod = {
.init(username: username, password: password)
Expand All @@ -270,20 +306,39 @@ extension OracleConnection {
@Sendable @autoclosure @escaping () -> OracleAuthenticationMethod,
tls: TLS = .disable
) {
self.host = host
self.port = port
self.endpointInfo = .connectTCP(
host: host,
port: port
)
self.service = service
self.authenticationMethod = authenticationMethod
self.tls = tls
}

init(
establishedChannel channel: Channel,
service: OracleServiceMethod,
username: String,
password: String
) {
self.endpointInfo = .configureChannel(channel)
self.service = service
self.authenticationMethod = {
.init(username: username, password: password)
}
self.tls = .disable
}


// MARK: - Implementation details

enum EndpointInfo {
case configureChannel(Channel)
case connectTCP(host: String, port: Int)
}

var endpointInfo: EndpointInfo

internal func getDescription() -> Description {
let address = Address(
protocol: self._protocol, host: self.host, port: self.port
Expand Down
56 changes: 36 additions & 20 deletions Sources/OracleNIO/Connection/OracleConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -198,20 +198,36 @@ public final class OracleConnection: Sendable {
id connectionID: ID,
logger: Logger
) -> EventLoopFuture<OracleConnection> {
var logger = logger
logger[oracleMetadataKey: .connectionID] = "\(connectionID)"

return eventLoop.flatSubmit { [logger] in
makeBootstrap(on: eventLoop, configuration: configuration)
.connect(host: configuration.host, port: configuration.port)
.flatMap { channel -> EventLoopFuture<OracleConnection> in
return OracleConnection.start(
configuration: configuration,
connectionID: connectionID,
channel: channel,
logger: logger
var mutableLogger = logger
mutableLogger[oracleMetadataKey: .connectionID] = "\(connectionID)"
let logger = mutableLogger

return eventLoop.flatSubmit {
let connectFuture: EventLoopFuture<Channel>

switch configuration.endpointInfo {
case .configureChannel(let channel):
guard channel.isActive else {
return eventLoop.makeFailedFuture(
OracleSQLError.connectionError(
underlying: ChannelError.alreadyClosed
)
)
}
connectFuture = eventLoop.makeSucceededFuture(channel)
case .connectTCP(let host, let port):
connectFuture = makeBootstrap(on: eventLoop, configuration: configuration)
.connect(host: host, port: port)
}

return connectFuture.flatMap { channel -> EventLoopFuture<OracleConnection> in
return OracleConnection.start(
configuration: configuration,
connectionID: connectionID,
channel: channel,
logger: logger
)
}
}
}

Expand Down Expand Up @@ -293,7 +309,7 @@ public final class OracleConnection: Sendable {
}

/// Closes the connection to the database server.
private func close() -> EventLoopFuture<Void> {
private func _close() -> EventLoopFuture<Void> {
guard !self.isClosed else {
return self.eventLoop.makeSucceededVoidFuture()
}
Expand All @@ -303,21 +319,21 @@ public final class OracleConnection: Sendable {
}

/// Sends a ping to the database server.
private func ping() -> EventLoopFuture<Void> {
private func _ping() -> EventLoopFuture<Void> {
let promise = self.eventLoop.makePromise(of: Void.self)
self.channel.write(OracleTask.ping(promise), promise: nil)
return promise.futureResult
}

/// Sends a commit to the database server.
private func commit() -> EventLoopFuture<Void> {
private func _commit() -> EventLoopFuture<Void> {
let promise = self.eventLoop.makePromise(of: Void.self)
self.channel.write(OracleTask.commit(promise), promise: nil)
return promise.futureResult
}

/// Sends a rollback to the database server.
private func rollback() -> EventLoopFuture<Void> {
private func _rollback() -> EventLoopFuture<Void> {
let promise = self.eventLoop.makePromise(of: Void.self)
self.channel.write(OracleTask.rollback(promise), promise: nil)
return promise.futureResult
Expand Down Expand Up @@ -393,22 +409,22 @@ extension OracleConnection {

/// Closes the connection to the database server.
public func close() async throws {
try await self.close().get()
try await self._close().get()
}

/// Sends a ping to the database server.
public func ping() async throws {
try await self.ping().get()
try await self._ping().get()
}

/// Sends a commit to the database server.
public func commit() async throws {
try await self.commit().get()
try await self._commit().get()
}

/// Sends a rollback to the database server.
public func rollback() async throws {
try await self.rollback().get()
try await self._rollback().get()
}

/// Run a statement on the Oracle server the connection is connected to.
Expand Down
2 changes: 1 addition & 1 deletion Sources/OracleNIO/OracleEventsHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ final class OracleEventsHandler: ChannelInboundHandler {
break
}
default:
preconditionFailure()
context.fireUserInboundEventTriggered(event)
}
}

Expand Down
8 changes: 4 additions & 4 deletions Tests/IntegrationTests/CustomTypeTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -174,22 +174,22 @@ struct CustomOracleObject: OracleDecodable {
case .object:
let typeOID =
if try buffer.throwingReadUB4() > 0 {
try buffer.readOracleSpecificLengthPrefixedSlice()!
try buffer.throwingReadOracleSpecificLengthPrefixedSlice()
} else { ByteBuffer() }
let oid =
if try buffer.throwingReadUB4() > 0 {
try buffer.readOracleSpecificLengthPrefixedSlice()!
try buffer.throwingReadOracleSpecificLengthPrefixedSlice()
} else { ByteBuffer() }
let snapshot =
if try buffer.throwingReadUB4() > 0 {
try buffer.readOracleSpecificLengthPrefixedSlice()!
try buffer.throwingReadOracleSpecificLengthPrefixedSlice()
} else { ByteBuffer() }
buffer.skipUB2() // version
let dataLength = try buffer.throwingReadUB4()
buffer.skipUB2() // flags
let data =
if dataLength > 0 {
try buffer.readOracleSpecificLengthPrefixedSlice()!
try buffer.throwingReadOracleSpecificLengthPrefixedSlice()
} else { ByteBuffer() }
self.init(typeOID: typeOID, oid: oid, snapshot: snapshot, data: data)
default:
Expand Down
2 changes: 1 addition & 1 deletion Tests/IntegrationTests/OracleClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ final class OracleClientTests: XCTestCase {

@available(macOS 14.0, *)
func testPingPong() async throws {
let idleTimeout = Duration.seconds(60)
let idleTimeout = Duration.seconds(20)
let config = try OracleConnection.testConfig()
var options = OracleClient.Options()
options.keepAliveBehavior?.frequency = .seconds(10)
Expand Down
7 changes: 6 additions & 1 deletion Tests/OracleNIOTests/OracleCodableTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,12 @@ final class OracleCodableTests: XCTestCase {

}

extension DataRow: ExpressibleByArrayLiteral {
#if compiler(>=6.0)
extension DataRow: @retroactive ExpressibleByArrayLiteral {}
#else
extension DataRow: ExpressibleByArrayLiteral {}
#endif
extension DataRow {
public typealias ArrayLiteralElement = OracleThrowingEncodable

public init(arrayLiteral elements: any OracleThrowingEncodable...) {
Expand Down
Loading

0 comments on commit 619336f

Please sign in to comment.