3-SATを解く乱択アルゴリズムを実装してみた

数学ガール-乱択アルゴリズム-の9章にある3-SATを解く乱択アルゴリズムをPythonで実装してみました。p336の強正美優問題は正しく計算できるところまで確認しました(テストではupdateとbf_three_satを使ってます)。ラウンドあたりの成功率が妥当かどうか調べ中です。

#!/usr/bin/env python2.6
# -*- encoding: utf-8 -*-
import unittest
from three_sat import *
from  test import test_support
class ThreeSatTests(unittest.TestCase):
    def test_literal(self):
        values = [True, False]
        x = Literal('', values, 0)
        self.assertEqual('', str(x))
        self.assertEqual(True, x.value())
        n_x = Not(x)
        self.assertEqual(False, n_x.value())
        self.assertEqual('¬優', str(n_x))        
        y = Literal('', values, 1)
        self.assertEqual('', y.name)
        self.assertEqual(False, y.value())
        self.assertEqual(True, Not(y).value())
        values[0] = False
        x = Literal('', values, 0)
        self.assertEqual('', x.name)
        self.assertEqual(False, x.value())
        self.assertEqual(True, n_x.value()) 
    def test_literal_and_or_returns_clause(self):
        values = [True, False, False]
        x1 = Literal('', values, 0)
        x2 = Literal('', values, 1)
        x3 = Literal('', values, 2)
        c1 = x1.Or(Not(x2))
        self.assertEqual('(優v¬正)', str(c1))
        self.assertEqual(True, c1.value())
        c2 = x2.Or(x3)
        self.assertEqual('(正v美)', str(c2))
        self.assertEqual(False, c2.value())
        c = x1.Or(x2).Or(x3)
        self.assertEqual(True, c.value())
        self.assertEqual('(優v正v美)', str(c))
    def test_clause_and(self):
        values = [True, False, False, True]
        x1 = Literal('', values, 0)
        x2 = Literal('', values, 1)
        x3 = Literal('', values, 2)
        x4 = Literal('', values, 3)
        c1 = x1.Or(Not(x2)).Or(Not(x3))
        c2 = x2.Or(x3).Or(Not(x4))
        cnf = c1.And(c2)
        self.assertEqual('(優v¬正v¬美)^(正v美v¬強)', str(cnf))
        self.assertEqual(False, cnf.value())
    @staticmethod
    def update(cvalues, nvalues):
        for i, v in enumerate(nvalues):
            cvalues[i] = v
    @staticmethod
    def bf_three_sat(cnf, values, values_iter):
        sucesses = []
        for round_, nvalues in enumerate(values_iter):
            ThreeSatTests.update(values, nvalues)
            if cnf.value():
                sucesses.append(round_ + 1)
        return sucesses   
if __name__ == '__main__':
    test_support.run_unittest(ThreeSatTests)

関数random_walk_three_sat(cnf, n, round_)はthree_sat(cnf, n, round_, selector)のselectorがRandomSelector(n)のときです。テストのために、selectorを次のクラスで置き換えたテストもして確認してます。

class SimpleSelector(object):
    def __init__(self, data):
        self.it = iter(data)
    def next_values(self):
        return next(self.it)
    def select(self, clause):
       pass

テスト中で呼び出してる関数本体は次のように定義されています。

# -*- encoding: utf-8 -*-
#three_sat.py
class Literal(object):
    def __init__(self, name, values, idx, not_ = False):
        self.name = name
        self.values = values
        self.idx = idx
        self.not_ = not_
    def value(self):
        if self.not_:
            return not self.values[self.idx]
        return self.values[self.idx]
    def Or(self, other):
        return Clause(self, other)
    def __str__(self):
        if self.not_:
            return '¬' + self.name
        return self.name 
    def invert(self):
        self.values[self.idx] = not self.value()
def Not(literal):
    return Literal(literal.name, literal.values, literal.idx, True)
class Clause(object):
    def __init__(self, x, y):
        self.literals = [x, y]
        self.size = 2
    def Or(self, x):
        self.literals.append(x)
        self.size += 1
        return self
    def And(self, x):
        return CNF(self, x)
    def value(self):
        for literal in self.literals:
            if literal.value():
                return True
        return False
    def values(self):
        return self.literals[0].values
    def __str__(self):
        return '(%s)' % 'v'.join([str(l) for l in self.literals])
    def update(self, idx):
        self.literals[idx].invert()
class CNF(object):
    def __init__(self, c1, c2):
        self.clauses = [c1, c2]
        self.failed = None
        self.size = 2
    def And(self, x):
        self.clauses.append(x)
        self.size += 1
        return self
    def __str__(self):
        return '^'.join([str(c) for c in self.clauses])
    def value(self):
        for clause in self.clauses:
            if not clause.value():
                self.failed = clause
                return False
        return True
    def values(self):
        return self.clauses[0].values()
    def update(self, nvalues):
        cvalues = self.values()
        for i, v in enumerate(nvalues):
            cvalues[i] = v
import random
class RandomSelector(object):
    def __init__(self, size):
        self.size = size
    def next_values(self):
        return [bool(random.randint(0, 1)) for i in range(self.size)]
    def select(self, clause):
        clause.update(random.randint(0, clause.size-1))
def random_walk_three_sat(cnf, n, round_):
    return three_sat(cnf, n, round_, RandomSelector(n))
def three_sat(cnf, size, round_, selector):
    #p353のW4でselector.next_valuesをW10-12でselector.selectを使います。
    pass