diff --git a/client.go b/client.go index 3c37fa9a..8b54a17e 100644 --- a/client.go +++ b/client.go @@ -870,6 +870,20 @@ type link struct { msg Message // current message being decoded } +func (l *link) getSource() *source { + if l.source == nil { + l.source = new(source) + } + return l.source +} + +func (l *link) getTarget() *target { + if l.target == nil { + l.target = new(target) + } + return l.target +} + // attachLink is used by Receiver and Sender to create new links func attachLink(s *Session, r *Receiver, opts []LinkOption) (*link, error) { l, err := newLink(s, r, opts) @@ -991,7 +1005,7 @@ func attachLink(s *Session, r *Receiver, opts []LinkOption) (*link, error) { if isReceiver { // if dynamic address requested, copy assigned name to address if l.dynamicAddr && resp.Source != nil { - l.source.Address = resp.Source.Address + l.getSource().Address = resp.Source.Address } // deliveryCount is a sequence number, must initialize to sender's initial sequence number l.deliveryCount = resp.InitialDeliveryCount @@ -1000,7 +1014,7 @@ func attachLink(s *Session, r *Receiver, opts []LinkOption) (*link, error) { } else { // if dynamic address requested, copy assigned name to address if l.dynamicAddr && resp.Target != nil { - l.target.Address = resp.Target.Address + l.getTarget().Address = resp.Target.Address } l.transfers = make(chan performTransfer) } @@ -1591,17 +1605,10 @@ func LinkName(name string) LinkOption { // LinkSourceCapabilities sets the source capabilities. func LinkSourceCapabilities(capabilities ...string) LinkOption { return func(l *link) error { - if l.source == nil { - l.source = new(source) + source := l.getSource() + for _, v := range capabilities { + source.Capabilities = append(source.Capabilities, symbol(v)) } - - // Convert string to symbol - symbolCapabilities := make([]symbol, len(capabilities)) - for i, v := range capabilities { - symbolCapabilities[i] = symbol(v) - } - - l.source.Capabilities = append(l.source.Capabilities, symbolCapabilities...) return nil } } @@ -1609,10 +1616,7 @@ func LinkSourceCapabilities(capabilities ...string) LinkOption { // LinkSourceAddress sets the source address. func LinkSourceAddress(addr string) LinkOption { return func(l *link) error { - if l.source == nil { - l.source = new(source) - } - l.source.Address = addr + l.getSource().Address = addr return nil } } @@ -1620,10 +1624,7 @@ func LinkSourceAddress(addr string) LinkOption { // LinkTargetAddress sets the target address. func LinkTargetAddress(addr string) LinkOption { return func(l *link) error { - if l.target == nil { - l.target = new(target) - } - l.target.Address = addr + l.getTarget().Address = addr return nil } } @@ -1734,11 +1735,9 @@ func LinkSelectorFilter(filter string) LinkOption { // http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-types-v1.0-os.html#section-descriptor-values func LinkSourceFilter(name string, code uint64, value interface{}) LinkOption { return func(l *link) error { - if l.source == nil { - l.source = new(source) - } - if l.source.Filter == nil { - l.source.Filter = make(map[symbol]*describedType) + source := l.getSource() + if source.Filter == nil { + source.Filter = make(map[symbol]*describedType) } var descriptor interface{} @@ -1748,7 +1747,7 @@ func LinkSourceFilter(name string, code uint64, value interface{}) LinkOption { descriptor = symbol(name) } - l.source.Filter[symbol(name)] = &describedType{ + source.Filter[symbol(name)] = &describedType{ descriptor: descriptor, value: value, } @@ -1777,12 +1776,7 @@ func LinkTargetDurability(d Durability) LinkOption { if d > DurabilityUnsettledState { return errorErrorf("invalid Durability %d", d) } - - if l.target == nil { - l.target = new(target) - } - l.target.Durable = d - + l.getTarget().Durable = d return nil } } @@ -1796,12 +1790,7 @@ func LinkTargetExpiryPolicy(p ExpiryPolicy) LinkOption { if err != nil { return err } - - if l.target == nil { - l.target = new(target) - } - l.target.ExpiryPolicy = p - + l.getTarget().ExpiryPolicy = p return nil } } @@ -1814,12 +1803,7 @@ func LinkSourceDurability(d Durability) LinkOption { if d > DurabilityUnsettledState { return errorErrorf("invalid Durability %d", d) } - - if l.source == nil { - l.source = new(source) - } - l.source.Durable = d - + l.getSource().Durable = d return nil } } @@ -1833,12 +1817,7 @@ func LinkSourceExpiryPolicy(p ExpiryPolicy) LinkOption { if err != nil { return err } - - if l.source == nil { - l.source = new(source) - } - l.source.ExpiryPolicy = p - + l.getSource().ExpiryPolicy = p return nil } } diff --git a/local_test.go b/local_test.go index c3c2992c..5d23267c 100644 --- a/local_test.go +++ b/local_test.go @@ -3,20 +3,22 @@ package amqp_test import ( + "context" "net" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "pack.ag/amqp" ) // Tests that require a local broker running on the standard AMQP port. func TestDial_IPV6(t *testing.T) { - if c, err := amqp.Dial("amqp://localhost"); err != nil { - t.Skip("can't connect to local AMQP server") - } else { - c.Close() - } + c, err := amqp.Dial("amqp://localhost") + assert.NoError(t, err) + c.Close() + l, err := net.Listen("tcp6", "[::]:0") if err != nil { t.Skip("ipv6 not supported") @@ -35,3 +37,28 @@ func TestDial_IPV6(t *testing.T) { }) } } + +func TestSendReceive(t *testing.T) { + c, err := amqp.Dial("amqp://") + require.NoError(t, err) + defer c.Close() + + ssn, err := c.NewSession() + require.NoError(t, err) + + r, err := ssn.NewReceiver(amqp.LinkAddressDynamic()) + require.NoError(t, err) + var m *amqp.Message + done := make(chan error) + go func() { + var err error + defer func() { done <- err; close(done) }() + m, err = r.Receive(context.Background()) + m.Accept() + }() + + s, err := ssn.NewSender(amqp.LinkAddress(r.Address())) + require.NoError(t, s.Send(context.Background(), amqp.NewMessage([]byte("hello")))) + require.NoError(t, <-done) + assert.Equal(t, "hello", string(m.GetData())) +}