Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
fix bug in NaN support (#2077)
Browse files Browse the repository at this point in the history
  • Loading branch information
QuanluZhang authored Feb 19, 2020
1 parent 914cc1f commit be2f3c7
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
5 changes: 4 additions & 1 deletion src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,8 @@ def handle_report_metric_data(self, data):
Data type not supported
"""
logger.debug('handle report metric data = %s', data)

if 'value' in data:
data['value'] = json_tricks.loads(data['value'])
if data['type'] == MetricType.REQUEST_PARAMETER:
assert multi_phase_enabled()
assert data['trial_job_id'] is not None
Expand Down Expand Up @@ -627,6 +628,8 @@ def handle_import_data(self, data):
AssertionError
data doesn't have required key 'parameter' and 'value'
"""
for entry in data:
entry['value'] = json_tricks.loads(entry['value'])
_completed_num = 0
for trial_info in data:
logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data))
Expand Down
2 changes: 2 additions & 0 deletions src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,8 @@ def handle_report_metric_data(self, data):
ValueError
Data type not supported
"""
if 'value' in data:
data['value'] = json_tricks.loads(data['value'])
if data['type'] == MetricType.REQUEST_PARAMETER:
assert multi_phase_enabled()
assert data['trial_job_id'] is not None
Expand Down
5 changes: 4 additions & 1 deletion src/sdk/pynni/nni/msg_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def handle_import_data(self, data):
"""Import additional data for tuning
data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value'
"""
for entry in data:
entry['value'] = json_tricks.loads(entry['value'])
self.tuner.import_data(data)

def handle_add_customized_trial(self, data):
Expand All @@ -128,7 +130,8 @@ def handle_report_metric_data(self, data):
- 'type': report type, support {'FINAL', 'PERIODICAL'}
"""
# metrics value is dumped as json string in trial, so we need to decode it here
data['value'] = json_tricks.loads(data['value'])
if 'value' in data:
data['value'] = json_tricks.loads(data['value'])
if data['type'] == MetricType.FINAL:
self._handle_final_metric_data(data)
elif data['type'] == MetricType.PERIODICAL:
Expand Down

0 comments on commit be2f3c7

Please sign in to comment.