Skip to content

Commit

Permalink
Implement KrylovKit.eigsolve for Tensors (#157)
Browse files Browse the repository at this point in the history
* Add KrylovKit as weakdep and extension

* Add extension TenetKrylovKitExt

* Add KrylovKit tests to integration tests

* Add module semantic

* Format code

* Extend tests

* Return info from KrylovKit

* Refactor code

* Fix permutation test

* Fix test

* Fix unexisting variable name

---------

Co-authored-by: Sergio Sánchez Ramírez <sergio.sanchez.ramirez+git@bsc.es>
Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 8, 2024
1 parent 5a09e3d commit 320cb02
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 0 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
Quac = "b9105292-1415-45cf-bff1-d6ccf71e6143"
Expand All @@ -38,6 +39,7 @@ TenetChainRulesTestUtilsExt = ["ChainRulesCore", "ChainRulesTestUtils"]
TenetDaggerExt = "Dagger"
TenetFiniteDifferencesExt = "FiniteDifferences"
TenetGraphMakieExt = ["GraphMakie", "Makie"]
TenetKrylovKitExt = ["KrylovKit"]
TenetReactantExt = "Reactant"
TenetQuacExt = "Quac"
TenetYaoExt = "Yao"
Expand All @@ -55,6 +57,7 @@ EinExprs = "0.5, 0.6"
FiniteDifferences = "0.12"
GraphMakie = "0.4,0.5"
Graphs = "1.7"
KrylovKit = "0.8.1"
LinearAlgebra = "1.9"
Makie = "0.18,0.19,0.20, 0.21"
Muscle = "0.2"
Expand Down
98 changes: 98 additions & 0 deletions ext/TenetKrylovKitExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
module TenetKrylovKitExt

using Tenet
using KrylovKit

function eigsolve_prehook_tensor_reshape(A::Tensor, left_inds, right_inds)
left_inds, right_inds = Tenet.factorinds(A, left_inds, right_inds)

# Determine the left and right indices
left_sizes = size.((A,), left_inds)
right_sizes = size.((A,), right_inds)
prod_left_sizes = prod(left_sizes)
prod_right_sizes = prod(right_sizes)

if prod_left_sizes != prod_right_sizes
throw(
ArgumentError("The resulting matrix must be square, but got sizes $prod_left_sizes and $prod_right_sizes.")
)
end

# Permute and reshape the tensor
A = permutedims(A, [left_inds..., right_inds...])
Amat = reshape(parent(A), prod_left_sizes, prod_right_sizes)

return Amat, left_sizes, right_sizes
end

function KrylovKit.eigselector(A::Tensor, T::Type; left_inds=Symbol[], right_inds=Symbol[], kwargs...)
Amat, _, _ = eigsolve_prehook_tensor_reshape(A, left_inds, right_inds)
return KrylovKit.eigselector(Amat, T; kwargs...)
end

function KrylovKit.eigsolve(
A::Tensor,
howmany::Int=1,
which::KrylovKit.Selector=:LM,
T::Type=eltype(A);
left_inds=Symbol[],
right_inds=Symbol[],
kwargs...,
)
Amat, left_sizes, right_sizes = eigsolve_prehook_tensor_reshape(A, left_inds, right_inds)

# Compute eigenvalues and eigenvectors
vals, vecs, info = KrylovKit.eigsolve(Amat, howmany, which; kwargs...)

# Tensorify the eigenvectors
Avecs = [Tensor(reshape(vec, left_sizes...), left_inds) for vec in vecs]

return vals, Avecs, info
end

function KrylovKit.eigsolve(
f::Tensor, x₀, howmany::Int=1, which::KrylovKit.Selector=:LM; left_inds=Symbol[], right_inds=Symbol[], kwargs...
)
Amat, left_sizes, right_sizes = eigsolve_prehook_tensor_reshape(A, left_inds, right_inds)

# Compute eigenvalues and eigenvectors
vals, vecs, info = KrylovKit.eigsolve(Amat, x₀, howmany, which; kwargs...)

# Tensorify the eigenvectors
Avecs = [Tensor(reshape(vec, left_sizes...), left_inds) for vec in vecs]

return vals, Avecs, info
end

"""
KrylovKit.eigsolve(tensor::Tensor; left_inds, right_inds, kwargs...)
Perform eigenvalue decomposition on a tensor.
# Keyword arguments
- `left_inds`: left indices to be used in the eigenvalue decomposition. Defaults to all indices of `t` except `right_inds`.
- `right_inds`: right indices to be used in the eigenvalue decomposition. Defaults to all indices of `t` except `left_inds`.
"""
function KrylovKit.eigsolve(
A::Tensor,
x₀,
howmany::Int,
which::KrylovKit.Selector,
alg::Algorithm;
left_inds=Symbol[],
right_inds=Symbol[],
kwargs...,
) where {Algorithm<:KrylovKit.Lanczos} # KrylovKit.KrylovAlgorithm}
Amat, left_sizes, right_sizes = eigsolve_prehook_tensor_reshape(A, left_inds, right_inds)

# Compute eigenvalues and eigenvectors
vals, vecs, info = KrylovKit.eigsolve(Amat, x₀, howmany, which, alg; kwargs...)

# Tensorify the eigenvectors
Avecs = [Tensor(reshape(vec, left_sizes...), left_inds) for vec in vecs]

return vals, Avecs, info
end

end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
DeltaArrays = "10b0fc19-5ccc-4427-889b-d75dd6306188"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
NetworkLayout = "46757867-2c16-5918-afeb-47bfcb05e46a"
Expand Down
80 changes: 80 additions & 0 deletions test/integration/KrylovKit_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
@testset "KrylovKit.eigsolve" begin
using Tenet: Tensor
using KrylovKit

A = rand(ComplexF64, 4, 4)
data = (A + A') / 2 # Make it Hermitian
tensor = Tensor(data, (:i, :j))

# Perform eigensolve
vals, vecs, info = eigsolve(tensor; left_inds=[:i], right_inds=[:j])

@test length(vals) == 4
@test length(vecs) == 4

for vec in vecs
@test inds(vec) == [:i]
@test size(vec) == (4,)
end

# throw if index is not present
@test_throws ArgumentError eigsolve(tensor; left_inds=[:z])
@test_throws ArgumentError eigsolve(tensor; right_inds=[:z])

# throw if the resulting matrix is not square
tensor_non_square = Tensor(rand(ComplexF64, 2, 4, 6), (:i, :j, :k))
@test_throws ArgumentError eigsolve(tensor_non_square; left_inds=[:i, :j], right_inds=[:k])
@test_throws ArgumentError eigsolve(tensor_non_square; right_inds=[:j, :k])

# Convert vecs to matrix form for reconstruction
V_matrix = hcat([reshape(parent(vec), :) for vec in vecs]...)
D_matrix = Diagonal(vals)
reconstructed_matrix = V_matrix * D_matrix * inv(V_matrix)

# Ensure the reconstruction is correct
reconstructed_tensor = Tensor(reconstructed_matrix, (:i, :j))
@test isapprox(reconstructed_tensor, tensor)

# Test consistency with permuted tensor
vals_perm, vecs_perm, info = eigsolve(tensor; left_inds=[:j], right_inds=[:i])

@test length(vals_perm) == 4
@test length(vecs_perm) == 4

# Ensure the eigenvalues are the same
@test isapprox(sort(real.(vals)), sort(real.(vals_perm))) && isapprox(sort(imag.(vals)), sort(imag.(vals_perm)))

V_matrix_perm = hcat([reshape(parent(vec), :) for vec in vecs_perm]...)
D_matrix_perm = Diagonal(vals)
reconstructed_matrix_perm = V_matrix_perm * D_matrix_perm * inv(V_matrix_perm)

# Ensure the reconstruction is correct
reconstructed_tensor_perm = Tensor(reconstructed_matrix_perm, (:j, :i))
@test isapprox(reconstructed_tensor_perm, transpose(tensor))

@test parent(reconstructed_tensor) parent(transpose(reconstructed_tensor_perm))

@testset "Lanczos" begin
vals_lanczos, vecs_lanczos = eigsolve(
tensor, rand(ComplexF64, 4), 1, :SR, Lanczos(; krylovdim=2, tol=1e-16); left_inds=[:i], right_inds=[:j]
)

@test length(vals_lanczos) == 1
@test length(vecs_lanczos) == 1

@test minimum(vals) first(vals_lanczos)
end

A = rand(ComplexF64, 4, 4)
data = (A + A') / 2 # Make it Hermitian
tensor = Tensor(reshape(data, 2, 2, 2, 2), (:i, :j, :k, :l))

vals, vecs, info = eigsolve(tensor; left_inds=[:i, :j], right_inds=[:k, :l])

# Convert vecs to matrix form for reconstruction
V_matrix = hcat([reshape(parent(vec), :) for vec in vecs]...)
D_matrix = Diagonal(vals)
reconstructed_matrix = V_matrix * D_matrix * inv(V_matrix)

@test isapprox(reconstructed_matrix, data)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ if VERSION >= v"1.10"
# include("integration/BlockArray_test.jl")
include("integration/Dagger_test.jl")
include("integration/Makie_test.jl")
include("integration/KrylovKit_test.jl")
include("integration/Quac_test.jl")
end
end
Expand Down

0 comments on commit 320cb02

Please sign in to comment.