Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion providers/openfeature-provider-flagd/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dev = [
"pytest>=8.4.0,<9.0.0",
"pytest-bdd>=8.1.0,<9.0.0",
"testcontainers>=4.12.0,<5.0.0",
"types-grpcio>=1.0.0,<2.0.0",
"types-protobuf>=6.30.0,<7.0.0",
"types-pyyaml>=6.0.0,<7.0.0",
]
Expand Down Expand Up @@ -97,7 +98,6 @@ disallow_any_generics = false

[[tool.mypy.overrides]]
module = [
"grpc.*",
"json_logic.*",
]
ignore_missing_imports = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from ..config import CacheType, Config
from ..flag_type import FlagType
from .types import GrpcMultiCallableArgs

if typing.TYPE_CHECKING:
from google.protobuf.message import Message
Expand Down Expand Up @@ -121,15 +122,16 @@ def _generate_channel(self, config: Config) -> grpc.Channel:
),
]
if config.tls:
channel_args = {
"options": options,
"credentials": grpc.ssl_channel_credentials(),
}
credentials = grpc.ssl_channel_credentials()
if config.cert_path:
with open(config.cert_path, "rb") as f:
channel_args["credentials"] = grpc.ssl_channel_credentials(f.read())
credentials = grpc.ssl_channel_credentials(f.read())

channel = grpc.secure_channel(target, **channel_args)
channel = grpc.secure_channel(
target,
credentials=credentials,
options=options,
)

else:
channel = grpc.insecure_channel(
Expand Down Expand Up @@ -220,20 +222,16 @@ def emit_error(self) -> None:

def listen(self) -> None:
logger.debug("gRPC starting listener thread")
call_args = (
{"timeout": self.streamline_deadline_seconds}
if self.streamline_deadline_seconds > 0
else {}
)
call_args: GrpcMultiCallableArgs = {"wait_for_ready": True}
if self.streamline_deadline_seconds > 0:
call_args["timeout"] = self.streamline_deadline_seconds
request = evaluation_pb2.EventStreamRequest()

# defining a never ending loop to recreate the stream
while self.active:
try:
logger.debug("Setting up gRPC sync flags connection")
for message in self.stub.EventStream(
request, wait_for_ready=True, **call_args
):
for message in self.stub.EventStream(request, **call_args):
if message.type == "provider_ready":
self.emit_provider_ready(
ProviderEventDetails(
Expand Down Expand Up @@ -309,20 +307,72 @@ def resolve_object_details(
]:
return self._resolve(key, FlagType.OBJECT, default_value, evaluation_context)

@typing.overload
def _resolve(
self,
flag_key: str,
flag_type: FlagType,
default_value: bool,
evaluation_context: typing.Optional[EvaluationContext],
) -> FlagResolutionDetails[bool]: ...

@typing.overload
def _resolve(
self,
flag_key: str,
flag_type: FlagType,
default_value: int,
evaluation_context: typing.Optional[EvaluationContext],
) -> FlagResolutionDetails[int]: ...

@typing.overload
def _resolve(
self,
flag_key: str,
flag_type: FlagType,
default_value: float,
evaluation_context: typing.Optional[EvaluationContext],
) -> FlagResolutionDetails[float]: ...

@typing.overload
def _resolve(
self,
flag_key: str,
flag_type: FlagType,
default_value: str,
evaluation_context: typing.Optional[EvaluationContext],
) -> FlagResolutionDetails[str]: ...

@typing.overload
def _resolve(
self,
flag_key: str,
flag_type: FlagType,
default_value: typing.Union[
typing.Sequence[FlagValueType], typing.Mapping[str, FlagValueType]
],
evaluation_context: typing.Optional[EvaluationContext],
) -> FlagResolutionDetails[
typing.Union[typing.Sequence[FlagValueType], typing.Mapping[str, FlagValueType]]
]: ...

def _resolve( # noqa: PLR0915 C901
self,
flag_key: str,
flag_type: FlagType,
default_value: T,
default_value: FlagValueType,
evaluation_context: typing.Optional[EvaluationContext],
) -> FlagResolutionDetails[T]:
) -> FlagResolutionDetails[FlagValueType]:
if self.cache is not None and flag_key in self.cache:
cached_flag: FlagResolutionDetails[T] = self.cache[flag_key]
cached_flag: FlagResolutionDetails[FlagValueType] = self.cache[flag_key]
cached_flag.reason = Reason.CACHED
return cached_flag

context = self._convert_context(evaluation_context)
call_args = {"timeout": self.deadline, "wait_for_ready": True}
call_args: GrpcMultiCallableArgs = {
"timeout": self.deadline,
"wait_for_ready": True,
}
try:
request: Message
if flag_type == FlagType.BOOLEAN:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)

from ....config import Config
from ...types import GrpcMultiCallableArgs
from ..connector import FlagStateConnector
from ..flags import FlagStore

Expand Down Expand Up @@ -105,22 +106,23 @@ def _generate_channel(self, config: Config) -> grpc.Channel:
options.append(("grpc.default_authority", config.default_authority))

if config.channel_credentials is not None:
channel_args = {
"options": options,
"credentials": config.channel_credentials,
}
channel = grpc.secure_channel(target, **channel_args)
channel = grpc.secure_channel(
target,
credentials=config.channel_credentials,
options=options,
)

elif config.tls:
channel_args = {
"options": options,
"credentials": grpc.ssl_channel_credentials(),
}
credentials = grpc.ssl_channel_credentials()
if config.cert_path:
with open(config.cert_path, "rb") as f:
channel_args["credentials"] = grpc.ssl_channel_credentials(f.read())
credentials = grpc.ssl_channel_credentials(f.read())

channel = grpc.secure_channel(target, **channel_args)
channel = grpc.secure_channel(
target,
credentials=credentials,
options=options,
)

else:
channel = grpc.insecure_channel(
Expand Down Expand Up @@ -227,12 +229,10 @@ def _fetch_metadata(self) -> typing.Optional[sync_pb2.GetMetadataResponse]:
else:
raise e

def listen(self) -> None:
call_args = (
{"timeout": self.streamline_deadline_seconds}
if self.streamline_deadline_seconds > 0
else {}
)
def listen(self) -> None: # noqa: C901
call_args: GrpcMultiCallableArgs = {"wait_for_ready": True}
if self.streamline_deadline_seconds > 0:
call_args["timeout"] = self.streamline_deadline_seconds
request_args = self._create_request_args()

while self.active:
Expand All @@ -242,9 +242,7 @@ def listen(self) -> None:
request = sync_pb2.SyncFlagsRequest(**request_args)

logger.debug("Setting up gRPC sync flags connection")
for flag_rsp in self.stub.SyncFlags(
request, wait_for_ready=True, **call_args
):
for flag_rsp in self.stub.SyncFlags(request, **call_args):
flag_str = flag_rsp.flag_configuration
logger.debug(
f"Received flag configuration - {abs(hash(flag_str)) % (10**8)}"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import typing


class GrpcMultiCallableArgs(typing.TypedDict, total=False):
timeout: typing.Optional[float]
wait_for_ready: typing.Optional[bool]
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies = [
[dependency-groups]
dev = [
"coverage[toml]>=7.10.0,<8.0.0",
"mypy[faster-cache]>=1.17.0,<2.0.0",
"mypy>=1.18.0,<2.0.0",
"pytest>=8.4.0,<9.0.0",
]

Expand Down
Loading