From e2f8914d1aeb30fab41e5df3472e0ac686d0f07d Mon Sep 17 00:00:00 2001 From: Raja Sekhar Rao Dheekonda Date: Tue, 27 Jan 2026 17:03:24 -0800 Subject: [PATCH] feat: add compliance framework tagging for AI red teaming Add OWASP, ATLAS, SAIF, and NIST compliance tags to attacks and transforms. Attacks tagged with core jailbreak technique (LLM01), transforms tagged with specific vulnerability categories. Includes comprehensive test coverage. --- dreadnode/airt/__init__.py | 19 +- dreadnode/airt/attack/base.py | 3 + dreadnode/airt/attack/crescendo.py | 22 +++ dreadnode/airt/attack/goat.py | 23 +++ dreadnode/airt/attack/prompt.py | 15 ++ dreadnode/airt/attack/tap.py | 32 +++- dreadnode/airt/compliance/__init__.py | 215 +++++++++++++++++++++ dreadnode/airt/compliance/atlas.py | 132 +++++++++++++ dreadnode/airt/compliance/nist.py | 63 +++++++ dreadnode/airt/compliance/owasp.py | 86 +++++++++ dreadnode/airt/compliance/saif.py | 69 +++++++ dreadnode/transforms/base.py | 3 + dreadnode/transforms/cipher.py | 41 ++-- dreadnode/transforms/constitutional.py | 27 ++- dreadnode/transforms/encoding.py | 57 +++--- dreadnode/transforms/language.py | 21 ++- dreadnode/transforms/perturbation.py | 83 ++++---- dreadnode/transforms/pii_extraction.py | 23 ++- dreadnode/transforms/refine.py | 16 +- dreadnode/transforms/stylistic.py | 17 +- dreadnode/transforms/substitution.py | 34 +++- dreadnode/transforms/swap.py | 15 +- dreadnode/transforms/text.py | 44 +++-- tests/airt/__init__.py | 1 + tests/airt/test_attack_compliance_tags.py | 120 ++++++++++++ tests/airt/test_compliance.py | 148 +++++++++++++++ tests/test_transform_compliance_tags.py | 218 ++++++++++++++++++++++ 27 files changed, 1432 insertions(+), 115 deletions(-) create mode 100644 dreadnode/airt/compliance/__init__.py create mode 100644 dreadnode/airt/compliance/atlas.py create mode 100644 dreadnode/airt/compliance/nist.py create mode 100644 dreadnode/airt/compliance/owasp.py create mode 100644 dreadnode/airt/compliance/saif.py create mode 100644 tests/airt/__init__.py create mode 100644 tests/airt/test_attack_compliance_tags.py create mode 100644 tests/airt/test_compliance.py create mode 100644 tests/test_transform_compliance_tags.py diff --git a/dreadnode/airt/__init__.py b/dreadnode/airt/__init__.py index 0c4820be..aacb9373 100644 --- a/dreadnode/airt/__init__.py +++ b/dreadnode/airt/__init__.py @@ -1,4 +1,4 @@ -from dreadnode.airt import attack, search +from dreadnode.airt import attack, compliance, search from dreadnode.airt.attack import ( Attack, goat_attack, @@ -9,20 +9,37 @@ tap_attack, zoo_attack, ) +from dreadnode.airt.compliance import ( + ATTACK_MAPPINGS, + ATLASTechnique, + NISTAIRMFFunction, + OWASPCategory, + SAIFCategory, + tag_attack, + tag_transform, +) from dreadnode.airt.target import CustomTarget, LLMTarget, Target __all__ = [ + "ATTACK_MAPPINGS", + "ATLASTechnique", "Attack", "CustomTarget", "LLMTarget", + "NISTAIRMFFunction", + "OWASPCategory", + "SAIFCategory", "Target", "attack", + "compliance", "goat_attack", "hop_skip_jump_attack", "nes_attack", "prompt_attack", "search", "simba_attack", + "tag_attack", + "tag_transform", "tap_attack", "target", "zoo_attack", diff --git a/dreadnode/airt/attack/base.py b/dreadnode/airt/attack/base.py index ac12dbff..ad9610ad 100644 --- a/dreadnode/airt/attack/base.py +++ b/dreadnode/airt/attack/base.py @@ -26,6 +26,9 @@ class Attack(Study[In, Out]): tags: list[str] = Config(default_factory=lambda: ["attack"]) """A list of tags associated with the attack for logging.""" + compliance_tags: dict[str, t.Any] = Config(default_factory=dict) + """Compliance framework tags (OWASP, ATLAS, SAIF, NIST) for this attack.""" + hooks: list[EvalHook] = Field(default_factory=list, exclude=True, repr=False) """Hooks to run at various points in the attack lifecycle.""" diff --git a/dreadnode/airt/attack/crescendo.py b/dreadnode/airt/attack/crescendo.py index a0d74c57..5dd53f0a 100644 --- a/dreadnode/airt/attack/crescendo.py +++ b/dreadnode/airt/attack/crescendo.py @@ -3,6 +3,13 @@ import yaml from dreadnode.airt.attack import Attack +from dreadnode.airt.compliance import ( + ATLASTechnique, + NISTAIRMFFunction, + OWASPCategory, + SAIFCategory, + tag_attack, +) from dreadnode.airt.target.base import Target from dreadnode.constants import CRESCENDO_VARIANT_1 from dreadnode.data_types.message import Message as DnMessage @@ -14,6 +21,20 @@ from dreadnode.transforms.base import Transform from dreadnode.transforms.refine import adapt_prompt_trials, llm_refine +# Compliance framework tags for Crescendo attack +# Core jailbreak technique tags - specific vulnerability categories (LLM02, LLM07, etc.) +# are added when transforms targeting those categories are used +COMPLIANCE_TAGS = tag_attack( + atlas=[ + ATLASTechnique.PROMPT_INJECTION_DIRECT, + ATLASTechnique.LLM_JAILBREAK, + ], + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + nist_function=NISTAIRMFFunction.MEASURE, + nist_subcategory="MS-2.7", +) + def crescendo_attack( goal: str, @@ -179,6 +200,7 @@ async def crescendo_refiner(trials: list[Trial[DnMessage]]) -> DnMessage: "objective": objective_judge, }, hooks=hooks or [], + compliance_tags=COMPLIANCE_TAGS, ) # Add stop condition based on early_stopping_score diff --git a/dreadnode/airt/attack/goat.py b/dreadnode/airt/attack/goat.py index 33f15738..6c598ea0 100644 --- a/dreadnode/airt/attack/goat.py +++ b/dreadnode/airt/attack/goat.py @@ -1,6 +1,13 @@ import typing as t from dreadnode.airt.attack import Attack +from dreadnode.airt.compliance import ( + ATLASTechnique, + NISTAIRMFFunction, + OWASPCategory, + SAIFCategory, + tag_attack, +) from dreadnode.data_types.message import Message as DnMessage from dreadnode.meta.context import TrialCandidate from dreadnode.optimization.search.graph import graph_neighborhood_search @@ -18,6 +25,21 @@ from dreadnode.optimization.trial import Trial +# Compliance framework tags for GOAT attack +# Core jailbreak technique tags - specific vulnerability categories (LLM02, LLM07, etc.) +# are added when transforms targeting those categories are used +COMPLIANCE_TAGS = tag_attack( + atlas=[ + ATLASTechnique.PROMPT_INJECTION_DIRECT, + ATLASTechnique.LLM_JAILBREAK, + ], + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + nist_function=NISTAIRMFFunction.MEASURE, + nist_subcategory="MS-2.7", +) + + def goat_attack( goal: str, target: "Target[DnMessage, DnMessage]", @@ -121,6 +143,7 @@ async def message_refiner(trials: list["Trial[DnMessage]"]) -> DnMessage: }, constraints=[topic_constraint], hooks=hooks or [], + compliance_tags=COMPLIANCE_TAGS, ) if early_stopping_score is not None: diff --git a/dreadnode/airt/attack/prompt.py b/dreadnode/airt/attack/prompt.py index 28c7f7eb..a33febb9 100644 --- a/dreadnode/airt/attack/prompt.py +++ b/dreadnode/airt/attack/prompt.py @@ -3,6 +3,7 @@ import rigging as rg from dreadnode.airt.attack.base import Attack +from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_attack from dreadnode.data_types.message import Message as DnMessage from dreadnode.meta import TrialCandidate from dreadnode.optimization.search.graph import beam_search @@ -18,6 +19,19 @@ from dreadnode.optimization.trial import Trial +# Compliance framework tags for prompt attack +# Core jailbreak technique tags - specific vulnerability categories (LLM02, LLM07, etc.) +# are added when transforms targeting those categories are used +COMPLIANCE_TAGS = tag_attack( + atlas=[ + ATLASTechnique.PROMPT_INJECTION_DIRECT, + ATLASTechnique.LLM_JAILBREAK, + ], + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, +) + + def prompt_attack( goal: str, target: "Target[DnMessage, DnMessage]", @@ -117,6 +131,7 @@ async def message_refiner(trials: list["Trial[DnMessage]"]) -> DnMessage: "prompt_judge": prompt_judge, }, hooks=hooks or [], + compliance_tags=COMPLIANCE_TAGS, ) if early_stopping_score is not None: diff --git a/dreadnode/airt/attack/tap.py b/dreadnode/airt/attack/tap.py index 9aa98278..a24e9ca4 100644 --- a/dreadnode/airt/attack/tap.py +++ b/dreadnode/airt/attack/tap.py @@ -2,6 +2,13 @@ from dreadnode.airt.attack import Attack from dreadnode.airt.attack.prompt import prompt_attack +from dreadnode.airt.compliance import ( + ATLASTechnique, + NISTAIRMFFunction, + OWASPCategory, + SAIFCategory, + tag_attack, +) from dreadnode.data_types.message import Message as DnMessage from dreadnode.scorers.judge import llm_judge @@ -10,6 +17,21 @@ from dreadnode.eval.hooks.base import EvalHook +# Compliance framework tags for TAP attack +# Core jailbreak technique tags - specific vulnerability categories (LLM02, LLM07, etc.) +# are added when transforms targeting those categories are used +COMPLIANCE_TAGS = tag_attack( + atlas=[ + ATLASTechnique.PROMPT_INJECTION_DIRECT, + ATLASTechnique.LLM_JAILBREAK, + ], + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + nist_function=NISTAIRMFFunction.MEASURE, + nist_subcategory="MS-2.7", +) + + def tap_attack( goal: str, target: "Target[DnMessage, DnMessage]", @@ -45,7 +67,7 @@ def tap_attack( topic_constraint = llm_judge(evaluator_model, ON_TOPIC_RUBRIC.format(goal=goal)) - return prompt_attack( + base_attack = prompt_attack( goal, target, attacker_model, @@ -58,7 +80,13 @@ def tap_attack( branching_factor=branching_factor, context_depth=context_depth, hooks=hooks or [], - ).with_(constraints={"on_topic": topic_constraint}) + ) + + # Set compliance tags before cloning + base_attack.compliance_tags = COMPLIANCE_TAGS + + # Add constraint and return + return base_attack.with_(constraints={"on_topic": topic_constraint}) REFINE_GUIDANCE = """\ diff --git a/dreadnode/airt/compliance/__init__.py b/dreadnode/airt/compliance/__init__.py new file mode 100644 index 00000000..bc2b3f7f --- /dev/null +++ b/dreadnode/airt/compliance/__init__.py @@ -0,0 +1,215 @@ +""" +Compliance framework tagging for AI red teaming. + +Maps attacks, transforms, and security tests to industry-standard frameworks: +- MITRE ATLAS: AI/ML attack taxonomy +- OWASP Top 10 for LLM Applications: Security vulnerabilities +- Google SAIF: Secure AI Framework categories +- NIST AI RMF: Risk management functions + +Example: + ```python + import dreadnode as dn + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, tag_attack + + # Tag an attack + tags = tag_attack( + atlas=ATLASTechnique.PROMPT_INJECTION_DIRECT, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + ) + + # Tags appear in run metadata + with dn.run("jailbreak-test", **tags): + result = await my_attack.run() + ``` +""" + +import typing as t + +from dreadnode.airt.compliance.atlas import ATLASTechnique +from dreadnode.airt.compliance.nist import NIST_SUBCATEGORIES, NISTAIRMFFunction +from dreadnode.airt.compliance.owasp import OWASPCategory +from dreadnode.airt.compliance.saif import SAIFCategory + + +def tag_attack( + *, + atlas: ATLASTechnique | list[ATLASTechnique] | None = None, + owasp: OWASPCategory | list[OWASPCategory] | None = None, + saif: SAIFCategory | list[SAIFCategory] | None = None, + nist_function: NISTAIRMFFunction | None = None, + nist_subcategory: str | None = None, +) -> dict[str, t.Any]: + """ + Tag an attack with compliance framework mappings. + + Returns a dictionary suitable for run metadata or span attributes. + All parameters are optional - provide only relevant frameworks. + + Args: + atlas: MITRE ATLAS technique ID(s) + owasp: OWASP LLM Application category/categories + saif: Google SAIF security category/categories + nist_function: NIST AI RMF core function + nist_subcategory: NIST AI RMF subcategory code (e.g., "MS-2.7") + + Returns: + Dictionary with framework tags suitable for run metadata + + Example: + ```python + # Single framework + tags = tag_attack(owasp=OWASPCategory.LLM01_PROMPT_INJECTION) + + # Multiple frameworks + tags = tag_attack( + atlas=ATLASTechnique.PROMPT_INJECTION_DIRECT, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + nist_function=NISTAIRMFFunction.MEASURE, + nist_subcategory="MS-2.7", + ) + + # Multiple categories from same framework + tags = tag_attack( + owasp=[ + OWASPCategory.LLM01_PROMPT_INJECTION, + OWASPCategory.LLM06_EXCESSIVE_AGENCY, + ] + ) + + # Use in run context + with dn.run("my-attack", **tags): + result = await attack.run() + ``` + """ + tags: dict[str, t.Any] = {} + + if atlas is not None: + atlas_list = [atlas] if isinstance(atlas, (str, ATLASTechnique)) else atlas + tags["atlas_techniques"] = [str(t) for t in atlas_list] + + if owasp is not None: + owasp_list = [owasp] if isinstance(owasp, (str, OWASPCategory)) else owasp + tags["owasp_categories"] = [str(c) for c in owasp_list] + + if saif is not None: + saif_list = [saif] if isinstance(saif, (str, SAIFCategory)) else saif + tags["saif_categories"] = [str(c) for c in saif_list] + + if nist_function is not None: + tags["nist_ai_rmf_function"] = str(nist_function) + if nist_subcategory: + tags["nist_ai_rmf_subcategory"] = nist_subcategory + + return tags + + +def tag_transform( + *, + atlas: ATLASTechnique | list[ATLASTechnique] | None = None, + owasp: OWASPCategory | list[OWASPCategory] | None = None, + saif: SAIFCategory | list[SAIFCategory] | None = None, +) -> dict[str, t.Any]: + """ + Tag a transform with compliance framework mappings. + + Similar to tag_attack() but for transforms. Transforms typically don't + map to NIST RMF functions (which are organizational processes). + + Args: + atlas: MITRE ATLAS technique ID(s) + owasp: OWASP LLM Application category/categories + saif: Google SAIF security category/categories + + Returns: + Dictionary with framework tags + + Example: + ```python + from dreadnode.transforms.pii_extraction import repeat_word_divergence + + # Tags are stored in transform metadata + transform = repeat_word_divergence() + transform.compliance_tags = tag_transform( + atlas=ATLASTechnique.INFER_TRAINING_DATA, + owasp=OWASPCategory.LLM02_SENSITIVE_INFORMATION_DISCLOSURE, + saif=SAIFCategory.PRIVACY_LEAKAGE, + ) + ``` + """ + return tag_attack(atlas=atlas, owasp=owasp, saif=saif) + + +# Pre-defined mappings for common attack patterns +ATTACK_MAPPINGS = { + "jailbreak": tag_attack( + atlas=ATLASTechnique.LLM_JAILBREAK, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + nist_function=NISTAIRMFFunction.MEASURE, + nist_subcategory="MS-2.7", + ), + "prompt_injection_direct": tag_attack( + atlas=ATLASTechnique.PROMPT_INJECTION_DIRECT, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + nist_function=NISTAIRMFFunction.MEASURE, + nist_subcategory="MS-2.7", + ), + "prompt_injection_indirect": tag_attack( + atlas=ATLASTechnique.PROMPT_INJECTION_INDIRECT, + owasp=[OWASPCategory.LLM01_PROMPT_INJECTION, OWASPCategory.LLM03_SUPPLY_CHAIN], + saif=SAIFCategory.INPUT_MANIPULATION, + nist_function=NISTAIRMFFunction.MEASURE, + nist_subcategory="MS-2.7", + ), + "tool_misuse": tag_attack( + atlas=ATLASTechnique.PROMPT_INJECTION, + owasp=OWASPCategory.LLM06_EXCESSIVE_AGENCY, + saif=SAIFCategory.INPUT_MANIPULATION, + nist_function=NISTAIRMFFunction.MEASURE, + ), + "pii_extraction": tag_attack( + atlas=[ATLASTechnique.MODEL_INVERSION, ATLASTechnique.MEMBERSHIP_INFERENCE], + owasp=OWASPCategory.LLM02_SENSITIVE_INFORMATION_DISCLOSURE, + saif=SAIFCategory.PRIVACY_LEAKAGE, + nist_function=NISTAIRMFFunction.MEASURE, + nist_subcategory="MS-2.8", + ), + "system_prompt_leakage": tag_attack( + atlas=ATLASTechnique.PROMPT_INJECTION, + owasp=OWASPCategory.LLM07_SYSTEM_PROMPT_LEAKAGE, + saif=SAIFCategory.PRIVACY_LEAKAGE, + nist_function=NISTAIRMFFunction.MEASURE, + ), + "model_extraction": tag_attack( + atlas=ATLASTechnique.MODEL_EXTRACTION, + saif=SAIFCategory.MODEL_THEFT, + nist_function=NISTAIRMFFunction.MEASURE, + ), + "denial_of_service": tag_attack( + atlas=ATLASTechnique.DENIAL_OF_ML_SERVICE, + owasp=OWASPCategory.LLM10_UNBOUNDED_CONSUMPTION, + saif=SAIFCategory.AVAILABILITY_ATTACKS, + nist_function=NISTAIRMFFunction.MEASURE, + ), + "data_poisoning": tag_attack( + atlas=ATLASTechnique.POISON_TRAINING_DATA, + owasp=OWASPCategory.LLM04_DATA_MODEL_POISONING, + saif=SAIFCategory.DATA_POISONING, + nist_function=NISTAIRMFFunction.MEASURE, + ), +} + + +__all__ = [ + "ATTACK_MAPPINGS", + "NIST_SUBCATEGORIES", + "ATLASTechnique", + "NISTAIRMFFunction", + "OWASPCategory", + "SAIFCategory", + "tag_attack", + "tag_transform", +] diff --git a/dreadnode/airt/compliance/atlas.py b/dreadnode/airt/compliance/atlas.py new file mode 100644 index 00000000..8c782bf0 --- /dev/null +++ b/dreadnode/airt/compliance/atlas.py @@ -0,0 +1,132 @@ +""" +MITRE ATLAS (Adversarial Threat Landscape for Artificial-Intelligence Systems). + +AI/ML-specific attack techniques organized by tactics similar to MITRE ATT&CK. + +Reference: https://atlas.mitre.org/ +""" + +from enum import Enum + + +class ATLASTechnique(str, Enum): + """ + MITRE ATLAS AI/ML Attack Technique IDs. + + Techniques are organized by tactics: Reconnaissance, Resource Development, + Initial Access, Execution, Persistence, Privilege Escalation, Defense Evasion, + Discovery, Collection, ML Attack Staging, Exfiltration, and Impact. + + Reference: https://atlas.mitre.org/techniques/ + """ + + # === Reconnaissance === + DISCOVER_ML_ARTIFACTS = "AML.T0002" + """Discover ML artifacts and metadata in public or target systems.""" + + DISCOVER_MODEL_ONTOLOGY = "AML.T0003" + """Discover model architecture, capabilities, and ontology.""" + + # === Resource Development === + ACQUIRE_PUBLIC_ML_ARTIFACTS = "AML.T0000" + """Acquire public ML artifacts like pre-trained models or datasets.""" + + DEVELOP_ADVERSARIAL_ML_MODEL = "AML.T0001" + """Develop models designed to evade or attack target systems.""" + + POISON_TRAINING_DATA = "AML.T0020" + """Inject malicious data into training sets to corrupt model behavior.""" + + # === Initial Access === + PROMPT_INJECTION = "AML.T0051" + """Manipulate LLM inputs to override instructions or execute unintended actions.""" + + PROMPT_INJECTION_DIRECT = "AML.T0051.000" + """Direct prompt injection via user-controlled input.""" + + PROMPT_INJECTION_INDIRECT = "AML.T0051.001" + """Indirect prompt injection via external data sources (emails, documents, web).""" + + SUPPLY_CHAIN_COMPROMISE = "AML.T0010" + """Compromise ML supply chain through malicious models, datasets, or dependencies.""" + + # === Execution === + UNSAFE_ML_ARTIFACT = "AML.T0018" + """Execute unsafe ML artifacts like poisoned models or malicious code.""" + + # === Persistence === + BACKDOOR_ML_MODEL = "AML.T0019" + """Embed backdoors in ML models that activate on specific triggers.""" + + # === Privilege Escalation === + # (Uses techniques from other tactics) + + # === Defense Evasion === + EVADE_ML_MODEL = "AML.T0043" + """Craft inputs that evade detection or classification by ML models.""" + + ADVERSARIAL_PERTURBATION = "AML.T0043.001" + """Add imperceptible perturbations to inputs to cause misclassification.""" + + TRANSFER_ATTACK = "AML.T0043.002" + """Transfer adversarial examples from surrogate to target model.""" + + OBFUSCATE_ARTIFACTS = "AML.T0044" + """Obfuscate malicious content to evade ML-based detection.""" + + # === Credential Access === + # (Uses techniques from traditional ATT&CK) + + # === Discovery === + DISCOVER_TRAINING_DATA = "AML.T0052" + """Infer characteristics or contents of training data.""" + + DISCOVER_MODEL_FAMILY = "AML.T0053" + """Determine model architecture family (transformer, CNN, etc.).""" + + # === Lateral Movement === + # (Uses techniques from traditional ATT&CK) + + # === Collection === + INFER_TRAINING_DATA = "AML.T0024" + """Extract or infer training data through model inversion or membership inference.""" + + MODEL_INVERSION = "AML.T0024.000" + """Reconstruct training data from model outputs (e.g., recover faces, text).""" + + MEMBERSHIP_INFERENCE = "AML.T0024.001" + """Determine if specific data was in the training set.""" + + # === ML Attack Staging === + CRAFT_ADVERSARIAL_DATA = "AML.T0049" + """Generate adversarial examples optimized to fool target models.""" + + VERIFY_ATTACK = "AML.T0042" + """Test adversarial inputs against surrogate or target models.""" + + # === Command and Control === + # (Uses techniques from traditional ATT&CK) + + # === Exfiltration === + EXFILTRATION_VIA_ML_INFERENCE = "AML.T0026" + """Extract sensitive data through repeated model queries and inference.""" + + MODEL_EXTRACTION = "AML.T0040" + """Steal model functionality by querying and replicating behavior.""" + + # === Impact === + ERODE_ML_MODEL_INTEGRITY = "AML.T0048" + """Degrade model accuracy, fairness, or reliability.""" + + LLM_JAILBREAK = "AML.T0054" + """Bypass LLM safety mechanisms to generate prohibited content.""" + + DENIAL_OF_ML_SERVICE = "AML.T0029" + """Exhaust model resources through adversarial queries or sponge examples.""" + + def __str__(self) -> str: + """Return the technique ID.""" + return self.value + + +__all__ = ["ATLASTechnique"] diff --git a/dreadnode/airt/compliance/nist.py b/dreadnode/airt/compliance/nist.py new file mode 100644 index 00000000..8e21731a --- /dev/null +++ b/dreadnode/airt/compliance/nist.py @@ -0,0 +1,63 @@ +""" +NIST AI Risk Management Framework (AI RMF). + +Risk management functions and categories for AI systems. + +Reference: https://www.nist.gov/itl/ai-risk-management-framework +""" + +from enum import Enum + + +class NISTAIRMFFunction(str, Enum): + """ + NIST AI Risk Management Framework Core Functions. + + The AI RMF organizes risk management activities into four core functions + that work together to manage AI risks throughout the system lifecycle. + + Reference: https://www.nist.gov/itl/ai-risk-management-framework + """ + + GOVERN = "GOVERN" + """ + Govern: Cultivate and manage organizational culture, processes, and structures + for responsible AI development and deployment. Includes policies, accountability, + and risk governance. + """ + + MAP = "MAP" + """ + Map: Establish context and understand risks. Includes categorizing AI systems, + identifying stakeholders, and mapping potential risks and impacts. + """ + + MEASURE = "MEASURE" + """ + Measure: Analyze, assess, benchmark, and monitor AI risks and impacts. + Includes testing, evaluation, auditing, and continuous monitoring. + """ + + MANAGE = "MANAGE" + """ + Manage: Allocate resources to prioritize and respond to AI risks. Includes + risk mitigation, treatment, incident response, and continuous improvement. + """ + + def __str__(self) -> str: + """Return the function name.""" + return self.value + + +# Common NIST AI RMF subcategories for reference +NIST_SUBCATEGORIES = { + "MS-2.7": "AI system reliability and robustness under adversarial conditions", + "MS-2.8": "Privacy risks from AI systems", + "MS-2.9": "Security vulnerabilities in AI systems", + "MG-3.1": "AI risks are prioritized and treated", + "MG-3.2": "Adverse events are documented and monitored", + "GV-1.1": "Legal and regulatory requirements are understood and documented", +} + + +__all__ = ["NIST_SUBCATEGORIES", "NISTAIRMFFunction"] diff --git a/dreadnode/airt/compliance/owasp.py b/dreadnode/airt/compliance/owasp.py new file mode 100644 index 00000000..b4088625 --- /dev/null +++ b/dreadnode/airt/compliance/owasp.py @@ -0,0 +1,86 @@ +""" +OWASP Top 10 for LLM Applications 2025. + +Reference: https://genai.owasp.org/llm-top-10/ +""" + +from enum import Enum + + +class OWASPCategory(str, Enum): + """ + OWASP Top 10 for LLM Applications 2025. + + Each category represents a critical security vulnerability class specific + to Large Language Model applications. + + Reference: https://genai.owasp.org/llm-top-10/ + """ + + LLM01_PROMPT_INJECTION = "LLM01:2025" + """ + Prompt Injection: Manipulating LLM inputs to override system instructions, + execute unintended actions, or access unauthorized data. Includes both direct + (user input) and indirect (external data sources) injection vectors. + """ + + LLM02_SENSITIVE_INFORMATION_DISCLOSURE = "LLM02:2025" + """ + Sensitive Information Disclosure: Exposing confidential data through LLM outputs, + including PII, credentials, proprietary information, or training data memorization. + """ + + LLM03_SUPPLY_CHAIN = "LLM03:2025" + """ + Supply Chain Vulnerabilities: Risks from third-party models, datasets, plugins, + or dependencies that may be compromised, outdated, or malicious. + """ + + LLM04_DATA_MODEL_POISONING = "LLM04:2025" + """ + Data and Model Poisoning: Manipulation of training data or fine-tuning processes + to inject backdoors, biases, or vulnerabilities into the model. + """ + + LLM05_IMPROPER_OUTPUT_HANDLING = "LLM05:2025" + """ + Improper Output Handling: Insufficient validation of LLM outputs before downstream + use, leading to injection attacks (XSS, SQL injection) or code execution. + """ + + LLM06_EXCESSIVE_AGENCY = "LLM06:2025" + """ + Excessive Agency: LLM systems with too much autonomy or permissions, enabling + unintended actions, privilege escalation, or unauthorized system modifications. + """ + + LLM07_SYSTEM_PROMPT_LEAKAGE = "LLM07:2025" + """ + System Prompt Leakage: Disclosure of system prompts, instructions, or configuration + details that reveal security mechanisms or enable targeted attacks. + """ + + LLM08_VECTOR_EMBEDDING_WEAKNESSES = "LLM08:2025" + """ + Vector and Embedding Weaknesses: Vulnerabilities in RAG systems, vector databases, + or embedding models that enable data poisoning or unauthorized access. + """ + + LLM09_MISINFORMATION = "LLM09:2025" + """ + Misinformation: Generation of false, misleading, or fabricated information + (hallucinations) that appears credible but lacks factual grounding. + """ + + LLM10_UNBOUNDED_CONSUMPTION = "LLM10:2025" + """ + Unbounded Consumption: Resource exhaustion through excessive LLM requests, + context window abuse, or denial-of-service attacks targeting inference costs. + """ + + def __str__(self) -> str: + """Return the category ID.""" + return self.value + + +__all__ = ["OWASPCategory"] diff --git a/dreadnode/airt/compliance/saif.py b/dreadnode/airt/compliance/saif.py new file mode 100644 index 00000000..c2ef18d5 --- /dev/null +++ b/dreadnode/airt/compliance/saif.py @@ -0,0 +1,69 @@ +""" +Google SAIF (Secure AI Framework). + +Security categories for AI/ML systems aligned with Google's security principles. + +Reference: https://blog.google/technology/safety-security/google-secure-ai-framework/ +""" + +from enum import Enum + + +class SAIFCategory(str, Enum): + """ + Google SAIF (Secure AI Framework) Security Categories. + + Organizes AI security risks into actionable categories aligned with + traditional security controls and threat modeling. + + Reference: https://blog.google/technology/safety-security/google-secure-ai-framework/ + """ + + INPUT_MANIPULATION = "INPUT_MANIPULATION" + """ + Input Manipulation: Adversarial inputs designed to manipulate model behavior, + including prompt injection, adversarial examples, and input perturbations. + """ + + OUTPUT_MANIPULATION = "OUTPUT_MANIPULATION" + """ + Output Manipulation: Attacks targeting model outputs, including response + poisoning, hallucination exploitation, and output handling vulnerabilities. + """ + + MODEL_THEFT = "MODEL_THEFT" + """ + Model Theft: Stealing model functionality or intellectual property through + model extraction, knowledge distillation, or architecture inference. + """ + + DATA_POISONING = "DATA_POISONING" + """ + Data Poisoning: Corruption of training data to inject backdoors, biases, + or vulnerabilities into the model during training or fine-tuning. + """ + + SUPPLY_CHAIN_COMPROMISE = "SUPPLY_CHAIN_COMPROMISE" + """ + Supply Chain Compromise: Attacks targeting the ML supply chain including + malicious dependencies, poisoned datasets, or compromised pre-trained models. + """ + + PRIVACY_LEAKAGE = "PRIVACY_LEAKAGE" + """ + Privacy Leakage: Disclosure of sensitive information through model outputs, + including PII extraction, training data memorization, and membership inference. + """ + + AVAILABILITY_ATTACKS = "AVAILABILITY_ATTACKS" + """ + Availability Attacks: Denial of service, resource exhaustion, or system + degradation through adversarial queries or sponge examples. + """ + + def __str__(self) -> str: + """Return the category name.""" + return self.value + + +__all__ = ["SAIFCategory"] diff --git a/dreadnode/transforms/base.py b/dreadnode/transforms/base.py index c1e8619c..e41dae74 100644 --- a/dreadnode/transforms/base.py +++ b/dreadnode/transforms/base.py @@ -43,6 +43,7 @@ def __init__( catch: bool = False, config: dict[str, ConfigInfo] | None = None, context: dict[str, Context] | None = None, + compliance_tags: dict[str, t.Any] | None = None, ): super().__init__( t.cast("t.Callable[[In], Out]", func), name=name, config=config, context=context @@ -55,6 +56,8 @@ def __init__( If True, catches exceptions during the transform and attempts to return the original, unmodified object from the input. If False, exceptions are raised. """ + self.compliance_tags = compliance_tags or {} + """Compliance framework tags (OWASP, ATLAS, SAIF, NIST) for this transform.""" @classmethod def fit(cls, transform: "TransformLike[In, Out]") -> "Transform[In, Out]": diff --git a/dreadnode/transforms/cipher.py b/dreadnode/transforms/cipher.py index d56a45ca..4e2bd365 100644 --- a/dreadnode/transforms/cipher.py +++ b/dreadnode/transforms/cipher.py @@ -1,4 +1,5 @@ import codecs +import functools import random import string import typing as t @@ -7,6 +8,18 @@ from dreadnode.transforms.base import Transform +@functools.lru_cache(maxsize=1) +def _get_obfuscation_tags() -> dict[str, t.Any]: + """Get compliance tags for obfuscation transforms (cached).""" + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_transform + + return tag_transform( + atlas=ATLASTechnique.OBFUSCATE_ARTIFACTS, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + def atbash_cipher(*, name: str = "atbash") -> Transform[str, str]: """Encodes text using the Atbash cipher.""" @@ -19,7 +32,7 @@ def transform(text: str) -> str: translation_table = str.maketrans("".join(alphabet), "".join(reversed_alphabet)) return text.translate(translation_table) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def caesar_cipher(offset: int, *, name: str = "caesar") -> Transform[str, str]: @@ -39,7 +52,7 @@ def shift(alphabet: str) -> str: translation_table = str.maketrans("".join(alphabet), "".join(shifted_alphabet)) return text.translate(translation_table) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def rot13_cipher(*, name: str = "rot13") -> Transform[str, str]: @@ -48,7 +61,7 @@ def rot13_cipher(*, name: str = "rot13") -> Transform[str, str]: def transform(text: str) -> str: return codecs.encode(text, "rot13") - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def rot47_cipher(*, name: str = "rot47") -> Transform[str, str]: @@ -67,7 +80,7 @@ def transform(text: str) -> str: transformed.append(char) return "".join(transformed) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def vigenere_cipher( @@ -115,7 +128,7 @@ def transform( return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def substitution_cipher( @@ -155,7 +168,7 @@ def transform(text: str) -> str: ) return text.translate(translation_table) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def xor_cipher( @@ -201,7 +214,7 @@ def transform( # raw return xored.decode("latin-1") - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def rail_fence_cipher( @@ -246,7 +259,7 @@ def transform( # Read off the rails return "".join("".join(rail) for rail in fence) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def columnar_transposition( @@ -297,7 +310,7 @@ def transform( return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def playfair_cipher( @@ -379,7 +392,7 @@ def transform( return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def affine_cipher( @@ -428,7 +441,7 @@ def transform( return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def bacon_cipher( @@ -522,7 +535,7 @@ def transform( return " ".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def autokey_cipher( @@ -571,7 +584,7 @@ def transform( return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def beaufort_cipher( @@ -622,4 +635,4 @@ def transform( return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) diff --git a/dreadnode/transforms/constitutional.py b/dreadnode/transforms/constitutional.py index ac92525b..2e250d46 100644 --- a/dreadnode/transforms/constitutional.py +++ b/dreadnode/transforms/constitutional.py @@ -22,6 +22,7 @@ - Hybrid: Static fallback to LLM for unknown terms """ +import functools import json import random import re @@ -34,6 +35,18 @@ from dreadnode.transforms.base import Transform +@functools.lru_cache(maxsize=1) +def _get_constitutional_evasion_tags() -> dict[str, t.Any]: + """Get compliance tags for constitutional_evasion transforms (cached).""" + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_transform + + return tag_transform( + atlas=[ATLASTechnique.EVADE_ML_MODEL, ATLASTechnique.OBFUSCATE_ARTIFACTS], + owasp=[OWASPCategory.LLM01_PROMPT_INJECTION, OWASPCategory.LLM05_IMPROPER_OUTPUT_HANDLING], + saif=[SAIFCategory.INPUT_MANIPULATION, SAIFCategory.OUTPUT_MANIPULATION], + ) + + @dataclass class MetaphorMapping: """Mapping of technical terms to metaphorical language.""" @@ -431,7 +444,7 @@ def transform(text: str) -> str: return "\n".join(code_lines) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_constitutional_evasion_tags()) def _format_as_markdown(fragments: list[str], labels: list[str], *, add_metadata: bool) -> str: @@ -579,7 +592,7 @@ def transform(text: str) -> str: return text - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_constitutional_evasion_tags()) def multi_turn_fragmentation( @@ -627,7 +640,7 @@ def transform(text: str) -> list[str]: return turns - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_constitutional_evasion_tags()) # ============================================================================ @@ -777,7 +790,7 @@ async def transform(text: str) -> str: return encoded - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_constitutional_evasion_tags()) def character_separation( @@ -819,7 +832,7 @@ def transform(text: str) -> str: return separated - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_constitutional_evasion_tags()) # Static riddle templates for riddle_encoding @@ -897,7 +910,7 @@ async def transform(text: str) -> str: # Generic riddle return f"Solve this puzzle: {text}\n(Express your answer in metaphorical terms)" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_constitutional_evasion_tags()) # Default chemistry substitution mappings for contextual_substitution @@ -972,4 +985,4 @@ def transform(text: str) -> str: return result - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_constitutional_evasion_tags()) diff --git a/dreadnode/transforms/encoding.py b/dreadnode/transforms/encoding.py index 06699b3a..93d9894b 100644 --- a/dreadnode/transforms/encoding.py +++ b/dreadnode/transforms/encoding.py @@ -1,4 +1,5 @@ import base64 +import functools import html import json import random @@ -9,13 +10,25 @@ from dreadnode.transforms.base import Transform +@functools.lru_cache(maxsize=1) +def _get_obfuscation_tags() -> dict[str, t.Any]: + """Get compliance tags for obfuscation transforms (cached).""" + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_transform + + return tag_transform( + atlas=ATLASTechnique.OBFUSCATE_ARTIFACTS, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + def ascii85_encode(*, name: str = "ascii85") -> Transform[str, str]: """Encodes text to ASCII85.""" def transform(text: str) -> str: return base64.a85encode(text.encode("utf-8")).decode("ascii") - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def base32_encode(*, name: str = "base32") -> Transform[str, str]: @@ -24,7 +37,7 @@ def base32_encode(*, name: str = "base32") -> Transform[str, str]: def transform(text: str) -> str: return base64.b32encode(text.encode("utf-8")).decode("ascii") - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def base64_encode(*, name: str = "base64") -> Transform[str, str]: @@ -33,7 +46,7 @@ def base64_encode(*, name: str = "base64") -> Transform[str, str]: def transform(text: str) -> str: return base64.b64encode(text.encode("utf-8")).decode("utf-8") - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def binary_encode(bits_per_char: int = 16, *, name: str = "binary") -> Transform[str, str]: @@ -52,7 +65,7 @@ def transform( ) return " ".join(format(ord(char), f"0{bits_per_char}b") for char in text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def hex_encode(*, name: str = "hex") -> Transform[str, str]: @@ -61,7 +74,7 @@ def hex_encode(*, name: str = "hex") -> Transform[str, str]: def transform(text: str) -> str: return text.encode("utf-8").hex().upper() - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def html_escape(*, name: str = "html_escape") -> Transform[str, str]: @@ -70,7 +83,7 @@ def html_escape(*, name: str = "html_escape") -> Transform[str, str]: def transform(text: str) -> str: return html.escape(text, quote=True) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def url_encode(*, name: str = "url_encode") -> Transform[str, str]: @@ -79,7 +92,7 @@ def url_encode(*, name: str = "url_encode") -> Transform[str, str]: def transform(text: str) -> str: return urllib.parse.quote(text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def unicode_escape( @@ -122,7 +135,7 @@ def transform( return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def json_encode( @@ -148,7 +161,7 @@ def transform( ) -> str: return json.dumps(text, ensure_ascii=ensure_ascii) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def punycode_encode(*, name: str = "punycode") -> Transform[str, str]: @@ -161,7 +174,7 @@ def punycode_encode(*, name: str = "punycode") -> Transform[str, str]: def transform(text: str) -> str: return text.encode("punycode").decode("ascii") - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def quoted_printable_encode(*, name: str = "quoted_printable") -> Transform[str, str]: @@ -175,7 +188,7 @@ def quoted_printable_encode(*, name: str = "quoted_printable") -> Transform[str, def transform(text: str) -> str: return quopri.encodestring(text.encode("utf-8")).decode("ascii") - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def base58_encode(*, name: str = "base58") -> Transform[str, str]: @@ -210,7 +223,7 @@ def transform(text: str) -> str: return "".join(reversed(result)) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def percent_encoding( @@ -241,7 +254,7 @@ def transform( encoded = urllib.parse.quote(encoded, safe="") return encoded - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def html_entity_encode( @@ -285,7 +298,7 @@ def transform( return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def octal_encode(*, name: str = "octal") -> Transform[str, str]: @@ -298,7 +311,7 @@ def octal_encode(*, name: str = "octal") -> Transform[str, str]: def transform(text: str) -> str: return "".join(f"\\{ord(char):03o}" for char in text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def utf7_encode(*, name: str = "utf7") -> Transform[str, str]: @@ -323,7 +336,7 @@ def transform(text: str) -> str: result.append(f"+{base64.b64encode(bytes([byte])).decode('ascii').rstrip('=')}-") return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def base91_encode(*, name: str = "base91") -> Transform[str, str]: @@ -366,7 +379,7 @@ def transform(text: str) -> str: return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def mixed_case_hex(*, name: str = "mixed_case_hex") -> Transform[str, str]: @@ -386,7 +399,7 @@ def transform(text: str) -> str: result.append(mixed) return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def backslash_escape( @@ -417,7 +430,7 @@ def transform( result.append(char) return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def zero_width_encode( @@ -471,7 +484,7 @@ def transform( return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def leetspeak_encoding( @@ -526,7 +539,7 @@ def transform( result.append(char) return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def morse_encode( @@ -595,4 +608,4 @@ def transform( return " ".join(morse_chars) return "".join(morse_chars) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) diff --git a/dreadnode/transforms/language.py b/dreadnode/transforms/language.py index 1151892f..92c16b4e 100644 --- a/dreadnode/transforms/language.py +++ b/dreadnode/transforms/language.py @@ -1,3 +1,4 @@ +import functools import typing as t import rigging as rg @@ -7,6 +8,18 @@ from dreadnode.transforms.base import Transform +@functools.lru_cache(maxsize=1) +def _get_style_manipulation_tags() -> dict[str, t.Any]: + """Get compliance tags for style_manipulation transforms (cached).""" + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_transform + + return tag_transform( + atlas=ATLASTechnique.OBFUSCATE_ARTIFACTS, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + def adapt_language( target_language: str, *, @@ -141,7 +154,7 @@ async def transform( return adapted_text - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_style_manipulation_tags()) def transliterate( @@ -409,7 +422,7 @@ def transform( result.append(char) return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_style_manipulation_tags()) def code_switch( @@ -515,7 +528,7 @@ async def transform( return result_text.strip() - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_style_manipulation_tags()) def dialectal_variation( @@ -616,4 +629,4 @@ async def transform( return result_text.strip() - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_style_manipulation_tags()) diff --git a/dreadnode/transforms/perturbation.py b/dreadnode/transforms/perturbation.py index 9ab17f9b..02df8d2c 100644 --- a/dreadnode/transforms/perturbation.py +++ b/dreadnode/transforms/perturbation.py @@ -1,3 +1,4 @@ +import functools import random import re import string @@ -10,6 +11,18 @@ from dreadnode.util import catch_import_error +@functools.lru_cache(maxsize=1) +def _get_perturbation_tags() -> dict[str, t.Any]: + """Get compliance tags for perturbation transforms (cached).""" + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_transform + + return tag_transform( + atlas=ATLASTechnique.ADVERSARIAL_PERTURBATION, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + def random_capitalization( *, ratio: float = 0.2, @@ -45,7 +58,7 @@ def transform( chars[i] = chars[i].upper() return "".join(chars) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def insert_punctuation( @@ -95,7 +108,7 @@ def transform( words[i] = words[i] + punc return " ".join(words) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def diacritic( @@ -133,7 +146,9 @@ def transform( for char in text ) - return Transform(transform, name=name or f"diacritic_{accent}") + return Transform( + transform, name=name or f"diacritic_{accent}", compliance_tags=_get_perturbation_tags() + ) def underline(*, name: str = "underline") -> Transform[str, str]: @@ -142,7 +157,7 @@ def underline(*, name: str = "underline") -> Transform[str, str]: def transform(text: str) -> str: return "".join(char + "\u0332" for char in text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def character_space(*, name: str = "character_space") -> Transform[str, str]: @@ -153,7 +168,7 @@ def transform(text: str) -> str: text_no_punc = text.translate(punctuation_to_remove) return " ".join(text_no_punc) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def zero_width(*, name: str = "zero_width") -> Transform[str, str]: @@ -162,7 +177,7 @@ def zero_width(*, name: str = "zero_width") -> Transform[str, str]: def transform(text: str) -> str: return "\u200b".join(text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def zalgo( @@ -214,7 +229,9 @@ def transform( return "".join(chars) - return Transform(transform, name=name or f"zalgo_{intensity}") + return Transform( + transform, name=name or f"zalgo_{intensity}", compliance_tags=_get_perturbation_tags() + ) def unicode_confusable( @@ -269,7 +286,7 @@ def transform( chars[i] = rand.choice(replacement_options) return "".join(chars) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def unicode_replacement( @@ -289,7 +306,7 @@ def transform(text: str) -> str: result = result.replace("\\u0020", " ") return result - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def unicode_substitution( @@ -306,7 +323,7 @@ def unicode_substitution( def transform(text: str) -> str: return "".join(chr(start_value + ord(ch)) for ch in text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def repeat_token( @@ -355,7 +372,7 @@ def transform( return f"{text}{repeated}" return repeated - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def emoji_substitution( @@ -469,7 +486,7 @@ def transform( return "".join(chars) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def token_smuggling( @@ -507,7 +524,7 @@ def transform( # both return f"{smuggled_text}{separator}{text}{separator}{smuggled_text}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def semantic_preserving_perturbation( @@ -569,7 +586,7 @@ def transform( result = "".join(words) return re.sub(r'\s([?.!,\'"])', r"\1", result).strip() - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def instruction_hierarchy_confusion( @@ -608,7 +625,7 @@ def transform( hidden = f"\n\n" return f"{text}{hidden}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def context_overflow( @@ -638,7 +655,7 @@ def transform( overflow = filler_text * repetitions return f"{overflow}\n\n{text}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def adversarial_suffix( @@ -664,7 +681,7 @@ def transform( ) -> str: return f"{text} {suffix}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def gradient_based_perturbation( @@ -721,7 +738,7 @@ def transform( return "".join(chars) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def multilingual_mixing( @@ -799,7 +816,7 @@ def transform( return "".join(words) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def cognitive_hacking( @@ -844,7 +861,7 @@ def transform( template = frames[framing] return template.format(text=text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def payload_splitting( @@ -895,7 +912,7 @@ def transform( return delimiter.join(parts) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def attention_diversion( @@ -930,7 +947,7 @@ def transform( # sandwich return f"{distractor_text}\n\n{text}\n\n{distractor_text}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def style_injection( @@ -1008,7 +1025,7 @@ def transform( return " ".join(words) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def implicit_continuation( @@ -1034,7 +1051,7 @@ def transform( ) -> str: return f"{continuation_prompt} {text}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def authority_exploitation( @@ -1070,7 +1087,7 @@ def transform( template = authority_frames[authority_source] return template.format(text=text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def linguistic_camouflage( @@ -1139,7 +1156,7 @@ def transform( return "".join(words) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def temporal_misdirection( @@ -1174,7 +1191,7 @@ def transform( template = temporal_frames[tense] return template.format(text=text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def complexity_amplification( @@ -1217,7 +1234,7 @@ def transform( return result - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def error_injection( @@ -1289,7 +1306,7 @@ def transform( return "".join(words) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def encoding_nesting( @@ -1330,7 +1347,7 @@ def transform(text: str) -> str: result = apply_encoding(result, encoding) return result - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def token_boundary_manipulation( @@ -1393,7 +1410,7 @@ def transform( return " ".join(w for w in words if w) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def meta_instruction_injection( @@ -1431,7 +1448,7 @@ def transform( encoded = base64.b64encode(meta_instruction.encode()).decode() return f"[Encoded]: {encoded}\n\n{text}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) def sentiment_inversion( @@ -1489,4 +1506,4 @@ def transform( return "".join(words) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_perturbation_tags()) diff --git a/dreadnode/transforms/pii_extraction.py b/dreadnode/transforms/pii_extraction.py index 0e8551c7..6e2c97bd 100644 --- a/dreadnode/transforms/pii_extraction.py +++ b/dreadnode/transforms/pii_extraction.py @@ -1,8 +1,21 @@ +import functools import typing as t from dreadnode.transforms.base import Transform +@functools.lru_cache(maxsize=1) +def _get_pii_extraction_tags() -> dict[str, t.Any]: + """Get compliance tags for PII extraction transforms (cached).""" + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_transform + + return tag_transform( + atlas=[ATLASTechnique.INFER_TRAINING_DATA, ATLASTechnique.MODEL_INVERSION], + owasp=OWASPCategory.LLM02_SENSITIVE_INFORMATION_DISCLOSURE, + saif=SAIFCategory.PRIVACY_LEAKAGE, + ) + + def repeat_word_divergence( word: str = "company", *, @@ -38,7 +51,7 @@ def transform( repeated = " ".join([word] * repetitions) return f"{text} {repeated}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_pii_extraction_tags()) def continue_exact_text( @@ -75,7 +88,7 @@ def transform( return f"{instruction}\n\n{prefix}\n\nContext: {text}" return f"{instruction}\n\n{prefix}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_pii_extraction_tags()) def complete_from_internet( @@ -113,7 +126,7 @@ def transform( return f"{context}\n\n{partial_text}\n\nAdditional context: {text}" return f"{context}\n\n{partial_text}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_pii_extraction_tags()) def partial_pii_completion( @@ -170,7 +183,7 @@ def transform( return f"{prompt}\n\nContext: {text}" return prompt - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_pii_extraction_tags()) def public_figure_pii_probe( @@ -224,4 +237,4 @@ def transform( return f"{prompt}\n\nAdditional context: {text}" return prompt - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_pii_extraction_tags()) diff --git a/dreadnode/transforms/refine.py b/dreadnode/transforms/refine.py index 822affe8..d17c6103 100644 --- a/dreadnode/transforms/refine.py +++ b/dreadnode/transforms/refine.py @@ -1,3 +1,4 @@ +import functools import typing as t from collections import defaultdict from textwrap import dedent, indent @@ -9,6 +10,19 @@ from dreadnode.meta import Config from dreadnode.transforms.base import Transform + +@functools.lru_cache(maxsize=1) +def _get_refinement_tags() -> dict[str, t.Any]: + """Get compliance tags for refinement transforms (cached).""" + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_transform + + return tag_transform( + atlas=ATLASTechnique.CRAFT_ADVERSARIAL_DATA, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + if t.TYPE_CHECKING: from ulid import ULID @@ -73,7 +87,7 @@ async def transform( refinement = await refine.bind(generator)(refiner_input) return refinement.prompt - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_refinement_tags()) def adapt_prompt_trials(trials: "list[Trial[DnMessage]]") -> str: diff --git a/dreadnode/transforms/stylistic.py b/dreadnode/transforms/stylistic.py index 79608352..b7c7993e 100644 --- a/dreadnode/transforms/stylistic.py +++ b/dreadnode/transforms/stylistic.py @@ -1,3 +1,4 @@ +import functools import typing as t from dreadnode.meta import Config @@ -5,6 +6,18 @@ from dreadnode.util import catch_import_error +@functools.lru_cache(maxsize=1) +def _get_style_manipulation_tags() -> dict[str, t.Any]: + """Get compliance tags for style_manipulation transforms (cached).""" + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_transform + + return tag_transform( + atlas=ATLASTechnique.OBFUSCATE_ARTIFACTS, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + def ascii_art(font: str = "rand", *, name: str = "ascii_art") -> Transform[str, str]: """Converts text into ASCII art using the 'art' library.""" @@ -14,7 +27,7 @@ def ascii_art(font: str = "rand", *, name: str = "ascii_art") -> Transform[str, def transform(text: str, *, font: str = Config(font, help="The font to use")) -> str: return str(text2art(text, font=font)) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_style_manipulation_tags()) def role_play_wrapper( @@ -69,4 +82,4 @@ def transform( } return templates[scenario] - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_style_manipulation_tags()) diff --git a/dreadnode/transforms/substitution.py b/dreadnode/transforms/substitution.py index 1ab5fa4e..902c1856 100644 --- a/dreadnode/transforms/substitution.py +++ b/dreadnode/transforms/substitution.py @@ -1,9 +1,23 @@ +import functools import random import re import typing as t from dreadnode.transforms.base import Transform + +@functools.lru_cache(maxsize=1) +def _get_obfuscation_tags() -> dict[str, t.Any]: + """Get compliance tags for obfuscation transforms (cached).""" + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_transform + + return tag_transform( + atlas=ATLASTechnique.OBFUSCATE_ARTIFACTS, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + # ruff: noqa: RUF001 @@ -57,7 +71,7 @@ def get_replacement(item: str) -> str: result = " ".join(substituted_words) return re.sub(r'\s([?.!,"\'`])', r"\1", result).strip() - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) # fmt: off @@ -87,7 +101,7 @@ def transform(text: str) -> str: result.append(BRAILLE_MAP.get(char, char)) return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) # fmt: off @@ -192,7 +206,7 @@ def transform(text: str) -> str: i += 1 return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) # fmt: off @@ -227,7 +241,7 @@ def transform(text: str) -> str: i += 1 return result - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) # fmt: off @@ -293,7 +307,7 @@ def small_caps(*, name: str = "small_caps") -> Transform[str, str]: def transform(text: str) -> str: return "".join(SMALL_CAPS_MAP.get(char.lower(), char) for char in text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) # fmt: off @@ -313,7 +327,7 @@ def wingdings(*, name: str = "wingdings") -> Transform[str, str]: def transform(text: str) -> str: return "".join(WINGDINGS_MAP.get(char.upper(), char) for char in text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) # fmt: off @@ -339,7 +353,7 @@ def transform(text: str) -> str: text_clean = " ".join([line.strip() for line in str.splitlines(text)]) return " ".join([MORSE_MAP.get(char, MORSE_ERROR) for char in text_clean.upper()]) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) # fmt: off @@ -361,7 +375,7 @@ def nato_phonetic(*, name: str = "nato_phonetic") -> Transform[str, str]: def transform(text: str) -> str: return " ".join(NATO_MAP.get(char.upper(), char) for char in text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) # fmt: off @@ -385,7 +399,7 @@ def transform(text: str) -> str: reversed_text = text[::-1] return "".join(MIRROR_MAP.get(char, char) for char in reversed_text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) # fmt: off @@ -431,4 +445,4 @@ def transform(text: str) -> str: words = re.findall(r"\w+|[^\w\s]", text) return "".join(_to_pig_latin_word(word) for word in words) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) diff --git a/dreadnode/transforms/swap.py b/dreadnode/transforms/swap.py index 2bfbf6ce..13d0ca49 100644 --- a/dreadnode/transforms/swap.py +++ b/dreadnode/transforms/swap.py @@ -1,3 +1,4 @@ +import functools import random import re import typing as t @@ -6,6 +7,18 @@ from dreadnode.transforms.base import Transform +@functools.lru_cache(maxsize=1) +def _get_obfuscation_tags() -> dict[str, t.Any]: + """Get compliance tags for obfuscation transforms (cached).""" + from dreadnode.airt.compliance import ATLASTechnique, OWASPCategory, SAIFCategory, tag_transform + + return tag_transform( + atlas=ATLASTechnique.OBFUSCATE_ARTIFACTS, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + def swap( *, unit: t.Literal["char", "word"] = "char", @@ -62,7 +75,7 @@ def transform( return re.sub(r'\s([?.!,"\'`])', r"\1", result).strip() return result - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_obfuscation_tags()) def adjacent_char_swap( diff --git a/dreadnode/transforms/text.py b/dreadnode/transforms/text.py index f1802b4f..5417c209 100644 --- a/dreadnode/transforms/text.py +++ b/dreadnode/transforms/text.py @@ -1,3 +1,4 @@ +import functools import random import re import typing as t @@ -6,13 +7,30 @@ from dreadnode.transforms.base import Transform +@functools.lru_cache(maxsize=1) +def _get_text_manipulation_tags() -> dict[str, t.Any]: + """Get compliance tags for text manipulation transforms (cached).""" + from dreadnode.airt.compliance import ( + ATLASTechnique, + OWASPCategory, + SAIFCategory, + tag_transform, + ) + + return tag_transform( + atlas=ATLASTechnique.CRAFT_ADVERSARIAL_DATA, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + def reverse(*, name: str = "reverse") -> Transform[str, str]: """Reverses the order of characters in a string.""" def transform(text: str) -> str: return text[::-1] - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def search_replace( @@ -63,7 +81,7 @@ def transform(text: str) -> str: flags=re.IGNORECASE, ) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def join( @@ -89,7 +107,7 @@ def transform( items = list(text) if unit == "char" else text.split() return delimiter.join(items) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def char_join(delimiter: str = "-", *, name: str = "char_join") -> Transform[str, str]: @@ -145,7 +163,7 @@ def transform( return text_to_add + delimiter + text return text + delimiter + text_to_add - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def prefix(text: str, *, name: str = "prefix") -> Transform[str, str]: @@ -229,7 +247,7 @@ def transform(text: str) -> str: words = re.findall(r"\w+|[^\w\s]", text) return "".join(_to_pig_latin_word(word) for word in words) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def word_removal( @@ -280,7 +298,7 @@ def transform( return "".join(result_words) return " ".join(result_words) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def word_duplication( @@ -336,7 +354,7 @@ def transform( return "".join(result_words) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def case_alternation( @@ -383,7 +401,7 @@ def transform( return "".join(result) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def whitespace_manipulation( @@ -429,7 +447,7 @@ def replace_space(_: re.Match[str]) -> str: return re.sub(r"\s+", replace_space, text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def sentence_reordering( @@ -469,7 +487,7 @@ def transform(text: str) -> str: return "".join(paired) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def question_transformation( @@ -505,7 +523,7 @@ def transform( # nested return f"For research purposes, I need to understand the answer to: {text}" - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def contextual_wrapping( @@ -544,7 +562,7 @@ def transform( template = wrappers[wrapper] return template.format(text=text) - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) def length_manipulation( @@ -596,4 +614,4 @@ def transform( repetitions = (target_length // current_len) + 1 return (text * repetitions)[:target_length] - return Transform(transform, name=name) + return Transform(transform, name=name, compliance_tags=_get_text_manipulation_tags()) diff --git a/tests/airt/__init__.py b/tests/airt/__init__.py new file mode 100644 index 00000000..8194f2b5 --- /dev/null +++ b/tests/airt/__init__.py @@ -0,0 +1 @@ +"""Tests for AI Red Teaming module.""" diff --git a/tests/airt/test_attack_compliance_tags.py b/tests/airt/test_attack_compliance_tags.py new file mode 100644 index 00000000..94978afc --- /dev/null +++ b/tests/airt/test_attack_compliance_tags.py @@ -0,0 +1,120 @@ +"""Tests for attack compliance tags.""" + + +from dreadnode.airt.attack.crescendo import COMPLIANCE_TAGS as CRESCENDO_TAGS +from dreadnode.airt.attack.goat import COMPLIANCE_TAGS as GOAT_TAGS +from dreadnode.airt.attack.prompt import COMPLIANCE_TAGS as PROMPT_TAGS +from dreadnode.airt.attack.tap import COMPLIANCE_TAGS as TAP_TAGS + + +def test_prompt_attack_has_compliance_tags() -> None: + """Prompt attack has compliance tags.""" + assert "atlas_techniques" in PROMPT_TAGS + assert "owasp_categories" in PROMPT_TAGS + assert "saif_categories" in PROMPT_TAGS + + +def test_prompt_attack_core_technique_only() -> None: + """Prompt attack has only core jailbreak technique tags.""" + assert PROMPT_TAGS["owasp_categories"] == ["LLM01:2025"] + assert "AML.T0051.000" in PROMPT_TAGS["atlas_techniques"] + assert "AML.T0054" in PROMPT_TAGS["atlas_techniques"] + assert "INPUT_MANIPULATION" in PROMPT_TAGS["saif_categories"] + + +def test_prompt_attack_no_vulnerability_categories() -> None: + """Prompt attack does not include specific vulnerability categories.""" + owasp = PROMPT_TAGS["owasp_categories"] + assert "LLM02:2025" not in owasp + assert "LLM07:2025" not in owasp + assert "LLM09:2025" not in owasp + assert "LLM10:2025" not in owasp + + +def test_tap_attack_has_compliance_tags() -> None: + """TAP attack has compliance tags.""" + assert "atlas_techniques" in TAP_TAGS + assert "owasp_categories" in TAP_TAGS + assert "saif_categories" in TAP_TAGS + + +def test_tap_attack_core_technique_only() -> None: + """TAP attack has only core jailbreak technique tags.""" + assert TAP_TAGS["owasp_categories"] == ["LLM01:2025"] + assert "AML.T0051.000" in TAP_TAGS["atlas_techniques"] + assert "AML.T0054" in TAP_TAGS["atlas_techniques"] + + +def test_tap_attack_has_nist() -> None: + """TAP attack includes NIST AI RMF tags.""" + assert "nist_ai_rmf_function" in TAP_TAGS + assert TAP_TAGS["nist_ai_rmf_function"] == "MEASURE" + assert "nist_ai_rmf_subcategory" in TAP_TAGS + assert TAP_TAGS["nist_ai_rmf_subcategory"] == "MS-2.7" + + +def test_goat_attack_has_compliance_tags() -> None: + """GOAT attack has compliance tags.""" + assert "atlas_techniques" in GOAT_TAGS + assert "owasp_categories" in GOAT_TAGS + assert "saif_categories" in GOAT_TAGS + + +def test_goat_attack_core_technique_only() -> None: + """GOAT attack has only core jailbreak technique tags.""" + assert GOAT_TAGS["owasp_categories"] == ["LLM01:2025"] + assert "AML.T0051.000" in GOAT_TAGS["atlas_techniques"] + assert "AML.T0054" in GOAT_TAGS["atlas_techniques"] + + +def test_goat_attack_has_nist() -> None: + """GOAT attack includes NIST AI RMF tags.""" + assert "nist_ai_rmf_function" in GOAT_TAGS + assert GOAT_TAGS["nist_ai_rmf_function"] == "MEASURE" + + +def test_crescendo_attack_has_compliance_tags() -> None: + """Crescendo attack has compliance tags.""" + assert "atlas_techniques" in CRESCENDO_TAGS + assert "owasp_categories" in CRESCENDO_TAGS + assert "saif_categories" in CRESCENDO_TAGS + + +def test_crescendo_attack_core_technique_only() -> None: + """Crescendo attack has only core jailbreak technique tags.""" + assert CRESCENDO_TAGS["owasp_categories"] == ["LLM01:2025"] + assert "AML.T0051.000" in CRESCENDO_TAGS["atlas_techniques"] + assert "AML.T0054" in CRESCENDO_TAGS["atlas_techniques"] + + +def test_crescendo_attack_has_nist() -> None: + """Crescendo attack includes NIST AI RMF tags.""" + assert "nist_ai_rmf_function" in CRESCENDO_TAGS + assert CRESCENDO_TAGS["nist_ai_rmf_function"] == "MEASURE" + + +def test_all_jailbreak_attacks_consistent() -> None: + """All jailbreak attacks have consistent core tags.""" + attacks = [PROMPT_TAGS, TAP_TAGS, GOAT_TAGS, CRESCENDO_TAGS] + + for tags in attacks: + assert tags["owasp_categories"] == ["LLM01:2025"] + assert "AML.T0051.000" in tags["atlas_techniques"] + assert "AML.T0054" in tags["atlas_techniques"] + assert "INPUT_MANIPULATION" in tags["saif_categories"] + + +def test_attacks_do_not_duplicate_transform_tags() -> None: + """Attacks do not include tags that should come from transforms.""" + attacks = [PROMPT_TAGS, TAP_TAGS, GOAT_TAGS, CRESCENDO_TAGS] + + for tags in attacks: + owasp = tags["owasp_categories"] + atlas = tags["atlas_techniques"] + + # Should not include PII extraction tags + assert "LLM02:2025" not in owasp + assert "AML.T0024" not in atlas + + # Should not include system prompt leakage tags + assert "LLM07:2025" not in owasp diff --git a/tests/airt/test_compliance.py b/tests/airt/test_compliance.py new file mode 100644 index 00000000..c1a36860 --- /dev/null +++ b/tests/airt/test_compliance.py @@ -0,0 +1,148 @@ +"""Tests for compliance framework tags.""" + + +from dreadnode.airt.compliance import ( + ATLASTechnique, + NISTAIRMFFunction, + OWASPCategory, + SAIFCategory, + tag_attack, + tag_transform, +) + + +def test_owasp_categories_exist() -> None: + """All OWASP Top 10 categories are defined.""" + assert OWASPCategory.LLM01_PROMPT_INJECTION.value == "LLM01:2025" + assert OWASPCategory.LLM02_SENSITIVE_INFORMATION_DISCLOSURE.value == "LLM02:2025" + assert OWASPCategory.LLM03_SUPPLY_CHAIN.value == "LLM03:2025" + assert OWASPCategory.LLM04_DATA_MODEL_POISONING.value == "LLM04:2025" + assert OWASPCategory.LLM05_IMPROPER_OUTPUT_HANDLING.value == "LLM05:2025" + assert OWASPCategory.LLM06_EXCESSIVE_AGENCY.value == "LLM06:2025" + assert OWASPCategory.LLM07_SYSTEM_PROMPT_LEAKAGE.value == "LLM07:2025" + assert OWASPCategory.LLM08_VECTOR_EMBEDDING_WEAKNESSES.value == "LLM08:2025" + assert OWASPCategory.LLM09_MISINFORMATION.value == "LLM09:2025" + assert OWASPCategory.LLM10_UNBOUNDED_CONSUMPTION.value == "LLM10:2025" + + +def test_atlas_techniques_exist() -> None: + """ATLAS techniques are defined.""" + assert ATLASTechnique.PROMPT_INJECTION.value == "AML.T0051" + assert ATLASTechnique.PROMPT_INJECTION_DIRECT.value == "AML.T0051.000" + assert ATLASTechnique.LLM_JAILBREAK.value == "AML.T0054" + assert ATLASTechnique.OBFUSCATE_ARTIFACTS.value == "AML.T0044" + assert ATLASTechnique.ADVERSARIAL_PERTURBATION.value == "AML.T0043.001" + assert ATLASTechnique.INFER_TRAINING_DATA.value == "AML.T0024" + assert ATLASTechnique.MODEL_INVERSION.value == "AML.T0024.000" + + +def test_saif_categories_exist() -> None: + """SAIF categories are defined.""" + assert SAIFCategory.INPUT_MANIPULATION.value == "INPUT_MANIPULATION" + assert SAIFCategory.OUTPUT_MANIPULATION.value == "OUTPUT_MANIPULATION" + assert SAIFCategory.PRIVACY_LEAKAGE.value == "PRIVACY_LEAKAGE" + + +def test_nist_functions_exist() -> None: + """NIST AI RMF functions are defined.""" + assert NISTAIRMFFunction.GOVERN.value == "GOVERN" + assert NISTAIRMFFunction.MAP.value == "MAP" + assert NISTAIRMFFunction.MEASURE.value == "MEASURE" + assert NISTAIRMFFunction.MANAGE.value == "MANAGE" + + +def test_tag_attack_single_values() -> None: + """Tag attack with single values.""" + tags = tag_attack( + atlas=ATLASTechnique.PROMPT_INJECTION, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + assert tags["atlas_techniques"] == ["AML.T0051"] + assert tags["owasp_categories"] == ["LLM01:2025"] + assert tags["saif_categories"] == ["INPUT_MANIPULATION"] + + +def test_tag_attack_multiple_values() -> None: + """Tag attack with multiple values.""" + tags = tag_attack( + atlas=[ATLASTechnique.PROMPT_INJECTION, ATLASTechnique.LLM_JAILBREAK], + owasp=[OWASPCategory.LLM01_PROMPT_INJECTION, OWASPCategory.LLM02_SENSITIVE_INFORMATION_DISCLOSURE], + saif=[SAIFCategory.INPUT_MANIPULATION, SAIFCategory.PRIVACY_LEAKAGE], + ) + + assert tags["atlas_techniques"] == ["AML.T0051", "AML.T0054"] + assert tags["owasp_categories"] == ["LLM01:2025", "LLM02:2025"] + assert tags["saif_categories"] == ["INPUT_MANIPULATION", "PRIVACY_LEAKAGE"] + + +def test_tag_attack_with_nist() -> None: + """Tag attack with NIST AI RMF.""" + tags = tag_attack( + atlas=ATLASTechnique.PROMPT_INJECTION, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + nist_function=NISTAIRMFFunction.MEASURE, + nist_subcategory="MS-2.7", + ) + + assert "nist_ai_rmf_function" in tags + assert tags["nist_ai_rmf_function"] == "MEASURE" + assert "nist_ai_rmf_subcategory" in tags + assert tags["nist_ai_rmf_subcategory"] == "MS-2.7" + + +def test_tag_attack_optional_parameters() -> None: + """Tag attack with only required parameters.""" + tags = tag_attack( + atlas=ATLASTechnique.PROMPT_INJECTION, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + assert "atlas_techniques" in tags + assert "owasp_categories" in tags + assert "saif_categories" in tags + assert "nist_ai_rmf_function" not in tags + assert "nist_ai_rmf_subcategory" not in tags + + +def test_tag_transform_single_values() -> None: + """Tag transform with single values.""" + tags = tag_transform( + atlas=ATLASTechnique.OBFUSCATE_ARTIFACTS, + owasp=OWASPCategory.LLM01_PROMPT_INJECTION, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + assert tags["atlas_techniques"] == ["AML.T0044"] + assert tags["owasp_categories"] == ["LLM01:2025"] + assert tags["saif_categories"] == ["INPUT_MANIPULATION"] + + +def test_tag_transform_multiple_values() -> None: + """Tag transform with multiple values.""" + tags = tag_transform( + atlas=[ATLASTechnique.EVADE_ML_MODEL, ATLASTechnique.OBFUSCATE_ARTIFACTS], + owasp=[OWASPCategory.LLM01_PROMPT_INJECTION, OWASPCategory.LLM05_IMPROPER_OUTPUT_HANDLING], + saif=[SAIFCategory.INPUT_MANIPULATION, SAIFCategory.OUTPUT_MANIPULATION], + ) + + assert "AML.T0043" in tags["atlas_techniques"] + assert "AML.T0044" in tags["atlas_techniques"] + assert "LLM01:2025" in tags["owasp_categories"] + assert "LLM05:2025" in tags["owasp_categories"] + + +def test_tag_attack_none_values() -> None: + """Tag attack handles None values.""" + tags = tag_attack( + atlas=ATLASTechnique.PROMPT_INJECTION, + owasp=None, + saif=SAIFCategory.INPUT_MANIPULATION, + ) + + assert "atlas_techniques" in tags + assert "owasp_categories" not in tags + assert "saif_categories" in tags diff --git a/tests/test_transform_compliance_tags.py b/tests/test_transform_compliance_tags.py new file mode 100644 index 00000000..d581d3d8 --- /dev/null +++ b/tests/test_transform_compliance_tags.py @@ -0,0 +1,218 @@ +"""Tests for transform compliance tags.""" + +from dreadnode.transforms.cipher import caesar_cipher +from dreadnode.transforms.constitutional import code_fragmentation +from dreadnode.transforms.encoding import base64_encode +from dreadnode.transforms.language import adapt_language +from dreadnode.transforms.perturbation import adversarial_suffix +from dreadnode.transforms.pii_extraction import repeat_word_divergence +from dreadnode.transforms.refine import llm_refine +from dreadnode.transforms.stylistic import role_play_wrapper +from dreadnode.transforms.substitution import braille +from dreadnode.transforms.swap import adjacent_char_swap +from dreadnode.transforms.text import reverse + + +def test_pii_transform_has_compliance_tags() -> None: + """PII extraction transforms have compliance tags.""" + transform = repeat_word_divergence() + + assert hasattr(transform, "compliance_tags") + assert isinstance(transform.compliance_tags, dict) + assert "atlas_techniques" in transform.compliance_tags + assert "owasp_categories" in transform.compliance_tags + assert "saif_categories" in transform.compliance_tags + + +def test_pii_transform_has_correct_tags() -> None: + """PII extraction transforms have correct vulnerability tags.""" + transform = repeat_word_divergence() + tags = transform.compliance_tags + + assert "LLM02:2025" in tags["owasp_categories"] + assert "PRIVACY_LEAKAGE" in tags["saif_categories"] + assert any("AML.T0024" in t for t in tags["atlas_techniques"]) + + +def test_cipher_transform_has_compliance_tags() -> None: + """Cipher transforms have compliance tags.""" + transform = caesar_cipher(offset=3) + + assert hasattr(transform, "compliance_tags") + assert isinstance(transform.compliance_tags, dict) + + +def test_cipher_transform_has_obfuscation_tags() -> None: + """Cipher transforms have obfuscation tags.""" + transform = caesar_cipher(offset=3) + tags = transform.compliance_tags + + assert "LLM01:2025" in tags["owasp_categories"] + assert "INPUT_MANIPULATION" in tags["saif_categories"] + assert "AML.T0044" in tags["atlas_techniques"] + + +def test_encoding_transform_has_compliance_tags() -> None: + """Encoding transforms have compliance tags.""" + transform = base64_encode() + + assert hasattr(transform, "compliance_tags") + assert "owasp_categories" in transform.compliance_tags + + +def test_encoding_transform_has_obfuscation_tags() -> None: + """Encoding transforms have obfuscation tags.""" + transform = base64_encode() + tags = transform.compliance_tags + + assert "LLM01:2025" in tags["owasp_categories"] + assert "AML.T0044" in tags["atlas_techniques"] + + +def test_perturbation_transform_has_compliance_tags() -> None: + """Perturbation transforms have compliance tags.""" + transform = adversarial_suffix() + + assert hasattr(transform, "compliance_tags") + assert "atlas_techniques" in transform.compliance_tags + + +def test_perturbation_transform_has_adversarial_tags() -> None: + """Perturbation transforms have adversarial perturbation tags.""" + transform = adversarial_suffix() + tags = transform.compliance_tags + + assert "AML.T0043.001" in tags["atlas_techniques"] + assert "INPUT_MANIPULATION" in tags["saif_categories"] + + +def test_constitutional_transform_has_compliance_tags() -> None: + """Constitutional evasion transforms have compliance tags.""" + transform = code_fragmentation() + + assert hasattr(transform, "compliance_tags") + assert "owasp_categories" in transform.compliance_tags + + +def test_constitutional_transform_has_evasion_tags() -> None: + """Constitutional evasion transforms have multiple OWASP tags.""" + transform = code_fragmentation() + tags = transform.compliance_tags + + assert "LLM01:2025" in tags["owasp_categories"] + assert "LLM05:2025" in tags["owasp_categories"] + assert "INPUT_MANIPULATION" in tags["saif_categories"] + assert "OUTPUT_MANIPULATION" in tags["saif_categories"] + + +def test_stylistic_transform_has_compliance_tags() -> None: + """Stylistic transforms have compliance tags.""" + transform = role_play_wrapper(scenario="educational", character="researcher") + + assert hasattr(transform, "compliance_tags") + assert "atlas_techniques" in transform.compliance_tags + + +def test_language_transform_has_compliance_tags() -> None: + """Language transforms have compliance tags.""" + transform = adapt_language(target_language="es", adapter_model="gpt-4") + + assert hasattr(transform, "compliance_tags") + assert "owasp_categories" in transform.compliance_tags + + +def test_text_transform_has_compliance_tags() -> None: + """Text manipulation transforms have compliance tags.""" + transform = reverse() + + assert hasattr(transform, "compliance_tags") + assert "atlas_techniques" in transform.compliance_tags + + +def test_refine_transform_has_compliance_tags() -> None: + """Refinement transforms have compliance tags.""" + transform = llm_refine(model="gpt-4", guidance="test") + + assert hasattr(transform, "compliance_tags") + assert "owasp_categories" in transform.compliance_tags + + +def test_substitution_transform_has_compliance_tags() -> None: + """Substitution transforms have compliance tags.""" + transform = braille() + + assert hasattr(transform, "compliance_tags") + assert "atlas_techniques" in transform.compliance_tags + + +def test_swap_transform_has_compliance_tags() -> None: + """Swap transforms have compliance tags.""" + transform = adjacent_char_swap() + + assert hasattr(transform, "compliance_tags") + assert "owasp_categories" in transform.compliance_tags + + +def test_all_transforms_have_required_keys() -> None: + """All transforms have required compliance tag keys.""" + transforms = [ + caesar_cipher(offset=3), + base64_encode(), + adversarial_suffix(), + repeat_word_divergence(), + code_fragmentation(), + role_play_wrapper(scenario="educational", character="researcher"), + adapt_language(target_language="es", adapter_model="gpt-4"), + reverse(), + llm_refine(model="gpt-4", guidance="test"), + braille(), + adjacent_char_swap(), + ] + + for transform in transforms: + assert "atlas_techniques" in transform.compliance_tags + assert "owasp_categories" in transform.compliance_tags + assert "saif_categories" in transform.compliance_tags + + +def test_transform_tags_are_lists() -> None: + """Transform tag values are lists.""" + transform = repeat_word_divergence() + tags = transform.compliance_tags + + assert isinstance(tags["atlas_techniques"], list) + assert isinstance(tags["owasp_categories"], list) + assert isinstance(tags["saif_categories"], list) + + +def test_transform_tags_not_empty() -> None: + """Transform tags contain at least one value.""" + transforms = [ + caesar_cipher(offset=3), + repeat_word_divergence(), + adversarial_suffix(), + ] + + for transform in transforms: + tags = transform.compliance_tags + assert len(tags["atlas_techniques"]) > 0 + assert len(tags["owasp_categories"]) > 0 + assert len(tags["saif_categories"]) > 0 + + +def test_pii_and_obfuscation_different_tags() -> None: + """PII and obfuscation transforms have different vulnerability tags.""" + pii = repeat_word_divergence() + cipher = caesar_cipher(offset=3) + + pii_owasp = pii.compliance_tags["owasp_categories"] + cipher_owasp = cipher.compliance_tags["owasp_categories"] + + # PII targets sensitive info disclosure + assert "LLM02:2025" in pii_owasp + + # Cipher targets prompt injection + assert "LLM01:2025" in cipher_owasp + + # Different vulnerability categories + assert pii_owasp != cipher_owasp