Skip to content

Commit

Permalink
Parse the assembly type based on the inheritance clause
Browse files Browse the repository at this point in the history
  • Loading branch information
skorulis-ap committed Feb 13, 2024
1 parent 50fcd9f commit 38e8fb8
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 8 deletions.
8 changes: 8 additions & 0 deletions Sources/Knit/Module/AbstractAssembly.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
//
// Copyright © Block, Inc. All rights reserved.
//

import Foundation

/// An AbstractAssembly can only contain abstract registrations and should not be initialised.
protocol AbstractAssembly: ModuleAssembly { }
20 changes: 17 additions & 3 deletions Sources/KnitCodeGen/AssemblyParsing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ func parseSyntaxTree(
throw AssemblyParsingError.missingModuleName
}

guard let assemblyType = assemblyFileVisitor.assemblyType else {
throw AssemblyParsingError.missingAssemblyType
}

errorsToPrint.append(contentsOf: assemblyFileVisitor.assemblyErrors)
errorsToPrint.append(contentsOf: assemblyFileVisitor.registrationErrors)

Expand All @@ -63,6 +67,7 @@ func parseSyntaxTree(

return Configuration(
name: name,
assemblyType: assemblyType,
registrations: assemblyFileVisitor.registrations,
registrationsIntoCollections: assemblyFileVisitor.registrationsIntoCollections,
imports: assemblyFileVisitor.imports,
Expand All @@ -77,6 +82,8 @@ private class AssemblyFileVisitor: SyntaxVisitor, IfConfigVisitor {

private(set) var moduleName: String?

private(set) var assemblyType: String?

private var classDeclVisitor: ClassDeclVisitor?

private(set) var assemblyErrors: [Error] = []
Expand Down Expand Up @@ -105,11 +112,11 @@ private class AssemblyFileVisitor: SyntaxVisitor, IfConfigVisitor {
}

override func visit(_ node: StructDeclSyntax) -> SyntaxVisitorContinueKind {
return visitAssemblyType(node)
return visitAssemblyType(node, node.inheritanceClause)
}

override func visit(_ node: ClassDeclSyntax) -> SyntaxVisitorContinueKind {
return visitAssemblyType(node)
return visitAssemblyType(node, node.inheritanceClause)
}

override func visit(_ node: ImportDeclSyntax) -> SyntaxVisitorContinueKind {
Expand All @@ -130,7 +137,7 @@ private class AssemblyFileVisitor: SyntaxVisitor, IfConfigVisitor {
return self.visitIfNode(node)
}

private func visitAssemblyType(_ node: NamedDeclSyntax) -> SyntaxVisitorContinueKind {
private func visitAssemblyType(_ node: NamedDeclSyntax, _ inheritance: InheritanceClauseSyntax?) -> SyntaxVisitorContinueKind {
guard classDeclVisitor == nil else {
// Only the first class declaration should be visited
return .skipChildren
Expand All @@ -143,6 +150,10 @@ private class AssemblyFileVisitor: SyntaxVisitor, IfConfigVisitor {
}

moduleName = node.moduleNameForAssembly
let inheritedTypes = inheritance?.inheritedTypes.map {
$0.type.description.trimmingCharacters(in: .whitespaces)
}
self.assemblyType = inheritedTypes?.first(where: { $0.hasSuffix("Assembly")})
classDeclVisitor = ClassDeclVisitor(viewMode: .fixedUp, directives: directives)
classDeclVisitor?.walk(node)
return .skipChildren
Expand Down Expand Up @@ -233,6 +244,7 @@ extension NamedDeclSyntax {
enum AssemblyParsingError: Error {
case fileReadError(Error, path: String)
case missingModuleName
case missingAssemblyType
case parsingError
}

Expand All @@ -251,6 +263,8 @@ extension AssemblyParsingError: LocalizedError {
"Is your Assembly file setup correctly?"
case .parsingError:
return "There were one or more errors parsing the assembly file"
case .missingAssemblyType:
return "Assembly files must inherit from an *Assembly type"
}
}

Expand Down
8 changes: 4 additions & 4 deletions Sources/KnitCodeGen/Configuration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ public struct Configuration: Encodable {

/// Name of the module for this configuration.
public var name: String
public var assemblyType: String

public var registrations: [Registration]
public var registrationsIntoCollections: [RegistrationIntoCollection]
Expand All @@ -18,12 +19,14 @@ public struct Configuration: Encodable {

public init(
name: String,
assemblyType: String = "Assembly",
registrations: [Registration],
registrationsIntoCollections: [RegistrationIntoCollection],
imports: [ModuleImport] = [],
targetResolver: String
) {
self.name = name
self.assemblyType = assemblyType
self.registrations = registrations
self.registrationsIntoCollections = registrationsIntoCollections
self.imports = imports
Expand All @@ -32,6 +35,7 @@ public struct Configuration: Encodable {

public enum CodingKeys: CodingKey {
case name
case assemblyType
case registrations
}

Expand All @@ -48,10 +52,6 @@ public extension Configuration {
}

func makeUnitTestSourceFile() throws -> SourceFileSyntax {
var allImports = imports
allImports.append(try .testable(name: name))
allImports.append(try .named("XCTest"))

return try UnitTestSourceFile.make(
configuration: self
)
Expand Down
4 changes: 3 additions & 1 deletion Tests/KnitCodeGenTests/AssemblyParsingTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ final class AssemblyParsingTests: XCTestCase {
let sourceFile: SourceFileSyntax = """
import A
import B // Comment after import should be stripped
class FooTestAssembly: Assembly { }
class FooTestAssembly: ModuleAssembly { }
"""

let config = try assertParsesSyntaxTree(sourceFile)
Expand All @@ -25,6 +25,7 @@ final class AssemblyParsingTests: XCTestCase {
]
)
XCTAssertEqual(config.registrations.count, 0, "No registrations")
XCTAssertEqual(config.assemblyType, "ModuleAssembly")
}

func testDebugWrappedAssemblyImports() throws {
Expand Down Expand Up @@ -108,6 +109,7 @@ final class AssemblyParsingTests: XCTestCase {

let config = try assertParsesSyntaxTree(sourceFile)
XCTAssertEqual(config.name, "FooTest")
XCTAssertEqual(config.assemblyType, "Assembly")
}

func testAssemblyStructModuleName() throws {
Expand Down

0 comments on commit 38e8fb8

Please sign in to comment.