ソースを参照

implements sqlalchemy

Olivier Massot 4 年 前
コミット
e03cd9434e
8 ファイル変更164 行追加156 行削除
  1. 3 1
      core/constants.py
  2. 18 0
      core/db.py
  3. 11 2
      core/file_utilities.py
  4. 68 54
      core/models.py
  5. 46 94
      core/repositories.py
  6. 6 1
      main.py
  7. 1 0
      requirements.txt
  8. 11 4
      ui/window.py

+ 3 - 1
core/constants.py

@@ -18,4 +18,6 @@ DATA_DIR = APP_ROOT / 'data'
 
 LOGGER_NAME = "mew"
 LOG_DIR = DATA_DIR
-LOGGER_LEVEL = 0
+LOGGER_LEVEL = 0
+
+SQL_ALCHEMY_VERBOSE = 1

+ 18 - 0
core/db.py

@@ -0,0 +1,18 @@
+from sqlalchemy import create_engine
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import Session
+
+from core import constants
+
+engine = create_engine(f'sqlite:///{constants.DB_PATH}', echo=constants.SQL_ALCHEMY_VERBOSE)
+Base = declarative_base()
+
+
+def session():
+    return Session(engine)
+
+
+def create():
+    if constants.DB_PATH.exists():
+        raise FileExistsError('A db file already exists')
+    Base.metadata.create_all(engine)

+ 11 - 2
core/file_utilities.py

@@ -22,9 +22,18 @@ def is_media_file_ext(ext):
 def hash_file(filename):
     """ return a SHA256 hash for the given file """
     h = hashlib.sha256()
-    b = bytearray(128*1024)
+    b = bytearray(128 * 1024)
     mv = memoryview(b)
     with open(filename, 'rb', buffering=0) as f:
         for n in iter(lambda: f.readinto(mv), 0):
             h.update(mv[:n])
-    return h.hexdigest()
+    return h.hexdigest()
+
+
+def is_subdir_of(subject, other):
+    """ is subject a subdirectory of other """
+    if not subject.parent:
+        return False
+    if subject.parent == other:
+        return True
+    return is_subdir_of(subject.parent, other)

+ 68 - 54
core/models.py

@@ -1,84 +1,98 @@
+from datetime import datetime
 
-class Model:
-    def __init__(self, id_=None):
-        self.id = id_
+from sqlalchemy import Column, Integer, String, DateTime, Boolean
 
-    def as_fields_and_values(self, exclude_id=False):
-        fields, values = [], []
-        for attr, val in self.__dict__.items():
-            if attr[0] == '_' or val is None or (exclude_id and attr == 'id'):
-                continue
-            fields.append(attr)
-            values.append(val)
-        return fields, values
+from core import db
+
+
+class Model(db.Base):
+    __abstract__ = True
+
+    id = Column(Integer, primary_key=True)
+    created_on = Column(DateTime, default=datetime.now())
+    deleted = Column(Boolean, nullable=False, default=False)
+
+    def __repr__(self):
+        return f"<{self.__class__.__name__} {self.id}>"
 
 
 class MusicFolder(Model):
+    __tablename__ = 'MusicFolders'
+
     STATUS_UNKNOWN = 0
     STATUS_FOUND = 1
     STATUS_UNAVAILABLE = 2
     STATUS_DELETED = 3
 
-    def __init__(self, id_=None, path_=None, last_scan=None, status=None):
-        super().__init__(id_)
-        self.path = path_
-        self.last_scan = last_scan
-        self.status = status if status is not None else MusicFolder.STATUS_UNKNOWN
+    path = Column(String, nullable=False)
+    last_scan = Column(DateTime)
+    status = Column(Integer, default=0)
 
 
 class Profile(Model):
-    def __init__(self, id_=None, name=None, created_on=None):
-        super().__init__(id_)
-        self.name = name
-        self.created_on = created_on
+    __tablename__ = 'Profiles'
+
+    name = Column(String, nullable=False)
 
 
 class Tag(Model):
