Skip to content

Commit

Permalink
add option to disable logger while compiling to avoid graph breaks (#…
Browse files Browse the repository at this point in the history
…6496)

adding an option to disable calls for logger while compiling to avoid
graph breaks. Here I used an environment variable to determine whether
to activate this option, but it can also be determined using the json
config file or any other way you see fit.

---------

Co-authored-by: snahir <snahir@habana.ai>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 15, 2024
1 parent bf60fc0 commit ce468c3
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions deepspeed/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import logging
import sys
import os
import torch
from deepspeed.runtime.compiler import is_compile_supported

log_levels = {
"debug": logging.DEBUG,
Expand All @@ -19,6 +21,31 @@

class LoggerFactory:

def create_warning_filter(logger):
warn = False

def warn_once(record):
nonlocal warn
if is_compile_supported() and torch.compiler.is_compiling() and not warn:
warn = True
logger.warning("To avoid graph breaks caused by logger in compile-mode, it is recommended to"
" disable logging by setting env var DISABLE_LOGS_WHILE_COMPILING=1")
return True

return warn_once

@staticmethod
def logging_decorator(func):

@functools.wraps(func)
def wrapper(*args, **kwargs):
if torch.compiler.is_compiling():
return
else:
return func(*args, **kwargs)

return wrapper

@staticmethod
def create_logger(name=None, level=logging.INFO):
"""create a logger
Expand All @@ -44,6 +71,12 @@ def create_logger(name=None, level=logging.INFO):
ch.setLevel(level)
ch.setFormatter(formatter)
logger_.addHandler(ch)
if os.getenv("DISABLE_LOGS_WHILE_COMPILING", "0") == "1":
for method in ['info', 'debug', 'error', 'warning', 'critical', 'exception']:
original_logger = getattr(logger_, method)
setattr(logger_, method, LoggerFactory.logging_decorator(original_logger))
else:
logger_.addFilter(LoggerFactory.create_warning_filter(logger_))
return logger_


Expand Down

0 comments on commit ce468c3

Please sign in to comment.