티스토리 뷰

Recurrent Neural Network (ver.R)

참고자료

https://www.tensorflow.org (TensorFlow)

https://tensorflow.rstudio.com (TensorFlow for R)

https://medium.com/@erikhallstrm/hello-world-rnn-83cd7105b767
How to build a Recurrent Neural Network in TensorFlow (1/7) - Erik Hallstrom

함께보기

Recurrent Neural Network (ver.Python)

Recurrent Neural Network (ver.R)

In [1]:
require(tensorflow)

num_epochs                      <- 100
total_series_length             <- 50000
truncated_backprop_length <- 15
state_size                          <- 4
num_classes                     <- 2
echo_step                         <- 3
batch_size                         <- 5
num_batches <- total_series_length / batch_size / truncated_backprop_length
num_batches <- as.integer(num_batches)
Loading required package: tensorflow

Generate data

In [2]:
generateData <- function(){
    x <- rbinom(total_series_length,1,0.5)
    y <- c(rep(0,echo_step+1),x[(-length(x)+echo_step):(-length(x))])
    
    dim(x) <- c(batch_size , total_series_length/batch_size)
    dim(y) <- c(batch_size , total_series_length/batch_size)
    
    return(list(x,y))
}

Building the computational graph

Variables and placeholders

In [3]:
batchX_placeholder <- tf$placeholder(
    tf$float32, as.integer(c(batch_size, truncated_backprop_length)))
batchY_placeholder <- tf$placeholder(
    tf$int32, as.integer(c(batch_size, truncated_backprop_length)))

init_state <- tf$placeholder(
    tf$float32,as.integer(c(batch_size, state_size)))
In [4]:
W <- tf$Variable(matrix(runif((state_size+1)*state_size),state_size+1), dtype=tf$float32)
b <- tf$Variable(t(rep(0,state_size)), dtype = tf$float32)

W2 <- tf$Variable(matrix(runif(state_size*num_classes),state_size), dtype=tf$float32)
b2 <- tf$Variable(t(rep(0,num_classes)), dtype = tf$float32)

Unpacking

In [5]:
inputs_series <- tf$unstack(batchX_placeholder, axis = 1L)
labels_series <- tf$unstack(batchY_placeholder, axis = 1L)

Forward pass

In [6]:
current_state <- init_state
states_series <- NULL
for(current_input in inputs_series){
    current_input <- tf$reshape(current_input, as.integer(c(batch_size, 1)))
    input_and_state_concatenated <- tf$concat(c(current_input, current_state),1L)
    
    next_state <- tf$tanh(tf$matmul(input_and_state_concatenated, W) + b)
    states_series <- c(states_series,next_state)
    current_state <- next_state
}
In [7]:
logits_series <- list()
predictions_series <- list()
losses <- list()
for(i in 1:length(states_series)){
    state <- states_series[[i]]
    logits_series[[i]] <- tf$matmul(state,W2) + b2
    
    logits <- logits_series[[i]]
    predictions_series[[i]] <- tf$nn$softmax(logits)
    
    labels <- labels_series[[i]]
    losses[[i]] <- tf$nn$sparse_softmax_cross_entropy_with_logits(labels = labels,logits = logits)
}

total_loss <- tf$reduce_mean(losses)

train_step <- tf$train$AdagradOptimizer(0.3)$minimize(total_loss)

Running a training session

In [8]:
sess <- tf$Session()

sess$run(tf$global_variables_initializer())
loss_list <- NULL
In [9]:
for(epoch_idx in 1:num_epochs){
    temp <- generateData()
    x <- temp[[1]] ; y <- temp[[2]]
    temp_current_state <- matrix(0, nrow = batch_size, ncol = state_size)
    
    cat("New data, epoch",epoch_idx,"\n")
    
    for(batch_idx in 1:num_batches){
        start_idx <- (batch_idx-1) * truncated_backprop_length +1
        end_idx <- start_idx + truncated_backprop_length -1
        
        batchX <- x[,start_idx:end_idx]
        batchY <- y[,start_idx:end_idx]
        
        temp_total_loss <- sess$run(total_loss,
                                    feed_dict = dict(batchX_placeholder = batchX,
                                                     batchY_placeholder = batchY,
                                                     init_state = temp_current_state))
        temp_train_step <- sess$run(train_step,
                                    feed_dict = dict(batchX_placeholder = batchX,
                                                     batchY_placeholder = batchY,
                                                     init_state = temp_current_state))
        temp_current_state <- sess$run(current_state,
                                       feed_dict = dict(batchX_placeholder = batchX,
                                                        batchY_placeholder = batchY,
                                                        init_state = temp_current_state))
        temp_predictions_series <- sess$run(predictions_series,
                                            feed_dict = dict(batchX_placeholder = batchX,
                                                             batchY_placeholder = batchY,
                                                             init_state = temp_current_state))
        
        loss_list <- c(loss_list , temp_total_loss)
        
        if(batch_idx%%100 == 0){
            cat("step", batch_idx, "Loss", temp_total_loss, "\n")
        }
    }
}

