Skip to content

Commit 3d51161

Browse files
authored
[MODEL-20346] Add optional inputSchema to metadata returned by /info/ route (#1661)
* add inputSchema to info/ route * add docs * add validation and more test cases * add more annotations * add more annotations
1 parent 43b2005 commit 3d51161

File tree

5 files changed

+223
-3
lines changed

5 files changed

+223
-3
lines changed

MODEL-METADATA.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ ignored if modelID is set.
2828
* majorVersion (optional, default: True): Whether the model version you are creating should be a
2929
major version update or a minor version update. If the previous model version is 2.3, a major version
3030
update would create the version 3.3, and a minor version update would create the version 2.4.
31+
* inputSchema (optional): A schema defining the format of the input data for your model. This is
32+
required when building unstructured models to serve as tools within MCP servers. The schema follows
33+
JSON Schema format with three key components: `type` (typically "object" for structured data),
34+
`properties` (a dictionary defining each field with its type, constraints, and optional default values),
35+
and `required` (an array listing mandatory fields). This type of schema can be generated by serializing
36+
the schema of a pydantic model.
3137

3238
## Options specific to inference models
3339
NOTE: All options specific to inference models or tasks are ignored if modelID is set- they

custom_model_runner/datarobot_drum/drum/enum.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ class ModelMetadataKeys(object):
442442
TRAINING_MODEL = "trainingModel"
443443
HYPERPARAMETERS = "hyperparameters"
444444
VALIDATION_SCHEMA = "typeSchema"
445+
INPUT_SCHEMA = "inputSchema"
445446
# customPredictor section is not used by DRUM,
446447
# it is a place holder if user wants to add some fields and read them on his own
447448
CUSTOM_PREDICTOR = "customPredictor"

custom_model_runner/datarobot_drum/drum/model_metadata.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import trafaret as t
1313

1414
from pathlib import Path
15+
16+
from pydantic import create_model
1517
from ruamel.yaml import YAMLError
1618
from strictyaml import (
1719
load,
@@ -25,7 +27,7 @@
2527
StrictYAMLError,
2628
YAMLValidationError,
2729
)
28-
from typing import Optional as PythonTypingOptional, List, Dict
30+
from typing import Optional as PythonTypingOptional, List, Dict, Union
2931

3032
from datarobot_drum.drum.common import get_drum_logger
3133
from datarobot_drum.drum.enum import (
@@ -187,7 +189,7 @@ def read_model_metadata_yaml(code_dir) -> PythonTypingOptional[dict]:
187189
validate_config_fields(model_config, ModelMetadataKeys.INFERENCE_MODEL)
188190
validate_config_fields(
189191
model_config[ModelMetadataKeys.INFERENCE_MODEL],
190-
*["positiveClassLabel", "negativeClassLabel"]
192+
*["positiveClassLabel", "negativeClassLabel"],
191193
)
192194

193195
if model_config[ModelMetadataKeys.TARGET_TYPE] == TargetType.MULTICLASS.value:
@@ -230,10 +232,67 @@ def read_model_metadata_yaml(code_dir) -> PythonTypingOptional[dict]:
230232
if hyper_params:
231233
validate_model_metadata_hyperparameter(hyper_params)
232234

235+
input_schema = model_config.get(ModelMetadataKeys.INPUT_SCHEMA)
236+
if input_schema:
237+
try:
238+
model = create_model_from_schema(input_schema)
239+
except Exception as e:
240+
raise DrumCommonException(
241+
"Error creating pydantic model from input schema: {}".format(e)
242+
)
243+
233244
return model_config
234245
return None
235246

236247

248+
def convert_json_type_to_python(prop_def: dict):
249+
"""Convert JSON Schema type to Python type."""
250+
251+
# Handle anyOf for union types
252+
if "anyOf" in prop_def:
253+
types = []
254+
for schema in prop_def["anyOf"]:
255+
types.append(convert_json_type_to_python(schema))
256+
return Union[tuple(types)]
257+
258+
# Handle regular `type` field of the property
259+
json_type = prop_def.get("type", "string")
260+
261+
type_mapping = {
262+
"string": str,
263+
"integer": int,
264+
"number": float,
265+
"boolean": bool,
266+
"array": list,
267+
"object": dict,
268+
"null": type(None),
269+
}
270+
271+
return type_mapping.get(json_type, str)
272+
273+
274+
def create_model_from_schema(schema_dict: dict):
275+
"""Create a Pydantic model from a JSON Schema dictionary."""
276+
schema_type = schema_dict.get("type")
277+
if schema_type != "object":
278+
raise ValueError(f"Only 'object' type schemas are supported, got '{schema_type}'")
279+
280+
properties = schema_dict.get("properties", {})
281+
282+
properties_type = type(properties)
283+
if properties_type is not dict:
284+
raise ValueError(f"'properties' must be a dictionary, got '{properties_type}'")
285+
286+
fields, required_fields = {}, set(schema_dict.get("required", []))
287+
288+
for prop_name, prop_def in properties.items():
289+
py_type = convert_json_type_to_python(prop_def)
290+
default_value = ... if prop_name in required_fields else prop_def.get("default")
291+
fields[prop_name] = (py_type, default_value)
292+
293+
return create_model("InputSchema", **fields)
294+
295+
237296
def read_default_model_metadata_yaml() -> PythonTypingOptional[dict]:
238297
default_type_schema_path = os.path.abspath(
239298
os.path.join(os.path.dirname(__file__), "..", "resource", "default_typeschema")
@@ -358,5 +417,13 @@ def _validate_multi_parameter(multi_params: Dict):
358417
Map({"key": Str(), "valueFrom": Str(), Optional("reminder"): Str()})
359418
),
360419
Optional(ModelMetadataKeys.LAZY_LOADING): Any(),
420+
Optional(ModelMetadataKeys.INPUT_SCHEMA): Map(
421+
{
422+
Optional("title"): Str(),
423+
"type": Str(),
424+
"properties": Any(),
425+
Optional("required"): Seq(Str()),
426+
}
427+
),
361428
}
362429
)

