-
Notifications
You must be signed in to change notification settings - Fork 34
/
ssh-r-sync-recv
executable file
·349 lines (299 loc) · 14.5 KB
/
ssh-r-sync-recv
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
#!/usr/bin/env python3
import itertools as it, operator as op, functools as ft
import os, sys, contextlib, collections, subprocess, socket
import hashlib, math, json, base64, signal, time
import tempfile, pathlib as pl
p_fmt = lambda fmt,*a,**k: (
[fmt.format(*a,**k) if a or k else fmt] if isinstance(fmt, str) else [[fmt, *a], k] )
p = lambda fmt,*a,**k: print(*p_fmt(fmt,*a,**k), file=sys.stderr, flush=True)
class SSHRsyncReceiverConfig:
# recv_port_* settings should be same on server/client
recv_port_bind = '127.0.0.1'
recv_port_range = 22000, 22999
recv_port_retries = 4
recv_port_hash_key = b'J\xbc\xed\xa3\xa5\x02E(\xde\xdc#h\xa4\xa5\xa4t'
hs_hello = b'ssh-r-sync o/ 1'
hs_timeout = 20
hs_key_size = 64
class SSHRsyncError(Exception):
def __init__(self, *args):
super().__init__(str(p_fmt(*args)[0]))
def hash_to_int(name, retry, n_max, **blake2_kws):
'Returns integer hash within n_max range from name/retry values.'
assert n_max > 0, n_max
n_bits = math.ceil(math.log(n_max + 1, 2))
n_bytes = math.ceil(n_bits / 8)
mac = retry.to_bytes(1, 'big') + name.encode()
for n in range(1000):
mac = hashlib.blake2b(mac, **blake2_kws).digest()
for o in range(0, 8//n_bytes):
n = int.from_bytes(mac[o*n_bytes:(o+1)*n_bytes], 'big')
if n_bits % 8: n >>= 8 - n_bits % 8
if n <= n_max: return n # (n % n_max) would add bias for small values
raise ValueError('Failed to get int within range after many hashes')
b64_str = lambda v: base64.urlsafe_b64encode(v).decode()
b64_bytes = lambda v: base64.urlsafe_b64decode(v.encode())
@contextlib.contextmanager
def err_timeout(timeout, err_t, *err_args):
def timeout_handler(sig, frm): raise err_t(*err_args)
handler_prev = signal.signal(signal.SIGALRM, timeout_handler)
delay, interval = signal.setitimer(signal.ITIMER_REAL, timeout)
assert not delay or interval
try: yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
if handler_prev: signal.signal(signal.SIGALRM, handler_prev)
def proc_close(proc, wait=1, wait_base=0.3):
if not proc or proc.poll() is not None: return
close_ops = [proc.terminate, proc.kill]
if proc.stdin: close_ops = [proc.stdin.close] + close_ops
if isinstance(wait, (int, float)) and wait > 0:
wait = list(n*wait for n, func in enumerate(close_ops, 1))
wait = dict(enumerate(wait or list()))
for n, func in enumerate(close_ops):
with contextlib.suppress(OSError): func()
with contextlib.suppress(subprocess.TimeoutExpired):
proc.wait(wait.get(n, wait_base))
break
class BackupHooks(dict):
points = {
'script.start':
'Before starting handshake with authenticated remote.\n'
'args: backup root dir.',
'script.done':
'After successful backup, before closing connection to remote.\n'
'args: backup root dir, backup dir, remote name.',
'script.handshake':
'After handshake with remote and before any backup actions.\n'
'args: backup root dir, backup dir, remote name.',
'rsync.pre':
'Right before backup-rsync is started, if it will be run.\n'
'args: backup root dir, backup dir, remote name, privileged sync (0 or 1).\n'
'stdout: any additional \\0-separated args to pass to rsync.\n'
' These must be terminated by \\0, if passed,\n'
' and can start with \\0 to avoid passing any default options.',
'rsync.done':
'Right after backup-rsync is finished, e.g. to check/process its output.\n'
'args: backup root dir, backup dir, remote name, rsync exit code.\n'
'stdin: interleaved stdout/stderr from rsync.\n'
'stdout: optional replacement for rsync return code to check, must be int if non-empty.',
}
__slots__ = 'timeout'.split()
def __init__(self, *args, **kws):
for k,v in it.chain(zip(self.__slots__, args), kws.items()): setattr(self, k, v)
if self.timeout <= 0: self.timeout = None
def run(self, hook, *hook_args, **run_kws):
import subprocess
hook_script = self.get(hook)
if not hook_script: return
kws = dict(check=True, timeout=self.timeout)
kws.update(run_kws)
hook_args = list(map(str, hook_args))
return subprocess.run([hook_script] + hook_args, **kws)
SSHPipe = collections.namedtuple('SSHPipe', 'send recv')
RsyncInfo = collections.namedtuple('RsyncInfo', 'user key mod su opts')
def get_rsync_info(conf, ssh):
hs_timeout_ctx = ft.partial( err_timeout,
conf.hs_timeout, SSHRsyncError, 'Timeout waiting for receiver' )
with hs_timeout_ctx(): ssh.send.write(conf.hs_hello + b'\n')
with hs_timeout_ctx(): hello = ssh.recv.readline().rstrip(b'\n')
if hello != conf.hs_hello:
raise SSHRsyncError( 'Hello-string mismatch:'
' local={!r} remote={!r}', conf.hs_hello, hello )
with hs_timeout_ctx(): info = ssh.recv.readline()
info = json.loads(info)
tun_ports = set(
(conf.recv_port_range[0] + hash_to_int( info['name'], attempt,
conf.recv_port_range[1] - conf.recv_port_range[0], key=conf.recv_port_hash_key ))
for attempt in range(conf.recv_port_retries+1) )
if info['port'] not in tun_ports:
raise SSHRsyncError( 'Tunnel-port mismatch:'
' remote={!r} local-set={!r}', info['port'], tun_ports )
name, rsync_port, rsync_info = info['name'], info['port'], RsyncInfo(**info['rsync'])
key_base = b64_bytes(info['key'])
key_recv = hashlib.blake2b(key_base, person=b'server', key=conf.hs_hello).digest()
key_push = hashlib.blake2b(key_base, person=b'client', key=conf.hs_hello).digest()
assert len(key_recv) == len(key_push) == conf.hs_key_size, [len(key_recv), len(key_push)]
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock_conn:
sock_conn.settimeout(conf.hs_timeout)
sock_conn.connect((conf.recv_port_bind, rsync_port))
key = sock_conn.recv(conf.hs_key_size)
if len(key) != conf.hs_key_size or key != key_push:
raise SSHRsyncError('ssh/socket mismatch (len={}/{})', len(key), conf.hs_key_size)
sock_conn.send(key_recv)
return name, rsync_port, rsync_info
def main(args=None):
if args is None: args = sys.argv[1:]
if not args: # very dangerous, as it'd allow rsync to arbitrary paths chosen by remote
print('ERROR: ssh-r-sync-recv shell setup without ForceCommand!', file=sys.stderr)
return 1
if args[0] == '-c' and len(args) == 2: args = args[1].split()[1:] # shell mode
conf = SSHRsyncReceiverConfig()
import argparse, textwrap
dedent = lambda text: (textwrap.dedent(text).strip('\n') + '\n').replace('\t', ' ')
class SmartHelpFormatter(argparse.HelpFormatter):
def __init__(self, *args, **kws):
return super().__init__(*args, **kws, width=100)
def _fill_text(self, text, width, indent):
if '\n' not in text: return super()._fill_text(text, width, indent)
return ''.join( indent + line
for line in text.replace('\t', ' ').splitlines(keepends=True) )
def _split_lines(self, text, width):
return super()._split_lines(text, width)\
if '\n' not in text else dedent(text).splitlines()
parser = argparse.ArgumentParser(
formatter_class=SmartHelpFormatter,
description=dedent('''
Receiver script for ssh sessions from ssh-r-sync.
Should be run after successful ssh auth to check reverse-tunnel setup,
perform any necessary maintenance on destination path and run rsync client.
Can be used as ssh ForceCommand (with any options) and shell at the same time.'''),
epilog=dedent('''
sshd_config snippet to set this script for backup-user command:
Match User backup
X11Forwarding no
AllowAgentForwarding no
PermitTunnel no
AllowStreamLocalForwarding no
AllowTcpForwarding remote
PasswordAuthentication no
PermitTTY no
ForceCommand ~/ssh-r-sync-recv -- /mnt/backups
Any cli opts of this script should be added to ForceCommand (string for shell).
To avoid running shell (i.e. "/bin/bash -c 'command args...'") entirely,
in addition to ForceCommand (where all options are specified),
set this script to be user's shell as well, though note that
it'll split ForceCommand by spaces and won't interpret any quoting there.
Authentication is done via usual lines in ~backup/.ssh/authorized_keys'''))
parser.add_argument('backup_dir', nargs='?', help='Destination backup directory root path.')
parser.add_argument('-r', '--rsync', metavar='path',
help='Path to custom rsync binary, e.g. one with all necessary posix caps enabled.')
group = parser.add_argument_group('Hook options')
group.add_argument('-x', '--hook', action='append', metavar='hook:path',
help='''
Hook-script to run at the specified point.
Specified path must be executable (chmod +x ...), will be run synchronously, and
must exit with 0 for tool to continue operation, and non-zero to abort immediately.
Hooks are run with same uid/gid and env as the main script, can use PATH-lookup.
See --hook-list output to get full list of
all supported hook-points and arguments passed to them.
Example spec: -x rsync.pre:~/hook.pre-sync.sh''')
group.add_argument('--hook-timeout', metavar='seconds', type=float, default=0,
help='Timeout for waiting for hook-script to finish running,'
' before aborting the operation (treated as hook error).'
' Zero or negative value will disable timeout. Default: no-limit')
group.add_argument('--hook-list', action='store_true',
help='Print the list of all supported hooks with descriptions/parameters and exit.')
group = parser.add_argument_group('Misc options')
group.add_argument('--nice', metavar='(prio:)(io-class(.io-level))',
help='''
Set "nice" and/or "ionice" (CFQ I/O) priorities, inherited by hooks and rsync.'
"nice" prio value, if specified, must be
in [-20;20] range, where lower = higher prio, and base=0.
"ionice" value should be in class[:level] format, where
"class" is one of [rt, be, idle] and "level" in 0-7 range (0=highest prio).
See setpriority(2) / ioprio_set(2) for more info.''')
opts = parser.parse_args(args)
def print_error_msg_only(err_t, err, err_tb):
if isinstance(err, SSHRsyncError):
print(f'ERROR: {err}', file=sys.stderr)
else: sys.__excepthook__(err_t, err, err_tb)
sys.excepthook = print_error_msg_only
for sig in 'int hup term'.upper().split():
signal.signal(getattr(signal, f'SIG{sig}'), lambda sig,frm: sys.exit(1))
if opts.nice is not None:
nice, ionice = (opts.nice + ':').split(':', 1)
if not ionice and not nice.isdigit(): nice, ionice = None, nice
if nice: os.setpriority(os.PRIO_PROCESS, os.getpid(), int(nice))
if ionice:
# grep -r ioprio_set /usr/share/gdb/syscalls/
syscall_id = os.uname().machine
try: syscall_id = dict(x86_64=251, armv7l=314, aarch64=30)[syscall_id]
except KeyError: parser.error(f'--nice ionice is not supported on arch: {syscall_id}')
ionice = ionice.rstrip(':').split('.', 1)
if len(ionice) == 1: ionice.append(0)
ionice[0] = dict(rt=1, be=2, idle=3)[ionice[0].lower()]
if ionice[0] == 3: ionice[1] = 0
elif 0 <= ionice[1] <= 7: parser.error('--nice ionice prio level must be in 0-7 range')
ionice = (ionice[0] << 13) | ionice[1]
import ctypes as ct
err = ct.CDLL('libc.so.6', use_errno=True).syscall(syscall_id, 1, os.getpid(), ionice)
if err != 0:
err = ct.get_errno()
raise OSError(err, 'ionice_set failed - ' + os.strerror(err))
bak_hooks = BackupHooks(opts.hook_timeout)
if opts.hook_list:
p('Available hook points:\n')
for hp, desc in bak_hooks.points.items():
p(' {}:', hp)
indent = ' '*4
desc = textwrap.fill(desc, width=100, initial_indent=indent, subsequent_indent=indent)\
if '\n' not in desc else ''.join(indent + line for line in desc.splitlines(keepends=True))
p(desc + '\n')
p('Hooks are run synchronously, waiting for subprocess to exit before continuing.')
p('All hooks must exit with status 0 to continue operation.')
p('Some hooks get passed arguments, as mentioned in hook descriptions.')
p('Setting --hook-timeout (defaults to no limit) can be used to abort when hook-scripts hang.')
return
for v in opts.hook or list():
if ':' not in v: parser.error(f'Invalid --hook spec (must be hook:path): {v!r}')
hp, path = v.split(':', 1)
if hp not in bak_hooks.points:
parser.error(f'Invaluid hook name: {hp!r} (see --hook-list)')
if os.sep in path: path = str(pl.Path(path).expanduser().resolve())
bak_hooks[hp] = path
if not opts.backup_dir: parser.error('backup_dir argument must be specified')
ssh = SSHPipe(
recv=open(sys.stdin.fileno(), 'rb', 0),
send=open(sys.stdout.fileno(), 'wb', 0) )
sys.stdin = sys.stdout = None
bak_root = pl.Path(opts.backup_dir).resolve()
bak_hooks.run('script.start', bak_root)
bak_name, rsync_port, rsync_info = get_rsync_info(conf, ssh)
bak_path = bak_root / bak_name
bak_hooks.run('script.handshake', bak_root, bak_path, bak_name)
cmd = ssh.recv.readline().rstrip(b'\n')
if cmd != b'start': raise SSHRsyncError('Invalid/missing backup command: {!r}', cmd)
su_fake = rsync_info.opts.get('fake_super')
su_mode = not su_fake and rsync_info.su
rsync_opts = ['--delete', '--delete-excluded', '--zl=1']
if rsync_info.opts.get('one_file_system'): rsync_opts.append('-x')
if su_mode: rsync_opts.extend(['-HaAXS', '--super', '--numeric-ids'])
else: rsync_opts.append('-rltHES')
if su_fake: rsync_opts.append('--fake-super')
hook = bak_hooks.run( 'rsync.pre',
bak_root, bak_path, bak_name,
int(bool(su_mode)), stdout=subprocess.PIPE )
if hook and hook.stdout.strip():
if b'\0' not in hook.stdout:
raise SSHRsyncError( 'rsync.pre hook'
' produced non-\\0-terminated stdout: {!r}', hook.stdout )
hook = hook.stdout.decode().strip(' \n').split('\0')
if not hook[-1]: hook = hook[:-1]
if not hook[0]: rsync_opts[:] = hook[1:]
else: rsync_opts.extend(hook)
rsync = pl.Path(opts.rsync).expanduser().resolve() if opts.rsync else 'rsync'
rsync = subprocess.run(
[ rsync, *rsync_opts, f'rsync://{rsync_info.user}@'
f'{conf.recv_port_bind}:{rsync_port}/{rsync_info.mod}/.', str(bak_path / '.') ],
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
env=dict(RSYNC_PASSWORD=rsync_info.key, RSYNC_COMPRESS_LIST='zstd none') )
rsync_code, rsync_ouput = rsync.returncode, rsync.stdout
hook = bak_hooks.run( 'rsync.done',
bak_root, bak_path, bak_name, rsync_code,
input=rsync_ouput, stdout=subprocess.PIPE )
if hook and hook.stdout.strip():
try: rsync_code = int(hook.stdout.strip())
except ValueError:
raise SSHRsyncError( 'rsync.done hook stdout'
' cannot be parsed as integer: {!r}', hook.stdout ) from None
if rsync_code != 0:
msg = f'Rsync process exited with non-zero code: {rsync_code}'
if rsync_ouput.strip():
msg += '\nRsync process output:\n' + '\n'.join(
f' {line}'.rstrip() for line in rsync_ouput.decode('utf-8', 'replace').split('\n') )
raise SSHRsyncError(msg)
ssh.send.write(b'done\n')
ssh.send.close()
bak_hooks.run('script.done', bak_root, bak_path, bak_name)
if __name__ == '__main__': sys.exit(main())