Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate protocol methods in RestClient #1991

Merged
merged 11 commits into from
Feb 18, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using System.Threading.Tasks;
using AutoRest.CSharp.AutoRest.Plugins;
using AutoRest.CSharp.Input;
using Azure.Core;

namespace AutoRest.CSharp.AutoRest.Communication
{
Expand Down Expand Up @@ -85,6 +86,7 @@ internal static string SaveConfiguration(Configuration configuration)
WriteIfNotDefault(writer, Configuration.Options.DataPlane, configuration.DataPlane);
WriteIfNotDefault(writer, Configuration.Options.SingleTopLevelClient, configuration.SingleTopLevelClient);
WriteIfNotDefault(writer, Configuration.Options.ProjectFolder, configuration.ProjectFolder);
Utf8JsonWriterExtensions.WriteNonEmptyArray(writer, nameof(Configuration.ProtocolMethodList), configuration.ProtocolMethodList);

configuration.MgmtConfiguration.SaveConfiguration(writer);

Expand Down Expand Up @@ -135,6 +137,11 @@ internal static Configuration LoadConfiguration(string basePath, string json)
sharedSourceFolders.Add(Path.Combine(basePath, sharedSourceFolder.GetString()));
}

root.TryGetProperty(nameof(Configuration.Options.ProtocolMethodList), out var protocolMethodList);
var protocolMethods = protocolMethodList.ValueKind == JsonValueKind.Array
? protocolMethodList.EnumerateArray().Select(t => t.ToString()).ToArray()
: Array.Empty<string>();

return new Configuration(
Path.Combine(basePath, root.GetProperty(nameof(Configuration.OutputFolder)).GetString()),
root.GetProperty(nameof(Configuration.Namespace)).GetString(),
Expand All @@ -149,6 +156,7 @@ internal static Configuration LoadConfiguration(string basePath, string json)
ReadOption(root, Configuration.Options.DataPlane),
ReadOption(root, Configuration.Options.SingleTopLevelClient),
ReadStringOption(root, Configuration.Options.ProjectFolder),
protocolMethods,
MgmtConfiguration.LoadConfiguration(root)
);
}
Expand Down
6 changes: 5 additions & 1 deletion src/AutoRest.CSharp/Common/AutoRest/Plugins/Configuration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ public static class Options
public const string SingleTopLevelClient = "single-top-level-client";
public const string AttachDebuggerFormat = "{0}.attach";
public const string ProjectFolder = "project-folder";
public const string ProtocolMethodList = "protocol-method-list";
}