tests/unit/datarobot_drum/conftest.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,93 @@ def custom_predictor_metadata_yaml():
210210
)
211211

212212

213+
@pytest.fixture
214+
def custom_unstructured_tool_with_schema_in_yaml():
215+
return dedent(
216+
"""
217+
name: "[Tool] Get Data Registry Dataset"
218+
description: |
219+
Fetches a dataset from the DataRobot Data Registry.
220+
221+
type: inference
222+
environmentID: 64d2ba178dd3f0b1fa2162f0
223+
targetType: unstructured
224+
inferenceModel:
225+
targetName: target
226+
inputSchema:
227+
type: object
228+
properties:
229+
dataset_id:
230+
title: Dataset ID
231+
description: The ID of the dataset to fetch from the Data Registry.
232+
type: string
233+
offset:
234+
title: Offset
235+
description: The number of rows to skip before starting to return rows. Default is 0.
236+
type: integer
237+
default: 0
238+
limit:
239+
title: Limit of rows
240+
description: The maximum number of rows to return. If not specified, all rows will be returned.
241+
anyOf:
242+
- type: integer
243+
- type: null
244+
default: null
245+
required:
246+
- dataset_id
247+
"""
248+
)
249+
250+
251+
@pytest.fixture
252+
def custom_unstructured_tool_with_invalid_schema1():
253+
return dedent(
254+
"""
255+
name: "[Tool] Get Data Registry Dataset"
256+
description: |
257+
Fetches a dataset from the DataRobot Data Registry.
258+
259+
type: inference
260+
environmentID: 64d2ba178dd3f0b1fa2162f0
261+
targetType: unstructured
262+
inferenceModel:
263+
targetName: target
264+
inputSchema:
265+
type: unexpected
266+
properties:
267+
dataset_id:
268+
title: Dataset ID
269+
type: string
270+
required:
271+
- dataset_id
272+
"""
273+
)
274+
275+
276+
@pytest.fixture
277+
def custom_unstructured_tool_with_invalid_schema2():
278+
return dedent(
279+
"""
280+
name: "[Tool] Get Data Registry Dataset"
281+
description: |
282+
Fetches a dataset from the DataRobot Data Registry.
283+
284+
type: inference
285+
environmentID: 64d2ba178dd3f0b1fa2162f0
286+
targetType: unstructured
287+
inferenceModel:
288+
targetName: target
289+
inputSchema:
290+
type: object
291+
properties:
292+
- list-instead-of-dict
293+
- another-item
294+
required:
295+
- dataset_id
296+
"""
297+
)
298+
299+
213300
###############################################################################
214301
# HELPER FUNCS
215302

tests/unit/datarobot_drum/model_metadata/test_model_metadata.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
from random import sample
1212
from tempfile import TemporaryDirectory
1313
from textwrap import dedent
14-
from typing import List, Union
14+
from typing import List, Union, Optional
1515

1616
import pytest
1717
import numpy as np
1818
import pandas as pd
19+
from pydantic import BaseModel, Field
1920
from scipy import sparse
2021
import yaml
2122
from strictyaml import load, YAMLValidationError
@@ -1001,6 +1002,16 @@ def _inner(input_dict):
10011002
yield _inner
10021003

