Skip to content

Commit f35f441

Browse files
committed
Add support for registering with, and resolving by, tags.
1 parent b13e1db commit f35f441

File tree

4 files changed

+177
-9
lines changed

4 files changed

+177
-9
lines changed

src/dependency_injection/container.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import inspect
22

3-
from typing import Any, Dict, Optional, TypeVar, Type
3+
from typing import Any, Dict, List, Optional, TypeVar, Type
44

55
from dependency_injection.registration import Registration
66
from dependency_injection.scope import DEFAULT_SCOPE_NAME, Scope
@@ -29,26 +29,26 @@ def get_instance(cls, name: str=None) -> Self:
2929

3030
return cls._instances[(cls, name)]
3131

32-
def register_transient(self, dependency: Type, implementation: Optional[Type] = None, constructor_args: Optional[Dict[str, Any]] = None) -> None:
32+
def register_transient(self, dependency: Type, implementation: Optional[Type] = None, tags: Optional[set] = None, constructor_args: Optional[Dict[str, Any]] = None) -> None:
3333
if implementation is None:
3434
implementation = dependency
3535
if dependency in self._registrations:
3636
raise ValueError(f"Dependency {dependency} is already registered.")
37-
self._registrations[dependency] = Registration(dependency, implementation, Scope.TRANSIENT, constructor_args)
37+
self._registrations[dependency] = Registration(dependency, implementation, Scope.TRANSIENT, tags, constructor_args)
3838

39-
def register_scoped(self, dependency: Type, implementation: Optional[Type] = None, constructor_args: Optional[Dict[str, Any]] = None) -> None:
39+
def register_scoped(self, dependency: Type, implementation: Optional[Type] = None, tags: Optional[set] = None, constructor_args: Optional[Dict[str, Any]] = None) -> None:
4040
if implementation is None:
4141
implementation = dependency
4242
if dependency in self._registrations:
4343
raise ValueError(f"Dependency {dependency} is already registered.")
44-
self._registrations[dependency] = Registration(dependency, implementation, Scope.SCOPED, constructor_args)
44+
self._registrations[dependency] = Registration(dependency, implementation, Scope.SCOPED, tags, constructor_args)
4545

46-
def register_singleton(self, dependency: Type, implementation: Optional[Type] = None, constructor_args: Optional[Dict[str, Any]] = None) -> None:
46+
def register_singleton(self, dependency: Type, implementation: Optional[Type] = None, tags: Optional[set] = None, constructor_args: Optional[Dict[str, Any]] = None) -> None:
4747
if implementation is None:
4848
implementation = dependency
4949
if dependency in self._registrations:
5050
raise ValueError(f"Dependency {dependency} is already registered.")
51-
self._registrations[dependency] = Registration(dependency, implementation, Scope.SINGLETON, constructor_args)
51+
self._registrations[dependency] = Registration(dependency, implementation, Scope.SINGLETON, tags, constructor_args)
5252

5353
def resolve(self, dependency: Type, scope_name: str = DEFAULT_SCOPE_NAME) -> Type:
5454
if scope_name not in self._scoped_instances:
@@ -90,6 +90,15 @@ def resolve(self, dependency: Type, scope_name: str = DEFAULT_SCOPE_NAME) -> Typ
9090

9191
raise ValueError(f"Invalid dependency scope: {scope}")
9292

93+
def resolve_all(self, tags: Optional[set] = None) -> List[Any]:
94+
tags = [] if not tags else tags
95+
resolved_dependencies = []
96+
for registration in self._registrations.values():
97+
if not len(tags) or tags.intersection(registration.tags):
98+
resolved_dependencies.append(
99+
self.resolve(registration.dependency))
100+
return resolved_dependencies
101+
93102
def _validate_constructor_args(self, constructor_args: Dict[str, Any], implementation: Type) -> None:
94103
constructor = inspect.signature(implementation.__init__).parameters
95104

src/dependency_injection/registration.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
class Registration():
77

8-
def __init__(self, dependency: Type, implementation: Type, scope: Scope, constructor_args: Optional[Dict[str, Any]] = None):
8+
def __init__(self, dependency: Type, implementation: Type, scope: Scope, tags: Optional[set] = None, constructor_args: Optional[Dict[str, Any]] = None):
99
self.dependency = dependency
1010
self.implementation = implementation
1111
self.scope = scope
12+
self.tags = tags or set()
1213
self.constructor_args = constructor_args or {}

tests/unit_test/container/register/test_register_transient.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class Car(Vehicle):
4444
with pytest.raises(ValueError, match="is already registered"):
4545
dependency_container.register_transient(dependency, implementation)
4646

47-
def test_register_transient_when_dependency_and_implementation_being_the_same(
47+
def test_register_transient_success_when_dependency_and_implementation_same(
4848
self,
4949
):
5050
# arrange
@@ -58,3 +58,38 @@ class Vehicle:
5858
dependency_container.register_transient(dependency)
5959

6060
# assert (no exception thrown)
61+
62+
def test_register_transient_fails_when_already_registered_and_dependency_and_implementation_same(
63+
self,
64+
):
65+
# arrange
66+
class Vehicle:
67+
pass
68+
69+
class Car(Vehicle):
70+
pass
71+
72+
dependency_container = DependencyContainer.get_instance()
73+
dependency_container.register_transient(Vehicle, Car)
74+
75+
# act + assert
76+
with pytest.raises(ValueError, match="is already registered"):
77+
dependency_container.register_transient(Vehicle)
78+
79+
def test_register_transient_success_when_other_dependency_registered_of_implementation_ancestor_class(
80+
self,
81+
):
82+
# arrange
83+
class Vehicle:
84+
pass
85+
86+
class Car(Vehicle):
87+
pass
88+
89+
dependency_container = DependencyContainer.get_instance()
90+
dependency_container.register_transient(Vehicle, Car)
91+
92+
# act
93+
dependency_container.register_transient(Car)
94+
95+
# assert (no exception thrown)
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from dependency_injection.container import DependencyContainer
2+
from unit_test.unit_test_case import UnitTestCase
3+
4+
5+
class TestResolveAll(UnitTestCase):
6+
7+
def test_returns_dependency_with_tag(
8+
self,
9+
):
10+
# arrange
11+
class Driveable:
12+
pass
13+
14+
class Vehicle:
15+
pass
16+
17+
class Car(Vehicle):
18+
pass
19+
20+
dependency_container = DependencyContainer.get_instance()
21+
dependency_container.register_transient(Vehicle, Car, tags={Driveable})
22+
23+
# act
24+
resolved_dependencies = dependency_container.resolve_all(tags={Driveable})
25+
26+
# assert
27+
self.assertEqual(len(resolved_dependencies), 1)
28+
self.assertIsInstance(resolved_dependencies[0], Car)
29+
30+
def test_returns_all_dependencies_with_tag(
31+
self,
32+
):
33+
# arrange
34+
class Driveable:
35+
pass
36+
37+
class Vehicle:
38+
pass
39+
40+
class Car(Vehicle):
41+
pass
42+
43+
class Innovation:
44+
pass
45+
46+
dependency_container = DependencyContainer.get_instance()
47+
dependency_container.register_transient(Vehicle, tags={Driveable})
48+
dependency_container.register_transient(Car, tags={Driveable})
49+
dependency_container.register_transient(Innovation, tags={Driveable})
50+
51+
# act
52+
resolved_dependencies = dependency_container.resolve_all(tags={Driveable})
53+
54+
# assert
55+
self.assertEqual(len(resolved_dependencies), 3)
56+
self.assertTrue(any(isinstance(dependency, Vehicle) for dependency in resolved_dependencies))
57+
self.assertTrue(any(isinstance(dependency, Car) for dependency in resolved_dependencies))
58+
self.assertTrue(any(isinstance(dependency, Innovation) for dependency in resolved_dependencies))
59+
60+
def test_does_not_return_dependency_without_tag(
61+
self,
62+
):
63+
# arrange
64+
class Driveable:
65+
pass
66+
67+
class Refuelable:
68+
pass
69+
70+
class Vehicle:
71+
pass
72+
73+
class Car(Vehicle):
74+
pass
75+
76+
class Innovation:
77+
pass
78+
79+
dependency_container = DependencyContainer.get_instance()
80+
dependency_container.register_transient(Vehicle, tags={Driveable, Refuelable})
81+
dependency_container.register_transient(Car, tags={Driveable, Refuelable})
82+
dependency_container.register_transient(Innovation, tags={Driveable})
83+
84+
# act
85+
resolved_dependencies = dependency_container.resolve_all(tags={Refuelable})
86+
87+
# assert
88+
self.assertEqual(len(resolved_dependencies), 2)
89+
self.assertTrue(any(isinstance(dependency, Vehicle) for dependency in resolved_dependencies))
90+
self.assertTrue(any(isinstance(dependency, Car) for dependency in resolved_dependencies))
91+
92+
def test_returns_all_dependencies_when_no_tag_specified(
93+
self,
94+
):
95+
# arrange
96+
class Driveable:
97+
pass
98+
99+
class Refuelable:
100+
pass
101+
102+
class Vehicle:
103+
pass
104+
105+
class Car(Vehicle):
106+
pass
107+
108+
class Innovation:
109+
pass
110+
111+
dependency_container = DependencyContainer.get_instance()
112+
dependency_container.register_transient(Vehicle, tags={Driveable, Refuelable})
113+
dependency_container.register_transient(Car, tags={Driveable, Refuelable})
114+
dependency_container.register_transient(Innovation, tags={Driveable})
115+
116+
# act
117+
resolved_dependencies = dependency_container.resolve_all()
118+
119+
# assert
120+
self.assertEqual(len(resolved_dependencies), 3)
121+
self.assertTrue(any(isinstance(dependency, Vehicle) for dependency in resolved_dependencies))
122+
self.assertTrue(any(isinstance(dependency, Car) for dependency in resolved_dependencies))
123+
self.assertTrue(any(isinstance(dependency, Innovation) for dependency in resolved_dependencies))

0 commit comments

Comments
 (0)