sqliteのデータベースファイルを読み込む(2)

sqliteのdbファイルを読み込むpythonスクリプトを書いています。読み込むことがある程度できるようになったら、書き込みもできるようにもしたいです。前回から変わったところはselect文みたいにカラムを指定する(現在は何番目のカラムかを数字で指定する)ことができるようになったことです。WHERE句,ORDER BY, LIMITなどはまだ使えません。


$ ./pysql.py insert.db Products 1 0
[40, u'apple']
[30, u'orange']

使い方はこんな感じです。

$./pysql.py
usage:./pysql.py <dabasefile> <table>? <cols>?

dbファイルだけ指定すると、自動的にsqlite_masterを読み込んで標準出力します。

$ ./pysql.py insert.db
['table', 'Products', 'Products', 2, 'CREATE TABLE Products(name text, price integer)']
['table', 'People', 'People', 3, 'CREATE TABLE People(name text, age integer)']
['table', 'Tbl', 'Tbl', 4, 'CREATE TABLE Tbl(col1 text, col2 text, col3 integer, col4 integer)']

カラム名を指定しないと、"SELECT * FROM Products"のように動きます。

$ ./pysql.py insert.db Products
['apple', 40]
['orange', 30]

カラム名を指定することもできます。"SELECT price, name FROM Products"のようなイメージです。

$ ./pysql.py insert.db Products 1
[40]
[30]
$ ./pysql.py insert.db Products 0
['apple']
['orange']
$ ./pysql.py insert.db Products 1 0
[40, 'apple']
[30, 'orange']

#!/usr/bin/env python2.6
# -*- coding: utf-8 -*-
import bitstring
HEADER_OFFSET_PAGE1 = 100
#page type
INTKEY = 0x01
ZERO_DATA = 0x02
LEAF_DATA = 0x04
LEAF = 0x08
TABLES = {'sqlite_master':(1, 
"""CREATE sqlite_master(
                type text, 
                name text, 
                tbl_name text, 
                rootpage integer, 
                sql text)'
""")}
SIZE = [0,1,2,3,4,6,8,8,0,0,0,0]
def get_fieldsize(serial_type):
    if serial_type >= 12:
        return (serial_type-12)/2
    else:
        return SIZE[serial_type];
class Pager(object):
    def __init__(self, fname):
        self.fp = bitstring.ConstBitStream(filename=fname)
        self.pagesize = self.get_pagesize()
        self.pages = {}
        self.fp.pos = 20*8

        nReserve = self.fp.read('uint:8')
        self.usableSize = self.pagesize - nReserve
        self.maxLeaf = self.usableSize - 35
        self.minLeaf = (self.usableSize - 12) * 32/255 - 23
        self.maxLocal = (self.usableSize - 12) * 64/255 - 23
        self.minLocal = self.minLeaf
    def read(self, type_fmt, pos):
        self.fp.pos = pos
        return self.fp.read(type_fmt)
    def getPage(self, iTab):
        page = self.pages.get(iTab)
        if page is None:
            page = Page(self, iTab)
        return page
# primitive
    def get2byte(self):
       return self.fp.read('uint:8') << 8 | self.fp.read('uint:8')
    def get4byte(self):
        return self.fp.read('uint:8') << 24 | self.fp.read('uint:8') << 16 |\
        self.fp.read('uint:8') << 8 | self.fp.read('uint:8')
    def get_pagesize(self):
        self.fp.pos = 16*8
        return self.fp.read('uint:8') << 8 | self.fp.read('uint:8') << 16
    def getVarint(self):
        p = []
        p.append(self.fp.read('uint:8'))
        if not (p[0] & 0x80):
            return p[0], 1
        p.append(self.fp.read('uint:8'))
        if not (p[1] & 0x80):
            v = p[0] & 0x7f
            v <<= 7
            v |= p[1]
            return v, 2
        p.append(self.fp.read('uint:8'))
        if not (p[2] & 0x80):
            v = p[0] & 0x7f
            v <<= 7
            v |= p[1] & 0x7f
            v <<= 7
            v |= p[2] & 0x7f
            return v, 3
        raise Exception('too long')
    def set_cellsize(self, page):
        self.fp.pos = (page.pos + page.hdroffset + 3)*8
        page.nCell = self.get2byte()
    def get_pagetype(self, page):
        self.fp.pos = (page.hdroffset + self.pagesize * (page.pageno-1)) * 8
        return self.fp.read('uint:8')
    def find_cell_offset(self, iCell, page):
        mask = self.pagesize - 0x01
        celloffset = page.pos + page.hdroffset + 8 + page.childPtrSize
        if iCell == page.nCell:
            self.fp.pos = (celloffset-4)*8
            return self.fp.pos
        self.fp.pos = (celloffset + iCell*2)*8
        self.fp.pos = (page.pos + (mask & self.get2byte()))*8
        return self.fp.pos
