diff --git a/README.md b/README.md index c0d6ee4..041959f 100644 --- a/README.md +++ b/README.md @@ -228,6 +228,58 @@ py.test ``` +New-Style Serializers +================ + +In 2021, an enhanced set of mixins were added that permit fine-grained control of nested +Serializer behavior using a `match_on` argument. New-style serializers delegate control +of the Create/Update behavior to the nested Serializer. The parent Serializer need only +resolve nested serializers in the right order; this is handled by the `RelatedSaveMixin`. + +New-style Serializers provide the following semantics: + + - Get: retrieve a matching object (but DO NOT update) + - Update: retrieve and update a matching object + - Create: create an object using the entire payload + - Combinations of the above e.g. GetOrCreate and UpdateOrCreate + +The matching of `data` to a specific `instance` is driven by a list of fields found in +`match_on`. This value is obtained from: + + - the `match_on` kwarg provided when the field is initialized + - the DEFAULT_MATCH_ON class attribute + +The new-style Serializers may be used as top-level Serializers to provide get-or-create +behaviors to DRF endpoints. Examples of use can be found in +`test_nested_serializer_mixins.py`. + +Migration +--------- + +To convert an existing serializer to the new style serializers, the following procedure +is recommended: + +1. Convert nested serializers by replacing `serializers.ModelSerializer` with +`UpdateOrCreateNestedSerializerMixin, serializers.ModelSerializer` which preserves +backwards-compatible behavior. +1. Convert parent serializer by replacing `WritableNestedModelSerializer` with +`RelatedSaveMixin, serializers.ModelSerializer`. +1. Verify that your test cases still pass. +1. Modify serializers (and test cases) to new-style behavior. For example, add an +explicit `match_on` or switch the mixin to an alternative behavior like +`GetOrCreateNestedSerializerMixin`. + +All test cases were duplicated for new-style serializers so you can see examples of +converted serializers in `tests/serializers.py`. For example `TeamSerializer` and +`UserSerializer` become `NewTeamSerializer` and `NewUserSerializer`. Examples of +`DEFAULT_MATCH_ON` can be found in `tests/serializers.py`. One example of an explicit +specified `match_on` is present, but non-default `match_on` values are not found in +`tests` because they were not required to produce existing behaviors. + +NOTE: While `RelatedSaveMixin` is the backwards-compatible mixin for the top-level +class, it is also possible to use other mixins to get complex matching behavior without +modifying the view. + Known problems with solutions ============================= diff --git a/drf_writable_nested/mixins.py b/drf_writable_nested/mixins.py index d3dc4c4..49be6c5 100644 --- a/drf_writable_nested/mixins.py +++ b/drf_writable_nested/mixins.py @@ -1,16 +1,24 @@ # -*- coding: utf-8 -*- +import logging from collections import OrderedDict, defaultdict from typing import List, Tuple from django.contrib.contenttypes.fields import GenericRelation from django.contrib.contenttypes.models import ContentType -from django.core.exceptions import FieldDoesNotExist -from django.db.models import ProtectedError -from django.db.models.fields.related import ForeignObjectRel +from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist +from django.db import transaction, router, IntegrityError +from django.db.models import OneToOneRel, ProtectedError +from django.db.models.fields.related import ForeignObjectRel, ManyToManyField from django.utils.translation import gettext_lazy as _ from rest_framework import serializers from rest_framework.exceptions import ValidationError -from rest_framework.validators import UniqueValidator +from rest_framework.fields import empty +from rest_framework.relations import ManyRelatedField +from rest_framework.serializers import BaseSerializer, ListSerializer +from rest_framework.validators import UniqueValidator, UniqueTogetherValidator + +# permit writable nested serializers +serializers.raise_errors_on_nested_writes = lambda a, b, c: None class BaseNestedModelSerializer(serializers.ModelSerializer): @@ -427,3 +435,519 @@ def create(self, validated_data): def update(self, instance, validated_data): self._validate_unique_fields(validated_data) return super(UniqueFieldsMixin, self).update(instance, validated_data) + + +class FieldLookupMixin(serializers.Serializer): + + def __init__(self, *args, **kwargs): + self.logger = logging.getLogger(self.__class__.__name__) + super(FieldLookupMixin, self).__init__(*args, **kwargs) + + def _get_model_field(self, source): + """Returns the field on the model""" + # for serializers like ModelSerializer, the Meta.model can be used to classify fields + if not hasattr(self, 'Meta') or not hasattr(self.Meta, 'model'): + return None + try: + return self.Meta.model._meta.get_field(source) + except FieldDoesNotExist: + pass + try: + # If `related_name` is not set, field name does not include + # `_set` -> remove it and check again + default_postfix = '_set' + if source.endswith(default_postfix): + return self.Meta.model._meta.get_field(source[:-len(default_postfix)]) + except FieldDoesNotExist: + pass + return None + + TYPE_READ_ONLY = 'read-only' + TYPE_LOCAL = 'local' + TYPE_DIRECT = 'direct' + TYPE_REVERSE = 'reverse' + + _cache_field_types = None + _cache_field_sources = None + + @property + def field_sources(self): + if self._cache_field_sources is None: + self._populate_field_types() + return self._cache_field_sources + + @property + def field_types(self): + if self._cache_field_types is None: + self._populate_field_types() + return self._cache_field_types + + def _populate_field_types(self): + self._cache_field_types = {} + self._cache_field_sources = {} + for field_name, field in self.fields.items(): + if isinstance(self._get_model_field(field.source), GenericRelation): + raise TypeError("GenericRelation not currently supported") + if field.read_only: + self._cache_field_types[field_name] = self.TYPE_READ_ONLY + self._cache_field_sources[field.source] = self.TYPE_READ_ONLY + continue + if not isinstance(field, BaseSerializer): + self._cache_field_types[field_name] = self.TYPE_LOCAL + self._cache_field_sources[field.source] = self.TYPE_LOCAL + continue + if field.source == '*': + self._cache_field_types[field_name] = self.TYPE_DIRECT + continue + model_field = self._get_model_field(field.source) + if isinstance(model_field, OneToOneRel): + self._cache_field_types[field_name] = self.TYPE_REVERSE + self._cache_field_sources[field.source] = self.TYPE_REVERSE + # TODO continue? + if isinstance(model_field, ForeignObjectRel): + self._cache_field_types[field_name] = self.TYPE_REVERSE + self._cache_field_sources[field.source] = self.TYPE_REVERSE + continue + self._cache_field_types[field_name] = self.TYPE_DIRECT + self._cache_field_sources[field.source] = self.TYPE_DIRECT + + +class RelatedSaveMixin(FieldLookupMixin): + """ + RelatedSaveMixin handes the saving of nested fields, both direct and reverse relations: + - Direct relations needs to be saved first + - The focal object can then be saved (which ensures the focal PK is available) + - Finally, reverse relations can be udpated with the object PK + """ + _is_saved = False + + def run_validation(self, data=empty): + """Cache nested call to `to_representation` on _validate_data for use when saving.""" + self.logger.debug("{} validating: {}".format(self.__class__.__name__, data)) + self._validated_data = super(RelatedSaveMixin, self).run_validation(data) + self.logger.debug("{} validated: {}".format(self.__class__.__name__, self._validated_data)) + self._errors = {} + return self._validated_data + + def to_internal_value(self, data): + """Injects the PK of this field into reverse relations so they validate when created in to_internal_value.""" + self._make_reverse_relations_valid() + return super(RelatedSaveMixin, self).to_internal_value(data) + + def _make_reverse_relations_valid(self): + """Make the reverse ForeignKey field optional since we may not have a key for the base object yet.""" + for field_name, field in self.fields.items(): + if self.field_types[field_name] != self.TYPE_REVERSE: + continue + # we know this is a reverse so reverse_field.field is valid + related_field = self._get_model_field(field.source).field + if isinstance(field, serializers.ListSerializer): + field = field.child + # find the serializer field matching the reverse model relation + for sub_field in field.fields.values(): + if sub_field.source == related_field.name: + sub_field.required = False + # found the matching field, move on + break + + @property + def validated_data(self): + """If mixed into a standard Serializer, prevents `save` from accessing reverse relations""" + return {k: v for k, v in super(RelatedSaveMixin, self).validated_data.items() + if k not in self.field_sources or self.field_sources[k] != self.TYPE_REVERSE} + + def save(self, **kwargs): + """Convert validated data into related objects and save.""" + # Create or update direct relations (foreign key, one-to-one) + self.logger.debug("RelatedSaveMixin.save for {} with data, kwargs: {}, {}".format(self.__class__.__name__, self._validated_data, kwargs)) + self._save_direct_relations(kwargs=kwargs) + instance = super(RelatedSaveMixin, self).save(**kwargs) + if instance is None: # possibly with GetOnly or UpdateOnly + # cannot create reverse relations with no object + return None + self._save_reverse_relations(instance=instance, kwargs=kwargs) + return instance + + def _save_direct_relations(self, kwargs): + """Save direct relations so FKs exist when committing the base instance""" + if self._validated_data is None and kwargs == {}: + return # delete-only + for field_name, field in self.fields.items(): + if self.field_types[field_name] != self.TYPE_DIRECT: + continue + self.logger.debug("{} direct save {} with data, kwargs; {}, {}".format(self.__class__.__name__, field_name, self._validated_data.get(field.source, empty), kwargs.get(field_name, {}))) + if self._validated_data.get(field.source, empty) == empty and kwargs.get(field_name, empty) == empty: + continue # nothing to save + #if self._validated_data.get(field_name) is None or kwargs.get(field_name) is None: + # continue # delete existing objects + # we need to pop from kwargs so the value doesn't "overwrite" the value generated by save + self._validated_data[field.source] = field.save(**kwargs.pop(field_name, {})) + self.logger.debug("{}._validated_data[{}] set to direct {}".format(self.__class__.__name__, field_name, self._validated_data[field.source])) + + def _format_generic_lookup(self, instance, related_field): + return { + related_field.content_type_field_name: ContentType.objects.get_for_model(instance), + related_field.object_id_field_name: instance.pk, + } + + def _save_reverse_relations(self, instance, kwargs): + """Inject the current object as the FK in the reverse related objects and save them""" + for field_name, field in self.fields.items(): + if self.field_types[field_name] != self.TYPE_REVERSE: + continue + if self._validated_data is None and kwargs == {}: + return # delete-only + if self._validated_data.get(field.source, empty) == empty and kwargs.get(field_name, empty) == empty: + continue # nothing to save + model_field = self._get_model_field(field.source) + self.logger.debug("{} populating reverse field {}".format(self.__class__.__name__, model_field.field.name)) + if isinstance(field, serializers.ListSerializer): + # reverse FK, inject the instance into reverse relations so the _id FK field is valid when saved + for obj in field._validated_data: + obj[model_field.field.name] = instance + elif isinstance(field, serializers.ModelSerializer): + # 1:1 + if self._validated_data[field.source] is None: + # indicates that we should delete 1:1 relation (if it exists) + try: + getattr(instance, field.source).delete() + continue + except ObjectDoesNotExist: + pass + else: + field._validated_data[model_field.field.name] = instance + else: + raise Exception("unexpected serializer type") + # create/update (as appropriate) related objects + self._validated_data[field.source] = field.save(**kwargs.get(field_name, {})) + self.logger.debug("{}._validated_data[{}] to reverse {}".format(self.__class__.__name__, field_name, self._validated_data[field.source])) + + # eliminate related objects that weren't in the request + if isinstance(field, ListSerializer): + # due to a bug in Django, calling `set` on a non-nullable reverse relation will only `add` + if model_field.field.null: + getattr(instance, field.source).set(self._validated_data[field.source]) + else: + # models should be attached when saved so we only need to delete + obj_field = getattr(instance, field.source) + db = router.db_for_write(obj_field.model, instance=instance) + old_objs = set(obj_field.using(db).all()) + for obj in old_objs: + if obj not in self._validated_data[field.source]: + obj.delete() + + +class FocalSaveMixin(FieldLookupMixin): + """Provides a framework for extracting the values needed to get or create the focal object.""" + + default_error_messages = { + 'incorrect_type': _('Nested field received an incorrect data type ({data_type}): {exception_message}'), + } + + def to_internal_value(self, data): + """Injects the PK of this field into reverse relations so they validate when created in to_internal_value.""" + # to_internal_value only preserves writable fields and match_on may need read-only like PK + self._validated_data = super(FocalSaveMixin, self).to_internal_value(data) + # patch read-only fields for match_on + for field_name, field in self.fields.items(): + if not field.read_only: + continue + if field_name not in data: + continue + try: + self._validated_data[field.source] = field.to_internal_value(data[field_name]) + except NotImplementedError: + pass # if `to_internal_value` isn't provided, we won't be able to match on it + return self._validated_data + + def build_match_on(self, kwargs): + match_on = {} + for field_name, field in self.fields.items(): + if self.match_on == '__all__' or field_name in self.match_on: + # build match_on dict + if hasattr(field, 'build_match_on'): + # create matching criteria + related_match_on = field.build_match_on(kwargs) + # apply match to nested field + match_on.update({"{}__{}".format(field.source, k): v for k, v in related_match_on.items()}) + else: + match_on[field.source] = kwargs.get(field_name, self._validated_data.get(field.source)) + # a parent serializer may inject a value that isn't among the fields, but is in `match_on` + for key in self.match_on: + if key not in self.fields.keys(): + match_on[key] = kwargs.get(key, None) + return match_on + + def build_direct_values(self, kwargs): + values = {} + for field_name, field in self.fields.items(): + if isinstance(field, ManyRelatedField) or isinstance(self._get_model_field(field.source), ManyToManyField): + continue # m2m fields + elif field.source == '*': + continue + elif self.field_types[field_name] == self.TYPE_LOCAL: + # need to check kwargs dict since there's no pre-processing + values[field.source] = kwargs.get(field_name, self._validated_data.get(field.source)) + elif self.field_types[field_name] == self.TYPE_DIRECT: + # kwargs should have been injected into _validated_data when direct relations were saved + values[field.source] = self._validated_data.get(field.source) + # reverse relations aren't sent to a create + return values + + def match(self, kwargs): + self.logger.debug("FocalSaveMixin.match with no super and kwargs {}".format(kwargs)) + return self.instance, False + + @transaction.atomic + def save(self, **kwargs): + self.logger.debug("FocalSaveMixin.save for {} with data, kwargs {}".format(self.__class__.__name__, self._validated_data, kwargs)) + if self._validated_data is None and kwargs == {}: + return None # deleted + match, needs_saved = self.match(kwargs) + self.logger.debug("Match: {}".format(match)) + needs_saved = self.do_update(match, kwargs) or needs_saved + try: + if needs_saved: + match.save() + except (TypeError, ValueError) as e: + self.fail('incorrect_type', data_type=type(self._validated_data).__name__, exception_message=e.args) + self.do_m2m_update(match, kwargs) + return match + + def do_update(self, match, create_values): + """Update the match (if appropriate) and returns a boolean indicating whether or not a save is required.""" + return False + + def do_m2m_update(self, match, kwargs): + return # no update + + +class NestedSaveListSerializer(serializers.ListSerializer): + """Need a special save() method that cascades to the list of child instances""" + + def __init__(self, *args, **kwargs): + self.logger = logging.getLogger(self.__class__.__name__) + super(NestedSaveListSerializer, self).__init__(*args, **kwargs) + + def save(self, **kwargs): + """ + Save and return a list of object instances. + """ + # Guard against incorrect use of `serializer.save(commit=False)` + assert 'commit' not in kwargs, ( + "'commit' is not a valid keyword argument to the 'save()' method. " + "If you need to access data before committing to the database then " + "inspect 'serializer.validated_data' instead. " + "You can also pass additional keyword arguments to 'save()' if you " + "need to set extra attributes on the saved model instance. " + "For example: 'serializer.save(owner=request.user)'.'" + ) + self.logger.debug("List-{} save with data, kwargs: {}, {}".format(self.child.__class__.__name__, self._validated_data, kwargs)) + + new_values = [] + + for item in self._validated_data: + # integrate save kwargs + self.child._validated_data = item + new_value = self.child.save(**kwargs) + # None indicates a deleted item + if new_value is not None: + new_values.append(new_value) + + self.logger.debug("List-{} saved: {}".format(self.child.__class__.__name__, new_values)) + return new_values + + def run_validation(self, data=empty): + """Since a nested serializer is treated like a Field, `is_valid` will not be called so we need to set + _validated_data in the mixin.""" + self.logger.debug("List-{} validating: {}".format(self.child.__class__.__name__, data)) + self._validated_data = super(NestedSaveListSerializer, self).run_validation(data) + self.logger.debug("List-{} validated: {}".format(self.child.__class__.__name__, self._validated_data)) + return self._validated_data + + def build_match_on(self, kwargs): + matches = [] + for item in self._validated_data: + self.child._validated_data = item + match, needs_save = self.child.match(kwargs) + if not needs_save: # we retrieved an instance + matches.append(match.pk) + return { + # TODO: make flexible e.g. to support match all + 'in': matches + } + + +class NestedSaveSerializer(RelatedSaveMixin, FocalSaveMixin): + """Provides a general framework for nested serializers including argument validation.""" + + default_list_serializer = NestedSaveListSerializer + DEFAULT_MATCH_ON = ['pk'] + queryset = None + + @classmethod + def many_init(cls, *args, **kwargs): + # inject the default into list_serializer_class (if not present) + meta = getattr(cls, 'Meta', None) + if meta is None: + class Meta: + pass + meta = Meta + setattr(cls, 'Meta', meta) + list_serializer_class = getattr(meta, 'list_serializer_class', None) + if list_serializer_class is None: + setattr(meta, 'list_serializer_class', cls.default_list_serializer) + assert issubclass(meta.list_serializer_class, NestedSaveListSerializer), \ + "NestedSaveMixin expects a NestedSaveListSerializer for correct save behavior. Please override " \ + "default_list_serializer or Meta.list_serializer_class and provide an appropriate class." + return super(NestedSaveSerializer, cls).many_init(*args, **kwargs) + + def __init__(self, *args, **kwargs): + if kwargs.get('partial', False) is True: + raise ValueError("Partial Updates not currently supported by NestedSaveSerializer") + self.queryset = kwargs.pop('queryset', self.queryset) + if self.queryset is None and hasattr(self, 'Meta') and hasattr(self.Meta, 'model'): + self.queryset = self.Meta.model.objects.all() + assert self.queryset is not None, \ + "NestedSerializerBase requires a Meta.model, a `queryset` on the Serializer, or a `queryset` kwarg" + self.match_on = kwargs.pop('match_on', self.DEFAULT_MATCH_ON) + assert self.match_on == '__all__' or isinstance(self.match_on, (tuple, list, set)), \ + "match_on only accepts as Collection of strings or the special value __all__" + if isinstance(self.match_on, (tuple, list, set)): + for match in self.match_on: + assert isinstance(match, str), "match_on collection can only contain strings" + super(NestedSaveSerializer, self).__init__(*args, **kwargs) + + def run_validation(self, data=empty): + """A nested serializer is treated like a Field so `is_valid` will not be called and `_validated_data` not set.""" + self.logger.debug("{} validating: {}".format(self.__class__.__name__, data)) + # ensure Unique and UniqueTogether don't collide with a DB match + validators = self.remove_validation_unique() + self._validated_data = super(NestedSaveSerializer, self).run_validation(data) + self.logger.debug("{} validated: {}".format(self.__class__.__name__, self._validated_data)) + # restore Unique or UniqueTogether + self.restore_validation_unique(validators) + return self._validated_data + + def remove_validation_unique(self): + """ + Removes unique validators from a serializers. This is critical for get-or-create style serialization. It can also + be used to distinguish 409 errors from client-side validation errors. + """ + fields = {} + # extract unique validators + for field_name, field in self.fields.items(): + fields[field_name] = [] + if not hasattr(field, 'validators'): + continue + for validator in field.validators: + if isinstance(validator, UniqueValidator): + fields[field_name].append(validator) + for validator in fields[field_name]: + field.validators.remove(validator) + # extract unique_together validators + fields['_'] = [] + for validator in self.validators: + if isinstance(validator, UniqueTogetherValidator): + fields['_'].append(validator) + for validator in fields['_']: + self.validators.remove(validator) + return fields + + def restore_validation_unique(self, unique_validators): + together_validators = unique_validators.pop('_') + for serializer in together_validators: + self.validators.append(serializer) + fields = self.fields + for name, validators in unique_validators.items(): + for validator in validators: + fields[name].validators.append(validator) + + def update(self, instance, validated_data): + raise KeyError( + "Update should never be called on a NestedSerializerBase. Make sure parent object uses NestedSaveMixin") + + def create(self, validated_data): + raise KeyError( + "Update should never be called on a NestedSerializerBase. Make sure parent object uses NestedSaveMixin") + + +class UpdateDoSaveMixin(NestedSaveSerializer): + """Adds behavior to update the focal object, including M2M relations""" + + def do_update(self, match, kwargs): + update_values = self.build_direct_values(kwargs) + for k, v in update_values.items(): + setattr(match, k, v) + return True + + def do_m2m_update(self, match, kwargs): + # assign relations to forward many-to-many fields + for field_name, field in self.fields.items(): + # we can't provide m2m values as kwargs; must use set() instead + # we don't care whether it's forward or reverse + # if we provide a custom serializer, it may not inherit from ManyRelatedField + if isinstance(field, ManyRelatedField) or isinstance(self._get_model_field(field.source), ManyToManyField): + value = kwargs.get(field_name, self._validated_data.get(field.source, empty)) + if value is empty: + continue # no information + if value is None: + value = [] # explicitly clear + self.logger.debug("{}: m2m set to {}, {}".format(self.__class__.__name__, field.source, value)) + getattr(match, field.source).set(value) + + +class GetOnlyNestedSerializerMixin(NestedSaveSerializer): + """Gets (without updating) requetsed object or fails.""" + + def match(self, kwargs): + match, needs_saved = super(GetOnlyNestedSerializerMixin, self).match(kwargs) + self.logger.debug("GetOnlyNestedSerializerMixin.match with super {} and kwargs {}".format(match, kwargs)) + if match is not None: + return match, needs_saved + try: + match_on = self.build_match_on(kwargs) + self.logger.debug("Matching on: {}".format(match_on)) + # if we don't filter().distint() we can get multiple copies of the same item and get() fails + # we can't distinct() and select_for_update() in the same query so we must use a subquery + return self.queryset.filter(pk__in=self.queryset.filter(**match_on).distinct()).select_for_update().get(), False + except self.queryset.model.DoesNotExist: + return None, False + + +class UpdateOnlyNestedSerializerMixin(UpdateDoSaveMixin, GetOnlyNestedSerializerMixin): + """Gets requested object (or fails) and updates object.""" + + +class GetOrCreateNestedSerializerMixin(GetOnlyNestedSerializerMixin): + """Gets (without updating) or creates requested object.""" + + def match(self, kwargs): + match, needs_saved = super(GetOrCreateNestedSerializerMixin, self).match(kwargs) + self.logger.debug("GetOrCreateNestedSerializerMixin.match with super {} and kwargs {}".format(match, kwargs)) + if match is not None: + return match, needs_saved + create_values = self.build_direct_values(kwargs) + return self.queryset.model(**create_values), True + + +class UpdateOrCreateNestedSerializerMixin(UpdateDoSaveMixin, GetOrCreateNestedSerializerMixin): + """Gets (without updating) or creates requested object.""" + + +class CreateOnlyNestedSerializerMixin(GetOnlyNestedSerializerMixin): + """Creates requested object or fails.""" + + def match(self, kwargs): + match, needs_saved = super(CreateOnlyNestedSerializerMixin, self).match(kwargs) + self.logger.debug("CreateOnlyNestedSerializerMixin.match with super {} and kwargs {}".format(match, kwargs)) + if match is not None: + raise IntegrityError("Matching {} object already exists".format(match.__class__.__name__)) + create_values = self.build_direct_values(kwargs) + return self.queryset.model(**create_values), True + + def do_m2m_update(self, match, m2m_relations): + # assign relations to forward many-to-many fields + for k, v in m2m_relations.items(): + self.logger.debug("{}: m2m add to {}, {}".format(self.__class__.__name__, k, v)) + getattr(match, k).add(v) diff --git a/tests/migrations/0002_auto_20210201_1452.py b/tests/migrations/0002_auto_20210201_1452.py new file mode 100644 index 0000000..6950140 --- /dev/null +++ b/tests/migrations/0002_auto_20210201_1452.py @@ -0,0 +1,190 @@ +# Generated by Django 2.1.3 on 2021-02-01 14:52 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('tests', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='Child', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.TextField()), + ], + ), + migrations.CreateModel( + name='ContextChild', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.TextField()), + ('owner', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ], + ), + migrations.CreateModel( + name='GrandParent', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ], + ), + migrations.CreateModel( + name='LookupChild', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.TextField()), + ], + ), + migrations.CreateModel( + name='LookupGrandParent', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ], + ), + migrations.CreateModel( + name='LookupOneToOneChild', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.TextField()), + ], + ), + migrations.CreateModel( + name='LookupParent', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('child', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='parent', to='tests.LookupChild')), + ('child2', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='parent2', to='tests.LookupChild')), + ], + ), + migrations.CreateModel( + name='LookupReverseChild', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.TextField()), + ('parent', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='children', to='tests.LookupParent')), + ], + ), + migrations.CreateModel( + name='M2MSource', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.TextField()), + ], + ), + migrations.CreateModel( + name='M2MTarget', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.TextField()), + ], + ), + migrations.CreateModel( + name='NewProfile', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('age', models.IntegerField()), + ], + ), + migrations.CreateModel( + name='NewUser', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('username', models.TextField()), + ], + ), + migrations.CreateModel( + name='Parent', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('child', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='tests.Child')), + ], + ), + migrations.CreateModel( + name='ParentMany', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('children', models.ManyToManyField(to='tests.Child')), + ], + ), + migrations.CreateModel( + name='ReadOnlyChild', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.TextField()), + ], + ), + migrations.CreateModel( + name='ReadOnlyParent', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('child', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='tests.ReadOnlyChild')), + ], + ), + migrations.CreateModel( + name='ReverseChild', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.TextField()), + ], + ), + migrations.CreateModel( + name='ReverseManyChild', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.TextField()), + ], + ), + migrations.CreateModel( + name='ReverseManyParent', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ], + ), + migrations.CreateModel( + name='ReverseParent', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ], + ), + migrations.AddField( + model_name='reversemanychild', + name='parent', + field=models.ManyToManyField(related_name='children', to='tests.ReverseManyParent'), + ), + migrations.AddField( + model_name='reversechild', + name='parent', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='children', to='tests.ReverseParent'), + ), + migrations.AddField( + model_name='newprofile', + name='user', + field=models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, related_name='profile', to='tests.NewUser'), + ), + migrations.AddField( + model_name='m2msource', + name='forward', + field=models.ManyToManyField(related_name='reverse', to='tests.M2MTarget'), + ), + migrations.AddField( + model_name='lookuponetoonechild', + name='parent', + field=models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, related_name='one_to_one', to='tests.LookupParent'), + ), + migrations.AddField( + model_name='lookupgrandparent', + name='child', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='tests.LookupParent'), + ), + migrations.AddField( + model_name='grandparent', + name='child', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='tests.Parent'), + ), + ] diff --git a/tests/models.py b/tests/models.py index 4455d97..ccf501f 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,4 +1,6 @@ import uuid + +from django.conf import settings from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation from django.contrib.contenttypes.models import ContentType from django.db import models @@ -147,3 +149,90 @@ class I86Name(models.Model): class I86Genre(models.Model): pass + +class ReadOnlyChild(models.Model): + name = models.TextField() + + +class ReadOnlyParent(models.Model): + child = models.ForeignKey(ReadOnlyChild, on_delete=models.CASCADE) + + +class Child(models.Model): + name = models.TextField() + + +class Parent(models.Model): + child = models.ForeignKey(Child, on_delete=models.CASCADE) + + +class ParentMany(models.Model): + children = models.ManyToManyField(Child) + + +class ReverseParent(models.Model): + pass + + +class ReverseChild(models.Model): + name = models.TextField() + parent = models.ForeignKey(ReverseParent, on_delete=models.CASCADE, related_name='children') + + +class ReverseManyParent(models.Model): + pass + + +class ReverseManyChild(models.Model): + name = models.TextField() + parent = models.ManyToManyField(ReverseManyParent, related_name='children') + + +class LookupChild(models.Model): + name = models.TextField() + + +class LookupParent(models.Model): + child = models.ForeignKey(LookupChild, on_delete=models.CASCADE, related_name='parent') + child2 = models.ForeignKey(LookupChild, on_delete=models.CASCADE, related_name='parent2') + + +class LookupReverseChild(models.Model): + name = models.TextField() + parent = models.ForeignKey(LookupParent, on_delete=models.CASCADE, related_name='children') + + +class LookupOneToOneChild(models.Model): + name = models.TextField() + parent = models.OneToOneField(LookupParent, on_delete=models.CASCADE, related_name='one_to_one') + + +class LookupGrandParent(models.Model): + child = models.ForeignKey(LookupParent, on_delete=models.CASCADE) + + +class M2MTarget(models.Model): + name = models.TextField() + + +class M2MSource(models.Model): + forward = models.ManyToManyField(M2MTarget, related_name='reverse') + name = models.TextField() + + +class GrandParent(models.Model): + child = models.ForeignKey(Parent, on_delete=models.CASCADE) + + +class ContextChild(models.Model): + name = models.TextField() + owner = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE) + + +class NewUser (models.Model): + username = models.TextField() + + +class NewProfile(models.Model): + user = models.OneToOneField(NewUser, on_delete=models.CASCADE, related_name='profile') + age = models.IntegerField() \ No newline at end of file diff --git a/tests/serializers.py b/tests/serializers.py index b96a9f1..d2fcad8 100644 --- a/tests/serializers.py +++ b/tests/serializers.py @@ -4,7 +4,7 @@ from rest_framework.validators import UniqueValidator from drf_writable_nested.serializers import WritableNestedModelSerializer -from drf_writable_nested.mixins import UniqueFieldsMixin +from drf_writable_nested.mixins import UniqueFieldsMixin, RelatedSaveMixin, UpdateOrCreateNestedSerializerMixin from . import models @@ -338,3 +338,387 @@ class I86GenreSerializer(WritableNestedModelSerializer): class Meta: model = models.I86Genre fields = ('id', 'names',) + + +############ +# NEW STYLE +############ + + +class NewAvatarSerializer(UpdateOrCreateNestedSerializerMixin, serializers.ModelSerializer): + image = serializers.CharField() + + class Meta: + model = models.Avatar + fields = ('pk', 'image', + # new-style serializers must include parent FK field + 'profile', + ) + + +class NewMessageSerializer(UpdateOrCreateNestedSerializerMixin, serializers.ModelSerializer): + + class Meta: + model = models.Message + fields = ('pk', 'message', + # new-style serializers must include parent FK field + 'profile' + ) + + +class NewSiteSerializer(UpdateOrCreateNestedSerializerMixin, serializers.ModelSerializer): + url = serializers.CharField() + + class Meta: + model = models.Site + fields = ('pk', 'url',) + + +class NewAccessKeySerializer(UpdateOrCreateNestedSerializerMixin, serializers.ModelSerializer): + + class Meta: + model = models.AccessKey + fields = ('pk', 'key',) + + +class NewProfileSerializer(UpdateOrCreateNestedSerializerMixin, serializers.ModelSerializer): + # new-style serializers can prioritize one-to-one relations by explicitly listing the one-to-one relation as the match criteria + DEFAULT_MATCH_ON = ['user'] + + # Direct ManyToMany relation + sites = NewSiteSerializer(many=True) + + # Reverse FK relation + avatars = NewAvatarSerializer(many=True) + + # Direct FK relation + access_key = NewAccessKeySerializer(allow_null=True) + + # Reverse FK relation with UUID + message_set = NewMessageSerializer(many=True) + + class Meta: + model = models.Profile + fields = ('pk', 'sites', 'avatars', 'access_key', 'message_set', + # new-style serializers must include parent FK field + 'user', + ) + + +class NewBaseProfileSerializer(UpdateOrCreateNestedSerializerMixin, serializers.ModelSerializer): + # this top-level serializer expects to find the Profile by pk + DEFAULT_MATCH_ON = ['pk'] # actually the default + + # Direct ManyToMany relation + sites = NewSiteSerializer(many=True) + + # Reverse FK relation + avatars = NewAvatarSerializer(many=True) + + # Direct FK relation + access_key = NewAccessKeySerializer(allow_null=True) + + # Reverse FK relation with UUID + message_set = NewMessageSerializer(many=True) + + class Meta: + model = models.Profile + # because this is not a nested serializer (i.e. found by parent), we don't include `user` + fields = ('pk', 'sites', 'avatars', 'access_key', 'message_set',) + + +class NewUserSerializer(UpdateOrCreateNestedSerializerMixin, serializers.ModelSerializer): + # Reverse OneToOne relation + profile = NewProfileSerializer(required=False, allow_null=True) + user_avatar = NewAvatarSerializer(required=False, allow_null=True) + + class Meta: + model = models.User + fields = ('pk', 'profile', 'username', 'user_avatar') + + +class NewCustomSerializer(UserSerializer): + # Simulate having non-modelfield information on the serializer + custom_field = serializers.CharField() + + class Meta(UserSerializer.Meta): + fields = NewUserSerializer.Meta.fields + ( + 'custom_field', + ) + + def validate(self, attrs): + attrs.pop('custom_field', None) + return attrs + + +class NewTagSerializer(UpdateOrCreateNestedSerializerMixin, serializers.ModelSerializer): + + class Meta: + model = models.Tag + fields = ( + 'pk', + 'tag', + ) + + +class NewTaggedItemSerializer(RelatedSaveMixin, serializers.ModelSerializer): + tags = NewTagSerializer(many=True) + + class Meta: + model = models.TaggedItem + fields = ( + 'tags', + ) + + +class NewTeamSerializer(RelatedSaveMixin, serializers.ModelSerializer): + members = NewUserSerializer(many=True, required=False) + + class Meta: + model = models.Team + fields = ( + 'members', + 'name', + ) + + +class NewCustomPKSerializer(UpdateOrCreateNestedSerializerMixin, serializers.ModelSerializer): + DEFAULT_MATCH_ON = ['slug'] + + class Meta: + model = models.CustomPK + fields = ( + 'slug', + # new-style serializers must include parent FK field + 'user', + ) + + +class NewUserWithCustomPKSerializer(RelatedSaveMixin, serializers.ModelSerializer): + custompks = NewCustomPKSerializer( + many=True, + ) + + class Meta: + model = models.User + fields = ( + 'custompks', + 'username', + ) + + +class NewAnotherAvatarSerializer(UpdateOrCreateNestedSerializerMixin, serializers.ModelSerializer): + image = serializers.CharField() + + class Meta: + model = models.AnotherAvatar + fields = ('pk', 'image', + # new-style serializers must include parent FK field + 'profile',) + + +# test_update_another_user_with_explicit_source expect to match an existing AnotherProfile by PK (but we get to the same place by-user) +class NewAnotherProfileSerializer(UpdateOrCreateNestedSerializerMixin, serializers.ModelSerializer): + # new-style serializers can prioritize one-to-one relations by explicitly listing the one-to-one relation as the match criteria + DEFAULT_MATCH_ON = ['user'] + + # Direct ManyToMany relation + another_sites = NewSiteSerializer(source='sites', many=True) + + # Reverse FK relation + another_avatars = NewAnotherAvatarSerializer(source='avatars', many=True) + + # Direct FK relation + another_access_key = NewAccessKeySerializer( + source='access_key', allow_null=True) + + class Meta: + model = models.AnotherProfile + fields = ('pk', 'another_sites', 'another_avatars', + 'another_access_key', + # new-style serializers must include parent FK field + 'user', + ) + + +# UpdateOrCreate because test_update_another_user_with_explicit_source expects serializer to find existing User by PK +class NewAnotherUserSerializer(UpdateOrCreateNestedSerializerMixin, serializers.ModelSerializer): + # Reverse OneToOne relation + another_profile = NewAnotherProfileSerializer( + source='anotherprofile', required=False, allow_null=True) + + class Meta: + model = models.User + fields = ('pk', 'another_profile', 'username',) + + +class NewPageSerializer(UpdateOrCreateNestedSerializerMixin, serializers.ModelSerializer): + class Meta: + model = models.Page + fields = ('pk', 'title') + + +class NewDocumentSerializer(RelatedSaveMixin, serializers.ModelSerializer): + page = NewPageSerializer() + + class Meta: + model = models.Document + fields = ('pk', 'page', 'source') + + +class NewUFMChildSerializer(UniqueFieldsMixin, serializers.ModelSerializer): + class Meta: + model = models.UFMChild + fields = ('pk', 'field') + + +class NewUFMParentSerializer(RelatedSaveMixin, serializers.ModelSerializer): + child = NewUFMChildSerializer() + + class Meta: + model = models.UFMParent + fields = ('pk', 'child') + + +# Different relations + + +class NewRaiseErrorMixin(RaiseErrorMixin, UpdateOrCreateNestedSerializerMixin, serializers.ModelSerializer): + raise_error = serializers.BooleanField(required=False, default=False) + + def save(self, **kwargs): + raise_error = self.validated_data.pop('raise_error', False) + if raise_error: + raise ValidationError({'raise_error': ['should be False']}) + + return super(RaiseErrorMixin, self).save(**kwargs) + + +class NewDirectForeignKeyChildSerializer(RaiseErrorMixin, + serializers.ModelSerializer): + class Meta: + model = models.ForeignKeyChild + fields = ('id', 'raise_error',) + + +class NewDirectForeignKeyParentSerializer(RelatedSaveMixin, serializers.ModelSerializer): + child = NewDirectForeignKeyChildSerializer() + + class Meta: + model = models.ForeignKeyParent + fields = ('id', 'child',) + + +class NewReverseForeignKeyParentSerializer(RaiseErrorMixin, + serializers.ModelSerializer): + class Meta: + model = models.ForeignKeyParent + fields = ('id', 'raise_error', + # new-style serializers must include parent FK field + 'child' + ) + + +class NewReverseForeignKeyChildSerializer(RelatedSaveMixin, serializers.ModelSerializer): + parents = NewReverseForeignKeyParentSerializer(many=True) + + class Meta: + model = models.ForeignKeyChild + fields = ('id', 'parents',) + + +class NewDirectOneToOneChildSerializer(RaiseErrorMixin, + serializers.ModelSerializer): + class Meta: + model = models.OneToOneChild + fields = ('id', 'raise_error',) + + +class NewDirectOneToOneParentSerializer(RelatedSaveMixin, serializers.ModelSerializer): + child = NewDirectOneToOneChildSerializer() + + class Meta: + model = models.OneToOneParent + fields = ('id', 'child',) + + +class NewReverseOneToOneParentSerializer(RaiseErrorMixin, + serializers.ModelSerializer): + class Meta: + model = models.OneToOneParent + fields = ('id', 'raise_error', + # new-style serializers must include parent FK field + 'child', + ) + + +class NewReverseOneToOneChildSerializer(RelatedSaveMixin, serializers.ModelSerializer): + parent = NewReverseOneToOneParentSerializer() + + class Meta: + model = models.OneToOneChild + fields = ('id', 'parent',) + + +class NewDirectManyToManyChildSerializer(RaiseErrorMixin, + serializers.ModelSerializer): + class Meta: + model = models.ManyToManyChild + fields = ('id', 'raise_error',) + + +class NewDirectManyToManyParentSerializer(RelatedSaveMixin, serializers.ModelSerializer): + children = NewDirectManyToManyChildSerializer(many=True) + + class Meta: + model = models.ManyToManyParent + fields = ('id', 'children',) + + +class NewReverseManyToManyParentSerializer(RaiseErrorMixin, + serializers.ModelSerializer): + class Meta: + model = models.ManyToManyParent + fields = ('id', 'raise_error',) + + +class NewReverseManyToManyChildSerializer(RelatedSaveMixin, serializers.ModelSerializer): + parents = NewReverseManyToManyParentSerializer(many=True) + + class Meta: + model = models.ManyToManyChild + fields = ('id', 'parents',) + + +class NewI86GenreNameSerializer(UpdateOrCreateNestedSerializerMixin, serializers.ModelSerializer): + class Meta: + model = models.I86Name + fields = ('id', 'string', + # new-style serializers must include parent FK field + 'item', + ) + + +class NewI86GenreSerializer(RelatedSaveMixin, serializers.ModelSerializer): + names = NewI86GenreNameSerializer(many=True) + + class Meta: + model = models.I86Genre + fields = ('id', 'names',) + + +class NewReadOnlyChildSerializer(UpdateOrCreateNestedSerializerMixin, serializers.ModelSerializer): + class Meta: + model = models.ReadOnlyChild + fields = ('id', 'name') + extra_kwargs = { + 'name': {'read_only': True} + } + + +class NewReadOnlyParentSerializer(RelatedSaveMixin, serializers.ModelSerializer): + child = NewReadOnlyChildSerializer(match_on=['id']) + + class Meta: + model = models.ReadOnlyParent + fields = ('id', 'child') diff --git a/tests/test_field_lookup.py b/tests/test_field_lookup.py new file mode 100644 index 0000000..dad73b0 --- /dev/null +++ b/tests/test_field_lookup.py @@ -0,0 +1,255 @@ +from django.db import models +from django.test import TestCase +from rest_framework import serializers + +from drf_writable_nested import mixins +from tests.models import LookupChild, LookupParent, LookupReverseChild, LookupOneToOneChild, LookupGrandParent, \ + M2MTarget, M2MSource + + +class ChildSerializer(mixins.FieldLookupMixin, serializers.ModelSerializer): + class Meta: + model = LookupChild + fields = '__all__' + + +class ReverseChildSerializer(mixins.FieldLookupMixin, serializers.ModelSerializer): + class Meta: + model = LookupReverseChild + fields = '__all__' + + +class OneToOneChildSerializer(mixins.FieldLookupMixin, serializers.ModelSerializer): + class Meta: + model = LookupOneToOneChild + fields = '__all__' + + +class ParentSerializer(mixins.FieldLookupMixin, serializers.ModelSerializer): + class Meta: + model = LookupParent + # otherwise child2 will get created by the ModelSerializer (and duplicate child_source) + exclude = ['child2'] + # source of a 1:many relationship + child = ChildSerializer() + child_source = ChildSerializer(source='child2') + children = ReverseChildSerializer(many=True) + one_to_one = OneToOneChildSerializer() + + +class NestedParentSerializer(mixins.FieldLookupMixin, serializers.ModelSerializer): + class Meta: + model = LookupParent + fields = '__all__' + + +class OneToOneForwardSerializer(mixins.FieldLookupMixin, serializers.ModelSerializer): + class Meta: + model = LookupOneToOneChild + fields = '__all__' + + parent = NestedParentSerializer() + + +class GrandParentSerializer(mixins.FieldLookupMixin, serializers.ModelSerializer): + class Meta: + model = LookupGrandParent + fields = '__all__' + # source of a 1:many relationship + child = ParentSerializer() + + +class M2MForwardTargetSerializer(serializers.ModelSerializer): + class Meta: + model = M2MTarget + fields = '__all__' + + +class M2MForwardSourceSerializer(mixins.FieldLookupMixin, serializers.ModelSerializer): + class Meta: + model = M2MSource + fields = '__all__' + forward = M2MForwardTargetSerializer(many=True) + + +class M2MReverseTargetSerializer(serializers.ModelSerializer): + class Meta: + model = M2MSource + fields = '__all__' + + +class M2MReverseSourceSerializer(mixins.FieldLookupMixin, serializers.ModelSerializer): + class Meta: + model = M2MTarget + fields = '__all__' + reverse = M2MReverseTargetSerializer(many=True) + + +class GetModelFieldTest(TestCase): + """Field types are determined by accessing the model. These test confirm that the methods for identifying these + fields do not change.""" + + def test_fk(self): + """Confirm that the test works correctly for fields with a source value""" + serializer = ParentSerializer() + model_field = serializer._get_model_field(serializer.fields['child'].source) + self.assertIsInstance( + model_field, + models.ForeignKey, + "Found {}, expected ForeignKey".format(type(model_field)) + ) + + def test_fk_source(self): + """Confirm that the test works correctly for fields with a source value""" + serializer = ParentSerializer() + model_field = serializer._get_model_field(serializer.fields['child_source'].source) + self.assertIsInstance( + model_field, + models.ForeignKey, + "Found {}, expected ForeignKey".format(type(model_field)) + ) + + def test_reverse_fk(self): + """Confirm that a reverse ForeignKey is a ManyToOneRel""" + serializer = ParentSerializer() + model_field = serializer._get_model_field(serializer.fields['children'].source) + # opposite side of a ForeignKeyField is a ManyToOneRel + self.assertIsInstance( + model_field, + models.ManyToOneRel, + "Found {}, expected ManyToOneRel".format(type(model_field)) + ) + + def test_onetoone_reverse(self): + """A reverse OneToOne relationship is a OneToOneRel""" + serializer = ParentSerializer() + model_field = serializer._get_model_field(serializer.fields['one_to_one'].source) + # opposite side of a OneToOneField is a ManyToOne + self.assertIsInstance( + model_field, + models.OneToOneRel, + "Found {}, expected OneToOneRel".format(type(model_field)) + ) + + def test_onetoone_forward(self): + """A forward OneToOne relationship is a OneToOneField""" + serializer = OneToOneForwardSerializer() + model_field = serializer._get_model_field(serializer.fields['parent'].source) + # opposite side of a OneToOneField is a ManyToOne + self.assertIsInstance( + model_field, + models.OneToOneField, + "Found {}, expected OneToOneRel".format(type(model_field)) + ) + + def test_m2m_reverse(self): + """A reverse OneToOne relationship is a OneToOneRel""" + serializer = M2MReverseSourceSerializer() + model_field = serializer._get_model_field(serializer.fields['reverse'].source) + # opposite side of a OneToOneField is a ManyToOne + self.assertIsInstance( + model_field, + models.ManyToManyRel, + "Found {}, expected ManyToManyRel".format(type(model_field)) + ) + + def test_m2m_forward(self): + """A forward OneToOne relationship is a OneToOneField""" + serializer = M2MForwardSourceSerializer() + model_field = serializer._get_model_field(serializer.fields['forward'].source) + # opposite side of a OneToOneField is a ManyToOne + self.assertIsInstance( + model_field, + models.ManyToManyField, + "Found {}, expected ManyToManyField".format(type(model_field)) + ) + + +class FieldTypesTest(TestCase): + """Tests resolution of field types. ID is always read-only.""" + + def test_field_types_grandparent(self): + """Nested serializer should be direct""" + serializer = GrandParentSerializer() + self.assertEqual( + { + 'id': serializer.TYPE_READ_ONLY, + 'child': serializer.TYPE_DIRECT, + }, + serializer.field_types + ) + + def test_field_types_parent(self): + """Reverse one-to-one and reverse FK should be classified as Reverse""" + serializer = ParentSerializer() + self.assertEqual( + { + 'id': serializer.TYPE_READ_ONLY, + 'child': serializer.TYPE_DIRECT, + 'child_source': serializer.TYPE_DIRECT, + 'children': serializer.TYPE_REVERSE, + 'one_to_one': serializer.TYPE_REVERSE, + }, + serializer.field_types + ) + + def test_field_sources_parent(self): + """Reverse one-to-one and reverse FK should be classified as Reverse""" + serializer = ParentSerializer() + self.assertEqual( + { + 'id': serializer.TYPE_READ_ONLY, + 'child': serializer.TYPE_DIRECT, + 'child2': serializer.TYPE_DIRECT, + 'children': serializer.TYPE_REVERSE, + 'one_to_one': serializer.TYPE_REVERSE, + }, + serializer.field_sources + ) + + def test_field_types_child(self): + """""" + serializer = ChildSerializer() + self.assertEqual( + { + 'id': serializer.TYPE_READ_ONLY, + 'name': serializer.TYPE_LOCAL, + }, + serializer.field_types + ) + + def test_field_types_reversechild(self): + serializer = ReverseChildSerializer() + self.assertEqual( + { + 'id': serializer.TYPE_READ_ONLY, + 'name': serializer.TYPE_LOCAL, + # must have a nested serializer to be "direct" otherwise it's just a local value + 'parent': serializer.TYPE_LOCAL, + }, + serializer.field_types + ) + + def test_field_types_onetoone_reverse(self): + serializer = OneToOneChildSerializer() + self.assertEqual( + { + 'id': serializer.TYPE_READ_ONLY, + 'name': serializer.TYPE_LOCAL, + # must have a nested serializer to be "direct" otherwise it's just a local value + 'parent': serializer.TYPE_LOCAL, + }, + serializer.field_types + ) + + def test_field_types_onetoone_forward(self): + serializer = OneToOneForwardSerializer() + self.assertEqual( + { + 'id': serializer.TYPE_READ_ONLY, + 'name': serializer.TYPE_LOCAL, + # must have a nested serializer to be "direct" otherwise it's just a local value + 'parent': serializer.TYPE_DIRECT, + }, + serializer.field_types + ) diff --git a/tests/test_nested_serializer_mixins.py b/tests/test_nested_serializer_mixins.py new file mode 100644 index 0000000..13e7d8b --- /dev/null +++ b/tests/test_nested_serializer_mixins.py @@ -0,0 +1,734 @@ +from django.contrib.auth import get_user_model +from django.db import IntegrityError +from django.test import TestCase, RequestFactory +from rest_framework import serializers + +from drf_writable_nested import mixins +from tests.models import Child, Parent, ParentMany, ReverseParent, ReverseChild, ReverseManyParent, ReverseManyChild, \ + GrandParent, ContextChild, NewUser, NewProfile + + +######################### +# GetOrCreate Serializer +######################### +class ChildGetOrCreateSerializer(mixins.GetOrCreateNestedSerializerMixin, serializers.ModelSerializer): + DEFAULT_MATCH_ON = ['name'] + + class Meta: + model = Child + fields = '__all__' + + +class GenericParentRelatedSaveSerializer(mixins.RelatedSaveMixin): + class Meta: + fields = '__all__' + # source of a 1:many relationship + child = ChildGetOrCreateSerializer() + + def create(self, validated_data): + # "container only", no create logic + return validated_data + + +################## +# Direct Relation +################## +class ParentRelatedSaveSerializer(mixins.RelatedSaveMixin, serializers.ModelSerializer): + class Meta: + model = Parent + fields = '__all__' + # source of a 1:many relationship + child = ChildGetOrCreateSerializer() + + +class ParentManyRelatedSaveSerializer(mixins.RelatedSaveMixin, serializers.ModelSerializer): + class Meta: + model = ParentMany + fields = '__all__' + # source of a m2m relationship + children = ChildGetOrCreateSerializer(many=True) + + +################### +# Reverse Relation +################### +class ReverseChildGetOrCreateSerializer(mixins.GetOrCreateNestedSerializerMixin, serializers.ModelSerializer): + DEFAULT_MATCH_ON = ['name'] + + class Meta: + model = ReverseChild + fields = '__all__' + + +class ReverseManyChildGetOrCreateSerializer(mixins.GetOrCreateNestedSerializerMixin, serializers.ModelSerializer): + DEFAULT_MATCH_ON = ['name'] + + class Meta: + model = ReverseManyChild + fields = '__all__' + + +class ReverseParentRelatedSaveSerializer(mixins.RelatedSaveMixin, serializers.ModelSerializer): + class Meta: + model = ReverseParent + fields = '__all__' + # target of a 1:many relationship + children = ReverseChildGetOrCreateSerializer(many=True) + + +class ReverseParentGetOnlySerializer(mixins.GetOnlyNestedSerializerMixin, serializers.ModelSerializer): + class Meta: + model = ReverseParent + fields = '__all__' + # target of a m2m relationship + children = ReverseChildGetOrCreateSerializer(many=True) + + +class ReverseManyParentRelatedSaveSerializer(mixins.RelatedSaveMixin, serializers.ModelSerializer): + class Meta: + model = ReverseManyParent + fields = '__all__' + # target of a m2m relationship + children = ReverseManyChildGetOrCreateSerializer(many=True) + + +class WritableNestedModelSerializerTest(TestCase): + + def test_generic_nested_create(self): + data = { + "child": { + "name": "test", + } + } + + serializer = GenericParentRelatedSaveSerializer(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + instance = serializer.save() + self.assertIsInstance( + instance, + dict, + ) + self.assertIn( + 'child', + instance, + ) + self.assertIsInstance( + instance['child'], + Child, + ) + self.assertEqual( + 'test', + instance['child'].name, + ) + + def test_generic_nested_get(self): + """A second run with a GetOrCreate nested serializer should find same child object (by name)""" + data = { + "child": { + "name": "test", + } + } + + serializer = GenericParentRelatedSaveSerializer(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + serializer.save() + + serializer = GenericParentRelatedSaveSerializer(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + serializer.save() + + self.assertEqual( + 1, + Child.objects.count(), + ) + + def test_direct_nested_create(self): + data = { + "child": { + "name": "test", + } + } + + serializer = ParentRelatedSaveSerializer(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + serializer.save() + + def test_direct_nested_get(self): + """A second run with a GetOrCreate nested serializer should find same child object (by name)""" + data = { + "child": { + "name": "test", + } + } + + serializer = ParentRelatedSaveSerializer(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + serializer.save() + + serializer = ParentRelatedSaveSerializer(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + serializer.save() + + self.assertEqual( + 2, + Parent.objects.count() + ) + + self.assertEqual( + 1, + Child.objects.count(), + ) + + def test_direct_many_nested_create(self): + data = { + "children": [{ + "name": "test", + }] + } + + serializer = ParentManyRelatedSaveSerializer(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + serializer.save() + + def test_direct_many_nested_get(self): + """A second run with a GetOrCreate nested serializer should find same child object (by name)""" + data = { + "children": [{ + "name": "test", + }] + } + + serializer = ParentManyRelatedSaveSerializer(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + serializer.save() + + serializer = ParentManyRelatedSaveSerializer(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + serializer.save() + + self.assertEqual( + 1, + Child.objects.count(), + ) + + def test_reverse_nested_create(self): + data = { + "children": [{ + "name": "test", + }] + } + + serializer = ReverseParentRelatedSaveSerializer(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + serializer.save() + + def test_reverse_nested_get(self): + """A second run with a GetOrCreate nested serializer should find same child object (by name)""" + data = { + "children": [{ + "name": "test", + }] + } + + serializer = ReverseParentRelatedSaveSerializer(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + serializer.save() + + serializer = ReverseParentRelatedSaveSerializer(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + serializer.save() + + self.assertEqual( + 2, + ReverseParent.objects.count() + ) + + self.assertEqual( + 1, + ReverseChild.objects.count(), + ) + + def test_reverse_many_nested_create(self): + data = { + "children": [{ + "name": "test", + }] + } + + serializer = ReverseManyParentRelatedSaveSerializer(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + serializer.save() + + def test_reverse_many_nested_get(self): + """A second run with a GetOrCreate nested serializer should find same child object (by name)""" + data = { + "children": [{ + "name": "test", + }] + } + + serializer = ReverseManyParentRelatedSaveSerializer(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + serializer.save() + + serializer = ReverseManyParentRelatedSaveSerializer(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + serializer.save() + + self.assertEqual( + 1, + ReverseManyChild.objects.count(), + ) + + def test_reverse_set(self): + """We had to implement a workaround because `set` does not work correctly on non-nullable reverse FKs""" + parent = ReverseParent() + parent.save() + + child1 = ReverseChild(name='test1', parent=parent) + child1.save() + child2 = ReverseChild(name='test2', parent=parent) + child2.save() + child3 = ReverseChild(name='test3', parent=parent) + child3.save() + # set is supposed to remove missing children + parent.children.set([child1, child3]) + + # if this ever fails (i.e. returns 2), we may be able to rip out the manual reverse-FK update logic + self.assertEqual( + 3, + parent.children.count() + ) + + def test_reverse_match(self): + p = ReverseParent.objects.create() + ReverseChild.objects.create( + parent=p, + name='test1', + ) + ReverseChild.objects.create( + parent=p, + name='test2', + ) + p = ReverseParent.objects.create() + ReverseChild.objects.create( + parent=p, + name='test3', + ) + ReverseChild.objects.create( + parent=p, + name='test4', + ) + + serializer = ReverseParentGetOnlySerializer( + match_on=['children'], + data={ + 'children': [ + { + 'name': 'test3' + }, + { + 'name': 'test5' + } + ] + } + ) + self.assertTrue( + serializer.is_valid() + ) + obj = serializer.save() + # match on child test3 was successful + self.assertEqual( + p, + obj, + ) + # has two children (i.e. test3, test5) + self.assertEqual( + 2, + obj.children.count(), + ) + # contains correct children + self.assertTrue( + ReverseParent.objects.filter(id=p.id, children__name='test3').exists() + ) + self.assertTrue( + ReverseParent.objects.filter(id=p.id, children__name='test5').exists() + ) + + +################### +# 3-Layer Relation +################### +class NestedParentGetOrCreateSerializer(mixins.GetOrCreateNestedSerializerMixin, serializers.ModelSerializer): + class Meta: + model = Parent + fields = '__all__' + # source of a 1:many relationship + child = ChildGetOrCreateSerializer() + + +class GrandParentRelatedSaveSerializer(mixins.RelatedSaveMixin, serializers.ModelSerializer): + class Meta: + model = GrandParent + fields = '__all__' + # source of a 1:many relationship + child = NestedParentGetOrCreateSerializer() + + +class DoubleNestedModelSerializerTest(TestCase): + + def test_direct_nested_create(self): + data = { + "child": { + "child": { + "name": "test", + } + } + } + + serializer = GrandParentRelatedSaveSerializer(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + serializer.save() + + self.assertEqual( + 1, + GrandParent.objects.count(), + ) + + self.assertEqual( + 1, + Parent.objects.count(), + ) + + self.assertEqual( + 1, + Child.objects.count(), + ) + + instance = serializer.save() + self.assertIsInstance( + instance, + GrandParent, + ) + self.assertIsInstance( + instance.child, + Parent, + ) + self.assertIsInstance( + instance.child.child, + Child, + ) + self.assertEqual( + 'test', + instance.child.child.name, + ) + + +############## +# Create Only +############## +class ChildCreateOnlySerializer(mixins.CreateOnlyNestedSerializerMixin, serializers.ModelSerializer): + DEFAULT_MATCH_ON = ['name'] + + class Meta: + model = Child + fields = '__all__' + + +class ParentRelatedSaveSerializerCreateOnly(mixins.RelatedSaveMixin): + class Meta: + fields = '__all__' + # source of a 1:many relationship + child = ChildCreateOnlySerializer() + + def create(self, validated_data): + # "container only", no create logic + return validated_data + + +class CreateOnlyModelSerializerTest(TestCase): + + def test_create_match_error(self): + """Create Only serializers will not match an existing object (despite match_on)""" + data = { + "child": { + "name": "test", + } + } + + serializer = ParentRelatedSaveSerializerCreateOnly(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + serializer.save() + + self.assertEqual( + 1, + Child.objects.count() + ) + + serializer = ParentRelatedSaveSerializerCreateOnly(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + with self.assertRaises(IntegrityError): + serializer.save() + + def test_create_match_error(self): + """Create Only serializers will error if a match is found""" + data = { + "child": { + "name": "test", + } + } + + serializer = ParentRelatedSaveSerializerCreateOnly(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + serializer.save() + + self.assertEqual( + 1, + Child.objects.count() + ) + + serializer = ParentRelatedSaveSerializerCreateOnly(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + with self.assertRaises(IntegrityError): + serializer.save() + + +##################### +# Context Conduction +##################### +class ContextChildGetOrCreateSerializer(mixins.GetOrCreateNestedSerializerMixin, serializers.ModelSerializer): + class Meta: + model = ContextChild + fields = '__all__' + extra_kwargs = { + 'owner': { + 'default': serializers.CurrentUserDefault(), + } + } + + +class GenericContextParentRelatedSaveSerializer(mixins.RelatedSaveMixin): + child = ContextChildGetOrCreateSerializer() + + def create(self, validated_data): + # "container only", no create logic + return validated_data + + +class GenericContextGrandParentRelatedSaveSerializer(mixins.RelatedSaveMixin): + child = GenericContextParentRelatedSaveSerializer() + + def create(self, validated_data): + # "container only", no create logic + return validated_data + + +class ContextConductionTest(TestCase): + + def setUp(self): + self.user = get_user_model().objects.create(username="test_user") + + def test_context_conduction(self): + data = { + "child": { + "child": { + "name": "test", + } + } + } + + request = RequestFactory() + request.user = self.user + + serializer = GenericContextGrandParentRelatedSaveSerializer(data=data) + serializer._context = { + 'request': request + } + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + serializer.save() + + +################## +# Wildcard Source +################## +class WildcardParentRelatedSaveSerializer(mixins.RelatedSaveMixin): + class Meta: + fields = '__all__' + # makes the current class a pass-through + parent = GenericParentRelatedSaveSerializer(source='*') + + def create(self, validated_data): + # "container only", no create logic + return validated_data + + +class WildcardSourceSerializerTest(TestCase): + + def test_wildcard_source(self): + """Wildcard sources should be processed correctly""" + data = { + "child": { + "name": "test", + } + } + + serializer = GenericParentRelatedSaveSerializer(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + instance = serializer.save() + self.assertIsInstance( + instance, + dict, + ) + self.assertIn( + 'child', + instance, + ) + self.assertIsInstance( + instance['child'], + Child, + ) + self.assertEqual( + 'test', + instance['child'].name, + ) + + +########### +# OneToOne +########### +class ProfileSerializer(mixins.GetOrCreateNestedSerializerMixin, serializers.ModelSerializer): + class Meta: + model = NewProfile + fields = '__all__' + + +class UserSerializer(mixins.RelatedSaveMixin, serializers.ModelSerializer): + class Meta: + model = NewUser + fields = '__all__' + # makes the current class a pass-through + profile = ProfileSerializer() + + +class OneToOneSerializerTest(TestCase): + + def test_onetoone_source(self): + """Wildcard sources should be processed correctly""" + data = { + "username": "test user", + "profile": { + "age": 50, + } + } + + serializer = UserSerializer(data=data) + valid = serializer.is_valid() + self.assertTrue( + valid, + "Serializer should have been valid: {}".format(serializer.errors) + ) + instance = serializer.save() + self.assertIsInstance( + instance, + NewUser, + ) + self.assertEqual( + "test user", + instance.username, + ) + self.assertIsInstance( + instance.profile, + NewProfile, + ) + self.assertEqual( + 50, + instance.profile.age, + ) diff --git a/tests/test_writable_nested_model_serializer_converted.py b/tests/test_writable_nested_model_serializer_converted.py new file mode 100644 index 0000000..a26ad14 --- /dev/null +++ b/tests/test_writable_nested_model_serializer_converted.py @@ -0,0 +1,957 @@ +import uuid + +from django.db import transaction +from django.db.models import ProtectedError +from django.http.request import QueryDict +from django.test import TestCase +from rest_framework import serializers +from rest_framework.exceptions import ValidationError + +from . import models, serializers +from .utils import get_sample_file + + +class WritableNestedModelSerializerTest(TestCase): + # noinspection PyMethodMayBeStatic + def get_initial_data(self): + return { + 'username': 'test', + 'profile': { + 'access_key': { + 'key': 'key', + }, + 'sites': [ + { + 'url': 'http://google.com', + }, + { + 'url': 'http://yahoo.com', + }, + ], + 'avatars': [ + { + 'image': 'image-1.png', + }, + { + 'image': 'image-2.png', + }, + ], + 'message_set': [ + { + 'message': 'Message 1' + }, + { + 'message': 'Message 2' + }, + { + 'message': 'Message 3' + }, + ] + }, + } + + def test_create(self): + serializer = serializers.NewUserSerializer(data=self.get_initial_data()) + serializer.is_valid(raise_exception=True) + user = serializer.save() + + self.assertIsNotNone(user) + self.assertEqual(user.username, 'test') + + profile = user.profile + self.assertIsNotNone(profile) + self.assertIsNotNone(profile.access_key) + self.assertEqual(profile.access_key.key, 'key') + self.assertEqual(profile.sites.count(), 2) + self.assertSetEqual( + set(profile.sites.values_list('url', flat=True)), + {'http://google.com', 'http://yahoo.com'} + ) + self.assertEqual(profile.avatars.count(), 2) + self.assertSetEqual( + set(profile.avatars.values_list('image', flat=True)), + {'image-1.png', 'image-2.png'} + ) + + # Check instances count + self.assertEqual(models.User.objects.count(), 1) + self.assertEqual(models.Profile.objects.count(), 1) + self.assertEqual(models.Site.objects.count(), 2) + self.assertEqual(models.Avatar.objects.count(), 2) + self.assertEqual(models.AccessKey.objects.count(), 1) + + def test_create_with_not_specified_reverse_one_to_one(self): + serializer = serializers.NewUserSerializer(data={'username': 'test'}) + serializer.is_valid(raise_exception=True) + user = serializer.save() + self.assertFalse(models.Profile.objects.filter(user=user).exists()) + + def test_create_with_empty_reverse_one_to_one(self): + serializer = serializers.NewUserSerializer( + data={'username': 'test', 'profile': None}) + serializer.is_valid(raise_exception=True) + user = serializer.save() + self.assertFalse(models.Profile.objects.filter(user=user).exists()) + + def test_create_with_custom_field(self): + data = self.get_initial_data() + data['custom_field'] = 'custom value' + serializer = serializers.NewCustomSerializer(data=data) + serializer.is_valid(raise_exception=True) + user = serializer.save() + self.assertIsNotNone(user) + + def test_create_with_generic_relation(self): + first_tag = 'the_first_tag' + next_tag = 'the_next_tag' + data = { + 'tags': [ + {'tag': first_tag}, + {'tag': next_tag}, + ], + } + serializer = serializers.NewTaggedItemSerializer(data=data) + serializer.is_valid(raise_exception=True) + item = serializer.save() + self.assertIsNotNone(item) + self.assertEqual(2, models.Tag.objects.count()) + self.assertEqual(first_tag, item.tags.all()[0].tag) + self.assertEqual(next_tag, item.tags.all()[1].tag) + + def test_update(self): + serializer = serializers.NewUserSerializer(data=self.get_initial_data()) + serializer.is_valid(raise_exception=True) + user = serializer.save() + + # Check instances count + self.assertEqual(models.User.objects.count(), 1) + self.assertEqual(models.Profile.objects.count(), 1) + self.assertEqual(models.Site.objects.count(), 2) + self.assertEqual(models.Avatar.objects.count(), 2) + self.assertEqual(models.Message.objects.count(), 3) + + # Update + user_pk = user.pk + profile_pk = user.profile.pk + + message_to_update_str_pk = str(user.profile.message_set.first().pk) + message_to_update_pk = user.profile.message_set.last().pk + serializer = serializers.NewUserSerializer( + instance=user, + data={ + 'pk': user_pk, + 'username': 'new', + 'profile': { + 'pk': profile_pk, + 'access_key': None, + 'sites': [ + { + 'url': 'http://new-site.com', + }, + ], + 'avatars': [ + { + 'pk': user.profile.avatars.earliest('pk').pk, + 'image': 'old-image-1.png', + }, + { + 'image': 'new-image-1.png', + }, + { + 'image': 'new-image-2.png', + }, + ], + 'message_set': [ + { + 'pk': message_to_update_str_pk, + 'message': 'Old message 1' + }, + { + 'pk': message_to_update_pk, + 'message': 'Old message 2' + }, + { + 'message': 'New message 1' + } + ], + }, + }, + ) + + serializer.is_valid(raise_exception=True) + user = serializer.save() + user.refresh_from_db() + self.assertIsNotNone(user) + self.assertEqual(user.pk, user_pk) + self.assertEqual(user.username, 'new') + + profile = user.profile + self.assertIsNotNone(profile) + self.assertIsNone(profile.access_key) + self.assertEqual(profile.pk, profile_pk) + self.assertEqual(profile.sites.count(), 1) + self.assertSetEqual( + set(profile.sites.values_list('url', flat=True)), + {'http://new-site.com'} + ) + self.assertEqual(profile.avatars.count(), 3) + self.assertSetEqual( + set(profile.avatars.values_list('image', flat=True)), + {'old-image-1.png', 'new-image-1.png', 'new-image-2.png'} + ) + self.assertSetEqual( + set(profile.message_set.values_list('message', flat=True)), + {'Old message 1', 'Old message 2', 'New message 1'} + ) + # Check that message which supposed to be updated still in profile + # message_set (new message wasn't created instead of update) + self.assertIn( + message_to_update_pk, + profile.message_set.values_list('id', flat=True) + ) + self.assertIn( + uuid.UUID(message_to_update_str_pk), + profile.message_set.values_list('id', flat=True) + ) + + # Check instances count + self.assertEqual(models.User.objects.count(), 1) + self.assertEqual(models.Profile.objects.count(), 1) + self.assertEqual(models.Avatar.objects.count(), 3) + self.assertEqual(models.Message.objects.count(), 3) + # Access key shouldn't be removed because it is FK + self.assertEqual(models.AccessKey.objects.count(), 1) + # Sites shouldn't be deleted either as it is M2M + self.assertEqual(models.Site.objects.count(), 3) + + def test_update_reverse_one_to_one_without_pk(self): + serializer = serializers.NewUserSerializer(data=self.get_initial_data()) + serializer.is_valid(raise_exception=True) + user = serializer.save() + + # Check instances count + self.assertEqual(models.User.objects.count(), 1) + self.assertEqual(models.Profile.objects.count(), 1) + self.assertEqual(models.Site.objects.count(), 2) + self.assertEqual(models.Avatar.objects.count(), 2) + self.assertEqual(models.Message.objects.count(), 3) + + # Update + user_pk = user.pk + profile_pk = user.profile.pk + + message_to_update_str_pk = str(user.profile.message_set.first().pk) + message_to_update_pk = user.profile.message_set.last().pk + serializer = serializers.NewUserSerializer( + instance=user, + data={ + 'pk': user_pk, + 'username': 'new', + 'profile': { + # omit pk + 'access_key': None, + 'sites': [ + { + 'url': 'http://new-site.com', + }, + ], + 'avatars': [ + { + 'pk': user.profile.avatars.earliest('pk').pk, + 'image': 'old-image-1.png', + }, + { + 'image': 'new-image-1.png', + }, + { + 'image': 'new-image-2.png', + }, + ], + 'message_set': [ + { + 'pk': message_to_update_str_pk, + 'message': 'Old message 1' + }, + { + 'pk': message_to_update_pk, + 'message': 'Old message 2' + }, + { + 'message': 'New message 1' + } + ], + }, + }, + ) + + serializer.is_valid(raise_exception=True) + user = serializer.save() + user.refresh_from_db() + self.assertIsNotNone(user) + self.assertEqual(user.pk, user_pk) + self.assertEqual(user.username, 'new') + + profile = user.profile + self.assertIsNotNone(profile) + self.assertIsNone(profile.access_key) + self.assertEqual(profile.pk, profile_pk) + self.assertEqual(profile.sites.count(), 1) + self.assertSetEqual( + set(profile.sites.values_list('url', flat=True)), + {'http://new-site.com'} + ) + self.assertEqual(profile.avatars.count(), 3) + self.assertSetEqual( + set(profile.avatars.values_list('image', flat=True)), + {'old-image-1.png', 'new-image-1.png', 'new-image-2.png'} + ) + self.assertSetEqual( + set(profile.message_set.values_list('message', flat=True)), + {'Old message 1', 'Old message 2', 'New message 1'} + ) + # Check that message which supposed to be updated still in profile + # messages (new message wasn't created instead of update) + self.assertIn( + message_to_update_pk, + profile.message_set.values_list('id', flat=True) + ) + self.assertIn( + uuid.UUID(message_to_update_str_pk), + profile.message_set.values_list('id', flat=True) + ) + + # Check instances count + self.assertEqual(models.User.objects.count(), 1) + self.assertEqual(models.Profile.objects.count(), 1) + self.assertEqual(models.Avatar.objects.count(), 3) + self.assertEqual(models.Message.objects.count(), 3) + # Access key shouldn't be removed because it is FK + self.assertEqual(models.AccessKey.objects.count(), 1) + # Sites shouldn't be deleted either as it is M2M + self.assertEqual(models.Site.objects.count(), 3) + + def test_update_raise_protected_error(self): + serializer = serializers.NewUserSerializer(data=self.get_initial_data()) + serializer.is_valid(raise_exception=True) + user = serializer.save() + + user.user_avatar = user.profile.avatars.first() + user.save() + + # Since this is not a nested serializer, + serializer = serializers.NewBaseProfileSerializer( + instance=user.profile, + data={ + 'access_key': None, + 'sites': [], + 'avatars': [ + { + 'pk': user.profile.avatars.last().id, + 'image': 'old-image-1.png', + }, + { + 'image': 'new-image-1.png', + }, + ], + 'message_set': [], + } + ) + + serializer.is_valid(raise_exception=True) + # new-style classes don't catch this during Validation so we get a ProtectedError instead + # with self.assertRaises(ValidationError): + with self.assertRaises(ProtectedError): + # TODO: remove transaction.atomic after #48 will be closed + with transaction.atomic(): + serializer.save() + + # Check that protected avatar hasn't been deleted + self.assertEqual(models.Avatar.objects.count(), 2) + self.assertSetEqual( + set(models.Avatar.objects.values_list('pk', flat=True)), + { + user.profile.avatars.first().id, + user.profile.avatars.last().id + }) + + def test_update_with_empty_reverse_one_to_one(self): + serializer = serializers.NewUserSerializer(data=self.get_initial_data()) + serializer.is_valid(raise_exception=True) + user = serializer.save() + self.assertIsNotNone(user.profile) + + serializer = serializers.NewUserSerializer( + instance=user, + data={ + 'pk': user.pk, + 'username': 'new', + 'profile': None + } + ) + serializer.is_valid(raise_exception=True) + user = serializer.save() + self.assertFalse(models.Profile.objects.filter(user=user).exists()) + + def test_partial_update(self): + serializer = serializers.NewUserSerializer(data=self.get_initial_data()) + serializer.is_valid(raise_exception=True) + user = serializer.save() + + # Check instances count + self.assertEqual(models.User.objects.count(), 1) + self.assertEqual(models.Profile.objects.count(), 1) + self.assertEqual(models.Site.objects.count(), 2) + self.assertEqual(models.Avatar.objects.count(), 2) + self.assertEqual(models.AccessKey.objects.count(), 1) + + # Partial update + user_pk = user.pk + profile_pk = user.profile.pk + + serializer = serializers.NewUserSerializer( + instance=user, + partial=True, + data={ + 'pk': user_pk, + 'username': 'new', + } + ) + serializer.is_valid(raise_exception=True) + user = serializer.save() + user.refresh_from_db() + self.assertIsNotNone(user) + self.assertEqual(user.pk, user_pk) + self.assertEqual(user.username, 'new') + + profile = user.profile + self.assertIsNotNone(profile) + self.assertIsNotNone(profile.access_key) + self.assertEqual(profile.access_key.key, 'key') + self.assertEqual(profile.pk, profile_pk) + self.assertEqual(profile.sites.count(), 2) + self.assertSetEqual( + set(profile.sites.values_list('url', flat=True)), + {'http://google.com', 'http://yahoo.com'} + ) + self.assertEqual(profile.avatars.count(), 2) + self.assertSetEqual( + set(profile.avatars.values_list('image', flat=True)), + {'image-1.png', 'image-2.png'} + ) + + # Check instances count + self.assertEqual(models.User.objects.count(), 1) + self.assertEqual(models.Profile.objects.count(), 1) + self.assertEqual(models.Site.objects.count(), 2) + self.assertEqual(models.Avatar.objects.count(), 2) + self.assertEqual(models.AccessKey.objects.count(), 1) + + def test_partial_update_direct_fk(self): + serializer = serializers.NewUserSerializer(data=self.get_initial_data()) + serializer.is_valid(raise_exception=True) + user = serializer.save() + + # Check instances count + self.assertEqual(models.User.objects.count(), 1) + self.assertEqual(models.Profile.objects.count(), 1) + self.assertEqual(models.Site.objects.count(), 2) + self.assertEqual(models.Avatar.objects.count(), 2) + self.assertEqual(models.AccessKey.objects.count(), 1) + + # Partial update + user_pk = user.pk + profile_pk = user.profile.pk + access_key_pk = user.profile.access_key.pk + + serializer = serializers.NewUserSerializer( + instance=user, + partial=True, + data={ + 'pk': user_pk, + 'profile': { + 'pk': profile_pk, + 'access_key': { + 'pk': access_key_pk, + 'key': 'new', + } + }, + } + ) + serializer.is_valid(raise_exception=True) + user = serializer.save() + user.refresh_from_db() + self.assertIsNotNone(user) + self.assertEqual(user.pk, user_pk) + self.assertEqual(user.username, 'test') + + profile = user.profile + self.assertIsNotNone(profile) + access_key = profile.access_key + self.assertIsNotNone(access_key) + self.assertEqual(access_key.key, 'new') + self.assertEqual(access_key.pk, access_key_pk) + + # Check instances count + self.assertEqual(models.User.objects.count(), 1) + self.assertEqual(models.Profile.objects.count(), 1) + self.assertEqual(models.Site.objects.count(), 2) + self.assertEqual(models.Avatar.objects.count(), 2) + self.assertEqual(models.AccessKey.objects.count(), 1) + + def test_nested_partial_update(self): + serializer = serializers.NewUserSerializer(data=self.get_initial_data()) + serializer.is_valid(raise_exception=True) + user = serializer.save() + + # Check instances count + self.assertEqual(models.User.objects.count(), 1) + self.assertEqual(models.Profile.objects.count(), 1) + self.assertEqual(models.Site.objects.count(), 2) + self.assertEqual(models.Avatar.objects.count(), 2) + self.assertEqual(models.AccessKey.objects.count(), 1) + + # Partial update + user_pk = user.pk + profile_pk = user.profile.pk + + serializer = serializers.NewUserSerializer( + instance=user, + partial=True, + data={ + 'pk': user_pk, + 'profile': { + 'pk': profile_pk, + 'access_key': { + 'key': 'new', + } + }, + } + ) + serializer.is_valid(raise_exception=True) + user = serializer.save() + user.refresh_from_db() + self.assertIsNotNone(user) + self.assertEqual(user.pk, user_pk) + self.assertEqual(user.username, 'test') + + profile = user.profile + self.assertIsNotNone(profile) + self.assertIsNotNone(profile.access_key) + self.assertEqual(profile.access_key.key, 'new') + self.assertEqual(profile.pk, profile_pk) + self.assertEqual(profile.sites.count(), 2) + self.assertSetEqual( + set(profile.sites.values_list('url', flat=True)), + {'http://google.com', 'http://yahoo.com'} + ) + self.assertEqual(profile.avatars.count(), 2) + self.assertSetEqual( + set(profile.avatars.values_list('image', flat=True)), + {'image-1.png', 'image-2.png'} + ) + + # Check instances count + self.assertEqual(models.User.objects.count(), 1) + self.assertEqual(models.Profile.objects.count(), 1) + self.assertEqual(models.Site.objects.count(), 2) + self.assertEqual(models.Avatar.objects.count(), 2) + # Old access key shouldn't be deleted + self.assertEqual(models.AccessKey.objects.count(), 2) + + def test_nested_partial_update_failed_with_empty_direct_fk_object(self): + serializer = serializers.NewUserSerializer(data=self.get_initial_data()) + serializer.is_valid(raise_exception=True) + user = serializer.save() + + # Check nested instances is None + self.assertIsNone(user.user_avatar) + + serializer = serializers.NewUserSerializer( + instance=user, + partial=True, + data={ + 'username': 'new', + 'user_avatar': {}, + } + ) + serializer.is_valid(raise_exception=True) + with self.assertRaises(ValidationError): + serializer.save() + + def test_update_with_generic_relation(self): + item = models.TaggedItem.objects.create() + serializer = serializers.NewTaggedItemSerializer( + instance=item, + data={ + 'tags': [{ + 'tag': 'the_tag', + }] + } + ) + serializer.is_valid(raise_exception=True) + serializer.save() + item.refresh_from_db() + self.assertEqual(1, item.tags.count()) + + serializer = serializers.NewTaggedItemSerializer( + instance=item, + data={ + 'tags': [{ + 'pk': item.tags.get().pk, + 'tag': 'the_new_tag', + }] + } + ) + serializer.is_valid(raise_exception=True) + serializer.save() + item.refresh_from_db() + self.assertEqual('the_new_tag', item.tags.get().tag) + + serializer = serializers.NewTaggedItemSerializer( + instance=item, + data={ + 'tags': [{ + 'tag': 'the_third_tag', + }] + } + ) + serializer.is_valid(raise_exception=True) + serializer.save() + item.refresh_from_db() + self.assertEqual(1, item.tags.count()) + self.assertEqual('the_third_tag', item.tags.get().tag) + + def test_create_m2m_with_existing_related_objects(self): + users = [ + models.User.objects.create(username='first user'), + models.User.objects.create(username='second user'), + ] + users_data = serializers.NewUserSerializer( + users, + many=True + ).data + print("users_data: {}".format(users_data)) + users_data.append({'username': 'third user'}) + data = { + 'name': 'Team', + 'members': users_data, + } + serializer = serializers.NewTeamSerializer(data=data) + self.assertTrue(serializer.is_valid()) + team = serializer.save() + self.assertEqual(3, team.members.count()) + self.assertEqual(3, models.User.objects.count()) + self.assertEqual('first user', team.members.first().username) + + # Update + data = serializers.NewTeamSerializer(team).data + data['members'].append({'username': 'fourth user'}) + serializer = serializers.NewTeamSerializer(team, data=data) + self.assertTrue(serializer.is_valid()) + team = serializer.save() + self.assertEqual(4, team.members.count()) + self.assertEqual(4, models.User.objects.count()) + self.assertEqual('fourth user', team.members.last().username) + + def test_create_fk_with_existing_related_object(self): + user = models.User.objects.create(username='user one') + profile = models.Profile.objects.create(user=user) + avatar = models.Avatar.objects.create(profile=profile) + data = self.get_initial_data() + # sets one of the avatars to the PK of the existing (expecting a match) + data['profile']['avatars'][0]['pk'] = avatar.pk + serializer = serializers.NewUserSerializer(data=data) + self.assertTrue(serializer.is_valid()) + new_user = serializer.save() + # one created, one match + self.assertEqual(2, models.Avatar.objects.count()) + avatar.refresh_from_db() + self.assertEqual('image-1.png', avatar.image) + self.assertNotEqual(new_user.profile, profile) + self.assertEqual(new_user.profile, avatar.profile) + + def test_create_with_existing_direct_fk_object(self): + access_key = models.AccessKey.objects.create( + key='the-key', + ) + serializer = serializers.NewAccessKeySerializer( + instance=access_key, + ) + data = self.get_initial_data() + data['profile']['access_key'] = serializer.data + data['profile']['access_key']['key'] = 'new-key' + serializer = serializers.NewUserSerializer( + data=data, + ) + self.assertTrue(serializer.is_valid()) + user = serializer.save() + access_key.refresh_from_db() + self.assertEqual(access_key, user.profile.access_key) + self.assertEqual('new-key', access_key.key) + + def test_create_with_save_kwargs(self): + data = self.get_initial_data() + serializer = serializers.NewUserSerializer(data=data) + serializer.is_valid(raise_exception=True) + user = serializer.save( + profile={ + 'access_key': {'key': 'key2'}, + 'sites': {'url': 'http://test.com'} + }, + ) + self.assertEqual('key2', user.profile.access_key.key) + sites = list(user.profile.sites.all()) + self.assertEqual('http://test.com', sites[0].url) + self.assertEqual('http://test.com', sites[1].url) + + def test_create_with_save_kwargs_failed(self): + data = self.get_initial_data() + serializer = serializers.NewUserSerializer(data=data) + serializer.is_valid(raise_exception=True) + + with self.assertRaises(TypeError): + user = serializer.save( + profile=None, + ) + + def test_custom_pk(self): + data = { + 'username': 'username', + 'custompks': [{ + 'slug': 'custom-key', + }] + } + serializer = serializers.NewUserWithCustomPKSerializer( + data=data, + ) + self.assertTrue(serializer.is_valid()) + user = serializer.save() + self.assertEqual('custom-key', + user.custompks.first().slug) + data['custompks'].append({ + 'slug': 'next-key', + }) + data['custompks'][0]['slug'] = 'key2' + serializer = serializers.NewUserWithCustomPKSerializer( + data=data, + instance=user, + ) + self.assertTrue(serializer.is_valid()) + user = serializer.save() + user.refresh_from_db() + custompks = list(user.custompks.all()) + self.assertEqual(2, len(custompks)) + self.assertEqual('key2', custompks[0].slug) + self.assertEqual('next-key', custompks[1].slug) + self.assertEqual(2, models.CustomPK.objects.count()) + + def get_another_initial_data(self): + return { + 'username': 'test', + 'another_profile': { + 'another_access_key': { + 'key': 'key', + }, + 'another_sites': [ + { + 'url': 'http://google.com', + }, + { + 'url': 'http://yahoo.com', + }, + ], + 'another_avatars': [ + { + 'image': 'image-1.png', + }, + { + 'image': 'image-2.png', + }, + ], + }, + } + + def test_create_another_user_with_explicit_source(self): + serializer = serializers.NewAnotherUserSerializer( + data=self.get_another_initial_data()) + serializer.is_valid(raise_exception=True) + user = serializer.save() + self.assertIsNotNone(user) + self.assertEqual(user.username, 'test') + + profile = user.anotherprofile + self.assertIsNotNone(profile) + self.assertIsNotNone(profile.access_key) + self.assertEqual(profile.access_key.key, 'key') + self.assertEqual(profile.sites.count(), 2) + self.assertSetEqual( + set(profile.sites.values_list('url', flat=True)), + {'http://google.com', 'http://yahoo.com'} + ) + self.assertEqual(profile.avatars.count(), 2) + self.assertSetEqual( + set(profile.avatars.values_list('image', flat=True)), + {'image-1.png', 'image-2.png'} + ) + # Check instances count + self.assertEqual(models.User.objects.count(), 1) + self.assertEqual(models.AnotherProfile.objects.count(), 1) + self.assertEqual(models.Site.objects.count(), 2) + self.assertEqual(models.AnotherAvatar.objects.count(), 2) + self.assertEqual(models.AccessKey.objects.count(), 1) + + def test_update_another_user_with_explicit_source(self): + serializer = serializers.NewAnotherUserSerializer( + data=self.get_another_initial_data()) + serializer.is_valid(raise_exception=True) + user = serializer.save() + + # Update + user_pk = user.pk + profile_pk = user.anotherprofile.pk + + serializer = serializers.NewAnotherUserSerializer( + instance=user, + data={ + 'pk': user_pk, + 'username': 'new', + 'another_profile': { + 'pk': profile_pk, + 'another_access_key': None, + 'another_sites': [ + { + 'url': 'http://new-site.com', + }, + ], + 'another_avatars': [ + { + 'pk': user.anotherprofile.avatars.earliest('pk').pk, + 'image': 'old-image-1.png', + }, + { + 'image': 'new-image-1.png', + }, + { + 'image': 'new-image-2.png', + }, + ], + }, + }, + ) + + serializer.is_valid(raise_exception=True) + user = serializer.save() + user.refresh_from_db() + self.assertIsNotNone(user) + self.assertEqual(user.pk, user_pk) + self.assertEqual(user.username, 'new') + + profile = user.anotherprofile + self.assertIsNotNone(profile) + self.assertIsNone(profile.access_key) + self.assertEqual(profile.pk, profile_pk) + self.assertEqual(profile.sites.count(), 1) + self.assertSetEqual( + set(profile.sites.values_list('url', flat=True)), + {'http://new-site.com'} + ) + self.assertEqual(profile.avatars.count(), 3) + self.assertSetEqual( + set(profile.avatars.values_list('image', flat=True)), + {'old-image-1.png', 'new-image-1.png', 'new-image-2.png'} + ) + + # Check instances count + self.assertEqual(models.User.objects.count(), 1) + self.assertEqual(models.AnotherProfile.objects.count(), 1) + self.assertEqual(models.AnotherAvatar.objects.count(), 3) + # Access key shouldn't be removed because it is FK + self.assertEqual(models.AccessKey.objects.count(), 1) + # Sites shouldn't be deleted either as it is M2M + self.assertEqual(models.Site.objects.count(), 3) + + def test_create_with_html_input_data(self): + """Serializer should not fail if request type is multipart + """ + # DRF sets data to `QueryDict` when request type is `multipart` + data = QueryDict('name=team') + serializer = serializers.NewTeamSerializer( + data=data + ) + serializer.is_valid(raise_exception=True) + team = serializer.save() + + self.assertTrue(models.Team.objects.filter(id=team.id).exists()) + self.assertEqual(team.name, 'team') + + def test_create_with_file(self): + data = { + 'page.title': 'some page', + 'source': get_sample_file(name='sample name') + } + qdict = QueryDict('', mutable=True) + qdict.update(data) + + serializer = serializers.NewDocumentSerializer( + data=qdict + ) + serializer.is_valid(raise_exception=True) + doc = serializer.save() + + self.assertTrue(models.Document.objects.filter(pk=doc.pk).exists()) + self.assertEqual(doc.page.title, 'some page') + + +class WritableNestedModelSerializerIssuesTest(TestCase): + def test_issue_86(self): + serializer = serializers.NewI86GenreSerializer(data={ + 'names': [ + { + 'string': 'Genre' + } + ] + }) + self.assertTrue(serializer.is_valid()) + instance = serializer.save() + + update_serializer = serializers.NewI86GenreSerializer( + instance=instance, + data={ + 'id': instance.pk, + 'names': [ + { + 'id': instance.names.first().pk, + 'string': 'Genre changed' + } + ] + } + ) + self.assertTrue(update_serializer.is_valid()) + update_serializer.save() + self.assertEqual(serializer.data['id'], update_serializer.data['id']) + self.assertEqual( + serializer.data['names'][0]['id'], + update_serializer.data['names'][0]['id']) + + +class ReadOnlyIssueTest(TestCase): + def test_pr_101_readonly_issue(self): + child = models.ReadOnlyChild.objects.create(name='blue') + parent = models.ReadOnlyParent.objects.create(child=child) + serializer = serializers.NewReadOnlyParentSerializer(data={ + 'id': parent.pk, + 'child': { + 'id': child.pk, + 'name': 'ReadOnly' + } + }) + self.assertTrue(serializer.is_valid()) + instance = serializer.save() + self.assertEqual( + child.pk, + instance.child.pk, + ) + self.assertEqual( + 'blue', + instance.child.name, + )