Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HealthChecker Callback #2002

Merged
merged 15 commits into from
Feb 28, 2023
2 changes: 2 additions & 0 deletions composer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from composer.callbacks.checkpoint_saver import CheckpointSaver
from composer.callbacks.early_stopper import EarlyStopper
from composer.callbacks.export_for_inference import ExportForInferenceCallback
from composer.callbacks.health_checker import HealthChecker
from composer.callbacks.image_visualizer import ImageVisualizer
from composer.callbacks.lr_monitor import LRMonitor
from composer.callbacks.memory_monitor import MemoryMonitor
Expand All @@ -29,5 +30,6 @@
'ExportForInferenceCallback',
'ThresholdStopper',
'ImageVisualizer',
'HealthChecker',
'RuntimeEstimator',
]
238 changes: 238 additions & 0 deletions composer/callbacks/health_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Check GPU Health during training."""
import logging
from collections import deque
from datetime import datetime
from typing import List, Optional, Tuple

import torch

try:
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
import pynvml
except ImportError:
pynvml = None

import os

import numpy as np
import pynvml
from slack_sdk.webhook import WebhookClient

from composer.core import Callback, State
from composer.core.time import Timestamp
from composer.loggers import Logger
from composer.utils import dist

log = logging.getLogger(__name__)

__all__ = ['HealthChecker']


class HealthChecker(Callback):
"""Checks for GPU health.

This callback checks for GPU health by tracking and alerting for abnormal
hanlint marked this conversation as resolved.
Show resolved Hide resolved
GPU utilizations.

For example, if the average utilization during the observation window is,
[30, 30, 45], then the range (45-30=15) would exceed a threshold of 10%.

