After running into concurrency issues where my Django model instances were being overwritten by previous values due to .save()
not having update_fields
specified, I began refactoring to specify update_fields
wherever possible. This process proved more difficult than anticipated for complex functions and during the process I had a thought - if we can reliably know which fields have changed in-memory during thread execution (tracked on the model instance), can we automatically set update_fields
?
This would make concurrency much safer as even if the instance gets “overwritten”, only the overlapping fields are an issue.
Though I can see how this could be undesirable in some cases, it seems like majority of cases would benefit from having update_fields
restricted to only fields which have been updated.
Here’s a draft of a BaseModel that can hopefully be used to achieve this functionality. Appreciate any insight into problematic side-effects this may cause.
Note: this accounts for deferred fields and does not incur additional database queries.
from django.db import models
class _Null:
"""A class to represent a Null type."""
def __bool__(self):
return False
NULL = _Null()
class BaseModel(models.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__original_values = self._get_loaded_field_values()
def _get_loaded_field_names(self):
concrete_fields = self._meta.concrete_fields
normalized_deferred_fields = []
for deferred_field in self.get_deferred_fields():
if deferred_field not in concrete_fields and deferred_field.endswith("_id"):
deferred_field = deferred_field[:-3]
normalized_deferred_fields.append(deferred_field)
return [
field.name
for field in concrete_fields
if field.name not in normalized_deferred_fields
]
def _get_loaded_field_values(self):
field_values = {}
for field_name in self._get_loaded_field_names():
foreign_key_id = getattr(self, field_name + "_id", NULL)
if foreign_key_id is not NULL:
field_values[field_name] = foreign_key_id
else:
field_values[field_name] = getattr(self, field_name, NULL)
return field_values
def _get_update_fields(self):
update_fields = []
for field_name, current_field_value in self._get_loaded_field_values().items():
original_field_value = self.__original_values.get(field_name, NULL)
if (
current_field_value is not NULL
and current_field_value != original_field_value
):
update_fields.append(field_name)
return update_fields
def refresh_from_db(self, *args, **kwargs):
super().refresh_from_db(*args, **kwargs)
self.__original_values = self._get_loaded_field_values()
def save(self, *args, update_fields=None, **kwargs):
"""
Automatically set `update_fields` (if not explicitly specified)
based on which fields have been modified
"""
creating = self._state.adding is True
if not creating and update_fields is None:
update_fields = self._get_update_fields()
super().save(*args, update_fields=update_fields, **kwargs)
self.__original_values = self._get_loaded_field_values()