Skip to content

Commit

Permalink
feat(metrics): append more metrics to nn model (#1042) (#1047)
Browse files Browse the repository at this point in the history
* feat(metrics): append more metrics to nn model

* feat(metrics): append more metrics to nn model

* feat(metrics): append more metrics to nn model

* feat(metrics): append more metrics to nn model

* feat(metrics): append more metrics to nn model

* feat(metrics): append more metrics to nn model

* feat(metrics): append more metrics to nn model

* feat(metrics): append more metrics to nn model

* feat(metrics): append more metrics to nn model

* feat(metrics): fix comments

* feat(metrics): fix comments
  • Loading branch information
lixiaoguang01 authored Sep 7, 2022
1 parent 556f6a9 commit 696ef31
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 83 deletions.
20 changes: 11 additions & 9 deletions fedlearner/channel/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import grpc
from fedlearner.common import fl_logging, stats
from fedlearner.channel import channel_pb2, channel_pb2_grpc
from fedlearner.common.metric_collector import metric_collector
from fedlearner.proxy.channel import make_insecure_channel, ChannelType
from fedlearner.channel.client_interceptor import ClientInterceptor
from fedlearner.channel.server_interceptor import ServerInterceptor
Expand Down Expand Up @@ -189,8 +190,7 @@ def __init__(self,
identifier=self._identifier,
retry_interval=self._retry_interval,
wait_fn=self.wait_for_ready,
check_fn=self._channel_response_check_fn,
stats_client=stats_client)
check_fn=self._channel_response_check_fn)
self._channel = grpc.intercept_channel(self._channel,
self._channel_interceptor)

