Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rewrite to merge multiple SVD Ops with different settings #769

Merged
merged 18 commits into from
Jun 28, 2024

Conversation

HangenYuu
Copy link
Contributor

@HangenYuu HangenYuu commented May 14, 2024

Description

When there are two or more SVD Ops with the same inputs on a graph, differing only by compute_uv, compute_uv = False should be changed to True everywhere. This will allow pytensor to see that these outputs are equivalent and re-use them, rather than computing the decomposition multiple times.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@HangenYuu
Copy link
Contributor Author

HangenYuu commented May 14, 2024

The PR is still draft right now. I have added a minimally modified copy of tensor\rewritings\linalg\local_det_chol to tensor\rewritings\linalg. I have the following questions:

  1. Am I using the APIs correctly to access and/or modify the argument/attribute of an Op?
  2. I have been tweaking a small example involving computing gradient s w.r.t input a to check for the effect of the rewrite:
import pytensor
import pytensor.tensor as pt
import numpy as np
from pytensor.tensor.type import matrix
from pytensor.tensor.linalg import svd

a_pt = matrix("a")
s = svd(a_pt, full_matrices=False, compute_uv=False)
J, updates = pytensor.scan(lambda i, s, a_pt : pt.grad(s[i], a_pt), sequences=pt.arange(s.shape[0]), non_sequences=[s, a_pt])
f = pytensor.function([a_pt], J, updates=updates)
e = pytensor.graph.fg.FunctionGraph([a_pt], [J], clone=False)

which produces a graph for f with 2 SVDs differing only compute_uv as required.
symbolic_graph_rewrite
However, the graph after rewriting of e contains only 1 SVD so the effect is masked.
image
Tweaking either ended up in the same situation or led to TypeError: Cost must be a scalar. e.g., this Hessian example

import pytensor
import pytensor.tensor as pt
from pytensor.tensor.type import matrix
from pytensor.tensor.linalg import svd

a_pt = matrix("a")
s = svd(a_pt, full_matrices=False, compute_uv=False)
gy = pt.grad(pt.sum(s), a_pt)
H, updates = pytensor.scan(lambda i, gy, a_pt : pt.grad(gy[i], a_pt), sequences=pt.arange(gy.shape[0]), non_sequences=[gy, a_pt])
f = pytensor.function([a_pt], H, updates=updates)
e = pytensor.graph.fg.FunctionGraph([a_pt], [H], clone=False)

Do you have suggestion for a small example to test the rewrite? This one can later be reused for unit testing.

if svd_count > 1 and compute_uv:
for cl in not_compute_uv_svd_list:
cl.op.core_op.compute_uv = True
return [cl.outputs[0] for cl in not_compute_uv_svd_list]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think changing properties of the op inplace might lead to problems...

This rewrite function should run for each SVD node, so maybe it is easier to just locate an existing compute_uv = True node, and return that as replacement for each compuet_uv = False node?

So something like:

  • If compute_uv is False, return and do nothing
  • check if there is a compute_uv = True node in the graph with the same input. If not, return and do nothing
  • Return the exising output of that node as replacement for the current compute_uv = False node.

I wonder though if there could be bad interactions somewhere if there is a rewrite that replaces compute_uv = Fales nodes if they are not used? We don't want to run into any infinite cycles...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 Do you know if there are any problems that could happen if a rewrite returns an existing variable instead of a new one?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there will be a problem only when a rewrite tries to replace a variable by another that depends on the original variable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And yes we shouldn't modify the properties in place. We should replace the smaller Op by the bigger one, just make sure the smaller one is not in the ancestors of the bigger one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise creating a new SVD should be simple, just call the user facing constructor with the specific flags

Copy link
Contributor Author

@HangenYuu HangenYuu May 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I seemed to dump information carelessly. The gist was

  1. I updated the code logic to be a node rewriter.
  2. The rewrite is registered properly in optdb. However, I am having trouble coming up with a test case to show the effect of the rewrite. Perhaps @jessegrabowski can provide the original use case that led to you opening the issue Add rewrite to merge multiple SVD Ops with different settings #732?

Copy link
Member

@jessegrabowski jessegrabowski May 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will arise in gradient graphs. For example, you can just do:

