Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[utils] Backport traverse_obj (etc) from yt-dlp #31156

Merged
merged 16 commits into from
Nov 3, 2022
20 changes: 20 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

# Various small unit tests
import io
import itertools
import json
import re
import xml.etree.ElementTree

from youtube_dl.utils import (
Expand Down Expand Up @@ -40,10 +42,12 @@
get_element_by_attribute,
get_elements_by_class,
get_elements_by_attribute,
get_first,
InAdvancePagedList,
int_or_none,
intlist_to_bytes,
is_html,
join_nonempty,
js_to_json,
dirkf marked this conversation as resolved.
Show resolved Hide resolved
limit_length,
merge_dicts,
Expand Down Expand Up @@ -79,6 +83,7 @@
strip_or_none,
subtitles_filename,
timeconvert,
traverse_obj,
dirkf marked this conversation as resolved.
Show resolved Hide resolved
unescapeHTML,
unified_strdate,
unified_timestamp,
Expand Down Expand Up @@ -112,12 +117,18 @@
compat_getenv,
compat_os_name,
compat_setenv,
compat_str,
compat_urlparse,
compat_parse_qs,
)


class TestUtil(unittest.TestCase):

# yt-dlp shim
def assertCountEqual(self, expected, got, msg='count should be the same'):
return self.assertEqual(len(tuple(expected)), len(tuple(got)), msg=msg)

def test_timeconvert(self):
self.assertTrue(timeconvert('') is None)
self.assertTrue(timeconvert('bougrg') is None)
Expand Down Expand Up @@ -1475,6 +1486,15 @@ def test_clean_podcast_url(self):
self.assertEqual(clean_podcast_url('https://www.podtrac.com/pts/redirect.mp3/chtbl.com/track/5899E/traffic.megaphone.fm/HSW7835899191.mp3'), 'https://traffic.megaphone.fm/HSW7835899191.mp3')
self.assertEqual(clean_podcast_url('https://play.podtrac.com/npr-344098539/edge1.pod.npr.org/anon.npr-podcasts/podcast/npr/waitwait/2020/10/20201003_waitwait_wwdtmpodcast201003-015621a5-f035-4eca-a9a1-7c118d90bc3c.mp3'), 'https://edge1.pod.npr.org/anon.npr-podcasts/podcast/npr/waitwait/2020/10/20201003_waitwait_wwdtmpodcast201003-015621a5-f035-4eca-a9a1-7c118d90bc3c.mp3')

def test_traverse_obj(self):
dirkf marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(traverse_obj({'a': [{'b': 'c'}]}, ('a', Ellipsis, 'b')), ['c'])
dirkf marked this conversation as resolved.
Show resolved Hide resolved

def test_get_first(self):
self.assertEqual(get_first([{'a': 'b'}], 'a'), 'b')
dirkf marked this conversation as resolved.
Show resolved Hide resolved

def test_join_nonempty(self):
self.assertEqual(join_nonempty('a', 'b'), 'a-b')
self.assertEqual(join_nonempty('a', 'b', from_dict={'a': 'c', 'b': 'd'}), 'c-d')
dirkf marked this conversation as resolved.
Show resolved Hide resolved

dirkf marked this conversation as resolved.
Show resolved Hide resolved
if __name__ == '__main__':
unittest.main()
156 changes: 156 additions & 0 deletions youtube_dl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
compat_HTTPError,
compat_basestring,
compat_chr,
compat_collections_abc,
compat_cookiejar,
compat_ctypes_WINFUNCTYPE,
compat_etree_fromstring,
Expand Down Expand Up @@ -1684,6 +1685,7 @@ def random_user_agent():


NO_DEFAULT = object()
IDENTITY = lambda x: x
dirkf marked this conversation as resolved.
Show resolved Hide resolved

