changeset 260:9fd52748fa59

Now dbtables is compatible with Python 3.0
author jcea
date Wed, 13 Aug 2008 14:26:55 +0000
parents 06c3247f9fd5
children 0e8ff3fb940e
files Lib/bsddb/dbtables.py Lib/bsddb/test/test_dbtables.py
diffstat 2 files changed, 227 insertions(+), 76 deletions(-) [+]
line wrap: on
line diff
--- a/Lib/bsddb/dbtables.py	Wed Aug 13 14:26:20 2008 +0000
+++ b/Lib/bsddb/dbtables.py	Wed Aug 13 14:26:55 2008 +0000
@@ -26,17 +26,16 @@
 
 try:
     # For Pythons w/distutils pybsddb
-    from bsddb3.db import *
+    from bsddb3 import db
 except ImportError:
     # For Python 2.3
-    from bsddb.db import *
+    from bsddb import db
 
 # XXX(nnorwitz): is this correct? DBIncompleteError is conditional in _bsddb.c
-try:
-    DBIncompleteError
-except NameError:
+if not hasattr(db,"DBIncompleteError") :
     class DBIncompleteError(Exception):
         pass
+    db.DBIncompleteError = DBIncompleteError
 
 class TableDBError(StandardError):
     pass
@@ -104,6 +103,7 @@
                      # row in the table.  (no data is stored)
 _rowid_str_len = 8   # length in bytes of the unique rowid strings
 
+
 def _data_key(table, col, rowid):
     return table + _data + col + _data + rowid
 