-    def __init__(self, id_=None, label=None, color=None, deleted=0):
-        super().__init__(id_)
-        self.label = label
-        self.color = color
-        self.deleted = deleted
+    __tablename__ = 'Tags'
+
+    label = Column(String, nullable=False)
+    color = Column(String, nullable=False, default="#6666ff")
 
 
 class Track(Model):
+    __tablename__ = 'Tracks'
+
     STATUS_UNKNOWN = 0
     STATUS_FOUND = 1
     STATUS_UNAVAILABLE = 2
     STATUS_UNREADABLE = 3
 
-    def __init__(self, id_=None, profile_id=None, music_folder_id=None, title=None,
-                 format_=None, artist=None, album=None, track_num=None, year=None,
-                 duration=None, size=None, note=None, status=None, path_=None,
-                 hash_=None, origin=None):
-        super().__init__(id_)
-        self.profile_id = profile_id
-        self.music_folder_id = music_folder_id
-        self.title = title
-        self.format = format_
-        self.artist = artist
-        self.album = album
-        self.track_num = track_num
-        self.year = year
-        self.duration = duration
-        self.size = size
-        self.note = note
-        self.status = status
-        self.path = path_
-        self.hash = hash_
-        self.origin = origin
+    profile_id = Column(Integer)
+    music_folder_id = Column(Integer)
+
+    title = Column(String)
+    format = Column(String)
+    artist = Column(String)
+    album = Column(String)
+    track_num = Column(Integer)
+    year = Column(Integer)
+    duration = Column(Integer)
+    size = Column(Integer)
+    note = Column(String)
+    status = Column(Integer, nullable=False, default=0)
+    path = Column(String, nullable=False)
+    hash = Column(String, nullable=False)
+    origin = Column(String)
 
 
 class TrackTag(Model):
-    def __init__(self, id_=None, track_id=None, tag_id=None):
-        super().__init__(id_)
-        self.track_id = track_id
-        self.tag_id = tag_id
+    __tablename__ = 'TracksTags'
+
+    track_id = Column(Integer)
+    tag_id = Column(Integer)
 
 
 class Session(Model):
-    def __init__(self, id_=None, name=None, date_=None, notes=None):
-        super().__init__(id_)
-        self.name = name
-        self.date = date_
-        self.notes = notes
+    __tablename__ = 'Sessions'
+
+    name = Column(String, nullable=False)
+    date = Column(DateTime)
+    notes = Column(String)
+
+
+class SessionTrack(Model):
+    __tablename__ = 'SessionsTracks'
+
+    track_id = Column(Integer)
+    session_id = Column(Integer)
+
+
+if __name__ == "__main__":
+    from core import constants
+
+    session = db.session()
+    for track in session.query(Track).all():
+        print(track)

+ 46 - 94
core/repositories.py

@@ -1,79 +1,51 @@
-import sqlite3
 from abc import abstractmethod
 
-from core import constants
-from core.models import MusicFolder, Track, Tag, Profile
+from core import db
+from core.models import MusicFolder, Track, Tag, TrackTag, SessionTrack, Session
 
 
 class Repository:
-    TABLE_NAME = None
     MODEL_CLS = None
 
-    @abstractmethod
     def __init__(self):
-        self.cnn = sqlite3.connect(constants.DB_PATH)
+        self.session = db.session()
 
-    def execute(self, sql, *parameters):
-        cur = self.cnn.cursor()
-        cur.execute(sql, parameters)
-        return cur
+    def query(self):
+        return self.session.query(self.MODEL_CLS)
+
+    def commit(self):
+        self.session.commit()
+
+    def rollback(self):
+        self.session.rollback()
 
     def get_by_id(self, id_):
-        cur = self.execute(
-            f"SELECT * FROM {self.TABLE_NAME} WHERE id=?;",
-            id_
-        )
-        return self.MODEL_CLS(**cur.fetchone())
+        return self.query().filter(id == id_).first()
 
     def get_all(self):
