Skip to content

Commit

Permalink
Get nested field for struct column
Browse files Browse the repository at this point in the history
  • Loading branch information
cchantep committed Aug 7, 2021
1 parent 45d7a60 commit bcd6432
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 1 deletion.
2 changes: 1 addition & 1 deletion dataset/src/main/scala/frameless/RecordEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ object NewInstanceExprs {
Literal.fromObject(()) +: tail.from(exprs)
}

implicit def deriveNonUnit[K <: Symbol, V , T <: HList]
implicit def deriveNonUnit[K <: Symbol, V, T <: HList]
(implicit
notUnit: V =:!= Unit,
tail: NewInstanceExprs[T]
Expand Down
13 changes: 13 additions & 0 deletions dataset/src/main/scala/frameless/TypedColumn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ sealed class TypedColumn[T, U](expr: Expression)(
}

override def typed[W, U1: TypedEncoder](c: Column): TypedColumn[W, U1] = c.typedColumn

override def lit[U1: TypedEncoder](c: U1): TypedColumn[T,U1] = flit(c)
}

Expand Down Expand Up @@ -828,6 +829,18 @@ abstract class AbstractTypedColumn[T, U]
w1: With.Aux[TT2, W1, W2]
): ThisType[W2, Boolean] =
typed(self.untyped.between(lowerBound.untyped, upperBound.untyped))

/**
* Returns a nested column matching the field `symbol`.
*
* @param V the type of the nested field
*/
def field[V](symbol: Witness.Lt[Symbol])(implicit
i0: TypedColumn.Exists[U, symbol.T, V],
i1: TypedEncoder[V]
): ThisType[T, V] =
typed(self.untyped.getField(symbol.value.name))

}


Expand Down
13 changes: 13 additions & 0 deletions dataset/src/test/scala/frameless/ColumnTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -425,4 +425,17 @@ class ColumnTests extends TypedDatasetSuite with Matchers {
"ds.select(ds('_1).opt.map(x => x))" shouldNot typeCheck
"ds.select(ds('_2).opt.map(x => x))" shouldNot typeCheck
}

test("field") {
val ds = TypedDataset.create((1, (2.3F, "a")) :: Nil)
val rs = ds.select(ds('_2).field('_2)).collect().run()

rs shouldEqual Seq("a")
}

test("field compiles only for valid field") {
val ds = TypedDataset.create((1, (2.3F, "a")) :: Nil)

"ds.select(ds('_2).field('_3))" shouldNot typeCheck
}
}

0 comments on commit bcd6432

Please sign in to comment.