Commit 0b594e50 authored by Lars Kruse's avatar Lars Kruse
Browse files

refactor: use short-lived sessions when accessing the database

Previously the database connection ("DB.session") was established during
the startup of Aura Engine.
This was causing conflicts in the threaded execution environment (#75).
Now all sessions are acquired in a short-lived context:

  with DB.Session() as session:
      ...

If database query performance is a real concern, then the session
contexts could be moved to a higher level.

Closes: #75
parent d041f86c
Pipeline #1111 passed with stage
in 1 minute and 27 seconds
...@@ -138,7 +138,8 @@ class FallbackManager: ...@@ -138,7 +138,8 @@ class FallbackManager:
if fallback_type != self.state.get("previous_fallback_type"): if fallback_type != self.state.get("previous_fallback_type"):
timeslot = self.state["timeslot"] timeslot = self.state["timeslot"]
if timeslot: if timeslot:
DB.session.merge(timeslot) with DB.Session() as session:
session.merge(timeslot)
self.engine.event_dispatcher.on_fallback_active(timeslot, fallback_type) self.engine.event_dispatcher.on_fallback_active(timeslot, fallback_type)
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import contextlib
import sys import sys
import time import time
import logging import logging
...@@ -24,6 +25,7 @@ import datetime ...@@ -24,6 +25,7 @@ import datetime
import sqlalchemy as sa import sqlalchemy as sa
import sqlalchemy
from sqlalchemy import BigInteger, Boolean, Column, DateTime, Integer, String, ForeignKey, ColumnDefault from sqlalchemy import BigInteger, Boolean, Column, DateTime, Integer, String, ForeignKey, ColumnDefault
from sqlalchemy.orm import scoped_session from sqlalchemy.orm import scoped_session
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
...@@ -40,14 +42,35 @@ config = AuraConfig() ...@@ -40,14 +42,35 @@ config = AuraConfig()
engine = sa.create_engine(config.get_database_uri()) engine = sa.create_engine(config.get_database_uri())
Base = declarative_base() Base = declarative_base()
Base.metadata.bind = engine Base.metadata.bind = engine
__sqlalchemy_version = tuple(int(item) for item in sqlalchemy.__version__.split(".")[:2])
class DB(): class DB():
session_factory = sessionmaker(bind=engine) session_factory = sessionmaker(bind=engine)
Session = scoped_session(session_factory) Session = scoped_session(session_factory)
session = Session()
Model = Base Model = Base
# Monkey-patch the above DB.Session generator for SQLAlchemy before v1.4.
# Such older versions of SQLAlchemy do not support contexts.
if __sqlalchemy_version < (1, 4):
@contextlib.contextmanager
def get_session_context():
""" provide a context for a session
This context is the same as the one provided by a "scoped_session" in SQLAlchemy v1.4 or
later.
see https://docs.sqlalchemy.org/en/13/orm/session_basics.html#when-do-i-construct-a-session-when-do-i-commit-it-and-when-do-i-close-it
"""
session = scoped_session(DB.session_factory)
try:
yield session
finally:
session.close()
DB.Session = get_session_context
class AuraDatabaseModel(): class AuraDatabaseModel():
""" """
...@@ -69,29 +92,32 @@ class AuraDatabaseModel(): ...@@ -69,29 +92,32 @@ class AuraDatabaseModel():
""" """
Store to the database Store to the database
""" """
if add: with DB.Session() as session:
DB.session.add(self) if add:
else: session.add(self)
DB.session.merge(self) else:
if commit: session.merge(self)
DB.session.commit() if commit:
session.commit()
def delete(self, commit=False): def delete(self, commit=False):
""" """
Delete from the database Delete from the database
""" """
DB.session.delete(self) with DB.Session() as session:
if commit: session.delete(self)
DB.session.commit() if commit:
session.commit()
def refresh(self): def refresh(self):
""" """
Refreshes the currect record Refreshes the currect record
""" """
DB.session.expire(self) with DB.Session() as session:
DB.session.refresh(self) session.expire(self)
session.refresh(self)
def _asdict(self): def _asdict(self):
...@@ -132,7 +158,8 @@ class AuraDatabaseModel(): ...@@ -132,7 +158,8 @@ class AuraDatabaseModel():
""" """
Base.metadata.drop_all() Base.metadata.drop_all()
Base.metadata.create_all() Base.metadata.create_all()
DB.session.commit() with DB.Session() as session:
session.commit()
if systemexit: if systemexit:
sys.exit(0) sys.exit(0)
...@@ -223,7 +250,12 @@ class Timeslot(DB.Model, AuraDatabaseModel): ...@@ -223,7 +250,12 @@ class Timeslot(DB.Model, AuraDatabaseModel):
Args: Args:
date_time (datetime): date and time when the timeslot starts date_time (datetime): date and time when the timeslot starts
""" """
return DB.session.query(Timeslot).filter(Timeslot.timeslot_start == date_time).first() with DB.Session() as session:
return (
session.query(Timeslot)
.filter(Timeslot.timeslot_start == date_time)
.first()
)
@staticmethod @staticmethod
...@@ -238,10 +270,13 @@ class Timeslot(DB.Model, AuraDatabaseModel): ...@@ -238,10 +270,13 @@ class Timeslot(DB.Model, AuraDatabaseModel):
Returns: Returns:
([Timeslot]): List of timeslots ([Timeslot]): List of timeslots
""" """
timeslots = DB.session.query(Timeslot).\ with DB.Session() as session:
filter(Timeslot.timeslot_start >= date_from).\ return (
order_by(Timeslot.timeslot_start).all() session.query(Timeslot)
return timeslots .filter(Timeslot.timeslot_start >= date_from)
.order_by(Timeslot.timeslot_start)
.all()
)
def set_active_entry(self, entry): def set_active_entry(self, entry):
...@@ -355,7 +390,8 @@ class Playlist(DB.Model, AuraDatabaseModel): ...@@ -355,7 +390,8 @@ class Playlist(DB.Model, AuraDatabaseModel):
# """ # """
# Fetches all entries # Fetches all entries
# """ # """
# all_entries = DB.session.query(Playlist).filter(Playlist.fallback_type == 0).all() # with DB.Session() as session:
# all_entries = session.query(Playlist).filter(Playlist.fallback_type == 0).all()
# cnt = 0 # cnt = 0
# for entry in all_entries: # for entry in all_entries:
...@@ -381,7 +417,12 @@ class Playlist(DB.Model, AuraDatabaseModel): ...@@ -381,7 +417,12 @@ class Playlist(DB.Model, AuraDatabaseModel):
Exception: In case there a inconsistent database state, such es multiple playlists for given date/time. Exception: In case there a inconsistent database state, such es multiple playlists for given date/time.
""" """
playlist = None playlist = None
playlists = DB.session.query(Playlist).filter(Playlist.timeslot_start == start_date).all() with DB.Session() as session:
playlists = (
session.query(Playlist)
.filter(Playlist.timeslot_start == start_date)
.all()
)
for p in playlists: for p in playlists:
if p.playlist_id == playlist_id: if p.playlist_id == playlist_id:
...@@ -401,7 +442,13 @@ class Playlist(DB.Model, AuraDatabaseModel): ...@@ -401,7 +442,13 @@ class Playlist(DB.Model, AuraDatabaseModel):
Returns: Returns:
(Array<Playlist>): An array holding the playlists (Array<Playlist>): An array holding the playlists
""" """
return DB.session.query(Playlist).filter(Playlist.playlist_id == playlist_id).order_by(Playlist.timeslot_start).all() with DB.Session() as session:
return (
session.query(Playlist)
.filter(Playlist.playlist_id == playlist_id)
.order_by(Playlist.timeslot_start)
.all()
)
@staticmethod @staticmethod
...@@ -409,10 +456,11 @@ class Playlist(DB.Model, AuraDatabaseModel): ...@@ -409,10 +456,11 @@ class Playlist(DB.Model, AuraDatabaseModel):
""" """
Checks if the given is empty Checks if the given is empty
""" """
try: with DB.Session() as session:
return not DB.session.query(Playlist).one_or_none() try:
except sa.orm.exc.MultipleResultsFound: return not session.query(Playlist).one_or_none()
return False except sa.orm.exc.MultipleResultsFound:
return False
@hybrid_property @hybrid_property
...@@ -509,26 +557,37 @@ class PlaylistEntry(DB.Model, AuraDatabaseModel): ...@@ -509,26 +557,37 @@ class PlaylistEntry(DB.Model, AuraDatabaseModel):
""" """
Selects one entry identified by `playlist_id` and `entry_num`. Selects one entry identified by `playlist_id` and `entry_num`.
""" """
return DB.session.query(PlaylistEntry).filter(PlaylistEntry.artificial_playlist_id == artificial_playlist_id, PlaylistEntry.entry_num == entry_num).first() with DB.Session() as session:
return (
session.query(PlaylistEntry)
.filter(PlaylistEntry.entry_num == entry_num)
.filter(PlaylistEntry.artificial_playlist_id == artificial_playlist_id)
.first()
)
@staticmethod @staticmethod
def delete_entry(artificial_playlist_id, entry_num): def delete_entry(artificial_playlist_id, entry_num):
""" """
Deletes the playlist entry and associated metadata. Deletes the playlist entry and associated metadata.
""" """
entry = PlaylistEntry.select_playlistentry_for_playlist(artificial_playlist_id, entry_num) with DB.Session() as session:
metadata = PlaylistEntryMetaData.select_metadata_for_entry(entry.artificial_id) entry = PlaylistEntry.select_playlistentry_for_playlist(artificial_playlist_id, entry_num)
metadata.delete() metadata = PlaylistEntryMetaData.select_metadata_for_entry(entry.artificial_id)
entry.delete() metadata.delete()
DB.session.commit() entry.delete()
session.commit()
@staticmethod @staticmethod
def count_entries(artificial_playlist_id): def count_entries(artificial_playlist_id):
""" """
Returns the count of all entries. Returns the count of all entries.
""" """
result = DB.session.query(PlaylistEntry).filter(PlaylistEntry.artificial_playlist_id == artificial_playlist_id).count() with DB.Session() as session:
return result return (
session.query(PlaylistEntry)
.filter(PlaylistEntry.artificial_playlist_id == artificial_playlist_id)
.count()
)
@hybrid_property @hybrid_property
def entry_end(self): def entry_end(self):
...@@ -633,6 +692,9 @@ class PlaylistEntryMetaData(DB.Model, AuraDatabaseModel): ...@@ -633,6 +692,9 @@ class PlaylistEntryMetaData(DB.Model, AuraDatabaseModel):
@staticmethod @staticmethod
def select_metadata_for_entry(artificial_playlistentry_id): def select_metadata_for_entry(artificial_playlistentry_id):
return DB.session.query(PlaylistEntryMetaData).filter(PlaylistEntryMetaData.artificial_entry_id == artificial_playlistentry_id).first() with DB.Session() as session:
return (
session.query(PlaylistEntryMetaData)
.filter(PlaylistEntryMetaData.artificial_entry_id == artificial_playlistentry_id)
.first()
)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment