sqliteのデータベースファイルを読み込む(3)[LIMITの実装]
前回のソースコードにDBクラスを追加しました。LIMIT offset, limit句はcursorをoffsetまで進めて、limit回値を返します。つまりoffset回だけ、cursorを動かすことになるので、offsetを設定してもそれだけのコストはかかります。また、cursorのレベルで見れば、1つのレコードを取りに行くのに必要な時間はレコードの大きさにはあまり影響しません(btreeの深さだけかかる)。
利用の仕方は次のような感じです。
def test_db_limit(self): con = sqlite3.connect('trac.db') cur = con.cursor() expections = cur.execute("SELECT text FROM wiki LIMIT 3, 10") db = DB('trac.db') rows = db.find({'from':'wiki', 'cols':(5,), 'offset': 3, 'limit': 10}) for expection in expections: self.assertEqual(expection[0].encode('utf-8'), next(rows)[0]) self.assertRaises(StopIteration, next, rows)
以下が実装です。これを見ると、WHEREなんかも、forの中に適当なifステートメントを足せば実現できそうです。
#!/usr/bin/env python2.6 # -*- coding: utf-8 -*- import bitstring HEADER_OFFSET_PAGE1 = 100 #page type INTKEY = 0x01 ZERO_DATA = 0x02 LEAF_DATA = 0x04 LEAF = 0x08 TABLES = {'sqlite_master':(1, """CREATE sqlite_master( type text, name text, tbl_name text, rootpage integer, sql text)' """)} SIZE = [0,1,2,3,4,6,8,8,0,0,0,0] def get_fieldsize(serial_type): if serial_type >= 12: return (serial_type-12)/2 else: return SIZE[serial_type]; class Pager(object): def __init__(self, fname): self.fp = bitstring.ConstBitStream(filename=fname) self.pagesize = self.get_pagesize() self.pages = {} self.fp.pos = 20*8 nReserve = self.fp.read('uint:8') self.usableSize = self.pagesize - nReserve self.maxLeaf = self.usableSize - 35 self.minLeaf = (self.usableSize - 12) * 32/255 - 23 self.maxLocal = (self.usableSize - 12) * 64/255 - 23 self.minLocal = self.minLeaf def read(self, type_fmt, pos): self.fp.pos = pos return self.fp.read(type_fmt) def getPage(self, iTab): page = self.pages.get(iTab) if page is None: page = Page(self, iTab) return page # primitive def get2byte(self): return self.fp.read('uint:8') << 8 | self.fp.read('uint:8') def get4byte(self): return self.fp.read('uint:8') << 24 | self.fp.read('uint:8') << 16 |\ self.fp.read('uint:8') << 8 | self.fp.read('uint:8') def get_pagesize(self): self.fp.pos = 16*8 return self.fp.read('uint:8') << 8 | self.fp.read('uint:8') << 16 def getVarint(self): v = 0 p = [self.fp.read('uint:8')] if not (p[0] & 0x80): return p[0] & 0x7f, 1 depth = 0 while p[depth] & 0x80 and depth < 9: p.append(self.fp.read('uint:8')) depth += 1 for i in range(depth): v |= p[i] & 0x7f v <<= 7 v |= p[i+1] & 0x7f if depth == 8: v <<= 8 v |= self.fp.read('uint:8') return v, depth+1 def set_cellsize(self, page): self.fp.pos = (page.pos + page.hdroffset + 3)*8 page.nCell = self.get2byte() def get_pagetype(self, page): self.fp.pos = (page.hdroffset + self.pagesize * (page.pageno-1)) * 8 return self.fp.read('uint:8') def find_cell_offset(self, iCell, page): mask = self.pagesize - 0x01 celloffset = page.pos + page.hdroffset + 8 + page.childPtrSize if iCell == page.nCell: self.fp.pos = (celloffset-4)*8 return self.fp.pos self.fp.pos = (celloffset + iCell*2)*8 self.fp.pos = (page.pos + (mask & self.get2byte()))*8 return self.fp.pos MAX_DEPTH = 20 class Cursor(object): def __init__(self, pager, pgno): self.pager = pager self.pgno = pgno self.cell = None self.pages = [None]*MAX_DEPTH self.iCells = [None]*MAX_DEPTH self.depth = -1 def getrowid(self): return self.cell.rowid def getvalue(self, pager, iField): cell = self.cell serial_type = cell.stypes[iField] offset = cell.offsets[iField] payload_size = get_fieldsize(serial_type) if serial_type == 0 or serial_type == 10 or serial_type == 11: return None elif 1 <= serial_type and serial_type <= 6: return pager.read('int:%d' % (payload_size*8), cell.pos + offset) else: page = self.pages[self.depth] if payload_size > cell.nLocal: ovflSize = pager.usableSize - 4 keyoffset = (offset - cell.hdr)/8 size = cell.nLocal - keyoffset pager.fp.pos = cell.pos + cell.hdr + cell.nLocal * 8 npgno = pager.get4byte() buf = [pager.read('bytes:%d' % size, cell.pos+offset)] nOverflow = (payload_size - cell.nLocal+ovflSize-1)/ovflSize payload_size -= size i = keyoffset/ovflSize while payload_size > 0 and npgno != 0: page = pager.getPage(npgno) pos = (page.pos+4)*8 if payload_size > ovflSize: nbytes = ovflSize payload_size -= ovflSize else: nbytes = payload_size payload_size = 0 buf.append(pager.read('bytes:%d' % nbytes, pos)) i+=1 pager.fp.pos = pager.pagesize*(npgno-1)*8 npgno = pager.get4byte() if payload_size != 0: raise Exception("database file is broken") return ''.join(buf) return pager.read('bytes:%d' % payload_size, cell.pos + offset) def moveToLeftMost(self): page = self.pager.getPage(self.pgno) self.depth += 1 self.pages[self.depth] = page self.iCells[self.depth] = 0 while not page.leaf: self.depth += 1 self.iCells[self.depth] = 0 page = page.find_entry(self.pager, 0) self.pages[self.depth] = page assert(page.leaf) if page.nCell == 0:# for empty table raise StopIteration self.cell = page.find_entry(self.pager, self.iCells[self.depth]) def moveNextLeaf(self): page = self.pages[self.depth] self.iCells[self.depth] += 1 iCell = self.iCells[self.depth] if iCell > page.nCell - page.leaf: if self.depth == 0: raise StopIteration self.depth -= 1 self.pgno = page.pageno return self.moveNextLeaf() else: entry = page.find_entry(self.pager, iCell) entry.setcell(self) return self.cell def next(self): if self.cell is None: self.moveToLeftMost() return self else: self.moveNextLeaf() return self def __iter__(self): return self def moveTo(self, iCell, pgno=None): if pgno is None: page = self.pages[self.depth] else: page = self.pager.getPage(pgno) self.pages[self.depth] = page assert(page.leaf) self.iCells[self.depth] = iCell self.cell = page.find_entry(self.pager, iCell) class Page(object): def __init__(self, pager, pageno): if pageno == 1: self.hdroffset = HEADER_OFFSET_PAGE1 else: self.hdroffset = 0 self.pageno = pageno leaf = False childPtrSize = 4 if LEAF & pager.get_pagetype(self): leaf = True childPtrSize = 0 self.leaf = leaf self.childPtrSize = childPtrSize self.pos = pager.pagesize*(pageno-1) self.maxLocal = pager.maxLeaf self.minLocal = pager.minLeaf self.nCell = None self.nField = None pager.set_cellsize(self) pager.pages[pageno] = self def find_entry(self, pager, iCell): pos = pager.find_cell_offset(iCell, self) if not self.leaf: pgno = pager.get4byte() return pager.getPage(pgno) n = 0 nPayload, tn = pager.getVarint() n += tn intKey, tn = pager.getVarint() n += tn cell_hdr_offset = n keyoff, tn = pager.getVarint() n = tn stypes = [] while n < keyoff: serial_type, tn = pager.getVarint() n += tn stypes.append(serial_type) if nPayload <= self.maxLocal: nLocal = nPayload else: minLocal = self.minLocal maxLocal = self.maxLocal surplus = minLocal + (nPayload - minLocal) % (pager.usableSize - 4) if surplus <= maxLocal: nLocal = surplus else: nLocal = minLocal return Cell(pos, nPayload, intKey, cell_hdr_offset, keyoff, stypes, nLocal) def setcell(self, cursor): cursor.pgno = self.pageno cursor.moveToLeftMost() class Cell(object): def __init__(self, pos, nPayload, rowid, hdr_size, keyoffset, stypes, nLocal): offset = (hdr_size + keyoffset)*8 self.pos = pos self.hdr = hdr_size*8 self.nPayload = nPayload self.rowid = rowid self.stypes = stypes self.nLocal = nLocal self.nField = len(stypes) self.offsets = [offset] for serial_type in stypes: offset += get_fieldsize(serial_type)*8 self.offsets.append(offset) def setcell(self, cursor): cursor.cell = self def get_rootpageno(tname): rootpage, sql = TABLES[tname] return rootpage def tables_add(row): if row[0] == 'table': TABLES[row[1]] = (row[3], row[4]) def printf(row): print row def init_db(fname): iTab = 1 pager = Pager(fname) cursor = Cursor(pager, iTab) for nxt in cursor: tables_add([nxt.getvalue(pager, i) for i in range(5)]) return pager def main(fname, tabname=None, *argv): pager = init_db(fname) if tabname is None: tabname = 'sqlite_master' iTab = get_rootpageno(tabname) cursor = Cursor(pager, iTab) cursor.moveToLeftMost() cell = cursor.cell if argv == (): indices = range(cell.nField) else: indices = [] for idx in argv: indices.append(int(idx) % cell.nField) print tuple([cell.getvalue(pager, idx) for idx in indices]) for cell in cursor: print tuple([cell.getvalue(pager, idx) for idx in indices]) class DB(object): def __init__(self, pager): self.pager = pager def find(self, dic): cols = dic['cols'] cursor = Cursor(self.pager, get_rootpageno(dic['from'])) offset = dic.get('offset') if offset is not None: for i in range(offset): next(cursor) limit = dic.get('limit') n = 0 for nxt in cursor: if limit is not None and limit == n: raise StopIteration values = [] for col in cols: values.append(nxt.getvalue(self.pager, col)) yield tuple(values) n += 1 import sys if __name__ == '__main__': argc = len(sys.argv) if argc < 2 : print "usage:%s <dabasefile> <table>? <cols>?" % sys.argv[0] sys.exit(1) main(*sys.argv[1:])