diff --git a/.golangci.yml b/.golangci.yml index 02ed57d6..425896f6 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -5,7 +5,7 @@ run: # default concurrency is a available CPU number concurrency: 4 # timeout for analysis, e.g. 30s, 5m, default is 1m - timeout: 3m + timeout: 5m # exit code when at least one issue was found, default is 1 issues-exit-code: 1 # include test files or not, default is true diff --git a/go.mod b/go.mod index a5d6367b..07b0fd7e 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,13 @@ go 1.21 require ( github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/free5gc/aper v1.0.6-0.20240503143507-2c4c4780b98f + github.com/free5gc/ike v1.1.1-0.20241014015325-083f89768f43 github.com/free5gc/ngap v1.0.9-0.20240708062829-734d184eed74 github.com/free5gc/sctp v1.0.1 - github.com/free5gc/util v1.0.7-0.20240713162917-350ee8f4af4c + github.com/free5gc/util v1.0.7-0.20241017071924-da29aef99a1c github.com/gin-contrib/pprof v1.5.0 github.com/gin-gonic/gin v1.10.0 + github.com/google/gopacket v1.1.19 github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 github.com/pkg/errors v0.9.1 github.com/sirupsen/logrus v1.9.3 diff --git a/go.sum b/go.sum index d93d4cd2..199902dd 100644 --- a/go.sum +++ b/go.sum @@ -16,14 +16,16 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/free5gc/aper v1.0.6-0.20240503143507-2c4c4780b98f h1:sO8FFwAq7feSw/vKN9ioY+fX1gNTXd6/xQOqaeclzsA= github.com/free5gc/aper v1.0.6-0.20240503143507-2c4c4780b98f/go.mod h1:oh3dtNsje2W4/q3pfidMWQKXbXIehXK3t6CD9tXmHx0= +github.com/free5gc/ike v1.1.1-0.20241014015325-083f89768f43 h1:cgpG06umqWTAwYy/bLXXcdNg+k7+qkinsElCVZzuOSI= +github.com/free5gc/ike v1.1.1-0.20241014015325-083f89768f43/go.mod h1:57Ujd9Xjva02mt3OVfepYKiheFHO5Y0YCQyBgB1p1Qs= github.com/free5gc/ngap v1.0.9-0.20240708062829-734d184eed74 h1:foSd3OVtTfDmn3EZbsBngK+U93Mv8YE+qSja7FvKEVU= github.com/free5gc/ngap v1.0.9-0.20240708062829-734d184eed74/go.mod h1:UsPP9LWVyNwu5sm7ZE5toAFeBNkkyj0rh+4Q3ylRBi8= github.com/free5gc/openapi v1.0.9-0.20240503143645-eac9f06c2f6b h1:+VcgZq+3apB6Xr4jEqgGf/uAECRF038SwixEvvxhYrM= github.com/free5gc/openapi v1.0.9-0.20240503143645-eac9f06c2f6b/go.mod h1:0qRW+H1/Nyzw5tjjvyp+90m+2SOZZefGQC9QV8iPwu8= github.com/free5gc/sctp v1.0.1 h1:g8WDO97r8B9ubkT5Hyk9b4I1fZUOii9Z39gQ2eRaASo= github.com/free5gc/sctp v1.0.1/go.mod h1:7QXfRWCmlkBGD0EIu3qL5o71bslfIakydz4h2QDZdjQ= -github.com/free5gc/util v1.0.7-0.20240713162917-350ee8f4af4c h1:baToZn4hxGKoCm3BWwYlRuZoCQ74cMZUJzg9BVLEdE0= -github.com/free5gc/util v1.0.7-0.20240713162917-350ee8f4af4c/go.mod h1:IHKIBd4OM9rwSJ0fG/hv6pXbVC+Eu4Lcaq++BWkfSsY= +github.com/free5gc/util v1.0.7-0.20241017071924-da29aef99a1c h1:vJ3IJPvW4gt7i7d3y8KMp42jypeKsfUG+CqSiFRoXAU= +github.com/free5gc/util v1.0.7-0.20241017071924-da29aef99a1c/go.mod h1:IHKIBd4OM9rwSJ0fG/hv6pXbVC+Eu4Lcaq++BWkfSsY= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gin-contrib/pprof v1.5.0 h1:E/Oy7g+kNw94KfdCy3bZxQFtyDnAX2V7axRS7sNYVrU= @@ -49,6 +51,8 @@ github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVI github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= @@ -114,23 +118,35 @@ github.com/wmnsk/go-gtp v0.8.11-0.20240705144331-f53bfdd4233b/go.mod h1:pXocxsDk golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/pkg/context/3gpp_types.go b/internal/context/3gpp_types.go similarity index 100% rename from pkg/context/3gpp_types.go rename to internal/context/3gpp_types.go diff --git a/pkg/context/amf.go b/internal/context/amf.go similarity index 95% rename from pkg/context/amf.go rename to internal/context/amf.go index 1479b2df..fa9b31bf 100644 --- a/pkg/context/amf.go +++ b/internal/context/amf.go @@ -21,7 +21,7 @@ type N3IWFAMF struct { // Overload related AMFOverloadContent *AMFOverloadContent // Relative Context - N3iwfRanUeList map[int64]*N3IWFRanUe // ranUeNgapId as key + N3iwfRanUeList map[int64]RanUe // ranUeNgapId as key } type AMFTNLAssociationItem struct { @@ -47,12 +47,12 @@ func (amf *N3IWFAMF) init(sctpAddr string, conn *sctp.SCTPConn) { amf.SCTPAddr = sctpAddr amf.SCTPConn = conn amf.AMFTNLAssociationList = make(map[string]*AMFTNLAssociationItem) - amf.N3iwfRanUeList = make(map[int64]*N3IWFRanUe) + amf.N3iwfRanUeList = make(map[int64]RanUe) } -func (amf *N3IWFAMF) FindUeByAmfUeNgapID(id int64) *N3IWFRanUe { +func (amf *N3IWFAMF) FindUeByAmfUeNgapID(id int64) RanUe { for _, ranUe := range amf.N3iwfRanUeList { - if ranUe.AmfUeNgapId == id { + if ranUe.GetSharedCtx().AmfUeNgapId == id { return ranUe } } diff --git a/pkg/context/context.go b/internal/context/context.go similarity index 85% rename from pkg/context/context.go rename to internal/context/context.go index 578b5e11..f55f72f9 100644 --- a/pkg/context/context.go +++ b/internal/context/context.go @@ -4,7 +4,7 @@ import ( "context" "crypto/rand" "crypto/rsa" - "crypto/sha1" + "crypto/sha1" // #nosec G505 "crypto/x509" "encoding/pem" "fmt" @@ -24,6 +24,7 @@ import ( "github.com/free5gc/ngap/ngapType" "github.com/free5gc/sctp" "github.com/free5gc/util/idgenerator" + "github.com/free5gc/util/ippool" ) type n3iwf interface { @@ -45,9 +46,9 @@ type N3IWFContext struct { ChildSA sync.Map // map[uint32]*ChildSecurityAssociation, inboundSPI as key GTPConnectionWithUPF sync.Map // map[string]*gtpv1.UPlaneConn, UPF address as key AllocatedUEIPAddress sync.Map // map[string]*N3IWFIkeUe, IPAddr as key - AllocatedUETEID sync.Map // map[uint32]*N3IWFRanUe, TEID as key + AllocatedUETEID sync.Map // map[uint32]*RanUe, TEID as key IKEUePool sync.Map // map[uint64]*N3IWFIkeUe, SPI as key - RANUePool sync.Map // map[int64]*N3IWFRanUe, RanUeNgapID as key + RANUePool sync.Map // map[int64]*RanUe, RanUeNgapID as key IKESPIToNGAPId sync.Map // map[uint64]RanUeNgapID, SPI as key NGAPIdToIKESPI sync.Map // map[uint64]SPI, RanUeNgapID as key @@ -56,12 +57,12 @@ type N3IWFContext struct { N3IWFCertificate []byte N3IWFPrivateKey *rsa.PrivateKey - UeIPRange *net.IPNet + IPSecInnerIPPool *ippool.IPPool + // TODO: [TWIF] TwifUe may has its own IP address pool // XFRM interface XfrmIfaces sync.Map // map[uint32]*netlink.Link, XfrmIfaceId as key XfrmParentIfaceName string - // Every UE's first UP IPsec will use default XFRM interface, additoinal UP IPsec will offset its XFRM id XfrmIfaceIdOffsetForUP uint32 } @@ -120,11 +121,11 @@ func NewContext(n3iwf n3iwf) (*N3IWFContext, error) { n.N3IWFCertificate = block.Bytes // UE IP address range - _, ueIPRange, err := net.ParseCIDR(cfg.GetUEIPAddrRange()) + ueIPPool, err := ippool.NewIPPool(cfg.GetUEIPAddrRange()) if err != nil { - return nil, errors.Errorf("Parse CIDR failed: %+v", err) + return nil, errors.Errorf("NewContext(): %+v", err) } - n.UeIPRange = ueIPRange + n.IPSecInnerIPPool = ueIPPool // XFRM related ikeBindIfaceName, err := getInterfaceName(cfg.GetIKEBindAddr()) @@ -190,11 +191,12 @@ func (c *N3IWFContext) NewN3iwfRanUe() *N3IWFRanUe { return nil } n3iwfRanUe := &N3IWFRanUe{ - N3iwfCtx: c, + RanUeSharedCtx: RanUeSharedCtx{ + N3iwfCtx: c, + }, } n3iwfRanUe.init(ranUeNgapId) c.RANUePool.Store(ranUeNgapId, n3iwfRanUe) - n3iwfRanUe.TemporaryPDUSessionSetupData = new(PDUSessionSetupTemporaryData) return n3iwfRanUe } @@ -218,10 +220,21 @@ func (c *N3IWFContext) IkeUePoolLoad(spi uint64) (*N3IWFIkeUe, bool) { } } -func (c *N3IWFContext) RanUePoolLoad(ranUeNgapId int64) (*N3IWFRanUe, bool) { +func (c *N3IWFContext) RanUePoolLoad(id interface{}) (RanUe, bool) { + var ranUeNgapId int64 + + cfgLog := logger.CfgLog + switch id := id.(type) { + case int64: + ranUeNgapId = id + default: + cfgLog.Warnf("RanUePoolLoad unhandle type: %t", id) + return nil, false + } + ranUe, ok := c.RANUePool.Load(ranUeNgapId) if ok { - return ranUe.(*N3IWFRanUe), ok + return ranUe.(RanUe), ok } else { return nil, ok } @@ -256,7 +269,7 @@ func (c *N3IWFContext) DeleteIkeSPIFromNgapId(ranUeNgapId int64) { c.NGAPIdToIKESPI.Delete(ranUeNgapId) } -func (c *N3IWFContext) RanUeLoadFromIkeSPI(spi uint64) (*N3IWFRanUe, error) { +func (c *N3IWFContext) RanUeLoadFromIkeSPI(spi uint64) (RanUe, error) { ranNgapId, ok := c.IKESPIToNGAPId.Load(spi) if ok { ranUe, err := c.RanUePoolLoad(ranNgapId.(int64)) @@ -371,26 +384,30 @@ func (c *N3IWFContext) GTPConnectionWithUPFStore(upfAddr string, conn *gtpv1.UPl c.GTPConnectionWithUPF.Store(upfAddr, conn) } -func (c *N3IWFContext) NewInternalUEIPAddr(ikeUe *N3IWFIkeUe) net.IP { +func (c *N3IWFContext) NewIPsecInnerUEIP(ikeUe *N3IWFIkeUe) (net.IP, error) { var ueIPAddr net.IP - + var err error cfg := c.Config() ipsecGwAddr := cfg.GetIPSecGatewayAddr() - // TODO: Check number of allocated IP to detect running out of IPs + for { - ueIPAddr = generateRandomIPinRange(c.UeIPRange) - if ueIPAddr != nil { - if ueIPAddr.String() == ipsecGwAddr { - continue - } - _, ok := c.AllocatedUEIPAddress.LoadOrStore(ueIPAddr.String(), ikeUe) - if !ok { - break - } + ueIPAddr, err = c.IPSecInnerIPPool.Allocate(nil) + if err != nil { + return nil, errors.Wrapf(err, "NewIPsecInnerUEIP()") + } + if ueIPAddr.String() == ipsecGwAddr { + continue + } + _, ok := c.AllocatedUEIPAddress.LoadOrStore(ueIPAddr.String(), ikeUe) + if ok { + logger.CtxLog.Warnf("NewIPsecInnerUEIP(): IP(%v) is used by other IkeUE", + ueIPAddr.String()) + } else { + break } } - return ueIPAddr + return ueIPAddr, nil } func (c *N3IWFContext) DeleteInternalUEIPAddr(ipAddr string) { @@ -405,12 +422,16 @@ func (c *N3IWFContext) AllocatedUEIPAddressLoad(ipAddr string) (*N3IWFIkeUe, boo return nil, false } -func (c *N3IWFContext) NewTEID(ranUe *N3IWFRanUe) uint32 { +func (c *N3IWFContext) NewTEID(ranUe RanUe) uint32 { teid64, err := c.TEIDGenerator.Allocate() if err != nil { logger.CtxLog.Errorf("New TEID failed: %+v", err) return 0 } + if teid64 < 0 || teid64 > math.MaxUint32 { + logger.CtxLog.Warnf("NewTEID teid64 out of uint32 range: %d, use maxUint32", teid64) + return 0 + } teid32 := uint32(teid64) c.AllocatedUETEID.Store(teid32, ranUe) @@ -423,10 +444,10 @@ func (c *N3IWFContext) DeleteTEID(teid uint32) { c.AllocatedUETEID.Delete(teid) } -func (c *N3IWFContext) AllocatedUETEIDLoad(teid uint32) (*N3IWFRanUe, bool) { +func (c *N3IWFContext) AllocatedUETEIDLoad(teid uint32) (RanUe, bool) { ranUe, ok := c.AllocatedUETEID.Load(teid) if ok { - return ranUe.(*N3IWFRanUe), ok + return ranUe.(RanUe), ok } return nil, false } @@ -435,9 +456,12 @@ func (c *N3IWFContext) AMFSelection( ueSpecifiedGUAMI *ngapType.GUAMI, ueSpecifiedPLMNId *ngapType.PLMNIdentity, ) *N3IWFAMF { - var availableAMF *N3IWFAMF + var availableAMF, defaultAMF *N3IWFAMF c.AMFPool.Range(func(key, value interface{}) bool { amf := value.(*N3IWFAMF) + if defaultAMF == nil { + defaultAMF = amf + } if amf.FindAvalibleAMFByCompareGUAMI(ueSpecifiedGUAMI) { availableAMF = amf return false @@ -451,24 +475,9 @@ func (c *N3IWFContext) AMFSelection( return true } }) - return availableAMF -} - -func generateRandomIPinRange(subnet *net.IPNet) net.IP { - ipAddr := make([]byte, 4) - randomNumber := make([]byte, 4) - - _, err := rand.Read(randomNumber) - if err != nil { - logger.CtxLog.Errorf("Generate random number for IP address failed: %+v", err) - return nil - } - - // TODO: elimenate network name, gateway, and broadcast - for i := 0; i < 4; i++ { - alter := randomNumber[i] & (subnet.Mask[i] ^ 255) - ipAddr[i] = subnet.IP[i] + alter + if availableAMF == nil && + defaultAMF != nil { + availableAMF = defaultAMF } - - return net.IPv4(ipAddr[0], ipAddr[1], ipAddr[2], ipAddr[3]) + return availableAMF } diff --git a/internal/context/context_test.go b/internal/context/context_test.go new file mode 100644 index 00000000..da321d30 --- /dev/null +++ b/internal/context/context_test.go @@ -0,0 +1,106 @@ +package context_test + +import ( + "context" + "net" + "sync" + "testing" + + "github.com/stretchr/testify/require" + + n3iwf_context "github.com/free5gc/n3iwf/internal/context" + "github.com/free5gc/n3iwf/pkg/factory" + "github.com/free5gc/util/ippool" +) + +type n3iwfTestApp struct { + cfg *factory.Config + n3iwfCtx *n3iwf_context.N3IWFContext + ctx context.Context + cancel context.CancelFunc + wg *sync.WaitGroup +} + +func (a *n3iwfTestApp) Config() *factory.Config { + return a.cfg +} + +func (a *n3iwfTestApp) Context() *n3iwf_context.N3IWFContext { + return a.n3iwfCtx +} + +func (a *n3iwfTestApp) CancelContext() context.Context { + return a.ctx +} + +func NewN3iwfTestApp(cfg *factory.Config) (*n3iwfTestApp, error) { + var err error + ctx, cancel := context.WithCancel(context.Background()) + + n3iwfApp := &n3iwfTestApp{ + cfg: cfg, + ctx: ctx, + cancel: cancel, + wg: &sync.WaitGroup{}, + } + + n3iwfApp.n3iwfCtx, err = n3iwf_context.NewTestContext(n3iwfApp) + if err != nil { + return nil, err + } + return n3iwfApp, err +} + +func NewTestCfg() *factory.Config { + return &factory.Config{ + Configuration: &factory.Configuration{ + IPSecGatewayAddr: "10.0.0.1", + UEIPAddressRange: "10.0.0.0/24", + }, + } +} + +func TestNewInternalUEIPAddr(t *testing.T) { + cfg := NewTestCfg() + var app *n3iwfTestApp + var err error + var ip, invalidIP, invalidIP2 net.IP + + app, err = NewN3iwfTestApp(cfg) + require.NoError(t, err) + + n3iwfCtx := app.n3iwfCtx + + invalidIP = net.ParseIP("10.0.0.0") + invalidIP2 = net.ParseIP("10.0.0.255") + n3iwfCtx.IPSecInnerIPPool, err = ippool.NewIPPool("10.0.0.0/24") + require.NoError(t, err) + + for i := 1; i <= 253; i++ { + ip, err = n3iwfCtx.NewIPsecInnerUEIP(&n3iwf_context.N3IWFIkeUe{}) + require.NoError(t, err) + require.NotEqual(t, cfg.GetIPSecGatewayAddr(), ip.String()) + require.NotEqual(t, ip, invalidIP) + require.NotEqual(t, ip, invalidIP2) + } + + _, err = n3iwfCtx.NewIPsecInnerUEIP(&n3iwf_context.N3IWFIkeUe{}) + require.Error(t, err) + + n3iwfCtx.AllocatedUEIPAddress = sync.Map{} + + n3iwfCtx.IPSecInnerIPPool, err = ippool.NewIPPool("10.0.0.0/16") + require.NoError(t, err) + + invalidIP2 = net.ParseIP("10.0.255.255") + for i := 1; i <= 65533; i++ { + ip, err = n3iwfCtx.NewIPsecInnerUEIP(&n3iwf_context.N3IWFIkeUe{}) + require.NoError(t, err) + require.NotEqual(t, cfg.GetIPSecGatewayAddr(), ip.String()) + require.NotEqual(t, ip, invalidIP) + require.NotEqual(t, ip, invalidIP2) + } + + _, err = n3iwfCtx.NewIPsecInnerUEIP(&n3iwf_context.N3IWFIkeUe{}) + require.Error(t, err) +} diff --git a/internal/context/gtp.go b/internal/context/gtp.go new file mode 100644 index 00000000..e1719a66 --- /dev/null +++ b/internal/context/gtp.go @@ -0,0 +1,30 @@ +package context + +type GtpEventType int64 + +// GTP Event Type +const ( + ForwardUL GtpEventType = iota +) + +type GtpEvt interface { + Type() GtpEventType +} + +type ForwardULEvt struct { + GtpConnInfo *GTPConnectionInfo + QFI *uint8 + Payload []byte +} + +func (forwardDLEvt *ForwardULEvt) Type() GtpEventType { + return ForwardUL +} + +func NewForwardULEvt(gtpConnInfo *GTPConnectionInfo, qfi *uint8, payload []byte) *ForwardULEvt { + return &ForwardULEvt{ + GtpConnInfo: gtpConnInfo, + QFI: qfi, + Payload: payload, + } +} diff --git a/pkg/context/ike.go b/internal/context/ike.go similarity index 100% rename from pkg/context/ike.go rename to internal/context/ike.go diff --git a/pkg/context/ikeue.go b/internal/context/ikeue.go similarity index 61% rename from pkg/context/ikeue.go rename to internal/context/ikeue.go index ebfd65f2..3768d87c 100644 --- a/pkg/context/ikeue.go +++ b/internal/context/ikeue.go @@ -1,13 +1,15 @@ package context import ( - "errors" "fmt" + "math" "net" + "github.com/pkg/errors" "github.com/vishvananda/netlink" - ike_message "github.com/free5gc/n3iwf/pkg/ike/message" + ike_message "github.com/free5gc/ike/message" + ike_security "github.com/free5gc/ike/security" ) const ( @@ -17,23 +19,23 @@ const ( type N3IWFIkeUe struct { N3iwfCtx *N3IWFContext - /* UE identity */ + // UE identity IPSecInnerIP net.IP IPSecInnerIPAddr *net.IPAddr // Used to send UP packets to UE - /* IKE Security Association */ + // IKE Security Association N3IWFIKESecurityAssociation *IKESecurityAssociation N3IWFChildSecurityAssociation map[uint32]*ChildSecurityAssociation // inbound SPI as key - /* Temporary Mapping of two SPIs */ + // Temporary Mapping of two SPIs // Exchange Message ID(including a SPI) and ChildSA(including a SPI) // Mapping of Message ID of exchange in IKE and Child SA when creating new child SA TemporaryExchangeMsgIDChildSAMapping map[uint32]*ChildSecurityAssociation // Message ID as a key - /* Security */ + // Security Kn3iwf []uint8 // 32 bytes (256 bits), value is from NGAP IE "Security Key" - /* NAS IKE Connection */ + // NAS IKE Connection IKEConnection *UDPSocketInfo // Length of PDU Session List @@ -47,6 +49,7 @@ type IkeMsgTemporaryData struct { } type IKESecurityAssociation struct { + *ike_security.IKESAKey // SPI RemoteSPI uint64 LocalSPI uint64 @@ -55,25 +58,8 @@ type IKESecurityAssociation struct { InitiatorMessageID uint32 ResponderMessageID uint32 - // Transforms for IKE SA - EncryptionAlgorithm *ike_message.Transform - PseudorandomFunction *ike_message.Transform - IntegrityAlgorithm *ike_message.Transform - DiffieHellmanGroup *ike_message.Transform - ExpandedSequenceNumber *ike_message.Transform - // Used for key generating - ConcatenatedNonce []byte - DiffieHellmanSharedKey []byte - - // Keys - SK_d []byte // used for child SA key deriving - SK_ai []byte // used by initiator for integrity checking - SK_ar []byte // used by responder for integrity checking - SK_ei []byte // used by initiator for encrypting - SK_er []byte // used by responder for encrypting - SK_pi []byte // used by initiator for IKE authentication - SK_pr []byte // used by responder for IKE authentication + ConcatenatedNonce []byte // State for IKE_AUTH State uint8 @@ -94,11 +80,8 @@ type IKESecurityAssociation struct { InitiatorSignedOctets []byte // NAT detection - // If UEIsBehindNAT == true, N3IWF should enable NAT traversal and - // TODO: should support dynamic updating network address (MOBIKE) - UEIsBehindNAT bool - // If N3IWFIsBehindNAT == true, N3IWF should send UDP keepalive periodically - N3IWFIsBehindNAT bool + UeBehindNAT bool // If true, N3IWF should enable NAT traversal and + N3iwfBehindNAT bool // TODO: If true, N3IWF should send UDP keepalive periodically // IKE UE context IkeUE *N3IWFIkeUe @@ -112,6 +95,13 @@ type IKESecurityAssociation struct { IsUseDPD bool } +func (ikeSA *IKESecurityAssociation) String() string { + return "====== IKE Security Association Info =====" + + "\nInitiator's SPI: " + fmt.Sprintf("%016x", ikeSA.RemoteSPI) + + "\nResponder's SPI: " + fmt.Sprintf("%016x", ikeSA.LocalSPI) + + "\nIKESAKey: " + ikeSA.IKESAKey.String() +} + // Temporary State Data Args const ( ArgsUEUDPConn string = "UE UDP Socket Info" @@ -138,13 +128,7 @@ type ChildSecurityAssociation struct { TrafficSelectorRemote net.IPNet // Security - EncryptionAlgorithm uint16 - InitiatorToResponderEncryptionKey []byte - ResponderToInitiatorEncryptionKey []byte - IntegrityAlgorithm uint16 - InitiatorToResponderIntegrityKey []byte - ResponderToInitiatorIntegrityKey []byte - ESN bool + *ike_security.ChildSAKey // Encapsulate EnableEncapsulate bool @@ -156,6 +140,60 @@ type ChildSecurityAssociation struct { // IKE UE context IkeUE *N3IWFIkeUe + + LocalIsInitiator bool +} + +func (childSA *ChildSecurityAssociation) String(xfrmiId uint32) string { + var inboundEncryptionKey, inboundIntegrityKey, outboundEncryptionKey, outboundIntegrityKey []byte + + if childSA.LocalIsInitiator { + inboundEncryptionKey = childSA.ResponderToInitiatorEncryptionKey + inboundIntegrityKey = childSA.ResponderToInitiatorIntegrityKey + outboundEncryptionKey = childSA.InitiatorToResponderEncryptionKey + outboundIntegrityKey = childSA.InitiatorToResponderIntegrityKey + } else { + inboundEncryptionKey = childSA.InitiatorToResponderEncryptionKey + inboundIntegrityKey = childSA.InitiatorToResponderIntegrityKey + outboundEncryptionKey = childSA.ResponderToInitiatorEncryptionKey + outboundIntegrityKey = childSA.ResponderToInitiatorIntegrityKey + } + + return fmt.Sprintf("====== IPSec/Child SA Info ======"+ + "\n====== Inbound ======"+ + "\nXFRM interface if_id: %d"+ + "\nIPSec Inbound SPI: 0x%08x"+ + "\n[UE:%+v] -> [N3IWF:%+v]"+ + "\nIPSec Encryption Algorithm: %d"+ + "\nIPSec Encryption Key: 0x%x"+ + "\nIPSec Integrity Algorithm: %d"+ + "\nIPSec Integrity Key: 0x%x"+ + "\n====== IPSec/Child SA Info ======"+ + "\n====== Outbound ======"+ + "\nXFRM interface if_id: %d"+ + "\nIPSec Outbound SPI: 0x%08x"+ + "\n[N3IWF:%+v] -> [UE:%+v]"+ + "\nIPSec Encryption Algorithm: %d"+ + "\nIPSec Encryption Key: 0x%x"+ + "\nIPSec Integrity Algorithm: %d"+ + "\nIPSec Integrity Key: 0x%x", + xfrmiId, + childSA.InboundSPI, + childSA.PeerPublicIPAddr, + childSA.LocalPublicIPAddr, + childSA.EncrKInfo.TransformID(), + inboundEncryptionKey, + childSA.IntegKInfo.TransformID(), + inboundIntegrityKey, + xfrmiId, + childSA.OutboundSPI, + childSA.LocalPublicIPAddr, + childSA.PeerPublicIPAddr, + childSA.EncrKInfo.TransformID(), + outboundEncryptionKey, + childSA.IntegKInfo.TransformID(), + outboundIntegrityKey, + ) } type UDPSocketInfo struct { @@ -179,6 +217,11 @@ func (ikeUe *N3IWFIkeUe) Remove() error { n3iwfCtx.DeleteIKESecurityAssociation(ikeUe.N3IWFIKESecurityAssociation.LocalSPI) n3iwfCtx.DeleteInternalUEIPAddr(ikeUe.IPSecInnerIP.String()) + err := n3iwfCtx.IPSecInnerIPPool.Release(net.ParseIP(ikeUe.IPSecInnerIP.String()).To4()) + if err != nil { + return errors.Wrapf(err, "N3IWFIkeUe Remove()") + } + for _, childSA := range ikeUe.N3IWFChildSecurityAssociation { if err := ikeUe.DeleteChildSA(childSA); err != nil { return err @@ -189,28 +232,45 @@ func (ikeUe *N3IWFIkeUe) Remove() error { return nil } -func (ikeUe *N3IWFIkeUe) DeleteChildSA(childSA *ChildSecurityAssociation) error { +func (ikeUe *N3IWFIkeUe) DeleteChildSAXfrm(childSA *ChildSecurityAssociation) error { n3iwfCtx := ikeUe.N3iwfCtx iface := childSA.XfrmIface // Delete child SA xfrmState - for _, xfrmState := range childSA.XfrmStateList { + for idx := range childSA.XfrmStateList { + xfrmState := childSA.XfrmStateList[idx] if err := netlink.XfrmStateDel(&xfrmState); err != nil { - return fmt.Errorf("Delete xfrmstate error : %+v", err) + return errors.Wrapf(err, "Delete xfrmstate") } } // Delete child SA xfrmPolicy - for _, xfrmPolicy := range childSA.XfrmPolicyList { + for idx := range childSA.XfrmPolicyList { + xfrmPolicy := childSA.XfrmPolicyList[idx] if err := netlink.XfrmPolicyDel(&xfrmPolicy); err != nil { - return fmt.Errorf("Delete xfrmPolicy error : %+v", err) + return errors.Wrapf(err, "Delete xfrmPolicy") } } if iface == nil || iface.Attrs().Name == "xfrmi-default" { } else if err := netlink.LinkDel(iface); err != nil { - return fmt.Errorf("Delete interface %s fail: %+v", iface.Attrs().Name, err) + return errors.Wrapf(err, "Delete interface[%s]", iface.Attrs().Name) } else { - n3iwfCtx.XfrmIfaces.Delete(uint32(childSA.XfrmStateList[0].Ifid)) + ifId := childSA.XfrmStateList[0].Ifid + if ifId < 0 || ifId > math.MaxUint32 { + return errors.Errorf("DeleteChildSAXfrm Ifid has out of uint32 range value: %d", ifId) + } + n3iwfCtx.XfrmIfaces.Delete(uint32(ifId)) + } + + childSA.XfrmStateList = nil + childSA.XfrmPolicyList = nil + + return nil +} + +func (ikeUe *N3IWFIkeUe) DeleteChildSA(childSA *ChildSecurityAssociation) error { + if err := ikeUe.DeleteChildSAXfrm(childSA); err != nil { + return err } delete(ikeUe.N3IWFChildSecurityAssociation, childSA.InboundSPI) @@ -252,18 +312,10 @@ func (ikeUe *N3IWFIkeUe) CompleteChildSA(msgID uint32, outboundSPI uint32, childSA.OutboundSPI = outboundSPI - if len(chosenSecurityAssociation.Proposals[0].EncryptionAlgorithm) != 0 { - childSA.EncryptionAlgorithm = chosenSecurityAssociation.Proposals[0].EncryptionAlgorithm[0].TransformID - } - if len(chosenSecurityAssociation.Proposals[0].IntegrityAlgorithm) != 0 { - childSA.IntegrityAlgorithm = chosenSecurityAssociation.Proposals[0].IntegrityAlgorithm[0].TransformID - } - if len(chosenSecurityAssociation.Proposals[0].ExtendedSequenceNumbers) != 0 { - if chosenSecurityAssociation.Proposals[0].ExtendedSequenceNumbers[0].TransformID == 0 { - childSA.ESN = false - } else { - childSA.ESN = true - } + var err error + childSA.ChildSAKey, err = ike_security.NewChildSAKeyByProposal(chosenSecurityAssociation.Proposals[0]) + if err != nil { + return nil, errors.Wrapf(err, "CompleteChildSA") } // Record to UE context with inbound SPI as key diff --git a/internal/context/n3iwf_ue.go b/internal/context/n3iwf_ue.go new file mode 100644 index 00000000..5b4c7cc4 --- /dev/null +++ b/internal/context/n3iwf_ue.go @@ -0,0 +1,89 @@ +package context + +import ( + "net" + + "github.com/pkg/errors" + + "github.com/free5gc/ngap/ngapConvert" + "github.com/free5gc/ngap/ngapType" +) + +type N3IWFRanUe struct { + RanUeSharedCtx + + // Temporary cached NAS message + // Used when NAS registration accept arrived before + // UE setup NAS TCP connection with N3IWF, and + // Forward pduSessionEstablishmentAccept to UE after + // UE send CREATE_CHILD_SA response + TemporaryCachedNASMessage []byte + + // NAS TCP Connection Established + IsNASTCPConnEstablished bool + IsNASTCPConnEstablishedComplete bool + + // NAS TCP Connection + TCPConnection net.Conn +} + +func (n3iwfUe *N3IWFRanUe) init(ranUeNgapId int64) { + n3iwfUe.RanUeNgapId = ranUeNgapId + n3iwfUe.AmfUeNgapId = AmfUeNgapIdUnspecified + n3iwfUe.PduSessionList = make(map[int64]*PDUSession) + n3iwfUe.TemporaryPDUSessionSetupData = new(PDUSessionSetupTemporaryData) + n3iwfUe.IsNASTCPConnEstablished = false + n3iwfUe.IsNASTCPConnEstablishedComplete = false +} + +func (ranUe *N3IWFRanUe) Remove() error { + // remove from AMF context + ranUe.DetachAMF() + + // remove from RAN UE context + n3iwfCtx := ranUe.N3iwfCtx + n3iwfCtx.DeleteRanUe(ranUe.RanUeNgapId) + + for _, pduSession := range ranUe.PduSessionList { + n3iwfCtx.DeleteTEID(pduSession.GTPConnInfo.IncomingTEID) + } + + if ranUe.TCPConnection != nil { + if err := ranUe.TCPConnection.Close(); err != nil { + return errors.Errorf("Close TCP conn error : %v", err) + } + } + + return nil +} + +func (n3iwfUe *N3IWFRanUe) AttachAMF(sctpAddr string) bool { + if amf, ok := n3iwfUe.N3iwfCtx.AMFPoolLoad(sctpAddr); ok { + amf.N3iwfRanUeList[n3iwfUe.RanUeNgapId] = n3iwfUe + n3iwfUe.AMF = amf + return true + } else { + return false + } +} + +func (n3iwfUe *N3IWFRanUe) DetachAMF() { + if n3iwfUe.AMF == nil { + return + } + delete(n3iwfUe.AMF.N3iwfRanUeList, n3iwfUe.RanUeNgapId) +} + +// Implement RanUe interface +func (n3iwfUe *N3IWFRanUe) GetUserLocationInformation() *ngapType.UserLocationInformation { + userLocationInformation := new(ngapType.UserLocationInformation) + + userLocationInformation.Present = ngapType.UserLocationInformationPresentUserLocationInformationN3IWF + userLocationInformation.UserLocationInformationN3IWF = new(ngapType.UserLocationInformationN3IWF) + + userLocationInfoN3IWF := userLocationInformation.UserLocationInformationN3IWF + userLocationInfoN3IWF.IPAddress = ngapConvert.IPAddressToNgap(n3iwfUe.IPAddrv4, n3iwfUe.IPAddrv6) + userLocationInfoN3IWF.PortNumber = ngapConvert.PortNumberToNgap(n3iwfUe.PortNumber) + + return userLocationInformation +} diff --git a/pkg/context/ngap.go b/internal/context/ngap.go similarity index 69% rename from pkg/context/ngap.go rename to internal/context/ngap.go index 4627bd78..f697a907 100644 --- a/pkg/context/ngap.go +++ b/internal/context/ngap.go @@ -1,19 +1,27 @@ package context +import ( + "github.com/free5gc/ngap/ngapType" +) + type NgapEventType int64 // NGAP event type const ( UnmarshalEAP5GData NgapEventType = iota + NASTCPConnEstablishedComplete + GetNGAPContext SendInitialUEMessage SendPDUSessionResourceSetupResponse SendNASMsg StartTCPSignalNASMsg - NASTCPConnEstablishedComplete + SendUEContextRelease SendUEContextReleaseRequest SendUEContextReleaseComplete + SendPDUSessionResourceRelease SendPDUSessionResourceReleaseResponse - GetNGAPContext + SendUplinkNASTransport + SendInitialContextSetupResponse ) type EvtError string @@ -197,3 +205,74 @@ func NewGetNGAPContextEvt(ranUeNgapId int64, ngapCxtReqNumlist []int64) *GetNGAP NgapCxtReqNumlist: ngapCxtReqNumlist, } } + +type SendUplinkNASTransportEvt struct { + RanUeNgapId int64 + Pdu []byte +} + +func (e *SendUplinkNASTransportEvt) Type() NgapEventType { + return SendUplinkNASTransport +} + +func NewSendUplinkNASTransportEvt(ranUeNgapId int64, pdu []byte) *SendUplinkNASTransportEvt { + return &SendUplinkNASTransportEvt{ + RanUeNgapId: ranUeNgapId, + Pdu: pdu, + } +} + +type SendInitialContextSetupRespEvt struct { + RanUeNgapId int64 + ResponseList *ngapType.PDUSessionResourceSetupListCxtRes + FailedList *ngapType.PDUSessionResourceFailedToSetupListCxtRes + CriticalityDiagnostics *ngapType.CriticalityDiagnostics +} + +func (e *SendInitialContextSetupRespEvt) Type() NgapEventType { + return SendInitialContextSetupResponse +} + +func NewSendInitialContextSetupRespEvt( + ranUeNgapId int64, + responseList *ngapType.PDUSessionResourceSetupListCxtRes, + failedList *ngapType.PDUSessionResourceFailedToSetupListCxtRes, + criticalityDiagnostics *ngapType.CriticalityDiagnostics, +) *SendInitialContextSetupRespEvt { + return &SendInitialContextSetupRespEvt{ + RanUeNgapId: ranUeNgapId, + ResponseList: responseList, + FailedList: failedList, + CriticalityDiagnostics: criticalityDiagnostics, + } +} + +type SendUEContextReleaseEvt struct { + RanUeNgapId int64 +} + +func (e *SendUEContextReleaseEvt) Type() NgapEventType { + return SendUEContextRelease +} + +func NewSendUEContextReleaseEvt(ranUeNgapId int64) *SendUEContextReleaseEvt { + return &SendUEContextReleaseEvt{ + RanUeNgapId: ranUeNgapId, + } +} + +type SendPDUSessionResourceReleaseEvt struct { + RanUeNgapId int64 + DeletPduIds []int64 +} + +func (e *SendPDUSessionResourceReleaseEvt) Type() NgapEventType { + return SendPDUSessionResourceRelease +} + +func NewendPDUSessionResourceReleaseEvt(ranUeNgapId int64, deletPduIds []int64) *SendPDUSessionResourceReleaseEvt { + return &SendPDUSessionResourceReleaseEvt{ + RanUeNgapId: ranUeNgapId, + DeletPduIds: deletPduIds, + } +} diff --git a/internal/context/nwuup.go b/internal/context/nwuup.go new file mode 100644 index 00000000..52d367eb --- /dev/null +++ b/internal/context/nwuup.go @@ -0,0 +1,28 @@ +package context + +import gtpQoSMsg "github.com/free5gc/n3iwf/internal/gtp/message" + +type NwuupEventType int64 + +// NWuup Event Type +const ( + NwuupForwardDL NwuupEventType = iota +) + +type NwuupEvt interface { + Type() NwuupEventType +} + +type NwuupForwardDLEvt struct { + Packet gtpQoSMsg.QoSTPDUPacket +} + +func (nwuupForwardDLEvt *NwuupForwardDLEvt) Type() NwuupEventType { + return NwuupForwardDL +} + +func NewNwuupForwardDLEvt(packet gtpQoSMsg.QoSTPDUPacket) *NwuupForwardDLEvt { + return &NwuupForwardDLEvt{ + Packet: packet, + } +} diff --git a/pkg/context/ranue.go b/internal/context/ranue.go similarity index 62% rename from pkg/context/ranue.go rename to internal/context/ranue.go index 1ddfa48d..c2851c26 100644 --- a/pkg/context/ranue.go +++ b/internal/context/ranue.go @@ -4,13 +4,43 @@ import ( "fmt" "net" - "github.com/pkg/errors" - "github.com/free5gc/ngap/ngapType" ) -type N3IWFRanUe struct { - /* UE identity */ +type UeCtxRelState bool + +const ( + // NGAP has already received UE Context release command + UeCtxRelStateNone UeCtxRelState = false + UeCtxRelStateOngoing UeCtxRelState = true +) + +type PduSessResRelState bool + +const ( + // NGAP has not received Pdu Session resouces release request + PduSessResRelStateNone PduSessResRelState = false + PduSessResRelStateOngoing PduSessResRelState = true +) + +type RanUe interface { + // Get Attributes + GetUserLocationInformation() *ngapType.UserLocationInformation + GetSharedCtx() *RanUeSharedCtx + + // User Plane Traffic + // ForwardDL(gtpQoSMsg.QoSTPDUPacket) + // ForwardUL() + + // Others + CreatePDUSession(int64, ngapType.SNSSAI) (*PDUSession, error) + DeletePDUSession(int64) + FindPDUSession(int64) *PDUSession + Remove() error +} + +type RanUeSharedCtx struct { + // UE identity RanUeNgapId int64 AmfUeNgapId int64 IPAddrv4 string @@ -19,34 +49,20 @@ type N3IWFRanUe struct { MaskedIMEISV *ngapType.MaskedIMEISV // TS 38.413 9.3.1.54 Guti string - /* Relative Context */ + // Relative Context N3iwfCtx *N3IWFContext AMF *N3IWFAMF - /* Security */ + // Security SecurityCapabilities *ngapType.UESecurityCapabilities // TS 38.413 9.3.1.86 - /* PDU Session */ + // PDU Session PduSessionList map[int64]*PDUSession // pduSessionId as key - /* PDU Session Setup Temporary Data */ + // PDU Session Setup Temporary Data TemporaryPDUSessionSetupData *PDUSessionSetupTemporaryData - /* Temporary cached NAS message */ - // Used when NAS registration accept arrived before - // UE setup NAS TCP connection with N3IWF, and - // Forward pduSessionEstablishmentAccept to UE after - // UE send CREATE_CHILD_SA response - TemporaryCachedNASMessage []byte - - /* NAS TCP Connection Established */ - IsNASTCPConnEstablished bool - IsNASTCPConnEstablishedComplete bool - - /* NAS TCP Connection */ - TCPConnection net.Conn - - /* Others */ + // Others Guami *ngapType.GUAMI IndexToRfsp int64 Ambr *ngapType.UEAggregateMaximumBitRate @@ -56,6 +72,8 @@ type N3IWFRanUe struct { IMSVoiceSupported int32 RRCEstablishmentCause int16 PduSessionReleaseList ngapType.PDUSessionResourceReleasedListRelRes + UeCtxRelState UeCtxRelState + PduSessResRelState PduSessResRelState } type PDUSession struct { @@ -68,7 +86,7 @@ type PDUSession struct { SecurityIntegrity bool MaximumIntegrityDataRateUplink *ngapType.MaximumIntegrityProtectedDataRate MaximumIntegrityDataRateDownlink *ngapType.MaximumIntegrityProtectedDataRate - GTPConnection *GTPConnectionInfo + GTPConnInfo *GTPConnectionInfo QFIList []uint8 QosFlows map[int64]*QosFlow // QosFlowIdentifier as key } @@ -102,36 +120,11 @@ type PDUSessionSetupTemporaryData struct { Index int } -func (ranUe *N3IWFRanUe) init(ranUeNgapId int64) { - ranUe.RanUeNgapId = ranUeNgapId - ranUe.AmfUeNgapId = AmfUeNgapIdUnspecified - ranUe.PduSessionList = make(map[int64]*PDUSession) - ranUe.IsNASTCPConnEstablished = false - ranUe.IsNASTCPConnEstablishedComplete = false +func (ranUe *RanUeSharedCtx) GetSharedCtx() *RanUeSharedCtx { + return ranUe } -func (ranUe *N3IWFRanUe) Remove() error { - // remove from AMF context - ranUe.DetachAMF() - - // remove from RAN UE context - n3iwfCtx := ranUe.N3iwfCtx - n3iwfCtx.DeleteRanUe(ranUe.RanUeNgapId) - - for _, pduSession := range ranUe.PduSessionList { - n3iwfCtx.DeleteTEID(pduSession.GTPConnection.IncomingTEID) - } - - if ranUe.TCPConnection != nil { - if err := ranUe.TCPConnection.Close(); err != nil { - return errors.Errorf("Close TCP conn error : %v", err) - } - } - - return nil -} - -func (ranUe *N3IWFRanUe) FindPDUSession(pduSessionID int64) *PDUSession { +func (ranUe *RanUeSharedCtx) FindPDUSession(pduSessionID int64) *PDUSession { if pduSession, ok := ranUe.PduSessionList[pduSessionID]; ok { return pduSession } else { @@ -139,7 +132,7 @@ func (ranUe *N3IWFRanUe) FindPDUSession(pduSessionID int64) *PDUSession { } } -func (ranUe *N3IWFRanUe) CreatePDUSession(pduSessionID int64, snssai ngapType.SNSSAI) (*PDUSession, error) { +func (ranUe *RanUeSharedCtx) CreatePDUSession(pduSessionID int64, snssai ngapType.SNSSAI) (*PDUSession, error) { if _, exists := ranUe.PduSessionList[pduSessionID]; exists { return nil, fmt.Errorf("PDU Session[ID:%d] is already exists", pduSessionID) } @@ -152,19 +145,6 @@ func (ranUe *N3IWFRanUe) CreatePDUSession(pduSessionID int64, snssai ngapType.SN return pduSession, nil } -func (ranUe *N3IWFRanUe) AttachAMF(sctpAddr string) bool { - if amf, ok := ranUe.N3iwfCtx.AMFPoolLoad(sctpAddr); ok { - amf.N3iwfRanUeList[ranUe.RanUeNgapId] = ranUe - ranUe.AMF = amf - return true - } else { - return false - } -} - -func (ranUe *N3IWFRanUe) DetachAMF() { - if ranUe.AMF == nil { - return - } - delete(ranUe.AMF.N3iwfRanUeList, ranUe.RanUeNgapId) +func (ranUe *RanUeSharedCtx) DeletePDUSession(pduSessionId int64) { + delete(ranUe.PduSessionList, pduSessionId) } diff --git a/pkg/context/testing_app.go b/internal/context/testing_app.go similarity index 100% rename from pkg/context/testing_app.go rename to internal/context/testing_app.go diff --git a/pkg/context/timer.go b/internal/context/timer.go similarity index 100% rename from pkg/context/timer.go rename to internal/context/timer.go diff --git a/internal/gre/message.go b/internal/gre/message.go index 86b0566e..e396b992 100644 --- a/internal/gre/message.go +++ b/internal/gre/message.go @@ -1,6 +1,11 @@ package gre -import "encoding/binary" +import ( + "encoding/binary" + "math" + + "github.com/pkg/errors" +) // [TS 24.502] 9.3.3 GRE encapsulated user data packet const ( @@ -77,8 +82,14 @@ func (p *GREPacket) setRQI(rqi bool) { } } -func (p *GREPacket) GetQFI() uint8 { - return uint8((p.key >> 24) & 0x3F) +func (p *GREPacket) GetQFI() (uint8, error) { + value := (p.key >> 24) & 0x3F + + if value > math.MaxUint8 { + return 0, errors.Errorf("GetQFI() value exceeds uint8: %d", value) + } else { + return uint8(value), nil + } } func (p *GREPacket) GetRQI() bool { diff --git a/internal/gtp/message/message.go b/internal/gtp/message/message.go index 2dba4847..eb84a55d 100644 --- a/internal/gtp/message/message.go +++ b/internal/gtp/message/message.go @@ -79,3 +79,20 @@ func (p *QoSTPDUPacket) unmarshalExtensionHeader() error { return nil } + +func BuildQoSGTPPacket(teid uint32, qfi uint8, payload []byte) ([]byte, error) { + header := gtpMsg.NewHeader(0x34, gtpMsg.MsgTypeTPDU, teid, 0x00, payload).WithExtensionHeaders( + gtpMsg.NewExtensionHeader( + gtpMsg.ExtHeaderTypePDUSessionContainer, + []byte{UL_PDU_SESSION_INFORMATION_TYPE, qfi}, + gtpMsg.ExtHeaderTypeNoMoreExtensionHeaders, + ), + ) + + b := make([]byte, header.MarshalLen()) + if err := header.MarshalTo(b); err != nil { + return nil, errors.Wrapf(err, "go-gtp Marshal failed") + } + + return b, nil +} diff --git a/pkg/ike/dispatcher.go b/internal/ike/dispatcher.go similarity index 55% rename from pkg/ike/dispatcher.go rename to internal/ike/dispatcher.go index 50f29493..4c30d52c 100644 --- a/pkg/ike/dispatcher.go +++ b/internal/ike/dispatcher.go @@ -4,14 +4,16 @@ import ( "net" "runtime/debug" + ike_message "github.com/free5gc/ike/message" + n3iwf_context "github.com/free5gc/n3iwf/internal/context" "github.com/free5gc/n3iwf/internal/logger" - ike_message "github.com/free5gc/n3iwf/pkg/ike/message" ) func (s *Server) Dispatch( udpConn *net.UDPConn, localAddr, remoteAddr *net.UDPAddr, - msg []byte, + ikeMessage *ike_message.IKEMessage, msg []byte, + ikeSA *n3iwf_context.IKESecurityAssociation, ) { ikeLog := logger.IKELog defer func() { @@ -21,37 +23,15 @@ func (s *Server) Dispatch( } }() - // As specified in RFC 7296 section 3.1, the IKE message send from/to UDP port 4500 - // should prepend a 4 bytes zero - if localAddr.Port == 4500 { - for i := 0; i < 4; i++ { - if msg[i] != 0 { - ikeLog.Warn( - "Received an IKE packet that does not prepend 4 bytes zero from UDP port 4500," + - " this packet may be the UDP encapsulated ESP. The packet will not be handled.") - return - } - } - msg = msg[4:] - } - - ikeMessage := new(ike_message.IKEMessage) - - err := ikeMessage.Decode(msg) - if err != nil { - ikeLog.Error(err) - return - } - switch ikeMessage.ExchangeType { case ike_message.IKE_SA_INIT: s.HandleIKESAINIT(udpConn, localAddr, remoteAddr, ikeMessage, msg) case ike_message.IKE_AUTH: - s.HandleIKEAUTH(udpConn, localAddr, remoteAddr, ikeMessage) + s.HandleIKEAUTH(udpConn, localAddr, remoteAddr, ikeMessage, ikeSA) case ike_message.CREATE_CHILD_SA: - s.HandleCREATECHILDSA(udpConn, localAddr, remoteAddr, ikeMessage) + s.HandleCREATECHILDSA(udpConn, localAddr, remoteAddr, ikeMessage, ikeSA) case ike_message.INFORMATIONAL: - s.HandleInformational(udpConn, localAddr, remoteAddr, ikeMessage) + s.HandleInformational(udpConn, localAddr, remoteAddr, ikeMessage, ikeSA) default: ikeLog.Warnf("Unimplemented IKE message type, exchange type: %d", ikeMessage.ExchangeType) } diff --git a/pkg/ike/handler.go b/internal/ike/handler.go similarity index 62% rename from pkg/ike/handler.go rename to internal/ike/handler.go index 2b84fca3..2d4f421a 100644 --- a/pkg/ike/handler.go +++ b/internal/ike/handler.go @@ -5,11 +5,11 @@ import ( "crypto" "crypto/rand" "crypto/rsa" - "crypto/sha1" + "crypto/sha1" // #nosec G505 "encoding/binary" "encoding/hex" "fmt" - math_rand "math/rand" + "math" "net" "runtime/debug" "sync/atomic" @@ -19,11 +19,15 @@ import ( "github.com/vishvananda/netlink" "golang.org/x/sys/unix" + ike_message "github.com/free5gc/ike/message" + ike_security "github.com/free5gc/ike/security" + "github.com/free5gc/ike/security/dh" + "github.com/free5gc/ike/security/encr" + "github.com/free5gc/ike/security/integ" + "github.com/free5gc/ike/security/prf" + n3iwf_context "github.com/free5gc/n3iwf/internal/context" + "github.com/free5gc/n3iwf/internal/ike/xfrm" "github.com/free5gc/n3iwf/internal/logger" - n3iwf_context "github.com/free5gc/n3iwf/pkg/context" - ike_message "github.com/free5gc/n3iwf/pkg/ike/message" - "github.com/free5gc/n3iwf/pkg/ike/security" - "github.com/free5gc/n3iwf/pkg/ike/xfrm" ) func (s *Server) HandleIKESAINIT( @@ -45,34 +49,12 @@ func (s *Server) HandleIKESAINIT( cfg := s.Config() // For response or needed data - responseIKEMessage := new(ike_message.IKEMessage) - var sharedKeyData, localNonce, concatenatedNonce []byte + var responseIKEPayload ike_message.IKEPayloadContainer + var localNonce, concatenatedNonce []byte // Chosen transform from peer's proposal - var encryptionAlgorithmTransform, pseudorandomFunctionTransform *ike_message.Transform - var integrityAlgorithmTransform, diffieHellmanGroupTransform *ike_message.Transform - // For NAT-T - var ueIsBehindNAT, n3iwfIsBehindNAT bool - - if message == nil { - ikeLog.Error("IKE Message is nil") - return - } - - // parse IKE header and setup IKE context - // check major version - majorVersion := ((message.Version & 0xf0) >> 4) - if majorVersion > 2 { - ikeLog.Warn("Received an IKE message with higher major version") - // send INFORMATIONAL type message with INVALID_MAJOR_VERSION Notify payload - responseIKEMessage.BuildIKEHeader(message.InitiatorSPI, message.ResponderSPI, - ike_message.INFORMATIONAL, ike_message.ResponseBitCheck, message.MessageID) - responseIKEMessage.Payloads.Reset() - responseIKEMessage.Payloads.BuildNotification(ike_message.TypeNone, - ike_message.INVALID_MAJOR_VERSION, nil, nil) - - SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage) - return - } + var chooseProposal ike_message.ProposalContainer + var localPublicValue []byte + var chosenDiffieHellmanGroup uint16 for _, ikePayload := range message.Payloads { switch ikePayload.Type() { @@ -92,95 +74,22 @@ func (s *Server) HandleIKESAINIT( } if securityAssociation != nil { - responseSecurityAssociation := responseIKEMessage.Payloads.BuildSecurityAssociation() - - for _, proposal := range securityAssociation.Proposals { - // We need ENCR, PRF, INTEG, DH, but not ESN - encryptionAlgorithmTransform = nil - pseudorandomFunctionTransform = nil - integrityAlgorithmTransform = nil - diffieHellmanGroupTransform = nil - - if len(proposal.EncryptionAlgorithm) > 0 { - for _, transform := range proposal.EncryptionAlgorithm { - if isTransformSupported(ike_message.TypeEncryptionAlgorithm, transform.TransformID, - transform.AttributePresent, transform.AttributeValue) { - encryptionAlgorithmTransform = transform - break - } - } - if encryptionAlgorithmTransform == nil { - continue - } - } else { - continue // mandatory - } - if len(proposal.PseudorandomFunction) > 0 { - for _, transform := range proposal.PseudorandomFunction { - if isTransformSupported(ike_message.TypePseudorandomFunction, transform.TransformID, - transform.AttributePresent, transform.AttributeValue) { - pseudorandomFunctionTransform = transform - break - } - } - if pseudorandomFunctionTransform == nil { - continue - } - } else { - continue // mandatory - } - if len(proposal.IntegrityAlgorithm) > 0 { - for _, transform := range proposal.IntegrityAlgorithm { - if isTransformSupported(ike_message.TypeIntegrityAlgorithm, transform.TransformID, - transform.AttributePresent, transform.AttributeValue) { - integrityAlgorithmTransform = transform - break - } - } - if integrityAlgorithmTransform == nil { - continue - } - } else { - continue // mandatory - } - if len(proposal.DiffieHellmanGroup) > 0 { - for _, transform := range proposal.DiffieHellmanGroup { - if isTransformSupported(ike_message.TypeDiffieHellmanGroup, transform.TransformID, - transform.AttributePresent, transform.AttributeValue) { - diffieHellmanGroupTransform = transform - break - } - } - if diffieHellmanGroupTransform == nil { - continue - } - } else { - continue // mandatory - } - if len(proposal.ExtendedSequenceNumbers) > 0 { - continue // No ESN - } - - // Construct chosen proposal, with ENCR, PRF, INTEG, DH, and each - // contains one transform expectively - chosenProposal := responseSecurityAssociation.Proposals.BuildProposal( - proposal.ProposalNumber, proposal.ProtocolID, nil) - chosenProposal.EncryptionAlgorithm = append(chosenProposal.EncryptionAlgorithm, encryptionAlgorithmTransform) - chosenProposal.PseudorandomFunction = append(chosenProposal.PseudorandomFunction, pseudorandomFunctionTransform) - chosenProposal.IntegrityAlgorithm = append(chosenProposal.IntegrityAlgorithm, integrityAlgorithmTransform) - chosenProposal.DiffieHellmanGroup = append(chosenProposal.DiffieHellmanGroup, diffieHellmanGroupTransform) - break - } + responseSecurityAssociation := responseIKEPayload.BuildSecurityAssociation() + chooseProposal = SelectProposal(securityAssociation.Proposals) + responseSecurityAssociation.Proposals = append(responseSecurityAssociation.Proposals, chooseProposal...) if len(responseSecurityAssociation.Proposals) == 0 { ikeLog.Warn("No proposal chosen") // Respond NO_PROPOSAL_CHOSEN to UE - responseIKEMessage.BuildIKEHeader(message.InitiatorSPI, message.ResponderSPI, - ike_message.IKE_SA_INIT, ike_message.ResponseBitCheck, message.MessageID) - responseIKEMessage.Payloads.Reset() - responseIKEMessage.Payloads.BuildNotification(ike_message.TypeNone, ike_message.NO_PROPOSAL_CHOSEN, nil, nil) + responseIKEPayload.Reset() + responseIKEPayload.BuildNotification(ike_message.TypeNone, ike_message.NO_PROPOSAL_CHOSEN, nil, nil) + responseIKEMessage := ike_message.NewMessage(message.InitiatorSPI, message.ResponderSPI, + ike_message.IKE_SA_INIT, true, false, message.MessageID, responseIKEPayload) - SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage) + err := SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage, nil) + if err != nil { + ikeLog.Errorf("HandleIKESAINIT(): %v", err) + } return } } else { @@ -190,29 +99,26 @@ func (s *Server) HandleIKESAINIT( } if keyExcahge != nil { - chosenDiffieHellmanGroup := diffieHellmanGroupTransform.TransformID + chosenDiffieHellmanGroup = chooseProposal[0].DiffieHellmanGroup[0].TransformID if chosenDiffieHellmanGroup != keyExcahge.DiffieHellmanGroup { ikeLog.Warn("The Diffie-Hellman group defined in key exchange payload not matches the one in chosen proposal") // send INVALID_KE_PAYLOAD to UE - responseIKEMessage.BuildIKEHeader(message.InitiatorSPI, message.ResponderSPI, - ike_message.IKE_SA_INIT, ike_message.ResponseBitCheck, message.MessageID) - responseIKEMessage.Payloads.Reset() + responseIKEPayload.Reset() notificationData := make([]byte, 2) binary.BigEndian.PutUint16(notificationData, chosenDiffieHellmanGroup) - responseIKEMessage.Payloads.BuildNotification( + responseIKEPayload.BuildNotification( ike_message.TypeNone, ike_message.INVALID_KE_PAYLOAD, nil, notificationData) - SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage) + responseIKEMessage := ike_message.NewMessage(message.InitiatorSPI, message.ResponderSPI, + ike_message.IKE_SA_INIT, true, false, message.MessageID, responseIKEPayload) + + err := SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage, nil) + if err != nil { + ikeLog.Errorf("HandleIKESAINIT(): %v", err) + } return } - - var localPublicValue []byte - localPublicValue, sharedKeyData = security.CalculateDiffieHellmanMaterials( - security.GenerateRandomNumber(), - keyExcahge.KeyExchangeData, - chosenDiffieHellmanGroup) - responseIKEMessage.Payloads.BUildKeyExchange(chosenDiffieHellmanGroup, localPublicValue) } else { ikeLog.Error("The key exchange field is nil") // TODO: send error message to UE @@ -220,113 +126,61 @@ func (s *Server) HandleIKESAINIT( } if nonce != nil { - localNonce = security.GenerateRandomNumber().Bytes() + localNonceBigInt, err := ike_security.GenerateRandomNumber() + if err != nil { + ikeLog.Errorf("HandleIKESAINIT: %v", err) + return + } + localNonce = localNonceBigInt.Bytes() concatenatedNonce = append(nonce.NonceData, localNonce...) - responseIKEMessage.Payloads.BuildNonce(localNonce) + responseIKEPayload.BuildNonce(localNonce) } else { ikeLog.Error("The nonce field is nil") // TODO: send error message to UE return } - if len(notifications) != 0 { - for _, notification := range notifications { - switch notification.NotifyMessageType { - case ike_message.NAT_DETECTION_SOURCE_IP: - ikeLog.Trace("Received IKE Notify: NAT_DETECTION_SOURCE_IP") - // Calculate local NAT_DETECTION_SOURCE_IP hash - // : sha1(ispi | rspi | ueip | ueport) - localDetectionData := make([]byte, 22) - binary.BigEndian.PutUint64(localDetectionData[0:8], message.InitiatorSPI) - binary.BigEndian.PutUint64(localDetectionData[8:16], message.ResponderSPI) - copy(localDetectionData[16:20], ueAddr.IP.To4()) - binary.BigEndian.PutUint16(localDetectionData[20:22], uint16(ueAddr.Port)) - - sha1HashFunction := sha1.New() // #nosec G401 - _, err := sha1HashFunction.Write(localDetectionData) - if err != nil { - ikeLog.Errorf("Hash function write error: %v", err) - return - } - - if !bytes.Equal(notification.NotificationData, sha1HashFunction.Sum(nil)) { - ueIsBehindNAT = true - } - case ike_message.NAT_DETECTION_DESTINATION_IP: - ikeLog.Trace("Received IKE Notify: NAT_DETECTION_DESTINATION_IP") - // Calculate local NAT_DETECTION_SOURCE_IP hash - // : sha1(ispi | rspi | n3iwfip | n3iwfport) - localDetectionData := make([]byte, 22) - binary.BigEndian.PutUint64(localDetectionData[0:8], message.InitiatorSPI) - binary.BigEndian.PutUint64(localDetectionData[8:16], message.ResponderSPI) - copy(localDetectionData[16:20], n3iwfAddr.IP.To4()) - binary.BigEndian.PutUint16(localDetectionData[20:22], uint16(n3iwfAddr.Port)) - - sha1HashFunction := sha1.New() // #nosec G401 - _, err := sha1HashFunction.Write(localDetectionData) - if err != nil { - ikeLog.Errorf("Hash function write error: %v", err) - return - } - - if !bytes.Equal(notification.NotificationData, sha1HashFunction.Sum(nil)) { - n3iwfIsBehindNAT = true - } - default: - } - } + ueBehindNAT, n3iwfBehindNAT, err := s.handleNATDetect( + message.InitiatorSPI, message.ResponderSPI, + notifications, ueAddr, n3iwfAddr) + if err != nil { + ikeLog.Errorf("Handle IKE_SA_INIT: %v", err) + return } // Create new IKE security association ikeSecurityAssociation := n3iwfCtx.NewIKESecurityAssociation() ikeSecurityAssociation.RemoteSPI = message.InitiatorSPI ikeSecurityAssociation.InitiatorMessageID = message.MessageID - ikeSecurityAssociation.UEIsBehindNAT = ueIsBehindNAT - ikeSecurityAssociation.N3IWFIsBehindNAT = n3iwfIsBehindNAT - // Record algorithm in context - ikeSecurityAssociation.EncryptionAlgorithm = encryptionAlgorithmTransform - ikeSecurityAssociation.IntegrityAlgorithm = integrityAlgorithmTransform - ikeSecurityAssociation.PseudorandomFunction = pseudorandomFunctionTransform - ikeSecurityAssociation.DiffieHellmanGroup = diffieHellmanGroupTransform + ikeSecurityAssociation.IKESAKey, localPublicValue, err = ike_security.NewIKESAKey(chooseProposal[0], + keyExcahge.KeyExchangeData, concatenatedNonce, + ikeSecurityAssociation.RemoteSPI, ikeSecurityAssociation.LocalSPI) + if err != nil { + ikeLog.Errorf("Handle IKE_SA_INIT: %v", err) + return + } + + ikeLog.Debugln(ikeSecurityAssociation.String()) // Record concatenated nonce - ikeSecurityAssociation.ConcatenatedNonce = append(ikeSecurityAssociation.ConcatenatedNonce, concatenatedNonce...) - // Record Diffie-Hellman shared key - ikeSecurityAssociation.DiffieHellmanSharedKey = append(ikeSecurityAssociation.DiffieHellmanSharedKey, sharedKeyData...) + ikeSecurityAssociation.ConcatenatedNonce = append( + ikeSecurityAssociation.ConcatenatedNonce, concatenatedNonce...) + ikeSecurityAssociation.UeBehindNAT = ueBehindNAT + ikeSecurityAssociation.N3iwfBehindNAT = n3iwfBehindNAT - err := security.GenerateKeyForIKESA(ikeSecurityAssociation) + responseIKEPayload.BUildKeyExchange(chosenDiffieHellmanGroup, localPublicValue) + err = s.buildNATDetectNotifPayload( + ikeSecurityAssociation, &responseIKEPayload, ueAddr, n3iwfAddr) if err != nil { - ikeLog.Errorf("Generate key for IKE SA failed: %v", err) + ikeLog.Warnf("Handle IKE_SA_INIT: %v", err) return } // IKE response to UE - responseIKEMessage.BuildIKEHeader(ikeSecurityAssociation.RemoteSPI, ikeSecurityAssociation.LocalSPI, - ike_message.IKE_SA_INIT, ike_message.ResponseBitCheck, message.MessageID) - - // Calculate NAT_DETECTION_SOURCE_IP for NAT-T - natDetectionSourceIP := make([]byte, 22) - binary.BigEndian.PutUint64(natDetectionSourceIP[0:8], ikeSecurityAssociation.RemoteSPI) - binary.BigEndian.PutUint64(natDetectionSourceIP[8:16], ikeSecurityAssociation.LocalSPI) - copy(natDetectionSourceIP[16:20], n3iwfAddr.IP.To4()) - binary.BigEndian.PutUint16(natDetectionSourceIP[20:22], uint16(n3iwfAddr.Port)) - - // Build and append notify payload for NAT_DETECTION_SOURCE_IP - responseIKEMessage.Payloads.BuildNotification( - ike_message.TypeNone, ike_message.NAT_DETECTION_SOURCE_IP, nil, natDetectionSourceIP) - - // Calculate NAT_DETECTION_DESTINATION_IP for NAT-T - natDetectionDestinationIP := make([]byte, 22) - binary.BigEndian.PutUint64(natDetectionDestinationIP[0:8], ikeSecurityAssociation.RemoteSPI) - binary.BigEndian.PutUint64(natDetectionDestinationIP[8:16], ikeSecurityAssociation.LocalSPI) - copy(natDetectionDestinationIP[16:20], ueAddr.IP.To4()) - binary.BigEndian.PutUint16(natDetectionDestinationIP[20:22], uint16(ueAddr.Port)) - - // Build and append notify payload for NAT_DETECTION_DESTINATION_IP - responseIKEMessage.Payloads.BuildNotification( - ike_message.TypeNone, ike_message.NAT_DETECTION_DESTINATION_IP, nil, natDetectionDestinationIP) + responseIKEMessage := ike_message.NewMessage(message.InitiatorSPI, ikeSecurityAssociation.LocalSPI, + ike_message.IKE_SA_INIT, true, false, message.MessageID, responseIKEPayload) // Prepare authentication data - InitatorSignedOctet // InitatorSignedOctet = RealMessage1 | NonceRData | MACedIDForI @@ -337,8 +191,7 @@ func (s *Server) HandleIKESAINIT( // ResponderSignedOctet = RealMessage2 | NonceIData | MACedIDForR responseIKEMessageData, err := responseIKEMessage.Encode() if err != nil { - ikeLog.Errorln(err) - ikeLog.Error("Encoding IKE message failed") + ikeLog.Errorf("Encoding IKE message failed: %v", err) return } ikeSecurityAssociation.ResponderSignedOctets = append(responseIKEMessageData, nonce.NonceData...) @@ -347,29 +200,27 @@ func (s *Server) HandleIKESAINIT( idPayload.BuildIdentificationResponder(ike_message.ID_FQDN, []byte(cfg.GetFQDN())) idPayloadData, err := idPayload.Encode() if err != nil { - ikeLog.Errorln(err) - ikeLog.Error("Encode IKE payload failed.") - return - } - pseudorandomFunction, ok := security.NewPseudorandomFunction( - ikeSecurityAssociation.SK_pr, - ikeSecurityAssociation.PseudorandomFunction.TransformID) - if !ok { - ikeLog.Error("Get an unsupported pseudorandom funcion. This may imply an unsupported transform is chosen.") + ikeLog.Errorf("Encode IKE payload failed: %v", err) return } - _, err = pseudorandomFunction.Write(idPayloadData[4:]) + + ikeSecurityAssociation.Prf_r.Reset() + _, err = ikeSecurityAssociation.Prf_r.Write(idPayloadData[4:]) if err != nil { ikeLog.Errorf("Pseudorandom function write error: %v", err) return } + ikeSecurityAssociation.ResponderSignedOctets = append(ikeSecurityAssociation.ResponderSignedOctets, - pseudorandomFunction.Sum(nil)...) + ikeSecurityAssociation.Prf_r.Sum(nil)...) ikeLog.Tracef("Local unsigned authentication data:\n%s", hex.Dump(ikeSecurityAssociation.ResponderSignedOctets)) // Send response to UE - SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage) + err = SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage, nil) + if err != nil { + ikeLog.Errorf("HandleIKESAINIT(): %v", err) + } } const ( @@ -387,73 +238,18 @@ func (s *Server) HandleIKEAUTH( udpConn *net.UDPConn, n3iwfAddr, ueAddr *net.UDPAddr, message *ike_message.IKEMessage, + ikeSecurityAssociation *n3iwf_context.IKESecurityAssociation, ) { ikeLog := logger.IKELog ikeLog.Infoln("Handle IKE_AUTH") - var encryptedPayload *ike_message.Encrypted - n3iwfCtx := s.Context() cfg := s.Config() ipsecGwAddr := cfg.GetIPSecGatewayAddr() // Used for response - responseIKEMessage := new(ike_message.IKEMessage) var responseIKEPayload ike_message.IKEPayloadContainer - if message == nil { - ikeLog.Error("IKE Message is nil") - return - } - - // parse IKE header and setup IKE context - // check major version - majorVersion := ((message.Version & 0xf0) >> 4) - if majorVersion > 2 { - ikeLog.Warn("Received an IKE message with higher major version") - // send INFORMATIONAL type message with INVALID_MAJOR_VERSION Notify payload ( OUTSIDE IKE SA ) - responseIKEMessage.BuildIKEHeader(message.InitiatorSPI, message.ResponderSPI, - ike_message.INFORMATIONAL, ike_message.ResponseBitCheck, message.MessageID) - responseIKEMessage.Payloads.Reset() - responseIKEMessage.Payloads.BuildNotification(ike_message.TypeNone, ike_message.INVALID_MAJOR_VERSION, nil, nil) - - SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage) - return - } - - // Find corresponding IKE security association - localSPI := message.ResponderSPI - ikeSecurityAssociation, ok := n3iwfCtx.IKESALoad(localSPI) - if !ok { - ikeLog.Warn("Unrecognized SPI") - // send INFORMATIONAL type message with INVALID_IKE_SPI Notify payload ( OUTSIDE IKE SA ) - responseIKEMessage.BuildIKEHeader(message.InitiatorSPI, 0, ike_message.INFORMATIONAL, - ike_message.ResponseBitCheck, message.MessageID) - responseIKEMessage.Payloads.Reset() - responseIKEMessage.Payloads.BuildNotification(ike_message.TypeNone, ike_message.INVALID_IKE_SPI, nil, nil) - - SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage) - return - } - - for _, ikePayload := range message.Payloads { - switch ikePayload.Type() { - case ike_message.TypeSK: - encryptedPayload = ikePayload.(*ike_message.Encrypted) - default: - ikeLog.Warnf( - "Get IKE payload (type %d) in IKE_AUTH message, this payload will not be handled by IKE handler", - ikePayload.Type()) - } - } - - decryptedIKEPayload, err := security.DecryptProcedure( - ikeSecurityAssociation, message, encryptedPayload) - if err != nil { - ikeLog.Errorf("Decrypt IKE message failed: %v", err) - return - } - // Parse payloads var initiatorID *ike_message.IdentificationInitiator var certificateRequest *ike_message.CertificateRequest @@ -464,8 +260,9 @@ func (s *Server) HandleIKEAUTH( var eap *ike_message.EAP var authentication *ike_message.Authentication var configuration *ike_message.Configuration + var ok bool - for _, ikePayload := range decryptedIKEPayload { + for _, ikePayload := range message.Payloads { switch ikePayload.Type() { case ike_message.TypeIDi: initiatorID = ikePayload.(*ike_message.IdentificationInitiator) @@ -492,8 +289,6 @@ func (s *Server) HandleIKEAUTH( } } - // NOTE: tune it - transformPseudorandomFunction := ikeSecurityAssociation.PseudorandomFunction ikeSecurityAssociation.InitiatorMessageID = message.MessageID switch ikeSecurityAssociation.State { @@ -512,20 +307,14 @@ func (s *Server) HandleIKEAUTH( ikeLog.Error("Encoding ID payload message failed.") return } - pseudorandomFunction, ok1 := security.NewPseudorandomFunction( - ikeSecurityAssociation.SK_pi, - transformPseudorandomFunction.TransformID) - if !ok1 { - ikeLog.Error("Get an unsupported pseudorandom funcion. This may imply an unsupported transform is chosen.") - return - } - if _, err := pseudorandomFunction.Write(idPayloadData[4:]); err != nil { + ikeSecurityAssociation.Prf_i.Reset() + if _, err := ikeSecurityAssociation.Prf_i.Write(idPayloadData[4:]); err != nil { ikeLog.Errorf("Pseudorandom function write error: %v", err) return } ikeSecurityAssociation.InitiatorSignedOctets = append( ikeSecurityAssociation.InitiatorSignedOctets, - pseudorandomFunction.Sum(nil)...) + ikeSecurityAssociation.Prf_i.Sum(nil)...) } else { ikeLog.Error("The initiator identification field is nil") // TODO: send error message to UE @@ -542,7 +331,7 @@ func (s *Server) HandleIKEAUTH( // authorities. This can be a chain of certificates. if certificateRequest != nil { ikeLog.Info("UE request N3IWF certificate") - if security.CompareRootCertificate( + if ike_security.CompareRootCertificate( n3iwfCtx.CertificateAuthority, certificateRequest.CertificateEncoding, certificateRequest.CertificationAuthority) { @@ -643,29 +432,20 @@ func (s *Server) HandleIKEAUTH( if len(responseSecurityAssociation.Proposals) == 0 { ikeLog.Warn("No proposal chosen") // Respond NO_PROPOSAL_CHOSEN to UE - // Build IKE message - responseIKEMessage.BuildIKEHeader( - message.InitiatorSPI, message.ResponderSPI, - ike_message.IKE_AUTH, ike_message.ResponseBitCheck, - message.MessageID) - responseIKEMessage.Payloads.Reset() - - // Build response - responseIKEPayload.Reset() - // Notification responseIKEPayload.BuildNotification( ike_message.TypeNone, ike_message.NO_PROPOSAL_CHOSEN, nil, nil) - err := security.EncryptProcedure( - ikeSecurityAssociation, responseIKEPayload, responseIKEMessage) - if err != nil { - ikeLog.Errorf("Encrypting IKE message failed: %v", err) - return - } + // Build IKE message + responseIKEMessage := ike_message.NewMessage(message.InitiatorSPI, message.ResponderSPI, + ike_message.IKE_AUTH, true, false, message.MessageID, responseIKEPayload) // Send IKE message to UE - SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage) + err := SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage, + ikeSecurityAssociation.IKESAKey) + if err != nil { + ikeLog.Errorf("HandleIKEAUTH(): %v", err) + } return } @@ -692,11 +472,7 @@ func (s *Server) HandleIKEAUTH( ikeLog.Info("Received traffic selector initiator from UE") ikeSecurityAssociation.TrafficSelectorResponder = trafficSelectorResponder - // Build response IKE message - responseIKEMessage.BuildIKEHeader(message.InitiatorSPI, message.ResponderSPI, - ike_message.IKE_AUTH, ike_message.ResponseBitCheck, message.MessageID) - responseIKEMessage.Payloads.Reset() - + responseIKEPayload.Reset() // Identification responseIKEPayload.BuildIdentificationResponder(ike_message.ID_FQDN, []byte(cfg.GetFQDN())) @@ -714,7 +490,9 @@ func (s *Server) HandleIKEAUTH( return } - signedAuth, err := rsa.SignPKCS1v15( + var signedAuth []byte + + signedAuth, err = rsa.SignPKCS1v15( rand.Reader, n3iwfCtx.N3IWFPrivateKey, crypto.SHA1, sha1HashFunction.Sum(nil)) if err != nil { @@ -726,7 +504,7 @@ func (s *Server) HandleIKEAUTH( // EAP expanded 5G-Start var identifier uint8 for { - identifier, err = security.GenerateRandomUint8() + identifier, err = ike_security.GenerateRandomUint8() if err != nil { ikeLog.Errorf("Random number failed: %v", err) return @@ -738,18 +516,20 @@ func (s *Server) HandleIKEAUTH( } responseIKEPayload.BuildEAP5GStart(identifier) - err = security.EncryptProcedure( - ikeSecurityAssociation, responseIKEPayload, responseIKEMessage) - if err != nil { - ikeLog.Errorf("Encrypting IKE message failed: %v", err) - return - } + // Build IKE message + responseIKEMessage := ike_message.NewMessage(message.InitiatorSPI, message.ResponderSPI, + ike_message.IKE_AUTH, true, false, message.MessageID, responseIKEPayload) // Shift state ikeSecurityAssociation.State++ // Send IKE message to UE - SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage) + err = SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage, + ikeSecurityAssociation.IKESAKey) + if err != nil { + ikeLog.Errorf("HandleIKEAUTH(): %v", err) + return + } case EAPSignalling: // If success, N3IWF will send an UPLinkNASTransport to AMF @@ -792,28 +572,26 @@ func (s *Server) HandleIKEAUTH( if eap5GMessageID == ike_message.EAP5GType5GStop { // Send EAP failure - // Build IKE message - responseIKEMessage.BuildIKEHeader(message.InitiatorSPI, message.ResponderSPI, - ike_message.IKE_AUTH, ike_message.ResponseBitCheck, message.MessageID) - responseIKEMessage.Payloads.Reset() + responseIKEPayload.Reset() // EAP - identifier, err := security.GenerateRandomUint8() + identifier, err := ike_security.GenerateRandomUint8() if err != nil { ikeLog.Errorf("Generate random uint8 failed: %v", err) return } responseIKEPayload.BuildEAPfailure(identifier) - err = security.EncryptProcedure( - ikeSecurityAssociation, responseIKEPayload, responseIKEMessage) - if err != nil { - ikeLog.Errorf("Encrypting IKE message failed: %v", err) - return - } + // Build IKE message + responseIKEMessage := ike_message.NewMessage(message.InitiatorSPI, message.ResponderSPI, + ike_message.IKE_AUTH, true, false, message.MessageID, responseIKEPayload) // Send IKE message to UE - SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage) + err = SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage, + ikeSecurityAssociation.IKESAKey) + if err != nil { + ikeLog.Errorf("HandleIKEAUTH(): %v", err) + } return } @@ -823,12 +601,12 @@ func (s *Server) HandleIKEAUTH( ranNgapId = 0 } - s.NgapEvtCh() <- n3iwf_context.NewUnmarshalEAP5GDataEvt( + s.SendNgapEvt(n3iwf_context.NewUnmarshalEAP5GDataEvt( ikeSecurityAssociation.LocalSPI, eapExpanded.VendorData, ikeSecurityAssociation.IkeUE != nil, ranNgapId, - ) + )) ikeSecurityAssociation.IKEConnection = &n3iwf_context.UDPSocketInfo{ Conn: udpConn, @@ -846,24 +624,14 @@ func (s *Server) HandleIKEAUTH( ikeUE := ikeSecurityAssociation.IkeUE // Prepare pseudorandom function for calculating/verifying authentication data - pseudorandomFunction, ok := security.NewPseudorandomFunction( - ikeUE.Kn3iwf, transformPseudorandomFunction.TransformID) - if !ok { - ikeLog.Error("Get an unsupported pseudorandom funcion. This may imply an unsupported transform is chosen.") - return - } + pseudorandomFunction := ikeSecurityAssociation.PrfInfo.Init(ikeUE.Kn3iwf) _, err := pseudorandomFunction.Write([]byte("Key Pad for IKEv2")) if err != nil { ikeLog.Errorf("Pseudorandom function write error: %v", err) return } secret := pseudorandomFunction.Sum(nil) - pseudorandomFunction, ok = security.NewPseudorandomFunction( - secret, transformPseudorandomFunction.TransformID) - if !ok { - ikeLog.Error("Get an unsupported pseudorandom funcion. This may imply an unsupported transform is chosen.") - return - } + pseudorandomFunction = ikeSecurityAssociation.PrfInfo.Init(secret) if authentication != nil { // Verifying remote AUTH @@ -875,28 +643,29 @@ func (s *Server) HandleIKEAUTH( } expectedAuthenticationData := pseudorandomFunction.Sum(nil) + ikeLog.Tracef("Kn3iwf:\n%s", hex.Dump(ikeUE.Kn3iwf)) + ikeLog.Tracef("secret:\n%s", hex.Dump(secret)) + ikeLog.Tracef("InitiatorSignedOctets:\n%s", hex.Dump(ikeSecurityAssociation.InitiatorSignedOctets)) ikeLog.Tracef("Expected Authentication Data:\n%s", hex.Dump(expectedAuthenticationData)) if !bytes.Equal(authentication.AuthenticationData, expectedAuthenticationData) { ikeLog.Warn("Peer authentication failed.") // Inform UE the authentication has failed - // Build IKE message - responseIKEMessage.BuildIKEHeader(message.InitiatorSPI, message.ResponderSPI, - ike_message.IKE_AUTH, ike_message.ResponseBitCheck, message.MessageID) - responseIKEMessage.Payloads.Reset() + responseIKEPayload.Reset() // Notification responseIKEPayload.BuildNotification( ike_message.TypeNone, ike_message.AUTHENTICATION_FAILED, nil, nil) - err = security.EncryptProcedure( - ikeSecurityAssociation, responseIKEPayload, responseIKEMessage) - if err != nil { - ikeLog.Errorf("Encrypting IKE message failed: %v", err) - return - } + // Build IKE message + responseIKEMessage := ike_message.NewMessage(message.InitiatorSPI, message.ResponderSPI, + ike_message.IKE_AUTH, true, false, message.MessageID, responseIKEPayload) // Send IKE message to UE - SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage) + err = SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage, + ikeSecurityAssociation.IKESAKey) + if err != nil { + ikeLog.Errorf("HandleIKEAUTH(): %v", err) + } return } else { ikeLog.Tracef("Peer authentication success") @@ -904,23 +673,21 @@ func (s *Server) HandleIKEAUTH( } else { ikeLog.Warn("Peer authentication failed.") // Inform UE the authentication has failed - // Build IKE message - responseIKEMessage.BuildIKEHeader(message.InitiatorSPI, message.ResponderSPI, - ike_message.IKE_AUTH, ike_message.ResponseBitCheck, message.MessageID) - responseIKEMessage.Payloads.Reset() + responseIKEPayload.Reset() // Notification responseIKEPayload.BuildNotification(ike_message.TypeNone, ike_message.AUTHENTICATION_FAILED, nil, nil) - err = security.EncryptProcedure( - ikeSecurityAssociation, responseIKEPayload, responseIKEMessage) - if err != nil { - ikeLog.Errorf("Encrypting IKE message failed: %v", err) - return - } + // Build IKE message + responseIKEMessage := ike_message.NewMessage(message.InitiatorSPI, message.ResponderSPI, + ike_message.IKE_AUTH, true, false, message.MessageID, responseIKEPayload) // Send IKE message to UE - SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage) + err = SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage, + ikeSecurityAssociation.IKESAKey) + if err != nil { + ikeLog.Errorf("HandleIKEAUTH(): %v", err) + } return } @@ -948,10 +715,7 @@ func (s *Server) HandleIKEAUTH( ikeLog.Warn("Configuration is nil. UE did not sent any configuration request.") } - // Build response IKE message - responseIKEMessage.BuildIKEHeader(message.InitiatorSPI, message.ResponderSPI, - ike_message.IKE_AUTH, ike_message.ResponseBitCheck, message.MessageID) - responseIKEMessage.Payloads.Reset() + responseIKEPayload.Reset() // Calculate local AUTH pseudorandomFunction.Reset() @@ -969,7 +733,13 @@ func (s *Server) HandleIKEAUTH( var ueIPAddr, n3iwfIPAddr net.IP if addrRequest { // IP addresses (IPSec) - ueIPAddr = n3iwfCtx.NewInternalUEIPAddr(ikeUE).To4() + var ueIp net.IP + ueIp, err = n3iwfCtx.NewIPsecInnerUEIP(ikeUE) + if err != nil { + ikeLog.Errorf("HandleIKEAUTH(): %v", err) + return + } + ueIPAddr = ueIp.To4() n3iwfIPAddr = net.ParseIP(ipsecGwAddr).To4() responseConfiguration := responseIKEPayload.BuildConfiguration( @@ -977,7 +747,7 @@ func (s *Server) HandleIKEAUTH( responseConfiguration.ConfigurationAttribute.BuildConfigurationAttribute( ike_message.INTERNAL_IP4_ADDRESS, ueIPAddr) responseConfiguration.ConfigurationAttribute.BuildConfigurationAttribute( - ike_message.INTERNAL_IP4_NETMASK, n3iwfCtx.UeIPRange.Mask) + ike_message.INTERNAL_IP4_NETMASK, n3iwfCtx.IPSecInnerIPPool.IPSubnet.Mask) var ipsecInnerIPAddr *net.IPAddr ikeUE.IPSecInnerIP = ueIPAddr @@ -1014,17 +784,23 @@ func (s *Server) HandleIKEAUTH( var inboundSPI uint32 inboundSPIByte := make([]byte, 4) for { - randomUint64 := security.GenerateRandomNumber().Uint64() + buf := make([]byte, 4) + _, err = rand.Read(buf) + if err != nil { + ikeLog.Errorf("Handle IKE_AUTH Generate ChildSA inboundSPI: %v", err) + return + } + randomUint32 := binary.BigEndian.Uint32(buf) // check if the inbound SPI havn't been allocated by N3IWF - if _, ok1 := n3iwfCtx.ChildSA.Load(uint32(randomUint64)); !ok1 { - inboundSPI = uint32(randomUint64) + if _, ok1 := n3iwfCtx.ChildSA.Load(randomUint32); !ok1 { + inboundSPI = randomUint32 break } } binary.BigEndian.PutUint32(inboundSPIByte, inboundSPI) outboundSPI := binary.BigEndian.Uint32(ikeSecurityAssociation.IKEAuthResponseSA.Proposals[0].SPI) - ikeLog.Infof("Inbound SPI: %+v, Outbound SPI: %+v", inboundSPI, outboundSPI) + ikeLog.Infof("Inbound SPI: 0x%08x, Outbound SPI: 0x%08x", inboundSPI, outboundSPI) // SPI field of IKEAuthResponseSA is used to save outbound SPI temporarily. // After N3IWF produced its inbound SPI, the field will be overwritten with the SPI. @@ -1035,7 +811,7 @@ func (s *Server) HandleIKEAUTH( childSecurityAssociationContext, err := ikeUE.CompleteChildSA( 0x01, outboundSPI, ikeSecurityAssociation.IKEAuthResponseSA) if err != nil { - ikeLog.Errorf("Create child security association context failed: %v", err) + ikeLog.Errorf("HandleIKEAUTH(): Create child security association context failed: %v", err) return } err = s.parseIPAddressInformationToChildSecurityAssociation( @@ -1049,14 +825,14 @@ func (s *Server) HandleIKEAUTH( // Select TCP traffic childSecurityAssociationContext.SelectedIPProtocol = unix.IPPROTO_TCP - errGen := security.GenerateKeyForChildSA( - ikeSecurityAssociation, childSecurityAssociationContext) + errGen := childSecurityAssociationContext.ChildSAKey.GenerateKeyForChildSA(ikeSecurityAssociation.IKESAKey, + ikeSecurityAssociation.ConcatenatedNonce) if errGen != nil { ikeLog.Errorf("Generate key for child SA failed: %v", errGen) return } // NAT-T concern - if ikeSecurityAssociation.UEIsBehindNAT || ikeSecurityAssociation.N3IWFIsBehindNAT { + if ikeSecurityAssociation.UeBehindNAT || ikeSecurityAssociation.N3iwfBehindNAT { childSecurityAssociationContext.EnableEncapsulate = true childSecurityAssociationContext.N3IWFPort = n3iwfAddr.Port childSecurityAssociationContext.NATPort = ueAddr.Port @@ -1068,13 +844,11 @@ func (s *Server) HandleIKEAUTH( // Notification(NSA_TCP_PORT) responseIKEPayload.BuildNotifyNAS_TCP_PORT(cfg.GetNasTcpPort()) - errEncrypt := security.EncryptProcedure( - ikeSecurityAssociation, responseIKEPayload, responseIKEMessage) - if errEncrypt != nil { - ikeLog.Errorf("Encrypting IKE message failed: %v", errEncrypt) - return - } + // Build IKE message + responseIKEMessage := ike_message.NewMessage(message.InitiatorSPI, message.ResponderSPI, + ike_message.IKE_AUTH, true, false, message.MessageID, responseIKEPayload) + childSecurityAssociationContext.LocalIsInitiator = false // Aplly XFRM rules // IPsec for CP always use default XFRM interface err = xfrm.ApplyXFRMRule(false, cfg.GetXfrmIfaceId(), childSecurityAssociationContext) @@ -1082,9 +856,15 @@ func (s *Server) HandleIKEAUTH( ikeLog.Errorf("Applying XFRM rules failed: %v", err) return } + ikeLog.Debugln(childSecurityAssociationContext.String(cfg.GetXfrmIfaceId())) // Send IKE message to UE - SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage) + err = SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage, + ikeSecurityAssociation.IKESAKey) + if err != nil { + ikeLog.Errorf("HandleIKEAUTH(): %v", err) + return + } ranNgapId, ok := n3iwfCtx.NgapIdLoad(ikeUE.N3IWFIKESecurityAssociation.LocalSPI) if !ok { @@ -1096,14 +876,12 @@ func (s *Server) HandleIKEAUTH( ikeSecurityAssociation.State++ // After this, N3IWF will forward NAS with Child SA (IPSec SA) - s.NgapEvtCh() <- n3iwf_context.NewStartTCPSignalNASMsgEvt( - ranNgapId, - ) + s.SendNgapEvt(n3iwf_context.NewStartTCPSignalNASMsgEvt(ranNgapId)) // Get TempPDUSessionSetupData from NGAP to setup PDU session if needed - s.NgapEvtCh() <- n3iwf_context.NewGetNGAPContextEvt( + s.SendNgapEvt(n3iwf_context.NewGetNGAPContextEvt( ranNgapId, []int64{n3iwf_context.CxtTempPDUSessionSetupData}, - ) + )) } } @@ -1111,68 +889,16 @@ func (s *Server) HandleCREATECHILDSA( udpConn *net.UDPConn, n3iwfAddr, ueAddr *net.UDPAddr, message *ike_message.IKEMessage, + ikeSecurityAssociation *n3iwf_context.IKESecurityAssociation, ) { ikeLog := logger.IKELog ikeLog.Infoln("Handle CREATE_CHILD_SA") - var encryptedPayload *ike_message.Encrypted - n3iwfCtx := s.Context() - responseIKEMessage := new(ike_message.IKEMessage) - - if message == nil { - ikeLog.Error("IKE Message is nil") - return - } - - // parse IKE header and setup IKE context - // check major version - majorVersion := ((message.Version & 0xf0) >> 4) - if majorVersion > 2 { - ikeLog.Warn("Received an IKE message with higher major version") - // send INFORMATIONAL type message with INVALID_MAJOR_VERSION Notify payload ( OUTSIDE IKE SA ) - responseIKEMessage.BuildIKEHeader(message.InitiatorSPI, message.ResponderSPI, - ike_message.INFORMATIONAL, ike_message.ResponseBitCheck, message.MessageID) - responseIKEMessage.Payloads.Reset() - responseIKEMessage.Payloads.BuildNotification(ike_message.TypeNone, ike_message.INVALID_MAJOR_VERSION, nil, nil) - - SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage) - return - } - - // Find corresponding IKE security association - responderSPI := message.ResponderSPI - - ikeLog.Warnf("CREATE_CHILD_SA responderSPI: %+v", responderSPI) - ikeSecurityAssociation, ok := n3iwfCtx.IKESALoad(responderSPI) - if !ok { - ikeLog.Warn("Unrecognized SPI") - // send INFORMATIONAL type message with INVALID_IKE_SPI Notify payload ( OUTSIDE IKE SA ) - responseIKEMessage.BuildIKEHeader(0, message.ResponderSPI, ike_message.INFORMATIONAL, - ike_message.ResponseBitCheck, message.MessageID) - responseIKEMessage.Payloads.Reset() - responseIKEMessage.Payloads.BuildNotification(ike_message.TypeNone, ike_message.INVALID_IKE_SPI, nil, nil) - - SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage) - return - } - - for _, ikePayload := range message.Payloads { - switch ikePayload.Type() { - case ike_message.TypeSK: - encryptedPayload = ikePayload.(*ike_message.Encrypted) - default: - ikeLog.Warnf( - "Get IKE payload (type %d) in CREATE_CHILD_SA message, this payload will not be handled by IKE handler", - ikePayload.Type()) - } - } - - decryptedIKEPayload, err := security.DecryptProcedure( - ikeSecurityAssociation, message, encryptedPayload) - if err != nil { - ikeLog.Errorf("Decrypt IKE message failed: %v", err) + if !ikeSecurityAssociation.IKEConnection.UEAddr.IP.Equal(ueAddr.IP) || + !ikeSecurityAssociation.IKEConnection.N3IWFAddr.IP.Equal(n3iwfAddr.IP) { + ikeLog.Warnf("Get unexpteced IP in SPI: %016x", ikeSecurityAssociation.LocalSPI) return } @@ -1182,7 +908,7 @@ func (s *Server) HandleCREATECHILDSA( var trafficSelectorInitiator *ike_message.TrafficSelectorInitiator var trafficSelectorResponder *ike_message.TrafficSelectorResponder - for _, ikePayload := range decryptedIKEPayload { + for _, ikePayload := range message.Payloads { switch ikePayload.Type() { case ike_message.TypeSA: securityAssociation = ikePayload.(*ike_message.SecurityAssociation) @@ -1239,8 +965,7 @@ func (s *Server) HandleCREATECHILDSA( ngapCxtReqNumlist := []int64{n3iwf_context.CxtTempPDUSessionSetupData} - s.NgapEvtCh() <- n3iwf_context.NewGetNGAPContextEvt(ranNgapId, - ngapCxtReqNumlist) + s.SendNgapEvt(n3iwf_context.NewGetNGAPContextEvt(ranNgapId, ngapCxtReqNumlist)) } func (s *Server) continueCreateChildSA( @@ -1314,14 +1039,14 @@ func (s *Server) continueCreateChildSA( // Select GRE traffic childSecurityAssociationContext.SelectedIPProtocol = unix.IPPROTO_GRE - err = security.GenerateKeyForChildSA( - ikeSecurityAssociation, childSecurityAssociationContext) + err = childSecurityAssociationContext.ChildSAKey.GenerateKeyForChildSA(ikeSecurityAssociation.IKESAKey, + ikeSecurityAssociation.ConcatenatedNonce) if err != nil { ikeLog.Errorf("Generate key for child SA failed: %v", err) return } // NAT-T concern - if ikeSecurityAssociation.UEIsBehindNAT || ikeSecurityAssociation.N3IWFIsBehindNAT { + if ikeSecurityAssociation.UeBehindNAT || ikeSecurityAssociation.N3iwfBehindNAT { childSecurityAssociationContext.EnableEncapsulate = true childSecurityAssociationContext.N3IWFPort = ikeConnection.N3IWFAddr.Port childSecurityAssociationContext.NATPort = ikeConnection.UEAddr.Port @@ -1337,7 +1062,7 @@ func (s *Server) continueCreateChildSA( // Setup XFRM interface for ipsec var linkIPSec netlink.Link n3iwfIPAddr := net.ParseIP(ipsecGwAddr).To4() - n3iwfIPAddrAndSubnet := net.IPNet{IP: n3iwfIPAddr, Mask: n3iwfCtx.UeIPRange.Mask} + n3iwfIPAddrAndSubnet := net.IPNet{IP: n3iwfIPAddr, Mask: n3iwfCtx.IPSecInnerIPPool.IPSubnet.Mask} newXfrmiId += cfg.GetXfrmIfaceId() + n3iwfCtx.XfrmIfaceIdOffsetForUP newXfrmiName := fmt.Sprintf("%s-%d", cfg.GetXfrmIfaceName(), newXfrmiId) @@ -1363,11 +1088,13 @@ func (s *Server) continueCreateChildSA( } // Aplly XFRM rules + childSecurityAssociationContext.LocalIsInitiator = true err = xfrm.ApplyXFRMRule(true, newXfrmiId, childSecurityAssociationContext) if err != nil { ikeLog.Errorf("Applying XFRM rules failed: %v", err) return } + ikeLog.Debugln(childSecurityAssociationContext.String(newXfrmiId)) ranNgapId, ok := n3iwfCtx.NgapIdLoad(ikeSecurityAssociation.LocalSPI) if !ok { @@ -1376,9 +1103,7 @@ func (s *Server) continueCreateChildSA( return } // Forward PDU Seesion Establishment Accept to UE - s.NgapEvtCh() <- n3iwf_context.NewSendNASMsgEvt( - ranNgapId, - ) + s.SendNgapEvt(n3iwf_context.NewSendNASMsgEvt(ranNgapId)) temporaryPDUSessionSetupData.FailedErrStr = append(temporaryPDUSessionSetupData.FailedErrStr, n3iwf_context.ErrNil) @@ -1392,50 +1117,14 @@ func (s *Server) HandleInformational( udpConn *net.UDPConn, n3iwfAddr, ueAddr *net.UDPAddr, message *ike_message.IKEMessage, + ikeSecurityAssociation *n3iwf_context.IKESecurityAssociation, ) { ikeLog := logger.IKELog ikeLog.Infoln("Handle Informational") - if message == nil { - ikeLog.Error("IKE Message is nil") - return - } - - n3iwfCtx := s.Context() - responseIKEMessage := new(ike_message.IKEMessage) - responderSPI := message.ResponderSPI - ikeSecurityAssociation, ok := n3iwfCtx.IKESALoad(responderSPI) - var encryptedPayload *ike_message.Encrypted - - if !ok { - ikeLog.Warn("Unrecognized SPI") - // send INFORMATIONAL type message with INVALID_IKE_SPI Notify payload ( OUTSIDE IKE SA ) - responseIKEMessage.BuildIKEHeader(0, message.ResponderSPI, ike_message.INFORMATIONAL, - ike_message.ResponseBitCheck, message.MessageID) - responseIKEMessage.Payloads.Reset() - responseIKEMessage.Payloads.BuildNotification(ike_message.TypeNone, ike_message.INVALID_IKE_SPI, nil, nil) - - SendIKEMessageToUE(udpConn, n3iwfAddr, ueAddr, responseIKEMessage) - return - } - - for _, ikePayload := range message.Payloads { - switch ikePayload.Type() { - case ike_message.TypeSK: - encryptedPayload = ikePayload.(*ike_message.Encrypted) - default: - ikeLog.Warnf( - "Get IKE payload (type %d) in Inoformational message, this payload will not be handled by IKE handler", - ikePayload.Type()) - } - } - - decryptedIKEPayload, err := security.DecryptProcedure( - ikeSecurityAssociation, message, encryptedPayload) - if err != nil { - ikeLog.Errorf("Decrypt IKE message failed: %v", err) - return - } + var deletePayload *ike_message.Delete + var err error + responseIKEPayload := new(ike_message.IKEPayloadContainer) n3iwfIke := ikeSecurityAssociation.IkeUE @@ -1445,42 +1134,32 @@ func (s *Server) HandleInformational( atomic.StoreInt32(&n3iwfIke.N3IWFIKESecurityAssociation.CurrentRetryTimes, 0) } - if len(decryptedIKEPayload) == 0 { // Receive DPD message - return - } - - for _, ikePayload := range decryptedIKEPayload { + for _, ikePayload := range message.Payloads { switch ikePayload.Type() { case ike_message.TypeD: - deletePayload := ikePayload.(*ike_message.Delete) - - ranNgapId, ok := n3iwfCtx.NgapIdLoad(n3iwfIke.N3IWFIKESecurityAssociation.LocalSPI) - if !ok { - ikeLog.Errorf("Cannot get RanNgapId from SPI : %+v", - n3iwfIke.N3IWFIKESecurityAssociation.LocalSPI) - return - } - - if deletePayload.ProtocolID == ike_message.TypeIKE { // Check if UE is response to a request that delete the ike SA - err := n3iwfIke.Remove() - if err != nil { - ikeLog.Errorf("Delete IkeUe Context error : %v", err) - } - s.NgapEvtCh() <- n3iwf_context.NewSendUEContextReleaseCompleteEvt( - ranNgapId, - ) - } else if deletePayload.ProtocolID == ike_message.TypeESP { - s.NgapEvtCh() <- n3iwf_context.NewSendPDUSessionResourceReleaseResEvt( - ranNgapId, - ) - } + deletePayload = ikePayload.(*ike_message.Delete) default: ikeLog.Warnf( "Get IKE payload (type %d) in Inoformational message, this payload will not be handled by IKE handler", ikePayload.Type()) } } - ikeSecurityAssociation.ResponderMessageID++ + + if deletePayload != nil { + responseIKEPayload, err = s.handleDeletePayload(deletePayload, message.IsResponse(), ikeSecurityAssociation) + if err != nil { + ikeLog.Errorf("HandleInformational(): %v", err) + return + } + } + + if message.IsResponse() { + ikeSecurityAssociation.ResponderMessageID++ + } else { // Get Request message + SendUEInformationExchange(ikeSecurityAssociation, ikeSecurityAssociation.IKESAKey, + responseIKEPayload, false, true, message.MessageID, + udpConn, ueAddr, n3iwfAddr) + } } func (s *Server) HandleEvent(ikeEvt n3iwf_context.IkeEvt) { @@ -1534,12 +1213,12 @@ func (s *Server) HandleUnmarshalEAP5GDataResponse(ikeEvt n3iwf_context.IkeEvt) { n3iwfCtx.IkeSpiNgapIdMapping(ikeUe.N3IWFIKESecurityAssociation.LocalSPI, ranUeNgapId) - s.NgapEvtCh() <- n3iwf_context.NewSendInitialUEMessageEvt( + s.SendNgapEvt(n3iwf_context.NewSendInitialUEMessageEvt( ranUeNgapId, ikeSecurityAssociation.IKEConnection.UEAddr.IP.To4().String(), ikeSecurityAssociation.IKEConnection.UEAddr.Port, nasPDU, - ) + )) } func (s *Server) HandleSendEAP5GFailureMsg(ikeEvt n3iwf_context.IkeEvt) { @@ -1554,32 +1233,28 @@ func (s *Server) HandleSendEAP5GFailureMsg(ikeEvt n3iwf_context.IkeEvt) { ikeSecurityAssociation, _ := n3iwfCtx.IKESALoad(localSPI) ikeLog.Warnf("EAP Failure : %s", errMsg.Error()) - responseIKEMessage := new(ike_message.IKEMessage) var responseIKEPayload ike_message.IKEPayloadContainer // Send EAP failure - // Build IKE message - responseIKEMessage.BuildIKEHeader(ikeSecurityAssociation.RemoteSPI, ikeSecurityAssociation.LocalSPI, - ike_message.IKE_AUTH, ike_message.ResponseBitCheck, ikeSecurityAssociation.InitiatorMessageID) - responseIKEMessage.Payloads.Reset() // EAP - identifier, err := security.GenerateRandomUint8() + identifier, err := ike_security.GenerateRandomUint8() if err != nil { ikeLog.Errorf("Generate random uint8 failed: %v", err) return } responseIKEPayload.BuildEAPfailure(identifier) - if err := security.EncryptProcedure( - ikeSecurityAssociation, responseIKEPayload, responseIKEMessage); err != nil { - ikeLog.Errorf("Encrypting IKE message failed: %v", err) - return - } + // Build IKE message + responseIKEMessage := ike_message.NewMessage(ikeSecurityAssociation.RemoteSPI, ikeSecurityAssociation.LocalSPI, + ike_message.IKE_AUTH, true, false, ikeSecurityAssociation.InitiatorMessageID, responseIKEPayload) // Send IKE message to UE - SendIKEMessageToUE(ikeSecurityAssociation.IKEConnection.Conn, + err = SendIKEMessageToUE(ikeSecurityAssociation.IKEConnection.Conn, ikeSecurityAssociation.IKEConnection.N3IWFAddr, ikeSecurityAssociation.IKEConnection.UEAddr, - responseIKEMessage) + responseIKEMessage, ikeSecurityAssociation.IKESAKey) + if err != nil { + ikeLog.Errorf("HandleSendEAP5GFailureMsg(): %v", err) + } } func (s *Server) HandleSendEAPSuccessMsg(ikeEvt n3iwf_context.IkeEvt) { @@ -1600,18 +1275,18 @@ func (s *Server) HandleSendEAPSuccessMsg(ikeEvt n3iwf_context.IkeEvt) { ikeSecurityAssociation.IkeUE.PduSessionListLen = pduSessionListLen - responseIKEMessage := new(ike_message.IKEMessage) var responseIKEPayload ike_message.IKEPayloadContainer - // Build IKE message - responseIKEMessage.BuildIKEHeader(ikeSecurityAssociation.RemoteSPI, - ikeSecurityAssociation.LocalSPI, ike_message.IKE_AUTH, ike_message.ResponseBitCheck, - ikeSecurityAssociation.InitiatorMessageID) - responseIKEMessage.Payloads.Reset() + responseIKEPayload.Reset() var identifier uint8 + var err error for { - identifier = uint8(math_rand.Uint32()) + identifier, err = ike_security.GenerateRandomUint8() + if err != nil { + ikeLog.Errorf("HandleSendEAPSuccessMsg() rand : %v", err) + return + } if identifier != ikeSecurityAssociation.LastEAPIdentifier { ikeSecurityAssociation.LastEAPIdentifier = identifier break @@ -1620,17 +1295,20 @@ func (s *Server) HandleSendEAPSuccessMsg(ikeEvt n3iwf_context.IkeEvt) { responseIKEPayload.BuildEAPSuccess(identifier) - err := security.EncryptProcedure( - ikeSecurityAssociation, responseIKEPayload, responseIKEMessage) - if err != nil { - ikeLog.Errorf("Encrypting IKE message failed: %v", err) - return - } + // Build IKE message + responseIKEMessage := ike_message.NewMessage(ikeSecurityAssociation.RemoteSPI, + ikeSecurityAssociation.LocalSPI, ike_message.IKE_AUTH, true, false, + ikeSecurityAssociation.InitiatorMessageID, responseIKEPayload) // Send IKE message to UE - SendIKEMessageToUE(ikeSecurityAssociation.IKEConnection.Conn, + err = SendIKEMessageToUE(ikeSecurityAssociation.IKEConnection.Conn, ikeSecurityAssociation.IKEConnection.N3IWFAddr, - ikeSecurityAssociation.IKEConnection.UEAddr, responseIKEMessage) + ikeSecurityAssociation.IKEConnection.UEAddr, responseIKEMessage, + ikeSecurityAssociation.IKESAKey) + if err != nil { + ikeLog.Errorf("HandleSendEAPSuccessMsg(): %v", err) + return + } ikeSecurityAssociation.State++ } @@ -1646,37 +1324,42 @@ func (s *Server) HandleSendEAPNASMsg(ikeEvt n3iwf_context.IkeEvt) { n3iwfCtx := s.Context() ikeSecurityAssociation, _ := n3iwfCtx.IKESALoad(localSPI) - responseIKEMessage := new(ike_message.IKEMessage) var responseIKEPayload ike_message.IKEPayloadContainer - - // Build IKE message - responseIKEMessage.BuildIKEHeader(ikeSecurityAssociation.RemoteSPI, - ikeSecurityAssociation.LocalSPI, ike_message.IKE_AUTH, ike_message.ResponseBitCheck, - ikeSecurityAssociation.InitiatorMessageID) - responseIKEMessage.Payloads.Reset() + responseIKEPayload.Reset() var identifier uint8 + var err error for { - identifier = uint8(math_rand.Uint32()) + identifier, err = ike_security.GenerateRandomUint8() + if err != nil { + ikeLog.Errorf("HandleSendEAPNASMsg() rand : %v", err) + return + } if identifier != ikeSecurityAssociation.LastEAPIdentifier { ikeSecurityAssociation.LastEAPIdentifier = identifier break } } - responseIKEPayload.BuildEAP5GNAS(identifier, nasPDU) - - err := security.EncryptProcedure( - ikeSecurityAssociation, responseIKEPayload, responseIKEMessage) + err = responseIKEPayload.BuildEAP5GNAS(identifier, nasPDU) if err != nil { - ikeLog.Errorf("Encrypting IKE message failed: %v", err) + ikeLog.Errorf("HandleSendEAPNASMsg() BuildEAP5GNAS: %v", err) return } + // Build IKE message + responseIKEMessage := ike_message.NewMessage(ikeSecurityAssociation.RemoteSPI, + ikeSecurityAssociation.LocalSPI, ike_message.IKE_AUTH, true, false, + ikeSecurityAssociation.InitiatorMessageID, responseIKEPayload) + // Send IKE message to UE - SendIKEMessageToUE(ikeSecurityAssociation.IKEConnection.Conn, + err = SendIKEMessageToUE(ikeSecurityAssociation.IKEConnection.Conn, ikeSecurityAssociation.IKEConnection.N3IWFAddr, - ikeSecurityAssociation.IKEConnection.UEAddr, responseIKEMessage) + ikeSecurityAssociation.IKEConnection.UEAddr, responseIKEMessage, + ikeSecurityAssociation.IKESAKey) + if err != nil { + ikeLog.Errorf("HandleSendEAPNASMsg(): %v", err) + } } func (s *Server) HandleCreatePDUSession(ikeEvt n3iwf_context.IkeEvt) { @@ -1820,26 +1503,28 @@ func (s *Server) CreatePDUSessionChildSA( pduSessionID := pduSession.Id // Send CREATE_CHILD_SA to UE - ikeMessage := new(ike_message.IKEMessage) - var ikePayload ike_message.IKEPayloadContainer + var responseIKEPayload ike_message.IKEPayloadContainer errStr := n3iwf_context.ErrNil - // Build IKE message - ikeMessage.BuildIKEHeader(ikeSecurityAssociation.RemoteSPI, - ikeSecurityAssociation.LocalSPI, ike_message.CREATE_CHILD_SA, - 0, ikeSecurityAssociation.ResponderMessageID) - ikeMessage.Payloads.Reset() + responseIKEPayload.Reset() // Build SA - requestSA := ikePayload.BuildSecurityAssociation() + requestSA := responseIKEPayload.BuildSecurityAssociation() // Allocate SPI var spi uint32 spiByte := make([]byte, 4) for { - randomUint64 := security.GenerateRandomNumber().Uint64() - if _, ok := n3iwfCtx.ChildSA.Load(uint32(randomUint64)); !ok { - spi = uint32(randomUint64) + var err error + buf := make([]byte, 4) + _, err = rand.Read(buf) + if err != nil { + ikeLog.Errorf("CreatePDUSessionChildSA Generate SPI: %v", err) + return + } + randomUint32 := binary.BigEndian.Uint32(buf) + if _, ok := n3iwfCtx.ChildSA.Load(randomUint32); !ok { + spi = randomUint32 break } } @@ -1849,14 +1534,18 @@ func (s *Server) CreatePDUSessionChildSA( proposal := requestSA.Proposals.BuildProposal(1, ike_message.TypeESP, spiByte) // Encryption transform - var attributeType uint16 = ike_message.AttributeTypeKeyLength - var attributeValue uint16 = 256 - proposal.EncryptionAlgorithm.BuildTransform(ike_message.TypeEncryptionAlgorithm, - ike_message.ENCR_AES_CBC, &attributeType, &attributeValue, nil) + encrTranform, err := encr.ToTransform(ikeSecurityAssociation.EncrInfo) + if err != nil { + ikeLog.Errorf("encr ToTransform error: %v", err) + break + } + + proposal.EncryptionAlgorithm = append(proposal.EncryptionAlgorithm, + encrTranform) // Integrity transform if pduSession.SecurityIntegrity { - proposal.IntegrityAlgorithm.BuildTransform(ike_message.TypeIntegrityAlgorithm, - ikeUe.N3IWFIKESecurityAssociation.IntegrityAlgorithm.TransformID, nil, nil, nil) + proposal.IntegrityAlgorithm = append(proposal.IntegrityAlgorithm, + integ.ToTransform(ikeSecurityAssociation.IntegInfo)) } // RFC 7296 @@ -1866,58 +1555,73 @@ func (s *Server) CreatePDUSessionChildSA( // ESN transform proposal.ExtendedSequenceNumbers.BuildTransform( - ike_message.TypeExtendedSequenceNumbers, ike_message.ESN_NO, nil, nil, nil) + ike_message.TypeExtendedSequenceNumbers, ike_message.ESN_DISABLE, nil, nil, nil) - ikeUe.CreateHalfChildSA(ikeMessage.MessageID, spi, pduSessionID) + ikeUe.CreateHalfChildSA(ikeSecurityAssociation.ResponderMessageID, spi, pduSessionID) // Build Nonce - nonceData := security.GenerateRandomNumber().Bytes() - ikePayload.BuildNonce(nonceData) + nonceDataBigInt, errGen := ike_security.GenerateRandomNumber() + if errGen != nil { + ikeLog.Errorf("CreatePDUSessionChildSA Build Nonce: %v", errGen) + return + } + nonceData := nonceDataBigInt.Bytes() + responseIKEPayload.BuildNonce(nonceData) // Store nonce into context ikeSecurityAssociation.ConcatenatedNonce = nonceData // TSi n3iwfIPAddr := net.ParseIP(ipsecGwAddr) - tsi := ikePayload.BuildTrafficSelectorInitiator() + tsi := responseIKEPayload.BuildTrafficSelectorInitiator() tsi.TrafficSelectors.BuildIndividualTrafficSelector( ike_message.TS_IPV4_ADDR_RANGE, ike_message.IPProtocolAll, 0, 65535, n3iwfIPAddr.To4(), n3iwfIPAddr.To4()) // TSr ueIPAddr := ikeUe.IPSecInnerIP - tsr := ikePayload.BuildTrafficSelectorResponder() + tsr := responseIKEPayload.BuildTrafficSelectorResponder() tsr.TrafficSelectors.BuildIndividualTrafficSelector( ike_message.TS_IPV4_ADDR_RANGE, ike_message.IPProtocolAll, 0, 65535, ueIPAddr.To4(), ueIPAddr.To4()) + if pduSessionID < 0 || pduSessionID > math.MaxUint8 { + ikeLog.Errorf("CreatePDUSessionChildSA pduSessionID exceeds uint8 range: %d", pduSessionID) + break + } // Notify-Qos - ikePayload.BuildNotify5G_QOS_INFO(uint8(pduSessionID), pduSession.QFIList, true, false, 0) + err = responseIKEPayload.BuildNotify5G_QOS_INFO(uint8(pduSessionID), pduSession.QFIList, true, false, 0) + if err != nil { + ikeLog.Errorf("CreatePDUSessionChildSA error : %v", err) + break + } // Notify-UP_IP_ADDRESS - ikePayload.BuildNotifyUP_IP4_ADDRESS(ipsecGwAddr) + responseIKEPayload.BuildNotifyUP_IP4_ADDRESS(ipsecGwAddr) temporaryPDUSessionSetupData.Index++ - if err := security.EncryptProcedure( - ikeUe.N3IWFIKESecurityAssociation, ikePayload, ikeMessage); err != nil { - ikeLog.Errorf("Encrypting IKE message failed: %v", err) + // Build IKE message + ikeMessage := ike_message.NewMessage(ikeSecurityAssociation.RemoteSPI, ikeSecurityAssociation.LocalSPI, + ike_message.CREATE_CHILD_SA, false, false, ikeSecurityAssociation.ResponderMessageID, + responseIKEPayload) + + err = SendIKEMessageToUE(ikeSecurityAssociation.IKEConnection.Conn, + ikeSecurityAssociation.IKEConnection.N3IWFAddr, + ikeSecurityAssociation.IKEConnection.UEAddr, ikeMessage, + ikeSecurityAssociation.IKESAKey) + if err != nil { + ikeLog.Errorf("CreatePDUSessionChildSA error : %v", err) errStr = n3iwf_context.ErrTransportResourceUnavailable temporaryPDUSessionSetupData.FailedErrStr = append(temporaryPDUSessionSetupData.FailedErrStr, errStr) - continue + } else { + temporaryPDUSessionSetupData.FailedErrStr = append(temporaryPDUSessionSetupData.FailedErrStr, + errStr) + break } - - temporaryPDUSessionSetupData.FailedErrStr = append(temporaryPDUSessionSetupData.FailedErrStr, - errStr) - - SendIKEMessageToUE(ikeSecurityAssociation.IKEConnection.Conn, ikeSecurityAssociation.IKEConnection.N3IWFAddr, - ikeSecurityAssociation.IKEConnection.UEAddr, ikeMessage) - break } else { - s.NgapEvtCh() <- n3iwf_context.NewSendPDUSessionResourceSetupResEvt( - ranNgapId, - ) + s.SendNgapEvt(n3iwf_context.NewSendPDUSessionResourceSetupResEvt(ranNgapId)) break } } @@ -1936,36 +1640,41 @@ func (s *Server) StartDPD(ikeUe *n3iwf_context.N3IWFIkeUe) { n3iwfCtx := s.Context() cfg := s.Config() + ikeSA := ikeUe.N3IWFIKESecurityAssociation liveness := cfg.GetLivenessCheck() if liveness.Enable { - ikeUe.N3IWFIKESecurityAssociation.IsUseDPD = true + ikeSA.IsUseDPD = true timer := time.NewTicker(liveness.TransFreq) for { select { - case <-ikeUe.N3IWFIKESecurityAssociation.IKESAClosedCh: - close(ikeUe.N3IWFIKESecurityAssociation.IKESAClosedCh) + case <-ikeSA.IKESAClosedCh: + close(ikeSA.IKESAClosedCh) timer.Stop() return case <-timer.C: - SendUEInformationExchange(ikeUe, nil) - var DPDReqRetransTime time.Duration = 2 * time.Second - ikeUe.N3IWFIKESecurityAssociation.DPDReqRetransTimer = n3iwf_context.NewDPDPeriodicTimer( - DPDReqRetransTime, liveness.MaxRetryTimes, ikeUe.N3IWFIKESecurityAssociation, + var payload *ike_message.IKEPayloadContainer + SendUEInformationExchange(ikeSA, ikeSA.IKESAKey, payload, false, false, + ikeSA.ResponderMessageID, ikeUe.IKEConnection.Conn, ikeUe.IKEConnection.UEAddr, + ikeUe.IKEConnection.N3IWFAddr) + + var DPDReqRetransTime time.Duration = 2 * time.Second // TODO: make it configurable + ikeSA.DPDReqRetransTimer = n3iwf_context.NewDPDPeriodicTimer( + DPDReqRetransTime, liveness.MaxRetryTimes, ikeSA, func() { ikeLog.Errorf("UE is down") - ranNgapId, ok := n3iwfCtx.NgapIdLoad(ikeUe.N3IWFIKESecurityAssociation.LocalSPI) + ranNgapId, ok := n3iwfCtx.NgapIdLoad(ikeSA.LocalSPI) if !ok { ikeLog.Infof("Cannot find ranNgapId form SPI : %+v", - ikeUe.N3IWFIKESecurityAssociation.LocalSPI) + ikeSA.LocalSPI) return } - s.NgapEvtCh() <- n3iwf_context.NewSendUEContextReleaseRequestEvt( + s.SendNgapEvt(n3iwf_context.NewSendUEContextReleaseRequestEvt( ranNgapId, n3iwf_context.ErrRadioConnWithUeLost, - ) + )) - ikeUe.N3IWFIKESecurityAssociation.DPDReqRetransTimer = nil + ikeSA.DPDReqRetransTimer = nil timer.Stop() }) } @@ -1973,113 +1682,133 @@ func (s *Server) StartDPD(ikeUe *n3iwf_context.N3IWFIkeUe) { } } -func isTransformSupported( - transformType uint8, - transformID uint16, - attributePresent bool, - attributeValue uint16, -) bool { - switch transformType { - case ike_message.TypeEncryptionAlgorithm: - switch transformID { - case ike_message.ENCR_DES_IV64: - return false - case ike_message.ENCR_DES: - return false - case ike_message.ENCR_3DES: - return false - case ike_message.ENCR_RC5: - return false - case ike_message.ENCR_IDEA: - return false - case ike_message.ENCR_CAST: - return false - case ike_message.ENCR_BLOWFISH: - return false - case ike_message.ENCR_3IDEA: - return false - case ike_message.ENCR_DES_IV32: - return false - case ike_message.ENCR_NULL: - return false - case ike_message.ENCR_AES_CBC: - if attributePresent { - switch attributeValue { - case 128: - return true - case 192: - return true - case 256: - return true - default: - return false - } - } else { - return false +func (s *Server) handleNATDetect( + initiatorSPI, responderSPI uint64, + notifications []*ike_message.Notification, + ueAddr, n3iwfAddr *net.UDPAddr, +) (bool, bool, error) { + ikeLog := logger.IKELog + ueBehindNAT := false + n3iwfBehindNAT := false + + srcNatDData, err := s.generateNATDetectHash(initiatorSPI, responderSPI, ueAddr) + if err != nil { + return false, false, errors.Wrapf(err, "handle NATD") + } + + dstNatDData, err := s.generateNATDetectHash(initiatorSPI, responderSPI, n3iwfAddr) + if err != nil { + return false, false, errors.Wrapf(err, "handle NATD") + } + + for _, notification := range notifications { + switch notification.NotifyMessageType { + case ike_message.NAT_DETECTION_SOURCE_IP: + ikeLog.Tracef("Received IKE Notify: NAT_DETECTION_SOURCE_IP") + if !bytes.Equal(notification.NotificationData, srcNatDData) { + ikeLog.Tracef("UE(SPI: %016x) is behind NAT", responderSPI) + ueBehindNAT = true + } + case ike_message.NAT_DETECTION_DESTINATION_IP: + ikeLog.Tracef("Received IKE Notify: NAT_DETECTION_DESTINATION_IP") + if !bytes.Equal(notification.NotificationData, dstNatDData) { + ikeLog.Tracef("N3IWF is behind NAT") + n3iwfBehindNAT = true } - case ike_message.ENCR_AES_CTR: - return false - default: - return false - } - case ike_message.TypePseudorandomFunction: - switch transformID { - case ike_message.PRF_HMAC_MD5: - return true - case ike_message.PRF_HMAC_SHA1: - return true - case ike_message.PRF_HMAC_TIGER: - return false - case ike_message.PRF_HMAC_SHA2_256: - return true default: - return false } - case ike_message.TypeIntegrityAlgorithm: - switch transformID { - case ike_message.AUTH_NONE: - return false - case ike_message.AUTH_HMAC_MD5_96: - return true - case ike_message.AUTH_HMAC_SHA1_96: - return true - case ike_message.AUTH_DES_MAC: - return false - case ike_message.AUTH_KPDK_MD5: - return false - case ike_message.AUTH_AES_XCBC_96: - return false - case ike_message.AUTH_HMAC_SHA2_256_128: - return true - default: - return false + } + return ueBehindNAT, n3iwfBehindNAT, nil +} + +func (s *Server) generateNATDetectHash( + initiatorSPI, responderSPI uint64, + addr *net.UDPAddr, +) ([]byte, error) { + // Calculate NAT_DETECTION hash for NAT-T + // : sha1(ispi | rspi | ip | port) + natdData := make([]byte, 22) + binary.BigEndian.PutUint64(natdData[0:8], initiatorSPI) + binary.BigEndian.PutUint64(natdData[8:16], responderSPI) + copy(natdData[16:20], addr.IP.To4()) + binary.BigEndian.PutUint16(natdData[20:22], uint16(addr.Port)) // #nosec G115 + + sha1HashFunction := sha1.New() // #nosec G401 + _, err := sha1HashFunction.Write(natdData) + if err != nil { + return nil, errors.Wrapf(err, "generate NATD Hash") + } + return sha1HashFunction.Sum(nil), nil +} + +func (s *Server) buildNATDetectNotifPayload( + ikeSA *n3iwf_context.IKESecurityAssociation, + payload *ike_message.IKEPayloadContainer, + ueAddr, n3iwfAddr *net.UDPAddr, +) error { + srcNatDHash, err := s.generateNATDetectHash(ikeSA.RemoteSPI, ikeSA.LocalSPI, n3iwfAddr) + if err != nil { + return errors.Wrapf(err, "build NATD") + } + // Build and append notify payload for NAT_DETECTION_SOURCE_IP + payload.BuildNotification( + ike_message.TypeNone, ike_message.NAT_DETECTION_SOURCE_IP, nil, srcNatDHash) + + dstNatDHash, err := s.generateNATDetectHash(ikeSA.RemoteSPI, ikeSA.LocalSPI, ueAddr) + if err != nil { + return errors.Wrapf(err, "build NATD") + } + // Build and append notify payload for NAT_DETECTION_DESTINATION_IP + payload.BuildNotification( + ike_message.TypeNone, ike_message.NAT_DETECTION_DESTINATION_IP, nil, dstNatDHash) + + return nil +} + +func (s *Server) handleDeletePayload(payload *ike_message.Delete, isResponse bool, + ikeSecurityAssociation *n3iwf_context.IKESecurityAssociation) ( + *ike_message.IKEPayloadContainer, error, +) { + var evt n3iwf_context.NgapEvt + var err error + n3iwfCtx := s.Context() + n3iwfIke := ikeSecurityAssociation.IkeUE + responseIKEPayload := new(ike_message.IKEPayloadContainer) + + ranNgapId, ok := n3iwfCtx.NgapIdLoad(n3iwfIke.N3IWFIKESecurityAssociation.LocalSPI) + if !ok { + return nil, errors.Errorf("handleDeletePayload: Cannot get RanNgapId from SPI : %+v", + n3iwfIke.N3IWFIKESecurityAssociation.LocalSPI) + } + + switch payload.ProtocolID { + case ike_message.TypeIKE: + if !isResponse { + err = n3iwfIke.Remove() + if err != nil { + return nil, errors.Wrapf(err, "handleDeletePayload: Delete IkeUe Context error") + } } - case ike_message.TypeDiffieHellmanGroup: - switch transformID { - case ike_message.DH_NONE: - return false - case ike_message.DH_768_BIT_MODP: - return false - case ike_message.DH_1024_BIT_MODP: - return true - case ike_message.DH_1536_BIT_MODP: - return false - case ike_message.DH_2048_BIT_MODP: - return true - case ike_message.DH_3072_BIT_MODP: - return false - case ike_message.DH_4096_BIT_MODP: - return false - case ike_message.DH_6144_BIT_MODP: - return false - case ike_message.DH_8192_BIT_MODP: - return false - default: - return false + + evt = n3iwf_context.NewSendUEContextReleaseEvt(ranNgapId) + case ike_message.TypeESP: + var deletSPIs []uint32 + var deletPduIds []int64 + if !isResponse { + deletSPIs, deletPduIds, err = s.deleteChildSAFromSPIList(n3iwfIke, payload.SPIs) + if err != nil { + return nil, errors.Wrapf(err, "handleDeletePayload") + } + responseIKEPayload.BuildDeletePayload(ike_message.TypeESP, 4, uint16(len(deletSPIs)), deletSPIs) } + + evt = n3iwf_context.NewendPDUSessionResourceReleaseEvt(ranNgapId, deletPduIds) default: - return false + return nil, errors.Errorf("Get Protocol ID %d in Informational delete payload, "+ + "this payload will not be handled by IKE handler", payload.ProtocolID) } + s.SendNgapEvt(evt) + return responseIKEPayload, nil } func isTransformKernelSupported( @@ -2176,32 +1905,32 @@ func isTransformKernelSupported( } case ike_message.TypeDiffieHellmanGroup: switch transformID { - case ike_message.DH_NONE: - return false - case ike_message.DH_768_BIT_MODP: - return false - case ike_message.DH_1024_BIT_MODP: - return false - case ike_message.DH_1536_BIT_MODP: - return false - case ike_message.DH_2048_BIT_MODP: - return false - case ike_message.DH_3072_BIT_MODP: - return false - case ike_message.DH_4096_BIT_MODP: - return false - case ike_message.DH_6144_BIT_MODP: - return false - case ike_message.DH_8192_BIT_MODP: - return false + // case ike_message.DH_NONE: + // return false + // case ike_message.DH_768_BIT_MODP: + // return false + // case ike_message.DH_1024_BIT_MODP: + // return false + // case ike_message.DH_1536_BIT_MODP: + // return false + // case ike_message.DH_2048_BIT_MODP: + // return false + // case ike_message.DH_3072_BIT_MODP: + // return false + // case ike_message.DH_4096_BIT_MODP: + // return false + // case ike_message.DH_6144_BIT_MODP: + // return false + // case ike_message.DH_8192_BIT_MODP: + // return false default: return false } case ike_message.TypeExtendedSequenceNumbers: switch transformID { - case ike_message.ESN_NO: + case ike_message.ESN_ENABLE: return true - case ike_message.ESN_NEED: + case ike_message.ESN_DISABLE: return true default: return false @@ -2242,3 +1971,118 @@ func (s *Server) parseIPAddressInformationToChildSecurityAssociation( return nil } + +func SelectProposal(proposals ike_message.ProposalContainer) ike_message.ProposalContainer { + var chooseProposal ike_message.ProposalContainer + + for _, proposal := range proposals { + // We need ENCR, PRF, INTEG, DH, but not ESN + + var encryptionAlgorithmTransform, pseudorandomFunctionTransform *ike_message.Transform + var integrityAlgorithmTransform, diffieHellmanGroupTransform *ike_message.Transform + var chooseDH dh.DHType + var chooseEncr encr.ENCRType + var chooseInte integ.INTEGType + var choosePrf prf.PRFType + + for _, transform := range proposal.DiffieHellmanGroup { + dhType := dh.DecodeTransform(transform) + if dhType != nil { + if diffieHellmanGroupTransform == nil { + diffieHellmanGroupTransform = transform + chooseDH = dhType + } + } + } + if chooseDH == nil { + continue // mandatory + } + + for _, transform := range proposal.EncryptionAlgorithm { + encrType := encr.DecodeTransform(transform) + if encrType != nil { + if encryptionAlgorithmTransform == nil { + encryptionAlgorithmTransform = transform + chooseEncr = encrType + } + } + } + if chooseEncr == nil { + continue // mandatory + } + + for _, transform := range proposal.IntegrityAlgorithm { + integType := integ.DecodeTransform(transform) + if integType != nil { + if integrityAlgorithmTransform == nil { + integrityAlgorithmTransform = transform + chooseInte = integType + } + } + } + if chooseInte == nil { + continue // mandatory + } + + for _, transform := range proposal.PseudorandomFunction { + prfType := prf.DecodeTransform(transform) + if prfType != nil { + if pseudorandomFunctionTransform == nil { + pseudorandomFunctionTransform = transform + choosePrf = prfType + } + } + } + if choosePrf == nil { + continue // mandatory + } + if len(proposal.ExtendedSequenceNumbers) > 0 { + continue // No ESN + } + + // Construct chosen proposal, with ENCR, PRF, INTEG, DH, and each + // contains one transform expectively + chosenProposal := chooseProposal.BuildProposal(proposal.ProposalNumber, proposal.ProtocolID, nil) + chosenProposal.EncryptionAlgorithm = append(chosenProposal.EncryptionAlgorithm, encryptionAlgorithmTransform) + chosenProposal.IntegrityAlgorithm = append(chosenProposal.IntegrityAlgorithm, integrityAlgorithmTransform) + chosenProposal.PseudorandomFunction = append(chosenProposal.PseudorandomFunction, pseudorandomFunctionTransform) + chosenProposal.DiffieHellmanGroup = append(chosenProposal.DiffieHellmanGroup, diffieHellmanGroupTransform) + break + } + return chooseProposal +} + +func (s *Server) deleteChildSAFromSPIList(ikeUe *n3iwf_context.N3IWFIkeUe, spiList []uint32) ( + []uint32, []int64, error, +) { + ikeLog := logger.IKELog + var deleteSPIs []uint32 + var deletePduIds []int64 + + for _, spi := range spiList { + found := false + for _, childSA := range ikeUe.N3IWFChildSecurityAssociation { + if childSA.OutboundSPI == spi { + found = true + deleteSPIs = append(deleteSPIs, childSA.InboundSPI) + + if len(childSA.PDUSessionIds) == 0 { + return nil, nil, errors.Errorf("Child_SA SPI: 0x%08x doesn't have PDU Session ID", + spi) + } + deletePduIds = append(deletePduIds, childSA.PDUSessionIds[0]) + + err := ikeUe.DeleteChildSA(childSA) + if err != nil { + return nil, nil, errors.Wrapf(err, "DeleteChildSAFromSPIList") + } + break + } + } + if !found { + ikeLog.Warnf("deleteChildSAFromSPIList(): Get unknown Child_SA with SPI: 0x%08x", spi) + } + } + + return deleteSPIs, deletePduIds, nil +} diff --git a/internal/ike/handler_test.go b/internal/ike/handler_test.go new file mode 100644 index 00000000..f4366ecf --- /dev/null +++ b/internal/ike/handler_test.go @@ -0,0 +1,294 @@ +package ike + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" + + ike_message "github.com/free5gc/ike/message" + n3iwf_context "github.com/free5gc/n3iwf/internal/context" + "github.com/free5gc/n3iwf/pkg/factory" + "github.com/free5gc/util/ippool" +) + +func TestRemoveIkeUe(t *testing.T) { + n3iwf, err := NewN3iwfTestApp(&factory.Config{}) + require.NoError(t, err) + + n3iwf.ikeServer, err = NewServer(n3iwf) + require.NoError(t, err) + + n3iwfCtx := n3iwf.n3iwfCtx + ikeSA := n3iwfCtx.NewIKESecurityAssociation() + ikeUe := n3iwfCtx.NewN3iwfIkeUe(ikeSA.LocalSPI) + ikeUe.N3IWFIKESecurityAssociation = ikeSA + ikeUe.IPSecInnerIP = net.ParseIP("10.0.0.1") + ikeSA.IsUseDPD = false + + n3iwfCtx.IPSecInnerIPPool, err = ippool.NewIPPool("10.0.0.0/24") + require.NoError(t, err) + _, err = n3iwfCtx.IPSecInnerIPPool.Allocate(nil) + require.NoError(t, err) + + ikeUe.CreateHalfChildSA(1, 123, 1) + + ikeAuth := &ike_message.SecurityAssociation{} + + proposal := ikeAuth.Proposals.BuildProposal(1, 1, []byte{0, 1, 2, 3}) + var attributeType uint16 = ike_message.AttributeTypeKeyLength + var attributeValue uint16 = 256 + proposal.EncryptionAlgorithm.BuildTransform(ike_message.TypeEncryptionAlgorithm, + ike_message.ENCR_AES_CBC, &attributeType, &attributeValue, nil) + + proposal.IntegrityAlgorithm.BuildTransform(ike_message.TypeIntegrityAlgorithm, + ike_message.AUTH_HMAC_SHA1_96, nil, nil, nil) + + proposal.ExtendedSequenceNumbers.BuildTransform( + ike_message.TypeExtendedSequenceNumbers, ike_message.ESN_DISABLE, nil, nil, nil) + + childSA, err := ikeUe.CompleteChildSA(1, 456, ikeAuth) + require.NoError(t, err) + + err = n3iwf.ikeServer.removeIkeUe(ikeSA.LocalSPI) + require.NoError(t, err) + + _, ok := n3iwfCtx.IkeUePoolLoad(ikeSA.LocalSPI) + require.False(t, ok) + + _, ok = n3iwfCtx.IKESALoad(ikeSA.LocalSPI) + require.False(t, ok) + + _, ok = ikeUe.N3IWFChildSecurityAssociation[childSA.InboundSPI] + require.False(t, ok) +} + +func TestGenerateNATDetectHash(t *testing.T) { + n3iwf, err := NewN3iwfTestApp(&factory.Config{}) + require.NoError(t, err) + + n3iwf.ikeServer, err = NewServer(n3iwf) + require.NoError(t, err) + + tests := []struct { + name string + initiatorSPI uint64 + responderSPI uint64 + Addr net.UDPAddr + expectedData []byte + }{ + { + name: "Generate NAT-D hash", + initiatorSPI: 0x1122334455667788, + responderSPI: 0xaabbeeddeeff1122, + Addr: net.UDPAddr{ + IP: net.ParseIP("192.168.1.1"), + Port: 4500, + }, + expectedData: []byte{ + 0xd2, 0xee, 0x40, 0x2d, 0x5d, 0x53, 0xe4, 0x4a, + 0x01, 0x2d, 0x44, 0x2a, 0x90, 0x05, 0xc1, 0xea, + 0x38, 0x8a, 0x81, 0x7e, + }, + }, + } + + for i := range tests { + tt := tests[i] + t.Run(tt.name, func(t *testing.T) { + data, err := n3iwf.ikeServer.generateNATDetectHash(tt.initiatorSPI, tt.responderSPI, &tt.Addr) + require.NoError(t, err) + + require.Equal(t, tt.expectedData, data) + }) + } +} + +func TestBuildNATDetectMsg(t *testing.T) { + n3iwf, err := NewN3iwfTestApp(&factory.Config{}) + require.NoError(t, err) + + n3iwf.ikeServer, err = NewServer(n3iwf) + require.NoError(t, err) + + remoteSPI := uint64(0x1234567890abcdef) + localSPI := uint64(0xfedcba0987654321) + ikeSA := &n3iwf_context.IKESecurityAssociation{ + LocalSPI: localSPI, + RemoteSPI: remoteSPI, + } + payload := &ike_message.IKEPayloadContainer{} + + ueAddr := net.UDPAddr{ + IP: net.ParseIP("192.168.1.1"), + Port: 4500, + } + n3iwfAddr := net.UDPAddr{ + IP: net.ParseIP("192.168.1.2"), + Port: 4500, + } + + err = n3iwf.ikeServer.buildNATDetectNotifPayload(ikeSA, payload, &ueAddr, &n3iwfAddr) + require.NoError(t, err) + + var notifications []*ike_message.Notification + for _, ikePayload := range *payload { + switch ikePayload.Type() { + case ike_message.TypeN: + notifications = append(notifications, ikePayload.(*ike_message.Notification)) + default: + require.Fail(t, "Get unexpected IKE payload type : %v", ikePayload.Type()) + } + } + + for _, notification := range notifications { + switch notification.NotifyMessageType { + case ike_message.NAT_DETECTION_SOURCE_IP: + expectedData := []byte{ + 0x13, 0xd8, 0x9e, 0xdc, 0xfa, 0x39, 0xe4, 0xc0, + 0x06, 0x80, 0x5f, 0xde, 0x11, 0x62, 0xd8, 0x76, + 0xee, 0xe8, 0xf2, 0x00, + } + require.Equal(t, expectedData, notification.NotificationData) + case ike_message.NAT_DETECTION_DESTINATION_IP: + expectedData := []byte{ + 0x0d, 0x36, 0x26, 0x71, 0xaf, 0x7f, 0x0b, 0x19, + 0x32, 0xec, 0xf8, 0xf3, 0xe1, 0x84, 0x87, 0xf0, + 0x47, 0x76, 0x83, 0x04, + } + require.Equal(t, expectedData, notification.NotificationData) + } + } +} + +func TestHandleNATDetect(t *testing.T) { + n3iwf, err := NewN3iwfTestApp(&factory.Config{}) + require.NoError(t, err) + + n3iwf.ikeServer, err = NewServer(n3iwf) + require.NoError(t, err) + + tests := []struct { + name string + initiatorSPI uint64 + responderSPI uint64 + notification []*ike_message.Notification + ueAddr net.UDPAddr + n3iwfAddr net.UDPAddr + expectedUeBehindNAT bool + expectedN3iwfBehindNAT bool + }{ + { + name: "UE and N3IWF is not behind NAT", + initiatorSPI: 0x1234567890abcdef, + responderSPI: 0xfedcba0987654321, + notification: []*ike_message.Notification{ + { + NotifyMessageType: ike_message.NAT_DETECTION_SOURCE_IP, + NotificationData: []byte{ + 0x0d, 0x36, 0x26, 0x71, 0xaf, 0x7f, 0x0b, 0x19, + 0x32, 0xec, 0xf8, 0xf3, 0xe1, 0x84, 0x87, 0xf0, + 0x47, 0x76, 0x83, 0x04, + }, + }, + { + NotifyMessageType: ike_message.NAT_DETECTION_DESTINATION_IP, + NotificationData: []byte{ + 0x13, 0xd8, 0x9e, 0xdc, 0xfa, 0x39, 0xe4, 0xc0, + 0x06, 0x80, 0x5f, 0xde, 0x11, 0x62, 0xd8, 0x76, + 0xee, 0xe8, 0xf2, 0x00, + }, + }, + }, + ueAddr: net.UDPAddr{ + IP: net.ParseIP("192.168.1.1"), + Port: 4500, + }, + n3iwfAddr: net.UDPAddr{ + IP: net.ParseIP("192.168.1.2"), + Port: 4500, + }, + expectedUeBehindNAT: false, + expectedN3iwfBehindNAT: false, + }, + { + name: "UE is behind NAT and N3IWF is not behind NAT", + initiatorSPI: 0x1234567890abcdef, + responderSPI: 0xfedcba0987654321, + notification: []*ike_message.Notification{ + { + NotifyMessageType: ike_message.NAT_DETECTION_SOURCE_IP, + NotificationData: []byte{ + 0x0b, 0x17, 0x2d, 0x42, 0xaf, 0x7f, 0x0b, 0x19, + 0x32, 0xec, 0xf8, 0xf3, 0xe1, 0x84, 0x87, 0xf0, + 0x47, 0x76, 0x83, 0x04, + }, + }, + { + NotifyMessageType: ike_message.NAT_DETECTION_DESTINATION_IP, + NotificationData: []byte{ + 0x13, 0xd8, 0x9e, 0xdc, 0xfa, 0x39, 0xe4, 0xc0, + 0x06, 0x80, 0x5f, 0xde, 0x11, 0x62, 0xd8, 0x76, + 0xee, 0xe8, 0xf2, 0x00, + }, + }, + }, + ueAddr: net.UDPAddr{ + IP: net.ParseIP("192.168.1.1"), + Port: 4500, + }, + n3iwfAddr: net.UDPAddr{ + IP: net.ParseIP("192.168.1.2"), + Port: 4500, + }, + expectedUeBehindNAT: true, + expectedN3iwfBehindNAT: false, + }, + { + name: "UE and N3IWF is behind NAT", + initiatorSPI: 0x1234567890abcdef, + responderSPI: 0xfedcba0987654321, + notification: []*ike_message.Notification{ + { + NotifyMessageType: ike_message.NAT_DETECTION_SOURCE_IP, + NotificationData: []byte{ + 0x0b, 0x16, 0x26, 0x71, 0xaf, 0x7f, 0x0b, 0x19, + 0x32, 0xec, 0xf8, 0xf3, 0xe1, 0x84, 0x87, 0xf0, + 0x47, 0x76, 0x83, 0x04, + }, + }, + { + NotifyMessageType: ike_message.NAT_DETECTION_DESTINATION_IP, + NotificationData: []byte{ + 0x0f, 0xd9, 0x9e, 0xdc, 0xfa, 0x39, 0xe4, 0xc0, + 0x06, 0x80, 0x5f, 0xde, 0x11, 0x62, 0xd8, 0x76, + 0xee, 0xe8, 0xf2, 0x00, + }, + }, + }, + ueAddr: net.UDPAddr{ + IP: net.ParseIP("192.168.1.1"), + Port: 4500, + }, + n3iwfAddr: net.UDPAddr{ + IP: net.ParseIP("192.168.1.2"), + Port: 4500, + }, + expectedUeBehindNAT: true, + expectedN3iwfBehindNAT: true, + }, + } + + for i := range tests { + tt := tests[i] + t.Run(tt.name, func(t *testing.T) { + ueBehindNAT, n3iwfBehindNAT, err := n3iwf.ikeServer.handleNATDetect( + tt.initiatorSPI, tt.responderSPI, + tt.notification, &tt.ueAddr, &tt.n3iwfAddr) + require.NoError(t, err) + + require.Equal(t, tt.expectedUeBehindNAT, ueBehindNAT) + require.Equal(t, tt.expectedN3iwfBehindNAT, n3iwfBehindNAT) + }) + } +} diff --git a/internal/ike/send.go b/internal/ike/send.go new file mode 100644 index 00000000..387cb516 --- /dev/null +++ b/internal/ike/send.go @@ -0,0 +1,119 @@ +package ike + +import ( + "math" + "net" + + "github.com/pkg/errors" + + "github.com/free5gc/ike" + "github.com/free5gc/ike/message" + ike_message "github.com/free5gc/ike/message" + "github.com/free5gc/ike/security" + n3iwf_context "github.com/free5gc/n3iwf/internal/context" + "github.com/free5gc/n3iwf/internal/logger" +) + +func SendIKEMessageToUE( + udpConn *net.UDPConn, + srcAddr, dstAddr *net.UDPAddr, + message *ike_message.IKEMessage, + ikeSAKey *security.IKESAKey, +) error { + ikeLog := logger.IKELog + ikeLog.Trace("Send IKE message to UE") + ikeLog.Trace("Encoding...") + pkt, err := ike.EncodeEncrypt(message, ikeSAKey, ike_message.Role_Responder) + if err != nil { + return errors.Wrapf(err, "SendIKEMessageToUE") + } + // As specified in RFC 7296 section 3.1, the IKE message send from/to UDP port 4500 + // should prepend a 4 bytes zero + if srcAddr.Port == 4500 { + prependZero := make([]byte, 4) + pkt = append(prependZero, pkt...) + } + + ikeLog.Trace("Sending...") + n, err := udpConn.WriteToUDP(pkt, dstAddr) + if err != nil { + return errors.Wrapf(err, "SendIKEMessageToUE") + } + if n != len(pkt) { + return errors.Errorf("SendIKEMessageToUE Not all of the data is sent. Total length: %d. Sent: %d.", + len(pkt), n) + } + return nil +} + +func SendUEInformationExchange( + ikeSA *n3iwf_context.IKESecurityAssociation, + ikeSAKey *security.IKESAKey, + payload *ike_message.IKEPayloadContainer, initiator bool, + response bool, messageID uint32, conn *net.UDPConn, + ueAddr *net.UDPAddr, n3iwfAddr *net.UDPAddr, +) { + ikeLog := logger.IKELog + + // Build IKE message + responseIKEMessage := ike_message.NewMessage(ikeSA.RemoteSPI, ikeSA.LocalSPI, + ike_message.INFORMATIONAL, response, initiator, messageID, nil) + + if payload != nil && len(*payload) > 0 { + responseIKEMessage.Payloads = append(responseIKEMessage.Payloads, *payload...) + } + + err := SendIKEMessageToUE(conn, n3iwfAddr, ueAddr, responseIKEMessage, ikeSAKey) + if err != nil { + ikeLog.Errorf("SendUEInformationExchange err: %+v", err) + return + } +} + +func SendIKEDeleteRequest(n3iwfCtx *n3iwf_context.N3IWFContext, localSPI uint64) { + ikeLog := logger.IKELog + ikeUe, ok := n3iwfCtx.IkeUePoolLoad(localSPI) + if !ok { + ikeLog.Errorf("Cannot get IkeUE from SPI : %+v", localSPI) + return + } + + var deletePayload message.IKEPayloadContainer + deletePayload.BuildDeletePayload(message.TypeIKE, 0, 0, nil) + SendUEInformationExchange(ikeUe.N3IWFIKESecurityAssociation, ikeUe.N3IWFIKESecurityAssociation.IKESAKey, + &deletePayload, false, false, ikeUe.N3IWFIKESecurityAssociation.ResponderMessageID, + ikeUe.IKEConnection.Conn, ikeUe.IKEConnection.UEAddr, ikeUe.IKEConnection.N3IWFAddr) +} + +func SendChildSADeleteRequest( + ikeUe *n3iwf_context.N3IWFIkeUe, + relaseList []int64, +) { + ikeLog := logger.IKELog + var deleteSPIs []uint32 + spiLen := uint16(0) + for _, releaseItem := range relaseList { + for _, childSA := range ikeUe.N3IWFChildSecurityAssociation { + if childSA.PDUSessionIds[0] == releaseItem { + spi := childSA.XfrmStateList[0].Spi + if spi < 0 || spi > math.MaxUint32 { + ikeLog.Errorf("SendChildSADeleteRequest spi out of uint32 range : %d", spi) + return + } + deleteSPIs = append(deleteSPIs, uint32(spi)) + spiLen += 1 + err := ikeUe.DeleteChildSA(childSA) + if err != nil { + ikeLog.Errorf("Delete Child SA error : %v", err) + return + } + } + } + } + + var deletePayload message.IKEPayloadContainer + deletePayload.BuildDeletePayload(message.TypeESP, 4, spiLen, deleteSPIs) + SendUEInformationExchange(ikeUe.N3IWFIKESecurityAssociation, ikeUe.N3IWFIKESecurityAssociation.IKESAKey, + &deletePayload, false, false, ikeUe.N3IWFIKESecurityAssociation.ResponderMessageID, + ikeUe.IKEConnection.Conn, ikeUe.IKEConnection.UEAddr, ikeUe.IKEConnection.N3IWFAddr) +} diff --git a/internal/ike/server.go b/internal/ike/server.go new file mode 100644 index 00000000..fca6d5f0 --- /dev/null +++ b/internal/ike/server.go @@ -0,0 +1,388 @@ +package ike + +import ( + "bytes" + "context" + "encoding/hex" + "fmt" + "net" + "runtime/debug" + "sync" + "syscall" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/pkg/errors" + + "github.com/free5gc/ike" + ike_message "github.com/free5gc/ike/message" + n3iwf_context "github.com/free5gc/n3iwf/internal/context" + "github.com/free5gc/n3iwf/internal/logger" + "github.com/free5gc/n3iwf/pkg/factory" + "github.com/free5gc/util/safe_channel" +) + +const ( + RECEIVE_IKEPACKET_CHANNEL_LEN = 512 + RECEIVE_IKEEVENT_CHANNEL_LEN = 512 + + DEFAULT_IKE_PORT = 500 + DEFAULT_NATT_PORT = 4500 +) + +type n3iwf interface { + Config() *factory.Config + Context() *n3iwf_context.N3IWFContext + CancelContext() context.Context + + SendNgapEvt(n3iwf_context.NgapEvt) +} + +type EspHandler func(srcIP, dstIP *net.UDPAddr, espPkt []byte) error + +type Server struct { + n3iwf + + Listener map[int]*net.UDPConn + StopServer chan struct{} + rcvPktCh *safe_channel.SafeCh[IkeReceivePacket] + rcvEvtCh *safe_channel.SafeCh[n3iwf_context.IkeEvt] +} + +type IkeReceivePacket struct { + Listener net.UDPConn + LocalAddr net.UDPAddr + RemoteAddr net.UDPAddr + Msg []byte +} + +func NewServer(n3iwf n3iwf) (*Server, error) { + s := &Server{ + n3iwf: n3iwf, + Listener: make(map[int]*net.UDPConn), + StopServer: make(chan struct{}), + } + s.rcvPktCh = safe_channel.NewSafeCh[IkeReceivePacket](RECEIVE_IKEPACKET_CHANNEL_LEN) + s.rcvEvtCh = safe_channel.NewSafeCh[n3iwf_context.IkeEvt](RECEIVE_IKEEVENT_CHANNEL_LEN) + return s, nil +} + +func (s *Server) Run(wg *sync.WaitGroup) error { + cfg := s.Config() + + // Resolve UDP addresses + ip := cfg.GetIKEBindAddr() + ikeAddrPort, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", ip, DEFAULT_IKE_PORT)) + if err != nil { + return err + } + nattAddrPort, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", ip, DEFAULT_NATT_PORT)) + if err != nil { + return err + } + + // Listen and serve + var errChan chan error + + wg.Add(1) + errChan = make(chan error) + go s.receiver(ikeAddrPort, errChan, wg) + if err, ok := <-errChan; ok { + return errors.Wrapf(err, "ikeAddrPort") + } + + wg.Add(1) + errChan = make(chan error) + go s.receiver(nattAddrPort, errChan, wg) + if err, ok := <-errChan; ok { + return errors.Wrapf(err, "nattAddrPort") + } + + wg.Add(1) + go s.server(wg) + + return nil +} + +func (s *Server) server(wg *sync.WaitGroup) { + ikeLog := logger.IKELog + defer func() { + if p := recover(); p != nil { + // Print stack for panic to log. Fatalf() will let program exit. + ikeLog.Fatalf("panic: %v\n%s", p, string(debug.Stack())) + } + ikeLog.Infof("Ike server stopped") + s.rcvPktCh.Close() + s.rcvEvtCh.Close() + close(s.StopServer) + wg.Done() + }() + + rcvEvtCh := s.rcvEvtCh.GetRcvChan() + rcvPktCh := s.rcvPktCh.GetRcvChan() + + for { + select { + case rcvPkt := <-rcvPktCh: + ikeMsg, ikeSA, err := s.checkIKEMessage( + rcvPkt.Msg, &rcvPkt.Listener, &rcvPkt.LocalAddr, &rcvPkt.RemoteAddr) + if err != nil { + ikeLog.Warnln(err) + continue + } + s.Dispatch(&rcvPkt.Listener, &rcvPkt.LocalAddr, &rcvPkt.RemoteAddr, + ikeMsg, rcvPkt.Msg, ikeSA) + case rcvIkeEvent := <-rcvEvtCh: + s.HandleEvent(rcvIkeEvent) + case <-s.StopServer: + return + } + } +} + +func (s *Server) receiver( + localAddr *net.UDPAddr, + errChan chan<- error, + wg *sync.WaitGroup, +) { + ikeLog := logger.IKELog + defer func() { + if p := recover(); p != nil { + // Print stack for panic to log. Fatalf() will let program exit. + ikeLog.Fatalf("panic: %v\n%s", p, string(debug.Stack())) + } + ikeLog.Infof("Ike receiver stopped") + wg.Done() + }() + + listener, err := net.ListenUDP("udp", localAddr) + if err != nil { + ikeLog.Errorf("Listen UDP failed: %+v", err) + errChan <- errors.New("listenAndServe failed") + return + } + + close(errChan) + + s.Listener[localAddr.Port] = listener + + buf := make([]byte, factory.MAX_BUF_MSG_LEN) + + for { + n, remoteAddr, err := listener.ReadFromUDP(buf) + if err != nil { + ikeLog.Errorf("ReadFromUDP failed: %+v", err) + return + } + + msgBuf := make([]byte, n) + copy(msgBuf, buf) + ikeLog.Tracef("recv from port(%d):\n%s", localAddr.Port, hex.Dump(msgBuf)) + + // As specified in RFC 7296 section 3.1, the IKE message send from/to UDP port 4500 + // should prepend a 4 bytes zero + if localAddr.Port == DEFAULT_NATT_PORT { + msgBuf, err = handleNattMsg(msgBuf, remoteAddr, localAddr, handleESPPacket) + if err != nil { + ikeLog.Errorf("Handle NATT msg: %v", err) + continue + } + if msgBuf == nil { + continue + } + } + + if len(msgBuf) < ike_message.IKE_HEADER_LEN { + ikeLog.Warnf("Received IKE msg is too short from %s", remoteAddr) + continue + } + + ikePkt := IkeReceivePacket{ + RemoteAddr: *remoteAddr, + Listener: *listener, + LocalAddr: *localAddr, + Msg: msgBuf, + } + s.rcvPktCh.Send(ikePkt) + } +} + +func handleNattMsg( + msgBuf []byte, + rAddr, lAddr *net.UDPAddr, + espHandler EspHandler, +) ([]byte, error) { + if len(msgBuf) == 1 && msgBuf[0] == 0xff { + // skip NAT-T Keepalive + return nil, nil + } + + nonEspMarker := []byte{0, 0, 0, 0} // Non-ESP Marker + nonEspMarkerLen := len(nonEspMarker) + if len(msgBuf) < nonEspMarkerLen { + return nil, errors.Errorf("Received msg is too short") + } + if !bytes.Equal(msgBuf[:nonEspMarkerLen], nonEspMarker) { + // ESP packet + if espHandler != nil { + err := espHandler(rAddr, lAddr, msgBuf) + if err != nil { + return nil, errors.Wrapf(err, "Handle ESP") + } + } + return nil, nil + } + + // IKE message: skip Non-ESP Marker + msgBuf = msgBuf[nonEspMarkerLen:] + return msgBuf, nil +} + +func (s *Server) SendIkeEvt(evt n3iwf_context.IkeEvt) { + s.rcvEvtCh.Send(evt) +} + +func (s *Server) Stop() { + ikeLog := logger.IKELog + ikeLog.Infof("Close Ike server...") + + for _, ikeServerListener := range s.Listener { + if err := ikeServerListener.Close(); err != nil { + ikeLog.Errorf("Stop ike server : %s error : %+v", err, ikeServerListener.LocalAddr().String()) + } + } + + s.StopServer <- struct{}{} +} + +func (s *Server) checkIKEMessage( + msg []byte, udpConn *net.UDPConn, + localAddr, remoteAddr *net.UDPAddr, +) (*ike_message.IKEMessage, + *n3iwf_context.IKESecurityAssociation, error, +) { + var ikeHeader *ike_message.IKEHeader + var ikeMessage *ike_message.IKEMessage + var ikeSA *n3iwf_context.IKESecurityAssociation + var err error + + // parse IKE header and setup IKE context + ikeHeader, err = ike_message.ParseHeader(msg) + if err != nil { + return nil, nil, errors.Wrapf(err, "IKE msg decode header") + } + + // check major version + if ikeHeader.MajorVersion > 2 { + // send INFORMATIONAL type message with INVALID_MAJOR_VERSION Notify payload + // For response or needed data + payload := new(ike_message.IKEPayloadContainer) + payload.BuildNotification(ike_message.TypeNone, + ike_message.INVALID_MAJOR_VERSION, nil, nil) + responseIKEMessage := ike_message.NewMessage(ikeHeader.InitiatorSPI, ikeHeader.ResponderSPI, + ike_message.INFORMATIONAL, true, false, ikeHeader.MessageID, *payload) + + err = SendIKEMessageToUE(udpConn, localAddr, remoteAddr, responseIKEMessage, nil) + if err != nil { + return nil, nil, errors.Wrapf(err, "Received an IKE message with higher major version "+ + "(%d>2)", ikeHeader.MajorVersion) + } + return nil, nil, errors.Errorf("Received an IKE message with higher major version (%d>2)", ikeHeader.MajorVersion) + } + + if ikeHeader.ExchangeType == ike_message.IKE_SA_INIT { + ikeMessage, err = ike.DecodeDecrypt(msg, ikeHeader, + nil, ike_message.Role_Responder) + if err != nil { + return nil, nil, errors.Wrapf(err, "Decrypt IkeMsg error") + } + } else if ikeHeader.ExchangeType != ike_message.IKE_SA_INIT { + localSPI := ikeHeader.ResponderSPI + var ok bool + n3iwfCtx := s.Context() + + ikeSA, ok = n3iwfCtx.IKESALoad(localSPI) + if !ok { + payload := new(ike_message.IKEPayloadContainer) + // send INFORMATIONAL type message with INVALID_IKE_SPI Notify payload ( OUTSIDE IKE SA ) + payload.BuildNotification(ike_message.TypeNone, ike_message.INVALID_IKE_SPI, nil, nil) + responseIKEMessage := ike_message.NewMessage(ikeHeader.InitiatorSPI, ikeHeader.ResponderSPI, + ike_message.INFORMATIONAL, true, false, ikeHeader.MessageID, *payload) + + err = SendIKEMessageToUE(udpConn, localAddr, remoteAddr, responseIKEMessage, nil) + if err != nil { + return nil, nil, errors.Wrapf(err, "checkIKEMessage():") + } + return nil, nil, errors.Errorf("Received an unrecognized SPI message: %d", localSPI) + } + + ikeMessage, err = ike.DecodeDecrypt(msg, ikeHeader, + ikeSA.IKESAKey, ike_message.Role_Responder) + if err != nil { + return nil, nil, errors.Wrapf(err, "Decrypt IkeMsg error") + } + } + + return ikeMessage, ikeSA, nil +} + +func constructPacketWithESP(srcIP, dstIP *net.UDPAddr, espPacket []byte) ([]byte, error) { + ipLayer := &layers.IPv4{ + SrcIP: srcIP.IP, + DstIP: dstIP.IP, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolESP, + } + + buffer := gopacket.NewSerializeBuffer() + options := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + err := gopacket.SerializeLayers(buffer, + options, + ipLayer, + gopacket.Payload(espPacket), + ) + if err != nil { + return nil, errors.Errorf("Error serializing layers: %v", err) + } + + packetData := buffer.Bytes() + return packetData, nil +} + +func handleESPPacket(srcIP, dstIP *net.UDPAddr, espPacket []byte) error { + ikeLog := logger.IKELog + ikeLog.Tracef("Handle ESPPacket") + + fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW) + if err != nil { + return errors.Errorf("socket error: %v", err) + } + + defer func() { + if err = syscall.Close(fd); err != nil { + ikeLog.Errorf("Close fd error : %v", err) + } + }() + + ipPacket, err := constructPacketWithESP(srcIP, dstIP, espPacket) + if err != nil { + return err + } + + addr := syscall.SockaddrInet4{ + Addr: [4]byte(dstIP.IP), + Port: dstIP.Port, + } + + err = syscall.Sendto(fd, ipPacket, 0, &addr) + if err != nil { + return errors.Errorf("sendto error: %v", err) + } + + return nil +} diff --git a/internal/ike/server_test.go b/internal/ike/server_test.go new file mode 100644 index 00000000..bd7a267a --- /dev/null +++ b/internal/ike/server_test.go @@ -0,0 +1,268 @@ +package ike + +import ( + "context" + "net" + "sync" + "testing" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/require" + + ike_message "github.com/free5gc/ike/message" + n3iwf_context "github.com/free5gc/n3iwf/internal/context" + "github.com/free5gc/n3iwf/internal/ngap" + "github.com/free5gc/n3iwf/pkg/factory" +) + +type n3iwfTestApp struct { + cfg *factory.Config + n3iwfCtx *n3iwf_context.N3IWFContext + ngapServer *ngap.Server + ikeServer *Server + ctx context.Context + cancel context.CancelFunc + wg *sync.WaitGroup +} + +func (a *n3iwfTestApp) Config() *factory.Config { + return a.cfg +} + +func (a *n3iwfTestApp) Context() *n3iwf_context.N3IWFContext { + return a.n3iwfCtx +} + +func (a *n3iwfTestApp) CancelContext() context.Context { + return a.ctx +} + +func (a *n3iwfTestApp) SendNgapEvt(evt n3iwf_context.NgapEvt) { + a.ngapServer.SendNgapEvt(evt) +} + +func NewN3iwfTestApp(cfg *factory.Config) (*n3iwfTestApp, error) { + var err error + ctx, cancel := context.WithCancel(context.Background()) + + n3iwfApp := &n3iwfTestApp{ + cfg: cfg, + ctx: ctx, + cancel: cancel, + wg: &sync.WaitGroup{}, + } + + n3iwfApp.n3iwfCtx, err = n3iwf_context.NewTestContext(n3iwfApp) + if err != nil { + return nil, err + } + return n3iwfApp, err +} + +func NewTestCfg() *factory.Config { + return &factory.Config{ + Configuration: &factory.Configuration{}, + } +} + +func TestHandleNattMsg(t *testing.T) { + initiatorSPI := uint64(0x123) + ikeMessage := ike_message.NewMessage(initiatorSPI, 0, ike_message.IKE_SA_INIT, + true, false, 0, nil) + pkt, err := ikeMessage.Encode() + require.NoError(t, err) + + NonESPPkt := append([]byte{0, 0, 0, 0}, pkt...) + + tests := []struct { + name string + conn *net.UDPConn + rcvPkt []byte + lAddr, rAddr *net.UDPAddr + msg *ike_message.IKEMessage + expectedErr bool + }{ + { + name: "Received NAT-T Keepalive", + rcvPkt: []byte{0xff}, + lAddr: &net.UDPAddr{ + IP: net.ParseIP("10.100.100.2"), + Port: 4500, + }, + rAddr: &net.UDPAddr{ + IP: net.ParseIP("10.100.100.1"), + Port: 4500, + }, + expectedErr: false, + }, + { + name: "Received NAT-T Msg is too short", + rcvPkt: []byte{0x01, 0x02}, + lAddr: &net.UDPAddr{ + IP: net.ParseIP("10.100.100.2"), + Port: 4500, + }, + rAddr: &net.UDPAddr{ + IP: net.ParseIP("10.100.100.1"), + Port: 4500, + }, + expectedErr: true, + }, + { + name: "Received IKE packet from port 4500, and no need to drop", + rcvPkt: NonESPPkt, + lAddr: &net.UDPAddr{ + IP: net.ParseIP("10.100.100.2"), + Port: 4500, + }, + rAddr: &net.UDPAddr{ + IP: net.ParseIP("10.100.100.1"), + Port: 4500, + }, + expectedErr: false, + }, + { + name: "Received ESP packet from port 4500", + rcvPkt: []byte{0x1, 0x2, 0x3, 0x4, 0x5}, + lAddr: &net.UDPAddr{ + IP: net.ParseIP("10.100.100.2"), + Port: 4500, + }, + rAddr: &net.UDPAddr{ + IP: net.ParseIP("10.100.100.1"), + Port: 4500, + }, + expectedErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := handleNattMsg(tt.rcvPkt, tt.rAddr, tt.lAddr, nil) + if tt.expectedErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestCheckIKEMessage(t *testing.T) { + n3iwf, err := NewN3iwfTestApp(NewTestCfg()) + require.NoError(t, err) + + n3iwf.ikeServer, err = NewServer(n3iwf) + require.NoError(t, err) + ikeServer := n3iwf.ikeServer + + srcIP := &net.UDPAddr{ + IP: net.ParseIP("10.100.100.1"), + Port: 500, + } + dstIP := &net.UDPAddr{ + IP: net.ParseIP("10.100.100.2"), + Port: 500, + } + + mockConn, err := net.DialUDP("udp", nil, dstIP) + require.NoError(t, err) + + initiatorSPI := uint64(0x123) + nonceData := []byte("randomNonce") + payload := new(ike_message.IKEPayloadContainer) + payload.BuildNonce(nonceData) + + ikeMsg := ike_message.NewMessage(initiatorSPI, 0, ike_message.IKE_SA_INIT, + true, false, 0, *payload) + + tests := []struct { + name string + conn *net.UDPConn + localAddr *net.UDPAddr + remoteAddr *net.UDPAddr + msg *ike_message.IKEMessage + expectedErr bool + }{ + { + name: "Receive packet has IKE version error", + conn: mockConn, + localAddr: dstIP, + remoteAddr: srcIP, + msg: &ike_message.IKEMessage{ + IKEHeader: &ike_message.IKEHeader{ + InitiatorSPI: initiatorSPI, + ExchangeType: ike_message.IKE_SA_INIT, + Flags: 0, + MajorVersion: 3, + MinorVersion: 0, + }, + }, + expectedErr: true, + }, + { + name: "Decode IKE_SA_INIT msg", + conn: mockConn, + localAddr: dstIP, + remoteAddr: srcIP, + msg: ikeMsg, + expectedErr: false, + }, + { + name: "SPI not found from IKE header", + conn: mockConn, + localAddr: dstIP, + remoteAddr: srcIP, + msg: &ike_message.IKEMessage{ + IKEHeader: &ike_message.IKEHeader{ + InitiatorSPI: initiatorSPI, + ExchangeType: ike_message.IKE_AUTH, + Flags: ike_message.ResponseBitCheck, + MajorVersion: 2, + MinorVersion: 0, + }, + Payloads: ikeMsg.Payloads, + }, + expectedErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msg, err := tt.msg.Encode() + require.NoError(t, err) + + _, _, err = ikeServer.checkIKEMessage( + msg, tt.conn, tt.localAddr, tt.remoteAddr) + if tt.expectedErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestConstructPacketWithESP(t *testing.T) { + srcIP := &net.UDPAddr{ + IP: net.IPv4(192, 168, 0, 1), + } + dstIP := &net.UDPAddr{ + IP: net.IPv4(192, 168, 0, 2), + } + + espPacket := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07} + + packet, err := constructPacketWithESP(srcIP, dstIP, espPacket) + require.NoError(t, err) + + packetParsed := gopacket.NewPacket(packet, layers.LayerTypeIPv4, gopacket.Default) + ipLayer := packetParsed.Layer(layers.LayerTypeIPv4) + require.NotNil(t, ipLayer) + + ipv4, _ := ipLayer.(*layers.IPv4) + require.Equal(t, ipv4.SrcIP.To4().String(), srcIP.IP.String()) + require.Equal(t, ipv4.DstIP.To4().String(), dstIP.IP.String()) + require.Equal(t, ipv4.Protocol, layers.IPProtocolESP) +} diff --git a/pkg/ike/xfrm/xfrm.go b/internal/ike/xfrm/xfrm.go similarity index 67% rename from pkg/ike/xfrm/xfrm.go rename to internal/ike/xfrm/xfrm.go index c6581b15..7657aa0f 100644 --- a/pkg/ike/xfrm/xfrm.go +++ b/internal/ike/xfrm/xfrm.go @@ -1,15 +1,15 @@ package xfrm import ( - "errors" "fmt" "net" + "github.com/pkg/errors" "github.com/vishvananda/netlink" + "github.com/free5gc/ike/message" + "github.com/free5gc/n3iwf/internal/context" "github.com/free5gc/n3iwf/internal/logger" - "github.com/free5gc/n3iwf/pkg/context" - "github.com/free5gc/n3iwf/pkg/ike/message" ) type XFRMEncryptionAlgorithmType uint16 @@ -55,7 +55,6 @@ func (xfrmIntegrityAlgorithmType XFRMIntegrityAlgorithmType) String() string { func ApplyXFRMRule(n3iwf_is_initiator bool, xfrmiId uint32, childSecurityAssociation *context.ChildSecurityAssociation, ) error { - ikeLog := logger.IKELog // Build XFRM information data structure for incoming traffic. // Direction: {private_network} -> this_server @@ -63,26 +62,26 @@ func ApplyXFRMRule(n3iwf_is_initiator bool, xfrmiId uint32, var xfrmEncryptionAlgorithm, xfrmIntegrityAlgorithm *netlink.XfrmStateAlgo if n3iwf_is_initiator { xfrmEncryptionAlgorithm = &netlink.XfrmStateAlgo{ - Name: XFRMEncryptionAlgorithmType(childSecurityAssociation.EncryptionAlgorithm).String(), + Name: XFRMEncryptionAlgorithmType(childSecurityAssociation.EncrKInfo.TransformID()).String(), Key: childSecurityAssociation.ResponderToInitiatorEncryptionKey, } - if childSecurityAssociation.IntegrityAlgorithm != 0 { + if childSecurityAssociation.IntegKInfo != nil { xfrmIntegrityAlgorithm = &netlink.XfrmStateAlgo{ - Name: XFRMIntegrityAlgorithmType(childSecurityAssociation.IntegrityAlgorithm).String(), + Name: XFRMIntegrityAlgorithmType(childSecurityAssociation.IntegKInfo.TransformID()).String(), Key: childSecurityAssociation.ResponderToInitiatorIntegrityKey, - TruncateLen: getTruncateLength(childSecurityAssociation.IntegrityAlgorithm), + TruncateLen: getTruncateLength(childSecurityAssociation.IntegKInfo.TransformID()), } } } else { xfrmEncryptionAlgorithm = &netlink.XfrmStateAlgo{ - Name: XFRMEncryptionAlgorithmType(childSecurityAssociation.EncryptionAlgorithm).String(), + Name: XFRMEncryptionAlgorithmType(childSecurityAssociation.EncrKInfo.TransformID()).String(), Key: childSecurityAssociation.InitiatorToResponderEncryptionKey, } - if childSecurityAssociation.IntegrityAlgorithm != 0 { + if childSecurityAssociation.IntegKInfo != nil { xfrmIntegrityAlgorithm = &netlink.XfrmStateAlgo{ - Name: XFRMIntegrityAlgorithmType(childSecurityAssociation.IntegrityAlgorithm).String(), + Name: XFRMIntegrityAlgorithmType(childSecurityAssociation.IntegKInfo.TransformID()).String(), Key: childSecurityAssociation.InitiatorToResponderIntegrityKey, - TruncateLen: getTruncateLength(childSecurityAssociation.IntegrityAlgorithm), + TruncateLen: getTruncateLength(childSecurityAssociation.IntegKInfo.TransformID()), } } } @@ -97,21 +96,12 @@ func ApplyXFRMRule(n3iwf_is_initiator bool, xfrmiId uint32, xfrmState.Ifid = int(xfrmiId) xfrmState.Auth = xfrmIntegrityAlgorithm xfrmState.Crypt = xfrmEncryptionAlgorithm - xfrmState.ESN = childSecurityAssociation.ESN - - if childSecurityAssociation.EnableEncapsulate { - xfrmState.Encap = &netlink.XfrmStateEncap{ - Type: netlink.XFRM_ENCAP_ESPINUDP, - SrcPort: childSecurityAssociation.NATPort, - DstPort: childSecurityAssociation.N3IWFPort, - } - } + xfrmState.ESN = childSecurityAssociation.EsnInfo.GetNeedESN() // Commit xfrm state to netlink var err error if err = netlink.XfrmStateAdd(xfrmState); err != nil { - ikeLog.Errorf("Set XFRM rules failed: %+v", err) - return errors.New("Set XFRM state rule failed") + return errors.Wrapf(err, "Add XFRM state") } childSecurityAssociation.XfrmStateList = append(childSecurityAssociation.XfrmStateList, *xfrmState) @@ -138,8 +128,7 @@ func ApplyXFRMRule(n3iwf_is_initiator bool, xfrmiId uint32, // Commit xfrm policy to netlink if err = netlink.XfrmPolicyAdd(xfrmPolicy); err != nil { - ikeLog.Errorf("Set XFRM rules failed: %+v", err) - return errors.New("Set XFRM policy rule failed") + return errors.Wrapf(err, "Add XFRM policy") } childSecurityAssociation.XfrmPolicyList = append(childSecurityAssociation.XfrmPolicyList, *xfrmPolicy) @@ -148,26 +137,34 @@ func ApplyXFRMRule(n3iwf_is_initiator bool, xfrmiId uint32, // State if n3iwf_is_initiator { xfrmEncryptionAlgorithm.Key = childSecurityAssociation.InitiatorToResponderEncryptionKey - if childSecurityAssociation.IntegrityAlgorithm != 0 { + if childSecurityAssociation.IntegKInfo != nil { xfrmIntegrityAlgorithm.Key = childSecurityAssociation.InitiatorToResponderIntegrityKey } } else { xfrmEncryptionAlgorithm.Key = childSecurityAssociation.ResponderToInitiatorEncryptionKey - if childSecurityAssociation.IntegrityAlgorithm != 0 { + if childSecurityAssociation.IntegKInfo != nil { xfrmIntegrityAlgorithm.Key = childSecurityAssociation.ResponderToInitiatorIntegrityKey } } xfrmState.Spi = int(childSecurityAssociation.OutboundSPI) xfrmState.Src, xfrmState.Dst = xfrmState.Dst, xfrmState.Src + + if childSecurityAssociation.EnableEncapsulate { + xfrmState.Encap = &netlink.XfrmStateEncap{ + Type: netlink.XFRM_ENCAP_ESPINUDP, + SrcPort: childSecurityAssociation.NATPort, + DstPort: childSecurityAssociation.N3IWFPort, + } + } + if xfrmState.Encap != nil { xfrmState.Encap.SrcPort, xfrmState.Encap.DstPort = xfrmState.Encap.DstPort, xfrmState.Encap.SrcPort } // Commit xfrm state to netlink if err = netlink.XfrmStateAdd(xfrmState); err != nil { - ikeLog.Errorf("Set XFRM rules failed: %+v", err) - return errors.New("Set XFRM state rule failed") + return errors.Wrapf(err, "Add XFRM state") } childSecurityAssociation.XfrmStateList = append(childSecurityAssociation.XfrmStateList, *xfrmState) @@ -184,54 +181,13 @@ func ApplyXFRMRule(n3iwf_is_initiator bool, xfrmiId uint32, // Commit xfrm policy to netlink if err = netlink.XfrmPolicyAdd(xfrmPolicy); err != nil { - ikeLog.Errorf("Set XFRM rules failed: %+v", err) - return errors.New("Set XFRM policy rule failed") + return errors.Wrapf(err, "Add XFRM policy") } childSecurityAssociation.XfrmPolicyList = append(childSecurityAssociation.XfrmPolicyList, *xfrmPolicy) - - printSAInfo(n3iwf_is_initiator, xfrmiId, childSecurityAssociation) - return nil } -func printSAInfo(n3iwf_is_initiator bool, xfrmiId uint32, childSecurityAssociation *context.ChildSecurityAssociation) { - ikeLog := logger.IKELog - var InboundEncryptionKey, InboundIntegrityKey, OutboundEncryptionKey, OutboundIntegrityKey []byte - - if n3iwf_is_initiator { - InboundEncryptionKey = childSecurityAssociation.ResponderToInitiatorEncryptionKey - InboundIntegrityKey = childSecurityAssociation.ResponderToInitiatorIntegrityKey - OutboundEncryptionKey = childSecurityAssociation.InitiatorToResponderEncryptionKey - OutboundIntegrityKey = childSecurityAssociation.InitiatorToResponderIntegrityKey - } else { - InboundEncryptionKey = childSecurityAssociation.InitiatorToResponderEncryptionKey - InboundIntegrityKey = childSecurityAssociation.InitiatorToResponderIntegrityKey - OutboundEncryptionKey = childSecurityAssociation.ResponderToInitiatorEncryptionKey - OutboundIntegrityKey = childSecurityAssociation.ResponderToInitiatorIntegrityKey - } - ikeLog.Debug("====== IPSec/Child SA Info ======") - // ====== Inbound ====== - ikeLog.Debugf("XFRM interface if_id: %d", xfrmiId) - ikeLog.Debugf("IPSec Inbound SPI: 0x%016x", childSecurityAssociation.InboundSPI) - ikeLog.Debugf("[UE:%+v] -> [N3IWF:%+v]", - childSecurityAssociation.PeerPublicIPAddr, childSecurityAssociation.LocalPublicIPAddr) - ikeLog.Debugf("IPSec Encryption Algorithm: %d", childSecurityAssociation.EncryptionAlgorithm) - ikeLog.Debugf("IPSec Encryption Key: 0x%x", InboundEncryptionKey) - ikeLog.Debugf("IPSec Integrity Algorithm: %d", childSecurityAssociation.IntegrityAlgorithm) - ikeLog.Debugf("IPSec Integrity Key: 0x%x", InboundIntegrityKey) - ikeLog.Debug("====== IPSec/Child SA Info ======") - // ====== Outbound ====== - ikeLog.Debugf("XFRM interface if_id: %d", xfrmiId) - ikeLog.Debugf("IPSec Outbound SPI: 0x%016x", childSecurityAssociation.OutboundSPI) - ikeLog.Debugf("[N3IWF:%+v] -> [UE:%+v]", - childSecurityAssociation.LocalPublicIPAddr, childSecurityAssociation.PeerPublicIPAddr) - ikeLog.Debugf("IPSec Encryption Algorithm: %d", childSecurityAssociation.EncryptionAlgorithm) - ikeLog.Debugf("IPSec Encryption Key: 0x%x", OutboundEncryptionKey) - ikeLog.Debugf("IPSec Integrity Algorithm: %d", childSecurityAssociation.IntegrityAlgorithm) - ikeLog.Debugf("IPSec Integrity Key: 0x%x", OutboundIntegrityKey) -} - func SetupIPsecXfrmi(xfrmIfaceName, parentIfaceName string, xfrmIfaceId uint32, xfrmIfaceAddr net.IPNet, ) (netlink.Link, error) { diff --git a/internal/logger/logger.go b/internal/logger/logger.go index cca85e34..bb946ea3 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -8,12 +8,13 @@ import ( var ( Log *logrus.Logger - NfLog *logrus.Entry + N3iwfLog *logrus.Entry MainLog *logrus.Entry InitLog *logrus.Entry CfgLog *logrus.Entry CtxLog *logrus.Entry GinLog *logrus.Entry + NasLog *logrus.Entry NgapLog *logrus.Entry IKELog *logrus.Entry GTPLog *logrus.Entry @@ -21,23 +22,24 @@ var ( NWuUPLog *logrus.Entry RelayLog *logrus.Entry UtilLog *logrus.Entry + GmmLog *logrus.Entry ) -func UpdateNfLog(s string) { - NfLog = Log.WithField(logger_util.FieldNF, s) - // update logs created from NfLog - MainLog = NfLog.WithField(logger_util.FieldCategory, "Main") - InitLog = NfLog.WithField(logger_util.FieldCategory, "Init") - CfgLog = NfLog.WithField(logger_util.FieldCategory, "CFG") - CtxLog = NfLog.WithField(logger_util.FieldCategory, "CTX") - GinLog = NfLog.WithField(logger_util.FieldCategory, "GIN") - NgapLog = NfLog.WithField(logger_util.FieldCategory, "NGAP") - IKELog = NfLog.WithField(logger_util.FieldCategory, "IKE") - GTPLog = NfLog.WithField(logger_util.FieldCategory, "GTP") - NWuCPLog = NfLog.WithField(logger_util.FieldCategory, "NWuCP") - NWuUPLog = NfLog.WithField(logger_util.FieldCategory, "NWuUP") - RelayLog = NfLog.WithField(logger_util.FieldCategory, "Relay") - UtilLog = NfLog.WithField(logger_util.FieldCategory, "Util") +func UpdateN3iwfLog() { + N3iwfLog = Log.WithField(logger_util.FieldNF, "N3IWF") + // update logs created from N3iwfLog + MainLog = N3iwfLog.WithField(logger_util.FieldCategory, "Main") + InitLog = N3iwfLog.WithField(logger_util.FieldCategory, "Init") + CfgLog = N3iwfLog.WithField(logger_util.FieldCategory, "CFG") + CtxLog = N3iwfLog.WithField(logger_util.FieldCategory, "CTX") + GinLog = N3iwfLog.WithField(logger_util.FieldCategory, "GIN") + NgapLog = N3iwfLog.WithField(logger_util.FieldCategory, "NGAP") + IKELog = N3iwfLog.WithField(logger_util.FieldCategory, "IKE") + GTPLog = N3iwfLog.WithField(logger_util.FieldCategory, "GTP") + NWuCPLog = N3iwfLog.WithField(logger_util.FieldCategory, "NWuCP") + NWuUPLog = N3iwfLog.WithField(logger_util.FieldCategory, "NWuUP") + RelayLog = N3iwfLog.WithField(logger_util.FieldCategory, "Relay") + UtilLog = N3iwfLog.WithField(logger_util.FieldCategory, "Util") } func init() { @@ -46,5 +48,5 @@ func init() { logger_util.FieldCategory, } Log = logger_util.New(fieldsOrder) - UpdateNfLog("N3IWF") + UpdateN3iwfLog() } diff --git a/internal/nas/nas_security/security.go b/internal/nas/nas_security/security.go new file mode 100644 index 00000000..6670babe --- /dev/null +++ b/internal/nas/nas_security/security.go @@ -0,0 +1,18 @@ +package nas_security + +import ( + "encoding/binary" +) + +// API for N3IWF +func EncapNasMsgToEnvelope(nasPDU []byte) []byte { + // According to TS 24.502 8.2.4, + // in order to transport a NAS message over the non-3GPP access between the UE and the N3IWF, + // the NAS message shall be framed in a NAS message envelope as defined in subclause 9.4. + // According to TS 24.502 9.4, + // a NAS message envelope = Length | NAS Message + nasEnv := make([]byte, 2) + binary.BigEndian.PutUint16(nasEnv, uint16(len(nasPDU))) + nasEnv = append(nasEnv, nasPDU...) + return nasEnv +} diff --git a/internal/ngap/handler.go b/internal/ngap/handler.go index b4c20df6..208303f1 100644 --- a/internal/ngap/handler.go +++ b/internal/ngap/handler.go @@ -2,6 +2,7 @@ package ngap import ( "encoding/binary" + "math" "net" "time" @@ -9,9 +10,10 @@ import ( "github.com/wmnsk/go-gtp/gtpv1" "github.com/free5gc/aper" + n3iwf_context "github.com/free5gc/n3iwf/internal/context" "github.com/free5gc/n3iwf/internal/logger" + "github.com/free5gc/n3iwf/internal/nas/nas_security" "github.com/free5gc/n3iwf/internal/ngap/message" - n3iwf_context "github.com/free5gc/n3iwf/pkg/context" "github.com/free5gc/ngap/ngapConvert" "github.com/free5gc/ngap/ngapType" "github.com/free5gc/sctp" @@ -23,7 +25,7 @@ func (s *Server) HandleNGSetupResponse( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle NG Setup Response") + ngapLog.Infoln("Handle NG Setup Response") var amfName *ngapType.AMFName var servedGUAMIList *ngapType.ServedGUAMIList @@ -136,7 +138,7 @@ func (s *Server) HandleNGSetupFailure( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle NG Setup Failure") + ngapLog.Infoln("Handle NG Setup Failure") var cause *ngapType.Cause var timeToWait *ngapType.TimeToWait @@ -243,7 +245,7 @@ func (s *Server) HandleNGReset( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle NG Reset") + ngapLog.Infoln("Handle NG Reset") var cause *ngapType.Cause var resetType *ngapType.ResetType @@ -320,7 +322,7 @@ func (s *Server) HandleNGReset( return } - var ranUe *n3iwf_context.N3IWFRanUe + var ranUe n3iwf_context.RanUe for _, ueAssociatedLogicalNGConnectionItem := range partOfNGInterface.List { if ueAssociatedLogicalNGConnectionItem.RANUENGAPID != nil { @@ -357,7 +359,7 @@ func (s *Server) HandleNGResetAcknowledge( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle NG Reset Acknowledge") + ngapLog.Infoln("Handle NG Reset Acknowledge") var uEAssociatedLogicalNGConnectionList *ngapType.UEAssociatedLogicalNGConnectionList var criticalityDiagnostics *ngapType.CriticalityDiagnostics @@ -419,7 +421,7 @@ func (s *Server) HandleInitialContextSetupRequest( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle Initial Context Setup Request") + ngapLog.Infoln("Handle Initial Context Setup Request") var amfUeNgapID *ngapType.AMFUENGAPID var ranUeNgapID *ngapType.RANUENGAPID @@ -439,7 +441,9 @@ func (s *Server) HandleInitialContextSetupRequest( var emergencyFallbackIndicator *ngapType.EmergencyFallbackIndicator var iesCriticalityDiagnostics ngapType.CriticalityDiagnosticsIEList - var ranUe *n3iwf_context.N3IWFRanUe + var ranUe n3iwf_context.RanUe + var ranUeCtx *n3iwf_context.RanUeSharedCtx + n3iwfCtx := s.Context() if pdu == nil { @@ -589,7 +593,8 @@ func (s *Server) HandleInitialContextSetupRequest( // Cause: Unknown local UE NGAP ID return } - if ranUe.AmfUeNgapId != amfUeNgapID.Value { + ranUeCtx = ranUe.GetSharedCtx() + if ranUeCtx.AmfUeNgapId != amfUeNgapID.Value { // TODO: build cause and handle error // Cause: Inconsistent remote UE NGAP ID return @@ -601,12 +606,12 @@ func (s *Server) HandleInitialContextSetupRequest( return } - ranUe.AmfUeNgapId = amfUeNgapID.Value - ranUe.RanUeNgapId = ranUeNgapID.Value + ranUeCtx.AmfUeNgapId = amfUeNgapID.Value + ranUeCtx.RanUeNgapId = ranUeNgapID.Value if pduSessionResourceSetupListCxtReq != nil { if ueAggregateMaximumBitRate != nil { - ranUe.Ambr = ueAggregateMaximumBitRate + ranUeCtx.Ambr = ueAggregateMaximumBitRate } else { ngapLog.Errorln("IE[UEAggregateMaximumBitRate] is nil") cause := message.BuildCause(ngapType.CausePresentProtocol, @@ -635,11 +640,11 @@ func (s *Server) HandleInitialContextSetupRequest( failedListCxtRes := new(ngapType.PDUSessionResourceFailedToSetupListCxtRes) // UE temporary data for PDU session setup response - ranUe.TemporaryPDUSessionSetupData.SetupListCxtRes = setupListCxtRes - ranUe.TemporaryPDUSessionSetupData.FailedListCxtRes = failedListCxtRes - ranUe.TemporaryPDUSessionSetupData.Index = 0 - ranUe.TemporaryPDUSessionSetupData.UnactivatedPDUSession = nil - ranUe.TemporaryPDUSessionSetupData.NGAPProcedureCode.Value = ngapType.ProcedureCodeInitialContextSetup + ranUeCtx.TemporaryPDUSessionSetupData.SetupListCxtRes = setupListCxtRes + ranUeCtx.TemporaryPDUSessionSetupData.FailedListCxtRes = failedListCxtRes + ranUeCtx.TemporaryPDUSessionSetupData.Index = 0 + ranUeCtx.TemporaryPDUSessionSetupData.UnactivatedPDUSession = nil + ranUeCtx.TemporaryPDUSessionSetupData.NGAPProcedureCode.Value = ngapType.ProcedureCodeInitialContextSetup for _, item := range pduSessionResourceSetupListCxtReq.List { pduSessionID := item.PDUSessionID.Value @@ -654,7 +659,7 @@ func (s *Server) HandleInitialContextSetupRequest( pduSessionID, err) } - pduSession, err := ranUe.CreatePDUSession(pduSessionID, snssai) + pduSession, err := ranUeCtx.CreatePDUSession(pduSessionID, snssai) if err != nil { ngapLog.Errorf("Create PDU Session Error: %v\n", err) @@ -670,16 +675,15 @@ func (s *Server) HandleInitialContextSetupRequest( continue } - success, resTransfer := s.handlePDUSessionResourceSetupRequestTransfer( - ranUe, pduSession, transfer) + success, resTransfer := s.handlePDUSessionResourceSetupRequestTransfer(ranUe, pduSession, transfer) if success { // Append this PDU session to unactivated PDU session list - ranUe.TemporaryPDUSessionSetupData.UnactivatedPDUSession = append( - ranUe.TemporaryPDUSessionSetupData.UnactivatedPDUSession, + ranUeCtx.TemporaryPDUSessionSetupData.UnactivatedPDUSession = append( + ranUeCtx.TemporaryPDUSessionSetupData.UnactivatedPDUSession, pduSession) } else { // Delete the pdusession store in UE conext - delete(ranUe.PduSessionList, pduSessionID) + delete(ranUeCtx.PduSessionList, pduSessionID) message. AppendPDUSessionResourceFailedToSetupListCxtRes(failedListCxtRes, pduSessionID, resTransfer) } @@ -691,46 +695,48 @@ func (s *Server) HandleInitialContextSetupRequest( } if guami != nil { - ranUe.Guami = guami + ranUeCtx.Guami = guami } if allowedNSSAI != nil { - ranUe.AllowedNssai = allowedNSSAI + ranUeCtx.AllowedNssai = allowedNSSAI } if maskedIMEISV != nil { - ranUe.MaskedIMEISV = maskedIMEISV + ranUeCtx.MaskedIMEISV = maskedIMEISV } if ueRadioCapability != nil { - ranUe.RadioCapability = ueRadioCapability + ranUeCtx.RadioCapability = ueRadioCapability } if coreNetworkAssistanceInformation != nil { - ranUe.CoreNetworkAssistanceInformation = coreNetworkAssistanceInformation + ranUeCtx.CoreNetworkAssistanceInformation = coreNetworkAssistanceInformation } if indexToRFSP != nil { - ranUe.IndexToRfsp = indexToRFSP.Value + ranUeCtx.IndexToRfsp = indexToRFSP.Value } if ueSecurityCapabilities != nil { - ranUe.SecurityCapabilities = ueSecurityCapabilities - } - - spi, ok := n3iwfCtx.IkeSpiLoad(ranUe.RanUeNgapId) - if !ok { - ngapLog.Errorf("Cannot get spi from ngapid : %+v", ranUe.RanUeNgapId) - return + ranUeCtx.SecurityCapabilities = ueSecurityCapabilities } - // if nasPDU != nil { - // TODO: Send NAS UE - // } - // Send EAP Success to UE - s.IkeEvtCh() <- n3iwf_context.NewSendEAPSuccessMsgEvt(spi, securityKey.Value.Bytes, - len(ranUe.PduSessionList)) + switch ue := ranUe.(type) { + case *n3iwf_context.N3IWFRanUe: + spi, ok := n3iwfCtx.IkeSpiLoad(ranUeCtx.RanUeNgapId) + if !ok { + ngapLog.Errorf("Cannot get spi from ngapid : %+v", ranUeCtx.RanUeNgapId) + return + } + + s.SendIkeEvt(n3iwf_context.NewSendEAPSuccessMsgEvt( + spi, securityKey.Value.Bytes, len(ranUeCtx.PduSessionList), + )) + default: + ngapLog.Errorf("Unknown UE type: %T", ue) + } } // handlePDUSessionResourceSetupRequestTransfer parse and store needed information from NGAP @@ -742,7 +748,7 @@ func (s *Server) HandleInitialContextSetupRequest( // a status value indicate whether the handlling is "success" :: // if failed, an unsuccessfulTransfer is set, otherwise, set to nil func (s *Server) handlePDUSessionResourceSetupRequestTransfer( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, pduSession *n3iwf_context.PDUSession, transfer ngapType.PDUSessionResourceSetupRequestTransfer, ) (bool, []byte) { @@ -853,15 +859,22 @@ func (s *Server) handlePDUSessionResourceSetupRequestTransfer( qosFlow.Identifier = item.QosFlowIdentifier.Value qosFlow.Parameters = item.QosFlowLevelQosParameters pduSession.QosFlows[item.QosFlowIdentifier.Value] = qosFlow + + value := item.QosFlowIdentifier.Value + if value < 0 || value > math.MaxUint8 { + ngapLog.Errorf("handlePDUSessionResourceSetupRequestTransfer() "+ + "item.QosFlowIdentifier.Value exceeds uint8 range: %d", value) + return false, nil + } // QFI List - pduSession.QFIList = append(pduSession.QFIList, uint8(item.QosFlowIdentifier.Value)) + pduSession.QFIList = append(pduSession.QFIList, uint8(value)) } // Setup GTP tunnel with UPF // TODO: Support IPv6 upfIPv4, _ := ngapConvert.IPAddressToString(ulNGUUPTNLInformation.GTPTunnel.TransportLayerAddress) if upfIPv4 != "" { - gtpConnection := &n3iwf_context.GTPConnectionInfo{ + gtpConnInfo := &n3iwf_context.GTPConnectionInfo{ UPFIPAddr: upfIPv4, OutgoingTEID: binary.BigEndian.Uint32(ulNGUUPTNLInformation.GTPTunnel.GTPTEID.Value), } @@ -870,10 +883,12 @@ func (s *Server) handlePDUSessionResourceSetupRequestTransfer( upfAddr := upfIPv4 + gtpv1.GTPUPort upfUDPAddr, err := net.ResolveUDPAddr("udp", upfAddr) if err != nil { + var responseTransfer []byte + ngapLog.Errorf("Resolve UPF addr [%s] failed: %v", upfAddr, err) cause := message.BuildCause(ngapType.CausePresentTransport, ngapType.CauseTransportPresentTransportResourceUnavailable) - responseTransfer, err := message.BuildPDUSessionResourceSetupUnsuccessfulTransfer(*cause, nil) + responseTransfer, err = message.BuildPDUSessionResourceSetupUnsuccessfulTransfer(*cause, nil) if err != nil { ngapLog.Errorf("Build PDUSessionResourceSetupUnsuccessfulTransfer Error: %v\n", err) } @@ -883,11 +898,13 @@ func (s *Server) handlePDUSessionResourceSetupRequestTransfer( // UE TEID ueTEID := n3iwfCtx.NewTEID(ranUe) if ueTEID == 0 { + var responseTransfer []byte + ngapLog.Error("Invalid TEID (0).") cause := message.BuildCause( ngapType.CausePresentProtocol, ngapType.CauseProtocolPresentUnspecified) - responseTransfer, err := message.BuildPDUSessionResourceSetupUnsuccessfulTransfer(*cause, nil) + responseTransfer, err = message.BuildPDUSessionResourceSetupUnsuccessfulTransfer(*cause, nil) if err != nil { ngapLog.Errorf("Build PDUSessionResourceSetupUnsuccessfulTransfer Error: %v\n", err) } @@ -895,10 +912,10 @@ func (s *Server) handlePDUSessionResourceSetupRequestTransfer( } // Setup GTP connection with UPF - gtpConnection.UPFUDPAddr = upfUDPAddr - gtpConnection.IncomingTEID = ueTEID + gtpConnInfo.UPFUDPAddr = upfUDPAddr + gtpConnInfo.IncomingTEID = ueTEID - pduSession.GTPConnection = gtpConnection + pduSession.GTPConnInfo = gtpConnInfo } else { ngapLog.Error( "Cannot parse \"PDU session resource setup request transfer\" message \"UL NG-U UP TNL Information\"") @@ -919,7 +936,7 @@ func (s *Server) HandleUEContextModificationRequest( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle UE Context Modification Request") + ngapLog.Infoln("Handle UE Context Modification Request") if amf == nil { ngapLog.Error("Corresponding AMF context not found") @@ -935,7 +952,9 @@ func (s *Server) HandleUEContextModificationRequest( var indexToRFSP *ngapType.IndexToRFSP var iesCriticalityDiagnostics ngapType.CriticalityDiagnosticsIEList - var ranUe *n3iwf_context.N3IWFRanUe + var ranUe n3iwf_context.RanUe + var ranUeCtx *n3iwf_context.RanUeSharedCtx + n3iwfCtx := s.Context() if pdu == nil { @@ -1014,7 +1033,8 @@ func (s *Server) HandleUEContextModificationRequest( // Cause: Unknown local UE NGAP ID return } - if ranUe.AmfUeNgapId != amfUeNgapID.Value { + ranUeCtx = ranUe.GetSharedCtx() + if ranUeCtx.AmfUeNgapId != amfUeNgapID.Value { // TODO: build cause and handle error // Cause: Inconsistent remote UE NGAP ID return @@ -1023,34 +1043,33 @@ func (s *Server) HandleUEContextModificationRequest( if newAmfUeNgapID != nil { ngapLog.Debugf("New AmfUeNgapID[%d]\n", newAmfUeNgapID.Value) - ranUe.AmfUeNgapId = newAmfUeNgapID.Value + ranUeCtx.AmfUeNgapId = newAmfUeNgapID.Value } if ueAggregateMaximumBitRate != nil { - ranUe.Ambr = ueAggregateMaximumBitRate + ranUeCtx.Ambr = ueAggregateMaximumBitRate // TODO: use the received UE Aggregate Maximum Bit Rate for all non-GBR QoS flows } if ueSecurityCapabilities != nil { - ranUe.SecurityCapabilities = ueSecurityCapabilities + ranUeCtx.SecurityCapabilities = ueSecurityCapabilities } // TODO: use new security key to update security context if indexToRFSP != nil { - ranUe.IndexToRfsp = indexToRFSP.Value + ranUeCtx.IndexToRfsp = indexToRFSP.Value } message.SendUEContextModificationResponse(ranUe, nil) - spi, ok := n3iwfCtx.IkeSpiLoad(ranUe.RanUeNgapId) + spi, ok := n3iwfCtx.IkeSpiLoad(ranUeCtx.RanUeNgapId) if !ok { - ngapLog.Errorf("Cannot get spi from ngapid : %+v", ranUe.RanUeNgapId) + ngapLog.Errorf("Cannot get spi from ngapid : %+v", ranUeCtx.RanUeNgapId) return } - s.IkeEvtCh() <- n3iwf_context.NewIKEContextUpdateEvt(spi, - securityKey.Value.Bytes) // Kn3iwf + s.SendIkeEvt(n3iwf_context.NewIKEContextUpdateEvt(spi, securityKey.Value.Bytes)) // Kn3iwf } func (s *Server) HandleUEContextReleaseCommand( @@ -1058,7 +1077,7 @@ func (s *Server) HandleUEContextReleaseCommand( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle UE Context Release Command") + ngapLog.Infoln("Handle UE Context Release Command") if amf == nil { ngapLog.Error("Corresponding AMF context not found") @@ -1068,8 +1087,8 @@ func (s *Server) HandleUEContextReleaseCommand( var ueNgapIDs *ngapType.UENGAPIDs var cause *ngapType.Cause var iesCriticalityDiagnostics ngapType.CriticalityDiagnosticsIEList + var ranUe n3iwf_context.RanUe - var ranUe *n3iwf_context.N3IWFRanUe n3iwfCtx := s.Context() if pdu == nil { @@ -1135,6 +1154,8 @@ func (s *Server) HandleUEContextReleaseCommand( printAndGetCause(cause) } + ranUe.GetSharedCtx().UeCtxRelState = n3iwf_context.UeCtxRelStateOngoing + message.SendUEContextReleaseComplete(ranUe, nil) err := s.releaseIkeUeAndRanUe(ranUe) @@ -1143,37 +1164,27 @@ func (s *Server) HandleUEContextReleaseCommand( } } -func (s *Server) releaseIkeUeAndRanUe(ranUe *n3iwf_context.N3IWFRanUe) error { +func (s *Server) releaseIkeUeAndRanUe(ranUe n3iwf_context.RanUe) error { n3iwfCtx := s.Context() - localSPI, ok := n3iwfCtx.IkeSpiLoad(ranUe.RanUeNgapId) + ranUeNgapID := ranUe.GetSharedCtx().RanUeNgapId + + localSPI, ok := n3iwfCtx.IkeSpiLoad(ranUeNgapID) if ok { - s.IkeEvtCh() <- n3iwf_context.NewIKEDeleteRequestEvt(localSPI) + s.SendIkeEvt(n3iwf_context.NewIKEDeleteRequestEvt(localSPI)) } if err := ranUe.Remove(); err != nil { - return errors.Wrapf(err, "releaseIkeUeAndRanUe RanUeNgapId[%016x]", ranUe.RanUeNgapId) + return errors.Wrapf(err, "releaseIkeUeAndRanUe RanUeNgapId[%016x]", ranUeNgapID) } return nil } -func encapNasMsgToEnvelope(nasPDU *ngapType.NASPDU) []byte { - // According to TS 24.502 8.2.4, - // in order to transport a NAS message over the non-3GPP access between the UE and the N3IWF, - // the NAS message shall be framed in a NAS message envelope as defined in subclause 9.4. - // According to TS 24.502 9.4, - // a NAS message envelope = Length | NAS Message - nasEnv := make([]byte, 2) - binary.BigEndian.PutUint16(nasEnv, uint16(len(nasPDU.Value))) - nasEnv = append(nasEnv, nasPDU.Value...) - return nasEnv -} - func (s *Server) HandleDownlinkNASTransport( amf *n3iwf_context.N3IWFAMF, pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle Downlink NAS Transport") + ngapLog.Infoln("Handle Downlink NAS Transport") if amf == nil { ngapLog.Error("Corresponding AMF context not found") @@ -1188,8 +1199,8 @@ func (s *Server) HandleDownlinkNASTransport( var ueAggregateMaximumBitRate *ngapType.UEAggregateMaximumBitRate var allowedNSSAI *ngapType.AllowedNSSAI var iesCriticalityDiagnostics ngapType.CriticalityDiagnosticsIEList + var ranUe n3iwf_context.RanUe - var ranUe *n3iwf_context.N3IWFRanUe n3iwfCtx := s.Context() if pdu == nil { @@ -1265,13 +1276,14 @@ func (s *Server) HandleDownlinkNASTransport( return } } + ranUeCtx := ranUe.GetSharedCtx() if amfUeNgapID != nil { - if ranUe.AmfUeNgapId == n3iwf_context.AmfUeNgapIdUnspecified { + if ranUeCtx.AmfUeNgapId == n3iwf_context.AmfUeNgapIdUnspecified { ngapLog.Tracef("Create new logical UE-associated NG-connection") - ranUe.AmfUeNgapId = amfUeNgapID.Value + ranUeCtx.AmfUeNgapId = amfUeNgapID.Value } else { - if ranUe.AmfUeNgapId != amfUeNgapID.Value { + if ranUeCtx.AmfUeNgapId != amfUeNgapID.Value { ngapLog.Warn("AMFUENGAPID unmatched") return } @@ -1283,46 +1295,48 @@ func (s *Server) HandleDownlinkNASTransport( } if indexToRFSP != nil { - ranUe.IndexToRfsp = indexToRFSP.Value + ranUeCtx.IndexToRfsp = indexToRFSP.Value } if ueAggregateMaximumBitRate != nil { - ranUe.Ambr = ueAggregateMaximumBitRate + ranUeCtx.Ambr = ueAggregateMaximumBitRate } if allowedNSSAI != nil { - ranUe.AllowedNssai = allowedNSSAI + ranUeCtx.AllowedNssai = allowedNSSAI } if nasPDU != nil { - // TODO: Send NAS PDU to UE - - // Send EAP5G NAS to UE - spi, ok := n3iwfCtx.IkeSpiLoad(ranUe.RanUeNgapId) - if !ok { - ngapLog.Errorf("Cannot get SPI from RanUeNGAPId : %+v", ranUe.RanUeNgapId) - return - } + switch ue := ranUe.(type) { + case *n3iwf_context.N3IWFRanUe: + // Send EAP5G NAS to UE + spi, ok := n3iwfCtx.IkeSpiLoad(ue.RanUeNgapId) + if !ok { + ngapLog.Errorf("Cannot get SPI from RanUeNGAPId : %+v", ue.RanUeNgapId) + return + } - if !ranUe.IsNASTCPConnEstablished { - s.IkeEvtCh() <- n3iwf_context.NewSendEAPNASMsgEvt(spi, - []byte(nasPDU.Value)) - } else { - // Using a "NAS message envelope" to transport a NAS message - // over the non-3GPP access between the UE and the N3IWF - nasEnv := encapNasMsgToEnvelope(nasPDU) - - if ranUe.IsNASTCPConnEstablishedComplete { - // Send to UE - if n, err := ranUe.TCPConnection.Write(nasEnv); err != nil { - ngapLog.Errorf("Writing via IPSec signalling SA failed: %v", err) + if !ue.IsNASTCPConnEstablished { + s.SendIkeEvt(n3iwf_context.NewSendEAPNASMsgEvt(spi, []byte(nasPDU.Value))) + } else { + // Using a "NAS message envelope" to transport a NAS message + // over the non-3GPP access between the UE and the N3IWF + nasEnv := nas_security.EncapNasMsgToEnvelope([]byte(nasPDU.Value)) + + if ue.IsNASTCPConnEstablishedComplete { + // Send to UE + if n, err := ue.TCPConnection.Write(nasEnv); err != nil { + ngapLog.Errorf("Writing via IPSec signalling SA failed: %v", err) + } else { + ngapLog.Trace("Forward NWu <- N2") + ngapLog.Tracef("Wrote %d bytes", n) + } } else { - ngapLog.Trace("Forward NWu <- N2") - ngapLog.Tracef("Wrote %d bytes", n) + ue.TemporaryCachedNASMessage = nasEnv } - } else { - ranUe.TemporaryCachedNASMessage = nasEnv } + default: + ngapLog.Errorf("Unknown UE type: %T", ue) } } } @@ -1332,7 +1346,7 @@ func (s *Server) HandlePDUSessionResourceSetupRequest( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle PDU Session Resource Setup Request") + ngapLog.Infoln("Handle PDU Session Resource Setup Request") if amf == nil { ngapLog.Error("Corresponding AMF context not found") @@ -1345,8 +1359,9 @@ func (s *Server) HandlePDUSessionResourceSetupRequest( var pduSessionResourceSetupListSUReq *ngapType.PDUSessionResourceSetupListSUReq var iesCriticalityDiagnostics ngapType.CriticalityDiagnosticsIEList var pduSessionEstablishmentAccept *ngapType.NASPDU + var ranUe n3iwf_context.RanUe + var ranUeCtx *n3iwf_context.RanUeSharedCtx - var ranUe *n3iwf_context.N3IWFRanUe n3iwfCtx := s.Context() if pdu == nil { @@ -1417,7 +1432,8 @@ func (s *Server) HandlePDUSessionResourceSetupRequest( // Cause: Unknown local UE NGAP ID return } - if ranUe.AmfUeNgapId != amfUeNgapID.Value { + ranUeCtx = ranUe.GetSharedCtx() + if ranUeCtx.AmfUeNgapId != amfUeNgapID.Value { // TODO: build cause and handle error // Cause: Inconsistent remote UE NGAP ID return @@ -1425,17 +1441,21 @@ func (s *Server) HandlePDUSessionResourceSetupRequest( } if nasPDU != nil { - // TODO: Send NAS to UE - if ranUe.TCPConnection == nil { + n3iwfUe, ok := ranUe.(*n3iwf_context.N3IWFRanUe) + if !ok { + ngapLog.Errorln("HandlePDUSessionResourceSetupRequest(): [Type Assertion] RanUe -> N3iwfRanUe failed") + return + } + if n3iwfUe.TCPConnection == nil { ngapLog.Error("No IPSec NAS signalling SA for this UE") return } // Using a "NAS message envelope" to transport a NAS message // over the non-3GPP access between the UE and the N3IWF - nasEnv := encapNasMsgToEnvelope(nasPDU) + nasEnv := nas_security.EncapNasMsgToEnvelope([]byte(nasPDU.Value)) - n, err := ranUe.TCPConnection.Write(nasEnv) + n, err := n3iwfUe.TCPConnection.Write(nasEnv) if err != nil { ngapLog.Errorf("Send NAS to UE failed: %v", err) return @@ -1443,7 +1463,7 @@ func (s *Server) HandlePDUSessionResourceSetupRequest( ngapLog.Tracef("Wrote %d bytes", n) } - tempPDUSessionSetupData := ranUe.TemporaryPDUSessionSetupData + tempPDUSessionSetupData := ranUeCtx.TemporaryPDUSessionSetupData tempPDUSessionSetupData.NGAPProcedureCode.Value = ngapType.ProcedureCodeInitialContextSetup if pduSessionResourceSetupListSUReq != nil { @@ -1468,7 +1488,7 @@ func (s *Server) HandlePDUSessionResourceSetupRequest( pduSessionID, err) } - pduSession, err := ranUe.CreatePDUSession(pduSessionID, snssai) + pduSession, err := ranUeCtx.CreatePDUSession(pduSessionID, snssai) if err != nil { ngapLog.Errorf("Create PDU Session Error: %v\n", err) @@ -1484,6 +1504,7 @@ func (s *Server) HandlePDUSessionResourceSetupRequest( continue } + // Process the message for AN success, resTransfer := s.handlePDUSessionResourceSetupRequestTransfer( ranUe, pduSession, transfer) if success { @@ -1493,29 +1514,35 @@ func (s *Server) HandlePDUSessionResourceSetupRequest( pduSession) } else { // Delete the pdusession store in UE conext - delete(ranUe.PduSessionList, pduSessionID) + delete(ranUeCtx.PduSessionList, pduSessionID) message.AppendPDUSessionResourceFailedToSetupListSURes( failedListSURes, pduSessionID, resTransfer) } } } - if tempPDUSessionSetupData != nil { - spi, ok := n3iwfCtx.IkeSpiLoad(ranUe.RanUeNgapId) - if !ok { - ngapLog.Errorf("Cannot get SPI from ranNgapID : %+v", ranUeNgapID) - return - } - s.IkeEvtCh() <- n3iwf_context.NewCreatePDUSessionEvt(spi, - len(ranUe.PduSessionList), ranUe.TemporaryPDUSessionSetupData) + if tempPDUSessionSetupData != nil && len(tempPDUSessionSetupData.UnactivatedPDUSession) != 0 { + switch ue := ranUe.(type) { + case *n3iwf_context.N3IWFRanUe: + spi, ok := n3iwfCtx.IkeSpiLoad(ue.RanUeNgapId) + if !ok { + ngapLog.Errorf("Cannot get SPI from ranNgapID : %+v", ranUeNgapID) + return + } + + s.SendIkeEvt(n3iwf_context.NewCreatePDUSessionEvt(spi, + len(ue.PduSessionList), + ue.TemporaryPDUSessionSetupData), + ) - // TS 23.501 4.12.5 Requested PDU Session Establishment via Untrusted non-3GPP Access - // After all IPsec Child SAs are established, the N3IWF shall forward to UE via the signalling IPsec SA - // the PDU Session Establishment Accept message - nasEnv := encapNasMsgToEnvelope(pduSessionEstablishmentAccept) + // TS 23.501 4.12.5 Requested PDU Session Establishment via Untrusted non-3GPP Access + // After all IPsec Child SAs are established, the N3IWF shall forward to UE via the signalling IPsec SA + // the PDU Session Establishment Accept message + nasEnv := nas_security.EncapNasMsgToEnvelope([]byte(pduSessionEstablishmentAccept.Value)) - // Cache the pduSessionEstablishmentAccept and forward to the UE after all CREATE_CHILD_SAs finish - ranUe.TemporaryCachedNASMessage = nasEnv + // Cache the pduSessionEstablishmentAccept and forward to the UE after all CREATE_CHILD_SAs finish + ue.TemporaryCachedNASMessage = nasEnv + } } } @@ -1524,7 +1551,7 @@ func (s *Server) HandlePDUSessionResourceModifyRequest( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle PDU Session Resource Modify Request") + ngapLog.Infoln("Handle PDU Session Resource Modify Request") if amf == nil { ngapLog.Error("Corresponding AMF context not found") @@ -1535,8 +1562,9 @@ func (s *Server) HandlePDUSessionResourceModifyRequest( var ranUeNgapID *ngapType.RANUENGAPID var pduSessionResourceModifyListModReq *ngapType.PDUSessionResourceModifyListModReq var iesCriticalityDiagnostics ngapType.CriticalityDiagnosticsIEList + var ranUe n3iwf_context.RanUe + var ranUeCtx *n3iwf_context.RanUeSharedCtx - var ranUe *n3iwf_context.N3IWFRanUe n3iwfCtx := s.Context() if pdu == nil { @@ -1604,7 +1632,8 @@ func (s *Server) HandlePDUSessionResourceModifyRequest( // Cause: Unknown local UE NGAP ID return } - if ranUe.AmfUeNgapId != amfUeNgapID.Value { + ranUeCtx = ranUe.GetSharedCtx() + if ranUeCtx.AmfUeNgapId != amfUeNgapID.Value { // TODO: build cause and send error indication // Cause: Inconsistent remote UE NGAP ID return @@ -1627,7 +1656,7 @@ func (s *Server) HandlePDUSessionResourceModifyRequest( pduSessionID, err) } - if pduSession = ranUe.FindPDUSession(pduSessionID); pduSession == nil { + if pduSession = ranUeCtx.FindPDUSession(pduSessionID); pduSession == nil { ngapLog.Errorf("[PDUSessionID: %d] Unknown PDU session ID", pduSessionID) cause := message.BuildCause(ngapType.CausePresentRadioNetwork, @@ -1663,7 +1692,7 @@ func (s *Server) handlePDUSessionResourceModifyRequestTransfer( success bool, responseTransfer []byte, ) { ngapLog := logger.NgapLog - ngapLog.Trace("[N3IWF] Handle PDU Session Resource Modify Request Transfer") + ngapLog.Trace("Handle PDU Session Resource Modify Request Transfer") var pduSessionAMBR *ngapType.PDUSessionAggregateMaximumBitRate var ulNGUUPTNLModifyList *ngapType.ULNGUUPTNLModifyList @@ -1819,7 +1848,7 @@ func (s *Server) HandlePDUSessionResourceModifyConfirm( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle PDU Session Resource Modify Confirm") + ngapLog.Infoln("Handle PDU Session Resource Modify Confirm") var aMFUENGAPID *ngapType.AMFUENGAPID var rANUENGAPID *ngapType.RANUENGAPID @@ -1827,6 +1856,8 @@ func (s *Server) HandlePDUSessionResourceModifyConfirm( var pDUSessionResourceFailedToModifyListModCfm *ngapType.PDUSessionResourceFailedToModifyListModCfm var criticalityDiagnostics *ngapType.CriticalityDiagnostics // var iesCriticalityDiagnostics ngapType.CriticalityDiagnosticsIEList + var ranUe n3iwf_context.RanUe + var ranUeCtx *n3iwf_context.RanUeSharedCtx n3iwfCtx := s.Context() @@ -1872,8 +1903,6 @@ func (s *Server) HandlePDUSessionResourceModifyConfirm( } } - var ranUe *n3iwf_context.N3IWFRanUe - if rANUENGAPID != nil { var ok bool ranUe, ok = n3iwfCtx.RanUePoolLoad(rANUENGAPID.Value) @@ -1881,12 +1910,14 @@ func (s *Server) HandlePDUSessionResourceModifyConfirm( ngapLog.Errorf("Unknown local UE NGAP ID. RanUENGAPID: %d", rANUENGAPID.Value) return } + ranUeCtx = ranUe.GetSharedCtx() } + if aMFUENGAPID != nil { if ranUe != nil { - if ranUe.AmfUeNgapId != aMFUENGAPID.Value { + if ranUeCtx.AmfUeNgapId != aMFUENGAPID.Value { ngapLog.Errorf("Inconsistent remote UE NGAP ID, AMFUENGAPID: %d, RanUe.AmfUeNgapId: %d", - aMFUENGAPID.Value, ranUe.AmfUeNgapId) + aMFUENGAPID.Value, ranUeCtx.AmfUeNgapId) return } } else { @@ -1898,18 +1929,20 @@ func (s *Server) HandlePDUSessionResourceModifyConfirm( } } } + if ranUe == nil { ngapLog.Warn("RANUENGAPID and AMFUENGAPID are both nil") return } + if pDUSessionResourceModifyListModCfm != nil { for _, item := range pDUSessionResourceModifyListModCfm.List { pduSessionId := item.PDUSessionID.Value ngapLog.Tracef("PDU Session Id[%d] in Pdu Session Resource Modification Confrim List", pduSessionId) - sess, exist := ranUe.PduSessionList[pduSessionId] + sess, exist := ranUeCtx.PduSessionList[pduSessionId] if !exist { ngapLog.Warnf( - "PDU Session Id[%d] is not exist in Ue[ranUeNgapId:%d]", pduSessionId, ranUe.RanUeNgapId) + "PDU Session Id[%d] is not exist in Ue[ranUeNgapId:%d]", pduSessionId, ranUeCtx.RanUeNgapId) } else { transfer := ngapType.PDUSessionResourceModifyConfirmTransfer{} err := aper.UnmarshalWithParams(item.PDUSessionResourceModifyConfirmTransfer, &transfer, "valueExt") @@ -1943,7 +1976,7 @@ func (s *Server) HandlePDUSessionResourceModifyConfirm( } ngapLog.Tracef( "Release PDU Session Id[%d] due to PDU Session Resource Modify Indication Unsuccessful", pduSessionId) - delete(ranUe.PduSessionList, pduSessionId) + delete(ranUeCtx.PduSessionList, pduSessionId) } } @@ -1957,7 +1990,7 @@ func (s *Server) HandlePDUSessionResourceReleaseCommand( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle PDU Session Resource Release Command") + ngapLog.Infoln("Handle PDU Session Resource Release Command") var aMFUENGAPID *ngapType.AMFUENGAPID var rANUENGAPID *ngapType.RANUENGAPID // var rANPagingPriority *ngapType.RANPagingPriority @@ -2046,10 +2079,11 @@ func (s *Server) HandlePDUSessionResourceReleaseCommand( message.SendErrorIndication(amf, nil, nil, cause, nil) return } + ranUeCtx := ranUe.GetSharedCtx() - if ranUe.AmfUeNgapId != aMFUENGAPID.Value { + if ranUeCtx.AmfUeNgapId != aMFUENGAPID.Value { ngapLog.Errorf("Inconsistent remote UE NGAP ID, AMFUENGAPID: %d, RanUe.AmfUeNgapId: %d", - aMFUENGAPID.Value, ranUe.AmfUeNgapId) + aMFUENGAPID.Value, ranUeCtx.AmfUeNgapId) cause := message.BuildCause(ngapType.CausePresentRadioNetwork, ngapType.CauseRadioNetworkPresentInconsistentRemoteUENGAPID) message.SendErrorIndication(amf, nil, &rANUENGAPID.Value, cause, nil) @@ -2074,7 +2108,7 @@ func (s *Server) HandlePDUSessionResourceReleaseCommand( printAndGetCause(&transfer.Cause) } ngapLog.Tracef("Release PDU Session Id[%d] due to PDU Session Resource Release Command", pduSessionId) - delete(ranUe.PduSessionList, pduSessionId) + delete(ranUeCtx.PduSessionList, pduSessionId) // response list releaseItem := ngapType.PDUSessionResourceReleasedItemRelRes{ @@ -2091,11 +2125,11 @@ func (s *Server) HandlePDUSessionResourceReleaseCommand( ngapLog.Errorf("Cannot get SPI from RanUeNgapID : %+v", rANUENGAPID.Value) return } + ranUe.GetSharedCtx().PduSessResRelState = n3iwf_context.PduSessResRelStateOngoing - s.IkeEvtCh() <- n3iwf_context.NewSendChildSADeleteRequestEvt( - localSPI, releaseIdList) + s.SendIkeEvt(n3iwf_context.NewSendChildSADeleteRequestEvt(localSPI, releaseIdList)) - ranUe.PduSessionReleaseList = releaseList + ranUeCtx.PduSessionReleaseList = releaseList // if nASPDU != nil { // TODO: Send NAS to UE // } @@ -2106,7 +2140,7 @@ func (s *Server) HandleErrorIndication( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle Error Indication") + ngapLog.Infoln("Handle Error Indication") var aMFUENGAPID *ngapType.AMFUENGAPID var rANUENGAPID *ngapType.RANUENGAPID @@ -2198,7 +2232,7 @@ func (s *Server) HandleUERadioCapabilityCheckRequest( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle UE Radio Capability Check Request") + ngapLog.Infoln("Handle UE Radio Capability Check Request") var aMFUENGAPID *ngapType.AMFUENGAPID var rANUENGAPID *ngapType.RANUENGAPID var uERadioCapability *ngapType.UERadioCapability @@ -2273,7 +2307,7 @@ func (s *Server) HandleUERadioCapabilityCheckRequest( return } - ranUe.RadioCapability = uERadioCapability + ranUe.GetSharedCtx().RadioCapability = uERadioCapability } func (s *Server) HandleAMFConfigurationUpdate( @@ -2281,7 +2315,7 @@ func (s *Server) HandleAMFConfigurationUpdate( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle AMF Configuration Updaet") + ngapLog.Infoln("Handle AMF Configuration Updaet") var aMFName *ngapType.AMFName var servedGUAMIList *ngapType.ServedGUAMIList @@ -2405,7 +2439,7 @@ func (s *Server) HandleRANConfigurationUpdateAcknowledge( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle RAN Configuration Update Acknowledge") + ngapLog.Infoln("Handle RAN Configuration Update Acknowledge") var criticalityDiagnostics *ngapType.CriticalityDiagnostics @@ -2449,7 +2483,7 @@ func (s *Server) HandleRANConfigurationUpdateFailure( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle RAN Configuration Update Failure") + ngapLog.Infoln("Handle RAN Configuration Update Failure") var cause *ngapType.Cause var timeToWait *ngapType.TimeToWait @@ -2536,35 +2570,35 @@ func (s *Server) HandleDownlinkRANConfigurationTransfer( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle Downlink RAN Configuration Transfer") + ngapLog.Infoln("Handle Downlink RAN Configuration Transfer") } func (s *Server) HandleDownlinkRANStatusTransfer( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle Downlink RAN Status Transfer") + ngapLog.Infoln("Handle Downlink RAN Status Transfer") } func (s *Server) HandleAMFStatusIndication( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle AMF Status Indication") + ngapLog.Infoln("Handle AMF Status Indication") } func (s *Server) HandleLocationReportingControl( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle Location Reporting Control") + ngapLog.Infoln("Handle Location Reporting Control") } func (s *Server) HandleUETNLAReleaseRequest( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle UE TNLA Release Request") + ngapLog.Infoln("Handle UE TNLA Release Request") } func (s *Server) HandleOverloadStart( @@ -2572,7 +2606,7 @@ func (s *Server) HandleOverloadStart( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle Overload Start") + ngapLog.Infoln("Handle Overload Start") var aMFOverloadResponse *ngapType.OverloadResponse var aMFTrafficLoadReductionIndication *ngapType.TrafficLoadReductionIndication @@ -2622,7 +2656,7 @@ func (s *Server) HandleOverloadStop( pdu *ngapType.NGAPPDU, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Handle Overload Stop") + ngapLog.Infoln("Handle Overload Stop") if amf == nil { ngapLog.Error("AMF Context is nil") @@ -2776,14 +2810,22 @@ func (s *Server) HandleEvent(ngapEvent n3iwf_context.NgapEvt) { s.HandleStartTCPSignalNASMsg(ngapEvent) case n3iwf_context.NASTCPConnEstablishedComplete: s.HandleNASTCPConnEstablishedComplete(ngapEvent) + case n3iwf_context.SendUEContextRelease: + s.HandleSendSendUEContextRelease(ngapEvent) case n3iwf_context.SendUEContextReleaseRequest: s.HandleSendUEContextReleaseRequest(ngapEvent) case n3iwf_context.SendUEContextReleaseComplete: s.HandleSendUEContextReleaseComplete(ngapEvent) + case n3iwf_context.SendPDUSessionResourceRelease: + s.HandleSendSendPDUSessionResourceRelease(ngapEvent) case n3iwf_context.SendPDUSessionResourceReleaseResponse: s.HandleSendPDUSessionResourceReleaseRes(ngapEvent) case n3iwf_context.GetNGAPContext: s.HandleGetNGAPContext(ngapEvent) + case n3iwf_context.SendUplinkNASTransport: + s.HandleSendUplinkNASTransport(ngapEvent) + case n3iwf_context.SendInitialContextSetupResponse: + s.HandleSendInitialContextSetupResponse(ngapEvent) default: ngapLog.Errorf("Undefine NGAP event type") return @@ -2796,9 +2838,9 @@ func (s *Server) HandleGetNGAPContext( ngapLog := logger.NgapLog ngapLog.Tracef("Handle HandleGetNGAPContext Event") - getNGAPContextEvt := ngapEvent.(*n3iwf_context.GetNGAPContextEvt) - ranUeNgapId := getNGAPContextEvt.RanUeNgapId - ngapCxtReqNumlist := getNGAPContextEvt.NgapCxtReqNumlist + evt := ngapEvent.(*n3iwf_context.GetNGAPContextEvt) + ranUeNgapId := evt.RanUeNgapId + ngapCxtReqNumlist := evt.NgapCxtReqNumlist n3iwfCtx := s.Context() ranUe, ok := n3iwfCtx.RanUePoolLoad(ranUeNgapId) @@ -2812,7 +2854,7 @@ func (s *Server) HandleGetNGAPContext( for _, num := range ngapCxtReqNumlist { switch num { case n3iwf_context.CxtTempPDUSessionSetupData: - ngapCxt = append(ngapCxt, ranUe.TemporaryPDUSessionSetupData) + ngapCxt = append(ngapCxt, ranUe.GetSharedCtx().TemporaryPDUSessionSetupData) default: ngapLog.Errorf("Receive undefine NGAP Context Request number : %d", num) } @@ -2824,8 +2866,7 @@ func (s *Server) HandleGetNGAPContext( return } - s.IkeEvtCh() <- n3iwf_context.NewGetNGAPContextRepEvt(spi, - ngapCxtReqNumlist, ngapCxt) + s.SendIkeEvt(n3iwf_context.NewGetNGAPContextRepEvt(spi, ngapCxtReqNumlist, ngapCxt)) } func (s *Server) HandleUnmarshalEAP5GData( @@ -2834,10 +2875,10 @@ func (s *Server) HandleUnmarshalEAP5GData( ngapLog := logger.NgapLog ngapLog.Tracef("Handle UnmarshalEAP5GData Event") - unmarshalEAP5GDataEvt := ngapEvent.(*n3iwf_context.UnmarshalEAP5GDataEvt) - spi := unmarshalEAP5GDataEvt.LocalSPI - eapVendorData := unmarshalEAP5GDataEvt.EAPVendorData - isInitialUE := unmarshalEAP5GDataEvt.IsInitialUE + evt := ngapEvent.(*n3iwf_context.UnmarshalEAP5GDataEvt) + spi := evt.LocalSPI + eapVendorData := evt.EAPVendorData + isInitialUE := evt.IsInitialUE n3iwfCtx := s.Context() @@ -2881,19 +2922,25 @@ func (s *Server) HandleUnmarshalEAP5GData( selectedAMF := n3iwfCtx.AMFSelection(anParameters.GUAMI, anParameters.SelectedPLMNID) if selectedAMF == nil { - s.IkeEvtCh() <- n3iwf_context.NewSendEAP5GFailureMsgEvt(spi, n3iwf_context.ErrAMFSelection) + s.SendIkeEvt(n3iwf_context.NewSendEAP5GFailureMsgEvt(spi, n3iwf_context.ErrAMFSelection)) } else { - ranUe := n3iwfCtx.NewN3iwfRanUe() - ranUe.AMF = selectedAMF + n3iwfUe := n3iwfCtx.NewN3iwfRanUe() + n3iwfUe.AMF = selectedAMF if anParameters.EstablishmentCause != nil { - ranUe.RRCEstablishmentCause = int16(anParameters.EstablishmentCause.Value) + value := uint64(anParameters.EstablishmentCause.Value) + if value > uint64(math.MaxInt16) { + ngapLog.Errorf("HandleUnmarshalEAP5GData() anParameters.EstablishmentCause.Value "+ + "exceeds int16: %+v", value) + return + } else { + n3iwfUe.RRCEstablishmentCause = int16(value) + } } - s.IkeEvtCh() <- n3iwf_context.NewUnmarshalEAP5GDataResponseEvt(spi, - ranUe.RanUeNgapId, nasPDU) + s.SendIkeEvt(n3iwf_context.NewUnmarshalEAP5GDataResponseEvt(spi, n3iwfUe.RanUeNgapId, nasPDU)) } } else { - ranUeNgapId := unmarshalEAP5GDataEvt.RanUeNgapId + ranUeNgapId := evt.RanUeNgapId ranUe, ok := n3iwfCtx.RanUePoolLoad(ranUeNgapId) if !ok { ngapLog.Errorf("Cannot get RanUE from ranUeNgapId : %+v", ranUeNgapId) @@ -2909,11 +2956,11 @@ func (s *Server) HandleSendInitialUEMessage( ngapLog := logger.NgapLog ngapLog.Tracef("Handle SendInitialUEMessage Event") - sendInitialUEMessageEvt := ngapEvent.(*n3iwf_context.SendInitialUEMessageEvt) - ranUeNgapId := sendInitialUEMessageEvt.RanUeNgapId - ipv4Addr := sendInitialUEMessageEvt.IPv4Addr - ipv4Port := sendInitialUEMessageEvt.IPv4Port - nasPDU := sendInitialUEMessageEvt.NasPDU + evt := ngapEvent.(*n3iwf_context.SendInitialUEMessageEvt) + ranUeNgapId := evt.RanUeNgapId + ipv4Addr := evt.IPv4Addr + ipv4Port := evt.IPv4Port + nasPDU := evt.NasPDU n3iwfCtx := s.Context() ranUe, ok := n3iwfCtx.RanUePoolLoad(ranUeNgapId) @@ -2921,9 +2968,11 @@ func (s *Server) HandleSendInitialUEMessage( ngapLog.Errorf("Cannot get RanUE from ranUeNgapId : %+v", ranUeNgapId) return } - ranUe.IPAddrv4 = ipv4Addr - ranUe.PortNumber = int32(ipv4Port) - message.SendInitialUEMessage(ranUe.AMF, ranUe, nasPDU) + ranUeCtx := ranUe.GetSharedCtx() + + ranUeCtx.IPAddrv4 = ipv4Addr + ranUeCtx.PortNumber = int32(ipv4Port) // #nosec G115 + message.SendInitialUEMessage(ranUeCtx.AMF, ranUe, nasPDU) } func (s *Server) HandleSendPDUSessionResourceSetupResponse( @@ -2932,18 +2981,18 @@ func (s *Server) HandleSendPDUSessionResourceSetupResponse( ngapLog := logger.NgapLog ngapLog.Tracef("Handle SendPDUSessionResourceSetupResponse Event") - sendPDUSessionResourceSetupResEvt := ngapEvent.(*n3iwf_context.SendPDUSessionResourceSetupResEvt) - ranUeNgapId := sendPDUSessionResourceSetupResEvt.RanUeNgapId + evt := ngapEvent.(*n3iwf_context.SendPDUSessionResourceSetupResEvt) + ranUeNgapId := evt.RanUeNgapId n3iwfCtx := s.Context() - cfg := s.Config() ranUe, ok := n3iwfCtx.RanUePoolLoad(ranUeNgapId) if !ok { ngapLog.Errorf("Cannot get RanUE from ranUeNgapId : %+v", ranUeNgapId) return } + ranUeCtx := ranUe.GetSharedCtx() - temporaryPDUSessionSetupData := ranUe.TemporaryPDUSessionSetupData + temporaryPDUSessionSetupData := ranUeCtx.TemporaryPDUSessionSetupData if len(temporaryPDUSessionSetupData.UnactivatedPDUSession) != 0 { for index, pduSession := range temporaryPDUSessionSetupData.UnactivatedPDUSession { @@ -2977,9 +3026,15 @@ func (s *Server) HandleSendPDUSessionResourceSetupResponse( temporaryPDUSessionSetupData.FailedListSURes, pduSession.Id, transfer) } } else { + var gtpAddr string + switch ranUe.(type) { + case *n3iwf_context.N3IWFRanUe: + gtpAddr = s.Config().GetN3iwfGtpBindAddress() + } + // Append NGAP PDU session resource setup response transfer transfer, err := message.BuildPDUSessionResourceSetupResponseTransfer( - pduSession, cfg.GetGTPBindAddr()) + pduSession, gtpAddr) if err != nil { ngapLog.Errorf("Build PDU session resource setup response transfer failed: %v", err) return @@ -3014,8 +3069,8 @@ func (s *Server) HandleSendNASMsg( ngapLog := logger.NgapLog ngapLog.Tracef("Handle SendNASMsg Event") - sendNASMsgEvt := ngapEvent.(*n3iwf_context.SendNASMsgEvt) - ranUeNgapId := sendNASMsgEvt.RanUeNgapId + evt := ngapEvent.(*n3iwf_context.SendNASMsgEvt) + ranUeNgapId := evt.RanUeNgapId n3iwfCtx := s.Context() ranUe, ok := n3iwfCtx.RanUePoolLoad(ranUeNgapId) @@ -3024,11 +3079,17 @@ func (s *Server) HandleSendNASMsg( return } - if n, ikeErr := ranUe.TCPConnection.Write(ranUe.TemporaryCachedNASMessage); ikeErr != nil { + n3iwfUe, ok := ranUe.(*n3iwf_context.N3IWFRanUe) + if !ok { + ngapLog.Errorln("HandleSendNASMsg(): [Type Assertion] RanUe -> N3iwfUe failed") + return + } + + if n, ikeErr := n3iwfUe.TCPConnection.Write(n3iwfUe.TemporaryCachedNASMessage); ikeErr != nil { ngapLog.Errorf("Writing via IPSec signalling SA failed: %v", ikeErr) } else { ngapLog.Tracef("Forward PDU Seesion Establishment Accept to UE. Wrote %d bytes", n) - ranUe.TemporaryCachedNASMessage = nil + n3iwfUe.TemporaryCachedNASMessage = nil } } @@ -3038,8 +3099,8 @@ func (s *Server) HandleStartTCPSignalNASMsg( ngapLog := logger.NgapLog ngapLog.Tracef("Handle StartTCPSignalNASMsg Event") - startTCPSignalNASMsgEvt := ngapEvent.(*n3iwf_context.StartTCPSignalNASMsgEvt) - ranUeNgapId := startTCPSignalNASMsgEvt.RanUeNgapId + evt := ngapEvent.(*n3iwf_context.StartTCPSignalNASMsgEvt) + ranUeNgapId := evt.RanUeNgapId n3iwfCtx := s.Context() ranUe, ok := n3iwfCtx.RanUePoolLoad(ranUeNgapId) @@ -3048,7 +3109,13 @@ func (s *Server) HandleStartTCPSignalNASMsg( return } - ranUe.IsNASTCPConnEstablished = true + n3iwfUe, ok := ranUe.(*n3iwf_context.N3IWFRanUe) + if !ok { + ngapLog.Errorln("HandleStartTCPSignalNASMsg(): [Type Assertion] RanUe -> N3iwfUe failed") + return + } + + n3iwfUe.IsNASTCPConnEstablished = true } func (s *Server) HandleNASTCPConnEstablishedComplete( @@ -3057,8 +3124,8 @@ func (s *Server) HandleNASTCPConnEstablishedComplete( ngapLog := logger.NgapLog ngapLog.Tracef("Handle NASTCPConnEstablishedComplete Event") - nasTCPConnEstablishedCompleteEvt := ngapEvent.(*n3iwf_context.NASTCPConnEstablishedCompleteEvt) - ranUeNgapId := nasTCPConnEstablishedCompleteEvt.RanUeNgapId + evt := ngapEvent.(*n3iwf_context.NASTCPConnEstablishedCompleteEvt) + ranUeNgapId := evt.RanUeNgapId n3iwfCtx := s.Context() ranUe, ok := n3iwfCtx.RanUePoolLoad(ranUeNgapId) @@ -3066,18 +3133,23 @@ func (s *Server) HandleNASTCPConnEstablishedComplete( ngapLog.Errorf("Cannot get RanUE from ranUeNgapId : %+v", ranUeNgapId) return } + n3iwfUe, ok := ranUe.(*n3iwf_context.N3IWFRanUe) + if !ok { + ngapLog.Errorln("HandleNASTCPConnEstablishedComplete(): [Type Assertion] RanUe -> N3iwfUe failed") + return + } - ranUe.IsNASTCPConnEstablishedComplete = true + n3iwfUe.IsNASTCPConnEstablishedComplete = true - if ranUe.TemporaryCachedNASMessage != nil { + if n3iwfUe.TemporaryCachedNASMessage != nil { // Send to UE - if n, err := ranUe.TCPConnection.Write(ranUe.TemporaryCachedNASMessage); err != nil { + if n, err := n3iwfUe.TCPConnection.Write(n3iwfUe.TemporaryCachedNASMessage); err != nil { ngapLog.Errorf("Writing via IPSec signalling SA failed: %v", err) } else { ngapLog.Trace("Forward NWu <- N2") ngapLog.Tracef("Wrote %d bytes", n) } - ranUe.TemporaryCachedNASMessage = nil + n3iwfUe.TemporaryCachedNASMessage = nil } } @@ -3087,10 +3159,10 @@ func (s *Server) HandleSendUEContextReleaseRequest( ngapLog := logger.NgapLog ngapLog.Tracef("Handle SendUEContextReleaseRequest Event") - sendUEContextReleaseReqEvt := ngapEvent.(*n3iwf_context.SendUEContextReleaseRequestEvt) + evt := ngapEvent.(*n3iwf_context.SendUEContextReleaseRequestEvt) - ranUeNgapId := sendUEContextReleaseReqEvt.RanUeNgapId - errMsg := sendUEContextReleaseReqEvt.ErrMsg + ranUeNgapId := evt.RanUeNgapId + errMsg := evt.ErrMsg var cause *ngapType.Cause switch errMsg { @@ -3119,8 +3191,8 @@ func (s *Server) HandleSendUEContextReleaseComplete( ngapLog := logger.NgapLog ngapLog.Tracef("Handle SendUEContextReleaseComplete Event") - sendUEContextReleaseCompleteEvt := ngapEvent.(*n3iwf_context.SendUEContextReleaseCompleteEvt) - ranUeNgapId := sendUEContextReleaseCompleteEvt.RanUeNgapId + evt := ngapEvent.(*n3iwf_context.SendUEContextReleaseCompleteEvt) + ranUeNgapId := evt.RanUeNgapId n3iwfCtx := s.Context() ranUe, ok := n3iwfCtx.RanUePoolLoad(ranUeNgapId) @@ -3141,9 +3213,93 @@ func (s *Server) HandleSendPDUSessionResourceReleaseRes( ngapLog := logger.NgapLog ngapLog.Tracef("Handle SendPDUSessionResourceReleaseResponse Event") - sendPDUSessionResourceReleaseResEvt := ngapEvent.(*n3iwf_context.SendPDUSessionResourceReleaseResEvt) - ranUeNgapId := sendPDUSessionResourceReleaseResEvt.RanUeNgapId + evt := ngapEvent.(*n3iwf_context.SendPDUSessionResourceReleaseResEvt) + ranUeNgapId := evt.RanUeNgapId + + n3iwfCtx := s.Context() + ranUe, ok := n3iwfCtx.RanUePoolLoad(ranUeNgapId) + if !ok { + ngapLog.Errorf("Cannot get RanUE from ranUeNgapId : %+v", ranUeNgapId) + return + } + + message.SendPDUSessionResourceReleaseResponse(ranUe, ranUe.GetSharedCtx().PduSessionReleaseList, nil) +} + +func (s *Server) HandleSendUplinkNASTransport( + ngapEvent n3iwf_context.NgapEvt, +) { + ngapLog := logger.NgapLog + ngapLog.Tracef("Handle SendUplinkNASTransport Event") + + evt := ngapEvent.(*n3iwf_context.SendUplinkNASTransportEvt) + ranUeNgapId := evt.RanUeNgapId + n3iwfCtx := s.Context() + ranUe, ok := n3iwfCtx.RanUePoolLoad(ranUeNgapId) + if !ok { + ngapLog.Errorf("Cannot get RanUE from ranUeNgapId : %+v", ranUeNgapId) + return + } + + message.SendUplinkNASTransport(ranUe, evt.Pdu) +} + +func (s *Server) HandleSendInitialContextSetupResponse( + ngapEvent n3iwf_context.NgapEvt, +) { + ngapLog := logger.NgapLog + ngapLog.Tracef("Handle SendInitialContextSetupResponse Event") + + evt := ngapEvent.(*n3iwf_context.SendInitialContextSetupRespEvt) + ranUeNgapId := evt.RanUeNgapId + n3iwfCtx := s.Context() + ranUe, ok := n3iwfCtx.RanUePoolLoad(ranUeNgapId) + if !ok { + ngapLog.Errorf("Cannot get RanUE from ranUeNgapId : %+v", ranUeNgapId) + return + } + + message.SendInitialContextSetupResponse(ranUe, evt.ResponseList, evt.FailedList, evt.CriticalityDiagnostics) +} + +func (s *Server) HandleSendSendUEContextRelease( + ngapEvent n3iwf_context.NgapEvt, +) { + ngapLog := logger.NgapLog + ngapLog.Tracef("Handle SendSendUEContextRelease Event") + + evt := ngapEvent.(*n3iwf_context.SendUEContextReleaseEvt) + ranUeNgapId := evt.RanUeNgapId + n3iwfCtx := s.Context() + ranUe, ok := n3iwfCtx.RanUePoolLoad(ranUeNgapId) + if !ok { + ngapLog.Errorf("Cannot get RanUE from ranUeNgapId : %+v", ranUeNgapId) + return + } + + if ranUe.GetSharedCtx().UeCtxRelState { + if err := ranUe.Remove(); err != nil { + ngapLog.Errorf("Delete RanUe Context error : %v", err) + } + message.SendUEContextReleaseComplete(ranUe, nil) + ranUe.GetSharedCtx().UeCtxRelState = n3iwf_context.UeCtxRelStateNone + } else { + cause := message.BuildCause(ngapType.CausePresentRadioNetwork, + ngapType.CauseRadioNetworkPresentRadioConnectionWithUeLost) + message.SendUEContextReleaseRequest(ranUe, *cause) + ranUe.GetSharedCtx().UeCtxRelState = n3iwf_context.UeCtxRelStateOngoing + } +} + +func (s *Server) HandleSendSendPDUSessionResourceRelease( + ngapEvent n3iwf_context.NgapEvt, +) { + ngapLog := logger.NgapLog + ngapLog.Tracef("Handle SendSendPDUSessionResourceRelease Event") + evt := ngapEvent.(*n3iwf_context.SendPDUSessionResourceReleaseEvt) + ranUeNgapId := evt.RanUeNgapId + deletPduIds := evt.DeletPduIds n3iwfCtx := s.Context() ranUe, ok := n3iwfCtx.RanUePoolLoad(ranUeNgapId) if !ok { @@ -3151,5 +3307,13 @@ func (s *Server) HandleSendPDUSessionResourceReleaseRes( return } - message.SendPDUSessionResourceReleaseResponse(ranUe, ranUe.PduSessionReleaseList, nil) + if ranUe.GetSharedCtx().PduSessResRelState { + message.SendPDUSessionResourceReleaseResponse(ranUe, ranUe.GetSharedCtx().PduSessionReleaseList, nil) + ranUe.GetSharedCtx().PduSessResRelState = n3iwf_context.PduSessResRelStateNone + } else { + for _, id := range deletPduIds { + ranUe.GetSharedCtx().DeletePDUSession(id) + } + ranUe.GetSharedCtx().PduSessResRelState = n3iwf_context.PduSessResRelStateOngoing + } } diff --git a/internal/ngap/handler_test.go b/internal/ngap/handler_test.go index 9c675221..37ad392a 100644 --- a/internal/ngap/handler_test.go +++ b/internal/ngap/handler_test.go @@ -5,9 +5,9 @@ import ( "github.com/stretchr/testify/require" - n3iwf_context "github.com/free5gc/n3iwf/pkg/context" + n3iwf_context "github.com/free5gc/n3iwf/internal/context" + "github.com/free5gc/n3iwf/internal/ike" "github.com/free5gc/n3iwf/pkg/factory" - "github.com/free5gc/n3iwf/pkg/ike" ) func TestReleaseIkeUeAndRanUe(t *testing.T) { @@ -22,7 +22,9 @@ func TestReleaseIkeUeAndRanUe(t *testing.T) { n3iwfCtx := n3iwf.n3iwfCtx ranUe := &n3iwf_context.N3IWFRanUe{ - N3iwfCtx: n3iwfCtx, + RanUeSharedCtx: n3iwf_context.RanUeSharedCtx{ + N3iwfCtx: n3iwfCtx, + }, } ranUeNgapId := int64(0x1234567890ABCDEF) @@ -33,13 +35,14 @@ func TestReleaseIkeUeAndRanUe(t *testing.T) { n3iwfCtx.IKESPIToNGAPId.Store(spi, ranUeNgapId) stopCh := make(chan struct{}) + rcvIkeEvtCh := n3iwf.mockIkeEvtCh.GetRcvChan() go func() { for { select { case <-stopCh: return - case rcvEvt := <-n3iwf.ikeServer.RcvEventCh: + case rcvEvt := <-rcvIkeEvtCh: if rcvEvt.Type() != n3iwf_context.IKEDeleteRequest { t.Errorf("Receive Wrong Event") } diff --git a/internal/ngap/message/build.go b/internal/ngap/message/build.go index 103f7ae6..f51eaf9c 100644 --- a/internal/ngap/message/build.go +++ b/internal/ngap/message/build.go @@ -4,10 +4,12 @@ import ( "encoding/binary" "encoding/hex" + "github.com/pkg/errors" + "github.com/free5gc/aper" + n3iwf_context "github.com/free5gc/n3iwf/internal/context" "github.com/free5gc/n3iwf/internal/logger" "github.com/free5gc/n3iwf/internal/util" - n3iwf_context "github.com/free5gc/n3iwf/pkg/context" "github.com/free5gc/n3iwf/pkg/factory" "github.com/free5gc/ngap" "github.com/free5gc/ngap/ngapConvert" @@ -112,21 +114,21 @@ func BuildNGSetupRequest( nGSetupRequestIEs.List = append(nGSetupRequestIEs.List, ie) /* - * The reason PagingDRX ie was commented is that in TS23.501 - * PagingDRX was mentioned to be used only for 3GPP access. - * However, the question that if the paging function for N3IWF - * is needed requires verification. - - // PagingDRX - ie = ngapType.NGSetupRequestIEs{} - ie.Id.Value = ngapType.ProtocolIEIDDefaultPagingDRX - ie.Criticality.Value = ngapType.CriticalityPresentIgnore - ie.Value.Present = ngapType.NGSetupRequestIEsPresentDefaultPagingDRX - ie.Value.DefaultPagingDRX = new(ngapType.PagingDRX) - - pagingDRX := ie.Value.DefaultPagingDRX - pagingDRX.Value = ngapType.PagingDRXPresentV128 - nGSetupRequestIEs.List = append(nGSetupRequestIEs.List, ie) + // The reason PagingDRX ie was commented is that in TS23.501 + // PagingDRX was mentioned to be used only for 3GPP access. + // However, the question that if the paging function for N3IWF + // is needed requires verification. + + // PagingDRX + ie = ngapType.NGSetupRequestIEs{} + ie.Id.Value = ngapType.ProtocolIEIDDefaultPagingDRX + ie.Criticality.Value = ngapType.CriticalityPresentIgnore + ie.Value.Present = ngapType.NGSetupRequestIEsPresentDefaultPagingDRX + ie.Value.DefaultPagingDRX = new(ngapType.PagingDRX) + + pagingDRX := ie.Value.DefaultPagingDRX + pagingDRX.Value = ngapType.PagingDRXPresentV128 + nGSetupRequestIEs.List = append(nGSetupRequestIEs.List, ie) */ return ngap.Encoder(pdu) @@ -235,11 +237,13 @@ func BuildNGResetAcknowledge( } func BuildInitialContextSetupResponse( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, responseList *ngapType.PDUSessionResourceSetupListCxtRes, failedList *ngapType.PDUSessionResourceFailedToSetupListCxtRes, criticalityDiagnostics *ngapType.CriticalityDiagnostics, ) ([]byte, error) { + ranUeCtx := ranUe.GetSharedCtx() + var pdu ngapType.NGAPPDU pdu.Present = ngapType.NGAPPDUPresentSuccessfulOutcome pdu.SuccessfulOutcome = new(ngapType.SuccessfulOutcome) @@ -262,7 +266,7 @@ func BuildInitialContextSetupResponse( ie.Value.AMFUENGAPID = new(ngapType.AMFUENGAPID) aMFUENGAPID := ie.Value.AMFUENGAPID - aMFUENGAPID.Value = ranUe.AmfUeNgapId + aMFUENGAPID.Value = ranUeCtx.AmfUeNgapId initialContextSetupResponseIEs.List = append(initialContextSetupResponseIEs.List, ie) @@ -274,7 +278,7 @@ func BuildInitialContextSetupResponse( ie.Value.RANUENGAPID = new(ngapType.RANUENGAPID) rANUENGAPID := ie.Value.RANUENGAPID - rANUENGAPID.Value = ranUe.RanUeNgapId + rANUENGAPID.Value = ranUeCtx.RanUeNgapId initialContextSetupResponseIEs.List = append(initialContextSetupResponseIEs.List, ie) @@ -311,11 +315,13 @@ func BuildInitialContextSetupResponse( } func BuildInitialContextSetupFailure( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, cause ngapType.Cause, failedList *ngapType.PDUSessionResourceFailedToSetupListCxtFail, criticalityDiagnostics *ngapType.CriticalityDiagnostics, ) ([]byte, error) { + ranUeCtx := ranUe.GetSharedCtx() + var pdu ngapType.NGAPPDU pdu.Present = ngapType.NGAPPDUPresentUnsuccessfulOutcome pdu.UnsuccessfulOutcome = new(ngapType.UnsuccessfulOutcome) @@ -338,7 +344,7 @@ func BuildInitialContextSetupFailure( ie.Value.AMFUENGAPID = new(ngapType.AMFUENGAPID) aMFUENGAPID := ie.Value.AMFUENGAPID - aMFUENGAPID.Value = ranUe.AmfUeNgapId + aMFUENGAPID.Value = ranUeCtx.AmfUeNgapId initialContextSetupFailureIEs.List = append(initialContextSetupFailureIEs.List, ie) @@ -350,7 +356,7 @@ func BuildInitialContextSetupFailure( ie.Value.RANUENGAPID = new(ngapType.RANUENGAPID) rANUENGAPID := ie.Value.RANUENGAPID - rANUENGAPID.Value = ranUe.RanUeNgapId + rANUENGAPID.Value = ranUeCtx.RanUeNgapId initialContextSetupFailureIEs.List = append(initialContextSetupFailureIEs.List, ie) @@ -385,8 +391,10 @@ func BuildInitialContextSetupFailure( } func BuildUEContextModificationResponse( - ranUe *n3iwf_context.N3IWFRanUe, criticalityDiagnostics *ngapType.CriticalityDiagnostics, + ranUe n3iwf_context.RanUe, criticalityDiagnostics *ngapType.CriticalityDiagnostics, ) ([]byte, error) { + ranUeCtx := ranUe.GetSharedCtx() + var pdu ngapType.NGAPPDU pdu.Present = ngapType.NGAPPDUPresentSuccessfulOutcome pdu.SuccessfulOutcome = new(ngapType.SuccessfulOutcome) @@ -409,7 +417,7 @@ func BuildUEContextModificationResponse( ie.Value.AMFUENGAPID = new(ngapType.AMFUENGAPID) aMFUENGAPID := ie.Value.AMFUENGAPID - aMFUENGAPID.Value = ranUe.AmfUeNgapId + aMFUENGAPID.Value = ranUeCtx.AmfUeNgapId uEContextModificationResponseIEs.List = append(uEContextModificationResponseIEs.List, ie) @@ -421,7 +429,7 @@ func BuildUEContextModificationResponse( ie.Value.RANUENGAPID = new(ngapType.RANUENGAPID) rANUENGAPID := ie.Value.RANUENGAPID - rANUENGAPID.Value = ranUe.RanUeNgapId + rANUENGAPID.Value = ranUeCtx.RanUeNgapId uEContextModificationResponseIEs.List = append(uEContextModificationResponseIEs.List, ie) @@ -435,9 +443,11 @@ func BuildUEContextModificationResponse( return ngap.Encoder(pdu) } -func BuildUEContextModificationFailure(ranUe *n3iwf_context.N3IWFRanUe, cause ngapType.Cause, +func BuildUEContextModificationFailure(ranUe n3iwf_context.RanUe, cause ngapType.Cause, criticalityDiagnostics *ngapType.CriticalityDiagnostics, ) ([]byte, error) { + ranUeCtx := ranUe.GetSharedCtx() + var pdu ngapType.NGAPPDU pdu.Present = ngapType.NGAPPDUPresentUnsuccessfulOutcome pdu.UnsuccessfulOutcome = new(ngapType.UnsuccessfulOutcome) @@ -460,7 +470,7 @@ func BuildUEContextModificationFailure(ranUe *n3iwf_context.N3IWFRanUe, cause ng ie.Value.AMFUENGAPID = new(ngapType.AMFUENGAPID) aMFUENGAPID := ie.Value.AMFUENGAPID - aMFUENGAPID.Value = ranUe.AmfUeNgapId + aMFUENGAPID.Value = ranUeCtx.AmfUeNgapId uEContextModificationFailureIEs.List = append(uEContextModificationFailureIEs.List, ie) @@ -472,7 +482,7 @@ func BuildUEContextModificationFailure(ranUe *n3iwf_context.N3IWFRanUe, cause ng ie.Value.RANUENGAPID = new(ngapType.RANUENGAPID) rANUENGAPID := ie.Value.RANUENGAPID - rANUENGAPID.Value = ranUe.RanUeNgapId + rANUENGAPID.Value = ranUeCtx.RanUeNgapId uEContextModificationFailureIEs.List = append(uEContextModificationFailureIEs.List, ie) @@ -494,9 +504,11 @@ func BuildUEContextModificationFailure(ranUe *n3iwf_context.N3IWFRanUe, cause ng return ngap.Encoder(pdu) } -func BuildUEContextReleaseComplete(ranUe *n3iwf_context.N3IWFRanUe, +func BuildUEContextReleaseComplete(ranUe n3iwf_context.RanUe, criticalityDiagnostics *ngapType.CriticalityDiagnostics, ) ([]byte, error) { + ranUeCtx := ranUe.GetSharedCtx() + var pdu ngapType.NGAPPDU pdu.Present = ngapType.NGAPPDUPresentSuccessfulOutcome pdu.SuccessfulOutcome = new(ngapType.SuccessfulOutcome) @@ -519,7 +531,7 @@ func BuildUEContextReleaseComplete(ranUe *n3iwf_context.N3IWFRanUe, ie.Value.AMFUENGAPID = new(ngapType.AMFUENGAPID) aMFUENGAPID := ie.Value.AMFUENGAPID - aMFUENGAPID.Value = ranUe.AmfUeNgapId + aMFUENGAPID.Value = ranUeCtx.AmfUeNgapId uEContextReleaseCompleteIEs.List = append(uEContextReleaseCompleteIEs.List, ie) @@ -531,7 +543,7 @@ func BuildUEContextReleaseComplete(ranUe *n3iwf_context.N3IWFRanUe, ie.Value.RANUENGAPID = new(ngapType.RANUENGAPID) rANUENGAPID := ie.Value.RANUENGAPID - rANUENGAPID.Value = ranUe.RanUeNgapId + rANUENGAPID.Value = ranUeCtx.RanUeNgapId uEContextReleaseCompleteIEs.List = append(uEContextReleaseCompleteIEs.List, ie) @@ -540,20 +552,12 @@ func BuildUEContextReleaseComplete(ranUe *n3iwf_context.N3IWFRanUe, ie.Id.Value = ngapType.ProtocolIEIDUserLocationInformation ie.Criticality.Value = ngapType.CriticalityPresentIgnore ie.Value.Present = ngapType.UEContextReleaseCompleteIEsPresentUserLocationInformation - ie.Value.UserLocationInformation = new(ngapType.UserLocationInformation) - - userLocationInformation := ie.Value.UserLocationInformation - userLocationInformation.Present = ngapType.UserLocationInformationPresentUserLocationInformationN3IWF - userLocationInformation.UserLocationInformationN3IWF = new(ngapType.UserLocationInformationN3IWF) - - userLocationInfoN3IWF := userLocationInformation.UserLocationInformationN3IWF - userLocationInfoN3IWF.IPAddress = ngapConvert.IPAddressToNgap(ranUe.IPAddrv4, ranUe.IPAddrv6) - userLocationInfoN3IWF.PortNumber = ngapConvert.PortNumberToNgap(ranUe.PortNumber) + ie.Value.UserLocationInformation = ranUe.GetUserLocationInformation() uEContextReleaseCompleteIEs.List = append(uEContextReleaseCompleteIEs.List, ie) // PDU Session Resource List (optional) - if len(ranUe.PduSessionList) > 0 { + if len(ranUeCtx.PduSessionList) > 0 { ie = ngapType.UEContextReleaseCompleteIEs{} ie.Id.Value = ngapType.ProtocolIEIDPDUSessionResourceListCxtRelCpl ie.Criticality.Value = ngapType.CriticalityPresentReject @@ -563,7 +567,7 @@ func BuildUEContextReleaseComplete(ranUe *n3iwf_context.N3IWFRanUe, pDUSessionResourceListCxtRelCpl := ie.Value.PDUSessionResourceListCxtRelCpl // PDU Session Resource Item (in PDU Session Resource List) - for _, pduSession := range ranUe.PduSessionList { + for _, pduSession := range ranUeCtx.PduSessionList { pDUSessionResourceItemCxtRelCpl := ngapType.PDUSessionResourceItemCxtRelCpl{} pDUSessionResourceItemCxtRelCpl.PDUSessionID.Value = pduSession.Id pDUSessionResourceListCxtRelCpl.List = append(pDUSessionResourceListCxtRelCpl.List, @@ -585,7 +589,9 @@ func BuildUEContextReleaseComplete(ranUe *n3iwf_context.N3IWFRanUe, return ngap.Encoder(pdu) } -func BuildUEContextReleaseRequest(ranUe *n3iwf_context.N3IWFRanUe, cause ngapType.Cause) ([]byte, error) { +func BuildUEContextReleaseRequest(ranUe n3iwf_context.RanUe, cause ngapType.Cause) ([]byte, error) { + ranUeCtx := ranUe.GetSharedCtx() + var pdu ngapType.NGAPPDU pdu.Present = ngapType.NGAPPDUPresentInitiatingMessage pdu.InitiatingMessage = new(ngapType.InitiatingMessage) @@ -608,7 +614,7 @@ func BuildUEContextReleaseRequest(ranUe *n3iwf_context.N3IWFRanUe, cause ngapTyp ie.Value.AMFUENGAPID = new(ngapType.AMFUENGAPID) aMFUENGAPID := ie.Value.AMFUENGAPID - aMFUENGAPID.Value = ranUe.AmfUeNgapId + aMFUENGAPID.Value = ranUeCtx.AmfUeNgapId uEContextReleaseRequestIEs.List = append(uEContextReleaseRequestIEs.List, ie) @@ -620,27 +626,29 @@ func BuildUEContextReleaseRequest(ranUe *n3iwf_context.N3IWFRanUe, cause ngapTyp ie.Value.RANUENGAPID = new(ngapType.RANUENGAPID) rANUENGAPID := ie.Value.RANUENGAPID - rANUENGAPID.Value = ranUe.RanUeNgapId + rANUENGAPID.Value = ranUeCtx.RanUeNgapId uEContextReleaseRequestIEs.List = append(uEContextReleaseRequestIEs.List, ie) // PDU Session Resource List - ie = ngapType.UEContextReleaseRequestIEs{} - ie.Id.Value = ngapType.ProtocolIEIDPDUSessionResourceListCxtRelReq - ie.Criticality.Value = ngapType.CriticalityPresentReject - ie.Value.Present = ngapType.UEContextReleaseRequestIEsPresentPDUSessionResourceListCxtRelReq - ie.Value.PDUSessionResourceListCxtRelReq = new(ngapType.PDUSessionResourceListCxtRelReq) + if len(ranUeCtx.PduSessionList) > 0 { + ie = ngapType.UEContextReleaseRequestIEs{} + ie.Id.Value = ngapType.ProtocolIEIDPDUSessionResourceListCxtRelReq + ie.Criticality.Value = ngapType.CriticalityPresentReject + ie.Value.Present = ngapType.UEContextReleaseRequestIEsPresentPDUSessionResourceListCxtRelReq + ie.Value.PDUSessionResourceListCxtRelReq = new(ngapType.PDUSessionResourceListCxtRelReq) - pDUSessionResourceListCxtRelReq := ie.Value.PDUSessionResourceListCxtRelReq + pDUSessionResourceListCxtRelReq := ie.Value.PDUSessionResourceListCxtRelReq - // PDU Session Resource Item in PDU session Resource List - for _, pduSession := range ranUe.PduSessionList { - pDUSessionResourceItem := ngapType.PDUSessionResourceItemCxtRelReq{} - pDUSessionResourceItem.PDUSessionID.Value = pduSession.Id - pDUSessionResourceListCxtRelReq.List = append(pDUSessionResourceListCxtRelReq.List, - pDUSessionResourceItem) + // PDU Session Resource Item in PDU session Resource List + for _, pduSession := range ranUeCtx.PduSessionList { + pDUSessionResourceItem := ngapType.PDUSessionResourceItemCxtRelReq{} + pDUSessionResourceItem.PDUSessionID.Value = pduSession.Id + pDUSessionResourceListCxtRelReq.List = append(pDUSessionResourceListCxtRelReq.List, + pDUSessionResourceItem) + } + uEContextReleaseRequestIEs.List = append(uEContextReleaseRequestIEs.List, ie) } - uEContextReleaseRequestIEs.List = append(uEContextReleaseRequestIEs.List, ie) // Cause ie = ngapType.UEContextReleaseRequestIEs{} @@ -653,11 +661,13 @@ func BuildUEContextReleaseRequest(ranUe *n3iwf_context.N3IWFRanUe, cause ngapTyp return ngap.Encoder(pdu) } -func BuildInitialUEMessage(ranUe *n3iwf_context.N3IWFRanUe, nasPdu []byte, +func BuildInitialUEMessage(ranUe n3iwf_context.RanUe, nasPdu []byte, allowedNSSAI *ngapType.AllowedNSSAI, ) ([]byte, error) { ngapLog := logger.NgapLog + ranUeCtx := ranUe.GetSharedCtx() + var pdu ngapType.NGAPPDU pdu.Present = ngapType.NGAPPDUPresentInitiatingMessage pdu.InitiatingMessage = new(ngapType.InitiatingMessage) @@ -680,7 +690,7 @@ func BuildInitialUEMessage(ranUe *n3iwf_context.N3IWFRanUe, nasPdu []byte, ie.Value.RANUENGAPID = new(ngapType.RANUENGAPID) rANUENGAPID := ie.Value.RANUENGAPID - rANUENGAPID.Value = ranUe.RanUeNgapId + rANUENGAPID.Value = ranUeCtx.RanUeNgapId initialUEMessageIEs.List = append(initialUEMessageIEs.List, ie) } @@ -703,15 +713,7 @@ func BuildInitialUEMessage(ranUe *n3iwf_context.N3IWFRanUe, nasPdu []byte, ie.Id.Value = ngapType.ProtocolIEIDUserLocationInformation ie.Criticality.Value = ngapType.CriticalityPresentReject ie.Value.Present = ngapType.InitialUEMessageIEsPresentUserLocationInformation - ie.Value.UserLocationInformation = new(ngapType.UserLocationInformation) - - userLocationInformation := ie.Value.UserLocationInformation - userLocationInformation.Present = ngapType.UserLocationInformationPresentUserLocationInformationN3IWF - userLocationInformation.UserLocationInformationN3IWF = new(ngapType.UserLocationInformationN3IWF) - - userLocationInfoN3IWF := userLocationInformation.UserLocationInformationN3IWF - userLocationInfoN3IWF.IPAddress = ngapConvert.IPAddressToNgap(ranUe.IPAddrv4, ranUe.IPAddrv6) - userLocationInfoN3IWF.PortNumber = ngapConvert.PortNumberToNgap(ranUe.PortNumber) + ie.Value.UserLocationInformation = ranUe.GetUserLocationInformation() initialUEMessageIEs.List = append(initialUEMessageIEs.List, ie) } @@ -724,11 +726,16 @@ func BuildInitialUEMessage(ranUe *n3iwf_context.N3IWFRanUe, nasPdu []byte, ie.Value.RRCEstablishmentCause = new(ngapType.RRCEstablishmentCause) rRCEstablishmentCause := ie.Value.RRCEstablishmentCause - rRCEstablishmentCause.Value = aper.Enumerated(ranUe.RRCEstablishmentCause) + value := ranUeCtx.RRCEstablishmentCause + if value < 0 { + return nil, errors.Errorf("BuildInitialUEMessage() ranUe.RRCEstablishmentCause "+ + "negative value: %d", value) + } + rRCEstablishmentCause.Value = aper.Enumerated(value) initialUEMessageIEs.List = append(initialUEMessageIEs.List, ie) } // FiveGSTMSI - if len(ranUe.Guti) != 0 { + if len(ranUeCtx.Guti) != 0 { ie := ngapType.InitialUEMessageIEs{} ie.Id.Value = ngapType.ProtocolIEIDFiveGSTMSI ie.Criticality.Value = ngapType.CriticalityPresentReject @@ -738,12 +745,12 @@ func BuildInitialUEMessage(ranUe *n3iwf_context.N3IWFRanUe, nasPdu []byte, fiveGSTMSI := ie.Value.FiveGSTMSI var amfID string var tmsi string - if len(ranUe.Guti) == 19 { - amfID = ranUe.Guti[5:11] - tmsi = ranUe.Guti[11:] + if len(ranUeCtx.Guti) == 19 { + amfID = ranUeCtx.Guti[5:11] + tmsi = ranUeCtx.Guti[11:] } else { - amfID = ranUe.Guti[6:12] - tmsi = ranUe.Guti[12:] + amfID = ranUeCtx.Guti[6:12] + tmsi = ranUeCtx.Guti[12:] } _, amfSetID, amfPointer := ngapConvert.AmfIdToNgap(amfID) @@ -757,7 +764,7 @@ func BuildInitialUEMessage(ranUe *n3iwf_context.N3IWFRanUe, nasPdu []byte, initialUEMessageIEs.List = append(initialUEMessageIEs.List, ie) } // AMFSetID - if len(ranUe.Guti) != 0 { + if len(ranUeCtx.Guti) != 0 { ie := ngapType.InitialUEMessageIEs{} ie.Id.Value = ngapType.ProtocolIEIDAMFSetID ie.Criticality.Value = ngapType.CriticalityPresentIgnore @@ -769,10 +776,10 @@ func BuildInitialUEMessage(ranUe *n3iwf_context.N3IWFRanUe, nasPdu []byte, // is 3 bytes, is 3 bytes // 1 byte is 2 characters var amfID string - if len(ranUe.Guti) == 19 { // MNC is 2 char - amfID = ranUe.Guti[5:11] + if len(ranUeCtx.Guti) == 19 { // MNC is 2 char + amfID = ranUeCtx.Guti[5:11] } else { - amfID = ranUe.Guti[6:12] + amfID = ranUeCtx.Guti[6:12] } _, aMFSetID.Value, _ = ngapConvert.AmfIdToNgap(amfID) @@ -805,7 +812,9 @@ func BuildInitialUEMessage(ranUe *n3iwf_context.N3IWFRanUe, nasPdu []byte, return ngap.Encoder(pdu) } -func BuildUplinkNASTransport(ranUe *n3iwf_context.N3IWFRanUe, nasPdu []byte) ([]byte, error) { +func BuildUplinkNASTransport(ranUe n3iwf_context.RanUe, nasPdu []byte) ([]byte, error) { + ranUeCtx := ranUe.GetSharedCtx() + var pdu ngapType.NGAPPDU pdu.Present = ngapType.NGAPPDUPresentInitiatingMessage pdu.InitiatingMessage = new(ngapType.InitiatingMessage) @@ -828,7 +837,7 @@ func BuildUplinkNASTransport(ranUe *n3iwf_context.N3IWFRanUe, nasPdu []byte) ([] ie.Value.AMFUENGAPID = new(ngapType.AMFUENGAPID) aMFUENGAPID := ie.Value.AMFUENGAPID - aMFUENGAPID.Value = ranUe.AmfUeNgapId + aMFUENGAPID.Value = ranUeCtx.AmfUeNgapId uplinkNasTransportIEs.List = append(uplinkNasTransportIEs.List, ie) @@ -840,7 +849,7 @@ func BuildUplinkNASTransport(ranUe *n3iwf_context.N3IWFRanUe, nasPdu []byte) ([] ie.Value.RANUENGAPID = new(ngapType.RANUENGAPID) rANUENGAPID := ie.Value.RANUENGAPID - rANUENGAPID.Value = ranUe.RanUeNgapId + rANUENGAPID.Value = ranUeCtx.RanUeNgapId uplinkNasTransportIEs.List = append(uplinkNasTransportIEs.List, ie) @@ -859,14 +868,7 @@ func BuildUplinkNASTransport(ranUe *n3iwf_context.N3IWFRanUe, nasPdu []byte) ([] ie.Id.Value = ngapType.ProtocolIEIDUserLocationInformation ie.Criticality.Value = ngapType.CriticalityPresentIgnore ie.Value.Present = ngapType.UplinkNASTransportIEsPresentUserLocationInformation - ie.Value.UserLocationInformation = new(ngapType.UserLocationInformation) - - userLocationInformation := ie.Value.UserLocationInformation - userLocationInformation.Present = ngapType.UserLocationInformationPresentUserLocationInformationN3IWF - userLocationInformation.UserLocationInformationN3IWF = new(ngapType.UserLocationInformationN3IWF) - userLocationInformationN3IWF := userLocationInformation.UserLocationInformationN3IWF - userLocationInformationN3IWF.IPAddress = ngapConvert.IPAddressToNgap(ranUe.IPAddrv4, ranUe.IPAddrv6) - userLocationInformationN3IWF.PortNumber = ngapConvert.PortNumberToNgap(ranUe.PortNumber) + ie.Value.UserLocationInformation = ranUe.GetUserLocationInformation() uplinkNasTransportIEs.List = append(uplinkNasTransportIEs.List, ie) @@ -874,9 +876,11 @@ func BuildUplinkNASTransport(ranUe *n3iwf_context.N3IWFRanUe, nasPdu []byte) ([] } func BuildNASNonDeliveryIndication( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, nasPdu []byte, cause ngapType.Cause, ) ([]byte, error) { + ranUeCtx := ranUe.GetSharedCtx() + var pdu ngapType.NGAPPDU pdu.Present = ngapType.NGAPPDUPresentInitiatingMessage pdu.InitiatingMessage = new(ngapType.InitiatingMessage) @@ -899,7 +903,7 @@ func BuildNASNonDeliveryIndication( ie.Value.AMFUENGAPID = new(ngapType.AMFUENGAPID) aMFUENGAPID := ie.Value.AMFUENGAPID - aMFUENGAPID.Value = ranUe.AmfUeNgapId + aMFUENGAPID.Value = ranUeCtx.AmfUeNgapId nASNonDeliveryIndicationIEs.List = append(nASNonDeliveryIndicationIEs.List, ie) } @@ -912,7 +916,7 @@ func BuildNASNonDeliveryIndication( ie.Value.RANUENGAPID = new(ngapType.RANUENGAPID) rANUENGAPID := ie.Value.RANUENGAPID - rANUENGAPID.Value = ranUe.RanUeNgapId + rANUENGAPID.Value = ranUeCtx.RanUeNgapId nASNonDeliveryIndicationIEs.List = append(nASNonDeliveryIndicationIEs.List, ie) } @@ -950,11 +954,13 @@ func BuildRerouteNASRequest() ([]byte, error) { } func BuildPDUSessionResourceSetupResponse( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, responseList *ngapType.PDUSessionResourceSetupListSURes, failedList *ngapType.PDUSessionResourceFailedToSetupListSURes, criticalityDiagnostics *ngapType.CriticalityDiagnostics, ) ([]byte, error) { + ranUeCtx := ranUe.GetSharedCtx() + var pdu ngapType.NGAPPDU pdu.Present = ngapType.NGAPPDUPresentSuccessfulOutcome pdu.SuccessfulOutcome = new(ngapType.SuccessfulOutcome) @@ -977,7 +983,7 @@ func BuildPDUSessionResourceSetupResponse( ie.Value.AMFUENGAPID = new(ngapType.AMFUENGAPID) aMFUENGAPID := ie.Value.AMFUENGAPID - aMFUENGAPID.Value = ranUe.AmfUeNgapId + aMFUENGAPID.Value = ranUeCtx.AmfUeNgapId pduSessionResourceSetupResponseIEs.List = append(pduSessionResourceSetupResponseIEs.List, ie) @@ -989,7 +995,7 @@ func BuildPDUSessionResourceSetupResponse( ie.Value.RANUENGAPID = new(ngapType.RANUENGAPID) rANUENGAPID := ie.Value.RANUENGAPID - rANUENGAPID.Value = ranUe.RanUeNgapId + rANUENGAPID.Value = ranUeCtx.RanUeNgapId pduSessionResourceSetupResponseIEs.List = append(pduSessionResourceSetupResponseIEs.List, ie) @@ -1026,11 +1032,13 @@ func BuildPDUSessionResourceSetupResponse( } func BuildPDUSessionResourceModifyResponse( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, responseList *ngapType.PDUSessionResourceModifyListModRes, failedList *ngapType.PDUSessionResourceFailedToModifyListModRes, criticalityDiagnostics *ngapType.CriticalityDiagnostics, ) ([]byte, error) { + ranUeCtx := ranUe.GetSharedCtx() + var pdu ngapType.NGAPPDU pdu.Present = ngapType.NGAPPDUPresentSuccessfulOutcome pdu.SuccessfulOutcome = new(ngapType.SuccessfulOutcome) @@ -1051,7 +1059,7 @@ func BuildPDUSessionResourceModifyResponse( ie.Criticality.Value = ngapType.CriticalityPresentIgnore ie.Value.Present = ngapType.PDUSessionResourceModifyResponseIEsPresentAMFUENGAPID ie.Value.AMFUENGAPID = &ngapType.AMFUENGAPID{ - Value: ranUe.AmfUeNgapId, + Value: ranUeCtx.AmfUeNgapId, } pduSessionResourceModifyResponseIEs.List = append(pduSessionResourceModifyResponseIEs.List, ie) @@ -1061,7 +1069,7 @@ func BuildPDUSessionResourceModifyResponse( ie.Criticality.Value = ngapType.CriticalityPresentIgnore ie.Value.Present = ngapType.PDUSessionResourceModifyResponseIEsPresentRANUENGAPID ie.Value.RANUENGAPID = &ngapType.RANUENGAPID{ - Value: ranUe.RanUeNgapId, + Value: ranUeCtx.RanUeNgapId, } pduSessionResourceModifyResponseIEs.List = append(pduSessionResourceModifyResponseIEs.List, ie) @@ -1090,15 +1098,7 @@ func BuildPDUSessionResourceModifyResponse( ie.Id.Value = ngapType.ProtocolIEIDUserLocationInformation ie.Criticality.Value = ngapType.CriticalityPresentIgnore ie.Value.Present = ngapType.PDUSessionResourceModifyResponseIEsPresentUserLocationInformation - ie.Value.UserLocationInformation = new(ngapType.UserLocationInformation) - - userLocationInformation := ie.Value.UserLocationInformation - userLocationInformation.Present = ngapType.UserLocationInformationPresentUserLocationInformationN3IWF - userLocationInformation.UserLocationInformationN3IWF = new(ngapType.UserLocationInformationN3IWF) - - userLocationInformationN3IWF := userLocationInformation.UserLocationInformationN3IWF - userLocationInformationN3IWF.IPAddress = ngapConvert.IPAddressToNgap(ranUe.IPAddrv4, ranUe.IPAddrv6) - userLocationInformationN3IWF.PortNumber = ngapConvert.PortNumberToNgap(ranUe.PortNumber) + ie.Value.UserLocationInformation = ranUe.GetUserLocationInformation() pduSessionResourceModifyResponseIEs.List = append(pduSessionResourceModifyResponseIEs.List, ie) @@ -1115,9 +1115,11 @@ func BuildPDUSessionResourceModifyResponse( } func BuildPDUSessionResourceModifyIndication( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, modifyList []ngapType.PDUSessionResourceModifyItemModInd, ) ([]byte, error) { + ranUeCtx := ranUe.GetSharedCtx() + var pdu ngapType.NGAPPDU pdu.Present = ngapType.NGAPPDUPresentInitiatingMessage pdu.InitiatingMessage = new(ngapType.InitiatingMessage) @@ -1140,7 +1142,7 @@ func BuildPDUSessionResourceModifyIndication( ie.Value.AMFUENGAPID = new(ngapType.AMFUENGAPID) aMFUENGAPID := ie.Value.AMFUENGAPID - aMFUENGAPID.Value = ranUe.AmfUeNgapId + aMFUENGAPID.Value = ranUeCtx.AmfUeNgapId pDUSessionResourceModifyIndicationIEs.List = append(pDUSessionResourceModifyIndicationIEs.List, ie) } @@ -1153,7 +1155,7 @@ func BuildPDUSessionResourceModifyIndication( ie.Value.RANUENGAPID = new(ngapType.RANUENGAPID) rANUENGAPID := ie.Value.RANUENGAPID - rANUENGAPID.Value = ranUe.RanUeNgapId + rANUENGAPID.Value = ranUeCtx.RanUeNgapId pDUSessionResourceModifyIndicationIEs.List = append(pDUSessionResourceModifyIndicationIEs.List, ie) } @@ -1175,10 +1177,12 @@ func BuildPDUSessionResourceModifyIndication( } func BuildPDUSessionResourceNotify( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, notiList *ngapType.PDUSessionResourceNotifyList, relList *ngapType.PDUSessionResourceReleasedListNot, ) ([]byte, error) { + ranUeCtx := ranUe.GetSharedCtx() + var pdu ngapType.NGAPPDU pdu.Present = ngapType.NGAPPDUPresentInitiatingMessage pdu.InitiatingMessage = new(ngapType.InitiatingMessage) @@ -1201,7 +1205,7 @@ func BuildPDUSessionResourceNotify( ie.Value.AMFUENGAPID = new(ngapType.AMFUENGAPID) aMFUENGAPID := ie.Value.AMFUENGAPID - aMFUENGAPID.Value = ranUe.AmfUeNgapId + aMFUENGAPID.Value = ranUeCtx.AmfUeNgapId pDUSessionResourceNotifyIEs.List = append(pDUSessionResourceNotifyIEs.List, ie) } @@ -1214,7 +1218,7 @@ func BuildPDUSessionResourceNotify( ie.Value.RANUENGAPID = new(ngapType.RANUENGAPID) rANUENGAPID := ie.Value.RANUENGAPID - rANUENGAPID.Value = ranUe.RanUeNgapId + rANUENGAPID.Value = ranUeCtx.RanUeNgapId pDUSessionResourceNotifyIEs.List = append(pDUSessionResourceNotifyIEs.List, ie) } @@ -1245,21 +1249,12 @@ func BuildPDUSessionResourceNotify( pDUSessionResourceNotifyIEs.List = append(pDUSessionResourceNotifyIEs.List, ie) } // UserLocationInformation - if (ranUe.IPAddrv4 != "" || ranUe.IPAddrv6 != "") && ranUe.PortNumber != 0 { + if (ranUeCtx.IPAddrv4 != "" || ranUeCtx.IPAddrv6 != "") && ranUeCtx.PortNumber != 0 { ie := ngapType.PDUSessionResourceNotifyIEs{} ie.Id.Value = ngapType.ProtocolIEIDUserLocationInformation ie.Criticality.Value = ngapType.CriticalityPresentIgnore ie.Value.Present = ngapType.PDUSessionResourceNotifyIEsPresentUserLocationInformation - ie.Value.UserLocationInformation = new(ngapType.UserLocationInformation) - - userLocationInformation := ie.Value.UserLocationInformation - *userLocationInformation = ngapType.UserLocationInformation{ - Present: ngapType.UserLocationInformationPresentUserLocationInformationN3IWF, - UserLocationInformationN3IWF: &ngapType.UserLocationInformationN3IWF{ - IPAddress: ngapConvert.IPAddressToNgap(ranUe.IPAddrv4, ranUe.IPAddrv6), - PortNumber: ngapConvert.PortNumberToNgap(ranUe.PortNumber), - }, - } + ie.Value.UserLocationInformation = ranUe.GetUserLocationInformation() pDUSessionResourceNotifyIEs.List = append(pDUSessionResourceNotifyIEs.List, ie) } @@ -1268,10 +1263,12 @@ func BuildPDUSessionResourceNotify( } func BuildPDUSessionResourceReleaseResponse( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, relList ngapType.PDUSessionResourceReleasedListRelRes, diagnostics *ngapType.CriticalityDiagnostics, ) ([]byte, error) { + ranUeCtx := ranUe.GetSharedCtx() + var pdu ngapType.NGAPPDU pdu.Present = ngapType.NGAPPDUPresentSuccessfulOutcome pdu.SuccessfulOutcome = new(ngapType.SuccessfulOutcome) @@ -1294,7 +1291,7 @@ func BuildPDUSessionResourceReleaseResponse( ie.Value.AMFUENGAPID = new(ngapType.AMFUENGAPID) aMFUENGAPID := ie.Value.AMFUENGAPID - aMFUENGAPID.Value = ranUe.AmfUeNgapId + aMFUENGAPID.Value = ranUeCtx.AmfUeNgapId pDUSessionResourceReleaseResponseIEs.List = append(pDUSessionResourceReleaseResponseIEs.List, ie) } @@ -1307,7 +1304,7 @@ func BuildPDUSessionResourceReleaseResponse( ie.Value.RANUENGAPID = new(ngapType.RANUENGAPID) rANUENGAPID := ie.Value.RANUENGAPID - rANUENGAPID.Value = ranUe.RanUeNgapId + rANUENGAPID.Value = ranUeCtx.RanUeNgapId pDUSessionResourceReleaseResponseIEs.List = append(pDUSessionResourceReleaseResponseIEs.List, ie) } @@ -1325,21 +1322,12 @@ func BuildPDUSessionResourceReleaseResponse( pDUSessionResourceReleaseResponseIEs.List = append(pDUSessionResourceReleaseResponseIEs.List, ie) } // UserLocationInformation - if (ranUe.IPAddrv4 != "" || ranUe.IPAddrv6 != "") && ranUe.PortNumber != 0 { + if (ranUeCtx.IPAddrv4 != "" || ranUeCtx.IPAddrv6 != "") && ranUeCtx.PortNumber != 0 { ie := ngapType.PDUSessionResourceReleaseResponseIEs{} ie.Id.Value = ngapType.ProtocolIEIDUserLocationInformation ie.Criticality.Value = ngapType.CriticalityPresentIgnore ie.Value.Present = ngapType.PDUSessionResourceReleaseResponseIEsPresentUserLocationInformation - ie.Value.UserLocationInformation = new(ngapType.UserLocationInformation) - - userLocationInformation := ie.Value.UserLocationInformation - *userLocationInformation = ngapType.UserLocationInformation{ - Present: ngapType.UserLocationInformationPresentUserLocationInformationN3IWF, - UserLocationInformationN3IWF: &ngapType.UserLocationInformationN3IWF{ - IPAddress: ngapConvert.IPAddressToNgap(ranUe.IPAddrv4, ranUe.IPAddrv6), - PortNumber: ngapConvert.PortNumberToNgap(ranUe.PortNumber), - }, - } + ie.Value.UserLocationInformation = ranUe.GetUserLocationInformation() pDUSessionResourceReleaseResponseIEs.List = append(pDUSessionResourceReleaseResponseIEs.List, ie) } @@ -1427,9 +1415,11 @@ func BuildUERadioCapabilityInfoIndication() ([]byte, error) { } func BuildUERadioCapabilityCheckResponse( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, diagnostics *ngapType.CriticalityDiagnostics, ) ([]byte, error) { + ranUeCtx := ranUe.GetSharedCtx() + var pdu ngapType.NGAPPDU pdu.Present = ngapType.NGAPPDUPresentSuccessfulOutcome pdu.SuccessfulOutcome = new(ngapType.SuccessfulOutcome) @@ -1452,7 +1442,7 @@ func BuildUERadioCapabilityCheckResponse( ie.Value.AMFUENGAPID = new(ngapType.AMFUENGAPID) aMFUENGAPID := ie.Value.AMFUENGAPID - aMFUENGAPID.Value = ranUe.AmfUeNgapId + aMFUENGAPID.Value = ranUeCtx.AmfUeNgapId uERadioCapabilityCheckResponseIEs.List = append(uERadioCapabilityCheckResponseIEs.List, ie) } // RANUENGAPID @@ -1464,7 +1454,7 @@ func BuildUERadioCapabilityCheckResponse( ie.Value.RANUENGAPID = new(ngapType.RANUENGAPID) rANUENGAPID := ie.Value.RANUENGAPID - rANUENGAPID.Value = ranUe.RanUeNgapId + rANUENGAPID.Value = ranUeCtx.RanUeNgapId uERadioCapabilityCheckResponseIEs.List = append(uERadioCapabilityCheckResponseIEs.List, ie) } // IMSVoiceSupportIndicator @@ -1476,7 +1466,13 @@ func BuildUERadioCapabilityCheckResponse( ie.Value.IMSVoiceSupportIndicator = new(ngapType.IMSVoiceSupportIndicator) iMSVoiceSupportIndicator := ie.Value.IMSVoiceSupportIndicator - iMSVoiceSupportIndicator.Value = aper.Enumerated(ranUe.IMSVoiceSupported) + value := ranUeCtx.IMSVoiceSupported + if value < 0 { + return nil, errors.Errorf("BuildUERadioCapabilityCheckResponse() ranUe.IMSVoiceSupported "+ + "negative value: %d", value) + } + + iMSVoiceSupportIndicator.Value = aper.Enumerated(value) uERadioCapabilityCheckResponseIEs.List = append(uERadioCapabilityCheckResponseIEs.List, ie) } // CriticalityDiagnostics @@ -1758,7 +1754,7 @@ func BuildPDUSessionResourceSetupResponseTransfer( gtpTunnel := qosFlowPerTNLInformation.UPTransportLayerInformation.GTPTunnel teid := make([]byte, 4) - binary.BigEndian.PutUint32(teid, pduSession.GTPConnection.IncomingTEID) + binary.BigEndian.PutUint32(teid, pduSession.GTPConnInfo.IncomingTEID) gtpTunnel.GTPTEID.Value = teid gtpTunnel.TransportLayerAddress = ngapConvert.IPAddressToNgap(gtpBindIPv4, "") diff --git a/internal/ngap/message/send.go b/internal/ngap/message/send.go index 9050a82d..f7dcd306 100644 --- a/internal/ngap/message/send.go +++ b/internal/ngap/message/send.go @@ -3,8 +3,8 @@ package message import ( "runtime/debug" + n3iwf_context "github.com/free5gc/n3iwf/internal/context" "github.com/free5gc/n3iwf/internal/logger" - n3iwf_context "github.com/free5gc/n3iwf/pkg/context" "github.com/free5gc/ngap/ngapType" "github.com/free5gc/sctp" ) @@ -12,7 +12,7 @@ import ( func SendToAmf(amf *n3iwf_context.N3IWFAMF, pkt []byte) { ngapLog := logger.NgapLog if amf == nil { - ngapLog.Errorf("[N3IWF] AMF Context is nil ") + ngapLog.Errorf("AMF Context is nil ") } else { if n, err := amf.SCTPConn.Write(pkt); err != nil { ngapLog.Errorf("Write to SCTP socket failed: %+v", err) @@ -34,14 +34,14 @@ func SendNGSetupRequest( } }() - ngapLog.Infoln("[N3IWF] Send NG Setup Request") + ngapLog.Infoln("Send NG Setup Request") cfg := n3iwfCtx.Config() sctpAddr := conn.RemoteAddr().String() if available, _ := n3iwfCtx.AMFReInitAvailableListLoad(sctpAddr); !available { ngapLog.Warnf( - "[N3IWF] Please Wait at least for the indicated time before reinitiating toward same AMF[%s]", + "Please Wait at least for the indicated time before reinitiating toward same AMF[%s]", sctpAddr) return } @@ -69,7 +69,7 @@ func SendNGReset( partOfNGInterface *ngapType.UEAssociatedLogicalNGConnectionList, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send NG Reset") + ngapLog.Infoln("Send NG Reset") pkt, err := BuildNGReset(cause, partOfNGInterface) if err != nil { @@ -86,7 +86,7 @@ func SendNGResetAcknowledge( diagnostics *ngapType.CriticalityDiagnostics, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send NG Reset Acknowledge") + ngapLog.Infoln("Send NG Reset Acknowledge") if partOfNGInterface != nil && len(partOfNGInterface.List) == 0 { ngapLog.Error("length of partOfNGInterface is 0") @@ -103,13 +103,13 @@ func SendNGResetAcknowledge( } func SendInitialContextSetupResponse( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, responseList *ngapType.PDUSessionResourceSetupListCxtRes, failedList *ngapType.PDUSessionResourceFailedToSetupListCxtRes, criticalityDiagnostics *ngapType.CriticalityDiagnostics, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send Initial Context Setup Response") + ngapLog.Infoln("Send Initial Context Setup Response") if responseList != nil && len(responseList.List) > n3iwf_context.MaxNumOfPDUSessions { ngapLog.Errorln("Pdu List out of range") @@ -127,17 +127,17 @@ func SendInitialContextSetupResponse( return } - SendToAmf(ranUe.AMF, pkt) + SendToAmf(ranUe.GetSharedCtx().AMF, pkt) } func SendInitialContextSetupFailure( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, cause ngapType.Cause, failedList *ngapType.PDUSessionResourceFailedToSetupListCxtFail, criticalityDiagnostics *ngapType.CriticalityDiagnostics, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send Initial Context Setup Failure") + ngapLog.Infoln("Send Initial Context Setup Failure") if failedList != nil && len(failedList.List) > n3iwf_context.MaxNumOfPDUSessions { ngapLog.Errorln("Pdu List out of range") @@ -150,15 +150,15 @@ func SendInitialContextSetupFailure( return } - SendToAmf(ranUe.AMF, pkt) + SendToAmf(ranUe.GetSharedCtx().AMF, pkt) } func SendUEContextModificationResponse( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, criticalityDiagnostics *ngapType.CriticalityDiagnostics, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send UE Context Modification Response") + ngapLog.Infoln("Send UE Context Modification Response") pkt, err := BuildUEContextModificationResponse(ranUe, criticalityDiagnostics) if err != nil { @@ -166,16 +166,16 @@ func SendUEContextModificationResponse( return } - SendToAmf(ranUe.AMF, pkt) + SendToAmf(ranUe.GetSharedCtx().AMF, pkt) } func SendUEContextModificationFailure( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, cause ngapType.Cause, criticalityDiagnostics *ngapType.CriticalityDiagnostics, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send UE Context Modification Failure") + ngapLog.Infoln("Send UE Context Modification Failure") pkt, err := BuildUEContextModificationFailure(ranUe, cause, criticalityDiagnostics) if err != nil { @@ -183,15 +183,15 @@ func SendUEContextModificationFailure( return } - SendToAmf(ranUe.AMF, pkt) + SendToAmf(ranUe.GetSharedCtx().AMF, pkt) } func SendUEContextReleaseComplete( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, criticalityDiagnostics *ngapType.CriticalityDiagnostics, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send UE Context Release Complete") + ngapLog.Infoln("Send UE Context Release Complete") pkt, err := BuildUEContextReleaseComplete(ranUe, criticalityDiagnostics) if err != nil { @@ -199,14 +199,14 @@ func SendUEContextReleaseComplete( return } - SendToAmf(ranUe.AMF, pkt) + SendToAmf(ranUe.GetSharedCtx().AMF, pkt) } func SendUEContextReleaseRequest( - ranUe *n3iwf_context.N3IWFRanUe, cause ngapType.Cause, + ranUe n3iwf_context.RanUe, cause ngapType.Cause, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send UE Context Release Request") + ngapLog.Infoln("Send UE Context Release Request") pkt, err := BuildUEContextReleaseRequest(ranUe, cause) if err != nil { @@ -214,14 +214,14 @@ func SendUEContextReleaseRequest( return } - SendToAmf(ranUe.AMF, pkt) + SendToAmf(ranUe.GetSharedCtx().AMF, pkt) } func SendInitialUEMessage(amf *n3iwf_context.N3IWFAMF, - ranUe *n3iwf_context.N3IWFRanUe, nasPdu []byte, + ranUe n3iwf_context.RanUe, nasPdu []byte, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send Initial UE Message") + ngapLog.Infoln("Send Initial UE Message") // Attach To AMF pkt, err := BuildInitialUEMessage(ranUe, nasPdu, nil) @@ -230,16 +230,16 @@ func SendInitialUEMessage(amf *n3iwf_context.N3IWFAMF, return } - SendToAmf(ranUe.AMF, pkt) - // ranUe.AttachAMF() + SendToAmf(ranUe.GetSharedCtx().AMF, pkt) + // ranUe.AttachAMF() // TODO: Check AttachAMF if is necessary } func SendUplinkNASTransport( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, nasPdu []byte, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send Uplink NAS Transport") + ngapLog.Infoln("Send Uplink NAS Transport") if len(nasPdu) == 0 { ngapLog.Errorln("NAS Pdu is nil") @@ -252,16 +252,16 @@ func SendUplinkNASTransport( return } - SendToAmf(ranUe.AMF, pkt) + SendToAmf(ranUe.GetSharedCtx().AMF, pkt) } func SendNASNonDeliveryIndication( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, nasPdu []byte, cause ngapType.Cause, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send NAS NonDelivery Indication") + ngapLog.Infoln("Send NAS NonDelivery Indication") if len(nasPdu) == 0 { ngapLog.Errorln("NAS Pdu is nil") @@ -270,26 +270,26 @@ func SendNASNonDeliveryIndication( pkt, err := BuildNASNonDeliveryIndication(ranUe, nasPdu, cause) if err != nil { - ngapLog.Errorf("Build Uplink NAS Transport failed : %+v\n", err) + ngapLog.Errorf("Build NAS Non Delivery Indication failed : %+v\n", err) return } - SendToAmf(ranUe.AMF, pkt) + SendToAmf(ranUe.GetSharedCtx().AMF, pkt) } func SendRerouteNASRequest() { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send Reroute NAS Request") + ngapLog.Infoln("Send Reroute NAS Request") } func SendPDUSessionResourceSetupResponse( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, responseList *ngapType.PDUSessionResourceSetupListSURes, failedListSURes *ngapType.PDUSessionResourceFailedToSetupListSURes, criticalityDiagnostics *ngapType.CriticalityDiagnostics, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send PDU Session Resource Setup Response") + ngapLog.Infoln("Send PDU Session Resource Setup Response") if ranUe == nil { ngapLog.Error("UE context is nil, this information is mandatory.") @@ -302,17 +302,17 @@ func SendPDUSessionResourceSetupResponse( return } - SendToAmf(ranUe.AMF, pkt) + SendToAmf(ranUe.GetSharedCtx().AMF, pkt) } func SendPDUSessionResourceModifyResponse( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, responseList *ngapType.PDUSessionResourceModifyListModRes, failedList *ngapType.PDUSessionResourceFailedToModifyListModRes, criticalityDiagnostics *ngapType.CriticalityDiagnostics, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send PDU Session Resource Modify Response") + ngapLog.Infoln("Send PDU Session Resource Modify Response") if ranUe == nil && criticalityDiagnostics == nil { ngapLog.Error("UE context is nil, this information is mandatory") @@ -325,15 +325,15 @@ func SendPDUSessionResourceModifyResponse( return } - SendToAmf(ranUe.AMF, pkt) + SendToAmf(ranUe.GetSharedCtx().AMF, pkt) } func SendPDUSessionResourceModifyIndication( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, modifyList []ngapType.PDUSessionResourceModifyItemModInd, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send PDU Session Resource Modify Indication") + ngapLog.Infoln("Send PDU Session Resource Modify Indication") if ranUe == nil { ngapLog.Error("UE context is nil, this information is mandatory") @@ -351,16 +351,16 @@ func SendPDUSessionResourceModifyIndication( return } - SendToAmf(ranUe.AMF, pkt) + SendToAmf(ranUe.GetSharedCtx().AMF, pkt) } func SendPDUSessionResourceNotify( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, notiList *ngapType.PDUSessionResourceNotifyList, relList *ngapType.PDUSessionResourceReleasedListNot, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send PDU Session Resource Notify") + ngapLog.Infoln("Send PDU Session Resource Notify") if ranUe == nil { ngapLog.Error("UE context is nil, this information is mandatory") @@ -373,16 +373,16 @@ func SendPDUSessionResourceNotify( return } - SendToAmf(ranUe.AMF, pkt) + SendToAmf(ranUe.GetSharedCtx().AMF, pkt) } func SendPDUSessionResourceReleaseResponse( - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, relList ngapType.PDUSessionResourceReleasedListRelRes, diagnostics *ngapType.CriticalityDiagnostics, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send PDU Session Resource Release Response") + ngapLog.Infoln("Send PDU Session Resource Release Response") if ranUe == nil { ngapLog.Error("UE context is nil, this information is mandatory") @@ -400,7 +400,7 @@ func SendPDUSessionResourceReleaseResponse( return } - SendToAmf(ranUe.AMF, pkt) + SendToAmf(ranUe.GetSharedCtx().AMF, pkt) } func SendErrorIndication( @@ -411,7 +411,7 @@ func SendErrorIndication( criticalityDiagnostics *ngapType.CriticalityDiagnostics, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send Error Indication") + ngapLog.Infoln("Send Error Indication") if (cause == nil) && (criticalityDiagnostics == nil) { ngapLog.Errorln("Both cause and criticality is nil. This message shall contain at least one of them.") @@ -435,7 +435,7 @@ func SendErrorIndicationWithSctpConn( criticalityDiagnostics *ngapType.CriticalityDiagnostics, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send Error Indication") + ngapLog.Infoln("Send Error Indication") if (cause == nil) && (criticalityDiagnostics == nil) { ngapLog.Errorln("Both cause and criticality is nil. This message shall contain at least one of them.") @@ -457,23 +457,23 @@ func SendErrorIndicationWithSctpConn( func SendUERadioCapabilityInfoIndication() { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send UE Radio Capability Info Indication") + ngapLog.Infoln("Send UE Radio Capability Info Indication") } func SendUERadioCapabilityCheckResponse( amf *n3iwf_context.N3IWFAMF, - ranUe *n3iwf_context.N3IWFRanUe, + ranUe n3iwf_context.RanUe, diagnostics *ngapType.CriticalityDiagnostics, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send UE Radio Capability Check Response") + ngapLog.Infoln("Send UE Radio Capability Check Response") pkt, err := BuildUERadioCapabilityCheckResponse(ranUe, diagnostics) if err != nil { ngapLog.Errorf("Build UERadio Capability Check Response failed : %+v\n", err) return } - SendToAmf(ranUe.AMF, pkt) + SendToAmf(ranUe.GetSharedCtx().AMF, pkt) } func SendAMFConfigurationUpdateAcknowledge( @@ -483,7 +483,7 @@ func SendAMFConfigurationUpdateAcknowledge( diagnostics *ngapType.CriticalityDiagnostics, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send AMF Configuration Update Acknowledge") + ngapLog.Infoln("Send AMF Configuration Update Acknowledge") pkt, err := BuildAMFConfigurationUpdateAcknowledge(setupList, failList, diagnostics) if err != nil { @@ -501,7 +501,7 @@ func SendAMFConfigurationUpdateFailure( diagnostics *ngapType.CriticalityDiagnostics, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send AMF Configuration Update Failure") + ngapLog.Infoln("Send AMF Configuration Update Failure") pkt, err := BuildAMFConfigurationUpdateFailure(ngCause, time, diagnostics) if err != nil { ngapLog.Errorf("Build AMF Configuration Update Failure failed : %+v\n", err) @@ -516,12 +516,12 @@ func SendRANConfigurationUpdate( amf *n3iwf_context.N3IWFAMF, ) { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send RAN Configuration Update") + ngapLog.Infoln("Send RAN Configuration Update") available, _ := n3iwfCtx.AMFReInitAvailableListLoad(amf.SCTPAddr) if !available { ngapLog.Warnf( - "[N3IWF] Please Wait at least for the indicated time before reinitiating toward same AMF[%s]", + "Please Wait at least for the indicated time before reinitiating toward same AMF[%s]", amf.SCTPAddr) return } @@ -540,25 +540,25 @@ func SendRANConfigurationUpdate( func SendUplinkRANConfigurationTransfer() { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send Uplink RAN Configuration Transfer") + ngapLog.Infoln("Send Uplink RAN Configuration Transfer") } func SendUplinkRANStatusTransfer() { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send Uplink RAN Status Transfer") + ngapLog.Infoln("Send Uplink RAN Status Transfer") } func SendLocationReportingFailureIndication() { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send Location Reporting Failure Indication") + ngapLog.Infoln("Send Location Reporting Failure Indication") } func SendLocationReport() { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send Location Report") + ngapLog.Infoln("Send Location Report") } func SendRRCInactiveTransitionReport() { ngapLog := logger.NgapLog - ngapLog.Infoln("[N3IWF] Send RRC Inactive Transition Report") + ngapLog.Infoln("Send RRC Inactive Transition Report") } diff --git a/internal/ngap/server.go b/internal/ngap/server.go index e8e58b52..03058395 100644 --- a/internal/ngap/server.go +++ b/internal/ngap/server.go @@ -8,15 +8,16 @@ import ( "sync" "time" + n3iwf_context "github.com/free5gc/n3iwf/internal/context" "github.com/free5gc/n3iwf/internal/logger" "github.com/free5gc/n3iwf/internal/ngap/message" - n3iwf_context "github.com/free5gc/n3iwf/pkg/context" "github.com/free5gc/n3iwf/pkg/factory" lib_ngap "github.com/free5gc/ngap" "github.com/free5gc/sctp" + "github.com/free5gc/util/safe_channel" ) -var ( +const ( RECEIVE_NGAPPACKET_CHANNEL_LEN = 512 RECEIVE_NGAPEVENT_CHANNEL_LEN = 512 ) @@ -25,15 +26,16 @@ type n3iwf interface { Config() *factory.Config Context() *n3iwf_context.N3IWFContext CancelContext() context.Context - IkeEvtCh() chan n3iwf_context.IkeEvt + + SendIkeEvt(n3iwf_context.IkeEvt) } type Server struct { n3iwf - Conn []*sctp.SCTPConn - RcvNgapPktCh chan ReceiveNGAPPacket - RcvEventCh chan n3iwf_context.NgapEvt + conn []*sctp.SCTPConn + rcvPktCh *safe_channel.SafeCh[ReceiveNGAPPacket] + rcvEvtCh *safe_channel.SafeCh[n3iwf_context.NgapEvt] } type ReceiveNGAPPacket struct { @@ -43,10 +45,10 @@ type ReceiveNGAPPacket struct { func NewServer(n3iwf n3iwf) (*Server, error) { s := &Server{ - n3iwf: n3iwf, - RcvNgapPktCh: make(chan ReceiveNGAPPacket, RECEIVE_NGAPPACKET_CHANNEL_LEN), - RcvEventCh: make(chan n3iwf_context.NgapEvt, RECEIVE_NGAPEVENT_CHANNEL_LEN), + n3iwf: n3iwf, } + s.rcvPktCh = safe_channel.NewSafeCh[ReceiveNGAPPacket](RECEIVE_NGAPPACKET_CHANNEL_LEN) + s.rcvEvtCh = safe_channel.NewSafeCh[n3iwf_context.NgapEvt](RECEIVE_NGAPEVENT_CHANNEL_LEN) return s, nil } @@ -80,19 +82,22 @@ func (s *Server) runNgapEventHandler(wg *sync.WaitGroup) { ngapLog.Fatalf("panic: %v\n%s", p, string(debug.Stack())) } ngapLog.Infof("NGAP server stopped") - close(s.RcvEventCh) - close(s.RcvNgapPktCh) + s.rcvEvtCh.Close() + s.rcvPktCh.Close() wg.Done() }() + rcvEvtCh := s.rcvEvtCh.GetRcvChan() + rcvPktCh := s.rcvPktCh.GetRcvChan() + for { select { - case rcvPkt := <-s.RcvNgapPktCh: + case rcvPkt := <-rcvPktCh: if len(rcvPkt.Buf) == 0 { // receiver closed return } s.NGAPDispatch(rcvPkt.Conn, rcvPkt.Buf) - case rcvEvt := <-s.RcvEventCh: + case rcvEvt := <-rcvEvtCh: s.HandleEvent(rcvEvt) } } @@ -175,12 +180,11 @@ func (s *Server) listenAndServe( close(errChan) - s.Conn = append(s.Conn, conn) + s.conn = append(s.conn, conn) - data := make([]byte, 65535) + buf := make([]byte, factory.MAX_BUF_MSG_LEN) for { - n, info, _, err := conn.SCTPRead(data) - + n, info, _, err := conn.SCTPRead(buf) if err != nil { ngapLog.Debugf("[SCTP] AMF SCTP address: %s", remoteAddr) if err == io.EOF || err == io.ErrUnexpectedEOF { @@ -189,34 +193,40 @@ func (s *Server) listenAndServe( if errConn != nil { ngapLog.Errorf("conn close error: %+v", errConn) } - s.RcvNgapPktCh <- ReceiveNGAPPacket{} + s.rcvPktCh.Send(ReceiveNGAPPacket{}) return } ngapLog.Errorf("[SCTP] Read from SCTP connection failed: %+v", err) - } else { - ngapLog.Tracef("[SCTP] Successfully read %d bytes.", n) + return + } - if info == nil || info.PPID != lib_ngap.PPID { - ngapLog.Warn("Received SCTP PPID != 60") - continue - } + ngapLog.Tracef("[SCTP] Successfully read %d bytes.", n) - forwardData := make([]byte, n) - copy(forwardData, data[:n]) + if info == nil || info.PPID != lib_ngap.PPID { + ngapLog.Warn("Received SCTP PPID != 60") + continue + } - s.RcvNgapPktCh <- ReceiveNGAPPacket{ - Conn: conn, - Buf: forwardData[:n], - } + forwardData := make([]byte, n) + copy(forwardData, buf[:n]) + + ngapPkt := ReceiveNGAPPacket{ + Conn: conn, + Buf: forwardData[:n], } + s.rcvPktCh.Send(ngapPkt) } } +func (s *Server) SendNgapEvt(evt n3iwf_context.NgapEvt) { + s.rcvEvtCh.Send(evt) +} + func (s *Server) Stop() { ngapLog := logger.NgapLog ngapLog.Infof("Close NGAP server....") - for _, ngapServerConn := range s.Conn { + for _, ngapServerConn := range s.conn { if err := ngapServerConn.Close(); err != nil { ngapLog.Errorf("Stop ngap server error : %+v", err) } diff --git a/internal/ngap/server_test.go b/internal/ngap/server_test.go index 5c54a3c1..2861999e 100644 --- a/internal/ngap/server_test.go +++ b/internal/ngap/server_test.go @@ -4,19 +4,23 @@ import ( "context" "sync" - n3iwf_context "github.com/free5gc/n3iwf/pkg/context" + n3iwf_context "github.com/free5gc/n3iwf/internal/context" + "github.com/free5gc/n3iwf/internal/ike" "github.com/free5gc/n3iwf/pkg/factory" - "github.com/free5gc/n3iwf/pkg/ike" + "github.com/free5gc/util/safe_channel" ) type n3iwfTestApp struct { - cfg *factory.Config - n3iwfCtx *n3iwf_context.N3IWFContext + cfg *factory.Config + n3iwfCtx *n3iwf_context.N3IWFContext + ctx context.Context + cancel context.CancelFunc + wg *sync.WaitGroup + ngapServer *Server ikeServer *ike.Server - ctx context.Context - cancel context.CancelFunc - wg *sync.WaitGroup + + mockIkeEvtCh *safe_channel.SafeCh[n3iwf_context.IkeEvt] } func (a *n3iwfTestApp) Config() *factory.Config { @@ -31,12 +35,12 @@ func (a *n3iwfTestApp) CancelContext() context.Context { return a.ctx } -func (a *n3iwfTestApp) NgapEvtCh() chan n3iwf_context.NgapEvt { - return a.ngapServer.RcvEventCh +func (a *n3iwfTestApp) SendNgapEvt(evt n3iwf_context.NgapEvt) { + a.ngapServer.SendNgapEvt(evt) } -func (a *n3iwfTestApp) IkeEvtCh() chan n3iwf_context.IkeEvt { - return a.ikeServer.RcvEventCh +func (a *n3iwfTestApp) SendIkeEvt(evt n3iwf_context.IkeEvt) { + a.mockIkeEvtCh.Send(evt) } func NewN3iwfTestApp(cfg *factory.Config) (*n3iwfTestApp, error) { @@ -49,7 +53,7 @@ func NewN3iwfTestApp(cfg *factory.Config) (*n3iwfTestApp, error) { cancel: cancel, wg: &sync.WaitGroup{}, } - + n3iwfApp.mockIkeEvtCh = safe_channel.NewSafeCh[n3iwf_context.IkeEvt](10) n3iwfApp.n3iwfCtx, err = n3iwf_context.NewTestContext(n3iwfApp) if err != nil { return nil, err diff --git a/internal/nwucp/server.go b/internal/nwucp/server.go index 159b9486..e50c1053 100644 --- a/internal/nwucp/server.go +++ b/internal/nwucp/server.go @@ -1,17 +1,18 @@ package nwucp import ( + "bufio" "context" "encoding/binary" - "encoding/hex" + "io" "net" "runtime/debug" "strings" "sync" + n3iwf_context "github.com/free5gc/n3iwf/internal/context" "github.com/free5gc/n3iwf/internal/logger" "github.com/free5gc/n3iwf/internal/ngap/message" - n3iwf_context "github.com/free5gc/n3iwf/pkg/context" "github.com/free5gc/n3iwf/pkg/factory" ) @@ -19,7 +20,8 @@ type n3iwf interface { Config() *factory.Config Context() *n3iwf_context.N3IWFContext CancelContext() context.Context - NgapEvtCh() chan n3iwf_context.NgapEvt + + SendNgapEvt(n3iwf_context.NgapEvt) } type Server struct { @@ -92,54 +94,46 @@ func (s *Server) listenAndServe(wg *sync.WaitGroup) { ranUe, err := n3iwfCtx.RanUeLoadFromIkeSPI(ikeUe.N3IWFIKESecurityAssociation.LocalSPI) if err != nil { - nwucpLog.Errorf("RanUe context not found : %+v", err) + nwucpLog.Errorf("RanUe context not found : %v", err) + continue + } + + n3iwfUe, ok := ranUe.(*n3iwf_context.N3IWFRanUe) + if !ok { + nwucpLog.Errorf("listenAndServe(): [Type Assertion] RanUe -> N3iwfUe failed") continue } + // Store connection - ranUe.TCPConnection = connection + n3iwfUe.TCPConnection = connection - s.NgapEvtCh() <- n3iwf_context.NewNASTCPConnEstablishedCompleteEvt( - ranUe.RanUeNgapId, - ) + s.SendNgapEvt(n3iwf_context.NewNASTCPConnEstablishedCompleteEvt(n3iwfUe.RanUeNgapId)) wg.Add(1) - go serveConn(ranUe, connection, wg) + go serveConn(n3iwfUe, connection, wg) } } -func decapNasMsgFromEnvelope(envelop []byte) []byte { - // According to TS 24.502 8.2.4, - // in order to transport a NAS message over the non-3GPP access between the UE and the N3IWF, - // the NAS message shall be framed in a NAS message envelope as defined in subclause 9.4. - // According to TS 24.502 9.4, - // a NAS message envelope = Length | NAS Message - - // Get NAS Message Length - nasLen := binary.BigEndian.Uint16(envelop[:2]) - nasMsg := make([]byte, nasLen) - copy(nasMsg, envelop[2:2+nasLen]) - - return nasMsg -} - func (s *Server) Stop() { nwucpLog := logger.NWuCPLog nwucpLog.Infof("Close Nwucp server...") if err := s.tcpListener.Close(); err != nil { - nwucpLog.Errorf("Stop nwuup server error : %+v", err) + nwucpLog.Errorf("Stop nwucp server error : %+v", err) } + // TODO: [Bug] TCPConnection may close twice, need to check s.Context().RANUePool.Range( func(key, value interface{}) bool { - ranUe := value.(*n3iwf_context.N3IWFRanUe) - if ranUe.TCPConnection != nil { + ranUe, ok := value.(*n3iwf_context.N3IWFRanUe) + if ok && ranUe.TCPConnection != nil { if err := ranUe.TCPConnection.Close(); err != nil { logger.InitLog.Errorf("Stop nwucp server error : %+v", err) } } return true - }) + }, + ) } // serveConn handle accepted TCP connection. It reads NAS packets @@ -160,21 +154,33 @@ func serveConn(ranUe *n3iwf_context.N3IWFRanUe, connection net.Conn, wg *sync.Wa wg.Done() }() - data := make([]byte, 65535) + connReader := bufio.NewReader(connection) + buf := make([]byte, factory.MAX_BUF_MSG_LEN) for { - n, err := connection.Read(data) + // Read the length of NAS message + n, err := io.ReadFull(connReader, buf[:2]) if err != nil { - nwucpLog.Errorf("Read TCP connection failed: %+v", err) + nwucpLog.Errorf("Read the length of NAS message failed: %+v", err) ranUe.TCPConnection = nil return } - nwucpLog.Tracef("Get NAS PDU from UE:\nNAS length: %d\nNAS content:\n%s", n, hex.Dump(data[:n])) + nasLen := binary.BigEndian.Uint16(buf[:n]) + if uint64(nasLen) > uint64(cap(buf)) { + buf = make([]byte, 0, nasLen) + } - // Decap Nas envelope - forwardData := decapNasMsgFromEnvelope(data) + // Read the NAS message + n, err = io.ReadFull(connReader, buf[:nasLen]) + if err != nil { + nwucpLog.Errorf("Read the NAS message failed: %+v", err) + ranUe.TCPConnection = nil + return + } + fwdNas := make([]byte, n) + copy(fwdNas, buf[:n]) wg.Add(1) - go forward(ranUe, forwardData, wg) + go forward(ranUe, fwdNas, wg) } } diff --git a/internal/nwuup/server.go b/internal/nwuup/server.go index 1be5de8b..fbcc59ef 100644 --- a/internal/nwuup/server.go +++ b/internal/nwuup/server.go @@ -7,14 +7,15 @@ import ( "sync" "github.com/pkg/errors" + "github.com/sirupsen/logrus" "github.com/wmnsk/go-gtp/gtpv1" gtpMsg "github.com/wmnsk/go-gtp/gtpv1/message" "golang.org/x/net/ipv4" + n3iwf_context "github.com/free5gc/n3iwf/internal/context" "github.com/free5gc/n3iwf/internal/gre" gtpQoSMsg "github.com/free5gc/n3iwf/internal/gtp/message" "github.com/free5gc/n3iwf/internal/logger" - n3iwf_context "github.com/free5gc/n3iwf/pkg/context" "github.com/free5gc/n3iwf/pkg/factory" ) @@ -27,13 +28,15 @@ type n3iwf interface { type Server struct { n3iwf - gtpuConn *gtpv1.UPlaneConn greConn *ipv4.PacketConn + gtpuConn *gtpv1.UPlaneConn + log *logrus.Entry } func NewServer(n3iwf n3iwf) (*Server, error) { s := &Server{ n3iwf: n3iwf, + log: logger.NWuUPLog, } return s, nil } @@ -62,8 +65,7 @@ func (s *Server) Run(wg *sync.WaitGroup) error { } func (s *Server) newGreConn() error { - cfg := s.Config() - listenAddr := cfg.GetIPSecGatewayAddr() + listenAddr := s.Config().GetIPSecGatewayAddr() // Setup IPv4 packet connection socket // This socket will only capture GRE encapsulated packet @@ -79,8 +81,7 @@ func (s *Server) newGreConn() error { } func (s *Server) newGtpuConn() error { - cfg := s.Config() - gtpuAddr := cfg.GetGTPBindAddr() + gtpv1.GTPUPort + gtpuAddr := s.Config().GetN3iwfGtpBindAddress() + gtpv1.GTPUPort laddr, err := net.ResolveUDPAddr("udp", gtpuAddr) if err != nil { @@ -94,111 +95,10 @@ func (s *Server) newGtpuConn() error { return nil } -// Parse the fields not supported by go-gtp and forward data to UE. -func (s *Server) handleQoSTPDU(c gtpv1.Conn, senderAddr net.Addr, msg gtpMsg.Message) error { - pdu := gtpQoSMsg.QoSTPDUPacket{} - err := pdu.Unmarshal(msg.(*gtpMsg.TPDU)) - if err != nil { - return err - } - - s.forwardDL(pdu) - return nil -} - -// Forward user plane packets from N3 to UE with GRE header and new IP header encapsulated -func (s *Server) forwardDL(packet gtpQoSMsg.QoSTPDUPacket) { - gtpLog := logger.GTPLog - - defer func() { - if p := recover(); p != nil { - // Print stack for panic to log. Fatalf() will let program exit. - gtpLog.Fatalf("panic: %v\n%s", p, string(debug.Stack())) - } - }() - - n3iwfCtx := s.Context() - - pktTEID := packet.GetTEID() - gtpLog.Tracef("pkt teid : %d", pktTEID) - - // Find UE information - ranUe, ok := n3iwfCtx.AllocatedUETEIDLoad(pktTEID) - if !ok { - gtpLog.Errorf("Cannot find RanUE context from QosPacket TEID : %+v", pktTEID) - return - } - - ikeUe, err := n3iwfCtx.IkeUeLoadFromNgapId(ranUe.RanUeNgapId) - if err != nil { - gtpLog.Errorf("Cannot find IkeUe context from RanUe , NgapID : %+v", ranUe.RanUeNgapId) - return - } - - // UE inner IP in IPSec - ueInnerIPAddr := ikeUe.IPSecInnerIPAddr - - var cm *ipv4.ControlMessage - for _, childSA := range ikeUe.N3IWFChildSecurityAssociation { - pdusession := ranUe.FindPDUSession(childSA.PDUSessionIds[0]) - if pdusession != nil && pdusession.GTPConnection.IncomingTEID == pktTEID { - gtpLog.Tracef("forwarding IPSec xfrm interfaceid : %d", childSA.XfrmIface.Attrs().Index) - cm = &ipv4.ControlMessage{ - IfIndex: childSA.XfrmIface.Attrs().Index, - } - break - } - } - - var ( - qfi uint8 - rqi bool - ) - - // QoS Related Parameter - if packet.HasQoS() { - qfi, rqi = packet.GetQoSParameters() - gtpLog.Tracef("QFI: %v, RQI: %v", qfi, rqi) - } - - // Encasulate IPv4 packet with GRE header before forward to UE through IPsec - grePacket := gre.GREPacket{} - - // TODO:[24.502(v15.7) 9.3.3 ] The Protocol Type field should be set to zero - grePacket.SetPayload(packet.GetPayload(), gre.IPv4) - grePacket.SetQoS(qfi, rqi) - forwardData := grePacket.Marshal() - - // Send to UE through Nwu - if n, err := s.greConn.WriteTo(forwardData, cm, ueInnerIPAddr); err != nil { - gtpLog.Errorf("Write to UE failed: %+v", err) - return - } else { - gtpLog.Trace("Forward NWu <- N3") - gtpLog.Tracef("Wrote %d bytes", n) - } -} - -func (s *Server) gtpuListenAndServe(wg *sync.WaitGroup) { - nwuupLog := logger.NWuUPLog - defer func() { - if p := recover(); p != nil { - // Print stack for panic to log. Fatalf() will let program exit. - nwuupLog.Fatalf("panic: %v\n%s", p, string(debug.Stack())) - } - - wg.Done() - }() - - if err := s.gtpuConn.ListenAndServe(context.Background()); err != nil { - nwuupLog.Errorf("GTP-U server err: %v", err) - } -} - // listenAndServe read from socket and call forward() to // forward packet. func (s *Server) greListenAndServe(wg *sync.WaitGroup) { - nwuupLog := logger.NWuUPLog + nwuupLog := s.log defer func() { if p := recover(); p != nil { // Print stack for panic to log. Fatalf() will let program exit. @@ -212,7 +112,7 @@ func (s *Server) greListenAndServe(wg *sync.WaitGroup) { wg.Done() }() - buffer := make([]byte, 65535) + buf := make([]byte, factory.MAX_BUF_MSG_LEN) err := s.greConn.SetControlMessage(ipv4.FlagInterface|ipv4.FlagTTL, true) if err != nil { @@ -221,7 +121,7 @@ func (s *Server) greListenAndServe(wg *sync.WaitGroup) { } for { - n, cm, src, err := s.greConn.ReadFrom(buffer) + n, cm, src, err := s.greConn.ReadFrom(buf) nwuupLog.Tracef("Read %d bytes, %s", n, cm) if err != nil { nwuupLog.Errorf("Error read from IPv4 packet connection: %+v", err) @@ -229,17 +129,33 @@ func (s *Server) greListenAndServe(wg *sync.WaitGroup) { } forwardData := make([]byte, n) - copy(forwardData, buffer) + copy(forwardData, buf) wg.Add(1) go s.forwardUL(src.String(), cm.IfIndex, forwardData, wg) } } +func (s *Server) gtpuListenAndServe(wg *sync.WaitGroup) { + nwuupLog := s.log + defer func() { + if p := recover(); p != nil { + // Print stack for panic to log. Fatalf() will let program exit. + nwuupLog.Fatalf("panic: %v\n%s", p, string(debug.Stack())) + } + + wg.Done() + }() + + if err := s.gtpuConn.ListenAndServe(context.Background()); err != nil { + nwuupLog.Errorf("GTP-U server err: %v", err) + } +} + // forward forwards user plane packets from NWu to UPF // with GTP header encapsulated func (s *Server) forwardUL(ueInnerIP string, ifIndex int, rawData []byte, wg *sync.WaitGroup) { - nwuupLog := logger.NWuUPLog + nwuupLog := s.log defer func() { if p := recover(); p != nil { // Print stack for panic to log. Fatalf() will let program exit. @@ -268,7 +184,7 @@ func (s *Server) forwardUL(ueInnerIP string, ifIndex int, rawData []byte, wg *sy // Check which child SA the packet come from with interface index, // and find the corresponding PDU session if childSA.XfrmIface != nil && childSA.XfrmIface.Attrs().Index == ifIndex { - pduSession = ranUe.PduSessionList[childSA.PDUSessionIds[0]] + pduSession = ranUe.GetSharedCtx().PduSessionList[childSA.PDUSessionIds[0]] break } } @@ -278,7 +194,7 @@ func (s *Server) forwardUL(ueInnerIP string, ifIndex int, rawData []byte, wg *sy return } - gtpConnection := pduSession.GTPConnection + gtpConnection := pduSession.GTPConnInfo // Decapsulate GRE header and extract QoS Parameters if exist grePacket := gre.GREPacket{} @@ -296,8 +212,12 @@ func (s *Server) forwardUL(ueInnerIP string, ifIndex int, rawData []byte, wg *sy // Encapsulate UL PDU SESSION INFORMATION with extension header if the QoS parameters exist if grePacket.GetKeyFlag() { - qfi := grePacket.GetQFI() - gtpPacket, err := buildQoSGTPPacket(gtpConnection.OutgoingTEID, qfi, payload) + qfi, err := grePacket.GetQFI() + if err != nil { + nwuupLog.Errorf("forwardUL err: %+v", err) + return + } + gtpPacket, err := gtpQoSMsg.BuildQoSGTPPacket(gtpConnection.OutgoingTEID, qfi, payload) if err != nil { nwuupLog.Errorf("buildQoSGTPPacket err: %+v", err) return @@ -321,34 +241,104 @@ func (s *Server) forwardUL(ueInnerIP string, ifIndex int, rawData []byte, wg *sy nwuupLog.Tracef("Wrote %d bytes", n) } -func buildQoSGTPPacket(teid uint32, qfi uint8, payload []byte) ([]byte, error) { - nwuupLog := logger.NWuUPLog - header := gtpMsg.NewHeader(0x34, gtpMsg.MsgTypeTPDU, teid, 0x00, payload).WithExtensionHeaders( - gtpMsg.NewExtensionHeader( - gtpMsg.ExtHeaderTypePDUSessionContainer, - []byte{gtpQoSMsg.UL_PDU_SESSION_INFORMATION_TYPE, qfi}, - gtpMsg.ExtHeaderTypeNoMoreExtensionHeaders, - ), - ) - - b := make([]byte, header.MarshalLen()) - - if err := header.MarshalTo(b); err != nil { - nwuupLog.Errorf("go-gtp MarshalTo err: %v", err) - return nil, err - } - - return b, nil -} - func (s *Server) Stop() { - nwuupLog := logger.NWuUPLog + nwuupLog := s.log nwuupLog.Infof("Close Nwuup server...") if err := s.greConn.Close(); err != nil { nwuupLog.Errorf("Stop nwuup greConn error : %v", err) } + if err := s.gtpuConn.Close(); err != nil { nwuupLog.Errorf("Stop nwuup gtpuConn error : %v", err) } } + +// Parse the fields not supported by go-gtp and forward data to UE. +func (s *Server) handleQoSTPDU(c gtpv1.Conn, senderAddr net.Addr, msg gtpMsg.Message) error { + pdu := gtpQoSMsg.QoSTPDUPacket{} + err := pdu.Unmarshal(msg.(*gtpMsg.TPDU)) + if err != nil { + return err + } + + s.forwardDL(pdu) + return nil +} + +// Forward user plane packets from N3 to UE with GRE header and new IP header encapsulated +func (s *Server) forwardDL(packet gtpQoSMsg.QoSTPDUPacket) { + nwuupLog := s.log + + defer func() { + if p := recover(); p != nil { + // Print stack for panic to log. Fatalf() will let program exit. + nwuupLog.Fatalf("panic: %v\n%s", p, string(debug.Stack())) + } + }() + + n3iwfCtx := s.Context() + pktTEID := packet.GetTEID() + nwuupLog.Tracef("pkt teid : %d", pktTEID) + + // Find UE information + ranUe, ok := n3iwfCtx.AllocatedUETEIDLoad(pktTEID) + if !ok { + nwuupLog.Errorf("Cannot find RanUE context from QosPacket TEID : %+v", pktTEID) + return + } + ranUeNgapID := ranUe.GetSharedCtx().RanUeNgapId + + ikeUe, err := n3iwfCtx.IkeUeLoadFromNgapId(ranUeNgapID) + if err != nil { + nwuupLog.Errorf("Cannot find IkeUe context from RanUe , NgapID : %+v", ranUeNgapID) + return + } + + // UE inner IP in IPSec + ueInnerIPAddr := ikeUe.IPSecInnerIPAddr + + var cm *ipv4.ControlMessage + for _, childSA := range ikeUe.N3IWFChildSecurityAssociation { + pdusession := ranUe.FindPDUSession(childSA.PDUSessionIds[0]) + if pdusession != nil && pdusession.GTPConnInfo.IncomingTEID == pktTEID { + nwuupLog.Tracef("forwarding IPSec xfrm interfaceid : %d", childSA.XfrmIface.Attrs().Index) + cm = &ipv4.ControlMessage{ + IfIndex: childSA.XfrmIface.Attrs().Index, + } + break + } + } + if cm == nil { + nwuupLog.Warnf("forwardDL(): Cannot match TEID(%d) to ChildSA", pktTEID) + return + } + + var ( + qfi uint8 + rqi bool + ) + + // QoS Related Parameter + if packet.HasQoS() { + qfi, rqi = packet.GetQoSParameters() + nwuupLog.Tracef("QFI: %v, RQI: %v", qfi, rqi) + } + + // Encasulate IPv4 packet with GRE header before forward to UE through IPsec + grePacket := gre.GREPacket{} + + // TODO:[24.502(v15.7) 9.3.3 ] The Protocol Type field should be set to zero + grePacket.SetPayload(packet.GetPayload(), gre.IPv4) + grePacket.SetQoS(qfi, rqi) + forwardData := grePacket.Marshal() + + // Send to UE through Nwu + if n, err := s.greConn.WriteTo(forwardData, cm, ueInnerIPAddr); err != nil { + nwuupLog.Errorf("Write to UE failed: %+v", err) + return + } else { + nwuupLog.Trace("Forward NWu <- N3") + nwuupLog.Tracef("Wrote %d bytes", n) + } +} diff --git a/internal/util/hash.go b/internal/util/hash.go new file mode 100644 index 00000000..20abc936 --- /dev/null +++ b/internal/util/hash.go @@ -0,0 +1,13 @@ +package util + +import "hash/crc32" + +func HashCRC32(text string) (uint32, error) { + h := crc32.NewIEEE() + _, err := h.Write([]byte(text)) + if err != nil { + return 0, err + } + + return h.Sum32(), nil +} diff --git a/pkg/app/app.go b/pkg/app/app.go index b6961007..cabfc06a 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -1,7 +1,7 @@ package app import ( - n3iwf_context "github.com/free5gc/n3iwf/pkg/context" + n3iwf_context "github.com/free5gc/n3iwf/internal/context" "github.com/free5gc/n3iwf/pkg/factory" ) diff --git a/pkg/factory/config.go b/pkg/factory/config.go index 2dbed9a0..bbc37eb6 100644 --- a/pkg/factory/config.go +++ b/pkg/factory/config.go @@ -1,7 +1,3 @@ -/* - * N3IWF Configuration Factory - */ - package factory import ( @@ -26,46 +22,48 @@ const ( N3iwfDefaultConfigPath string = "./config/n3iwfcfg.yaml" N3iwfDefaultXfrmIfaceName string = "ipsec" N3iwfDefaultXfrmIfaceId uint32 = 7 + + MAX_BUF_MSG_LEN int = 65535 ) type N3IWFNFInfo struct { - GlobalN3IWFID *GlobalN3IWFID `yaml:"GlobalN3IWFID" valid:"required"` - RanNodeName string `yaml:"Name,omitempty" valid:"optional"` - SupportedTAList []SupportedTAItem `yaml:"SupportedTAList" valid:"required"` + GlobalN3IWFID *GlobalN3IWFID `yaml:"globalN3IWFID" valid:"required"` + RanNodeName string `yaml:"name,omitempty" valid:"optional"` + SupportedTAList []SupportedTAItem `yaml:"supportedTAList" valid:"required"` } type GlobalN3IWFID struct { - PLMNID *PLMNID `yaml:"PLMNID" valid:"required"` - N3IWFID uint16 `yaml:"N3IWFID" valid:"range(0|65535),required"` // with length 2 bytes + PLMNID *PLMNID `yaml:"plmnID" valid:"required"` + N3IWFID uint16 `yaml:"n3iwfID" valid:"range(0|65535),required"` // with length 2 bytes } type SupportedTAItem struct { - TAC string `yaml:"TAC" valid:"hexadecimal,stringlength(6|6),required"` - BroadcastPLMNList []BroadcastPLMNItem `yaml:"BroadcastPLMNList" valid:"required"` + TAC string `yaml:"tac" valid:"hexadecimal,stringlength(6|6),required"` + BroadcastPLMNList []BroadcastPLMNItem `yaml:"broadcastPlmnList" valid:"required"` } type BroadcastPLMNItem struct { - PLMNID *PLMNID `yaml:"PLMNID" valid:"required"` - TAISliceSupportList []SliceSupportItem `yaml:"TAISliceSupportList" valid:"required"` + PLMNID *PLMNID `yaml:"plmnID" valid:"required"` + TAISliceSupportList []SliceSupportItem `yaml:"taiSliceSupportList" valid:"required"` } type PLMNID struct { - Mcc string `yaml:"MCC" valid:"numeric,stringlength(3|3),required"` - Mnc string `yaml:"MNC" valid:"numeric,stringlength(2|3),required"` + Mcc string `yaml:"mcc" valid:"numeric,stringlength(3|3),required"` + Mnc string `yaml:"mnc" valid:"numeric,stringlength(2|3),required"` } type SliceSupportItem struct { - SNSSAI SNSSAIItem `yaml:"SNSSAI" valid:"required"` + SNSSAI SNSSAIItem `yaml:"snssai" valid:"required"` } type SNSSAIItem struct { - SST int32 `yaml:"SST" valid:"required"` - SD string `yaml:"SD,omitempty" valid:"required,hexadecimal,stringlength(6|6)"` + SST int32 `yaml:"sst" valid:"required"` + SD string `yaml:"sd,omitempty" valid:"required,hexadecimal,stringlength(6|6)"` } type AMFSCTPAddresses struct { - IPAddresses []string `yaml:"IP" valid:"required"` - Port int `yaml:"Port,omitempty" valid:"port,optional"` // Default port is 38412 if not defined. + IPAddresses []string `yaml:"ip" valid:"required"` + Port int `yaml:"port,omitempty" valid:"port,optional"` // Default port is 38412 if not defined. } func (a *AMFSCTPAddresses) validate() error { @@ -115,22 +113,22 @@ type Info struct { } type Configuration struct { - N3IWFInfo *N3IWFNFInfo `yaml:"N3IWFInformation" valid:"required"` + N3IWFInfo *N3IWFNFInfo `yaml:"n3iwfInformation" valid:"required"` LocalSctpAddr string `yaml:"localSctpAddr,omitempty" valid:"optional,host"` - AMFSCTPAddresses []AMFSCTPAddresses `yaml:"AMFSCTPAddresses" valid:"required"` - - TCPPort int `yaml:"NASTCPPort" valid:"required,port"` - IKEBindAddr string `yaml:"IKEBindAddress" valid:"required,host"` - IPSecGatewayAddr string `yaml:"IPSecTunnelAddress" valid:"required,host"` - UEIPAddressRange string `yaml:"UEIPAddressRange" valid:"required,cidr"` // e.g. 10.0.1.0/24 - XfrmIfaceName string `yaml:"XFRMInterfaceName" valid:"optional,stringlength(1|10)"` // must != 0 - XfrmIfaceId uint32 `yaml:"XFRMInterfaceID" valid:"optional"` // must != 0 - GTPBindAddr string `yaml:"GTPBindAddress" valid:"required,host"` - FQDN string `yaml:"FQDN" valid:"required,host"` // e.g. n3iwf.Saviah.com - PrivateKey string `yaml:"PrivateKey" valid:"optional"` - CertificateAuthority string `yaml:"CertificateAuthority" valid:"optional"` - Certificate string `yaml:"Certificate" valid:"optional"` - LivenessCheck *TimerValue `yaml:"LivenessCheck" valid:"required"` + AMFSCTPAddresses []AMFSCTPAddresses `yaml:"amfSCTPAddresses" valid:"required"` + + TCPPort int `yaml:"nasTcpPort" valid:"required,port"` + IKEBindAddr string `yaml:"ikeBindAddress" valid:"required,host"` + UEIPAddressRange string `yaml:"ueIpAddressRange" valid:"required,cidr"` // e.g. 10.0.1.0/24 + IPSecGatewayAddr string `yaml:"ipSecTunnelAddress" valid:"required,host"` + XfrmIfaceName string `yaml:"xfrmInterfaceName" valid:"optional,stringlength(1|10)"` // must != 0 + XfrmIfaceId uint32 `yaml:"xfrmInterfaceID" valid:"optional"` // must != 0 + N3IWFGTPBindAddress string `yaml:"n3iwfGtpBindAddress" valid:"required,host"` + FQDN string `yaml:"fqdn" valid:"required,host"` // e.g. n3iwf.Saviah.com + PrivateKey string `yaml:"privateKey" valid:"optional"` + CertificateAuthority string `yaml:"certificateAuthority" valid:"optional"` + Certificate string `yaml:"certificate" valid:"optional"` + LivenessCheck *TimerValue `yaml:"livenessCheck" valid:"required"` } type Logger struct { @@ -318,10 +316,10 @@ func (c *Config) GetIPSecGatewayAddr() string { return c.Configuration.IPSecGatewayAddr } -func (c *Config) GetGTPBindAddr() string { +func (c *Config) GetN3iwfGtpBindAddress() string { c.RLock() defer c.RUnlock() - return c.Configuration.GTPBindAddr + return c.Configuration.N3IWFGTPBindAddress } func (c *Config) GetNasTcpAddr() string { @@ -333,7 +331,7 @@ func (c *Config) GetNasTcpAddr() string { func (c *Config) GetNasTcpPort() uint16 { c.RLock() defer c.RUnlock() - return uint16(c.Configuration.TCPPort) + return uint16(c.Configuration.TCPPort) // #nosec G115 } func (c *Config) GetFQDN() string { diff --git a/pkg/factory/factory.go b/pkg/factory/factory.go index cb10b749..c21e11c7 100644 --- a/pkg/factory/factory.go +++ b/pkg/factory/factory.go @@ -1,7 +1,3 @@ -/* - * N3IWF Configuration Factory - */ - package factory import ( diff --git a/pkg/ike/handler_test.go b/pkg/ike/handler_test.go deleted file mode 100644 index 2a5a57d8..00000000 --- a/pkg/ike/handler_test.go +++ /dev/null @@ -1,45 +0,0 @@ -package ike - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/free5gc/n3iwf/pkg/factory" - ike_message "github.com/free5gc/n3iwf/pkg/ike/message" -) - -func TestRemoveIkeUe(t *testing.T) { - n3iwf, err := NewN3iwfTestApp(&factory.Config{}) - require.NoError(t, err) - - n3iwf.ikeServer, err = NewServer(n3iwf) - require.NoError(t, err) - - n3iwfCtx := n3iwf.n3iwfCtx - ikeSA := n3iwfCtx.NewIKESecurityAssociation() - ikeUe := n3iwfCtx.NewN3iwfIkeUe(ikeSA.LocalSPI) - ikeUe.N3IWFIKESecurityAssociation = ikeSA - ikeSA.IsUseDPD = false - - ikeUe.CreateHalfChildSA(1, 123, 1) - - ikeAuth := &ike_message.SecurityAssociation{} - - ikeAuth.Proposals.BuildProposal(1, 1, []byte{0, 1, 2, 3}) - - childSA, err := ikeUe.CompleteChildSA(1, 456, ikeAuth) - require.NoError(t, err) - - err = n3iwf.ikeServer.removeIkeUe(ikeSA.LocalSPI) - require.NoError(t, err) - - _, ok := n3iwfCtx.IkeUePoolLoad(ikeSA.LocalSPI) - require.False(t, ok) - - _, ok = n3iwfCtx.IKESALoad(ikeSA.LocalSPI) - require.False(t, ok) - - _, ok = ikeUe.N3IWFChildSecurityAssociation[childSA.InboundSPI] - require.False(t, ok) -} diff --git a/pkg/ike/message/build.go b/pkg/ike/message/build.go deleted file mode 100644 index f709324b..00000000 --- a/pkg/ike/message/build.go +++ /dev/null @@ -1,324 +0,0 @@ -package message - -import ( - "encoding/binary" - "net" - - "github.com/free5gc/n3iwf/internal/logger" -) - -func (ikeMessage *IKEMessage) BuildIKEHeader( - initiatorSPI uint64, - responsorSPI uint64, - exchangeType uint8, - flags uint8, - messageID uint32, -) { - ikeMessage.InitiatorSPI = initiatorSPI - ikeMessage.ResponderSPI = responsorSPI - ikeMessage.Version = 0x20 - ikeMessage.ExchangeType = exchangeType - ikeMessage.Flags = flags - ikeMessage.MessageID = messageID -} - -func (container *IKEPayloadContainer) Reset() { - *container = nil -} - -func (container *IKEPayloadContainer) BuildNotification( - protocolID uint8, - notifyMessageType uint16, - spi []byte, - notificationData []byte, -) { - notification := new(Notification) - notification.ProtocolID = protocolID - notification.NotifyMessageType = notifyMessageType - notification.SPI = append(notification.SPI, spi...) - notification.NotificationData = append(notification.NotificationData, notificationData...) - *container = append(*container, notification) -} - -func (container *IKEPayloadContainer) BuildCertificate(certificateEncode uint8, certificateData []byte) { - certificate := new(Certificate) - certificate.CertificateEncoding = certificateEncode - certificate.CertificateData = append(certificate.CertificateData, certificateData...) - *container = append(*container, certificate) -} - -func (container *IKEPayloadContainer) BuildEncrypted(nextPayload IKEPayloadType, encryptedData []byte) *Encrypted { - encrypted := new(Encrypted) - encrypted.NextPayload = uint8(nextPayload) - encrypted.EncryptedData = append(encrypted.EncryptedData, encryptedData...) - *container = append(*container, encrypted) - return encrypted -} - -func (container *IKEPayloadContainer) BUildKeyExchange(diffiehellmanGroup uint16, keyExchangeData []byte) { - keyExchange := new(KeyExchange) - keyExchange.DiffieHellmanGroup = diffiehellmanGroup - keyExchange.KeyExchangeData = append(keyExchange.KeyExchangeData, keyExchangeData...) - *container = append(*container, keyExchange) -} - -func (container *IKEPayloadContainer) BuildIdentificationInitiator(idType uint8, idData []byte) { - identification := new(IdentificationInitiator) - identification.IDType = idType - identification.IDData = append(identification.IDData, idData...) - *container = append(*container, identification) -} - -func (container *IKEPayloadContainer) BuildIdentificationResponder(idType uint8, idData []byte) { - identification := new(IdentificationResponder) - identification.IDType = idType - identification.IDData = append(identification.IDData, idData...) - *container = append(*container, identification) -} - -func (container *IKEPayloadContainer) BuildAuthentication(authenticationMethod uint8, authenticationData []byte) { - authentication := new(Authentication) - authentication.AuthenticationMethod = authenticationMethod - authentication.AuthenticationData = append(authentication.AuthenticationData, authenticationData...) - *container = append(*container, authentication) -} - -func (container *IKEPayloadContainer) BuildConfiguration(configurationType uint8) *Configuration { - configuration := new(Configuration) - configuration.ConfigurationType = configurationType - *container = append(*container, configuration) - return configuration -} - -func (container *ConfigurationAttributeContainer) Reset() { - *container = nil -} - -func (container *ConfigurationAttributeContainer) BuildConfigurationAttribute( - attributeType uint16, - attributeValue []byte, -) { - configurationAttribute := new(IndividualConfigurationAttribute) - configurationAttribute.Type = attributeType - configurationAttribute.Value = append(configurationAttribute.Value, attributeValue...) - *container = append(*container, configurationAttribute) -} - -func (container *IKEPayloadContainer) BuildNonce(nonceData []byte) { - nonce := new(Nonce) - nonce.NonceData = append(nonce.NonceData, nonceData...) - *container = append(*container, nonce) -} - -func (container *IKEPayloadContainer) BuildTrafficSelectorInitiator() *TrafficSelectorInitiator { - trafficSelectorInitiator := new(TrafficSelectorInitiator) - *container = append(*container, trafficSelectorInitiator) - return trafficSelectorInitiator -} - -func (container *IKEPayloadContainer) BuildTrafficSelectorResponder() *TrafficSelectorResponder { - trafficSelectorResponder := new(TrafficSelectorResponder) - *container = append(*container, trafficSelectorResponder) - return trafficSelectorResponder -} - -func (container *IndividualTrafficSelectorContainer) Reset() { - *container = nil -} - -func (container *IndividualTrafficSelectorContainer) BuildIndividualTrafficSelector( - tsType uint8, - ipProtocolID uint8, - startPort uint16, - endPort uint16, - startAddr []byte, - endAddr []byte, -) { - trafficSelector := new(IndividualTrafficSelector) - trafficSelector.TSType = tsType - trafficSelector.IPProtocolID = ipProtocolID - trafficSelector.StartPort = startPort - trafficSelector.EndPort = endPort - trafficSelector.StartAddress = append(trafficSelector.StartAddress, startAddr...) - trafficSelector.EndAddress = append(trafficSelector.EndAddress, endAddr...) - *container = append(*container, trafficSelector) -} - -func (container *IKEPayloadContainer) BuildSecurityAssociation() *SecurityAssociation { - securityAssociation := new(SecurityAssociation) - *container = append(*container, securityAssociation) - return securityAssociation -} - -func (container *ProposalContainer) Reset() { - *container = nil -} - -func (container *ProposalContainer) BuildProposal(proposalNumber uint8, protocolID uint8, spi []byte) *Proposal { - proposal := new(Proposal) - proposal.ProposalNumber = proposalNumber - proposal.ProtocolID = protocolID - proposal.SPI = append(proposal.SPI, spi...) - *container = append(*container, proposal) - return proposal -} - -func (container *IKEPayloadContainer) BuildDeletePayload( - protocolID uint8, SPISize uint8, numberOfSPI uint16, SPIs []byte, -) { - deletePayload := new(Delete) - deletePayload.ProtocolID = protocolID - deletePayload.SPISize = SPISize - deletePayload.NumberOfSPI = numberOfSPI - deletePayload.SPIs = SPIs - *container = append(*container, deletePayload) -} - -func (container *TransformContainer) Reset() { - *container = nil -} - -func (container *TransformContainer) BuildTransform( - transformType uint8, - transformID uint16, - attributeType *uint16, - attributeValue *uint16, - variableLengthAttributeValue []byte, -) { - transform := new(Transform) - transform.TransformType = transformType - transform.TransformID = transformID - if attributeType != nil { - transform.AttributePresent = true - transform.AttributeType = *attributeType - if attributeValue != nil { - transform.AttributeFormat = AttributeFormatUseTV - transform.AttributeValue = *attributeValue - } else if len(variableLengthAttributeValue) != 0 { - transform.AttributeFormat = AttributeFormatUseTLV - transform.VariableLengthAttributeValue = append(transform.VariableLengthAttributeValue, - variableLengthAttributeValue...) - } else { - return - } - } else { - transform.AttributePresent = false - } - *container = append(*container, transform) -} - -func (container *IKEPayloadContainer) BuildEAP(code uint8, identifier uint8) *EAP { - eap := new(EAP) - eap.Code = code - eap.Identifier = identifier - *container = append(*container, eap) - return eap -} - -func (container *IKEPayloadContainer) BuildEAPSuccess(identifier uint8) { - eap := new(EAP) - eap.Code = EAPCodeSuccess - eap.Identifier = identifier - *container = append(*container, eap) -} - -func (container *IKEPayloadContainer) BuildEAPfailure(identifier uint8) { - eap := new(EAP) - eap.Code = EAPCodeFailure - eap.Identifier = identifier - *container = append(*container, eap) -} - -func (container *EAPTypeDataContainer) BuildEAPExpanded(vendorID uint32, vendorType uint32, vendorData []byte) { - eapExpanded := new(EAPExpanded) - eapExpanded.VendorID = vendorID - eapExpanded.VendorType = vendorType - eapExpanded.VendorData = append(eapExpanded.VendorData, vendorData...) - *container = append(*container, eapExpanded) -} - -func (container *IKEPayloadContainer) BuildEAP5GStart(identifier uint8) { - eap := container.BuildEAP(EAPCodeRequest, identifier) - eap.EAPTypeData.BuildEAPExpanded(VendorID3GPP, VendorTypeEAP5G, []byte{EAP5GType5GStart, EAP5GSpareValue}) -} - -func (container *IKEPayloadContainer) BuildEAP5GNAS(identifier uint8, nasPDU []byte) { - ikeLog := logger.IKELog - if len(nasPDU) == 0 { - ikeLog.Error("BuildEAP5GNAS(): NASPDU is nil") - return - } - - header := make([]byte, 4) - - // Message ID - header[0] = EAP5GType5GNAS - // NASPDU length (2 octets) - binary.BigEndian.PutUint16(header[2:4], uint16(len(nasPDU))) - vendorData := append(header, nasPDU...) - - eap := container.BuildEAP(EAPCodeRequest, identifier) - eap.EAPTypeData.BuildEAPExpanded(VendorID3GPP, VendorTypeEAP5G, vendorData) -} - -func (container *IKEPayloadContainer) BuildNotify5G_QOS_INFO( - pduSessionID uint8, - qfiList []uint8, - isDefault bool, - isDSCPSpecified bool, - DSCP uint8, -) { - notifyData := make([]byte, 1) // For length - // Append PDU session ID - notifyData = append(notifyData, pduSessionID) - // Append QFI list length - notifyData = append(notifyData, uint8(len(qfiList))) - // Append QFI list - notifyData = append(notifyData, qfiList...) - // Append default and differentiated service flags - var defaultAndDifferentiatedServiceFlags uint8 - if isDefault { - defaultAndDifferentiatedServiceFlags |= NotifyType5G_QOS_INFOBitDCSICheck - } - if isDSCPSpecified { - defaultAndDifferentiatedServiceFlags |= NotifyType5G_QOS_INFOBitDSCPICheck - } - - notifyData = append(notifyData, defaultAndDifferentiatedServiceFlags) - if isDSCPSpecified { - notifyData = append(notifyData, DSCP) - } - - // Assign length - notifyData[0] = uint8(len(notifyData)) - - container.BuildNotification(TypeNone, Vendor3GPPNotifyType5G_QOS_INFO, nil, notifyData) -} - -func (container *IKEPayloadContainer) BuildNotifyNAS_IP4_ADDRESS(nasIPAddr string) { - if nasIPAddr == "" { - return - } else { - ipAddrByte := net.ParseIP(nasIPAddr).To4() - container.BuildNotification(TypeNone, Vendor3GPPNotifyTypeNAS_IP4_ADDRESS, nil, ipAddrByte) - } -} - -func (container *IKEPayloadContainer) BuildNotifyUP_IP4_ADDRESS(upIPAddr string) { - if upIPAddr == "" { - return - } else { - ipAddrByte := net.ParseIP(upIPAddr).To4() - container.BuildNotification(TypeNone, Vendor3GPPNotifyTypeUP_IP4_ADDRESS, nil, ipAddrByte) - } -} - -func (container *IKEPayloadContainer) BuildNotifyNAS_TCP_PORT(port uint16) { - if port == 0 { - return - } else { - portData := make([]byte, 2) - binary.BigEndian.PutUint16(portData, port) - container.BuildNotification(TypeNone, Vendor3GPPNotifyTypeNAS_TCP_PORT, nil, portData) - } -} diff --git a/pkg/ike/message/message.go b/pkg/ike/message/message.go deleted file mode 100644 index 413f2a17..00000000 --- a/pkg/ike/message/message.go +++ /dev/null @@ -1,1564 +0,0 @@ -package message - -import ( - "encoding/binary" - "encoding/hex" - "errors" - "fmt" - - "github.com/free5gc/n3iwf/internal/logger" -) - -type IKEMessage struct { - InitiatorSPI uint64 - ResponderSPI uint64 - Version uint8 - ExchangeType uint8 - Flags uint8 - MessageID uint32 - Payloads IKEPayloadContainer -} - -func (ikeMessage *IKEMessage) Encode() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("Encoding IKE message") - - ikeMessageData := make([]byte, 28) - - binary.BigEndian.PutUint64(ikeMessageData[0:8], ikeMessage.InitiatorSPI) - binary.BigEndian.PutUint64(ikeMessageData[8:16], ikeMessage.ResponderSPI) - ikeMessageData[17] = ikeMessage.Version - ikeMessageData[18] = ikeMessage.ExchangeType - ikeMessageData[19] = ikeMessage.Flags - binary.BigEndian.PutUint32(ikeMessageData[20:24], ikeMessage.MessageID) - - if len(ikeMessage.Payloads) > 0 { - ikeMessageData[16] = byte(ikeMessage.Payloads[0].Type()) - } else { - ikeMessageData[16] = NoNext - } - - ikeMessagePayloadData, err := ikeMessage.Payloads.Encode() - if err != nil { - return nil, fmt.Errorf("Encode(): EncodePayload failed: %+v", err) - } - - ikeMessageData = append(ikeMessageData, ikeMessagePayloadData...) - binary.BigEndian.PutUint32(ikeMessageData[24:28], uint32(len(ikeMessageData))) - - ikeLog.Tracef("Encoded %d bytes", len(ikeMessageData)) - ikeLog.Tracef("IKE message data:\n%s", hex.Dump(ikeMessageData)) - - return ikeMessageData, nil -} - -func (ikeMessage *IKEMessage) Decode(rawData []byte) error { - // IKE message packet format this implementation referenced is - // defined in RFC 7296, Section 3.1 - ikeLog := logger.IKELog - ikeLog.Info("Decoding IKE message") - ikeLog.Tracef("Received IKE message:\n%s", hex.Dump(rawData)) - - // bounds checking - if len(rawData) < 28 { - return errors.New("Decode(): Received broken IKE header") - } - ikeMessageLength := binary.BigEndian.Uint32(rawData[24:28]) - if ikeMessageLength < 28 { - return fmt.Errorf("Decode(): Illegal IKE message length %d < header length 20", ikeMessageLength) - } - // len() return int, which is 64 bit on 64-bit host and 32 bit - // on 32-bit host, so this implementation may potentially cause - // problem on 32-bit machine - if len(rawData) != int(ikeMessageLength) { - return errors.New("Decode(): The length of received message not matchs the length specified in header") - } - - nextPayload := rawData[16] - - ikeMessage.InitiatorSPI = binary.BigEndian.Uint64(rawData[:8]) - ikeMessage.ResponderSPI = binary.BigEndian.Uint64(rawData[8:16]) - ikeMessage.Version = rawData[17] - ikeMessage.ExchangeType = rawData[18] - ikeMessage.Flags = rawData[19] - ikeMessage.MessageID = binary.BigEndian.Uint32(rawData[20:24]) - - err := ikeMessage.Payloads.Decode(nextPayload, rawData[28:]) - if err != nil { - return fmt.Errorf("Decode(): DecodePayload failed: %+v", err) - } - - return nil -} - -type IKEPayloadContainer []IKEPayload - -func (container *IKEPayloadContainer) Encode() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("Encoding IKE payloads") - - ikeMessagePayloadData := make([]byte, 0) - - for index, payload := range *container { - payloadData := make([]byte, 4) // IKE payload general header - if (index + 1) < len(*container) { // if it has next payload - payloadData[0] = uint8((*container)[index+1].Type()) - } else { - if payload.Type() == TypeSK { - payloadData[0] = payload.(*Encrypted).NextPayload - } else { - payloadData[0] = NoNext - } - } - - data, err := payload.marshal() - if err != nil { - return nil, fmt.Errorf("EncodePayload(): Failed to marshal payload: %+v", err) - } - - payloadData = append(payloadData, data...) - binary.BigEndian.PutUint16(payloadData[2:4], uint16(len(payloadData))) - - ikeMessagePayloadData = append(ikeMessagePayloadData, payloadData...) - } - - return ikeMessagePayloadData, nil -} - -func (container *IKEPayloadContainer) Decode(nextPayload uint8, rawData []byte) error { - ikeLog := logger.IKELog - ikeLog.Info("Decoding IKE payloads") - - for len(rawData) > 0 { - // bounds checking - ikeLog.Trace("DecodePayload(): Decode 1 payload") - if len(rawData) < 4 { - return errors.New("DecodePayload(): No sufficient bytes to decode next payload") - } - payloadLength := binary.BigEndian.Uint16(rawData[2:4]) - if payloadLength < 4 { - return fmt.Errorf("DecodePayload(): Illegal payload length %d < header length 4", payloadLength) - } - if len(rawData) < int(payloadLength) { - return errors.New("DecodePayload(): The length of received message not matchs the length specified in header") - } - - criticalBit := (rawData[1] & 0x80) >> 7 - - var payload IKEPayload - - switch nextPayload { - case TypeSA: - payload = new(SecurityAssociation) - case TypeKE: - payload = new(KeyExchange) - case TypeIDi: - payload = new(IdentificationInitiator) - case TypeIDr: - payload = new(IdentificationResponder) - case TypeCERT: - payload = new(Certificate) - case TypeCERTreq: - payload = new(CertificateRequest) - case TypeAUTH: - payload = new(Authentication) - case TypeNiNr: - payload = new(Nonce) - case TypeN: - payload = new(Notification) - case TypeD: - payload = new(Delete) - case TypeV: - payload = new(VendorID) - case TypeTSi: - payload = new(TrafficSelectorInitiator) - case TypeTSr: - payload = new(TrafficSelectorResponder) - case TypeSK: - encryptedPayload := new(Encrypted) - encryptedPayload.NextPayload = rawData[0] - payload = encryptedPayload - case TypeCP: - payload = new(Configuration) - case TypeEAP: - payload = new(EAP) - default: - if criticalBit == 0 { - // Skip this payload - nextPayload = rawData[0] - rawData = rawData[payloadLength:] - continue - } else { - // TODO: Reject this IKE message - return fmt.Errorf("Unknown payload type: %d", nextPayload) - } - } - - if err := payload.unmarshal(rawData[4:payloadLength]); err != nil { - return fmt.Errorf("DecodePayload(): Unmarshal payload failed: %+v", err) - } - - *container = append(*container, payload) - - nextPayload = rawData[0] - rawData = rawData[payloadLength:] - } - - return nil -} - -type IKEPayload interface { - // Type specifies the IKE payload types - Type() IKEPayloadType - - // Called by Encode() or Decode() - marshal() ([]byte, error) - unmarshal(rawData []byte) error -} - -// Definition of Security Association - -var _ IKEPayload = &SecurityAssociation{} - -type SecurityAssociation struct { - Proposals ProposalContainer -} - -type ProposalContainer []*Proposal - -type Proposal struct { - ProposalNumber uint8 - ProtocolID uint8 - SPI []byte - EncryptionAlgorithm TransformContainer - PseudorandomFunction TransformContainer - IntegrityAlgorithm TransformContainer - DiffieHellmanGroup TransformContainer - ExtendedSequenceNumbers TransformContainer -} - -type TransformContainer []*Transform - -type Transform struct { - TransformType uint8 - TransformID uint16 - AttributePresent bool - AttributeFormat uint8 - AttributeType uint16 - AttributeValue uint16 - VariableLengthAttributeValue []byte -} - -func (securityAssociation *SecurityAssociation) Type() IKEPayloadType { return TypeSA } - -func (securityAssociation *SecurityAssociation) marshal() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("[SecurityAssociation] marshal(): Start marshalling") - - securityAssociationData := make([]byte, 0) - - for proposalIndex, proposal := range securityAssociation.Proposals { - proposalData := make([]byte, 8) - - if (proposalIndex + 1) < len(securityAssociation.Proposals) { - proposalData[0] = 2 - } else { - proposalData[0] = 0 - } - - proposalData[4] = proposal.ProposalNumber - proposalData[5] = proposal.ProtocolID - - proposalData[6] = uint8(len(proposal.SPI)) - if len(proposal.SPI) > 0 { - proposalData = append(proposalData, proposal.SPI...) - } - - // combine all transforms - var transformList []*Transform - transformList = append(transformList, proposal.EncryptionAlgorithm...) - transformList = append(transformList, proposal.PseudorandomFunction...) - transformList = append(transformList, proposal.IntegrityAlgorithm...) - transformList = append(transformList, proposal.DiffieHellmanGroup...) - transformList = append(transformList, proposal.ExtendedSequenceNumbers...) - - if len(transformList) == 0 { - return nil, errors.New("One proposal has no any transform") - } - proposalData[7] = uint8(len(transformList)) - - proposalTransformData := make([]byte, 0) - - for transformIndex, transform := range transformList { - transformData := make([]byte, 8) - - if (transformIndex + 1) < len(transformList) { - transformData[0] = 3 - } else { - transformData[0] = 0 - } - - transformData[4] = transform.TransformType - binary.BigEndian.PutUint16(transformData[6:8], transform.TransformID) - - if transform.AttributePresent { - attributeData := make([]byte, 4) - - if transform.AttributeFormat == 0 { - // TLV - if len(transform.VariableLengthAttributeValue) == 0 { - return nil, errors.New("Attribute of one transform not specified") - } - attributeFormatAndType := ((uint16(transform.AttributeFormat) & 0x1) << 15) | transform.AttributeType - binary.BigEndian.PutUint16(attributeData[0:2], attributeFormatAndType) - binary.BigEndian.PutUint16(attributeData[2:4], uint16(len(transform.VariableLengthAttributeValue))) - attributeData = append(attributeData, transform.VariableLengthAttributeValue...) - } else { - // TV - attributeFormatAndType := ((uint16(transform.AttributeFormat) & 0x1) << 15) | transform.AttributeType - binary.BigEndian.PutUint16(attributeData[0:2], attributeFormatAndType) - binary.BigEndian.PutUint16(attributeData[2:4], transform.AttributeValue) - } - - transformData = append(transformData, attributeData...) - } - - binary.BigEndian.PutUint16(transformData[2:4], uint16(len(transformData))) - - proposalTransformData = append(proposalTransformData, transformData...) - } - - proposalData = append(proposalData, proposalTransformData...) - binary.BigEndian.PutUint16(proposalData[2:4], uint16(len(proposalData))) - - securityAssociationData = append(securityAssociationData, proposalData...) - } - - return securityAssociationData, nil -} - -func (securityAssociation *SecurityAssociation) unmarshal(rawData []byte) error { - ikeLog := logger.IKELog - ikeLog.Info("[SecurityAssociation] unmarshal(): Start unmarshalling received bytes") - ikeLog.Tracef("[SecurityAssociation] unmarshal(): Payload length %d bytes", len(rawData)) - - for len(rawData) > 0 { - ikeLog.Trace("[SecurityAssociation] unmarshal(): Unmarshal 1 proposal") - // bounds checking - if len(rawData) < 8 { - return errors.New("Proposal: No sufficient bytes to decode next proposal") - } - proposalLength := binary.BigEndian.Uint16(rawData[2:4]) - if proposalLength < 8 { - return errors.New("Proposal: Illegal payload length %d < header length 8") - } - if len(rawData) < int(proposalLength) { - return errors.New("Proposal: The length of received message not matchs the length specified in header") - } - - // Log whether this proposal is the last - if rawData[0] == 0 { - ikeLog.Trace("[SecurityAssociation] This proposal is the last") - } - // Log the number of transform in the proposal - ikeLog.Tracef("[SecurityAssociation] This proposal contained %d transform", rawData[7]) - - proposal := new(Proposal) - var transformData []byte - - proposal.ProposalNumber = rawData[4] - proposal.ProtocolID = rawData[5] - - spiSize := rawData[6] - if spiSize > 0 { - // bounds checking - if len(rawData) < int(8+spiSize) { - return errors.New("Proposal: No sufficient bytes for unmarshalling SPI of proposal") - } - proposal.SPI = append(proposal.SPI, rawData[8:8+spiSize]...) - } - - transformData = rawData[8+spiSize : proposalLength] - - for len(transformData) > 0 { - // bounds checking - ikeLog.Trace("[SecurityAssociation] unmarshal(): Unmarshal 1 transform") - if len(transformData) < 8 { - return errors.New("Transform: No sufficient bytes to decode next transform") - } - transformLength := binary.BigEndian.Uint16(transformData[2:4]) - if transformLength < 8 { - return errors.New("Transform: Illegal payload length %d < header length 8") - } - if len(transformData) < int(transformLength) { - return errors.New("Transform: The length of received message not matchs the length specified in header") - } - - // Log whether this transform is the last - if transformData[0] == 0 { - ikeLog.Trace("[SecurityAssociation] This transform is the last") - } - - transform := new(Transform) - - transform.TransformType = transformData[4] - transform.TransformID = binary.BigEndian.Uint16(transformData[6:8]) - if transformLength > 8 { - transform.AttributePresent = true - transform.AttributeFormat = ((transformData[8] & 0x80) >> 7) - transform.AttributeType = binary.BigEndian.Uint16(transformData[8:10]) & 0x7f - - if transform.AttributeFormat == 0 { - attributeLength := binary.BigEndian.Uint16(transformData[10:12]) - // bounds checking - if (12 + attributeLength) != transformLength { - return fmt.Errorf("Illegal attribute length %d not satisfies the transform length %d", - attributeLength, transformLength) - } - copy(transform.VariableLengthAttributeValue, transformData[12:12+attributeLength]) - } else { - transform.AttributeValue = binary.BigEndian.Uint16(transformData[10:12]) - } - } - - switch transform.TransformType { - case TypeEncryptionAlgorithm: - proposal.EncryptionAlgorithm = append(proposal.EncryptionAlgorithm, transform) - case TypePseudorandomFunction: - proposal.PseudorandomFunction = append(proposal.PseudorandomFunction, transform) - case TypeIntegrityAlgorithm: - proposal.IntegrityAlgorithm = append(proposal.IntegrityAlgorithm, transform) - case TypeDiffieHellmanGroup: - proposal.DiffieHellmanGroup = append(proposal.DiffieHellmanGroup, transform) - case TypeExtendedSequenceNumbers: - proposal.ExtendedSequenceNumbers = append(proposal.ExtendedSequenceNumbers, transform) - } - - transformData = transformData[transformLength:] - } - - securityAssociation.Proposals = append(securityAssociation.Proposals, proposal) - - rawData = rawData[proposalLength:] - } - - return nil -} - -// Definition of Key Exchange - -var _ IKEPayload = &KeyExchange{} - -type KeyExchange struct { - DiffieHellmanGroup uint16 - KeyExchangeData []byte -} - -func (keyExchange *KeyExchange) Type() IKEPayloadType { return TypeKE } - -func (keyExchange *KeyExchange) marshal() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("[KeyExchange] marshal(): Start marshalling") - - keyExchangeData := make([]byte, 4) - - binary.BigEndian.PutUint16(keyExchangeData[0:2], keyExchange.DiffieHellmanGroup) - keyExchangeData = append(keyExchangeData, keyExchange.KeyExchangeData...) - - return keyExchangeData, nil -} - -func (keyExchange *KeyExchange) unmarshal(rawData []byte) error { - ikeLog := logger.IKELog - ikeLog.Info("[KeyExchange] unmarshal(): Start unmarshalling received bytes") - ikeLog.Tracef("[KeyExchange] unmarshal(): Payload length %d bytes", len(rawData)) - - if len(rawData) > 0 { - ikeLog.Trace("[KeyExchange] unmarshal(): Unmarshal 1 key exchange data") - // bounds checking - if len(rawData) <= 4 { - return errors.New("KeyExchange: No sufficient bytes to decode next key exchange data") - } - - keyExchange.DiffieHellmanGroup = binary.BigEndian.Uint16(rawData[0:2]) - keyExchange.KeyExchangeData = append(keyExchange.KeyExchangeData, rawData[4:]...) - } - - return nil -} - -// Definition of Identification - Initiator - -var _ IKEPayload = &IdentificationInitiator{} - -type IdentificationInitiator struct { - IDType uint8 - IDData []byte -} - -func (identification *IdentificationInitiator) Type() IKEPayloadType { return TypeIDi } - -func (identification *IdentificationInitiator) marshal() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("[Identification] marshal(): Start marshalling") - - identificationData := make([]byte, 4) - - identificationData[0] = identification.IDType - identificationData = append(identificationData, identification.IDData...) - - return identificationData, nil -} - -func (identification *IdentificationInitiator) unmarshal(rawData []byte) error { - ikeLog := logger.IKELog - ikeLog.Info("[Identification] unmarshal(): Start unmarshalling received bytes") - ikeLog.Tracef("[Identification] unmarshal(): Payload length %d bytes", len(rawData)) - - if len(rawData) > 0 { - ikeLog.Trace("[Identification] unmarshal(): Unmarshal 1 identification") - // bounds checking - if len(rawData) <= 4 { - return errors.New("Identification: No sufficient bytes to decode next identification") - } - - identification.IDType = rawData[0] - identification.IDData = append(identification.IDData, rawData[4:]...) - } - - return nil -} - -// Definition of Identification - Responder - -var _ IKEPayload = &IdentificationResponder{} - -type IdentificationResponder struct { - IDType uint8 - IDData []byte -} - -func (identification *IdentificationResponder) Type() IKEPayloadType { return TypeIDr } - -func (identification *IdentificationResponder) marshal() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("[Identification] marshal(): Start marshalling") - - identificationData := make([]byte, 4) - - identificationData[0] = identification.IDType - identificationData = append(identificationData, identification.IDData...) - - return identificationData, nil -} - -func (identification *IdentificationResponder) unmarshal(rawData []byte) error { - ikeLog := logger.IKELog - ikeLog.Info("[Identification] unmarshal(): Start unmarshalling received bytes") - ikeLog.Tracef("[Identification] unmarshal(): Payload length %d bytes", len(rawData)) - - if len(rawData) > 0 { - ikeLog.Trace("[Identification] unmarshal(): Unmarshal 1 identification") - // bounds checking - if len(rawData) <= 4 { - return errors.New("Identification: No sufficient bytes to decode next identification") - } - - identification.IDType = rawData[0] - identification.IDData = append(identification.IDData, rawData[4:]...) - } - - return nil -} - -// Definition of Certificate - -var _ IKEPayload = &Certificate{} - -type Certificate struct { - CertificateEncoding uint8 - CertificateData []byte -} - -func (certificate *Certificate) Type() IKEPayloadType { return TypeCERT } - -func (certificate *Certificate) marshal() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("[Certificate] marshal(): Start marshalling") - - certificateData := make([]byte, 1) - - certificateData[0] = certificate.CertificateEncoding - certificateData = append(certificateData, certificate.CertificateData...) - - return certificateData, nil -} - -func (certificate *Certificate) unmarshal(rawData []byte) error { - ikeLog := logger.IKELog - ikeLog.Info("[Certificate] unmarshal(): Start unmarshalling received bytes") - ikeLog.Tracef("[Certificate] unmarshal(): Payload length %d bytes", len(rawData)) - - if len(rawData) > 0 { - ikeLog.Trace("[Certificate] unmarshal(): Unmarshal 1 certificate") - // bounds checking - if len(rawData) <= 1 { - return errors.New("Certificate: No sufficient bytes to decode next certificate") - } - - certificate.CertificateEncoding = rawData[0] - certificate.CertificateData = append(certificate.CertificateData, rawData[1:]...) - } - - return nil -} - -// Definition of Certificate Request - -var _ IKEPayload = &CertificateRequest{} - -type CertificateRequest struct { - CertificateEncoding uint8 - CertificationAuthority []byte -} - -func (certificateRequest *CertificateRequest) Type() IKEPayloadType { return TypeCERTreq } - -func (certificateRequest *CertificateRequest) marshal() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("[CertificateRequest] marshal(): Start marshalling") - - certificateRequestData := make([]byte, 1) - - certificateRequestData[0] = certificateRequest.CertificateEncoding - certificateRequestData = append(certificateRequestData, certificateRequest.CertificationAuthority...) - - return certificateRequestData, nil -} - -func (certificateRequest *CertificateRequest) unmarshal(rawData []byte) error { - ikeLog := logger.IKELog - ikeLog.Info("[CertificateRequest] unmarshal(): Start unmarshalling received bytes") - ikeLog.Tracef("[CertificateRequest] unmarshal(): Payload length %d bytes", len(rawData)) - - if len(rawData) > 0 { - ikeLog.Trace("[CertificateRequest] unmarshal(): Unmarshal 1 certificate request") - // bounds checking - if len(rawData) <= 1 { - return errors.New("CertificateRequest: No sufficient bytes to decode next certificate request") - } - - certificateRequest.CertificateEncoding = rawData[0] - certificateRequest.CertificationAuthority = append(certificateRequest.CertificationAuthority, rawData[1:]...) - } - - return nil -} - -// Definition of Authentication - -var _ IKEPayload = &Authentication{} - -type Authentication struct { - AuthenticationMethod uint8 - AuthenticationData []byte -} - -func (authentication *Authentication) Type() IKEPayloadType { return TypeAUTH } - -func (authentication *Authentication) marshal() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("[Authentication] marshal(): Start marshalling") - - authenticationData := make([]byte, 4) - - authenticationData[0] = authentication.AuthenticationMethod - authenticationData = append(authenticationData, authentication.AuthenticationData...) - - return authenticationData, nil -} - -func (authentication *Authentication) unmarshal(rawData []byte) error { - ikeLog := logger.IKELog - ikeLog.Info("[Authentication] unmarshal(): Start unmarshalling received bytes") - ikeLog.Tracef("[Authentication] unmarshal(): Payload length %d bytes", len(rawData)) - - if len(rawData) > 0 { - ikeLog.Trace("[Authentication] unmarshal(): Unmarshal 1 authentication") - // bounds checking - if len(rawData) <= 4 { - return errors.New("Authentication: No sufficient bytes to decode next authentication") - } - - authentication.AuthenticationMethod = rawData[0] - authentication.AuthenticationData = append(authentication.AuthenticationData, rawData[4:]...) - } - - return nil -} - -// Definition of Nonce - -var _ IKEPayload = &Nonce{} - -type Nonce struct { - NonceData []byte -} - -func (nonce *Nonce) Type() IKEPayloadType { return TypeNiNr } - -func (nonce *Nonce) marshal() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("[Nonce] marshal(): Start marshalling") - - nonceData := make([]byte, 0) - nonceData = append(nonceData, nonce.NonceData...) - - return nonceData, nil -} - -func (nonce *Nonce) unmarshal(rawData []byte) error { - ikeLog := logger.IKELog - ikeLog.Info("[Nonce] unmarshal(): Start unmarshalling received bytes") - ikeLog.Tracef("[Nonce] unmarshal(): Payload length %d bytes", len(rawData)) - - if len(rawData) > 0 { - ikeLog.Trace("[Nonce] unmarshal(): Unmarshal 1 nonce") - nonce.NonceData = append(nonce.NonceData, rawData...) - } - - return nil -} - -// Definition of Notification - -var _ IKEPayload = &Notification{} - -type Notification struct { - ProtocolID uint8 - NotifyMessageType uint16 - SPI []byte - NotificationData []byte -} - -func (notification *Notification) Type() IKEPayloadType { return TypeN } - -func (notification *Notification) marshal() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("[Notification] marshal(): Start marshalling") - - notificationData := make([]byte, 4) - - notificationData[0] = notification.ProtocolID - notificationData[1] = uint8(len(notification.SPI)) - binary.BigEndian.PutUint16(notificationData[2:4], notification.NotifyMessageType) - - notificationData = append(notificationData, notification.SPI...) - notificationData = append(notificationData, notification.NotificationData...) - - return notificationData, nil -} - -func (notification *Notification) unmarshal(rawData []byte) error { - ikeLog := logger.IKELog - ikeLog.Info("[Notification] unmarshal(): Start unmarshalling received bytes") - ikeLog.Tracef("[Notification] unmarshal(): Payload length %d bytes", len(rawData)) - - if len(rawData) > 0 { - ikeLog.Trace("[Notification] unmarshal(): Unmarshal 1 notification") - // bounds checking - if len(rawData) < 4 { - return errors.New("Notification: No sufficient bytes to decode next notification") - } - spiSize := rawData[1] - if len(rawData) < int(4+spiSize) { - return errors.New("Notification: No sufficient bytes to get SPI according to the length specified in header") - } - - notification.ProtocolID = rawData[0] - notification.NotifyMessageType = binary.BigEndian.Uint16(rawData[2:4]) - - notification.SPI = append(notification.SPI, rawData[4:4+spiSize]...) - notification.NotificationData = append(notification.NotificationData, rawData[4+spiSize:]...) - } - - return nil -} - -// Definition of Delete - -var _ IKEPayload = &Delete{} - -type Delete struct { - ProtocolID uint8 - SPISize uint8 - NumberOfSPI uint16 - SPIs []byte -} - -func (del *Delete) Type() IKEPayloadType { return TypeD } - -func (del *Delete) marshal() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("[Delete] marshal(): Start marshalling") - - if len(del.SPIs) != (int(del.SPISize) * int(del.NumberOfSPI)) { - return nil, fmt.Errorf("Total bytes of all SPIs not correct") - } - - deleteData := make([]byte, 4) - - deleteData[0] = del.ProtocolID - deleteData[1] = del.SPISize - binary.BigEndian.PutUint16(deleteData[2:4], del.NumberOfSPI) - - if int(del.NumberOfSPI) > 0 { - deleteData = append(deleteData, del.SPIs...) - } - - return deleteData, nil -} - -func (del *Delete) unmarshal(rawData []byte) error { - ikeLog := logger.IKELog - ikeLog.Info("[Delete] unmarshal(): Start unmarshalling received bytes") - ikeLog.Tracef("[Delete] unmarshal(): Payload length %d bytes", len(rawData)) - - if len(rawData) > 0 { - ikeLog.Trace("[Delete] unmarshal(): Unmarshal 1 delete") - // bounds checking - if len(rawData) <= 3 { - return errors.New("Delete: No sufficient bytes to decode next delete") - } - spiSize := rawData[1] - numberOfSPI := binary.BigEndian.Uint16(rawData[2:4]) - if len(rawData) < (4 + (int(spiSize) * int(numberOfSPI))) { - return errors.New("Delete: No Sufficient bytes to get SPIs according to the length specified in header") - } - - del.ProtocolID = rawData[0] - del.SPISize = spiSize - del.NumberOfSPI = numberOfSPI - - del.SPIs = append(del.SPIs, rawData[4:]...) - } - - return nil -} - -// Definition of Vendor ID - -var _ IKEPayload = &VendorID{} - -type VendorID struct { - VendorIDData []byte -} - -func (vendorID *VendorID) Type() IKEPayloadType { return TypeV } - -func (vendorID *VendorID) marshal() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("[VendorID] marshal(): Start marshalling") - return vendorID.VendorIDData, nil -} - -func (vendorID *VendorID) unmarshal(rawData []byte) error { - ikeLog := logger.IKELog - ikeLog.Info("[VendorID] unmarshal(): Start unmarshalling received bytes") - ikeLog.Tracef("[VendorID] unmarshal(): Payload length %d bytes", len(rawData)) - - if len(rawData) > 0 { - ikeLog.Trace("[VendorID] unmarshal(): Unmarshal 1 vendor ID") - vendorID.VendorIDData = append(vendorID.VendorIDData, rawData...) - } - - return nil -} - -// Definition of Traffic Selector - Initiator - -var _ IKEPayload = &TrafficSelectorInitiator{} - -type TrafficSelectorInitiator struct { - TrafficSelectors IndividualTrafficSelectorContainer -} - -type IndividualTrafficSelectorContainer []*IndividualTrafficSelector - -type IndividualTrafficSelector struct { - TSType uint8 - IPProtocolID uint8 - StartPort uint16 - EndPort uint16 - StartAddress []byte - EndAddress []byte -} - -func (trafficSelector *TrafficSelectorInitiator) Type() IKEPayloadType { return TypeTSi } - -func (trafficSelector *TrafficSelectorInitiator) marshal() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("[TrafficSelector] marshal(): Start marshalling") - - if len(trafficSelector.TrafficSelectors) > 0 { - trafficSelectorData := make([]byte, 4) - trafficSelectorData[0] = uint8(len(trafficSelector.TrafficSelectors)) - - for _, individualTrafficSelector := range trafficSelector.TrafficSelectors { - if individualTrafficSelector.TSType == TS_IPV4_ADDR_RANGE { - // Address length checking - if len(individualTrafficSelector.StartAddress) != 4 { - ikeLog.Errorf("Address length %d", len(individualTrafficSelector.StartAddress)) - return nil, errors.New("TrafficSelector: Start IPv4 address length is not correct") - } - if len(individualTrafficSelector.EndAddress) != 4 { - return nil, errors.New("TrafficSelector: End IPv4 address length is not correct") - } - - individualTrafficSelectorData := make([]byte, 8) - - individualTrafficSelectorData[0] = individualTrafficSelector.TSType - individualTrafficSelectorData[1] = individualTrafficSelector.IPProtocolID - binary.BigEndian.PutUint16(individualTrafficSelectorData[4:6], individualTrafficSelector.StartPort) - binary.BigEndian.PutUint16(individualTrafficSelectorData[6:8], individualTrafficSelector.EndPort) - - individualTrafficSelectorData = append(individualTrafficSelectorData, individualTrafficSelector.StartAddress...) - individualTrafficSelectorData = append(individualTrafficSelectorData, individualTrafficSelector.EndAddress...) - - binary.BigEndian.PutUint16(individualTrafficSelectorData[2:4], uint16(len(individualTrafficSelectorData))) - - trafficSelectorData = append(trafficSelectorData, individualTrafficSelectorData...) - } else if individualTrafficSelector.TSType == TS_IPV6_ADDR_RANGE { - // Address length checking - if len(individualTrafficSelector.StartAddress) != 16 { - return nil, errors.New("TrafficSelector: Start IPv6 address length is not correct") - } - if len(individualTrafficSelector.EndAddress) != 16 { - return nil, errors.New("TrafficSelector: End IPv6 address length is not correct") - } - - individualTrafficSelectorData := make([]byte, 8) - - individualTrafficSelectorData[0] = individualTrafficSelector.TSType - individualTrafficSelectorData[1] = individualTrafficSelector.IPProtocolID - binary.BigEndian.PutUint16(individualTrafficSelectorData[4:6], individualTrafficSelector.StartPort) - binary.BigEndian.PutUint16(individualTrafficSelectorData[6:8], individualTrafficSelector.EndPort) - - individualTrafficSelectorData = append(individualTrafficSelectorData, individualTrafficSelector.StartAddress...) - individualTrafficSelectorData = append(individualTrafficSelectorData, individualTrafficSelector.EndAddress...) - - binary.BigEndian.PutUint16(individualTrafficSelectorData[2:4], uint16(len(individualTrafficSelectorData))) - - trafficSelectorData = append(trafficSelectorData, individualTrafficSelectorData...) - } else { - return nil, errors.New("TrafficSelector: Unsupported traffic selector type") - } - } - - return trafficSelectorData, nil - } else { - return nil, errors.New("TrafficSelector: Contains no traffic selector for marshalling message") - } -} - -func (trafficSelector *TrafficSelectorInitiator) unmarshal(rawData []byte) error { - ikeLog := logger.IKELog - ikeLog.Info("[TrafficSelector] unmarshal(): Start unmarshalling received bytes") - ikeLog.Tracef("[TrafficSelector] unmarshal(): Payload length %d bytes", len(rawData)) - - if len(rawData) > 0 { - ikeLog.Trace("[TrafficSelector] unmarshal(): Unmarshal 1 traffic selector") - // bounds checking - if len(rawData) < 4 { - return errors.New("TrafficSelector: No sufficient bytes to get number of traffic selector in header") - } - - numberOfSPI := rawData[0] - - rawData = rawData[4:] - - for ; numberOfSPI > 0; numberOfSPI-- { - // bounds checking - if len(rawData) < 4 { - return errors.New( - "TrafficSelector: No sufficient bytes to decode next individual traffic selector length in header") - } - trafficSelectorType := rawData[0] - if trafficSelectorType == TS_IPV4_ADDR_RANGE { - selectorLength := binary.BigEndian.Uint16(rawData[2:4]) - if selectorLength != 16 { - return errors.New("TrafficSelector: A TS_IPV4_ADDR_RANGE type traffic selector should has length 16 bytes") - } - if len(rawData) < int(selectorLength) { - return errors.New("TrafficSelector: No sufficient bytes to decode next individual traffic selector") - } - - individualTrafficSelector := &IndividualTrafficSelector{} - - individualTrafficSelector.TSType = rawData[0] - individualTrafficSelector.IPProtocolID = rawData[1] - individualTrafficSelector.StartPort = binary.BigEndian.Uint16(rawData[4:6]) - individualTrafficSelector.EndPort = binary.BigEndian.Uint16(rawData[6:8]) - - individualTrafficSelector.StartAddress = append(individualTrafficSelector.StartAddress, rawData[8:12]...) - individualTrafficSelector.EndAddress = append(individualTrafficSelector.EndAddress, rawData[12:16]...) - - trafficSelector.TrafficSelectors = append(trafficSelector.TrafficSelectors, individualTrafficSelector) - - rawData = rawData[16:] - } else if trafficSelectorType == TS_IPV6_ADDR_RANGE { - selectorLength := binary.BigEndian.Uint16(rawData[2:4]) - if selectorLength != 40 { - return errors.New("TrafficSelector: A TS_IPV6_ADDR_RANGE type traffic selector should has length 40 bytes") - } - if len(rawData) < int(selectorLength) { - return errors.New("TrafficSelector: No sufficient bytes to decode next individual traffic selector") - } - - individualTrafficSelector := &IndividualTrafficSelector{} - - individualTrafficSelector.TSType = rawData[0] - individualTrafficSelector.IPProtocolID = rawData[1] - individualTrafficSelector.StartPort = binary.BigEndian.Uint16(rawData[4:6]) - individualTrafficSelector.EndPort = binary.BigEndian.Uint16(rawData[6:8]) - - individualTrafficSelector.StartAddress = append(individualTrafficSelector.StartAddress, rawData[8:24]...) - individualTrafficSelector.EndAddress = append(individualTrafficSelector.EndAddress, rawData[24:40]...) - - trafficSelector.TrafficSelectors = append(trafficSelector.TrafficSelectors, individualTrafficSelector) - - rawData = rawData[40:] - } else { - return errors.New("TrafficSelector: Unsupported traffic selector type") - } - } - } - - return nil -} - -// Definition of Traffic Selector - Responder - -var _ IKEPayload = &TrafficSelectorResponder{} - -type TrafficSelectorResponder struct { - TrafficSelectors IndividualTrafficSelectorContainer -} - -func (trafficSelector *TrafficSelectorResponder) Type() IKEPayloadType { return TypeTSr } - -func (trafficSelector *TrafficSelectorResponder) marshal() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("[TrafficSelector] marshal(): Start marshalling") - - if len(trafficSelector.TrafficSelectors) > 0 { - trafficSelectorData := make([]byte, 4) - trafficSelectorData[0] = uint8(len(trafficSelector.TrafficSelectors)) - - for _, individualTrafficSelector := range trafficSelector.TrafficSelectors { - if individualTrafficSelector.TSType == TS_IPV4_ADDR_RANGE { - // Address length checking - if len(individualTrafficSelector.StartAddress) != 4 { - return nil, errors.New("TrafficSelector: Start IPv4 address length is not correct") - } - if len(individualTrafficSelector.EndAddress) != 4 { - return nil, errors.New("TrafficSelector: End IPv4 address length is not correct") - } - - individualTrafficSelectorData := make([]byte, 8) - - individualTrafficSelectorData[0] = individualTrafficSelector.TSType - individualTrafficSelectorData[1] = individualTrafficSelector.IPProtocolID - binary.BigEndian.PutUint16(individualTrafficSelectorData[4:6], individualTrafficSelector.StartPort) - binary.BigEndian.PutUint16(individualTrafficSelectorData[6:8], individualTrafficSelector.EndPort) - - individualTrafficSelectorData = append(individualTrafficSelectorData, individualTrafficSelector.StartAddress...) - individualTrafficSelectorData = append(individualTrafficSelectorData, individualTrafficSelector.EndAddress...) - - binary.BigEndian.PutUint16(individualTrafficSelectorData[2:4], uint16(len(individualTrafficSelectorData))) - - trafficSelectorData = append(trafficSelectorData, individualTrafficSelectorData...) - } else if individualTrafficSelector.TSType == TS_IPV6_ADDR_RANGE { - // Address length checking - if len(individualTrafficSelector.StartAddress) != 16 { - return nil, errors.New("TrafficSelector: Start IPv6 address length is not correct") - } - if len(individualTrafficSelector.EndAddress) != 16 { - return nil, errors.New("TrafficSelector: End IPv6 address length is not correct") - } - - individualTrafficSelectorData := make([]byte, 8) - - individualTrafficSelectorData[0] = individualTrafficSelector.TSType - individualTrafficSelectorData[1] = individualTrafficSelector.IPProtocolID - binary.BigEndian.PutUint16(individualTrafficSelectorData[4:6], individualTrafficSelector.StartPort) - binary.BigEndian.PutUint16(individualTrafficSelectorData[6:8], individualTrafficSelector.EndPort) - - individualTrafficSelectorData = append(individualTrafficSelectorData, individualTrafficSelector.StartAddress...) - individualTrafficSelectorData = append(individualTrafficSelectorData, individualTrafficSelector.EndAddress...) - - binary.BigEndian.PutUint16(individualTrafficSelectorData[2:4], uint16(len(individualTrafficSelectorData))) - - trafficSelectorData = append(trafficSelectorData, individualTrafficSelectorData...) - } else { - return nil, errors.New("TrafficSelector: Unsupported traffic selector type") - } - } - - return trafficSelectorData, nil - } else { - return nil, errors.New("TrafficSelector: Contains no traffic selector for marshalling message") - } -} - -func (trafficSelector *TrafficSelectorResponder) unmarshal(rawData []byte) error { - ikeLog := logger.IKELog - ikeLog.Info("[TrafficSelector] unmarshal(): Start unmarshalling received bytes") - ikeLog.Tracef("[TrafficSelector] unmarshal(): Payload length %d bytes", len(rawData)) - - if len(rawData) > 0 { - ikeLog.Trace("[TrafficSelector] unmarshal(): Unmarshal 1 traffic selector") - // bounds checking - if len(rawData) < 4 { - return errors.New("TrafficSelector: No sufficient bytes to get number of traffic selector in header") - } - - numberOfSPI := rawData[0] - - rawData = rawData[4:] - - for ; numberOfSPI > 0; numberOfSPI-- { - // bounds checking - if len(rawData) < 4 { - return errors.New( - "TrafficSelector: No sufficient bytes to decode next individual traffic selector length in header") - } - trafficSelectorType := rawData[0] - if trafficSelectorType == TS_IPV4_ADDR_RANGE { - selectorLength := binary.BigEndian.Uint16(rawData[2:4]) - if selectorLength != 16 { - return errors.New("TrafficSelector: A TS_IPV4_ADDR_RANGE type traffic selector should has length 16 bytes") - } - if len(rawData) < int(selectorLength) { - return errors.New("TrafficSelector: No sufficient bytes to decode next individual traffic selector") - } - - individualTrafficSelector := &IndividualTrafficSelector{} - - individualTrafficSelector.TSType = rawData[0] - individualTrafficSelector.IPProtocolID = rawData[1] - individualTrafficSelector.StartPort = binary.BigEndian.Uint16(rawData[4:6]) - individualTrafficSelector.EndPort = binary.BigEndian.Uint16(rawData[6:8]) - - individualTrafficSelector.StartAddress = append(individualTrafficSelector.StartAddress, rawData[8:12]...) - individualTrafficSelector.EndAddress = append(individualTrafficSelector.EndAddress, rawData[12:16]...) - - trafficSelector.TrafficSelectors = append(trafficSelector.TrafficSelectors, individualTrafficSelector) - - rawData = rawData[16:] - } else if trafficSelectorType == TS_IPV6_ADDR_RANGE { - selectorLength := binary.BigEndian.Uint16(rawData[2:4]) - if selectorLength != 40 { - return errors.New("TrafficSelector: A TS_IPV6_ADDR_RANGE type traffic selector should has length 40 bytes") - } - if len(rawData) < int(selectorLength) { - return errors.New("TrafficSelector: No sufficient bytes to decode next individual traffic selector") - } - - individualTrafficSelector := &IndividualTrafficSelector{} - - individualTrafficSelector.TSType = rawData[0] - individualTrafficSelector.IPProtocolID = rawData[1] - individualTrafficSelector.StartPort = binary.BigEndian.Uint16(rawData[4:6]) - individualTrafficSelector.EndPort = binary.BigEndian.Uint16(rawData[6:8]) - - individualTrafficSelector.StartAddress = append(individualTrafficSelector.StartAddress, rawData[8:24]...) - individualTrafficSelector.EndAddress = append(individualTrafficSelector.EndAddress, rawData[24:40]...) - - trafficSelector.TrafficSelectors = append(trafficSelector.TrafficSelectors, individualTrafficSelector) - - rawData = rawData[40:] - } else { - return errors.New("TrafficSelector: Unsupported traffic selector type") - } - } - } - - return nil -} - -// Definition of Encrypted Payload - -var _ IKEPayload = &Encrypted{} - -type Encrypted struct { - NextPayload uint8 - EncryptedData []byte -} - -func (encrypted *Encrypted) Type() IKEPayloadType { return TypeSK } - -func (encrypted *Encrypted) marshal() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("[Encrypted] marshal(): Start marshalling") - - if len(encrypted.EncryptedData) == 0 { - ikeLog.Warn("[Encrypted] The encrypted data is empty") - } - - return encrypted.EncryptedData, nil -} - -func (encrypted *Encrypted) unmarshal(rawData []byte) error { - ikeLog := logger.IKELog - ikeLog.Info("[Encrypted] unmarshal(): Start unmarshalling received bytes") - ikeLog.Tracef("[Encrypted] unmarshal(): Payload length %d bytes", len(rawData)) - encrypted.EncryptedData = append(encrypted.EncryptedData, rawData...) - return nil -} - -// Definition of Configuration - -var _ IKEPayload = &Configuration{} - -type Configuration struct { - ConfigurationType uint8 - ConfigurationAttribute ConfigurationAttributeContainer -} - -type ConfigurationAttributeContainer []*IndividualConfigurationAttribute - -type IndividualConfigurationAttribute struct { - Type uint16 - Value []byte -} - -func (configuration *Configuration) Type() IKEPayloadType { return TypeCP } - -func (configuration *Configuration) marshal() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("[Configuration] marshal(): Start marshalling") - - configurationData := make([]byte, 4) - - configurationData[0] = configuration.ConfigurationType - - for _, attribute := range configuration.ConfigurationAttribute { - individualConfigurationAttributeData := make([]byte, 4) - - binary.BigEndian.PutUint16(individualConfigurationAttributeData[0:2], (attribute.Type & 0x7fff)) - binary.BigEndian.PutUint16(individualConfigurationAttributeData[2:4], uint16(len(attribute.Value))) - - individualConfigurationAttributeData = append(individualConfigurationAttributeData, attribute.Value...) - - configurationData = append(configurationData, individualConfigurationAttributeData...) - } - - return configurationData, nil -} - -func (configuration *Configuration) unmarshal(rawData []byte) error { - ikeLog := logger.IKELog - ikeLog.Info("[Configuration] unmarshal(): Start unmarshalling received bytes") - ikeLog.Tracef("[Configuration] unmarshal(): Payload length %d bytes", len(rawData)) - - if len(rawData) > 0 { - ikeLog.Trace("[Configuration] unmarshal(): Unmarshal 1 configuration") - // bounds checking - if len(rawData) <= 4 { - return errors.New("Configuration: No sufficient bytes to decode next configuration") - } - configuration.ConfigurationType = rawData[0] - - configurationAttributeData := rawData[4:] - - for len(configurationAttributeData) > 0 { - ikeLog.Trace("[Configuration] unmarshal(): Unmarshal 1 configuration attribute") - // bounds checking - if len(configurationAttributeData) < 4 { - return errors.New("ConfigurationAttribute: No sufficient bytes to decode next configuration attribute") - } - length := binary.BigEndian.Uint16(configurationAttributeData[2:4]) - if len(configurationAttributeData) < int(4+length) { - return errors.New("ConfigurationAttribute: TLV attribute length error") - } - - individualConfigurationAttribute := new(IndividualConfigurationAttribute) - - individualConfigurationAttribute.Type = binary.BigEndian.Uint16(configurationAttributeData[0:2]) - configurationAttributeData = configurationAttributeData[4:] - individualConfigurationAttribute.Value = append(individualConfigurationAttribute.Value, - configurationAttributeData[:length]...) - configurationAttributeData = configurationAttributeData[length:] - - configuration.ConfigurationAttribute = append(configuration.ConfigurationAttribute, individualConfigurationAttribute) - } - } - - return nil -} - -// Definition of IKE EAP - -var _ IKEPayload = &EAP{} - -type EAP struct { - Code uint8 - Identifier uint8 - EAPTypeData EAPTypeDataContainer -} - -func (eap *EAP) Type() IKEPayloadType { return TypeEAP } - -func (eap *EAP) marshal() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("[EAP] marshal(): Start marshalling") - - eapData := make([]byte, 4) - - eapData[0] = eap.Code - eapData[1] = eap.Identifier - - if len(eap.EAPTypeData) > 0 { - eapTypeData, err := eap.EAPTypeData[0].marshal() - if err != nil { - return nil, fmt.Errorf("EAP: EAP type data marshal failed: %+v", err) - } - - eapData = append(eapData, eapTypeData...) - } - - binary.BigEndian.PutUint16(eapData[2:4], uint16(len(eapData))) - - return eapData, nil -} - -func (eap *EAP) unmarshal(rawData []byte) error { - ikeLog := logger.IKELog - ikeLog.Info("[EAP] unmarshal(): Start unmarshalling received bytes") - ikeLog.Tracef("[EAP] unmarshal(): Payload length %d bytes", len(rawData)) - - if len(rawData) > 0 { - ikeLog.Trace("[EAP] unmarshal(): Unmarshal 1 EAP") - // bounds checking - if len(rawData) < 4 { - return errors.New("EAP: No sufficient bytes to decode next EAP payload") - } - eapPayloadLength := binary.BigEndian.Uint16(rawData[2:4]) - if eapPayloadLength < 4 { - return errors.New("EAP: Payload length specified in the header is too small for EAP") - } - if len(rawData) != int(eapPayloadLength) { - return errors.New("EAP: Received payload length not matches the length specified in header") - } - - eap.Code = rawData[0] - eap.Identifier = rawData[1] - - // EAP Success or Failed - if eapPayloadLength == 4 { - return nil - } - - eapType := rawData[4] - var eapTypeData EAPTypeFormat - - switch eapType { - case EAPTypeIdentity: - eapTypeData = new(EAPIdentity) - case EAPTypeNotification: - eapTypeData = new(EAPNotification) - case EAPTypeNak: - eapTypeData = new(EAPNak) - case EAPTypeExpanded: - eapTypeData = new(EAPExpanded) - default: - // TODO: Create unsupprted type to handle it - return errors.New("EAP: Not supported EAP type") - } - - if err := eapTypeData.unmarshal(rawData[4:]); err != nil { - return fmt.Errorf("EAP: Unamrshal EAP type data failed: %+v", err) - } - - eap.EAPTypeData = append(eap.EAPTypeData, eapTypeData) - } - - return nil -} - -type EAPTypeDataContainer []EAPTypeFormat - -type EAPTypeFormat interface { - // Type specifies EAP types - Type() EAPType - - // Called by EAP.marshal() or EAP.unmarshal() - marshal() ([]byte, error) - unmarshal(rawData []byte) error -} - -// Definition of EAP Identity - -var _ EAPTypeFormat = &EAPIdentity{} - -type EAPIdentity struct { - IdentityData []byte -} - -func (eapIdentity *EAPIdentity) Type() EAPType { return EAPTypeIdentity } - -func (eapIdentity *EAPIdentity) marshal() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("[EAP][Identity] marshal(): Start marshalling") - - if len(eapIdentity.IdentityData) == 0 { - return nil, errors.New("EAPIdentity: EAP identity is empty") - } - - eapIdentityData := []byte{EAPTypeIdentity} - eapIdentityData = append(eapIdentityData, eapIdentity.IdentityData...) - - return eapIdentityData, nil -} - -func (eapIdentity *EAPIdentity) unmarshal(rawData []byte) error { - ikeLog := logger.IKELog - ikeLog.Info("[EAP][Identity] unmarshal(): Start unmarshalling received bytes") - ikeLog.Tracef("[EAP][Identity] unmarshal(): Payload length %d bytes", len(rawData)) - - if len(rawData) > 1 { - eapIdentity.IdentityData = append(eapIdentity.IdentityData, rawData[1:]...) - } - - return nil -} - -// Definition of EAP Notification - -var _ EAPTypeFormat = &EAPNotification{} - -type EAPNotification struct { - NotificationData []byte -} - -func (eapNotification *EAPNotification) Type() EAPType { return EAPTypeNotification } - -func (eapNotification *EAPNotification) marshal() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("[EAP][Notification] marshal(): Start marshalling") - - if len(eapNotification.NotificationData) == 0 { - return nil, errors.New("EAPNotification: EAP notification is empty") - } - - eapNotificationData := []byte{EAPTypeNotification} - eapNotificationData = append(eapNotificationData, eapNotification.NotificationData...) - - return eapNotificationData, nil -} - -func (eapNotification *EAPNotification) unmarshal(rawData []byte) error { - ikeLog := logger.IKELog - ikeLog.Info("[EAP][Notification] unmarshal(): Start unmarshalling received bytes") - ikeLog.Tracef("[EAP][Notification] unmarshal(): Payload length %d bytes", len(rawData)) - - if len(rawData) > 1 { - eapNotification.NotificationData = append(eapNotification.NotificationData, rawData[1:]...) - } - - return nil -} - -// Definition of EAP Nak - -var _ EAPTypeFormat = &EAPNak{} - -type EAPNak struct { - NakData []byte -} - -func (eapNak *EAPNak) Type() EAPType { return EAPTypeNak } - -func (eapNak *EAPNak) marshal() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("[EAP][Nak] marshal(): Start marshalling") - - if len(eapNak.NakData) == 0 { - return nil, errors.New("EAPNak: EAP nak is empty") - } - - eapNakData := []byte{EAPTypeNak} - eapNakData = append(eapNakData, eapNak.NakData...) - - return eapNakData, nil -} - -func (eapNak *EAPNak) unmarshal(rawData []byte) error { - ikeLog := logger.IKELog - ikeLog.Info("[EAP][Nak] unmarshal(): Start unmarshalling received bytes") - ikeLog.Tracef("[EAP][Nak] unmarshal(): Payload length %d bytes", len(rawData)) - - if len(rawData) > 1 { - eapNak.NakData = append(eapNak.NakData, rawData[1:]...) - } - - return nil -} - -// Definition of EAP expanded - -var _ EAPTypeFormat = &EAPExpanded{} - -type EAPExpanded struct { - VendorID uint32 - VendorType uint32 - VendorData []byte -} - -func (eapExpanded *EAPExpanded) Type() EAPType { return EAPTypeExpanded } - -func (eapExpanded *EAPExpanded) marshal() ([]byte, error) { - ikeLog := logger.IKELog - ikeLog.Info("[EAP][Expanded] marshal(): Start marshalling") - - eapExpandedData := make([]byte, 8) - - vendorID := eapExpanded.VendorID & 0x00ffffff - typeAndVendorID := (uint32(EAPTypeExpanded)<<24 | vendorID) - - binary.BigEndian.PutUint32(eapExpandedData[0:4], typeAndVendorID) - binary.BigEndian.PutUint32(eapExpandedData[4:8], eapExpanded.VendorType) - - if len(eapExpanded.VendorData) == 0 { - ikeLog.Warn("[EAP][Expanded] marshal(): EAP vendor data field is empty") - return eapExpandedData, nil - } - - eapExpandedData = append(eapExpandedData, eapExpanded.VendorData...) - - return eapExpandedData, nil -} - -func (eapExpanded *EAPExpanded) unmarshal(rawData []byte) error { - ikeLog := logger.IKELog - ikeLog.Info("[EAP][Expanded] unmarshal(): Start unmarshalling received bytes") - ikeLog.Tracef("[EAP][Expanded] unmarshal(): Payload length %d bytes", len(rawData)) - - if len(rawData) > 0 { - if len(rawData) < 8 { - return errors.New("EAPExpanded: No sufficient bytes to decode the EAP expanded type") - } - - typeAndVendorID := binary.BigEndian.Uint32(rawData[0:4]) - eapExpanded.VendorID = typeAndVendorID & 0x00ffffff - - eapExpanded.VendorType = binary.BigEndian.Uint32(rawData[4:8]) - - if len(rawData) > 8 { - eapExpanded.VendorData = append(eapExpanded.VendorData, rawData[8:]...) - } - } - - return nil -} diff --git a/pkg/ike/message/message_test.go b/pkg/ike/message/message_test.go deleted file mode 100644 index 614ce1b2..00000000 --- a/pkg/ike/message/message_test.go +++ /dev/null @@ -1,435 +0,0 @@ -package message - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - Crand "crypto/rand" - "encoding/binary" - "encoding/hex" - "io" - Mrand "math/rand" - "net" - "testing" -) - -// TestEncodeDecode tests the Encode() and Decode() function using the data -// build manually. -// First, build each payload with correct value, then the IKE message for -// IKE_SA_INIT type. -// Second, encode/decode the IKE message using Encode/Decode function, and then -// re-encode the decoded message again. -// Third, send the encoded data to the UDP connection for verification with Wireshark. -// Compare the dataFirstEncode and dataSecondEncode and return the result. -func TestEncodeDecode(t *testing.T) { - conn, err := net.Dial("udp", "127.0.0.1:500") - if err != nil { - t.Fatalf("udp Dial failed: %+v", err) - } - testPacket := &IKEMessage{} - - // random an SPI - src := Mrand.NewSource(63579) - localRand := Mrand.New(src) - ispi := localRand.Uint64() - - testPacket.InitiatorSPI = ispi - testPacket.Version = 0x20 - testPacket.ExchangeType = 34 // IKE_SA_INIT - testPacket.Flags = 16 // flagI is set - testPacket.MessageID = 0 // for IKE_SA_INIT - - testSA := &SecurityAssociation{} - - testProposal1 := &Proposal{} - testProposal1.ProposalNumber = 1 // first - testProposal1.ProtocolID = 1 // IKE - - testtransform1 := &Transform{} - testtransform1.TransformType = 1 // ENCR - testtransform1.TransformID = 12 // ENCR_AES_CBC - testtransform1.AttributePresent = true - testtransform1.AttributeFormat = 1 - testtransform1.AttributeType = 14 - testtransform1.AttributeValue = 128 - - testProposal1.EncryptionAlgorithm = append(testProposal1.EncryptionAlgorithm, testtransform1) - - testtransform2 := &Transform{} - testtransform2.TransformType = 1 // ENCR - testtransform2.TransformID = 12 // ENCR_AES_CBC - testtransform2.AttributePresent = true - testtransform2.AttributeFormat = 1 - testtransform2.AttributeType = 14 - testtransform2.AttributeValue = 192 - - testProposal1.EncryptionAlgorithm = append(testProposal1.EncryptionAlgorithm, testtransform2) - - testtransform3 := &Transform{} - testtransform3.TransformType = 3 // INTEG - testtransform3.TransformID = 5 // AUTH_AES_XCBC_96 - testtransform3.AttributePresent = false - - testProposal1.IntegrityAlgorithm = append(testProposal1.IntegrityAlgorithm, testtransform3) - - testtransform4 := &Transform{} - testtransform4.TransformType = 3 // INTEG - testtransform4.TransformID = 2 // AUTH_HMAC_SHA1_96 - testtransform4.AttributePresent = false - - testProposal1.IntegrityAlgorithm = append(testProposal1.IntegrityAlgorithm, testtransform4) - - testSA.Proposals = append(testSA.Proposals, testProposal1) - - testProposal2 := &Proposal{} - testProposal2.ProposalNumber = 2 // second - testProposal2.ProtocolID = 1 // IKE - - testtransform1 = &Transform{} - testtransform1.TransformType = 1 // ENCR - testtransform1.TransformID = 12 // ENCR_AES_CBC - testtransform1.AttributePresent = true - testtransform1.AttributeFormat = 1 - testtransform1.AttributeType = 14 - testtransform1.AttributeValue = 128 - - testProposal2.EncryptionAlgorithm = append(testProposal2.EncryptionAlgorithm, testtransform1) - - testtransform2 = &Transform{} - testtransform2.TransformType = 1 // ENCR - testtransform2.TransformID = 12 // ENCR_AES_CBC - testtransform2.AttributePresent = true - testtransform2.AttributeFormat = 1 - testtransform2.AttributeType = 14 - testtransform2.AttributeValue = 192 - - testProposal2.EncryptionAlgorithm = append(testProposal2.EncryptionAlgorithm, testtransform2) - - testtransform3 = &Transform{} - testtransform3.TransformType = 3 // INTEG - testtransform3.TransformID = 1 // AUTH_HMAC_MD5_96 - testtransform3.AttributePresent = false - - testProposal2.IntegrityAlgorithm = append(testProposal2.IntegrityAlgorithm, testtransform3) - - testtransform4 = &Transform{} - testtransform4.TransformType = 3 // INTEG - testtransform4.TransformID = 2 // AUTH_HMAC_SHA1_96 - testtransform4.AttributePresent = false - - testProposal2.IntegrityAlgorithm = append(testProposal2.IntegrityAlgorithm, testtransform4) - - testSA.Proposals = append(testSA.Proposals, testProposal2) - - testPacket.Payloads = append(testPacket.Payloads, testSA) - - testKE := &KeyExchange{} - - testKE.DiffieHellmanGroup = 1 - for i := 0; i < 8; i++ { - partKeyExchangeData := make([]byte, 8) - binary.BigEndian.PutUint64(partKeyExchangeData, 7482105748278537214) - testKE.KeyExchangeData = append(testKE.KeyExchangeData, partKeyExchangeData...) - } - - testPacket.Payloads = append(testPacket.Payloads, testKE) - - testIDr := &IdentificationResponder{} - - testIDr.IDType = 3 - for i := 0; i < 8; i++ { - partIdentification := make([]byte, 8) - binary.BigEndian.PutUint64(partIdentification, 4378215321473912643) - testIDr.IDData = append(testIDr.IDData, partIdentification...) - } - - testPacket.Payloads = append(testPacket.Payloads, testIDr) - - testCert := &Certificate{} - - testCert.CertificateEncoding = 1 - for i := 0; i < 8; i++ { - partCertificate := make([]byte, 8) - binary.BigEndian.PutUint64(partCertificate, 4378217432157543265) - testCert.CertificateData = append(testCert.CertificateData, partCertificate...) - } - - testPacket.Payloads = append(testPacket.Payloads, testCert) - - testCertReq := &CertificateRequest{} - - testCertReq.CertificateEncoding = 1 - for i := 0; i < 8; i++ { - partCertificateRquest := make([]byte, 8) - binary.BigEndian.PutUint64(partCertificateRquest, 7438274381754372584) - testCertReq.CertificationAuthority = append(testCertReq.CertificationAuthority, partCertificateRquest...) - } - - testPacket.Payloads = append(testPacket.Payloads, testCertReq) - - testAuth := &Authentication{} - - testAuth.AuthenticationMethod = 1 - for i := 0; i < 8; i++ { - partAuthentication := make([]byte, 8) - binary.BigEndian.PutUint64(partAuthentication, 4632714362816473824) - testAuth.AuthenticationData = append(testAuth.AuthenticationData, partAuthentication...) - } - - testPacket.Payloads = append(testPacket.Payloads, testAuth) - - testNonce := &Nonce{} - - for i := 0; i < 8; i++ { - partNonce := make([]byte, 8) - binary.BigEndian.PutUint64(partNonce, 8984327463782167381) - testNonce.NonceData = append(testNonce.NonceData, partNonce...) - } - - testPacket.Payloads = append(testPacket.Payloads, testNonce) - - testNotification := &Notification{} - - testNotification.ProtocolID = 1 - testNotification.NotifyMessageType = 2 - - for i := 0; i < 5; i++ { - partSPI := make([]byte, 8) - binary.BigEndian.PutUint64(partSPI, 4372847328749832794) - testNotification.SPI = append(testNotification.SPI, partSPI...) - } - - for i := 0; i < 19; i++ { - partNotification := make([]byte, 8) - binary.BigEndian.PutUint64(partNotification, 9721437148392747354) - testNotification.NotificationData = append(testNotification.NotificationData, partNotification...) - } - - testPacket.Payloads = append(testPacket.Payloads, testNotification) - - testDelete := &Delete{} - - testDelete.ProtocolID = 1 - testDelete.SPISize = 9 - testDelete.NumberOfSPI = 4 - - for i := 0; i < 36; i++ { - testDelete.SPIs = append(testDelete.SPIs, 87) - } - - testPacket.Payloads = append(testPacket.Payloads, testDelete) - - testVendor := &VendorID{} - - for i := 0; i < 5; i++ { - partVendorData := make([]byte, 8) - binary.BigEndian.PutUint64(partVendorData, 5421487329873941748) - testVendor.VendorIDData = append(testVendor.VendorIDData, partVendorData...) - } - - testPacket.Payloads = append(testPacket.Payloads, testVendor) - - testTSi := &TrafficSelectorResponder{} - - testIndividualTS := &IndividualTrafficSelector{} - - testIndividualTS.TSType = 7 - testIndividualTS.IPProtocolID = 6 - testIndividualTS.StartPort = 1989 - testIndividualTS.EndPort = 2020 - - testIndividualTS.StartAddress = []byte{192, 168, 0, 15} - testIndividualTS.EndAddress = []byte{192, 168, 0, 192} - - testTSi.TrafficSelectors = append(testTSi.TrafficSelectors, testIndividualTS) - - testIndividualTS = &IndividualTrafficSelector{} - - testIndividualTS.TSType = 8 - testIndividualTS.IPProtocolID = 6 - testIndividualTS.StartPort = 2010 - testIndividualTS.EndPort = 2050 - - testIndividualTS.StartAddress = net.ParseIP("2001:db8::68") - testIndividualTS.EndAddress = net.ParseIP("2001:db8::72") - - testTSi.TrafficSelectors = append(testTSi.TrafficSelectors, testIndividualTS) - - testPacket.Payloads = append(testPacket.Payloads, testTSi) - - testCP := new(Configuration) - - testCP.ConfigurationType = 1 - - testIndividualConfigurationAttribute := new(IndividualConfigurationAttribute) - - testIndividualConfigurationAttribute.Type = 1 - testIndividualConfigurationAttribute.Value = []byte{10, 1, 14, 1} - - testCP.ConfigurationAttribute = append(testCP.ConfigurationAttribute, testIndividualConfigurationAttribute) - - testPacket.Payloads = append(testPacket.Payloads, testCP) - - testEAP := new(EAP) - - testEAP.Code = 1 - testEAP.Identifier = 123 - - // testEAPExpanded := &EAPExpanded{ - // VendorID: 26838, - // VendorType: 1, - // VendorData: []byte{9, 4, 8, 7}, - // } - - testEAPNotification := new(EAPNotification) - - rawstr := "I'm tired" - testEAPNotification.NotificationData = []byte(rawstr) - - testEAP.EAPTypeData = append(testEAP.EAPTypeData, testEAPNotification) - - testPacket.Payloads = append(testPacket.Payloads, testEAP) - - testSK := new(Encrypted) - - testSK.NextPayload = TypeSA - - ikePayload := IKEPayloadContainer{ - testSA, - testAuth, - } - - ikePayloadDataForSK, retErr := ikePayload.Encode() - if retErr != nil { - t.Fatalf("EncodePayload failed: %+v", retErr) - } - - // aes 128 key - key, retErr := hex.DecodeString("6368616e676520746869732070617373") - if retErr != nil { - t.Fatalf("HexDecoding failed: %+v", retErr) - } - block, retErr := aes.NewCipher(key) - if retErr != nil { - t.Fatalf("AES NewCipher failed: %+v", retErr) - } - - // padding plaintext - padNum := len(ikePayloadDataForSK) % aes.BlockSize - for i := 0; i < (aes.BlockSize - padNum); i++ { - ikePayloadDataForSK = append(ikePayloadDataForSK, byte(padNum)) - } - - // ciphertext - cipherText := make([]byte, aes.BlockSize+len(ikePayloadDataForSK)) - iv := cipherText[:aes.BlockSize] - _, err = io.ReadFull(Crand.Reader, iv) - if err != nil { - t.Fatalf("IO ReadFull failed: %+v", err) - } - - // CBC mode - mode := cipher.NewCBCEncrypter(block, iv) - mode.CryptBlocks(cipherText[aes.BlockSize:], ikePayloadDataForSK) - - testSK.EncryptedData = cipherText - - testPacket.Payloads = append(testPacket.Payloads, testSK) - - var dataFirstEncode, dataSecondEncode []byte - decodedPacket := new(IKEMessage) - - if dataFirstEncode, err = testPacket.Encode(); err != nil { - t.Fatalf("Encode failed: %+v", err) - } - - t.Logf("%+v", dataFirstEncode) - - if err = decodedPacket.Decode(dataFirstEncode); err != nil { - t.Fatalf("Decode failed: %+v", err) - } - - if dataSecondEncode, err = decodedPacket.Encode(); err != nil { - t.Fatalf("Encode failed: %+v", err) - } - - t.Logf("Original IKE Message: %+v", dataFirstEncode) - t.Logf("Result IKE Message: %+v", dataSecondEncode) - - _, err = conn.Write(dataFirstEncode) - if err != nil { - t.Fatalf("Error: %+v", err) - } - - if !bytes.Equal(dataFirstEncode, dataSecondEncode) { - t.FailNow() - } -} - -// TestEncodeDecodeUsingPublicData tests the Encode() and Decode() function -// using the public data. -// Decode and encode the data, and compare the verifyData and the origin -// data and return the result. -func TestEncodeDecodeUsingPublicData(t *testing.T) { - data := []byte{ - 0x86, 0x43, 0x30, 0xac, 0x30, 0xe6, 0x56, 0x4d, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x21, 0x20, 0x22, 0x08, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xc9, 0x22, 0x00, 0x00, - 0x30, 0x00, 0x00, 0x00, 0x2c, 0x01, 0x01, 0x00, 0x04, 0x03, 0x00, - 0x00, 0x0c, 0x01, 0x00, 0x00, 0x0c, 0x80, 0x0e, 0x00, 0x80, - 0x03, 0x00, 0x00, 0x08, 0x02, 0x00, 0x00, 0x02, 0x03, 0x00, 0x00, - 0x08, 0x03, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x04, - 0x00, 0x00, 0x02, 0x28, 0x00, 0x00, 0x88, 0x00, 0x02, 0x00, 0x00, - 0x03, 0xdc, 0xf5, 0x9a, 0x29, 0x05, 0x7b, 0x5a, 0x49, 0xbd, - 0x55, 0x8c, 0x9b, 0x14, 0x7a, 0x11, 0x0e, 0xed, 0xff, 0xe5, 0xea, - 0x2d, 0x12, 0xc2, 0x1e, 0x5c, 0x7a, 0x5f, 0x5e, 0x9c, 0x99, - 0xe3, 0xd1, 0xd3, 0x00, 0x24, 0x3c, 0x89, 0x73, 0x1e, 0x6c, 0x6d, - 0x63, 0x41, 0x7b, 0x33, 0xfa, 0xaf, 0x5a, 0xc7, 0x26, 0xe8, - 0xb6, 0xf8, 0xc3, 0xb5, 0x2a, 0x14, 0xeb, 0xec, 0xd5, 0x6f, 0x1b, - 0xd9, 0x5b, 0x28, 0x32, 0x84, 0x9e, 0x26, 0xfc, 0x59, 0xee, - 0xf1, 0x4e, 0x38, 0x5f, 0x55, 0xc2, 0x1b, 0xe8, 0xf6, 0xa3, 0xfb, - 0xc5, 0x55, 0xd7, 0x35, 0x92, 0x86, 0x24, 0x00, 0x62, 0x8b, - 0xea, 0xce, 0x23, 0xf0, 0x47, 0xaf, 0xaa, 0xf8, 0x61, 0xe4, 0x5c, - 0x42, 0xba, 0x5c, 0xa1, 0x4a, 0x52, 0x6e, 0xd8, 0xe8, 0xf1, - 0xb9, 0x74, 0xae, 0xe4, 0xd1, 0x9c, 0x9f, 0xa5, 0x9b, 0xf0, 0xd7, - 0xdb, 0x55, 0x2b, 0x00, 0x00, 0x44, 0x4c, 0xa7, 0xf3, 0x9b, - 0xcd, 0x1d, 0xc2, 0x01, 0x79, 0xfa, 0xa2, 0xe4, 0x72, 0xe0, 0x61, - 0xc4, 0x45, 0x61, 0xe6, 0x49, 0x2d, 0xb3, 0x96, 0xae, 0xc9, - 0x2c, 0xdb, 0x54, 0x21, 0xf4, 0x98, 0x4f, 0x72, 0xd2, 0x43, 0x78, - 0xab, 0x80, 0xe4, 0x6c, 0x01, 0x78, 0x6a, 0xc4, 0x64, 0x45, - 0xbc, 0xa8, 0x1f, 0x56, 0xbc, 0xed, 0xf9, 0xb5, 0xd8, 0x21, 0x95, - 0x41, 0x71, 0xe9, 0x0e, 0xb4, 0x3c, 0x4e, 0x2b, 0x00, 0x00, - 0x17, 0x43, 0x49, 0x53, 0x43, 0x4f, 0x2d, 0x44, 0x45, 0x4c, 0x45, - 0x54, 0x45, 0x2d, 0x52, 0x45, 0x41, 0x53, 0x4f, 0x4e, 0x2b, - 0x00, 0x00, 0x3b, 0x43, 0x49, 0x53, 0x43, 0x4f, 0x28, 0x43, 0x4f, - 0x50, 0x59, 0x52, 0x49, 0x47, 0x48, 0x54, 0x29, 0x26, 0x43, - 0x6f, 0x70, 0x79, 0x72, 0x69, 0x67, 0x68, 0x74, 0x20, 0x28, 0x63, - 0x29, 0x20, 0x32, 0x30, 0x30, 0x39, 0x20, 0x43, 0x69, 0x73, - 0x63, 0x6f, 0x20, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x73, 0x2c, - 0x20, 0x49, 0x6e, 0x63, 0x2e, 0x29, 0x00, 0x00, 0x13, 0x43, - 0x49, 0x53, 0x43, 0x4f, 0x2d, 0x47, 0x52, 0x45, 0x2d, 0x4d, 0x4f, - 0x44, 0x45, 0x02, 0x29, 0x00, 0x00, 0x1c, 0x01, 0x00, 0x40, - 0x04, 0x7e, 0x57, 0x6c, 0xc0, 0x13, 0xd4, 0x05, 0x43, 0xa2, 0xe8, - 0x77, 0x7d, 0x00, 0x34, 0x68, 0xa5, 0xb1, 0x89, 0x0c, 0x58, - 0x2b, 0x00, 0x00, 0x1c, 0x01, 0x00, 0x40, 0x05, 0x52, 0x64, 0x4d, - 0x87, 0xd4, 0x7c, 0x2d, 0x44, 0x23, 0xbd, 0x37, 0xe4, 0x48, - 0xa9, 0xf5, 0x17, 0x01, 0x81, 0xcb, 0x8a, 0x00, 0x00, 0x00, 0x14, - 0x40, 0x48, 0xb7, 0xd5, 0x6e, 0xbc, 0xe8, 0x85, 0x25, 0xe7, - 0xde, 0x7f, 0x00, 0xd6, 0xc2, 0xd3, - } - - ikePacket := new(IKEMessage) - err := ikePacket.Decode(data) - if err != nil { - t.Fatalf("Decode failed: %+v", err) - } - - verifyData, err := ikePacket.Encode() - if err != nil { - t.Fatalf("Encode failed: %+v", err) - } - - if !bytes.Equal(data, verifyData) { - t.FailNow() - } -} diff --git a/pkg/ike/message/types.go b/pkg/ike/message/types.go deleted file mode 100644 index 39e79ca6..00000000 --- a/pkg/ike/message/types.go +++ /dev/null @@ -1,296 +0,0 @@ -package message - -// IKE types -type IKEPayloadType uint8 - -const ( - NoNext = 0 - TypeSA = iota + 32 - TypeKE - TypeIDi - TypeIDr - TypeCERT - TypeCERTreq - TypeAUTH - TypeNiNr - TypeN - TypeD - TypeV - TypeTSi - TypeTSr - TypeSK - TypeCP - TypeEAP -) - -// EAP types -type EAPType uint8 - -const ( - EAPTypeIdentity = iota + 1 - EAPTypeNotification - EAPTypeNak - EAPTypeExpanded = 254 -) - -const ( - EAPCodeRequest = iota + 1 - EAPCodeResponse - EAPCodeSuccess - EAPCodeFailure -) - -// used for SecurityAssociation-Proposal-Transform TransformType -const ( - TypeEncryptionAlgorithm = iota + 1 - TypePseudorandomFunction - TypeIntegrityAlgorithm - TypeDiffieHellmanGroup - TypeExtendedSequenceNumbers -) - -// used for SecurityAssociation-Proposal-Transform AttributeFormat -const ( - AttributeFormatUseTLV = iota - AttributeFormatUseTV -) - -// used for SecurityAssociation-Proposal-Trandform AttributeType -const ( - AttributeTypeKeyLength = 14 -) - -// used for SecurityAssociation-Proposal-Transform TransformID -const ( - ENCR_DES_IV64 = 1 - ENCR_DES = 2 - ENCR_3DES = 3 - ENCR_RC5 = 4 - ENCR_IDEA = 5 - ENCR_CAST = 6 - ENCR_BLOWFISH = 7 - ENCR_3IDEA = 8 - ENCR_DES_IV32 = 9 - ENCR_NULL = 11 - ENCR_AES_CBC = 12 - ENCR_AES_CTR = 13 -) - -const ( - PRF_HMAC_MD5 = iota + 1 - PRF_HMAC_SHA1 - PRF_HMAC_TIGER - PRF_HMAC_SHA2_256 = 5 -) - -const ( - AUTH_NONE = iota - AUTH_HMAC_MD5_96 - AUTH_HMAC_SHA1_96 - AUTH_DES_MAC - AUTH_KPDK_MD5 - AUTH_AES_XCBC_96 - AUTH_HMAC_SHA2_256_128 = 12 -) - -const ( - DH_NONE = 0 - DH_768_BIT_MODP = 1 - DH_1024_BIT_MODP = 2 - DH_1536_BIT_MODP = 5 - DH_2048_BIT_MODP = iota + 10 - DH_3072_BIT_MODP - DH_4096_BIT_MODP - DH_6144_BIT_MODP - DH_8192_BIT_MODP -) - -const ( - ESN_NO = iota - ESN_NEED -) - -// used for TrafficSelector-Individual Traffic Selector TSType -const ( - TS_IPV4_ADDR_RANGE = 7 - TS_IPV6_ADDR_RANGE = 8 -) - -// Exchange Type -const ( - IKE_SA_INIT = iota + 34 - IKE_AUTH - CREATE_CHILD_SA - INFORMATIONAL -) - -// Notify message types -const ( - UNSUPPORTED_CRITICAL_PAYLOAD = 1 - INVALID_IKE_SPI = 4 - INVALID_MAJOR_VERSION = 5 - INVALID_SYNTAX = 7 - INVALID_MESSAGE_ID = 9 - INVALID_SPI = 11 - NO_PROPOSAL_CHOSEN = 14 - INVALID_KE_PAYLOAD = 17 - AUTHENTICATION_FAILED = 24 - SINGLE_PAIR_REQUIRED = 34 - NO_ADDITIONAL_SAS = 35 - INTERNAL_ADDRESS_FAILURE = 36 - FAILED_CP_REQUIRED = 37 - TS_UNACCEPTABLE = 38 - INVALID_SELECTORS = 39 - TEMPORARY_FAILURE = 43 - CHILD_SA_NOT_FOUND = 44 - INITIAL_CONTACT = 16384 - SET_WINDOW_SIZE = 16385 - ADDITIONAL_TS_POSSIBLE = 16386 - IPCOMP_SUPPORTED = 16387 - NAT_DETECTION_SOURCE_IP = 16388 - NAT_DETECTION_DESTINATION_IP = 16389 - COOKIE = 16390 - USE_TRANSPORT_MODE = 16391 - HTTP_CERT_LOOKUP_SUPPORTED = 16392 - REKEY_SA = 16393 - ESP_TFC_PADDING_NOT_SUPPORTED = 16394 - NON_FIRST_FRAGMENTS_ALSO = 16395 -) - -// Protocol ID -const ( - TypeNone = iota - TypeIKE - TypeAH - TypeESP -) - -// Flags -const ( - ResponseBitCheck = 0x20 - VersionBitCheck = 0x10 - InitiatorBitCheck = 0x08 -) - -// Certificate encoding -const ( - PKCS7WrappedX509Certificate = 1 - PGPCertificate = 2 - DNSSignedKey = 3 - X509CertificateSignature = 4 - KerberosToken = 6 - CertificateRevocationList = 7 - AuthorityRevocationList = 8 - SPKICertificate = 9 - X509CertificateAttribute = 10 - HashAndURLOfX509Certificate = 12 - HashAndURLOfX509Bundle = 13 -) - -// ID Types -const ( - ID_IPV4_ADDR = 1 - ID_FQDN = 2 - ID_RFC822_ADDR = 3 - ID_IPV6_ADDR = 5 - ID_DER_ASN1_DN = 9 - ID_DER_ASN1_GN = 10 - ID_KEY_ID = 11 -) - -// Authentication Methods -const ( - RSADigitalSignature = iota + 1 - SharedKeyMesageIntegrityCode - DSSDigitalSignature -) - -// Configuration types -const ( - CFG_REQUEST = 1 - CFG_REPLY = 2 - CFG_SET = 3 - CFG_ACK = 4 -) - -// Configuration attribute types -const ( - INTERNAL_IP4_ADDRESS = 1 - INTERNAL_IP4_NETMASK = 2 - INTERNAL_IP4_DNS = 3 - INTERNAL_IP4_NBNS = 4 - INTERNAL_IP4_DHCP = 6 - APPLICATION_VERSION = 7 - INTERNAL_IP6_ADDRESS = 8 - INTERNAL_IP6_DNS = 10 - INTERNAL_IP6_DHCP = 12 - INTERNAL_IP4_SUBNET = 13 - SUPPORTED_ATTRIBUTES = 14 - INTERNAL_IP6_SUBNET = 15 -) - -// IP protocols ID, used in individual traffic selector -const ( - IPProtocolAll = 0 - IPProtocolICMP = 1 - IPProtocolTCP = 6 - IPProtocolUDP = 17 - IPProtocolGRE = 47 -) - -// Types for EAP-5G -// Used in IKE EAP expanded for vendor ID -const VendorID3GPP = 10415 - -// Used in IKE EAP expanded for vendor data -const VendorTypeEAP5G = 3 - -// Used in EAP-5G for message ID -const ( - EAP5GType5GStart = 1 - EAP5GType5GNAS = 2 - EAP5GType5GStop = 4 -) - -// Used in AN-Parameter field for IE types -const ( - ANParametersTypeGUAMI = 1 - ANParametersTypeSelectedPLMNID = 2 - ANParametersTypeRequestedNSSAI = 3 - ANParametersTypeEstablishmentCause = 4 -) - -// Used for checking if AN-Parameter length field is legal -const ( - ANParametersLenGUAMI = 6 - ANParametersLenPLMNID = 3 - ANParametersLenEstCause = 1 -) - -// Used in IE Establishment Cause field for cause types -const ( - EstablishmentCauseEmergency = 0 - EstablishmentCauseHighPriorityAccess = 1 - EstablishmentCauseMO_Signalling = 3 - EstablishmentCauseMO_Data = 4 - EstablishmentCauseMPS_PriorityAccess = 8 - EstablishmentCauseMCS_PriorityAccess = 9 -) - -// Spare -const EAP5GSpareValue = 0 - -// 3GPP specified IKE Notify -// 3GPP specified IKE Notify Message Types -const ( - Vendor3GPPNotifyType5G_QOS_INFO uint16 = 55501 - Vendor3GPPNotifyTypeNAS_IP4_ADDRESS uint16 = 55502 - Vendor3GPPNotifyTypeUP_IP4_ADDRESS uint16 = 55504 - Vendor3GPPNotifyTypeNAS_TCP_PORT uint16 = 55506 -) - -// Used in NotifyType5G_QOS_INFO -const ( - NotifyType5G_QOS_INFOBitDSCPICheck uint8 = 1 - NotifyType5G_QOS_INFOBitDCSICheck uint8 = 1 << 1 -) diff --git a/pkg/ike/security/security.go b/pkg/ike/security/security.go deleted file mode 100644 index 188eeb6c..00000000 --- a/pkg/ike/security/security.go +++ /dev/null @@ -1,898 +0,0 @@ -package security - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/hmac" - "crypto/md5" - "crypto/rand" - "crypto/sha1" // #nosec G505 - "crypto/sha256" - "encoding/binary" - "encoding/hex" - "hash" - "io" - "math/big" - "strings" - - "github.com/pkg/errors" - - "github.com/free5gc/n3iwf/internal/logger" - n3iwf_context "github.com/free5gc/n3iwf/pkg/context" - "github.com/free5gc/n3iwf/pkg/ike/message" -) - -// General data -var ( - randomNumberMaximum big.Int - randomNumberMinimum big.Int -) - -func init() { - randomNumberMaximum.SetString(strings.Repeat("F", 512), 16) - randomNumberMinimum.SetString("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", 16) -} - -func GenerateRandomNumber() *big.Int { - var number *big.Int - var err error - - ikeLog := logger.IKELog - for { - number, err = rand.Int(rand.Reader, &randomNumberMaximum) - if err != nil { - ikeLog.Errorf("Error occurs when generate random number: %+v", err) - return nil - } else { - if number.Cmp(&randomNumberMinimum) == 1 { - break - } - } - } - return number -} - -func GenerateRandomUint8() (uint8, error) { - ikeLog := logger.IKELog - number := make([]byte, 1) - _, err := io.ReadFull(rand.Reader, number) - if err != nil { - ikeLog.Errorf("Read random failed: %+v", err) - return 0, errors.New("Read failed") - } - return number[0], nil -} - -// Diffie-Hellman Exchange -// The strength supplied by group 1 may not be sufficient for typical uses -const ( - Group2PrimeString string = "FFFFFFFFFFFFFFFFC90FDAA22168C234" + - "C4C6628B80DC1CD129024E088A67CC74" + - "020BBEA63B139B22514A08798E3404DD" + - "EF9519B3CD3A431B302B0A6DF25F1437" + - "4FE1356D6D51C245E485B576625E7EC6" + - "F44C42E9A637ED6B0BFF5CB6F406B7ED" + - "EE386BFB5A899FA5AE9F24117C4B1FE6" + - "49286651ECE65381FFFFFFFFFFFFFFFF" - Group2Generator = 2 - Group14PrimeString string = "FFFFFFFFFFFFFFFFC90FDAA22168C234" + - "C4C6628B80DC1CD129024E088A67CC74" + - "020BBEA63B139B22514A08798E3404DD" + - "EF9519B3CD3A431B302B0A6DF25F1437" + - "4FE1356D6D51C245E485B576625E7EC6" + - "F44C42E9A637ED6B0BFF5CB6F406B7ED" + - "EE386BFB5A899FA5AE9F24117C4B1FE6" + - "49286651ECE45B3DC2007CB8A163BF05" + - "98DA48361C55D39A69163FA8FD24CF5F" + - "83655D23DCA3AD961C62F356208552BB" + - "9ED529077096966D670C354E4ABC9804" + - "F1746C08CA18217C32905E462E36CE3B" + - "E39E772C180E86039B2783A2EC07A28F" + - "B5C55DF06F4C52C9DE2BCBF695581718" + - "3995497CEA956AE515D2261898FA0510" + - "15728E5A8AACAA68FFFFFFFFFFFFFFFF" - Group14Generator = 2 -) - -func CalculateDiffieHellmanMaterials(secret *big.Int, peerPublicValue []byte, - diffieHellmanGroupNumber uint16, -) (localPublicValue []byte, sharedKey []byte) { - ikeLog := logger.IKELog - peerPublicValueBig := new(big.Int).SetBytes(peerPublicValue) - var generator, factor *big.Int - var ok bool - - switch diffieHellmanGroupNumber { - case message.DH_1024_BIT_MODP: - generator = new(big.Int).SetUint64(Group2Generator) - factor, ok = new(big.Int).SetString(Group2PrimeString, 16) - if !ok { - ikeLog.Errorf( - "Error occurs when setting big number \"factor\" in %d group", - diffieHellmanGroupNumber) - } - case message.DH_2048_BIT_MODP: - generator = new(big.Int).SetUint64(Group14Generator) - factor, ok = new(big.Int).SetString(Group14PrimeString, 16) - if !ok { - ikeLog.Errorf( - "Error occurs when setting big number \"factor\" in %d group", - diffieHellmanGroupNumber) - } - default: - ikeLog.Errorf("Unsupported Diffie-Hellman group: %d", diffieHellmanGroupNumber) - return localPublicValue, sharedKey - } - - localPublicValue = new(big.Int).Exp(generator, secret, factor).Bytes() - prependZero := make([]byte, len(factor.Bytes())-len(localPublicValue)) - localPublicValue = append(prependZero, localPublicValue...) - - sharedKey = new(big.Int).Exp(peerPublicValueBig, secret, factor).Bytes() - prependZero = make([]byte, len(factor.Bytes())-len(sharedKey)) - sharedKey = append(prependZero, sharedKey...) - - return localPublicValue, sharedKey -} - -// Pseudorandom Function -func NewPseudorandomFunction(key []byte, algorithmType uint16) (hash.Hash, bool) { - ikeLog := logger.IKELog - switch algorithmType { - case message.PRF_HMAC_MD5: - return hmac.New(md5.New, key), true - case message.PRF_HMAC_SHA1: - return hmac.New(sha1.New, key), true // #nosec G401 - case message.PRF_HMAC_SHA2_256: - return hmac.New(sha256.New, key), true - default: - ikeLog.Errorf("Unsupported pseudo random function: %d", algorithmType) - return nil, false - } -} - -// Integrity Algorithm -func calculateIntegrity(key []byte, originData []byte, transform *message.Transform) ([]byte, error) { - expectKeyLen, ok := getKeyLength( - transform.TransformType, transform.TransformID, - transform.AttributePresent, transform.AttributeValue) - if !ok { - return nil, errors.Errorf("calculateIntegrity[%d]: unsupported algo", transform.TransformID) - } - keyLen := len(key) - if keyLen != expectKeyLen { - return nil, errors.Errorf("calculateIntegrity[%d]: Unmatched input key length[%d:%d]", - transform.TransformID, keyLen, expectKeyLen) - } - outputLen, ok := getOutputLength( - transform.TransformType, transform.TransformID, - transform.AttributePresent, transform.AttributeValue) - if !ok { - return nil, errors.Errorf("calculateIntegrity[%d]: unsupported algo", transform.TransformID) - } - - var integrityFunction hash.Hash - switch transform.TransformID { - case message.AUTH_HMAC_MD5_96: - integrityFunction = hmac.New(md5.New, key) - case message.AUTH_HMAC_SHA1_96: - integrityFunction = hmac.New(sha1.New, key) // #nosec G401 - case message.AUTH_HMAC_SHA2_256_128: - integrityFunction = hmac.New(sha256.New, key) - default: - return nil, errors.Errorf("calculateIntegrity[%d]: unsupported algo", transform.TransformID) - } - - if _, err := integrityFunction.Write(originData); err != nil { - return nil, errors.Wrapf(err, "calculateIntegrity[%d]", transform.TransformID) - } - return integrityFunction.Sum(nil)[:outputLen], nil -} - -func verifyIntegrity(key []byte, originData []byte, checksum []byte, transform *message.Transform) (bool, error) { - ikeLog := logger.IKELog - expectChecksum, err := calculateIntegrity(key, originData, transform) - if err != nil { - return false, errors.Wrapf(err, "verifyIntegrity") - } - - ikeLog.Tracef("Calculated checksum:\n%s\nReceived checksum:\n%s", - hex.Dump(expectChecksum), hex.Dump(checksum)) - return hmac.Equal(expectChecksum, checksum), nil -} - -// Encryption Algorithm -func EncryptMessage(key []byte, originData []byte, algorithmType uint16) ([]byte, error) { - ikeLog := logger.IKELog - switch algorithmType { - case message.ENCR_AES_CBC: - // padding message - originData = PKCS7Padding(originData, aes.BlockSize) - originData[len(originData)-1]-- - - block, err := aes.NewCipher(key) - if err != nil { - ikeLog.Errorf("Error occur when create new cipher: %+v", err) - return nil, errors.New("Create cipher failed") - } - - cipherText := make([]byte, aes.BlockSize+len(originData)) - initializationVector := cipherText[:aes.BlockSize] - - _, err = io.ReadFull(rand.Reader, initializationVector) - if err != nil { - ikeLog.Errorf("Read random failed: %+v", err) - return nil, errors.New("Read random initialization vector failed") - } - - cbcBlockMode := cipher.NewCBCEncrypter(block, initializationVector) - cbcBlockMode.CryptBlocks(cipherText[aes.BlockSize:], originData) - - return cipherText, nil - default: - ikeLog.Errorf("Unsupported encryption algorithm: %d", algorithmType) - return nil, errors.New("Unsupported algorithm") - } -} - -func DecryptMessage(key []byte, cipherText []byte, algorithmType uint16) ([]byte, error) { - ikeLog := logger.IKELog - switch algorithmType { - case message.ENCR_AES_CBC: - if len(cipherText) < aes.BlockSize { - ikeLog.Error("Length of cipher text is too short to decrypt") - return nil, errors.New("Cipher text is too short") - } - - initializationVector := cipherText[:aes.BlockSize] - encryptedMessage := cipherText[aes.BlockSize:] - - if len(encryptedMessage)%aes.BlockSize != 0 { - ikeLog.Error("Cipher text is not a multiple of block size") - return nil, errors.New("Cipher text length error") - } - - plainText := make([]byte, len(encryptedMessage)) - - block, err := aes.NewCipher(key) - if err != nil { - ikeLog.Errorf("Error occur when create new cipher: %+v", err) - return nil, errors.New("Create cipher failed") - } - cbcBlockMode := cipher.NewCBCDecrypter(block, initializationVector) - cbcBlockMode.CryptBlocks(plainText, encryptedMessage) - - ikeLog.Tracef("Decrypted content:\n%s", hex.Dump(plainText)) - - padding := int(plainText[len(plainText)-1]) + 1 - plainText = plainText[:len(plainText)-padding] - - ikeLog.Tracef("Decrypted content with out padding:\n%s", hex.Dump(plainText)) - - return plainText, nil - default: - ikeLog.Errorf("Unsupported encryption algorithm: %d", algorithmType) - return nil, errors.New("Unsupported algorithm") - } -} - -func PKCS7Padding(plainText []byte, blockSize int) []byte { - padding := blockSize - (len(plainText) % blockSize) - if padding == 0 { - padding = blockSize - } - paddingText := bytes.Repeat([]byte{byte(padding)}, padding) - return append(plainText, paddingText...) -} - -// Certificate -func CompareRootCertificate( - ca []byte, - certificateEncoding uint8, - requestedCertificateAuthorityHash []byte, -) bool { - ikeLog := logger.IKELog - if certificateEncoding != message.X509CertificateSignature { - ikeLog.Debugf("Not support certificate type: %d. Reject.", certificateEncoding) - return false - } - - if len(ca) == 0 { - ikeLog.Error("Certificate authority in context is empty") - return false - } - - return bytes.Equal(ca, requestedCertificateAuthorityHash) -} - -// Key Gen for IKE SA -func GenerateKeyForIKESA(ikeSecurityAssociation *n3iwf_context.IKESecurityAssociation) error { - ikeLog := logger.IKELog - // Check parameters - if ikeSecurityAssociation == nil { - return errors.New("IKE SA is nil") - } - - // Check if the context contain needed data - if ikeSecurityAssociation.EncryptionAlgorithm == nil { - return errors.New("No encryption algorithm specified") - } - if ikeSecurityAssociation.IntegrityAlgorithm == nil { - return errors.New("No integrity algorithm specified") - } - if ikeSecurityAssociation.PseudorandomFunction == nil { - return errors.New("No pseudorandom function specified") - } - if ikeSecurityAssociation.DiffieHellmanGroup == nil { - return errors.New("No Diffie-hellman group algorithm specified") - } - - if len(ikeSecurityAssociation.ConcatenatedNonce) == 0 { - return errors.New("No concatenated nonce data") - } - if len(ikeSecurityAssociation.DiffieHellmanSharedKey) == 0 { - return errors.New("No Diffie-Hellman shared key") - } - - // Transforms - transformIntegrityAlgorithm := ikeSecurityAssociation.IntegrityAlgorithm - transformEncryptionAlgorithm := ikeSecurityAssociation.EncryptionAlgorithm - transformPseudorandomFunction := ikeSecurityAssociation.PseudorandomFunction - - // Get key length of SK_d, SK_ai, SK_ar, SK_ei, SK_er, SK_pi, SK_pr - var length_SK_d, length_SK_ai, length_SK_ar, length_SK_ei, length_SK_er, length_SK_pi, length_SK_pr, totalKeyLength int - var ok bool - - if length_SK_d, ok = getKeyLength(transformPseudorandomFunction.TransformType, - transformPseudorandomFunction.TransformID, transformPseudorandomFunction.AttributePresent, - transformPseudorandomFunction.AttributeValue); !ok { - ikeLog.Error("Get key length of an unsupported algorithm. This may imply an unsupported transform is chosen.") - return errors.New("Get key length failed") - } - if length_SK_ai, ok = getKeyLength(transformIntegrityAlgorithm.TransformType, - transformIntegrityAlgorithm.TransformID, transformIntegrityAlgorithm.AttributePresent, - transformIntegrityAlgorithm.AttributeValue); !ok { - ikeLog.Error("Get key length of an unsupported algorithm. This may imply an unsupported transform is chosen.") - return errors.New("Get key length failed") - } - length_SK_ar = length_SK_ai - if length_SK_ei, ok = getKeyLength(transformEncryptionAlgorithm.TransformType, - transformEncryptionAlgorithm.TransformID, transformEncryptionAlgorithm.AttributePresent, - transformEncryptionAlgorithm.AttributeValue); !ok { - ikeLog.Error("Get key length of an unsupported algorithm. This may imply an unsupported transform is chosen.") - return errors.New("Get key length failed") - } - length_SK_er = length_SK_ei - length_SK_pi, length_SK_pr = length_SK_d, length_SK_d - totalKeyLength = length_SK_d + length_SK_ai + length_SK_ar + length_SK_ei + length_SK_er + length_SK_pi + length_SK_pr - - // Generate IKE SA key as defined in RFC7296 Section 1.3 and Section 1.4 - var pseudorandomFunction hash.Hash - - if pseudorandomFunction, ok = NewPseudorandomFunction(ikeSecurityAssociation.ConcatenatedNonce, - transformPseudorandomFunction.TransformID); !ok { - ikeLog.Error("Get an unsupported pseudorandom funcion. This may imply an unsupported transform is chosen.") - return errors.New("New pseudorandom function failed") - } - - ikeLog.Tracef("DH shared key:\n%s", hex.Dump(ikeSecurityAssociation.DiffieHellmanSharedKey)) - ikeLog.Tracef("Concatenated nonce:\n%s", hex.Dump(ikeSecurityAssociation.ConcatenatedNonce)) - - if _, err := pseudorandomFunction.Write(ikeSecurityAssociation.DiffieHellmanSharedKey); err != nil { - ikeLog.Errorf("Pseudorandom function write error: %+v", err) - return errors.New("Pseudorandom function write failed") - } - - SKEYSEED := pseudorandomFunction.Sum(nil) - - ikeLog.Tracef("SKEYSEED:\n%s", hex.Dump(SKEYSEED)) - - seed := concatenateNonceAndSPI(ikeSecurityAssociation.ConcatenatedNonce, - ikeSecurityAssociation.RemoteSPI, ikeSecurityAssociation.LocalSPI) - - var keyStream, generatedKeyBlock []byte - var index byte - for index = 1; len(keyStream) < totalKeyLength; index++ { - if pseudorandomFunction, ok = NewPseudorandomFunction(SKEYSEED, transformPseudorandomFunction.TransformID); !ok { - ikeLog.Error("Get an unsupported pseudorandom funcion. This may imply an unsupported transform is chosen.") - return errors.New("New pseudorandom function failed") - } - if _, err := pseudorandomFunction.Write(append(append(generatedKeyBlock, seed...), index)); err != nil { - ikeLog.Errorf("Pseudorandom function write error: %+v", err) - return errors.New("Pseudorandom function write failed") - } - generatedKeyBlock = pseudorandomFunction.Sum(nil) - keyStream = append(keyStream, generatedKeyBlock...) - } - - // Assign keys into context - ikeSecurityAssociation.SK_d = keyStream[:length_SK_d] - keyStream = keyStream[length_SK_d:] - ikeSecurityAssociation.SK_ai = keyStream[:length_SK_ai] - keyStream = keyStream[length_SK_ai:] - ikeSecurityAssociation.SK_ar = keyStream[:length_SK_ar] - keyStream = keyStream[length_SK_ar:] - ikeSecurityAssociation.SK_ei = keyStream[:length_SK_ei] - keyStream = keyStream[length_SK_ei:] - ikeSecurityAssociation.SK_er = keyStream[:length_SK_er] - keyStream = keyStream[length_SK_er:] - ikeSecurityAssociation.SK_pi = keyStream[:length_SK_pi] - keyStream = keyStream[length_SK_pi:] - ikeSecurityAssociation.SK_pr = keyStream[:length_SK_pr] - // keyStream = keyStream[length_SK_pr:] - - ikeLog.Debugln("====== IKE Security Association Info =====") - ikeLog.Debugf("Initiator's SPI: %016x", ikeSecurityAssociation.RemoteSPI) - ikeLog.Debugf("Responder's SPI: %016x", ikeSecurityAssociation.LocalSPI) - ikeLog.Debugf("Encryption Algorithm: %d", ikeSecurityAssociation.EncryptionAlgorithm.TransformID) - ikeLog.Debugf("SK_ei: %x", ikeSecurityAssociation.SK_ei) - ikeLog.Debugf("SK_er: %x", ikeSecurityAssociation.SK_er) - ikeLog.Debugf("Integrity Algorithm: %d", ikeSecurityAssociation.IntegrityAlgorithm.TransformID) - ikeLog.Debugf("SK_ai: %x", ikeSecurityAssociation.SK_ai) - ikeLog.Debugf("SK_ar: %x", ikeSecurityAssociation.SK_ar) - ikeLog.Debugf("SK_pi: %x", ikeSecurityAssociation.SK_pi) - ikeLog.Debugf("SK_pr: %x", ikeSecurityAssociation.SK_pr) - - return nil -} - -// Key Gen for child SA -func GenerateKeyForChildSA(ikeSecurityAssociation *n3iwf_context.IKESecurityAssociation, - childSecurityAssociation *n3iwf_context.ChildSecurityAssociation, -) error { - ikeLog := logger.IKELog - // Check parameters - if ikeSecurityAssociation == nil { - return errors.New("IKE SA is nil") - } - if childSecurityAssociation == nil { - return errors.New("Child SA is nil") - } - - // Check if the context contain needed data - if ikeSecurityAssociation.PseudorandomFunction == nil { - return errors.New("No pseudorandom function specified") - } - if ikeSecurityAssociation.IKEAuthResponseSA == nil { - return errors.New("No IKE_AUTH response SA specified") - } - if len(ikeSecurityAssociation.IKEAuthResponseSA.Proposals) == 0 { - return errors.New("No proposal in IKE_AUTH response SA") - } - if len(ikeSecurityAssociation.IKEAuthResponseSA.Proposals[0].EncryptionAlgorithm) == 0 { - return errors.New("No encryption algorithm specified") - } - - if len(ikeSecurityAssociation.SK_d) == 0 { - return errors.New("No key deriving key") - } - - // Transforms - transformPseudorandomFunction := ikeSecurityAssociation.PseudorandomFunction - transformEncryptionAlgorithmForIPSec := ikeSecurityAssociation.IKEAuthResponseSA.Proposals[0].EncryptionAlgorithm[0] - var transformIntegrityAlgorithmForIPSec *message.Transform - if len(ikeSecurityAssociation.IKEAuthResponseSA.Proposals[0].IntegrityAlgorithm) != 0 { - transformIntegrityAlgorithmForIPSec = ikeSecurityAssociation.IKEAuthResponseSA.Proposals[0].IntegrityAlgorithm[0] - } - - // Get key length for encryption and integrity key for IPSec - var lengthEncryptionKeyIPSec, lengthIntegrityKeyIPSec, totalKeyLength int - var ok bool - - if lengthEncryptionKeyIPSec, ok = getKeyLength(transformEncryptionAlgorithmForIPSec.TransformType, - transformEncryptionAlgorithmForIPSec.TransformID, - transformEncryptionAlgorithmForIPSec.AttributePresent, - transformEncryptionAlgorithmForIPSec.AttributeValue); !ok { - ikeLog.Error("Get key length of an unsupported algorithm. This may imply an unsupported transform is chosen.") - return errors.New("Get key length failed") - } - if transformIntegrityAlgorithmForIPSec != nil { - if lengthIntegrityKeyIPSec, ok = getKeyLength(transformIntegrityAlgorithmForIPSec.TransformType, - transformIntegrityAlgorithmForIPSec.TransformID, - transformIntegrityAlgorithmForIPSec.AttributePresent, - transformIntegrityAlgorithmForIPSec.AttributeValue); !ok { - ikeLog.Error("Get key length of an unsupported algorithm. This may imply an unsupported transform is chosen.") - return errors.New("Get key length failed") - } - } - totalKeyLength = lengthEncryptionKeyIPSec + lengthIntegrityKeyIPSec - totalKeyLength = totalKeyLength * 2 - - // Generate key for child security association as specified in RFC 7296 section 2.17 - seed := ikeSecurityAssociation.ConcatenatedNonce - var pseudorandomFunction hash.Hash - - var keyStream, generatedKeyBlock []byte - var index byte - for index = 1; len(keyStream) < totalKeyLength; index++ { - if pseudorandomFunction, ok = NewPseudorandomFunction(ikeSecurityAssociation.SK_d, - transformPseudorandomFunction.TransformID); !ok { - ikeLog.Error("Get an unsupported pseudorandom funcion. This may imply an unsupported transform is chosen.") - return errors.New("New pseudorandom function failed") - } - if _, err := pseudorandomFunction.Write(append(append(generatedKeyBlock, seed...), index)); err != nil { - ikeLog.Errorf("Pseudorandom function write error: %+v", err) - return errors.New("Pseudorandom function write failed") - } - generatedKeyBlock = pseudorandomFunction.Sum(nil) - keyStream = append(keyStream, generatedKeyBlock...) - } - - childSecurityAssociation.InitiatorToResponderEncryptionKey = append( - childSecurityAssociation.InitiatorToResponderEncryptionKey, - keyStream[:lengthEncryptionKeyIPSec]...) - keyStream = keyStream[lengthEncryptionKeyIPSec:] - childSecurityAssociation.InitiatorToResponderIntegrityKey = append( - childSecurityAssociation.InitiatorToResponderIntegrityKey, - keyStream[:lengthIntegrityKeyIPSec]...) - keyStream = keyStream[lengthIntegrityKeyIPSec:] - childSecurityAssociation.ResponderToInitiatorEncryptionKey = append( - childSecurityAssociation.ResponderToInitiatorEncryptionKey, - keyStream[:lengthEncryptionKeyIPSec]...) - keyStream = keyStream[lengthEncryptionKeyIPSec:] - childSecurityAssociation.ResponderToInitiatorIntegrityKey = append( - childSecurityAssociation.ResponderToInitiatorIntegrityKey, - keyStream[:lengthIntegrityKeyIPSec]...) - - return nil -} - -// Decrypt -func DecryptProcedure(ikeSecurityAssociation *n3iwf_context.IKESecurityAssociation, ikeMessage *message.IKEMessage, - encryptedPayload *message.Encrypted, -) (message.IKEPayloadContainer, error) { - ikeLog := logger.IKELog - // Check parameters - if ikeSecurityAssociation == nil { - return nil, errors.New("IKE SA is nil") - } - if ikeMessage == nil { - return nil, errors.New("IKE message is nil") - } - if encryptedPayload == nil { - return nil, errors.New("IKE encrypted payload is nil") - } - - // Check if the context contain needed data - if ikeSecurityAssociation.IntegrityAlgorithm == nil { - return nil, errors.New("No integrity algorithm specified") - } - if ikeSecurityAssociation.EncryptionAlgorithm == nil { - return nil, errors.New("No encryption algorithm specified") - } - - if len(ikeSecurityAssociation.SK_ai) == 0 { - return nil, errors.New("No initiator's integrity key") - } - if len(ikeSecurityAssociation.SK_ei) == 0 { - return nil, errors.New("No initiator's encryption key") - } - - // Load needed information - transformIntegrityAlgorithm := ikeSecurityAssociation.IntegrityAlgorithm - transformEncryptionAlgorithm := ikeSecurityAssociation.EncryptionAlgorithm - checksumLength, ok := getOutputLength(transformIntegrityAlgorithm.TransformType, - transformIntegrityAlgorithm.TransformID, transformIntegrityAlgorithm.AttributePresent, - transformIntegrityAlgorithm.AttributeValue) - if !ok { - ikeLog.Error("Get key length of an unsupported algorithm. This may imply an unsupported transform is chosen.") - return nil, errors.New("Get key length failed") - } - - // Checksum - checksum := encryptedPayload.EncryptedData[len(encryptedPayload.EncryptedData)-checksumLength:] - - ikeMessageData, err := ikeMessage.Encode() - if err != nil { - ikeLog.Errorln(err) - ikeLog.Error("Error occur when encoding for checksum") - return nil, errors.New("Encoding IKE message failed") - } - - ok, err = verifyIntegrity(ikeSecurityAssociation.SK_ai, - ikeMessageData[:len(ikeMessageData)-checksumLength], checksum, - transformIntegrityAlgorithm) - if err != nil { - ikeLog.Errorf("Error occur when verifying checksum: %+v", err) - return nil, errors.New("Error verify checksum") - } - if !ok { - ikeLog.Warn("Message checksum failed. Drop the message.") - return nil, errors.New("Checksum failed, drop.") - } - - // Decrypt - encryptedData := encryptedPayload.EncryptedData[:len(encryptedPayload.EncryptedData)-checksumLength] - plainText, err := DecryptMessage(ikeSecurityAssociation.SK_ei, encryptedData, - transformEncryptionAlgorithm.TransformID) - if err != nil { - ikeLog.Errorf("Error occur when decrypting message: %+v", err) - return nil, errors.New("Error decrypting message") - } - - var decryptedIKEPayload message.IKEPayloadContainer - err = decryptedIKEPayload.Decode(encryptedPayload.NextPayload, plainText) - if err != nil { - ikeLog.Errorln(err) - return nil, errors.New("Decoding decrypted payload failed") - } - - return decryptedIKEPayload, nil -} - -// Encrypt -func EncryptProcedure(ikeSecurityAssociation *n3iwf_context.IKESecurityAssociation, - ikePayload message.IKEPayloadContainer, responseIKEMessage *message.IKEMessage, -) error { - ikeLog := logger.IKELog - // Check parameters - if ikeSecurityAssociation == nil { - return errors.New("IKE SA is nil") - } - if len(ikePayload) == 0 { - return errors.New("No IKE payload to be encrypted") - } - if responseIKEMessage == nil { - return errors.New("Response IKE message is nil") - } - - // Check if the context contain needed data - if ikeSecurityAssociation.IntegrityAlgorithm == nil { - return errors.New("No integrity algorithm specified") - } - if ikeSecurityAssociation.EncryptionAlgorithm == nil { - return errors.New("No encryption algorithm specified") - } - - if len(ikeSecurityAssociation.SK_ar) == 0 { - return errors.New("No responder's integrity key") - } - if len(ikeSecurityAssociation.SK_er) == 0 { - return errors.New("No responder's encryption key") - } - - // Load needed information - transformIntegrityAlgorithm := ikeSecurityAssociation.IntegrityAlgorithm - transformEncryptionAlgorithm := ikeSecurityAssociation.EncryptionAlgorithm - checksumLength, ok := getOutputLength(transformIntegrityAlgorithm.TransformType, - transformIntegrityAlgorithm.TransformID, transformIntegrityAlgorithm.AttributePresent, - transformIntegrityAlgorithm.AttributeValue) - if !ok { - ikeLog.Error("Get key length of an unsupported algorithm. This may imply an unsupported transform is chosen.") - return errors.New("Get key length failed") - } - - // Encrypting - ikePayloadData, err := ikePayload.Encode() - if err != nil { - ikeLog.Error(err) - return errors.New("Encoding IKE payload failed.") - } - - encryptedData, err := EncryptMessage(ikeSecurityAssociation.SK_er, ikePayloadData, - transformEncryptionAlgorithm.TransformID) - if err != nil { - ikeLog.Errorf("Encrypting data error: %+v", err) - return errors.New("Error encrypting message") - } - - encryptedData = append(encryptedData, make([]byte, checksumLength)...) - sk := responseIKEMessage.Payloads.BuildEncrypted(ikePayload[0].Type(), encryptedData) - - // Calculate checksum - responseIKEMessageData, err := responseIKEMessage.Encode() - if err != nil { - ikeLog.Error(err) - return errors.New("Encoding IKE message error") - } - checksumOfMessage, err := calculateIntegrity(ikeSecurityAssociation.SK_ar, - responseIKEMessageData[:len(responseIKEMessageData)-checksumLength], - transformIntegrityAlgorithm) - if err != nil { - ikeLog.Errorf("Calculating checksum failed: %+v", err) - return errors.New("Error calculating checksum") - } - checksumField := sk.EncryptedData[len(sk.EncryptedData)-checksumLength:] - copy(checksumField, checksumOfMessage) - - return nil -} - -// Get information of algorithm -func getKeyLength(transformType uint8, transformID uint16, attributePresent bool, - attributeValue uint16, -) (int, bool) { - switch transformType { - case message.TypeEncryptionAlgorithm: - switch transformID { - case message.ENCR_DES_IV64: - return 0, false - case message.ENCR_DES: - return 8, true - case message.ENCR_3DES: - return 24, true - case message.ENCR_RC5: - return 0, false - case message.ENCR_IDEA: - return 0, false - case message.ENCR_CAST: - if attributePresent { - switch attributeValue { - case 128: - return 16, true - case 256: - return 0, false - default: - return 0, false - } - } - return 0, false - case message.ENCR_BLOWFISH: // Blowfish support variable key length - if attributePresent { - if attributeValue < 40 { - return 0, false - } else if attributeValue > 448 { - return 0, false - } else { - return int(attributeValue / 8), true - } - } else { - return 0, false - } - case message.ENCR_3IDEA: - return 0, false - case message.ENCR_DES_IV32: - return 0, false - case message.ENCR_NULL: - return 0, true - case message.ENCR_AES_CBC: - if attributePresent { - switch attributeValue { - case 128: - return 16, true - case 192: - return 24, true - case 256: - return 32, true - default: - return 0, false - } - } else { - return 0, false - } - case message.ENCR_AES_CTR: - if attributePresent { - switch attributeValue { - case 128: - return 20, true - case 192: - return 28, true - case 256: - return 36, true - default: - return 0, false - } - } else { - return 0, false - } - default: - return 0, false - } - case message.TypePseudorandomFunction: - switch transformID { - case message.PRF_HMAC_MD5: - return 16, true - case message.PRF_HMAC_SHA1: - return 20, true - case message.PRF_HMAC_SHA2_256: - return 32, true - case message.PRF_HMAC_TIGER: - return 0, false - default: - return 0, false - } - case message.TypeIntegrityAlgorithm: - switch transformID { - case message.AUTH_NONE: - return 0, false - case message.AUTH_HMAC_MD5_96: - return 16, true - case message.AUTH_HMAC_SHA1_96: - return 20, true - case message.AUTH_DES_MAC: - return 0, false - case message.AUTH_KPDK_MD5: - return 0, false - case message.AUTH_AES_XCBC_96: - return 0, false - case message.AUTH_HMAC_SHA2_256_128: - return 32, true - default: - return 0, false - } - case message.TypeDiffieHellmanGroup: - switch transformID { - case message.DH_NONE: - return 0, false - case message.DH_768_BIT_MODP: - return 0, false - case message.DH_1024_BIT_MODP: - return 0, false - case message.DH_1536_BIT_MODP: - return 0, false - case message.DH_2048_BIT_MODP: - return 0, false - case message.DH_3072_BIT_MODP: - return 0, false - case message.DH_4096_BIT_MODP: - return 0, false - case message.DH_6144_BIT_MODP: - return 0, false - case message.DH_8192_BIT_MODP: - return 0, false - default: - return 0, false - } - default: - return 0, false - } -} - -func getOutputLength( - transformType uint8, - transformID uint16, - attributePresent bool, - attributeValue uint16, -) (int, bool) { - _ = attributePresent - _ = attributeValue - - switch transformType { - case message.TypePseudorandomFunction: - switch transformID { - case message.PRF_HMAC_MD5: - return 16, true - case message.PRF_HMAC_SHA1: - return 20, true - case message.PRF_HMAC_TIGER: - return 0, false - case message.PRF_HMAC_SHA2_256: - return 32, true - default: - return 0, false - } - case message.TypeIntegrityAlgorithm: - switch transformID { - case message.AUTH_NONE: - return 0, false - case message.AUTH_HMAC_MD5_96: - return 12, true - case message.AUTH_HMAC_SHA1_96: - return 12, true - case message.AUTH_DES_MAC: - return 0, false - case message.AUTH_KPDK_MD5: - return 0, false - case message.AUTH_AES_XCBC_96: - return 0, false - case message.AUTH_HMAC_SHA2_256_128: - return 16, true - default: - return 0, false - } - default: - return 0, false - } -} - -func concatenateNonceAndSPI(nonce []byte, SPI_initiator uint64, SPI_responder uint64) []byte { - spi := make([]byte, 8) - - binary.BigEndian.PutUint64(spi, SPI_initiator) - newSlice := append(nonce, spi...) - binary.BigEndian.PutUint64(spi, SPI_responder) - newSlice = append(newSlice, spi...) - - return newSlice -} diff --git a/pkg/ike/security/security_test.go b/pkg/ike/security/security_test.go deleted file mode 100644 index ad6a0007..00000000 --- a/pkg/ike/security/security_test.go +++ /dev/null @@ -1,135 +0,0 @@ -package security - -import ( - "encoding/hex" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/free5gc/n3iwf/pkg/ike/message" -) - -func TestVerifyIntegrity(t *testing.T) { - tests := []struct { - name string - key string - originData []byte - checksum string - transform *message.Transform - expectedValid bool - }{ - { - name: "HMAC MD5 96 - valid", - key: "0123456789abcdef0123456789abcdef", - originData: []byte("hello world"), - checksum: "c30f366e411540f68221d04a", - transform: &message.Transform{ - TransformType: message.TypeIntegrityAlgorithm, - TransformID: message.AUTH_HMAC_MD5_96, - }, - expectedValid: true, - }, - { - name: "HMAC MD5 96 - invalid checksum", - key: "0123456789abcdef", - originData: []byte("hello world"), - checksum: "01231875aa", - transform: &message.Transform{ - TransformType: message.TypeIntegrityAlgorithm, - TransformID: message.AUTH_HMAC_MD5_96, - }, - expectedValid: false, - }, - { - name: "HMAC MD5 96 - invalid key length", - key: "0123", - originData: []byte("hello world"), - transform: &message.Transform{ - TransformType: message.TypeIntegrityAlgorithm, - TransformID: message.AUTH_HMAC_MD5_96, - }, - expectedValid: false, - }, - { - name: "HMAC SHA1 96 - valid", - key: "0123456789abcdef0123456789abcdef01234567", - originData: []byte("hello world"), - checksum: "5089f6a86e4dafb89e3fcd23", - transform: &message.Transform{ - TransformType: message.TypeIntegrityAlgorithm, - TransformID: message.AUTH_HMAC_SHA1_96, - }, - expectedValid: true, - }, - { - name: "HMAC SHA1 96 - invalid checksum", - key: "0123456789abcdef0123456789abcdef01234567", - originData: []byte("hello world"), - checksum: "01231875aa", - transform: &message.Transform{ - TransformType: message.TypeIntegrityAlgorithm, - TransformID: message.AUTH_HMAC_SHA1_96, - }, - expectedValid: false, - }, - { - name: "HMAC SHA1 96 - invalid key length", - key: "0123", - originData: []byte("hello world"), - transform: &message.Transform{ - TransformType: message.TypeIntegrityAlgorithm, - TransformID: message.AUTH_HMAC_SHA1_96, - }, - expectedValid: false, - }, - { - name: "HMAC SHA256 128 - valid", - key: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", - originData: []byte("hello world"), - checksum: "a64166565bc1f48eb3edd4109fcaeb72", - transform: &message.Transform{ - TransformType: message.TypeIntegrityAlgorithm, - TransformID: message.AUTH_HMAC_SHA2_256_128, - }, - expectedValid: true, - }, - { - name: "HMAC SHA256 128 - invalid checksum", - key: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", - originData: []byte("hello world"), - checksum: "01231875aa", - transform: &message.Transform{ - TransformType: message.TypeIntegrityAlgorithm, - TransformID: message.AUTH_HMAC_SHA1_96, - }, - expectedValid: false, - }, - { - name: "HMAC SHA256 128 - invalid key length", - key: "0123", - originData: []byte("hello world"), - transform: &message.Transform{ - TransformType: message.TypeIntegrityAlgorithm, - TransformID: message.AUTH_HMAC_SHA1_96, - }, - expectedValid: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var key, checksum []byte - var err error - checksum, err = hex.DecodeString(tt.checksum) - require.NoError(t, err, "failed to decode checksum hex string") - - key, err = hex.DecodeString(tt.key) - require.NoError(t, err, "failed to decode key hex string") - - valid, err := verifyIntegrity(key, tt.originData, checksum, tt.transform) - if tt.expectedValid { - require.NoError(t, err, "verifyIntegrity returned an error") - } - require.Equal(t, tt.expectedValid, valid) - }) - } -} diff --git a/pkg/ike/send.go b/pkg/ike/send.go deleted file mode 100644 index c7612d8d..00000000 --- a/pkg/ike/send.go +++ /dev/null @@ -1,109 +0,0 @@ -package ike - -import ( - "encoding/binary" - "net" - - "github.com/free5gc/n3iwf/internal/logger" - n3iwf_context "github.com/free5gc/n3iwf/pkg/context" - "github.com/free5gc/n3iwf/pkg/ike/message" - "github.com/free5gc/n3iwf/pkg/ike/security" -) - -func SendIKEMessageToUE( - udpConn *net.UDPConn, - srcAddr, dstAddr *net.UDPAddr, - message *message.IKEMessage, -) { - ikeLog := logger.IKELog - ikeLog.Trace("Send IKE message to UE") - ikeLog.Trace("Encoding...") - pkt, err := message.Encode() - if err != nil { - ikeLog.Errorln(err) - return - } - // As specified in RFC 7296 section 3.1, the IKE message send from/to UDP port 4500 - // should prepend a 4 bytes zero - if srcAddr.Port == 4500 { - prependZero := make([]byte, 4) - pkt = append(prependZero, pkt...) - } - - ikeLog.Trace("Sending...") - n, err := udpConn.WriteToUDP(pkt, dstAddr) - if err != nil { - ikeLog.Error(err) - return - } - if n != len(pkt) { - ikeLog.Errorf("Not all of the data is sent. Total length: %d. Sent: %d.", len(pkt), n) - return - } -} - -func SendUEInformationExchange( - ikeUe *n3iwf_context.N3IWFIkeUe, - payload message.IKEPayloadContainer, -) { - ikeLog := logger.IKELog - ikeSA := ikeUe.N3IWFIKESecurityAssociation - responseIKEMessage := new(message.IKEMessage) - - // Build IKE message - responseIKEMessage.BuildIKEHeader( - ikeSA.RemoteSPI, ikeSA.LocalSPI, - message.INFORMATIONAL, 0, - ikeSA.ResponderMessageID) - if payload != nil { // This message isn't a DPD message - err := security.EncryptProcedure( - ikeSA, payload, responseIKEMessage) - if err != nil { - ikeLog.Errorf("Encrypting IKE message failed: %+v", err) - return - } - } - SendIKEMessageToUE( - ikeUe.IKEConnection.Conn, ikeUe.IKEConnection.N3IWFAddr, - ikeUe.IKEConnection.UEAddr, responseIKEMessage) -} - -func SendIKEDeleteRequest(n3iwfCtx *n3iwf_context.N3IWFContext, localSPI uint64) { - ikeLog := logger.IKELog - ikeUe, ok := n3iwfCtx.IkeUePoolLoad(localSPI) - if !ok { - ikeLog.Errorf("Cannot get IkeUE from SPI : %+v", localSPI) - return - } - - var deletePayload message.IKEPayloadContainer - deletePayload.BuildDeletePayload(message.TypeIKE, 0, 0, nil) - SendUEInformationExchange(ikeUe, deletePayload) -} - -func SendChildSADeleteRequest( - ikeUe *n3iwf_context.N3IWFIkeUe, - relaseList []int64, -) { - ikeLog := logger.IKELog - var deleteSPIs []byte - spiLen := uint16(0) - for _, releaseItem := range relaseList { - for _, childSA := range ikeUe.N3IWFChildSecurityAssociation { - if childSA.PDUSessionIds[0] == releaseItem { - spiByte := make([]byte, 4) - binary.BigEndian.PutUint32(spiByte, uint32(childSA.XfrmStateList[0].Spi)) - deleteSPIs = append(deleteSPIs, spiByte...) - spiLen += 1 - err := ikeUe.DeleteChildSA(childSA) - if err != nil { - ikeLog.Errorf("Delete Child SA error : %v", err) - } - } - } - } - - var deletePayload message.IKEPayloadContainer - deletePayload.BuildDeletePayload(message.TypeESP, 4, spiLen, deleteSPIs) - SendUEInformationExchange(ikeUe, deletePayload) -} diff --git a/pkg/ike/server.go b/pkg/ike/server.go deleted file mode 100644 index 5b8ab7c4..00000000 --- a/pkg/ike/server.go +++ /dev/null @@ -1,177 +0,0 @@ -package ike - -import ( - "context" - "net" - "runtime/debug" - "sync" - - "github.com/pkg/errors" - - "github.com/free5gc/n3iwf/internal/logger" - n3iwf_context "github.com/free5gc/n3iwf/pkg/context" - "github.com/free5gc/n3iwf/pkg/factory" -) - -var ( - RECEIVE_IKEPACKET_CHANNEL_LEN = 512 - RECEIVE_IKEEVENT_CHANNEL_LEN = 512 -) - -type n3iwf interface { - Config() *factory.Config - Context() *n3iwf_context.N3IWFContext - CancelContext() context.Context - NgapEvtCh() chan n3iwf_context.NgapEvt -} - -type Server struct { - n3iwf - - Listener map[int]*net.UDPConn - RcvIkePktCh chan IkeReceivePacket - RcvEventCh chan n3iwf_context.IkeEvt - StopServer chan struct{} -} - -type IkeReceivePacket struct { - Listener net.UDPConn - LocalAddr net.UDPAddr - RemoteAddr net.UDPAddr - Msg []byte -} - -func NewServer(n3iwf n3iwf) (*Server, error) { - s := &Server{ - n3iwf: n3iwf, - Listener: make(map[int]*net.UDPConn), - RcvIkePktCh: make(chan IkeReceivePacket, RECEIVE_IKEPACKET_CHANNEL_LEN), - RcvEventCh: make(chan n3iwf_context.IkeEvt, RECEIVE_IKEEVENT_CHANNEL_LEN), - StopServer: make(chan struct{}), - } - return s, nil -} - -func (s *Server) Run(wg *sync.WaitGroup) error { - cfg := s.Config() - - // Resolve UDP addresses - ip := cfg.GetIKEBindAddr() - udpAddrPort500, err := net.ResolveUDPAddr("udp", ip+":500") - if err != nil { - return errors.Wrapf(err, "ResolveUDPAddr (%s:500)", ip) - } - udpAddrPort4500, err := net.ResolveUDPAddr("udp", ip+":4500") - if err != nil { - return errors.Wrapf(err, "ResolveUDPAddr (%s:4500)", ip) - } - - // Listen and serve - var errChan chan error - - // Port 500 - wg.Add(1) - errChan = make(chan error) - go s.receiver(udpAddrPort500, errChan, wg) - if err, ok := <-errChan; ok { - return errors.Wrapf(err, "udp 500") - } - - // Port 4500 - wg.Add(1) - errChan = make(chan error) - go s.receiver(udpAddrPort4500, errChan, wg) - if err, ok := <-errChan; ok { - return errors.Wrapf(err, "udp 4500") - } - - wg.Add(1) - go s.server(wg) - - return nil -} - -func (s *Server) server(wg *sync.WaitGroup) { - ikeLog := logger.IKELog - defer func() { - if p := recover(); p != nil { - // Print stack for panic to log. Fatalf() will let program exit. - ikeLog.Fatalf("panic: %v\n%s", p, string(debug.Stack())) - } - ikeLog.Infof("Ike server stopped") - close(s.RcvIkePktCh) - close(s.StopServer) - wg.Done() - }() - - for { - select { - case rcvPkt := <-s.RcvIkePktCh: - ikeLog.Tracef("Receive IKE packet") - s.Dispatch(&rcvPkt.Listener, &rcvPkt.LocalAddr, &rcvPkt.RemoteAddr, rcvPkt.Msg) - case rcvIkeEvent := <-s.RcvEventCh: - s.HandleEvent(rcvIkeEvent) - case <-s.StopServer: - return - } - } -} - -func (s *Server) receiver( - localAddr *net.UDPAddr, - errChan chan<- error, - wg *sync.WaitGroup, -) { - ikeLog := logger.IKELog - defer func() { - if p := recover(); p != nil { - // Print stack for panic to log. Fatalf() will let program exit. - ikeLog.Fatalf("panic: %v\n%s", p, string(debug.Stack())) - } - ikeLog.Infof("Ike receiver stopped") - wg.Done() - }() - - listener, err := net.ListenUDP("udp", localAddr) - if err != nil { - ikeLog.Errorf("Listen UDP failed: %+v", err) - errChan <- errors.New("listenAndServe failed") - return - } - - close(errChan) - - s.Listener[localAddr.Port] = listener - - data := make([]byte, 65535) - - for { - n, remoteAddr, err := listener.ReadFromUDP(data) - if err != nil { - ikeLog.Errorf("ReadFromUDP failed: %+v", err) - return - } - - forwardData := make([]byte, n) - copy(forwardData, data[:n]) - s.RcvIkePktCh <- IkeReceivePacket{ - RemoteAddr: *remoteAddr, - Listener: *listener, - LocalAddr: *localAddr, - Msg: forwardData, - } - } -} - -func (s *Server) Stop() { - ikeLog := logger.IKELog - ikeLog.Infof("Close Ike server...") - - for _, ikeServerListener := range s.Listener { - if err := ikeServerListener.Close(); err != nil { - ikeLog.Errorf("Stop ike server : %s error : %+v", err, ikeServerListener.LocalAddr().String()) - } - } - - s.StopServer <- struct{}{} -} diff --git a/pkg/ike/server_test.go b/pkg/ike/server_test.go deleted file mode 100644 index a0c57030..00000000 --- a/pkg/ike/server_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package ike - -import ( - "context" - "sync" - - "github.com/free5gc/n3iwf/internal/ngap" - n3iwf_context "github.com/free5gc/n3iwf/pkg/context" - "github.com/free5gc/n3iwf/pkg/factory" -) - -type n3iwfTestApp struct { - cfg *factory.Config - n3iwfCtx *n3iwf_context.N3IWFContext - ngapServer *ngap.Server - ikeServer *Server - ctx context.Context - cancel context.CancelFunc - wg *sync.WaitGroup -} - -func (a *n3iwfTestApp) Config() *factory.Config { - return a.cfg -} - -func (a *n3iwfTestApp) Context() *n3iwf_context.N3IWFContext { - return a.n3iwfCtx -} - -func (a *n3iwfTestApp) CancelContext() context.Context { - return a.ctx -} - -func (a *n3iwfTestApp) NgapEvtCh() chan n3iwf_context.NgapEvt { - return a.ngapServer.RcvEventCh -} - -func NewN3iwfTestApp(cfg *factory.Config) (*n3iwfTestApp, error) { - var err error - ctx, cancel := context.WithCancel(context.Background()) - - n3iwfApp := &n3iwfTestApp{ - cfg: cfg, - ctx: ctx, - cancel: cancel, - wg: &sync.WaitGroup{}, - } - - n3iwfApp.n3iwfCtx, err = n3iwf_context.NewTestContext(n3iwfApp) - if err != nil { - return nil, err - } - return n3iwfApp, err -} diff --git a/pkg/service/init.go b/pkg/service/init.go index 5d16f87a..fa79750e 100644 --- a/pkg/service/init.go +++ b/pkg/service/init.go @@ -13,15 +13,15 @@ import ( "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" + n3iwf_context "github.com/free5gc/n3iwf/internal/context" + "github.com/free5gc/n3iwf/internal/ike" + "github.com/free5gc/n3iwf/internal/ike/xfrm" "github.com/free5gc/n3iwf/internal/logger" "github.com/free5gc/n3iwf/internal/ngap" "github.com/free5gc/n3iwf/internal/nwucp" "github.com/free5gc/n3iwf/internal/nwuup" "github.com/free5gc/n3iwf/pkg/app" - n3iwf_context "github.com/free5gc/n3iwf/pkg/context" "github.com/free5gc/n3iwf/pkg/factory" - "github.com/free5gc/n3iwf/pkg/ike" - "github.com/free5gc/n3iwf/pkg/ike/xfrm" ) var N3IWF *N3iwfApp @@ -106,12 +106,13 @@ func (a *N3iwfApp) SetLogEnable(enable bool) { func (a *N3iwfApp) SetLogLevel(level string) { lvl, err := logrus.ParseLevel(level) + mainLog := logger.MainLog if err != nil { - logger.MainLog.Warnf("Log level [%s] is invalid", level) + mainLog.Warnf("Log level [%s] is invalid", level) return } - logger.MainLog.Infof("Log level is set to [%s]", level) + mainLog.Infof("Log level is set to [%s]", level) if lvl == logger.Log.GetLevel() { return } @@ -134,6 +135,7 @@ func (a *N3iwfApp) Run() error { if err := a.initDefaultXfrmInterface(); err != nil { return err } + mainLog := logger.MainLog a.wg.Add(1) go a.listenShutdownEvent() @@ -142,28 +144,28 @@ func (a *N3iwfApp) Run() error { if err := a.ngapServer.Run(&a.wg); err != nil { return errors.Wrapf(err, "Run()") } - logger.MainLog.Infof("NGAP service running.") + mainLog.Infof("NGAP service running.") // Relay listeners // Control plane if err := a.nwucpServer.Run(&a.wg); err != nil { return errors.Wrapf(err, "Listen NWu control plane traffic failed") } - logger.MainLog.Infof("NAS TCP server successfully started.") + mainLog.Infof("NAS TCP server successfully started.") - // User plane + // User plane of N3IWF if err := a.nwuupServer.Run(&a.wg); err != nil { return errors.Wrapf(err, "Listen NWu user plane traffic failed") } - logger.MainLog.Infof("Listening NWu user plane traffic") + mainLog.Infof("Listening NWu user plane traffic") // IKE if err := a.ikeServer.Run(&a.wg); err != nil { return errors.Wrapf(err, "Start IKE service failed") } - logger.MainLog.Infof("IKE service running") + mainLog.Infof("IKE service running") - logger.MainLog.Infof("N3IWF started") + mainLog.Infof("N3IWF started") a.WaitRoutineStopped() return nil @@ -201,26 +203,27 @@ func (a *N3iwfApp) initDefaultXfrmInterface() error { var err error n3iwfCtx := a.n3iwfCtx cfg := a.Config() + mainLog := logger.MainLog n3iwfIPAddr := net.ParseIP(cfg.GetIPSecGatewayAddr()).To4() - n3iwfIPAddrAndSubnet := net.IPNet{IP: n3iwfIPAddr, Mask: n3iwfCtx.UeIPRange.Mask} + n3iwfIPAddrAndSubnet := net.IPNet{IP: n3iwfIPAddr, Mask: n3iwfCtx.IPSecInnerIPPool.IPSubnet.Mask} newXfrmiName := fmt.Sprintf("%s-default", cfg.GetXfrmIfaceName()) if linkIPSec, err = xfrm.SetupIPsecXfrmi(newXfrmiName, n3iwfCtx.XfrmParentIfaceName, cfg.GetXfrmIfaceId(), n3iwfIPAddrAndSubnet); err != nil { - logger.MainLog.Errorf("Setup XFRM interface %s fail: %+v", newXfrmiName, err) + mainLog.Errorf("Setup XFRM interface %s fail: %+v", newXfrmiName, err) return err } route := &netlink.Route{ LinkIndex: linkIPSec.Attrs().Index, - Dst: n3iwfCtx.UeIPRange, + Dst: n3iwfCtx.IPSecInnerIPPool.IPSubnet, } if err := netlink.RouteAdd(route); err != nil { - logger.MainLog.Warnf("netlink.RouteAdd: %+v", err) + mainLog.Warnf("netlink.RouteAdd: %+v", err) } - logger.MainLog.Infof("Setup XFRM interface %s ", newXfrmiName) + mainLog.Infof("Setup XFRM interface %s ", newXfrmiName) n3iwfCtx.XfrmIfaces.LoadOrStore(cfg.GetXfrmIfaceId(), linkIPSec) n3iwfCtx.XfrmIfaceIdOffsetForUP = 1 @@ -229,16 +232,18 @@ func (a *N3iwfApp) initDefaultXfrmInterface() error { } func (a *N3iwfApp) removeIPsecInterfaces() { + mainLog := logger.MainLog a.n3iwfCtx.XfrmIfaces.Range( func(key, value interface{}) bool { iface := value.(netlink.Link) if err := netlink.LinkDel(iface); err != nil { - logger.MainLog.Errorf("Delete interface %s fail: %+v", iface.Attrs().Name, err) + mainLog.Errorf("Delete interface %s fail: %+v", iface.Attrs().Name, err) } else { - logger.MainLog.Infof("Delete interface: %s", iface.Attrs().Name) + mainLog.Infof("Delete interface: %s", iface.Attrs().Name) } return true - }) + }, + ) } func (a *N3iwfApp) Terminate() { @@ -249,18 +254,15 @@ func (a *N3iwfApp) terminateProcedure() { logger.MainLog.Info("Stopping service created by N3IWF") a.ngapServer.Stop() - a.nwucpServer.Stop() - a.nwuupServer.Stop() - a.ikeServer.Stop() } -func (a *N3iwfApp) NgapEvtCh() chan n3iwf_context.NgapEvt { - return a.ngapServer.RcvEventCh +func (a *N3iwfApp) SendNgapEvt(evt n3iwf_context.NgapEvt) { + a.ngapServer.SendNgapEvt(evt) } -func (a *N3iwfApp) IkeEvtCh() chan n3iwf_context.IkeEvt { - return a.ikeServer.RcvEventCh +func (a *N3iwfApp) SendIkeEvt(evt n3iwf_context.IkeEvt) { + a.ikeServer.SendIkeEvt(evt) }