After running into concurrency issues where my Django model instances were being overwritten by previous values due to .save() not having update_fields specified, I began refactoring to specify update_fields wherever possible. This process proved more difficult than anticipated for complex functions and during the process I had a thought - if we can reliably know which fields have changed in-memory during thread execution (tracked on the model instance), can we automatically set update_fields?
This would make concurrency much safer as even if the instance gets “overwritten”, only the overlapping fields are an issue.
Though I can see how this could be undesirable in some cases, it seems like majority of cases would benefit from having update_fields restricted to only fields which have been updated.
Here’s a draft of a BaseModel that can hopefully be used to achieve this functionality. Appreciate any insight into problematic side-effects this may cause.
Note: this accounts for deferred fields and does not incur additional database queries.
from django.db import models
class _Null:
    """A class to represent a Null type."""
    def __bool__(self):
        return False
NULL = _Null()
class BaseModel(models.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__original_values = self._get_loaded_field_values()
    def _get_loaded_field_names(self):
        concrete_fields = self._meta.concrete_fields
        normalized_deferred_fields = []
        for deferred_field in self.get_deferred_fields():
            if deferred_field not in concrete_fields and deferred_field.endswith("_id"):
                deferred_field = deferred_field[:-3]
            normalized_deferred_fields.append(deferred_field)
        return [
            field.name
            for field in concrete_fields
            if field.name not in normalized_deferred_fields
        ]
    def _get_loaded_field_values(self):
        field_values = {}
        for field_name in self._get_loaded_field_names():
            foreign_key_id = getattr(self, field_name + "_id", NULL)
            if foreign_key_id is not NULL:
                field_values[field_name] = foreign_key_id
            else:
                field_values[field_name] = getattr(self, field_name, NULL)
        return field_values
    def _get_update_fields(self):
        update_fields = []
        for field_name, current_field_value in self._get_loaded_field_values().items():
            original_field_value = self.__original_values.get(field_name, NULL)
            if (
                current_field_value is not NULL
                and current_field_value != original_field_value
            ):
                update_fields.append(field_name)
        return update_fields
    def refresh_from_db(self, *args, **kwargs):
        super().refresh_from_db(*args, **kwargs)
        self.__original_values = self._get_loaded_field_values()
    def save(self, *args, update_fields=None, **kwargs):
        """
        Automatically set `update_fields` (if not explicitly specified)
        based on which fields have been modified
        """
        creating = self._state.adding is True
        if not creating and update_fields is None:
            update_fields = self._get_update_fields()
        super().save(*args, update_fields=update_fields, **kwargs)
        self.__original_values = self._get_loaded_field_values()