Skip to content

Commit

Permalink
Merge pull request #39826 from slavapestov/encapsulate-gsb
Browse files Browse the repository at this point in the history
Refactor remaining non-request usages of the GenericSignatureBuilder
  • Loading branch information
slavapestov authored Oct 20, 2021
2 parents 5e2a5cb + a9794dc commit 9481130
Show file tree
Hide file tree
Showing 13 changed files with 116 additions and 190 deletions.
2 changes: 0 additions & 2 deletions include/swift/AST/GenericParamList.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,6 @@ class RequirementRepr {
void print(ASTPrinter &Printer) const;
};

using GenericParamSource = PointerUnion<GenericContext *, GenericParamList *>;

/// GenericParamList - A list of generic parameters that is part of a generic
/// function or type, along with extra requirements placed on those generic
/// parameters and types derived from them.
Expand Down
22 changes: 15 additions & 7 deletions include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,8 @@ struct WhereClauseOwner {
SpecializeAttr *, DifferentiableAttr *>
source;

WhereClauseOwner() : dc(nullptr) {}

WhereClauseOwner(GenericContext *genCtx);
WhereClauseOwner(AssociatedTypeDecl *atd);

Expand All @@ -480,6 +482,10 @@ struct WhereClauseOwner {
return llvm::hash_value(owner.source.getOpaqueValue());
}

operator bool() const {
return dc != nullptr;
}

friend bool operator==(const WhereClauseOwner &lhs,
const WhereClauseOwner &rhs) {
return lhs.source.getOpaqueValue() == rhs.source.getOpaqueValue();
Expand Down Expand Up @@ -1437,11 +1443,12 @@ class AbstractGenericSignatureRequest :
class InferredGenericSignatureRequest :
public SimpleRequest<InferredGenericSignatureRequest,
GenericSignature (ModuleDecl *,
const GenericSignatureImpl *,
GenericParamSource,
SmallVector<Requirement, 2>,
SmallVector<TypeLoc, 2>,
bool),
const GenericSignatureImpl *,
GenericParamList *,
WhereClauseOwner,
SmallVector<Requirement, 2>,
SmallVector<TypeLoc, 2>,
bool),
RequestFlags::Cached> {
public:
using SimpleRequest::SimpleRequest;
Expand All @@ -1452,9 +1459,10 @@ class InferredGenericSignatureRequest :
// Evaluation.
GenericSignature
evaluate(Evaluator &evaluator,
ModuleDecl *module,
ModuleDecl *parentModule,
const GenericSignatureImpl *baseSignature,
GenericParamSource paramSource,
GenericParamList *genericParams,
WhereClauseOwner whereClause,
SmallVector<Requirement, 2> addedRequirements,
SmallVector<TypeLoc, 2> inferenceSources,
bool allowConcreteGenericParams) const;
Expand Down
10 changes: 6 additions & 4 deletions include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,12 @@ SWIFT_REQUEST(TypeChecker, HasImplementationOnlyImportsRequest,
SWIFT_REQUEST(TypeChecker, ModuleLibraryLevelRequest,
LibraryLevel(ModuleDecl *), Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, InferredGenericSignatureRequest,
GenericSignature (ModuleDecl *, const GenericSignatureImpl *,
GenericParamSource,
SmallVector<Requirement, 2>,
SmallVector<TypeLoc, 2>, bool),
GenericSignature (ModuleDecl *,
const GenericSignatureImpl *,
GenericParamList *,
WhereClauseOwner,
SmallVector<Requirement, 2>,
SmallVector<TypeLoc, 2>, bool),
Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, DistributedModuleIsAvailableRequest,
bool(ModuleDecl *), Cached, NoLocationInfo)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
#define SWIFT_SILOPTIMIZER_ANALYSIS_DIFFERENTIABLEACTIVITYANALYSIS_H_

#include "swift/AST/GenericEnvironment.h"
#include "swift/AST/GenericSignatureBuilder.h"
#include "swift/SIL/SILFunction.h"
#include "swift/SIL/SILModule.h"
#include "swift/SIL/SILValue.h"
Expand Down
23 changes: 7 additions & 16 deletions lib/AST/GenericSignatureBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8679,9 +8679,11 @@ AbstractGenericSignatureRequest::evaluate(

GenericSignature
InferredGenericSignatureRequest::evaluate(
Evaluator &evaluator, ModuleDecl *parentModule,
Evaluator &evaluator,
ModuleDecl *parentModule,
const GenericSignatureImpl *parentSig,
GenericParamSource paramSource,
GenericParamList *genericParams,
WhereClauseOwner whereClause,
SmallVector<Requirement, 2> addedRequirements,
SmallVector<TypeLoc, 2> inferenceSources,
bool allowConcreteGenericParams) const {
Expand Down Expand Up @@ -8729,12 +8731,6 @@ InferredGenericSignatureRequest::evaluate(
return false;
};

GenericParamList *genericParams = nullptr;
if (auto params = paramSource.dyn_cast<GenericParamList *>())
genericParams = params;
else
genericParams = paramSource.get<GenericContext *>()->getGenericParams();

if (genericParams) {
// Extensions never have a parent signature.
if (genericParams->getOuterParameters())
Expand Down Expand Up @@ -8777,15 +8773,10 @@ InferredGenericSignatureRequest::evaluate(
}
}

if (auto *ctx = paramSource.dyn_cast<GenericContext *>()) {
// The declaration might have a trailing where clause.
if (auto *where = ctx->getTrailingWhereClause()) {
// Determine where and how to perform name lookup.
lookupDC = ctx;

WhereClauseOwner(lookupDC, where).visitRequirements(
if (whereClause) {
lookupDC = whereClause.dc;
std::move(whereClause).visitRequirements(
TypeResolutionStage::Structural, visitRequirement);
}
}

/// Perform any remaining requirement inference.
Expand Down
1 change: 0 additions & 1 deletion lib/IRGen/GenType.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ namespace llvm {
}

namespace swift {
class GenericSignatureBuilder;
class ArchetypeType;
class CanType;
class ClassDecl;
Expand Down
1 change: 0 additions & 1 deletion lib/IRGen/IRGenModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ namespace clang {

namespace swift {
class GenericSignature;
class GenericSignatureBuilder;
class AssociatedConformance;
class AssociatedType;
class ASTContext;
Expand Down
24 changes: 10 additions & 14 deletions lib/SILOptimizer/Differentiation/Thunk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
#include "swift/SILOptimizer/Differentiation/Common.h"

#include "swift/AST/AnyFunctionRef.h"
#include "swift/AST/GenericSignatureBuilder.h"
#include "swift/AST/Requirement.h"
#include "swift/AST/SubstitutionMap.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
#include "swift/SILOptimizer/Utils/DifferentiationMangler.h"

Expand Down Expand Up @@ -53,30 +53,26 @@ CanGenericSignature buildThunkSignature(SILFunction *fn, bool inheritGenericSig,
}

auto &ctx = fn->getASTContext();
GenericSignatureBuilder builder(ctx);

// Add the existing generic signature.
GenericSignature baseGenericSig;
int depth = 0;
if (inheritGenericSig) {
if (auto genericSig =
fn->getLoweredFunctionType()->getSubstGenericSignature()) {
builder.addGenericSignature(genericSig);
depth = genericSig.getGenericParams().back()->getDepth() + 1;
}
baseGenericSig = fn->getLoweredFunctionType()->getSubstGenericSignature();
if (baseGenericSig)
depth = baseGenericSig.getGenericParams().back()->getDepth() + 1;
}

// Add a new generic parameter to replace the opened existential.
auto *newGenericParam = GenericTypeParamType::get(depth, 0, ctx);

builder.addGenericParameter(newGenericParam);
Requirement newRequirement(RequirementKind::Conformance, newGenericParam,
openedExistential->getOpenedExistentialType());
auto source =
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
builder.addRequirement(newRequirement, source, nullptr);

auto genericSig = std::move(builder).computeGenericSignature(
/*allowConcreteGenericParams=*/true);
auto genericSig = evaluateOrDefault(
ctx.evaluator,
AbstractGenericSignatureRequest{
baseGenericSig.getPointer(), { newGenericParam }, { newRequirement }},
GenericSignature());
genericEnv = genericSig.getGenericEnvironment();

newArchetype =
Expand Down
99 changes: 35 additions & 64 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "swift/AST/DiagnosticsParse.h"
#include "swift/AST/Effects.h"
#include "swift/AST/GenericEnvironment.h"
#include "swift/AST/GenericSignatureBuilder.h"
#include "swift/AST/ImportCache.h"
#include "swift/AST/ModuleNameLookup.h"
#include "swift/AST/NameLookup.h"
Expand Down Expand Up @@ -2231,28 +2230,17 @@ void AttributeChecker::visitSpecializeAttr(SpecializeAttr *attr) {
return;
}

// Form a new generic signature based on the old one.
GenericSignatureBuilder Builder(D->getASTContext());
InferredGenericSignatureRequest request{
DC->getParentModule(),
genericSig.getPointer(),
/*genericParams=*/nullptr,
WhereClauseOwner(FD, attr),
/*addedRequirements=*/{},
/*inferenceSources=*/{},
/*allowConcreteGenericParams=*/true};

// First, add the old generic signature.
Builder.addGenericSignature(genericSig);

// Go over the set of requirements, adding them to the builder.
WhereClauseOwner(FD, attr).visitRequirements(TypeResolutionStage::Interface,
[&](const Requirement &req, RequirementRepr *reqRepr) {
// Add the requirement to the generic signature builder.
using FloatingRequirementSource =
GenericSignatureBuilder::FloatingRequirementSource;
Builder.addRequirement(req, reqRepr,
FloatingRequirementSource::forExplicit(
reqRepr->getSeparatorLoc()),
nullptr, DC->getParentModule());
return false;
});

// Check the result.
auto specializedSig = std::move(Builder).computeGenericSignature(
/*allowConcreteGenericParams=*/true);
auto specializedSig = evaluateOrDefault(Ctx.evaluator, request,
GenericSignature());

// Check the validity of provided requirements.
checkSpecializeAttrRequirements(attr, genericSig, specializedSig, Ctx);
Expand Down Expand Up @@ -4266,7 +4254,8 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
// - If the `@differentiable` attribute has a `where` clause, use it to
// compute the derivative generic signature.
// - Otherwise, use the original function's generic signature by default.
derivativeGenSig = original->getGenericSignature();
auto originalGenSig = original->getGenericSignature();
derivativeGenSig = originalGenSig;

// Handle the `where` clause, if it exists.
// - Resolve attribute where clause requirements and store in the attribute
Expand All @@ -4291,7 +4280,6 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
return true;
}

auto originalGenSig = original->getGenericSignature();
if (!originalGenSig) {
// `where` clauses are valid only when the original function is generic.
diags
Expand All @@ -4304,51 +4292,34 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
return true;
}

// Build a new generic signature for autodiff derivative functions.
GenericSignatureBuilder builder(ctx);
// Add the original function's generic signature.
builder.addGenericSignature(originalGenSig);

using FloatingRequirementSource =
GenericSignatureBuilder::FloatingRequirementSource;

bool errorOccurred = false;
WhereClauseOwner(original, attr)
.visitRequirements(
TypeResolutionStage::Structural,
[&](const Requirement &req, RequirementRepr *reqRepr) {
switch (req.getKind()) {
case RequirementKind::SameType:
case RequirementKind::Superclass:
case RequirementKind::Conformance:
break;

// Layout requirements are not supported.
case RequirementKind::Layout:
diags
.diagnose(attr->getLocation(),
diag::differentiable_attr_layout_req_unsupported)
.highlight(reqRepr->getSourceRange());
errorOccurred = true;
return false;
}
InferredGenericSignatureRequest request{
original->getParentModule(),
originalGenSig.getPointer(),
/*genericParams=*/nullptr,
WhereClauseOwner(original, attr),
/*addedRequirements=*/{},
/*inferenceSources=*/{},
/*allowConcreteParams=*/true};

// Compute generic signature for derivative functions.
derivativeGenSig = evaluateOrDefault(ctx.evaluator, request,
GenericSignature());

// Add requirement to generic signature builder.
builder.addRequirement(
req, reqRepr, FloatingRequirementSource::forExplicit(
reqRepr->getSeparatorLoc()),
nullptr, original->getModuleContext());
return false;
});
bool hadInvalidRequirements = false;
for (auto req : derivativeGenSig.requirementsNotSatisfiedBy(originalGenSig)) {
if (req.getKind() == RequirementKind::Layout) {
// Layout requirements are not supported.
diags
.diagnose(attr->getLocation(),
diag::differentiable_attr_layout_req_unsupported);
hadInvalidRequirements = true;
}
}

if (errorOccurred) {
if (hadInvalidRequirements) {
attr->setInvalid();
return true;
}

// Compute generic signature for derivative functions.
derivativeGenSig = std::move(builder).computeGenericSignature(
/*allowConcreteGenericParams=*/true);
}

attr->setDerivativeGenericSignature(derivativeGenSig);
Expand Down
Loading

0 comments on commit 9481130

Please sign in to comment.