Skip to content

Commit

Permalink
Add AdiabaticStateSelector (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesgardner1421 authored Sep 10, 2022
1 parent 01fb132 commit 284612b
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 35 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NQCModels"
uuid = "c814dc9f-a51f-4eaf-877f-82eda4edad48"
authors = ["James Gardner <james.gardner1421@gmail.com>"]
version = "0.8.14"
version = "0.8.15"

[deps]
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
Expand Down
5 changes: 4 additions & 1 deletion src/diabatic/DiabaticModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using Unitful: @u_str
using UnitfulAtomic: austrip

using Parameters: Parameters
using LinearAlgebra: Hermitian
using LinearAlgebra: LinearAlgebra, Hermitian
using StaticArrays: SMatrix, SVector

"""
Expand Down Expand Up @@ -170,4 +170,7 @@ export FullGaussLegendre
include("anderson_holstein.jl")
export AndersonHolstein

include("adiabatic_state_selector.jl")
export AdiabaticStateSelector

end # module
28 changes: 28 additions & 0 deletions src/diabatic/adiabatic_state_selector.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@

struct AdiabaticStateSelector{M} <: NQCModels.AdiabaticModels.AdiabaticModel
model::M
state::Int
function AdiabaticStateSelector(model, state)
state < 1 && throw(DomainError(state, "selected state must be greater than 0"))
state > NQCModels.nstates(model) && throw(DomainError(state, "selected state must be less than the total number of states of the diabatic model"))
return new{typeof(model)}(model, state)
end
end

NQCModels.ndofs(model::AdiabaticStateSelector) = NQCModels.ndofs(model.model)

function NQCModels.potential(model::AdiabaticStateSelector, r::AbstractMatrix)
V = NQCModels.potential(model.model, r)
eigenvalues = LinearAlgebra.eigvals(V)
return eigenvalues[model.state]
end

function NQCModels.derivative!(model::AdiabaticStateSelector, output::AbstractMatrix, r::AbstractMatrix)
V = NQCModels.potential(model.model, r)
U = LinearAlgebra.eigvecs(V)
D = NQCModels.derivative(model.model, r)
for I in eachindex(output, D)
output[I] = (U' * D[I] * U)[model.state, model.state]
end
return output
end
35 changes: 2 additions & 33 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,13 @@ using Test
using NQCBase
using NQCModels
using LinearAlgebra
using FiniteDiff
using SafeTestsets

@time @safetestset "Wide band bath discretisations" begin include("wide_band_bath_discretisations.jl") end
@time @safetestset "Anderson Holstein" begin include("anderson_holstein.jl") end
@safetestset "AdiabaticStateSelector" begin include("test_adiabatic_state_selector.jl") end

function finite_difference_gradient(model::NQCModels.AdiabaticModels.AdiabaticModel, R)
f(x) = potential(model, x)
FiniteDiff.finite_difference_gradient(f, R)
end

function finite_difference_gradient(model::NQCModels.DiabaticModels.DiabaticModel, R)
f(x, j, i) = potential(model, x)[j,i]
grad = [Hermitian(zeros(nstates(model), nstates(model))) for _ in CartesianIndices(R)]
for i=1:nstates(model)
for j=1:nstates(model)
gradient = FiniteDiff.finite_difference_gradient(x->f(x,j,i), R)
for k in eachindex(R)
grad[k].data[j,i] = gradient[k]
end
end
end
grad
end

function test_model(model::NQCModels.Model, atoms; rtol=1e-5)
R = rand(ndofs(model), atoms)
D = derivative(model, R)
finite_diff = finite_difference_gradient(model, R)
return isapprox(finite_diff, D, rtol=rtol)
end

function test_model(model::NQCModels.FrictionModels.AdiabaticFrictionModel, atoms)
R = rand(ndofs(model), atoms)
D = derivative(model, R)
friction(model, R)
return finite_difference_gradient(model, R) D
end
include("test_utils.jl")

@testset "Potential abstraction" begin
struct TestModel <: NQCModels.Model end
Expand Down
29 changes: 29 additions & 0 deletions test/test_adiabatic_state_selector.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using Test
using NQCModels
using LinearAlgebra: eigvals

include("test_utils.jl")

innermodel = DoubleWell()

@testset "State index out of bounds" begin
@test_throws DomainError AdiabaticStateSelector(innermodel, 0)
@test_throws DomainError AdiabaticStateSelector(innermodel, 3)
end

@testset "Potential, state: $selected_state" for selected_state = 1:2
model = AdiabaticStateSelector(innermodel, selected_state)
r = randn(1,1)

diabatic_potential = potential(innermodel, r)
correct_value = eigvals(diabatic_potential)[selected_state]
new_value = potential(model, r)
@test correct_value new_value
end

@testset "Potential, state: $selected_state" for selected_state = 1:2
model = AdiabaticStateSelector(innermodel, selected_state)
@test test_model(model, 1)
end


35 changes: 35 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using FiniteDiff
using NQCModels

function finite_difference_gradient(model::NQCModels.AdiabaticModels.AdiabaticModel, R)
f(x) = potential(model, x)
FiniteDiff.finite_difference_gradient(f, R)
end

function finite_difference_gradient(model::NQCModels.DiabaticModels.DiabaticModel, R)
f(x, j, i) = potential(model, x)[j,i]
grad = [Hermitian(zeros(nstates(model), nstates(model))) for _ in CartesianIndices(R)]
for i=1:nstates(model)
for j=1:nstates(model)
gradient = FiniteDiff.finite_difference_gradient(x->f(x,j,i), R)
for k in eachindex(R)
grad[k].data[j,i] = gradient[k]
end
end
end
grad
end

function test_model(model::NQCModels.Model, atoms; rtol=1e-5)
R = rand(ndofs(model), atoms)
D = derivative(model, R)
finite_diff = finite_difference_gradient(model, R)
return isapprox(finite_diff, D, rtol=rtol)
end

function test_model(model::NQCModels.FrictionModels.AdiabaticFrictionModel, atoms)
R = rand(ndofs(model), atoms)
D = derivative(model, R)
friction(model, R)
return finite_difference_gradient(model, R) D
end

0 comments on commit 284612b

Please sign in to comment.