Tensorflow/Tensorflow for Python

Recurrent Neural Network (ver.Python)

딥스탯 2018. 6. 13. 20:27
Recurrent Neural Network (ver.Python)

참고자료

https://www.tensorflow.org (TensorFlow)

출처

https://medium.com/@erikhallstrm/hello-world-rnn-83cd7105b767 (How to build a Recurrent Neural Network in TensorFlow (1/7), Erik Hall strom)

함께보기

Recurrent Neural Network (ver.R)

Recurrent Neural Network (ver.Python)

In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

num_epochs = 100
total_series_length = 50000
truncated_backprop_length = 15
state_size = 4
num_classes = 2
echo_step = 3
batch_size = 5
num_batches = total_series_length // batch_size // truncated_backprop_length

Generate data

In [2]:
def generateData():
    x = np.array(np.random.choice(2, total_series_length, p=[.5,.5]))
    y = np.roll(x, echo_step)
    y[0:echo_step] = 0
    
    x = x.reshape((batch_size, -1))
    y = y.reshape((batch_size, -1))
    
    return (x, y)

Buliding the computational graph

Variables and placeholders

In [3]:
batchX_placeholder = tf.placeholder(tf.float32, [batch_size, truncated_backprop_length])
batchY_placeholder = tf.placeholder(tf.int32, [batch_size, truncated_backprop_length])

init_state = tf.placeholder(tf.float32, [batch_size, state_size])
In [4]:
W = tf.Variable(np.random.rand(state_size+1, state_size), dtype=tf.float32)
b = tf.Variable(np.zeros((1,state_size)), dtype=tf.float32)

W2 = tf.Variable(np.random.rand(state_size, num_classes),dtype=tf.float32)
b2 = tf.Variable(np.zeros((1,num_classes)), dtype=tf.float32)

Unpacking

In [5]:
inputs_series = tf.unstack(batchX_placeholder, axis=1)
labels_series = tf.unstack(batchY_placeholder, axis=1)

Forward pass

In [6]:
current_state = init_state
states_series = []
for current_input in inputs_series:
    current_input = tf.reshape(current_input, [batch_size, 1])
    input_and_state_concatenated = tf.concat([current_input, current_state], 1)

    next_state = tf.tanh(tf.matmul(input_and_state_concatenated, W) + b)
    states_series.append(next_state)
    current_state = next_state

Calculating loss

In [7]:
logits_series = [tf.matmul(state, W2) + b2 for state in states_series]
predictions_series = [tf.nn.softmax(logits) for logits in logits_series]

losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(labels = labels, logits = logits) for logits, labels in zip(logits_series,labels_series)]
total_loss = tf.reduce_mean(losses)

train_step = tf.train.AdagradOptimizer(0.1).minimize(total_loss)

Visualizing the training

In [8]:
def plot(loss_list, predictions_series, batchX, batchY):
    plt.subplot(2, 3, 1)
    plt.cla()
    plt.plot(loss_list)

    for batch_series_idx in range(5):
        one_hot_output_series = np.array(predictions_series)[:, batch_series_idx, :]
        single_output_series = np.array([(1 if out[0] < 0.5 else 0) for out in one_hot_output_series])

        plt.subplot(2, 3, batch_series_idx + 2)
        plt.cla()
        plt.axis([0, truncated_backprop_length, 0, 2])
        left_offset = range(truncated_backprop_length)
        plt.bar(left_offset, batchX[batch_series_idx, :], width=1, color="blue")
        plt.bar(left_offset, batchY[batch_series_idx, :] * 0.5, width=1, color="red")
        plt.bar(left_offset, single_output_series * 0.3, width=1, color="green")

    plt.draw()
    plt.pause(0.0001)

Running a training session

In [9]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    plt.ion()
    plt.figure()
    plt.show()
    loss_list = []

    for epoch_idx in range(num_epochs):
        x,y = generateData()
        _current_state = np.zeros((batch_size, state_size))

        if epoch_idx % 10 == 9:
            print("New data, epoch", epoch_idx)

        for batch_idx in range(num_batches):
            start_idx = batch_idx * truncated_backprop_length
            end_idx = start_idx + truncated_backprop_length

            batchX = x[:,start_idx:end_idx]
            batchY = y[:,start_idx:end_idx]

            _total_loss, _train_step, _current_state, _predictions_series = sess.run(
                [total_loss, train_step, current_state, predictions_series],
                feed_dict={
                    batchX_placeholder:batchX,
                    batchY_placeholder:batchY,
                    init_state:_current_state
                })

            loss_list.append(_total_loss)


            if (epoch_idx % 10 == 9) & (batch_idx % 200 == 199):
                print("Step",batch_idx, "Loss", _total_loss)
                plot(loss_list, _predictions_series, batchX, batchY)

