Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit cbb7e84

Browse files
committedJan 18, 2025·
Type annotations
1 parent 9e4873d commit cbb7e84

File tree

2 files changed

+68
-42
lines changed

2 files changed

+68
-42
lines changed
 

‎nbstripout/_nbstripout.py

+45-22
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
*.ipynb diff=ipynb
110110
"""
111111

112-
from argparse import ArgumentParser, RawDescriptionHelpFormatter
112+
from argparse import ArgumentParser, RawDescriptionHelpFormatter, Namespace
113113
import collections
114114
import copy
115115
import io
@@ -118,6 +118,7 @@
118118
from pathlib import PureWindowsPath
119119
import re
120120
from subprocess import call, check_call, check_output, CalledProcessError, STDOUT
121+
from typing import Optional
121122
import sys
122123
import warnings
123124

@@ -134,7 +135,7 @@
134135
INSTALL_LOCATION_SYSTEM = 'system'
135136

136137

137-
def _get_system_gitconfig_folder():
138+
def _get_system_gitconfig_folder() -> str:
138139
try:
139140
git_config_output = check_output(
140141
['git', 'config', '--system', '--list', '--show-origin'], universal_newlines=True, stderr=STDOUT
@@ -160,7 +161,9 @@ def _get_system_gitconfig_folder():
160161
return path.abspath(path.dirname(system_gitconfig_file_path))
161162

162163

163-
def _get_attrfile(git_config, install_location=INSTALL_LOCATION_LOCAL, attrfile=None):
164+
def _get_attrfile(
165+
git_config: str, install_location: str = INSTALL_LOCATION_LOCAL, attrfile: Optional[str] = None
166+
) -> str:
164167
if not attrfile:
165168
if install_location == INSTALL_LOCATION_SYSTEM:
166169
try:
@@ -185,7 +188,7 @@ def _get_attrfile(git_config, install_location=INSTALL_LOCATION_LOCAL, attrfile=
185188
return attrfile
186189

187190

188-
def _parse_size(num_str):
191+
def _parse_size(num_str: str) -> int:
189192
num_str = num_str.upper()
190193
if num_str[-1].isdigit():
191194
return int(num_str)
@@ -195,11 +198,15 @@ def _parse_size(num_str):
195198
return int(num_str[:-1]) * (10**6)
196199
elif num_str[-1] == 'G':
197200
return int(num_str[:-1]) * (10**9)
198-
else:
199-
raise ValueError(f'Unknown size identifier {num_str[-1]}')
201+
raise ValueError(f'Unknown size identifier {num_str[-1]}')
200202

201203

202-
def install(git_config, install_location=INSTALL_LOCATION_LOCAL, python=None, attrfile=None):
204+
def install(
205+
git_config: str,
206+
install_location: str = INSTALL_LOCATION_LOCAL,
207+
python: Optional[str] = None,
208+
attrfile: Optional[str] = None,
209+
) -> int:
203210
"""Install the git filter and set the git attributes."""
204211
try:
205212
filepath = f'"{PureWindowsPath(python or sys.executable).as_posix()}" -m nbstripout'
@@ -229,7 +236,7 @@ def install(git_config, install_location=INSTALL_LOCATION_LOCAL, python=None, at
229236
diff_exists = '*.ipynb diff' in attrs
230237

231238
if filt_exists and diff_exists:
232-
return
239+
return 0
233240

234241
try:
235242
with open(attrfile, 'a', newline='') as f:
@@ -242,6 +249,7 @@ def install(git_config, install_location=INSTALL_LOCATION_LOCAL, python=None, at
242249
print('*.zpln filter=nbstripout', file=f)
243250
if not diff_exists:
244251
print('*.ipynb diff=ipynb', file=f)
252+
return 0
245253
except PermissionError:
246254
print(f'Installation failed: could not write to {attrfile}', file=sys.stderr)
247255

@@ -251,7 +259,7 @@ def install(git_config, install_location=INSTALL_LOCATION_LOCAL, python=None, at
251259
return 1
252260

253261

254-
def uninstall(git_config, install_location=INSTALL_LOCATION_LOCAL, attrfile=None):
262+
def uninstall(git_config: str, install_location: str = INSTALL_LOCATION_LOCAL, attrfile: Optional[str] = None) -> int:
255263
"""Uninstall the git filter and unset the git attributes."""
256264
try:
257265
call(git_config + ['--unset', 'filter.nbstripout.clean'], stdout=open(devnull, 'w'), stderr=STDOUT)
@@ -274,9 +282,10 @@ def uninstall(git_config, install_location=INSTALL_LOCATION_LOCAL, attrfile=None
274282
f.seek(0)
275283
f.write(''.join(lines))
276284
f.truncate()
285+
return 0
277286

278287

279-
def status(git_config, install_location=INSTALL_LOCATION_LOCAL, verbose=False):
288+
def status(git_config: str, install_location: str = INSTALL_LOCATION_LOCAL, verbose: bool = False) -> int:
280289
"""Return 0 if nbstripout is installed in the current repo, 1 otherwise"""
281290
try:
282291
if install_location == INSTALL_LOCATION_SYSTEM:
@@ -342,22 +351,28 @@ def status(git_config, install_location=INSTALL_LOCATION_LOCAL, verbose=False):
342351
return 1
343352

344353

345-
def process_jupyter_notebook(input_stream, output_stream, args, extra_keys, filename='input from stdin'):
354+
def process_jupyter_notebook(
355+
input_stream: io.IOBase,
356+
output_stream: io.IOBase,
357+
args: Namespace,
358+
extra_keys: list[str],
359+
filename: str = 'input from stdin',
360+
) -> bool:
346361
with warnings.catch_warnings():
347362
warnings.simplefilter('ignore', category=UserWarning)
348363
nb = nbformat.read(input_stream, as_version=nbformat.NO_CONVERT)
349364

350365
nb_orig = copy.deepcopy(nb)
351366
nb_stripped = strip_output(
352-
nb,
353-
args.keep_output,
354-
args.keep_count,
355-
args.keep_id,
356-
extra_keys,
357-
args.drop_empty_cells,
358-
args.drop_tagged_cells.split(),
359-
args.strip_init_cells,
360-
_parse_size(args.max_size),
367+
nb=nb,
368+
keep_output=args.keep_output,
369+
keep_count=args.keep_count,
370+
keep_id=args.keep_id,
371+
extra_keys=extra_keys,
372+
drop_empty_cells=args.drop_empty_cells,
373+
drop_tagged_cells=args.drop_tagged_cells.split(),
374+
strip_init_cells=args.strip_init_cells,
375+
max_size=_parse_size(args.max_size),
361376
)
362377

363378
any_change = nb_orig != nb_stripped
@@ -377,7 +392,13 @@ def process_jupyter_notebook(input_stream, output_stream, args, extra_keys, file
377392
return any_change
378393

379394

380-
def process_zeppelin_notebook(input_stream, output_stream, args, extra_keys, filename='input from stdin'):
395+
def process_zeppelin_notebook(
396+
input_stream: io.IOBase,
397+
output_stream: io.IOBase,
398+
args: Namespace,
399+
extra_keys: list[str],
400+
filename: str = 'input from stdin',
401+
):
381402
nb = json.load(input_stream, object_pairs_hook=collections.OrderedDict)
382403
nb_orig = copy.deepcopy(nb)
383404
nb_stripped = strip_zeppelin_output(nb)
@@ -569,7 +590,9 @@ def main():
569590
try:
570591
with io.open(filename, 'r+', encoding='utf8', newline='') as f:
571592
out = output_stream if args.textconv or args.dry_run else f
572-
if process_notebook(f, out, args, extra_keys, filename):
593+
if process_notebook(
594+
input_stream=f, output_stream=out, args=args, extra_keys=extra_keys, filename=filename
595+
):
573596
any_change = True
574597

575598
except nbformat.reader.NotJSONError:

‎nbstripout/_utils.py

+23-20
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from collections import defaultdict
22
import sys
3+
from typing import Any, Callable, Iterator, Optional
4+
5+
from nbformat import NotebookNode
36

47
__all__ = ['pop_recursive', 'strip_output', 'strip_zeppelin_output', 'MetadataError']
58

@@ -8,7 +11,7 @@ class MetadataError(Exception):
811
pass
912

1013

11-
def pop_recursive(d, key, default=None):
14+
def pop_recursive(d: dict, key: str, default: Optional[NotebookNode] = None) -> NotebookNode:
1215
"""dict.pop(key) where `key` is a `.`-delimited list of nested keys.
1316
1417
>>> d = {'a': {'b': 1, 'c': 2}}
@@ -25,11 +28,11 @@ def pop_recursive(d, key, default=None):
2528
return default
2629
key_head, key_tail = key.split('.', maxsplit=1)
2730
if key_head in d:
28-
return pop_recursive(d[key_head], key_tail, default)
31+
return pop_recursive(d[key_head], key=key_tail, default=default)
2932
return default
3033

