diff --git a/src/scheduling/fallback.py b/src/scheduling/fallback.py index 6a6133aa6c62e1b745776d5910d780bdab2088ad..bcf669f4cd9bd3c0b052428e65210cdde531d8ff 100644 --- a/src/scheduling/fallback.py +++ b/src/scheduling/fallback.py @@ -138,7 +138,8 @@ class FallbackManager: if fallback_type != self.state.get("previous_fallback_type"): timeslot = self.state["timeslot"] if timeslot: - DB.session.merge(timeslot) + with DB.Session() as session: + session.merge(timeslot) self.engine.event_dispatcher.on_fallback_active(timeslot, fallback_type) diff --git a/src/scheduling/models.py b/src/scheduling/models.py index f5a6b487e61aca67124071155197955511ef49b0..fdb5ebdcebcc99d0dd5cf6638b57b91e871dded8 100644 --- a/src/scheduling/models.py +++ b/src/scheduling/models.py @@ -17,6 +17,7 @@ # along with this program. If not, see <http://www.gnu.org/licenses/>. +import contextlib import sys import time import logging @@ -24,6 +25,7 @@ import datetime import sqlalchemy as sa +import sqlalchemy from sqlalchemy import BigInteger, Boolean, Column, DateTime, Integer, String, ForeignKey, ColumnDefault from sqlalchemy.orm import scoped_session from sqlalchemy.orm import sessionmaker @@ -40,14 +42,35 @@ config = AuraConfig() engine = sa.create_engine(config.get_database_uri()) Base = declarative_base() Base.metadata.bind = engine +__sqlalchemy_version = tuple(int(item) for item in sqlalchemy.__version__.split(".")[:2]) + class DB(): session_factory = sessionmaker(bind=engine) Session = scoped_session(session_factory) - session = Session() Model = Base +# Monkey-patch the above DB.Session generator for SQLAlchemy before v1.4. +# Such older versions of SQLAlchemy do not support contexts. +if __sqlalchemy_version < (1, 4): + @contextlib.contextmanager + def get_session_context(): + """ provide a context for a session + + This context is the same as the one provided by a "scoped_session" in SQLAlchemy v1.4 or + later. + + see https://docs.sqlalchemy.org/en/13/orm/session_basics.html#when-do-i-construct-a-session-when-do-i-commit-it-and-when-do-i-close-it + """ + session = scoped_session(DB.session_factory) + try: + yield session + finally: + session.close() + + DB.Session = get_session_context + class AuraDatabaseModel(): """ @@ -69,29 +92,32 @@ class AuraDatabaseModel(): """ Store to the database """ - if add: - DB.session.add(self) - else: - DB.session.merge(self) - if commit: - DB.session.commit() + with DB.Session() as session: + if add: + session.add(self) + else: + session.merge(self) + if commit: + session.commit() def delete(self, commit=False): """ Delete from the database """ - DB.session.delete(self) - if commit: - DB.session.commit() + with DB.Session() as session: + session.delete(self) + if commit: + session.commit() def refresh(self): """ Refreshes the currect record """ - DB.session.expire(self) - DB.session.refresh(self) + with DB.Session() as session: + session.expire(self) + session.refresh(self) def _asdict(self): @@ -132,7 +158,8 @@ class AuraDatabaseModel(): """ Base.metadata.drop_all() Base.metadata.create_all() - DB.session.commit() + with DB.Session() as session: + session.commit() if systemexit: sys.exit(0) @@ -223,7 +250,12 @@ class Timeslot(DB.Model, AuraDatabaseModel): Args: date_time (datetime): date and time when the timeslot starts """ - return DB.session.query(Timeslot).filter(Timeslot.timeslot_start == date_time).first() + with DB.Session() as session: + return ( + session.query(Timeslot) + .filter(Timeslot.timeslot_start == date_time) + .first() + ) @staticmethod @@ -238,10 +270,13 @@ class Timeslot(DB.Model, AuraDatabaseModel): Returns: ([Timeslot]): List of timeslots """ - timeslots = DB.session.query(Timeslot).\ - filter(Timeslot.timeslot_start >= date_from).\ - order_by(Timeslot.timeslot_start).all() - return timeslots + with DB.Session() as session: + return ( + session.query(Timeslot) + .filter(Timeslot.timeslot_start >= date_from) + .order_by(Timeslot.timeslot_start) + .all() + ) def set_active_entry(self, entry): @@ -355,7 +390,8 @@ class Playlist(DB.Model, AuraDatabaseModel): # """ # Fetches all entries # """ - # all_entries = DB.session.query(Playlist).filter(Playlist.fallback_type == 0).all() + # with DB.Session() as session: + # all_entries = session.query(Playlist).filter(Playlist.fallback_type == 0).all() # cnt = 0 # for entry in all_entries: @@ -381,7 +417,12 @@ class Playlist(DB.Model, AuraDatabaseModel): Exception: In case there a inconsistent database state, such es multiple playlists for given date/time. """ playlist = None - playlists = DB.session.query(Playlist).filter(Playlist.timeslot_start == start_date).all() + with DB.Session() as session: + playlists = ( + session.query(Playlist) + .filter(Playlist.timeslot_start == start_date) + .all() + ) for p in playlists: if p.playlist_id == playlist_id: @@ -401,7 +442,13 @@ class Playlist(DB.Model, AuraDatabaseModel): Returns: (Array<Playlist>): An array holding the playlists """ - return DB.session.query(Playlist).filter(Playlist.playlist_id == playlist_id).order_by(Playlist.timeslot_start).all() + with DB.Session() as session: + return ( + session.query(Playlist) + .filter(Playlist.playlist_id == playlist_id) + .order_by(Playlist.timeslot_start) + .all() + ) @staticmethod @@ -409,10 +456,11 @@ class Playlist(DB.Model, AuraDatabaseModel): """ Checks if the given is empty """ - try: - return not DB.session.query(Playlist).one_or_none() - except sa.orm.exc.MultipleResultsFound: - return False + with DB.Session() as session: + try: + return not session.query(Playlist).one_or_none() + except sa.orm.exc.MultipleResultsFound: + return False @hybrid_property @@ -509,26 +557,37 @@ class PlaylistEntry(DB.Model, AuraDatabaseModel): """ Selects one entry identified by `playlist_id` and `entry_num`. """ - return DB.session.query(PlaylistEntry).filter(PlaylistEntry.artificial_playlist_id == artificial_playlist_id, PlaylistEntry.entry_num == entry_num).first() + with DB.Session() as session: + return ( + session.query(PlaylistEntry) + .filter(PlaylistEntry.entry_num == entry_num) + .filter(PlaylistEntry.artificial_playlist_id == artificial_playlist_id) + .first() + ) @staticmethod def delete_entry(artificial_playlist_id, entry_num): """ Deletes the playlist entry and associated metadata. """ - entry = PlaylistEntry.select_playlistentry_for_playlist(artificial_playlist_id, entry_num) - metadata = PlaylistEntryMetaData.select_metadata_for_entry(entry.artificial_id) - metadata.delete() - entry.delete() - DB.session.commit() + with DB.Session() as session: + entry = PlaylistEntry.select_playlistentry_for_playlist(artificial_playlist_id, entry_num) + metadata = PlaylistEntryMetaData.select_metadata_for_entry(entry.artificial_id) + metadata.delete() + entry.delete() + session.commit() @staticmethod def count_entries(artificial_playlist_id): """ Returns the count of all entries. """ - result = DB.session.query(PlaylistEntry).filter(PlaylistEntry.artificial_playlist_id == artificial_playlist_id).count() - return result + with DB.Session() as session: + return ( + session.query(PlaylistEntry) + .filter(PlaylistEntry.artificial_playlist_id == artificial_playlist_id) + .count() + ) @hybrid_property def entry_end(self): @@ -633,6 +692,9 @@ class PlaylistEntryMetaData(DB.Model, AuraDatabaseModel): @staticmethod def select_metadata_for_entry(artificial_playlistentry_id): - return DB.session.query(PlaylistEntryMetaData).filter(PlaylistEntryMetaData.artificial_entry_id == artificial_playlistentry_id).first() - - + with DB.Session() as session: + return ( + session.query(PlaylistEntryMetaData) + .filter(PlaylistEntryMetaData.artificial_entry_id == artificial_playlistentry_id) + .first() + )