plt.ioff()
plt.show()
<Figure size 432x288 with 0 Axes>
New data, epoch 9
Step 199 Loss 0.00079908106
Step 399 Loss 0.0007074815
Step 599 Loss 0.0008050729
New data, epoch 19
Step 199 Loss 0.0004187428
Step 399 Loss 0.00036695477
Step 599 Loss 0.00032974474
New data, epoch 29
Step 199 Loss 0.00022247522
Step 399 Loss 0.00026064744
Step 599 Loss 0.00019080346
New data, epoch 39
Step 199 Loss 0.00015691138
Step 399 Loss 0.00013432371
Step 599 Loss 0.00015947461
New data, epoch 49
Step 199 Loss 0.000120350094
Step 399 Loss 0.00011892631
Step 599 Loss 0.00012949985
New data, epoch 59
Step 199 Loss 9.575777e-05
Step 399 Loss 0.00012161001
Step 599 Loss 9.048325e-05
New data, epoch 69
Step 199 Loss 9.5389856e-05
Step 399 Loss 9.0413916e-05
Step 599 Loss 8.450087e-05
New data, epoch 79
Step 199 Loss 7.612624e-05
Step 399 Loss 6.0770177e-05
Step 599 Loss 9.961249e-05
New data, epoch 89
Step 199 Loss 5.8486232e-05
Step 399 Loss 8.736214e-05
Step 599 Loss 5.8161673e-05
New data, epoch 99
Step 199 Loss 5.9515685e-05
Step 399 Loss 6.8561596e-05
Step 599 Loss 5.846836e-05

만일 input을 0으로 한다면?

In [10]:
def generateData2():
    x = np.array(np.zeros(total_series_length))
    y = np.roll(x, echo_step)
    y[0:echo_step] = 0
    
    x = x.reshape((batch_size, -1))
    y = y.reshape((batch_size, -1))
    
    return (x, y)
In [11]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    plt.ion()
    plt.figure()
    plt.show()
    loss_list = []

    for epoch_idx in range(num_epochs):
        x,y = generateData()
        _current_state = np.zeros((batch_size, state_size))

        if epoch_idx % 10 == 9:
            print("New data, epoch", epoch_idx)

        for batch_idx in range(num_batches):
            start_idx = batch_idx * truncated_backprop_length
            end_idx = start_idx + truncated_backprop_length

            batchX = x[:,start_idx:end_idx]
            batchY = y[:,start_idx:end_idx]

            _total_loss, _train_step, _current_state, _predictions_series = sess.run(
                [total_loss, train_step, current_state, predictions_series],
                feed_dict={
                    batchX_placeholder:batchX,
                    batchY_placeholder:batchY,
                    init_state:_current_state
                })

            loss_list.append(_total_loss)


            if (epoch_idx % 10 == 9) & (batch_idx % 200 == 199):
                print("Step",batch_idx, "Loss", _total_loss)
                plot(loss_list, _predictions_series, batchX, batchY)

plt.ioff()
plt.show()
<Figure size 432x288 with 0 Axes>
New data, epoch 9
Step 199 Loss 0.0007161029
Step 399 Loss 0.00086138956
Step 599 Loss 0.00075928005
New data, epoch 19
Step 199 Loss 0.00028155174
Step 399 Loss 0.00032595568
Step 599 Loss 0.0003387397
New data, epoch 29
Step 199 Loss 0.00022976821
Step 399 Loss 0.00019787454
Step 599 Loss 0.0002100072
New data, epoch 39
Step 199 Loss 0.00016683446
Step 399 Loss 0.00014772288
Step 599 Loss 0.00015413253
New data, epoch 49
Step 199 Loss 9.6867625e-05
Step 399 Loss 0.00010297788
Step 599 Loss 0.00011774493
New data, epoch 59
Step 199 Loss 0.000100157085
Step 399 Loss 9.909426e-05
Step 599 Loss 9.6904936e-05
New data, epoch 69
Step 199 Loss 7.597725e-05
Step 399 Loss 8.0316684e-05
Step 599 Loss 0.000111936875
New data, epoch 79
Step 199 Loss 7.7389624e-05
Step 399 Loss 6.651924e-05
Step 599 Loss 6.490248e-05
New data, epoch 89
Step 199 Loss 6.890087e-05
Step 399 Loss 6.337313e-05
Step 599 Loss 6.686284e-05
New data, epoch 99
Step 199 Loss 5.7047135e-05
Step 399 Loss 5.8833975e-05
Step 599 Loss 7.5563024e-05