DFAベース正規表現エンジンを追加(Python)

前回のものにDFAを構成しながらマッチングするクラスを追加しました。

class State(object):
    def transfer(self, c, nlist):
        pass
    def appendTo(self, nlist):
        nlist.append(self)
SuccessState = State()
class CharState(State):
    def __init__(self, c):
        self.c = c
        self.out = None
    def transfer(self, c, nlist):
        if self.c == c and self.out is not None:
            self.out.appendTo(nlist)
    def connect(self, s):
        self.out = s
class RegExpError(Exception):
    pass
class SplitState(State):
    def __init__(self):
        self.first = None
        self.second = None
    def transfer(self, c, nlist):
        raise RegExpError()
    def appendTo(self, nlist):
        self.first.appendTo(nlist)
        self.second.appendTo(nlist)
    def connect(self, s):
        assert(self.first is not None)# greedy
        self.second = s
class Fragment(object):
    def __init__(self, start, out):
        self.start = start
        self.out = out
    def patch(self, outer):
        self.connect(outer.start)
        return Fragment(self.start, outer.out)
    def connect(self, state):
        for s in self.out:
            s.connect(state)
def from_post(post, FA):
    stack = []
    for c in post:
        if c == '.':
            f2 = stack.pop()
            f1 = stack.pop()
            stack.append(f1.patch(f2))
        elif c == '|':
            f2 = stack.pop()
            f1 = stack.pop()
            s = SplitState()
            s.first = f1.start
            s.second = f2.start
            out = []
            out.extend(f1.out)
            out.extend(f2.out)
            stack.append(Fragment(s, out))
        elif c == '*':
            f = stack.pop()
            s = SplitState()
            f.connect(s)
            s.first = f.start
            stack.append(Fragment(s, [s]))
        elif c == '+':
            f = stack.pop()
            s = SplitState()
            f.connect(s)
            s.first = f.start
            stack.append(Fragment(f.start, [s]))
        elif c == '?':
            f = stack.pop()
            s = SplitState()
            s.first = f.start
            out = [s]
            out.extend(f.out)
            stack.append(Fragment(s, out))
        else:
            s = CharState(c)
            stack.append(Fragment(s, [s]))
    f = stack.pop()
    assert(0 == len(stack))
    f.connect(SuccessState)
    return FA(f.start)
class NFA(object):
    def __init__(self, start):
        self.start = start
    def match(self, string):
        clist = []
        nlist = []
        self.start.appendTo(clist)        
        for c in string:
            for s in clist:
                s.transfer(c, nlist)
            clist = nlist
            nlist = []
        return SuccessState in clist
    @staticmethod
    def from_post(post):
        return from_post(post, FA=NFA)
class DFA(object):
    def __init__(self, start):
        self.start = start
        self.states = []
    def match(self, string):
        nlist = []
        self.start.appendTo(nlist)
        s = self.find(nlist)
        if s is None:
            s = DState(nlist)
            self.states.append(s)
        nlist = []
        for c in string:
            s = s.getnext(c, nlist, self)
            nlist = []
        return SuccessState in s.nstates
    def find(self, nlist):
        nset = set(nlist)
        for s in self.states:
            if nset == s.nstates:
                return s
        return None
    @staticmethod
    def from_post(post):
        return from_post(post, FA=DFA)
class DState(object):
    def __init__(self, nstates):
        self.nstates = set(nstates)
        self.next = {}
    def transfer(self, c, nlist):
        for s in self.nstates:
            s.transfer(c, nlist)
    def getnext(self, c, nlist, dfa):
        ns = self.next.get(c)
        if ns is not None:
            return ns
        self.transfer(c, nlist)
        ns = dfa.find(nlist)
        if ns is None:
            ns = DState(nlist)
            dfa.states.append(ns)
        self.setnext(c, ns)
        return ns 
    def setnext(self, c, ns):
          self.next.update({c: ns})

続いてテスト。DFAとNFAのマッチングについてのテストは同じものを流用してます。

#!/usr/bin/env python2.6
import unittest
from test import test_support
from regex import *
class RegexTests(unittest.TestCase):
    def test_char_transfer(self):
        s1 = CharState('a')
        s1.out = CharState('b')
        nlist = []
        s1.transfer('b', nlist)
        self.assertEqual([], nlist)
        s1.transfer('a', nlist)
        self.assertEqual([s1.out], nlist)
        nlist = []
        s1.out.transfer('b', nlist)
        self.assertEqual([], nlist)
    def test_char_fragment(self):
        s1 = CharState('a')
        s2 = CharState('b')
        f1 = Fragment(s1, [s1])
        f2 = Fragment(s2, [s2])
        self.assertEqual(None, s1.out)
        f = f1.patch(f2)
        self.assertEqual(s2, s1.out)
        self.assertEqual(s1, f.start)
        self.assertEqual([s2], f.out)
    def test_cat_match(self):
        nfa = NFA.from_post('ab.')
        self.assertTrue(nfa.match('ab'))
        self.assertFalse(nfa.match('a'))
        self.assertFalse(nfa.match('b'))
        self.assertFalse(nfa.match('abc'))
    def test_split_appendTo(self):
        s = SplitState()
        s1 = CharState('a')
        s2 = CharState('b')
        s.first = s1
        s.second = s2
        nlist = []
        s.appendTo(nlist)
        self.assertEqual([s1, s2], nlist)
