Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 55 additions & 45 deletions python/cudaq/kernel/ast_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def bodyBuilder(iterVar):
[iterVar], rawIndex).result
cc.StoreOp(castedEle, targetEleAddr)

self.createInvariantForLoop(sourceSize, bodyBuilder)
self.createForLoop(sourceSize, bodyBuilder, invariant=True)
return cc.StdvecInitOp(targetVecTy, targetPtr, length=sourceSize).result

def __insertDbgStmt(self, value, dbgStmt):
Expand Down Expand Up @@ -816,15 +816,16 @@ def checkControlAndTargetTypes(self, controls, targets):
for i, target in enumerate(targets)
]

def createInvariantForLoop(self,
endVal,
bodyBuilder,
startVal=None,
stepVal=None,
isDecrementing=False,
elseStmts=None):
def createForLoop(self,
endVal,
bodyBuilder,
startVal=None,
stepVal=None,
isDecrementing=False,
elseStmts=None,
invariant=False):
"""
Create an invariant loop using the CC dialect.
Create a loop using the CC dialect.
"""
startVal = self.getConstantInt(0) if startVal == None else startVal
stepVal = self.getConstantInt(1) if stepVal == None else stepVal
Expand Down Expand Up @@ -868,7 +869,9 @@ def createInvariantForLoop(self,
cc.ContinueOp(elseBlock.arguments)
self.symbolTable.popScope()

loop.attributes.__setitem__('invariant', UnitAttr.get())
if invariant:
loop.attributes['invariant'] = UnitAttr.get()

return

def __applyQuantumOperation(self, opName, parameters, targets):
Expand All @@ -885,7 +888,7 @@ def bodyBuilder(iterVal):

veqSize = quake.VeqSizeOp(self.getIntegerType(),
quantumValue).result
self.createInvariantForLoop(veqSize, bodyBuilder)
self.createForLoop(veqSize, bodyBuilder, invariant=True)
elif quake.RefType.isinstance(quantumValue.type):
opCtor([], parameters, [], [quantumValue])
else:
Expand Down Expand Up @@ -1719,11 +1722,12 @@ def bodyBuilder(iterVar):
incrementedCounter = arith.AddIOp(loadedCounter, one).result
cc.StoreOp(incrementedCounter, counter)

self.createInvariantForLoop(endVal,
bodyBuilder,
startVal=startVal,
stepVal=stepVal,
isDecrementing=isDecrementing)
self.createForLoop(endVal,
bodyBuilder,
startVal=startVal,
stepVal=stepVal,
isDecrementing=isDecrementing,
invariant=True)

self.pushValue(iterable)
self.pushValue(actualSize)
Expand Down Expand Up @@ -1820,7 +1824,7 @@ def bodyBuilder(iterVar):
DenseI64ArrayAttr.get([1], context=self.ctx)).result
cc.StoreOp(element, eleAddr)

self.createInvariantForLoop(totalSize, bodyBuilder)
self.createForLoop(totalSize, bodyBuilder, invariant=True)
self.pushValue(enumIterable)
self.pushValue(totalSize)
return
Expand Down Expand Up @@ -1929,7 +1933,7 @@ def bodyBuilder(iterVal):

veqSize = quake.VeqSizeOp(self.getIntegerType(),
target).result
self.createInvariantForLoop(veqSize, bodyBuilder)
self.createForLoop(veqSize, bodyBuilder, invariant=True)
return
elif quake.RefType.isinstance(target.type):
opCtor([], [], [], [target], is_adj=True)
Expand Down Expand Up @@ -2010,7 +2014,7 @@ def bodyBuilder(iterVal):

