Skip to content

Commit 19dd095

Browse files
committed
Add support for registering with factory.
1 parent 9b43b0e commit 19dd095

File tree

5 files changed

+177
-4
lines changed

5 files changed

+177
-4
lines changed

src/dependency_injection/container.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import inspect
22

3-
from typing import Any, Dict, List, Optional, TypeVar, Type
3+
from typing import Any, Callable, Dict, List, Optional, TypeVar, Type
44

55
from dependency_injection.registration import Registration
66
from dependency_injection.scope import DEFAULT_SCOPE_NAME, Scope
@@ -50,6 +50,11 @@ def register_singleton(self, dependency: Type, implementation: Optional[Type] =
5050
raise ValueError(f"Dependency {dependency} is already registered.")
5151
self._registrations[dependency] = Registration(dependency, implementation, Scope.SINGLETON, tags, constructor_args)
5252

53+
def register_factory(self, dependency: Type, factory: Callable[[Any], Any], factory_args: Optional[Dict[str, Any]] = None, tags: Optional[set] = None) -> None:
54+
if dependency in self._registrations:
55+
raise ValueError(f"Dependency {dependency} is already registered.")
56+
self._registrations[dependency] = Registration(dependency, None, Scope.FACTORY, None, tags, factory, factory_args)
57+
5358
def register_instance(self, dependency: Type, instance: Any, tags: Optional[set] = None) -> None:
5459
if dependency in self._registrations:
5560
raise ValueError(f"Dependency {dependency} is already registered.")
@@ -93,11 +98,15 @@ def resolve(self, dependency: Type, scope_name: str = DEFAULT_SCOPE_NAME) -> Typ
9398
)
9499
)
95100
return self._singleton_instances[dependency]
101+
elif scope == Scope.FACTORY:
102+
factory = registration.factory
103+
factory_args = registration.factory_args or {}
104+
return factory(**factory_args)
96105

97106
raise ValueError(f"Invalid dependency scope: {scope}")
98107

99108
def resolve_all(self, tags: Optional[set] = None) -> List[Any]:
100-
tags = [] if not tags else tags
109+
tags = tags or []
101110
resolved_dependencies = []
102111
for registration in self._registrations.values():
103112
if not len(tags) or tags.intersection(registration.tags):
Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1-
from typing import Any, Dict, Optional, Type
1+
from typing import Any, Callable, Dict, Optional, Type
22

33
from dependency_injection.scope import Scope
44

55

66
class Registration():
77

8-
def __init__(self, dependency: Type, implementation: Type, scope: Scope, tags: Optional[set] = None, constructor_args: Optional[Dict[str, Any]] = None):
8+
def __init__(self, dependency: Type, implementation: Optional[Type], scope: Scope, tags: Optional[set] = None, constructor_args: Optional[Dict[str, Any]] = None, factory: Optional[Callable[[Any], Any]] = None, factory_args: Optional[Dict[str, Any]] = None):
99
self.dependency = dependency
1010
self.implementation = implementation
1111
self.scope = scope
1212
self.tags = tags or set()
1313
self.constructor_args = constructor_args or {}
14+
self.factory = factory
15+
self.factory_args = factory_args or {}
16+
17+
if not any([self.implementation, self.factory]):
18+
raise Exception("There must be either an implementation or a factory.")

