Skip to content

Commit d54a003

Browse files
committed
v2: More validation
* Check priors * Check observables * Fix missing prior parameters after v1->v2 conversion of uniform priors * Fix style
1 parent 84fbdba commit d54a003

File tree

3 files changed

+155
-15
lines changed

3 files changed

+155
-15
lines changed

petab/v2/core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,9 @@ def _validate_id(cls, v):
890890
@field_validator("prior_parameters", mode="before")
891891
@classmethod
892892
def _validate_prior_parameters(cls, v):
893+
if isinstance(v, float) and np.isnan(v):
894+
return []
895+
893896
if isinstance(v, str):
894897
v = v.split(C.PARAMETER_SEPARATOR)
895898
elif not isinstance(v, Sequence):

petab/v2/lint.py

Lines changed: 135 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
import pandas as pd
1515
import sympy as sp
1616

17+
from ..v1.visualize.lint import validate_visualization_df
18+
from ..v2.C import *
19+
from .core import PriorDistribution
1720
from .problem import Problem
1821

1922
logger = logging.getLogger(__name__)
@@ -37,6 +40,8 @@
3740
"CheckUnusedExperiments",
3841
"CheckObservablesDoNotShadowModelEntities",
3942
"CheckUnusedConditions",
43+
"CheckAllObservablesDefined",
44+
"CheckPriorDistribution",
4045
"lint_problem",
4146
"default_validation_tasks",
4247
]
@@ -77,8 +82,12 @@ def __post_init__(self):
7782
def __str__(self):
7883
return f"{self.level.name}: {self.message}"
7984

80-
def _get_task_name(self):
81-
"""Get the name of the ValidationTask that raised this error."""
85+
@staticmethod
86+
def _get_task_name() -> str | None:
87+
"""Get the name of the ValidationTask that raised this error.
88+
89+
Expected to be called from below a `ValidationTask.run`.
90+
"""
8291
import inspect
8392

8493
# walk up the stack until we find the ValidationTask.run method
@@ -88,6 +97,7 @@ def _get_task_name(self):
8897
task = frame.f_locals["self"]
8998
if isinstance(task, ValidationTask):
9099
return task.__class__.__name__
100+
return None
91101

92102

93103
@dataclass
@@ -222,6 +232,8 @@ def run(self, problem: Problem) -> ValidationIssue | None:
222232
f"Missing files: {', '.join(missing_files)}"
223233
)
224234

235+
return None
236+
225237

226238
class CheckModel(ValidationTask):
227239
"""A task to validate the model of a PEtab problem."""
@@ -234,6 +246,8 @@ def run(self, problem: Problem) -> ValidationIssue | None:
234246
# TODO get actual model validation messages
235247
return ValidationError("Model is invalid.")
236248

249+
return None
250+
237251

238252
class CheckMeasuredObservablesDefined(ValidationTask):
239253
"""A task to check that all observables referenced by the measurements
@@ -252,10 +266,13 @@ def run(self, problem: Problem) -> ValidationIssue | None:
252266
"measurement table but not defined in observable table."
253267
)
254268

269+
return None
270+
255271

256272
class CheckOverridesMatchPlaceholders(ValidationTask):
257273
"""A task to check that the number of observable/noise parameters
258-
in the measurements match the number of placeholders in the observables."""
274+
in the measurements matches the number of placeholders in the observables.
275+
"""
259276

260277
def run(self, problem: Problem) -> ValidationIssue | None:
261278
observable_parameters_count = {
@@ -320,18 +337,20 @@ def run(self, problem: Problem) -> ValidationIssue | None:
320337
if messages:
321338
return ValidationError("\n".join(messages))
322339

340+
return None
341+
323342

324343
class CheckPosLogMeasurements(ValidationTask):
325344
"""Check that measurements for observables with
326345
log-transformation are positive."""
327346

328347
def run(self, problem: Problem) -> ValidationIssue | None:
329-
from .core import NoiseDistribution as nd
348+
from .core import NoiseDistribution as ND # noqa: N813
330349

331350
log_observables = {
332351
o.id
333352
for o in problem.observable_table.observables
334-
if o.noise_distribution in [nd.LOG_NORMAL, nd.LOG_LAPLACE]
353+
if o.noise_distribution in [ND.LOG_NORMAL, ND.LOG_LAPLACE]
335354
}
336355
if log_observables:
337356
for m in problem.measurement_table.measurements:
@@ -342,6 +361,8 @@ def run(self, problem: Problem) -> ValidationIssue | None:
342361
f"positive, but {m.measurement} <= 0 for {m}"
343362
)
344363

364+
return None
365+
345366

