From e13cbc9920dd2d5bf1a44ace11938e204b3a2bed Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Wed, 23 Nov 2022 14:10:38 -0800 Subject: [PATCH 1/2] allow-encoded-ds-config --- src/accelerate/utils/deepspeed.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/accelerate/utils/deepspeed.py b/src/accelerate/utils/deepspeed.py index 02d1ab8bc9b..b56334223d2 100644 --- a/src/accelerate/utils/deepspeed.py +++ b/src/accelerate/utils/deepspeed.py @@ -15,6 +15,8 @@ import io import json from copy import deepcopy +import base64 +import os from ..optimizer import AcceleratedOptimizer from ..scheduler import AcceleratedScheduler @@ -43,11 +45,18 @@ def __init__(self, config_file_or_dict): # Don't modify user's data should they want to reuse it (e.g. in tests), because once we # modified it, it will not be accepted here again, since `auto` values would have been overridden config = deepcopy(config_file_or_dict) - elif isinstance(config_file_or_dict, str): + elif os.path.exists(config_file_or_dict): with io.open(config_file_or_dict, "r", encoding="utf-8") as f: config = json.load(f) else: - raise ValueError("expecting either a path to a DeepSpeed config file or a pre-populated dict") + try: + config_decoded = base64.urlsafe_b64decode(config_file_or_dict).decode('utf-8') + config = json.loads(config_decoded) + except (UnicodeDecodeError, AttributeError): + raise ValueError( + f"Expected a string path to an existing deepspeed config, or a dictionary, or a base64 encoded string. Received: {config}" + ) + self.config = config # zero stage - this is done as early as possible, before model is created, to allow From 99e7ffae442c9aa15854f30e8e037c2d6eea16e1 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Mon, 9 Jan 2023 15:29:55 -0800 Subject: [PATCH 2/2] fix style --- src/accelerate/utils/deepspeed.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/accelerate/utils/deepspeed.py b/src/accelerate/utils/deepspeed.py index b56334223d2..afd876e9dc5 100644 --- a/src/accelerate/utils/deepspeed.py +++ b/src/accelerate/utils/deepspeed.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import io import json -from copy import deepcopy -import base64 import os +from copy import deepcopy from ..optimizer import AcceleratedOptimizer from ..scheduler import AcceleratedScheduler @@ -50,7 +50,7 @@ def __init__(self, config_file_or_dict): config = json.load(f) else: try: - config_decoded = base64.urlsafe_b64decode(config_file_or_dict).decode('utf-8') + config_decoded = base64.urlsafe_b64decode(config_file_or_dict).decode("utf-8") config = json.loads(config_decoded) except (UnicodeDecodeError, AttributeError): raise ValueError(