Skip to content

Commit

Permalink
Merge pull request #799 from s1ck/ir_refactoring
Browse files Browse the repository at this point in the history
Refactor okapi expressions
  • Loading branch information
Mats-SX authored Feb 14, 2019
2 parents da982a2 + ae34d9e commit 9b5e7c4
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 570 deletions.
495 changes: 57 additions & 438 deletions okapi-ir/src/main/scala/org/opencypher/okapi/ir/api/expr/Expr.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,14 @@
package org.opencypher.okapi.ir.api.pattern

import org.opencypher.okapi.api.types.CTRelationship
import org.opencypher.okapi.impl.exception.IllegalArgumentException
import org.opencypher.okapi.ir.api._
import org.opencypher.okapi.ir.api.pattern.Orientation.{Cyclic, Directed, Undirected}
import org.opencypher.v9_0.expressions.SemanticDirection
import org.opencypher.v9_0.expressions.SemanticDirection.{BOTH, INCOMING, OUTGOING}
import org.opencypher.v9_0.expressions.SemanticDirection.OUTGOING

import scala.language.higherKinds

sealed trait Connection {
type SELF[XO, XE] <: Connection { type O = XO; type E = XE }
type O <: Orientation[E]
type E <: Endpoints

Expand All @@ -46,8 +44,6 @@ sealed trait Connection {
def source: IRField
def target: IRField

def flip: SELF[O, E]

override def hashCode(): Int = orientation.hash(endpoints, seed)
override def equals(obj: scala.Any): Boolean = super.equals(obj) || (obj != null && equalsIfNotEq(obj))

Expand All @@ -56,7 +52,6 @@ sealed trait Connection {
}

sealed trait DirectedConnection extends Connection {
override type SELF[XO, XE] <: DirectedConnection { type O = XO; type E = XE }
override type O = Directed.type
override type E = DifferentEndpoints

Expand All @@ -67,7 +62,6 @@ sealed trait DirectedConnection extends Connection {
}

sealed trait UndirectedConnection extends Connection {
override type SELF[XO, XE] <: UndirectedConnection { type O = XO; type E = XE }
override type O = Undirected.type
override type E = DifferentEndpoints

Expand All @@ -78,7 +72,6 @@ sealed trait UndirectedConnection extends Connection {
}

sealed trait CyclicConnection extends Connection {
override type SELF[XO, XE] <: CyclicConnection { type O = XO; type E = XE }
override type O = Cyclic.type
override type E = IdenticalEndpoints

Expand All @@ -93,25 +86,16 @@ case object SingleRelationship {
}

sealed trait SingleRelationship extends Connection {
override type SELF[XO, XE] <: SingleRelationship { type O = XO; type E = XE }
final protected override def seed: Int = SingleRelationship.seed
}

final case class DirectedRelationship(endpoints: DifferentEndpoints, semanticDirection: SemanticDirection)
extends SingleRelationship with DirectedConnection {

override type SELF[XO, XE] = DirectedRelationship { type O = XO; type E = XE }

protected def equalsIfNotEq(obj: scala.Any): Boolean = obj match {
case other: DirectedRelationship => orientation.eqv(endpoints, other.endpoints)
case _ => false
}

override def flip: DirectedRelationship = copy(endpoints.flip, semanticDirection = semanticDirection match {
case OUTGOING => INCOMING
case INCOMING => OUTGOING
case BOTH => throw IllegalArgumentException("semantic direction to be OUTGOING or INCOMING", "BOTH")
})
}

case object DirectedRelationship {
Expand All @@ -124,14 +108,10 @@ case object DirectedRelationship {
final case class UndirectedRelationship(endpoints: DifferentEndpoints)
extends SingleRelationship with UndirectedConnection {

override type SELF[XO, XE] = UndirectedRelationship { type O = XO; type E = XE }

protected def equalsIfNotEq(obj: scala.Any): Boolean = obj match {
case other: UndirectedRelationship => orientation.eqv(endpoints, other.endpoints)
case _ => false
}

override def flip: UndirectedRelationship = copy(endpoints.flip)
}

case object UndirectedRelationship {
Expand All @@ -143,22 +123,17 @@ case object UndirectedRelationship {

final case class CyclicRelationship(endpoints: IdenticalEndpoints) extends SingleRelationship with CyclicConnection {

override type SELF[XO, XE] = CyclicRelationship { type O = XO; type E = XE }

protected def equalsIfNotEq(obj: scala.Any): Boolean = obj match {
case other: CyclicRelationship => orientation.eqv(endpoints, other.endpoints)
case _ => false
}

override def flip: CyclicRelationship = this
}

object VarLengthRelationship {
val seed: Int = "VarLengthRelationship".hashCode
}

sealed trait VarLengthRelationship extends Connection {
override type SELF[XO, XE] <: VarLengthRelationship { type O = XO; type E = XE }
final protected override def seed: Int = VarLengthRelationship.seed

def lower: Int
Expand All @@ -173,13 +148,6 @@ final case class DirectedVarLengthRelationship(
upper: Option[Int],
semanticDirection: SemanticDirection = OUTGOING
) extends VarLengthRelationship with DirectedConnection {
override type SELF[XO, XE] = DirectedVarLengthRelationship {type O = XO; type E = XE}

override def flip: DirectedVarLengthRelationship = copy(endpoints = endpoints.flip, semanticDirection = semanticDirection match {
case OUTGOING => INCOMING
case INCOMING => OUTGOING
case BOTH => throw IllegalArgumentException("semantic direction to be OUTGOING or INCOMING", "BOTH")
})

override protected def equalsIfNotEq(obj: Any): Boolean = obj match {
case other: DirectedVarLengthRelationship => orientation.eqv(endpoints, other.endpoints)
Expand All @@ -188,9 +156,6 @@ final case class DirectedVarLengthRelationship(
}

final case class UndirectedVarLengthRelationship(edgeType: CTRelationship, endpoints: DifferentEndpoints, lower: Int, upper: Option[Int]) extends VarLengthRelationship with UndirectedConnection {
override type SELF[XO, XE] = UndirectedVarLengthRelationship { type O = XO; type E = XE }

override def flip: UndirectedVarLengthRelationship = this

override protected def equalsIfNotEq(obj: Any): Boolean = obj match {
case other: UndirectedVarLengthRelationship => orientation.eqv(endpoints, other.endpoints)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ final case class IRBuilderContext(

// TODO: Fuse monads
def infer(expr: ast.Expression): Map[Ref[ast.Expression], CypherType] = {
typer.infer(expr, TypeTracker(List(knownTypes), parameters.value)) match {
typer.infer(expr, TypeTracker(knownTypes, parameters.value)) match {
case Right(result) =>
result.recorder.toMap

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,17 +152,7 @@ object SchemaTyper {

case Ands(exprs) => processAndsOrs(expr, exprs.toVector)

case Ors(exprs) =>
for {
t1 <- get[R, TypeTracker]
t2 <- put[R, TypeTracker](t1.pushScope()) >> get[R, TypeTracker]
result <- processAndsOrs(expr, exprs.toVector)
_ <- t2.popScope() match {
case None => error(TypeTrackerScopeError) >> put[R, TypeTracker](TypeTracker.empty)
case Some(t) =>
put[R, TypeTracker](t)
}
} yield result
case Ors(exprs) => processAndsOrs(expr, exprs.toVector)

case Equals(lhs, rhs) =>
for {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,16 @@ import org.opencypher.v9_0.expressions.Expression
import scala.annotation.tailrec

object TypeTracker {
val empty = TypeTracker(List.empty)
val empty = TypeTracker(Map.empty)
}

case class TypeTracker(maps: List[Map[Expression, CypherType]], parameters: Map[String, CypherValue] = Map.empty) {
case class TypeTracker(map: Map[Expression, CypherType], parameters: Map[String, CypherValue] = Map.empty) {

def withParameters(newParameters: Map[String, CypherValue]): TypeTracker =
copy(parameters = newParameters)
def withParameters(newParameters: Map[String, CypherValue]): TypeTracker = copy(parameters = newParameters)

def get(e: Expression): Option[CypherType] = get(e, maps)
def get(e: Expression): Option[CypherType] = map.get(e)

def getParameterType(e: String): Option[CypherType] = parameters.get(e).map(_.cypherType)

@tailrec
private def get(e: Expression, maps: List[Map[Expression, CypherType]]): Option[CypherType] = maps.headOption match {
case None => None
case Some(map) if map.contains(e) => map.get(e)
case Some(_) => get(e, maps.tail)
}

def updated(e: Expression, t: CypherType): TypeTracker = copy(maps = head.updated(e, t) +: tail)
def updated(entry: (Expression, CypherType)): TypeTracker = updated(entry._1, entry._2)
def pushScope(): TypeTracker = copy(maps = Map.empty[Expression, CypherType] +: maps)
def popScope(): Option[TypeTracker] = if (maps.isEmpty) None else Some(copy(maps = maps.tail))

private def head: Map[Expression, CypherType] =
maps.headOption.getOrElse(Map.empty[Expression, CypherType])
private def tail: List[Map[Expression, CypherType]] =
if (maps.isEmpty) List.empty else maps.tail
def updated(e: Expression, t: CypherType): TypeTracker = copy(map = map.updated(e, t))
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,4 @@ class ExprTest extends BaseTestSuite {
(n as Var("m")()).cypherType should equal(n.cypherType)
}

test("set new cypher type") {
val n = Var("n")(CTNode("A"))
val updatedN = n.withCypherType(CTNode("B").nullable)
updatedN should equal(n)
updatedN.cypherType should equal(CTNode("B").nullable)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ package org.opencypher.okapi.ir.api.pattern
import org.opencypher.okapi.api.types.CTRelationship
import org.opencypher.okapi.ir.api.IRField
import org.opencypher.okapi.testing.BaseTestSuite
import org.opencypher.v9_0.expressions.SemanticDirection.{INCOMING, OUTGOING}
import org.opencypher.v9_0.expressions.SemanticDirection.OUTGOING

class ConnectionTest extends BaseTestSuite {

Expand All @@ -39,43 +39,13 @@ class ConnectionTest extends BaseTestSuite {

val relType = CTRelationship("FOO")

test("SimpleConnection.flip") {
DirectedRelationship(field_a, field_b).flip should equal(DirectedRelationship(field_b, field_a))
DirectedRelationship(field_a, field_a).flip should equal(DirectedRelationship(field_a, field_a))
}

test("SimpleConnection.flip with semantic direction") {
DirectedRelationship(field_a, field_b, OUTGOING).flip should equal(DirectedRelationship(field_b, field_a, INCOMING))
DirectedRelationship(field_a, field_a, INCOMING).flip should equal(DirectedRelationship(field_a, field_a, OUTGOING))
}

test("SimpleConnection.equals") {
DirectedRelationship(field_a, field_b) shouldNot equal(DirectedRelationship(field_b, field_a))
DirectedRelationship(field_a, field_a) should equal(DirectedRelationship(field_a, field_a))
DirectedRelationship(field_a, field_a, OUTGOING) should equal(DirectedRelationship(field_a, field_a, OUTGOING))
DirectedRelationship(field_a, field_a) shouldNot equal(DirectedRelationship(field_a, field_b))
}

test("VarLengthRelationship.flip") {
DirectedVarLengthRelationship(relType, field_a -> field_b, 0, Some(0)).flip should equal(
DirectedVarLengthRelationship(relType, field_b -> field_a, 0, Some(0)))

DirectedVarLengthRelationship(relType, field_a -> field_a, 0, Some(0)).flip should equal(
DirectedVarLengthRelationship(relType, field_a -> field_a, 0, Some(0)))
}

test("VarLengthRelationship.flip with semantic direction") {
DirectedVarLengthRelationship(relType, field_a -> field_b, 0, Some(0), OUTGOING).flip should equal(
DirectedVarLengthRelationship(relType, field_b -> field_a, 0, Some(0), INCOMING))

DirectedVarLengthRelationship(relType, field_a -> field_a, 0, Some(0), INCOMING).flip should equal(
DirectedVarLengthRelationship(relType, field_a -> field_a, 0, Some(0), OUTGOING))
}

test("UndirectedConnection.flip") {
UndirectedRelationship(field_a, field_b).flip should equal(UndirectedRelationship(field_b, field_a))
}

test("UndirectedConnection.equals") {
UndirectedRelationship(field_a, field_b) should equal(UndirectedRelationship(field_b, field_a))
UndirectedRelationship(field_c, field_c) should equal(UndirectedRelationship(field_c, field_c))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,6 @@ class SchemaTyperTest extends BaseTestSuite with Neo4jAstTestSupport with Mockit

assertExpr.from("n:Person") shouldMake varFor("n") haveType CTNode("Person")
assertExpr.from("n:Person AND n:Dog") shouldMake varFor("n") haveType CTNode("Person", "Dog")

assertExpr.from("n:Person OR n:Dog") shouldMake varFor("n") haveType CTNode // not enough information for us to act
}

it("typing equality") {
Expand Down Expand Up @@ -604,7 +602,7 @@ class SchemaTyperTest extends BaseTestSuite with Neo4jAstTestSupport with Mockit
}

private def typeTracker(typings: (String, CypherType)*): TypeTracker =
TypeTracker(List(typings.map { case (v, t) => varFor(v) -> t }.toMap))
TypeTracker(typings.map { case (v, t) => varFor(v) -> t }.toMap)

private object assertExpr {
def from(exprText: String)(implicit tracker: TypeTracker = TypeTracker.empty) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,4 @@ class TypeTrackerTest extends BaseTestSuite with AstConstructionTestSupport {
tracker.get(True()(pos)) shouldBe Some(CTString)
}

test("push scope and lookup") {
val tracker = TypeTracker.empty.updated(True()(pos), CTString).pushScope()

tracker.get(True()(pos)) shouldBe Some(CTString)
}

test("pushing and popping scope") {
val tracker1 = TypeTracker.empty.updated(True()(pos), CTString)

val tracker2 = tracker1.pushScope().updated(False()(pos), CTBoolean).popScope()

tracker1 should equal(tracker2.get)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -384,15 +384,24 @@ case class RecordHeader(exprToColumn: Map[Expr, String]) {
}

def ++(other: RecordHeader): RecordHeader = {
val joined = exprToColumn ++ other.exprToColumn
val withJoinedCypherTypes: Map[Expr, String] = joined.map {
val result = (exprToColumn ++ other.exprToColumn).map {
case (key, value) =>
val leftCT = exprToColumn.keySet.find(_ == key).map(_.cypherType).getOrElse(CTVoid)
val rightCT = other.exprToColumn.keySet.find(_ == key).map(_.cypherType).getOrElse(CTVoid)
key.withCypherType(leftCT join rightCT) -> value
}

copy(exprToColumn = withJoinedCypherTypes)
val resultExpr = (key, leftCT, rightCT) match {
case (v: Var, l: CTNode, r: CTNode) => Var(v.name)(l.join(r))
case (v: Var, l: CTRelationship, r: CTRelationship) => Var(v.name)(l.join(r))
case (_, l, r) if l.subTypeOf(r).isTrue => other.exprToColumn.keySet.collectFirst { case k if k == key => k }.get
case (_, l, r) if r.subTypeOf(l).isTrue => key
case _ => throw IllegalArgumentException(
expected = s"Compatible Cypher types for expression $key",
actual = s"left type `$leftCT` and right type `$rightCT`"
)
}
resultExpr -> value
}
copy(exprToColumn = result)
}

def --[T <: Expr](expressions: Set[T]): RecordHeader = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,13 @@ class RecordHeaderTest extends BaseTestSuite {
unionHeader.ownedBy(o) should equalWithTracing(oExprs)
}

it("can combine headers with same vars but different cypherType") {
it("can combine headers with same vars and compatible cypherType") {
Seq(
CTBoolean -> CTBoolean,
CTInteger -> CTString,
CTNode("A") -> CTNode("B"),
CTRelationship("A") -> CTRelationship("B")
CTNode("A") -> CTNode("A", "B"),
CTRelationship("A") -> CTRelationship("B"),
CTRelationship("A") -> CTRelationship("A", "B")
).foreach {
case (p1Type, p2Type) =>
val p1 = Var("p")(p1Type)
Expand Down

0 comments on commit 9b5e7c4

Please sign in to comment.