Skip to content

Commit

Permalink
shutdown() should cancel the signal handlers installed by start()
Browse files Browse the repository at this point in the history
motivation: allow easier testing of shutdown hooks

changes:
* introduce ServiceLifecycle.removeTrap which removes a trap
* call ServiceLifecycle.removeTrap when setting up the shutdown hook
* make the shutdown hook cleanup into a lifecycle task to ensure correct ordering
* add tests
* improve logging

rdar://89552798
  • Loading branch information
tomerd committed May 17, 2022
1 parent e63be9e commit f070ede
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 20 deletions.
64 changes: 50 additions & 14 deletions Sources/Lifecycle/Lifecycle.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,22 @@ public protocol LifecycleTask {
var shutdownIfNotStarted: Bool { get }
func start(_ callback: @escaping (Error?) -> Void)
func shutdown(_ callback: @escaping (Error?) -> Void)
var logStart: Bool { get }
var logShutdown: Bool { get }
}

extension LifecycleTask {
public var shutdownIfNotStarted: Bool {
return false
}

public var logStart: Bool {
return true
}

public var logShutdown: Bool {
return true
}
}

// MARK: - LifecycleHandler
Expand Down Expand Up @@ -317,9 +327,14 @@ public struct ServiceLifecycle {
self.log("intercepted signal: \(signal)")
self.shutdown()
}, cancelAfterTrap: true)
self.underlying.shutdownGroup.notify(queue: .global()) {
signalSource.cancel()
}
// register cleanup as the last task
self.registerShutdown(label: "\(signal) shutdown hook cleanup", .sync {
// cancel if not already canceled by the trap
if !signalSource.isCancelled {
signalSource.cancel()
ServiceLifecycle.removeTrap(signal: signal)
}
})
}
}