@@ -142,37 +142,103 @@
         Use keyword arguments when calling this constructor.
         """
         self.db = None
-        myflags = DB_THREAD
+        myflags = db.DB_THREAD
         if create:
-            myflags |= DB_CREATE
-        flagsforenv = (DB_INIT_MPOOL | DB_INIT_LOCK | DB_INIT_LOG |
-                       DB_INIT_TXN | dbflags)
+            myflags |= db.DB_CREATE
+        flagsforenv = (db.DB_INIT_MPOOL | db.DB_INIT_LOCK | db.DB_INIT_LOG |
+                       db.DB_INIT_TXN | dbflags)
         # DB_AUTO_COMMIT isn't a valid flag for env.open()
         try:
-            dbflags |= DB_AUTO_COMMIT
+            dbflags |= db.DB_AUTO_COMMIT
         except AttributeError:
             pass
         if recover:
-            flagsforenv = flagsforenv | DB_RECOVER
-        self.env = DBEnv()
+            flagsforenv = flagsforenv | db.DB_RECOVER
+        self.env = db.DBEnv()
         # enable auto deadlock avoidance
-        self.env.set_lk_detect(DB_LOCK_DEFAULT)
+        self.env.set_lk_detect(db.DB_LOCK_DEFAULT)
         self.env.open(dbhome, myflags | flagsforenv)
         if truncate:
-            myflags |= DB_TRUNCATE
-        self.db = DB(self.env)
+            myflags |= db.DB_TRUNCATE
+        self.db = db.DB(self.env)
         # this code relies on DBCursor.set* methods to raise exceptions
         # rather than returning None
         self.db.set_get_returns_none(1)
         # allow duplicate entries [warning: be careful w/ metadata]
-        self.db.set_flags(DB_DUP)
-        self.db.open(filename, DB_BTREE, dbflags | myflags, mode)
+        self.db.set_flags(db.DB_DUP)
+        self.db.open(filename, db.DB_BTREE, dbflags | myflags, mode)
         self.dbfilename = filename
+
+        if sys.version_info[0] >= 3 :
+            class cursor_py3k(object) :
+                def __init__(self, dbcursor) :
+                    self._dbcursor = dbcursor
+
+                def close(self) :
+                    return self._dbcursor.close()
+
+                def set_range(self, search) :
+                    v = self._dbcursor.set_range(bytes(search, "iso8859-1"))
+                    if v != None :
+                        v = (v[0].decode("iso8859-1"),
+                                v[1].decode("iso8859-1"))
+                    return v
+
+                def __next__(self) :
+                    v = getattr(self._dbcursor, "next")()
+                    if v != None :
+                        v = (v[0].decode("iso8859-1"),
+                                v[1].decode("iso8859-1"))
+                    return v
+
+            class db_py3k(object) :
+                def __init__(self, db) :
+                    self._db = db
+
+                def cursor(self, txn=None) :
+                    return cursor_py3k(self._db.cursor(txn=txn))
+
+                def has_key(self, key, txn=None) :
+                    return getattr(self._db,"has_key")(bytes(key, "iso8859-1"),
+                            txn=txn)
+
+                def put(self, key, value, flags=0, txn=None) :
+                    key = bytes(key, "iso8859-1")
+                    if value != None :
+                        value = bytes(value, "iso8859-1")
+                    return self._db.put(key, value, flags=flags, txn=txn)
+
+                def put_bytes(self, key, value, txn=None) :
+                    key = bytes(key, "iso8859-1")
+                    return self._db.put(key, value, txn=txn)
+
+                def get(self, key, txn=None, flags=0) :
+                    key = bytes(key, "iso8859-1")
+                    v = self._db.get(key, txn=txn, flags=flags)
+                    if v != None :
+                        v = v.decode("iso8859-1")
+                    return v
+
+                def get_bytes(self, key, txn=None, flags=0) :
+                    key = bytes(key, "iso8859-1")
+                    return self._db.get(key, txn=txn, flags=flags)
+
+                def delete(self, key, txn=None) :
+                    key = bytes(key, "iso8859-1")
+                    return self._db.delete(key, txn=txn)
+
+                def close (self) :
+                    return self._db.close()
+
+            self.db = db_py3k(self.db)
+        else :  # Python 2.x
+            pass
+
         # Initialize the table names list if this is a new database
         txn = self.env.txn_begin()
         try:
-            if not self.db.has_key(_table_names_key, txn):
-                self.db.put(_table_names_key, pickle.dumps([], 1), txn=txn)
+            if not getattr(self.db, "has_key")(_table_names_key, txn):
+                self.db.put_bytes(_table_names_key, pickle.dumps([], 1), txn=txn)
         # Yes, bare except
         except:
             txn.abort()
@@ -196,13 +262,13 @@
     def checkpoint(self, mins=0):
         try:
             self.env.txn_checkpoint(mins)
-        except DBIncompleteError:
+        except db.DBIncompleteError:
             pass
 
     def sync(self):
         try:
             self.db.sync()
-        except DBIncompleteError:
+        except db.DBIncompleteError:
             pass
 
     def _db_print(self) :
@@ -219,7 +285,7 @@
                 else:
                     cur.close()
                     return
-        except DBNotFoundError:
+        except db.DBNotFoundError:
             cur.close()
 
 
@@ -229,6 +295,7 @@
         raises TableDBError if it already exists or for other DB errors.
         """
         assert isinstance(columns, list)
+
         txn = None
         try:
             # checking sanity of the table and column names here on
@@ -242,27 +309,30 @@
                         "bad column name: contains reserved metastrings")
 
             columnlist_key = _columns_key(table)
-            if self.db.has_key(columnlist_key):
+            if getattr(self.db, "has_key")(columnlist_key):
                 raise TableAlreadyExists, "table already exists"
 
             txn = self.env.txn_begin()
             # store the table's column info
-            self.db.put(columnlist_key, pickle.dumps(columns, 1), txn=txn)
+            self.db.put_bytes(columnlist_key, pickle.dumps(columns, 1), txn=txn)
 
             # add the table name to the tablelist