public Configuration(string outputFolder, string? ns, string? name, string[] sharedSourceFolders, bool saveInputs, bool azureArm, bool publicClients, bool modelNamespace, bool headAsBoolean, bool skipCSProjPackageReference, bool dataplane, bool singleTopLevelClient, string projectFolder, MgmtConfiguration mgmtConfiguration)
public Configuration(string outputFolder, string? ns, string? name, string[] sharedSourceFolders, bool saveInputs, bool azureArm, bool publicClients, bool modelNamespace, bool headAsBoolean, bool skipCSProjPackageReference, bool dataplane, bool singleTopLevelClient, string projectFolder, string[] protocolMethodList, MgmtConfiguration mgmtConfiguration)
{
OutputFolder = outputFolder;
Namespace = ns;
Expand All @@ -43,6 +44,7 @@ public Configuration(string outputFolder, string? ns, string? name, string[] sha
DataPlane = dataplane;
SingleTopLevelClient = singleTopLevelClient;
ProjectFolder = Path.IsPathRooted(projectFolder) ? Path.GetRelativePath(outputFolder, projectFolder) : projectFolder;
ProtocolMethodList = protocolMethodList;
MgmtConfiguration = mgmtConfiguration;
}

Expand All @@ -58,6 +60,7 @@ public Configuration(string outputFolder, string? ns, string? name, string[] sha
public bool SkipCSProjPackageReference { get; }
public bool DataPlane { get; }
public bool SingleTopLevelClient { get; }
public string[] ProtocolMethodList { get; }
public MgmtConfiguration MgmtConfiguration { get; }

public string ProjectFolder { get; }
Expand All @@ -78,6 +81,7 @@ public static Configuration GetConfiguration(IPluginCommunication autoRest)
dataplane: GetOptionValue(autoRest, Options.DataPlane),
singleTopLevelClient: GetOptionValue(autoRest, Options.SingleTopLevelClient),
projectFolder: GetOptionStringValue(autoRest, Options.ProjectFolder, TrimFileSuffix),
protocolMethodList: autoRest.GetValue<string[]?>(Options.ProtocolMethodList).GetAwaiter().GetResult() ?? Array.Empty<string>(),
mgmtConfiguration: MgmtConfiguration.GetConfiguration(autoRest)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using AutoRest.CSharp.AutoRest.Plugins;
using AutoRest.CSharp.Generation.Types;
using AutoRest.CSharp.Output.Models;
using AutoRest.CSharp.Output.Models.Requests;
Expand All @@ -20,7 +22,7 @@ namespace AutoRest.CSharp.Generation.Writers
{
internal class RestClientWriter
{
public void WriteClient(CodeWriter writer, RestClient restClient)
public void WriteClient(CodeWriter writer, RestClient restClient, Configuration? configuration = null, IReadOnlyList<LowLevelClientMethod>? protocolMethods = null)
{
var cs = restClient.Type;
var @namespace = cs.Namespace;
Expand All @@ -32,12 +34,44 @@ public void WriteClient(CodeWriter writer, RestClient restClient)

WriteClientCtor(writer, restClient, cs);

var responseClassifierTypes = new List<LowLevelClientWriter.ResponseClassifierType>();
foreach (var method in restClient.Methods)
{
WriteRequestCreation(writer, method, restClient.Parameters, restClient.Fields);
WriteOperation(writer, method, true);
WriteOperation(writer, method, false);

if (protocolMethods != null)
{
var protocolMethodList = protocolMethods.Where(m => m.RequestMethod.Operation.Equals(method.Operation));
if (protocolMethodList != null && protocolMethodList.Count() == 1)
{
var protocolMethod = protocolMethodList.FirstOrDefault();
LowLevelClientWriter.WriteRequestCreationMethod(writer, protocolMethod.RequestMethod, restClient.Fields, responseClassifierTypes);

if (protocolMethod.IsLongRunning)
{
LowLevelClientWriter.WriteLongRunningOperationMethod(writer, protocolMethod, restClient.Fields, true);
LowLevelClientWriter.WriteLongRunningOperationMethod(writer, protocolMethod, restClient.Fields, false);
}
else if (protocolMethod.PagingInfo != null)
{
LowLevelClientWriter.WritePagingMethod(writer, protocolMethod, restClient.Fields, true);
LowLevelClientWriter.WritePagingMethod(writer, protocolMethod, restClient.Fields, false);
}
else
{
if (configuration != null)
{
LowLevelClientWriter.WriteClientMethod(writer, protocolMethod, restClient.Fields, configuration, true);
LowLevelClientWriter.WriteClientMethod(writer, protocolMethod, restClient.Fields, configuration, false);
}
}
}
}
}

LowLevelClientWriter.WriteResponseClassifierMethod(writer, responseClassifierTypes);
}
}
}
Expand Down
11 changes: 11 additions & 0 deletions src/AutoRest.CSharp/DataPlane/AutoRest/DataPlaneOutputLibrary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public DataPlaneOutputLibrary(CodeModel codeModel, BuildContext<DataPlaneOutputL
public IEnumerable<DataPlaneResponseHeaderGroupType> HeaderModels => _headerModels.Values;
internal CachedDictionary<Schema, TypeProvider> SchemaMap => _models;
public IEnumerable<TypeProvider> Models => SchemaMap.Values;
public IDictionary<string, LowLevelOutputLibraryFactory.ClientInfo> DPGClientInfosByName => GetDPGClientInfosByName();

public override CSharpType FindTypeForSchema(Schema schema)
{
Expand Down Expand Up @@ -80,6 +81,16 @@ protected virtual Dictionary<Schema, TypeProvider> BuildModels()
_ => throw new NotImplementedException()
};

private IDictionary<string, LowLevelOutputLibraryFactory.ClientInfo> GetDPGClientInfosByName()
{
var clientInfosByName = _context.CodeModel.OperationGroups
.Select(og => LowLevelOutputLibraryFactory.CreateClientInfo(og, _context))
.ToDictionary(ci => ci.Name);
LowLevelOutputLibraryFactory.SetRequestsToClients(clientInfosByName.Values);

return clientInfosByName;
}

public LongRunningOperation FindLongRunningOperation(Operation operation)
{
Debug.Assert(operation.IsLongRunning);
Expand Down
2 changes: 1 addition & 1 deletion src/AutoRest.CSharp/DataPlane/AutoRest/DataPlaneTarget.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public static void Execute(GeneratedCodeWorkspace project, CodeModel codeModel,
foreach (var client in context.Library.RestClients)
{
var restCodeWriter = new CodeWriter();
restClientWriter.WriteClient(restCodeWriter, client);
restClientWriter.WriteClient(restCodeWriter, client, context.Configuration, client.ProtocolMethods);

project.AddGeneratedFile($"{client.Type.Name}.cs", restCodeWriter.ToString());
}
Expand Down
51 changes: 51 additions & 0 deletions src/AutoRest.CSharp/DataPlane/Output/DataPlaneRestClient.cs
Original file line number Diff line number Diff line change
@@ -1,23 +1,74 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Linq;
using AutoRest.CSharp.Common.Output.Builders;
using AutoRest.CSharp.Input;
using AutoRest.CSharp.Output.Models.Requests;
using AutoRest.CSharp.Output.Models.Types;
using AutoRest.CSharp.Utilities;

namespace AutoRest.CSharp.Output.Models
{
internal class DataPlaneRestClient : RestClient
{
private BuildContext<DataPlaneOutputLibrary> _context;

public IReadOnlyList<LowLevelClientMethod>? ProtocolMethods => _context.Configuration.ProtocolMethodList.Length > 0 && GetProtocolMethods() != null
ShivangiReja marked this conversation as resolved.
Show resolved Hide resolved
? GetProtocolMethods().ToArray() : null;

public DataPlaneRestClient(OperationGroup operationGroup, BuildContext<DataPlaneOutputLibrary> context)
: base(operationGroup, context, context.Library.FindClient(operationGroup)?.Declaration.Name)
{
_context = context;
}

private IEnumerable<LowLevelClientMethod>? GetProtocolMethods()
{
// Filter protocol methods based on the config
List<Operation> operations = new();
foreach (var operation in OperationGroup.Operations)
{
if (isProtocolMethodExists(operation))
{
operations.Add(operation);
}

}

// Atleast one match found
if (operations.Count > 0)
{
var clientParameters = RestClientBuilder.GetParametersFromOperations(operations).ToList();
var restClientBuilder = new RestClientBuilder(clientParameters, _context);

var clientInfo = LowLevelOutputLibraryFactory.CreateClientInfo(OperationGroup, _context);
var clientInfoByName = _context.Library.DPGClientInfosByName[clientInfo.Name];

// Filter protocol method requests based on the config
List<(ServiceRequest, Operation)> requests = new();
foreach (var (serviceRequest, operation) in clientInfoByName.Requests)
{
if (isProtocolMethodExists(operation))
{
requests.Add((serviceRequest, operation));
}
}

return LowLevelClient.BuildMethods(restClientBuilder, requests, clientInfo.Name);
}

return null;
}

private bool isProtocolMethodExists(Operation operation)
{
var protocolMethods = _context.Configuration.ProtocolMethodList;
return protocolMethods.Any(m => m.Equals(operation.Language.Default.Name, StringComparison.OrdinalIgnoreCase));
}

protected override Dictionary<ServiceRequest, RestClientMethod> EnsureNormalMethods()
{
var requestMethods = new Dictionary<ServiceRequest, RestClientMethod>();
Expand Down
42 changes: 26 additions & 16 deletions src/AutoRest.CSharp/LowLevel/Generation/LowLevelClientWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ public void WriteClient(CodeWriter writer, LowLevelClient client, BuildContext<L
}
else
{
WriteClientMethod(writer, client, clientMethod, context.Configuration, true);
WriteClientMethod(writer, client, clientMethod, context.Configuration, false);
WriteClientMethod(writer, clientMethod, client.Fields, context.Configuration, true);
WriteClientMethod(writer, clientMethod, client.Fields, context.Configuration, false);
}
}

Expand All @@ -76,15 +76,10 @@ public void WriteClient(CodeWriter writer, LowLevelClient client, BuildContext<L
var responseClassifierTypes = new List<ResponseClassifierType>();
foreach (var method in client.RequestMethods)
{
var responseClassifierType = CreateResponseClassifierType(method);
responseClassifierTypes.Add(responseClassifierType);
RequestWriterHelpers.WriteRequestCreation(writer, method, "internal", client.Fields, responseClassifierType.Name, false);
WriteRequestCreationMethod(writer, method, client.Fields, responseClassifierTypes);
}

foreach ((string name, StatusCodes[] statusCodes) in responseClassifierTypes.Distinct())
{
WriteResponseClassifier(writer, name, statusCodes);
}
WriteResponseClassifierMethod(writer, responseClassifierTypes);
}
}
}
Expand Down Expand Up @@ -201,7 +196,7 @@ private static void WriteSubClientInternalConstructor(CodeWriter writer, LowLeve
writer.Line();
}

private static void WriteClientMethod(CodeWriter writer, LowLevelClient client, LowLevelClientMethod clientMethod, Configuration configuration, bool async)
public static void WriteClientMethod(CodeWriter writer, LowLevelClientMethod clientMethod, ClientFields fields, Configuration configuration, bool async)
{
var restMethod = clientMethod.RequestMethod;
var headAsBoolean = restMethod.Request.HttpMethod == RequestMethod.Head && configuration.HeadAsBoolean;
Expand All @@ -212,7 +207,7 @@ private static void WriteClientMethod(CodeWriter writer, LowLevelClient client,

using (WriteClientMethodDeclaration(writer, clientMethod, clientMethod.OperationSchemas, returnType, async))
{
using (WriteDiagnosticScope(writer, clientMethod.Diagnostic, client.Fields.ClientDiagnosticsProperty.Name))
using (WriteDiagnosticScope(writer, clientMethod.Diagnostic, fields.ClientDiagnosticsProperty.Name))
{
var messageVariable = new CodeWriterDeclaration("message");
writer.Line($"using {typeof(HttpMessage)} {messageVariable:D} = {RequestWriterHelpers.CreateRequestMethodName(restMethod.Name)}({restMethod.Parameters.GetIdentifiersFormattable()});");
Expand All @@ -222,16 +217,16 @@ private static void WriteClientMethod(CodeWriter writer, LowLevelClient client,
: headAsBoolean ? nameof(HttpPipelineExtensions.ProcessHeadAsBoolMessage) : nameof(HttpPipelineExtensions.ProcessMessage);

FormattableString paramString = headAsBoolean
? (FormattableString)$"{messageVariable}, {client.Fields.ClientDiagnosticsProperty.Name}, {KnownParameters.RequestContext.Name:I}"
? (FormattableString)$"{messageVariable}, {fields.ClientDiagnosticsProperty.Name}, {KnownParameters.RequestContext.Name:I}"
: (FormattableString)$"{messageVariable}, {KnownParameters.RequestContext.Name:I}";

writer.AppendRaw("return ").WriteMethodCall(async, $"{client.Fields.PipelineField.Name:I}.{methodName}", paramString);
writer.AppendRaw("return ").WriteMethodCall(async, $"{fields.PipelineField.Name:I}.{methodName}", paramString);
}
}
writer.Line();
}

private static void WritePagingMethod(CodeWriter writer, LowLevelClientMethod clientMethod, ClientFields fields, bool async)
public static void WritePagingMethod(CodeWriter writer, LowLevelClientMethod clientMethod, ClientFields fields, bool async)
{
var method = clientMethod.RequestMethod;
var pagingInfo = clientMethod.PagingInfo!;
Expand Down Expand Up @@ -293,7 +288,7 @@ private static void WritePagingMethod(CodeWriter writer, LowLevelClientMethod cl
writer.Line();
}

private static void WriteLongRunningOperationMethod(CodeWriter writer, LowLevelClientMethod clientMethod, ClientFields fields, bool async)
public static void WriteLongRunningOperationMethod(CodeWriter writer, LowLevelClientMethod clientMethod, ClientFields fields, bool async)
{
var startMethod = clientMethod.RequestMethod;
var pagingInfo = clientMethod.PagingInfo;
Expand Down Expand Up @@ -404,6 +399,21 @@ private void WriteSubClientFactoryMethod(CodeWriter writer, LowLevelClient clien
}
}

public static void WriteRequestCreationMethod(CodeWriter writer, RestClientMethod restMethod, ClientFields fields, List<ResponseClassifierType> responseClassifierTypes)
{
var responseClassifierType = CreateResponseClassifierType(restMethod);
responseClassifierTypes.Add(responseClassifierType);
RequestWriterHelpers.WriteRequestCreation(writer, restMethod, "internal", fields, responseClassifierType.Name, false);
}

public static void WriteResponseClassifierMethod(CodeWriter writer, List<ResponseClassifierType> responseClassifierTypes)
{
foreach ((string name, StatusCodes[] statusCodes) in responseClassifierTypes.Distinct())
{
WriteResponseClassifier(writer, name, statusCodes);
}
}

private static void WriteResponseClassifier(CodeWriter writer, string responseClassifierTypeName, StatusCodes[] statusCodes)
{
using (writer.Scope($"private sealed class {responseClassifierTypeName} : {typeof(ResponseClassifier)}"))
Expand Down Expand Up @@ -636,7 +646,7 @@ public SchemaDocumentation(string schemaName, DocumentationRow[] documentationRo
}
}

private readonly struct ResponseClassifierType : IEquatable<ResponseClassifierType>
public readonly struct ResponseClassifierType : IEquatable<ResponseClassifierType>
{
public string Name { get; }
private readonly StatusCodes[] _statusCodes;
Expand Down
Loading