db.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. '''
  2. Convenient access to various databases
  3. '''
  4. from collections import namedtuple
  5. from datetime import datetime
  6. import logging
  7. import pypyodbc
  8. pypyodbc.lowercase = False
  9. logger = logging.getLogger("database")
  10. class CustomDb(pypyodbc.Connection):
  11. """ Connexion to a database """
  12. _cache = {}
  13. default_name = ""
  14. drivername = ""
  15. dsn = ""
  16. default_user = ""
  17. default_pwd = ""
  18. def __init__(self, **kwargs):
  19. cls = self.__class__
  20. if not "uid" in kwargs and cls.default_user:
  21. kwargs["uid"] = cls.default_user
  22. if not "pwd" in kwargs and cls.default_pwd:
  23. kwargs["pwd"] = cls.default_pwd
  24. super(CustomDb, self).__init__(cls.dsn, **kwargs)
  25. def connect(self, *args, **kwargs):
  26. """ Establish the connexion to the database"""
  27. logger.info("Connection to %s: %s", self.__class__.__name__, self.connectString)
  28. super(CustomDb, self).connect(*args, **kwargs)
  29. def read(self, sql, *args):
  30. """ yield rows as NamedTupleRow """
  31. # print(sql)
  32. cursor = self.execute(sql)
  33. row = cursor.fetchone()
  34. fieldnames = [(column[0] if column[0].isidentifier() else "field_{}".format(i)) for i, column in enumerate(cursor.description)]
  35. rowmodel = namedtuple("Row", fieldnames)
  36. while row:
  37. yield rowmodel(*row)
  38. row = cursor.fetchone()
  39. cursor.close()
  40. def read_all(self, sql, *args):
  41. """ return the selection as a list of dictionnaries """
  42. cursor = self.execute(sql)
  43. fieldnames = [(column[0] if column[0].isidentifier() else "field_{}".format(i)) for i, column in enumerate(cursor.description)]
  44. rowmodel = namedtuple("Row", fieldnames)
  45. data = [rowmodel(*row) for row in cursor.fetchall()]
  46. cursor.close()
  47. return data
  48. def first(self, sql, *args):
  49. try:
  50. return next(self.read(sql, *args))
  51. except StopIteration:
  52. return None
  53. def exists(self, sql, *args):
  54. """ return True if the sql command retrieves records """
  55. return (self.first(sql, *args) is not None)
  56. def execute(self, sql, *args):
  57. cursor = self.cursor()
  58. args = [sql, tuple(args)] if args else [sql]
  59. cursor.execute(*args)
  60. return cursor
  61. class AccessDb(CustomDb):
  62. dsn = "DRIVER={Microsoft Access Driver (*.mdb, *.accdb)};FIL={MS Access};"
  63. default_user = "admin"
  64. default_pwd = ""
  65. def __init__(self, dbpath, **kwargs):
  66. super(AccessDb, self).__init__(dbq=dbpath, **kwargs)
  67. def assert_connected(self):
  68. for row in self.read("SELECT TOP 1 * FROM MSysObjects;"):
  69. if not row:
  70. raise AssertionError("Unable to connect to: {}".format(self.connectString))
  71. return
  72. @staticmethod
  73. def format_date(dat, in_format="%Y-%m-%dT%H:%M:%S", out_format="%m/%d/%Y"):
  74. return datetime.strptime(str(dat), in_format).strftime(out_format)
  75. @staticmethod
  76. def nz(val, default=""):
  77. return val if val else default
  78. class AccessSDb(AccessDb):
  79. dsn = "DRIVER={Microsoft Access Driver (*.mdb, *.accdb)};FIL={MS Access};"
  80. default_user = ""
  81. default_pwd = ""
  82. def __init__(self, dbpath, mdwpath, uid, pwd, **kwargs):
  83. super(AccessSDb, self).__init__(dbpath, uid=uid, pwd=pwd, systemdb=mdwpath, **kwargs)
  84. class OracleDb(CustomDb):
  85. dsn = "DRIVER={Oracle dans ORA102};"
  86. def __init__(self, dbname, user, pwd, **kwargs):
  87. super(OracleDb, self).__init__(dbq=dbname, uid=user, pwd=pwd, **kwargs)
  88. class SqlServerDb(CustomDb):
  89. dsn = "DRIVER={SQL Server};"
  90. def __init__(self, server, dbname, user, pwd, **kwargs):
  91. super(SqlServerDb, self).__init__(server=server, database=dbname, uid=user, pwd=pwd, **kwargs)
  92. class PostgresDb(CustomDb):
  93. dsn = "DRIVER={PostgreSQL Unicode};"
  94. server = ""
  95. db = ""
  96. user = ""
  97. pwd = ""
  98. def __init__(self, server, dbname, user, pwd, **kwargs):
  99. super(PostgresDb, self).__init__(server=server, database=dbname, uid=user, pwd=pwd, **kwargs)
  100. # class SqliteDb(CustomDb):
  101. # drivername = "QODBC"
  102. # dsn = "DRIVER={{Microsoft Access Driver (*.mdb, *.accdb)}};FIL={{MS Access}}"
  103. # default_user = "admin"
  104. # pwd = ""
  105. # def __init__(self, dbpath, **kwargs):
  106. # CustomDb.__init__(self, dbq=dbpath, **kwargs)
  107. #
  108. ### SQL Helpers ###
  109. class SQLHelper():
  110. """ Génère du code sql """
  111. @classmethod
  112. def _sql_format(cls, val):
  113. """ pre-formatte une variable pour injection sql dans une base MS Access """
  114. raise NotImplementedError()
  115. @classmethod
  116. def select(cls, where=""):
  117. raise NotImplementedError()
  118. @classmethod
  119. def update(cls, tblname, data, where):
  120. raise NotImplementedError()
  121. @classmethod
  122. def insert(cls, tblname, data):
  123. raise NotImplementedError()
  124. @classmethod
  125. def delete(cls, tblname, where):
  126. raise NotImplementedError()
  127. class AccessSqlHelper(SQLHelper):
  128. """ SQL Helper pour MS Access """
  129. @classmethod
  130. def _sql_format(cls, val):
  131. if val is None:
  132. return "Null"
  133. elif type(val) is int or type(val) is bool:
  134. return "{}".format(val)
  135. elif type(val) is str:
  136. return "\"{}\"".format(val)
  137. elif type(val) is datetime:
  138. return "#{:%Y-%m-%d %H:%M:%S}#".format(val)
  139. return "{}".format(val)
  140. @classmethod
  141. def select(cls, where=""):
  142. sql = "SELECT * FROM {}".format(cls._tblname)
  143. if where:
  144. sql = "{} WHERE {}".format(sql, where)
  145. return sql
  146. @classmethod
  147. def update(cls, tblname, data, where):
  148. sql = "UPDATE {} SET {} WHERE {}".format(tblname,
  149. ",".join(["{} = {}".format(key, cls._sql_format(data[key])) for key in data]),
  150. " AND ".join(["{} = {}".format(key, cls._sql_format(where[key])) for key in where]))
  151. return sql
  152. @classmethod
  153. def insert(cls, tblname, data):
  154. sql = "INSERT INTO {} ({}) VALUES ({})".format(tblname,
  155. ",".join(data.keys()),
  156. ",".join([cls._sql_format(data[key]) for key in data]))
  157. return sql
  158. @classmethod
  159. def delete(cls, tblname, where):
  160. sql = "DELETE * FROM {} WHERE {}".format(tblname,
  161. " AND ".join(["{} = {}".format(key, cls._sql_format(where[key])) for key in where]))
  162. return sql