Skip to content

Commit

Permalink
Use pure Lua endian conversion throughout
Browse files Browse the repository at this point in the history
Also fix a bug in lib.bitfield() to properly truncate the value to be
set to the size of the bitfield.
  • Loading branch information
alexandergall committed Oct 2, 2015
1 parent fadf244 commit 4665cdf
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 80 deletions.
71 changes: 36 additions & 35 deletions src/core/lib.lua
Original file line number Diff line number Diff line change
Expand Up @@ -171,41 +171,6 @@ function bitset (value, n)
return band(value, lshift(1, n)) ~= 0
end

-- Manipulation of bit fields in uint{8,16,32)_t stored in network
-- byte order. Using bit fields in C structs is compiler-dependent
-- and a little awkward for handling endianness and fields that cross
-- byte boundaries. We're bound to the LuaJIT compiler, so I guess
-- this would be save, but masking and shifting is guaranteed to be
-- portable. Performance could be an issue, though.

local bitfield_endian_conversion =
{ [16] = { ntoh = C.ntohs, hton = C.htons },
[32] = { ntoh = C.ntohl, hton = C.htonl }
}

function bitfield(size, struct, member, offset, nbits, value)
local conv = bitfield_endian_conversion[size]
local field
if conv then
field = conv.ntoh(struct[member])
else
field = struct[member]
end
local shift = size-(nbits+offset)
local mask = lshift(2^nbits-1, shift)
local imask = bnot(mask)
if value then
field = bor(band(field, imask), lshift(value, shift))
if conv then
struct[member] = conv.hton(field)
else
struct[member] = field
end
else
return rshift(band(field, mask), shift)
end
end

