1414import pandas as pd
1515import sympy as sp
1616
17+ from ..v1 .visualize .lint import validate_visualization_df
18+ from ..v2 .C import *
19+ from .core import PriorDistribution
1720from .problem import Problem
1821
1922logger = logging .getLogger (__name__ )
3740 "CheckUnusedExperiments" ,
3841 "CheckObservablesDoNotShadowModelEntities" ,
3942 "CheckUnusedConditions" ,
43+ "CheckAllObservablesDefined" ,
44+ "CheckPriorDistribution" ,
4045 "lint_problem" ,
4146 "default_validation_tasks" ,
4247]
@@ -77,8 +82,12 @@ def __post_init__(self):
7782 def __str__ (self ):
7883 return f"{ self .level .name } : { self .message } "
7984
80- def _get_task_name (self ):
81- """Get the name of the ValidationTask that raised this error."""
85+ @staticmethod
86+ def _get_task_name () -> str | None :
87+ """Get the name of the ValidationTask that raised this error.
88+
89+ Expected to be called from below a `ValidationTask.run`.
90+ """
8291 import inspect
8392
8493 # walk up the stack until we find the ValidationTask.run method
@@ -88,6 +97,7 @@ def _get_task_name(self):
8897 task = frame .f_locals ["self" ]
8998 if isinstance (task , ValidationTask ):
9099 return task .__class__ .__name__
100+ return None
91101
92102
93103@dataclass
@@ -222,6 +232,8 @@ def run(self, problem: Problem) -> ValidationIssue | None:
222232 f"Missing files: { ', ' .join (missing_files )} "
223233 )
224234
235+ return None
236+
225237
226238class CheckModel (ValidationTask ):
227239 """A task to validate the model of a PEtab problem."""
@@ -234,6 +246,8 @@ def run(self, problem: Problem) -> ValidationIssue | None:
234246 # TODO get actual model validation messages
235247 return ValidationError ("Model is invalid." )
236248
249+ return None
250+
237251
238252class CheckMeasuredObservablesDefined (ValidationTask ):
239253 """A task to check that all observables referenced by the measurements
@@ -252,10 +266,13 @@ def run(self, problem: Problem) -> ValidationIssue | None:
252266 "measurement table but not defined in observable table."
253267 )
254268
269+ return None
270+
255271
256272class CheckOverridesMatchPlaceholders (ValidationTask ):
257273 """A task to check that the number of observable/noise parameters
258- in the measurements match the number of placeholders in the observables."""
274+ in the measurements matches the number of placeholders in the observables.
275+ """
259276
260277 def run (self , problem : Problem ) -> ValidationIssue | None :
261278 observable_parameters_count = {
@@ -320,18 +337,20 @@ def run(self, problem: Problem) -> ValidationIssue | None:
320337 if messages :
321338 return ValidationError ("\n " .join (messages ))
322339
340+ return None
341+
323342
324343class CheckPosLogMeasurements (ValidationTask ):
325344 """Check that measurements for observables with
326345 log-transformation are positive."""
327346
328347 def run (self , problem : Problem ) -> ValidationIssue | None :
329- from .core import NoiseDistribution as nd
348+ from .core import NoiseDistribution as ND # noqa: N813
330349
331350 log_observables = {
332351 o .id
333352 for o in problem .observable_table .observables
334- if o .noise_distribution in [nd .LOG_NORMAL , nd .LOG_LAPLACE ]
353+ if o .noise_distribution in [ND .LOG_NORMAL , ND .LOG_LAPLACE ]
335354 }
336355 if log_observables :
337356 for m in problem .measurement_table .measurements :
@@ -342,6 +361,8 @@ def run(self, problem: Problem) -> ValidationIssue | None:
342361 f"positive, but { m .measurement } <= 0 for { m } "
343362 )
344363
364+ return None
365+
345366
346367class CheckMeasuredExperimentsDefined (ValidationTask ):
347368 """A task to check that all experiments referenced by measurements
@@ -369,6 +390,8 @@ def run(self, problem: Problem) -> ValidationIssue | None:
369390 + str (missing_experiments )
370391 )
371392
393+ return None
394+
372395
373396class CheckValidConditionTargets (ValidationTask ):
374397 """Check that all condition table targets are valid."""
@@ -418,6 +441,32 @@ def run(self, problem: Problem) -> ValidationIssue | None:
418441 f"{ invalid } at time { period .time } ."
419442 )
420443 period_targets |= condition_targets
444+ return None
445+
446+
447+ class CheckAllObservablesDefined (ValidationTask ):
448+ """A task to validate that all observables in the measurement table are
449+ defined in the observable table."""
450+
451+ def run (self , problem : Problem ) -> ValidationIssue | None :
452+ if problem .measurement_df is None :
453+ return None
454+
455+ measurement_df = problem .measurement_df
456+ observable_df = problem .observable_df
457+ used_observables = set (measurement_df [OBSERVABLE_ID ].values )
458+ defined_observables = (
459+ set (observable_df .index .values )
460+ if observable_df is not None
461+ else set ()
462+ )
463+ if undefined_observables := (used_observables - defined_observables ):
464+ return ValidationError (
465+ f"Observables { undefined_observables } are used in the"
466+ "measurements table but are not defined in observables table."
467+ )
468+
469+ return None
421470
422471
423472class CheckUniquePrimaryKeys (ValidationTask ):
@@ -429,37 +478,39 @@ def run(self, problem: Problem) -> ValidationIssue | None:
429478
430479 # check for uniqueness of all primary keys
431480 counter = Counter (c .id for c in problem .condition_table .conditions )
432- duplicates = {id for id , count in counter .items () if count > 1 }
481+ duplicates = {id_ for id_ , count in counter .items () if count > 1 }
433482
434483 if duplicates :
435484 return ValidationError (
436485 f"Condition table contains duplicate IDs: { duplicates } "
437486 )
438487
439488 counter = Counter (o .id for o in problem .observable_table .observables )
440- duplicates = {id for id , count in counter .items () if count > 1 }
489+ duplicates = {id_ for id_ , count in counter .items () if count > 1 }
441490
442491 if duplicates :
443492 return ValidationError (
444493 f"Observable table contains duplicate IDs: { duplicates } "
445494 )
446495
447496 counter = Counter (e .id for e in problem .experiment_table .experiments )
448- duplicates = {id for id , count in counter .items () if count > 1 }
497+ duplicates = {id_ for id_ , count in counter .items () if count > 1 }
449498
450499 if duplicates :
451500 return ValidationError (
452501 f"Experiment table contains duplicate IDs: { duplicates } "
453502 )
454503
455504 counter = Counter (p .id for p in problem .parameter_table .parameters )
456- duplicates = {id for id , count in counter .items () if count > 1 }
505+ duplicates = {id_ for id_ , count in counter .items () if count > 1 }
457506
458507 if duplicates :
459508 return ValidationError (
460509 f"Parameter table contains duplicate IDs: { duplicates } "
461510 )
462511
512+ return None
513+
463514
464515class CheckObservablesDoNotShadowModelEntities (ValidationTask ):
465516 """A task to check that observable IDs do not shadow model entities."""
@@ -479,6 +530,8 @@ def run(self, problem: Problem) -> ValidationIssue | None:
479530 f"Observable IDs { shadowed_entities } shadow model entities."
480531 )
481532
533+ return None
534+
482535
483536class CheckExperimentTable (ValidationTask ):
484537 """A task to validate the experiment table of a PEtab problem."""
@@ -498,6 +551,8 @@ def run(self, problem: Problem) -> ValidationIssue | None:
498551 if messages :
499552 return ValidationError ("\n " .join (messages ))
500553
554+ return None
555+
501556
502557class CheckExperimentConditionsExist (ValidationTask ):
503558 """A task to validate that all conditions in the experiment table exist
@@ -526,6 +581,8 @@ def run(self, problem: Problem) -> ValidationIssue | None:
526581 if messages :
527582 return ValidationError ("\n " .join (messages ))
528583
584+ return None
585+
529586
530587class CheckAllParametersPresentInParameterTable (ValidationTask ):
531588 """Ensure all required parameters are contained in the parameter table
@@ -573,6 +630,8 @@ def run(self, problem: Problem) -> ValidationIssue | None:
573630 + str (extraneous )
574631 )
575632
633+ return None
634+
576635
577636class CheckValidParameterInConditionOrParameterTable (ValidationTask ):
578637 """A task to check that all required and only allowed model parameters are
@@ -646,9 +705,11 @@ def run(self, problem: Problem) -> ValidationIssue | None:
646705 "the condition table and the parameter table."
647706 )
648707
708+ return None
709+
649710
650711class CheckUnusedExperiments (ValidationTask ):
651- """A task to check for experiments that are not used in the measurements
712+ """A task to check for experiments that are not used in the measurement
652713 table."""
653714
654715 def run (self , problem : Problem ) -> ValidationIssue | None :
@@ -668,9 +729,11 @@ def run(self, problem: Problem) -> ValidationIssue | None:
668729 "measurements table."
669730 )
670731
732+ return None
733+
671734
672735class CheckUnusedConditions (ValidationTask ):
673- """A task to check for conditions that are not used in the experiments
736+ """A task to check for conditions that are not used in the experiment
674737 table."""
675738
676739 def run (self , problem : Problem ) -> ValidationIssue | None :
@@ -692,6 +755,8 @@ def run(self, problem: Problem) -> ValidationIssue | None:
692755 "experiments table."
693756 )
694757
758+ return None
759+
695760
696761class CheckVisualizationTable (ValidationTask ):
697762 """A task to validate the visualization table of a PEtab problem."""
@@ -700,14 +765,64 @@ def run(self, problem: Problem) -> ValidationIssue | None:
700765 if problem .visualization_df is None :
701766 return None
702767
703- from ..v1 .visualize .lint import validate_visualization_df
704-
705768 if validate_visualization_df (problem ):
706769 return ValidationIssue (
707770 level = ValidationIssueSeverity .ERROR ,
708771 message = "Visualization table is invalid." ,
709772 )
710773
774+ return None
775+
776+
777+ class CheckPriorDistribution (ValidationTask ):
778+ """A task to validate the prior distribution of a PEtab problem."""
779+
780+ _num_pars = {
781+ PriorDistribution .CAUCHY : 2 ,
782+ PriorDistribution .CHI_SQUARED : 1 ,
783+ PriorDistribution .EXPONENTIAL : 1 ,
784+ PriorDistribution .GAMMA : 2 ,
785+ PriorDistribution .LAPLACE : 2 ,
786+ PriorDistribution .LOG10_NORMAL : 2 ,
787+ PriorDistribution .LOG_LAPLACE : 2 ,
788+ PriorDistribution .LOG_NORMAL : 2 ,
789+ PriorDistribution .LOG_UNIFORM : 2 ,
790+ PriorDistribution .NORMAL : 2 ,
791+ PriorDistribution .RAYLEIGH : 1 ,
792+ PriorDistribution .UNIFORM : 2 ,
793+ }
794+
795+ def run (self , problem : Problem ) -> ValidationIssue | None :
796+ messages = []
797+ for parameter in problem .parameter_table .parameters :
798+ if parameter .prior_distribution is None :
799+ continue
800+
801+ if parameter .prior_distribution not in PRIOR_DISTRIBUTIONS :
802+ messages .append (
803+ f"Prior distribution `{ parameter .prior_distribution } ' "
804+ f"for parameter `{ parameter .id } ' is not valid."
805+ )
806+ continue
807+
808+ if (
809+ exp_num_par := self ._num_pars [parameter .prior_distribution ]
810+ ) != len (parameter .prior_parameters ):
811+ messages .append (
812+ f"Prior distribution `{ parameter .prior_distribution } ' "
813+ f"for parameter `{ parameter .id } ' requires "
814+ f"{ exp_num_par } parameters, but got "
815+ f"{ len (parameter .prior_parameters )} "
816+ f"({ parameter .prior_parameters } )."
817+ )
818+
819+ # TODO: check distribution parameter domains
820+
821+ if messages :
822+ return ValidationError ("\n " .join (messages ))
823+
824+ return None
825+
711826
712827def get_valid_parameters_for_parameter_table (
713828 problem : Problem ,
@@ -752,7 +867,7 @@ def get_valid_parameters_for_parameter_table(
752867 if mapping .model_id and mapping .model_id in parameter_ids .keys ():
753868 parameter_ids [mapping .petab_id ] = None
754869
755- # add output parameters from observables table
870+ # add output parameters from observable table
756871 output_parameters = get_output_parameters (problem )
757872 for p in output_parameters :
758873 if p not in invalid :
@@ -781,7 +896,7 @@ def get_required_parameters_for_parameter_table(
781896 problem : Problem ,
782897) -> Set [str ]:
783898 """
784- Get set of parameters which need to go into the parameter table
899+ Get the set of parameters that need to go into the parameter table
785900
786901 Arguments:
787902 problem: The PEtab problem
@@ -965,4 +1080,9 @@ def get_placeholders(
9651080 # TODO: atomize checks, update to long condition table, re-enable
9661081 # CheckVisualizationTable(),
9671082 # TODO validate mapping table
1083+ CheckValidParameterInConditionOrParameterTable (),
1084+ CheckAllObservablesDefined (),
1085+ CheckAllParametersPresentInParameterTable (),
1086+ CheckValidConditionTargets (),
1087+ CheckPriorDistribution (),
9681088]
0 commit comments