Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor remaining non-request usages of the GenericSignatureBuilder #39826

Merged
merged 4 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions include/swift/AST/GenericParamList.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,6 @@ class RequirementRepr {
void print(ASTPrinter &Printer) const;
};

using GenericParamSource = PointerUnion<GenericContext *, GenericParamList *>;

/// GenericParamList - A list of generic parameters that is part of a generic
/// function or type, along with extra requirements placed on those generic
/// parameters and types derived from them.
Expand Down
22 changes: 15 additions & 7 deletions include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,8 @@ struct WhereClauseOwner {
SpecializeAttr *, DifferentiableAttr *>
source;

WhereClauseOwner() : dc(nullptr) {}

WhereClauseOwner(GenericContext *genCtx);
WhereClauseOwner(AssociatedTypeDecl *atd);

Expand All @@ -480,6 +482,10 @@ struct WhereClauseOwner {
return llvm::hash_value(owner.source.getOpaqueValue());
}

operator bool() const {
return dc != nullptr;
}

friend bool operator==(const WhereClauseOwner &lhs,
const WhereClauseOwner &rhs) {
return lhs.source.getOpaqueValue() == rhs.source.getOpaqueValue();
Expand Down Expand Up @@ -1437,11 +1443,12 @@ class AbstractGenericSignatureRequest :
class InferredGenericSignatureRequest :
public SimpleRequest<InferredGenericSignatureRequest,
GenericSignature (ModuleDecl *,
const GenericSignatureImpl *,
GenericParamSource,
SmallVector<Requirement, 2>,
SmallVector<TypeLoc, 2>,
bool),
const GenericSignatureImpl *,
GenericParamList *,
WhereClauseOwner,
SmallVector<Requirement, 2>,
SmallVector<TypeLoc, 2>,
bool),
RequestFlags::Cached> {
public:
using SimpleRequest::SimpleRequest;
Expand All @@ -1452,9 +1459,10 @@ class InferredGenericSignatureRequest :
// Evaluation.
GenericSignature
evaluate(Evaluator &evaluator,
ModuleDecl *module,
ModuleDecl *parentModule,
const GenericSignatureImpl *baseSignature,
GenericParamSource paramSource,
GenericParamList *genericParams,
WhereClauseOwner whereClause,
SmallVector<Requirement, 2> addedRequirements,
SmallVector<TypeLoc, 2> inferenceSources,
bool allowConcreteGenericParams) const;
Expand Down
10 changes: 6 additions & 4 deletions include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,12 @@ SWIFT_REQUEST(TypeChecker, HasImplementationOnlyImportsRequest,
SWIFT_REQUEST(TypeChecker, ModuleLibraryLevelRequest,
LibraryLevel(ModuleDecl *), Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, InferredGenericSignatureRequest,
GenericSignature (ModuleDecl *, const GenericSignatureImpl *,
GenericParamSource,
SmallVector<Requirement, 2>,
SmallVector<TypeLoc, 2>, bool),
GenericSignature (ModuleDecl *,
const GenericSignatureImpl *,
GenericParamList *,
WhereClauseOwner,
SmallVector<Requirement, 2>,
SmallVector<TypeLoc, 2>, bool),
Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, DistributedModuleIsAvailableRequest,
bool(ModuleDecl *), Cached, NoLocationInfo)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
#define SWIFT_SILOPTIMIZER_ANALYSIS_DIFFERENTIABLEACTIVITYANALYSIS_H_

#include "swift/AST/GenericEnvironment.h"
#include "swift/AST/GenericSignatureBuilder.h"
#include "swift/SIL/SILFunction.h"
#include "swift/SIL/SILModule.h"
#include "swift/SIL/SILValue.h"
Expand Down
23 changes: 7 additions & 16 deletions lib/AST/GenericSignatureBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8679,9 +8679,11 @@ AbstractGenericSignatureRequest::evaluate(

GenericSignature
InferredGenericSignatureRequest::evaluate(
Evaluator &evaluator, ModuleDecl *parentModule,
Evaluator &evaluator,
ModuleDecl *parentModule,
const GenericSignatureImpl *parentSig,
GenericParamSource paramSource,
GenericParamList *genericParams,
WhereClauseOwner whereClause,
SmallVector<Requirement, 2> addedRequirements,
SmallVector<TypeLoc, 2> inferenceSources,
bool allowConcreteGenericParams) const {
Expand Down Expand Up @@ -8729,12 +8731,6 @@ InferredGenericSignatureRequest::evaluate(
return false;
};

GenericParamList *genericParams = nullptr;
if (auto params = paramSource.dyn_cast<GenericParamList *>())
genericParams = params;
else
genericParams = paramSource.get<GenericContext *>()->getGenericParams();

if (genericParams) {
// Extensions never have a parent signature.
if (genericParams->getOuterParameters())
Expand Down Expand Up @@ -8777,15 +8773,10 @@ InferredGenericSignatureRequest::evaluate(
}
}

if (auto *ctx = paramSource.dyn_cast<GenericContext *>()) {
// The declaration might have a trailing where clause.
if (auto *where = ctx->getTrailingWhereClause()) {
// Determine where and how to perform name lookup.
lookupDC = ctx;

WhereClauseOwner(lookupDC, where).visitRequirements(
if (whereClause) {
lookupDC = whereClause.dc;
std::move(whereClause).visitRequirements(
TypeResolutionStage::Structural, visitRequirement);
}
}

/// Perform any remaining requirement inference.
Expand Down
1 change: 0 additions & 1 deletion lib/IRGen/GenType.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ namespace llvm {
}

namespace swift {
class GenericSignatureBuilder;
class ArchetypeType;
class CanType;
class ClassDecl;
Expand Down
1 change: 0 additions & 1 deletion lib/IRGen/IRGenModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ namespace clang {

namespace swift {
class GenericSignature;
class GenericSignatureBuilder;
class AssociatedConformance;
class AssociatedType;
class ASTContext;
Expand Down
24 changes: 10 additions & 14 deletions lib/SILOptimizer/Differentiation/Thunk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
#include "swift/SILOptimizer/Differentiation/Common.h"

#include "swift/AST/AnyFunctionRef.h"
#include "swift/AST/GenericSignatureBuilder.h"
#include "swift/AST/Requirement.h"
#include "swift/AST/SubstitutionMap.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
#include "swift/SILOptimizer/Utils/DifferentiationMangler.h"

Expand Down Expand Up @@ -53,30 +53,26 @@ CanGenericSignature buildThunkSignature(SILFunction *fn, bool inheritGenericSig,
}

auto &ctx = fn->getASTContext();
GenericSignatureBuilder builder(ctx);

// Add the existing generic signature.
GenericSignature baseGenericSig;
int depth = 0;
if (inheritGenericSig) {
if (auto genericSig =
fn->getLoweredFunctionType()->getSubstGenericSignature()) {
builder.addGenericSignature(genericSig);
depth = genericSig.getGenericParams().back()->getDepth() + 1;
}
baseGenericSig = fn->getLoweredFunctionType()->getSubstGenericSignature();
if (baseGenericSig)
depth = baseGenericSig.getGenericParams().back()->getDepth() + 1;
}

// Add a new generic parameter to replace the opened existential.
auto *newGenericParam = GenericTypeParamType::get(depth, 0, ctx);

builder.addGenericParameter(newGenericParam);
Requirement newRequirement(RequirementKind::Conformance, newGenericParam,
openedExistential->getOpenedExistentialType());
auto source =
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
builder.addRequirement(newRequirement, source, nullptr);

auto genericSig = std::move(builder).computeGenericSignature(
/*allowConcreteGenericParams=*/true);
auto genericSig = evaluateOrDefault(
ctx.evaluator,
AbstractGenericSignatureRequest{
baseGenericSig.getPointer(), { newGenericParam }, { newRequirement }},
GenericSignature());
genericEnv = genericSig.getGenericEnvironment();

newArchetype =
Expand Down
99 changes: 35 additions & 64 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "swift/AST/DiagnosticsParse.h"
#include "swift/AST/Effects.h"
#include "swift/AST/GenericEnvironment.h"
#include "swift/AST/GenericSignatureBuilder.h"
#include "swift/AST/ImportCache.h"
#include "swift/AST/ModuleNameLookup.h"
#include "swift/AST/NameLookup.h"
Expand Down Expand Up @@ -2231,28 +2230,17 @@ void AttributeChecker::visitSpecializeAttr(SpecializeAttr *attr) {
return;
}

// Form a new generic signature based on the old one.
GenericSignatureBuilder Builder(D->getASTContext());
InferredGenericSignatureRequest request{
DC->getParentModule(),
genericSig.getPointer(),
/*genericParams=*/nullptr,
WhereClauseOwner(FD, attr),
/*addedRequirements=*/{},
/*inferenceSources=*/{},
/*allowConcreteGenericParams=*/true};

// First, add the old generic signature.
Builder.addGenericSignature(genericSig);

// Go over the set of requirements, adding them to the builder.
WhereClauseOwner(FD, attr).visitRequirements(TypeResolutionStage::Interface,
[&](const Requirement &req, RequirementRepr *reqRepr) {
// Add the requirement to the generic signature builder.
using FloatingRequirementSource =
GenericSignatureBuilder::FloatingRequirementSource;
Builder.addRequirement(req, reqRepr,
FloatingRequirementSource::forExplicit(
reqRepr->getSeparatorLoc()),
nullptr, DC->getParentModule());
return false;
});

// Check the result.
auto specializedSig = std::move(Builder).computeGenericSignature(
/*allowConcreteGenericParams=*/true);
auto specializedSig = evaluateOrDefault(Ctx.evaluator, request,
GenericSignature());

// Check the validity of provided requirements.
checkSpecializeAttrRequirements(attr, genericSig, specializedSig, Ctx);
Expand Down Expand Up @@ -4266,7 +4254,8 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
// - If the `@differentiable` attribute has a `where` clause, use it to
// compute the derivative generic signature.
// - Otherwise, use the original function's generic signature by default.
derivativeGenSig = original->getGenericSignature();
auto originalGenSig = original->getGenericSignature();
derivativeGenSig = originalGenSig;

// Handle the `where` clause, if it exists.
// - Resolve attribute where clause requirements and store in the attribute
Expand All @@ -4291,7 +4280,6 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
return true;
}

auto originalGenSig = original->getGenericSignature();
if (!originalGenSig) {
// `where` clauses are valid only when the original function is generic.
diags
Expand All @@ -4304,51 +4292,34 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
return true;
}

// Build a new generic signature for autodiff derivative functions.
GenericSignatureBuilder builder(ctx);
// Add the original function's generic signature.
builder.addGenericSignature(originalGenSig);

using FloatingRequirementSource =
GenericSignatureBuilder::FloatingRequirementSource;

bool errorOccurred = false;
WhereClauseOwner(original, attr)
.visitRequirements(
TypeResolutionStage::Structural,
[&](const Requirement &req, RequirementRepr *reqRepr) {
switch (req.getKind()) {
case RequirementKind::SameType:
case RequirementKind::Superclass:
case RequirementKind::Conformance:
break;

// Layout requirements are not supported.
case RequirementKind::Layout:
diags
.diagnose(attr->getLocation(),
diag::differentiable_attr_layout_req_unsupported)
.highlight(reqRepr->getSourceRange());
errorOccurred = true;
return false;
}
InferredGenericSignatureRequest request{
original->getParentModule(),
originalGenSig.getPointer(),
/*genericParams=*/nullptr,
WhereClauseOwner(original, attr),
/*addedRequirements=*/{},
/*inferenceSources=*/{},
/*allowConcreteParams=*/true};

// Compute generic signature for derivative functions.
derivativeGenSig = evaluateOrDefault(ctx.evaluator, request,
GenericSignature());

// Add requirement to generic signature builder.
builder.addRequirement(
req, reqRepr, FloatingRequirementSource::forExplicit(
reqRepr->getSeparatorLoc()),
nullptr, original->getModuleContext());
return false;
});
bool hadInvalidRequirements = false;
for (auto req : derivativeGenSig.requirementsNotSatisfiedBy(originalGenSig)) {
if (req.getKind() == RequirementKind::Layout) {
// Layout requirements are not supported.
diags
.diagnose(attr->getLocation(),
diag::differentiable_attr_layout_req_unsupported);
hadInvalidRequirements = true;
}
}

if (errorOccurred) {
if (hadInvalidRequirements) {
attr->setInvalid();
return true;
}

// Compute generic signature for derivative functions.
derivativeGenSig = std::move(builder).computeGenericSignature(
/*allowConcreteGenericParams=*/true);
}

attr->setDerivativeGenericSignature(derivativeGenSig);
Expand Down
Loading