-            tablelist = pickle.loads(self.db.get(_table_names_key, txn=txn,
-                                                 flags=DB_RMW))
+            tablelist = pickle.loads(self.db.get_bytes(_table_names_key, txn=txn,
+                                                 flags=db.DB_RMW))
             tablelist.append(table)
             # delete 1st, in case we opened with DB_DUP
             self.db.delete(_table_names_key, txn=txn)
-            self.db.put(_table_names_key, pickle.dumps(tablelist, 1), txn=txn)
+            self.db.put_bytes(_table_names_key, pickle.dumps(tablelist, 1), txn=txn)
 
             txn.commit()
             txn = None
-        except DBError, dberror:
+        except db.DBError, dberror:
             if txn:
                 txn.abort()
-            raise TableDBError, dberror[1]
+            if sys.version_info[0] < 3 :
+                raise TableDBError, dberror[1]
+            else :
+                raise TableDBError, dberror.args[1]
 
 
     def ListTableColumns(self, table):
@@ -274,9 +344,9 @@
             raise ValueError, "bad table name: contains reserved metastrings"
 
         columnlist_key = _columns_key(table)
-        if not self.db.has_key(columnlist_key):
+        if not getattr(self.db, "has_key")(columnlist_key):
             return []
-        pickledcolumnlist = self.db.get(columnlist_key)
+        pickledcolumnlist = self.db.get_bytes(columnlist_key)
         if pickledcolumnlist:
             return pickle.loads(pickledcolumnlist)
         else:
@@ -284,7 +354,7 @@
 
     def ListTables(self):
         """Return a list of tables in this database."""
-        pickledtablelist = self.db.get(_table_names_key)
+        pickledtablelist = self.db.get_get(_table_names_key)
         if pickledtablelist:
             return pickle.loads(pickledtablelist)
         else:
@@ -300,6 +370,7 @@
         all of its current columns.
         """
         assert isinstance(columns, list)
+
         try:
             self.CreateTable(table, columns)
         except TableAlreadyExists:
@@ -311,7 +382,7 @@
 
                 # load the current column list
                 oldcolumnlist = pickle.loads(
-                    self.db.get(columnlist_key, txn=txn, flags=DB_RMW))
+                    self.db.get_bytes(columnlist_key, txn=txn, flags=db.DB_RMW))
                 # create a hash table for fast lookups of column names in the
                 # loop below
                 oldcolumnhash = {}
@@ -329,7 +400,7 @@
                 if newcolumnlist != oldcolumnlist :
                     # delete the old one first since we opened with DB_DUP
                     self.db.delete(columnlist_key, txn=txn)
-                    self.db.put(columnlist_key,
+                    self.db.put_bytes(columnlist_key,
                                 pickle.dumps(newcolumnlist, 1),
                                 txn=txn)
 
@@ -337,18 +408,21 @@
                 txn = None
 
                 self.__load_column_info(table)
-            except DBError, dberror:
+            except db.DBError, dberror:
                 if txn:
                     txn.abort()
-                raise TableDBError, dberror[1]
+                if sys.version_info[0] < 3 :
+                    raise TableDBError, dberror[1]
+                else :
+                    raise TableDBError, dberror.args[1]
 
 
     def __load_column_info(self, table) :
         """initialize the self.__tablecolumns dict"""
         # check the column names
         try:
-            tcolpickles = self.db.get(_columns_key(table))
-        except DBNotFoundError:
+            tcolpickles = self.db.get_bytes(_columns_key(table))
+        except db.DBNotFoundError:
             raise TableDBError, "unknown table: %r" % (table,)
         if not tcolpickles:
             raise TableDBError, "unknown table: %r" % (table,)
@@ -366,11 +440,14 @@
                 blist.append(random.randint(0,255))
             newid = struct.pack('B'*_rowid_str_len, *blist)
 
+            if sys.version_info[0] >= 3 :
+                newid = newid.decode("iso8859-1")  # 8 bits
+
             # Guarantee uniqueness by adding this key to the database
             try:
                 self.db.put(_rowid_key(table, newid), None, txn=txn,
-                            flags=DB_NOOVERWRITE)
-            except DBKeyExistError:
+                            flags=db.DB_NOOVERWRITE)
+            except db.DBKeyExistError:
                 pass
             else:
                 unique = 1
@@ -382,9 +459,10 @@
         """Insert(table, datadict) - Insert a new row into the table
         using the keys+values from rowdict as the column values.
         """
+
         txn = None
         try:
-            if not self.db.has_key(_columns_key(table)):
+            if not getattr(self.db, "has_key")(_columns_key(table)):
                 raise TableDBError, "unknown table"
 
             # check the validity of each column name
@@ -406,7 +484,7 @@
             txn.commit()
             txn = None
 
-        except DBError, dberror:
+        except db.DBError, dberror:
             # WIBNI we could just abort the txn and re-raise the exception?
             # But no, because TableDBError is not related to DBError via
             # inheritance, so it would be backwards incompatible.  Do the next
@@ -415,7 +493,10 @@
             if txn:
                 txn.abort()
                 self.db.delete(_rowid_key(table, rowid))
-            raise TableDBError, dberror[1], info[2]
+            if sys.version_info[0] < 3 :
+                raise TableDBError, dberror[1], info[2]
+            else :
+                raise TableDBError, dberror.args[1], info[2]
 
 
     def Modify(self, table, conditions={}, mappings={}):
@@ -429,6 +510,7 @@
           condition callable expecting the data string as an argument and
           returning the new string for that column.
         """
