diff --git a/python/cudaq/kernel/ast_bridge.py b/python/cudaq/kernel/ast_bridge.py index 920318b8bec..51146cc41d3 100644 --- a/python/cudaq/kernel/ast_bridge.py +++ b/python/cudaq/kernel/ast_bridge.py @@ -634,7 +634,7 @@ def bodyBuilder(iterVar): [iterVar], rawIndex).result cc.StoreOp(castedEle, targetEleAddr) - self.createInvariantForLoop(sourceSize, bodyBuilder) + self.createForLoop(sourceSize, bodyBuilder, invariant=True) return cc.StdvecInitOp(targetVecTy, targetPtr, length=sourceSize).result def __insertDbgStmt(self, value, dbgStmt): @@ -816,15 +816,16 @@ def checkControlAndTargetTypes(self, controls, targets): for i, target in enumerate(targets) ] - def createInvariantForLoop(self, - endVal, - bodyBuilder, - startVal=None, - stepVal=None, - isDecrementing=False, - elseStmts=None): + def createForLoop(self, + endVal, + bodyBuilder, + startVal=None, + stepVal=None, + isDecrementing=False, + elseStmts=None, + invariant=False): """ - Create an invariant loop using the CC dialect. + Create a loop using the CC dialect. """ startVal = self.getConstantInt(0) if startVal == None else startVal stepVal = self.getConstantInt(1) if stepVal == None else stepVal @@ -868,7 +869,9 @@ def createInvariantForLoop(self, cc.ContinueOp(elseBlock.arguments) self.symbolTable.popScope() - loop.attributes.__setitem__('invariant', UnitAttr.get()) + if invariant: + loop.attributes['invariant'] = UnitAttr.get() + return def __applyQuantumOperation(self, opName, parameters, targets): @@ -885,7 +888,7 @@ def bodyBuilder(iterVal): veqSize = quake.VeqSizeOp(self.getIntegerType(), quantumValue).result - self.createInvariantForLoop(veqSize, bodyBuilder) + self.createForLoop(veqSize, bodyBuilder, invariant=True) elif quake.RefType.isinstance(quantumValue.type): opCtor([], parameters, [], [quantumValue]) else: @@ -1719,11 +1722,12 @@ def bodyBuilder(iterVar): incrementedCounter = arith.AddIOp(loadedCounter, one).result cc.StoreOp(incrementedCounter, counter) - self.createInvariantForLoop(endVal, - bodyBuilder, - startVal=startVal, - stepVal=stepVal, - isDecrementing=isDecrementing) + self.createForLoop(endVal, + bodyBuilder, + startVal=startVal, + stepVal=stepVal, + isDecrementing=isDecrementing, + invariant=True) self.pushValue(iterable) self.pushValue(actualSize) @@ -1820,7 +1824,7 @@ def bodyBuilder(iterVar): DenseI64ArrayAttr.get([1], context=self.ctx)).result cc.StoreOp(element, eleAddr) - self.createInvariantForLoop(totalSize, bodyBuilder) + self.createForLoop(totalSize, bodyBuilder, invariant=True) self.pushValue(enumIterable) self.pushValue(totalSize) return @@ -1929,7 +1933,7 @@ def bodyBuilder(iterVal): veqSize = quake.VeqSizeOp(self.getIntegerType(), target).result - self.createInvariantForLoop(veqSize, bodyBuilder) + self.createForLoop(veqSize, bodyBuilder, invariant=True) return elif quake.RefType.isinstance(target.type): opCtor([], [], [], [target], is_adj=True) @@ -2010,7 +2014,7 @@ def bodyBuilder(iterVal): veqSize = quake.VeqSizeOp(self.getIntegerType(), target).result - self.createInvariantForLoop(veqSize, bodyBuilder) + self.createForLoop(veqSize, bodyBuilder, invariant=True) return self.emitFatalError( 'reset quantum operation on incorrect type {}.'.format( @@ -2780,7 +2784,7 @@ def bodyBuilder(iterVal): veqSize = quake.VeqSizeOp(self.getIntegerType(), target).result - self.createInvariantForLoop(veqSize, bodyBuilder) + self.createForLoop(veqSize, bodyBuilder, invariant=True) return elif quake.RefType.isinstance(target.type): opCtor([], [], [], [target], is_adj=True) @@ -2860,7 +2864,7 @@ def bodyBuilder(iterVal): veqSize = quake.VeqSizeOp(self.getIntegerType(), target).result - self.createInvariantForLoop(veqSize, bodyBuilder) + self.createForLoop(veqSize, bodyBuilder, invariant=True) return elif quake.RefType.isinstance(target.type): opCtor([], [param], [], [target], is_adj=True) @@ -2938,7 +2942,7 @@ def bodyBuilder(iterVal): veqSize = quake.VeqSizeOp(self.getIntegerType(), target).result - self.createInvariantForLoop(veqSize, bodyBuilder) + self.createForLoop(veqSize, bodyBuilder, invariant=True) return elif quake.RefType.isinstance(target.type): opCtor([], params, [], [target], is_adj=True) @@ -3039,14 +3043,21 @@ def visit_ListComp(self, node): node.generators[0].iter) if quake.VeqType.isinstance( self.symbolTable[node.generators[0].iter.id].type): - # now we know we have `[expr(r) for r in iterable]` - # reuse what we do in `visit_For()` - forNode = ast.For() - forNode.iter = node.generators[0].iter - forNode.target = node.generators[0].target - forNode.body = [node.elt] - forNode.orelse = [] - self.visit_For(forNode) + iterable = self.symbolTable[node.generators[0].iter.id] + totalSize = quake.VeqSizeOp(self.getIntegerType(), + iterable).result + + def bodyBuilder(iterVar): + self.symbolTable.pushScope() + q = quake.ExtractRefOp(self.getRefType(), + iterable, + -1, + index=iterVar).result + self.symbolTable[node.generators[0].target.id] = q + self.visit(node.elt) + self.symbolTable.popScope() + + self.createForLoop(totalSize, bodyBuilder, invariant=True) return # General case of @@ -3115,7 +3126,7 @@ def bodyBuilder(iterVar): cc.StoreOp(result, listValueAddr) self.symbolTable.popScope() - self.createInvariantForLoop(iterableSize, bodyBuilder) + self.createForLoop(iterableSize, bodyBuilder, invariant=True) self.pushValue( cc.StdvecInitOp(cc.StdvecType.get(listComputePtrTy), listValue, @@ -3462,12 +3473,12 @@ def bodyBuilder(iterVar): [self.visit(b) for b in node.body] self.symbolTable.popScope() - self.createInvariantForLoop(endVal, - bodyBuilder, - startVal=startVal, - stepVal=stepVal, - isDecrementing=isDecrementing, - elseStmts=node.orelse) + self.createForLoop(endVal, + bodyBuilder, + startVal=startVal, + stepVal=stepVal, + isDecrementing=isDecrementing, + elseStmts=node.orelse) return @@ -3536,9 +3547,9 @@ def bodyBuilder(iterVar): [self.visit(b) for b in node.body] self.symbolTable.popScope() - self.createInvariantForLoop(totalSize, - bodyBuilder, - elseStmts=node.orelse) + self.createForLoop(totalSize, + bodyBuilder, + elseStmts=node.orelse) return self.visit(node.iter) @@ -3656,9 +3667,7 @@ def bodyBuilder(iterVar): [self.visit(b) for b in node.body] self.symbolTable.popScope() - self.createInvariantForLoop(totalSize, - bodyBuilder, - elseStmts=node.orelse) + self.createForLoop(totalSize, bodyBuilder, elseStmts=node.orelse) def visit_While(self, node): """ @@ -3903,8 +3912,9 @@ def check_element(idx): current = cc.LoadOp(accumulator).result cc.StoreOp(arith.OrIOp(current, cmp_result.result), accumulator) - self.createInvariantForLoop(self.__get_vector_size(right_val), - check_element) + self.createForLoop(self.__get_vector_size(right_val), + check_element, + invariant=True) final_result = cc.LoadOp(accumulator).result if isinstance(op, ast.NotIn): diff --git a/python/tests/mlir/ast_break.py b/python/tests/mlir/ast_break.py index a0b36549864..0def54647b1 100644 --- a/python/tests/mlir/ast_break.py +++ b/python/tests/mlir/ast_break.py @@ -60,6 +60,6 @@ def kernel(x: float): # CHECK: ^bb0(%[[VAL_21:.*]]: i64): # CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_3]] : i64 # CHECK: cc.continue %[[VAL_22]] : i64 -# CHECK: } {invariant} +# CHECK: } # CHECK: return # CHECK: } diff --git a/python/tests/mlir/ast_continue.py b/python/tests/mlir/ast_continue.py index 8c1c03b5ea5..48eed4754e9 100644 --- a/python/tests/mlir/ast_continue.py +++ b/python/tests/mlir/ast_continue.py @@ -64,5 +64,5 @@ def kernel(x: float): # CHECK: ^bb0(%[[VAL_23:.*]]: i64): # CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_3]] : i64 # CHECK: cc.continue %[[VAL_24]] : i64 -# CHECK: } {invariant} +# CHECK: } # CHECK: } diff --git a/python/tests/mlir/ast_decrementing_range.py b/python/tests/mlir/ast_decrementing_range.py index fa65f383b44..4bbf47c8c6d 100644 --- a/python/tests/mlir/ast_decrementing_range.py +++ b/python/tests/mlir/ast_decrementing_range.py @@ -46,6 +46,6 @@ def test(q: int, p: int): # CHECK: ^bb0(%[[VAL_13:.*]]: i64): # CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_2]] : i64 # CHECK: cc.continue %[[VAL_14]] : i64 -# CHECK: } {invariant} +# CHECK: } # CHECK: return # CHECK: } diff --git a/python/tests/mlir/ast_elif.py b/python/tests/mlir/ast_elif.py index ac1703820bb..f781bc6d912 100644 --- a/python/tests/mlir/ast_elif.py +++ b/python/tests/mlir/ast_elif.py @@ -66,7 +66,7 @@ def cost(thetas: np.ndarray): # can pass 1D ndarray or list # CHECK: ^bb0(%[[VAL_25:.*]]: i64): # CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_25]], %[[VAL_5]] : i64 # CHECK: cc.continue %[[VAL_26]] : i64 -# CHECK: } {invariant} +# CHECK: } # CHECK: return # CHECK: } @@ -103,6 +103,6 @@ def cost(thetas: np.ndarray): # can pass 1D ndarray or list # CHECK: ^bb0(%[[VAL_21:.*]]: i64): # CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_2]] : i64 # CHECK: cc.continue %[[VAL_22]] : i64 -# CHECK: } {invariant} +# CHECK: } # CHECK: return # CHECK: } diff --git a/python/tests/mlir/ast_for_stdvec.py b/python/tests/mlir/ast_for_stdvec.py index f75a7a9887d..4efc921c92a 100644 --- a/python/tests/mlir/ast_for_stdvec.py +++ b/python/tests/mlir/ast_for_stdvec.py @@ -54,6 +54,6 @@ def cost(thetas: np.ndarray): # can pass 1D ndarray or list # CHECK: ^bb0(%[[VAL_17:.*]]: i64): # CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_1]] : i64 # CHECK: cc.continue %[[VAL_18]] : i64 -# CHECK: } {invariant} +# CHECK: } # CHECK: return # CHECK: } diff --git a/python/tests/mlir/ast_iterate_loop_init.py b/python/tests/mlir/ast_iterate_loop_init.py index 565f06bfa40..e488601f225 100644 --- a/python/tests/mlir/ast_iterate_loop_init.py +++ b/python/tests/mlir/ast_iterate_loop_init.py @@ -63,6 +63,6 @@ def kernel(x: float): # CHECK: ^bb0(%[[VAL_26:.*]]: i64): # CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_3]] : i64 # CHECK: cc.continue %[[VAL_27]] : i64 -# CHECK: } {invariant} +# CHECK: } # CHECK: return # CHECK: } diff --git a/python/tests/mlir/ast_list_init.py b/python/tests/mlir/ast_list_init.py index 9901d6557a6..06506a6fc69 100644 --- a/python/tests/mlir/ast_list_init.py +++ b/python/tests/mlir/ast_list_init.py @@ -63,6 +63,6 @@ def kernel(): # CHECK: ^bb0(%[[VAL_26:.*]]: i64): # CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_0]] : i64 # CHECK: cc.continue %[[VAL_27]] : i64 -# CHECK: } {invariant} +# CHECK: } # CHECK: return # CHECK: } diff --git a/python/tests/mlir/ast_list_int.py b/python/tests/mlir/ast_list_int.py index edcaa428219..c67a5d2c202 100644 --- a/python/tests/mlir/ast_list_int.py +++ b/python/tests/mlir/ast_list_int.py @@ -52,6 +52,6 @@ def oracle(register: cudaq.qview, auxillary_qubit: cudaq.qubit, # CHECK: ^bb0(%[[VAL_16:.*]]: i64): # CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_3]] : i64 # CHECK: cc.continue %[[VAL_17]] : i64 -# CHECK: } {invariant} +# CHECK: } # CHECK: return # CHECK: } diff --git a/python/tests/mlir/ast_qreg_slice.py b/python/tests/mlir/ast_qreg_slice.py index f24a9152742..96b2112e8d9 100644 --- a/python/tests/mlir/ast_qreg_slice.py +++ b/python/tests/mlir/ast_qreg_slice.py @@ -114,7 +114,7 @@ def slice(): # CHECK: ^bb0(%[[VAL_43:.*]]: i64): # CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_43]], %[[VAL_3]] : i64 # CHECK: cc.continue %[[VAL_44]] : i64 -# CHECK: } {invariant} +# CHECK: } # CHECK: %[[VAL_45:.*]] = quake.extract_ref %[[VAL_7]][3] : (!quake.veq<4>) -> !quake.ref # CHECK: quake.rz (%[[VAL_4]]) %[[VAL_45]] : (f64, !quake.ref) -> () # CHECK: return diff --git a/python/tests/mlir/bug_1777.py b/python/tests/mlir/bug_1777.py index fd5b2f75385..95c6838eac2 100644 --- a/python/tests/mlir/bug_1777.py +++ b/python/tests/mlir/bug_1777.py @@ -61,7 +61,7 @@ def test(): # CHECK: ^bb0(%[[VAL_18:.*]]: i64): # CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_1]] : i64 # CHECK: cc.continue %[[VAL_19]] : i64 -# CHECK: } {invariant} +# CHECK: } # CHECK: %[[VAL_20:.*]] = cc.load %[[VAL_6]] : !cc.ptr # CHECK: %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_3]] : i1 # CHECK: cc.if(%[[VAL_21]]) { diff --git a/python/tests/mlir/ghz.py b/python/tests/mlir/ghz.py index ac6dd80116c..ff6dc96d41a 100644 --- a/python/tests/mlir/ghz.py +++ b/python/tests/mlir/ghz.py @@ -48,7 +48,7 @@ def ghz(N: int): # CHECK: ^bb0(%[[VAL_16:.*]]: i64): # CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_1]] : i64 # CHECK: cc.continue %[[VAL_17]] : i64 - # CHECK: } {invariant} + # CHECK: } # CHECK: return # CHECK: } @@ -91,6 +91,6 @@ def simple(numQubits: int): # CHECK: ^bb0(%[[VAL_19:.*]]: i64): # CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_19]], %[[VAL_2]] : i64 # CHECK: cc.continue %[[VAL_20]] : i64 -# CHECK: } {invariant} +# CHECK: } # CHECK: return # CHECK: } diff --git a/python/tests/mlir/invalid_subrange.py b/python/tests/mlir/invalid_subrange.py index fbf1e856e39..98305ae77e8 100644 --- a/python/tests/mlir/invalid_subrange.py +++ b/python/tests/mlir/invalid_subrange.py @@ -62,7 +62,7 @@ def bar(): # CHECK: ^bb0(%[[VAL_17:.*]]: i64): # CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_1]] : i64 # CHECK: cc.continue %[[VAL_18]] : i64 -# CHECK: } {invariant} +# CHECK: } # CHECK: return # CHECK: } diff --git a/python/tests/mlir/qft.py b/python/tests/mlir/qft.py index 4565611d9a6..e8c83ffa990 100644 --- a/python/tests/mlir/qft.py +++ b/python/tests/mlir/qft.py @@ -60,7 +60,7 @@ def iqft(qubits: cudaq.qview): # CHECK: ^bb0(%[[VAL_19:.*]]: i64): # CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_19]], %[[VAL_3]] : i64 # CHECK: cc.continue %[[VAL_20]] : i64 -# CHECK: } {invariant} +# CHECK: } # CHECK: %[[VAL_21:.*]] = cc.load %[[VAL_7]] : !cc.ptr # CHECK: %[[VAL_22:.*]] = arith.subi %[[VAL_21]], %[[VAL_3]] : i64 # CHECK: %[[VAL_23:.*]] = cc.loop while ((%[[VAL_24:.*]] = %[[VAL_4]]) -> (i64)) { @@ -91,13 +91,13 @@ def iqft(qubits: cudaq.qview): # CHECK: ^bb0(%[[VAL_41:.*]]: i64): # CHECK: %[[VAL_42:.*]] = arith.addi %[[VAL_41]], %[[VAL_2]] : i64 # CHECK: cc.continue %[[VAL_42]] : i64 -# CHECK: } {invariant} +# CHECK: } # CHECK: cc.continue %[[VAL_26]] : i64 # CHECK: } step { # CHECK: ^bb0(%[[VAL_43:.*]]: i64): # CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_43]], %[[VAL_3]] : i64 # CHECK: cc.continue %[[VAL_44]] : i64 -# CHECK: } {invariant} +# CHECK: } # CHECK: %[[VAL_45:.*]] = cc.load %[[VAL_7]] : !cc.ptr # CHECK: %[[VAL_46:.*]] = arith.subi %[[VAL_45]], %[[VAL_3]] : i64 # CHECK: %[[VAL_47:.*]] = quake.extract_ref %[[VAL_0]]{{\[}}%[[VAL_46]]] : (!quake.veq, i64) -> !quake.ref diff --git a/python/tests/mlir/qreg_iterable.py b/python/tests/mlir/qreg_iterable.py index 8ece9fafe23..390a9ed428a 100644 --- a/python/tests/mlir/qreg_iterable.py +++ b/python/tests/mlir/qreg_iterable.py @@ -43,6 +43,6 @@ def foo(N: int): # CHECK: ^bb0(%[[VAL_12:.*]]: i64): # CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_1]] : i64 # CHECK: cc.continue %[[VAL_13]] : i64 -# CHECK: } {invariant} +# CHECK: } # CHECK: return # CHECK: }