New data, epoch 1 step 100 Loss 0.6923453 step 200 Loss 0.6969035 step 300 Loss 0.6889008 step 400 Loss 0.7145686 step 500 Loss 0.7020864 step 600 Loss 0.6930443 New data, epoch 2 step 100 Loss 0.6960838 step 200 Loss 0.686676 step 300 Loss 0.6962051 step 400 Loss 0.6945438 step 500 Loss 0.6925812 step 600 Loss 0.696036 .

.

.

New data, epoch 100 step 100 Loss 0.6929061 step 200 Loss 0.6938135 step 300 Loss 0.6919 step 400 Loss 0.6934645 step 500 Loss 0.6927102 step 600 Loss 0.6929337

만일 input을 0으로 한다면?

In [10]:
generateData2 <- function(){
    x <- y <- rep(0,total_series_length)
    
    dim(x) <- c(batch_size , total_series_length/batch_size)
    dim(y) <- c(batch_size , total_series_length/batch_size)
    
    return(list(x,y))
}
In [11]:
for(epoch_idx in 1:num_epochs){
    temp <- generateData2()
    x <- temp[[1]] ; y <- temp[[2]]
    temp_current_state <- matrix(0, nrow = batch_size, ncol = state_size)
    
    cat("New data, epoch",epoch_idx,"\n")
    
    for(batch_idx in 1:num_batches){
        start_idx <- (batch_idx-1) * truncated_backprop_length +1
        end_idx <- start_idx + truncated_backprop_length -1
        
        batchX <- x[,start_idx:end_idx]
        batchY <- y[,start_idx:end_idx]
        
        temp_total_loss <- sess$run(total_loss,
                                    feed_dict = dict(batchX_placeholder = batchX,
                                                     batchY_placeholder = batchY,
                                                     init_state = temp_current_state))
        temp_train_step <- sess$run(train_step,
                                    feed_dict = dict(batchX_placeholder = batchX,
                                                     batchY_placeholder = batchY,
                                                     init_state = temp_current_state))
        temp_current_state <- sess$run(current_state,
                                       feed_dict = dict(batchX_placeholder = batchX,
                                                        batchY_placeholder = batchY,
                                                        init_state = temp_current_state))
        temp_predictions_series <- sess$run(predictions_series,
                                            feed_dict = dict(batchX_placeholder = batchX,
                                                             batchY_placeholder = batchY,
                                                             init_state = temp_current_state))
        
        loss_list <- c(loss_list , temp_total_loss)
        
        if(batch_idx%%100 == 0){
            cat("step", batch_idx, "Loss", temp_total_loss, "\n")
        }
    }
}

New data, epoch 1 step 100 Loss 0.01274876 step 200 Loss 0.005752675 step 300 Loss 0.003710649 step 400 Loss 0.002737705 step 500 Loss 0.002168686 step 600 Loss 0.001795348 New data, epoch 2 step 100 Loss 0.001393896 step 200 Loss 0.001229484 step 300 Loss 0.001099816 step 400 Loss 0.0009947831 step 500 Loss 0.0009080815 step 600 Loss 0.000835308 .

.

.

New data, epoch 100 step 100 Loss 1.56163e-05 step 200 Loss 1.56163e-05 step 300 Loss 1.549709e-05 step 400 Loss 1.549709e-05 step 500 Loss 1.549709e-05 step 600 Loss 1.549709e-05


공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
TAG
more
«   2025/05   »
1 2 3
4 5 6 7 8 9 10
11 12 13 14 15 16 17
18 19 20 21 22 23 24
25 26 27 28 29 30 31
글 보관함