誤り訂正法のパーセプトロン

わかりやすいパターン認識』の2章で説明されているパーセプトロンPythonで実装しました。説明はここでは省きます。

#!/usr/bin/env python2.7
from matplotlib import pyplot
from matplotlib.path import Path
import matplotlib.patches as patches
fig = pyplot.figure()
ax = fig.add_subplot(111)
def lineTo(w0, w, c):
  verts = [
    (w0[1], w0[0]),
    (w[1] ,  w[0]),
    ]
  codes = [
    Path.MOVETO,
    Path.LINETO,
  ]
  path = Path(verts, codes)
  patch = patches.PathPatch(path, edgecolor= c, lw=2)
  ax.add_patch(patch)
def draw(w, c):
  pyplot.scatter(w[1], w[0], c=c)
def g(w, x):
  v = 0
  for i, e in enumerate(x):
    v += w[i] * e
  return v
def train(rho, w, data, c):
  trained = False
  n = len(data)
  while not trained:
    for i in range(n):
      x = data[i]["x"]
      t = data[i]["t"]
      r = g(w, x)
      if t * r < 0:
        w0 = w[:]
        w = [w[i] + t*rho * v for i, v in enumerate(x)]
        lineTo(w0, w, c)
        draw(w, c)
        break
    if i == n - 1: trained = True 
  return w
def valid(data, w):
  for e in data:
    x = e["x"]
    t = e["t"]
    r = g(w, x)
    if t * r < 0: return False
  return True
if __name__ == '__main__':
  trained = False
  w = [11, 5]
  rho = 2.0
  data = [
    {"t": 1, "x":[1,  1.2]},
    {"t": 1, "x":[1,  0.2]},
    {"t": 1, "x":[1, -0.2]},
    {"t": -1,"x":[1, -0.5]},
    {"t": -1, "x":[1, -1.0]},
    {"t": -1, "x":[1, -1.5]}
  ]
  c = "r"
  rho, w0 = 2.0, [11, 5]
  draw(w0, c)
  w = train(rho, w0, data, c)
  if valid(data, w):
    print "rho=%f, w0 = %s, w = %s, (%f)" % (rho, w0, w, - w[0]/w[1])
  rho, w0 = 3.6, [-7, 2]
  c = "b"
  draw(w0, c)
  w = train(rho, w0, data, c)
  if valid(data, w):
    print "rho=%f, w0 = %s, w = %s, (%f)" % (rho, w0, w, - w[0]/w[1])

  rho, w0 = 1.2, [-7, 2]
  c = "g"
  draw(w0, c)
  w = train(rho, w0, data, c)
  if valid(data, w):
    print "rho=%f, w0 = %s, w = %s, (%f)" % (rho, w0, w, - w[0]/w[1])
  pyplot.xlabel("w1")
  pyplot.ylabel("w0")
  pyplot.show()

赤がrho=2.0、青がrho=3.6、緑がrho=1.2の場合の重みベクトルの移動の様子を示しています。
f:id:nabeyang:20130402170550p:plain