Skip to content

Commit 5480e02

Browse files
committed
Only allow decorator on class- and static methods.
1 parent c4db1cd commit 5480e02

File tree

2 files changed

+159
-84
lines changed

2 files changed

+159
-84
lines changed

src/dependency_injection/decorator.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,36 @@
88

99
def inject(container_name=DEFAULT_CONTAINER_NAME, scope_name=DEFAULT_SCOPE_NAME):
1010

11+
def is_instance_method(func):
12+
parameters = inspect.signature(func).parameters
13+
is_instance_method = len(parameters) > 0 and list(parameters.values())[0].name == 'self'
14+
return is_instance_method
15+
1116
def decorator_inject(func):
17+
1218
@functools.wraps(func)
13-
def wrapper_inject(self, *args, **kwargs):
19+
def wrapper_inject(*args, **kwargs):
20+
1421
# Get the parameter names from the function signature
15-
param_names = list(inspect.signature(func).parameters.keys())
22+
sig = inspect.signature(func)
23+
parameter_names = [param.name for param in sig.parameters.values()]
1624

1725
# Iterate over the parameter names and inject dependencies into kwargs
18-
for param_name in param_names:
19-
if param_name != 'self' and param_name not in kwargs:
26+
for parameter_name in parameter_names:
27+
if parameter_name != 'cls' and parameter_name not in kwargs:
2028
# get container
2129
container = DependencyContainer.get_instance(container_name)
2230
# Resolve the dependency based on the parameter name
23-
dependency_type = inspect.signature(func).parameters[param_name].annotation
24-
kwargs[param_name] = container.resolve(dependency_type, scope_name=scope_name)
31+
dependency_type = sig.parameters[parameter_name].annotation
32+
kwargs[parameter_name] = container.resolve(dependency_type, scope_name=scope_name)
2533

2634
# Call the original function with the injected dependencies
27-
return func(self, *args, **kwargs)
35+
return func(*args, **kwargs)
36+
37+
# Not allowed on instance methods
38+
if is_instance_method(func):
39+
raise TypeError(
40+
"@inject decorator can only be applied to class methods or static methods.")
2841

2942
return wrapper_inject
3043

Lines changed: 139 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from dependency_injection.container import DEFAULT_CONTAINER_NAME, \
2-
DependencyContainer
1+
import pytest
2+
3+
from dependency_injection.container import DependencyContainer
34
from dependency_injection.decorator import inject
45
from unit_test.car import Car
56
from unit_test.unit_test_case import UnitTestCase
@@ -8,133 +9,194 @@
89

910
class TestDecorator(UnitTestCase):
1011

11-
def test_injects_dependencies_into_method_signature(
12-
self,
13-
):
12+
def test_decoration_on_class_method(self):
13+
1414
# arrange
15-
dependency_container = DependencyContainer.get_instance(name="test-container")
15+
dependency_container = DependencyContainer.get_instance()
1616
interface = Vehicle
1717
dependency_class = Car
1818

1919
dependency_container.register_transient(interface, dependency_class)
2020

21+
class Garage:
22+
vehicle: Vehicle
23+
24+
@classmethod
25+
@inject()
26+
def park(cls, vehicle: Vehicle):
27+
cls.vehicle = vehicle
28+
2129
# act
22-
foo = Foo()
23-
foo.bar_1()
30+
Garage.park()
2431

2532
# assert
26-
self.assertIsNotNone(foo.vehicle_1)
27-
self.assertIsInstance(foo.vehicle_1, Vehicle)
33+
self.assertIsNotNone(Garage.vehicle)
34+
35+
def test_decoration_on_static_method(self):
2836

29-
def test_injects_dependencies_from_different_scopes_correctly(
30-
self,
31-
):
3237
# arrange
33-
dependency_container = DependencyContainer.get_instance(name="test-container")
38+
dependency_container = DependencyContainer.get_instance()
3439
interface = Vehicle
3540
dependency_class = Car
3641

