-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add KrylovKit as weakdep and extension * Add extension TenetKrylovKitExt * Add KrylovKit tests to integration tests * Add module semantic * Format code * Extend tests * Return info from KrylovKit * Refactor code * Fix permutation test * Fix test * Fix unexisting variable name --------- Co-authored-by: Sergio Sánchez Ramírez <sergio.sanchez.ramirez+git@bsc.es> Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com>
- Loading branch information
1 parent
5a09e3d
commit 320cb02
Showing
5 changed files
with
183 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
module TenetKrylovKitExt | ||
|
||
using Tenet | ||
using KrylovKit | ||
|
||
function eigsolve_prehook_tensor_reshape(A::Tensor, left_inds, right_inds) | ||
left_inds, right_inds = Tenet.factorinds(A, left_inds, right_inds) | ||
|
||
# Determine the left and right indices | ||
left_sizes = size.((A,), left_inds) | ||
right_sizes = size.((A,), right_inds) | ||
prod_left_sizes = prod(left_sizes) | ||
prod_right_sizes = prod(right_sizes) | ||
|
||
if prod_left_sizes != prod_right_sizes | ||
throw( | ||
ArgumentError("The resulting matrix must be square, but got sizes $prod_left_sizes and $prod_right_sizes.") | ||
) | ||
end | ||
|
||
# Permute and reshape the tensor | ||
A = permutedims(A, [left_inds..., right_inds...]) | ||
Amat = reshape(parent(A), prod_left_sizes, prod_right_sizes) | ||
|
||
return Amat, left_sizes, right_sizes | ||
end | ||
|
||
function KrylovKit.eigselector(A::Tensor, T::Type; left_inds=Symbol[], right_inds=Symbol[], kwargs...) | ||
Amat, _, _ = eigsolve_prehook_tensor_reshape(A, left_inds, right_inds) | ||
return KrylovKit.eigselector(Amat, T; kwargs...) | ||
end | ||
|
||
function KrylovKit.eigsolve( | ||
A::Tensor, | ||
howmany::Int=1, | ||
which::KrylovKit.Selector=:LM, | ||
T::Type=eltype(A); | ||
left_inds=Symbol[], | ||
right_inds=Symbol[], | ||
kwargs..., | ||
) | ||
Amat, left_sizes, right_sizes = eigsolve_prehook_tensor_reshape(A, left_inds, right_inds) | ||
|
||
# Compute eigenvalues and eigenvectors | ||
vals, vecs, info = KrylovKit.eigsolve(Amat, howmany, which; kwargs...) | ||
|
||
# Tensorify the eigenvectors | ||
Avecs = [Tensor(reshape(vec, left_sizes...), left_inds) for vec in vecs] | ||
|
||
return vals, Avecs, info | ||
end | ||
|
||
function KrylovKit.eigsolve( | ||
f::Tensor, x₀, howmany::Int=1, which::KrylovKit.Selector=:LM; left_inds=Symbol[], right_inds=Symbol[], kwargs... | ||
) | ||
Amat, left_sizes, right_sizes = eigsolve_prehook_tensor_reshape(A, left_inds, right_inds) | ||
|
||
# Compute eigenvalues and eigenvectors | ||
vals, vecs, info = KrylovKit.eigsolve(Amat, x₀, howmany, which; kwargs...) | ||
|
||
# Tensorify the eigenvectors | ||
Avecs = [Tensor(reshape(vec, left_sizes...), left_inds) for vec in vecs] | ||
|
||
return vals, Avecs, info | ||
end | ||
|
||
""" | ||
KrylovKit.eigsolve(tensor::Tensor; left_inds, right_inds, kwargs...) | ||
Perform eigenvalue decomposition on a tensor. | ||
# Keyword arguments | ||
- `left_inds`: left indices to be used in the eigenvalue decomposition. Defaults to all indices of `t` except `right_inds`. | ||
- `right_inds`: right indices to be used in the eigenvalue decomposition. Defaults to all indices of `t` except `left_inds`. | ||
""" | ||
function KrylovKit.eigsolve( | ||
A::Tensor, | ||
x₀, | ||
howmany::Int, | ||
which::KrylovKit.Selector, | ||
alg::Algorithm; | ||
left_inds=Symbol[], | ||
right_inds=Symbol[], | ||
kwargs..., | ||
) where {Algorithm<:KrylovKit.Lanczos} # KrylovKit.KrylovAlgorithm} | ||
Amat, left_sizes, right_sizes = eigsolve_prehook_tensor_reshape(A, left_inds, right_inds) | ||
|
||
# Compute eigenvalues and eigenvectors | ||
vals, vecs, info = KrylovKit.eigsolve(Amat, x₀, howmany, which, alg; kwargs...) | ||
|
||
# Tensorify the eigenvectors | ||
Avecs = [Tensor(reshape(vec, left_sizes...), left_inds) for vec in vecs] | ||
|
||
return vals, Avecs, info | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
@testset "KrylovKit.eigsolve" begin | ||
using Tenet: Tensor | ||
using KrylovKit | ||
|
||
A = rand(ComplexF64, 4, 4) | ||
data = (A + A') / 2 # Make it Hermitian | ||
tensor = Tensor(data, (:i, :j)) | ||
|
||
# Perform eigensolve | ||
vals, vecs, info = eigsolve(tensor; left_inds=[:i], right_inds=[:j]) | ||
|
||
@test length(vals) == 4 | ||
@test length(vecs) == 4 | ||
|
||
for vec in vecs | ||
@test inds(vec) == [:i] | ||
@test size(vec) == (4,) | ||
end | ||
|
||
# throw if index is not present | ||
@test_throws ArgumentError eigsolve(tensor; left_inds=[:z]) | ||
@test_throws ArgumentError eigsolve(tensor; right_inds=[:z]) | ||
|
||
# throw if the resulting matrix is not square | ||
tensor_non_square = Tensor(rand(ComplexF64, 2, 4, 6), (:i, :j, :k)) | ||
@test_throws ArgumentError eigsolve(tensor_non_square; left_inds=[:i, :j], right_inds=[:k]) | ||
@test_throws ArgumentError eigsolve(tensor_non_square; right_inds=[:j, :k]) | ||
|
||
# Convert vecs to matrix form for reconstruction | ||
V_matrix = hcat([reshape(parent(vec), :) for vec in vecs]...) | ||
D_matrix = Diagonal(vals) | ||
reconstructed_matrix = V_matrix * D_matrix * inv(V_matrix) | ||
|
||
# Ensure the reconstruction is correct | ||
reconstructed_tensor = Tensor(reconstructed_matrix, (:i, :j)) | ||
@test isapprox(reconstructed_tensor, tensor) | ||
|
||
# Test consistency with permuted tensor | ||
vals_perm, vecs_perm, info = eigsolve(tensor; left_inds=[:j], right_inds=[:i]) | ||
|
||
@test length(vals_perm) == 4 | ||
@test length(vecs_perm) == 4 | ||
|
||
# Ensure the eigenvalues are the same | ||
@test isapprox(sort(real.(vals)), sort(real.(vals_perm))) && isapprox(sort(imag.(vals)), sort(imag.(vals_perm))) | ||
|
||
V_matrix_perm = hcat([reshape(parent(vec), :) for vec in vecs_perm]...) | ||
D_matrix_perm = Diagonal(vals) | ||
reconstructed_matrix_perm = V_matrix_perm * D_matrix_perm * inv(V_matrix_perm) | ||
|
||
# Ensure the reconstruction is correct | ||
reconstructed_tensor_perm = Tensor(reconstructed_matrix_perm, (:j, :i)) | ||
@test isapprox(reconstructed_tensor_perm, transpose(tensor)) | ||
|
||
@test parent(reconstructed_tensor) ≈ parent(transpose(reconstructed_tensor_perm)) | ||
|
||
@testset "Lanczos" begin | ||
vals_lanczos, vecs_lanczos = eigsolve( | ||
tensor, rand(ComplexF64, 4), 1, :SR, Lanczos(; krylovdim=2, tol=1e-16); left_inds=[:i], right_inds=[:j] | ||
) | ||
|
||
@test length(vals_lanczos) == 1 | ||
@test length(vecs_lanczos) == 1 | ||
|
||
@test minimum(vals) ≈ first(vals_lanczos) | ||
end | ||
|
||
A = rand(ComplexF64, 4, 4) | ||
data = (A + A') / 2 # Make it Hermitian | ||
tensor = Tensor(reshape(data, 2, 2, 2, 2), (:i, :j, :k, :l)) | ||
|
||
vals, vecs, info = eigsolve(tensor; left_inds=[:i, :j], right_inds=[:k, :l]) | ||
|
||
# Convert vecs to matrix form for reconstruction | ||
V_matrix = hcat([reshape(parent(vec), :) for vec in vecs]...) | ||
D_matrix = Diagonal(vals) | ||
reconstructed_matrix = V_matrix * D_matrix * inv(V_matrix) | ||
|
||
@test isapprox(reconstructed_matrix, data) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters