From 9c6738fa800e2447a3ab8141917203aab125974b Mon Sep 17 00:00:00 2001 From: jingzhang36 Date: Fri, 31 Jul 2020 17:41:06 +0800 Subject: [PATCH] feat(backend): sort by run metrics - step 3. Part of #3591 (#4251) * enable pagination when expanding experiment in both the home page and the archive page * Revert "enable pagination when expanding experiment in both the home page and the archive page" This reverts commit 5b672739dd88235d41b5867666db917a8eb80a10. * sorting by run metrics is different from sorting by name, uuid, created at, etc. The lattre are direct field in listable object, the former is an element in an arrary-typed field in listable object. In other words, the latter are columns in table, the former is not. * unit test: add sorting on metrics with both asc and desc order * GetFieldValue in all models * fix unit test * whether to test in list_test. It's hacky when check mode == 'run' * move model specific code to model; prevent model package depends on list package; let list package depends on modelpackage; marshal/unmarshal listable interface; include listable interface in token. * some assumption on token's Model field * fix the regular field checking logic * add comment to help devs to use the new field * add a validation check * Listable object can be too large to be in token. So replace it with only relevant fields taken out of it. In the future, if more fields in Listable object become relevant, manually add it to token * matches func update --- backend/src/apiserver/list/list.go | 108 ++--- backend/src/apiserver/list/list_test.go | 406 +++++++++++++----- backend/src/apiserver/model/BUILD.bazel | 3 - backend/src/apiserver/model/experiment.go | 33 +- backend/src/apiserver/model/job.go | 29 +- backend/src/apiserver/model/pipeline.go | 29 +- .../src/apiserver/model/pipeline_version.go | 29 +- .../apiserver/model/pipeline_version_test.go | 32 -- .../model/resource_reference_test.go | 3 +- backend/src/apiserver/model/run.go | 38 ++ backend/src/apiserver/model/run_test.go | 51 --- .../src/apiserver/server/list_request_util.go | 4 + .../server/list_request_util_test.go | 24 ++ backend/src/apiserver/storage/run_store.go | 20 +- 14 files changed, 523 insertions(+), 286 deletions(-) delete mode 100644 backend/src/apiserver/model/pipeline_version_test.go delete mode 100644 backend/src/apiserver/model/run_test.go diff --git a/backend/src/apiserver/list/list.go b/backend/src/apiserver/list/list.go index 54698e48142..f9c357668b2 100644 --- a/backend/src/apiserver/list/list.go +++ b/backend/src/apiserver/list/list.go @@ -43,16 +43,15 @@ type token struct { SortByFieldName string // SortByFieldValue is the value of the sorted field of the next row to be // returned. - SortByFieldValue interface{} - // SortByFieldIsRunMetric indicates whether the SortByFieldName field is - // a run metric field or not. - SortByFieldIsRunMetric bool + SortByFieldValue interface{} + SortByFieldPrefix string // KeyFieldName is the name of the primary key for the model being queried. KeyFieldName string // KeyFieldValue is the value of the sorted field of the next row to be // returned. - KeyFieldValue interface{} + KeyFieldValue interface{} + KeyFieldPrefix string // IsDesc is true if the sorting order should be descending. IsDesc bool @@ -101,7 +100,7 @@ type Options struct { // Matches returns trues if the sorting and filtering criteria in o matches that // of the one supplied in opts. func (o *Options) Matches(opts *Options) bool { - return o.SortByFieldName == opts.SortByFieldName && o.SortByFieldIsRunMetric == opts.SortByFieldIsRunMetric && + return o.SortByFieldName == opts.SortByFieldName && o.SortByFieldPrefix == opts.SortByFieldPrefix && o.IsDesc == opts.IsDesc && reflect.DeepEqual(o.Filter, opts.Filter) } @@ -147,24 +146,16 @@ func NewOptions(listable Listable, pageSize int, sortBy string, filterProto *api } token.SortByFieldName = listable.DefaultSortField() - token.SortByFieldIsRunMetric = false if len(queryList) > 0 { - var err error - n, ok := listable.APIToModelFieldMap()[queryList[0]] + n, ok := listable.GetField(queryList[0]) if ok { token.SortByFieldName = n - } else if strings.HasPrefix(queryList[0], "metric:") { - // Sorting on metrics is only available on certain runs. - model := reflect.ValueOf(listable).Elem().Type().Name() - if model != "Run" { - return nil, util.NewInvalidInputError("Invalid sorting field: %q on %q : %s", queryList[0], model, err) - } - token.SortByFieldName = queryList[0][7:] - token.SortByFieldIsRunMetric = true } else { - return nil, util.NewInvalidInputError("Invalid sorting field: %q: %s", queryList[0], err) + return nil, util.NewInvalidInputError("Invalid sorting field: %q on listable type %s", queryList[0], reflect.ValueOf(listable).Elem().Type().Name()) } } + token.SortByFieldPrefix = listable.GetSortByFieldPrefix(token.SortByFieldName) + token.KeyFieldPrefix = listable.GetKeyFieldPrefix() if len(queryList) == 2 { token.IsDesc = queryList[1] == "desc" @@ -196,31 +187,18 @@ func (o *Options) AddPaginationToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBu // AddSortingToSelect adds Order By clause. func (o *Options) AddSortingToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBuilder { // When sorting by a direct field in the listable model (i.e., name in Run or uuid in Pipeline), a sortByFieldPrefix can be specified; when sorting by a field in an array-typed dictionary (i.e., a run metric inside the metrics in Run), a sortByFieldPrefix is not needed. - var keyFieldPrefix string - var sortByFieldPrefix string - if len(o.ModelName) == 0 { - keyFieldPrefix = "" - sortByFieldPrefix = "" - } else if o.SortByFieldIsRunMetric { - keyFieldPrefix = o.ModelName + "." - sortByFieldPrefix = "" - } else { - keyFieldPrefix = o.ModelName + "." - sortByFieldPrefix = o.ModelName + "." - } - // If next row's value is specified, set those values in the clause. if o.SortByFieldValue != nil && o.KeyFieldValue != nil { if o.IsDesc { sqlBuilder = sqlBuilder. - Where(sq.Or{sq.Lt{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, - sq.And{sq.Eq{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, - sq.LtOrEq{keyFieldPrefix + o.KeyFieldName: o.KeyFieldValue}}}) + Where(sq.Or{sq.Lt{o.SortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, + sq.And{sq.Eq{o.SortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, + sq.LtOrEq{o.KeyFieldPrefix + o.KeyFieldName: o.KeyFieldValue}}}) } else { sqlBuilder = sqlBuilder. - Where(sq.Or{sq.Gt{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, - sq.And{sq.Eq{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, - sq.GtOrEq{keyFieldPrefix + o.KeyFieldName: o.KeyFieldValue}}}) + Where(sq.Or{sq.Gt{o.SortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, + sq.And{sq.Eq{o.SortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, + sq.GtOrEq{o.KeyFieldPrefix + o.KeyFieldName: o.KeyFieldValue}}}) } } @@ -229,25 +207,12 @@ func (o *Options) AddSortingToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBuild order = "DESC" } sqlBuilder = sqlBuilder. - OrderBy(fmt.Sprintf("%v %v", sortByFieldPrefix+o.SortByFieldName, order)). - OrderBy(fmt.Sprintf("%v %v", keyFieldPrefix+o.KeyFieldName, order)) + OrderBy(fmt.Sprintf("%v %v", o.SortByFieldPrefix+o.SortByFieldName, order)). + OrderBy(fmt.Sprintf("%v %v", o.KeyFieldPrefix+o.KeyFieldName, order)) return sqlBuilder } -// Add a metric as a new field to the select clause by join the passed-in SQL query with run_metrics table. -// With the metric as a field in the select clause enable sorting on this metric afterwards. -func (o *Options) AddSortByRunMetricToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBuilder { - if !o.SortByFieldIsRunMetric { - return sqlBuilder - } - // TODO(jingzhang36): address the case where runs doesn't have the specified metric. - return sq. - Select("selected_runs.*, run_metrics.numbervalue as "+o.SortByFieldName). - FromSelect(sqlBuilder, "selected_runs"). - LeftJoin("run_metrics ON selected_runs.uuid=run_metrics.runuuid AND run_metrics.name='" + o.SortByFieldName + "'") -} - // AddFilterToSelect adds WHERE clauses with the filtering criteria in the // Options o to the supplied SelectBuilder, and returns the new SelectBuilder // containing these. @@ -345,6 +310,12 @@ type Listable interface { APIToModelFieldMap() map[string]string // GetModelName returns table name used as sort field prefix. GetModelName() string + // Get the prefix of sorting field. + GetSortByFieldPrefix(string) string + // Get the prefix of key field. + GetKeyFieldPrefix() string + // Get a valid field for sorting/filtering in a listable model from the given string. + GetField(name string) (string, bool) // Find the value of a given field in a listable object. GetFieldValue(name string) interface{} } @@ -365,20 +336,8 @@ func (o *Options) nextPageToken(listable Listable) (*token, error) { elemName := elem.Type().Name() var sortByField interface{} - // TODO(jingzhang36): this if-else block can be simplified to one call to - // GetFieldValue after all the models (run, job, experiment, etc.) implement - // GetFieldValue method in listable interface. - if !o.SortByFieldIsRunMetric { - if value := elem.FieldByName(o.SortByFieldName); value.IsValid() { - sortByField = value.Interface() - } else { - return nil, util.NewInvalidInputError("cannot sort by field %q on type %q", o.SortByFieldName, elemName) - } - } else { - sortByField = listable.GetFieldValue(o.SortByFieldName) - if sortByField == nil { - return nil, util.NewInvalidInputError("Unable to find run metric %s", o.SortByFieldName) - } + if sortByField = listable.GetFieldValue(o.SortByFieldName); sortByField == nil { + return nil, util.NewInvalidInputError("cannot sort by field %q on type %q", o.SortByFieldName, elemName) } keyField := elem.FieldByName(listable.PrimaryKeyColumnName()) @@ -387,14 +346,15 @@ func (o *Options) nextPageToken(listable Listable) (*token, error) { } return &token{ - SortByFieldName: o.SortByFieldName, - SortByFieldValue: sortByField, - SortByFieldIsRunMetric: o.SortByFieldIsRunMetric, - KeyFieldName: listable.PrimaryKeyColumnName(), - KeyFieldValue: keyField.Interface(), - IsDesc: o.IsDesc, - Filter: o.Filter, - ModelName: o.ModelName, + SortByFieldName: o.SortByFieldName, + SortByFieldValue: sortByField, + SortByFieldPrefix: listable.GetSortByFieldPrefix(o.SortByFieldName), + KeyFieldName: listable.PrimaryKeyColumnName(), + KeyFieldValue: keyField.Interface(), + KeyFieldPrefix: listable.GetKeyFieldPrefix(), + IsDesc: o.IsDesc, + Filter: o.Filter, + ModelName: o.ModelName, }, nil } diff --git a/backend/src/apiserver/list/list_test.go b/backend/src/apiserver/list/list_test.go index 215961f3c8c..8ad4667914e 100644 --- a/backend/src/apiserver/list/list_test.go +++ b/backend/src/apiserver/list/list_test.go @@ -2,10 +2,12 @@ package list import ( "reflect" + "strings" "testing" "github.com/kubeflow/pipelines/backend/src/apiserver/common" "github.com/kubeflow/pipelines/backend/src/apiserver/filter" + "github.com/kubeflow/pipelines/backend/src/apiserver/model" "github.com/kubeflow/pipelines/backend/src/common/util" "github.com/stretchr/testify/assert" @@ -15,10 +17,16 @@ import ( api "github.com/kubeflow/pipelines/backend/api/go_client" ) +type fakeMetric struct { + Name string + Value float64 +} + type fakeListable struct { PrimaryKey string FakeName string CreatedTimestamp int64 + Metrics []*fakeMetric } func (f *fakeListable) PrimaryKeyColumnName() string { @@ -43,12 +51,52 @@ func (f *fakeListable) GetModelName() string { return "" } +func (f *fakeListable) GetField(name string) (string, bool) { + if field, ok := fakeAPIToModelMap[name]; ok { + return field, true + } + if strings.HasPrefix(name, "metric:") { + return name[7:], true + } + return "", false +} + func (f *fakeListable) GetFieldValue(name string) interface{} { + switch name { + case "CreatedTimestamp": + return f.CreatedTimestamp + case "FakeName": + return f.FakeName + case "PrimaryKey": + return f.PrimaryKey + } + for _, metric := range f.Metrics { + if metric.Name == name { + return metric.Value + } + } return nil } +func (f *fakeListable) GetSortByFieldPrefix(name string) string { + return "" +} + +func (f *fakeListable) GetKeyFieldPrefix() string { + return "" +} + func TestNextPageToken_ValidTokens(t *testing.T) { - l := &fakeListable{PrimaryKey: "uuid123", FakeName: "Fake", CreatedTimestamp: 1234} + l := &fakeListable{PrimaryKey: "uuid123", FakeName: "Fake", CreatedTimestamp: 1234, Metrics: []*fakeMetric{ + { + Name: "m1", + Value: 1.0, + }, + { + Name: "m2", + Value: 2.0, + }, + }} protoFilter := &api.Filter{Predicates: []*api.Predicate{ &api.Predicate{ @@ -70,11 +118,13 @@ func TestNextPageToken_ValidTokens(t *testing.T) { PageSize: 10, token: &token{SortByFieldName: "CreatedTimestamp", IsDesc: true}, }, want: &token{ - SortByFieldName: "CreatedTimestamp", - SortByFieldValue: int64(1234), - KeyFieldName: "PrimaryKey", - KeyFieldValue: "uuid123", - IsDesc: true, + SortByFieldName: "CreatedTimestamp", + SortByFieldValue: int64(1234), + SortByFieldPrefix: "", + KeyFieldName: "PrimaryKey", + KeyFieldValue: "uuid123", + KeyFieldPrefix: "", + IsDesc: true, }, }, { @@ -82,11 +132,13 @@ func TestNextPageToken_ValidTokens(t *testing.T) { PageSize: 10, token: &token{SortByFieldName: "PrimaryKey", IsDesc: true}, }, want: &token{ - SortByFieldName: "PrimaryKey", - SortByFieldValue: "uuid123", - KeyFieldName: "PrimaryKey", - KeyFieldValue: "uuid123", - IsDesc: true, + SortByFieldName: "PrimaryKey", + SortByFieldValue: "uuid123", + SortByFieldPrefix: "", + KeyFieldName: "PrimaryKey", + KeyFieldValue: "uuid123", + KeyFieldPrefix: "", + IsDesc: true, }, }, { @@ -94,11 +146,13 @@ func TestNextPageToken_ValidTokens(t *testing.T) { PageSize: 10, token: &token{SortByFieldName: "FakeName", IsDesc: false}, }, want: &token{ - SortByFieldName: "FakeName", - SortByFieldValue: "Fake", - KeyFieldName: "PrimaryKey", - KeyFieldValue: "uuid123", - IsDesc: false, + SortByFieldName: "FakeName", + SortByFieldValue: "Fake", + SortByFieldPrefix: "", + KeyFieldName: "PrimaryKey", + KeyFieldValue: "uuid123", + KeyFieldPrefix: "", + IsDesc: false, }, }, { @@ -110,12 +164,31 @@ func TestNextPageToken_ValidTokens(t *testing.T) { }, }, want: &token{ - SortByFieldName: "FakeName", - SortByFieldValue: "Fake", - KeyFieldName: "PrimaryKey", - KeyFieldValue: "uuid123", - IsDesc: false, - Filter: testFilter, + SortByFieldName: "FakeName", + SortByFieldValue: "Fake", + SortByFieldPrefix: "", + KeyFieldName: "PrimaryKey", + KeyFieldValue: "uuid123", + KeyFieldPrefix: "", + IsDesc: false, + Filter: testFilter, + }, + }, + { + inOpts: &Options{ + PageSize: 10, + token: &token{ + SortByFieldName: "m1", IsDesc: false, + }, + }, + want: &token{ + SortByFieldName: "m1", + SortByFieldValue: 1.0, + SortByFieldPrefix: "", + KeyFieldName: "PrimaryKey", + KeyFieldValue: "uuid123", + KeyFieldPrefix: "", + IsDesc: false, }, }, } @@ -173,11 +246,13 @@ func TestValidatePageSize(t *testing.T) { func TestNewOptions_FromValidSerializedToken(t *testing.T) { tok := &token{ - SortByFieldName: "SortField", - SortByFieldValue: "string_field_value", - KeyFieldName: "KeyField", - KeyFieldValue: "string_key_value", - IsDesc: true, + SortByFieldName: "SortField", + SortByFieldValue: "string_field_value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: "string_key_value", + KeyFieldPrefix: "", + IsDesc: true, } s, err := tok.marshal() @@ -209,11 +284,13 @@ func TestNewOptionsFromToken_FromInValidSerializedToken(t *testing.T) { func TestNewOptionsFromToken_FromInValidPageSize(t *testing.T) { tok := &token{ - SortByFieldName: "SortField", - SortByFieldValue: "string_field_value", - KeyFieldName: "KeyField", - KeyFieldValue: "string_key_value", - IsDesc: true, + SortByFieldName: "SortField", + SortByFieldValue: "string_field_value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: "string_key_value", + KeyFieldPrefix: "", + IsDesc: true, } s, err := tok.marshal() @@ -239,9 +316,11 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { want: &Options{ PageSize: pageSize, token: &token{ - KeyFieldName: "PrimaryKey", - SortByFieldName: "CreatedTimestamp", - IsDesc: false, + KeyFieldName: "PrimaryKey", + KeyFieldPrefix: "", + SortByFieldName: "CreatedTimestamp", + SortByFieldPrefix: "", + IsDesc: false, }, }, }, @@ -250,9 +329,11 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { want: &Options{ PageSize: pageSize, token: &token{ - KeyFieldName: "PrimaryKey", - SortByFieldName: "CreatedTimestamp", - IsDesc: false, + KeyFieldName: "PrimaryKey", + KeyFieldPrefix: "", + SortByFieldName: "CreatedTimestamp", + SortByFieldPrefix: "", + IsDesc: false, }, }, }, @@ -261,9 +342,11 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { want: &Options{ PageSize: pageSize, token: &token{ - KeyFieldName: "PrimaryKey", - SortByFieldName: "FakeName", - IsDesc: false, + KeyFieldName: "PrimaryKey", + KeyFieldPrefix: "", + SortByFieldName: "FakeName", + SortByFieldPrefix: "", + IsDesc: false, }, }, }, @@ -272,9 +355,11 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { want: &Options{ PageSize: pageSize, token: &token{ - KeyFieldName: "PrimaryKey", - SortByFieldName: "FakeName", - IsDesc: false, + KeyFieldName: "PrimaryKey", + KeyFieldPrefix: "", + SortByFieldName: "FakeName", + SortByFieldPrefix: "", + IsDesc: false, }, }, }, @@ -283,9 +368,11 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { want: &Options{ PageSize: pageSize, token: &token{ - KeyFieldName: "PrimaryKey", - SortByFieldName: "FakeName", - IsDesc: true, + KeyFieldName: "PrimaryKey", + KeyFieldPrefix: "", + SortByFieldName: "FakeName", + SortByFieldPrefix: "", + IsDesc: true, }, }, }, @@ -294,9 +381,11 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { want: &Options{ PageSize: pageSize, token: &token{ - KeyFieldName: "PrimaryKey", - SortByFieldName: "PrimaryKey", - IsDesc: true, + KeyFieldName: "PrimaryKey", + KeyFieldPrefix: "", + SortByFieldName: "PrimaryKey", + SortByFieldPrefix: "", + IsDesc: true, }, }, }, @@ -368,10 +457,12 @@ func TestNewOptions_ValidFilter(t *testing.T) { want := &Options{ PageSize: 10, token: &token{ - KeyFieldName: "PrimaryKey", - SortByFieldName: "CreatedTimestamp", - IsDesc: false, - Filter: f, + KeyFieldName: "PrimaryKey", + KeyFieldPrefix: "", + SortByFieldName: "CreatedTimestamp", + SortByFieldPrefix: "", + IsDesc: false, + Filter: f, }, } @@ -427,11 +518,13 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { in: &Options{ PageSize: 123, token: &token{ - SortByFieldName: "SortField", - SortByFieldValue: "value", - KeyFieldName: "KeyField", - KeyFieldValue: 1111, - IsDesc: true, + SortByFieldName: "SortField", + SortByFieldValue: "value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: 1111, + KeyFieldPrefix: "", + IsDesc: true, }, }, wantSQL: "SELECT * FROM MyTable WHERE (SortField < ? OR (SortField = ? AND KeyField <= ?)) ORDER BY SortField DESC, KeyField DESC LIMIT 124", @@ -441,11 +534,13 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { in: &Options{ PageSize: 123, token: &token{ - SortByFieldName: "SortField", - SortByFieldValue: "value", - KeyFieldName: "KeyField", - KeyFieldValue: 1111, - IsDesc: false, + SortByFieldName: "SortField", + SortByFieldValue: "value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: 1111, + KeyFieldPrefix: "", + IsDesc: false, }, }, wantSQL: "SELECT * FROM MyTable WHERE (SortField > ? OR (SortField = ? AND KeyField >= ?)) ORDER BY SortField ASC, KeyField ASC LIMIT 124", @@ -455,12 +550,14 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { in: &Options{ PageSize: 123, token: &token{ - SortByFieldName: "SortField", - SortByFieldValue: "value", - KeyFieldName: "KeyField", - KeyFieldValue: 1111, - IsDesc: false, - Filter: f, + SortByFieldName: "SortField", + SortByFieldValue: "value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: 1111, + KeyFieldPrefix: "", + IsDesc: false, + Filter: f, }, }, wantSQL: "SELECT * FROM MyTable WHERE (SortField > ? OR (SortField = ? AND KeyField >= ?)) AND Name = ? ORDER BY SortField ASC, KeyField ASC LIMIT 124", @@ -470,10 +567,12 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { in: &Options{ PageSize: 123, token: &token{ - SortByFieldName: "SortField", - KeyFieldName: "KeyField", - KeyFieldValue: 1111, - IsDesc: true, + SortByFieldName: "SortField", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldPrefix: "", + KeyFieldValue: 1111, + IsDesc: true, }, }, wantSQL: "SELECT * FROM MyTable ORDER BY SortField DESC, KeyField DESC LIMIT 124", @@ -483,10 +582,12 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { in: &Options{ PageSize: 123, token: &token{ - SortByFieldName: "SortField", - SortByFieldValue: "value", - KeyFieldName: "KeyField", - IsDesc: false, + SortByFieldName: "SortField", + SortByFieldValue: "value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldPrefix: "", + IsDesc: false, }, }, wantSQL: "SELECT * FROM MyTable ORDER BY SortField ASC, KeyField ASC LIMIT 124", @@ -496,11 +597,13 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { in: &Options{ PageSize: 123, token: &token{ - SortByFieldName: "SortField", - SortByFieldValue: "value", - KeyFieldName: "KeyField", - IsDesc: false, - Filter: f, + SortByFieldName: "SortField", + SortByFieldValue: "value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldPrefix: "", + IsDesc: false, + Filter: f, }, }, wantSQL: "SELECT * FROM MyTable WHERE Name = ? ORDER BY SortField ASC, KeyField ASC LIMIT 124", @@ -538,50 +641,62 @@ func TestTokenSerialization(t *testing.T) { // string values in sort by fields { in: &token{ - SortByFieldName: "SortField", - SortByFieldValue: "string_field_value", - KeyFieldName: "KeyField", - KeyFieldValue: "string_key_value", - IsDesc: true}, + SortByFieldName: "SortField", + SortByFieldValue: "string_field_value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: "string_key_value", + KeyFieldPrefix: "", + IsDesc: true}, want: &token{ - SortByFieldName: "SortField", - SortByFieldValue: "string_field_value", - KeyFieldName: "KeyField", - KeyFieldValue: "string_key_value", - IsDesc: true}, + SortByFieldName: "SortField", + SortByFieldValue: "string_field_value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: "string_key_value", + KeyFieldPrefix: "", + IsDesc: true}, }, // int values get deserialized as floats by JSON unmarshal. { in: &token{ - SortByFieldName: "SortField", - SortByFieldValue: 100, - KeyFieldName: "KeyField", - KeyFieldValue: 200, - IsDesc: true}, + SortByFieldName: "SortField", + SortByFieldValue: 100, + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: 200, + KeyFieldPrefix: "", + IsDesc: true}, want: &token{ - SortByFieldName: "SortField", - SortByFieldValue: float64(100), - KeyFieldName: "KeyField", - KeyFieldValue: float64(200), - IsDesc: true}, + SortByFieldName: "SortField", + SortByFieldValue: float64(100), + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: float64(200), + KeyFieldPrefix: "", + IsDesc: true}, }, // has a filter. { in: &token{ - SortByFieldName: "SortField", - SortByFieldValue: 100, - KeyFieldName: "KeyField", - KeyFieldValue: 200, - IsDesc: true, - Filter: testFilter, + SortByFieldName: "SortField", + SortByFieldValue: 100, + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: 200, + KeyFieldPrefix: "", + IsDesc: true, + Filter: testFilter, }, want: &token{ - SortByFieldName: "SortField", - SortByFieldValue: float64(100), - KeyFieldName: "KeyField", - KeyFieldValue: float64(200), - IsDesc: true, - Filter: testFilter, + SortByFieldName: "SortField", + SortByFieldValue: float64(100), + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: float64(200), + KeyFieldPrefix: "", + IsDesc: true, + Filter: testFilter, }, }, } @@ -828,3 +943,64 @@ func TestFilterOnNamesapce(t *testing.T) { } } } + +func TestAddSortingToSelectWithPipelineVersionModel(t *testing.T) { + listable := &model.PipelineVersion{ + UUID: "version_id_1", + CreatedAtInSec: 1, + Name: "version_name_1", + Parameters: "", + PipelineId: "pipeline_id_1", + Status: model.PipelineVersionReady, + CodeSourceUrl: "", + } + protoFilter := &api.Filter{} + listableOptions, err := NewOptions(listable, 10, "name", protoFilter) + assert.Nil(t, err) + sqlBuilder := sq.Select("*").From("pipeline_versions") + sql, _, err := listableOptions.AddSortingToSelect(sqlBuilder).ToSql() + assert.Nil(t, err) + + assert.Contains(t, sql, "pipeline_versions.Name") // sorting field + assert.Contains(t, sql, "pipeline_versions.UUID") // primary key field +} + +func TestAddStatusFilterToSelectWithRunModel(t *testing.T) { + listable := &model.Run{ + UUID: "run_id_1", + CreatedAtInSec: 1, + Name: "run_name_1", + Conditions: "Succeeded", + } + protoFilter := &api.Filter{} + protoFilter.Predicates = []*api.Predicate{ + { + Key: "status", + Op: api.Predicate_EQUALS, + Value: &api.Predicate_StringValue{StringValue: "Succeeded"}, + }, + } + listableOptions, err := NewOptions(listable, 10, "name", protoFilter) + assert.Nil(t, err) + sqlBuilder := sq.Select("*").From("run_details") + sql, args, err := listableOptions.AddFilterToSelect(sqlBuilder).ToSql() + assert.Nil(t, err) + assert.Contains(t, sql, "WHERE Conditions = ?") // filtering on status, aka Conditions in db + assert.Contains(t, args, "Succeeded") + + notEqualProtoFilter := &api.Filter{} + notEqualProtoFilter.Predicates = []*api.Predicate{ + { + Key: "status", + Op: api.Predicate_NOT_EQUALS, + Value: &api.Predicate_StringValue{StringValue: "somevalue"}, + }, + } + listableOptions, err = NewOptions(listable, 10, "name", notEqualProtoFilter) + assert.Nil(t, err) + sqlBuilder = sq.Select("*").From("run_details") + sql, args, err = listableOptions.AddFilterToSelect(sqlBuilder).ToSql() + assert.Nil(t, err) + assert.Contains(t, sql, "WHERE Conditions <> ?") // filtering on status, aka Conditions in db + assert.Contains(t, args, "somevalue") +} diff --git a/backend/src/apiserver/model/BUILD.bazel b/backend/src/apiserver/model/BUILD.bazel index a6716102f22..75add6c411a 100644 --- a/backend/src/apiserver/model/BUILD.bazel +++ b/backend/src/apiserver/model/BUILD.bazel @@ -22,9 +22,7 @@ go_library( go_test( name = "go_default_test", srcs = [ - "pipeline_version_test.go", "resource_reference_test.go", - "run_test.go", ], embed = [":go_default_library"], importpath = "github.com/kubeflow/pipelines/backend/src/apiserver/model", @@ -32,7 +30,6 @@ go_test( deps = [ "//backend/api:go_default_library", "//backend/src/apiserver/common:go_default_library", - "//backend/src/apiserver/list:go_default_library", "@com_github_masterminds_squirrel//:go_default_library", "@com_github_stretchr_testify//assert:go_default_library", ], diff --git a/backend/src/apiserver/model/experiment.go b/backend/src/apiserver/model/experiment.go index 255ceca3aa6..e79ed1984aa 100644 --- a/backend/src/apiserver/model/experiment.go +++ b/backend/src/apiserver/model/experiment.go @@ -47,7 +47,36 @@ func (e *Experiment) GetModelName() string { return "experiments" } +func (e *Experiment) GetField(name string) (string, bool) { + if field, ok := experimentAPIToModelFieldMap[name]; ok { + return field, true + } + return "", false +} + func (e *Experiment) GetFieldValue(name string) interface{} { - // TODO(jingzhang36): follow the example of GetFieldValue in run.go - return nil + switch name { + case "UUID": + return e.UUID + case "Name": + return e.Name + case "CreatedAtInSec": + return e.CreatedAtInSec + case "Description": + return e.Description + case "Namespace": + return e.Namespace + case "StorageState": + return e.StorageState + default: + return nil + } +} + +func (e *Experiment) GetSortByFieldPrefix(name string) string { + return "experiments." +} + +func (e *Experiment) GetKeyFieldPrefix() string { + return "experiments." } diff --git a/backend/src/apiserver/model/job.go b/backend/src/apiserver/model/job.go index b9ecf9a1433..376132f5346 100644 --- a/backend/src/apiserver/model/job.go +++ b/backend/src/apiserver/model/job.go @@ -105,7 +105,32 @@ func (j *Job) GetModelName() string { return "jobs" } +func (j *Job) GetField(name string) (string, bool) { + if field, ok := jobAPIToModelFieldMap[name]; ok { + return field, true + } + return "", false +} + func (j *Job) GetFieldValue(name string) interface{} { - // TODO(jingzhang36): follow the example of GetFieldValue in run.go - return nil + switch name { + case "UUID": + return j.UUID + case "DisplayName": + return j.DisplayName + case "CreatedAtInSec": + return j.CreatedAtInSec + case "PipelineId": + return j.PipelineId + default: + return nil + } +} + +func (j *Job) GetSortByFieldPrefix(name string) string { + return "jobs." +} + +func (j *Job) GetKeyFieldPrefix() string { + return "jobs." } diff --git a/backend/src/apiserver/model/pipeline.go b/backend/src/apiserver/model/pipeline.go index 5cd120df33c..1a52ec73467 100644 --- a/backend/src/apiserver/model/pipeline.go +++ b/backend/src/apiserver/model/pipeline.go @@ -79,7 +79,32 @@ func (p *Pipeline) GetModelName() string { return "pipelines" } +func (p *Pipeline) GetField(name string) (string, bool) { + if field, ok := pipelineAPIToModelFieldMap[name]; ok { + return field, true + } + return "", false +} + func (p *Pipeline) GetFieldValue(name string) interface{} { - // TODO(jingzhang36): follow the example of GetFieldValue in run.go - return nil + switch name { + case "UUID": + return p.UUID + case "Name": + return p.Name + case "CreatedAtInSec": + return p.CreatedAtInSec + case "Description": + return p.Description + default: + return nil + } +} + +func (p *Pipeline) GetSortByFieldPrefix(name string) string { + return "pipelines." +} + +func (p *Pipeline) GetKeyFieldPrefix() string { + return "pipelines." } diff --git a/backend/src/apiserver/model/pipeline_version.go b/backend/src/apiserver/model/pipeline_version.go index 7ba2983844a..8eca8670aa9 100644 --- a/backend/src/apiserver/model/pipeline_version.go +++ b/backend/src/apiserver/model/pipeline_version.go @@ -74,7 +74,32 @@ func (p *PipelineVersion) GetModelName() string { return "pipeline_versions" } +func (p *PipelineVersion) GetField(name string) (string, bool) { + if field, ok := p.APIToModelFieldMap()[name]; ok { + return field, true + } + return "", false +} + func (p *PipelineVersion) GetFieldValue(name string) interface{} { - // TODO(jingzhang36): follow the example of GetFieldValue in run.go - return nil + switch name { + case "UUID": + return p.UUID + case "Name": + return p.Name + case "CreatedAtInSec": + return p.CreatedAtInSec + case "Status": + return p.Status + default: + return nil + } +} + +func (p *PipelineVersion) GetSortByFieldPrefix(name string) string { + return "pipeline_versions." +} + +func (p *PipelineVersion) GetKeyFieldPrefix() string { + return "pipeline_versions." } diff --git a/backend/src/apiserver/model/pipeline_version_test.go b/backend/src/apiserver/model/pipeline_version_test.go deleted file mode 100644 index 1746f333952..00000000000 --- a/backend/src/apiserver/model/pipeline_version_test.go +++ /dev/null @@ -1,32 +0,0 @@ -package model - -import ( - "testing" - - sq "github.com/Masterminds/squirrel" - api "github.com/kubeflow/pipelines/backend/api/go_client" - "github.com/kubeflow/pipelines/backend/src/apiserver/list" - "github.com/stretchr/testify/assert" -) - -// Test model name usage in sorting clause -func TestAddSortingToSelect(t *testing.T) { - listable := &PipelineVersion{ - UUID: "version_id_1", - CreatedAtInSec: 1, - Name: "version_name_1", - Parameters: "", - PipelineId: "pipeline_id_1", - Status: PipelineVersionReady, - CodeSourceUrl: "", - } - protoFilter := &api.Filter{} - listableOptions, err := list.NewOptions(listable, 10, "name", protoFilter) - assert.Nil(t, err) - sqlBuilder := sq.Select("*").From("pipeline_versions") - sql, _, err := listableOptions.AddSortingToSelect(sqlBuilder).ToSql() - assert.Nil(t, err) - - assert.Contains(t, sql, "pipeline_versions.Name") // sorting field - assert.Contains(t, sql, "pipeline_versions.UUID") // primary key field -} diff --git a/backend/src/apiserver/model/resource_reference_test.go b/backend/src/apiserver/model/resource_reference_test.go index ff58545064e..afc05a89ad1 100644 --- a/backend/src/apiserver/model/resource_reference_test.go +++ b/backend/src/apiserver/model/resource_reference_test.go @@ -15,9 +15,10 @@ package model import ( + "testing" + "github.com/kubeflow/pipelines/backend/src/apiserver/common" "github.com/stretchr/testify/assert" - "testing" ) func TestGetNamespaceFromResourceReferencesModel(t *testing.T) { diff --git a/backend/src/apiserver/model/run.go b/backend/src/apiserver/model/run.go index 288085f5e16..a7ad0476f0a 100644 --- a/backend/src/apiserver/model/run.go +++ b/backend/src/apiserver/model/run.go @@ -14,6 +14,10 @@ package model +import ( + "strings" +) + type Run struct { UUID string `gorm:"column:UUID; not null; primary_key"` ExperimentUUID string `gorm:"column:ExperimentUUID; not null;"` @@ -92,6 +96,16 @@ func (r *Run) GetModelName() string { return "" } +func (r *Run) GetField(name string) (string, bool) { + if field, ok := runAPIToModelFieldMap[name]; ok { + return field, true + } + if strings.HasPrefix(name, "metric:") { + return name[7:], true + } + return "", false +} + func (r *Run) GetFieldValue(name string) interface{} { // "name" could be a field in Run type or a name inside an array typed field // in Run type @@ -120,3 +134,27 @@ func (r *Run) GetFieldValue(name string) interface{} { } return nil } + +// Regular fields are the fields that are mapped to columns in Run table. +// Non-regular fields are the run metrics for now. Could have other non-regular +// sorting fields later. +func (r *Run) IsRegularField(name string) bool { + for _, field := range runAPIToModelFieldMap { + if field == name { + return true + } + } + return false +} + +func (r *Run) GetSortByFieldPrefix(name string) string { + if r.IsRegularField(name) { + return r.GetModelName() + } else { + return "" + } +} + +func (r *Run) GetKeyFieldPrefix() string { + return r.GetModelName() +} diff --git a/backend/src/apiserver/model/run_test.go b/backend/src/apiserver/model/run_test.go deleted file mode 100644 index fb379352ff9..00000000000 --- a/backend/src/apiserver/model/run_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package model - -import ( - "testing" - - sq "github.com/Masterminds/squirrel" - api "github.com/kubeflow/pipelines/backend/api/go_client" - "github.com/kubeflow/pipelines/backend/src/apiserver/list" - "github.com/stretchr/testify/assert" -) - -// Test model name usage in sorting clause -func TestAddStatusFilterToSelect(t *testing.T) { - listable := &Run{ - UUID: "run_id_1", - CreatedAtInSec: 1, - Name: "run_name_1", - Conditions: "Succeeded", - } - protoFilter := &api.Filter{} - protoFilter.Predicates = []*api.Predicate{ - { - Key: "status", - Op: api.Predicate_EQUALS, - Value: &api.Predicate_StringValue{StringValue: "Succeeded"}, - }, - } - listableOptions, err := list.NewOptions(listable, 10, "name", protoFilter) - assert.Nil(t, err) - sqlBuilder := sq.Select("*").From("run_details") - sql, args, err := listableOptions.AddFilterToSelect(sqlBuilder).ToSql() - assert.Nil(t, err) - assert.Contains(t, sql, "WHERE Conditions = ?") // filtering on status, aka Conditions in db - assert.Contains(t, args, "Succeeded") - - notEqualProtoFilter := &api.Filter{} - notEqualProtoFilter.Predicates = []*api.Predicate{ - { - Key: "status", - Op: api.Predicate_NOT_EQUALS, - Value: &api.Predicate_StringValue{StringValue: "somevalue"}, - }, - } - listableOptions, err = list.NewOptions(listable, 10, "name", notEqualProtoFilter) - assert.Nil(t, err) - sqlBuilder = sq.Select("*").From("run_details") - sql, args, err = listableOptions.AddFilterToSelect(sqlBuilder).ToSql() - assert.Nil(t, err) - assert.Contains(t, sql, "WHERE Conditions <> ?") // filtering on status, aka Conditions in db - assert.Contains(t, args, "somevalue") -} diff --git a/backend/src/apiserver/server/list_request_util.go b/backend/src/apiserver/server/list_request_util.go index 88a5b1ca62e..72f5784d474 100644 --- a/backend/src/apiserver/server/list_request_util.go +++ b/backend/src/apiserver/server/list_request_util.go @@ -187,6 +187,10 @@ func parseAPIFilter(encoded string) (*api.Filter, error) { func validatedListOptions(listable list.Listable, pageToken string, pageSize int, sortBy string, filterSpec string) (*list.Options, error) { defaultOpts := func() (*list.Options, error) { + if listable == nil { + return nil, util.NewInvalidInputError("Please specify a valid type to list. E.g., list runs or list jobs.") + } + f, err := parseAPIFilter(filterSpec) if err != nil { return nil, err diff --git a/backend/src/apiserver/server/list_request_util_test.go b/backend/src/apiserver/server/list_request_util_test.go index 17991b039f3..18bfe2d5309 100644 --- a/backend/src/apiserver/server/list_request_util_test.go +++ b/backend/src/apiserver/server/list_request_util_test.go @@ -245,10 +245,34 @@ func (f *fakeListable) GetModelName() string { return "" } +func (f *fakeListable) GetField(name string) (string, bool) { + if field, ok := fakeAPIToModelMap[name]; ok { + return field, true + } else { + return "", false + } +} + func (f *fakeListable) GetFieldValue(name string) interface{} { + switch name { + case "CreatedTimestamp": + return f.CreatedTimestamp + case "FakeName": + return f.FakeName + case "PrimaryKey": + return f.PrimaryKey + } return nil } +func (f *fakeListable) GetSortByFieldPrefix(name string) string { + return "" +} + +func (f *fakeListable) GetKeyFieldPrefix() string { + return "" +} + func TestValidatedListOptions_Errors(t *testing.T) { opts, err := list.NewOptions(&fakeListable{}, 10, "name asc", nil) if err != nil { diff --git a/backend/src/apiserver/storage/run_store.go b/backend/src/apiserver/storage/run_store.go index f283390fd7b..92270c22314 100644 --- a/backend/src/apiserver/storage/run_store.go +++ b/backend/src/apiserver/storage/run_store.go @@ -170,7 +170,7 @@ func (s *RunStore) buildSelectRunsQuery(selectCount bool, opts *list.Options, // If we're not just counting, then also add select columns and perform a left join // to get resource reference information. Also add pagination. if !selectCount { - sqlBuilder = opts.AddSortByRunMetricToSelect(sqlBuilder) + sqlBuilder = s.AddSortByRunMetricToSelect(sqlBuilder, opts) sqlBuilder = opts.AddPaginationToSelect(sqlBuilder) sqlBuilder = s.addMetricsAndResourceReferences(sqlBuilder, opts) sqlBuilder = opts.AddSortingToSelect(sqlBuilder) @@ -224,11 +224,12 @@ func Map(vs []string, f func(string) string) []string { } func (s *RunStore) addMetricsAndResourceReferences(filteredSelectBuilder sq.SelectBuilder, opts *list.Options) sq.SelectBuilder { + var r model.Run resourceRefConcatQuery := s.db.Concat([]string{`"["`, s.db.GroupConcat("rr.Payload", ","), `"]"`}, "") columnsAfterJoiningResourceReferences := append( Map(runColumns, func(column string) string { return "rd." + column }), // Add prefix "rd." to runColumns resourceRefConcatQuery+" AS refs") - if opts != nil && opts.SortByFieldIsRunMetric { + if opts != nil && !r.IsRegularField(opts.SortByFieldName) { columnsAfterJoiningResourceReferences = append(columnsAfterJoiningResourceReferences, "rd."+opts.SortByFieldName) } subQ := sq. @@ -619,3 +620,18 @@ func (s *RunStore) TerminateRun(runId string) error { return nil } + +// Add a metric as a new field to the select clause by join the passed-in SQL query with run_metrics table. +// With the metric as a field in the select clause enable sorting on this metric afterwards. +// TODO(jingzhang36): example of resulting SQL query and explanation for it. +func (s *RunStore) AddSortByRunMetricToSelect(sqlBuilder sq.SelectBuilder, opts *list.Options) sq.SelectBuilder { + var r model.Run + if r.IsRegularField(opts.SortByFieldName) { + return sqlBuilder + } + // TODO(jingzhang36): address the case where runs doesn't have the specified metric. + return sq. + Select("selected_runs.*, run_metrics.numbervalue as "+opts.SortByFieldName). + FromSelect(sqlBuilder, "selected_runs"). + LeftJoin("run_metrics ON selected_runs.uuid=run_metrics.runuuid AND run_metrics.name='" + opts.SortByFieldName + "'") +}