Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/tirith/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,9 @@ def start_policy_evaluation(

with open(input_path) as f:
if input_path.endswith(".yaml") or input_path.endswith(".yml"):
# safe_load_all returns a generator, we need to convert it into a
# dictionary because start_policy_evaluation_from_dict expects a dictionary
input_data = dict(yamls=list(yaml.safe_load_all(f)))
input_data = list(yaml.safe_load_all(f))
if len(input_data) == 1:
input_data = input_data[0]
else:
input_data = json.load(f)
# TODO: validate input_data using the optionally available validate function in provider
Expand Down
158 changes: 154 additions & 4 deletions src/tirith/providers/common.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,163 @@
from typing import Dict
import pydash

from typing import Dict, Any


def create_result_dict(value=None, meta=None, err=None) -> Dict:
return dict(value=value, meta=meta, err=err)


def get_path_value_from_dict(key_path: str, input_dict: dict, get_path_value_from_dict_func):
splitted_attribute = key_path.split(".*.")
return get_path_value_from_dict_func(splitted_attribute, input_dict)
class PydashPathNotFound:
pass


def _get_path_value_from_input_internal(splitted_paths, input_data, place_none_if_not_found=False):

if not splitted_paths:
return [input_data] if input_data is not PydashPathNotFound else ([None] if place_none_if_not_found else [])

final_data = []
expression = splitted_paths[0]
remaining_paths = splitted_paths[1:]

# Handle wildcard at the beginning (e.g., "*.something")
if expression == "":
if isinstance(input_data, list):
for item in input_data:
if remaining_paths:
results = _get_path_value_from_input_internal(remaining_paths, item, place_none_if_not_found)
final_data.extend(results)
else:
final_data.append(item)
elif isinstance(input_data, dict):
for value in input_data.values():
if remaining_paths:
results = _get_path_value_from_input_internal(remaining_paths, value, place_none_if_not_found)
final_data.extend(results)
else:
final_data.append(value)
else:
# For primitive values with empty expression (wildcard match)
# Just return the value if no more paths to traverse
if not remaining_paths:
final_data.append(input_data)
return final_data

# Get the value at the current path
intermediate_val = pydash.get(input_data, expression, default=PydashPathNotFound)

if intermediate_val is PydashPathNotFound:
return [None] if place_none_if_not_found else []

# If there are more paths to traverse
if remaining_paths:
if isinstance(intermediate_val, list) and remaining_paths[0] == "":
# For lists with a wildcard marker, iterate over list items
# Skip the wildcard marker since iteration is implicit for lists
paths_to_apply = remaining_paths[1:]
for val in intermediate_val:
results = _get_path_value_from_input_internal(paths_to_apply, val, place_none_if_not_found)
final_data.extend(results)
elif isinstance(intermediate_val, dict) and remaining_paths[0] == "":
# If it's a dict and next path is a wildcard, iterate over dict values
# Skip the wildcard marker and apply remaining paths to each value
for value in intermediate_val.values():
results = _get_path_value_from_input_internal(remaining_paths[1:], value, place_none_if_not_found)
final_data.extend(results)
else:
# For non-wildcard paths, continue traversal without iteration
results = _get_path_value_from_input_internal(remaining_paths, intermediate_val, place_none_if_not_found)
final_data.extend(results)
else:
# This is the final path segment
final_data.append(intermediate_val)

return final_data


def get_path_value_from_input(key_path: str, input: Any, place_none_if_not_found: bool = False):
"""
Retrieve values from a nested data structure using a path expression with wildcard support.

:param key_path: A dot-separated path to traverse the data structure.
Use ``*`` for wildcard to match all items at that level.
Supports nested structures including dictionaries, lists, and primitives.
:type key_path: str
:param input: The input data structure to search through (dict, list, or primitive).
:type input: Any
:param place_none_if_not_found: If True, returns [None] when a path is not found.
If False, returns an empty list []. Defaults to False.
:type place_none_if_not_found: bool
:return: A list of values found at the specified path. Returns empty list or [None] if path not found,
depending on place_none_if_not_found parameter.
:rtype: list

**Examples:**

Basic path traversal::

>>> data = {"user": {"name": "Alice", "age": 30}}
>>> get_path_value_from_input("user.name", data)
["Alice"]

Wildcard with list items::

>>> data = {"users": [{"name": "Alice"}, {"name": "Bob"}]}
>>> get_path_value_from_input("users.*.name", data)
["Alice", "Bob"]

Wildcard with dictionary values::

>>> data = {"countries": {"US": {"capital": "Washington"}, "UK": {"capital": "London"}}}
>>> get_path_value_from_input("countries.*.capital", data)
["Washington", "London"]

Leading wildcard on lists::

>>> data = [{"name": "Alice"}, {"name": "Bob"}]
>>> get_path_value_from_input("*.name", data)
["Alice", "Bob"]

Wildcard on primitives::

>>> get_path_value_from_input("*", 42)
[42]
>>> get_path_value_from_input("*", "hello")
["hello"]

Multiple wildcards::

>>> data = {"groups": [[{"id": 1}, {"id": 2}], [{"id": 3}]]}
>>> get_path_value_from_input("groups.*.*.id", data)
[1, 2, 3]

Empty path returns input as-is::

>>> data = {"key": "value"}
>>> get_path_value_from_input("", data)
[{"key": "value"}]

Path not found behavior::

>>> data = {"user": {"name": "Alice"}}
>>> get_path_value_from_input("missing.path", data)
[]
>>> get_path_value_from_input("missing.path", data, place_none_if_not_found=True)
[None]
"""
# Handle empty path - return the input data as is
if not key_path:
return [input]

# Split the path by dots and replace '*' with empty string to mark wildcards
# Empty strings act as markers to iterate over collections (lists or dict values)
# Example: "users.*.name" -> ["users", "", "name"]
# "*.name" -> ["", "name"]
# "numbers.*" -> ["numbers", ""]
splitted_attribute = key_path.split(".")
splitted_attribute = ["" if part == "*" else part for part in splitted_attribute]

return _get_path_value_from_input_internal(splitted_attribute, input, place_none_if_not_found)


class ProviderError:
Expand Down
30 changes: 2 additions & 28 deletions src/tirith/providers/json/handler.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,12 @@
import pydash

from typing import Callable, Dict, List
from ..common import create_result_dict, ProviderError, get_path_value_from_dict


class PydashPathNotFound:
pass


def _get_path_value_from_dict(splitted_paths, input_dict):
final_data = []
for i, expression in enumerate(splitted_paths):
intermediate_val = pydash.get(input_dict, expression, default=PydashPathNotFound)
if isinstance(intermediate_val, list) and i < len(splitted_paths) - 1:
for val in intermediate_val:
final_attributes = _get_path_value_from_dict(splitted_paths[1:], val)
for final_attribute in final_attributes:
final_data.append(final_attribute)
elif i == len(splitted_paths) - 1 and intermediate_val is not PydashPathNotFound:
final_data.append(intermediate_val)
elif ".*" in expression:
intermediate_exp = expression.split(".*")
intermediate_data = pydash.get(input_dict, intermediate_exp[0], default=PydashPathNotFound)
if intermediate_data is not PydashPathNotFound and isinstance(intermediate_data, list):
for val in intermediate_data:
final_data.append(val)
return final_data
from ..common import create_result_dict, ProviderError, get_path_value_from_input


def get_value(provider_args: Dict, input_data: Dict) -> List[dict]:
# Must be validated first whether the provider args are valid for this op type
key_path: str = provider_args["key_path"]

values = get_path_value_from_dict(key_path, input_data, _get_path_value_from_dict)
values = get_path_value_from_input(key_path, input_data)

if len(values) == 0:
severity_value = 2
Expand Down
42 changes: 6 additions & 36 deletions src/tirith/providers/kubernetes/handler.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,8 @@
import pydash
from typing import Callable, Dict, List, Union
from ..common import create_result_dict, ProviderError, get_path_value_from_input

from typing import Callable, Dict, List
from ..common import create_result_dict, ProviderError, get_path_value_from_dict


class PydashPathNotFound:
pass


def _get_path_value_from_dict(splitted_paths, input_dict):
final_data = []
expression = splitted_paths[0]
is_the_last_expression = len(splitted_paths) == 1

intermediate_val = pydash.get(input_dict, expression, default=PydashPathNotFound)
if isinstance(intermediate_val, list) and not is_the_last_expression:
for val in intermediate_val:
final_attributes = _get_path_value_from_dict(splitted_paths[1:], val)
for final_attribute in final_attributes:
final_data.append(final_attribute)
elif intermediate_val is PydashPathNotFound:
final_data.append(None)
elif is_the_last_expression:
final_data.append(intermediate_val)
elif ".*" in expression:
intermediate_exp = expression.split(".*")
intermediate_data = pydash.get(input_dict, intermediate_exp[0], default=PydashPathNotFound)
if intermediate_data is not PydashPathNotFound and isinstance(intermediate_data, list):
for val in intermediate_data:
final_data.append(val)
return final_data


def get_value(provider_args: Dict, input_data: Dict, outputs: list) -> Dict:
def get_value(provider_args: Dict, input_data: Union[Dict, List], outputs: list) -> Dict:
# Must be validated first whether the provider args are valid for this op type
target_kind: str = provider_args.get("kubernetes_kind")
attribute_path: str = provider_args.get("attribute_path", "")
Expand All @@ -42,15 +12,15 @@ def get_value(provider_args: Dict, input_data: Dict, outputs: list) -> Dict:
if attribute_path == "":
return create_result_dict(value=ProviderError(severity_value=99), err="attribute_path must be provided")

kubernetes_resources = input_data["yamls"]
kubernetes_resources = input_data
is_kind_found = False

for resource in kubernetes_resources:
if resource["kind"] != target_kind:
continue
is_kind_found = True
values = get_path_value_from_dict(attribute_path, resource, _get_path_value_from_dict)
if ".*." not in attribute_path:
values = get_path_value_from_input(attribute_path, resource, place_none_if_not_found=True)
if "*" not in attribute_path:
# If there's no * in the attribute path, the values always have 1 member
values = values[0]
outputs.append(create_result_dict(value=values))
Expand Down
44 changes: 44 additions & 0 deletions tests/providers/json/playbook.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
- name: Provision EC2 instance and set up MySQL
hosts: localhost
gather_facts: false
become: True
vars:
region: "your_aws_region"
instance_type: "t2.micro"
ami_id: "your_ami_id"
key_name: "your_key_name"
security_group: "your_security_group_id"
subnet_id: "your_subnet_id"
mysql_root_password: "your_mysql_root_password"
package_list:
- unauthorized-app
tasks:
- name: Create EC2 instance
amazon.aws.ec2_instance:
region: "{{ region }}"
key_name: "{{ key_name }}"
instance_type: "{{ instance_type }}"
image_id: "{{ ami_id }}"
security_group: "{{ security_group }}"
subnet_id: "{{ subnet_id }}"
assign_public_ip: true
wait: yes
count: 1
instance_tags:
Name: "MySQLInstance"
register: ec2

- name: Install Unauthorized App
become: true
ansible.builtin.package:
name: "{{ package_list }}"
state: present

- name: Set MySQL root password [using unauthorized collection]
community.mysql.mysql_user:
name: root
password: "{{ mysql_root_password }}"
host: "{{ item }}"
login_unix_socket: yes
with_items: ["localhost", "127.0.0.1", "::1"]

20 changes: 20 additions & 0 deletions tests/providers/json/policy_playbook.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"meta": {
"version": "v1",
"required_provider": "stackguardian/json"
},
"evaluators": [
{
"id": "check0",
"provider_args": {
"operation_type": "get_value",
"key_path": "*.vars.region"
},
"condition": {
"type": "Equals",
"value": "your_aws_region"
}
}
],
"eval_expression": "check0"
}
14 changes: 12 additions & 2 deletions tests/providers/json/test_get_value.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os

from tirith.core.core import start_policy_evaluation_from_dict
from tirith.core.core import start_policy_evaluation, start_policy_evaluation_from_dict


# TODO: Need to split this into multiple tests
Expand All @@ -13,4 +13,14 @@ def test_get_value():
policy = json.load(f)

result = start_policy_evaluation_from_dict(policy, input_data)
assert result["final_result"] == True
assert result["final_result"] is True


def test_get_value_playbook():
"""Test get_value with playbook YAML data using wildcard path"""
test_dir = os.path.dirname(os.path.realpath(__file__))
input_path = os.path.join(test_dir, "playbook.yml")
policy_path = os.path.join(test_dir, "policy_playbook.json")

result = start_policy_evaluation(policy_path=policy_path, input_path=input_path)
assert result["final_result"] is True
12 changes: 12 additions & 0 deletions tests/providers/kubernetes/test_attribute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import os

from tirith.core.core import start_policy_evaluation


def test_get_value():
test_dir = os.path.dirname(os.path.realpath(__file__))
input_path = os.path.join(test_dir, "input.yml")
policy_path = os.path.join(test_dir, "policy.json")

result = start_policy_evaluation(policy_path=policy_path, input_path=input_path)
assert result["final_result"] is False
Loading
Loading