Skip to content

Commit

Permalink
add packet reflection safeguard in shim
Browse files Browse the repository at this point in the history
  • Loading branch information
JordiSubira committed Nov 4, 2023
1 parent 76805a0 commit 8a16a50
Show file tree
Hide file tree
Showing 3 changed files with 377 additions and 68 deletions.
18 changes: 17 additions & 1 deletion dispatcher/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
load("//tools/lint:go.bzl", "go_library")
load("//tools/lint:go.bzl", "go_library", "go_test")

go_library(
name = "go_default_library",
Expand All @@ -14,6 +14,22 @@ go_library(
"//pkg/slayers/path/epic:go_default_library",
"//pkg/slayers/path/scion:go_default_library",
"//private/topology:go_default_library",
"//private/underlay/sockctrl:go_default_library",
"@com_github_google_gopacket//:go_default_library",
],
)

go_test(
name = "go_default_test",
srcs = ["dispatcher_test.go"],
embed = [":go_default_library"],
deps = [
"//pkg/addr:go_default_library",
"//pkg/private/xtest:go_default_library",
"//pkg/snet:go_default_library",
"//pkg/snet/path:go_default_library",
"//private/topology:go_default_library",
"@com_github_stretchr_testify//assert:go_default_library",
"@com_github_stretchr_testify//require:go_default_library",
],
)
229 changes: 162 additions & 67 deletions dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020 Anapaya Systems
// Copyright 2023 ETH Zurich
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -15,9 +15,12 @@
package dispatcher