-        cur = self.execute(f"SELECT * FROM {self.TABLE_NAME};", )
-        return [self.MODEL_CLS(*row) for row in cur.fetchall()]
+        return self.query().all()
 
     def get_by(self, field, val):
-        cur = self.execute(f"SELECT * FROM {self.TABLE_NAME} WHERE {field}=?;", val)
-        return [self.MODEL_CLS(*row) for row in cur.fetchall()]
+        return self.query().filter(field == val).all()
 
-    def get_by_raw_sql(self, sql, parameters=None):
-        parameters = parameters if parameters is not None else []
-        cur = self.execute(sql, parameters)
-        return [self.MODEL_CLS(*row) for row in cur.fetchall()]
+    def exists(self, field, val):
+        return self.query().filter(field == val).exists()
 
     def create(self, model, commit=False):
-        fields, values = model.as_fields_and_values(True)
-        self.execute(
-            f"INSERT INTO {self.TABLE_NAME} ({', '.join(fields)}) VALUES ({', '.join(['?' for v in values])});",
-            *values
-        )
+        self.session.add(model)
         if commit:
-            self.commit()
-
-    def update(self, model, commit=False):
-        fields, values = model.as_fields_and_values(True)
-        values.append(model.id)
-        self.execute(
-            f"UPDATE {self.TABLE_NAME} SET {', '.join([f'{f}=?' for f in fields])} WHERE id=?;",
-            *values
-        )
-        if commit:
-            self.commit()
+            self.session.commit()
 
     def delete(self, model, commit=False):
-        self.execute(f"DELETE FROM {self.TABLE_NAME} WHERE id=?;", model.id)
+        model.delete()
         if commit:
             self.commit()
 
-    def commit(self):
-        self.cnn.commit()
-
-    def rollback(self):
-        self.cnn.rollback()
-
     def __del__(self):
-        self.cnn.close()
+        self.session.close()
 
 
 class MusicFolderRepository(Repository):
-    TABLE_NAME = "MusicFolders"
     MODEL_CLS = MusicFolder
 
     def __init__(self):
@@ -81,55 +53,38 @@ class MusicFolderRepository(Repository):
 
 
 class TagRepository(Repository):
-    TABLE_NAME = "Tags"
     MODEL_CLS = Tag
 
     def __init__(self):
         super().__init__()
 
-    def get_by_track(self, track):
-        cur = self.execute(
-            f"""SELECT * 
-                FROM Tags t
-                INNER JOIN TracksTags tt 
-                    ON tt.tag_id = t.id
-                WHERE tt.track_id=?;""", track.id)
-        return [self.MODEL_CLS(*row) for row in cur.fetchall()]
+    def get_by_track_id(self, track_id):
+        return self.session.query(Tag)\
+            .join(TrackTag, Tag.id == TrackTag.tag_id)\
+            .filter(TrackTag.track_id == track_id).\
+            all()
 
 
 class TrackRepository(Repository):
-    TABLE_NAME = "Tracks"
     MODEL_CLS = Track
 
-    def __init__(self):
-        super().__init__()
+    def get_by_tag_id(self, tag_id):
+        return self.session.query(Track)\
+            .join(TrackTag, Track.id == TrackTag.track_id)\
+            .filter(TrackTag.tag_id == tag_id).\
+            all()
+
+    def get_by_tag_ids(self, tag_ids):
+        return self.session.query(Track)\
+            .join(TrackTag, Track.id == TrackTag.track_id)\
+            .filter(TrackTag.tag_id.in_(tag_ids)).\
+            all()
 