-- Iterator factory for splitting a string by pattern
-- (http://lua-users.org/lists/lua-l/2006-12/msg00414.html)
function string:split(pat)
Expand Down Expand Up @@ -410,6 +375,42 @@ function eq32 (a, b)
return band(a, 0xFFFFFFFF) == band(b, 0xFFFFFFFF)
end

-- Manipulation of bit fields in uint{8,16,32)_t stored in network
-- byte order. Using bit fields in C structs is compiler-dependent
-- and a little awkward for handling endianness and fields that cross
-- byte boundaries. We're bound to the LuaJIT compiler, so I guess
-- this would be save, but masking and shifting is guaranteed to be
-- portable.

local bitfield_endian_conversion =
{ [16] = { ntoh = ntohs, hton = htons },
[32] = { ntoh = ntohl, hton = htonl }
}

function bitfield(size, struct, member, offset, nbits, value)
local conv = bitfield_endian_conversion[size]
local field
if conv then
field = conv.ntoh(struct[member])
else
field = struct[member]
end
local shift = size-(nbits+offset)
local mask = lshift(2^nbits-1, shift)
local imask = bnot(mask)
if value then
field = bor(band(field, imask),
band(lshift(value, shift), mask))
if conv then
struct[member] = conv.hton(field)
else
struct[member] = field
end
else
return rshift(band(field, mask), shift)
end
end

-- Process ARGS using ACTIONS with getopt OPTS/LONG_OPTS.
-- Return the remaining unprocessed arguments.
function dogetopt (args, actions, opts, long_opts)
Expand Down
6 changes: 4 additions & 2 deletions src/lib/protocol/ethernet.lua
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
module(..., package.seeall)
local ffi = require("ffi")
local C = ffi.C
local lib = require("core.lib")
local header = require("lib.protocol.header")
local ipv6 = require("lib.protocol.ipv6")
local band = require("bit").band
local ntohs, htons = lib.ntohs, lib.htons

local ether_header_t = ffi.typeof[[
struct {
Expand Down Expand Up @@ -116,9 +118,9 @@ end
function ethernet:type (t)
local h = self:header()
if t ~= nil then
h.ether_type = C.htons(t)
h.ether_type = htons(t)
else
return(C.ntohs(h.ether_type))
return(ntohs(h.ether_type))
end
end

Expand Down
16 changes: 9 additions & 7 deletions src/lib/protocol/gre.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ local header = require("lib.protocol.header")
local lib = require("core.lib")
local bitfield = lib.bitfield
local ipsum = require("lib.checksum").ipsum
local ntohs, htons, ntohl, htonl =
lib.ntohs, lib.htons, lib.ntohl, lib.htonl

-- GRE uses a variable-length header as specified by RFCs 2784 and
-- 2890. The actual size is determined by flag bits in the base
Expand Down Expand Up @@ -123,16 +125,16 @@ function gre:checksum (payload, length)
end
if payload ~= nil then
-- Calculate and set the checksum
self:header().csum = C.htons(checksum(self:header(), payload, length))
self:header().csum = htons(checksum(self:header(), payload, length))
end
return C.ntohs(self:header().csum)
return ntohs(self:header().csum)
end

function gre:checksum_check (payload, length)
if not self._checksum then
return true
end
return checksum(self:header(), payload, length) == C.ntohs(self:header().csum)
return checksum(self:header(), payload, length) == lib.ntohs(self:header().csum)
end

-- Returns nil if keying is disabled. Otherwise, the key is set to the
Expand All @@ -143,17 +145,17 @@ function gre:key (key)
return nil
end
if key ~= nil then
self:header().key = C.htonl(key)
self:header().key = htonl(key)
else
return C.ntohl(self:header().key)
return ntohl(self:header().key)
end
end

function gre:protocol (protocol)
if protocol ~= nil then
self:header().protocol = C.htons(protocol)
self:header().protocol = htons(protocol)
end
return(C.ntohs(self:header().protocol))
return(ntohs(self:header().protocol))
end

return gre
4 changes: 2 additions & 2 deletions src/lib/protocol/icmp/header.lua
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ end

function icmp:checksum (payload, length, ipv6)
local header = self:header()
header.checksum = C.htons(checksum(header, payload, length, ipv6))
header.checksum = lib.htons(checksum(header, payload, length, ipv6))
end

function icmp:checksum_check (payload, length, ipv6)
return checksum(self:header(), payload, length, ipv6) == C.ntohs(self:header().checksum)
return checksum(self:header(), payload, length, ipv6) == lib.ntohs(self:header().checksum)
end

return icmp
20 changes: 11 additions & 9 deletions src/lib/protocol/ipv4.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ local C = ffi.C
local lib = require("core.lib")
local header = require("lib.protocol.header")
local ipsum = require("lib.checksum").ipsum
local htons, ntohs, htonl, ntohl =
lib.htons, lib.ntohs, lib.htonl, lib.ntohl

-- TODO: generalize
local AF_INET = 2
Expand Down Expand Up @@ -54,7 +56,7 @@ ipv4._ulp = {

function ipv4:new (config)
local o = ipv4:superClass().new(self)
o:header().ihl_v_tos = C.htons(0x4000) -- v4
o:header().ihl_v_tos = htons(0x4000) -- v4
o:ihl(o:sizeof() / 4)
o:dscp(config.dscp or 0)
o:ecn(config.ecn or 0)
Expand Down Expand Up @@ -105,17 +107,17 @@ end

function ipv4:total_length (length)
if length ~= nil then
self:header().total_length = C.htons(length)
self:header().total_length = htons(length)
else
return(C.ntohs(self:header().total_length))
return(ntohs(self:header().total_length))
end
end

function ipv4:id (id)
if id ~= nil then
self:header().id = C.htons(id)
self:header().id = htons(id)
else
return(C.ntohs(self:header().id))
return(ntohs(self:header().id))
end
end

Expand Down Expand Up @@ -144,9 +146,9 @@ function ipv4:protocol (protocol)
end

function ipv4:checksum ()
self:header().checksum = C.htons(ipsum(ffi.cast("uint8_t *", self:header()),
self:sizeof(), 0))
return C.ntohs(self:header().checksum)
self:header().checksum = htons(ipsum(ffi.cast("uint8_t *", self:header()),
self:sizeof(), 0))
return ntohs(self:header().checksum)
end

function ipv4:src (ip)
Expand Down Expand Up @@ -191,7 +193,7 @@ function ipv4:pseudo_header (ulplen, proto)
local ph = ipv4hdr_pseudo_t()
local h = self:header()
ffi.copy(ph, h.src_ip, 2*ipv4_addr_t_size) -- Copy source and destination
ph.ulp_length = C.htons(ulplen)
ph.ulp_length = htons(ulplen)
ph.ulp_protocol = proto
return(ph)
end
Expand Down
7 changes: 4 additions & 3 deletions src/lib/protocol/ipv6.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ local ffi = require("ffi")
local C = ffi.C
local lib = require("core.lib")
local header = require("lib.protocol.header")
local htons, ntohs = lib.htons, lib.ntohs

local AF_INET6 = 10
local INET6_ADDRSTRLEN = 48
Expand Down Expand Up @@ -126,9 +127,9 @@ end

function ipv6:payload_length (length)
if length ~= nil then
self:header().payload_length = C.htons(length)
self:header().payload_length = htons(length)
else
return(C.ntohs(self:header().payload_length))
return(ntohs(self:header().payload_length))
end
end

Expand Down Expand Up @@ -182,7 +183,7 @@ function ipv6:pseudo_header (plen, nh)
ffi.fill(ph, ffi.sizeof(ph))
local h = self:header()
ffi.copy(ph, h.src_ip, 32) -- Copy source and destination
ph.ulp_length = C.htons(plen)
ph.ulp_length = htons(plen)
ph.next_header = nh
return(ph)
end
Expand Down
6 changes: 4 additions & 2 deletions src/lib/protocol/keyed_ipv6_tunnel.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ module(..., package.seeall)
local ffi = require("ffi")
local C = ffi.C
local header = require("lib.protocol.header")
local lib = require("core.lib")
local htonl, ntohl = lib.htonl, lib.ntohl

ffi.cdef[[
typedef union {
Expand Down Expand Up @@ -90,9 +92,9 @@ function tunnel:session_id (id)
local h = self:header()
if id ~= nil then
assert(id ~= 0, "invalid session id 0")
h.session_id = C.htonl(id)
h.session_id = htonl(id)
else
return C.ntohl(h.session_id)
return ntohl(h.session_id)
end
end

Expand Down
26 changes: 14 additions & 12 deletions src/lib/protocol/tcp.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ local C = ffi.C
local lib = require("core.lib")
local header = require("lib.protocol.header")
local ipsum = require("lib.checksum").ipsum
local ntohs, htons, ntohl, htonl =
lib.ntohs, lib.htons, lib.ntohl, lib.htonl

local tcp_header_t = ffi.typeof[[
struct {
Expand Down Expand Up @@ -55,33 +57,33 @@ end
function tcp:src_port (port)
local h = self:header()
if port ~= nil then
h.src_port = C.htons(port)
h.src_port = htons(port)
end
return C.ntohs(h.src_port)
return ntohs(h.src_port)
end

function tcp:dst_port (port)
local h = self:header()
if port ~= nil then
h.dst_port = C.htons(port)
h.dst_port = htons(port)
end
return C.ntohs(h.dst_port)
return ntohs(h.dst_port)
end

function tcp:seq_num (seq)
local h = self:header()
if seq ~= nil then
h.seq = C.htonl(seq)
h.seq = htonl(seq)
end
return C.ntohl(h.seq)
return ntohl(h.seq)
end

function tcp:ack_num (ack)
local h = self:header()
if ack ~= nil then
h.ack = C.htonl(ack)
h.ack = htonl(ack)
end
return C.ntohl(h.ack)
return ntohl(h.ack)
end

function tcp:offset (offset)
Expand Down Expand Up @@ -135,9 +137,9 @@ end
function tcp:window_size (window_size)
local h = self:header()
if window_size ~= nil then
h.window_size = C.htons(window_size)
h.window_size = htons(window_size)
end
return C.ntohs(h.window_size)
return ntohs(h.window_size)
end

function tcp:checksum (payload, length, ip)
Expand All @@ -154,9 +156,9 @@ function tcp:checksum (payload, length, ip)
csum = ipsum(ffi.cast("uint8_t *", h),
self:sizeof(), bit.bnot(csum))
-- Add TCP payload
h.checksum = C.htons(ipsum(payload, length, bit.bnot(csum)))
h.checksum = htons(ipsum(payload, length, bit.bnot(csum)))
end
return C.ntohs(h.checksum)
return ntohs(h.checksum)
end

-- override the default equality method
Expand Down
Loading

0 comments on commit 4665cdf

Please sign in to comment.