diff --git a/drf_writable_nested/mixins.py b/drf_writable_nested/mixins.py index f40dd48..ebec84f 100644 --- a/drf_writable_nested/mixins.py +++ b/drf_writable_nested/mixins.py @@ -4,8 +4,9 @@ from django.contrib.contenttypes.fields import GenericRelation from django.contrib.contenttypes.models import ContentType -from django.core.exceptions import FieldDoesNotExist -from django.db.models import ProtectedError, SET_NULL, SET_DEFAULT +from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist +from django.db.models import (SET_DEFAULT, SET_NULL, OneToOneField, + ProtectedError) from django.db.models.fields.related import ForeignObjectRel, ManyToManyRel from django.utils.translation import gettext_lazy as _ from rest_framework import serializers @@ -13,6 +14,29 @@ from rest_framework.validators import UniqueValidator +class NestedOnlySerializerMixin(serializers.ModelSerializer): + """ + Required for all serializers that are nested under BaseNestedModelSerializer. + """ + + def save(self, **kwargs): + self._save_kwargs = defaultdict(dict, kwargs) + validated_data = {**self.validated_data, **kwargs} + + if self.instance is not None: + self.instance = self.update(self.instance, validated_data) + assert self.instance is not None, ( + '`update()` did not return an object instance.' + ) + else: + self.instance = self.create(validated_data) + assert self.instance is not None, ( + '`create()` did not return an object instance.' + ) + + return self.instance + + class BaseNestedModelSerializer(serializers.ModelSerializer): def _extract_relations(self, validated_data): reverse_relations = OrderedDict() @@ -121,19 +145,22 @@ def _extract_related_pks(self, field, related_data): return pk_list - def _prefetch_related_instances(self, field, related_data): - model_class = field.Meta.model + def _prefetch_related_instances(self, field, related_data, field_name, instance): pk_list = self._extract_related_pks(field, related_data) + try: + related_manager = getattr(instance, field_name) + except ObjectDoesNotExist: + return {} + instances = { str(related_instance.pk): related_instance - for related_instance in model_class.objects.filter( - pk__in=pk_list - ) + for related_instance in related_manager.filter(pk__in=pk_list) } return instances + def update_or_create_reverse_relations(self, instance, reverse_relations): # Update or create reverse relations: # many-to-one, many-to-many, reversed one-to-one @@ -147,6 +174,8 @@ def update_or_create_reverse_relations(self, instance, reverse_relations): if related_data is None: continue + related_validated_data = self._validated_data[field_source] + if related_field.one_to_one: # If an object already exists, fill in the pk so # we don't try to duplicate it @@ -160,8 +189,14 @@ def update_or_create_reverse_relations(self, instance, reverse_relations): # Expand to array of one item for one-to-one for uniformity related_data = [related_data] + related_validated_data = [related_validated_data] - instances = self._prefetch_related_instances(field, related_data) + instances = self._prefetch_related_instances( + field, + related_data, + field_name, + instance + ) save_kwargs = self._get_save_kwargs(field_name) if isinstance(related_field, GenericRelation): @@ -173,7 +208,7 @@ def update_or_create_reverse_relations(self, instance, reverse_relations): new_related_instances = [] errors = [] - for data in related_data: + for index, data in enumerate(related_data): obj = instances.get( self._get_related_pk(data, field.Meta.model) ) @@ -183,7 +218,8 @@ def update_or_create_reverse_relations(self, instance, reverse_relations): data=data, ) try: - serializer.is_valid(raise_exception=True) + serializer._errors = {} + serializer._validated_data = related_validated_data[index] related_instance = serializer.save(**save_kwargs) data['pk'] = related_instance.pk new_related_instances.append(related_instance) @@ -208,10 +244,21 @@ def update_or_create_direct_relations(self, attrs, relations): data = self.get_initial()[field_name] model_class = field.Meta.model pk = self._get_related_pk(data, model_class) - if pk: + # pk needs to be specified if it's not one to one or creation of new object is not intended + + is_one_to_one = isinstance(self.Meta.model._meta.get_field(field_source), OneToOneField) + + if pk and not is_one_to_one: + # for direct ForeignKey + # potential filtering should be done in the child serializer + # as it is too project-specific obj = model_class.objects.filter( pk=pk, ).first() + else: + # for direct OneToOne or current ForeignKey + obj = getattr(self.instance, field_source, None) + serializer = self._get_serializer_for_field( field, instance=obj, @@ -219,7 +266,9 @@ def update_or_create_direct_relations(self, attrs, relations): ) try: - serializer.is_valid(raise_exception=True) + + serializer._errors = {} + serializer._validated_data = self._validated_data[field_source] attrs[field_source] = serializer.save( **self._get_save_kwargs(field_name) ) @@ -310,6 +359,10 @@ def perform_nested_delete_or_update(self, pks_to_delete, model_class, instance, qs.delete() def delete_reverse_relations_if_need(self, instance, reverse_relations): + if self.partial: + # bypass deletion if set to partial update + return + # Reverse `reverse_relations` for correct delete priority reverse_relations = OrderedDict( reversed(list(reverse_relations.items())))