Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More DB unit tests #234

Merged
merged 2 commits into from
Nov 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions pkg/db/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ func (d *dbConn) GetSuggestionParam(paramID string) ([]*api.SuggestionParameter,
func (d *dbConn) GetSuggestionParamList(studyID string) ([]*api.SuggestionParameterSet, error) {
var rows *sql.Rows
var err error
rows, err = d.db.Query("SELECT * FROM suggestion_param WHERE study_id = ?", studyID)
rows, err = d.db.Query("SELECT id, suggestion_algo, parameters FROM suggestion_param WHERE study_id = ?", studyID)
if err != nil {
return nil, err
}
Expand All @@ -976,14 +976,10 @@ func (d *dbConn) GetSuggestionParamList(studyID string) ([]*api.SuggestionParame
var id string
var algorithm string
var params string
var sID string
err := rows.Scan(&id, &sID, &algorithm, &params)
err := rows.Scan(&id, &algorithm, &params)
if err != nil {
return nil, err
}
if studyID != sID {
continue
}
var pArray []string
if len(params) > 0 {
pArray = strings.Split(params, ",\n")
Expand Down Expand Up @@ -1021,7 +1017,7 @@ func (d *dbConn) SetEarlyStopParam(algorithm string, studyID string, params []*a
}
var paramID string
for true {
paramID := generateRandid()
paramID = generateRandid()
_, err = d.db.Exec("INSERT INTO earlystopping_param VALUES (?,?, ?, ?)",
paramID, algorithm, studyID, strings.Join(ps, ",\n"))
if err == nil {
Expand Down Expand Up @@ -1077,7 +1073,7 @@ func (d *dbConn) GetEarlyStopParam(paramID string) ([]*api.EarlyStoppingParamete
func (d *dbConn) GetEarlyStopParamList(studyID string) ([]*api.EarlyStoppingParameterSet, error) {
var rows *sql.Rows
var err error
rows, err = d.db.Query("SELECT * FROM earlystopping_param WHERE study_id = ?", studyID)
rows, err = d.db.Query("SELECT id, earlystop_algo, parameters FROM earlystopping_param WHERE study_id = ?", studyID)
if err != nil {
return nil, err
}
Expand All @@ -1086,14 +1082,10 @@ func (d *dbConn) GetEarlyStopParamList(studyID string) ([]*api.EarlyStoppingPara
var id string
var algorithm string
var params string
var sID string
err := rows.Scan(&id, &sID, &algorithm, &params)
err := rows.Scan(&id, &algorithm, &params)
if err != nil {
return nil, err
}
if studyID != sID {
continue
}
var pArray []string
if len(params) > 0 {
pArray = strings.Split(params, ",\n")
Expand Down
247 changes: 228 additions & 19 deletions pkg/db/interface_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ import (
var dbInterface VizierDBInterface
var mock sqlmock.Sqlmock

var studyColumns = []string{
"id", "name", "owner", "optimization_type", "optimization_goal",
"parameter_configs", "tags", "objective_value_name",
"metrics", "job_id"}
var trialColumns = []string{
"id", "study_id", "parameters", "objective_value", "tags"}
var workerColumns = []string{"id",
"study_id", "trial_id", "type",
"status", "template_path", "tags"}

func TestMain(m *testing.M) {
db, sm, err := sqlmock.New()
mock = sm
Expand Down Expand Up @@ -55,25 +65,43 @@ func TestGetStudyConfig(t *testing.T) {
}
// mock.ExpectExec("SELECT * FROM studies WHERE id").WithArgs(id).WillReturnRows(sqlmock.NewRows())
mock.ExpectQuery("SELECT").WillReturnRows(
sqlmock.NewRows([]string{
"id",
"name",
"owner",
"optimization_type",
"optimization_goal",
"parameter_configs",
"tags",
"objective_value_name",
"metrics",
"job_id",
}).
AddRow("abc", "test", "admin", 1, 0.99, "{}", "", "", "", "test"))
sqlmock.NewRows(studyColumns).AddRow(
"abc", "test", "admin", 1, 0.99, "{}", "", "", "", "test"))
study, err := dbInterface.GetStudyConfig(id)
if err != nil {
t.Errorf("GetStudyConfig failed: %v", err)
} else if study.Name != "test" || study.Owner != "admin" {
t.Errorf("GetStudyConfig incorrect return %v", study)
}
}

func TestGetStudyList(t *testing.T) {
ids := []string{"abcde1234567890f", "bcde1234567890fa"}
mock.ExpectQuery("SELECT id FROM studies").WillReturnRows(
sqlmock.NewRows([]string{"id"}).AddRow(ids[0]).AddRow(ids[1]))
r, err := dbInterface.GetStudyList()
if err != nil {
t.Errorf("GetStudyList error %v", err)
}
if len(r) != len(ids) {
t.Errorf("GetStudyList returned incorrect number of ids %d != %d",
len(r), len(ids))
}
for i, id := range r {
if ids[i] != id {
t.Errorf("GetStudyList returned incorrect ID %s != %s",
id, ids[i])
}
}
}

func TestDeleteStudy(t *testing.T) {
studyID := generateRandid()
mock.ExpectExec(`DELETE FROM studies WHERE id = \?`).WithArgs(studyID).WillReturnResult(sqlmock.NewResult(1, 1))
err := dbInterface.DeleteStudy(studyID)
if err != nil {
t.Errorf("DeleteStudy error %v", err)
}
fmt.Printf("%v", study)
// TODO: check study data
}

func TestCreateStudyIdGeneration(t *testing.T) {
Expand Down Expand Up @@ -108,6 +136,58 @@ func TestCreateStudyIdGeneration(t *testing.T) {
}
}

func TestGetTrial(t *testing.T) {
id := generateRandid()
mock.ExpectQuery(`SELECT \* FROM trials WHERE id = \?`).WillReturnRows(
sqlmock.NewRows(trialColumns).AddRow(
id, "s1234567890abcde",
"{\"name\": \"1\"},\n{}", "obj_val",
"{\"name\": \"foo\"},\n{}"))
trial, err := dbInterface.GetTrial(id)
if err != nil {
t.Errorf("GetTrial error %v", err)
} else if len((*trial).Tags) != 2 {
t.Errorf("GetTrial returned incorrect Tag %v", (*trial).Tags)
}
}

func TestGetTrialList(t *testing.T) {
studyID := generateRandid()
var ids = []string{"abcdef1234567890", "bcdef1234567890a"}
rows := sqlmock.NewRows(trialColumns)
for _, id := range ids {
rows.AddRow(id, studyID, "", "obj_val", "")
}
mock.ExpectQuery(`SELECT \* FROM trials WHERE study_id = \?`).WithArgs(studyID).WillReturnRows(rows)
trials, err := dbInterface.GetTrialList(studyID)
if err != nil {
t.Errorf("GetTrialList error %v", err)
} else if len(trials) != len(ids) {
t.Errorf("GetTrialList returned incorrect number of trials %d != %d",
len(trials), len(ids))
}
}

func TestCreateTrial(t *testing.T) {
var trial api.Trial
trial.StudyId = generateRandid()
mock.ExpectExec(`INSERT INTO trials VALUES \(`).WithArgs(sqlmock.AnyArg(),
trial.StudyId, "", "", "").WillReturnResult(sqlmock.NewResult(1, 1))
err := dbInterface.CreateTrial(&trial)
if err != nil {
t.Errorf("CreateTrial error %v", err)
}
}

func TestDeleteTrial(t *testing.T) {
id := generateRandid()
mock.ExpectExec(`DELETE FROM trials WHERE id = \?`).WithArgs(id).WillReturnResult(sqlmock.NewResult(1, 1))
err := dbInterface.DeleteTrial(id)
if err != nil {
t.Errorf("DeleteTrial error %v", err)
}
}

func TestCreateWorker(t *testing.T) {
var w api.Worker
w.StudyId = generateRandid()
Expand All @@ -124,10 +204,6 @@ func TestCreateWorker(t *testing.T) {
}
}

var workerColumns = []string{"id",
"study_id", "trial_id", "type",
"status", "template_path", "tags"}

const defaultWorkerID = "w123456789abcdef"
const objValueName = "obj_value"

Expand Down Expand Up @@ -189,6 +265,31 @@ func TestDeleteWorker(t *testing.T) {

}

func TestGetWorkerFullInfo(t *testing.T) {
studyID := generateRandid()
wRows := sqlmock.NewRows(workerColumns)
wRows.AddRow("w1134567890abcde", studyID, "", "", "1", "", "")
wRows.AddRow("w2234567890abcde", studyID, "", "", "2", "", "")
mock.ExpectQuery(`SELECT \* FROM workers WHERE study_id = \?`).WithArgs(studyID).WillReturnRows(wRows)
mock.ExpectQuery(`SELECT \* FROM trials WHERE study_id = \?`).WithArgs(studyID).WillReturnRows(
sqlmock.NewRows(trialColumns))
mock.ExpectQuery(`SELECT \* FROM studies WHERE id = \?`).WithArgs(studyID).WillReturnRows(
sqlmock.NewRows(studyColumns).AddRow(
studyID, "test", "admin", 1, 0.99, "{}", "", "", "foo,\nbar", "test"))
WMRows := sqlmock.NewRows([]string{"WM.worker_id", "WM.time", "WM.name", "WM.value"})
WMRows.AddRow("w1134567890abcde", "2012-01-01 09:54:32", "foo", "1")
WMRows.AddRow("w1134567890abcde", "2012-01-01 09:54:32", "bar", "1")
mock.ExpectQuery(`SELECT WM.worker_id, WM.time, WM.name, WM.value FROM .* MaxID .* ON WM.worker_id`).WithArgs(studyID).WillReturnRows(WMRows)

fi, err := dbInterface.GetWorkerFullInfo(studyID, "", "", true)
if err != nil {
t.Errorf("GetWorkerFullInfo error %v", err)
} else if len(fi.WorkerFullInfos) != 2 ||
len(fi.WorkerFullInfos[0].MetricsLogs) != 2 {
t.Errorf("GetWorkerFullInfo incorrect return %v", fi)
}
}

type MetricsLogData struct {
stored bool
name string
Expand Down Expand Up @@ -337,3 +438,111 @@ func TestGetWorkerLogs(t *testing.T) {
}
}
}

func TestSetSuggestionParam(t *testing.T) {
sp := make([]*api.SuggestionParameter, 1)
sp[0] = &api.SuggestionParameter{Name: "DefaultGrid", Value: "1"}
studyID := generateRandid()
mock.ExpectExec("INSERT INTO suggestion_param VALUES").WithArgs(
sqlmock.AnyArg(), "grid", studyID,
`{"name":"DefaultGrid","value":"1"}`).WillReturnResult(sqlmock.NewResult(1, 1))
id, err := dbInterface.SetSuggestionParam("grid", studyID, sp)
if err != nil {
t.Errorf("SetSuggestionParam error %v", err)
} else if len(id) != 16 {
t.Errorf("SetSuggestionParam returned incorrect ID %s", id)
}
}

func TestUpdateSuggestionParam(t *testing.T) {
sp := make([]*api.SuggestionParameter, 1)
sp[0] = &api.SuggestionParameter{Name: "DefaultGrid", Value: "12"}
id := generateRandid()
mock.ExpectExec(`UPDATE suggestion_param SET parameters = \? WHERE id = \?`).WithArgs(
`{"name":"DefaultGrid","value":"12"}`, id).WillReturnResult(sqlmock.NewResult(1, 1))
err := dbInterface.UpdateSuggestionParam(id, sp)
if err != nil {
t.Errorf("UpdateSuggestionParam error %v", err)
}
}

func TestGetSuggestionParam(t *testing.T) {
id := generateRandid()
mock.ExpectQuery(`SELECT parameters FROM suggestion_param WHERE id = \?`).WithArgs(id).WillReturnRows(
sqlmock.NewRows([]string{"parameters"}).AddRow(
`{"name":"DefaultGrid","value":"12"}`))
sp, err := dbInterface.GetSuggestionParam(id)
if err != nil {
t.Errorf("GetSuggestionParam error %v", err)
} else if len(sp) != 1 {
t.Errorf("GetSuggestionParam returned incorrect number of data %v", sp)
}
}

func TestGetSuggestionParamList(t *testing.T) {
studyID := generateRandid()
mock.ExpectQuery(`SELECT id, suggestion_algo, parameters FROM suggestion_param WHERE study_id = \?`).WithArgs(studyID).WillReturnRows(
sqlmock.NewRows([]string{"id", "suggestion_algo", "parameters"}).AddRow(
generateRandid(), "random", "{}"))

sp, err := dbInterface.GetSuggestionParamList(studyID)
if err != nil {
t.Errorf("GetSuggestionParamList error %v", err)
} else if len(sp) != 1 {
t.Errorf("GetSuggestionParamList returned incorrect number of data %v", sp)
}
}

func TestSetEarlyStopParam(t *testing.T) {
ep := make([]*api.EarlyStoppingParameter, 1)
ep[0] = &api.EarlyStoppingParameter{Name: "LeastStep", Value: "1"}
studyID := generateRandid()
mock.ExpectExec("INSERT INTO earlystopping_param VALUES").WithArgs(
sqlmock.AnyArg(), "medianstopping", studyID,
`{"name":"LeastStep","value":"1"}`).WillReturnResult(sqlmock.NewResult(1, 1))
id, err := dbInterface.SetEarlyStopParam("medianstopping", studyID, ep)
if err != nil {
t.Errorf("SetEarlyStopParam error %v", err)
} else if len(id) != 16 {
t.Errorf("SetEarlyStopParam returned incorrect ID %s", id)
}
}

func TestUpdateEarlyStopParam(t *testing.T) {
ep := make([]*api.EarlyStoppingParameter, 1)
ep[0] = &api.EarlyStoppingParameter{Name: "LeastStep", Value: "12"}
id := generateRandid()
mock.ExpectExec(`UPDATE earlystopping_param SET parameters = \? WHERE id = \?`).WithArgs(
`{"name":"LeastStep","value":"12"}`, id).WillReturnResult(sqlmock.NewResult(1, 1))
err := dbInterface.UpdateEarlyStopParam(id, ep)
if err != nil {
t.Errorf("UpdateEarlyStopParamerror %v", err)
}
}

func TestGetEarlyStopParam(t *testing.T) {
id := generateRandid()
mock.ExpectQuery(`SELECT parameters FROM earlystopping_param WHERE id = \?`).WithArgs(id).WillReturnRows(
sqlmock.NewRows([]string{"parameters"}).AddRow(
`{"name":"LeastStep","value":"12"}`))
ep, err := dbInterface.GetEarlyStopParam(id)
if err != nil {
t.Errorf("GetEarlyStopParam error %v", err)
} else if len(ep) != 1 {
t.Errorf("GetEarlyStopParam returned incorrect number of data %v", ep)
}
}

func TestGetEarlyStopParamList(t *testing.T) {
studyID := generateRandid()
mock.ExpectQuery(`SELECT id, earlystop_algo, parameters FROM earlystopping_param WHERE study_id = \?`).WithArgs(studyID).WillReturnRows(
sqlmock.NewRows([]string{"id", "earlystop_algo", "parameters"}).AddRow(
generateRandid(), "medianstopping", "{}"))

ep, err := dbInterface.GetEarlyStopParamList(studyID)
if err != nil {
t.Errorf("GetEarlyStopParamList error %v", err)
} else if len(ep) != 1 {
t.Errorf("GetEarlyStopParamList returned incorrect number of data %v", ep)
}
}