Skip to content
Snippets Groups Projects
sessions.go 7.54 KiB
//
//  tank
//
//  Import and Playlist Daemon for autoradio project
//
//
//  Copyright (C) 2017-2019 Christian Pointner <equinox@helsinki.at>
//
//  This file is part of tank.
//
//  tank is free software: you can redistribute it and/or modify
//  it under the terms of the GNU General Public License as published by
//  the Free Software Foundation, either version 3 of the License, or
//  any later version.
//
//  tank is distributed in the hope that it will be useful,
//  but WITHOUT ANY WARRANTY; without even the implied warranty of
//  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//  GNU General Public License for more details.
//
//  You should have received a copy of the GNU General Public License
//  along with tank. If not, see <http://www.gnu.org/licenses/>.
//

package auth

import (
	"context"
	"encoding/json"
	"errors"
	"net/http"
	"strings"
	"sync"
	"sync/atomic"
	"time"
)

const (
	defaultAge     = 24 * time.Hour
	tokenSeperator = ':'
)

type SessionState uint32

const (
	SessionStateNew SessionState = iota
	SessionStateLoginStarted
	SessionStateLoginFinalizing
	SessionStateLoggedIn
	SessionStateLoginFailed
	SessionStateLoginTimeout
	SessionStateStale
	SessionStateRemoved
)

func (s SessionState) String() string {
	switch s {
	case SessionStateNew:
		return "new"
	case SessionStateLoginStarted:
		return "login-started"
	case SessionStateLoginFinalizing:
		return "login-finalizing"
	case SessionStateLoggedIn:
		return "logged-in"
	case SessionStateLoginFailed:
		return "login-failed"
	case SessionStateLoginTimeout:
		return "login-timeout"
	case SessionStateStale:
		return "stale"
	case SessionStateRemoved:
		return "removed"
	}
	return "unknown"
}

func (s SessionState) MarshalText() (data []byte, err error) {
	data = []byte(s.String())
	return
}

type Session struct {
	id      string
	secret  string
	state   SessionState
	expires *time.Time

	subs      chan struct{}
	subsMutex sync.Mutex
	oidc      *OIDCSession

	Username string   `json:"username"`
	ReadOnly bool     `json:"readonly"`
	AllShows bool     `json:"all-shows"`
	Shows    []string `json:"shows"`
}

func NewSession() (s *Session, err error) {
	s = &Session{}
	if s.id, err = generateRandomString(16); err != nil {
		return
	}
	if s.secret, err = generateRandomString(32); err != nil {
		return
	}
	s.state = SessionStateNew
	s.expires = &time.Time{}
	return
}

var (
	anonAllowNone  = &Session{Username: "anonymous", ReadOnly: false, AllShows: false, Shows: []string{}}
	anonAllowAll   = &Session{Username: "anonymous", ReadOnly: false, AllShows: true, Shows: []string{}}
	anonAllowAllRO = &Session{Username: "anonymous", ReadOnly: true, AllShows: true, Shows: []string{}}
)

func (s *Session) ID() string {
	return s.id
}

func (s *Session) State() SessionState {
	return SessionState(atomic.LoadUint32((*uint32)(&s.state)))
}

func (s *Session) setState(st SessionState) {
	old := atomic.SwapUint32((*uint32)(&s.state), uint32(st))
	if old != uint32(st) {
		s.signalSubscribers()
	}
}

func (s *Session) updateState(old, new SessionState) bool {
	ok := atomic.CompareAndSwapUint32((*uint32)(&s.state), uint32(old), uint32(new))
	if ok && old != new {
		s.signalSubscribers()
	}
	return ok
}
func (s *Session) Expired() bool {
	if s.expires == nil {
		return false
	}
	return s.expires.Before(time.Now())
}

func (s *Session) subscribe() <-chan struct{} {
	s.subsMutex.Lock()
	defer s.subsMutex.Unlock()

	if s.subs == nil {
		s.subs = make(chan struct{})
	}
	return s.subs
}

func (s *Session) signalSubscribers() {
	s.subsMutex.Lock()
	defer s.subsMutex.Unlock()

	if s.subs != nil {
		close(s.subs)
		s.subs = nil
	}
}

func (s *Session) MarshalJSON() ([]byte, error) {
	return json.Marshal(struct {
		ID       string       `json:"id"`
		State    SessionState `json:"state"`
		Expires  *time.Time   `json:"expires,omitempty"`
		Username string       `json:"username"`
		ReadOnly bool         `json:"readonly"`
		AllShows bool         `json:"all-shows"`
		Shows    []string     `json:"shows"`
	}{
		ID:       s.id,
		State:    s.State(),
		Expires:  s.expires,
		Username: s.Username,
		ReadOnly: s.ReadOnly,
		AllShows: s.AllShows,
		Shows:    s.Shows,
	})
}