-    def get_by_tag(self, tag):
-        cur = self.execute(
-            f"""SELECT * 
-                FROM Tracks t
-                INNER JOIN TracksTags tt 
-                    ON tt.track_id = t.id
-                WHERE tt.tag_id=?;""", tag.id)
-        return [self.MODEL_CLS(*row) for row in cur.fetchall()]
-
-    def get_by_tags(self, tags):
-        cur = self.execute(
-            f"""SELECT * 
-                FROM Tracks t
-                INNER JOIN TracksTags tt 
-                    ON tt.track_id = t.id
-                WHERE tt.tag_id in ({', '.join(['?' for _ in tags])};""", *[tag.id for tag in tags])
-        return [self.MODEL_CLS(*row) for row in cur.fetchall()]
-
-    def get_by_session(self, session):
-        cur = self.execute(
-            f"""SELECT * 
-                FROM Tracks t
-                INNER JOIN SessionsTracks st
-                    ON st.track_id = t.id
-                WHERE st.session_id=?;""", session.id)
-        return [self.MODEL_CLS(*row) for row in cur.fetchall()]
+    def get_by_session_id(self, session_id):
+        return self.session.query(Track)\
+            .join(SessionTrack, Track.id == SessionTrack.track_id)\
+            .filter(SessionTrack.session_id == session_id).\
+            all()
 
 
 class SessionRepository(Repository):
@@ -139,11 +94,8 @@ class SessionRepository(Repository):
     def __init__(self):
         super().__init__()
 
-    def get_by_track(self, track):
-        cur = self.execute(
-            f"""SELECT * 
-                FROM Tracks t
-                INNER JOIN SessionsTracks st
-                    ON st.session_id = t.id
-                WHERE st.track_id=?;""", track.id)
-        return [self.MODEL_CLS(*row) for row in cur.fetchall()]
+    def get_by_track_id(self, track_id):
+        return self.session.query(Session)\
+            .join(SessionTrack, Track.id == SessionTrack.track_id)\
+            .filter(SessionTrack.track_id == track_id).\
+            all()

+ 6 - 1
main.py

@@ -12,7 +12,7 @@ import traceback
 from PyQt5.Qt import QApplication
 from PyQt5.QtWidgets import QMessageBox
 
-from core import logging_
+from core import logging_, db
 from core.logging_ import Logger
 from ui.window import MainWindow
 
@@ -33,6 +33,11 @@ def err_handler(typ, value, trace):
     sys_err(typ, value, trace)
 sys.excepthook = err_handler
 
+# Create db if not existing
+try:
+    db.create()
+except FileExistsError:
+    pass
 
 # Start UI
 app = QApplication(sys.argv)

+ 1 - 0
requirements.txt

@@ -1,5 +1,6 @@
 PyQt5
 PyQt5-stubs
+sqlalchemy
 python-vlc~=3.0
 path.py
 pyyaml

+ 11 - 4
ui/window.py

@@ -11,6 +11,7 @@ from path import Path
 from PyQt5.QtGui import QIcon
 from PyQt5.QtWidgets import QMainWindow, QListWidgetItem, QTableWidgetItem, QFileDialog, QDialog, QMessageBox
 
+from core.file_utilities import is_subdir_of
 from core.models import MusicFolder
 from core.repositories import MusicFolderRepository
 from ui.qt.main_ui import Ui_mainWindow
@@ -51,7 +52,6 @@ class MainWindow(QMainWindow):
         self.ui.musicFoldersRemoveButton.clicked.connect(self.remove_music_folder)
         self.populate_music_folders_table()
 
-
     def menu_item_selected(self, e):
         self.ui.stack.setCurrentIndex(e.index)
 
@@ -81,10 +81,17 @@ class MainWindow(QMainWindow):
         repo = MusicFolderRepository()
 
         music_folders = repo.get_all()
-        if any(re.match(str(path), f"^{f.path}.*") for f in music_folders):
-            QMessageBox.warning(self, "Ajout invalide",  "Ce dossier ou un dossier le contenant ont déjà été ajoutés")
 
-        folder = MusicFolder(None, path)
+        for folder in music_folders:
+            if path == Path(folder.path):
+                QMessageBox.warning(self, "Ajout invalide",  "Ce dossier a déjà été ajouté")
+                return
+
+            if is_subdir_of(path, Path(folder.path)):
+                QMessageBox.warning(self, "Ajout invalide",  "Ce dossier est contenu dans un dossier existant")
+                return
+
+        folder = MusicFolder(path=path)
         repo.create(folder, True)
 
         self.populate_music_folders_table()