ENGLISH_MONTH_NAMES = [
'January', 'February', 'March', 'April', 'May', 'June',
Expand Down Expand Up @@ -4068,6 +4070,23 @@ def dict_get(d, key_or_keys, default=None, skip_false_values=True):
return d.get(key_or_keys, default)


def try_call(*funcs, **kwargs):

# parameter defaults
expected_type = kwargs.get('expected_type')
fargs = kwargs.get('args', [])
fkwargs = kwargs.get('kwargs', {})

for f in funcs:
try:
val = f(*fargs, **fkwargs)
except (AttributeError, KeyError, TypeError, IndexError, ZeroDivisionError):
pass
else:
if expected_type is None or isinstance(val, expected_type):
return val


def try_get(src, getter, expected_type=None):
if not isinstance(getter, (list, tuple)):
getter = [getter]
Expand Down Expand Up @@ -5801,3 +5820,140 @@ def clean_podcast_url(url):
st\.fm # https://podsights.com/docs/
)/e
)/''', '', url)


def traverse_obj(obj, *path_list, **kwargs):
''' Traverse nested list/dict/tuple
@param path_list A list of paths which are checked one by one.
Each path is a list of keys where each key is a:
- None: Do nothing
- string: A dictionary key
- int: An index into a list
- tuple: A list of keys all of which will be traversed
- Ellipsis: Fetch all values in the object
- Function: Takes the key and value as arguments
and returns whether the key matches or not
@param default Default value to return
@param expected_type Only accept final value of this type (Can also be any callable)
@param get_all Return all the values obtained from a path or only the first one
@param casesense Whether to consider dictionary keys as case sensitive
@param is_user_input Whether the keys are generated from user input. If True,
strings are converted to int/slice if necessary
@param traverse_string Whether to traverse inside strings. If True, any
non-compatible object will also be converted into a string
# TODO: Write tests
'''

# parameter defaults
default = kwargs.get('default')
expected_type = kwargs.get('expected_type')
get_all = kwargs.get('get_all', True)
casesense = kwargs.get('casesense', True)
is_user_input = kwargs.get('is_user_input', False)
traverse_string = kwargs.get('traverse_string', False)

def listish(l):
# TODO support LazyList when ported
return isinstance(l, (list, tuple))

def from_iterable(iterables):
# chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
for it in iterables:
for element in it:
yield element

class Nonlocal:
pass
nl = Nonlocal()

if not casesense:
_lower = lambda k: (k.lower() if isinstance(k, compat_str) else k)
path_list = (map(_lower, variadic(path)) for path in path_list)

def _traverse_obj(obj, path, _current_depth=0):
path = tuple(variadic(path))
for i, key in enumerate(path):
if None in (key, obj):
return obj
if listish(key):
obj = [_traverse_obj(obj, sub_key, _current_depth) for sub_key in key]
key = Ellipsis
if key is Ellipsis:
obj = (obj.values() if isinstance(obj, dict)
else obj if listish(obj)
else compat_str(obj) if traverse_string else [])
_current_depth += 1
nl.depth = max(nl.depth, _current_depth)
return [_traverse_obj(inner_obj, path[i + 1:], _current_depth) for inner_obj in obj]
elif callable(key):
if listish(obj):
obj = enumerate(obj)
elif isinstance(obj, dict):
obj = obj.items()
else:
if not traverse_string:
return None
obj = compat_str(obj)
_current_depth += 1
nl.depth = max(nl.depth, _current_depth)
return [_traverse_obj(v, path[i + 1:], _current_depth) for k, v in obj if try_call(key, args=(k, v))]
elif isinstance(obj, dict) and not (is_user_input and key == ':'):
obj = (obj.get(key) if casesense or (key in obj)
else next((v for k, v in obj.items() if _lower(k) == key), None))
else:
if is_user_input:
key = (int_or_none(key) if ':' not in key
else slice(*map(int_or_none, key.split(':'))))
if key == slice(None):
return _traverse_obj(obj, (Ellipsis,) + path[i + 1:], _current_depth)
if not isinstance(key, (int, slice)):
return None
if not listish(obj):
if not traverse_string:
return None
obj = compat_str(obj)
try:
obj = obj[key]
except IndexError:
return None
return obj

if isinstance(expected_type, type):
type_test = lambda val: val if isinstance(val, expected_type) else None
else:
type_test = expected_type or IDENTITY

for path in path_list:
nl.depth = 0
val = _traverse_obj(obj, path)
if val is not None:
if nl.depth:
for _ in range(nl.depth - 1):
val = from_iterable(v for v in val if v is not None)
val = [v for v in map(type_test, val) if v is not None]
if val:
return val if get_all else val[0]
else:
val = type_test(val)
if val is not None:
return val
return default
dirkf marked this conversation as resolved.
Show resolved Hide resolved


def get_first(obj, keys, **kwargs):
return traverse_obj(obj, (Ellipsis,) + tuple(variadic(keys)), get_all=False, **kwargs)


def variadic(x, allowed_types=(str, bytes, dict)):
return x if isinstance(x, compat_collections_abc.Iterable) and not isinstance(x, allowed_types) else (x,)


dirkf marked this conversation as resolved.
Show resolved Hide resolved
def join_nonempty(*values, **kwargs):

# parameter defaults
delim = kwargs.get('delim', '-')
from_dict = kwargs.get('from_dict')

if from_dict is not None:
values = (traverse_obj(from_dict, variadic(v)) for v in values)
return delim.join(map(str, filter(None, values)))
dirkf marked this conversation as resolved.
Show resolved Hide resolved