読者です 読者をやめる 読者になる 読者になる

Insertion Sort

SQLiteソースコードリーディングしているのですが、ぼくはソート・アルゴリズムについてあまり勉強したことがないので、アルゴリズムのところでいちいち躓いています。SQLiteで使われているアルゴリズムは本にも載っているような古典アルゴリズムがほとんどだと思いますが、オープンソースでも使われているってことで勉強したついでにここで紹介します。

まずはオリジナルから引用しておきます。
:btree.c 6207行目-

  /*
  ** Put the new pages in accending order.  This helps to
  ** keep entries in the disk file in order so that a scan
  ** of the table is a linear scan through the file.  That
  ** in turn helps the operating system to deliver pages
  ** from the disk more rapidly.
  **
  ** An O(n^2) insertion sort algorithm is used, but since
  ** n is never more than NB (a small constant), that should
  ** not be a problem.
  **
  ** When NB==3, this one optimization makes the database
  ** about 25% faster for large insertions and deletions.
  */
  for(i=0; i<k-1; i++){
    int minV = apNew[i]->pgno;
    int minI = i;
    for(j=i+1; j<k; j++){
      if( apNew[j]->pgno<(unsigned)minV ){
        minI = j;
        minV = apNew[j]->pgno;
      }
    }
    if( minI>i ){
      MemPage *pT;
      pT = apNew[i];
      apNew[i] = apNew[minI];
      apNew[minI] = pT;
    }
  }

次のコードは、このアルゴリズムの学習のためにコードを少し関数を抽出して、少しテストを書きたしたものです。

#!/usr/bin/env python2.6
def insertion_sort(inp):
    length = len(inp)
    for i in range(length-1):
        min_idx, min_value = find_smallest_value(inp, inp[i], i)
        is_find = min_idx > i
        if is_find:
            swap(inp, i, min_idx)
def swap(inp, i, j):
    t = inp[i]
    inp[i] = inp[j]
    inp[j] = t
def find_smallest_value(inp, value, i):
    for j in range(i+1, len(inp)):
        if inp[j] < value:
            value = inp[j]
            i = j
    return i, value
import unittest
import random
class InsertionSortTests(unittest.TestCase):
    def test_sort(self):
        inp = range(23)
        random.shuffle(inp)
        insertion_sort(inp)
        self.assertEqual(range(23), inp)
    def test_find_smallest_value(self):
        self.assertEqual((0, 3), find_smallest_value([3,6,9,12], 3, 0))
        self.assertEqual((1, 3), find_smallest_value([12,3, 6,9], 12, 0))
        self.assertEqual((3, 3), find_smallest_value([12, 9, 6,3], 12, 0))

if __name__ == '__main__':
    unittest.main()