From 3c88388ae5d5378a4b6f7082e1fb8c1f9dfc0077 Mon Sep 17 00:00:00 2001 From: Suyash Kumar Date: Mon, 12 Aug 2024 19:57:55 -0400 Subject: [PATCH 1/4] Initial Median implementation --- interpreter/operator_aggregate.go | 81 +++++++++++++ interpreter/operator_dispatcher.go | 11 ++ model/model.go | 9 ++ parser/operators.go | 17 ++- parser/operators_test.go | 27 +++++ tests/enginetests/operator_aggregate_test.go | 120 +++++++++++++++++++ tests/spectests/exclusions/exclusions.go | 1 - 7 files changed, 264 insertions(+), 2 deletions(-) diff --git a/interpreter/operator_aggregate.go b/interpreter/operator_aggregate.go index 9ae183c..6950dbe 100644 --- a/interpreter/operator_aggregate.go +++ b/interpreter/operator_aggregate.go @@ -16,6 +16,7 @@ package interpreter import ( "fmt" + "sort" "github.com/google/cql/model" "github.com/google/cql/result" @@ -264,6 +265,86 @@ func (i *interpreter) evalMinDateTime(m model.IUnaryExpression, operand result.V return result.New(dt) } +// Median(argument List) Decimal +// https://cql.hl7.org/09-b-cqlreference.html#median +func (i *interpreter) evalMedianDecimal(_ model.IUnaryExpression, operand result.Value) (result.Value, error) { + if result.IsNull(operand) { + return result.New(nil) + } + + l, err := result.ToSlice(operand) + if err != nil { + return result.Value{}, err + } + + var values []float64 + for _, elem := range l { + if result.IsNull(elem) { + continue + } + v, err := result.ToFloat64(elem) + if err != nil { + return result.Value{}, err + } + values = append(values, v) + } + if len(values) == 0 { + return result.New(nil) + } + + median := calculateMedianFloat64(values) + return result.New(median) +} + +// Median(argument List) Quantity +// https://cql.hl7.org/09-b-cqlreference.html#median +func (i *interpreter) evalMedianQuantity(_ model.IUnaryExpression, operand result.Value) (result.Value, error) { + if result.IsNull(operand) { + return result.New(nil) + } + + l, err := result.ToSlice(operand) + if err != nil { + return result.Value{}, err + } + + values := make([]float64, 0, len(l)) + var unit model.Unit + for _, elem := range l { + if result.IsNull(elem) { + continue + } + v, err := result.ToQuantity(elem) + if err != nil { + return result.Value{}, err + } + // We only support List where all the elements have the exact same unit, since we do not support + // mixed unit Quantity math in our engine yet. + if unit == "" { + unit = v.Unit + } else if unit != v.Unit { + return result.Value{}, fmt.Errorf("Median(List) operand has different units which is not supported, got %v and %v", unit, v.Unit) + } + values = append(values, v.Value) + } + if len(values) == 0 { + return result.New(nil) + } + median := calculateMedianFloat64(values) + return result.New(result.Quantity{Value: median, Unit: unit}) +} + +// calculateMedianFloat64 calculates the median of a slice of float64 values. +// This modifies the values slice in place while sorting it. +func calculateMedianFloat64(values []float64) float64 { + sort.Float64s(values) + mid := len(values) / 2 + if len(values)%2 == 0 { + return (values[mid-1] + values[mid]) / 2 + } + return values[mid] +} + // Sum(argument List) Decimal // Sum(argument List) Integer // Sum(argument List) Long diff --git a/interpreter/operator_dispatcher.go b/interpreter/operator_dispatcher.go index b68c327..a9924fc 100644 --- a/interpreter/operator_dispatcher.go +++ b/interpreter/operator_dispatcher.go @@ -565,6 +565,17 @@ func (i *interpreter) unaryOverloads(m model.IUnaryExpression) ([]convert.Overlo Result: i.evalSum, }, }, nil + case *model.Median: + return []convert.Overload[evalUnarySignature]{ + { + Operands: []types.IType{&types.List{ElementType: types.Decimal}}, + Result: i.evalMedianDecimal, + }, + { + Operands: []types.IType{&types.List{ElementType: types.Quantity}}, + Result: i.evalMedianQuantity, + }, + }, nil default: return nil, fmt.Errorf("unsupported Unary Expression %v", m.GetName()) } diff --git a/model/model.go b/model/model.go index fd05d05..ebab4ea 100644 --- a/model/model.go +++ b/model/model.go @@ -809,6 +809,12 @@ type Min struct{ *UnaryExpression } // far as we can tell. type Sum struct{ *UnaryExpression } +// Median ELM expression from https://cql.hl7.org/09-b-cqlreference.html#median +// TODO: b/347346351 - In ELM it's modeled as an AggregateExpression, but for now we model it as an +// UnaryExpression since there is no way to set the AggregateExpression's "path" property for CQL as +// far as we can tell. +type Median struct{ *UnaryExpression } + // CalculateAge CQL expression type type CalculateAge struct { *UnaryExpression @@ -1409,3 +1415,6 @@ func (a *Combine) GetName() string { return "Combine" } // GetName returns the name of the system operator. func (i *Indexer) GetName() string { return "Indexer" } + +// GetName returns the name of the system operator. +func (m *Median) GetName() string { return "Median" } diff --git a/parser/operators.go b/parser/operators.go index 9f2f5c2..ebdf398 100644 --- a/parser/operators.go +++ b/parser/operators.go @@ -19,10 +19,10 @@ import ( "fmt" "strings" + "github.com/antlr4-go/antlr/v4" "github.com/google/cql/internal/convert" "github.com/google/cql/model" "github.com/google/cql/types" - "github.com/antlr4-go/antlr/v4" ) // parseFunction uses the reference resolver to resolve the function, visits the operands, and sets @@ -184,6 +184,9 @@ func (v *visitor) resolveFunction(libraryName, funcName string, operands []model // The operands should be AgeInYearsAt(convertedBirthDate) resolved.WrappedOperands = []model.IExpression{res.WrappedOperand, resolved.WrappedOperands[0]} } + case *model.Median: + listType := resolved.WrappedOperands[0].GetResultType().(*types.List) + t.Expression = model.ResultType(listType.ElementType) } // Set Operands. @@ -1899,6 +1902,18 @@ func (p *Parser) loadSystemOperators() error { return &model.Message{} }, }, + { + name: "Median", + operands: [][]types.IType{ + {&types.List{ElementType: types.Decimal}}, + {&types.List{ElementType: types.Quantity}}, + }, + model: func() model.IExpression { + return &model.Median{ + UnaryExpression: &model.UnaryExpression{}, + } + }, + }, } for _, b := range systemOperators { diff --git a/parser/operators_test.go b/parser/operators_test.go index c4a5f43..2307906 100644 --- a/parser/operators_test.go +++ b/parser/operators_test.go @@ -1227,6 +1227,33 @@ func TestBuiltInFunctions(t *testing.T) { }, }, // AGGREGATE FUNCTIONS - https://cql.hl7.org/09-b-cqlreference.html#aggregate-functions + { + name: "Median Decimal", + cql: "Median({1.0, 2.0, 3.0})", + want: &model.Median{ + UnaryExpression: &model.UnaryExpression{ + Operand: model.NewList([]string{"1.0", "2.0", "3.0"}, types.Decimal), + Expression: model.ResultType(types.Decimal), + }, + }, + }, + { + name: "Median Quantity", + cql: "Median({1.0 'cm', 2.0 'cm', 3.0 'cm'})", + want: &model.Median{ + UnaryExpression: &model.UnaryExpression{ + Operand: &model.List{ + List: []model.IExpression{ + &model.Quantity{Value: 1.0, Unit: "cm", Expression: model.ResultType(types.Quantity)}, + &model.Quantity{Value: 2.0, Unit: "cm", Expression: model.ResultType(types.Quantity)}, + &model.Quantity{Value: 3.0, Unit: "cm", Expression: model.ResultType(types.Quantity)}, + }, + Expression: model.ResultType(&types.List{ElementType: types.Quantity}), + }, + Expression: model.ResultType(types.Quantity), + }, + }, + }, { name: "Count", cql: "Count({1, 2, 3})", diff --git a/tests/enginetests/operator_aggregate_test.go b/tests/enginetests/operator_aggregate_test.go index 0c5a31b..56eb680 100644 --- a/tests/enginetests/operator_aggregate_test.go +++ b/tests/enginetests/operator_aggregate_test.go @@ -565,3 +565,123 @@ func TestSum_Error(t *testing.T) { }) } } + +func TestMedian(t *testing.T) { + tests := []struct { + name string + cql string + wantModel model.IExpression + wantResult result.Value + }{ + { + name: "Median({1.5, 2.5, 3.5, 4.5})", + cql: "Median({1.5, 2.5, 3.5, 4.5})", + wantModel: &model.Median{ + UnaryExpression: &model.UnaryExpression{ + Operand: model.NewList([]string{"1.5", "2.5", "3.5", "4.5"}, types.Decimal), + Expression: model.ResultType(types.Decimal), + }, + }, + wantResult: newOrFatal(t, 3.0), + }, + { + name: "Median({1 'cm', 2 'cm', 3 'cm'})", + cql: "Median({1 'cm', 2 'cm', 3 'cm'})", + wantResult: newOrFatal(t, result.Quantity{Value: 2.0, Unit: "cm"}), + }, + { + name: "Median({1.5 'g', 2.5 'g', 3.5 'g', 4.5 'g'})", + cql: "Median({1.5 'g', 2.5 'g', 3.5 'g', 4.5 'g'})", + wantResult: newOrFatal(t, result.Quantity{Value: 3.0, Unit: "g"}), + }, + { + name: "Unordered list: Median({2.5 'g', 3.5 'g', 1.5 'g', 4.5 'g'})", + cql: "Median({2.5 'g', 3.5 'g', 1.5 'g', 4.5 'g'})", + wantResult: newOrFatal(t, result.Quantity{Value: 3.0, Unit: "g"}), + }, + { + name: "Median(List{})", + cql: "Median(List{})", + wantResult: newOrFatal(t, nil), + }, + { + name: "Median({null as Decimal})", + cql: "Median({null as Decimal})", + wantResult: newOrFatal(t, nil), + }, + { + name: "Median(null as List)", + cql: "Median(null as List)", + wantResult: newOrFatal(t, nil), + }, + { + name: "Median(List{})", + cql: "Median(List{})", + wantResult: newOrFatal(t, nil), + }, + { + name: "Median({null as Quantity})", + cql: "Median({null as Quantity})", + wantResult: newOrFatal(t, nil), + }, + { + name: "Median(null as List)", + cql: "Median(null as List)", + wantResult: newOrFatal(t, nil), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + p := newFHIRParser(t) + parsedLibs, err := p.Libraries(context.Background(), wrapInLib(t, tc.cql), parser.Config{}) + if err != nil { + t.Fatalf("Parse returned unexpected error: %v", err) + } + if diff := cmp.Diff(tc.wantModel, getTESTRESULTModel(t, parsedLibs)); tc.wantModel != nil && diff != "" { + t.Errorf("Parse diff (-want +got):\n%s", diff) + } + + results, err := interpreter.Eval(context.Background(), parsedLibs, defaultInterpreterConfig(t, p)) + if err != nil { + t.Fatalf("Eval returned unexpected error: %v", err) + } + if diff := cmp.Diff(tc.wantResult, getTESTRESULT(t, results), protocmp.Transform()); diff != "" { + t.Errorf("Eval diff (-want +got)\n%v", diff) + } + }) + } +} + +func TestMedian_Error(t *testing.T) { + tests := []struct { + name string + cql string + wantModel model.IExpression + wantErrContains string + }{ + { + name: "Median({1 'cm', 2 'g'})", + cql: "Median({1 'cm', 2 'g'})", + wantErrContains: "Median(List) operand has different units which is not supported", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + p := newFHIRParser(t) + parsedLibs, err := p.Libraries(context.Background(), wrapInLib(t, tc.cql), parser.Config{}) + if err != nil { + t.Fatalf("Parse returned unexpected error: %v", err) + } + if diff := cmp.Diff(tc.wantModel, getTESTRESULTModel(t, parsedLibs)); tc.wantModel != nil && diff != "" { + t.Errorf("Parse diff (-want +got):\n%s", diff) + } + + _, err = interpreter.Eval(context.Background(), parsedLibs, defaultInterpreterConfig(t, p)) + if !strings.Contains(err.Error(), tc.wantErrContains) { + t.Errorf("Eval returned unexpected error: %v, want error containing %q", err, tc.wantErrContains) + } + }) + } +} diff --git a/tests/spectests/exclusions/exclusions.go b/tests/spectests/exclusions/exclusions.go index c706145..993e9c2 100644 --- a/tests/spectests/exclusions/exclusions.go +++ b/tests/spectests/exclusions/exclusions.go @@ -30,7 +30,6 @@ func XMLTestFileExclusionDefinitions() map[string]XMLTestFileExclusions { "CqlAggregateFunctionsTest.xml": XMLTestFileExclusions{ GroupExcludes: []string{ // TODO: b/342061715 - unsupported operators. - "Median", "Mode", "PopulationStdDev", "PopulationVariance", From 4ce9929e61e128048def50939b7556ec0a795413 Mon Sep 17 00:00:00 2001 From: Suyash Kumar Date: Mon, 12 Aug 2024 20:18:40 -0400 Subject: [PATCH 2/4] init values slice capacity --- interpreter/operator_aggregate.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/interpreter/operator_aggregate.go b/interpreter/operator_aggregate.go index 6950dbe..4138c92 100644 --- a/interpreter/operator_aggregate.go +++ b/interpreter/operator_aggregate.go @@ -277,7 +277,7 @@ func (i *interpreter) evalMedianDecimal(_ model.IUnaryExpression, operand result return result.Value{}, err } - var values []float64 + values := make([]float64, 0, len(l)) for _, elem := range l { if result.IsNull(elem) { continue From fce4610d1fb807a30d18b47d3a88778755b301d7 Mon Sep 17 00:00:00 2001 From: Suyash Kumar Date: Mon, 12 Aug 2024 20:26:04 -0400 Subject: [PATCH 3/4] add decimal overload tests --- tests/enginetests/operator_aggregate_test.go | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/enginetests/operator_aggregate_test.go b/tests/enginetests/operator_aggregate_test.go index 56eb680..dcf66d7 100644 --- a/tests/enginetests/operator_aggregate_test.go +++ b/tests/enginetests/operator_aggregate_test.go @@ -595,10 +595,25 @@ func TestMedian(t *testing.T) { wantResult: newOrFatal(t, result.Quantity{Value: 3.0, Unit: "g"}), }, { - name: "Unordered list: Median({2.5 'g', 3.5 'g', 1.5 'g', 4.5 'g'})", + name: "Unordered Quantity list: Median({2.5 'g', 3.5 'g', 1.5 'g', 4.5 'g'})", cql: "Median({2.5 'g', 3.5 'g', 1.5 'g', 4.5 'g'})", wantResult: newOrFatal(t, result.Quantity{Value: 3.0, Unit: "g"}), }, + { + name: "Median({1.0, 2.0, 3.0})", + cql: "Median({1.0, 2.0, 3.0})", + wantResult: newOrFatal(t, 2.0), + }, + { + name: "Median({1.5, 2.5, 3.5, 4.5})", + cql: "Median({1.5, 2.5, 3.5, 4.5})", + wantResult: newOrFatal(t, 3.0), + }, + { + name: "Unordered Decimal list: Median({2.5, 3.5, 1.5, 4.5})", + cql: "Median({2.5, 3.5, 1.5, 4.5})", + wantResult: newOrFatal(t, 3.0), + }, { name: "Median(List{})", cql: "Median(List{})", From 738226be4c62c4b92026f0d558680884eda5d538 Mon Sep 17 00:00:00 2001 From: Suyash Kumar Date: Mon, 12 Aug 2024 20:40:56 -0400 Subject: [PATCH 4/4] Correctly error for mixed units where one is an empty unit string --- interpreter/operator_aggregate.go | 13 ++++++++----- tests/enginetests/operator_aggregate_test.go | 7 ++++++- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/interpreter/operator_aggregate.go b/interpreter/operator_aggregate.go index 4138c92..b23b272 100644 --- a/interpreter/operator_aggregate.go +++ b/interpreter/operator_aggregate.go @@ -310,7 +310,7 @@ func (i *interpreter) evalMedianQuantity(_ model.IUnaryExpression, operand resul values := make([]float64, 0, len(l)) var unit model.Unit - for _, elem := range l { + for idx, elem := range l { if result.IsNull(elem) { continue } @@ -318,11 +318,14 @@ func (i *interpreter) evalMedianQuantity(_ model.IUnaryExpression, operand resul if err != nil { return result.Value{}, err } - // We only support List where all the elements have the exact same unit, since we do not support - // mixed unit Quantity math in our engine yet. - if unit == "" { + // We only support List where all the elements have the exact same unit, since we + // do not support mixed unit Quantity math in our engine yet. + if idx == 0 { unit = v.Unit - } else if unit != v.Unit { + } + if unit != v.Unit { + // TODO: b/342061715 - technically we should treat '' unit and '1' unit as the same, but + // for now we don't (and we should apply this globally). return result.Value{}, fmt.Errorf("Median(List) operand has different units which is not supported, got %v and %v", unit, v.Unit) } values = append(values, v.Value) diff --git a/tests/enginetests/operator_aggregate_test.go b/tests/enginetests/operator_aggregate_test.go index dcf66d7..05f24fb 100644 --- a/tests/enginetests/operator_aggregate_test.go +++ b/tests/enginetests/operator_aggregate_test.go @@ -680,6 +680,11 @@ func TestMedian_Error(t *testing.T) { cql: "Median({1 'cm', 2 'g'})", wantErrContains: "Median(List) operand has different units which is not supported", }, + { + name: "Median({1 '', 2 'g'})", + cql: "Median({1 '', 2 'g'})", + wantErrContains: "Median(List) operand has different units which is not supported", + }, } for _, tc := range tests { @@ -694,7 +699,7 @@ func TestMedian_Error(t *testing.T) { } _, err = interpreter.Eval(context.Background(), parsedLibs, defaultInterpreterConfig(t, p)) - if !strings.Contains(err.Error(), tc.wantErrContains) { + if err == nil || !strings.Contains(err.Error(), tc.wantErrContains) { t.Errorf("Eval returned unexpected error: %v, want error containing %q", err, tc.wantErrContains) } })