X = pt.dmatrix('X')
s = pt.linalg.svd(X, compute_uv=False)
g = pt.grad(s.sum(), X)

The graph for g will re-compute the SVD of X during the backward pass with compute_uv = True, because we require the matrices U and V to compute the gradient of s with respect to X. Pytensor then won't be able to see that these two computations are the same, and will end up computing the SVD twice.

Copy link
Contributor Author

@HangenYuu HangenYuu May 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a_pt = matrix("a")
s = svd(a_pt, full_matrices=False, compute_uv=False)
gs = pt.grad(pt.sum(s), a_pt)
f = pytensor.function([a_pt], gs)
e = pytensor.graph.fg.FunctionGraph([a_pt], [gs], clone=False)

Thank you. I indeed received a graph for gs and e with 2 different SVD:
image

But for f, I receive a graph with just a single SVD (that seems to be rewritten already with compute_uv=True):
image

The f's rewritten graph will be used in calculation if I run f([[1, 2], [3, 4]]). Does this satisfy your end goal already?

Copy link
Contributor Author

@HangenYuu HangenYuu May 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is f summary profile:

Function profiling
==================
  Message: /tmp/ipykernel_1282122/871230895.py:10
  Time in 1 calls to Function.__call__: 3.448710e-02s
  Time in Function.vm.__call__: 0.03426380921155214s (99.353%)
  Time in thunks: 0.03424406051635742s (99.295%)
  Total compilation time: 4.109558e-02s
    Number of Apply nodes: 2
    PyTensor rewrite time: 2.893809e-02s
       PyTensor validate time: 2.457825e-04s
    PyTensor Linker time (includes C, CUDA code generation/compiling): 0.00876139895990491s
       C-cache preloading 5.506449e-03s
       Import time 8.061258e-04s
       Node make_thunk time 1.967770e-03s
           Node Dot22(SVD{full_matrices=False, compute_uv=True}.0, SVD{full_matrices=False, compute_uv=True}.2) time 1.942240e-03s
           Node SVD{full_matrices=False, compute_uv=True}(a) time 1.436425e-05s

Time in all call to pytensor.grad() 1.036228e-02s
Time since pytensor import 2.774s
Class
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Class name>
  99.8%    99.8%       0.034s       3.42e-02s     Py       1       1   pytensor.tensor.nlinalg.SVD
   0.2%   100.0%       0.000s       6.60e-05s     C        1       1   pytensor.tensor.blas.Dot22
   ... (remaining 0 Classes account for   0.00%(0.00s) of the runtime)

Ops
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Op name>
  99.8%    99.8%       0.034s       3.42e-02s     Py       1        1   SVD{full_matrices=False, compute_uv=True}
   0.2%   100.0%       0.000s       6.60e-05s     C        1        1   Dot22
   ... (remaining 0 Ops account for   0.00%(0.00s) of the runtime)

Apply
------
<% time> <sum %> <apply time> <time per call> <#call> <id> <Apply name>
  99.8%    99.8%       0.034s       3.42e-02s      1     0   SVD{full_matrices=False, compute_uv=True}(a)
   0.2%   100.0%       0.000s       6.60e-05s      1     1   Dot22(SVD{full_matrices=False, compute_uv=True}.0, SVD{full_matrices=False, compute_uv=True}.2)
   ... (remaining 0 Apply instances account for 0.00%(0.00s) of the runtime)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytensor.dprint may be an easier way to introspect the graphs

Copy link

codecov bot commented May 18, 2024

Codecov Report

Attention: Patch coverage is 84.00000% with 4 lines in your changes missing coverage. Please review.

Project coverage is 80.98%. Comparing base (8c157a2) to head (3ba3ba4).
Report is 228 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/rewriting/linalg.py 84.00% 2 Missing and 2 partials ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #769      +/-   ##
==========================================
+ Coverage   80.85%   80.98%   +0.13%     
==========================================
  Files         162      169       +7     
  Lines       47016    46985      -31     
  Branches    11501    11494       -7     
