@@ -202,6 +202,24 @@ def loop_check(callback, item):
202202 for each in item :
203203 callback (each )
204204
205+ class CheckInputTypeWrapper (object ):
206+ def __init__ (self , generator , input_types , logger ):
207+ self .generator = generator
208+ self .input_types = input_types
209+ self .logger = logger
210+
211+ def __call__ (self , obj , filename ):
212+ for items in self .generator (obj , filename ):
213+ try :
214+ # dict type is required for input_types when item is dict type
215+ assert (isinstance (items , dict ) and \
216+ not isinstance (self .input_types , dict ))== False
217+ yield items
218+ except AssertionError as e :
219+ self .logger .error (
220+ "%s type is required for input type but got %s" %
221+ (repr (type (items )), repr (type (self .input_types ))))
222+ raise
205223
206224def provider (input_types = None ,
207225 should_shuffle = None ,
@@ -355,6 +373,9 @@ def __init__(self, file_list, **kwargs):
355373 if use_dynamic_order :
356374 self .generator = InputOrderWrapper (self .generator ,
357375 self .input_order )
376+ else :
377+ self .generator = CheckInputTypeWrapper (self .generator , self .slots ,
378+ self .logger )
358379 if self .check :
359380 self .generator = CheckWrapper (self .generator , self .slots ,
360381 check_fail_continue ,
0 commit comments