Flux in Julia/Flux in Julia

선형 회귀분석 Linear regression (ver. Flux in Julia)

딥스탯 2018. 11. 25. 21:28
선형 회귀분석 Linear regression (ver. Flux in Julia)

참고자료

http://jorditorres.org/first-contact-with-tensorflow/#cap2 (First contact with tensorflow)

http://fluxml.ai/ (flux: The Elegant Machine Learning Stack)

https://deepstat.tistory.com/5 (선형 회귀분석 Linear regression ver. Tensorflow for Python)

https://deepstat.tistory.com/6 (선형 회귀분석 Linear regression ver.Tensorflow for R)

선형 회귀분석 (ver. Flux in Julia)

linear regression (ver. Julia)

1. 좌표값 생성

좌표값 생성

In [1]:
num_point = 1000
x_data = randn(num_point)
y_data = x_data * 0.1 .+ 0.3 + sqrt(0.03) * randn(num_point);

그림 그리는 코드

In [2]:
using Plots
In [3]:
scatter(x_data, y_data, alpha = 0.25)
Out[3]:
-2 0 2 4 -0.2 0.0 0.2 0.4 0.6 0.8 1.0 y1

2. 비용함수와 경사하강법

Flux 패키지 불러오기

In [4]:
using Flux

모수 지정

In [5]:
W = param(rand(1))
b = param(rand(1))

yhat(x) = W .* x_data .+ b
Out[5]:
yhat (generic function with 1 method)

비용함수 지정 (Cost, Loss) (MSE : mean square error)

In [6]:
function loss(x, y)
    sum((y .- yhat(x)).^2)
end
Out[6]:
loss (generic function with 1 method)

최적화 (경사하강법 : GD gradient descent algorithm)

In [7]:
data = [(x_data,y_data)]

for i = 0:9
    Flux.train!(loss, data, SGD(params(W,b), 0.0002))
    print(i, ".  W = ", W[1], "  b = ", b[1])
    print("\n")
end
0.  W = 0.05462474597510691 (tracked)  b = 0.1366739953471074 (tracked)
1.  W = 0.07265082938365555 (tracked)  b = 0.2030295733909739 (tracked)
2.  W = 0.08369064180104697 (tracked)  b = 0.24286431174510387 (tracked)
3.  W = 0.09045086040144956 (tracked)  b = 0.2667782556874395 (tracked)
4.  W = 0.09458990559617989 (tracked)  b = 0.28113464439473934 (tracked)
5.  W = 0.09712375953000592 (tracked)  b = 0.2897533894182263 (tracked)
6.  W = 0.09867474060772886 (tracked)  b = 0.2949276433532155 (tracked)
7.  W = 0.09962398146088063 (tracked)  b = 0.29803403626127095 (tracked)
8.  W = 0.10020486995015893 (tracked)  b = 0.2998989984688887 (tracked)
9.  W = 0.10056030227998272 (tracked)  b = 0.30101866513301195 (tracked)
In [ ]: