Skip to content

Commit

Permalink
Merge pull request #2237 from daspecster/vision-label-detection
Browse files Browse the repository at this point in the history
Add vision label detection
  • Loading branch information
daspecster authored Sep 15, 2016
2 parents 1f7ed9d + 408f929 commit 143b0a1
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 7 deletions.
2 changes: 1 addition & 1 deletion google/cloud/vision/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def from_api_repr(cls, response):
:rtype: :class:`~google.cloud.vision.entiy.EntityAnnotation`
:returns: Instance of ``EntityAnnotation``.
"""
bounds = Bounds.from_api_repr(response['boundingPoly'])
bounds = Bounds.from_api_repr(response.get('boundingPoly'))
description = response['description']
locations = [LocationInformation.from_api_repr(location)
for location in response.get('locations', [])]
Expand Down
13 changes: 7 additions & 6 deletions google/cloud/vision/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ def from_api_repr(cls, response_vertices):
:type response_vertices: dict
:param response_vertices: List of vertices.
:rtype: :class:`~google.cloud.vision.geometry.BoundsBase`
:returns: Instance of BoundsBase with populated verticies.
:rtype: :class:`~google.cloud.vision.geometry.BoundsBase` or None
:returns: Instance of BoundsBase with populated verticies or None.
"""
vertices = []
for vertex in response_vertices['vertices']:
vertices.append(Vertex(vertex.get('x', None),
vertex.get('y', None)))
if not response_vertices:
return None

vertices = [Vertex(vertex.get('x', None), vertex.get('y', None)) for
vertex in response_vertices.get('vertices', [])]
return cls(vertices)

@property
Expand Down
13 changes: 13 additions & 0 deletions google/cloud/vision/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _detect_annotation(self, feature):
:class:`~google.cloud.vision.entity.EntityAnnotation`.
"""
reverse_types = {
'LABEL_DETECTION': 'labelAnnotations',
'LANDMARK_DETECTION': 'landmarkAnnotations',
'LOGO_DETECTION': 'logoAnnotations',
}
Expand Down Expand Up @@ -122,6 +123,18 @@ def detect_faces(self, limit=10):

return faces

def detect_labels(self, limit=10):
"""Detect labels that describe objects in an image.
:type limit: int
:param limit: The maximum number of labels to try and detect.
:rtype: list
:returns: List of :class:`~google.cloud.vision.entity.EntityAnnotation`
"""
feature = Feature(FeatureTypes.LABEL_DETECTION, limit)
return self._detect_annotation(feature)

def detect_landmarks(self, limit=10):
"""Detect landmarks in an image.
Expand Down
25 changes: 25 additions & 0 deletions unit_tests/vision/_fixtures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,28 @@
LABEL_DETECTION_RESPONSE = {
'responses': [
{
'labelAnnotations': [
{
'mid': '/m/0k4j',
'description': 'automobile',
'score': 0.9776855
},
{
'mid': '/m/07yv9',
'description': 'vehicle',
'score': 0.947987
},
{
'mid': '/m/07r04',
'description': 'truck',
'score': 0.88429511
}
]
}
]
}


LANDMARK_DETECTION_RESPONSE = {
'responses': [
{
Expand Down
21 changes: 21 additions & 0 deletions unit_tests/vision/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,27 @@ def test_face_detection_from_content(self):
image_request['image']['content'])
self.assertEqual(5, image_request['features'][0]['maxResults'])

def test_label_detection_from_source(self):
from google.cloud.vision.entity import EntityAnnotation
from unit_tests.vision._fixtures import (LABEL_DETECTION_RESPONSE as
RETURNED)
credentials = _Credentials()
client = self._makeOne(project=self.PROJECT, credentials=credentials)
client.connection = _Connection(RETURNED)

image = client.image(source_uri=_IMAGE_SOURCE)
labels = image.detect_labels(limit=3)
self.assertEqual(3, len(labels))
self.assertTrue(isinstance(labels[0], EntityAnnotation))
image_request = client.connection._requested[0]['data']['requests'][0]
self.assertEqual(_IMAGE_SOURCE,
image_request['image']['source']['gcs_image_uri'])
self.assertEqual(3, image_request['features'][0]['maxResults'])
self.assertEqual('automobile', labels[0].description)
self.assertEqual('vehicle', labels[1].description)
self.assertEqual('/m/0k4j', labels[0].mid)
self.assertEqual('/m/07yv9', labels[1].mid)

def test_landmark_detection_from_source(self):
from google.cloud.vision.entity import EntityAnnotation
from unit_tests.vision._fixtures import (LANDMARK_DETECTION_RESPONSE as
Expand Down

0 comments on commit 143b0a1

Please sign in to comment.