Skip to content

Commit

Permalink
Merge branch 'master' into sethvargo/iam_samples
Browse files Browse the repository at this point in the history
  • Loading branch information
sethvargo authored Jan 29, 2020
2 parents 8b94789 + ee73ade commit 606744f
Show file tree
Hide file tree
Showing 23 changed files with 625 additions and 83 deletions.
52 changes: 52 additions & 0 deletions automl/cloud-client/batch_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def batch_predict(project_id, model_id, input_uri, output_uri):
"""Batch predict"""
# [START automl_batch_predict]
from google.cloud import automl

# TODO(developer): Uncomment and set the following variables
# project_id = "YOUR_PROJECT_ID"
# model_id = "YOUR_MODEL_ID"
# input_uri = "gs://YOUR_BUCKET_ID/path/to/your/input/csv_or_jsonl"
# output_uri = "gs://YOUR_BUCKET_ID/path/to/save/results/"

prediction_client = automl.PredictionServiceClient()

# Get the full path of the model.
model_full_id = prediction_client.model_path(
project_id, "us-central1", model_id
)

gcs_source = automl.types.GcsSource(input_uris=[input_uri])

input_config = automl.types.BatchPredictInputConfig(gcs_source=gcs_source)
gcs_destination = automl.types.GcsDestination(output_uri_prefix=output_uri)
output_config = automl.types.BatchPredictOutputConfig(
gcs_destination=gcs_destination
)

response = prediction_client.batch_predict(
model_full_id, input_config, output_config
)

print("Waiting for operation to complete...")
print(
"Batch Prediction results saved to Cloud Storage bucket. {}".format(
response.result()
)
)
# [END automl_batch_predict]
47 changes: 47 additions & 0 deletions automl/cloud-client/batch_predict_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific ladnguage governing permissions and
# limitations under the License.

import datetime
import os

import batch_predict

PROJECT_ID = os.environ["AUTOML_PROJECT_ID"]
BUCKET_ID = "{}-lcm".format(PROJECT_ID)
MODEL_ID = "TEN0000000000000000000"
PREFIX = "TEST_EXPORT_OUTPUT_" + datetime.datetime.now().strftime(
"%Y%m%d%H%M%S"
)


def test_batch_predict(capsys):
# As batch prediction can take a long time. Try to batch predict on a model
# and confirm that the model was not found, but other elements of the
# request were valid.
try:
input_uri = "gs://{}/entity-extraction/input.jsonl".format(BUCKET_ID)
output_uri = "gs://{}/{}/".format(BUCKET_ID, PREFIX)
batch_predict.batch_predict(
PROJECT_ID, MODEL_ID, input_uri, output_uri
)
out, _ = capsys.readouterr()
assert (
"The model is either not found or not supported for prediction yet"
in out
)
except Exception as e:
assert (
"The model is either not found or not supported for prediction yet"
in e.message
)
6 changes: 3 additions & 3 deletions automl/cloud-client/delete_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


@pytest.fixture(scope="function")
def create_dataset():
def dataset_id():
client = automl.AutoMlClient()
project_location = client.location_path(PROJECT_ID, "us-central1")
display_name = "test_" + datetime.datetime.now().strftime("%Y%m%d%H%M%S")
Expand All @@ -39,8 +39,8 @@ def create_dataset():
yield dataset_id


def test_delete_dataset(capsys, create_dataset):
def test_delete_dataset(capsys, dataset_id):
# delete dataset
delete_dataset.delete_dataset(PROJECT_ID, create_dataset)
delete_dataset.delete_dataset(PROJECT_ID, dataset_id)
out, _ = capsys.readouterr()
assert "Dataset deleted." in out
6 changes: 3 additions & 3 deletions automl/cloud-client/get_model_evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


@pytest.fixture(scope="function")
def get_evaluation_id():
def model_evaluation_id():
client = automl.AutoMlClient()
model_full_id = client.model_path(PROJECT_ID, "us-central1", MODEL_ID)
evaluation = None
Expand All @@ -37,9 +37,9 @@ def get_evaluation_id():
yield model_evaluation_id


def test_get_model_evaluation(capsys, get_evaluation_id):
def test_get_model_evaluation(capsys, model_evaluation_id):
get_model_evaluation.get_model_evaluation(
PROJECT_ID, MODEL_ID, get_evaluation_id
PROJECT_ID, MODEL_ID, model_evaluation_id
)
out, _ = capsys.readouterr()
assert "Model evaluation name: " in out
34 changes: 34 additions & 0 deletions automl/cloud-client/get_operation_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def get_operation_status(operation_full_id):
"""Get operation status."""
# [START automl_get_operation_status]
from google.cloud import automl

# TODO(developer): Uncomment and set the following variables
# operation_full_id = \
# "projects/[projectId]/locations/us-central1/operations/[operationId]"

client = automl.AutoMlClient()
# Get the latest state of a long-running operation.
response = client.transport._operations_client.get_operation(
operation_full_id
)

print("Name: {}".format(response.name))
print("Operation details:")
print(response)
# [END automl_get_operation_status]
40 changes: 40 additions & 0 deletions automl/cloud-client/get_operation_status_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from google.cloud import automl
import pytest

import get_operation_status

PROJECT_ID = os.environ["AUTOML_PROJECT_ID"]


@pytest.fixture(scope="function")
def operation_id():
client = automl.AutoMlClient()
project_location = client.location_path(PROJECT_ID, "us-central1")
generator = client.transport._operations_client.list_operations(
project_location, filter_=""
).pages
page = next(generator)
operation = page.next()
yield operation.name


def test_get_operation_status(capsys, operation_id):
get_operation_status.get_operation_status(operation_id)
out, _ = capsys.readouterr()
assert "Operation details" in out
61 changes: 21 additions & 40 deletions automl/cloud-client/import_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,49 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import os

from google.cloud import automl
import pytest

import import_dataset

PROJECT_ID = os.environ["AUTOML_PROJECT_ID"]
BUCKET_ID = "{}-lcm".format(PROJECT_ID)


@pytest.fixture(scope="function")
def create_dataset():
client = automl.AutoMlClient()
project_location = client.location_path(PROJECT_ID, "us-central1")
display_name = "test_" + datetime.datetime.now().strftime("%Y%m%d%H%M%S")
metadata = automl.types.TextSentimentDatasetMetadata(
sentiment_max=4
)
dataset = automl.types.Dataset(
display_name=display_name, text_sentiment_dataset_metadata=metadata
)
response = client.create_dataset(project_location, dataset)
dataset_id = response.result().name.split("/")[-1]

yield dataset_id


@pytest.mark.slow
def test_import_dataset(capsys, create_dataset):
data = (
"gs://{}/sentiment-analysis/dataset.csv".format(BUCKET_ID)
)
dataset_id = create_dataset
import_dataset.import_dataset(PROJECT_ID, dataset_id, data)
out, _ = capsys.readouterr()
assert "Data imported." in out

# delete created dataset
client = automl.AutoMlClient()
dataset_full_id = client.dataset_path(
PROJECT_ID, "us-central1", dataset_id
)
response = client.delete_dataset(dataset_full_id)
response.result()
DATASET_ID = "TEN0000000000000000000"


def test_import_dataset(capsys):
# As importing a dataset can take a long time and only four operations can
# be run on a dataset at once. Try to import into a nonexistent dataset and
# confirm that the dataset was not found, but other elements of the request
# were valid.
try:
data = "gs://{}/sentiment-analysis/dataset.csv".format(BUCKET_ID)
import_dataset.import_dataset(PROJECT_ID, DATASET_ID, data)
out, _ = capsys.readouterr()
assert (
"The Dataset doesn't exist or is inaccessible for use with AutoMl."
in out
)
except Exception as e:
assert (
"The Dataset doesn't exist or is inaccessible for use with AutoMl."
in e.message
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
MODEL_ID = os.environ["SENTIMENT_ANALYSIS_MODEL_ID"]


@pytest.fixture(scope="function")
def verify_model_state():
@pytest.fixture(scope="function", autouse=True)
def setup():
# Verify the model is deployed before trying to predict
client = automl.AutoMlClient()
model_full_id = client.model_path(PROJECT_ID, "us-central1", MODEL_ID)

Expand All @@ -35,8 +36,7 @@ def verify_model_state():
response.result()


def test_sentiment_analysis_predict(capsys, verify_model_state):
verify_model_state
def test_sentiment_analysis_predict(capsys):
text = "Hopefully this Claritin kicks in soon"
language_sentiment_analysis_predict.predict(PROJECT_ID, MODEL_ID, text)
out, _ = capsys.readouterr()
Expand Down
37 changes: 37 additions & 0 deletions automl/cloud-client/list_operation_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def list_operation_status(project_id):
"""List operation status."""
# [START automl_list_operation_status]
from google.cloud import automl

# TODO(developer): Uncomment and set the following variables
# project_id = "YOUR_PROJECT_ID"

client = automl.AutoMlClient()
# A resource that represents Google Cloud Platform location.
project_location = client.location_path(project_id, "us-central1")
# List all the operations names available in the region.
response = client.transport._operations_client.list_operations(
project_location, ""
)

print("List of operations:")
for operation in response:
print("Name: {}".format(operation.name))
print("Operation details:")
print(operation)
# [END automl_list_operation_status]
28 changes: 28 additions & 0 deletions automl/cloud-client/list_operation_status_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pytest

import list_operation_status

PROJECT_ID = os.environ["AUTOML_PROJECT_ID"]


@pytest.mark.slow
def test_list_operation_status(capsys):
list_operation_status.list_operation_status(PROJECT_ID)
out, _ = capsys.readouterr()
assert "Operation details" in out
8 changes: 4 additions & 4 deletions automl/cloud-client/translate_predict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
MODEL_ID = os.environ["TRANSLATION_MODEL_ID"]


@pytest.fixture(scope="function")
def verify_model_state():
@pytest.fixture(scope="function", autouse=True)
def setup():
# Verify the model is deployed before trying to predict
client = automl.AutoMlClient()
model_full_id = client.model_path(PROJECT_ID, "us-central1", MODEL_ID)

Expand All @@ -35,8 +36,7 @@ def verify_model_state():
response.result()


def test_translate_predict(capsys, verify_model_state):
verify_model_state
def test_translate_predict(capsys):
translate_predict.predict(PROJECT_ID, MODEL_ID, "resources/input.txt")
out, _ = capsys.readouterr()
assert "Translated content: " in out
Loading

0 comments on commit 606744f

Please sign in to comment.