Asyncifying django.contrib.auth? (and signals, and maybe sessions)

Hi, stepping in to add some code as @carltongibson mentioned a “prototype”

We’ve been using this session backend successfully in a FastAPI/Django hybrid application on a small scale for half a year now. (As a note, I’m the founder of the company, wrote the below code, and expressly allow the Django community to modify this freely. Anyway…). The “hybrid-ness” is that I use the Django ORM under the hood for db ops, my django-async-redis package for caching, and FastAPI for my API framework:

base.py:

#
#  Copyright (c) 2022-2023. Lazify, Inc. and its affiliates - All Rights Reserved
#
#  This file is part of the Lazify project.
#
#  Unauthorized copying of this file in source and binary
#  forms, via any medium, is strictly prohibited.
#  Proprietary and confidential
#

import logging
from datetime import datetime, timedelta

from asgiref.sync import sync_to_async
from django.conf import settings
from django.contrib.sessions.backends.base import VALID_KEY_CHARS
from django.contrib.sessions.backends.base import SessionBase as DjangoSessionBase
from django.core import signing
from django.utils import timezone
from django.utils.crypto import get_random_string


class SessionBase(DjangoSessionBase):
    """
    Base class for all Session classes.
    """

    TEST_COOKIE_NAME = "testcookie"
    TEST_COOKIE_VALUE = "worked"

    __not_given = object()

    def __contains__(self, key):
        return key in self._session

    def __getitem__(self, key):
        return self._session[key]

    def __setitem__(self, key, value):
        self.set(key, value)

    def __delitem__(self, key):
        del self._session[key]
        self.modified = True

    @property
    def key_salt(self):
        return "django.contrib.sessions." + self.__class__.__qualname__

    def get(self, key, default=None):
        return self._session.get(key, default)

    async def aget(self, key, default=None):
        return (await self._aget_session()).get(key, default)

    def set(self, key, value):
        self._session[key] = value
        self.modified = True

    async def aset(self, key, value):
        (await self._aget_session())[key] = value
        self.modified = True

    def pop(self, key, default=__not_given):
        self.modified = self.modified or key in self._session
        args = () if default is self.__not_given else (default,)
        return self._session.pop(key, *args)

    async def apop(self, key, default=__not_given):
        self.modified = self.modified or key in (await self._aget_session())
        args = () if default is self.__not_given else (default,)
        return (await self._aget_session()).pop(key, *args)

    def setdefault(self, key, value):
        if key in self._session:
            return self._session[key]
        else:
            self.set(key, value)
            return value

    async def asetdefault(self, key, value):
        if key in (await self._aget_session()):
            return (await self._aget_session())[key]
        else:
            await self.aset(key, value)
            return value

    def set_test_cookie(self):
        self[self.TEST_COOKIE_NAME] = self.TEST_COOKIE_VALUE

    async def aset_test_cookie(self):
        await self.aset(self.TEST_COOKIE_NAME, self.TEST_COOKIE_VALUE)

    def test_cookie_worked(self):
        return self.get(self.TEST_COOKIE_NAME) == self.TEST_COOKIE_VALUE

    async def atest_cookie_worked(self):
        return (await self.aget(self.TEST_COOKIE_NAME)) == self.TEST_COOKIE_VALUE

    def delete_test_cookie(self):
        del self[self.TEST_COOKIE_NAME]

    async def adelete_test_cookie(self):
        del (await self._aget_session())[self.TEST_COOKIE_NAME]
        self.modified = True

    def encode(self, session_dict):
        "Return the given session dictionary serialized and encoded as a string."
        return signing.dumps(
            session_dict,
            salt=self.key_salt,
            serializer=self.serializer,
            compress=True,
        )

    def decode(self, session_data):
        try:
            return signing.loads(
                session_data, salt=self.key_salt, serializer=self.serializer
            )
        except signing.BadSignature:
            logger = logging.getLogger("django.security.SuspiciousSession")
            logger.warning("Session data corrupted")
        except Exception:
            # ValueError, unpickling exceptions. If any of these happen, just
            # return an empty dictionary (an empty session).
            pass
        return {}

    def update(self, dict_):
        self._session.update(dict_)
        self.modified = True

    async def aupdate(self, dict_):
        (await self._aget_session()).update(dict_)
        self.modified = True

    def has_key(self, key):
        return key in self._session

    async def ahas_key(self, key):
        return key in (await self._aget_session())

    def keys(self):
        return self._session.keys()

    async def akeys(self):
        return (await self._aget_session()).keys()

    def values(self):
        return self._session.values()

    async def avalues(self):
        return (await self._aget_session()).values()

    def items(self):
        return self._session.items()

    async def aitems(self):
        return (await self._aget_session()).items()

    def clear(self):
        # To avoid unnecessary persistent storage accesses, we set up the
        # internals directly (loading data wastes time, since we are going to
        # set it to an empty dict anyway).
        self._session_cache = {}
        self.accessed = True
        self.modified = True

    def is_empty(self):
        "Return True when there is no session_key and the session is empty."
        try:
            return not self._session_key and not self._session_cache
        except AttributeError:
            return True

    def _get_new_session_key(self):
        "Return session key that isn't being used."
        while True:
            session_key = get_random_string(32, VALID_KEY_CHARS)
            if not self.exists(session_key):
                return session_key

    async def _aget_new_session_key(self):
        while True:
            session_key = get_random_string(32, VALID_KEY_CHARS)
            if not await self.aexists(session_key):
                return session_key

    def _get_or_create_session_key(self):
        if self._session_key is None:
            self._session_key = self._get_new_session_key()
        return self._session_key

    async def _aget_or_create_session_key(self):
        if self._session_key is None:
            self._session_key = await self._aget_new_session_key()
        return self._session_key

    def _validate_session_key(self, key):
        """
        Key must be truthy and at least 8 characters long. 8 characters is an
        arbitrary lower bound for some minimal key security.
        """
        return key and len(key) >= 8

    def _get_session_key(self):
        return self.__session_key

    def _set_session_key(self, value):
        """
        Validate session key on assignment. Invalid values will set to None.
        """
        if self._validate_session_key(value):
            self.__session_key = value
        else:
            self.__session_key = None

    session_key = property(_get_session_key)
    _session_key = property(_get_session_key, _set_session_key)

    def _get_session(self, no_load=False):
        """
        Lazily load session from storage (unless "no_load" is True, when only
        an empty dict is stored) and store it in the current instance.
        """
        self.accessed = True
        try:
            return self._session_cache
        except AttributeError:
            if self.session_key is None or no_load:
                self._session_cache = {}
            else:
                self._session_cache = self.load()
        return self._session_cache

    async def _aget_session(self, no_load=False):
        self.accessed = True
        try:
            return self._session_cache
        except AttributeError:
            if self.session_key is None or no_load:
                self._session_cache = {}
            else:
                self._session_cache = await self.aload()
        return self._session_cache

    _session = property(_get_session)

    def get_session_cookie_age(self):
        return settings.SESSION_COOKIE_AGE

    def get_expiry_age(self, **kwargs):
        """Get the number of seconds until the session expires.

        Optionally, this function accepts `modification` and `expiry` keyword
        arguments specifying the modification and expiry of the session.
        """
        try:
            modification = kwargs["modification"]
        except KeyError:
            modification = timezone.now()
        # Make the difference between "expiry=None passed in kwargs" and
        # "expiry not passed in kwargs", in order to guarantee not to trigger
        # self.load() when expiry is provided.
        try:
            expiry = kwargs["expiry"]
        except KeyError:
            expiry = self.get("_session_expiry")

        if not expiry:  # Checks both None and 0 cases
            return self.get_session_cookie_age()
        if not isinstance(expiry, (datetime, str)):
            return expiry
        if isinstance(expiry, str):
            expiry = datetime.fromisoformat(expiry)
        delta = expiry - modification
        return delta.days * 86400 + delta.seconds

    async def aget_expiry_age(self, **kwargs):
        try:
            modification = kwargs["modification"]
        except KeyError:
            modification = timezone.now()
        try:
            expiry = kwargs["expiry"]
        except KeyError:
            expiry = await self.aget("_session_expiry")

        if not expiry:  # Checks both None and 0 cases
            return self.get_session_cookie_age()
        if not isinstance(expiry, (datetime, str)):
            return expiry
        if isinstance(expiry, str):
            expiry = datetime.fromisoformat(expiry)
        delta = expiry - modification
        return delta.days * 86400 + delta.seconds

    def get_expiry_date(self, **kwargs):
        """Get session the expiry date (as a datetime object).

        Optionally, this function accepts `modification` and `expiry` keyword
        arguments specifying the modification and expiry of the session.
        """
        try:
            modification = kwargs["modification"]
        except KeyError:
            modification = timezone.now()
        # Same comment as in get_expiry_age
        try:
            expiry = kwargs["expiry"]
        except KeyError:
            expiry = self.get("_session_expiry")

        if isinstance(expiry, datetime):
            return expiry
        elif isinstance(expiry, str):
            return datetime.fromisoformat(expiry)
        expiry = expiry or self.get_session_cookie_age()
        return modification + timedelta(seconds=expiry)

    async def aget_expiry_date(self, **kwargs):
        try:
            modification = kwargs["modification"]
        except KeyError:
            modification = timezone.now()
        try:
            expiry = kwargs["expiry"]
        except KeyError:
            expiry = await self.aget("_session_expiry")

        if isinstance(expiry, datetime):
            return expiry
        elif isinstance(expiry, str):
            return datetime.fromisoformat(expiry)
        expiry = expiry or self.get_session_cookie_age()
        return modification + timedelta(seconds=expiry)

    def set_expiry(self, value):
        """
        Set a custom expiration for the session. ``value`` can be an integer,
        a Python ``datetime`` or ``timedelta`` object or ``None``.

        If ``value`` is an integer, the session will expire after that many
        seconds of inactivity. If set to ``0`` then the session will expire on
        browser close.

        If ``value`` is a ``datetime`` or ``timedelta`` object, the session
        will expire at that specific future time.

        If ``value`` is ``None``, the session uses the global session expiry
        policy.
        """
        if value is None:
            # Remove any custom expiration for this session.
            try:
                del self["_session_expiry"]
            except KeyError:
                pass
            return
        if isinstance(value, timedelta):
            value = timezone.now() + value
        if isinstance(value, datetime):
            value = value.isoformat()
        self["_session_expiry"] = value

    async def aset_expiry(self, value):
        if value is None:
            # Remove any custom expiration for this session.
            try:
                await self.adelete("_session_expiry")
            except KeyError:
                pass
            return
        if isinstance(value, timedelta):
            value = timezone.now() + value
        if isinstance(value, datetime):
            value = value.isoformat()
        await self.aset("_session_expiry", value)

    def get_expire_at_browser_close(self):
        """
        Return ``True`` if the session is set to expire when the browser
        closes, and ``False`` if there's an expiry date. Use
        ``get_expiry_date()`` or ``get_expiry_age()`` to find the actual expiry
        date/age, if there is one.
        """
        if (expiry := self.get("_session_expiry")) is None:
            return settings.SESSION_EXPIRE_AT_BROWSER_CLOSE
        return expiry == 0

    async def aget_expire_at_browser_close(self):
        if (expiry := await self.aget("_session_expiry")) is None:
            return settings.SESSION_EXPIRE_AT_BROWSER_CLOSE
        return expiry == 0

    def flush(self):
        """
        Remove the current session data from the database and regenerate the
        key.
        """
        self.clear()
        self.delete()
        self._session_key = None

    async def aflush(self):
        self.clear()
        await self.adelete()
        self._session_key = None

    def cycle_key(self):
        """
        Create a new session key, while retaining the current session data.
        """
        data = self._session
        key = self.session_key
        self.create()
        self._session_cache = data
        if key:
            self.delete(key)

    async def acycle_key(self):
        """
        Create a new session key, while retaining the current session data.
        """
        data = self._session
        key = self.session_key
        await self.acreate()
        self._session_cache = data
        if key:
            await self.adelete(key)

    # Methods that child classes must implement.

    def exists(self, session_key):
        """
        Return True if the given session_key already exists.
        """
        raise NotImplementedError(
            "subclasses of SessionBase must provide an exists() method"
        )

    async def aexists(self, session_key):
        return await sync_to_async(self.exists, thread_sensitive=True)(session_key)

    def create(self):
        """
        Create a new session instance. Guaranteed to create a new object with
        a unique key and will have saved the result once (with empty data)
        before the method returns.
        """
        raise NotImplementedError(
            "subclasses of SessionBase must provide a create() method"
        )

    async def acreate(self):
        return await sync_to_async(self.create, thread_sensitive=True)()

    def save(self, must_create=False):
        """
        Save the session data. If 'must_create' is True, create a new session
        object (or raise CreateError). Otherwise, only update an existing
        object and don't create one (raise UpdateError if needed).
        """
        raise NotImplementedError(
            "subclasses of SessionBase must provide a save() method"
        )

    async def asave(self, must_create=False):
        return await sync_to_async(self.save, thread_sensitive=True)(must_create)

    def delete(self, session_key=None):
        """
        Delete the session data under this key. If the key is None, use the
        current session key value.
        """
        raise NotImplementedError(
            "subclasses of SessionBase must provide a delete() method"
        )

    async def adelete(self, session_key=None):
        return await sync_to_async(self.delete, thread_sensitive=True)(session_key)

    def load(self):
        """
        Load the session data and return a dictionary.
        """
        raise NotImplementedError(
            "subclasses of SessionBase must provide a load() method"
        )

    async def aload(self):
        return await sync_to_async(self.load, thread_sensitive=True)()

    @classmethod
    def clear_expired(cls):
        """
        Remove expired sessions from the session store.

        If this operation isn't possible on a given backend, it should raise
        NotImplementedError. If it isn't necessary, because the backend has
        a built-in expiration mechanism, it should be a no-op.
        """
        raise NotImplementedError("This backend does not support clear_expired().")

    @classmethod
    async def aclear_expired(cls):
        return await sync_to_async(cls.clear_expired, thread_sensitive=True)()

