VAEって何?[実装編]::8/1
時刻 | 行動 |
---|---|
7:50 | 起床 |
10:00 | バ始 |
19:00 | バ終 |
21:00 | 帰宅 |
21:30 | PRML読書 |
22:00 | VAE実装の続き |
Kerasを使ってVAEを実装する.
- 原著論文:[1312.6114] Auto-Encoding Variational Bayes
- まとめ殴り書き:#10 Auto-Encoding Variational Bayes - HackMD
- 実装Notebook:VAE_MNIST.ipynb · GitHub
ライブラリをインポートする
import numpy as np import matplotlib.pyplot as plt import keras.backend as K from IPython.display import clear_output from keras.layers import * from keras.models import Sequential, Model from keras.optimizers import SGD, Adam, RMSprop from utils.general.visualize import show_image from keras.datasets import mnist import plotly.plotly as py import plotly.graph_objs as go import plotly from mpl_toolkits.mplot3d import Axes3D import matplotlib as mpl import matplotlib.cm as cm
matplotlibのほかにplotly*1という可視化パッケージを使った.matplotlibと似たような使い方で,グリグリ動かせるグラフが作れる+共有ができるようだ.
なおutils.general.visualize
は画像表示は頻繁に使う関数なので別ファイルで作成して,それを読み込んでいる.
utils/general/visualize.py
import matplotlib.gridspec as gs import matplotlib.pyplot as plt import os import numpy as np """ visualize src images: src: array-like object """ def show_image(src, col=4, row=4, size=(12,12),shuffle=False, path=None, name=None,channel_first=False,image_verbose=True): assert src.ndim==4 or src.ndim==3,"image dimension should be 3 or 4." if channel_first: src = np.rollaxis(src,1,4) assert src.ndim==3 or src.shape[-1] == (3 or 1),"number of channels should be 1 or 3" if src.shape[-1] ==1: src=src[:,:,:,0] n = col * row if shuffle: idx = np.random.permutation(src.shape[0]) src=src[idx] plt.gray() src = src[:n] fig = plt.figure(figsize=size) g = gs.GridSpec(row,col) g.update(wspace=0.05, hspace=0.05) for i, sample in enumerate(src): ax = plt.subplot(g[i]) plt.axis('off') ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_aspect('equal') plt.imshow(sample) if path is not None and name is not None: os.makedirs(path, exist_ok=True) if path[-1] == '/': path = path[:-1] plt.savefig("./{}/{}.png".format(path, name)) if image_verbose: plt.show() else: plt.close()
データのロード
(x1, _), (x2, _) = mnist.load_data() X = np.concatenate((x1,x2)).reshape((-1,28*28))/255.0 X.shape
testsetは生成モデルでは要らないので訓練画像(60000枚)とテスト画像(10000枚)を結合して[70000,784]のデータを作る.
モデルの作成
class VAE(): def __init__(self, image_dim=784, hidden_dim=128, latent_dim=32, optimizer=None): self.im_dim = image_dim self.hd_dim = hidden_dim self.lt_dim = latent_dim if optimizer is None: self.optimizer = RMSprop() else: self.optimizer = optimizer self.model = self.model_compile(self.optimizer) def _build_encoder(self, input_dim=784, hidden_dim=128, latent_dim=32): in_layer = Input((input_dim,)) hidden = Dense(hidden_dim,activation='relu')(in_layer) z_mean = Dense(latent_dim,activation='linear')(hidden) z_std = Dense(latent_dim,activation='linear')(hidden) return Model(in_layer, [z_mean, z_std]) def _build_decoder(self, output_dim=784, hidden_dim=128, latent_dim=32): z_mean = Input((latent_dim,)) z_std = Input((latent_dim,)) sample = Lambda(self._sampling, output_shape=(self.lt_dim,))([z_mean, z_std]) hidden = Dense(hidden_dim,activation='relu')(sample) decoded = Dense(output_dim,activation='relu')(hidden) return Model([z_mean, z_std], decoded) def _sampling(self, args): z_mean, z_std = args epsilon = K.random_normal((self.lt_dim,)) return z_mean + z_std*epsilon def build_vae(self): self.encoder = self._build_encoder(self.im_dim, self.hd_dim, self.lt_dim) self.decoder = self._build_decoder(self.im_dim, self.hd_dim, self.lt_dim) in_layer = Input((self.im_dim,)) z_mean, z_std = self.encoder(in_layer) ######## self.z_m = z_mean self.z_s = z_std ######## decoded = self.decoder([z_mean, z_std]) vae = Model(in_layer, decoded) return vae def vae_loss(self, x_real, x_decoded): z_m2 = K.square(self.z_m) z_s2 = K.square(self.z_s) kl_loss = -0.5* K.mean(K.sum(1+K.log(z_s2)-z_m2-z_s2,axis=-1)) rc_loss = K.mean(K.sum(K.binary_crossentropy(x_real, x_decoded),axis=-1),axis=-1) return kl_loss+rc_loss def model_compile(self, optimizer): vae = self.build_vae() vae.compile(optimizer=optimizer,loss=self.vae_loss) return vae
エンコーダ,デコーダともに隠れ層が1つだけの単純なMLPとした.
エンコーダの出力は潜在変数の平均と標準偏差の2つがあるので,このような多入力多出力のようにちょっと複雑なモデルを作りたい際にはKeras Functional API*2が便利である. モデルの構造は大体こんな感じになってる.
そして誤差は論文にある通り, となるのでスッと実装する.
また,エンコーダの出力は誤差として使うのでメンバ変数に代入しとく(うえのコードの####
で強調した部分).
学習
vae = VAE(hidden_dim=64,latent_dim=3) batch_size = 32 i = 0 batch_len = int(X.shape[0]//batch_size) for epoch in range(10): np.random.shuffle(X) for batch in range(batch_len): X_minibatch = X[batch_size*(batch):batch_size*(batch+1)] if X_minibatch.shape[0] != batch_size: continue vae.model.train_on_batch(X_minibatch, X_minibatch) if batch%500==0: i += 1 clear_output(True) print("epoch:{}, batch:{}/{}".format(epoch, batch, batch_len)) image = vae.model.predict(X_minibatch) score = vae.model.evaluate(X_minibatch, X_minibatch) print("loss:{:.4f}".format(score)) image = image.reshape((-1,28,28)) show_image(image[:16])
今回は3次元の潜在変数へ投影する.
学習時間は高々5分弱だった.(GPU:NVidia GeForce GTX965m)
Matplotlibでの可視化
(_, _),(X_test, y_test) = mnist.load_data() X_t = X_test[:800].reshape((-1,784)) y_t = y_test[:800] z_t,_ = vae.encoder.predict(X_t)
800個のデータをエンコーダに通して次元圧縮する.
%matplotlib notebook fig = plt.figure(figsize=(8,10)) ax = Axes3D(fig) cmap = plt.get_cmap("jet",10) norm = mpl.colors.Normalize(0, 9) sm = cm.ScalarMappable(norm, cmap) sm.set_array([]) divider = make_axes_locatable(ax) fig.suptitle("VAE(ndim:784->3)") ax.scatter3D(z_t[:,0],z_t[:,1],z_t[:,2],c=y_t,cmap=cm.jet) plt.colorbar(sm, ticks=np.arange(0,10),boundaries=np.arange(-0.5,9.55,1),fraction=0.15,pad=0.1,orientation='horizontal')
jupyter上では%matplotlib notebook
と書くことでグリグリ動かせるグラフが出力される.
(colormap周りの実装は結構忘れがちなので少し戸惑った.)
plotlyでの可視化
(↓グリグリできるヨ)
こんな感じにiframeで簡単に埋め込みができて便利. グリグリ動かして楽しもう.
cmap = [ [0.0,'rgb(0, 0, 128)'],[0.1,'rgb(0, 0, 128)'], [0.1,'rgb(0, 0, 255)'],[0.2,'rgb(0, 0, 255)'], [0.2,'rgb(0, 99, 255)'],[0.3,'rgb(0, 99, 255)'], [0.3,'rgb(0, 213, 255)'],[0.4,'rgb(0, 213, 255)'], [0.4,'rgb(78, 255, 169)'],[0.5,'rgb(78, 255, 169)'], [0.5,'rgb(169, 255, 78)'],[0.6,'rgb(169, 255, 78)'], [0.6,'rgb(255, 230, 0)'],[0.7,'rgb(255, 230, 0)'], [0.7,'rgb(255, 125, 0)'],[0.8,'rgb(255, 125, 0)'], [0.8,'rgb(255, 20, 0)'],[0.9,'rgb(255, 20, 0)'], [0.9,'rgb(128, 0, 0)'],[1.0,'rgb(128, 0, 0)'] ]
カラースケールの定義. plotlyでSpectralのカラーマップへの変更の仕方がいまいちわかんなかった.
trace = go.Scatter3d( x=z_t[:, 0], y=z_t[:, 1], z=z_t[:, 2], mode="markers", marker=dict( size=3, color=y_t, cmin=-0.5, cmax=9.5, colorscale=cmap, colorbar=dict(dtick=1))) layout = go.Layout(margin=dict(l=0, r=0, b=0, t=0)) fig = go.Figure(data=[trace], layout=layout) py.iplot(fig, filename="example")
案外楽な文法でつらつら書くだけでできる.
VAE関連の理解がまだ足りない気がするので積極的に論文を読んでいきたい.
PRML
夏休みを通じて確率統計の知見を深める会が計画されているらしいので,手元にあるPRMLを見始めた.
- 作者: C.M.ビショップ,元田浩,栗田多喜夫,樋口知之,松本裕治,村田昇
- 出版社/メーカー: 丸善出版
- 発売日: 2012/04/05
- メディア: 単行本(ソフトカバー)
- 購入: 6人 クリック: 33回
- この商品を含むブログ (20件) を見る
- 作者: C.M.ビショップ,元田浩,栗田多喜夫,樋口知之,松本裕治,村田昇
- 出版社/メーカー: 丸善出版
- 発売日: 2012/02/29
- メディア: 単行本
- 購入: 6人 クリック: 14回
- この商品を含むブログを見る
やはり結構難解で一筋縄ではいかない.頑張る.