import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
#SAMPLING LAYER
class Sample(layers.Layer):
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
latent_dim=2
#VAE MODEL
class VAE(keras.Model):
def __init__(self):
super(VAE,self).__init__()
#encoder
encoder_inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sample()([z_mean, z_log_var])
self.encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
#decoder
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
self.decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
def train_step(self, data):
if isinstance(data, tuple):
data = data[0]
with tf.GradientTape() as tape:
#latent mean and variance
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = tf.reduce_mean(
keras.losses.binary_crossentropy(data, reconstruction)
)
#loss function
reconstruction_loss *= 28 * 28
kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
kl_loss = tf.reduce_mean(kl_loss)
kl_loss *= -0.5
total_loss = reconstruction_loss + kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
return {
"loss": total_loss,
"reconstruction_loss": reconstruction_loss,
"kl_loss": kl_loss,
}
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255
vae = VAE()
vae.compile(optimizer=keras.optimizers.Adam())
vae.fit(mnist_digits, epochs=30, batch_size=128)
Epoch 1/30
547/547 [==============================] - 9s 17ms/step - loss: 207.9442 - reconstruction_loss: 206.5694 - kl_loss: 1.3748
Epoch 2/30
547/547 [==============================] - 9s 16ms/step - loss: 167.6381 - reconstruction_loss: 164.8430 - kl_loss: 2.7950
Epoch 3/30
547/547 [==============================] - 9s 17ms/step - loss: 157.1702 - reconstruction_loss: 154.0470 - kl_loss: 3.1231
Epoch 4/30
547/547 [==============================] - 9s 16ms/step - loss: 154.2223 - reconstruction_loss: 151.0141 - kl_loss: 3.2082
Epoch 5/30
547/547 [==============================] - 9s 16ms/step - loss: 152.4078 - reconstruction_loss: 149.1486 - kl_loss: 3.2593
Epoch 6/30
547/547 [==============================] - 9s 16ms/step - loss: 151.1136 - reconstruction_loss: 147.8070 - kl_loss: 3.3066
Epoch 7/30
547/547 [==============================] - 9s 16ms/step - loss: 150.1958 - reconstruction_loss: 146.8621 - kl_loss: 3.3336
Epoch 8/30
547/547 [==============================] - 9s 16ms/step - loss: 149.2447 - reconstruction_loss: 145.8765 - kl_loss: 3.3683
Epoch 9/30
547/547 [==============================] - 9s 16ms/step - loss: 148.5966 - reconstruction_loss: 145.1993 - kl_loss: 3.3972
Epoch 10/30
547/547 [==============================] - 9s 16ms/step - loss: 147.9715 - reconstruction_loss: 144.5540 - kl_loss: 3.4174
Epoch 11/30
547/547 [==============================] - 9s 16ms/step - loss: 147.4371 - reconstruction_loss: 143.9974 - kl_loss: 3.4396
Epoch 12/30
547/547 [==============================] - 9s 16ms/step - loss: 147.0404 - reconstruction_loss: 143.5817 - kl_loss: 3.4587
Epoch 13/30
547/547 [==============================] - 9s 16ms/step - loss: 146.5592 - reconstruction_loss: 143.0784 - kl_loss: 3.4807
Epoch 14/30
547/547 [==============================] - 9s 16ms/step - loss: 146.2075 - reconstruction_loss: 142.7118 - kl_loss: 3.4957
Epoch 15/30
547/547 [==============================] - 9s 16ms/step - loss: 145.9416 - reconstruction_loss: 142.4310 - kl_loss: 3.5106
Epoch 16/30
547/547 [==============================] - 9s 17ms/step - loss: 145.5281 - reconstruction_loss: 142.0085 - kl_loss: 3.5196
Epoch 17/30
547/547 [==============================] - 9s 16ms/step - loss: 145.2842 - reconstruction_loss: 141.7477 - kl_loss: 3.5365
Epoch 18/30
547/547 [==============================] - 9s 16ms/step - loss: 145.1024 - reconstruction_loss: 141.5528 - kl_loss: 3.5496
Epoch 19/30
547/547 [==============================] - 9s 16ms/step - loss: 144.7374 - reconstruction_loss: 141.1775 - kl_loss: 3.5599
Epoch 20/30
547/547 [==============================] - 9s 16ms/step - loss: 144.5054 - reconstruction_loss: 140.9273 - kl_loss: 3.5781
Epoch 21/30
547/547 [==============================] - 9s 16ms/step - loss: 144.3437 - reconstruction_loss: 140.7661 - kl_loss: 3.5776
Epoch 22/30
547/547 [==============================] - 9s 16ms/step - loss: 144.1328 - reconstruction_loss: 140.5432 - kl_loss: 3.5897
Epoch 23/30
547/547 [==============================] - 9s 16ms/step - loss: 143.9308 - reconstruction_loss: 140.3421 - kl_loss: 3.5887
Epoch 24/30
547/547 [==============================] - 9s 16ms/step - loss: 143.7300 - reconstruction_loss: 140.1331 - kl_loss: 3.5968
Epoch 25/30
547/547 [==============================] - 9s 16ms/step - loss: 143.5860 - reconstruction_loss: 139.9617 - kl_loss: 3.6243
Epoch 26/30
547/547 [==============================] - 9s 16ms/step - loss: 143.4559 - reconstruction_loss: 139.8398 - kl_loss: 3.6162
Epoch 27/30
547/547 [==============================] - 9s 16ms/step - loss: 143.3631 - reconstruction_loss: 139.7232 - kl_loss: 3.6399
Epoch 28/30
547/547 [==============================] - 9s 16ms/step - loss: 143.2122 - reconstruction_loss: 139.5813 - kl_loss: 3.6309
Epoch 29/30
547/547 [==============================] - 9s 16ms/step - loss: 143.0500 - reconstruction_loss: 139.4098 - kl_loss: 3.6402
Epoch 30/30
547/547 [==============================] - 9s 16ms/step - loss: 142.8536 - reconstruction_loss: 139.2033 - kl_loss: 3.6503
<tensorflow.python.keras.callbacks.History at 0x7fe18c2a9518>
#Visualising over latent space
import matplotlib.pyplot as plt
def plot_latent( ):
# display digits
n = 30
digit_size = 28
scale = 2.0
figsize = 15
figure = np.zeros((digit_size * n, digit_size * n))
grid_x = np.linspace(-scale, scale, n)
grid_y = np.linspace(-scale, scale, n)[::-1]
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
z_sample = np.array([[xi, yi]])
x_decoded = vae.decoder.predict(z_sample)
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[
i * digit_size : (i + 1) * digit_size,
j * digit_size : (j + 1) * digit_size,
] = digit
plt.figure(figsize=(figsize, figsize))
start_range = digit_size // 2
end_range = n * digit_size + start_range + 1
pixel_range = np.arange(start_range, end_range, digit_size)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.imshow(figure, cmap="Greys_r")
plt.show()
plot_latent()