Skip to content

Commit

Permalink
Add Vision label detection.
Browse files Browse the repository at this point in the history
  • Loading branch information
daspecster committed Sep 12, 2016
1 parent fb2fb6b commit c49f0c2
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 1 deletion.
4 changes: 3 additions & 1 deletion google/cloud/vision/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def from_api_repr(cls, response):
:rtype: :class:`~google.cloud.vision.entiy.EntityAnnotation`
:returns: Instance of ``EntityAnnotation``.
"""
bounds = Bounds.from_api_repr(response['boundingPoly'])
bounds = []
if 'boundingPoly' in response:
bounds = Bounds.from_api_repr(response['boundingPoly'])
description = response['description']
locations = [LocationInformation.from_api_repr(location)
for location in response.get('locations', [])]
Expand Down
13 changes: 13 additions & 0 deletions google/cloud/vision/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _detect_annotation(self, feature):
:class:`~google.cloud.vision.entity.EntityAnnotation`.
"""
reverse_types = {
'LABEL_DETECTION': 'labelAnnotations',
'LANDMARK_DETECTION': 'landmarkAnnotations',
'LOGO_DETECTION': 'logoAnnotations',
}
Expand Down Expand Up @@ -122,6 +123,18 @@ def detect_faces(self, limit=10):

return faces

def detect_labels(self, limit=10):
"""Detect labels that describe objects in an image.
:type limit: int
:param limit: The maximum number of labels to try and detect.
:rtype: list
:returns: List of :class:`~google.cloud.vision.entity.EntityAnnotation`
"""
feature = Feature(FeatureTypes.LABEL_DETECTION, limit)
return self._detect_annotation(feature)

def detect_landmarks(self, limit=10):
"""Detect landmarks in an image.
Expand Down
25 changes: 25 additions & 0 deletions unit_tests/vision/_fixtures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,28 @@
LABEL_DETECTION_RESPONSE = {
'responses': [
{
'labelAnnotations': [
{
'mid': '/m/0k4j',
'description': 'automobile',
'score': 0.9776855
},
{
'mid': '/m/07yv9',
'description': 'vehicle',
'score': 0.947987
},
{
'mid': '/m/07r04',
'description': 'truck',
'score': 0.88429511
}
]
}
]
}


LANDMARK_DETECTION_RESPONSE = {
'responses': [
{
Expand Down
21 changes: 21 additions & 0 deletions unit_tests/vision/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,27 @@ def test_face_detection_from_content(self):
image_request['image']['content'])
self.assertEqual(5, image_request['features'][0]['maxResults'])

def test_label_detection_from_source(self):
from google.cloud.vision.entity import EntityAnnotation
from unit_tests.vision._fixtures import (LABEL_DETECTION_RESPONSE as
RETURNED)
credentials = _Credentials()
client = self._makeOne(project=self.PROJECT, credentials=credentials)
client.connection = _Connection(RETURNED)

image = client.image(source_uri=_IMAGE_SOURCE)
labels = image.detect_labels(limit=3)
self.assertEqual(3, len(labels))
self.assertTrue(isinstance(labels[0], EntityAnnotation))
image_request = client.connection._requested[0]['data']['requests'][0]
self.assertEqual(_IMAGE_SOURCE,
image_request['image']['source']['gcs_image_uri'])
self.assertEqual(3, image_request['features'][0]['maxResults'])
self.assertEqual('automobile', labels[0].description)
self.assertEqual('vehicle', labels[1].description)
self.assertEqual('/m/0k4j', labels[0].mid)
self.assertEqual('/m/07yv9', labels[1].mid)

def test_landmark_detection_from_source(self):
from google.cloud.vision.entity import EntityAnnotation
from unit_tests.vision._fixtures import (LANDMARK_DETECTION_RESPONSE as
Expand Down

0 comments on commit c49f0c2

Please sign in to comment.