forked from omrysendik/DCor
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCombineGrads.m
27 lines (21 loc) · 1 KB
/
CombineGrads.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
function [grads] = CombineGrads(styleGrads, ACorrGrads, DiversityGrads, SmoothnessGrads, params)
grads = cell(length(params.unitedLayerInds),1);
for k=1:length(params.unitedLayerInds)
grads{k}=0;
auxInd = find(params.unitedLayerInds(k)==params.styleMatchLayerInds,1);
if(~isempty(auxInd))
grads{k} = grads{k}+params.styleLossWeight*styleGrads{auxInd};
end
auxInd = find(params.unitedLayerInds(k)==params.ACorrMatchLayerInds,1);
if(~isempty(auxInd))
grads{k} = grads{k}+params.ACorrLossWeight*single(ACorrGrads{auxInd});
end
auxInd = find(params.unitedLayerInds(k)==params.DiversityMatchLayerInds,1);
if(~isempty(auxInd) && params.DiversityLossWeight ~= 0)
grads{k} = grads{k}+params.DiversityLossWeight*single(DiversityGrads{auxInd});
end
auxInd = find(params.unitedLayerInds(k)==params.SmoothnessMatchLayerInds,1);
if(~isempty(auxInd))
grads{k} = grads{k}+params.SmoothnessLossWeight*single(SmoothnessGrads{auxInd});
end
end