|
- {% if rule.ipv4 %}
- {{ rule.ipv4 }}{{ '/' if rule.ipv4_mask else '' }}{{rule.ipv4_mask|default("", True)}}
- {% endif %}
- {% if rule.ipv6 %}
- {{ rule.ipv6 }}{{ '/' if rule.ipv6_mask else '' }} {{rule.ipv6_mask|default("", True)}}
- {% endif %}
- |
-
- {{ rule.community.name }}
+ {{ rule.ip }}{{ '/' if rule.mask else '' }}{{rule.mask|default("", True)}}
|
-
-
+ |
{{ rule.expires|strftime }}
|
@@ -97,10 +150,10 @@
|
{% if editable %}
-
+
-
+
{% endif %}
@@ -122,6 +175,7 @@
{% endmacro %}
+
{% macro build_rules_thead(rules_columns, rtype, rstate, sort_key, sort_order, search_query='', group_op=True) %}
diff --git a/flowapp/templates/pages/machine_api_key.html b/flowapp/templates/pages/machine_api_key.html
index 52eb478a..c2ced223 100644
--- a/flowapp/templates/pages/machine_api_key.html
+++ b/flowapp/templates/pages/machine_api_key.html
@@ -11,8 +11,8 @@ Machines and ApiKeys
| Machine address |
ApiKey |
- Created by |
Created for |
+ Created by |
Expires |
Read/Write ? |
Action |
diff --git a/flowapp/tests/conftest.py b/flowapp/tests/conftest.py
index 65919af1..3a988ef4 100644
--- a/flowapp/tests/conftest.py
+++ b/flowapp/tests/conftest.py
@@ -12,6 +12,7 @@
from flowapp import db as _db
from datetime import datetime
import flowapp.models
+from flowapp.models.organization import Organization
TESTDB = "test_project.db"
@@ -67,6 +68,8 @@ def app(request):
SECRET_KEY="testkeysession",
LOCAL_USER_UUID="jiri.vrany@cesnet.cz",
LOCAL_AUTH=True,
+ ALLOWED_COMMUNITIES=[1, 2, 3],
+ WTF_CSRF_ENABLED=False,
)
print("\n----- CREATE FLASK APPLICATION\n")
@@ -114,6 +117,12 @@ def db(app, request):
print("#: inserting users")
flowapp.models.insert_users(users)
+ org = _db.session.query(Organization).filter_by(id=1).first()
+ # Update the organization address range to include our test networks
+ org.arange = "147.230.0.0/16\n2001:718:1c01::/48\n192.168.0.0/16\n10.0.0.0/8"
+ _db.session.commit()
+ print("\n----- UPDATED TEST ORG 1 \n", org)
+
def teardown():
_db.session.commit()
_db.drop_all()
@@ -143,6 +152,26 @@ def jwt_token(client, app, db, request):
return data["token"]
+@pytest.fixture(scope="session")
+def machine_api_token(client, app, db, request):
+ """
+ Get the test_client from the app, for the whole test session.
+ """
+ mkey = "machinetestkey"
+
+ with app.app_context():
+ model = flowapp.models.MachineApiKey(machine="127.0.0.1", key=mkey, user_id=1, org_id=1)
+ db.session.add(model)
+ db.session.commit()
+
+ print("\n----- GET MACHINE API KEY TEST TOKEN\n")
+ url = "/api/v3/auth"
+ headers = {"x-api-key": mkey}
+ token = client.get(url, headers=headers)
+ data = json.loads(token.data)
+ return data["token"]
+
+
@pytest.fixture(scope="session")
def expired_auth_token(client, app, db, request):
"""
@@ -185,3 +214,19 @@ def auth_client(client):
print("\n----- CREATE AUTHENTICATED FLASK TEST CLIENT\n")
client.get("/local-login")
return client
+
+
+@pytest.fixture(autouse=True)
+def reset_org_limits(db, app):
+ """
+ Automatically reset organization limits after each test that modifies them.
+ """
+ yield # Allow test execution
+
+ with app.app_context():
+ org = db.session.query(Organization).filter_by(id=1).first()
+ if org:
+ org.limit_flowspec4 = 0
+ org.limit_flowspec6 = 0
+ org.limit_rtbh = 0
+ db.session.commit()
diff --git a/flowapp/tests/rule_service_integration.py b/flowapp/tests/rule_service_integration.py
new file mode 100644
index 00000000..106000ab
--- /dev/null
+++ b/flowapp/tests/rule_service_integration.py
@@ -0,0 +1,332 @@
+"""Integration tests for rule_service module."""
+
+import pytest
+from datetime import datetime, timedelta
+from unittest.mock import patch
+
+from flowapp import db, create_app
+from flowapp.constants import RuleTypes
+from flowapp.models import (
+ RTBH,
+ Flowspec4,
+ Flowspec6,
+ Whitelist,
+ User,
+ Organization,
+ Role,
+ Community,
+ Action,
+ Rstate,
+)
+from flowapp.services import rule_service
+
+
+@pytest.fixture(scope="module")
+def app():
+ """Create and configure a Flask app for testing."""
+ app = create_app("testing")
+
+ # Create a test context
+ with app.app_context():
+ # Create database tables
+ db.create_all()
+
+ # Create test data
+ _create_test_data()
+
+ yield app
+
+ # Clean up after tests
+ db.session.remove()
+ db.drop_all()
+
+
+@pytest.fixture(scope="function")
+def app_context(app):
+ """Provide an application context for each test."""
+ with app.app_context():
+ yield
+
+
+@pytest.fixture(scope="function")
+def mock_announce_route():
+ """Mock the announce_route function."""
+ with patch("flowapp.services.rule_service.announce_route") as mock:
+ yield mock
+
+
+def _create_test_data():
+ """Create test data in the database."""
+ # Create roles
+ admin_role = Role(name="admin", description="Administrator")
+ user_role = Role(name="user", description="Normal user")
+ view_role = Role(name="view", description="View only")
+ db.session.add_all([admin_role, user_role, view_role])
+ db.session.flush()
+
+ # Create organizations
+ org1 = Organization(name="Test Org 1", arange="192.168.1.0/24\n2001:db8::/64")
+ org2 = Organization(name="Test Org 2", arange="10.0.0.0/8")
+ db.session.add_all([org1, org2])
+ db.session.flush()
+
+ # Create users
+ user1 = User(uuid="user1@example.com", name="User One")
+ user1.role.append(user_role)
+ user1.organization.append(org1)
+
+ admin1 = User(uuid="admin@example.com", name="Admin User")
+ admin1.role.append(admin_role)
+ admin1.organization.append(org2)
+
+ db.session.add_all([user1, admin1])
+ db.session.flush()
+
+ # Create rule states
+ state_active = Rstate(description="active rule")
+ state_withdrawn = Rstate(description="withdrawed rule")
+ state_deleted = Rstate(description="deleted rule")
+ state_whitelisted = Rstate(description="whitelisted rule")
+ db.session.add_all([state_active, state_withdrawn, state_deleted, state_whitelisted])
+ db.session.flush()
+
+ # Create actions
+ action1 = Action(name="Discard", command="discard", description="Drop traffic", role_id=user_role.id)
+ action2 = Action(name="Rate limit", command="rate-limit 1000000", description="Limit traffic", role_id=user_role.id)
+ db.session.add_all([action1, action2])
+ db.session.flush()
+
+ # Create communities
+ community1 = Community(
+ name="65000:1",
+ comm="65000:1",
+ larcomm="",
+ extcomm="",
+ description="Test community",
+ role_id=user_role.id,
+ as_path=False,
+ )
+ db.session.add(community1)
+ db.session.flush()
+
+ # Create some rules
+ # IPv4 rule
+ ipv4_rule = Flowspec4(
+ source="192.168.1.1",
+ source_mask=32,
+ source_port="",
+ dest="192.168.2.1",
+ dest_mask=32,
+ dest_port="",
+ protocol="tcp",
+ flags="",
+ packet_len="",
+ fragment="",
+ expires=datetime.now() + timedelta(hours=1),
+ comment="Test IPv4 rule",
+ action_id=action1.id,
+ user_id=user1.id,
+ org_id=org1.id,
+ rstate_id=state_active.id,
+ )
+ db.session.add(ipv4_rule)
+
+ # IPv6 rule
+ ipv6_rule = Flowspec6(
+ source="2001:db8::1",
+ source_mask=128,
+ source_port="",
+ dest="2001:db8:1::1",
+ dest_mask=128,
+ dest_port="",
+ next_header="tcp",
+ flags="",
+ packet_len="",
+ expires=datetime.now() + timedelta(hours=1),
+ comment="Test IPv6 rule",
+ action_id=action1.id,
+ user_id=user1.id,
+ org_id=org1.id,
+ rstate_id=state_active.id,
+ )
+ db.session.add(ipv6_rule)
+
+ # RTBH rule
+ rtbh_rule = RTBH(
+ ipv4="192.168.1.100",
+ ipv4_mask=32,
+ ipv6=None,
+ ipv6_mask=None,
+ community_id=community1.id,
+ expires=datetime.now() + timedelta(hours=1),
+ comment="Test RTBH rule",
+ user_id=user1.id,
+ org_id=org1.id,
+ rstate_id=state_active.id,
+ )
+ db.session.add(rtbh_rule)
+
+ # Save IDs as class variables for tests to use
+ _create_test_data.user_id = user1.id
+ _create_test_data.admin_id = admin1.id
+ _create_test_data.org1_id = org1.id
+ _create_test_data.org2_id = org2.id
+ _create_test_data.ipv4_rule_id = ipv4_rule.id
+ _create_test_data.ipv6_rule_id = ipv6_rule.id
+ _create_test_data.rtbh_rule_id = rtbh_rule.id
+ _create_test_data.action_id = action1.id
+ _create_test_data.community_id = community1.id
+
+ db.session.commit()
+
+
+class TestRuleServiceIntegration:
+ """Integration tests for rule_service functions."""
+
+ def test_reactivate_rule_ipv4(self, app_context, mock_announce_route):
+ """Test reactivating an IPv4 rule."""
+ # Test data
+ rule_id = _create_test_data.ipv4_rule_id
+ user_id = _create_test_data.user_id
+ org_id = _create_test_data.org1_id
+
+ # Set new expiration time (2 hours in the future)
+ new_expires = datetime.now() + timedelta(hours=2)
+ new_comment = "Updated test comment"
+
+ # Call the function
+ model, messages = rule_service.reactivate_rule(
+ rule_type=RuleTypes.IPv4,
+ rule_id=rule_id,
+ expires=new_expires,
+ comment=new_comment,
+ user_id=user_id,
+ org_id=org_id,
+ user_email="test@example.com",
+ org_name="Test Org",
+ )
+
+ # Verify the rule was updated
+ assert model is not None
+ assert model.id == rule_id
+ assert model.comment == new_comment
+ assert model.expires == new_expires
+ assert model.rstate_id == 1 # active state
+
+ # Verify route announcement was attempted
+ assert mock_announce_route.called
+
+ # Verify message
+ assert messages == ["Rule successfully updated"]
+
+ def test_reactivate_rule_inactive(self, app_context, mock_announce_route):
+ """Test reactivating a rule to inactive state."""
+ # Test data
+ rule_id = _create_test_data.ipv4_rule_id
+ user_id = _create_test_data.user_id
+ org_id = _create_test_data.org1_id
+
+ # Set past expiration time
+ past_expires = datetime.now() - timedelta(hours=1)
+
+ # Call the function
+ model, messages = rule_service.reactivate_rule(
+ rule_type=RuleTypes.IPv4,
+ rule_id=rule_id,
+ expires=past_expires,
+ comment="Expired comment",
+ user_id=user_id,
+ org_id=org_id,
+ user_email="test@example.com",
+ org_name="Test Org",
+ )
+
+ # Verify the rule was updated
+ assert model is not None
+ assert model.id == rule_id
+ assert model.expires == past_expires
+ assert model.rstate_id == 2 # inactive/withdrawn state
+
+ # Verify route announcement was attempted
+ assert mock_announce_route.called
+
+ def test_delete_rule(self, app_context, mock_announce_route):
+ """Test deleting a rule."""
+ # Test data
+ rule_id = _create_test_data.ipv6_rule_id
+ user_id = _create_test_data.user_id
+
+ # Call the function
+ success, message = rule_service.delete_rule(
+ rule_type=RuleTypes.IPv6,
+ rule_id=rule_id,
+ user_id=user_id,
+ user_email="test@example.com",
+ org_name="Test Org",
+ )
+
+ # Verify the rule was deleted
+ assert success is True
+ assert message == "Rule deleted successfully"
+
+ # Verify the rule no longer exists in the database
+ rule = db.session.get(Flowspec6, rule_id)
+ assert rule is None
+
+ # Verify route withdrawal was attempted
+ assert mock_announce_route.called
+
+ def test_delete_rule_not_found(self, app_context):
+ """Test deleting a non-existent rule."""
+ # Use a non-existent rule ID
+ non_existent_id = 9999
+
+ # Call the function
+ success, message = rule_service.delete_rule(
+ rule_type=RuleTypes.IPv4,
+ rule_id=non_existent_id,
+ user_id=_create_test_data.user_id,
+ user_email="test@example.com",
+ org_name="Test Org",
+ )
+
+ # Verify the operation failed
+ assert success is False
+ assert message == "Rule not found"
+
+ @patch("flowapp.services.rule_service.create_or_update_whitelist")
+ def test_delete_rtbh_and_create_whitelist(self, mock_create_whitelist, app_context, mock_announce_route):
+ """Test deleting an RTBH rule and creating a whitelist entry."""
+ # Test data
+ rule_id = _create_test_data.rtbh_rule_id
+ user_id = _create_test_data.user_id
+ org_id = _create_test_data.org1_id
+
+ # Mock whitelist creation
+ mock_whitelist = Whitelist(
+ ip="192.168.1.100", mask=32, expires=datetime.now() + timedelta(days=7), user_id=user_id, org_id=org_id
+ )
+ mock_create_whitelist.return_value = (mock_whitelist, ["Whitelist created"])
+
+ # Call the function
+ success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist(
+ rule_id=rule_id, user_id=user_id, org_id=org_id, user_email="test@example.com", org_name="Test Org"
+ )
+
+ # Verify success
+ assert success is True
+ assert len(messages) == 2
+ assert "Rule deleted successfully" in messages[0]
+ assert "Whitelist created" in messages[1]
+
+ # Verify the rule was deleted
+ rule = db.session.get(RTBH, rule_id)
+ assert rule is None
+
+ # Verify create_or_update_whitelist was called with correct data
+ mock_create_whitelist.assert_called_once()
+ args, kwargs = mock_create_whitelist.call_args
+ form_data = kwargs.get("form_data", args[0] if args else None)
+ assert form_data["ip"] == "192.168.1.100"
+ assert form_data["mask"] == 32
+ assert "Created from RTBH rule" in form_data["comment"]
diff --git a/flowapp/tests/test_api_auth.py b/flowapp/tests/test_api_auth.py
index 5733346b..8d08a7ae 100644
--- a/flowapp/tests/test_api_auth.py
+++ b/flowapp/tests/test_api_auth.py
@@ -11,6 +11,15 @@ def test_token(client, jwt_token):
assert req.status_code == 200
+def test_machine_token(client, machine_api_token):
+ """
+ test that token authorization works
+ """
+ req = client.get("/api/v3/test_token", headers={"x-access-token": machine_api_token})
+
+ assert req.status_code == 200
+
+
def test_expired_token(client, expired_auth_token):
"""
test that expired token authorization return 401
@@ -37,7 +46,7 @@ def test_readonly_token(client, readonly_jwt_token):
assert req.status_code == 200
data = json.loads(req.data)
- assert data['readonly']
+ assert data["readonly"]
def test_readonly_token_ipv4_create(client, db, readonly_jwt_token):
diff --git a/flowapp/tests/test_api_v3.py b/flowapp/tests/test_api_v3.py
index 96b24786..87f70949 100644
--- a/flowapp/tests/test_api_v3.py
+++ b/flowapp/tests/test_api_v3.py
@@ -1,10 +1,37 @@
import json
+
from flowapp.models import Flowspec4, Organization
V_PREFIX = "/api/v3"
+def test_create_rtbh_only(client, app, db, jwt_token):
+ """
+ Test creating an RTBH rule through API that exactly matches an existing whitelist.
+
+ The rule should be created but marked as whitelisted (rstate_id=4), and a cache entry
+ should be created linking the rule to the whitelist.
+ """
+
+ # Now create the RTBH rule via API
+ res = client.post(
+ f"{V_PREFIX}/rules/rtbh",
+ headers={"x-access-token": jwt_token},
+ json={
+ "community": 1,
+ "ipv4": "147.230.17.17",
+ "ipv4_mask": 32,
+ "expires": "10/25/2050 14:46",
+ },
+ )
+
+ # Verify response is successful
+ assert res.status_code == 201
+ data = json.loads(res.data)
+ assert data["rule"] is not None
+
+
def test_token(client, jwt_token):
"""
test that token authorization works
diff --git a/flowapp/tests/test_api_whitelist_integration.py b/flowapp/tests/test_api_whitelist_integration.py
new file mode 100644
index 00000000..0391a762
--- /dev/null
+++ b/flowapp/tests/test_api_whitelist_integration.py
@@ -0,0 +1,273 @@
+"""
+Integration test for the API view when interacting with whitelists.
+
+This test suite verifies the behavior when creating RTBH rules through the API
+when there are existing whitelists that could affect the rules.
+"""
+
+import json
+import pytest
+from datetime import datetime, timedelta
+
+from flowapp.constants import RuleTypes, RuleOrigin
+from flowapp.models import RTBH, RuleWhitelistCache, Organization
+from flowapp.services import whitelist_service
+
+
+@pytest.fixture
+def whitelist_data():
+ """Create whitelist data for testing"""
+ return {
+ "ip": "192.168.1.0",
+ "mask": 24,
+ "comment": "Test whitelist for API integration test",
+ "expires": datetime.now() + timedelta(days=1),
+ }
+
+
+@pytest.fixture
+def rtbh_api_payload():
+ """Create RTBH rule data for API testing"""
+ return {
+ "community": 1,
+ "ipv4": "192.168.1.0",
+ "ipv4_mask": 24,
+ "expires": (datetime.now() + timedelta(days=1)).strftime("%m/%d/%Y %H:%M"),
+ "comment": "Test RTBH rule via API",
+ }
+
+
+def test_create_rtbh_equal_to_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload):
+ """
+ Test creating an RTBH rule through API that exactly matches an existing whitelist.
+
+ The rule should be created but marked as whitelisted (rstate_id=4), and a cache entry
+ should be created linking the rule to the whitelist.
+ """
+ # First, configure app to include community ID 1 in allowed communities
+ app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]})
+
+ # Create the whitelist directly using the service
+ with app.app_context():
+ # Create user and organization if needed for the whitelist
+ org = db.session.query(Organization).first()
+
+ # Create the whitelist
+ whitelist_model, _ = whitelist_service.create_or_update_whitelist(
+ form_data=whitelist_data,
+ user_id=1, # Using user ID 1 from test fixtures
+ org_id=org.id,
+ user_email="test@example.com",
+ org_name=org.name,
+ )
+
+ # Verify whitelist was created
+ assert whitelist_model.id is not None
+ assert whitelist_model.ip == whitelist_data["ip"]
+ assert whitelist_model.mask == whitelist_data["mask"]
+
+ # Now create the RTBH rule via API
+ response = client.post(
+ "/api/v3/rules/rtbh",
+ headers={"x-access-token": jwt_token},
+ json=rtbh_api_payload,
+ )
+
+ # Verify response is successful
+ assert response.status_code == 201
+ data = json.loads(response.data)
+ assert data["rule"] is not None
+ rule_id = data["rule"]["id"]
+
+ # Now verify the rule was created but marked as whitelisted
+ with app.app_context():
+ rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first()
+ assert rtbh_rule is not None
+ assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state
+
+ # Verify a cache entry was created
+ cache_entry = RuleWhitelistCache.query.filter_by(
+ rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id
+ ).first()
+
+ assert cache_entry is not None
+ assert cache_entry.rorigin == RuleOrigin.USER.value
+
+
+def test_create_rtbh_supernet_of_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload):
+ """
+ Test creating an RTBH rule through API that is a supernet of an existing whitelist.
+
+ The rule should be created with whitelisted state (rstate_id=4) and smaller subnet rules
+ should be created for the non-whitelisted parts.
+ """
+ # Configure app with allowed communities
+ app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]})
+
+ # First create a whitelist for a subnet
+ whitelist_data["ip"] = "192.168.1.128"
+ whitelist_data["mask"] = 25 # Subnet of the RTBH rule which will be /24
+
+ # Create the whitelist directly using the service
+ with app.app_context():
+ org = db.session.query(Organization).first()
+ whitelist_model, _ = whitelist_service.create_or_update_whitelist(
+ form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name
+ )
+
+ # Verify whitelist was created
+ assert whitelist_model.id is not None
+
+ # Now create the RTBH rule via API that covers both the whitelist and additional space
+ headers = {"x-access-token": jwt_token}
+ response = client.post(
+ "/api/v3/rules/rtbh",
+ headers=headers,
+ json=rtbh_api_payload, # This is a /24, which contains the /25 whitelist
+ )
+
+ # Verify response is successful
+ assert response.status_code == 201
+ data = json.loads(response.data)
+ assert data["rule"] is not None
+ rule_id = data["rule"]["id"]
+
+ # Now verify the rule was created and marked as whitelisted
+ with app.app_context():
+ # Check the original rule
+ rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first()
+ assert rtbh_rule is not None
+ assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state
+
+ # Check if a new subnet rule was created for the non-whitelisted part
+ subnet_rule = (
+ db.session.query(RTBH)
+ .filter(
+ RTBH.ipv4 == "192.168.1.0",
+ RTBH.ipv4_mask == 25, # This would be the other half not covered by the whitelist
+ )
+ .first()
+ )
+
+ assert subnet_rule is not None
+ assert subnet_rule.rstate_id == 1 # Active status
+
+ # Verify cache entries
+ # Main rule should be cached as a USER rule
+ user_cache = RuleWhitelistCache.query.filter_by(
+ rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id, rorigin=RuleOrigin.USER.value
+ ).first()
+ assert user_cache is not None
+
+ # Subnet rule should be cached as a WHITELIST rule
+ whitelist_cache = RuleWhitelistCache.query.filter_by(
+ rid=subnet_rule.id,
+ rtype=RuleTypes.RTBH.value,
+ whitelist_id=whitelist_model.id,
+ rorigin=RuleOrigin.WHITELIST.value,
+ ).first()
+ assert whitelist_cache is not None
+
+
+def test_create_rtbh_subnet_of_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload):
+ """
+ Test creating an RTBH rule through API that is contained within an existing whitelist.
+
+ The rule should be created but immediately marked as whitelisted (rstate_id=4).
+ """
+ # Configure app with allowed communities
+ app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]})
+
+ # First create a whitelist for a supernet
+ whitelist_data["ip"] = "192.168.0.0"
+ whitelist_data["mask"] = 16 # Supernet that contains the RTBH rule
+
+ # Create the whitelist directly using the service
+ with app.app_context():
+ all_rtbh_rules_before = db.session.query(RTBH).count()
+
+ org = db.session.query(Organization).first()
+ whitelist_model, _ = whitelist_service.create_or_update_whitelist(
+ form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name
+ )
+
+ # Verify whitelist was created
+ assert whitelist_model.id is not None
+
+ # Now create the RTBH rule via API that is inside the whitelist
+ headers = {"x-access-token": jwt_token}
+ response = client.post(
+ "/api/v3/rules/rtbh", headers=headers, json=rtbh_api_payload # This is a /24 inside the /16 whitelist
+ )
+
+ # Verify response is successful
+ assert response.status_code == 201
+ data = json.loads(response.data)
+ assert data["rule"] is not None
+ rule_id = data["rule"]["id"]
+
+ # Now verify the rule was created but marked as whitelisted
+ with app.app_context():
+ rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first()
+ assert rtbh_rule is not None
+ assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state
+
+ # Verify a cache entry was created
+ cache_entry = RuleWhitelistCache.query.filter_by(
+ rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id
+ ).first()
+
+ assert cache_entry is not None
+ assert cache_entry.rorigin == RuleOrigin.USER.value
+
+ # Verify no additional rules were created
+ all_rtbh_rules = db.session.query(RTBH).count()
+ assert all_rtbh_rules - all_rtbh_rules_before == 1
+
+
+def test_create_rtbh_no_relation_to_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload):
+ """
+ Test creating an RTBH rule through API that has no relation to any existing whitelist.
+
+ The rule should be created normally with active state (rstate_id=1).
+ """
+ # Configure app with allowed communities
+ app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]})
+
+ # First create a whitelist for a completely different network
+ whitelist_data["ip"] = "10.0.0.0"
+ whitelist_data["mask"] = 8
+
+ # Create the whitelist directly using the service
+ with app.app_context():
+ org = db.session.query(Organization).first()
+ whitelist_model, _ = whitelist_service.create_or_update_whitelist(
+ form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name
+ )
+
+ # Verify whitelist was created
+ assert whitelist_model.id is not None
+
+ # Now create the RTBH rule via API for a network not covered by any of the whitelists
+ rtbh_api_payload["ipv4"] = "147.230.17.185"
+ rtbh_api_payload["ipv4_mask"] = 32
+
+ headers = {"x-access-token": jwt_token}
+ response = client.post("/api/v3/rules/rtbh", headers=headers, json=rtbh_api_payload) # This is 192.168.1.0/24
+
+ # Verify response is successful
+ assert response.status_code == 201
+ data = json.loads(response.data)
+ assert data["rule"] is not None
+ rule_id = data["rule"]["id"]
+
+ # Now verify the rule was created with active state
+ with app.app_context():
+ rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first()
+ assert rtbh_rule is not None
+ assert rtbh_rule.rstate_id == 1 # 1 = active state
+
+ # Verify no cache entry was created
+ cache_entry = RuleWhitelistCache.query.filter_by(rid=rule_id, rtype=RuleTypes.RTBH.value).first()
+
+ assert cache_entry is None
diff --git a/flowapp/tests/test_flowspec.py b/flowapp/tests/test_flowspec.py
index e7a9c3e9..70000b81 100644
--- a/flowapp/tests/test_flowspec.py
+++ b/flowapp/tests/test_flowspec.py
@@ -1,12 +1,12 @@
import pytest
-import flowapp.flowspec
+from flowapp.flowspec import translate_sequence, filter_rules_action, check_limit
def test_translate_number():
"""
tests for x (integer) to =x
"""
- assert "[=10]" == flowapp.flowspec.translate_sequence("10")
+ assert "[=10]" == translate_sequence("10")
def test_raises():
@@ -14,7 +14,7 @@ def test_raises():
tests for translator
"""
with pytest.raises(ValueError):
- flowapp.flowspec.translate_sequence("ahoj")
+ translate_sequence("ahoj")
def test_raises_bad_number():
@@ -22,46 +22,146 @@ def test_raises_bad_number():
tests for translator
"""
with pytest.raises(ValueError):
- flowapp.flowspec.translate_sequence("75555")
+ translate_sequence("75555")
def test_translate_range():
"""
tests for x-y to >=x&<=y
"""
- assert "[>=10&<=20]" == flowapp.flowspec.translate_sequence("10-20")
+ assert "[>=10&<=20]" == translate_sequence("10-20")
def test_exact_rule():
"""
test for >=x&<=y to >=x&<=y
"""
- assert "[>=10&<=20]" == flowapp.flowspec.translate_sequence(">=10&<=20")
+ assert "[>=10&<=20]" == translate_sequence(">=10&<=20")
def test_greater_than():
"""
test for >x to >=x&<=65535
"""
- assert "[>=10&<=65535]" == flowapp.flowspec.translate_sequence(">10")
+ assert "[>=10&<=65535]" == translate_sequence(">10")
def test_greater_equal_than():
"""
test for >=x to >=x&<=65535
"""
- assert "[>=10&<=65535]" == flowapp.flowspec.translate_sequence(">=10")
+ assert "[>=10&<=65535]" == translate_sequence(">=10")
def test_lower_than():
"""
test for =0&<=0
"""
- assert "[>=0&<=10]" == flowapp.flowspec.translate_sequence("<10")
+ assert "[>=0&<=10]" == translate_sequence("<10")
def test_lower_equal_than():
"""
test for =0&<=0
"""
- assert "[>=0&<=10]" == flowapp.flowspec.translate_sequence("<=10")
+ assert "[>=0&<=10]" == translate_sequence("<=10")
+
+
+# new tests
+
+
+def test_multiple_sequences():
+ """Test multiple sequences separated by semicolons"""
+ assert "[=10 >=20&<=30]" == translate_sequence("10;20-30")
+ assert "[=10 >=0&<=20 >=30&<=65535]" == translate_sequence("10;<=20;>30")
+
+
+def test_empty_sequence():
+ """Test empty sequence and sequences with empty parts"""
+ assert "[]" == translate_sequence("")
+ assert "[=10]" == translate_sequence("10;;")
+ assert "[=10 =20]" == translate_sequence("10;;20")
+
+
+def test_range_edge_cases():
+ """Test edge cases for ranges"""
+ # Same numbers in range
+ assert "[>=10&<=10]" == translate_sequence("10-10")
+
+ # Invalid range (start > end)
+ with pytest.raises(ValueError, match="Invalid range: start value cannot be greater than end value"):
+ translate_sequence("20-10")
+
+ # Invalid range in exact rule
+ with pytest.raises(ValueError, match="Invalid range: start value cannot be greater than end value"):
+ translate_sequence(">=20&<=10")
+
+
+def test_check_limit_validation():
+ """Test the check_limit function"""
+ # Test valid cases
+ assert check_limit(10, 100) == 10
+ assert check_limit(0, 100, 0) == 0
+ assert check_limit(100, 100, 0) == 100
+
+ # Test invalid cases
+ with pytest.raises(ValueError, match="Invalid value number: .* is too big"):
+ check_limit(101, 100)
+
+ with pytest.raises(ValueError, match="Invalid value number: .* is too small"):
+ check_limit(-1, 100, 0)
+
+
+def test_invalid_inputs():
+ """Test various invalid inputs"""
+ invalid_inputs = [
+ "abc", # Non-numeric
+ "10.5", # Decimal
+ "10,20", # Wrong separator
+ "10&20", # Wrong format
+ ">", # Incomplete
+ ">=", # Incomplete
+ "<", # Incomplete
+ "<=", # Incomplete
+ ">>10", # Invalid operator
+ "<<10", # Invalid operator
+ ">=10&", # Incomplete range
+ "&<=10", # Incomplete range
+ ]
+
+ for invalid_input in invalid_inputs:
+ with pytest.raises(ValueError):
+ translate_sequence(invalid_input)
+
+
+class MockRule:
+ def __init__(self, action_id):
+ self.action_id = action_id
+
+
+def test_filter_rules_action():
+ """Test the filter_rules_action function"""
+ # Create test rules
+ rules = [MockRule(1), MockRule(2), MockRule(3), MockRule(4)]
+
+ # Test with empty allowed actions
+ editable, viewonly = filter_rules_action([], rules)
+ assert len(editable) == 0
+ assert len(viewonly) == 4
+
+ # Test with some allowed actions
+ editable, viewonly = filter_rules_action([1, 3], rules)
+ assert len(editable) == 2
+ assert len(viewonly) == 2
+ assert all(rule.action_id in [1, 3] for rule in editable)
+ assert all(rule.action_id in [2, 4] for rule in viewonly)
+
+ # Test with all actions allowed
+ editable, viewonly = filter_rules_action([1, 2, 3, 4], rules)
+ assert len(editable) == 4
+ assert len(viewonly) == 0
+
+ # Test with empty rules list
+ editable, viewonly = filter_rules_action([1, 2], [])
+ assert len(editable) == 0
+ assert len(viewonly) == 0
diff --git a/flowapp/tests/test_forms.py b/flowapp/tests/test_forms.py
index 76c98960..e9620d8d 100644
--- a/flowapp/tests/test_forms.py
+++ b/flowapp/tests/test_forms.py
@@ -1,15 +1,7 @@
import pytest
-from flask import Flask
import flowapp.forms
-@pytest.fixture()
-def app():
- app = Flask(__name__)
- app.secret_key = "test"
- return app
-
-
@pytest.fixture()
def ip_form(app, field_class):
with app.test_request_context(): # Push the request context
diff --git a/flowapp/tests/test_forms_cl.py b/flowapp/tests/test_forms_cl.py
new file mode 100644
index 00000000..db0a569f
--- /dev/null
+++ b/flowapp/tests/test_forms_cl.py
@@ -0,0 +1,608 @@
+import pytest
+from datetime import datetime, timedelta
+from werkzeug.datastructures import MultiDict
+from flowapp.forms import (
+ UserForm,
+ BulkUserForm,
+ ApiKeyForm,
+ MachineApiKeyForm,
+ OrganizationForm,
+ ActionForm,
+ ASPathForm,
+ CommunityForm,
+ RTBHForm,
+ IPv4Form,
+ IPv6Form,
+ WhitelistForm,
+)
+
+
+def create_form_data(data):
+ """Helper function to create proper form data format"""
+ processed_data = {}
+ for key, value in data.items():
+ if isinstance(value, list):
+ processed_data[key] = [str(v) for v in value]
+ else:
+ processed_data[key] = value
+ return MultiDict(processed_data)
+
+
+@pytest.fixture
+def valid_datetime():
+ return (datetime.now() + timedelta(days=1)).strftime("%Y-%m-%dT%H:%M")
+
+
+@pytest.fixture
+def sample_network_ranges():
+ return ["192.168.0.0/16", "2001:db8::/32"]
+
+
+class TestUserForm:
+ @pytest.fixture
+ def mock_choices(self):
+ return {
+ "role_ids": [(1, "Admin"), (2, "User"), (3, "Guest")],
+ "org_ids": [(1, "Org1"), (2, "Org2"), (3, "Org3")],
+ }
+
+ @pytest.fixture
+ def valid_user_data(self):
+ return {
+ "uuid": "test@example.com",
+ "email": "user@example.com",
+ "name": "Test User",
+ "phone": "123456789",
+ "role_ids": ["2"],
+ "org_ids": ["1"],
+ }
+
+ def test_valid_user_form(self, app, valid_user_data, mock_choices):
+ with app.test_request_context():
+ form_data = create_form_data(valid_user_data)
+ form = UserForm(formdata=form_data)
+ form.role_ids.choices = mock_choices["role_ids"]
+ form.org_ids.choices = mock_choices["org_ids"]
+
+ if not form.validate():
+ print("Validation errors:", form.errors)
+
+ assert form.validate()
+
+ def test_invalid_email(self, app, mock_choices):
+ with app.test_request_context():
+ form_data = create_form_data({"uuid": "invalid-email", "role_ids": ["2"], "org_ids": ["1"]})
+ form = UserForm(formdata=form_data)
+ form.role_ids.choices = mock_choices["role_ids"]
+ form.org_ids.choices = mock_choices["org_ids"]
+
+ assert not form.validate()
+ assert "Please provide valid email" in form.uuid.errors
+
+
+class TestRTBHForm:
+ @pytest.fixture
+ def mock_community_choices(self):
+ return [(1, "Community1"), (2, "Community2")]
+
+ @pytest.fixture
+ def valid_rtbh_data(self, valid_datetime):
+ return {
+ "ipv4": "192.168.1.0",
+ "ipv4_mask": 24,
+ "community": "1",
+ "expires": valid_datetime,
+ "comment": "Test RTBH rule",
+ }
+
+ def test_valid_ipv4_rtbh(self, app, valid_rtbh_data, sample_network_ranges, mock_community_choices):
+ with app.test_request_context():
+ form_data = create_form_data(valid_rtbh_data)
+ form = RTBHForm(formdata=form_data)
+ form.net_ranges = sample_network_ranges
+ form.community.choices = mock_community_choices
+ assert form.validate()
+
+
+class TestIPv4Form:
+ @pytest.fixture
+ def mock_action_choices(self):
+ return [(1, "Accept"), (2, "Drop"), (3, "Reject")]
+
+ @pytest.fixture
+ def valid_ipv4_data(self, valid_datetime):
+ return {
+ "source": "192.168.1.0",
+ "source_mask": "24",
+ "protocol": "tcp",
+ "action": "1",
+ "expires": valid_datetime,
+ }
+
+ def test_valid_ipv4_rule(self, app, valid_ipv4_data, sample_network_ranges, mock_action_choices):
+ with app.test_request_context():
+ form_data = create_form_data(valid_ipv4_data)
+ form = IPv4Form(formdata=form_data)
+ form.net_ranges = sample_network_ranges
+ form.action.choices = mock_action_choices
+
+ if not form.validate():
+ print("Validation errors:", form.errors)
+
+ assert form.validate()
+
+ def test_invalid_protocol_flags(self, app, valid_datetime, sample_network_ranges, mock_action_choices):
+ with app.test_request_context():
+ form_data = create_form_data(
+ {
+ "source": "192.168.1.0",
+ "source_mask": "24",
+ "protocol": "udp",
+ "flags": ["syn"],
+ "action": "1",
+ "expires": valid_datetime,
+ }
+ )
+ form = IPv4Form(formdata=form_data)
+ form.net_ranges = sample_network_ranges
+ form.action.choices = mock_action_choices
+ assert not form.validate()
+
+
+class TestIPv6Form:
+ @pytest.fixture
+ def mock_action_choices(self):
+ return [(1, "Accept"), (2, "Drop"), (3, "Reject")]
+
+ @pytest.fixture
+ def valid_ipv6_data(self, valid_datetime):
+ return {
+ "source": "2001:db8::", # Network aligned address within allowed range
+ "source_mask": "32", # Matching the organization's prefix length
+ "next_header": "tcp",
+ "action": "1",
+ "expires": valid_datetime,
+ }
+
+ def test_valid_ipv6_rule(self, app, valid_ipv6_data, sample_network_ranges, mock_action_choices):
+ with app.test_request_context():
+ form_data = create_form_data(valid_ipv6_data)
+ form = IPv6Form(formdata=form_data)
+ form.net_ranges = sample_network_ranges
+ form.action.choices = mock_action_choices
+
+ if not form.validate():
+ print("Validation errors:", form.errors)
+
+ assert form.validate()
+
+ def test_invalid_next_header_flags(self, app, valid_datetime, sample_network_ranges, mock_action_choices):
+ with app.test_request_context():
+ form_data = create_form_data(
+ {
+ "source": "2001:db8::",
+ "source_mask": "32",
+ "next_header": "udp",
+ "flags": ["syn"],
+ "action": "1",
+ "expires": valid_datetime,
+ }
+ )
+ form = IPv6Form(formdata=form_data)
+ form.net_ranges = sample_network_ranges
+ form.action.choices = mock_action_choices
+ assert not form.validate()
+
+ def test_address_outside_range(self, app, valid_datetime, sample_network_ranges, mock_action_choices):
+ """Test validation fails when address is outside allowed ranges"""
+ with app.test_request_context():
+ form_data = create_form_data(
+ {
+ "source": "2001:db9::", # Different prefix
+ "source_mask": "32",
+ "next_header": "tcp",
+ "action": "1",
+ "expires": valid_datetime,
+ }
+ )
+ form = IPv6Form(formdata=form_data)
+ form.net_ranges = sample_network_ranges
+ form.action.choices = mock_action_choices
+ assert not form.validate()
+ assert any("must be in organization range" in error for error in form.source.errors)
+
+ def test_destination_address(self, app, valid_datetime, sample_network_ranges, mock_action_choices):
+ """Test validation with destination address instead of source"""
+ with app.test_request_context():
+ form_data = create_form_data(
+ {
+ "dest": "2001:db8::",
+ "dest_mask": "32",
+ "next_header": "tcp",
+ "action": "1",
+ "expires": valid_datetime,
+ }
+ )
+ form = IPv6Form(formdata=form_data)
+ form.net_ranges = sample_network_ranges
+ form.action.choices = mock_action_choices
+
+ if not form.validate():
+ print("Validation errors:", form.errors)
+
+ assert form.validate()
+
+ def test_both_source_and_dest(self, app, valid_datetime, sample_network_ranges, mock_action_choices):
+ """Test validation with both source and destination addresses"""
+ with app.test_request_context():
+ form_data = create_form_data(
+ {
+ "source": "2001:db8::",
+ "source_mask": "32",
+ "dest": "2001:db8:1::",
+ "dest_mask": "48",
+ "next_header": "tcp",
+ "action": "1",
+ "expires": valid_datetime,
+ }
+ )
+ form = IPv6Form(formdata=form_data)
+ form.net_ranges = sample_network_ranges
+ form.action.choices = mock_action_choices
+
+ if not form.validate():
+ print("Validation errors:", form.errors)
+
+ assert form.validate()
+
+ def test_tcp_flags(self, app, valid_datetime, sample_network_ranges, mock_action_choices):
+ """Test validation with TCP flags (should be valid with TCP)"""
+ with app.test_request_context():
+ form_data = create_form_data(
+ {
+ "source": "2001:db8::",
+ "source_mask": "32",
+ "next_header": "tcp",
+ "flags": ["SYN", "ACK"],
+ "action": "1",
+ "expires": valid_datetime,
+ }
+ )
+ form = IPv6Form(formdata=form_data)
+ form.net_ranges = sample_network_ranges
+ form.action.choices = mock_action_choices
+
+ if not form.validate():
+ print("Validation errors:", form.errors)
+
+ assert form.validate()
+
+ @pytest.mark.parametrize("port_data", ["80", "80;443", "1024-2048"])
+ def test_valid_ports(self, app, valid_datetime, sample_network_ranges, mock_action_choices, port_data):
+ """Test validation with various valid port formats"""
+ with app.test_request_context():
+ form_data = create_form_data(
+ {
+ "source": "2001:db8::",
+ "source_mask": "32",
+ "next_header": "tcp",
+ "source_port": port_data,
+ "dest_port": port_data,
+ "action": "1",
+ "expires": valid_datetime,
+ }
+ )
+ form = IPv6Form(formdata=form_data)
+ form.net_ranges = sample_network_ranges
+ form.action.choices = mock_action_choices
+
+ if not form.validate():
+ print("Validation errors:", form.errors)
+
+ assert form.validate()
+
+
+class TestBulkUserForm:
+ @pytest.fixture
+ def valid_csv_data(self):
+ return "uuid-eppn,role,organizace\nuser1@example.com,2,1\nuser2@example.com,2,1"
+
+ def test_valid_bulk_import(self, app, valid_csv_data):
+ with app.test_request_context():
+ form_data = create_form_data({"users": valid_csv_data})
+ form = BulkUserForm(formdata=form_data)
+ form.roles = {2} # Mock available roles
+ form.organizations = {1} # Mock available organizations
+ form.uuids = set() # Mock existing UUIDs (empty set)
+ assert form.validate()
+
+ def test_duplicate_uuid(self, app, valid_csv_data):
+ with app.test_request_context():
+ form_data = create_form_data({"users": valid_csv_data})
+ form = BulkUserForm(formdata=form_data)
+ form.roles = {2}
+ form.organizations = {1}
+ form.uuids = {"user1@example.com"} # UUID already exists
+ assert not form.validate()
+ assert any("already exists" in error for error in form.users.errors)
+
+ def test_invalid_role(self, app):
+ csv_data = "uuid-eppn,role,organizace\nuser1@example.com,999,1" # Invalid role ID
+ with app.test_request_context():
+ form_data = create_form_data({"users": csv_data})
+ form = BulkUserForm(formdata=form_data)
+ form.roles = {2}
+ form.organizations = {1}
+ form.uuids = set()
+ assert not form.validate()
+ assert any("does not exist" in error for error in form.users.errors)
+
+
+class TestApiKeyForm:
+ @pytest.fixture
+ def valid_api_key_data(self, valid_datetime):
+ return {"machine": "192.168.1.1", "comment": "Test API key", "expires": valid_datetime, "readonly": "true"}
+
+ def test_valid_api_key(self, app, valid_api_key_data):
+ with app.test_request_context():
+ form_data = create_form_data(valid_api_key_data)
+ form = ApiKeyForm(formdata=form_data)
+ assert form.validate()
+
+ def test_invalid_ip(self, app, valid_datetime):
+ with app.test_request_context():
+ form_data = create_form_data({"machine": "invalid_ip", "expires": valid_datetime})
+ form = ApiKeyForm(formdata=form_data)
+ assert not form.validate()
+ assert "provide valid IP address" in form.machine.errors
+
+ def test_unlimited_expiration(self, app):
+ with app.test_request_context():
+ form_data = create_form_data({"machine": "192.168.1.1", "expires": ""}) # Empty expiration for unlimited
+ form = ApiKeyForm(formdata=form_data)
+ assert form.validate()
+
+
+class TestMachineApiKeyForm:
+ # Similar to ApiKeyForm, but might have different validation rules
+ @pytest.fixture
+ def valid_machine_key_data(self, valid_datetime):
+ return {
+ "machine": "192.168.1.1",
+ "comment": "Test machine API key",
+ "expires": valid_datetime,
+ "readonly": "true",
+ "user": 1,
+ }
+
+ def test_valid_machine_key(self, app, valid_machine_key_data):
+ with app.test_request_context():
+ form_data = create_form_data(valid_machine_key_data)
+ form = MachineApiKeyForm(formdata=form_data)
+ form.user.choices = [(1, "g.name"), (2, "test")]
+ assert form.validate()
+
+
+class TestOrganizationForm:
+ @pytest.fixture
+ def valid_org_data(self):
+ return {
+ "name": "Test Organization",
+ "limit_flowspec4": "100",
+ "limit_flowspec6": "100",
+ "limit_rtbh": "50",
+ "arange": "192.168.0.0/16\n2001:db8::/32",
+ }
+
+ def test_valid_organization(self, app, valid_org_data):
+ with app.test_request_context():
+ form_data = create_form_data(valid_org_data)
+ form = OrganizationForm(formdata=form_data)
+ assert form.validate()
+
+ def test_invalid_ranges(self, app):
+ with app.test_request_context():
+ form_data = create_form_data({"name": "Test Org", "arange": "invalid_range\n192.168.0.0/16"})
+ form = OrganizationForm(formdata=form_data)
+ assert not form.validate()
+
+ def test_invalid_limits(self, app):
+ with app.test_request_context():
+ form_data = create_form_data({"name": "Test Org", "limit_flowspec4": "1001"}) # Exceeds max value
+ form = OrganizationForm(formdata=form_data)
+ assert not form.validate()
+
+
+class TestActionForm:
+ @pytest.fixture
+ def valid_action_data(self):
+ return {
+ "name": "test_action",
+ "command": "announce route",
+ "description": "Test action description",
+ "role_id": "2",
+ }
+
+ def test_valid_action(self, app, valid_action_data):
+ with app.test_request_context():
+ form_data = create_form_data(valid_action_data)
+ form = ActionForm(formdata=form_data)
+ assert form.validate()
+
+ def test_invalid_role(self, app, valid_action_data):
+ with app.test_request_context():
+ invalid_data = dict(valid_action_data)
+ invalid_data["role_id"] = "4" # Invalid role
+ form_data = create_form_data(invalid_data)
+ form = ActionForm(formdata=form_data)
+ assert not form.validate()
+
+
+class TestASPathForm:
+ @pytest.fixture
+ def valid_aspath_data(self):
+ return {"prefix": "AS64512", "as_path": "64512 64513 64514"}
+
+ def test_valid_aspath(self, app, valid_aspath_data):
+ with app.test_request_context():
+ form_data = create_form_data(valid_aspath_data)
+ form = ASPathForm(formdata=form_data)
+ assert form.validate()
+
+ def test_empty_fields(self, app):
+ with app.test_request_context():
+ form_data = create_form_data({})
+ form = ASPathForm(formdata=form_data)
+ assert not form.validate()
+
+
+class TestCommunityForm:
+ @pytest.fixture
+ def valid_community_data(self):
+ return {
+ "name": "test_community",
+ "comm": "64512:100",
+ "description": "Test community",
+ "role_id": "2",
+ "as_path": "true",
+ }
+
+ def test_valid_community(self, app, valid_community_data):
+ with app.test_request_context():
+ form_data = create_form_data(valid_community_data)
+ form = CommunityForm(formdata=form_data)
+ assert form.validate()
+
+ def test_missing_all_community_types(self, app):
+ with app.test_request_context():
+ form_data = create_form_data({"name": "test_community", "role_id": "2"})
+ form = CommunityForm(formdata=form_data)
+ assert not form.validate()
+ assert any("could not be empty" in error for error in form.comm.errors)
+
+ def test_valid_with_extended_community(self, app):
+ with app.test_request_context():
+ form_data = create_form_data({"name": "test_community", "extcomm": "rt:64512:100", "role_id": "2"})
+ form = CommunityForm(formdata=form_data)
+ assert form.validate()
+
+ def test_valid_with_large_community(self, app):
+ with app.test_request_context():
+ form_data = create_form_data({"name": "test_community", "larcomm": "64512:100:200", "role_id": "2"})
+ form = CommunityForm(formdata=form_data)
+ assert form.validate()
+
+
+class TestWhitelistForm:
+ @pytest.fixture
+ def valid_ipv4_data(self, valid_datetime):
+ return {"ip": "192.168.0.0", "mask": "24", "comment": "Test whitelist entry", "expires": valid_datetime}
+
+ @pytest.fixture
+ def valid_ipv6_data(self, valid_datetime):
+ return {"ip": "2001:db8::", "mask": "32", "comment": "IPv6 whitelist entry", "expires": valid_datetime}
+
+ def test_valid_ipv4_entry(self, app, valid_ipv4_data, sample_network_ranges):
+ with app.test_request_context():
+ form_data = create_form_data(valid_ipv4_data)
+ form = WhitelistForm(formdata=form_data)
+ form.net_ranges = sample_network_ranges
+
+ if not form.validate():
+ print("Validation errors:", form.errors)
+
+ assert form.validate()
+
+ def test_valid_ipv6_entry(self, app, valid_ipv6_data, sample_network_ranges):
+ with app.test_request_context():
+ form_data = create_form_data(valid_ipv6_data)
+ form = WhitelistForm(formdata=form_data)
+ form.net_ranges = sample_network_ranges
+
+ if not form.validate():
+ print("Validation errors:", form.errors)
+
+ assert form.validate()
+
+ def test_invalid_ip_format(self, app, valid_datetime, sample_network_ranges):
+ with app.test_request_context():
+ form_data = create_form_data({"ip": "invalid_ip", "mask": "24", "expires": valid_datetime})
+ form = WhitelistForm(formdata=form_data)
+ form.net_ranges = sample_network_ranges
+ assert not form.validate()
+ assert any("valid IP address" in error for error in form.ip.errors)
+
+ def test_ip_outside_range(self, app, valid_datetime):
+ with app.test_request_context():
+ form_data = create_form_data(
+ {"ip": "10.0.0.0", "mask": "24", "expires": valid_datetime} # IP outside allowed range
+ )
+ form = WhitelistForm(formdata=form_data)
+ form.net_ranges = ["192.168.0.0/16"] # Only allow 192.168.0.0/16
+ assert not form.validate()
+ assert any("must be in organization range" in error for error in form.ip.errors)
+
+ def test_invalid_mask_ipv4(self, app, valid_datetime, sample_network_ranges):
+ with app.test_request_context():
+ form_data = create_form_data(
+ {"ip": "192.168.0.0", "mask": "33", "expires": valid_datetime} # Invalid mask for IPv4
+ )
+ form = WhitelistForm(formdata=form_data)
+ form.net_ranges = sample_network_ranges
+ assert not form.validate()
+
+ def test_invalid_mask_ipv6(self, app, valid_datetime, sample_network_ranges):
+ with app.test_request_context():
+ form_data = create_form_data(
+ {"ip": "2001:db8::", "mask": "129", "expires": valid_datetime} # Invalid mask for IPv6
+ )
+ form = WhitelistForm(formdata=form_data)
+ form.net_ranges = sample_network_ranges
+ assert not form.validate()
+
+ def test_missing_required_fields(self, app, sample_network_ranges):
+ with app.test_request_context():
+ form_data = create_form_data({})
+ form = WhitelistForm(formdata=form_data)
+ form.net_ranges = sample_network_ranges
+ assert not form.validate()
+ assert form.ip.errors # Should have error for missing IP
+ assert form.mask.errors # Should have error for missing mask
+ assert form.expires.errors # Should have error for missing expiration
+
+ def test_comment_length(self, app, valid_datetime, sample_network_ranges):
+ with app.test_request_context():
+ form_data = create_form_data(
+ {
+ "ip": "192.168.0.0",
+ "mask": "24",
+ "expires": valid_datetime,
+ "comment": "x" * 256, # Comment longer than 255 chars
+ }
+ )
+ form = WhitelistForm(formdata=form_data)
+ form.net_ranges = sample_network_ranges
+ assert not form.validate()
+
+ @pytest.mark.parametrize(
+ "expires", ["", "invalid_date", "2020-33-42T00:00"] # Empty expiration # Invalid date format # Past date
+ )
+ def test_invalid_expiration(self, app, expires, sample_network_ranges):
+ with app.test_request_context():
+ form_data = create_form_data({"ip": "192.168.0.0", "mask": "24", "expires": expires})
+ form = WhitelistForm(formdata=form_data)
+ form.net_ranges = sample_network_ranges
+ if not form.validate():
+ print("Validation errors:", form.errors)
+
+ assert not form.validate()
+
+ def test_network_alignment(self, app, valid_datetime, sample_network_ranges):
+ """Test that IP addresses must be properly network-aligned"""
+ with app.test_request_context():
+ form_data = create_form_data(
+ {"ip": "192.168.1.1", "mask": "24", "expires": valid_datetime} # Not aligned to /24 boundary
+ )
+ form = WhitelistForm(formdata=form_data)
+ form.net_ranges = sample_network_ranges
+ assert not form.validate()
diff --git a/flowapp/tests/test_models.py b/flowapp/tests/test_models.py
index 89599453..b352a5f2 100644
--- a/flowapp/tests/test_models.py
+++ b/flowapp/tests/test_models.py
@@ -1,4 +1,16 @@
-from datetime import datetime
+from datetime import datetime, timedelta
+from flowapp.models import (
+ User,
+ Organization,
+ Role,
+ ApiKey,
+ MachineApiKey,
+ Rstate,
+ Community,
+ Action,
+ Flowspec6,
+ Whitelist,
+)
import flowapp.models as models
@@ -232,3 +244,255 @@ def test_rtbj_eq(db):
)
assert model_A == model_B
+
+
+def test_user_creation(db):
+ """Test basic user creation and relationships"""
+ # Create test role and org first
+ role = Role(name="test_role", description="Test Role")
+ org = Organization(name="test_org", arange="10.0.0.0/8")
+ db.session.add_all([role, org])
+ db.session.commit()
+
+ # Create user with relationships
+ user = User(
+ uuid="test-user-123", name="Test User", phone="1234567890", email="test@example.com", comment="Test comment"
+ )
+ user.role.append(role)
+ user.organization.append(org)
+ db.session.add(user)
+ db.session.commit()
+
+ # Verify user and relationships
+ assert user.uuid == "test-user-123"
+ assert user.name == "Test User"
+ assert len(user.role.all()) == 1
+ assert len(user.organization.all()) == 1
+ assert user.role.first().name == "test_role"
+ assert user.organization.first().name == "test_org"
+
+
+def test_api_key_expiration(db):
+ """Test ApiKey expiration logic"""
+ user = User(uuid="test-user")
+ org = Organization(name="test-org", arange="10.0.0.0/8")
+ db.session.add_all([user, org])
+ db.session.commit()
+
+ # Create non-expiring key
+ non_expiring_key = ApiKey(
+ machine="test-machine-1",
+ key="key1",
+ readonly=True,
+ expires=None,
+ comment="Non-expiring key",
+ user_id=user.id,
+ org_id=org.id,
+ )
+
+ # Create expired key
+ expired_key = ApiKey(
+ machine="test-machine-2",
+ key="key2",
+ readonly=True,
+ expires=datetime.now() - timedelta(days=1),
+ comment="Expired key",
+ user_id=user.id,
+ org_id=org.id,
+ )
+
+ # Create future key
+ future_key = ApiKey(
+ machine="test-machine-3",
+ key="key3",
+ readonly=True,
+ expires=datetime.now() + timedelta(days=1),
+ comment="Future key",
+ user_id=user.id,
+ org_id=org.id,
+ )
+
+ db.session.add_all([non_expiring_key, expired_key, future_key])
+ db.session.commit()
+
+ assert not non_expiring_key.is_expired()
+ assert expired_key.is_expired()
+ assert not future_key.is_expired()
+
+
+def test_machine_api_key_expiration(db):
+ """Test MachineApiKey expiration logic"""
+ user = User(uuid="test-user-machine")
+ org = Organization(name="test-org-machine", arange="10.0.0.0/8")
+ db.session.add_all([user, org])
+ db.session.commit()
+
+ # Create non-expiring key
+ non_expiring_key = MachineApiKey(
+ machine="test-machine-1",
+ key="key1",
+ readonly=True,
+ expires=None,
+ comment="Non-expiring key",
+ user_id=user.id,
+ org_id=org.id,
+ )
+
+ # Create expired key
+ expired_key = MachineApiKey(
+ machine="test-machine-2",
+ key="key2",
+ readonly=True,
+ expires=datetime.now() - timedelta(days=1),
+ comment="Expired key",
+ user_id=user.id,
+ org_id=org.id,
+ )
+
+ db.session.add_all([non_expiring_key, expired_key])
+ db.session.commit()
+
+ assert not non_expiring_key.is_expired()
+ assert expired_key.is_expired()
+
+
+def test_organization_get_users(db):
+ """Test Organization's get_users method"""
+ org = Organization(
+ name="test-org-get-user", arange="10.0.0.0/8", limit_flowspec4=100, limit_flowspec6=100, limit_rtbh=100
+ )
+ uuid1 = "test-org-get-user"
+ uuid2 = "test-org-get-user2"
+ user1 = User(uuid=uuid1)
+ user2 = User(uuid=uuid2)
+
+ db.session.add(org)
+ db.session.add_all([user1, user2])
+ db.session.commit()
+
+ org.user.append(user1)
+ org.user.append(user2)
+ db.session.commit()
+
+ users = org.get_users()
+ assert len(users) == 2
+ assert all(isinstance(user, User) for user in users)
+ assert {user.uuid for user in users} == {uuid1, uuid2}
+
+
+def test_flowspec6_equality(db):
+ """Test Flowspec6 equality comparison"""
+ model_a = Flowspec6(
+ source="2001:db8::1",
+ source_mask=128,
+ source_port="80",
+ destination="2001:db8::2",
+ destination_mask=128,
+ destination_port="443",
+ next_header="tcp",
+ flags="",
+ packet_len="",
+ expires=datetime.now(),
+ user_id=1,
+ org_id=1,
+ action_id=1,
+ )
+
+ # Same network parameters but different timestamps
+ model_b = Flowspec6(
+ source="2001:db8::1",
+ source_mask=128,
+ source_port="80",
+ destination="2001:db8::2",
+ destination_mask=128,
+ destination_port="443",
+ next_header="tcp",
+ flags="",
+ packet_len="",
+ expires=datetime.now() + timedelta(days=1),
+ user_id=1,
+ org_id=1,
+ action_id=1,
+ )
+
+ # Different network parameters
+ model_c = Flowspec6(
+ source="2001:db8::3",
+ source_mask=128,
+ source_port="80",
+ destination="2001:db8::4",
+ destination_mask=128,
+ destination_port="443",
+ next_header="tcp",
+ flags="",
+ packet_len="",
+ expires=datetime.now(),
+ user_id=1,
+ org_id=1,
+ action_id=1,
+ )
+
+ assert model_a == model_b # Should be equal despite different timestamps
+ assert model_a != model_c # Should be different due to different network parameters
+
+
+def test_whitelist_equality(db):
+ """Test Whitelist equality comparison"""
+ model_a = Whitelist(
+ ip="192.168.1.1", mask=32, expires=datetime.now(), user_id=1, org_id=1, comment="Test whitelist"
+ )
+
+ # Same IP/mask but different timestamps
+ model_b = Whitelist(
+ ip="192.168.1.1",
+ mask=32,
+ expires=datetime.now() + timedelta(days=1),
+ user_id=1,
+ org_id=1,
+ comment="Different comment",
+ )
+
+ # Different IP
+ model_c = Whitelist(
+ ip="192.168.1.2", mask=32, expires=datetime.now(), user_id=1, org_id=1, comment="Test whitelist"
+ )
+
+ assert model_a == model_b # Should be equal despite different timestamps
+ assert model_a != model_c # Should be different due to different IP
+
+
+def test_whitelist_to_dict(db):
+ """Test Whitelist to_dict serialization"""
+ whitelist = Whitelist(
+ ip="192.168.1.1", mask=32, expires=datetime.now(), user_id=1, org_id=1, comment="Test whitelist"
+ )
+
+ # Create required related objects
+ user = User(uuid="test-user-whitelist")
+ rstate = Rstate(description="active")
+ db.session.add_all([user, rstate])
+ db.session.commit()
+
+ db.session.add(whitelist)
+ db.session.commit()
+
+ whitelist.user = user
+ whitelist.rstate_id = rstate.id
+ db.session.add(whitelist)
+ db.session.commit()
+
+ # Test timestamp format
+ dict_timestamp = whitelist.to_dict(prefered_format="timestamp")
+ assert isinstance(dict_timestamp["expires"], int)
+ assert isinstance(dict_timestamp["created"], int)
+
+ # Test yearfirst format
+ dict_yearfirst = whitelist.to_dict(prefered_format="yearfirst")
+ assert isinstance(dict_yearfirst["expires"], str)
+ assert isinstance(dict_yearfirst["created"], str)
+
+ # Check basic fields
+ assert dict_timestamp["ip"] == "192.168.1.1"
+ assert dict_timestamp["mask"] == 32
+ assert dict_timestamp["comment"] == "Test whitelist"
+ assert dict_timestamp["user"] == "test-user-whitelist"
diff --git a/flowapp/tests/test_rule_service.py b/flowapp/tests/test_rule_service.py
new file mode 100644
index 00000000..490d2c2f
--- /dev/null
+++ b/flowapp/tests/test_rule_service.py
@@ -0,0 +1,628 @@
+"""
+Tests for rule_service.py module.
+
+This test suite verifies the functionality of the rule service after refactoring,
+which manages creation, updating, and processing of flow rules.
+"""
+
+import pytest
+from datetime import datetime, timedelta
+from unittest.mock import patch, MagicMock
+
+from flowapp.constants import RuleOrigin, ANNOUNCE
+from flowapp.models import Flowspec4, Flowspec6, RTBH, Whitelist
+from flowapp.services import rule_service
+from flowapp.services.whitelist_common import Relation
+
+
+@pytest.fixture
+def ipv4_form_data():
+ """Sample valid IPv4 form data"""
+ return {
+ "source": "192.168.1.0",
+ "source_mask": 24,
+ "source_port": "80",
+ "dest": "",
+ "dest_mask": None,
+ "dest_port": "",
+ "protocol": "tcp",
+ "flags": ["SYN"],
+ "packet_len": "",
+ "fragment": ["dont-fragment"],
+ "comment": "Test IPv4 rule",
+ "expires": datetime.now() + timedelta(hours=1),
+ "action": 1,
+ }
+
+
+@pytest.fixture
+def ipv6_form_data():
+ """Sample valid IPv6 form data"""
+ return {
+ "source": "2001:db8::",
+ "source_mask": 32,
+ "source_port": "80",
+ "dest": "",
+ "dest_mask": None,
+ "dest_port": "",
+ "next_header": "tcp",
+ "flags": ["SYN"],
+ "packet_len": "",
+ "comment": "Test IPv6 rule",
+ "expires": datetime.now() + timedelta(hours=1),
+ "action": 1,
+ }
+
+
+@pytest.fixture
+def rtbh_form_data():
+ """Sample valid RTBH form data"""
+ return {
+ "ipv4": "192.168.1.0",
+ "ipv4_mask": 24,
+ "ipv6": "",
+ "ipv6_mask": None,
+ "community": 1,
+ "comment": "Test RTBH rule",
+ "expires": datetime.now() + timedelta(hours=1),
+ }
+
+
+@pytest.fixture
+def whitelist_fixture():
+ """Create a whitelist fixture"""
+ whitelist = MagicMock(spec=Whitelist)
+ whitelist.id = 1
+ return whitelist
+
+
+class TestCreateOrUpdateIPv4Rule:
+ @patch("flowapp.services.rule_service.get_ipv4_model_if_exists")
+ @patch("flowapp.services.rule_service.messages")
+ @patch("flowapp.services.rule_service.announce_route")
+ @patch("flowapp.services.rule_service.log_route")
+ def test_create_new_ipv4_rule(
+ self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv4_form_data
+ ):
+ """Test creating a new IPv4 rule"""
+ # Mock the get_ipv4_model_if_exists to return False (not found)
+ mock_get_model.return_value = False
+
+ # Mock the announce route behavior
+ mock_messages.create_ipv4.return_value = "mock command"
+
+ # Call the service function
+ with app.app_context():
+ model, message = rule_service.create_or_update_ipv4_rule(
+ form_data=ipv4_form_data,
+ user_id=1,
+ org_id=1,
+ user_email="test@example.com",
+ org_name="Test Org",
+ )
+
+ # Verify the model was created with correct attributes
+ assert model is not None
+ assert model.source == ipv4_form_data["source"]
+ assert model.source_mask == ipv4_form_data["source_mask"]
+ assert model.protocol == ipv4_form_data["protocol"]
+ assert model.flags == "SYN" # form data flags are joined with ";"
+ assert model.action_id == ipv4_form_data["action"]
+ assert model.rstate_id == 1 # Active state
+ assert model.user_id == 1
+ assert model.org_id == 1
+
+ # Verify message is still a string for IPv4 rules
+ assert message == "IPv4 Rule saved"
+
+ # Verify route was announced
+ mock_messages.create_ipv4.assert_called_once_with(model, ANNOUNCE)
+ mock_announce.assert_called_once()
+ mock_log.assert_called_once()
+
+ @patch("flowapp.services.rule_service.get_ipv4_model_if_exists")
+ @patch("flowapp.services.rule_service.messages")
+ @patch("flowapp.services.rule_service.announce_route")
+ @patch("flowapp.services.rule_service.log_route")
+ def test_update_existing_ipv4_rule(
+ self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv4_form_data
+ ):
+ """Test updating an existing IPv4 rule"""
+ # Create an existing model to return
+ existing_model = Flowspec4(
+ source=ipv4_form_data["source"],
+ source_mask=ipv4_form_data["source_mask"],
+ source_port=ipv4_form_data["source_port"],
+ destination=ipv4_form_data["dest"] or "",
+ destination_mask=ipv4_form_data["dest_mask"],
+ destination_port=ipv4_form_data["dest_port"] or "",
+ protocol=ipv4_form_data["protocol"],
+ flags=";".join(ipv4_form_data["flags"]),
+ packet_len=ipv4_form_data["packet_len"] or "",
+ fragment=";".join(ipv4_form_data["fragment"]),
+ expires=datetime.now(),
+ user_id=1,
+ org_id=1,
+ action_id=1,
+ rstate_id=1,
+ )
+ mock_get_model.return_value = existing_model
+ mock_messages.create_ipv4.return_value = "mock command"
+
+ # Set a new expiration time
+ new_expires = datetime.now() + timedelta(days=1)
+ ipv4_form_data["expires"] = new_expires
+
+ # Call the service function
+ with app.app_context():
+ db.session.add(existing_model)
+ db.session.commit()
+
+ model, message = rule_service.create_or_update_ipv4_rule(
+ form_data=ipv4_form_data,
+ user_id=1,
+ org_id=1,
+ user_email="test@example.com",
+ org_name="Test Org",
+ )
+
+ # Verify the model was updated
+ assert model == existing_model
+ assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date()
+
+ # Verify message is still a string for IPv4 rules
+ assert message == "Existing IPv4 Rule found. Expiration time was updated to new value."
+
+ # Verify route was announced
+ mock_messages.create_ipv4.assert_called_once_with(model, ANNOUNCE)
+ mock_announce.assert_called_once()
+ mock_log.assert_called_once()
+
+
+class TestCreateOrUpdateIPv6Rule:
+ @patch("flowapp.services.rule_service.get_ipv6_model_if_exists")
+ @patch("flowapp.services.rule_service.messages")
+ @patch("flowapp.services.rule_service.announce_route")
+ @patch("flowapp.services.rule_service.log_route")
+ def test_create_new_ipv6_rule(
+ self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv6_form_data
+ ):
+ """Test creating a new IPv6 rule"""
+ # Mock get_ipv6_model_if_exists to return False
+ mock_get_model.return_value = False
+
+ # Mock the announce route behavior
+ mock_messages.create_ipv6.return_value = "mock command"
+
+ # Call the service function
+ with app.app_context():
+ model, message = rule_service.create_or_update_ipv6_rule(
+ form_data=ipv6_form_data,
+ user_id=1,
+ org_id=1,
+ user_email="test@example.com",
+ org_name="Test Org",
+ )
+
+ # Verify the model was created with correct attributes
+ assert model is not None
+ assert model.source == ipv6_form_data["source"]
+ assert model.source_mask == ipv6_form_data["source_mask"]
+ assert model.next_header == ipv6_form_data["next_header"]
+ assert model.flags == "SYN"
+ assert model.action_id == ipv6_form_data["action"]
+ assert model.rstate_id == 1 # Active state
+ assert model.user_id == 1
+ assert model.org_id == 1
+
+ # Verify message is still a string for IPv6 rules
+ assert message == "IPv6 Rule saved"
+
+ # Verify route was announced
+ mock_messages.create_ipv6.assert_called_once_with(model, ANNOUNCE)
+ mock_announce.assert_called_once()
+ mock_log.assert_called_once()
+
+
+class TestCreateOrUpdateRTBHRule:
+ @patch("flowapp.services.rule_service.get_rtbh_model_if_exists")
+ @patch("flowapp.services.rule_service.db.session.query")
+ @patch("flowapp.services.rule_service.check_rule_against_whitelists")
+ @patch("flowapp.services.rule_service.evaluate_rtbh_against_whitelists_check_results")
+ @patch("flowapp.services.rule_service.announce_rtbh_route")
+ @patch("flowapp.services.rule_service.log_route")
+ def test_create_new_rtbh_rule(
+ self, mock_log, mock_announce, mock_evaluate, mock_check, mock_query, mock_get_model, app, db, rtbh_form_data
+ ):
+ """Test creating a new RTBH rule"""
+ # Mock get_rtbh_model_if_exists to return False
+ mock_get_model.return_value = False
+
+ # Mock the whitelist query
+ mock_whitelists = []
+ mock_query.return_value.filter.return_value.all.return_value = mock_whitelists
+
+ # Mock check_rule_against_whitelists to return empty list (no matches)
+ mock_check.return_value = []
+
+ # Mock evaluate function to return the model unchanged
+ mock_evaluate.side_effect = lambda user_id, model, flashes, author, wl_cache, results: model
+
+ # Call the service function
+ with app.app_context():
+ model, flashes = rule_service.create_or_update_rtbh_rule(
+ form_data=rtbh_form_data,
+ user_id=1,
+ org_id=1,
+ user_email="test@example.com",
+ org_name="Test Org",
+ )
+
+ # Verify the model was created with correct attributes
+ assert model is not None
+ assert model.ipv4 == rtbh_form_data["ipv4"]
+ assert model.ipv4_mask == rtbh_form_data["ipv4_mask"]
+ assert model.ipv6 == rtbh_form_data["ipv6"]
+ assert model.ipv6_mask == rtbh_form_data["ipv6_mask"]
+ assert model.community_id == rtbh_form_data["community"]
+ assert model.rstate_id == 1 # Active state
+ assert model.user_id == 1
+ assert model.org_id == 1
+
+ # Verify flash messages - now a list instead of a string
+ assert isinstance(flashes, list)
+ assert "RTBH Rule saved" in flashes[0]
+
+ # Verify rule was announced
+ mock_announce.assert_called_once()
+ mock_log.assert_called_once()
+
+ # Verify evaluate function was called
+ mock_evaluate.assert_called_once()
+
+ @patch("flowapp.services.rule_service.get_rtbh_model_if_exists")
+ @patch("flowapp.services.rule_service.db.session.query")
+ @patch("flowapp.services.rule_service.check_rule_against_whitelists")
+ @patch("flowapp.services.rule_service.evaluate_rtbh_against_whitelists_check_results")
+ @patch("flowapp.services.rule_service.announce_rtbh_route")
+ @patch("flowapp.services.rule_service.log_route")
+ def test_update_existing_rtbh_rule(
+ self, mock_log, mock_announce, mock_evaluate, mock_check, mock_query, mock_get_model, app, db, rtbh_form_data
+ ):
+ """Test updating an existing RTBH rule"""
+ # Create an existing model to return
+ existing_model = RTBH(
+ ipv4=rtbh_form_data["ipv4"],
+ ipv4_mask=rtbh_form_data["ipv4_mask"],
+ ipv6=rtbh_form_data["ipv6"] or "",
+ ipv6_mask=rtbh_form_data["ipv6_mask"],
+ community_id=rtbh_form_data["community"],
+ expires=datetime.now(),
+ user_id=1,
+ org_id=1,
+ rstate_id=1,
+ )
+ mock_get_model.return_value = existing_model
+
+ # Mock the whitelist query
+ mock_whitelists = []
+ mock_query.return_value.filter.return_value.all.return_value = mock_whitelists
+
+ # Mock check_rule_against_whitelists to return empty list
+ mock_check.return_value = []
+
+ # Mock evaluate function to return the model unchanged
+ mock_evaluate.side_effect = lambda user_id, model, flashes, author, wl_cache, results: model
+
+ # Set a new expiration time
+ new_expires = datetime.now() + timedelta(days=1)
+ rtbh_form_data["expires"] = new_expires
+
+ # Call the service function
+ with app.app_context():
+ db.session.add(existing_model)
+ db.session.commit()
+
+ model, flashes = rule_service.create_or_update_rtbh_rule(
+ form_data=rtbh_form_data,
+ user_id=1,
+ org_id=1,
+ user_email="test@example.com",
+ org_name="Test Org",
+ )
+
+ # Verify the model was updated
+ assert model == existing_model
+ assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date()
+
+ # Verify flash messages
+ assert isinstance(flashes, list)
+ assert "Existing RTBH Rule found" in flashes[0]
+
+ # Verify route was announced
+ mock_announce.assert_called_once()
+ mock_log.assert_called_once()
+
+ @patch("flowapp.services.rule_service.get_rtbh_model_if_exists")
+ @patch("flowapp.services.rule_service.db.session.query")
+ @patch("flowapp.services.rule_service.map_whitelists_to_strings")
+ @patch("flowapp.services.rule_service.check_rule_against_whitelists")
+ @patch("flowapp.services.rule_service.evaluate_rtbh_against_whitelists_check_results")
+ @patch("flowapp.services.rule_service.announce_rtbh_route")
+ @patch("flowapp.services.rule_service.log_route")
+ def test_rtbh_rule_with_whitelists(
+ self,
+ mock_log,
+ mock_announce,
+ mock_evaluate,
+ mock_check,
+ mock_map,
+ mock_query,
+ mock_get_model,
+ app,
+ db,
+ rtbh_form_data,
+ whitelist_fixture,
+ ):
+ """Test creating a RTBH rule that interacts with whitelists"""
+ # Mock get_rtbh_model_if_exists to return False
+ mock_get_model.return_value = False
+
+ # Create a mock whitelist
+ mock_whitelist = whitelist_fixture
+ mock_whitelists = [mock_whitelist]
+
+ # Setup mock query to return our whitelist
+ mock_query_result = MagicMock()
+ mock_query_result.filter.return_value.all.return_value = mock_whitelists
+ mock_query.return_value = mock_query_result
+
+ # Setup map function to return our whitelist in a dict
+ whitelist_key = "192.168.1.0/24"
+ mock_map.return_value = {whitelist_key: mock_whitelist}
+
+ # Setup check function to return a relation
+ mock_rtbh = MagicMock(spec=RTBH)
+ mock_rtbh.__str__.return_value = "192.168.1.0/24"
+
+ # Setup check result
+ rule_relation = [(str(mock_rtbh), whitelist_key, Relation.EQUAL)]
+ mock_check.return_value = rule_relation
+
+ # Setup evaluate function to add a flash message
+ def evaluate_side_effect(user_id, model, flashes, author, wl_cache, results):
+ flashes.append("Rule is equal to whitelist")
+ return model
+
+ mock_evaluate.side_effect = evaluate_side_effect
+
+ # Call the service function
+ with app.app_context():
+ model, flashes = rule_service.create_or_update_rtbh_rule(
+ form_data=rtbh_form_data,
+ user_id=1,
+ org_id=1,
+ user_email="test@example.com",
+ org_name="Test Org",
+ )
+
+ # Verify flash messages show both rule creation and whitelist check
+ assert isinstance(flashes, list)
+ assert "RTBH Rule saved" in flashes[0]
+ assert "Rule is equal to whitelist" in flashes[1]
+
+ # Verify interactions
+ mock_map.assert_called_once()
+ mock_check.assert_called_once()
+ mock_evaluate.assert_called_once()
+
+
+class TestEvaluateRtbhAgainstWhitelistsCheckResults:
+ def test_equal_relation(self, app, whitelist_fixture):
+ """Test evaluating a rule with an EQUAL relation to a whitelist"""
+ # Create a model
+ model = MagicMock(spec=RTBH)
+
+ # Create test data
+ flashes = []
+ user_id = 1
+ author = "test@example.com / Test Org"
+ whitelist_key = "192.168.1.0/24"
+ wl_cache = {whitelist_key: whitelist_fixture}
+ results = [(str(model), whitelist_key, Relation.EQUAL)]
+
+ # Call the function with mocked whitelist_rtbh_rule
+ with patch("flowapp.services.rule_service.whitelist_rtbh_rule") as mock_whitelist_rule:
+ mock_whitelist_rule.return_value = model
+
+ with app.app_context():
+ result = rule_service.evaluate_rtbh_against_whitelists_check_results(
+ user_id, model, flashes, author, wl_cache, results
+ )
+
+ # Verify the rule was whitelisted
+ mock_whitelist_rule.assert_called_once_with(model, whitelist_fixture)
+
+ # Verify the flash message
+ assert flashes
+
+ # Verify the correct model was returned
+ assert result == model
+
+ def test_subnet_relation(self, app, whitelist_fixture):
+ """Test evaluating a rule with a SUBNET relation to a whitelist"""
+ # Create a model
+ model = MagicMock(spec=RTBH)
+
+ # Create test data
+ flashes = []
+ user_id = 1
+ author = "test@example.com / Test Org"
+ whitelist_key = "192.168.1.128/25"
+ wl_cache = {whitelist_key: whitelist_fixture}
+ results = [(str(model), whitelist_key, Relation.SUBNET)]
+
+ # Call the function with mocked dependencies
+ with patch("flowapp.services.rule_service.subtract_network") as mock_subtract, patch(
+ "flowapp.services.rule_service.create_rtbh_from_whitelist_parts"
+ ) as mock_create, patch("flowapp.services.rule_service.add_rtbh_rule_to_cache") as mock_add_cache, patch(
+ "flowapp.services.rule_service.db.session.commit"
+ ) as mock_commit:
+
+ # Mock subtract_network to return some subnets
+ mock_subtract.return_value = ["192.168.1.0/25"]
+
+ with app.app_context():
+ _result = rule_service.evaluate_rtbh_against_whitelists_check_results(
+ user_id, model, flashes, author, wl_cache, results
+ )
+
+ # Verify subnet calculation was performed
+ mock_subtract.assert_called_once()
+
+ # Verify new rules were created for the subnets
+ mock_create.assert_called_once()
+
+ # Verify the original rule was cached
+ mock_add_cache.assert_called_once_with(model, whitelist_fixture.id, RuleOrigin.USER)
+
+ # Verify transaction was committed
+ mock_commit.assert_called_once()
+
+ # Verify the flash messages
+ assert flashes
+
+ # Verify model was updated to whitelisted state
+ assert model.rstate_id == 4
+
+ def test_supernet_relation(self, app, whitelist_fixture):
+ """Test evaluating a rule with a SUPERNET relation to a whitelist"""
+ # Create a model
+ model = MagicMock(spec=RTBH)
+
+ # Create test data
+ flashes = []
+ user_id = 1
+ author = "test@example.com / Test Org"
+ whitelist_key = "192.168.0.0/16"
+ wl_cache = {whitelist_key: whitelist_fixture}
+ results = [(str(model), whitelist_key, Relation.SUPERNET)]
+
+ # Call the function with mocked whitelist_rtbh_rule
+ with patch("flowapp.services.rule_service.whitelist_rtbh_rule") as mock_whitelist_rule:
+ mock_whitelist_rule.return_value = model
+
+ with app.app_context():
+ result = rule_service.evaluate_rtbh_against_whitelists_check_results(
+ user_id, model, flashes, author, wl_cache, results
+ )
+
+ # Verify the rule was whitelisted
+ mock_whitelist_rule.assert_called_once_with(model, whitelist_fixture)
+
+ # Verify the flash message
+ assert flashes
+
+ # Verify the correct model was returned
+ assert result == model
+
+ def test_no_relation(self, app):
+ """Test evaluating a rule with no relation to any whitelist"""
+ # Create a model
+ model = MagicMock(spec=RTBH)
+
+ # Create test data
+ flashes = []
+ user_id = 1
+ author = "test@example.com / Test Org"
+ wl_cache = {}
+ results = []
+
+ with app.app_context():
+ result = rule_service.evaluate_rtbh_against_whitelists_check_results(
+ user_id, model, flashes, author, wl_cache, results
+ )
+
+ # Verify no changes to the model and no messages
+ assert result == model
+ assert not flashes
+
+
+class TestMapWhitelistsToStrings:
+ def test_map_whitelists_to_strings(self):
+ """Test mapping whitelist objects to strings"""
+ # Create mock whitelists
+ whitelist1 = MagicMock(spec=Whitelist)
+ whitelist1.__str__.return_value = "192.168.1.0/24"
+
+ whitelist2 = MagicMock(spec=Whitelist)
+ whitelist2.__str__.return_value = "10.0.0.0/8"
+
+ whitelists = [whitelist1, whitelist2]
+
+ # Call the function
+ result = rule_service.map_whitelists_to_strings(whitelists)
+
+ # Verify the result
+ assert len(result) == 2
+ assert "192.168.1.0/24" in result
+ assert "10.0.0.0/8" in result
+ assert result["192.168.1.0/24"] == whitelist1
+ assert result["10.0.0.0/8"] == whitelist2
+
+ @patch("flowapp.services.rule_service.get_ipv6_model_if_exists")
+ @patch("flowapp.services.rule_service.messages")
+ @patch("flowapp.services.rule_service.announce_route")
+ @patch("flowapp.services.rule_service.log_route")
+ def test_update_existing_ipv6_rule(
+ self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv6_form_data
+ ):
+ """Test updating an existing IPv6 rule"""
+ # Create an existing model to return
+ existing_model = Flowspec6(
+ source=ipv6_form_data["source"],
+ source_mask=ipv6_form_data["source_mask"],
+ source_port=ipv6_form_data["source_port"] or "",
+ destination=ipv6_form_data["dest"] or "",
+ destination_mask=ipv6_form_data["dest_mask"],
+ destination_port=ipv6_form_data["dest_port"] or "",
+ next_header=ipv6_form_data["next_header"],
+ flags=";".join(ipv6_form_data["flags"]),
+ packet_len=ipv6_form_data["packet_len"] or "",
+ expires=datetime.now(),
+ user_id=1,
+ org_id=1,
+ action_id=1,
+ rstate_id=1,
+ )
+ mock_get_model.return_value = existing_model
+ mock_messages.create_ipv6.return_value = "mock command"
+
+ # Set a new expiration time
+ new_expires = datetime.now() + timedelta(days=1)
+ ipv6_form_data["expires"] = new_expires
+
+ # Call the service function
+ with app.app_context():
+ db.session.add(existing_model)
+ db.session.commit()
+
+ model, message = rule_service.create_or_update_ipv6_rule(
+ form_data=ipv6_form_data,
+ user_id=1,
+ org_id=1,
+ user_email="test@example.com",
+ org_name="Test Org",
+ )
+
+ # Verify the model was updated
+ assert model == existing_model
+ assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date()
+
+ # Verify message is still a string for IPv6 rules
+ assert message == "Existing IPv6 Rule found. Expiration time was updated to new value."
+
+ # Verify route was announced
+ mock_messages.create_ipv6.assert_called_once_with(model, ANNOUNCE)
+ mock_announce.assert_called_once()
+ mock_log.assert_called_once()
diff --git a/flowapp/tests/test_rule_service_reactivate_delete.py b/flowapp/tests/test_rule_service_reactivate_delete.py
new file mode 100644
index 00000000..22a3f9f4
--- /dev/null
+++ b/flowapp/tests/test_rule_service_reactivate_delete.py
@@ -0,0 +1,527 @@
+"""Tests for rule_service module."""
+
+import pytest
+from datetime import datetime, timedelta
+from unittest.mock import patch
+
+from flowapp.constants import RuleTypes
+from flowapp.models import (
+ RTBH,
+ Flowspec4,
+ Flowspec6,
+ Whitelist,
+)
+from flowapp.output import RouteSources
+from flowapp.services import rule_service
+
+
+@pytest.fixture
+def test_data(app, db):
+ """Fixture providing test data for rule service tests."""
+ current_time = datetime.now()
+
+ # Create a test Flowspec4 rule
+ ipv4_rule = Flowspec4(
+ source="192.168.1.1",
+ source_mask=32,
+ source_port="",
+ destination="192.168.2.1",
+ destination_mask=32,
+ destination_port="",
+ protocol="tcp",
+ flags="",
+ packet_len="",
+ fragment="",
+ expires=current_time + timedelta(hours=1),
+ comment="Test IPv4 rule",
+ action_id=1, # Using action ID 1 from test database
+ user_id=1, # Using user ID 1 from test database
+ org_id=1, # Using org ID 1 from test database
+ rstate_id=1, # Active state
+ )
+ db.session.add(ipv4_rule)
+
+ # Create a test Flowspec6 rule
+ ipv6_rule = Flowspec6(
+ source="2001:db8::1",
+ source_mask=128,
+ source_port="",
+ destination="2001:db8:1::1",
+ destination_mask=128,
+ destination_port="",
+ next_header="tcp",
+ flags="",
+ packet_len="",
+ expires=current_time + timedelta(hours=1),
+ comment="Test IPv6 rule",
+ action_id=1, # Using action ID 1 from test database
+ user_id=1, # Using user ID 1 from test database
+ org_id=1, # Using org ID 1 from test database
+ rstate_id=1, # Active state
+ )
+ db.session.add(ipv6_rule)
+
+ # Create a test RTBH rule
+ rtbh_rule = RTBH(
+ ipv4="192.168.1.100",
+ ipv4_mask=32,
+ ipv6=None,
+ ipv6_mask=None,
+ community_id=1, # Using community ID 1 from test database
+ expires=current_time + timedelta(hours=1),
+ comment="Test RTBH rule",
+ user_id=1, # Using user ID 1 from test database
+ org_id=1, # Using org ID 1 from test database
+ rstate_id=1, # Active state
+ )
+ db.session.add(rtbh_rule)
+
+ # Create a test Whitelist
+ whitelist = Whitelist(
+ ip="192.168.2.1",
+ mask=24,
+ expires=current_time + timedelta(days=7),
+ comment="Test whitelist",
+ user_id=1,
+ org_id=1,
+ rstate_id=1,
+ )
+ db.session.add(whitelist)
+
+ db.session.commit()
+
+ # Return data that will be useful for tests
+ return {
+ "user_id": 1,
+ "org_id": 1,
+ "user_email": "test@example.com",
+ "org_name": "Test Org",
+ "ipv4_rule_id": ipv4_rule.id,
+ "ipv6_rule_id": ipv6_rule.id,
+ "rtbh_rule_id": rtbh_rule.id,
+ "whitelist_id": whitelist.id,
+ "current_time": current_time,
+ "future_time": current_time + timedelta(hours=1),
+ "past_time": current_time - timedelta(hours=1),
+ "comment": "Test comment",
+ }
+
+
+class TestReactivateRule:
+ """Tests for the reactivate_rule function."""
+
+ @patch("flowapp.services.rule_service.check_global_rule_limit")
+ @patch("flowapp.services.rule_service.check_rule_limit")
+ @patch("flowapp.services.rule_service.announce_route")
+ @patch("flowapp.services.rule_service.log_route")
+ def test_reactivate_rule_active(
+ self, mock_log_route, mock_announce_route, mock_check_rule_limit, mock_check_global_limit, test_data, db
+ ):
+ """Test reactivating a rule with future expiration (active state)."""
+ # Setup mocks
+ mock_check_global_limit.return_value = False
+ mock_check_rule_limit.return_value = False
+
+ # The rule will be active (state=1) because expiration is in the future
+ expires = test_data["future_time"]
+
+ # Call the function
+ model, messages = rule_service.reactivate_rule(
+ rule_type=RuleTypes.IPv4,
+ rule_id=test_data["ipv4_rule_id"],
+ expires=expires,
+ comment=test_data["comment"],
+ user_id=test_data["user_id"],
+ org_id=test_data["org_id"],
+ user_email=test_data["user_email"],
+ org_name=test_data["org_name"],
+ )
+
+ # Assertions
+ mock_check_global_limit.assert_called_once_with(RuleTypes.IPv4.value)
+ mock_check_rule_limit.assert_called_once_with(test_data["org_id"], rule_type=RuleTypes.IPv4.value)
+
+ # Verify model was updated
+ assert model.expires == expires
+ assert model.comment == test_data["comment"]
+ assert model.rstate_id == 1 # Active state
+
+ # Verify route announcement was made
+ mock_announce_route.assert_called_once()
+ args, _ = mock_announce_route.call_args
+ assert args[0].source == RouteSources.UI
+
+ # Verify logging
+ mock_log_route.assert_called_once()
+
+ # Check returned values
+ assert messages
+ assert messages[0].startswith("Rule ")
+
+ @patch("flowapp.services.rule_service.check_global_rule_limit")
+ @patch("flowapp.services.rule_service.check_rule_limit")
+ @patch("flowapp.services.rule_service.announce_route")
+ @patch("flowapp.services.rule_service.log_withdraw")
+ def test_reactivate_rule_inactive(
+ self, mock_log_withdraw, mock_announce_route, mock_check_rule_limit, mock_check_global_limit, test_data, db
+ ):
+ """Test reactivating a rule with past expiration (inactive state)."""
+ # Setup mocks
+ mock_check_global_limit.return_value = False
+ mock_check_rule_limit.return_value = False
+
+ # The rule will be inactive (state=2) because expiration is in the past
+ expires = test_data["past_time"]
+
+ # Call the function
+ model, messages = rule_service.reactivate_rule(
+ rule_type=RuleTypes.IPv4,
+ rule_id=test_data["ipv4_rule_id"],
+ expires=expires,
+ comment=test_data["comment"],
+ user_id=test_data["user_id"],
+ org_id=test_data["org_id"],
+ user_email=test_data["user_email"],
+ org_name=test_data["org_name"],
+ )
+
+ # Verify model was updated
+ assert model.expires == expires
+ assert model.comment == test_data["comment"]
+ assert model.rstate_id == 2 # Inactive state
+
+ # Verify route withdrawal was made
+ mock_announce_route.assert_called_once()
+ args, _ = mock_announce_route.call_args
+ assert args[0].source == RouteSources.UI
+
+ # Verify logging
+ mock_log_withdraw.assert_called_once()
+
+ # Check returned values
+ assert messages
+ assert messages[0].startswith("Rule ")
+
+ @patch("flowapp.services.rule_service.check_global_rule_limit")
+ @patch("flowapp.services.rule_service.check_rule_limit")
+ def test_reactivate_rule_global_limit_reached(self, mock_check_rule_limit, mock_check_global_limit, test_data, db):
+ """Test reactivating a rule when global limit is reached."""
+ # Setup mocks
+ mock_check_global_limit.return_value = True # Global limit reached
+ mock_check_rule_limit.return_value = False
+
+ # The rule would be active, but global limit is reached
+ expires = test_data["future_time"]
+
+ # Call the function
+ model, messages = rule_service.reactivate_rule(
+ rule_type=RuleTypes.IPv4,
+ rule_id=test_data["ipv4_rule_id"],
+ expires=expires,
+ comment=test_data["comment"],
+ user_id=test_data["user_id"],
+ org_id=test_data["org_id"],
+ user_email=test_data["user_email"],
+ org_name=test_data["org_name"],
+ )
+
+ # Assertions
+ mock_check_global_limit.assert_called_once_with(RuleTypes.IPv4.value)
+
+ # Check returned values
+ assert messages == ["global_limit_reached"]
+
+ @patch("flowapp.services.rule_service.check_global_rule_limit")
+ @patch("flowapp.services.rule_service.check_rule_limit")
+ def test_reactivate_rule_org_limit_reached(self, mock_check_rule_limit, mock_check_global_limit, test_data, db):
+ """Test reactivating a rule when organization limit is reached."""
+ # Setup mocks
+ mock_check_global_limit.return_value = False
+ mock_check_rule_limit.return_value = True # Org limit reached
+
+ # The rule would be active, but org limit is reached
+ expires = test_data["future_time"]
+
+ # Call the function
+ model, messages = rule_service.reactivate_rule(
+ rule_type=RuleTypes.IPv4,
+ rule_id=test_data["ipv4_rule_id"],
+ expires=expires,
+ comment=test_data["comment"],
+ user_id=test_data["user_id"],
+ org_id=test_data["org_id"],
+ user_email=test_data["user_email"],
+ org_name=test_data["org_name"],
+ )
+
+ # Assertions
+ mock_check_global_limit.assert_called_once_with(RuleTypes.IPv4.value)
+ mock_check_rule_limit.assert_called_once_with(test_data["org_id"], rule_type=RuleTypes.IPv4.value)
+
+ # Check returned values
+ assert messages == ["limit_reached"]
+
+
+class TestDeleteRule:
+ """Tests for the delete_rule function."""
+
+ @patch("flowapp.services.rule_service.announce_route")
+ @patch("flowapp.services.rule_service.log_withdraw")
+ def test_delete_rule_success(self, mock_log_withdraw, mock_announce_route, test_data, db):
+ """Test successful rule deletion."""
+ # Call the function
+ success, message = rule_service.delete_rule(
+ rule_type=RuleTypes.IPv4,
+ rule_id=test_data["ipv4_rule_id"],
+ user_id=test_data["user_id"],
+ user_email=test_data["user_email"],
+ org_name=test_data["org_name"],
+ allowed_rule_ids=[test_data["ipv4_rule_id"]], # Rule is in allowed list
+ )
+
+ # Verify route withdrawal was made
+ mock_announce_route.assert_called_once()
+ args, _ = mock_announce_route.call_args
+ assert args[0].source == RouteSources.UI
+
+ # Verify logging
+ mock_log_withdraw.assert_called_once()
+
+ # Verify database operations - rule should be deleted
+ rule = db.session.get(Flowspec4, test_data["ipv4_rule_id"])
+ assert rule is None
+
+ # Check returned values
+ assert success is True
+ assert message == "Rule deleted successfully"
+
+ def test_delete_rule_not_allowed(self, test_data, db):
+ """Test rule deletion when rule is not in allowed list."""
+ # Call the function
+ success, message = rule_service.delete_rule(
+ rule_type=RuleTypes.IPv4,
+ rule_id=test_data["ipv4_rule_id"],
+ user_id=test_data["user_id"],
+ user_email=test_data["user_email"],
+ org_name=test_data["org_name"],
+ allowed_rule_ids=[999], # Rule is not in allowed list
+ )
+
+ # Verify rule still exists
+ rule = db.session.get(Flowspec4, test_data["ipv4_rule_id"])
+ assert rule is not None
+
+ # Check returned values
+ assert success is False
+ assert message == "You cannot delete this rule"
+
+ def test_delete_rule_not_found(self, test_data, db):
+ """Test rule deletion when rule is not found."""
+ # Call the function with non-existent ID
+ success, message = rule_service.delete_rule(
+ rule_type=RuleTypes.IPv4,
+ rule_id=9999, # Non-existent ID
+ user_id=test_data["user_id"],
+ user_email=test_data["user_email"],
+ org_name=test_data["org_name"],
+ )
+
+ # Check returned values
+ assert success is False
+ assert message == "Rule not found"
+
+ @patch("flowapp.services.rule_service.announce_route")
+ @patch("flowapp.services.rule_service.log_withdraw")
+ @patch("flowapp.services.rule_service.RuleWhitelistCache")
+ def test_delete_rtbh_rule(self, mock_whitelist_cache, mock_log_withdraw, mock_announce_route, test_data, db):
+ """Test deleting an RTBH rule."""
+ # Call the function
+ success, message = rule_service.delete_rule(
+ rule_type=RuleTypes.RTBH,
+ rule_id=test_data["rtbh_rule_id"],
+ user_id=test_data["user_id"],
+ user_email=test_data["user_email"],
+ org_name=test_data["org_name"],
+ allowed_rule_ids=[test_data["rtbh_rule_id"]],
+ )
+
+ # Verify whitelist cache was cleaned
+ mock_whitelist_cache.delete_by_rule_id.assert_called_once_with(test_data["rtbh_rule_id"])
+
+ # Verify route withdrawal and logging
+ mock_announce_route.assert_called_once()
+ mock_log_withdraw.assert_called_once()
+
+ # Verify rule was deleted
+ rule = db.session.get(RTBH, test_data["rtbh_rule_id"])
+ assert rule is None
+
+ # Check returned values
+ assert success is True
+ assert message == "Rule deleted successfully"
+
+
+class TestDeleteRtbhAndCreateWhitelist:
+ """Tests for the delete_rtbh_and_create_whitelist function."""
+
+ @patch("flowapp.services.rule_service.delete_rule")
+ @patch("flowapp.services.rule_service.create_or_update_whitelist")
+ def test_delete_rtbh_and_create_whitelist_success(self, mock_create_whitelist, mock_delete_rule, test_data, db):
+ """Test successful RTBH deletion and whitelist creation."""
+ # Setup mock for delete_rule service
+ mock_delete_rule.return_value = (True, "Rule deleted successfully")
+
+ # Setup mock for create_or_update_whitelist service
+ mock_whitelist = Whitelist(
+ ip="192.168.1.100",
+ mask=32,
+ expires=datetime.now() + timedelta(days=7),
+ user_id=test_data["user_id"],
+ org_id=test_data["org_id"],
+ )
+ mock_create_whitelist.return_value = (mock_whitelist, ["Whitelist created"])
+
+ # Call the function
+ success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist(
+ rule_id=test_data["rtbh_rule_id"],
+ user_id=test_data["user_id"],
+ org_id=test_data["org_id"],
+ user_email=test_data["user_email"],
+ org_name=test_data["org_name"],
+ allowed_rule_ids=[test_data["rtbh_rule_id"]],
+ )
+
+ # Verify delete_rule was called
+ mock_delete_rule.assert_called_once_with(
+ rule_type=RuleTypes.RTBH,
+ rule_id=test_data["rtbh_rule_id"],
+ user_id=test_data["user_id"],
+ user_email=test_data["user_email"],
+ org_name=test_data["org_name"],
+ allowed_rule_ids=[test_data["rtbh_rule_id"]],
+ )
+
+ # Verify create_or_update_whitelist was called with correct data
+ mock_create_whitelist.assert_called_once()
+ args, kwargs = mock_create_whitelist.call_args
+ form_data = kwargs.get("form_data", args[0] if args else None)
+ assert form_data["ip"] == "192.168.1.100"
+ assert form_data["mask"] == 32
+ assert "Created from RTBH rule" in form_data["comment"]
+
+ # Check returned values
+ assert success is True
+ assert len(messages) == 2
+ assert messages[0] == "Rule deleted successfully"
+ assert messages[1] == "Whitelist created"
+ assert whitelist == mock_whitelist
+
+ def test_delete_rtbh_and_create_whitelist_rule_not_found(self, test_data, db):
+ """Test when the RTBH rule to convert is not found."""
+ # Call the function with non-existent ID
+ success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist(
+ rule_id=9999, # Non-existent ID
+ user_id=test_data["user_id"],
+ org_id=test_data["org_id"],
+ user_email=test_data["user_email"],
+ org_name=test_data["org_name"],
+ )
+
+ # Check returned values
+ assert success is False
+ assert messages == ["RTBH rule not found"]
+ assert whitelist is None
+
+ def test_delete_rtbh_and_create_whitelist_not_allowed(self, test_data, db):
+ """Test when the user is not allowed to delete the rule."""
+ # Call the function
+ success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist(
+ rule_id=test_data["rtbh_rule_id"],
+ user_id=test_data["user_id"],
+ org_id=test_data["org_id"],
+ user_email=test_data["user_email"],
+ org_name=test_data["org_name"],
+ allowed_rule_ids=[999], # Rule not in allowed list
+ )
+
+ # Check returned values
+ assert success is False
+ assert messages == ["You cannot delete this rule"]
+ assert whitelist is None
+
+ @patch("flowapp.services.rule_service.delete_rule")
+ def test_delete_rtbh_and_create_whitelist_delete_fails(self, mock_delete_rule, test_data, db):
+ """Test when the RTBH deletion fails."""
+ # Setup mock for delete_rule service to fail
+ mock_delete_rule.return_value = (False, "Error deleting rule")
+
+ # Call the function
+ success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist(
+ rule_id=test_data["rtbh_rule_id"],
+ user_id=test_data["user_id"],
+ org_id=test_data["org_id"],
+ user_email=test_data["user_email"],
+ org_name=test_data["org_name"],
+ allowed_rule_ids=[test_data["rtbh_rule_id"]],
+ )
+
+ # Verify delete_rule was called
+ mock_delete_rule.assert_called_once()
+
+ # Check returned values
+ assert success is False
+ assert messages == ["Error deleting rule"]
+ assert whitelist is None
+
+ @patch("flowapp.services.rule_service.delete_rule")
+ @patch("flowapp.services.rule_service.create_or_update_whitelist")
+ def test_rtbh_with_ipv6(self, mock_create_whitelist, mock_delete_rule, test_data, db):
+ """Test with an IPv6 RTBH rule."""
+ # Create an IPv6 RTBH rule
+ ipv6_rtbh = RTBH(
+ ipv4=None,
+ ipv4_mask=None,
+ ipv6="2001:db8::1",
+ ipv6_mask=64,
+ community_id=1,
+ expires=datetime.now() + timedelta(hours=1),
+ comment="IPv6 RTBH rule",
+ user_id=test_data["user_id"],
+ org_id=test_data["org_id"],
+ rstate_id=1,
+ )
+ db.session.add(ipv6_rtbh)
+ db.session.commit()
+
+ # Setup mocks
+ mock_delete_rule.return_value = (True, "Rule deleted successfully")
+ mock_whitelist = Whitelist(
+ ip="2001:db8::1",
+ mask=64,
+ expires=datetime.now() + timedelta(days=7),
+ user_id=test_data["user_id"],
+ org_id=test_data["org_id"],
+ )
+ mock_create_whitelist.return_value = (mock_whitelist, ["Whitelist created"])
+
+ # Call the function
+ success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist(
+ rule_id=ipv6_rtbh.id,
+ user_id=test_data["user_id"],
+ org_id=test_data["org_id"],
+ user_email=test_data["user_email"],
+ org_name=test_data["org_name"],
+ allowed_rule_ids=[ipv6_rtbh.id],
+ )
+
+ # Verify create_or_update_whitelist was called with correct IPv6 data
+ mock_create_whitelist.assert_called_once()
+ args, kwargs = mock_create_whitelist.call_args
+ form_data = kwargs.get("form_data", args[0] if args else None)
+ assert form_data["ip"] == "2001:db8::1"
+ assert form_data["mask"] == 64
+
+ # Check returned values
+ assert success is True
+ assert len(messages) == 2
+ assert whitelist == mock_whitelist
diff --git a/flowapp/tests/test_validators.py b/flowapp/tests/test_validators.py
index 33e757cf..cab089f0 100644
--- a/flowapp/tests/test_validators.py
+++ b/flowapp/tests/test_validators.py
@@ -1,11 +1,28 @@
import pytest
-import flowapp.validators
+from flowapp.validators import (
+ PortString,
+ PacketString,
+ NetRangeString,
+ NetInRange,
+ IPAddress,
+ IPAddressValidator,
+ NetworkValidator,
+ ValidationError,
+ address_in_range,
+ address_with_mask,
+ IPv4Address,
+ IPv6Address,
+ DateNotExpired,
+ editable_range,
+ network_in_range,
+ range_in_network,
+)
def test_port_string_len_raises(field):
- port = flowapp.validators.PortString()
+ port = PortString()
field.data = "1;2;3;4;5;6;7;8"
- with pytest.raises(flowapp.validators.ValidationError):
+ with pytest.raises(ValidationError):
port(None, field)
@@ -20,12 +37,12 @@ def test_port_string_len_raises(field):
],
)
def test_is_valid_address_with_mask(address, mask, expected):
- assert flowapp.validators.address_with_mask(address, mask) == expected
+ assert address_with_mask(address, mask) == expected
@pytest.mark.parametrize("address", ["147.230.23.25", "147.230.23.0"])
def test_ip4address_passes(field, address):
- adr = flowapp.validators.IPv4Address()
+ adr = IPv4Address()
field.data = address
adr(None, field)
@@ -38,7 +55,7 @@ def test_ip4address_passes(field, address):
],
)
def test_ip6address_passes(field, address):
- adr = flowapp.validators.IPv6Address()
+ adr = IPv6Address()
field.data = address
adr(None, field)
@@ -51,19 +68,17 @@ def test_ip6address_passes(field, address):
],
)
def test_bad_ip6address_raises(field, address):
- adr = flowapp.validators.IPv4Address()
+ adr = IPv4Address()
field.data = address
- with pytest.raises(flowapp.validators.ValidationError):
+ with pytest.raises(ValidationError):
adr(None, field)
-@pytest.mark.parametrize(
- "expired", ["2018/10/25 14:46", "2018/12/20 9:46", "2019/05/22 12:33"]
-)
+@pytest.mark.parametrize("expired", ["2018/10/25 14:46", "2018/12/20 9:46", "2019/05/22 12:33"])
def test_expired_date_raises(field, expired):
- adr = flowapp.validators.DateNotExpired()
+ adr = DateNotExpired()
field.data = expired
- with pytest.raises(flowapp.validators.ValidationError):
+ with pytest.raises(ValidationError):
adr(None, field)
@@ -75,9 +90,9 @@ def test_expired_date_raises(field, expired):
],
)
def test_ipaddress_raises(field, address):
- adr = flowapp.validators.IPv6Address()
+ adr = IPv6Address()
field.data = address
- with pytest.raises(flowapp.validators.ValidationError):
+ with pytest.raises(ValidationError):
adr(None, field)
@@ -91,7 +106,7 @@ def test_ipaddress_raises(field, address):
def test_editable_rule(rule, address, mask, ranges, expected):
rule.source = address
rule.source_mask = mask
- assert flowapp.validators.editable_range(rule, ranges) == expected
+ assert editable_range(rule, ranges) == expected
@pytest.mark.parametrize(
@@ -104,7 +119,7 @@ def test_editable_rule(rule, address, mask, ranges, expected):
],
)
def test_address_in_range(address, mask, ranges, expected):
- assert flowapp.validators.address_in_range(address, ranges) == expected
+ assert address_in_range(address, ranges) == expected
@pytest.mark.parametrize(
@@ -123,7 +138,7 @@ def test_address_in_range(address, mask, ranges, expected):
],
)
def test_network_in_range(address, mask, ranges, expected):
- assert flowapp.validators.network_in_range(address, mask, ranges) == expected
+ assert network_in_range(address, mask, ranges) == expected
@pytest.mark.parametrize(
@@ -133,4 +148,209 @@ def test_network_in_range(address, mask, ranges, expected):
],
)
def test_range_in_network(address, mask, ranges, expected):
- assert flowapp.validators.range_in_network(address, mask, ranges) == expected
+ assert range_in_network(address, mask, ranges) == expected
+
+
+### new tests
+
+
+# New tests for previously uncovered validators
+
+
+# Tests for PortString validator syntax
+@pytest.mark.parametrize(
+ "port_data",
+ [
+ "80", # Simple number
+ "80-443", # Range with hyphen
+ ">=80&<=443", # Explicit range
+ ">80", # Greater than
+ ">=80", # Greater than or equal
+ "<443", # Less than
+ "<=443", # Less than or equal
+ "80;443;8080", # Multiple port expressions
+ ],
+)
+def test_port_string_valid_syntax(field, port_data):
+ validator = PortString()
+ field.data = port_data
+ validator(None, field) # Should not raise
+
+
+@pytest.mark.parametrize(
+ "invalid_port_data",
+ [
+ "80,443", # Comma not supported, should use semicolon
+ "80..443", # Invalid range syntax (should be 80-443)
+ ">65536", # Port number too high
+ "-1", # Negative number
+ "<-1", # Port number too low
+ ">=80<=443", # Invalid range syntax (missing &)
+ "!=80", # Not equals not supported
+ "abc", # Non-numeric
+ "80-", # Incomplete range
+ "-80", # Invalid range
+ "80&&443", # Invalid operator
+ "443-80", # End less than start
+ ">=443&<=80", # End less than start in explicit range
+ "0-65537", # Range exceeding max
+ ],
+)
+def test_port_string_invalid_syntax(field, invalid_port_data):
+ validator = PortString()
+ field.data = invalid_port_data
+ with pytest.raises(ValidationError):
+ validator(None, field)
+
+
+# Tests for PacketString validator
+@pytest.mark.parametrize(
+ "packet_data",
+ ["1500", ">1000", "<9000", "1000-1500", ">=1000&<=1500", "1500;9000"], # Multiple packet size expressions
+)
+def test_packet_string_valid(field, packet_data):
+ validator = PacketString()
+ field.data = packet_data
+ validator(None, field) # Should not raise
+
+
+@pytest.mark.parametrize(
+ "invalid_packet_data",
+ [
+ "65536", # Too large
+ "-1", # Negative
+ "1500..", # Invalid range syntax
+ "abc", # Non-numeric
+ "!!1500", # Invalid operator
+ "1500-", # Incomplete range
+ "9000-1500", # End less than start
+ ],
+)
+def test_packet_string_invalid(field, invalid_packet_data):
+ validator = PacketString()
+ field.data = invalid_packet_data
+ with pytest.raises(ValidationError):
+ validator(None, field)
+
+
+# Tests for NetRangeString validator
+@pytest.mark.parametrize(
+ "net_range",
+ [
+ "192.168.0.0/24",
+ "10.0.0.0/8\n172.16.0.0/12",
+ "2001:db8::/32",
+ "192.168.0.0/24 10.0.0.0/8",
+ "2001:db8::/32 2001:db8:1::/48",
+ ],
+)
+def test_net_range_string_valid(field, net_range):
+ validator = NetRangeString()
+ field.data = net_range
+ validator(None, field) # Should not raise
+
+
+@pytest.mark.parametrize(
+ "invalid_net_range",
+ [
+ "192.168.0.0/33", # Invalid mask
+ "256.256.256.0/24", # Invalid IP
+ "2001:xyz::/32", # Invalid IPv6
+ "192.168.0.0/24/", # Malformed
+ "not-an-ip/24", # Invalid format
+ ],
+)
+def test_net_range_string_invalid(field, invalid_net_range):
+ validator = NetRangeString()
+ field.data = invalid_net_range
+ with pytest.raises(ValidationError):
+ validator(None, field)
+
+
+# Tests for NetInRange validator
+def test_net_in_range_valid(field):
+ net_ranges = ["192.168.0.0/16", "10.0.0.0/8"]
+ validator = NetInRange(net_ranges)
+ field.data = "192.168.1.1/24"
+ validator(None, field) # Should not raise
+
+
+def test_net_in_range_invalid(field):
+ net_ranges = ["192.168.0.0/16", "10.0.0.0/8"]
+ validator = NetInRange(net_ranges)
+ field.data = "172.16.1.1/24"
+ with pytest.raises(ValidationError):
+ validator(None, field)
+
+
+# Tests for base IPAddress validator
+@pytest.mark.parametrize("ip_addr", ["192.168.1.1", "10.0.0.1", "2001:db8::1", "fe80::1"])
+def test_ip_address_valid(field, ip_addr):
+ validator = IPAddress()
+ field.data = ip_addr
+ validator(None, field) # Should not raise
+
+
+@pytest.mark.parametrize("invalid_ip", ["256.256.256.256", "2001:xyz::1", "not-an-ip", "192.168.1"])
+def test_ip_address_invalid(field, invalid_ip):
+ validator = IPAddress()
+ field.data = invalid_ip
+ with pytest.raises(ValidationError):
+ validator(None, field)
+
+
+# Tests for universal IPAddressValidator
+@pytest.mark.parametrize("ip_addr", ["192.168.1.1", "2001:db8::1", "", None]) # Empty is allowed # None is allowed
+def test_ip_address_validator_valid(field, ip_addr):
+ validator = IPAddressValidator()
+ field.data = ip_addr
+ validator(None, field) # Should not raise
+
+
+@pytest.mark.parametrize("invalid_ip", ["256.256.256.256", "2001:xyz::1", "not-an-ip"])
+def test_ip_address_validator_invalid(field, invalid_ip):
+ validator = IPAddressValidator()
+ field.data = invalid_ip
+ with pytest.raises(ValidationError):
+ validator(None, field)
+
+
+# Tests for NetworkValidator
+class MockForm:
+ def __init__(self, address, mask):
+ self._fields = {"mask": type("MockField", (), {"data": mask})()}
+
+
+@pytest.mark.parametrize(
+ "address,mask",
+ [
+ ("192.168.0.0", "24"),
+ ("10.0.0.0", "8"),
+ ("2001:db8::", "32"),
+ ("", None), # Empty values should pass
+ (None, None), # None values should pass
+ ],
+)
+def test_network_validator_valid(field, address, mask):
+ validator = NetworkValidator("mask")
+ field.data = address
+ form = MockForm(address, mask)
+ validator(form, field) # Should not raise
+
+
+@pytest.mark.parametrize(
+ "address,mask",
+ [
+ ("192.168.1.1", "24"), # Not a network address
+ ("192.168.0.0", "33"), # Invalid IPv4 mask
+ ("2001:db8::", "129"), # Invalid IPv6 mask
+ ("256.256.256.0", "24"), # Invalid IP
+ ("2001:xyz::", "32"), # Invalid IPv6
+ ],
+)
+def test_network_validator_invalid(field, address, mask):
+ validator = NetworkValidator("mask")
+ field.data = address
+ form = MockForm(address, mask)
+ with pytest.raises(ValidationError):
+ validator(form, field)
diff --git a/flowapp/tests/test_whitelist_common.py b/flowapp/tests/test_whitelist_common.py
new file mode 100644
index 00000000..a843aff9
--- /dev/null
+++ b/flowapp/tests/test_whitelist_common.py
@@ -0,0 +1,250 @@
+import pytest
+from flowapp.services.whitelist_common import (
+ Relation,
+ check_rule_against_whitelists,
+ check_whitelist_against_rules,
+ check_whitelist_to_rule_relation,
+ subtract_network,
+ clear_network_cache,
+ _is_same_ip_version, # New helper function
+)
+
+
+# Tests for core function that checks network relations
+@pytest.mark.parametrize(
+ "rule,whitelist,expected",
+ [
+ # IPv4 test cases
+ ("192.168.1.0/24", "192.168.0.0/16", Relation.SUPERNET), # Whitelist is supernet
+ ("192.168.1.0/24", "192.168.1.0/24", Relation.EQUAL), # Equal networks
+ ("192.168.1.0/24", "192.168.1.128/25", Relation.SUBNET), # Whitelist is subnet
+ ("192.168.1.128/25", "192.168.1.0/24", Relation.SUPERNET), # Whitelist is supernet
+ ("10.0.0.0/8", "192.168.1.0/24", Relation.DIFFERENT), # Different networks
+ # IPv6 test cases
+ ("2001:db8::/32", "2001:db8::/32", Relation.EQUAL), # Equal networks
+ ("2001:db8:1::/48", "2001:db8::/32", Relation.SUPERNET), # Whitelist is supernet
+ ("2001:db8::/32", "2001:db8:1::/48", Relation.SUBNET), # Whitelist is subnet
+ ("2001:db8::/32", "2002:db8::/32", Relation.DIFFERENT), # Different networks
+ ("2001:db8:1:2::/64", "2001:db8::/32", Relation.SUPERNET), # Whitelist is supernet
+ ],
+)
+def test_check_whitelist_to_rule_relation(rule, whitelist, expected):
+ clear_network_cache()
+ assert check_whitelist_to_rule_relation(rule, whitelist) == expected
+
+
+@pytest.mark.parametrize(
+ "target,whitelist,expected",
+ [
+ # Basic IPv4 cases
+ (
+ "192.168.1.0/24",
+ "192.168.1.128/25",
+ ["192.168.1.0/25"],
+ ), # One remaining subnet
+ (
+ "192.168.1.0/24",
+ "192.168.1.64/26",
+ ["192.168.1.0/26", "192.168.1.128/25"],
+ ), # Two remaining subnets
+ (
+ "192.168.1.0/24",
+ "192.168.1.0/24",
+ ["192.168.1.0/24"],
+ ), # Equal networks - return original network
+ (
+ "192.168.1.0/24",
+ "192.168.2.0/24",
+ ["192.168.1.0/24"],
+ ), # No overlap
+ ],
+)
+def test_subtract_network(target, whitelist, expected):
+ clear_network_cache()
+ result = subtract_network(target, whitelist)
+ assert sorted(result) == sorted(expected)
+
+
+def test_check_rule_against_whitelists():
+ clear_network_cache()
+
+ rule = "192.168.1.0/24"
+ whitelists = [
+ "192.168.0.0/16", # SUPERNET
+ "172.16.0.0/12", # DIFFERENT
+ "192.168.1.0/24", # EQUAL
+ "192.168.1.128/25", # SUBNET
+ ]
+
+ results = check_rule_against_whitelists(rule, whitelists)
+
+ # Should return list of tuples (rule, whitelist, relation) for non-DIFFERENT relations
+ expected = [
+ (rule, "192.168.0.0/16", Relation.SUPERNET),
+ (rule, "192.168.1.0/24", Relation.EQUAL),
+ (rule, "192.168.1.128/25", Relation.SUBNET),
+ ]
+
+ assert len(results) == 3
+ for result, exp in zip(sorted(results), sorted(expected)):
+ assert result == exp
+
+
+def test_check_whitelist_against_rules():
+ clear_network_cache()
+
+ whitelist = "192.168.1.128/25"
+ rules = [
+ "192.168.0.0/16", # Whitelist is SUBNET of this larger network
+ "172.16.0.0/12", # DIFFERENT
+ "192.168.1.128/25", # EQUAL
+ "192.168.1.0/24", # Whitelist is SUBNET of this network
+ ]
+
+ results = check_whitelist_against_rules(rules, whitelist)
+
+ # Should return list of tuples (rule, whitelist, relation) for non-DIFFERENT relations
+ expected = [
+ ("192.168.0.0/16", whitelist, Relation.SUBNET),
+ ("192.168.1.128/25", whitelist, Relation.EQUAL),
+ ("192.168.1.0/24", whitelist, Relation.SUBNET),
+ ]
+
+ assert len(results) == 3
+ for result, exp in zip(sorted(results), sorted(expected)):
+ assert result == exp
+
+
+# Tests for edge cases and error handling
+def test_host_addresses():
+ clear_network_cache()
+ # Test with host addresses (no explicit subnet mask)
+ assert check_whitelist_to_rule_relation("192.168.1.1", "192.168.1.0/24") == Relation.SUPERNET
+ assert check_whitelist_to_rule_relation("192.168.1.1", "192.168.2.0/24") == Relation.DIFFERENT
+
+
+def test_single_ip_as_network():
+ clear_network_cache()
+ # Test with /32 (single IPv4 address as network)
+ assert check_whitelist_to_rule_relation("192.168.1.1/32", "192.168.1.0/24") == Relation.SUPERNET
+ # Test with /128 (single IPv6 address as network)
+ assert check_whitelist_to_rule_relation("2001:db8::1/128", "2001:db8::/32") == Relation.SUPERNET
+
+
+def test_is_same_ip_version():
+ """Test the new helper function to check if two addresses are of the same IP version"""
+ # IPv4 cases
+ assert _is_same_ip_version("192.168.1.0/24", "10.0.0.0/8") is True
+ assert _is_same_ip_version("192.168.1.1", "10.0.0.1") is True
+
+ # IPv6 cases
+ assert _is_same_ip_version("2001:db8::/32", "fe80::/64") is True
+ assert _is_same_ip_version("2001:db8::1", "fe80::1") is True
+
+ # Mixed cases
+ assert _is_same_ip_version("192.168.1.0/24", "2001:db8::/32") is False
+ assert _is_same_ip_version("192.168.1.1", "fe80::1") is False
+
+
+def test_check_whitelist_to_rule_relation_mixed_versions():
+ """Test the check_whitelist_to_rule_relation function with mixed IP versions"""
+ clear_network_cache()
+
+ # IPv4 rule with IPv6 whitelist
+ assert check_whitelist_to_rule_relation("192.168.1.0/24", "2001:db8::/32") == Relation.DIFFERENT
+
+ # IPv6 rule with IPv4 whitelist
+ assert check_whitelist_to_rule_relation("2001:db8::/32", "192.168.1.0/24") == Relation.DIFFERENT
+
+ # Invalid IP addresses should return DIFFERENT
+ assert check_whitelist_to_rule_relation("invalid", "192.168.1.0/24") == Relation.DIFFERENT
+ assert check_whitelist_to_rule_relation("192.168.1.0/24", "invalid") == Relation.DIFFERENT
+
+
+def test_subtract_network_mixed_versions():
+ """Test the subtract_network function with mixed IP versions"""
+ clear_network_cache()
+
+ # IPv4 target with IPv6 whitelist - should return original target
+ result = subtract_network("192.168.1.0/24", "2001:db8::/32")
+ assert result == ["192.168.1.0/24"]
+
+ # IPv6 target with IPv4 whitelist - should return original target
+ result = subtract_network("2001:db8::/32", "192.168.1.0/24")
+ assert result == ["2001:db8::/32"]
+
+ # Invalid addresses - should return original target
+ result = subtract_network("invalid", "192.168.1.0/24")
+ assert result == ["invalid"]
+ result = subtract_network("192.168.1.0/24", "invalid")
+ assert result == ["192.168.1.0/24"]
+
+
+def test_check_rule_against_whitelists_mixed_versions():
+ """Test the check_rule_against_whitelists function with mixed IP versions"""
+ clear_network_cache()
+
+ # IPv4 rule against mixed whitelist
+ rule = "192.168.1.0/24"
+ whitelists = [
+ "192.168.0.0/16", # IPv4 - should match
+ "2001:db8::/32", # IPv6 - should not match
+ "10.0.0.0/8", # IPv4 - should not match (different network)
+ ]
+
+ results = check_rule_against_whitelists(rule, whitelists)
+ assert len(results) == 1
+ assert results[0][0] == rule
+ assert results[0][1] == "192.168.0.0/16"
+ assert results[0][2] == Relation.SUPERNET
+
+ # IPv6 rule against mixed whitelist
+ rule = "2001:db8:1::/48"
+ whitelists = [
+ "192.168.0.0/16", # IPv4 - should not match
+ "2001:db8::/32", # IPv6 - should match
+ "2002:db8::/32", # IPv6 - should not match (different network)
+ ]
+
+ results = check_rule_against_whitelists(rule, whitelists)
+ assert len(results) == 1
+ assert results[0][0] == rule
+ assert results[0][1] == "2001:db8::/32"
+ assert results[0][2] == Relation.SUPERNET
+
+
+def test_check_whitelist_against_rules_mixed_versions():
+ """Test the check_whitelist_against_rules function with mixed IP versions"""
+ clear_network_cache()
+
+ # IPv4 whitelist against mixed rules
+ whitelist = "192.168.1.0/24"
+ rules = [
+ "192.168.0.0/16", # IPv4 - should match
+ "2001:db8::/32", # IPv6 - should not match
+ "10.0.0.0/8", # IPv4 - should not match (different network)
+ ]
+
+ results = check_whitelist_against_rules(rules, whitelist)
+ assert len(results) == 1
+ assert results[0][0] == "192.168.0.0/16"
+ assert results[0][1] == whitelist
+ assert results[0][2] == Relation.SUBNET
+
+ # IPv6 whitelist against mixed rules
+ whitelist = "2001:db8:1::/48"
+ rules = [
+ "192.168.0.0/16", # IPv4 - should not match
+ "2001:db8::/32", # IPv6 - should match
+ "2002:db8::/32", # IPv6 - should not match (different network)
+ ]
+
+ results = check_whitelist_against_rules(rules, whitelist)
+ assert len(results) == 1
+ assert results[0][0] == "2001:db8::/32"
+ assert results[0][1] == whitelist
+ assert results[0][2] == Relation.SUBNET
+
+
+if __name__ == "__main__":
+ pytest.main(["-v"])
diff --git a/flowapp/tests/test_whitelist_service.py b/flowapp/tests/test_whitelist_service.py
new file mode 100644
index 00000000..460a172b
--- /dev/null
+++ b/flowapp/tests/test_whitelist_service.py
@@ -0,0 +1,461 @@
+"""
+Tests for whitelist_service.py module.
+
+This test suite verifies the functionality of the whitelist service,
+which manages creation, updating, and handling of whitelist rules.
+"""
+
+import pytest
+from datetime import datetime, timedelta
+from unittest.mock import patch, MagicMock
+
+from flowapp.constants import RuleTypes, RuleOrigin
+from flowapp.models import Whitelist, RuleWhitelistCache, RTBH
+from flowapp.services import whitelist_service
+from flowapp.services.whitelist_common import Relation
+
+
+@pytest.fixture
+def whitelist_form_data():
+ """Sample valid whitelist form data"""
+ return {
+ "ip": "192.168.1.0",
+ "mask": 24,
+ "comment": "Test whitelist entry",
+ "expires": datetime.now() + timedelta(hours=1),
+ }
+
+
+class TestCreateOrUpdateWhitelist:
+ @patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists")
+ @patch("flowapp.services.whitelist_service.check_whitelist_against_rules")
+ @patch("flowapp.services.whitelist_service.evaluate_whitelist_against_rtbh_check_results")
+ def test_create_new_whitelist(
+ self, mock_evaluate, mock_check_whitelist, mock_get_model, app, db, whitelist_form_data
+ ):
+ """Test creating a new whitelist entry"""
+ # Mock the get_whitelist_model_if_exists to return False (not found)
+ mock_get_model.return_value = False
+
+ # Mock RTBH rules and check results
+ mock_rtbh_rules = []
+ mock_query = MagicMock()
+ mock_query.filter.return_value.all.return_value = mock_rtbh_rules
+
+ # Mock check_whitelist_against_rules to return empty list (no matches)
+ mock_check_whitelist.return_value = []
+
+ # Mock evaluate to return the model unchanged
+ mock_evaluate.side_effect = lambda model, flashes, rtbh_rule_cache, results: model
+
+ # Call the service function
+ with patch("flowapp.services.whitelist_service.db.session.query", return_value=mock_query):
+ with app.app_context():
+ model, flashes = whitelist_service.create_or_update_whitelist(
+ form_data=whitelist_form_data,
+ user_id=1,
+ org_id=1,
+ user_email="test@example.com",
+ org_name="Test Org",
+ )
+
+ # Verify the model was created with correct attributes
+ assert model is not None
+ assert model.ip == whitelist_form_data["ip"]
+ assert model.mask == whitelist_form_data["mask"]
+ assert model.comment == whitelist_form_data["comment"]
+ assert model.user_id == 1
+ assert model.org_id == 1
+ assert model.rstate_id == 1 # Active state
+
+ # Verify flash messages - now a list instead of a string
+ assert isinstance(flashes, list)
+ assert "Whitelist saved" in flashes[0]
+
+ @patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists")
+ @patch("flowapp.services.whitelist_service.check_whitelist_against_rules")
+ @patch("flowapp.services.whitelist_service.evaluate_whitelist_against_rtbh_check_results")
+ def test_update_existing_whitelist(
+ self, mock_evaluate, mock_check_whitelist, mock_get_model, app, db, whitelist_form_data
+ ):
+ """Test updating an existing whitelist entry"""
+ # Create an existing whitelist
+ existing_model = Whitelist(
+ ip=whitelist_form_data["ip"],
+ mask=whitelist_form_data["mask"],
+ expires=datetime.now(),
+ user_id=1,
+ org_id=1,
+ rstate_id=1,
+ )
+
+ # Mock to return the existing model
+ mock_get_model.return_value = existing_model
+
+ # Mock RTBH rules and check results
+ mock_rtbh_rules = []
+ mock_query = MagicMock()
+ mock_query.filter.return_value.all.return_value = mock_rtbh_rules
+
+ # Mock check_whitelist_against_rules to return empty list (no matches)
+ mock_check_whitelist.return_value = []
+
+ # Mock evaluate to return the model unchanged
+ mock_evaluate.side_effect = lambda model, flashes, rtbh_rule_cache, results: model
+
+ # Set a new expiration time
+ new_expires = datetime.now() + timedelta(days=1)
+ whitelist_form_data["expires"] = new_expires
+
+ # Call the service function
+ with patch("flowapp.services.whitelist_service.db.session.query", return_value=mock_query):
+ with app.app_context():
+ db.session.add(existing_model)
+ db.session.commit()
+
+ model, flashes = whitelist_service.create_or_update_whitelist(
+ form_data=whitelist_form_data,
+ user_id=1,
+ org_id=1,
+ user_email="test@example.com",
+ org_name="Test Org",
+ )
+
+ # Verify the model was updated
+ assert model == existing_model
+ # Check that expiration date was updated (rounded to 10 minutes)
+ # We can't compare exact timestamps, so check date parts
+ assert model.expires.date() == whitelist_service.round_to_ten_minutes(new_expires).date()
+
+ # Verify flash messages - now a list instead of a string
+ assert isinstance(flashes, list)
+ assert "Existing Whitelist found" in flashes[0]
+
+ @patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists")
+ @patch("flowapp.services.whitelist_service.db.session.query")
+ @patch("flowapp.services.whitelist_service.map_rtbh_rules_to_strings")
+ @patch("flowapp.services.whitelist_service.check_whitelist_against_rules")
+ @patch("flowapp.services.whitelist_service.evaluate_whitelist_against_rtbh_check_results")
+ def test_create_whitelist_with_matching_rules(
+ self, mock_evaluate, mock_check, mock_map, mock_query, mock_get_model, app, db, whitelist_form_data
+ ):
+ """Test creating a whitelist that affects existing rules"""
+ # Mock get_whitelist_model_if_exists to return False (new whitelist)
+ mock_get_model.return_value = False
+
+ # Create a mock RTBH rule
+ mock_rtbh_rule = MagicMock(spec=RTBH)
+ mock_rtbh_rule.__str__.return_value = "192.168.1.0/24"
+ mock_rtbh_rules = [mock_rtbh_rule]
+
+ # Setup mock query to return our rule
+ mock_query_result = MagicMock()
+ mock_query_result.filter.return_value.all.return_value = mock_rtbh_rules
+ mock_query.return_value = mock_query_result
+
+ # Setup map function to return our rule in a dict
+ mock_map.return_value = {"192.168.1.0/24": mock_rtbh_rule}
+
+ # Setup check function to return a relation
+ rule_relation = [
+ (
+ str(mock_rtbh_rule),
+ str(whitelist_form_data["ip"]) + "/" + str(whitelist_form_data["mask"]),
+ Relation.EQUAL,
+ )
+ ]
+ mock_check.return_value = rule_relation
+
+ # Setup evaluate function to modify flashes
+ def evaluate_side_effect(model, flashes, rtbh_rule_cache, results):
+ flashes.append("Rule was whitelisted")
+ return model
+
+ mock_evaluate.side_effect = evaluate_side_effect
+
+ # Call the service function
+ with app.app_context():
+ model, flashes = whitelist_service.create_or_update_whitelist(
+ form_data=whitelist_form_data,
+ user_id=1,
+ org_id=1,
+ user_email="test@example.com",
+ org_name="Test Org",
+ )
+
+ # Verify flash messages includes both whitelist creation and rule whitelisting
+ assert "Whitelist saved" in flashes[0]
+ assert "Rule was whitelisted" in flashes[1]
+
+ # Verify interactions
+ mock_map.assert_called_once()
+ mock_check.assert_called_once()
+ mock_evaluate.assert_called_once()
+
+
+class TestDeleteWhitelist:
+ @patch("flowapp.services.whitelist_service.announce_rtbh_route")
+ def test_delete_whitelist_with_user_rules(self, mock_announce, app, db):
+ """Test deleting a whitelist that has user-created rules attached to it"""
+ # Create a whitelist
+ whitelist = Whitelist(
+ ip="192.168.1.0",
+ mask=24,
+ expires=datetime.now() + timedelta(hours=1),
+ user_id=1,
+ org_id=1,
+ rstate_id=1,
+ )
+
+ # Create a RTBH rule (created by user, whitelisted)
+ rtbh_rule = RTBH(
+ ipv4="192.168.1.0",
+ ipv4_mask=24,
+ ipv6="",
+ ipv6_mask=None,
+ community_id=1,
+ expires=datetime.now() + timedelta(hours=1),
+ user_id=1,
+ org_id=1,
+ rstate_id=4, # Whitelisted state
+ )
+
+ # Create a cache entry linking the rule to the whitelist
+ cache_entry = RuleWhitelistCache(
+ rid=1,
+ rtype=RuleTypes.RTBH,
+ whitelist_id=1,
+ rorigin=RuleOrigin.USER,
+ )
+
+ with app.app_context():
+ db.session.add(whitelist)
+ db.session.add(rtbh_rule)
+ db.session.commit()
+
+ # Set whitelist ID now that it has been saved
+ cache_entry.whitelist_id = whitelist.id
+ cache_entry.rid = rtbh_rule.id
+ db.session.add(cache_entry)
+ db.session.commit()
+
+ # Mock the get_by_whitelist_id to return our cache entry
+ with patch.object(RuleWhitelistCache, "get_by_whitelist_id", return_value=[cache_entry]):
+ # Call the service function
+ flashes = whitelist_service.delete_whitelist(whitelist.id)
+
+ # Verify the rule state was changed back to active
+ rtbh_rule = db.session.get(RTBH, rtbh_rule.id)
+ assert rtbh_rule.rstate_id == 1 # Active state
+
+ # Verify announcement was made
+ mock_announce.assert_called_once()
+
+ # Verify flash messages
+ assert isinstance(flashes, list)
+ assert flashes
+
+ # Verify the whitelist was deleted
+ assert db.session.get(Whitelist, whitelist.id) is None
+
+ @patch("flowapp.services.whitelist_service.RuleWhitelistCache.clean_by_whitelist_id")
+ def test_delete_whitelist_with_whitelist_created_rules(self, mock_clean, app, db):
+ """Test deleting a whitelist that has rules created by the whitelist"""
+ # Create a whitelist
+ whitelist = Whitelist(
+ ip="192.168.1.0",
+ mask=24,
+ expires=datetime.now() + timedelta(hours=1),
+ user_id=1,
+ org_id=1,
+ rstate_id=1,
+ )
+
+ # Create a RTBH rule (created by whitelist)
+ rtbh_rule = RTBH(
+ ipv4="192.168.1.0",
+ ipv4_mask=24,
+ ipv6="",
+ ipv6_mask=None,
+ community_id=1,
+ expires=datetime.now() + timedelta(hours=1),
+ user_id=1,
+ org_id=1,
+ rstate_id=1,
+ )
+
+ # Create a cache entry linking the rule to the whitelist
+ cache_entry = RuleWhitelistCache(
+ rid=1,
+ rtype=RuleTypes.RTBH,
+ whitelist_id=1,
+ rorigin=RuleOrigin.WHITELIST, # Important: created BY whitelist
+ )
+
+ with app.app_context():
+ db.session.add(whitelist)
+ db.session.add(rtbh_rule)
+ db.session.commit()
+
+ # Set IDs now that they have been saved
+ cache_entry.whitelist_id = whitelist.id
+ cache_entry.rid = rtbh_rule.id
+ db.session.add(cache_entry)
+ db.session.commit()
+
+ # Create a mock session that can get our rule
+ with patch.object(RuleWhitelistCache, "get_by_whitelist_id", return_value=[cache_entry]):
+ # Call the service function
+ flashes = whitelist_service.delete_whitelist(whitelist.id)
+
+ # Verify flash messages
+ assert isinstance(flashes, list)
+ assert flashes
+
+ # Verify the rule was deleted
+ assert db.session.get(RTBH, rtbh_rule.id) is None
+
+ # Verify the whitelist was deleted
+ assert db.session.get(Whitelist, whitelist.id) is None
+
+ # Verify cache cleanup was called
+ mock_clean.assert_called_once_with(whitelist.id)
+
+ def test_delete_nonexistent_whitelist(self, app, db):
+ """Test deleting a whitelist that doesn't exist"""
+ with app.app_context():
+ # Call the service function with a non-existent ID
+ flashes = whitelist_service.delete_whitelist(999)
+
+ # Should return empty list of flash messages, as no whitelist was found
+ assert isinstance(flashes, list)
+ assert len(flashes) == 0
+
+
+class TestEvaluateWhitelistAgainstRtbhResults:
+ def test_equal_relation(self, app):
+ """Test evaluating a whitelist with an EQUAL relation to a rule"""
+ # Create test data
+ whitelist_model = MagicMock(spec=Whitelist)
+ whitelist_model.id = 1
+
+ flashes = []
+
+ rtbh_rule = MagicMock(spec=RTBH)
+ rtbh_rule.rstate_id = 1 # Active state
+
+ rule_key = "192.168.1.0/24"
+ whitelist_key = "192.168.1.0/24"
+ rtbh_rule_cache = {rule_key: rtbh_rule}
+
+ results = [(rule_key, whitelist_key, Relation.EQUAL)]
+
+ # Mock required functions
+ with patch("flowapp.services.whitelist_service.whitelist_rtbh_rule") as mock_whitelist_rule, patch(
+ "flowapp.services.whitelist_service.withdraw_rtbh_route"
+ ) as mock_withdraw:
+
+ # Call the function
+ with app.app_context():
+ result = whitelist_service.evaluate_whitelist_against_rtbh_check_results(
+ whitelist_model, flashes, rtbh_rule_cache, results
+ )
+
+ # Verify the rule was whitelisted and route withdrawn
+ mock_whitelist_rule.assert_called_once_with(rtbh_rule, whitelist_model)
+ mock_withdraw.assert_called_once_with(rtbh_rule)
+
+ # Verify the flash message
+ assert flashes
+
+ # Verify the correct model was returned
+ assert result == whitelist_model
+
+ def test_subnet_relation(self, app):
+ """Test evaluating a whitelist with a SUBNET relation to a rule"""
+ # Create test data
+ whitelist_model = MagicMock(spec=Whitelist)
+ whitelist_model.id = 1
+
+ flashes = []
+
+ rtbh_rule = MagicMock(spec=RTBH)
+ rtbh_rule.rstate_id = 1 # Active state
+
+ rule_key = "192.168.1.0/24"
+ whitelist_key = "192.168.1.128/25"
+ rtbh_rule_cache = {rule_key: rtbh_rule}
+
+ results = [(rule_key, whitelist_key, Relation.SUBNET)]
+
+ # Mock required functions
+ with patch("flowapp.services.whitelist_service.subtract_network") as mock_subtract, patch(
+ "flowapp.services.whitelist_service.create_rtbh_from_whitelist_parts"
+ ) as mock_create, patch("flowapp.services.whitelist_service.add_rtbh_rule_to_cache") as mock_add_cache, patch(
+ "flowapp.services.whitelist_service.db.session.commit"
+ ) as mock_commit:
+
+ # Mock subtract_network to return some subnets
+ mock_subtract.return_value = ["192.168.1.0/25"]
+
+ # Call the function
+ with app.app_context():
+ _result = whitelist_service.evaluate_whitelist_against_rtbh_check_results(
+ whitelist_model, flashes, rtbh_rule_cache, results
+ )
+
+ # Verify subnet calculation was performed
+ mock_subtract.assert_called_once()
+
+ # Verify new rules were created for the subnets
+ mock_create.assert_called_once()
+
+ # Verify the original rule was cached
+ mock_add_cache.assert_called_once_with(rtbh_rule, whitelist_model.id, RuleOrigin.USER)
+
+ # Verify transaction was committed
+ mock_commit.assert_called_once()
+
+ # Verify the flash messages
+ assert any("supernet of whitelist" in msg for msg in flashes)
+
+ # Verify model was updated to whitelisted state
+ assert rtbh_rule.rstate_id == 4
+
+ def test_supernet_relation(self, app):
+ """Test evaluating a whitelist with a SUPERNET relation to a rule"""
+ # Create test data
+ whitelist_model = MagicMock(spec=Whitelist)
+ whitelist_model.id = 1
+
+ flashes = []
+
+ rtbh_rule = MagicMock(spec=RTBH)
+ rtbh_rule.rstate_id = 1 # Active state
+
+ rule_key = "192.168.1.0/24"
+ whitelist_key = "192.168.0.0/16"
+ rtbh_rule_cache = {rule_key: rtbh_rule}
+
+ results = [(rule_key, whitelist_key, Relation.SUPERNET)]
+
+ # Mock required functions
+ with patch("flowapp.services.whitelist_service.whitelist_rtbh_rule") as mock_whitelist_rule, patch(
+ "flowapp.services.whitelist_service.withdraw_rtbh_route"
+ ) as mock_withdraw:
+
+ # Call the function
+ with app.app_context():
+ result = whitelist_service.evaluate_whitelist_against_rtbh_check_results(
+ whitelist_model, flashes, rtbh_rule_cache, results
+ )
+
+ # Verify the rule was whitelisted and route withdrawn
+ mock_whitelist_rule.assert_called_once_with(rtbh_rule, whitelist_model)
+ mock_withdraw.assert_called_once_with(rtbh_rule)
+
+ # Verify the flash message
+ assert any("subnet of whitelist" in msg for msg in flashes)
+
+ # Verify the correct model was returned
+ assert result == whitelist_model
diff --git a/flowapp/utils/__init__.py b/flowapp/utils/__init__.py
new file mode 100644
index 00000000..99ca325a
--- /dev/null
+++ b/flowapp/utils/__init__.py
@@ -0,0 +1,46 @@
+from .base import (
+ other_rtypes,
+ output_date_format,
+ parse_api_time,
+ quote_to_ent,
+ webpicker_to_datetime,
+ datetime_to_webpicker,
+ get_state_by_time,
+ round_to_ten_minutes,
+ flash_errors,
+ active_css_rstate,
+ get_comp_func,
+)
+
+from .app_factory import (
+ # configure_app,
+ configure_logging,
+ register_blueprints,
+ register_error_handlers,
+ register_context_processors,
+ register_template_filters,
+ register_auth_handlers,
+ # register_org_routes,
+)
+
+__all__ = [
+ "other_rtypes",
+ "output_date_format",
+ "parse_api_time",
+ "quote_to_ent",
+ "webpicker_to_datetime",
+ "datetime_to_webpicker",
+ "get_state_by_time",
+ "round_to_ten_minutes",
+ "flash_errors",
+ "active_css_rstate",
+ "get_comp_func",
+ # "configure_app",
+ "configure_logging",
+ "register_blueprints",
+ "register_error_handlers",
+ "register_context_processors",
+ "register_template_filters",
+ "register_auth_handlers",
+ # "register_org_routes",
+]
diff --git a/flowapp/utils/app_factory.py b/flowapp/utils/app_factory.py
new file mode 100644
index 00000000..234522fd
--- /dev/null
+++ b/flowapp/utils/app_factory.py
@@ -0,0 +1,221 @@
+import logging
+import babel
+from flask import redirect, render_template, request, session, url_for
+
+
+def register_blueprints(app, csrf=None):
+ """Register Flask blueprints."""
+ from flowapp.views.admin import admin
+ from flowapp.views.rules import rules
+ from flowapp.views.api_v1 import api as api_v1
+ from flowapp.views.api_v2 import api as api_v2
+ from flowapp.views.api_v3 import api as api_v3
+ from flowapp.views.api_keys import api_keys
+ from flowapp.views.dashboard import dashboard
+ from flowapp.views.whitelist import whitelist
+
+ # Configure CSRF exemption for API routes
+ if csrf:
+ csrf.exempt(api_v1)
+ csrf.exempt(api_v2)
+ csrf.exempt(api_v3)
+
+ # Register blueprints with URL prefixes
+ app.register_blueprint(admin, url_prefix="/admin")
+ app.register_blueprint(rules, url_prefix="/rules")
+ app.register_blueprint(api_keys, url_prefix="/api_keys")
+ app.register_blueprint(api_v1, url_prefix="/api/v1")
+ app.register_blueprint(api_v2, url_prefix="/api/v2")
+ app.register_blueprint(api_v3, url_prefix="/api/v3")
+ app.register_blueprint(dashboard, url_prefix="/dashboard")
+ app.register_blueprint(whitelist, url_prefix="/whitelist")
+
+ return app
+
+
+def configure_logging(app):
+ """Configure logging for the Flask application."""
+
+ # Remove all default handlers
+ for handler in app.logger.handlers[:]:
+ app.logger.removeHandler(handler)
+
+ # Retrieve log level and file name from config
+ log_level = app.config.get("LOG_LEVEL", "DEBUG").upper()
+ log_file = app.config.get("LOG_FILE", "app.log")
+
+ # Define log format
+ log_format = "%(asctime)s | %(levelname)s | %(message)s"
+ log_datefmt = "%Y-%m-%d %H:%M:%S"
+ formatter = logging.Formatter(log_format, datefmt=log_datefmt)
+
+ # Console handler
+ console_handler = logging.StreamHandler()
+ console_handler.setFormatter(formatter)
+
+ # File handler
+ file_handler = logging.FileHandler(log_file)
+ file_handler.setFormatter(formatter)
+
+ # Set logger level
+ app.logger.setLevel(getattr(logging, log_level, logging.DEBUG))
+
+ # Attach handlers
+ app.logger.addHandler(console_handler)
+ app.logger.addHandler(file_handler)
+
+ return app
+
+
+def register_error_handlers(app):
+ """Register error handlers."""
+
+ @app.errorhandler(404)
+ def not_found(error):
+ return render_template("errors/404.html"), 404
+
+ @app.errorhandler(500)
+ def internal_error(exception):
+ app.logger.exception(exception)
+ return render_template("errors/500.html"), 500
+
+ return app
+
+
+def register_context_processors(app):
+ """Register template context processors."""
+
+ @app.context_processor
+ def utility_processor():
+ def editable_rule(rule):
+ if rule:
+ from flowapp.validators import editable_range
+ from flowapp.models import get_user_nets
+
+ editable_range(rule, get_user_nets(session["user_id"]))
+ return True
+ return False
+
+ return dict(editable_rule=editable_rule)
+
+ @app.context_processor
+ def inject_main_menu():
+ """Inject main menu config to templates."""
+ return {"main_menu": app.config.get("MAIN_MENU")}
+
+ @app.context_processor
+ def inject_dashboard():
+ """Inject dashboard config to templates."""
+ return {"dashboard": app.config.get("DASHBOARD")}
+
+ return app
+
+
+def register_template_filters(app):
+ """Register custom template filters."""
+
+ @app.template_filter("strftime")
+ def format_datetime(value):
+ if value is None:
+ return app.config.get("MISSING_DATETIME_MESSAGE", "Never")
+
+ format = "y/MM/dd HH:mm"
+ return babel.dates.format_datetime(value, format)
+
+ @app.template_filter("unlimited")
+ def unlimited_filter(value):
+ return "unlimited" if value == 0 else value
+
+ return app
+
+
+def register_auth_handlers(app, ext):
+ """Register authentication handlers."""
+
+ @ext.login_handler
+ def login(user_info):
+ try:
+ uuid = user_info.get("eppn")
+ except KeyError:
+ uuid = False
+ return render_template("errors/401.html")
+
+ return _handle_login(uuid, app)
+
+ @app.route("/logout")
+ def logout():
+ session["user_uuid"] = False
+ session["user_id"] = False
+ session.clear()
+ return redirect(app.config.get("LOGOUT_URL"))
+
+ @app.route("/ext-login")
+ def ext_login():
+ header_name = app.config.get("AUTH_HEADER_NAME", "X-Authenticated-User")
+ if header_name not in request.headers:
+ return render_template("errors/401.html")
+
+ uuid = request.headers.get(header_name)
+ if not uuid:
+ return render_template("errors/401.html")
+
+ return _handle_login(uuid, app)
+
+ @app.route("/local-login")
+ def local_login():
+ print("Local login started")
+ if not app.config.get("LOCAL_AUTH", False):
+ print("Local auth not enabled")
+ return render_template("errors/401.html")
+
+ uuid = app.config.get("LOCAL_USER_UUID", False)
+ if not uuid:
+ print("Local user not set")
+ return render_template("errors/401.html")
+
+ print(f"Local login with {uuid}")
+ return _handle_login(uuid, app)
+
+ return app
+
+
+def _handle_login(uuid, app):
+ """Handle login process for all authentication methods."""
+ from flowapp import db
+
+ multiple_orgs = False
+ try:
+ user, multiple_orgs = _register_user_to_session(uuid, db)
+ except AttributeError as e:
+ app.logger.exception(e)
+ return render_template("errors/401.html")
+
+ if multiple_orgs:
+ return redirect(url_for("select_org", org_id=None))
+
+ # set user org to session
+ user_org = user.organization.first()
+ session["user_org"] = user_org.name
+ session["user_org_id"] = user_org.id
+
+ return redirect("/")
+
+
+def _register_user_to_session(uuid, db):
+ """Register user information to session."""
+ from flowapp.models import User
+
+ print(f"Registering user {uuid} to session")
+ user = db.session.query(User).filter_by(uuid=uuid).first()
+ print(f"Got user {user} from DB")
+ session["user_uuid"] = user.uuid
+ session["user_email"] = user.uuid
+ session["user_name"] = user.name
+ session["user_id"] = user.id
+ session["user_roles"] = [role.name for role in user.role.all()]
+ session["user_role_ids"] = [role.id for role in user.role.all()]
+ roles = [i > 1 for i in session["user_role_ids"]]
+ session["can_edit"] = True if all(roles) and roles else []
+ # check if user has multiple organizations and return True if so
+ print(f"DEBUG SESSION {session}")
+ return user, len(user.organization.all()) > 1
diff --git a/flowapp/utils.py b/flowapp/utils/base.py
similarity index 98%
rename from flowapp/utils.py
rename to flowapp/utils/base.py
index ad15b615..512f7ccd 100644
--- a/flowapp/utils.py
+++ b/flowapp/utils/base.py
@@ -1,4 +1,3 @@
-from operator import ge, lt
from datetime import datetime, timedelta
from flask import flash
from flowapp.constants import (
@@ -73,7 +72,7 @@ def parse_api_time(apitime):
except ValueError:
mytime = False
- return False
+ return mytime
def quote_to_ent(comment):
diff --git a/flowapp/validators.py b/flowapp/validators.py
index 63061062..890bdacb 100644
--- a/flowapp/validators.py
+++ b/flowapp/validators.py
@@ -253,7 +253,10 @@ def __call__(self, form, field):
result = False
for address in field.data.split("/"):
for adr_range in self.net_ranges:
- result = result or ipaddress.ip_address(address) in ipaddress.ip_network(adr_range)
+ try:
+ result = result or ipaddress.ip_address(address) in ipaddress.ip_network(adr_range)
+ except ValueError as e:
+ raise ValidationError(self.message + str(e.args[0]))
if not result:
raise ValidationError(self.message)
@@ -352,3 +355,66 @@ def subnet_of(net_a, net_b):
def supernet_of(net_a, net_b):
"""Return True if this network is a supernet of other."""
return _is_subnet_of(net_b, net_a)
+
+
+class IPAddressValidator:
+ """
+ Universal validator that accepts both IPv4 and IPv6 addresses.
+ """
+
+ def __init__(self, message=None):
+ self.message = message or "Invalid IP address: {}"
+
+ def __call__(self, form, field):
+ if not field.data:
+ return
+
+ try:
+ ipaddress.ip_address(field.data)
+ except ValueError:
+ raise ValidationError(self.message.format(field.data))
+
+
+class NetworkValidator:
+ """
+ Validates that an IP address and mask form a valid network.
+ Works with both IPv4 and IPv6 addresses.
+ """
+
+ def __init__(self, mask_field_name, message=None):
+ self.mask_field_name = mask_field_name
+ self.message = message or "Invalid network: address {}, mask {}"
+
+ def __call__(self, form, field):
+ if not field.data:
+ return
+
+ mask_field = form._fields.get(self.mask_field_name)
+ if not mask_field or not mask_field.data:
+ return
+
+ try:
+ # Determine IP version
+ ip = ipaddress.ip_address(field.data)
+ mask = int(mask_field.data)
+
+ # Validate mask range based on IP version
+ if isinstance(ip, ipaddress.IPv4Address):
+ if not 0 <= mask <= 32:
+ raise ValidationError(f"Invalid IPv4 mask: {mask}")
+ else: # IPv6
+ if not 0 <= mask <= 128:
+ raise ValidationError(f"Invalid IPv6 mask: {mask}")
+
+ # Try to create network to validate the combination
+ network = ipaddress.ip_network(f"{field.data}/{mask}", strict=False)
+
+ # Check if the original IP is the correct network address
+ if str(network.network_address) != str(ip):
+ raise ValidationError(
+ f"Invalid network address: {field.data}/{mask}. "
+ f"Network address should be: {network.network_address}"
+ )
+
+ except ValueError:
+ raise ValidationError(self.message.format(field.data, mask_field.data))
diff --git a/flowapp/views/admin.py b/flowapp/views/admin.py
index d014c87c..7912e86a 100644
--- a/flowapp/views/admin.py
+++ b/flowapp/views/admin.py
@@ -73,15 +73,22 @@ def add_machine_key():
"""
generated = secrets.token_hex(24)
form = MachineApiKeyForm(request.form, key=generated)
+ form.user.choices = [(g.id, g.name) for g in db.session.query(User).order_by("name")]
if request.method == "POST" and form.validate():
+ target_user = db.session.get(User, form.user.data)
+ target_org = target_user.organization.first() if target_user else None
+ current_user = session.get("user_name")
+ curent_email = session.get("user_uuid")
+ comment = f"created by: {current_user}/{curent_email}, comment: {form.comment.data}"
model = MachineApiKey(
machine=form.machine.data,
key=form.key.data,
expires=form.expires.data,
readonly=form.readonly.data,
- comment=form.comment.data,
- user_id=session["user_id"],
+ comment=comment,
+ user_id=target_user.id,
+ org_id=target_org.id,
)
db.session.add(model)
diff --git a/flowapp/views/api_common.py b/flowapp/views/api_common.py
index 6ce24bf6..3951b093 100644
--- a/flowapp/views/api_common.py
+++ b/flowapp/views/api_common.py
@@ -5,7 +5,7 @@
from functools import wraps
from datetime import datetime, timedelta
-from flowapp.constants import RULE_NAMES_DICT, WITHDRAW, ANNOUNCE, TIME_FORMAT_ARG, RuleTypes
+from flowapp.constants import RULE_NAMES_DICT, WITHDRAW, TIME_FORMAT_ARG, RuleTypes
from flowapp.models import (
RTBH,
Flowspec4,
@@ -18,20 +18,16 @@
check_rule_limit,
get_user_nets,
get_user_actions,
- get_ipv4_model_if_exists,
- get_ipv6_model_if_exists,
insert_initial_communities,
get_user_communities,
- get_rtbh_model_if_exists,
)
from flowapp.forms import IPv4Form, IPv6Form, RTBHForm
+from flowapp.services import rule_service
from flowapp.utils import (
- quote_to_ent,
- get_state_by_time,
output_date_format,
)
from flowapp.auth import check_access_rights
-from flowapp.output import announce_route, log_route, log_withdraw, Route, RouteSources
+from flowapp.output import announce_route, log_withdraw, Route, RouteSources
from flowapp import db, validators, flowspec, messages
@@ -67,7 +63,6 @@ def authorize(user_key):
:return: page with token
"""
jwt_key = current_app.config.get("JWT_SECRET")
-
# try normal user key first
model = db.session.query(ApiKey).filter_by(key=user_key).first()
# if not found try machine key
@@ -100,7 +95,7 @@ def authorize(user_key):
return jsonify({"token": encoded})
else:
- return jsonify({"message": "auth token is invalid"}), 403
+ return jsonify({"message": f"auth token is not valid from machine {request.remote_addr}"}), 403
def check_readonly(func):
@@ -191,7 +186,7 @@ def all_communities(current_user):
def limit_reached(count, rule_type, org_id):
- rule_name = RULE_NAMES_DICT[int(rule_type)]
+ rule_name = RULE_NAMES_DICT[rule_type.value]
org = db.session.get(Organization, org_id)
if rule_type == RuleTypes.IPv4:
limit = org.limit_flowspec4
@@ -207,7 +202,7 @@ def limit_reached(count, rule_type, org_id):
def global_limit_reached(count, rule_type):
- rule_name = RULE_NAMES_DICT[int(rule_type)]
+ rule_name = RULE_NAMES_DICT[rule_type.value]
if rule_type == RuleTypes.IPv4 or rule_type == RuleTypes.IPv6:
limit = current_app.config.get("FLOWSPEC_MAX_RULES")
elif rule_type == RuleTypes.RTBH:
@@ -249,52 +244,15 @@ def create_ipv4(current_user):
if form_errors:
return jsonify(form_errors), 400
- model = get_ipv4_model_if_exists(form.data, 1)
-
- if model:
- model.expires = form.expires.data
- flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value."
- else:
- model = Flowspec4(
- source=form.source.data,
- source_mask=form.source_mask.data,
- source_port=form.source_port.data,
- destination=form.dest.data,
- destination_mask=form.dest_mask.data,
- destination_port=form.dest_port.data,
- protocol=form.protocol.data,
- flags=";".join(form.flags.data),
- packet_len=form.packet_len.data,
- fragment=";".join(form.fragment.data),
- expires=form.expires.data,
- comment=quote_to_ent(form.comment.data),
- action_id=form.action.data,
- user_id=current_user["id"],
- org_id=current_user["org_id"],
- rstate_id=get_state_by_time(form.expires.data),
- )
- flash_message = "IPv4 Rule saved"
- db.session.add(model)
-
- db.session.commit()
-
- # announce route if model is in active state
- if model.rstate_id == 1:
- command = messages.create_ipv4(model, ANNOUNCE)
- route = Route(
- author=f"{current_user['uuid']} / {current_user['org']}",
- source=RouteSources.API,
- command=command,
- )
- announce_route(route)
-
- # log changes
- log_route(
- current_user["id"],
- model,
- RuleTypes.IPv4,
- f"{current_user['uuid']} / {current_user['org']}",
+ # Use the service to create/update the rule
+ model, flash_message = rule_service.create_or_update_ipv4_rule(
+ form_data=form.data,
+ user_id=current_user["id"],
+ org_id=current_user["org_id"],
+ user_email=current_user["uuid"],
+ org_name=current_user["org"],
)
+
pref_format = output_date_format(json_request_data, form.expires.pref_format)
response = {"message": flash_message, "rule": model.to_dict(pref_format)}
return jsonify(response), 201
@@ -326,50 +284,12 @@ def create_ipv6(current_user):
if form_errors:
return jsonify(form_errors), 400
- model = get_ipv6_model_if_exists(form.data, 1)
-
- if model:
- model.expires = form.expires.data
- flash_message = "Existing IPv6 Rule found. Expiration time was updated to new value."
- else:
- model = Flowspec6(
- source=form.source.data,
- source_mask=form.source_mask.data,
- source_port=form.source_port.data,
- destination=form.dest.data,
- destination_mask=form.dest_mask.data,
- destination_port=form.dest_port.data,
- next_header=form.next_header.data,
- flags=";".join(form.flags.data),
- packet_len=form.packet_len.data,
- expires=form.expires.data,
- comment=quote_to_ent(form.comment.data),
- action_id=form.action.data,
- user_id=current_user["id"],
- org_id=current_user["org_id"],
- rstate_id=get_state_by_time(form.expires.data),
- )
- flash_message = "IPv6 Rule saved"
- db.session.add(model)
-
- db.session.commit()
-
- # announce routes
- if model.rstate_id == 1:
- command = messages.create_ipv6(model, ANNOUNCE)
- route = Route(
- author=f"{current_user['uuid']} / {current_user['org']}",
- source=RouteSources.API,
- command=command,
- )
- announce_route(route)
-
- # log changes
- log_route(
- current_user["id"],
- model,
- RuleTypes.IPv6,
- f"{current_user['uuid']} / {current_user['org']}",
+ model, flash_message = rule_service.create_or_update_ipv6_rule(
+ form_data=form.data,
+ user_id=current_user["id"],
+ org_id=current_user["org_id"],
+ user_email=current_user["uuid"],
+ org_name=current_user["org"],
)
pref_format = output_date_format(json_request_data, form.expires.pref_format)
@@ -384,7 +304,6 @@ def create_rtbh(current_user):
count = db.session.query(RTBH).filter_by(rstate_id=1).count()
return global_limit_reached(count=count, rule_type=RuleTypes.RTBH)
- # check limit
if check_rule_limit(current_user["org_id"], RuleTypes.RTBH):
count = db.session.query(RTBH).filter_by(rstate_id=1, org_id=current_user["org_id"]).count()
return limit_reached(count=count, rule_type=RuleTypes.RTBH, org_id=current_user["org_id"])
@@ -406,43 +325,12 @@ def create_rtbh(current_user):
if form_errors:
return jsonify(form_errors), 400
- model = get_rtbh_model_if_exists(form.data, 1)
-
- if model:
- model.expires = form.expires.data
- flash_message = "Existing RTBH Rule found. Expiration time was updated to new value."
- else:
- model = RTBH(
- ipv4=form.ipv4.data,
- ipv4_mask=form.ipv4_mask.data,
- ipv6=form.ipv6.data,
- ipv6_mask=form.ipv6_mask.data,
- community_id=form.community.data,
- expires=form.expires.data,
- comment=quote_to_ent(form.comment.data),
- user_id=current_user["id"],
- org_id=current_user["org_id"],
- rstate_id=get_state_by_time(form.expires.data),
- )
- db.session.add(model)
- db.session.commit()
- flash_message = "RTBH Rule saved"
-
- # announce routes
- if model.rstate_id == 1:
- command = messages.create_rtbh(model, ANNOUNCE)
- route = Route(
- author=f"{current_user['uuid']} / {current_user['org']}",
- source=RouteSources.API,
- command=command,
- )
- announce_route(route)
- # log changes
- log_route(
- current_user["id"],
- model,
- RuleTypes.RTBH,
- f"{current_user['uuid']} / {current_user['org']}",
+ model, flash_message = rule_service.create_or_update_rtbh_rule(
+ form_data=form.data,
+ user_id=current_user["id"],
+ org_id=current_user["org_id"],
+ user_email=current_user["uuid"],
+ org_name=current_user["org"],
)
pref_format = output_date_format(json_request_data, form.expires.pref_format)
@@ -506,7 +394,7 @@ def delete_v4_rule(current_user, rule_id):
"""
model_name = Flowspec4
route_model = messages.create_ipv4
- return delete_rule(current_user, rule_id, model_name, route_model, 4)
+ return delete_rule(current_user, rule_id, model_name, route_model, RuleTypes.IPv4)
def delete_v6_rule(current_user, rule_id):
@@ -516,7 +404,7 @@ def delete_v6_rule(current_user, rule_id):
"""
model_name = Flowspec6
route_model = messages.create_ipv6
- return delete_rule(current_user, rule_id, model_name, route_model, 6)
+ return delete_rule(current_user, rule_id, model_name, route_model, RuleTypes.IPv6)
def delete_rtbh_rule(current_user, rule_id):
@@ -526,7 +414,7 @@ def delete_rtbh_rule(current_user, rule_id):
"""
model_name = RTBH
route_model = messages.create_rtbh
- return delete_rule(current_user, rule_id, model_name, route_model, 1)
+ return delete_rule(current_user, rule_id, model_name, route_model, RuleTypes.RTBH)
def delete_rule(current_user, rule_id, model_name, route_model, rule_type):
diff --git a/flowapp/views/dashboard.py b/flowapp/views/dashboard.py
index 8ed262f0..c709f11f 100644
--- a/flowapp/views/dashboard.py
+++ b/flowapp/views/dashboard.py
@@ -48,7 +48,7 @@ def whois(ip_address):
def index(rtype=None, rstate="active"):
"""
dispatcher object for the dashboard
- :param rtype: ipv4, ipv6, rtbh
+ :param rtype: ipv4, ipv6, rtbh, whitelist
:param rstate:
:return: view from view factory
"""
@@ -84,7 +84,6 @@ def index(rtype=None, rstate="active"):
data_handler_module = current_app.config["DASHBOARD"].get(rtype).get("data_handler", models)
data_handler_method = current_app.config["DASHBOARD"].get(rtype).get("data_handler_method", "get_ip_rules")
-
# get search query, sort order and sort key from request or session
get_search_query = request.args.get(SEARCH_ARG, session.get(SEARCH_ARG, ""))
get_sort_key = request.args.get(SORT_ARG, session.get(SORT_ARG, DEFAULT_SORT))
@@ -109,6 +108,10 @@ def index(rtype=None, rstate="active"):
# get the handler and the data
handler = getattr(data_handler_module, data_handler_method)
rules = handler(rtype, rstate, get_sort_key, get_sort_order)
+
+ # Enrich rules with whitelist information
+ rules, whitelist_rule_ids = enrich_rules_with_whitelist_info(rules, rtype)
+
session[RULES_KEY] = [rule.id for rule in rules]
# search rules
if get_search_query:
@@ -123,6 +126,8 @@ def index(rtype=None, rstate="active"):
else:
count_match = ""
+ allowed_communities = current_app.config["ALLOWED_COMMUNITIES"]
+
return view_factory(
rtype=rtype,
rstate=rstate,
@@ -138,6 +143,8 @@ def index(rtype=None, rstate="active"):
macro_tbody=macro_tbody,
macro_thead=macro_thead,
macro_tfoot=macro_tfoot,
+ whitelist_rule_ids=whitelist_rule_ids,
+ allowed_communities=allowed_communities,
)
@@ -148,18 +155,25 @@ def create_dashboard_table_body(
group_op=True,
macro_file="macros.html",
macro_name="build_ip_tbody",
+ whitelist_rule_ids=None,
+ allowed_communities=None,
):
"""
create the table body for the dashboard using a jinja2 macro
:param rules: list of rules
:param rtype: ipv4, ipv6, rtbh
+ :param editable: whether rules can be edited
+ :param group_op: whether group operations are allowed
:param macro_file: the file where the macro is defined
:param macro_name: the name of the macro
+ :param whitelist_rule_ids: set of rule IDs that were created by a whitelist
"""
tstring = "{% "
tstring = tstring + f"from '{macro_file}' import {macro_name}"
tstring = tstring + " %} {{"
- tstring = tstring + f" {macro_name}(rules, today, editable, group_op) " + "}}"
+ tstring = (
+ tstring + f" {macro_name}(rules, today, editable, group_op, whitelist_rule_ids, allowed_communities) " + "}}"
+ )
dashboard_table_body = render_template_string(
tstring,
@@ -167,6 +181,8 @@ def create_dashboard_table_body(
today=datetime.now(),
editable=editable,
group_op=group_op,
+ whitelist_rule_ids=whitelist_rule_ids or set(),
+ allowed_communities=allowed_communities or [],
)
return dashboard_table_body
@@ -246,6 +262,8 @@ def create_admin_response(
macro_tbody="build_ip_tbody",
macro_thead="build_rules_thead",
macro_tfoot="build_group_buttons_tfoot",
+ whitelist_rule_ids=None,
+ allowed_communities=None,
):
"""
Admin can see and edit any rules
@@ -256,8 +274,17 @@ def create_admin_response(
:param sort_order:
:return:
"""
+ group_op = True if rtype != "whitelist" else False
- dashboard_table_body = create_dashboard_table_body(rules, rtype, macro_file=macro_file, macro_name=macro_tbody)
+ dashboard_table_body = create_dashboard_table_body(
+ rules,
+ rtype,
+ macro_file=macro_file,
+ macro_name=macro_tbody,
+ group_op=group_op,
+ whitelist_rule_ids=whitelist_rule_ids,
+ allowed_communities=allowed_communities,
+ )
dashboard_table_head = create_dashboard_table_head(
rules_columns=table_columns,
@@ -266,15 +293,18 @@ def create_admin_response(
sort_key=sort_key,
sort_order=sort_order,
search_query=search_query,
+ group_op=group_op,
macro_file=macro_file,
macro_name=macro_thead,
)
-
- dashboard_table_foot = create_dashboard_table_foot(
- table_colspan,
- macro_file=macro_file,
- macro_name=macro_tfoot,
- )
+ if group_op:
+ dashboard_table_foot = create_dashboard_table_foot(
+ table_colspan,
+ macro_file=macro_file,
+ macro_name=macro_tfoot,
+ )
+ else:
+ dashboard_table_foot = ""
res = make_response(
render_template(
@@ -313,6 +343,8 @@ def create_user_response(
macro_tbody="build_ip_tbody",
macro_thead="build_rules_thead",
macro_tfoot="build_rules_tfoot",
+ whitelist_rule_ids=None,
+ allowed_communities=None,
):
"""
Filter out the rules for normal users
@@ -343,9 +375,20 @@ def create_user_response(
group_op=False,
macro_file=macro_file,
macro_name=macro_tbody,
+ whitelist_rule_ids=whitelist_rule_ids,
+ allowed_communities=allowed_communities,
)
+
+ group_op = True if rtype != "whitelist" else False
+
dashboard_table_editable = create_dashboard_table_body(
- rules_editable, rtype, macro_file=macro_file, macro_name=macro_tbody
+ rules_editable,
+ rtype,
+ macro_file=macro_file,
+ macro_name=macro_tbody,
+ group_op=group_op,
+ whitelist_rule_ids=whitelist_rule_ids,
+ allowed_communities=allowed_communities,
)
dashboard_table_editable_head = create_dashboard_table_head(
rules_columns=table_columns,
@@ -354,7 +397,7 @@ def create_user_response(
sort_key=sort_key,
sort_order=sort_order,
search_query=search_query,
- group_op=True,
+ group_op=group_op,
macro_file=macro_file,
macro_name=macro_thead,
)
@@ -370,11 +413,14 @@ def create_user_response(
macro_name=macro_thead,
)
- dashboard_table_foot = create_dashboard_table_foot(
- table_colspan,
- macro_file=macro_file,
- macro_name=macro_tfoot,
- )
+ if group_op:
+ dashboard_table_foot = create_dashboard_table_foot(
+ table_colspan,
+ macro_file=macro_file,
+ macro_name=macro_tfoot,
+ )
+ else:
+ dashboard_table_foot = ""
display_editable = len(rules_editable)
display_readonly = len(read_only_rules)
@@ -419,6 +465,8 @@ def create_view_response(
macro_tbody="build_ip_tbody",
macro_thead="build_rules_thead",
macro_tfoot="build_rules_tfoot",
+ whitelist_rule_ids=None,
+ allowed_communities=None,
):
"""
Filter out the rules for normal users
@@ -433,6 +481,8 @@ def create_view_response(
group_op=False,
macro_file=macro_file,
macro_name=macro_tbody,
+ whitelist_rule_ids=whitelist_rule_ids,
+ allowed_communities=allowed_communities,
)
dashboard_table_head = create_dashboard_table_head(
@@ -482,3 +532,39 @@ def filter_rules(rules, get_search_query):
result.append(rules[idx])
return result
+
+
+def enrich_rules_with_whitelist_info(rules, rule_type):
+ """
+ Enrich rules with whitelist information from RuleWhitelistCache.
+
+ Args:
+ rules: List of rule objects (Flowspec4, Flowspec6, RTBH)
+ rule_type: String identifier of rule type ("ipv4", "ipv6", "rtbh")
+
+ Returns:
+ Tuple of (rules, whitelist_rule_ids) where whitelist_rule_ids is a set of
+ rule IDs that were created by a whitelist.
+ """
+ from flowapp.models.rules.whitelist import RuleWhitelistCache
+ from flowapp.constants import RuleTypes, RuleOrigin
+
+ # Map rule type string to enum value
+ rule_type_map = {"ipv4": RuleTypes.IPv4.value, "ipv6": RuleTypes.IPv6.value, "rtbh": RuleTypes.RTBH.value}
+
+ # Get all rule IDs
+ rule_ids = [rule.id for rule in rules]
+
+ # No rules to process
+ if not rule_ids:
+ return rules, set()
+
+ # Query the cache for these rule IDs
+ cache_entries = RuleWhitelistCache.query.filter(
+ RuleWhitelistCache.rid.in_(rule_ids), RuleWhitelistCache.rtype == rule_type_map.get(rule_type)
+ ).all()
+
+ # Create a set of rule IDs that were created by a whitelist
+ whitelist_rule_ids = {entry.rid for entry in cache_entries if entry.rorigin == RuleOrigin.WHITELIST.value}
+
+ return rules, whitelist_rule_ids
diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py
index 5618789f..e3b82a1a 100644
--- a/flowapp/views/rules.py
+++ b/flowapp/views/rules.py
@@ -1,11 +1,10 @@
# flowapp/views/admin.py
from datetime import datetime, timedelta
-from operator import ge, lt
from collections import namedtuple
from flask import Blueprint, current_app, flash, redirect, render_template, request, session, url_for
-from flowapp import constants, db, messages
+from flowapp import constants, db
from flowapp.auth import (
admin_required,
auth_required,
@@ -24,19 +23,17 @@
Organization,
check_global_rule_limit,
check_rule_limit,
- get_ipv4_model_if_exists,
- get_ipv6_model_if_exists,
- get_rtbh_model_if_exists,
get_user_actions,
get_user_communities,
get_user_nets,
insert_initial_communities,
)
+from flowapp.models.log import Log
from flowapp.output import ROUTE_MODELS, announce_route, log_route, log_withdraw, RouteSources, Route
+from flowapp.services import rule_service, announce_all_routes, delete_expired_whitelists
from flowapp.utils import (
flash_errors,
get_state_by_time,
- quote_to_ent,
round_to_ten_minutes,
)
@@ -61,13 +58,21 @@
def reactivate_rule(rule_type, rule_id):
"""
Set new time for the rule of given type identified by id
- :param rule_type: string - type of rule
+ :param rule_type: integer - type of rule, corresponds to RuleTypes enum value
:param rule_id: integer - id of the rule
"""
+ # Convert the integer rule_type to RuleTypes enum
+ enum_rule_type = RuleTypes(rule_type)
+
+ # Now use the enum value where needed but the integer for dictionary lookups
model_name = DATA_MODELS[rule_type]
form_name = DATA_FORMS[rule_type]
model = db.session.get(model_name, rule_id)
+ if not model:
+ flash("Rule not found", "alert-danger")
+ return redirect(url_for("index"))
+
form = form_name(request.form, obj=model)
form.net_ranges = get_user_nets(session["user_id"])
@@ -75,72 +80,41 @@ def reactivate_rule(rule_type, rule_id):
form.action.choices = [(g.id, g.name) for g in db.session.query(Action).order_by("name")]
form.action.data = model.action_id
- if rule_type == 1:
+ if rule_type == RuleTypes.RTBH.value:
form.community.choices = get_user_communities(session["user_role_ids"])
form.community.data = model.community_id
- if rule_type == 4:
+ if rule_type == RuleTypes.IPv4.value:
form.protocol.data = model.protocol
- if rule_type == 6:
+ if rule_type == RuleTypes.IPv6.value:
form.next_header.data = model.next_header
- # do not need to validate - all is readonly
+ # Process form submission
if request.method == "POST":
- # check if rule will be reactivated
- state = get_state_by_time(form.expires.data)
+ # Round expiration time to 10 minutes
+ expires = round_to_ten_minutes(form.expires.data)
+
+ # Use the service to reactivate the rule
+ _, messages = rule_service.reactivate_rule(
+ rule_type=enum_rule_type,
+ rule_id=rule_id,
+ expires=expires,
+ comment=form.comment.data,
+ user_id=session["user_id"],
+ org_id=session["user_org_id"],
+ user_email=session["user_email"],
+ org_name=session["user_org"],
+ )
- # check global limit
- check_gl = check_global_rule_limit(rule_type)
- if state == 1 and check_gl:
+ # Handle special messages (redirects)
+ if "global_limit_reached" in messages:
return redirect(url_for("rules.global_limit_reached", rule_type=rule_type))
- # check org limit
- if state == 1 and check_rule_limit(session["user_org_id"], rule_type=rule_type):
+ if "limit_reached" in messages:
return redirect(url_for("rules.limit_reached", rule_type=rule_type))
- # set new expiration date
- model.expires = round_to_ten_minutes(form.expires.data)
- # set again the active state
- model.rstate_id = get_state_by_time(form.expires.data)
- model.comment = form.comment.data
- db.session.commit()
- flash("Rule successfully updated", "alert-success")
-
- route_model = ROUTE_MODELS[rule_type]
-
- if model.rstate_id == 1:
- # announce route
- command = route_model(model, constants.ANNOUNCE)
- route = Route(
- author=f"{session['user_email']} / {session['user_org']}",
- source=RouteSources.UI,
- command=command,
- )
- announce_route(route)
- # log changes
- log_route(
- session["user_id"],
- model,
- rule_type,
- f"{session['user_email']} / {session['user_org']}",
- )
- else:
- # withdraw route
- command = route_model(model, constants.WITHDRAW)
- route = Route(
- author=f"{session['user_email']} / {session['user_org']}",
- source=RouteSources.UI,
- command=command,
- )
- announce_route(route)
- # log changes
- log_withdraw(
- session["user_id"],
- route.command,
- rule_type,
- model.id,
- f"{session['user_email']} / {session['user_org']}",
- )
+ for message in messages:
+ flash(message, "alert-success")
return redirect(
url_for(
@@ -155,6 +129,7 @@ def reactivate_rule(rule_type, rule_id):
else:
flash_errors(form)
+ # For GET requests, prepare the form for display
form.expires.data = model.expires
for field in form:
if field.name not in ["expires", "csrf_token", "comment"]:
@@ -177,42 +152,74 @@ def reactivate_rule(rule_type, rule_id):
def delete_rule(rule_type, rule_id):
"""
Delete rule with given id and type
- :param sort_key:
- :param filter_text:
- :param rstate:
- :param rule_type: string - type of rule to be deleted
+ :param rule_type: integer - type of rule to be deleted
:param rule_id: integer - rule id
"""
- model_name = DATA_MODELS[rule_type]
- route_model = ROUTE_MODELS[rule_type]
+ # Convert the integer rule_type to RuleTypes enum
+ enum_rule_type = RuleTypes(rule_type)
+
+ # Use the service to delete the rule
+ success, message = rule_service.delete_rule(
+ rule_type=enum_rule_type,
+ rule_id=rule_id,
+ user_id=session["user_id"],
+ user_email=session["user_email"],
+ org_name=session["user_org"],
+ allowed_rule_ids=session.get(constants.RULES_KEY, []),
+ )
- model = db.session.get(model_name, rule_id)
- if model.id in session[constants.RULES_KEY]:
- # withdraw route
- command = route_model(model, constants.WITHDRAW)
- route = Route(
- author=f"{session['user_email']} / {session['user_org']}",
- source=RouteSources.UI,
- command=command,
- )
- announce_route(route)
-
- log_withdraw(
- session["user_id"],
- route.command,
- rule_type,
- model.id,
- f"{session['user_email']} / {session['user_org']}",
+ # Flash appropriate message based on result
+ flash(message, "alert-success" if success else "alert-warning")
+
+ # Redirect back to dashboard
+ return redirect(
+ url_for(
+ "dashboard.index",
+ rtype=session[constants.TYPE_ARG],
+ rstate=session[constants.RULE_ARG],
+ sort=session[constants.SORT_ARG],
+ squery=session[constants.SEARCH_ARG],
+ order=session[constants.ORDER_ARG],
)
+ )
- # delete from db
- db.session.delete(model)
- db.session.commit()
- flash("Rule deleted", "alert-success")
- else:
- flash("You can not delete this rule", "alert-warning")
+@rules.route("/delete_and_whitelist//", methods=["GET"])
+@auth_required
+@user_or_admin_required
+def delete_and_whitelist(rule_type, rule_id):
+ """
+ Delete an RTBH rule and create a whitelist entry from it.
+ :param rule_id: integer - id of the RTBH rule
+ """
+ if rule_type != RuleTypes.RTBH.value:
+ flash("Only RTBH rules can be converted to whitelists", "alert-warning")
+ return redirect(url_for("index"))
+
+ # Set whitelist expiration to 7 days from now by default
+ whitelist_expires = datetime.now() + timedelta(days=7)
+
+ # Use the service to delete RTBH and create whitelist
+ success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist(
+ rule_id=rule_id,
+ user_id=session["user_id"],
+ org_id=session["user_org_id"],
+ user_email=session["user_email"],
+ org_name=session["user_org"],
+ allowed_rule_ids=session.get(constants.RULES_KEY, []),
+ whitelist_expires=whitelist_expires,
+ )
+
+ # Flash all messages
+ for message in messages:
+ flash(message, "alert-success" if success else "alert-warning")
+
+ # If successful, flash additional message about whitelist
+ if success and whitelist:
+ flash(f"Created whitelist entry ID {whitelist.id} from RTBH rule", "alert-info")
+
+ # Redirect back to dashboard
return redirect(
url_for(
"dashboard.index",
@@ -260,6 +267,7 @@ def group_delete():
rule_type = session[constants.TYPE_ARG]
model_name = DATA_MODELS_NAMED[rule_type]
rule_type_int = constants.RULE_TYPES_DICT[rule_type]
+ enum_rule_type = RuleTypes(rule_type_int)
route_model = ROUTE_MODELS[rule_type_int]
rules = [str(x) for x in session[constants.RULES_KEY]]
to_delete = request.form.getlist("delete-id")
@@ -279,7 +287,7 @@ def group_delete():
log_withdraw(
session["user_id"],
route.command,
- rule_type_int,
+ enum_rule_type,
model.id,
f"{session['user_email']} / {session['user_org']}",
)
@@ -376,6 +384,7 @@ def group_update_save(rule_type):
model_name = DATA_MODELS[rule_type]
form_name = DATA_FORMS[rule_type]
+ enum_rule_type = RuleTypes(rule_type)
form = form_name(request.form)
@@ -417,7 +426,7 @@ def group_update_save(rule_type):
log_route(
session["user_id"],
model,
- rule_type,
+ enum_rule_type,
f"{session['user_email']} / {session['user_org']}",
)
else:
@@ -433,7 +442,7 @@ def group_update_save(rule_type):
log_withdraw(
session["user_id"],
route.command,
- rule_type,
+ enum_rule_type,
model.id,
f"{session['user_email']} / {session['user_org']}",
)
@@ -476,54 +485,17 @@ def ipv4_rule():
form.net_ranges = net_ranges
if request.method == "POST" and form.validate():
- model = get_ipv4_model_if_exists(form.data, 1)
-
- if model:
- model.expires = round_to_ten_minutes(form.expires.data)
- flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value."
- else:
- model = Flowspec4(
- source=form.source.data,
- source_mask=form.source_mask.data,
- source_port=form.source_port.data,
- destination=form.dest.data,
- destination_mask=form.dest_mask.data,
- destination_port=form.dest_port.data,
- protocol=form.protocol.data,
- flags=";".join(form.flags.data),
- packet_len=form.packet_len.data,
- fragment=";".join(form.fragment.data),
- expires=round_to_ten_minutes(form.expires.data),
- comment=quote_to_ent(form.comment.data),
- action_id=form.action.data,
- user_id=session["user_id"],
- org_id=session["user_org_id"],
- rstate_id=get_state_by_time(form.expires.data),
- )
- flash_message = "IPv4 Rule saved"
- db.session.add(model)
-
- db.session.commit()
- flash(flash_message, "alert-success")
-
- # announce route if model is in active state
- if model.rstate_id == 1:
- command = messages.create_ipv4(model, constants.ANNOUNCE)
- route = Route(
- author=f"{session['user_email']} / {session['user_org']}",
- source=RouteSources.UI,
- command=command,
- )
- announce_route(route)
-
- # log changes
- log_route(
- session["user_id"],
- model,
- RuleTypes.IPv4,
- f"{session['user_email']} / {session['user_org']}",
+ # Use the service to create/update the rule
+ _model, message = rule_service.create_or_update_ipv4_rule(
+ form_data=form.data,
+ user_id=session["user_id"],
+ org_id=session["user_org_id"],
+ user_email=session["user_email"],
+ org_name=session["user_org"],
)
+ flash(message, "alert-success")
+
return redirect(url_for("index"))
else:
for field, errors in form.errors.items():
@@ -560,52 +532,14 @@ def ipv6_rule():
form.net_ranges = net_ranges
if request.method == "POST" and form.validate():
- model = get_ipv6_model_if_exists(form.data, 1)
-
- if model:
- model.expires = round_to_ten_minutes(form.expires.data)
- flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value."
- else:
- model = Flowspec6(
- source=form.source.data,
- source_mask=form.source_mask.data,
- source_port=form.source_port.data,
- destination=form.dest.data,
- destination_mask=form.dest_mask.data,
- destination_port=form.dest_port.data,
- next_header=form.next_header.data,
- flags=";".join(form.flags.data),
- packet_len=form.packet_len.data,
- expires=round_to_ten_minutes(form.expires.data),
- comment=quote_to_ent(form.comment.data),
- action_id=form.action.data,
- user_id=session["user_id"],
- org_id=session["user_org_id"],
- rstate_id=get_state_by_time(form.expires.data),
- )
- flash_message = "IPv6 Rule saved"
- db.session.add(model)
-
- db.session.commit()
- flash(flash_message, "alert-success")
-
- # announce routes
- if model.rstate_id == 1:
- command = messages.create_ipv6(model, constants.ANNOUNCE)
- route = Route(
- author=f"{session['user_email']} / {session['user_org']}",
- source=RouteSources.UI,
- command=command,
- )
- announce_route(route)
-
- # log changes
- log_route(
- session["user_id"],
- model,
- RuleTypes.IPv6,
- f"{session['user_email']} / {session['user_org']}",
+ _model, message = rule_service.create_or_update_ipv6_rule(
+ form_data=form.data,
+ user_id=session["user_id"],
+ org_id=session["user_org_id"],
+ user_email=session["user_email"],
+ org_name=session["user_org"],
)
+ flash(message, "alert-success")
return redirect(url_for("index"))
else:
@@ -642,47 +576,18 @@ def rtbh_rule():
] + user_communities
form.community.choices = user_communities
form.net_ranges = net_ranges
+ whitelistable = Community.get_whitelistable_communities(current_app.config["ALLOWED_COMMUNITIES"])
if request.method == "POST" and form.validate():
- model = get_rtbh_model_if_exists(form.data, 1)
-
- if model:
- model.expires = round_to_ten_minutes(form.expires.data)
- flash_message = "Existing RTBH Rule found. Expiration time was updated to new value."
- else:
- model = RTBH(
- ipv4=form.ipv4.data,
- ipv4_mask=form.ipv4_mask.data,
- ipv6=form.ipv6.data,
- ipv6_mask=form.ipv6_mask.data,
- community_id=form.community.data,
- expires=round_to_ten_minutes(form.expires.data),
- comment=quote_to_ent(form.comment.data),
- user_id=session["user_id"],
- org_id=session["user_org_id"],
- rstate_id=get_state_by_time(form.expires.data),
- )
- db.session.add(model)
- db.session.commit()
- flash_message = "RTBH Rule saved"
-
- flash(flash_message, "alert-success")
- # announce routes
- if model.rstate_id == 1:
- command = messages.create_rtbh(model, constants.ANNOUNCE)
- route = Route(
- author=f"{session['user_email']} / {session['user_org']}",
- source=RouteSources.UI,
- command=command,
- )
- announce_route(route)
- # log changes
- log_route(
- session["user_id"],
- model,
- RuleTypes.RTBH,
- f"{session['user_email']} / {session['user_org']}",
+ _model, messages = rule_service.create_or_update_rtbh_rule(
+ form_data=form.data,
+ user_id=session["user_id"],
+ org_id=session["user_org_id"],
+ user_email=session["user_email"],
+ org_name=session["user_org"],
)
+ for message in messages:
+ flash(message, "alert-success")
return redirect(url_for("index"))
else:
@@ -693,7 +598,11 @@ def rtbh_rule():
default_expires = datetime.now() + timedelta(days=7)
form.expires.data = default_expires
- return render_template("forms/rtbh_rule.html", form=form, action_url=url_for("rules.rtbh_rule"))
+ print(whitelistable)
+
+ return render_template(
+ "forms/rtbh_rule.html", form=form, action_url=url_for("rules.rtbh_rule"), whitelistable=whitelistable
+ )
@rules.route("/limit_reached/")
@@ -775,64 +684,13 @@ def announce_all():
@rules.route("/withdraw_expired", methods=["GET"])
@localhost_only
def withdraw_expired():
- announce_all_routes(constants.WITHDRAW)
- return " "
-
-
-def announce_all_routes(action=constants.ANNOUNCE):
"""
- get routes from db and send it to ExaBGB api
-
- @TODO take the request away, use some kind of messaging (maybe celery?)
- :param action: action with routes - announce valid routes or withdraw expired routes
+ cleaning endpoint
+ deletes expired whitelists
+ withdraws all expired routes from ExaBGP
+ deletes logs older than 30 days
"""
- today = datetime.now()
- comp_func = ge if action == constants.ANNOUNCE else lt
-
- rules4 = (
- db.session.query(Flowspec4)
- .filter(Flowspec4.rstate_id == 1)
- .filter(comp_func(Flowspec4.expires, today))
- .order_by(Flowspec4.expires.desc())
- .all()
- )
- rules6 = (
- db.session.query(Flowspec6)
- .filter(Flowspec6.rstate_id == 1)
- .filter(comp_func(Flowspec6.expires, today))
- .order_by(Flowspec6.expires.desc())
- .all()
- )
- rules_rtbh = (
- db.session.query(RTBH)
- .filter(RTBH.rstate_id == 1)
- .filter(comp_func(RTBH.expires, today))
- .order_by(RTBH.expires.desc())
- .all()
- )
-
- messages_v4 = [messages.create_ipv4(rule, action) for rule in rules4]
- messages_v6 = [messages.create_ipv6(rule, action) for rule in rules6]
- messages_rtbh = [messages.create_rtbh(rule, action) for rule in rules_rtbh]
-
- messages_all = []
- messages_all.extend(messages_v4)
- messages_all.extend(messages_v6)
- messages_all.extend(messages_rtbh)
-
- author_action = "announce all" if action == constants.ANNOUNCE else "withdraw all expired"
-
- for command in messages_all:
- route = Route(
- author=f"System call / {author_action} rules",
- source=RouteSources.UI,
- command=command,
- )
- announce_route(route)
-
- if action == constants.WITHDRAW:
- for ruleset in [rules4, rules6, rules_rtbh]:
- for rule in ruleset:
- rule.rstate_id = 2
-
- db.session.commit()
+ delete_expired_whitelists()
+ announce_all_routes(constants.WITHDRAW)
+ Log.delete_old()
+ return " "
diff --git a/flowapp/views/whitelist.py b/flowapp/views/whitelist.py
new file mode 100644
index 00000000..39a9352a
--- /dev/null
+++ b/flowapp/views/whitelist.py
@@ -0,0 +1,127 @@
+from datetime import datetime, timedelta
+from flask import Blueprint, current_app, flash, redirect, render_template, request, session, url_for
+
+from flowapp.auth import (
+ auth_required,
+ user_or_admin_required,
+)
+from flowapp import constants, db
+from flowapp.forms import WhitelistForm
+from flowapp.models import get_user_nets, Whitelist
+from flowapp.services import create_or_update_whitelist, delete_whitelist
+from flowapp.utils.base import flash_errors
+
+whitelist = Blueprint("whitelist", __name__, template_folder="templates")
+
+
+@whitelist.route("/add", methods=["GET", "POST"])
+@auth_required
+@user_or_admin_required
+def add():
+ net_ranges = get_user_nets(session["user_id"])
+ form = WhitelistForm(request.form)
+
+ form.net_ranges = net_ranges
+
+ if request.method == "POST" and form.validate():
+ model, messages = create_or_update_whitelist(
+ form.data,
+ user_id=session["user_id"],
+ org_id=session["user_org_id"],
+ user_email=session["user_email"],
+ org_name=session["user_org"],
+ )
+ for message in messages:
+ flash(message, "alert-success")
+
+ return redirect(url_for("index"))
+ else:
+ for field, errors in form.errors.items():
+ for error in errors:
+ current_app.logger.debug("Error in the %s field - %s" % (getattr(form, field).label.text, error))
+
+ print("NOW", datetime.now())
+ default_expires = datetime.now() + timedelta(hours=1)
+ form.expires.data = default_expires
+
+ return render_template("forms/whitelist.html", form=form, action_url=url_for("whitelist.add"))
+
+
+@whitelist.route("/reactivate/", methods=["GET", "POST"])
+@auth_required
+@user_or_admin_required
+def reactivate(wl_id):
+ """
+ Set new time for whitelist
+ :param wl_id: int - id of the whitelist
+ """
+
+ model = db.session.get(Whitelist, wl_id)
+ form = WhitelistForm(request.form, obj=model)
+ form.net_ranges = get_user_nets(session["user_id"])
+
+ # do not need to validate - all is readonly
+ if request.method == "POST":
+ model = create_or_update_whitelist(
+ form.data,
+ user_id=session["user_id"],
+ org_id=session["user_org_id"],
+ user_email=session["user_email"],
+ org_name=session["user_org"],
+ )
+ flash("Whitelist updated", "alert-success")
+ return redirect(
+ url_for(
+ "dashboard.index",
+ rtype=session[constants.TYPE_ARG],
+ rstate=session[constants.RULE_ARG],
+ sort=session[constants.SORT_ARG],
+ squery=session[constants.SEARCH_ARG],
+ order=session[constants.ORDER_ARG],
+ )
+ )
+ else:
+ flash_errors(form)
+
+ form.expires.data = model.expires
+ for field in form:
+ if field.name not in ["expires", "csrf_token", "comment"]:
+ field.render_kw = {"disabled": "disabled"}
+
+ action_url = url_for("whitelist.reactivate", wl_id=wl_id)
+
+ return render_template(
+ "forms/whitelist.html",
+ form=form,
+ action_url=action_url,
+ editing=True,
+ title="Update",
+ )
+
+
+@whitelist.route("/delete/", methods=["GET"])
+@auth_required
+@user_or_admin_required
+def delete(wl_id):
+ """
+ Delete whitelist
+ :param wl_id: integer - id of the whitelist
+ """
+ if wl_id in session[constants.RULES_KEY]:
+ messages = delete_whitelist(wl_id)
+ flash(f"Whitelist {wl_id} deleted", "alert-success")
+ for message in messages:
+ flash(message, "alert-info")
+ else:
+ flash("You can not delete this Whitelist", "alert-warning")
+
+ return redirect(
+ url_for(
+ "dashboard.index",
+ rtype=session[constants.TYPE_ARG],
+ rstate=session[constants.RULE_ARG],
+ sort=session[constants.SORT_ARG],
+ squery=session[constants.SEARCH_ARG],
+ order=session[constants.ORDER_ARG],
+ )
+ )
diff --git a/run.example.py b/run.example.py
index b3dd0c4e..ad6f97eb 100644
--- a/run.example.py
+++ b/run.example.py
@@ -1,30 +1,19 @@
"""
-This is an example of how to run the application.
-First copy the file as run.py (or whatever you want)
-Then edit the file to match your needs.
-
-From version 0.8.1 the application is using Flask-Session
-stored in DB using SQL Alchemy driver. This can be configured for other
-drivers, however server side session is required for the application.
-
-In general you should not need to edit this example file.
-Only if you want to configure the application main menu and
-dashboard.
-
-Or in case that you want to add extensions etc.
+This is run py for the application. Copied to the container on build.
"""
from os import environ
-from flowapp import create_app, db, sess
+from flowapp import create_app, db
import config
# Configurations
-env = environ.get("EXAFS_ENV", "Production")
+exafs_env = environ.get("EXAFS_ENV", "Production")
+exafs_env = exafs_env.lower()
# Call app factory
-if env == "devel":
+if exafs_env == "devel" or exafs_env == "development":
app = create_app(config.DevelopmentConfig)
else:
app = create_app(config.ProductionConfig)
@@ -32,12 +21,6 @@
# init database object
db.init_app(app)
-# init session
-app.config.update(SESSION_TYPE="sqlalchemy")
-app.config.update(SESSION_SQLALCHEMY=db)
-sess.init_app(app)
-
-
# run app
if __name__ == "__main__":
app.run(host="::", port=8080, debug=True)
diff --git a/withdraw_expired b/withdraw_expired
new file mode 100644
index 00000000..0519ecba
--- /dev/null
+++ b/withdraw_expired
@@ -0,0 +1 @@
+
\ No newline at end of file
|