repositories.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from abc import abstractmethod
  2. from core import db
  3. from core.models import MusicFolder, Track, Tag, TrackTag, SessionTrack, Session
  4. class Repository:
  5. MODEL_CLS = None
  6. def __init__(self):
  7. self.session = db.session()
  8. def query(self):
  9. return self.session.query(self.MODEL_CLS)
  10. def commit(self):
  11. self.session.commit()
  12. def rollback(self):
  13. self.session.rollback()
  14. def get_by_id(self, id_):
  15. return self.query().filter(id == id_).first()
  16. def get_all(self):
  17. return self.query().all()
  18. def get_by(self, field, val):
  19. return self.query().filter(field == val).all()
  20. def exists(self, field, val):
  21. return self.query().filter(field == val).exists()
  22. def create(self, model, commit=False):
  23. self.session.add(model)
  24. if commit:
  25. self.session.commit()
  26. def delete(self, model, commit=False):
  27. model.delete()
  28. if commit:
  29. self.commit()
  30. def __del__(self):
  31. self.session.close()
  32. class MusicFolderRepository(Repository):
  33. MODEL_CLS = MusicFolder
  34. def __init__(self):
  35. super().__init__()
  36. class TagRepository(Repository):
  37. MODEL_CLS = Tag
  38. def __init__(self):
  39. super().__init__()
  40. def get_by_track_id(self, track_id):
  41. return self.session.query(Tag)\
  42. .join(TrackTag, Tag.id == TrackTag.tag_id)\
  43. .filter(TrackTag.track_id == track_id).\
  44. all()
  45. class TrackRepository(Repository):
  46. MODEL_CLS = Track
  47. def get_by_hash(self, hash_):
  48. return self.query().filter(hash == hash_).first()
  49. def get_by_tag_id(self, tag_id):
  50. return self.session.query(Track)\
  51. .join(TrackTag, Track.id == TrackTag.track_id)\
  52. .filter(TrackTag.tag_id == tag_id).\
  53. all()
  54. def get_by_tag_ids(self, tag_ids):
  55. return self.session.query(Track)\
  56. .join(TrackTag, Track.id == TrackTag.track_id)\
  57. .filter(TrackTag.tag_id.in_(tag_ids)).\
  58. all()
  59. def get_by_session_id(self, session_id):
  60. return self.session.query(Track)\
  61. .join(SessionTrack, Track.id == SessionTrack.track_id)\
  62. .filter(SessionTrack.session_id == session_id).\
  63. all()
  64. class SessionRepository(Repository):
  65. TABLE_NAME = "Sessions"
  66. MODEL_CLS = Track
  67. def __init__(self):
  68. super().__init__()
  69. def get_by_track_id(self, track_id):
  70. return self.session.query(Session)\
  71. .join(SessionTrack, Track.id == SessionTrack.track_id)\
  72. .filter(SessionTrack.track_id == track_id).\
  73. all()