Skip to content

Commit

Permalink
add differentiation
Browse files Browse the repository at this point in the history
  • Loading branch information
lkct committed Sep 25, 2024
1 parent 7fa497c commit 1238437
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 8 deletions.
30 changes: 30 additions & 0 deletions cirkit/backend/torch/parameters/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,3 +797,33 @@ def forward(self, coeff1: Tensor, coeff2: Tensor) -> Tensor:
)

return ifft(spec, n=degp1, dim=-1) # shape (F, K1*K2, dp1).


class TorchPolynomialDifferential(TorchUnaryParameterOp):
def __init__(self, in_shape: Tuple[int, ...], *, num_folds: int = 1, order: int = 1) -> None:
if order <= 0:
raise ValueError("The order of differentiation must be positive.")
super().__init__(in_shape, num_folds=num_folds)
self.order = order

@property
def shape(self) -> Tuple[int, ...]:
# if dp1>order, i.e., deg>=order, then diff, else const 0.
return (
self.in_shapes[0][0],
self.in_shapes[0][1] - self.order if self.in_shapes[0][1] > self.order else 1,
)

@classmethod
def _diff_once(cls, x: Tensor) -> Tensor:
degp1 = x.shape[-1] # x shape (F, K, dp1).
arange = torch.arange(1, degp1).to(x) # shape (deg,).
return x[..., 1:] * arange # a_n x^n -> n a_n x^(n-1), with a_0 disappeared.

def forward(self, coeff: Tensor) -> Tensor:
if coeff.shape[-1] <= self.order:
return torch.zeros_like(coeff[..., :1]) # shape (F, K, 1).

