誤り訂正法のパーセプトロン
『わかりやすいパターン認識』の2章で説明されているパーセプトロンをPythonで実装しました。説明はここでは省きます。
#!/usr/bin/env python2.7 from matplotlib import pyplot from matplotlib.path import Path import matplotlib.patches as patches fig = pyplot.figure() ax = fig.add_subplot(111) def lineTo(w0, w, c): verts = [ (w0[1], w0[0]), (w[1] , w[0]), ] codes = [ Path.MOVETO, Path.LINETO, ] path = Path(verts, codes) patch = patches.PathPatch(path, edgecolor= c, lw=2) ax.add_patch(patch) def draw(w, c): pyplot.scatter(w[1], w[0], c=c) def g(w, x): v = 0 for i, e in enumerate(x): v += w[i] * e return v def train(rho, w, data, c): trained = False n = len(data) while not trained: for i in range(n): x = data[i]["x"] t = data[i]["t"] r = g(w, x) if t * r < 0: w0 = w[:] w = [w[i] + t*rho * v for i, v in enumerate(x)] lineTo(w0, w, c) draw(w, c) break if i == n - 1: trained = True return w def valid(data, w): for e in data: x = e["x"] t = e["t"] r = g(w, x) if t * r < 0: return False return True if __name__ == '__main__': trained = False w = [11, 5] rho = 2.0 data = [ {"t": 1, "x":[1, 1.2]}, {"t": 1, "x":[1, 0.2]}, {"t": 1, "x":[1, -0.2]}, {"t": -1,"x":[1, -0.5]}, {"t": -1, "x":[1, -1.0]}, {"t": -1, "x":[1, -1.5]} ] c = "r" rho, w0 = 2.0, [11, 5] draw(w0, c) w = train(rho, w0, data, c) if valid(data, w): print "rho=%f, w0 = %s, w = %s, (%f)" % (rho, w0, w, - w[0]/w[1]) rho, w0 = 3.6, [-7, 2] c = "b" draw(w0, c) w = train(rho, w0, data, c) if valid(data, w): print "rho=%f, w0 = %s, w = %s, (%f)" % (rho, w0, w, - w[0]/w[1]) rho, w0 = 1.2, [-7, 2] c = "g" draw(w0, c) w = train(rho, w0, data, c) if valid(data, w): print "rho=%f, w0 = %s, w = %s, (%f)" % (rho, w0, w, - w[0]/w[1]) pyplot.xlabel("w1") pyplot.ylabel("w0") pyplot.show()
赤がrho=2.0、青がrho=3.6、緑がrho=1.2の場合の重みベクトルの移動の様子を示しています。