This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def step(sum): | |
if sum>=0: | |
return 1 | |
else: | |
return -1 | |
def prediction(w, x): | |
sum = np.dot(w, x) | |
return step(sum) | |
def compare(y, l): | |
if y == l: | |
return True | |
else: | |
return False | |
def trainWeight(w, x, p, l): | |
new_weight = list() | |
for _w, _x in zip(w, x): | |
new_weight.append(_w + p * _x * l ) | |
return np.array(new_weight) | |
def train(x, w, p, l): | |
for _x, label in zip(x, l): | |
prediction_res = prediction(w, _x) | |
if compare(prediction_res, label) is False: | |
w = trainWeight(w, _x, p, label) | |
return w | |
# 学習用データ | |
x = np.array( [[1, 8], [1, 2], [1, 5], [1, 3], [1,6]] ) | |
# 重み | |
w = np.array( [0, 0] ) | |
# ラベル | |
l = np.array( [1, -1, 1, -1, 1]) | |
# 学習係数 | |
p = 1 | |
# エポック回数 | |
epoch = 10 | |
for i in range(epoch): | |
print ("epoc%d start w:%s" % (i+1, w) ) | |
w=train(x, w, p, l) | |
print ("epoc%d end w:%s" % (i+1, w) ) | |
for i in range(len(x)): | |
print ( compare(prediction(w, x[i]), l[i]) ) |
正しいかは未保証.グラフ書いたら分かるかもねー
とりあえず,何かしら学習が進んで分類されてる感じはする.
クラスが3つ以上の奴とか,ロジスティック回帰のアルゴリズムも勉強したい.
(数学的に深く理解したいとは言っていない)
クラスが3つ以上の奴とか,ロジスティック回帰のアルゴリズムも勉強したい.
(数学的に深く理解したいとは言っていない)
0 件のコメント:
コメントを投稿