cache.py

#
#  Copyright (c) 2022-2023. Lazify, Inc. and its affiliates - All Rights Reserved
#
#  This file is part of the Lazify project.
#
#  Unauthorized copying of this file in source and binary
#  forms, via any medium, is strictly prohibited.
#  Proprietary and confidential
#

from django.conf import settings
from django.contrib.sessions.backends.base import CreateError, UpdateError
from django.core.cache import caches

from app.vendor.django.contrib.sessions.backends.base import SessionBase


KEY_PREFIX = "django.contrib.sessions.cache"


class SessionStore(SessionBase):
    """
    A cache-based session store.
    """

    cache_key_prefix = KEY_PREFIX

    def __init__(self, session_key=None):
        self._cache = caches[settings.SESSION_CACHE_ALIAS]
        super().__init__(session_key)

    @property
    def cache_key(self):
        return self.cache_key_prefix + self._get_or_create_session_key()

    async def acache_key(self):
        return self.cache_key_prefix + await self._aget_or_create_session_key()

    def load(self):
        try:
            session_data = self._cache.get(self.cache_key)
        except Exception:
            # Some backends (e.g. memcache) raise an exception on invalid
            # cache keys. If this happens, reset the session. See #17810.
            session_data = None
        if session_data is not None:
            return session_data
        self._session_key = None
        return {}

    async def aload(self):
        try:
            session_data = await self._cache.aget(await self.acache_key())
        except Exception:
            session_data = None
        if session_data is not None:
            return session_data
        self._session_key = None
        return {}

    def create(self):
        # Because a cache can fail silently (e.g. memcache), we don't know if
        # we are failing to create a new session because of a key collision or
        # because the cache is missing. So we try for a (large) number of times
        # and then raise an exception. That's the risk you shoulder if using
        # cache backing.
        for i in range(10000):
            self._session_key = self._get_new_session_key()
            try:
                self.save(must_create=True)
            except CreateError:
                continue
            self.modified = True
            return
        raise RuntimeError(
            "Unable to create a new session key. "
            "It is likely that the cache is unavailable."
        )

    async def acreate(self):
        for i in range(10000):
            self._session_key = await self._aget_new_session_key()
            try:
                await self.asave(must_create=True)
            except CreateError:
                continue
            self.modified = True
            return
        raise RuntimeError(
            "Unable to create a new session key. "
            "It is likely that the cache is unavailable."
        )

    def save(self, must_create=False):
        if self.session_key is None:
            return self.create()
        if must_create:
            func = self._cache.add
        elif self._cache.get(self.cache_key) is not None:
            func = self._cache.set
        else:
            raise UpdateError
        result = func(
            self.cache_key,
            self._get_session(no_load=must_create),
            self.get_expiry_age(),
        )
        if must_create and not result:
            raise CreateError

    async def asave(self, must_create=False):
        if self.session_key is None:
            return await self.acreate()
        if must_create:
            func = self._cache.aadd
        elif await self._cache.ahas_key(self.cache_key):
            func = self._cache.aset
        else:
            raise UpdateError
        result = await func(
            self.cache_key,
            await self._aget_session(no_load=must_create),
            self.get_expiry_age(),
        )
        if must_create and not result:
            raise CreateError

    def exists(self, session_key):
        return (
            bool(session_key) and (self.cache_key_prefix + session_key) in self._cache
        )

    async def aexists(self, session_key):
        return bool(session_key) and await self._cache.ahas_key(
            self.cache_key_prefix + session_key
        )

    def delete(self, session_key=None):
        if session_key is None:
            if self.session_key is None:
                return
            session_key = self.session_key
        self._cache.delete(self.cache_key_prefix + session_key)

    async def adelete(self, session_key=None):
        if session_key is None:
            if self.session_key is None:
                return
            session_key = self.session_key
        await self._cache.adelete(self.cache_key_prefix + session_key)

    @classmethod
    def clear_expired(cls):
        pass

    @classmethod
    async def aclear_expired(cls):
        pass