+
         try:
             matching_rowids = self.__Select(table, [], conditions)
 
@@ -447,7 +529,7 @@
                             self.db.delete(
                                 _data_key(table, column, rowid),
                                 txn=txn)
-                        except DBNotFoundError:
+                        except db.DBNotFoundError:
                              # XXXXXXX row key somehow didn't exist, assume no
                              # error
                             dataitem = None
@@ -465,8 +547,11 @@
                         txn.abort()
                     raise
 
-        except DBError, dberror:
-            raise TableDBError, dberror[1]
+        except db.DBError, dberror:
+            if sys.version_info[0] < 3 :
+                raise TableDBError, dberror[1]
+            else :
+                raise TableDBError, dberror.args[1]
 
     def Delete(self, table, conditions={}):
         """Delete(table, conditions) - Delete items matching the given
@@ -476,6 +561,7 @@
           condition functions expecting the data string as an
           argument and returning a boolean.
         """
+
         try:
             matching_rowids = self.__Select(table, [], conditions)
 
@@ -490,23 +576,26 @@
                         try:
                             self.db.delete(_data_key(table, column, rowid),
                                            txn=txn)
-                        except DBNotFoundError:
+                        except db.DBNotFoundError:
                             # XXXXXXX column may not exist, assume no error
                             pass
 
                     try:
                         self.db.delete(_rowid_key(table, rowid), txn=txn)
-                    except DBNotFoundError:
+                    except db.DBNotFoundError:
                         # XXXXXXX row key somehow didn't exist, assume no error
                         pass
                     txn.commit()
                     txn = None
-                except DBError, dberror:
+                except db.DBError, dberror:
                     if txn:
                         txn.abort()
                     raise
-        except DBError, dberror:
-            raise TableDBError, dberror[1]
+        except db.DBError, dberror:
+            if sys.version_info[0] < 3 :
+                raise TableDBError, dberror[1]
+            else :
+                raise TableDBError, dberror.args[1]
 
 
     def Select(self, table, columns, conditions={}):
@@ -525,8 +614,11 @@
             if columns is None:
                 columns = self.__tablecolumns[table]
             matching_rowids = self.__Select(table, columns, conditions)
-        except DBError, dberror:
-            raise TableDBError, dberror[1]
+        except db.DBError, dberror:
+            if sys.version_info[0] < 3 :
+                raise TableDBError, dberror[1]
+            else :
+                raise TableDBError, dberror.args[1]
         # return the matches as a list of dictionaries
         return matching_rowids.values()
 
