From 05c0fe7ed8efeb5937b88501e9a5013d4a6e2fb3 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 13 Nov 2024 18:45:28 +0000 Subject: [PATCH] [Minor] print_directory_tree returns a string ghstack-source-id: d57f19dd8efcef06676fca40a4d6f95367ff1d55 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1086 (cherry picked from commit 2b19ef1112dde7c114d3971f35b267fcff09bfe2) --- tensordict/utils.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/tensordict/utils.py b/tensordict/utils.py index e43e4e8fd..fb485573c 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2138,7 +2138,7 @@ def _is_json_serializable(item): return isinstance(item, (str, int, float, bool)) or item is None -def print_directory_tree(path, indent="", display_metadata=True): +def print_directory_tree(path, indent="", display_metadata=True) -> str: """Prints the directory tree starting from the specified path. Args: @@ -2147,7 +2147,11 @@ def print_directory_tree(path, indent="", display_metadata=True): display_metadata (bool): if ``True``, metadata of the dir will be displayed too. + Returns: + the string printed with the logger. + """ + string = [] if display_metadata: def get_directory_size(path="."): @@ -2169,17 +2173,23 @@ def format_size(size): total_size_bytes = get_directory_size(path) formatted_size = format_size(total_size_bytes) - logger.info(f"Directory size: {formatted_size}") + string.append(f"Directory size: {formatted_size}") + logger.info(string[-1]) if os.path.isdir(path): - logger.info(indent + os.path.basename(path) + "/") + string.append(indent + os.path.basename(path) + "/") + logger.info(string[-1]) indent += " " for item in os.listdir(path): - print_directory_tree( - os.path.join(path, item), indent=indent, display_metadata=False + string.append( + print_directory_tree( + os.path.join(path, item), indent=indent, display_metadata=False + ) ) else: - logger.info(indent + os.path.basename(path)) + string.append(indent + os.path.basename(path)) + logger.info(string[-1]) + return "\n".join(string) def isin(