diff --git a/channel.go b/channel.go index 7c6c7e2..0e2073e 100644 --- a/channel.go +++ b/channel.go @@ -173,6 +173,13 @@ type Underlay interface { GetRemoteAddr() net.Addr } +type classicUnderlay interface { + Underlay + getPeer() transport.Conn + init(id string, connectionId string, headers Headers) + rxHello() (*Message, error) +} + const AnyContentType = -1 const HelloSequence = -1 diff --git a/classic_dialer.go b/classic_dialer.go index 8774eb5..9bcde42 100644 --- a/classic_dialer.go +++ b/classic_dialer.go @@ -17,52 +17,63 @@ package channel import ( + "errors" "fmt" "github.com/michaelquigley/pfxlog" "github.com/openziti/identity" "github.com/openziti/transport/v2" - "github.com/pkg/errors" "time" ) type classicDialer struct { - identity *identity.TokenId - endpoint transport.Address - localBinding string - headers map[int32][]byte + identity *identity.TokenId + endpoint transport.Address + localBinding string + headers map[int32][]byte + underlayFactory func(peer transport.Conn, version uint32) classicUnderlay } func NewClassicDialerWithBindAddress(identity *identity.TokenId, endpoint transport.Address, localBinding string, headers map[int32][]byte) UnderlayFactory { - return &classicDialer{ + result := &classicDialer{ identity: identity, endpoint: endpoint, localBinding: localBinding, headers: headers, } + + if endpoint.Type() == "dtls" { + result.underlayFactory = newDatagramUnderlay + } else { + result.underlayFactory = newClassicImpl + } + + return result } func NewClassicDialer(identity *identity.TokenId, endpoint transport.Address, headers map[int32][]byte) UnderlayFactory { return NewClassicDialerWithBindAddress(identity, endpoint, "", headers) } -func (dialer *classicDialer) Create(timeout time.Duration, tcfg transport.Configuration) (Underlay, error) { - log := pfxlog.ContextLogger(dialer.endpoint.String()) +func (self *classicDialer) Create(timeout time.Duration, tcfg transport.Configuration) (Underlay, error) { + log := pfxlog.ContextLogger(self.endpoint.String()) log.Debug("started") defer log.Debug("exited") + deadline := time.Now().Add(timeout) + version := uint32(2) tryCount := 0 - log.Debugf("Attempting to dial with bind: %s", dialer.localBinding) + log.Debugf("Attempting to dial with bind: %s", self.localBinding) - for { - peer, err := dialer.endpoint.DialWithLocalBinding("classic", dialer.localBinding, dialer.identity, timeout, tcfg) + for time.Now().Before(deadline) { + peer, err := self.endpoint.DialWithLocalBinding("classic", self.localBinding, self.identity, timeout, tcfg) if err != nil { return nil, err } - impl := newClassicImpl(peer, version) - if err := dialer.sendHello(impl, timeout); err != nil { + underlay := self.underlayFactory(peer, version) + if err = self.sendHello(underlay, deadline); err != nil { if tryCount > 0 { return nil, err } else { @@ -73,57 +84,60 @@ func (dialer *classicDialer) Create(timeout time.Duration, tcfg transport.Config log.Warnf("Retrying dial with protocol version %v", version) continue } - return impl, nil + return underlay, nil } + return nil, errors.New("timeout waiting for dial") } -func (dialer *classicDialer) sendHello(impl *classicImpl, timeout time.Duration) error { - log := pfxlog.ContextLogger(impl.Label()) +func (self *classicDialer) sendHello(underlay classicUnderlay, deadline time.Time) error { + log := pfxlog.ContextLogger(underlay.Label()) defer log.Debug("exited") log.Debug("started") - if timeout == 0 { - timeout = time.Minute - } + peer := underlay.getPeer() - if err := impl.peer.SetReadDeadline(time.Now().Add(timeout)); err != nil { + if err := peer.SetDeadline(deadline); err != nil { return err } defer func() { - _ = impl.peer.SetReadDeadline(time.Time{}) + if err := peer.SetDeadline(time.Time{}); err != nil { // clear write deadline + log.WithError(err).Error("unable to clear deadline") + } }() - request := NewHello(dialer.identity.Token, dialer.headers) - request.sequence = HelloSequence - if err := impl.Tx(request); err != nil { - _ = impl.peer.Close() + request := NewHello(self.identity.Token, self.headers) + request.SetSequence(HelloSequence) + if err := underlay.Tx(request); err != nil { + _ = underlay.Close() return err } - response, err := impl.Rx() + response, err := underlay.Rx() if err != nil { if errors.Is(err, BadMagicNumberError) { - return errors.Errorf("could not negotiate connection with %v, invalid header", impl.peer.RemoteAddr().String()) + return fmt.Errorf("could not negotiate connection with %v, invalid header", peer.RemoteAddr().String()) } return err } - if !response.IsReplyingTo(request.sequence) || response.ContentType != ContentTypeResultType { - return fmt.Errorf("channel synchronization error, expected %v, got %v", request.sequence, response.ReplyFor()) + if !response.IsReplyingTo(HelloSequence) || response.ContentType != ContentTypeResultType { + return fmt.Errorf("channel synchronization error, expected %v, got %v", HelloSequence, response.ReplyFor()) } result := UnmarshalResult(response) if !result.Success { return errors.New(result.Message) } - impl.connectionId = string(response.Headers[ConnectionIdHeader]) - if id, ok := response.GetStringHeader(IdHeader); ok { - impl.id = &identity.TokenId{Token: id} - } else if certs := impl.Certificates(); len(certs) > 0 { - impl.id = &identity.TokenId{Token: certs[0].Subject.CommonName} + connectionId := string(response.Headers[ConnectionIdHeader]) + id := "" + + if val, ok := response.GetStringHeader(IdHeader); ok { + id = val + } else if certs := underlay.Certificates(); len(certs) > 0 { + id = certs[0].Subject.CommonName } - impl.headers = response.Headers + underlay.init(id, connectionId, response.Headers) return nil } diff --git a/classic_impl.go b/classic_impl.go index aa3769a..2389ca5 100644 --- a/classic_impl.go +++ b/classic_impl.go @@ -19,7 +19,6 @@ package channel import ( "crypto/x509" "fmt" - "github.com/openziti/identity" "github.com/openziti/transport/v2" "github.com/pkg/errors" "net" @@ -29,7 +28,7 @@ import ( type classicImpl struct { peer transport.Conn - id *identity.TokenId + id string connectionId string headers map[int32][]byte closed atomic.Bool @@ -86,7 +85,7 @@ func (impl *classicImpl) Tx(m *Message) error { } func (impl *classicImpl) Id() string { - return impl.id.Token + return impl.id } func (impl *classicImpl) Headers() map[int32][]byte { @@ -120,7 +119,17 @@ func (impl *classicImpl) IsClosed() bool { return impl.closed.Load() } -func newClassicImpl(peer transport.Conn, version uint32) *classicImpl { +func (impl *classicImpl) init(id string, connectionId string, headers Headers) { + impl.id = id + impl.connectionId = connectionId + impl.headers = headers +} + +func (impl *classicImpl) getPeer() transport.Conn { + return impl.peer +} + +func newClassicImpl(peer transport.Conn, version uint32) classicUnderlay { readF := ReadV2 marshalF := MarshalV2 diff --git a/classic_listener.go b/classic_listener.go index c4a45ce..9a83d28 100644 --- a/classic_listener.go +++ b/classic_listener.go @@ -30,18 +30,19 @@ import ( ) type classicListener struct { - identity *identity.TokenId - endpoint transport.Address - socket io.Closer - close chan struct{} - handlers []ConnectionHandler - acceptF func(underlay Underlay) - created chan Underlay - connectOptions ConnectOptions - tcfg transport.Configuration - headers map[int32][]byte - closed atomic.Bool - listenerPool goroutines.Pool + identity *identity.TokenId + endpoint transport.Address + socket io.Closer + close chan struct{} + handlers []ConnectionHandler + acceptF func(underlay Underlay) + created chan Underlay + connectOptions ConnectOptions + tcfg transport.Configuration + headers map[int32][]byte + closed atomic.Bool + listenerPool goroutines.Pool + underlayFactory func(peer transport.Conn, version uint32) classicUnderlay } func DefaultListenerConfig() ListenerConfig { @@ -86,15 +87,21 @@ func newClassicListener(identity *identity.TokenId, endpoint transport.Address, panic(err) } + underlayFactory := newClassicImpl + if endpoint.Type() == "dtls" { + underlayFactory = newDatagramUnderlay + } + return &classicListener{ - identity: identity, - endpoint: endpoint, - close: closeNotify, - connectOptions: config.ConnectOptions, - tcfg: config.TransportConfig, - headers: config.Headers, - listenerPool: pool, - handlers: config.ConnectionHandlers, + identity: identity, + endpoint: endpoint, + close: closeNotify, + connectOptions: config.ConnectOptions, + tcfg: config.TransportConfig, + headers: config.Headers, + listenerPool: pool, + handlers: config.ConnectionHandlers, + underlayFactory: underlayFactory, } } @@ -121,77 +128,78 @@ func NewClassicListener(identity *identity.TokenId, endpoint transport.Address, return listener } -func (listener *classicListener) Listen(handlers ...ConnectionHandler) error { - listener.handlers = append(listener.handlers, handlers...) - socket, err := listener.endpoint.Listen("classic", listener.identity, listener.acceptConnection, listener.tcfg) +func (self *classicListener) Listen(handlers ...ConnectionHandler) error { + self.handlers = append(self.handlers, handlers...) + socket, err := self.endpoint.Listen("classic", self.identity, self.acceptConnection, self.tcfg) if err != nil { return err } - listener.socket = socket + self.socket = socket return nil } -func (listener *classicListener) Close() error { - if listener.closed.CompareAndSwap(false, true) { - close(listener.close) - if socket := listener.socket; socket != nil { +func (self *classicListener) Close() error { + if self.closed.CompareAndSwap(false, true) { + close(self.close) + if socket := self.socket; socket != nil { if err := socket.Close(); err != nil { return err } } - listener.socket = nil + self.socket = nil } return nil } -func (listener *classicListener) Create(_ time.Duration, _ transport.Configuration) (Underlay, error) { - if listener.created == nil { +func (self *classicListener) Create(_ time.Duration, _ transport.Configuration) (Underlay, error) { + if self.created == nil { return nil, errors.New("this listener was not set up for Create to be called, programming error") } select { - case impl := <-listener.created: + case impl := <-self.created: if impl != nil { return impl, nil } - case <-listener.close: + case <-self.close: } return nil, ListenerClosedError } -func (listener *classicListener) acceptConnection(peer transport.Conn) { - log := pfxlog.ContextLogger(listener.endpoint.String()) - err := listener.listenerPool.Queue(func() { - impl := newClassicImpl(peer, 2) +func (self *classicListener) acceptConnection(peer transport.Conn) { + log := pfxlog.ContextLogger(self.endpoint.String()) + err := self.listenerPool.Queue(func() { + impl := self.underlayFactory(peer, 2) - var err error - impl.connectionId, err = NextConnectionId() + connectionId, err := NextConnectionId() if err != nil { _ = peer.Close() log.Errorf("error getting connection id for [%s] (%v)", peer.Detail().Address, err) return } - if err = peer.SetReadDeadline(time.Now().Add(listener.connectOptions.ConnectTimeout)); err != nil { - log.Errorf("could not set read timeout for [%s] (%v)", peer.Detail().Address, err) + if err = peer.SetDeadline(time.Now().Add(self.connectOptions.ConnectTimeout)); err != nil { + log.Errorf("could not set connection deadline for [%s] (%v)", peer.Detail().Address, err) _ = peer.Close() return } - request, hello, err := listener.receiveHello(impl) + defer func() { + if err = peer.SetDeadline(time.Time{}); err != nil { + log.Errorf("could not clear connection deadline for [%s] (%v)", peer.Detail().Address, err) + _ = peer.Close() + return + } + }() + + request, hello, err := self.receiveHello(impl) if err != nil { _ = peer.Close() log.Errorf("error receiving hello from [%s] (%v)", peer.Detail().Address, err) return } - if err = peer.SetReadDeadline(time.Time{}); err != nil { - log.Errorf("could not clear read timeout for [%s] (%v)", peer.Detail().Address, err) - _ = peer.Close() - return - } - - for _, h := range listener.handlers { + for _, h := range self.handlers { if err = h.HandleConnection(hello, peer.PeerCertificates()); err != nil { break } @@ -203,13 +211,13 @@ func (listener *classicListener) acceptConnection(peer transport.Conn) { return } - impl.id = &identity.TokenId{Token: hello.IdToken} - impl.headers = hello.Headers + impl.init(hello.IdToken, connectionId, hello.Headers) - if err := listener.ackHello(impl, request, true, ""); err == nil { - listener.acceptF(impl) + if err = self.ackHello(impl, request, true, ""); err == nil { + self.acceptF(impl) } else { log.Errorf("error acknowledging hello for [%s] (%v)", peer.Detail().Address, err) + _ = peer.Close() } }) if err != nil { @@ -217,15 +225,15 @@ func (listener *classicListener) acceptConnection(peer transport.Conn) { } } -func (listener *classicListener) receiveHello(impl *classicImpl) (*Message, *Hello, error) { +func (self *classicListener) receiveHello(impl classicUnderlay) (*Message, *Hello, error) { log := pfxlog.ContextLogger(impl.Label()) log.Debug("started") defer log.Debug("exited") request, err := impl.rxHello() if err != nil { - if err == BadMagicNumberError { - WriteUnknownVersionResponse(impl.peer) + if errors.Is(err, BadMagicNumberError) { + WriteUnknownVersionResponse(impl.getPeer()) } _ = impl.Close() return nil, nil, fmt.Errorf("receive error (%s)", err) @@ -238,16 +246,16 @@ func (listener *classicListener) receiveHello(impl *classicImpl) (*Message, *Hel return request, hello, nil } -func (listener *classicListener) ackHello(impl *classicImpl, request *Message, success bool, message string) error { +func (self *classicListener) ackHello(impl classicUnderlay, request *Message, success bool, message string) error { response := NewResult(success, message) - for key, val := range listener.headers { + for key, val := range self.headers { response.Headers[key] = val } - response.PutStringHeader(ConnectionIdHeader, impl.connectionId) - if listener.identity != nil { - response.PutStringHeader(IdHeader, listener.identity.Token) + response.PutStringHeader(ConnectionIdHeader, impl.ConnectionId()) + if self.identity != nil { + response.PutStringHeader(IdHeader, self.identity.Token) } response.sequence = HelloSequence diff --git a/datagram/dialer.go b/datagram/dialer.go deleted file mode 100644 index f480cfb..0000000 --- a/datagram/dialer.go +++ /dev/null @@ -1,122 +0,0 @@ -//go:build prototype - -/* - Copyright NetFoundry Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package datagram - -import ( - "errors" - "fmt" - "github.com/michaelquigley/pfxlog" - "github.com/openziti/channel/v2" - "github.com/openziti/identity" - "github.com/openziti/transport/v2" - "time" -) - -type dialer struct { - id *identity.TokenId - peer transport.Conn - headers map[int32][]byte -} - -func NewDatagramDialer(id *identity.TokenId, peer transport.Conn, headers map[int32][]byte) channel.UnderlayFactory { - return &dialer{ - id: id, - peer: peer, - headers: headers, - } -} - -func (self *dialer) Create(timeout time.Duration, _ transport.Configuration) (channel.Underlay, error) { - log := pfxlog.Logger() - log.Debug("started") - defer log.Debug("exited") - - if timeout < 10*time.Millisecond { - return nil, errors.New("timeout must be at least 10ms") - } - - version := uint32(2) - - defer func() { - if err := self.peer.SetDeadline(time.Time{}); err != nil { // clear write deadline - log.WithError(err).Error("unable to clear write deadline") - } - }() - - impl := &Underlay{ - id: self.id, - peer: self.peer, - } - - deadline := time.Now().Add(timeout) - - for deadline.After(time.Now()) { - if err := self.sendHello(impl); err != nil { - if retryVersion, _ := channel.GetRetryVersion(err); retryVersion != version { - version = retryVersion - } - - log.Warnf("Retrying dial with protocol version %v", version) - continue - } - impl.id = self.id - return impl, nil - } - - return nil, errors.New("connect timeout") -} - -func (self *dialer) sendHello(impl *Underlay) error { - log := pfxlog.ContextLogger(impl.Label()) - defer log.Debug("exited") - log.Debug("started") - - request := channel.NewHello(self.id.Token, self.headers) - request.SetSequence(channel.HelloSequence) - if err := impl.Tx(request); err != nil { - _ = impl.peer.Close() - return err - } - - if err := impl.peer.SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil { - return err - } - - defer func() { - if err := self.peer.SetReadDeadline(time.Time{}); err != nil { // clear write deadline - log.WithError(err).Error("unable to clear read deadline") - } - }() - - response, err := impl.Rx() - if err != nil { - return err - } - if !response.IsReplyingTo(channel.HelloSequence) || response.ContentType != channel.ContentTypeResultType { - return fmt.Errorf("channel synchronization error, expected %v, got %v", channel.HelloSequence, response.ReplyFor()) - } - result := channel.UnmarshalResult(response) - if !result.Success { - return errors.New(result.Message) - } - impl.connectionId = string(response.Headers[channel.ConnectionIdHeader]) - impl.headers = response.Headers - - return nil -} diff --git a/datagram/listener.go b/datagram/listener.go deleted file mode 100644 index 3188829..0000000 --- a/datagram/listener.go +++ /dev/null @@ -1,122 +0,0 @@ -//go:build prototype - -/* - Copyright NetFoundry Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package datagram - -import ( - "fmt" - "github.com/michaelquigley/pfxlog" - "github.com/openziti/channel/v2" - "github.com/openziti/identity" - "github.com/openziti/transport/v2" - "github.com/pkg/errors" - "time" -) - -type listener struct { - identity *identity.TokenId - peer transport.Conn - headers map[int32][]byte -} - -func NewListener(identity *identity.TokenId, peer transport.Conn, headers map[int32][]byte) channel.UnderlayFactory { - return &listener{ - identity: identity, - peer: peer, - headers: headers, - } -} - -// TODO: need to restructure so we start after receiving hello and responding, but can also -// respond to additional hellos after we're up and running, since initial response -// may have gotten lost. Could add hello receive handler here -func (self *listener) Create(timeout time.Duration, _ transport.Configuration) (channel.Underlay, error) { - log := pfxlog.Logger() - - impl := &Underlay{ - id: self.identity, - peer: self.peer, - } - - connectionId, err := channel.NextConnectionId() - if err != nil { - return nil, errors.Wrap(err, "error getting connection id") - } - impl.connectionId = connectionId - - if timeout > 0 { - defer func() { - if err = self.peer.SetDeadline(time.Time{}); err != nil { - log.WithError(err).Error("unable to clear deadline on conn after create") - } - }() - - if err = self.peer.SetDeadline(time.Now().Add(timeout)); err != nil { - return nil, errors.Wrap(err, "could not set deadline on conn") - } - } - - request, hello, err := self.receiveHello(impl) - if err != nil { - return nil, errors.Wrap(err, "error receiving hello") - } - - impl.id = &identity.TokenId{Token: hello.IdToken} - impl.headers = hello.Headers - - if err = self.ackHello(impl, request, true, ""); err != nil { - return nil, errors.Wrap(err, "unable to acknowledge hello") - } - - return impl, nil -} - -func (self *listener) receiveHello(impl *Underlay) (*channel.Message, *channel.Hello, error) { - log := pfxlog.ContextLogger(impl.Label()) - log.Debug("started") - defer log.Debug("exited") - - request, err := impl.Rx() - if err != nil { - if err == channel.UnknownVersionError { - channel.WriteUnknownVersionResponse(impl.peer) - } - _ = impl.Close() - return nil, nil, fmt.Errorf("receive error (%s)", err) - } - if request.ContentType != channel.ContentTypeHelloType { - _ = impl.Close() - return nil, nil, fmt.Errorf("unexpected content type [%d]", request.ContentType) - } - hello := channel.UnmarshalHello(request) - return request, hello, nil -} - -func (self *listener) ackHello(impl *Underlay, request *channel.Message, success bool, message string) error { - response := channel.NewResult(success, message) - - for key, val := range self.headers { - response.Headers[key] = val - } - - response.Headers[channel.ConnectionIdHeader] = []byte(impl.connectionId) - response.SetSequence(channel.HelloSequence) - - response.ReplyTo(request) - return impl.Tx(response) -} diff --git a/datagram/underlay.go b/datagram/underlay.go deleted file mode 100644 index 712f373..0000000 --- a/datagram/underlay.go +++ /dev/null @@ -1,118 +0,0 @@ -//go:build prototype - -/* - Copyright NetFoundry Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package datagram - -import ( - "bytes" - "crypto/x509" - "fmt" - "github.com/openziti/channel/v2" - "github.com/openziti/foundation/v2/concurrenz" - "github.com/openziti/identity" - "github.com/openziti/transport/v2" - "time" -) - -type Underlay struct { - id *identity.TokenId - connectionId string - headers map[int32][]byte - peer transport.Conn - closed concurrenz.AtomicBoolean -} - -func NewUnderlay(id *identity.TokenId, peer transport.Conn) channel.Underlay { - return &Underlay{ - id: id, - peer: peer, - } -} - -func (impl *Underlay) GetLocalAddr() net.Addr { - return impl.peer.LocalAddr() -} - -func (impl *Underlay) GetRemoteAddr() net.Addr { - return impl.peer.RemoteAddr() -} - -func (self *Underlay) Rx() (*channel.Message, error) { - buf := make([]byte, 65000) - n, err := self.peer.Read(buf) - if err != nil { - return nil, err - } - - buf = buf[:n] - - reader := bytes.NewBuffer(buf) - return channel.ReadV2(reader) -} - -func (self *Underlay) Tx(m *channel.Message) error { - data, err := channel.MarshalV2(m) - if err != nil { - return err - } - _, err = self.peer.Write(data) - return err -} - -func (self *Underlay) Id() *identity.TokenId { - return self.id -} - -func (self *Underlay) LogicalName() string { - return "datagram" -} - -func (self *Underlay) ConnectionId() string { - return self.connectionId -} - -func (self *Underlay) Certificates() []*x509.Certificate { - return self.peer.PeerCertificates() -} - -func (self *Underlay) Label() string { - return fmt.Sprintf("u{%s}->i{%s}", self.LogicalName(), self.ConnectionId()) -} - -func (self *Underlay) Close() error { - if self.closed.CompareAndSwap(false, true) { - return self.peer.Close() - } - return nil -} - -func (self *Underlay) IsClosed() bool { - return self.closed.Get() -} - -func (self *Underlay) Headers() map[int32][]byte { - return self.headers -} - -func (self *Underlay) SetWriteTimeout(duration time.Duration) error { - return self.peer.SetWriteDeadline(time.Now().Add(duration)) -} - -func (self *Underlay) SetWriteDeadline(deadline time.Time) error { - return self.peer.SetWriteDeadline(deadline) -} diff --git a/datagram_underlay.go b/datagram_underlay.go new file mode 100644 index 0000000..18b406b --- /dev/null +++ b/datagram_underlay.go @@ -0,0 +1,128 @@ +/* + Copyright NetFoundry Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package channel + +import ( + "bytes" + "crypto/x509" + "fmt" + "github.com/openziti/transport/v2" + "net" + "sync/atomic" + "time" +) + +type DatagramUnderlay struct { + id string + connectionId string + headers map[int32][]byte + peer transport.Conn + closed atomic.Bool +} + +func newDatagramUnderlay(peer transport.Conn, _ uint32) classicUnderlay { + return &DatagramUnderlay{ + peer: peer, + } +} + +func (self *DatagramUnderlay) GetLocalAddr() net.Addr { + return self.peer.LocalAddr() +} + +func (self *DatagramUnderlay) GetRemoteAddr() net.Addr { + return self.peer.RemoteAddr() +} + +func (self *DatagramUnderlay) Rx() (*Message, error) { + buf := make([]byte, 65000) + n, err := self.peer.Read(buf) + if err != nil { + return nil, err + } + + buf = buf[:n] + + reader := bytes.NewBuffer(buf) + return ReadV2(reader) +} + +func (self *DatagramUnderlay) Tx(m *Message) error { + data, err := MarshalV2(m) + if err != nil { + return err + } + _, err = self.peer.Write(data) + return err +} + +func (self *DatagramUnderlay) Id() string { + return self.id +} + +func (self *DatagramUnderlay) LogicalName() string { + return "datagram" +} + +func (self *DatagramUnderlay) ConnectionId() string { + return self.connectionId +} + +func (self *DatagramUnderlay) Certificates() []*x509.Certificate { + return self.peer.PeerCertificates() +} + +func (self *DatagramUnderlay) Label() string { + return fmt.Sprintf("u{%s}->i{%s}", self.LogicalName(), self.ConnectionId()) +} + +func (self *DatagramUnderlay) Close() error { + if self.closed.CompareAndSwap(false, true) { + return self.peer.Close() + } + return nil +} + +func (self *DatagramUnderlay) IsClosed() bool { + return self.closed.Load() +} + +func (self *DatagramUnderlay) Headers() map[int32][]byte { + return self.headers +} + +func (self *DatagramUnderlay) SetWriteTimeout(duration time.Duration) error { + return self.peer.SetWriteDeadline(time.Now().Add(duration)) +} + +func (self *DatagramUnderlay) SetWriteDeadline(deadline time.Time) error { + return self.peer.SetWriteDeadline(deadline) +} + +func (impl *DatagramUnderlay) init(id string, connectionId string, headers Headers) { + impl.id = id + impl.connectionId = connectionId + impl.headers = headers +} + +func (impl *DatagramUnderlay) getPeer() transport.Conn { + return impl.peer +} + +func (self *DatagramUnderlay) rxHello() (*Message, error) { + return self.Rx() +}