Skip to content

Commit

Permalink
Fix SEGV when using dynamic addresses.
Browse files Browse the repository at this point in the history
Fixes  vcabbage#197

Signed-off-by: Alan Conway <aconway@redhat.com>
  • Loading branch information
alanconway committed Dec 2, 2019
1 parent 86a6a19 commit e721817
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 55 deletions.
79 changes: 29 additions & 50 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,20 @@ type link struct {
msg Message // current message being decoded
}

func (l *link) getSource() *source {
if l.source == nil {
l.source = new(source)
}
return l.source
}

func (l *link) getTarget() *target {
if l.target == nil {
l.target = new(target)
}
return l.target
}

// attachLink is used by Receiver and Sender to create new links
func attachLink(s *Session, r *Receiver, opts []LinkOption) (*link, error) {
l, err := newLink(s, r, opts)
Expand Down Expand Up @@ -978,7 +992,7 @@ func attachLink(s *Session, r *Receiver, opts []LinkOption) (*link, error) {
if isReceiver {
// if dynamic address requested, copy assigned name to address
if l.dynamicAddr && resp.Source != nil {
l.source.Address = resp.Source.Address
l.getSource().Address = resp.Source.Address
}
// deliveryCount is a sequence number, must initialize to sender's initial sequence number
l.deliveryCount = resp.InitialDeliveryCount
Expand All @@ -987,7 +1001,7 @@ func attachLink(s *Session, r *Receiver, opts []LinkOption) (*link, error) {
} else {
// if dynamic address requested, copy assigned name to address
if l.dynamicAddr && resp.Target != nil {
l.target.Address = resp.Target.Address
l.getTarget().Address = resp.Target.Address
}
l.transfers = make(chan performTransfer)
}
Expand Down Expand Up @@ -1578,39 +1592,26 @@ func LinkName(name string) LinkOption {
// LinkSourceCapabilities sets the source capabilities.
func LinkSourceCapabilities(capabilities ...string) LinkOption {
return func(l *link) error {
if l.source == nil {
l.source = new(source)
source := l.getSource()
for _, v := range capabilities {
source.Capabilities = append(source.Capabilities, symbol(v))
}

// Convert string to symbol
symbolCapabilities := make([]symbol, len(capabilities))
for i, v := range capabilities {
symbolCapabilities[i] = symbol(v)
}

l.source.Capabilities = append(l.source.Capabilities, symbolCapabilities...)
return nil
}
}

// LinkSourceAddress sets the source address.
func LinkSourceAddress(addr string) LinkOption {
return func(l *link) error {
if l.source == nil {
l.source = new(source)
}
l.source.Address = addr
l.getSource().Address = addr
return nil
}
}

// LinkTargetAddress sets the target address.
func LinkTargetAddress(addr string) LinkOption {
return func(l *link) error {
if l.target == nil {
l.target = new(target)
}
l.target.Address = addr
l.getTarget().Address = addr
return nil
}
}
Expand Down Expand Up @@ -1721,11 +1722,9 @@ func LinkSelectorFilter(filter string) LinkOption {
// http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-types-v1.0-os.html#section-descriptor-values
func LinkSourceFilter(name string, code uint64, value interface{}) LinkOption {
return func(l *link) error {
if l.source == nil {
l.source = new(source)
}
if l.source.Filter == nil {
l.source.Filter = make(map[symbol]*describedType)
source := l.getSource()
if source.Filter == nil {
source.Filter = make(map[symbol]*describedType)
}

var descriptor interface{}
Expand All @@ -1735,7 +1734,7 @@ func LinkSourceFilter(name string, code uint64, value interface{}) LinkOption {
descriptor = symbol(name)
}

l.source.Filter[symbol(name)] = &describedType{
source.Filter[symbol(name)] = &describedType{
descriptor: descriptor,
value: value,
}
Expand Down Expand Up @@ -1764,12 +1763,7 @@ func LinkTargetDurability(d Durability) LinkOption {
if d > DurabilityUnsettledState {
return errorErrorf("invalid Durability %d", d)
}

if l.target == nil {
l.target = new(target)
}
l.target.Durable = d

l.getTarget().Durable = d
return nil
}
}
Expand All @@ -1783,12 +1777,7 @@ func LinkTargetExpiryPolicy(p ExpiryPolicy) LinkOption {
if err != nil {
return err
}

if l.target == nil {
l.target = new(target)
}
l.target.ExpiryPolicy = p

l.getTarget().ExpiryPolicy = p
return nil
}
}
Expand All @@ -1801,12 +1790,7 @@ func LinkSourceDurability(d Durability) LinkOption {
if d > DurabilityUnsettledState {
return errorErrorf("invalid Durability %d", d)
}

if l.source == nil {
l.source = new(source)
}
l.source.Durable = d

l.getSource().Durable = d
return nil
}
}
Expand All @@ -1820,12 +1804,7 @@ func LinkSourceExpiryPolicy(p ExpiryPolicy) LinkOption {
if err != nil {
return err
}

if l.source == nil {
l.source = new(source)
}
l.source.ExpiryPolicy = p

l.getSource().ExpiryPolicy = p
return nil
}
}
Expand Down
37 changes: 32 additions & 5 deletions local_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@
package amqp_test

import (
"context"
"net"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"pack.ag/amqp"
)

// Tests that require a local broker running on the standard AMQP port.

func TestDial_IPV6(t *testing.T) {
if c, err := amqp.Dial("amqp://localhost"); err != nil {
t.Skip("can't connect to local AMQP server")
} else {
c.Close()
}
c, err := amqp.Dial("amqp://localhost")
assert.NoError(t, err)
c.Close()

l, err := net.Listen("tcp6", "[::]:0")
if err != nil {
t.Skip("ipv6 not supported")
Expand All @@ -35,3 +37,28 @@ func TestDial_IPV6(t *testing.T) {
})
}
}

func TestSendReceive(t *testing.T) {
c, err := amqp.Dial("amqp://")
require.NoError(t, err)
defer c.Close()

ssn, err := c.NewSession()
require.NoError(t, err)

r, err := ssn.NewReceiver(amqp.LinkAddressDynamic())
require.NoError(t, err)
var m *amqp.Message
done := make(chan error)
go func() {
var err error
defer func() { done <- err; close(done) }()
m, err = r.Receive(context.Background())
m.Accept()
}()

s, err := ssn.NewSender(amqp.LinkAddress(r.Address()))
require.NoError(t, s.Send(context.Background(), amqp.NewMessage([]byte("hello"))))
require.NoError(t, <-done)
assert.Equal(t, "hello", string(m.GetData()))
}

0 comments on commit e721817

Please sign in to comment.