Skip to content

Commit

Permalink
Merge pull request #174 from cashapp/skorulis/main-actor
Browse files Browse the repository at this point in the history
Add support for registering factories on the main thread
  • Loading branch information
skorulis-ap authored Jul 17, 2024
2 parents 0ef4f36 + 8d3b336 commit 54054ea
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import PackageDescription
let package = Package(
name: "Knit",
platforms: [
.macOS(.v13),
.macOS(.v14),
],
products: [
.library(name: "Knit", targets: ["Knit"]),
Expand Down
42 changes: 42 additions & 0 deletions Sources/Knit/Container+MainActor.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//
// Copyright © Block, Inc. All rights reserved.
//

import Swinject

// This code should move into the Swinject library.
// There is an open pull request to make this change https://github.com/Swinject/Swinject/pull/570

extension Container {
// Register a service type's factory with the assumption that the registration
// will happen on the main thread.
//
// This method relies on the type eventually being resolved by a caller on the main
// thread using the knit generated resolver. If that call is not made on the main
// thread then a crash will occur.
@discardableResult
public func register<Service>(
_ serviceType: Service.Type,
name: String? = nil,
mainActorFactory: @escaping @MainActor (Resolver) -> Service
) -> ServiceEntry<Service> {
return register(serviceType, name: name) { r in
MainActor.assumeIsolated {
return mainActorFactory(r)
}
}
}

@discardableResult
public func register<Service, Arg1>(
_ serviceType: Service.Type,
name: String? = nil,
mainActorFactory: @escaping @MainActor (Resolver, Arg1) -> Service
) -> ServiceEntry<Service> {
return register(serviceType) { (resolver: Resolver, arg1: Arg1) in
MainActor.assumeIsolated {
return mainActorFactory(resolver, arg1)
}
}
}
}
16 changes: 16 additions & 0 deletions Sources/KnitCodeGen/FunctionCallRegistrationParsing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,13 @@ extension FunctionCallExprSyntax {
trailingClosure: primaryRegisterMethod.trailingClosure
)

let concurrencyModifier = getConcurrencyModifier(trailingClosure: primaryRegisterMethod.trailingClosure)

// The primary registration (not `.implements()`)
guard let primaryRegistration = try makeRegistrationFor(
defaultDirectives: defaultDirectives,
arguments: primaryRegisterMethod.arguments,
concurrencyModifier: concurrencyModifier,
registrationArguments: registrationArguments,
leadingTrivia: self.leadingTrivia,
functionName: functionName
Expand All @@ -98,6 +101,7 @@ extension FunctionCallExprSyntax {
if let forwardedRegistration = try makeRegistrationFor(
defaultDirectives: defaultDirectives,
arguments: implementsCalledMethod.arguments,
concurrencyModifier: concurrencyModifier,
registrationArguments: registrationArguments,
leadingTrivia: leadingTrivia,
functionName: .implements
Expand Down Expand Up @@ -150,6 +154,7 @@ func recurseAllCalledMethods(
private func makeRegistrationFor(
defaultDirectives: KnitDirectives,
arguments: LabeledExprListSyntax,
concurrencyModifier: String?,
registrationArguments: [Registration.Argument],
leadingTrivia: Trivia?,
functionName: Registration.FunctionName
Expand All @@ -173,6 +178,7 @@ private func makeRegistrationFor(
name: name,
accessLevel: directives.accessLevel ?? defaultDirectives.accessLevel ?? .default,
arguments: registrationArguments,
concurrencyModifier: concurrencyModifier,
getterConfig: getterConfig,
functionName: functionName
)
Expand Down Expand Up @@ -271,6 +277,16 @@ private func getArguments(
return []
}

private func getConcurrencyModifier(trailingClosure: ClosureExprSyntax?) -> String? {
guard let signature = trailingClosure?.signature else { return nil }
for att in signature.attributes {
if att.description.trimmingCharacters(in: .whitespaces) == "@MainActor" {
return "@MainActor"
}
}
return nil
}

private func getArgumentType(arg: LabeledExprSyntax) -> String? {
return arg.expression.as(MemberAccessExprSyntax.self)?.base?.description
.trimmingCharacters(in: .whitespacesAndNewlines)
Expand Down
6 changes: 5 additions & 1 deletion Sources/KnitCodeGen/Registration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ public struct Registration: Equatable, Codable {

public var accessLevel: AccessLevel

public let concurrencyModifier: String?

/// Argument types required to resolve the registration
public var arguments: [Argument]

Expand All @@ -29,12 +31,14 @@ public struct Registration: Equatable, Codable {
name: String? = nil,
accessLevel: AccessLevel = .internal,
arguments: [Argument] = [],
concurrencyModifier: String? = nil,
getterConfig: Set<GetterConfig> = GetterConfig.default,
functionName: FunctionName = .register
) {
self.service = service
self.name = name
self.accessLevel = accessLevel
self.concurrencyModifier = concurrencyModifier
self.arguments = arguments
self.getterConfig = getterConfig
self.functionName = functionName
Expand All @@ -47,7 +51,7 @@ public struct Registration: Equatable, Codable {

private enum CodingKeys: CodingKey {
// ifConfigCondition is not encoded since ExprSyntax does not conform to codable
case service, name, accessLevel, arguments, getterConfig, functionName
case service, name, accessLevel, arguments, getterConfig, functionName, concurrencyModifier
}

}
Expand Down
6 changes: 5 additions & 1 deletion Sources/KnitCodeGen/TypeSafetySourceFile.swift
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ public enum TypeSafetySourceFile {
enumName: String? = nil,
getterType: GetterConfig = .callAsFunction
) throws -> DeclSyntaxProtocol {
let modifier = registration.accessLevel == .public ? "public " : ""
var modifier = ""
if let concurrencyModifier = registration.concurrencyModifier {
modifier = "\(concurrencyModifier) "
}
modifier += registration.accessLevel == .public ? "public " : ""
let nameInput = enumName.map { "name: \($0)" }
let nameUsage = enumName != nil ? "name: name.rawValue" : nil
let (argInput, argUsage) = argumentString(registration: registration)
Expand Down
13 changes: 13 additions & 0 deletions Tests/KnitCodeGenTests/RegistrationParsingTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,19 @@ final class RegistrationParsingTests: XCTestCase {
)
}

func testMainActorParsing() throws {
try assertMultipleRegistrationsString(
"""
container.register(A.self) { @MainActor in A() }
.implements(B.self)
""",
registrations: [
Registration(service: "A", concurrencyModifier: "@MainActor", functionName: .register),
Registration(service: "B", concurrencyModifier: "@MainActor", functionName: .implements),
]
)
}

func testIncorrectRegistrations() throws {
try assertNoRegistrationsString("container.someOtherMethod(AType.self)", message: "Incorrect method name")
try assertNoRegistrationsString("container.register(A)", message: "First param is not a metatype")
Expand Down
23 changes: 23 additions & 0 deletions Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -246,4 +246,27 @@ final class TypeSafetySourceFileTests: XCTestCase {
XCTAssertEqual(expected, result.formatted().description)
}

func test_mainActor_resolver() throws {
let result = try TypeSafetySourceFile.make(
from: Configuration(
assemblyName: "MainActorAssembly",
moduleName: "Module",
registrations: [
.init(service: "ServiceA", concurrencyModifier: "@MainActor")
],
targetResolver: "Resolver"
)
)
let expected = """
/// Generated from ``MainActorAssembly``
extension Resolver {
@MainActor func serviceA(file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> ServiceA {
knitUnwrap(resolve(ServiceA.self), callsiteFile: file, callsiteFunction: function, callsiteLine: line)
}
}
"""

XCTAssertEqual(expected, result.formatted().description)
}

}

0 comments on commit 54054ea

Please sign in to comment.