346367
class CheckMeasuredExperimentsDefined(ValidationTask):
347368
"""A task to check that all experiments referenced by measurements
@@ -369,6 +390,8 @@ def run(self, problem: Problem) -> ValidationIssue | None:
369390
+ str(missing_experiments)
370391
)
371392

393+
return None
394+
372395

373396
class CheckValidConditionTargets(ValidationTask):
374397
"""Check that all condition table targets are valid."""
@@ -418,6 +441,32 @@ def run(self, problem: Problem) -> ValidationIssue | None:
418441
f"{invalid} at time {period.time}."
419442
)
420443
period_targets |= condition_targets
444+
return None
445+
446+
447+
class CheckAllObservablesDefined(ValidationTask):
448+
"""A task to validate that all observables in the measurement table are
449+
defined in the observable table."""
450+
451+
def run(self, problem: Problem) -> ValidationIssue | None:
452+
if problem.measurement_df is None:
453+
return None
454+
455+
measurement_df = problem.measurement_df
456+
observable_df = problem.observable_df
457+
used_observables = set(measurement_df[OBSERVABLE_ID].values)
458+
defined_observables = (
459+
set(observable_df.index.values)
460+
if observable_df is not None
461+
else set()
462+
)
463+
if undefined_observables := (used_observables - defined_observables):
464+
return ValidationError(
465+
f"Observables {undefined_observables} are used in the"
466+
"measurements table but are not defined in observables table."
467+
)
468+
469+
return None
421470

422471

423472
class CheckUniquePrimaryKeys(ValidationTask):
@@ -429,37 +478,39 @@ def run(self, problem: Problem) -> ValidationIssue | None:
429478

430479
# check for uniqueness of all primary keys
431480
counter = Counter(c.id for c in problem.condition_table.conditions)
432-
duplicates = {id for id, count in counter.items() if count > 1}
481+
duplicates = {id_ for id_, count in counter.items() if count > 1}
433482

434483
if duplicates:
435484
return ValidationError(
436485
f"Condition table contains duplicate IDs: {duplicates}"
437486
)
438487

439488
counter = Counter(o.id for o in problem.observable_table.observables)
440-
duplicates = {id for id, count in counter.items() if count > 1}
489+
duplicates = {id_ for id_, count in counter.items() if count > 1}
441490

442491
if duplicates:
443492
return ValidationError(
444493
f"Observable table contains duplicate IDs: {duplicates}"
445494
)
446495

447496
counter = Counter(e.id for e in problem.experiment_table.experiments)
448-
duplicates = {id for id, count in counter.items() if count > 1}
497+
duplicates = {id_ for id_, count in counter.items() if count > 1}
449498

450499
if duplicates:
451500
return ValidationError(
452501
f"Experiment table contains duplicate IDs: {duplicates}"
453502
)
454503

455504
counter = Counter(p.id for p in problem.parameter_table.parameters)
456-
duplicates = {id for id, count in counter.items() if count > 1}
505+
duplicates = {id_ for id_, count in counter.items() if count > 1}
457506

458507
if duplicates:
459508
return ValidationError(
460509
f"Parameter table contains duplicate IDs: {duplicates}"
461510
)
462511

512+
return None
513+
463514

464515
class CheckObservablesDoNotShadowModelEntities(ValidationTask):
465516
"""A task to check that observable IDs do not shadow model entities."""
@@ -479,6 +530,8 @@ def run(self, problem: Problem) -> ValidationIssue | None:
479530
f"Observable IDs {shadowed_entities} shadow model entities."
480531
)
481532

533+
return None
534+
482535

483536
class CheckExperimentTable(ValidationTask):
484537
"""A task to validate the experiment table of a PEtab problem."""
@@ -498,6 +551,8 @@ def run(self, problem: Problem) -> ValidationIssue | None:
498551
if messages:
499552
return ValidationError("\n".join(messages))
500553

554+
return None
555+
501556

502557
class CheckExperimentConditionsExist(ValidationTask):
503558
"""A task to validate that all conditions in the experiment table exist
@@ -526,6 +581,8 @@ def run(self, problem: Problem) -> ValidationIssue | None:
526581
if messages:
527582
return ValidationError("\n".join(messages))
528583

584+
return None
585+
529586

530587
class CheckAllParametersPresentInParameterTable(ValidationTask):
531588
"""Ensure all required parameters are contained in the parameter table
@@ -573,6 +630,8 @@ def run(self, problem: Problem) -> ValidationIssue | None:
573630
+ str(extraneous)
574631
)
575632

633+
return None
634+
576635

577636
class CheckValidParameterInConditionOrParameterTable(ValidationTask):
578637
"""A task to check that all required and only allowed model parameters are
@@ -646,9 +705,11 @@ def run(self, problem: Problem) -> ValidationIssue | None:
646705
"the condition table and the parameter table."
647706
)
648707

708+
return None
709+
649710

650711
class CheckUnusedExperiments(ValidationTask):
651-
"""A task to check for experiments that are not used in the measurements
712+
"""A task to check for experiments that are not used in the measurement
652713
table."""
653714

654715
def run(self, problem: Problem) -> ValidationIssue | None:
@@ -668,9 +729,11 @@ def run(self, problem: Problem) -> ValidationIssue | None:
668729
"measurements table."
669730
)
670731

732+
return None
733+
671734

672735
class CheckUnusedConditions(ValidationTask):
673-
"""A task to check for conditions that are not used in the experiments
736+
"""A task to check for conditions that are not used in the experiment
674737
table."""
675738

