Skip to content

Commit

Permalink
Add monitoring for sparsity tooling API usage.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 390032555
  • Loading branch information
fredrec authored and tensorflower-gardener committed Aug 11, 2021
1 parent c1b190b commit 85e6860
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tensorflow_model_optimization/python/core/keras/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
class MonitorBoolGauge():
"""Monitoring utility class for usage metrics."""

_PRUNE_FOR_BENCHMARK_USAGE = monitoring.BoolGauge(
'/tfmot/api/sparsity/prune_for_benchmark',
'prune_for_benchmark usage.', 'status')

_PRUNE_LOW_MAGNITUDE_USAGE = monitoring.BoolGauge(
'/tfmot/api/sparsity/prune_low_magnitude',
'prune_low_magnitude usage.', 'status')
Expand All @@ -43,6 +47,9 @@ def __init__(self, name):
self.bool_gauge = self.get_usage_gauge(name)

def get_usage_gauge(self, name):
"""Gets a gauge by name."""
if name == 'prune_for_benchmark_usage':
return MonitorBoolGauge._PRUNE_FOR_BENCHMARK_USAGE
if name == 'prune_low_magnitude_usage':
return MonitorBoolGauge._PRUNE_LOW_MAGNITUDE_USAGE
if name == 'prune_low_magnitude_wrapper_usage':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ py_strict_library(
visibility = ["//visibility:public"],
deps = [
# tensorflow dep1,
"//tensorflow_model_optimization/python/core/keras:metrics",
"//tensorflow_model_optimization/python/core/sparsity/keras:prune",
"//tensorflow_model_optimization/python/core/sparsity/keras:pruning_schedule",
"//tensorflow_model_optimization/python/core/sparsity/keras:pruning_wrapper",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import tensorflow as tf

from tensorflow_model_optimization.python.core.keras import metrics
from tensorflow_model_optimization.python.core.sparsity.keras import prune
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
Expand Down Expand Up @@ -65,6 +66,7 @@ def _apply_pruning(prunable_object):
layer.pruning_obj.weight_mask_op() # weight = weight * mask


@metrics.MonitorBoolGauge('prune_for_benchmark_usage')
def prune_for_benchmark(keras_model,
target_sparsity,
block_size=(1, 1)):
Expand Down

0 comments on commit 85e6860

Please sign in to comment.