veqSize = quake.VeqSizeOp(self.getIntegerType(),
target).result
self.createInvariantForLoop(veqSize, bodyBuilder)
self.createForLoop(veqSize, bodyBuilder, invariant=True)
return
self.emitFatalError(
'reset quantum operation on incorrect type {}.'.format(
Expand Down Expand Up @@ -2780,7 +2784,7 @@ def bodyBuilder(iterVal):

veqSize = quake.VeqSizeOp(self.getIntegerType(),
target).result
self.createInvariantForLoop(veqSize, bodyBuilder)
self.createForLoop(veqSize, bodyBuilder, invariant=True)
return
elif quake.RefType.isinstance(target.type):
opCtor([], [], [], [target], is_adj=True)
Expand Down Expand Up @@ -2860,7 +2864,7 @@ def bodyBuilder(iterVal):

veqSize = quake.VeqSizeOp(self.getIntegerType(),
target).result
self.createInvariantForLoop(veqSize, bodyBuilder)
self.createForLoop(veqSize, bodyBuilder, invariant=True)
return
elif quake.RefType.isinstance(target.type):
opCtor([], [param], [], [target], is_adj=True)
Expand Down Expand Up @@ -2938,7 +2942,7 @@ def bodyBuilder(iterVal):

veqSize = quake.VeqSizeOp(self.getIntegerType(),
target).result
self.createInvariantForLoop(veqSize, bodyBuilder)
self.createForLoop(veqSize, bodyBuilder, invariant=True)
return
elif quake.RefType.isinstance(target.type):
opCtor([], params, [], [target], is_adj=True)
Expand Down Expand Up @@ -3039,14 +3043,21 @@ def visit_ListComp(self, node):
node.generators[0].iter)
if quake.VeqType.isinstance(
self.symbolTable[node.generators[0].iter.id].type):
# now we know we have `[expr(r) for r in iterable]`
# reuse what we do in `visit_For()`
forNode = ast.For()
forNode.iter = node.generators[0].iter
forNode.target = node.generators[0].target
forNode.body = [node.elt]
forNode.orelse = []
self.visit_For(forNode)
iterable = self.symbolTable[node.generators[0].iter.id]
totalSize = quake.VeqSizeOp(self.getIntegerType(),
iterable).result

def bodyBuilder(iterVar):
self.symbolTable.pushScope()
q = quake.ExtractRefOp(self.getRefType(),
iterable,
-1,
index=iterVar).result
self.symbolTable[node.generators[0].target.id] = q
self.visit(node.elt)
self.symbolTable.popScope()

self.createForLoop(totalSize, bodyBuilder, invariant=True)
return

# General case of
Expand Down Expand Up @@ -3115,7 +3126,7 @@ def bodyBuilder(iterVar):
cc.StoreOp(result, listValueAddr)
self.symbolTable.popScope()

self.createInvariantForLoop(iterableSize, bodyBuilder)
self.createForLoop(iterableSize, bodyBuilder, invariant=True)
self.pushValue(
cc.StdvecInitOp(cc.StdvecType.get(listComputePtrTy),
listValue,
Expand Down Expand Up @@ -3462,12 +3473,12 @@ def bodyBuilder(iterVar):
[self.visit(b) for b in node.body]
self.symbolTable.popScope()

self.createInvariantForLoop(endVal,
bodyBuilder,
startVal=startVal,
stepVal=stepVal,
isDecrementing=isDecrementing,
elseStmts=node.orelse)
self.createForLoop(endVal,
bodyBuilder,
startVal=startVal,
stepVal=stepVal,
isDecrementing=isDecrementing,
elseStmts=node.orelse)

return

Expand Down Expand Up @@ -3536,9 +3547,9 @@ def bodyBuilder(iterVar):
[self.visit(b) for b in node.body]
self.symbolTable.popScope()

self.createInvariantForLoop(totalSize,
bodyBuilder,
elseStmts=node.orelse)
self.createForLoop(totalSize,
bodyBuilder,
elseStmts=node.orelse)
return

self.visit(node.iter)
Expand Down Expand Up @@ -3656,9 +3667,7 @@ def bodyBuilder(iterVar):
[self.visit(b) for b in node.body]
self.symbolTable.popScope()

self.createInvariantForLoop(totalSize,
bodyBuilder,
elseStmts=node.orelse)
self.createForLoop(totalSize, bodyBuilder, elseStmts=node.orelse)

def visit_While(self, node):
"""
Expand Down Expand Up @@ -3903,8 +3912,9 @@ def check_element(idx):
current = cc.LoadOp(accumulator).result
cc.StoreOp(arith.OrIOp(current, cmp_result.result), accumulator)

self.createInvariantForLoop(self.__get_vector_size(right_val),
check_element)
self.createForLoop(self.__get_vector_size(right_val),
check_element,
invariant=True)

final_result = cc.LoadOp(accumulator).result
if isinstance(op, ast.NotIn):
Expand Down
2 changes: 1 addition & 1 deletion python/tests/mlir/ast_break.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,6 @@ def kernel(x: float):
# CHECK: ^bb0(%[[VAL_21:.*]]: i64):
# CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_3]] : i64
# CHECK: cc.continue %[[VAL_22]] : i64
# CHECK: } {invariant}
# CHECK: }
# CHECK: return
# CHECK: }
2 changes: 1 addition & 1 deletion python/tests/mlir/ast_continue.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,5 @@ def kernel(x: float):
# CHECK: ^bb0(%[[VAL_23:.*]]: i64):
# CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_3]] : i64
# CHECK: cc.continue %[[VAL_24]] : i64
# CHECK: } {invariant}
# CHECK: }
# CHECK: }
2 changes: 1 addition & 1 deletion python/tests/mlir/ast_decrementing_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ def test(q: int, p: int):
# CHECK: ^bb0(%[[VAL_13:.*]]: i64):
# CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_2]] : i64
# CHECK: cc.continue %[[VAL_14]] : i64
# CHECK: } {invariant}
# CHECK: }
# CHECK: return
# CHECK: }
4 changes: 2 additions & 2 deletions python/tests/mlir/ast_elif.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def cost(thetas: np.ndarray): # can pass 1D ndarray or list
# CHECK: ^bb0(%[[VAL_25:.*]]: i64):
# CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_25]], %[[VAL_5]] : i64
# CHECK: cc.continue %[[VAL_26]] : i64
# CHECK: } {invariant}
# CHECK: }
# CHECK: return
# CHECK: }

Expand Down Expand Up @@ -103,6 +103,6 @@ def cost(thetas: np.ndarray): # can pass 1D ndarray or list
# CHECK: ^bb0(%[[VAL_21:.*]]: i64):
# CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_2]] : i64
# CHECK: cc.continue %[[VAL_22]] : i64
# CHECK: } {invariant}
# CHECK: }
# CHECK: return
# CHECK: }
2 changes: 1 addition & 1 deletion python/tests/mlir/ast_for_stdvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,6 @@ def cost(thetas: np.ndarray): # can pass 1D ndarray or list
# CHECK: ^bb0(%[[VAL_17:.*]]: i64):
# CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_1]] : i64
# CHECK: cc.continue %[[VAL_18]] : i64
# CHECK: } {invariant}
# CHECK: }
# CHECK: return
# CHECK: }
2 changes: 1 addition & 1 deletion python/tests/mlir/ast_iterate_loop_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,6 @@ def kernel(x: float):
# CHECK: ^bb0(%[[VAL_26:.*]]: i64):
# CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_3]] : i64
# CHECK: cc.continue %[[VAL_27]] : i64
# CHECK: } {invariant}
# CHECK: }
# CHECK: return
# CHECK: }
2 changes: 1 addition & 1 deletion python/tests/mlir/ast_list_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,6 @@ def kernel():
# CHECK: ^bb0(%[[VAL_26:.*]]: i64):
# CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_0]] : i64
# CHECK: cc.continue %[[VAL_27]] : i64
# CHECK: } {invariant}
# CHECK: }
# CHECK: return
# CHECK: }
2 changes: 1 addition & 1 deletion python/tests/mlir/ast_list_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,6 @@ def oracle(register: cudaq.qview, auxillary_qubit: cudaq.qubit,
# CHECK: ^bb0(%[[VAL_16:.*]]: i64):
# CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_3]] : i64
# CHECK: cc.continue %[[VAL_17]] : i64
# CHECK: } {invariant}
# CHECK: }
# CHECK: return
# CHECK: }
2 changes: 1 addition & 1 deletion python/tests/mlir/ast_qreg_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def slice():
# CHECK: ^bb0(%[[VAL_43:.*]]: i64):
# CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_43]], %[[VAL_3]] : i64
# CHECK: cc.continue %[[VAL_44]] : i64
# CHECK: } {invariant}
# CHECK: }
# CHECK: %[[VAL_45:.*]] = quake.extract_ref %[[VAL_7]][3] : (!quake.veq<4>) -> !quake.ref
# CHECK: quake.rz (%[[VAL_4]]) %[[VAL_45]] : (f64, !quake.ref) -> ()
# CHECK: return
Expand Down
2 changes: 1 addition & 1 deletion python/tests/mlir/bug_1777.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test():
# CHECK: ^bb0(%[[VAL_18:.*]]: i64):
# CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_1]] : i64
# CHECK: cc.continue %[[VAL_19]] : i64
# CHECK: } {invariant}
# CHECK: }
# CHECK: %[[VAL_20:.*]] = cc.load %[[VAL_6]] : !cc.ptr<i1>
# CHECK: %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_3]] : i1
# CHECK: cc.if(%[[VAL_21]]) {
Expand Down
4 changes: 2 additions & 2 deletions python/tests/mlir/ghz.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def ghz(N: int):
# CHECK: ^bb0(%[[VAL_16:.*]]: i64):
# CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_1]] : i64
# CHECK: cc.continue %[[VAL_17]] : i64
# CHECK: } {invariant}
# CHECK: }
# CHECK: return
# CHECK: }

Expand Down Expand Up @@ -91,6 +91,6 @@ def simple(numQubits: int):
# CHECK: ^bb0(%[[VAL_19:.*]]: i64):
# CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_19]], %[[VAL_2]] : i64
# CHECK: cc.continue %[[VAL_20]] : i64
# CHECK: } {invariant}
# CHECK: }
# CHECK: return
# CHECK: }
2 changes: 1 addition & 1 deletion python/tests/mlir/invalid_subrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def bar():
# CHECK: ^bb0(%[[VAL_17:.*]]: i64):
# CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_1]] : i64
# CHECK: cc.continue %[[VAL_18]] : i64
# CHECK: } {invariant}
# CHECK: }
# CHECK: return
# CHECK: }

Expand Down
6 changes: 3 additions & 3 deletions python/tests/mlir/qft.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def iqft(qubits: cudaq.qview):
# CHECK: ^bb0(%[[VAL_19:.*]]: i64):
# CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_19]], %[[VAL_3]] : i64
# CHECK: cc.continue %[[VAL_20]] : i64
# CHECK: } {invariant}
# CHECK: }
# CHECK: %[[VAL_21:.*]] = cc.load %[[VAL_7]] : !cc.ptr<i64>
# CHECK: %[[VAL_22:.*]] = arith.subi %[[VAL_21]], %[[VAL_3]] : i64
# CHECK: %[[VAL_23:.*]] = cc.loop while ((%[[VAL_24:.*]] = %[[VAL_4]]) -> (i64)) {
Expand Down Expand Up @@ -91,13 +91,13 @@ def iqft(qubits: cudaq.qview):
# CHECK: ^bb0(%[[VAL_41:.*]]: i64):
# CHECK: %[[VAL_42:.*]] = arith.addi %[[VAL_41]], %[[VAL_2]] : i64
# CHECK: cc.continue %[[VAL_42]] : i64
# CHECK: } {invariant}
# CHECK: }
# CHECK: cc.continue %[[VAL_26]] : i64
# CHECK: } step {
# CHECK: ^bb0(%[[VAL_43:.*]]: i64):
# CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_43]], %[[VAL_3]] : i64
# CHECK: cc.continue %[[VAL_44]] : i64
# CHECK: } {invariant}
# CHECK: }
# CHECK: %[[VAL_45:.*]] = cc.load %[[VAL_7]] : !cc.ptr<i64>
# CHECK: %[[VAL_46:.*]] = arith.subi %[[VAL_45]], %[[VAL_3]] : i64
# CHECK: %[[VAL_47:.*]] = quake.extract_ref %[[VAL_0]]{{\[}}%[[VAL_46]]] : (!quake.veq<?>, i64) -> !quake.ref
Expand Down
2 changes: 1 addition & 1 deletion python/tests/mlir/qreg_iterable.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,6 @@ def foo(N: int):
# CHECK: ^bb0(%[[VAL_12:.*]]: i64):
# CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_1]] : i64
# CHECK: cc.continue %[[VAL_13]] : i64
# CHECK: } {invariant}
# CHECK: }
# CHECK: return
# CHECK: }
Loading