diff --git a/docs/ABI/Mangling.rst b/docs/ABI/Mangling.rst index d8fb59b09a666..c301dbf6fbe26 100644 --- a/docs/ABI/Mangling.rst +++ b/docs/ABI/Mangling.rst @@ -589,7 +589,7 @@ mangled in to disambiguate. impl-function-type ::= type* 'I' FUNC-ATTRIBUTES '_' impl-function-type ::= type* generic-signature 'I' FUNC-ATTRIBUTES '_' - FUNC-ATTRIBUTES ::= PATTERN-SUBS? INVOCATION-SUBS? PSEUDO-GENERIC? CALLEE-ESCAPE? DIFFERENTIABILITY-KIND? CALLEE-CONVENTION FUNC-REPRESENTATION? COROUTINE-KIND? PARAM-CONVENTION* RESULT-CONVENTION* ('Y' PARAM-CONVENTION)* ('z' RESULT-CONVENTION)? + FUNC-ATTRIBUTES ::= PATTERN-SUBS? INVOCATION-SUBS? PSEUDO-GENERIC? CALLEE-ESCAPE? DIFFERENTIABILITY-KIND? CALLEE-CONVENTION FUNC-REPRESENTATION? COROUTINE-KIND? (PARAM-CONVENTION PARAM-DIFFERENTIABILITY?)* RESULT-CONVENTION* ('Y' PARAM-CONVENTION)* ('z' RESULT-CONVENTION)? PATTERN-SUBS ::= 's' // has pattern substitutions INVOCATION-SUB ::= 'I' // has invocation substitutions @@ -626,6 +626,8 @@ mangled in to disambiguate. PARAM-CONVENTION ::= 'g' // direct guaranteed PARAM-CONVENTION ::= 'e' // direct deallocating + PARAM-DIFFERENTIABILITY ::= 'w' // @noDerivative + RESULT-CONVENTION ::= 'r' // indirect RESULT-CONVENTION ::= 'o' // owned RESULT-CONVENTION ::= 'd' // unowned diff --git a/include/swift/Demangling/DemangleNodes.def b/include/swift/Demangling/DemangleNodes.def index 121268fecfd61..8321255140d40 100644 --- a/include/swift/Demangling/DemangleNodes.def +++ b/include/swift/Demangling/DemangleNodes.def @@ -117,6 +117,7 @@ NODE(ImplDifferentiable) NODE(ImplLinear) NODE(ImplEscaping) NODE(ImplConvention) +NODE(ImplDifferentiability) NODE(ImplFunctionAttribute) NODE(ImplFunctionType) NODE(ImplInvocationSubstitutions) diff --git a/include/swift/Demangling/Demangler.h b/include/swift/Demangling/Demangler.h index 8afd2e929863d..1ca202f0585bf 100644 --- a/include/swift/Demangling/Demangler.h +++ b/include/swift/Demangling/Demangler.h @@ -518,6 +518,7 @@ class Demangler : public NodeFactory { NodePointer demangleInitializer(); NodePointer demangleImplParamConvention(Node::Kind ConvKind); NodePointer demangleImplResultConvention(Node::Kind ConvKind); + NodePointer demangleImplDifferentiability(); NodePointer demangleImplFunctionType(); NodePointer demangleMetatype(); NodePointer demanglePrivateContextDescriptor(); diff --git a/include/swift/Demangling/TypeDecoder.h b/include/swift/Demangling/TypeDecoder.h index a2a2cc16eef5c..4cc6c53d4b5d5 100644 --- a/include/swift/Demangling/TypeDecoder.h +++ b/include/swift/Demangling/TypeDecoder.h @@ -88,15 +88,31 @@ enum class ImplParameterConvention { Direct_Guaranteed, }; +enum class ImplParameterDifferentiability { + DifferentiableOrNotApplicable, + NotDifferentiable +}; + +static inline Optional +getDifferentiabilityFromString(StringRef string) { + if (string.empty()) + return ImplParameterDifferentiability::DifferentiableOrNotApplicable; + if (string == "@noDerivative") + return ImplParameterDifferentiability::NotDifferentiable; + return None; +} + /// Describe a lowered function parameter, parameterized on the type /// representation. template class ImplFunctionParam { ImplParameterConvention Convention; + ImplParameterDifferentiability Differentiability; BuiltType Type; public: using ConventionType = ImplParameterConvention; + using DifferentiabilityType = ImplParameterDifferentiability; static Optional getConventionFromString(StringRef conventionString) { @@ -120,11 +136,16 @@ class ImplFunctionParam { return None; } - ImplFunctionParam(ImplParameterConvention convention, BuiltType type) - : Convention(convention), Type(type) {} + ImplFunctionParam(ImplParameterConvention convention, + ImplParameterDifferentiability diffKind, BuiltType type) + : Convention(convention), Differentiability(diffKind), Type(type) {} ImplParameterConvention getConvention() const { return Convention; } + ImplParameterDifferentiability getDifferentiability() const { + return Differentiability; + } + BuiltType getType() const { return Type; } }; @@ -614,10 +635,8 @@ class TypeDecoder { ImplFunctionDifferentiabilityKind::Linear); } else if (child->getKind() == NodeKind::ImplEscaping) { flags = flags.withEscaping(); - } else if (child->getKind() == NodeKind::ImplEscaping) { - flags = flags.withEscaping(); } else if (child->getKind() == NodeKind::ImplParameter) { - if (decodeImplFunctionPart(child, parameters)) + if (decodeImplFunctionParam(child, parameters)) return BuiltType(); } else if (child->getKind() == NodeKind::ImplResult) { if (decodeImplFunctionPart(child, results)) @@ -897,6 +916,45 @@ class TypeDecoder { return false; } + bool decodeImplFunctionParam( + Demangle::NodePointer node, + SmallVectorImpl> &results) { + // Children: `convention, differentiability?, type` + if (node->getNumChildren() != 2 && node->getNumChildren() != 3) + return true; + + auto *conventionNode = node->getChild(0); + auto *typeNode = node->getLastChild(); + if (conventionNode->getKind() != Node::Kind::ImplConvention || + typeNode->getKind() != Node::Kind::Type) + return true; + + StringRef conventionString = conventionNode->getText(); + auto convention = + ImplFunctionParam::getConventionFromString(conventionString); + if (!convention) + return true; + BuiltType type = decodeMangledType(typeNode); + if (!type) + return true; + + auto diffKind = + ImplParameterDifferentiability::DifferentiableOrNotApplicable; + if (node->getNumChildren() == 3) { + auto diffKindNode = node->getChild(1); + if (diffKindNode->getKind() != Node::Kind::ImplDifferentiability) + return true; + auto optDiffKind = + getDifferentiabilityFromString(diffKindNode->getText()); + if (!optDiffKind) + return true; + diffKind = *optDiffKind; + } + + results.emplace_back(*convention, diffKind, type); + return false; + } + bool decodeMangledTypeDecl(Demangle::NodePointer node, BuiltTypeDecl &typeDecl, BuiltType &parent, diff --git a/lib/AST/ASTDemangler.cpp b/lib/AST/ASTDemangler.cpp index 63146a83464ba..9b81a8e681aa1 100644 --- a/lib/AST/ASTDemangler.cpp +++ b/lib/AST/ASTDemangler.cpp @@ -444,6 +444,16 @@ getParameterConvention(ImplParameterConvention conv) { llvm_unreachable("covered switch"); } +static SILParameterDifferentiability +getParameterDifferentiability(ImplParameterDifferentiability diffKind) { + switch (diffKind) { + case ImplParameterDifferentiability::DifferentiableOrNotApplicable: + return SILParameterDifferentiability::DifferentiableOrNotApplicable; + case ImplParameterDifferentiability::NotDifferentiable: + return SILParameterDifferentiability::NotDifferentiable; + } +} + static ResultConvention getResultConvention(ImplResultConvention conv) { switch (conv) { case Demangle::ImplResultConvention::Indirect: @@ -526,7 +536,8 @@ Type ASTBuilder::createImplFunctionType( for (const auto ¶m : params) { auto type = param.getType()->getCanonicalType(); auto conv = getParameterConvention(param.getConvention()); - funcParams.emplace_back(type, conv); + auto diffKind = getParameterDifferentiability(param.getDifferentiability()); + funcParams.emplace_back(type, conv, diffKind); } for (const auto &result : results) { diff --git a/lib/AST/ASTMangler.cpp b/lib/AST/ASTMangler.cpp index b1322fa9a84ed..ce887ae986ccf 100644 --- a/lib/AST/ASTMangler.cpp +++ b/lib/AST/ASTMangler.cpp @@ -1569,6 +1569,17 @@ static char getParamConvention(ParameterConvention conv) { llvm_unreachable("bad parameter convention"); }; +static Optional +getParamDifferentiability(SILParameterDifferentiability diffKind) { + switch (diffKind) { + case swift::SILParameterDifferentiability::DifferentiableOrNotApplicable: + return None; + case swift::SILParameterDifferentiability::NotDifferentiable: + return 'w'; + } + llvm_unreachable("bad parameter convention"); +}; + static char getResultConvention(ResultConvention conv) { switch (conv) { case ResultConvention::Indirect: return 'r'; @@ -1658,6 +1669,8 @@ void ASTMangler::appendImplFunctionType(SILFunctionType *fn) { // Mangle the parameters. for (auto param : fn->getParameters()) { OpArgs.push_back(getParamConvention(param.getConvention())); + if (auto diffKind = getParamDifferentiability(param.getDifferentiability())) + OpArgs.push_back(*diffKind); appendType(param.getInterfaceType()); } diff --git a/lib/Demangling/Demangler.cpp b/lib/Demangling/Demangler.cpp index c88a98d2a8a31..2eaf9e08d7a57 100644 --- a/lib/Demangling/Demangler.cpp +++ b/lib/Demangling/Demangler.cpp @@ -1732,6 +1732,14 @@ NodePointer Demangler::demangleImplResultConvention(Node::Kind ConvKind) { createNode(Node::Kind::ImplConvention, attr)); } +NodePointer Demangler::demangleImplDifferentiability() { + // Empty string represents default differentiability. + const char *attr = ""; + if (nextIf('w')) + attr = "@noDerivative"; + return createNode(Node::Kind::ImplDifferentiability, attr); +} + NodePointer Demangler::demangleImplFunctionType() { NodePointer type = createNode(Node::Kind::ImplFunctionType); @@ -1817,8 +1825,10 @@ NodePointer Demangler::demangleImplFunctionType() { int NumTypesToAdd = 0; while (NodePointer Param = - demangleImplParamConvention(Node::Kind::ImplParameter)) { + demangleImplParamConvention(Node::Kind::ImplParameter)) { type = addChild(type, Param); + if (NodePointer Diff = demangleImplDifferentiability()) + Param = addChild(Param, Diff); NumTypesToAdd++; } while (NodePointer Result = demangleImplResultConvention( diff --git a/lib/Demangling/NodePrinter.cpp b/lib/Demangling/NodePrinter.cpp index 16f284234b2d1..655cbac01ea24 100644 --- a/lib/Demangling/NodePrinter.cpp +++ b/lib/Demangling/NodePrinter.cpp @@ -394,6 +394,7 @@ class NodePrinter { case Node::Kind::ImplLinear: case Node::Kind::ImplEscaping: case Node::Kind::ImplConvention: + case Node::Kind::ImplDifferentiability: case Node::Kind::ImplFunctionAttribute: case Node::Kind::ImplFunctionType: case Node::Kind::ImplInvocationSubstitutions: @@ -2060,6 +2061,13 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) { case Node::Kind::ImplConvention: Printer << Node->getText(); return nullptr; + case Node::Kind::ImplDifferentiability: + // Skip if text is empty. + if (Node->getText().empty()) + return nullptr; + // Otherwise, print with trailing space. + Printer << Node->getText() << ' '; + return nullptr; case Node::Kind::ImplFunctionAttribute: Printer << Node->getText(); return nullptr; @@ -2072,6 +2080,16 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) { printChildren(Node, " "); return nullptr; case Node::Kind::ImplParameter: + // Children: `convention, differentiability?, type` + // Print convention. + print(Node->getChild(0)); + Printer << " "; + // Print differentiability, if it exists. + if (Node->getNumChildren() == 3) + print(Node->getChild(1)); + // Print type. + print(Node->getLastChild()); + return nullptr; case Node::Kind::ImplResult: printChildren(Node, " "); return nullptr; diff --git a/lib/Demangling/OldRemangler.cpp b/lib/Demangling/OldRemangler.cpp index 4da4cb52a509b..09b1896d9f9e1 100644 --- a/lib/Demangling/OldRemangler.cpp +++ b/lib/Demangling/OldRemangler.cpp @@ -1326,6 +1326,19 @@ void Remangler::mangleImplConvention(Node *node) { } } +void Remangler::mangleImplDifferentiability(Node *node) { + assert(node->getKind() == Node::Kind::ImplDifferentiability); + StringRef text = node->getText(); + // Empty string represents default differentiability. + if (text.empty()) + return; + if (text == "@noDerivative") { + Buffer << 'w'; + return; + } + unreachable("Invalid impl differentiability"); +} + void Remangler::mangleDynamicSelf(Node *node) { Buffer << 'D'; mangleSingleChildNode(node); // type diff --git a/lib/Demangling/Remangler.cpp b/lib/Demangling/Remangler.cpp index 17e8078746e62..4e9907d7d9699 100644 --- a/lib/Demangling/Remangler.cpp +++ b/lib/Demangling/Remangler.cpp @@ -1420,6 +1420,18 @@ void Remangler::mangleImplConvention(Node *node) { Buffer << ConvCh; } +void Remangler::mangleImplDifferentiability(Node *node) { + assert(node->hasText()); + // Empty string represents default differentiability. + if (node->getText().empty()) + return; + char diffChar = llvm::StringSwitch(node->getText()) + .Case("@noDerivative", 'w') + .Default(0); + assert(diffChar && "Invalid impl differentiability"); + Buffer << diffChar; +} + void Remangler::mangleImplFunctionAttribute(Node *node) { unreachable("handled inline"); } @@ -1443,7 +1455,9 @@ void Remangler::mangleImplFunctionType(Node *node) { case Node::Kind::ImplResult: case Node::Kind::ImplYield: case Node::Kind::ImplErrorResult: - mangleChildNode(Child, 1); + // Mangle type. Type should be the last child. + assert(Child->getNumChildren() == 2 || Child->getNumChildren() == 3); + mangle(Child->getLastChild()); break; case Node::Kind::DependentPseudogenericSignature: PseudoGeneric = "P"; @@ -1526,6 +1540,7 @@ void Remangler::mangleImplFunctionType(Node *node) { Buffer << 'Y'; LLVM_FALLTHROUGH; case Node::Kind::ImplParameter: { + // Mangle parameter convention. char ConvCh = llvm::StringSwitch(Child->getFirstChild()->getText()) .Case("@in", 'i') @@ -1540,6 +1555,9 @@ void Remangler::mangleImplFunctionType(Node *node) { .Default(0); assert(ConvCh && "invalid impl parameter convention"); Buffer << ConvCh; + // Mangle parameter differentiability, if it exists. + if (Child->getNumChildren() == 3) + mangleImplDifferentiability(Child->getChild(1)); break; } case Node::Kind::ImplErrorResult: diff --git a/test/AutoDiff/compiler_crashers_fixed/sr12650-noderivative-parameter-type-mangling.swift b/test/AutoDiff/compiler_crashers_fixed/sr12650-noderivative-parameter-type-mangling.swift new file mode 100644 index 0000000000000..2d3083da33ab6 --- /dev/null +++ b/test/AutoDiff/compiler_crashers_fixed/sr12650-noderivative-parameter-type-mangling.swift @@ -0,0 +1,37 @@ +// RUN: %target-build-swift -g %s + +// SR-12650: IRGenDebugInfo type reconstruction crash because `@noDerivative` +// parameters are not mangled. + +import _Differentiation +func id(_ x: Float, _ y: Float) -> Float { x } +let transformed: @differentiable (Float, @noDerivative Float) -> Float = id + +// Incorrect reconstructed type for $sS3fIedgyyd_D +// Original type: +// (sil_function_type type=@differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float +// (input=struct_type decl=Swift.(file).Float) +// (input=struct_type decl=Swift.(file).Float) +// (result=struct_type decl=Swift.(file).Float) +// (substitution_map generic_signature=) +// (substitution_map generic_signature=)) +// Reconstructed type: +// (sil_function_type type=@differentiable @callee_guaranteed (Float, Float) -> Float +// (input=struct_type decl=Swift.(file).Float) +// (input=struct_type decl=Swift.(file).Float) +// (result=struct_type decl=Swift.(file).Float) +// (substitution_map generic_signature=) +// (substitution_map generic_signature=)) +// Stack dump: +// ... +// 1. Swift version 5.3-dev (LLVM 803d1b184d, Swift 477af9f90d) +// 2. While evaluating request IRGenSourceFileRequest(IR Generation for file "noderiv.swift") +// 0 swift 0x00000001104c7ae8 llvm::sys::PrintStackTrace(llvm::raw_ostream&) + 40 +// 1 swift 0x00000001104c6a68 llvm::sys::RunSignalHandlers() + 248 +// 2 swift 0x00000001104c80dd SignalHandler(int) + 285 +// 3 libsystem_platform.dylib 0x00007fff718335fd _sigtramp + 29 +// 4 libsystem_platform.dylib 000000000000000000 _sigtramp + 18446603338611739168 +// 5 libsystem_c.dylib 0x00007fff71709808 abort + 120 +// 6 swift 0x0000000110604152 (anonymous namespace)::IRGenDebugInfoImpl::getOrCreateType(swift::irgen::DebugTypeInfo) (.cold.20) + 146 +// 7 swift 0x000000010c24ab1e (anonymous namespace)::IRGenDebugInfoImpl::getOrCreateType(swift::irgen::DebugTypeInfo) + 3614 +// 8 swift 0x000000010c245437 swift::irgen::IRGenDebugInfo::emitGlobalVariableDeclaration(llvm::GlobalVariable*, llvm::StringRef, llvm::StringRef, swift::irgen::DebugTypeInfo, bool, bool, llvm::Optional) + 167 diff --git a/test/Demangle/Inputs/manglings.txt b/test/Demangle/Inputs/manglings.txt index c987ab733d537..7da56fdad04ab 100644 --- a/test/Demangle/Inputs/manglings.txt +++ b/test/Demangle/Inputs/manglings.txt @@ -357,3 +357,4 @@ $s17property_wrappers10WithTuplesV9fractionsSd_S2dtvpfP --> property wrapper bac $sSo17OS_dispatch_queueC4sync7executeyyyXE_tFTOTA ---> {T:$sSo17OS_dispatch_queueC4sync7executeyyyXE_tFTO} partial apply forwarder for @nonobjc __C.OS_dispatch_queue.sync(execute: () -> ()) -> () $sxq_Idgnr_D ---> @differentiable @callee_guaranteed (@in_guaranteed A) -> (@out B) $sxq_Ilgnr_D ---> @differentiable(linear) @callee_guaranteed (@in_guaranteed A) -> (@out B) +$sS3fIedgyywd_D ---> @escaping @differentiable @callee_guaranteed (@unowned Swift.Float, @unowned @noDerivative Swift.Float) -> (@unowned Swift.Float)