1+ import json
12import os
23import threading
34import re
78from test .support .os_helper import EnvironmentVarGuard
89from http .server import BaseHTTPRequestHandler , HTTPServer
910
11+ from kagglesdk .kaggle_env import get_endpoint , get_env
12+
1013class KaggleAPIHandler (BaseHTTPRequestHandler ):
1114 """
1215 Fake Kaggle API server supporting the download endpoint.
@@ -15,15 +18,18 @@ class KaggleAPIHandler(BaseHTTPRequestHandler):
1518 def do_HEAD (self ):
1619 self .send_response (200 )
1720
18- def do_GET (self ):
19- m = re .match ("^/api/v1/models/(.+)/download/(.+)$" , self .path )
20- if not m :
21+ def do_POST (self ):
22+ content_length = int (self .headers .get ('Content-Length' , 0 ))
23+ body_bytes = self .rfile .read (content_length )
24+ request_body = json .loads (body_bytes .decode ('utf-8' ))
25+
26+ if self .path != "/api/v1/models.ModelApiService/DownloadModelInstanceVersion" :
2127 self .send_response (404 )
2228 self .wfile .write (bytes (f"Unhandled path: { self .path } " , "utf-8" ))
2329 return
2430
25- model_handle = m . group ( 1 )
26- path = m . group ( 2 )
31+ model_handle = f" { request_body [ "ownerSlug" ] } / { request_body [ "modelSlug" ] } /keras/ { request_body [ "instanceSlug" ] } / { request_body [ "versionNumber" ] } "
32+ path = request_body [ "path" ]
2733 filepath = f"/input/tests/data/kagglehub/models/{ model_handle } /{ path } "
2834 if not os .path .isfile (filepath ):
2935 self .send_error (404 , "Internet is disabled in our tests "
@@ -41,14 +47,12 @@ def do_GET(self):
4147
4248@contextmanager
4349def create_test_kagglehub_server ():
44- endpoint = 'http://localhost:7777'
4550 env = EnvironmentVarGuard ()
46- env .set ('KAGGLE_API_ENDPOINT' , endpoint )
47- test_server_address = urlparse (endpoint )
51+ env .set ('KAGGLE_API_ENVIRONMENT' , 'TEST' )
4852 with env :
49- if not test_server_address . hostname or not test_server_address . port :
50- msg = f"Invalid test server address: { endpoint } . You must specify a hostname & port"
51- raise ValueError ( msg )
53+ endpoint = get_endpoint ( get_env ())
54+ test_server_address = urlparse ( endpoint )
55+
5256 with HTTPServer ((test_server_address .hostname , test_server_address .port ), KaggleAPIHandler ) as httpd :
5357 threading .Thread (target = httpd .serve_forever ).start ()
5458
0 commit comments