@@ -579,8 +671,19 @@
             # leave all unknown condition callables alone as equals
             return 0
 
-        conditionlist = conditions.items()
-        conditionlist.sort(cmp_conditions)
+        if sys.version_info[0] < 3 :
+            conditionlist = conditions.items()
+            conditionlist.sort(cmp_conditions)
+        else :  # Insertion Sort. Please, improve
+            conditionlist = []
+            for i in conditions.items() :
+                for j, k in enumerate(conditionlist) :
+                    r = cmp_conditions(k, i)
+                    if r == 1 :
+                        conditionlist.insert(j, i)
+                        break
+                else :
+                    conditionlist.append(i)
 
         # Apply conditions to column data to find what we want
         cur = self.db.cursor()
@@ -615,9 +718,13 @@
 
                     key, data = cur.next()
 
-            except DBError, dberror:
-                if dberror[0] != DB_NOTFOUND:
-                    raise
+            except db.DBError, dberror:
+                if sys.version_info[0] < 3 :
+                    if dberror[0] != db.DB_NOTFOUND:
+                        raise
+                else :
+                    if dberror.args[0] != db.DB_NOTFOUND:
+                        raise
                 continue
 
         cur.close()
@@ -635,9 +742,13 @@
                     try:
                         rowdata[column] = self.db.get(
                             _data_key(table, column, rowid))
-                    except DBError, dberror:
-                        if dberror[0] != DB_NOTFOUND:
-                            raise
+                    except db.DBError, dberror:
+                        if sys.version_info[0] < 3 :
+                            if dberror[0] != db.DB_NOTFOUND:
+                                raise
+                        else :
+                            if dberror.args[0] != db.DB_NOTFOUND:
+                                raise
                         rowdata[column] = None
 
         # return the matches
@@ -660,7 +771,7 @@
             while 1:
                 try:
                     key, data = cur.set_range(table_key)
-                except DBNotFoundError:
+                except db.DBNotFoundError:
                     break
                 # only delete items in this table
                 if key[:len(table_key)] != table_key:
@@ -672,7 +783,7 @@
             while 1:
                 try:
                     key, data = cur.set_range(table_key)
-                except DBNotFoundError:
+                except db.DBNotFoundError:
                     break
                 # only delete items in this table
                 if key[:len(table_key)] != table_key:
@@ -683,7 +794,7 @@
 
             # delete the tablename from the table name list
             tablelist = pickle.loads(
-                self.db.get(_table_names_key, txn=txn, flags=DB_RMW))
+                self.db.get_bytes(_table_names_key, txn=txn, flags=db.DB_RMW))
             try:
                 tablelist.remove(table)
             except ValueError:
@@ -691,7 +802,7 @@
                 pass
             # delete 1st, incase we opened with DB_DUP
             self.db.delete(_table_names_key, txn=txn)
-            self.db.put(_table_names_key, pickle.dumps(tablelist, 1), txn=txn)
+            self.db.put_bytes(_table_names_key, pickle.dumps(tablelist, 1), txn=txn)
 
             txn.commit()
             txn = None
@@ -699,7 +810,11 @@
             if self.__tablecolumns.has_key(table):
                 del self.__tablecolumns[table]
 
-        except DBError, dberror:
+        except db.DBError, dberror:
             if txn:
                 txn.abort()
-            raise TableDBError, dberror[1]
+            if sys.version_info[0] < 3 :
+                raise TableDBError, dberror[1]
+            else :
+                 raise TableDBError, dberror.args[1]
+
--- a/Lib/bsddb/test/test_dbtables.py	Wed Aug 13 14:26:20 2008 +0000
+++ b/Lib/bsddb/test/test_dbtables.py	Wed Aug 13 14:26:55 2008 +0000
@@ -37,6 +37,11 @@
     db_name = 'test-table.db'
 
     def setUp(self):
