You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

87 lines
2.8 KiB

  1. # Simple example presenting how persistent ID can be used to pickle
  2. # external objects by reference.
  3. import pickle
  4. import sqlite3
  5. from collections import namedtuple
  6. # Simple class representing a record in our database.
  7. MemoRecord = namedtuple("MemoRecord", "key, task")
  8. class DBPickler(pickle.Pickler):
  9. def persistent_id(self, obj):
  10. # Instead of pickling MemoRecord as a regular class instance, we emit a
  11. # persistent ID.
  12. if isinstance(obj, MemoRecord):
  13. # Here, our persistent ID is simply a tuple, containing a tag and a
  14. # key, which refers to a specific record in the database.
  15. return ("MemoRecord", obj.key)
  16. else:
  17. # If obj does not have a persistent ID, return None. This means obj
  18. # needs to be pickled as usual.
  19. return None
  20. class DBUnpickler(pickle.Unpickler):
  21. def __init__(self, file, connection):
  22. super().__init__(file)
  23. self.connection = connection
  24. def persistent_load(self, pid):
  25. # This method is invoked whenever a persistent ID is encountered.
  26. # Here, pid is the tuple returned by DBPickler.
  27. cursor = self.connection.cursor()
  28. type_tag, key_id = pid
  29. if type_tag == "MemoRecord":
  30. # Fetch the referenced record from the database and return it.
  31. cursor.execute("SELECT * FROM memos WHERE key=?", (str(key_id),))
  32. key, task = cursor.fetchone()
  33. return MemoRecord(key, task)
  34. else:
  35. # Always raises an error if you cannot return the correct object.
  36. # Otherwise, the unpickler will think None is the object referenced
  37. # by the persistent ID.
  38. raise pickle.UnpicklingError("unsupported persistent object")
  39. def main():
  40. import io
  41. import pprint
  42. # Initialize and populate our database.
  43. conn = sqlite3.connect(":memory:")
  44. cursor = conn.cursor()
  45. cursor.execute("CREATE TABLE memos(key INTEGER PRIMARY KEY, task TEXT)")
  46. tasks = (
  47. 'give food to fish',
  48. 'prepare group meeting',
  49. 'fight with a zebra',
  50. )
  51. for task in tasks:
  52. cursor.execute("INSERT INTO memos VALUES(NULL, ?)", (task,))
  53. # Fetch the records to be pickled.
  54. cursor.execute("SELECT * FROM memos")
  55. memos = [MemoRecord(key, task) for key, task in cursor]
  56. # Save the records using our custom DBPickler.
  57. file = io.BytesIO()
  58. DBPickler(file).dump(memos)
  59. print("Pickled records:")
  60. pprint.pprint(memos)
  61. # Update a record, just for good measure.
  62. cursor.execute("UPDATE memos SET task='learn italian' WHERE key=1")
  63. # Load the records from the pickle data stream.
  64. file.seek(0)
  65. memos = DBUnpickler(file, conn).load()
  66. print("Unpickled records:")
  67. pprint.pprint(memos)
  68. if __name__ == '__main__':
  69. main()