diff --git a/Project.toml b/Project.toml index 3c57e214..79d9e5fe 100644 --- a/Project.toml +++ b/Project.toml @@ -6,10 +6,15 @@ version = "0.3.20" [deps] ExactOptimalTransport = "24df6009-d856-477c-ac5c-91f668376b31" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" +LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +NearestNeighborDescent = "dd2c4c9e-a32f-5b2f-b342-08c2f244fce8" +NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" +RandomMatrix = "0af1cf96-9b30-454e-9d9e-87908f700846" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] ExactOptimalTransport = "0.1, 0.2" diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index bbf0a29a..a7385391 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -13,6 +13,7 @@ using LinearAlgebra using IterativeSolvers using LogExpFunctions: LogExpFunctions using NNlib: NNlib +using SparseArrays export SinkhornGibbs, SinkhornStabilized, SinkhornEpsilonScaling export SinkhornBarycenterGibbs @@ -39,6 +40,7 @@ include("entropic/sinkhorn_solve.jl") include("quadratic.jl") include("quadratic_newton.jl") +include("quadratic_newton_symm.jl") include("dual/entropic_dual.jl") diff --git a/src/quadratic_newton_symm.jl b/src/quadratic_newton_symm.jl new file mode 100644 index 00000000..415494d1 --- /dev/null +++ b/src/quadratic_newton_symm.jl @@ -0,0 +1,444 @@ +struct SymmetricQuadraticOTNewton{T<:Real,K<:Real,D<:Real} <: QuadraticOT + θ::T + κ::K + δ::D + armijo_max::Int +end + +struct SymmetricQuadraticOTNewtonAS{T<:Real,K<:Real,D<:Real,ST<:AbstractSparseArray, IT, JT} <: QuadraticOT + θ::T + κ::K + δ::D + armijo_max::Int + S::ST + I::IT + J::JT +end + +Base.show(io::IO, ::SymmetricQuadraticOTNewtonAS) = print(io, "Symmetric semi-smooth Newton algorithm (active set)") + +function SymmetricQuadraticOTNewton(; θ=0.1, κ=0.5, δ=1e-5, armijo_max=50) + return SymmetricQuadraticOTNewton(θ, κ, δ, armijo_max) +end + +function SymmetricQuadraticOTNewtonAS(S; θ=0.1, κ=0.5, δ=1e-5, armijo_max=50) + I, J, V = findnz(S) + return SymmetricQuadraticOTNewtonAS(θ, κ, δ, armijo_max, S, I, J) +end + +function check_convergence( + μ::AbstractVector, + cache::QuadraticOTNewtonCache, + convergence_cache::QuadraticOTConvergenceCache, + atol::Real, + rtol::Real, +) + γ = cache.γ + norm_diff = norm(vec(sum(γ; dims=2)) .- μ, Inf) + isconverged = + norm_diff < + max(atol, rtol * max(convergence_cache.norm_source, convergence_cache.norm_target)) + return isconverged, norm_diff +end + +function descent_dir!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewton}) + # unpack solver + eps = solver.eps + C = solver.C + μ = solver.source + cache = solver.cache + # unpack cache + u = cache.u + δu = cache.δu + σ = cache.σ + γ = cache.γ + G = cache.G + x = cache.x + M = cache.M + N = M + # Armijo parameters + δ = solver.alg.δ + + # setup intermediate variables + @. γ = u + u' - C + @. σ = γ ≥ 0 + @. γ = NNlib.relu(γ) / eps + + # setup kernel matrix G + G = Diagonal(vec(sum(σ; dims=2))) + σ + δ*I + + # cg step + b = -eps * (vec(sum(γ; dims=2)) .- μ) + cg!(x, G, b) + δu .= x +end + +function descent_dir!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewtonAS}) + # unpack solver + eps = solver.eps + C = solver.C + μ = solver.source + cache = solver.cache + # unpack cache + u = cache.u + δu = cache.δu + σ = cache.σ + γ = cache.γ + G = cache.G + x = cache.x + M = cache.M + N = M + # Armijo parameters + δ = solver.alg.δ + S = solver.alg.S + I = solver.alg.I + J = solver.alg.J + + # setup intermediate variables + @. γ.nzval = u[I] + u[J] - C.nzval + + @. σ.nzval = γ.nzval ≥ 0 + @. γ.nzval = NNlib.relu(γ.nzval) / eps + + # setup kernel matrix G + # G = Diagonal(vec(sum(σ; dims=2))) + σ + δ*I + fill!(G.nzval, 0) + G += σ + G += Diagonal(vec(sum(σ; dims=2)) .+ δ) + + # cg step + b = -eps * (vec(sum(γ; dims=2)) .- μ) + cg!(x, G, b) + δu .= x +end + +function descent_step!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewton}) + # unpack solver + eps = solver.eps + C = solver.C + μ = solver.source + cache = solver.cache + # unpack cache + u = cache.u + δu = cache.δu + γ = cache.γ + + # Armijo parameters + θ = solver.alg.θ + κ = solver.alg.κ + armijo_max = solver.alg.armijo_max + armijo_counter = 0 + + # dual objective + function Φ(u, μ, C, ε) + return norm(NNlib.relu.(u .+ u' .- C))^2 / 2 - 2*ε * dot(μ, u) + end + + # compute directional derivative + d = -eps * (2*dot(δu, μ)) + eps * dot(γ, δu .+ δu') + t = 1 + Φ0 = Φ(u, μ, C, eps) + while (armijo_counter < armijo_max) && + (Φ(u + t * δu, μ, C, eps) ≥ Φ0 + t * θ * d) + t = κ * t + armijo_counter += 1 + end + return u .= u + t * δu +end + +function descent_step!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewtonAS}) + # unpack solver + eps = solver.eps + C = solver.C + μ = solver.source + cache = solver.cache + # unpack cache + u = cache.u + δu = cache.δu + γ = cache.γ + S = solver.alg.S + I = solver.alg.I + J = solver.alg.J + + # Armijo parameters + θ = solver.alg.θ + κ = solver.alg.κ + armijo_max = solver.alg.armijo_max + armijo_counter = 0 + + # dual objective + function Φ(u, μ, C, ε, I, J) + return norm(NNlib.relu.(u[I] + u[J] - C.nzval))^2 / 2 - 2*ε * dot(μ, u) + end + + # compute directional derivative + d = -eps * (2*dot(δu, μ)) + eps * dot(γ.nzval, δu[I] + δu[J]) + t = 1 + Φ0 = Φ(u, μ, C, eps, I, J) + while (armijo_counter < armijo_max) && + (Φ(u + t * δu, μ, C, eps, I, J) ≥ Φ0 + t * θ * d) + t = κ * t + armijo_counter += 1 + end + return u .= u + t * δu +end + + +function solve!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewton}) + # unpack solver + μ = solver.source + atol = solver.atol + rtol = solver.rtol + maxiter = solver.maxiter + check_convergence = solver.check_convergence + cache = solver.cache + convergence_cache = solver.convergence_cache + + isconverged = false + to_check_step = check_convergence + for iter in 1:maxiter + # compute descent direction + descent_dir!(solver) + # Newton step + descent_step!(solver) + # check source marginal + # always check convergence after the final iteration + to_check_step -= 1 + if to_check_step == 0 || iter == maxiter + # reset counter + to_check_step = check_convergence + + isconverged, abserror = OptimalTransport.check_convergence( + μ, μ, cache, convergence_cache, atol, rtol + ) + @debug string(solver.alg) * + " (" * + string(iter) * + "/" * + string(maxiter) * + ": absolute error of source marginal = " * + string(maximum(abserror)) + + if isconverged + @debug "$(solver.alg) ($iter/$maxiter): converged" + break + end + end + end + + if !isconverged + @warn "$(solver.alg) ($maxiter/$maxiter): not converged" + end + + return nothing +end + + +function solve!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewtonAS}) + # unpack solver + μ = solver.source + atol = solver.atol + rtol = solver.rtol + maxiter = solver.maxiter + check_convergence = solver.check_convergence + cache = solver.cache + convergence_cache = solver.convergence_cache + u_prev = similar(solver.cache.u) + copy!(u_prev, solver.cache.u) + + # active-set method only works for symmetric case with uniform weights. + # verify this is indeed the case + if !all(maximum(μ) .== μ) throw(ArgumentError("Active set method only works for uniform weights.")) end + + function check_convergence_dual(u, u_prev, atol, rtol) + norm_diff = norm(u - u_prev) + isconverged = + norm_diff < + max(atol, rtol * max(norm(u), norm(u_prev))) + return isconverged, norm_diff + end + + isconverged = false + to_check_step = check_convergence + for iter in 1:maxiter + # compute descent direction + descent_dir!(solver) + # Newton step + descent_step!(solver) + # check source marginal + # always check convergence after the final iteration + to_check_step -= 1 + if to_check_step == 0 || iter == maxiter + # reset counter + to_check_step = check_convergence + + # isconverged, abserror = OptimalTransport.check_convergence( + # μ, μ, cache, convergence_cache, atol, rtol + # ) + isconverged, abserror = check_convergence_dual(solver.cache.u, u_prev, atol, rtol) + @debug string(solver.alg) * + " (" * + string(iter) * + "/" * + string(maxiter) * + ": absolute error of source marginal = " * + string(maximum(abserror)) + + if isconverged + @debug "$(solver.alg) ($iter/$maxiter): converged" + break + end + copy!(u_prev, solver.cache.u) + end + end + + if !isconverged + @warn "$(solver.alg) ($maxiter/$maxiter): not converged" + end + + return nothing +end + +function build_cache( + ::Type{T}, + ::SymmetricQuadraticOTNewton, + μ::AbstractVector, + ν::AbstractVector, + C::AbstractMatrix, + ε::Real, +) where {T} + # create and initialize dual potentials + u = similar(μ, T, size(μ, 1)) + v = similar(ν, T, size(ν, 1)) + fill!(u, zero(T)) + fill!(v, zero(T)) + δu = similar(u, T) + δv = similar(v, T) + # intermediate variables (don't need to be initialised) + σ = similar(C, T) + γ = similar(C, T) + M = size(μ, 1) + N = size(ν, 1) + G = similar(u, T, M, M) + fill!(G, zero(T)) + # initial guess for conjugate gradient + x = similar(u, T, M) + fill!(x, zero(T)) + return QuadraticOTNewtonCache(u, v, δu, δv, σ, γ, G, x, M, N) +end + +function build_cache( + ::Type{T}, + alg::SymmetricQuadraticOTNewtonAS, + μ::AbstractVector, + ν::AbstractVector, + C::AbstractMatrix, + ε::Real + ) where {T} + # create and initialize dual potentials + u = similar(μ, T, size(μ, 1)) + v = similar(ν, T, size(ν, 1)) + fill!(u, zero(T)) + fill!(v, zero(T)) + δu = similar(u, T) + δv = similar(v, T) + # intermediate variables (don't need to be initialised) + σ = similar(C, T) + γ = similar(C, T) + M = size(μ, 1) + N = size(ν, 1) + # G = similar(u, T, M, M) + G = similar(alg.S + sparse(SparseArrays.I, M, M), T) + # initial guess for conjugate gradient + x = similar(u, T, M) + fill!(x, zero(T)) + return QuadraticOTNewtonCache(u, v, δu, δv, σ, γ, G, x, M, N) +end + +function build_solver( + μ::AbstractVector, + C::AbstractMatrix, + ε::Real, + alg::QuadraticOT; + atol=nothing, + rtol=nothing, + check_convergence=1, + maxiter::Int=100, +) + # check that source and target marginals have the correct size + checksize(μ, μ, C) + # do not use checksize2 since for quadratic OT (at least for now) we do not support batch computations + + # compute type + T = float(Base.promote_eltype(μ, one(eltype(C)) / ε)) + + # build caches + cache = build_cache(T, alg, μ, μ, C, ε) + convergence_cache = build_convergence_cache(T, μ, μ) + + # set tolerances + _atol = atol === nothing ? 0 : atol + _rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol + + # create solver + solver = QuadraticOTSolver( + μ, μ, C, ε, alg, _atol, _rtol, maxiter, check_convergence, cache, convergence_cache + ) + return solver +end + + +# interface +function quadreg(μ, C, ε, alg::SymmetricQuadraticOTNewton; kwargs...) + solver = build_solver(μ, C, ε, alg; kwargs...) + solve!(solver) + γ = plan(solver) + return γ +end + +function quadreg(μ, C, ε, alg::SymmetricQuadraticOTNewtonAS; maxiter_as = 5, kwargs...) + function check_convergence(γ, cache, convergence_cache, atol, rtol) + norm_diff = norm(vec(sum(γ; dims=2)) .- μ, Inf) + isconverged = + norm_diff < + max(atol, rtol * max(convergence_cache.norm_source, convergence_cache.norm_target)) + return isconverged, norm_diff + end + S = alg.S + I = alg.I + J = alg.J + γ = spzeros(size(C)...) + u = similar(μ) + fill!(u, 0) + for iter = 1:maxiter_as + _alg = SymmetricQuadraticOTNewtonAS(alg.θ, alg.κ, alg.δ, alg.armijo_max, S, I, J) + Csp_v = similar(C, length(_alg.I)); @inbounds for k = 1:length(_alg.I) Csp_v[k] = C[_alg.I[k], _alg.J[k]] end + solver = build_solver(μ, sparse(_alg.I, _alg.J, Csp_v), ε, _alg; kwargs...) + copy!(solver.cache.u, u) + solve!(solver) + copy!(u, solver.cache.u) + γ = plan(solver, C) + isconverged, norm_diff = check_convergence(γ, solver.cache, solver.convergence_cache, solver.atol, solver.rtol) + if isconverged + @debug "$(solver.alg) AS step ($iter/$maxiter_as): converged" + break + else + S = max.(_alg.S, sign.(γ)) + I, J, _ = findnz(S) + @debug "$(solver.alg) AS growing support to $(nnz(S)) (sparsity $(nnz(S)/length(S)))" + end + end + return γ +end + +function plan(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewton}) + cache = solver.cache + γ = NNlib.relu.(cache.u .+ cache.u' .- solver.C) / solver.eps + return γ +end + +function plan(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewtonAS}, Cfull) + cache = solver.cache + γ = sparse(NNlib.relu.(cache.u .+ cache.u' .- Cfull) / solver.eps) + return γ +end