Skip to content

Commit

Permalink
fix: chanio constructors return errors
Browse files Browse the repository at this point in the history
  • Loading branch information
jensdrenhaus committed Sep 10, 2024
1 parent fe26b8d commit 2737ff9
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 9 deletions.
34 changes: 32 additions & 2 deletions cmds/dutagent/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,43 @@ func (s *session) Print(text string) {

//nolint:nonamedreturns
func (s *session) Console() (stdin io.Reader, stdout, stderr io.Writer) {
return chanio.NewChanReader(s.stdin), chanio.NewChanWriter(s.stdout), chanio.NewChanWriter(s.stderr)
var (
stdinReader io.Reader
stdoutWriter, stderrWriter io.Writer
err error
)

stdinReader, err = chanio.NewChanReader(s.stdin)
if err != nil {
log.Fatalf("session.Console() failed to create stdinReader: %v", err)
}

stdoutWriter, err = chanio.NewChanWriter(s.stdout)
if err != nil {
log.Fatalf("session.Console() failed to create stdoutWriter: %v", err)
}

stderrWriter, err = chanio.NewChanWriter(s.stderr)
if err != nil {
log.Fatalf("session.Console() failed to create stderrWriter: %v", err)
}

return stdinReader, stdoutWriter, stderrWriter
}

func (s *session) RequestFile(name string) (io.Reader, error) {
if s.fileReq == nil {
log.Fatal("session.RequestFile() called but session.fileReq is nil")
}

r, err := chanio.NewChanReader(s.file)
if err != nil {
log.Fatalf("session.RequestFile() failed to create reader: %v", err)
}

s.fileReq <- name

return chanio.NewChanReader(s.file), nil
return r, nil
}

func (s *session) SendFile(_ string, _ io.Reader) error {
Expand Down
20 changes: 15 additions & 5 deletions internal/chanio/chanio.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package chanio

import (
"errors"
"io"
"log"
)
Expand All @@ -14,11 +15,16 @@ type ChanReader struct {
buf []byte // Buffer to store excess bytes
}

func NewChanReader(ch <-chan []byte) *ChanReader {
// NewChanReader returns a new ChanReader reading from ch. The provided channel must not be nil.
func NewChanReader(ch <-chan []byte) (*ChanReader, error) {
if ch == nil {
return nil, errors.New("cannot create a ChanReader with a nil channel")
}

return &ChanReader{
ch: ch,
buf: make([]byte, 0),
}
}, nil
}

// Read reads up to len(bytes) bytes into bytes. It returns the number of bytes
Expand Down Expand Up @@ -85,11 +91,15 @@ type ChanWriter struct {
ch chan<- []byte
}

// NewChanWriter returns a new ChanWriter writing to ch.
func NewChanWriter(ch chan<- []byte) *ChanWriter {
// NewChanWriter returns a new ChanWriter writing to ch. The provided channel must not be nil.
func NewChanWriter(ch chan<- []byte) (*ChanWriter, error) {
if ch == nil {
return nil, errors.New("cannot create a ChanWriter with a nil channel")
}

return &ChanWriter{
ch: ch,
}
}, nil
}

// Write writes len(bytes) bytes from bytes to the underlying data stream.
Expand Down
36 changes: 34 additions & 2 deletions internal/chanio/chanio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,38 @@ import (
)

func TestNewChanWriter(t *testing.T) {
// Test with valid channel
ch := make(chan []byte)
writer := NewChanWriter(ch)
writer, err := NewChanWriter(ch)

if err != nil {
t.Fatalf("NewChanWriter() returned an error: %v", err)
}

if writer.ch == nil {
t.Errorf("NewChanWriter() returned a ChanWriter with a nil channel")
}

// Test with nil channel
writer, err = NewChanWriter(nil)

if err == nil {
t.Fatalf("NewChanWriter() did not return an error for nil channel")
}

if writer != nil {
t.Errorf("NewChanWriter() returned a non-nil ChanWriter for nil channel")
}
}

func TestNewChanReader(t *testing.T) {
// Test with valid channel
ch := make(chan []byte)
reader := NewChanReader(ch)
reader, err := NewChanReader(ch)

if err != nil {
t.Fatalf("NewChanReader() returned an error: %v", err)
}

if reader.ch == nil {
t.Errorf("NewChanReader() returned a ChanReader with a nil channel")
Expand All @@ -27,6 +48,17 @@ func TestNewChanReader(t *testing.T) {
if reader.buf == nil {
t.Errorf("NewChanReader() returned a ChanReader with a nil buffer")
}

// Test with nil channel
reader, err = NewChanReader(nil)

if err == nil {
t.Fatalf("NewChanReader() did not return an error for nil channel")
}

if reader != nil {
t.Errorf("NewChanReader() returned a non-nil ChanReader for nil channel")
}
}

func TestChanReader_Read(t *testing.T) {
Expand Down

0 comments on commit 2737ff9

Please sign in to comment.