middleware (it’s a Starlette middleware but incorporates the beef of Django’s session middleware):

#
#  Copyright (c) 2022-2023. Lazify, Inc. and its affiliates - All Rights Reserved
#
#  This file is part of the Lazify project.
#
#  Unauthorized copying of this file in source and binary
#  forms, via any medium, is strictly prohibited.
#  Proprietary and confidential
#

import asyncio
import time

from django.conf import settings
from django.contrib.sessions.backends.base import UpdateError
from django.utils.http import http_date
from starlette.datastructures import MutableHeaders
from starlette.exceptions import HTTPException
from starlette.types import ASGIApp, Message, Receive, Scope, Send

from app.api.views.client.session import ClientChatSessionStore
from app.core.utils.django.cache import patch_vary_headers
from app.core.utils.django.request import delete_cookie, set_cookie
from app.request import Request, WebSocket
from app.vendor.django.contrib.sessions.backends.cache import SessionStore


class SessionMiddleware:
    def __init__(self, app: ASGIApp):
        self.app = app

    async def _process_response(
        self,
        message: Message,
        scope: Scope,
        receive: Receive,
        send: Send,
        prefix: str = "",
    ):
        request = Request(scope, receive, send)
        message.setdefault("headers", [])
        setting_prefix = prefix.upper()
        headers = MutableHeaders(scope=message)

        def session() -> SessionStore:
            return getattr(request, f"{prefix}session")

        def setting(_setting_name: str):
            return getattr(settings, f"{setting_prefix}{_setting_name}")

        try:
            accessed = session().accessed
            modified = session().modified
            empty = session().is_empty()
        except AttributeError:
            await self.app(scope, receive, send)
            return

        if prefix == "chat_":
            if (modified or setting("SESSION_SAVE_EVERY_REQUEST")) and not empty:
                try:
                    await session().asave()
                except UpdateError:
                    # originally django.contrib.sessions.exceptions.SessionInterrupted
                    raise HTTPException(
                        status_code=400,
                        detail="The request's session was deleted before the "
                        "request completed. The user may have logged "
                        "out in a concurrent request, for example.",
                    )
            return

        # First check if we need to delete this cookie.
        # The session should be deleted only if the session is entirely empty.
        if setting("SESSION_COOKIE_NAME") in request.cookies and empty:
            delete_cookie(
                headers,
                setting("SESSION_COOKIE_NAME"),
                path=setting("SESSION_COOKIE_PATH"),
                domain=setting("SESSION_COOKIE_DOMAIN"),
                samesite=setting("SESSION_COOKIE_SAMESITE"),
            )
            delete_cookie(headers, "user")
            patch_vary_headers(headers, ("Cookie",))
        else:
            if accessed:
                patch_vary_headers(headers, ("Cookie",))
            if (modified or setting("SESSION_SAVE_EVERY_REQUEST")) and not empty:
                if await session().aget_expire_at_browser_close():
                    max_age = None
                    expires = None
                else:
                    max_age = await session().aget_expiry_age()
                    expires_time = time.time() + max_age
                    expires = http_date(expires_time)
                # Save the session data and refresh the client cookie.
                # Skip session save for 500 responses, refs #3881.
                # if response.status_code != 500:
                try:
                    await session().asave()
                except UpdateError:
                    # originally django.contrib.sessions.exceptions.SessionInterrupted
                    raise HTTPException(
                        status_code=400,
                        detail="The request's session was deleted before the "
                        "request completed. The user may have logged "
                        "out in a concurrent request, for example.",
                    )
                set_cookie(
                    headers,
                    setting("SESSION_COOKIE_NAME"),
                    session().session_key,
                    max_age=max_age,
                    expires=expires,
                    domain=setting("SESSION_COOKIE_DOMAIN"),
                    path=setting("SESSION_COOKIE_PATH"),
                    secure=setting("SESSION_COOKIE_SECURE") or None,
                    httponly=setting("SESSION_COOKIE_HTTPONLY") or None,
                    samesite=setting("SESSION_COOKIE_SAMESITE"),
                )
1 Like