Skip to content

Commit

Permalink
#36 Reorder nested and filter aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
To-om committed Nov 10, 2017
1 parent 4c05ea5 commit 3b4eb7f
Showing 1 changed file with 38 additions and 38 deletions.
76 changes: 38 additions & 38 deletions app/org/elastic4play/services/Aggregations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import scala.collection.JavaConverters._
import play.api.libs.json.{ JsArray, JsNumber, JsObject }

import com.sksamuel.elastic4s.ElasticDsl.{ avgAggregation, dateHistogramAggregation, filterAggregation, matchAllQuery, maxAggregation, minAggregation, nestedAggregation, sumAggregation, termsAggregation, topHitsAggregation }
import com.sksamuel.elastic4s.script.ScriptDefinition
import com.sksamuel.elastic4s.searches.RichSearchHit
import com.sksamuel.elastic4s.searches.aggs._
import com.sksamuel.elastic4s.searches.queries.QueryDefinition
import org.elasticsearch.search.aggregations.bucket.filter.Filter
import org.elasticsearch.search.aggregations.bucket.filters.Filters
import org.elasticsearch.search.aggregations.bucket.histogram.{ DateHistogramInterval, Histogram }
Expand All @@ -32,19 +32,7 @@ abstract class Agg(val aggregationName: String) {
def processResult(model: BaseModelDef, aggregations: RichAggregations): JsObject
}

trait AggQuery { _: Agg
def filteredAgg(agg: AggregationDefinition, query: Option[QueryDef]): AggregationDefinition = {
query match {
case None agg
case Some(q) filterAggregation(agg.name).query(q.query).subAggregations(agg)
}
}
def filteredResult(aggregations: RichAggregations, query: Option[QueryDef]): RichAggregations = {
query.fold(aggregations)(_ RichAggregations(aggregations.aggregations.get[Filter](aggregationName).getAggregations))
}
}

