Skip to content

Commit

Permalink
FR Training client argument and live tests (Azure#11363)
Browse files Browse the repository at this point in the history
* trainig argument and live tests

* PR feedbacl

* fixes

* Wes' work for test infra

* pr feedback
  • Loading branch information
maririos authored Apr 17, 2020
1 parent 4cb7cb7 commit ffb3545
Show file tree
Hide file tree
Showing 9 changed files with 379 additions and 43 deletions.
30 changes: 26 additions & 4 deletions sdk/formrecognizer/Azure.AI.FormRecognizer/src/CustomFormModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ internal CustomFormModel(Model_internal model)
CreatedOn = model.ModelInfo.CreatedDateTime;
LastModified = model.ModelInfo.LastUpdatedDateTime;
Models = ConvertToSubmodels(model);
TrainingDocuments = model.TrainResult?.TrainingDocuments;
TrainingDocuments = ConvertToTrainingDocuments(model.TrainResult);
Errors = ConvertToFormRecognizerError(model.TrainResult);
}

Expand Down Expand Up @@ -73,9 +73,9 @@ private static IReadOnlyList<CustomFormSubModel> ConvertFromUnlabeled(KeysResult
{
var subModels = new List<CustomFormSubModel>();

var fieldMap = new Dictionary<string, CustomFormModelField>();
foreach (var cluster in keys.Clusters)
{
var fieldMap = new Dictionary<string, CustomFormModelField>();
for (int i = 0; i < cluster.Value.Count; i++)
{
string fieldName = "field-" + i;
Expand All @@ -93,9 +93,13 @@ private static IReadOnlyList<CustomFormSubModel> ConvertFromUnlabeled(KeysResult
private static IReadOnlyList<CustomFormSubModel> ConvertFromLabeled(Model_internal model)
{
var fieldMap = new Dictionary<string, CustomFormModelField>();
foreach (var formFieldsReport in model.TrainResult.Fields)

if (model.TrainResult.Fields != null)
{
fieldMap.Add(formFieldsReport.Name, new CustomFormModelField(formFieldsReport.Name, null, formFieldsReport.Accuracy));
foreach (var formFieldsReport in model.TrainResult.Fields)
{
fieldMap.Add(formFieldsReport.Name, new CustomFormModelField(formFieldsReport.Name, null, formFieldsReport.Accuracy));
}
}

return new List<CustomFormSubModel> {
Expand All @@ -105,6 +109,24 @@ private static IReadOnlyList<CustomFormSubModel> ConvertFromLabeled(Model_intern
fieldMap)};
}

private static IReadOnlyList<TrainingDocumentInfo> ConvertToTrainingDocuments(TrainResult_internal trainResult)
{
var trainingDocs = new List<TrainingDocumentInfo>();
if (trainResult?.TrainingDocuments != null)
{
foreach (var docs in trainResult?.TrainingDocuments)
{
trainingDocs.Add(
new TrainingDocumentInfo(
docs.DocumentName,
docs.PageCount,
docs.Errors ?? new List<FormRecognizerError>(),
docs.Status));
}
}
return trainingDocs;
}

private static IReadOnlyList<FormRecognizerError> ConvertToFormRecognizerError(TrainResult_internal trainResult)
{
var errors = new List<FormRecognizerError>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ namespace Azure.AI.FormRecognizer.Training
public partial class CustomFormModelField
{
internal CustomFormModelField(string name, string label, float? accuracy)
: this(name, accuracy)
{
Name = name;
Label = label;
Accuracy = accuracy;
}
/// <summary>
/// Unique name of the field.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Azure.AI.FormRecognizer.Models;
Expand Down Expand Up @@ -39,6 +37,10 @@ protected FormTrainingClient()
/// </summary>
public FormTrainingClient(Uri endpoint, AzureKeyCredential credential, FormRecognizerClientOptions options)
{
Argument.AssertNotNull(endpoint, nameof(endpoint));
Argument.AssertNotNull(credential, nameof(credential));
Argument.AssertNotNull(options, nameof(options));

var diagnostics = new ClientDiagnostics(options);
HttpPipeline pipeline = HttpPipelineBuilder.Build(options, new AzureKeyCredentialPolicy(credential, Constants.AuthorizationHeader));
ServiceClient = new ServiceClient(diagnostics, pipeline, endpoint.ToString());
Expand All @@ -58,6 +60,8 @@ public FormTrainingClient(Uri endpoint, AzureKeyCredential credential, FormRecog
[ForwardsClientCalls]
public virtual TrainingOperation StartTraining(Uri trainingFiles, bool useLabels = false, TrainingFileFilter filter = default, CancellationToken cancellationToken = default)
{
Argument.AssertNotNull(trainingFiles, nameof(trainingFiles));

var trainRequest = new TrainRequest_internal(trainingFiles.AbsoluteUri, filter, useLabels);

ResponseWithHeaders<ServiceTrainCustomModelAsyncHeaders> response = ServiceClient.RestClient.TrainCustomModelAsync(trainRequest);
Expand All @@ -76,12 +80,18 @@ public virtual TrainingOperation StartTraining(Uri trainingFiles, bool useLabels
[ForwardsClientCalls]
public virtual async Task<TrainingOperation> StartTrainingAsync(Uri trainingFiles, bool useLabels = false, TrainingFileFilter filter = default, CancellationToken cancellationToken = default)
{
Argument.AssertNotNull(trainingFiles, nameof(trainingFiles));

var trainRequest = new TrainRequest_internal(trainingFiles.AbsoluteUri, filter, useLabels);

ResponseWithHeaders<ServiceTrainCustomModelAsyncHeaders> response = await ServiceClient.RestClient.TrainCustomModelAsyncAsync(trainRequest).ConfigureAwait(false);
return new TrainingOperation(response.Headers.Location, ServiceClient);
}

#endregion

#region Management Ops

/// <summary>
/// Get a description of a custom model, including the types of forms it can recognize and the fields it will extract for each form type.
/// </summary>
Expand All @@ -91,6 +101,8 @@ public virtual async Task<TrainingOperation> StartTrainingAsync(Uri trainingFile
[ForwardsClientCalls]
public virtual Response<CustomFormModel> GetCustomModel(string modelId, CancellationToken cancellationToken = default)
{
Argument.AssertNotNullOrEmpty(modelId, nameof(modelId));

Response<Model_internal> response = ServiceClient.GetCustomModel(new Guid(modelId), includeKeys: true, cancellationToken);
return Response.FromValue(new CustomFormModel(response.Value), response.GetRawResponse());
}
Expand All @@ -104,13 +116,12 @@ public virtual Response<CustomFormModel> GetCustomModel(string modelId, Cancella
[ForwardsClientCalls]
public virtual async Task<Response<CustomFormModel>> GetCustomModelAsync(string modelId, CancellationToken cancellationToken = default)
{
Argument.AssertNotNullOrEmpty(modelId, nameof(modelId));

Response<Model_internal> response = await ServiceClient.GetCustomModelAsync(new Guid(modelId), includeKeys: true, cancellationToken).ConfigureAwait(false);
return Response.FromValue(new CustomFormModel(response.Value), response.GetRawResponse());
}

#endregion

#region Management Ops
/// <summary>
/// Delete the model with the specified model ID.
/// </summary>
Expand All @@ -120,6 +131,8 @@ public virtual async Task<Response<CustomFormModel>> GetCustomModelAsync(string
[ForwardsClientCalls]
public virtual Response DeleteModel(string modelId, CancellationToken cancellationToken = default)
{
Argument.AssertNotNullOrEmpty(modelId, nameof(modelId));

return ServiceClient.DeleteCustomModel(new Guid(modelId), cancellationToken);
}

Expand All @@ -132,6 +145,8 @@ public virtual Response DeleteModel(string modelId, CancellationToken cancellati
[ForwardsClientCalls]
public virtual async Task<Response> DeleteModelAsync(string modelId, CancellationToken cancellationToken = default)
{
Argument.AssertNotNullOrEmpty(modelId, nameof(modelId));

return await ServiceClient.DeleteCustomModelAsync(new Guid(modelId), cancellationToken).ConfigureAwait(false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,26 @@ public FormRecognizerClientLiveTests(bool isAsync) : base(isAsync)
{
}

/// <summary>
/// Creates a <see cref="FormRecognizerClient" /> with the endpoint and API key provided via environment
/// variables and instruments it to make use of the Azure Core Test Framework functionalities.
/// </summary>
/// <returns>The instrumented <see cref="FormRecognizerClient" />.</returns>
private FormRecognizerClient CreateInstrumentedClient()
{
var endpointEnvironmentVariable = Environment.GetEnvironmentVariable(TestEnvironment.EndpointEnvironmentVariableName);
var keyEnvironmentVariable = Environment.GetEnvironmentVariable(TestEnvironment.ApiKeyEnvironmentVariableName);

Assert.NotNull(endpointEnvironmentVariable);
Assert.NotNull(keyEnvironmentVariable);

var endpoint = new Uri(endpointEnvironmentVariable);
var credential = new AzureKeyCredential(keyEnvironmentVariable);
var client = new FormRecognizerClient(endpoint, credential);

return InstrumentClient(client);
}

/// <summary>
/// Verifies that the <see cref="FormRecognizerClient" /> is able to connect to the Form
/// Recognizer cognitive service and perform operations.
Expand Down Expand Up @@ -210,24 +230,5 @@ public void CreateFormTrainingClientFromFormRecognizerClient()
Assert.IsNotNull(trainingClient);
}

/// <summary>
/// Creates a <see cref="FormRecognizerClient" /> with the endpoint and API key provided via environment
/// variables and instruments it to make use of the Azure Core Test Framework functionalities.
/// </summary>
/// <returns>The instrumented <see cref="FormRecognizerClient" />.</returns>
private FormRecognizerClient CreateInstrumentedClient()
{
var endpointEnvironmentVariable = Environment.GetEnvironmentVariable(TestEnvironment.EndpointEnvironmentVariableName);
var keyEnvironmentVariable = Environment.GetEnvironmentVariable(TestEnvironment.ApiKeyEnvironmentVariableName);

Assert.NotNull(endpointEnvironmentVariable);
Assert.NotNull(keyEnvironmentVariable);

var endpoint = new Uri(endpointEnvironmentVariable);
var credential = new AzureKeyCredential(keyEnvironmentVariable);
var client = new FormRecognizerClient(endpoint, credential);

return InstrumentClient(client);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ public FormRecognizerClientTests(bool isAsync) : base(isAsync)
{
}

/// <summary>
/// Creates a fake <see cref="FormRecognizerClient" /> and instruments it to make use of the Azure Core
/// Test Framework functionalities.
/// </summary>
/// <returns>The instrumented <see cref="FormRecognizerClient" />.</returns>
private FormRecognizerClient CreateInstrumentedClient()
{
var fakeEndpoint = new Uri("http://localhost");
var fakeCredential = new AzureKeyCredential("fakeKey");
var client = new FormRecognizerClient(fakeEndpoint, fakeCredential);

return InstrumentClient(client);
}

/// <summary>
/// Verifies functionality of the <see cref="FormRecognizerClient"/> constructors.
/// </summary>
Expand Down Expand Up @@ -177,18 +191,5 @@ public void StartRecognizeReceiptsFromUriRespectsTheCancellationToken()
Assert.ThrowsAsync<TaskCanceledException>(async () => await client.StartRecognizeReceiptsFromUriAsync(fakeUri, cancellationToken: cancellationSource.Token));
}

/// <summary>
/// Creates a fake <see cref="FormRecognizerClient" /> and instruments it to make use of the Azure Core
/// Test Framework functionalities.
/// </summary>
/// <returns>The instrumented <see cref="FormRecognizerClient" />.</returns>
private FormRecognizerClient CreateInstrumentedClient()
{
var fakeEndpoint = new Uri("http://localhost");
var fakeCredential = new AzureKeyCredential("fakeKey");
var client = new FormRecognizerClient(fakeEndpoint, fakeCredential);

return InstrumentClient(client);
}
}
}
Loading

0 comments on commit ffb3545

Please sign in to comment.