109
109
*.ipynb diff=ipynb
110
110
"""
111
111
112
- from argparse import ArgumentParser , RawDescriptionHelpFormatter
112
+ from argparse import ArgumentParser , RawDescriptionHelpFormatter , Namespace
113
113
import collections
114
114
import copy
115
115
import io
118
118
from pathlib import PureWindowsPath
119
119
import re
120
120
from subprocess import call , check_call , check_output , CalledProcessError , STDOUT
121
+ from typing import Optional
121
122
import sys
122
123
import warnings
123
124
134
135
INSTALL_LOCATION_SYSTEM = 'system'
135
136
136
137
137
- def _get_system_gitconfig_folder ():
138
+ def _get_system_gitconfig_folder () -> str :
138
139
try :
139
140
git_config_output = check_output (
140
141
['git' , 'config' , '--system' , '--list' , '--show-origin' ], universal_newlines = True , stderr = STDOUT
@@ -160,7 +161,9 @@ def _get_system_gitconfig_folder():
160
161
return path .abspath (path .dirname (system_gitconfig_file_path ))
161
162
162
163
163
- def _get_attrfile (git_config , install_location = INSTALL_LOCATION_LOCAL , attrfile = None ):
164
+ def _get_attrfile (
165
+ git_config : str , install_location : str = INSTALL_LOCATION_LOCAL , attrfile : Optional [str ] = None
166
+ ) -> str :
164
167
if not attrfile :
165
168
if install_location == INSTALL_LOCATION_SYSTEM :
166
169
try :
@@ -185,7 +188,7 @@ def _get_attrfile(git_config, install_location=INSTALL_LOCATION_LOCAL, attrfile=
185
188
return attrfile
186
189
187
190
188
- def _parse_size (num_str ) :
191
+ def _parse_size (num_str : str ) -> int :
189
192
num_str = num_str .upper ()
190
193
if num_str [- 1 ].isdigit ():
191
194
return int (num_str )
@@ -195,11 +198,15 @@ def _parse_size(num_str):
195
198
return int (num_str [:- 1 ]) * (10 ** 6 )
196
199
elif num_str [- 1 ] == 'G' :
197
200
return int (num_str [:- 1 ]) * (10 ** 9 )
198
- else :
199
- raise ValueError (f'Unknown size identifier { num_str [- 1 ]} ' )
201
+ raise ValueError (f'Unknown size identifier { num_str [- 1 ]} ' )
200
202
201
203
202
- def install (git_config , install_location = INSTALL_LOCATION_LOCAL , python = None , attrfile = None ):
204
+ def install (
205
+ git_config : str ,
206
+ install_location : str = INSTALL_LOCATION_LOCAL ,
207
+ python : Optional [str ] = None ,
208
+ attrfile : Optional [str ] = None ,
209
+ ) -> int :
203
210
"""Install the git filter and set the git attributes."""
204
211
try :
205
212
filepath = f'"{ PureWindowsPath (python or sys .executable ).as_posix ()} " -m nbstripout'
@@ -229,7 +236,7 @@ def install(git_config, install_location=INSTALL_LOCATION_LOCAL, python=None, at
229
236
diff_exists = '*.ipynb diff' in attrs
230
237
231
238
if filt_exists and diff_exists :
232
- return
239
+ return 0
233
240
234
241
try :
235
242
with open (attrfile , 'a' , newline = '' ) as f :
@@ -242,6 +249,7 @@ def install(git_config, install_location=INSTALL_LOCATION_LOCAL, python=None, at
242
249
print ('*.zpln filter=nbstripout' , file = f )
243
250
if not diff_exists :
244
251
print ('*.ipynb diff=ipynb' , file = f )
252
+ return 0
245
253
except PermissionError :
246
254
print (f'Installation failed: could not write to { attrfile } ' , file = sys .stderr )
247
255
@@ -251,7 +259,7 @@ def install(git_config, install_location=INSTALL_LOCATION_LOCAL, python=None, at
251
259
return 1
252
260
253
261
254
- def uninstall (git_config , install_location = INSTALL_LOCATION_LOCAL , attrfile = None ):
262
+ def uninstall (git_config : str , install_location : str = INSTALL_LOCATION_LOCAL , attrfile : Optional [ str ] = None ) -> int :
255
263
"""Uninstall the git filter and unset the git attributes."""
256
264
try :
257
265
call (git_config + ['--unset' , 'filter.nbstripout.clean' ], stdout = open (devnull , 'w' ), stderr = STDOUT )
@@ -274,9 +282,10 @@ def uninstall(git_config, install_location=INSTALL_LOCATION_LOCAL, attrfile=None
274
282
f .seek (0 )
275
283
f .write ('' .join (lines ))
276
284
f .truncate ()
285
+ return 0
277
286
278
287
279
- def status (git_config , install_location = INSTALL_LOCATION_LOCAL , verbose = False ):
288
+ def status (git_config : str , install_location : str = INSTALL_LOCATION_LOCAL , verbose : bool = False ) -> int :
280
289
"""Return 0 if nbstripout is installed in the current repo, 1 otherwise"""
281
290
try :
282
291
if install_location == INSTALL_LOCATION_SYSTEM :
@@ -342,22 +351,28 @@ def status(git_config, install_location=INSTALL_LOCATION_LOCAL, verbose=False):
342
351
return 1
343
352
344
353
345
- def process_jupyter_notebook (input_stream , output_stream , args , extra_keys , filename = 'input from stdin' ):
354
+ def process_jupyter_notebook (
355
+ input_stream : io .IOBase ,
356
+ output_stream : io .IOBase ,
357
+ args : Namespace ,
358
+ extra_keys : list [str ],
359
+ filename : str = 'input from stdin' ,
360
+ ) -> bool :
346
361
with warnings .catch_warnings ():
347
362
warnings .simplefilter ('ignore' , category = UserWarning )
348
363
nb = nbformat .read (input_stream , as_version = nbformat .NO_CONVERT )
349
364
350
365
nb_orig = copy .deepcopy (nb )
351
366
nb_stripped = strip_output (
352
- nb ,
353
- args .keep_output ,
354
- args .keep_count ,
355
- args .keep_id ,
356
- extra_keys ,
357
- args .drop_empty_cells ,
358
- args .drop_tagged_cells .split (),
359
- args .strip_init_cells ,
360
- _parse_size (args .max_size ),
367
+ nb = nb ,
368
+ keep_output = args .keep_output ,
369
+ keep_count = args .keep_count ,
370
+ keep_id = args .keep_id ,
371
+ extra_keys = extra_keys ,
372
+ drop_empty_cells = args .drop_empty_cells ,
373
+ drop_tagged_cells = args .drop_tagged_cells .split (),
374
+ strip_init_cells = args .strip_init_cells ,
375
+ max_size = _parse_size (args .max_size ),
361
376
)
362
377
363
378
any_change = nb_orig != nb_stripped
@@ -377,7 +392,13 @@ def process_jupyter_notebook(input_stream, output_stream, args, extra_keys, file
377
392
return any_change
378
393
379
394
380
- def process_zeppelin_notebook (input_stream , output_stream , args , extra_keys , filename = 'input from stdin' ):
395
+ def process_zeppelin_notebook (
396
+ input_stream : io .IOBase ,
397
+ output_stream : io .IOBase ,
398
+ args : Namespace ,
399
+ extra_keys : list [str ],
400
+ filename : str = 'input from stdin' ,
401
+ ):
381
402
nb = json .load (input_stream , object_pairs_hook = collections .OrderedDict )
382
403
nb_orig = copy .deepcopy (nb )
383
404
nb_stripped = strip_zeppelin_output (nb )
@@ -569,7 +590,9 @@ def main():
569
590
try :
570
591
with io .open (filename , 'r+' , encoding = 'utf8' , newline = '' ) as f :
571
592
out = output_stream if args .textconv or args .dry_run else f
572
- if process_notebook (f , out , args , extra_keys , filename ):
593
+ if process_notebook (
594
+ input_stream = f , output_stream = out , args = args , extra_keys = extra_keys , filename = filename
595
+ ):
573
596
any_change = True
574
597
575
598
except nbformat .reader .NotJSONError :
0 commit comments