Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[utils] Backport traverse_obj (etc) from yt-dlp #31156

Merged
merged 16 commits into from
Nov 3, 2022
20 changes: 20 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

# Various small unit tests
import io
import itertools
import json
import re
import xml.etree.ElementTree

from youtube_dl.utils import (
Expand Down Expand Up @@ -40,10 +42,12 @@
get_element_by_attribute,
get_elements_by_class,
get_elements_by_attribute,
get_first,
InAdvancePagedList,
int_or_none,
intlist_to_bytes,
is_html,
join_nonempty,
js_to_json,
dirkf marked this conversation as resolved.
Show resolved Hide resolved
limit_length,
merge_dicts,
Expand Down Expand Up @@ -79,6 +83,7 @@
strip_or_none,
subtitles_filename,
timeconvert,
traverse_obj,
dirkf marked this conversation as resolved.
Show resolved Hide resolved
unescapeHTML,
unified_strdate,
unified_timestamp,
Expand Down Expand Up @@ -112,12 +117,18 @@
compat_getenv,
compat_os_name,
compat_setenv,
compat_str,
compat_urlparse,
compat_parse_qs,
)


class TestUtil(unittest.TestCase):

# yt-dlp shim
def assertCountEqual(self, expected, got, msg='count should be the same'):
return self.assertEqual(len(tuple(expected)), len(tuple(got)), msg=msg)

def test_timeconvert(self):
self.assertTrue(timeconvert('') is None)
self.assertTrue(timeconvert('bougrg') is None)
Expand Down Expand Up @@ -1475,6 +1486,15 @@ def test_clean_podcast_url(self):
self.assertEqual(clean_podcast_url('https://www.podtrac.com/pts/redirect.mp3/chtbl.com/track/5899E/traffic.megaphone.fm/HSW7835899191.mp3'), 'https://traffic.megaphone.fm/HSW7835899191.mp3')
self.assertEqual(clean_podcast_url('https://play.podtrac.com/npr-344098539/edge1.pod.npr.org/anon.npr-podcasts/podcast/npr/waitwait/2020/10/20201003_waitwait_wwdtmpodcast201003-015621a5-f035-4eca-a9a1-7c118d90bc3c.mp3'), 'https://edge1.pod.npr.org/anon.npr-podcasts/podcast/npr/waitwait/2020/10/20201003_waitwait_wwdtmpodcast201003-015621a5-f035-4eca-a9a1-7c118d90bc3c.mp3')

def test_traverse_obj(self):
dirkf marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(traverse_obj({'a': [{'b': 'c'}]}, ('a', Ellipsis, 'b')), ['c'])
dirkf marked this conversation as resolved.
Show resolved Hide resolved

def test_get_first(self):
self.assertEqual(get_first([{'a': 'b'}], 'a'), 'b')
dirkf marked this conversation as resolved.
Show resolved Hide resolved

def test_join_nonempty(self):
self.assertEqual(join_nonempty('a', 'b'), 'a-b')
self.assertEqual(join_nonempty('a', 'b', from_dict={'a': 'c', 'b': 'd'}), 'c-d')
dirkf marked this conversation as resolved.
Show resolved Hide resolved

dirkf marked this conversation as resolved.
Show resolved Hide resolved
if __name__ == '__main__':
unittest.main()
259 changes: 259 additions & 0 deletions youtube_dl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
compat_HTTPError,
compat_basestring,
compat_chr,
compat_collections_abc,
compat_cookiejar,
compat_ctypes_WINFUNCTYPE,
compat_etree_fromstring,
Expand Down Expand Up @@ -1684,6 +1685,7 @@ def random_user_agent():


NO_DEFAULT = object()
IDENTITY = lambda x: x
dirkf marked this conversation as resolved.
Show resolved Hide resolved

