N-WAY MERGEの実装

SQLiteの中のソート関数(ORDER BYを書いたときに内部で動く関数)を読んでいたら、複数のソートされた配列をマージしていたので、それをpythonに焼き直してみました。アルゴリズムとしては、各配列の第0要素をエントリして、一番小さな数を決めるのにトーナメント戦と同じ方法で比較をします。このようにして全配列中で最も小さな数を見つけます。next関数で一番小さな要素を取り出すとき、その要素の属する配列から次に小さな数をトーナメントにエントリし直します。そして、新しい要素と、抜けた要素が戦った要素を比較しなおし再び1位を決めます。このとき1位になった要素がnext関数で次に取り出される要素になります。地味に悩んでいるところが、トーナメントの対戦回数の総数(tot)で、これは登録した配列の数をniterとすると、tot = 1 + 2 + 4 + 8 + ... n(tot < niter)なんだろうなと思います。

#!/usr/bin/env python2.6
class Merge(object):
    def __init__(self, *lists):
        self.count = 0
        self.aIter = []
        for l in lists:
            self.aIter.append(l)
        length = len(lists)
        self.len = 2
        i = 2
        while self.len < length:
            self.len += 1 << i
            i += 1
        self.aTree = [None]*self.len
        for idx in reversed(xrange(1, self.len)):
            self.cmp(idx)
    def cmp(self, idx):
        if idx >= self.len/2:
            i = (idx - self.len/2)*2
            j = i + 1
        else:
            i = self.aTree[idx*2]
            j = self.aTree[idx*2 + 1]
        try:
            v1 = self.aIter[i][0]
        except IndexError:
            self.aTree[idx] = j
            return
        try:
            v2 = self.aIter[j][0]
        except IndexError:
            self.aTree[idx] = i
            return
        if v1 <= v2:
            self.aTree[idx] =  i
        else:
            self.aTree[idx] = j
    def __iter__(self):
        return self
    def next(self):
        idx = self.aTree[1]
        try:
            ret = self.aIter[idx][0]
        except IndexError:
            raise StopIteration
        i = (self.len + idx)/2
        self.aIter[idx] = self.aIter[idx][1:]
        while i:
            self.cmp(i)
            i = i / 2
        return ret
import unittest
class MergeTests(unittest.TestCase):
    def test_cmp(self):
        # even case
        m = Merge([3],[2], [5], [1], [4], [6])
        self.assertEqual([None, 3, 3, 1, 3, 4], m.aTree)
        # odd case 
        m = Merge([1],[9], [7], [2], [4])
        self.assertEqual([None, 0, 3, 0, 3, 4], m.aTree)
        # single
        m = Merge([2,4,6,8])
        self.assertEqual([None, 0], m.aTree)
    
    def test_next(self):
        # even case
        m = Merge([2,4,6,8], [3,5,7], [9,81], [1, 3, 5, 7, 11])
        result = []
        for i in m:
            result.append(i)
        self.assertEqual([1, 2, 3, 3, 4, 5, 5, 6, 7, 7, 8, 9, 11, 81]
        , result)

        # odd case
        m = Merge([2,4,6,8], [3,5,7], [9,81])
        result = []
        for i in m:
            result.append(i)
        self.assertEqual([2,3,4,5,6,7,8,9,81]
        , result)
        
    

if __name__ == '__main__':
    unittest.main()