func (s *Session) getToken() string {
	return s.id + string(tokenSeperator) + s.secret
}

func parseBearerAuthHeader(authHeader string) (string, string, bool) {
	const prefix = "Bearer "

	if len(authHeader) < len(prefix) || !strings.EqualFold(authHeader[:len(prefix)], prefix) {
		return "", "", false
	}
	token := authHeader[len(prefix):]
	if token == "" {
		return "", "", false
	}
	sepIdx := strings.IndexByte(token, tokenSeperator)
	if sepIdx < 0 {
		return "", "", false
	}
	return token[:sepIdx], token[sepIdx+1:], true
}

func getSessionFromBearerToken(r *http.Request) *Session {
	authHeader := r.Header.Get("Authorization")
	if authHeader == "" {
		return nil
	}

	sID, secret, ok := parseBearerAuthHeader(authHeader)
	if !ok {
		return nil
	}

	s := auth.sessions.get(sID)
	if s == nil || s.secret != secret {
		return nil
	}
	if s.Expired() {
		auth.sessions.remove(sID)
		return nil
	}

	return s
}

func attachSessionToRequest(r *http.Request, s *Session) *http.Request {
	ctx := context.WithValue(r.Context(), sessionContextKey, s)
	return r.WithContext(ctx)
}

func SessionFromRequest(r *http.Request) (*Session, bool) {
	s, ok := r.Context().Value(sessionContextKey).(*Session)
	return s, ok
}

type SessionManager struct {
	mutex    sync.RWMutex
	maxAge   time.Duration
	sessions map[string]*Session
}

func NewSessionManager(c SessionsConfig) (sm *SessionManager, err error) {
	sm = &SessionManager{maxAge: defaultAge}
	if c.MaxAge > 0 {
		sm.maxAge = c.MaxAge
	}
	sm.sessions = make(map[string]*Session)
	go sm.runMaintenance()
	return
}

func (sm *SessionManager) runMaintenance() {
	t := time.NewTicker(time.Minute * 5)
	for {
		<-t.C
		sm.cleanup()
	}
}

func (sm *SessionManager) insert(s *Session) (err error) {
	eMax := time.Now().Add(sm.maxAge)
	if s.expires != nil && (s.expires.IsZero() || s.expires.After(eMax)) {
		s.expires = &eMax
	}

	sm.mutex.Lock()
	defer sm.mutex.Unlock()
	sm.sessions[s.id] = s
	auth.dbgLog.Printf("authentication: added new session %s", s.id)
	return
}

func (sm *SessionManager) get(id string) *Session {
	sm.mutex.RLock()
	defer sm.mutex.RUnlock()

	s, ok := sm.sessions[id]
	if !ok {
		return nil
	}
	return s
}

func (sm *SessionManager) getAndSubscribe(id string) (*Session, <-chan struct{}) {
	sm.mutex.RLock()
	defer sm.mutex.RUnlock()

	s, ok := sm.sessions[id]
	if !ok {
		return nil, nil
	}
	return s, s.subscribe()
}

func (sm *SessionManager) update(id string, s *Session) error {
	sm.mutex.Lock()
	defer sm.mutex.Unlock()

	old, ok := sm.sessions[id]
	if !ok {
		return errors.New("session not found.")
	}
	s.secret = old.secret
	s.id = old.id
	s.state = old.state
	s.expires = old.expires

	old.signalSubscribers()

	sm.sessions[id] = s
	old.setState(SessionStateStale)
	auth.dbgLog.Printf("authentication: updated session %s", id)
	return nil
}

func (sm *SessionManager) remove(id string) {
	sm.mutex.Lock()
	defer sm.mutex.Unlock()

	s, ok := sm.sessions[id]
	if !ok {
		return
	}
	s.setState(SessionStateRemoved)
	delete(sm.sessions, id)
	auth.dbgLog.Printf("authentication: removed session %s", id)
}

func (sm *SessionManager) cleanup() {
	sm.mutex.Lock()
	defer sm.mutex.Unlock()

	for id, s := range sm.sessions {
		exp := s.Expired()
		st := s.State()
		if exp || st == SessionStateLoginFailed || st == SessionStateLoginTimeout {
			s.setState(SessionStateRemoved)
			delete(sm.sessions, id)
			reason := st.String()
			if exp {
				reason = "expired"
			}
			auth.dbgLog.Printf("authentication: removed session %s (reason=%s)", id, reason)
		}
	}
}