ENGLISH_MONTH_NAMES = [
'January', 'February', 'March', 'April', 'May', 'June',
Expand Down Expand Up @@ -3834,6 +3836,105 @@ def detect_exe_version(output, version_re=None, unrecognized='present'):
return unrecognized


class LazyList(compat_collections_abc.Sequence):
"""Lazy immutable list from an iterable
Note that slices of a LazyList are lists and not LazyList"""

class IndexError(IndexError):
def __init__(self, cause=None):
if cause:
# reproduce `raise from`
self.__cause__ = cause
super(IndexError, self).__init__()

def __init__(self, iterable, **kwargs):
# kwarg-only
reverse = kwargs.get('reverse', False)
_cache = kwargs.get('_cache')

self._iterable = iter(iterable)
self._cache = [] if _cache is None else _cache
self._reversed = reverse

def __iter__(self):
if self._reversed:
# We need to consume the entire iterable to iterate in reverse
for item in self.exhaust():
yield item
return
for item in self._cache:
yield item
for item in self._iterable:
self._cache.append(item)
yield item

def _exhaust(self):
self._cache.extend(self._iterable)
self._iterable = [] # Discard the emptied iterable to make it pickle-able
return self._cache

def exhaust(self):
"""Evaluate the entire iterable"""
return self._exhaust()[::-1 if self._reversed else 1]

@staticmethod
def _reverse_index(x):
return None if x is None else ~x

def __getitem__(self, idx):
if isinstance(idx, slice):
if self._reversed:
idx = slice(self._reverse_index(idx.start), self._reverse_index(idx.stop), -(idx.step or 1))
start, stop, step = idx.start, idx.stop, idx.step or 1
elif isinstance(idx, int):
if self._reversed:
idx = self._reverse_index(idx)
start, stop, step = idx, idx, 0
else:
raise TypeError('indices must be integers or slices')
if ((start or 0) < 0 or (stop or 0) < 0
or (start is None and step < 0)
or (stop is None and step > 0)):
# We need to consume the entire iterable to be able to slice from the end
# Obviously, never use this with infinite iterables
self._exhaust()
try:
return self._cache[idx]
except IndexError as e:
raise self.IndexError(e)
n = max(start or 0, stop or 0) - len(self._cache) + 1
if n > 0:
self._cache.extend(itertools.islice(self._iterable, n))
try:
return self._cache[idx]
except IndexError as e:
raise self.IndexError(e)

def __bool__(self):
try:
self[-1] if self._reversed else self[0]
except self.IndexError:
return False
return True

def __len__(self):
self._exhaust()
return len(self._cache)

def __reversed__(self):
return type(self)(self._iterable, reverse=not self._reversed, _cache=self._cache)

def __copy__(self):
return type(self)(self._iterable, reverse=self._reversed, _cache=self._cache)

def __repr__(self):
# repr and str should mimic a list. So we exhaust the iterable
return repr(self.exhaust())

def __str__(self):
return repr(self.exhaust())


class PagedList(object):
def __len__(self):
# This is only useful for tests
Expand Down Expand Up @@ -4058,6 +4159,10 @@ def multipart_encode(data, boundary=None):
return out, content_type


def variadic(x, allowed_types=(compat_str, bytes, dict)):
return x if isinstance(x, compat_collections_abc.Iterable) and not isinstance(x, allowed_types) else (x,)


def dict_get(d, key_or_keys, default=None, skip_false_values=True):
if isinstance(key_or_keys, (list, tuple)):
for key in key_or_keys:
Expand All @@ -4068,6 +4173,23 @@ def dict_get(d, key_or_keys, default=None, skip_false_values=True):
return d.get(key_or_keys, default)


def try_call(*funcs, **kwargs):

# parameter defaults
expected_type = kwargs.get('expected_type')
fargs = kwargs.get('args', [])
fkwargs = kwargs.get('kwargs', {})

for f in funcs:
try:
val = f(*fargs, **fkwargs)
except (AttributeError, KeyError, TypeError, IndexError, ZeroDivisionError):
pass
else:
if expected_type is None or isinstance(val, expected_type):
return val


def try_get(src, getter, expected_type=None):
if not isinstance(getter, (list, tuple)):
getter = [getter]
Expand Down Expand Up @@ -5801,3 +5923,140 @@ def clean_podcast_url(url):
st\.fm # https://podsights.com/docs/
)/e
)/''', '', url)


def traverse_obj(obj, *path_list, **kwargs):
''' Traverse nested list/dict/tuple
@param path_list A list of paths which are checked one by one.
Each path is a list of keys where each key is a:
- None: Do nothing
- string: A dictionary key
- int: An index into a list
- tuple: A list of keys all of which will be traversed
- Ellipsis: Fetch all values in the object
- Function: Takes the key and value as arguments
and returns whether the key matches or not
@param default Default value to return
@param expected_type Only accept final value of this type (Can also be any callable)
@param get_all Return all the values obtained from a path or only the first one
@param casesense Whether to consider dictionary keys as case sensitive
@param is_user_input Whether the keys are generated from user input. If True,
strings are converted to int/slice if necessary
@param traverse_string Whether to traverse inside strings. If True, any
non-compatible object will also be converted into a string
# TODO: Write tests
'''

# parameter defaults
default = kwargs.get('default')
expected_type = kwargs.get('expected_type')
get_all = kwargs.get('get_all', True)
casesense = kwargs.get('casesense', True)
is_user_input = kwargs.get('is_user_input', False)
traverse_string = kwargs.get('traverse_string', False)

def listish(l):
# TODO support LazyList when ported
return isinstance(l, (list, tuple))

def from_iterable(iterables):
# chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
for it in iterables:
for element in it:
yield element

class Nonlocal:
pass
nl = Nonlocal()

if not casesense:
_lower = lambda k: (k.lower() if isinstance(k, compat_str) else k)
path_list = (map(_lower, variadic(path)) for path in path_list)

def _traverse_obj(obj, path, _current_depth=0):
path = tuple(variadic(path))
for i, key in enumerate(path):
if None in (key, obj):
return obj
if listish(key):
obj = [_traverse_obj(obj, sub_key, _current_depth) for sub_key in key]
key = Ellipsis
if key is Ellipsis:
obj = (obj.values() if isinstance(obj, dict)
else obj if listish(obj)
else compat_str(obj) if traverse_string else [])
_current_depth += 1
nl.depth = max(nl.depth, _current_depth)
return [_traverse_obj(inner_obj, path[i + 1:], _current_depth) for inner_obj in obj]
elif callable(key):
if listish(obj):
obj = enumerate(obj)
elif isinstance(obj, dict):
obj = obj.items()
else:
if not traverse_string:
return None
obj = compat_str(obj)
_current_depth += 1
nl.depth = max(nl.depth, _current_depth)
return [_traverse_obj(v, path[i + 1:], _current_depth) for k, v in obj if try_call(key, args=(k, v))]
elif isinstance(obj, dict) and not (is_user_input and key == ':'):
obj = (obj.get(key) if casesense or (key in obj)
else next((v for k, v in obj.items() if _lower(k) == key), None))
else:
if is_user_input:
key = (int_or_none(key) if ':' not in key
else slice(*map(int_or_none, key.split(':'))))
if key == slice(None):
return _traverse_obj(obj, (Ellipsis,) + path[i + 1:], _current_depth)
if not isinstance(key, (int, slice)):
return None
if not listish(obj):
if not traverse_string:
return None
obj = compat_str(obj)
try:
obj = obj[key]
except IndexError:
return None
return obj

if isinstance(expected_type, type):
type_test = lambda val: val if isinstance(val, expected_type) else None
else:
type_test = expected_type or IDENTITY

for path in path_list:
nl.depth = 0
val = _traverse_obj(obj, path)
if val is not None:
if nl.depth:
for _ in range(nl.depth - 1):
val = from_iterable(v for v in val if v is not None)
val = [v for v in map(type_test, val) if v is not None]
if val:
return val if get_all else val[0]
else:
val = type_test(val)
if val is not None:
return val
return default
dirkf marked this conversation as resolved.
Show resolved Hide resolved


def get_first(obj, keys, **kwargs):
return traverse_obj(obj, (Ellipsis,) + tuple(variadic(keys)), get_all=False, **kwargs)


def variadic(x, allowed_types=(str, bytes, dict)):
return x if isinstance(x, compat_collections_abc.Iterable) and not isinstance(x, allowed_types) else (x,)


dirkf marked this conversation as resolved.
Show resolved Hide resolved
def join_nonempty(*values, **kwargs):

# parameter defaults
delim = kwargs.get('delim', '-')
from_dict = kwargs.get('from_dict')

if from_dict is not None:
values = (traverse_obj(from_dict, variadic(v)) for v in values)
return delim.join(map(str, filter(None, values)))
dirkf marked this conversation as resolved.
Show resolved Hide resolved