import (
"bytes"
"encoding/binary"
"fmt"
"net"
"net/netip"
"syscall"

"github.com/google/gopacket"

Expand All @@ -29,6 +32,7 @@ import (
"github.com/scionproto/scion/pkg/slayers/path/epic"
"github.com/scionproto/scion/pkg/slayers/path/scion"
"github.com/scionproto/scion/private/topology"
"github.com/scionproto/scion/private/underlay/sockctrl"
)

const ErrUnsupportedL4 common.ErrMsg = "unsupported SCION L4 protocol"
Expand All @@ -43,6 +47,7 @@ type Server struct {
conn *net.UDPConn

buf []byte
oobuf []byte
outBuffer gopacket.SerializeBuffer
decoded []gopacket.LayerType
parser *gopacket.DecodingLayerParser
Expand All @@ -60,8 +65,9 @@ type Server struct {
func NewServer(topo map[addr.AS]*topology.Loader, conn *net.UDPConn) *Server {
server := Server{
topo: topo,
conn: conn,
conn: setPktInfo(conn),
buf: make([]byte, common.SupportedMTU),
oobuf: make([]byte, 1024),
decoded: make([]gopacket.LayerType, 4),
outBuffer: gopacket.NewSerializeBuffer(),
options: gopacket.SerializeOptions{
Expand Down Expand Up @@ -92,93 +98,127 @@ func (s *Server) Serve() error {
for {
s.buf = s.buf[:cap(s.buf)]

n, nextHop, err := s.conn.ReadFromUDPAddrPort(s.buf)
n, _, _, previousHop, err := s.conn.ReadMsgUDPAddrPort(s.buf, s.oobuf)
if err != nil {
log.Error("Decoding layers", "err", err)
log.Error("Reading message", "err", err)
continue
}

err = s.parser.DecodeLayers(s.buf[:n], &s.decoded)
nextHopAddr, err := s.processMsgNextHop(s.buf[:n], previousHop)
if err != nil {
log.Error("Decoding layers", "err", err)
continue
return err
}
if len(s.decoded) < 2 {
log.Error("Unexpected decode packet", "layers decoded", len(s.decoded))
if nextHopAddr == nil {
// some error parsing the address from the incoming packet;
// we discard the packet and keep serving.
continue
}
err = s.outBuffer.Clear()
if err != nil {
return err
}
// Packet handling
var dstAddrPort netip.AddrPort
switch s.decoded[len(s.decoded)-1] {
case slayers.LayerTypeSCMP:
// send response to BR
if s.scmpLayer.TypeCode.Type() == slayers.SCMPTypeTracerouteRequest ||
s.scmpLayer.TypeCode.Type() == slayers.SCMPTypeEchoRequest {
err = s.reverseSCMPInfo()
if err != nil {
return err
}
dstAddrPort = nextHop
} else { // rely to end application
dstAddrPort, err = s.getDstSCMP()
if err != nil {
log.Error("Getting destination for SCMP message", "err", err)
continue
}
}
payload := gopacket.Payload(s.scmpLayer.Payload)
err = payload.SerializeTo(s.outBuffer, s.options)
if err != nil {
return err
}
s.outBuffer.PushLayer(payload.LayerType())

err = s.scmpLayer.SerializeTo(s.outBuffer, s.options)
if err != nil {
return err
// If the incoming message is SCMP Informational message (Echo or Traceroute)
// the dispatcher has already processed the response, and the nextHop address
// belongs to the BR.
isSCMPinfo := s.scmpLayer.TypeCode.Type() == slayers.SCMPTypeTracerouteRequest ||
s.scmpLayer.TypeCode.Type() == slayers.SCMPTypeEchoRequest
if isSCMPinfo || validateNextHopAddr(*nextHopAddr, s.oobuf) {
m, err := s.conn.WriteToUDPAddrPort(s.outBuffer.Bytes(), *nextHopAddr)
if err != nil || m != len(s.outBuffer.Bytes()) {
log.Error("writing packet out", "err", err)
}
s.outBuffer.PushLayer(s.scmpLayer.LayerType())
case slayers.LayerTypeSCIONUDP:
dstAddrPort, err = s.getDstSCIONUDP()
}
}
}

// processMsgNextHop processes the message arriving to the shim dispatcher and serializes
// the outbound packet into buf. The intended nextHop address, either the end application,
// or the next BR (for SCMP informational response) is returned. It returns a non-nil error
// for non-recoverable errors, only. If the incoming packet couldn't be processed due to a
// recoverable error, the returned address will be nil. The caller must check both values
// consistently.
func (s *Server) processMsgNextHop(buf []byte, prevHop netip.AddrPort) (*netip.AddrPort, error) {
err := s.parser.DecodeLayers(buf, &s.decoded)
if err != nil {
log.Error("Decoding layers", "err", err)
return nil, nil
}
if len(s.decoded) < 2 {
log.Error("Unexpected decode packet", "layers decoded", len(s.decoded))
return nil, nil
}
err = s.outBuffer.Clear()
if err != nil {
return nil, err
}
// Packet handling
var dstAddrPort netip.AddrPort
switch s.decoded[len(s.decoded)-1] {
case slayers.LayerTypeSCMP:
// send response to BR
if s.scmpLayer.TypeCode.Type() == slayers.SCMPTypeTracerouteRequest ||
s.scmpLayer.TypeCode.Type() == slayers.SCMPTypeEchoRequest {
err = s.reverseSCMPInfo()
if err != nil {
log.Error("Getting destination for SCION/UDP message", "err", err)
continue
log.Error("Reversing SCMP information", "err", err)
return nil, nil
}
payload := gopacket.Payload(s.udpLayer.Payload)
err = payload.SerializeTo(s.outBuffer, s.options)
dstAddrPort = prevHop
} else { // relay to end application
dstAddrPort, err = s.getDstSCMP()
if err != nil {
return err
log.Error("Getting destination for SCMP message", "err", err)
return nil, nil
}
s.outBuffer.PushLayer(payload.LayerType())
}
payload := gopacket.Payload(s.scmpLayer.Payload)
err = payload.SerializeTo(s.outBuffer, s.options)
if err != nil {
log.Error("Serializing payload", "err", err)
return nil, nil
}
s.outBuffer.PushLayer(payload.LayerType())

err = s.udpLayer.SerializeTo(s.outBuffer, s.options)
if err != nil {
return err
}
s.outBuffer.PushLayer(s.udpLayer.LayerType())
err = s.scmpLayer.SerializeTo(s.outBuffer, s.options)
if err != nil {
log.Error("Serializing SCMP header", "err", err)
return nil, nil
}
if s.decoded[len(s.decoded)-2] == slayers.LayerTypeEndToEndExtn {
err = s.e2e.SerializeTo(s.outBuffer, s.options)
if err != nil {
return err
}
s.outBuffer.PushLayer(s.e2e.LayerType())
s.outBuffer.PushLayer(s.scmpLayer.LayerType())
case slayers.LayerTypeSCIONUDP:
dstAddrPort, err = s.getDstSCIONUDP()
if err != nil {
log.Error("Getting destination for SCION/UDP message", "err", err)
return nil, nil
}
err = s.scionLayer.SerializeTo(s.outBuffer, s.options)
payload := gopacket.Payload(s.udpLayer.Payload)
err = payload.SerializeTo(s.outBuffer, s.options)
if err != nil {
return err
log.Error("Serializing payload", "err", err)
return nil, nil
}
s.outBuffer.PushLayer(s.scionLayer.LayerType())
s.outBuffer.PushLayer(payload.LayerType())

m, err := s.conn.WriteToUDPAddrPort(s.outBuffer.Bytes(), dstAddrPort)
if err != nil || m != len(s.outBuffer.Bytes()) {
log.Error("writing packet out", "err", err)
err = s.udpLayer.SerializeTo(s.outBuffer, s.options)
if err != nil {
log.Error("Serializing udp header", "err", err)
return nil, nil
}
s.outBuffer.PushLayer(s.udpLayer.LayerType())
}
if s.decoded[len(s.decoded)-2] == slayers.LayerTypeEndToEndExtn {
err = s.e2e.SerializeTo(s.outBuffer, s.options)
if err != nil {
log.Error("Serializing e2e extension", "err", err)
return nil, nil
}
s.outBuffer.PushLayer(s.e2e.LayerType())
}
err = s.scionLayer.SerializeTo(s.outBuffer, s.options)
if err != nil {
log.Error("Serializing SCION header", "err", err)
return nil, nil
}
s.outBuffer.PushLayer(s.scionLayer.LayerType())

return &dstAddrPort, nil
}

func (s *Server) reverseSCMPInfo() error {
Expand Down Expand Up @@ -382,3 +422,58 @@ func addrPortFromBytes(addr []byte, port uint16) (netip.AddrPort, error) {
}
return netip.AddrPortFrom(a, port), nil
}

func setPktInfo(conn *net.UDPConn) *net.UDPConn {
err := sockctrl.SetsockoptInt(conn, syscall.IPPROTO_IP, syscall.IP_PKTINFO, 1)
if err != nil {
panic(fmt.Sprintf("cannot set IP_PKTINFO on socket: %s", err))
}
return conn
}

// validateNextHopAddr return true if the underlay address on the UDP/IP wrapper
// header corresponds to the address on the encapsulated UDP/SCION header, otherwise
// it returns false. This implements a safeguard for traffic reflection as discussed in:
// https://github.com/scionproto/scion/pull/4280#issuecomment-1775177351
func validateNextHopAddr(addr netip.AddrPort, oobuffer []byte) bool {
buf := bytes.NewBuffer(oobuffer)

msg := syscall.Cmsghdr{}
if err := binary.Read(buf, binary.LittleEndian, &msg); err != nil {
log.Error("Parsing message", "err", err)
return false
}
if msg.Level == syscall.IPPROTO_IP && msg.Type == syscall.IP_PKTINFO {
if addr.Addr().Unmap().Is4() {
packet_info := syscall.Inet4Pktinfo{}
if err := binary.Read(buf, binary.LittleEndian, &packet_info); err != nil {
log.Error("Parsing Inet4 PKT_INFO", "err", err)
return false
}
pktAddr := netip.AddrFrom4(packet_info.Addr)
if addr.Addr().Unmap().Compare(pktAddr) != 0 {
log.Error("UDP/IP addr destination different from UDP/SCION addr",
"UDP/IP:", pktAddr.String(),
"UDP/SCION:", addr.Addr().String())
return false
}
return true
}
if addr.Addr().Unmap().Is6() {
packet_info := syscall.Inet6Pktinfo{}
if err := binary.Read(buf, binary.LittleEndian, &packet_info); err != nil {
log.Error("Parsing Inet6 PKT_INFO", "err", err)
return false
}
pktAddr := netip.AddrFrom16(packet_info.Addr)
if addr.Addr().Unmap().Compare(pktAddr) != 0 {
log.Error("UDP/IP addr destination different from UDP/SCION addr",
"UDP/IP:", pktAddr.String(),
"UDP/SCION:", addr.Addr().String())
return false
}
return true
}
}
return false
}
Loading

0 comments on commit 8a16a50

Please sign in to comment.