for _ in range(self.order):
coeff = self._diff_once(coeff)
return coeff # shape (F, K, dp1-ord).
9 changes: 9 additions & 0 deletions cirkit/backend/torch/rules/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
TorchOuterProductParameter,
TorchOuterSumParameter,
TorchPointerParameter,
TorchPolynomialDifferential,
TorchPolynomialProduct,
TorchReduceLSEParameter,
TorchReduceProductParameter,
Expand Down Expand Up @@ -45,6 +46,7 @@
LogSoftmaxParameter,
OuterProductParameter,
OuterSumParameter,
PolynomialDifferential,
PolynomialProduct,
ReduceLSEParameter,
ReduceProductParameter,
Expand Down Expand Up @@ -242,6 +244,12 @@ def compile_polynomial_product(
return TorchPolynomialProduct(*p.in_shapes)


def compile_polynomial_differential(
compiler: "TorchCompiler", p: PolynomialDifferential
) -> TorchPolynomialDifferential:
return TorchPolynomialDifferential(*p.in_shapes, order=p.order)


DEFAULT_PARAMETER_COMPILATION_RULES: Dict[ParameterCompilationSign, ParameterCompilationFunc] = { # type: ignore[misc]
TensorParameter: compile_tensor_parameter,
ConstantParameter: compile_constant_parameter,
Expand All @@ -268,4 +276,5 @@ def compile_polynomial_product(
GaussianProductStddev: compile_gaussian_product_stddev,
GaussianProductLogPartition: compile_gaussian_product_log_partition,
PolynomialProduct: compile_polynomial_product,
PolynomialDifferential: compile_polynomial_differential,
}
12 changes: 8 additions & 4 deletions cirkit/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,13 @@ def multiply(self, lhs_cc: CompiledCircuit, rhs_cc: CompiledCircuit) -> Compiled
prod_sc = SF.multiply(lhs_sc, rhs_sc, registry=self._op_registry)
return self.compile(prod_sc)

def differentiate(self, cc: CompiledCircuit) -> CompiledCircuit:
def differentiate(self, cc: CompiledCircuit, *, order: int = 1) -> CompiledCircuit:
if not self._compiler.has_symbolic(cc):
raise ValueError("The given compiled circuit is not known in this pipeline")
if order <= 0:
raise ValueError("The order of differentiation must be positive.")
sc = self._compiler.get_symbolic_circuit(cc)
diff_sc = SF.differentiate(sc, registry=self._op_registry)
diff_sc = SF.differentiate(sc, registry=self._op_registry, order=order)
return self.compile(diff_sc)

def conjugate(self, cc: CompiledCircuit) -> CompiledCircuit:
Expand Down Expand Up @@ -157,10 +159,12 @@ def multiply(
return ctx.multiply(lhs_cc, rhs_cc)


def differentiate(cc: CompiledCircuit, ctx: Optional[PipelineContext] = None) -> CompiledCircuit:
def differentiate(
cc: CompiledCircuit, ctx: Optional[PipelineContext] = None, *, order: int = 1
) -> CompiledCircuit:
if ctx is None:
ctx = _PIPELINE_CONTEXT.get()
return ctx.differentiate(cc)
return ctx.differentiate(cc, order=order)


def conjugate(cc: CompiledCircuit, ctx: Optional[PipelineContext] = None) -> CompiledCircuit:
Expand Down
178 changes: 175 additions & 3 deletions cirkit/symbolic/functional.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import heapq
import itertools
from typing import Dict, List, Optional, Sequence, Tuple
from typing import Dict, Iterable, List, NamedTuple, Optional, Sequence, Tuple, TypeVar

from cirkit.symbolic.circuit import (
Circuit,
Expand Down Expand Up @@ -263,16 +264,187 @@ def multiply(sc1: Circuit, sc2: Circuit, registry: Optional[OperatorRegistry] =
)


def differentiate(sc: Circuit, registry: Optional[OperatorRegistry] = None) -> Circuit:
class _ScopeVarAndBlockAndInputs(NamedTuple):
"""The tuple of a scope variable and a circiut block for diff.
Used for differential of ProductLayer.
"""

scope_var: int # The id of a variable in the scope of THE ProductLayer.
diff_block: CircuitBlock # The partial diff of THE ProductLayer w.r.t. the var.
diff_in_blocks: List[CircuitBlock] # The inputs to the layer of diff_block.


_T = TypeVar("_T") # TODO: for _repeat. move together


# TODO: this can be made public and moved to utils, might be used elsewhere.
def _repeat(iterable: Iterable[_T], /, *, times: int) -> Iterable[_T]:
"""Repeat each element of the given iterable by given times.
The elements are generated lazily. The iterable passed in will be iterated once.
This function differs from itertools in that it repeats an interable instead of only one elem.
Args:
iterable (Iterable[_T]): The iterable to generate the original elements.
times (int): The times to repeat each element.
Returns:
Iterable[_T]: The iterable with repeated elements.
"""
return itertools.chain.from_iterable(itertools.repeat(elem, times=times) for elem in iterable)


def differentiate(
sc: Circuit, registry: Optional[OperatorRegistry] = None, *, order: int = 1
) -> Circuit:
if not sc.is_smooth or not sc.is_decomposable:
raise StructuralPropertyError(
"Only smooth and decomposable circuits can be efficiently differentiated."
)
if order <= 0:
raise ValueError("The order of differentiation must be positive.")

# Use the registry in the current context, if not specified otherwise
if registry is None:
registry = OPERATOR_REGISTRY.get()
raise NotImplementedError()

# Mapping the symbolic circuit layers with blocks of circuit layers
layers_to_blocks: Dict[Layer, List[CircuitBlock]] = {}

# For each new circuit block, keep track of its inputs
in_blocks: Dict[CircuitBlock, Sequence[CircuitBlock]] = {}

for sl in sc.topological_ordering():
# "diff_blocks: List[CircuitBlock]" is the diff of sl wrt each variable and channel in order
# and then at the end we append a copy of sl
sl_params = {name: p.ref() for name, p in sl.params.items()}

if isinstance(sl, InputLayer):
# TODO: no type hint for func, also cannot quick jump in static analysis
func = registry.retrieve_rule(LayerOperator.DIFFERENTIATION, type(sl))
diff_blocks = [
func(sl, var_idx=var_idx, ch_idx=ch_idx, order=order)
for var_idx, ch_idx in itertools.product(
range(len(sl.scope)), range(sc.num_channels)
)
]

elif isinstance(sl, SumLayer):
# Zip to transpose the generator into an iterable of length (num_vars * num_chs),
# corresponding to each var to take diff.
# Each item is a tuple of length arity, which are inputs to that diff.
# TODO: typeshed issue?
# ANNOTATE: zip gives Any when using *iterables.
zip_blocks_in: Iterable[Tuple[CircuitBlock, ...]] = zip(
# This is a generator of length arity, corresponding to each input of sl.
# Each item is a list of length (num_vars * num_chs), corresponding to the diff wrt
# each variable of that input.
# NOTE: [-1] is omitted and will be added at the end.
*(layers_to_blocks[sl_in][:-1] for sl_in in sc.layer_inputs(sl))
)

# The layers are the same for all diffs of a SumLayer. We retrieve (num_vars * num_chs)
# from the length of one input blocks.
var_ch = len(layers_to_blocks[sc.layer_inputs(sl)[0]][:-1])
diff_blocks = [
CircuitBlock.from_layer(type(sl)(**sl.config, **sl_params)) for _ in range(var_ch)
]

# Connect the layers to their inputs, by zipping a length of (num_vars * num_chs).
in_blocks.update(zip(diff_blocks, zip_blocks_in))

elif isinstance(sl, ProductLayer):
# NOTE: Only the outmost level can be a generator, and inner levels must be lists,
# otherwise reference to locals will be broken.

# This is a generator of length arity, corresponding to each input of sl.
# Each item is a list of length (num_vars * num_chs) of that input, corresponding to the
# diff wrt each var and ch of that input.
all_scope_var_diff_block = (
# Each list is all the diffs of sl wrt each var and each channel in the scope of
# the cur_layer in the input of sl.
[
# Each named-tuple is a diff of sl and its inputs, where the diff is wrt the
# current variable and channel as in the double loop.
_ScopeVarAndBlockAndInputs(
# Label the named-tuple as the var id in the whole scope, for sorting.
scope_var=scope_var,
# The layers are the same for all diffs of a ProductLayer.
diff_block=CircuitBlock.from_layer(type(sl)(**sl.config, **sl_params)),
# The inputs to the diff is the copy of input to sl (retrieved by [-1]),
# only with cur_layer replaced by its diff.
diff_in_blocks=[
diff_cur_layer if sl_in == cur_layer else layers_to_blocks[sl_in][-1]
for sl_in in sc.layer_inputs(sl)
],
)
# Loop over the (num_vars * num_chs) diffs of cur_layer, while also providing
# the corresponding scope_var which the current diff is wrt.
# We need the scope_var to label and sort the diff layers of sl. We do nnt need
# channel ids because they are always saved densely in order.
for scope_var, diff_cur_layer in zip(
_repeat(cur_layer.scope, times=sc.num_channels),
layers_to_blocks[cur_layer][:-1],
)
]
# Loop over each input of sl for the diffs wrt vars and chs in its scope.
for cur_layer in sc.layer_inputs(sl)
)

# NOTE: This relys on the fact that Scope object is iterated in id order.
# Merge sort the named-tuples by the var id in the scope, so that the diffs are
# correctly ordered according to the scope of sl.
sorted_scope_var_diff_block = list(
heapq.merge(
# Unpack the generator into several lists, where each list is the named-tuples
# wrt the scope of each input to sl.
*all_scope_var_diff_block,
key=lambda scope_var_diff_block: scope_var_diff_block.scope_var,
)
)

# Take out the diffs of sl and save them in diff_blocks in correct order.
diff_blocks = [
scope_var_diff_block.diff_block
for scope_var_diff_block in sorted_scope_var_diff_block
]

# Connect the diffs with its corresponding inputs as saved in the named-tuples.
in_blocks.update(
(scope_var_diff_block.diff_block, scope_var_diff_block.diff_in_blocks)
for scope_var_diff_block in sorted_scope_var_diff_block
)

else:
# NOTE: In the above if/elif, we made all conditions explicit to make it more readable
# and also easier for static analysis inside the blocks. Yet the completeness
# cannot be inferred and is only guaranteed by larger picture. Also, should
# anything really go wrong, we will hit this guard statement instead of going into
# a wrong branch.
assert False, "This should not happen."

# Save a copy of sl in the diff circuit and connect inputs. This can be accessed through
# diff_blocks[-1], as in the [-1] above for ProductLayer.
diff_blocks.append(CircuitBlock.from_layer(type(sl)(**sl.config, **sl_params)))
in_blocks[diff_blocks[-1]] = [layers_to_blocks[sl_in][-1] for sl_in in sc.layer_inputs(sl)]

# Save all the blocks including a copy of sl at [-1] as the diff layers of sl.
layers_to_blocks[sl] = diff_blocks

# Construct the integral symbolic circuit and set the integration operation metadata
return Circuit.from_operation(
sc.scope,
sc.num_channels,
sum(layers_to_blocks.values(), []),
in_blocks, # TODO: in_blocks uses Sequence, and Sequence should work.
sum((layers_to_blocks[sl] for sl in sc.outputs), []),
operation=CircuitOperation(
operator=CircuitOperator.DIFFERENTIATION,
operands=(sc,),
metadata=dict(order=order),
),
)


def conjugate(
Expand Down
19 changes: 18 additions & 1 deletion cirkit/symbolic/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
LogParameter,
OuterSumParameter,
Parameter,
PolynomialDifferential,
PolynomialProduct,
ReduceLSEParameter,
ReduceSumParameter,
Expand Down Expand Up @@ -193,6 +194,22 @@ def multiply_polynomial_layers(sl1: PolynomialLayer, sl2: PolynomialLayer) -> Ci
return CircuitBlock.from_layer(sl)


def differentiate_polynomial_layer(
sl: PolynomialLayer, *, var_idx: int, ch_idx: int, order: int = 1
) -> CircuitBlock:
# PolynomialLayer is constructed univariate, but we still take the 2 idx for unified interface
assert (var_idx, ch_idx) == (0, 0), "This should not happen"
if order <= 0:
raise ValueError("The order of differentiation must be positive.")
coeff = Parameter.from_unary(
PolynomialDifferential(sl.coeff.shape, order=order), sl.coeff.ref()
)
sl = PolynomialLayer(
sl.scope, sl.num_output_units, sl.num_channels, degree=coeff.shape[-1] - 1, coeff=coeff
)
return CircuitBlock.from_layer(sl)


def conjugate_categorical_layer(sl: CategoricalLayer) -> CircuitBlock:
logits = sl.logits.ref() if sl.logits is not None else None
probs = sl.probs.ref() if sl.probs is not None else None
Expand Down Expand Up @@ -291,7 +308,7 @@ def __call__(self, *sl: Layer, **kwargs) -> CircuitBlock:

DEFAULT_OPERATOR_RULES: Dict[LayerOperator, List[LayerOperatorFunc]] = {
LayerOperator.INTEGRATION: [integrate_categorical_layer, integrate_gaussian_layer],
LayerOperator.DIFFERENTIATION: [],
LayerOperator.DIFFERENTIATION: [differentiate_polynomial_layer],
LayerOperator.MULTIPLICATION: [
multiply_categorical_layers,
multiply_gaussian_layers,
Expand Down
16 changes: 16 additions & 0 deletions cirkit/symbolic/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,3 +558,19 @@ def shape(self) -> Tuple[int, ...]:
self.in_shapes[0][0] * self.in_shapes[1][0], # dim Ko
self.in_shapes[0][1] + self.in_shapes[1][1] - 1, # dim deg+1
)


class PolynomialDifferential(UnaryParameterOp):
def __init__(self, in_shape: Tuple[int, ...], *, order: int = 1):
if order <= 0:
raise ValueError("The order of differentiation must be positive.")
super().__init__(in_shape)
self.order = order

@property
def shape(self) -> Tuple[int, ...]:
# if dp1>order, i.e., deg>=order, then diff, else const 0.
return (
self.in_shapes[0][0],
self.in_shapes[0][1] - self.order if self.in_shapes[0][1] > self.order else 1,
)

0 comments on commit 1238437

Please sign in to comment.