Skip to content

Commit

Permalink
added call tests, added mock package
Browse files Browse the repository at this point in the history
  • Loading branch information
kndndrj committed Jan 11, 2024
1 parent 4acfedf commit 5663c33
Show file tree
Hide file tree
Showing 14 changed files with 610 additions and 162 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/compile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
run: ./ci/target-matrix.sh >> "$GITHUB_OUTPUT"

go-build:
needs: [ assemble-os-matrix ]
needs: [assemble-os-matrix]
name: Go Build
strategy:
matrix: ${{ fromJson(needs.assemble-os-matrix.outputs.matrix) }}
Expand Down Expand Up @@ -83,7 +83,7 @@ jobs:
runs-on: ubuntu-22.04
if: github.event_name == 'release'
name: Create Install Manifest
needs: [ go-build ]
needs: [go-build]
env:
MANIFEST_FILE: "lua/dbee/install/__manifest.lua"
steps:
Expand Down
4 changes: 4 additions & 0 deletions dbee/core/builders/next.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ func NextSingle(value any) (func() (core.Row, error), func() bool) {
// NextSlice creates next and hasNext functions from provided values
// preprocessor is an optional function which parses a single value from slice before adding it to a row
func NextSlice[T any](values []T, preprocess func(T) any) (func() (core.Row, error), func() bool) {
if preprocess == nil {
preprocess = func(v T) any { return v }
}

index := 0

hasNext := func() bool {
Expand Down
67 changes: 36 additions & 31 deletions dbee/core/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ type (
timeTaken time.Duration
timestamp time.Time

result *Result
archive *archive
cancelFunc func()
onEventFunc func(*Call)
result *Result
archive *archive
cancelFunc func()

// any error that might occur during execution
err error
Expand Down Expand Up @@ -99,46 +98,66 @@ func (c *Call) UnmarshalJSON(data []byte) error {
return nil
}

func newCallFromExecutor(executor func(context.Context) (ResultStream, error), query string, onEvent func(*Call)) *Call {
func newCallFromExecutor(executor func(context.Context) (ResultStream, error), query string, onEvent func(CallState, *Call)) *Call {
id := CallID(uuid.New().String())
c := &Call{
id: id,
query: query,
state: CallStateUnknown,

result: new(Result),
archive: newArchive(id),
onEventFunc: onEvent,
result: new(Result),
archive: newArchive(id),

done: make(chan struct{}),
}

eventsCh := make(chan CallState, 10)

ctx, cancel := context.WithCancel(context.Background())
c.timestamp = time.Now()
c.cancelFunc = func() {
cancel()
c.timeTaken = time.Since(c.timestamp)
c.setState(CallStateCanceled)
eventsCh <- CallStateCanceled
}

// event function handler
go func() {
for state := range eventsCh {
if c.state == CallStateExecutingFailed ||
c.state == CallStateRetrievingFailed ||
c.state == CallStateCanceled {
return
}
c.state = state

// trigger event callback
if onEvent != nil {
onEvent(state, c)
}
}
}()

go func() {
defer close(eventsCh)

// execute the function
c.setState(CallStateExecuting)
eventsCh <- CallStateExecuting
iter, err := executor(ctx)
if err != nil {
c.timeTaken = time.Since(c.timestamp)
c.err = err
c.setState(CallStateExecutingFailed)
eventsCh <- CallStateExecutingFailed
close(c.done)
return
}

// set iterator to result
err = c.result.setIter(iter, func() { c.setState(CallStateRetrieving) })
err = c.result.SetIter(iter, func() { eventsCh <- CallStateRetrieving })
if err != nil {
c.timeTaken = time.Since(c.timestamp)
c.err = err
c.setState(CallStateRetrievingFailed)
eventsCh <- CallStateRetrievingFailed
close(c.done)
return
}
Expand All @@ -148,13 +167,13 @@ func newCallFromExecutor(executor func(context.Context) (ResultStream, error), q
if err != nil {
c.timeTaken = time.Since(c.timestamp)
c.err = err
c.setState(CallStateArchiveFailed)
eventsCh <- CallStateArchiveFailed
close(c.done)
return
}

c.timeTaken = time.Since(c.timestamp)
c.setState(CallStateArchived)
eventsCh <- CallStateArchived
close(c.done)
}()

Expand Down Expand Up @@ -191,22 +210,8 @@ func (c *Call) Done() chan struct{} {
return c.done
}

func (c *Call) setState(state CallState) {
if c.state == CallStateExecutingFailed ||
c.state == CallStateRetrievingFailed ||
c.state == CallStateCanceled {
return
}
c.state = state

// trigger event callback
if c.onEventFunc != nil {
c.onEventFunc(c)
}
}

func (c *Call) Cancel() {
if c.state != CallStateExecuting {
if c.state > CallStateExecuting {
return
}
if c.cancelFunc != nil {
Expand All @@ -220,7 +225,7 @@ func (c *Call) GetResult() (*Result, error) {
if err != nil {
return nil, fmt.Errorf("c.archive.getResult: %w", err)
}
err = c.result.setIter(iter, nil)
err = c.result.SetIter(iter, nil)
if err != nil {
return nil, fmt.Errorf("c.result.setIter: %w", err)
}
Expand Down
196 changes: 196 additions & 0 deletions dbee/core/call_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
package core_test

import (
"context"
"encoding/json"
"errors"
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/kndndrj/nvim-dbee/dbee/core"
"github.com/kndndrj/nvim-dbee/dbee/core/mock"
)

func TestCall_Success(t *testing.T) {
r := require.New(t)

rows := mock.NewRows(0, 10)

connection, err := core.NewConnection(&core.ConnectionParams{}, mock.NewAdapter(rows,
mock.AdapterWithResultStreamOpts(mock.ResultStreamWithNextSleep(300*time.Millisecond)),
))
r.NoError(err)

expectedEvents := []core.CallState{
core.CallStateExecuting,
core.CallStateRetrieving,
core.CallStateArchived,
}

eventIndex := 0
call := connection.Execute("", func(state core.CallState, c *core.Call) {
// make sure events were in order
r.Equal(expectedEvents[eventIndex], state)
eventIndex++

if state == core.CallStateRetrieving {
result, err := c.GetResult()
r.NoError(err)

actualRows, err := result.Rows(0, len(rows))
r.NoError(err)

r.Equal(rows, actualRows)
}
})

// wait for call to finish
select {
case <-call.Done():
// wait a bit for event index to stabilize
time.Sleep(100 * time.Millisecond)
case <-time.After(5 * time.Second):
t.Error("call did not finish in expected time")
}

// make sure all events passed
r.Equal(len(expectedEvents), eventIndex)
}

func TestCall_Cancel(t *testing.T) {
r := require.New(t)

rows := mock.NewRows(0, 10)

adapter := mock.NewAdapter(rows,
mock.AdapterWithQuerySideEffect("wait", func(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(10 * time.Second):
}
return nil
}),
mock.AdapterWithResultStreamOpts(mock.ResultStreamWithNextSleep(300*time.Millisecond)),
)

connection, err := core.NewConnection(&core.ConnectionParams{}, adapter)
r.NoError(err)

expectedEvents := []core.CallState{
core.CallStateExecuting,
core.CallStateCanceled,
}

eventIndex := 0
call := connection.Execute("wait", func(state core.CallState, c *core.Call) {
// wait for first event and cancel request
c.Cancel()
// make sure events were in order
r.Equal(expectedEvents[eventIndex], state)
eventIndex++
})

// wait for call to finish
select {
case <-call.Done():
// wait a bit for event index to stabilize
time.Sleep(100 * time.Millisecond)
case <-time.After(5 * time.Second):
t.Error("call did not finish in expected time")
}

// make sure all events passed
r.Equal(len(expectedEvents), eventIndex)
}

func TestCall_FailedQuery(t *testing.T) {
r := require.New(t)

rows := mock.NewRows(0, 10)

adapter := mock.NewAdapter(rows,
mock.AdapterWithQuerySideEffect("fail", func(ctx context.Context) error {
return errors.New("query failed")
}),
mock.AdapterWithResultStreamOpts(mock.ResultStreamWithNextSleep(300*time.Millisecond)),
)

connection, err := core.NewConnection(&core.ConnectionParams{}, adapter)
r.NoError(err)

expectedEvents := []core.CallState{
core.CallStateExecuting,
core.CallStateExecutingFailed,
}

eventIndex := 0
call := connection.Execute("fail", func(state core.CallState, c *core.Call) {
// make sure events were in order
r.Equal(expectedEvents[eventIndex], state)
eventIndex++

if state == core.CallStateExecutingFailed {
r.NotNil(c.Err())
}
})

// wait for call to finish
select {
case <-call.Done():
// wait a bit for event index to stabilize
time.Sleep(100 * time.Millisecond)
case <-time.After(5 * time.Second):
t.Error("call did not finish in expected time")
}

// make sure all events passed
r.Equal(len(expectedEvents), eventIndex)
}

func TestCall_Archive(t *testing.T) {
r := require.New(t)

rows := mock.NewRows(0, 10)

connection, err := core.NewConnection(&core.ConnectionParams{}, mock.NewAdapter(rows,
mock.AdapterWithResultStreamOpts(mock.ResultStreamWithNextSleep(300*time.Millisecond)),
))
r.NoError(err)

call := connection.Execute("", nil)

// wait for call to finish
select {
case <-call.Done():
// wait a bit for event index to stabilize
time.Sleep(100 * time.Millisecond)
case <-time.After(5 * time.Second):
t.Error("call did not finish in expected time")
}

// check result
result, err := call.GetResult()
r.NoError(err)
actualRows, err := result.Rows(0, len(rows))
r.NoError(err)
r.Equal(rows, actualRows)

// marshal to json
b, err := json.Marshal(call)
r.NoError(err)

// marshal back
restoredCall := new(core.Call)
err = json.Unmarshal(b, restoredCall)
r.NoError(err)

// check result again
result, err = restoredCall.GetResult()
r.NoError(err)
actualRows, err = result.Rows(0, len(rows))
r.NoError(err)
r.Equal(rows, actualRows)
}
2 changes: 1 addition & 1 deletion dbee/core/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (c *Connection) GetParams() *ConnectionParams {
return c.unexpandedParams
}

func (c *Connection) Execute(query string, onEvent func(*Call)) *Call {
func (c *Connection) Execute(query string, onEvent func(CallState, *Call)) *Call {
exec := func(ctx context.Context) (ResultStream, error) {
return c.driver.Query(ctx, query)
}
Expand Down
Loading

0 comments on commit 5663c33

Please sign in to comment.