NFAベースの正規表現エンジン(Python)

lexerを書いていないので、全体で100行少しです。インプットとしては正規表現postfix形式に変換して入力する必要があります。'a|b'の場合'ab|'といった感じです。あとは文字列の接続に'.'を使ってます。なので'abc'は'ab.c.'、'ab*c'は'ab*.c.'となります。

class State(object):
    def transfer(self, c, nlist):
        pass
    def appendTo(self, nlist):
        nlist.append(self)
SuccessState = State()
class CharState(State):
    def __init__(self, c):
        self.c = c
        self.out = None
    def transfer(self, c, nlist):
        if self.c == c and self.out is not None:
            self.out.appendTo(nlist)
    def connect(self, s):
        self.out = s
class RegExpError(Exception):
    pass
class SplitState(State):
    def __init__(self):
        self.first = None
        self.second = None
    def transfer(self, c, nlist):
        raise RegExpError()
    def appendTo(self, nlist):
        self.first.appendTo(nlist)
        self.second.appendTo(nlist)
    def connect(self, s):
        assert(self.first is not None)# greedy
        self.second = s
class Fragment(object):
    def __init__(self, start, out):
        self.start = start
        self.out = out
    def patch(self, outer):
        self.connect(outer.start)
        return Fragment(self.start, outer.out)
    def connect(self, state):
        for s in self.out:
            s.connect(state)
class NFA(object):
    def __init__(self, start):
        self.start = start
    def match(self, string):
        clist = []
        nlist = []
        self.start.appendTo(clist)        
        for c in string:
            for s in clist:
                s.transfer(c, nlist)
            clist = nlist
            nlist = []
        return SuccessState in clist
    @staticmethod
    def from_post(post):
        stack = []
        for c in post:
            if c == '.':
                f2 = stack.pop()
                f1 = stack.pop()
                stack.append(f1.patch(f2))
            elif c == '|':
                f2 = stack.pop()
                f1 = stack.pop()
                s = SplitState()
                s.first = f1.start
                s.second = f2.start
                out = []
                out.extend(f1.out)
                out.extend(f2.out)
                stack.append(Fragment(s, out))
            elif c == '*':
                f = stack.pop()
                s = SplitState()
                f.connect(s)
                s.first = f.start
                stack.append(Fragment(s, [s]))
            elif c == '+':
                f = stack.pop()
                s = SplitState()
                f.connect(s)
                s.first = f.start
                stack.append(Fragment(f.start, [s]))
            elif c == '?':
                f = stack.pop()
                s = SplitState()
                s.first = f.start
                out = [s]
                out.extend(f.out)
                stack.append(Fragment(s, out))
            else:
                s = CharState(c)
                stack.append(Fragment(s, [s]))
        f = stack.pop()
        assert(0 == len(stack))
        f.connect(SuccessState)
        return NFA(f.start)

続いてテスト

#!/usr/bin/env python2.6
import unittest
from regex import *
class RegexTests(unittest.TestCase):
    def test_char_transfer(self):
        s1 = CharState('a')
        s1.out = CharState('b')
        nlist = []
        s1.transfer('b', nlist)
        self.assertEqual([], nlist)
        s1.transfer('a', nlist)
        self.assertEqual([s1.out], nlist)
        nlist = []
        s1.out.transfer('b', nlist)
        self.assertEqual([], nlist)
    def test_char_fragment(self):
        s1 = CharState('a')
        s2 = CharState('b')
        f1 = Fragment(s1, [s1])
        f2 = Fragment(s2, [s2])
        self.assertEqual(None, s1.out)
        f = f1.patch(f2)
        self.assertEqual(s2, s1.out)
        self.assertEqual(s1, f.start)
        self.assertEqual([s2], f.out)
    def test_cat_match(self):
        nfa = NFA.from_post('ab.')
        self.assertTrue(nfa.match('ab'))
        self.assertFalse(nfa.match('a'))
        self.assertFalse(nfa.match('b'))
        self.assertFalse(nfa.match('abc'))
    def test_split_appendTo(self):
        s = SplitState()
        s1 = CharState('a')
        s2 = CharState('b')
        s.first = s1
        s.second = s2
        nlist = []
        s.appendTo(nlist)
        self.assertEqual([s1, s2], nlist)
    def test_nfa_alt_match(self):
        nfa = NFA.from_post('ab|')
        self.assertFalse(nfa.match(''))
        self.assertTrue(nfa.match('a'))
        self.assertTrue(nfa.match('b'))
        self.assertFalse(nfa.match('ab'))
    def test_nfa_star_match(self):
        nfa = NFA.from_post('a*')
        self.assertTrue(nfa.match(''))
        self.assertTrue(nfa.match('a'))
        self.assertTrue(nfa.match('aa'))
        self.assertTrue(nfa.match('aaaaa'))
        self.assertFalse(nfa.match('aab'))
    def test_nfa_plus_match(self):
        nfa = NFA.from_post('b+')
        self.assertFalse(nfa.match(''))
        self.assertTrue(nfa.match('b'))
        self.assertTrue(nfa.match('bbbbb'))
        self.assertFalse(nfa.match('bbba'))
    def test_nfa_quest_match(self):
        nfa = NFA.from_post('a?')
        self.assertTrue(nfa.match(''))
        self.assertTrue(nfa.match('a'))
        self.assertFalse(nfa.match('aa'))
        self.assertFalse(nfa.match('abc'))
    def test_nfa_mixed_case(self):
        nfa = NFA.from_post('a?b+.c*.')
        self.assertTrue(nfa.match('b'))
        self.assertTrue(nfa.match('abc'))
        self.assertTrue(nfa.match('abbc'))
        self.assertTrue(nfa.match('abbccc'))
        self.assertFalse(nfa.match('aabc'))
        self.assertFalse(nfa.match('ac'))
        self.assertFalse(nfa.match(''))
    def test_nfa_complex_case(self):
        buf = [] 
        buf.append('0r|') # 0|r
        buf.append('20*.1r|.') # 20*(1|r)
        buf.append('120*.2.|0*.')# (1|20*2)0*
        buf.append('2r|10*.1r|.|') #(2|r)|(10*(1|r))
        source = ''.join(buf) + '.||*'
        prog = NFA.from_post(source)
        self.assertFalse(prog.match('11'))
        self.assertTrue(prog.match(''))
        self.assertTrue(prog.match('111'))
        self.assertTrue(prog.match('21'))
        self.assertTrue(prog.match('2100011r'))
        self.assertTrue(prog.match('r'))
        self.assertTrue(prog.match('r00'))
        self.assertTrue(prog.match('20002101'))
if __name__ == '__main__':
    unittest.main()