-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement lowrank_svd in NiSparseArrays #35
Comments
|
那个loss 应该是typo,应该这么写。 loss(res, z, w) = sum(res.U * Diagonal(z) * res.V) + sum(res.S .* w)
QR的backward rule也有人正在写PR,你可以先checkout ChainRules.jl对应这个PR的branch。 pkg> add https://github.com/rkube/ChainRules.jl.git 然后看看能不能满足你的要求,如果可以,你去issue下面comment下,帮助这个PR早日merge。不然的话,我们得要自己提PR。
你先和LinearAlgebra.svd对比结果是否可靠。 |
关于 loss(res, z, w) = sum(res.U * Diagonal(z) * res.V) + sum(res.S .* w) 这里的 function loss(U, S, Vt, z)
residual_mat = U * Diagonal(S) * Vt - U * Diagonal(z) * Vt
return tr(residual_mat'*residual_mat) # square of Frobenius norm
# return norm(residual_mat)
end 示例如下: julia> using NiSparseArrays:low_rank_svd
julia> using SparseArrays
julia> using LinearAlgebra
julia> A = sprand(100,10,0.2)*sprand(10,20,0.2);
julia> U, S, Vt = low_rank_svd(A, 10);
julia> function loss(U, S, Vt, z)
residual_mat = U * Diagonal(S) * Vt - U * Diagonal(z) * Vt
return tr(residual_mat'*residual_mat) # square of Frobenius norm
# return norm(residual_mat)
end
loss (generic function with 1 method)
julia> z = rand(10);
julia> loss(U, S, Vt, z)
67.38981243097754 cc: @GiggleLiu |
这一项,是因为U和V之间有“规范自由度”,比如(-U, -V)和(U, V)对应同一个解。如果乘起来符号可以相消,因此这样可以构成一个和规范自由度无关,同时又可以测试U,V的loss. |
目标:
通过
NiSparseArrays
的稀疏矩阵乘法加速lowrank_svd
的微分过程实现:
在项目申请阶段已对
lowrank_svd
进行实现;计划将lowrank_svd
实现在src/目录下,其使用可提供在examples/目录下,类似b.t.w 为什么
loss
里定义了z
作为输入最后却只用了w
?这只是一个示意例子,取出稀疏矩阵
A
的列有更合适的写法。此外,由于计算稠密基涉及到了
qr
,而ChainRules
未提供关于qr
的实现,所以会对金国学长的实现进行ChainRules
包装。其余部分我理解的应该和之前写的没太大区别,当然潜在的问题是测试:准确率与性能。关键是和谁比?
low_rank_svd
似乎没有官方的julia
实现?我目前着手在写,如果有相关的问题/建议也麻烦提出来讨论,十分感谢~
cc: @GiggleLiu @johnnychen94
The text was updated successfully, but these errors were encountered: