Skip to content

Commit

Permalink
Add concurrency support for abstract registrations
Browse files Browse the repository at this point in the history
  • Loading branch information
bradfol committed Dec 13, 2024
1 parent b2c54c8 commit fc751a4
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 7 deletions.
14 changes: 14 additions & 0 deletions Sources/Knit/ConcurrencyAttribute.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//
// Copyright © Block, Inc. All rights reserved.
//

/// The possible concurrency isolation for a registration.
public enum ConcurrencyAttribute: Codable, Sendable {
/// We do not currently have a way to forward this information through Swinject Behavior hooks
/// so registrations that come from behaviors will be unknown.
case unknown

case nonisolated

case MainActor
}
35 changes: 29 additions & 6 deletions Sources/Knit/Module/Container+AbstractRegistration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ extension Container {
public func registerAbstract<Service>(
_ serviceType: Service.Type,
name: String? = nil,
concurrency: ConcurrencyAttribute = .nonisolated,
file: String = #fileID
) {
let registration = RealAbstractRegistration<Service>(name: name, file: file)
let registration = RealAbstractRegistration<Service>(name: name, file: file, concurrency: concurrency)
abstractRegistrations().abstractRegistrations.append(registration)
}

Expand All @@ -26,9 +27,10 @@ extension Container {
public func registerAbstract<Service>(
_ serviceType: Optional<Service>.Type,
name: String? = nil,
concurrency: ConcurrencyAttribute = .nonisolated,
file: String = #fileID
) {
let registration = OptionalAbstractRegistration<Service>(name: name, file: file)
let registration = OptionalAbstractRegistration<Service>(name: name, file: file, concurrency: concurrency)
abstractRegistrations().abstractRegistrations.append(registration)
}

Expand All @@ -50,6 +52,7 @@ extension Container {
internal struct RegistrationKey: Hashable, Equatable {
let typeIdentifier: ObjectIdentifier
let name: String?
let concurrency: ConcurrencyAttribute
}

/// Protocol version to allow storing generic types an array
Expand All @@ -60,6 +63,7 @@ internal protocol AbstractRegistration {
var file: String { get }
var name: String? { get }
var key: RegistrationKey { get }
var concurrency: ConcurrencyAttribute { get }

/// Register a placeholder registration to fill the unfulfilled abstract registration
/// This placeholder cannot be resolved
Expand Down Expand Up @@ -89,8 +93,14 @@ fileprivate struct RealAbstractRegistration<ServiceType>: AbstractRegistration {

var serviceType: ServiceType.Type { ServiceType.self }

let concurrency: ConcurrencyAttribute

var key: RegistrationKey {
return .init(typeIdentifier: ObjectIdentifier(ServiceType.self), name: name)
return .init(
typeIdentifier: ObjectIdentifier(ServiceType.self),
name: name,
concurrency: concurrency
)
}

func registerPlaceholder(
Expand All @@ -113,8 +123,10 @@ fileprivate struct OptionalAbstractRegistration<ServiceType>: AbstractRegistrati

var serviceType: ServiceType.Type { ServiceType.self }

let concurrency: ConcurrencyAttribute

var key: RegistrationKey {
return .init(typeIdentifier: ObjectIdentifier(ServiceType.self), name: name)
return .init(typeIdentifier: ObjectIdentifier(ServiceType.self), name: name, concurrency: concurrency)
}

func registerPlaceholder(
Expand Down Expand Up @@ -171,12 +183,23 @@ extension Container {
toService entry: ServiceEntry<Service>,
withName name: String?
) {
let id = RegistrationKey(typeIdentifier: ObjectIdentifier(Type.self), name: name)
let id = RegistrationKey(
typeIdentifier: ObjectIdentifier(Type.self),
name: name,
concurrency: .unknown
)
concreteRegistrations.insert(id)
}

var unfulfilledRegistrations: [any AbstractRegistration] {
abstractRegistrations.filter { !concreteRegistrations.contains($0.key) }
abstractRegistrations.filter { abstractRegistration in
let abstractKey = abstractRegistration.key
return !concreteRegistrations.contains { concreteKey in
// We need to ignore the concurrency attribute currently due to Swinject limitations
concreteKey.typeIdentifier == abstractKey.typeIdentifier &&
concreteKey.name == abstractKey.name
}
}
}

// Throws an error if any abstract registrations have not been implemented
Expand Down
1 change: 0 additions & 1 deletion Sources/KnitCodeGen/Configuration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ public struct Configuration: Encodable, Sendable {
public var directives: KnitDirectives

public enum AssemblyType: String, Encodable, Sendable {
/// `Swinject.Assembly`
case moduleAssembly = "ModuleAssembly"
case autoInitAssembly = "AutoInitModuleAssembly"
case abstractAssembly = "AbstractAssembly"
Expand Down
10 changes: 10 additions & 0 deletions Sources/KnitCodeGen/FunctionCallRegistrationParsing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,19 @@ private func getConcurrencyModifier(
arguments: LabeledExprListSyntax,
trailingClosure: ClosureExprSyntax?
) -> String? {
// Detects concrete registrations that use the explicitly named closure argument
if arguments.contains(where: {$0.label?.text == "mainActorFactory" }) {
return "@MainActor"
}
// Detects abstract registrations
for arg in arguments {
guard arg.label?.text == "concurrency" else { continue }
// Corresponds to `(concurrency: .MainActor)`
// declName is what follows the period
if arg.expression.as(MemberAccessExprSyntax.self)?.declName.baseName.text == "MainActor" {
return "@MainActor"
}
}
guard let signature = trailingClosure?.signature else { return nil }
for att in signature.attributes {
guard case let .attribute(attributeSyntax) = att else {
Expand Down
12 changes: 12 additions & 0 deletions Tests/KnitCodeGenTests/RegistrationParsingTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,16 @@ final class RegistrationParsingTests: XCTestCase {
serviceName: "AType",
name: "service"
)

try assertRegistrationString(
"""
container.registerAbstract(AType.self, name: "service", concurrency: .MainActor)
""",
serviceName: "AType",
name: "service",
concurrencyModifier: "@MainActor"
)

}

func testForwardedRegistration() throws {
Expand Down Expand Up @@ -637,6 +647,7 @@ private func assertRegistrationString(
accessLevel: AccessLevel = .internal,
name: String? = nil,
isForwarded: Bool = false,
concurrencyModifier: String? = nil,
file: StaticString = #filePath, line: UInt = #line
) throws {
let functionCall = try XCTUnwrap(FunctionCallExprSyntax("\(raw: string)" as ExprSyntax))
Expand All @@ -651,6 +662,7 @@ private func assertRegistrationString(
XCTAssertEqual(registration?.accessLevel, accessLevel, file: file, line: line)
XCTAssertEqual(registration?.name, name, file: file, line: line)
XCTAssertEqual(registration?.isForwarded, isForwarded, file: file, line: line)
XCTAssertEqual(registration?.concurrencyModifier, concurrencyModifier, file: file, line: line)
}

/// Assert that multiple registrations exist within the string.
Expand Down

0 comments on commit fc751a4

Please sign in to comment.