Skip to content
Snippets Groups Projects
Commit 24943467 authored by David Trattnig's avatar David Trattnig
Browse files

Merge branch 'lars-short-db-sessions' into 'master'

refactor: use short-lived sessions when accessing the database

Closes #75

See merge request aura/engine!9
parents d8356a38 0b594e50
No related branches found
No related tags found
No related merge requests found
......@@ -138,7 +138,8 @@ class FallbackManager:
if fallback_type != self.state.get("previous_fallback_type"):
timeslot = self.state["timeslot"]
if timeslot:
DB.session.merge(timeslot)
with DB.Session() as session:
session.merge(timeslot)
self.engine.event_dispatcher.on_fallback_active(timeslot, fallback_type)
......
......@@ -17,6 +17,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import contextlib
import sys
import time
import logging
......@@ -24,6 +25,7 @@ import datetime
import sqlalchemy as sa
import sqlalchemy
from sqlalchemy import BigInteger, Boolean, Column, DateTime, Integer, String, ForeignKey, ColumnDefault
from sqlalchemy.orm import scoped_session
from sqlalchemy.orm import sessionmaker
......@@ -40,14 +42,35 @@ config = AuraConfig()
engine = sa.create_engine(config.get_database_uri())
Base = declarative_base()
Base.metadata.bind = engine
__sqlalchemy_version = tuple(int(item) for item in sqlalchemy.__version__.split(".")[:2])
class DB():
session_factory = sessionmaker(bind=engine)
Session = scoped_session(session_factory)
session = Session()
Model = Base
# Monkey-patch the above DB.Session generator for SQLAlchemy before v1.4.
# Such older versions of SQLAlchemy do not support contexts.
if __sqlalchemy_version < (1, 4):
@contextlib.contextmanager
def get_session_context():
""" provide a context for a session
This context is the same as the one provided by a "scoped_session" in SQLAlchemy v1.4 or
later.
see https://docs.sqlalchemy.org/en/13/orm/session_basics.html#when-do-i-construct-a-session-when-do-i-commit-it-and-when-do-i-close-it
"""
session = scoped_session(DB.session_factory)
try:
yield session
finally:
session.close()
DB.Session = get_session_context
class AuraDatabaseModel():
"""
......@@ -69,29 +92,32 @@ class AuraDatabaseModel():
"""
Store to the database
"""
if add:
DB.session.add(self)
else:
DB.session.merge(self)
if commit:
DB.session.commit()
with DB.Session() as session:
if add:
session.add(self)
else:
session.merge(self)
if commit:
session.commit()
def delete(self, commit=False):
"""
Delete from the database
"""
DB.session.delete(self)
if commit:
DB.session.commit()
with DB.Session() as session:
session.delete(self)
if commit:
session.commit()
def refresh(self):
"""
Refreshes the correct record
"""
DB.session.expire(self)
DB.session.refresh(self)
with DB.Session() as session:
session.expire(self)
session.refresh(self)
def _asdict(self):
......@@ -132,7 +158,8 @@ class AuraDatabaseModel():
"""
Base.metadata.drop_all()
Base.metadata.create_all()
DB.session.commit()
with DB.Session() as session:
session.commit()
if systemexit:
sys.exit(0)
......@@ -223,7 +250,12 @@ class Timeslot(DB.Model, AuraDatabaseModel):
Args:
date_time (datetime): date and time when the timeslot starts
"""
return DB.session.query(Timeslot).filter(Timeslot.timeslot_start == date_time).first()
with DB.Session() as session:
return (
session.query(Timeslot)
.filter(Timeslot.timeslot_start == date_time)
.first()
)
@staticmethod
......@@ -238,10 +270,13 @@ class Timeslot(DB.Model, AuraDatabaseModel):
Returns:
([Timeslot]): List of timeslots
"""
timeslots = DB.session.query(Timeslot).\
filter(Timeslot.timeslot_start >= date_from).\
order_by(Timeslot.timeslot_start).all()
return timeslots
with DB.Session() as session:
return (
session.query(Timeslot)
.filter(Timeslot.timeslot_start >= date_from)
.order_by(Timeslot.timeslot_start)
.all()
)
def set_active_entry(self, entry):
......@@ -355,7 +390,8 @@ class Playlist(DB.Model, AuraDatabaseModel):
# """
# Fetches all entries
# """
# all_entries = DB.session.query(Playlist).filter(Playlist.fallback_type == 0).all()
# with DB.Session() as session:
# all_entries = session.query(Playlist).filter(Playlist.fallback_type == 0).all()
# cnt = 0
# for entry in all_entries:
......@@ -381,7 +417,12 @@ class Playlist(DB.Model, AuraDatabaseModel):
Exception: In case there a inconsistent database state, such es multiple playlists for given date/time.
"""
playlist = None
playlists = DB.session.query(Playlist).filter(Playlist.timeslot_start == start_date).all()
with DB.Session() as session:
playlists = (
session.query(Playlist)
.filter(Playlist.timeslot_start == start_date)
.all()
)
for p in playlists:
if p.playlist_id == playlist_id:
......@@ -401,7 +442,13 @@ class Playlist(DB.Model, AuraDatabaseModel):
Returns:
(Array<Playlist>): An array holding the playlists
"""
return DB.session.query(Playlist).filter(Playlist.playlist_id == playlist_id).order_by(Playlist.timeslot_start).all()
with DB.Session() as session:
return (
session.query(Playlist)
.filter(Playlist.playlist_id == playlist_id)
.order_by(Playlist.timeslot_start)
.all()
)
@staticmethod
......@@ -409,10 +456,11 @@ class Playlist(DB.Model, AuraDatabaseModel):
"""
Checks if the given is empty
"""
try:
return not DB.session.query(Playlist).one_or_none()
except sa.orm.exc.MultipleResultsFound:
return False
with DB.Session() as session:
try:
return not session.query(Playlist).one_or_none()
except sa.orm.exc.MultipleResultsFound:
return False
@hybrid_property
......@@ -509,26 +557,37 @@ class PlaylistEntry(DB.Model, AuraDatabaseModel):
"""
Selects one entry identified by `playlist_id` and `entry_num`.
"""
return DB.session.query(PlaylistEntry).filter(PlaylistEntry.artificial_playlist_id == artificial_playlist_id, PlaylistEntry.entry_num == entry_num).first()
with DB.Session() as session:
return (
session.query(PlaylistEntry)
.filter(PlaylistEntry.entry_num == entry_num)
.filter(PlaylistEntry.artificial_playlist_id == artificial_playlist_id)
.first()
)
@staticmethod
def delete_entry(artificial_playlist_id, entry_num):
"""
Deletes the playlist entry and associated metadata.
"""
entry = PlaylistEntry.select_playlistentry_for_playlist(artificial_playlist_id, entry_num)
metadata = PlaylistEntryMetaData.select_metadata_for_entry(entry.artificial_id)
metadata.delete()
entry.delete()
DB.session.commit()
with DB.Session() as session:
entry = PlaylistEntry.select_playlistentry_for_playlist(artificial_playlist_id, entry_num)
metadata = PlaylistEntryMetaData.select_metadata_for_entry(entry.artificial_id)
metadata.delete()
entry.delete()
session.commit()
@staticmethod
def count_entries(artificial_playlist_id):
"""
Returns the count of all entries.
"""
result = DB.session.query(PlaylistEntry).filter(PlaylistEntry.artificial_playlist_id == artificial_playlist_id).count()
return result
with DB.Session() as session:
return (
session.query(PlaylistEntry)
.filter(PlaylistEntry.artificial_playlist_id == artificial_playlist_id)
.count()
)
@hybrid_property
def entry_end(self):
......@@ -633,6 +692,9 @@ class PlaylistEntryMetaData(DB.Model, AuraDatabaseModel):
@staticmethod
def select_metadata_for_entry(artificial_playlistentry_id):
return DB.session.query(PlaylistEntryMetaData).filter(PlaylistEntryMetaData.artificial_entry_id == artificial_playlistentry_id).first()
with DB.Session() as session:
return (
session.query(PlaylistEntryMetaData)
.filter(PlaylistEntryMetaData.artificial_entry_id == artificial_playlistentry_id)
.first()
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment