Skip to content

Commit

Permalink
Refactor PSQLRowStream to make async/await easier (#201)
Browse files Browse the repository at this point in the history
### Motivation

`PSQLRowStream`'s current implementation is interesting. It should be better tested and easier to follow for async/await support later.

### Changes

- Make `PSQLRowStream`'s implementation more sensible
- Add unit tests for `PSQLRowStream`

### Result

Adding async/await support becomes easier.
  • Loading branch information
fabianfett authored Nov 26, 2021
1 parent 81ca909 commit 780a510
Show file tree
Hide file tree
Showing 3 changed files with 429 additions and 104 deletions.
6 changes: 6 additions & 0 deletions Sources/PostgresNIO/New/PSQLRow.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ struct PSQLRow {
}
}

extension PSQLRow: Equatable {
static func ==(lhs: Self, rhs: Self) -> Bool {
lhs.data == rhs.data && lhs.columns == rhs.columns
}
}

extension PSQLRow {
/// Access the data in the provided column and decode it into the target type.
///
Expand Down
204 changes: 100 additions & 104 deletions Sources/PostgresNIO/New/PSQLRowStream.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import NIOCore
import Logging

final class PSQLRowStream {

enum RowSource {
case stream(PSQLRowsDataSource)
case noRows(Result<String, Error>)
Expand All @@ -11,23 +10,21 @@ final class PSQLRowStream {
let eventLoop: EventLoop
let logger: Logger

private enum UpstreamState {
private enum BufferState {
case streaming(buffer: CircularBuffer<DataRow>, dataSource: PSQLRowsDataSource)
case finished(buffer: CircularBuffer<DataRow>, commandTag: String)
case failure(Error)
case consumed(Result<String, Error>)
case modifying
}

private enum DownstreamState {
case iteratingRows(onRow: (PSQLRow) throws -> (), EventLoopPromise<Void>)
case waitingForAll(EventLoopPromise<[PSQLRow]>)
case consuming
case waitingForConsumer(BufferState)
case iteratingRows(onRow: (PSQLRow) throws -> (), EventLoopPromise<Void>, PSQLRowsDataSource)
case waitingForAll([PSQLRow], EventLoopPromise<[PSQLRow]>, PSQLRowsDataSource)
case consumed(Result<String, Error>)
}

internal let rowDescription: [RowDescription.Column]
private let lookupTable: [String: Int]
private var upstreamState: UpstreamState
private var downstreamState: DownstreamState
private let jsonDecoder: PSQLJSONDecoder

Expand All @@ -36,30 +33,33 @@ final class PSQLRowStream {
eventLoop: EventLoop,
rowSource: RowSource)
{
let buffer = CircularBuffer<DataRow>()

self.downstreamState = .consuming
let bufferState: BufferState
switch rowSource {
case .stream(let dataSource):
self.upstreamState = .streaming(buffer: buffer, dataSource: dataSource)
bufferState = .streaming(buffer: .init(), dataSource: dataSource)
case .noRows(.success(let commandTag)):
self.upstreamState = .finished(buffer: .init(), commandTag: commandTag)
bufferState = .finished(buffer: .init(), commandTag: commandTag)
case .noRows(.failure(let error)):
self.upstreamState = .failure(error)
bufferState = .failure(error)
}

self.downstreamState = .waitingForConsumer(bufferState)

self.eventLoop = eventLoop
self.logger = queryContext.logger
self.jsonDecoder = queryContext.jsonDecoder

self.rowDescription = rowDescription

var lookup = [String: Int]()
lookup.reserveCapacity(rowDescription.count)
rowDescription.enumerated().forEach { (index, column) in
lookup[column.name] = index
}
self.lookupTable = lookup
}

// MARK: Consume in array

func all() -> EventLoopFuture<[PSQLRow]> {
if self.eventLoop.inEventLoop {
Expand All @@ -74,40 +74,37 @@ final class PSQLRowStream {
private func all0() -> EventLoopFuture<[PSQLRow]> {
self.eventLoop.preconditionInEventLoop()

guard case .consuming = self.downstreamState else {
preconditionFailure("Invalid state")
guard case .waitingForConsumer(let bufferState) = self.downstreamState else {
preconditionFailure("Invalid state: \(self.downstreamState)")
}

switch self.upstreamState {
case .streaming(_, let dataSource):
dataSource.request(for: self)
switch bufferState {
case .streaming(let bufferedRows, let dataSource):
let promise = self.eventLoop.makePromise(of: [PSQLRow].self)
self.downstreamState = .waitingForAll(promise)
let rows = bufferedRows.map { data in
PSQLRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder)
}
self.downstreamState = .waitingForAll(rows, promise, dataSource)
// immediately request more
dataSource.request(for: self)
return promise.futureResult

case .finished(let buffer, let commandTag):
self.upstreamState = .modifying

let rows = buffer.map {
PSQLRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder)
}

self.downstreamState = .consuming
self.upstreamState = .consumed(.success(commandTag))
self.downstreamState = .consumed(.success(commandTag))
return self.eventLoop.makeSucceededFuture(rows)

case .consumed:
preconditionFailure("We already signaled, that the stream has completed, why are we asked again?")

case .modifying:
preconditionFailure("Invalid state")

case .failure(let error):
self.upstreamState = .consumed(.failure(error))
self.downstreamState = .consumed(.failure(error))
return self.eventLoop.makeFailedFuture(error)
}
}

// MARK: Consume on EventLoop

func onRow(_ onRow: @escaping (PSQLRow) throws -> ()) -> EventLoopFuture<Void> {
if self.eventLoop.inEventLoop {
return self.onRow0(onRow)
Expand All @@ -121,7 +118,11 @@ final class PSQLRowStream {
private func onRow0(_ onRow: @escaping (PSQLRow) throws -> ()) -> EventLoopFuture<Void> {
self.eventLoop.preconditionInEventLoop()

switch self.upstreamState {
guard case .waitingForConsumer(let bufferState) = self.downstreamState else {
preconditionFailure("Invalid state: \(self.downstreamState)")
}

switch bufferState {
case .streaming(var buffer, let dataSource):
let promise = self.eventLoop.makePromise(of: Void.self)
do {
Expand All @@ -136,12 +137,11 @@ final class PSQLRowStream {
}

buffer.removeAll()
self.upstreamState = .streaming(buffer: buffer, dataSource: dataSource)
self.downstreamState = .iteratingRows(onRow: onRow, promise)
self.downstreamState = .iteratingRows(onRow: onRow, promise, dataSource)
// immediately request more
dataSource.request(for: self)
} catch {
self.upstreamState = .failure(error)
self.downstreamState = .consumed(.failure(error))
dataSource.cancel(for: self)
promise.fail(error)
}
Expand All @@ -160,22 +160,15 @@ final class PSQLRowStream {
try onRow(row)
}

self.upstreamState = .consumed(.success(commandTag))
self.downstreamState = .consuming
self.downstreamState = .consumed(.success(commandTag))
return self.eventLoop.makeSucceededVoidFuture()
} catch {
self.upstreamState = .consumed(.failure(error))
self.downstreamState = .consumed(.failure(error))
return self.eventLoop.makeFailedFuture(error)
}

case .consumed:
preconditionFailure("We already signaled, that the stream has completed, why are we asked again?")

case .modifying:
preconditionFailure("Invalid state")

case .failure(let error):
self.upstreamState = .consumed(.failure(error))
self.downstreamState = .consumed(.failure(error))
return self.eventLoop.makeFailedFuture(error)
}
}
Expand All @@ -193,13 +186,15 @@ final class PSQLRowStream {
"row_count": "\(newRows.count)"
])

guard case .streaming(var buffer, let dataSource) = self.upstreamState else {
preconditionFailure("Invalid state")
}

switch self.downstreamState {
case .iteratingRows(let onRow, let promise):
precondition(buffer.isEmpty)
case .waitingForConsumer(.streaming(buffer: var buffer, dataSource: let dataSource)):
buffer.append(contentsOf: newRows)
self.downstreamState = .waitingForConsumer(.streaming(buffer: buffer, dataSource: dataSource))

case .waitingForConsumer(.finished), .waitingForConsumer(.failure):
preconditionFailure("How can new rows be received, if an end was already signalled?")

case .iteratingRows(let onRow, let promise, let dataSource):
do {
for data in newRows {
let row = PSQLRow(
Expand All @@ -214,82 +209,83 @@ final class PSQLRowStream {
dataSource.request(for: self)
} catch {
dataSource.cancel(for: self)
self.upstreamState = .failure(error)
self.downstreamState = .consumed(.failure(error))
promise.fail(error)
return
}
case .waitingForAll:
self.upstreamState = .modifying
buffer.append(contentsOf: newRows)
self.upstreamState = .streaming(buffer: buffer, dataSource: dataSource)


case .waitingForAll(var rows, let promise, let dataSource):
newRows.forEach { data in
let row = PSQLRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder)
rows.append(row)
}
self.downstreamState = .waitingForAll(rows, promise, dataSource)
// immediately request more
dataSource.request(for: self)

case .consuming:
// this might happen, if the query has finished while the user is consuming data
// we don't need to ask for more since the user is consuming anyway
self.upstreamState = .modifying
buffer.append(contentsOf: newRows)
self.upstreamState = .streaming(buffer: buffer, dataSource: dataSource)
case .consumed(.success):
preconditionFailure("How can we receive further rows, if we are supposed to be done")

case .consumed(.failure):
break
}
}

internal func receive(completion result: Result<String, Error>) {
self.eventLoop.preconditionInEventLoop()

guard case .streaming(let oldBuffer, _) = self.upstreamState else {
preconditionFailure("Invalid state")
switch result {
case .success(let commandTag):
self.receiveEnd(commandTag)
case .failure(let error):
self.receiveError(error)
}
}

private func receiveEnd(_ commandTag: String) {
switch self.downstreamState {
case .iteratingRows(_, let promise):
precondition(oldBuffer.isEmpty)
self.downstreamState = .consuming
self.upstreamState = .consumed(result)
switch result {
case .success:
promise.succeed(())
case .failure(let error):
promise.fail(error)
}
case .waitingForConsumer(.streaming(buffer: let buffer, _)):
self.downstreamState = .waitingForConsumer(.finished(buffer: buffer, commandTag: commandTag))

case .waitingForConsumer(.finished), .waitingForConsumer(.failure):
preconditionFailure("How can we get another end, if an end was already signalled?")

case .consuming:
switch result {
case .success(let commandTag):
self.upstreamState = .finished(buffer: oldBuffer, commandTag: commandTag)
case .failure(let error):
self.upstreamState = .failure(error)
}

case .waitingForAll(let promise):
switch result {
case .failure(let error):
self.upstreamState = .consumed(.failure(error))
promise.fail(error)
case .success(let commandTag):
let rows = oldBuffer.map {
PSQLRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder)
}
self.upstreamState = .consumed(.success(commandTag))
promise.succeed(rows)
}
case .iteratingRows(_, let promise, _):
self.downstreamState = .consumed(.success(commandTag))
promise.succeed(())

case .waitingForAll(let rows, let promise, _):
self.downstreamState = .consumed(.success(commandTag))
promise.succeed(rows)

case .consumed:
break
}
}

func cancel() {
guard case .streaming(_, let dataSource) = self.upstreamState else {
// We don't need to cancel any upstream resource. All needed data is already
// included in this
return
}

dataSource.cancel(for: self)
private func receiveError(_ error: Error) {
switch self.downstreamState {
case .waitingForConsumer(.streaming):
self.downstreamState = .waitingForConsumer(.failure(error))

case .waitingForConsumer(.finished), .waitingForConsumer(.failure):
preconditionFailure("How can we get another end, if an end was already signalled?")

case .iteratingRows(_, let promise, _):
self.downstreamState = .consumed(.failure(error))
promise.fail(error)

case .waitingForAll(_, let promise, _):
self.downstreamState = .consumed(.failure(error))
promise.fail(error)

case .consumed:
break
}
}

var commandTag: String {
guard case .consumed(.success(let commandTag)) = self.upstreamState else {
guard case .consumed(.success(let commandTag)) = self.downstreamState else {
preconditionFailure("commandTag may only be called if all rows have been consumed")
}
return commandTag
Expand Down
Loading

0 comments on commit 780a510

Please sign in to comment.