10031004

1005+
def normalize_schema(obj):
1006+
if isinstance(obj, dict):
1007+
return {k: normalize_schema(v) for k, v in obj.items()}
1008+
if isinstance(obj, list):
1009+
return [normalize_schema(v) for v in obj]
1010+
if obj == "null" or obj == "None":
1011+
return None
1012+
return obj
1013+
1014+
10041015
class TestReadModelMetadata:
10051016
@pytest.fixture
10061017
def minimal_training_metadata(self, environment_id):
@@ -1010,13 +1021,46 @@ def minimal_training_metadata(self, environment_id):
10101021
"targetType": "regression",
10111022
"environmentID": environment_id,
10121023
"validation": {"input": "hello"},
1024+
"inputSchema": {
1025+
"type": "object",
1026+
"properties": {
1027+
"foo": {"type": "integer", "default": "42"},
1028+
"bar": {"type": "string"},
1029+
"baz": {"type": "boolean"},
1030+
},
1031+
"required": ["foo", "bar"],
1032+
},
10131033
}
10141034

1035+
@pytest.fixture
1036+
def metadata_with_pydantic_schema(self, minimal_training_metadata):
1037+
"""An example of using a pydantic model to generate the input schema
1038+
with rich annotations, which are common practices when defining pydantic
1039+
models for automated processing using LLMs.
1040+
"""
1041+
1042+
class ExampleSchema(BaseModel):
1043+
foo: int = Field(..., description="A foo field")
1044+
bar: str = Field(
1045+
"bar-value", title="The bar field", description="The bar field with a default value"
1046+
)
1047+
baz: Optional[bool] = None
1048+
1049+
minimal_training_metadata["inputSchema"] = ExampleSchema.model_json_schema()
1050+
return minimal_training_metadata
1051+
10151052
def test_minimal_data(self, model_metadata_file_factory, minimal_training_metadata):
10161053
code_dir = model_metadata_file_factory(minimal_training_metadata)
10171054
result = read_model_metadata_yaml(code_dir)
10181055
assert result == minimal_training_metadata
10191056

1057+
def test_metadata_with_pydantic_schema(
1058+
self, model_metadata_file_factory, metadata_with_pydantic_schema
1059+
):
1060+
code_dir = model_metadata_file_factory(metadata_with_pydantic_schema)
1061+
result = read_model_metadata_yaml(code_dir)
1062+
assert normalize_schema(result) == normalize_schema(metadata_with_pydantic_schema)
1063+
10201064
def test_user_credential_specs(self, model_metadata_file_factory, minimal_training_metadata):
10211065
credential_specs = [
10221066
{"key": "HI", "valueFrom": "65170a6bc4b7f4bec89db932", "reminder": "remember"},
@@ -1176,8 +1220,11 @@ def test_validate_model_metadata_output_requirements_r():
11761220
("inference_binary_metadata_no_label", 2),
11771221
("inference_multiclass_metadata_yaml_no_labels", 3),
11781222
("inference_multiclass_metadata_yaml_labels_and_label_file", 4),
1223+
("custom_unstructured_tool_with_invalid_schema1", 5),
1224+
("custom_unstructured_tool_with_invalid_schema2", 6),
11791225
("inference_multiclass_metadata_yaml", 100),
11801226
("inference_multiclass_metadata_yaml_label_file", 100),
1227+
("custom_unstructured_tool_with_schema_in_yaml", 100),
11811228
],
11821229
)
11831230
def test_yaml_metadata_missing_fields(tmp_path, config_yaml, request, test_case_number):
@@ -1211,5 +1258,17 @@ def test_yaml_metadata_missing_fields(tmp_path, config_yaml, request, test_case_
12111258
match="Error - for multiclass classification, either the class labels or a class labels file should be provided in model-metadata.yaml file, but not both",
12121259
):
12131260
read_model_metadata_yaml(tmp_path)
1261+
elif test_case_number == 5:
1262+
with pytest.raises(
1263+
DrumCommonException,
1264+
match="Error creating pydantic model from input schema: Only 'object' type schemas are supported, got ",
1265+
):
1266+
read_model_metadata_yaml(tmp_path)
1267+
elif test_case_number == 6:
1268+
with pytest.raises(
1269+
DrumCommonException,
1270+
match="Error creating pydantic model from input schema: 'properties' must be a dictionary, got ",
1271+
):
1272+
read_model_metadata_yaml(tmp_path)
12141273
elif test_case_number == 100:
12151274
read_model_metadata_yaml(tmp_path)

0 commit comments

Comments
 (0)