Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix parsing of enum key: allow absolute paths and fix value instances #310

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
891 changes: 891 additions & 0 deletions jvm/src/test/scala/io/kaitai/struct/ClassTypeProvider$Test.scala

Large diffs are not rendered by default.

56 changes: 56 additions & 0 deletions jvm/src/test/scala/io/kaitai/struct/exprlang/EnumRefSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package io.kaitai.struct.exprlang

import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.should.Matchers._

class EnumRefSpec extends AnyFunSpec {
describe("Expressions.parseEnumRef") {
describe("parses local enum refs") {
it("some_enum") {
Expressions.parseEnumRef("some_enum") should be(Ast.EnumRef(
false, Seq(), "some_enum"
))
}
it("with spaces: ' some_enum '") {
Expressions.parseEnumRef(" some_enum ") should be(Ast.EnumRef(
false, Seq(), "some_enum"
))
}

it("::some_enum") {
Expressions.parseEnumRef("::some_enum") should be(Ast.EnumRef(
true, Seq(), "some_enum"
))
}
it("with spaces: ' :: some_enum '") {
Expressions.parseEnumRef(" :: some_enum ") should be(Ast.EnumRef(
true, Seq(), "some_enum"
))
}
}

describe("parses path enum refs") {
it("some::enum") {
Expressions.parseEnumRef("some::enum") should be(Ast.EnumRef(
false, Seq("some"), "enum"
))
}
it("with spaces: ' some :: enum '") {
Expressions.parseEnumRef(" some :: enum ") should be(Ast.EnumRef(
false, Seq("some"), "enum"
))
}

it("::some::enum") {
Expressions.parseEnumRef("::some::enum") should be(Ast.EnumRef(
true, Seq("some"), "enum"
))
}
it("with spaces: ' :: some :: enum '") {
Expressions.parseEnumRef(" :: some :: enum ") should be(Ast.EnumRef(
true, Seq("some"), "enum"
))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,35 +133,32 @@ class ExpressionsSpec extends AnyFunSpec {

// Enums
it("parses port::http") {
Expressions.parse("port::http") should be (EnumByLabel(identifier("port"), identifier("http")))
Expressions.parse("port::http") should be (EnumByLabel(EnumRef(false, Seq(), "port"), identifier("http")))
}

it("parses some_type::port::http") {
Expressions.parse("some_type::port::http") should be (
EnumByLabel(
identifier("port"),
EnumRef(false, Seq("some_type"), "port"),
identifier("http"),
typeId(absolute = false, Seq("some_type"))
)
)
}

it("parses parent_type::child_type::port::http") {
Expressions.parse("parent_type::child_type::port::http") should be (
EnumByLabel(
identifier("port"),
EnumRef(false, Seq("parent_type", "child_type"), "port"),
identifier("http"),
typeId(absolute = false, Seq("parent_type", "child_type"))
)
)
}

it("parses ::parent_type::child_type::port::http") {
Expressions.parse("::parent_type::child_type::port::http") should be (
EnumByLabel(
identifier("port"),
EnumRef(true, Seq("parent_type", "child_type"), "port"),
identifier("http"),
typeId(absolute = true, Seq("parent_type", "child_type"))
)
)
}
Expand All @@ -171,7 +168,7 @@ class ExpressionsSpec extends AnyFunSpec {
Compare(
BinOp(
Attribute(
EnumByLabel(identifier("port"),identifier("http")),
EnumByLabel(EnumRef(false, Seq(), "port"), identifier("http")),
identifier("to_i")
),
Add,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ object TestTypeProviders {
abstract class FakeTypeProvider extends TypeProvider {
val nowClass = ClassSpec.opaquePlaceholder(List("top_class"))

override def resolveEnum(inType: Ast.typeId, enumName: String) =
override def resolveEnum(ref: Ast.EnumRef) =
throw new NotImplementedError

override def resolveType(typeName: Ast.typeId): DataType = {
Expand Down
57 changes: 35 additions & 22 deletions shared/src/main/scala/io/kaitai/struct/ClassTypeProvider.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import io.kaitai.struct.datatype.DataType
import io.kaitai.struct.datatype.DataType._
import io.kaitai.struct.exprlang.Ast
import io.kaitai.struct.format._
import io.kaitai.struct.precompile.{EnumNotFoundError, FieldNotFoundError, TypeNotFoundError, TypeUndecidedError}
import io.kaitai.struct.precompile.{EnumNotFoundInHierarchyError, EnumNotFoundInTypeError, FieldNotFoundError, TypeNotFoundInHierarchyError, TypeNotFoundInTypeError, TypeUndecidedError}
import io.kaitai.struct.translators.TypeProvider

class ClassTypeProvider(classSpecs: ClassSpecs, var topClass: ClassSpec) extends TypeProvider {
Expand Down Expand Up @@ -90,46 +90,59 @@ class ClassTypeProvider(classSpecs: ClassSpecs, var topClass: ClassSpec) extends
throw new FieldNotFoundError(attrName, inClass)
}

override def resolveEnum(inType: Ast.typeId, enumName: String): EnumSpec =
resolveEnum(resolveClassSpec(inType), enumName)
override def resolveEnum(ref: Ast.EnumRef): EnumSpec = {
val inClass = if (ref.absolute) topClass else nowClass
// When concrete type is not defined, search enum definition in all enclosing types
if (ref.typePath.isEmpty) {
resolveEnumName(inClass, ref.name)
} else {
val ty = resolveTypePath(inClass, ref.typePath)
ty.enums.get(ref.name) match {
case Some(spec) =>
spec
case None =>
throw new EnumNotFoundInTypeError(ref.name, ty)
}
}
}

def resolveEnum(inClass: ClassSpec, enumName: String): EnumSpec = {
private def resolveEnumName(inClass: ClassSpec, enumName: String): EnumSpec = {
inClass.enums.get(enumName) match {
case Some(spec) =>
spec
case None =>
// let's try upper levels of hierarchy
inClass.upClass match {
case Some(upClass) => resolveEnum(upClass, enumName)
case Some(upClass) => resolveEnumName(upClass, enumName)
case None =>
throw new EnumNotFoundError(enumName, nowClass)
throw new EnumNotFoundInHierarchyError(enumName, nowClass)
}
}
}

override def resolveType(typeName: Ast.typeId): DataType =
resolveClassSpec(typeName).toDataType

def resolveClassSpec(typeName: Ast.typeId): ClassSpec =
resolveClassSpec(
resolveTypePath(
if (typeName.absolute) topClass else nowClass,
typeName.names
)
).toDataType

def resolveClassSpec(inClass: ClassSpec, typeName: Seq[String]): ClassSpec = {
if (typeName.isEmpty)
def resolveTypePath(inClass: ClassSpec, path: Seq[String]): ClassSpec = {
if (path.isEmpty)
return inClass

val headTypeName :: restTypesNames = typeName.toList
val nextClass = resolveClassSpec(inClass, headTypeName)
if (restTypesNames.isEmpty) {
nextClass
} else {
resolveClassSpec(nextClass, restTypesNames)
val headTypeName :: restTypesNames = path.toList
var nextClass = resolveTypeName(inClass, headTypeName)
for (name <- restTypesNames) {
nextClass = nextClass.types.get(name) match {
case Some(spec) => spec
case None =>
throw new TypeNotFoundInTypeError(name, nextClass)
}
}
nextClass
}

def resolveClassSpec(inClass: ClassSpec, typeName: String): ClassSpec = {
def resolveTypeName(inClass: ClassSpec, typeName: String): ClassSpec = {
if (inClass.name.last == typeName)
return inClass

Expand All @@ -139,12 +152,12 @@ class ClassTypeProvider(classSpecs: ClassSpecs, var topClass: ClassSpec) extends
case None =>
// let's try upper levels of hierarchy
inClass.upClass match {
case Some(upClass) => resolveClassSpec(upClass, typeName)
case Some(upClass) => resolveTypeName(upClass, typeName)
case None =>
classSpecs.get(typeName) match {
case Some(spec) => spec
case None =>
throw new TypeNotFoundError(typeName, nowClass)
throw new TypeNotFoundInHierarchyError(typeName, nowClass)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,8 @@ class GraphvizClassCompiler(classSpecs: ClassSpecs, topClass: ClassSpec) extends
List()
case _: Ast.expr.EnumByLabel =>
List()
case Ast.expr.EnumById(_, id, _) =>
affectedVars(id)
case Ast.expr.EnumById(_, expr) =>
affectedVars(expr)
case Ast.expr.Attribute(value, attr) =>
if (attr.name == Identifier.SIZEOF) {
val vars = value match {
Expand Down Expand Up @@ -484,7 +484,7 @@ object GraphvizClassCompiler extends LanguageCompilerStatic {
): LanguageCompiler = ???

def type2class(name: List[String]) = name.last
def type2display(name: List[String]) = name.map(Utils.upperCamelCase).mkString("::")
def type2display(name: Seq[String]) = name.map(Utils.upperCamelCase).mkString("::")

def dataTypeName(dataType: DataType, valid: Option[ValidationSpec]): String = {
dataType match {
Expand All @@ -508,7 +508,7 @@ object GraphvizClassCompiler extends LanguageCompilerStatic {
val comma = if (bytesStr.isEmpty) "" else ", "
s"str($bytesStr$comma$encoding)"
case EnumType(name, basedOn) =>
s"${dataTypeName(basedOn, valid)}→${type2display(name)}"
s"${dataTypeName(basedOn, valid)}→${type2display(name.fullName)}"
case BitsType(width, bitEndian) => s"b$width${bitEndian.toSuffix}"
case BitsType1(bitEndian) => s"b1${bitEndian.toSuffix}→bool"
case _ => dataType.toString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ object DataType {
def isOwning = false
}

case class EnumType(name: List[String], basedOn: IntType) extends DataType {
case class EnumType(name: Ast.EnumRef, basedOn: IntType) extends DataType {
var enumSpec: Option[EnumSpec] = None

/**
Expand Down Expand Up @@ -487,7 +487,7 @@ object DataType {
enumRef match {
case Some(enumName) =>
r match {
case numType: IntType => EnumType(classNameToList(enumName), numType)
case numType: IntType => EnumType(Expressions.parseEnumRef(enumName), numType)
case _ =>
throw KSYParseError(s"tried to resolve non-integer $r to enum", path).toException
}
Expand Down
19 changes: 17 additions & 2 deletions shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,10 @@ object Ast {
case class FloatNum(n: BigDecimal) extends expr
case class Str(s: String) extends expr
case class Bool(n: Boolean) extends expr
case class EnumByLabel(enumName: identifier, label: identifier, inType: typeId = EmptyTypeId) extends expr
case class EnumById(enumName: identifier, id: expr, inType: typeId = EmptyTypeId) extends expr
/** Take named enumeration constant from the specified enumeration. */
case class EnumByLabel(ref: EnumRef, label: identifier) extends expr
/** Cast specified expression to the enumerated type. Used only by value instances with `enum` key. */
case class EnumById(ref: EnumRef, expr: expr) extends expr

case class Attribute(value: expr, attr: identifier) extends expr
case class CastToType(value: expr, typeName: typeId) extends expr
Expand Down Expand Up @@ -141,5 +143,18 @@ object Ast {
case object GtE extends cmpop
}

/**
* Reference to an enum in scope. Scope is defined by the `absolute` flag and
* a path to a type (which can be empty) in which enum is defined.
*/
case class EnumRef(absolute: Boolean, typePath: Seq[String], name: String) {
/** @return Type path and name of enum in one list. */
def fullName: Seq[String] = typePath :+ name
/**
* @return Enum designation name as human-readable string, to be used in compiler
* error messages.
*/
def asStr: String = fullName.mkString(if (absolute) "::" else "", "::", "")
}
case class TypeWithArguments(typeName: typeId, arguments: expr.List)
}
26 changes: 19 additions & 7 deletions shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,11 @@ object Expressions {
case (first, names: Seq[Ast.identifier]) =>
val isAbsolute = first.nonEmpty
val (enumName, enumLabel) = names.takeRight(2) match {
case Seq(a, b) => (a, b)
}
val typePath = names.dropRight(2)
if (typePath.isEmpty) {
Ast.expr.EnumByLabel(enumName, enumLabel, Ast.EmptyTypeId)
} else {
Ast.expr.EnumByLabel(enumName, enumLabel, Ast.typeId(isAbsolute, typePath.map(_.name)))
case Seq(a, b) => (a.name, b)
}
val typePath = names.dropRight(2).map(n => n.name)
val ref = Ast.EnumRef(isAbsolute, typePath, enumName)
Ast.expr.EnumByLabel(ref, enumLabel)
}

def byteSizeOfType[$: P]: P[Ast.expr.ByteSizeOfType] =
Expand All @@ -195,6 +192,13 @@ object Expressions {
case (path, Some(args)) => Ast.TypeWithArguments(path, args)
}

def enumRef[$: P]: P[Ast.EnumRef] = P(Start ~ "::".!.? ~ NAME.rep(1, "::") ~ End).map {
case (absolute, names) =>
// List have at least one element, so we always can split it into head and the last element
val typePath :+ enumName = names
Ast.EnumRef(absolute.nonEmpty, typePath.map(i => i.name), enumName.name)
}

class ParseException(val src: String, val failure: Parsed.Failure)
extends RuntimeException(failure.msg)

Expand All @@ -211,6 +215,14 @@ object Expressions {
*/
def parseTypeRef(src: String): Ast.TypeWithArguments = realParse(src, typeRef(_))

/**
* Parse string with reference to enumeration definition, optionally in full path format.
*
* @param src Enum reference as string, like `::path::to::enum`
* @return Object that represents path to enum
*/
def parseEnumRef(src: String): Ast.EnumRef = realParse(src, enumRef(_))

private def realParse[T](src: String, parser: P[_] => P[T]): T = {
val r = fastparse.parse(src.trim, parser)
r match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ object InstanceSpec {
case None =>
value
case Some(enumName) =>
Ast.expr.EnumById(Ast.identifier(enumName), value)
Ast.expr.EnumById(Expressions.parseEnumRef(enumName), value)
}

val ifExpr = ParseUtils.getOptValueExpression(srcMap, "if", path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ object CSharpCompiler extends LanguageCompilerStatic
case KaitaiStreamType | OwnedKaitaiStreamType => kstreamName

case t: UserType => types2class(t.name)
case EnumType(name, _) => types2class(name)
case EnumType(ref, _) => types2class(ref.fullName)

case at: ArrayType => {
importList.add("System.Collections.Generic")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,7 @@ object CppCompiler extends LanguageCompilerStatic
types2class(if (absolute) {
t.enumSpec.get.name
} else {
t.name
t.name.fullName
})

case at: ArrayType => {
Expand Down Expand Up @@ -1210,7 +1210,7 @@ object CppCompiler extends LanguageCompilerStatic
)
}

def types2class(components: List[String]) =
def types2class(components: Seq[String]) =
components.map(type2class).mkString("::")

def type2class(name: String) = Utils.lowerUnderscoreCase(name) + "_t"
Expand Down
Loading
Loading