Skip to content

Commit

Permalink
Added Auto encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
noblec04 committed Jul 28, 2024
1 parent 52a553d commit e97fb30
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 6 deletions.
91 changes: 91 additions & 0 deletions MatlabGP/+NN2/AE.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
classdef AE

properties
Encoder
Decoder
lossfunc

lb_x
ub_x
end

methods

function obj = AE(Encoder,Decoder,loss)
obj.Encoder = Encoder;
obj.Decoder = Decoder;
obj.lossfunc = loss;
end

function [y,obj] = forward(obj,x)

[y] = obj.Encoder.forward(x);
[y] = obj.Decoder.forward(y);


end

function [y] = predict(obj,x)

[y] = obj.Decoder.forward(x);

end

function V = getHPs(obj)

V=obj.Encoder.getHPs();
V=[V;obj.Decoder.getHPs()];

end

function obj = setHPs(obj,V)

nE = numel(obj.Encoder.getHPs());
Vl = V(1:nE);
obj.Encoder = obj.Encoder.setHPs(Vl);
Vl = V(nE+1:end);
obj.Decoder = obj.Decoder.setHPs(Vl);

end

function [e,de] = loss(obj,V,x,y)

nV = length(V(:));

V = AutoDiff(V(:));

obj = obj.setHPs(V(:));

[yp] = obj.forward(x);

[eout] = obj.lossfunc.forward(y,yp);

e1 = sum(eout,2);

e = getvalue(e1);
de = getderivs(e1);
de = reshape(full(de),[1 nV]);

end

function [obj,fval] = train(obj,x,y)%,xv,fv

obj.lb_x = min(x);
obj.ub_x = max(x);

x = (x - obj.lb_x)./(obj.ub_x - obj.lb_x);

tx0 = (obj.getHPs())';

func = @(V) obj.loss(V,x,y);


opts = optimoptions('fmincon','SpecifyObjectiveGradient',true,'MaxFunctionEvaluations',5000,'MaxIterations',2000,'Display','iter');
[theta,fval] = fmincon(func,tx0,[],[],[],[],[],[],[],opts);

%[theta,fval,xv,fv] = VSGD(func,tx0,'lr',0.001,'gamma',0.001,'iters',3000,'tol',1*10^(-7));

obj = obj.setHPs(theta(:));
end
end
end
105 changes: 99 additions & 6 deletions MatlabGP/VGP.m
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,19 @@
Y2 = obj2.eval(obj.X);

dy = -1*abs(sum(Y2-Y1));


end

function [thetas,ntm,ntk,tm0,tk0] = getHPs(obj)

tm0 = obj.mean.getHPs();
tk0 = obj.kernel.getHPs();

ntm = numel(tm0);
ntk = numel(tk0);

thetas = [tm0 tk0 obj.kernel.signn];

end

function obj = condition(obj,X,Y)
Expand All @@ -157,13 +169,13 @@
obj.kernel.scale = std(Y)/2;

obj.Kuu = obj.kernel.build(xu,xu);
obj.Kuuinv = pinv(obj.Kuu);
obj.Kuuinv = pinv(obj.Kuu,1*10^(-7));

obj.Kuf = obj.kernel.build(xu,xf);

obj.B = obj.Kuf*obj.Kuf'/obj.kernel.signn;
obj.M = obj.Kuu + obj.B;
obj.Minv = pinv(obj.M);
obj.Minv = pinv(obj.M,1*10^(-7));

obj.alpha = obj.Minv*obj.Kuf*(Y - obj.mean.eval(X))/obj.kernel.signn;

Expand All @@ -173,14 +185,15 @@

if regress
obj.kernel.signn = theta(end);
theta(end) = [];
end

tm0 = obj.mean.getHPs();
ntm = numel(tm0);
tk0 = obj.kernel.getHPs();
ntk = numel(tk0);

obj.mean = obj.mean.setHPs(theta(1:ntm));
obj.kernel = obj.kernel.setHPs(theta(ntm+1:end));
obj.kernel = obj.kernel.setHPs(theta(ntm+1:ntm+ntk));

obj = obj.condition(obj.X,obj.Y);

Expand All @@ -194,6 +207,37 @@
nll(isinf(nll)) = 0;
end

function [loss, dloss] = loss(obj,theta,regress)

nV = length(theta(:));
tm0 = obj.mean.getHPs();
ntm = numel(tm0);
tk0 = obj.kernel.getHPs();
ntk = numel(tk0);

theta = AutoDiff(theta);

if regress
obj.kernel.signn = theta(ntm+ntk+1);
end

obj.mean = obj.mean.setHPs(theta(1:ntm));
obj.kernel = obj.kernel.setHPs(theta(ntm+1:ntm+ntk));

