From 22363fed316cd9942b56bcd1a1df8875df79b794 Mon Sep 17 00:00:00 2001 From: tomer doron Date: Mon, 13 Jun 2022 10:28:56 -0700 Subject: [PATCH] add the ability to de-register a task (#121) motivation: sometimes a task needs to be de-registered since the task was manually shutdown outside the lifecycle scope changes: * registration APIs now return a registration key which can be used as a cancellation token * add API to de-register a task * refactor state to use a registery instead of array of tasks * add tests --- Sources/Lifecycle/Lifecycle.swift | 160 ++++++++++++++---- .../ComponentLifecycleTests+XCTest.swift | 2 + .../ComponentLifecycleTests.swift | 67 ++++++++ docker/docker-compose.yaml | 2 +- 4 files changed, 195 insertions(+), 36 deletions(-) diff --git a/Sources/Lifecycle/Lifecycle.swift b/Sources/Lifecycle/Lifecycle.swift index 7601a61..f5c5106 100644 --- a/Sources/Lifecycle/Lifecycle.swift +++ b/Sources/Lifecycle/Lifecycle.swift @@ -417,8 +417,13 @@ extension ServiceLifecycle { } extension ServiceLifecycle: LifecycleTasksContainer { - public func register(_ tasks: [LifecycleTask]) { - self.underlying.register(tasks) + @discardableResult + public func register(_ tasks: [LifecycleTask]) -> [RegistrationKey] { + return self.underlying.register(tasks) + } + + public func deregister(_ key: RegistrationKey) { + self.underlying.deregister(key) } } @@ -462,7 +467,7 @@ public class ComponentLifecycle: LifecycleTask { fileprivate let logger: Logger fileprivate let shutdownGroup = DispatchGroup() - private var state = State.idle([]) + private var state = State.idle(Registry()) private let stateLock = Lock() /// Creates a `ComponentLifecycle` instance. @@ -492,10 +497,10 @@ public class ComponentLifecycle: LifecycleTask { /// - on: `DispatchQueue` to run the handlers callback on /// - callback: The handler which is called after the start operation completes. The parameter will be `nil` on success and contain the `Error` otherwise. public func start(on queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) { - guard case .idle(let tasks) = (self.stateLock.withLock { self.state }) else { + guard case .idle(let registry) = (self.stateLock.withLock { self.state }) else { preconditionFailure("invalid state, \(self.state)") } - self._start(on: queue, tasks: tasks, callback: callback) + self._start(on: queue, registry: registry, callback: callback) } /// Starts the provided `LifecycleTask` array and waits (blocking) until `shutdown` is called on another thread. @@ -530,15 +535,15 @@ public class ComponentLifecycle: LifecycleTask { self.stateLock.lock() switch self.state { - case .idle(let tasks) where tasks.isEmpty: + case .idle(let registry) where registry.isEmpty: self.state = .shutdown(nil) self.stateLock.unlock() defer { self.shutdownGroup.leave() } callback(nil) - case .idle(let tasks): + case .idle(let registry): self.stateLock.unlock() // attempt to shutdown any registered tasks - let stoppable = tasks.filter { $0.shutdownIfNotStarted } + let stoppable = registry.tasks.filter { $0.shutdownIfNotStarted } setupShutdownListener(.global()) self._shutdown(on: .global(), tasks: stoppable, callback: self.shutdownGroup.leave) case .shutdown: @@ -552,10 +557,10 @@ public class ComponentLifecycle: LifecycleTask { case .shuttingDown(let queue): self.stateLock.unlock() setupShutdownListener(queue) - case .started(let queue, let tasks): + case .started(let queue, let registry): self.stateLock.unlock() setupShutdownListener(queue) - self._shutdown(on: queue, tasks: tasks, callback: self.shutdownGroup.leave) + self._shutdown(on: queue, tasks: registry.tasks, callback: self.shutdownGroup.leave) } } @@ -576,7 +581,7 @@ public class ComponentLifecycle: LifecycleTask { // MARK: - private - private func _start(on queue: DispatchQueue, tasks: [LifecycleTask], callback: @escaping (Error?) -> Void) { + private func _start(on queue: DispatchQueue, registry: Registry, callback: @escaping (Error?) -> Void) { self.stateLock.withLock { guard case .idle = self.state else { preconditionFailure("invalid state, \(self.state)") @@ -587,10 +592,10 @@ public class ComponentLifecycle: LifecycleTask { self.logger.info("starting") Counter(label: "\(self.label).lifecycle.start").increment() - if tasks.count == 0 { + if registry.isEmpty { self.logger.notice("no tasks provided") } - self.startTask(on: queue, tasks: tasks, index: 0) { started, error in + self.startTask(on: queue, tasks: registry.tasks, index: 0) { started, error in self.stateLock.lock() if error != nil { self.state = .shuttingDown(queue) @@ -600,8 +605,8 @@ public class ComponentLifecycle: LifecycleTask { self.stateLock.unlock() // shutdown was called while starting, or start failed, shutdown what we can var stoppable = started - if started.count < tasks.count { - let shutdownIfNotStarted = tasks.enumerated() + if started.count < registry.tasks.count { + let shutdownIfNotStarted = registry.tasks.enumerated() .filter { $0.offset >= started.count } .map { $0.element } .filter { $0.shutdownIfNotStarted } @@ -612,7 +617,7 @@ public class ComponentLifecycle: LifecycleTask { self.shutdownGroup.leave() } case .starting: - self.state = .started(queue, tasks) + self.state = .started(queue, registry) self.stateLock.unlock() callback(nil) default: @@ -697,70 +702,116 @@ public class ComponentLifecycle: LifecycleTask { } private enum State { - case idle([LifecycleTask]) + case idle(Registry) case starting(DispatchQueue) - case started(DispatchQueue, [LifecycleTask]) + case started(DispatchQueue, Registry) case shuttingDown(DispatchQueue) case shutdown([String: Error]?) } } extension ComponentLifecycle: LifecycleTasksContainer { - public func register(_ tasks: [LifecycleTask]) { + @discardableResult + public func register(_ newTasks: [LifecycleTask]) -> [RegistrationKey] { + let registrationKeys = self.stateLock.withLock { () -> [RegistrationKey] in + guard case .idle(let registry) = self.state else { + preconditionFailure("invalid state, \(self.state)") + } + return registry.add(newTasks) + } + return registrationKeys + } + + public func deregister(_ key: RegistrationKey) { + func remove(key: RegistrationKey, tasks: [LifecycleTask], keys: [RegistrationKey]) -> ([LifecycleTask], [RegistrationKey]) { + guard let index = keys.firstIndex(of: key) else { + return (tasks, keys) + } + var updatedTasks = tasks + updatedTasks.remove(at: index) + var updatedKeys = keys + updatedKeys.remove(at: index) + return (updatedTasks, updatedKeys) + } + self.stateLock.withLock { - guard case .idle(let existing) = self.state else { + switch self.state { + case .idle(let registry), .started(_, let registry): + registry.remove(key) + default: preconditionFailure("invalid state, \(self.state)") } - self.state = .idle(existing + tasks) } } } /// A container of `LifecycleTask`, used to register additional `LifecycleTask` public protocol LifecycleTasksContainer { - /// Adds a `LifecycleTask` to a `LifecycleTasks` collection. + typealias RegistrationKey = String + + /// Register a `LifecycleTask` with a `LifecycleTasksContainer`. /// /// - parameters: /// - tasks: array of `LifecycleTask`. - func register(_ tasks: [LifecycleTask]) + @discardableResult + func register(_ tasks: [LifecycleTask]) -> [RegistrationKey] + + /// De-register a `LifecycleTask` from a `LifecycleTasksContainer`. + /// + /// - parameters: + /// - registrationKey: The key returned by a register operation. + func deregister(_ key: RegistrationKey) } extension LifecycleTasksContainer { - /// Adds a `LifecycleTask` to a `LifecycleTasks` collection. + /// Register a `LifecycleTask` with a `LifecycleTasksContainer`. + /// + /// - parameters: + /// - tasks: one or more `LifecycleTask`. + @discardableResult + public func register(_ tasks: LifecycleTask ...) -> [RegistrationKey] { + return self.register(tasks) + } + + /// Register a `LifecycleTask` with a `LifecycleTasksContainer`. /// /// - parameters: /// - tasks: one or more `LifecycleTask`. - public func register(_ tasks: LifecycleTask ...) { - self.register(tasks) + @discardableResult + public func register(_ tasks: LifecycleTask) -> RegistrationKey { + return self.register(tasks).first! // force the optional on the first in this case is safe } - /// Adds a `LifecycleTask` to a `LifecycleTasks` collection. + /// Register a `LifecycleTask` with a `LifecycleTasksContainer`. /// /// - parameters: /// - label: label of the item, useful for debugging. /// - start: `Handler` to perform the startup. /// - shutdown: `Handler` to perform the shutdown. - public func register(label: String, start: LifecycleHandler, shutdown: LifecycleHandler, shutdownIfNotStarted: Bool? = nil) { - self.register(_LifecycleTask(label: label, shutdownIfNotStarted: shutdownIfNotStarted, start: start, shutdown: shutdown)) + @discardableResult + public func register(label: String, start: LifecycleHandler, shutdown: LifecycleHandler, shutdownIfNotStarted: Bool? = nil) -> RegistrationKey { + return self.register(_LifecycleTask(label: label, shutdownIfNotStarted: shutdownIfNotStarted, start: start, shutdown: shutdown)) } - /// Adds a `LifecycleTask` to a `LifecycleTasks` collection. + /// Register a `LifecycleTask` with a `LifecycleTasksContainer`. /// /// - parameters: /// - label: label of the item, useful for debugging. /// - handler: `Handler` to perform the shutdown. - public func registerShutdown(label: String, _ handler: LifecycleHandler) { - self.register(label: label, start: .none, shutdown: handler) + @discardableResult + public func registerShutdown(label: String, _ handler: LifecycleHandler) -> RegistrationKey { + return self.register(label: label, start: .none, shutdown: handler) } - /// Add a stateful `LifecycleTask` to a `LifecycleTasks` collection. + /// Register a stateful `LifecycleTask` with a `LifecycleTasksContainer`. /// /// - parameters: /// - label: label of the item, useful for debugging. /// - start: `LifecycleStartHandler` to perform the startup and return the state. /// - shutdown: `LifecycleShutdownHandler` to perform the shutdown given the state. - public func registerStateful(label: String, start: LifecycleStartHandler, shutdown: LifecycleShutdownHandler) { - self.register(StatefulLifecycleTask(label: label, start: start, shutdown: shutdown)) + @discardableResult + public func registerStateful(label: String, start: LifecycleStartHandler, shutdown: LifecycleShutdownHandler) -> RegistrationKey { + return self.register(StatefulLifecycleTask(label: label, start: start, shutdown: shutdown)) } } @@ -830,3 +881,42 @@ internal class StatefulLifecycleTask: LifecycleTask { struct UnknownState: Error {} } + +private class Registry { + typealias RegistrationKey = LifecycleTasksContainer.RegistrationKey + + private var _tasks: [LifecycleTask] = [] + private var keys: [LifecycleTasksContainer.RegistrationKey] = [] + private let lock = Lock() + + func add(_ tasks: [LifecycleTask]) -> [RegistrationKey] { + // FIXME: better id generation scheme (cant use UUID) + let keys: [RegistrationKey] = tasks.map { _ in + let random = UInt64.random(in: UInt64.min ..< UInt64.max).addingReportingOverflow(DispatchTime.now().uptimeNanoseconds).partialValue + return "task-\(random)" + } + self.lock.withLock { + self._tasks.append(contentsOf: tasks) + self.keys.append(contentsOf: keys) + } + return keys + } + + func remove(_ key: RegistrationKey) { + self.lock.withLock { + guard let index = self.keys.firstIndex(of: key) else { + return + } + self._tasks.remove(at: index) + self.keys.remove(at: index) + } + } + + var tasks: [LifecycleTask] { + return self.lock.withLock { self._tasks } + } + + var isEmpty: Bool { + return self.lock.withLock { self._tasks.isEmpty } + } +} diff --git a/Tests/LifecycleTests/ComponentLifecycleTests+XCTest.swift b/Tests/LifecycleTests/ComponentLifecycleTests+XCTest.swift index bc3f74d..656bb1f 100644 --- a/Tests/LifecycleTests/ComponentLifecycleTests+XCTest.swift +++ b/Tests/LifecycleTests/ComponentLifecycleTests+XCTest.swift @@ -26,6 +26,8 @@ extension ComponentLifecycleTests { static var allTests: [(String, (ComponentLifecycleTests) -> () throws -> Void)] { return [ ("testStartThenShutdown", testStartThenShutdown), + ("testDeregister", testDeregister), + ("testDeregisterAfterStart", testDeregisterAfterStart), ("testDefaultCallbackQueue", testDefaultCallbackQueue), ("testUserDefinedCallbackQueue", testUserDefinedCallbackQueue), ("testShutdownWhileStarting", testShutdownWhileStarting), diff --git a/Tests/LifecycleTests/ComponentLifecycleTests.swift b/Tests/LifecycleTests/ComponentLifecycleTests.swift index 8e09c92..b20116d 100644 --- a/Tests/LifecycleTests/ComponentLifecycleTests.swift +++ b/Tests/LifecycleTests/ComponentLifecycleTests.swift @@ -33,6 +33,73 @@ final class ComponentLifecycleTests: XCTestCase { items.forEach { XCTAssertEqual($0.state, .shutdown, "expected item to be shutdown, but \($0.state)") } } + func testDeregister() { + class BadItem: LifecycleTask { + let label: String = UUID().uuidString + + func start(_ callback: (Error?) -> Void) { + callback(TestError()) + } + + func shutdown(_ callback: (Error?) -> Void) { + callback(TestError()) + } + } + + let lifecycle = ComponentLifecycle(label: "test") + let itemToDeregister1 = BadItem() + let itemToDeregister2 = BadItem() + lifecycle.register(GoodItem()) + let key1 = lifecycle.register(itemToDeregister1) + lifecycle.register(GoodItem()) + lifecycle.register(GoodItem()) + let key2 = lifecycle.register(itemToDeregister2) + + lifecycle.deregister(key1) + lifecycle.deregister(key2) + + lifecycle.start { startError in + XCTAssertNil(startError, "not expecting error") + lifecycle.shutdown { shutdownErrors in + XCTAssertNil(shutdownErrors, "not expecting error") + } + } + lifecycle.wait() + } + + func testDeregisterAfterStart() { + class BadItem: LifecycleTask { + let label: String = UUID().uuidString + + func start(_ callback: (Error?) -> Void) { + callback(.none) // okay + } + + func shutdown(_ callback: (Error?) -> Void) { + callback(TestError()) + } + } + + let lifecycle = ComponentLifecycle(label: "test") + let itemToDeregister1 = BadItem() + let itemToDeregister2 = BadItem() + lifecycle.register(GoodItem()) + let key1 = lifecycle.register(itemToDeregister1) + lifecycle.register(GoodItem()) + lifecycle.register(GoodItem()) + let key2 = lifecycle.register(itemToDeregister2) + + lifecycle.start { startError in + XCTAssertNil(startError, "not expecting error") + lifecycle.deregister(key1) + lifecycle.deregister(key2) + lifecycle.shutdown { shutdownErrors in + XCTAssertNil(shutdownErrors, "not expecting error") + } + } + lifecycle.wait() + } + func testDefaultCallbackQueue() throws { guard #available(OSX 10.12, *) else { return diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 858dbbd..0c380f8 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -34,4 +34,4 @@ services: shell: <<: *common - entrypoint: /bin/bash + entrypoint: /bin/bash -l