From 5b672739dd88235d41b5867666db917a8eb80a10 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 17 Jun 2020 17:37:42 +0800 Subject: [PATCH 01/14] enable pagination when expanding experiment in both the home page and the archive page --- frontend/src/components/ExperimentList.tsx | 2 +- frontend/src/pages/ExperimentList.tsx | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/src/components/ExperimentList.tsx b/frontend/src/components/ExperimentList.tsx index 330e4488ac7..8823ec9b73e 100644 --- a/frontend/src/components/ExperimentList.tsx +++ b/frontend/src/components/ExperimentList.tsx @@ -177,7 +177,7 @@ export class ExperimentList extends React.PureComponent null} {...this.props} - disablePaging={true} + disablePaging={false} noFilterBox={true} storageState={ this.props.storageState === ExperimentStorageState.ARCHIVED diff --git a/frontend/src/pages/ExperimentList.tsx b/frontend/src/pages/ExperimentList.tsx index cc192c40640..4a6a6ce7f72 100644 --- a/frontend/src/pages/ExperimentList.tsx +++ b/frontend/src/pages/ExperimentList.tsx @@ -278,7 +278,7 @@ export class ExperimentList extends Page<{ namespace?: string }, ExperimentListS experimentIdMask={experiment.id} onError={() => null} {...this.props} - disablePaging={true} + disablePaging={false} selectedIds={this.state.selectedIds} noFilterBox={true} storageState={RunStorageState.AVAILABLE} From d3f8d1d1a164c9a6ec49282cdd7b4ebd03564bed Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 17 Jun 2020 17:40:17 +0800 Subject: [PATCH 02/14] Revert "enable pagination when expanding experiment in both the home page and the archive page" This reverts commit 5b672739dd88235d41b5867666db917a8eb80a10. --- frontend/src/components/ExperimentList.tsx | 2 +- frontend/src/pages/ExperimentList.tsx | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/src/components/ExperimentList.tsx b/frontend/src/components/ExperimentList.tsx index 8823ec9b73e..330e4488ac7 100644 --- a/frontend/src/components/ExperimentList.tsx +++ b/frontend/src/components/ExperimentList.tsx @@ -177,7 +177,7 @@ export class ExperimentList extends React.PureComponent null} {...this.props} - disablePaging={false} + disablePaging={true} noFilterBox={true} storageState={ this.props.storageState === ExperimentStorageState.ARCHIVED diff --git a/frontend/src/pages/ExperimentList.tsx b/frontend/src/pages/ExperimentList.tsx index 4a6a6ce7f72..cc192c40640 100644 --- a/frontend/src/pages/ExperimentList.tsx +++ b/frontend/src/pages/ExperimentList.tsx @@ -278,7 +278,7 @@ export class ExperimentList extends Page<{ namespace?: string }, ExperimentListS experimentIdMask={experiment.id} onError={() => null} {...this.props} - disablePaging={false} + disablePaging={true} selectedIds={this.state.selectedIds} noFilterBox={true} storageState={RunStorageState.AVAILABLE} From 0dbe684d4611081737569c32a5bc1a4a1b84c345 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Fri, 17 Jul 2020 17:35:50 +0800 Subject: [PATCH 03/14] sorting by run metrics is different from sorting by name, uuid, created at, etc. The lattre are direct field in listable object, the former is an element in an arrary-typed field in listable object. In other words, the latter are columns in table, the former is not. --- backend/src/apiserver/list/BUILD.bazel | 2 + backend/src/apiserver/list/list.go | 107 ++++++-- backend/src/apiserver/list/list_test.go | 8 + backend/src/apiserver/model/experiment.go | 5 + backend/src/apiserver/model/job.go | 5 + backend/src/apiserver/model/pipeline.go | 5 + .../src/apiserver/model/pipeline_version.go | 5 + backend/src/apiserver/model/run.go | 29 ++ .../server/list_request_util_test.go | 4 + backend/src/apiserver/storage/run_store.go | 31 ++- .../src/apiserver/storage/run_store_test.go | 257 +++++++++++++++++- 11 files changed, 421 insertions(+), 37 deletions(-) diff --git a/backend/src/apiserver/list/BUILD.bazel b/backend/src/apiserver/list/BUILD.bazel index 3a26548af88..b46aee86303 100644 --- a/backend/src/apiserver/list/BUILD.bazel +++ b/backend/src/apiserver/list/BUILD.bazel @@ -9,7 +9,9 @@ go_library( "//backend/api:go_default_library", "//backend/src/apiserver/common:go_default_library", "//backend/src/apiserver/filter:go_default_library", + "//backend/src/apiserver/model:go_default_library", "//backend/src/common/util:go_default_library", + "@com_github_golang_glog//:go_default_library", "@com_github_masterminds_squirrel//:go_default_library", ], ) diff --git a/backend/src/apiserver/list/list.go b/backend/src/apiserver/list/list.go index 2a44b39b4ec..54698e48142 100644 --- a/backend/src/apiserver/list/list.go +++ b/backend/src/apiserver/list/list.go @@ -44,15 +44,22 @@ type token struct { // SortByFieldValue is the value of the sorted field of the next row to be // returned. SortByFieldValue interface{} + // SortByFieldIsRunMetric indicates whether the SortByFieldName field is + // a run metric field or not. + SortByFieldIsRunMetric bool + // KeyFieldName is the name of the primary key for the model being queried. KeyFieldName string // KeyFieldValue is the value of the sorted field of the next row to be // returned. KeyFieldValue interface{} + // IsDesc is true if the sorting order should be descending. IsDesc bool + // ModelName is the table where ***FieldName belongs to. ModelName string + // Filter represents the filtering that should be applied in the query. Filter *filter.Filter } @@ -94,7 +101,7 @@ type Options struct { // Matches returns trues if the sorting and filtering criteria in o matches that // of the one supplied in opts. func (o *Options) Matches(opts *Options) bool { - return o.SortByFieldName == opts.SortByFieldName && + return o.SortByFieldName == opts.SortByFieldName && o.SortByFieldIsRunMetric == opts.SortByFieldIsRunMetric && o.IsDesc == opts.IsDesc && reflect.DeepEqual(o.Filter, opts.Filter) } @@ -140,13 +147,23 @@ func NewOptions(listable Listable, pageSize int, sortBy string, filterProto *api } token.SortByFieldName = listable.DefaultSortField() + token.SortByFieldIsRunMetric = false if len(queryList) > 0 { var err error n, ok := listable.APIToModelFieldMap()[queryList[0]] - if !ok { + if ok { + token.SortByFieldName = n + } else if strings.HasPrefix(queryList[0], "metric:") { + // Sorting on metrics is only available on certain runs. + model := reflect.ValueOf(listable).Elem().Type().Name() + if model != "Run" { + return nil, util.NewInvalidInputError("Invalid sorting field: %q on %q : %s", queryList[0], model, err) + } + token.SortByFieldName = queryList[0][7:] + token.SortByFieldIsRunMetric = true + } else { return nil, util.NewInvalidInputError("Invalid sorting field: %q: %s", queryList[0], err) } - token.SortByFieldName = n } if len(queryList) == 2 { @@ -176,28 +193,34 @@ func (o *Options) AddPaginationToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBu return sqlBuilder } -// AddPaginationToSelect adds WHERE clauses with the sorting and pagination criteria in the -// Options o to the supplied SelectBuilder, and returns the new SelectBuilder -// containing these. +// AddSortingToSelect adds Order By clause. func (o *Options) AddSortingToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBuilder { - // If next row's value is specified, set those values in the clause. - var modelNamePrefix string + // When sorting by a direct field in the listable model (i.e., name in Run or uuid in Pipeline), a sortByFieldPrefix can be specified; when sorting by a field in an array-typed dictionary (i.e., a run metric inside the metrics in Run), a sortByFieldPrefix is not needed. + var keyFieldPrefix string + var sortByFieldPrefix string if len(o.ModelName) == 0 { - modelNamePrefix = "" + keyFieldPrefix = "" + sortByFieldPrefix = "" + } else if o.SortByFieldIsRunMetric { + keyFieldPrefix = o.ModelName + "." + sortByFieldPrefix = "" } else { - modelNamePrefix = o.ModelName + "." + keyFieldPrefix = o.ModelName + "." + sortByFieldPrefix = o.ModelName + "." } + + // If next row's value is specified, set those values in the clause. if o.SortByFieldValue != nil && o.KeyFieldValue != nil { if o.IsDesc { sqlBuilder = sqlBuilder. - Where(sq.Or{sq.Lt{modelNamePrefix + o.SortByFieldName: o.SortByFieldValue}, - sq.And{sq.Eq{modelNamePrefix + o.SortByFieldName: o.SortByFieldValue}, - sq.LtOrEq{modelNamePrefix + o.KeyFieldName: o.KeyFieldValue}}}) + Where(sq.Or{sq.Lt{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, + sq.And{sq.Eq{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, + sq.LtOrEq{keyFieldPrefix + o.KeyFieldName: o.KeyFieldValue}}}) } else { sqlBuilder = sqlBuilder. - Where(sq.Or{sq.Gt{modelNamePrefix + o.SortByFieldName: o.SortByFieldValue}, - sq.And{sq.Eq{modelNamePrefix + o.SortByFieldName: o.SortByFieldValue}, - sq.GtOrEq{modelNamePrefix + o.KeyFieldName: o.KeyFieldValue}}}) + Where(sq.Or{sq.Gt{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, + sq.And{sq.Eq{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, + sq.GtOrEq{keyFieldPrefix + o.KeyFieldName: o.KeyFieldValue}}}) } } @@ -206,12 +229,25 @@ func (o *Options) AddSortingToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBuild order = "DESC" } sqlBuilder = sqlBuilder. - OrderBy(fmt.Sprintf("%v %v", modelNamePrefix+o.SortByFieldName, order)). - OrderBy(fmt.Sprintf("%v %v", modelNamePrefix+o.KeyFieldName, order)) + OrderBy(fmt.Sprintf("%v %v", sortByFieldPrefix+o.SortByFieldName, order)). + OrderBy(fmt.Sprintf("%v %v", keyFieldPrefix+o.KeyFieldName, order)) return sqlBuilder } +// Add a metric as a new field to the select clause by join the passed-in SQL query with run_metrics table. +// With the metric as a field in the select clause enable sorting on this metric afterwards. +func (o *Options) AddSortByRunMetricToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBuilder { + if !o.SortByFieldIsRunMetric { + return sqlBuilder + } + // TODO(jingzhang36): address the case where runs doesn't have the specified metric. + return sq. + Select("selected_runs.*, run_metrics.numbervalue as "+o.SortByFieldName). + FromSelect(sqlBuilder, "selected_runs"). + LeftJoin("run_metrics ON selected_runs.uuid=run_metrics.runuuid AND run_metrics.name='" + o.SortByFieldName + "'") +} + // AddFilterToSelect adds WHERE clauses with the filtering criteria in the // Options o to the supplied SelectBuilder, and returns the new SelectBuilder // containing these. @@ -309,6 +345,8 @@ type Listable interface { APIToModelFieldMap() map[string]string // GetModelName returns table name used as sort field prefix. GetModelName() string + // Find the value of a given field in a listable object. + GetFieldValue(name string) interface{} } // NextPageToken returns a string that can be used to fetch the subsequent set @@ -326,9 +364,21 @@ func (o *Options) nextPageToken(listable Listable) (*token, error) { elem := reflect.ValueOf(listable).Elem() elemName := elem.Type().Name() - sortByField := elem.FieldByName(o.SortByFieldName) - if !sortByField.IsValid() { - return nil, util.NewInvalidInputError("cannot sort by field %q on type %q", o.SortByFieldName, elemName) + var sortByField interface{} + // TODO(jingzhang36): this if-else block can be simplified to one call to + // GetFieldValue after all the models (run, job, experiment, etc.) implement + // GetFieldValue method in listable interface. + if !o.SortByFieldIsRunMetric { + if value := elem.FieldByName(o.SortByFieldName); value.IsValid() { + sortByField = value.Interface() + } else { + return nil, util.NewInvalidInputError("cannot sort by field %q on type %q", o.SortByFieldName, elemName) + } + } else { + sortByField = listable.GetFieldValue(o.SortByFieldName) + if sortByField == nil { + return nil, util.NewInvalidInputError("Unable to find run metric %s", o.SortByFieldName) + } } keyField := elem.FieldByName(listable.PrimaryKeyColumnName()) @@ -337,13 +387,14 @@ func (o *Options) nextPageToken(listable Listable) (*token, error) { } return &token{ - SortByFieldName: o.SortByFieldName, - SortByFieldValue: sortByField.Interface(), - KeyFieldName: listable.PrimaryKeyColumnName(), - KeyFieldValue: keyField.Interface(), - IsDesc: o.IsDesc, - Filter: o.Filter, - ModelName: o.ModelName, + SortByFieldName: o.SortByFieldName, + SortByFieldValue: sortByField, + SortByFieldIsRunMetric: o.SortByFieldIsRunMetric, + KeyFieldName: listable.PrimaryKeyColumnName(), + KeyFieldValue: keyField.Interface(), + IsDesc: o.IsDesc, + Filter: o.Filter, + ModelName: o.ModelName, }, nil } diff --git a/backend/src/apiserver/list/list_test.go b/backend/src/apiserver/list/list_test.go index 332acf5e37e..335f6d0dd21 100644 --- a/backend/src/apiserver/list/list_test.go +++ b/backend/src/apiserver/list/list_test.go @@ -43,6 +43,10 @@ func (f *fakeListable) GetModelName() string { return "" } +func (f *fakeListable) GetFieldValue(name string) interface{} { + return nil +} + func TestNextPageToken_ValidTokens(t *testing.T) { l := &fakeListable{PrimaryKey: "uuid123", FakeName: "Fake", CreatedTimestamp: 1234} @@ -824,3 +828,7 @@ func TestFilterOnNamesapce(t *testing.T) { } } } + +func TestSortByRunMetrics(t *testing.T) { + +} diff --git a/backend/src/apiserver/model/experiment.go b/backend/src/apiserver/model/experiment.go index 40bdc27c6b8..255ceca3aa6 100644 --- a/backend/src/apiserver/model/experiment.go +++ b/backend/src/apiserver/model/experiment.go @@ -46,3 +46,8 @@ func (e *Experiment) APIToModelFieldMap() map[string]string { func (e *Experiment) GetModelName() string { return "experiments" } + +func (e *Experiment) GetFieldValue(name string) interface{} { + // TODO(jingzhang36): follow the example of GetFieldValue in run.go + return nil +} diff --git a/backend/src/apiserver/model/job.go b/backend/src/apiserver/model/job.go index 2ea467e83ef..b9ecf9a1433 100644 --- a/backend/src/apiserver/model/job.go +++ b/backend/src/apiserver/model/job.go @@ -104,3 +104,8 @@ func (k *Job) APIToModelFieldMap() map[string]string { func (j *Job) GetModelName() string { return "jobs" } + +func (j *Job) GetFieldValue(name string) interface{} { + // TODO(jingzhang36): follow the example of GetFieldValue in run.go + return nil +} diff --git a/backend/src/apiserver/model/pipeline.go b/backend/src/apiserver/model/pipeline.go index 08100e49ce1..5cd120df33c 100644 --- a/backend/src/apiserver/model/pipeline.go +++ b/backend/src/apiserver/model/pipeline.go @@ -78,3 +78,8 @@ func (p *Pipeline) APIToModelFieldMap() map[string]string { func (p *Pipeline) GetModelName() string { return "pipelines" } + +func (p *Pipeline) GetFieldValue(name string) interface{} { + // TODO(jingzhang36): follow the example of GetFieldValue in run.go + return nil +} diff --git a/backend/src/apiserver/model/pipeline_version.go b/backend/src/apiserver/model/pipeline_version.go index 1cdb55a9197..7ba2983844a 100644 --- a/backend/src/apiserver/model/pipeline_version.go +++ b/backend/src/apiserver/model/pipeline_version.go @@ -73,3 +73,8 @@ func (p *PipelineVersion) APIToModelFieldMap() map[string]string { func (p *PipelineVersion) GetModelName() string { return "pipeline_versions" } + +func (p *PipelineVersion) GetFieldValue(name string) interface{} { + // TODO(jingzhang36): follow the example of GetFieldValue in run.go + return nil +} diff --git a/backend/src/apiserver/model/run.go b/backend/src/apiserver/model/run.go index 079e69602ce..288085f5e16 100644 --- a/backend/src/apiserver/model/run.go +++ b/backend/src/apiserver/model/run.go @@ -91,3 +91,32 @@ func (r *Run) GetModelName() string { // and thus as prefix in sorting fields. return "" } + +func (r *Run) GetFieldValue(name string) interface{} { + // "name" could be a field in Run type or a name inside an array typed field + // in Run type + // First, try to find the value if "name" is a field in Run type + switch name { + case "UUID": + return r.UUID + case "DisplayName": + return r.DisplayName + case "CreatedAtInSec": + return r.CreatedAtInSec + case "Description": + return r.Description + case "ScheduledAtInSec": + return r.ScheduledAtInSec + case "StorageState": + return r.StorageState + case "Conditions": + return r.Conditions + } + // Second, try to find the match of "name" inside an array typed field + for _, metric := range r.Metrics { + if metric.Name == name { + return metric.NumberValue + } + } + return nil +} diff --git a/backend/src/apiserver/server/list_request_util_test.go b/backend/src/apiserver/server/list_request_util_test.go index 76607317cf7..17991b039f3 100644 --- a/backend/src/apiserver/server/list_request_util_test.go +++ b/backend/src/apiserver/server/list_request_util_test.go @@ -245,6 +245,10 @@ func (f *fakeListable) GetModelName() string { return "" } +func (f *fakeListable) GetFieldValue(name string) interface{} { + return nil +} + func TestValidatedListOptions_Errors(t *testing.T) { opts, err := list.NewOptions(&fakeListable{}, 10, "name asc", nil) if err != nil { diff --git a/backend/src/apiserver/storage/run_store.go b/backend/src/apiserver/storage/run_store.go index 3cadb51f7ec..f283390fd7b 100644 --- a/backend/src/apiserver/storage/run_store.go +++ b/backend/src/apiserver/storage/run_store.go @@ -170,8 +170,9 @@ func (s *RunStore) buildSelectRunsQuery(selectCount bool, opts *list.Options, // If we're not just counting, then also add select columns and perform a left join // to get resource reference information. Also add pagination. if !selectCount { + sqlBuilder = opts.AddSortByRunMetricToSelect(sqlBuilder) sqlBuilder = opts.AddPaginationToSelect(sqlBuilder) - sqlBuilder = s.addMetricsAndResourceReferences(sqlBuilder) + sqlBuilder = s.addMetricsAndResourceReferences(sqlBuilder, opts) sqlBuilder = opts.AddSortingToSelect(sqlBuilder) } sql, args, err := sqlBuilder.ToSql() @@ -187,7 +188,7 @@ func (s *RunStore) GetRun(runId string) (*model.RunDetail, error) { sq.Select(runColumns...). From("run_details"). Where(sq.Eq{"UUID": runId}). - Limit(1)). + Limit(1), nil). ToSql() if err != nil { @@ -213,17 +214,37 @@ func (s *RunStore) GetRun(runId string) (*model.RunDetail, error) { return runs[0], nil } -func (s *RunStore) addMetricsAndResourceReferences(filteredSelectBuilder sq.SelectBuilder) sq.SelectBuilder { +// Apply func f to every string in a given string slice. +func Map(vs []string, f func(string) string) []string { + vsm := make([]string, len(vs)) + for i, v := range vs { + vsm[i] = f(v) + } + return vsm +} + +func (s *RunStore) addMetricsAndResourceReferences(filteredSelectBuilder sq.SelectBuilder, opts *list.Options) sq.SelectBuilder { resourceRefConcatQuery := s.db.Concat([]string{`"["`, s.db.GroupConcat("rr.Payload", ","), `"]"`}, "") + columnsAfterJoiningResourceReferences := append( + Map(runColumns, func(column string) string { return "rd." + column }), // Add prefix "rd." to runColumns + resourceRefConcatQuery+" AS refs") + if opts != nil && opts.SortByFieldIsRunMetric { + columnsAfterJoiningResourceReferences = append(columnsAfterJoiningResourceReferences, "rd."+opts.SortByFieldName) + } subQ := sq. - Select("rd.*", resourceRefConcatQuery+" AS refs"). + Select(columnsAfterJoiningResourceReferences...). FromSelect(filteredSelectBuilder, "rd"). LeftJoin("resource_references AS rr ON rr.ResourceType='Run' AND rd.UUID=rr.ResourceUUID"). GroupBy("rd.UUID") + // TODO(jingzhang36): address the case where some runs don't have the metric used in order by. metricConcatQuery := s.db.Concat([]string{`"["`, s.db.GroupConcat("rm.Payload", ","), `"]"`}, "") + columnsAfterJoiningRunMetrics := append( + Map(runColumns, func(column string) string { return "subq." + column }), // Add prefix "subq." to runColumns + "subq.refs", + metricConcatQuery+" AS metrics") return sq. - Select("subq.*", metricConcatQuery+" AS metrics"). + Select(columnsAfterJoiningRunMetrics...). FromSelect(subQ, "subq"). LeftJoin("run_metrics AS rm ON subq.UUID=rm.RunUUID"). GroupBy("subq.UUID") diff --git a/backend/src/apiserver/storage/run_store_test.go b/backend/src/apiserver/storage/run_store_test.go index e9e49b578ea..bdfdb0660aa 100644 --- a/backend/src/apiserver/storage/run_store_test.go +++ b/backend/src/apiserver/storage/run_store_test.go @@ -15,6 +15,8 @@ package storage import ( + "fmt" + "sort" "testing" sq "github.com/Masterminds/squirrel" @@ -27,6 +29,12 @@ import ( "google.golang.org/grpc/codes" ) +type RunMetricSorter []*model.RunMetric + +func (r RunMetricSorter) Len() int { return len(r) } +func (r RunMetricSorter) Less(i, j int) bool { return r[i].Name < r[j].Name } +func (r RunMetricSorter) Swap(i, j int) { r[i], r[j] = r[j], r[i] } + func initializeRunStore() (*DB, *RunStore) { db := NewFakeDbOrFatal() expStore := NewExperimentStore(db, util.NewFakeTimeForEpoch(), util.NewFakeUUIDGeneratorOrFatal(defaultFakeExpId, nil)) @@ -104,6 +112,24 @@ func initializeRunStore() (*DB, *RunStore) { runStore.CreateRun(run1) runStore.CreateRun(run2) runStore.CreateRun(run3) + + metric1 := &model.RunMetric{ + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + } + metric2 := &model.RunMetric{ + RunUUID: "2", + NodeID: "node2", + Name: "dummymetric", + NumberValue: 2.0, + Format: "PERCENTAGE", + } + runStore.ReportMetric(metric1) + runStore.ReportMetric(metric2) + return db, runStore } @@ -121,6 +147,15 @@ func TestListRuns_Pagination(t *testing.T) { ScheduledAtInSec: 1, StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), Conditions: "Running", + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "1", ResourceType: common.Run, @@ -139,6 +174,15 @@ func TestListRuns_Pagination(t *testing.T) { ScheduledAtInSec: 2, StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), Conditions: "done", + Metrics: []*model.RunMetric{ + { + RunUUID: "2", + NodeID: "node2", + Name: "dummymetric", + NumberValue: 2.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "2", ResourceType: common.Run, @@ -168,6 +212,85 @@ func TestListRuns_Pagination(t *testing.T) { assert.Empty(t, nextPageToken) } +func TestListRuns_Pagination_WithSortingOnMetrics(t *testing.T) { + db, runStore := initializeRunStore() + defer db.Close() + + expectedFirstPageRuns := []*model.Run{ + { + UUID: "1", + Name: "run1", + DisplayName: "run1", + Namespace: "n1", + CreatedAtInSec: 1, + ScheduledAtInSec: 1, + StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), + Conditions: "Running", + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + }, + ResourceReferences: []*model.ResourceReference{ + { + ResourceUUID: "1", ResourceType: common.Run, + ReferenceUUID: defaultFakeExpId, ReferenceName: "e1", + ReferenceType: common.Experiment, Relationship: common.Creator, + }, + }, + }} + expectedSecondPageRuns := []*model.Run{ + { + UUID: "2", + Name: "run2", + DisplayName: "run2", + Namespace: "n2", + CreatedAtInSec: 2, + ScheduledAtInSec: 2, + StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), + Conditions: "done", + Metrics: []*model.RunMetric{ + { + RunUUID: "2", + NodeID: "node2", + Name: "dummymetric", + NumberValue: 2.0, + Format: "PERCENTAGE", + }, + }, + ResourceReferences: []*model.ResourceReference{ + { + ResourceUUID: "2", ResourceType: common.Run, + ReferenceUUID: defaultFakeExpId, ReferenceName: "e1", + ReferenceType: common.Experiment, Relationship: common.Creator, + }, + }, + }} + + opts, err := list.NewOptions(&model.Run{}, 1, "metric:dummymetric", nil) + assert.Nil(t, err) + + runs, total_size, nextPageToken, err := runStore.ListRuns( + &common.FilterContext{ReferenceKey: &common.ReferenceKey{Type: common.Experiment, ID: defaultFakeExpId}}, opts) + assert.Nil(t, err) + assert.Equal(t, 2, total_size) + assert.Equal(t, expectedFirstPageRuns, runs, "Unexpected Run listed.") + assert.NotEmpty(t, nextPageToken) + + opts, err = list.NewOptionsFromToken(nextPageToken, 1) + assert.Nil(t, err) + runs, total_size, nextPageToken, err = runStore.ListRuns( + &common.FilterContext{ReferenceKey: &common.ReferenceKey{Type: common.Experiment, ID: defaultFakeExpId}}, opts) + assert.Nil(t, err) + assert.Equal(t, 2, total_size) + assert.Equal(t, expectedSecondPageRuns, runs, "Unexpected Run listed.") + assert.Empty(t, nextPageToken) +} + func TestListRuns_TotalSizeWithNoFilter(t *testing.T) { db, runStore := initializeRunStore() defer db.Close() @@ -219,6 +342,15 @@ func TestListRuns_Pagination_Descend(t *testing.T) { ScheduledAtInSec: 2, StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), Conditions: "done", + Metrics: []*model.RunMetric{ + { + RunUUID: "2", + NodeID: "node2", + Name: "dummymetric", + NumberValue: 2.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "2", ResourceType: common.Run, @@ -237,6 +369,15 @@ func TestListRuns_Pagination_Descend(t *testing.T) { ScheduledAtInSec: 1, StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), Conditions: "Running", + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "1", ResourceType: common.Run, @@ -251,6 +392,10 @@ func TestListRuns_Pagination_Descend(t *testing.T) { runs, total_size, nextPageToken, err := runStore.ListRuns( &common.FilterContext{ReferenceKey: &common.ReferenceKey{Type: common.Experiment, ID: defaultFakeExpId}}, opts) + for _, run := range runs { + fmt.Printf("%+v\n", run) + } + assert.Nil(t, err) assert.Equal(t, 2, total_size) assert.Equal(t, expectedFirstPageRuns, runs, "Unexpected Run listed.") @@ -280,6 +425,15 @@ func TestListRuns_Pagination_LessThanPageSize(t *testing.T) { ScheduledAtInSec: 1, StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), Conditions: "Running", + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "1", ResourceType: common.Run, @@ -297,6 +451,15 @@ func TestListRuns_Pagination_LessThanPageSize(t *testing.T) { ScheduledAtInSec: 2, StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), Conditions: "done", + Metrics: []*model.RunMetric{ + { + RunUUID: "2", + NodeID: "node2", + Name: "dummymetric", + NumberValue: 2.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "2", ResourceType: common.Run, @@ -341,6 +504,15 @@ func TestGetRun(t *testing.T) { ScheduledAtInSec: 1, StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), Conditions: "Running", + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "1", ResourceType: common.Run, @@ -389,6 +561,15 @@ func TestCreateOrUpdateRun_UpdateSuccess(t *testing.T) { ScheduledAtInSec: 1, StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), Conditions: "Running", + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "1", ResourceType: common.Run, @@ -426,6 +607,15 @@ func TestCreateOrUpdateRun_UpdateSuccess(t *testing.T) { ScheduledAtInSec: 1, StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), Conditions: "done", + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "1", ResourceType: common.Run, @@ -559,6 +749,15 @@ func TestCreateOrUpdateRun_BadStorageStateValue(t *testing.T) { CreatedAtInSec: 1, ScheduledAtInSec: 1, Conditions: "Running", + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "1", ResourceType: common.Run, @@ -603,6 +802,15 @@ func TestTerminateRun(t *testing.T) { ScheduledAtInSec: 1, StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), Conditions: "Terminating", + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "1", ResourceType: common.Run, @@ -652,7 +860,16 @@ func TestReportMetric_Success(t *testing.T) { runDetail, err := runStore.GetRun("1") assert.Nil(t, err, "Got error: %+v", err) - assert.Equal(t, []*model.RunMetric{metric}, runDetail.Run.Metrics) + sort.Sort(RunMetricSorter(runDetail.Run.Metrics)) + assert.Equal(t, []*model.RunMetric{ + metric, + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }}, runDetail.Run.Metrics) } func TestReportMetric_DupReports_Fail(t *testing.T) { @@ -744,7 +961,16 @@ func TestListRuns_WithMetrics(t *testing.T) { ReferenceType: common.Experiment, Relationship: common.Creator, }, }, - Metrics: []*model.RunMetric{metric1, metric2}, + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + metric1, + metric2}, }, { UUID: "2", @@ -762,15 +988,29 @@ func TestListRuns_WithMetrics(t *testing.T) { ReferenceType: common.Experiment, Relationship: common.Creator, }, }, - Metrics: []*model.RunMetric{metric3}, + Metrics: []*model.RunMetric{ + { + RunUUID: "2", + NodeID: "node2", + Name: "dummymetric", + NumberValue: 2.0, + Format: "PERCENTAGE", + }, + metric3}, }, } - opts, err := list.NewOptions(&model.Run{}, 2, "", nil) + opts, err := list.NewOptions(&model.Run{}, 2, "id", nil) assert.Nil(t, err) runs, total_size, _, err := runStore.ListRuns(&common.FilterContext{}, opts) assert.Equal(t, 3, total_size) assert.Nil(t, err) + for _, run := range expectedRuns { + sort.Sort(RunMetricSorter(run.Metrics)) + } + for _, run := range runs { + sort.Sort(RunMetricSorter(run.Metrics)) + } assert.Equal(t, expectedRuns, runs, "Unexpected Run listed.") } @@ -866,6 +1106,15 @@ func TestArchiveRun_IncludedInRunList(t *testing.T) { ScheduledAtInSec: 1, StorageState: api.Run_STORAGESTATE_ARCHIVED.String(), Conditions: "Running", + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "1", ResourceType: common.Run, From 90b0857a5a981541a6531ae1dae362b67b4178a6 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Mon, 20 Jul 2020 13:23:09 +0800 Subject: [PATCH 04/14] unit test: add sorting on metrics with both asc and desc order --- .../src/apiserver/storage/run_store_test.go | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/backend/src/apiserver/storage/run_store_test.go b/backend/src/apiserver/storage/run_store_test.go index bdfdb0660aa..d7375713cae 100644 --- a/backend/src/apiserver/storage/run_store_test.go +++ b/backend/src/apiserver/storage/run_store_test.go @@ -271,6 +271,7 @@ func TestListRuns_Pagination_WithSortingOnMetrics(t *testing.T) { }, }} + // Sort in asc order opts, err := list.NewOptions(&model.Run{}, 1, "metric:dummymetric", nil) assert.Nil(t, err) @@ -289,6 +290,26 @@ func TestListRuns_Pagination_WithSortingOnMetrics(t *testing.T) { assert.Equal(t, 2, total_size) assert.Equal(t, expectedSecondPageRuns, runs, "Unexpected Run listed.") assert.Empty(t, nextPageToken) + + // Sort in desc order + opts, err = list.NewOptions(&model.Run{}, 1, "metric:dummymetric desc", nil) + assert.Nil(t, err) + + runs, total_size, nextPageToken, err = runStore.ListRuns( + &common.FilterContext{ReferenceKey: &common.ReferenceKey{Type: common.Experiment, ID: defaultFakeExpId}}, opts) + assert.Nil(t, err) + assert.Equal(t, 2, total_size) + assert.Equal(t, expectedSecondPageRuns, runs, "Unexpected Run listed.") + assert.NotEmpty(t, nextPageToken) + + opts, err = list.NewOptionsFromToken(nextPageToken, 1) + assert.Nil(t, err) + runs, total_size, nextPageToken, err = runStore.ListRuns( + &common.FilterContext{ReferenceKey: &common.ReferenceKey{Type: common.Experiment, ID: defaultFakeExpId}}, opts) + assert.Nil(t, err) + assert.Equal(t, 2, total_size) + assert.Equal(t, expectedFirstPageRuns, runs, "Unexpected Run listed.") + assert.Empty(t, nextPageToken) } func TestListRuns_TotalSizeWithNoFilter(t *testing.T) { From 76048b72fc87c934deaab530268a3ccb7c8f70f5 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Mon, 20 Jul 2020 15:52:17 +0800 Subject: [PATCH 05/14] GetFieldValue in all models --- backend/src/apiserver/list/list.go | 13 ++----------- backend/src/apiserver/model/experiment.go | 18 ++++++++++++++++-- backend/src/apiserver/model/job.go | 14 ++++++++++++-- backend/src/apiserver/model/pipeline.go | 14 ++++++++++++-- .../src/apiserver/model/pipeline_version.go | 14 ++++++++++++-- 5 files changed, 54 insertions(+), 19 deletions(-) diff --git a/backend/src/apiserver/list/list.go b/backend/src/apiserver/list/list.go index 54698e48142..54c423aaa38 100644 --- a/backend/src/apiserver/list/list.go +++ b/backend/src/apiserver/list/list.go @@ -368,17 +368,8 @@ func (o *Options) nextPageToken(listable Listable) (*token, error) { // TODO(jingzhang36): this if-else block can be simplified to one call to // GetFieldValue after all the models (run, job, experiment, etc.) implement // GetFieldValue method in listable interface. - if !o.SortByFieldIsRunMetric { - if value := elem.FieldByName(o.SortByFieldName); value.IsValid() { - sortByField = value.Interface() - } else { - return nil, util.NewInvalidInputError("cannot sort by field %q on type %q", o.SortByFieldName, elemName) - } - } else { - sortByField = listable.GetFieldValue(o.SortByFieldName) - if sortByField == nil { - return nil, util.NewInvalidInputError("Unable to find run metric %s", o.SortByFieldName) - } + if sortByField = listable.GetFieldValue(o.SortByFieldName); sortByField == nil { + return nil, util.NewInvalidInputError("cannot sort by field %q on type %q", o.SortByFieldName, elemName) } keyField := elem.FieldByName(listable.PrimaryKeyColumnName()) diff --git a/backend/src/apiserver/model/experiment.go b/backend/src/apiserver/model/experiment.go index 255ceca3aa6..7a9d17d3591 100644 --- a/backend/src/apiserver/model/experiment.go +++ b/backend/src/apiserver/model/experiment.go @@ -48,6 +48,20 @@ func (e *Experiment) GetModelName() string { } func (e *Experiment) GetFieldValue(name string) interface{} { - // TODO(jingzhang36): follow the example of GetFieldValue in run.go - return nil + switch name { + case "UUID": + return e.UUID + case "Name": + return e.Name + case "CreatedAtInSec": + return e.CreatedAtInSec + case "Description": + return e.Description + case "Namespace": + return e.Namespace + case "StorageState": + return e.StorageState + default: + return nil + } } diff --git a/backend/src/apiserver/model/job.go b/backend/src/apiserver/model/job.go index b9ecf9a1433..855038219a3 100644 --- a/backend/src/apiserver/model/job.go +++ b/backend/src/apiserver/model/job.go @@ -106,6 +106,16 @@ func (j *Job) GetModelName() string { } func (j *Job) GetFieldValue(name string) interface{} { - // TODO(jingzhang36): follow the example of GetFieldValue in run.go - return nil + switch name { + case "UUID": + return j.UUID + case "DisplayName": + return j.DisplayName + case "CreatedAtInSec": + return j.CreatedAtInSec + case "PipelineId": + return j.PipelineId + default: + return nil + } } diff --git a/backend/src/apiserver/model/pipeline.go b/backend/src/apiserver/model/pipeline.go index 5cd120df33c..68cc9a24295 100644 --- a/backend/src/apiserver/model/pipeline.go +++ b/backend/src/apiserver/model/pipeline.go @@ -80,6 +80,16 @@ func (p *Pipeline) GetModelName() string { } func (p *Pipeline) GetFieldValue(name string) interface{} { - // TODO(jingzhang36): follow the example of GetFieldValue in run.go - return nil + switch name { + case "UUID": + return p.UUID + case "Name": + return p.Name + case "CreatedAtInSec": + return p.CreatedAtInSec + case "Description": + return p.Description + default: + return nil + } } diff --git a/backend/src/apiserver/model/pipeline_version.go b/backend/src/apiserver/model/pipeline_version.go index 7ba2983844a..93b76466a50 100644 --- a/backend/src/apiserver/model/pipeline_version.go +++ b/backend/src/apiserver/model/pipeline_version.go @@ -75,6 +75,16 @@ func (p *PipelineVersion) GetModelName() string { } func (p *PipelineVersion) GetFieldValue(name string) interface{} { - // TODO(jingzhang36): follow the example of GetFieldValue in run.go - return nil + switch name { + case "UUID": + return p.UUID + case "Name": + return p.Name + case "CreatedAtInSec": + return p.CreatedAtInSec + case "Status": + return p.Status + default: + return nil + } } From f5859ae3fd7b8c066b97b6e14c1bd5f4181d546f Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Mon, 20 Jul 2020 16:33:42 +0800 Subject: [PATCH 06/14] fix unit test --- backend/src/apiserver/list/list_test.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/backend/src/apiserver/list/list_test.go b/backend/src/apiserver/list/list_test.go index 335f6d0dd21..2410419596d 100644 --- a/backend/src/apiserver/list/list_test.go +++ b/backend/src/apiserver/list/list_test.go @@ -44,7 +44,16 @@ func (f *fakeListable) GetModelName() string { } func (f *fakeListable) GetFieldValue(name string) interface{} { - return nil + switch name { + case "CreatedTimestamp": + return f.CreatedTimestamp + case "FakeName": + return f.FakeName + case "PrimaryKey": + return f.PrimaryKey + default: + return nil + } } func TestNextPageToken_ValidTokens(t *testing.T) { From dd401bf4bf1e1637293fd1cc9c97298dc99cf840 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Mon, 20 Jul 2020 21:11:34 +0800 Subject: [PATCH 07/14] whether to test in list_test. It's hacky when check mode == 'run' --- backend/src/apiserver/list/list.go | 3 -- backend/src/apiserver/list/list_test.go | 45 +++++++++++++++++++++---- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/backend/src/apiserver/list/list.go b/backend/src/apiserver/list/list.go index 54c423aaa38..528c43739f8 100644 --- a/backend/src/apiserver/list/list.go +++ b/backend/src/apiserver/list/list.go @@ -365,9 +365,6 @@ func (o *Options) nextPageToken(listable Listable) (*token, error) { elemName := elem.Type().Name() var sortByField interface{} - // TODO(jingzhang36): this if-else block can be simplified to one call to - // GetFieldValue after all the models (run, job, experiment, etc.) implement - // GetFieldValue method in listable interface. if sortByField = listable.GetFieldValue(o.SortByFieldName); sortByField == nil { return nil, util.NewInvalidInputError("cannot sort by field %q on type %q", o.SortByFieldName, elemName) } diff --git a/backend/src/apiserver/list/list_test.go b/backend/src/apiserver/list/list_test.go index 2410419596d..3a00b45aacb 100644 --- a/backend/src/apiserver/list/list_test.go +++ b/backend/src/apiserver/list/list_test.go @@ -15,10 +15,16 @@ import ( api "github.com/kubeflow/pipelines/backend/api/go_client" ) +type fakeMetric struct { + Name: string + Value: float64 +} + type fakeListable struct { PrimaryKey string FakeName string CreatedTimestamp int64 + Metrics []*fakeMetric } func (f *fakeListable) PrimaryKeyColumnName() string { @@ -51,13 +57,26 @@ func (f *fakeListable) GetFieldValue(name string) interface{} { return f.FakeName case "PrimaryKey": return f.PrimaryKey - default: - return nil } + for _, metric := range f.Metrics { + if metric.Name == name { + return metric.Value + } + } + return nil } func TestNextPageToken_ValidTokens(t *testing.T) { - l := &fakeListable{PrimaryKey: "uuid123", FakeName: "Fake", CreatedTimestamp: 1234} + l := &fakeListable{PrimaryKey: "uuid123", FakeName: "Fake", CreatedTimestamp: 1234, Metrics: []*fakeMetric{ + { + Name: "m1", + Value: 1.0, + }, + { + Name: "m2", + Value: 2.0, + }, + }} protoFilter := &api.Filter{Predicates: []*api.Predicate{ &api.Predicate{ @@ -127,6 +146,22 @@ func TestNextPageToken_ValidTokens(t *testing.T) { Filter: testFilter, }, }, + { + inOpts: &Options{ + PageSize: 10, + token: &token{ + SortByFieldName: "metric:m1", IsDesc: false, + }, + }, + want: &token{ + SortByFieldName: "m1", + SortByFieldValue: "1.0", + SortByFieldIsRunMetric: true, + KeyFieldName: "PrimaryKey", + KeyFieldValue: "uuid123", + IsDesc: false, + }, + }, } for _, test := range tests { @@ -837,7 +872,3 @@ func TestFilterOnNamesapce(t *testing.T) { } } } - -func TestSortByRunMetrics(t *testing.T) { - -} From 1b3dcc40c36a305abe430a61e7e5c68228ec59b1 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Thu, 23 Jul 2020 21:26:57 +0800 Subject: [PATCH 08/14] move model specific code to model; prevent model package depends on list package; let list package depends on modelpackage; marshal/unmarshal listable interface; include listable interface in token. --- backend/src/apiserver/list/list.go | 137 +++++++++++------- backend/src/apiserver/list/list_test.go | 136 +++++++++++++++-- backend/src/apiserver/model/BUILD.bazel | 3 - backend/src/apiserver/model/experiment.go | 15 ++ backend/src/apiserver/model/job.go | 15 ++ backend/src/apiserver/model/pipeline.go | 15 ++ .../src/apiserver/model/pipeline_version.go | 15 ++ .../apiserver/model/pipeline_version_test.go | 32 ---- .../model/resource_reference_test.go | 3 +- backend/src/apiserver/model/run.go | 34 +++++ backend/src/apiserver/model/run_test.go | 51 ------- .../server/list_request_util_test.go | 24 +++ backend/src/apiserver/storage/run_store.go | 20 ++- 13 files changed, 350 insertions(+), 150 deletions(-) delete mode 100644 backend/src/apiserver/model/pipeline_version_test.go delete mode 100644 backend/src/apiserver/model/run_test.go diff --git a/backend/src/apiserver/list/list.go b/backend/src/apiserver/list/list.go index 528c43739f8..2f4f2ef3d96 100644 --- a/backend/src/apiserver/list/list.go +++ b/backend/src/apiserver/list/list.go @@ -29,6 +29,7 @@ import ( api "github.com/kubeflow/pipelines/backend/api/go_client" "github.com/kubeflow/pipelines/backend/src/apiserver/common" "github.com/kubeflow/pipelines/backend/src/apiserver/filter" + "github.com/kubeflow/pipelines/backend/src/apiserver/model" "github.com/kubeflow/pipelines/backend/src/common/util" ) @@ -44,9 +45,6 @@ type token struct { // SortByFieldValue is the value of the sorted field of the next row to be // returned. SortByFieldValue interface{} - // SortByFieldIsRunMetric indicates whether the SortByFieldName field is - // a run metric field or not. - SortByFieldIsRunMetric bool // KeyFieldName is the name of the primary key for the model being queried. KeyFieldName string @@ -58,10 +56,21 @@ type token struct { IsDesc bool // ModelName is the table where ***FieldName belongs to. + // TODO(jingzhang36): we probably can deprecate this since we now have + // Model field. ModelName string // Filter represents the filtering that should be applied in the query. Filter *filter.Filter + + // The listable model this token is applied to. Not used in json marshal/unmarshal. + Model Listable `json:"-"` + // ModelType and the ModelMessage are helper fields to unmarshal data correctly to + // the underlying listable model, and this underlying listable model will be stored + // in the above Model field. Those two fields are only used in token's marshal and + // unmarshal methods. + ModelType string + ModelMessage json.RawMessage } func (t *token) unmarshal(pageToken string) error { @@ -77,10 +86,64 @@ func (t *token) unmarshal(pageToken string) error { return errorF(err) } + if t.ModelMessage != nil { + switch t.ModelType { + case "Run": + model := &model.Run{} + err = json.Unmarshal(t.ModelMessage, model) + if err != nil { + return errorF(err) + } + t.Model = model + break + case "Job": + model := &model.Job{} + err = json.Unmarshal(t.ModelMessage, model) + if err != nil { + return errorF(err) + } + t.Model = model + break + case "Experiment": + model := &model.Experiment{} + err = json.Unmarshal(t.ModelMessage, model) + if err != nil { + return errorF(err) + } + t.Model = model + break + case "Pipeline": + model := &model.Pipeline{} + err = json.Unmarshal(t.ModelMessage, model) + if err != nil { + return errorF(err) + } + t.Model = model + break + case "PipelineVersion": + model := &model.PipelineVersion{} + err = json.Unmarshal(t.ModelMessage, model) + if err != nil { + return errorF(err) + } + t.Model = model + break + } + } + return nil } func (t *token) marshal() (string, error) { + if t.Model != nil { + t.ModelType = reflect.ValueOf(t.Model).Elem().Type().Name() + modelMessage, err := json.Marshal(t.Model) + if err != nil { + return "", util.NewInternalServerError(err, "Failed to serialize the listable object in page token.") + } + t.ModelMessage = modelMessage + } // can we set empty raw message explicitly in case of nil model + b, err := json.Marshal(t) if err != nil { return "", util.NewInternalServerError(err, "Failed to serialize page token.") @@ -101,7 +164,7 @@ type Options struct { // Matches returns trues if the sorting and filtering criteria in o matches that // of the one supplied in opts. func (o *Options) Matches(opts *Options) bool { - return o.SortByFieldName == opts.SortByFieldName && o.SortByFieldIsRunMetric == opts.SortByFieldIsRunMetric && + return o.SortByFieldName == opts.SortByFieldName && o.IsDesc == opts.IsDesc && reflect.DeepEqual(o.Filter, opts.Filter) } @@ -136,7 +199,8 @@ func NewOptions(listable Listable, pageSize int, sortBy string, filterProto *api token := &token{ KeyFieldName: listable.PrimaryKeyColumnName(), - ModelName: listable.GetModelName()} + ModelName: listable.GetModelName(), + Model: listable} // Ignore the case of the letter. Split query string by space. queryList := strings.Fields(strings.ToLower(sortBy)) @@ -147,22 +211,12 @@ func NewOptions(listable Listable, pageSize int, sortBy string, filterProto *api } token.SortByFieldName = listable.DefaultSortField() - token.SortByFieldIsRunMetric = false if len(queryList) > 0 { - var err error - n, ok := listable.APIToModelFieldMap()[queryList[0]] + n, ok := listable.GetField(queryList[0]) if ok { token.SortByFieldName = n - } else if strings.HasPrefix(queryList[0], "metric:") { - // Sorting on metrics is only available on certain runs. - model := reflect.ValueOf(listable).Elem().Type().Name() - if model != "Run" { - return nil, util.NewInvalidInputError("Invalid sorting field: %q on %q : %s", queryList[0], model, err) - } - token.SortByFieldName = queryList[0][7:] - token.SortByFieldIsRunMetric = true } else { - return nil, util.NewInvalidInputError("Invalid sorting field: %q: %s", queryList[0], err) + return nil, util.NewInvalidInputError("Invalid sorting field: %q on listable type %s", queryList[0], reflect.ValueOf(listable).Elem().Type().Name()) } } @@ -196,18 +250,8 @@ func (o *Options) AddPaginationToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBu // AddSortingToSelect adds Order By clause. func (o *Options) AddSortingToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBuilder { // When sorting by a direct field in the listable model (i.e., name in Run or uuid in Pipeline), a sortByFieldPrefix can be specified; when sorting by a field in an array-typed dictionary (i.e., a run metric inside the metrics in Run), a sortByFieldPrefix is not needed. - var keyFieldPrefix string - var sortByFieldPrefix string - if len(o.ModelName) == 0 { - keyFieldPrefix = "" - sortByFieldPrefix = "" - } else if o.SortByFieldIsRunMetric { - keyFieldPrefix = o.ModelName + "." - sortByFieldPrefix = "" - } else { - keyFieldPrefix = o.ModelName + "." - sortByFieldPrefix = o.ModelName + "." - } + keyFieldPrefix := o.Model.GetKeyFieldPrefix() + sortByFieldPrefix := o.Model.GetSortByFieldPrefix(o.SortByFieldName) // If next row's value is specified, set those values in the clause. if o.SortByFieldValue != nil && o.KeyFieldValue != nil { @@ -235,19 +279,6 @@ func (o *Options) AddSortingToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBuild return sqlBuilder } -// Add a metric as a new field to the select clause by join the passed-in SQL query with run_metrics table. -// With the metric as a field in the select clause enable sorting on this metric afterwards. -func (o *Options) AddSortByRunMetricToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBuilder { - if !o.SortByFieldIsRunMetric { - return sqlBuilder - } - // TODO(jingzhang36): address the case where runs doesn't have the specified metric. - return sq. - Select("selected_runs.*, run_metrics.numbervalue as "+o.SortByFieldName). - FromSelect(sqlBuilder, "selected_runs"). - LeftJoin("run_metrics ON selected_runs.uuid=run_metrics.runuuid AND run_metrics.name='" + o.SortByFieldName + "'") -} - // AddFilterToSelect adds WHERE clauses with the filtering criteria in the // Options o to the supplied SelectBuilder, and returns the new SelectBuilder // containing these. @@ -345,6 +376,12 @@ type Listable interface { APIToModelFieldMap() map[string]string // GetModelName returns table name used as sort field prefix. GetModelName() string + // Get the prefix of sorting field. + GetSortByFieldPrefix(string) string + // Get the prefix of key field. + GetKeyFieldPrefix() string + // Get a valid field for sorting/filtering in a listable model from the given string. + GetField(name string) (string, bool) // Find the value of a given field in a listable object. GetFieldValue(name string) interface{} } @@ -375,14 +412,14 @@ func (o *Options) nextPageToken(listable Listable) (*token, error) { } return &token{ - SortByFieldName: o.SortByFieldName, - SortByFieldValue: sortByField, - SortByFieldIsRunMetric: o.SortByFieldIsRunMetric, - KeyFieldName: listable.PrimaryKeyColumnName(), - KeyFieldValue: keyField.Interface(), - IsDesc: o.IsDesc, - Filter: o.Filter, - ModelName: o.ModelName, + SortByFieldName: o.SortByFieldName, + SortByFieldValue: sortByField, + KeyFieldName: listable.PrimaryKeyColumnName(), + KeyFieldValue: keyField.Interface(), + IsDesc: o.IsDesc, + Filter: o.Filter, + ModelName: o.ModelName, + Model: listable, }, nil } diff --git a/backend/src/apiserver/list/list_test.go b/backend/src/apiserver/list/list_test.go index 3a00b45aacb..c01d5dc3443 100644 --- a/backend/src/apiserver/list/list_test.go +++ b/backend/src/apiserver/list/list_test.go @@ -1,11 +1,14 @@ package list import ( + "encoding/json" "reflect" + "strings" "testing" "github.com/kubeflow/pipelines/backend/src/apiserver/common" "github.com/kubeflow/pipelines/backend/src/apiserver/filter" + "github.com/kubeflow/pipelines/backend/src/apiserver/model" "github.com/kubeflow/pipelines/backend/src/common/util" "github.com/stretchr/testify/assert" @@ -16,8 +19,8 @@ import ( ) type fakeMetric struct { - Name: string - Value: float64 + Name string + Value float64 } type fakeListable struct { @@ -49,6 +52,16 @@ func (f *fakeListable) GetModelName() string { return "" } +func (f *fakeListable) GetField(name string) (string, bool) { + if field, ok := fakeAPIToModelMap[name]; ok { + return field, true + } + if strings.HasPrefix(name, "metric:") { + return name[7:], true + } + return "", false +} + func (f *fakeListable) GetFieldValue(name string) interface{} { switch name { case "CreatedTimestamp": @@ -66,14 +79,22 @@ func (f *fakeListable) GetFieldValue(name string) interface{} { return nil } +func (f *fakeListable) GetSortByFieldPrefix(name string) string { + return "" +} + +func (f *fakeListable) GetKeyFieldPrefix() string { + return "" +} + func TestNextPageToken_ValidTokens(t *testing.T) { l := &fakeListable{PrimaryKey: "uuid123", FakeName: "Fake", CreatedTimestamp: 1234, Metrics: []*fakeMetric{ { - Name: "m1", + Name: "m1", Value: 1.0, }, { - Name: "m2", + Name: "m2", Value: 2.0, }, }} @@ -103,6 +124,7 @@ func TestNextPageToken_ValidTokens(t *testing.T) { KeyFieldName: "PrimaryKey", KeyFieldValue: "uuid123", IsDesc: true, + Model: l, }, }, { @@ -115,6 +137,7 @@ func TestNextPageToken_ValidTokens(t *testing.T) { KeyFieldName: "PrimaryKey", KeyFieldValue: "uuid123", IsDesc: true, + Model: l, }, }, { @@ -127,6 +150,7 @@ func TestNextPageToken_ValidTokens(t *testing.T) { KeyFieldName: "PrimaryKey", KeyFieldValue: "uuid123", IsDesc: false, + Model: l, }, }, { @@ -144,22 +168,23 @@ func TestNextPageToken_ValidTokens(t *testing.T) { KeyFieldValue: "uuid123", IsDesc: false, Filter: testFilter, + Model: l, }, }, { inOpts: &Options{ PageSize: 10, token: &token{ - SortByFieldName: "metric:m1", IsDesc: false, + SortByFieldName: "m1", IsDesc: false, }, }, want: &token{ SortByFieldName: "m1", - SortByFieldValue: "1.0", - SortByFieldIsRunMetric: true, + SortByFieldValue: 1.0, KeyFieldName: "PrimaryKey", KeyFieldValue: "uuid123", IsDesc: false, + Model: l, }, }, } @@ -222,6 +247,7 @@ func TestNewOptions_FromValidSerializedToken(t *testing.T) { KeyFieldName: "KeyField", KeyFieldValue: "string_key_value", IsDesc: true, + Model: &fakeListable{}, } s, err := tok.marshal() @@ -229,6 +255,9 @@ func TestNewOptions_FromValidSerializedToken(t *testing.T) { t.Fatalf("failed to marshal token %+v: %v", tok, err) } + tok.Model = nil + tok.ModelType = "fakeListable" + tok.ModelMessage, _ = json.Marshal(&fakeListable{}) want := &Options{PageSize: 123, token: tok} got, err := NewOptionsFromToken(s, 123) @@ -286,6 +315,7 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { KeyFieldName: "PrimaryKey", SortByFieldName: "CreatedTimestamp", IsDesc: false, + Model: &fakeListable{}, }, }, }, @@ -297,6 +327,7 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { KeyFieldName: "PrimaryKey", SortByFieldName: "CreatedTimestamp", IsDesc: false, + Model: &fakeListable{}, }, }, }, @@ -308,6 +339,7 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { KeyFieldName: "PrimaryKey", SortByFieldName: "FakeName", IsDesc: false, + Model: &fakeListable{}, }, }, }, @@ -319,6 +351,7 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { KeyFieldName: "PrimaryKey", SortByFieldName: "FakeName", IsDesc: false, + Model: &fakeListable{}, }, }, }, @@ -330,6 +363,7 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { KeyFieldName: "PrimaryKey", SortByFieldName: "FakeName", IsDesc: true, + Model: &fakeListable{}, }, }, }, @@ -341,6 +375,7 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { KeyFieldName: "PrimaryKey", SortByFieldName: "PrimaryKey", IsDesc: true, + Model: &fakeListable{}, }, }, }, @@ -416,6 +451,7 @@ func TestNewOptions_ValidFilter(t *testing.T) { SortByFieldName: "CreatedTimestamp", IsDesc: false, Filter: f, + Model: &fakeListable{}, }, } @@ -476,6 +512,7 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { KeyFieldName: "KeyField", KeyFieldValue: 1111, IsDesc: true, + Model: &fakeListable{}, }, }, wantSQL: "SELECT * FROM MyTable WHERE (SortField < ? OR (SortField = ? AND KeyField <= ?)) ORDER BY SortField DESC, KeyField DESC LIMIT 124", @@ -490,6 +527,7 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { KeyFieldName: "KeyField", KeyFieldValue: 1111, IsDesc: false, + Model: &fakeListable{}, }, }, wantSQL: "SELECT * FROM MyTable WHERE (SortField > ? OR (SortField = ? AND KeyField >= ?)) ORDER BY SortField ASC, KeyField ASC LIMIT 124", @@ -505,6 +543,7 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { KeyFieldValue: 1111, IsDesc: false, Filter: f, + Model: &fakeListable{}, }, }, wantSQL: "SELECT * FROM MyTable WHERE (SortField > ? OR (SortField = ? AND KeyField >= ?)) AND Name = ? ORDER BY SortField ASC, KeyField ASC LIMIT 124", @@ -518,6 +557,7 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { KeyFieldName: "KeyField", KeyFieldValue: 1111, IsDesc: true, + Model: &fakeListable{}, }, }, wantSQL: "SELECT * FROM MyTable ORDER BY SortField DESC, KeyField DESC LIMIT 124", @@ -531,6 +571,7 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { SortByFieldValue: "value", KeyFieldName: "KeyField", IsDesc: false, + Model: &fakeListable{}, }, }, wantSQL: "SELECT * FROM MyTable ORDER BY SortField ASC, KeyField ASC LIMIT 124", @@ -545,6 +586,7 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { KeyFieldName: "KeyField", IsDesc: false, Filter: f, + Model: &fakeListable{}, }, }, wantSQL: "SELECT * FROM MyTable WHERE Name = ? ORDER BY SortField ASC, KeyField ASC LIMIT 124", @@ -575,6 +617,8 @@ func TestTokenSerialization(t *testing.T) { t.Fatalf("failed to parse filter proto %+v: %v", protoFilter, err) } + modelMessage, _ := json.Marshal(&fakeListable{}) + tests := []struct { in *token want *token @@ -586,13 +630,16 @@ func TestTokenSerialization(t *testing.T) { SortByFieldValue: "string_field_value", KeyFieldName: "KeyField", KeyFieldValue: "string_key_value", - IsDesc: true}, + IsDesc: true, + Model: &fakeListable{}}, want: &token{ SortByFieldName: "SortField", SortByFieldValue: "string_field_value", KeyFieldName: "KeyField", KeyFieldValue: "string_key_value", - IsDesc: true}, + IsDesc: true, + ModelType: "fakeListable", + ModelMessage: modelMessage}, }, // int values get deserialized as floats by JSON unmarshal. { @@ -601,13 +648,16 @@ func TestTokenSerialization(t *testing.T) { SortByFieldValue: 100, KeyFieldName: "KeyField", KeyFieldValue: 200, - IsDesc: true}, + IsDesc: true, + Model: &fakeListable{}}, want: &token{ SortByFieldName: "SortField", SortByFieldValue: float64(100), KeyFieldName: "KeyField", KeyFieldValue: float64(200), - IsDesc: true}, + IsDesc: true, + ModelType: "fakeListable", + ModelMessage: modelMessage}, }, // has a filter. { @@ -618,6 +668,7 @@ func TestTokenSerialization(t *testing.T) { KeyFieldValue: 200, IsDesc: true, Filter: testFilter, + Model: &fakeListable{}, }, want: &token{ SortByFieldName: "SortField", @@ -626,6 +677,8 @@ func TestTokenSerialization(t *testing.T) { KeyFieldValue: float64(200), IsDesc: true, Filter: testFilter, + ModelType: "fakeListable", + ModelMessage: modelMessage, }, }, } @@ -872,3 +925,64 @@ func TestFilterOnNamesapce(t *testing.T) { } } } + +func TestAddSortingToSelectWithPipelineVersionModel(t *testing.T) { + listable := &model.PipelineVersion{ + UUID: "version_id_1", + CreatedAtInSec: 1, + Name: "version_name_1", + Parameters: "", + PipelineId: "pipeline_id_1", + Status: model.PipelineVersionReady, + CodeSourceUrl: "", + } + protoFilter := &api.Filter{} + listableOptions, err := NewOptions(listable, 10, "name", protoFilter) + assert.Nil(t, err) + sqlBuilder := sq.Select("*").From("pipeline_versions") + sql, _, err := listableOptions.AddSortingToSelect(sqlBuilder).ToSql() + assert.Nil(t, err) + + assert.Contains(t, sql, "pipeline_versions.Name") // sorting field + assert.Contains(t, sql, "pipeline_versions.UUID") // primary key field +} + +func TestAddStatusFilterToSelectWithRunModel(t *testing.T) { + listable := &model.Run{ + UUID: "run_id_1", + CreatedAtInSec: 1, + Name: "run_name_1", + Conditions: "Succeeded", + } + protoFilter := &api.Filter{} + protoFilter.Predicates = []*api.Predicate{ + { + Key: "status", + Op: api.Predicate_EQUALS, + Value: &api.Predicate_StringValue{StringValue: "Succeeded"}, + }, + } + listableOptions, err := NewOptions(listable, 10, "name", protoFilter) + assert.Nil(t, err) + sqlBuilder := sq.Select("*").From("run_details") + sql, args, err := listableOptions.AddFilterToSelect(sqlBuilder).ToSql() + assert.Nil(t, err) + assert.Contains(t, sql, "WHERE Conditions = ?") // filtering on status, aka Conditions in db + assert.Contains(t, args, "Succeeded") + + notEqualProtoFilter := &api.Filter{} + notEqualProtoFilter.Predicates = []*api.Predicate{ + { + Key: "status", + Op: api.Predicate_NOT_EQUALS, + Value: &api.Predicate_StringValue{StringValue: "somevalue"}, + }, + } + listableOptions, err = NewOptions(listable, 10, "name", notEqualProtoFilter) + assert.Nil(t, err) + sqlBuilder = sq.Select("*").From("run_details") + sql, args, err = listableOptions.AddFilterToSelect(sqlBuilder).ToSql() + assert.Nil(t, err) + assert.Contains(t, sql, "WHERE Conditions <> ?") // filtering on status, aka Conditions in db + assert.Contains(t, args, "somevalue") +} diff --git a/backend/src/apiserver/model/BUILD.bazel b/backend/src/apiserver/model/BUILD.bazel index a6716102f22..75add6c411a 100644 --- a/backend/src/apiserver/model/BUILD.bazel +++ b/backend/src/apiserver/model/BUILD.bazel @@ -22,9 +22,7 @@ go_library( go_test( name = "go_default_test", srcs = [ - "pipeline_version_test.go", "resource_reference_test.go", - "run_test.go", ], embed = [":go_default_library"], importpath = "github.com/kubeflow/pipelines/backend/src/apiserver/model", @@ -32,7 +30,6 @@ go_test( deps = [ "//backend/api:go_default_library", "//backend/src/apiserver/common:go_default_library", - "//backend/src/apiserver/list:go_default_library", "@com_github_masterminds_squirrel//:go_default_library", "@com_github_stretchr_testify//assert:go_default_library", ], diff --git a/backend/src/apiserver/model/experiment.go b/backend/src/apiserver/model/experiment.go index 7a9d17d3591..e79ed1984aa 100644 --- a/backend/src/apiserver/model/experiment.go +++ b/backend/src/apiserver/model/experiment.go @@ -47,6 +47,13 @@ func (e *Experiment) GetModelName() string { return "experiments" } +func (e *Experiment) GetField(name string) (string, bool) { + if field, ok := experimentAPIToModelFieldMap[name]; ok { + return field, true + } + return "", false +} + func (e *Experiment) GetFieldValue(name string) interface{} { switch name { case "UUID": @@ -65,3 +72,11 @@ func (e *Experiment) GetFieldValue(name string) interface{} { return nil } } + +func (e *Experiment) GetSortByFieldPrefix(name string) string { + return "experiments." +} + +func (e *Experiment) GetKeyFieldPrefix() string { + return "experiments." +} diff --git a/backend/src/apiserver/model/job.go b/backend/src/apiserver/model/job.go index 855038219a3..376132f5346 100644 --- a/backend/src/apiserver/model/job.go +++ b/backend/src/apiserver/model/job.go @@ -105,6 +105,13 @@ func (j *Job) GetModelName() string { return "jobs" } +func (j *Job) GetField(name string) (string, bool) { + if field, ok := jobAPIToModelFieldMap[name]; ok { + return field, true + } + return "", false +} + func (j *Job) GetFieldValue(name string) interface{} { switch name { case "UUID": @@ -119,3 +126,11 @@ func (j *Job) GetFieldValue(name string) interface{} { return nil } } + +func (j *Job) GetSortByFieldPrefix(name string) string { + return "jobs." +} + +func (j *Job) GetKeyFieldPrefix() string { + return "jobs." +} diff --git a/backend/src/apiserver/model/pipeline.go b/backend/src/apiserver/model/pipeline.go index 68cc9a24295..1a52ec73467 100644 --- a/backend/src/apiserver/model/pipeline.go +++ b/backend/src/apiserver/model/pipeline.go @@ -79,6 +79,13 @@ func (p *Pipeline) GetModelName() string { return "pipelines" } +func (p *Pipeline) GetField(name string) (string, bool) { + if field, ok := pipelineAPIToModelFieldMap[name]; ok { + return field, true + } + return "", false +} + func (p *Pipeline) GetFieldValue(name string) interface{} { switch name { case "UUID": @@ -93,3 +100,11 @@ func (p *Pipeline) GetFieldValue(name string) interface{} { return nil } } + +func (p *Pipeline) GetSortByFieldPrefix(name string) string { + return "pipelines." +} + +func (p *Pipeline) GetKeyFieldPrefix() string { + return "pipelines." +} diff --git a/backend/src/apiserver/model/pipeline_version.go b/backend/src/apiserver/model/pipeline_version.go index 93b76466a50..8eca8670aa9 100644 --- a/backend/src/apiserver/model/pipeline_version.go +++ b/backend/src/apiserver/model/pipeline_version.go @@ -74,6 +74,13 @@ func (p *PipelineVersion) GetModelName() string { return "pipeline_versions" } +func (p *PipelineVersion) GetField(name string) (string, bool) { + if field, ok := p.APIToModelFieldMap()[name]; ok { + return field, true + } + return "", false +} + func (p *PipelineVersion) GetFieldValue(name string) interface{} { switch name { case "UUID": @@ -88,3 +95,11 @@ func (p *PipelineVersion) GetFieldValue(name string) interface{} { return nil } } + +func (p *PipelineVersion) GetSortByFieldPrefix(name string) string { + return "pipeline_versions." +} + +func (p *PipelineVersion) GetKeyFieldPrefix() string { + return "pipeline_versions." +} diff --git a/backend/src/apiserver/model/pipeline_version_test.go b/backend/src/apiserver/model/pipeline_version_test.go deleted file mode 100644 index 1746f333952..00000000000 --- a/backend/src/apiserver/model/pipeline_version_test.go +++ /dev/null @@ -1,32 +0,0 @@ -package model - -import ( - "testing" - - sq "github.com/Masterminds/squirrel" - api "github.com/kubeflow/pipelines/backend/api/go_client" - "github.com/kubeflow/pipelines/backend/src/apiserver/list" - "github.com/stretchr/testify/assert" -) - -// Test model name usage in sorting clause -func TestAddSortingToSelect(t *testing.T) { - listable := &PipelineVersion{ - UUID: "version_id_1", - CreatedAtInSec: 1, - Name: "version_name_1", - Parameters: "", - PipelineId: "pipeline_id_1", - Status: PipelineVersionReady, - CodeSourceUrl: "", - } - protoFilter := &api.Filter{} - listableOptions, err := list.NewOptions(listable, 10, "name", protoFilter) - assert.Nil(t, err) - sqlBuilder := sq.Select("*").From("pipeline_versions") - sql, _, err := listableOptions.AddSortingToSelect(sqlBuilder).ToSql() - assert.Nil(t, err) - - assert.Contains(t, sql, "pipeline_versions.Name") // sorting field - assert.Contains(t, sql, "pipeline_versions.UUID") // primary key field -} diff --git a/backend/src/apiserver/model/resource_reference_test.go b/backend/src/apiserver/model/resource_reference_test.go index ff58545064e..afc05a89ad1 100644 --- a/backend/src/apiserver/model/resource_reference_test.go +++ b/backend/src/apiserver/model/resource_reference_test.go @@ -15,9 +15,10 @@ package model import ( + "testing" + "github.com/kubeflow/pipelines/backend/src/apiserver/common" "github.com/stretchr/testify/assert" - "testing" ) func TestGetNamespaceFromResourceReferencesModel(t *testing.T) { diff --git a/backend/src/apiserver/model/run.go b/backend/src/apiserver/model/run.go index 288085f5e16..fc85d206f21 100644 --- a/backend/src/apiserver/model/run.go +++ b/backend/src/apiserver/model/run.go @@ -14,6 +14,10 @@ package model +import ( + "strings" +) + type Run struct { UUID string `gorm:"column:UUID; not null; primary_key"` ExperimentUUID string `gorm:"column:ExperimentUUID; not null;"` @@ -92,6 +96,16 @@ func (r *Run) GetModelName() string { return "" } +func (r *Run) GetField(name string) (string, bool) { + if field, ok := runAPIToModelFieldMap[name]; ok { + return field, true + } + if strings.HasPrefix(name, "metric:") { + return name[7:], true + } + return "", false +} + func (r *Run) GetFieldValue(name string) interface{} { // "name" could be a field in Run type or a name inside an array typed field // in Run type @@ -120,3 +134,23 @@ func (r *Run) GetFieldValue(name string) interface{} { } return nil } + +// Regular fields are the fields that are mapped to columns in Run table. +// Non-regular fields are the run metrics for now. Could have other non-regular +// sorting fields later. +func (r *Run) IsRegularField(name string) bool { + _, ok := runAPIToModelFieldMap[name] + return ok +} + +func (r *Run) GetSortByFieldPrefix(name string) string { + if r.IsRegularField(name) { + return r.GetModelName() + } else { + return "" + } +} + +func (r *Run) GetKeyFieldPrefix() string { + return r.GetModelName() +} diff --git a/backend/src/apiserver/model/run_test.go b/backend/src/apiserver/model/run_test.go deleted file mode 100644 index fb379352ff9..00000000000 --- a/backend/src/apiserver/model/run_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package model - -import ( - "testing" - - sq "github.com/Masterminds/squirrel" - api "github.com/kubeflow/pipelines/backend/api/go_client" - "github.com/kubeflow/pipelines/backend/src/apiserver/list" - "github.com/stretchr/testify/assert" -) - -// Test model name usage in sorting clause -func TestAddStatusFilterToSelect(t *testing.T) { - listable := &Run{ - UUID: "run_id_1", - CreatedAtInSec: 1, - Name: "run_name_1", - Conditions: "Succeeded", - } - protoFilter := &api.Filter{} - protoFilter.Predicates = []*api.Predicate{ - { - Key: "status", - Op: api.Predicate_EQUALS, - Value: &api.Predicate_StringValue{StringValue: "Succeeded"}, - }, - } - listableOptions, err := list.NewOptions(listable, 10, "name", protoFilter) - assert.Nil(t, err) - sqlBuilder := sq.Select("*").From("run_details") - sql, args, err := listableOptions.AddFilterToSelect(sqlBuilder).ToSql() - assert.Nil(t, err) - assert.Contains(t, sql, "WHERE Conditions = ?") // filtering on status, aka Conditions in db - assert.Contains(t, args, "Succeeded") - - notEqualProtoFilter := &api.Filter{} - notEqualProtoFilter.Predicates = []*api.Predicate{ - { - Key: "status", - Op: api.Predicate_NOT_EQUALS, - Value: &api.Predicate_StringValue{StringValue: "somevalue"}, - }, - } - listableOptions, err = list.NewOptions(listable, 10, "name", notEqualProtoFilter) - assert.Nil(t, err) - sqlBuilder = sq.Select("*").From("run_details") - sql, args, err = listableOptions.AddFilterToSelect(sqlBuilder).ToSql() - assert.Nil(t, err) - assert.Contains(t, sql, "WHERE Conditions <> ?") // filtering on status, aka Conditions in db - assert.Contains(t, args, "somevalue") -} diff --git a/backend/src/apiserver/server/list_request_util_test.go b/backend/src/apiserver/server/list_request_util_test.go index 17991b039f3..18bfe2d5309 100644 --- a/backend/src/apiserver/server/list_request_util_test.go +++ b/backend/src/apiserver/server/list_request_util_test.go @@ -245,10 +245,34 @@ func (f *fakeListable) GetModelName() string { return "" } +func (f *fakeListable) GetField(name string) (string, bool) { + if field, ok := fakeAPIToModelMap[name]; ok { + return field, true + } else { + return "", false + } +} + func (f *fakeListable) GetFieldValue(name string) interface{} { + switch name { + case "CreatedTimestamp": + return f.CreatedTimestamp + case "FakeName": + return f.FakeName + case "PrimaryKey": + return f.PrimaryKey + } return nil } +func (f *fakeListable) GetSortByFieldPrefix(name string) string { + return "" +} + +func (f *fakeListable) GetKeyFieldPrefix() string { + return "" +} + func TestValidatedListOptions_Errors(t *testing.T) { opts, err := list.NewOptions(&fakeListable{}, 10, "name asc", nil) if err != nil { diff --git a/backend/src/apiserver/storage/run_store.go b/backend/src/apiserver/storage/run_store.go index f283390fd7b..92270c22314 100644 --- a/backend/src/apiserver/storage/run_store.go +++ b/backend/src/apiserver/storage/run_store.go @@ -170,7 +170,7 @@ func (s *RunStore) buildSelectRunsQuery(selectCount bool, opts *list.Options, // If we're not just counting, then also add select columns and perform a left join // to get resource reference information. Also add pagination. if !selectCount { - sqlBuilder = opts.AddSortByRunMetricToSelect(sqlBuilder) + sqlBuilder = s.AddSortByRunMetricToSelect(sqlBuilder, opts) sqlBuilder = opts.AddPaginationToSelect(sqlBuilder) sqlBuilder = s.addMetricsAndResourceReferences(sqlBuilder, opts) sqlBuilder = opts.AddSortingToSelect(sqlBuilder) @@ -224,11 +224,12 @@ func Map(vs []string, f func(string) string) []string { } func (s *RunStore) addMetricsAndResourceReferences(filteredSelectBuilder sq.SelectBuilder, opts *list.Options) sq.SelectBuilder { + var r model.Run resourceRefConcatQuery := s.db.Concat([]string{`"["`, s.db.GroupConcat("rr.Payload", ","), `"]"`}, "") columnsAfterJoiningResourceReferences := append( Map(runColumns, func(column string) string { return "rd." + column }), // Add prefix "rd." to runColumns resourceRefConcatQuery+" AS refs") - if opts != nil && opts.SortByFieldIsRunMetric { + if opts != nil && !r.IsRegularField(opts.SortByFieldName) { columnsAfterJoiningResourceReferences = append(columnsAfterJoiningResourceReferences, "rd."+opts.SortByFieldName) } subQ := sq. @@ -619,3 +620,18 @@ func (s *RunStore) TerminateRun(runId string) error { return nil } + +// Add a metric as a new field to the select clause by join the passed-in SQL query with run_metrics table. +// With the metric as a field in the select clause enable sorting on this metric afterwards. +// TODO(jingzhang36): example of resulting SQL query and explanation for it. +func (s *RunStore) AddSortByRunMetricToSelect(sqlBuilder sq.SelectBuilder, opts *list.Options) sq.SelectBuilder { + var r model.Run + if r.IsRegularField(opts.SortByFieldName) { + return sqlBuilder + } + // TODO(jingzhang36): address the case where runs doesn't have the specified metric. + return sq. + Select("selected_runs.*, run_metrics.numbervalue as "+opts.SortByFieldName). + FromSelect(sqlBuilder, "selected_runs"). + LeftJoin("run_metrics ON selected_runs.uuid=run_metrics.runuuid AND run_metrics.name='" + opts.SortByFieldName + "'") +} From 80f514e8d26456bf0bedd29892ecf14055a80e0b Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Mon, 27 Jul 2020 13:30:12 +0800 Subject: [PATCH 09/14] some assumption on token's Model field --- backend/src/apiserver/list/list.go | 15 +++++++-------- backend/src/apiserver/list/list_test.go | 1 + 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/backend/src/apiserver/list/list.go b/backend/src/apiserver/list/list.go index 2f4f2ef3d96..8fd855a1fc3 100644 --- a/backend/src/apiserver/list/list.go +++ b/backend/src/apiserver/list/list.go @@ -135,14 +135,13 @@ func (t *token) unmarshal(pageToken string) error { } func (t *token) marshal() (string, error) { - if t.Model != nil { - t.ModelType = reflect.ValueOf(t.Model).Elem().Type().Name() - modelMessage, err := json.Marshal(t.Model) - if err != nil { - return "", util.NewInternalServerError(err, "Failed to serialize the listable object in page token.") - } - t.ModelMessage = modelMessage - } // can we set empty raw message explicitly in case of nil model + // Model in a token should not be nil, because this token is created when listing a model (i.e., run, job, experiment, pipeline and pipeline version). + t.ModelType = reflect.ValueOf(t.Model).Elem().Type().Name() + modelMessage, err := json.Marshal(t.Model) + if err != nil { + return "", util.NewInternalServerError(err, "Failed to serialize the listable object in page token.") + } + t.ModelMessage = modelMessage b, err := json.Marshal(t) if err != nil { diff --git a/backend/src/apiserver/list/list_test.go b/backend/src/apiserver/list/list_test.go index c01d5dc3443..8ae588678a2 100644 --- a/backend/src/apiserver/list/list_test.go +++ b/backend/src/apiserver/list/list_test.go @@ -287,6 +287,7 @@ func TestNewOptionsFromToken_FromInValidPageSize(t *testing.T) { KeyFieldName: "KeyField", KeyFieldValue: "string_key_value", IsDesc: true, + Model: &fakeListable{}, } s, err := tok.marshal() From 98353f99e77bfbe0a09e926012465a3773e18094 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Mon, 27 Jul 2020 18:11:05 +0800 Subject: [PATCH 10/14] fix the regular field checking logic --- backend/src/apiserver/model/run.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/backend/src/apiserver/model/run.go b/backend/src/apiserver/model/run.go index fc85d206f21..a7ad0476f0a 100644 --- a/backend/src/apiserver/model/run.go +++ b/backend/src/apiserver/model/run.go @@ -139,8 +139,12 @@ func (r *Run) GetFieldValue(name string) interface{} { // Non-regular fields are the run metrics for now. Could have other non-regular // sorting fields later. func (r *Run) IsRegularField(name string) bool { - _, ok := runAPIToModelFieldMap[name] - return ok + for _, field := range runAPIToModelFieldMap { + if field == name { + return true + } + } + return false } func (r *Run) GetSortByFieldPrefix(name string) string { From 998ccc790a10e81e8144618fd3dc2675581ac267 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Thu, 30 Jul 2020 14:55:51 +0800 Subject: [PATCH 11/14] add comment to help devs to use the new field --- backend/src/apiserver/list/list.go | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/src/apiserver/list/list.go b/backend/src/apiserver/list/list.go index 8fd855a1fc3..c49d2318961 100644 --- a/backend/src/apiserver/list/list.go +++ b/backend/src/apiserver/list/list.go @@ -64,6 +64,7 @@ type token struct { Filter *filter.Filter // The listable model this token is applied to. Not used in json marshal/unmarshal. + // The types that implement the listable interface and hence can be used in this field are at backend/src/apiserver/model. Model Listable `json:"-"` // ModelType and the ModelMessage are helper fields to unmarshal data correctly to // the underlying listable model, and this underlying listable model will be stored From afc06b9fe43efc118cbcbc8db91aab53104cbd94 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Thu, 30 Jul 2020 15:06:52 +0800 Subject: [PATCH 12/14] add a validation check --- backend/src/apiserver/server/list_request_util.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/backend/src/apiserver/server/list_request_util.go b/backend/src/apiserver/server/list_request_util.go index 88a5b1ca62e..72f5784d474 100644 --- a/backend/src/apiserver/server/list_request_util.go +++ b/backend/src/apiserver/server/list_request_util.go @@ -187,6 +187,10 @@ func parseAPIFilter(encoded string) (*api.Filter, error) { func validatedListOptions(listable list.Listable, pageToken string, pageSize int, sortBy string, filterSpec string) (*list.Options, error) { defaultOpts := func() (*list.Options, error) { + if listable == nil { + return nil, util.NewInvalidInputError("Please specify a valid type to list. E.g., list runs or list jobs.") + } + f, err := parseAPIFilter(filterSpec) if err != nil { return nil, err From be4be6117aae305a5048a348d49abe585c93e80e Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Thu, 30 Jul 2020 18:44:19 +0800 Subject: [PATCH 13/14] Listable object can be too large to be in token. So replace it with only relevant fields taken out of it. In the future, if more fields in Listable object become relevant, manually add it to token --- backend/src/apiserver/list/list.go | 113 ++------ backend/src/apiserver/list/list_test.go | 325 +++++++++++++----------- 2 files changed, 195 insertions(+), 243 deletions(-) diff --git a/backend/src/apiserver/list/list.go b/backend/src/apiserver/list/list.go index c49d2318961..3ebdbd2d5ef 100644 --- a/backend/src/apiserver/list/list.go +++ b/backend/src/apiserver/list/list.go @@ -29,7 +29,6 @@ import ( api "github.com/kubeflow/pipelines/backend/api/go_client" "github.com/kubeflow/pipelines/backend/src/apiserver/common" "github.com/kubeflow/pipelines/backend/src/apiserver/filter" - "github.com/kubeflow/pipelines/backend/src/apiserver/model" "github.com/kubeflow/pipelines/backend/src/common/util" ) @@ -44,34 +43,24 @@ type token struct { SortByFieldName string // SortByFieldValue is the value of the sorted field of the next row to be // returned. - SortByFieldValue interface{} + SortByFieldValue interface{} + SortByFieldPrefix string // KeyFieldName is the name of the primary key for the model being queried. KeyFieldName string // KeyFieldValue is the value of the sorted field of the next row to be // returned. - KeyFieldValue interface{} + KeyFieldValue interface{} + KeyFieldPrefix string // IsDesc is true if the sorting order should be descending. IsDesc bool // ModelName is the table where ***FieldName belongs to. - // TODO(jingzhang36): we probably can deprecate this since we now have - // Model field. ModelName string // Filter represents the filtering that should be applied in the query. Filter *filter.Filter - - // The listable model this token is applied to. Not used in json marshal/unmarshal. - // The types that implement the listable interface and hence can be used in this field are at backend/src/apiserver/model. - Model Listable `json:"-"` - // ModelType and the ModelMessage are helper fields to unmarshal data correctly to - // the underlying listable model, and this underlying listable model will be stored - // in the above Model field. Those two fields are only used in token's marshal and - // unmarshal methods. - ModelType string - ModelMessage json.RawMessage } func (t *token) unmarshal(pageToken string) error { @@ -87,63 +76,10 @@ func (t *token) unmarshal(pageToken string) error { return errorF(err) } - if t.ModelMessage != nil { - switch t.ModelType { - case "Run": - model := &model.Run{} - err = json.Unmarshal(t.ModelMessage, model) - if err != nil { - return errorF(err) - } - t.Model = model - break - case "Job": - model := &model.Job{} - err = json.Unmarshal(t.ModelMessage, model) - if err != nil { - return errorF(err) - } - t.Model = model - break - case "Experiment": - model := &model.Experiment{} - err = json.Unmarshal(t.ModelMessage, model) - if err != nil { - return errorF(err) - } - t.Model = model - break - case "Pipeline": - model := &model.Pipeline{} - err = json.Unmarshal(t.ModelMessage, model) - if err != nil { - return errorF(err) - } - t.Model = model - break - case "PipelineVersion": - model := &model.PipelineVersion{} - err = json.Unmarshal(t.ModelMessage, model) - if err != nil { - return errorF(err) - } - t.Model = model - break - } - } - return nil } func (t *token) marshal() (string, error) { - // Model in a token should not be nil, because this token is created when listing a model (i.e., run, job, experiment, pipeline and pipeline version). - t.ModelType = reflect.ValueOf(t.Model).Elem().Type().Name() - modelMessage, err := json.Marshal(t.Model) - if err != nil { - return "", util.NewInternalServerError(err, "Failed to serialize the listable object in page token.") - } - t.ModelMessage = modelMessage - b, err := json.Marshal(t) if err != nil { return "", util.NewInternalServerError(err, "Failed to serialize page token.") @@ -199,8 +135,7 @@ func NewOptions(listable Listable, pageSize int, sortBy string, filterProto *api token := &token{ KeyFieldName: listable.PrimaryKeyColumnName(), - ModelName: listable.GetModelName(), - Model: listable} + ModelName: listable.GetModelName()} // Ignore the case of the letter. Split query string by space. queryList := strings.Fields(strings.ToLower(sortBy)) @@ -219,6 +154,8 @@ func NewOptions(listable Listable, pageSize int, sortBy string, filterProto *api return nil, util.NewInvalidInputError("Invalid sorting field: %q on listable type %s", queryList[0], reflect.ValueOf(listable).Elem().Type().Name()) } } + token.SortByFieldPrefix = listable.GetSortByFieldPrefix(token.SortByFieldName) + token.KeyFieldPrefix = listable.GetKeyFieldPrefix() if len(queryList) == 2 { token.IsDesc = queryList[1] == "desc" @@ -250,21 +187,18 @@ func (o *Options) AddPaginationToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBu // AddSortingToSelect adds Order By clause. func (o *Options) AddSortingToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBuilder { // When sorting by a direct field in the listable model (i.e., name in Run or uuid in Pipeline), a sortByFieldPrefix can be specified; when sorting by a field in an array-typed dictionary (i.e., a run metric inside the metrics in Run), a sortByFieldPrefix is not needed. - keyFieldPrefix := o.Model.GetKeyFieldPrefix() - sortByFieldPrefix := o.Model.GetSortByFieldPrefix(o.SortByFieldName) - // If next row's value is specified, set those values in the clause. if o.SortByFieldValue != nil && o.KeyFieldValue != nil { if o.IsDesc { sqlBuilder = sqlBuilder. - Where(sq.Or{sq.Lt{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, - sq.And{sq.Eq{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, - sq.LtOrEq{keyFieldPrefix + o.KeyFieldName: o.KeyFieldValue}}}) + Where(sq.Or{sq.Lt{o.SortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, + sq.And{sq.Eq{o.SortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, + sq.LtOrEq{o.KeyFieldPrefix + o.KeyFieldName: o.KeyFieldValue}}}) } else { sqlBuilder = sqlBuilder. - Where(sq.Or{sq.Gt{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, - sq.And{sq.Eq{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, - sq.GtOrEq{keyFieldPrefix + o.KeyFieldName: o.KeyFieldValue}}}) + Where(sq.Or{sq.Gt{o.SortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, + sq.And{sq.Eq{o.SortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, + sq.GtOrEq{o.KeyFieldPrefix + o.KeyFieldName: o.KeyFieldValue}}}) } } @@ -273,8 +207,8 @@ func (o *Options) AddSortingToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBuild order = "DESC" } sqlBuilder = sqlBuilder. - OrderBy(fmt.Sprintf("%v %v", sortByFieldPrefix+o.SortByFieldName, order)). - OrderBy(fmt.Sprintf("%v %v", keyFieldPrefix+o.KeyFieldName, order)) + OrderBy(fmt.Sprintf("%v %v", o.SortByFieldPrefix+o.SortByFieldName, order)). + OrderBy(fmt.Sprintf("%v %v", o.KeyFieldPrefix+o.KeyFieldName, order)) return sqlBuilder } @@ -412,14 +346,15 @@ func (o *Options) nextPageToken(listable Listable) (*token, error) { } return &token{ - SortByFieldName: o.SortByFieldName, - SortByFieldValue: sortByField, - KeyFieldName: listable.PrimaryKeyColumnName(), - KeyFieldValue: keyField.Interface(), - IsDesc: o.IsDesc, - Filter: o.Filter, - ModelName: o.ModelName, - Model: listable, + SortByFieldName: o.SortByFieldName, + SortByFieldValue: sortByField, + SortByFieldPrefix: listable.GetSortByFieldPrefix(o.SortByFieldName), + KeyFieldName: listable.PrimaryKeyColumnName(), + KeyFieldValue: keyField.Interface(), + KeyFieldPrefix: listable.GetKeyFieldPrefix(), + IsDesc: o.IsDesc, + Filter: o.Filter, + ModelName: o.ModelName, }, nil } diff --git a/backend/src/apiserver/list/list_test.go b/backend/src/apiserver/list/list_test.go index 8ae588678a2..8ad4667914e 100644 --- a/backend/src/apiserver/list/list_test.go +++ b/backend/src/apiserver/list/list_test.go @@ -1,7 +1,6 @@ package list import ( - "encoding/json" "reflect" "strings" "testing" @@ -119,12 +118,13 @@ func TestNextPageToken_ValidTokens(t *testing.T) { PageSize: 10, token: &token{SortByFieldName: "CreatedTimestamp", IsDesc: true}, }, want: &token{ - SortByFieldName: "CreatedTimestamp", - SortByFieldValue: int64(1234), - KeyFieldName: "PrimaryKey", - KeyFieldValue: "uuid123", - IsDesc: true, - Model: l, + SortByFieldName: "CreatedTimestamp", + SortByFieldValue: int64(1234), + SortByFieldPrefix: "", + KeyFieldName: "PrimaryKey", + KeyFieldValue: "uuid123", + KeyFieldPrefix: "", + IsDesc: true, }, }, { @@ -132,12 +132,13 @@ func TestNextPageToken_ValidTokens(t *testing.T) { PageSize: 10, token: &token{SortByFieldName: "PrimaryKey", IsDesc: true}, }, want: &token{ - SortByFieldName: "PrimaryKey", - SortByFieldValue: "uuid123", - KeyFieldName: "PrimaryKey", - KeyFieldValue: "uuid123", - IsDesc: true, - Model: l, + SortByFieldName: "PrimaryKey", + SortByFieldValue: "uuid123", + SortByFieldPrefix: "", + KeyFieldName: "PrimaryKey", + KeyFieldValue: "uuid123", + KeyFieldPrefix: "", + IsDesc: true, }, }, { @@ -145,12 +146,13 @@ func TestNextPageToken_ValidTokens(t *testing.T) { PageSize: 10, token: &token{SortByFieldName: "FakeName", IsDesc: false}, }, want: &token{ - SortByFieldName: "FakeName", - SortByFieldValue: "Fake", - KeyFieldName: "PrimaryKey", - KeyFieldValue: "uuid123", - IsDesc: false, - Model: l, + SortByFieldName: "FakeName", + SortByFieldValue: "Fake", + SortByFieldPrefix: "", + KeyFieldName: "PrimaryKey", + KeyFieldValue: "uuid123", + KeyFieldPrefix: "", + IsDesc: false, }, }, { @@ -162,13 +164,14 @@ func TestNextPageToken_ValidTokens(t *testing.T) { }, }, want: &token{ - SortByFieldName: "FakeName", - SortByFieldValue: "Fake", - KeyFieldName: "PrimaryKey", - KeyFieldValue: "uuid123", - IsDesc: false, - Filter: testFilter, - Model: l, + SortByFieldName: "FakeName", + SortByFieldValue: "Fake", + SortByFieldPrefix: "", + KeyFieldName: "PrimaryKey", + KeyFieldValue: "uuid123", + KeyFieldPrefix: "", + IsDesc: false, + Filter: testFilter, }, }, { @@ -179,12 +182,13 @@ func TestNextPageToken_ValidTokens(t *testing.T) { }, }, want: &token{ - SortByFieldName: "m1", - SortByFieldValue: 1.0, - KeyFieldName: "PrimaryKey", - KeyFieldValue: "uuid123", - IsDesc: false, - Model: l, + SortByFieldName: "m1", + SortByFieldValue: 1.0, + SortByFieldPrefix: "", + KeyFieldName: "PrimaryKey", + KeyFieldValue: "uuid123", + KeyFieldPrefix: "", + IsDesc: false, }, }, } @@ -242,12 +246,13 @@ func TestValidatePageSize(t *testing.T) { func TestNewOptions_FromValidSerializedToken(t *testing.T) { tok := &token{ - SortByFieldName: "SortField", - SortByFieldValue: "string_field_value", - KeyFieldName: "KeyField", - KeyFieldValue: "string_key_value", - IsDesc: true, - Model: &fakeListable{}, + SortByFieldName: "SortField", + SortByFieldValue: "string_field_value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: "string_key_value", + KeyFieldPrefix: "", + IsDesc: true, } s, err := tok.marshal() @@ -255,9 +260,6 @@ func TestNewOptions_FromValidSerializedToken(t *testing.T) { t.Fatalf("failed to marshal token %+v: %v", tok, err) } - tok.Model = nil - tok.ModelType = "fakeListable" - tok.ModelMessage, _ = json.Marshal(&fakeListable{}) want := &Options{PageSize: 123, token: tok} got, err := NewOptionsFromToken(s, 123) @@ -282,12 +284,13 @@ func TestNewOptionsFromToken_FromInValidSerializedToken(t *testing.T) { func TestNewOptionsFromToken_FromInValidPageSize(t *testing.T) { tok := &token{ - SortByFieldName: "SortField", - SortByFieldValue: "string_field_value", - KeyFieldName: "KeyField", - KeyFieldValue: "string_key_value", - IsDesc: true, - Model: &fakeListable{}, + SortByFieldName: "SortField", + SortByFieldValue: "string_field_value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: "string_key_value", + KeyFieldPrefix: "", + IsDesc: true, } s, err := tok.marshal() @@ -313,10 +316,11 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { want: &Options{ PageSize: pageSize, token: &token{ - KeyFieldName: "PrimaryKey", - SortByFieldName: "CreatedTimestamp", - IsDesc: false, - Model: &fakeListable{}, + KeyFieldName: "PrimaryKey", + KeyFieldPrefix: "", + SortByFieldName: "CreatedTimestamp", + SortByFieldPrefix: "", + IsDesc: false, }, }, }, @@ -325,10 +329,11 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { want: &Options{ PageSize: pageSize, token: &token{ - KeyFieldName: "PrimaryKey", - SortByFieldName: "CreatedTimestamp", - IsDesc: false, - Model: &fakeListable{}, + KeyFieldName: "PrimaryKey", + KeyFieldPrefix: "", + SortByFieldName: "CreatedTimestamp", + SortByFieldPrefix: "", + IsDesc: false, }, }, }, @@ -337,10 +342,11 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { want: &Options{ PageSize: pageSize, token: &token{ - KeyFieldName: "PrimaryKey", - SortByFieldName: "FakeName", - IsDesc: false, - Model: &fakeListable{}, + KeyFieldName: "PrimaryKey", + KeyFieldPrefix: "", + SortByFieldName: "FakeName", + SortByFieldPrefix: "", + IsDesc: false, }, }, }, @@ -349,10 +355,11 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { want: &Options{ PageSize: pageSize, token: &token{ - KeyFieldName: "PrimaryKey", - SortByFieldName: "FakeName", - IsDesc: false, - Model: &fakeListable{}, + KeyFieldName: "PrimaryKey", + KeyFieldPrefix: "", + SortByFieldName: "FakeName", + SortByFieldPrefix: "", + IsDesc: false, }, }, }, @@ -361,10 +368,11 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { want: &Options{ PageSize: pageSize, token: &token{ - KeyFieldName: "PrimaryKey", - SortByFieldName: "FakeName", - IsDesc: true, - Model: &fakeListable{}, + KeyFieldName: "PrimaryKey", + KeyFieldPrefix: "", + SortByFieldName: "FakeName", + SortByFieldPrefix: "", + IsDesc: true, }, }, }, @@ -373,10 +381,11 @@ func TestNewOptions_ValidSortOptions(t *testing.T) { want: &Options{ PageSize: pageSize, token: &token{ - KeyFieldName: "PrimaryKey", - SortByFieldName: "PrimaryKey", - IsDesc: true, - Model: &fakeListable{}, + KeyFieldName: "PrimaryKey", + KeyFieldPrefix: "", + SortByFieldName: "PrimaryKey", + SortByFieldPrefix: "", + IsDesc: true, }, }, }, @@ -448,11 +457,12 @@ func TestNewOptions_ValidFilter(t *testing.T) { want := &Options{ PageSize: 10, token: &token{ - KeyFieldName: "PrimaryKey", - SortByFieldName: "CreatedTimestamp", - IsDesc: false, - Filter: f, - Model: &fakeListable{}, + KeyFieldName: "PrimaryKey", + KeyFieldPrefix: "", + SortByFieldName: "CreatedTimestamp", + SortByFieldPrefix: "", + IsDesc: false, + Filter: f, }, } @@ -508,12 +518,13 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { in: &Options{ PageSize: 123, token: &token{ - SortByFieldName: "SortField", - SortByFieldValue: "value", - KeyFieldName: "KeyField", - KeyFieldValue: 1111, - IsDesc: true, - Model: &fakeListable{}, + SortByFieldName: "SortField", + SortByFieldValue: "value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: 1111, + KeyFieldPrefix: "", + IsDesc: true, }, }, wantSQL: "SELECT * FROM MyTable WHERE (SortField < ? OR (SortField = ? AND KeyField <= ?)) ORDER BY SortField DESC, KeyField DESC LIMIT 124", @@ -523,12 +534,13 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { in: &Options{ PageSize: 123, token: &token{ - SortByFieldName: "SortField", - SortByFieldValue: "value", - KeyFieldName: "KeyField", - KeyFieldValue: 1111, - IsDesc: false, - Model: &fakeListable{}, + SortByFieldName: "SortField", + SortByFieldValue: "value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: 1111, + KeyFieldPrefix: "", + IsDesc: false, }, }, wantSQL: "SELECT * FROM MyTable WHERE (SortField > ? OR (SortField = ? AND KeyField >= ?)) ORDER BY SortField ASC, KeyField ASC LIMIT 124", @@ -538,13 +550,14 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { in: &Options{ PageSize: 123, token: &token{ - SortByFieldName: "SortField", - SortByFieldValue: "value", - KeyFieldName: "KeyField", - KeyFieldValue: 1111, - IsDesc: false, - Filter: f, - Model: &fakeListable{}, + SortByFieldName: "SortField", + SortByFieldValue: "value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: 1111, + KeyFieldPrefix: "", + IsDesc: false, + Filter: f, }, }, wantSQL: "SELECT * FROM MyTable WHERE (SortField > ? OR (SortField = ? AND KeyField >= ?)) AND Name = ? ORDER BY SortField ASC, KeyField ASC LIMIT 124", @@ -554,11 +567,12 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { in: &Options{ PageSize: 123, token: &token{ - SortByFieldName: "SortField", - KeyFieldName: "KeyField", - KeyFieldValue: 1111, - IsDesc: true, - Model: &fakeListable{}, + SortByFieldName: "SortField", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldPrefix: "", + KeyFieldValue: 1111, + IsDesc: true, }, }, wantSQL: "SELECT * FROM MyTable ORDER BY SortField DESC, KeyField DESC LIMIT 124", @@ -568,11 +582,12 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { in: &Options{ PageSize: 123, token: &token{ - SortByFieldName: "SortField", - SortByFieldValue: "value", - KeyFieldName: "KeyField", - IsDesc: false, - Model: &fakeListable{}, + SortByFieldName: "SortField", + SortByFieldValue: "value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldPrefix: "", + IsDesc: false, }, }, wantSQL: "SELECT * FROM MyTable ORDER BY SortField ASC, KeyField ASC LIMIT 124", @@ -582,12 +597,13 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { in: &Options{ PageSize: 123, token: &token{ - SortByFieldName: "SortField", - SortByFieldValue: "value", - KeyFieldName: "KeyField", - IsDesc: false, - Filter: f, - Model: &fakeListable{}, + SortByFieldName: "SortField", + SortByFieldValue: "value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldPrefix: "", + IsDesc: false, + Filter: f, }, }, wantSQL: "SELECT * FROM MyTable WHERE Name = ? ORDER BY SortField ASC, KeyField ASC LIMIT 124", @@ -618,8 +634,6 @@ func TestTokenSerialization(t *testing.T) { t.Fatalf("failed to parse filter proto %+v: %v", protoFilter, err) } - modelMessage, _ := json.Marshal(&fakeListable{}) - tests := []struct { in *token want *token @@ -627,59 +641,62 @@ func TestTokenSerialization(t *testing.T) { // string values in sort by fields { in: &token{ - SortByFieldName: "SortField", - SortByFieldValue: "string_field_value", - KeyFieldName: "KeyField", - KeyFieldValue: "string_key_value", - IsDesc: true, - Model: &fakeListable{}}, + SortByFieldName: "SortField", + SortByFieldValue: "string_field_value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: "string_key_value", + KeyFieldPrefix: "", + IsDesc: true}, want: &token{ - SortByFieldName: "SortField", - SortByFieldValue: "string_field_value", - KeyFieldName: "KeyField", - KeyFieldValue: "string_key_value", - IsDesc: true, - ModelType: "fakeListable", - ModelMessage: modelMessage}, + SortByFieldName: "SortField", + SortByFieldValue: "string_field_value", + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: "string_key_value", + KeyFieldPrefix: "", + IsDesc: true}, }, // int values get deserialized as floats by JSON unmarshal. { in: &token{ - SortByFieldName: "SortField", - SortByFieldValue: 100, - KeyFieldName: "KeyField", - KeyFieldValue: 200, - IsDesc: true, - Model: &fakeListable{}}, + SortByFieldName: "SortField", + SortByFieldValue: 100, + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: 200, + KeyFieldPrefix: "", + IsDesc: true}, want: &token{ - SortByFieldName: "SortField", - SortByFieldValue: float64(100), - KeyFieldName: "KeyField", - KeyFieldValue: float64(200), - IsDesc: true, - ModelType: "fakeListable", - ModelMessage: modelMessage}, + SortByFieldName: "SortField", + SortByFieldValue: float64(100), + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: float64(200), + KeyFieldPrefix: "", + IsDesc: true}, }, // has a filter. { in: &token{ - SortByFieldName: "SortField", - SortByFieldValue: 100, - KeyFieldName: "KeyField", - KeyFieldValue: 200, - IsDesc: true, - Filter: testFilter, - Model: &fakeListable{}, + SortByFieldName: "SortField", + SortByFieldValue: 100, + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: 200, + KeyFieldPrefix: "", + IsDesc: true, + Filter: testFilter, }, want: &token{ - SortByFieldName: "SortField", - SortByFieldValue: float64(100), - KeyFieldName: "KeyField", - KeyFieldValue: float64(200), - IsDesc: true, - Filter: testFilter, - ModelType: "fakeListable", - ModelMessage: modelMessage, + SortByFieldName: "SortField", + SortByFieldValue: float64(100), + SortByFieldPrefix: "", + KeyFieldName: "KeyField", + KeyFieldValue: float64(200), + KeyFieldPrefix: "", + IsDesc: true, + Filter: testFilter, }, }, } From 6124c829e2384db55926f165a0d0e1011d59de6e Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Fri, 31 Jul 2020 16:12:57 +0800 Subject: [PATCH 14/14] matches func update --- backend/src/apiserver/list/list.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/src/apiserver/list/list.go b/backend/src/apiserver/list/list.go index 3ebdbd2d5ef..f9c357668b2 100644 --- a/backend/src/apiserver/list/list.go +++ b/backend/src/apiserver/list/list.go @@ -100,7 +100,7 @@ type Options struct { // Matches returns trues if the sorting and filtering criteria in o matches that // of the one supplied in opts. func (o *Options) Matches(opts *Options) bool { - return o.SortByFieldName == opts.SortByFieldName && + return o.SortByFieldName == opts.SortByFieldName && o.SortByFieldPrefix == opts.SortByFieldPrefix && o.IsDesc == opts.IsDesc && reflect.DeepEqual(o.Filter, opts.Filter) }