3134

32-
def _cells(nb, conditionals):
35+
def _cells(nb: NotebookNode, conditionals: Callable[[NotebookNode], bool]) -> Iterator[NotebookNode]:
3336
"""Remove cells not satisfying any conditional in conditionals and yield all other cells."""
3437
if hasattr(nb, 'nbformat') and nb.nbformat < 4:
3538
for ws in nb.worksheets:
@@ -44,7 +47,7 @@ def _cells(nb, conditionals):
4447
yield cell
4548

4649

47-
def get_size(item):
50+
def get_size(item: Any) -> int:
4851
"""Recursively sums length of all strings in `item`"""
4952
if isinstance(item, str):
5053
return len(item)
@@ -56,7 +59,7 @@ def get_size(item):
5659
return len(str(item))
5760

5861

59-
def determine_keep_output(cell, default, strip_init_cells=False):
62+
def determine_keep_output(cell: NotebookNode, default: bool, strip_init_cells: bool = False):
6063
"""Given a cell, determine whether output should be kept
6164
6265
Based on whether the metadata has "init_cell": true,
@@ -80,29 +83,29 @@ def determine_keep_output(cell, default, strip_init_cells=False):
8083
return default
8184

8285

83-
def _zeppelin_cells(nb):
86+
def _zeppelin_cells(nb: dict) -> Iterator[dict]:
8487
for pg in nb['paragraphs']:
8588
yield pg
8689

8790

88-
def strip_zeppelin_output(nb):
91+
def strip_zeppelin_output(nb: dict) -> dict:
8992
for cell in _zeppelin_cells(nb):
9093
if 'results' in cell:
9194
cell['results'] = {}
9295
return nb
9396

9497

9598
def strip_output(
96-
nb,
97-
keep_output,
98-
keep_count,
99-
keep_id,
100-
extra_keys=[],
101-
drop_empty_cells=False,
102-
drop_tagged_cells=[],
103-
strip_init_cells=False,
104-
max_size=0,
105-
):
99+
nb: NotebookNode,
100+
keep_output: bool,
101+
keep_count: bool,
102+
keep_id: bool,
103+
extra_keys: list[str] = [],
104+
drop_empty_cells: bool = False,
105+
drop_tagged_cells: list[str] = [],
106+
strip_init_cells: bool = False,
107+
max_size: int = 0,
108+
) -> NotebookNode:
106109
"""
107110
Strip the outputs, execution count/prompt number and miscellaneous
108111
metadata from a notebook object, unless specified to keep either the outputs
@@ -122,7 +125,7 @@ def strip_output(
122125
keys[namespace].append(subkey)
123126

124127
for field in keys['metadata']:
125-
pop_recursive(nb.metadata, field)
128+
pop_recursive(nb.metadata, key=field)
126129

127130
conditionals = []
128131
# Keep cells if they have any `source` line that contains non-whitespace
@@ -132,7 +135,7 @@ def strip_output(
132135
conditionals.append(lambda c: tag_to_drop not in c.get('metadata', {}).get('tags', []))
133136

134137
for i, cell in enumerate(_cells(nb, conditionals)):
135-
keep_output_this_cell = determine_keep_output(cell, keep_output, strip_init_cells)
138+
keep_output_this_cell = determine_keep_output(cell=cell, default=keep_output, strip_init_cells=strip_init_cells)
136139

137140
# Remove the outputs, unless directed otherwise
138141
if 'outputs' in cell:
@@ -157,5 +160,5 @@ def strip_output(
157160
if 'id' in cell and not keep_id:
158161
cell['id'] = str(i)
159162
for field in keys['cell']:
160-
pop_recursive(cell, field)
163+
pop_recursive(cell, key=field)
161164
return nb

0 commit comments

Comments
 (0)
Please sign in to comment.