==========================================
+ Hits        38014    38052      +38     
+ Misses       6751     6719      -32     
+ Partials     2251     2214      -37     
Files with missing lines Coverage Δ
pytensor/tensor/rewriting/linalg.py 88.02% <84.00%> (-0.67%) ⬇️

... and 55 files with indirect coverage changes

.gitignore Outdated Show resolved Hide resolved
(x,) = node.inputs
compute_uv = False

for cl, _ in fgraph.clients[x]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have to be careful because if the output of the SVD is an output of the function one of the clients will be a string "output" and the call cl.op will fail.


for cl, _ in fgraph.clients[x]:
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if (not compute_uv) and cl.op.core_op.compute_uv:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need that first check?

Suggested change
if (not compute_uv) and cl.op.core_op.compute_uv:
if cl.op.core_op.compute_uv:


for cl, _ in fgraph.clients[x]:
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if (not compute_uv) and cl.op.core_op.compute_uv:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should check if the uv outputs of this node are actually used (i.e., they have clients of their own). If not, they are useless and the rewrite shouldn't happen. In fact, this or another rewrite should change the flag from True to False for those nodes

@ricardoV94
Copy link
Member

ricardoV94 commented May 21, 2024

I would break this rewrite into different logical parts:

  1. Find all SVD clients from the same input X
  2. Check if any have compute_uv that is actually being used (has clients of their own).
  3. If compute_uv is ever needed/used, replace any variable coming out of an SVD with compute_uv == False by one coming out of an SVD with compute_uv == True. You can return a dictionary of replacements {var_from_svd_without_uv: var_from_svd_with_uv, ...}. You should never have to create a new SVD for this case, because compute_uv can only ever be needed if at least one of the nodes already has it set to True and is using those variables elsewhere in the graph.
  4. If compute_uv is never needed, replace any variable with compute_uv == True, by one of the existing ones with compute_uv==False. If there is no replacement, you can create a brand new SVD operation.

Comment on lines 385 to 428


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([SVD])
def local_svd_uv_simplify(fgraph, node):
"""If we have more than one `SVD` `Op`s and at least one has keyword argument
`compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere
and allow `pytensor` to re-use the decomposition outputs instead of recomputing.
"""
(x,) = node.inputs

if node.compute_uv:
# compute_uv=True returns [u, s, v].
# if at least u or v is used, no need to rewrite this node.
if (
fgraph.clients[node.outputs[0]] is not None
or fgraph.clients[node.outputs[2]] is not None
):
return

# Else, has to replace the s of this node with s of an SVD Op that compute_uv=False.
# First, iterate to see if there is an SVD Op that can be reused.
for cl, _ in fgraph.clients[x]:
if cl == "output":
continue
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if not cl.op.core_op.compute_uv:
return {fgraph.clients[node.outputs[1]]: cl.outputs[0]}

# If no SVD reusable, return a new one.
return [svd(x, full_matrices=node.full_matrices, compute_uv=False)]

else:
# compute_uv=False returns [s].
# We want rewrite if there is another one with compute_uv=True.
# For this case, just reuse the `s` from the one with compute_uv=True.
for cl, _ in fgraph.clients[x]:
if cl == "output":
continue
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if cl.op.core_op.compute_uv:
return [cl.outputs[1]]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ricardoV94. My understanding is like this: The SVD with compute_uv == False will return [s], while the one with compute_uv == True will return [u, s, v]. We want to rewrite when there are 2 SVD Ops using the same input in the graph with different compute_uv value. Let's take the specific example of 2 SVD Ops, svd_f which returns [s_f] and svd_t which returns [u_t, s_t, v_t]. Based on whether at least u_t or v_t is used (since we still have to calculate both even if we use just one of them for subsequent calculations), 1 of 2 rewrites can happen:

  • Case 1: If at least u_t or v_t is used: return [s_t] in place of [s_f].
  • Case 2: Else: return [s_f] in place of [s_t].
  • Case 3: Additionally, if there is just one SVD Op with compute_uv == True, but both u and v are not used, then it must be substituted with a new SVD Op with compute_uv == False.

Copy link
Member

@ricardoV94 ricardoV94 May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup that's it!. When you write down the updated rewrite feel free to add comments with as much explanation as you did here!

Copy link
Member

