Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoDiff] Mangle @noDerivative parameters. #31201

Merged
merged 2 commits into from
Apr 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/ABI/Mangling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ mangled in to disambiguate.
impl-function-type ::= type* 'I' FUNC-ATTRIBUTES '_'
impl-function-type ::= type* generic-signature 'I' FUNC-ATTRIBUTES '_'

FUNC-ATTRIBUTES ::= PATTERN-SUBS? INVOCATION-SUBS? PSEUDO-GENERIC? CALLEE-ESCAPE? DIFFERENTIABILITY-KIND? CALLEE-CONVENTION FUNC-REPRESENTATION? COROUTINE-KIND? PARAM-CONVENTION* RESULT-CONVENTION* ('Y' PARAM-CONVENTION)* ('z' RESULT-CONVENTION)?
FUNC-ATTRIBUTES ::= PATTERN-SUBS? INVOCATION-SUBS? PSEUDO-GENERIC? CALLEE-ESCAPE? DIFFERENTIABILITY-KIND? CALLEE-CONVENTION FUNC-REPRESENTATION? COROUTINE-KIND? (PARAM-CONVENTION PARAM-DIFFERENTIABILITY?)* RESULT-CONVENTION* ('Y' PARAM-CONVENTION)* ('z' RESULT-CONVENTION)?

PATTERN-SUBS ::= 's' // has pattern substitutions
INVOCATION-SUB ::= 'I' // has invocation substitutions
Expand Down Expand Up @@ -626,6 +626,8 @@ mangled in to disambiguate.
PARAM-CONVENTION ::= 'g' // direct guaranteed
PARAM-CONVENTION ::= 'e' // direct deallocating

PARAM-DIFFERENTIABILITY ::= 'w' // @noDerivative