+        import sys
+        if sys.version_info[0] >= 3 :
+            from test_all import do_proxy_db_py3k
+            self._flag_proxy_db_py3k = do_proxy_db_py3k(False)
+
         self.testHomeDir = get_new_environment_path()
         self.tdb = dbtables.bsdTableDB(
             filename='tabletest.db', dbhome=self.testHomeDir, create=1)
@@ -44,6 +49,10 @@
     def tearDown(self):
         self.tdb.close()
         test_support.rmtree(self.testHomeDir)
+        import sys
+        if sys.version_info[0] >= 3 :
+            from test_all import do_proxy_db_py3k
+            do_proxy_db_py3k(self._flag_proxy_db_py3k)
 
     def test01(self):
         tabname = "test01"
@@ -53,7 +62,12 @@
         except dbtables.TableDBError:
             pass
         self.tdb.CreateTable(tabname, [colname])
-        self.tdb.Insert(tabname, {colname: pickle.dumps(3.14159, 1)})
+        import sys
+        if sys.version_info[0] < 3 :
+            self.tdb.Insert(tabname, {colname: pickle.dumps(3.14159, 1)})
+        else :
+            self.tdb.Insert(tabname, {colname: pickle.dumps(3.14159,
+                1).decode("iso8859-1")})  # 8 bits
 
         if verbose:
             self.tdb._db_print()
@@ -61,7 +75,11 @@
         values = self.tdb.Select(
             tabname, [colname], conditions={colname: None})
 
-        colval = pickle.loads(values[0][colname])
+        import sys
+        if sys.version_info[0] < 3 :
+            colval = pickle.loads(values[0][colname])
+        else :
+            colval = pickle.loads(bytes(values[0][colname], "iso8859-1"))
         self.assert_(colval > 3.141)
         self.assert_(colval < 3.142)
 
@@ -71,11 +89,23 @@
         col0 = 'coolness factor'
         col1 = 'but can it fly?'
         col2 = 'Species'
-        testinfo = [
-            {col0: pickle.dumps(8, 1), col1: 'no', col2: 'Penguin'},
-            {col0: pickle.dumps(-1, 1), col1: 'no', col2: 'Turkey'},
-            {col0: pickle.dumps(9, 1), col1: 'yes', col2: 'SR-71A Blackbird'}
-        ]
+
+        import sys
+        if sys.version_info[0] < 3 :
+            testinfo = [
+                {col0: pickle.dumps(8, 1), col1: 'no', col2: 'Penguin'},
+                {col0: pickle.dumps(-1, 1), col1: 'no', col2: 'Turkey'},
+                {col0: pickle.dumps(9, 1), col1: 'yes', col2: 'SR-71A Blackbird'}
+            ]
+        else :
+            testinfo = [
+                {col0: pickle.dumps(8, 1).decode("iso8859-1"),
+                    col1: 'no', col2: 'Penguin'},
+                {col0: pickle.dumps(-1, 1).decode("iso8859-1"),
+                    col1: 'no', col2: 'Turkey'},
+                {col0: pickle.dumps(9, 1).decode("iso8859-1"),
+                    col1: 'yes', col2: 'SR-71A Blackbird'}
+            ]
 
         try:
             self.tdb.Drop(tabname)
@@ -85,8 +115,14 @@
         for row in testinfo :
             self.tdb.Insert(tabname, row)
 
-        values = self.tdb.Select(tabname, [col2],
-            conditions={col0: lambda x: pickle.loads(x) >= 8})
+        import sys
+        if sys.version_info[0] < 3 :
+            values = self.tdb.Select(tabname, [col2],
+                conditions={col0: lambda x: pickle.loads(x) >= 8})
+        else :
+            values = self.tdb.Select(tabname, [col2],
+                conditions={col0: lambda x:
+                    pickle.loads(bytes(x, "iso8859-1")) >= 8})
 
         self.assertEqual(len(values), 2)
         if values[0]['Species'] == 'Penguin' :