1010import logging
1111import os
1212import random
13+ import shutil
1314import sys
1415from dataclasses import dataclass
1516
@@ -584,6 +585,43 @@ def set_moderations_lib_content(path: Path, content: str):
584585 mod_hook_file .write_text (content )
585586
586587
588+ def remove_moderations_lib_content (path : Path ):
589+ """Removes the moderations subdirectory from the specified path."""
590+ mod_dir = path / MODERATIONS_LIBRARY_PACKAGE
591+ if not mod_dir .exists () or not mod_dir .is_dir ():
592+ return
593+
594+ shutil .rmtree (mod_dir )
595+
596+
597+ @contextlib .contextmanager
598+ def mock_moderation_content (path : Path , content : str ):
599+ """
600+ Sets the moderations content in the provided path, makes adjustments to find it, and
601+ cleans up the files and modules following test execution.
602+ """
603+
604+ # remove any currently loaded moderations libraries
605+ sys .modules .pop (MODERATIONS_HOOK_MODULE , None )
606+ sys .modules .pop (MODERATIONS_LIBRARY_PACKAGE , None )
607+
608+ # put provided path in the search path for modules
609+ sys .path .insert (0 , str (path ))
610+
611+ # set the content to the provided value
612+ set_moderations_lib_content (path , content )
613+ try :
614+ yield # let the test run here
615+ finally :
616+ # remove the moderations subdirectory and remove it from search path
617+ remove_moderations_lib_content (path )
618+ sys .path .remove (str (path ))
619+
620+ # unload any moderation modules, so they don't get used by another test
621+ sys .modules .pop (MODERATIONS_HOOK_MODULE , None )
622+ sys .modules .pop (MODERATIONS_LIBRARY_PACKAGE , None )
623+
624+
587625class TestPythonModelAdapterWithGuards :
588626 """Use cases to test the moderation integration with DRUM"""
589627
@@ -597,16 +635,11 @@ def guard_score_wrapper(data, model, pipeline, drum_score_fn, **kwargs):
597635 def init():
598636 return Mock()
599637 """
600- sys .path .insert (0 , str (tmp_path ))
601- set_moderations_lib_content (tmp_path , textwrap .dedent (guard_hook_contents ))
602-
603638 text_generation_target_name = "completion"
604- with patch .dict (os .environ , {"TARGET_NAME" : text_generation_target_name }):
605- # Remove any existing cached imports to allow importing the fake guard package.
606- # Existing imports will be there if real moderations library is in python path.
607- sys .modules .pop (MODERATIONS_HOOK_MODULE , None )
608- sys .modules .pop (MODERATIONS_LIBRARY_PACKAGE , None )
609-
639+ with (
640+ patch .dict (os .environ , {"TARGET_NAME" : text_generation_target_name }),
641+ mock_moderation_content (tmp_path , textwrap .dedent (guard_hook_contents )),
642+ ):
610643 adapter = PythonModelAdapter (tmp_path , TargetType .TEXT_GENERATION )
611644 assert adapter ._moderation_pipeline is not None
612645 # Ensure that it is Mock as set by guard_hook_contents
@@ -619,7 +652,6 @@ def init():
619652 assert adapter ._moderation_pipeline is None
620653 assert adapter ._moderation_score_hook is None
621654 assert adapter ._moderation_chat_hook is None
622- sys .path .remove (str (tmp_path ))
623655
624656 @pytest .mark .parametrize (
625657 ["target_type" , "score_hook_name" ],
@@ -657,24 +689,39 @@ def get_moderations_fn(target_type, custom_hook):
657689 def create_pipeline(target_type):
658690 return Mock()
659691 """
660- sys .path .insert (0 , str (tmp_path ))
661-
662- set_moderations_lib_content (tmp_path , textwrap .dedent (moderation_content ))
663692 text_generation_target_name = "completion"
664- with patch .dict (os .environ , {"TARGET_NAME" : text_generation_target_name }):
665- # Remove any existing cached imports to allow importing the fake guard package.
666- # Existing imports will be there if real moderations library is in python path.
667- sys .modules .pop (MODERATIONS_HOOK_MODULE , None )
668- sys .modules .pop (MODERATIONS_LIBRARY_PACKAGE , None )
669-
693+ with (
694+ patch .dict (os .environ , {"TARGET_NAME" : text_generation_target_name }),
695+ mock_moderation_content (tmp_path , textwrap .dedent (moderation_content )),
696+ ):
670697 adapter = PythonModelAdapter (tmp_path , target_type )
671698 assert adapter ._moderation_pipeline is not None
672699 assert isinstance (adapter ._moderation_pipeline , Mock )
673700 assert score_hook_name in str (adapter ._moderation_score_hook )
674701 # would be nice to check chat_hook, but having a stub function causes other problems
675702 assert adapter ._moderation_chat_hook is None
676703
677- sys .path .remove (str (tmp_path ))
704+ @pytest .mark .parametrize (
705+ ["target_type" ],
706+ [
707+ pytest .param (TargetType .TEXT_GENERATION , id = "textgen" ),
708+ pytest .param (TargetType .VECTOR_DATABASE , id = "vectordb" ),
709+ ],
710+ )
711+ def test_loading_moderations_pipeline (self , target_type , tmp_path ):
712+ moderation_content = """
713+ from unittest.mock import Mock
714+
715+ def moderation_pipeline_factory(target_type):
716+ return Mock()
717+ """
718+ target_name = "completion"
719+ with (
720+ patch .dict (os .environ , {"TARGET_NAME" : target_name }),
721+ mock_moderation_content (tmp_path , textwrap .dedent (moderation_content )),
722+ ):
723+ adapter = PythonModelAdapter (tmp_path , target_type )
724+ assert adapter ._mod_pipeline is not None
678725
679726 @pytest .mark .parametrize (
680727 "guard_hook_present, expected_predictions" ,
@@ -781,12 +828,36 @@ def custom_score(data, model, **kwargs):
781828
782829 df = pd .DataFrame ({"text" : ["abc" , "def" ]})
783830 data = bytes (df .to_csv (index = False ), encoding = "utf-8" )
784- text_generation_target_name = "completion"
785- with patch .dict (os .environ , {"TARGET_NAME" : text_generation_target_name }):
831+ target_name = "completion"
832+ with patch .dict (os .environ , {"TARGET_NAME" : target_name }):
786833 adapter = PythonModelAdapter (tmp_path , TargetType .VECTOR_DATABASE )
787834 adapter ._moderation_pipeline = Mock ()
788835 adapter ._moderation_score_hook = Mock (return_value = df )
789836 adapter ._custom_hooks ["score" ] = custom_score
790837
791838 adapter .predict (binary_data = data )
792839 assert adapter ._moderation_score_hook .call_count == 1
840+
841+ def test_vdb_moderation_pipeline (self , tmp_path ):
842+ def custom_score (data , model , ** kwargs ):
843+ """Dummy score method just for the purpose of unit test"""
844+ return data
845+
846+ class TestModPipeline :
847+ def __init__ (self ):
848+ self .call_count = 0
849+
850+ def score (self , data , model , score_fn , ** kwargs ):
851+ self .call_count += 1
852+ return score_fn (data , model , ** kwargs )
853+
854+ df = pd .DataFrame ({"text" : ["abc" , "def" ]})
855+ data = bytes (df .to_csv (index = False ), encoding = "utf-8" )
856+ target_name = "completion"
857+ with patch .dict (os .environ , {"TARGET_NAME" : target_name }):
858+ adapter = PythonModelAdapter (tmp_path , TargetType .VECTOR_DATABASE )
859+ adapter ._mod_pipeline = TestModPipeline ()
860+ adapter ._custom_hooks ["score" ] = custom_score
861+
862+ adapter .predict (binary_data = data )
863+ assert adapter ._mod_pipeline .call_count == 1
0 commit comments