RESULT-CONVENTION ::= 'r' // indirect
RESULT-CONVENTION ::= 'o' // owned
RESULT-CONVENTION ::= 'd' // unowned
Expand Down
1 change: 1 addition & 0 deletions include/swift/Demangling/DemangleNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ NODE(ImplDifferentiable)
NODE(ImplLinear)
NODE(ImplEscaping)
NODE(ImplConvention)
NODE(ImplDifferentiability)
NODE(ImplFunctionAttribute)
NODE(ImplFunctionType)
NODE(ImplInvocationSubstitutions)
Expand Down
1 change: 1 addition & 0 deletions include/swift/Demangling/Demangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ class Demangler : public NodeFactory {
NodePointer demangleInitializer();
NodePointer demangleImplParamConvention(Node::Kind ConvKind);
NodePointer demangleImplResultConvention(Node::Kind ConvKind);
NodePointer demangleImplDifferentiability();
NodePointer demangleImplFunctionType();
NodePointer demangleMetatype();
NodePointer demanglePrivateContextDescriptor();
Expand Down
68 changes: 63 additions & 5 deletions include/swift/Demangling/TypeDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,31 @@ enum class ImplParameterConvention {
Direct_Guaranteed,
};

enum class ImplParameterDifferentiability {
DifferentiableOrNotApplicable,
NotDifferentiable
};

static inline Optional<ImplParameterDifferentiability>
getDifferentiabilityFromString(StringRef string) {
if (string.empty())
return ImplParameterDifferentiability::DifferentiableOrNotApplicable;
if (string == "@noDerivative")
return ImplParameterDifferentiability::NotDifferentiable;
return None;
}

/// Describe a lowered function parameter, parameterized on the type
/// representation.
template <typename BuiltType>
class ImplFunctionParam {
ImplParameterConvention Convention;
ImplParameterDifferentiability Differentiability;
BuiltType Type;

public:
using ConventionType = ImplParameterConvention;
using DifferentiabilityType = ImplParameterDifferentiability;

static Optional<ConventionType>
getConventionFromString(StringRef conventionString) {
Expand All @@ -120,11 +136,16 @@ class ImplFunctionParam {
return None;
}

ImplFunctionParam(ImplParameterConvention convention, BuiltType type)
: Convention(convention), Type(type) {}
ImplFunctionParam(ImplParameterConvention convention,
ImplParameterDifferentiability diffKind, BuiltType type)
: Convention(convention), Differentiability(diffKind), Type(type) {}

ImplParameterConvention getConvention() const { return Convention; }

ImplParameterDifferentiability getDifferentiability() const {
return Differentiability;
}

BuiltType getType() const { return Type; }
};

Expand Down Expand Up @@ -614,10 +635,8 @@ class TypeDecoder {
ImplFunctionDifferentiabilityKind::Linear);
} else if (child->getKind() == NodeKind::ImplEscaping) {
flags = flags.withEscaping();
} else if (child->getKind() == NodeKind::ImplEscaping) {
flags = flags.withEscaping();
} else if (child->getKind() == NodeKind::ImplParameter) {
if (decodeImplFunctionPart(child, parameters))
if (decodeImplFunctionParam(child, parameters))
return BuiltType();
} else if (child->getKind() == NodeKind::ImplResult) {
if (decodeImplFunctionPart(child, results))
Expand Down Expand Up @@ -897,6 +916,45 @@ class TypeDecoder {
return false;
}

bool decodeImplFunctionParam(
Demangle::NodePointer node,
SmallVectorImpl<ImplFunctionParam<BuiltType>> &results) {
// Children: `convention, differentiability?, type`
if (node->getNumChildren() != 2 && node->getNumChildren() != 3)
return true;

auto *conventionNode = node->getChild(0);
auto *typeNode = node->getLastChild();
if (conventionNode->getKind() != Node::Kind::ImplConvention ||
typeNode->getKind() != Node::Kind::Type)
return true;

StringRef conventionString = conventionNode->getText();
auto convention =
ImplFunctionParam<BuiltType>::getConventionFromString(conventionString);
if (!convention)
return true;
BuiltType type = decodeMangledType(typeNode);
if (!type)
return true;

auto diffKind =
ImplParameterDifferentiability::DifferentiableOrNotApplicable;
if (node->getNumChildren() == 3) {
auto diffKindNode = node->getChild(1);
if (diffKindNode->getKind() != Node::Kind::ImplDifferentiability)
return true;
auto optDiffKind =
getDifferentiabilityFromString(diffKindNode->getText());
if (!optDiffKind)
return true;
diffKind = *optDiffKind;
}

results.emplace_back(*convention, diffKind, type);
return false;
}

bool decodeMangledTypeDecl(Demangle::NodePointer node,
BuiltTypeDecl &typeDecl,
BuiltType &parent,
Expand Down
13 changes: 12 additions & 1 deletion lib/AST/ASTDemangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,16 @@ getParameterConvention(ImplParameterConvention conv) {
llvm_unreachable("covered switch");
}

static SILParameterDifferentiability
getParameterDifferentiability(ImplParameterDifferentiability diffKind) {
switch (diffKind) {
case ImplParameterDifferentiability::DifferentiableOrNotApplicable:
return SILParameterDifferentiability::DifferentiableOrNotApplicable;
case ImplParameterDifferentiability::NotDifferentiable:
return SILParameterDifferentiability::NotDifferentiable;
}
}

static ResultConvention getResultConvention(ImplResultConvention conv) {
switch (conv) {
case Demangle::ImplResultConvention::Indirect:
Expand Down Expand Up @@ -526,7 +536,8 @@ Type ASTBuilder::createImplFunctionType(
for (const auto &param : params) {
auto type = param.getType()->getCanonicalType();
auto conv = getParameterConvention(param.getConvention());
funcParams.emplace_back(type, conv);
auto diffKind = getParameterDifferentiability(param.getDifferentiability());
funcParams.emplace_back(type, conv, diffKind);
}

for (const auto &result : results) {
Expand Down
13 changes: 13 additions & 0 deletions lib/AST/ASTMangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,17 @@ static char getParamConvention(ParameterConvention conv) {
llvm_unreachable("bad parameter convention");
};

static Optional<char>
getParamDifferentiability(SILParameterDifferentiability diffKind) {
switch (diffKind) {
case swift::SILParameterDifferentiability::DifferentiableOrNotApplicable:
return None;
case swift::SILParameterDifferentiability::NotDifferentiable:
return 'w';
}
llvm_unreachable("bad parameter convention");
};

static char getResultConvention(ResultConvention conv) {
switch (conv) {
case ResultConvention::Indirect: return 'r';
Expand Down Expand Up @@ -1658,6 +1669,8 @@ void ASTMangler::appendImplFunctionType(SILFunctionType *fn) {
// Mangle the parameters.
for (auto param : fn->getParameters()) {
OpArgs.push_back(getParamConvention(param.getConvention()));
if (auto diffKind = getParamDifferentiability(param.getDifferentiability()))
OpArgs.push_back(*diffKind);
appendType(param.getInterfaceType());
}

Expand Down
12 changes: 11 additions & 1 deletion lib/Demangling/Demangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1732,6 +1732,14 @@ NodePointer Demangler::demangleImplResultConvention(Node::Kind ConvKind) {
createNode(Node::Kind::ImplConvention, attr));
}

NodePointer Demangler::demangleImplDifferentiability() {
// Empty string represents default differentiability.
const char *attr = "";
if (nextIf('w'))
attr = "@noDerivative";
return createNode(Node::Kind::ImplDifferentiability, attr);
}

NodePointer Demangler::demangleImplFunctionType() {
NodePointer type = createNode(Node::Kind::ImplFunctionType);

Expand Down Expand Up @@ -1817,8 +1825,10 @@ NodePointer Demangler::demangleImplFunctionType() {

int NumTypesToAdd = 0;
while (NodePointer Param =
demangleImplParamConvention(Node::Kind::ImplParameter)) {
demangleImplParamConvention(Node::Kind::ImplParameter)) {
type = addChild(type, Param);
if (NodePointer Diff = demangleImplDifferentiability())
Param = addChild(Param, Diff);
NumTypesToAdd++;
}
while (NodePointer Result = demangleImplResultConvention(
Expand Down
18 changes: 18 additions & 0 deletions lib/Demangling/NodePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ class NodePrinter {
case Node::Kind::ImplLinear:
case Node::Kind::ImplEscaping:
case Node::Kind::ImplConvention:
case Node::Kind::ImplDifferentiability:
case Node::Kind::ImplFunctionAttribute:
case Node::Kind::ImplFunctionType:
case Node::Kind::ImplInvocationSubstitutions:
Expand Down Expand Up @@ -2060,6 +2061,13 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) {
case Node::Kind::ImplConvention:
Printer << Node->getText();
return nullptr;
case Node::Kind::ImplDifferentiability:
// Skip if text is empty.
if (Node->getText().empty())
return nullptr;
// Otherwise, print with trailing space.
Printer << Node->getText() << ' ';
return nullptr;
case Node::Kind::ImplFunctionAttribute:
Printer << Node->getText();
return nullptr;
Expand All @@ -2072,6 +2080,16 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) {
printChildren(Node, " ");
return nullptr;
case Node::Kind::ImplParameter:
// Children: `convention, differentiability?, type`
// Print convention.
print(Node->getChild(0));
Printer << " ";
// Print differentiability, if it exists.
if (Node->getNumChildren() == 3)
print(Node->getChild(1));
// Print type.
print(Node->getLastChild());
return nullptr;
case Node::Kind::ImplResult:
printChildren(Node, " ");
return nullptr;
Expand Down
13 changes: 13 additions & 0 deletions lib/Demangling/OldRemangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,19 @@ void Remangler::mangleImplConvention(Node *node) {
}
}

void Remangler::mangleImplDifferentiability(Node *node) {
assert(node->getKind() == Node::Kind::ImplDifferentiability);
StringRef text = node->getText();
// Empty string represents default differentiability.
if (text.empty())
return;
if (text == "@noDerivative") {
Buffer << 'w';
return;
}
unreachable("Invalid impl differentiability");
}

void Remangler::mangleDynamicSelf(Node *node) {
Buffer << 'D';
mangleSingleChildNode(node); // type
Expand Down
20 changes: 19 additions & 1 deletion lib/Demangling/Remangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,18 @@ void Remangler::mangleImplConvention(Node *node) {
Buffer << ConvCh;
}

void Remangler::mangleImplDifferentiability(Node *node) {
assert(node->hasText());
// Empty string represents default differentiability.
if (node->getText().empty())
return;
char diffChar = llvm::StringSwitch<char>(node->getText())
.Case("@noDerivative", 'w')
.Default(0);
assert(diffChar && "Invalid impl differentiability");
Buffer << diffChar;
}

void Remangler::mangleImplFunctionAttribute(Node *node) {
unreachable("handled inline");
}
Expand All @@ -1443,7 +1455,9 @@ void Remangler::mangleImplFunctionType(Node *node) {
case Node::Kind::ImplResult:
case Node::Kind::ImplYield:
case Node::Kind::ImplErrorResult:
mangleChildNode(Child, 1);
// Mangle type. Type should be the last child.
assert(Child->getNumChildren() == 2 || Child->getNumChildren() == 3);
mangle(Child->getLastChild());
break;
case Node::Kind::DependentPseudogenericSignature:
PseudoGeneric = "P";
Expand Down Expand Up @@ -1526,6 +1540,7 @@ void Remangler::mangleImplFunctionType(Node *node) {
Buffer << 'Y';
LLVM_FALLTHROUGH;
case Node::Kind::ImplParameter: {
// Mangle parameter convention.
char ConvCh =
llvm::StringSwitch<char>(Child->getFirstChild()->getText())
.Case("@in", 'i')
Expand All @@ -1540,6 +1555,9 @@ void Remangler::mangleImplFunctionType(Node *node) {
.Default(0);
assert(ConvCh && "invalid impl parameter convention");
Buffer << ConvCh;
// Mangle parameter differentiability, if it exists.
if (Child->getNumChildren() == 3)
mangleImplDifferentiability(Child->getChild(1));
break;
}
case Node::Kind::ImplErrorResult:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// RUN: %target-build-swift -g %s

// SR-12650: IRGenDebugInfo type reconstruction crash because `@noDerivative`
// parameters are not mangled.

import _Differentiation
func id(_ x: Float, _ y: Float) -> Float { x }
let transformed: @differentiable (Float, @noDerivative Float) -> Float = id

// Incorrect reconstructed type for $sS3fIedgyyd_D
// Original type:
// (sil_function_type type=@differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float
// (input=struct_type decl=Swift.(file).Float)
// (input=struct_type decl=Swift.(file).Float)
// (result=struct_type decl=Swift.(file).Float)
// (substitution_map generic_signature=<nullptr>)
// (substitution_map generic_signature=<nullptr>))
// Reconstructed type:
// (sil_function_type type=@differentiable @callee_guaranteed (Float, Float) -> Float
// (input=struct_type decl=Swift.(file).Float)
// (input=struct_type decl=Swift.(file).Float)
// (result=struct_type decl=Swift.(file).Float)
// (substitution_map generic_signature=<nullptr>)
// (substitution_map generic_signature=<nullptr>))
// Stack dump:
// ...
// 1. Swift version 5.3-dev (LLVM 803d1b184d, Swift 477af9f90d)
// 2. While evaluating request IRGenSourceFileRequest(IR Generation for file "noderiv.swift")
// 0 swift 0x00000001104c7ae8 llvm::sys::PrintStackTrace(llvm::raw_ostream&) + 40
// 1 swift 0x00000001104c6a68 llvm::sys::RunSignalHandlers() + 248
// 2 swift 0x00000001104c80dd SignalHandler(int) + 285
// 3 libsystem_platform.dylib 0x00007fff718335fd _sigtramp + 29
// 4 libsystem_platform.dylib 000000000000000000 _sigtramp + 18446603338611739168
// 5 libsystem_c.dylib 0x00007fff71709808 abort + 120
// 6 swift 0x0000000110604152 (anonymous namespace)::IRGenDebugInfoImpl::getOrCreateType(swift::irgen::DebugTypeInfo) (.cold.20) + 146
// 7 swift 0x000000010c24ab1e (anonymous namespace)::IRGenDebugInfoImpl::getOrCreateType(swift::irgen::DebugTypeInfo) + 3614
// 8 swift 0x000000010c245437 swift::irgen::IRGenDebugInfo::emitGlobalVariableDeclaration(llvm::GlobalVariable*, llvm::StringRef, llvm::StringRef, swift::irgen::DebugTypeInfo, bool, bool, llvm::Optional<swift::SILLocation>) + 167
1 change: 1 addition & 0 deletions test/Demangle/Inputs/manglings.txt
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,4 @@ $s17property_wrappers10WithTuplesV9fractionsSd_S2dtvpfP --> property wrapper bac
$sSo17OS_dispatch_queueC4sync7executeyyyXE_tFTOTA ---> {T:$sSo17OS_dispatch_queueC4sync7executeyyyXE_tFTO} partial apply forwarder for @nonobjc __C.OS_dispatch_queue.sync(execute: () -> ()) -> ()
$sxq_Idgnr_D ---> @differentiable @callee_guaranteed (@in_guaranteed A) -> (@out B)
$sxq_Ilgnr_D ---> @differentiable(linear) @callee_guaranteed (@in_guaranteed A) -> (@out B)
$sS3fIedgyywd_D ---> @escaping @differentiable @callee_guaranteed (@unowned Swift.Float, @unowned @noDerivative Swift.Float) -> (@unowned Swift.Float)