37-
dependency_container.register_scoped(interface, dependency_class)
42+
dependency_container.register_transient(interface, dependency_class)
43+
44+
class Garage:
45+
vehicle: Vehicle
46+
47+
@staticmethod
48+
@inject()
49+
def park(vehicle: Vehicle):
50+
Garage.vehicle = vehicle
3851

3952
# act
40-
foo = Foo()
41-
foo.bar_1()
42-
foo.bar_2()
53+
Garage.park()
4354

4455
# assert
45-
self.assertIsNotNone(foo.vehicle_1)
46-
self.assertIsNotNone(foo.vehicle_2)
47-
self.assertNotEqual(foo.vehicle_1, foo.vehicle_2)
56+
self.assertIsNotNone(Garage.vehicle)
4857

49-
def test_injects_using_default_container_and_scope_if_omitted_in_decorator_arguments(
58+
def test_decoration_on_instance_method_raises(
5059
self,
5160
):
5261
# arrange
5362
dependency_container = DependencyContainer.get_instance()
5463
interface = Vehicle
5564
dependency_class = Car
5665

57-
dependency_container.register_scoped(interface, dependency_class)
66+
dependency_container.register_transient(interface, dependency_class)
67+
68+
with pytest.raises(TypeError, match="@inject decorator can only be applied to class methods or static methods."):
69+
class Garage:
70+
@inject()
71+
def park(self, vehicle: Vehicle):
72+
pass
73+
74+
def test_class_method_decorator_container_name_is_honoured(
75+
self,
76+
):
77+
# arrange
78+
interface = Vehicle
79+
dependency_class = Car
80+
81+
dependency_container = DependencyContainer.get_instance()
82+
dependency_container.register_singleton(interface, dependency_class)
83+
84+
second_container = DependencyContainer.get_instance("second")
85+
second_container.register_singleton(interface, dependency_class)
86+
87+
second_container_vehicle = second_container.resolve(interface)
88+
89+
class Garage:
90+
vehicle: Vehicle
91+
92+
@classmethod
93+
@inject(container_name="second")
94+
def park(cls, vehicle: Vehicle):
95+
cls.vehicle = vehicle
5896

5997
# act
60-
foo = Foo()
61-
foo.bar_3()
98+
Garage.park()
6299

63100
# assert
64-
self.assertIsNotNone(foo.vehicle_3)
65-
self.assertIsInstance(foo.vehicle_3, Vehicle)
66-
self.assertEqual(foo.vehicle_3, dependency_container.resolve(interface))
101+
self.assertEquals(second_container_vehicle, Garage.vehicle)
67102

68-
def test_injects_same_scoped_dependency_when_no_container_or_scope_name_in_decorator_arguments(
103+
def test_class_method_decorator_scope_name_is_honoured(
69104
self,
70105
):
71106
# arrange
72-
dependency_container = DependencyContainer.get_instance()
73107
interface = Vehicle
74108
dependency_class = Car
75109

110+
dependency_container = DependencyContainer.get_instance()
76111
dependency_container.register_scoped(interface, dependency_class)
77112

113+
first_scope_vehicle = dependency_container.resolve(interface, scope_name="first_scope")
114+
second_scope_vehicle = dependency_container.resolve(interface, scope_name="second_scope")
115+
116+
class Garage:
117+
first_vehicle: Vehicle
118+
second_vehicle: Vehicle
119+
120+
@classmethod
121+
@inject(scope_name="first_scope")
122+
def park_first(cls, vehicle: Vehicle):
123+
cls.first_vehicle = vehicle
124+
125+
@classmethod
126+
@inject(scope_name="second_scope")
127+
def park_second(cls, vehicle: Vehicle):
128+
cls.second_vehicle = vehicle
129+
78130
# act
79-
foo = Foo()
80-
foo.bar_3()
81-
foo.bar_4()
131+
Garage.park_first()
132+
Garage.park_second()
82133

83134
# assert
84-
self.assertIsNotNone(foo.vehicle_3)
85-
self.assertIsNotNone(foo.vehicle_4)
86-
self.assertEqual(foo.vehicle_3, foo.vehicle_4)
135+
self.assertEqual(first_scope_vehicle, Garage.first_vehicle)
136+
self.assertEqual(second_scope_vehicle, Garage.second_vehicle)
137+
self.assertNotEqual(Garage.first_vehicle, Garage.second_vehicle)
87138

88-
def test_injects_different_scoped_dependencies_when_no_container_but_different_scope_names_in_decorator_arguments(
139+
def test_static_method_decorator_container_name_is_honoured(
89140
self,
90141
):
91142
# arrange
143+
interface = Vehicle
144+
dependency_class = Car
145+
92146
dependency_container = DependencyContainer.get_instance()
147+
dependency_container.register_singleton(interface, dependency_class)
148+
149+
second_container = DependencyContainer.get_instance("second")
150+
second_container.register_singleton(interface, dependency_class)
151+
152+
second_container_vehicle = second_container.resolve(interface)
153+
154+
class Garage:
155+
vehicle: Vehicle
156+
157+
@staticmethod
158+
@inject(container_name="second")
159+
def park(vehicle: Vehicle):
160+
Garage.vehicle = vehicle
161+
162+
# act
163+
Garage.park()
164+
165+
# assert
166+
self.assertEquals(second_container_vehicle, Garage.vehicle)
167+
168+
def test_static_method_decorator_scope_name_is_honoured(
169+
self,
170+
):
171+
# arrange
93172
interface = Vehicle
94173
dependency_class = Car
95174

175+
dependency_container = DependencyContainer.get_instance()
96176
dependency_container.register_scoped(interface, dependency_class)
97177

178+
first_scope_vehicle = dependency_container.resolve(interface, scope_name="first_scope")
179+
second_scope_vehicle = dependency_container.resolve(interface, scope_name="second_scope")
180+
181+
class Garage:
182+
first_vehicle: Vehicle
183+
second_vehicle: Vehicle
184+
185+
@staticmethod
186+
@inject(scope_name="first_scope")
187+
def park_first(vehicle: Vehicle):
188+
Garage.first_vehicle = vehicle
189+
190+
@staticmethod
191+
@inject(scope_name="second_scope")
192+
def park_second(vehicle: Vehicle):
193+
Garage.second_vehicle = vehicle
194+
98195
# act
99-
foo = Foo()
100-
foo.bar_5()
101-
foo.bar_6()
196+
Garage.park_first()
197+
Garage.park_second()
102198

103199
# assert
104-
self.assertIsNotNone(foo.vehicle_5)
105-
self.assertIsNotNone(foo.vehicle_6)
106-
self.assertNotEqual(foo.vehicle_5, foo.vehicle_6)
107-
108-
class Foo:
109-
110-
def __init__(self):
111-
self.vehicle_1 = None
112-
self.vehicle_2 = None
113-
self.vehicle_3 = None
114-
self.vehicle_4 = None
115-
self.vehicle_5 = None
116-
self.vehicle_6 = None
117-
118-
@inject(container_name="test-container", scope_name="test-scope-1")
119-
def bar_1(self, vehicle: Vehicle):
120-
self.vehicle_1 = vehicle
121-
122-
@inject(container_name="test-container", scope_name="test-scope-2")
123-
def bar_2(self, vehicle: Vehicle):
124-
self.vehicle_2 = vehicle
125-
126-
@inject()
127-
def bar_3(self, vehicle: Vehicle):
128-
self.vehicle_3 = vehicle
129-
130-
@inject()
131-
def bar_4(self, vehicle: Vehicle):
132-
self.vehicle_4 = vehicle
133-
134-
@inject(scope_name="scope-5")
135-
def bar_5(self, vehicle: Vehicle):
136-
self.vehicle_5 = vehicle
137-
138-
@inject(scope_name="scope-6")
139-
def bar_6(self, vehicle: Vehicle):
140-
self.vehicle_6 = vehicle
200+
self.assertEqual(first_scope_vehicle, Garage.first_vehicle)
201+
self.assertEqual(second_scope_vehicle, Garage.second_vehicle)
202+
self.assertNotEqual(Garage.first_vehicle, Garage.second_vehicle)

0 commit comments

Comments
 (0)