Skip to content

Commit b13e1db

Browse files
committed
Add support for registering with dependency and implementation being the same.
1 parent feaee14 commit b13e1db

File tree

7 files changed

+107
-3
lines changed

7 files changed

+107
-3
lines changed

src/dependency_injection/container.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,23 @@ def get_instance(cls, name: str=None) -> Self:
2929

3030
return cls._instances[(cls, name)]
3131

32-
def register_transient(self, dependency: Type, implementation: Type, constructor_args: Optional[Dict[str, Any]] = None) -> None:
32+
def register_transient(self, dependency: Type, implementation: Optional[Type] = None, constructor_args: Optional[Dict[str, Any]] = None) -> None:
33+
if implementation is None:
34+
implementation = dependency
3335
if dependency in self._registrations:
3436
raise ValueError(f"Dependency {dependency} is already registered.")
3537
self._registrations[dependency] = Registration(dependency, implementation, Scope.TRANSIENT, constructor_args)
3638

37-
def register_scoped(self, dependency: Type, implementation: Type, constructor_args: Optional[Dict[str, Any]] = None) -> None:
39+
def register_scoped(self, dependency: Type, implementation: Optional[Type] = None, constructor_args: Optional[Dict[str, Any]] = None) -> None:
40+
if implementation is None:
41+
implementation = dependency
3842
if dependency in self._registrations:
3943
raise ValueError(f"Dependency {dependency} is already registered.")
4044
self._registrations[dependency] = Registration(dependency, implementation, Scope.SCOPED, constructor_args)
4145

42-
def register_singleton(self, dependency: Type, implementation: Type, constructor_args: Optional[Dict[str, Any]] = None) -> None:
46+
def register_singleton(self, dependency: Type, implementation: Optional[Type] = None, constructor_args: Optional[Dict[str, Any]] = None) -> None:
47+
if implementation is None:
48+
implementation = dependency
4349
if dependency in self._registrations:
4450
raise ValueError(f"Dependency {dependency} is already registered.")
4551
self._registrations[dependency] = Registration(dependency, implementation, Scope.SINGLETON, constructor_args)

tests/unit_test/container/register/test_register_scoped.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,19 @@ class Car(Vehicle):
4444
# act + assert
4545
with pytest.raises(ValueError, match="is already registered"):
4646
dependency_container.register_scoped(dependency, implementation)
47+
48+
def test_register_scoped_when_dependency_and_implementation_being_the_same(
49+
self,
50+
):
51+
# arrange
52+
class Vehicle:
53+
pass
54+
55+
dependency_container = DependencyContainer.get_instance()
56+
dependency = Vehicle
57+
58+
# act
59+
dependency_container.register_scoped(dependency)
60+
61+
# assert (no exception thrown)
62+

tests/unit_test/container/register/test_register_singleton.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,18 @@ class Car(Vehicle):
4444
# act + assert
4545
with pytest.raises(ValueError, match="is already registered"):
4646
dependency_container.register_singleton(dependency, implementation)
47+
48+
def test_register_singleton_when_dependency_and_implementation_being_the_same(
49+
self,
50+
):
51+
# arrange
52+
class Vehicle:
53+
pass
54+
55+
dependency_container = DependencyContainer.get_instance()
56+
dependency = Vehicle
57+
58+
# act
59+
dependency_container.register_singleton(dependency)
60+
61+
# assert (no exception thrown)

tests/unit_test/container/register/test_register_transient.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,18 @@ class Car(Vehicle):
4343
# act + assert
4444
with pytest.raises(ValueError, match="is already registered"):
4545
dependency_container.register_transient(dependency, implementation)
46+
47+
def test_register_transient_when_dependency_and_implementation_being_the_same(
48+
self,
49+
):
50+
# arrange
51+
class Vehicle:
52+
pass
53+
54+
dependency_container = DependencyContainer.get_instance()
55+
dependency = Vehicle
56+
57+
# act
58+
dependency_container.register_transient(dependency)
59+
60+
# assert (no exception thrown)

tests/unit_test/container/resolve/test_resolve_scoped.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,21 @@ class Car(Vehicle):
4747

4848
# assert
4949
self.assertNotEqual(resolved_dependency_in_scope_1, resolved_dependency_in_scope_2)
50+
51+
def test_resolve_scoped_when_registered_with_dependency_and_implementation_being_the_same_returns_an_instance(
52+
self,
53+
):
54+
# arrange
55+
class Vehicle:
56+
pass
57+
58+
dependency_container = DependencyContainer.get_instance()
59+
dependency = Vehicle
60+
dependency_container.register_scoped(dependency)
61+
62+
# act
63+
resolved_dependency = dependency_container.resolve(dependency)
64+
65+
# assert
66+
self.assertIsInstance(resolved_dependency, Vehicle)
67+

tests/unit_test/container/resolve/test_resolve_singleton.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,20 @@ class Car(Vehicle):
4646

4747
# assert
4848
self.assertEqual(resolved_dependency_1, resolved_dependency_2)
49+
50+
def test_resolve_singleton_when_registered_with_dependency_and_implementation_being_the_same_returns_an_instance(
51+
self,
52+
):
53+
# arrange
54+
class Vehicle:
55+
pass
56+
57+
dependency_container = DependencyContainer.get_instance()
58+
dependency = Vehicle
59+
dependency_container.register_singleton(dependency)
60+
61+
# act
62+
resolved_dependency = dependency_container.resolve(dependency)
63+
64+
# assert
65+
self.assertIsInstance(resolved_dependency, Vehicle)

tests/unit_test/container/resolve/test_resolve_transient.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,20 @@ class Car(Vehicle):
4646

4747
# assert
4848
self.assertNotEqual(resolved_dependency_1, resolved_dependency_2)
49+
50+
def test_resolve_transient_when_registered_with_dependency_and_implementation_being_the_same_returns_an_instance(
51+
self,
52+
):
53+
# arrange
54+
class Vehicle:
55+
pass
56+
57+
dependency_container = DependencyContainer.get_instance()
58+
dependency = Vehicle
59+
dependency_container.register_transient(dependency)
60+
61+
# act
62+
resolved_dependency = dependency_container.resolve(dependency)
63+
64+
# assert
65+
self.assertIsInstance(resolved_dependency, Vehicle)

0 commit comments

Comments
 (0)