MAX_DEPTH = 20
class Cursor(object):
    def __init__(self, fp, pgno):
        self.fp = fp
        self.pgno = pgno
        self.cell = None
        self.pages = [None]*MAX_DEPTH
        self.iCells = [None]*MAX_DEPTH
        self.depth = -1
    def moveToLeftMost(self):
        page = self.fp.getPage(self.pgno)
        self.depth += 1
        self.pages[self.depth] = page
        self.iCells[self.depth] = 0
        while not page.leaf:
            self.depth += 1  
            self.iCells[self.depth] = 0
            page = page.find_entry(self.fp, 0)
            self.pages[self.depth] = page
        assert(page.leaf)
        if page.nCell == 0:# for empty table
            raise StopIteration
        self.cell = page.find_entry(self.fp, self.iCells[self.depth])
    def moveNextLeaf(self):
        page = self.pages[self.depth]
        self.iCells[self.depth] += 1
        iCell = self.iCells[self.depth]
        if iCell > page.nCell - page.leaf:
            if self.depth == 0:
                raise StopIteration 
            self.depth -= 1
            self.pgno = page.pageno
            return self.moveNextLeaf()
        else:
            entry = page.find_entry(self.fp, iCell)
            entry.setcell(self)
            return self.cell
    def next(self):
        if self.cell is None:
            self.moveToLeftMost()
            return self.cell
        else:
            return self.moveNextLeaf() 
    def __iter__(self):
        return self
    def moveTo(self, iCell, pgno=None):
        if pgno is None:
            page = self.pages[self.depth]
        else:
            page = self.fp.getPage(pgno)
            self.pages[self.depth] = page
        assert(page.leaf)
        self.iCells[self.depth] = iCell
        self.cell = page.find_entry(self.fp, iCell)
class Page(object):
    def __init__(self, pager, pageno):
        if pageno == 1:
            self.hdroffset = HEADER_OFFSET_PAGE1
        else:
            self.hdroffset = 0
        self.pageno = pageno
        leaf = False
        childPtrSize = 4
        if LEAF & pager.get_pagetype(self):
            leaf = True
            childPtrSize = 0 
        self.leaf = leaf
        self.childPtrSize = childPtrSize

        self.pos = pager.pagesize*(pageno-1)
        self.maxLocal = pager.maxLeaf
        self.minLocal = pager.minLeaf
        self.nCell = None
        self.nField = None
        pager.set_cellsize(self)
        pager.pages[pageno] = self 
    def find_entry(self, fp, iCell):
        pos = fp.find_cell_offset(iCell, self)
        if not self.leaf:
            pgno = fp.get4byte() 
            return fp.getPage(pgno) 
        n = 0
        nPayload, tn =  fp.getVarint()
        n += tn
        intKey, tn = fp.getVarint()
        n += tn
        cell_hdr_offset = n
        keyoff, tn = fp.getVarint()
        n = tn
        stypes = []
        while n < keyoff:
            serial_type, tn = fp.getVarint()
            n += tn
            stypes.append(serial_type)
        if nPayload <= self.maxLocal:
            nLocal = nPayload
        else:
            minLocal = self.minLocal
            maxLocal = self.maxLocal
            surplus = minLocal + (nPayload - minLocal) % (fp.usableSize - 4)
            if surplus <= maxLocal:
                nLocal = surplus
            else:
                nLocal = minLocal
        return Cell(self, pos, nPayload, intKey, cell_hdr_offset, keyoff, stypes, nLocal)
    def setcell(self, cursor):
        cursor.pgno = self.pageno
        cursor.moveToLeftMost()
