Skip to content

Commit 0c165a3

Browse files
authored
Fix kagglehub fake API server (#1522)
Starting with `kagglehub` 0.4.0, a new generated Python client is used which sends POST requests instead of GET. The endpoint is now controlled via the `KAGGLE_API_ENVIRONMENT` env variable. http://b/475292408
1 parent a85fe20 commit 0c165a3

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

tests/utils/kagglehub.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import os
23
import threading
34
import re
@@ -7,6 +8,8 @@
78
from test.support.os_helper import EnvironmentVarGuard
89
from http.server import BaseHTTPRequestHandler, HTTPServer
910

11+
from kagglesdk.kaggle_env import get_endpoint, get_env
12+
1013
class KaggleAPIHandler(BaseHTTPRequestHandler):
1114
"""
1215
Fake Kaggle API server supporting the download endpoint.
@@ -15,15 +18,18 @@ class KaggleAPIHandler(BaseHTTPRequestHandler):
1518
def do_HEAD(self):
1619
self.send_response(200)
1720

18-
def do_GET(self):
19-
m = re.match("^/api/v1/models/(.+)/download/(.+)$", self.path)
20-
if not m:
21+
def do_POST(self):
22+
content_length = int(self.headers.get('Content-Length', 0))
23+
body_bytes = self.rfile.read(content_length)
24+
request_body = json.loads(body_bytes.decode('utf-8'))
25+
26+
if self.path != "/api/v1/models.ModelApiService/DownloadModelInstanceVersion":
2127
self.send_response(404)
2228
self.wfile.write(bytes(f"Unhandled path: {self.path}", "utf-8"))
2329
return
2430

25-
model_handle = m.group(1)
26-
path = m.group(2)
31+
model_handle = f"{request_body["ownerSlug"]}/{request_body["modelSlug"]}/keras/{request_body["instanceSlug"]}/{request_body["versionNumber"]}"
32+
path = request_body["path"]
2733
filepath = f"/input/tests/data/kagglehub/models/{model_handle}/{path}"
2834
if not os.path.isfile(filepath):
2935
self.send_error(404, "Internet is disabled in our tests "
@@ -41,14 +47,12 @@ def do_GET(self):
4147

4248
@contextmanager
4349
def create_test_kagglehub_server():
44-
endpoint = 'http://localhost:7777'
4550
env = EnvironmentVarGuard()
46-
env.set('KAGGLE_API_ENDPOINT', endpoint)
47-
test_server_address = urlparse(endpoint)
51+
env.set('KAGGLE_API_ENVIRONMENT', 'TEST')
4852
with env:
49-
if not test_server_address.hostname or not test_server_address.port:
50-
msg = f"Invalid test server address: {endpoint}. You must specify a hostname & port"
51-
raise ValueError(msg)
53+
endpoint = get_endpoint(get_env())
54+
test_server_address = urlparse(endpoint)
55+
5256
with HTTPServer((test_server_address.hostname, test_server_address.port), KaggleAPIHandler) as httpd:
5357
threading.Thread(target=httpd.serve_forever).start()
5458

0 commit comments

Comments
 (0)