-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvire.py
252 lines (231 loc) · 8.91 KB
/
vire.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
from __future__ import annotations
from typing import *
from dataclasses import dataclass
from pathlib import Path
from queue import Queue
import argparse
import importlib
import os
import runpy
import signal
import sys
import termios
import threading
import tty
import traceback
import time
from inotify_simple import INotify, flags # type: ignore
STDIN_FILENO = sys.stdin.fileno()
STDOUT_FILENO = sys.stdout.fileno()
def spawn(f: Callable[[], None]) -> None:
threading.Thread(target=f, daemon=True).start()
def sigterm(pid: int):
try:
os.kill(pid, signal.SIGTERM)
relevant = True
@spawn
def _():
time.sleep(1)
if relevant:
print('waited for 5s, now sending SIGKILL...')
os.kill(pid, signal.SIGKILL)
# print('wait')
os.waitpid(pid, 0)
relevant = False
except ProcessLookupError:
pass
def clear(clear_opt: int):
# from entr: https://github.com/eradman/entr/blob/master/entr.c
# 2J - erase the entire display
# 3J - clear scrollback buffer
# H - set cursor position to the default
if clear_opt == 1:
print('\033[2J\033[H', end='', flush=True)
if clear_opt >= 2:
print('\033[2J\033[3J\033[H', end='', flush=True)
def run_child(argv: list[str], is_module: bool):
sys.dont_write_bytecode = True
sys.argv[1:] = argv[1:]
try:
if is_module:
runpy.run_module(argv[0], run_name='__main__')
else:
sys.path = [str(Path(argv[0]).parent), *sys.path]
runpy.run_path(argv[0], run_name='__main__')
except SystemExit:
pass
except:
traceback.print_exc()
@dataclass
class Vire:
preload: str
argv: list[str]
is_module: bool
glob_patterns: str
clear_opt: int = 0
silent: bool = False
auto_full_reload: bool = False
preload_exclude: Iterable[str] = ()
_restore: Callable[[], None] = lambda: None
def main(self):
if not sys.stdin.isatty():
self._main()
else:
mode = termios.tcgetattr(STDIN_FILENO)
self._restore = lambda: termios.tcsetattr(STDIN_FILENO, termios.TCSAFLUSH, mode)
try:
tty.setcbreak(STDIN_FILENO) # required for returning single characters from standard input
self._main()
finally:
self._restore()
def getchar(self):
# requires tty.setcbreak
return sys.stdin.read(1)
def _main(self):
sys_argv_copy = [*sys.argv]
# print(f'{sys.executable=} {sys_argv_copy=}')
for name in self.preload.split(','):
name = name.strip()
if name:
try:
importlib.import_module(name)
except:
traceback.print_exc()
imported_at_preload: set[str] = {
file
for _, m in sys.modules.items()
for file in [getattr(m, '__file__', None)]
if file
if not file.startswith('/usr')
}
imported_at_preload = imported_at_preload - {*self.preload_exclude}
wd_to_filename: dict[int, str] = {}
ino: Any = INotify()
def add_watch(filename: str | Path):
wd = ino.add_watch(filename, flags.MODIFY)
wd_to_filename[wd] = str(Path(filename).resolve())
for name in imported_at_preload:
add_watch(name)
for glob_pattern in self.glob_patterns.split(','):
for name in Path('.').glob(glob_pattern.strip()):
add_watch(name)
out_of_sync: set[str] = set()
q: Queue[Union[set[str], str]] = Queue()
@spawn
def bg_getchar():
while True:
c = self.getchar()
q.put_nowait(c)
@spawn
def bg_read_events():
nonlocal out_of_sync
while True:
events = ino.read(read_delay=1)
files = {wd_to_filename[e.wd] for e in events}
out_of_sync |= files & imported_at_preload
msg = files - imported_at_preload
q.put(msg)
while True:
clear(self.clear_opt)
pid = fork(lambda: run_child(self.argv, is_module=self.is_module))
out_of_sync_reported: set[str] = set()
while True:
if not self.silent and out_of_sync != out_of_sync_reported:
print('vire: Preloaded files have been modified:', file=sys.stderr)
print(*[' ' + f for f in out_of_sync], sep='\n', file=sys.stderr)
print(' Press R for full reload.', file=sys.stderr)
out_of_sync_reported = set(out_of_sync)
try:
msg = q.get()
except KeyboardInterrupt:
msg = 'q'
if self.auto_full_reload and out_of_sync:
print('vire: Preloaded files have been modified:', file=sys.stderr)
print(*[' ' + f for f in out_of_sync], sep='\n', file=sys.stderr)
print(' Running full reload.', file=sys.stderr)
msg = 'R-auto'
if msg == 'q':
sigterm(pid)
sys.exit()
if msg == 'R' or msg == 'R-auto':
sigterm(pid)
clear(1 if msg == 'R' else self.clear_opt)
self._restore()
max_fd = os.sysconf("SC_OPEN_MAX")
os.closerange(3, max_fd)
sys.argv[:] = sys_argv_copy
if os.access(sys.argv[0], os.X_OK):
os.execve(sys.argv[0], sys.argv, os.environ)
else:
sys.argv[:] = [sys.executable, *sys_argv_copy]
os.execve(sys.argv[0], sys.argv, os.environ)
if msg == 'c':
clear(1)
break
if msg == 'C':
clear(2)
break
if msg == ' ' or msg == 'r':
break
if isinstance(msg, set) and msg:
break
sigterm(pid)
def fork(child: Callable[[], None]):
pid = os.fork()
if pid == 0:
os.dup2(os.open('/dev/null', os.O_RDONLY), STDIN_FILENO)
child()
return pid
_is_reloading: bool = False
_is_running_from_command_line_tool: bool = False
def reload(filepath: str, *, glob: str='**/*.py', auto_full_reload: bool=False, clear: int=1):
global _is_reloading
if _is_running_from_command_line_tool:
raise ValueError('Not a good idea to run both vire.reload and the vire command line tool')
# print(f'{_is_reloading=}\n{sys.argv=}')
if _is_reloading:
return
else:
_is_reloading = True
Vire(
preload = '',
argv = [*sys.argv],
is_module = False,
glob_patterns = glob,
clear_opt = clear,
silent = False,
auto_full_reload = auto_full_reload,
preload_exclude = (filepath,)
).main()
def main():
global _is_running_from_command_line_tool
_is_running_from_command_line_tool = True
parser = argparse.ArgumentParser(
description='''
Runs a program and reruns it on updating files matching a glob (default **/*.py).
''',
add_help=False,
)
parser.add_argument('--help', '-h', action='store_true', help=argparse.SUPPRESS)
parser.add_argument('--clear', '-c', action='count', default=0, help='Clear the screen before invoking the utility. Specify twice to erase the scrollback buffer.')
parser.add_argument('--preload', '-p', metavar='M', help='Modules to preload, comma-separated. Example: flask,pandas')
parser.add_argument('--glob', '-g', metavar='G', help='Watch for updates to files matching this glob, Default: **/*.py', default='**/*.py')
parser.add_argument('--silent', '-s', action='store_true', help='Silence warning about modifications to preloaded modules.')
parser.add_argument('--auto-full-reload', '-r', action='store_true', help='Automatically do full reload on modifications to preloaded modules.')
parser.add_argument('-m', action='store_true', help='Argument is a module, will be run like python -m (using runpy)')
parser.add_argument(dest='argv', nargs=argparse.REMAINDER)
args = parser.parse_args()
if not args.argv or args.help:
parser.print_help()
quit()
Vire(
preload = args.preload or '',
argv = args.argv,
is_module = args.m,
glob_patterns = args.glob,
clear_opt = args.clear,
silent = args.silent,
auto_full_reload = args.auto_full_reload,
).main()
if __name__ == '__main__':
main()