Skip to content

Commit

Permalink
refactor to make sigmatch use LayeredIdTable for bindings (#24216)
Browse files Browse the repository at this point in the history
split from #24198

This is a required refactor for the only good solution I've been able to
think of for #4858 etc. Explanation:

---

`sigmatch` currently [disables
bindings](https://github.com/nim-lang/Nim/blob/d6a71a10671b66ee4f5be09f99234b3d834e7fce/compiler/sigmatch.nim#L1956)
(except for binding to other generic parameters) when matching against
constraints of generic parameters. This is so when the constraint is a
general metatype like `seq`, the type matching will not treat all
following uses of `seq` as the type matched against that generic
parameter.

However to solve #4858 etc we need to bind `or` types with a conversion
match to the type they are supposed to be converted to (i.e. matching
`int literal(123)` against `int8 | int16` should bind `int8`[^1], not
`int`). The generic parameter constraint binding needs some way to keep
track of this so that matching `int literal(123)` against `T: int8 |
int16` also binds `T` to `int8`[^1].

The only good way to do this IMO is to generate a new "binding context"
when matching against constraints, then binding the generic param to
what the constraint was bound to in that context (in #24198 this is
restricted to just `or` types & concrete types with convertible matches,
it doesn't work in general).

---

`semtypinst` already does something similar for bindings of generic
invocations using `LayeredIdTable`, so `LayeredIdTable` is now split
into its own module and used in `sigmatch` for type bindings as well,
rather than a single-layer `TypeMapping`. Other modules which act on
`sigmatch`'s binding map are also updated to use this type instead.

The type is also made into an `object` type rather than a `ref object`
to reduce the pointer indirection when embedding it inside
`TCandidate`/`TReplTypeVars`, but only on arc/orc since there are some
weird aliasing bugs on refc/markAndSweep that cause a segfault when
setting a layer to its previous layer. If we want we can also just
remove the conditional compilation altogether and always use `ref
object` at the cost of some performance.

[^1]: `int8` binding here and not `int16` might seem weird, since they
match equally well. But we need to resolve the ambiguity here, in #24012
I tested disallowing ambiguities like this and it broke many packages
that tries to match int literals to things like `int16 | uint16` or
`int8 | int16`. Instead of making these packages stop working I think
it's better we resolve the ambiguity with a rule like "the earliest `or`
branch with the best match, matches". This is the rule used in #24198.
  • Loading branch information
metagn authored Oct 6, 2024
1 parent aa605da commit cad8726
Show file tree
Hide file tree
Showing 10 changed files with 149 additions and 85 deletions.
12 changes: 6 additions & 6 deletions compiler/concepts.nim
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
## for details. Note this is a first implementation and only the "Concept matching"
## section has been implemented.

import ast, astalgo, semdata, lookups, lineinfos, idents, msgs, renderer, types
import ast, semdata, lookups, lineinfos, idents, msgs, renderer, types, layeredtable

import std/intsets

Expand Down Expand Up @@ -309,7 +309,7 @@ proc conceptMatchNode(c: PContext; n: PNode; m: var MatchCon): bool =
# error was reported earlier.
result = false

proc conceptMatch*(c: PContext; concpt, arg: PType; bindings: var TypeMapping; invocation: PType): bool =
proc conceptMatch*(c: PContext; concpt, arg: PType; bindings: var LayeredIdTable; invocation: PType): bool =
## Entry point from sigmatch. 'concpt' is the concept we try to match (here still a PType but
## we extract its AST via 'concpt.n.lastSon'). 'arg' is the type that might fulfill the
## concept's requirements. If so, we return true and fill the 'bindings' with pairs of
Expand All @@ -328,16 +328,16 @@ proc conceptMatch*(c: PContext; concpt, arg: PType; bindings: var TypeMapping; i
dest = existingBinding(m, dest)
if dest == nil or dest.kind != tyGenericParam: break
if dest != nil:
bindings.idTablePut(a, dest)
bindings.put(a, dest)
when logBindings: echo "A bind ", a, " ", dest
else:
bindings.idTablePut(a, b)
bindings.put(a, b)
when logBindings: echo "B bind ", a, " ", b
# we have a match, so bind 'arg' itself to 'concpt':
bindings.idTablePut(concpt, arg)
bindings.put(concpt, arg)
# invocation != nil means we have a non-atomic concept:
if invocation != nil and arg.kind == tyGenericInst and invocation.kidsLen == arg.kidsLen-1:
# bind even more generic parameters
assert invocation.kind == tyGenericInvocation
for i in FirstGenericParamAt ..< invocation.kidsLen:
bindings.idTablePut(invocation[i], arg[i])
bindings.put(invocation[i], arg[i])
82 changes: 82 additions & 0 deletions compiler/layeredtable.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import std/tables
import ast

type
LayeredIdTableObj* {.acyclic.} = object
## stack of type binding contexts implemented as a linked list
topLayer*: TypeMapping
## the mappings on the current layer
nextLayer*: ref LayeredIdTableObj
## the parent type binding context, possibly `nil`
previousLen*: int
## total length of the bindings up to the parent layer,
## used to track if new bindings were added

const useRef = not defined(gcDestructors)
# implementation detail, only arc/orc doesn't cause issues when
# using LayeredIdTable as an object and not a ref

when useRef:
type LayeredIdTable* = ref LayeredIdTableObj
else:
type LayeredIdTable* = LayeredIdTableObj

proc initLayeredTypeMap*(pt: sink TypeMapping = initTypeMapping()): LayeredIdTable =
result = LayeredIdTable(topLayer: pt, nextLayer: nil)

proc shallowCopy*(pt: LayeredIdTable): LayeredIdTable {.inline.} =
## copies only the type bindings of the current layer, but not any parent layers,
## useful for write-only bindings
result = LayeredIdTable(topLayer: pt.topLayer, nextLayer: pt.nextLayer, previousLen: pt.previousLen)

proc currentLen*(pt: LayeredIdTable): int =
## the sum of the cached total binding count of the parents and
## the current binding count, just used to track if bindings were added
pt.previousLen + pt.topLayer.len

proc newTypeMapLayer*(pt: LayeredIdTable): LayeredIdTable =
result = LayeredIdTable(topLayer: initTable[ItemId, PType](), previousLen: pt.currentLen)
when useRef:
result.nextLayer = pt
else:
new(result.nextLayer)
result.nextLayer[] = pt

proc setToPreviousLayer*(pt: var LayeredIdTable) {.inline.} =
when useRef:
pt = pt.nextLayer
else:
when defined(gcDestructors):
pt = pt.nextLayer[]
else:
# workaround refc
let tmp = pt.nextLayer[]
pt = tmp

proc lookup(typeMap: ref LayeredIdTableObj, key: ItemId): PType =
result = nil
var tm = typeMap
while tm != nil:
result = getOrDefault(tm.topLayer, key)
if result != nil: return
tm = tm.nextLayer

template lookup*(typeMap: ref LayeredIdTableObj, key: PType): PType =
## recursively looks up binding of `key` in all parent layers
lookup(typeMap, key.itemId)

when not useRef:
proc lookup(typeMap: LayeredIdTableObj, key: ItemId): PType {.inline.} =
result = getOrDefault(typeMap.topLayer, key)
if result == nil and typeMap.nextLayer != nil:
result = lookup(typeMap.nextLayer, key)

template lookup*(typeMap: LayeredIdTableObj, key: PType): PType =
lookup(typeMap, key.itemId)

proc put(typeMap: var LayeredIdTable, key: ItemId, value: PType) {.inline.} =
typeMap.topLayer[key] = value

template put*(typeMap: var LayeredIdTable, key, value: PType) =
## binds `key` to `value` only in current layer
put(typeMap, key.itemId, value)
6 changes: 3 additions & 3 deletions compiler/sem.nim
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import
evaltempl, patterns, parampatterns, sempass2, linter, semmacrosanity,
lowerings, plugins/active, lineinfos, int128,
isolation_check, typeallowed, modulegraphs, enumtostr, concepts, astmsgs,
extccomp
extccomp, layeredtable

import vtables
import std/[strtabs, math, tables, intsets, strutils, packedsets]
Expand Down Expand Up @@ -478,15 +478,15 @@ proc semAfterMacroCall(c: PContext, call, macroResult: PNode,
# e.g. template foo(T: typedesc): seq[T]
# We will instantiate the return type here, because
# we now know the supplied arguments
var paramTypes = initTypeMapping()
var paramTypes = initLayeredTypeMap()
for param, value in genericParamsInMacroCall(s, call):
var givenType = value.typ
# the sym nodes used for the supplied generic arguments for
# templates and macros leave type nil so regular sem can handle it
# in this case, get the type directly from the sym
if givenType == nil and value.kind == nkSym and value.sym.typ != nil:
givenType = value.sym.typ
idTablePut(paramTypes, param.typ, givenType)
put(paramTypes, param.typ, givenType)

retType = generateTypeInstance(c, paramTypes,
macroResult.info, retType)
Expand Down
4 changes: 2 additions & 2 deletions compiler/semcall.nim
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ proc inheritBindings(c: PContext, x: var TCandidate, expectedType: PType) =
if t[i] == nil or u[i] == nil: return
stackPut(t[i], u[i])
of tyGenericParam:
let prebound = x.bindings.idTableGet(t)
let prebound = x.bindings.lookup(t)
if prebound != nil:
continue # Skip param, already bound

Expand All @@ -769,7 +769,7 @@ proc inheritBindings(c: PContext, x: var TCandidate, expectedType: PType) =
discard
# update bindings
for i in 0 ..< flatUnbound.len():
x.bindings.idTablePut(flatUnbound[i], flatBound[i])
x.bindings.put(flatUnbound[i], flatBound[i])

proc semResolvedCall(c: PContext, x: var TCandidate,
n: PNode, flags: TExprFlags;
Expand Down
10 changes: 5 additions & 5 deletions compiler/semdata.nim
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ when defined(nimPreviewSlimSystem):
import std/assertions

import
options, ast, astalgo, msgs, idents, renderer,
magicsys, vmdef, modulegraphs, lineinfos, pathutils
options, ast, msgs, idents, renderer,
magicsys, vmdef, modulegraphs, lineinfos, pathutils, layeredtable

import ic / ic

Expand Down Expand Up @@ -136,10 +136,10 @@ type
semOverloadedCall*: proc (c: PContext, n, nOrig: PNode,
filter: TSymKinds, flags: TExprFlags, expectedType: PType = nil): PNode {.nimcall.}
semTypeNode*: proc(c: PContext, n: PNode, prev: PType): PType {.nimcall.}
semInferredLambda*: proc(c: PContext, pt: Table[ItemId, PType], n: PNode): PNode
semGenerateInstance*: proc (c: PContext, fn: PSym, pt: Table[ItemId, PType],
semInferredLambda*: proc(c: PContext, pt: LayeredIdTable, n: PNode): PNode
semGenerateInstance*: proc (c: PContext, fn: PSym, pt: LayeredIdTable,
info: TLineInfo): PSym
instantiateOnlyProcType*: proc (c: PContext, pt: TypeMapping,
instantiateOnlyProcType*: proc (c: PContext, pt: LayeredIdTable,
prc: PSym, info: TLineInfo): PType
# used by sigmatch for explicit generic instantiations
includedFiles*: IntSet # used to detect recursive include files
Expand Down
4 changes: 2 additions & 2 deletions compiler/semexprs.nim
Original file line number Diff line number Diff line change
Expand Up @@ -2538,8 +2538,8 @@ proc instantiateCreateFlowVarCall(c: PContext; t: PType;
let sym = magicsys.getCompilerProc(c.graph, "nimCreateFlowVar")
if sym == nil:
localError(c.config, info, "system needs: nimCreateFlowVar")
var bindings = initTypeMapping()
bindings.idTablePut(sym.ast[genericParamsPos][0].typ, t)
var bindings = initLayeredTypeMap()
bindings.put(sym.ast[genericParamsPos][0].typ, t)
result = c.semGenerateInstance(c, sym, bindings, info)
# since it's an instantiation, we unmark it as a compilerproc. Otherwise
# codegen would fail:
Expand Down
12 changes: 6 additions & 6 deletions compiler/seminst.nim
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ proc pushProcCon*(c: PContext; owner: PSym) =
const
errCannotInstantiateX = "cannot instantiate: '$1'"

iterator instantiateGenericParamList(c: PContext, n: PNode, pt: TypeMapping): PSym =
iterator instantiateGenericParamList(c: PContext, n: PNode, pt: LayeredIdTable): PSym =
internalAssert c.config, n.kind == nkGenericParams
for a in n.items:
internalAssert c.config, a.kind == nkSym
Expand All @@ -43,7 +43,7 @@ iterator instantiateGenericParamList(c: PContext, n: PNode, pt: TypeMapping): PS
let symKind = if q.typ.kind == tyStatic: skConst else: skType
var s = newSym(symKind, q.name, c.idgen, getCurrOwner(c), q.info)
s.flags.incl {sfUsed, sfFromGeneric}
var t = idTableGet(pt, q.typ)
var t = lookup(pt, q.typ)
if t == nil:
if tfRetType in q.typ.flags:
# keep the generic type and allow the return type to be bound
Expand Down Expand Up @@ -220,7 +220,7 @@ proc referencesAnotherParam(n: PNode, p: PSym): bool =
if referencesAnotherParam(n[i], p): return true
return false

proc instantiateProcType(c: PContext, pt: TypeMapping,
proc instantiateProcType(c: PContext, pt: LayeredIdTable,
prc: PSym, info: TLineInfo) =
# XXX: Instantiates a generic proc signature, while at the same
# time adding the instantiated proc params into the current scope.
Expand All @@ -237,7 +237,7 @@ proc instantiateProcType(c: PContext, pt: TypeMapping,
# will need to use openScope, addDecl, etc.
#addDecl(c, prc)
pushInfoContext(c.config, info)
var typeMap = initLayeredTypeMap(pt)
var typeMap = shallowCopy(pt) # use previous bindings without writing to them
var cl = initTypeVars(c, typeMap, info, nil)
var result = instCopyType(cl, prc.typ)
let originalParams = result.n
Expand Down Expand Up @@ -324,7 +324,7 @@ proc instantiateProcType(c: PContext, pt: TypeMapping,
prc.typ = result
popInfoContext(c.config)

proc instantiateOnlyProcType(c: PContext, pt: TypeMapping, prc: PSym, info: TLineInfo): PType =
proc instantiateOnlyProcType(c: PContext, pt: LayeredIdTable, prc: PSym, info: TLineInfo): PType =
# instantiates only the type of a given proc symbol
# used by sigmatch for explicit generics
# wouldn't be needed if sigmatch could handle complex cases,
Expand Down Expand Up @@ -360,7 +360,7 @@ proc getLocalPassC(c: PContext, s: PSym): string =
for p in n:
extractPassc(p)

proc generateInstance(c: PContext, fn: PSym, pt: TypeMapping,
proc generateInstance(c: PContext, fn: PSym, pt: LayeredIdTable,
info: TLineInfo): PSym =
## Generates a new instance of a generic procedure.
## The `pt` parameter is a type-unsafe mapping table used to link generic
Expand Down
2 changes: 1 addition & 1 deletion compiler/semstmts.nim
Original file line number Diff line number Diff line change
Expand Up @@ -1969,7 +1969,7 @@ proc semProcAnnotation(c: PContext, prc: PNode;

return result

proc semInferredLambda(c: PContext, pt: TypeMapping, n: PNode): PNode =
proc semInferredLambda(c: PContext, pt: LayeredIdTable, n: PNode): PNode =
## used for resolving 'auto' in lambdas based on their callsite
var n = n
let original = n[namePos].sym
Expand Down
43 changes: 12 additions & 31 deletions compiler/semtypinst.nim
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import std / tables

import ast, astalgo, msgs, types, magicsys, semdata, renderer, options,
lineinfos, modulegraphs
lineinfos, modulegraphs, layeredtable

when defined(nimPreviewSlimSystem):
import std/assertions
Expand Down Expand Up @@ -65,10 +65,6 @@ proc cacheTypeInst(c: PContext; inst: PType) =
addToGenericCache(c, gt.sym, inst)

type
LayeredIdTable* {.acyclic.} = ref object
topLayer*: TypeMapping
nextLayer*: LayeredIdTable

TReplTypeVars* = object
c*: PContext
typeMap*: LayeredIdTable # map PType to PType
Expand All @@ -88,23 +84,8 @@ proc replaceTypeVarsTAux(cl: var TReplTypeVars, t: PType): PType
proc replaceTypeVarsS(cl: var TReplTypeVars, s: PSym, t: PType): PSym
proc replaceTypeVarsN*(cl: var TReplTypeVars, n: PNode; start=0; expectedType: PType = nil): PNode

proc initLayeredTypeMap*(pt: sink TypeMapping): LayeredIdTable =
result = LayeredIdTable()
result.topLayer = pt

proc newTypeMapLayer*(cl: var TReplTypeVars): LayeredIdTable =
result = LayeredIdTable(nextLayer: cl.typeMap, topLayer: initTable[ItemId, PType]())

proc lookup(typeMap: LayeredIdTable, key: PType): PType =
result = nil
var tm = typeMap
while tm != nil:
result = getOrDefault(tm.topLayer, key.itemId)
if result != nil: return
tm = tm.nextLayer

template put(typeMap: LayeredIdTable, key, value: PType) =
typeMap.topLayer[key.itemId] = value
result = newTypeMapLayer(cl.typeMap)

template checkMetaInvariants(cl: TReplTypeVars, t: PType) = # noop code
when false:
Expand Down Expand Up @@ -500,7 +481,7 @@ proc handleGenericInvocation(cl: var TReplTypeVars, t: PType): PType =
newbody.flags = newbody.flags + (t.flags + body.flags - tfInstClearedFlags)
result.flags = result.flags + newbody.flags - tfInstClearedFlags

cl.typeMap = cl.typeMap.nextLayer
setToPreviousLayer(cl.typeMap)

# This is actually wrong: tgeneric_closure fails with this line:
#newbody.callConv = body.callConv
Expand Down Expand Up @@ -791,19 +772,19 @@ proc initTypeVars*(p: PContext, typeMap: LayeredIdTable, info: TLineInfo;
localCache: initTypeMapping(), typeMap: typeMap,
info: info, c: p, owner: owner)

proc replaceTypesInBody*(p: PContext, pt: TypeMapping, n: PNode;
proc replaceTypesInBody*(p: PContext, pt: LayeredIdTable, n: PNode;
owner: PSym, allowMetaTypes = false,
fromStaticExpr = false, expectedType: PType = nil): PNode =
var typeMap = initLayeredTypeMap(pt)
var typeMap = shallowCopy(pt) # use previous bindings without writing to them
var cl = initTypeVars(p, typeMap, n.info, owner)
cl.allowMetaTypes = allowMetaTypes
pushInfoContext(p.config, n.info)
result = replaceTypeVarsN(cl, n, expectedType = expectedType)
popInfoContext(p.config)

proc prepareTypesInBody*(p: PContext, pt: TypeMapping, n: PNode;
proc prepareTypesInBody*(p: PContext, pt: LayeredIdTable, n: PNode;
owner: PSym = nil): PNode =
var typeMap = initLayeredTypeMap(pt)
var typeMap = shallowCopy(pt) # use previous bindings without writing to them
var cl = initTypeVars(p, typeMap, n.info, owner)
pushInfoContext(p.config, n.info)
result = prepareNode(cl, n)
Expand Down Expand Up @@ -836,13 +817,13 @@ proc recomputeFieldPositions*(t: PType; obj: PNode; currPosition: var int) =
inc currPosition
else: discard "cannot happen"

proc generateTypeInstance*(p: PContext, pt: TypeMapping, info: TLineInfo,
proc generateTypeInstance*(p: PContext, pt: LayeredIdTable, info: TLineInfo,
t: PType): PType =
# Given `t` like Foo[T]
# pt: Table with type mappings: T -> int
# Desired result: Foo[int]
# proc (x: T = 0); T -> int ----> proc (x: int = 0)
var typeMap = initLayeredTypeMap(pt)
var typeMap = shallowCopy(pt) # use previous bindings without writing to them
var cl = initTypeVars(p, typeMap, info, nil)
pushInfoContext(p.config, info)
result = replaceTypeVarsT(cl, t)
Expand All @@ -852,15 +833,15 @@ proc generateTypeInstance*(p: PContext, pt: TypeMapping, info: TLineInfo,
var position = 0
recomputeFieldPositions(objType, objType.n, position)

proc prepareMetatypeForSigmatch*(p: PContext, pt: TypeMapping, info: TLineInfo,
proc prepareMetatypeForSigmatch*(p: PContext, pt: LayeredIdTable, info: TLineInfo,
t: PType): PType =
var typeMap = initLayeredTypeMap(pt)
var typeMap = shallowCopy(pt) # use previous bindings without writing to them
var cl = initTypeVars(p, typeMap, info, nil)
cl.allowMetaTypes = true
pushInfoContext(p.config, info)
result = replaceTypeVarsT(cl, t)
popInfoContext(p.config)

template generateTypeInstance*(p: PContext, pt: TypeMapping, arg: PNode,
template generateTypeInstance*(p: PContext, pt: LayeredIdTable, arg: PNode,
t: PType): untyped =
generateTypeInstance(p, pt, arg.info, t)
Loading

0 comments on commit cad8726

Please sign in to comment.