From 6e2a6907d73c284ed86e3579db6e05b18b11bf4a Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Wed, 3 Dec 2025 17:45:58 +0900 Subject: [PATCH 1/2] Always import future annotations Signed-off-by: Anuraag Agrawal --- justfile | 2 +- pyproject.toml | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/justfile b/justfile index a13f963..3b2c3d6 100644 --- a/justfile +++ b/justfile @@ -7,7 +7,7 @@ BUF_VERSION := "v1.57.0" # Format Python files format: - uv run ruff check --fix . + uv run ruff check --fix --unsafe-fixes --exit-zero . uv run ruff format . # Lint Python files diff --git a/pyproject.toml b/pyproject.toml index 1518c9f..0e3a25e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -184,6 +184,8 @@ extend-ignore = [ "PLR", ] +typing-extensions = false + [tool.ruff.lint.per-file-ignores] "conformance/test/**" = ["ANN", "INP", "SLF", "SIM115", "S101", "D"] "example/**" = [ @@ -206,6 +208,7 @@ extend-ignore = [ ] [tool.ruff.lint.isort] +required-imports = ["from __future__ import annotations"] split-on-trailing-comma = false [tool.ruff] From 6df76320517e4fac62f9712aada0624ec29581ed Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Wed, 3 Dec 2025 17:46:43 +0900 Subject: [PATCH 2/2] Apply fixes Signed-off-by: Anuraag Agrawal --- conformance/test/_util.py | 2 ++ conformance/test/client.py | 11 ++++++--- .../conformance/v1/service_connect.py | 18 ++++++++++---- conformance/test/server.py | 10 +++++--- conformance/test/test_client.py | 2 ++ conformance/test/test_server.py | 2 ++ example/example/_eliza.py | 7 +++++- example/example/eliza_client.py | 2 ++ example/example/eliza_client_sync.py | 2 ++ example/example/eliza_connect.py | 18 ++++++++++---- example/example/eliza_service.py | 7 ++++-- example/example/eliza_service_sync.py | 10 ++++++-- noextras/test/test_compression_default.py | 2 ++ .../scripts/generate_wheels.py | 2 ++ src/connectrpc/__init__.py | 2 ++ src/connectrpc/_asyncio_timeout.py | 8 +++++-- src/connectrpc/_client_async.py | 13 ++++++---- src/connectrpc/_client_shared.py | 21 +++++++++------- src/connectrpc/_client_sync.py | 13 ++++++---- src/connectrpc/_codec.py | 2 ++ src/connectrpc/_compression.py | 8 +++++-- src/connectrpc/_envelope.py | 13 ++++++---- src/connectrpc/_headers.py | 2 ++ src/connectrpc/_interceptor_async.py | 10 +++++--- src/connectrpc/_interceptor_sync.py | 10 +++++--- src/connectrpc/_protocol.py | 21 +++++++++------- src/connectrpc/_server_async.py | 19 ++++++++------- src/connectrpc/_server_shared.py | 24 +++++++++++-------- src/connectrpc/_server_sync.py | 4 +++- src/connectrpc/_version.py | 2 ++ src/connectrpc/client.py | 2 ++ src/connectrpc/code.py | 2 ++ src/connectrpc/errors.py | 12 +++++++--- src/connectrpc/interceptor.py | 2 ++ src/connectrpc/method.py | 2 ++ src/connectrpc/request.py | 8 +++++-- src/connectrpc/server.py | 2 ++ test/haberdasher_connect.py | 18 ++++++++++---- test/test_client.py | 2 ++ test/test_details.py | 2 ++ test/test_errors.py | 2 ++ test/test_example.py | 2 ++ test/test_headers.py | 2 ++ test/test_interceptor.py | 8 +++++-- test/test_lifespan.py | 2 ++ test/test_roundtrip.py | 7 +++++- 46 files changed, 252 insertions(+), 90 deletions(-) diff --git a/conformance/test/_util.py b/conformance/test/_util.py index 54facef..1e9d255 100644 --- a/conformance/test/_util.py +++ b/conformance/test/_util.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import sys diff --git a/conformance/test/client.py b/conformance/test/client.py index a25f3eb..e621496 100644 --- a/conformance/test/client.py +++ b/conformance/test/client.py @@ -1,12 +1,13 @@ +from __future__ import annotations + import argparse import asyncio import ssl import sys import time import traceback -from collections.abc import AsyncIterator, Iterator from tempfile import NamedTemporaryFile -from typing import Literal, TypeVar +from typing import TYPE_CHECKING, Literal, TypeVar import httpx from _util import create_standard_streams @@ -29,7 +30,6 @@ UnaryRequest, UnimplementedRequest, ) -from google.protobuf.any_pb2 import Any from google.protobuf.message import Message from connectrpc.client import ResponseMetadata @@ -37,6 +37,11 @@ from connectrpc.errors import ConnectError from connectrpc.request import Headers +if TYPE_CHECKING: + from collections.abc import AsyncIterator, Iterator + + from google.protobuf.any_pb2 import Any + def _convert_code(error: Code) -> ConformanceCode: match error: diff --git a/conformance/test/gen/connectrpc/conformance/v1/service_connect.py b/conformance/test/gen/connectrpc/conformance/v1/service_connect.py index e7400cd..d651658 100644 --- a/conformance/test/gen/connectrpc/conformance/v1/service_connect.py +++ b/conformance/test/gen/connectrpc/conformance/v1/service_connect.py @@ -1,15 +1,13 @@ # Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT! # source: connectrpc/conformance/v1/service.proto +from __future__ import annotations -from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping -from typing import Protocol +from typing import TYPE_CHECKING, Protocol from connectrpc.client import ConnectClient, ConnectClientSync from connectrpc.code import Code from connectrpc.errors import ConnectError -from connectrpc.interceptor import Interceptor, InterceptorSync from connectrpc.method import IdempotencyLevel, MethodInfo -from connectrpc.request import Headers, RequestContext from connectrpc.server import ( ConnectASGIApplication, ConnectWSGIApplication, @@ -19,6 +17,18 @@ from . import service_pb2 as connectrpc_dot_conformance_dot_v1_dot_service__pb2 +if TYPE_CHECKING: + from collections.abc import ( + AsyncGenerator, + AsyncIterator, + Iterable, + Iterator, + Mapping, + ) + + from connectrpc.interceptor import Interceptor, InterceptorSync + from connectrpc.request import Headers, RequestContext + class ConformanceService(Protocol): async def unary( diff --git a/conformance/test/server.py b/conformance/test/server.py index b996341..799a58e 100644 --- a/conformance/test/server.py +++ b/conformance/test/server.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import asyncio import os @@ -7,7 +9,6 @@ import ssl import sys import time -from collections.abc import AsyncIterator, Iterator from contextlib import ExitStack, closing from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Literal, TypeVar, get_args @@ -43,14 +44,17 @@ from connectrpc.code import Code from connectrpc.errors import ConnectError -from connectrpc.request import RequestContext if TYPE_CHECKING: + from collections.abc import AsyncIterator, Iterator + from google.protobuf.message import Message + from connectrpc.request import RequestContext + # TODO: Use google.protobuf.any.pack on upgrade to protobuf==6. -def pack(msg: "Message") -> Any: +def pack(msg: Message) -> Any: any_msg = Any() any_msg.Pack(msg) return any_msg diff --git a/conformance/test/test_client.py b/conformance/test/test_client.py index 148c84e..b30586f 100644 --- a/conformance/test/test_client.py +++ b/conformance/test/test_client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import subprocess import sys from pathlib import Path diff --git a/conformance/test/test_server.py b/conformance/test/test_server.py index 52ac777..459cfb9 100644 --- a/conformance/test/test_server.py +++ b/conformance/test/test_server.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import subprocess import sys diff --git a/example/example/_eliza.py b/example/example/_eliza.py index 3086a38..ecce44c 100644 --- a/example/example/_eliza.py +++ b/example/example/_eliza.py @@ -1,6 +1,11 @@ +from __future__ import annotations + import random import re -from collections.abc import Sequence +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Sequence # Ported from https://github.com/connectrpc/examples-go # Originally from https://github.com/mattshiel/eliza-go diff --git a/example/example/eliza_client.py b/example/example/eliza_client.py index 428b080..fccdccf 100644 --- a/example/example/eliza_client.py +++ b/example/example/eliza_client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio from example.eliza_connect import ElizaServiceClient diff --git a/example/example/eliza_client_sync.py b/example/example/eliza_client_sync.py index d4b0af9..e684faa 100644 --- a/example/example/eliza_client_sync.py +++ b/example/example/eliza_client_sync.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from example.eliza_connect import ElizaServiceClientSync from example.eliza_pb2 import IntroduceRequest, SayRequest diff --git a/example/example/eliza_connect.py b/example/example/eliza_connect.py index 90eb71b..02e3750 100644 --- a/example/example/eliza_connect.py +++ b/example/example/eliza_connect.py @@ -1,15 +1,13 @@ # Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT! # source: example/eliza.proto +from __future__ import annotations -from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping -from typing import Protocol +from typing import TYPE_CHECKING, Protocol from connectrpc.client import ConnectClient, ConnectClientSync from connectrpc.code import Code from connectrpc.errors import ConnectError -from connectrpc.interceptor import Interceptor, InterceptorSync from connectrpc.method import IdempotencyLevel, MethodInfo -from connectrpc.request import Headers, RequestContext from connectrpc.server import ( ConnectASGIApplication, ConnectWSGIApplication, @@ -19,6 +17,18 @@ import example.eliza_pb2 as example_dot_eliza__pb2 +if TYPE_CHECKING: + from collections.abc import ( + AsyncGenerator, + AsyncIterator, + Iterable, + Iterator, + Mapping, + ) + + from connectrpc.interceptor import Interceptor, InterceptorSync + from connectrpc.request import Headers, RequestContext + class ElizaService(Protocol): async def say( diff --git a/example/example/eliza_service.py b/example/example/eliza_service.py index 3b0b64c..67ac26a 100644 --- a/example/example/eliza_service.py +++ b/example/example/eliza_service.py @@ -1,8 +1,8 @@ +from __future__ import annotations + import asyncio -from collections.abc import AsyncIterator from typing import TYPE_CHECKING, cast -from connectrpc.request import RequestContext from starlette.applications import Starlette from starlette.responses import PlainTextResponse from starlette.routing import Mount, Route @@ -19,6 +19,9 @@ ) if TYPE_CHECKING: + from collections.abc import AsyncIterator + + from connectrpc.request import RequestContext from starlette.types import ASGIApp diff --git a/example/example/eliza_service_sync.py b/example/example/eliza_service_sync.py index e5777f1..2477bee 100644 --- a/example/example/eliza_service_sync.py +++ b/example/example/eliza_service_sync.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import time -from collections.abc import Iterator +from typing import TYPE_CHECKING -from connectrpc.request import RequestContext from flask import Flask from werkzeug.middleware.dispatcher import DispatcherMiddleware @@ -17,6 +18,11 @@ from . import _eliza from .eliza_connect import ElizaServiceSync, ElizaServiceWSGIApplication +if TYPE_CHECKING: + from collections.abc import Iterator + + from connectrpc.request import RequestContext + class DemoElizaServiceSync(ElizaServiceSync): stream_delay_secs: float diff --git a/noextras/test/test_compression_default.py b/noextras/test/test_compression_default.py index dfa3cf6..fbcc09c 100644 --- a/noextras/test/test_compression_default.py +++ b/noextras/test/test_compression_default.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from example.eliza_connect import ( ElizaService, diff --git a/protoc-gen-connect-python/scripts/generate_wheels.py b/protoc-gen-connect-python/scripts/generate_wheels.py index 0f0e57b..8076157 100644 --- a/protoc-gen-connect-python/scripts/generate_wheels.py +++ b/protoc-gen-connect-python/scripts/generate_wheels.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import shutil import subprocess diff --git a/src/connectrpc/__init__.py b/src/connectrpc/__init__.py index 92ad3cc..31350fb 100644 --- a/src/connectrpc/__init__.py +++ b/src/connectrpc/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + __all__ = ["__version__"] diff --git a/src/connectrpc/_asyncio_timeout.py b/src/connectrpc/_asyncio_timeout.py index 53ac78d..5f60385 100644 --- a/src/connectrpc/_asyncio_timeout.py +++ b/src/connectrpc/_asyncio_timeout.py @@ -2,11 +2,15 @@ # SPDX-License-Identifier: PSF-2.0 # Backport of asyncio.timeout for Python 3.10 +from __future__ import annotations import enum import sys from asyncio import events, exceptions, tasks -from types import TracebackType +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from types import TracebackType _HAX_EXCEPTION_GROUP = sys.version_info >= (3, 11) @@ -80,7 +84,7 @@ def __repr__(self) -> str: info_str = " ".join(info) return f"" - async def __aenter__(self) -> "Timeout": + async def __aenter__(self) -> Timeout: if self._state is not _State.CREATED: msg = "Timeout has already been entered" raise RuntimeError(msg) diff --git a/src/connectrpc/_client_async.py b/src/connectrpc/_client_async.py index e445785..550ae2c 100644 --- a/src/connectrpc/_client_async.py +++ b/src/connectrpc/_client_async.py @@ -1,8 +1,8 @@ +from __future__ import annotations + import asyncio import functools from asyncio import CancelledError, sleep, wait_for -from collections.abc import AsyncIterator, Iterable, Mapping -from types import TracebackType from typing import TYPE_CHECKING, Any, Protocol, TypeVar import httpx @@ -11,7 +11,6 @@ from . import _client_shared from ._asyncio_timeout import timeout as asyncio_timeout from ._codec import Codec, get_proto_binary_codec, get_proto_json_codec -from ._compression import Compression from ._envelope import EnvelopeReader, EnvelopeWriter from ._interceptor_async import ( BidiStreamInterceptor, @@ -24,8 +23,6 @@ from ._protocol import CONNECT_STREAMING_HEADER_COMPRESSION, ConnectWireError from .code import Code from .errors import ConnectError -from .method import MethodInfo -from .request import Headers, RequestContext try: from asyncio import ( @@ -36,6 +33,12 @@ if TYPE_CHECKING: import sys + from collections.abc import AsyncIterator, Iterable, Mapping + from types import TracebackType + + from ._compression import Compression + from .method import MethodInfo + from .request import Headers, RequestContext if sys.version_info >= (3, 11): from typing import Self diff --git a/src/connectrpc/_client_shared.py b/src/connectrpc/_client_shared.py index b20cf6c..2feef32 100644 --- a/src/connectrpc/_client_shared.py +++ b/src/connectrpc/_client_shared.py @@ -1,12 +1,10 @@ +from __future__ import annotations + import base64 import contextlib -from collections.abc import Iterable, Mapping, Sequence from contextvars import ContextVar, Token from http import HTTPStatus -from types import TracebackType -from typing import TypeVar - -from httpx import Headers as HttpxHeaders +from typing import TYPE_CHECKING, TypeVar from . import _compression from ._codec import CODEC_NAME_JSON, CODEC_NAME_JSON_CHARSET_UTF8, Codec @@ -25,9 +23,16 @@ from ._version import __version__ from .code import Code from .errors import ConnectError -from .method import MethodInfo from .request import Headers, RequestContext +if TYPE_CHECKING: + from collections.abc import Iterable, Mapping, Sequence + from types import TracebackType + + from httpx import Headers as HttpxHeaders + + from .method import MethodInfo + _DEFAULT_CONNECT_USER_AGENT = f"connectrpc/{__version__}" REQ = TypeVar("REQ") @@ -247,9 +252,9 @@ class ResponseMetadata: _headers: Headers | None = None _trailers: Headers | None = None - _token: Token["ResponseMetadata"] | None = None + _token: Token[ResponseMetadata] | None = None - def __enter__(self) -> "ResponseMetadata": + def __enter__(self) -> ResponseMetadata: self._token = _current_response.set(self) return self diff --git a/src/connectrpc/_client_sync.py b/src/connectrpc/_client_sync.py index 7e414a8..b37f86f 100644 --- a/src/connectrpc/_client_sync.py +++ b/src/connectrpc/_client_sync.py @@ -1,6 +1,6 @@ +from __future__ import annotations + import functools -from collections.abc import Iterable, Iterator, Mapping -from types import TracebackType from typing import TYPE_CHECKING, Any, Protocol, TypeVar import httpx @@ -8,7 +8,6 @@ from . import _client_shared from ._codec import Codec, get_proto_binary_codec, get_proto_json_codec -from ._compression import Compression from ._envelope import EnvelopeReader, EnvelopeWriter from ._interceptor_sync import ( BidiStreamInterceptorSync, @@ -21,11 +20,15 @@ from ._protocol import CONNECT_STREAMING_HEADER_COMPRESSION, ConnectWireError from .code import Code from .errors import ConnectError -from .method import MethodInfo -from .request import Headers, RequestContext if TYPE_CHECKING: import sys + from collections.abc import Iterable, Iterator, Mapping + from types import TracebackType + + from ._compression import Compression + from .method import MethodInfo + from .request import Headers, RequestContext if sys.version_info >= (3, 11): from typing import Self diff --git a/src/connectrpc/_codec.py b/src/connectrpc/_codec.py index 5f30614..10167cb 100644 --- a/src/connectrpc/_codec.py +++ b/src/connectrpc/_codec.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Protocol, TypeVar from google.protobuf.json_format import MessageToJson diff --git a/src/connectrpc/_compression.py b/src/connectrpc/_compression.py index d863aee..090f043 100644 --- a/src/connectrpc/_compression.py +++ b/src/connectrpc/_compression.py @@ -1,6 +1,10 @@ +from __future__ import annotations + import gzip -from collections.abc import KeysView -from typing import Protocol +from typing import TYPE_CHECKING, Protocol + +if TYPE_CHECKING: + from collections.abc import KeysView class Compression(Protocol): diff --git a/src/connectrpc/_envelope.py b/src/connectrpc/_envelope.py index 3628b83..175b0c5 100644 --- a/src/connectrpc/_envelope.py +++ b/src/connectrpc/_envelope.py @@ -1,15 +1,20 @@ +from __future__ import annotations + import json import struct -from collections.abc import Iterator -from typing import Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar from ._client_shared import handle_response_trailers -from ._codec import Codec from ._compression import Compression, IdentityCompression from ._protocol import ConnectWireError from .code import Code from .errors import ConnectError -from .request import Headers + +if TYPE_CHECKING: + from collections.abc import Iterator + + from ._codec import Codec + from .request import Headers _RES = TypeVar("_RES") _T = TypeVar("_T") diff --git a/src/connectrpc/_headers.py b/src/connectrpc/_headers.py index e962989..9ffb7a1 100644 --- a/src/connectrpc/_headers.py +++ b/src/connectrpc/_headers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import ( ItemsView, Iterator, diff --git a/src/connectrpc/_interceptor_async.py b/src/connectrpc/_interceptor_async.py index 9d77335..5d8e372 100644 --- a/src/connectrpc/_interceptor_async.py +++ b/src/connectrpc/_interceptor_async.py @@ -1,7 +1,11 @@ -from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence -from typing import Generic, Protocol, TypeVar, runtime_checkable +from __future__ import annotations -from .request import RequestContext +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, runtime_checkable + +if TYPE_CHECKING: + from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence + + from .request import RequestContext REQ = TypeVar("REQ") RES = TypeVar("RES") diff --git a/src/connectrpc/_interceptor_sync.py b/src/connectrpc/_interceptor_sync.py index fac90eb..7b0ca8a 100644 --- a/src/connectrpc/_interceptor_sync.py +++ b/src/connectrpc/_interceptor_sync.py @@ -1,7 +1,11 @@ -from collections.abc import Callable, Iterable, Iterator, Sequence -from typing import Generic, Protocol, TypeVar, runtime_checkable +from __future__ import annotations -from .request import RequestContext +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, runtime_checkable + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Iterator, Sequence + + from .request import RequestContext REQ = TypeVar("REQ") RES = TypeVar("RES") diff --git a/src/connectrpc/_protocol.py b/src/connectrpc/_protocol.py index a2c6015..0136d35 100644 --- a/src/connectrpc/_protocol.py +++ b/src/connectrpc/_protocol.py @@ -1,16 +1,21 @@ +from __future__ import annotations + import json from base64 import b64decode, b64encode -from collections.abc import Sequence from dataclasses import dataclass from http import HTTPStatus -from typing import cast +from typing import TYPE_CHECKING, cast -import httpx from google.protobuf.any_pb2 import Any from .code import Code from .errors import ConnectError +if TYPE_CHECKING: + from collections.abc import Sequence + + import httpx + CONNECT_HEADER_PROTOCOL_VERSION = "connect-protocol-version" CONNECT_PROTOCOL_VERSION = "1" CONNECT_UNARY_CONTENT_TYPE_PREFIX = "application/" @@ -31,7 +36,7 @@ class ExtendedHTTPStatus: reason: str @staticmethod - def from_http_status(status: HTTPStatus) -> "ExtendedHTTPStatus": + def from_http_status(status: HTTPStatus) -> ExtendedHTTPStatus: return ExtendedHTTPStatus(code=status.value, reason=status.phrase) @@ -87,13 +92,13 @@ class ConnectWireError: details: Sequence[Any] @staticmethod - def from_exception(exc: Exception) -> "ConnectWireError": + def from_exception(exc: Exception) -> ConnectWireError: if isinstance(exc, ConnectError): return ConnectWireError(exc.code, exc.message, exc.details) return ConnectWireError(Code.UNKNOWN, str(exc), details=()) @staticmethod - def from_response(response: httpx.Response) -> "ConnectWireError": + def from_response(response: httpx.Response) -> ConnectWireError: try: data = response.json() except Exception: @@ -107,7 +112,7 @@ def from_response(response: httpx.Response) -> "ConnectWireError": @staticmethod def from_dict( data: dict, http_status: int, unexpected_code: Code - ) -> "ConnectWireError": + ) -> ConnectWireError: code_str = data.get("code") if code_str: try: @@ -136,7 +141,7 @@ def from_dict( return ConnectWireError(code, message, details) @staticmethod - def from_http_status(status_code: int) -> "ConnectWireError": + def from_http_status(status_code: int) -> ConnectWireError: code = _http_status_code_to_error.get(status_code, Code.UNKNOWN) try: http_status = HTTPStatus(status_code) diff --git a/src/connectrpc/_server_async.py b/src/connectrpc/_server_async.py index 176841d..1390951 100644 --- a/src/connectrpc/_server_async.py +++ b/src/connectrpc/_server_async.py @@ -1,16 +1,10 @@ +from __future__ import annotations + import base64 import functools import inspect from abc import ABC, abstractmethod from asyncio import CancelledError, sleep -from collections.abc import ( - AsyncGenerator, - AsyncIterator, - Callable, - Iterable, - Mapping, - Sequence, -) from dataclasses import replace from http import HTTPStatus from typing import TYPE_CHECKING, Generic, TypeVar, cast @@ -48,6 +42,15 @@ if TYPE_CHECKING: # We don't use asgiref code so only import from it for type checking + from collections.abc import ( + AsyncGenerator, + AsyncIterator, + Callable, + Iterable, + Mapping, + Sequence, + ) + from asgiref.typing import ASGIReceiveCallable, ASGISendCallable, HTTPScope, Scope else: ASGIReceiveCallable = "asgiref.typing.ASGIReceiveCallable" diff --git a/src/connectrpc/_server_shared.py b/src/connectrpc/_server_shared.py index 3500426..edf5504 100644 --- a/src/connectrpc/_server_shared.py +++ b/src/connectrpc/_server_shared.py @@ -1,7 +1,8 @@ -from collections.abc import AsyncIterator, Awaitable, Callable, Iterator +from __future__ import annotations + from dataclasses import dataclass from http import HTTPStatus -from typing import Generic, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar from ._protocol import ( CONNECT_HEADER_PROTOCOL_VERSION, @@ -14,6 +15,9 @@ from .method import IdempotencyLevel, MethodInfo from .request import Headers, RequestContext +if TYPE_CHECKING: + from collections.abc import AsyncIterator, Awaitable, Callable, Iterator + REQ = TypeVar("REQ") RES = TypeVar("RES") T = TypeVar("T") @@ -35,28 +39,28 @@ class Endpoint(Generic[REQ, RES]): def unary( method: MethodInfo[T, U], function: Callable[[T, RequestContext[T, U]], Awaitable[U]], - ) -> "EndpointUnary[T, U]": + ) -> EndpointUnary[T, U]: return EndpointUnary(method=method, function=function) @staticmethod def client_stream( method: MethodInfo[T, U], function: Callable[[AsyncIterator[T], RequestContext[T, U]], Awaitable[U]], - ) -> "EndpointClientStream[T, U]": + ) -> EndpointClientStream[T, U]: return EndpointClientStream(method=method, function=function) @staticmethod def server_stream( method: MethodInfo[T, U], function: Callable[[T, RequestContext[T, U]], AsyncIterator[U]], - ) -> "EndpointServerStream[T, U]": + ) -> EndpointServerStream[T, U]: return EndpointServerStream(method=method, function=function) @staticmethod def bidi_stream( method: MethodInfo[T, U], function: Callable[[AsyncIterator[T], RequestContext[T, U]], AsyncIterator[U]], - ) -> "EndpointBidiStream[T, U]": + ) -> EndpointBidiStream[T, U]: return EndpointBidiStream(method=method, function=function) @@ -96,7 +100,7 @@ class EndpointSync(Generic[REQ, RES]): @staticmethod def unary( *, method: MethodInfo[T, U], function: Callable[[T, RequestContext[T, U]], U] - ) -> "EndpointUnarySync[T, U]": + ) -> EndpointUnarySync[T, U]: return EndpointUnarySync(method=method, function=function) @staticmethod @@ -104,7 +108,7 @@ def client_stream( *, method: MethodInfo[T, U], function: Callable[[Iterator[T], RequestContext[T, U]], U], - ) -> "EndpointClientStreamSync[T, U]": + ) -> EndpointClientStreamSync[T, U]: return EndpointClientStreamSync(method=method, function=function) @staticmethod @@ -112,14 +116,14 @@ def server_stream( *, method: MethodInfo[T, U], function: Callable[[T, RequestContext[T, U]], Iterator[U]], - ) -> "EndpointServerStreamSync[T, U]": + ) -> EndpointServerStreamSync[T, U]: return EndpointServerStreamSync(method=method, function=function) @staticmethod def bidi_stream( method: MethodInfo[T, U], function: Callable[[Iterator[T], RequestContext[T, U]], Iterator[U]], - ) -> "EndpointBidiStreamSync[T, U]": + ) -> EndpointBidiStreamSync[T, U]: return EndpointBidiStreamSync(method=method, function=function) diff --git a/src/connectrpc/_server_sync.py b/src/connectrpc/_server_sync.py index c7aa3c8..9beba1a 100644 --- a/src/connectrpc/_server_sync.py +++ b/src/connectrpc/_server_sync.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import base64 import functools from abc import ABC, abstractmethod -from collections.abc import Iterable, Iterator, Mapping, Sequence from dataclasses import replace from http import HTTPStatus from typing import TYPE_CHECKING, TypeVar @@ -40,6 +41,7 @@ if TYPE_CHECKING: import sys + from collections.abc import Iterable, Iterator, Mapping, Sequence from io import BytesIO if sys.version_info >= (3, 11): diff --git a/src/connectrpc/_version.py b/src/connectrpc/_version.py index ba5d9db..3b4962e 100644 --- a/src/connectrpc/_version.py +++ b/src/connectrpc/_version.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from importlib.metadata import version __version__ = version("connect-python") diff --git a/src/connectrpc/client.py b/src/connectrpc/client.py index 85087b6..a727b63 100644 --- a/src/connectrpc/client.py +++ b/src/connectrpc/client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + __all__ = ["ConnectClient", "ConnectClientSync", "ResponseMetadata"] diff --git a/src/connectrpc/code.py b/src/connectrpc/code.py index 922c5cd..55d9872 100644 --- a/src/connectrpc/code.py +++ b/src/connectrpc/code.py @@ -1,3 +1,5 @@ +from __future__ import annotations + __all__ = ["Code"] diff --git a/src/connectrpc/errors.py b/src/connectrpc/errors.py index dfb0b42..ae12caf 100644 --- a/src/connectrpc/errors.py +++ b/src/connectrpc/errors.py @@ -1,12 +1,18 @@ +from __future__ import annotations + __all__ = ["ConnectError"] -from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING from google.protobuf.any_pb2 import Any -from google.protobuf.message import Message -from .code import Code +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + from google.protobuf.message import Message + + from .code import Code class ConnectError(Exception): diff --git a/src/connectrpc/interceptor.py b/src/connectrpc/interceptor.py index ba6fbb0..bb961b4 100644 --- a/src/connectrpc/interceptor.py +++ b/src/connectrpc/interceptor.py @@ -1,3 +1,5 @@ +from __future__ import annotations + __all__ = [ "BidiStreamInterceptor", "BidiStreamInterceptorSync", diff --git a/src/connectrpc/method.py b/src/connectrpc/method.py index 10c169f..fcbca6e 100644 --- a/src/connectrpc/method.py +++ b/src/connectrpc/method.py @@ -1,3 +1,5 @@ +from __future__ import annotations + __all__ = ["IdempotencyLevel", "MethodInfo"] diff --git a/src/connectrpc/request.py b/src/connectrpc/request.py index ed50af1..02b92ca 100644 --- a/src/connectrpc/request.py +++ b/src/connectrpc/request.py @@ -1,11 +1,15 @@ +from __future__ import annotations + __all__ = ["Headers", "RequestContext"] import time -from typing import Generic, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar from ._headers import Headers -from .method import MethodInfo + +if TYPE_CHECKING: + from .method import MethodInfo REQ = TypeVar("REQ") RES = TypeVar("RES") diff --git a/src/connectrpc/server.py b/src/connectrpc/server.py index dc7fb37..29881a8 100644 --- a/src/connectrpc/server.py +++ b/src/connectrpc/server.py @@ -1,3 +1,5 @@ +from __future__ import annotations + __all__ = [ "ConnectASGIApplication", "ConnectWSGIApplication", diff --git a/test/haberdasher_connect.py b/test/haberdasher_connect.py index 34a8537..d83c143 100644 --- a/test/haberdasher_connect.py +++ b/test/haberdasher_connect.py @@ -1,17 +1,15 @@ # Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT! # source: haberdasher.proto +from __future__ import annotations -from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping -from typing import Protocol +from typing import TYPE_CHECKING, Protocol import google.protobuf.empty_pb2 as google_dot_protobuf_dot_empty__pb2 from connectrpc.client import ConnectClient, ConnectClientSync from connectrpc.code import Code from connectrpc.errors import ConnectError -from connectrpc.interceptor import Interceptor, InterceptorSync from connectrpc.method import IdempotencyLevel, MethodInfo -from connectrpc.request import Headers, RequestContext from connectrpc.server import ( ConnectASGIApplication, ConnectWSGIApplication, @@ -21,6 +19,18 @@ from . import haberdasher_pb2 as haberdasher__pb2 +if TYPE_CHECKING: + from collections.abc import ( + AsyncGenerator, + AsyncIterator, + Iterable, + Iterator, + Mapping, + ) + + from connectrpc.interceptor import Interceptor, InterceptorSync + from connectrpc.request import Headers, RequestContext + class Haberdasher(Protocol): async def make_hat( diff --git a/test/test_client.py b/test/test_client.py index 4c56e81..c2df9f5 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from httpx import ASGITransport, AsyncClient, Client, WSGITransport diff --git a/test/test_details.py b/test/test_details.py index 200504b..83a9b67 100644 --- a/test/test_details.py +++ b/test/test_details.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import NoReturn import pytest diff --git a/test/test_errors.py b/test/test_errors.py index 9ba218d..9b48253 100644 --- a/test/test_errors.py +++ b/test/test_errors.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import threading import time from http import HTTPStatus diff --git a/test/test_example.py b/test/test_example.py index 8ddcdee..c3a4193 100644 --- a/test/test_example.py +++ b/test/test_example.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import threading from wsgiref.simple_server import WSGIServer, make_server diff --git a/test/test_headers.py b/test/test_headers.py index ae2c7fe..c2f5edc 100644 --- a/test/test_headers.py +++ b/test/test_headers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from connectrpc.request import Headers diff --git a/test/test_interceptor.py b/test/test_interceptor.py index 8b2e6ae..dfd1a11 100644 --- a/test/test_interceptor.py +++ b/test/test_interceptor.py @@ -1,11 +1,12 @@ +from __future__ import annotations + import itertools +from typing import TYPE_CHECKING import pytest import pytest_asyncio from httpx import ASGITransport, AsyncClient, Client, WSGITransport -from connectrpc.request import RequestContext - from .haberdasher_connect import ( Haberdasher, HaberdasherASGIApplication, @@ -16,6 +17,9 @@ ) from .haberdasher_pb2 import Hat, Size +if TYPE_CHECKING: + from connectrpc.request import RequestContext + class RequestInterceptor: def __init__(self) -> None: diff --git a/test/test_lifespan.py b/test/test_lifespan.py index b45dd45..90ba14b 100644 --- a/test/test_lifespan.py +++ b/test/test_lifespan.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio from collections import Counter from io import StringIO diff --git a/test/test_roundtrip.py b/test/test_roundtrip.py index 27589ce..67f4501 100644 --- a/test/test_roundtrip.py +++ b/test/test_roundtrip.py @@ -1,4 +1,6 @@ -from collections.abc import AsyncIterator, Iterator +from __future__ import annotations + +from typing import TYPE_CHECKING import pytest from httpx import ASGITransport, AsyncClient, Client, WSGITransport @@ -16,6 +18,9 @@ ) from .haberdasher_pb2 import Hat, Size +if TYPE_CHECKING: + from collections.abc import AsyncIterator, Iterator + @pytest.mark.parametrize("proto_json", [False, True]) @pytest.mark.parametrize("compression", ["gzip", "br", "zstd", "identity", None])