from abc import abstractmethod from core import db from core.models import MusicFolder, Track, Tag, TrackTag, SessionTrack, Session class Repository: MODEL_CLS = None def __init__(self): self.session = db.session() def query(self): return self.session.query(self.MODEL_CLS) def commit(self): self.session.commit() def rollback(self): self.session.rollback() def get_by_id(self, id_): return self.query().filter(id == id_).first() def get_all(self): return self.query().all() def get_by(self, field, val): return self.query().filter(field == val).all() def exists(self, field, val): return self.query().filter(field == val).exists() def create(self, model, commit=False): self.session.add(model) if commit: self.session.commit() def delete(self, model, commit=False): model.delete() if commit: self.commit() def __del__(self): self.session.close() class MusicFolderRepository(Repository): MODEL_CLS = MusicFolder def __init__(self): super().__init__() class TagRepository(Repository): MODEL_CLS = Tag def __init__(self): super().__init__() def get_by_track_id(self, track_id): return self.session.query(Tag)\ .join(TrackTag, Tag.id == TrackTag.tag_id)\ .filter(TrackTag.track_id == track_id).\ all() class TrackRepository(Repository): MODEL_CLS = Track def get_by_tag_id(self, tag_id): return self.session.query(Track)\ .join(TrackTag, Track.id == TrackTag.track_id)\ .filter(TrackTag.tag_id == tag_id).\ all() def get_by_tag_ids(self, tag_ids): return self.session.query(Track)\ .join(TrackTag, Track.id == TrackTag.track_id)\ .filter(TrackTag.tag_id.in_(tag_ids)).\ all() def get_by_session_id(self, session_id): return self.session.query(Track)\ .join(SessionTrack, Track.id == SessionTrack.track_id)\ .filter(SessionTrack.session_id == session_id).\ all() class SessionRepository(Repository): TABLE_NAME = "Sessions" MODEL_CLS = Track def __init__(self): super().__init__() def get_by_track_id(self, track_id): return self.session.query(Session)\ .join(SessionTrack, Track.id == SessionTrack.track_id)\ .filter(SessionTrack.track_id == track_id).\ all()