| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- from abc import abstractmethod
- from core import db
- from core.models import MusicFolder, Track, Tag, TrackTag, SessionTrack, Session
- class Repository:
- MODEL_CLS = None
- def __init__(self):
- self.session = db.session()
- def query(self):
- return self.session.query(self.MODEL_CLS)
- def commit(self):
- self.session.commit()
- def rollback(self):
- self.session.rollback()
- def get_by_id(self, id_):
- return self.query().filter(id == id_).first()
- def get_all(self):
- return self.query().all()
- def get_by(self, field, val):
- return self.query().filter(field == val).all()
- def exists(self, field, val):
- return self.query().filter(field == val).exists()
- def create(self, model, commit=False):
- self.session.add(model)
- if commit:
- self.session.commit()
- def delete(self, model, commit=False):
- model.delete()
- if commit:
- self.commit()
- def __del__(self):
- self.session.close()
- class MusicFolderRepository(Repository):
- MODEL_CLS = MusicFolder
- def __init__(self):
- super().__init__()
- class TagRepository(Repository):
- MODEL_CLS = Tag
- def __init__(self):
- super().__init__()
- def get_by_track_id(self, track_id):
- return self.session.query(Tag)\
- .join(TrackTag, Tag.id == TrackTag.tag_id)\
- .filter(TrackTag.track_id == track_id).\
- all()
- class TrackRepository(Repository):
- MODEL_CLS = Track
- def get_by_tag_id(self, tag_id):
- return self.session.query(Track)\
- .join(TrackTag, Track.id == TrackTag.track_id)\
- .filter(TrackTag.tag_id == tag_id).\
- all()
- def get_by_tag_ids(self, tag_ids):
- return self.session.query(Track)\
- .join(TrackTag, Track.id == TrackTag.track_id)\
- .filter(TrackTag.tag_id.in_(tag_ids)).\
- all()
- def get_by_session_id(self, session_id):
- return self.session.query(Track)\
- .join(SessionTrack, Track.id == SessionTrack.track_id)\
- .filter(SessionTrack.session_id == session_id).\
- all()
- class SessionRepository(Repository):
- TABLE_NAME = "Sessions"
- MODEL_CLS = Track
- def __init__(self):
- super().__init__()
- def get_by_track_id(self, track_id):
- return self.session.query(Session)\
- .join(SessionTrack, Track.id == SessionTrack.track_id)\
- .filter(SessionTrack.track_id == track_id).\
- all()
|