diff --git a/Example/KnitExample/KnitExampleAssembly.swift b/Example/KnitExample/KnitExampleAssembly.swift index fa727a3..9dc136f 100644 --- a/Example/KnitExample/KnitExampleAssembly.swift +++ b/Example/KnitExample/KnitExampleAssembly.swift @@ -32,6 +32,10 @@ final class KnitExampleAssembly: Assembly { } container.autoregisterIntoCollection(ExampleService.self, initializer: ExampleService.init) + + #if DEBUG + container.autoregister(DebugService.self, initializer: DebugService.init) + #endif } } @@ -59,3 +63,5 @@ final class ClosureService { init(closure: @escaping (() -> Void)) { } } + +struct DebugService { } diff --git a/Example/KnitExample/KnitExampleUserAssembly.swift b/Example/KnitExample/KnitExampleUserAssembly.swift index b7caff1..7c8aac5 100644 --- a/Example/KnitExample/KnitExampleUserAssembly.swift +++ b/Example/KnitExample/KnitExampleUserAssembly.swift @@ -1,7 +1,7 @@ // Copyright © Square, Inc. All rights reserved. import Foundation -import Knit +import KnitLib // @knit internal getter-named /// An assembly expected to be registered at the user level rather than at the app level diff --git a/Sources/KnitCodeGen/AssemblyParsing.swift b/Sources/KnitCodeGen/AssemblyParsing.swift index 680a8ec..55d4edf 100644 --- a/Sources/KnitCodeGen/AssemblyParsing.swift +++ b/Sources/KnitCodeGen/AssemblyParsing.swift @@ -116,6 +116,9 @@ private class ClassDeclVisitor: SyntaxVisitor { private(set) var registrationErrors = [Error]() + /// For any registrations parsed, this should be #if condition should be applied when it is used + private var currentIfConfigCondition: ExprSyntax? + init(viewMode: SyntaxTreeViewMode, directives: KnitDirectives) { self.directives = directives super.init(viewMode: viewMode) @@ -123,7 +126,12 @@ private class ClassDeclVisitor: SyntaxVisitor { override func visit(_ node: FunctionCallExprSyntax) -> SyntaxVisitorContinueKind { do { - let (registrations, registrationsIntoCollections) = try node.getRegistrations(defaultDirectives: directives) + var (registrations, registrationsIntoCollections) = try node.getRegistrations(defaultDirectives: directives) + registrations = registrations.map { registration in + var mutable = registration + mutable.ifConfigCondition = currentIfConfigCondition + return mutable + } self.registrations.append(contentsOf: registrations) self.registrationsIntoCollections.append(contentsOf: registrationsIntoCollections) } catch { @@ -138,6 +146,24 @@ private class ClassDeclVisitor: SyntaxVisitor { return .skipChildren } + override func visit(_ node: IfConfigClauseSyntax) -> SyntaxVisitorContinueKind { + // Allowing for #else creates a link between the registration inside the #if and those in the #else + // This greatly increases the complexity of handling #if so raise an error when #else is used + if node.poundKeyword.text.contains("#else") { + registrationErrors.append( + RegistrationParsingError.invalidIfConfig(syntax: node, text: node.poundKeyword.text) + ) + return .skipChildren + } + // Set the condition and walk the children to create the registrations + self.currentIfConfigCondition = node.condition + node.children(viewMode: .sourceAccurate).forEach { syntax in + walk(syntax) + } + self.currentIfConfigCondition = nil + return .skipChildren + } + } extension NamedDeclSyntax { diff --git a/Sources/KnitCodeGen/FunctionCallRegistrationParsing.swift b/Sources/KnitCodeGen/FunctionCallRegistrationParsing.swift index eb4c78c..a69aef7 100644 --- a/Sources/KnitCodeGen/FunctionCallRegistrationParsing.swift +++ b/Sources/KnitCodeGen/FunctionCallRegistrationParsing.swift @@ -282,6 +282,7 @@ enum RegistrationParsingError: LocalizedError, SyntaxError { case unwrappedClosureParams(syntax: SyntaxProtocol) case chainedRegistrations(syntax: SyntaxProtocol) case nonStaticString(syntax: SyntaxProtocol, name: String) + case invalidIfConfig(syntax: SyntaxProtocol, text: String) var errorDescription: String? { switch self { @@ -293,6 +294,8 @@ enum RegistrationParsingError: LocalizedError, SyntaxError { return "Chained registration calls are not supported" case let .nonStaticString(_, name): return "Service name must be a static string. Found: \(name)" + case let .invalidIfConfig(_, text): + return "Invalid IfConfig expression around container registration: \(text)" } } @@ -301,7 +304,8 @@ enum RegistrationParsingError: LocalizedError, SyntaxError { case let .missingArgumentType(syntax, _), let .chainedRegistrations(syntax), let .nonStaticString(syntax, _), - let .unwrappedClosureParams(syntax): + let .unwrappedClosureParams(syntax), + let .invalidIfConfig(syntax, _): return syntax } } diff --git a/Sources/KnitCodeGen/Registration.swift b/Sources/KnitCodeGen/Registration.swift index 1904892..5f0fda5 100644 --- a/Sources/KnitCodeGen/Registration.swift +++ b/Sources/KnitCodeGen/Registration.swift @@ -1,3 +1,5 @@ +import SwiftSyntax + public struct Registration: Equatable, Codable { public var service: String @@ -16,6 +18,8 @@ public struct Registration: Equatable, Codable { /// This registration's getter setting. public var getterConfig: Set + public var ifConfigCondition: ExprSyntax? + public init( service: String, name: String? = nil, @@ -32,6 +36,10 @@ public struct Registration: Equatable, Codable { self.getterConfig = getterConfig } + private enum CodingKeys: CodingKey { + case service, name, accessLevel, arguments, isForwarded, getterConfig + } + } extension Registration { diff --git a/Sources/KnitCodeGen/TypeSafetySourceFile.swift b/Sources/KnitCodeGen/TypeSafetySourceFile.swift index c6756ef..a74b286 100644 --- a/Sources/KnitCodeGen/TypeSafetySourceFile.swift +++ b/Sources/KnitCodeGen/TypeSafetySourceFile.swift @@ -49,7 +49,7 @@ public enum TypeSafetySourceFile { registration: Registration, enumName: String? = nil, getterType: GetterConfig = .callAsFunction - ) throws -> FunctionDeclSyntax { + ) throws -> DeclSyntaxProtocol { let modifier = registration.accessLevel == .public ? "public " : "" let nameInput = enumName.map { "name: \($0)" } let nameUsage = enumName != nil ? "name: name.rawValue" : nil @@ -64,9 +64,21 @@ public enum TypeSafetySourceFile { funcName = name ?? TypeNamer.computedIdentifierName(type: registration.service) } - return try FunctionDeclSyntax("\(raw: modifier)func \(raw: funcName)(\(raw: inputs)) -> \(raw: registration.service)") { + let function = try FunctionDeclSyntax("\(raw: modifier)func \(raw: funcName)(\(raw: inputs)) -> \(raw: registration.service)") { "self.resolve(\(raw: usages))!" } + + // Wrap the output an in #if where needed + guard let ifConfigCondition = registration.ifConfigCondition else { + return function + } + let codeBlock = CodeBlockItemListSyntax([.init(item: .init(function))]) + let clause = IfConfigClauseSyntax( + poundKeyword: .poundIfToken(), + condition: ifConfigCondition, + elements: .statements(codeBlock) + ) + return IfConfigDeclSyntax(clauses: [clause]) } private static func argumentString(registration: Registration) -> (input: String?, usage: String?) { @@ -114,7 +126,6 @@ extension Registration { } return result } - } extension Registration.Argument { @@ -133,5 +144,4 @@ extension Registration.Argument { var functionType: String { return TypeNamer.isClosure(type: type) ? "@escaping \(type)" : type } - } diff --git a/Sources/KnitCodeGen/UnitTestSourceFile.swift b/Sources/KnitCodeGen/UnitTestSourceFile.swift index 5748644..00e4205 100644 --- a/Sources/KnitCodeGen/UnitTestSourceFile.swift +++ b/Sources/KnitCodeGen/UnitTestSourceFile.swift @@ -87,8 +87,25 @@ public enum UnitTestSourceFile { } } + static func makeAssertCall(registration: Registration) -> CodeBlockItemListSyntax { + let expression = makeAssertCallExpression(registration: registration) + let codeBlock = CodeBlockItemListSyntax([.init(item: .init(expression))]) + + // Wrap the output an in #if where needed + guard let ifConfigCondition = registration.ifConfigCondition else { + return codeBlock + } + let clause = IfConfigClauseSyntax( + poundKeyword: .poundIfToken(), + condition: ifConfigCondition, + elements: .statements(codeBlock) + ) + let ifConfig = IfConfigDeclSyntax(clauses: [clause]) + return CodeBlockItemListSyntax([.init(item: .init(ifConfig))]) + } + /// Generate a function call to test a single registration resolves - static func makeAssertCall(registration: Registration) -> ExprSyntax { + private static func makeAssertCallExpression(registration: Registration) -> ExprSyntax { if !registration.arguments.isEmpty { let argParams = argumentParams(registration: registration) let nameParam = registration.name.map { "name: \"\($0)\""} diff --git a/Tests/KnitCodeGenTests/AssemblyParsingTests.swift b/Tests/KnitCodeGenTests/AssemblyParsingTests.swift index a1e6885..8459cee 100644 --- a/Tests/KnitCodeGenTests/AssemblyParsingTests.swift +++ b/Tests/KnitCodeGenTests/AssemblyParsingTests.swift @@ -244,6 +244,81 @@ final class AssemblyParsingTests: XCTestCase { ) } + func testIfDefElseFailure() throws { + let sourceFile: SourceFileSyntax = """ + class ExampleAssembly: Assembly { + func assemble(container: Container) { + #if SOME_FLAG + container.autoregister(B.self, initializer: B.init) + #else + container.autoregister(C.self, initializer: C.init) + #endif + } + } + """ + + // Make sure that individual registration errors are bubbled up to be printed + _ = try assertParsesSyntaxTree( + sourceFile, + assertErrorsToPrint: { errors in + XCTAssertEqual(errors.count, 1) + XCTAssertEqual( + errors.first?.localizedDescription, + "Invalid IfConfig expression around container registration: #else" + ) + } + ) + } + + func testIfDefParsing() throws { + let sourceFile: SourceFileSyntax = """ + class ExampleAssembly: Assembly { + func assemble(container: Container) { + #if SOME_FLAG + container.autoregister(A.self, initializer: A.init) + #endif + + #if SOME_FLAG && !ANOTHER_FLAG + container.autoregister(B.self, initializer: B.init) + container.autoregister(C.self, initializer: C.init) + #endif + } + } + """ + + let config = try assertParsesSyntaxTree(sourceFile) + XCTAssertEqual(config.name, "Example") + XCTAssertEqual(config.registrations.count, 3) + + XCTAssertEqual(config.registrations[0].service, "A") + XCTAssertEqual(config.registrations[0].ifConfigCondition?.description, "SOME_FLAG") + + XCTAssertEqual(config.registrations[1].service, "B") + XCTAssertEqual(config.registrations[1].ifConfigCondition?.description, "SOME_FLAG && !ANOTHER_FLAG") + + XCTAssertEqual(config.registrations[2].service, "C") + XCTAssertEqual(config.registrations[2].ifConfigCondition?.description, "SOME_FLAG && !ANOTHER_FLAG") + } + + func testIfSimulatorParsing() throws { + let sourceFile: SourceFileSyntax = """ + class ExampleAssembly: Assembly { + func assemble(container: Container) { + #if targetEnvironment(simulator) + container.autoregister(A.self, initializer: A.init) + #endif + } + } + """ + + let config = try assertParsesSyntaxTree(sourceFile) + XCTAssertEqual(config.registrations.count, 1) + XCTAssertEqual( + config.registrations.first?.ifConfigCondition?.description, + "targetEnvironment(simulator)" + ) + } + } private func assertParsesSyntaxTree( diff --git a/Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift b/Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift index 79eb1f9..b97617e 100644 --- a/Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift +++ b/Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift @@ -134,6 +134,24 @@ final class TypeSafetySourceFileTests: XCTestCase { ) } + func testRegistrationWithIfConfig() { + var registration = Registration(service: "A", accessLevel: .public) + registration.ifConfigCondition = ExprSyntax("SOME_FLAG") + XCTAssertEqual( + try TypeSafetySourceFile.makeResolver( + registration: registration, + enumName: nil + ).formatted().description, + """ + #if SOME_FLAG + public func callAsFunction() -> A { + self.resolve(A.self)! + } + #endif + """ + ) + } + func testArgumentNames() { let registration1 = Registration(service: "A", accessLevel: .public, arguments: [.init(type: "String?")]) XCTAssertEqual( diff --git a/Tests/KnitCodeGenTests/UnitTestSourceFileTests.swift b/Tests/KnitCodeGenTests/UnitTestSourceFileTests.swift index 1200612..8d0fcb8 100644 --- a/Tests/KnitCodeGenTests/UnitTestSourceFileTests.swift +++ b/Tests/KnitCodeGenTests/UnitTestSourceFileTests.swift @@ -170,6 +170,21 @@ final class UnitTestSourceFileTests: XCTestCase { XCTAssertEqual(formattedResult, expected) } + func test_registrationAssertIfConfig() { + var registration = Registration(service: "A", accessLevel: .hidden) + registration.ifConfigCondition = ExprSyntax("SOME_FLAG && !DEBUG") + let result = UnitTestSourceFile.makeAssertCall(registration: registration) + let formattedResult = result.formatted().description + XCTAssertEqual( + formattedResult, + """ + #if SOME_FLAG && !DEBUG + resolver.assertTypeResolves(A.self) + #endif + """ + ) + } + func test_registrationAssertArgument() { let registration = Registration( service: "A",