diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 9e659abb..0f8eab6d 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -1,28 +1,25 @@ -# This workflow will install Python dependencies, run tests and lint with a single version of Python +# This workflow will install Python dependencies, run tests and lint with multiple versions of Python # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python - name: Python application - on: push: branches: [ "master", "develop" ] pull_request: branches: [ "master", "develop" ] - permissions: contents: read - jobs: build: - runs-on: ubuntu-latest - + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v3 with: - python-version: "3.9" + python-version: ${{ matrix.python-version }} - name: Setup timezone uses: zcong1993/setup-timezone@master with: diff --git a/README.md b/README.md index f0bc9742..f6f2ca12 100644 --- a/README.md +++ b/README.md @@ -35,26 +35,39 @@ See how is ExaFS integrated into the network in the picture below. ## System overview ![ExaFS schema](./docs/app_schema_en.png) +The core component of ExaFS is a web application written in Python using the Flask framework. It provides a user interface for managing ExaBGP rules (CRUD operations) and also exposes a REST API with similar functionality. The web application uses Shibboleth for authentication, while the REST API relies on token-based authentication. -The central part of the ExaFS is a web application, written in Python3.6 with Flask framework. It provides a user interface for ExaBGP rule CRUD operations. The application also provides the REST API with CRUD operations for the configuration rules. The web app uses Shibboleth authorization; the REST API is using token-based authorization. +The application generates ExaBGP commands and forwards them to the ExaBGP process. All rules are thoroughly validated—only valid rules are stored in the database and sent to the ExaBGP connector. -The app creates the ExaBGP commands and forwards them to ExaBGP process. All rules are carefully validated, and only valid rules are stored in the database and sent to the ExaBGP connector. - -This second part of the system is another application that replicates the received command to the stdout. The connection between ExaBGP daemon and stdout of ExaAPI (ExaBGP process) is specified in the ExaBGP config. +The second component of the system is a separate application that replicates received commands to `stdout`. The connection between the ExaBGP daemon and the `stdout` of the ExaAPI (ExaBGP process) is defined in the ExaBGP configuration. -This API was a part of the project, but now has been moved to own repository. You can use [pip package exabgp-process](https://pypi.org/project/exabgp-process/) or clone the git repo. Or you can create your own version. - -Every time this process gets a command from ExaFS, it replicates this command to the ExaBGP service through the stdout. The registered service then updates the ExaBGP table – create, modify or remove the rule from command. +This API was originally part of the same project but has since been moved to its own repository. You can use the [exabgp-process pip package](https://pypi.org/project/exabgp-process/), clone the Git repository, or develop your own implementation. -You may also need to monitor the ExaBGP and renew the commands after restart / shutdown. In docs you can find and example of system service named Guarda. This systemctl service is running in the host system and gets a notification on each restart of ExaBGP service via systemctl WantedBy config option. For every restart of ExaBGP the Guarda service will put all the valid and active rules to the ExaBGP rules table again. +Each time this process receives a command from ExaFS, it outputs it to `stdout`, allowing the ExaBGP service to process the command and update its routing table—creating, modifying, or removing rules accordingly. + +It may also be necessary to monitor ExaBGP and re-announce rules after a restart or shutdown. This can be handled via the ExaBGP service configuration, or by using an example system service called **Guarda**, described in the documentation. In either case, the key mechanism is calling the application endpoint `/rules/announce_all`. This endpoint is only accessible from `localhost`; a local IP address must be configured in the application settings. ## DOCS +### Instalation related +* [ExaFS Ansible deploy](https://github.com/CESNET/ExaFS-deploy) - repository with Ansbile playbook for deploying ExaFS with Docker Compose. * [Install notes](./docs/INSTALL.md) -* [API documentation ](https://exafs.docs.apiary.io/#) * [Database backup configuration](./docs/DB_BACKUP.md) * [Local database instalation notes](./docs/DB_LOCAL.md) +### API +The REST API is documented using Swagger (OpenAPI). After installing and running the application, the API documentation is available locally at the /apidocs/ endpoint. This interactive documentation provides details about all available endpoints, request and response formats, and supported operations, making it easier to integrate and test the API. + + ## Change Log +- 1.1.1 - Machine API Key rewrited. + - API keys for machines are now tied to one of the existing users. If there is a need to have API access for machine, first create service user, and set the access rights. Then create machine key as Admin and assign it to this user. +- 1.1.0 - Major Architecture Refactoring and Whitelist Integration + - Code Organization and Architecture Improvements. Significant architectural refactoring focused on better separation of concerns and improved maintainability. The most notable change is the introduction of a dedicated **services layer** that extracts business logic from view controllers. Key service modules include `rule_service.py` for rule management operations, `whitelist_service.py` for whitelist functionality, and `whitelist_common.py` for shared whitelist utilities. + - The **models structure** has been reorganized with better separation into logical modules. Rule models are now organized under `flowapp/models/rules/` with separate files for different rule types (`flowspec.py`, `rtbh.py`, `whitelist.py`), while maintaining backward compatibility through the main models `__init__.py`. Form handling has also been improved with better organization under `flowapp/forms/` and enhanced validation logic. + - **RTBH Whitelist Integration** This system automatically evaluates new RTBH rules against existing whitelists and can automatically modify or block rules that conflict with whitelisted networks. When an RTBH rule is created that intersects with a whitelist entry, the system can: + - **Automatically whitelist** rules that exactly match or are contained within whitelisted networks + - **Create subnet rules** when RTBH rules are supersets of whitelisted networks, automatically generating the non-whitelisted portions + - **Maintain rule cache** that tracks relationships between rules and whitelists for proper cleanup - 1.0.2 - fixed bug in IPv6 Flowspec messages - 1.0.1 . minor bug fixes - 1.0.0 . Major changes diff --git a/docs/guarda-service/README.md b/docs/guarda-service/README.md index 45f53446..e8086b44 100644 --- a/docs/guarda-service/README.md +++ b/docs/guarda-service/README.md @@ -1,5 +1,4 @@ # Guarda Service for ExaBGP - This is a systemd service designed to monitor ExaBGP and reapply commands after a restart or shutdown. The guarda.service runs on the host system and is triggered whenever the ExaBGP service restarts, thanks to the WantedBy configuration in systemd. After each restart, the Guarda service will reapply all valid and active rules to the ExaBGP rules table. ## Usage (as root) diff --git a/flowapp/__about__.py b/flowapp/__about__.py index d10af327..c7f183a6 100755 --- a/flowapp/__about__.py +++ b/flowapp/__about__.py @@ -1 +1 @@ -__version__ = "1.0.2" +__version__ = "1.1.1" diff --git a/flowapp/__init__.py b/flowapp/__init__.py index f029c37d..61543e78 100644 --- a/flowapp/__init__.py +++ b/flowapp/__init__.py @@ -1,10 +1,6 @@ # -*- coding: utf-8 -*- -import babel -import logging -from loguru import logger +from flask import Flask, redirect, render_template, session, url_for -from flask import Flask, redirect, render_template, session, url_for, request -from flask.logging import default_handler from flask_sso import SSO from flask_sqlalchemy import SQLAlchemy from flask_wtf.csrf import CSRFProtect @@ -26,16 +22,16 @@ swagger = Swagger(template_file="static/swagger.yml") -class InterceptHandler(logging.Handler): - - def emit(self, record): - logger_opt = logger.opt(depth=6, exception=record.exc_info, colors=True) - logger_opt.log(record.levelname, record.getMessage()) - - def create_app(config_object=None): app = Flask(__name__) + # Load the default configuration for dashboard and main menu + app.config.from_object(InstanceConfig) + if config_object: + app.config.from_object(config_object) + + app.config.setdefault("VERSION", __version__) + # SSO configuration SSO_ATTRIBUTE_MAP = { "eppn": (True, "eppn"), @@ -47,13 +43,6 @@ def create_app(config_object=None): migrate.init_app(app, db) csrf.init_app(app) - # Load the default configuration for dashboard and main menu - app.config.from_object(InstanceConfig) - if config_object: - app.config.from_object(config_object) - - app.config.setdefault("VERSION", __version__) - # Init SSO ext.init_app(app) @@ -64,76 +53,28 @@ def create_app(config_object=None): if app.config.get("BEHIND_PROXY", False): app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1) - from flowapp import models, constants, validators - from .views.admin import admin - from .views.rules import rules - from .views.api_v1 import api as api_v1 - from .views.api_v2 import api as api_v2 - from .views.api_v3 import api as api_v3 - from .views.api_keys import api_keys + from flowapp import models, constants from .auth import auth_required - from .views.dashboard import dashboard - - # no need for csrf on api because we use JWT - csrf.exempt(api_v1) - csrf.exempt(api_v2) - csrf.exempt(api_v3) - - app.register_blueprint(admin, url_prefix="/admin") - app.register_blueprint(rules, url_prefix="/rules") - app.register_blueprint(api_keys, url_prefix="/api_keys") - app.register_blueprint(api_v1, url_prefix="/api/v1") - app.register_blueprint(api_v2, url_prefix="/api/v2") - app.register_blueprint(api_v3, url_prefix="/api/v3") - app.register_blueprint(dashboard, url_prefix="/dashboard") - - # register loguru as handler - app.logger.removeHandler(default_handler) - app.logger.addHandler(InterceptHandler()) - - @ext.login_handler - def login(user_info): - try: - uuid = user_info.get("eppn") - except KeyError: - uuid = False - return render_template("errors/401.html") - return _handle_login(uuid) + # Register blueprints + from .utils import register_blueprints - @app.route("/logout") - def logout(): - session["user_uuid"] = False - session["user_id"] = False - session.clear() - return redirect(app.config.get("LOGOUT_URL")) + register_blueprints(app, csrf) - @app.route("/ext-login") - def ext_login(): - header_name = app.config.get("AUTH_HEADER_NAME", "X-Authenticated-User") - if header_name not in request.headers: - return render_template("errors/401.html") + # configure logging + from .utils import configure_logging - uuid = request.headers.get(header_name) - if not uuid: - return render_template("errors/401.html") + configure_logging(app) - return _handle_login(uuid) + # register error handlers + from .utils import register_error_handlers - @app.route("/local-login") - def local_login(): - print("Local login started") - if not app.config.get("LOCAL_AUTH", False): - print("Local auth not enabled") - return render_template("errors/401.html") + register_error_handlers(app) - uuid = app.config.get("LOCAL_USER_UUID", False) - if not uuid: - print("Local user not set") - return render_template("errors/401.html") + # register auth handlers + from .utils import register_auth_handlers - print(f"Local login with {uuid}") - return _handle_login(uuid) + register_auth_handlers(app, ext) @app.route("/") @auth_required @@ -192,89 +133,10 @@ def select_org(org_id=None): def shutdown_session(exception=None): db.session.remove() - # HTTP error handling - @app.errorhandler(404) - def not_found(error): - return render_template("errors/404.html"), 404 - - @app.errorhandler(500) - def internal_error(exception): - app.logger.exception(exception) - return render_template("errors/500.html"), 500 - - @app.context_processor - def utility_processor(): - def editable_rule(rule): - if rule: - validators.editable_range(rule, models.get_user_nets(session["user_id"])) - return True - return False - - return dict(editable_rule=editable_rule) - - @app.context_processor - def inject_main_menu(): - """ - inject main menu config to templates - used in default template to create main menu - """ - return {"main_menu": app.config.get("MAIN_MENU")} - - @app.context_processor - def inject_dashboard(): - """ - inject dashboard config to templates - used in submenu dashboard to create dashboard tables - """ - return {"dashboard": app.config.get("DASHBOARD")} - - @app.template_filter("strftime") - def format_datetime(value): - if value is None: - return app.config.get("MISSING_DATETIME_MESSAGE", "Never") - - format = "y/MM/dd HH:mm" - return babel.dates.format_datetime(value, format) - - @app.template_filter("unlimited") - def unlimited_filter(value): - return "unlimited" if value == 0 else value - - def _handle_login(uuid: str): - """ - handles rest of login process - """ - multiple_orgs = False - try: - user, multiple_orgs = _register_user_to_session(uuid) - except AttributeError as e: - app.logger.exception(e) - return render_template("errors/401.html") - - if multiple_orgs: - return redirect(url_for("select_org", org_id=None)) + # register context processors and template filters + from .utils import register_context_processors, register_template_filters - # set user org to session - user_org = user.organization.first() - session["user_org"] = user_org.name - session["user_org_id"] = user_org.id - - return redirect("/") - - def _register_user_to_session(uuid: str): - print(f"Registering user {uuid} to session") - user = db.session.query(models.User).filter_by(uuid=uuid).first() - print(f"Got user {user} from DB") - session["user_uuid"] = user.uuid - session["user_email"] = user.uuid - session["user_name"] = user.name - session["user_id"] = user.id - session["user_roles"] = [role.name for role in user.role.all()] - session["user_role_ids"] = [role.id for role in user.role.all()] - roles = [i > 1 for i in session["user_role_ids"]] - session["can_edit"] = True if all(roles) and roles else [] - # check if user has multiple organizations and return True if so - print(f"DEBUG SESSION {session}") - return user, len(user.organization.all()) > 1 + register_context_processors(app) + register_template_filters(app) return app diff --git a/flowapp/constants.py b/flowapp/constants.py index 5aa52834..975f99b7 100644 --- a/flowapp/constants.py +++ b/flowapp/constants.py @@ -2,6 +2,7 @@ This module contains constant values used in application """ +from enum import Enum from operator import ge, lt DEFAULT_SORT = "expires" @@ -59,7 +60,12 @@ FORM_TIME_PATTERN = "%Y-%m-%dT%H:%M" -class RuleTypes: +class RuleTypes(Enum): RTBH = 1 IPv4 = 4 IPv6 = 6 + + +class RuleOrigin(Enum): + USER = 1 + WHITELIST = 2 diff --git a/flowapp/flowspec.py b/flowapp/flowspec.py index e0ce35a5..ae357ca8 100644 --- a/flowapp/flowspec.py +++ b/flowapp/flowspec.py @@ -21,6 +21,19 @@ def translate_sequence(sequence, max_val=MAX_PORT): return "[{}]".format(result) +def check_limit(value, max_value, min_value=0): + """ + test if the value is within valid range (min_value to max_value inclusive) + raise exception otherwise + """ + value = int(value) + if value > max_value: + raise ValueError("Invalid value number: {} is too big. Max is {}.".format(value, max_value)) + if value < min_value: + raise ValueError("Invalid value number: {} is too small. Min is {}.".format(value, min_value)) + return value + + def to_exabgp_string(value_string, max_val): """ Translate form string to flowspec value or packet size rule @@ -42,35 +55,30 @@ def to_exabgp_string(value_string, max_val): return "={}".format(check_limit(value_string, max_val)) elif RANGE.match(value_string): m = RANGE.match(value_string) - return ">={}&<={}".format( - check_limit(m.group(1), max_val), check_limit(m.group(2), max_val) - ) + start = check_limit(m.group(1), max_val) + end = check_limit(m.group(2), max_val) + if start > end: + raise ValueError("Invalid range: start value cannot be greater than end value") + return ">={}&<={}".format(start, end) elif NOTRAN.match(value_string): - return value_string + m = NOTRAN.match(value_string) + start = check_limit(m.group(1), max_val) + end = check_limit(m.group(2), max_val) + if start > end: + raise ValueError("Invalid range: start value cannot be greater than end value") + return ">={}&<={}".format(start, end) elif GREATER.match(value_string): m = GREATER.match(value_string) return ">={}&<={}".format(check_limit(m.group(1), max_val), max_val) elif LOWER.match(value_string): m = LOWER.match(value_string) - return ">=0&<={}".format(check_limit(m.group(1), max_val)) + # Even for lower bound expressions, validate that the value itself is within range + end = check_limit(m.group(1), max_val) + return ">=0&<={}".format(end) else: raise ValueError("string {} can not be converted".format(value_string)) -def check_limit(value, max_value): - """ - test if the value is lower than max_value - raise exception otherwise - """ - value = int(value) - if value > max_value: - raise ValueError( - "Invalid value number: {} is too big. Max is {}.".format(value, max_value) - ) - else: - return value - - def filter_rules_action(user_actions, rules): """ Divide the list of rules by user_actions to editable and viewonly subsets diff --git a/flowapp/forms.py b/flowapp/forms.py deleted file mode 100644 index 4b914dce..00000000 --- a/flowapp/forms.py +++ /dev/null @@ -1,650 +0,0 @@ -import csv -from io import StringIO - - -from flask_wtf import FlaskForm -from wtforms import widgets -from wtforms import ( - BooleanField, - DateTimeField, - HiddenField, - IntegerField, - SelectField, - SelectMultipleField, - StringField, - TextAreaField, -) -from wtforms.validators import ( - ValidationError, - DataRequired, - Email, - InputRequired, - IPAddress, - Length, - NumberRange, - Optional, -) - -from flowapp.constants import ( - IPV4_FRAGMENT, - IPV4_PROTOCOL, - IPV6_NEXT_HEADER, - TCP_FLAGS, - FORM_TIME_PATTERN, -) -from flowapp.validators import ( - IPv4Address, - IPv6Address, - NetRangeString, - PortString, - address_in_range, - address_with_mask, - network_in_range, - whole_world_range, -) - -from flowapp.utils import parse_api_time - - -class MultiFormatDateTimeLocalField(DateTimeField): - """ - Same as :class:`~wtforms.fields.DateTimeField`, but represents an - ````. - - Custom implementation uses default HTML5 format for parsing the field. - It's possible to use multiple formats - used in API. - - """ - - widget = widgets.DateTimeLocalInput() - - def __init__(self, *args, **kwargs): - kwargs.setdefault("format", "%Y-%m-%dT%H:%M") - self.unlimited = kwargs.pop("unlimited", False) - self.pref_format = None - super().__init__(*args, **kwargs) - - def process_formdata(self, valuelist): - if not valuelist: - return None - # with unlimited field we do not need to parse the empty value - if self.unlimited and len(valuelist) == 1 and len(valuelist[0]) == 0: - self.data = None - return None - - date_str = " ".join((str(val) for val in valuelist)) - result, pref_format = parse_api_time(date_str) - if result: - self.data = result - self.pref_format = pref_format - else: - self.data = None - self.pref_format = None - raise ValueError(self.gettext("Not a valid datetime value.")) - - -class UserForm(FlaskForm): - """ - User Form object - used in Admin - """ - - uuid = StringField( - "Unique User ID", - validators=[ - InputRequired("Please provide UUID"), - Email("Please provide valid email"), - ], - ) - - email = StringField("Email", validators=[Optional(), Email("Please provide valid email")]) - - comment = StringField("Notice", validators=[Optional()]) - - name = StringField("Name", validators=[Optional()]) - - phone = StringField("Contact phone", validators=[Optional()]) - - role_ids = SelectMultipleField("Role", coerce=int, validators=[DataRequired("Select at last one role")]) - - org_ids = SelectMultipleField( - "Organization", - coerce=int, - validators=[DataRequired("We prefer one Organization per user, but it's possible select more")], - ) - - -class BulkUserForm(FlaskForm): - """ - Bulk User Form object - used in Admin - """ - - users = TextAreaField("Users in CSV - see example below", validators=[DataRequired()]) - - def __init__(self, *args, **kwargs): - super(BulkUserForm, self).__init__(*args, **kwargs) - self.roles = None - self.organizations = None - self.uuids = None - - # Custom validator for CSV data - def validate_users(self, field): - csv_data = field.data - - # Parse CSV data - csv_reader = csv.DictReader(StringIO(csv_data), delimiter=",") - - # List to keep track of failed validation rows - errors = 0 - for row_num, row in enumerate(csv_reader, start=1): - try: - # check if the user not already exists - if row["uuid-eppn"] in self.uuids: - field.errors.append(f"Row {row_num}: User with UUID {row['uuid-eppn']} already exists.") - errors += 1 - - # Check if role exists in the database - role_id = int(row["role"]) # Convert role field to integer - if role_id not in self.roles: - field.errors.append(f"Row {row_num}: Role ID {role_id} does not exist.") - errors += 1 - - # Check if organization exists in the database - org_id = int(row["organizace"]) # Convert organization field to integer - if org_id not in self.organizations: - field.errors.append(f"Row {row_num}: Organization ID {org_id} does not exist.") - errors += 1 - - except (KeyError, ValueError) as e: - field.errors.append(f"Row {row_num}: Invalid data / key - {str(e)}. Check CSV head row.") - - if errors > 0: - # Raise validation error if any invalid rows found - raise ValidationError("Invalid CSV Data - check the errors above.") - - -class ApiKeyForm(FlaskForm): - """ - ApiKey for User - Each key / machine pair is unique - """ - - machine = StringField( - "Machine address", - validators=[DataRequired(), IPAddress(message="provide valid IP address")], - ) - - comment = TextAreaField("Your comment for this key", validators=[Optional(), Length(max=255)]) - - expires = MultiFormatDateTimeLocalField( - "Key expiration. Leave blank for non expring key (not-recomended).", - format=FORM_TIME_PATTERN, - validators=[Optional()], - unlimited=True, - ) - - readonly = BooleanField("Read only key", default=False) - - key = HiddenField("GeneratedKey") - - -class MachineApiKeyForm(FlaskForm): - """ - ApiKey for Machines - Each key / machine pair is unique - Only Admin can create new these keys - """ - - machine = StringField( - "Machine address", - validators=[DataRequired(), IPAddress(message="provide valid IP address")], - ) - - comment = TextAreaField("Your comment for this key", validators=[Optional(), Length(max=255)]) - - expires = MultiFormatDateTimeLocalField( - "Key expiration. Leave blank for non expring key (not-recomended).", - format=FORM_TIME_PATTERN, - validators=[Optional()], - unlimited=True, - ) - - readonly = BooleanField("Read only key", default=False) - - key = HiddenField("GeneratedKey") - - -class OrganizationForm(FlaskForm): - """ - Organization form object - used in Admin - """ - - name = StringField("Organization name", validators=[Optional(), Length(max=150)]) - - limit_flowspec4 = IntegerField( - "Maximum number of IPv4 rules, 0 for unlimited", - validators=[ - Optional(), - NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), - ], - ) - - limit_flowspec6 = IntegerField( - "Maximum number of IPv6 rules, 0 for unlimited", - validators=[ - Optional(), - NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), - ], - ) - - limit_rtbh = IntegerField( - "Maximum number of RTBH rules, 0 for unlimited", - validators=[ - Optional(), - NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), - ], - ) - - arange = TextAreaField( - "Organization Adress Range - one range per row", - validators=[Optional(), NetRangeString()], - ) - - -class ActionForm(FlaskForm): - """ - Action form object - used in Admin - """ - - name = StringField("Action short name", validators=[Length(max=150)]) - - command = StringField("ExaBGP command", validators=[Length(max=150)]) - - description = StringField("Action description") - - role_id = SelectField( - "Minimal required role", - choices=[("2", "user"), ("3", "admin")], - validators=[DataRequired()], - ) - - -class ASPathForm(FlaskForm): - """ - AS Path form object - used in Admin - """ - - prefix = StringField("Prefix", validators=[Length(max=120), DataRequired()]) - - as_path = StringField("as-path value", validators=[Length(max=250), DataRequired()]) - - -class CommunityForm(FlaskForm): - """ - Community form object - used in Admin - """ - - name = StringField("Community short name", validators=[Length(max=120), DataRequired()]) - - comm = StringField("Community value", validators=[Length(max=2046)]) - - larcomm = StringField("Large community value", validators=[Length(max=2046)]) - - extcomm = StringField("Extended community value", validators=[Length(max=2046)]) - - description = StringField("Community description", validators=[Length(max=255)]) - - role_id = SelectField( - "Minimal required role", - choices=[("2", "user"), ("3", "admin")], - validators=[DataRequired()], - ) - - as_path = BooleanField("add AS-path (checked = true)") - - def validate(self): - """ - custom validation method - :return: boolean - """ - result = True - - if not FlaskForm.validate(self): - result = False - - if not self.comm.data and not self.extcomm.data and not self.larcomm.data: - err_message = "At last one of those values could not be empty" - self.comm.errors.append(err_message) - self.larcomm.errors.append(err_message) - self.extcomm.errors.append(err_message) - result = False - - return result - - -class RTBHForm(FlaskForm): - """ - RoadToBlackHole rule form - """ - - def __init__(self, *args, **kwargs): - super(RTBHForm, self).__init__(*args, **kwargs) - self.net_ranges = None - - ipv4 = StringField( - "IPv4 address", - validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], - ) - - ipv4_mask = IntegerField( - "IPv4 mask (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=32, message="invalid IPv4 mask value (0-32)"), - ], - ) - - ipv6 = StringField( - "IPv6 address", - validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], - ) - - ipv6_mask = IntegerField( - "IPv6 mask (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=128, message="invalid IPv6 mask value (0-128)"), - ], - ) - - community = SelectField( - "Community", - coerce=int, - validators=[ - DataRequired(message="Please select a community for the rule."), - ], - ) - - expires = MultiFormatDateTimeLocalField( - "Expires", - format=FORM_TIME_PATTERN, - validators=[DataRequired(), InputRequired()], - ) - - comment = arange = TextAreaField("Comments") - - def validate(self): - """ - custom validation method - :return: boolean - """ - result = True - - if not FlaskForm.validate(self): - result = False - - # ipv4 and ipv6 are mutually exclusive - # if both are set, validation fails - # if none is set, validation fails - # if one is set, validation passes - if self.ipv4.data and self.ipv6.data: - self.ipv4.errors.append("IPv4 and IPv6 are mutually exclusive in RTBH rule.") - self.ipv6.errors.append("IPv4 and IPv6 are mutually exclusive in RTBH rule.") - result = False - - if self.ipv4.data and not address_with_mask(self.ipv4.data, self.ipv4_mask.data): - self.ipv4.errors.append( - "This is not valid combination of address {} and mask {}.".format(self.ipv4.data, self.ipv4_mask.data) - ) - result = False - - if self.ipv6.data and not address_with_mask(self.ipv6.data, self.ipv6_mask.data): - self.ipv6.errors.append( - "This is not valid combination of address {} and mask {}.".format(self.ipv6.data, self.ipv6_mask.data) - ) - result = False - - ipv6_in_range = address_in_range(self.ipv6.data, self.net_ranges) - ipv4_in_range = address_in_range(self.ipv4.data, self.net_ranges) - - if not (ipv6_in_range or ipv4_in_range): - self.ipv6.errors.append("IPv4 or IPv6 address must be in organization range : {}.".format(self.net_ranges)) - self.ipv4.errors.append("IPv4 or IPv6 address must be in organization range : {}.".format(self.net_ranges)) - result = False - - return result - - -class IPForm(FlaskForm): - """ - Base class for IPv4 and IPv6 rules - """ - - def __init__(self, *args, **kwargs): - super(IPForm, self).__init__(*args, **kwargs) - self.net_ranges = None - - zero_address = None - source = None - source_mask = None - dest = None - dest_mask = None - flags = SelectMultipleField("TCP flag(s)", choices=TCP_FLAGS, validators=[Optional()]) - - source_port = StringField( - "Source port(s) - ; separated ", - validators=[Optional(), Length(max=255), PortString()], - ) - - dest_port = StringField( - "Destination port(s) - ; separated", - validators=[Optional(), Length(max=255), PortString()], - ) - - packet_len = StringField( - "Packet length - ; separated ", - validators=[Optional(), Length(max=255), PortString()], - ) - - action = SelectField( - "Action", - coerce=int, - validators=[DataRequired(message="Please select an action for the rule.")], - ) - - expires = MultiFormatDateTimeLocalField("Expires", format="%Y-%m-%dT%H:%M", validators=[InputRequired()]) - - comment = arange = TextAreaField("Comments") - - def validate(self): - """ - custom validation method - :return: boolean - """ - - result = True - if not FlaskForm.validate(self): - result = False - - source = self.validate_source_address() - dest = self.validate_dest_address() - ranges = self.validate_address_ranges() - ips = self.validate_ipv_specific() - - return result and source and dest and ranges and ips - - def validate_source_address(self): - """ - validate source address, set error message if validation fails - :return: boolean validation result - """ - if self.source.data and not address_with_mask(self.source.data, self.source_mask.data): - self.source.errors.append( - "This is not valid combination of address {} and mask {}.".format( - self.source.data, self.source_mask.data - ) - ) - return False - - return True - - def validate_dest_address(self): - """ - validate dest address, set error message if validation fails - :return: boolean validation result - """ - if self.dest.data and not address_with_mask(self.dest.data, self.dest_mask.data): - self.dest.errors.append( - "This is not valid combination of address {} and mask {}.".format(self.dest.data, self.dest_mask.data) - ) - return False - - return True - - def validate_address_ranges(self): - """ - validates if the address of source is in the user range - if the source and dest address are empty, check if the user - is member of whole world organization - :return: boolean validation result - """ - if not (self.source.data or self.dest.data): - whole_world_member = whole_world_range(self.net_ranges, self.zero_address) - if not whole_world_member: - self.source.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) - self.dest.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) - return False - else: - source_in_range = network_in_range(self.source.data, self.source_mask.data, self.net_ranges) - dest_in_range = network_in_range(self.dest.data, self.dest_mask.data, self.net_ranges) - if not (source_in_range or dest_in_range): - self.source.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) - self.dest.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) - return False - - return True - - def validate_ipv_specific(self): - """ - abstract method must be implemented in the subclass - """ - pass - - -class IPv4Form(IPForm): - """ - IPv4 form object - """ - - def __init__(self, *args, **kwargs): - super(IPv4Form, self).__init__(*args, **kwargs) - self.net_ranges = None - - zero_address = "0.0.0.0" - source = StringField( - "Source address", - validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], - ) - - source_mask = IntegerField( - "Source mask (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=32, message="invalid mask value (0-32)"), - ], - ) - - dest = StringField( - "Destination address", - validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], - ) - - dest_mask = IntegerField( - "Destination mask (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=32, message="invalid mask value (0-32)"), - ], - ) - - protocol = SelectField( - "Protocol", - choices=[(pr, pr.upper()) for pr in IPV4_PROTOCOL.keys()], - validators=[DataRequired()], - ) - - fragment = SelectMultipleField( - "Fragment", - choices=[(frv, frk.upper()) for frk, frv in IPV4_FRAGMENT.items()], - validators=[Optional()], - ) - - def validate_ipv_specific(self): - """ - validate protocol and flags, set error message if validation fails - :return: boolean validation result - """ - - if self.flags.data and self.protocol.data and len(self.flags.data) > 0 and self.protocol.data != "tcp": - self.flags.errors.append("Can not set TCP flags for protocol {} !".format(self.protocol.data.upper())) - return False - return True - - -class IPv6Form(IPForm): - """ - IPv6 form object - """ - - def __init__(self, *args, **kwargs): - super(IPv6Form, self).__init__(*args, **kwargs) - self.net_ranges = None - - zero_address = "::" - source = StringField( - "Source address", - validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], - ) - - source_mask = IntegerField( - "Source prefix length (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=128, message="invalid prefix value (0-128)"), - ], - ) - - dest = StringField( - "Destination address", - validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], - ) - - dest_mask = IntegerField( - "Destination prefix length (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=128, message="invalid prefix value (0-128)"), - ], - ) - - next_header = SelectField( - "Next Header", - choices=[(pr, pr.upper()) for pr in IPV6_NEXT_HEADER.keys()], - validators=[DataRequired()], - ) - - def validate_ipv_specific(self): - """ - validate next header and flags, set error message if validation fails - :return: boolean validation result - """ - if len(self.flags.data) > 0 and self.next_header.data != "tcp": - self.flags.errors.append("Can not set TCP flags for next-header {} !".format(self.next_header.data.upper())) - return False - - return True diff --git a/flowapp/forms/__init__.py b/flowapp/forms/__init__.py new file mode 100644 index 00000000..a9d7d887 --- /dev/null +++ b/flowapp/forms/__init__.py @@ -0,0 +1,40 @@ +""" +Forms module for the flowapp application. +This file imports and re-exports all form classes from the module. +""" + +# Base field +from .base import MultiFormatDateTimeLocalField + +# User forms +from .user import UserForm, BulkUserForm + +# API key forms +from .api import ApiKeyForm, MachineApiKeyForm + +# Organization forms +from .organization import OrganizationForm + +# Action, ASPath, and Community forms +from .choices import ActionForm, ASPathForm, CommunityForm + +# Rule forms +from .rules import IPForm, IPv4Form, IPv6Form, RTBHForm, WhitelistForm + + +__all__ = [ + "MultiFormatDateTimeLocalField", + "UserForm", + "BulkUserForm", + "ApiKeyForm", + "MachineApiKeyForm", + "OrganizationForm", + "ActionForm", + "ASPathForm", + "CommunityForm", + "RTBHForm", + "IPForm", + "IPv4Form", + "IPv6Form", + "WhitelistForm", +] diff --git a/flowapp/forms/api.py b/flowapp/forms/api.py new file mode 100644 index 00000000..c9af19a0 --- /dev/null +++ b/flowapp/forms/api.py @@ -0,0 +1,68 @@ +""" +API key forms for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import SelectField, StringField, TextAreaField, BooleanField, HiddenField +from wtforms.validators import DataRequired, IPAddress, Optional, Length + +from .base import MultiFormatDateTimeLocalField +from ..constants import FORM_TIME_PATTERN + + +class ApiKeyForm(FlaskForm): + """ + ApiKey for User + Each key / machine pair is unique + """ + + machine = StringField( + "Machine address", + validators=[DataRequired(), IPAddress(message="provide valid IP address")], + ) + + comment = TextAreaField("Your comment for this key", validators=[Optional(), Length(max=255)]) + + expires = MultiFormatDateTimeLocalField( + "Key expiration. Leave blank for non expring key (not-recomended).", + format=FORM_TIME_PATTERN, + validators=[Optional()], + unlimited=True, + ) + + readonly = BooleanField("Read only key", default=False) + + key = HiddenField("GeneratedKey") + + +class MachineApiKeyForm(FlaskForm): + """ + ApiKey for Services / No login users. + Each key / machine pair is unique + This form is used by Admin to create api key for services or users with no Shibboleth login + User must be created first and must have an organization + """ + + machine = StringField( + "Machine address", + validators=[DataRequired(), IPAddress(message="provide valid IP address")], + ) + + comment = TextAreaField("Your comment for this key", validators=[Optional(), Length(max=255)]) + + expires = MultiFormatDateTimeLocalField( + "Key expiration. Leave blank for non expring key (not-recomended).", + format=FORM_TIME_PATTERN, + validators=[Optional()], + unlimited=True, + ) + + readonly = BooleanField("Read only key", default=True) + + user = SelectField( + "User", + coerce=int, + validators=[DataRequired("Select user")], + ) + + key = HiddenField("GeneratedKey") diff --git a/flowapp/forms/base.py b/flowapp/forms/base.py new file mode 100644 index 00000000..f99db23d --- /dev/null +++ b/flowapp/forms/base.py @@ -0,0 +1,50 @@ +""" +Base form fields for the flowapp application. +""" + +from wtforms import widgets +from wtforms.fields import DateTimeField +from ..utils import parse_api_time + + +class MultiFormatDateTimeLocalField(DateTimeField): + """ + Same as :class:`~wtforms.fields.DateTimeField`, but represents an + ````. + + Custom implementation uses default HTML5 format for parsing the field. + It's possible to use multiple formats - used in API. + + """ + + widget = widgets.DateTimeLocalInput() + + def __init__(self, *args, **kwargs): + kwargs.setdefault("format", "%Y-%m-%dT%H:%M") + self.unlimited = kwargs.pop("unlimited", False) + self.pref_format = None + super().__init__(*args, **kwargs) + + def process_formdata(self, valuelist): + if not valuelist or (len(valuelist) == 1 and not valuelist[0]): + return None + + # with unlimited field we do not need to parse the empty value + if self.unlimited and len(valuelist) == 1 and len(valuelist[0]) == 0: + self.data = None + return None + + date_str = " ".join((str(val) for val in valuelist)) + + try: + result, pref_format = parse_api_time(date_str) + except TypeError: + raise ValueError(self.gettext("Not a valid datetime value.")) + + if result: + self.data = result + self.pref_format = pref_format + else: + self.data = None + self.pref_format = None + raise ValueError(self.gettext("Not a valid datetime value.")) diff --git a/flowapp/forms/choices.py b/flowapp/forms/choices.py new file mode 100644 index 00000000..424b6b80 --- /dev/null +++ b/flowapp/forms/choices.py @@ -0,0 +1,81 @@ +""" +Action, ASPath, and Community forms for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import StringField, SelectField, BooleanField +from wtforms.validators import Length, DataRequired + + +class ActionForm(FlaskForm): + """ + Action form object + used in Admin + """ + + name = StringField("Action short name", validators=[Length(max=150)]) + + command = StringField("ExaBGP command", validators=[Length(max=150)]) + + description = StringField("Action description") + + role_id = SelectField( + "Minimal required role", + choices=[("2", "user"), ("3", "admin")], + validators=[DataRequired()], + ) + + +class ASPathForm(FlaskForm): + """ + AS Path form object + used in Admin + """ + + prefix = StringField("Prefix", validators=[Length(max=120), DataRequired()]) + + as_path = StringField("as-path value", validators=[Length(max=250), DataRequired()]) + + +class CommunityForm(FlaskForm): + """ + Community form object + used in Admin + """ + + name = StringField("Community short name", validators=[Length(max=120), DataRequired()]) + + comm = StringField("Community value", validators=[Length(max=2046)]) + + larcomm = StringField("Large community value", validators=[Length(max=2046)]) + + extcomm = StringField("Extended community value", validators=[Length(max=2046)]) + + description = StringField("Community description", validators=[Length(max=255)]) + + role_id = SelectField( + "Minimal required role", + choices=[("2", "user"), ("3", "admin")], + validators=[DataRequired()], + ) + + as_path = BooleanField("add AS-path (checked = true)") + + def validate(self): + """ + custom validation method + :return: boolean + """ + result = True + + if not FlaskForm.validate(self): + result = False + + if not self.comm.data and not self.extcomm.data and not self.larcomm.data: + err_message = "At last one of those values could not be empty" + self.comm.errors.append(err_message) + self.larcomm.errors.append(err_message) + self.extcomm.errors.append(err_message) + result = False + + return result diff --git a/flowapp/forms/organization.py b/flowapp/forms/organization.py new file mode 100644 index 00000000..37594ab0 --- /dev/null +++ b/flowapp/forms/organization.py @@ -0,0 +1,47 @@ +""" +Organization form for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import StringField, IntegerField, TextAreaField +from wtforms.validators import Optional, Length, NumberRange + +from ..validators import NetRangeString + + +class OrganizationForm(FlaskForm): + """ + Organization form object + used in Admin + """ + + name = StringField("Organization name", validators=[Optional(), Length(max=150)]) + + limit_flowspec4 = IntegerField( + "Maximum number of IPv4 rules, 0 for unlimited", + validators=[ + Optional(), + NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), + ], + ) + + limit_flowspec6 = IntegerField( + "Maximum number of IPv6 rules, 0 for unlimited", + validators=[ + Optional(), + NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), + ], + ) + + limit_rtbh = IntegerField( + "Maximum number of RTBH rules, 0 for unlimited", + validators=[ + Optional(), + NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), + ], + ) + + arange = TextAreaField( + "Organization Adress Range - one range per row", + validators=[Optional(), NetRangeString()], + ) diff --git a/flowapp/forms/rules/__init__.py b/flowapp/forms/rules/__init__.py new file mode 100644 index 00000000..9796e8ce --- /dev/null +++ b/flowapp/forms/rules/__init__.py @@ -0,0 +1,17 @@ +""" +Rule forms for the flowapp application. +""" + +from .base import IPForm +from .ipv4 import IPv4Form +from .ipv6 import IPv6Form +from .rtbh import RTBHForm +from .whitelist import WhitelistForm + +__all__ = [ + "IPForm", + "IPv4Form", + "IPv6Form", + "RTBHForm", + "WhitelistForm", +] diff --git a/flowapp/forms/rules/base.py b/flowapp/forms/rules/base.py new file mode 100644 index 00000000..0e67776e --- /dev/null +++ b/flowapp/forms/rules/base.py @@ -0,0 +1,127 @@ +""" +Base rule form for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import SelectMultipleField, StringField, SelectField, TextAreaField +from wtforms.validators import Optional, Length, DataRequired, InputRequired + +from ..base import MultiFormatDateTimeLocalField +from ...constants import TCP_FLAGS +from ...validators import PortString, address_with_mask, network_in_range, whole_world_range + + +class IPForm(FlaskForm): + """ + Base class for IPv4 and IPv6 rules + """ + + def __init__(self, *args, **kwargs): + super(IPForm, self).__init__(*args, **kwargs) + self.net_ranges = None + + zero_address = None + source = None + source_mask = None + dest = None + dest_mask = None + flags = SelectMultipleField("TCP flag(s)", choices=TCP_FLAGS, validators=[Optional()]) + + source_port = StringField( + "Source port(s) - ; separated ", + validators=[Optional(), Length(max=255), PortString()], + ) + + dest_port = StringField( + "Destination port(s) - ; separated", + validators=[Optional(), Length(max=255), PortString()], + ) + + packet_len = StringField( + "Packet length - ; separated ", + validators=[Optional(), Length(max=255), PortString()], + ) + + action = SelectField( + "Action", + coerce=int, + validators=[DataRequired(message="Please select an action for the rule.")], + ) + + expires = MultiFormatDateTimeLocalField("Expires", format="%Y-%m-%dT%H:%M", validators=[InputRequired()]) + + comment = arange = TextAreaField("Comments") + + def validate(self): + """ + custom validation method + :return: boolean + """ + + result = True + if not FlaskForm.validate(self): + result = False + + source = self.validate_source_address() + dest = self.validate_dest_address() + ranges = self.validate_address_ranges() + ips = self.validate_ipv_specific() + + return result and source and dest and ranges and ips + + def validate_source_address(self): + """ + validate source address, set error message if validation fails + :return: boolean validation result + """ + if self.source.data and not address_with_mask(self.source.data, self.source_mask.data): + self.source.errors.append( + "This is not valid combination of address {} and mask {}.".format( + self.source.data, self.source_mask.data + ) + ) + return False + + return True + + def validate_dest_address(self): + """ + validate dest address, set error message if validation fails + :return: boolean validation result + """ + if self.dest.data and not address_with_mask(self.dest.data, self.dest_mask.data): + self.dest.errors.append( + "This is not valid combination of address {} and mask {}.".format(self.dest.data, self.dest_mask.data) + ) + return False + + return True + + def validate_address_ranges(self): + """ + validates if the address of source is in the user range + if the source and dest address are empty, check if the user + is member of whole world organization + :return: boolean validation result + """ + if not (self.source.data or self.dest.data): + whole_world_member = whole_world_range(self.net_ranges, self.zero_address) + if not whole_world_member: + self.source.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) + self.dest.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) + return False + else: + source_in_range = network_in_range(self.source.data, self.source_mask.data, self.net_ranges) + dest_in_range = network_in_range(self.dest.data, self.dest_mask.data, self.net_ranges) + if not (source_in_range or dest_in_range): + self.source.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) + self.dest.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) + return False + + return True + + def validate_ipv_specific(self): + """ + abstract method must be implemented in the subclass + """ + pass diff --git a/flowapp/forms/rules/ipv4.py b/flowapp/forms/rules/ipv4.py new file mode 100644 index 00000000..07f27a8c --- /dev/null +++ b/flowapp/forms/rules/ipv4.py @@ -0,0 +1,70 @@ +""" +IPv4 rule form for the flowapp application. +""" + +from wtforms import StringField, IntegerField, SelectField, SelectMultipleField +from wtforms.validators import Optional, NumberRange, DataRequired + +from ...constants import IPV4_PROTOCOL, IPV4_FRAGMENT +from ...validators import IPv4Address +from .base import IPForm + + +class IPv4Form(IPForm): + """ + IPv4 form object + """ + + def __init__(self, *args, **kwargs): + super(IPv4Form, self).__init__(*args, **kwargs) + self.net_ranges = None + + zero_address = "0.0.0.0" + source = StringField( + "Source address", + validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], + ) + + source_mask = IntegerField( + "Source mask (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=32, message="invalid mask value (0-32)"), + ], + ) + + dest = StringField( + "Destination address", + validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], + ) + + dest_mask = IntegerField( + "Destination mask (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=32, message="invalid mask value (0-32)"), + ], + ) + + protocol = SelectField( + "Protocol", + choices=[(pr, pr.upper()) for pr in IPV4_PROTOCOL.keys()], + validators=[DataRequired()], + ) + + fragment = SelectMultipleField( + "Fragment", + choices=[(frv, frk.upper()) for frk, frv in IPV4_FRAGMENT.items()], + validators=[Optional()], + ) + + def validate_ipv_specific(self): + """ + validate protocol and flags, set error message if validation fails + :return: boolean validation result + """ + + if self.flags.data and self.protocol.data and len(self.flags.data) > 0 and self.protocol.data != "tcp": + self.flags.errors.append("Can not set TCP flags for protocol {} !".format(self.protocol.data.upper())) + return False + return True diff --git a/flowapp/forms/rules/ipv6.py b/flowapp/forms/rules/ipv6.py new file mode 100644 index 00000000..aabd1b2f --- /dev/null +++ b/flowapp/forms/rules/ipv6.py @@ -0,0 +1,64 @@ +""" +IPv6 rule form for the flowapp application. +""" + +from wtforms import StringField, IntegerField, SelectField +from wtforms.validators import Optional, NumberRange, DataRequired + +from ...constants import IPV6_NEXT_HEADER +from ...validators import IPv6Address +from .base import IPForm + + +class IPv6Form(IPForm): + """ + IPv6 form object + """ + + def __init__(self, *args, **kwargs): + super(IPv6Form, self).__init__(*args, **kwargs) + self.net_ranges = None + + zero_address = "::" + source = StringField( + "Source address", + validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], + ) + + source_mask = IntegerField( + "Source prefix length (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=128, message="invalid prefix value (0-128)"), + ], + ) + + dest = StringField( + "Destination address", + validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], + ) + + dest_mask = IntegerField( + "Destination prefix length (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=128, message="invalid prefix value (0-128)"), + ], + ) + + next_header = SelectField( + "Next Header", + choices=[(pr, pr.upper()) for pr in IPV6_NEXT_HEADER.keys()], + validators=[DataRequired()], + ) + + def validate_ipv_specific(self): + """ + validate next header and flags, set error message if validation fails + :return: boolean validation result + """ + if len(self.flags.data) > 0 and self.next_header.data != "tcp": + self.flags.errors.append("Can not set TCP flags for next-header {} !".format(self.next_header.data.upper())) + return False + + return True diff --git a/flowapp/forms/rules/rtbh.py b/flowapp/forms/rules/rtbh.py new file mode 100644 index 00000000..cef79483 --- /dev/null +++ b/flowapp/forms/rules/rtbh.py @@ -0,0 +1,104 @@ +""" +RTBH rule form for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import StringField, IntegerField, SelectField, TextAreaField +from wtforms.validators import Optional, NumberRange, DataRequired, InputRequired + +from ...constants import FORM_TIME_PATTERN +from ...validators import IPv6Address, address_with_mask, address_in_range, IPv4Address +from ..base import MultiFormatDateTimeLocalField + + +class RTBHForm(FlaskForm): + """ + RoadToBlackHole rule form + """ + + def __init__(self, *args, **kwargs): + super(RTBHForm, self).__init__(*args, **kwargs) + self.net_ranges = None + + ipv4 = StringField( + "IPv4 address", + validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], + ) + + ipv4_mask = IntegerField( + "IPv4 mask (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=32, message="invalid IPv4 mask value (0-32)"), + ], + ) + + ipv6 = StringField( + "IPv6 address", + validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], + ) + + ipv6_mask = IntegerField( + "IPv6 mask (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=128, message="invalid IPv6 mask value (0-128)"), + ], + ) + + community = SelectField( + "Community", + coerce=int, + validators=[ + DataRequired(message="Please select a community for the rule."), + ], + ) + + expires = MultiFormatDateTimeLocalField( + "Expires", + format=FORM_TIME_PATTERN, + validators=[DataRequired(), InputRequired()], + ) + + comment = arange = TextAreaField("Comments") + + def validate(self): + """ + custom validation method + :return: boolean + """ + result = True + + if not FlaskForm.validate(self): + result = False + + # ipv4 and ipv6 are mutually exclusive + # if both are set, validation fails + # if none is set, validation fails + # if one is set, validation passes + if self.ipv4.data and self.ipv6.data: + self.ipv4.errors.append("IPv4 and IPv6 are mutually exclusive in RTBH rule.") + self.ipv6.errors.append("IPv4 and IPv6 are mutually exclusive in RTBH rule.") + result = False + + if self.ipv4.data and not address_with_mask(self.ipv4.data, self.ipv4_mask.data): + self.ipv4.errors.append( + "This is not valid combination of address {} and mask {}.".format(self.ipv4.data, self.ipv4_mask.data) + ) + result = False + + if self.ipv6.data and not address_with_mask(self.ipv6.data, self.ipv6_mask.data): + self.ipv6.errors.append( + "This is not valid combination of address {} and mask {}.".format(self.ipv6.data, self.ipv6_mask.data) + ) + result = False + + ipv6_in_range = address_in_range(self.ipv6.data, self.net_ranges) + ipv4_in_range = address_in_range(self.ipv4.data, self.net_ranges) + + if not (ipv6_in_range or ipv4_in_range): + self.ipv6.errors.append("IPv4 or IPv6 address must be in organization range : {}.".format(self.net_ranges)) + self.ipv4.errors.append("IPv4 or IPv6 address must be in organization range : {}.".format(self.net_ranges)) + result = False + + return result diff --git a/flowapp/forms/rules/whitelist.py b/flowapp/forms/rules/whitelist.py new file mode 100644 index 00000000..8aa88b41 --- /dev/null +++ b/flowapp/forms/rules/whitelist.py @@ -0,0 +1,66 @@ +""" +Whitelist form for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import StringField, IntegerField, TextAreaField +from wtforms.validators import Optional, DataRequired, InputRequired, Length + +from ...constants import FORM_TIME_PATTERN +from ...validators import IPAddressValidator, NetworkValidator, network_in_range +from ..base import MultiFormatDateTimeLocalField + + +class WhitelistForm(FlaskForm): + """ + Whitelist form object + Used for creating and editing whitelist entries + Supports both IPv4 and IPv6 addresses + """ + + def __init__(self, *args, **kwargs): + super(WhitelistForm, self).__init__(*args, **kwargs) + self.net_ranges = None + + ip = StringField( + "IP address", + validators=[ + DataRequired(message="Please provide an IP address"), + IPAddressValidator(message="Please provide a valid IP address: {}"), + NetworkValidator(mask_field_name="mask"), + ], + ) + + mask = IntegerField( + "Network mask (bits)", + validators=[ + DataRequired(message="Please provide a network mask"), + ], + ) + + comment = TextAreaField("Comments", validators=[Optional(), Length(max=255)]) + + expires = MultiFormatDateTimeLocalField( + "Expires", + format=FORM_TIME_PATTERN, + validators=[DataRequired(), InputRequired()], + ) + + def validate(self): + """ + Custom validation method + :return: boolean + """ + result = True + + if not FlaskForm.validate(self): + result = False + + # Validate IP is in organization range + if self.ip.data and self.mask.data and self.net_ranges: + ip_in_range = network_in_range(self.ip.data, self.mask.data, self.net_ranges) + if not ip_in_range: + self.ip.errors.append("IP address must be in organization range: {}.".format(self.net_ranges)) + result = False + + return result diff --git a/flowapp/forms/user.py b/flowapp/forms/user.py new file mode 100644 index 00000000..5944afd1 --- /dev/null +++ b/flowapp/forms/user.py @@ -0,0 +1,91 @@ +""" +User-related forms for the flowapp application. +""" + +import csv +from io import StringIO + +from flask_wtf import FlaskForm +from wtforms import StringField, TextAreaField, SelectMultipleField +from wtforms.validators import ValidationError, DataRequired, Email, InputRequired, Optional + + +class UserForm(FlaskForm): + """ + User Form object + used in Admin + """ + + uuid = StringField( + "Unique User ID", + validators=[ + InputRequired("Please provide UUID"), + Email("Please provide valid email"), + ], + ) + + email = StringField("Email", validators=[Optional(), Email("Please provide valid email")]) + + comment = StringField("Notice", validators=[Optional()]) + + name = StringField("Name", validators=[Optional()]) + + phone = StringField("Contact phone", validators=[Optional()]) + + role_ids = SelectMultipleField("Role", coerce=int, validators=[DataRequired("Select at last one role")]) + + org_ids = SelectMultipleField( + "Organization", + coerce=int, + validators=[DataRequired("We prefer one Organization per user, but it's possible select more")], + ) + + +class BulkUserForm(FlaskForm): + """ + Bulk User Form object + used in Admin + """ + + users = TextAreaField("Users in CSV - see example below", validators=[DataRequired()]) + + def __init__(self, *args, **kwargs): + super(BulkUserForm, self).__init__(*args, **kwargs) + self.roles = None + self.organizations = None + self.uuids = None + + # Custom validator for CSV data + def validate_users(self, field): + csv_data = field.data + + # Parse CSV data + csv_reader = csv.DictReader(StringIO(csv_data), delimiter=",") + + # List to keep track of failed validation rows + errors = 0 + for row_num, row in enumerate(csv_reader, start=1): + try: + # check if the user not already exists + if row["uuid-eppn"] in self.uuids: + field.errors.append(f"Row {row_num}: User with UUID {row['uuid-eppn']} already exists.") + errors += 1 + + # Check if role exists in the database + role_id = int(row["role"]) # Convert role field to integer + if role_id not in self.roles: + field.errors.append(f"Row {row_num}: Role ID {role_id} does not exist.") + errors += 1 + + # Check if organization exists in the database + org_id = int(row["organizace"]) # Convert organization field to integer + if org_id not in self.organizations: + field.errors.append(f"Row {row_num}: Organization ID {org_id} does not exist.") + errors += 1 + + except (KeyError, ValueError) as e: + field.errors.append(f"Row {row_num}: Invalid data / key - {str(e)}. Check CSV head row.") + + if errors > 0: + # Raise validation error if any invalid rows found + raise ValidationError("Invalid CSV Data - check the errors above.") diff --git a/flowapp/instance_config.py b/flowapp/instance_config.py index 1ffcb1f8..4136530e 100644 --- a/flowapp/instance_config.py +++ b/flowapp/instance_config.py @@ -2,12 +2,19 @@ # column names for tables RTBH_COLUMNS = ( - ("ipv4", "IP adress (v4 or v6)"), + ("ipv4", "IP address (v4 or v6)"), ("community_id", "Community"), ("expires", "Expires"), ("user_id", "User"), ) +WHITELIST_COLUMNS = ( + ("address", "IP address / network (v4 or v6)"), + ("expires", "Expires"), + ("user_id", "User"), +) + + RULES_COLUMNS_V4 = ( ("source", "Source addr."), ("source_port", "S port"), @@ -74,6 +81,7 @@ class InstanceConfig: {"name": "Add IPv4", "url": "rules.ipv4_rule"}, {"name": "Add IPv6", "url": "rules.ipv6_rule"}, {"name": "Add RTBH", "url": "rules.rtbh_rule"}, + {"name": "Add Whitelist", "url": "whitelist.add"}, {"name": "API Key", "url": "api_keys.all"}, ], "admin": [ @@ -123,6 +131,14 @@ class InstanceConfig: "table_colspan": 5, "table_columns": RTBH_COLUMNS, }, + "whitelist": { + "name": "Whitelist", + "macro_file": "macros.html", + "macro_tbody": "build_whitelist_tbody", + "macro_thead": "build_rules_thead", + "table_colspan": 4, + "table_columns": WHITELIST_COLUMNS, + }, } - COUNT_MATCH = {"ipv4": 0, "ipv6": 0, "rtbh": 0} + COUNT_MATCH = {"ipv4": 0, "ipv6": 0, "rtbh": 0, "whitelist": 0} diff --git a/flowapp/models.py b/flowapp/models.py deleted file mode 100644 index 4e1291f3..00000000 --- a/flowapp/models.py +++ /dev/null @@ -1,1064 +0,0 @@ -import json -from sqlalchemy import event -from datetime import datetime -from flowapp import db, utils -from flowapp.constants import RuleTypes -from flask import current_app - -# models and tables - -user_role = db.Table( - "user_role", - db.Column("user_id", db.Integer, db.ForeignKey("user.id"), nullable=False), - db.Column("role_id", db.Integer, db.ForeignKey("role.id"), nullable=False), - db.PrimaryKeyConstraint("user_id", "role_id"), -) - -user_organization = db.Table( - "user_organization", - db.Column("user_id", db.Integer, db.ForeignKey("user.id"), nullable=False), - db.Column("organization_id", db.Integer, db.ForeignKey("organization.id"), nullable=False), - db.PrimaryKeyConstraint("user_id", "organization_id"), -) - - -class User(db.Model): - """ - App User - """ - - id = db.Column(db.Integer, primary_key=True) - uuid = db.Column(db.String(180), unique=True) - comment = db.Column(db.String(500)) - email = db.Column(db.String(255)) - name = db.Column(db.String(255)) - phone = db.Column(db.String(255)) - apikeys = db.relationship("ApiKey", back_populates="user", lazy="dynamic") - machineapikeys = db.relationship("MachineApiKey", back_populates="user", lazy="dynamic") - role = db.relationship("Role", secondary=user_role, lazy="dynamic", backref="user") - - organization = db.relationship("Organization", secondary=user_organization, lazy="dynamic", backref="user") - - def __init__(self, uuid, name=None, phone=None, email=None, comment=None): - self.uuid = uuid - self.phone = phone - self.name = name - self.email = email - self.comment = comment - - def update(self, form): - """ - update the user with values from form object - :param form: flask form from request - :return: None - """ - self.uuid = form.uuid.data - self.name = form.name.data - self.email = form.email.data - self.phone = form.phone.data - self.comment = form.comment.data - - # first clear existing roles and orgs - for role in self.role: - self.role.remove(role) - for org in self.organization: - self.organization.remove(org) - - for role_id in form.role_ids.data: - my_role = db.session.query(Role).filter_by(id=role_id).first() - if my_role not in self.role: - self.role.append(my_role) - - for org_id in form.org_ids.data: - my_org = db.session.query(Organization).filter_by(id=org_id).first() - if my_org not in self.organization: - self.organization.append(my_org) - - db.session.commit() - - -class ApiKey(db.Model): - id = db.Column(db.Integer, primary_key=True) - machine = db.Column(db.String(255)) - key = db.Column(db.String(255)) - readonly = db.Column(db.Boolean, default=False) - expires = db.Column(db.DateTime, nullable=True) - comment = db.Column(db.String(255)) - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", back_populates="apikeys") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="apikey") - - def is_expired(self): - if self.expires is None: - return False # Non-expiring key - else: - return self.expires < datetime.now() - - -class MachineApiKey(db.Model): - id = db.Column(db.Integer, primary_key=True) - machine = db.Column(db.String(255)) - key = db.Column(db.String(255)) - readonly = db.Column(db.Boolean, default=True) - expires = db.Column(db.DateTime, nullable=True) - comment = db.Column(db.String(255)) - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", back_populates="machineapikeys") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="machineapikey") - - def is_expired(self): - if self.expires is None: - return False # Non-expiring key - else: - return self.expires < datetime.now() - - -class Role(db.Model): - id = db.Column(db.Integer, primary_key=True) - name = db.Column(db.String(20), unique=True) - description = db.Column(db.String(260)) - - def __init__(self, name, description): - self.name = name - self.description = description - - def __repr__(self): - return self.name - - -class Organization(db.Model): - id = db.Column(db.Integer, primary_key=True) - name = db.Column(db.String(150), unique=True) - arange = db.Column(db.Text) - limit_flowspec4 = db.Column(db.Integer, default=0) - limit_flowspec6 = db.Column(db.Integer, default=0) - limit_rtbh = db.Column(db.Integer, default=0) - - def __init__(self, name, arange, limit_flowspec4=0, limit_flowspec6=0, limit_rtbh=0): - self.name = name - self.arange = arange - self.limit_flowspec4 = limit_flowspec4 - self.limit_flowspec6 = limit_flowspec6 - self.limit_rtbh = limit_rtbh - - def __repr__(self): - return self.name - - def get_users(self): - """ - Returns all users associated with this organization. - """ - # self.user is the backref from the user_organization relationship - return self.user - - -class ASPath(db.Model): - id = db.Column(db.Integer, primary_key=True) - prefix = db.Column(db.String(120), unique=True) - as_path = db.Column(db.String(250)) - - def __init__(self, prefix, as_path): - self.prefix = prefix - self.as_path = as_path - - def __repr__(self): - return f"{self.prefix} : {self.as_path}" - - -class Action(db.Model): - """ - Action for rule - """ - - id = db.Column(db.Integer, primary_key=True) - name = db.Column(db.String(120), unique=True) - command = db.Column(db.String(120), unique=True) - description = db.Column(db.String(260)) - role_id = db.Column(db.Integer, db.ForeignKey("role.id"), nullable=False) - role = db.relationship("Role", backref="action") - - def __init__(self, name, command, description, role_id=2): - self.name = name - self.command = command - self.description = description - self.role_id = role_id - - -class Community(db.Model): - """ - Community for RTBH rule - """ - - id = db.Column(db.Integer, primary_key=True) - name = db.Column(db.String(120), unique=True) - comm = db.Column(db.String(2047)) - larcomm = db.Column(db.String(2047)) - extcomm = db.Column(db.String(2047)) - description = db.Column(db.String(255)) - as_path = db.Column(db.Boolean, default=False) - role_id = db.Column(db.Integer, db.ForeignKey("role.id"), nullable=False) - role = db.relationship("Role", backref="community") - - def __init__(self, name, comm, larcomm, extcomm, description, as_path=False, role_id=2): - self.name = name - self.comm = comm - self.larcomm = larcomm - self.extcomm = extcomm - self.description = description - self.as_path = as_path - self.role_id = role_id - - -class Rstate(db.Model): - """ - State for Rules - """ - - id = db.Column(db.Integer, primary_key=True) - description = db.Column(db.String(260)) - - def __init__(self, description): - self.description = description - - -class RTBH(db.Model): - __tablename__ = "RTBH" - - id = db.Column(db.Integer, primary_key=True) - ipv4 = db.Column(db.String(255)) - ipv4_mask = db.Column(db.Integer) - ipv6 = db.Column(db.String(255)) - ipv6_mask = db.Column(db.Integer) - community_id = db.Column(db.Integer, db.ForeignKey("community.id"), nullable=False) - community = db.relationship("Community", backref="rtbh") - comment = db.Column(db.Text) - expires = db.Column(db.DateTime) - created = db.Column(db.DateTime) - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", backref="rtbh") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="rtbh") - rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) - rstate = db.relationship("Rstate", backref="RTBH") - - def __init__( - self, - ipv4, - ipv4_mask, - ipv6, - ipv6_mask, - community_id, - expires, - user_id, - org_id, - comment=None, - created=None, - rstate_id=1, - ): - self.ipv4 = ipv4 - self.ipv4_mask = ipv4_mask - self.ipv6 = ipv6 - self.ipv6_mask = ipv6_mask - self.community_id = community_id - self.expires = expires - self.user_id = user_id - self.org_id = org_id - self.comment = comment - if created is None: - created = datetime.now() - self.created = created - self.rstate_id = rstate_id - - def __eq__(self, other): - """ - Two models are equal if all the network parameters equals. - User_id and time fields can differ. - :param other: other RTBH instance - :return: boolean - """ - return ( - self.ipv4 == other.ipv4 - and self.ipv4_mask == other.ipv4_mask - and self.ipv6 == other.ipv6 - and self.ipv6_mask == other.ipv6_mask - and self.community_id == other.community_id - and self.rstate_id == other.rstate_id - ) - - def __ne__(self, other): - """ - Two models are not equal if all the network parameters are not equal. - User_id and time fields can differ. - :param other: other RTBH instance - :return: boolean - """ - compars = ( - self.ipv4 == other.ipv4 - and self.ipv4_mask == other.ipv4_mask - and self.ipv6 == other.ipv6 - and self.ipv6_mask == other.ipv6_mask - and self.community_id == other.community_id - and self.rstate_id == other.rstate_id - ) - - return not compars - - def update_time(self, form): - self.expires = utils.webpicker_to_datetime(form.expire_date.data) - db.session.commit() - - def to_dict(self, prefered_format="yearfirst"): - """ - Serialize to dict used in API - :param prefered_format: string with prefered time format - :return: dictionary - """ - if prefered_format == "timestamp": - expires = int(datetime.timestamp(self.expires)) - created = int(datetime.timestamp(self.created)) - else: - expires = utils.datetime_to_webpicker(self.expires, prefered_format) - created = utils.datetime_to_webpicker(self.created, prefered_format) - - return { - "id": self.id, - "ipv4": self.ipv4, - "ipv4_mask": self.ipv4_mask, - "ipv6": self.ipv6, - "ipv6_mask": self.ipv6_mask, - "community": self.community.name, - "comment": self.comment, - "expires": expires, - "created": created, - "user": self.user.uuid, - "rstate": self.rstate.description, - } - - def dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :returns: dictionary - """ - return self.to_dict(prefered_format) - - def json(self, prefered_format="yearfirst"): - """ - Serialize to json - :param prefered_format: string with prefered time format - :returns: json - """ - return json.dumps(self.to_dict()) - - -class Flowspec4(db.Model): - id = db.Column(db.Integer, primary_key=True) - source = db.Column(db.String(255)) - source_mask = db.Column(db.Integer) - source_port = db.Column(db.String(255)) - dest = db.Column(db.String(255)) - dest_mask = db.Column(db.Integer) - dest_port = db.Column(db.String(255)) - protocol = db.Column(db.String(255)) - flags = db.Column(db.String(255)) - packet_len = db.Column(db.String(255)) - fragment = db.Column(db.String(255)) - comment = db.Column(db.Text) - expires = db.Column(db.DateTime) - created = db.Column(db.DateTime) - action_id = db.Column(db.Integer, db.ForeignKey("action.id"), nullable=False) - action = db.relationship("Action", backref="flowspec4") - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", backref="flowspec4") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="flowspec4") - rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) - rstate = db.relationship("Rstate", backref="flowspec4") - - def __init__( - self, - source, - source_mask, - source_port, - destination, - destination_mask, - destination_port, - protocol, - flags, - packet_len, - fragment, - expires, - user_id, - org_id, - action_id, - created=None, - comment=None, - rstate_id=1, - ): - self.source = source - self.source_mask = source_mask - self.dest = destination - self.dest_mask = destination_mask - self.source_port = source_port - self.dest_port = destination_port - self.protocol = protocol - self.flags = flags - self.packet_len = packet_len - self.fragment = fragment - self.comment = comment - self.expires = expires - self.user_id = user_id - self.org_id = org_id - self.action_id = action_id - if created is None: - created = datetime.now() - self.created = created - self.rstate_id = rstate_id - - def __eq__(self, other): - """ - Two models are equal if all the network parameters equals. User_id and time fields can differ. - :param other: other Flowspec4 instance - :return: boolean - """ - return ( - self.source == other.source - and self.source_mask == other.source_mask - and self.dest == other.dest - and self.dest_mask == other.dest_mask - and self.source_port == other.source_port - and self.dest_port == other.dest_port - and self.protocol == other.protocol - and self.flags == other.flags - and self.packet_len == other.packet_len - and self.fragment == other.fragment - and self.action_id == other.action_id - and self.rstate_id == other.rstate_id - ) - - def __ne__(self, other): - """ - Two models are not equal if all the network parameters are not equal. - User_id and time fields can differ. - :param other: other Flowspec4 instance - :return: boolean - """ - compars = ( - self.source == other.source - and self.source_mask == other.source_mask - and self.dest == other.dest - and self.dest_mask == other.dest_mask - and self.source_port == other.source_port - and self.dest_port == other.dest_port - and self.protocol == other.protocol - and self.flags == other.flags - and self.packet_len == other.packet_len - and self.fragment == other.fragment - and self.action_id == other.action_id - and self.rstate_id == other.rstate_id - ) - - return not compars - - def to_dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :return: dictionary - """ - if prefered_format == "timestamp": - expires = int(datetime.timestamp(self.expires)) - created = int(datetime.timestamp(self.created)) - else: - expires = utils.datetime_to_webpicker(self.expires, prefered_format) - created = utils.datetime_to_webpicker(self.created, prefered_format) - - return { - "id": self.id, - "source": self.source, - "source_mask": self.source_mask, - "source_port": self.source_port, - "dest": self.dest, - "dest_mask": self.dest_mask, - "dest_port": self.dest_port, - "protocol": self.protocol, - "flags": self.flags, - "packet_len": self.packet_len, - "fragment": self.fragment, - "comment": self.comment, - "expires": expires, - "created": created, - "action": self.action.name, - "user": self.user.uuid, - "rstate": self.rstate.description, - } - - def dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :returns: dictionary - """ - return self.to_dict(prefered_format) - - def json(self, prefered_format="yearfirst"): - """ - Serialize to json - :param prefered_format: string with prefered time format - :returns: json - """ - return json.dumps(self.to_dict()) - - -class Flowspec6(db.Model): - id = db.Column(db.Integer, primary_key=True) - source = db.Column(db.String(255)) - source_mask = db.Column(db.Integer) - source_port = db.Column(db.String(255)) - dest = db.Column(db.String(255)) - dest_mask = db.Column(db.Integer) - dest_port = db.Column(db.String(255)) - next_header = db.Column(db.String(255)) - flags = db.Column(db.String(255)) - packet_len = db.Column(db.String(255)) - comment = db.Column(db.Text) - expires = db.Column(db.DateTime) - created = db.Column(db.DateTime) - action_id = db.Column(db.Integer, db.ForeignKey("action.id"), nullable=False) - action = db.relationship("Action", backref="flowspec6") - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", backref="flowspec6") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="flowspec6") - rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) - rstate = db.relationship("Rstate", backref="flowspec6") - - def __init__( - self, - source, - source_mask, - source_port, - destination, - destination_mask, - destination_port, - next_header, - flags, - packet_len, - expires, - user_id, - org_id, - action_id, - created=None, - comment=None, - rstate_id=1, - ): - self.source = source - self.source_mask = source_mask - self.dest = destination - self.dest_mask = destination_mask - self.source_port = source_port - self.dest_port = destination_port - self.next_header = next_header - self.flags = flags - self.packet_len = packet_len - self.comment = comment - self.expires = expires - self.user_id = user_id - self.org_id = org_id - self.action_id = action_id - if created is None: - created = datetime.now() - self.created = created - self.rstate_id = rstate_id - - def __eq__(self, other): - """ - Two models are equal if all the network parameters equals. User_id and time fields can differ. - :param other: other Flowspec4 instance - :return: boolean - """ - return ( - self.source == other.source - and self.source_mask == other.source_mask - and self.dest == other.dest - and self.dest_mask == other.dest_mask - and self.source_port == other.source_port - and self.dest_port == other.dest_port - and self.next_header == other.next_header - and self.flags == other.flags - and self.packet_len == other.packet_len - and self.action_id == other.action_id - and self.rstate_id == other.rstate_id - ) - - def to_dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :returns: dictionary - """ - if prefered_format == "timestamp": - expires = int(datetime.timestamp(self.expires)) - created = int(datetime.timestamp(self.created)) - else: - expires = utils.datetime_to_webpicker(self.expires, prefered_format) - created = utils.datetime_to_webpicker(self.created, prefered_format) - - return { - "id": str(self.id), - "source": self.source, - "source_mask": self.source_mask, - "source_port": self.source_port, - "dest": self.dest, - "dest_mask": self.dest_mask, - "dest_port": self.dest_port, - "next_header": self.next_header, - "flags": self.flags, - "packet_len": self.packet_len, - "comment": self.comment, - "expires": expires, - "created": created, - "action": self.action.name, - "user": self.user.uuid, - "rstate": self.rstate.description, - } - - def dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :returns: dictionary - """ - return self.to_dict(prefered_format) - - def json(self, prefered_format="yearfirst"): - """ - Serialize to json - :param prefered_format: string with prefered time format - :returns: json - """ - return json.dumps(self.to_dict()) - - -class Log(db.Model): - id = db.Column(db.Integer, primary_key=True) - time = db.Column(db.DateTime) - task = db.Column(db.String(1000)) - author = db.Column(db.String(1000)) - rule_type = db.Column(db.Integer) - rule_id = db.Column(db.Integer) - user_id = db.Column(db.Integer) - org_id = db.Column(db.Integer, nullable=True) - - def __init__(self, time, task, user_id, rule_type, rule_id, author, org_id=None): - self.time = time - self.task = task - self.rule_type = rule_type - self.rule_id = rule_id - self.user_id = user_id - self.author = author - self.org_id = org_id - - -# DDL -# default values for tables inserted after create - - -@event.listens_for(Action.__table__, "after_create") -def insert_initial_actions(table, conn, *args, **kwargs): - conn.execute( - table.insert().values( - name="QoS 100 kbps", - command="rate-limit 12800", - description="QoS", - role_id=2, - ) - ) - conn.execute( - table.insert().values( - name="QoS 1Mbps", - command="rate-limit 13107200", - description="QoS", - role_id=2, - ) - ) - conn.execute( - table.insert().values( - name="QoS 10Mbps", - command="rate-limit 131072000", - description="QoS", - role_id=2, - ) - ) - conn.execute(table.insert().values(name="Discard", command="discard", description="Discard", role_id=2)) - - -@event.listens_for(Community.__table__, "after_create") -def insert_initial_communities(table, conn, *args, **kwargs): - conn.execute( - table.insert().values( - name="65535:65283", - comm="65535:65283", - larcomm="", - extcomm="", - description="local-as", - role_id=2, - ) - ) - conn.execute( - table.insert().values( - name="64496:64511", - comm="64496:64511", - larcomm="", - extcomm="", - description="", - role_id=2, - ) - ) - conn.execute( - table.insert().values( - name="64497:64510", - comm="64497:64510", - larcomm="", - extcomm="", - description="", - role_id=2, - ) - ) - - -@event.listens_for(Role.__table__, "after_create") -def insert_initial_roles(table, conn, *args, **kwargs): - conn.execute(table.insert().values(name="view", description="just view, no edit")) - conn.execute(table.insert().values(name="user", description="can edit")) - conn.execute(table.insert().values(name="admin", description="admin")) - - -@event.listens_for(Organization.__table__, "after_create") -def insert_initial_organizations(table, conn, *args, **kwargs): - conn.execute(table.insert().values(name="Cesnet", arange="147.230.0.0/16\n2001:718:1c01::/48")) - - -@event.listens_for(Rstate.__table__, "after_create") -def insert_initial_rulestates(table, conn, *args, **kwargs): - conn.execute(table.insert().values(description="active rule")) - conn.execute(table.insert().values(description="withdrawed rule")) - conn.execute(table.insert().values(description="deleted rule")) - - -# Misc functions -def check_rule_limit(org_id: int, rule_type: RuleTypes) -> bool: - """ - Check if the organization has reached the rule limit - :param org_id: integer organization id - :param rule_type: RuleType rule type - :return: boolean - """ - flowspec_limit = current_app.config.get("FLOWSPEC_MAX_RULES", 9000) - rtbh_limit = current_app.config.get("RTBH_MAX_RULES", 100000) - fs4 = db.session.query(Flowspec4).filter_by(rstate_id=1).count() - fs6 = db.session.query(Flowspec6).filter_by(rstate_id=1).count() - rtbh = db.session.query(RTBH).filter_by(rstate_id=1).count() - - # check the organization limits - org = Organization.query.filter_by(id=org_id).first() - if rule_type == RuleTypes.IPv4 and org.limit_flowspec4 > 0: - count = db.session.query(Flowspec4).filter_by(org_id=org_id, rstate_id=1).count() - return count >= org.limit_flowspec4 or fs4 >= flowspec_limit - if rule_type == RuleTypes.IPv6 and org.limit_flowspec6 > 0: - count = db.session.query(Flowspec6).filter_by(org_id=org_id, rstate_id=1).count() - return count >= org.limit_flowspec6 or fs6 >= flowspec_limit - if rule_type == RuleTypes.RTBH and org.limit_rtbh > 0: - count = db.session.query(RTBH).filter_by(org_id=org_id, rstate_id=1).count() - return count >= org.limit_rtbh or rtbh >= rtbh_limit - - -def check_global_rule_limit(rule_type: RuleTypes) -> bool: - flowspec4_limit = current_app.config.get("FLOWSPEC4_MAX_RULES", 9000) - flowspec6_limit = current_app.config.get("FLOWSPEC6_MAX_RULES", 9000) - rtbh_limit = current_app.config.get("RTBH_MAX_RULES", 100000) - fs4 = db.session.query(Flowspec4).filter_by(rstate_id=1).count() - fs6 = db.session.query(Flowspec6).filter_by(rstate_id=1).count() - rtbh = db.session.query(RTBH).filter_by(rstate_id=1).count() - - # check the global limits if the organization limits are not set - - if rule_type == RuleTypes.IPv4: - return fs4 >= flowspec4_limit - if rule_type == RuleTypes.IPv6: - return fs6 >= flowspec6_limit - if rule_type == RuleTypes.RTBH: - return rtbh >= rtbh_limit - - -def get_ipv4_model_if_exists(form_data, rstate_id=1): - """ - Check if the record in database exist - """ - record = ( - db.session.query(Flowspec4) - .filter( - Flowspec4.source == form_data["source"], - Flowspec4.source_mask == form_data["source_mask"], - Flowspec4.source_port == form_data["source_port"], - Flowspec4.dest == form_data["dest"], - Flowspec4.dest_mask == form_data["dest_mask"], - Flowspec4.dest_port == form_data["dest_port"], - Flowspec4.protocol == form_data["protocol"], - Flowspec4.flags == ";".join(form_data["flags"]), - Flowspec4.packet_len == form_data["packet_len"], - Flowspec4.action_id == form_data["action"], - Flowspec4.rstate_id == rstate_id, - ) - .first() - ) - - if record: - return record - - return False - - -def get_ipv6_model_if_exists(form_data, rstate_id=1): - """ - Check if the record in database exist - """ - record = ( - db.session.query(Flowspec6) - .filter( - Flowspec6.source == form_data["source"], - Flowspec6.source_mask == form_data["source_mask"], - Flowspec6.source_port == form_data["source_port"], - Flowspec6.dest == form_data["dest"], - Flowspec6.dest_mask == form_data["dest_mask"], - Flowspec6.dest_port == form_data["dest_port"], - Flowspec6.next_header == form_data["next_header"], - Flowspec6.flags == ";".join(form_data["flags"]), - Flowspec6.packet_len == form_data["packet_len"], - Flowspec6.action_id == form_data["action"], - Flowspec6.rstate_id == rstate_id, - ) - .first() - ) - - if record: - return record - - return False - - -def get_rtbh_model_if_exists(form_data, rstate_id=1): - """ - Check if the record in database exist - """ - - record = ( - db.session.query(RTBH) - .filter( - RTBH.ipv4 == form_data["ipv4"], - RTBH.ipv4_mask == form_data["ipv4_mask"], - RTBH.ipv6 == form_data["ipv6"], - RTBH.ipv6_mask == form_data["ipv6_mask"], - RTBH.community_id == form_data["community"], - RTBH.rstate_id == rstate_id, - ) - .first() - ) - - if record: - return record - - return False - - -def insert_users(users): - """ - inser list of users {name: string, role_id: integer} to db - """ - for user in users: - r = Role.query.filter_by(id=user["role_id"]).first() - o = Organization.query.filter_by(id=user["org_id"]).first() - u = User(uuid=user["name"]) - u.role.append(r) - u.organization.append(o) - db.session.add(u) - - db.session.commit() - - -def insert_user( - uuid: str, - role_ids: list, - org_ids: list, - name: str = None, - phone: str = None, - email: str = None, - comment: str = None, -): - """ - insert new user with multiple roles and organizations - :param uuid: string unique user id (eppn or similar) - :param phone: string phone number - :param name: string user name - :param email: string email - :param comment: string comment / notice - :param role_ids: list of roles - :param org_ids: list of orgs - :return: None - """ - u = User(uuid=uuid, name=name, phone=phone, comment=comment, email=email) - - for role_id in role_ids: - r = Role.query.filter_by(id=role_id).first() - u.role.append(r) - - for org_id in org_ids: - o = Organization.query.filter_by(id=org_id).first() - u.organization.append(o) - - db.session.add(u) - db.session.commit() - - -def get_user_nets(user_id): - """ - Return list of network ranges for all user organization - """ - user = db.session.query(User).filter_by(id=user_id).first() - orgs = user.organization - result = [] - for org in orgs: - result.extend(org.arange.split()) - - return result - - -def get_user_orgs_choices(user_id): - """ - Return list of orgs as choices for form - """ - user = db.session.query(User).filter_by(id=user_id).first() - orgs = user.organization - - return [(g.id, g.name) for g in orgs] - - -def get_user_actions(user_roles): - """ - Return list of actions based on current user role - """ - max_role = max(user_roles) - if max_role == 3: - actions = db.session.query(Action).order_by("id") - else: - actions = db.session.query(Action).filter_by(role_id=max_role).order_by("id") - - return [(g.id, g.name) for g in actions] - - -def get_user_communities(user_roles): - """ - Return list of communities based on current user role - """ - max_role = max(user_roles) - if max_role == 3: - communities = db.session.query(Community).order_by("id") - else: - communities = db.session.query(Community).filter_by(role_id=max_role).order_by("id") - - return [(g.id, g.name) for g in communities] - - -def get_existing_action(name=None, command=None): - """ - return Action with given name or command if the action exists - return None if action not exists - :param name: string action name - :param command: string action command - :return: action id - """ - action = Action.query.filter((Action.name == name) | (Action.command == command)).first() - return action.id if hasattr(action, "id") else None - - -def get_existing_community(name=None): - """ - return Community with given name or command if the action exists - return None if action not exists - :param name: string action name - :param command: string action command - :return: action id - """ - community = Community.query.filter(Community.name == name).first() - return community.id if hasattr(community, "id") else None - - -def get_ip_rules(rule_type, rule_state, sort="expires", order="desc"): - """ - Returns list of rules sorted by sort column ordered asc or desc - :param sort: sorting column - :param order: asc or desc - :return: list - """ - - today = datetime.now() - comp_func = utils.get_comp_func(rule_state) - - if rule_type == "ipv4": - sorter_ip4 = getattr(Flowspec4, sort, Flowspec4.id) - sorting_ip4 = getattr(sorter_ip4, order) - if comp_func: - rules4 = ( - db.session.query(Flowspec4).filter(comp_func(Flowspec4.expires, today)).order_by(sorting_ip4()).all() - ) - else: - rules4 = db.session.query(Flowspec4).order_by(sorting_ip4()).all() - - return rules4 - - if rule_type == "ipv6": - sorter_ip6 = getattr(Flowspec6, sort, Flowspec6.id) - sorting_ip6 = getattr(sorter_ip6, order) - if comp_func: - rules6 = ( - db.session.query(Flowspec6).filter(comp_func(Flowspec6.expires, today)).order_by(sorting_ip6()).all() - ) - else: - rules6 = db.session.query(Flowspec6).order_by(sorting_ip6()).all() - - return rules6 - - if rule_type == "rtbh": - sorter_rtbh = getattr(RTBH, sort, RTBH.id) - sorting_rtbh = getattr(sorter_rtbh, order) - - if comp_func: - rules_rtbh = db.session.query(RTBH).filter(comp_func(RTBH.expires, today)).order_by(sorting_rtbh()).all() - - else: - rules_rtbh = db.session.query(RTBH).order_by(sorting_rtbh()).all() - - return rules_rtbh - - -def get_user_rules_ids(user_id, rule_type): - """ - Returns list of rule ids belonging to user - :param user_id: user id - :param rule_type: ipv4, ipv6 or rtbh - :return: list - """ - - if rule_type == "ipv4": - rules4 = db.session.query(Flowspec4.id).filter_by(user_id=user_id).all() - return [int(x[0]) for x in rules4] - - if rule_type == "ipv6": - rules6 = db.session.query(Flowspec6.id).order_by(Flowspec6.expires.desc()).all() - return [int(x[0]) for x in rules6] - - if rule_type == "rtbh": - rules_rtbh = db.session.query(RTBH.id).filter_by(user_id=user_id).all() - return [int(x[0]) for x in rules_rtbh] diff --git a/flowapp/models/__init__.py b/flowapp/models/__init__.py new file mode 100644 index 00000000..64055059 --- /dev/null +++ b/flowapp/models/__init__.py @@ -0,0 +1,71 @@ +# Import and re-export all models and functions for backward compatibility +from .base import db, user_role, user_organization + +# User-related models +from .user import User, Role +from .api import ApiKey, MachineApiKey +from .organization import Organization + +# Rule-related models +from .rules import Flowspec4, Flowspec6, RTBH, Rstate, Action, Whitelist, RuleWhitelistCache + +# Other models +from .community import Community, ASPath, insert_initial_communities +from .log import Log + +# Helper functions +from .utils import ( + get_user_nets, + get_user_actions, + get_user_communities, + get_existing_action, + get_existing_community, + get_ipv4_model_if_exists, + get_ipv6_model_if_exists, + get_rtbh_model_if_exists, + get_whitelist_model_if_exists, + get_ip_rules, + get_user_rules_ids, + insert_users, + insert_user, + check_rule_limit, + check_global_rule_limit, +) + +# Ensure all models are registered properly +__all__ = [ + "db", + "User", + "Role", + "user_role", + "ApiKey", + "MachineApiKey", + "Organization", + "user_organization", + "Action", + "Flowspec4", + "Flowspec6", + "RTBH", + "Rstate", + "Community", + "ASPath", + "Log", + "Whitelist", + "RuleWhitelistCache", + "get_user_nets", + "get_user_actions", + "get_user_communities", + "get_existing_action", + "get_existing_community", + "get_ipv4_model_if_exists", + "get_ipv6_model_if_exists", + "get_rtbh_model_if_exists", + "get_whitelist_model_if_exists", + "get_ip_rules", + "get_user_rules_ids", + "insert_users", + "insert_user", + "check_rule_limit", + "check_global_rule_limit", + "insert_initial_communities", +] diff --git a/flowapp/models/api.py b/flowapp/models/api.py new file mode 100644 index 00000000..ed506850 --- /dev/null +++ b/flowapp/models/api.py @@ -0,0 +1,40 @@ +from datetime import datetime +from .base import db + + +class ApiKey(db.Model): + id = db.Column(db.Integer, primary_key=True) + machine = db.Column(db.String(255)) + key = db.Column(db.String(255)) + readonly = db.Column(db.Boolean, default=False) + expires = db.Column(db.DateTime, nullable=True) + comment = db.Column(db.String(255)) + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", back_populates="apikeys") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="apikey") + + def is_expired(self): + if self.expires is None: + return False # Non-expiring key + else: + return self.expires < datetime.now() + + +class MachineApiKey(db.Model): + id = db.Column(db.Integer, primary_key=True) + machine = db.Column(db.String(255)) + key = db.Column(db.String(255)) + readonly = db.Column(db.Boolean, default=True) + expires = db.Column(db.DateTime, nullable=True) + comment = db.Column(db.String(255)) + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", back_populates="machineapikeys") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="machineapikey") + + def is_expired(self): + if self.expires is None: + return False # Non-expiring key + else: + return self.expires < datetime.now() diff --git a/flowapp/models/base.py b/flowapp/models/base.py new file mode 100644 index 00000000..a934ce83 --- /dev/null +++ b/flowapp/models/base.py @@ -0,0 +1,16 @@ +from flowapp import db + +# Define shared tables +user_role = db.Table( + "user_role", + db.Column("user_id", db.Integer, db.ForeignKey("user.id"), nullable=False), + db.Column("role_id", db.Integer, db.ForeignKey("role.id"), nullable=False), + db.PrimaryKeyConstraint("user_id", "role_id"), +) + +user_organization = db.Table( + "user_organization", + db.Column("user_id", db.Integer, db.ForeignKey("user.id"), nullable=False), + db.Column("organization_id", db.Integer, db.ForeignKey("organization.id"), nullable=False), + db.PrimaryKeyConstraint("user_id", "organization_id"), +) diff --git a/flowapp/models/community.py b/flowapp/models/community.py new file mode 100644 index 00000000..880a837a --- /dev/null +++ b/flowapp/models/community.py @@ -0,0 +1,79 @@ +from sqlalchemy import event +from .base import db + + +class Community(db.Model): + """Community for RTBH rule""" + + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(120), unique=True) + comm = db.Column(db.String(2047)) + larcomm = db.Column(db.String(2047)) + extcomm = db.Column(db.String(2047)) + description = db.Column(db.String(255)) + as_path = db.Column(db.Boolean, default=False) + role_id = db.Column(db.Integer, db.ForeignKey("role.id"), nullable=False) + role = db.relationship("Role", backref="community") + + def __init__(self, name, comm, larcomm, extcomm, description, as_path, role_id): + self.name = name + self.comm = comm + self.larcomm = larcomm + self.extcomm = extcomm + self.description = description + self.as_path = as_path + self.role_id = role_id + + @classmethod + def get_whitelistable_communities(cls, id_list): + return cls.query.filter(cls.id.in_(id_list)).all() + + def __repr__(self): + return f"" + + def __str__(self): + return f"{self.name}" + + +class ASPath(db.Model): + """AS Path for RTBH rules""" + + id = db.Column(db.Integer, primary_key=True) + prefix = db.Column(db.String(120), unique=True) + as_path = db.Column(db.String(250)) + + # Methods and initializer + + +@event.listens_for(Community.__table__, "after_create") +def insert_initial_communities(table, conn, *args, **kwargs): + conn.execute( + table.insert().values( + name="65535:65283", + comm="65535:65283", + larcomm="", + extcomm="", + description="local-as", + role_id=2, + ) + ) + conn.execute( + table.insert().values( + name="64496:64511", + comm="64496:64511", + larcomm="", + extcomm="", + description="", + role_id=2, + ) + ) + conn.execute( + table.insert().values( + name="64497:64510", + comm="64497:64510", + larcomm="", + extcomm="", + description="", + role_id=2, + ) + ) diff --git a/flowapp/models/log.py b/flowapp/models/log.py new file mode 100644 index 00000000..16cf5b13 --- /dev/null +++ b/flowapp/models/log.py @@ -0,0 +1,39 @@ +from datetime import datetime, timedelta + +from flowapp.constants import RuleTypes +from .base import db + + +class Log(db.Model): + """Log model for system actions""" + + id = db.Column(db.Integer, primary_key=True) + time = db.Column(db.DateTime) + task = db.Column(db.String(1000)) + author = db.Column(db.String(1000)) + rule_type = db.Column(db.Integer) + rule_id = db.Column(db.Integer) + user_id = db.Column(db.Integer) + + def __init__(self, time, task, user_id, rule_type, rule_id, author): + self.time = time + self.task = task + self.rule_type = rule_type + self.rule_id = rule_id + self.user_id = user_id + self.author = author + + @classmethod + def delete_old(cls, days: int = 30): + """Delete logs older than :param days from the database""" + cls.query.filter(cls.time < datetime.now() - timedelta(days=days)).delete() + db.session.commit() + + def __repr__(self): + return f"" + + def __str__(self): + """ + {"author": "vrany@cesnet.cz / Cel\u00fd sv\u011bt", "source": "UI", "command": "cmd"} + """ + return f"{self.author} - {RuleTypes(self.rule_type).name}({self.rule_id}) - {self.task}" diff --git a/flowapp/models/organization.py b/flowapp/models/organization.py new file mode 100644 index 00000000..baf0ec1f --- /dev/null +++ b/flowapp/models/organization.py @@ -0,0 +1,35 @@ +from sqlalchemy import event +from .base import db + + +class Organization(db.Model): + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(150), unique=True) + arange = db.Column(db.Text) + limit_flowspec4 = db.Column(db.Integer, default=0) + limit_flowspec6 = db.Column(db.Integer, default=0) + limit_rtbh = db.Column(db.Integer, default=0) + + def __init__(self, name, arange, limit_flowspec4=0, limit_flowspec6=0, limit_rtbh=0): + self.name = name + self.arange = arange + self.limit_flowspec4 = limit_flowspec4 + self.limit_flowspec6 = limit_flowspec6 + self.limit_rtbh = limit_rtbh + + def __repr__(self): + return self.name + + def get_users(self): + """ + Returns all users associated with this organization. + """ + # self.user is the backref from the user_organization relationship + return self.user + + +# Event listeners for Organization +@event.listens_for(Organization.__table__, "after_create") +def insert_initial_organizations(table, conn, *args, **kwargs): + conn.execute(table.insert().values(name="TU Liberec", arange="147.230.0.0/16\n2001:718:1c01::/48")) + conn.execute(table.insert().values(name="Cesnet", arange="147.230.0.0/16\n2001:718:1c01::/48")) diff --git a/flowapp/models/rules/__init__.py b/flowapp/models/rules/__init__.py new file mode 100644 index 00000000..da12ba45 --- /dev/null +++ b/flowapp/models/rules/__init__.py @@ -0,0 +1,6 @@ +from .flowspec import Flowspec4, Flowspec6 +from .rtbh import RTBH +from .base import Rstate, Action +from .whitelist import Whitelist, RuleWhitelistCache + +__all__ = ["Flowspec4", "Flowspec6", "RTBH", "Rstate", "Action", "Whitelist", "RuleWhitelistCache"] diff --git a/flowapp/models/rules/base.py b/flowapp/models/rules/base.py new file mode 100644 index 00000000..22fbc089 --- /dev/null +++ b/flowapp/models/rules/base.py @@ -0,0 +1,69 @@ +from sqlalchemy import event +from ..base import db + + +class Rstate(db.Model): + """State for Rules""" + + id = db.Column(db.Integer, primary_key=True) + description = db.Column(db.String(260)) + + def __init__(self, description): + self.description = description + + +class Action(db.Model): + """ + Action for rule + """ + + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(120), unique=True) + command = db.Column(db.String(120), unique=True) + description = db.Column(db.String(260)) + role_id = db.Column(db.Integer, db.ForeignKey("role.id"), nullable=False) + role = db.relationship("Role", backref="action") + + def __init__(self, name, command, description, role_id=2): + self.name = name + self.command = command + self.description = description + self.role_id = role_id + + +# Event listeners for Rstate +@event.listens_for(Rstate.__table__, "after_create") +def insert_initial_rulestates(table, conn, *args, **kwargs): + conn.execute(table.insert().values(description="active rule")) + conn.execute(table.insert().values(description="withdrawed rule")) + conn.execute(table.insert().values(description="deleted rule")) + conn.execute(table.insert().values(description="whitelisted rule")) + + +@event.listens_for(Action.__table__, "after_create") +def insert_initial_actions(table, conn, *args, **kwargs): + conn.execute( + table.insert().values( + name="QoS 100 kbps", + command="rate-limit 12800", + description="QoS", + role_id=2, + ) + ) + conn.execute( + table.insert().values( + name="QoS 1Mbps", + command="rate-limit 13107200", + description="QoS", + role_id=2, + ) + ) + conn.execute( + table.insert().values( + name="QoS 10Mbps", + command="rate-limit 131072000", + description="QoS", + role_id=2, + ) + ) + conn.execute(table.insert().values(name="Discard", command="discard", description="Discard", role_id=2)) diff --git a/flowapp/models/rules/flowspec.py b/flowapp/models/rules/flowspec.py new file mode 100644 index 00000000..9823aba8 --- /dev/null +++ b/flowapp/models/rules/flowspec.py @@ -0,0 +1,294 @@ +import json +from datetime import datetime +from flowapp import utils +from ..base import db + + +class Flowspec4(db.Model): + id = db.Column(db.Integer, primary_key=True) + source = db.Column(db.String(255)) + source_mask = db.Column(db.Integer) + source_port = db.Column(db.String(255)) + dest = db.Column(db.String(255)) + dest_mask = db.Column(db.Integer) + dest_port = db.Column(db.String(255)) + protocol = db.Column(db.String(255)) + flags = db.Column(db.String(255)) + packet_len = db.Column(db.String(255)) + fragment = db.Column(db.String(255)) + comment = db.Column(db.Text) + expires = db.Column(db.DateTime) + created = db.Column(db.DateTime) + action_id = db.Column(db.Integer, db.ForeignKey("action.id"), nullable=False) + action = db.relationship("Action", backref="flowspec4") + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", backref="flowspec4") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="flowspec4") + rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) + rstate = db.relationship("Rstate", backref="flowspec4") + + def __init__( + self, + source, + source_mask, + source_port, + destination, + destination_mask, + destination_port, + protocol, + flags, + packet_len, + fragment, + expires, + user_id, + org_id, + action_id, + created=None, + comment=None, + rstate_id=1, + ): + self.source = source + self.source_mask = source_mask + self.dest = destination + self.dest_mask = destination_mask + self.source_port = source_port + self.dest_port = destination_port + self.protocol = protocol + self.flags = flags + self.packet_len = packet_len + self.fragment = fragment + self.comment = comment + self.expires = expires + self.user_id = user_id + self.org_id = org_id + self.action_id = action_id + if created is None: + created = datetime.now() + self.created = created + self.rstate_id = rstate_id + + def __eq__(self, other): + """ + Two models are equal if all the network parameters equals. + User_id and time fields can differ. + :param other: other Flowspec4 instance + :return: boolean + """ + return ( + self.source == other.source + and self.source_mask == other.source_mask + and self.dest == other.dest + and self.dest_mask == other.dest_mask + and self.source_port == other.source_port + and self.dest_port == other.dest_port + and self.protocol == other.protocol + and self.flags == other.flags + and self.packet_len == other.packet_len + and self.fragment == other.fragment + and self.action_id == other.action_id + and self.rstate_id == other.rstate_id + ) + + def __ne__(self, other): + """ + Two models are not equal if all the network parameters are not equal. + User_id and time fields can differ. + :param other: other Flowspec4 instance + :return: boolean + """ + compars = ( + self.source == other.source + and self.source_mask == other.source_mask + and self.dest == other.dest + and self.dest_mask == other.dest_mask + and self.source_port == other.source_port + and self.dest_port == other.dest_port + and self.protocol == other.protocol + and self.flags == other.flags + and self.packet_len == other.packet_len + and self.fragment == other.fragment + and self.action_id == other.action_id + and self.rstate_id == other.rstate_id + ) + + return not compars + + def to_dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :return: dictionary + """ + if prefered_format == "timestamp": + expires = int(datetime.timestamp(self.expires)) + created = int(datetime.timestamp(self.created)) + else: + expires = utils.datetime_to_webpicker(self.expires, prefered_format) + created = utils.datetime_to_webpicker(self.created, prefered_format) + + return { + "id": self.id, + "source": self.source, + "source_mask": self.source_mask, + "source_port": self.source_port, + "dest": self.dest, + "dest_mask": self.dest_mask, + "dest_port": self.dest_port, + "protocol": self.protocol, + "flags": self.flags, + "packet_len": self.packet_len, + "fragment": self.fragment, + "comment": self.comment, + "expires": expires, + "created": created, + "action": self.action.name, + "user": self.user.uuid, + "rstate": self.rstate.description, + } + + def dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + return self.to_dict(prefered_format) + + def json(self, prefered_format="yearfirst"): + """ + Serialize to json + :param prefered_format: string with prefered time format + :returns: json + """ + return json.dumps(self.to_dict()) + + +class Flowspec6(db.Model): + id = db.Column(db.Integer, primary_key=True) + source = db.Column(db.String(255)) + source_mask = db.Column(db.Integer) + source_port = db.Column(db.String(255)) + dest = db.Column(db.String(255)) + dest_mask = db.Column(db.Integer) + dest_port = db.Column(db.String(255)) + next_header = db.Column(db.String(255)) + flags = db.Column(db.String(255)) + packet_len = db.Column(db.String(255)) + comment = db.Column(db.Text) + expires = db.Column(db.DateTime) + created = db.Column(db.DateTime) + action_id = db.Column(db.Integer, db.ForeignKey("action.id"), nullable=False) + action = db.relationship("Action", backref="flowspec6") + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", backref="flowspec6") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="flowspec6") + rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) + rstate = db.relationship("Rstate", backref="flowspec6") + + def __init__( + self, + source, + source_mask, + source_port, + destination, + destination_mask, + destination_port, + next_header, + flags, + packet_len, + expires, + user_id, + org_id, + action_id, + created=None, + comment=None, + rstate_id=1, + ): + self.source = source + self.source_mask = source_mask + self.dest = destination + self.dest_mask = destination_mask + self.source_port = source_port + self.dest_port = destination_port + self.next_header = next_header + self.flags = flags + self.packet_len = packet_len + self.comment = comment + self.expires = expires + self.user_id = user_id + self.org_id = org_id + self.action_id = action_id + if created is None: + created = datetime.now() + self.created = created + self.rstate_id = rstate_id + + def __eq__(self, other): + """ + Two models are equal if all the network parameters equals. User_id and time fields can differ. + :param other: other Flowspec4 instance + :return: boolean + """ + return ( + self.source == other.source + and self.source_mask == other.source_mask + and self.dest == other.dest + and self.dest_mask == other.dest_mask + and self.source_port == other.source_port + and self.dest_port == other.dest_port + and self.next_header == other.next_header + and self.flags == other.flags + and self.packet_len == other.packet_len + and self.action_id == other.action_id + and self.rstate_id == other.rstate_id + ) + + def to_dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + if prefered_format == "timestamp": + expires = int(datetime.timestamp(self.expires)) + created = int(datetime.timestamp(self.created)) + else: + expires = utils.datetime_to_webpicker(self.expires, prefered_format) + created = utils.datetime_to_webpicker(self.created, prefered_format) + + return { + "id": str(self.id), + "source": self.source, + "source_mask": self.source_mask, + "source_port": self.source_port, + "dest": self.dest, + "dest_mask": self.dest_mask, + "dest_port": self.dest_port, + "next_header": self.next_header, + "flags": self.flags, + "packet_len": self.packet_len, + "comment": self.comment, + "expires": expires, + "created": created, + "action": self.action.name, + "user": self.user.uuid, + "rstate": self.rstate.description, + } + + def dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + return self.to_dict(prefered_format) + + def json(self, prefered_format="yearfirst"): + """ + Serialize to json + :param prefered_format: string with prefered time format + :returns: json + """ + return json.dumps(self.to_dict()) diff --git a/flowapp/models/rules/rtbh.py b/flowapp/models/rules/rtbh.py new file mode 100644 index 00000000..943fef0f --- /dev/null +++ b/flowapp/models/rules/rtbh.py @@ -0,0 +1,153 @@ +import json +from datetime import datetime +from flowapp import utils +from ..base import db + + +class RTBH(db.Model): + __tablename__ = "RTBH" + + id = db.Column(db.Integer, primary_key=True) + ipv4 = db.Column(db.String(255)) + ipv4_mask = db.Column(db.Integer) + ipv6 = db.Column(db.String(255)) + ipv6_mask = db.Column(db.Integer) + community_id = db.Column(db.Integer, db.ForeignKey("community.id"), nullable=False) + community = db.relationship("Community", backref="rtbh") + comment = db.Column(db.Text) + expires = db.Column(db.DateTime) + created = db.Column(db.DateTime) + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", backref="rtbh") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="rtbh") + rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) + rstate = db.relationship("Rstate", backref="RTBH") + + def __init__( + self, + ipv4, + ipv4_mask, + ipv6, + ipv6_mask, + community_id, + expires, + user_id, + org_id, + comment=None, + created=None, + rstate_id=1, + ): + self.ipv4 = ipv4 + self.ipv4_mask = ipv4_mask + self.ipv6 = ipv6 + self.ipv6_mask = ipv6_mask + self.community_id = community_id + self.expires = expires + self.user_id = user_id + self.org_id = org_id + self.comment = comment + if created is None: + created = datetime.now() + self.created = created + self.rstate_id = rstate_id + + def __eq__(self, other): + """ + Two models are equal if all the network parameters equals. + User_id and time fields can differ. + :param other: other RTBH instance + :return: boolean + """ + return ( + self.ipv4 == other.ipv4 + and self.ipv4_mask == other.ipv4_mask + and self.ipv6 == other.ipv6 + and self.ipv6_mask == other.ipv6_mask + and self.community_id == other.community_id + and self.rstate_id == other.rstate_id + ) + + def __ne__(self, other): + """ + Two models are not equal if all the network parameters are not equal. + User_id and time fields can differ. + :param other: other RTBH instance + :return: boolean + """ + compars = ( + self.ipv4 == other.ipv4 + and self.ipv4_mask == other.ipv4_mask + and self.ipv6 == other.ipv6 + and self.ipv6_mask == other.ipv6_mask + and self.community_id == other.community_id + and self.rstate_id == other.rstate_id + ) + + return not compars + + def update_time(self, form): + self.expires = utils.webpicker_to_datetime(form.expire_date.data) + db.session.commit() + + def to_dict(self, prefered_format="yearfirst"): + """ + Serialize to dict used in API + :param prefered_format: string with prefered time format + :return: dictionary + """ + if prefered_format == "timestamp": + expires = int(datetime.timestamp(self.expires)) + created = int(datetime.timestamp(self.created)) + else: + expires = utils.datetime_to_webpicker(self.expires, prefered_format) + created = utils.datetime_to_webpicker(self.created, prefered_format) + + return { + "id": self.id, + "ipv4": self.ipv4, + "ipv4_mask": self.ipv4_mask, + "ipv6": self.ipv6, + "ipv6_mask": self.ipv6_mask, + "community": self.community.name, + "comment": self.comment, + "expires": expires, + "created": created, + "user": self.user.uuid, + "rstate": self.rstate.description, + } + + def dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + return self.to_dict(prefered_format) + + def json(self, prefered_format="yearfirst"): + """ + Serialize to json + :param prefered_format: string with prefered time format + :returns: json + """ + return json.dumps(self.to_dict()) + + def __repr__(self): + if not self.ipv6 and not self.ipv6_mask: + return f"" + if not self.ipv4 and not self.ipv4_mask: + return f"" + + return f"" + + def __str__(self): + if not self.ipv6 and not self.ipv6_mask: + return f"{self.ipv4}/{self.ipv4_mask}" + if not self.ipv4 and not self.ipv4_mask: + return f"{self.ipv6}/{self.ipv6_mask}" + + return f"{self.ipv4}/{self.ipv4_mask} {self.ipv6}/{self.ipv6_mask}" + + def get_author(self): + return f"{self.user.email} / {self.org}" diff --git a/flowapp/models/rules/whitelist.py b/flowapp/models/rules/whitelist.py new file mode 100644 index 00000000..83c131f6 --- /dev/null +++ b/flowapp/models/rules/whitelist.py @@ -0,0 +1,172 @@ +from flowapp import utils +from ..base import db +from datetime import datetime +from flowapp.constants import RuleTypes, RuleOrigin + + +class Whitelist(db.Model): + id = db.Column(db.Integer, primary_key=True) + ip = db.Column(db.String(255)) + mask = db.Column(db.Integer) + comment = db.Column(db.Text) + expires = db.Column(db.DateTime) + created = db.Column(db.DateTime) + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", backref="whitelist") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="whitelist") + rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) + rstate = db.relationship("Rstate", backref="whitelist") + + def __init__( + self, + ip, + mask, + expires, + user_id, + org_id, + created=None, + comment=None, + rstate_id=1, + ): + self.ip = ip + self.mask = mask + self.expires = expires + self.user_id = user_id + self.org_id = org_id + self.comment = comment + if created is None: + created = datetime.now() + self.created = created + self.rstate_id = rstate_id + + def __eq__(self, other): + """ + Two whitelists are equal if all the network parameters equals. + User_id, org, comment and time fields can differ. + :param other: other Whitelist instance + :return: boolean + """ + return self.ip == other.ip and self.mask == other.mask and self.rstate_id == other.rstate_id + + def to_dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + if prefered_format == "timestamp": + expires = int(datetime.timestamp(self.expires)) + created = int(datetime.timestamp(self.created)) + else: + expires = utils.datetime_to_webpicker(self.expires, prefered_format) + created = utils.datetime_to_webpicker(self.created, prefered_format) + + return { + "id": self.id, + "ip": self.ip, + "mask": self.mask, + "comment": self.comment, + "expires": expires, + "created": created, + "user": self.user.uuid, + "rstate": self.rstate.description, + } + + def dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + return self.to_dict(prefered_format) + + def __repr__(self): + return f"" + + def __str__(self): + return f"{self.ip}/{self.mask}" + + +class RuleWhitelistCache(db.Model): + """ + Cache for whitelisted rules + For each rule we store id and type + Rule origin determines if the rule was created by user or by whitelist + """ + + id = db.Column(db.Integer, primary_key=True) + rid = db.Column(db.Integer) + rtype = db.Column(db.Integer) + rorigin = db.Column(db.Integer) + whitelist_id = db.Column(db.Integer, db.ForeignKey("whitelist.id")) # Add ForeignKey + whitelist = db.relationship("Whitelist", backref="rulewhitelistcache") + + def __init__(self, rid: int, rtype: RuleTypes, whitelist_id: int, rorigin: RuleOrigin = RuleOrigin.USER): + self.rid = rid + self.rtype = rtype.value + self.rorigin = rorigin.value + self.whitelist_id = whitelist_id + + @classmethod + def get_by_whitelist_id(cls, whitelist_id: int): + """ + Get all cache items with the given whitelist ID + + Args: + whitelist_id (int): The ID of the whitelist to filter by + + Returns: + list: All RuleWhitelistCache objects with the specified whitelist_id + """ + return cls.query.filter_by(whitelist_id=whitelist_id).all() + + @classmethod + def clean_by_whitelist_id(cls, whitelist_id: int): + """ + Delete all cache entries with the given whitelist ID from the database + + Args: + whitelist_id (int): The ID of the whitelist to clean + + Returns: + int: Number of rows deleted + """ + deleted = cls.query.filter_by(whitelist_id=whitelist_id).delete() + db.session.commit() + return deleted + + @classmethod + def delete_by_rule_id(cls, rule_id: int): + """ + Delete all cache entries with the given rule ID from the database + + Args: + rule_id (int): The ID of the rule to clean + + Returns: + int: Number of rows deleted + """ + deleted = cls.query.filter_by(rid=rule_id).delete() + db.session.commit() + return deleted + + @classmethod + def count_by_rule(cls, rule_id: int, rule_type: RuleTypes): + """ + Count the number of cache entries for the given rule + + Args: + rule_id (int): The ID of the rule to count + rule_type (RuleTypes): The type of the rule + + Returns: + int: Number of cache entries + """ + return cls.query.filter_by(rid=rule_id, rtype=rule_type.value).count() + + def __repr__(self): + return f"" + + def __str__(self): + return f"{self.rid} {self.rtype} {self.rorigin}" diff --git a/flowapp/models/user.py b/flowapp/models/user.py new file mode 100644 index 00000000..dcb2d7eb --- /dev/null +++ b/flowapp/models/user.py @@ -0,0 +1,79 @@ +from sqlalchemy import event +from .base import db, user_role, user_organization +from .organization import Organization + + +class User(db.Model): + """ + App User + """ + + id = db.Column(db.Integer, primary_key=True) + uuid = db.Column(db.String(180), unique=True) + comment = db.Column(db.String(500)) + email = db.Column(db.String(255)) + name = db.Column(db.String(255)) + phone = db.Column(db.String(255)) + apikeys = db.relationship("ApiKey", back_populates="user", lazy="dynamic") + machineapikeys = db.relationship("MachineApiKey", back_populates="user", lazy="dynamic") + role = db.relationship("Role", secondary=user_role, lazy="dynamic", backref="user") + + organization = db.relationship("Organization", secondary=user_organization, lazy="dynamic", backref="user") + + def __init__(self, uuid, name=None, phone=None, email=None, comment=None): + self.uuid = uuid + self.phone = phone + self.name = name + self.email = email + self.comment = comment + + def update(self, form): + """ + update the user with values from form object + :param form: flask form from request + :return: None + """ + self.uuid = form.uuid.data + self.name = form.name.data + self.email = form.email.data + self.phone = form.phone.data + self.comment = form.comment.data + + # first clear existing roles and orgs + for role in self.role: + self.role.remove(role) + for org in self.organization: + self.organization.remove(org) + + for role_id in form.role_ids.data: + my_role = db.session.query(Role).filter_by(id=role_id).first() + if my_role not in self.role: + self.role.append(my_role) + + for org_id in form.org_ids.data: + my_org = db.session.query(Organization).filter_by(id=org_id).first() + if my_org not in self.organization: + self.organization.append(my_org) + + db.session.commit() + + +class Role(db.Model): + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(20), unique=True) + description = db.Column(db.String(260)) + + def __init__(self, name, description): + self.name = name + self.description = description + + def __repr__(self): + return self.name + + +# Event listeners for Role +@event.listens_for(Role.__table__, "after_create") +def insert_initial_roles(table, conn, *args, **kwargs): + conn.execute(table.insert().values(name="view", description="just view, no edit")) + conn.execute(table.insert().values(name="user", description="can edit")) + conn.execute(table.insert().values(name="admin", description="admin")) diff --git a/flowapp/models/utils.py b/flowapp/models/utils.py new file mode 100644 index 00000000..d0bf1208 --- /dev/null +++ b/flowapp/models/utils.py @@ -0,0 +1,377 @@ +"""Utility functions for models""" + +from datetime import datetime +from flowapp import utils +from flowapp.constants import RuleTypes +from flask import current_app + +from flowapp.models.rules.whitelist import Whitelist +from .base import db +from .user import User, Role +from .organization import Organization +from .community import Community +from .rules.flowspec import Flowspec4, Flowspec6 +from .rules.rtbh import RTBH +from .rules.base import Action + + +def check_rule_limit(org_id: int, rule_type: RuleTypes) -> bool: + """ + Check if the organization has reached the rule limit + :param org_id: integer organization id + :param rule_type: RuleType rule type + :return: boolean + """ + flowspec_limit = current_app.config.get("FLOWSPEC_MAX_RULES", 9000) + rtbh_limit = current_app.config.get("RTBH_MAX_RULES", 100000) + fs4 = db.session.query(Flowspec4).filter_by(rstate_id=1).count() + fs6 = db.session.query(Flowspec6).filter_by(rstate_id=1).count() + rtbh = db.session.query(RTBH).filter_by(rstate_id=1).count() + + # check the organization limits + org = Organization.query.filter_by(id=org_id).first() + if rule_type == RuleTypes.IPv4 and org.limit_flowspec4 > 0: + count = db.session.query(Flowspec4).filter_by(org_id=org_id, rstate_id=1).count() + return count >= org.limit_flowspec4 or fs4 >= flowspec_limit + if rule_type == RuleTypes.IPv6 and org.limit_flowspec6 > 0: + count = db.session.query(Flowspec6).filter_by(org_id=org_id, rstate_id=1).count() + return count >= org.limit_flowspec6 or fs6 >= flowspec_limit + if rule_type == RuleTypes.RTBH and org.limit_rtbh > 0: + count = db.session.query(RTBH).filter_by(org_id=org_id, rstate_id=1).count() + return count >= org.limit_rtbh or rtbh >= rtbh_limit + + return False + + +def check_global_rule_limit(rule_type: RuleTypes) -> bool: + flowspec4_limit = current_app.config.get("FLOWSPEC4_MAX_RULES", 9000) + flowspec6_limit = current_app.config.get("FLOWSPEC6_MAX_RULES", 9000) + rtbh_limit = current_app.config.get("RTBH_MAX_RULES", 100000) + fs4 = db.session.query(Flowspec4).filter_by(rstate_id=1).count() + fs6 = db.session.query(Flowspec6).filter_by(rstate_id=1).count() + rtbh = db.session.query(RTBH).filter_by(rstate_id=1).count() + + # check the global limits if the organization limits are not set + + if rule_type == RuleTypes.IPv4: + return fs4 >= flowspec4_limit + if rule_type == RuleTypes.IPv6: + return fs6 >= flowspec6_limit + if rule_type == RuleTypes.RTBH: + return rtbh >= rtbh_limit + + return False + + +def get_whitelist_model_if_exists(form_data): + """ + Check if the record in database exist + ip, mask should match + expires, rstate_id, user_id, org_id, created, comment can be different + """ + record = ( + db.session.query(Whitelist) + .filter( + Whitelist.ip == form_data["ip"], + Whitelist.mask == form_data["mask"], + ) + .first() + ) + + if record: + return record + + return False + + +def get_ipv4_model_if_exists(form_data, rstate_id=1): + """ + Check if the record in database exist + Source and destination addresses, protocol, flags, action and packet_len should match + Other fields can be different + """ + record = ( + db.session.query(Flowspec4) + .filter( + Flowspec4.source == form_data["source"], + Flowspec4.source_mask == form_data["source_mask"], + Flowspec4.source_port == form_data["source_port"], + Flowspec4.dest == form_data["dest"], + Flowspec4.dest_mask == form_data["dest_mask"], + Flowspec4.dest_port == form_data["dest_port"], + Flowspec4.protocol == form_data["protocol"], + Flowspec4.flags == ";".join(form_data["flags"]), + Flowspec4.packet_len == form_data["packet_len"], + Flowspec4.action_id == form_data["action"], + Flowspec4.rstate_id == rstate_id, + ) + .first() + ) + + if record: + return record + + return False + + +def get_ipv6_model_if_exists(form_data, rstate_id=1): + """ + Check if the record in database exist + """ + record = ( + db.session.query(Flowspec6) + .filter( + Flowspec6.source == form_data["source"], + Flowspec6.source_mask == form_data["source_mask"], + Flowspec6.source_port == form_data["source_port"], + Flowspec6.dest == form_data["dest"], + Flowspec6.dest_mask == form_data["dest_mask"], + Flowspec6.dest_port == form_data["dest_port"], + Flowspec6.next_header == form_data["next_header"], + Flowspec6.flags == ";".join(form_data["flags"]), + Flowspec6.packet_len == form_data["packet_len"], + Flowspec6.action_id == form_data["action"], + Flowspec6.rstate_id == rstate_id, + ) + .first() + ) + + if record: + return record + + return False + + +def get_rtbh_model_if_exists(form_data): + """ + Check RTBH rule exist in database + IPv4, IPv6 and community_id should match + Rule can be in any state and have different expires, user_id, org_id, created, comment + """ + + record = ( + db.session.query(RTBH) + .filter( + RTBH.ipv4 == form_data["ipv4"], + RTBH.ipv4_mask == form_data["ipv4_mask"], + RTBH.ipv6 == form_data["ipv6"], + RTBH.ipv6_mask == form_data["ipv6_mask"], + RTBH.community_id == form_data["community"], + ) + .first() + ) + + if record: + return record + + return False + + +def insert_users(users): + """ + inser list of users {name: string, role_id: integer} to db + """ + for user in users: + r = Role.query.filter_by(id=user["role_id"]).first() + o = Organization.query.filter_by(id=user["org_id"]).first() + u = User(uuid=user["name"]) + u.role.append(r) + u.organization.append(o) + db.session.add(u) + + db.session.commit() + + +def insert_user( + uuid: str, + role_ids: list, + org_ids: list, + name: str = None, + phone: str = None, + email: str = None, + comment: str = None, +): + """ + insert new user with multiple roles and organizations + :param uuid: string unique user id (eppn or similar) + :param phone: string phone number + :param name: string user name + :param email: string email + :param comment: string comment / notice + :param role_ids: list of roles + :param org_ids: list of orgs + :return: None + """ + u = User(uuid=uuid, name=name, phone=phone, comment=comment, email=email) + + for role_id in role_ids: + r = Role.query.filter_by(id=role_id).first() + u.role.append(r) + + for org_id in org_ids: + o = Organization.query.filter_by(id=org_id).first() + u.organization.append(o) + + db.session.add(u) + db.session.commit() + + +def get_user_nets(user_id): + """ + Return list of network ranges for all user organization + """ + user = db.session.query(User).filter_by(id=user_id).first() + orgs = user.organization + result = [] + for org in orgs: + result.extend(org.arange.split()) + + return result + + +def get_user_orgs_choices(user_id): + """ + Return list of orgs as choices for form + """ + user = db.session.query(User).filter_by(id=user_id).first() + orgs = user.organization + + return [(g.id, g.name) for g in orgs] + + +def get_user_actions(user_roles): + """ + Return list of actions based on current user role + """ + max_role = max(user_roles) + print(max_role) + if max_role == 3: + actions = db.session.query(Action).order_by("id").all() + else: + actions = db.session.query(Action).filter_by(role_id=max_role).order_by("id").all() + result = [(g.id, g.name) for g in actions] + print(actions, result) + return result + + +def get_user_communities(user_roles): + """ + Return list of communities based on current user role + """ + max_role = max(user_roles) + if max_role == 3: + communities = db.session.query(Community).order_by("id") + else: + communities = db.session.query(Community).filter_by(role_id=max_role).order_by("id") + + return [(g.id, g.name) for g in communities] + + +def get_existing_action(name=None, command=None): + """ + return Action with given name or command if the action exists + return None if action not exists + :param name: string action name + :param command: string action command + :return: action id + """ + action = Action.query.filter((Action.name == name) | (Action.command == command)).first() + return action.id if hasattr(action, "id") else None + + +def get_existing_community(name=None): + """ + return Community with given name or command if the action exists + return None if action not exists + :param name: string action name + :param command: string action command + :return: action id + """ + community = Community.query.filter(Community.name == name).first() + return community.id if hasattr(community, "id") else None + + +def get_ip_rules(rule_type, rule_state, sort="expires", order="desc"): + """ + Returns list of rules sorted by sort column ordered asc or desc + :param sort: sorting column + :param order: asc or desc + :return: list + """ + + today = datetime.now() + comp_func = utils.get_comp_func(rule_state) + + if rule_type == "ipv4": + sorter_ip4 = getattr(Flowspec4, sort, Flowspec4.id) + sorting_ip4 = getattr(sorter_ip4, order) + if comp_func: + rules4 = ( + db.session.query(Flowspec4).filter(comp_func(Flowspec4.expires, today)).order_by(sorting_ip4()).all() + ) + else: + rules4 = db.session.query(Flowspec4).order_by(sorting_ip4()).all() + + return rules4 + + if rule_type == "ipv6": + sorter_ip6 = getattr(Flowspec6, sort, Flowspec6.id) + sorting_ip6 = getattr(sorter_ip6, order) + if comp_func: + rules6 = ( + db.session.query(Flowspec6).filter(comp_func(Flowspec6.expires, today)).order_by(sorting_ip6()).all() + ) + else: + rules6 = db.session.query(Flowspec6).order_by(sorting_ip6()).all() + + return rules6 + + if rule_type == "rtbh": + sorter_rtbh = getattr(RTBH, sort, RTBH.id) + sorting_rtbh = getattr(sorter_rtbh, order) + + if comp_func: + rules_rtbh = db.session.query(RTBH).filter(comp_func(RTBH.expires, today)).order_by(sorting_rtbh()).all() + + else: + rules_rtbh = db.session.query(RTBH).order_by(sorting_rtbh()).all() + + return rules_rtbh + + if rule_type == "whitelist": + sorter_whitelist = getattr(Whitelist, sort, Whitelist.id) + sorting_whitelist = getattr(sorter_whitelist, order) + + if comp_func: + rules_whitelist = ( + db.session.query(Whitelist) + .filter(comp_func(Whitelist.expires, today)) + .order_by(sorting_whitelist()) + .all() + ) + + else: + rules_whitelist = db.session.query(Whitelist).order_by(sorting_whitelist()).all() + + return rules_whitelist + + +def get_user_rules_ids(user_id, rule_type): + """ + Returns list of rule ids belonging to user + :param user_id: user id + :param rule_type: ipv4, ipv6 or rtbh + :return: list + """ + + if rule_type == "ipv4": + rules4 = db.session.query(Flowspec4.id).filter_by(user_id=user_id).all() + return [int(x[0]) for x in rules4] + + if rule_type == "ipv6": + rules6 = db.session.query(Flowspec6.id).order_by(Flowspec6.expires.desc()).all() + return [int(x[0]) for x in rules6] + + if rule_type == "rtbh": + rules_rtbh = db.session.query(RTBH.id).filter_by(user_id=user_id).all() + return [int(x[0]) for x in rules_rtbh] diff --git a/flowapp/output.py b/flowapp/output.py index 3dde8221..0d023671 100644 --- a/flowapp/output.py +++ b/flowapp/output.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, asdict from datetime import datetime +from typing import Dict, Union, Callable, Any import requests import pika @@ -11,9 +12,12 @@ from flask import current_app from flowapp import db, messages +from flowapp.constants import RuleTypes from flowapp.models import Log +from flowapp.models import Flowspec4, Flowspec6, RTBH -ROUTE_MODELS = { + +ROUTE_MODELS: Dict[int, Callable[[Any], str]] = { 1: messages.create_rtbh, 4: messages.create_ipv4, 6: messages.create_ipv6, @@ -21,21 +25,21 @@ class RouteSources: - UI = "UI" - API = "API" + UI: str = "UI" + API: str = "API" @dataclass class Route: author: str - source: RouteSources + source: str # Using str instead of RouteSources for flexibility command: str - def __dict__(self): + def __dict__(self) -> Dict[str, str]: return asdict(self) -def announce_route(route: Route): +def announce_route(route: Route) -> None: """ Dispatch route as dict to ExaBGP API API must be set in app config.py @@ -48,7 +52,7 @@ def announce_route(route: Route): announce_to_http(asdict(route)) -def announce_to_http(route): +def announce_to_http(route: Dict[str, str]) -> None: """ Announce route to ExaBGP HTTP API process """ @@ -64,7 +68,7 @@ def announce_to_http(route): current_app.logger.debug(f"Testing: {route}") -def announce_to_rabbitmq(route): +def announce_to_rabbitmq(route: Dict[str, str]) -> None: """ Announce rout to ExaBGP RabbitMQ API process """ @@ -80,48 +84,57 @@ def announce_to_rabbitmq(route): credentials, ) - connection = pika.BlockingConnection(parameters) - channel = connection.channel() - channel.queue_declare(queue=queue) - channel.basic_publish(exchange="", routing_key=queue, body=json.dumps(route)) + with pika.BlockingConnection(parameters) as connection: + channel = connection.channel() + channel.queue_declare(queue=queue) + channel.basic_publish(exchange="", routing_key=queue, body=json.dumps(route)) else: - current_app.logger.debug("Testing: {route}") + current_app.logger.debug(f"Testing: {route}") -def log_route(user_id, route_model, rule_type, author): +def log_route(user_id: int, route_model: Union[RTBH, Flowspec4, Flowspec6], rule_type: RuleTypes, author: str) -> None: """ Convert route to EXAFS message and log it to database :param user_id : int curent user - :param route_model: model with route object - :param rule_type: string + :param route_model: model with route object (RTBH, Flowspec4, or Flowspec6) + :param rule_type: RuleTypes enum + :param author: str name of the author :return: None """ - converter = ROUTE_MODELS[rule_type] + rule_type_value = rule_type.value + converter = ROUTE_MODELS[rule_type_value] task = converter(route_model) log = Log( time=datetime.now(), task=task, - rule_type=rule_type, + rule_type=rule_type_value, rule_id=route_model.id, user_id=user_id, author=author, ) db.session.add(log) + current_app.logger.info(log) db.session.commit() -def log_withdraw(user_id, task, rule_type, deleted_id, author): +def log_withdraw(user_id: int, task: str, rule_type: RuleTypes, deleted_id: int, author: str) -> None: """ Log the withdraw command to database - :param task: command message + :param user_id: int user ID + :param task: str command message + :param rule_type: RuleTypes enum + :param deleted_id: int ID of deleted rule + :param author: str name of the author + :return: None """ log = Log( time=datetime.now(), task=task, - rule_type=rule_type, + rule_type=rule_type.value, rule_id=deleted_id, user_id=user_id, author=author, ) db.session.add(log) + current_app.logger.info(log) db.session.commit() diff --git a/flowapp/services/__init__.py b/flowapp/services/__init__.py new file mode 100644 index 00000000..be4b838b --- /dev/null +++ b/flowapp/services/__init__.py @@ -0,0 +1,21 @@ +from .rule_service import ( + create_or_update_ipv4_rule, + create_or_update_ipv6_rule, + create_or_update_rtbh_rule, + reactivate_rule, +) + +from .whitelist_service import create_or_update_whitelist, delete_whitelist, delete_expired_whitelists + +from .base import announce_all_routes + +__all__ = [ + create_or_update_ipv4_rule, + create_or_update_ipv6_rule, + create_or_update_rtbh_rule, + create_or_update_whitelist, + delete_whitelist, + delete_expired_whitelists, + announce_all_routes, + reactivate_rule, +] diff --git a/flowapp/services/base.py b/flowapp/services/base.py new file mode 100644 index 00000000..77adeb07 --- /dev/null +++ b/flowapp/services/base.py @@ -0,0 +1,93 @@ +from datetime import datetime +from operator import ge, lt +from flowapp import constants, db, messages +from flowapp.constants import ANNOUNCE, WITHDRAW +from flowapp.models import RTBH, Flowspec4, Flowspec6 +from flowapp.output import Route, RouteSources, announce_route + + +def announce_rtbh_route(model: RTBH, author: str) -> None: + """ + Announce RTBH route if rule is in active state + """ + if model.rstate_id == 1: + command = messages.create_rtbh(model, ANNOUNCE) + route = Route( + author=author, + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + +def withdraw_rtbh_route(model: RTBH) -> None: + """ + Withdraw RTBH route if rule is in whitelist state + """ + if model.rstate_id == 4: + command = messages.create_rtbh(model, WITHDRAW) + route = Route( + author=model.get_author(), + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + +def announce_all_routes(action=constants.ANNOUNCE): + """ + get routes from db and send it to ExaBGB api + + @TODO take the request away, use some kind of messaging (maybe celery?) + :param action: action with routes - announce valid routes or withdraw expired routes + """ + today = datetime.now() + comp_func = ge if action == constants.ANNOUNCE else lt + + rules4 = ( + db.session.query(Flowspec4) + .filter(Flowspec4.rstate_id == 1) + .filter(comp_func(Flowspec4.expires, today)) + .order_by(Flowspec4.expires.desc()) + .all() + ) + rules6 = ( + db.session.query(Flowspec6) + .filter(Flowspec6.rstate_id == 1) + .filter(comp_func(Flowspec6.expires, today)) + .order_by(Flowspec6.expires.desc()) + .all() + ) + rules_rtbh = ( + db.session.query(RTBH) + .filter(RTBH.rstate_id == 1) + .filter(comp_func(RTBH.expires, today)) + .order_by(RTBH.expires.desc()) + .all() + ) + + messages_v4 = [messages.create_ipv4(rule, action) for rule in rules4] + messages_v6 = [messages.create_ipv6(rule, action) for rule in rules6] + messages_rtbh = [messages.create_rtbh(rule, action) for rule in rules_rtbh] + + messages_all = [] + messages_all.extend(messages_v4) + messages_all.extend(messages_v6) + messages_all.extend(messages_rtbh) + + author_action = "announce all" if action == constants.ANNOUNCE else "withdraw all expired" + + for command in messages_all: + route = Route( + author=f"System call / {author_action} rules", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + if action == constants.WITHDRAW: + for ruleset in [rules4, rules6, rules_rtbh]: + for rule in ruleset: + rule.rstate_id = 2 + + db.session.commit() diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py new file mode 100644 index 00000000..fca7719a --- /dev/null +++ b/flowapp/services/rule_service.py @@ -0,0 +1,533 @@ +# flowapp/services/rule_service.py +""" +Service module for rule operations. + +This module provides business logic functions for creating, updating, +and managing flow rules, separating these concerns from HTTP handling. +""" + +from datetime import datetime, timedelta +from typing import Dict, List, Tuple, Union + +from flask import current_app + +from flowapp import db, messages +from flowapp.constants import WITHDRAW, RuleOrigin, RuleTypes, ANNOUNCE +from flowapp.models import ( + get_ipv4_model_if_exists, + get_ipv6_model_if_exists, + get_rtbh_model_if_exists, + Flowspec4, + Flowspec6, + RTBH, + Whitelist, +) +from flowapp.models.rules.whitelist import RuleWhitelistCache +from flowapp.models.utils import check_global_rule_limit, check_rule_limit +from flowapp.output import ROUTE_MODELS, Route, announce_route, log_route, RouteSources, log_withdraw +from flowapp.services.base import announce_rtbh_route +from flowapp.services.whitelist_common import ( + Relation, + add_rtbh_rule_to_cache, + create_rtbh_from_whitelist_parts, + subtract_network, + whitelist_rtbh_rule, + check_rule_against_whitelists, +) +from flowapp.services.whitelist_service import create_or_update_whitelist +from flowapp.utils import round_to_ten_minutes, get_state_by_time, quote_to_ent + + +def reactivate_rule( + rule_type: RuleTypes, + rule_id: int, + expires: datetime, + comment: str, + user_id: int, + org_id: int, + user_email: str, + org_name: str, +) -> Tuple[Union[RTBH, Flowspec4, Flowspec6], List[str]]: + """ + Reactivate a rule by setting a new expiration time. + + Args: + rule_type: Type of rule (RTBH, IPv4, IPv6) + rule_id: ID of the rule to reactivate + expires: New expiration datetime + comment: Updated comment + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + + Returns: + Tuple containing (rule_model, messages) + """ + model_name = {RuleTypes.RTBH: RTBH, RuleTypes.IPv4: Flowspec4, RuleTypes.IPv6: Flowspec6}[rule_type] + + model = db.session.get(model_name, rule_id) + if not model: + return None, ["Rule not found"] + + flashes = [] + + # Check if rule will be reactivated + state = get_state_by_time(expires) + + # Check global limit + if state == 1 and check_global_rule_limit(rule_type.value): + return model, ["global_limit_reached"] + + # Check org limit + if state == 1 and check_rule_limit(org_id, rule_type=rule_type.value): + return model, ["limit_reached"] + + # Set new expiration date + model.expires = expires + # Set again the active state, if the rule is not whitelisted RTBH + if rule_type == RuleTypes.RTBH: + model = check_rtbh_whitelisted(model, user_id, flashes, f"{user_email} / {org_name}") + else: + model.rstate_id = state + + model.comment = comment + db.session.commit() + + route_model = ROUTE_MODELS[rule_type.value] + + if model.rstate_id == 1: + flashes.append("Rule successfully updated, state set to active.") + # Announce route + command = route_model(model, ANNOUNCE) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + # Log changes + log_route( + user_id, + model, + rule_type, + f"{user_email} / {org_name}", + ) + else: + # Withdraw route + command = route_model(model, WITHDRAW) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + # Log changes + log_withdraw( + user_id, + route.command, + rule_type, + model.id, + f"{user_email} / {org_name}", + ) + if model.rstate_id == 4: + flashes.append("Rule successfully updated, state set to whitelisted.") + else: + flashes.append("Rule successfully updated, state set to inactive.") + + return model, flashes + + +def create_or_update_ipv4_rule( + form_data: Dict, user_id: int, org_id: int, user_email: str, org_name: str +) -> Tuple[Flowspec4, str]: + """ + Create a new IPv4 rule or update an existing one. + + Args: + form_data: Validated form data + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + + Returns: + Tuple containing (rule_model, message) + """ + # Check for existing model + model = get_ipv4_model_if_exists(form_data, 1) + + if model: + model.expires = round_to_ten_minutes(form_data["expires"]) + flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value." + else: + # Create new model + model = Flowspec4( + source=form_data["source"], + source_mask=form_data["source_mask"], + source_port=form_data["source_port"], + destination=form_data["dest"], + destination_mask=form_data["dest_mask"], + destination_port=form_data["dest_port"], + protocol=form_data["protocol"], + flags=";".join(form_data["flags"]), + packet_len=form_data["packet_len"], + fragment=";".join(form_data["fragment"]), + expires=round_to_ten_minutes(form_data["expires"]), + comment=quote_to_ent(form_data["comment"]), + action_id=form_data["action"], + user_id=user_id, + org_id=org_id, + rstate_id=get_state_by_time(form_data["expires"]), + ) + db.session.add(model) + flash_message = "IPv4 Rule saved" + + db.session.commit() + + # Announce route if model is in active state + if model.rstate_id == 1: + command = messages.create_ipv4(model, ANNOUNCE) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + # Log changes + log_route( + user_id, + model, + RuleTypes.IPv4, + f"{user_email} / {org_name}", + ) + + return model, flash_message + + +def create_or_update_ipv6_rule( + form_data: Dict, user_id: int, org_id: int, user_email: str, org_name: str +) -> Tuple[Flowspec6, str]: + """ + Create a new IPv6 rule or update an existing one. + + Args: + form_data: Validated form data + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + + Returns: + Tuple containing (rule_model, message) + """ + # Check for existing model + model = get_ipv6_model_if_exists(form_data, 1) + + if model: + model.expires = round_to_ten_minutes(form_data["expires"]) + flash_message = "Existing IPv6 Rule found. Expiration time was updated to new value." + else: + # Create new model + model = Flowspec6( + source=form_data["source"], + source_mask=form_data["source_mask"], + source_port=form_data["source_port"], + destination=form_data["dest"], + destination_mask=form_data["dest_mask"], + destination_port=form_data["dest_port"], + next_header=form_data["next_header"], + flags=";".join(form_data["flags"]), + packet_len=form_data["packet_len"], + expires=round_to_ten_minutes(form_data["expires"]), + comment=quote_to_ent(form_data["comment"]), + action_id=form_data["action"], + user_id=user_id, + org_id=org_id, + rstate_id=get_state_by_time(form_data["expires"]), + ) + db.session.add(model) + flash_message = "IPv6 Rule saved" + + db.session.commit() + + # Announce routes + if model.rstate_id == 1: + command = messages.create_ipv6(model, ANNOUNCE) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + # Log changes + log_route( + user_id, + model, + RuleTypes.IPv6, + f"{user_email} / {org_name}", + ) + + return model, flash_message + + +def create_or_update_rtbh_rule( + form_data: Dict, user_id: int, org_id: int, user_email: str, org_name: str +) -> Tuple[RTBH, List]: + """ + Create a new RTBH rule or update an existing one. + + Args: + form_data: Validated form data + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + + Returns: + Tuple containing (rule_model, message) + """ + # Check for existing model + model = get_rtbh_model_if_exists(form_data) + flashes = [] + if model: + model.expires = round_to_ten_minutes(form_data["expires"]) + flashes.append("Existing RTBH Rule found. Expiration time was updated to new value.") + else: + # Create new model + model = RTBH( + ipv4=form_data["ipv4"], + ipv4_mask=form_data["ipv4_mask"], + ipv6=form_data["ipv6"], + ipv6_mask=form_data["ipv6_mask"], + community_id=form_data["community"], + expires=round_to_ten_minutes(form_data["expires"]), + comment=quote_to_ent(form_data["comment"]), + user_id=user_id, + org_id=org_id, + rstate_id=get_state_by_time(form_data["expires"]), + ) + db.session.add(model) + flashes.append("RTBH Rule saved") + + db.session.commit() + + # rule author for logging and announcing + author = f"{user_email} / {org_name}" + # Check if rule is whitelisted + model = check_rtbh_whitelisted(model, user_id, flashes, author) + + announce_rtbh_route(model, author=author) + # Log changes + log_route(user_id, model, RuleTypes.RTBH, author) + + return model, flashes + + +def check_rtbh_whitelisted(model: RTBH, user_id: int, flashes: List[str], author: str) -> None: + # Check if rule is whitelisted + allowed_communities = current_app.config["ALLOWED_COMMUNITIES"] + if model.community_id in allowed_communities: + # get all not expired whitelists + whitelists = db.session.query(Whitelist).filter(Whitelist.expires > datetime.now()).all() + wl_cache = map_whitelists_to_strings(whitelists) + results = check_rule_against_whitelists(str(model), wl_cache.keys()) + # check rule against whitelists + model = evaluate_rtbh_against_whitelists_check_results(user_id, model, flashes, author, wl_cache, results) + return model + + +def evaluate_rtbh_against_whitelists_check_results( + user_id: int, + model: RTBH, + flashes: List[str], + author: str, + wl_cache: Dict[str, Whitelist], + results: List[Tuple[str, str, Relation]], +) -> RTBH: + """ + Evaluate RTBH rule against whitelist check results. + Process all results for cases where rule is whitelisted by several whitelists. + """ + for rule, whitelist_key, relation in results: + if relation == Relation.EQUAL: + model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) + msg = f"RTBH Rule {model.id} {model} is equal to active whitelist {whitelist_key}. Rule is whitelisted." + flashes.append(msg) + current_app.logger.info(msg) + elif relation == Relation.SUBNET: + parts = subtract_network(target=str(model), whitelist=whitelist_key) + wl_id = wl_cache[whitelist_key].id + msg = f"RTBH Rule {model.id} {model} is supernet of active whitelist {whitelist_key}. Rule is whitelisted, {len(parts)} subnet rules created." + flashes.append(msg) + current_app.logger.info(msg) + for network in parts: + new_rule = create_rtbh_from_whitelist_parts(model, wl_id, whitelist_key, network, author, user_id) + msg = f"Created RTBH rule {new_rule.id} {new_rule} for {network} parted by whitelist {whitelist_key}" + flashes.append(msg) + current_app.logger.info(msg) + model.rstate_id = 4 + add_rtbh_rule_to_cache(model, wl_id, RuleOrigin.USER) + db.session.commit() + elif relation == Relation.SUPERNET: + model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) + msg = f"RTBH Rule {model.id} {model} is subnet of active whitelist {whitelist_key}. Rule is whitelisted." + current_app.logger.info(msg) + flashes.append(msg) + return model + + +def map_whitelists_to_strings(whitelists: List[Whitelist]) -> Dict[str, Whitelist]: + return {str(w): w for w in whitelists} + + +def delete_rule( + rule_type: RuleTypes, rule_id: int, user_id: int, user_email: str, org_name: str, allowed_rule_ids: List[int] = None +) -> Tuple[bool, str]: + """ + Delete a rule with the given id and type. + + Args: + rule_type: Type of rule (RTBH, IPv4, IPv6) + rule_id: ID of the rule to delete + user_id: Current user ID + user_email: User email for logging + org_name: Organization name for logging + allowed_rule_ids: List of rule IDs the user is allowed to delete, None means no restriction + + Returns: + Tuple containing (success, message) + """ + model_class = {RuleTypes.RTBH: RTBH, RuleTypes.IPv4: Flowspec4, RuleTypes.IPv6: Flowspec6}[rule_type] + + route_model = ROUTE_MODELS[rule_type.value] + + model = db.session.get(model_class, rule_id) + if not model: + return False, "Rule not found" + + # Check permission if allowed_rule_ids is provided + if allowed_rule_ids is not None and model.id not in allowed_rule_ids: + return False, "You cannot delete this rule" + + # Withdraw route + command = route_model(model, WITHDRAW) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + # Log withdrawal + log_withdraw( + user_id, + route.command, + rule_type, + model.id, + f"{user_email} / {org_name}", + ) + + # Special handling for RTBH rules + if rule_type == RuleTypes.RTBH: + current_app.logger.debug(f"Deleting RTBH rule {rule_id} from cache") + RuleWhitelistCache.delete_by_rule_id(rule_id) + + # Delete from database + db.session.delete(model) + db.session.commit() + + return True, "Rule deleted successfully" + + +def delete_rtbh_and_create_whitelist( + rule_id: int, + user_id: int, + org_id: int, + user_email: str, + org_name: str, + allowed_rule_ids: List[int] = None, + whitelist_expires: datetime = None, +) -> Tuple[bool, List[str], Union[Whitelist, None]]: + """ + Delete an RTBH rule and create a whitelist entry from it. + + Args: + rule_id: ID of the RTBH rule to delete + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + allowed_rule_ids: List of rule IDs the user is allowed to delete + whitelist_expires: Expiration time for the whitelist entry (default: 7 days from now) + + Returns: + Tuple containing (success, messages, whitelist_model) + """ + messages = [] + + # First get the RTBH rule to extract its data + model = db.session.get(RTBH, rule_id) + if not model: + return False, ["RTBH rule not found"], None + + # Check permission if allowed_rule_ids is provided + if allowed_rule_ids is not None and model.id not in allowed_rule_ids: + return False, ["You cannot delete this rule"], None + + # Extract data for whitelist + if model.ipv4: + ip = model.ipv4 + mask = model.ipv4_mask + elif model.ipv6: + ip = model.ipv6 + mask = model.ipv6_mask + else: + return False, ["RTBH rule has no IP address"], None + + # Set default whitelist expiration time if not provided + if whitelist_expires is None: + whitelist_expires = datetime.now() + timedelta(days=7) + + # Prepare whitelist data + # Create base comment + comment_text = f"Created from RTBH rule {model} {rule_id}" + # Append the rule's comment only if it exists + if model.comment: + comment_text += f": {model.comment}" + + whitelist_data = { + "ip": ip, + "mask": mask, + "expires": whitelist_expires, + "comment": comment_text, + } + + # Delete the RTBH rule + success, delete_message = delete_rule( + rule_type=RuleTypes.RTBH, + rule_id=rule_id, + user_id=user_id, + user_email=user_email, + org_name=org_name, + allowed_rule_ids=allowed_rule_ids, + ) + + if not success: + return False, [delete_message], None + + messages.append(delete_message) + + # Create the whitelist entry + try: + whitelist_model, whitelist_messages = create_or_update_whitelist( + form_data=whitelist_data, user_id=user_id, org_id=org_id, user_email=user_email, org_name=org_name + ) + messages.extend(whitelist_messages) + return True, messages, whitelist_model + except Exception as e: + current_app.logger.exception(f"Error creating whitelist entry: {e}") + messages.append(f"Rule deleted but failed to create whitelist: {str(e)}") + return False, messages, None diff --git a/flowapp/services/whitelist_common.py b/flowapp/services/whitelist_common.py new file mode 100644 index 00000000..b9b9e8e3 --- /dev/null +++ b/flowapp/services/whitelist_common.py @@ -0,0 +1,244 @@ +from enum import Enum, auto +from functools import lru_cache +import ipaddress +from typing import List, Tuple, Union +from flowapp import db +from flowapp.constants import RuleOrigin, RuleTypes +from flowapp.models import RTBH, RuleWhitelistCache, Whitelist +from flowapp.output import log_route +from flowapp.services.base import announce_rtbh_route + + +def add_rtbh_rule_to_cache(model: RTBH, whitelist_id: int, rule_origin: RuleOrigin = RuleOrigin.USER) -> None: + """ + Add RTBH rule to whitelist cache + """ + cache = RuleWhitelistCache(rid=model.id, rtype=RuleTypes.RTBH, whitelist_id=whitelist_id, rorigin=rule_origin) + db.session.add(cache) + db.session.commit() + + +def whitelist_rtbh_rule(model: RTBH, whitelist: Whitelist) -> RTBH: + """ + Whitelist RTBH rule. + Set rule state to 4 - whitelisted rule, do not announce later + Add to whitelist cache + """ + model.rstate_id = 4 + db.session.commit() + add_rtbh_rule_to_cache(model, whitelist.id, RuleOrigin.USER) + return model + + +class Relation(Enum): + """ + Enum to represent the relation between Whitelist to Rule Relation + Subnet: Whitelist is a subnet of the rule + Supernet: Whitelist is a supernet of the rule + Equal: Whitelist is equal to the rule + Different: Whitelist is different from the rule + """ + + SUBNET = auto() + SUPERNET = auto() + EQUAL = auto() + DIFFERENT = auto() + + +def _is_same_ip_version(addr1: str, addr2: str) -> bool: + """ + Check if two IP addresses/networks are of the same IP version. + + :param addr1: First IP address or network string + :param addr2: Second IP address or network string + :return: True if both addresses are of the same IP version (IPv4 or IPv6), False otherwise + """ + is_ipv4_1 = "." in addr1 + is_ipv4_2 = "." in addr2 + return is_ipv4_1 == is_ipv4_2 + + +@lru_cache(maxsize=1024) +def get_network(address: str) -> Union[ipaddress.IPv4Network, ipaddress.IPv6Network]: + """ + Create and cache an IP network object. + + :param address: IP address or network in string format + :return: Cached IP network object + """ + return ipaddress.ip_network(address, strict=False) + + +def check_whitelist_to_rule_relation(rule: str, whitelist_entry: str) -> Relation: + """ + Checks if the whitelist network is a subnet or supernet or exactly the same as the rule network. + Uses cached network objects for better performance. + + If the rule and whitelist are different IP versions (IPv4 vs IPv6), + they are treated as different networks. + + :param rule: The IP address or network to check (e.g., "192.168.1.1" or "192.168.1.0/24") + :param whitelist_entry: The allowed network to compare against (e.g., "192.168.1.0/24") + :return: Relation between the two networks + """ + # First check if IP versions are the same + if not _is_same_ip_version(rule, whitelist_entry): + return Relation.DIFFERENT + + try: + rule_net = get_network(rule) + whitelist_net = get_network(whitelist_entry) + + if whitelist_net == rule_net: + return Relation.EQUAL + if whitelist_net.supernet_of(rule_net): + return Relation.SUPERNET + if whitelist_net.subnet_of(rule_net): + return Relation.SUBNET + else: + return Relation.DIFFERENT + except (ValueError, TypeError): + # Handle any other errors that might occur during comparison + return Relation.DIFFERENT + + +def subtract_network(target: str, whitelist: str) -> List[str]: + """ + Computes the remaining parts of a network after removing the whitelist subnet. + Uses cached network objects for better performance. + + If the target and whitelist are different IP versions (IPv4 vs IPv6), + the original target is returned unchanged. + + :param target: The main network (e.g., "192.168.1.0/24") + :param whitelist: The subnet to remove (e.g., "192.168.1.128/25") + :return: A list of remaining subnets as strings + """ + # First check if IP versions are the same + if not _is_same_ip_version(target, whitelist): + return [target] + + try: + target_net = get_network(target) + whitelist_net = get_network(whitelist) + + # Check if the whitelist is actually a subnet + if check_whitelist_to_rule_relation(target, whitelist) != Relation.SUBNET: + return [target] # Return the full network if whitelist isn't a valid subnet + + remaining = [] + + # Compute ranges before and after the whitelist + if whitelist_net.network_address > target_net.network_address: + # Before the whitelist + start = target_net.network_address + end = whitelist_net.network_address - 1 + remaining.extend(ipaddress.summarize_address_range(start, end)) + + if whitelist_net.broadcast_address < target_net.broadcast_address: + # After the whitelist + start = whitelist_net.broadcast_address + 1 + end = target_net.broadcast_address + remaining.extend(ipaddress.summarize_address_range(start, end)) + + # Convert to string format + return [str(net) for net in remaining] + except (ValueError, TypeError): + # Return the original target in case of any error + return [target] + + +def check_rule_against_whitelists(rule: str, whitelists: List[str]) -> List[Tuple]: + """ + Helper function to check a single rule against multiple whitelist entries. + Creates a cached rule network object for better performance. + Reduces list of whitelists, where the Relation is not DIFFERENT + + :param rule: The IP address or network to check + :param whitelists: List of whitelist networks to check against + :return: tuple of rule, whitelist and relation for each whitelists that is not DIFFERENT + """ + # Pre-cache the rule network since it will be used multiple times + try: + get_network(rule) + except (ValueError, TypeError): + # Return empty list if rule is not a valid network + return [] + + items = [] + for whitelist in whitelists: + try: + rel = check_whitelist_to_rule_relation(rule, whitelist) + if rel != Relation.DIFFERENT: + items.append((rule, whitelist, rel)) + except (ValueError, TypeError): + # Skip this whitelist if there's an error comparing + continue + return items + + +def check_whitelist_against_rules(rules: List[str], whitelist: str) -> List[Tuple]: + """ + Helper function to check if any whitelist entry is a subnet of the rule. + Creates a cached rule network object for better performance. + Reduces list of rules, where the Relation is not DIFFERENT + + :param rules: List of rule networks to check against + :param whitelist: The whitelist network to check + :return: tuple of rule, whitelist and relation for each rules that is not DIFFERENT + """ + # Pre-cache the whitelist network since it will be used multiple times + try: + get_network(whitelist) + except (ValueError, TypeError): + # Return empty list if whitelist is not a valid network + return [] + + items = [] + for rule in rules: + try: + rel = check_whitelist_to_rule_relation(rule, whitelist) + if rel != Relation.DIFFERENT: + items.append((rule, whitelist, rel)) + except (ValueError, TypeError): + # Skip this rule if there's an error comparing + continue + return items + + +def clear_network_cache() -> None: + """ + Clear the network object cache. + Useful when processing a large number of networks to prevent memory growth. + """ + get_network.cache_clear() + + +def create_rtbh_from_whitelist_parts( + model: RTBH, wl_id: int, whitelist_key: str, network: str, rule_owner: str = "", user_id: int = 0 +) -> RTBH: + # default values from model + rule_owner = rule_owner or model.get_author() + user_id = user_id or model.user_id + + net_ip, net_mask = network.split("/") + new_model = RTBH( + ipv4=net_ip, + ipv4_mask=net_mask, + ipv6=model.ipv6, + ipv6_mask=model.ipv6_mask, + community_id=model.community_id, + expires=model.expires, + comment=model.comment, + user_id=model.user_id, + org_id=model.org_id, + rstate_id=1, + ) + db.session.add(new_model) + db.session.commit() + + add_rtbh_rule_to_cache(new_model, wl_id, RuleOrigin.WHITELIST) + announce_rtbh_route(new_model, rule_owner) + log_route(user_id, model, RuleTypes.RTBH, rule_owner) + + return new_model diff --git a/flowapp/services/whitelist_service.py b/flowapp/services/whitelist_service.py new file mode 100644 index 00000000..bf281af7 --- /dev/null +++ b/flowapp/services/whitelist_service.py @@ -0,0 +1,198 @@ +# flowapp/services/rule_service.py +""" +Service module for rule operations. + +This module provides business logic functions for creating, updating, +and managing flow rules, separating these concerns from HTTP handling. +""" +from flask import current_app +from typing import Dict, Tuple, List + +import sqlalchemy + +from flowapp import db +from flowapp.constants import RuleOrigin, RuleTypes +from flowapp.models import Whitelist, RuleWhitelistCache, get_whitelist_model_if_exists +from flowapp.models.rules.flowspec import Flowspec4, Flowspec6 +from flowapp.models.rules.rtbh import RTBH +from flowapp.services.base import announce_rtbh_route, withdraw_rtbh_route +from flowapp.services.whitelist_common import add_rtbh_rule_to_cache, create_rtbh_from_whitelist_parts +from flowapp.services.whitelist_common import ( + Relation, + check_whitelist_against_rules, + subtract_network, + whitelist_rtbh_rule, +) +from flowapp.utils import round_to_ten_minutes, quote_to_ent + + +def create_or_update_whitelist( + form_data: Dict, user_id: int, org_id: int, user_email: str, org_name: str +) -> Tuple[Whitelist, str]: + """ + Create a new Whitelist rule or update expiration time of an existing one. + + Args: + form_data: Validated form data + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + + Returns: + Tuple containing (whitelist_model, message) + """ + # Check for existing model + model = get_whitelist_model_if_exists(form_data) + flashes = [] + if model: + model.expires = round_to_ten_minutes(form_data["expires"]) + flashes.append("Existing Whitelist found. Expiration time was updated to new value.") + else: + # Create new model + model = Whitelist( + ip=form_data["ip"], + mask=form_data["mask"], + expires=round_to_ten_minutes(form_data["expires"]), + user_id=user_id, + org_id=org_id, + comment=quote_to_ent(form_data["comment"]), + ) + db.session.add(model) + flashes.append("Whitelist saved") + + db.session.commit() + + # check RTBH rules against whitelist + allowed_communities = current_app.config["ALLOWED_COMMUNITIES"] + # filter out RTBH rules that are not active or whitelisted and not in allowed communities + all_rtbh_rules = RTBH.query.filter(RTBH.rstate_id.in_([1, 4]), RTBH.community_id.in_(allowed_communities)).all() + rtbh_rules_map = map_rtbh_rules_to_strings(all_rtbh_rules) + result = check_whitelist_against_rules(rtbh_rules_map, str(model)) + current_app.logger.info(f"Found {len(result)} matching RTBH rules for whitelist {model}") + model = evaluate_whitelist_against_rtbh_check_results(model, flashes, rtbh_rules_map, result) + + return model, flashes + + +def evaluate_whitelist_against_rtbh_check_results( + whitelist_model: Whitelist, + flashes: List[str], + rtbh_rule_cache: Dict[str, Whitelist], + results: List[Tuple[str, str, Relation]], +) -> Whitelist: + + for rule_key, whitelist_key, relation in results: + current_app.logger.info(f"whitelist {whitelist_key} is {relation} to Rule {rule_key}") + if relation == Relation.EQUAL: + whitelist_rtbh_rule(rtbh_rule_cache[rule_key], whitelist_model) + withdraw_rtbh_route(rtbh_rule_cache[rule_key]) + msg = "Existing active rule {rule_key} is equal to whitelist {whitelist_key}. Rule is now whitelisted." + flashes.append(msg) + current_app.logger.info(msg) + elif relation == Relation.SUBNET: + parts = subtract_network(target=rule_key, whitelist=whitelist_key) + wl_id = whitelist_model.id + msg = f"Rule {rule_key} is supernet of whitelist {whitelist_key}. Rule is whitelisted, {len(parts)} subnet rules will be created." + flashes.append(msg) + current_app.logger.info(msg) + for network in parts: + rule_model = rtbh_rule_cache[rule_key] + create_rtbh_from_whitelist_parts(rule_model, wl_id, whitelist_key, network) + msg = f"Created RTBH rule from {rule_model.id} {network} parted by whitelist {whitelist_key}." + flashes.append(msg) + current_app.logger.info(msg) + rule_model.rstate_id = 4 + add_rtbh_rule_to_cache(rule_model, wl_id, RuleOrigin.USER) + db.session.commit() + elif relation == Relation.SUPERNET: + + whitelist_rtbh_rule(rtbh_rule_cache[rule_key], whitelist_model) + withdraw_rtbh_route(rtbh_rule_cache[rule_key]) + msg = f"Existing active rule {rule_key} is subnet of whitelist {whitelist_key}. Rule is now whitelisted." + current_app.logger.info(msg) + flashes.append(msg) + + return whitelist_model + + +def map_rtbh_rules_to_strings(all_rtbh_rules: List[RTBH]) -> Dict[str, RTBH]: + return {str(rule): rule for rule in all_rtbh_rules} + + +def delete_expired_whitelists() -> List[str]: + """ + Delete all expired whitelist entries from the database. + + Returns: + List of messages for the user + """ + expired_whitelists = Whitelist.query.filter(Whitelist.expires < db.func.now()).all() + flashes = [] + for model in expired_whitelists: + flashes.extend(delete_whitelist(model.id)) + return flashes + + +def delete_whitelist(whitelist_id: int) -> List[str]: + """ + Delete a whitelist entry from the database. + + Args: + whitelist_id: The ID of the whitelist to delete + """ + model = db.session.get(Whitelist, whitelist_id) + flashes = [] + if model: + cached_rules = RuleWhitelistCache.get_by_whitelist_id(whitelist_id) + current_app.logger.info( + f"Deleting whitelist {whitelist_id} {model}. Found {len(cached_rules)} cached rules to process." + ) + for cached_rule in cached_rules: + rule_model_type = RuleTypes(cached_rule.rtype) + cache_entries_count = RuleWhitelistCache.count_by_rule(cached_rule.rid, rule_model_type) + if rule_model_type == RuleTypes.IPv4: + rule_model = db.session.get(Flowspec4, cached_rule.rid) + elif rule_model_type == RuleTypes.IPv6: + rule_model = db.session.get(Flowspec6, cached_rule.rid) + elif rule_model_type == RuleTypes.RTBH: + rule_model = db.session.get(RTBH, cached_rule.rid) + rorigin_type = RuleOrigin(cached_rule.rorigin) + current_app.logger.debug(f"Rule {rule_model} has origin {rorigin_type}") + if rorigin_type == RuleOrigin.WHITELIST: + msg = f"Deleted {rule_model_type} rule {rule_model} created by whitelist {model}" + current_app.logger.info(msg) + flashes.append(msg) + try: + db.session.delete(rule_model) + except sqlalchemy.orm.exc.UnmappedInstanceError: + current_app.logger.warning( + f"RuleWhitelistCache Anomaly! Rule {rule_model} does not exist. Type {rule_model_type} RID {cached_rule.rid} ID {cached_rule.id} Skipping." + ) + + elif rorigin_type == RuleOrigin.USER and cache_entries_count == 1: + msg = f"Set rule {rule_model} back to state 'Active'" + current_app.logger.info(msg) + flashes.append(msg) + try: + rule_model.rstate_id = 1 # Set rule state to "Active" again + except AttributeError: + current_app.logger.warning( + f"RuleWhitelistCache Anomaly! Rule {rule_model} does not exist. Type {rule_model_type} RID {cached_rule.rid} ID {cached_rule.id} Skipping." + ) + else: + author = f"{model.user.email} ({model.user.organization})" + announce_rtbh_route(rule_model, author) + + msg = f"Deleted cache entries for whitelist {whitelist_id} {model}" + current_app.logger.info(msg) + flashes.append(msg) + RuleWhitelistCache.clean_by_whitelist_id(whitelist_id) + + db.session.delete(model) + db.session.commit() + msg = f"Deleted whitelist {whitelist_id} {model}" + flashes.append(msg) + current_app.logger.info(msg) + + return flashes diff --git a/flowapp/templates/forms/bulk_user_form.html b/flowapp/templates/forms/bulk_user_form.html index c41389ea..b15b3ff8 100644 --- a/flowapp/templates/forms/bulk_user_form.html +++ b/flowapp/templates/forms/bulk_user_form.html @@ -42,9 +42,9 @@

Example CSV data

             
 uuid-eppn,name,telefon,email,role,organizace,poznamka
-view@example.com,Test View,123,view@example.com,1,1,View
-user@example.com,Test User,123456,user@example.com,2,1,User
-admin@example.com,Test Admin,+420 111 111 111,admin@example.com,3,1,Admin
+view@example.com,Test View,123,view@example.com,1,1,user with view role (1)
+user@example.com,Test User,123456,user@example.com,2,1,regular user (role 2)
+admin@example.com,Test Admin,+420 111 111 111,admin@example.com,3,1,user with admin role (3)
             
         

Role

diff --git a/flowapp/templates/forms/machine_api_key.html b/flowapp/templates/forms/machine_api_key.html index be6fcaad..1b646ad3 100644 --- a/flowapp/templates/forms/machine_api_key.html +++ b/flowapp/templates/forms/machine_api_key.html @@ -18,25 +18,29 @@
Machine Api Key: {{ generated_key }}
{{ form.hidden_tag() if form.hidden_tag }}
-
+
{{ render_field(form.machine) }}
{{ render_checkbox_field(form.readonly, checked="checked") }}
-
+
+
+
+ {{ render_field(form.user) }} +
+
+
{{ render_field(form.expires) }}
-
- -
+
-
+
{{ render_field(form.comment) }}
-
+
diff --git a/flowapp/templates/forms/rtbh_rule.html b/flowapp/templates/forms/rtbh_rule.html index 986c081b..8c91ef22 100644 --- a/flowapp/templates/forms/rtbh_rule.html +++ b/flowapp/templates/forms/rtbh_rule.html @@ -28,6 +28,15 @@

{{ title or 'New'}} RTBH rule

{{ render_field(form.community) }}
+
+
Following communities can be whitelisted:
+
    + {% for com in whitelistable %} +
  • + {{ com }} +
  • + {% endfor %} +
@@ -36,9 +45,6 @@

{{ title or 'New'}} RTBH rule

- {{ form.expires(class_='form-control') }} diff --git a/flowapp/templates/forms/whitelist.html b/flowapp/templates/forms/whitelist.html new file mode 100644 index 00000000..35559d79 --- /dev/null +++ b/flowapp/templates/forms/whitelist.html @@ -0,0 +1,62 @@ +{% extends 'layouts/default.html' %} +{% from 'forms/macros.html' import render_field %} +{% block title %}Add Whitelist{% endblock %} +{% block content %} +

{{ title or 'New'}} Whitelist

+ + {{ form.hidden_tag() if form.hidden_tag }} +
+
+
+
+ {{ render_field(form.ip) }} +
+
+ {{ render_field(form.mask) }} +
+
+
+
+
+ +
+ +
+
+
+ +
+ + {{ form.expires(class_='form-control') }} + + + +
+ + {% if form.expires.errors %} + {% for e in form.expires.errors %} +

{{ e }}

+ {% endfor %} + {% endif %} +
+
+
+
+
+ +
+
+ {{ render_field(form.comment) }} +
+
+ +
+
+ +
+
+ +{% endblock %} \ No newline at end of file diff --git a/flowapp/templates/layouts/default.html b/flowapp/templates/layouts/default.html index 47b049c4..c6f6e4fe 100644 --- a/flowapp/templates/layouts/default.html +++ b/flowapp/templates/layouts/default.html @@ -59,6 +59,7 @@ {% endfor %}
  • ExaFS version {{ session['app_version'] }}
  • +
  • API docs
  • {% endif %} diff --git a/flowapp/templates/macros.html b/flowapp/templates/macros.html index 9e0b4b79..661be2f7 100644 --- a/flowapp/templates/macros.html +++ b/flowapp/templates/macros.html @@ -1,5 +1,5 @@ -{% macro build_ip_tbody(rules, today, editable=True, group_op=True) %} +{% macro build_ip_tbody(rules, today, editable=True, group_op=True, whitelist_rule_ids=None, allowed_communities=None) %} {% for rule in rules %} {% if rule.next_header is defined %} @@ -8,7 +8,11 @@ {% set rtype_int = 4 %} {% endif %} - + {{ rule.source }}{% if rule.source_mask != none %}{{ '/' if rule.source_mask >= 0 else '' }}{{ rule.source_mask if rule.source_mask >= 0 else '' }}{% endif %} @@ -45,10 +49,10 @@ {% if editable %} - + - + {% endif %} @@ -73,23 +77,72 @@ {% endmacro %} -{% macro build_rtbh_tbody(rules, today, editable=True, group_op=True) %} +{% macro build_rtbh_tbody(rules, today, editable=True, group_op=True, whitelist_rule_ids=None, allowed_communities=None) %} + +{% for rule in rules %} + + + {% if rule.ipv4 %} + {{ rule.ipv4 }}{{ '/' if rule.ipv4_mask else '' }}{{rule.ipv4_mask|default("", True)}} + {% endif %} + {% if rule.ipv6 %} + {{ rule.ipv6 }}{{ '/' if rule.ipv6_mask else '' }} {{rule.ipv6_mask|default("", True)}} + {% endif %} + + + {{ rule.community.name }} + + + + {{ rule.expires|strftime }} + + + {{ rule.user.name }} + + + {% if editable %} + + + + + + + {% if rule.community.id in allowed_communities %} + + + + {% endif %} + {% endif %} + {% if rule.comment %} + + {% endif %} + + {% if editable and group_op %} + + + + {% endif %} + + +{% endfor %} + +{% endmacro %} + + +{% macro build_whitelist_tbody(rules, today, editable=True, group_op=True, whitelist_rule_ids=None, allowed_communities=None) %} {% for rule in rules %} - {% if rule.ipv4 %} - {{ rule.ipv4 }}{{ '/' if rule.ipv4_mask else '' }}{{rule.ipv4_mask|default("", True)}} - {% endif %} - {% if rule.ipv6 %} - {{ rule.ipv6 }}{{ '/' if rule.ipv6_mask else '' }} {{rule.ipv6_mask|default("", True)}} - {% endif %} - - - {{ rule.community.name }} + {{ rule.ip }}{{ '/' if rule.mask else '' }}{{rule.mask|default("", True)}} - - + {{ rule.expires|strftime }} @@ -97,10 +150,10 @@ {% if editable %} - + - + {% endif %} @@ -122,6 +175,7 @@ {% endmacro %} + {% macro build_rules_thead(rules_columns, rtype, rstate, sort_key, sort_order, search_query='', group_op=True) %} diff --git a/flowapp/templates/pages/machine_api_key.html b/flowapp/templates/pages/machine_api_key.html index 52eb478a..c2ced223 100644 --- a/flowapp/templates/pages/machine_api_key.html +++ b/flowapp/templates/pages/machine_api_key.html @@ -11,8 +11,8 @@

    Machines and ApiKeys

    Machine address ApiKey - Created by Created for + Created by Expires Read/Write ? Action diff --git a/flowapp/tests/conftest.py b/flowapp/tests/conftest.py index 65919af1..3a988ef4 100644 --- a/flowapp/tests/conftest.py +++ b/flowapp/tests/conftest.py @@ -12,6 +12,7 @@ from flowapp import db as _db from datetime import datetime import flowapp.models +from flowapp.models.organization import Organization TESTDB = "test_project.db" @@ -67,6 +68,8 @@ def app(request): SECRET_KEY="testkeysession", LOCAL_USER_UUID="jiri.vrany@cesnet.cz", LOCAL_AUTH=True, + ALLOWED_COMMUNITIES=[1, 2, 3], + WTF_CSRF_ENABLED=False, ) print("\n----- CREATE FLASK APPLICATION\n") @@ -114,6 +117,12 @@ def db(app, request): print("#: inserting users") flowapp.models.insert_users(users) + org = _db.session.query(Organization).filter_by(id=1).first() + # Update the organization address range to include our test networks + org.arange = "147.230.0.0/16\n2001:718:1c01::/48\n192.168.0.0/16\n10.0.0.0/8" + _db.session.commit() + print("\n----- UPDATED TEST ORG 1 \n", org) + def teardown(): _db.session.commit() _db.drop_all() @@ -143,6 +152,26 @@ def jwt_token(client, app, db, request): return data["token"] +@pytest.fixture(scope="session") +def machine_api_token(client, app, db, request): + """ + Get the test_client from the app, for the whole test session. + """ + mkey = "machinetestkey" + + with app.app_context(): + model = flowapp.models.MachineApiKey(machine="127.0.0.1", key=mkey, user_id=1, org_id=1) + db.session.add(model) + db.session.commit() + + print("\n----- GET MACHINE API KEY TEST TOKEN\n") + url = "/api/v3/auth" + headers = {"x-api-key": mkey} + token = client.get(url, headers=headers) + data = json.loads(token.data) + return data["token"] + + @pytest.fixture(scope="session") def expired_auth_token(client, app, db, request): """ @@ -185,3 +214,19 @@ def auth_client(client): print("\n----- CREATE AUTHENTICATED FLASK TEST CLIENT\n") client.get("/local-login") return client + + +@pytest.fixture(autouse=True) +def reset_org_limits(db, app): + """ + Automatically reset organization limits after each test that modifies them. + """ + yield # Allow test execution + + with app.app_context(): + org = db.session.query(Organization).filter_by(id=1).first() + if org: + org.limit_flowspec4 = 0 + org.limit_flowspec6 = 0 + org.limit_rtbh = 0 + db.session.commit() diff --git a/flowapp/tests/rule_service_integration.py b/flowapp/tests/rule_service_integration.py new file mode 100644 index 00000000..106000ab --- /dev/null +++ b/flowapp/tests/rule_service_integration.py @@ -0,0 +1,332 @@ +"""Integration tests for rule_service module.""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import patch + +from flowapp import db, create_app +from flowapp.constants import RuleTypes +from flowapp.models import ( + RTBH, + Flowspec4, + Flowspec6, + Whitelist, + User, + Organization, + Role, + Community, + Action, + Rstate, +) +from flowapp.services import rule_service + + +@pytest.fixture(scope="module") +def app(): + """Create and configure a Flask app for testing.""" + app = create_app("testing") + + # Create a test context + with app.app_context(): + # Create database tables + db.create_all() + + # Create test data + _create_test_data() + + yield app + + # Clean up after tests + db.session.remove() + db.drop_all() + + +@pytest.fixture(scope="function") +def app_context(app): + """Provide an application context for each test.""" + with app.app_context(): + yield + + +@pytest.fixture(scope="function") +def mock_announce_route(): + """Mock the announce_route function.""" + with patch("flowapp.services.rule_service.announce_route") as mock: + yield mock + + +def _create_test_data(): + """Create test data in the database.""" + # Create roles + admin_role = Role(name="admin", description="Administrator") + user_role = Role(name="user", description="Normal user") + view_role = Role(name="view", description="View only") + db.session.add_all([admin_role, user_role, view_role]) + db.session.flush() + + # Create organizations + org1 = Organization(name="Test Org 1", arange="192.168.1.0/24\n2001:db8::/64") + org2 = Organization(name="Test Org 2", arange="10.0.0.0/8") + db.session.add_all([org1, org2]) + db.session.flush() + + # Create users + user1 = User(uuid="user1@example.com", name="User One") + user1.role.append(user_role) + user1.organization.append(org1) + + admin1 = User(uuid="admin@example.com", name="Admin User") + admin1.role.append(admin_role) + admin1.organization.append(org2) + + db.session.add_all([user1, admin1]) + db.session.flush() + + # Create rule states + state_active = Rstate(description="active rule") + state_withdrawn = Rstate(description="withdrawed rule") + state_deleted = Rstate(description="deleted rule") + state_whitelisted = Rstate(description="whitelisted rule") + db.session.add_all([state_active, state_withdrawn, state_deleted, state_whitelisted]) + db.session.flush() + + # Create actions + action1 = Action(name="Discard", command="discard", description="Drop traffic", role_id=user_role.id) + action2 = Action(name="Rate limit", command="rate-limit 1000000", description="Limit traffic", role_id=user_role.id) + db.session.add_all([action1, action2]) + db.session.flush() + + # Create communities + community1 = Community( + name="65000:1", + comm="65000:1", + larcomm="", + extcomm="", + description="Test community", + role_id=user_role.id, + as_path=False, + ) + db.session.add(community1) + db.session.flush() + + # Create some rules + # IPv4 rule + ipv4_rule = Flowspec4( + source="192.168.1.1", + source_mask=32, + source_port="", + dest="192.168.2.1", + dest_mask=32, + dest_port="", + protocol="tcp", + flags="", + packet_len="", + fragment="", + expires=datetime.now() + timedelta(hours=1), + comment="Test IPv4 rule", + action_id=action1.id, + user_id=user1.id, + org_id=org1.id, + rstate_id=state_active.id, + ) + db.session.add(ipv4_rule) + + # IPv6 rule + ipv6_rule = Flowspec6( + source="2001:db8::1", + source_mask=128, + source_port="", + dest="2001:db8:1::1", + dest_mask=128, + dest_port="", + next_header="tcp", + flags="", + packet_len="", + expires=datetime.now() + timedelta(hours=1), + comment="Test IPv6 rule", + action_id=action1.id, + user_id=user1.id, + org_id=org1.id, + rstate_id=state_active.id, + ) + db.session.add(ipv6_rule) + + # RTBH rule + rtbh_rule = RTBH( + ipv4="192.168.1.100", + ipv4_mask=32, + ipv6=None, + ipv6_mask=None, + community_id=community1.id, + expires=datetime.now() + timedelta(hours=1), + comment="Test RTBH rule", + user_id=user1.id, + org_id=org1.id, + rstate_id=state_active.id, + ) + db.session.add(rtbh_rule) + + # Save IDs as class variables for tests to use + _create_test_data.user_id = user1.id + _create_test_data.admin_id = admin1.id + _create_test_data.org1_id = org1.id + _create_test_data.org2_id = org2.id + _create_test_data.ipv4_rule_id = ipv4_rule.id + _create_test_data.ipv6_rule_id = ipv6_rule.id + _create_test_data.rtbh_rule_id = rtbh_rule.id + _create_test_data.action_id = action1.id + _create_test_data.community_id = community1.id + + db.session.commit() + + +class TestRuleServiceIntegration: + """Integration tests for rule_service functions.""" + + def test_reactivate_rule_ipv4(self, app_context, mock_announce_route): + """Test reactivating an IPv4 rule.""" + # Test data + rule_id = _create_test_data.ipv4_rule_id + user_id = _create_test_data.user_id + org_id = _create_test_data.org1_id + + # Set new expiration time (2 hours in the future) + new_expires = datetime.now() + timedelta(hours=2) + new_comment = "Updated test comment" + + # Call the function + model, messages = rule_service.reactivate_rule( + rule_type=RuleTypes.IPv4, + rule_id=rule_id, + expires=new_expires, + comment=new_comment, + user_id=user_id, + org_id=org_id, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the rule was updated + assert model is not None + assert model.id == rule_id + assert model.comment == new_comment + assert model.expires == new_expires + assert model.rstate_id == 1 # active state + + # Verify route announcement was attempted + assert mock_announce_route.called + + # Verify message + assert messages == ["Rule successfully updated"] + + def test_reactivate_rule_inactive(self, app_context, mock_announce_route): + """Test reactivating a rule to inactive state.""" + # Test data + rule_id = _create_test_data.ipv4_rule_id + user_id = _create_test_data.user_id + org_id = _create_test_data.org1_id + + # Set past expiration time + past_expires = datetime.now() - timedelta(hours=1) + + # Call the function + model, messages = rule_service.reactivate_rule( + rule_type=RuleTypes.IPv4, + rule_id=rule_id, + expires=past_expires, + comment="Expired comment", + user_id=user_id, + org_id=org_id, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the rule was updated + assert model is not None + assert model.id == rule_id + assert model.expires == past_expires + assert model.rstate_id == 2 # inactive/withdrawn state + + # Verify route announcement was attempted + assert mock_announce_route.called + + def test_delete_rule(self, app_context, mock_announce_route): + """Test deleting a rule.""" + # Test data + rule_id = _create_test_data.ipv6_rule_id + user_id = _create_test_data.user_id + + # Call the function + success, message = rule_service.delete_rule( + rule_type=RuleTypes.IPv6, + rule_id=rule_id, + user_id=user_id, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the rule was deleted + assert success is True + assert message == "Rule deleted successfully" + + # Verify the rule no longer exists in the database + rule = db.session.get(Flowspec6, rule_id) + assert rule is None + + # Verify route withdrawal was attempted + assert mock_announce_route.called + + def test_delete_rule_not_found(self, app_context): + """Test deleting a non-existent rule.""" + # Use a non-existent rule ID + non_existent_id = 9999 + + # Call the function + success, message = rule_service.delete_rule( + rule_type=RuleTypes.IPv4, + rule_id=non_existent_id, + user_id=_create_test_data.user_id, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the operation failed + assert success is False + assert message == "Rule not found" + + @patch("flowapp.services.rule_service.create_or_update_whitelist") + def test_delete_rtbh_and_create_whitelist(self, mock_create_whitelist, app_context, mock_announce_route): + """Test deleting an RTBH rule and creating a whitelist entry.""" + # Test data + rule_id = _create_test_data.rtbh_rule_id + user_id = _create_test_data.user_id + org_id = _create_test_data.org1_id + + # Mock whitelist creation + mock_whitelist = Whitelist( + ip="192.168.1.100", mask=32, expires=datetime.now() + timedelta(days=7), user_id=user_id, org_id=org_id + ) + mock_create_whitelist.return_value = (mock_whitelist, ["Whitelist created"]) + + # Call the function + success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( + rule_id=rule_id, user_id=user_id, org_id=org_id, user_email="test@example.com", org_name="Test Org" + ) + + # Verify success + assert success is True + assert len(messages) == 2 + assert "Rule deleted successfully" in messages[0] + assert "Whitelist created" in messages[1] + + # Verify the rule was deleted + rule = db.session.get(RTBH, rule_id) + assert rule is None + + # Verify create_or_update_whitelist was called with correct data + mock_create_whitelist.assert_called_once() + args, kwargs = mock_create_whitelist.call_args + form_data = kwargs.get("form_data", args[0] if args else None) + assert form_data["ip"] == "192.168.1.100" + assert form_data["mask"] == 32 + assert "Created from RTBH rule" in form_data["comment"] diff --git a/flowapp/tests/test_api_auth.py b/flowapp/tests/test_api_auth.py index 5733346b..8d08a7ae 100644 --- a/flowapp/tests/test_api_auth.py +++ b/flowapp/tests/test_api_auth.py @@ -11,6 +11,15 @@ def test_token(client, jwt_token): assert req.status_code == 200 +def test_machine_token(client, machine_api_token): + """ + test that token authorization works + """ + req = client.get("/api/v3/test_token", headers={"x-access-token": machine_api_token}) + + assert req.status_code == 200 + + def test_expired_token(client, expired_auth_token): """ test that expired token authorization return 401 @@ -37,7 +46,7 @@ def test_readonly_token(client, readonly_jwt_token): assert req.status_code == 200 data = json.loads(req.data) - assert data['readonly'] + assert data["readonly"] def test_readonly_token_ipv4_create(client, db, readonly_jwt_token): diff --git a/flowapp/tests/test_api_v3.py b/flowapp/tests/test_api_v3.py index 96b24786..87f70949 100644 --- a/flowapp/tests/test_api_v3.py +++ b/flowapp/tests/test_api_v3.py @@ -1,10 +1,37 @@ import json + from flowapp.models import Flowspec4, Organization V_PREFIX = "/api/v3" +def test_create_rtbh_only(client, app, db, jwt_token): + """ + Test creating an RTBH rule through API that exactly matches an existing whitelist. + + The rule should be created but marked as whitelisted (rstate_id=4), and a cache entry + should be created linking the rule to the whitelist. + """ + + # Now create the RTBH rule via API + res = client.post( + f"{V_PREFIX}/rules/rtbh", + headers={"x-access-token": jwt_token}, + json={ + "community": 1, + "ipv4": "147.230.17.17", + "ipv4_mask": 32, + "expires": "10/25/2050 14:46", + }, + ) + + # Verify response is successful + assert res.status_code == 201 + data = json.loads(res.data) + assert data["rule"] is not None + + def test_token(client, jwt_token): """ test that token authorization works diff --git a/flowapp/tests/test_api_whitelist_integration.py b/flowapp/tests/test_api_whitelist_integration.py new file mode 100644 index 00000000..0391a762 --- /dev/null +++ b/flowapp/tests/test_api_whitelist_integration.py @@ -0,0 +1,273 @@ +""" +Integration test for the API view when interacting with whitelists. + +This test suite verifies the behavior when creating RTBH rules through the API +when there are existing whitelists that could affect the rules. +""" + +import json +import pytest +from datetime import datetime, timedelta + +from flowapp.constants import RuleTypes, RuleOrigin +from flowapp.models import RTBH, RuleWhitelistCache, Organization +from flowapp.services import whitelist_service + + +@pytest.fixture +def whitelist_data(): + """Create whitelist data for testing""" + return { + "ip": "192.168.1.0", + "mask": 24, + "comment": "Test whitelist for API integration test", + "expires": datetime.now() + timedelta(days=1), + } + + +@pytest.fixture +def rtbh_api_payload(): + """Create RTBH rule data for API testing""" + return { + "community": 1, + "ipv4": "192.168.1.0", + "ipv4_mask": 24, + "expires": (datetime.now() + timedelta(days=1)).strftime("%m/%d/%Y %H:%M"), + "comment": "Test RTBH rule via API", + } + + +def test_create_rtbh_equal_to_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): + """ + Test creating an RTBH rule through API that exactly matches an existing whitelist. + + The rule should be created but marked as whitelisted (rstate_id=4), and a cache entry + should be created linking the rule to the whitelist. + """ + # First, configure app to include community ID 1 in allowed communities + app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) + + # Create the whitelist directly using the service + with app.app_context(): + # Create user and organization if needed for the whitelist + org = db.session.query(Organization).first() + + # Create the whitelist + whitelist_model, _ = whitelist_service.create_or_update_whitelist( + form_data=whitelist_data, + user_id=1, # Using user ID 1 from test fixtures + org_id=org.id, + user_email="test@example.com", + org_name=org.name, + ) + + # Verify whitelist was created + assert whitelist_model.id is not None + assert whitelist_model.ip == whitelist_data["ip"] + assert whitelist_model.mask == whitelist_data["mask"] + + # Now create the RTBH rule via API + response = client.post( + "/api/v3/rules/rtbh", + headers={"x-access-token": jwt_token}, + json=rtbh_api_payload, + ) + + # Verify response is successful + assert response.status_code == 201 + data = json.loads(response.data) + assert data["rule"] is not None + rule_id = data["rule"]["id"] + + # Now verify the rule was created but marked as whitelisted + with app.app_context(): + rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + assert rtbh_rule is not None + assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state + + # Verify a cache entry was created + cache_entry = RuleWhitelistCache.query.filter_by( + rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id + ).first() + + assert cache_entry is not None + assert cache_entry.rorigin == RuleOrigin.USER.value + + +def test_create_rtbh_supernet_of_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): + """ + Test creating an RTBH rule through API that is a supernet of an existing whitelist. + + The rule should be created with whitelisted state (rstate_id=4) and smaller subnet rules + should be created for the non-whitelisted parts. + """ + # Configure app with allowed communities + app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) + + # First create a whitelist for a subnet + whitelist_data["ip"] = "192.168.1.128" + whitelist_data["mask"] = 25 # Subnet of the RTBH rule which will be /24 + + # Create the whitelist directly using the service + with app.app_context(): + org = db.session.query(Organization).first() + whitelist_model, _ = whitelist_service.create_or_update_whitelist( + form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name + ) + + # Verify whitelist was created + assert whitelist_model.id is not None + + # Now create the RTBH rule via API that covers both the whitelist and additional space + headers = {"x-access-token": jwt_token} + response = client.post( + "/api/v3/rules/rtbh", + headers=headers, + json=rtbh_api_payload, # This is a /24, which contains the /25 whitelist + ) + + # Verify response is successful + assert response.status_code == 201 + data = json.loads(response.data) + assert data["rule"] is not None + rule_id = data["rule"]["id"] + + # Now verify the rule was created and marked as whitelisted + with app.app_context(): + # Check the original rule + rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + assert rtbh_rule is not None + assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state + + # Check if a new subnet rule was created for the non-whitelisted part + subnet_rule = ( + db.session.query(RTBH) + .filter( + RTBH.ipv4 == "192.168.1.0", + RTBH.ipv4_mask == 25, # This would be the other half not covered by the whitelist + ) + .first() + ) + + assert subnet_rule is not None + assert subnet_rule.rstate_id == 1 # Active status + + # Verify cache entries + # Main rule should be cached as a USER rule + user_cache = RuleWhitelistCache.query.filter_by( + rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id, rorigin=RuleOrigin.USER.value + ).first() + assert user_cache is not None + + # Subnet rule should be cached as a WHITELIST rule + whitelist_cache = RuleWhitelistCache.query.filter_by( + rid=subnet_rule.id, + rtype=RuleTypes.RTBH.value, + whitelist_id=whitelist_model.id, + rorigin=RuleOrigin.WHITELIST.value, + ).first() + assert whitelist_cache is not None + + +def test_create_rtbh_subnet_of_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): + """ + Test creating an RTBH rule through API that is contained within an existing whitelist. + + The rule should be created but immediately marked as whitelisted (rstate_id=4). + """ + # Configure app with allowed communities + app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) + + # First create a whitelist for a supernet + whitelist_data["ip"] = "192.168.0.0" + whitelist_data["mask"] = 16 # Supernet that contains the RTBH rule + + # Create the whitelist directly using the service + with app.app_context(): + all_rtbh_rules_before = db.session.query(RTBH).count() + + org = db.session.query(Organization).first() + whitelist_model, _ = whitelist_service.create_or_update_whitelist( + form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name + ) + + # Verify whitelist was created + assert whitelist_model.id is not None + + # Now create the RTBH rule via API that is inside the whitelist + headers = {"x-access-token": jwt_token} + response = client.post( + "/api/v3/rules/rtbh", headers=headers, json=rtbh_api_payload # This is a /24 inside the /16 whitelist + ) + + # Verify response is successful + assert response.status_code == 201 + data = json.loads(response.data) + assert data["rule"] is not None + rule_id = data["rule"]["id"] + + # Now verify the rule was created but marked as whitelisted + with app.app_context(): + rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + assert rtbh_rule is not None + assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state + + # Verify a cache entry was created + cache_entry = RuleWhitelistCache.query.filter_by( + rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id + ).first() + + assert cache_entry is not None + assert cache_entry.rorigin == RuleOrigin.USER.value + + # Verify no additional rules were created + all_rtbh_rules = db.session.query(RTBH).count() + assert all_rtbh_rules - all_rtbh_rules_before == 1 + + +def test_create_rtbh_no_relation_to_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): + """ + Test creating an RTBH rule through API that has no relation to any existing whitelist. + + The rule should be created normally with active state (rstate_id=1). + """ + # Configure app with allowed communities + app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) + + # First create a whitelist for a completely different network + whitelist_data["ip"] = "10.0.0.0" + whitelist_data["mask"] = 8 + + # Create the whitelist directly using the service + with app.app_context(): + org = db.session.query(Organization).first() + whitelist_model, _ = whitelist_service.create_or_update_whitelist( + form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name + ) + + # Verify whitelist was created + assert whitelist_model.id is not None + + # Now create the RTBH rule via API for a network not covered by any of the whitelists + rtbh_api_payload["ipv4"] = "147.230.17.185" + rtbh_api_payload["ipv4_mask"] = 32 + + headers = {"x-access-token": jwt_token} + response = client.post("/api/v3/rules/rtbh", headers=headers, json=rtbh_api_payload) # This is 192.168.1.0/24 + + # Verify response is successful + assert response.status_code == 201 + data = json.loads(response.data) + assert data["rule"] is not None + rule_id = data["rule"]["id"] + + # Now verify the rule was created with active state + with app.app_context(): + rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + assert rtbh_rule is not None + assert rtbh_rule.rstate_id == 1 # 1 = active state + + # Verify no cache entry was created + cache_entry = RuleWhitelistCache.query.filter_by(rid=rule_id, rtype=RuleTypes.RTBH.value).first() + + assert cache_entry is None diff --git a/flowapp/tests/test_flowspec.py b/flowapp/tests/test_flowspec.py index e7a9c3e9..70000b81 100644 --- a/flowapp/tests/test_flowspec.py +++ b/flowapp/tests/test_flowspec.py @@ -1,12 +1,12 @@ import pytest -import flowapp.flowspec +from flowapp.flowspec import translate_sequence, filter_rules_action, check_limit def test_translate_number(): """ tests for x (integer) to =x """ - assert "[=10]" == flowapp.flowspec.translate_sequence("10") + assert "[=10]" == translate_sequence("10") def test_raises(): @@ -14,7 +14,7 @@ def test_raises(): tests for translator """ with pytest.raises(ValueError): - flowapp.flowspec.translate_sequence("ahoj") + translate_sequence("ahoj") def test_raises_bad_number(): @@ -22,46 +22,146 @@ def test_raises_bad_number(): tests for translator """ with pytest.raises(ValueError): - flowapp.flowspec.translate_sequence("75555") + translate_sequence("75555") def test_translate_range(): """ tests for x-y to >=x&<=y """ - assert "[>=10&<=20]" == flowapp.flowspec.translate_sequence("10-20") + assert "[>=10&<=20]" == translate_sequence("10-20") def test_exact_rule(): """ test for >=x&<=y to >=x&<=y """ - assert "[>=10&<=20]" == flowapp.flowspec.translate_sequence(">=10&<=20") + assert "[>=10&<=20]" == translate_sequence(">=10&<=20") def test_greater_than(): """ test for >x to >=x&<=65535 """ - assert "[>=10&<=65535]" == flowapp.flowspec.translate_sequence(">10") + assert "[>=10&<=65535]" == translate_sequence(">10") def test_greater_equal_than(): """ test for >=x to >=x&<=65535 """ - assert "[>=10&<=65535]" == flowapp.flowspec.translate_sequence(">=10") + assert "[>=10&<=65535]" == translate_sequence(">=10") def test_lower_than(): """ test for =0&<=0 """ - assert "[>=0&<=10]" == flowapp.flowspec.translate_sequence("<10") + assert "[>=0&<=10]" == translate_sequence("<10") def test_lower_equal_than(): """ test for =0&<=0 """ - assert "[>=0&<=10]" == flowapp.flowspec.translate_sequence("<=10") + assert "[>=0&<=10]" == translate_sequence("<=10") + + +# new tests + + +def test_multiple_sequences(): + """Test multiple sequences separated by semicolons""" + assert "[=10 >=20&<=30]" == translate_sequence("10;20-30") + assert "[=10 >=0&<=20 >=30&<=65535]" == translate_sequence("10;<=20;>30") + + +def test_empty_sequence(): + """Test empty sequence and sequences with empty parts""" + assert "[]" == translate_sequence("") + assert "[=10]" == translate_sequence("10;;") + assert "[=10 =20]" == translate_sequence("10;;20") + + +def test_range_edge_cases(): + """Test edge cases for ranges""" + # Same numbers in range + assert "[>=10&<=10]" == translate_sequence("10-10") + + # Invalid range (start > end) + with pytest.raises(ValueError, match="Invalid range: start value cannot be greater than end value"): + translate_sequence("20-10") + + # Invalid range in exact rule + with pytest.raises(ValueError, match="Invalid range: start value cannot be greater than end value"): + translate_sequence(">=20&<=10") + + +def test_check_limit_validation(): + """Test the check_limit function""" + # Test valid cases + assert check_limit(10, 100) == 10 + assert check_limit(0, 100, 0) == 0 + assert check_limit(100, 100, 0) == 100 + + # Test invalid cases + with pytest.raises(ValueError, match="Invalid value number: .* is too big"): + check_limit(101, 100) + + with pytest.raises(ValueError, match="Invalid value number: .* is too small"): + check_limit(-1, 100, 0) + + +def test_invalid_inputs(): + """Test various invalid inputs""" + invalid_inputs = [ + "abc", # Non-numeric + "10.5", # Decimal + "10,20", # Wrong separator + "10&20", # Wrong format + ">", # Incomplete + ">=", # Incomplete + "<", # Incomplete + "<=", # Incomplete + ">>10", # Invalid operator + "<<10", # Invalid operator + ">=10&", # Incomplete range + "&<=10", # Incomplete range + ] + + for invalid_input in invalid_inputs: + with pytest.raises(ValueError): + translate_sequence(invalid_input) + + +class MockRule: + def __init__(self, action_id): + self.action_id = action_id + + +def test_filter_rules_action(): + """Test the filter_rules_action function""" + # Create test rules + rules = [MockRule(1), MockRule(2), MockRule(3), MockRule(4)] + + # Test with empty allowed actions + editable, viewonly = filter_rules_action([], rules) + assert len(editable) == 0 + assert len(viewonly) == 4 + + # Test with some allowed actions + editable, viewonly = filter_rules_action([1, 3], rules) + assert len(editable) == 2 + assert len(viewonly) == 2 + assert all(rule.action_id in [1, 3] for rule in editable) + assert all(rule.action_id in [2, 4] for rule in viewonly) + + # Test with all actions allowed + editable, viewonly = filter_rules_action([1, 2, 3, 4], rules) + assert len(editable) == 4 + assert len(viewonly) == 0 + + # Test with empty rules list + editable, viewonly = filter_rules_action([1, 2], []) + assert len(editable) == 0 + assert len(viewonly) == 0 diff --git a/flowapp/tests/test_forms.py b/flowapp/tests/test_forms.py index 76c98960..e9620d8d 100644 --- a/flowapp/tests/test_forms.py +++ b/flowapp/tests/test_forms.py @@ -1,15 +1,7 @@ import pytest -from flask import Flask import flowapp.forms -@pytest.fixture() -def app(): - app = Flask(__name__) - app.secret_key = "test" - return app - - @pytest.fixture() def ip_form(app, field_class): with app.test_request_context(): # Push the request context diff --git a/flowapp/tests/test_forms_cl.py b/flowapp/tests/test_forms_cl.py new file mode 100644 index 00000000..db0a569f --- /dev/null +++ b/flowapp/tests/test_forms_cl.py @@ -0,0 +1,608 @@ +import pytest +from datetime import datetime, timedelta +from werkzeug.datastructures import MultiDict +from flowapp.forms import ( + UserForm, + BulkUserForm, + ApiKeyForm, + MachineApiKeyForm, + OrganizationForm, + ActionForm, + ASPathForm, + CommunityForm, + RTBHForm, + IPv4Form, + IPv6Form, + WhitelistForm, +) + + +def create_form_data(data): + """Helper function to create proper form data format""" + processed_data = {} + for key, value in data.items(): + if isinstance(value, list): + processed_data[key] = [str(v) for v in value] + else: + processed_data[key] = value + return MultiDict(processed_data) + + +@pytest.fixture +def valid_datetime(): + return (datetime.now() + timedelta(days=1)).strftime("%Y-%m-%dT%H:%M") + + +@pytest.fixture +def sample_network_ranges(): + return ["192.168.0.0/16", "2001:db8::/32"] + + +class TestUserForm: + @pytest.fixture + def mock_choices(self): + return { + "role_ids": [(1, "Admin"), (2, "User"), (3, "Guest")], + "org_ids": [(1, "Org1"), (2, "Org2"), (3, "Org3")], + } + + @pytest.fixture + def valid_user_data(self): + return { + "uuid": "test@example.com", + "email": "user@example.com", + "name": "Test User", + "phone": "123456789", + "role_ids": ["2"], + "org_ids": ["1"], + } + + def test_valid_user_form(self, app, valid_user_data, mock_choices): + with app.test_request_context(): + form_data = create_form_data(valid_user_data) + form = UserForm(formdata=form_data) + form.role_ids.choices = mock_choices["role_ids"] + form.org_ids.choices = mock_choices["org_ids"] + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_invalid_email(self, app, mock_choices): + with app.test_request_context(): + form_data = create_form_data({"uuid": "invalid-email", "role_ids": ["2"], "org_ids": ["1"]}) + form = UserForm(formdata=form_data) + form.role_ids.choices = mock_choices["role_ids"] + form.org_ids.choices = mock_choices["org_ids"] + + assert not form.validate() + assert "Please provide valid email" in form.uuid.errors + + +class TestRTBHForm: + @pytest.fixture + def mock_community_choices(self): + return [(1, "Community1"), (2, "Community2")] + + @pytest.fixture + def valid_rtbh_data(self, valid_datetime): + return { + "ipv4": "192.168.1.0", + "ipv4_mask": 24, + "community": "1", + "expires": valid_datetime, + "comment": "Test RTBH rule", + } + + def test_valid_ipv4_rtbh(self, app, valid_rtbh_data, sample_network_ranges, mock_community_choices): + with app.test_request_context(): + form_data = create_form_data(valid_rtbh_data) + form = RTBHForm(formdata=form_data) + form.net_ranges = sample_network_ranges + form.community.choices = mock_community_choices + assert form.validate() + + +class TestIPv4Form: + @pytest.fixture + def mock_action_choices(self): + return [(1, "Accept"), (2, "Drop"), (3, "Reject")] + + @pytest.fixture + def valid_ipv4_data(self, valid_datetime): + return { + "source": "192.168.1.0", + "source_mask": "24", + "protocol": "tcp", + "action": "1", + "expires": valid_datetime, + } + + def test_valid_ipv4_rule(self, app, valid_ipv4_data, sample_network_ranges, mock_action_choices): + with app.test_request_context(): + form_data = create_form_data(valid_ipv4_data) + form = IPv4Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_invalid_protocol_flags(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + with app.test_request_context(): + form_data = create_form_data( + { + "source": "192.168.1.0", + "source_mask": "24", + "protocol": "udp", + "flags": ["syn"], + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv4Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + assert not form.validate() + + +class TestIPv6Form: + @pytest.fixture + def mock_action_choices(self): + return [(1, "Accept"), (2, "Drop"), (3, "Reject")] + + @pytest.fixture + def valid_ipv6_data(self, valid_datetime): + return { + "source": "2001:db8::", # Network aligned address within allowed range + "source_mask": "32", # Matching the organization's prefix length + "next_header": "tcp", + "action": "1", + "expires": valid_datetime, + } + + def test_valid_ipv6_rule(self, app, valid_ipv6_data, sample_network_ranges, mock_action_choices): + with app.test_request_context(): + form_data = create_form_data(valid_ipv6_data) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_invalid_next_header_flags(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db8::", + "source_mask": "32", + "next_header": "udp", + "flags": ["syn"], + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + assert not form.validate() + + def test_address_outside_range(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + """Test validation fails when address is outside allowed ranges""" + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db9::", # Different prefix + "source_mask": "32", + "next_header": "tcp", + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + assert not form.validate() + assert any("must be in organization range" in error for error in form.source.errors) + + def test_destination_address(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + """Test validation with destination address instead of source""" + with app.test_request_context(): + form_data = create_form_data( + { + "dest": "2001:db8::", + "dest_mask": "32", + "next_header": "tcp", + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_both_source_and_dest(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + """Test validation with both source and destination addresses""" + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db8::", + "source_mask": "32", + "dest": "2001:db8:1::", + "dest_mask": "48", + "next_header": "tcp", + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_tcp_flags(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + """Test validation with TCP flags (should be valid with TCP)""" + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db8::", + "source_mask": "32", + "next_header": "tcp", + "flags": ["SYN", "ACK"], + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + @pytest.mark.parametrize("port_data", ["80", "80;443", "1024-2048"]) + def test_valid_ports(self, app, valid_datetime, sample_network_ranges, mock_action_choices, port_data): + """Test validation with various valid port formats""" + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db8::", + "source_mask": "32", + "next_header": "tcp", + "source_port": port_data, + "dest_port": port_data, + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + +class TestBulkUserForm: + @pytest.fixture + def valid_csv_data(self): + return "uuid-eppn,role,organizace\nuser1@example.com,2,1\nuser2@example.com,2,1" + + def test_valid_bulk_import(self, app, valid_csv_data): + with app.test_request_context(): + form_data = create_form_data({"users": valid_csv_data}) + form = BulkUserForm(formdata=form_data) + form.roles = {2} # Mock available roles + form.organizations = {1} # Mock available organizations + form.uuids = set() # Mock existing UUIDs (empty set) + assert form.validate() + + def test_duplicate_uuid(self, app, valid_csv_data): + with app.test_request_context(): + form_data = create_form_data({"users": valid_csv_data}) + form = BulkUserForm(formdata=form_data) + form.roles = {2} + form.organizations = {1} + form.uuids = {"user1@example.com"} # UUID already exists + assert not form.validate() + assert any("already exists" in error for error in form.users.errors) + + def test_invalid_role(self, app): + csv_data = "uuid-eppn,role,organizace\nuser1@example.com,999,1" # Invalid role ID + with app.test_request_context(): + form_data = create_form_data({"users": csv_data}) + form = BulkUserForm(formdata=form_data) + form.roles = {2} + form.organizations = {1} + form.uuids = set() + assert not form.validate() + assert any("does not exist" in error for error in form.users.errors) + + +class TestApiKeyForm: + @pytest.fixture + def valid_api_key_data(self, valid_datetime): + return {"machine": "192.168.1.1", "comment": "Test API key", "expires": valid_datetime, "readonly": "true"} + + def test_valid_api_key(self, app, valid_api_key_data): + with app.test_request_context(): + form_data = create_form_data(valid_api_key_data) + form = ApiKeyForm(formdata=form_data) + assert form.validate() + + def test_invalid_ip(self, app, valid_datetime): + with app.test_request_context(): + form_data = create_form_data({"machine": "invalid_ip", "expires": valid_datetime}) + form = ApiKeyForm(formdata=form_data) + assert not form.validate() + assert "provide valid IP address" in form.machine.errors + + def test_unlimited_expiration(self, app): + with app.test_request_context(): + form_data = create_form_data({"machine": "192.168.1.1", "expires": ""}) # Empty expiration for unlimited + form = ApiKeyForm(formdata=form_data) + assert form.validate() + + +class TestMachineApiKeyForm: + # Similar to ApiKeyForm, but might have different validation rules + @pytest.fixture + def valid_machine_key_data(self, valid_datetime): + return { + "machine": "192.168.1.1", + "comment": "Test machine API key", + "expires": valid_datetime, + "readonly": "true", + "user": 1, + } + + def test_valid_machine_key(self, app, valid_machine_key_data): + with app.test_request_context(): + form_data = create_form_data(valid_machine_key_data) + form = MachineApiKeyForm(formdata=form_data) + form.user.choices = [(1, "g.name"), (2, "test")] + assert form.validate() + + +class TestOrganizationForm: + @pytest.fixture + def valid_org_data(self): + return { + "name": "Test Organization", + "limit_flowspec4": "100", + "limit_flowspec6": "100", + "limit_rtbh": "50", + "arange": "192.168.0.0/16\n2001:db8::/32", + } + + def test_valid_organization(self, app, valid_org_data): + with app.test_request_context(): + form_data = create_form_data(valid_org_data) + form = OrganizationForm(formdata=form_data) + assert form.validate() + + def test_invalid_ranges(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "Test Org", "arange": "invalid_range\n192.168.0.0/16"}) + form = OrganizationForm(formdata=form_data) + assert not form.validate() + + def test_invalid_limits(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "Test Org", "limit_flowspec4": "1001"}) # Exceeds max value + form = OrganizationForm(formdata=form_data) + assert not form.validate() + + +class TestActionForm: + @pytest.fixture + def valid_action_data(self): + return { + "name": "test_action", + "command": "announce route", + "description": "Test action description", + "role_id": "2", + } + + def test_valid_action(self, app, valid_action_data): + with app.test_request_context(): + form_data = create_form_data(valid_action_data) + form = ActionForm(formdata=form_data) + assert form.validate() + + def test_invalid_role(self, app, valid_action_data): + with app.test_request_context(): + invalid_data = dict(valid_action_data) + invalid_data["role_id"] = "4" # Invalid role + form_data = create_form_data(invalid_data) + form = ActionForm(formdata=form_data) + assert not form.validate() + + +class TestASPathForm: + @pytest.fixture + def valid_aspath_data(self): + return {"prefix": "AS64512", "as_path": "64512 64513 64514"} + + def test_valid_aspath(self, app, valid_aspath_data): + with app.test_request_context(): + form_data = create_form_data(valid_aspath_data) + form = ASPathForm(formdata=form_data) + assert form.validate() + + def test_empty_fields(self, app): + with app.test_request_context(): + form_data = create_form_data({}) + form = ASPathForm(formdata=form_data) + assert not form.validate() + + +class TestCommunityForm: + @pytest.fixture + def valid_community_data(self): + return { + "name": "test_community", + "comm": "64512:100", + "description": "Test community", + "role_id": "2", + "as_path": "true", + } + + def test_valid_community(self, app, valid_community_data): + with app.test_request_context(): + form_data = create_form_data(valid_community_data) + form = CommunityForm(formdata=form_data) + assert form.validate() + + def test_missing_all_community_types(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "test_community", "role_id": "2"}) + form = CommunityForm(formdata=form_data) + assert not form.validate() + assert any("could not be empty" in error for error in form.comm.errors) + + def test_valid_with_extended_community(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "test_community", "extcomm": "rt:64512:100", "role_id": "2"}) + form = CommunityForm(formdata=form_data) + assert form.validate() + + def test_valid_with_large_community(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "test_community", "larcomm": "64512:100:200", "role_id": "2"}) + form = CommunityForm(formdata=form_data) + assert form.validate() + + +class TestWhitelistForm: + @pytest.fixture + def valid_ipv4_data(self, valid_datetime): + return {"ip": "192.168.0.0", "mask": "24", "comment": "Test whitelist entry", "expires": valid_datetime} + + @pytest.fixture + def valid_ipv6_data(self, valid_datetime): + return {"ip": "2001:db8::", "mask": "32", "comment": "IPv6 whitelist entry", "expires": valid_datetime} + + def test_valid_ipv4_entry(self, app, valid_ipv4_data, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data(valid_ipv4_data) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_valid_ipv6_entry(self, app, valid_ipv6_data, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data(valid_ipv6_data) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_invalid_ip_format(self, app, valid_datetime, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data({"ip": "invalid_ip", "mask": "24", "expires": valid_datetime}) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + assert any("valid IP address" in error for error in form.ip.errors) + + def test_ip_outside_range(self, app, valid_datetime): + with app.test_request_context(): + form_data = create_form_data( + {"ip": "10.0.0.0", "mask": "24", "expires": valid_datetime} # IP outside allowed range + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = ["192.168.0.0/16"] # Only allow 192.168.0.0/16 + assert not form.validate() + assert any("must be in organization range" in error for error in form.ip.errors) + + def test_invalid_mask_ipv4(self, app, valid_datetime, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data( + {"ip": "192.168.0.0", "mask": "33", "expires": valid_datetime} # Invalid mask for IPv4 + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + + def test_invalid_mask_ipv6(self, app, valid_datetime, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data( + {"ip": "2001:db8::", "mask": "129", "expires": valid_datetime} # Invalid mask for IPv6 + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + + def test_missing_required_fields(self, app, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data({}) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + assert form.ip.errors # Should have error for missing IP + assert form.mask.errors # Should have error for missing mask + assert form.expires.errors # Should have error for missing expiration + + def test_comment_length(self, app, valid_datetime, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data( + { + "ip": "192.168.0.0", + "mask": "24", + "expires": valid_datetime, + "comment": "x" * 256, # Comment longer than 255 chars + } + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + + @pytest.mark.parametrize( + "expires", ["", "invalid_date", "2020-33-42T00:00"] # Empty expiration # Invalid date format # Past date + ) + def test_invalid_expiration(self, app, expires, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data({"ip": "192.168.0.0", "mask": "24", "expires": expires}) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + if not form.validate(): + print("Validation errors:", form.errors) + + assert not form.validate() + + def test_network_alignment(self, app, valid_datetime, sample_network_ranges): + """Test that IP addresses must be properly network-aligned""" + with app.test_request_context(): + form_data = create_form_data( + {"ip": "192.168.1.1", "mask": "24", "expires": valid_datetime} # Not aligned to /24 boundary + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() diff --git a/flowapp/tests/test_models.py b/flowapp/tests/test_models.py index 89599453..b352a5f2 100644 --- a/flowapp/tests/test_models.py +++ b/flowapp/tests/test_models.py @@ -1,4 +1,16 @@ -from datetime import datetime +from datetime import datetime, timedelta +from flowapp.models import ( + User, + Organization, + Role, + ApiKey, + MachineApiKey, + Rstate, + Community, + Action, + Flowspec6, + Whitelist, +) import flowapp.models as models @@ -232,3 +244,255 @@ def test_rtbj_eq(db): ) assert model_A == model_B + + +def test_user_creation(db): + """Test basic user creation and relationships""" + # Create test role and org first + role = Role(name="test_role", description="Test Role") + org = Organization(name="test_org", arange="10.0.0.0/8") + db.session.add_all([role, org]) + db.session.commit() + + # Create user with relationships + user = User( + uuid="test-user-123", name="Test User", phone="1234567890", email="test@example.com", comment="Test comment" + ) + user.role.append(role) + user.organization.append(org) + db.session.add(user) + db.session.commit() + + # Verify user and relationships + assert user.uuid == "test-user-123" + assert user.name == "Test User" + assert len(user.role.all()) == 1 + assert len(user.organization.all()) == 1 + assert user.role.first().name == "test_role" + assert user.organization.first().name == "test_org" + + +def test_api_key_expiration(db): + """Test ApiKey expiration logic""" + user = User(uuid="test-user") + org = Organization(name="test-org", arange="10.0.0.0/8") + db.session.add_all([user, org]) + db.session.commit() + + # Create non-expiring key + non_expiring_key = ApiKey( + machine="test-machine-1", + key="key1", + readonly=True, + expires=None, + comment="Non-expiring key", + user_id=user.id, + org_id=org.id, + ) + + # Create expired key + expired_key = ApiKey( + machine="test-machine-2", + key="key2", + readonly=True, + expires=datetime.now() - timedelta(days=1), + comment="Expired key", + user_id=user.id, + org_id=org.id, + ) + + # Create future key + future_key = ApiKey( + machine="test-machine-3", + key="key3", + readonly=True, + expires=datetime.now() + timedelta(days=1), + comment="Future key", + user_id=user.id, + org_id=org.id, + ) + + db.session.add_all([non_expiring_key, expired_key, future_key]) + db.session.commit() + + assert not non_expiring_key.is_expired() + assert expired_key.is_expired() + assert not future_key.is_expired() + + +def test_machine_api_key_expiration(db): + """Test MachineApiKey expiration logic""" + user = User(uuid="test-user-machine") + org = Organization(name="test-org-machine", arange="10.0.0.0/8") + db.session.add_all([user, org]) + db.session.commit() + + # Create non-expiring key + non_expiring_key = MachineApiKey( + machine="test-machine-1", + key="key1", + readonly=True, + expires=None, + comment="Non-expiring key", + user_id=user.id, + org_id=org.id, + ) + + # Create expired key + expired_key = MachineApiKey( + machine="test-machine-2", + key="key2", + readonly=True, + expires=datetime.now() - timedelta(days=1), + comment="Expired key", + user_id=user.id, + org_id=org.id, + ) + + db.session.add_all([non_expiring_key, expired_key]) + db.session.commit() + + assert not non_expiring_key.is_expired() + assert expired_key.is_expired() + + +def test_organization_get_users(db): + """Test Organization's get_users method""" + org = Organization( + name="test-org-get-user", arange="10.0.0.0/8", limit_flowspec4=100, limit_flowspec6=100, limit_rtbh=100 + ) + uuid1 = "test-org-get-user" + uuid2 = "test-org-get-user2" + user1 = User(uuid=uuid1) + user2 = User(uuid=uuid2) + + db.session.add(org) + db.session.add_all([user1, user2]) + db.session.commit() + + org.user.append(user1) + org.user.append(user2) + db.session.commit() + + users = org.get_users() + assert len(users) == 2 + assert all(isinstance(user, User) for user in users) + assert {user.uuid for user in users} == {uuid1, uuid2} + + +def test_flowspec6_equality(db): + """Test Flowspec6 equality comparison""" + model_a = Flowspec6( + source="2001:db8::1", + source_mask=128, + source_port="80", + destination="2001:db8::2", + destination_mask=128, + destination_port="443", + next_header="tcp", + flags="", + packet_len="", + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + ) + + # Same network parameters but different timestamps + model_b = Flowspec6( + source="2001:db8::1", + source_mask=128, + source_port="80", + destination="2001:db8::2", + destination_mask=128, + destination_port="443", + next_header="tcp", + flags="", + packet_len="", + expires=datetime.now() + timedelta(days=1), + user_id=1, + org_id=1, + action_id=1, + ) + + # Different network parameters + model_c = Flowspec6( + source="2001:db8::3", + source_mask=128, + source_port="80", + destination="2001:db8::4", + destination_mask=128, + destination_port="443", + next_header="tcp", + flags="", + packet_len="", + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + ) + + assert model_a == model_b # Should be equal despite different timestamps + assert model_a != model_c # Should be different due to different network parameters + + +def test_whitelist_equality(db): + """Test Whitelist equality comparison""" + model_a = Whitelist( + ip="192.168.1.1", mask=32, expires=datetime.now(), user_id=1, org_id=1, comment="Test whitelist" + ) + + # Same IP/mask but different timestamps + model_b = Whitelist( + ip="192.168.1.1", + mask=32, + expires=datetime.now() + timedelta(days=1), + user_id=1, + org_id=1, + comment="Different comment", + ) + + # Different IP + model_c = Whitelist( + ip="192.168.1.2", mask=32, expires=datetime.now(), user_id=1, org_id=1, comment="Test whitelist" + ) + + assert model_a == model_b # Should be equal despite different timestamps + assert model_a != model_c # Should be different due to different IP + + +def test_whitelist_to_dict(db): + """Test Whitelist to_dict serialization""" + whitelist = Whitelist( + ip="192.168.1.1", mask=32, expires=datetime.now(), user_id=1, org_id=1, comment="Test whitelist" + ) + + # Create required related objects + user = User(uuid="test-user-whitelist") + rstate = Rstate(description="active") + db.session.add_all([user, rstate]) + db.session.commit() + + db.session.add(whitelist) + db.session.commit() + + whitelist.user = user + whitelist.rstate_id = rstate.id + db.session.add(whitelist) + db.session.commit() + + # Test timestamp format + dict_timestamp = whitelist.to_dict(prefered_format="timestamp") + assert isinstance(dict_timestamp["expires"], int) + assert isinstance(dict_timestamp["created"], int) + + # Test yearfirst format + dict_yearfirst = whitelist.to_dict(prefered_format="yearfirst") + assert isinstance(dict_yearfirst["expires"], str) + assert isinstance(dict_yearfirst["created"], str) + + # Check basic fields + assert dict_timestamp["ip"] == "192.168.1.1" + assert dict_timestamp["mask"] == 32 + assert dict_timestamp["comment"] == "Test whitelist" + assert dict_timestamp["user"] == "test-user-whitelist" diff --git a/flowapp/tests/test_rule_service.py b/flowapp/tests/test_rule_service.py new file mode 100644 index 00000000..490d2c2f --- /dev/null +++ b/flowapp/tests/test_rule_service.py @@ -0,0 +1,628 @@ +""" +Tests for rule_service.py module. + +This test suite verifies the functionality of the rule service after refactoring, +which manages creation, updating, and processing of flow rules. +""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import patch, MagicMock + +from flowapp.constants import RuleOrigin, ANNOUNCE +from flowapp.models import Flowspec4, Flowspec6, RTBH, Whitelist +from flowapp.services import rule_service +from flowapp.services.whitelist_common import Relation + + +@pytest.fixture +def ipv4_form_data(): + """Sample valid IPv4 form data""" + return { + "source": "192.168.1.0", + "source_mask": 24, + "source_port": "80", + "dest": "", + "dest_mask": None, + "dest_port": "", + "protocol": "tcp", + "flags": ["SYN"], + "packet_len": "", + "fragment": ["dont-fragment"], + "comment": "Test IPv4 rule", + "expires": datetime.now() + timedelta(hours=1), + "action": 1, + } + + +@pytest.fixture +def ipv6_form_data(): + """Sample valid IPv6 form data""" + return { + "source": "2001:db8::", + "source_mask": 32, + "source_port": "80", + "dest": "", + "dest_mask": None, + "dest_port": "", + "next_header": "tcp", + "flags": ["SYN"], + "packet_len": "", + "comment": "Test IPv6 rule", + "expires": datetime.now() + timedelta(hours=1), + "action": 1, + } + + +@pytest.fixture +def rtbh_form_data(): + """Sample valid RTBH form data""" + return { + "ipv4": "192.168.1.0", + "ipv4_mask": 24, + "ipv6": "", + "ipv6_mask": None, + "community": 1, + "comment": "Test RTBH rule", + "expires": datetime.now() + timedelta(hours=1), + } + + +@pytest.fixture +def whitelist_fixture(): + """Create a whitelist fixture""" + whitelist = MagicMock(spec=Whitelist) + whitelist.id = 1 + return whitelist + + +class TestCreateOrUpdateIPv4Rule: + @patch("flowapp.services.rule_service.get_ipv4_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_create_new_ipv4_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv4_form_data + ): + """Test creating a new IPv4 rule""" + # Mock the get_ipv4_model_if_exists to return False (not found) + mock_get_model.return_value = False + + # Mock the announce route behavior + mock_messages.create_ipv4.return_value = "mock command" + + # Call the service function + with app.app_context(): + model, message = rule_service.create_or_update_ipv4_rule( + form_data=ipv4_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was created with correct attributes + assert model is not None + assert model.source == ipv4_form_data["source"] + assert model.source_mask == ipv4_form_data["source_mask"] + assert model.protocol == ipv4_form_data["protocol"] + assert model.flags == "SYN" # form data flags are joined with ";" + assert model.action_id == ipv4_form_data["action"] + assert model.rstate_id == 1 # Active state + assert model.user_id == 1 + assert model.org_id == 1 + + # Verify message is still a string for IPv4 rules + assert message == "IPv4 Rule saved" + + # Verify route was announced + mock_messages.create_ipv4.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() + + @patch("flowapp.services.rule_service.get_ipv4_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_update_existing_ipv4_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv4_form_data + ): + """Test updating an existing IPv4 rule""" + # Create an existing model to return + existing_model = Flowspec4( + source=ipv4_form_data["source"], + source_mask=ipv4_form_data["source_mask"], + source_port=ipv4_form_data["source_port"], + destination=ipv4_form_data["dest"] or "", + destination_mask=ipv4_form_data["dest_mask"], + destination_port=ipv4_form_data["dest_port"] or "", + protocol=ipv4_form_data["protocol"], + flags=";".join(ipv4_form_data["flags"]), + packet_len=ipv4_form_data["packet_len"] or "", + fragment=";".join(ipv4_form_data["fragment"]), + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + rstate_id=1, + ) + mock_get_model.return_value = existing_model + mock_messages.create_ipv4.return_value = "mock command" + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + ipv4_form_data["expires"] = new_expires + + # Call the service function + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, message = rule_service.create_or_update_ipv4_rule( + form_data=ipv4_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() + + # Verify message is still a string for IPv4 rules + assert message == "Existing IPv4 Rule found. Expiration time was updated to new value." + + # Verify route was announced + mock_messages.create_ipv4.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() + + +class TestCreateOrUpdateIPv6Rule: + @patch("flowapp.services.rule_service.get_ipv6_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_create_new_ipv6_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv6_form_data + ): + """Test creating a new IPv6 rule""" + # Mock get_ipv6_model_if_exists to return False + mock_get_model.return_value = False + + # Mock the announce route behavior + mock_messages.create_ipv6.return_value = "mock command" + + # Call the service function + with app.app_context(): + model, message = rule_service.create_or_update_ipv6_rule( + form_data=ipv6_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was created with correct attributes + assert model is not None + assert model.source == ipv6_form_data["source"] + assert model.source_mask == ipv6_form_data["source_mask"] + assert model.next_header == ipv6_form_data["next_header"] + assert model.flags == "SYN" + assert model.action_id == ipv6_form_data["action"] + assert model.rstate_id == 1 # Active state + assert model.user_id == 1 + assert model.org_id == 1 + + # Verify message is still a string for IPv6 rules + assert message == "IPv6 Rule saved" + + # Verify route was announced + mock_messages.create_ipv6.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() + + +class TestCreateOrUpdateRTBHRule: + @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") + @patch("flowapp.services.rule_service.db.session.query") + @patch("flowapp.services.rule_service.check_rule_against_whitelists") + @patch("flowapp.services.rule_service.evaluate_rtbh_against_whitelists_check_results") + @patch("flowapp.services.rule_service.announce_rtbh_route") + @patch("flowapp.services.rule_service.log_route") + def test_create_new_rtbh_rule( + self, mock_log, mock_announce, mock_evaluate, mock_check, mock_query, mock_get_model, app, db, rtbh_form_data + ): + """Test creating a new RTBH rule""" + # Mock get_rtbh_model_if_exists to return False + mock_get_model.return_value = False + + # Mock the whitelist query + mock_whitelists = [] + mock_query.return_value.filter.return_value.all.return_value = mock_whitelists + + # Mock check_rule_against_whitelists to return empty list (no matches) + mock_check.return_value = [] + + # Mock evaluate function to return the model unchanged + mock_evaluate.side_effect = lambda user_id, model, flashes, author, wl_cache, results: model + + # Call the service function + with app.app_context(): + model, flashes = rule_service.create_or_update_rtbh_rule( + form_data=rtbh_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was created with correct attributes + assert model is not None + assert model.ipv4 == rtbh_form_data["ipv4"] + assert model.ipv4_mask == rtbh_form_data["ipv4_mask"] + assert model.ipv6 == rtbh_form_data["ipv6"] + assert model.ipv6_mask == rtbh_form_data["ipv6_mask"] + assert model.community_id == rtbh_form_data["community"] + assert model.rstate_id == 1 # Active state + assert model.user_id == 1 + assert model.org_id == 1 + + # Verify flash messages - now a list instead of a string + assert isinstance(flashes, list) + assert "RTBH Rule saved" in flashes[0] + + # Verify rule was announced + mock_announce.assert_called_once() + mock_log.assert_called_once() + + # Verify evaluate function was called + mock_evaluate.assert_called_once() + + @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") + @patch("flowapp.services.rule_service.db.session.query") + @patch("flowapp.services.rule_service.check_rule_against_whitelists") + @patch("flowapp.services.rule_service.evaluate_rtbh_against_whitelists_check_results") + @patch("flowapp.services.rule_service.announce_rtbh_route") + @patch("flowapp.services.rule_service.log_route") + def test_update_existing_rtbh_rule( + self, mock_log, mock_announce, mock_evaluate, mock_check, mock_query, mock_get_model, app, db, rtbh_form_data + ): + """Test updating an existing RTBH rule""" + # Create an existing model to return + existing_model = RTBH( + ipv4=rtbh_form_data["ipv4"], + ipv4_mask=rtbh_form_data["ipv4_mask"], + ipv6=rtbh_form_data["ipv6"] or "", + ipv6_mask=rtbh_form_data["ipv6_mask"], + community_id=rtbh_form_data["community"], + expires=datetime.now(), + user_id=1, + org_id=1, + rstate_id=1, + ) + mock_get_model.return_value = existing_model + + # Mock the whitelist query + mock_whitelists = [] + mock_query.return_value.filter.return_value.all.return_value = mock_whitelists + + # Mock check_rule_against_whitelists to return empty list + mock_check.return_value = [] + + # Mock evaluate function to return the model unchanged + mock_evaluate.side_effect = lambda user_id, model, flashes, author, wl_cache, results: model + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + rtbh_form_data["expires"] = new_expires + + # Call the service function + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, flashes = rule_service.create_or_update_rtbh_rule( + form_data=rtbh_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() + + # Verify flash messages + assert isinstance(flashes, list) + assert "Existing RTBH Rule found" in flashes[0] + + # Verify route was announced + mock_announce.assert_called_once() + mock_log.assert_called_once() + + @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") + @patch("flowapp.services.rule_service.db.session.query") + @patch("flowapp.services.rule_service.map_whitelists_to_strings") + @patch("flowapp.services.rule_service.check_rule_against_whitelists") + @patch("flowapp.services.rule_service.evaluate_rtbh_against_whitelists_check_results") + @patch("flowapp.services.rule_service.announce_rtbh_route") + @patch("flowapp.services.rule_service.log_route") + def test_rtbh_rule_with_whitelists( + self, + mock_log, + mock_announce, + mock_evaluate, + mock_check, + mock_map, + mock_query, + mock_get_model, + app, + db, + rtbh_form_data, + whitelist_fixture, + ): + """Test creating a RTBH rule that interacts with whitelists""" + # Mock get_rtbh_model_if_exists to return False + mock_get_model.return_value = False + + # Create a mock whitelist + mock_whitelist = whitelist_fixture + mock_whitelists = [mock_whitelist] + + # Setup mock query to return our whitelist + mock_query_result = MagicMock() + mock_query_result.filter.return_value.all.return_value = mock_whitelists + mock_query.return_value = mock_query_result + + # Setup map function to return our whitelist in a dict + whitelist_key = "192.168.1.0/24" + mock_map.return_value = {whitelist_key: mock_whitelist} + + # Setup check function to return a relation + mock_rtbh = MagicMock(spec=RTBH) + mock_rtbh.__str__.return_value = "192.168.1.0/24" + + # Setup check result + rule_relation = [(str(mock_rtbh), whitelist_key, Relation.EQUAL)] + mock_check.return_value = rule_relation + + # Setup evaluate function to add a flash message + def evaluate_side_effect(user_id, model, flashes, author, wl_cache, results): + flashes.append("Rule is equal to whitelist") + return model + + mock_evaluate.side_effect = evaluate_side_effect + + # Call the service function + with app.app_context(): + model, flashes = rule_service.create_or_update_rtbh_rule( + form_data=rtbh_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify flash messages show both rule creation and whitelist check + assert isinstance(flashes, list) + assert "RTBH Rule saved" in flashes[0] + assert "Rule is equal to whitelist" in flashes[1] + + # Verify interactions + mock_map.assert_called_once() + mock_check.assert_called_once() + mock_evaluate.assert_called_once() + + +class TestEvaluateRtbhAgainstWhitelistsCheckResults: + def test_equal_relation(self, app, whitelist_fixture): + """Test evaluating a rule with an EQUAL relation to a whitelist""" + # Create a model + model = MagicMock(spec=RTBH) + + # Create test data + flashes = [] + user_id = 1 + author = "test@example.com / Test Org" + whitelist_key = "192.168.1.0/24" + wl_cache = {whitelist_key: whitelist_fixture} + results = [(str(model), whitelist_key, Relation.EQUAL)] + + # Call the function with mocked whitelist_rtbh_rule + with patch("flowapp.services.rule_service.whitelist_rtbh_rule") as mock_whitelist_rule: + mock_whitelist_rule.return_value = model + + with app.app_context(): + result = rule_service.evaluate_rtbh_against_whitelists_check_results( + user_id, model, flashes, author, wl_cache, results + ) + + # Verify the rule was whitelisted + mock_whitelist_rule.assert_called_once_with(model, whitelist_fixture) + + # Verify the flash message + assert flashes + + # Verify the correct model was returned + assert result == model + + def test_subnet_relation(self, app, whitelist_fixture): + """Test evaluating a rule with a SUBNET relation to a whitelist""" + # Create a model + model = MagicMock(spec=RTBH) + + # Create test data + flashes = [] + user_id = 1 + author = "test@example.com / Test Org" + whitelist_key = "192.168.1.128/25" + wl_cache = {whitelist_key: whitelist_fixture} + results = [(str(model), whitelist_key, Relation.SUBNET)] + + # Call the function with mocked dependencies + with patch("flowapp.services.rule_service.subtract_network") as mock_subtract, patch( + "flowapp.services.rule_service.create_rtbh_from_whitelist_parts" + ) as mock_create, patch("flowapp.services.rule_service.add_rtbh_rule_to_cache") as mock_add_cache, patch( + "flowapp.services.rule_service.db.session.commit" + ) as mock_commit: + + # Mock subtract_network to return some subnets + mock_subtract.return_value = ["192.168.1.0/25"] + + with app.app_context(): + _result = rule_service.evaluate_rtbh_against_whitelists_check_results( + user_id, model, flashes, author, wl_cache, results + ) + + # Verify subnet calculation was performed + mock_subtract.assert_called_once() + + # Verify new rules were created for the subnets + mock_create.assert_called_once() + + # Verify the original rule was cached + mock_add_cache.assert_called_once_with(model, whitelist_fixture.id, RuleOrigin.USER) + + # Verify transaction was committed + mock_commit.assert_called_once() + + # Verify the flash messages + assert flashes + + # Verify model was updated to whitelisted state + assert model.rstate_id == 4 + + def test_supernet_relation(self, app, whitelist_fixture): + """Test evaluating a rule with a SUPERNET relation to a whitelist""" + # Create a model + model = MagicMock(spec=RTBH) + + # Create test data + flashes = [] + user_id = 1 + author = "test@example.com / Test Org" + whitelist_key = "192.168.0.0/16" + wl_cache = {whitelist_key: whitelist_fixture} + results = [(str(model), whitelist_key, Relation.SUPERNET)] + + # Call the function with mocked whitelist_rtbh_rule + with patch("flowapp.services.rule_service.whitelist_rtbh_rule") as mock_whitelist_rule: + mock_whitelist_rule.return_value = model + + with app.app_context(): + result = rule_service.evaluate_rtbh_against_whitelists_check_results( + user_id, model, flashes, author, wl_cache, results + ) + + # Verify the rule was whitelisted + mock_whitelist_rule.assert_called_once_with(model, whitelist_fixture) + + # Verify the flash message + assert flashes + + # Verify the correct model was returned + assert result == model + + def test_no_relation(self, app): + """Test evaluating a rule with no relation to any whitelist""" + # Create a model + model = MagicMock(spec=RTBH) + + # Create test data + flashes = [] + user_id = 1 + author = "test@example.com / Test Org" + wl_cache = {} + results = [] + + with app.app_context(): + result = rule_service.evaluate_rtbh_against_whitelists_check_results( + user_id, model, flashes, author, wl_cache, results + ) + + # Verify no changes to the model and no messages + assert result == model + assert not flashes + + +class TestMapWhitelistsToStrings: + def test_map_whitelists_to_strings(self): + """Test mapping whitelist objects to strings""" + # Create mock whitelists + whitelist1 = MagicMock(spec=Whitelist) + whitelist1.__str__.return_value = "192.168.1.0/24" + + whitelist2 = MagicMock(spec=Whitelist) + whitelist2.__str__.return_value = "10.0.0.0/8" + + whitelists = [whitelist1, whitelist2] + + # Call the function + result = rule_service.map_whitelists_to_strings(whitelists) + + # Verify the result + assert len(result) == 2 + assert "192.168.1.0/24" in result + assert "10.0.0.0/8" in result + assert result["192.168.1.0/24"] == whitelist1 + assert result["10.0.0.0/8"] == whitelist2 + + @patch("flowapp.services.rule_service.get_ipv6_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_update_existing_ipv6_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv6_form_data + ): + """Test updating an existing IPv6 rule""" + # Create an existing model to return + existing_model = Flowspec6( + source=ipv6_form_data["source"], + source_mask=ipv6_form_data["source_mask"], + source_port=ipv6_form_data["source_port"] or "", + destination=ipv6_form_data["dest"] or "", + destination_mask=ipv6_form_data["dest_mask"], + destination_port=ipv6_form_data["dest_port"] or "", + next_header=ipv6_form_data["next_header"], + flags=";".join(ipv6_form_data["flags"]), + packet_len=ipv6_form_data["packet_len"] or "", + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + rstate_id=1, + ) + mock_get_model.return_value = existing_model + mock_messages.create_ipv6.return_value = "mock command" + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + ipv6_form_data["expires"] = new_expires + + # Call the service function + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, message = rule_service.create_or_update_ipv6_rule( + form_data=ipv6_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() + + # Verify message is still a string for IPv6 rules + assert message == "Existing IPv6 Rule found. Expiration time was updated to new value." + + # Verify route was announced + mock_messages.create_ipv6.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() diff --git a/flowapp/tests/test_rule_service_reactivate_delete.py b/flowapp/tests/test_rule_service_reactivate_delete.py new file mode 100644 index 00000000..22a3f9f4 --- /dev/null +++ b/flowapp/tests/test_rule_service_reactivate_delete.py @@ -0,0 +1,527 @@ +"""Tests for rule_service module.""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import patch + +from flowapp.constants import RuleTypes +from flowapp.models import ( + RTBH, + Flowspec4, + Flowspec6, + Whitelist, +) +from flowapp.output import RouteSources +from flowapp.services import rule_service + + +@pytest.fixture +def test_data(app, db): + """Fixture providing test data for rule service tests.""" + current_time = datetime.now() + + # Create a test Flowspec4 rule + ipv4_rule = Flowspec4( + source="192.168.1.1", + source_mask=32, + source_port="", + destination="192.168.2.1", + destination_mask=32, + destination_port="", + protocol="tcp", + flags="", + packet_len="", + fragment="", + expires=current_time + timedelta(hours=1), + comment="Test IPv4 rule", + action_id=1, # Using action ID 1 from test database + user_id=1, # Using user ID 1 from test database + org_id=1, # Using org ID 1 from test database + rstate_id=1, # Active state + ) + db.session.add(ipv4_rule) + + # Create a test Flowspec6 rule + ipv6_rule = Flowspec6( + source="2001:db8::1", + source_mask=128, + source_port="", + destination="2001:db8:1::1", + destination_mask=128, + destination_port="", + next_header="tcp", + flags="", + packet_len="", + expires=current_time + timedelta(hours=1), + comment="Test IPv6 rule", + action_id=1, # Using action ID 1 from test database + user_id=1, # Using user ID 1 from test database + org_id=1, # Using org ID 1 from test database + rstate_id=1, # Active state + ) + db.session.add(ipv6_rule) + + # Create a test RTBH rule + rtbh_rule = RTBH( + ipv4="192.168.1.100", + ipv4_mask=32, + ipv6=None, + ipv6_mask=None, + community_id=1, # Using community ID 1 from test database + expires=current_time + timedelta(hours=1), + comment="Test RTBH rule", + user_id=1, # Using user ID 1 from test database + org_id=1, # Using org ID 1 from test database + rstate_id=1, # Active state + ) + db.session.add(rtbh_rule) + + # Create a test Whitelist + whitelist = Whitelist( + ip="192.168.2.1", + mask=24, + expires=current_time + timedelta(days=7), + comment="Test whitelist", + user_id=1, + org_id=1, + rstate_id=1, + ) + db.session.add(whitelist) + + db.session.commit() + + # Return data that will be useful for tests + return { + "user_id": 1, + "org_id": 1, + "user_email": "test@example.com", + "org_name": "Test Org", + "ipv4_rule_id": ipv4_rule.id, + "ipv6_rule_id": ipv6_rule.id, + "rtbh_rule_id": rtbh_rule.id, + "whitelist_id": whitelist.id, + "current_time": current_time, + "future_time": current_time + timedelta(hours=1), + "past_time": current_time - timedelta(hours=1), + "comment": "Test comment", + } + + +class TestReactivateRule: + """Tests for the reactivate_rule function.""" + + @patch("flowapp.services.rule_service.check_global_rule_limit") + @patch("flowapp.services.rule_service.check_rule_limit") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_reactivate_rule_active( + self, mock_log_route, mock_announce_route, mock_check_rule_limit, mock_check_global_limit, test_data, db + ): + """Test reactivating a rule with future expiration (active state).""" + # Setup mocks + mock_check_global_limit.return_value = False + mock_check_rule_limit.return_value = False + + # The rule will be active (state=1) because expiration is in the future + expires = test_data["future_time"] + + # Call the function + model, messages = rule_service.reactivate_rule( + rule_type=RuleTypes.IPv4, + rule_id=test_data["ipv4_rule_id"], + expires=expires, + comment=test_data["comment"], + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + ) + + # Assertions + mock_check_global_limit.assert_called_once_with(RuleTypes.IPv4.value) + mock_check_rule_limit.assert_called_once_with(test_data["org_id"], rule_type=RuleTypes.IPv4.value) + + # Verify model was updated + assert model.expires == expires + assert model.comment == test_data["comment"] + assert model.rstate_id == 1 # Active state + + # Verify route announcement was made + mock_announce_route.assert_called_once() + args, _ = mock_announce_route.call_args + assert args[0].source == RouteSources.UI + + # Verify logging + mock_log_route.assert_called_once() + + # Check returned values + assert messages + assert messages[0].startswith("Rule ") + + @patch("flowapp.services.rule_service.check_global_rule_limit") + @patch("flowapp.services.rule_service.check_rule_limit") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_withdraw") + def test_reactivate_rule_inactive( + self, mock_log_withdraw, mock_announce_route, mock_check_rule_limit, mock_check_global_limit, test_data, db + ): + """Test reactivating a rule with past expiration (inactive state).""" + # Setup mocks + mock_check_global_limit.return_value = False + mock_check_rule_limit.return_value = False + + # The rule will be inactive (state=2) because expiration is in the past + expires = test_data["past_time"] + + # Call the function + model, messages = rule_service.reactivate_rule( + rule_type=RuleTypes.IPv4, + rule_id=test_data["ipv4_rule_id"], + expires=expires, + comment=test_data["comment"], + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + ) + + # Verify model was updated + assert model.expires == expires + assert model.comment == test_data["comment"] + assert model.rstate_id == 2 # Inactive state + + # Verify route withdrawal was made + mock_announce_route.assert_called_once() + args, _ = mock_announce_route.call_args + assert args[0].source == RouteSources.UI + + # Verify logging + mock_log_withdraw.assert_called_once() + + # Check returned values + assert messages + assert messages[0].startswith("Rule ") + + @patch("flowapp.services.rule_service.check_global_rule_limit") + @patch("flowapp.services.rule_service.check_rule_limit") + def test_reactivate_rule_global_limit_reached(self, mock_check_rule_limit, mock_check_global_limit, test_data, db): + """Test reactivating a rule when global limit is reached.""" + # Setup mocks + mock_check_global_limit.return_value = True # Global limit reached + mock_check_rule_limit.return_value = False + + # The rule would be active, but global limit is reached + expires = test_data["future_time"] + + # Call the function + model, messages = rule_service.reactivate_rule( + rule_type=RuleTypes.IPv4, + rule_id=test_data["ipv4_rule_id"], + expires=expires, + comment=test_data["comment"], + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + ) + + # Assertions + mock_check_global_limit.assert_called_once_with(RuleTypes.IPv4.value) + + # Check returned values + assert messages == ["global_limit_reached"] + + @patch("flowapp.services.rule_service.check_global_rule_limit") + @patch("flowapp.services.rule_service.check_rule_limit") + def test_reactivate_rule_org_limit_reached(self, mock_check_rule_limit, mock_check_global_limit, test_data, db): + """Test reactivating a rule when organization limit is reached.""" + # Setup mocks + mock_check_global_limit.return_value = False + mock_check_rule_limit.return_value = True # Org limit reached + + # The rule would be active, but org limit is reached + expires = test_data["future_time"] + + # Call the function + model, messages = rule_service.reactivate_rule( + rule_type=RuleTypes.IPv4, + rule_id=test_data["ipv4_rule_id"], + expires=expires, + comment=test_data["comment"], + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + ) + + # Assertions + mock_check_global_limit.assert_called_once_with(RuleTypes.IPv4.value) + mock_check_rule_limit.assert_called_once_with(test_data["org_id"], rule_type=RuleTypes.IPv4.value) + + # Check returned values + assert messages == ["limit_reached"] + + +class TestDeleteRule: + """Tests for the delete_rule function.""" + + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_withdraw") + def test_delete_rule_success(self, mock_log_withdraw, mock_announce_route, test_data, db): + """Test successful rule deletion.""" + # Call the function + success, message = rule_service.delete_rule( + rule_type=RuleTypes.IPv4, + rule_id=test_data["ipv4_rule_id"], + user_id=test_data["user_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[test_data["ipv4_rule_id"]], # Rule is in allowed list + ) + + # Verify route withdrawal was made + mock_announce_route.assert_called_once() + args, _ = mock_announce_route.call_args + assert args[0].source == RouteSources.UI + + # Verify logging + mock_log_withdraw.assert_called_once() + + # Verify database operations - rule should be deleted + rule = db.session.get(Flowspec4, test_data["ipv4_rule_id"]) + assert rule is None + + # Check returned values + assert success is True + assert message == "Rule deleted successfully" + + def test_delete_rule_not_allowed(self, test_data, db): + """Test rule deletion when rule is not in allowed list.""" + # Call the function + success, message = rule_service.delete_rule( + rule_type=RuleTypes.IPv4, + rule_id=test_data["ipv4_rule_id"], + user_id=test_data["user_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[999], # Rule is not in allowed list + ) + + # Verify rule still exists + rule = db.session.get(Flowspec4, test_data["ipv4_rule_id"]) + assert rule is not None + + # Check returned values + assert success is False + assert message == "You cannot delete this rule" + + def test_delete_rule_not_found(self, test_data, db): + """Test rule deletion when rule is not found.""" + # Call the function with non-existent ID + success, message = rule_service.delete_rule( + rule_type=RuleTypes.IPv4, + rule_id=9999, # Non-existent ID + user_id=test_data["user_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + ) + + # Check returned values + assert success is False + assert message == "Rule not found" + + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_withdraw") + @patch("flowapp.services.rule_service.RuleWhitelistCache") + def test_delete_rtbh_rule(self, mock_whitelist_cache, mock_log_withdraw, mock_announce_route, test_data, db): + """Test deleting an RTBH rule.""" + # Call the function + success, message = rule_service.delete_rule( + rule_type=RuleTypes.RTBH, + rule_id=test_data["rtbh_rule_id"], + user_id=test_data["user_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[test_data["rtbh_rule_id"]], + ) + + # Verify whitelist cache was cleaned + mock_whitelist_cache.delete_by_rule_id.assert_called_once_with(test_data["rtbh_rule_id"]) + + # Verify route withdrawal and logging + mock_announce_route.assert_called_once() + mock_log_withdraw.assert_called_once() + + # Verify rule was deleted + rule = db.session.get(RTBH, test_data["rtbh_rule_id"]) + assert rule is None + + # Check returned values + assert success is True + assert message == "Rule deleted successfully" + + +class TestDeleteRtbhAndCreateWhitelist: + """Tests for the delete_rtbh_and_create_whitelist function.""" + + @patch("flowapp.services.rule_service.delete_rule") + @patch("flowapp.services.rule_service.create_or_update_whitelist") + def test_delete_rtbh_and_create_whitelist_success(self, mock_create_whitelist, mock_delete_rule, test_data, db): + """Test successful RTBH deletion and whitelist creation.""" + # Setup mock for delete_rule service + mock_delete_rule.return_value = (True, "Rule deleted successfully") + + # Setup mock for create_or_update_whitelist service + mock_whitelist = Whitelist( + ip="192.168.1.100", + mask=32, + expires=datetime.now() + timedelta(days=7), + user_id=test_data["user_id"], + org_id=test_data["org_id"], + ) + mock_create_whitelist.return_value = (mock_whitelist, ["Whitelist created"]) + + # Call the function + success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( + rule_id=test_data["rtbh_rule_id"], + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[test_data["rtbh_rule_id"]], + ) + + # Verify delete_rule was called + mock_delete_rule.assert_called_once_with( + rule_type=RuleTypes.RTBH, + rule_id=test_data["rtbh_rule_id"], + user_id=test_data["user_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[test_data["rtbh_rule_id"]], + ) + + # Verify create_or_update_whitelist was called with correct data + mock_create_whitelist.assert_called_once() + args, kwargs = mock_create_whitelist.call_args + form_data = kwargs.get("form_data", args[0] if args else None) + assert form_data["ip"] == "192.168.1.100" + assert form_data["mask"] == 32 + assert "Created from RTBH rule" in form_data["comment"] + + # Check returned values + assert success is True + assert len(messages) == 2 + assert messages[0] == "Rule deleted successfully" + assert messages[1] == "Whitelist created" + assert whitelist == mock_whitelist + + def test_delete_rtbh_and_create_whitelist_rule_not_found(self, test_data, db): + """Test when the RTBH rule to convert is not found.""" + # Call the function with non-existent ID + success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( + rule_id=9999, # Non-existent ID + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + ) + + # Check returned values + assert success is False + assert messages == ["RTBH rule not found"] + assert whitelist is None + + def test_delete_rtbh_and_create_whitelist_not_allowed(self, test_data, db): + """Test when the user is not allowed to delete the rule.""" + # Call the function + success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( + rule_id=test_data["rtbh_rule_id"], + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[999], # Rule not in allowed list + ) + + # Check returned values + assert success is False + assert messages == ["You cannot delete this rule"] + assert whitelist is None + + @patch("flowapp.services.rule_service.delete_rule") + def test_delete_rtbh_and_create_whitelist_delete_fails(self, mock_delete_rule, test_data, db): + """Test when the RTBH deletion fails.""" + # Setup mock for delete_rule service to fail + mock_delete_rule.return_value = (False, "Error deleting rule") + + # Call the function + success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( + rule_id=test_data["rtbh_rule_id"], + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[test_data["rtbh_rule_id"]], + ) + + # Verify delete_rule was called + mock_delete_rule.assert_called_once() + + # Check returned values + assert success is False + assert messages == ["Error deleting rule"] + assert whitelist is None + + @patch("flowapp.services.rule_service.delete_rule") + @patch("flowapp.services.rule_service.create_or_update_whitelist") + def test_rtbh_with_ipv6(self, mock_create_whitelist, mock_delete_rule, test_data, db): + """Test with an IPv6 RTBH rule.""" + # Create an IPv6 RTBH rule + ipv6_rtbh = RTBH( + ipv4=None, + ipv4_mask=None, + ipv6="2001:db8::1", + ipv6_mask=64, + community_id=1, + expires=datetime.now() + timedelta(hours=1), + comment="IPv6 RTBH rule", + user_id=test_data["user_id"], + org_id=test_data["org_id"], + rstate_id=1, + ) + db.session.add(ipv6_rtbh) + db.session.commit() + + # Setup mocks + mock_delete_rule.return_value = (True, "Rule deleted successfully") + mock_whitelist = Whitelist( + ip="2001:db8::1", + mask=64, + expires=datetime.now() + timedelta(days=7), + user_id=test_data["user_id"], + org_id=test_data["org_id"], + ) + mock_create_whitelist.return_value = (mock_whitelist, ["Whitelist created"]) + + # Call the function + success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( + rule_id=ipv6_rtbh.id, + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[ipv6_rtbh.id], + ) + + # Verify create_or_update_whitelist was called with correct IPv6 data + mock_create_whitelist.assert_called_once() + args, kwargs = mock_create_whitelist.call_args + form_data = kwargs.get("form_data", args[0] if args else None) + assert form_data["ip"] == "2001:db8::1" + assert form_data["mask"] == 64 + + # Check returned values + assert success is True + assert len(messages) == 2 + assert whitelist == mock_whitelist diff --git a/flowapp/tests/test_validators.py b/flowapp/tests/test_validators.py index 33e757cf..cab089f0 100644 --- a/flowapp/tests/test_validators.py +++ b/flowapp/tests/test_validators.py @@ -1,11 +1,28 @@ import pytest -import flowapp.validators +from flowapp.validators import ( + PortString, + PacketString, + NetRangeString, + NetInRange, + IPAddress, + IPAddressValidator, + NetworkValidator, + ValidationError, + address_in_range, + address_with_mask, + IPv4Address, + IPv6Address, + DateNotExpired, + editable_range, + network_in_range, + range_in_network, +) def test_port_string_len_raises(field): - port = flowapp.validators.PortString() + port = PortString() field.data = "1;2;3;4;5;6;7;8" - with pytest.raises(flowapp.validators.ValidationError): + with pytest.raises(ValidationError): port(None, field) @@ -20,12 +37,12 @@ def test_port_string_len_raises(field): ], ) def test_is_valid_address_with_mask(address, mask, expected): - assert flowapp.validators.address_with_mask(address, mask) == expected + assert address_with_mask(address, mask) == expected @pytest.mark.parametrize("address", ["147.230.23.25", "147.230.23.0"]) def test_ip4address_passes(field, address): - adr = flowapp.validators.IPv4Address() + adr = IPv4Address() field.data = address adr(None, field) @@ -38,7 +55,7 @@ def test_ip4address_passes(field, address): ], ) def test_ip6address_passes(field, address): - adr = flowapp.validators.IPv6Address() + adr = IPv6Address() field.data = address adr(None, field) @@ -51,19 +68,17 @@ def test_ip6address_passes(field, address): ], ) def test_bad_ip6address_raises(field, address): - adr = flowapp.validators.IPv4Address() + adr = IPv4Address() field.data = address - with pytest.raises(flowapp.validators.ValidationError): + with pytest.raises(ValidationError): adr(None, field) -@pytest.mark.parametrize( - "expired", ["2018/10/25 14:46", "2018/12/20 9:46", "2019/05/22 12:33"] -) +@pytest.mark.parametrize("expired", ["2018/10/25 14:46", "2018/12/20 9:46", "2019/05/22 12:33"]) def test_expired_date_raises(field, expired): - adr = flowapp.validators.DateNotExpired() + adr = DateNotExpired() field.data = expired - with pytest.raises(flowapp.validators.ValidationError): + with pytest.raises(ValidationError): adr(None, field) @@ -75,9 +90,9 @@ def test_expired_date_raises(field, expired): ], ) def test_ipaddress_raises(field, address): - adr = flowapp.validators.IPv6Address() + adr = IPv6Address() field.data = address - with pytest.raises(flowapp.validators.ValidationError): + with pytest.raises(ValidationError): adr(None, field) @@ -91,7 +106,7 @@ def test_ipaddress_raises(field, address): def test_editable_rule(rule, address, mask, ranges, expected): rule.source = address rule.source_mask = mask - assert flowapp.validators.editable_range(rule, ranges) == expected + assert editable_range(rule, ranges) == expected @pytest.mark.parametrize( @@ -104,7 +119,7 @@ def test_editable_rule(rule, address, mask, ranges, expected): ], ) def test_address_in_range(address, mask, ranges, expected): - assert flowapp.validators.address_in_range(address, ranges) == expected + assert address_in_range(address, ranges) == expected @pytest.mark.parametrize( @@ -123,7 +138,7 @@ def test_address_in_range(address, mask, ranges, expected): ], ) def test_network_in_range(address, mask, ranges, expected): - assert flowapp.validators.network_in_range(address, mask, ranges) == expected + assert network_in_range(address, mask, ranges) == expected @pytest.mark.parametrize( @@ -133,4 +148,209 @@ def test_network_in_range(address, mask, ranges, expected): ], ) def test_range_in_network(address, mask, ranges, expected): - assert flowapp.validators.range_in_network(address, mask, ranges) == expected + assert range_in_network(address, mask, ranges) == expected + + +### new tests + + +# New tests for previously uncovered validators + + +# Tests for PortString validator syntax +@pytest.mark.parametrize( + "port_data", + [ + "80", # Simple number + "80-443", # Range with hyphen + ">=80&<=443", # Explicit range + ">80", # Greater than + ">=80", # Greater than or equal + "<443", # Less than + "<=443", # Less than or equal + "80;443;8080", # Multiple port expressions + ], +) +def test_port_string_valid_syntax(field, port_data): + validator = PortString() + field.data = port_data + validator(None, field) # Should not raise + + +@pytest.mark.parametrize( + "invalid_port_data", + [ + "80,443", # Comma not supported, should use semicolon + "80..443", # Invalid range syntax (should be 80-443) + ">65536", # Port number too high + "-1", # Negative number + "<-1", # Port number too low + ">=80<=443", # Invalid range syntax (missing &) + "!=80", # Not equals not supported + "abc", # Non-numeric + "80-", # Incomplete range + "-80", # Invalid range + "80&&443", # Invalid operator + "443-80", # End less than start + ">=443&<=80", # End less than start in explicit range + "0-65537", # Range exceeding max + ], +) +def test_port_string_invalid_syntax(field, invalid_port_data): + validator = PortString() + field.data = invalid_port_data + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for PacketString validator +@pytest.mark.parametrize( + "packet_data", + ["1500", ">1000", "<9000", "1000-1500", ">=1000&<=1500", "1500;9000"], # Multiple packet size expressions +) +def test_packet_string_valid(field, packet_data): + validator = PacketString() + field.data = packet_data + validator(None, field) # Should not raise + + +@pytest.mark.parametrize( + "invalid_packet_data", + [ + "65536", # Too large + "-1", # Negative + "1500..", # Invalid range syntax + "abc", # Non-numeric + "!!1500", # Invalid operator + "1500-", # Incomplete range + "9000-1500", # End less than start + ], +) +def test_packet_string_invalid(field, invalid_packet_data): + validator = PacketString() + field.data = invalid_packet_data + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for NetRangeString validator +@pytest.mark.parametrize( + "net_range", + [ + "192.168.0.0/24", + "10.0.0.0/8\n172.16.0.0/12", + "2001:db8::/32", + "192.168.0.0/24 10.0.0.0/8", + "2001:db8::/32 2001:db8:1::/48", + ], +) +def test_net_range_string_valid(field, net_range): + validator = NetRangeString() + field.data = net_range + validator(None, field) # Should not raise + + +@pytest.mark.parametrize( + "invalid_net_range", + [ + "192.168.0.0/33", # Invalid mask + "256.256.256.0/24", # Invalid IP + "2001:xyz::/32", # Invalid IPv6 + "192.168.0.0/24/", # Malformed + "not-an-ip/24", # Invalid format + ], +) +def test_net_range_string_invalid(field, invalid_net_range): + validator = NetRangeString() + field.data = invalid_net_range + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for NetInRange validator +def test_net_in_range_valid(field): + net_ranges = ["192.168.0.0/16", "10.0.0.0/8"] + validator = NetInRange(net_ranges) + field.data = "192.168.1.1/24" + validator(None, field) # Should not raise + + +def test_net_in_range_invalid(field): + net_ranges = ["192.168.0.0/16", "10.0.0.0/8"] + validator = NetInRange(net_ranges) + field.data = "172.16.1.1/24" + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for base IPAddress validator +@pytest.mark.parametrize("ip_addr", ["192.168.1.1", "10.0.0.1", "2001:db8::1", "fe80::1"]) +def test_ip_address_valid(field, ip_addr): + validator = IPAddress() + field.data = ip_addr + validator(None, field) # Should not raise + + +@pytest.mark.parametrize("invalid_ip", ["256.256.256.256", "2001:xyz::1", "not-an-ip", "192.168.1"]) +def test_ip_address_invalid(field, invalid_ip): + validator = IPAddress() + field.data = invalid_ip + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for universal IPAddressValidator +@pytest.mark.parametrize("ip_addr", ["192.168.1.1", "2001:db8::1", "", None]) # Empty is allowed # None is allowed +def test_ip_address_validator_valid(field, ip_addr): + validator = IPAddressValidator() + field.data = ip_addr + validator(None, field) # Should not raise + + +@pytest.mark.parametrize("invalid_ip", ["256.256.256.256", "2001:xyz::1", "not-an-ip"]) +def test_ip_address_validator_invalid(field, invalid_ip): + validator = IPAddressValidator() + field.data = invalid_ip + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for NetworkValidator +class MockForm: + def __init__(self, address, mask): + self._fields = {"mask": type("MockField", (), {"data": mask})()} + + +@pytest.mark.parametrize( + "address,mask", + [ + ("192.168.0.0", "24"), + ("10.0.0.0", "8"), + ("2001:db8::", "32"), + ("", None), # Empty values should pass + (None, None), # None values should pass + ], +) +def test_network_validator_valid(field, address, mask): + validator = NetworkValidator("mask") + field.data = address + form = MockForm(address, mask) + validator(form, field) # Should not raise + + +@pytest.mark.parametrize( + "address,mask", + [ + ("192.168.1.1", "24"), # Not a network address + ("192.168.0.0", "33"), # Invalid IPv4 mask + ("2001:db8::", "129"), # Invalid IPv6 mask + ("256.256.256.0", "24"), # Invalid IP + ("2001:xyz::", "32"), # Invalid IPv6 + ], +) +def test_network_validator_invalid(field, address, mask): + validator = NetworkValidator("mask") + field.data = address + form = MockForm(address, mask) + with pytest.raises(ValidationError): + validator(form, field) diff --git a/flowapp/tests/test_whitelist_common.py b/flowapp/tests/test_whitelist_common.py new file mode 100644 index 00000000..a843aff9 --- /dev/null +++ b/flowapp/tests/test_whitelist_common.py @@ -0,0 +1,250 @@ +import pytest +from flowapp.services.whitelist_common import ( + Relation, + check_rule_against_whitelists, + check_whitelist_against_rules, + check_whitelist_to_rule_relation, + subtract_network, + clear_network_cache, + _is_same_ip_version, # New helper function +) + + +# Tests for core function that checks network relations +@pytest.mark.parametrize( + "rule,whitelist,expected", + [ + # IPv4 test cases + ("192.168.1.0/24", "192.168.0.0/16", Relation.SUPERNET), # Whitelist is supernet + ("192.168.1.0/24", "192.168.1.0/24", Relation.EQUAL), # Equal networks + ("192.168.1.0/24", "192.168.1.128/25", Relation.SUBNET), # Whitelist is subnet + ("192.168.1.128/25", "192.168.1.0/24", Relation.SUPERNET), # Whitelist is supernet + ("10.0.0.0/8", "192.168.1.0/24", Relation.DIFFERENT), # Different networks + # IPv6 test cases + ("2001:db8::/32", "2001:db8::/32", Relation.EQUAL), # Equal networks + ("2001:db8:1::/48", "2001:db8::/32", Relation.SUPERNET), # Whitelist is supernet + ("2001:db8::/32", "2001:db8:1::/48", Relation.SUBNET), # Whitelist is subnet + ("2001:db8::/32", "2002:db8::/32", Relation.DIFFERENT), # Different networks + ("2001:db8:1:2::/64", "2001:db8::/32", Relation.SUPERNET), # Whitelist is supernet + ], +) +def test_check_whitelist_to_rule_relation(rule, whitelist, expected): + clear_network_cache() + assert check_whitelist_to_rule_relation(rule, whitelist) == expected + + +@pytest.mark.parametrize( + "target,whitelist,expected", + [ + # Basic IPv4 cases + ( + "192.168.1.0/24", + "192.168.1.128/25", + ["192.168.1.0/25"], + ), # One remaining subnet + ( + "192.168.1.0/24", + "192.168.1.64/26", + ["192.168.1.0/26", "192.168.1.128/25"], + ), # Two remaining subnets + ( + "192.168.1.0/24", + "192.168.1.0/24", + ["192.168.1.0/24"], + ), # Equal networks - return original network + ( + "192.168.1.0/24", + "192.168.2.0/24", + ["192.168.1.0/24"], + ), # No overlap + ], +) +def test_subtract_network(target, whitelist, expected): + clear_network_cache() + result = subtract_network(target, whitelist) + assert sorted(result) == sorted(expected) + + +def test_check_rule_against_whitelists(): + clear_network_cache() + + rule = "192.168.1.0/24" + whitelists = [ + "192.168.0.0/16", # SUPERNET + "172.16.0.0/12", # DIFFERENT + "192.168.1.0/24", # EQUAL + "192.168.1.128/25", # SUBNET + ] + + results = check_rule_against_whitelists(rule, whitelists) + + # Should return list of tuples (rule, whitelist, relation) for non-DIFFERENT relations + expected = [ + (rule, "192.168.0.0/16", Relation.SUPERNET), + (rule, "192.168.1.0/24", Relation.EQUAL), + (rule, "192.168.1.128/25", Relation.SUBNET), + ] + + assert len(results) == 3 + for result, exp in zip(sorted(results), sorted(expected)): + assert result == exp + + +def test_check_whitelist_against_rules(): + clear_network_cache() + + whitelist = "192.168.1.128/25" + rules = [ + "192.168.0.0/16", # Whitelist is SUBNET of this larger network + "172.16.0.0/12", # DIFFERENT + "192.168.1.128/25", # EQUAL + "192.168.1.0/24", # Whitelist is SUBNET of this network + ] + + results = check_whitelist_against_rules(rules, whitelist) + + # Should return list of tuples (rule, whitelist, relation) for non-DIFFERENT relations + expected = [ + ("192.168.0.0/16", whitelist, Relation.SUBNET), + ("192.168.1.128/25", whitelist, Relation.EQUAL), + ("192.168.1.0/24", whitelist, Relation.SUBNET), + ] + + assert len(results) == 3 + for result, exp in zip(sorted(results), sorted(expected)): + assert result == exp + + +# Tests for edge cases and error handling +def test_host_addresses(): + clear_network_cache() + # Test with host addresses (no explicit subnet mask) + assert check_whitelist_to_rule_relation("192.168.1.1", "192.168.1.0/24") == Relation.SUPERNET + assert check_whitelist_to_rule_relation("192.168.1.1", "192.168.2.0/24") == Relation.DIFFERENT + + +def test_single_ip_as_network(): + clear_network_cache() + # Test with /32 (single IPv4 address as network) + assert check_whitelist_to_rule_relation("192.168.1.1/32", "192.168.1.0/24") == Relation.SUPERNET + # Test with /128 (single IPv6 address as network) + assert check_whitelist_to_rule_relation("2001:db8::1/128", "2001:db8::/32") == Relation.SUPERNET + + +def test_is_same_ip_version(): + """Test the new helper function to check if two addresses are of the same IP version""" + # IPv4 cases + assert _is_same_ip_version("192.168.1.0/24", "10.0.0.0/8") is True + assert _is_same_ip_version("192.168.1.1", "10.0.0.1") is True + + # IPv6 cases + assert _is_same_ip_version("2001:db8::/32", "fe80::/64") is True + assert _is_same_ip_version("2001:db8::1", "fe80::1") is True + + # Mixed cases + assert _is_same_ip_version("192.168.1.0/24", "2001:db8::/32") is False + assert _is_same_ip_version("192.168.1.1", "fe80::1") is False + + +def test_check_whitelist_to_rule_relation_mixed_versions(): + """Test the check_whitelist_to_rule_relation function with mixed IP versions""" + clear_network_cache() + + # IPv4 rule with IPv6 whitelist + assert check_whitelist_to_rule_relation("192.168.1.0/24", "2001:db8::/32") == Relation.DIFFERENT + + # IPv6 rule with IPv4 whitelist + assert check_whitelist_to_rule_relation("2001:db8::/32", "192.168.1.0/24") == Relation.DIFFERENT + + # Invalid IP addresses should return DIFFERENT + assert check_whitelist_to_rule_relation("invalid", "192.168.1.0/24") == Relation.DIFFERENT + assert check_whitelist_to_rule_relation("192.168.1.0/24", "invalid") == Relation.DIFFERENT + + +def test_subtract_network_mixed_versions(): + """Test the subtract_network function with mixed IP versions""" + clear_network_cache() + + # IPv4 target with IPv6 whitelist - should return original target + result = subtract_network("192.168.1.0/24", "2001:db8::/32") + assert result == ["192.168.1.0/24"] + + # IPv6 target with IPv4 whitelist - should return original target + result = subtract_network("2001:db8::/32", "192.168.1.0/24") + assert result == ["2001:db8::/32"] + + # Invalid addresses - should return original target + result = subtract_network("invalid", "192.168.1.0/24") + assert result == ["invalid"] + result = subtract_network("192.168.1.0/24", "invalid") + assert result == ["192.168.1.0/24"] + + +def test_check_rule_against_whitelists_mixed_versions(): + """Test the check_rule_against_whitelists function with mixed IP versions""" + clear_network_cache() + + # IPv4 rule against mixed whitelist + rule = "192.168.1.0/24" + whitelists = [ + "192.168.0.0/16", # IPv4 - should match + "2001:db8::/32", # IPv6 - should not match + "10.0.0.0/8", # IPv4 - should not match (different network) + ] + + results = check_rule_against_whitelists(rule, whitelists) + assert len(results) == 1 + assert results[0][0] == rule + assert results[0][1] == "192.168.0.0/16" + assert results[0][2] == Relation.SUPERNET + + # IPv6 rule against mixed whitelist + rule = "2001:db8:1::/48" + whitelists = [ + "192.168.0.0/16", # IPv4 - should not match + "2001:db8::/32", # IPv6 - should match + "2002:db8::/32", # IPv6 - should not match (different network) + ] + + results = check_rule_against_whitelists(rule, whitelists) + assert len(results) == 1 + assert results[0][0] == rule + assert results[0][1] == "2001:db8::/32" + assert results[0][2] == Relation.SUPERNET + + +def test_check_whitelist_against_rules_mixed_versions(): + """Test the check_whitelist_against_rules function with mixed IP versions""" + clear_network_cache() + + # IPv4 whitelist against mixed rules + whitelist = "192.168.1.0/24" + rules = [ + "192.168.0.0/16", # IPv4 - should match + "2001:db8::/32", # IPv6 - should not match + "10.0.0.0/8", # IPv4 - should not match (different network) + ] + + results = check_whitelist_against_rules(rules, whitelist) + assert len(results) == 1 + assert results[0][0] == "192.168.0.0/16" + assert results[0][1] == whitelist + assert results[0][2] == Relation.SUBNET + + # IPv6 whitelist against mixed rules + whitelist = "2001:db8:1::/48" + rules = [ + "192.168.0.0/16", # IPv4 - should not match + "2001:db8::/32", # IPv6 - should match + "2002:db8::/32", # IPv6 - should not match (different network) + ] + + results = check_whitelist_against_rules(rules, whitelist) + assert len(results) == 1 + assert results[0][0] == "2001:db8::/32" + assert results[0][1] == whitelist + assert results[0][2] == Relation.SUBNET + + +if __name__ == "__main__": + pytest.main(["-v"]) diff --git a/flowapp/tests/test_whitelist_service.py b/flowapp/tests/test_whitelist_service.py new file mode 100644 index 00000000..460a172b --- /dev/null +++ b/flowapp/tests/test_whitelist_service.py @@ -0,0 +1,461 @@ +""" +Tests for whitelist_service.py module. + +This test suite verifies the functionality of the whitelist service, +which manages creation, updating, and handling of whitelist rules. +""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import patch, MagicMock + +from flowapp.constants import RuleTypes, RuleOrigin +from flowapp.models import Whitelist, RuleWhitelistCache, RTBH +from flowapp.services import whitelist_service +from flowapp.services.whitelist_common import Relation + + +@pytest.fixture +def whitelist_form_data(): + """Sample valid whitelist form data""" + return { + "ip": "192.168.1.0", + "mask": 24, + "comment": "Test whitelist entry", + "expires": datetime.now() + timedelta(hours=1), + } + + +class TestCreateOrUpdateWhitelist: + @patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists") + @patch("flowapp.services.whitelist_service.check_whitelist_against_rules") + @patch("flowapp.services.whitelist_service.evaluate_whitelist_against_rtbh_check_results") + def test_create_new_whitelist( + self, mock_evaluate, mock_check_whitelist, mock_get_model, app, db, whitelist_form_data + ): + """Test creating a new whitelist entry""" + # Mock the get_whitelist_model_if_exists to return False (not found) + mock_get_model.return_value = False + + # Mock RTBH rules and check results + mock_rtbh_rules = [] + mock_query = MagicMock() + mock_query.filter.return_value.all.return_value = mock_rtbh_rules + + # Mock check_whitelist_against_rules to return empty list (no matches) + mock_check_whitelist.return_value = [] + + # Mock evaluate to return the model unchanged + mock_evaluate.side_effect = lambda model, flashes, rtbh_rule_cache, results: model + + # Call the service function + with patch("flowapp.services.whitelist_service.db.session.query", return_value=mock_query): + with app.app_context(): + model, flashes = whitelist_service.create_or_update_whitelist( + form_data=whitelist_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was created with correct attributes + assert model is not None + assert model.ip == whitelist_form_data["ip"] + assert model.mask == whitelist_form_data["mask"] + assert model.comment == whitelist_form_data["comment"] + assert model.user_id == 1 + assert model.org_id == 1 + assert model.rstate_id == 1 # Active state + + # Verify flash messages - now a list instead of a string + assert isinstance(flashes, list) + assert "Whitelist saved" in flashes[0] + + @patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists") + @patch("flowapp.services.whitelist_service.check_whitelist_against_rules") + @patch("flowapp.services.whitelist_service.evaluate_whitelist_against_rtbh_check_results") + def test_update_existing_whitelist( + self, mock_evaluate, mock_check_whitelist, mock_get_model, app, db, whitelist_form_data + ): + """Test updating an existing whitelist entry""" + # Create an existing whitelist + existing_model = Whitelist( + ip=whitelist_form_data["ip"], + mask=whitelist_form_data["mask"], + expires=datetime.now(), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Mock to return the existing model + mock_get_model.return_value = existing_model + + # Mock RTBH rules and check results + mock_rtbh_rules = [] + mock_query = MagicMock() + mock_query.filter.return_value.all.return_value = mock_rtbh_rules + + # Mock check_whitelist_against_rules to return empty list (no matches) + mock_check_whitelist.return_value = [] + + # Mock evaluate to return the model unchanged + mock_evaluate.side_effect = lambda model, flashes, rtbh_rule_cache, results: model + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + whitelist_form_data["expires"] = new_expires + + # Call the service function + with patch("flowapp.services.whitelist_service.db.session.query", return_value=mock_query): + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, flashes = whitelist_service.create_or_update_whitelist( + form_data=whitelist_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + # Check that expiration date was updated (rounded to 10 minutes) + # We can't compare exact timestamps, so check date parts + assert model.expires.date() == whitelist_service.round_to_ten_minutes(new_expires).date() + + # Verify flash messages - now a list instead of a string + assert isinstance(flashes, list) + assert "Existing Whitelist found" in flashes[0] + + @patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists") + @patch("flowapp.services.whitelist_service.db.session.query") + @patch("flowapp.services.whitelist_service.map_rtbh_rules_to_strings") + @patch("flowapp.services.whitelist_service.check_whitelist_against_rules") + @patch("flowapp.services.whitelist_service.evaluate_whitelist_against_rtbh_check_results") + def test_create_whitelist_with_matching_rules( + self, mock_evaluate, mock_check, mock_map, mock_query, mock_get_model, app, db, whitelist_form_data + ): + """Test creating a whitelist that affects existing rules""" + # Mock get_whitelist_model_if_exists to return False (new whitelist) + mock_get_model.return_value = False + + # Create a mock RTBH rule + mock_rtbh_rule = MagicMock(spec=RTBH) + mock_rtbh_rule.__str__.return_value = "192.168.1.0/24" + mock_rtbh_rules = [mock_rtbh_rule] + + # Setup mock query to return our rule + mock_query_result = MagicMock() + mock_query_result.filter.return_value.all.return_value = mock_rtbh_rules + mock_query.return_value = mock_query_result + + # Setup map function to return our rule in a dict + mock_map.return_value = {"192.168.1.0/24": mock_rtbh_rule} + + # Setup check function to return a relation + rule_relation = [ + ( + str(mock_rtbh_rule), + str(whitelist_form_data["ip"]) + "/" + str(whitelist_form_data["mask"]), + Relation.EQUAL, + ) + ] + mock_check.return_value = rule_relation + + # Setup evaluate function to modify flashes + def evaluate_side_effect(model, flashes, rtbh_rule_cache, results): + flashes.append("Rule was whitelisted") + return model + + mock_evaluate.side_effect = evaluate_side_effect + + # Call the service function + with app.app_context(): + model, flashes = whitelist_service.create_or_update_whitelist( + form_data=whitelist_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify flash messages includes both whitelist creation and rule whitelisting + assert "Whitelist saved" in flashes[0] + assert "Rule was whitelisted" in flashes[1] + + # Verify interactions + mock_map.assert_called_once() + mock_check.assert_called_once() + mock_evaluate.assert_called_once() + + +class TestDeleteWhitelist: + @patch("flowapp.services.whitelist_service.announce_rtbh_route") + def test_delete_whitelist_with_user_rules(self, mock_announce, app, db): + """Test deleting a whitelist that has user-created rules attached to it""" + # Create a whitelist + whitelist = Whitelist( + ip="192.168.1.0", + mask=24, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Create a RTBH rule (created by user, whitelisted) + rtbh_rule = RTBH( + ipv4="192.168.1.0", + ipv4_mask=24, + ipv6="", + ipv6_mask=None, + community_id=1, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=4, # Whitelisted state + ) + + # Create a cache entry linking the rule to the whitelist + cache_entry = RuleWhitelistCache( + rid=1, + rtype=RuleTypes.RTBH, + whitelist_id=1, + rorigin=RuleOrigin.USER, + ) + + with app.app_context(): + db.session.add(whitelist) + db.session.add(rtbh_rule) + db.session.commit() + + # Set whitelist ID now that it has been saved + cache_entry.whitelist_id = whitelist.id + cache_entry.rid = rtbh_rule.id + db.session.add(cache_entry) + db.session.commit() + + # Mock the get_by_whitelist_id to return our cache entry + with patch.object(RuleWhitelistCache, "get_by_whitelist_id", return_value=[cache_entry]): + # Call the service function + flashes = whitelist_service.delete_whitelist(whitelist.id) + + # Verify the rule state was changed back to active + rtbh_rule = db.session.get(RTBH, rtbh_rule.id) + assert rtbh_rule.rstate_id == 1 # Active state + + # Verify announcement was made + mock_announce.assert_called_once() + + # Verify flash messages + assert isinstance(flashes, list) + assert flashes + + # Verify the whitelist was deleted + assert db.session.get(Whitelist, whitelist.id) is None + + @patch("flowapp.services.whitelist_service.RuleWhitelistCache.clean_by_whitelist_id") + def test_delete_whitelist_with_whitelist_created_rules(self, mock_clean, app, db): + """Test deleting a whitelist that has rules created by the whitelist""" + # Create a whitelist + whitelist = Whitelist( + ip="192.168.1.0", + mask=24, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Create a RTBH rule (created by whitelist) + rtbh_rule = RTBH( + ipv4="192.168.1.0", + ipv4_mask=24, + ipv6="", + ipv6_mask=None, + community_id=1, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Create a cache entry linking the rule to the whitelist + cache_entry = RuleWhitelistCache( + rid=1, + rtype=RuleTypes.RTBH, + whitelist_id=1, + rorigin=RuleOrigin.WHITELIST, # Important: created BY whitelist + ) + + with app.app_context(): + db.session.add(whitelist) + db.session.add(rtbh_rule) + db.session.commit() + + # Set IDs now that they have been saved + cache_entry.whitelist_id = whitelist.id + cache_entry.rid = rtbh_rule.id + db.session.add(cache_entry) + db.session.commit() + + # Create a mock session that can get our rule + with patch.object(RuleWhitelistCache, "get_by_whitelist_id", return_value=[cache_entry]): + # Call the service function + flashes = whitelist_service.delete_whitelist(whitelist.id) + + # Verify flash messages + assert isinstance(flashes, list) + assert flashes + + # Verify the rule was deleted + assert db.session.get(RTBH, rtbh_rule.id) is None + + # Verify the whitelist was deleted + assert db.session.get(Whitelist, whitelist.id) is None + + # Verify cache cleanup was called + mock_clean.assert_called_once_with(whitelist.id) + + def test_delete_nonexistent_whitelist(self, app, db): + """Test deleting a whitelist that doesn't exist""" + with app.app_context(): + # Call the service function with a non-existent ID + flashes = whitelist_service.delete_whitelist(999) + + # Should return empty list of flash messages, as no whitelist was found + assert isinstance(flashes, list) + assert len(flashes) == 0 + + +class TestEvaluateWhitelistAgainstRtbhResults: + def test_equal_relation(self, app): + """Test evaluating a whitelist with an EQUAL relation to a rule""" + # Create test data + whitelist_model = MagicMock(spec=Whitelist) + whitelist_model.id = 1 + + flashes = [] + + rtbh_rule = MagicMock(spec=RTBH) + rtbh_rule.rstate_id = 1 # Active state + + rule_key = "192.168.1.0/24" + whitelist_key = "192.168.1.0/24" + rtbh_rule_cache = {rule_key: rtbh_rule} + + results = [(rule_key, whitelist_key, Relation.EQUAL)] + + # Mock required functions + with patch("flowapp.services.whitelist_service.whitelist_rtbh_rule") as mock_whitelist_rule, patch( + "flowapp.services.whitelist_service.withdraw_rtbh_route" + ) as mock_withdraw: + + # Call the function + with app.app_context(): + result = whitelist_service.evaluate_whitelist_against_rtbh_check_results( + whitelist_model, flashes, rtbh_rule_cache, results + ) + + # Verify the rule was whitelisted and route withdrawn + mock_whitelist_rule.assert_called_once_with(rtbh_rule, whitelist_model) + mock_withdraw.assert_called_once_with(rtbh_rule) + + # Verify the flash message + assert flashes + + # Verify the correct model was returned + assert result == whitelist_model + + def test_subnet_relation(self, app): + """Test evaluating a whitelist with a SUBNET relation to a rule""" + # Create test data + whitelist_model = MagicMock(spec=Whitelist) + whitelist_model.id = 1 + + flashes = [] + + rtbh_rule = MagicMock(spec=RTBH) + rtbh_rule.rstate_id = 1 # Active state + + rule_key = "192.168.1.0/24" + whitelist_key = "192.168.1.128/25" + rtbh_rule_cache = {rule_key: rtbh_rule} + + results = [(rule_key, whitelist_key, Relation.SUBNET)] + + # Mock required functions + with patch("flowapp.services.whitelist_service.subtract_network") as mock_subtract, patch( + "flowapp.services.whitelist_service.create_rtbh_from_whitelist_parts" + ) as mock_create, patch("flowapp.services.whitelist_service.add_rtbh_rule_to_cache") as mock_add_cache, patch( + "flowapp.services.whitelist_service.db.session.commit" + ) as mock_commit: + + # Mock subtract_network to return some subnets + mock_subtract.return_value = ["192.168.1.0/25"] + + # Call the function + with app.app_context(): + _result = whitelist_service.evaluate_whitelist_against_rtbh_check_results( + whitelist_model, flashes, rtbh_rule_cache, results + ) + + # Verify subnet calculation was performed + mock_subtract.assert_called_once() + + # Verify new rules were created for the subnets + mock_create.assert_called_once() + + # Verify the original rule was cached + mock_add_cache.assert_called_once_with(rtbh_rule, whitelist_model.id, RuleOrigin.USER) + + # Verify transaction was committed + mock_commit.assert_called_once() + + # Verify the flash messages + assert any("supernet of whitelist" in msg for msg in flashes) + + # Verify model was updated to whitelisted state + assert rtbh_rule.rstate_id == 4 + + def test_supernet_relation(self, app): + """Test evaluating a whitelist with a SUPERNET relation to a rule""" + # Create test data + whitelist_model = MagicMock(spec=Whitelist) + whitelist_model.id = 1 + + flashes = [] + + rtbh_rule = MagicMock(spec=RTBH) + rtbh_rule.rstate_id = 1 # Active state + + rule_key = "192.168.1.0/24" + whitelist_key = "192.168.0.0/16" + rtbh_rule_cache = {rule_key: rtbh_rule} + + results = [(rule_key, whitelist_key, Relation.SUPERNET)] + + # Mock required functions + with patch("flowapp.services.whitelist_service.whitelist_rtbh_rule") as mock_whitelist_rule, patch( + "flowapp.services.whitelist_service.withdraw_rtbh_route" + ) as mock_withdraw: + + # Call the function + with app.app_context(): + result = whitelist_service.evaluate_whitelist_against_rtbh_check_results( + whitelist_model, flashes, rtbh_rule_cache, results + ) + + # Verify the rule was whitelisted and route withdrawn + mock_whitelist_rule.assert_called_once_with(rtbh_rule, whitelist_model) + mock_withdraw.assert_called_once_with(rtbh_rule) + + # Verify the flash message + assert any("subnet of whitelist" in msg for msg in flashes) + + # Verify the correct model was returned + assert result == whitelist_model diff --git a/flowapp/utils/__init__.py b/flowapp/utils/__init__.py new file mode 100644 index 00000000..99ca325a --- /dev/null +++ b/flowapp/utils/__init__.py @@ -0,0 +1,46 @@ +from .base import ( + other_rtypes, + output_date_format, + parse_api_time, + quote_to_ent, + webpicker_to_datetime, + datetime_to_webpicker, + get_state_by_time, + round_to_ten_minutes, + flash_errors, + active_css_rstate, + get_comp_func, +) + +from .app_factory import ( + # configure_app, + configure_logging, + register_blueprints, + register_error_handlers, + register_context_processors, + register_template_filters, + register_auth_handlers, + # register_org_routes, +) + +__all__ = [ + "other_rtypes", + "output_date_format", + "parse_api_time", + "quote_to_ent", + "webpicker_to_datetime", + "datetime_to_webpicker", + "get_state_by_time", + "round_to_ten_minutes", + "flash_errors", + "active_css_rstate", + "get_comp_func", + # "configure_app", + "configure_logging", + "register_blueprints", + "register_error_handlers", + "register_context_processors", + "register_template_filters", + "register_auth_handlers", + # "register_org_routes", +] diff --git a/flowapp/utils/app_factory.py b/flowapp/utils/app_factory.py new file mode 100644 index 00000000..234522fd --- /dev/null +++ b/flowapp/utils/app_factory.py @@ -0,0 +1,221 @@ +import logging +import babel +from flask import redirect, render_template, request, session, url_for + + +def register_blueprints(app, csrf=None): + """Register Flask blueprints.""" + from flowapp.views.admin import admin + from flowapp.views.rules import rules + from flowapp.views.api_v1 import api as api_v1 + from flowapp.views.api_v2 import api as api_v2 + from flowapp.views.api_v3 import api as api_v3 + from flowapp.views.api_keys import api_keys + from flowapp.views.dashboard import dashboard + from flowapp.views.whitelist import whitelist + + # Configure CSRF exemption for API routes + if csrf: + csrf.exempt(api_v1) + csrf.exempt(api_v2) + csrf.exempt(api_v3) + + # Register blueprints with URL prefixes + app.register_blueprint(admin, url_prefix="/admin") + app.register_blueprint(rules, url_prefix="/rules") + app.register_blueprint(api_keys, url_prefix="/api_keys") + app.register_blueprint(api_v1, url_prefix="/api/v1") + app.register_blueprint(api_v2, url_prefix="/api/v2") + app.register_blueprint(api_v3, url_prefix="/api/v3") + app.register_blueprint(dashboard, url_prefix="/dashboard") + app.register_blueprint(whitelist, url_prefix="/whitelist") + + return app + + +def configure_logging(app): + """Configure logging for the Flask application.""" + + # Remove all default handlers + for handler in app.logger.handlers[:]: + app.logger.removeHandler(handler) + + # Retrieve log level and file name from config + log_level = app.config.get("LOG_LEVEL", "DEBUG").upper() + log_file = app.config.get("LOG_FILE", "app.log") + + # Define log format + log_format = "%(asctime)s | %(levelname)s | %(message)s" + log_datefmt = "%Y-%m-%d %H:%M:%S" + formatter = logging.Formatter(log_format, datefmt=log_datefmt) + + # Console handler + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + + # File handler + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + + # Set logger level + app.logger.setLevel(getattr(logging, log_level, logging.DEBUG)) + + # Attach handlers + app.logger.addHandler(console_handler) + app.logger.addHandler(file_handler) + + return app + + +def register_error_handlers(app): + """Register error handlers.""" + + @app.errorhandler(404) + def not_found(error): + return render_template("errors/404.html"), 404 + + @app.errorhandler(500) + def internal_error(exception): + app.logger.exception(exception) + return render_template("errors/500.html"), 500 + + return app + + +def register_context_processors(app): + """Register template context processors.""" + + @app.context_processor + def utility_processor(): + def editable_rule(rule): + if rule: + from flowapp.validators import editable_range + from flowapp.models import get_user_nets + + editable_range(rule, get_user_nets(session["user_id"])) + return True + return False + + return dict(editable_rule=editable_rule) + + @app.context_processor + def inject_main_menu(): + """Inject main menu config to templates.""" + return {"main_menu": app.config.get("MAIN_MENU")} + + @app.context_processor + def inject_dashboard(): + """Inject dashboard config to templates.""" + return {"dashboard": app.config.get("DASHBOARD")} + + return app + + +def register_template_filters(app): + """Register custom template filters.""" + + @app.template_filter("strftime") + def format_datetime(value): + if value is None: + return app.config.get("MISSING_DATETIME_MESSAGE", "Never") + + format = "y/MM/dd HH:mm" + return babel.dates.format_datetime(value, format) + + @app.template_filter("unlimited") + def unlimited_filter(value): + return "unlimited" if value == 0 else value + + return app + + +def register_auth_handlers(app, ext): + """Register authentication handlers.""" + + @ext.login_handler + def login(user_info): + try: + uuid = user_info.get("eppn") + except KeyError: + uuid = False + return render_template("errors/401.html") + + return _handle_login(uuid, app) + + @app.route("/logout") + def logout(): + session["user_uuid"] = False + session["user_id"] = False + session.clear() + return redirect(app.config.get("LOGOUT_URL")) + + @app.route("/ext-login") + def ext_login(): + header_name = app.config.get("AUTH_HEADER_NAME", "X-Authenticated-User") + if header_name not in request.headers: + return render_template("errors/401.html") + + uuid = request.headers.get(header_name) + if not uuid: + return render_template("errors/401.html") + + return _handle_login(uuid, app) + + @app.route("/local-login") + def local_login(): + print("Local login started") + if not app.config.get("LOCAL_AUTH", False): + print("Local auth not enabled") + return render_template("errors/401.html") + + uuid = app.config.get("LOCAL_USER_UUID", False) + if not uuid: + print("Local user not set") + return render_template("errors/401.html") + + print(f"Local login with {uuid}") + return _handle_login(uuid, app) + + return app + + +def _handle_login(uuid, app): + """Handle login process for all authentication methods.""" + from flowapp import db + + multiple_orgs = False + try: + user, multiple_orgs = _register_user_to_session(uuid, db) + except AttributeError as e: + app.logger.exception(e) + return render_template("errors/401.html") + + if multiple_orgs: + return redirect(url_for("select_org", org_id=None)) + + # set user org to session + user_org = user.organization.first() + session["user_org"] = user_org.name + session["user_org_id"] = user_org.id + + return redirect("/") + + +def _register_user_to_session(uuid, db): + """Register user information to session.""" + from flowapp.models import User + + print(f"Registering user {uuid} to session") + user = db.session.query(User).filter_by(uuid=uuid).first() + print(f"Got user {user} from DB") + session["user_uuid"] = user.uuid + session["user_email"] = user.uuid + session["user_name"] = user.name + session["user_id"] = user.id + session["user_roles"] = [role.name for role in user.role.all()] + session["user_role_ids"] = [role.id for role in user.role.all()] + roles = [i > 1 for i in session["user_role_ids"]] + session["can_edit"] = True if all(roles) and roles else [] + # check if user has multiple organizations and return True if so + print(f"DEBUG SESSION {session}") + return user, len(user.organization.all()) > 1 diff --git a/flowapp/utils.py b/flowapp/utils/base.py similarity index 98% rename from flowapp/utils.py rename to flowapp/utils/base.py index ad15b615..512f7ccd 100644 --- a/flowapp/utils.py +++ b/flowapp/utils/base.py @@ -1,4 +1,3 @@ -from operator import ge, lt from datetime import datetime, timedelta from flask import flash from flowapp.constants import ( @@ -73,7 +72,7 @@ def parse_api_time(apitime): except ValueError: mytime = False - return False + return mytime def quote_to_ent(comment): diff --git a/flowapp/validators.py b/flowapp/validators.py index 63061062..890bdacb 100644 --- a/flowapp/validators.py +++ b/flowapp/validators.py @@ -253,7 +253,10 @@ def __call__(self, form, field): result = False for address in field.data.split("/"): for adr_range in self.net_ranges: - result = result or ipaddress.ip_address(address) in ipaddress.ip_network(adr_range) + try: + result = result or ipaddress.ip_address(address) in ipaddress.ip_network(adr_range) + except ValueError as e: + raise ValidationError(self.message + str(e.args[0])) if not result: raise ValidationError(self.message) @@ -352,3 +355,66 @@ def subnet_of(net_a, net_b): def supernet_of(net_a, net_b): """Return True if this network is a supernet of other.""" return _is_subnet_of(net_b, net_a) + + +class IPAddressValidator: + """ + Universal validator that accepts both IPv4 and IPv6 addresses. + """ + + def __init__(self, message=None): + self.message = message or "Invalid IP address: {}" + + def __call__(self, form, field): + if not field.data: + return + + try: + ipaddress.ip_address(field.data) + except ValueError: + raise ValidationError(self.message.format(field.data)) + + +class NetworkValidator: + """ + Validates that an IP address and mask form a valid network. + Works with both IPv4 and IPv6 addresses. + """ + + def __init__(self, mask_field_name, message=None): + self.mask_field_name = mask_field_name + self.message = message or "Invalid network: address {}, mask {}" + + def __call__(self, form, field): + if not field.data: + return + + mask_field = form._fields.get(self.mask_field_name) + if not mask_field or not mask_field.data: + return + + try: + # Determine IP version + ip = ipaddress.ip_address(field.data) + mask = int(mask_field.data) + + # Validate mask range based on IP version + if isinstance(ip, ipaddress.IPv4Address): + if not 0 <= mask <= 32: + raise ValidationError(f"Invalid IPv4 mask: {mask}") + else: # IPv6 + if not 0 <= mask <= 128: + raise ValidationError(f"Invalid IPv6 mask: {mask}") + + # Try to create network to validate the combination + network = ipaddress.ip_network(f"{field.data}/{mask}", strict=False) + + # Check if the original IP is the correct network address + if str(network.network_address) != str(ip): + raise ValidationError( + f"Invalid network address: {field.data}/{mask}. " + f"Network address should be: {network.network_address}" + ) + + except ValueError: + raise ValidationError(self.message.format(field.data, mask_field.data)) diff --git a/flowapp/views/admin.py b/flowapp/views/admin.py index d014c87c..7912e86a 100644 --- a/flowapp/views/admin.py +++ b/flowapp/views/admin.py @@ -73,15 +73,22 @@ def add_machine_key(): """ generated = secrets.token_hex(24) form = MachineApiKeyForm(request.form, key=generated) + form.user.choices = [(g.id, g.name) for g in db.session.query(User).order_by("name")] if request.method == "POST" and form.validate(): + target_user = db.session.get(User, form.user.data) + target_org = target_user.organization.first() if target_user else None + current_user = session.get("user_name") + curent_email = session.get("user_uuid") + comment = f"created by: {current_user}/{curent_email}, comment: {form.comment.data}" model = MachineApiKey( machine=form.machine.data, key=form.key.data, expires=form.expires.data, readonly=form.readonly.data, - comment=form.comment.data, - user_id=session["user_id"], + comment=comment, + user_id=target_user.id, + org_id=target_org.id, ) db.session.add(model) diff --git a/flowapp/views/api_common.py b/flowapp/views/api_common.py index 6ce24bf6..3951b093 100644 --- a/flowapp/views/api_common.py +++ b/flowapp/views/api_common.py @@ -5,7 +5,7 @@ from functools import wraps from datetime import datetime, timedelta -from flowapp.constants import RULE_NAMES_DICT, WITHDRAW, ANNOUNCE, TIME_FORMAT_ARG, RuleTypes +from flowapp.constants import RULE_NAMES_DICT, WITHDRAW, TIME_FORMAT_ARG, RuleTypes from flowapp.models import ( RTBH, Flowspec4, @@ -18,20 +18,16 @@ check_rule_limit, get_user_nets, get_user_actions, - get_ipv4_model_if_exists, - get_ipv6_model_if_exists, insert_initial_communities, get_user_communities, - get_rtbh_model_if_exists, ) from flowapp.forms import IPv4Form, IPv6Form, RTBHForm +from flowapp.services import rule_service from flowapp.utils import ( - quote_to_ent, - get_state_by_time, output_date_format, ) from flowapp.auth import check_access_rights -from flowapp.output import announce_route, log_route, log_withdraw, Route, RouteSources +from flowapp.output import announce_route, log_withdraw, Route, RouteSources from flowapp import db, validators, flowspec, messages @@ -67,7 +63,6 @@ def authorize(user_key): :return: page with token """ jwt_key = current_app.config.get("JWT_SECRET") - # try normal user key first model = db.session.query(ApiKey).filter_by(key=user_key).first() # if not found try machine key @@ -100,7 +95,7 @@ def authorize(user_key): return jsonify({"token": encoded}) else: - return jsonify({"message": "auth token is invalid"}), 403 + return jsonify({"message": f"auth token is not valid from machine {request.remote_addr}"}), 403 def check_readonly(func): @@ -191,7 +186,7 @@ def all_communities(current_user): def limit_reached(count, rule_type, org_id): - rule_name = RULE_NAMES_DICT[int(rule_type)] + rule_name = RULE_NAMES_DICT[rule_type.value] org = db.session.get(Organization, org_id) if rule_type == RuleTypes.IPv4: limit = org.limit_flowspec4 @@ -207,7 +202,7 @@ def limit_reached(count, rule_type, org_id): def global_limit_reached(count, rule_type): - rule_name = RULE_NAMES_DICT[int(rule_type)] + rule_name = RULE_NAMES_DICT[rule_type.value] if rule_type == RuleTypes.IPv4 or rule_type == RuleTypes.IPv6: limit = current_app.config.get("FLOWSPEC_MAX_RULES") elif rule_type == RuleTypes.RTBH: @@ -249,52 +244,15 @@ def create_ipv4(current_user): if form_errors: return jsonify(form_errors), 400 - model = get_ipv4_model_if_exists(form.data, 1) - - if model: - model.expires = form.expires.data - flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value." - else: - model = Flowspec4( - source=form.source.data, - source_mask=form.source_mask.data, - source_port=form.source_port.data, - destination=form.dest.data, - destination_mask=form.dest_mask.data, - destination_port=form.dest_port.data, - protocol=form.protocol.data, - flags=";".join(form.flags.data), - packet_len=form.packet_len.data, - fragment=";".join(form.fragment.data), - expires=form.expires.data, - comment=quote_to_ent(form.comment.data), - action_id=form.action.data, - user_id=current_user["id"], - org_id=current_user["org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - flash_message = "IPv4 Rule saved" - db.session.add(model) - - db.session.commit() - - # announce route if model is in active state - if model.rstate_id == 1: - command = messages.create_ipv4(model, ANNOUNCE) - route = Route( - author=f"{current_user['uuid']} / {current_user['org']}", - source=RouteSources.API, - command=command, - ) - announce_route(route) - - # log changes - log_route( - current_user["id"], - model, - RuleTypes.IPv4, - f"{current_user['uuid']} / {current_user['org']}", + # Use the service to create/update the rule + model, flash_message = rule_service.create_or_update_ipv4_rule( + form_data=form.data, + user_id=current_user["id"], + org_id=current_user["org_id"], + user_email=current_user["uuid"], + org_name=current_user["org"], ) + pref_format = output_date_format(json_request_data, form.expires.pref_format) response = {"message": flash_message, "rule": model.to_dict(pref_format)} return jsonify(response), 201 @@ -326,50 +284,12 @@ def create_ipv6(current_user): if form_errors: return jsonify(form_errors), 400 - model = get_ipv6_model_if_exists(form.data, 1) - - if model: - model.expires = form.expires.data - flash_message = "Existing IPv6 Rule found. Expiration time was updated to new value." - else: - model = Flowspec6( - source=form.source.data, - source_mask=form.source_mask.data, - source_port=form.source_port.data, - destination=form.dest.data, - destination_mask=form.dest_mask.data, - destination_port=form.dest_port.data, - next_header=form.next_header.data, - flags=";".join(form.flags.data), - packet_len=form.packet_len.data, - expires=form.expires.data, - comment=quote_to_ent(form.comment.data), - action_id=form.action.data, - user_id=current_user["id"], - org_id=current_user["org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - flash_message = "IPv6 Rule saved" - db.session.add(model) - - db.session.commit() - - # announce routes - if model.rstate_id == 1: - command = messages.create_ipv6(model, ANNOUNCE) - route = Route( - author=f"{current_user['uuid']} / {current_user['org']}", - source=RouteSources.API, - command=command, - ) - announce_route(route) - - # log changes - log_route( - current_user["id"], - model, - RuleTypes.IPv6, - f"{current_user['uuid']} / {current_user['org']}", + model, flash_message = rule_service.create_or_update_ipv6_rule( + form_data=form.data, + user_id=current_user["id"], + org_id=current_user["org_id"], + user_email=current_user["uuid"], + org_name=current_user["org"], ) pref_format = output_date_format(json_request_data, form.expires.pref_format) @@ -384,7 +304,6 @@ def create_rtbh(current_user): count = db.session.query(RTBH).filter_by(rstate_id=1).count() return global_limit_reached(count=count, rule_type=RuleTypes.RTBH) - # check limit if check_rule_limit(current_user["org_id"], RuleTypes.RTBH): count = db.session.query(RTBH).filter_by(rstate_id=1, org_id=current_user["org_id"]).count() return limit_reached(count=count, rule_type=RuleTypes.RTBH, org_id=current_user["org_id"]) @@ -406,43 +325,12 @@ def create_rtbh(current_user): if form_errors: return jsonify(form_errors), 400 - model = get_rtbh_model_if_exists(form.data, 1) - - if model: - model.expires = form.expires.data - flash_message = "Existing RTBH Rule found. Expiration time was updated to new value." - else: - model = RTBH( - ipv4=form.ipv4.data, - ipv4_mask=form.ipv4_mask.data, - ipv6=form.ipv6.data, - ipv6_mask=form.ipv6_mask.data, - community_id=form.community.data, - expires=form.expires.data, - comment=quote_to_ent(form.comment.data), - user_id=current_user["id"], - org_id=current_user["org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - db.session.add(model) - db.session.commit() - flash_message = "RTBH Rule saved" - - # announce routes - if model.rstate_id == 1: - command = messages.create_rtbh(model, ANNOUNCE) - route = Route( - author=f"{current_user['uuid']} / {current_user['org']}", - source=RouteSources.API, - command=command, - ) - announce_route(route) - # log changes - log_route( - current_user["id"], - model, - RuleTypes.RTBH, - f"{current_user['uuid']} / {current_user['org']}", + model, flash_message = rule_service.create_or_update_rtbh_rule( + form_data=form.data, + user_id=current_user["id"], + org_id=current_user["org_id"], + user_email=current_user["uuid"], + org_name=current_user["org"], ) pref_format = output_date_format(json_request_data, form.expires.pref_format) @@ -506,7 +394,7 @@ def delete_v4_rule(current_user, rule_id): """ model_name = Flowspec4 route_model = messages.create_ipv4 - return delete_rule(current_user, rule_id, model_name, route_model, 4) + return delete_rule(current_user, rule_id, model_name, route_model, RuleTypes.IPv4) def delete_v6_rule(current_user, rule_id): @@ -516,7 +404,7 @@ def delete_v6_rule(current_user, rule_id): """ model_name = Flowspec6 route_model = messages.create_ipv6 - return delete_rule(current_user, rule_id, model_name, route_model, 6) + return delete_rule(current_user, rule_id, model_name, route_model, RuleTypes.IPv6) def delete_rtbh_rule(current_user, rule_id): @@ -526,7 +414,7 @@ def delete_rtbh_rule(current_user, rule_id): """ model_name = RTBH route_model = messages.create_rtbh - return delete_rule(current_user, rule_id, model_name, route_model, 1) + return delete_rule(current_user, rule_id, model_name, route_model, RuleTypes.RTBH) def delete_rule(current_user, rule_id, model_name, route_model, rule_type): diff --git a/flowapp/views/dashboard.py b/flowapp/views/dashboard.py index 8ed262f0..c709f11f 100644 --- a/flowapp/views/dashboard.py +++ b/flowapp/views/dashboard.py @@ -48,7 +48,7 @@ def whois(ip_address): def index(rtype=None, rstate="active"): """ dispatcher object for the dashboard - :param rtype: ipv4, ipv6, rtbh + :param rtype: ipv4, ipv6, rtbh, whitelist :param rstate: :return: view from view factory """ @@ -84,7 +84,6 @@ def index(rtype=None, rstate="active"): data_handler_module = current_app.config["DASHBOARD"].get(rtype).get("data_handler", models) data_handler_method = current_app.config["DASHBOARD"].get(rtype).get("data_handler_method", "get_ip_rules") - # get search query, sort order and sort key from request or session get_search_query = request.args.get(SEARCH_ARG, session.get(SEARCH_ARG, "")) get_sort_key = request.args.get(SORT_ARG, session.get(SORT_ARG, DEFAULT_SORT)) @@ -109,6 +108,10 @@ def index(rtype=None, rstate="active"): # get the handler and the data handler = getattr(data_handler_module, data_handler_method) rules = handler(rtype, rstate, get_sort_key, get_sort_order) + + # Enrich rules with whitelist information + rules, whitelist_rule_ids = enrich_rules_with_whitelist_info(rules, rtype) + session[RULES_KEY] = [rule.id for rule in rules] # search rules if get_search_query: @@ -123,6 +126,8 @@ def index(rtype=None, rstate="active"): else: count_match = "" + allowed_communities = current_app.config["ALLOWED_COMMUNITIES"] + return view_factory( rtype=rtype, rstate=rstate, @@ -138,6 +143,8 @@ def index(rtype=None, rstate="active"): macro_tbody=macro_tbody, macro_thead=macro_thead, macro_tfoot=macro_tfoot, + whitelist_rule_ids=whitelist_rule_ids, + allowed_communities=allowed_communities, ) @@ -148,18 +155,25 @@ def create_dashboard_table_body( group_op=True, macro_file="macros.html", macro_name="build_ip_tbody", + whitelist_rule_ids=None, + allowed_communities=None, ): """ create the table body for the dashboard using a jinja2 macro :param rules: list of rules :param rtype: ipv4, ipv6, rtbh + :param editable: whether rules can be edited + :param group_op: whether group operations are allowed :param macro_file: the file where the macro is defined :param macro_name: the name of the macro + :param whitelist_rule_ids: set of rule IDs that were created by a whitelist """ tstring = "{% " tstring = tstring + f"from '{macro_file}' import {macro_name}" tstring = tstring + " %} {{" - tstring = tstring + f" {macro_name}(rules, today, editable, group_op) " + "}}" + tstring = ( + tstring + f" {macro_name}(rules, today, editable, group_op, whitelist_rule_ids, allowed_communities) " + "}}" + ) dashboard_table_body = render_template_string( tstring, @@ -167,6 +181,8 @@ def create_dashboard_table_body( today=datetime.now(), editable=editable, group_op=group_op, + whitelist_rule_ids=whitelist_rule_ids or set(), + allowed_communities=allowed_communities or [], ) return dashboard_table_body @@ -246,6 +262,8 @@ def create_admin_response( macro_tbody="build_ip_tbody", macro_thead="build_rules_thead", macro_tfoot="build_group_buttons_tfoot", + whitelist_rule_ids=None, + allowed_communities=None, ): """ Admin can see and edit any rules @@ -256,8 +274,17 @@ def create_admin_response( :param sort_order: :return: """ + group_op = True if rtype != "whitelist" else False - dashboard_table_body = create_dashboard_table_body(rules, rtype, macro_file=macro_file, macro_name=macro_tbody) + dashboard_table_body = create_dashboard_table_body( + rules, + rtype, + macro_file=macro_file, + macro_name=macro_tbody, + group_op=group_op, + whitelist_rule_ids=whitelist_rule_ids, + allowed_communities=allowed_communities, + ) dashboard_table_head = create_dashboard_table_head( rules_columns=table_columns, @@ -266,15 +293,18 @@ def create_admin_response( sort_key=sort_key, sort_order=sort_order, search_query=search_query, + group_op=group_op, macro_file=macro_file, macro_name=macro_thead, ) - - dashboard_table_foot = create_dashboard_table_foot( - table_colspan, - macro_file=macro_file, - macro_name=macro_tfoot, - ) + if group_op: + dashboard_table_foot = create_dashboard_table_foot( + table_colspan, + macro_file=macro_file, + macro_name=macro_tfoot, + ) + else: + dashboard_table_foot = "" res = make_response( render_template( @@ -313,6 +343,8 @@ def create_user_response( macro_tbody="build_ip_tbody", macro_thead="build_rules_thead", macro_tfoot="build_rules_tfoot", + whitelist_rule_ids=None, + allowed_communities=None, ): """ Filter out the rules for normal users @@ -343,9 +375,20 @@ def create_user_response( group_op=False, macro_file=macro_file, macro_name=macro_tbody, + whitelist_rule_ids=whitelist_rule_ids, + allowed_communities=allowed_communities, ) + + group_op = True if rtype != "whitelist" else False + dashboard_table_editable = create_dashboard_table_body( - rules_editable, rtype, macro_file=macro_file, macro_name=macro_tbody + rules_editable, + rtype, + macro_file=macro_file, + macro_name=macro_tbody, + group_op=group_op, + whitelist_rule_ids=whitelist_rule_ids, + allowed_communities=allowed_communities, ) dashboard_table_editable_head = create_dashboard_table_head( rules_columns=table_columns, @@ -354,7 +397,7 @@ def create_user_response( sort_key=sort_key, sort_order=sort_order, search_query=search_query, - group_op=True, + group_op=group_op, macro_file=macro_file, macro_name=macro_thead, ) @@ -370,11 +413,14 @@ def create_user_response( macro_name=macro_thead, ) - dashboard_table_foot = create_dashboard_table_foot( - table_colspan, - macro_file=macro_file, - macro_name=macro_tfoot, - ) + if group_op: + dashboard_table_foot = create_dashboard_table_foot( + table_colspan, + macro_file=macro_file, + macro_name=macro_tfoot, + ) + else: + dashboard_table_foot = "" display_editable = len(rules_editable) display_readonly = len(read_only_rules) @@ -419,6 +465,8 @@ def create_view_response( macro_tbody="build_ip_tbody", macro_thead="build_rules_thead", macro_tfoot="build_rules_tfoot", + whitelist_rule_ids=None, + allowed_communities=None, ): """ Filter out the rules for normal users @@ -433,6 +481,8 @@ def create_view_response( group_op=False, macro_file=macro_file, macro_name=macro_tbody, + whitelist_rule_ids=whitelist_rule_ids, + allowed_communities=allowed_communities, ) dashboard_table_head = create_dashboard_table_head( @@ -482,3 +532,39 @@ def filter_rules(rules, get_search_query): result.append(rules[idx]) return result + + +def enrich_rules_with_whitelist_info(rules, rule_type): + """ + Enrich rules with whitelist information from RuleWhitelistCache. + + Args: + rules: List of rule objects (Flowspec4, Flowspec6, RTBH) + rule_type: String identifier of rule type ("ipv4", "ipv6", "rtbh") + + Returns: + Tuple of (rules, whitelist_rule_ids) where whitelist_rule_ids is a set of + rule IDs that were created by a whitelist. + """ + from flowapp.models.rules.whitelist import RuleWhitelistCache + from flowapp.constants import RuleTypes, RuleOrigin + + # Map rule type string to enum value + rule_type_map = {"ipv4": RuleTypes.IPv4.value, "ipv6": RuleTypes.IPv6.value, "rtbh": RuleTypes.RTBH.value} + + # Get all rule IDs + rule_ids = [rule.id for rule in rules] + + # No rules to process + if not rule_ids: + return rules, set() + + # Query the cache for these rule IDs + cache_entries = RuleWhitelistCache.query.filter( + RuleWhitelistCache.rid.in_(rule_ids), RuleWhitelistCache.rtype == rule_type_map.get(rule_type) + ).all() + + # Create a set of rule IDs that were created by a whitelist + whitelist_rule_ids = {entry.rid for entry in cache_entries if entry.rorigin == RuleOrigin.WHITELIST.value} + + return rules, whitelist_rule_ids diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py index 5618789f..e3b82a1a 100644 --- a/flowapp/views/rules.py +++ b/flowapp/views/rules.py @@ -1,11 +1,10 @@ # flowapp/views/admin.py from datetime import datetime, timedelta -from operator import ge, lt from collections import namedtuple from flask import Blueprint, current_app, flash, redirect, render_template, request, session, url_for -from flowapp import constants, db, messages +from flowapp import constants, db from flowapp.auth import ( admin_required, auth_required, @@ -24,19 +23,17 @@ Organization, check_global_rule_limit, check_rule_limit, - get_ipv4_model_if_exists, - get_ipv6_model_if_exists, - get_rtbh_model_if_exists, get_user_actions, get_user_communities, get_user_nets, insert_initial_communities, ) +from flowapp.models.log import Log from flowapp.output import ROUTE_MODELS, announce_route, log_route, log_withdraw, RouteSources, Route +from flowapp.services import rule_service, announce_all_routes, delete_expired_whitelists from flowapp.utils import ( flash_errors, get_state_by_time, - quote_to_ent, round_to_ten_minutes, ) @@ -61,13 +58,21 @@ def reactivate_rule(rule_type, rule_id): """ Set new time for the rule of given type identified by id - :param rule_type: string - type of rule + :param rule_type: integer - type of rule, corresponds to RuleTypes enum value :param rule_id: integer - id of the rule """ + # Convert the integer rule_type to RuleTypes enum + enum_rule_type = RuleTypes(rule_type) + + # Now use the enum value where needed but the integer for dictionary lookups model_name = DATA_MODELS[rule_type] form_name = DATA_FORMS[rule_type] model = db.session.get(model_name, rule_id) + if not model: + flash("Rule not found", "alert-danger") + return redirect(url_for("index")) + form = form_name(request.form, obj=model) form.net_ranges = get_user_nets(session["user_id"]) @@ -75,72 +80,41 @@ def reactivate_rule(rule_type, rule_id): form.action.choices = [(g.id, g.name) for g in db.session.query(Action).order_by("name")] form.action.data = model.action_id - if rule_type == 1: + if rule_type == RuleTypes.RTBH.value: form.community.choices = get_user_communities(session["user_role_ids"]) form.community.data = model.community_id - if rule_type == 4: + if rule_type == RuleTypes.IPv4.value: form.protocol.data = model.protocol - if rule_type == 6: + if rule_type == RuleTypes.IPv6.value: form.next_header.data = model.next_header - # do not need to validate - all is readonly + # Process form submission if request.method == "POST": - # check if rule will be reactivated - state = get_state_by_time(form.expires.data) + # Round expiration time to 10 minutes + expires = round_to_ten_minutes(form.expires.data) + + # Use the service to reactivate the rule + _, messages = rule_service.reactivate_rule( + rule_type=enum_rule_type, + rule_id=rule_id, + expires=expires, + comment=form.comment.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], + ) - # check global limit - check_gl = check_global_rule_limit(rule_type) - if state == 1 and check_gl: + # Handle special messages (redirects) + if "global_limit_reached" in messages: return redirect(url_for("rules.global_limit_reached", rule_type=rule_type)) - # check org limit - if state == 1 and check_rule_limit(session["user_org_id"], rule_type=rule_type): + if "limit_reached" in messages: return redirect(url_for("rules.limit_reached", rule_type=rule_type)) - # set new expiration date - model.expires = round_to_ten_minutes(form.expires.data) - # set again the active state - model.rstate_id = get_state_by_time(form.expires.data) - model.comment = form.comment.data - db.session.commit() - flash("Rule successfully updated", "alert-success") - - route_model = ROUTE_MODELS[rule_type] - - if model.rstate_id == 1: - # announce route - command = route_model(model, constants.ANNOUNCE) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - # log changes - log_route( - session["user_id"], - model, - rule_type, - f"{session['user_email']} / {session['user_org']}", - ) - else: - # withdraw route - command = route_model(model, constants.WITHDRAW) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - # log changes - log_withdraw( - session["user_id"], - route.command, - rule_type, - model.id, - f"{session['user_email']} / {session['user_org']}", - ) + for message in messages: + flash(message, "alert-success") return redirect( url_for( @@ -155,6 +129,7 @@ def reactivate_rule(rule_type, rule_id): else: flash_errors(form) + # For GET requests, prepare the form for display form.expires.data = model.expires for field in form: if field.name not in ["expires", "csrf_token", "comment"]: @@ -177,42 +152,74 @@ def reactivate_rule(rule_type, rule_id): def delete_rule(rule_type, rule_id): """ Delete rule with given id and type - :param sort_key: - :param filter_text: - :param rstate: - :param rule_type: string - type of rule to be deleted + :param rule_type: integer - type of rule to be deleted :param rule_id: integer - rule id """ - model_name = DATA_MODELS[rule_type] - route_model = ROUTE_MODELS[rule_type] + # Convert the integer rule_type to RuleTypes enum + enum_rule_type = RuleTypes(rule_type) + + # Use the service to delete the rule + success, message = rule_service.delete_rule( + rule_type=enum_rule_type, + rule_id=rule_id, + user_id=session["user_id"], + user_email=session["user_email"], + org_name=session["user_org"], + allowed_rule_ids=session.get(constants.RULES_KEY, []), + ) - model = db.session.get(model_name, rule_id) - if model.id in session[constants.RULES_KEY]: - # withdraw route - command = route_model(model, constants.WITHDRAW) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - - log_withdraw( - session["user_id"], - route.command, - rule_type, - model.id, - f"{session['user_email']} / {session['user_org']}", + # Flash appropriate message based on result + flash(message, "alert-success" if success else "alert-warning") + + # Redirect back to dashboard + return redirect( + url_for( + "dashboard.index", + rtype=session[constants.TYPE_ARG], + rstate=session[constants.RULE_ARG], + sort=session[constants.SORT_ARG], + squery=session[constants.SEARCH_ARG], + order=session[constants.ORDER_ARG], ) + ) - # delete from db - db.session.delete(model) - db.session.commit() - flash("Rule deleted", "alert-success") - else: - flash("You can not delete this rule", "alert-warning") +@rules.route("/delete_and_whitelist//", methods=["GET"]) +@auth_required +@user_or_admin_required +def delete_and_whitelist(rule_type, rule_id): + """ + Delete an RTBH rule and create a whitelist entry from it. + :param rule_id: integer - id of the RTBH rule + """ + if rule_type != RuleTypes.RTBH.value: + flash("Only RTBH rules can be converted to whitelists", "alert-warning") + return redirect(url_for("index")) + + # Set whitelist expiration to 7 days from now by default + whitelist_expires = datetime.now() + timedelta(days=7) + + # Use the service to delete RTBH and create whitelist + success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( + rule_id=rule_id, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], + allowed_rule_ids=session.get(constants.RULES_KEY, []), + whitelist_expires=whitelist_expires, + ) + + # Flash all messages + for message in messages: + flash(message, "alert-success" if success else "alert-warning") + + # If successful, flash additional message about whitelist + if success and whitelist: + flash(f"Created whitelist entry ID {whitelist.id} from RTBH rule", "alert-info") + + # Redirect back to dashboard return redirect( url_for( "dashboard.index", @@ -260,6 +267,7 @@ def group_delete(): rule_type = session[constants.TYPE_ARG] model_name = DATA_MODELS_NAMED[rule_type] rule_type_int = constants.RULE_TYPES_DICT[rule_type] + enum_rule_type = RuleTypes(rule_type_int) route_model = ROUTE_MODELS[rule_type_int] rules = [str(x) for x in session[constants.RULES_KEY]] to_delete = request.form.getlist("delete-id") @@ -279,7 +287,7 @@ def group_delete(): log_withdraw( session["user_id"], route.command, - rule_type_int, + enum_rule_type, model.id, f"{session['user_email']} / {session['user_org']}", ) @@ -376,6 +384,7 @@ def group_update_save(rule_type): model_name = DATA_MODELS[rule_type] form_name = DATA_FORMS[rule_type] + enum_rule_type = RuleTypes(rule_type) form = form_name(request.form) @@ -417,7 +426,7 @@ def group_update_save(rule_type): log_route( session["user_id"], model, - rule_type, + enum_rule_type, f"{session['user_email']} / {session['user_org']}", ) else: @@ -433,7 +442,7 @@ def group_update_save(rule_type): log_withdraw( session["user_id"], route.command, - rule_type, + enum_rule_type, model.id, f"{session['user_email']} / {session['user_org']}", ) @@ -476,54 +485,17 @@ def ipv4_rule(): form.net_ranges = net_ranges if request.method == "POST" and form.validate(): - model = get_ipv4_model_if_exists(form.data, 1) - - if model: - model.expires = round_to_ten_minutes(form.expires.data) - flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value." - else: - model = Flowspec4( - source=form.source.data, - source_mask=form.source_mask.data, - source_port=form.source_port.data, - destination=form.dest.data, - destination_mask=form.dest_mask.data, - destination_port=form.dest_port.data, - protocol=form.protocol.data, - flags=";".join(form.flags.data), - packet_len=form.packet_len.data, - fragment=";".join(form.fragment.data), - expires=round_to_ten_minutes(form.expires.data), - comment=quote_to_ent(form.comment.data), - action_id=form.action.data, - user_id=session["user_id"], - org_id=session["user_org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - flash_message = "IPv4 Rule saved" - db.session.add(model) - - db.session.commit() - flash(flash_message, "alert-success") - - # announce route if model is in active state - if model.rstate_id == 1: - command = messages.create_ipv4(model, constants.ANNOUNCE) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - - # log changes - log_route( - session["user_id"], - model, - RuleTypes.IPv4, - f"{session['user_email']} / {session['user_org']}", + # Use the service to create/update the rule + _model, message = rule_service.create_or_update_ipv4_rule( + form_data=form.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], ) + flash(message, "alert-success") + return redirect(url_for("index")) else: for field, errors in form.errors.items(): @@ -560,52 +532,14 @@ def ipv6_rule(): form.net_ranges = net_ranges if request.method == "POST" and form.validate(): - model = get_ipv6_model_if_exists(form.data, 1) - - if model: - model.expires = round_to_ten_minutes(form.expires.data) - flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value." - else: - model = Flowspec6( - source=form.source.data, - source_mask=form.source_mask.data, - source_port=form.source_port.data, - destination=form.dest.data, - destination_mask=form.dest_mask.data, - destination_port=form.dest_port.data, - next_header=form.next_header.data, - flags=";".join(form.flags.data), - packet_len=form.packet_len.data, - expires=round_to_ten_minutes(form.expires.data), - comment=quote_to_ent(form.comment.data), - action_id=form.action.data, - user_id=session["user_id"], - org_id=session["user_org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - flash_message = "IPv6 Rule saved" - db.session.add(model) - - db.session.commit() - flash(flash_message, "alert-success") - - # announce routes - if model.rstate_id == 1: - command = messages.create_ipv6(model, constants.ANNOUNCE) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - - # log changes - log_route( - session["user_id"], - model, - RuleTypes.IPv6, - f"{session['user_email']} / {session['user_org']}", + _model, message = rule_service.create_or_update_ipv6_rule( + form_data=form.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], ) + flash(message, "alert-success") return redirect(url_for("index")) else: @@ -642,47 +576,18 @@ def rtbh_rule(): ] + user_communities form.community.choices = user_communities form.net_ranges = net_ranges + whitelistable = Community.get_whitelistable_communities(current_app.config["ALLOWED_COMMUNITIES"]) if request.method == "POST" and form.validate(): - model = get_rtbh_model_if_exists(form.data, 1) - - if model: - model.expires = round_to_ten_minutes(form.expires.data) - flash_message = "Existing RTBH Rule found. Expiration time was updated to new value." - else: - model = RTBH( - ipv4=form.ipv4.data, - ipv4_mask=form.ipv4_mask.data, - ipv6=form.ipv6.data, - ipv6_mask=form.ipv6_mask.data, - community_id=form.community.data, - expires=round_to_ten_minutes(form.expires.data), - comment=quote_to_ent(form.comment.data), - user_id=session["user_id"], - org_id=session["user_org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - db.session.add(model) - db.session.commit() - flash_message = "RTBH Rule saved" - - flash(flash_message, "alert-success") - # announce routes - if model.rstate_id == 1: - command = messages.create_rtbh(model, constants.ANNOUNCE) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - # log changes - log_route( - session["user_id"], - model, - RuleTypes.RTBH, - f"{session['user_email']} / {session['user_org']}", + _model, messages = rule_service.create_or_update_rtbh_rule( + form_data=form.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], ) + for message in messages: + flash(message, "alert-success") return redirect(url_for("index")) else: @@ -693,7 +598,11 @@ def rtbh_rule(): default_expires = datetime.now() + timedelta(days=7) form.expires.data = default_expires - return render_template("forms/rtbh_rule.html", form=form, action_url=url_for("rules.rtbh_rule")) + print(whitelistable) + + return render_template( + "forms/rtbh_rule.html", form=form, action_url=url_for("rules.rtbh_rule"), whitelistable=whitelistable + ) @rules.route("/limit_reached/") @@ -775,64 +684,13 @@ def announce_all(): @rules.route("/withdraw_expired", methods=["GET"]) @localhost_only def withdraw_expired(): - announce_all_routes(constants.WITHDRAW) - return " " - - -def announce_all_routes(action=constants.ANNOUNCE): """ - get routes from db and send it to ExaBGB api - - @TODO take the request away, use some kind of messaging (maybe celery?) - :param action: action with routes - announce valid routes or withdraw expired routes + cleaning endpoint + deletes expired whitelists + withdraws all expired routes from ExaBGP + deletes logs older than 30 days """ - today = datetime.now() - comp_func = ge if action == constants.ANNOUNCE else lt - - rules4 = ( - db.session.query(Flowspec4) - .filter(Flowspec4.rstate_id == 1) - .filter(comp_func(Flowspec4.expires, today)) - .order_by(Flowspec4.expires.desc()) - .all() - ) - rules6 = ( - db.session.query(Flowspec6) - .filter(Flowspec6.rstate_id == 1) - .filter(comp_func(Flowspec6.expires, today)) - .order_by(Flowspec6.expires.desc()) - .all() - ) - rules_rtbh = ( - db.session.query(RTBH) - .filter(RTBH.rstate_id == 1) - .filter(comp_func(RTBH.expires, today)) - .order_by(RTBH.expires.desc()) - .all() - ) - - messages_v4 = [messages.create_ipv4(rule, action) for rule in rules4] - messages_v6 = [messages.create_ipv6(rule, action) for rule in rules6] - messages_rtbh = [messages.create_rtbh(rule, action) for rule in rules_rtbh] - - messages_all = [] - messages_all.extend(messages_v4) - messages_all.extend(messages_v6) - messages_all.extend(messages_rtbh) - - author_action = "announce all" if action == constants.ANNOUNCE else "withdraw all expired" - - for command in messages_all: - route = Route( - author=f"System call / {author_action} rules", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - - if action == constants.WITHDRAW: - for ruleset in [rules4, rules6, rules_rtbh]: - for rule in ruleset: - rule.rstate_id = 2 - - db.session.commit() + delete_expired_whitelists() + announce_all_routes(constants.WITHDRAW) + Log.delete_old() + return " " diff --git a/flowapp/views/whitelist.py b/flowapp/views/whitelist.py new file mode 100644 index 00000000..39a9352a --- /dev/null +++ b/flowapp/views/whitelist.py @@ -0,0 +1,127 @@ +from datetime import datetime, timedelta +from flask import Blueprint, current_app, flash, redirect, render_template, request, session, url_for + +from flowapp.auth import ( + auth_required, + user_or_admin_required, +) +from flowapp import constants, db +from flowapp.forms import WhitelistForm +from flowapp.models import get_user_nets, Whitelist +from flowapp.services import create_or_update_whitelist, delete_whitelist +from flowapp.utils.base import flash_errors + +whitelist = Blueprint("whitelist", __name__, template_folder="templates") + + +@whitelist.route("/add", methods=["GET", "POST"]) +@auth_required +@user_or_admin_required +def add(): + net_ranges = get_user_nets(session["user_id"]) + form = WhitelistForm(request.form) + + form.net_ranges = net_ranges + + if request.method == "POST" and form.validate(): + model, messages = create_or_update_whitelist( + form.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], + ) + for message in messages: + flash(message, "alert-success") + + return redirect(url_for("index")) + else: + for field, errors in form.errors.items(): + for error in errors: + current_app.logger.debug("Error in the %s field - %s" % (getattr(form, field).label.text, error)) + + print("NOW", datetime.now()) + default_expires = datetime.now() + timedelta(hours=1) + form.expires.data = default_expires + + return render_template("forms/whitelist.html", form=form, action_url=url_for("whitelist.add")) + + +@whitelist.route("/reactivate/", methods=["GET", "POST"]) +@auth_required +@user_or_admin_required +def reactivate(wl_id): + """ + Set new time for whitelist + :param wl_id: int - id of the whitelist + """ + + model = db.session.get(Whitelist, wl_id) + form = WhitelistForm(request.form, obj=model) + form.net_ranges = get_user_nets(session["user_id"]) + + # do not need to validate - all is readonly + if request.method == "POST": + model = create_or_update_whitelist( + form.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], + ) + flash("Whitelist updated", "alert-success") + return redirect( + url_for( + "dashboard.index", + rtype=session[constants.TYPE_ARG], + rstate=session[constants.RULE_ARG], + sort=session[constants.SORT_ARG], + squery=session[constants.SEARCH_ARG], + order=session[constants.ORDER_ARG], + ) + ) + else: + flash_errors(form) + + form.expires.data = model.expires + for field in form: + if field.name not in ["expires", "csrf_token", "comment"]: + field.render_kw = {"disabled": "disabled"} + + action_url = url_for("whitelist.reactivate", wl_id=wl_id) + + return render_template( + "forms/whitelist.html", + form=form, + action_url=action_url, + editing=True, + title="Update", + ) + + +@whitelist.route("/delete/", methods=["GET"]) +@auth_required +@user_or_admin_required +def delete(wl_id): + """ + Delete whitelist + :param wl_id: integer - id of the whitelist + """ + if wl_id in session[constants.RULES_KEY]: + messages = delete_whitelist(wl_id) + flash(f"Whitelist {wl_id} deleted", "alert-success") + for message in messages: + flash(message, "alert-info") + else: + flash("You can not delete this Whitelist", "alert-warning") + + return redirect( + url_for( + "dashboard.index", + rtype=session[constants.TYPE_ARG], + rstate=session[constants.RULE_ARG], + sort=session[constants.SORT_ARG], + squery=session[constants.SEARCH_ARG], + order=session[constants.ORDER_ARG], + ) + ) diff --git a/run.example.py b/run.example.py index b3dd0c4e..ad6f97eb 100644 --- a/run.example.py +++ b/run.example.py @@ -1,30 +1,19 @@ """ -This is an example of how to run the application. -First copy the file as run.py (or whatever you want) -Then edit the file to match your needs. - -From version 0.8.1 the application is using Flask-Session -stored in DB using SQL Alchemy driver. This can be configured for other -drivers, however server side session is required for the application. - -In general you should not need to edit this example file. -Only if you want to configure the application main menu and -dashboard. - -Or in case that you want to add extensions etc. +This is run py for the application. Copied to the container on build. """ from os import environ -from flowapp import create_app, db, sess +from flowapp import create_app, db import config # Configurations -env = environ.get("EXAFS_ENV", "Production") +exafs_env = environ.get("EXAFS_ENV", "Production") +exafs_env = exafs_env.lower() # Call app factory -if env == "devel": +if exafs_env == "devel" or exafs_env == "development": app = create_app(config.DevelopmentConfig) else: app = create_app(config.ProductionConfig) @@ -32,12 +21,6 @@ # init database object db.init_app(app) -# init session -app.config.update(SESSION_TYPE="sqlalchemy") -app.config.update(SESSION_SQLALCHEMY=db) -sess.init_app(app) - - # run app if __name__ == "__main__": app.run(host="::", port=8080, debug=True) diff --git a/withdraw_expired b/withdraw_expired new file mode 100644 index 00000000..0519ecba --- /dev/null +++ b/withdraw_expired @@ -0,0 +1 @@ + \ No newline at end of file