class Cell(object):
    def __init__(self, page, pos, nPayload, rowid, hdr_size, keyoffset, stypes, nLocal):
        self.parent = page
        offset = (hdr_size + keyoffset)*8
        self.pos = pos
        self.hdr = hdr_size*8
        self.nPayload = nPayload
        self.rowid = rowid
        self.stypes = stypes
        self.nLocal = nLocal

        self.nField = len(stypes)
        self.offsets = [offset]
        for serial_type in stypes:
            offset += get_fieldsize(serial_type)*8
            self.offsets.append(offset)
    def getvalue(self, fp, iField):
        serial_type = self.stypes[iField]
        offset = self.offsets[iField]
        payload_size = get_fieldsize(serial_type)
        if serial_type == 0 or serial_type == 10 or serial_type == 11:
            return None
        elif 1 <= serial_type and serial_type <= 6:
            return fp.read('int:%d' % (payload_size*8), self.pos + offset)
        else:
            page = self.parent
            if payload_size > self.nLocal:
                ovflSize = fp.usableSize - 4
                keyoffset = (offset - self.hdr)/8
                size = self.nLocal - keyoffset
                fp.fp.pos = self.pos + self.hdr + self.nLocal * 8
                npgno = fp.get4byte()
                buf = [fp.read('bytes:%d' % size, self.pos+offset)]
                nOverflow = (payload_size-self.nLocal+ovflSize-1)/ovflSize
                payload_size -= size
                i = keyoffset/ovflSize
                while payload_size > 0 and npgno != 0:
                    page = fp.getPage(npgno)
                    pos = (page.pos+4)*8
                    if payload_size > ovflSize:
                        nbytes = ovflSize
                        payload_size -= ovflSize
                    else:
                        nbytes = payload_size
                        payload_size = 0
                    buf.append(fp.read('bytes:%d' % nbytes, pos))
                    i+=1
                    fp.fp.pos = fp.pagesize*(npgno-1)*8
                    npgno = fp.get4byte() 
                if payload_size != 0:
                    raise Exception("database file is broken")
                return ''.join(buf)
            return fp.read('bytes:%d' % payload_size, self.pos+offset)
    def setcell(self, cursor):
        cursor.cell = self
def get_rootpageno(tname):
    rootpage, sql = TABLES[tname]
    return rootpage
def tables_add(row):
    if row[0] == 'table':
        TABLES[row[1]] = (row[3], row[4])
def printf(row):
    print row
def init_db(fname):
    iTab = 1
    fp = Pager(fname)
    cursor = Cursor(fp, iTab)
    for cell in cursor:
        tables_add([cell.getvalue(fp, i) for i in range(5)])
    return fp
def main(fname, tabname=None, *argv):
    fp = init_db(fname)
    if tabname is None:
        tabname = 'sqlite_master'
    iTab = get_rootpageno(tabname)
    cursor = Cursor(fp, iTab)
    cursor.moveToLeftMost()
    cell = cursor.cell
    if argv == ():
        indices = range(cell.nField)
    else:
        indices = []
        for idx in argv:
            indices.append(int(idx) % cell.nField)
        print tuple([cell.getvalue(fp, idx) for idx in indices])
    for cell in cursor:
        print tuple([cell.getvalue(fp, idx) for idx in indices])
class DB(object):
    def __init__(self, filename):
        self.pager = init_db(filename)
    def find(self, dic):
        cols = dic['cols']
        cursor = Cursor(self.pager, get_rootpageno(dic['from']))
        offset = dic.get('offset')
        if offset is not None:
            for i in range(offset):
                next(cursor)
        limit = dic.get('limit')
        n = 0
        for cell in cursor:
            if limit is not None and limit == n:
                raise StopIteration
            values = []
            for col in cols:
                values.append(cell.getvalue(self.pager, col))
            yield values
            n += 1
import sys
if __name__ == '__main__':
    argc = len(sys.argv)
    if argc < 2 :
        print "usage:%s <dabasefile> <table>? <cols>?" % sys.argv[0]
        sys.exit(1)
    main(*sys.argv[1:])