Expand Down Expand Up @@ -404,13 +404,14 @@ def _call_locked(self, call_type):
token=self._token,
identifier=self._identifier,
peer_identifier=self._peer_identifier)
timer = self._stats_client.timer("channel.call_timing").start()
res = self._channel_call.Call(req,
timeout=self._heartbeat_interval,
wait_for_ready=True)
timer.stop()
with metric_collector.emit_timing(
'model.grpc.channel.call_timing'
):
res = self._channel_call.Call(req,
timeout=self._heartbeat_interval,
wait_for_ready=True)
except Exception as e:
self._stats_client.incr("channel.call_error")
metric_collector.emit_counter('model.grpc.channel.call_error', 1)
if isinstance(e, grpc.RpcError):
fl_logging.warning("[Channel] grpc error, code: %s, "
"details: %s.(call type: %s)",
Expand Down Expand Up @@ -469,7 +470,8 @@ def _state_fn(self):
saved_state = self._state
wait_timeout = 10

self._stats_client.gauge("channel.status", self._state.value)
metric_collector.emit_store('model.grpc.channel.status',
self._state.value)
if self._state in (Channel.State.DONE, Channel.State.ERROR):
break

Expand Down
67 changes: 35 additions & 32 deletions fedlearner/channel/client_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import time
import grpc

from fedlearner.common import fl_logging, stats
from fedlearner.common import fl_logging
from fedlearner.common.metric_collector import metric_collector
from fedlearner.channel import channel_pb2

Expand All @@ -47,8 +47,7 @@ def __init__(self,
identifier,
retry_interval,
wait_fn=None,
check_fn=None,
stats_client=None):
check_fn=None):
self._retry_interval = retry_interval
self._identifer = identifier
self._wait_fn = wait_fn
Expand All @@ -57,7 +56,6 @@ def __init__(self,
self._method_details = dict()

self._fl_metadata = ("fl-channel-id", self._identifer)
self._stats_client = stats_client or stats.NoneClient()

def _wait(self):
if self._wait_fn:
Expand Down Expand Up @@ -103,8 +101,10 @@ def call():
self._wait()
return continuation(client_call_details, request)

with self._stats_client.timer("channel.client.unary_unary_timing",
tags={"grpc_method": method_details.method}):
with metric_collector.emit_timing(
'model.grpc.interceptor.channel.client.unary_unary_timing',
tags={'grpc_method': method_details.method}
):
_call = _grpc_with_retry(call, self._retry_interval)
return _UnaryOutcome(
method_details.response_deserializer, _call, self._check_fn)
Expand All @@ -125,18 +125,20 @@ def call():
self._wait()
return continuation(client_call_details, request)

timer = self._stats_client.timer(
"channel.client.unary_stream_timing",
tags={"grpc_method": method_details.method}
).start()
start = time.time()

stream_response = _grpc_with_retry(call, self._retry_interval)

def response_iterator():
for response in stream_response:
self._check_fn(response)
yield method_details.response_deserializer(response.payload)
timer.stop()
value = (time.time() - start) * 1000
metric_collector.emit_store(
f'model.grpc.interceptor.channel.client.unary_stream_timing',
value,
tags={'grpc_method': method_details.method}
)

return response_iterator()

Expand All @@ -157,8 +159,10 @@ def call():
consumer = srq.consumer()
return continuation(client_call_details, iter(consumer))

with self._stats_client.timer("channel.client.stream_unary_timing",
tags={"grpc_method": method_details.method}):
with metric_collector.emit_timing(
'model.grpc.interceptor.channel.client.stream_unary_timing',
tags={'grpc_method': method_details.method}
):
_call = _grpc_with_retry(call, self._retry_interval)
return _UnaryOutcome(method_details.response_deserializer,
_call, self._check_fn)
Expand All @@ -174,9 +178,7 @@ def intercept_stream_stream(self, continuation, client_call_details,

srq = _SingleConsumerSendRequestQueue(
request_iterator, method_details.request_serializer,
stats_client=self._stats_client.with_tags(
tags={"grpc_method": method_details.method}
),
stats_tags={'grpc_method': method_details.method},
stats_prefix="channel.client.stream_stream")
consumer = None

Expand Down Expand Up @@ -229,7 +231,7 @@ def __next__(self):
next = __next__

def __init__(self, request_iterator, request_serializer,
stats_client=None, stats_prefix=""):
stats_tags=None, stats_prefix=""):
self._lock = threading.Lock()
self._seq = 0
self._offset = 0
Expand All @@ -239,13 +241,16 @@ def __init__(self, request_iterator, request_serializer,
self._request_lock = threading.Lock()
self._request_iterator = request_iterator
self._request_serializer = request_serializer
self._stats_client = stats_client or stats.NoneClient()
self._stats_tags = stats_tags or {}
self._stats_prefix = stats_prefix

def _reset(self):
if self._offset > 0:
self._stats_client.incr(
"%s_resend"%self._stats_prefix, self._offset)
metric_collector.emit_counter(
f'model.grpc.interceptor.{self._stats_prefix}_resend',
self._offset,
tags=self._stats_tags
)
self._offset = 0

def _empty(self):
Expand Down Expand Up @@ -285,18 +290,16 @@ def ack(self, consumer, ack):
return False
now = time.time()
n = self._seq - ack
with self._stats_client.pipeline() as pipe:
while len(self._deque) >= n:
req = self._deque.popleft()
self._offset -= 1
# TODO(lixiaoguang.01) old version, to be deleted
pipe.timing("%s_timing"%self._stats_prefix,
(now-req.ts)*1000)
# new version
name_prefix = f'model.grpc.interceptor'
value = (now - req.ts) * 1000
metric_collector.emit_store(
f'{name_prefix}.{self._stats_prefix}_timing', value)
while len(self._deque) >= n:
req = self._deque.popleft()
self._offset -= 1
name_prefix = 'model.grpc.interceptor'
value = (now - req.ts) * 1000
metric_collector.emit_store(
f'{name_prefix}.{self._stats_prefix}_timing',
value,
tags=self._stats_tags
)
return True

def next(self, consumer):
Expand Down
28 changes: 24 additions & 4 deletions fedlearner/common/metric_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@

class AbstractCollector(ABC):

@abstractmethod
def add_global_tags(self, global_tags: Dict[str, str]):
pass

@abstractmethod
def emit_single_point(self,
name: str,
Expand Down Expand Up @@ -80,6 +84,9 @@ def __enter__(self):
def __exit__(self, *a):
pass

def add_global_tags(self, global_tags: Dict[str, str]):
pass

def emit_single_point(self,
name: str,
value: Union[int, float],
Expand Down Expand Up @@ -188,6 +195,11 @@ def __init__(
self._cache: \
Dict[str, Union[UpDownCounter, MetricCollector.Callback]] = {}

self._global_tags = {}

def add_global_tags(self, global_tags: Dict[str, str]):
self._global_tags.update(global_tags)

def emit_single_point(self,
name: str,
value: Union[int, float],
Expand All @@ -196,12 +208,13 @@ def emit_single_point(self,
self._meter.create_observable_gauge(
name=f'values.{name}', callback=cb
)
cb.record(value=value, tags=tags)
cb.record(value=value, tags=self._get_merged_tags(tags))

def emit_timing(self,
name: str,
tags: Dict[str, str] = None) -> Iterator[Span]:
return self._tracer.start_as_current_span(name=name, attributes=tags)
return self._tracer.start_as_current_span(
name=name, attributes=self._get_merged_tags(tags))

def emit_counter(self,
name: str,
Expand All @@ -216,7 +229,7 @@ def emit_counter(self,
)
self._cache[name] = counter
assert isinstance(self._cache[name], UpDownCounter)
self._cache[name].add(value, attributes=tags)
self._cache[name].add(value, attributes=self._get_merged_tags(tags))

def emit_store(self,
name: str,
Expand All @@ -232,7 +245,14 @@ def emit_store(self,
)
self._cache[name] = cb
assert isinstance(self._cache[name], self.Callback)
self._cache[name].record(value=value, tags=tags)
self._cache[name].record(value=value,
tags=self._get_merged_tags(tags))

def _get_merged_tags(self, tags: Dict[str, str] = None):
merged = self._global_tags.copy()
if tags is not None:
merged.update(tags)
return merged


enable = True
Expand Down
12 changes: 3 additions & 9 deletions fedlearner/trainer/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,6 @@ def commit(self):
duration = (time.time() - self._iter_started_at) * 1000
self._current_iter_id = None

# TODO(lixiaoguang.01) old version, to be deleted
with _gctx.stats_client.pipeline() as pipe:
pipe.gauge("trainer.bridge.iterator_step", iter_id)
pipe.timing("trainer.bridge.iterator_timing", duration)
# new version
name_prefix = 'model.grpc.bridge'
metric_collector.emit_store(f'{name_prefix}.iterator_step', iter_id)
metric_collector.emit_store(f'{name_prefix}.iterator_timing', duration)
Expand Down Expand Up @@ -485,10 +480,9 @@ def _receive(self, name):
data = self._received_data[iter_id][name]

duration = time.time() - start_time
_gctx.stats_client.timing(
"trainer.bridge.receive_timing", duration * 1000,
{"bridge_receive_name": name}
)
metric_collector.emit_store(f'model.grpc.bridge.receive_timing',
duration * 1000,
{'bridge_receive_name': name})
fl_logging.debug("[Bridge] Data: received iter_id: %d, name: %s "
"after %f sec",
iter_id, name, duration)
Expand Down
14 changes: 11 additions & 3 deletions fedlearner/trainer/parameter_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import argparse

import tensorflow.compat.v1 as tf
from fedlearner.common import stats

from fedlearner.common.metric_collector import metric_collector
from fedlearner.trainer.cluster_server import ClusterServer
from fedlearner.trainer._global_context import global_context as _gctx

Expand All @@ -32,8 +33,15 @@
args = parser.parse_args()

_gctx.task = "ps"
stats.enable_cpu_stats(_gctx.stats_client)
stats.enable_mem_stats(_gctx.stats_client)
global_tags = {
'task': _gctx.task,
'task_index': str(_gctx.task_index),
'node_name': os.environ.get('HOSTNAME', 'default_node_name'),
'pod_name': os.environ.get('POD_NAME', 'default_pod_name'),
}
metric_collector.add_global_tags(global_tags)
name_prefix = 'model.common.nn_vertical'
metric_collector.emit_counter(f'{name_prefix}.start_count', 1)

cluster_spec = tf.train.ClusterSpec({'ps': {0: args.address}})
cluster_server = ClusterServer(cluster_spec, "ps")
Expand Down
6 changes: 3 additions & 3 deletions fedlearner/trainer/run_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ def after_run(self, run_context, run_values):

def _stats_metric(self, global_step, results):
with self._stats_client.pipeline() as pipe:
# TODO(lixiaoguang.01) old version, to be deleted
pipe.gauge("trainer.metric_global_step", global_step)
# new version
name_prefix = 'model.train.nn_vertical'
metric_collector.emit_store(
f'{name_prefix}.global_step', global_step)
Expand All @@ -104,6 +101,9 @@ def _stats_metric(self, global_step, results):

metric_collector.emit_store(
f'{name_prefix}.{key}', value.sum())
# for compatibility, also emit one with metric name in tags
metric_collector.emit_store(f'{name_prefix}.metric_value',
value.sum(), tags={'metric': key})


class StepMetricsHook(GlobalStepMetricTensorHook):
Expand Down
Loading

0 comments on commit 696ef31

Please sign in to comment.