diff --git a/backend/src/apiserver/list/list.go b/backend/src/apiserver/list/list.go index 54698e48142..f9c357668b2 100644 --- a/backend/src/apiserver/list/list.go +++ b/backend/src/apiserver/list/list.go @@ -43,16 +43,15 @@ type token struct { SortByFieldName string // SortByFieldValue is the value of the sorted field of the next row to be // returned. - SortByFieldValue interface{} - // SortByFieldIsRunMetric indicates whether the SortByFieldName field is - // a run metric field or not. - SortByFieldIsRunMetric bool + SortByFieldValue interface{} + SortByFieldPrefix string // KeyFieldName is the name of the primary key for the model being queried. KeyFieldName string // KeyFieldValue is the value of the sorted field of the next row to be // returned. - KeyFieldValue interface{} + KeyFieldValue interface{} + KeyFieldPrefix string // IsDesc is true if the sorting order should be descending. IsDesc bool @@ -101,7 +100,7 @@ type Options struct { // Matches returns trues if the sorting and filtering criteria in o matches that // of the one supplied in opts. func (o *Options) Matches(opts *Options) bool { - return o.SortByFieldName == opts.SortByFieldName && o.SortByFieldIsRunMetric == opts.SortByFieldIsRunMetric && + return o.SortByFieldName == opts.SortByFieldName && o.SortByFieldPrefix == opts.SortByFieldPrefix && o.IsDesc == opts.IsDesc && reflect.DeepEqual(o.Filter, opts.Filter) } @@ -147,24 +146,16 @@ func NewOptions(listable Listable, pageSize int, sortBy string, filterProto *api } token.SortByFieldName = listable.DefaultSortField() - token.SortByFieldIsRunMetric = false if len(queryList) > 0 { - var err error - n, ok := listable.APIToModelFieldMap()[queryList[0]] + n, ok := listable.GetField(queryList[0]) if ok { token.SortByFieldName = n - } else if strings.HasPrefix(queryList[0], "metric:") { - // Sorting on metrics is only available on certain runs. - model := reflect.ValueOf(listable).Elem().Type().Name() - if model != "Run" { - return nil, util.NewInvalidInputError("Invalid sorting field: %q on %q : %s", queryList[0], model, err) - } - token.SortByFieldName = queryList[0][7:] - token.SortByFieldIsRunMetric = true } else { - return nil, util.NewInvalidInputError("Invalid sorting field: %q: %s", queryList[0], err) + return nil, util.NewInvalidInputError("Invalid sorting field: %q on listable type %s", queryList[0], reflect.ValueOf(listable).Elem().Type().Name()) } } + token.SortByFieldPrefix = listable.GetSortByFieldPrefix(token.SortByFieldName) + token.KeyFieldPrefix = listable.GetKeyFieldPrefix() if len(queryList) == 2 { token.IsDesc = queryList[1] == "desc" @@ -196,31 +187,18 @@ func (o *Options) AddPaginationToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBu // AddSortingToSelect adds Order By clause. func (o *Options) AddSortingToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBuilder { // When sorting by a direct field in the listable model (i.e., name in Run or uuid in Pipeline), a sortByFieldPrefix can be specified; when sorting by a field in an array-typed dictionary (i.e., a run metric inside the metrics in Run), a sortByFieldPrefix is not needed. - var keyFieldPrefix string - var sortByFieldPrefix string - if len(o.ModelName) == 0 { - keyFieldPrefix = "" - sortByFieldPrefix = "" - } else if o.SortByFieldIsRunMetric { - keyFieldPrefix = o.ModelName + "." - sortByFieldPrefix = "" - } else { - keyFieldPrefix = o.ModelName + "." - sortByFieldPrefix = o.ModelName + "." - } - // If next row's value is specified, set those values in the clause. if o.SortByFieldValue != nil && o.KeyFieldValue != nil { if o.IsDesc { sqlBuilder = sqlBuilder. - Where(sq.Or{sq.Lt{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, - sq.And{sq.Eq{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, - sq.LtOrEq{keyFieldPrefix + o.KeyFieldName: o.KeyFieldValue}}}) + Where(sq.Or{sq.Lt{o.SortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, + sq.And{sq.Eq{o.SortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, + sq.LtOrEq{o.KeyFieldPrefix + o.KeyFieldName: o.KeyFieldValue}}}) } else { sqlBuilder = sqlBuilder. - Where(sq.Or{sq.Gt{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, - sq.And{sq.Eq{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, - sq.GtOrEq{keyFieldPrefix + o.KeyFieldName: o.KeyFieldValue}}}) + Where(sq.Or{sq.Gt{o.SortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, + sq.And{sq.Eq{o.SortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, + sq.GtOrEq{o.KeyFieldPrefix + o.KeyFieldName: o.KeyFieldValue}}}) } } @@ -229,25 +207,12 @@ func (o *Options) AddSortingToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBuild order = "DESC" } sqlBuilder = sqlBuilder. - OrderBy(fmt.Sprintf("%v %v", sortByFieldPrefix+o.SortByFieldName, order)). - OrderBy(fmt.Sprintf("%v %v", keyFieldPrefix+o.KeyFieldName, order)) + OrderBy(fmt.Sprintf("%v %v", o.SortByFieldPrefix+o.SortByFieldName, order)). + OrderBy(fmt.Sprintf("%v %v", o.KeyFieldPrefix+o.KeyFieldName, order)) return sqlBuilder } -// Add a metric as a new field to the select clause by join the passed-in SQL query with run_metrics table. -// With the metric as a field in the select clause enable sorting on this metric afterwards. -func (o *Options) AddSortByRunMetricToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBuilder { - if !o.SortByFieldIsRunMetric { - return sqlBuilder - } - // TODO(jingzhang36): address the case where runs doesn't have the specified metric. - return sq. - Select("selected_runs.*, run_metrics.numbervalue as "+o.SortByFieldName). - FromSelect(sqlBuilder, "selected_runs"). - LeftJoin("run_metrics ON selected_runs.uuid=run_metrics.runuuid AND run_metrics.name='" + o.SortByFieldName + "'") -} - // AddFilterToSelect adds WHERE clauses with the filtering criteria in the // Options o to the supplied SelectBuilder, and returns the new SelectBuilder // containing these. @@ -345,6 +310,12 @@ type Listable interface { APIToModelFieldMap() map[string]string // GetModelName returns table name used as sort field prefix. GetModelName() string + // Get the prefix of sorting field. + GetSortByFieldPrefix(string) string + // Get the prefix of key field. + GetKeyFieldPrefix() string + // Get a valid field for sorting/filtering in a listable model from the given string. + GetField(name string) (string, bool) // Find the value of a given field in a listable object. GetFieldValue(name string) interface{} } @@ -365,20 +336,8 @@ func (o *Options) nextPageToken(listable Listable) (*token, error) { elemName := elem.Type().Name() var sortByField interface{} - // TODO(jingzhang36): this if-else block can be simplified to one call to - // GetFieldValue after all the models (run, job, experiment, etc.) implement - // GetFieldValue method in listable interface. - if !o.SortByFieldIsRunMetric { - if value := elem.FieldByName(o.SortByFieldName); value.IsValid() { - sortByField = value.Interface() - } else { - return nil, util.NewInvalidInputError("cannot sort by field %q on type %q", o.SortByFieldName, elemName) - } - } else { - sortByField = listable.GetFieldValue(o.SortByFieldName) - if sortByField == nil { - return nil, util.NewInvalidInputError("Unable to find run metric %s", o.SortByFieldName) - } + if sortByField = listable.GetFieldValue(o.SortByFieldName); sortByField == nil { + return nil, util.NewInvalidInputError("cannot sort by field %q on type %q", o.SortByFieldName, elemName) } keyField := elem.FieldByName(listable.PrimaryKeyColumnName()) @@ -387,14 +346,15 @@ func (o *Options) nextPageToken(listable Listable) (*token, error) { } return &token{ - SortByFieldName: o.SortByFieldName, - SortByFieldValue: sortByField, - SortByFieldIsRunMetric: o.SortByFieldIsRunMetric, - KeyFieldName: listable.PrimaryKeyColumnName(), - KeyFieldValue: keyField.Interface(), - IsDesc: o.IsDesc, - Filter: o.Filter, - ModelName: o.ModelName, + SortByFieldName: o.SortByFieldName, + SortByFieldValue: sortByField, + SortByFieldPrefix: listable.GetSortByFieldPrefix(o.SortByFieldName), + KeyFieldName: listable.PrimaryKeyColumnName(), + KeyFieldValue: keyField.Interface(), + KeyFieldPrefix: listable.GetKeyFieldPrefix(), + IsDesc: o.IsDesc, + Filter: o.Filter, + ModelName: o.ModelName, }, nil } diff --git a/backend/src/apiserver/list/list_test.go b/backend/src/apiserver/list/list_test.go index 215961f3c8c..8ad4667914e 100644 --- a/backend/src/apiserver/list/list_test.go +++ b/backend/src/apiserver/list/list_test.go @@ -2,10 +2,12 @@ package list import ( "reflect" + "strings" "testing" "github.com/kubeflow/pipelines/backend/src/apiserver/common" "github.com/kubeflow/pipelines/backend/src/apiserver/filter" + "github.com/kubeflow/pipelines/backend/src/apiserver/model" "github.com/kubeflow/pipelines/backend/src/common/util" "github.com/stretchr/testify/assert" @@ -15,10 +17,16 @@ import ( api "github.com/kubeflow/pipelines/backend/api/go_client" ) +type fakeMetric struct { + Name string + Value float64 +} + type fakeListable struct { PrimaryKey string FakeName string CreatedTimestamp int64 + Metrics []*fakeMetric } func (f *fakeListable) PrimaryKeyColumnName() string { @@ -43,12 +51,52 @@ func (f *fakeListable) GetModelName() string { return "" } +func (f *fakeListable) GetField(name string) (string, bool) { + if field, ok := fakeAPIToModelMap[name]; ok { + return field, true + } + if strings.HasPrefix(name, "metric:") { + return name[7:], true + } + return "", false +} + func (f *fakeListable) GetFieldValue(name string) interface{} { + switch name { + case "CreatedTimestamp": + return f.CreatedTimestamp + case "FakeName": + return f.FakeName + case "PrimaryKey": + return f.PrimaryKey + } + for _, metric := range f.Metrics { + if metric.Name == name { + return metric.Value + } + } return nil } +func (f *fakeListable) GetSortByFieldPrefix(name string) string { + return "" +} + +func (f *fakeListable) GetKeyFieldPrefix() string { + return "" +} + func TestNextPageToken_ValidTokens(t *testing.T) { - l := &fakeListable{PrimaryKey: "uuid123", FakeName: "Fake", CreatedTimestamp: 1234} + l := &fakeListable{PrimaryKey: "uuid123", FakeName: "Fake", CreatedTimestamp: 1234, Metrics: []*fakeMetric{ + { + Name: "m1", + Value: 1.0, + }, + { + Name: "m2", + Value: 2.0, + }, + }} protoFilter := &api.Filter{Predicates: []*api.Predicate{ &api.Predicate{ @@ -70,11 +118,13 @@ func TestNextPageToken_ValidTokens(t *testing.T) { PageSize: 10, token: &token{SortByFieldName: "CreatedTimestamp", IsDesc: true}, }, want: &token{ - SortByFieldName: "CreatedTimestamp", - SortByFieldValue: int64(1234), - KeyFieldName: "PrimaryKey", - KeyFieldValue: "uuid123", - IsDesc: true, + SortByFieldName: "CreatedTimestamp", + SortByFieldValue: int64(1234), + SortByFieldPrefix: "", + KeyFieldName: "PrimaryKey", + KeyFieldValue: "uuid123", + KeyFieldPrefix: "", + IsDesc: true, }, }, { @@ -82,11 +132,13 @@ func TestNextPageToken_ValidTokens(t *testing.T) { PageSize: 10, token: &token{SortByFieldName: "PrimaryKey", IsDesc: true}, }, want: &token{ - SortByFieldName: "PrimaryKey", - SortByFieldValue: "uuid123", - KeyFieldName: "PrimaryKey", - KeyFieldValue: "uuid123", - IsDesc: true, + SortByFieldName: "PrimaryKey", + SortByFieldValue: "uuid123", + SortByFieldPrefix: "", + KeyFieldName: "PrimaryKey", + KeyFieldValue: "uuid123", + KeyFieldPrefix: "", + IsDesc: true, }, }, { @@ -94,11 +146,13 @@ func TestNextPageToken_ValidTokens(t *testing.T) { PageSize: 10, token: &token{SortByFieldName: "FakeName", IsDesc: false}, }, want: &token{ - SortByFieldName: "FakeName", - SortByFieldValue: "Fake", - KeyFieldName: "PrimaryKey", - KeyFieldValue: "uuid123", - IsDesc: false, + SortByFieldName: "FakeName", + SortByFieldValue: "Fake", + SortByFieldPrefix: "", + KeyFieldName: "PrimaryKey", + KeyFieldValue: "uuid123", + KeyFieldPrefix: "", + IsDesc: false, }, }, { @@ -110,12 +164,31 @@ func TestNextPageToken_ValidTokens(t *testing.T) { }, }, want: &token{ - SortByFieldName: "FakeName", - SortByFieldValue: "Fake", - KeyFieldName: "PrimaryKey", - KeyFieldValue: "uuid123", - IsDesc: false, - Filter: testFilter, + SortByFieldName: "FakeName", + SortByFieldValue: "Fake", + SortByFieldPrefix: "", + KeyFieldName: "PrimaryKey", + KeyFieldValue: "uuid123", + KeyFieldPrefix: "", + IsDesc: false, + Filter: testFilter, + }, + }, + { + inOpts: &Options{ + PageSize: 10, + token: &token{ + SortByFieldName: "m1", IsDesc: false, + }, + }, + want: &token{ + SortByFieldName: "m1", + SortByFieldValue: 1.0, + SortByFieldPrefix: "", + KeyFieldName: "PrimaryKey", + KeyFieldValue: "uuid123", + KeyFieldPrefix: "", + IsDesc: false, }, }, } @@ -173,11 +246,13 @@ func TestValidatePageSize(t *testing.T) { func TestNewOptions_FromValidSerializedToken(t *testing.T) { tok := &token{ - SortByFieldName: "SortField", - SortByFieldValue: "string_field_value", - KeyFieldName: "KeyField", - KeyFieldValue: "string_key_value", - IsDesc: true, + SortByFieldName: "SortField", + SortByFieldValue: "string_field_value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: "string_key_value", + KeyFieldPrefix: "", + IsDesc: true, } s, err := tok.marshal() @@ -209,11 +284,13 @@ func TestNewOptionsFromToken_FromInValidSerializedToken(t *testing.T) { func TestNewOptionsFromToken_FromInValidPageSize(t *testing.T) { tok := &token{ - SortByFieldName: "SortField", - SortByFieldValue: "string_field_value", - KeyFieldName: "KeyField", - KeyFieldValue: "string_key_value", - IsDesc: true, + SortByFieldName: "SortField", + SortByFieldValue: "string_field_value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: "string_key_value", + KeyFieldPrefix: "", + IsDesc: true, } s, err := tok.marshal() @@ -239,9 +316,11 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { want: &Options{ PageSize: pageSize, token: &token{ - KeyFieldName: "PrimaryKey", - SortByFieldName: "CreatedTimestamp", - IsDesc: false, + KeyFieldName: "PrimaryKey", + KeyFieldPrefix: "", + SortByFieldName: "CreatedTimestamp", + SortByFieldPrefix: "", + IsDesc: false, }, }, }, @@ -250,9 +329,11 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { want: &Options{ PageSize: pageSize, token: &token{ - KeyFieldName: "PrimaryKey", - SortByFieldName: "CreatedTimestamp", - IsDesc: false, + KeyFieldName: "PrimaryKey", + KeyFieldPrefix: "", + SortByFieldName: "CreatedTimestamp", + SortByFieldPrefix: "", + IsDesc: false, }, }, }, @@ -261,9 +342,11 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { want: &Options{ PageSize: pageSize, token: &token{ - KeyFieldName: "PrimaryKey", - SortByFieldName: "FakeName", - IsDesc: false, + KeyFieldName: "PrimaryKey", + KeyFieldPrefix: "", + SortByFieldName: "FakeName", + SortByFieldPrefix: "", + IsDesc: false, }, }, }, @@ -272,9 +355,11 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { want: &Options{ PageSize: pageSize, token: &token{ - KeyFieldName: "PrimaryKey", - SortByFieldName: "FakeName", - IsDesc: false, + KeyFieldName: "PrimaryKey", + KeyFieldPrefix: "", + SortByFieldName: "FakeName", + SortByFieldPrefix: "", + IsDesc: false, }, }, }, @@ -283,9 +368,11 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { want: &Options{ PageSize: pageSize, token: &token{ - KeyFieldName: "PrimaryKey", - SortByFieldName: "FakeName", - IsDesc: true, + KeyFieldName: "PrimaryKey", + KeyFieldPrefix: "", + SortByFieldName: "FakeName", + SortByFieldPrefix: "", + IsDesc: true, }, }, }, @@ -294,9 +381,11 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { want: &Options{ PageSize: pageSize, token: &token{ - KeyFieldName: "PrimaryKey", - SortByFieldName: "PrimaryKey", - IsDesc: true, + KeyFieldName: "PrimaryKey", + KeyFieldPrefix: "", + SortByFieldName: "PrimaryKey", + SortByFieldPrefix: "", + IsDesc: true, }, }, }, @@ -368,10 +457,12 @@ func TestNewOptions_ValidFilter(t *testing.T) { want := &Options{ PageSize: 10, token: &token{ - KeyFieldName: "PrimaryKey", - SortByFieldName: "CreatedTimestamp", - IsDesc: false, - Filter: f, + KeyFieldName: "PrimaryKey", + KeyFieldPrefix: "", + SortByFieldName: "CreatedTimestamp", + SortByFieldPrefix: "", + IsDesc: false, + Filter: f, }, } @@ -427,11 +518,13 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { in: &Options{ PageSize: 123, token: &token{ - SortByFieldName: "SortField", - SortByFieldValue: "value", - KeyFieldName: "KeyField", - KeyFieldValue: 1111, - IsDesc: true, + SortByFieldName: "SortField", + SortByFieldValue: "value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: 1111, + KeyFieldPrefix: "", + IsDesc: true, }, }, wantSQL: "SELECT * FROM MyTable WHERE (SortField < ? OR (SortField = ? AND KeyField <= ?)) ORDER BY SortField DESC, KeyField DESC LIMIT 124", @@ -441,11 +534,13 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { in: &Options{ PageSize: 123, token: &token{ - SortByFieldName: "SortField", - SortByFieldValue: "value", - KeyFieldName: "KeyField", - KeyFieldValue: 1111, - IsDesc: false, + SortByFieldName: "SortField", + SortByFieldValue: "value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: 1111, + KeyFieldPrefix: "", + IsDesc: false, }, }, wantSQL: "SELECT * FROM MyTable WHERE (SortField > ? OR (SortField = ? AND KeyField >= ?)) ORDER BY SortField ASC, KeyField ASC LIMIT 124", @@ -455,12 +550,14 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { in: &Options{ PageSize: 123, token: &token{ - SortByFieldName: "SortField", - SortByFieldValue: "value", - KeyFieldName: "KeyField", - KeyFieldValue: 1111, - IsDesc: false, - Filter: f, + SortByFieldName: "SortField", + SortByFieldValue: "value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: 1111, + KeyFieldPrefix: "", + IsDesc: false, + Filter: f, }, }, wantSQL: "SELECT * FROM MyTable WHERE (SortField > ? OR (SortField = ? AND KeyField >= ?)) AND Name = ? ORDER BY SortField ASC, KeyField ASC LIMIT 124", @@ -470,10 +567,12 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { in: &Options{ PageSize: 123, token: &token{ - SortByFieldName: "SortField", - KeyFieldName: "KeyField", - KeyFieldValue: 1111, - IsDesc: true, + SortByFieldName: "SortField", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldPrefix: "", + KeyFieldValue: 1111, + IsDesc: true, }, }, wantSQL: "SELECT * FROM MyTable ORDER BY SortField DESC, KeyField DESC LIMIT 124", @@ -483,10 +582,12 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { in: &Options{ PageSize: 123, token: &token{ - SortByFieldName: "SortField", - SortByFieldValue: "value", - KeyFieldName: "KeyField", - IsDesc: false, + SortByFieldName: "SortField", + SortByFieldValue: "value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldPrefix: "", + IsDesc: false, }, }, wantSQL: "SELECT * FROM MyTable ORDER BY SortField ASC, KeyField ASC LIMIT 124", @@ -496,11 +597,13 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { in: &Options{ PageSize: 123, token: &token{ - SortByFieldName: "SortField", - SortByFieldValue: "value", - KeyFieldName: "KeyField", - IsDesc: false, - Filter: f, + SortByFieldName: "SortField", + SortByFieldValue: "value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldPrefix: "", + IsDesc: false, + Filter: f, }, }, wantSQL: "SELECT * FROM MyTable WHERE Name = ? ORDER BY SortField ASC, KeyField ASC LIMIT 124", @@ -538,50 +641,62 @@ func TestTokenSerialization(t *testing.T) { // string values in sort by fields { in: &token{ - SortByFieldName: "SortField", - SortByFieldValue: "string_field_value", - KeyFieldName: "KeyField", - KeyFieldValue: "string_key_value", - IsDesc: true}, + SortByFieldName: "SortField", + SortByFieldValue: "string_field_value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: "string_key_value", + KeyFieldPrefix: "", + IsDesc: true}, want: &token{ - SortByFieldName: "SortField", - SortByFieldValue: "string_field_value", - KeyFieldName: "KeyField", - KeyFieldValue: "string_key_value", - IsDesc: true}, + SortByFieldName: "SortField", + SortByFieldValue: "string_field_value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: "string_key_value", + KeyFieldPrefix: "", + IsDesc: true}, }, // int values get deserialized as floats by JSON unmarshal. { in: &token{ - SortByFieldName: "SortField", - SortByFieldValue: 100, - KeyFieldName: "KeyField", - KeyFieldValue: 200, - IsDesc: true}, + SortByFieldName: "SortField", + SortByFieldValue: 100, + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: 200, + KeyFieldPrefix: "", + IsDesc: true}, want: &token{ - SortByFieldName: "SortField", - SortByFieldValue: float64(100), - KeyFieldName: "KeyField", - KeyFieldValue: float64(200), - IsDesc: true}, + SortByFieldName: "SortField", + SortByFieldValue: float64(100), + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: float64(200), + KeyFieldPrefix: "", + IsDesc: true}, }, // has a filter. { in: &token{ - SortByFieldName: "SortField", - SortByFieldValue: 100, - KeyFieldName: "KeyField", - KeyFieldValue: 200, - IsDesc: true, - Filter: testFilter, + SortByFieldName: "SortField", + SortByFieldValue: 100, + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: 200, + KeyFieldPrefix: "", + IsDesc: true, + Filter: testFilter, }, want: &token{ - SortByFieldName: "SortField", - SortByFieldValue: float64(100), - KeyFieldName: "KeyField", - KeyFieldValue: float64(200), - IsDesc: true, - Filter: testFilter, + SortByFieldName: "SortField", + SortByFieldValue: float64(100), + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: float64(200), + KeyFieldPrefix: "", + IsDesc: true, + Filter: testFilter, }, }, } @@ -828,3 +943,64 @@ func TestFilterOnNamesapce(t *testing.T) { } } } + +func TestAddSortingToSelectWithPipelineVersionModel(t *testing.T) { + listable := &model.PipelineVersion{ + UUID: "version_id_1", + CreatedAtInSec: 1, + Name: "version_name_1", + Parameters: "", + PipelineId: "pipeline_id_1", + Status: model.PipelineVersionReady, + CodeSourceUrl: "", + } + protoFilter := &api.Filter{} + listableOptions, err := NewOptions(listable, 10, "name", protoFilter) + assert.Nil(t, err) + sqlBuilder := sq.Select("*").From("pipeline_versions") + sql, _, err := listableOptions.AddSortingToSelect(sqlBuilder).ToSql() + assert.Nil(t, err) + + assert.Contains(t, sql, "pipeline_versions.Name") // sorting field + assert.Contains(t, sql, "pipeline_versions.UUID") // primary key field +} + +func TestAddStatusFilterToSelectWithRunModel(t *testing.T) { + listable := &model.Run{ + UUID: "run_id_1", + CreatedAtInSec: 1, + Name: "run_name_1", + Conditions: "Succeeded", + } + protoFilter := &api.Filter{} + protoFilter.Predicates = []*api.Predicate{ + { + Key: "status", + Op: api.Predicate_EQUALS, + Value: &api.Predicate_StringValue{StringValue: "Succeeded"}, + }, + } + listableOptions, err := NewOptions(listable, 10, "name", protoFilter) + assert.Nil(t, err) + sqlBuilder := sq.Select("*").From("run_details") + sql, args, err := listableOptions.AddFilterToSelect(sqlBuilder).ToSql() + assert.Nil(t, err) + assert.Contains(t, sql, "WHERE Conditions = ?") // filtering on status, aka Conditions in db + assert.Contains(t, args, "Succeeded") + + notEqualProtoFilter := &api.Filter{} + notEqualProtoFilter.Predicates = []*api.Predicate{ + { + Key: "status", + Op: api.Predicate_NOT_EQUALS, + Value: &api.Predicate_StringValue{StringValue: "somevalue"}, + }, + } + listableOptions, err = NewOptions(listable, 10, "name", notEqualProtoFilter) + assert.Nil(t, err) + sqlBuilder = sq.Select("*").From("run_details") + sql, args, err = listableOptions.AddFilterToSelect(sqlBuilder).ToSql() + assert.Nil(t, err) + assert.Contains(t, sql, "WHERE Conditions <> ?") // filtering on status, aka Conditions in db + assert.Contains(t, args, "somevalue") +} diff --git a/backend/src/apiserver/model/BUILD.bazel b/backend/src/apiserver/model/BUILD.bazel index a6716102f22..75add6c411a 100644 --- a/backend/src/apiserver/model/BUILD.bazel +++ b/backend/src/apiserver/model/BUILD.bazel @@ -22,9 +22,7 @@ go_library( go_test( name = "go_default_test", srcs = [ - "pipeline_version_test.go", "resource_reference_test.go", - "run_test.go", ], embed = [":go_default_library"], importpath = "github.com/kubeflow/pipelines/backend/src/apiserver/model", @@ -32,7 +30,6 @@ go_test( deps = [ "//backend/api:go_default_library", "//backend/src/apiserver/common:go_default_library", - "//backend/src/apiserver/list:go_default_library", "@com_github_masterminds_squirrel//:go_default_library", "@com_github_stretchr_testify//assert:go_default_library", ], diff --git a/backend/src/apiserver/model/experiment.go b/backend/src/apiserver/model/experiment.go index 255ceca3aa6..e79ed1984aa 100644 --- a/backend/src/apiserver/model/experiment.go +++ b/backend/src/apiserver/model/experiment.go @@ -47,7 +47,36 @@ func (e *Experiment) GetModelName() string { return "experiments" } +func (e *Experiment) GetField(name string) (string, bool) { + if field, ok := experimentAPIToModelFieldMap[name]; ok { + return field, true + } + return "", false +} + func (e *Experiment) GetFieldValue(name string) interface{} { - // TODO(jingzhang36): follow the example of GetFieldValue in run.go - return nil + switch name { + case "UUID": + return e.UUID + case "Name": + return e.Name + case "CreatedAtInSec": + return e.CreatedAtInSec + case "Description": + return e.Description + case "Namespace": + return e.Namespace + case "StorageState": + return e.StorageState + default: + return nil + } +} + +func (e *Experiment) GetSortByFieldPrefix(name string) string { + return "experiments." +} + +func (e *Experiment) GetKeyFieldPrefix() string { + return "experiments." } diff --git a/backend/src/apiserver/model/job.go b/backend/src/apiserver/model/job.go index b9ecf9a1433..376132f5346 100644 --- a/backend/src/apiserver/model/job.go +++ b/backend/src/apiserver/model/job.go @@ -105,7 +105,32 @@ func (j *Job) GetModelName() string { return "jobs" } +func (j *Job) GetField(name string) (string, bool) { + if field, ok := jobAPIToModelFieldMap[name]; ok { + return field, true + } + return "", false +} + func (j *Job) GetFieldValue(name string) interface{} { - // TODO(jingzhang36): follow the example of GetFieldValue in run.go - return nil + switch name { + case "UUID": + return j.UUID + case "DisplayName": + return j.DisplayName + case "CreatedAtInSec": + return j.CreatedAtInSec + case "PipelineId": + return j.PipelineId + default: + return nil + } +} + +func (j *Job) GetSortByFieldPrefix(name string) string { + return "jobs." +} + +func (j *Job) GetKeyFieldPrefix() string { + return "jobs." } diff --git a/backend/src/apiserver/model/pipeline.go b/backend/src/apiserver/model/pipeline.go index 5cd120df33c..1a52ec73467 100644 --- a/backend/src/apiserver/model/pipeline.go +++ b/backend/src/apiserver/model/pipeline.go @@ -79,7 +79,32 @@ func (p *Pipeline) GetModelName() string { return "pipelines" } +func (p *Pipeline) GetField(name string) (string, bool) { + if field, ok := pipelineAPIToModelFieldMap[name]; ok { + return field, true + } + return "", false +} + func (p *Pipeline) GetFieldValue(name string) interface{} { - // TODO(jingzhang36): follow the example of GetFieldValue in run.go - return nil + switch name { + case "UUID": + return p.UUID + case "Name": + return p.Name + case "CreatedAtInSec": + return p.CreatedAtInSec + case "Description": + return p.Description + default: + return nil + } +} + +func (p *Pipeline) GetSortByFieldPrefix(name string) string { + return "pipelines." +} + +func (p *Pipeline) GetKeyFieldPrefix() string { + return "pipelines." } diff --git a/backend/src/apiserver/model/pipeline_version.go b/backend/src/apiserver/model/pipeline_version.go index 7ba2983844a..8eca8670aa9 100644 --- a/backend/src/apiserver/model/pipeline_version.go +++ b/backend/src/apiserver/model/pipeline_version.go @@ -74,7 +74,32 @@ func (p *PipelineVersion) GetModelName() string { return "pipeline_versions" } +func (p *PipelineVersion) GetField(name string) (string, bool) { + if field, ok := p.APIToModelFieldMap()[name]; ok { + return field, true + } + return "", false +} + func (p *PipelineVersion) GetFieldValue(name string) interface{} { - // TODO(jingzhang36): follow the example of GetFieldValue in run.go - return nil + switch name { + case "UUID": + return p.UUID + case "Name": + return p.Name + case "CreatedAtInSec": + return p.CreatedAtInSec + case "Status": + return p.Status + default: + return nil + } +} + +func (p *PipelineVersion) GetSortByFieldPrefix(name string) string { + return "pipeline_versions." +} + +func (p *PipelineVersion) GetKeyFieldPrefix() string { + return "pipeline_versions." } diff --git a/backend/src/apiserver/model/pipeline_version_test.go b/backend/src/apiserver/model/pipeline_version_test.go deleted file mode 100644 index 1746f333952..00000000000 --- a/backend/src/apiserver/model/pipeline_version_test.go +++ /dev/null @@ -1,32 +0,0 @@ -package model - -import ( - "testing" - - sq "github.com/Masterminds/squirrel" - api "github.com/kubeflow/pipelines/backend/api/go_client" - "github.com/kubeflow/pipelines/backend/src/apiserver/list" - "github.com/stretchr/testify/assert" -) - -// Test model name usage in sorting clause -func TestAddSortingToSelect(t *testing.T) { - listable := &PipelineVersion{ - UUID: "version_id_1", - CreatedAtInSec: 1, - Name: "version_name_1", - Parameters: "", - PipelineId: "pipeline_id_1", - Status: PipelineVersionReady, - CodeSourceUrl: "", - } - protoFilter := &api.Filter{} - listableOptions, err := list.NewOptions(listable, 10, "name", protoFilter) - assert.Nil(t, err) - sqlBuilder := sq.Select("*").From("pipeline_versions") - sql, _, err := listableOptions.AddSortingToSelect(sqlBuilder).ToSql() - assert.Nil(t, err) - - assert.Contains(t, sql, "pipeline_versions.Name") // sorting field - assert.Contains(t, sql, "pipeline_versions.UUID") // primary key field -} diff --git a/backend/src/apiserver/model/resource_reference_test.go b/backend/src/apiserver/model/resource_reference_test.go index ff58545064e..afc05a89ad1 100644 --- a/backend/src/apiserver/model/resource_reference_test.go +++ b/backend/src/apiserver/model/resource_reference_test.go @@ -15,9 +15,10 @@ package model import ( + "testing" + "github.com/kubeflow/pipelines/backend/src/apiserver/common" "github.com/stretchr/testify/assert" - "testing" ) func TestGetNamespaceFromResourceReferencesModel(t *testing.T) { diff --git a/backend/src/apiserver/model/run.go b/backend/src/apiserver/model/run.go index 288085f5e16..a7ad0476f0a 100644 --- a/backend/src/apiserver/model/run.go +++ b/backend/src/apiserver/model/run.go @@ -14,6 +14,10 @@ package model +import ( + "strings" +) + type Run struct { UUID string `gorm:"column:UUID; not null; primary_key"` ExperimentUUID string `gorm:"column:ExperimentUUID; not null;"` @@ -92,6 +96,16 @@ func (r *Run) GetModelName() string { return "" } +func (r *Run) GetField(name string) (string, bool) { + if field, ok := runAPIToModelFieldMap[name]; ok { + return field, true + } + if strings.HasPrefix(name, "metric:") { + return name[7:], true + } + return "", false +} + func (r *Run) GetFieldValue(name string) interface{} { // "name" could be a field in Run type or a name inside an array typed field // in Run type @@ -120,3 +134,27 @@ func (r *Run) GetFieldValue(name string) interface{} { } return nil } + +// Regular fields are the fields that are mapped to columns in Run table. +// Non-regular fields are the run metrics for now. Could have other non-regular +// sorting fields later. +func (r *Run) IsRegularField(name string) bool { + for _, field := range runAPIToModelFieldMap { + if field == name { + return true + } + } + return false +} + +func (r *Run) GetSortByFieldPrefix(name string) string { + if r.IsRegularField(name) { + return r.GetModelName() + } else { + return "" + } +} + +func (r *Run) GetKeyFieldPrefix() string { + return r.GetModelName() +} diff --git a/backend/src/apiserver/model/run_test.go b/backend/src/apiserver/model/run_test.go deleted file mode 100644 index fb379352ff9..00000000000 --- a/backend/src/apiserver/model/run_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package model - -import ( - "testing" - - sq "github.com/Masterminds/squirrel" - api "github.com/kubeflow/pipelines/backend/api/go_client" - "github.com/kubeflow/pipelines/backend/src/apiserver/list" - "github.com/stretchr/testify/assert" -) - -// Test model name usage in sorting clause -func TestAddStatusFilterToSelect(t *testing.T) { - listable := &Run{ - UUID: "run_id_1", - CreatedAtInSec: 1, - Name: "run_name_1", - Conditions: "Succeeded", - } - protoFilter := &api.Filter{} - protoFilter.Predicates = []*api.Predicate{ - { - Key: "status", - Op: api.Predicate_EQUALS, - Value: &api.Predicate_StringValue{StringValue: "Succeeded"}, - }, - } - listableOptions, err := list.NewOptions(listable, 10, "name", protoFilter) - assert.Nil(t, err) - sqlBuilder := sq.Select("*").From("run_details") - sql, args, err := listableOptions.AddFilterToSelect(sqlBuilder).ToSql() - assert.Nil(t, err) - assert.Contains(t, sql, "WHERE Conditions = ?") // filtering on status, aka Conditions in db - assert.Contains(t, args, "Succeeded") - - notEqualProtoFilter := &api.Filter{} - notEqualProtoFilter.Predicates = []*api.Predicate{ - { - Key: "status", - Op: api.Predicate_NOT_EQUALS, - Value: &api.Predicate_StringValue{StringValue: "somevalue"}, - }, - } - listableOptions, err = list.NewOptions(listable, 10, "name", notEqualProtoFilter) - assert.Nil(t, err) - sqlBuilder = sq.Select("*").From("run_details") - sql, args, err = listableOptions.AddFilterToSelect(sqlBuilder).ToSql() - assert.Nil(t, err) - assert.Contains(t, sql, "WHERE Conditions <> ?") // filtering on status, aka Conditions in db - assert.Contains(t, args, "somevalue") -} diff --git a/backend/src/apiserver/server/list_request_util.go b/backend/src/apiserver/server/list_request_util.go index 88a5b1ca62e..72f5784d474 100644 --- a/backend/src/apiserver/server/list_request_util.go +++ b/backend/src/apiserver/server/list_request_util.go @@ -187,6 +187,10 @@ func parseAPIFilter(encoded string) (*api.Filter, error) { func validatedListOptions(listable list.Listable, pageToken string, pageSize int, sortBy string, filterSpec string) (*list.Options, error) { defaultOpts := func() (*list.Options, error) { + if listable == nil { + return nil, util.NewInvalidInputError("Please specify a valid type to list. E.g., list runs or list jobs.") + } + f, err := parseAPIFilter(filterSpec) if err != nil { return nil, err diff --git a/backend/src/apiserver/server/list_request_util_test.go b/backend/src/apiserver/server/list_request_util_test.go index 17991b039f3..18bfe2d5309 100644 --- a/backend/src/apiserver/server/list_request_util_test.go +++ b/backend/src/apiserver/server/list_request_util_test.go @@ -245,10 +245,34 @@ func (f *fakeListable) GetModelName() string { return "" } +func (f *fakeListable) GetField(name string) (string, bool) { + if field, ok := fakeAPIToModelMap[name]; ok { + return field, true + } else { + return "", false + } +} + func (f *fakeListable) GetFieldValue(name string) interface{} { + switch name { + case "CreatedTimestamp": + return f.CreatedTimestamp + case "FakeName": + return f.FakeName + case "PrimaryKey": + return f.PrimaryKey + } return nil } +func (f *fakeListable) GetSortByFieldPrefix(name string) string { + return "" +} + +func (f *fakeListable) GetKeyFieldPrefix() string { + return "" +} + func TestValidatedListOptions_Errors(t *testing.T) { opts, err := list.NewOptions(&fakeListable{}, 10, "name asc", nil) if err != nil { diff --git a/backend/src/apiserver/storage/run_store.go b/backend/src/apiserver/storage/run_store.go index f283390fd7b..92270c22314 100644 --- a/backend/src/apiserver/storage/run_store.go +++ b/backend/src/apiserver/storage/run_store.go @@ -170,7 +170,7 @@ func (s *RunStore) buildSelectRunsQuery(selectCount bool, opts *list.Options, // If we're not just counting, then also add select columns and perform a left join // to get resource reference information. Also add pagination. if !selectCount { - sqlBuilder = opts.AddSortByRunMetricToSelect(sqlBuilder) + sqlBuilder = s.AddSortByRunMetricToSelect(sqlBuilder, opts) sqlBuilder = opts.AddPaginationToSelect(sqlBuilder) sqlBuilder = s.addMetricsAndResourceReferences(sqlBuilder, opts) sqlBuilder = opts.AddSortingToSelect(sqlBuilder) @@ -224,11 +224,12 @@ func Map(vs []string, f func(string) string) []string { } func (s *RunStore) addMetricsAndResourceReferences(filteredSelectBuilder sq.SelectBuilder, opts *list.Options) sq.SelectBuilder { + var r model.Run resourceRefConcatQuery := s.db.Concat([]string{`"["`, s.db.GroupConcat("rr.Payload", ","), `"]"`}, "") columnsAfterJoiningResourceReferences := append( Map(runColumns, func(column string) string { return "rd." + column }), // Add prefix "rd." to runColumns resourceRefConcatQuery+" AS refs") - if opts != nil && opts.SortByFieldIsRunMetric { + if opts != nil && !r.IsRegularField(opts.SortByFieldName) { columnsAfterJoiningResourceReferences = append(columnsAfterJoiningResourceReferences, "rd."+opts.SortByFieldName) } subQ := sq. @@ -619,3 +620,18 @@ func (s *RunStore) TerminateRun(runId string) error { return nil } + +// Add a metric as a new field to the select clause by join the passed-in SQL query with run_metrics table. +// With the metric as a field in the select clause enable sorting on this metric afterwards. +// TODO(jingzhang36): example of resulting SQL query and explanation for it. +func (s *RunStore) AddSortByRunMetricToSelect(sqlBuilder sq.SelectBuilder, opts *list.Options) sq.SelectBuilder { + var r model.Run + if r.IsRegularField(opts.SortByFieldName) { + return sqlBuilder + } + // TODO(jingzhang36): address the case where runs doesn't have the specified metric. + return sq. + Select("selected_runs.*, run_metrics.numbervalue as "+opts.SortByFieldName). + FromSelect(sqlBuilder, "selected_runs"). + LeftJoin("run_metrics ON selected_runs.uuid=run_metrics.runuuid AND run_metrics.name='" + opts.SortByFieldName + "'") +}