Skip to content

Commit

Permalink
Merge pull request #135 from openziti/dtls-support
Browse files Browse the repository at this point in the history
DTLS support. Fixes #134
  • Loading branch information
plorenz authored Aug 19, 2024
2 parents 84786ad + 792ba99 commit a67fa18
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 462 deletions.
7 changes: 7 additions & 0 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,13 @@ type Underlay interface {
GetRemoteAddr() net.Addr
}

type classicUnderlay interface {
Underlay
getPeer() transport.Conn
init(id string, connectionId string, headers Headers)
rxHello() (*Message, error)
}

const AnyContentType = -1
const HelloSequence = -1

Expand Down
84 changes: 49 additions & 35 deletions classic_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,52 +17,63 @@
package channel

import (
"errors"
"fmt"
"github.com/michaelquigley/pfxlog"
"github.com/openziti/identity"
"github.com/openziti/transport/v2"
"github.com/pkg/errors"
"time"
)

type classicDialer struct {
identity *identity.TokenId
endpoint transport.Address
localBinding string
headers map[int32][]byte
identity *identity.TokenId
endpoint transport.Address
localBinding string
headers map[int32][]byte
underlayFactory func(peer transport.Conn, version uint32) classicUnderlay
}

func NewClassicDialerWithBindAddress(identity *identity.TokenId, endpoint transport.Address, localBinding string, headers map[int32][]byte) UnderlayFactory {
return &classicDialer{
result := &classicDialer{
identity: identity,
endpoint: endpoint,
localBinding: localBinding,
headers: headers,
}

if endpoint.Type() == "dtls" {
result.underlayFactory = newDatagramUnderlay
} else {
result.underlayFactory = newClassicImpl
}

return result
}

func NewClassicDialer(identity *identity.TokenId, endpoint transport.Address, headers map[int32][]byte) UnderlayFactory {
return NewClassicDialerWithBindAddress(identity, endpoint, "", headers)
}

func (dialer *classicDialer) Create(timeout time.Duration, tcfg transport.Configuration) (Underlay, error) {
log := pfxlog.ContextLogger(dialer.endpoint.String())
func (self *classicDialer) Create(timeout time.Duration, tcfg transport.Configuration) (Underlay, error) {
log := pfxlog.ContextLogger(self.endpoint.String())
log.Debug("started")
defer log.Debug("exited")

deadline := time.Now().Add(timeout)

version := uint32(2)
tryCount := 0

log.Debugf("Attempting to dial with bind: %s", dialer.localBinding)
log.Debugf("Attempting to dial with bind: %s", self.localBinding)

for {
peer, err := dialer.endpoint.DialWithLocalBinding("classic", dialer.localBinding, dialer.identity, timeout, tcfg)
for time.Now().Before(deadline) {
peer, err := self.endpoint.DialWithLocalBinding("classic", self.localBinding, self.identity, timeout, tcfg)
if err != nil {
return nil, err
}

impl := newClassicImpl(peer, version)
if err := dialer.sendHello(impl, timeout); err != nil {
underlay := self.underlayFactory(peer, version)
if err = self.sendHello(underlay, deadline); err != nil {
if tryCount > 0 {
return nil, err
} else {
Expand All @@ -73,57 +84,60 @@ func (dialer *classicDialer) Create(timeout time.Duration, tcfg transport.Config
log.Warnf("Retrying dial with protocol version %v", version)
continue
}
return impl, nil
return underlay, nil
}
return nil, errors.New("timeout waiting for dial")
}

func (dialer *classicDialer) sendHello(impl *classicImpl, timeout time.Duration) error {
log := pfxlog.ContextLogger(impl.Label())
func (self *classicDialer) sendHello(underlay classicUnderlay, deadline time.Time) error {
log := pfxlog.ContextLogger(underlay.Label())
defer log.Debug("exited")
log.Debug("started")

if timeout == 0 {
timeout = time.Minute
}
peer := underlay.getPeer()

if err := impl.peer.SetReadDeadline(time.Now().Add(timeout)); err != nil {
if err := peer.SetDeadline(deadline); err != nil {
return err
}

defer func() {
_ = impl.peer.SetReadDeadline(time.Time{})
if err := peer.SetDeadline(time.Time{}); err != nil { // clear write deadline
log.WithError(err).Error("unable to clear deadline")
}
}()

request := NewHello(dialer.identity.Token, dialer.headers)
request.sequence = HelloSequence
if err := impl.Tx(request); err != nil {
_ = impl.peer.Close()
request := NewHello(self.identity.Token, self.headers)
request.SetSequence(HelloSequence)
if err := underlay.Tx(request); err != nil {
_ = underlay.Close()
return err
}

response, err := impl.Rx()
response, err := underlay.Rx()
if err != nil {
if errors.Is(err, BadMagicNumberError) {
return errors.Errorf("could not negotiate connection with %v, invalid header", impl.peer.RemoteAddr().String())
return fmt.Errorf("could not negotiate connection with %v, invalid header", peer.RemoteAddr().String())
}
return err
}
if !response.IsReplyingTo(request.sequence) || response.ContentType != ContentTypeResultType {
return fmt.Errorf("channel synchronization error, expected %v, got %v", request.sequence, response.ReplyFor())
if !response.IsReplyingTo(HelloSequence) || response.ContentType != ContentTypeResultType {
return fmt.Errorf("channel synchronization error, expected %v, got %v", HelloSequence, response.ReplyFor())
}
result := UnmarshalResult(response)
if !result.Success {
return errors.New(result.Message)
}
impl.connectionId = string(response.Headers[ConnectionIdHeader])

if id, ok := response.GetStringHeader(IdHeader); ok {
impl.id = &identity.TokenId{Token: id}
} else if certs := impl.Certificates(); len(certs) > 0 {
impl.id = &identity.TokenId{Token: certs[0].Subject.CommonName}
connectionId := string(response.Headers[ConnectionIdHeader])
id := ""

if val, ok := response.GetStringHeader(IdHeader); ok {
id = val
} else if certs := underlay.Certificates(); len(certs) > 0 {
id = certs[0].Subject.CommonName
}

impl.headers = response.Headers
underlay.init(id, connectionId, response.Headers)

return nil
}
17 changes: 13 additions & 4 deletions classic_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package channel
import (
"crypto/x509"
"fmt"
"github.com/openziti/identity"
"github.com/openziti/transport/v2"
"github.com/pkg/errors"
"net"
Expand All @@ -29,7 +28,7 @@ import (

type classicImpl struct {
peer transport.Conn
id *identity.TokenId
id string
connectionId string
headers map[int32][]byte
closed atomic.Bool
Expand Down Expand Up @@ -86,7 +85,7 @@ func (impl *classicImpl) Tx(m *Message) error {
}

func (impl *classicImpl) Id() string {
return impl.id.Token
return impl.id
}

func (impl *classicImpl) Headers() map[int32][]byte {
Expand Down Expand Up @@ -120,7 +119,17 @@ func (impl *classicImpl) IsClosed() bool {
return impl.closed.Load()
}

func newClassicImpl(peer transport.Conn, version uint32) *classicImpl {
func (impl *classicImpl) init(id string, connectionId string, headers Headers) {
impl.id = id
impl.connectionId = connectionId
impl.headers = headers
}

func (impl *classicImpl) getPeer() transport.Conn {
return impl.peer
}

func newClassicImpl(peer transport.Conn, version uint32) classicUnderlay {
readF := ReadV2
marshalF := MarshalV2

Expand Down
Loading

0 comments on commit a67fa18

Please sign in to comment.