|
|
@@ -1,79 +1,51 @@
|
|
|
-import sqlite3
|
|
|
from abc import abstractmethod
|
|
|
|
|
|
-from core import constants
|
|
|
-from core.models import MusicFolder, Track, Tag, Profile
|
|
|
+from core import db
|
|
|
+from core.models import MusicFolder, Track, Tag, TrackTag, SessionTrack, Session
|
|
|
|
|
|
|
|
|
class Repository:
|
|
|
- TABLE_NAME = None
|
|
|
MODEL_CLS = None
|
|
|
|
|
|
- @abstractmethod
|
|
|
def __init__(self):
|
|
|
- self.cnn = sqlite3.connect(constants.DB_PATH)
|
|
|
+ self.session = db.session()
|
|
|
|
|
|
- def execute(self, sql, *parameters):
|
|
|
- cur = self.cnn.cursor()
|
|
|
- cur.execute(sql, parameters)
|
|
|
- return cur
|
|
|
+ def query(self):
|
|
|
+ return self.session.query(self.MODEL_CLS)
|
|
|
+
|
|
|
+ def commit(self):
|
|
|
+ self.session.commit()
|
|
|
+
|
|
|
+ def rollback(self):
|
|
|
+ self.session.rollback()
|
|
|
|
|
|
def get_by_id(self, id_):
|
|
|
- cur = self.execute(
|
|
|
- f"SELECT * FROM {self.TABLE_NAME} WHERE id=?;",
|
|
|
- id_
|
|
|
- )
|
|
|
- return self.MODEL_CLS(**cur.fetchone())
|
|
|
+ return self.query().filter(id == id_).first()
|
|
|
|
|
|
def get_all(self):
|
|
|
- cur = self.execute(f"SELECT * FROM {self.TABLE_NAME};", )
|
|
|
- return [self.MODEL_CLS(*row) for row in cur.fetchall()]
|
|
|
+ return self.query().all()
|
|
|
|
|
|
def get_by(self, field, val):
|
|
|
- cur = self.execute(f"SELECT * FROM {self.TABLE_NAME} WHERE {field}=?;", val)
|
|
|
- return [self.MODEL_CLS(*row) for row in cur.fetchall()]
|
|
|
+ return self.query().filter(field == val).all()
|
|
|
|
|
|
- def get_by_raw_sql(self, sql, parameters=None):
|
|
|
- parameters = parameters if parameters is not None else []
|
|
|
- cur = self.execute(sql, parameters)
|
|
|
- return [self.MODEL_CLS(*row) for row in cur.fetchall()]
|
|
|
+ def exists(self, field, val):
|
|
|
+ return self.query().filter(field == val).exists()
|
|
|
|
|
|
def create(self, model, commit=False):
|
|
|
- fields, values = model.as_fields_and_values(True)
|
|
|
- self.execute(
|
|
|
- f"INSERT INTO {self.TABLE_NAME} ({', '.join(fields)}) VALUES ({', '.join(['?' for v in values])});",
|
|
|
- *values
|
|
|
- )
|
|
|
+ self.session.add(model)
|
|
|
if commit:
|
|
|
- self.commit()
|
|
|
-
|
|
|
- def update(self, model, commit=False):
|
|
|
- fields, values = model.as_fields_and_values(True)
|
|
|
- values.append(model.id)
|
|
|
- self.execute(
|
|
|
- f"UPDATE {self.TABLE_NAME} SET {', '.join([f'{f}=?' for f in fields])} WHERE id=?;",
|
|
|
- *values
|
|
|
- )
|
|
|
- if commit:
|
|
|
- self.commit()
|
|
|
+ self.session.commit()
|
|
|
|
|
|
def delete(self, model, commit=False):
|
|
|
- self.execute(f"DELETE FROM {self.TABLE_NAME} WHERE id=?;", model.id)
|
|
|
+ model.delete()
|
|
|
if commit:
|
|
|
self.commit()
|
|
|
|
|
|
- def commit(self):
|
|
|
- self.cnn.commit()
|
|
|
-
|
|
|
- def rollback(self):
|
|
|
- self.cnn.rollback()
|
|
|
-
|
|
|
def __del__(self):
|
|
|
- self.cnn.close()
|
|
|
+ self.session.close()
|
|
|
|
|
|
|
|
|
class MusicFolderRepository(Repository):
|
|
|
- TABLE_NAME = "MusicFolders"
|
|
|
MODEL_CLS = MusicFolder
|
|
|
|
|
|
def __init__(self):
|
|
|
@@ -81,55 +53,38 @@ class MusicFolderRepository(Repository):
|
|
|
|
|
|
|
|
|
class TagRepository(Repository):
|
|
|
- TABLE_NAME = "Tags"
|
|
|
MODEL_CLS = Tag
|
|
|
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
|
|
|
- def get_by_track(self, track):
|
|
|
- cur = self.execute(
|
|
|
- f"""SELECT *
|
|
|
- FROM Tags t
|
|
|
- INNER JOIN TracksTags tt
|
|
|
- ON tt.tag_id = t.id
|
|
|
- WHERE tt.track_id=?;""", track.id)
|
|
|
- return [self.MODEL_CLS(*row) for row in cur.fetchall()]
|
|
|
+ def get_by_track_id(self, track_id):
|
|
|
+ return self.session.query(Tag)\
|
|
|
+ .join(TrackTag, Tag.id == TrackTag.tag_id)\
|
|
|
+ .filter(TrackTag.track_id == track_id).\
|
|
|
+ all()
|
|
|
|
|
|
|
|
|
class TrackRepository(Repository):
|
|
|
- TABLE_NAME = "Tracks"
|
|
|
MODEL_CLS = Track
|
|
|
|
|
|
- def __init__(self):
|
|
|
- super().__init__()
|
|
|
+ def get_by_tag_id(self, tag_id):
|
|
|
+ return self.session.query(Track)\
|
|
|
+ .join(TrackTag, Track.id == TrackTag.track_id)\
|
|
|
+ .filter(TrackTag.tag_id == tag_id).\
|
|
|
+ all()
|
|
|
+
|
|
|
+ def get_by_tag_ids(self, tag_ids):
|
|
|
+ return self.session.query(Track)\
|
|
|
+ .join(TrackTag, Track.id == TrackTag.track_id)\
|
|
|
+ .filter(TrackTag.tag_id.in_(tag_ids)).\
|
|
|
+ all()
|
|
|
|
|
|
- def get_by_tag(self, tag):
|
|
|
- cur = self.execute(
|
|
|
- f"""SELECT *
|
|
|
- FROM Tracks t
|
|
|
- INNER JOIN TracksTags tt
|
|
|
- ON tt.track_id = t.id
|
|
|
- WHERE tt.tag_id=?;""", tag.id)
|
|
|
- return [self.MODEL_CLS(*row) for row in cur.fetchall()]
|
|
|
-
|
|
|
- def get_by_tags(self, tags):
|
|
|
- cur = self.execute(
|
|
|
- f"""SELECT *
|
|
|
- FROM Tracks t
|
|
|
- INNER JOIN TracksTags tt
|
|
|
- ON tt.track_id = t.id
|
|
|
- WHERE tt.tag_id in ({', '.join(['?' for _ in tags])};""", *[tag.id for tag in tags])
|
|
|
- return [self.MODEL_CLS(*row) for row in cur.fetchall()]
|
|
|
-
|
|
|
- def get_by_session(self, session):
|
|
|
- cur = self.execute(
|
|
|
- f"""SELECT *
|
|
|
- FROM Tracks t
|
|
|
- INNER JOIN SessionsTracks st
|
|
|
- ON st.track_id = t.id
|
|
|
- WHERE st.session_id=?;""", session.id)
|
|
|
- return [self.MODEL_CLS(*row) for row in cur.fetchall()]
|
|
|
+ def get_by_session_id(self, session_id):
|
|
|
+ return self.session.query(Track)\
|
|
|
+ .join(SessionTrack, Track.id == SessionTrack.track_id)\
|
|
|
+ .filter(SessionTrack.session_id == session_id).\
|
|
|
+ all()
|
|
|
|
|
|
|
|
|
class SessionRepository(Repository):
|
|
|
@@ -139,11 +94,8 @@ class SessionRepository(Repository):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
|
|
|
- def get_by_track(self, track):
|
|
|
- cur = self.execute(
|
|
|
- f"""SELECT *
|
|
|
- FROM Tracks t
|
|
|
- INNER JOIN SessionsTracks st
|
|
|
- ON st.session_id = t.id
|
|
|
- WHERE st.track_id=?;""", track.id)
|
|
|
- return [self.MODEL_CLS(*row) for row in cur.fetchall()]
|
|
|
+ def get_by_track_id(self, track_id):
|
|
|
+ return self.session.query(Session)\
|
|
|
+ .join(SessionTrack, Track.id == SessionTrack.track_id)\
|
|
|
+ .filter(SessionTrack.track_id == track_id).\
|
|
|
+ all()
|