Skip to content

Commit

Permalink
fix: Fix error in tensorboard uploader thrown when time_series_id is …
Browse files Browse the repository at this point in the history
…None

PiperOrigin-RevId: 666966843
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Aug 24, 2024
1 parent 7af80c6 commit d59a052
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
20 changes: 12 additions & 8 deletions google/cloud/aiplatform/tensorboard/uploader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,14 +305,16 @@ def _get_or_create_time_series(
ValueError:
More than one time series with the resource name was found.
"""
time_series = None
run_name = run_resource_name.split("/")[-1]
run = self._get_or_create_run_resource(run_name)
time_series_id = run.get_tensorboard_time_series_id(tag_name)
time_series = self._api.get_tensorboard_time_series(
request=tensorboard_service.GetTensorboardTimeSeriesRequest(
name=run_resource_name + "/timeSeries/" + time_series_id
if time_series_id:
time_series = self._api.get_tensorboard_time_series(
request=tensorboard_service.GetTensorboardTimeSeriesRequest(
name=run_resource_name + "/timeSeries/" + time_series_id
)
)
)
if not time_series:
time_series = time_series_resource_creator()
time_series.display_name = tag_name
Expand Down Expand Up @@ -416,13 +418,15 @@ def get_or_create(
if tag_name in self._tag_to_time_series_proto:
return self._tag_to_time_series_proto[tag_name]

time_series = None
tb_run = self._get_run_resource()
time_series_id = tb_run.get_tensorboard_time_series_id(tag_name)
time_series = self._api.get_tensorboard_time_series(
request=tensorboard_service.GetTensorboardTimeSeriesRequest(
name=self._run_resource_id + "/timeSeries/" + time_series_id
if time_series_id:
time_series = self._api.get_tensorboard_time_series(
request=tensorboard_service.GetTensorboardTimeSeriesRequest(
name=self._run_resource_id + "/timeSeries/" + time_series_id
)
)
)
if not time_series:
time_series = time_series_resource_creator()
time_series.display_name = tag_name
Expand Down
14 changes: 12 additions & 2 deletions tests/unit/aiplatform/test_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,10 @@ def test_start_uploading_without_create_experiment_fails(self):
with self.assertRaisesRegex(RuntimeError, "call create_experiment()"):
uploader.start_uploading()

@parameterized.parameters(
{"time_series_name": None},
{"time_series_name": _TEST_TIME_SERIES_NAME},
)
@patch.object(
uploader_utils.OnePlatformResourceManager,
"get_run_resource_name",
Expand All @@ -629,10 +633,16 @@ def test_start_uploading_without_create_experiment_fails(self):
@patch.object(metadata, "_experiment_tracker", autospec=True)
@patch.object(experiment_resources, "Experiment", autospec=True)
def test_start_uploading_scalars(
self, experiment_resources_mock, experiment_tracker_mock, run_resource_mock
self,
experiment_resources_mock,
experiment_tracker_mock,
run_resource_mock,
time_series_name,
):
experiment_resources_mock.get.return_value = _TEST_EXPERIMENT_NAME
self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock()
self.mock_run_resource_mock.return_value = _create_tensorboard_run_mock(
time_series_name=time_series_name
)
experiment_tracker_mock.set_experiment.return_value = _TEST_EXPERIMENT_NAME
experiment_tracker_mock.set_tensorboard.return_value = (
_TEST_TENSORBOARD_RESOURCE_NAME
Expand Down

0 comments on commit d59a052

Please sign in to comment.