Skip to content

Commit

Permalink
Added listModels() API implementation and a few samples (#9975)
Browse files Browse the repository at this point in the history
* samples for extract receipts

* add codeowner to form recognizer

* add listModels API
  • Loading branch information
mssfang authored Apr 14, 2020
1 parent 074fea8 commit 03035f9
Show file tree
Hide file tree
Showing 17 changed files with 552 additions and 70 deletions.
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
/sdk/core/azure-core-tracing-opentelemetry/ @samvaity @alzimmermsft
/sdk/cosmos/ @moderakh @kushagraThapar @David-Noble-at-work @kirankumarkolli @mbhaskar
/sdk/eventhubs/ @conniey @srnagar @mssfang
/sdk/formrecognizer/ @samvaity @mssfang
/sdk/identity/ @schaabs @g2vinay @jianghaolu
/sdk/keyvault/ @g2vinay @vcolin7 @samvaity
/sdk/search/ @alzimmermsft @sima-zhu
Expand Down
2 changes: 1 addition & 1 deletion sdk/formrecognizer/azure-ai-formrecognizer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ FormRecognizerAsyncClient formRecognizerAsyncClient = new FormRecognizerClientBu
```java
String receiptSourceUrl = "https://docs.microsoft.com/en-us/azure/cognitive-services/form-recognizer/media/contoso-allinone.jpg";
SyncPoller<OperationResult, IterableStream<ExtractedReceipt>> syncPoller =
formRecognizerClient.beginExtractReceipt(receiptSourceUrl);
formRecognizerClient.beginExtractReceiptsFromUrl(receiptSourceUrl);
IterableStream<ExtractedReceipt> extractedReceipts = syncPoller.getFinalResult();

for (ExtractedReceipt extractedReceiptItem : extractedReceipts) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,20 @@
import com.azure.ai.formrecognizer.implementation.models.AnalyzeOperationResult;
import com.azure.ai.formrecognizer.implementation.models.ContentType;
import com.azure.ai.formrecognizer.implementation.models.SourcePath;
import com.azure.ai.formrecognizer.models.CustomFormModelInfo;
import com.azure.ai.formrecognizer.models.ExtractedReceipt;
import com.azure.ai.formrecognizer.models.FormContentType;
import com.azure.ai.formrecognizer.models.OperationResult;
import com.azure.core.annotation.ReturnType;
import com.azure.core.annotation.ServiceClient;
import com.azure.core.annotation.ServiceMethod;
import com.azure.core.exception.HttpResponseException;
import com.azure.core.http.HttpPipeline;
import com.azure.core.http.rest.PagedFlux;
import com.azure.core.http.rest.PagedResponse;
import com.azure.core.http.rest.PagedResponseBase;
import com.azure.core.http.rest.SimpleResponse;
import com.azure.core.util.Context;
import com.azure.core.util.CoreUtils;
import com.azure.core.util.IterableStream;
import com.azure.core.util.logging.ClientLogger;
Expand All @@ -30,8 +37,10 @@
import java.util.UUID;
import java.util.function.Function;

import static com.azure.ai.formrecognizer.Transforms.toCustomFormModelInfo;
import static com.azure.ai.formrecognizer.Transforms.toReceipt;
import static com.azure.core.util.FluxUtil.monoError;
import static com.azure.core.util.FluxUtil.withContext;

/**
* This class provides an asynchronous client that contains all the operations that apply to Azure Form Recognizer.
Expand Down Expand Up @@ -158,6 +167,65 @@ public PollerFlux<OperationResult, IterableStream<ExtractedReceipt>> beginExtrac
fetchExtractReceiptResult(includeTextDetails));
}

/**
* List information for all models.
*
* @return {@link PagedFlux} of {@link CustomFormModelInfo}.
*/
@ServiceMethod(returns = ReturnType.COLLECTION)
public PagedFlux<CustomFormModelInfo> listModels() {
try {
return new PagedFlux<>(() -> withContext(context -> listFirstPageModelInfo(context)),
continuationToken -> withContext(context -> listNextPageModelInfo(continuationToken, context)));
} catch (RuntimeException ex) {
return new PagedFlux<>(() -> monoError(logger, ex));
}
}

/**
* List information for all models with taking {@link Context}.
*
* @param context Additional context that is passed through the Http pipeline during the service call.
*
* @return {@link PagedFlux} of {@link CustomFormModelInfo}.
*/
PagedFlux<CustomFormModelInfo> listModels(Context context) {
return new PagedFlux<>(() -> listFirstPageModelInfo(context),
continuationToken -> listNextPageModelInfo(continuationToken, context));
}

private Mono<PagedResponse<CustomFormModelInfo>> listFirstPageModelInfo(Context context) {
return service.listCustomModelsSinglePageAsync(context)
.doOnRequest(ignoredValue -> logger.info("Listing information for all models"))
.doOnSuccess(response -> logger.info("Listed all models"))
.doOnError(error -> logger.warning("Failed to list all models information", error))
.map(res -> new PagedResponseBase<>(
res.getRequest(),
res.getStatusCode(),
res.getHeaders(),
toCustomFormModelInfo(res.getValue()),
res.getContinuationToken(),
null));
}

private Mono<PagedResponse<CustomFormModelInfo>> listNextPageModelInfo(String nextPageLink, Context context) {
if (CoreUtils.isNullOrEmpty(nextPageLink)) {
return Mono.empty();
}
return service.listCustomModelsNextSinglePageAsync(nextPageLink, context)
.doOnSubscribe(ignoredValue -> logger.info("Retrieving the next listing page - Page {}", nextPageLink))
.doOnSuccess(response -> logger.info("Retrieved the next listing page - Page {}", nextPageLink))
.doOnError(error -> logger.warning("Failed to retrieve the next listing page - Page {}", nextPageLink,
error))
.map(res -> new PagedResponseBase<>(
res.getRequest(),
res.getStatusCode(),
res.getHeaders(),
toCustomFormModelInfo(res.getValue()),
res.getContinuationToken(),
null));
}

private Function<PollingContext<OperationResult>, Mono<OperationResult>> receiptAnalyzeActivationOperation(
String sourceUrl, boolean includeTextDetails) {
return (pollingContext) -> {
Expand All @@ -176,8 +244,8 @@ private Function<PollingContext<OperationResult>, Mono<OperationResult>> receipt
Flux<ByteBuffer> buffer, long length, FormContentType formContentType, boolean includeTextDetails) {
return (pollingContext) -> {
try {
return service.analyzeReceiptAsyncWithResponseAsync(includeTextDetails,
ContentType.fromString(formContentType.toString()), buffer, length)
return service.analyzeReceiptAsyncWithResponseAsync(ContentType.fromString(formContentType.toString()),
buffer, length, includeTextDetails)
.map(response -> new OperationResult(
parseModelId(response.getDeserializedHeaders().getOperationLocation())));
} catch (RuntimeException ex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
package com.azure.ai.formrecognizer;

import com.azure.ai.formrecognizer.implementation.Utility;
import com.azure.ai.formrecognizer.models.CustomFormModelInfo;
import com.azure.ai.formrecognizer.models.ExtractedReceipt;
import com.azure.ai.formrecognizer.models.FormContentType;
import com.azure.ai.formrecognizer.models.OperationResult;
import com.azure.core.annotation.ReturnType;
import com.azure.core.annotation.ServiceClient;
import com.azure.core.annotation.ServiceMethod;
import com.azure.core.http.rest.PagedIterable;
import com.azure.core.util.Context;
import com.azure.core.util.IterableStream;
import com.azure.core.util.polling.SyncPoller;
import reactor.core.publisher.Flux;
Expand Down Expand Up @@ -112,4 +117,25 @@ public SyncPoller<OperationResult, IterableStream<ExtractedReceipt>> beginExtrac
return client.beginExtractReceipts(buffer, length, includeTextDetails, formContentType, pollInterval)
.getSyncPoller();
}

/**
* List information for all models.
*
* @return {@link PagedIterable} of {@link CustomFormModelInfo} custom form model information.
*/
@ServiceMethod(returns = ReturnType.COLLECTION)
public PagedIterable<CustomFormModelInfo> listModels() {
return new PagedIterable<>(client.listModels(Context.NONE));
}

/**
* List information for all models with taking {@link Context}.
*
* @param context Additional context that is passed through the Http pipeline during the service call.
* @return {@link PagedIterable} of {@link CustomFormModelInfo} custom form model information.
*/
@ServiceMethod(returns = ReturnType.COLLECTION)
public PagedIterable<CustomFormModelInfo> listModels(Context context) {
return new PagedIterable<>(client.listModels(context));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

import com.azure.ai.formrecognizer.implementation.models.AnalyzeResult;
import com.azure.ai.formrecognizer.implementation.models.DocumentResult;
import com.azure.ai.formrecognizer.implementation.models.ModelInfo;
import com.azure.ai.formrecognizer.implementation.models.ReadResult;
import com.azure.ai.formrecognizer.implementation.models.TextLine;
import com.azure.ai.formrecognizer.implementation.models.TextWord;
import com.azure.ai.formrecognizer.models.BoundingBox;
import com.azure.ai.formrecognizer.models.CustomFormModelInfo;
import com.azure.ai.formrecognizer.models.DateValue;
import com.azure.ai.formrecognizer.models.DimensionUnit;
import com.azure.ai.formrecognizer.models.Element;
Expand All @@ -17,6 +19,7 @@
import com.azure.ai.formrecognizer.models.FloatValue;
import com.azure.ai.formrecognizer.models.IntegerValue;
import com.azure.ai.formrecognizer.models.LineElement;
import com.azure.ai.formrecognizer.models.ModelTrainingStatus;
import com.azure.ai.formrecognizer.models.PageMetadata;
import com.azure.ai.formrecognizer.models.PageRange;
import com.azure.ai.formrecognizer.models.Point;
Expand Down Expand Up @@ -374,4 +377,41 @@ private static TimeValue toFieldValueTime(com.azure.ai.formrecognizer.implementa
serviceDateValue.getValueTime(), serviceDateValue.getPage());
// TODO: currently returning a string, waiting on swagger update.
}

/**
* Transform a list of {@link ModelInfo} to a list of {@link CustomFormModelInfo}.
*
* @param list A list of {@link ModelInfo}.
* @return A list of {@link CustomFormModelInfo}.
*/
static List<CustomFormModelInfo> toCustomFormModelInfo(List<ModelInfo> list) {
CollectionTransformer<ModelInfo, CustomFormModelInfo> transformer =
new CollectionTransformer<ModelInfo, CustomFormModelInfo>() {
@Override
CustomFormModelInfo transform(ModelInfo modelInfo) {
return new CustomFormModelInfo(modelInfo.getModelId().toString(),
ModelTrainingStatus.fromString(modelInfo.getStatus().toString()),
modelInfo.getCreatedDateTime(), modelInfo.getLastUpdatedDateTime());
}
};
return transformer.transform(list);
}

/**
* A generic transformation class for collection that transform from type {@code E} to type {@code F}.
*
* @param <E> Transform type E to another type.
* @param <F> Transform to type F from another type.
*/
abstract static class CollectionTransformer<E, F> {
abstract F transform(E e);

List<F> transform(List<E> list) {
List<F> newList = new ArrayList<>();
for (E e : list) {
newList.add(transform(e));
}
return newList;
}
}
}
Loading

0 comments on commit 03035f9

Please sign in to comment.