Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(metrics): append more metrics to nn model (#1042) #1047

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions fedlearner/channel/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import grpc
from fedlearner.common import fl_logging, stats
from fedlearner.channel import channel_pb2, channel_pb2_grpc
from fedlearner.common.metric_collector import metric_collector
from fedlearner.proxy.channel import make_insecure_channel, ChannelType
from fedlearner.channel.client_interceptor import ClientInterceptor
from fedlearner.channel.server_interceptor import ServerInterceptor
Expand Down Expand Up @@ -189,8 +190,7 @@ def __init__(self,
identifier=self._identifier,
retry_interval=self._retry_interval,
wait_fn=self.wait_for_ready,
check_fn=self._channel_response_check_fn,
stats_client=stats_client)
check_fn=self._channel_response_check_fn)
self._channel = grpc.intercept_channel(self._channel,
self._channel_interceptor)

Expand Down Expand Up @@ -404,13 +404,14 @@ def _call_locked(self, call_type):
token=self._token,
identifier=self._identifier,
peer_identifier=self._peer_identifier)
timer = self._stats_client.timer("channel.call_timing").start()
res = self._channel_call.Call(req,
timeout=self._heartbeat_interval,
wait_for_ready=True)
timer.stop()
with metric_collector.emit_timing(
'model.grpc.channel.call_timing'
):
res = self._channel_call.Call(req,
timeout=self._heartbeat_interval,
wait_for_ready=True)
except Exception as e:
self._stats_client.incr("channel.call_error")
metric_collector.emit_counter('model.grpc.channel.call_error', 1)
if isinstance(e, grpc.RpcError):
fl_logging.warning("[Channel] grpc error, code: %s, "
"details: %s.(call type: %s)",
Expand Down Expand Up @@ -469,7 +470,8 @@ def _state_fn(self):
saved_state = self._state
wait_timeout = 10

self._stats_client.gauge("channel.status", self._state.value)
metric_collector.emit_store('model.grpc.channel.status',
self._state.value)
if self._state in (Channel.State.DONE, Channel.State.ERROR):
break

Expand Down
67 changes: 35 additions & 32 deletions fedlearner/channel/client_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import time
import grpc

from fedlearner.common import fl_logging, stats
from fedlearner.common import fl_logging
from fedlearner.common.metric_collector import metric_collector
from fedlearner.channel import channel_pb2

Expand All @@ -47,8 +47,7 @@ def __init__(self,
identifier,
retry_interval,
wait_fn=None,
check_fn=None,
stats_client=None):
check_fn=None):
self._retry_interval = retry_interval
self._identifer = identifier
self._wait_fn = wait_fn
Expand All @@ -57,7 +56,6 @@ def __init__(self,
self._method_details = dict()

self._fl_metadata = ("fl-channel-id", self._identifer)
self._stats_client = stats_client or stats.NoneClient()

def _wait(self):
if self._wait_fn:
Expand Down Expand Up @@ -103,8 +101,10 @@ def call():
self._wait()
return continuation(client_call_details, request)

with self._stats_client.timer("channel.client.unary_unary_timing",
tags={"grpc_method": method_details.method}):
with metric_collector.emit_timing(
'model.grpc.interceptor.channel.client.unary_unary_timing',
tags={'grpc_method': method_details.method}
):
_call = _grpc_with_retry(call, self._retry_interval)
return _UnaryOutcome(
method_details.response_deserializer, _call, self._check_fn)
Expand All @@ -125,18 +125,20 @@ def call():
self._wait()
return continuation(client_call_details, request)

timer = self._stats_client.timer(
"channel.client.unary_stream_timing",
tags={"grpc_method": method_details.method}
).start()
start = time.time()

stream_response = _grpc_with_retry(call, self._retry_interval)

def response_iterator():
for response in stream_response:
self._check_fn(response)
yield method_details.response_deserializer(response.payload)
timer.stop()
value = (time.time() - start) * 1000
metric_collector.emit_store(
f'model.grpc.interceptor.channel.client.unary_stream_timing',
value,
tags={'grpc_method': method_details.method}
)

return response_iterator()

Expand All @@ -157,8 +159,10 @@ def call():
consumer = srq.consumer()
return continuation(client_call_details, iter(consumer))

with self._stats_client.timer("channel.client.stream_unary_timing",
tags={"grpc_method": method_details.method}):
with metric_collector.emit_timing(
'model.grpc.interceptor.channel.client.stream_unary_timing',
tags={'grpc_method': method_details.method}
):
_call = _grpc_with_retry(call, self._retry_interval)
return _UnaryOutcome(method_details.response_deserializer,
_call, self._check_fn)
Expand All @@ -174,9 +178,7 @@ def intercept_stream_stream(self, continuation, client_call_details,

srq = _SingleConsumerSendRequestQueue(
request_iterator, method_details.request_serializer,
stats_client=self._stats_client.with_tags(
tags={"grpc_method": method_details.method}
),
stats_tags={'grpc_method': method_details.method},
stats_prefix="channel.client.stream_stream")
consumer = None

Expand Down Expand Up @@ -229,7 +231,7 @@ def __next__(self):
next = __next__

def __init__(self, request_iterator, request_serializer,
stats_client=None, stats_prefix=""):
stats_tags=None, stats_prefix=""):
self._lock = threading.Lock()
self._seq = 0
self._offset = 0
Expand All @@ -239,13 +241,16 @@ def __init__(self, request_iterator, request_serializer,
self._request_lock = threading.Lock()
self._request_iterator = request_iterator
self._request_serializer = request_serializer
self._stats_client = stats_client or stats.NoneClient()
self._stats_tags = stats_tags or {}
self._stats_prefix = stats_prefix

def _reset(self):
if self._offset > 0:
self._stats_client.incr(
"%s_resend"%self._stats_prefix, self._offset)
metric_collector.emit_counter(
f'model.grpc.interceptor.{self._stats_prefix}_resend',
self._offset,
tags=self._stats_tags
)
self._offset = 0

def _empty(self):
Expand Down Expand Up @@ -285,18 +290,16 @@ def ack(self, consumer, ack):
return False
now = time.time()
n = self._seq - ack
with self._stats_client.pipeline() as pipe:
while len(self._deque) >= n:
req = self._deque.popleft()
self._offset -= 1
# TODO(lixiaoguang.01) old version, to be deleted
pipe.timing("%s_timing"%self._stats_prefix,
(now-req.ts)*1000)
# new version
name_prefix = f'model.grpc.interceptor'
value = (now - req.ts) * 1000
metric_collector.emit_store(
f'{name_prefix}.{self._stats_prefix}_timing', value)
while len(self._deque) >= n:
req = self._deque.popleft()
self._offset -= 1
name_prefix = 'model.grpc.interceptor'
value = (now - req.ts) * 1000
metric_collector.emit_store(
f'{name_prefix}.{self._stats_prefix}_timing',
value,
tags=self._stats_tags
)
return True

def next(self, consumer):
Expand Down
28 changes: 24 additions & 4 deletions fedlearner/common/metric_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@

class AbstractCollector(ABC):

@abstractmethod
def add_global_tags(self, global_tags: Dict[str, str]):
pass

@abstractmethod
def emit_single_point(self,
name: str,
Expand Down Expand Up @@ -80,6 +84,9 @@ def __enter__(self):
def __exit__(self, *a):
pass

def add_global_tags(self, global_tags: Dict[str, str]):
pass

def emit_single_point(self,
name: str,
value: Union[int, float],
Expand Down Expand Up @@ -188,6 +195,11 @@ def __init__(
self._cache: \
Dict[str, Union[UpDownCounter, MetricCollector.Callback]] = {}

self._global_tags = {}

def add_global_tags(self, global_tags: Dict[str, str]):
self._global_tags.update(global_tags)

def emit_single_point(self,
name: str,
value: Union[int, float],
Expand All @@ -196,12 +208,13 @@ def emit_single_point(self,
self._meter.create_observable_gauge(
name=f'values.{name}', callback=cb
)
cb.record(value=value, tags=tags)
cb.record(value=value, tags=self._get_merged_tags(tags))

def emit_timing(self,
name: str,
tags: Dict[str, str] = None) -> Iterator[Span]:
return self._tracer.start_as_current_span(name=name, attributes=tags)
return self._tracer.start_as_current_span(
name=name, attributes=self._get_merged_tags(tags))

def emit_counter(self,
name: str,
Expand All @@ -216,7 +229,7 @@ def emit_counter(self,
)
self._cache[name] = counter
assert isinstance(self._cache[name], UpDownCounter)
self._cache[name].add(value, attributes=tags)
self._cache[name].add(value, attributes=self._get_merged_tags(tags))

def emit_store(self,
name: str,
Expand All @@ -232,7 +245,14 @@ def emit_store(self,
)
self._cache[name] = cb
assert isinstance(self._cache[name], self.Callback)
self._cache[name].record(value=value, tags=tags)
self._cache[name].record(value=value,
tags=self._get_merged_tags(tags))

def _get_merged_tags(self, tags: Dict[str, str] = None):
merged = self._global_tags.copy()
if tags is not None:
merged.update(tags)
return merged


enable = True
Expand Down
12 changes: 3 additions & 9 deletions fedlearner/trainer/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,6 @@ def commit(self):
duration = (time.time() - self._iter_started_at) * 1000
self._current_iter_id = None

# TODO(lixiaoguang.01) old version, to be deleted
with _gctx.stats_client.pipeline() as pipe:
pipe.gauge("trainer.bridge.iterator_step", iter_id)
pipe.timing("trainer.bridge.iterator_timing", duration)
# new version
name_prefix = 'model.grpc.bridge'
metric_collector.emit_store(f'{name_prefix}.iterator_step', iter_id)
metric_collector.emit_store(f'{name_prefix}.iterator_timing', duration)
Expand Down Expand Up @@ -485,10 +480,9 @@ def _receive(self, name):
data = self._received_data[iter_id][name]

duration = time.time() - start_time
_gctx.stats_client.timing(
"trainer.bridge.receive_timing", duration * 1000,
{"bridge_receive_name": name}
)
metric_collector.emit_store(f'model.grpc.bridge.receive_timing',
duration * 1000,
{'bridge_receive_name': name})
fl_logging.debug("[Bridge] Data: received iter_id: %d, name: %s "
"after %f sec",
iter_id, name, duration)
Expand Down
14 changes: 11 additions & 3 deletions fedlearner/trainer/parameter_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import argparse

import tensorflow.compat.v1 as tf
from fedlearner.common import stats

from fedlearner.common.metric_collector import metric_collector
from fedlearner.trainer.cluster_server import ClusterServer
from fedlearner.trainer._global_context import global_context as _gctx

Expand All @@ -32,8 +33,15 @@
args = parser.parse_args()

_gctx.task = "ps"
stats.enable_cpu_stats(_gctx.stats_client)
stats.enable_mem_stats(_gctx.stats_client)
global_tags = {
'task': _gctx.task,
'task_index': str(_gctx.task_index),
'node_name': os.environ.get('HOSTNAME', 'default_node_name'),
'pod_name': os.environ.get('POD_NAME', 'default_pod_name'),
}
metric_collector.add_global_tags(global_tags)
name_prefix = 'model.common.nn_vertical'
metric_collector.emit_counter(f'{name_prefix}.start_count', 1)

cluster_spec = tf.train.ClusterSpec({'ps': {0: args.address}})
cluster_server = ClusterServer(cluster_spec, "ps")
Expand Down
6 changes: 3 additions & 3 deletions fedlearner/trainer/run_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ def after_run(self, run_context, run_values):

def _stats_metric(self, global_step, results):
with self._stats_client.pipeline() as pipe:
# TODO(lixiaoguang.01) old version, to be deleted
pipe.gauge("trainer.metric_global_step", global_step)
# new version
name_prefix = 'model.train.nn_vertical'
metric_collector.emit_store(
f'{name_prefix}.global_step', global_step)
Expand All @@ -104,6 +101,9 @@ def _stats_metric(self, global_step, results):

metric_collector.emit_store(
f'{name_prefix}.{key}', value.sum())
# for compatibility, also emit one with metric name in tags
metric_collector.emit_store(f'{name_prefix}.metric_value',
value.sum(), tags={'metric': key})


class StepMetricsHook(GlobalStepMetricTensorHook):
Expand Down
Loading