Skip to content

Commit 95c1619

Browse files
[MMM-19325] Use new object-oriented moderations pipeline (#1401)
1 parent 999be37 commit 95c1619

File tree

2 files changed

+115
-32
lines changed

2 files changed

+115
-32
lines changed

custom_model_runner/datarobot_drum/drum/adapters/model_adapters/python_model_adapter.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def __init__(self, model_dir, target_type=None):
107107
# New custom task class and instance loaded from custom.py
108108
self._custom_task_class = None
109109
self._custom_task_class_instance = None
110+
self._mod_pipeline = None
110111
self._moderation_pipeline = None
111112
self._moderation_score_hook = None
112113
self._moderation_chat_hook = None
@@ -132,8 +133,11 @@ def _load_moderation_hooks(self):
132133
self._logger.info(
133134
f"Detected {mod_module.__name__} in {mod_module.__file__}.. trying to load hooks"
134135
)
136+
# use the 'moderation_pipeline_factory()' to determine if moderations has integrated pipeline
137+
if hasattr(mod_module, "moderation_pipeline_factory"):
138+
self._mod_pipeline = mod_module.moderation_pipeline_factory(self._target_type.value)
135139
# use the 'create_pipeline' to determine if using version that supports VDB
136-
if hasattr(mod_module, "create_pipeline"):
140+
elif hasattr(mod_module, "create_pipeline"):
137141
self._moderation_score_hook = mod_module.get_moderations_fn(
138142
self._target_type.value, CustomHooks.SCORE
139143
)
@@ -605,14 +609,21 @@ def _predict_legacy_drum(self, data, model, **kwargs) -> RawPredictResponse:
605609
if request_labels is not None:
606610
assert all(isinstance(label, str) for label in request_labels)
607611
extra_model_output = None
608-
if self._custom_hooks.get(CustomHooks.SCORE):
612+
score_fn = self._custom_hooks.get(CustomHooks.SCORE)
613+
if score_fn:
609614
try:
610-
if self._moderation_pipeline and self._moderation_score_hook:
615+
if self._mod_pipeline:
616+
predictions_df = self._mod_pipeline.score(data, model, score_fn, **kwargs)
617+
if self._target_name not in predictions_df:
618+
predictions_df.rename(
619+
columns={"completion": self._target_name}, inplace=True
620+
)
621+
elif self._moderation_pipeline and self._moderation_score_hook:
611622
predictions_df = self._moderation_score_hook(
612623
data,
613624
model,
614625
self._moderation_pipeline,
615-
self._custom_hooks.get(CustomHooks.SCORE),
626+
score_fn,
616627
**kwargs,
617628
)
618629
if self._target_name not in predictions_df:
@@ -621,9 +632,7 @@ def _predict_legacy_drum(self, data, model, **kwargs) -> RawPredictResponse:
621632
)
622633
else:
623634
# noinspection PyCallingNonCallable
624-
predictions_df = self._custom_hooks.get(CustomHooks.SCORE)(
625-
data, model, **kwargs
626-
)
635+
predictions_df = score_fn(data, model, **kwargs)
627636
except Exception as exc:
628637
self._log_and_raise_final_error(
629638
exc, "Model 'score' hook failed to make predictions."
@@ -758,16 +767,19 @@ def predict_unstructured(self, model, data, **kwargs):
758767
return predictions
759768

760769
def chat(self, completion_create_params, model, association_id):
761-
if self._moderation_pipeline and self._moderation_chat_hook:
770+
chat_fn = self._custom_hooks.get(CustomHooks.CHAT)
771+
if self._mod_pipeline:
772+
self._mod_pipeline.chat(completion_create_params, model, chat_fn, association_id)
773+
elif self._moderation_pipeline and self._moderation_chat_hook:
762774
return self._moderation_chat_hook(
763775
completion_create_params,
764776
model,
765777
self._moderation_pipeline,
766-
self._custom_hooks.get(CustomHooks.CHAT),
778+
chat_fn,
767779
association_id,
768780
)
769781
else:
770-
return self._custom_hooks.get(CustomHooks.CHAT)(completion_create_params, model)
782+
return chat_fn(completion_create_params, model)
771783

772784
def get_supported_llm_models(self, model):
773785
"""

tests/unit/datarobot_drum/drum/adapters/model_adapters/test_python_model_adapter.py

Lines changed: 93 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import logging
1111
import os
1212
import random
13+
import shutil
1314
import sys
1415
from dataclasses import dataclass
1516

@@ -584,6 +585,43 @@ def set_moderations_lib_content(path: Path, content: str):
584585
mod_hook_file.write_text(content)
585586

586587

588+
def remove_moderations_lib_content(path: Path):
589+
"""Removes the moderations subdirectory from the specified path."""
590+
mod_dir = path / MODERATIONS_LIBRARY_PACKAGE
591+
if not mod_dir.exists() or not mod_dir.is_dir():
592+
return
593+
594+
shutil.rmtree(mod_dir)
595+
596+
597+
@contextlib.contextmanager
598+
def mock_moderation_content(path: Path, content: str):
599+
"""
600+
Sets the moderations content in the provided path, makes adjustments to find it, and
601+
cleans up the files and modules following test execution.
602+
"""
603+
604+
# remove any currently loaded moderations libraries
605+
sys.modules.pop(MODERATIONS_HOOK_MODULE, None)
606+
sys.modules.pop(MODERATIONS_LIBRARY_PACKAGE, None)
607+
608+
# put provided path in the search path for modules
609+
sys.path.insert(0, str(path))
610+
611+
# set the content to the provided value
612+
set_moderations_lib_content(path, content)
613+
try:
614+
yield # let the test run here
615+
finally:
616+
# remove the moderations subdirectory and remove it from search path
617+
remove_moderations_lib_content(path)
618+
sys.path.remove(str(path))
619+
620+
# unload any moderation modules, so they don't get used by another test
621+
sys.modules.pop(MODERATIONS_HOOK_MODULE, None)
622+
sys.modules.pop(MODERATIONS_LIBRARY_PACKAGE, None)
623+
624+
587625
class TestPythonModelAdapterWithGuards:
588626
"""Use cases to test the moderation integration with DRUM"""
589627

@@ -597,16 +635,11 @@ def guard_score_wrapper(data, model, pipeline, drum_score_fn, **kwargs):
597635
def init():
598636
return Mock()
599637
"""
600-
sys.path.insert(0, str(tmp_path))
601-
set_moderations_lib_content(tmp_path, textwrap.dedent(guard_hook_contents))
602-
603638
text_generation_target_name = "completion"
604-
with patch.dict(os.environ, {"TARGET_NAME": text_generation_target_name}):
605-
# Remove any existing cached imports to allow importing the fake guard package.
606-
# Existing imports will be there if real moderations library is in python path.
607-
sys.modules.pop(MODERATIONS_HOOK_MODULE, None)
608-
sys.modules.pop(MODERATIONS_LIBRARY_PACKAGE, None)
609-
639+
with (
640+
patch.dict(os.environ, {"TARGET_NAME": text_generation_target_name}),
641+
mock_moderation_content(tmp_path, textwrap.dedent(guard_hook_contents)),
642+
):
610643
adapter = PythonModelAdapter(tmp_path, TargetType.TEXT_GENERATION)
611644
assert adapter._moderation_pipeline is not None
612645
# Ensure that it is Mock as set by guard_hook_contents
@@ -619,7 +652,6 @@ def init():
619652
assert adapter._moderation_pipeline is None
620653
assert adapter._moderation_score_hook is None
621654
assert adapter._moderation_chat_hook is None
622-
sys.path.remove(str(tmp_path))
623655

624656
@pytest.mark.parametrize(
625657
["target_type", "score_hook_name"],
@@ -657,24 +689,39 @@ def get_moderations_fn(target_type, custom_hook):
657689
def create_pipeline(target_type):
658690
return Mock()
659691
"""
660-
sys.path.insert(0, str(tmp_path))
661-
662-
set_moderations_lib_content(tmp_path, textwrap.dedent(moderation_content))
663692
text_generation_target_name = "completion"
664-
with patch.dict(os.environ, {"TARGET_NAME": text_generation_target_name}):
665-
# Remove any existing cached imports to allow importing the fake guard package.
666-
# Existing imports will be there if real moderations library is in python path.
667-
sys.modules.pop(MODERATIONS_HOOK_MODULE, None)
668-
sys.modules.pop(MODERATIONS_LIBRARY_PACKAGE, None)
669-
693+
with (
694+
patch.dict(os.environ, {"TARGET_NAME": text_generation_target_name}),
695+
mock_moderation_content(tmp_path, textwrap.dedent(moderation_content)),
696+
):
670697
adapter = PythonModelAdapter(tmp_path, target_type)
671698
assert adapter._moderation_pipeline is not None
672699
assert isinstance(adapter._moderation_pipeline, Mock)
673700
assert score_hook_name in str(adapter._moderation_score_hook)
674701
# would be nice to check chat_hook, but having a stub function causes other problems
675702
assert adapter._moderation_chat_hook is None
676703

677-
sys.path.remove(str(tmp_path))
704+
@pytest.mark.parametrize(
705+
["target_type"],
706+
[
707+
pytest.param(TargetType.TEXT_GENERATION, id="textgen"),
708+
pytest.param(TargetType.VECTOR_DATABASE, id="vectordb"),
709+
],
710+
)
711+
def test_loading_moderations_pipeline(self, target_type, tmp_path):
712+
moderation_content = """
713+
from unittest.mock import Mock
714+
715+
def moderation_pipeline_factory(target_type):
716+
return Mock()
717+
"""
718+
target_name = "completion"
719+
with (
720+
patch.dict(os.environ, {"TARGET_NAME": target_name}),
721+
mock_moderation_content(tmp_path, textwrap.dedent(moderation_content)),
722+
):
723+
adapter = PythonModelAdapter(tmp_path, target_type)
724+
assert adapter._mod_pipeline is not None
678725

679726
@pytest.mark.parametrize(
680727
"guard_hook_present, expected_predictions",
@@ -781,12 +828,36 @@ def custom_score(data, model, **kwargs):
781828

782829
df = pd.DataFrame({"text": ["abc", "def"]})
783830
data = bytes(df.to_csv(index=False), encoding="utf-8")
784-
text_generation_target_name = "completion"
785-
with patch.dict(os.environ, {"TARGET_NAME": text_generation_target_name}):
831+
target_name = "completion"
832+
with patch.dict(os.environ, {"TARGET_NAME": target_name}):
786833
adapter = PythonModelAdapter(tmp_path, TargetType.VECTOR_DATABASE)
787834
adapter._moderation_pipeline = Mock()
788835
adapter._moderation_score_hook = Mock(return_value=df)
789836
adapter._custom_hooks["score"] = custom_score
790837

791838
adapter.predict(binary_data=data)
792839
assert adapter._moderation_score_hook.call_count == 1
840+
841+
def test_vdb_moderation_pipeline(self, tmp_path):
842+
def custom_score(data, model, **kwargs):
843+
"""Dummy score method just for the purpose of unit test"""
844+
return data
845+
846+
class TestModPipeline:
847+
def __init__(self):
848+
self.call_count = 0
849+
850+
def score(self, data, model, score_fn, **kwargs):
851+
self.call_count += 1
852+
return score_fn(data, model, **kwargs)
853+
854+
df = pd.DataFrame({"text": ["abc", "def"]})
855+
data = bytes(df.to_csv(index=False), encoding="utf-8")
856+
target_name = "completion"
857+
with patch.dict(os.environ, {"TARGET_NAME": target_name}):
858+
adapter = PythonModelAdapter(tmp_path, TargetType.VECTOR_DATABASE)
859+
adapter._mod_pipeline = TestModPipeline()
860+
adapter._custom_hooks["score"] = custom_score
861+
862+
adapter.predict(binary_data=data)
863+
assert adapter._mod_pipeline.call_count == 1

0 commit comments

Comments
 (0)