src/dependency_injection/scope.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ class Scope(Enum):
77
TRANSIENT = "transient"
88
SCOPED = "scoped"
99
SINGLETON = "singleton"
10+
FACTORY = "factory"
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import pytest
2+
3+
from dependency_injection.container import DependencyContainer
4+
from unit_test.unit_test_case import UnitTestCase
5+
6+
7+
class TestRegisterFactory(UnitTestCase):
8+
9+
def test_register_factory_succeeds_when_not_previously_registered(
10+
self,
11+
):
12+
# arrange
13+
class Vehicle:
14+
pass
15+
16+
class Car(Vehicle):
17+
pass
18+
19+
class CarFactory:
20+
@classmethod
21+
def create(cls) -> Car:
22+
return Car()
23+
24+
dependency_container = DependencyContainer.get_instance()
25+
26+
# act
27+
dependency_container.register_factory(Vehicle, factory=CarFactory.create)
28+
29+
# assert
30+
# (no exception thrown)
31+
32+
def test_register_with_factory_args(
33+
self,
34+
):
35+
# arrange
36+
class Vehicle:
37+
pass
38+
39+
class Car(Vehicle):
40+
def __init__(self, color: str, mileage: int):
41+
self.color = color
42+
self.mileage = mileage
43+
44+
class CarFactory:
45+
@classmethod
46+
def create(cls, color: str, mileage: int) -> Car:
47+
return Car(color=color, mileage=mileage)
48+
49+
dependency_container = DependencyContainer.get_instance()
50+
51+
# act + assert (no exception)
52+
dependency_container.register_factory(Vehicle, factory=CarFactory.create, factory_args={"color": "red", "mileage": 3800})
53+
54+
def test_register_instance_fails_when_already_registered(
55+
self,
56+
):
57+
# arrange
58+
class Vehicle:
59+
pass
60+
61+
class Car(Vehicle):
62+
pass
63+
64+
class CarFactory:
65+
@classmethod
66+
def create(cls) -> Car:
67+
return Car()
68+
69+
dependency_container = DependencyContainer.get_instance()
70+
71+
# act
72+
dependency_container.register_factory(Vehicle, factory=CarFactory.create)
73+
74+
# act + assert
75+
with pytest.raises(ValueError, match="is already registered"):
76+
dependency_container.register_factory(Vehicle, factory=CarFactory.create)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from dependency_injection.container import DependencyContainer
2+
from unit_test.unit_test_case import UnitTestCase
3+
4+
5+
class TestResolveTransient(UnitTestCase):
6+
7+
def test_resolve_factory_returns_an_instance(
8+
self,
9+
):
10+
# arrange
11+
class Vehicle:
12+
pass
13+
14+
class Car(Vehicle):
15+
pass
16+
17+
class CarFactory:
18+
@classmethod
19+
def create(cls) -> Car:
20+
return Car()
21+
22+
dependency_container = DependencyContainer.get_instance()
23+
dependency_container.register_factory(Vehicle, factory=CarFactory.create)
24+
25+
# act
26+
resolved_dependency = dependency_container.resolve(Vehicle)
27+
28+
# assert
29+
self.assertIsInstance(resolved_dependency, Car)
30+
31+
def test_resolve_factory_twice_returns_different_instances(
32+
self,
33+
):
34+
# arrange
35+
class Vehicle:
36+
pass
37+
38+
class Car(Vehicle):
39+
pass
40+
41+
class CarFactory:
42+
@classmethod
43+
def create(cls) -> Car:
44+
return Car()
45+
46+
dependency_container = DependencyContainer.get_instance()
47+
dependency_container.register_factory(Vehicle, factory=CarFactory.create)
48+
49+
# act
50+
resolved_dependency_1 = dependency_container.resolve(Vehicle)
51+
resolved_dependency_2 = dependency_container.resolve(Vehicle)
52+
53+
# assert
54+
self.assertNotEqual(resolved_dependency_1, resolved_dependency_2)
55+
56+
def test_resolve_factory_with_args_passes_args(
57+
self,
58+
):
59+
# arrange
60+
class Vehicle:
61+
pass
62+
63+
class Car(Vehicle):
64+
def __init__(self, color: str, mileage: int):
65+
self.color = color
66+
self.mileage = mileage
67+
68+
class CarFactory:
69+
@classmethod
70+
def create(cls, color: str, mileage: int) -> Car:
71+
return Car(color=color, mileage=mileage)
72+
73+
dependency_container = DependencyContainer.get_instance()
74+
dependency_container.register_factory(Vehicle, factory=CarFactory.create, factory_args={"color": "red", "mileage": 6327})
75+
76+
# act
77+
resolved_dependency = dependency_container.resolve(Vehicle)
78+
79+
# assert
80+
self.assertIsInstance(resolved_dependency, Car)
81+
self.assertEqual("red", resolved_dependency.color)
82+
self.assertEqual(6327, resolved_dependency.mileage)

0 commit comments

Comments
 (0)