Args:
threshold (float, optional): Threshold of GPU utilization range to
trigger an alert. Defaults to 10.
sample_freq (int, optional): Sample frequency in seconds. Default: 5.
window_size (int, optional): Window size in seconds. HealthChecker will
check for abnormalities at this frequency. Default: 120.
wait (int, optional): Seconds to wait for starting to sample. Default: 120.
slack_webhook_url (str, optional): Slack URL to send alerts. Can also
be set with the SLACK_WEBHOOK_URL environment variable. Default: None
test_mode (bool, optional): If True, will send a test alert at the first check.
Default: False
"""

def __init__(
self,
threshold: float = 10,
sample_freq: int = 5,
window_size: int = 120,
wait: int = 120,
slack_webhook_url: Optional[str] = None,
test_mode: bool = False,
) -> None:
self.sample_freq = sample_freq
self.window_size = window_size
self.wait = wait
self.slack_webhook_url = slack_webhook_url
self.test_mode = test_mode

if not self.slack_webhook_url:
self.slack_webhook_url = os.environ.get('SLACK_WEBHOOK_URL', None)

self.last_sample = 0
self.last_check = 0

self.metrics = []
if self._is_available():
self.metrics.append(GPUUtilization(threshold))

def init(self, state: State, logger: Logger) -> None:
pass

def after_train_batch(self, state: State, logger: Logger):
if not self.metrics:
return

if self._sample(state.timestamp):
for metric in self.metrics:
metric.sample()

if self._check(state.timestamp):
for metric in self.metrics:
message, alert = metric.check()
if self.test_mode and message:
alert = True
message = '[**THIS IS A TEST**]' + message
if alert and not metric.alerted:
self._alert(message, state)
metric.alerted = True
metric.clear()

def _sample(self, timestamp: Timestamp) -> bool:
now = timestamp.total_wct.seconds

if now < self.wait:
return False

if now - self.last_sample >= self.sample_freq:
self.last_sample = now
return True

return False

def _check(self, timestamp: Timestamp) -> bool:
now = timestamp.total_wct.seconds

if now - self.last_check >= self.window_size:
self.last_check = now
return True
return False

def _alert(self, message: str, state: State) -> None:
prefix = '[{now}][{run_name}][node_rank={node_rank}]'.format(
now=datetime.now(),
run_name=state.run_name,
node_rank=dist.get_node_rank(),
)

node_name = os.environ.get('NODENAME', None)
if node_name:
hanlint marked this conversation as resolved.
Show resolved Hide resolved
prefix += f'[node={node_name}]'

message = prefix + ' : ' + message

logging.warning(message)
if self.slack_webhook_url:
client = WebhookClient(url=self.slack_webhook_url)
client.send(text=message)

@staticmethod
def _is_available() -> bool:
if not torch.cuda.is_available():
return False
try:
pynvml.nvmlInit() # type: ignore
return True
except pynvml.NVMLError_LibraryNotFound: # type: ignore
logging.warning('NVML not found, disabling GPU health checking')
except ImportError:
logging.warning('pynvml library not found, disabling GPU health checking.')
except Exception as e:
logging.warning(f'Error initializing NVML: {e}')

return False


class GPUUtilization:
"""GPU Utilization Metric."""
alerted: bool = False
hanlint marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, threshold=10) -> None:
self.samples = deque()
self.threshold = threshold

def sample(self) -> None:
if dist.get_local_rank() == 0:
sample = self._sample()
if sample is not None:
self.samples.append(sample)

def _sample(self) -> Optional[List]:
try:
samples = []
device_count = pynvml.nvmlDeviceGetCount() # type: ignore
for i in range(device_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
samples.append(pynvml.nvmlDeviceGetUtilizationRates(handle).gpu) # type: ignore
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
except pynvml.NVMLError: # type: ignore
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
return None
return samples

def check(self) -> Tuple[Optional[str], bool]:
if dist.get_local_rank() == 0:
average_sample = np.nanmean(list(self.samples), axis=0)
if np.nanmax(average_sample) - np.nanmin(average_sample) > self.threshold:
message = f'Abnormal GPU utilizations: {average_sample}'
return message, True
else:
message = f':+1: Normal GPU utilizations: {average_sample}'
return message, False
return None, False

def clear(self) -> None:
self.samples.clear()


class ECCErrors:
"""Metric for ECC counters."""
alerted: bool = False

def __init__(self, threshold=100) -> None:
self.samples = deque()
self.threshold = threshold

def sample(self) -> None:
if dist.get_local_rank() == 0:
sample = self._sample()
if sample is not None:
self.samples.append(sample)

def _sample(self) -> Optional[List]:
try:
samples = []
device_count = pynvml.nvmlDeviceGetCount() # type: ignore
for i in range(device_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
samples.append(pynvml.nvmlDeviceGetMemoryErrorCounter(handle, 0, 0, 2)) # type: ignore
except pynvml.NVMLError: # type: ignore
return None
return samples

def check(self) -> Tuple[Optional[str], bool]:
if dist.get_local_rank() == 0:
min_counter = np.min(list(self.samples), axis=0)
max_counter = np.max(list(self.samples), axis=0)
gpus_with_error = np.where(max_counter - min_counter > self.threshold)
hanlint marked this conversation as resolved.
Show resolved Hide resolved
if len(gpus_with_error) > 0:
message = 'High memory ECC error for GPUs : {gpus}'
ecc_data = ['GPU: {} ({} -> {})'.format(i, min_counter[i], max_counter[i]) for i in gpus_with_error]
return message.format(
rank=dist.get_node_rank(),
hanlint marked this conversation as resolved.
Show resolved Hide resolved
gpus=ecc_data,
), True

return None, False

def clear(self) -> None:
self.samples.clear()
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def package_files(prefix: str, directory: str, extension: str):
'py-cpuinfo>=8.0.0,<10',
'packaging>=21.3.0,<23',
'importlib-metadata>=5.0.0,<7',
'pynvml>=11.5.0,<12',
hanlint marked this conversation as resolved.
Show resolved Hide resolved
'slack_sdk>=3.19.5,<4',
]
extra_deps = {}

Expand Down
130 changes: 130 additions & 0 deletions tests/callbacks/test_health_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import datetime
from unittest.mock import MagicMock, patch

import pytest

from composer import Timestamp
from composer.callbacks import HealthChecker
from composer.callbacks.health_checker import ECCErrors, GPUUtilization
from composer.utils import dist
from tests.common import world_size


class MockUtil:

def __init__(self, util):
self.gpu = util


@pytest.mark.gpu
@world_size(1, 2)
def test_gpu_utilization(world_size):
import pynvml
HealthChecker._is_available()
hanlint marked this conversation as resolved.
Show resolved Hide resolved

gpu_utilization_values = [
MockUtil(100),
MockUtil(10),
MockUtil(100),
MockUtil(100),
MockUtil(100),
MockUtil(100),
]

with patch.multiple(pynvml,
nvmlDeviceGetUtilizationRates=MagicMock(side_effect=gpu_utilization_values),
nvmlDeviceGetCount=MagicMock(return_value=world_size)):

gpu_utilization = GPUUtilization()
gpu_utilization.sample()
gpu_utilization.sample()
gpu_utilization.sample()
_, alert = gpu_utilization.check()

should_alert = dist.get_local_rank() == 0 and world_size > 1
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
assert alert == should_alert


@pytest.mark.gpu
@world_size(1, 2)
def test_ecc_counters(world_size):
import pynvml
HealthChecker._is_available()
hanlint marked this conversation as resolved.
Show resolved Hide resolved

ecc_counters = [0, 0, 150, 0, 300, 0]

with patch.multiple(pynvml,
nvmlDeviceGetMemoryErrorCounter=MagicMock(side_effect=ecc_counters),
nvmlDeviceGetCount=MagicMock(return_value=world_size)):

ecc_counter = ECCErrors()
ecc_counter.sample()
ecc_counter.sample()
ecc_counter.sample()
_, alert = ecc_counter.check()

# only the local rank 0 alerts
assert alert == (dist.get_local_rank() == 0)


@pytest.mark.gpu
@world_size(1, 2)
def test_health_checker(world_size):
import pynvml

state = MagicMock()
state.run_name = 'pytest-mock-run-kwei73'
logger = MagicMock()

health_checker = HealthChecker(
sample_freq=1,
window_size=3,
wait=0,
)

gpu_utilization_values = [
MockUtil(100),
MockUtil(10),
MockUtil(100),
MockUtil(100),
MockUtil(100),
MockUtil(100),
]

with patch.multiple(pynvml,
nvmlDeviceGetUtilizationRates=MagicMock(side_effect=gpu_utilization_values),
nvmlDeviceGetCount=MagicMock(return_value=world_size)):

# collect data and checker
for seconds in [1, 2, 3]:
state.timestamp = Timestamp(total_wct=datetime.timedelta(seconds=seconds))
health_checker.after_train_batch(state, logger)

should_alert = dist.get_local_rank() == 0 and world_size > 1
hanlint marked this conversation as resolved.
Show resolved Hide resolved
assert health_checker.metrics[0].alerted == should_alert


def test_health_checker_sampling():
timestamp = Timestamp(total_wct=datetime.timedelta(seconds=0))

health_checker = HealthChecker(
sample_freq=1,
window_size=5,
wait=10,
)

config = [
(5, False), # before wait
(11, True),
(11.5, False), # below sample frequency
(12, True),
(20, True),
(11, False), # no time travel
]

for seconds, is_sample in config:
timestamp = Timestamp(total_wct=datetime.timedelta(seconds=seconds))
assert health_checker._sample(timestamp) == is_sample