repositories.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. from abc import abstractmethod
  2. from core import db
  3. from core.models import MusicFolder, Track, Tag, TrackTag, SessionTrack, Session
  4. class Repository:
  5. MODEL_CLS = None
  6. def __init__(self):
  7. self.session = db.session()
  8. def query(self):
  9. return self.session.query(self.MODEL_CLS)
  10. def commit(self):
  11. self.session.commit()
  12. def rollback(self):
  13. self.session.rollback()
  14. def get_by_id(self, id_):
  15. return self.query().filter(id == id_).first()
  16. def get_all(self):
  17. return self.query().all()
  18. def get_by(self, field, val):
  19. return self.query().filter(field == val).all()
  20. def exists(self, field, val):
  21. return self.query().filter(field == val).exists()
  22. def create(self, model, commit=False):
  23. self.session.add(model)
  24. if commit:
  25. self.session.commit()
  26. def delete(self, model, commit=False):
  27. model.delete()
  28. if commit:
  29. self.commit()
  30. def __del__(self):
  31. self.session.close()
  32. class MusicFolderRepository(Repository):
  33. MODEL_CLS = MusicFolder
  34. def __init__(self):
  35. super().__init__()
  36. class TagRepository(Repository):
  37. MODEL_CLS = Tag
  38. def __init__(self):
  39. super().__init__()
  40. def get_by_track_id(self, track_id):
  41. return self.session.query(Tag)\
  42. .join(TrackTag, Tag.id == TrackTag.tag_id)\
  43. .filter(TrackTag.track_id == track_id).\
  44. all()
  45. class TrackRepository(Repository):
  46. MODEL_CLS = Track
  47. def get_by_tag_id(self, tag_id):
  48. return self.session.query(Track)\
  49. .join(TrackTag, Track.id == TrackTag.track_id)\
  50. .filter(TrackTag.tag_id == tag_id).\
  51. all()
  52. def get_by_tag_ids(self, tag_ids):
  53. return self.session.query(Track)\
  54. .join(TrackTag, Track.id == TrackTag.track_id)\
  55. .filter(TrackTag.tag_id.in_(tag_ids)).\
  56. all()
  57. def get_by_session_id(self, session_id):
  58. return self.session.query(Track)\
  59. .join(SessionTrack, Track.id == SessionTrack.track_id)\
  60. .filter(SessionTrack.session_id == session_id).\
  61. all()
  62. class SessionRepository(Repository):
  63. TABLE_NAME = "Sessions"
  64. MODEL_CLS = Track
  65. def __init__(self):
  66. super().__init__()
  67. def get_by_track_id(self, track_id):
  68. return self.session.query(Session)\
  69. .join(SessionTrack, Track.id == SessionTrack.track_id)\
  70. .filter(SessionTrack.track_id == track_id).\
  71. all()