N-WAY MERGE の実装(2)

"ORDER BY"のときにSQLite中で動いているsort関数について調べているのですが、前回のは少しinputと結果が違っているようで、実はこんなのだと今のところ思っています。やっていることはcountの大きさのトーナメントを用意して、inputの配列の0-(count-1)成分のどれが一番小さいか戦わせます。エントリされた要素の中で一位が抜けて、抜けたところにcount-(len(inp)-1)成分が順次オブザーバーとして最エントリしトーナメントをし直します。これを繰り返した結果は後ろの成分の内count分だけsortされています。これでは完全にソートされていないので、結果をインプットとしてトーナメントをやり直します。トーナメントをやり直す度にcount分だけsortされた結果が出るので、sortされる条件はトーナメントをした回数(i+1)とすると"i*MAX_MERGE_COUNT < len(inp)"が成り立つだけsortを繰り返せば、完全にsortできることが分かります。

#!/usr/bin/env python2.6
class Merge(object):
    MAX_MERGE_COUNT = 16
    def __init__(self, inp):
        self.len = len(inp)
        self.inp = inp
    @staticmethod
    def count(inp):
        length = len(inp)
        cnt = 2
        i = 2
        while cnt < length:
            cnt += 1 << i
            i += 1
        if cnt < Merge.MAX_MERGE_COUNT:
            return cnt
        return Merge.MAX_MERGE_COUNT
    def sort(self):
        result = []
        i = 0
        result = self.inp
        while i * self.MAX_MERGE_COUNT < self.len:
            sorter = Sorter(result, Merge.count(result))
            result = [it.value for it in sorter]
            i += 1
        return result

class Sorter(object):
    def __init__(self, inp, count):
        it = iter(inp)
        self.count = count
        self.aTree = [None]*self.count
        self.aIter = []
        for i in range(self.count):
            self.aIter.append(Iterator(it))
        for idx in reversed(xrange(1, self.count)):
            self.cmp(idx)
    def cmp(self, idx):
        if idx >= self.count/2:
            i = (idx - self.count/2)*2
            j = i + 1
        else:
            i = self.aTree[idx*2]
            j = self.aTree[idx*2 + 1]
        if self.aIter[i] <= self.aIter[j]:
            self.aTree[idx] =  i
        else:
            self.aTree[idx] = j
    def __iter__(self):
        return self
    def next(self):
        idx = self.aTree[1]
        ret = self.aIter[idx]
        if ret.value is None:
            raise StopIteration
        self.aIter[idx] = next(ret)
        i = (self.count + idx)/2
        while i:
            self.cmp(i)
            i = i / 2
        return ret

class Iterator(object):
    def __init__(self, inp):
        try:
            self.value = next(inp)
        except StopIteration:
            self.value = None
        self._iter = inp
    def next(self):
        return Iterator(self._iter)
    def __ge__(self, other):
        if self.value is None:
            return True
        elif other.value is None:
            return False
        else:
            return self.value >= other.value
import unittest
import random    
class MergeTests(unittest.TestCase):
    def test_iter(self):
        inp = iter([3,6,9,12])
        it1 = Iterator(inp)
        it2 = Iterator(inp)
        it3 = Iterator(inp)
        self.assertEqual(3, it1.value)
        self.assertEqual(6, it2.value)
        self.assertEqual(9, it3.value)
        it4 = next(it1)
        self.assertEqual(12, it4.value)
        it5 = next(it2)
        self.assertEqual(None, it5.value)
        self.assertTrue(it1 <= it2)
        self.assertTrue(it3 >= it2)
        self.assertTrue(it4 <= it5)
    def test_sorter_cmp(self):
        # even case
        inp = [3,2,5,1,4,6]
        sorter = Sorter(inp, Merge.count(inp))
        self.assertEqual([None, 3,3,1,3,4], sorter.aTree)
        # odd case 
        inp = [1,9,7,2,4]
        sorter = Sorter(inp, Merge.count(inp))
        self.assertEqual([None, 0, 3, 0, 3, 4], sorter.aTree)
        # single
        sorter = Sorter([2], 2)
        self.assertEqual([None, 0], sorter.aTree)
    def test_sorter_next(self):
        inp = [2,3,5,1,4,6]
        s = Sorter(inp, Merge.count(inp))
        self.assertEqual(1, next(s).value)
        self.assertEqual([2,3,4,5,6], [i.value for i in s])
        inp = [5,3,6,4,2,7,1]
        s = Sorter(inp, 4)
        self.assertEqual([3,2,4,1,5,6,7], [i.value for i in s]) 
    def test_merge(self):
        inp = range(101)
        random.shuffle(inp)
        m = Merge(inp)
        self.assertEqual(range(101), m.sort())
if __name__ == '__main__':
    unittest.main()