676739
def run(self, problem: Problem) -> ValidationIssue | None:
@@ -692,6 +755,8 @@ def run(self, problem: Problem) -> ValidationIssue | None:
692755
"experiments table."
693756
)
694757

758+
return None
759+
695760

696761
class CheckVisualizationTable(ValidationTask):
697762
"""A task to validate the visualization table of a PEtab problem."""
@@ -700,14 +765,64 @@ def run(self, problem: Problem) -> ValidationIssue | None:
700765
if problem.visualization_df is None:
701766
return None
702767

703-
from ..v1.visualize.lint import validate_visualization_df
704-
705768
if validate_visualization_df(problem):
706769
return ValidationIssue(
707770
level=ValidationIssueSeverity.ERROR,
708771
message="Visualization table is invalid.",
709772
)
710773

774+
return None
775+
776+
777+
class CheckPriorDistribution(ValidationTask):
778+
"""A task to validate the prior distribution of a PEtab problem."""
779+
780+
_num_pars = {
781+
PriorDistribution.CAUCHY: 2,
782+
PriorDistribution.CHI_SQUARED: 1,
783+
PriorDistribution.EXPONENTIAL: 1,
784+
PriorDistribution.GAMMA: 2,
785+
PriorDistribution.LAPLACE: 2,
786+
PriorDistribution.LOG10_NORMAL: 2,
787+
PriorDistribution.LOG_LAPLACE: 2,
788+
PriorDistribution.LOG_NORMAL: 2,
789+
PriorDistribution.LOG_UNIFORM: 2,
790+
PriorDistribution.NORMAL: 2,
791+
PriorDistribution.RAYLEIGH: 1,
792+
PriorDistribution.UNIFORM: 2,
793+
}
794+
795+
def run(self, problem: Problem) -> ValidationIssue | None:
796+
messages = []
797+
for parameter in problem.parameter_table.parameters:
798+
if parameter.prior_distribution is None:
799+
continue
800+
801+
if parameter.prior_distribution not in PRIOR_DISTRIBUTIONS:
802+
messages.append(
803+
f"Prior distribution `{parameter.prior_distribution}' "
804+
f"for parameter `{parameter.id}' is not valid."
805+
)
806+
continue
807+
808+
if (
809+
exp_num_par := self._num_pars[parameter.prior_distribution]
810+
) != len(parameter.prior_parameters):
811+
messages.append(
812+
f"Prior distribution `{parameter.prior_distribution}' "
813+
f"for parameter `{parameter.id}' requires "
814+
f"{exp_num_par} parameters, but got "
815+
f"{len(parameter.prior_parameters)} "
816+
f"({parameter.prior_parameters})."
817+
)
818+
819+
# TODO: check distribution parameter domains
820+
821+
if messages:
822+
return ValidationError("\n".join(messages))
823+
824+
return None
825+
711826

712827
def get_valid_parameters_for_parameter_table(
713828
problem: Problem,
@@ -752,7 +867,7 @@ def get_valid_parameters_for_parameter_table(
752867
if mapping.model_id and mapping.model_id in parameter_ids.keys():
753868
parameter_ids[mapping.petab_id] = None
754869

755-
# add output parameters from observables table
870+
# add output parameters from observable table
756871
output_parameters = get_output_parameters(problem)
757872
for p in output_parameters:
758873
if p not in invalid:
@@ -781,7 +896,7 @@ def get_required_parameters_for_parameter_table(
781896
problem: Problem,
782897
) -> Set[str]:
783898
"""
784-
Get set of parameters which need to go into the parameter table
899+
Get the set of parameters that need to go into the parameter table
785900
786901
Arguments:
787902
problem: The PEtab problem
@@ -965,4 +1080,9 @@ def get_placeholders(
9651080
# TODO: atomize checks, update to long condition table, re-enable
9661081
# CheckVisualizationTable(),
9671082
# TODO validate mapping table
1083+
CheckValidParameterInConditionOrParameterTable(),
1084+
CheckAllObservablesDefined(),
1085+
CheckAllParametersPresentInParameterTable(),
1086+
CheckValidConditionTargets(),
1087+
CheckPriorDistribution(),
9681088
]

petab/v2/petab1to2.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,4 +455,21 @@ def update_prior(row):
455455
errors="ignore",
456456
)
457457

458+
# if uniform, we need to explicitly set the parameters
459+
def update_prior_pars(row):
460+
prior_type = row.get(v2.C.PRIOR_DISTRIBUTION)
461+
prior_pars = row.get(v2.C.PRIOR_PARAMETERS)
462+
463+
if prior_type not in (v2.C.UNIFORM, v2.C.LOG_UNIFORM) or not pd.isna(
464+
prior_pars
465+
):
466+
return prior_pars
467+
468+
return (
469+
f"{row[v2.C.LOWER_BOUND]}{v2.C.PARAMETER_SEPARATOR}"
470+
f"{row[v2.C.UPPER_BOUND]}"
471+
)
472+
473+
df[v2.C.PRIOR_PARAMETERS] = df.apply(update_prior_pars, axis=1)
474+
458475
return df

0 commit comments

Comments
 (0)