Skip to content

Commit 9e3ec03

Browse files
authored
feat: Add headers argument to allow setting additional user headers (#601)
## Description This PR adds the `headers` argument for `ApifyClient` and `ApifyClientAsync`, which allows you to set custom headers in the HTTP client. The set headers will be applied to all requests. This solves the problem case described in #416 (comment) ## Testing Tests checking the set headers have been added.
1 parent 25ff4e5 commit 9e3ec03

File tree

3 files changed

+186
-6
lines changed

3 files changed

+186
-6
lines changed

src/apify_client/_http_client.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,29 +38,32 @@ def __init__(
3838
min_delay_between_retries_millis: int = 500,
3939
timeout_secs: int = 360,
4040
stats: Statistics | None = None,
41+
headers: dict | None = None,
4142
) -> None:
4243
self.max_retries = max_retries
4344
self.min_delay_between_retries_millis = min_delay_between_retries_millis
4445
self.timeout_secs = timeout_secs
4546

46-
headers = {'Accept': 'application/json, */*'}
47+
default_headers = {'Accept': 'application/json, */*'}
4748

4849
workflow_key = os.getenv('APIFY_WORKFLOW_KEY')
4950
if workflow_key is not None:
50-
headers['X-Apify-Workflow-Key'] = workflow_key
51+
default_headers['X-Apify-Workflow-Key'] = workflow_key
5152

5253
is_at_home = 'APIFY_IS_AT_HOME' in os.environ
5354
python_version = '.'.join([str(x) for x in sys.version_info[:3]])
5455
client_version = metadata.version('apify-client')
5556

5657
user_agent = f'ApifyClient/{client_version} ({sys.platform}; Python/{python_version}); isAtHome/{is_at_home}'
57-
headers['User-Agent'] = user_agent
58+
default_headers['User-Agent'] = user_agent
5859

5960
if token is not None:
60-
headers['Authorization'] = f'Bearer {token}'
61+
default_headers['Authorization'] = f'Bearer {token}'
62+
63+
init_headers = {**default_headers, **(headers or {})}
6164

62-
self.impit_client = impit.Client(headers=headers, follow_redirects=True, timeout=timeout_secs)
63-
self.impit_async_client = impit.AsyncClient(headers=headers, follow_redirects=True, timeout=timeout_secs)
65+
self.impit_client = impit.Client(headers=init_headers, follow_redirects=True, timeout=timeout_secs)
66+
self.impit_async_client = impit.AsyncClient(headers=init_headers, follow_redirects=True, timeout=timeout_secs)
6467

6568
self.stats = stats or Statistics()
6669

src/apify_client/client.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import warnings
4+
35
from apify_client._http_client import HTTPClient, HTTPClientAsync
46
from apify_client._statistics import Statistics
57
from apify_client.clients import (
@@ -98,6 +100,18 @@ def _options(self) -> dict:
98100
'http_client': self.http_client,
99101
}
100102

103+
def _check_custom_headers(self, headers: dict) -> None:
104+
default_headers = {'Accept', 'Authorization', 'Accept-Encoding', 'User-Agent'}
105+
overwrite_headers = [key for key in headers if key.title() in default_headers]
106+
if overwrite_headers:
107+
warnings.warn(
108+
f'{", ".join(overwrite_headers)} headers of {self.__class__.__name__} was overridden with an '
109+
'explicit value. A wrong header value can lead to API errors, it is recommended to use the default '
110+
f'value for following headers: {", ".join(default_headers)}.',
111+
category=UserWarning,
112+
stacklevel=2,
113+
)
114+
101115

102116
class ApifyClient(_BaseApifyClient):
103117
"""The Apify API client."""
@@ -113,6 +127,7 @@ def __init__(
113127
max_retries: int | None = 8,
114128
min_delay_between_retries_millis: int | None = 500,
115129
timeout_secs: int | None = DEFAULT_TIMEOUT,
130+
headers: dict | None = None,
116131
) -> None:
117132
"""Initialize a new instance.
118133
@@ -126,6 +141,7 @@ def __init__(
126141
min_delay_between_retries_millis: How long will the client wait between retrying requests
127142
(increases exponentially from this value).
128143
timeout_secs: The socket timeout of the HTTP requests sent to the Apify API.
144+
headers: Set headers to client for all requests.
129145
"""
130146
super().__init__(
131147
token,
@@ -137,12 +153,17 @@ def __init__(
137153
)
138154

139155
self.stats = Statistics()
156+
157+
if headers:
158+
self._check_custom_headers(headers)
159+
140160
self.http_client = HTTPClient(
141161
token=token,
142162
max_retries=self.max_retries,
143163
min_delay_between_retries_millis=self.min_delay_between_retries_millis,
144164
timeout_secs=self.timeout_secs,
145165
stats=self.stats,
166+
headers=headers,
146167
)
147168

148169
def actor(self, actor_id: str) -> ActorClient:
@@ -301,6 +322,7 @@ def __init__(
301322
max_retries: int | None = 8,
302323
min_delay_between_retries_millis: int | None = 500,
303324
timeout_secs: int | None = DEFAULT_TIMEOUT,
325+
headers: dict | None = None,
304326
) -> None:
305327
"""Initialize a new instance.
306328
@@ -314,6 +336,7 @@ def __init__(
314336
min_delay_between_retries_millis: How long will the client wait between retrying requests
315337
(increases exponentially from this value).
316338
timeout_secs: The socket timeout of the HTTP requests sent to the Apify API.
339+
headers: Set headers to client for all requests.
317340
"""
318341
super().__init__(
319342
token,
@@ -325,12 +348,17 @@ def __init__(
325348
)
326349

327350
self.stats = Statistics()
351+
352+
if headers:
353+
self._check_custom_headers(headers)
354+
328355
self.http_client = HTTPClientAsync(
329356
token=token,
330357
max_retries=self.max_retries,
331358
min_delay_between_retries_millis=self.min_delay_between_retries_millis,
332359
timeout_secs=self.timeout_secs,
333360
stats=self.stats,
361+
headers=headers,
334362
)
335363

336364
def actor(self, actor_id: str) -> ActorClientAsync:

tests/unit/test_client_headers.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
from __future__ import annotations
2+
3+
import json
4+
import os
5+
import sys
6+
from importlib import metadata
7+
from typing import TYPE_CHECKING
8+
9+
import pytest
10+
from werkzeug import Request, Response
11+
12+
from apify_client import ApifyClient, ApifyClientAsync
13+
from apify_client._http_client import HTTPClient, HTTPClientAsync
14+
15+
if TYPE_CHECKING:
16+
from pytest_httpserver import HTTPServer
17+
18+
19+
def _header_handler(request: Request) -> Response:
20+
return Response(
21+
status=200,
22+
headers={},
23+
response=json.dumps({'received_headers': dict(request.headers)}),
24+
)
25+
26+
27+
def _get_user_agent() -> str:
28+
is_at_home = 'APIFY_IS_AT_HOME' in os.environ
29+
python_version = '.'.join([str(x) for x in sys.version_info[:3]])
30+
client_version = metadata.version('apify-client')
31+
return f'ApifyClient/{client_version} ({sys.platform}; Python/{python_version}); isAtHome/{is_at_home}'
32+
33+
34+
async def test_default_headers_async(httpserver: HTTPServer) -> None:
35+
"""Test that default headers are sent with each request."""
36+
37+
client = HTTPClientAsync(token='placeholder_token')
38+
httpserver.expect_request('/').respond_with_handler(_header_handler)
39+
api_url = httpserver.url_for('/').removesuffix('/')
40+
41+
response = await client.call(method='GET', url=f'{api_url}/')
42+
43+
request_headers = json.loads(response.text)['received_headers']
44+
45+
assert request_headers == {
46+
'User-Agent': _get_user_agent(),
47+
'Accept': 'application/json, */*',
48+
'Authorization': 'Bearer placeholder_token',
49+
'Accept-Encoding': 'gzip, br, zstd, deflate',
50+
'Host': f'{httpserver.host}:{httpserver.port}',
51+
}
52+
53+
54+
def test_default_headers_sync(httpserver: HTTPServer) -> None:
55+
"""Test that default headers are sent with each request."""
56+
57+
client = HTTPClient(token='placeholder_token')
58+
httpserver.expect_request('/').respond_with_handler(_header_handler)
59+
api_url = httpserver.url_for('/').removesuffix('/')
60+
61+
response = client.call(method='GET', url=f'{api_url}/')
62+
63+
request_headers = json.loads(response.text)['received_headers']
64+
65+
assert request_headers == {
66+
'User-Agent': _get_user_agent(),
67+
'Accept': 'application/json, */*',
68+
'Authorization': 'Bearer placeholder_token',
69+
'Accept-Encoding': 'gzip, br, zstd, deflate',
70+
'Host': f'{httpserver.host}:{httpserver.port}',
71+
}
72+
73+
74+
async def test_headers_async(httpserver: HTTPServer) -> None:
75+
"""Test that custom headers are sent with each request."""
76+
77+
client = HTTPClientAsync(
78+
token='placeholder_token',
79+
headers={'Test-Header': 'blah', 'User-Agent': 'CustomUserAgent/1.0', 'Authorization': 'strange_value'},
80+
)
81+
httpserver.expect_request('/').respond_with_handler(_header_handler)
82+
api_url = httpserver.url_for('/').removesuffix('/')
83+
84+
response = await client.call(method='GET', url=f'{api_url}/')
85+
86+
request_headers = json.loads(response.text)['received_headers']
87+
88+
assert request_headers == {
89+
'Test-Header': 'blah',
90+
'User-Agent': 'CustomUserAgent/1.0',
91+
'Accept': 'application/json, */*',
92+
'Authorization': 'strange_value',
93+
'Accept-Encoding': 'gzip, br, zstd, deflate',
94+
'Host': f'{httpserver.host}:{httpserver.port}',
95+
}
96+
97+
98+
def test_headers_sync(httpserver: HTTPServer) -> None:
99+
"""Test that custom headers are sent with each request."""
100+
101+
client = HTTPClient(
102+
token='placeholder_token',
103+
headers={
104+
'Test-Header': 'blah',
105+
'User-Agent': 'CustomUserAgent/1.0',
106+
'Authorization': 'strange_value',
107+
},
108+
)
109+
httpserver.expect_request('/').respond_with_handler(_header_handler)
110+
api_url = httpserver.url_for('/').removesuffix('/')
111+
112+
response = client.call(method='GET', url=f'{api_url}/')
113+
114+
request_headers = json.loads(response.text)['received_headers']
115+
116+
assert request_headers == {
117+
'Test-Header': 'blah',
118+
'User-Agent': 'CustomUserAgent/1.0',
119+
'Accept': 'application/json, */*',
120+
'Authorization': 'strange_value',
121+
'Accept-Encoding': 'gzip, br, zstd, deflate',
122+
'Host': f'{httpserver.host}:{httpserver.port}',
123+
}
124+
125+
126+
def test_warning_on_overridden_headers_sync() -> None:
127+
"""Test that warning is raised when default headers are overridden."""
128+
129+
with pytest.warns(UserWarning, match='User-Agent, Authorization headers of ApifyClient'):
130+
ApifyClient(
131+
token='placeholder_token',
132+
headers={
133+
'User-Agent': 'CustomUserAgent/1.0',
134+
'Authorization': 'strange_value',
135+
},
136+
)
137+
138+
139+
async def test_warning_on_overridden_headers_async() -> None:
140+
"""Test that warning is raised when default headers are overridden."""
141+
142+
with pytest.warns(UserWarning, match='User-Agent, Authorization headers of ApifyClientAsync'):
143+
ApifyClientAsync(
144+
token='placeholder_token',
145+
headers={
146+
'User-Agent': 'CustomUserAgent/1.0',
147+
'Authorization': 'strange_value',
148+
},
149+
)

0 commit comments

Comments
 (0)