class MatchTests(unittest.TestCase):
    def test_alt_match(self):
        nfa = self.FA.from_post('ab|')
        self.assertFalse(nfa.match(''))
        self.assertTrue(nfa.match('a'))
        self.assertTrue(nfa.match('b'))
        self.assertFalse(nfa.match('ab'))
    def test_star_match(self):
        nfa = self.FA.from_post('a*')
        self.assertTrue(nfa.match(''))
        self.assertTrue(nfa.match('a'))
        self.assertTrue(nfa.match('aa'))
        self.assertTrue(nfa.match('aaaaa'))
        self.assertFalse(nfa.match('aab'))
    def testplus_match(self):
        nfa = self.FA.from_post('b+')
        self.assertFalse(nfa.match(''))
        self.assertTrue(nfa.match('b'))
        self.assertTrue(nfa.match('bbbbb'))
        self.assertFalse(nfa.match('bbba'))
    def test_quest_match(self):
        nfa = self.FA.from_post('a?')
        self.assertTrue(nfa.match(''))
        self.assertTrue(nfa.match('a'))
        self.assertFalse(nfa.match('aa'))
        self.assertFalse(nfa.match('abc'))
    def test_mixed_case(self):
        nfa = self.FA.from_post('a?b+.c*.')
        self.assertTrue(nfa.match('b'))
        self.assertTrue(nfa.match('abc'))
        self.assertTrue(nfa.match('abbc'))
        self.assertTrue(nfa.match('abbccc'))
        self.assertFalse(nfa.match('aabc'))
        self.assertFalse(nfa.match('ac'))
        self.assertFalse(nfa.match(''))
    def test_complex_case(self):
        buf = [] 
        buf.append('0r|') # 0|r
        buf.append('20*.1r|.') # 20*(1|r)
        buf.append('120*.2.|0*.')# (1|20*2)0*
        buf.append('2r|10*.1r|.|') #(2|r)|(10*(1|r))
        source = ''.join(buf) + '.||*'
        prog = self.FA.from_post(source)
        self.assertFalse(prog.match('11'))
        self.assertTrue(prog.match(''))
        self.assertTrue(prog.match('111'))
        self.assertTrue(prog.match('21'))
        self.assertTrue(prog.match('2100011r'))
        self.assertTrue(prog.match('r'))
        self.assertTrue(prog.match('r00'))
        self.assertTrue(prog.match('20002101'))
class NFAMatchTests(MatchTests):
    FA = NFA
class DFAMatchTests(MatchTests):
    FA = DFA
class DFATests(unittest.TestCase):
    def test_DState(self):
        nfa = NFA.from_post('ab*|')
        clist = []
        nfa.start.appendTo(clist)
        ds = DState(clist)
        s1 = nfa.start
        s2 = clist[0]
        s4 = clist[1]
        s3 = s4.out
        s5 = clist[2]
        nlist = []
        ds.transfer('a', nlist)
        self.assertEqual([s5], nlist)
        nlist = []
        ds.transfer('b', nlist)
        self.assertEqual([s4, s5], nlist)
    def test_find(self):
        dfa = DFA.from_post('ab*|')
        clist = []
        dfa.start.appendTo(clist)
        s = DState(clist)
        dfa.states.append(s)
        s1 = dfa.start
        s2 = clist[0]
        s4 = clist[1]
        s3 = s4.out
        s5 = clist[2] 
        self.assertEqual(s, dfa.find([s2, s4, s5]))
        self.assertEqual(s, dfa.find([s4, s2, s5]))
        self.assertEqual(None, dfa.find([s1, s2]))
    def test_match_simple(self):
        dfa = DFA.from_post('ab*|')
        self.assertTrue(dfa.match('a'))
        self.assertFalse(dfa.match('aa'))
        self.assertFalse(dfa.match('ab'))
        self.assertTrue(dfa.match('bb'))
        self.assertTrue(dfa.match(''))
        self.assertFalse(dfa.match('bba'))
        self.assertFalse(dfa.match('bbbba'))
        self.assertEqual(4, len(dfa.states))
if __name__ == '__main__':
    test_support.run_unittest(RegexTests, NFAMatchTests, DFAMatchTests, DFATests)