From 16558a87a7a49d27d53ff13fd1723f8a4651b755 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 15 Jan 2026 14:27:05 +0100 Subject: [PATCH 01/10] version 1.2.0-beta1 start --- flowapp/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flowapp/__about__.py b/flowapp/__about__.py index 1ef1e64..2c73c33 100755 --- a/flowapp/__about__.py +++ b/flowapp/__about__.py @@ -1,4 +1,4 @@ -__version__ = "1.1.9" +__version__ = "1.2.0-dev" __title__ = "ExaFS" __description__ = "Tool for creation, validation, and execution of ExaBGP messages." __author__ = "CESNET / Jiri Vrany, Petr Adamec, Josef Verich, Jakub Man" From c5b127f8622735481365bd6f29a82e99ff85272e Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 15 Jan 2026 14:43:05 +0100 Subject: [PATCH 02/10] Update for security - org id check in select_org route - config example updated - DEBUG and DEVEL default None / False, needs to be set in Produciton or Development config --- config.example.py | 7 ++----- flowapp/__init__.py | 11 +++++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/config.example.py b/config.example.py index 6831062..1682dd4 100644 --- a/config.example.py +++ b/config.example.py @@ -10,11 +10,6 @@ class Config: FLOWSPEC6_MAX_RULES = 9000 RTBH_MAX_RULES = 100000 - # Flask debugging - DEBUG = True - # Flask testing - TESTING = False - # Choose your authentication method and set it to True here or # the production / development config # SSO auth enabled @@ -104,6 +99,8 @@ class DevelopmentConfig(Config): SQLALCHEMY_DATABASE_URI = "Your Local Database URI" LOCAL_IP = "127.0.0.1" LOCAL_IP6 = "::ffff:127.0.0.1" + + # Debug and Devel mode enabled DEBUG = True DEVEL = True diff --git a/flowapp/__init__.py b/flowapp/__init__.py index d1ea49d..5bed6f8 100644 --- a/flowapp/__init__.py +++ b/flowapp/__init__.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -import os -from flask import Flask, redirect, render_template, session, url_for +from flask import Flask, redirect, render_template, session, url_for, flash from flask_sso import SSO from flask_sqlalchemy import SQLAlchemy @@ -128,11 +127,15 @@ def select_org(org_id=None): user = db.session.query(models.User).filter_by(uuid=uuid).first() if user is None: - return render_template("errors/404.html"), 404 # Handle missing user gracefully + return render_template("errors/404.html"), 404 orgs = user.organization if org_id: - org = db.session.query(models.Organization).filter_by(id=org_id).first() + # Verify user belongs to this organization + org = user.organization.filter_by(id=org_id).first() + if not org: + flash("You don't have access to this organization", "alert-danger") + return redirect(url_for("index")) session["user_org_id"] = org.id session["user_org"] = org.name return redirect("/") From 8ad7f8f45d4613a1db06670f3c8bfeb4ae96f1b5 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 15 Jan 2026 15:17:19 +0100 Subject: [PATCH 03/10] update views and forms for delete operation - POST instead GET to avoid CSRF --- flowapp/templates/macros.html | 40 +++++++++++++------- flowapp/templates/pages/actions.html | 9 +++-- flowapp/templates/pages/as_paths.html | 9 +++-- flowapp/templates/pages/communities.html | 9 +++-- flowapp/templates/pages/machine_api_key.html | 9 +++-- flowapp/templates/pages/orgs.html | 9 +++-- flowapp/templates/pages/users.html | 9 +++-- flowapp/views/admin.py | 27 ++++++++++--- flowapp/views/rules.py | 4 +- flowapp/views/whitelist.py | 2 +- 10 files changed, 86 insertions(+), 41 deletions(-) diff --git a/flowapp/templates/macros.html b/flowapp/templates/macros.html index 661be2f..a8495e8 100644 --- a/flowapp/templates/macros.html +++ b/flowapp/templates/macros.html @@ -52,9 +52,12 @@ - - - +
+ + +
{% endif %} {% if rule.comment %} + + {% if rule.community.id in allowed_communities %} +
+ + +
+ {% endif %} {% endif %} {% if rule.comment %} + {% endif %} {% if rule.comment %} + {% endfor %} diff --git a/flowapp/templates/pages/as_paths.html b/flowapp/templates/pages/as_paths.html index 369fda0..0df9eb9 100644 --- a/flowapp/templates/pages/as_paths.html +++ b/flowapp/templates/pages/as_paths.html @@ -17,9 +17,12 @@ - - - +
+ + +
{% endfor %} diff --git a/flowapp/templates/pages/communities.html b/flowapp/templates/pages/communities.html index b0005bb..45e4426 100644 --- a/flowapp/templates/pages/communities.html +++ b/flowapp/templates/pages/communities.html @@ -31,9 +31,12 @@ - - - +
+ + +
{% endfor %} diff --git a/flowapp/templates/pages/machine_api_key.html b/flowapp/templates/pages/machine_api_key.html index c2ced22..5c6a4a1 100644 --- a/flowapp/templates/pages/machine_api_key.html +++ b/flowapp/templates/pages/machine_api_key.html @@ -42,9 +42,12 @@

Machines and ApiKeys

{% endif %} - - - +
+ + +
{% endfor %} diff --git a/flowapp/templates/pages/orgs.html b/flowapp/templates/pages/orgs.html index 50ed1a1..3533bfe 100644 --- a/flowapp/templates/pages/orgs.html +++ b/flowapp/templates/pages/orgs.html @@ -42,9 +42,12 @@ - - - +
+ + +
{% endfor %} diff --git a/flowapp/templates/pages/users.html b/flowapp/templates/pages/users.html index f230abd..fd910fc 100644 --- a/flowapp/templates/pages/users.html +++ b/flowapp/templates/pages/users.html @@ -33,9 +33,12 @@ - - - +
+ + +
{% endfor %} diff --git a/flowapp/views/admin.py b/flowapp/views/admin.py index add1b67..cf5241d 100644 --- a/flowapp/views/admin.py +++ b/flowapp/views/admin.py @@ -104,7 +104,7 @@ def add_machine_key(): return render_template("forms/machine_api_key.html", form=form, generated_key=generated) -@admin.route("/delete_machine_key/", methods=["GET"]) +@admin.route("/delete_machine_key/", methods=["POST"]) @auth_required @admin_required def delete_machine_key(key_id): @@ -113,6 +113,9 @@ def delete_machine_key(key_id): :param key_id: integer """ model = db.session.get(MachineApiKey, key_id) + if not model: + flash("Key not found", "alert-danger") + return redirect(url_for("admin.machine_keys")) # delete from db db.session.delete(model) db.session.commit() @@ -181,7 +184,7 @@ def edit_user(user_id): ) -@admin.route("/user/delete/", methods=["GET"]) +@admin.route("/user/delete/", methods=["POST"]) @auth_required @admin_required def delete_user(user_id): @@ -387,11 +390,14 @@ def edit_organization(org_id): ) -@admin.route("/organization/delete/", methods=["GET"]) +@admin.route("/organization/delete/", methods=["POST"]) @auth_required @admin_required def delete_organization(org_id): org = db.session.get(Organization, org_id) + if not org: + flash("Organization not found", "alert-danger") + return redirect(url_for("admin.organizations")) aname = org.name db.session.delete(org) message = "Organization {} deleted".format(aname) @@ -465,11 +471,14 @@ def edit_as_path(path_id): ) -@admin.route("/as-path/delete/", methods=["GET"]) +@admin.route("/as-path/delete/", methods=["POST"]) @auth_required @admin_required def delete_as_path(path_id): pth = db.session.get(ASPath, path_id) + if not pth: + flash("AS path not found", "alert-danger") + return redirect(url_for("admin.as_paths")) db.session.delete(pth) message = f"AS path {pth.prefix} : {pth.as_path} deleted" alert_type = "alert-success" @@ -544,11 +553,14 @@ def edit_action(action_id): ) -@admin.route("/action/delete/", methods=["GET"]) +@admin.route("/action/delete/", methods=["POST"]) @auth_required @admin_required def delete_action(action_id): action = db.session.get(Action, action_id) + if not action: + flash("Action not found", "alert-danger") + return redirect(url_for("admin.actions")) aname = action.name db.session.delete(action) @@ -628,11 +640,14 @@ def edit_community(community_id): ) -@admin.route("/community/delete/", methods=["GET"]) +@admin.route("/community/delete/", methods=["POST"]) @auth_required @admin_required def delete_community(community_id): community = db.session.get(Community, community_id) + if not community: + flash("Community not found", "alert-danger") + return redirect(url_for("admin.communities")) aname = community.name db.session.delete(community) message = "Community {} deleted".format(aname) diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py index b871eec..0d0df44 100644 --- a/flowapp/views/rules.py +++ b/flowapp/views/rules.py @@ -148,7 +148,7 @@ def reactivate_rule(rule_type, rule_id): ) -@rules.route("/delete//", methods=["GET"]) +@rules.route("/delete//", methods=["POST"]) @auth_required @user_or_admin_required def delete_rule(rule_type, rule_id): @@ -205,7 +205,7 @@ def delete_rule(rule_type, rule_id): ) -@rules.route("/delete_and_whitelist//", methods=["GET"]) +@rules.route("/delete_and_whitelist//", methods=["POST"]) @auth_required @user_or_admin_required def delete_and_whitelist(rule_type, rule_id): diff --git a/flowapp/views/whitelist.py b/flowapp/views/whitelist.py index 743174c..400aa94 100644 --- a/flowapp/views/whitelist.py +++ b/flowapp/views/whitelist.py @@ -100,7 +100,7 @@ def reactivate(wl_id): ) -@whitelist.route("/delete/", methods=["GET"]) +@whitelist.route("/delete/", methods=["POST"]) @auth_required @user_or_admin_required def delete(wl_id): From 96194a11271d33bd54b298f05b1c28cba640cd00 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 15 Jan 2026 18:04:35 +0100 Subject: [PATCH 04/10] fix wrong session key for rules.py is_admin function --- flowapp/views/rules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py index 0d0df44..d746271 100644 --- a/flowapp/views/rules.py +++ b/flowapp/views/rules.py @@ -301,7 +301,7 @@ def group_delete(): to_delete = request.form.getlist("delete-id") # Check if user has permission to delete these rules - if set(to_delete).issubset(set(allowed_rules_str)) or is_admin(session["user_roles"]): + if set(to_delete).issubset(set(allowed_rules_str)) or is_admin(session["user_role_ids"]): for rule_id in to_delete: # withdraw route model = db.session.get(model_name, rule_id) @@ -357,7 +357,7 @@ def group_update(): allowed_rules_str = [str(x) for x in allowed_rule_ids] # redirect bad request - if not set(to_update).issubset(set(allowed_rules_str)) and not is_admin(session["user_roles"]): + if not set(to_update).issubset(set(allowed_rules_str)) and not is_admin(session["user_role_ids"]): flash("You can't edit these rules!", "alert-danger") return redirect( url_for( From 4e7e5d11ef42342c1f8dcfedacbfb53f6ceee017 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 15 Jan 2026 18:24:51 +0100 Subject: [PATCH 05/10] fix possible wrong behavior of admin_or_user_required auth decorator --- flowapp/auth.py | 4 +++- flowapp/utils/app_factory.py | 6 ------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/flowapp/auth.py b/flowapp/auth.py index a26628b..f484386 100644 --- a/flowapp/auth.py +++ b/flowapp/auth.py @@ -48,11 +48,13 @@ def decorated(*args, **kwargs): def user_or_admin_required(f): """ decorator for admin/user endpoints + Allows access if the user has at least one role with ID > 1 (user or admin) + Role IDs: 1=view (read-only), 2=user (can create/edit), 3=admin """ @wraps(f) def decorated(*args, **kwargs): - if not all(i > 1 for i in session["user_role_ids"]): + if not any(i > 1 for i in session["user_role_ids"]): return redirect(url_for("index")) return f(*args, **kwargs) diff --git a/flowapp/utils/app_factory.py b/flowapp/utils/app_factory.py index 6a4cae3..9e034aa 100644 --- a/flowapp/utils/app_factory.py +++ b/flowapp/utils/app_factory.py @@ -169,17 +169,13 @@ def ext_login(): @app.route("/local-login") def local_login(): - print("Local login started") if not app.config.get("LOCAL_AUTH", False): - print("Local auth not enabled") return render_template("errors/401.html") uuid = app.config.get("LOCAL_USER_UUID", False) if not uuid: - print("Local user not set") return render_template("errors/401.html") - print(f"Local login with {uuid}") return _handle_login(uuid, app) return app @@ -211,9 +207,7 @@ def _register_user_to_session(uuid, db): """Register user information to session.""" from flowapp.models import User - print(f"Registering user {uuid} to session") user = db.session.query(User).filter_by(uuid=uuid).first() - print(f"Got user {user} from DB") session["user_uuid"] = user.uuid session["user_email"] = user.uuid session["user_name"] = user.name From b9fb4517a74afb0d6278eb2a6152d27c6d32af50 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 15 Jan 2026 18:28:45 +0100 Subject: [PATCH 06/10] move tests outside flowapp package --- flowapp/tests/__init__.py | 0 flowapp/tests/conftest.py | 232 ------- flowapp/tests/rule_service_integration.py | 332 --------- flowapp/tests/test_api_auth.py | 71 -- flowapp/tests/test_api_deprecated.py | 28 - flowapp/tests/test_api_v3.py | 611 ----------------- .../tests/test_api_whitelist_integration.py | 273 -------- flowapp/tests/test_flowapp.py | 14 - flowapp/tests/test_flowspec.py | 167 ----- flowapp/tests/test_forms.py | 65 -- flowapp/tests/test_forms_cl.py | 608 ----------------- flowapp/tests/test_login.py | 0 flowapp/tests/test_models.py | 498 -------------- flowapp/tests/test_rule_service.py | 628 ------------------ .../test_rule_service_reactivate_delete.py | 527 --------------- flowapp/tests/test_utils.py | 71 -- flowapp/tests/test_validators.py | 547 --------------- flowapp/tests/test_whitelist_common.py | 250 ------- flowapp/tests/test_whitelist_service.py | 461 ------------- .../tests/test_zzz_api_rtbh_expired_bug.py | 240 ------- 20 files changed, 5623 deletions(-) delete mode 100644 flowapp/tests/__init__.py delete mode 100644 flowapp/tests/conftest.py delete mode 100644 flowapp/tests/rule_service_integration.py delete mode 100644 flowapp/tests/test_api_auth.py delete mode 100644 flowapp/tests/test_api_deprecated.py delete mode 100644 flowapp/tests/test_api_v3.py delete mode 100644 flowapp/tests/test_api_whitelist_integration.py delete mode 100644 flowapp/tests/test_flowapp.py delete mode 100644 flowapp/tests/test_flowspec.py delete mode 100644 flowapp/tests/test_forms.py delete mode 100644 flowapp/tests/test_forms_cl.py delete mode 100644 flowapp/tests/test_login.py delete mode 100644 flowapp/tests/test_models.py delete mode 100644 flowapp/tests/test_rule_service.py delete mode 100644 flowapp/tests/test_rule_service_reactivate_delete.py delete mode 100644 flowapp/tests/test_utils.py delete mode 100644 flowapp/tests/test_validators.py delete mode 100644 flowapp/tests/test_whitelist_common.py delete mode 100644 flowapp/tests/test_whitelist_service.py delete mode 100644 flowapp/tests/test_zzz_api_rtbh_expired_bug.py diff --git a/flowapp/tests/__init__.py b/flowapp/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/flowapp/tests/conftest.py b/flowapp/tests/conftest.py deleted file mode 100644 index 3a988ef..0000000 --- a/flowapp/tests/conftest.py +++ /dev/null @@ -1,232 +0,0 @@ -""" -PyTest configuration file for all tests -""" - -import os -import json -import pytest -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker - -from flowapp import create_app -from flowapp import db as _db -from datetime import datetime -import flowapp.models -from flowapp.models.organization import Organization - - -TESTDB = "test_project.db" -TESTDB_PATH = "/tmp/{}".format(TESTDB) -TEST_DATABASE_URI = "sqlite:///" + TESTDB_PATH - - -class FieldMock: - def __init__(self): - self.data = None - self.errors = [] - - -class RuleMock: - def __init__(self): - self.source = None - self.source_mask = None - self.dest = None - self.dest_mask = None - - -@pytest.fixture -def field(): - return FieldMock() - - -@pytest.fixture -def field_class(): - return FieldMock - - -@pytest.fixture -def rule(): - return RuleMock() - - -@pytest.fixture(scope="session") -def app(request): - """ - Create a Flask app, and override settings, for the whole test session. - """ - - _app = create_app() - - _app.config.update( - EXA_API="HTTP", - EXA_API_URL="http://localhost:5000/", - TESTING=True, - SQLALCHEMY_DATABASE_URI=TEST_DATABASE_URI, - SQLALCHEMY_TRACK_MODIFICATIONS=False, - JWT_SECRET="testing", - API_KEY="testkey", - SECRET_KEY="testkeysession", - LOCAL_USER_UUID="jiri.vrany@cesnet.cz", - LOCAL_AUTH=True, - ALLOWED_COMMUNITIES=[1, 2, 3], - WTF_CSRF_ENABLED=False, - ) - - print("\n----- CREATE FLASK APPLICATION\n") - context = _app.app_context() - context.push() - yield _app - print("\n----- CREATE FLASK APPLICATION CONTEXT\n") - - context.pop() - print("\n----- RELEASE FLASK APPLICATION CONTEXT\n") - - -@pytest.fixture(scope="session") -def client(app, request): - """ - Get the test_client from the app, for the whole test session. - """ - print("\n----- CREATE FLASK TEST CLIENT\n") - return app.test_client() - - -@pytest.fixture(scope="session") -def db(app, request): - """ - Create entire database for every test. - """ - engine = create_engine(app.config["SQLALCHEMY_DATABASE_URI"], echo=True) - sessionmaker(bind=engine) - print("\n----- CREATE TEST DB CONNECTION POOL\n") - if os.path.exists(TESTDB_PATH): - os.unlink(TESTDB_PATH) - - with app.app_context(): - _db.init_app(app) - print("#: cleaning database") - _db.reflect() - _db.drop_all() - print("#: creating tables") - _db.create_all() - - users = [ - {"name": "jiri.vrany@cesnet.cz", "role_id": 3, "org_id": 1}, - {"name": "petr.adamec@cesnet.cz", "role_id": 3, "org_id": 1}, - ] - print("#: inserting users") - flowapp.models.insert_users(users) - - org = _db.session.query(Organization).filter_by(id=1).first() - # Update the organization address range to include our test networks - org.arange = "147.230.0.0/16\n2001:718:1c01::/48\n192.168.0.0/16\n10.0.0.0/8" - _db.session.commit() - print("\n----- UPDATED TEST ORG 1 \n", org) - - def teardown(): - _db.session.commit() - _db.drop_all() - os.unlink(TESTDB_PATH) - - request.addfinalizer(teardown) - return _db - - -@pytest.fixture(scope="session") -def jwt_token(client, app, db, request): - """ - Get the test_client from the app, for the whole test session. - """ - mkey = "testkey" - - with app.app_context(): - model = flowapp.models.ApiKey(machine="127.0.0.1", key=mkey, user_id=1, org_id=1) - db.session.add(model) - db.session.commit() - - print("\n----- GET JWT TEST TOKEN\n") - url = "/api/v3/auth" - headers = {"x-api-key": mkey} - token = client.get(url, headers=headers) - data = json.loads(token.data) - return data["token"] - - -@pytest.fixture(scope="session") -def machine_api_token(client, app, db, request): - """ - Get the test_client from the app, for the whole test session. - """ - mkey = "machinetestkey" - - with app.app_context(): - model = flowapp.models.MachineApiKey(machine="127.0.0.1", key=mkey, user_id=1, org_id=1) - db.session.add(model) - db.session.commit() - - print("\n----- GET MACHINE API KEY TEST TOKEN\n") - url = "/api/v3/auth" - headers = {"x-api-key": mkey} - token = client.get(url, headers=headers) - data = json.loads(token.data) - return data["token"] - - -@pytest.fixture(scope="session") -def expired_auth_token(client, app, db, request): - """ - Get the test_client from the app, for the whole test session. - """ - test_key = "expired_test_key" - expired_date = datetime.strptime("2019-01-01", "%Y-%m-%d") - with app.app_context(): - model = flowapp.models.ApiKey(machine="127.0.0.1", key=test_key, user_id=1, expires=expired_date, org_id=1) - db.session.add(model) - db.session.commit() - - return test_key - - -@pytest.fixture(scope="session") -def readonly_jwt_token(client, app, db, request): - """ - Get the test_client from the app, for the whole test session. - """ - readonly_key = "readonly-testkey" - with app.app_context(): - model = flowapp.models.ApiKey(machine="127.0.0.1", key=readonly_key, user_id=1, readonly=True, org_id=1) - db.session.add(model) - db.session.commit() - - print("\n----- GET JWT TEST TOKEN\n") - url = "/api/v3/auth" - headers = {"x-api-key": readonly_key} - token = client.get(url, headers=headers) - data = json.loads(token.data) - return data["token"] - - -@pytest.fixture(scope="session") -def auth_client(client): - """ - Get the test_client from the app, for the whole test session. - """ - print("\n----- CREATE AUTHENTICATED FLASK TEST CLIENT\n") - client.get("/local-login") - return client - - -@pytest.fixture(autouse=True) -def reset_org_limits(db, app): - """ - Automatically reset organization limits after each test that modifies them. - """ - yield # Allow test execution - - with app.app_context(): - org = db.session.query(Organization).filter_by(id=1).first() - if org: - org.limit_flowspec4 = 0 - org.limit_flowspec6 = 0 - org.limit_rtbh = 0 - db.session.commit() diff --git a/flowapp/tests/rule_service_integration.py b/flowapp/tests/rule_service_integration.py deleted file mode 100644 index 106000a..0000000 --- a/flowapp/tests/rule_service_integration.py +++ /dev/null @@ -1,332 +0,0 @@ -"""Integration tests for rule_service module.""" - -import pytest -from datetime import datetime, timedelta -from unittest.mock import patch - -from flowapp import db, create_app -from flowapp.constants import RuleTypes -from flowapp.models import ( - RTBH, - Flowspec4, - Flowspec6, - Whitelist, - User, - Organization, - Role, - Community, - Action, - Rstate, -) -from flowapp.services import rule_service - - -@pytest.fixture(scope="module") -def app(): - """Create and configure a Flask app for testing.""" - app = create_app("testing") - - # Create a test context - with app.app_context(): - # Create database tables - db.create_all() - - # Create test data - _create_test_data() - - yield app - - # Clean up after tests - db.session.remove() - db.drop_all() - - -@pytest.fixture(scope="function") -def app_context(app): - """Provide an application context for each test.""" - with app.app_context(): - yield - - -@pytest.fixture(scope="function") -def mock_announce_route(): - """Mock the announce_route function.""" - with patch("flowapp.services.rule_service.announce_route") as mock: - yield mock - - -def _create_test_data(): - """Create test data in the database.""" - # Create roles - admin_role = Role(name="admin", description="Administrator") - user_role = Role(name="user", description="Normal user") - view_role = Role(name="view", description="View only") - db.session.add_all([admin_role, user_role, view_role]) - db.session.flush() - - # Create organizations - org1 = Organization(name="Test Org 1", arange="192.168.1.0/24\n2001:db8::/64") - org2 = Organization(name="Test Org 2", arange="10.0.0.0/8") - db.session.add_all([org1, org2]) - db.session.flush() - - # Create users - user1 = User(uuid="user1@example.com", name="User One") - user1.role.append(user_role) - user1.organization.append(org1) - - admin1 = User(uuid="admin@example.com", name="Admin User") - admin1.role.append(admin_role) - admin1.organization.append(org2) - - db.session.add_all([user1, admin1]) - db.session.flush() - - # Create rule states - state_active = Rstate(description="active rule") - state_withdrawn = Rstate(description="withdrawed rule") - state_deleted = Rstate(description="deleted rule") - state_whitelisted = Rstate(description="whitelisted rule") - db.session.add_all([state_active, state_withdrawn, state_deleted, state_whitelisted]) - db.session.flush() - - # Create actions - action1 = Action(name="Discard", command="discard", description="Drop traffic", role_id=user_role.id) - action2 = Action(name="Rate limit", command="rate-limit 1000000", description="Limit traffic", role_id=user_role.id) - db.session.add_all([action1, action2]) - db.session.flush() - - # Create communities - community1 = Community( - name="65000:1", - comm="65000:1", - larcomm="", - extcomm="", - description="Test community", - role_id=user_role.id, - as_path=False, - ) - db.session.add(community1) - db.session.flush() - - # Create some rules - # IPv4 rule - ipv4_rule = Flowspec4( - source="192.168.1.1", - source_mask=32, - source_port="", - dest="192.168.2.1", - dest_mask=32, - dest_port="", - protocol="tcp", - flags="", - packet_len="", - fragment="", - expires=datetime.now() + timedelta(hours=1), - comment="Test IPv4 rule", - action_id=action1.id, - user_id=user1.id, - org_id=org1.id, - rstate_id=state_active.id, - ) - db.session.add(ipv4_rule) - - # IPv6 rule - ipv6_rule = Flowspec6( - source="2001:db8::1", - source_mask=128, - source_port="", - dest="2001:db8:1::1", - dest_mask=128, - dest_port="", - next_header="tcp", - flags="", - packet_len="", - expires=datetime.now() + timedelta(hours=1), - comment="Test IPv6 rule", - action_id=action1.id, - user_id=user1.id, - org_id=org1.id, - rstate_id=state_active.id, - ) - db.session.add(ipv6_rule) - - # RTBH rule - rtbh_rule = RTBH( - ipv4="192.168.1.100", - ipv4_mask=32, - ipv6=None, - ipv6_mask=None, - community_id=community1.id, - expires=datetime.now() + timedelta(hours=1), - comment="Test RTBH rule", - user_id=user1.id, - org_id=org1.id, - rstate_id=state_active.id, - ) - db.session.add(rtbh_rule) - - # Save IDs as class variables for tests to use - _create_test_data.user_id = user1.id - _create_test_data.admin_id = admin1.id - _create_test_data.org1_id = org1.id - _create_test_data.org2_id = org2.id - _create_test_data.ipv4_rule_id = ipv4_rule.id - _create_test_data.ipv6_rule_id = ipv6_rule.id - _create_test_data.rtbh_rule_id = rtbh_rule.id - _create_test_data.action_id = action1.id - _create_test_data.community_id = community1.id - - db.session.commit() - - -class TestRuleServiceIntegration: - """Integration tests for rule_service functions.""" - - def test_reactivate_rule_ipv4(self, app_context, mock_announce_route): - """Test reactivating an IPv4 rule.""" - # Test data - rule_id = _create_test_data.ipv4_rule_id - user_id = _create_test_data.user_id - org_id = _create_test_data.org1_id - - # Set new expiration time (2 hours in the future) - new_expires = datetime.now() + timedelta(hours=2) - new_comment = "Updated test comment" - - # Call the function - model, messages = rule_service.reactivate_rule( - rule_type=RuleTypes.IPv4, - rule_id=rule_id, - expires=new_expires, - comment=new_comment, - user_id=user_id, - org_id=org_id, - user_email="test@example.com", - org_name="Test Org", - ) - - # Verify the rule was updated - assert model is not None - assert model.id == rule_id - assert model.comment == new_comment - assert model.expires == new_expires - assert model.rstate_id == 1 # active state - - # Verify route announcement was attempted - assert mock_announce_route.called - - # Verify message - assert messages == ["Rule successfully updated"] - - def test_reactivate_rule_inactive(self, app_context, mock_announce_route): - """Test reactivating a rule to inactive state.""" - # Test data - rule_id = _create_test_data.ipv4_rule_id - user_id = _create_test_data.user_id - org_id = _create_test_data.org1_id - - # Set past expiration time - past_expires = datetime.now() - timedelta(hours=1) - - # Call the function - model, messages = rule_service.reactivate_rule( - rule_type=RuleTypes.IPv4, - rule_id=rule_id, - expires=past_expires, - comment="Expired comment", - user_id=user_id, - org_id=org_id, - user_email="test@example.com", - org_name="Test Org", - ) - - # Verify the rule was updated - assert model is not None - assert model.id == rule_id - assert model.expires == past_expires - assert model.rstate_id == 2 # inactive/withdrawn state - - # Verify route announcement was attempted - assert mock_announce_route.called - - def test_delete_rule(self, app_context, mock_announce_route): - """Test deleting a rule.""" - # Test data - rule_id = _create_test_data.ipv6_rule_id - user_id = _create_test_data.user_id - - # Call the function - success, message = rule_service.delete_rule( - rule_type=RuleTypes.IPv6, - rule_id=rule_id, - user_id=user_id, - user_email="test@example.com", - org_name="Test Org", - ) - - # Verify the rule was deleted - assert success is True - assert message == "Rule deleted successfully" - - # Verify the rule no longer exists in the database - rule = db.session.get(Flowspec6, rule_id) - assert rule is None - - # Verify route withdrawal was attempted - assert mock_announce_route.called - - def test_delete_rule_not_found(self, app_context): - """Test deleting a non-existent rule.""" - # Use a non-existent rule ID - non_existent_id = 9999 - - # Call the function - success, message = rule_service.delete_rule( - rule_type=RuleTypes.IPv4, - rule_id=non_existent_id, - user_id=_create_test_data.user_id, - user_email="test@example.com", - org_name="Test Org", - ) - - # Verify the operation failed - assert success is False - assert message == "Rule not found" - - @patch("flowapp.services.rule_service.create_or_update_whitelist") - def test_delete_rtbh_and_create_whitelist(self, mock_create_whitelist, app_context, mock_announce_route): - """Test deleting an RTBH rule and creating a whitelist entry.""" - # Test data - rule_id = _create_test_data.rtbh_rule_id - user_id = _create_test_data.user_id - org_id = _create_test_data.org1_id - - # Mock whitelist creation - mock_whitelist = Whitelist( - ip="192.168.1.100", mask=32, expires=datetime.now() + timedelta(days=7), user_id=user_id, org_id=org_id - ) - mock_create_whitelist.return_value = (mock_whitelist, ["Whitelist created"]) - - # Call the function - success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( - rule_id=rule_id, user_id=user_id, org_id=org_id, user_email="test@example.com", org_name="Test Org" - ) - - # Verify success - assert success is True - assert len(messages) == 2 - assert "Rule deleted successfully" in messages[0] - assert "Whitelist created" in messages[1] - - # Verify the rule was deleted - rule = db.session.get(RTBH, rule_id) - assert rule is None - - # Verify create_or_update_whitelist was called with correct data - mock_create_whitelist.assert_called_once() - args, kwargs = mock_create_whitelist.call_args - form_data = kwargs.get("form_data", args[0] if args else None) - assert form_data["ip"] == "192.168.1.100" - assert form_data["mask"] == 32 - assert "Created from RTBH rule" in form_data["comment"] diff --git a/flowapp/tests/test_api_auth.py b/flowapp/tests/test_api_auth.py deleted file mode 100644 index 8d08a7a..0000000 --- a/flowapp/tests/test_api_auth.py +++ /dev/null @@ -1,71 +0,0 @@ -# Test for api authorization -import json - - -def test_token(client, jwt_token): - """ - test that token authorization works - """ - req = client.get("/api/v3/test_token", headers={"x-access-token": jwt_token}) - - assert req.status_code == 200 - - -def test_machine_token(client, machine_api_token): - """ - test that token authorization works - """ - req = client.get("/api/v3/test_token", headers={"x-access-token": machine_api_token}) - - assert req.status_code == 200 - - -def test_expired_token(client, expired_auth_token): - """ - test that expired token authorization return 401 - """ - req = client.get("/api/v3/auth", headers={"x-api-key": expired_auth_token}) - - assert req.status_code == 401 - - -def test_withnout_token(client): - """ - test that without token authorization return 401 - """ - req = client.get("/api/v3/test_token") - - assert req.status_code == 401 - - -def test_readonly_token(client, readonly_jwt_token): - """ - test that readonly flag is set correctly - """ - req = client.get("/api/v3/test_token", headers={"x-access-token": readonly_jwt_token}) - - assert req.status_code == 200 - data = json.loads(req.data) - assert data["readonly"] - - -def test_readonly_token_ipv4_create(client, db, readonly_jwt_token): - """ - test that readonly token can't create ipv4 rule - """ - headers = {"x-access-token": readonly_jwt_token} - - req = client.post( - "/api/v3/rules/ipv4", - headers=headers, - json={ - "action": 2, - "protocol": "tcp", - "source": "147.230.17.117", - "source_mask": 32, - "source_port": "", - "expires": "1444913400", - }, - ) - - assert req.status_code == 403 diff --git a/flowapp/tests/test_api_deprecated.py b/flowapp/tests/test_api_deprecated.py deleted file mode 100644 index fca9414..0000000 --- a/flowapp/tests/test_api_deprecated.py +++ /dev/null @@ -1,28 +0,0 @@ -V_PREFIX = "/api/v1" - - -def test_token(client, jwt_token): - """ - test that token authorization works - """ - req = client.get(f"{V_PREFIX}/test_token", headers={"x-access-token": jwt_token}) - - assert req.status_code == 400 - - -def test_withnout_token(client): - """ - test that without token authorization return 401 - """ - req = client.get(f"{V_PREFIX}/test_token") - - assert req.status_code == 400 - - -def test_rules(client, db, jwt_token): - """ - test that there is one ipv4 rule created in the first test - """ - req = client.get(f"{V_PREFIX}/rules", headers={"x-access-token": jwt_token}) - - assert req.status_code == 400 diff --git a/flowapp/tests/test_api_v3.py b/flowapp/tests/test_api_v3.py deleted file mode 100644 index 0548532..0000000 --- a/flowapp/tests/test_api_v3.py +++ /dev/null @@ -1,611 +0,0 @@ -import json - - -from flowapp.models import Flowspec4, Organization - -V_PREFIX = "/api/v3" - - -def test_create_rtbh_only(client, app, db, jwt_token): - """ - Test creating an RTBH rule through API that exactly matches an existing whitelist. - - The rule should be created but marked as whitelisted (rstate_id=4), and a cache entry - should be created linking the rule to the whitelist. - """ - - # Now create the RTBH rule via API - res = client.post( - f"{V_PREFIX}/rules/rtbh", - headers={"x-access-token": jwt_token}, - json={ - "community": 1, - "ipv4": "147.230.17.17", - "ipv4_mask": 32, - "expires": "10/25/2050 14:46", - }, - ) - - # Verify response is successful - assert res.status_code == 201 - data = json.loads(res.data) - assert data["rule"] is not None - - -def test_token(client, jwt_token): - """ - test that token authorization works - """ - req = client.get(f"{V_PREFIX}/test_token", headers={"x-access-token": jwt_token}) - - assert req.status_code == 200 - - -def test_withnout_token(client): - """ - test that without token authorization return 401 - """ - req = client.get(f"{V_PREFIX}/test_token") - - assert req.status_code == 401 - - -def test_list_actions(client, db, jwt_token): - """ - test that endpoint returns all action in db - """ - req = client.get(f"{V_PREFIX}/actions", headers={"x-access-token": jwt_token}) - - assert req.status_code == 200 - data = json.loads(req.data) - assert len(data) == 4 - - -def test_list_communities(client, db, jwt_token): - """ - test that endpoint returns all action in db - """ - req = client.get(f"{V_PREFIX}/communities", headers={"x-access-token": jwt_token}) - - assert req.status_code == 200 - data = json.loads(req.data) - assert len(data) == 3 - - -def test_create_v4rule(client, db, jwt_token): - """ - test that creating with valid data returns 201 - """ - req = client.post( - f"{V_PREFIX}/rules/ipv4", - headers={"x-access-token": jwt_token}, - json={ - "action": 2, - "protocol": "tcp", - "source": "147.230.17.17", - "source_mask": 32, - "source_port": "", - "expires": "10/15/2050 14:46", - "flags": ["SYN", "RST"], - }, - ) - - assert req.status_code == 201 - data = json.loads(req.data) - assert data["rule"] - assert data["rule"]["id"] == 1 - assert data["rule"]["user"] == "jiri.vrany@cesnet.cz" - - -def test_delete_v4rule(client, db, jwt_token): - """ - test that creating with valid data returns 201 - that time in the past creates expired rule (state 2) - and that the rule deletion works as expected - """ - req = client.post( - f"{V_PREFIX}/rules/ipv4", - headers={"x-access-token": jwt_token}, - json={ - "action": 2, - "protocol": "tcp", - "source": "147.230.17.12", - "source_mask": 32, - "source_port": "", - "expires": "10/15/2015 14:46", - }, - ) - - assert req.status_code == 201 - data = json.loads(req.data) - assert data["rule"]["id"] == 2 - assert data["rule"]["rstate"] == "withdrawed rule" - - req2 = client.delete( - f'{V_PREFIX}/rules/ipv4/{data["rule"]["id"]}', - headers={"x-access-token": jwt_token}, - ) - assert req2.status_code == 201 - - -def test_create_rtbh_rule(client, db, jwt_token): - """ - test that creating with valid data returns 201 - """ - req = client.post( - f"{V_PREFIX}/rules/rtbh", - headers={"x-access-token": jwt_token}, - json={ - "community": 1, - "ipv4": "147.230.17.17", - "ipv4_mask": 32, - "expires": "10/25/2050 14:46", - }, - ) - data = json.loads(req.data) - assert req.status_code == 201 - assert data["rule"] - assert data["rule"]["id"] == 1 - assert data["rule"]["user"] == "jiri.vrany@cesnet.cz" - - -def test_delete_rtbh_rule(client, db, jwt_token): - """ - test that creating with valid data returns 201 - """ - req = client.post( - f"{V_PREFIX}/rules/rtbh", - headers={"x-access-token": jwt_token}, - json={ - "community": 2, - "ipv4": "147.230.17.177", - "ipv4_mask": 32, - "expires": "10/25/2050 14:46", - }, - ) - - assert req.status_code == 201 - data = json.loads(req.data) - assert data["rule"]["id"] == 2 - req2 = client.delete( - f'{V_PREFIX}/rules/rtbh/{data["rule"]["id"]}', - headers={"x-access-token": jwt_token}, - ) - assert req2.status_code == 201 - - -def test_validation_rtbh_rule(client, db, jwt_token): - """ - test that creating with invalid data returns 400 and errors - """ - req = client.post( - f"{V_PREFIX}/rules/rtbh", - headers={"x-access-token": jwt_token}, - json={ - "community": 1, - "ipv4": "147.230.17.17", - "ipv4_mask": 32, - "ipv6": "2001:718:1C01:1111::", - "ipv6_mask": 128, - "expires": "10/25/2050 14:46", - }, - ) - data = json.loads(req.data) - assert req.status_code == 400 - assert data["message"] == "error - invalid request data" - assert type(data["validation_errors"]) is dict - assert "ipv6" in data["validation_errors"] - assert "ipv4" in data["validation_errors"] - - -def test_create_v6rule(client, db, jwt_token): - """ - test that creating with valid data returns 201 - """ - req = client.post( - f"{V_PREFIX}/rules/ipv6", - headers={"x-access-token": jwt_token}, - json={ - "action": 3, - "next_header": "tcp", - "source": "2001:718:1C01:1111::", - "source_mask": 64, - "source_port": "", - "expires": "10/25/2050 14:46", - }, - ) - data = json.loads(req.data) - assert req.status_code == 201 - assert data["rule"] - assert data["rule"]["id"] == "1" - assert data["rule"]["user"] == "jiri.vrany@cesnet.cz" - - -def test_validation_v4rule(client, db, jwt_token): - """ - test that creating with invalid data returns 400 and errors - """ - req = client.post( - f"{V_PREFIX}/rules/ipv4", - headers={"x-access-token": jwt_token}, - json={ - "action": 2, - "dest": "200.200.200.32", - "dest_mask": 16, - "protocol": "tcp", - "source": "1.1.1.1", - "source_mask": 32, - "source_port": "", - "expires": "10/15/2050 14:46", - }, - ) - - assert req.status_code == 400 - data = json.loads(req.data) - assert len(data["validation_errors"]) > 0 - assert sorted(data["validation_errors"].keys()) == sorted(["dest", "source"]) - assert len(data["validation_errors"]["dest"]) == 2 - assert data["validation_errors"]["dest"][0].startswith("This is not") - assert data["validation_errors"]["dest"][1].startswith("Source or des") - assert len(data["validation_errors"]["source"]) == 1 - assert data["validation_errors"]["source"][0].startswith("Source or des") - - -def test_all_validation_errors(client, db, jwt_token): - """ - test that creating with invalid data returns 400 and errors - """ - req = client.post(f"{V_PREFIX}/rules/ipv4", headers={"x-access-token": jwt_token}, json={"action": 2}) - assert req.status_code == 400 - - -def test_validate_v6rule(client, db, jwt_token): - """ - test that creating with invalid data returns 400 and errors - """ - req = client.post( - f"{V_PREFIX}/rules/ipv6", - headers={"x-access-token": jwt_token}, - json={ - "action": 32, - "next_header": "abc", - "source": "2011:78:1C01:1111::", - "source_mask": 64, - "source_port": "", - "expires": "10/25/2050 14:46", - }, - ) - data = json.loads(req.data) - assert req.status_code == 400 - assert len(data["validation_errors"]) > 0 - assert sorted(data["validation_errors"].keys()) == sorted(["action", "next_header", "dest", "source"]) - # assert data['validation_errors'][0].startswith('Error in the Action') - # assert data['validation_errors'][1].startswith('Error in the Source') - # assert data['validation_errors'][2].startswith('Error in the Next Header') - - -def test_rules(client, db, jwt_token): - """ - test that there is one ipv4 rule created in the first test - """ - req = client.get(f"{V_PREFIX}/rules", headers={"x-access-token": jwt_token}) - - assert req.status_code == 200 - - data = json.loads(req.data) - assert len(data["flowspec_ipv4_rw"]) == 1 - assert len(data["flowspec_ipv6_rw"]) == 1 - - -def test_timestamp_param(client, db, jwt_token): - """ - test that url param for time format works as expected - """ - req = client.get(f"{V_PREFIX}/rules?time_format=timestamp", headers={"x-access-token": jwt_token}) - - assert req.status_code == 200 - - data = json.loads(req.data) - assert data["flowspec_ipv4_rw"][0]["expires"] == 2549451000 - assert data["flowspec_ipv6_rw"][0]["expires"] == 2550315000 - - -def test_update_existing_v4rule_with_timestamp(client, db, jwt_token): - """ - test that update with different data passes - """ - req = client.post( - f"{V_PREFIX}/rules/ipv4", - headers={"x-access-token": jwt_token}, - json={ - "action": 2, - "protocol": "tcp", - "source": "147.230.17.17", - "source_mask": 32, - "source_port": "", - "expires": "1444913400", - }, - ) - - assert req.status_code == 201 - data = json.loads(req.data) - assert data["rule"] - assert data["rule"]["id"] == 2 - assert data["rule"]["user"] == "jiri.vrany@cesnet.cz" - assert data["rule"]["expires"] == 1444913400 - - -def test_create_v4rule_with_timestamp(client, db, jwt_token): - """ - test that creating with valid data returns 201 - """ - req = client.post( - f"{V_PREFIX}/rules/ipv4", - headers={"x-access-token": jwt_token}, - json={ - "action": 2, - "protocol": "tcp", - "source": "147.230.17.117", - "source_mask": 32, - "source_port": "", - "expires": "1444913400", - }, - ) - - assert req.status_code == 201 - data = json.loads(req.data) - assert data["rule"] - assert data["rule"]["id"] == 3 - assert data["rule"]["user"] == "jiri.vrany@cesnet.cz" - assert data["rule"]["expires"] == 1444913400 - - -def test_update_existing_v6rule_with_timestamp(client, db, jwt_token): - """ - test that update with different data passes - """ - req = client.post( - f"{V_PREFIX}/rules/ipv6", - headers={"x-access-token": jwt_token}, - json={ - "action": 3, - "next_header": "tcp", - "source": "2001:718:1C01:1111::", - "source_mask": 64, - "source_port": "", - "expires": "1444913400", - }, - ) - data = json.loads(req.data) - assert req.status_code == 201 - assert data["rule"] - assert data["rule"]["id"] == "1" - assert data["rule"]["user"] == "jiri.vrany@cesnet.cz" - assert data["rule"]["expires"] == 1444913400 - - -def test_create_v6rule_with_timestamp(client, db, jwt_token): - """ - test that creating with valid data returns 201 - """ - req = client.post( - f"{V_PREFIX}/rules/ipv6", - headers={"x-access-token": jwt_token}, - json={ - "action": 2, - "next_header": "udp", - "source": "2001:718:1C01:1111::", - "source_mask": 64, - "source_port": "", - "expires": "2549952908", - }, - ) - data = json.loads(req.data) - assert req.status_code == 201 - assert data["rule"] - assert data["rule"]["id"] == "2" - assert data["rule"]["rstate"] == "active rule" - assert data["rule"]["user"] == "jiri.vrany@cesnet.cz" - assert data["rule"]["expires"] == 2549953200 - - -def test_update_existing_rtbh_rule_with_timestamp(client, db, jwt_token): - """ - test that update with different data passes - """ - req = client.post( - f"{V_PREFIX}/rules/rtbh", - headers={"x-access-token": jwt_token}, - json={ - "community": 1, - "ipv4": "147.230.17.17", - "ipv4_mask": 32, - "expires": "1444913400", - }, - ) - data = json.loads(req.data) - assert req.status_code == 201 - assert data["rule"] - assert data["rule"]["id"] == 1 - assert data["rule"]["user"] == "jiri.vrany@cesnet.cz" - assert data["rule"]["expires"] == 1444913400 - - -def test_create_rtbh_rule_with_timestamp(client, db, jwt_token): - """ - test that creating with valid data returns 201 - """ - req = client.post( - f"{V_PREFIX}/rules/rtbh", - headers={"x-access-token": jwt_token}, - json={ - "community": 1, - "ipv4": "147.230.17.117", - "ipv4_mask": 32, - "expires": "1444913400", - }, - ) - data = json.loads(req.data) - assert req.status_code == 201 - assert data["rule"] - assert data["rule"]["id"] == 2 - assert data["rule"]["user"] == "jiri.vrany@cesnet.cz" - assert data["rule"]["expires"] == 1444913400 - - -def test_create_v4rule_lmit(client, db, app, jwt_token): - """ - test that limit checkt for v4 works - """ - with app.app_context(): - org = db.session.query(Organization).filter_by(id=1).first() - org.limit_flowspec4 = 2 - db.session.commit() - - # count - count = db.session.query(Flowspec4).count() - print("COUNT", count) - - sources = ["147.230.42.17", "147.230.42.118"] - codes = [201, 403] - - for source, code in zip(sources, codes): - data = { - "action": 1, - "protocol": "tcp", - "source": source, - "source_mask": 32, - "source_port": "", - "expires": "10/15/2050 14:46", - } - req = client.post( - f"{V_PREFIX}/rules/ipv4", - headers={"x-access-token": jwt_token}, - json=data, - ) - - assert req.status_code == code - - -def test_create_v6rule_lmit(client, db, app, jwt_token): - """ - test that limit check for v6 works - """ - with app.app_context(): - org = db.session.query(Organization).filter_by(id=1).first() - org.limit_flowspec6 = 3 - db.session.commit() - - sources = ["2001:718:1C01:1111::", "2001:718:1C01:1112::"] - codes = [201, 403] - - for source, code in zip(sources, codes): - data = { - "action": 1, - "next_header": "tcp", - "source": source, - "source_mask": 64, - "source_port": "", - "expires": "10/15/2050 14:46", - } - req = client.post( - f"{V_PREFIX}/rules/ipv6", - headers={"x-access-token": jwt_token}, - json=data, - ) - - assert req.status_code == code - - -def test_create_rtbh_lmit(client, db, app, jwt_token): - """ - test that limit check for v6 works - """ - with app.app_context(): - org = db.session.query(Organization).filter_by(id=1).first() - org.limit_rtbh = 1 - db.session.commit() - - sources = ["147.230.17.42", "147.230.17.43"] - codes = [201, 403] - - for source, code in zip(sources, codes): - data = { - "community": 1, - "ipv4": source, - "ipv4_mask": 32, - "expires": "10/25/2050 14:46", - } - req = client.post(f"{V_PREFIX}/rules/rtbh", headers={"x-access-token": jwt_token}, json=data) - - assert req.status_code == code - - -def test_update_existing_v4rule_with_timestamp_limit(client, db, app, jwt_token): - """ - test that update with different data passes - """ - with app.app_context(): - # count - count = db.session.query(Flowspec4).filter_by(org_id=1, rstate_id=1).count() - print("COUNT in update", count) - - org = db.session.query(Organization).filter_by(id=1).first() - org.limit_flowspec4 = count - db.session.commit() - - req = client.post( - f"{V_PREFIX}/rules/ipv4", - headers={"x-access-token": jwt_token}, - json={ - "action": 2, - "protocol": "tcp", - "source": "147.230.17.17", - "source_mask": 32, - "source_port": "", - "expires": "2552634908", - }, - ) - - assert req.status_code == 403 - data = json.loads(req.data) - assert data["message"] - assert data["message"].startswith("Rule limit") - - -def test_overall_limit(client, db, app, jwt_token): - """ - test that update with different data passes - """ - app.config.update({"FLOWSPEC4_MAX_RULES": 5, "FLOWSPEC6_MAX_RULES": 5, "RTBH_MAX_RULES": 5}) - - with app.app_context(): - # count - - org = db.session.query(Organization).filter_by(id=1).first() - org.limit_flowspec4 = 20 - db.session.commit() - - sources = ["147.230.42.1", "147.230.42.2", "147.230.42.3", "147.230.42.4"] - codes = [201, 201, 201, 403] - - for source, code in zip(sources, codes): - data = { - "action": 1, - "protocol": "tcp", - "source": source, - "source_mask": 32, - "source_port": "", - "expires": "10/15/2050 14:46", - } - req = client.post( - f"{V_PREFIX}/rules/ipv4", - headers={"x-access-token": jwt_token}, - json=data, - ) - print(source) - assert req.status_code == code - - data = json.loads(req.data) - assert data["message"] - assert data["message"].startswith("System limit") diff --git a/flowapp/tests/test_api_whitelist_integration.py b/flowapp/tests/test_api_whitelist_integration.py deleted file mode 100644 index 0391a76..0000000 --- a/flowapp/tests/test_api_whitelist_integration.py +++ /dev/null @@ -1,273 +0,0 @@ -""" -Integration test for the API view when interacting with whitelists. - -This test suite verifies the behavior when creating RTBH rules through the API -when there are existing whitelists that could affect the rules. -""" - -import json -import pytest -from datetime import datetime, timedelta - -from flowapp.constants import RuleTypes, RuleOrigin -from flowapp.models import RTBH, RuleWhitelistCache, Organization -from flowapp.services import whitelist_service - - -@pytest.fixture -def whitelist_data(): - """Create whitelist data for testing""" - return { - "ip": "192.168.1.0", - "mask": 24, - "comment": "Test whitelist for API integration test", - "expires": datetime.now() + timedelta(days=1), - } - - -@pytest.fixture -def rtbh_api_payload(): - """Create RTBH rule data for API testing""" - return { - "community": 1, - "ipv4": "192.168.1.0", - "ipv4_mask": 24, - "expires": (datetime.now() + timedelta(days=1)).strftime("%m/%d/%Y %H:%M"), - "comment": "Test RTBH rule via API", - } - - -def test_create_rtbh_equal_to_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): - """ - Test creating an RTBH rule through API that exactly matches an existing whitelist. - - The rule should be created but marked as whitelisted (rstate_id=4), and a cache entry - should be created linking the rule to the whitelist. - """ - # First, configure app to include community ID 1 in allowed communities - app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) - - # Create the whitelist directly using the service - with app.app_context(): - # Create user and organization if needed for the whitelist - org = db.session.query(Organization).first() - - # Create the whitelist - whitelist_model, _ = whitelist_service.create_or_update_whitelist( - form_data=whitelist_data, - user_id=1, # Using user ID 1 from test fixtures - org_id=org.id, - user_email="test@example.com", - org_name=org.name, - ) - - # Verify whitelist was created - assert whitelist_model.id is not None - assert whitelist_model.ip == whitelist_data["ip"] - assert whitelist_model.mask == whitelist_data["mask"] - - # Now create the RTBH rule via API - response = client.post( - "/api/v3/rules/rtbh", - headers={"x-access-token": jwt_token}, - json=rtbh_api_payload, - ) - - # Verify response is successful - assert response.status_code == 201 - data = json.loads(response.data) - assert data["rule"] is not None - rule_id = data["rule"]["id"] - - # Now verify the rule was created but marked as whitelisted - with app.app_context(): - rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() - assert rtbh_rule is not None - assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state - - # Verify a cache entry was created - cache_entry = RuleWhitelistCache.query.filter_by( - rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id - ).first() - - assert cache_entry is not None - assert cache_entry.rorigin == RuleOrigin.USER.value - - -def test_create_rtbh_supernet_of_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): - """ - Test creating an RTBH rule through API that is a supernet of an existing whitelist. - - The rule should be created with whitelisted state (rstate_id=4) and smaller subnet rules - should be created for the non-whitelisted parts. - """ - # Configure app with allowed communities - app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) - - # First create a whitelist for a subnet - whitelist_data["ip"] = "192.168.1.128" - whitelist_data["mask"] = 25 # Subnet of the RTBH rule which will be /24 - - # Create the whitelist directly using the service - with app.app_context(): - org = db.session.query(Organization).first() - whitelist_model, _ = whitelist_service.create_or_update_whitelist( - form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name - ) - - # Verify whitelist was created - assert whitelist_model.id is not None - - # Now create the RTBH rule via API that covers both the whitelist and additional space - headers = {"x-access-token": jwt_token} - response = client.post( - "/api/v3/rules/rtbh", - headers=headers, - json=rtbh_api_payload, # This is a /24, which contains the /25 whitelist - ) - - # Verify response is successful - assert response.status_code == 201 - data = json.loads(response.data) - assert data["rule"] is not None - rule_id = data["rule"]["id"] - - # Now verify the rule was created and marked as whitelisted - with app.app_context(): - # Check the original rule - rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() - assert rtbh_rule is not None - assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state - - # Check if a new subnet rule was created for the non-whitelisted part - subnet_rule = ( - db.session.query(RTBH) - .filter( - RTBH.ipv4 == "192.168.1.0", - RTBH.ipv4_mask == 25, # This would be the other half not covered by the whitelist - ) - .first() - ) - - assert subnet_rule is not None - assert subnet_rule.rstate_id == 1 # Active status - - # Verify cache entries - # Main rule should be cached as a USER rule - user_cache = RuleWhitelistCache.query.filter_by( - rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id, rorigin=RuleOrigin.USER.value - ).first() - assert user_cache is not None - - # Subnet rule should be cached as a WHITELIST rule - whitelist_cache = RuleWhitelistCache.query.filter_by( - rid=subnet_rule.id, - rtype=RuleTypes.RTBH.value, - whitelist_id=whitelist_model.id, - rorigin=RuleOrigin.WHITELIST.value, - ).first() - assert whitelist_cache is not None - - -def test_create_rtbh_subnet_of_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): - """ - Test creating an RTBH rule through API that is contained within an existing whitelist. - - The rule should be created but immediately marked as whitelisted (rstate_id=4). - """ - # Configure app with allowed communities - app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) - - # First create a whitelist for a supernet - whitelist_data["ip"] = "192.168.0.0" - whitelist_data["mask"] = 16 # Supernet that contains the RTBH rule - - # Create the whitelist directly using the service - with app.app_context(): - all_rtbh_rules_before = db.session.query(RTBH).count() - - org = db.session.query(Organization).first() - whitelist_model, _ = whitelist_service.create_or_update_whitelist( - form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name - ) - - # Verify whitelist was created - assert whitelist_model.id is not None - - # Now create the RTBH rule via API that is inside the whitelist - headers = {"x-access-token": jwt_token} - response = client.post( - "/api/v3/rules/rtbh", headers=headers, json=rtbh_api_payload # This is a /24 inside the /16 whitelist - ) - - # Verify response is successful - assert response.status_code == 201 - data = json.loads(response.data) - assert data["rule"] is not None - rule_id = data["rule"]["id"] - - # Now verify the rule was created but marked as whitelisted - with app.app_context(): - rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() - assert rtbh_rule is not None - assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state - - # Verify a cache entry was created - cache_entry = RuleWhitelistCache.query.filter_by( - rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id - ).first() - - assert cache_entry is not None - assert cache_entry.rorigin == RuleOrigin.USER.value - - # Verify no additional rules were created - all_rtbh_rules = db.session.query(RTBH).count() - assert all_rtbh_rules - all_rtbh_rules_before == 1 - - -def test_create_rtbh_no_relation_to_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): - """ - Test creating an RTBH rule through API that has no relation to any existing whitelist. - - The rule should be created normally with active state (rstate_id=1). - """ - # Configure app with allowed communities - app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) - - # First create a whitelist for a completely different network - whitelist_data["ip"] = "10.0.0.0" - whitelist_data["mask"] = 8 - - # Create the whitelist directly using the service - with app.app_context(): - org = db.session.query(Organization).first() - whitelist_model, _ = whitelist_service.create_or_update_whitelist( - form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name - ) - - # Verify whitelist was created - assert whitelist_model.id is not None - - # Now create the RTBH rule via API for a network not covered by any of the whitelists - rtbh_api_payload["ipv4"] = "147.230.17.185" - rtbh_api_payload["ipv4_mask"] = 32 - - headers = {"x-access-token": jwt_token} - response = client.post("/api/v3/rules/rtbh", headers=headers, json=rtbh_api_payload) # This is 192.168.1.0/24 - - # Verify response is successful - assert response.status_code == 201 - data = json.loads(response.data) - assert data["rule"] is not None - rule_id = data["rule"]["id"] - - # Now verify the rule was created with active state - with app.app_context(): - rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() - assert rtbh_rule is not None - assert rtbh_rule.rstate_id == 1 # 1 = active state - - # Verify no cache entry was created - cache_entry = RuleWhitelistCache.query.filter_by(rid=rule_id, rtype=RuleTypes.RTBH.value).first() - - assert cache_entry is None diff --git a/flowapp/tests/test_flowapp.py b/flowapp/tests/test_flowapp.py deleted file mode 100644 index f71e4ae..0000000 --- a/flowapp/tests/test_flowapp.py +++ /dev/null @@ -1,14 +0,0 @@ -def test_dashboard_not_auth(client): - - response = client.get("/dashboard/ipv4/active/?sort=expires&order=desc") - - # Expecting a 302 redirect to login - assert response.status_code == 302 - - -def test_dashboard(auth_client): - - response = auth_client.get("/dashboard/ipv4/active/?sort=expires&order=desc") - - # Check that the request is successful and renders the correct template - assert response.status_code == 200 # Expecting a 200 OK if the user is authenticated diff --git a/flowapp/tests/test_flowspec.py b/flowapp/tests/test_flowspec.py deleted file mode 100644 index 70000b8..0000000 --- a/flowapp/tests/test_flowspec.py +++ /dev/null @@ -1,167 +0,0 @@ -import pytest -from flowapp.flowspec import translate_sequence, filter_rules_action, check_limit - - -def test_translate_number(): - """ - tests for x (integer) to =x - """ - assert "[=10]" == translate_sequence("10") - - -def test_raises(): - """ - tests for translator - """ - with pytest.raises(ValueError): - translate_sequence("ahoj") - - -def test_raises_bad_number(): - """ - tests for translator - """ - with pytest.raises(ValueError): - translate_sequence("75555") - - -def test_translate_range(): - """ - tests for x-y to >=x&<=y - """ - assert "[>=10&<=20]" == translate_sequence("10-20") - - -def test_exact_rule(): - """ - test for >=x&<=y to >=x&<=y - """ - assert "[>=10&<=20]" == translate_sequence(">=10&<=20") - - -def test_greater_than(): - """ - test for >x to >=x&<=65535 - """ - assert "[>=10&<=65535]" == translate_sequence(">10") - - -def test_greater_equal_than(): - """ - test for >=x to >=x&<=65535 - """ - assert "[>=10&<=65535]" == translate_sequence(">=10") - - -def test_lower_than(): - """ - test for =0&<=0 - """ - assert "[>=0&<=10]" == translate_sequence("<10") - - -def test_lower_equal_than(): - """ - test for =0&<=0 - """ - assert "[>=0&<=10]" == translate_sequence("<=10") - - -# new tests - - -def test_multiple_sequences(): - """Test multiple sequences separated by semicolons""" - assert "[=10 >=20&<=30]" == translate_sequence("10;20-30") - assert "[=10 >=0&<=20 >=30&<=65535]" == translate_sequence("10;<=20;>30") - - -def test_empty_sequence(): - """Test empty sequence and sequences with empty parts""" - assert "[]" == translate_sequence("") - assert "[=10]" == translate_sequence("10;;") - assert "[=10 =20]" == translate_sequence("10;;20") - - -def test_range_edge_cases(): - """Test edge cases for ranges""" - # Same numbers in range - assert "[>=10&<=10]" == translate_sequence("10-10") - - # Invalid range (start > end) - with pytest.raises(ValueError, match="Invalid range: start value cannot be greater than end value"): - translate_sequence("20-10") - - # Invalid range in exact rule - with pytest.raises(ValueError, match="Invalid range: start value cannot be greater than end value"): - translate_sequence(">=20&<=10") - - -def test_check_limit_validation(): - """Test the check_limit function""" - # Test valid cases - assert check_limit(10, 100) == 10 - assert check_limit(0, 100, 0) == 0 - assert check_limit(100, 100, 0) == 100 - - # Test invalid cases - with pytest.raises(ValueError, match="Invalid value number: .* is too big"): - check_limit(101, 100) - - with pytest.raises(ValueError, match="Invalid value number: .* is too small"): - check_limit(-1, 100, 0) - - -def test_invalid_inputs(): - """Test various invalid inputs""" - invalid_inputs = [ - "abc", # Non-numeric - "10.5", # Decimal - "10,20", # Wrong separator - "10&20", # Wrong format - ">", # Incomplete - ">=", # Incomplete - "<", # Incomplete - "<=", # Incomplete - ">>10", # Invalid operator - "<<10", # Invalid operator - ">=10&", # Incomplete range - "&<=10", # Incomplete range - ] - - for invalid_input in invalid_inputs: - with pytest.raises(ValueError): - translate_sequence(invalid_input) - - -class MockRule: - def __init__(self, action_id): - self.action_id = action_id - - -def test_filter_rules_action(): - """Test the filter_rules_action function""" - # Create test rules - rules = [MockRule(1), MockRule(2), MockRule(3), MockRule(4)] - - # Test with empty allowed actions - editable, viewonly = filter_rules_action([], rules) - assert len(editable) == 0 - assert len(viewonly) == 4 - - # Test with some allowed actions - editable, viewonly = filter_rules_action([1, 3], rules) - assert len(editable) == 2 - assert len(viewonly) == 2 - assert all(rule.action_id in [1, 3] for rule in editable) - assert all(rule.action_id in [2, 4] for rule in viewonly) - - # Test with all actions allowed - editable, viewonly = filter_rules_action([1, 2, 3, 4], rules) - assert len(editable) == 4 - assert len(viewonly) == 0 - - # Test with empty rules list - editable, viewonly = filter_rules_action([1, 2], []) - assert len(editable) == 0 - assert len(viewonly) == 0 diff --git a/flowapp/tests/test_forms.py b/flowapp/tests/test_forms.py deleted file mode 100644 index e9620d8..0000000 --- a/flowapp/tests/test_forms.py +++ /dev/null @@ -1,65 +0,0 @@ -import pytest -import flowapp.forms - - -@pytest.fixture() -def ip_form(app, field_class): - with app.test_request_context(): # Push the request context - form = flowapp.forms.IPForm() - form.source = field_class() - form.dest = field_class() - form.source_mask = field_class() - form.dest_mask = field_class() - return form - - -def test_ip_form_created(ip_form): - assert ip_form.source.data is None - assert ip_form.source.errors == [] - - -@pytest.mark.parametrize( - "address, mask, expected", - [ - ("147.230.23.25", "24", False), - ("147.230.23.0", "24", True), - ("0.0.0.0", "0", True), - ("2001:718:1C01:1111::1111", "64", False), - ("2001:718:1C01:1111::", "64", True), - ], -) -def test_ip_form_validate_source_address(ip_form, address, mask, expected): - ip_form.source.data = address - ip_form.source_mask.data = mask - assert ip_form.validate_source_address() == expected - - -@pytest.mark.parametrize( - "address, mask, expected", - [ - ("147.230.23.25", "24", False), - ("147.230.23.0", "24", True), - ("0.0.0.0", "0", True), - ("2001:718:1C01:1111::1111", "64", False), - ("2001:718:1C01:1111::", "64", True), - ], -) -def test_ip_form_validate_dest_address(ip_form, address, mask, expected): - ip_form.dest.data = address - ip_form.dest_mask.data = mask - assert ip_form.validate_dest_address() == expected - - -@pytest.mark.parametrize( - "address, mask, ranges, expected", - [ - ("147.230.23.0", "24", ["147.230.0.0/16", "2001:718:1c01::/48"], True), - ("0.0.0.0", "0", ["147.230.0.0/16", "2001:718:1c01::/48"], False), - ("195.113.0.0", "16", ["195.113.0.0/18", "195.113.64.0/21"], False), - ], -) -def test_ip_form_validate_address_mask(ip_form, address, mask, ranges, expected): - ip_form.net_ranges = ranges - ip_form.source.data = address - ip_form.source_mask.data = mask - assert ip_form.validate_address_ranges() == expected diff --git a/flowapp/tests/test_forms_cl.py b/flowapp/tests/test_forms_cl.py deleted file mode 100644 index db0a569..0000000 --- a/flowapp/tests/test_forms_cl.py +++ /dev/null @@ -1,608 +0,0 @@ -import pytest -from datetime import datetime, timedelta -from werkzeug.datastructures import MultiDict -from flowapp.forms import ( - UserForm, - BulkUserForm, - ApiKeyForm, - MachineApiKeyForm, - OrganizationForm, - ActionForm, - ASPathForm, - CommunityForm, - RTBHForm, - IPv4Form, - IPv6Form, - WhitelistForm, -) - - -def create_form_data(data): - """Helper function to create proper form data format""" - processed_data = {} - for key, value in data.items(): - if isinstance(value, list): - processed_data[key] = [str(v) for v in value] - else: - processed_data[key] = value - return MultiDict(processed_data) - - -@pytest.fixture -def valid_datetime(): - return (datetime.now() + timedelta(days=1)).strftime("%Y-%m-%dT%H:%M") - - -@pytest.fixture -def sample_network_ranges(): - return ["192.168.0.0/16", "2001:db8::/32"] - - -class TestUserForm: - @pytest.fixture - def mock_choices(self): - return { - "role_ids": [(1, "Admin"), (2, "User"), (3, "Guest")], - "org_ids": [(1, "Org1"), (2, "Org2"), (3, "Org3")], - } - - @pytest.fixture - def valid_user_data(self): - return { - "uuid": "test@example.com", - "email": "user@example.com", - "name": "Test User", - "phone": "123456789", - "role_ids": ["2"], - "org_ids": ["1"], - } - - def test_valid_user_form(self, app, valid_user_data, mock_choices): - with app.test_request_context(): - form_data = create_form_data(valid_user_data) - form = UserForm(formdata=form_data) - form.role_ids.choices = mock_choices["role_ids"] - form.org_ids.choices = mock_choices["org_ids"] - - if not form.validate(): - print("Validation errors:", form.errors) - - assert form.validate() - - def test_invalid_email(self, app, mock_choices): - with app.test_request_context(): - form_data = create_form_data({"uuid": "invalid-email", "role_ids": ["2"], "org_ids": ["1"]}) - form = UserForm(formdata=form_data) - form.role_ids.choices = mock_choices["role_ids"] - form.org_ids.choices = mock_choices["org_ids"] - - assert not form.validate() - assert "Please provide valid email" in form.uuid.errors - - -class TestRTBHForm: - @pytest.fixture - def mock_community_choices(self): - return [(1, "Community1"), (2, "Community2")] - - @pytest.fixture - def valid_rtbh_data(self, valid_datetime): - return { - "ipv4": "192.168.1.0", - "ipv4_mask": 24, - "community": "1", - "expires": valid_datetime, - "comment": "Test RTBH rule", - } - - def test_valid_ipv4_rtbh(self, app, valid_rtbh_data, sample_network_ranges, mock_community_choices): - with app.test_request_context(): - form_data = create_form_data(valid_rtbh_data) - form = RTBHForm(formdata=form_data) - form.net_ranges = sample_network_ranges - form.community.choices = mock_community_choices - assert form.validate() - - -class TestIPv4Form: - @pytest.fixture - def mock_action_choices(self): - return [(1, "Accept"), (2, "Drop"), (3, "Reject")] - - @pytest.fixture - def valid_ipv4_data(self, valid_datetime): - return { - "source": "192.168.1.0", - "source_mask": "24", - "protocol": "tcp", - "action": "1", - "expires": valid_datetime, - } - - def test_valid_ipv4_rule(self, app, valid_ipv4_data, sample_network_ranges, mock_action_choices): - with app.test_request_context(): - form_data = create_form_data(valid_ipv4_data) - form = IPv4Form(formdata=form_data) - form.net_ranges = sample_network_ranges - form.action.choices = mock_action_choices - - if not form.validate(): - print("Validation errors:", form.errors) - - assert form.validate() - - def test_invalid_protocol_flags(self, app, valid_datetime, sample_network_ranges, mock_action_choices): - with app.test_request_context(): - form_data = create_form_data( - { - "source": "192.168.1.0", - "source_mask": "24", - "protocol": "udp", - "flags": ["syn"], - "action": "1", - "expires": valid_datetime, - } - ) - form = IPv4Form(formdata=form_data) - form.net_ranges = sample_network_ranges - form.action.choices = mock_action_choices - assert not form.validate() - - -class TestIPv6Form: - @pytest.fixture - def mock_action_choices(self): - return [(1, "Accept"), (2, "Drop"), (3, "Reject")] - - @pytest.fixture - def valid_ipv6_data(self, valid_datetime): - return { - "source": "2001:db8::", # Network aligned address within allowed range - "source_mask": "32", # Matching the organization's prefix length - "next_header": "tcp", - "action": "1", - "expires": valid_datetime, - } - - def test_valid_ipv6_rule(self, app, valid_ipv6_data, sample_network_ranges, mock_action_choices): - with app.test_request_context(): - form_data = create_form_data(valid_ipv6_data) - form = IPv6Form(formdata=form_data) - form.net_ranges = sample_network_ranges - form.action.choices = mock_action_choices - - if not form.validate(): - print("Validation errors:", form.errors) - - assert form.validate() - - def test_invalid_next_header_flags(self, app, valid_datetime, sample_network_ranges, mock_action_choices): - with app.test_request_context(): - form_data = create_form_data( - { - "source": "2001:db8::", - "source_mask": "32", - "next_header": "udp", - "flags": ["syn"], - "action": "1", - "expires": valid_datetime, - } - ) - form = IPv6Form(formdata=form_data) - form.net_ranges = sample_network_ranges - form.action.choices = mock_action_choices - assert not form.validate() - - def test_address_outside_range(self, app, valid_datetime, sample_network_ranges, mock_action_choices): - """Test validation fails when address is outside allowed ranges""" - with app.test_request_context(): - form_data = create_form_data( - { - "source": "2001:db9::", # Different prefix - "source_mask": "32", - "next_header": "tcp", - "action": "1", - "expires": valid_datetime, - } - ) - form = IPv6Form(formdata=form_data) - form.net_ranges = sample_network_ranges - form.action.choices = mock_action_choices - assert not form.validate() - assert any("must be in organization range" in error for error in form.source.errors) - - def test_destination_address(self, app, valid_datetime, sample_network_ranges, mock_action_choices): - """Test validation with destination address instead of source""" - with app.test_request_context(): - form_data = create_form_data( - { - "dest": "2001:db8::", - "dest_mask": "32", - "next_header": "tcp", - "action": "1", - "expires": valid_datetime, - } - ) - form = IPv6Form(formdata=form_data) - form.net_ranges = sample_network_ranges - form.action.choices = mock_action_choices - - if not form.validate(): - print("Validation errors:", form.errors) - - assert form.validate() - - def test_both_source_and_dest(self, app, valid_datetime, sample_network_ranges, mock_action_choices): - """Test validation with both source and destination addresses""" - with app.test_request_context(): - form_data = create_form_data( - { - "source": "2001:db8::", - "source_mask": "32", - "dest": "2001:db8:1::", - "dest_mask": "48", - "next_header": "tcp", - "action": "1", - "expires": valid_datetime, - } - ) - form = IPv6Form(formdata=form_data) - form.net_ranges = sample_network_ranges - form.action.choices = mock_action_choices - - if not form.validate(): - print("Validation errors:", form.errors) - - assert form.validate() - - def test_tcp_flags(self, app, valid_datetime, sample_network_ranges, mock_action_choices): - """Test validation with TCP flags (should be valid with TCP)""" - with app.test_request_context(): - form_data = create_form_data( - { - "source": "2001:db8::", - "source_mask": "32", - "next_header": "tcp", - "flags": ["SYN", "ACK"], - "action": "1", - "expires": valid_datetime, - } - ) - form = IPv6Form(formdata=form_data) - form.net_ranges = sample_network_ranges - form.action.choices = mock_action_choices - - if not form.validate(): - print("Validation errors:", form.errors) - - assert form.validate() - - @pytest.mark.parametrize("port_data", ["80", "80;443", "1024-2048"]) - def test_valid_ports(self, app, valid_datetime, sample_network_ranges, mock_action_choices, port_data): - """Test validation with various valid port formats""" - with app.test_request_context(): - form_data = create_form_data( - { - "source": "2001:db8::", - "source_mask": "32", - "next_header": "tcp", - "source_port": port_data, - "dest_port": port_data, - "action": "1", - "expires": valid_datetime, - } - ) - form = IPv6Form(formdata=form_data) - form.net_ranges = sample_network_ranges - form.action.choices = mock_action_choices - - if not form.validate(): - print("Validation errors:", form.errors) - - assert form.validate() - - -class TestBulkUserForm: - @pytest.fixture - def valid_csv_data(self): - return "uuid-eppn,role,organizace\nuser1@example.com,2,1\nuser2@example.com,2,1" - - def test_valid_bulk_import(self, app, valid_csv_data): - with app.test_request_context(): - form_data = create_form_data({"users": valid_csv_data}) - form = BulkUserForm(formdata=form_data) - form.roles = {2} # Mock available roles - form.organizations = {1} # Mock available organizations - form.uuids = set() # Mock existing UUIDs (empty set) - assert form.validate() - - def test_duplicate_uuid(self, app, valid_csv_data): - with app.test_request_context(): - form_data = create_form_data({"users": valid_csv_data}) - form = BulkUserForm(formdata=form_data) - form.roles = {2} - form.organizations = {1} - form.uuids = {"user1@example.com"} # UUID already exists - assert not form.validate() - assert any("already exists" in error for error in form.users.errors) - - def test_invalid_role(self, app): - csv_data = "uuid-eppn,role,organizace\nuser1@example.com,999,1" # Invalid role ID - with app.test_request_context(): - form_data = create_form_data({"users": csv_data}) - form = BulkUserForm(formdata=form_data) - form.roles = {2} - form.organizations = {1} - form.uuids = set() - assert not form.validate() - assert any("does not exist" in error for error in form.users.errors) - - -class TestApiKeyForm: - @pytest.fixture - def valid_api_key_data(self, valid_datetime): - return {"machine": "192.168.1.1", "comment": "Test API key", "expires": valid_datetime, "readonly": "true"} - - def test_valid_api_key(self, app, valid_api_key_data): - with app.test_request_context(): - form_data = create_form_data(valid_api_key_data) - form = ApiKeyForm(formdata=form_data) - assert form.validate() - - def test_invalid_ip(self, app, valid_datetime): - with app.test_request_context(): - form_data = create_form_data({"machine": "invalid_ip", "expires": valid_datetime}) - form = ApiKeyForm(formdata=form_data) - assert not form.validate() - assert "provide valid IP address" in form.machine.errors - - def test_unlimited_expiration(self, app): - with app.test_request_context(): - form_data = create_form_data({"machine": "192.168.1.1", "expires": ""}) # Empty expiration for unlimited - form = ApiKeyForm(formdata=form_data) - assert form.validate() - - -class TestMachineApiKeyForm: - # Similar to ApiKeyForm, but might have different validation rules - @pytest.fixture - def valid_machine_key_data(self, valid_datetime): - return { - "machine": "192.168.1.1", - "comment": "Test machine API key", - "expires": valid_datetime, - "readonly": "true", - "user": 1, - } - - def test_valid_machine_key(self, app, valid_machine_key_data): - with app.test_request_context(): - form_data = create_form_data(valid_machine_key_data) - form = MachineApiKeyForm(formdata=form_data) - form.user.choices = [(1, "g.name"), (2, "test")] - assert form.validate() - - -class TestOrganizationForm: - @pytest.fixture - def valid_org_data(self): - return { - "name": "Test Organization", - "limit_flowspec4": "100", - "limit_flowspec6": "100", - "limit_rtbh": "50", - "arange": "192.168.0.0/16\n2001:db8::/32", - } - - def test_valid_organization(self, app, valid_org_data): - with app.test_request_context(): - form_data = create_form_data(valid_org_data) - form = OrganizationForm(formdata=form_data) - assert form.validate() - - def test_invalid_ranges(self, app): - with app.test_request_context(): - form_data = create_form_data({"name": "Test Org", "arange": "invalid_range\n192.168.0.0/16"}) - form = OrganizationForm(formdata=form_data) - assert not form.validate() - - def test_invalid_limits(self, app): - with app.test_request_context(): - form_data = create_form_data({"name": "Test Org", "limit_flowspec4": "1001"}) # Exceeds max value - form = OrganizationForm(formdata=form_data) - assert not form.validate() - - -class TestActionForm: - @pytest.fixture - def valid_action_data(self): - return { - "name": "test_action", - "command": "announce route", - "description": "Test action description", - "role_id": "2", - } - - def test_valid_action(self, app, valid_action_data): - with app.test_request_context(): - form_data = create_form_data(valid_action_data) - form = ActionForm(formdata=form_data) - assert form.validate() - - def test_invalid_role(self, app, valid_action_data): - with app.test_request_context(): - invalid_data = dict(valid_action_data) - invalid_data["role_id"] = "4" # Invalid role - form_data = create_form_data(invalid_data) - form = ActionForm(formdata=form_data) - assert not form.validate() - - -class TestASPathForm: - @pytest.fixture - def valid_aspath_data(self): - return {"prefix": "AS64512", "as_path": "64512 64513 64514"} - - def test_valid_aspath(self, app, valid_aspath_data): - with app.test_request_context(): - form_data = create_form_data(valid_aspath_data) - form = ASPathForm(formdata=form_data) - assert form.validate() - - def test_empty_fields(self, app): - with app.test_request_context(): - form_data = create_form_data({}) - form = ASPathForm(formdata=form_data) - assert not form.validate() - - -class TestCommunityForm: - @pytest.fixture - def valid_community_data(self): - return { - "name": "test_community", - "comm": "64512:100", - "description": "Test community", - "role_id": "2", - "as_path": "true", - } - - def test_valid_community(self, app, valid_community_data): - with app.test_request_context(): - form_data = create_form_data(valid_community_data) - form = CommunityForm(formdata=form_data) - assert form.validate() - - def test_missing_all_community_types(self, app): - with app.test_request_context(): - form_data = create_form_data({"name": "test_community", "role_id": "2"}) - form = CommunityForm(formdata=form_data) - assert not form.validate() - assert any("could not be empty" in error for error in form.comm.errors) - - def test_valid_with_extended_community(self, app): - with app.test_request_context(): - form_data = create_form_data({"name": "test_community", "extcomm": "rt:64512:100", "role_id": "2"}) - form = CommunityForm(formdata=form_data) - assert form.validate() - - def test_valid_with_large_community(self, app): - with app.test_request_context(): - form_data = create_form_data({"name": "test_community", "larcomm": "64512:100:200", "role_id": "2"}) - form = CommunityForm(formdata=form_data) - assert form.validate() - - -class TestWhitelistForm: - @pytest.fixture - def valid_ipv4_data(self, valid_datetime): - return {"ip": "192.168.0.0", "mask": "24", "comment": "Test whitelist entry", "expires": valid_datetime} - - @pytest.fixture - def valid_ipv6_data(self, valid_datetime): - return {"ip": "2001:db8::", "mask": "32", "comment": "IPv6 whitelist entry", "expires": valid_datetime} - - def test_valid_ipv4_entry(self, app, valid_ipv4_data, sample_network_ranges): - with app.test_request_context(): - form_data = create_form_data(valid_ipv4_data) - form = WhitelistForm(formdata=form_data) - form.net_ranges = sample_network_ranges - - if not form.validate(): - print("Validation errors:", form.errors) - - assert form.validate() - - def test_valid_ipv6_entry(self, app, valid_ipv6_data, sample_network_ranges): - with app.test_request_context(): - form_data = create_form_data(valid_ipv6_data) - form = WhitelistForm(formdata=form_data) - form.net_ranges = sample_network_ranges - - if not form.validate(): - print("Validation errors:", form.errors) - - assert form.validate() - - def test_invalid_ip_format(self, app, valid_datetime, sample_network_ranges): - with app.test_request_context(): - form_data = create_form_data({"ip": "invalid_ip", "mask": "24", "expires": valid_datetime}) - form = WhitelistForm(formdata=form_data) - form.net_ranges = sample_network_ranges - assert not form.validate() - assert any("valid IP address" in error for error in form.ip.errors) - - def test_ip_outside_range(self, app, valid_datetime): - with app.test_request_context(): - form_data = create_form_data( - {"ip": "10.0.0.0", "mask": "24", "expires": valid_datetime} # IP outside allowed range - ) - form = WhitelistForm(formdata=form_data) - form.net_ranges = ["192.168.0.0/16"] # Only allow 192.168.0.0/16 - assert not form.validate() - assert any("must be in organization range" in error for error in form.ip.errors) - - def test_invalid_mask_ipv4(self, app, valid_datetime, sample_network_ranges): - with app.test_request_context(): - form_data = create_form_data( - {"ip": "192.168.0.0", "mask": "33", "expires": valid_datetime} # Invalid mask for IPv4 - ) - form = WhitelistForm(formdata=form_data) - form.net_ranges = sample_network_ranges - assert not form.validate() - - def test_invalid_mask_ipv6(self, app, valid_datetime, sample_network_ranges): - with app.test_request_context(): - form_data = create_form_data( - {"ip": "2001:db8::", "mask": "129", "expires": valid_datetime} # Invalid mask for IPv6 - ) - form = WhitelistForm(formdata=form_data) - form.net_ranges = sample_network_ranges - assert not form.validate() - - def test_missing_required_fields(self, app, sample_network_ranges): - with app.test_request_context(): - form_data = create_form_data({}) - form = WhitelistForm(formdata=form_data) - form.net_ranges = sample_network_ranges - assert not form.validate() - assert form.ip.errors # Should have error for missing IP - assert form.mask.errors # Should have error for missing mask - assert form.expires.errors # Should have error for missing expiration - - def test_comment_length(self, app, valid_datetime, sample_network_ranges): - with app.test_request_context(): - form_data = create_form_data( - { - "ip": "192.168.0.0", - "mask": "24", - "expires": valid_datetime, - "comment": "x" * 256, # Comment longer than 255 chars - } - ) - form = WhitelistForm(formdata=form_data) - form.net_ranges = sample_network_ranges - assert not form.validate() - - @pytest.mark.parametrize( - "expires", ["", "invalid_date", "2020-33-42T00:00"] # Empty expiration # Invalid date format # Past date - ) - def test_invalid_expiration(self, app, expires, sample_network_ranges): - with app.test_request_context(): - form_data = create_form_data({"ip": "192.168.0.0", "mask": "24", "expires": expires}) - form = WhitelistForm(formdata=form_data) - form.net_ranges = sample_network_ranges - if not form.validate(): - print("Validation errors:", form.errors) - - assert not form.validate() - - def test_network_alignment(self, app, valid_datetime, sample_network_ranges): - """Test that IP addresses must be properly network-aligned""" - with app.test_request_context(): - form_data = create_form_data( - {"ip": "192.168.1.1", "mask": "24", "expires": valid_datetime} # Not aligned to /24 boundary - ) - form = WhitelistForm(formdata=form_data) - form.net_ranges = sample_network_ranges - assert not form.validate() diff --git a/flowapp/tests/test_login.py b/flowapp/tests/test_login.py deleted file mode 100644 index e69de29..0000000 diff --git a/flowapp/tests/test_models.py b/flowapp/tests/test_models.py deleted file mode 100644 index b352a5f..0000000 --- a/flowapp/tests/test_models.py +++ /dev/null @@ -1,498 +0,0 @@ -from datetime import datetime, timedelta -from flowapp.models import ( - User, - Organization, - Role, - ApiKey, - MachineApiKey, - Rstate, - Community, - Action, - Flowspec6, - Whitelist, -) - -import flowapp.models as models - - -def test_insert_ipv4(db): - """ - test the record can be inserted - :param db: conftest fixture - :return: - """ - model = models.Flowspec4( - source="192.168.1.1", - source_mask="32", - source_port="80", - destination="", - destination_mask="", - destination_port="", - protocol="tcp", - flags="", - packet_len="", - fragment="", - action_id=1, - expires=datetime.now(), - user_id=1, - org_id=1, - rstate_id=1, - ) - db.session.add(model) - db.session.commit() - - -def test_get_ipv4_model_if_exists(db): - """ - test if the function find existing model correctly - :param db: conftest fixture - :return: - """ - model = models.Flowspec4( - source="192.168.1.1", - source_mask="32", - source_port="80", - destination="", - destination_mask="", - destination_port="", - protocol="tcp", - flags="", - fragment="", - packet_len="", - action_id=1, - expires=datetime.now(), - user_id=1, - org_id=1, - rstate_id=1, - ) - db.session.add(model) - db.session.commit() - - form_data = { - "source": "192.168.1.1", - "source_mask": "32", - "source_port": "80", - "dest": "", - "dest_mask": "", - "dest_port": "", - "protocol": "tcp", - "flags": "", - "packet_len": "", - "action": 1, - } - - result = models.get_ipv4_model_if_exists(form_data, 1) - assert result - assert result == model - - -def test_get_ipv6_model_if_exists(db): - """ - test if the function find existing model correctly - :param db: conftest fixture - :return: - """ - model = models.Flowspec6( - source="2001:0db8:85a3:0000:0000:8a2e:0370:7334", - source_mask="32", - source_port="80", - destination="", - destination_mask="", - destination_port="", - next_header="tcp", - flags="", - packet_len="", - action_id=1, - expires=datetime.now(), - user_id=1, - org_id=1, - rstate_id=1, - ) - db.session.add(model) - db.session.commit() - - form_data = { - "source": "2001:0db8:85a3:0000:0000:8a2e:0370:7334", - "source_mask": "32", - "source_port": "80", - "dest": "", - "dest_mask": "", - "dest_port": "", - "next_header": "tcp", - "flags": "", - "packet_len": "", - "action": 1, - } - - result = models.get_ipv6_model_if_exists(form_data, 1) - assert result - assert result == model - - -def test_ipv4_eq(db): - """ - test that creating with valid data returns 201 - """ - model_A = models.Flowspec4( - source="192.168.1.1", - source_mask="32", - source_port="80", - destination="", - destination_mask="", - destination_port="", - protocol="tcp", - flags="", - fragment="", - packet_len="", - action_id=1, - expires="123", - user_id=1, - org_id=1, - rstate_id=1, - ) - - model_B = models.Flowspec4( - source="192.168.1.1", - source_mask="32", - source_port="80", - destination="", - destination_mask="", - destination_port="", - protocol="tcp", - flags="", - fragment="", - packet_len="", - action_id=1, - expires="123456", - user_id=1, - org_id=1, - rstate_id=1, - ) - - assert model_A == model_B - - -def test_ipv4_ne(db): - """ - test that creating with valid data returns 201 - """ - model_A = models.Flowspec4( - source="192.168.2.2", - source_mask="32", - source_port="80", - destination="", - destination_mask="", - destination_port="", - protocol="tcp", - flags="", - fragment="", - packet_len="", - action_id=1, - expires="123", - user_id=1, - org_id=1, - rstate_id=1, - ) - - model_B = models.Flowspec4( - source="192.168.1.1", - source_mask="32", - source_port="80", - destination="", - destination_mask="", - destination_port="", - protocol="tcp", - flags="", - fragment="", - packet_len="", - action_id=1, - expires="123456", - user_id=1, - org_id=1, - rstate_id=1, - ) - - assert model_A != model_B - - -def test_rtbj_eq(db): - """ - test that two equal rtbh rules are equal - """ - model_A = models.RTBH( - ipv4="192.168.1.1", - ipv4_mask="32", - ipv6="", - ipv6_mask="", - community_id=1, - expires="123", - user_id=1, - org_id=1, - rstate_id=1, - ) - - model_B = models.RTBH( - ipv4="192.168.1.1", - ipv4_mask="32", - ipv6="", - ipv6_mask="", - community_id=1, - expires="123456", - user_id=1, - org_id=1, - rstate_id=1, - ) - - assert model_A == model_B - - -def test_user_creation(db): - """Test basic user creation and relationships""" - # Create test role and org first - role = Role(name="test_role", description="Test Role") - org = Organization(name="test_org", arange="10.0.0.0/8") - db.session.add_all([role, org]) - db.session.commit() - - # Create user with relationships - user = User( - uuid="test-user-123", name="Test User", phone="1234567890", email="test@example.com", comment="Test comment" - ) - user.role.append(role) - user.organization.append(org) - db.session.add(user) - db.session.commit() - - # Verify user and relationships - assert user.uuid == "test-user-123" - assert user.name == "Test User" - assert len(user.role.all()) == 1 - assert len(user.organization.all()) == 1 - assert user.role.first().name == "test_role" - assert user.organization.first().name == "test_org" - - -def test_api_key_expiration(db): - """Test ApiKey expiration logic""" - user = User(uuid="test-user") - org = Organization(name="test-org", arange="10.0.0.0/8") - db.session.add_all([user, org]) - db.session.commit() - - # Create non-expiring key - non_expiring_key = ApiKey( - machine="test-machine-1", - key="key1", - readonly=True, - expires=None, - comment="Non-expiring key", - user_id=user.id, - org_id=org.id, - ) - - # Create expired key - expired_key = ApiKey( - machine="test-machine-2", - key="key2", - readonly=True, - expires=datetime.now() - timedelta(days=1), - comment="Expired key", - user_id=user.id, - org_id=org.id, - ) - - # Create future key - future_key = ApiKey( - machine="test-machine-3", - key="key3", - readonly=True, - expires=datetime.now() + timedelta(days=1), - comment="Future key", - user_id=user.id, - org_id=org.id, - ) - - db.session.add_all([non_expiring_key, expired_key, future_key]) - db.session.commit() - - assert not non_expiring_key.is_expired() - assert expired_key.is_expired() - assert not future_key.is_expired() - - -def test_machine_api_key_expiration(db): - """Test MachineApiKey expiration logic""" - user = User(uuid="test-user-machine") - org = Organization(name="test-org-machine", arange="10.0.0.0/8") - db.session.add_all([user, org]) - db.session.commit() - - # Create non-expiring key - non_expiring_key = MachineApiKey( - machine="test-machine-1", - key="key1", - readonly=True, - expires=None, - comment="Non-expiring key", - user_id=user.id, - org_id=org.id, - ) - - # Create expired key - expired_key = MachineApiKey( - machine="test-machine-2", - key="key2", - readonly=True, - expires=datetime.now() - timedelta(days=1), - comment="Expired key", - user_id=user.id, - org_id=org.id, - ) - - db.session.add_all([non_expiring_key, expired_key]) - db.session.commit() - - assert not non_expiring_key.is_expired() - assert expired_key.is_expired() - - -def test_organization_get_users(db): - """Test Organization's get_users method""" - org = Organization( - name="test-org-get-user", arange="10.0.0.0/8", limit_flowspec4=100, limit_flowspec6=100, limit_rtbh=100 - ) - uuid1 = "test-org-get-user" - uuid2 = "test-org-get-user2" - user1 = User(uuid=uuid1) - user2 = User(uuid=uuid2) - - db.session.add(org) - db.session.add_all([user1, user2]) - db.session.commit() - - org.user.append(user1) - org.user.append(user2) - db.session.commit() - - users = org.get_users() - assert len(users) == 2 - assert all(isinstance(user, User) for user in users) - assert {user.uuid for user in users} == {uuid1, uuid2} - - -def test_flowspec6_equality(db): - """Test Flowspec6 equality comparison""" - model_a = Flowspec6( - source="2001:db8::1", - source_mask=128, - source_port="80", - destination="2001:db8::2", - destination_mask=128, - destination_port="443", - next_header="tcp", - flags="", - packet_len="", - expires=datetime.now(), - user_id=1, - org_id=1, - action_id=1, - ) - - # Same network parameters but different timestamps - model_b = Flowspec6( - source="2001:db8::1", - source_mask=128, - source_port="80", - destination="2001:db8::2", - destination_mask=128, - destination_port="443", - next_header="tcp", - flags="", - packet_len="", - expires=datetime.now() + timedelta(days=1), - user_id=1, - org_id=1, - action_id=1, - ) - - # Different network parameters - model_c = Flowspec6( - source="2001:db8::3", - source_mask=128, - source_port="80", - destination="2001:db8::4", - destination_mask=128, - destination_port="443", - next_header="tcp", - flags="", - packet_len="", - expires=datetime.now(), - user_id=1, - org_id=1, - action_id=1, - ) - - assert model_a == model_b # Should be equal despite different timestamps - assert model_a != model_c # Should be different due to different network parameters - - -def test_whitelist_equality(db): - """Test Whitelist equality comparison""" - model_a = Whitelist( - ip="192.168.1.1", mask=32, expires=datetime.now(), user_id=1, org_id=1, comment="Test whitelist" - ) - - # Same IP/mask but different timestamps - model_b = Whitelist( - ip="192.168.1.1", - mask=32, - expires=datetime.now() + timedelta(days=1), - user_id=1, - org_id=1, - comment="Different comment", - ) - - # Different IP - model_c = Whitelist( - ip="192.168.1.2", mask=32, expires=datetime.now(), user_id=1, org_id=1, comment="Test whitelist" - ) - - assert model_a == model_b # Should be equal despite different timestamps - assert model_a != model_c # Should be different due to different IP - - -def test_whitelist_to_dict(db): - """Test Whitelist to_dict serialization""" - whitelist = Whitelist( - ip="192.168.1.1", mask=32, expires=datetime.now(), user_id=1, org_id=1, comment="Test whitelist" - ) - - # Create required related objects - user = User(uuid="test-user-whitelist") - rstate = Rstate(description="active") - db.session.add_all([user, rstate]) - db.session.commit() - - db.session.add(whitelist) - db.session.commit() - - whitelist.user = user - whitelist.rstate_id = rstate.id - db.session.add(whitelist) - db.session.commit() - - # Test timestamp format - dict_timestamp = whitelist.to_dict(prefered_format="timestamp") - assert isinstance(dict_timestamp["expires"], int) - assert isinstance(dict_timestamp["created"], int) - - # Test yearfirst format - dict_yearfirst = whitelist.to_dict(prefered_format="yearfirst") - assert isinstance(dict_yearfirst["expires"], str) - assert isinstance(dict_yearfirst["created"], str) - - # Check basic fields - assert dict_timestamp["ip"] == "192.168.1.1" - assert dict_timestamp["mask"] == 32 - assert dict_timestamp["comment"] == "Test whitelist" - assert dict_timestamp["user"] == "test-user-whitelist" diff --git a/flowapp/tests/test_rule_service.py b/flowapp/tests/test_rule_service.py deleted file mode 100644 index 490d2c2..0000000 --- a/flowapp/tests/test_rule_service.py +++ /dev/null @@ -1,628 +0,0 @@ -""" -Tests for rule_service.py module. - -This test suite verifies the functionality of the rule service after refactoring, -which manages creation, updating, and processing of flow rules. -""" - -import pytest -from datetime import datetime, timedelta -from unittest.mock import patch, MagicMock - -from flowapp.constants import RuleOrigin, ANNOUNCE -from flowapp.models import Flowspec4, Flowspec6, RTBH, Whitelist -from flowapp.services import rule_service -from flowapp.services.whitelist_common import Relation - - -@pytest.fixture -def ipv4_form_data(): - """Sample valid IPv4 form data""" - return { - "source": "192.168.1.0", - "source_mask": 24, - "source_port": "80", - "dest": "", - "dest_mask": None, - "dest_port": "", - "protocol": "tcp", - "flags": ["SYN"], - "packet_len": "", - "fragment": ["dont-fragment"], - "comment": "Test IPv4 rule", - "expires": datetime.now() + timedelta(hours=1), - "action": 1, - } - - -@pytest.fixture -def ipv6_form_data(): - """Sample valid IPv6 form data""" - return { - "source": "2001:db8::", - "source_mask": 32, - "source_port": "80", - "dest": "", - "dest_mask": None, - "dest_port": "", - "next_header": "tcp", - "flags": ["SYN"], - "packet_len": "", - "comment": "Test IPv6 rule", - "expires": datetime.now() + timedelta(hours=1), - "action": 1, - } - - -@pytest.fixture -def rtbh_form_data(): - """Sample valid RTBH form data""" - return { - "ipv4": "192.168.1.0", - "ipv4_mask": 24, - "ipv6": "", - "ipv6_mask": None, - "community": 1, - "comment": "Test RTBH rule", - "expires": datetime.now() + timedelta(hours=1), - } - - -@pytest.fixture -def whitelist_fixture(): - """Create a whitelist fixture""" - whitelist = MagicMock(spec=Whitelist) - whitelist.id = 1 - return whitelist - - -class TestCreateOrUpdateIPv4Rule: - @patch("flowapp.services.rule_service.get_ipv4_model_if_exists") - @patch("flowapp.services.rule_service.messages") - @patch("flowapp.services.rule_service.announce_route") - @patch("flowapp.services.rule_service.log_route") - def test_create_new_ipv4_rule( - self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv4_form_data - ): - """Test creating a new IPv4 rule""" - # Mock the get_ipv4_model_if_exists to return False (not found) - mock_get_model.return_value = False - - # Mock the announce route behavior - mock_messages.create_ipv4.return_value = "mock command" - - # Call the service function - with app.app_context(): - model, message = rule_service.create_or_update_ipv4_rule( - form_data=ipv4_form_data, - user_id=1, - org_id=1, - user_email="test@example.com", - org_name="Test Org", - ) - - # Verify the model was created with correct attributes - assert model is not None - assert model.source == ipv4_form_data["source"] - assert model.source_mask == ipv4_form_data["source_mask"] - assert model.protocol == ipv4_form_data["protocol"] - assert model.flags == "SYN" # form data flags are joined with ";" - assert model.action_id == ipv4_form_data["action"] - assert model.rstate_id == 1 # Active state - assert model.user_id == 1 - assert model.org_id == 1 - - # Verify message is still a string for IPv4 rules - assert message == "IPv4 Rule saved" - - # Verify route was announced - mock_messages.create_ipv4.assert_called_once_with(model, ANNOUNCE) - mock_announce.assert_called_once() - mock_log.assert_called_once() - - @patch("flowapp.services.rule_service.get_ipv4_model_if_exists") - @patch("flowapp.services.rule_service.messages") - @patch("flowapp.services.rule_service.announce_route") - @patch("flowapp.services.rule_service.log_route") - def test_update_existing_ipv4_rule( - self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv4_form_data - ): - """Test updating an existing IPv4 rule""" - # Create an existing model to return - existing_model = Flowspec4( - source=ipv4_form_data["source"], - source_mask=ipv4_form_data["source_mask"], - source_port=ipv4_form_data["source_port"], - destination=ipv4_form_data["dest"] or "", - destination_mask=ipv4_form_data["dest_mask"], - destination_port=ipv4_form_data["dest_port"] or "", - protocol=ipv4_form_data["protocol"], - flags=";".join(ipv4_form_data["flags"]), - packet_len=ipv4_form_data["packet_len"] or "", - fragment=";".join(ipv4_form_data["fragment"]), - expires=datetime.now(), - user_id=1, - org_id=1, - action_id=1, - rstate_id=1, - ) - mock_get_model.return_value = existing_model - mock_messages.create_ipv4.return_value = "mock command" - - # Set a new expiration time - new_expires = datetime.now() + timedelta(days=1) - ipv4_form_data["expires"] = new_expires - - # Call the service function - with app.app_context(): - db.session.add(existing_model) - db.session.commit() - - model, message = rule_service.create_or_update_ipv4_rule( - form_data=ipv4_form_data, - user_id=1, - org_id=1, - user_email="test@example.com", - org_name="Test Org", - ) - - # Verify the model was updated - assert model == existing_model - assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() - - # Verify message is still a string for IPv4 rules - assert message == "Existing IPv4 Rule found. Expiration time was updated to new value." - - # Verify route was announced - mock_messages.create_ipv4.assert_called_once_with(model, ANNOUNCE) - mock_announce.assert_called_once() - mock_log.assert_called_once() - - -class TestCreateOrUpdateIPv6Rule: - @patch("flowapp.services.rule_service.get_ipv6_model_if_exists") - @patch("flowapp.services.rule_service.messages") - @patch("flowapp.services.rule_service.announce_route") - @patch("flowapp.services.rule_service.log_route") - def test_create_new_ipv6_rule( - self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv6_form_data - ): - """Test creating a new IPv6 rule""" - # Mock get_ipv6_model_if_exists to return False - mock_get_model.return_value = False - - # Mock the announce route behavior - mock_messages.create_ipv6.return_value = "mock command" - - # Call the service function - with app.app_context(): - model, message = rule_service.create_or_update_ipv6_rule( - form_data=ipv6_form_data, - user_id=1, - org_id=1, - user_email="test@example.com", - org_name="Test Org", - ) - - # Verify the model was created with correct attributes - assert model is not None - assert model.source == ipv6_form_data["source"] - assert model.source_mask == ipv6_form_data["source_mask"] - assert model.next_header == ipv6_form_data["next_header"] - assert model.flags == "SYN" - assert model.action_id == ipv6_form_data["action"] - assert model.rstate_id == 1 # Active state - assert model.user_id == 1 - assert model.org_id == 1 - - # Verify message is still a string for IPv6 rules - assert message == "IPv6 Rule saved" - - # Verify route was announced - mock_messages.create_ipv6.assert_called_once_with(model, ANNOUNCE) - mock_announce.assert_called_once() - mock_log.assert_called_once() - - -class TestCreateOrUpdateRTBHRule: - @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") - @patch("flowapp.services.rule_service.db.session.query") - @patch("flowapp.services.rule_service.check_rule_against_whitelists") - @patch("flowapp.services.rule_service.evaluate_rtbh_against_whitelists_check_results") - @patch("flowapp.services.rule_service.announce_rtbh_route") - @patch("flowapp.services.rule_service.log_route") - def test_create_new_rtbh_rule( - self, mock_log, mock_announce, mock_evaluate, mock_check, mock_query, mock_get_model, app, db, rtbh_form_data - ): - """Test creating a new RTBH rule""" - # Mock get_rtbh_model_if_exists to return False - mock_get_model.return_value = False - - # Mock the whitelist query - mock_whitelists = [] - mock_query.return_value.filter.return_value.all.return_value = mock_whitelists - - # Mock check_rule_against_whitelists to return empty list (no matches) - mock_check.return_value = [] - - # Mock evaluate function to return the model unchanged - mock_evaluate.side_effect = lambda user_id, model, flashes, author, wl_cache, results: model - - # Call the service function - with app.app_context(): - model, flashes = rule_service.create_or_update_rtbh_rule( - form_data=rtbh_form_data, - user_id=1, - org_id=1, - user_email="test@example.com", - org_name="Test Org", - ) - - # Verify the model was created with correct attributes - assert model is not None - assert model.ipv4 == rtbh_form_data["ipv4"] - assert model.ipv4_mask == rtbh_form_data["ipv4_mask"] - assert model.ipv6 == rtbh_form_data["ipv6"] - assert model.ipv6_mask == rtbh_form_data["ipv6_mask"] - assert model.community_id == rtbh_form_data["community"] - assert model.rstate_id == 1 # Active state - assert model.user_id == 1 - assert model.org_id == 1 - - # Verify flash messages - now a list instead of a string - assert isinstance(flashes, list) - assert "RTBH Rule saved" in flashes[0] - - # Verify rule was announced - mock_announce.assert_called_once() - mock_log.assert_called_once() - - # Verify evaluate function was called - mock_evaluate.assert_called_once() - - @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") - @patch("flowapp.services.rule_service.db.session.query") - @patch("flowapp.services.rule_service.check_rule_against_whitelists") - @patch("flowapp.services.rule_service.evaluate_rtbh_against_whitelists_check_results") - @patch("flowapp.services.rule_service.announce_rtbh_route") - @patch("flowapp.services.rule_service.log_route") - def test_update_existing_rtbh_rule( - self, mock_log, mock_announce, mock_evaluate, mock_check, mock_query, mock_get_model, app, db, rtbh_form_data - ): - """Test updating an existing RTBH rule""" - # Create an existing model to return - existing_model = RTBH( - ipv4=rtbh_form_data["ipv4"], - ipv4_mask=rtbh_form_data["ipv4_mask"], - ipv6=rtbh_form_data["ipv6"] or "", - ipv6_mask=rtbh_form_data["ipv6_mask"], - community_id=rtbh_form_data["community"], - expires=datetime.now(), - user_id=1, - org_id=1, - rstate_id=1, - ) - mock_get_model.return_value = existing_model - - # Mock the whitelist query - mock_whitelists = [] - mock_query.return_value.filter.return_value.all.return_value = mock_whitelists - - # Mock check_rule_against_whitelists to return empty list - mock_check.return_value = [] - - # Mock evaluate function to return the model unchanged - mock_evaluate.side_effect = lambda user_id, model, flashes, author, wl_cache, results: model - - # Set a new expiration time - new_expires = datetime.now() + timedelta(days=1) - rtbh_form_data["expires"] = new_expires - - # Call the service function - with app.app_context(): - db.session.add(existing_model) - db.session.commit() - - model, flashes = rule_service.create_or_update_rtbh_rule( - form_data=rtbh_form_data, - user_id=1, - org_id=1, - user_email="test@example.com", - org_name="Test Org", - ) - - # Verify the model was updated - assert model == existing_model - assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() - - # Verify flash messages - assert isinstance(flashes, list) - assert "Existing RTBH Rule found" in flashes[0] - - # Verify route was announced - mock_announce.assert_called_once() - mock_log.assert_called_once() - - @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") - @patch("flowapp.services.rule_service.db.session.query") - @patch("flowapp.services.rule_service.map_whitelists_to_strings") - @patch("flowapp.services.rule_service.check_rule_against_whitelists") - @patch("flowapp.services.rule_service.evaluate_rtbh_against_whitelists_check_results") - @patch("flowapp.services.rule_service.announce_rtbh_route") - @patch("flowapp.services.rule_service.log_route") - def test_rtbh_rule_with_whitelists( - self, - mock_log, - mock_announce, - mock_evaluate, - mock_check, - mock_map, - mock_query, - mock_get_model, - app, - db, - rtbh_form_data, - whitelist_fixture, - ): - """Test creating a RTBH rule that interacts with whitelists""" - # Mock get_rtbh_model_if_exists to return False - mock_get_model.return_value = False - - # Create a mock whitelist - mock_whitelist = whitelist_fixture - mock_whitelists = [mock_whitelist] - - # Setup mock query to return our whitelist - mock_query_result = MagicMock() - mock_query_result.filter.return_value.all.return_value = mock_whitelists - mock_query.return_value = mock_query_result - - # Setup map function to return our whitelist in a dict - whitelist_key = "192.168.1.0/24" - mock_map.return_value = {whitelist_key: mock_whitelist} - - # Setup check function to return a relation - mock_rtbh = MagicMock(spec=RTBH) - mock_rtbh.__str__.return_value = "192.168.1.0/24" - - # Setup check result - rule_relation = [(str(mock_rtbh), whitelist_key, Relation.EQUAL)] - mock_check.return_value = rule_relation - - # Setup evaluate function to add a flash message - def evaluate_side_effect(user_id, model, flashes, author, wl_cache, results): - flashes.append("Rule is equal to whitelist") - return model - - mock_evaluate.side_effect = evaluate_side_effect - - # Call the service function - with app.app_context(): - model, flashes = rule_service.create_or_update_rtbh_rule( - form_data=rtbh_form_data, - user_id=1, - org_id=1, - user_email="test@example.com", - org_name="Test Org", - ) - - # Verify flash messages show both rule creation and whitelist check - assert isinstance(flashes, list) - assert "RTBH Rule saved" in flashes[0] - assert "Rule is equal to whitelist" in flashes[1] - - # Verify interactions - mock_map.assert_called_once() - mock_check.assert_called_once() - mock_evaluate.assert_called_once() - - -class TestEvaluateRtbhAgainstWhitelistsCheckResults: - def test_equal_relation(self, app, whitelist_fixture): - """Test evaluating a rule with an EQUAL relation to a whitelist""" - # Create a model - model = MagicMock(spec=RTBH) - - # Create test data - flashes = [] - user_id = 1 - author = "test@example.com / Test Org" - whitelist_key = "192.168.1.0/24" - wl_cache = {whitelist_key: whitelist_fixture} - results = [(str(model), whitelist_key, Relation.EQUAL)] - - # Call the function with mocked whitelist_rtbh_rule - with patch("flowapp.services.rule_service.whitelist_rtbh_rule") as mock_whitelist_rule: - mock_whitelist_rule.return_value = model - - with app.app_context(): - result = rule_service.evaluate_rtbh_against_whitelists_check_results( - user_id, model, flashes, author, wl_cache, results - ) - - # Verify the rule was whitelisted - mock_whitelist_rule.assert_called_once_with(model, whitelist_fixture) - - # Verify the flash message - assert flashes - - # Verify the correct model was returned - assert result == model - - def test_subnet_relation(self, app, whitelist_fixture): - """Test evaluating a rule with a SUBNET relation to a whitelist""" - # Create a model - model = MagicMock(spec=RTBH) - - # Create test data - flashes = [] - user_id = 1 - author = "test@example.com / Test Org" - whitelist_key = "192.168.1.128/25" - wl_cache = {whitelist_key: whitelist_fixture} - results = [(str(model), whitelist_key, Relation.SUBNET)] - - # Call the function with mocked dependencies - with patch("flowapp.services.rule_service.subtract_network") as mock_subtract, patch( - "flowapp.services.rule_service.create_rtbh_from_whitelist_parts" - ) as mock_create, patch("flowapp.services.rule_service.add_rtbh_rule_to_cache") as mock_add_cache, patch( - "flowapp.services.rule_service.db.session.commit" - ) as mock_commit: - - # Mock subtract_network to return some subnets - mock_subtract.return_value = ["192.168.1.0/25"] - - with app.app_context(): - _result = rule_service.evaluate_rtbh_against_whitelists_check_results( - user_id, model, flashes, author, wl_cache, results - ) - - # Verify subnet calculation was performed - mock_subtract.assert_called_once() - - # Verify new rules were created for the subnets - mock_create.assert_called_once() - - # Verify the original rule was cached - mock_add_cache.assert_called_once_with(model, whitelist_fixture.id, RuleOrigin.USER) - - # Verify transaction was committed - mock_commit.assert_called_once() - - # Verify the flash messages - assert flashes - - # Verify model was updated to whitelisted state - assert model.rstate_id == 4 - - def test_supernet_relation(self, app, whitelist_fixture): - """Test evaluating a rule with a SUPERNET relation to a whitelist""" - # Create a model - model = MagicMock(spec=RTBH) - - # Create test data - flashes = [] - user_id = 1 - author = "test@example.com / Test Org" - whitelist_key = "192.168.0.0/16" - wl_cache = {whitelist_key: whitelist_fixture} - results = [(str(model), whitelist_key, Relation.SUPERNET)] - - # Call the function with mocked whitelist_rtbh_rule - with patch("flowapp.services.rule_service.whitelist_rtbh_rule") as mock_whitelist_rule: - mock_whitelist_rule.return_value = model - - with app.app_context(): - result = rule_service.evaluate_rtbh_against_whitelists_check_results( - user_id, model, flashes, author, wl_cache, results - ) - - # Verify the rule was whitelisted - mock_whitelist_rule.assert_called_once_with(model, whitelist_fixture) - - # Verify the flash message - assert flashes - - # Verify the correct model was returned - assert result == model - - def test_no_relation(self, app): - """Test evaluating a rule with no relation to any whitelist""" - # Create a model - model = MagicMock(spec=RTBH) - - # Create test data - flashes = [] - user_id = 1 - author = "test@example.com / Test Org" - wl_cache = {} - results = [] - - with app.app_context(): - result = rule_service.evaluate_rtbh_against_whitelists_check_results( - user_id, model, flashes, author, wl_cache, results - ) - - # Verify no changes to the model and no messages - assert result == model - assert not flashes - - -class TestMapWhitelistsToStrings: - def test_map_whitelists_to_strings(self): - """Test mapping whitelist objects to strings""" - # Create mock whitelists - whitelist1 = MagicMock(spec=Whitelist) - whitelist1.__str__.return_value = "192.168.1.0/24" - - whitelist2 = MagicMock(spec=Whitelist) - whitelist2.__str__.return_value = "10.0.0.0/8" - - whitelists = [whitelist1, whitelist2] - - # Call the function - result = rule_service.map_whitelists_to_strings(whitelists) - - # Verify the result - assert len(result) == 2 - assert "192.168.1.0/24" in result - assert "10.0.0.0/8" in result - assert result["192.168.1.0/24"] == whitelist1 - assert result["10.0.0.0/8"] == whitelist2 - - @patch("flowapp.services.rule_service.get_ipv6_model_if_exists") - @patch("flowapp.services.rule_service.messages") - @patch("flowapp.services.rule_service.announce_route") - @patch("flowapp.services.rule_service.log_route") - def test_update_existing_ipv6_rule( - self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv6_form_data - ): - """Test updating an existing IPv6 rule""" - # Create an existing model to return - existing_model = Flowspec6( - source=ipv6_form_data["source"], - source_mask=ipv6_form_data["source_mask"], - source_port=ipv6_form_data["source_port"] or "", - destination=ipv6_form_data["dest"] or "", - destination_mask=ipv6_form_data["dest_mask"], - destination_port=ipv6_form_data["dest_port"] or "", - next_header=ipv6_form_data["next_header"], - flags=";".join(ipv6_form_data["flags"]), - packet_len=ipv6_form_data["packet_len"] or "", - expires=datetime.now(), - user_id=1, - org_id=1, - action_id=1, - rstate_id=1, - ) - mock_get_model.return_value = existing_model - mock_messages.create_ipv6.return_value = "mock command" - - # Set a new expiration time - new_expires = datetime.now() + timedelta(days=1) - ipv6_form_data["expires"] = new_expires - - # Call the service function - with app.app_context(): - db.session.add(existing_model) - db.session.commit() - - model, message = rule_service.create_or_update_ipv6_rule( - form_data=ipv6_form_data, - user_id=1, - org_id=1, - user_email="test@example.com", - org_name="Test Org", - ) - - # Verify the model was updated - assert model == existing_model - assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() - - # Verify message is still a string for IPv6 rules - assert message == "Existing IPv6 Rule found. Expiration time was updated to new value." - - # Verify route was announced - mock_messages.create_ipv6.assert_called_once_with(model, ANNOUNCE) - mock_announce.assert_called_once() - mock_log.assert_called_once() diff --git a/flowapp/tests/test_rule_service_reactivate_delete.py b/flowapp/tests/test_rule_service_reactivate_delete.py deleted file mode 100644 index 22a3f9f..0000000 --- a/flowapp/tests/test_rule_service_reactivate_delete.py +++ /dev/null @@ -1,527 +0,0 @@ -"""Tests for rule_service module.""" - -import pytest -from datetime import datetime, timedelta -from unittest.mock import patch - -from flowapp.constants import RuleTypes -from flowapp.models import ( - RTBH, - Flowspec4, - Flowspec6, - Whitelist, -) -from flowapp.output import RouteSources -from flowapp.services import rule_service - - -@pytest.fixture -def test_data(app, db): - """Fixture providing test data for rule service tests.""" - current_time = datetime.now() - - # Create a test Flowspec4 rule - ipv4_rule = Flowspec4( - source="192.168.1.1", - source_mask=32, - source_port="", - destination="192.168.2.1", - destination_mask=32, - destination_port="", - protocol="tcp", - flags="", - packet_len="", - fragment="", - expires=current_time + timedelta(hours=1), - comment="Test IPv4 rule", - action_id=1, # Using action ID 1 from test database - user_id=1, # Using user ID 1 from test database - org_id=1, # Using org ID 1 from test database - rstate_id=1, # Active state - ) - db.session.add(ipv4_rule) - - # Create a test Flowspec6 rule - ipv6_rule = Flowspec6( - source="2001:db8::1", - source_mask=128, - source_port="", - destination="2001:db8:1::1", - destination_mask=128, - destination_port="", - next_header="tcp", - flags="", - packet_len="", - expires=current_time + timedelta(hours=1), - comment="Test IPv6 rule", - action_id=1, # Using action ID 1 from test database - user_id=1, # Using user ID 1 from test database - org_id=1, # Using org ID 1 from test database - rstate_id=1, # Active state - ) - db.session.add(ipv6_rule) - - # Create a test RTBH rule - rtbh_rule = RTBH( - ipv4="192.168.1.100", - ipv4_mask=32, - ipv6=None, - ipv6_mask=None, - community_id=1, # Using community ID 1 from test database - expires=current_time + timedelta(hours=1), - comment="Test RTBH rule", - user_id=1, # Using user ID 1 from test database - org_id=1, # Using org ID 1 from test database - rstate_id=1, # Active state - ) - db.session.add(rtbh_rule) - - # Create a test Whitelist - whitelist = Whitelist( - ip="192.168.2.1", - mask=24, - expires=current_time + timedelta(days=7), - comment="Test whitelist", - user_id=1, - org_id=1, - rstate_id=1, - ) - db.session.add(whitelist) - - db.session.commit() - - # Return data that will be useful for tests - return { - "user_id": 1, - "org_id": 1, - "user_email": "test@example.com", - "org_name": "Test Org", - "ipv4_rule_id": ipv4_rule.id, - "ipv6_rule_id": ipv6_rule.id, - "rtbh_rule_id": rtbh_rule.id, - "whitelist_id": whitelist.id, - "current_time": current_time, - "future_time": current_time + timedelta(hours=1), - "past_time": current_time - timedelta(hours=1), - "comment": "Test comment", - } - - -class TestReactivateRule: - """Tests for the reactivate_rule function.""" - - @patch("flowapp.services.rule_service.check_global_rule_limit") - @patch("flowapp.services.rule_service.check_rule_limit") - @patch("flowapp.services.rule_service.announce_route") - @patch("flowapp.services.rule_service.log_route") - def test_reactivate_rule_active( - self, mock_log_route, mock_announce_route, mock_check_rule_limit, mock_check_global_limit, test_data, db - ): - """Test reactivating a rule with future expiration (active state).""" - # Setup mocks - mock_check_global_limit.return_value = False - mock_check_rule_limit.return_value = False - - # The rule will be active (state=1) because expiration is in the future - expires = test_data["future_time"] - - # Call the function - model, messages = rule_service.reactivate_rule( - rule_type=RuleTypes.IPv4, - rule_id=test_data["ipv4_rule_id"], - expires=expires, - comment=test_data["comment"], - user_id=test_data["user_id"], - org_id=test_data["org_id"], - user_email=test_data["user_email"], - org_name=test_data["org_name"], - ) - - # Assertions - mock_check_global_limit.assert_called_once_with(RuleTypes.IPv4.value) - mock_check_rule_limit.assert_called_once_with(test_data["org_id"], rule_type=RuleTypes.IPv4.value) - - # Verify model was updated - assert model.expires == expires - assert model.comment == test_data["comment"] - assert model.rstate_id == 1 # Active state - - # Verify route announcement was made - mock_announce_route.assert_called_once() - args, _ = mock_announce_route.call_args - assert args[0].source == RouteSources.UI - - # Verify logging - mock_log_route.assert_called_once() - - # Check returned values - assert messages - assert messages[0].startswith("Rule ") - - @patch("flowapp.services.rule_service.check_global_rule_limit") - @patch("flowapp.services.rule_service.check_rule_limit") - @patch("flowapp.services.rule_service.announce_route") - @patch("flowapp.services.rule_service.log_withdraw") - def test_reactivate_rule_inactive( - self, mock_log_withdraw, mock_announce_route, mock_check_rule_limit, mock_check_global_limit, test_data, db - ): - """Test reactivating a rule with past expiration (inactive state).""" - # Setup mocks - mock_check_global_limit.return_value = False - mock_check_rule_limit.return_value = False - - # The rule will be inactive (state=2) because expiration is in the past - expires = test_data["past_time"] - - # Call the function - model, messages = rule_service.reactivate_rule( - rule_type=RuleTypes.IPv4, - rule_id=test_data["ipv4_rule_id"], - expires=expires, - comment=test_data["comment"], - user_id=test_data["user_id"], - org_id=test_data["org_id"], - user_email=test_data["user_email"], - org_name=test_data["org_name"], - ) - - # Verify model was updated - assert model.expires == expires - assert model.comment == test_data["comment"] - assert model.rstate_id == 2 # Inactive state - - # Verify route withdrawal was made - mock_announce_route.assert_called_once() - args, _ = mock_announce_route.call_args - assert args[0].source == RouteSources.UI - - # Verify logging - mock_log_withdraw.assert_called_once() - - # Check returned values - assert messages - assert messages[0].startswith("Rule ") - - @patch("flowapp.services.rule_service.check_global_rule_limit") - @patch("flowapp.services.rule_service.check_rule_limit") - def test_reactivate_rule_global_limit_reached(self, mock_check_rule_limit, mock_check_global_limit, test_data, db): - """Test reactivating a rule when global limit is reached.""" - # Setup mocks - mock_check_global_limit.return_value = True # Global limit reached - mock_check_rule_limit.return_value = False - - # The rule would be active, but global limit is reached - expires = test_data["future_time"] - - # Call the function - model, messages = rule_service.reactivate_rule( - rule_type=RuleTypes.IPv4, - rule_id=test_data["ipv4_rule_id"], - expires=expires, - comment=test_data["comment"], - user_id=test_data["user_id"], - org_id=test_data["org_id"], - user_email=test_data["user_email"], - org_name=test_data["org_name"], - ) - - # Assertions - mock_check_global_limit.assert_called_once_with(RuleTypes.IPv4.value) - - # Check returned values - assert messages == ["global_limit_reached"] - - @patch("flowapp.services.rule_service.check_global_rule_limit") - @patch("flowapp.services.rule_service.check_rule_limit") - def test_reactivate_rule_org_limit_reached(self, mock_check_rule_limit, mock_check_global_limit, test_data, db): - """Test reactivating a rule when organization limit is reached.""" - # Setup mocks - mock_check_global_limit.return_value = False - mock_check_rule_limit.return_value = True # Org limit reached - - # The rule would be active, but org limit is reached - expires = test_data["future_time"] - - # Call the function - model, messages = rule_service.reactivate_rule( - rule_type=RuleTypes.IPv4, - rule_id=test_data["ipv4_rule_id"], - expires=expires, - comment=test_data["comment"], - user_id=test_data["user_id"], - org_id=test_data["org_id"], - user_email=test_data["user_email"], - org_name=test_data["org_name"], - ) - - # Assertions - mock_check_global_limit.assert_called_once_with(RuleTypes.IPv4.value) - mock_check_rule_limit.assert_called_once_with(test_data["org_id"], rule_type=RuleTypes.IPv4.value) - - # Check returned values - assert messages == ["limit_reached"] - - -class TestDeleteRule: - """Tests for the delete_rule function.""" - - @patch("flowapp.services.rule_service.announce_route") - @patch("flowapp.services.rule_service.log_withdraw") - def test_delete_rule_success(self, mock_log_withdraw, mock_announce_route, test_data, db): - """Test successful rule deletion.""" - # Call the function - success, message = rule_service.delete_rule( - rule_type=RuleTypes.IPv4, - rule_id=test_data["ipv4_rule_id"], - user_id=test_data["user_id"], - user_email=test_data["user_email"], - org_name=test_data["org_name"], - allowed_rule_ids=[test_data["ipv4_rule_id"]], # Rule is in allowed list - ) - - # Verify route withdrawal was made - mock_announce_route.assert_called_once() - args, _ = mock_announce_route.call_args - assert args[0].source == RouteSources.UI - - # Verify logging - mock_log_withdraw.assert_called_once() - - # Verify database operations - rule should be deleted - rule = db.session.get(Flowspec4, test_data["ipv4_rule_id"]) - assert rule is None - - # Check returned values - assert success is True - assert message == "Rule deleted successfully" - - def test_delete_rule_not_allowed(self, test_data, db): - """Test rule deletion when rule is not in allowed list.""" - # Call the function - success, message = rule_service.delete_rule( - rule_type=RuleTypes.IPv4, - rule_id=test_data["ipv4_rule_id"], - user_id=test_data["user_id"], - user_email=test_data["user_email"], - org_name=test_data["org_name"], - allowed_rule_ids=[999], # Rule is not in allowed list - ) - - # Verify rule still exists - rule = db.session.get(Flowspec4, test_data["ipv4_rule_id"]) - assert rule is not None - - # Check returned values - assert success is False - assert message == "You cannot delete this rule" - - def test_delete_rule_not_found(self, test_data, db): - """Test rule deletion when rule is not found.""" - # Call the function with non-existent ID - success, message = rule_service.delete_rule( - rule_type=RuleTypes.IPv4, - rule_id=9999, # Non-existent ID - user_id=test_data["user_id"], - user_email=test_data["user_email"], - org_name=test_data["org_name"], - ) - - # Check returned values - assert success is False - assert message == "Rule not found" - - @patch("flowapp.services.rule_service.announce_route") - @patch("flowapp.services.rule_service.log_withdraw") - @patch("flowapp.services.rule_service.RuleWhitelistCache") - def test_delete_rtbh_rule(self, mock_whitelist_cache, mock_log_withdraw, mock_announce_route, test_data, db): - """Test deleting an RTBH rule.""" - # Call the function - success, message = rule_service.delete_rule( - rule_type=RuleTypes.RTBH, - rule_id=test_data["rtbh_rule_id"], - user_id=test_data["user_id"], - user_email=test_data["user_email"], - org_name=test_data["org_name"], - allowed_rule_ids=[test_data["rtbh_rule_id"]], - ) - - # Verify whitelist cache was cleaned - mock_whitelist_cache.delete_by_rule_id.assert_called_once_with(test_data["rtbh_rule_id"]) - - # Verify route withdrawal and logging - mock_announce_route.assert_called_once() - mock_log_withdraw.assert_called_once() - - # Verify rule was deleted - rule = db.session.get(RTBH, test_data["rtbh_rule_id"]) - assert rule is None - - # Check returned values - assert success is True - assert message == "Rule deleted successfully" - - -class TestDeleteRtbhAndCreateWhitelist: - """Tests for the delete_rtbh_and_create_whitelist function.""" - - @patch("flowapp.services.rule_service.delete_rule") - @patch("flowapp.services.rule_service.create_or_update_whitelist") - def test_delete_rtbh_and_create_whitelist_success(self, mock_create_whitelist, mock_delete_rule, test_data, db): - """Test successful RTBH deletion and whitelist creation.""" - # Setup mock for delete_rule service - mock_delete_rule.return_value = (True, "Rule deleted successfully") - - # Setup mock for create_or_update_whitelist service - mock_whitelist = Whitelist( - ip="192.168.1.100", - mask=32, - expires=datetime.now() + timedelta(days=7), - user_id=test_data["user_id"], - org_id=test_data["org_id"], - ) - mock_create_whitelist.return_value = (mock_whitelist, ["Whitelist created"]) - - # Call the function - success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( - rule_id=test_data["rtbh_rule_id"], - user_id=test_data["user_id"], - org_id=test_data["org_id"], - user_email=test_data["user_email"], - org_name=test_data["org_name"], - allowed_rule_ids=[test_data["rtbh_rule_id"]], - ) - - # Verify delete_rule was called - mock_delete_rule.assert_called_once_with( - rule_type=RuleTypes.RTBH, - rule_id=test_data["rtbh_rule_id"], - user_id=test_data["user_id"], - user_email=test_data["user_email"], - org_name=test_data["org_name"], - allowed_rule_ids=[test_data["rtbh_rule_id"]], - ) - - # Verify create_or_update_whitelist was called with correct data - mock_create_whitelist.assert_called_once() - args, kwargs = mock_create_whitelist.call_args - form_data = kwargs.get("form_data", args[0] if args else None) - assert form_data["ip"] == "192.168.1.100" - assert form_data["mask"] == 32 - assert "Created from RTBH rule" in form_data["comment"] - - # Check returned values - assert success is True - assert len(messages) == 2 - assert messages[0] == "Rule deleted successfully" - assert messages[1] == "Whitelist created" - assert whitelist == mock_whitelist - - def test_delete_rtbh_and_create_whitelist_rule_not_found(self, test_data, db): - """Test when the RTBH rule to convert is not found.""" - # Call the function with non-existent ID - success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( - rule_id=9999, # Non-existent ID - user_id=test_data["user_id"], - org_id=test_data["org_id"], - user_email=test_data["user_email"], - org_name=test_data["org_name"], - ) - - # Check returned values - assert success is False - assert messages == ["RTBH rule not found"] - assert whitelist is None - - def test_delete_rtbh_and_create_whitelist_not_allowed(self, test_data, db): - """Test when the user is not allowed to delete the rule.""" - # Call the function - success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( - rule_id=test_data["rtbh_rule_id"], - user_id=test_data["user_id"], - org_id=test_data["org_id"], - user_email=test_data["user_email"], - org_name=test_data["org_name"], - allowed_rule_ids=[999], # Rule not in allowed list - ) - - # Check returned values - assert success is False - assert messages == ["You cannot delete this rule"] - assert whitelist is None - - @patch("flowapp.services.rule_service.delete_rule") - def test_delete_rtbh_and_create_whitelist_delete_fails(self, mock_delete_rule, test_data, db): - """Test when the RTBH deletion fails.""" - # Setup mock for delete_rule service to fail - mock_delete_rule.return_value = (False, "Error deleting rule") - - # Call the function - success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( - rule_id=test_data["rtbh_rule_id"], - user_id=test_data["user_id"], - org_id=test_data["org_id"], - user_email=test_data["user_email"], - org_name=test_data["org_name"], - allowed_rule_ids=[test_data["rtbh_rule_id"]], - ) - - # Verify delete_rule was called - mock_delete_rule.assert_called_once() - - # Check returned values - assert success is False - assert messages == ["Error deleting rule"] - assert whitelist is None - - @patch("flowapp.services.rule_service.delete_rule") - @patch("flowapp.services.rule_service.create_or_update_whitelist") - def test_rtbh_with_ipv6(self, mock_create_whitelist, mock_delete_rule, test_data, db): - """Test with an IPv6 RTBH rule.""" - # Create an IPv6 RTBH rule - ipv6_rtbh = RTBH( - ipv4=None, - ipv4_mask=None, - ipv6="2001:db8::1", - ipv6_mask=64, - community_id=1, - expires=datetime.now() + timedelta(hours=1), - comment="IPv6 RTBH rule", - user_id=test_data["user_id"], - org_id=test_data["org_id"], - rstate_id=1, - ) - db.session.add(ipv6_rtbh) - db.session.commit() - - # Setup mocks - mock_delete_rule.return_value = (True, "Rule deleted successfully") - mock_whitelist = Whitelist( - ip="2001:db8::1", - mask=64, - expires=datetime.now() + timedelta(days=7), - user_id=test_data["user_id"], - org_id=test_data["org_id"], - ) - mock_create_whitelist.return_value = (mock_whitelist, ["Whitelist created"]) - - # Call the function - success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( - rule_id=ipv6_rtbh.id, - user_id=test_data["user_id"], - org_id=test_data["org_id"], - user_email=test_data["user_email"], - org_name=test_data["org_name"], - allowed_rule_ids=[ipv6_rtbh.id], - ) - - # Verify create_or_update_whitelist was called with correct IPv6 data - mock_create_whitelist.assert_called_once() - args, kwargs = mock_create_whitelist.call_args - form_data = kwargs.get("form_data", args[0] if args else None) - assert form_data["ip"] == "2001:db8::1" - assert form_data["mask"] == 64 - - # Check returned values - assert success is True - assert len(messages) == 2 - assert whitelist == mock_whitelist diff --git a/flowapp/tests/test_utils.py b/flowapp/tests/test_utils.py deleted file mode 100644 index 47c89c4..0000000 --- a/flowapp/tests/test_utils.py +++ /dev/null @@ -1,71 +0,0 @@ -import ipaddress -import pytest - -from datetime import datetime, timedelta - -from flowapp import utils - - -@pytest.mark.parametrize( - "apitime, preformat", - [ - ("10/15/2015 14:46", "us"), - ("2015/10/15 14:46", "yearfirst"), - ("1444913400", "timestamp"), - (1444913400, "timestamp"), - ], -) -def test_parse_api_time(apitime, preformat): - """ - is the time parsed correctly - """ - result = utils.parse_api_time(apitime) - assert isinstance(result, tuple) - assert result[0] == datetime(2015, 10, 15, 14, 50) - assert result[1] == preformat - - -@pytest.mark.parametrize( - "apitime", ["10/152015 14:46", "201/10/15 14:46", "144123254913400", "abcd"] -) -def test_parse_api_time_bad_time(apitime): - """ - is the time parsed correctly - """ - assert not utils.parse_api_time(apitime) - - -def test_get_rule_state_by_time(): - """ - Test if time in the past returns 2 - """ - past = datetime.now() - timedelta(days=1) - - assert utils.get_state_by_time(past) == 2 - - -def test_round_to_ten(): - """ - Test if the time is rounded correctly - """ - d1 = datetime(2013, 9, 2, 16, 25, 59) - d2 = datetime(2013, 9, 2, 16, 32, 59) - dround = datetime(2013, 9, 2, 16, 30, 00) - - assert utils.round_to_ten_minutes(d1) == dround - assert utils.round_to_ten_minutes(d2) == dround - - -@pytest.mark.parametrize( - "address_a, address_b", - [ - ( - "2001:718:1c01:16:f1c4:c682:817:7e23", - "2001:0718:1c01:0016:f1c4:c682:0817:7e23", - ), - ("2001:718::", "2001:718::0"), - ("2001:718::0", "2001:0718:0000:0000:0000:0000:0000:0000"), - ], -) -def test_ipv6_comparsion(address_a, address_b): - assert ipaddress.ip_address(address_a) == ipaddress.ip_address(address_b) diff --git a/flowapp/tests/test_validators.py b/flowapp/tests/test_validators.py deleted file mode 100644 index 614d3a5..0000000 --- a/flowapp/tests/test_validators.py +++ /dev/null @@ -1,547 +0,0 @@ -import pytest -from flowapp.validators import ( - PortString, - PacketString, - NetRangeString, - NetInRange, - IPAddress, - IPAddressValidator, - NetworkValidator, - ValidationError, - address_in_range, - address_with_mask, - IPv4Address, - IPv6Address, - DateNotExpired, - editable_range, - network_in_range, - range_in_network, - filter_rules_in_network, - split_rules_for_user, - filter_rtbh_rules, - split_rtbh_rules_for_user, -) - - -def test_port_string_len_raises(field): - port = PortString() - field.data = "1;2;3;4;5;6;7;8" - with pytest.raises(ValidationError): - port(None, field) - - -@pytest.mark.parametrize( - "address, mask, expected", - [ - ("147.230.23.25", "24", False), - ("147.230.23.0", "24", True), - ("0.0.0.0", "0", True), - ("2001:718:1C01:1111::1111", "64", False), - ("2001:718:1C01:1111::", "64", True), - ], -) -def test_is_valid_address_with_mask(address, mask, expected): - assert address_with_mask(address, mask) == expected - - -@pytest.mark.parametrize("address", ["147.230.23.25", "147.230.23.0"]) -def test_ip4address_passes(field, address): - adr = IPv4Address() - field.data = address - adr(None, field) - - -@pytest.mark.parametrize( - "address", - [ - "2001:718:1C01:1111::1111", - "2001:718:1C01:1111::", - ], -) -def test_ip6address_passes(field, address): - adr = IPv6Address() - field.data = address - adr(None, field) - - -@pytest.mark.parametrize( - "address", - [ - "2001:718:1C01:1111::1111", - "2001:718:1C01:1111::", - ], -) -def test_bad_ip6address_raises(field, address): - adr = IPv4Address() - field.data = address - with pytest.raises(ValidationError): - adr(None, field) - - -@pytest.mark.parametrize("expired", ["2018/10/25 14:46", "2018/12/20 9:46", "2019/05/22 12:33"]) -def test_expired_date_raises(field, expired): - adr = DateNotExpired() - field.data = expired - with pytest.raises(ValidationError): - adr(None, field) - - -@pytest.mark.parametrize( - "address", - [ - "147.230.1000.25", - "2001:718::::", - ], -) -def test_ipaddress_raises(field, address): - adr = IPv6Address() - field.data = address - with pytest.raises(ValidationError): - adr(None, field) - - -@pytest.mark.parametrize( - "address, mask, ranges, expected", - [ - ("147.230.23.0", "24", ["147.230.0.0/16", "147.251.0.0/16"], True), - ("147.230.23.0", "24", ["147.230.0.0/16", "147.251.0.0/16"], True), - ], -) -def test_editable_rule(rule, address, mask, ranges, expected): - rule.source = address - rule.source_mask = mask - assert editable_range(rule, ranges) == expected - - -@pytest.mark.parametrize( - "address, mask, ranges, expected", - [ - ("147.230.23.0", "24", ["147.230.0.0/16", "147.251.0.0/16"], True), - ("147.233.23.0", "24", ["147.230.0.0/16", "147.251.0.0/16"], False), - ("147.230.23.0", "24", ["147.230.0.0/16", "2001:718:1c01::/48"], True), - ("195.113.0.0", "16", ["0.0.0.0/0", "::/0"], True), - ], -) -def test_address_in_range(address, mask, ranges, expected): - assert address_in_range(address, ranges) == expected - - -@pytest.mark.parametrize( - "address, mask, ranges, expected", - [ - ("147.230.23.0", "24", ["147.230.0.0/16", "147.251.0.0/16"], True), - ("147.233.23.0", "24", ["147.230.0.0/16", "147.251.0.0/16"], False), - ("195.113.0.0", "16", ["195.113.0.0/18", "195.113.64.0/21"], False), - ("195.113.0.0", "16", ["0.0.0.0/0", "::/0"], True), - ( - "195.113.0.0", - "16", - ["147.230.0.0/16", "2001:718:1c01::/48", "0.0.0.0/0", "::/0"], - True, - ), - ], -) -def test_network_in_range(address, mask, ranges, expected): - assert network_in_range(address, mask, ranges) == expected - - -@pytest.mark.parametrize( - "address, mask, ranges, expected", - [ - ("195.113.0.0", "16", ["147.230.0.0/16", "195.113.250.0/24"], True), - ], -) -def test_range_in_network(address, mask, ranges, expected): - assert range_in_network(address, mask, ranges) == expected - - -### new tests - - -# New tests for previously uncovered validators - - -# Tests for PortString validator syntax -@pytest.mark.parametrize( - "port_data", - [ - "80", # Simple number - "80-443", # Range with hyphen - ">=80&<=443", # Explicit range - ">80", # Greater than - ">=80", # Greater than or equal - "<443", # Less than - "<=443", # Less than or equal - "80;443;8080", # Multiple port expressions - ], -) -def test_port_string_valid_syntax(field, port_data): - validator = PortString() - field.data = port_data - validator(None, field) # Should not raise - - -@pytest.mark.parametrize( - "invalid_port_data", - [ - "80,443", # Comma not supported, should use semicolon - "80..443", # Invalid range syntax (should be 80-443) - ">65536", # Port number too high - "-1", # Negative number - "<-1", # Port number too low - ">=80<=443", # Invalid range syntax (missing &) - "!=80", # Not equals not supported - "abc", # Non-numeric - "80-", # Incomplete range - "-80", # Invalid range - "80&&443", # Invalid operator - "443-80", # End less than start - ">=443&<=80", # End less than start in explicit range - "0-65537", # Range exceeding max - ], -) -def test_port_string_invalid_syntax(field, invalid_port_data): - validator = PortString() - field.data = invalid_port_data - with pytest.raises(ValidationError): - validator(None, field) - - -# Tests for PacketString validator -@pytest.mark.parametrize( - "packet_data", - ["1500", ">1000", "<9000", "1000-1500", ">=1000&<=1500", "1500;9000"], # Multiple packet size expressions -) -def test_packet_string_valid(field, packet_data): - validator = PacketString() - field.data = packet_data - validator(None, field) # Should not raise - - -@pytest.mark.parametrize( - "invalid_packet_data", - [ - "65536", # Too large - "-1", # Negative - "1500..", # Invalid range syntax - "abc", # Non-numeric - "!!1500", # Invalid operator - "1500-", # Incomplete range - "9000-1500", # End less than start - ], -) -def test_packet_string_invalid(field, invalid_packet_data): - validator = PacketString() - field.data = invalid_packet_data - with pytest.raises(ValidationError): - validator(None, field) - - -# Tests for NetRangeString validator -@pytest.mark.parametrize( - "net_range", - [ - "192.168.0.0/24", - "10.0.0.0/8\n172.16.0.0/12", - "2001:db8::/32", - "192.168.0.0/24 10.0.0.0/8", - "2001:db8::/32 2001:db8:1::/48", - ], -) -def test_net_range_string_valid(field, net_range): - validator = NetRangeString() - field.data = net_range - validator(None, field) # Should not raise - - -@pytest.mark.parametrize( - "invalid_net_range", - [ - "192.168.0.0/33", # Invalid mask - "256.256.256.0/24", # Invalid IP - "2001:xyz::/32", # Invalid IPv6 - "192.168.0.0/24/", # Malformed - "not-an-ip/24", # Invalid format - ], -) -def test_net_range_string_invalid(field, invalid_net_range): - validator = NetRangeString() - field.data = invalid_net_range - with pytest.raises(ValidationError): - validator(None, field) - - -# Tests for NetInRange validator -def test_net_in_range_valid(field): - net_ranges = ["192.168.0.0/16", "10.0.0.0/8"] - validator = NetInRange(net_ranges) - field.data = "192.168.1.1/24" - validator(None, field) # Should not raise - - -def test_net_in_range_invalid(field): - net_ranges = ["192.168.0.0/16", "10.0.0.0/8"] - validator = NetInRange(net_ranges) - field.data = "172.16.1.1/24" - with pytest.raises(ValidationError): - validator(None, field) - - -# Tests for base IPAddress validator -@pytest.mark.parametrize("ip_addr", ["192.168.1.1", "10.0.0.1", "2001:db8::1", "fe80::1"]) -def test_ip_address_valid(field, ip_addr): - validator = IPAddress() - field.data = ip_addr - validator(None, field) # Should not raise - - -@pytest.mark.parametrize("invalid_ip", ["256.256.256.256", "2001:xyz::1", "not-an-ip", "192.168.1"]) -def test_ip_address_invalid(field, invalid_ip): - validator = IPAddress() - field.data = invalid_ip - with pytest.raises(ValidationError): - validator(None, field) - - -# Tests for universal IPAddressValidator -@pytest.mark.parametrize("ip_addr", ["192.168.1.1", "2001:db8::1", "", None]) # Empty is allowed # None is allowed -def test_ip_address_validator_valid(field, ip_addr): - validator = IPAddressValidator() - field.data = ip_addr - validator(None, field) # Should not raise - - -@pytest.mark.parametrize("invalid_ip", ["256.256.256.256", "2001:xyz::1", "not-an-ip"]) -def test_ip_address_validator_invalid(field, invalid_ip): - validator = IPAddressValidator() - field.data = invalid_ip - with pytest.raises(ValidationError): - validator(None, field) - - -# Tests for NetworkValidator -class MockForm: - def __init__(self, address, mask): - self._fields = {"mask": type("MockField", (), {"data": mask})()} - - -@pytest.mark.parametrize( - "address,mask", - [ - ("192.168.0.0", "24"), - ("10.0.0.0", "8"), - ("2001:db8::", "32"), - ("", None), # Empty values should pass - (None, None), # None values should pass - ], -) -def test_network_validator_valid(field, address, mask): - validator = NetworkValidator("mask") - field.data = address - form = MockForm(address, mask) - validator(form, field) # Should not raise - - -@pytest.mark.parametrize( - "address,mask", - [ - ("192.168.1.1", "24"), # Not a network address - ("192.168.0.0", "33"), # Invalid IPv4 mask - ("2001:db8::", "129"), # Invalid IPv6 mask - ("256.256.256.0", "24"), # Invalid IP - ("2001:xyz::", "32"), # Invalid IPv6 - ], -) -def test_network_validator_invalid(field, address, mask): - validator = NetworkValidator("mask") - field.data = address - form = MockForm(address, mask) - with pytest.raises(ValidationError): - validator(form, field) - - -# Mock rule classes for testing robust attribute handling -class MockRule: - """Mock rule with all expected attributes""" - - def __init__(self, source=None, source_mask=None, dest=None, dest_mask=None): - self.source = source - self.source_mask = source_mask - self.dest = dest - self.dest_mask = dest_mask - - -class MockRuleIncomplete: - """Mock rule with missing attributes""" - - def __init__(self, name=None): - self.name = name - # Intentionally missing source, source_mask, dest, dest_mask attributes - - -class MockRulePartial: - """Mock rule with some attributes""" - - def __init__(self, source=None): - self.source = source - # Missing source_mask, dest, dest_mask attributes - - -class MockRTBHRule: - """Mock RTBH rule with all expected attributes""" - - def __init__(self, ipv4=None, ipv4_mask=None, ipv6=None, ipv6_mask=None): - self.ipv4 = ipv4 - self.ipv4_mask = ipv4_mask - self.ipv6 = ipv6 - self.ipv6_mask = ipv6_mask - - -class MockRTBHRuleIncomplete: - """Mock RTBH rule with missing attributes""" - - def __init__(self, name=None): - self.name = name - # Intentionally missing ipv4, ipv4_mask, ipv6, ipv6_mask attributes - - -# Tests for filter_rules_in_network with robust attribute handling -def test_filter_rules_in_network_normal_rules(): - """Test filter_rules_in_network with normal rule objects""" - net_ranges = ["192.168.0.0/16", "10.0.0.0/8"] - rules = [ - MockRule("192.168.1.0", "24", "10.0.1.0", "24"), # Should match - MockRule("172.16.1.0", "24", "172.16.2.0", "24"), # Should not match - MockRule("10.1.0.0", "16", None, None), # Should match (source only) - ] - - filtered = filter_rules_in_network(net_ranges, rules) - assert len(filtered) == 2 - assert rules[0] in filtered # 192.168.x.x rule - assert rules[2] in filtered # 10.x.x.x rule - assert rules[1] not in filtered # 172.16.x.x rule - - -def test_filter_rules_in_network_missing_attributes(): - """Test filter_rules_in_network with rules missing required attributes""" - net_ranges = ["192.168.0.0/16"] - rules = [ - MockRule("192.168.1.0", "24", "10.0.1.0", "24"), # Normal rule - should match - MockRuleIncomplete("rule_without_network_attrs"), # Missing attrs - should be included - MockRulePartial("172.16.1.0"), # Partial attrs - should be included - ] - - filtered = filter_rules_in_network(net_ranges, rules) - assert len(filtered) == 3 # All rules should be included - assert all(rule in filtered for rule in rules) - - -def test_filter_rules_in_network_none_values(): - """Test filter_rules_in_network with None values in attributes""" - net_ranges = ["192.168.0.0/16"] - rules = [ - MockRule("192.168.1.0", "24", None, None), # Should match on source - MockRule(None, None, "192.168.2.0", "24"), # Should match on dest - MockRule(None, None, None, None), # Should not match - ] - - filtered = filter_rules_in_network(net_ranges, rules) - assert len(filtered) == 2 - assert rules[0] in filtered - assert rules[1] in filtered - assert rules[2] not in filtered - - -# Tests for split_rules_for_user with robust attribute handling -def test_split_rules_for_user_normal_rules(): - """Test split_rules_for_user with normal rule objects""" - net_ranges = ["192.168.0.0/16"] - rules = [ - MockRule("192.168.1.0", "24", "10.0.1.0", "24"), # Should be user rule - MockRule("172.16.1.0", "24", "172.16.2.0", "24"), # Should be rest rule - ] - - user_rules, rest_rules = split_rules_for_user(net_ranges, rules) - assert len(user_rules) == 1 - assert len(rest_rules) == 1 - assert rules[0] in user_rules - assert rules[1] in rest_rules - - -def test_split_rules_for_user_missing_attributes(): - """Test split_rules_for_user with rules missing required attributes""" - net_ranges = ["192.168.0.0/16"] - rules = [ - MockRule("192.168.1.0", "24", "10.0.1.0", "24"), # Normal rule - user rule - MockRuleIncomplete("rule_without_attrs"), # Missing attrs - should be user rule - MockRule("172.16.1.0", "24", "172.16.2.0", "24"), # Normal rule - rest rule - ] - - user_rules, rest_rules = split_rules_for_user(net_ranges, rules) - assert len(user_rules) == 2 # Normal matching rule + incomplete rule - assert len(rest_rules) == 1 - assert rules[0] in user_rules # Matching rule - assert rules[1] in user_rules # Incomplete rule treated as editable - assert rules[2] in rest_rules # Non-matching rule - - -# Tests for filter_rtbh_rules with robust attribute handling -def test_filter_rtbh_rules_normal_rules(): - """Test filter_rtbh_rules with normal RTBH rule objects""" - net_ranges = ["192.168.0.0/16", "2001:db8::/32"] - rules = [ - MockRTBHRule("192.168.1.0", "24", None, None), # Should match on IPv4 - MockRTBHRule(None, None, "2001:db8:1::", "48"), # Should match on IPv6 - MockRTBHRule("172.16.1.0", "24", "2001:db9::", "32"), # Should not match - ] - - filtered = filter_rtbh_rules(net_ranges, rules) - assert len(filtered) == 2 - assert rules[0] in filtered - assert rules[1] in filtered - assert rules[2] not in filtered - - -# Tests for split_rtbh_rules_for_user with robust attribute handling -def test_split_rtbh_rules_for_user_normal_rules(): - """Test split_rtbh_rules_for_user with normal RTBH rule objects""" - net_ranges = ["192.168.0.0/16"] - rules = [ - MockRTBHRule("192.168.1.0", "24", None, None), # Should be filtered (user) - MockRTBHRule("172.16.1.0", "24", None, None), # Should be read-only - ] - - filtered, read_only = split_rtbh_rules_for_user(net_ranges, rules) - assert len(filtered) == 1 - assert len(read_only) == 1 - assert rules[0] in filtered - assert rules[1] in read_only - - -# Edge case tests -def test_filter_functions_empty_input(): - """Test all filter functions with empty input""" - net_ranges = ["192.168.0.0/16"] - - # Empty rules list - assert filter_rules_in_network(net_ranges, []) == [] - assert split_rules_for_user(net_ranges, []) == ([], []) - assert filter_rtbh_rules(net_ranges, []) == [] - assert split_rtbh_rules_for_user(net_ranges, []) == ([], []) - - -def test_filter_functions_empty_net_ranges(): - """Test filter functions with empty network ranges""" - rules = [MockRule("192.168.1.0", "24", None, None)] - rtbh_rules = [MockRTBHRule("192.168.1.0", "24", None, None)] - - # Empty network ranges - nothing should match - assert filter_rules_in_network([], rules) == [] - user_rules, rest_rules = split_rules_for_user([], rules) - assert user_rules == [] - assert rest_rules == rules - - assert filter_rtbh_rules([], rtbh_rules) == [] - filtered, read_only = split_rtbh_rules_for_user([], rtbh_rules) - assert filtered == [] - assert read_only == rtbh_rules diff --git a/flowapp/tests/test_whitelist_common.py b/flowapp/tests/test_whitelist_common.py deleted file mode 100644 index a843aff..0000000 --- a/flowapp/tests/test_whitelist_common.py +++ /dev/null @@ -1,250 +0,0 @@ -import pytest -from flowapp.services.whitelist_common import ( - Relation, - check_rule_against_whitelists, - check_whitelist_against_rules, - check_whitelist_to_rule_relation, - subtract_network, - clear_network_cache, - _is_same_ip_version, # New helper function -) - - -# Tests for core function that checks network relations -@pytest.mark.parametrize( - "rule,whitelist,expected", - [ - # IPv4 test cases - ("192.168.1.0/24", "192.168.0.0/16", Relation.SUPERNET), # Whitelist is supernet - ("192.168.1.0/24", "192.168.1.0/24", Relation.EQUAL), # Equal networks - ("192.168.1.0/24", "192.168.1.128/25", Relation.SUBNET), # Whitelist is subnet - ("192.168.1.128/25", "192.168.1.0/24", Relation.SUPERNET), # Whitelist is supernet - ("10.0.0.0/8", "192.168.1.0/24", Relation.DIFFERENT), # Different networks - # IPv6 test cases - ("2001:db8::/32", "2001:db8::/32", Relation.EQUAL), # Equal networks - ("2001:db8:1::/48", "2001:db8::/32", Relation.SUPERNET), # Whitelist is supernet - ("2001:db8::/32", "2001:db8:1::/48", Relation.SUBNET), # Whitelist is subnet - ("2001:db8::/32", "2002:db8::/32", Relation.DIFFERENT), # Different networks - ("2001:db8:1:2::/64", "2001:db8::/32", Relation.SUPERNET), # Whitelist is supernet - ], -) -def test_check_whitelist_to_rule_relation(rule, whitelist, expected): - clear_network_cache() - assert check_whitelist_to_rule_relation(rule, whitelist) == expected - - -@pytest.mark.parametrize( - "target,whitelist,expected", - [ - # Basic IPv4 cases - ( - "192.168.1.0/24", - "192.168.1.128/25", - ["192.168.1.0/25"], - ), # One remaining subnet - ( - "192.168.1.0/24", - "192.168.1.64/26", - ["192.168.1.0/26", "192.168.1.128/25"], - ), # Two remaining subnets - ( - "192.168.1.0/24", - "192.168.1.0/24", - ["192.168.1.0/24"], - ), # Equal networks - return original network - ( - "192.168.1.0/24", - "192.168.2.0/24", - ["192.168.1.0/24"], - ), # No overlap - ], -) -def test_subtract_network(target, whitelist, expected): - clear_network_cache() - result = subtract_network(target, whitelist) - assert sorted(result) == sorted(expected) - - -def test_check_rule_against_whitelists(): - clear_network_cache() - - rule = "192.168.1.0/24" - whitelists = [ - "192.168.0.0/16", # SUPERNET - "172.16.0.0/12", # DIFFERENT - "192.168.1.0/24", # EQUAL - "192.168.1.128/25", # SUBNET - ] - - results = check_rule_against_whitelists(rule, whitelists) - - # Should return list of tuples (rule, whitelist, relation) for non-DIFFERENT relations - expected = [ - (rule, "192.168.0.0/16", Relation.SUPERNET), - (rule, "192.168.1.0/24", Relation.EQUAL), - (rule, "192.168.1.128/25", Relation.SUBNET), - ] - - assert len(results) == 3 - for result, exp in zip(sorted(results), sorted(expected)): - assert result == exp - - -def test_check_whitelist_against_rules(): - clear_network_cache() - - whitelist = "192.168.1.128/25" - rules = [ - "192.168.0.0/16", # Whitelist is SUBNET of this larger network - "172.16.0.0/12", # DIFFERENT - "192.168.1.128/25", # EQUAL - "192.168.1.0/24", # Whitelist is SUBNET of this network - ] - - results = check_whitelist_against_rules(rules, whitelist) - - # Should return list of tuples (rule, whitelist, relation) for non-DIFFERENT relations - expected = [ - ("192.168.0.0/16", whitelist, Relation.SUBNET), - ("192.168.1.128/25", whitelist, Relation.EQUAL), - ("192.168.1.0/24", whitelist, Relation.SUBNET), - ] - - assert len(results) == 3 - for result, exp in zip(sorted(results), sorted(expected)): - assert result == exp - - -# Tests for edge cases and error handling -def test_host_addresses(): - clear_network_cache() - # Test with host addresses (no explicit subnet mask) - assert check_whitelist_to_rule_relation("192.168.1.1", "192.168.1.0/24") == Relation.SUPERNET - assert check_whitelist_to_rule_relation("192.168.1.1", "192.168.2.0/24") == Relation.DIFFERENT - - -def test_single_ip_as_network(): - clear_network_cache() - # Test with /32 (single IPv4 address as network) - assert check_whitelist_to_rule_relation("192.168.1.1/32", "192.168.1.0/24") == Relation.SUPERNET - # Test with /128 (single IPv6 address as network) - assert check_whitelist_to_rule_relation("2001:db8::1/128", "2001:db8::/32") == Relation.SUPERNET - - -def test_is_same_ip_version(): - """Test the new helper function to check if two addresses are of the same IP version""" - # IPv4 cases - assert _is_same_ip_version("192.168.1.0/24", "10.0.0.0/8") is True - assert _is_same_ip_version("192.168.1.1", "10.0.0.1") is True - - # IPv6 cases - assert _is_same_ip_version("2001:db8::/32", "fe80::/64") is True - assert _is_same_ip_version("2001:db8::1", "fe80::1") is True - - # Mixed cases - assert _is_same_ip_version("192.168.1.0/24", "2001:db8::/32") is False - assert _is_same_ip_version("192.168.1.1", "fe80::1") is False - - -def test_check_whitelist_to_rule_relation_mixed_versions(): - """Test the check_whitelist_to_rule_relation function with mixed IP versions""" - clear_network_cache() - - # IPv4 rule with IPv6 whitelist - assert check_whitelist_to_rule_relation("192.168.1.0/24", "2001:db8::/32") == Relation.DIFFERENT - - # IPv6 rule with IPv4 whitelist - assert check_whitelist_to_rule_relation("2001:db8::/32", "192.168.1.0/24") == Relation.DIFFERENT - - # Invalid IP addresses should return DIFFERENT - assert check_whitelist_to_rule_relation("invalid", "192.168.1.0/24") == Relation.DIFFERENT - assert check_whitelist_to_rule_relation("192.168.1.0/24", "invalid") == Relation.DIFFERENT - - -def test_subtract_network_mixed_versions(): - """Test the subtract_network function with mixed IP versions""" - clear_network_cache() - - # IPv4 target with IPv6 whitelist - should return original target - result = subtract_network("192.168.1.0/24", "2001:db8::/32") - assert result == ["192.168.1.0/24"] - - # IPv6 target with IPv4 whitelist - should return original target - result = subtract_network("2001:db8::/32", "192.168.1.0/24") - assert result == ["2001:db8::/32"] - - # Invalid addresses - should return original target - result = subtract_network("invalid", "192.168.1.0/24") - assert result == ["invalid"] - result = subtract_network("192.168.1.0/24", "invalid") - assert result == ["192.168.1.0/24"] - - -def test_check_rule_against_whitelists_mixed_versions(): - """Test the check_rule_against_whitelists function with mixed IP versions""" - clear_network_cache() - - # IPv4 rule against mixed whitelist - rule = "192.168.1.0/24" - whitelists = [ - "192.168.0.0/16", # IPv4 - should match - "2001:db8::/32", # IPv6 - should not match - "10.0.0.0/8", # IPv4 - should not match (different network) - ] - - results = check_rule_against_whitelists(rule, whitelists) - assert len(results) == 1 - assert results[0][0] == rule - assert results[0][1] == "192.168.0.0/16" - assert results[0][2] == Relation.SUPERNET - - # IPv6 rule against mixed whitelist - rule = "2001:db8:1::/48" - whitelists = [ - "192.168.0.0/16", # IPv4 - should not match - "2001:db8::/32", # IPv6 - should match - "2002:db8::/32", # IPv6 - should not match (different network) - ] - - results = check_rule_against_whitelists(rule, whitelists) - assert len(results) == 1 - assert results[0][0] == rule - assert results[0][1] == "2001:db8::/32" - assert results[0][2] == Relation.SUPERNET - - -def test_check_whitelist_against_rules_mixed_versions(): - """Test the check_whitelist_against_rules function with mixed IP versions""" - clear_network_cache() - - # IPv4 whitelist against mixed rules - whitelist = "192.168.1.0/24" - rules = [ - "192.168.0.0/16", # IPv4 - should match - "2001:db8::/32", # IPv6 - should not match - "10.0.0.0/8", # IPv4 - should not match (different network) - ] - - results = check_whitelist_against_rules(rules, whitelist) - assert len(results) == 1 - assert results[0][0] == "192.168.0.0/16" - assert results[0][1] == whitelist - assert results[0][2] == Relation.SUBNET - - # IPv6 whitelist against mixed rules - whitelist = "2001:db8:1::/48" - rules = [ - "192.168.0.0/16", # IPv4 - should not match - "2001:db8::/32", # IPv6 - should match - "2002:db8::/32", # IPv6 - should not match (different network) - ] - - results = check_whitelist_against_rules(rules, whitelist) - assert len(results) == 1 - assert results[0][0] == "2001:db8::/32" - assert results[0][1] == whitelist - assert results[0][2] == Relation.SUBNET - - -if __name__ == "__main__": - pytest.main(["-v"]) diff --git a/flowapp/tests/test_whitelist_service.py b/flowapp/tests/test_whitelist_service.py deleted file mode 100644 index 460a172..0000000 --- a/flowapp/tests/test_whitelist_service.py +++ /dev/null @@ -1,461 +0,0 @@ -""" -Tests for whitelist_service.py module. - -This test suite verifies the functionality of the whitelist service, -which manages creation, updating, and handling of whitelist rules. -""" - -import pytest -from datetime import datetime, timedelta -from unittest.mock import patch, MagicMock - -from flowapp.constants import RuleTypes, RuleOrigin -from flowapp.models import Whitelist, RuleWhitelistCache, RTBH -from flowapp.services import whitelist_service -from flowapp.services.whitelist_common import Relation - - -@pytest.fixture -def whitelist_form_data(): - """Sample valid whitelist form data""" - return { - "ip": "192.168.1.0", - "mask": 24, - "comment": "Test whitelist entry", - "expires": datetime.now() + timedelta(hours=1), - } - - -class TestCreateOrUpdateWhitelist: - @patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists") - @patch("flowapp.services.whitelist_service.check_whitelist_against_rules") - @patch("flowapp.services.whitelist_service.evaluate_whitelist_against_rtbh_check_results") - def test_create_new_whitelist( - self, mock_evaluate, mock_check_whitelist, mock_get_model, app, db, whitelist_form_data - ): - """Test creating a new whitelist entry""" - # Mock the get_whitelist_model_if_exists to return False (not found) - mock_get_model.return_value = False - - # Mock RTBH rules and check results - mock_rtbh_rules = [] - mock_query = MagicMock() - mock_query.filter.return_value.all.return_value = mock_rtbh_rules - - # Mock check_whitelist_against_rules to return empty list (no matches) - mock_check_whitelist.return_value = [] - - # Mock evaluate to return the model unchanged - mock_evaluate.side_effect = lambda model, flashes, rtbh_rule_cache, results: model - - # Call the service function - with patch("flowapp.services.whitelist_service.db.session.query", return_value=mock_query): - with app.app_context(): - model, flashes = whitelist_service.create_or_update_whitelist( - form_data=whitelist_form_data, - user_id=1, - org_id=1, - user_email="test@example.com", - org_name="Test Org", - ) - - # Verify the model was created with correct attributes - assert model is not None - assert model.ip == whitelist_form_data["ip"] - assert model.mask == whitelist_form_data["mask"] - assert model.comment == whitelist_form_data["comment"] - assert model.user_id == 1 - assert model.org_id == 1 - assert model.rstate_id == 1 # Active state - - # Verify flash messages - now a list instead of a string - assert isinstance(flashes, list) - assert "Whitelist saved" in flashes[0] - - @patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists") - @patch("flowapp.services.whitelist_service.check_whitelist_against_rules") - @patch("flowapp.services.whitelist_service.evaluate_whitelist_against_rtbh_check_results") - def test_update_existing_whitelist( - self, mock_evaluate, mock_check_whitelist, mock_get_model, app, db, whitelist_form_data - ): - """Test updating an existing whitelist entry""" - # Create an existing whitelist - existing_model = Whitelist( - ip=whitelist_form_data["ip"], - mask=whitelist_form_data["mask"], - expires=datetime.now(), - user_id=1, - org_id=1, - rstate_id=1, - ) - - # Mock to return the existing model - mock_get_model.return_value = existing_model - - # Mock RTBH rules and check results - mock_rtbh_rules = [] - mock_query = MagicMock() - mock_query.filter.return_value.all.return_value = mock_rtbh_rules - - # Mock check_whitelist_against_rules to return empty list (no matches) - mock_check_whitelist.return_value = [] - - # Mock evaluate to return the model unchanged - mock_evaluate.side_effect = lambda model, flashes, rtbh_rule_cache, results: model - - # Set a new expiration time - new_expires = datetime.now() + timedelta(days=1) - whitelist_form_data["expires"] = new_expires - - # Call the service function - with patch("flowapp.services.whitelist_service.db.session.query", return_value=mock_query): - with app.app_context(): - db.session.add(existing_model) - db.session.commit() - - model, flashes = whitelist_service.create_or_update_whitelist( - form_data=whitelist_form_data, - user_id=1, - org_id=1, - user_email="test@example.com", - org_name="Test Org", - ) - - # Verify the model was updated - assert model == existing_model - # Check that expiration date was updated (rounded to 10 minutes) - # We can't compare exact timestamps, so check date parts - assert model.expires.date() == whitelist_service.round_to_ten_minutes(new_expires).date() - - # Verify flash messages - now a list instead of a string - assert isinstance(flashes, list) - assert "Existing Whitelist found" in flashes[0] - - @patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists") - @patch("flowapp.services.whitelist_service.db.session.query") - @patch("flowapp.services.whitelist_service.map_rtbh_rules_to_strings") - @patch("flowapp.services.whitelist_service.check_whitelist_against_rules") - @patch("flowapp.services.whitelist_service.evaluate_whitelist_against_rtbh_check_results") - def test_create_whitelist_with_matching_rules( - self, mock_evaluate, mock_check, mock_map, mock_query, mock_get_model, app, db, whitelist_form_data - ): - """Test creating a whitelist that affects existing rules""" - # Mock get_whitelist_model_if_exists to return False (new whitelist) - mock_get_model.return_value = False - - # Create a mock RTBH rule - mock_rtbh_rule = MagicMock(spec=RTBH) - mock_rtbh_rule.__str__.return_value = "192.168.1.0/24" - mock_rtbh_rules = [mock_rtbh_rule] - - # Setup mock query to return our rule - mock_query_result = MagicMock() - mock_query_result.filter.return_value.all.return_value = mock_rtbh_rules - mock_query.return_value = mock_query_result - - # Setup map function to return our rule in a dict - mock_map.return_value = {"192.168.1.0/24": mock_rtbh_rule} - - # Setup check function to return a relation - rule_relation = [ - ( - str(mock_rtbh_rule), - str(whitelist_form_data["ip"]) + "/" + str(whitelist_form_data["mask"]), - Relation.EQUAL, - ) - ] - mock_check.return_value = rule_relation - - # Setup evaluate function to modify flashes - def evaluate_side_effect(model, flashes, rtbh_rule_cache, results): - flashes.append("Rule was whitelisted") - return model - - mock_evaluate.side_effect = evaluate_side_effect - - # Call the service function - with app.app_context(): - model, flashes = whitelist_service.create_or_update_whitelist( - form_data=whitelist_form_data, - user_id=1, - org_id=1, - user_email="test@example.com", - org_name="Test Org", - ) - - # Verify flash messages includes both whitelist creation and rule whitelisting - assert "Whitelist saved" in flashes[0] - assert "Rule was whitelisted" in flashes[1] - - # Verify interactions - mock_map.assert_called_once() - mock_check.assert_called_once() - mock_evaluate.assert_called_once() - - -class TestDeleteWhitelist: - @patch("flowapp.services.whitelist_service.announce_rtbh_route") - def test_delete_whitelist_with_user_rules(self, mock_announce, app, db): - """Test deleting a whitelist that has user-created rules attached to it""" - # Create a whitelist - whitelist = Whitelist( - ip="192.168.1.0", - mask=24, - expires=datetime.now() + timedelta(hours=1), - user_id=1, - org_id=1, - rstate_id=1, - ) - - # Create a RTBH rule (created by user, whitelisted) - rtbh_rule = RTBH( - ipv4="192.168.1.0", - ipv4_mask=24, - ipv6="", - ipv6_mask=None, - community_id=1, - expires=datetime.now() + timedelta(hours=1), - user_id=1, - org_id=1, - rstate_id=4, # Whitelisted state - ) - - # Create a cache entry linking the rule to the whitelist - cache_entry = RuleWhitelistCache( - rid=1, - rtype=RuleTypes.RTBH, - whitelist_id=1, - rorigin=RuleOrigin.USER, - ) - - with app.app_context(): - db.session.add(whitelist) - db.session.add(rtbh_rule) - db.session.commit() - - # Set whitelist ID now that it has been saved - cache_entry.whitelist_id = whitelist.id - cache_entry.rid = rtbh_rule.id - db.session.add(cache_entry) - db.session.commit() - - # Mock the get_by_whitelist_id to return our cache entry - with patch.object(RuleWhitelistCache, "get_by_whitelist_id", return_value=[cache_entry]): - # Call the service function - flashes = whitelist_service.delete_whitelist(whitelist.id) - - # Verify the rule state was changed back to active - rtbh_rule = db.session.get(RTBH, rtbh_rule.id) - assert rtbh_rule.rstate_id == 1 # Active state - - # Verify announcement was made - mock_announce.assert_called_once() - - # Verify flash messages - assert isinstance(flashes, list) - assert flashes - - # Verify the whitelist was deleted - assert db.session.get(Whitelist, whitelist.id) is None - - @patch("flowapp.services.whitelist_service.RuleWhitelistCache.clean_by_whitelist_id") - def test_delete_whitelist_with_whitelist_created_rules(self, mock_clean, app, db): - """Test deleting a whitelist that has rules created by the whitelist""" - # Create a whitelist - whitelist = Whitelist( - ip="192.168.1.0", - mask=24, - expires=datetime.now() + timedelta(hours=1), - user_id=1, - org_id=1, - rstate_id=1, - ) - - # Create a RTBH rule (created by whitelist) - rtbh_rule = RTBH( - ipv4="192.168.1.0", - ipv4_mask=24, - ipv6="", - ipv6_mask=None, - community_id=1, - expires=datetime.now() + timedelta(hours=1), - user_id=1, - org_id=1, - rstate_id=1, - ) - - # Create a cache entry linking the rule to the whitelist - cache_entry = RuleWhitelistCache( - rid=1, - rtype=RuleTypes.RTBH, - whitelist_id=1, - rorigin=RuleOrigin.WHITELIST, # Important: created BY whitelist - ) - - with app.app_context(): - db.session.add(whitelist) - db.session.add(rtbh_rule) - db.session.commit() - - # Set IDs now that they have been saved - cache_entry.whitelist_id = whitelist.id - cache_entry.rid = rtbh_rule.id - db.session.add(cache_entry) - db.session.commit() - - # Create a mock session that can get our rule - with patch.object(RuleWhitelistCache, "get_by_whitelist_id", return_value=[cache_entry]): - # Call the service function - flashes = whitelist_service.delete_whitelist(whitelist.id) - - # Verify flash messages - assert isinstance(flashes, list) - assert flashes - - # Verify the rule was deleted - assert db.session.get(RTBH, rtbh_rule.id) is None - - # Verify the whitelist was deleted - assert db.session.get(Whitelist, whitelist.id) is None - - # Verify cache cleanup was called - mock_clean.assert_called_once_with(whitelist.id) - - def test_delete_nonexistent_whitelist(self, app, db): - """Test deleting a whitelist that doesn't exist""" - with app.app_context(): - # Call the service function with a non-existent ID - flashes = whitelist_service.delete_whitelist(999) - - # Should return empty list of flash messages, as no whitelist was found - assert isinstance(flashes, list) - assert len(flashes) == 0 - - -class TestEvaluateWhitelistAgainstRtbhResults: - def test_equal_relation(self, app): - """Test evaluating a whitelist with an EQUAL relation to a rule""" - # Create test data - whitelist_model = MagicMock(spec=Whitelist) - whitelist_model.id = 1 - - flashes = [] - - rtbh_rule = MagicMock(spec=RTBH) - rtbh_rule.rstate_id = 1 # Active state - - rule_key = "192.168.1.0/24" - whitelist_key = "192.168.1.0/24" - rtbh_rule_cache = {rule_key: rtbh_rule} - - results = [(rule_key, whitelist_key, Relation.EQUAL)] - - # Mock required functions - with patch("flowapp.services.whitelist_service.whitelist_rtbh_rule") as mock_whitelist_rule, patch( - "flowapp.services.whitelist_service.withdraw_rtbh_route" - ) as mock_withdraw: - - # Call the function - with app.app_context(): - result = whitelist_service.evaluate_whitelist_against_rtbh_check_results( - whitelist_model, flashes, rtbh_rule_cache, results - ) - - # Verify the rule was whitelisted and route withdrawn - mock_whitelist_rule.assert_called_once_with(rtbh_rule, whitelist_model) - mock_withdraw.assert_called_once_with(rtbh_rule) - - # Verify the flash message - assert flashes - - # Verify the correct model was returned - assert result == whitelist_model - - def test_subnet_relation(self, app): - """Test evaluating a whitelist with a SUBNET relation to a rule""" - # Create test data - whitelist_model = MagicMock(spec=Whitelist) - whitelist_model.id = 1 - - flashes = [] - - rtbh_rule = MagicMock(spec=RTBH) - rtbh_rule.rstate_id = 1 # Active state - - rule_key = "192.168.1.0/24" - whitelist_key = "192.168.1.128/25" - rtbh_rule_cache = {rule_key: rtbh_rule} - - results = [(rule_key, whitelist_key, Relation.SUBNET)] - - # Mock required functions - with patch("flowapp.services.whitelist_service.subtract_network") as mock_subtract, patch( - "flowapp.services.whitelist_service.create_rtbh_from_whitelist_parts" - ) as mock_create, patch("flowapp.services.whitelist_service.add_rtbh_rule_to_cache") as mock_add_cache, patch( - "flowapp.services.whitelist_service.db.session.commit" - ) as mock_commit: - - # Mock subtract_network to return some subnets - mock_subtract.return_value = ["192.168.1.0/25"] - - # Call the function - with app.app_context(): - _result = whitelist_service.evaluate_whitelist_against_rtbh_check_results( - whitelist_model, flashes, rtbh_rule_cache, results - ) - - # Verify subnet calculation was performed - mock_subtract.assert_called_once() - - # Verify new rules were created for the subnets - mock_create.assert_called_once() - - # Verify the original rule was cached - mock_add_cache.assert_called_once_with(rtbh_rule, whitelist_model.id, RuleOrigin.USER) - - # Verify transaction was committed - mock_commit.assert_called_once() - - # Verify the flash messages - assert any("supernet of whitelist" in msg for msg in flashes) - - # Verify model was updated to whitelisted state - assert rtbh_rule.rstate_id == 4 - - def test_supernet_relation(self, app): - """Test evaluating a whitelist with a SUPERNET relation to a rule""" - # Create test data - whitelist_model = MagicMock(spec=Whitelist) - whitelist_model.id = 1 - - flashes = [] - - rtbh_rule = MagicMock(spec=RTBH) - rtbh_rule.rstate_id = 1 # Active state - - rule_key = "192.168.1.0/24" - whitelist_key = "192.168.0.0/16" - rtbh_rule_cache = {rule_key: rtbh_rule} - - results = [(rule_key, whitelist_key, Relation.SUPERNET)] - - # Mock required functions - with patch("flowapp.services.whitelist_service.whitelist_rtbh_rule") as mock_whitelist_rule, patch( - "flowapp.services.whitelist_service.withdraw_rtbh_route" - ) as mock_withdraw: - - # Call the function - with app.app_context(): - result = whitelist_service.evaluate_whitelist_against_rtbh_check_results( - whitelist_model, flashes, rtbh_rule_cache, results - ) - - # Verify the rule was whitelisted and route withdrawn - mock_whitelist_rule.assert_called_once_with(rtbh_rule, whitelist_model) - mock_withdraw.assert_called_once_with(rtbh_rule) - - # Verify the flash message - assert any("subnet of whitelist" in msg for msg in flashes) - - # Verify the correct model was returned - assert result == whitelist_model diff --git a/flowapp/tests/test_zzz_api_rtbh_expired_bug.py b/flowapp/tests/test_zzz_api_rtbh_expired_bug.py deleted file mode 100644 index 6fd36a2..0000000 --- a/flowapp/tests/test_zzz_api_rtbh_expired_bug.py +++ /dev/null @@ -1,240 +0,0 @@ -import json -from datetime import datetime, timedelta - -from flowapp.models import RTBH -from flowapp.models.rules.whitelist import Whitelist - - -def test_create_rtbh_after_expired_rule_exists(client, app, db, jwt_token): - """ - Test that demonstrates the bug: creating a new RTBH rule with the same IP - as an expired rule results in the new rule having withdrawn state instead - of active state. - - Test should be run in isolation or as the last in stack. - - Steps: - 1. Create an RTBH rule with expiration in the past (will be withdrawn, rstate_id=2) - 2. Create another RTBH rule with the same IP but expiration in the future - 3. Verify that the second rule should be active (rstate_id=1) but is actually withdrawn (rstate_id=2) - """ - # cleanup any existing RTBH rules to avoid interference - cleanup_before_stack(app, db) - - # Step 1: Create an expired RTBH rule - expired_payload = { - "community": 1, - "ipv4": "192.168.100.50", - "ipv4_mask": 32, - "expires": (datetime.now() - timedelta(days=1)).strftime("%m/%d/%Y %H:%M"), - "comment": "Expired rule that should be in withdrawn state", - } - - response1 = client.post( - "/api/v3/rules/rtbh", - headers={"x-access-token": jwt_token}, - json=expired_payload, - ) - - assert response1.status_code == 201 - data1 = json.loads(response1.data) - rule_id_1 = data1["rule"]["id"] - - # Verify the first rule is in withdrawn state - with app.app_context(): - expired_rule = db.session.query(RTBH).filter_by(id=rule_id_1).first() - assert expired_rule is not None - assert expired_rule.rstate_id == 2, "Expired rule should be in withdrawn state (rstate_id=2)" - assert expired_rule.ipv4 == "192.168.100.50" - assert expired_rule.ipv4_mask == 32 - print(f"✓ First rule created with ID {rule_id_1}, state: {expired_rule.rstate_id} (withdrawn)") - - # Step 2: Create a new RTBH rule with the same IP but future expiration - future_payload = { - "community": 1, - "ipv4": "192.168.100.50", - "ipv4_mask": 32, - "expires": (datetime.now() + timedelta(days=7)).strftime("%m/%d/%Y %H:%M"), - "comment": "New rule that should be active but will be withdrawn due to bug", - } - - response2 = client.post( - "/api/v3/rules/rtbh", - headers={"x-access-token": jwt_token}, - json=future_payload, - ) - - assert response2.status_code == 201 - data2 = json.loads(response2.data) - rule_id_2 = data2["rule"]["id"] - - # Step 3: Verify the bug - the second rule should be active but is withdrawn - with app.app_context(): - # The bug causes the expired rule to be updated instead of creating a new one - # OR if a new rule is created, it has the wrong state - - # Check if it's the same rule (updated) or a new rule - total_rules = db.session.query(RTBH).filter_by(ipv4="192.168.100.50", ipv4_mask=32).count() - - new_rule = db.session.query(RTBH).filter_by(id=rule_id_2).first() - assert new_rule is not None - - print("\n--- Bug Verification ---") - print(f"Total rules with IP 192.168.100.50/32: {total_rules}") - print(f"First rule ID: {rule_id_1}") - print(f"Second rule ID: {rule_id_2}") - print(f"Same rule updated: {rule_id_1 == rule_id_2}") - print(f"Second rule state: {new_rule.rstate_id}") - print(f"Second rule expires: {new_rule.expires}") - print(f"Expiration is in future: {new_rule.expires > datetime.now()}") - - # The bug: even though expiration is in the future, the rule is in withdrawn state - # EXPECTED: rstate_id should be 1 (active) - # ACTUAL: rstate_id is 2 (withdrawn) due to the bug - - # This assertion will FAIL due to the bug, demonstrating the issue - assert new_rule.expires > datetime.now(), "Rule expiration should be in the future" - - # THIS IS THE BUG: The rule has future expiration but is in withdrawn state - try: - assert new_rule.rstate_id == 1, ( - f"BUG DETECTED: Rule with future expiration should be active (rstate_id=1), " - f"but is in state {new_rule.rstate_id}. " - f"This happens because the expired rule was found and updated without resetting the state." - ) - print("✓ Test PASSED - bug is fixed!") - except AssertionError as e: - print(f"✗ Test FAILED - bug confirmed: {e}") - raise - cleanup_rtbh_rule(app, db, rule_id_1) - cleanup_rtbh_rule(app, db, rule_id_2) - - -def test_create_rtbh_after_expired_rule_different_mask(client, app, db, jwt_token): - """ - Test that verifies the bug only occurs when IP AND mask match. - When the mask is different, a new rule should be created successfully. - """ - - # Step 1: Create an expired RTBH rule with /32 mask - expired_payload = { - "community": 1, - "ipv4": "192.168.100.60", - "ipv4_mask": 32, - "expires": (datetime.now() - timedelta(days=1)).strftime("%m/%d/%Y %H:%M"), - "comment": "Expired /32 rule", - } - - response1 = client.post( - "/api/v3/rules/rtbh", - headers={"x-access-token": jwt_token}, - json=expired_payload, - ) - - assert response1.status_code == 201 - - # Step 2: Create a new rule with same IP but different mask (/24) - future_payload = { - "community": 1, - "ipv4": "192.168.100.0", - "ipv4_mask": 24, - "expires": (datetime.now() + timedelta(days=7)).strftime("%m/%d/%Y %H:%M"), - "comment": "New /24 rule - should be active", - } - - response2 = client.post( - "/api/v3/rules/rtbh", - headers={"x-access-token": jwt_token}, - json=future_payload, - ) - - assert response2.status_code == 201 - data2 = json.loads(response2.data) - - # Verify the new rule is active (this should work because IP+mask don't match) - with app.app_context(): - new_rule = db.session.query(RTBH).filter_by(id=data2["rule"]["id"]).first() - assert new_rule is not None - assert new_rule.rstate_id == 1, "New rule with different mask should be active" - print("✓ Different mask creates new active rule correctly") - - cleanup_rtbh_rule(app, db, new_rule.id) - - -def test_create_rtbh_after_active_rule_exists(client, app, db, jwt_token): - """ - Test that when an active rule exists, updating it with a new expiration - maintains the active state (this should work correctly). - """ - - # Step 1: Create an active RTBH rule - active_payload = { - "community": 1, - "ipv4": "192.168.100.70", - "ipv4_mask": 32, - "expires": (datetime.now() + timedelta(days=1)).strftime("%m/%d/%Y %H:%M"), - "comment": "Active rule", - } - - response1 = client.post( - "/api/v3/rules/rtbh", - headers={"x-access-token": jwt_token}, - json=active_payload, - ) - - assert response1.status_code == 201 - data1 = json.loads(response1.data) - rule_id_1 = data1["rule"]["id"] - - # Verify the first rule is active - with app.app_context(): - first_rule = db.session.query(RTBH).filter_by(id=rule_id_1).first() - assert first_rule.rstate_id == 1, "First rule should be active" - - # Step 2: Update the same rule with a new expiration - updated_payload = { - "community": 1, - "ipv4": "192.168.100.70", - "ipv4_mask": 32, - "expires": (datetime.now() + timedelta(days=7)).strftime("%m/%d/%Y %H:%M"), - "comment": "Updated active rule", - } - - response2 = client.post( - "/api/v3/rules/rtbh", - headers={"x-access-token": jwt_token}, - json=updated_payload, - ) - - assert response2.status_code == 201 - data2 = json.loads(response2.data) - - # Verify it maintains active state - with app.app_context(): - updated_rule = db.session.query(RTBH).filter_by(id=data2["rule"]["id"]).first() - assert updated_rule is not None - assert updated_rule.rstate_id == 1, "Updated rule should remain active" - print("✓ Updating active rule maintains active state correctly") - - cleanup_rtbh_rule(app, db, rule_id_1) - - -def cleanup_before_stack(app, db): - """ - Cleanup function to remove all RTBH rules created during tests. - """ - with app.app_context(): - db.session.query(RTBH).delete() - db.session.query(Whitelist).delete() - db.session.commit() - - -def cleanup_rtbh_rule(app, db, rule_id): - """ - Cleanup function to remove RTBH rule created during tests. - """ - with app.app_context(): - rule = db.session.get(RTBH, rule_id) - if rule: - db.session.delete(rule) - db.session.commit() From 09256d546135eb3d963af3c131b21b9639d63c2f Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 15 Jan 2026 18:29:30 +0100 Subject: [PATCH 07/10] move tests outside flowapp package --- tests/__init__.py | 0 tests/conftest.py | 232 +++++++ tests/rule_service_integration.py | 332 ++++++++++ tests/test_api_auth.py | 71 +++ tests/test_api_deprecated.py | 28 + tests/test_api_v3.py | 611 ++++++++++++++++++ tests/test_api_whitelist_integration.py | 273 ++++++++ tests/test_flowapp.py | 14 + tests/test_flowspec.py | 167 +++++ tests/test_forms.py | 65 ++ tests/test_forms_cl.py | 608 ++++++++++++++++++ tests/test_login.py | 0 tests/test_models.py | 498 +++++++++++++++ tests/test_rule_service.py | 628 +++++++++++++++++++ tests/test_rule_service_reactivate_delete.py | 527 ++++++++++++++++ tests/test_utils.py | 71 +++ tests/test_validators.py | 547 ++++++++++++++++ tests/test_whitelist_common.py | 250 ++++++++ tests/test_whitelist_service.py | 461 ++++++++++++++ tests/test_zzz_api_rtbh_expired_bug.py | 240 +++++++ 20 files changed, 5623 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/rule_service_integration.py create mode 100644 tests/test_api_auth.py create mode 100644 tests/test_api_deprecated.py create mode 100644 tests/test_api_v3.py create mode 100644 tests/test_api_whitelist_integration.py create mode 100644 tests/test_flowapp.py create mode 100644 tests/test_flowspec.py create mode 100644 tests/test_forms.py create mode 100644 tests/test_forms_cl.py create mode 100644 tests/test_login.py create mode 100644 tests/test_models.py create mode 100644 tests/test_rule_service.py create mode 100644 tests/test_rule_service_reactivate_delete.py create mode 100644 tests/test_utils.py create mode 100644 tests/test_validators.py create mode 100644 tests/test_whitelist_common.py create mode 100644 tests/test_whitelist_service.py create mode 100644 tests/test_zzz_api_rtbh_expired_bug.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..3a988ef --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,232 @@ +""" +PyTest configuration file for all tests +""" + +import os +import json +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from flowapp import create_app +from flowapp import db as _db +from datetime import datetime +import flowapp.models +from flowapp.models.organization import Organization + + +TESTDB = "test_project.db" +TESTDB_PATH = "/tmp/{}".format(TESTDB) +TEST_DATABASE_URI = "sqlite:///" + TESTDB_PATH + + +class FieldMock: + def __init__(self): + self.data = None + self.errors = [] + + +class RuleMock: + def __init__(self): + self.source = None + self.source_mask = None + self.dest = None + self.dest_mask = None + + +@pytest.fixture +def field(): + return FieldMock() + + +@pytest.fixture +def field_class(): + return FieldMock + + +@pytest.fixture +def rule(): + return RuleMock() + + +@pytest.fixture(scope="session") +def app(request): + """ + Create a Flask app, and override settings, for the whole test session. + """ + + _app = create_app() + + _app.config.update( + EXA_API="HTTP", + EXA_API_URL="http://localhost:5000/", + TESTING=True, + SQLALCHEMY_DATABASE_URI=TEST_DATABASE_URI, + SQLALCHEMY_TRACK_MODIFICATIONS=False, + JWT_SECRET="testing", + API_KEY="testkey", + SECRET_KEY="testkeysession", + LOCAL_USER_UUID="jiri.vrany@cesnet.cz", + LOCAL_AUTH=True, + ALLOWED_COMMUNITIES=[1, 2, 3], + WTF_CSRF_ENABLED=False, + ) + + print("\n----- CREATE FLASK APPLICATION\n") + context = _app.app_context() + context.push() + yield _app + print("\n----- CREATE FLASK APPLICATION CONTEXT\n") + + context.pop() + print("\n----- RELEASE FLASK APPLICATION CONTEXT\n") + + +@pytest.fixture(scope="session") +def client(app, request): + """ + Get the test_client from the app, for the whole test session. + """ + print("\n----- CREATE FLASK TEST CLIENT\n") + return app.test_client() + + +@pytest.fixture(scope="session") +def db(app, request): + """ + Create entire database for every test. + """ + engine = create_engine(app.config["SQLALCHEMY_DATABASE_URI"], echo=True) + sessionmaker(bind=engine) + print("\n----- CREATE TEST DB CONNECTION POOL\n") + if os.path.exists(TESTDB_PATH): + os.unlink(TESTDB_PATH) + + with app.app_context(): + _db.init_app(app) + print("#: cleaning database") + _db.reflect() + _db.drop_all() + print("#: creating tables") + _db.create_all() + + users = [ + {"name": "jiri.vrany@cesnet.cz", "role_id": 3, "org_id": 1}, + {"name": "petr.adamec@cesnet.cz", "role_id": 3, "org_id": 1}, + ] + print("#: inserting users") + flowapp.models.insert_users(users) + + org = _db.session.query(Organization).filter_by(id=1).first() + # Update the organization address range to include our test networks + org.arange = "147.230.0.0/16\n2001:718:1c01::/48\n192.168.0.0/16\n10.0.0.0/8" + _db.session.commit() + print("\n----- UPDATED TEST ORG 1 \n", org) + + def teardown(): + _db.session.commit() + _db.drop_all() + os.unlink(TESTDB_PATH) + + request.addfinalizer(teardown) + return _db + + +@pytest.fixture(scope="session") +def jwt_token(client, app, db, request): + """ + Get the test_client from the app, for the whole test session. + """ + mkey = "testkey" + + with app.app_context(): + model = flowapp.models.ApiKey(machine="127.0.0.1", key=mkey, user_id=1, org_id=1) + db.session.add(model) + db.session.commit() + + print("\n----- GET JWT TEST TOKEN\n") + url = "/api/v3/auth" + headers = {"x-api-key": mkey} + token = client.get(url, headers=headers) + data = json.loads(token.data) + return data["token"] + + +@pytest.fixture(scope="session") +def machine_api_token(client, app, db, request): + """ + Get the test_client from the app, for the whole test session. + """ + mkey = "machinetestkey" + + with app.app_context(): + model = flowapp.models.MachineApiKey(machine="127.0.0.1", key=mkey, user_id=1, org_id=1) + db.session.add(model) + db.session.commit() + + print("\n----- GET MACHINE API KEY TEST TOKEN\n") + url = "/api/v3/auth" + headers = {"x-api-key": mkey} + token = client.get(url, headers=headers) + data = json.loads(token.data) + return data["token"] + + +@pytest.fixture(scope="session") +def expired_auth_token(client, app, db, request): + """ + Get the test_client from the app, for the whole test session. + """ + test_key = "expired_test_key" + expired_date = datetime.strptime("2019-01-01", "%Y-%m-%d") + with app.app_context(): + model = flowapp.models.ApiKey(machine="127.0.0.1", key=test_key, user_id=1, expires=expired_date, org_id=1) + db.session.add(model) + db.session.commit() + + return test_key + + +@pytest.fixture(scope="session") +def readonly_jwt_token(client, app, db, request): + """ + Get the test_client from the app, for the whole test session. + """ + readonly_key = "readonly-testkey" + with app.app_context(): + model = flowapp.models.ApiKey(machine="127.0.0.1", key=readonly_key, user_id=1, readonly=True, org_id=1) + db.session.add(model) + db.session.commit() + + print("\n----- GET JWT TEST TOKEN\n") + url = "/api/v3/auth" + headers = {"x-api-key": readonly_key} + token = client.get(url, headers=headers) + data = json.loads(token.data) + return data["token"] + + +@pytest.fixture(scope="session") +def auth_client(client): + """ + Get the test_client from the app, for the whole test session. + """ + print("\n----- CREATE AUTHENTICATED FLASK TEST CLIENT\n") + client.get("/local-login") + return client + + +@pytest.fixture(autouse=True) +def reset_org_limits(db, app): + """ + Automatically reset organization limits after each test that modifies them. + """ + yield # Allow test execution + + with app.app_context(): + org = db.session.query(Organization).filter_by(id=1).first() + if org: + org.limit_flowspec4 = 0 + org.limit_flowspec6 = 0 + org.limit_rtbh = 0 + db.session.commit() diff --git a/tests/rule_service_integration.py b/tests/rule_service_integration.py new file mode 100644 index 0000000..106000a --- /dev/null +++ b/tests/rule_service_integration.py @@ -0,0 +1,332 @@ +"""Integration tests for rule_service module.""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import patch + +from flowapp import db, create_app +from flowapp.constants import RuleTypes +from flowapp.models import ( + RTBH, + Flowspec4, + Flowspec6, + Whitelist, + User, + Organization, + Role, + Community, + Action, + Rstate, +) +from flowapp.services import rule_service + + +@pytest.fixture(scope="module") +def app(): + """Create and configure a Flask app for testing.""" + app = create_app("testing") + + # Create a test context + with app.app_context(): + # Create database tables + db.create_all() + + # Create test data + _create_test_data() + + yield app + + # Clean up after tests + db.session.remove() + db.drop_all() + + +@pytest.fixture(scope="function") +def app_context(app): + """Provide an application context for each test.""" + with app.app_context(): + yield + + +@pytest.fixture(scope="function") +def mock_announce_route(): + """Mock the announce_route function.""" + with patch("flowapp.services.rule_service.announce_route") as mock: + yield mock + + +def _create_test_data(): + """Create test data in the database.""" + # Create roles + admin_role = Role(name="admin", description="Administrator") + user_role = Role(name="user", description="Normal user") + view_role = Role(name="view", description="View only") + db.session.add_all([admin_role, user_role, view_role]) + db.session.flush() + + # Create organizations + org1 = Organization(name="Test Org 1", arange="192.168.1.0/24\n2001:db8::/64") + org2 = Organization(name="Test Org 2", arange="10.0.0.0/8") + db.session.add_all([org1, org2]) + db.session.flush() + + # Create users + user1 = User(uuid="user1@example.com", name="User One") + user1.role.append(user_role) + user1.organization.append(org1) + + admin1 = User(uuid="admin@example.com", name="Admin User") + admin1.role.append(admin_role) + admin1.organization.append(org2) + + db.session.add_all([user1, admin1]) + db.session.flush() + + # Create rule states + state_active = Rstate(description="active rule") + state_withdrawn = Rstate(description="withdrawed rule") + state_deleted = Rstate(description="deleted rule") + state_whitelisted = Rstate(description="whitelisted rule") + db.session.add_all([state_active, state_withdrawn, state_deleted, state_whitelisted]) + db.session.flush() + + # Create actions + action1 = Action(name="Discard", command="discard", description="Drop traffic", role_id=user_role.id) + action2 = Action(name="Rate limit", command="rate-limit 1000000", description="Limit traffic", role_id=user_role.id) + db.session.add_all([action1, action2]) + db.session.flush() + + # Create communities + community1 = Community( + name="65000:1", + comm="65000:1", + larcomm="", + extcomm="", + description="Test community", + role_id=user_role.id, + as_path=False, + ) + db.session.add(community1) + db.session.flush() + + # Create some rules + # IPv4 rule + ipv4_rule = Flowspec4( + source="192.168.1.1", + source_mask=32, + source_port="", + dest="192.168.2.1", + dest_mask=32, + dest_port="", + protocol="tcp", + flags="", + packet_len="", + fragment="", + expires=datetime.now() + timedelta(hours=1), + comment="Test IPv4 rule", + action_id=action1.id, + user_id=user1.id, + org_id=org1.id, + rstate_id=state_active.id, + ) + db.session.add(ipv4_rule) + + # IPv6 rule + ipv6_rule = Flowspec6( + source="2001:db8::1", + source_mask=128, + source_port="", + dest="2001:db8:1::1", + dest_mask=128, + dest_port="", + next_header="tcp", + flags="", + packet_len="", + expires=datetime.now() + timedelta(hours=1), + comment="Test IPv6 rule", + action_id=action1.id, + user_id=user1.id, + org_id=org1.id, + rstate_id=state_active.id, + ) + db.session.add(ipv6_rule) + + # RTBH rule + rtbh_rule = RTBH( + ipv4="192.168.1.100", + ipv4_mask=32, + ipv6=None, + ipv6_mask=None, + community_id=community1.id, + expires=datetime.now() + timedelta(hours=1), + comment="Test RTBH rule", + user_id=user1.id, + org_id=org1.id, + rstate_id=state_active.id, + ) + db.session.add(rtbh_rule) + + # Save IDs as class variables for tests to use + _create_test_data.user_id = user1.id + _create_test_data.admin_id = admin1.id + _create_test_data.org1_id = org1.id + _create_test_data.org2_id = org2.id + _create_test_data.ipv4_rule_id = ipv4_rule.id + _create_test_data.ipv6_rule_id = ipv6_rule.id + _create_test_data.rtbh_rule_id = rtbh_rule.id + _create_test_data.action_id = action1.id + _create_test_data.community_id = community1.id + + db.session.commit() + + +class TestRuleServiceIntegration: + """Integration tests for rule_service functions.""" + + def test_reactivate_rule_ipv4(self, app_context, mock_announce_route): + """Test reactivating an IPv4 rule.""" + # Test data + rule_id = _create_test_data.ipv4_rule_id + user_id = _create_test_data.user_id + org_id = _create_test_data.org1_id + + # Set new expiration time (2 hours in the future) + new_expires = datetime.now() + timedelta(hours=2) + new_comment = "Updated test comment" + + # Call the function + model, messages = rule_service.reactivate_rule( + rule_type=RuleTypes.IPv4, + rule_id=rule_id, + expires=new_expires, + comment=new_comment, + user_id=user_id, + org_id=org_id, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the rule was updated + assert model is not None + assert model.id == rule_id + assert model.comment == new_comment + assert model.expires == new_expires + assert model.rstate_id == 1 # active state + + # Verify route announcement was attempted + assert mock_announce_route.called + + # Verify message + assert messages == ["Rule successfully updated"] + + def test_reactivate_rule_inactive(self, app_context, mock_announce_route): + """Test reactivating a rule to inactive state.""" + # Test data + rule_id = _create_test_data.ipv4_rule_id + user_id = _create_test_data.user_id + org_id = _create_test_data.org1_id + + # Set past expiration time + past_expires = datetime.now() - timedelta(hours=1) + + # Call the function + model, messages = rule_service.reactivate_rule( + rule_type=RuleTypes.IPv4, + rule_id=rule_id, + expires=past_expires, + comment="Expired comment", + user_id=user_id, + org_id=org_id, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the rule was updated + assert model is not None + assert model.id == rule_id + assert model.expires == past_expires + assert model.rstate_id == 2 # inactive/withdrawn state + + # Verify route announcement was attempted + assert mock_announce_route.called + + def test_delete_rule(self, app_context, mock_announce_route): + """Test deleting a rule.""" + # Test data + rule_id = _create_test_data.ipv6_rule_id + user_id = _create_test_data.user_id + + # Call the function + success, message = rule_service.delete_rule( + rule_type=RuleTypes.IPv6, + rule_id=rule_id, + user_id=user_id, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the rule was deleted + assert success is True + assert message == "Rule deleted successfully" + + # Verify the rule no longer exists in the database + rule = db.session.get(Flowspec6, rule_id) + assert rule is None + + # Verify route withdrawal was attempted + assert mock_announce_route.called + + def test_delete_rule_not_found(self, app_context): + """Test deleting a non-existent rule.""" + # Use a non-existent rule ID + non_existent_id = 9999 + + # Call the function + success, message = rule_service.delete_rule( + rule_type=RuleTypes.IPv4, + rule_id=non_existent_id, + user_id=_create_test_data.user_id, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the operation failed + assert success is False + assert message == "Rule not found" + + @patch("flowapp.services.rule_service.create_or_update_whitelist") + def test_delete_rtbh_and_create_whitelist(self, mock_create_whitelist, app_context, mock_announce_route): + """Test deleting an RTBH rule and creating a whitelist entry.""" + # Test data + rule_id = _create_test_data.rtbh_rule_id + user_id = _create_test_data.user_id + org_id = _create_test_data.org1_id + + # Mock whitelist creation + mock_whitelist = Whitelist( + ip="192.168.1.100", mask=32, expires=datetime.now() + timedelta(days=7), user_id=user_id, org_id=org_id + ) + mock_create_whitelist.return_value = (mock_whitelist, ["Whitelist created"]) + + # Call the function + success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( + rule_id=rule_id, user_id=user_id, org_id=org_id, user_email="test@example.com", org_name="Test Org" + ) + + # Verify success + assert success is True + assert len(messages) == 2 + assert "Rule deleted successfully" in messages[0] + assert "Whitelist created" in messages[1] + + # Verify the rule was deleted + rule = db.session.get(RTBH, rule_id) + assert rule is None + + # Verify create_or_update_whitelist was called with correct data + mock_create_whitelist.assert_called_once() + args, kwargs = mock_create_whitelist.call_args + form_data = kwargs.get("form_data", args[0] if args else None) + assert form_data["ip"] == "192.168.1.100" + assert form_data["mask"] == 32 + assert "Created from RTBH rule" in form_data["comment"] diff --git a/tests/test_api_auth.py b/tests/test_api_auth.py new file mode 100644 index 0000000..8d08a7a --- /dev/null +++ b/tests/test_api_auth.py @@ -0,0 +1,71 @@ +# Test for api authorization +import json + + +def test_token(client, jwt_token): + """ + test that token authorization works + """ + req = client.get("/api/v3/test_token", headers={"x-access-token": jwt_token}) + + assert req.status_code == 200 + + +def test_machine_token(client, machine_api_token): + """ + test that token authorization works + """ + req = client.get("/api/v3/test_token", headers={"x-access-token": machine_api_token}) + + assert req.status_code == 200 + + +def test_expired_token(client, expired_auth_token): + """ + test that expired token authorization return 401 + """ + req = client.get("/api/v3/auth", headers={"x-api-key": expired_auth_token}) + + assert req.status_code == 401 + + +def test_withnout_token(client): + """ + test that without token authorization return 401 + """ + req = client.get("/api/v3/test_token") + + assert req.status_code == 401 + + +def test_readonly_token(client, readonly_jwt_token): + """ + test that readonly flag is set correctly + """ + req = client.get("/api/v3/test_token", headers={"x-access-token": readonly_jwt_token}) + + assert req.status_code == 200 + data = json.loads(req.data) + assert data["readonly"] + + +def test_readonly_token_ipv4_create(client, db, readonly_jwt_token): + """ + test that readonly token can't create ipv4 rule + """ + headers = {"x-access-token": readonly_jwt_token} + + req = client.post( + "/api/v3/rules/ipv4", + headers=headers, + json={ + "action": 2, + "protocol": "tcp", + "source": "147.230.17.117", + "source_mask": 32, + "source_port": "", + "expires": "1444913400", + }, + ) + + assert req.status_code == 403 diff --git a/tests/test_api_deprecated.py b/tests/test_api_deprecated.py new file mode 100644 index 0000000..fca9414 --- /dev/null +++ b/tests/test_api_deprecated.py @@ -0,0 +1,28 @@ +V_PREFIX = "/api/v1" + + +def test_token(client, jwt_token): + """ + test that token authorization works + """ + req = client.get(f"{V_PREFIX}/test_token", headers={"x-access-token": jwt_token}) + + assert req.status_code == 400 + + +def test_withnout_token(client): + """ + test that without token authorization return 401 + """ + req = client.get(f"{V_PREFIX}/test_token") + + assert req.status_code == 400 + + +def test_rules(client, db, jwt_token): + """ + test that there is one ipv4 rule created in the first test + """ + req = client.get(f"{V_PREFIX}/rules", headers={"x-access-token": jwt_token}) + + assert req.status_code == 400 diff --git a/tests/test_api_v3.py b/tests/test_api_v3.py new file mode 100644 index 0000000..0548532 --- /dev/null +++ b/tests/test_api_v3.py @@ -0,0 +1,611 @@ +import json + + +from flowapp.models import Flowspec4, Organization + +V_PREFIX = "/api/v3" + + +def test_create_rtbh_only(client, app, db, jwt_token): + """ + Test creating an RTBH rule through API that exactly matches an existing whitelist. + + The rule should be created but marked as whitelisted (rstate_id=4), and a cache entry + should be created linking the rule to the whitelist. + """ + + # Now create the RTBH rule via API + res = client.post( + f"{V_PREFIX}/rules/rtbh", + headers={"x-access-token": jwt_token}, + json={ + "community": 1, + "ipv4": "147.230.17.17", + "ipv4_mask": 32, + "expires": "10/25/2050 14:46", + }, + ) + + # Verify response is successful + assert res.status_code == 201 + data = json.loads(res.data) + assert data["rule"] is not None + + +def test_token(client, jwt_token): + """ + test that token authorization works + """ + req = client.get(f"{V_PREFIX}/test_token", headers={"x-access-token": jwt_token}) + + assert req.status_code == 200 + + +def test_withnout_token(client): + """ + test that without token authorization return 401 + """ + req = client.get(f"{V_PREFIX}/test_token") + + assert req.status_code == 401 + + +def test_list_actions(client, db, jwt_token): + """ + test that endpoint returns all action in db + """ + req = client.get(f"{V_PREFIX}/actions", headers={"x-access-token": jwt_token}) + + assert req.status_code == 200 + data = json.loads(req.data) + assert len(data) == 4 + + +def test_list_communities(client, db, jwt_token): + """ + test that endpoint returns all action in db + """ + req = client.get(f"{V_PREFIX}/communities", headers={"x-access-token": jwt_token}) + + assert req.status_code == 200 + data = json.loads(req.data) + assert len(data) == 3 + + +def test_create_v4rule(client, db, jwt_token): + """ + test that creating with valid data returns 201 + """ + req = client.post( + f"{V_PREFIX}/rules/ipv4", + headers={"x-access-token": jwt_token}, + json={ + "action": 2, + "protocol": "tcp", + "source": "147.230.17.17", + "source_mask": 32, + "source_port": "", + "expires": "10/15/2050 14:46", + "flags": ["SYN", "RST"], + }, + ) + + assert req.status_code == 201 + data = json.loads(req.data) + assert data["rule"] + assert data["rule"]["id"] == 1 + assert data["rule"]["user"] == "jiri.vrany@cesnet.cz" + + +def test_delete_v4rule(client, db, jwt_token): + """ + test that creating with valid data returns 201 + that time in the past creates expired rule (state 2) + and that the rule deletion works as expected + """ + req = client.post( + f"{V_PREFIX}/rules/ipv4", + headers={"x-access-token": jwt_token}, + json={ + "action": 2, + "protocol": "tcp", + "source": "147.230.17.12", + "source_mask": 32, + "source_port": "", + "expires": "10/15/2015 14:46", + }, + ) + + assert req.status_code == 201 + data = json.loads(req.data) + assert data["rule"]["id"] == 2 + assert data["rule"]["rstate"] == "withdrawed rule" + + req2 = client.delete( + f'{V_PREFIX}/rules/ipv4/{data["rule"]["id"]}', + headers={"x-access-token": jwt_token}, + ) + assert req2.status_code == 201 + + +def test_create_rtbh_rule(client, db, jwt_token): + """ + test that creating with valid data returns 201 + """ + req = client.post( + f"{V_PREFIX}/rules/rtbh", + headers={"x-access-token": jwt_token}, + json={ + "community": 1, + "ipv4": "147.230.17.17", + "ipv4_mask": 32, + "expires": "10/25/2050 14:46", + }, + ) + data = json.loads(req.data) + assert req.status_code == 201 + assert data["rule"] + assert data["rule"]["id"] == 1 + assert data["rule"]["user"] == "jiri.vrany@cesnet.cz" + + +def test_delete_rtbh_rule(client, db, jwt_token): + """ + test that creating with valid data returns 201 + """ + req = client.post( + f"{V_PREFIX}/rules/rtbh", + headers={"x-access-token": jwt_token}, + json={ + "community": 2, + "ipv4": "147.230.17.177", + "ipv4_mask": 32, + "expires": "10/25/2050 14:46", + }, + ) + + assert req.status_code == 201 + data = json.loads(req.data) + assert data["rule"]["id"] == 2 + req2 = client.delete( + f'{V_PREFIX}/rules/rtbh/{data["rule"]["id"]}', + headers={"x-access-token": jwt_token}, + ) + assert req2.status_code == 201 + + +def test_validation_rtbh_rule(client, db, jwt_token): + """ + test that creating with invalid data returns 400 and errors + """ + req = client.post( + f"{V_PREFIX}/rules/rtbh", + headers={"x-access-token": jwt_token}, + json={ + "community": 1, + "ipv4": "147.230.17.17", + "ipv4_mask": 32, + "ipv6": "2001:718:1C01:1111::", + "ipv6_mask": 128, + "expires": "10/25/2050 14:46", + }, + ) + data = json.loads(req.data) + assert req.status_code == 400 + assert data["message"] == "error - invalid request data" + assert type(data["validation_errors"]) is dict + assert "ipv6" in data["validation_errors"] + assert "ipv4" in data["validation_errors"] + + +def test_create_v6rule(client, db, jwt_token): + """ + test that creating with valid data returns 201 + """ + req = client.post( + f"{V_PREFIX}/rules/ipv6", + headers={"x-access-token": jwt_token}, + json={ + "action": 3, + "next_header": "tcp", + "source": "2001:718:1C01:1111::", + "source_mask": 64, + "source_port": "", + "expires": "10/25/2050 14:46", + }, + ) + data = json.loads(req.data) + assert req.status_code == 201 + assert data["rule"] + assert data["rule"]["id"] == "1" + assert data["rule"]["user"] == "jiri.vrany@cesnet.cz" + + +def test_validation_v4rule(client, db, jwt_token): + """ + test that creating with invalid data returns 400 and errors + """ + req = client.post( + f"{V_PREFIX}/rules/ipv4", + headers={"x-access-token": jwt_token}, + json={ + "action": 2, + "dest": "200.200.200.32", + "dest_mask": 16, + "protocol": "tcp", + "source": "1.1.1.1", + "source_mask": 32, + "source_port": "", + "expires": "10/15/2050 14:46", + }, + ) + + assert req.status_code == 400 + data = json.loads(req.data) + assert len(data["validation_errors"]) > 0 + assert sorted(data["validation_errors"].keys()) == sorted(["dest", "source"]) + assert len(data["validation_errors"]["dest"]) == 2 + assert data["validation_errors"]["dest"][0].startswith("This is not") + assert data["validation_errors"]["dest"][1].startswith("Source or des") + assert len(data["validation_errors"]["source"]) == 1 + assert data["validation_errors"]["source"][0].startswith("Source or des") + + +def test_all_validation_errors(client, db, jwt_token): + """ + test that creating with invalid data returns 400 and errors + """ + req = client.post(f"{V_PREFIX}/rules/ipv4", headers={"x-access-token": jwt_token}, json={"action": 2}) + assert req.status_code == 400 + + +def test_validate_v6rule(client, db, jwt_token): + """ + test that creating with invalid data returns 400 and errors + """ + req = client.post( + f"{V_PREFIX}/rules/ipv6", + headers={"x-access-token": jwt_token}, + json={ + "action": 32, + "next_header": "abc", + "source": "2011:78:1C01:1111::", + "source_mask": 64, + "source_port": "", + "expires": "10/25/2050 14:46", + }, + ) + data = json.loads(req.data) + assert req.status_code == 400 + assert len(data["validation_errors"]) > 0 + assert sorted(data["validation_errors"].keys()) == sorted(["action", "next_header", "dest", "source"]) + # assert data['validation_errors'][0].startswith('Error in the Action') + # assert data['validation_errors'][1].startswith('Error in the Source') + # assert data['validation_errors'][2].startswith('Error in the Next Header') + + +def test_rules(client, db, jwt_token): + """ + test that there is one ipv4 rule created in the first test + """ + req = client.get(f"{V_PREFIX}/rules", headers={"x-access-token": jwt_token}) + + assert req.status_code == 200 + + data = json.loads(req.data) + assert len(data["flowspec_ipv4_rw"]) == 1 + assert len(data["flowspec_ipv6_rw"]) == 1 + + +def test_timestamp_param(client, db, jwt_token): + """ + test that url param for time format works as expected + """ + req = client.get(f"{V_PREFIX}/rules?time_format=timestamp", headers={"x-access-token": jwt_token}) + + assert req.status_code == 200 + + data = json.loads(req.data) + assert data["flowspec_ipv4_rw"][0]["expires"] == 2549451000 + assert data["flowspec_ipv6_rw"][0]["expires"] == 2550315000 + + +def test_update_existing_v4rule_with_timestamp(client, db, jwt_token): + """ + test that update with different data passes + """ + req = client.post( + f"{V_PREFIX}/rules/ipv4", + headers={"x-access-token": jwt_token}, + json={ + "action": 2, + "protocol": "tcp", + "source": "147.230.17.17", + "source_mask": 32, + "source_port": "", + "expires": "1444913400", + }, + ) + + assert req.status_code == 201 + data = json.loads(req.data) + assert data["rule"] + assert data["rule"]["id"] == 2 + assert data["rule"]["user"] == "jiri.vrany@cesnet.cz" + assert data["rule"]["expires"] == 1444913400 + + +def test_create_v4rule_with_timestamp(client, db, jwt_token): + """ + test that creating with valid data returns 201 + """ + req = client.post( + f"{V_PREFIX}/rules/ipv4", + headers={"x-access-token": jwt_token}, + json={ + "action": 2, + "protocol": "tcp", + "source": "147.230.17.117", + "source_mask": 32, + "source_port": "", + "expires": "1444913400", + }, + ) + + assert req.status_code == 201 + data = json.loads(req.data) + assert data["rule"] + assert data["rule"]["id"] == 3 + assert data["rule"]["user"] == "jiri.vrany@cesnet.cz" + assert data["rule"]["expires"] == 1444913400 + + +def test_update_existing_v6rule_with_timestamp(client, db, jwt_token): + """ + test that update with different data passes + """ + req = client.post( + f"{V_PREFIX}/rules/ipv6", + headers={"x-access-token": jwt_token}, + json={ + "action": 3, + "next_header": "tcp", + "source": "2001:718:1C01:1111::", + "source_mask": 64, + "source_port": "", + "expires": "1444913400", + }, + ) + data = json.loads(req.data) + assert req.status_code == 201 + assert data["rule"] + assert data["rule"]["id"] == "1" + assert data["rule"]["user"] == "jiri.vrany@cesnet.cz" + assert data["rule"]["expires"] == 1444913400 + + +def test_create_v6rule_with_timestamp(client, db, jwt_token): + """ + test that creating with valid data returns 201 + """ + req = client.post( + f"{V_PREFIX}/rules/ipv6", + headers={"x-access-token": jwt_token}, + json={ + "action": 2, + "next_header": "udp", + "source": "2001:718:1C01:1111::", + "source_mask": 64, + "source_port": "", + "expires": "2549952908", + }, + ) + data = json.loads(req.data) + assert req.status_code == 201 + assert data["rule"] + assert data["rule"]["id"] == "2" + assert data["rule"]["rstate"] == "active rule" + assert data["rule"]["user"] == "jiri.vrany@cesnet.cz" + assert data["rule"]["expires"] == 2549953200 + + +def test_update_existing_rtbh_rule_with_timestamp(client, db, jwt_token): + """ + test that update with different data passes + """ + req = client.post( + f"{V_PREFIX}/rules/rtbh", + headers={"x-access-token": jwt_token}, + json={ + "community": 1, + "ipv4": "147.230.17.17", + "ipv4_mask": 32, + "expires": "1444913400", + }, + ) + data = json.loads(req.data) + assert req.status_code == 201 + assert data["rule"] + assert data["rule"]["id"] == 1 + assert data["rule"]["user"] == "jiri.vrany@cesnet.cz" + assert data["rule"]["expires"] == 1444913400 + + +def test_create_rtbh_rule_with_timestamp(client, db, jwt_token): + """ + test that creating with valid data returns 201 + """ + req = client.post( + f"{V_PREFIX}/rules/rtbh", + headers={"x-access-token": jwt_token}, + json={ + "community": 1, + "ipv4": "147.230.17.117", + "ipv4_mask": 32, + "expires": "1444913400", + }, + ) + data = json.loads(req.data) + assert req.status_code == 201 + assert data["rule"] + assert data["rule"]["id"] == 2 + assert data["rule"]["user"] == "jiri.vrany@cesnet.cz" + assert data["rule"]["expires"] == 1444913400 + + +def test_create_v4rule_lmit(client, db, app, jwt_token): + """ + test that limit checkt for v4 works + """ + with app.app_context(): + org = db.session.query(Organization).filter_by(id=1).first() + org.limit_flowspec4 = 2 + db.session.commit() + + # count + count = db.session.query(Flowspec4).count() + print("COUNT", count) + + sources = ["147.230.42.17", "147.230.42.118"] + codes = [201, 403] + + for source, code in zip(sources, codes): + data = { + "action": 1, + "protocol": "tcp", + "source": source, + "source_mask": 32, + "source_port": "", + "expires": "10/15/2050 14:46", + } + req = client.post( + f"{V_PREFIX}/rules/ipv4", + headers={"x-access-token": jwt_token}, + json=data, + ) + + assert req.status_code == code + + +def test_create_v6rule_lmit(client, db, app, jwt_token): + """ + test that limit check for v6 works + """ + with app.app_context(): + org = db.session.query(Organization).filter_by(id=1).first() + org.limit_flowspec6 = 3 + db.session.commit() + + sources = ["2001:718:1C01:1111::", "2001:718:1C01:1112::"] + codes = [201, 403] + + for source, code in zip(sources, codes): + data = { + "action": 1, + "next_header": "tcp", + "source": source, + "source_mask": 64, + "source_port": "", + "expires": "10/15/2050 14:46", + } + req = client.post( + f"{V_PREFIX}/rules/ipv6", + headers={"x-access-token": jwt_token}, + json=data, + ) + + assert req.status_code == code + + +def test_create_rtbh_lmit(client, db, app, jwt_token): + """ + test that limit check for v6 works + """ + with app.app_context(): + org = db.session.query(Organization).filter_by(id=1).first() + org.limit_rtbh = 1 + db.session.commit() + + sources = ["147.230.17.42", "147.230.17.43"] + codes = [201, 403] + + for source, code in zip(sources, codes): + data = { + "community": 1, + "ipv4": source, + "ipv4_mask": 32, + "expires": "10/25/2050 14:46", + } + req = client.post(f"{V_PREFIX}/rules/rtbh", headers={"x-access-token": jwt_token}, json=data) + + assert req.status_code == code + + +def test_update_existing_v4rule_with_timestamp_limit(client, db, app, jwt_token): + """ + test that update with different data passes + """ + with app.app_context(): + # count + count = db.session.query(Flowspec4).filter_by(org_id=1, rstate_id=1).count() + print("COUNT in update", count) + + org = db.session.query(Organization).filter_by(id=1).first() + org.limit_flowspec4 = count + db.session.commit() + + req = client.post( + f"{V_PREFIX}/rules/ipv4", + headers={"x-access-token": jwt_token}, + json={ + "action": 2, + "protocol": "tcp", + "source": "147.230.17.17", + "source_mask": 32, + "source_port": "", + "expires": "2552634908", + }, + ) + + assert req.status_code == 403 + data = json.loads(req.data) + assert data["message"] + assert data["message"].startswith("Rule limit") + + +def test_overall_limit(client, db, app, jwt_token): + """ + test that update with different data passes + """ + app.config.update({"FLOWSPEC4_MAX_RULES": 5, "FLOWSPEC6_MAX_RULES": 5, "RTBH_MAX_RULES": 5}) + + with app.app_context(): + # count + + org = db.session.query(Organization).filter_by(id=1).first() + org.limit_flowspec4 = 20 + db.session.commit() + + sources = ["147.230.42.1", "147.230.42.2", "147.230.42.3", "147.230.42.4"] + codes = [201, 201, 201, 403] + + for source, code in zip(sources, codes): + data = { + "action": 1, + "protocol": "tcp", + "source": source, + "source_mask": 32, + "source_port": "", + "expires": "10/15/2050 14:46", + } + req = client.post( + f"{V_PREFIX}/rules/ipv4", + headers={"x-access-token": jwt_token}, + json=data, + ) + print(source) + assert req.status_code == code + + data = json.loads(req.data) + assert data["message"] + assert data["message"].startswith("System limit") diff --git a/tests/test_api_whitelist_integration.py b/tests/test_api_whitelist_integration.py new file mode 100644 index 0000000..0391a76 --- /dev/null +++ b/tests/test_api_whitelist_integration.py @@ -0,0 +1,273 @@ +""" +Integration test for the API view when interacting with whitelists. + +This test suite verifies the behavior when creating RTBH rules through the API +when there are existing whitelists that could affect the rules. +""" + +import json +import pytest +from datetime import datetime, timedelta + +from flowapp.constants import RuleTypes, RuleOrigin +from flowapp.models import RTBH, RuleWhitelistCache, Organization +from flowapp.services import whitelist_service + + +@pytest.fixture +def whitelist_data(): + """Create whitelist data for testing""" + return { + "ip": "192.168.1.0", + "mask": 24, + "comment": "Test whitelist for API integration test", + "expires": datetime.now() + timedelta(days=1), + } + + +@pytest.fixture +def rtbh_api_payload(): + """Create RTBH rule data for API testing""" + return { + "community": 1, + "ipv4": "192.168.1.0", + "ipv4_mask": 24, + "expires": (datetime.now() + timedelta(days=1)).strftime("%m/%d/%Y %H:%M"), + "comment": "Test RTBH rule via API", + } + + +def test_create_rtbh_equal_to_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): + """ + Test creating an RTBH rule through API that exactly matches an existing whitelist. + + The rule should be created but marked as whitelisted (rstate_id=4), and a cache entry + should be created linking the rule to the whitelist. + """ + # First, configure app to include community ID 1 in allowed communities + app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) + + # Create the whitelist directly using the service + with app.app_context(): + # Create user and organization if needed for the whitelist + org = db.session.query(Organization).first() + + # Create the whitelist + whitelist_model, _ = whitelist_service.create_or_update_whitelist( + form_data=whitelist_data, + user_id=1, # Using user ID 1 from test fixtures + org_id=org.id, + user_email="test@example.com", + org_name=org.name, + ) + + # Verify whitelist was created + assert whitelist_model.id is not None + assert whitelist_model.ip == whitelist_data["ip"] + assert whitelist_model.mask == whitelist_data["mask"] + + # Now create the RTBH rule via API + response = client.post( + "/api/v3/rules/rtbh", + headers={"x-access-token": jwt_token}, + json=rtbh_api_payload, + ) + + # Verify response is successful + assert response.status_code == 201 + data = json.loads(response.data) + assert data["rule"] is not None + rule_id = data["rule"]["id"] + + # Now verify the rule was created but marked as whitelisted + with app.app_context(): + rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + assert rtbh_rule is not None + assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state + + # Verify a cache entry was created + cache_entry = RuleWhitelistCache.query.filter_by( + rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id + ).first() + + assert cache_entry is not None + assert cache_entry.rorigin == RuleOrigin.USER.value + + +def test_create_rtbh_supernet_of_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): + """ + Test creating an RTBH rule through API that is a supernet of an existing whitelist. + + The rule should be created with whitelisted state (rstate_id=4) and smaller subnet rules + should be created for the non-whitelisted parts. + """ + # Configure app with allowed communities + app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) + + # First create a whitelist for a subnet + whitelist_data["ip"] = "192.168.1.128" + whitelist_data["mask"] = 25 # Subnet of the RTBH rule which will be /24 + + # Create the whitelist directly using the service + with app.app_context(): + org = db.session.query(Organization).first() + whitelist_model, _ = whitelist_service.create_or_update_whitelist( + form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name + ) + + # Verify whitelist was created + assert whitelist_model.id is not None + + # Now create the RTBH rule via API that covers both the whitelist and additional space + headers = {"x-access-token": jwt_token} + response = client.post( + "/api/v3/rules/rtbh", + headers=headers, + json=rtbh_api_payload, # This is a /24, which contains the /25 whitelist + ) + + # Verify response is successful + assert response.status_code == 201 + data = json.loads(response.data) + assert data["rule"] is not None + rule_id = data["rule"]["id"] + + # Now verify the rule was created and marked as whitelisted + with app.app_context(): + # Check the original rule + rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + assert rtbh_rule is not None + assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state + + # Check if a new subnet rule was created for the non-whitelisted part + subnet_rule = ( + db.session.query(RTBH) + .filter( + RTBH.ipv4 == "192.168.1.0", + RTBH.ipv4_mask == 25, # This would be the other half not covered by the whitelist + ) + .first() + ) + + assert subnet_rule is not None + assert subnet_rule.rstate_id == 1 # Active status + + # Verify cache entries + # Main rule should be cached as a USER rule + user_cache = RuleWhitelistCache.query.filter_by( + rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id, rorigin=RuleOrigin.USER.value + ).first() + assert user_cache is not None + + # Subnet rule should be cached as a WHITELIST rule + whitelist_cache = RuleWhitelistCache.query.filter_by( + rid=subnet_rule.id, + rtype=RuleTypes.RTBH.value, + whitelist_id=whitelist_model.id, + rorigin=RuleOrigin.WHITELIST.value, + ).first() + assert whitelist_cache is not None + + +def test_create_rtbh_subnet_of_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): + """ + Test creating an RTBH rule through API that is contained within an existing whitelist. + + The rule should be created but immediately marked as whitelisted (rstate_id=4). + """ + # Configure app with allowed communities + app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) + + # First create a whitelist for a supernet + whitelist_data["ip"] = "192.168.0.0" + whitelist_data["mask"] = 16 # Supernet that contains the RTBH rule + + # Create the whitelist directly using the service + with app.app_context(): + all_rtbh_rules_before = db.session.query(RTBH).count() + + org = db.session.query(Organization).first() + whitelist_model, _ = whitelist_service.create_or_update_whitelist( + form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name + ) + + # Verify whitelist was created + assert whitelist_model.id is not None + + # Now create the RTBH rule via API that is inside the whitelist + headers = {"x-access-token": jwt_token} + response = client.post( + "/api/v3/rules/rtbh", headers=headers, json=rtbh_api_payload # This is a /24 inside the /16 whitelist + ) + + # Verify response is successful + assert response.status_code == 201 + data = json.loads(response.data) + assert data["rule"] is not None + rule_id = data["rule"]["id"] + + # Now verify the rule was created but marked as whitelisted + with app.app_context(): + rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + assert rtbh_rule is not None + assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state + + # Verify a cache entry was created + cache_entry = RuleWhitelistCache.query.filter_by( + rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id + ).first() + + assert cache_entry is not None + assert cache_entry.rorigin == RuleOrigin.USER.value + + # Verify no additional rules were created + all_rtbh_rules = db.session.query(RTBH).count() + assert all_rtbh_rules - all_rtbh_rules_before == 1 + + +def test_create_rtbh_no_relation_to_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): + """ + Test creating an RTBH rule through API that has no relation to any existing whitelist. + + The rule should be created normally with active state (rstate_id=1). + """ + # Configure app with allowed communities + app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) + + # First create a whitelist for a completely different network + whitelist_data["ip"] = "10.0.0.0" + whitelist_data["mask"] = 8 + + # Create the whitelist directly using the service + with app.app_context(): + org = db.session.query(Organization).first() + whitelist_model, _ = whitelist_service.create_or_update_whitelist( + form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name + ) + + # Verify whitelist was created + assert whitelist_model.id is not None + + # Now create the RTBH rule via API for a network not covered by any of the whitelists + rtbh_api_payload["ipv4"] = "147.230.17.185" + rtbh_api_payload["ipv4_mask"] = 32 + + headers = {"x-access-token": jwt_token} + response = client.post("/api/v3/rules/rtbh", headers=headers, json=rtbh_api_payload) # This is 192.168.1.0/24 + + # Verify response is successful + assert response.status_code == 201 + data = json.loads(response.data) + assert data["rule"] is not None + rule_id = data["rule"]["id"] + + # Now verify the rule was created with active state + with app.app_context(): + rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + assert rtbh_rule is not None + assert rtbh_rule.rstate_id == 1 # 1 = active state + + # Verify no cache entry was created + cache_entry = RuleWhitelistCache.query.filter_by(rid=rule_id, rtype=RuleTypes.RTBH.value).first() + + assert cache_entry is None diff --git a/tests/test_flowapp.py b/tests/test_flowapp.py new file mode 100644 index 0000000..f71e4ae --- /dev/null +++ b/tests/test_flowapp.py @@ -0,0 +1,14 @@ +def test_dashboard_not_auth(client): + + response = client.get("/dashboard/ipv4/active/?sort=expires&order=desc") + + # Expecting a 302 redirect to login + assert response.status_code == 302 + + +def test_dashboard(auth_client): + + response = auth_client.get("/dashboard/ipv4/active/?sort=expires&order=desc") + + # Check that the request is successful and renders the correct template + assert response.status_code == 200 # Expecting a 200 OK if the user is authenticated diff --git a/tests/test_flowspec.py b/tests/test_flowspec.py new file mode 100644 index 0000000..70000b8 --- /dev/null +++ b/tests/test_flowspec.py @@ -0,0 +1,167 @@ +import pytest +from flowapp.flowspec import translate_sequence, filter_rules_action, check_limit + + +def test_translate_number(): + """ + tests for x (integer) to =x + """ + assert "[=10]" == translate_sequence("10") + + +def test_raises(): + """ + tests for translator + """ + with pytest.raises(ValueError): + translate_sequence("ahoj") + + +def test_raises_bad_number(): + """ + tests for translator + """ + with pytest.raises(ValueError): + translate_sequence("75555") + + +def test_translate_range(): + """ + tests for x-y to >=x&<=y + """ + assert "[>=10&<=20]" == translate_sequence("10-20") + + +def test_exact_rule(): + """ + test for >=x&<=y to >=x&<=y + """ + assert "[>=10&<=20]" == translate_sequence(">=10&<=20") + + +def test_greater_than(): + """ + test for >x to >=x&<=65535 + """ + assert "[>=10&<=65535]" == translate_sequence(">10") + + +def test_greater_equal_than(): + """ + test for >=x to >=x&<=65535 + """ + assert "[>=10&<=65535]" == translate_sequence(">=10") + + +def test_lower_than(): + """ + test for =0&<=0 + """ + assert "[>=0&<=10]" == translate_sequence("<10") + + +def test_lower_equal_than(): + """ + test for =0&<=0 + """ + assert "[>=0&<=10]" == translate_sequence("<=10") + + +# new tests + + +def test_multiple_sequences(): + """Test multiple sequences separated by semicolons""" + assert "[=10 >=20&<=30]" == translate_sequence("10;20-30") + assert "[=10 >=0&<=20 >=30&<=65535]" == translate_sequence("10;<=20;>30") + + +def test_empty_sequence(): + """Test empty sequence and sequences with empty parts""" + assert "[]" == translate_sequence("") + assert "[=10]" == translate_sequence("10;;") + assert "[=10 =20]" == translate_sequence("10;;20") + + +def test_range_edge_cases(): + """Test edge cases for ranges""" + # Same numbers in range + assert "[>=10&<=10]" == translate_sequence("10-10") + + # Invalid range (start > end) + with pytest.raises(ValueError, match="Invalid range: start value cannot be greater than end value"): + translate_sequence("20-10") + + # Invalid range in exact rule + with pytest.raises(ValueError, match="Invalid range: start value cannot be greater than end value"): + translate_sequence(">=20&<=10") + + +def test_check_limit_validation(): + """Test the check_limit function""" + # Test valid cases + assert check_limit(10, 100) == 10 + assert check_limit(0, 100, 0) == 0 + assert check_limit(100, 100, 0) == 100 + + # Test invalid cases + with pytest.raises(ValueError, match="Invalid value number: .* is too big"): + check_limit(101, 100) + + with pytest.raises(ValueError, match="Invalid value number: .* is too small"): + check_limit(-1, 100, 0) + + +def test_invalid_inputs(): + """Test various invalid inputs""" + invalid_inputs = [ + "abc", # Non-numeric + "10.5", # Decimal + "10,20", # Wrong separator + "10&20", # Wrong format + ">", # Incomplete + ">=", # Incomplete + "<", # Incomplete + "<=", # Incomplete + ">>10", # Invalid operator + "<<10", # Invalid operator + ">=10&", # Incomplete range + "&<=10", # Incomplete range + ] + + for invalid_input in invalid_inputs: + with pytest.raises(ValueError): + translate_sequence(invalid_input) + + +class MockRule: + def __init__(self, action_id): + self.action_id = action_id + + +def test_filter_rules_action(): + """Test the filter_rules_action function""" + # Create test rules + rules = [MockRule(1), MockRule(2), MockRule(3), MockRule(4)] + + # Test with empty allowed actions + editable, viewonly = filter_rules_action([], rules) + assert len(editable) == 0 + assert len(viewonly) == 4 + + # Test with some allowed actions + editable, viewonly = filter_rules_action([1, 3], rules) + assert len(editable) == 2 + assert len(viewonly) == 2 + assert all(rule.action_id in [1, 3] for rule in editable) + assert all(rule.action_id in [2, 4] for rule in viewonly) + + # Test with all actions allowed + editable, viewonly = filter_rules_action([1, 2, 3, 4], rules) + assert len(editable) == 4 + assert len(viewonly) == 0 + + # Test with empty rules list + editable, viewonly = filter_rules_action([1, 2], []) + assert len(editable) == 0 + assert len(viewonly) == 0 diff --git a/tests/test_forms.py b/tests/test_forms.py new file mode 100644 index 0000000..e9620d8 --- /dev/null +++ b/tests/test_forms.py @@ -0,0 +1,65 @@ +import pytest +import flowapp.forms + + +@pytest.fixture() +def ip_form(app, field_class): + with app.test_request_context(): # Push the request context + form = flowapp.forms.IPForm() + form.source = field_class() + form.dest = field_class() + form.source_mask = field_class() + form.dest_mask = field_class() + return form + + +def test_ip_form_created(ip_form): + assert ip_form.source.data is None + assert ip_form.source.errors == [] + + +@pytest.mark.parametrize( + "address, mask, expected", + [ + ("147.230.23.25", "24", False), + ("147.230.23.0", "24", True), + ("0.0.0.0", "0", True), + ("2001:718:1C01:1111::1111", "64", False), + ("2001:718:1C01:1111::", "64", True), + ], +) +def test_ip_form_validate_source_address(ip_form, address, mask, expected): + ip_form.source.data = address + ip_form.source_mask.data = mask + assert ip_form.validate_source_address() == expected + + +@pytest.mark.parametrize( + "address, mask, expected", + [ + ("147.230.23.25", "24", False), + ("147.230.23.0", "24", True), + ("0.0.0.0", "0", True), + ("2001:718:1C01:1111::1111", "64", False), + ("2001:718:1C01:1111::", "64", True), + ], +) +def test_ip_form_validate_dest_address(ip_form, address, mask, expected): + ip_form.dest.data = address + ip_form.dest_mask.data = mask + assert ip_form.validate_dest_address() == expected + + +@pytest.mark.parametrize( + "address, mask, ranges, expected", + [ + ("147.230.23.0", "24", ["147.230.0.0/16", "2001:718:1c01::/48"], True), + ("0.0.0.0", "0", ["147.230.0.0/16", "2001:718:1c01::/48"], False), + ("195.113.0.0", "16", ["195.113.0.0/18", "195.113.64.0/21"], False), + ], +) +def test_ip_form_validate_address_mask(ip_form, address, mask, ranges, expected): + ip_form.net_ranges = ranges + ip_form.source.data = address + ip_form.source_mask.data = mask + assert ip_form.validate_address_ranges() == expected diff --git a/tests/test_forms_cl.py b/tests/test_forms_cl.py new file mode 100644 index 0000000..db0a569 --- /dev/null +++ b/tests/test_forms_cl.py @@ -0,0 +1,608 @@ +import pytest +from datetime import datetime, timedelta +from werkzeug.datastructures import MultiDict +from flowapp.forms import ( + UserForm, + BulkUserForm, + ApiKeyForm, + MachineApiKeyForm, + OrganizationForm, + ActionForm, + ASPathForm, + CommunityForm, + RTBHForm, + IPv4Form, + IPv6Form, + WhitelistForm, +) + + +def create_form_data(data): + """Helper function to create proper form data format""" + processed_data = {} + for key, value in data.items(): + if isinstance(value, list): + processed_data[key] = [str(v) for v in value] + else: + processed_data[key] = value + return MultiDict(processed_data) + + +@pytest.fixture +def valid_datetime(): + return (datetime.now() + timedelta(days=1)).strftime("%Y-%m-%dT%H:%M") + + +@pytest.fixture +def sample_network_ranges(): + return ["192.168.0.0/16", "2001:db8::/32"] + + +class TestUserForm: + @pytest.fixture + def mock_choices(self): + return { + "role_ids": [(1, "Admin"), (2, "User"), (3, "Guest")], + "org_ids": [(1, "Org1"), (2, "Org2"), (3, "Org3")], + } + + @pytest.fixture + def valid_user_data(self): + return { + "uuid": "test@example.com", + "email": "user@example.com", + "name": "Test User", + "phone": "123456789", + "role_ids": ["2"], + "org_ids": ["1"], + } + + def test_valid_user_form(self, app, valid_user_data, mock_choices): + with app.test_request_context(): + form_data = create_form_data(valid_user_data) + form = UserForm(formdata=form_data) + form.role_ids.choices = mock_choices["role_ids"] + form.org_ids.choices = mock_choices["org_ids"] + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_invalid_email(self, app, mock_choices): + with app.test_request_context(): + form_data = create_form_data({"uuid": "invalid-email", "role_ids": ["2"], "org_ids": ["1"]}) + form = UserForm(formdata=form_data) + form.role_ids.choices = mock_choices["role_ids"] + form.org_ids.choices = mock_choices["org_ids"] + + assert not form.validate() + assert "Please provide valid email" in form.uuid.errors + + +class TestRTBHForm: + @pytest.fixture + def mock_community_choices(self): + return [(1, "Community1"), (2, "Community2")] + + @pytest.fixture + def valid_rtbh_data(self, valid_datetime): + return { + "ipv4": "192.168.1.0", + "ipv4_mask": 24, + "community": "1", + "expires": valid_datetime, + "comment": "Test RTBH rule", + } + + def test_valid_ipv4_rtbh(self, app, valid_rtbh_data, sample_network_ranges, mock_community_choices): + with app.test_request_context(): + form_data = create_form_data(valid_rtbh_data) + form = RTBHForm(formdata=form_data) + form.net_ranges = sample_network_ranges + form.community.choices = mock_community_choices + assert form.validate() + + +class TestIPv4Form: + @pytest.fixture + def mock_action_choices(self): + return [(1, "Accept"), (2, "Drop"), (3, "Reject")] + + @pytest.fixture + def valid_ipv4_data(self, valid_datetime): + return { + "source": "192.168.1.0", + "source_mask": "24", + "protocol": "tcp", + "action": "1", + "expires": valid_datetime, + } + + def test_valid_ipv4_rule(self, app, valid_ipv4_data, sample_network_ranges, mock_action_choices): + with app.test_request_context(): + form_data = create_form_data(valid_ipv4_data) + form = IPv4Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_invalid_protocol_flags(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + with app.test_request_context(): + form_data = create_form_data( + { + "source": "192.168.1.0", + "source_mask": "24", + "protocol": "udp", + "flags": ["syn"], + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv4Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + assert not form.validate() + + +class TestIPv6Form: + @pytest.fixture + def mock_action_choices(self): + return [(1, "Accept"), (2, "Drop"), (3, "Reject")] + + @pytest.fixture + def valid_ipv6_data(self, valid_datetime): + return { + "source": "2001:db8::", # Network aligned address within allowed range + "source_mask": "32", # Matching the organization's prefix length + "next_header": "tcp", + "action": "1", + "expires": valid_datetime, + } + + def test_valid_ipv6_rule(self, app, valid_ipv6_data, sample_network_ranges, mock_action_choices): + with app.test_request_context(): + form_data = create_form_data(valid_ipv6_data) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_invalid_next_header_flags(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db8::", + "source_mask": "32", + "next_header": "udp", + "flags": ["syn"], + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + assert not form.validate() + + def test_address_outside_range(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + """Test validation fails when address is outside allowed ranges""" + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db9::", # Different prefix + "source_mask": "32", + "next_header": "tcp", + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + assert not form.validate() + assert any("must be in organization range" in error for error in form.source.errors) + + def test_destination_address(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + """Test validation with destination address instead of source""" + with app.test_request_context(): + form_data = create_form_data( + { + "dest": "2001:db8::", + "dest_mask": "32", + "next_header": "tcp", + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_both_source_and_dest(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + """Test validation with both source and destination addresses""" + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db8::", + "source_mask": "32", + "dest": "2001:db8:1::", + "dest_mask": "48", + "next_header": "tcp", + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_tcp_flags(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + """Test validation with TCP flags (should be valid with TCP)""" + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db8::", + "source_mask": "32", + "next_header": "tcp", + "flags": ["SYN", "ACK"], + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + @pytest.mark.parametrize("port_data", ["80", "80;443", "1024-2048"]) + def test_valid_ports(self, app, valid_datetime, sample_network_ranges, mock_action_choices, port_data): + """Test validation with various valid port formats""" + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db8::", + "source_mask": "32", + "next_header": "tcp", + "source_port": port_data, + "dest_port": port_data, + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + +class TestBulkUserForm: + @pytest.fixture + def valid_csv_data(self): + return "uuid-eppn,role,organizace\nuser1@example.com,2,1\nuser2@example.com,2,1" + + def test_valid_bulk_import(self, app, valid_csv_data): + with app.test_request_context(): + form_data = create_form_data({"users": valid_csv_data}) + form = BulkUserForm(formdata=form_data) + form.roles = {2} # Mock available roles + form.organizations = {1} # Mock available organizations + form.uuids = set() # Mock existing UUIDs (empty set) + assert form.validate() + + def test_duplicate_uuid(self, app, valid_csv_data): + with app.test_request_context(): + form_data = create_form_data({"users": valid_csv_data}) + form = BulkUserForm(formdata=form_data) + form.roles = {2} + form.organizations = {1} + form.uuids = {"user1@example.com"} # UUID already exists + assert not form.validate() + assert any("already exists" in error for error in form.users.errors) + + def test_invalid_role(self, app): + csv_data = "uuid-eppn,role,organizace\nuser1@example.com,999,1" # Invalid role ID + with app.test_request_context(): + form_data = create_form_data({"users": csv_data}) + form = BulkUserForm(formdata=form_data) + form.roles = {2} + form.organizations = {1} + form.uuids = set() + assert not form.validate() + assert any("does not exist" in error for error in form.users.errors) + + +class TestApiKeyForm: + @pytest.fixture + def valid_api_key_data(self, valid_datetime): + return {"machine": "192.168.1.1", "comment": "Test API key", "expires": valid_datetime, "readonly": "true"} + + def test_valid_api_key(self, app, valid_api_key_data): + with app.test_request_context(): + form_data = create_form_data(valid_api_key_data) + form = ApiKeyForm(formdata=form_data) + assert form.validate() + + def test_invalid_ip(self, app, valid_datetime): + with app.test_request_context(): + form_data = create_form_data({"machine": "invalid_ip", "expires": valid_datetime}) + form = ApiKeyForm(formdata=form_data) + assert not form.validate() + assert "provide valid IP address" in form.machine.errors + + def test_unlimited_expiration(self, app): + with app.test_request_context(): + form_data = create_form_data({"machine": "192.168.1.1", "expires": ""}) # Empty expiration for unlimited + form = ApiKeyForm(formdata=form_data) + assert form.validate() + + +class TestMachineApiKeyForm: + # Similar to ApiKeyForm, but might have different validation rules + @pytest.fixture + def valid_machine_key_data(self, valid_datetime): + return { + "machine": "192.168.1.1", + "comment": "Test machine API key", + "expires": valid_datetime, + "readonly": "true", + "user": 1, + } + + def test_valid_machine_key(self, app, valid_machine_key_data): + with app.test_request_context(): + form_data = create_form_data(valid_machine_key_data) + form = MachineApiKeyForm(formdata=form_data) + form.user.choices = [(1, "g.name"), (2, "test")] + assert form.validate() + + +class TestOrganizationForm: + @pytest.fixture + def valid_org_data(self): + return { + "name": "Test Organization", + "limit_flowspec4": "100", + "limit_flowspec6": "100", + "limit_rtbh": "50", + "arange": "192.168.0.0/16\n2001:db8::/32", + } + + def test_valid_organization(self, app, valid_org_data): + with app.test_request_context(): + form_data = create_form_data(valid_org_data) + form = OrganizationForm(formdata=form_data) + assert form.validate() + + def test_invalid_ranges(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "Test Org", "arange": "invalid_range\n192.168.0.0/16"}) + form = OrganizationForm(formdata=form_data) + assert not form.validate() + + def test_invalid_limits(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "Test Org", "limit_flowspec4": "1001"}) # Exceeds max value + form = OrganizationForm(formdata=form_data) + assert not form.validate() + + +class TestActionForm: + @pytest.fixture + def valid_action_data(self): + return { + "name": "test_action", + "command": "announce route", + "description": "Test action description", + "role_id": "2", + } + + def test_valid_action(self, app, valid_action_data): + with app.test_request_context(): + form_data = create_form_data(valid_action_data) + form = ActionForm(formdata=form_data) + assert form.validate() + + def test_invalid_role(self, app, valid_action_data): + with app.test_request_context(): + invalid_data = dict(valid_action_data) + invalid_data["role_id"] = "4" # Invalid role + form_data = create_form_data(invalid_data) + form = ActionForm(formdata=form_data) + assert not form.validate() + + +class TestASPathForm: + @pytest.fixture + def valid_aspath_data(self): + return {"prefix": "AS64512", "as_path": "64512 64513 64514"} + + def test_valid_aspath(self, app, valid_aspath_data): + with app.test_request_context(): + form_data = create_form_data(valid_aspath_data) + form = ASPathForm(formdata=form_data) + assert form.validate() + + def test_empty_fields(self, app): + with app.test_request_context(): + form_data = create_form_data({}) + form = ASPathForm(formdata=form_data) + assert not form.validate() + + +class TestCommunityForm: + @pytest.fixture + def valid_community_data(self): + return { + "name": "test_community", + "comm": "64512:100", + "description": "Test community", + "role_id": "2", + "as_path": "true", + } + + def test_valid_community(self, app, valid_community_data): + with app.test_request_context(): + form_data = create_form_data(valid_community_data) + form = CommunityForm(formdata=form_data) + assert form.validate() + + def test_missing_all_community_types(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "test_community", "role_id": "2"}) + form = CommunityForm(formdata=form_data) + assert not form.validate() + assert any("could not be empty" in error for error in form.comm.errors) + + def test_valid_with_extended_community(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "test_community", "extcomm": "rt:64512:100", "role_id": "2"}) + form = CommunityForm(formdata=form_data) + assert form.validate() + + def test_valid_with_large_community(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "test_community", "larcomm": "64512:100:200", "role_id": "2"}) + form = CommunityForm(formdata=form_data) + assert form.validate() + + +class TestWhitelistForm: + @pytest.fixture + def valid_ipv4_data(self, valid_datetime): + return {"ip": "192.168.0.0", "mask": "24", "comment": "Test whitelist entry", "expires": valid_datetime} + + @pytest.fixture + def valid_ipv6_data(self, valid_datetime): + return {"ip": "2001:db8::", "mask": "32", "comment": "IPv6 whitelist entry", "expires": valid_datetime} + + def test_valid_ipv4_entry(self, app, valid_ipv4_data, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data(valid_ipv4_data) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_valid_ipv6_entry(self, app, valid_ipv6_data, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data(valid_ipv6_data) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_invalid_ip_format(self, app, valid_datetime, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data({"ip": "invalid_ip", "mask": "24", "expires": valid_datetime}) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + assert any("valid IP address" in error for error in form.ip.errors) + + def test_ip_outside_range(self, app, valid_datetime): + with app.test_request_context(): + form_data = create_form_data( + {"ip": "10.0.0.0", "mask": "24", "expires": valid_datetime} # IP outside allowed range + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = ["192.168.0.0/16"] # Only allow 192.168.0.0/16 + assert not form.validate() + assert any("must be in organization range" in error for error in form.ip.errors) + + def test_invalid_mask_ipv4(self, app, valid_datetime, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data( + {"ip": "192.168.0.0", "mask": "33", "expires": valid_datetime} # Invalid mask for IPv4 + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + + def test_invalid_mask_ipv6(self, app, valid_datetime, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data( + {"ip": "2001:db8::", "mask": "129", "expires": valid_datetime} # Invalid mask for IPv6 + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + + def test_missing_required_fields(self, app, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data({}) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + assert form.ip.errors # Should have error for missing IP + assert form.mask.errors # Should have error for missing mask + assert form.expires.errors # Should have error for missing expiration + + def test_comment_length(self, app, valid_datetime, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data( + { + "ip": "192.168.0.0", + "mask": "24", + "expires": valid_datetime, + "comment": "x" * 256, # Comment longer than 255 chars + } + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + + @pytest.mark.parametrize( + "expires", ["", "invalid_date", "2020-33-42T00:00"] # Empty expiration # Invalid date format # Past date + ) + def test_invalid_expiration(self, app, expires, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data({"ip": "192.168.0.0", "mask": "24", "expires": expires}) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + if not form.validate(): + print("Validation errors:", form.errors) + + assert not form.validate() + + def test_network_alignment(self, app, valid_datetime, sample_network_ranges): + """Test that IP addresses must be properly network-aligned""" + with app.test_request_context(): + form_data = create_form_data( + {"ip": "192.168.1.1", "mask": "24", "expires": valid_datetime} # Not aligned to /24 boundary + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() diff --git a/tests/test_login.py b/tests/test_login.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..b352a5f --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,498 @@ +from datetime import datetime, timedelta +from flowapp.models import ( + User, + Organization, + Role, + ApiKey, + MachineApiKey, + Rstate, + Community, + Action, + Flowspec6, + Whitelist, +) + +import flowapp.models as models + + +def test_insert_ipv4(db): + """ + test the record can be inserted + :param db: conftest fixture + :return: + """ + model = models.Flowspec4( + source="192.168.1.1", + source_mask="32", + source_port="80", + destination="", + destination_mask="", + destination_port="", + protocol="tcp", + flags="", + packet_len="", + fragment="", + action_id=1, + expires=datetime.now(), + user_id=1, + org_id=1, + rstate_id=1, + ) + db.session.add(model) + db.session.commit() + + +def test_get_ipv4_model_if_exists(db): + """ + test if the function find existing model correctly + :param db: conftest fixture + :return: + """ + model = models.Flowspec4( + source="192.168.1.1", + source_mask="32", + source_port="80", + destination="", + destination_mask="", + destination_port="", + protocol="tcp", + flags="", + fragment="", + packet_len="", + action_id=1, + expires=datetime.now(), + user_id=1, + org_id=1, + rstate_id=1, + ) + db.session.add(model) + db.session.commit() + + form_data = { + "source": "192.168.1.1", + "source_mask": "32", + "source_port": "80", + "dest": "", + "dest_mask": "", + "dest_port": "", + "protocol": "tcp", + "flags": "", + "packet_len": "", + "action": 1, + } + + result = models.get_ipv4_model_if_exists(form_data, 1) + assert result + assert result == model + + +def test_get_ipv6_model_if_exists(db): + """ + test if the function find existing model correctly + :param db: conftest fixture + :return: + """ + model = models.Flowspec6( + source="2001:0db8:85a3:0000:0000:8a2e:0370:7334", + source_mask="32", + source_port="80", + destination="", + destination_mask="", + destination_port="", + next_header="tcp", + flags="", + packet_len="", + action_id=1, + expires=datetime.now(), + user_id=1, + org_id=1, + rstate_id=1, + ) + db.session.add(model) + db.session.commit() + + form_data = { + "source": "2001:0db8:85a3:0000:0000:8a2e:0370:7334", + "source_mask": "32", + "source_port": "80", + "dest": "", + "dest_mask": "", + "dest_port": "", + "next_header": "tcp", + "flags": "", + "packet_len": "", + "action": 1, + } + + result = models.get_ipv6_model_if_exists(form_data, 1) + assert result + assert result == model + + +def test_ipv4_eq(db): + """ + test that creating with valid data returns 201 + """ + model_A = models.Flowspec4( + source="192.168.1.1", + source_mask="32", + source_port="80", + destination="", + destination_mask="", + destination_port="", + protocol="tcp", + flags="", + fragment="", + packet_len="", + action_id=1, + expires="123", + user_id=1, + org_id=1, + rstate_id=1, + ) + + model_B = models.Flowspec4( + source="192.168.1.1", + source_mask="32", + source_port="80", + destination="", + destination_mask="", + destination_port="", + protocol="tcp", + flags="", + fragment="", + packet_len="", + action_id=1, + expires="123456", + user_id=1, + org_id=1, + rstate_id=1, + ) + + assert model_A == model_B + + +def test_ipv4_ne(db): + """ + test that creating with valid data returns 201 + """ + model_A = models.Flowspec4( + source="192.168.2.2", + source_mask="32", + source_port="80", + destination="", + destination_mask="", + destination_port="", + protocol="tcp", + flags="", + fragment="", + packet_len="", + action_id=1, + expires="123", + user_id=1, + org_id=1, + rstate_id=1, + ) + + model_B = models.Flowspec4( + source="192.168.1.1", + source_mask="32", + source_port="80", + destination="", + destination_mask="", + destination_port="", + protocol="tcp", + flags="", + fragment="", + packet_len="", + action_id=1, + expires="123456", + user_id=1, + org_id=1, + rstate_id=1, + ) + + assert model_A != model_B + + +def test_rtbj_eq(db): + """ + test that two equal rtbh rules are equal + """ + model_A = models.RTBH( + ipv4="192.168.1.1", + ipv4_mask="32", + ipv6="", + ipv6_mask="", + community_id=1, + expires="123", + user_id=1, + org_id=1, + rstate_id=1, + ) + + model_B = models.RTBH( + ipv4="192.168.1.1", + ipv4_mask="32", + ipv6="", + ipv6_mask="", + community_id=1, + expires="123456", + user_id=1, + org_id=1, + rstate_id=1, + ) + + assert model_A == model_B + + +def test_user_creation(db): + """Test basic user creation and relationships""" + # Create test role and org first + role = Role(name="test_role", description="Test Role") + org = Organization(name="test_org", arange="10.0.0.0/8") + db.session.add_all([role, org]) + db.session.commit() + + # Create user with relationships + user = User( + uuid="test-user-123", name="Test User", phone="1234567890", email="test@example.com", comment="Test comment" + ) + user.role.append(role) + user.organization.append(org) + db.session.add(user) + db.session.commit() + + # Verify user and relationships + assert user.uuid == "test-user-123" + assert user.name == "Test User" + assert len(user.role.all()) == 1 + assert len(user.organization.all()) == 1 + assert user.role.first().name == "test_role" + assert user.organization.first().name == "test_org" + + +def test_api_key_expiration(db): + """Test ApiKey expiration logic""" + user = User(uuid="test-user") + org = Organization(name="test-org", arange="10.0.0.0/8") + db.session.add_all([user, org]) + db.session.commit() + + # Create non-expiring key + non_expiring_key = ApiKey( + machine="test-machine-1", + key="key1", + readonly=True, + expires=None, + comment="Non-expiring key", + user_id=user.id, + org_id=org.id, + ) + + # Create expired key + expired_key = ApiKey( + machine="test-machine-2", + key="key2", + readonly=True, + expires=datetime.now() - timedelta(days=1), + comment="Expired key", + user_id=user.id, + org_id=org.id, + ) + + # Create future key + future_key = ApiKey( + machine="test-machine-3", + key="key3", + readonly=True, + expires=datetime.now() + timedelta(days=1), + comment="Future key", + user_id=user.id, + org_id=org.id, + ) + + db.session.add_all([non_expiring_key, expired_key, future_key]) + db.session.commit() + + assert not non_expiring_key.is_expired() + assert expired_key.is_expired() + assert not future_key.is_expired() + + +def test_machine_api_key_expiration(db): + """Test MachineApiKey expiration logic""" + user = User(uuid="test-user-machine") + org = Organization(name="test-org-machine", arange="10.0.0.0/8") + db.session.add_all([user, org]) + db.session.commit() + + # Create non-expiring key + non_expiring_key = MachineApiKey( + machine="test-machine-1", + key="key1", + readonly=True, + expires=None, + comment="Non-expiring key", + user_id=user.id, + org_id=org.id, + ) + + # Create expired key + expired_key = MachineApiKey( + machine="test-machine-2", + key="key2", + readonly=True, + expires=datetime.now() - timedelta(days=1), + comment="Expired key", + user_id=user.id, + org_id=org.id, + ) + + db.session.add_all([non_expiring_key, expired_key]) + db.session.commit() + + assert not non_expiring_key.is_expired() + assert expired_key.is_expired() + + +def test_organization_get_users(db): + """Test Organization's get_users method""" + org = Organization( + name="test-org-get-user", arange="10.0.0.0/8", limit_flowspec4=100, limit_flowspec6=100, limit_rtbh=100 + ) + uuid1 = "test-org-get-user" + uuid2 = "test-org-get-user2" + user1 = User(uuid=uuid1) + user2 = User(uuid=uuid2) + + db.session.add(org) + db.session.add_all([user1, user2]) + db.session.commit() + + org.user.append(user1) + org.user.append(user2) + db.session.commit() + + users = org.get_users() + assert len(users) == 2 + assert all(isinstance(user, User) for user in users) + assert {user.uuid for user in users} == {uuid1, uuid2} + + +def test_flowspec6_equality(db): + """Test Flowspec6 equality comparison""" + model_a = Flowspec6( + source="2001:db8::1", + source_mask=128, + source_port="80", + destination="2001:db8::2", + destination_mask=128, + destination_port="443", + next_header="tcp", + flags="", + packet_len="", + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + ) + + # Same network parameters but different timestamps + model_b = Flowspec6( + source="2001:db8::1", + source_mask=128, + source_port="80", + destination="2001:db8::2", + destination_mask=128, + destination_port="443", + next_header="tcp", + flags="", + packet_len="", + expires=datetime.now() + timedelta(days=1), + user_id=1, + org_id=1, + action_id=1, + ) + + # Different network parameters + model_c = Flowspec6( + source="2001:db8::3", + source_mask=128, + source_port="80", + destination="2001:db8::4", + destination_mask=128, + destination_port="443", + next_header="tcp", + flags="", + packet_len="", + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + ) + + assert model_a == model_b # Should be equal despite different timestamps + assert model_a != model_c # Should be different due to different network parameters + + +def test_whitelist_equality(db): + """Test Whitelist equality comparison""" + model_a = Whitelist( + ip="192.168.1.1", mask=32, expires=datetime.now(), user_id=1, org_id=1, comment="Test whitelist" + ) + + # Same IP/mask but different timestamps + model_b = Whitelist( + ip="192.168.1.1", + mask=32, + expires=datetime.now() + timedelta(days=1), + user_id=1, + org_id=1, + comment="Different comment", + ) + + # Different IP + model_c = Whitelist( + ip="192.168.1.2", mask=32, expires=datetime.now(), user_id=1, org_id=1, comment="Test whitelist" + ) + + assert model_a == model_b # Should be equal despite different timestamps + assert model_a != model_c # Should be different due to different IP + + +def test_whitelist_to_dict(db): + """Test Whitelist to_dict serialization""" + whitelist = Whitelist( + ip="192.168.1.1", mask=32, expires=datetime.now(), user_id=1, org_id=1, comment="Test whitelist" + ) + + # Create required related objects + user = User(uuid="test-user-whitelist") + rstate = Rstate(description="active") + db.session.add_all([user, rstate]) + db.session.commit() + + db.session.add(whitelist) + db.session.commit() + + whitelist.user = user + whitelist.rstate_id = rstate.id + db.session.add(whitelist) + db.session.commit() + + # Test timestamp format + dict_timestamp = whitelist.to_dict(prefered_format="timestamp") + assert isinstance(dict_timestamp["expires"], int) + assert isinstance(dict_timestamp["created"], int) + + # Test yearfirst format + dict_yearfirst = whitelist.to_dict(prefered_format="yearfirst") + assert isinstance(dict_yearfirst["expires"], str) + assert isinstance(dict_yearfirst["created"], str) + + # Check basic fields + assert dict_timestamp["ip"] == "192.168.1.1" + assert dict_timestamp["mask"] == 32 + assert dict_timestamp["comment"] == "Test whitelist" + assert dict_timestamp["user"] == "test-user-whitelist" diff --git a/tests/test_rule_service.py b/tests/test_rule_service.py new file mode 100644 index 0000000..490d2c2 --- /dev/null +++ b/tests/test_rule_service.py @@ -0,0 +1,628 @@ +""" +Tests for rule_service.py module. + +This test suite verifies the functionality of the rule service after refactoring, +which manages creation, updating, and processing of flow rules. +""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import patch, MagicMock + +from flowapp.constants import RuleOrigin, ANNOUNCE +from flowapp.models import Flowspec4, Flowspec6, RTBH, Whitelist +from flowapp.services import rule_service +from flowapp.services.whitelist_common import Relation + + +@pytest.fixture +def ipv4_form_data(): + """Sample valid IPv4 form data""" + return { + "source": "192.168.1.0", + "source_mask": 24, + "source_port": "80", + "dest": "", + "dest_mask": None, + "dest_port": "", + "protocol": "tcp", + "flags": ["SYN"], + "packet_len": "", + "fragment": ["dont-fragment"], + "comment": "Test IPv4 rule", + "expires": datetime.now() + timedelta(hours=1), + "action": 1, + } + + +@pytest.fixture +def ipv6_form_data(): + """Sample valid IPv6 form data""" + return { + "source": "2001:db8::", + "source_mask": 32, + "source_port": "80", + "dest": "", + "dest_mask": None, + "dest_port": "", + "next_header": "tcp", + "flags": ["SYN"], + "packet_len": "", + "comment": "Test IPv6 rule", + "expires": datetime.now() + timedelta(hours=1), + "action": 1, + } + + +@pytest.fixture +def rtbh_form_data(): + """Sample valid RTBH form data""" + return { + "ipv4": "192.168.1.0", + "ipv4_mask": 24, + "ipv6": "", + "ipv6_mask": None, + "community": 1, + "comment": "Test RTBH rule", + "expires": datetime.now() + timedelta(hours=1), + } + + +@pytest.fixture +def whitelist_fixture(): + """Create a whitelist fixture""" + whitelist = MagicMock(spec=Whitelist) + whitelist.id = 1 + return whitelist + + +class TestCreateOrUpdateIPv4Rule: + @patch("flowapp.services.rule_service.get_ipv4_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_create_new_ipv4_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv4_form_data + ): + """Test creating a new IPv4 rule""" + # Mock the get_ipv4_model_if_exists to return False (not found) + mock_get_model.return_value = False + + # Mock the announce route behavior + mock_messages.create_ipv4.return_value = "mock command" + + # Call the service function + with app.app_context(): + model, message = rule_service.create_or_update_ipv4_rule( + form_data=ipv4_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was created with correct attributes + assert model is not None + assert model.source == ipv4_form_data["source"] + assert model.source_mask == ipv4_form_data["source_mask"] + assert model.protocol == ipv4_form_data["protocol"] + assert model.flags == "SYN" # form data flags are joined with ";" + assert model.action_id == ipv4_form_data["action"] + assert model.rstate_id == 1 # Active state + assert model.user_id == 1 + assert model.org_id == 1 + + # Verify message is still a string for IPv4 rules + assert message == "IPv4 Rule saved" + + # Verify route was announced + mock_messages.create_ipv4.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() + + @patch("flowapp.services.rule_service.get_ipv4_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_update_existing_ipv4_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv4_form_data + ): + """Test updating an existing IPv4 rule""" + # Create an existing model to return + existing_model = Flowspec4( + source=ipv4_form_data["source"], + source_mask=ipv4_form_data["source_mask"], + source_port=ipv4_form_data["source_port"], + destination=ipv4_form_data["dest"] or "", + destination_mask=ipv4_form_data["dest_mask"], + destination_port=ipv4_form_data["dest_port"] or "", + protocol=ipv4_form_data["protocol"], + flags=";".join(ipv4_form_data["flags"]), + packet_len=ipv4_form_data["packet_len"] or "", + fragment=";".join(ipv4_form_data["fragment"]), + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + rstate_id=1, + ) + mock_get_model.return_value = existing_model + mock_messages.create_ipv4.return_value = "mock command" + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + ipv4_form_data["expires"] = new_expires + + # Call the service function + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, message = rule_service.create_or_update_ipv4_rule( + form_data=ipv4_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() + + # Verify message is still a string for IPv4 rules + assert message == "Existing IPv4 Rule found. Expiration time was updated to new value." + + # Verify route was announced + mock_messages.create_ipv4.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() + + +class TestCreateOrUpdateIPv6Rule: + @patch("flowapp.services.rule_service.get_ipv6_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_create_new_ipv6_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv6_form_data + ): + """Test creating a new IPv6 rule""" + # Mock get_ipv6_model_if_exists to return False + mock_get_model.return_value = False + + # Mock the announce route behavior + mock_messages.create_ipv6.return_value = "mock command" + + # Call the service function + with app.app_context(): + model, message = rule_service.create_or_update_ipv6_rule( + form_data=ipv6_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was created with correct attributes + assert model is not None + assert model.source == ipv6_form_data["source"] + assert model.source_mask == ipv6_form_data["source_mask"] + assert model.next_header == ipv6_form_data["next_header"] + assert model.flags == "SYN" + assert model.action_id == ipv6_form_data["action"] + assert model.rstate_id == 1 # Active state + assert model.user_id == 1 + assert model.org_id == 1 + + # Verify message is still a string for IPv6 rules + assert message == "IPv6 Rule saved" + + # Verify route was announced + mock_messages.create_ipv6.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() + + +class TestCreateOrUpdateRTBHRule: + @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") + @patch("flowapp.services.rule_service.db.session.query") + @patch("flowapp.services.rule_service.check_rule_against_whitelists") + @patch("flowapp.services.rule_service.evaluate_rtbh_against_whitelists_check_results") + @patch("flowapp.services.rule_service.announce_rtbh_route") + @patch("flowapp.services.rule_service.log_route") + def test_create_new_rtbh_rule( + self, mock_log, mock_announce, mock_evaluate, mock_check, mock_query, mock_get_model, app, db, rtbh_form_data + ): + """Test creating a new RTBH rule""" + # Mock get_rtbh_model_if_exists to return False + mock_get_model.return_value = False + + # Mock the whitelist query + mock_whitelists = [] + mock_query.return_value.filter.return_value.all.return_value = mock_whitelists + + # Mock check_rule_against_whitelists to return empty list (no matches) + mock_check.return_value = [] + + # Mock evaluate function to return the model unchanged + mock_evaluate.side_effect = lambda user_id, model, flashes, author, wl_cache, results: model + + # Call the service function + with app.app_context(): + model, flashes = rule_service.create_or_update_rtbh_rule( + form_data=rtbh_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was created with correct attributes + assert model is not None + assert model.ipv4 == rtbh_form_data["ipv4"] + assert model.ipv4_mask == rtbh_form_data["ipv4_mask"] + assert model.ipv6 == rtbh_form_data["ipv6"] + assert model.ipv6_mask == rtbh_form_data["ipv6_mask"] + assert model.community_id == rtbh_form_data["community"] + assert model.rstate_id == 1 # Active state + assert model.user_id == 1 + assert model.org_id == 1 + + # Verify flash messages - now a list instead of a string + assert isinstance(flashes, list) + assert "RTBH Rule saved" in flashes[0] + + # Verify rule was announced + mock_announce.assert_called_once() + mock_log.assert_called_once() + + # Verify evaluate function was called + mock_evaluate.assert_called_once() + + @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") + @patch("flowapp.services.rule_service.db.session.query") + @patch("flowapp.services.rule_service.check_rule_against_whitelists") + @patch("flowapp.services.rule_service.evaluate_rtbh_against_whitelists_check_results") + @patch("flowapp.services.rule_service.announce_rtbh_route") + @patch("flowapp.services.rule_service.log_route") + def test_update_existing_rtbh_rule( + self, mock_log, mock_announce, mock_evaluate, mock_check, mock_query, mock_get_model, app, db, rtbh_form_data + ): + """Test updating an existing RTBH rule""" + # Create an existing model to return + existing_model = RTBH( + ipv4=rtbh_form_data["ipv4"], + ipv4_mask=rtbh_form_data["ipv4_mask"], + ipv6=rtbh_form_data["ipv6"] or "", + ipv6_mask=rtbh_form_data["ipv6_mask"], + community_id=rtbh_form_data["community"], + expires=datetime.now(), + user_id=1, + org_id=1, + rstate_id=1, + ) + mock_get_model.return_value = existing_model + + # Mock the whitelist query + mock_whitelists = [] + mock_query.return_value.filter.return_value.all.return_value = mock_whitelists + + # Mock check_rule_against_whitelists to return empty list + mock_check.return_value = [] + + # Mock evaluate function to return the model unchanged + mock_evaluate.side_effect = lambda user_id, model, flashes, author, wl_cache, results: model + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + rtbh_form_data["expires"] = new_expires + + # Call the service function + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, flashes = rule_service.create_or_update_rtbh_rule( + form_data=rtbh_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() + + # Verify flash messages + assert isinstance(flashes, list) + assert "Existing RTBH Rule found" in flashes[0] + + # Verify route was announced + mock_announce.assert_called_once() + mock_log.assert_called_once() + + @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") + @patch("flowapp.services.rule_service.db.session.query") + @patch("flowapp.services.rule_service.map_whitelists_to_strings") + @patch("flowapp.services.rule_service.check_rule_against_whitelists") + @patch("flowapp.services.rule_service.evaluate_rtbh_against_whitelists_check_results") + @patch("flowapp.services.rule_service.announce_rtbh_route") + @patch("flowapp.services.rule_service.log_route") + def test_rtbh_rule_with_whitelists( + self, + mock_log, + mock_announce, + mock_evaluate, + mock_check, + mock_map, + mock_query, + mock_get_model, + app, + db, + rtbh_form_data, + whitelist_fixture, + ): + """Test creating a RTBH rule that interacts with whitelists""" + # Mock get_rtbh_model_if_exists to return False + mock_get_model.return_value = False + + # Create a mock whitelist + mock_whitelist = whitelist_fixture + mock_whitelists = [mock_whitelist] + + # Setup mock query to return our whitelist + mock_query_result = MagicMock() + mock_query_result.filter.return_value.all.return_value = mock_whitelists + mock_query.return_value = mock_query_result + + # Setup map function to return our whitelist in a dict + whitelist_key = "192.168.1.0/24" + mock_map.return_value = {whitelist_key: mock_whitelist} + + # Setup check function to return a relation + mock_rtbh = MagicMock(spec=RTBH) + mock_rtbh.__str__.return_value = "192.168.1.0/24" + + # Setup check result + rule_relation = [(str(mock_rtbh), whitelist_key, Relation.EQUAL)] + mock_check.return_value = rule_relation + + # Setup evaluate function to add a flash message + def evaluate_side_effect(user_id, model, flashes, author, wl_cache, results): + flashes.append("Rule is equal to whitelist") + return model + + mock_evaluate.side_effect = evaluate_side_effect + + # Call the service function + with app.app_context(): + model, flashes = rule_service.create_or_update_rtbh_rule( + form_data=rtbh_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify flash messages show both rule creation and whitelist check + assert isinstance(flashes, list) + assert "RTBH Rule saved" in flashes[0] + assert "Rule is equal to whitelist" in flashes[1] + + # Verify interactions + mock_map.assert_called_once() + mock_check.assert_called_once() + mock_evaluate.assert_called_once() + + +class TestEvaluateRtbhAgainstWhitelistsCheckResults: + def test_equal_relation(self, app, whitelist_fixture): + """Test evaluating a rule with an EQUAL relation to a whitelist""" + # Create a model + model = MagicMock(spec=RTBH) + + # Create test data + flashes = [] + user_id = 1 + author = "test@example.com / Test Org" + whitelist_key = "192.168.1.0/24" + wl_cache = {whitelist_key: whitelist_fixture} + results = [(str(model), whitelist_key, Relation.EQUAL)] + + # Call the function with mocked whitelist_rtbh_rule + with patch("flowapp.services.rule_service.whitelist_rtbh_rule") as mock_whitelist_rule: + mock_whitelist_rule.return_value = model + + with app.app_context(): + result = rule_service.evaluate_rtbh_against_whitelists_check_results( + user_id, model, flashes, author, wl_cache, results + ) + + # Verify the rule was whitelisted + mock_whitelist_rule.assert_called_once_with(model, whitelist_fixture) + + # Verify the flash message + assert flashes + + # Verify the correct model was returned + assert result == model + + def test_subnet_relation(self, app, whitelist_fixture): + """Test evaluating a rule with a SUBNET relation to a whitelist""" + # Create a model + model = MagicMock(spec=RTBH) + + # Create test data + flashes = [] + user_id = 1 + author = "test@example.com / Test Org" + whitelist_key = "192.168.1.128/25" + wl_cache = {whitelist_key: whitelist_fixture} + results = [(str(model), whitelist_key, Relation.SUBNET)] + + # Call the function with mocked dependencies + with patch("flowapp.services.rule_service.subtract_network") as mock_subtract, patch( + "flowapp.services.rule_service.create_rtbh_from_whitelist_parts" + ) as mock_create, patch("flowapp.services.rule_service.add_rtbh_rule_to_cache") as mock_add_cache, patch( + "flowapp.services.rule_service.db.session.commit" + ) as mock_commit: + + # Mock subtract_network to return some subnets + mock_subtract.return_value = ["192.168.1.0/25"] + + with app.app_context(): + _result = rule_service.evaluate_rtbh_against_whitelists_check_results( + user_id, model, flashes, author, wl_cache, results + ) + + # Verify subnet calculation was performed + mock_subtract.assert_called_once() + + # Verify new rules were created for the subnets + mock_create.assert_called_once() + + # Verify the original rule was cached + mock_add_cache.assert_called_once_with(model, whitelist_fixture.id, RuleOrigin.USER) + + # Verify transaction was committed + mock_commit.assert_called_once() + + # Verify the flash messages + assert flashes + + # Verify model was updated to whitelisted state + assert model.rstate_id == 4 + + def test_supernet_relation(self, app, whitelist_fixture): + """Test evaluating a rule with a SUPERNET relation to a whitelist""" + # Create a model + model = MagicMock(spec=RTBH) + + # Create test data + flashes = [] + user_id = 1 + author = "test@example.com / Test Org" + whitelist_key = "192.168.0.0/16" + wl_cache = {whitelist_key: whitelist_fixture} + results = [(str(model), whitelist_key, Relation.SUPERNET)] + + # Call the function with mocked whitelist_rtbh_rule + with patch("flowapp.services.rule_service.whitelist_rtbh_rule") as mock_whitelist_rule: + mock_whitelist_rule.return_value = model + + with app.app_context(): + result = rule_service.evaluate_rtbh_against_whitelists_check_results( + user_id, model, flashes, author, wl_cache, results + ) + + # Verify the rule was whitelisted + mock_whitelist_rule.assert_called_once_with(model, whitelist_fixture) + + # Verify the flash message + assert flashes + + # Verify the correct model was returned + assert result == model + + def test_no_relation(self, app): + """Test evaluating a rule with no relation to any whitelist""" + # Create a model + model = MagicMock(spec=RTBH) + + # Create test data + flashes = [] + user_id = 1 + author = "test@example.com / Test Org" + wl_cache = {} + results = [] + + with app.app_context(): + result = rule_service.evaluate_rtbh_against_whitelists_check_results( + user_id, model, flashes, author, wl_cache, results + ) + + # Verify no changes to the model and no messages + assert result == model + assert not flashes + + +class TestMapWhitelistsToStrings: + def test_map_whitelists_to_strings(self): + """Test mapping whitelist objects to strings""" + # Create mock whitelists + whitelist1 = MagicMock(spec=Whitelist) + whitelist1.__str__.return_value = "192.168.1.0/24" + + whitelist2 = MagicMock(spec=Whitelist) + whitelist2.__str__.return_value = "10.0.0.0/8" + + whitelists = [whitelist1, whitelist2] + + # Call the function + result = rule_service.map_whitelists_to_strings(whitelists) + + # Verify the result + assert len(result) == 2 + assert "192.168.1.0/24" in result + assert "10.0.0.0/8" in result + assert result["192.168.1.0/24"] == whitelist1 + assert result["10.0.0.0/8"] == whitelist2 + + @patch("flowapp.services.rule_service.get_ipv6_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_update_existing_ipv6_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv6_form_data + ): + """Test updating an existing IPv6 rule""" + # Create an existing model to return + existing_model = Flowspec6( + source=ipv6_form_data["source"], + source_mask=ipv6_form_data["source_mask"], + source_port=ipv6_form_data["source_port"] or "", + destination=ipv6_form_data["dest"] or "", + destination_mask=ipv6_form_data["dest_mask"], + destination_port=ipv6_form_data["dest_port"] or "", + next_header=ipv6_form_data["next_header"], + flags=";".join(ipv6_form_data["flags"]), + packet_len=ipv6_form_data["packet_len"] or "", + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + rstate_id=1, + ) + mock_get_model.return_value = existing_model + mock_messages.create_ipv6.return_value = "mock command" + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + ipv6_form_data["expires"] = new_expires + + # Call the service function + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, message = rule_service.create_or_update_ipv6_rule( + form_data=ipv6_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() + + # Verify message is still a string for IPv6 rules + assert message == "Existing IPv6 Rule found. Expiration time was updated to new value." + + # Verify route was announced + mock_messages.create_ipv6.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() diff --git a/tests/test_rule_service_reactivate_delete.py b/tests/test_rule_service_reactivate_delete.py new file mode 100644 index 0000000..22a3f9f --- /dev/null +++ b/tests/test_rule_service_reactivate_delete.py @@ -0,0 +1,527 @@ +"""Tests for rule_service module.""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import patch + +from flowapp.constants import RuleTypes +from flowapp.models import ( + RTBH, + Flowspec4, + Flowspec6, + Whitelist, +) +from flowapp.output import RouteSources +from flowapp.services import rule_service + + +@pytest.fixture +def test_data(app, db): + """Fixture providing test data for rule service tests.""" + current_time = datetime.now() + + # Create a test Flowspec4 rule + ipv4_rule = Flowspec4( + source="192.168.1.1", + source_mask=32, + source_port="", + destination="192.168.2.1", + destination_mask=32, + destination_port="", + protocol="tcp", + flags="", + packet_len="", + fragment="", + expires=current_time + timedelta(hours=1), + comment="Test IPv4 rule", + action_id=1, # Using action ID 1 from test database + user_id=1, # Using user ID 1 from test database + org_id=1, # Using org ID 1 from test database + rstate_id=1, # Active state + ) + db.session.add(ipv4_rule) + + # Create a test Flowspec6 rule + ipv6_rule = Flowspec6( + source="2001:db8::1", + source_mask=128, + source_port="", + destination="2001:db8:1::1", + destination_mask=128, + destination_port="", + next_header="tcp", + flags="", + packet_len="", + expires=current_time + timedelta(hours=1), + comment="Test IPv6 rule", + action_id=1, # Using action ID 1 from test database + user_id=1, # Using user ID 1 from test database + org_id=1, # Using org ID 1 from test database + rstate_id=1, # Active state + ) + db.session.add(ipv6_rule) + + # Create a test RTBH rule + rtbh_rule = RTBH( + ipv4="192.168.1.100", + ipv4_mask=32, + ipv6=None, + ipv6_mask=None, + community_id=1, # Using community ID 1 from test database + expires=current_time + timedelta(hours=1), + comment="Test RTBH rule", + user_id=1, # Using user ID 1 from test database + org_id=1, # Using org ID 1 from test database + rstate_id=1, # Active state + ) + db.session.add(rtbh_rule) + + # Create a test Whitelist + whitelist = Whitelist( + ip="192.168.2.1", + mask=24, + expires=current_time + timedelta(days=7), + comment="Test whitelist", + user_id=1, + org_id=1, + rstate_id=1, + ) + db.session.add(whitelist) + + db.session.commit() + + # Return data that will be useful for tests + return { + "user_id": 1, + "org_id": 1, + "user_email": "test@example.com", + "org_name": "Test Org", + "ipv4_rule_id": ipv4_rule.id, + "ipv6_rule_id": ipv6_rule.id, + "rtbh_rule_id": rtbh_rule.id, + "whitelist_id": whitelist.id, + "current_time": current_time, + "future_time": current_time + timedelta(hours=1), + "past_time": current_time - timedelta(hours=1), + "comment": "Test comment", + } + + +class TestReactivateRule: + """Tests for the reactivate_rule function.""" + + @patch("flowapp.services.rule_service.check_global_rule_limit") + @patch("flowapp.services.rule_service.check_rule_limit") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_reactivate_rule_active( + self, mock_log_route, mock_announce_route, mock_check_rule_limit, mock_check_global_limit, test_data, db + ): + """Test reactivating a rule with future expiration (active state).""" + # Setup mocks + mock_check_global_limit.return_value = False + mock_check_rule_limit.return_value = False + + # The rule will be active (state=1) because expiration is in the future + expires = test_data["future_time"] + + # Call the function + model, messages = rule_service.reactivate_rule( + rule_type=RuleTypes.IPv4, + rule_id=test_data["ipv4_rule_id"], + expires=expires, + comment=test_data["comment"], + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + ) + + # Assertions + mock_check_global_limit.assert_called_once_with(RuleTypes.IPv4.value) + mock_check_rule_limit.assert_called_once_with(test_data["org_id"], rule_type=RuleTypes.IPv4.value) + + # Verify model was updated + assert model.expires == expires + assert model.comment == test_data["comment"] + assert model.rstate_id == 1 # Active state + + # Verify route announcement was made + mock_announce_route.assert_called_once() + args, _ = mock_announce_route.call_args + assert args[0].source == RouteSources.UI + + # Verify logging + mock_log_route.assert_called_once() + + # Check returned values + assert messages + assert messages[0].startswith("Rule ") + + @patch("flowapp.services.rule_service.check_global_rule_limit") + @patch("flowapp.services.rule_service.check_rule_limit") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_withdraw") + def test_reactivate_rule_inactive( + self, mock_log_withdraw, mock_announce_route, mock_check_rule_limit, mock_check_global_limit, test_data, db + ): + """Test reactivating a rule with past expiration (inactive state).""" + # Setup mocks + mock_check_global_limit.return_value = False + mock_check_rule_limit.return_value = False + + # The rule will be inactive (state=2) because expiration is in the past + expires = test_data["past_time"] + + # Call the function + model, messages = rule_service.reactivate_rule( + rule_type=RuleTypes.IPv4, + rule_id=test_data["ipv4_rule_id"], + expires=expires, + comment=test_data["comment"], + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + ) + + # Verify model was updated + assert model.expires == expires + assert model.comment == test_data["comment"] + assert model.rstate_id == 2 # Inactive state + + # Verify route withdrawal was made + mock_announce_route.assert_called_once() + args, _ = mock_announce_route.call_args + assert args[0].source == RouteSources.UI + + # Verify logging + mock_log_withdraw.assert_called_once() + + # Check returned values + assert messages + assert messages[0].startswith("Rule ") + + @patch("flowapp.services.rule_service.check_global_rule_limit") + @patch("flowapp.services.rule_service.check_rule_limit") + def test_reactivate_rule_global_limit_reached(self, mock_check_rule_limit, mock_check_global_limit, test_data, db): + """Test reactivating a rule when global limit is reached.""" + # Setup mocks + mock_check_global_limit.return_value = True # Global limit reached + mock_check_rule_limit.return_value = False + + # The rule would be active, but global limit is reached + expires = test_data["future_time"] + + # Call the function + model, messages = rule_service.reactivate_rule( + rule_type=RuleTypes.IPv4, + rule_id=test_data["ipv4_rule_id"], + expires=expires, + comment=test_data["comment"], + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + ) + + # Assertions + mock_check_global_limit.assert_called_once_with(RuleTypes.IPv4.value) + + # Check returned values + assert messages == ["global_limit_reached"] + + @patch("flowapp.services.rule_service.check_global_rule_limit") + @patch("flowapp.services.rule_service.check_rule_limit") + def test_reactivate_rule_org_limit_reached(self, mock_check_rule_limit, mock_check_global_limit, test_data, db): + """Test reactivating a rule when organization limit is reached.""" + # Setup mocks + mock_check_global_limit.return_value = False + mock_check_rule_limit.return_value = True # Org limit reached + + # The rule would be active, but org limit is reached + expires = test_data["future_time"] + + # Call the function + model, messages = rule_service.reactivate_rule( + rule_type=RuleTypes.IPv4, + rule_id=test_data["ipv4_rule_id"], + expires=expires, + comment=test_data["comment"], + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + ) + + # Assertions + mock_check_global_limit.assert_called_once_with(RuleTypes.IPv4.value) + mock_check_rule_limit.assert_called_once_with(test_data["org_id"], rule_type=RuleTypes.IPv4.value) + + # Check returned values + assert messages == ["limit_reached"] + + +class TestDeleteRule: + """Tests for the delete_rule function.""" + + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_withdraw") + def test_delete_rule_success(self, mock_log_withdraw, mock_announce_route, test_data, db): + """Test successful rule deletion.""" + # Call the function + success, message = rule_service.delete_rule( + rule_type=RuleTypes.IPv4, + rule_id=test_data["ipv4_rule_id"], + user_id=test_data["user_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[test_data["ipv4_rule_id"]], # Rule is in allowed list + ) + + # Verify route withdrawal was made + mock_announce_route.assert_called_once() + args, _ = mock_announce_route.call_args + assert args[0].source == RouteSources.UI + + # Verify logging + mock_log_withdraw.assert_called_once() + + # Verify database operations - rule should be deleted + rule = db.session.get(Flowspec4, test_data["ipv4_rule_id"]) + assert rule is None + + # Check returned values + assert success is True + assert message == "Rule deleted successfully" + + def test_delete_rule_not_allowed(self, test_data, db): + """Test rule deletion when rule is not in allowed list.""" + # Call the function + success, message = rule_service.delete_rule( + rule_type=RuleTypes.IPv4, + rule_id=test_data["ipv4_rule_id"], + user_id=test_data["user_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[999], # Rule is not in allowed list + ) + + # Verify rule still exists + rule = db.session.get(Flowspec4, test_data["ipv4_rule_id"]) + assert rule is not None + + # Check returned values + assert success is False + assert message == "You cannot delete this rule" + + def test_delete_rule_not_found(self, test_data, db): + """Test rule deletion when rule is not found.""" + # Call the function with non-existent ID + success, message = rule_service.delete_rule( + rule_type=RuleTypes.IPv4, + rule_id=9999, # Non-existent ID + user_id=test_data["user_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + ) + + # Check returned values + assert success is False + assert message == "Rule not found" + + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_withdraw") + @patch("flowapp.services.rule_service.RuleWhitelistCache") + def test_delete_rtbh_rule(self, mock_whitelist_cache, mock_log_withdraw, mock_announce_route, test_data, db): + """Test deleting an RTBH rule.""" + # Call the function + success, message = rule_service.delete_rule( + rule_type=RuleTypes.RTBH, + rule_id=test_data["rtbh_rule_id"], + user_id=test_data["user_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[test_data["rtbh_rule_id"]], + ) + + # Verify whitelist cache was cleaned + mock_whitelist_cache.delete_by_rule_id.assert_called_once_with(test_data["rtbh_rule_id"]) + + # Verify route withdrawal and logging + mock_announce_route.assert_called_once() + mock_log_withdraw.assert_called_once() + + # Verify rule was deleted + rule = db.session.get(RTBH, test_data["rtbh_rule_id"]) + assert rule is None + + # Check returned values + assert success is True + assert message == "Rule deleted successfully" + + +class TestDeleteRtbhAndCreateWhitelist: + """Tests for the delete_rtbh_and_create_whitelist function.""" + + @patch("flowapp.services.rule_service.delete_rule") + @patch("flowapp.services.rule_service.create_or_update_whitelist") + def test_delete_rtbh_and_create_whitelist_success(self, mock_create_whitelist, mock_delete_rule, test_data, db): + """Test successful RTBH deletion and whitelist creation.""" + # Setup mock for delete_rule service + mock_delete_rule.return_value = (True, "Rule deleted successfully") + + # Setup mock for create_or_update_whitelist service + mock_whitelist = Whitelist( + ip="192.168.1.100", + mask=32, + expires=datetime.now() + timedelta(days=7), + user_id=test_data["user_id"], + org_id=test_data["org_id"], + ) + mock_create_whitelist.return_value = (mock_whitelist, ["Whitelist created"]) + + # Call the function + success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( + rule_id=test_data["rtbh_rule_id"], + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[test_data["rtbh_rule_id"]], + ) + + # Verify delete_rule was called + mock_delete_rule.assert_called_once_with( + rule_type=RuleTypes.RTBH, + rule_id=test_data["rtbh_rule_id"], + user_id=test_data["user_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[test_data["rtbh_rule_id"]], + ) + + # Verify create_or_update_whitelist was called with correct data + mock_create_whitelist.assert_called_once() + args, kwargs = mock_create_whitelist.call_args + form_data = kwargs.get("form_data", args[0] if args else None) + assert form_data["ip"] == "192.168.1.100" + assert form_data["mask"] == 32 + assert "Created from RTBH rule" in form_data["comment"] + + # Check returned values + assert success is True + assert len(messages) == 2 + assert messages[0] == "Rule deleted successfully" + assert messages[1] == "Whitelist created" + assert whitelist == mock_whitelist + + def test_delete_rtbh_and_create_whitelist_rule_not_found(self, test_data, db): + """Test when the RTBH rule to convert is not found.""" + # Call the function with non-existent ID + success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( + rule_id=9999, # Non-existent ID + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + ) + + # Check returned values + assert success is False + assert messages == ["RTBH rule not found"] + assert whitelist is None + + def test_delete_rtbh_and_create_whitelist_not_allowed(self, test_data, db): + """Test when the user is not allowed to delete the rule.""" + # Call the function + success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( + rule_id=test_data["rtbh_rule_id"], + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[999], # Rule not in allowed list + ) + + # Check returned values + assert success is False + assert messages == ["You cannot delete this rule"] + assert whitelist is None + + @patch("flowapp.services.rule_service.delete_rule") + def test_delete_rtbh_and_create_whitelist_delete_fails(self, mock_delete_rule, test_data, db): + """Test when the RTBH deletion fails.""" + # Setup mock for delete_rule service to fail + mock_delete_rule.return_value = (False, "Error deleting rule") + + # Call the function + success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( + rule_id=test_data["rtbh_rule_id"], + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[test_data["rtbh_rule_id"]], + ) + + # Verify delete_rule was called + mock_delete_rule.assert_called_once() + + # Check returned values + assert success is False + assert messages == ["Error deleting rule"] + assert whitelist is None + + @patch("flowapp.services.rule_service.delete_rule") + @patch("flowapp.services.rule_service.create_or_update_whitelist") + def test_rtbh_with_ipv6(self, mock_create_whitelist, mock_delete_rule, test_data, db): + """Test with an IPv6 RTBH rule.""" + # Create an IPv6 RTBH rule + ipv6_rtbh = RTBH( + ipv4=None, + ipv4_mask=None, + ipv6="2001:db8::1", + ipv6_mask=64, + community_id=1, + expires=datetime.now() + timedelta(hours=1), + comment="IPv6 RTBH rule", + user_id=test_data["user_id"], + org_id=test_data["org_id"], + rstate_id=1, + ) + db.session.add(ipv6_rtbh) + db.session.commit() + + # Setup mocks + mock_delete_rule.return_value = (True, "Rule deleted successfully") + mock_whitelist = Whitelist( + ip="2001:db8::1", + mask=64, + expires=datetime.now() + timedelta(days=7), + user_id=test_data["user_id"], + org_id=test_data["org_id"], + ) + mock_create_whitelist.return_value = (mock_whitelist, ["Whitelist created"]) + + # Call the function + success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( + rule_id=ipv6_rtbh.id, + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[ipv6_rtbh.id], + ) + + # Verify create_or_update_whitelist was called with correct IPv6 data + mock_create_whitelist.assert_called_once() + args, kwargs = mock_create_whitelist.call_args + form_data = kwargs.get("form_data", args[0] if args else None) + assert form_data["ip"] == "2001:db8::1" + assert form_data["mask"] == 64 + + # Check returned values + assert success is True + assert len(messages) == 2 + assert whitelist == mock_whitelist diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..47c89c4 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,71 @@ +import ipaddress +import pytest + +from datetime import datetime, timedelta + +from flowapp import utils + + +@pytest.mark.parametrize( + "apitime, preformat", + [ + ("10/15/2015 14:46", "us"), + ("2015/10/15 14:46", "yearfirst"), + ("1444913400", "timestamp"), + (1444913400, "timestamp"), + ], +) +def test_parse_api_time(apitime, preformat): + """ + is the time parsed correctly + """ + result = utils.parse_api_time(apitime) + assert isinstance(result, tuple) + assert result[0] == datetime(2015, 10, 15, 14, 50) + assert result[1] == preformat + + +@pytest.mark.parametrize( + "apitime", ["10/152015 14:46", "201/10/15 14:46", "144123254913400", "abcd"] +) +def test_parse_api_time_bad_time(apitime): + """ + is the time parsed correctly + """ + assert not utils.parse_api_time(apitime) + + +def test_get_rule_state_by_time(): + """ + Test if time in the past returns 2 + """ + past = datetime.now() - timedelta(days=1) + + assert utils.get_state_by_time(past) == 2 + + +def test_round_to_ten(): + """ + Test if the time is rounded correctly + """ + d1 = datetime(2013, 9, 2, 16, 25, 59) + d2 = datetime(2013, 9, 2, 16, 32, 59) + dround = datetime(2013, 9, 2, 16, 30, 00) + + assert utils.round_to_ten_minutes(d1) == dround + assert utils.round_to_ten_minutes(d2) == dround + + +@pytest.mark.parametrize( + "address_a, address_b", + [ + ( + "2001:718:1c01:16:f1c4:c682:817:7e23", + "2001:0718:1c01:0016:f1c4:c682:0817:7e23", + ), + ("2001:718::", "2001:718::0"), + ("2001:718::0", "2001:0718:0000:0000:0000:0000:0000:0000"), + ], +) +def test_ipv6_comparsion(address_a, address_b): + assert ipaddress.ip_address(address_a) == ipaddress.ip_address(address_b) diff --git a/tests/test_validators.py b/tests/test_validators.py new file mode 100644 index 0000000..614d3a5 --- /dev/null +++ b/tests/test_validators.py @@ -0,0 +1,547 @@ +import pytest +from flowapp.validators import ( + PortString, + PacketString, + NetRangeString, + NetInRange, + IPAddress, + IPAddressValidator, + NetworkValidator, + ValidationError, + address_in_range, + address_with_mask, + IPv4Address, + IPv6Address, + DateNotExpired, + editable_range, + network_in_range, + range_in_network, + filter_rules_in_network, + split_rules_for_user, + filter_rtbh_rules, + split_rtbh_rules_for_user, +) + + +def test_port_string_len_raises(field): + port = PortString() + field.data = "1;2;3;4;5;6;7;8" + with pytest.raises(ValidationError): + port(None, field) + + +@pytest.mark.parametrize( + "address, mask, expected", + [ + ("147.230.23.25", "24", False), + ("147.230.23.0", "24", True), + ("0.0.0.0", "0", True), + ("2001:718:1C01:1111::1111", "64", False), + ("2001:718:1C01:1111::", "64", True), + ], +) +def test_is_valid_address_with_mask(address, mask, expected): + assert address_with_mask(address, mask) == expected + + +@pytest.mark.parametrize("address", ["147.230.23.25", "147.230.23.0"]) +def test_ip4address_passes(field, address): + adr = IPv4Address() + field.data = address + adr(None, field) + + +@pytest.mark.parametrize( + "address", + [ + "2001:718:1C01:1111::1111", + "2001:718:1C01:1111::", + ], +) +def test_ip6address_passes(field, address): + adr = IPv6Address() + field.data = address + adr(None, field) + + +@pytest.mark.parametrize( + "address", + [ + "2001:718:1C01:1111::1111", + "2001:718:1C01:1111::", + ], +) +def test_bad_ip6address_raises(field, address): + adr = IPv4Address() + field.data = address + with pytest.raises(ValidationError): + adr(None, field) + + +@pytest.mark.parametrize("expired", ["2018/10/25 14:46", "2018/12/20 9:46", "2019/05/22 12:33"]) +def test_expired_date_raises(field, expired): + adr = DateNotExpired() + field.data = expired + with pytest.raises(ValidationError): + adr(None, field) + + +@pytest.mark.parametrize( + "address", + [ + "147.230.1000.25", + "2001:718::::", + ], +) +def test_ipaddress_raises(field, address): + adr = IPv6Address() + field.data = address + with pytest.raises(ValidationError): + adr(None, field) + + +@pytest.mark.parametrize( + "address, mask, ranges, expected", + [ + ("147.230.23.0", "24", ["147.230.0.0/16", "147.251.0.0/16"], True), + ("147.230.23.0", "24", ["147.230.0.0/16", "147.251.0.0/16"], True), + ], +) +def test_editable_rule(rule, address, mask, ranges, expected): + rule.source = address + rule.source_mask = mask + assert editable_range(rule, ranges) == expected + + +@pytest.mark.parametrize( + "address, mask, ranges, expected", + [ + ("147.230.23.0", "24", ["147.230.0.0/16", "147.251.0.0/16"], True), + ("147.233.23.0", "24", ["147.230.0.0/16", "147.251.0.0/16"], False), + ("147.230.23.0", "24", ["147.230.0.0/16", "2001:718:1c01::/48"], True), + ("195.113.0.0", "16", ["0.0.0.0/0", "::/0"], True), + ], +) +def test_address_in_range(address, mask, ranges, expected): + assert address_in_range(address, ranges) == expected + + +@pytest.mark.parametrize( + "address, mask, ranges, expected", + [ + ("147.230.23.0", "24", ["147.230.0.0/16", "147.251.0.0/16"], True), + ("147.233.23.0", "24", ["147.230.0.0/16", "147.251.0.0/16"], False), + ("195.113.0.0", "16", ["195.113.0.0/18", "195.113.64.0/21"], False), + ("195.113.0.0", "16", ["0.0.0.0/0", "::/0"], True), + ( + "195.113.0.0", + "16", + ["147.230.0.0/16", "2001:718:1c01::/48", "0.0.0.0/0", "::/0"], + True, + ), + ], +) +def test_network_in_range(address, mask, ranges, expected): + assert network_in_range(address, mask, ranges) == expected + + +@pytest.mark.parametrize( + "address, mask, ranges, expected", + [ + ("195.113.0.0", "16", ["147.230.0.0/16", "195.113.250.0/24"], True), + ], +) +def test_range_in_network(address, mask, ranges, expected): + assert range_in_network(address, mask, ranges) == expected + + +### new tests + + +# New tests for previously uncovered validators + + +# Tests for PortString validator syntax +@pytest.mark.parametrize( + "port_data", + [ + "80", # Simple number + "80-443", # Range with hyphen + ">=80&<=443", # Explicit range + ">80", # Greater than + ">=80", # Greater than or equal + "<443", # Less than + "<=443", # Less than or equal + "80;443;8080", # Multiple port expressions + ], +) +def test_port_string_valid_syntax(field, port_data): + validator = PortString() + field.data = port_data + validator(None, field) # Should not raise + + +@pytest.mark.parametrize( + "invalid_port_data", + [ + "80,443", # Comma not supported, should use semicolon + "80..443", # Invalid range syntax (should be 80-443) + ">65536", # Port number too high + "-1", # Negative number + "<-1", # Port number too low + ">=80<=443", # Invalid range syntax (missing &) + "!=80", # Not equals not supported + "abc", # Non-numeric + "80-", # Incomplete range + "-80", # Invalid range + "80&&443", # Invalid operator + "443-80", # End less than start + ">=443&<=80", # End less than start in explicit range + "0-65537", # Range exceeding max + ], +) +def test_port_string_invalid_syntax(field, invalid_port_data): + validator = PortString() + field.data = invalid_port_data + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for PacketString validator +@pytest.mark.parametrize( + "packet_data", + ["1500", ">1000", "<9000", "1000-1500", ">=1000&<=1500", "1500;9000"], # Multiple packet size expressions +) +def test_packet_string_valid(field, packet_data): + validator = PacketString() + field.data = packet_data + validator(None, field) # Should not raise + + +@pytest.mark.parametrize( + "invalid_packet_data", + [ + "65536", # Too large + "-1", # Negative + "1500..", # Invalid range syntax + "abc", # Non-numeric + "!!1500", # Invalid operator + "1500-", # Incomplete range + "9000-1500", # End less than start + ], +) +def test_packet_string_invalid(field, invalid_packet_data): + validator = PacketString() + field.data = invalid_packet_data + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for NetRangeString validator +@pytest.mark.parametrize( + "net_range", + [ + "192.168.0.0/24", + "10.0.0.0/8\n172.16.0.0/12", + "2001:db8::/32", + "192.168.0.0/24 10.0.0.0/8", + "2001:db8::/32 2001:db8:1::/48", + ], +) +def test_net_range_string_valid(field, net_range): + validator = NetRangeString() + field.data = net_range + validator(None, field) # Should not raise + + +@pytest.mark.parametrize( + "invalid_net_range", + [ + "192.168.0.0/33", # Invalid mask + "256.256.256.0/24", # Invalid IP + "2001:xyz::/32", # Invalid IPv6 + "192.168.0.0/24/", # Malformed + "not-an-ip/24", # Invalid format + ], +) +def test_net_range_string_invalid(field, invalid_net_range): + validator = NetRangeString() + field.data = invalid_net_range + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for NetInRange validator +def test_net_in_range_valid(field): + net_ranges = ["192.168.0.0/16", "10.0.0.0/8"] + validator = NetInRange(net_ranges) + field.data = "192.168.1.1/24" + validator(None, field) # Should not raise + + +def test_net_in_range_invalid(field): + net_ranges = ["192.168.0.0/16", "10.0.0.0/8"] + validator = NetInRange(net_ranges) + field.data = "172.16.1.1/24" + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for base IPAddress validator +@pytest.mark.parametrize("ip_addr", ["192.168.1.1", "10.0.0.1", "2001:db8::1", "fe80::1"]) +def test_ip_address_valid(field, ip_addr): + validator = IPAddress() + field.data = ip_addr + validator(None, field) # Should not raise + + +@pytest.mark.parametrize("invalid_ip", ["256.256.256.256", "2001:xyz::1", "not-an-ip", "192.168.1"]) +def test_ip_address_invalid(field, invalid_ip): + validator = IPAddress() + field.data = invalid_ip + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for universal IPAddressValidator +@pytest.mark.parametrize("ip_addr", ["192.168.1.1", "2001:db8::1", "", None]) # Empty is allowed # None is allowed +def test_ip_address_validator_valid(field, ip_addr): + validator = IPAddressValidator() + field.data = ip_addr + validator(None, field) # Should not raise + + +@pytest.mark.parametrize("invalid_ip", ["256.256.256.256", "2001:xyz::1", "not-an-ip"]) +def test_ip_address_validator_invalid(field, invalid_ip): + validator = IPAddressValidator() + field.data = invalid_ip + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for NetworkValidator +class MockForm: + def __init__(self, address, mask): + self._fields = {"mask": type("MockField", (), {"data": mask})()} + + +@pytest.mark.parametrize( + "address,mask", + [ + ("192.168.0.0", "24"), + ("10.0.0.0", "8"), + ("2001:db8::", "32"), + ("", None), # Empty values should pass + (None, None), # None values should pass + ], +) +def test_network_validator_valid(field, address, mask): + validator = NetworkValidator("mask") + field.data = address + form = MockForm(address, mask) + validator(form, field) # Should not raise + + +@pytest.mark.parametrize( + "address,mask", + [ + ("192.168.1.1", "24"), # Not a network address + ("192.168.0.0", "33"), # Invalid IPv4 mask + ("2001:db8::", "129"), # Invalid IPv6 mask + ("256.256.256.0", "24"), # Invalid IP + ("2001:xyz::", "32"), # Invalid IPv6 + ], +) +def test_network_validator_invalid(field, address, mask): + validator = NetworkValidator("mask") + field.data = address + form = MockForm(address, mask) + with pytest.raises(ValidationError): + validator(form, field) + + +# Mock rule classes for testing robust attribute handling +class MockRule: + """Mock rule with all expected attributes""" + + def __init__(self, source=None, source_mask=None, dest=None, dest_mask=None): + self.source = source + self.source_mask = source_mask + self.dest = dest + self.dest_mask = dest_mask + + +class MockRuleIncomplete: + """Mock rule with missing attributes""" + + def __init__(self, name=None): + self.name = name + # Intentionally missing source, source_mask, dest, dest_mask attributes + + +class MockRulePartial: + """Mock rule with some attributes""" + + def __init__(self, source=None): + self.source = source + # Missing source_mask, dest, dest_mask attributes + + +class MockRTBHRule: + """Mock RTBH rule with all expected attributes""" + + def __init__(self, ipv4=None, ipv4_mask=None, ipv6=None, ipv6_mask=None): + self.ipv4 = ipv4 + self.ipv4_mask = ipv4_mask + self.ipv6 = ipv6 + self.ipv6_mask = ipv6_mask + + +class MockRTBHRuleIncomplete: + """Mock RTBH rule with missing attributes""" + + def __init__(self, name=None): + self.name = name + # Intentionally missing ipv4, ipv4_mask, ipv6, ipv6_mask attributes + + +# Tests for filter_rules_in_network with robust attribute handling +def test_filter_rules_in_network_normal_rules(): + """Test filter_rules_in_network with normal rule objects""" + net_ranges = ["192.168.0.0/16", "10.0.0.0/8"] + rules = [ + MockRule("192.168.1.0", "24", "10.0.1.0", "24"), # Should match + MockRule("172.16.1.0", "24", "172.16.2.0", "24"), # Should not match + MockRule("10.1.0.0", "16", None, None), # Should match (source only) + ] + + filtered = filter_rules_in_network(net_ranges, rules) + assert len(filtered) == 2 + assert rules[0] in filtered # 192.168.x.x rule + assert rules[2] in filtered # 10.x.x.x rule + assert rules[1] not in filtered # 172.16.x.x rule + + +def test_filter_rules_in_network_missing_attributes(): + """Test filter_rules_in_network with rules missing required attributes""" + net_ranges = ["192.168.0.0/16"] + rules = [ + MockRule("192.168.1.0", "24", "10.0.1.0", "24"), # Normal rule - should match + MockRuleIncomplete("rule_without_network_attrs"), # Missing attrs - should be included + MockRulePartial("172.16.1.0"), # Partial attrs - should be included + ] + + filtered = filter_rules_in_network(net_ranges, rules) + assert len(filtered) == 3 # All rules should be included + assert all(rule in filtered for rule in rules) + + +def test_filter_rules_in_network_none_values(): + """Test filter_rules_in_network with None values in attributes""" + net_ranges = ["192.168.0.0/16"] + rules = [ + MockRule("192.168.1.0", "24", None, None), # Should match on source + MockRule(None, None, "192.168.2.0", "24"), # Should match on dest + MockRule(None, None, None, None), # Should not match + ] + + filtered = filter_rules_in_network(net_ranges, rules) + assert len(filtered) == 2 + assert rules[0] in filtered + assert rules[1] in filtered + assert rules[2] not in filtered + + +# Tests for split_rules_for_user with robust attribute handling +def test_split_rules_for_user_normal_rules(): + """Test split_rules_for_user with normal rule objects""" + net_ranges = ["192.168.0.0/16"] + rules = [ + MockRule("192.168.1.0", "24", "10.0.1.0", "24"), # Should be user rule + MockRule("172.16.1.0", "24", "172.16.2.0", "24"), # Should be rest rule + ] + + user_rules, rest_rules = split_rules_for_user(net_ranges, rules) + assert len(user_rules) == 1 + assert len(rest_rules) == 1 + assert rules[0] in user_rules + assert rules[1] in rest_rules + + +def test_split_rules_for_user_missing_attributes(): + """Test split_rules_for_user with rules missing required attributes""" + net_ranges = ["192.168.0.0/16"] + rules = [ + MockRule("192.168.1.0", "24", "10.0.1.0", "24"), # Normal rule - user rule + MockRuleIncomplete("rule_without_attrs"), # Missing attrs - should be user rule + MockRule("172.16.1.0", "24", "172.16.2.0", "24"), # Normal rule - rest rule + ] + + user_rules, rest_rules = split_rules_for_user(net_ranges, rules) + assert len(user_rules) == 2 # Normal matching rule + incomplete rule + assert len(rest_rules) == 1 + assert rules[0] in user_rules # Matching rule + assert rules[1] in user_rules # Incomplete rule treated as editable + assert rules[2] in rest_rules # Non-matching rule + + +# Tests for filter_rtbh_rules with robust attribute handling +def test_filter_rtbh_rules_normal_rules(): + """Test filter_rtbh_rules with normal RTBH rule objects""" + net_ranges = ["192.168.0.0/16", "2001:db8::/32"] + rules = [ + MockRTBHRule("192.168.1.0", "24", None, None), # Should match on IPv4 + MockRTBHRule(None, None, "2001:db8:1::", "48"), # Should match on IPv6 + MockRTBHRule("172.16.1.0", "24", "2001:db9::", "32"), # Should not match + ] + + filtered = filter_rtbh_rules(net_ranges, rules) + assert len(filtered) == 2 + assert rules[0] in filtered + assert rules[1] in filtered + assert rules[2] not in filtered + + +# Tests for split_rtbh_rules_for_user with robust attribute handling +def test_split_rtbh_rules_for_user_normal_rules(): + """Test split_rtbh_rules_for_user with normal RTBH rule objects""" + net_ranges = ["192.168.0.0/16"] + rules = [ + MockRTBHRule("192.168.1.0", "24", None, None), # Should be filtered (user) + MockRTBHRule("172.16.1.0", "24", None, None), # Should be read-only + ] + + filtered, read_only = split_rtbh_rules_for_user(net_ranges, rules) + assert len(filtered) == 1 + assert len(read_only) == 1 + assert rules[0] in filtered + assert rules[1] in read_only + + +# Edge case tests +def test_filter_functions_empty_input(): + """Test all filter functions with empty input""" + net_ranges = ["192.168.0.0/16"] + + # Empty rules list + assert filter_rules_in_network(net_ranges, []) == [] + assert split_rules_for_user(net_ranges, []) == ([], []) + assert filter_rtbh_rules(net_ranges, []) == [] + assert split_rtbh_rules_for_user(net_ranges, []) == ([], []) + + +def test_filter_functions_empty_net_ranges(): + """Test filter functions with empty network ranges""" + rules = [MockRule("192.168.1.0", "24", None, None)] + rtbh_rules = [MockRTBHRule("192.168.1.0", "24", None, None)] + + # Empty network ranges - nothing should match + assert filter_rules_in_network([], rules) == [] + user_rules, rest_rules = split_rules_for_user([], rules) + assert user_rules == [] + assert rest_rules == rules + + assert filter_rtbh_rules([], rtbh_rules) == [] + filtered, read_only = split_rtbh_rules_for_user([], rtbh_rules) + assert filtered == [] + assert read_only == rtbh_rules diff --git a/tests/test_whitelist_common.py b/tests/test_whitelist_common.py new file mode 100644 index 0000000..a843aff --- /dev/null +++ b/tests/test_whitelist_common.py @@ -0,0 +1,250 @@ +import pytest +from flowapp.services.whitelist_common import ( + Relation, + check_rule_against_whitelists, + check_whitelist_against_rules, + check_whitelist_to_rule_relation, + subtract_network, + clear_network_cache, + _is_same_ip_version, # New helper function +) + + +# Tests for core function that checks network relations +@pytest.mark.parametrize( + "rule,whitelist,expected", + [ + # IPv4 test cases + ("192.168.1.0/24", "192.168.0.0/16", Relation.SUPERNET), # Whitelist is supernet + ("192.168.1.0/24", "192.168.1.0/24", Relation.EQUAL), # Equal networks + ("192.168.1.0/24", "192.168.1.128/25", Relation.SUBNET), # Whitelist is subnet + ("192.168.1.128/25", "192.168.1.0/24", Relation.SUPERNET), # Whitelist is supernet + ("10.0.0.0/8", "192.168.1.0/24", Relation.DIFFERENT), # Different networks + # IPv6 test cases + ("2001:db8::/32", "2001:db8::/32", Relation.EQUAL), # Equal networks + ("2001:db8:1::/48", "2001:db8::/32", Relation.SUPERNET), # Whitelist is supernet + ("2001:db8::/32", "2001:db8:1::/48", Relation.SUBNET), # Whitelist is subnet + ("2001:db8::/32", "2002:db8::/32", Relation.DIFFERENT), # Different networks + ("2001:db8:1:2::/64", "2001:db8::/32", Relation.SUPERNET), # Whitelist is supernet + ], +) +def test_check_whitelist_to_rule_relation(rule, whitelist, expected): + clear_network_cache() + assert check_whitelist_to_rule_relation(rule, whitelist) == expected + + +@pytest.mark.parametrize( + "target,whitelist,expected", + [ + # Basic IPv4 cases + ( + "192.168.1.0/24", + "192.168.1.128/25", + ["192.168.1.0/25"], + ), # One remaining subnet + ( + "192.168.1.0/24", + "192.168.1.64/26", + ["192.168.1.0/26", "192.168.1.128/25"], + ), # Two remaining subnets + ( + "192.168.1.0/24", + "192.168.1.0/24", + ["192.168.1.0/24"], + ), # Equal networks - return original network + ( + "192.168.1.0/24", + "192.168.2.0/24", + ["192.168.1.0/24"], + ), # No overlap + ], +) +def test_subtract_network(target, whitelist, expected): + clear_network_cache() + result = subtract_network(target, whitelist) + assert sorted(result) == sorted(expected) + + +def test_check_rule_against_whitelists(): + clear_network_cache() + + rule = "192.168.1.0/24" + whitelists = [ + "192.168.0.0/16", # SUPERNET + "172.16.0.0/12", # DIFFERENT + "192.168.1.0/24", # EQUAL + "192.168.1.128/25", # SUBNET + ] + + results = check_rule_against_whitelists(rule, whitelists) + + # Should return list of tuples (rule, whitelist, relation) for non-DIFFERENT relations + expected = [ + (rule, "192.168.0.0/16", Relation.SUPERNET), + (rule, "192.168.1.0/24", Relation.EQUAL), + (rule, "192.168.1.128/25", Relation.SUBNET), + ] + + assert len(results) == 3 + for result, exp in zip(sorted(results), sorted(expected)): + assert result == exp + + +def test_check_whitelist_against_rules(): + clear_network_cache() + + whitelist = "192.168.1.128/25" + rules = [ + "192.168.0.0/16", # Whitelist is SUBNET of this larger network + "172.16.0.0/12", # DIFFERENT + "192.168.1.128/25", # EQUAL + "192.168.1.0/24", # Whitelist is SUBNET of this network + ] + + results = check_whitelist_against_rules(rules, whitelist) + + # Should return list of tuples (rule, whitelist, relation) for non-DIFFERENT relations + expected = [ + ("192.168.0.0/16", whitelist, Relation.SUBNET), + ("192.168.1.128/25", whitelist, Relation.EQUAL), + ("192.168.1.0/24", whitelist, Relation.SUBNET), + ] + + assert len(results) == 3 + for result, exp in zip(sorted(results), sorted(expected)): + assert result == exp + + +# Tests for edge cases and error handling +def test_host_addresses(): + clear_network_cache() + # Test with host addresses (no explicit subnet mask) + assert check_whitelist_to_rule_relation("192.168.1.1", "192.168.1.0/24") == Relation.SUPERNET + assert check_whitelist_to_rule_relation("192.168.1.1", "192.168.2.0/24") == Relation.DIFFERENT + + +def test_single_ip_as_network(): + clear_network_cache() + # Test with /32 (single IPv4 address as network) + assert check_whitelist_to_rule_relation("192.168.1.1/32", "192.168.1.0/24") == Relation.SUPERNET + # Test with /128 (single IPv6 address as network) + assert check_whitelist_to_rule_relation("2001:db8::1/128", "2001:db8::/32") == Relation.SUPERNET + + +def test_is_same_ip_version(): + """Test the new helper function to check if two addresses are of the same IP version""" + # IPv4 cases + assert _is_same_ip_version("192.168.1.0/24", "10.0.0.0/8") is True + assert _is_same_ip_version("192.168.1.1", "10.0.0.1") is True + + # IPv6 cases + assert _is_same_ip_version("2001:db8::/32", "fe80::/64") is True + assert _is_same_ip_version("2001:db8::1", "fe80::1") is True + + # Mixed cases + assert _is_same_ip_version("192.168.1.0/24", "2001:db8::/32") is False + assert _is_same_ip_version("192.168.1.1", "fe80::1") is False + + +def test_check_whitelist_to_rule_relation_mixed_versions(): + """Test the check_whitelist_to_rule_relation function with mixed IP versions""" + clear_network_cache() + + # IPv4 rule with IPv6 whitelist + assert check_whitelist_to_rule_relation("192.168.1.0/24", "2001:db8::/32") == Relation.DIFFERENT + + # IPv6 rule with IPv4 whitelist + assert check_whitelist_to_rule_relation("2001:db8::/32", "192.168.1.0/24") == Relation.DIFFERENT + + # Invalid IP addresses should return DIFFERENT + assert check_whitelist_to_rule_relation("invalid", "192.168.1.0/24") == Relation.DIFFERENT + assert check_whitelist_to_rule_relation("192.168.1.0/24", "invalid") == Relation.DIFFERENT + + +def test_subtract_network_mixed_versions(): + """Test the subtract_network function with mixed IP versions""" + clear_network_cache() + + # IPv4 target with IPv6 whitelist - should return original target + result = subtract_network("192.168.1.0/24", "2001:db8::/32") + assert result == ["192.168.1.0/24"] + + # IPv6 target with IPv4 whitelist - should return original target + result = subtract_network("2001:db8::/32", "192.168.1.0/24") + assert result == ["2001:db8::/32"] + + # Invalid addresses - should return original target + result = subtract_network("invalid", "192.168.1.0/24") + assert result == ["invalid"] + result = subtract_network("192.168.1.0/24", "invalid") + assert result == ["192.168.1.0/24"] + + +def test_check_rule_against_whitelists_mixed_versions(): + """Test the check_rule_against_whitelists function with mixed IP versions""" + clear_network_cache() + + # IPv4 rule against mixed whitelist + rule = "192.168.1.0/24" + whitelists = [ + "192.168.0.0/16", # IPv4 - should match + "2001:db8::/32", # IPv6 - should not match + "10.0.0.0/8", # IPv4 - should not match (different network) + ] + + results = check_rule_against_whitelists(rule, whitelists) + assert len(results) == 1 + assert results[0][0] == rule + assert results[0][1] == "192.168.0.0/16" + assert results[0][2] == Relation.SUPERNET + + # IPv6 rule against mixed whitelist + rule = "2001:db8:1::/48" + whitelists = [ + "192.168.0.0/16", # IPv4 - should not match + "2001:db8::/32", # IPv6 - should match + "2002:db8::/32", # IPv6 - should not match (different network) + ] + + results = check_rule_against_whitelists(rule, whitelists) + assert len(results) == 1 + assert results[0][0] == rule + assert results[0][1] == "2001:db8::/32" + assert results[0][2] == Relation.SUPERNET + + +def test_check_whitelist_against_rules_mixed_versions(): + """Test the check_whitelist_against_rules function with mixed IP versions""" + clear_network_cache() + + # IPv4 whitelist against mixed rules + whitelist = "192.168.1.0/24" + rules = [ + "192.168.0.0/16", # IPv4 - should match + "2001:db8::/32", # IPv6 - should not match + "10.0.0.0/8", # IPv4 - should not match (different network) + ] + + results = check_whitelist_against_rules(rules, whitelist) + assert len(results) == 1 + assert results[0][0] == "192.168.0.0/16" + assert results[0][1] == whitelist + assert results[0][2] == Relation.SUBNET + + # IPv6 whitelist against mixed rules + whitelist = "2001:db8:1::/48" + rules = [ + "192.168.0.0/16", # IPv4 - should not match + "2001:db8::/32", # IPv6 - should match + "2002:db8::/32", # IPv6 - should not match (different network) + ] + + results = check_whitelist_against_rules(rules, whitelist) + assert len(results) == 1 + assert results[0][0] == "2001:db8::/32" + assert results[0][1] == whitelist + assert results[0][2] == Relation.SUBNET + + +if __name__ == "__main__": + pytest.main(["-v"]) diff --git a/tests/test_whitelist_service.py b/tests/test_whitelist_service.py new file mode 100644 index 0000000..460a172 --- /dev/null +++ b/tests/test_whitelist_service.py @@ -0,0 +1,461 @@ +""" +Tests for whitelist_service.py module. + +This test suite verifies the functionality of the whitelist service, +which manages creation, updating, and handling of whitelist rules. +""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import patch, MagicMock + +from flowapp.constants import RuleTypes, RuleOrigin +from flowapp.models import Whitelist, RuleWhitelistCache, RTBH +from flowapp.services import whitelist_service +from flowapp.services.whitelist_common import Relation + + +@pytest.fixture +def whitelist_form_data(): + """Sample valid whitelist form data""" + return { + "ip": "192.168.1.0", + "mask": 24, + "comment": "Test whitelist entry", + "expires": datetime.now() + timedelta(hours=1), + } + + +class TestCreateOrUpdateWhitelist: + @patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists") + @patch("flowapp.services.whitelist_service.check_whitelist_against_rules") + @patch("flowapp.services.whitelist_service.evaluate_whitelist_against_rtbh_check_results") + def test_create_new_whitelist( + self, mock_evaluate, mock_check_whitelist, mock_get_model, app, db, whitelist_form_data + ): + """Test creating a new whitelist entry""" + # Mock the get_whitelist_model_if_exists to return False (not found) + mock_get_model.return_value = False + + # Mock RTBH rules and check results + mock_rtbh_rules = [] + mock_query = MagicMock() + mock_query.filter.return_value.all.return_value = mock_rtbh_rules + + # Mock check_whitelist_against_rules to return empty list (no matches) + mock_check_whitelist.return_value = [] + + # Mock evaluate to return the model unchanged + mock_evaluate.side_effect = lambda model, flashes, rtbh_rule_cache, results: model + + # Call the service function + with patch("flowapp.services.whitelist_service.db.session.query", return_value=mock_query): + with app.app_context(): + model, flashes = whitelist_service.create_or_update_whitelist( + form_data=whitelist_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was created with correct attributes + assert model is not None + assert model.ip == whitelist_form_data["ip"] + assert model.mask == whitelist_form_data["mask"] + assert model.comment == whitelist_form_data["comment"] + assert model.user_id == 1 + assert model.org_id == 1 + assert model.rstate_id == 1 # Active state + + # Verify flash messages - now a list instead of a string + assert isinstance(flashes, list) + assert "Whitelist saved" in flashes[0] + + @patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists") + @patch("flowapp.services.whitelist_service.check_whitelist_against_rules") + @patch("flowapp.services.whitelist_service.evaluate_whitelist_against_rtbh_check_results") + def test_update_existing_whitelist( + self, mock_evaluate, mock_check_whitelist, mock_get_model, app, db, whitelist_form_data + ): + """Test updating an existing whitelist entry""" + # Create an existing whitelist + existing_model = Whitelist( + ip=whitelist_form_data["ip"], + mask=whitelist_form_data["mask"], + expires=datetime.now(), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Mock to return the existing model + mock_get_model.return_value = existing_model + + # Mock RTBH rules and check results + mock_rtbh_rules = [] + mock_query = MagicMock() + mock_query.filter.return_value.all.return_value = mock_rtbh_rules + + # Mock check_whitelist_against_rules to return empty list (no matches) + mock_check_whitelist.return_value = [] + + # Mock evaluate to return the model unchanged + mock_evaluate.side_effect = lambda model, flashes, rtbh_rule_cache, results: model + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + whitelist_form_data["expires"] = new_expires + + # Call the service function + with patch("flowapp.services.whitelist_service.db.session.query", return_value=mock_query): + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, flashes = whitelist_service.create_or_update_whitelist( + form_data=whitelist_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + # Check that expiration date was updated (rounded to 10 minutes) + # We can't compare exact timestamps, so check date parts + assert model.expires.date() == whitelist_service.round_to_ten_minutes(new_expires).date() + + # Verify flash messages - now a list instead of a string + assert isinstance(flashes, list) + assert "Existing Whitelist found" in flashes[0] + + @patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists") + @patch("flowapp.services.whitelist_service.db.session.query") + @patch("flowapp.services.whitelist_service.map_rtbh_rules_to_strings") + @patch("flowapp.services.whitelist_service.check_whitelist_against_rules") + @patch("flowapp.services.whitelist_service.evaluate_whitelist_against_rtbh_check_results") + def test_create_whitelist_with_matching_rules( + self, mock_evaluate, mock_check, mock_map, mock_query, mock_get_model, app, db, whitelist_form_data + ): + """Test creating a whitelist that affects existing rules""" + # Mock get_whitelist_model_if_exists to return False (new whitelist) + mock_get_model.return_value = False + + # Create a mock RTBH rule + mock_rtbh_rule = MagicMock(spec=RTBH) + mock_rtbh_rule.__str__.return_value = "192.168.1.0/24" + mock_rtbh_rules = [mock_rtbh_rule] + + # Setup mock query to return our rule + mock_query_result = MagicMock() + mock_query_result.filter.return_value.all.return_value = mock_rtbh_rules + mock_query.return_value = mock_query_result + + # Setup map function to return our rule in a dict + mock_map.return_value = {"192.168.1.0/24": mock_rtbh_rule} + + # Setup check function to return a relation + rule_relation = [ + ( + str(mock_rtbh_rule), + str(whitelist_form_data["ip"]) + "/" + str(whitelist_form_data["mask"]), + Relation.EQUAL, + ) + ] + mock_check.return_value = rule_relation + + # Setup evaluate function to modify flashes + def evaluate_side_effect(model, flashes, rtbh_rule_cache, results): + flashes.append("Rule was whitelisted") + return model + + mock_evaluate.side_effect = evaluate_side_effect + + # Call the service function + with app.app_context(): + model, flashes = whitelist_service.create_or_update_whitelist( + form_data=whitelist_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify flash messages includes both whitelist creation and rule whitelisting + assert "Whitelist saved" in flashes[0] + assert "Rule was whitelisted" in flashes[1] + + # Verify interactions + mock_map.assert_called_once() + mock_check.assert_called_once() + mock_evaluate.assert_called_once() + + +class TestDeleteWhitelist: + @patch("flowapp.services.whitelist_service.announce_rtbh_route") + def test_delete_whitelist_with_user_rules(self, mock_announce, app, db): + """Test deleting a whitelist that has user-created rules attached to it""" + # Create a whitelist + whitelist = Whitelist( + ip="192.168.1.0", + mask=24, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Create a RTBH rule (created by user, whitelisted) + rtbh_rule = RTBH( + ipv4="192.168.1.0", + ipv4_mask=24, + ipv6="", + ipv6_mask=None, + community_id=1, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=4, # Whitelisted state + ) + + # Create a cache entry linking the rule to the whitelist + cache_entry = RuleWhitelistCache( + rid=1, + rtype=RuleTypes.RTBH, + whitelist_id=1, + rorigin=RuleOrigin.USER, + ) + + with app.app_context(): + db.session.add(whitelist) + db.session.add(rtbh_rule) + db.session.commit() + + # Set whitelist ID now that it has been saved + cache_entry.whitelist_id = whitelist.id + cache_entry.rid = rtbh_rule.id + db.session.add(cache_entry) + db.session.commit() + + # Mock the get_by_whitelist_id to return our cache entry + with patch.object(RuleWhitelistCache, "get_by_whitelist_id", return_value=[cache_entry]): + # Call the service function + flashes = whitelist_service.delete_whitelist(whitelist.id) + + # Verify the rule state was changed back to active + rtbh_rule = db.session.get(RTBH, rtbh_rule.id) + assert rtbh_rule.rstate_id == 1 # Active state + + # Verify announcement was made + mock_announce.assert_called_once() + + # Verify flash messages + assert isinstance(flashes, list) + assert flashes + + # Verify the whitelist was deleted + assert db.session.get(Whitelist, whitelist.id) is None + + @patch("flowapp.services.whitelist_service.RuleWhitelistCache.clean_by_whitelist_id") + def test_delete_whitelist_with_whitelist_created_rules(self, mock_clean, app, db): + """Test deleting a whitelist that has rules created by the whitelist""" + # Create a whitelist + whitelist = Whitelist( + ip="192.168.1.0", + mask=24, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Create a RTBH rule (created by whitelist) + rtbh_rule = RTBH( + ipv4="192.168.1.0", + ipv4_mask=24, + ipv6="", + ipv6_mask=None, + community_id=1, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Create a cache entry linking the rule to the whitelist + cache_entry = RuleWhitelistCache( + rid=1, + rtype=RuleTypes.RTBH, + whitelist_id=1, + rorigin=RuleOrigin.WHITELIST, # Important: created BY whitelist + ) + + with app.app_context(): + db.session.add(whitelist) + db.session.add(rtbh_rule) + db.session.commit() + + # Set IDs now that they have been saved + cache_entry.whitelist_id = whitelist.id + cache_entry.rid = rtbh_rule.id + db.session.add(cache_entry) + db.session.commit() + + # Create a mock session that can get our rule + with patch.object(RuleWhitelistCache, "get_by_whitelist_id", return_value=[cache_entry]): + # Call the service function + flashes = whitelist_service.delete_whitelist(whitelist.id) + + # Verify flash messages + assert isinstance(flashes, list) + assert flashes + + # Verify the rule was deleted + assert db.session.get(RTBH, rtbh_rule.id) is None + + # Verify the whitelist was deleted + assert db.session.get(Whitelist, whitelist.id) is None + + # Verify cache cleanup was called + mock_clean.assert_called_once_with(whitelist.id) + + def test_delete_nonexistent_whitelist(self, app, db): + """Test deleting a whitelist that doesn't exist""" + with app.app_context(): + # Call the service function with a non-existent ID + flashes = whitelist_service.delete_whitelist(999) + + # Should return empty list of flash messages, as no whitelist was found + assert isinstance(flashes, list) + assert len(flashes) == 0 + + +class TestEvaluateWhitelistAgainstRtbhResults: + def test_equal_relation(self, app): + """Test evaluating a whitelist with an EQUAL relation to a rule""" + # Create test data + whitelist_model = MagicMock(spec=Whitelist) + whitelist_model.id = 1 + + flashes = [] + + rtbh_rule = MagicMock(spec=RTBH) + rtbh_rule.rstate_id = 1 # Active state + + rule_key = "192.168.1.0/24" + whitelist_key = "192.168.1.0/24" + rtbh_rule_cache = {rule_key: rtbh_rule} + + results = [(rule_key, whitelist_key, Relation.EQUAL)] + + # Mock required functions + with patch("flowapp.services.whitelist_service.whitelist_rtbh_rule") as mock_whitelist_rule, patch( + "flowapp.services.whitelist_service.withdraw_rtbh_route" + ) as mock_withdraw: + + # Call the function + with app.app_context(): + result = whitelist_service.evaluate_whitelist_against_rtbh_check_results( + whitelist_model, flashes, rtbh_rule_cache, results + ) + + # Verify the rule was whitelisted and route withdrawn + mock_whitelist_rule.assert_called_once_with(rtbh_rule, whitelist_model) + mock_withdraw.assert_called_once_with(rtbh_rule) + + # Verify the flash message + assert flashes + + # Verify the correct model was returned + assert result == whitelist_model + + def test_subnet_relation(self, app): + """Test evaluating a whitelist with a SUBNET relation to a rule""" + # Create test data + whitelist_model = MagicMock(spec=Whitelist) + whitelist_model.id = 1 + + flashes = [] + + rtbh_rule = MagicMock(spec=RTBH) + rtbh_rule.rstate_id = 1 # Active state + + rule_key = "192.168.1.0/24" + whitelist_key = "192.168.1.128/25" + rtbh_rule_cache = {rule_key: rtbh_rule} + + results = [(rule_key, whitelist_key, Relation.SUBNET)] + + # Mock required functions + with patch("flowapp.services.whitelist_service.subtract_network") as mock_subtract, patch( + "flowapp.services.whitelist_service.create_rtbh_from_whitelist_parts" + ) as mock_create, patch("flowapp.services.whitelist_service.add_rtbh_rule_to_cache") as mock_add_cache, patch( + "flowapp.services.whitelist_service.db.session.commit" + ) as mock_commit: + + # Mock subtract_network to return some subnets + mock_subtract.return_value = ["192.168.1.0/25"] + + # Call the function + with app.app_context(): + _result = whitelist_service.evaluate_whitelist_against_rtbh_check_results( + whitelist_model, flashes, rtbh_rule_cache, results + ) + + # Verify subnet calculation was performed + mock_subtract.assert_called_once() + + # Verify new rules were created for the subnets + mock_create.assert_called_once() + + # Verify the original rule was cached + mock_add_cache.assert_called_once_with(rtbh_rule, whitelist_model.id, RuleOrigin.USER) + + # Verify transaction was committed + mock_commit.assert_called_once() + + # Verify the flash messages + assert any("supernet of whitelist" in msg for msg in flashes) + + # Verify model was updated to whitelisted state + assert rtbh_rule.rstate_id == 4 + + def test_supernet_relation(self, app): + """Test evaluating a whitelist with a SUPERNET relation to a rule""" + # Create test data + whitelist_model = MagicMock(spec=Whitelist) + whitelist_model.id = 1 + + flashes = [] + + rtbh_rule = MagicMock(spec=RTBH) + rtbh_rule.rstate_id = 1 # Active state + + rule_key = "192.168.1.0/24" + whitelist_key = "192.168.0.0/16" + rtbh_rule_cache = {rule_key: rtbh_rule} + + results = [(rule_key, whitelist_key, Relation.SUPERNET)] + + # Mock required functions + with patch("flowapp.services.whitelist_service.whitelist_rtbh_rule") as mock_whitelist_rule, patch( + "flowapp.services.whitelist_service.withdraw_rtbh_route" + ) as mock_withdraw: + + # Call the function + with app.app_context(): + result = whitelist_service.evaluate_whitelist_against_rtbh_check_results( + whitelist_model, flashes, rtbh_rule_cache, results + ) + + # Verify the rule was whitelisted and route withdrawn + mock_whitelist_rule.assert_called_once_with(rtbh_rule, whitelist_model) + mock_withdraw.assert_called_once_with(rtbh_rule) + + # Verify the flash message + assert any("subnet of whitelist" in msg for msg in flashes) + + # Verify the correct model was returned + assert result == whitelist_model diff --git a/tests/test_zzz_api_rtbh_expired_bug.py b/tests/test_zzz_api_rtbh_expired_bug.py new file mode 100644 index 0000000..6fd36a2 --- /dev/null +++ b/tests/test_zzz_api_rtbh_expired_bug.py @@ -0,0 +1,240 @@ +import json +from datetime import datetime, timedelta + +from flowapp.models import RTBH +from flowapp.models.rules.whitelist import Whitelist + + +def test_create_rtbh_after_expired_rule_exists(client, app, db, jwt_token): + """ + Test that demonstrates the bug: creating a new RTBH rule with the same IP + as an expired rule results in the new rule having withdrawn state instead + of active state. + + Test should be run in isolation or as the last in stack. + + Steps: + 1. Create an RTBH rule with expiration in the past (will be withdrawn, rstate_id=2) + 2. Create another RTBH rule with the same IP but expiration in the future + 3. Verify that the second rule should be active (rstate_id=1) but is actually withdrawn (rstate_id=2) + """ + # cleanup any existing RTBH rules to avoid interference + cleanup_before_stack(app, db) + + # Step 1: Create an expired RTBH rule + expired_payload = { + "community": 1, + "ipv4": "192.168.100.50", + "ipv4_mask": 32, + "expires": (datetime.now() - timedelta(days=1)).strftime("%m/%d/%Y %H:%M"), + "comment": "Expired rule that should be in withdrawn state", + } + + response1 = client.post( + "/api/v3/rules/rtbh", + headers={"x-access-token": jwt_token}, + json=expired_payload, + ) + + assert response1.status_code == 201 + data1 = json.loads(response1.data) + rule_id_1 = data1["rule"]["id"] + + # Verify the first rule is in withdrawn state + with app.app_context(): + expired_rule = db.session.query(RTBH).filter_by(id=rule_id_1).first() + assert expired_rule is not None + assert expired_rule.rstate_id == 2, "Expired rule should be in withdrawn state (rstate_id=2)" + assert expired_rule.ipv4 == "192.168.100.50" + assert expired_rule.ipv4_mask == 32 + print(f"✓ First rule created with ID {rule_id_1}, state: {expired_rule.rstate_id} (withdrawn)") + + # Step 2: Create a new RTBH rule with the same IP but future expiration + future_payload = { + "community": 1, + "ipv4": "192.168.100.50", + "ipv4_mask": 32, + "expires": (datetime.now() + timedelta(days=7)).strftime("%m/%d/%Y %H:%M"), + "comment": "New rule that should be active but will be withdrawn due to bug", + } + + response2 = client.post( + "/api/v3/rules/rtbh", + headers={"x-access-token": jwt_token}, + json=future_payload, + ) + + assert response2.status_code == 201 + data2 = json.loads(response2.data) + rule_id_2 = data2["rule"]["id"] + + # Step 3: Verify the bug - the second rule should be active but is withdrawn + with app.app_context(): + # The bug causes the expired rule to be updated instead of creating a new one + # OR if a new rule is created, it has the wrong state + + # Check if it's the same rule (updated) or a new rule + total_rules = db.session.query(RTBH).filter_by(ipv4="192.168.100.50", ipv4_mask=32).count() + + new_rule = db.session.query(RTBH).filter_by(id=rule_id_2).first() + assert new_rule is not None + + print("\n--- Bug Verification ---") + print(f"Total rules with IP 192.168.100.50/32: {total_rules}") + print(f"First rule ID: {rule_id_1}") + print(f"Second rule ID: {rule_id_2}") + print(f"Same rule updated: {rule_id_1 == rule_id_2}") + print(f"Second rule state: {new_rule.rstate_id}") + print(f"Second rule expires: {new_rule.expires}") + print(f"Expiration is in future: {new_rule.expires > datetime.now()}") + + # The bug: even though expiration is in the future, the rule is in withdrawn state + # EXPECTED: rstate_id should be 1 (active) + # ACTUAL: rstate_id is 2 (withdrawn) due to the bug + + # This assertion will FAIL due to the bug, demonstrating the issue + assert new_rule.expires > datetime.now(), "Rule expiration should be in the future" + + # THIS IS THE BUG: The rule has future expiration but is in withdrawn state + try: + assert new_rule.rstate_id == 1, ( + f"BUG DETECTED: Rule with future expiration should be active (rstate_id=1), " + f"but is in state {new_rule.rstate_id}. " + f"This happens because the expired rule was found and updated without resetting the state." + ) + print("✓ Test PASSED - bug is fixed!") + except AssertionError as e: + print(f"✗ Test FAILED - bug confirmed: {e}") + raise + cleanup_rtbh_rule(app, db, rule_id_1) + cleanup_rtbh_rule(app, db, rule_id_2) + + +def test_create_rtbh_after_expired_rule_different_mask(client, app, db, jwt_token): + """ + Test that verifies the bug only occurs when IP AND mask match. + When the mask is different, a new rule should be created successfully. + """ + + # Step 1: Create an expired RTBH rule with /32 mask + expired_payload = { + "community": 1, + "ipv4": "192.168.100.60", + "ipv4_mask": 32, + "expires": (datetime.now() - timedelta(days=1)).strftime("%m/%d/%Y %H:%M"), + "comment": "Expired /32 rule", + } + + response1 = client.post( + "/api/v3/rules/rtbh", + headers={"x-access-token": jwt_token}, + json=expired_payload, + ) + + assert response1.status_code == 201 + + # Step 2: Create a new rule with same IP but different mask (/24) + future_payload = { + "community": 1, + "ipv4": "192.168.100.0", + "ipv4_mask": 24, + "expires": (datetime.now() + timedelta(days=7)).strftime("%m/%d/%Y %H:%M"), + "comment": "New /24 rule - should be active", + } + + response2 = client.post( + "/api/v3/rules/rtbh", + headers={"x-access-token": jwt_token}, + json=future_payload, + ) + + assert response2.status_code == 201 + data2 = json.loads(response2.data) + + # Verify the new rule is active (this should work because IP+mask don't match) + with app.app_context(): + new_rule = db.session.query(RTBH).filter_by(id=data2["rule"]["id"]).first() + assert new_rule is not None + assert new_rule.rstate_id == 1, "New rule with different mask should be active" + print("✓ Different mask creates new active rule correctly") + + cleanup_rtbh_rule(app, db, new_rule.id) + + +def test_create_rtbh_after_active_rule_exists(client, app, db, jwt_token): + """ + Test that when an active rule exists, updating it with a new expiration + maintains the active state (this should work correctly). + """ + + # Step 1: Create an active RTBH rule + active_payload = { + "community": 1, + "ipv4": "192.168.100.70", + "ipv4_mask": 32, + "expires": (datetime.now() + timedelta(days=1)).strftime("%m/%d/%Y %H:%M"), + "comment": "Active rule", + } + + response1 = client.post( + "/api/v3/rules/rtbh", + headers={"x-access-token": jwt_token}, + json=active_payload, + ) + + assert response1.status_code == 201 + data1 = json.loads(response1.data) + rule_id_1 = data1["rule"]["id"] + + # Verify the first rule is active + with app.app_context(): + first_rule = db.session.query(RTBH).filter_by(id=rule_id_1).first() + assert first_rule.rstate_id == 1, "First rule should be active" + + # Step 2: Update the same rule with a new expiration + updated_payload = { + "community": 1, + "ipv4": "192.168.100.70", + "ipv4_mask": 32, + "expires": (datetime.now() + timedelta(days=7)).strftime("%m/%d/%Y %H:%M"), + "comment": "Updated active rule", + } + + response2 = client.post( + "/api/v3/rules/rtbh", + headers={"x-access-token": jwt_token}, + json=updated_payload, + ) + + assert response2.status_code == 201 + data2 = json.loads(response2.data) + + # Verify it maintains active state + with app.app_context(): + updated_rule = db.session.query(RTBH).filter_by(id=data2["rule"]["id"]).first() + assert updated_rule is not None + assert updated_rule.rstate_id == 1, "Updated rule should remain active" + print("✓ Updating active rule maintains active state correctly") + + cleanup_rtbh_rule(app, db, rule_id_1) + + +def cleanup_before_stack(app, db): + """ + Cleanup function to remove all RTBH rules created during tests. + """ + with app.app_context(): + db.session.query(RTBH).delete() + db.session.query(Whitelist).delete() + db.session.commit() + + +def cleanup_rtbh_rule(app, db, rule_id): + """ + Cleanup function to remove RTBH rule created during tests. + """ + with app.app_context(): + rule = db.session.get(RTBH, rule_id) + if rule: + db.session.delete(rule) + db.session.commit() From b86b840525a590d626a8195b12a49fff6013bad8 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Tue, 27 Jan 2026 08:31:17 +0100 Subject: [PATCH 08/10] add as paths to admin menu --- flowapp/__about__.py | 2 +- flowapp/instance_config.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/flowapp/__about__.py b/flowapp/__about__.py index 2c73c33..debeb9a 100755 --- a/flowapp/__about__.py +++ b/flowapp/__about__.py @@ -1,4 +1,4 @@ -__version__ = "1.2.0-dev" +__version__ = "1.2.0b1" __title__ = "ExaFS" __description__ = "Tool for creation, validation, and execution of ExaBGP messages." __author__ = "CESNET / Jiri Vrany, Petr Adamec, Josef Verich, Jakub Man" diff --git a/flowapp/instance_config.py b/flowapp/instance_config.py index 4136530..bded1ca 100644 --- a/flowapp/instance_config.py +++ b/flowapp/instance_config.py @@ -102,8 +102,14 @@ class InstanceConfig: "divide_before": True, }, {"name": "Add action", "url": "admin.action"}, - {"name": "RTBH Communities", "url": "admin.communities"}, + { + "name": "RTBH Communities", + "url": "admin.communities", + "divide_before": True, + }, {"name": "Add RTBH Comm.", "url": "admin.community"}, + {"name": "AS Paths", "url": "admin.as_paths"}, + {"name": "Add AS Path", "url": "admin.as_path"}, ], } DASHBOARD = { From 0867f03a0439a0bc6e5c9270c50e214a74f553bd Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 29 Jan 2026 12:22:11 +0100 Subject: [PATCH 09/10] add GRE protocol --- flowapp/__about__.py | 2 +- flowapp/constants.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flowapp/__about__.py b/flowapp/__about__.py index debeb9a..b7f92fe 100755 --- a/flowapp/__about__.py +++ b/flowapp/__about__.py @@ -1,4 +1,4 @@ -__version__ = "1.2.0b1" +__version__ = "1.2.0b2" __title__ = "ExaFS" __description__ = "Tool for creation, validation, and execution of ExaBGP messages." __author__ = "CESNET / Jiri Vrany, Petr Adamec, Josef Verich, Jakub Man" diff --git a/flowapp/constants.py b/flowapp/constants.py index 819f685..b37c8a0 100644 --- a/flowapp/constants.py +++ b/flowapp/constants.py @@ -35,9 +35,9 @@ MAX_PORT = 65535 MAX_PACKET = 9216 -IPV6_NEXT_HEADER = {"tcp": "tcp", "udp": "udp", "icmp": "58", "all": ""} +IPV6_NEXT_HEADER = {"tcp": "tcp", "udp": "udp", "icmp": "58", "gre": "gre", "all": ""} -IPV4_PROTOCOL = {"tcp": "tcp", "udp": "udp", "icmp": "icmp", "all": ""} +IPV4_PROTOCOL = {"tcp": "tcp", "udp": "udp", "icmp": "icmp", "gre": "gre", "all": ""} IPV4_FRAGMENT = { "dont": "dont-fragment", From fdab2a43145f749a9e403c54c9b0cdd73ca4665d Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 29 Jan 2026 13:02:50 +0100 Subject: [PATCH 10/10] release 1.2.0 --- flowapp/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flowapp/__about__.py b/flowapp/__about__.py index b7f92fe..da7547d 100755 --- a/flowapp/__about__.py +++ b/flowapp/__about__.py @@ -1,4 +1,4 @@ -__version__ = "1.2.0b2" +__version__ = "1.2.0" __title__ = "ExaFS" __description__ = "Tool for creation, validation, and execution of ExaBGP messages." __author__ = "CESNET / Jiri Vrany, Petr Adamec, Josef Verich, Jakub Man"