Skip to content

Commit

Permalink
fix: type inference for lambdas and their parameters (#1304)
Browse files Browse the repository at this point in the history
### Summary of Changes

The following code led to a stack overflow when inferring the type of
the call marked with the arrow:

```
pipeline p {
    g(() -> "1", (a) -> 2); // <-- here
}


@pure
fun g<T>(
    q2: () -> b: T,
    q1: (a: T) -> (),
)
```

This PR fixes this.
  • Loading branch information
lars-reimann authored Jan 4, 2025
1 parent 5a949c2 commit a9e070f
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 71 deletions.
91 changes: 36 additions & 55 deletions packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ export class SafeDsTypeComputer {
private readonly typeChecker: SafeDsTypeChecker;

/**
* Contains all lambda parameters that are currently being computed. When computing the types of lambda parameters,
* they must only access the type of the containing lambda, if they are not contained in this set themselves.
* Otherwise, this would cause endless recursion.
* Contains all calls for which we currently compute substitutions. This prevents endless recursion, since the
* substitutions of a call depend on the inferred types of their arguments, which may be lambdas. The inferred type
* of a lambda in turn depends on the substitutions of the call it is passed to.
*/
private readonly incompleteLambdaParameters = new Set<SdsParameter>();
private readonly incompleteCalls = new Set<SdsAbstractCall>();
private readonly nodeTypeCache: WorkspaceCache<string, Type>;

constructor(services: SafeDsServices) {
Expand Down Expand Up @@ -301,18 +301,21 @@ export class SafeDsTypeComputer {

// Lambda passed as argument
if (isSdsArgument(containerOfLambda)) {
// Lookup parameter type in lambda unless the lambda is being computed. These contain the correct
// substitutions for type parameters.
if (!this.incompleteLambdaParameters.has(node)) {
return this.computeType(containingCallable);
}

const parameter = this.nodeMapper.argumentToParameter(containerOfLambda);
if (!parameter) {
return UnknownType;
}

return this.computeType(parameter);
let result = this.computeType(parameter);

// Substitute type parameters
const call = AstUtils.getContainerOfType(containerOfLambda, isSdsCall);
if (call) {
const substitutions = this.computeSubstitutionsForCall(call, containerOfLambda.$containerIndex);
result = result.substituteTypeParameters(substitutions);
}

return result;
}

// Lambda passed as default value
Expand Down Expand Up @@ -569,29 +572,16 @@ export class SafeDsTypeComputer {
}

private computeTypeOfLambda(node: SdsLambda): Type {
// Remember lambda parameters
const parameters = getParameters(node);
parameters.forEach((it) => {
this.incompleteLambdaParameters.add(it);
});

const parameterEntries = parameters.map((it) => new NamedTupleEntry(it, it.name, this.computeType(it)));
const resultEntries = this.buildLambdaResultEntries(node);

const unsubstitutedType = this.factory.createCallableType(
return this.factory.createCallableType(
node,
undefined,
this.factory.createNamedTupleType(...parameterEntries),
this.factory.createNamedTupleType(...resultEntries),
);
const substitutions = this.computeSubstitutionsForLambda(node, unsubstitutedType);

// Forget lambda parameters
parameters.forEach((it) => {
this.incompleteLambdaParameters.delete(it);
});

return unsubstitutedType.substituteTypeParameters(substitutions);
}

private buildLambdaResultEntries(node: SdsLambda): NamedTupleEntry<SdsAbstractResult>[] {
Expand Down Expand Up @@ -843,23 +833,32 @@ export class SafeDsTypeComputer {
/**
* Computes substitutions for the type parameters of a callable in the context of a call.
*
* @param node The call to compute substitutions for.
* @param node
* The call to compute substitutions for.
* @param argumentEndIndex
* The index of the first argument that should not be considered for the computation. If not specified, all
* arguments are considered.
*
* @returns The computed substitutions for the type parameters of the callable.
*/
computeSubstitutionsForCall(node: SdsAbstractCall): TypeParameterSubstitutions {
return this.doComputeSubstitutionsForCall(node);
}

private doComputeSubstitutionsForCall(
computeSubstitutionsForCall(
node: SdsAbstractCall,
precomputedArgumentTypes?: Map<AstNode | undefined, Type>,
argumentEndIndex: number | undefined = undefined,
): TypeParameterSubstitutions {
// Compute substitutions for member access
const substitutionsFromReceiver =
isSdsCall(node) && isSdsMemberAccess(node.receiver)
? this.computeSubstitutionsForMemberAccess(node.receiver)
: NO_SUBSTITUTIONS;

// Check if the call is already being computed
if (this.incompleteCalls.has(node)) {
return substitutionsFromReceiver;
}

// Remember call
this.incompleteCalls.add(node);

// Compute substitutions for arguments
const callable = this.nodeMapper.callToCallable(node);
const typeParameters = getTypeParameters(callable);
Expand All @@ -868,24 +867,22 @@ export class SafeDsTypeComputer {
}

const parameters = getParameters(callable);
const args = getArguments(node);
const args = getArguments(node).slice(0, argumentEndIndex);

const parametersToArguments = this.nodeMapper.parametersToArguments(parameters, args);
const parameterTypesToArgumentTypes: [Type, Type][] = parameters.map((parameter) => {
const argument = parametersToArguments.get(parameter);
return [
this.computeType(parameter.type),
// Use precomputed argument types (lambdas) if available. This prevents infinite recursion.
precomputedArgumentTypes?.get(argument?.value) ??
this.computeType(argument?.value ?? parameter.defaultValue),
];
return [this.computeType(parameter.type), this.computeType(argument?.value ?? parameter.defaultValue)];
});

const substitutionsFromArguments = this.computeSubstitutionsForArguments(
typeParameters,
parameterTypesToArgumentTypes,
);

// Forget call
this.incompleteCalls.delete(node);

return new Map([...substitutionsFromReceiver, ...substitutionsFromArguments]);
}

Expand Down Expand Up @@ -918,22 +915,6 @@ export class SafeDsTypeComputer {
return this.computeSubstitutionsForArguments(ownTypeParameters, ownTypesToOverriddenTypes);
}

private computeSubstitutionsForLambda(node: SdsLambda, unsubstitutedType: Type): TypeParameterSubstitutions {
const containerOfLambda = node.$container;
if (!isSdsArgument(containerOfLambda)) {
return NO_SUBSTITUTIONS;
}

const containingCall = AstUtils.getContainerOfType(containerOfLambda, isSdsCall);
if (!containingCall) {
/* c8 ignore next 2 */
return NO_SUBSTITUTIONS;
}

const precomputedArgumentTypes = new Map([[node, unsubstitutedType]]);
return this.doComputeSubstitutionsForCall(containingCall, precomputedArgumentTypes);
}

private computeSubstitutionsForMemberAccess(node: SdsMemberAccess): TypeParameterSubstitutions {
const receiverType = this.computeType(node.receiver);
if (receiverType instanceof ClassType) {
Expand Down
2 changes: 1 addition & 1 deletion packages/safe-ds-lang/src/language/validation/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ export const argumentTypesMustMatchParameterTypes = (services: SafeDsServices) =
return;
}

const argumentType = typeComputer.computeType(argument).substituteTypeParameters(substitutions);
const argumentType = typeComputer.computeType(argument);
const parameterType = typeComputer.computeType(parameter).substituteTypeParameters(substitutions);

if (!typeChecker.isSubtypeOf(argumentType, parameterType, { ignoreParameterNames: true })) {
Expand Down
4 changes: 2 additions & 2 deletions packages/safe-ds-lang/tests/helpers/nodeFinder.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ describe('getNodeOfType', async () => {

it('should throw if no node is found', async () => {
const code = '';
expect(async () => {
await expect(async () => {
await getNodeOfType(services, code, isSdsClass);
}).rejects.toThrowErrorMatchingSnapshot();
});

it('should throw if not enough nodes are found', async () => {
const code = `class C`;
expect(async () => {
await expect(async () => {
await getNodeOfType(services, code, isSdsClass, 1);
}).rejects.toThrowErrorMatchingSnapshot();
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ segment mySegment() {
// $TEST$ serialization literal<1>
myFunction(1, (»p«) {});

// $TEST$ serialization literal<"">
// $TEST$ serialization Nothing
myFunction2((»p«) -> "");
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ segment mySegment() {
// $TEST$ serialization literal<1>
myFunction(1, (»p«) -> "");

// $TEST$ serialization literal<"">
// $TEST$ serialization Nothing
myFunction2((»p«) -> "");
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ class MyClass<T>(param: T) sub MySuperclass<T> {
@Pure fun myMethod(callback: (p: T) -> ())
}

@Pure fun myFunction<T>(p: T, id: (p: T) -> (r: T))
@Pure fun myFunction1<T>(p: T, id: (p: T) -> (r: T))
@Pure fun myFunction2<T>(id: (p: T) -> (r: T))
@Pure fun myFunction3<T>(producer: () -> (r: T), consumer: (p: T) -> ())

segment mySegment() {
// $TEST$ serialization (p: literal<1>) -> (r: literal<1>)
Expand All @@ -22,7 +24,19 @@ segment mySegment() {
}«);

// $TEST$ serialization (p: literal<1>) -> (r: literal<1>)
myFunction(1, »(p) {
myFunction1(1, »(p) {
yield r = p;
}«);

// $TEST$ serialization (p: Nothing) -> (r: literal<1>)
myFunction2(»(p) {
yield r = 1;
}«);

// $TEST$ serialization () -> (r: literal<1>)
// $TEST$ serialization (p: literal<1>) -> ()
myFunction3(
»() { yield r = 1; }«,
»(p) {}«,
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@ class MyClass<T>(param: T) sub MySuperclass<T> {
@Pure fun myMethod(callback: (p: T) -> ())
}

@Pure fun myFunction<T>(p: T, id: (p: T) -> (r: T))

segment mySegment() {
// $TEST$ serialization (p: literal<1>) -> (result: literal<1>)
MyClass(1).myMethod(»(p) -> p«);

// $TEST$ serialization (p: literal<1>) -> (result: literal<1>)
MyClass(1).myInheritedMethod(»(p) -> p«);
@Pure fun myFunction1<T>(p: T, id: (p: T) -> (r: T))
@Pure fun myFunction2<T>(id: (p: T) -> (r: T))
@Pure fun myFunction3<T>(producer: () -> (r: T), consumer: (p: T) -> ())

segment mySegment() {
// $TEST$ serialization () -> (result: literal<1>)
// $TEST$ serialization (p: literal<1>) -> (result: literal<1>)
myFunction(1, »(p) -> p«);
myFunction3(
»() -> 1«,
»(p) -> 1«,
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ segment mySegment(
// $TEST$ no error r"Expected type .* but got .*\."
f(»(p) -> p«);

// $TEST$ no error r"Expected type .* but got .*\."
// $TEST$ error "Expected type '(p: literal<1>) -> (r: literal<1>)' but got '(p: Nothing) -> (result: literal<1>)'."
f(»(p) -> 1«);

// $TEST$ no error r"Expected type .* but got .*\."
Expand Down

0 comments on commit a9e070f

Please sign in to comment.