1616# under the License.
1717"""FileIO implementation for reading and writing table files that uses fsspec compatible filesystems."""
1818
19+ import abc
1920import errno
2021import json
2122import logging
2223import os
2324import threading
2425from copy import copy
25- from functools import lru_cache , partial
26+ from functools import lru_cache
2627from typing import (
2728 TYPE_CHECKING ,
2829 Any ,
2930 Callable ,
3031 Dict ,
32+ Type ,
3133 Union ,
3234)
3335from urllib .parse import urlparse
9597 from botocore .awsrequest import AWSRequest
9698
9799
98- def s3v4_rest_signer (properties : Properties , request : "AWSRequest" , ** _ : Any ) -> "AWSRequest" :
99- signer_url = properties .get (S3_SIGNER_URI , properties [URI ]).rstrip ("/" ) # type: ignore
100- signer_endpoint = properties .get (S3_SIGNER_ENDPOINT , S3_SIGNER_ENDPOINT_DEFAULT )
100+ class S3RequestSigner (abc .ABC ):
101+ """Abstract base class for S3 request signers."""
101102
102- signer_headers = {}
103- if token := properties .get (TOKEN ):
104- signer_headers = {"Authorization" : f"Bearer { token } " }
105- signer_headers .update (get_header_properties (properties ))
103+ properties : Properties
106104
107- signer_body = {
108- "method" : request .method ,
109- "region" : request .context ["client_region" ],
110- "uri" : request .url ,
111- "headers" : {key : [val ] for key , val in request .headers .items ()},
112- }
105+ def __init__ (self , properties : Properties ) -> None :
106+ self .properties = properties
107+
108+ @abc .abstractmethod
109+ def __call__ (self , request : "AWSRequest" , ** _ : Any ) -> None :
110+ pass
111+
112+
113+ class S3V4RestSigner (S3RequestSigner ):
114+ """An S3 request signer that uses an external REST signing service to sign requests."""
115+
116+ session : requests .Session
113117
114- response = requests .post (f"{ signer_url } /{ signer_endpoint .strip ()} " , headers = signer_headers , json = signer_body )
115- try :
116- response .raise_for_status ()
117- response_json = response .json ()
118- except HTTPError as e :
119- raise SignError (f"Failed to sign request { response .status_code } : { signer_body } " ) from e
118+ def __init__ (self , properties : Properties ) -> None :
119+ super ().__init__ (properties )
120+ self .session = requests .Session ()
120121
121- for key , value in response_json ["headers" ].items ():
122- request .headers .add_header (key , ", " .join (value ))
122+ def __call__ (self , request : "AWSRequest" , ** _ : Any ) -> None :
123+ signer_url = self .properties .get (S3_SIGNER_URI , self .properties [URI ]).rstrip ("/" ) # type: ignore
124+ signer_endpoint = self .properties .get (S3_SIGNER_ENDPOINT , S3_SIGNER_ENDPOINT_DEFAULT )
125+
126+ signer_headers = {}
127+ if token := self .properties .get (TOKEN ):
128+ signer_headers = {"Authorization" : f"Bearer { token } " }
129+ signer_headers .update (get_header_properties (self .properties ))
130+
131+ signer_body = {
132+ "method" : request .method ,
133+ "region" : request .context ["client_region" ],
134+ "uri" : request .url ,
135+ "headers" : {key : [val ] for key , val in request .headers .items ()},
136+ }
137+
138+ response = self .session .post (f"{ signer_url } /{ signer_endpoint .strip ()} " , headers = signer_headers ,
139+ json = signer_body )
140+ try :
141+ response .raise_for_status ()
142+ response_json = response .json ()
143+ except HTTPError as e :
144+ raise SignError (f"Failed to sign request { response .status_code } : { signer_body } " ) from e
123145
124- request .url = response_json ["uri" ]
146+ for key , value in response_json ["headers" ].items ():
147+ request .headers .add_header (key , ", " .join (value ))
125148
126- return request
149+ request . url = response_json [ "uri" ]
127150
128151
129- SIGNERS : Dict [str , Callable [[ Properties , "AWSRequest" ], "AWSRequest" ]] = {"S3V4RestSigner" : s3v4_rest_signer }
152+ SIGNERS : Dict [str , Type [ S3RequestSigner ]] = {"S3V4RestSigner" : S3V4RestSigner }
130153
131154
132155def _file (_ : Properties ) -> LocalFileSystem :
@@ -144,13 +167,13 @@ def _s3(properties: Properties) -> AbstractFileSystem:
144167 "region_name" : get_first_property_value (properties , S3_REGION , AWS_REGION ),
145168 }
146169 config_kwargs = {}
147- register_events : Dict [str , Callable [[Properties ], None ]] = {}
170+ register_events : Dict [str , Callable [[AWSRequest ], None ]] = {}
148171
149172 if signer := properties .get (S3_SIGNER ):
150173 logger .info ("Loading signer %s" , signer )
151- if signer_func := SIGNERS .get (signer ):
152- signer_func_with_properties = partial ( signer_func , properties )
153- register_events ["before-sign.s3" ] = signer_func_with_properties
174+ if signer_cls := SIGNERS .get (signer ):
175+ signer = signer_cls ( properties )
176+ register_events ["before-sign.s3" ] = signer
154177
155178 # Disable the AWS Signer
156179 from botocore import UNSIGNED
0 commit comments