Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3ヶ月コース@鮎川 #23

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions Lecture_TUKR/ayukawa_TUKR.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import numpy as np
import jax,jaxlib
import jax.numpy as jnp
from tqdm import tqdm #プログレスバーを表示させてくれる


class UKR:
def __init__(self, X, latent_dim, sigma, prior='random', Zinit=None):
#--------初期値を設定する.---------
self.X = X
#ここから下は書き換えてね
self.nb_samples, self.ob_dim = X.shape
self.sigma =sigma
self.latent_dim =latent_dim

if Zinit is None:
if prior == 'random': #一様事前分布のとき
self.Z = np.random.uniform(0, self.sigma*0.001, (self.nb_samples, self.latent_dim))
# Z1_vec = np.random.uniform(low=-1, high=1, size=Z)
# Z1_colum_vec = np.random.uniform(low=-1, high=1, size=[Z, 1])
# else: #ガウス事前分布のとき
# else: #Zの初期値が与えられた時
#self.Z = Zinit

self.history = {}

def kernel(self, Z1, Z2): #写像の計算 TUKRの式に変更
Mom = jnp.sum((Z1[:, None, :] - Z2[None, :, :]) ** 2, axis=2)
Chi = jnp.exp(-1/(2*self.sigma**2)*Mom)
f = (Chi@self.X)/jnp.sum(Chi, axis=1, keepdims=True)

return f

def E(self, Z, X, alpha, norm): #目的関数の計算
E = np.sum((X - self.kernel(Z,Z))**2)
R = alpha * jnp.sum(jnp.abs(Z ** norm))
E = E / self.nb_samples + R / self.nb_samples

return E

def fit(self, nb_epoch: int, eta: float, alpha: float, norm: float) :
# 学習過程記録用
self.history['z'] = np.zeros((nb_epoch, self.nb_samples, self.latent_dim))
self.history['kernel'] = np.zeros((nb_epoch, self.nb_samples, self.ob_dim))
self.history['error'] = np.zeros(nb_epoch)

for epoch in tqdm(range(nb_epoch)):
# Zの更新
dEdx = jax.grad(self.E, argnums=0)(self.Z, self.X, alpha, norm)
self.Z -= (eta) * dEdx


# 学習過程記録用
self.history['z'][epoch] = self.Z
self.history['kernel'][epoch] = self.kernel(self.Z,self.Z)
self.history['error'][epoch] = self.E(self.Z,self.X, alpha, norm)

#--------------以下描画用(上の部分が実装できたら実装してね)---------------------
def calc_approximate_f(self, resolution): #fのメッシュ描画用,resolution:一辺の代表点の数
nb_epoch = self.history['z'].shape[0]
self.history['y'] = np.zeros((nb_epoch, resolution ** self.latent_dim, self.ob_dim))
for epoch in tqdm(range(nb_epoch)):
zeta = create_zeta(self.Z, resolution)
Y = self.kernel(zeta, self.history['z'][epoch])
self.history['y'][epoch] = Y
return self.history['y']


def create_zeta(Z, resolution): #fのメッシュの描画用に潜在空間に代表点zetaを作る.
z_x = np.linspace(np.min(Z), np.max(Z), resolution).reshape(-1, 1)
z_y = np.linspace(np.min(Z), np.max(Z), resolution)
XX, YY = np.meshgrid(z_x, z_y)
xx = XX.reshape(-1)
yy = YY.reshape(-1)
zeta = np.concatenate([xx[:, None], yy[:, None]], axis=1)





return zeta


if __name__ == '__main__':
from data_scratch.data import load_kura_tsom
from Lecture_UKR.data import create_rasen
from Lecture_UKR.data import create_2d_sin_curve
from visualizer import visualize_history

#各種パラメータ変えて遊んでみてね.
##
epoch = 300 #学習回数
sigma = 0.4 #カーネルの幅
eta = 2 #学習率
latent_dim = 2 #潜在空間の次元

alpha = 0
norm = 10

seed = 2
np.random.seed(seed)

#入力データ(詳しくはdata.pyを除いてみると良い)
nb_samples = 200 #データ数
X = load_kura_tsom(nb_samples) #鞍型データ ob_dim=3, 真のL=2
# X = create_rasen(nb_samples) #らせん型データ ob_dim=3, 真のL=1
# X = create_2d_sin_curve(nb_samples) #sin型データ ob_dim=2, 真のL=1

ukr = UKR(X, latent_dim, sigma, prior='random')
ukr.fit(epoch, eta, alpha, norm)
#visualize_history(X, ukr.history['kernel'], ukr.history['z'], ukr.history['error'], save_gif=False, filename="tmp")
#----------描画部分が実装されたらコメントアウト外す----------
ukr.calc_approximate_f(resolution=10)
visualize_history(X, ukr.history['y'], ukr.history['z'], ukr.history['error'], save_gif=False, filename="tmp")

24 changes: 12 additions & 12 deletions Lecture_TUKR/data_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
import matplotlib.pyplot as plt

def load_kura_tsom(xsamples, ysamples, missing_rate=None,retz=False):
z1 =
z2 =
z1 = np.linspace(-1, 1, xsamples)
z2 = np.linspace(-1, 1, ysamples)

z1_repeated, z2_repeated =
x1 =
x2 =
x3 =
z1_repeated, z2_repeated = np.meshgrid(z1, z2)
x1 = z1_repeated
x2 = z2_repeated
x3 = x1**2 - x2**2
#ノイズを加えたい時はここをいじる,locがガウス分布の平均、scaleが分散,size何個ノイズを作るか
#このノイズを加えることによって三次元空間のデータ点は上下に動く

x =
truez =
x = np.concatenate((x1[:, :, np.newaxis], x2[:, :, np.newaxis], x3[:, :, np.newaxis]), axis=2)
truez = np.concatenate((z2_repeated[:, :, np.newaxis], z2_repeated[:, :, np.newaxis]), axis=2)

if missing_rate == 0 or missing_rate == None:
if retz:
Expand All @@ -25,13 +25,13 @@ def load_kura_tsom(xsamples, ysamples, missing_rate=None,retz=False):
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

xsamples =
ysamples =
xsamples = 20
ysamples = 20

x, truez = load_kura_tsom()
x, truez = load_kura_tsom(xsamples, ysamples, retz=True)

fig = plt.figure(figsize=[5, 5])
ax_x = fig.add_subplot(projection='3d')
ax_x.scatter()
ax_x.scatter(x[:, :, 0], x[:, :, 1], x[:, :, 2])
ax_x.set_title('Generated three-dimensional data')
plt.show()
1 change: 1 addition & 0 deletions Lecture_TUKR/sample/tukr_kura_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def load_kura_tsom(xsamples, ysamples, missing_rate=None,retz=False):
#ノイズを加えたい時はここをいじる,locがガウス分布の平均、scaleが分散,size何個ノイズを作るか
#このノイズを加えることによって三次元空間のデータ点は上下に動く


x = np.concatenate((x1[:, :, np.newaxis], x2[:, :, np.newaxis], x3[:, :, np.newaxis]), axis=2)
truez = np.concatenate((z1_repeated[:, :, np.newaxis], z2_repeated[:, :, np.newaxis]), axis=2)

Expand Down
3 changes: 2 additions & 1 deletion Lecture_TUKR/test.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
import numpy as np
import numpy as np
print("hello")