abstract class FieldAgg(val fieldName: String, aggregationName: String) extends Agg(aggregationName) with AggQuery {
abstract class FieldAgg(val fieldName: String, aggregationName: String, query: Option[QueryDef]) extends Agg(aggregationName) {
def script(s: String): AggregationDefinition

def field(f: String): AggregationDefinition
Expand All @@ -59,36 +47,45 @@ abstract class FieldAgg(val fieldName: String, aggregationName: String) extends
}

def getAggregation(fieldName: String, aggregations: RichAggregations, query: Option[QueryDef]): RichAggregations = {
val agg = if (fieldName.startsWith("computed")) aggregations

val agg = query match {
case None aggregations
case _ RichAggregations(aggregations.aggregations.get[Filter](aggregationName).getAggregations)
}

if (fieldName.startsWith("computed")) agg
else {
fieldName.split("\\.").init.foldLeft(aggregations) { (agg, _)
RichAggregations(agg.getAs[Nested](aggregationName).getAggregations)
fieldName.split("\\.").init.foldLeft(aggregations) { (a, _)
RichAggregations(a.getAs[Nested](aggregationName).getAggregations)
}
}
filteredResult(agg, query)
}

def apply(model: BaseModelDef): Seq[AggregationDefinition] = {
fieldName.split("\\.") match {
val aggs = fieldName.split("\\.") match {
case Array("computed", c)
val s = model.computedMetrics.getOrElse(
c,
throw BadRequestError(s"Field $fieldName is unknown in ${model.name}"))
Seq(script(s))
case array
if (!model.attributes.exists(_.name == array(0))) {
if (array(0) != "" && !model.attributes.exists(_.name == array(0))) {
throw BadRequestError(s"Field $fieldName is unknown in ${model.name}")
}
// TODO check attribute type
nested(fieldName, Seq(field(fieldName)))
}
query match {
case None aggs
case Some(q) Seq(filterAggregation(aggregationName).query(q.query).subAggregations(aggs))
}
}
}

class SelectAvg(aggregationName: String, fieldName: String, query: Option[QueryDef]) extends FieldAgg(fieldName, aggregationName) with AggQuery {
def script(s: String): AggregationDefinition = filteredAgg(avgAggregation(aggregationName).script(s), query)
class SelectAvg(aggregationName: String, fieldName: String, query: Option[QueryDef]) extends FieldAgg(fieldName, aggregationName, query) {
def script(s: String): AggregationDefinition = avgAggregation(aggregationName).script(ScriptDefinition(s).lang("groovy"))

def field(f: String): AggregationDefinition = filteredAgg(avgAggregation(aggregationName).field(f), query)
def field(f: String): AggregationDefinition = avgAggregation(aggregationName).field(f)

def processResult(model: BaseModelDef, aggregations: RichAggregations): JsObject = {
val avg = getAggregation(fieldName, aggregations, query).getAs[Avg](aggregationName)
Expand All @@ -97,10 +94,10 @@ class SelectAvg(aggregationName: String, fieldName: String, query: Option[QueryD
}
}

class SelectMin(aggregationName: String, fieldName: String, query: Option[QueryDef]) extends FieldAgg(fieldName, aggregationName) with AggQuery {
def script(s: String): AggregationDefinition = filteredAgg(minAggregation(aggregationName).script(s), query)
class SelectMin(aggregationName: String, fieldName: String, query: Option[QueryDef]) extends FieldAgg(fieldName, aggregationName, query) {
def script(s: String): AggregationDefinition = minAggregation(aggregationName).script(ScriptDefinition(s).lang("groovy"))

def field(f: String): AggregationDefinition = filteredAgg(minAggregation(aggregationName).field(f), query)
def field(f: String): AggregationDefinition = minAggregation(aggregationName).field(f)

def processResult(model: BaseModelDef, aggregations: RichAggregations): JsObject = {
val min = getAggregation(fieldName, aggregations, query).getAs[Min](aggregationName)
Expand All @@ -109,10 +106,10 @@ class SelectMin(aggregationName: String, fieldName: String, query: Option[QueryD
}
}

class SelectMax(aggregationName: String, fieldName: String, query: Option[QueryDef]) extends FieldAgg(fieldName, aggregationName) with AggQuery {
def script(s: String): AggregationDefinition = filteredAgg(maxAggregation(aggregationName).script(s), query)
class SelectMax(aggregationName: String, fieldName: String, query: Option[QueryDef]) extends FieldAgg(fieldName, aggregationName, query) {
def script(s: String): AggregationDefinition = maxAggregation(aggregationName).script(ScriptDefinition(s).lang("groovy"))

def field(f: String): AggregationDefinition = filteredAgg(maxAggregation(aggregationName).field(f), query)
def field(f: String): AggregationDefinition = maxAggregation(aggregationName).field(f)

def processResult(model: BaseModelDef, aggregations: RichAggregations): JsObject = {
val max = getAggregation(fieldName, aggregations, query).getAs[Max](aggregationName)
Expand All @@ -121,10 +118,10 @@ class SelectMax(aggregationName: String, fieldName: String, query: Option[QueryD
}
}

class SelectSum(aggregationName: String, fieldName: String, query: Option[QueryDef]) extends FieldAgg(fieldName, aggregationName) with AggQuery {
def script(s: String): AggregationDefinition = filteredAgg(sumAggregation(aggregationName).script(s), query)
class SelectSum(aggregationName: String, fieldName: String, query: Option[QueryDef]) extends FieldAgg(fieldName, aggregationName, query) {
def script(s: String): AggregationDefinition = sumAggregation(aggregationName).script(ScriptDefinition(s).lang("groovy"))

def field(f: String): AggregationDefinition = filteredAgg(sumAggregation(aggregationName).field(f), query)
def field(f: String): AggregationDefinition = sumAggregation(aggregationName).field(f)

def processResult(model: BaseModelDef, aggregations: RichAggregations): JsObject = {
val sum = getAggregation(fieldName, aggregations, query).getAs[Sum](aggregationName)
Expand All @@ -133,21 +130,24 @@ class SelectSum(aggregationName: String, fieldName: String, query: Option[QueryD
}
}

class SelectCount(aggregationName: String, query: Option[QueryDef]) extends Agg(aggregationName) {
override def apply(model: BaseModelDef) = Seq(filterAggregation(aggregationName).query(query.fold[QueryDefinition](matchAllQuery)(_.query)))
class SelectCount(aggregationName: String, query: Option[QueryDef]) extends FieldAgg("", aggregationName, query) {
def script(s: String): AggregationDefinition = ???

def field(f: String): AggregationDefinition = filterAggregation(aggregationName).query(matchAllQuery)

def processResult(model: BaseModelDef, aggregations: RichAggregations): JsObject = {
val count = aggregations.getAs[Filter](aggregationName)
JsObject(Seq(count.getName JsNumber(count.getDocCount)))
}
}

class SelectTop(aggregationName: String, size: Int, sortBy: Seq[String], query: Option[QueryDef] = None) extends Agg(aggregationName) with AggQuery {
def apply(model: BaseModelDef) = Seq(filteredAgg(topHitsAggregation(aggregationName).size(size).sortBy(DBUtils.sortDefinition(sortBy)), query))
class SelectTop(aggregationName: String, size: Int, sortBy: Seq[String], query: Option[QueryDef] = None) extends FieldAgg("", aggregationName, query) {
def script(s: String): AggregationDefinition = ???

def field(f: String): AggregationDefinition = topHitsAggregation(aggregationName).size(size).sortBy(DBUtils.sortDefinition(sortBy))

def processResult(model: BaseModelDef, aggregations: RichAggregations): JsObject = {
val top = filteredResult(aggregations, query)
.getAs[TopHits](aggregationName)
val top = aggregations.getAs[TopHits](aggregationName)
JsObject(Seq("top" JsArray(top.getHits.getHits.map(h DBUtils.hit2json(RichSearchHit(h))))))
}
}
Expand Down

0 comments on commit 3b4eb7f

Please sign in to comment.