티스토리 뷰
단일신경망 Single Layer Neural Network (ver. python)
딥스탯 2017. 6. 24. 17:31출처¶
http://jorditorres.org/first-contact-with-tensorflow/#cap4 (First Contact with tensorflow)
단일신경망 Single Layer Neural Network (ver. python)¶
The MNIST data-set¶
숫자 손글씨(hand-written digits)에 관한 엄청 유명한 데이터 셋.
training set으로 6만개 이상, test set으로 1만개이다.
흑백사진으로 이루어져있고, anti-aliasing 돼있다. 전 처리가 다 돼있기 때문에, 패턴인식을 시작하는 사람들에게 이상적이라고 한다.
supervised learning의 한 예이다.
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("~/MNIST_data/", one_hot=True)
모형에 대한 자세한 설명은 생략하도록 하겠습니다.¶
Neural Network, activation function (softmax), loss function (cross-entropy), optimizer (gradient descent, batchsize) etc..
Single layer neural network¶
변수 지정
x = tf.placeholder("float", [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
모형 설정
Activation function : softmax
matm = tf.matmul(x,W)
y = tf.nn.softmax(tf.matmul(x,W) + b)
y_ = tf.placeholder("float", [None, 10])
Loss function : cross-entropy
cross_entropy = tf.reduce_sum(y_ * tf.log(y))
Optimizer : Gradient Descent (learning_rate = 0.01)
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
최적화 반복
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(101):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict = {x:batch_xs, y_:batch_ys})
if i in [j * 5 for j in range(21)]:
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print(i, sess.run(accuracy, feed_dict = {x: mnist.test.images, y_: mnist.test.labels}))
batch size = 100개로 하고, 계속 Gradient Descent를 반복하면서,
5의 배수일 때마다 test set의 정확도를 계산해서 출력한다.
정확도가 기껏해서 0.098로, 성능 안 좋다는 것을 알 수 있다.
같이보기¶
http://deepstat.tistory.com/8 (단일신경망 Single Layer Neural Network (ver.python)
'Tensorflow > Tensorflow for Python' 카테고리의 다른 글
Tensorflow-GPU 설치 on Ubuntu16.04 (0) | 2018.01.22 |
---|---|
Multilayer Perceptron (ver.python) (편집 예정) (0) | 2017.09.30 |
Convolutional Neural Network (ver. python) (0) | 2017.06.25 |
군집화 k-means Clustering (ver.Python) (0) | 2017.06.20 |
선형 회귀분석 Linear regression (ver.Python) (0) | 2017.06.06 |