Skip to content

Commit

Permalink
Merge pull request #96 from squareup/skorulis/ifdef
Browse files Browse the repository at this point in the history
Add support for parsing registrations wrapped in #if conditions
  • Loading branch information
skorulis-ap authored Nov 17, 2023
2 parents b95eea1 + c4afc54 commit 5c1fed4
Show file tree
Hide file tree
Showing 10 changed files with 225 additions and 8 deletions.
6 changes: 6 additions & 0 deletions Example/KnitExample/KnitExampleAssembly.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ final class KnitExampleAssembly: Assembly {
}

container.autoregisterIntoCollection(ExampleService.self, initializer: ExampleService.init)

#if DEBUG
container.autoregister(DebugService.self, initializer: DebugService.init)
#endif
}

}
Expand Down Expand Up @@ -59,3 +63,5 @@ final class ClosureService {

init(closure: @escaping (() -> Void)) { }
}

struct DebugService { }
2 changes: 1 addition & 1 deletion Example/KnitExample/KnitExampleUserAssembly.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright © Square, Inc. All rights reserved.

import Foundation
import Knit
import KnitLib

// @knit internal getter-named
/// An assembly expected to be registered at the user level rather than at the app level
Expand Down
35 changes: 34 additions & 1 deletion Sources/KnitCodeGen/AssemblyParsing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,22 @@ private class ClassDeclVisitor: SyntaxVisitor {

private(set) var targetResolver: String?

/// For any registrations parsed, this should be #if condition should be applied when it is used
private var currentIfConfigCondition: ExprSyntax?

init(viewMode: SyntaxTreeViewMode, directives: KnitDirectives) {
self.directives = directives
super.init(viewMode: viewMode)
}

override func visit(_ node: FunctionCallExprSyntax) -> SyntaxVisitorContinueKind {
do {
let (registrations, registrationsIntoCollections) = try node.getRegistrations(defaultDirectives: directives)
var (registrations, registrationsIntoCollections) = try node.getRegistrations(defaultDirectives: directives)
registrations = registrations.map { registration in
var mutable = registration
mutable.ifConfigCondition = currentIfConfigCondition
return mutable
}
self.registrations.append(contentsOf: registrations)
self.registrationsIntoCollections.append(contentsOf: registrationsIntoCollections)
} catch {
Expand All @@ -172,6 +180,31 @@ private class ClassDeclVisitor: SyntaxVisitor {
return .skipChildren
}

override func visit(_ node: IfConfigClauseSyntax) -> SyntaxVisitorContinueKind {
// Allowing for #else creates a link between the registration inside the #if and those in the #else
// This greatly increases the complexity of handling #if so raise an error when #else is used
if node.poundKeyword.text.contains("#else") {
registrationErrors.append(
RegistrationParsingError.invalidIfConfig(syntax: node, text: node.poundKeyword.text)
)
return .skipChildren
}
// Raise an error for nested if statements
if self.currentIfConfigCondition != nil {
registrationErrors.append(
RegistrationParsingError.nestedIfConfig(syntax: node)
)
return .skipChildren
}
// Set the condition and walk the children to create the registrations
self.currentIfConfigCondition = node.condition
node.children(viewMode: .sourceAccurate).forEach { syntax in
walk(syntax)
}
self.currentIfConfigCondition = nil
return .skipChildren
}

}

extension NamedDeclSyntax {
Expand Down
10 changes: 9 additions & 1 deletion Sources/KnitCodeGen/FunctionCallRegistrationParsing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ enum RegistrationParsingError: LocalizedError, SyntaxError {
case unwrappedClosureParams(syntax: SyntaxProtocol)
case chainedRegistrations(syntax: SyntaxProtocol)
case nonStaticString(syntax: SyntaxProtocol, name: String)
case invalidIfConfig(syntax: SyntaxProtocol, text: String)
case nestedIfConfig(syntax: SyntaxProtocol)

var errorDescription: String? {
switch self {
Expand All @@ -293,6 +295,10 @@ enum RegistrationParsingError: LocalizedError, SyntaxError {
return "Chained registration calls are not supported"
case let .nonStaticString(_, name):
return "Service name must be a static string. Found: \(name)"
case let .invalidIfConfig(_, text):
return "Invalid IfConfig expression around container registration: \(text)"
case .nestedIfConfig:
return "Nested #if statements are not supported"
}
}

Expand All @@ -301,7 +307,9 @@ enum RegistrationParsingError: LocalizedError, SyntaxError {
case let .missingArgumentType(syntax, _),
let .chainedRegistrations(syntax),
let .nonStaticString(syntax, _),
let .unwrappedClosureParams(syntax):
let .unwrappedClosureParams(syntax),
let .invalidIfConfig(syntax, _),
let .nestedIfConfig(syntax):
return syntax
}
}
Expand Down
9 changes: 9 additions & 0 deletions Sources/KnitCodeGen/Registration.swift
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import SwiftSyntax

public struct Registration: Equatable, Codable {

public var service: String
Expand All @@ -16,6 +18,8 @@ public struct Registration: Equatable, Codable {
/// This registration's getter setting.
public var getterConfig: Set<GetterConfig>

public var ifConfigCondition: ExprSyntax?

public init(
service: String,
name: String? = nil,
Expand All @@ -32,6 +36,11 @@ public struct Registration: Equatable, Codable {
self.getterConfig = getterConfig
}

private enum CodingKeys: CodingKey {
// ifConfigCondition is not encoded since ExprSyntax does not conform to codable
case service, name, accessLevel, arguments, isForwarded, getterConfig
}

}

extension Registration {
Expand Down
18 changes: 14 additions & 4 deletions Sources/KnitCodeGen/TypeSafetySourceFile.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public enum TypeSafetySourceFile {
registration: Registration,
enumName: String? = nil,
getterType: GetterConfig = .callAsFunction
) throws -> FunctionDeclSyntax {
) throws -> DeclSyntaxProtocol {
let modifier = registration.accessLevel == .public ? "public " : ""
let nameInput = enumName.map { "name: \($0)" }
let nameUsage = enumName != nil ? "name: name.rawValue" : nil
Expand All @@ -64,9 +64,21 @@ public enum TypeSafetySourceFile {
funcName = name ?? TypeNamer.computedIdentifierName(type: registration.service)
}

return try FunctionDeclSyntax("\(raw: modifier)func \(raw: funcName)(\(raw: inputs)) -> \(raw: registration.service)") {
let function = try FunctionDeclSyntax("\(raw: modifier)func \(raw: funcName)(\(raw: inputs)) -> \(raw: registration.service)") {
"self.resolve(\(raw: usages))!"
}

// Wrap the output an in #if where needed
guard let ifConfigCondition = registration.ifConfigCondition else {
return function
}
let codeBlock = CodeBlockItemListSyntax([.init(item: .init(function))])
let clause = IfConfigClauseSyntax(
poundKeyword: .poundIfToken(),
condition: ifConfigCondition,
elements: .statements(codeBlock)
)
return IfConfigDeclSyntax(clauses: [clause])
}

private static func argumentString(registration: Registration) -> (input: String?, usage: String?) {
Expand Down Expand Up @@ -114,7 +126,6 @@ extension Registration {
}
return result
}

}

extension Registration.Argument {
Expand All @@ -133,5 +144,4 @@ extension Registration.Argument {
var functionType: String {
return TypeNamer.isClosure(type: type) ? "@escaping \(type)" : type
}

}
19 changes: 18 additions & 1 deletion Sources/KnitCodeGen/UnitTestSourceFile.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,25 @@ public enum UnitTestSourceFile {
}
}

static func makeAssertCall(registration: Registration) -> CodeBlockItemListSyntax {
let expression = makeAssertCallExpression(registration: registration)
let codeBlock = CodeBlockItemListSyntax([.init(item: .init(expression))])

// Wrap the output an in #if where needed
guard let ifConfigCondition = registration.ifConfigCondition else {
return codeBlock
}
let clause = IfConfigClauseSyntax(
poundKeyword: .poundIfToken(),
condition: ifConfigCondition,
elements: .statements(codeBlock)
)
let ifConfig = IfConfigDeclSyntax(clauses: [clause])
return CodeBlockItemListSyntax([.init(item: .init(ifConfig))])
}

/// Generate a function call to test a single registration resolves
static func makeAssertCall(registration: Registration) -> ExprSyntax {
private static func makeAssertCallExpression(registration: Registration) -> ExprSyntax {
if !registration.arguments.isEmpty {
let argParams = argumentParams(registration: registration)
let nameParam = registration.name.map { "name: \"\($0)\""}
Expand Down
101 changes: 101 additions & 0 deletions Tests/KnitCodeGenTests/AssemblyParsingTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,107 @@ final class AssemblyParsingTests: XCTestCase {
XCTAssertEqual(config.targetResolver, "Resolver")
}

func testIfDefElseFailure() throws {
let sourceFile: SourceFileSyntax = """
class ExampleAssembly: Assembly {
func assemble(container: Container) {
#if SOME_FLAG
container.autoregister(B.self, initializer: B.init)
#else
container.autoregister(C.self, initializer: C.init)
#endif
}
}
"""

// Make sure that individual registration errors are bubbled up to be printed
_ = try assertParsesSyntaxTree(
sourceFile,
assertErrorsToPrint: { errors in
XCTAssertEqual(errors.count, 1)
XCTAssertEqual(
errors.first?.localizedDescription,
"Invalid IfConfig expression around container registration: #else"
)
}
)
}

func testIfDefParsing() throws {
let sourceFile: SourceFileSyntax = """
class ExampleAssembly: Assembly {
func assemble(container: Container) {
#if SOME_FLAG
container.autoregister(A.self, initializer: A.init)
#endif
#if SOME_FLAG && !ANOTHER_FLAG
container.autoregister(B.self, initializer: B.init)
container.autoregister(C.self, initializer: C.init)
#endif
}
}
"""

let config = try assertParsesSyntaxTree(sourceFile)
XCTAssertEqual(config.name, "Example")
XCTAssertEqual(config.registrations.count, 3)

XCTAssertEqual(config.registrations[0].service, "A")
XCTAssertEqual(config.registrations[0].ifConfigCondition?.description, "SOME_FLAG")

XCTAssertEqual(config.registrations[1].service, "B")
XCTAssertEqual(config.registrations[1].ifConfigCondition?.description, "SOME_FLAG && !ANOTHER_FLAG")

XCTAssertEqual(config.registrations[2].service, "C")
XCTAssertEqual(config.registrations[2].ifConfigCondition?.description, "SOME_FLAG && !ANOTHER_FLAG")
}

func testIfSimulatorParsing() throws {
let sourceFile: SourceFileSyntax = """
class ExampleAssembly: Assembly {
func assemble(container: Container) {
#if targetEnvironment(simulator)
container.autoregister(A.self, initializer: A.init)
#endif
}
}
"""

let config = try assertParsesSyntaxTree(sourceFile)
XCTAssertEqual(config.registrations.count, 1)
XCTAssertEqual(
config.registrations.first?.ifConfigCondition?.description,
"targetEnvironment(simulator)"
)
}

func testNestedIfConfig() throws {
let sourceFile: SourceFileSyntax = """
class ExampleAssembly: Assembly {
func assemble(container: Container) {
#if DEBUG
#if FEATURE
container.autoregister(A.self, initializer: A.init)
#endif
#endif
}
}
"""

// Make sure that individual registration errors are bubbled up to be printed
_ = try assertParsesSyntaxTree(
sourceFile,
assertErrorsToPrint: { errors in
XCTAssertEqual(errors.count, 1)
XCTAssertEqual(
errors.first?.localizedDescription,
"Nested #if statements are not supported"
)
}
)
}

}

private func assertParsesSyntaxTree(
Expand Down
18 changes: 18 additions & 0 deletions Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,24 @@ final class TypeSafetySourceFileTests: XCTestCase {
)
}

func testRegistrationWithIfConfig() {
var registration = Registration(service: "A", accessLevel: .public)
registration.ifConfigCondition = ExprSyntax("SOME_FLAG")
XCTAssertEqual(
try TypeSafetySourceFile.makeResolver(
registration: registration,
enumName: nil
).formatted().description,
"""
#if SOME_FLAG
public func callAsFunction() -> A {
self.resolve(A.self)!
}
#endif
"""
)
}

func testArgumentNames() {
let registration1 = Registration(service: "A", accessLevel: .public, arguments: [.init(type: "String?")])
XCTAssertEqual(
Expand Down
15 changes: 15 additions & 0 deletions Tests/KnitCodeGenTests/UnitTestSourceFileTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,21 @@ final class UnitTestSourceFileTests: XCTestCase {
XCTAssertEqual(formattedResult, expected)
}

func test_registrationAssertIfConfig() {
var registration = Registration(service: "A", accessLevel: .hidden)
registration.ifConfigCondition = ExprSyntax("SOME_FLAG && !DEBUG")
let result = UnitTestSourceFile.makeAssertCall(registration: registration)
let formattedResult = result.formatted().description
XCTAssertEqual(
formattedResult,
"""
#if SOME_FLAG && !DEBUG
resolver.assertTypeResolves(A.self)
#endif
"""
)
}

func test_registrationAssertArgument() {
let registration = Registration(
service: "A",
Expand Down

0 comments on commit 5c1fed4

Please sign in to comment.