Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement lowrank_svd in NiSparseArrays #35

Closed
jieli-matrix opened this issue Sep 11, 2021 · 4 comments · Fixed by #36 or #37
Closed

Implement lowrank_svd in NiSparseArrays #35

jieli-matrix opened this issue Sep 11, 2021 · 4 comments · Fixed by #36 or #37

Comments

@jieli-matrix
Copy link
Owner

目标:
通过NiSparseArrays的稀疏矩阵乘法加速lowrank_svd的微分过程
实现:
在项目申请阶段已对lowrank_svd进行实现;计划将lowrank_svd实现在src/目录下,其使用可提供在examples/目录下,类似

using LinearAlgebra, Zygote, SparseArrays

loss(res, z, w) = sum(res.U * Diagonal(res.S) * res.V) + sum(res.S .* w)

A = sprand(1000, 1000, 0.1);
F = low_rank_svd(A);
Zygote.gradient(x->loss(low_rank_svd(x), x[:,1], x[:,2]), A)

b.t.w 为什么loss里定义了z作为输入最后却只用了w?
这只是一个示意例子,取出稀疏矩阵A的列有更合适的写法。

此外,由于计算稠密基涉及到了qr,而ChainRules未提供关于qr的实现,所以会对金国学长的实现进行ChainRules包装。
其余部分我理解的应该和之前写的没太大区别,当然潜在的问题是测试:准确率与性能。关键是和谁比?low_rank_svd似乎没有官方的julia实现?
我目前着手在写,如果有相关的问题/建议也麻烦提出来讨论,十分感谢~
cc: @GiggleLiu @johnnychen94

@johnnychen94
Copy link
Collaborator

low_rank_svd似乎没有官方的julia实现?

https://github.com/JuliaMatrices/LowRankApprox.jl psvd

@GiggleLiu
Copy link
Collaborator

GiggleLiu commented Sep 11, 2021

那个loss 应该是typo,应该这么写。

loss(res, z, w) = sum(res.U * Diagonal(z) * res.V) + sum(res.S .* w)

所以会对金国学长的实现进行ChainRules包装

QR的backward rule也有人正在写PR,你可以先checkout ChainRules.jl对应这个PR的branch。
JuliaDiff/ChainRules.jl#469

pkg> add https://github.com/rkube/ChainRules.jl.git

然后看看能不能满足你的要求,如果可以,你去issue下面comment下,帮助这个PR早日merge。不然的话,我们得要自己提PR。
(EDIT: 我自己试了下,发现它工作的ChainRules版本太低了,我觉得还是自己写比较可靠~)

测试:准确率与性能。

你先和LinearAlgebra.svd对比结果是否可靠。
完了就是测试AD,这个ChainRulesTestUtils可能没有对应的测试工具,我们可以通过和FiniteDifference中的jacobian函数给的结果做对比来测试,这个应该很直接就可以做。关于性能我感觉没有特别的必要,这个没有很好的参考对象,如果要和pytorch比,这还需要配置pytorch环境,不值得。

This was linked to pull requests Sep 14, 2021
@jieli-matrix
Copy link
Owner Author

jieli-matrix commented Sep 25, 2021

关于svdloss我还有些疑问:

loss(res, z, w) = sum(res.U * Diagonal(z) * res.V) + sum(res.S .* w)

这里的zS的维度必须一致,也就是估计的rank大小;我不太明白这里的w的含义,会有维度的相关限制吗?
我想象中的loss(pred, tgt)一般是用于衡量predtgt的差异;所以,可以这样写吗?

function loss(U, S, Vt, z)
      residual_mat = U * Diagonal(S) * Vt - U * Diagonal(z) * Vt
      return tr(residual_mat'*residual_mat) # square of Frobenius norm
     # return norm(residual_mat)
end

示例如下:

julia> using NiSparseArrays:low_rank_svd

julia> using SparseArrays

julia> using LinearAlgebra

julia> A = sprand(100,10,0.2)*sprand(10,20,0.2);

julia> U, S, Vt = low_rank_svd(A, 10);

julia> function loss(U, S, Vt, z)
             residual_mat = U * Diagonal(S) * Vt - U * Diagonal(z) * Vt
             return tr(residual_mat'*residual_mat) # square of Frobenius norm
            # return norm(residual_mat)
       end
loss (generic function with 1 method)

julia> z = rand(10);

julia> loss(U, S, Vt, z)
67.38981243097754

cc: @GiggleLiu

@GiggleLiu
Copy link
Collaborator

GiggleLiu commented Sep 26, 2021

loss 衡量差异是狭义的理解,广义它就是一个实数标量,用来作为变分的准则。因此这里的loss可以是任意函数,之所以取

sum(res.U * Diagonal(z) * res.V)

这一项,是因为U和V之间有“规范自由度”,比如(-U, -V)和(U, V)对应同一个解。如果乘起来符号可以相消,因此这样可以构成一个和规范自由度无关,同时又可以测试U,V的loss.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
3 participants