Skip to content

Commit

Permalink
Run nft with context in TrafPol
Browse files Browse the repository at this point in the history
Signed-off-by: hwipl <33433250+hwipl@users.noreply.github.com>
  • Loading branch information
hwipl committed Sep 14, 2023
1 parent 1c1c0f0 commit 384172a
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 50 deletions.
10 changes: 6 additions & 4 deletions internal/trafpol/allowdevs.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
package trafpol

import "context"

// AllowDevs contains allowed devices
type AllowDevs struct {
m map[string]string
}

// Add adds device to the allowed devices
func (a *AllowDevs) Add(device string) {
func (a *AllowDevs) Add(ctx context.Context, device string) {
if a.m[device] != device {
a.m[device] = device
addAllowedDevice(device)
addAllowedDevice(ctx, device)
}
}

// Remove removes device from the allowed devices
func (a *AllowDevs) Remove(device string) {
func (a *AllowDevs) Remove(ctx context.Context, device string) {
if a.m[device] == device {
delete(a.m, device)
removeAllowedDevice(device)
removeAllowedDevice(ctx, device)
}
}

Expand Down
17 changes: 10 additions & 7 deletions internal/trafpol/allowdevs_test.go
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
package trafpol

import (
"context"
"reflect"
"testing"
)

// TestAllowDevsAdd tests Add of AllowDevs
func TestAllowDevsAdd(t *testing.T) {
a := NewAllowDevs()
ctx := context.Background()

got := []string{}
runNft = func(s string) {
runNft = func(ctx context.Context, s string) {
got = append(got, s)
}

// test adding
want := []string{
"add element inet oc-daemon-filter allowdevs { eth3 }",
}
a.Add("eth3")
a.Add(ctx, "eth3")
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}

// test adding again
// should not change anything
a.Add("eth3")
a.Add(ctx, "eth3")
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}
Expand All @@ -34,26 +36,27 @@ func TestAllowDevsAdd(t *testing.T) {
// TestAllowDevsRemove tests Remove of AllowDevs
func TestAllowDevsRemove(t *testing.T) {
a := NewAllowDevs()
ctx := context.Background()

got := []string{}
runNft = func(s string) {
runNft = func(ctx context.Context, s string) {
got = append(got, s)
}

// test removing device
a.Add("eth3")
a.Add(ctx, "eth3")
want := []string{
"delete element inet oc-daemon-filter allowdevs { eth3 }",
}
got = []string{}
a.Remove("eth3")
a.Remove(ctx, "eth3")
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}

// test removing again (not existing device)
// should not change anything
a.Remove("eth3")
a.Remove(ctx, "eth3")
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}
Expand Down
8 changes: 4 additions & 4 deletions internal/trafpol/allowhosts.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func (a *AllowHosts) getAndClearUpdates() bool {
}

// setFilter sets the allowed hosts in the traffic filter
func (a *AllowHosts) setFilter() {
func (a *AllowHosts) setFilter(ctx context.Context) {
a.Lock()
defer a.Unlock()

Expand All @@ -190,14 +190,14 @@ func (a *AllowHosts) setFilter() {
}

// set ips in traffic filter
setAllowedIPs(ips)
setAllowedIPs(ctx, ips)
}

// update updates all allowed hosts
func (a *AllowHosts) update(ctx context.Context, upDone chan<- struct{}) {
a.resolveAll(ctx)
if a.getAndClearUpdates() {
a.setFilter()
a.setFilter(ctx)
}
upDone <- struct{}{}
}
Expand Down Expand Up @@ -228,7 +228,7 @@ func (a *AllowHosts) resolvePeriodic(ctx context.Context) {
func (a *AllowHosts) updatePeriodic(ctx context.Context, upDone chan<- struct{}) {
a.resolvePeriodic(ctx)
if a.getAndClearUpdates() {
a.setFilter()
a.setFilter(ctx)
}
upDone <- struct{}{}
}
Expand Down
44 changes: 25 additions & 19 deletions internal/trafpol/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package trafpol

import (
"bytes"
"context"
"errors"
"fmt"
"net"
"os/exec"
Expand All @@ -11,17 +13,21 @@ import (
)

// runNft runs nft and passes s to it via stdin
var runNft = func(s string) {
var runNft = func(ctx context.Context, s string) {
cmd := "nft -f -"
c := exec.Command("bash", "-c", cmd)
c := exec.CommandContext(ctx, "bash", "-c", cmd)
c.Stdin = bytes.NewBufferString(s)
if err := c.Run(); err != nil {
if errors.Is(err, context.Canceled) {
log.WithError(err).Debug("TrafPol nft execution canceled")
return
}
log.WithError(err).Error("TrafPol nft execution error")
}
}

// setFilterRules sets the filter rules
func setFilterRules(fwMark string) {
func setFilterRules(ctx context.Context, fwMark string) {
const filterRules = `
table inet oc-daemon-filter {
# set for allowed devices
Expand Down Expand Up @@ -164,59 +170,59 @@ table inet oc-daemon-filter {
`
r := strings.NewReplacer("$FWMARK", fwMark)
rules := r.Replace(filterRules)
runNft(rules)
runNft(ctx, rules)
}

// unsetFilterRules unsets the filter rules
func unsetFilterRules() {
runNft("delete table inet oc-daemon-filter")
func unsetFilterRules(ctx context.Context) {
runNft(ctx, "delete table inet oc-daemon-filter")
}

// addAllowedDevice adds device to the allowed devices
func addAllowedDevice(device string) {
func addAllowedDevice(ctx context.Context, device string) {
nftconf := fmt.Sprintf("add element inet oc-daemon-filter allowdevs { %s }", device)
runNft(nftconf)
runNft(ctx, nftconf)
}

// removeAllowedDevice removes device from the allowed devices
func removeAllowedDevice(device string) {
func removeAllowedDevice(ctx context.Context, device string) {
nftconf := fmt.Sprintf("delete element inet oc-daemon-filter allowdevs { %s }", device)
runNft(nftconf)
runNft(ctx, nftconf)
}

// setAllowedIPs set the allowed hosts
func setAllowedIPs(ips []*net.IPNet) {
func setAllowedIPs(ctx context.Context, ips []*net.IPNet) {
// we perform all nft commands separately here and not as one atomic
// operation to avoid issues where the whole update fails because nft
// runs into "file exists" errors even though we remove duplicates from
// ips before calling this function and we flush the existing entries

runNft("flush set inet oc-daemon-filter allowhosts4")
runNft("flush set inet oc-daemon-filter allowhosts6")
runNft(ctx, "flush set inet oc-daemon-filter allowhosts4")
runNft(ctx, "flush set inet oc-daemon-filter allowhosts6")

fmt4 := "add element inet oc-daemon-filter allowhosts4 { %s }"
fmt6 := "add element inet oc-daemon-filter allowhosts6 { %s }"
for _, ip := range ips {
if ip.IP.To4() != nil {
// ipv4 address
runNft(fmt.Sprintf(fmt4, ip))
runNft(ctx, fmt.Sprintf(fmt4, ip))
} else {
// ipv6 address
runNft(fmt.Sprintf(fmt6, ip))
runNft(ctx, fmt.Sprintf(fmt6, ip))
}
}
}

// addPortalPorts adds ports for a captive portal to the allowed ports
func addPortalPorts() {
func addPortalPorts(ctx context.Context) {
nftconf := "add element inet oc-daemon-filter allowports { 80, 443 }"
runNft(nftconf)
runNft(ctx, nftconf)
}

// removePortalPorts removes ports for a captive portal from the allowed ports
func removePortalPorts() {
func removePortalPorts(ctx context.Context) {
nftconf := "delete element inet oc-daemon-filter allowports { 80, 443 }"
runNft(nftconf)
runNft(ctx, nftconf)
}

// runCleanupNft runs nft for cleanups
Expand Down
25 changes: 15 additions & 10 deletions internal/trafpol/trafpol.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package trafpol

import (
"context"

log "github.com/sirupsen/logrus"
"github.com/telekom-mms/oc-daemon/internal/cpd"
"github.com/telekom-mms/oc-daemon/internal/devmon"
Expand All @@ -25,18 +27,18 @@ type TrafPol struct {
}

// handleDeviceUpdate handles a device update
func (t *TrafPol) handleDeviceUpdate(u *devmon.Update) {
func (t *TrafPol) handleDeviceUpdate(ctx context.Context, u *devmon.Update) {
// skip physical devices and only allow virtual devices
if u.Type == "device" {
return
}

// add or remove virtual device to/from allowed devices
if u.Add {
t.allowDevs.Add(u.Device)
t.allowDevs.Add(ctx, u.Device)
return
}
t.allowDevs.Remove(u.Device)
t.allowDevs.Remove(ctx, u.Device)
}

// handleDNSUpdate handles a dns config update
Expand All @@ -49,7 +51,7 @@ func (t *TrafPol) handleDNSUpdate() {
}

// handleCPDReport handles a CPD report
func (t *TrafPol) handleCPDReport(report *cpd.Report) {
func (t *TrafPol) handleCPDReport(ctx context.Context, report *cpd.Report) {
if !report.Detected {
// no captive portal detected
// check if there was a portal before
Expand All @@ -59,15 +61,15 @@ func (t *TrafPol) handleCPDReport(report *cpd.Report) {
t.allowHosts.Update()

// remove ports from allowed ports
removePortalPorts()
removePortalPorts(ctx)
t.capPortal = false
}
return
}

// add ports to allowed ports
if !t.capPortal {
addPortalPorts()
addPortalPorts(ctx)
t.capPortal = true
}
}
Expand All @@ -77,9 +79,12 @@ func (t *TrafPol) start() {
log.Debug("TrafPol starting")
defer close(t.loopDone)

// create context
ctx := context.Background()

// set firewall config
setFilterRules(t.config.FirewallMark)
defer unsetFilterRules()
setFilterRules(ctx, t.config.FirewallMark)
defer unsetFilterRules(ctx)

// add CPD hosts to allowed hosts
for _, h := range t.cpd.Hosts() {
Expand Down Expand Up @@ -108,7 +113,7 @@ func (t *TrafPol) start() {
case u := <-t.devmon.Updates():
// Device Update
log.WithField("update", u).Debug("TrafPol got DevMon update")
t.handleDeviceUpdate(u)
t.handleDeviceUpdate(ctx, u)

case <-t.dnsmon.Updates():
// DNS Update
Expand All @@ -118,7 +123,7 @@ func (t *TrafPol) start() {
case r := <-t.cpd.Results():
// CPD Result
log.WithField("result", r).Debug("TrafPol got CPD result")
t.handleCPDReport(r)
t.handleCPDReport(ctx, r)

case <-t.done:
// shutdown
Expand Down
Loading

0 comments on commit 384172a

Please sign in to comment.