Expand All @@ -343,22 +358,34 @@ extension ServiceLifecycle {
public static func trap(signal sig: Signal, handler: @escaping (Signal) -> Void, on queue: DispatchQueue = .global(), cancelAfterTrap: Bool = false) -> DispatchSourceSignal {
// on linux, we can call singal() once per process
self.trappedLock.withLockVoid {
if !trapped.contains(sig.rawValue) {
if !self.trapped.contains(sig.rawValue) {
signal(sig.rawValue, SIG_IGN)
trapped.insert(sig.rawValue)
self.trapped.insert(sig.rawValue)
}
}
let signalSource = DispatchSource.makeSignalSource(signal: sig.rawValue, queue: queue)
signalSource.setEventHandler {
// run handler first
handler(sig)
// then cancel trap if so requested
if cancelAfterTrap {
signalSource.cancel()
self.removeTrap(signal: sig)
}
handler(sig)
}
signalSource.resume()
return signalSource
}

public static func removeTrap(signal sig: Signal) {
self.trappedLock.withLockVoid {
if self.trapped.contains(sig.rawValue) {
signal(sig.rawValue, SIG_DFL)
self.trapped.remove(sig.rawValue)
}
}
}

/// A system signal
public struct Signal: Equatable, CustomStringConvertible {
internal var rawValue: CInt
Expand Down Expand Up @@ -413,7 +440,8 @@ extension ServiceLifecycle {
logger: Logger? = nil,
callbackQueue: DispatchQueue = .global(),
shutdownSignal: [Signal]? = [.TERM, .INT],
installBacktrace: Bool = true) {
installBacktrace: Bool = true)
{
self.label = label
self.logger = logger ?? Logger(label: label)
self.callbackQueue = callbackQueue
Expand All @@ -433,7 +461,7 @@ struct ShutdownError: Error {
public class ComponentLifecycle: LifecycleTask {
public let label: String
fileprivate let logger: Logger
internal let shutdownGroup = DispatchGroup()
fileprivate let shutdownGroup = DispatchGroup()

private var state = State.idle([])
private let stateLock = Lock()
Expand Down Expand Up @@ -596,13 +624,15 @@ public class ComponentLifecycle: LifecycleTask {

private func startTask(on queue: DispatchQueue, tasks: [LifecycleTask], index: Int, callback: @escaping ([LifecycleTask], Error?) -> Void) {
// async barrier
let start = { (callback) -> Void in queue.async { tasks[index].start(callback) } }
let callback = { (index, error) -> Void in queue.async { callback(index, error) } }
let start = { callback in queue.async { tasks[index].start(callback) } }
let callback = { index, error in queue.async { callback(index, error) } }

if index >= tasks.count {
return callback(tasks, nil)
}
self.logger.info("starting tasks [\(tasks[index].label)]")
if tasks[index].logStart {
self.logger.info("starting tasks [\(tasks[index].label)]")
}
let startTime = DispatchTime.now()
start { error in
Timer(label: "\(self.label).\(tasks[index].label).lifecycle.start").recordNanoseconds(DispatchTime.now().uptimeNanoseconds - startTime.uptimeNanoseconds)
Expand Down Expand Up @@ -642,14 +672,16 @@ public class ComponentLifecycle: LifecycleTask {

private func shutdownTask(on queue: DispatchQueue, tasks: [LifecycleTask], index: Int, errors: [String: Error]?, callback: @escaping ([String: Error]?) -> Void) {
// async barrier
let shutdown = { (callback) -> Void in queue.async { tasks[index].shutdown(callback) } }
let callback = { (errors) -> Void in queue.async { callback(errors) } }
let shutdown = { callback in queue.async { tasks[index].shutdown(callback) } }
let callback = { errors in queue.async { callback(errors) } }

if index >= tasks.count {
return callback(errors)
}

self.logger.info("stopping tasks [\(tasks[index].label)]")
if tasks[index].logShutdown {
self.logger.info("stopping tasks [\(tasks[index].label)]")
}
let startTime = DispatchTime.now()
shutdown { error in
Timer(label: "\(self.label).\(tasks[index].label).lifecycle.shutdown").recordNanoseconds(DispatchTime.now().uptimeNanoseconds - startTime.uptimeNanoseconds)
Expand Down Expand Up @@ -739,12 +771,16 @@ internal struct _LifecycleTask: LifecycleTask {
let shutdownIfNotStarted: Bool
let start: LifecycleHandler
let shutdown: LifecycleHandler
let logStart: Bool
let logShutdown: Bool

init(label: String, shutdownIfNotStarted: Bool? = nil, start: LifecycleHandler, shutdown: LifecycleHandler) {
self.label = label
self.shutdownIfNotStarted = shutdownIfNotStarted ?? start.noop
self.start = start
self.shutdown = shutdown
self.logStart = !start.noop
self.logShutdown = !shutdown.noop
}

func start(_ callback: @escaping (Error?) -> Void) {
Expand Down
4 changes: 2 additions & 2 deletions Sources/Lifecycle/Locks.swift
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ extension Lock {
/// - Parameter body: The block to execute while holding the lock.
/// - Returns: The value returned by the block.
@inlinable
internal func withLock<T>(_ body: () throws -> T) rethrows -> T {
func withLock<T>(_ body: () throws -> T) rethrows -> T {
self.lock()
defer {
self.unlock()
Expand All @@ -91,7 +91,7 @@ extension Lock {

// specialise Void return (for performance)
@inlinable
internal func withLockVoid(_ body: () throws -> Void) rethrows {
func withLockVoid(_ body: () throws -> Void) rethrows {
try self.withLock(body)
}
}
4 changes: 2 additions & 2 deletions Tests/LifecycleTests/ComponentLifecycleTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ final class ComponentLifecycleTests: XCTestCase {
dispatchPrecondition(condition: .onQueue(.global()))
XCTAssertTrue(startCalls.contains(id))
stopCalls.append(id)
})
})
}
lifecycle.register(items)

Expand Down Expand Up @@ -92,7 +92,7 @@ final class ComponentLifecycleTests: XCTestCase {
dispatchPrecondition(condition: .onQueue(testQueue))
XCTAssertTrue(startCalls.contains(id))
stopCalls.append(id)
})
})
}
lifecycle.register(items)

Expand Down
6 changes: 4 additions & 2 deletions Tests/LifecycleTests/Helpers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class GoodItem: LifecycleTask {

init(id: String = UUID().uuidString,
startDelay: Double = Double.random(in: 0.01 ... 0.1),
shutdownDelay: Double = Double.random(in: 0.01 ... 0.1)) {
shutdownDelay: Double = Double.random(in: 0.01 ... 0.1))
{
self.id = id
self.startDelay = startDelay
self.shutdownDelay = shutdownDelay
Expand Down Expand Up @@ -72,7 +73,8 @@ class NIOItem {
init(eventLoopGroup: EventLoopGroup,
id: String = UUID().uuidString,
startDelay: Int64 = Int64.random(in: 10 ... 20),
shutdownDelay: Int64 = Int64.random(in: 10 ... 20)) {
shutdownDelay: Int64 = Int64.random(in: 10 ... 20))
{
self.id = id
self.eventLoopGroup = eventLoopGroup
self.startDelay = startDelay
Expand Down
1 change: 1 addition & 0 deletions Tests/LifecycleTests/ServiceLifecycleTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ extension ServiceLifecycleTests {
("testSignalDescription", testSignalDescription),
("testBacktracesInstalledOnce", testBacktracesInstalledOnce),
("testRepeatShutdown", testRepeatShutdown),
("testShutdownCancelSignal", testShutdownCancelSignal),
]
}
}
43 changes: 43 additions & 0 deletions Tests/LifecycleTests/ServiceLifecycleTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -271,4 +271,47 @@ final class ServiceLifecycleTests: XCTestCase {

XCTAssertEqual(attempts, count)
}

func testShutdownCancelSignal() {
if ProcessInfo.processInfo.environment["SKIP_SIGNAL_TEST"].flatMap(Bool.init) ?? false {
print("skipping testRepeatShutdown")
return
}

struct Service {
static let signal = ServiceLifecycle.Signal.ALRM

let lifecycle: ServiceLifecycle

init() {
self.lifecycle = ServiceLifecycle(configuration: .init(shutdownSignal: [Service.signal]))
self.lifecycle.register(GoodItem())
}
}

let service = Service()
service.lifecycle.start { error in
XCTAssertNil(error, "not expecting error")
kill(getpid(), Service.signal.rawValue)
}
service.lifecycle.wait()

var count = 0
let sync = DispatchGroup()
sync.enter()
let signalSource = ServiceLifecycle.trap(signal: Service.signal, handler: { _ in
count = count + 1 // not thread safe but fine for this purpose
sync.leave()
}, cancelAfterTrap: false)

// since we are removing the hook added by lifecycle on shutdown,
// this will fail unless a new hook is set up as done above
kill(getpid(), Service.signal.rawValue)

XCTAssertEqual(.success, sync.wait(timeout: .now() + 2))
XCTAssertEqual(count, 1)

signalSource.cancel()
ServiceLifecycle.removeTrap(signal: Service.signal)
}
}

0 comments on commit f070ede

Please sign in to comment.