Skip to content

Commit

Permalink
validate MaxMessageSize in blockwise and tcp session (dustin#42)
Browse files Browse the repository at this point in the history
* validate MaxMessageSize in blockwise and tcp session
  • Loading branch information
arun1587 authored and jkralik committed Mar 26, 2019
1 parent a643abf commit 1958954
Show file tree
Hide file tree
Showing 11 changed files with 144 additions and 1 deletion.
41 changes: 41 additions & 0 deletions blockwise.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,29 @@ func (b *blockWiseSession) WriteMsg(msg Message) error {
return b.WriteMsgWithContext(context.Background(), msg)
}

func (b *blockWiseSession) validateMessageSize(msg Message) error {
size, err := msg.ToBytesLength()
if err != nil {
return err
}
session, ok := b.networkSession.(*sessionTCP)
if !ok {
// Not supported for UDP session
return nil
}

if session.peerMaxMessageSize != 0 &&
uint32(size) > session.peerMaxMessageSize {
return ErrMaxMessageSizeLimitExceeded
}

return nil
}

func (b *blockWiseSession) WriteMsgWithContext(ctx context.Context, msg Message) error {
if err := b.validateMessageSize(msg); err != nil {
return err
}
switch msg.Code() {
case CSM, Ping, Pong, Release, Abort, Empty, GET:
return b.networkSession.WriteMsgWithContext(ctx, msg)
Expand Down Expand Up @@ -532,7 +554,26 @@ func (r *blockWiseReceiver) exchange(ctx context.Context, b *blockWiseSession, r
return resp, err
}

func (r *blockWiseReceiver) validateMessageSize(msg Message, b *blockWiseSession) error {
size, err := msg.ToBytesLength()
if err != nil {
return err
}

session, ok := b.networkSession.(*sessionTCP)
if ok {
if session.srv.MaxMessageSize != 0 &&
uint32(size) > session.srv.MaxMessageSize {
return ErrMaxMessageSizeLimitExceeded
}
}
return nil
}

func (r *blockWiseReceiver) processResp(b *blockWiseSession, req Message, resp Message) (Message, error) {
if err := r.validateMessageSize(req, b); err != nil {
return nil, err
}
if respBlock, ok := resp.Option(r.blockType).(uint32); ok {
szx, num, more, err := UnmarshalBlockOption(respBlock)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ func testServingObservation(t *testing.T, net string, addrstr string, BlockWiseT
Net: net,
BlockWiseTransfer: &BlockWiseTransfer,
BlockWiseTransferSzx: &BlockWiseTransferSzx,
MaxMessageSize: ^uint32(0),
}

conn, err := client.Dial(addrstr)
Expand Down
3 changes: 3 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,6 @@ const ErrUnexpectedReponseCode = Error("unexpected response code")

// ErrMessageNotInterested message is not of interest to the client
const ErrMessageNotInterested = Error("message not to be sent due to disinterest")

// ErrMaxMessageSizeLimitExceeded message size bigger thab maximum message size limit
const ErrMaxMessageSizeLimitExceeded = Error("maximum message size limit exceeded")
1 change: 1 addition & 0 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ type Message interface {
UnmarshalBinary(data []byte) error
SetToken(t []byte)
SetMessageID(messageID uint16)
ToBytesLength() (int, error)
}

// MessageParams params to create COAP message
Expand Down
21 changes: 21 additions & 0 deletions message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1132,3 +1132,24 @@ func TestDecodeMessageWithNoResponseOption(t *testing.T) {
t.Fatalf("parsedMsg.Option(NoResponse): %v", parsedMsg.Option(NoResponse).(uint32))
}
}

func TestToBytesLength(t *testing.T) {
data := []byte{
0x40, 0x1, 0x30, 0x39, 0x46, 0x77,
0x65, 0x65, 0x74, 0x61, 0x67, 0xa1, 0x3,
}

msg, err := ParseDgramMessage(data)
if err != nil {
t.Fatalf("Error parsing request: %v", err)
}

bytesLength, err := msg.ToBytesLength()
if err != nil {
t.Fatalf("Error parsing request: %v", err)
}

if len(data) != bytesLength {
t.Errorf("Expected Length = %d, got %d", len(data), bytesLength)
}
}
11 changes: 11 additions & 0 deletions messagedgram.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package coap

import (
"bytes"
"encoding/binary"
"io"
"sort"
Expand Down Expand Up @@ -114,3 +115,13 @@ func ParseDgramMessage(data []byte) (*DgramMessage, error) {
rv := &DgramMessage{}
return rv, rv.UnmarshalBinary(data)
}

// ToBytesLength gets the length of the message
func (m *DgramMessage) ToBytesLength() (int, error) {
buf := bytes.NewBuffer(make([]byte, 0, 1024))
if err := m.MarshalBinary(buf); err != nil {
return 0, err
}

return len(buf.Bytes()), nil
}
9 changes: 9 additions & 0 deletions messagetcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,15 @@ func (m *TcpMessage) fill(mti msgTcpInfo, o options, p []byte) {
m.MessageBase.payload = p
}

func (m *TcpMessage) ToBytesLength() (int, error) {
buf := bytes.NewBuffer(make([]byte, 0, 1024))
if err := m.MarshalBinary(buf); err != nil {
return 0, err
}

return len(buf.Bytes()), nil
}

type contextBytesReader struct {
reader io.Reader
}
Expand Down
33 changes: 33 additions & 0 deletions messagetcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,37 @@ func TestTCPDecodeMessageSmallWithPayload(t *testing.T) {
if !bytes.Equal(msg.Payload(), []byte("hi")) {
t.Errorf("Incorrect payload: %q", msg.Payload())
}

}

func TestMessageTCPToBytesLength(t *testing.T) {
msgParams := MessageParams{
Code: COAPCode(02),
Token: []byte{0xab},
Payload: []byte("hi"),
}

msg := NewTcpMessage(msgParams)
msg.AddOption(MaxMessageSize, maxMessageSize)

buf := &bytes.Buffer{}
err := msg.MarshalBinary(buf)
if err != nil {
t.Fatalf("Error encoding request: %v", err)
}

bytesLength, err := msg.ToBytesLength()
if err != nil {
t.Fatalf("Error parsing request: %v", err)
}

lenTkl := 1
lenCode := 1
maxMessageSizeOptionLength := 3
payloadMarker := []byte{0xff}

expectedLength := lenTkl + lenCode + len(msgParams.Token) + maxMessageSizeOptionLength + len(payloadMarker) + len(msgParams.Payload)
if expectedLength != bytesLength {
t.Errorf("Expected Length = %d, got %d", expectedLength, bytesLength)
}
}
16 changes: 16 additions & 0 deletions networksession.go
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,24 @@ func (s *sessionUDP) ExchangeWithContext(ctx context.Context, req Message) (Mess
}
}

func (s *sessionTCP) validateMessageSize(msg Message) error {
size, err := msg.ToBytesLength()
if err != nil {
return err
}

if uint32(size) > s.peerMaxMessageSize {
return ErrMaxMessageSizeLimitExceeded
}

return nil
}

// Write implements the networkSession.Write method.
func (s *sessionTCP) WriteMsgWithContext(ctx context.Context, req Message) error {
if err := s.validateMessageSize(req); err != nil {
return err
}
buffer := bytes.NewBuffer(make([]byte, 0, 1500))
err := req.MarshalBinary(buffer)
if err != nil {
Expand Down
5 changes: 5 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,11 @@ func (srv *Server) serveTCPconnection(ctx *shutdownContext, netConn net.Conn) er
return session.closeWithError(fmt.Errorf("cannot serve tcp connection: %v", err))
}

if srv.MaxMessageSize != 0 &&
uint32(mti.totLen) > srv.MaxMessageSize {
return session.closeWithError(fmt.Errorf("cannot serve tcp connection: %v", ErrMaxMessageSizeLimitExceeded))
}

body := make([]byte, mti.BodyLen())
//ctx, cancel := context.WithTimeout(srv.ctx, srv.readTimeout())
err = conn.ReadFullWithContext(ctx, body)
Expand Down
4 changes: 3 additions & 1 deletion server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, chan
fmt.Printf("networkSession start %v\n", s.RemoteAddr())
}, NotifySessionEndFunc: func(w *ClientConn, err error) {
fmt.Printf("networkSession end %v: %v\n", w.RemoteAddr(), err)
}}
},
MaxMessageSize: ^uint32(0),
}

// fin must be buffered so the goroutine below won't block
// forever if fin is never read from. This always happens
Expand Down

0 comments on commit 1958954

Please sign in to comment.