-
Notifications
You must be signed in to change notification settings - Fork 34
/
wg-mux-client
executable file
·392 lines (342 loc) · 16.3 KB
/
wg-mux-client
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
#!/usr/bin/env python3
import os, sys, io, logging, contextlib, asyncio, socket, signal, time
import math, hashlib, secrets, struct, base64, ipaddress, textwrap
import pathlib as pl, subprocess as sp
class WGMuxConfig:
auth_salt = b'wg-mux-1'
mux_attempts = 6
mux_port = 8739
mux_timeout = 10.0
wg_iface = 'wg'
wg_keepalive = 10
ping_err_code = 133
err_fmt = lambda err: '[{}] {}'.format(err.__class__.__name__, err)
class LogMessage:
def __init__(self, fmt, a, k): self.fmt, self.a, self.k = fmt, a, k
def __str__(self): return self.fmt.format(*self.a, **self.k) if self.a or self.k else self.fmt
class LogStyleAdapter(logging.LoggerAdapter):
def __init__(self, logger, extra=None):
super(LogStyleAdapter, self).__init__(logger, extra or {})
def log(self, level, msg, *args, **kws):
if not self.isEnabledFor(level): return
log_kws = {} if 'exc_info' not in kws else dict(exc_info=kws.pop('exc_info'))
msg, kws = self.process(msg, kws)
self.logger._log(level, LogMessage(msg, args, kws), (), **log_kws)
get_logger = lambda name: LogStyleAdapter(logging.getLogger(name))
log = get_logger('mux-client.main')
b64_encode = lambda s: base64.standard_b64encode(s).decode()
b64_decode = lambda s: ( base64.urlsafe_b64decode
if '-' in s or '_' in s else base64.standard_b64decode )(s)
to_bytes = lambda s: s if isinstance(s, bytes) else str(s).encode()
def str_part(s, sep, default=None):
'Examples: str_part("user@host", "<@", "root"), str_part("host:port", ":>")'
c = sep.strip('<>')
if sep.strip(c) == '<': return (default, s) if c not in s else s.split(c, 1)
else: return (s, default) if c not in s else s.rsplit(c, 1)
def sockopt_resolve(prefix, v):
prefix = prefix.upper()
for k in dir(socket):
if not k.startswith(prefix): continue
if getattr(socket, k) == v: return k[len(prefix):]
return v
def bin_pack(fmt, *args):
'Extends struct.pack with "z" for auto-length bytes.'
return struct.pack(fmt.replace('z', '{}s').format(
*(len(s) for s in args if isinstance(s, bytes)) ), *args)
def retries_within_timeout( tries, timeout,
backoff_func=lambda e,n: ((e**n-1)/e), slack=1e-2 ):
'Return list of delays to make exactly n tires within timeout, with backoff_func.'
a, b = 0, timeout
while True:
m = (a + b) / 2
delays = list(backoff_func(m, n) for n in range(tries))
error = sum(delays) - timeout
if abs(error) < slack: return delays
elif error > 0: b = m
else: a = m
class MuxClientProtocol:
transport = None
def __init__(self, loop):
self.responses = asyncio.Queue(loop=loop)
self.log = get_logger('mux-client.udp')
def connection_made(self, transport):
self.log.debug('Connection made')
self.transport = transport
def datagram_received(self, data, addr):
self.log.debug('Received {:,d}B from {!r}', len(data), addr)
self.responses.put_nowait(data)
def error_received(self, err):
self.log.debug('Network error: {}', err)
def connection_lost(self, err):
self.log.debug('Connection lost: {}', err or 'done')
self.responses.put_nowait(None)
class AuthError(Exception): pass
def build_req(auth_secret, ident, wg_key_pk):
ident_buff, pk_buff = to_bytes(ident), b64_decode(wg_key_pk)
salt, payload = os.urandom(16), bin_pack('>Bzz', len(ident_buff), ident_buff, pk_buff)
mac = hashlib.blake2b(payload, key=auth_secret, salt=salt).digest()
return payload + salt + mac
def parse_res(auth_secret, ident, res):
if not res: return
try:
fmt = '>BBHB{}s{}s'.format(res[0], {4: 4, 6: 16}[res[1]])
( ident_len, ip_len, wg_port, wg_mask, ident_buff,
wg_addr, salt, mac ) = struct.unpack(fmt + '16s64s', res)
mac_chk = hashlib.blake2b(
res[:struct.calcsize(fmt)], key=auth_secret, salt=salt ).digest()
if not secrets.compare_digest(mac, mac_chk): raise AuthError('MAC mismatch')
except (KeyError, struct.error, AuthError) as err:
log.debug('Failed to parse/auth response value: {}', err)
return
wg_addr = ipaddress.ip_address(wg_addr)
return wg_addr, wg_mask, wg_port
async def mux_negotiate(
loop, auth_secret, ident, wg_key_pk,
sock_af, sock_p, host, port, delays ):
req = build_req(auth_secret, ident, wg_key_pk)
transport = proto = None
try:
for delay in delays + [2**30]:
deadline = loop.time() + delay
if not transport:
transport, proto = await loop.create_datagram_endpoint(
lambda: MuxClientProtocol(loop), remote_addr=(host, port), family=sock_af, proto=sock_p )
transport.sendto(req)
if delay:
while True:
try:
response = await asyncio.wait_for(
proto.responses.get(), max(0, deadline - loop.time()) )
except asyncio.TimeoutError: break
if response is None:
transport = proto = None
break
response = parse_res(auth_secret, ident, response)
if response: return response
if transport: transport.sendto(req)
await asyncio.sleep(max(0, deadline - loop.time()), loop=loop)
finally:
if transport: transport.close()
def main(args=None, conf=None):
if not conf: conf = WGMuxConfig()
import argparse
parser = argparse.ArgumentParser(
description='Wrapper for "wg set" + "ip addr" with peer/iface'
' config queried from remote server by some unique --ident-* info.')
group = parser.add_argument_group('Mux server info')
group.add_argument('host',
help='Host or address (to be resolved via gai) or a host[:port] spec.'
' "port" will be used for -m/--mux-port option, if specified here.')
group.add_argument('pubkey', help='Base64-encoded WireGuard server public key.')
group.add_argument('pubkey_client',
help='Base64-encoded public key of this client or a file with such.')
group.add_argument('-m', '--mux-port',
default=conf.mux_port, type=int, metavar='port',
help='Remote UDP port on which corresponding'
' mux-server script is listening (default: %(default)s).'
' Can also be specified in the "host" argument, which overrides this option.')
group = parser.add_argument_group('Auth/ident options')
group.add_argument('-s', '--auth-secret', metavar='string',
help='Any string to use as symmetric secret'
' to authenticate both sides on --mux-port (default: %(default)s).'
' Must be same for both mux-client and server scripts talking to each other.'
' Will be read from stdin, if not specified.')
group.add_argument('-i', '--ident-string',
metavar='string',
help='Any string to use as this node identity -'
' i.e. serial number, mac/hw address, machine-id, etc.'
' Hash of /etc/machine-id contents is used, if not specified.'
' Overrides any other --ident-* option.')
group.add_argument('--ident-rpi', action='store_true',
help='Use hash of "Serial" from /proc/cpuinfo as ident.'
' Only available on Raspberry Pi boards.')
group.add_argument('--ident-file', metavar='path',
help='Read specified file contents as the ident string, stripping spaces.'
' E.g. /sys/class/dmi/id/board_serial can be used on x86 platforms, if accessible.')
group.add_argument('--ident-cmd', metavar='shell-cmd',
help='Shell command to run to get ident string on stdout.'
' Must exit with code 0, otherwise script will abort.'
' Resulting string be stripped of spaces, otherwise sent as-is,'
' so should be hashed in the command if necessary.')
group = parser.add_argument_group('WireGuard options')
group.add_argument('--wg-iface', metavar='iface', default=conf.wg_iface,
help='WireGuard interface name to configure. Default: %(default)s')
group.add_argument('--wg-port', type=int, metavar='port',
help='Remote WireGuard port to use. Default is to use one returned by mux-server.')
group.add_argument('--wg-psk', metavar='file',
help='File with base64-encoded WireGuard pre-shared-secret key to use for connection.')
group.add_argument('--wg-net', metavar='ip/mask', default='0.0.0.0/0',
help='IP/mask network spec to check returned address against or map to.'
' This is required if mux-server is configured to only return last IP octet.')
group.add_argument('--wg-keepalive', metavar='seconds', default=conf.wg_keepalive,
help='WireGuard keepalive interval. Default: %(default)s')
group.add_argument('--wg-cmd', metavar='cmd', default='wg',
help='"wg" command to run, split on spaces.'
' Use e.g. "sudo wg" to have it run via sudo or full path'
' for a special binary with suid/capabilities. Default: %(default)s')
group.add_argument('--ip-cmd', metavar='cmd', default='ip',
help='"ip" command to use for assigning IP address to the interface.'
' Will be run as "ip addr add <addr/mask> dev <wg-iface>".'
' Wrapper can be used to do more stuff. Split on spaces. Default: %(default)s')
group = parser.add_argument_group('Endless-run options')
group.add_argument('-p', '--ping-cmd', metavar='cmd',
help='Run specified ping-command in a loop after negotiating a tunnel,'
f' exiting with code={conf.ping_err_code} if/when it exits with non-zero code.'
' Command and arguments are split on spaces.'
' Command can include ":" to load it as python'
' script and call function there with all other args in a list.'
' Examples: ping -q -w15 -c3 -i3 10.123.0.1, ident-client.py:main 10.123.0.1')
group.add_argument('--ping-interval', type=float, metavar='seconds', default=600,
help='Interval between running --ping-cmd in a loop until it fails. Default: %(default)s')
group.add_argument('--ping-silent', action='store_true',
help='Suppress stdout/stderr for --ping-cmd.')
group.add_argument('--ping-systemd', action='store_true',
help='Use systemd service notification/watchdog'
' mechanisms when running as a daemon, if enabled in systemd unit.'
' Service start notification is only issued after first successful ping command.'
' Requres python systemd module installed.')
group = parser.add_argument_group('Misc options')
group.add_argument('-n', '--attempts',
type=int, metavar='n', default=conf.mux_attempts,
help='Number of UDP packets to send to'
' --mux-port (to offset packet loss). Default: %(default)s')
group.add_argument('-t', '--timeout',
type=float, metavar='seconds', default=conf.mux_timeout,
help='Negotiation response timeout on --mux-port, in seconds. Default: %(default)ss')
group.add_argument('--dry-run', action='store_true',
help='Do not change WireGuard configuration, only pretend to.')
group.add_argument('--debug', action='store_true', help='Verbose operation mode.')
opts = parser.parse_args(sys.argv[1:] if args is None else args)
logging.basicConfig(level=logging.DEBUG if opts.debug else logging.WARNING)
if not opts.auth_secret:
log.debug('Reading --auth-secret from stdin (exact value, incl. spaces and newlines).')
with open(sys.stdin.fileno(), 'rb') as src: opts.auth_secret = src.read()
if not opts.auth_secret: parser.error('No --auth-secret specified and none provided on stdin.')
auth_secret = hashlib.blake2s(to_bytes(opts.auth_secret), person=conf.auth_salt).digest()
ident = opts.ident_string
if not ident:
if opts.ident_rpi:
import re
with open('/proc/cpuinfo') as src:
for line in src:
m = re.search(r'^\s*Serial\s*:\s*(\S+)\s*$', line)
if m: break
else: parser.error('Failed to find "Serial : ..." line in /proc/cpuinfo (non-RPi kernel?)')
ident = hashlib.blake2s(m.group(1).encode(), key=auth_secret).digest()
elif opts.ident_file:
ident = pl.Path(opts.ident_file).read_bytes().strip()
ident = hashlib.blake2s(ident, key=auth_secret).digest()
elif opts.ident_cmd:
ident = sp.run(opts.ident_cmd, shell=True, check=True, stdout=sp.PIPE)
ident = ident.stdout.decode().strip()
else:
with open('/etc/machine-id', 'rb') as src:
ident = hashlib.blake2s(src.read(), key=auth_secret).digest()
ping_cmd = opts.ping_cmd
if ping_cmd:
ping_cmd = ping_cmd.split()
if ':' in ping_cmd[0]:
import runpy
src, func = ping_cmd[0].rsplit(':', 1)
p = pl.Path(src)
if not p.exists():
for p in os.environ.get('PATH', '').split(':'):
p = pl.Path(p) / src
if p.exists(): break
else: parser.error(f'Failed to find import path --ping-cmd: {ping_cmd[0]}')
ping_cmd[0] = runpy.run_path(p, run_name='ping_cmd')[func]
if not callable(ping_cmd[0]): parser.error(f'--ping-cmd spec not callable: {ping_cmd}')
ping_sd = opts.ping_systemd
if ping_sd:
from systemd import daemon
sd_pid, sd_usec = (os.environ.get(k) for k in ['WATCHDOG_PID', 'WATCHDOG_USEC'])
if sd_pid and sd_pid.isdigit() and int(sd_pid) == os.getpid():
sd_ping_interval = float(sd_usec) / 2e6 # half of interval in seconds
if sd_ping_interval <= 0: parser.error('Passed WATCHDOG_USEC interval <= 0')
log.debug('Initializing systemd watchdog pinger with interval: {:,.1f}s', sd_ping_interval)
else: sd_ping_interval = None
sd_ready, sd_ping_ts = False, sd_ping_interval and time.monotonic() + sd_ping_interval
host, port_mux, family = opts.host, opts.mux_port, 0
if host.count(':') > 1: host, port_mux = str_part(host, ']:>', port_mux)
else: host, port_mux = str_part(host, ':>', port_mux)
if '[' in host: family = socket.AF_INET6
host, port_mux = host.strip('[]'), int(port_mux)
try:
addrinfo = socket.getaddrinfo( host, str(port_mux),
family=family, type=socket.SOCK_DGRAM, proto=socket.IPPROTO_UDP )
if not addrinfo: raise socket.gaierror(f'No addrinfo for host: {host}')
except (socket.gaierror, socket.error) as err:
parser.error( 'Failed to resolve socket parameters'
' via getaddrinfo: {!r} - {}'.format((host, port_mux), err_fmt(err)) )
sock_af, sock_t, sock_p, _, sock_addr = addrinfo[0]
log.debug(
'Resolved mux host:port {!r}:{!r} to endpoint: {} (family: {}, type: {}, proto: {})',
host, port_mux, sock_addr,
*(sockopt_resolve(pre, n) for pre, n in [
('af_', sock_af), ('sock_', sock_t), ('ipproto_', sock_p) ]) )
host, port_mux = sock_addr[:2]
wg_cmd, ip_cmd = opts.wg_cmd.split(), opts.ip_cmd.split()
wg_net_check = ipaddress.ip_network(opts.wg_net)
wg_pk, wg_pk_client = opts.pubkey, pl.Path(opts.pubkey_client)
if wg_pk_client.exists(): wg_pk_client = wg_pk_client.read_text().strip()
else: wg_pk_client = opts.pubkey_client
wg_psk = list()
if opts.wg_psk: wg_psk.extend(['preshared-key', opts.wg_psk])
retry_delays = retries_within_timeout(opts.attempts+1, opts.timeout)[:-1]
with contextlib.closing(asyncio.get_event_loop()) as loop:
muxer = loop.create_task(mux_negotiate(
loop, auth_secret, ident, wg_pk_client,
sock_af, sock_p, host, port_mux, retry_delays ))
for sig in 'INT TERM'.split():
loop.add_signal_handler(getattr(signal, f'SIG{sig}'), muxer.cancel)
try:
wg_addr, wg_mask, wg_port = \
loop.run_until_complete(asyncio.wait_for(muxer, opts.timeout))
except (asyncio.CancelledError, asyncio.TimeoutError) as err:
log.debug('mux_negotiate cancelled: {}', err_fmt(err))
return
if opts.wg_port: wg_port = opts.wg_port
if wg_addr not in wg_net_check:
print( 'ERROR: mux-server returned address'
f' outside of specified --wg-net: {wg_addr}', file=sys.stderr )
return 1
wg_ep = f'{host}:{wg_port}'
wg_net = ipaddress.ip_network(f'{wg_addr}/{wg_mask}', strict=False)
log.debug( 'Negotiated wg params: ep={} addr={}/{}'
' pubkey={}', wg_ep, wg_addr, wg_mask, wg_pk )
cmd_peer = wg_cmd + [
'set', opts.wg_iface, 'peer', wg_pk, 'allowed-ips', str(wg_net),
'endpoint', wg_ep, 'persistent-keepalive', str(opts.wg_keepalive) ] + wg_psk
cmd_iface = ip_cmd + ['addr', 'add', f'{wg_addr}/{wg_mask}', 'dev', opts.wg_iface]
if not opts.dry_run:
sp.run(cmd_peer, check=True)
sp.run(cmd_iface, stderr=sp.DEVNULL)
else:
log.debug('Config for peer: {}', ' '.join(cmd_peer))
log.debug('Config for iface: {}', ' '.join(cmd_iface))
if ping_cmd:
for sig in 'int term'.upper().split():
signal.signal(getattr(signal, f'SIG{sig}'), lambda sig,frm: sys.exit(0))
ping_cmd_call, ping_kws = callable(ping_cmd[0]), dict()
if opts.ping_silent: ping_kws.update(stdout=sp.DEVNULL, stderr=sp.DEVNULL)
while True:
if not ping_cmd_call: code = sp.run(ping_cmd, **ping_kws).returncode
else: code = ping_cmd[0](ping_cmd[1:])
if code:
print( f'ERROR: --ping-cmd exited with code {code},'
f' exiting with code {conf.ping_err_code}', file=sys.stderr )
return conf.ping_err_code
delay, delay_sd = opts.ping_interval, None
if ping_sd:
if not sd_ready:
daemon.notify('READY=1')
daemon.notify(f'STATUS=Running with wg addr: {wg_addr}/{wg_mask}')
sd_ready = True
if sd_ping_ts:
delay_sd = sd_ping_ts - time.monotonic()
if delay_sd <= delay: delay, delay_sd = delay_sd, True
if delay > 0: time.sleep(delay)
if delay_sd is True:
sd_ping_ts += sd_ping_interval
daemon.notify('WATCHDOG=1')
if __name__ == '__main__': sys.exit(main())