diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index ca60810df1..176c8e9400 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -1435,20 +1435,26 @@ def get_extra_kwargs(self): def get_unique_together_constraints(self, model): """ - Returns iterator of (fields, queryset, condition_fields, condition), + Returns iterator of (fields, queryset, condition_fields, condition, nulls_distinct), each entry describes an unique together constraint on `fields` in `queryset` - with respect of constraint's `condition`. + with respect of constraint's `condition` and `nulls_distinct` option. """ for parent_class in [model] + list(model._meta.parents): for unique_together in parent_class._meta.unique_together: - yield unique_together, model._default_manager, [], None + yield unique_together, model._default_manager, [], None, None for constraint in parent_class._meta.constraints: if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1: if constraint.condition is None: condition_fields = [] else: condition_fields = list(get_referenced_base_fields_from_q(constraint.condition)) - yield (constraint.fields, model._default_manager, condition_fields, constraint.condition) + yield ( + constraint.fields, + model._default_manager, + condition_fields, + constraint.condition, + getattr(constraint, 'nulls_distinct', None), + ) def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs): """ @@ -1481,7 +1487,7 @@ def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs # Include each of the `unique_together` and `UniqueConstraint` field names, # so long as all the field names are included on the serializer. - for unique_together_list, queryset, condition_fields, condition in self.get_unique_together_constraints(model): + for unique_together_list, queryset, condition_fields, condition, nulls_distinct in self.get_unique_together_constraints(model): unique_together_list_and_condition_fields = set(unique_together_list) | set(condition_fields) if model_fields_names.issuperset(unique_together_list_and_condition_fields): unique_constraint_names |= unique_together_list_and_condition_fields @@ -1624,7 +1630,7 @@ def get_unique_together_validators(self): # Note that we make sure to check `unique_together` both on the # base model class, but also on any parent classes. validators = [] - for unique_together, queryset, condition_fields, condition in self.get_unique_together_constraints(self.Meta.model): + for unique_together, queryset, condition_fields, condition, nulls_distinct in self.get_unique_together_constraints(self.Meta.model): # Skip if serializer does not map to all unique together sources unique_together_and_condition_fields = set(unique_together) | set(condition_fields) if not set(source_map).issuperset(unique_together_and_condition_fields): @@ -1658,6 +1664,7 @@ def get_unique_together_validators(self): condition=condition, message=violation_error_message, code=getattr(constraint, 'violation_error_code', None), + nulls_distinct=nulls_distinct, ) validators.append(validator) return validators diff --git a/rest_framework/validators.py b/rest_framework/validators.py index cc759b39cc..4e58a20767 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -113,13 +113,14 @@ class UniqueTogetherValidator: requires_context = True code = 'unique' - def __init__(self, queryset, fields, message=None, condition_fields=None, condition=None, code=None): + def __init__(self, queryset, fields, message=None, condition_fields=None, condition=None, code=None, nulls_distinct=None): self.queryset = queryset self.fields = fields self.message = message or self.message self.condition_fields = [] if condition_fields is None else condition_fields self.condition = condition self.code = code or self.code + self.nulls_distinct = nulls_distinct def enforce_required_fields(self, attrs, serializer): """ @@ -197,17 +198,21 @@ def __call__(self, attrs, serializer): else getattr(serializer.instance, source) for source in condition_sources } - if checked_values and None not in checked_values and qs_exists_with_condition(queryset, self.condition, condition_kwargs): - field_names = ', '.join(self.fields) - message = self.message.format(field_names=field_names) - raise ValidationError(message, code=self.code) + if checked_values: + # Skip validation for None values unless nulls_distinct is False + if self.nulls_distinct is not False and None in checked_values: + return + if qs_exists_with_condition(queryset, self.condition, condition_kwargs): + field_names = ', '.join(self.fields) + message = self.message.format(field_names=field_names) + raise ValidationError(message, code=self.code) def __repr__(self): return '<{}({})>'.format( self.__class__.__name__, ', '.join( f'{attr}={smart_repr(getattr(self, attr))}' - for attr in ('queryset', 'fields', 'condition') + for attr in ('queryset', 'fields', 'condition', 'nulls_distinct') if getattr(self, attr) is not None) ) @@ -220,6 +225,7 @@ def __eq__(self, other): and self.queryset == other.queryset and self.fields == other.fields and self.code == other.code + and self.nulls_distinct == other.nulls_distinct ) diff --git a/tests/test_validators.py b/tests/test_validators.py index 96354b9b13..39b86ce63e 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -616,6 +616,23 @@ class Meta: ] +# Only define nulls_distinct model for Django 5.0+ +if django_version >= (5, 0): + class UniqueConstraintNullsDistinctModel(models.Model): + name = models.CharField(max_length=100) + code = models.CharField(max_length=100, null=True) + category = models.CharField(max_length=100, null=True) + + class Meta: + constraints = [ + models.UniqueConstraint( + name='unique_code_category_nulls_not_distinct', + fields=('code', 'category'), + nulls_distinct=False, + ), + ] + + class UniqueConstraintCustomMessageCodeModel(models.Model): username = models.CharField(max_length=32) company_id = models.IntegerField() @@ -1063,3 +1080,118 @@ def test_equality_operator(self): assert validator == validator2 validator2.date_field = "bar2" assert validator != validator2 + + +# Tests for `nulls_distinct` option (Django 5.0+) +# ----------------------------------------------- + +@pytest.mark.skipif( + django_version < (5, 0), + reason="nulls_distinct requires Django 5.0+" +) +class TestUniqueConstraintNullsDistinct(TestCase): + """ + Tests for UniqueConstraint with nulls_distinct=False option. + When nulls_distinct=False, NULL values should be treated as equal + for uniqueness validation. + """ + + def setUp(self): + from tests.test_validators import UniqueConstraintNullsDistinctModel + + class UniqueConstraintNullsDistinctSerializer(serializers.ModelSerializer): + class Meta: + model = UniqueConstraintNullsDistinctModel + fields = ('name', 'code', 'category') + + self.serializer_class = UniqueConstraintNullsDistinctSerializer + + def test_nulls_distinct_false_validates_null_as_duplicate(self): + """ + When nulls_distinct=False, creating a second record with NULL values + in the constrained fields should fail validation. + """ + from tests.test_validators import UniqueConstraintNullsDistinctModel + + # Create first record with NULL values + UniqueConstraintNullsDistinctModel.objects.create( + name='First', + code=None, + category=None + ) + + # Attempt to create second record with same NULL values + serializer = self.serializer_class(data={ + 'name': 'Second', + 'code': None, + 'category': None + }) + + # Should fail validation because nulls_distinct=False + assert not serializer.is_valid() + + def test_nulls_distinct_false_allows_different_non_null_values(self): + """ + Non-NULL values should still work normally with uniqueness validation. + """ + from tests.test_validators import UniqueConstraintNullsDistinctModel + + # Create first record with non-NULL values + UniqueConstraintNullsDistinctModel.objects.create( + name='First', + code='A', + category='X' + ) + + # Create second record with different values - should pass + serializer = self.serializer_class(data={ + 'name': 'Second', + 'code': 'B', + 'category': 'Y' + }) + assert serializer.is_valid(), serializer.errors + + def test_nulls_distinct_false_rejects_duplicate_non_null_values(self): + """ + Duplicate non-NULL values should still fail validation. + """ + from tests.test_validators import UniqueConstraintNullsDistinctModel + + # Create first record + UniqueConstraintNullsDistinctModel.objects.create( + name='First', + code='A', + category='X' + ) + + # Attempt to create duplicate - should fail + serializer = self.serializer_class(data={ + 'name': 'Second', + 'code': 'A', + 'category': 'X' + }) + assert not serializer.is_valid() + + def test_unique_together_validator_nulls_distinct_equality(self): + """ + Test that UniqueTogetherValidator equality considers nulls_distinct. + """ + mock_queryset = MagicMock() + validator1 = UniqueTogetherValidator( + queryset=mock_queryset, + fields=('a', 'b'), + nulls_distinct=False + ) + validator2 = UniqueTogetherValidator( + queryset=mock_queryset, + fields=('a', 'b'), + nulls_distinct=False + ) + validator3 = UniqueTogetherValidator( + queryset=mock_queryset, + fields=('a', 'b'), + nulls_distinct=True + ) + + assert validator1 == validator2 + assert validator1 != validator3