obj = obj.condition(obj.X,obj.Y);

its = randsample(size(obj.X,1),max(5,ceil(size(obj.X,1)/50)));

nll = obj.nLL(obj.X(its,:),obj.Y(its));

nll = -1*nll;

loss = getvalue(nll);
dloss = getderivs(nll);
dloss = reshape(full(dloss),[1 nV]);

end

function [obj,nll] = train(obj,regress)

if obj.kernel.signn==0||nargin<2
Expand Down Expand Up @@ -222,7 +266,7 @@
func = @(x) obj.LL(x,regress);


xxt = tlb + (tub - tlb).*lhsdesign(200*length(tlb),length(tlb));
xxt = tlb + (tub - tlb).*lhsdesign(500*length(tlb),length(tlb));

for ii = 1:size(xxt,1)
LL(ii) = -1*func(xxt(ii,:));
Expand Down Expand Up @@ -261,6 +305,55 @@

end

function [obj,nll] = train2(obj,regress)

if obj.kernel.signn==0||nargin<2
regress=1;
end

tm0 = obj.mean.getHPs();
ntm = numel(tm0);

tmlb = 0*tm0 - 10;
tmub = 0*tm0 + 10;

tk0 = obj.kernel.getHPs();

tklb = 0*tk0 + 0.001;
tkub = 0*tk0 + 2;

tlb = [tmlb tklb];
tub = [tmub tkub];

if regress
tlb(end+1) = 0.001;
tub(end+1) = std(obj.Y)/5;
end

func = @(x) obj.loss(x,regress);

for i = 1:3
tx0 = tlb + (tub - tlb).*rand(1,length(tlb));

[theta{i},val(i)] = VSGD(func,tx0,'lr',0.02,'lb',tlb,'ub',tub,'gamma',0.0001,'iters',20,'tol',1*10^(-4));

end

[nll,i] = min(val);

theta = theta{i};

if regress
obj.kernel.signn = theta(end);
theta(end) = [];
end

obj.mean = obj.mean.setHPs(theta(1:ntm));
obj.kernel = obj.kernel.setHPs(theta(ntm+1:end));
obj = obj.condition(obj.X,obj.Y);

end

function obj = resolve(obj,x,y)

replicates = ismembertol(x,obj.X,1e-4,'ByRows',true);
Expand Down
Binary file modified MatlabGP/docs/TestVGPClass.mlx
Binary file not shown.
83 changes: 83 additions & 0 deletions MatlabGP/docs/testAE.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@

clear all
close all
clc

xx = lhsdesign(20,1);
yy = normrnd(forr(xx,0),0*forr(xx,0)+0);

yy = (yy-min(yy(:)))/(max(yy(:))-min(yy(:)));

xmesh = linspace(0,1,100)';
ymesh = forr(xmesh,0);

layers1{1} = NN2.FF(3,3);
layers1{2} = NN2.FF(3,2);
layers1{3} = NN2.FF(2,1);
acts1{1} = NN2.SNAKE(1);
acts1{2} = NN2.SNAKE(1);

lss = NN2.MAE();

enc = NN2.NN(layers1,acts1,lss);

layers2{1} = NN2.FF(1,2);
layers2{2} = NN2.FF(2,3);
layers2{3} = NN2.FF(3,3);
acts2{1} = NN2.SNAKE(1);
acts2{2} = NN2.SNAKE(1);

lss = NN2.MAE();

dec = NN2.NN(layers2,acts2,lss);

AE1 = NN2.AE(enc,dec,lss);

%%

t0 = AE1.getHPs();

[e,de] = AE1.loss(t0,yy,yy);

%%

tic
[AE2,fval] = AE1.train(yy,yy);%,xv,fv
toc

%%

yp2 = AE2.forward(yy);


%%
% figure
% plot(fv,'.')
% set(gca,'yscale','log')
% set(gca,'xscale','log')

figure
%plot(xmesh,yp1)
plot(yy,yp2,'.')

%%

function y = forr(x,dx)

nx = length(x);

A = 0.5; B = 10; C = -5;

for i = 1:nx
if x(i)<0.45
y(i,1) = (6*x(i)-2).^2.*sin(12*x(i)-4);
else
y(i,1) = (6*x(i)-2).^2.*sin(12*x(i)-4)+dx;
end

y(i,2) = 0.4*(6*x(i)-2).^2.*sin(12*x(i)-4)-x(i)-1;
y(i,3) = A*(6*x(i)-2).^2.*sin(12*x(i)-4)+B*(x(i)-0.5)-C;
end

end

0 comments on commit e97fb30

Please sign in to comment.