@ricardoV94 ricardoV94 May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There could also be some weird cases where there are 3 SVDs, one with uv and full_matrices that actually doesn't use the uv, and one with uv and not full matrices that actually uses them (or vice-versa). In that case we could replace one for the other, but perhaps that's too much to worry and unlikely to happen. I don't see we ignoring this causing any bug. I am just raising attention to it so we don't accidentally rewrite a full-matrices into non full-matrices that are actually used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this one return {fgraph.clients[node.outputs[1]]: cl.outputs[0]} is this the correct syntax?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that tells to replace the key by the value variable

@HangenYuu HangenYuu marked this pull request as ready for review May 23, 2024 01:50
pytensor/tensor/rewriting/linalg.py Outdated Show resolved Hide resolved
if cl == "output":
continue
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
if cl.op.core_op.compute_uv:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only want to do this if that other node is actually using the UV. If not we would actually want to replace that node by this one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be taken care by the first half at that node turn. As this is a local rewrite applied to all SVD node, each node will have its turn.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if you don't want to handle that other node there's no reason to rewrite this node into it. In general it's better to do as few rewrites as possible as every time a rewrite succeeds all other candidate rewrites are rerun (until an Equilibrium is achieved and nothing changes anymore).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought I like your eager approach better, it's not readable. Since SVDs are rare we don't need to over optimize

HangenYuu and others added 2 commits May 25, 2024 08:59
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
@HangenYuu
Copy link
Contributor Author

image

The tests run successfully.

tests/tensor/rewriting/test_linalg.py Show resolved Hide resolved
tests/tensor/rewriting/test_linalg.py Outdated Show resolved Hide resolved
tests/tensor/rewriting/test_linalg.py Outdated Show resolved Hide resolved
tests/tensor/rewriting/test_linalg.py Outdated Show resolved Hide resolved
tests/tensor/rewriting/test_linalg.py Outdated Show resolved Hide resolved
tests/tensor/rewriting/test_linalg.py Show resolved Hide resolved
@HangenYuu
Copy link
Contributor Author

I will be slower for the next 2 weeks. I am house looking right now, which should be over by then. I don't expect it to resemble a wedding preparation like this, but it is what it is. For the changes you suggested @ricardoV94 I will edit them in a slot of free time tomorrow.

@ricardoV94
Copy link
Member

No worries and best of luck!

pytensor/tensor/rewriting/linalg.py Outdated Show resolved Hide resolved
pytensor/tensor/rewriting/linalg.py Outdated Show resolved Hide resolved
pytensor/tensor/rewriting/linalg.py Outdated Show resolved Hide resolved
pytensor/tensor/rewriting/linalg.py Show resolved Hide resolved
@HangenYuu
Copy link
Contributor Author

Thanks @ricardoV94 for your patience.

Quick updates: I added your suggestions. The tests are not passed right now. I am looking at it. It seems that the rewrite does not happen for the second case

=================================== FAILURES ===================================
______________________________ test_svd_uv_merge _______________________________

    def test_svd_uv_merge():
        a = matrix("a")
        s_1 = svd(a, full_matrices=False, compute_uv=False)
        _, s_2, _ = svd(a, full_matrices=False, compute_uv=True)
        _, s_3, _ = svd(a, full_matrices=True, compute_uv=True)
        u_4, s_4, v_4 = svd(a, full_matrices=False, compute_uv=True)
        # `grad` will introduces an SVD Op with compute_uv=True
        # full_matrices = True is not supported for grad of svd
        gs = pt.grad(pt.sum(s_1), a)
    
        # 1. compute_uv=False needs rewriting with compute_uv=True
        f_1 = pytensor.function([a], gs)
        nodes = f_1.maker.fgraph.apply_nodes
        svd_counter = 0
        for node in nodes:
            if isinstance(node.op, SVD):
                assert node.op.compute_uv
                svd_counter += 1
        assert svd_counter == 1
    
        # 2. compute_uv=True needs rewriting with compute=False, reuse node
        f_2 = pytensor.function([a], [s_1, s_2])
        nodes = f_2.maker.fgraph.apply_nodes
        svd_counter = 0
        for node in nodes:
            if isinstance(node.op, SVD):
>               assert not node.op.compute_uv
E               assert not True
E                +  where True = SVD(full_matrices=False,compute_uv=True).compute_uv
E                +    where SVD(full_matrices=False,compute_uv=True) = SVD{full_matrices=False, compute_uv=True}(a).op

@HangenYuu
Copy link
Contributor Author

I fixed a minor logic error: If u or v is not used in subsequent calculations, the client will be an empty list, but is still not None, so checking for length is the correct logic. For test case 2 that is not passed, I want to ask if the node rewrite happens for one node at a time, or run for each node in parallel? I suspect that the test was not passed because the 2 SVD Ops are swapped with each other, so the final map looks exactly the same.
image

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 7, 2024

I want to ask if the node rewrite happens for one node at a time, or run for each node in parallel?

It does one node at a time, but keeps applying all rewrites in the database to all compatible nodes, until no further changes take place (that is, until an equilibrium is achieved)

@HangenYuu
Copy link
Contributor Author

HangenYuu commented Jun 11, 2024

I tried commenting out my added code to the file, and reinstall pip install -e .. What I saw was the same test result - only the first test case passes. I suspect that my rewrite was not registered properly in optdb. I assume that it is automatically added after I wrote the codes to the rewriting folder, but it seems I am wrong. May I ask the correct steps to register a new node_rewriter? The current tutorial we have in our documentation does not apply here.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good, just some small questions regarding the test cases

tests/tensor/rewriting/test_linalg.py Outdated Show resolved Hide resolved
for node in nodes:
if isinstance(node.op, SVD):
assert not node.op.compute_uv
assert node.op.full_matrices
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, there's no point in worrying about whether we keep the same full_matrices or not, since they play no role when we don't compute_uv (right?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I check for full_matrices parameter to make sure that the rewrite indeed reuse the Op.

tests/tensor/rewriting/test_linalg.py Show resolved Hide resolved
@ricardoV94
Copy link
Member

ricardoV94 commented Jun 11, 2024

@HangenYuu I think you're registering the rewrites correctly, but you're tracking the wrong Op SVD. When users call the svd helper this will actually return a Blockwise(SVD), so the rewrite should target Blockwise instead (and check at the top that the Blockwise is an SVD).

There's no way to track the Blockwise(SVD) directly because we create those dynamically in the helper function:

def svd(a, full_matrices: bool = True, compute_uv: bool = True):
"""
This function performs the SVD on CPU.
Parameters
----------
full_matrices : bool, optional
If True (default), u and v have the shapes (M, M) and (N, N),
respectively.
Otherwise, the shapes are (M, K) and (K, N), respectively,
where K = min(M, N).
compute_uv : bool, optional
Whether or not to compute u and v in addition to s.
True by default.
Returns
-------
U, V, D : matrices
"""
return Blockwise(SVD(full_matrices, compute_uv))(a)

Alternatively, we could create all the permutation versions in advance, return those from the helper function, and then track those directly, since there are only 3 of them*

  • When compute_uv=False the full_matrices is redundant. We could always coerce one of the combinations into the other.

Copy link
Contributor Author

@HangenYuu HangenYuu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After sitting on it for long, it passes the tests now. Change to tracking Blockwise works.
image

@HangenYuu
Copy link
Contributor Author

@ricardoV94 Sorry for calling you if you are busy. The tests pass now. Can you review the changes to see if anything else need modifying, or the PR can be merged?

@ricardoV94
Copy link
Member

@HangenYuu I tested without the explicit "remove" and it seems to work. Also I did a tiny refactor to reduce indentation. I pushed the commit now. I think it's ready to merge!

@ricardoV94 ricardoV94 added graph rewriting linalg Linear algebra enhancement New feature or request labels Jun 28, 2024
@ricardoV94 ricardoV94 merged commit 920b409 into pymc-devs:main Jun 28, 2024
56 of 57 checks passed
@ricardoV94
Copy link
Member

Thanks @HangenYuu

@HangenYuu HangenYuu deleted the svd_graph_rewrite branch July 3, 2024 07:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting linalg Linear algebra
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add rewrite to merge multiple SVD Ops with different settings
4 participants