diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 9e659abb..ad479220 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -19,10 +19,10 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11 uses: actions/setup-python@v3 with: - python-version: "3.9" + python-version: "3.11" - name: Setup timezone uses: zcong1993/setup-timezone@master with: diff --git a/flowapp/__about__.py b/flowapp/__about__.py index d10af327..c7f183a6 100755 --- a/flowapp/__about__.py +++ b/flowapp/__about__.py @@ -1 +1 @@ -__version__ = "1.0.2" +__version__ = "1.1.1" diff --git a/flowapp/__init__.py b/flowapp/__init__.py index f029c37d..61543e78 100644 --- a/flowapp/__init__.py +++ b/flowapp/__init__.py @@ -1,10 +1,6 @@ # -*- coding: utf-8 -*- -import babel -import logging -from loguru import logger +from flask import Flask, redirect, render_template, session, url_for -from flask import Flask, redirect, render_template, session, url_for, request -from flask.logging import default_handler from flask_sso import SSO from flask_sqlalchemy import SQLAlchemy from flask_wtf.csrf import CSRFProtect @@ -26,16 +22,16 @@ swagger = Swagger(template_file="static/swagger.yml") -class InterceptHandler(logging.Handler): - - def emit(self, record): - logger_opt = logger.opt(depth=6, exception=record.exc_info, colors=True) - logger_opt.log(record.levelname, record.getMessage()) - - def create_app(config_object=None): app = Flask(__name__) + # Load the default configuration for dashboard and main menu + app.config.from_object(InstanceConfig) + if config_object: + app.config.from_object(config_object) + + app.config.setdefault("VERSION", __version__) + # SSO configuration SSO_ATTRIBUTE_MAP = { "eppn": (True, "eppn"), @@ -47,13 +43,6 @@ def create_app(config_object=None): migrate.init_app(app, db) csrf.init_app(app) - # Load the default configuration for dashboard and main menu - app.config.from_object(InstanceConfig) - if config_object: - app.config.from_object(config_object) - - app.config.setdefault("VERSION", __version__) - # Init SSO ext.init_app(app) @@ -64,76 +53,28 @@ def create_app(config_object=None): if app.config.get("BEHIND_PROXY", False): app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1) - from flowapp import models, constants, validators - from .views.admin import admin - from .views.rules import rules - from .views.api_v1 import api as api_v1 - from .views.api_v2 import api as api_v2 - from .views.api_v3 import api as api_v3 - from .views.api_keys import api_keys + from flowapp import models, constants from .auth import auth_required - from .views.dashboard import dashboard - - # no need for csrf on api because we use JWT - csrf.exempt(api_v1) - csrf.exempt(api_v2) - csrf.exempt(api_v3) - - app.register_blueprint(admin, url_prefix="/admin") - app.register_blueprint(rules, url_prefix="/rules") - app.register_blueprint(api_keys, url_prefix="/api_keys") - app.register_blueprint(api_v1, url_prefix="/api/v1") - app.register_blueprint(api_v2, url_prefix="/api/v2") - app.register_blueprint(api_v3, url_prefix="/api/v3") - app.register_blueprint(dashboard, url_prefix="/dashboard") - - # register loguru as handler - app.logger.removeHandler(default_handler) - app.logger.addHandler(InterceptHandler()) - - @ext.login_handler - def login(user_info): - try: - uuid = user_info.get("eppn") - except KeyError: - uuid = False - return render_template("errors/401.html") - return _handle_login(uuid) + # Register blueprints + from .utils import register_blueprints - @app.route("/logout") - def logout(): - session["user_uuid"] = False - session["user_id"] = False - session.clear() - return redirect(app.config.get("LOGOUT_URL")) + register_blueprints(app, csrf) - @app.route("/ext-login") - def ext_login(): - header_name = app.config.get("AUTH_HEADER_NAME", "X-Authenticated-User") - if header_name not in request.headers: - return render_template("errors/401.html") + # configure logging + from .utils import configure_logging - uuid = request.headers.get(header_name) - if not uuid: - return render_template("errors/401.html") + configure_logging(app) - return _handle_login(uuid) + # register error handlers + from .utils import register_error_handlers - @app.route("/local-login") - def local_login(): - print("Local login started") - if not app.config.get("LOCAL_AUTH", False): - print("Local auth not enabled") - return render_template("errors/401.html") + register_error_handlers(app) - uuid = app.config.get("LOCAL_USER_UUID", False) - if not uuid: - print("Local user not set") - return render_template("errors/401.html") + # register auth handlers + from .utils import register_auth_handlers - print(f"Local login with {uuid}") - return _handle_login(uuid) + register_auth_handlers(app, ext) @app.route("/") @auth_required @@ -192,89 +133,10 @@ def select_org(org_id=None): def shutdown_session(exception=None): db.session.remove() - # HTTP error handling - @app.errorhandler(404) - def not_found(error): - return render_template("errors/404.html"), 404 - - @app.errorhandler(500) - def internal_error(exception): - app.logger.exception(exception) - return render_template("errors/500.html"), 500 - - @app.context_processor - def utility_processor(): - def editable_rule(rule): - if rule: - validators.editable_range(rule, models.get_user_nets(session["user_id"])) - return True - return False - - return dict(editable_rule=editable_rule) - - @app.context_processor - def inject_main_menu(): - """ - inject main menu config to templates - used in default template to create main menu - """ - return {"main_menu": app.config.get("MAIN_MENU")} - - @app.context_processor - def inject_dashboard(): - """ - inject dashboard config to templates - used in submenu dashboard to create dashboard tables - """ - return {"dashboard": app.config.get("DASHBOARD")} - - @app.template_filter("strftime") - def format_datetime(value): - if value is None: - return app.config.get("MISSING_DATETIME_MESSAGE", "Never") - - format = "y/MM/dd HH:mm" - return babel.dates.format_datetime(value, format) - - @app.template_filter("unlimited") - def unlimited_filter(value): - return "unlimited" if value == 0 else value - - def _handle_login(uuid: str): - """ - handles rest of login process - """ - multiple_orgs = False - try: - user, multiple_orgs = _register_user_to_session(uuid) - except AttributeError as e: - app.logger.exception(e) - return render_template("errors/401.html") - - if multiple_orgs: - return redirect(url_for("select_org", org_id=None)) + # register context processors and template filters + from .utils import register_context_processors, register_template_filters - # set user org to session - user_org = user.organization.first() - session["user_org"] = user_org.name - session["user_org_id"] = user_org.id - - return redirect("/") - - def _register_user_to_session(uuid: str): - print(f"Registering user {uuid} to session") - user = db.session.query(models.User).filter_by(uuid=uuid).first() - print(f"Got user {user} from DB") - session["user_uuid"] = user.uuid - session["user_email"] = user.uuid - session["user_name"] = user.name - session["user_id"] = user.id - session["user_roles"] = [role.name for role in user.role.all()] - session["user_role_ids"] = [role.id for role in user.role.all()] - roles = [i > 1 for i in session["user_role_ids"]] - session["can_edit"] = True if all(roles) and roles else [] - # check if user has multiple organizations and return True if so - print(f"DEBUG SESSION {session}") - return user, len(user.organization.all()) > 1 + register_context_processors(app) + register_template_filters(app) return app diff --git a/flowapp/constants.py b/flowapp/constants.py index 5aa52834..975f99b7 100644 --- a/flowapp/constants.py +++ b/flowapp/constants.py @@ -2,6 +2,7 @@ This module contains constant values used in application """ +from enum import Enum from operator import ge, lt DEFAULT_SORT = "expires" @@ -59,7 +60,12 @@ FORM_TIME_PATTERN = "%Y-%m-%dT%H:%M" -class RuleTypes: +class RuleTypes(Enum): RTBH = 1 IPv4 = 4 IPv6 = 6 + + +class RuleOrigin(Enum): + USER = 1 + WHITELIST = 2 diff --git a/flowapp/flowspec.py b/flowapp/flowspec.py index e0ce35a5..ae357ca8 100644 --- a/flowapp/flowspec.py +++ b/flowapp/flowspec.py @@ -21,6 +21,19 @@ def translate_sequence(sequence, max_val=MAX_PORT): return "[{}]".format(result) +def check_limit(value, max_value, min_value=0): + """ + test if the value is within valid range (min_value to max_value inclusive) + raise exception otherwise + """ + value = int(value) + if value > max_value: + raise ValueError("Invalid value number: {} is too big. Max is {}.".format(value, max_value)) + if value < min_value: + raise ValueError("Invalid value number: {} is too small. Min is {}.".format(value, min_value)) + return value + + def to_exabgp_string(value_string, max_val): """ Translate form string to flowspec value or packet size rule @@ -42,35 +55,30 @@ def to_exabgp_string(value_string, max_val): return "={}".format(check_limit(value_string, max_val)) elif RANGE.match(value_string): m = RANGE.match(value_string) - return ">={}&<={}".format( - check_limit(m.group(1), max_val), check_limit(m.group(2), max_val) - ) + start = check_limit(m.group(1), max_val) + end = check_limit(m.group(2), max_val) + if start > end: + raise ValueError("Invalid range: start value cannot be greater than end value") + return ">={}&<={}".format(start, end) elif NOTRAN.match(value_string): - return value_string + m = NOTRAN.match(value_string) + start = check_limit(m.group(1), max_val) + end = check_limit(m.group(2), max_val) + if start > end: + raise ValueError("Invalid range: start value cannot be greater than end value") + return ">={}&<={}".format(start, end) elif GREATER.match(value_string): m = GREATER.match(value_string) return ">={}&<={}".format(check_limit(m.group(1), max_val), max_val) elif LOWER.match(value_string): m = LOWER.match(value_string) - return ">=0&<={}".format(check_limit(m.group(1), max_val)) + # Even for lower bound expressions, validate that the value itself is within range + end = check_limit(m.group(1), max_val) + return ">=0&<={}".format(end) else: raise ValueError("string {} can not be converted".format(value_string)) -def check_limit(value, max_value): - """ - test if the value is lower than max_value - raise exception otherwise - """ - value = int(value) - if value > max_value: - raise ValueError( - "Invalid value number: {} is too big. Max is {}.".format(value, max_value) - ) - else: - return value - - def filter_rules_action(user_actions, rules): """ Divide the list of rules by user_actions to editable and viewonly subsets diff --git a/flowapp/forms.py b/flowapp/forms.py deleted file mode 100644 index 4b914dce..00000000 --- a/flowapp/forms.py +++ /dev/null @@ -1,650 +0,0 @@ -import csv -from io import StringIO - - -from flask_wtf import FlaskForm -from wtforms import widgets -from wtforms import ( - BooleanField, - DateTimeField, - HiddenField, - IntegerField, - SelectField, - SelectMultipleField, - StringField, - TextAreaField, -) -from wtforms.validators import ( - ValidationError, - DataRequired, - Email, - InputRequired, - IPAddress, - Length, - NumberRange, - Optional, -) - -from flowapp.constants import ( - IPV4_FRAGMENT, - IPV4_PROTOCOL, - IPV6_NEXT_HEADER, - TCP_FLAGS, - FORM_TIME_PATTERN, -) -from flowapp.validators import ( - IPv4Address, - IPv6Address, - NetRangeString, - PortString, - address_in_range, - address_with_mask, - network_in_range, - whole_world_range, -) - -from flowapp.utils import parse_api_time - - -class MultiFormatDateTimeLocalField(DateTimeField): - """ - Same as :class:`~wtforms.fields.DateTimeField`, but represents an - ````. - - Custom implementation uses default HTML5 format for parsing the field. - It's possible to use multiple formats - used in API. - - """ - - widget = widgets.DateTimeLocalInput() - - def __init__(self, *args, **kwargs): - kwargs.setdefault("format", "%Y-%m-%dT%H:%M") - self.unlimited = kwargs.pop("unlimited", False) - self.pref_format = None - super().__init__(*args, **kwargs) - - def process_formdata(self, valuelist): - if not valuelist: - return None - # with unlimited field we do not need to parse the empty value - if self.unlimited and len(valuelist) == 1 and len(valuelist[0]) == 0: - self.data = None - return None - - date_str = " ".join((str(val) for val in valuelist)) - result, pref_format = parse_api_time(date_str) - if result: - self.data = result - self.pref_format = pref_format - else: - self.data = None - self.pref_format = None - raise ValueError(self.gettext("Not a valid datetime value.")) - - -class UserForm(FlaskForm): - """ - User Form object - used in Admin - """ - - uuid = StringField( - "Unique User ID", - validators=[ - InputRequired("Please provide UUID"), - Email("Please provide valid email"), - ], - ) - - email = StringField("Email", validators=[Optional(), Email("Please provide valid email")]) - - comment = StringField("Notice", validators=[Optional()]) - - name = StringField("Name", validators=[Optional()]) - - phone = StringField("Contact phone", validators=[Optional()]) - - role_ids = SelectMultipleField("Role", coerce=int, validators=[DataRequired("Select at last one role")]) - - org_ids = SelectMultipleField( - "Organization", - coerce=int, - validators=[DataRequired("We prefer one Organization per user, but it's possible select more")], - ) - - -class BulkUserForm(FlaskForm): - """ - Bulk User Form object - used in Admin - """ - - users = TextAreaField("Users in CSV - see example below", validators=[DataRequired()]) - - def __init__(self, *args, **kwargs): - super(BulkUserForm, self).__init__(*args, **kwargs) - self.roles = None - self.organizations = None - self.uuids = None - - # Custom validator for CSV data - def validate_users(self, field): - csv_data = field.data - - # Parse CSV data - csv_reader = csv.DictReader(StringIO(csv_data), delimiter=",") - - # List to keep track of failed validation rows - errors = 0 - for row_num, row in enumerate(csv_reader, start=1): - try: - # check if the user not already exists - if row["uuid-eppn"] in self.uuids: - field.errors.append(f"Row {row_num}: User with UUID {row['uuid-eppn']} already exists.") - errors += 1 - - # Check if role exists in the database - role_id = int(row["role"]) # Convert role field to integer - if role_id not in self.roles: - field.errors.append(f"Row {row_num}: Role ID {role_id} does not exist.") - errors += 1 - - # Check if organization exists in the database - org_id = int(row["organizace"]) # Convert organization field to integer - if org_id not in self.organizations: - field.errors.append(f"Row {row_num}: Organization ID {org_id} does not exist.") - errors += 1 - - except (KeyError, ValueError) as e: - field.errors.append(f"Row {row_num}: Invalid data / key - {str(e)}. Check CSV head row.") - - if errors > 0: - # Raise validation error if any invalid rows found - raise ValidationError("Invalid CSV Data - check the errors above.") - - -class ApiKeyForm(FlaskForm): - """ - ApiKey for User - Each key / machine pair is unique - """ - - machine = StringField( - "Machine address", - validators=[DataRequired(), IPAddress(message="provide valid IP address")], - ) - - comment = TextAreaField("Your comment for this key", validators=[Optional(), Length(max=255)]) - - expires = MultiFormatDateTimeLocalField( - "Key expiration. Leave blank for non expring key (not-recomended).", - format=FORM_TIME_PATTERN, - validators=[Optional()], - unlimited=True, - ) - - readonly = BooleanField("Read only key", default=False) - - key = HiddenField("GeneratedKey") - - -class MachineApiKeyForm(FlaskForm): - """ - ApiKey for Machines - Each key / machine pair is unique - Only Admin can create new these keys - """ - - machine = StringField( - "Machine address", - validators=[DataRequired(), IPAddress(message="provide valid IP address")], - ) - - comment = TextAreaField("Your comment for this key", validators=[Optional(), Length(max=255)]) - - expires = MultiFormatDateTimeLocalField( - "Key expiration. Leave blank for non expring key (not-recomended).", - format=FORM_TIME_PATTERN, - validators=[Optional()], - unlimited=True, - ) - - readonly = BooleanField("Read only key", default=False) - - key = HiddenField("GeneratedKey") - - -class OrganizationForm(FlaskForm): - """ - Organization form object - used in Admin - """ - - name = StringField("Organization name", validators=[Optional(), Length(max=150)]) - - limit_flowspec4 = IntegerField( - "Maximum number of IPv4 rules, 0 for unlimited", - validators=[ - Optional(), - NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), - ], - ) - - limit_flowspec6 = IntegerField( - "Maximum number of IPv6 rules, 0 for unlimited", - validators=[ - Optional(), - NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), - ], - ) - - limit_rtbh = IntegerField( - "Maximum number of RTBH rules, 0 for unlimited", - validators=[ - Optional(), - NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), - ], - ) - - arange = TextAreaField( - "Organization Adress Range - one range per row", - validators=[Optional(), NetRangeString()], - ) - - -class ActionForm(FlaskForm): - """ - Action form object - used in Admin - """ - - name = StringField("Action short name", validators=[Length(max=150)]) - - command = StringField("ExaBGP command", validators=[Length(max=150)]) - - description = StringField("Action description") - - role_id = SelectField( - "Minimal required role", - choices=[("2", "user"), ("3", "admin")], - validators=[DataRequired()], - ) - - -class ASPathForm(FlaskForm): - """ - AS Path form object - used in Admin - """ - - prefix = StringField("Prefix", validators=[Length(max=120), DataRequired()]) - - as_path = StringField("as-path value", validators=[Length(max=250), DataRequired()]) - - -class CommunityForm(FlaskForm): - """ - Community form object - used in Admin - """ - - name = StringField("Community short name", validators=[Length(max=120), DataRequired()]) - - comm = StringField("Community value", validators=[Length(max=2046)]) - - larcomm = StringField("Large community value", validators=[Length(max=2046)]) - - extcomm = StringField("Extended community value", validators=[Length(max=2046)]) - - description = StringField("Community description", validators=[Length(max=255)]) - - role_id = SelectField( - "Minimal required role", - choices=[("2", "user"), ("3", "admin")], - validators=[DataRequired()], - ) - - as_path = BooleanField("add AS-path (checked = true)") - - def validate(self): - """ - custom validation method - :return: boolean - """ - result = True - - if not FlaskForm.validate(self): - result = False - - if not self.comm.data and not self.extcomm.data and not self.larcomm.data: - err_message = "At last one of those values could not be empty" - self.comm.errors.append(err_message) - self.larcomm.errors.append(err_message) - self.extcomm.errors.append(err_message) - result = False - - return result - - -class RTBHForm(FlaskForm): - """ - RoadToBlackHole rule form - """ - - def __init__(self, *args, **kwargs): - super(RTBHForm, self).__init__(*args, **kwargs) - self.net_ranges = None - - ipv4 = StringField( - "IPv4 address", - validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], - ) - - ipv4_mask = IntegerField( - "IPv4 mask (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=32, message="invalid IPv4 mask value (0-32)"), - ], - ) - - ipv6 = StringField( - "IPv6 address", - validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], - ) - - ipv6_mask = IntegerField( - "IPv6 mask (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=128, message="invalid IPv6 mask value (0-128)"), - ], - ) - - community = SelectField( - "Community", - coerce=int, - validators=[ - DataRequired(message="Please select a community for the rule."), - ], - ) - - expires = MultiFormatDateTimeLocalField( - "Expires", - format=FORM_TIME_PATTERN, - validators=[DataRequired(), InputRequired()], - ) - - comment = arange = TextAreaField("Comments") - - def validate(self): - """ - custom validation method - :return: boolean - """ - result = True - - if not FlaskForm.validate(self): - result = False - - # ipv4 and ipv6 are mutually exclusive - # if both are set, validation fails - # if none is set, validation fails - # if one is set, validation passes - if self.ipv4.data and self.ipv6.data: - self.ipv4.errors.append("IPv4 and IPv6 are mutually exclusive in RTBH rule.") - self.ipv6.errors.append("IPv4 and IPv6 are mutually exclusive in RTBH rule.") - result = False - - if self.ipv4.data and not address_with_mask(self.ipv4.data, self.ipv4_mask.data): - self.ipv4.errors.append( - "This is not valid combination of address {} and mask {}.".format(self.ipv4.data, self.ipv4_mask.data) - ) - result = False - - if self.ipv6.data and not address_with_mask(self.ipv6.data, self.ipv6_mask.data): - self.ipv6.errors.append( - "This is not valid combination of address {} and mask {}.".format(self.ipv6.data, self.ipv6_mask.data) - ) - result = False - - ipv6_in_range = address_in_range(self.ipv6.data, self.net_ranges) - ipv4_in_range = address_in_range(self.ipv4.data, self.net_ranges) - - if not (ipv6_in_range or ipv4_in_range): - self.ipv6.errors.append("IPv4 or IPv6 address must be in organization range : {}.".format(self.net_ranges)) - self.ipv4.errors.append("IPv4 or IPv6 address must be in organization range : {}.".format(self.net_ranges)) - result = False - - return result - - -class IPForm(FlaskForm): - """ - Base class for IPv4 and IPv6 rules - """ - - def __init__(self, *args, **kwargs): - super(IPForm, self).__init__(*args, **kwargs) - self.net_ranges = None - - zero_address = None - source = None - source_mask = None - dest = None - dest_mask = None - flags = SelectMultipleField("TCP flag(s)", choices=TCP_FLAGS, validators=[Optional()]) - - source_port = StringField( - "Source port(s) - ; separated ", - validators=[Optional(), Length(max=255), PortString()], - ) - - dest_port = StringField( - "Destination port(s) - ; separated", - validators=[Optional(), Length(max=255), PortString()], - ) - - packet_len = StringField( - "Packet length - ; separated ", - validators=[Optional(), Length(max=255), PortString()], - ) - - action = SelectField( - "Action", - coerce=int, - validators=[DataRequired(message="Please select an action for the rule.")], - ) - - expires = MultiFormatDateTimeLocalField("Expires", format="%Y-%m-%dT%H:%M", validators=[InputRequired()]) - - comment = arange = TextAreaField("Comments") - - def validate(self): - """ - custom validation method - :return: boolean - """ - - result = True - if not FlaskForm.validate(self): - result = False - - source = self.validate_source_address() - dest = self.validate_dest_address() - ranges = self.validate_address_ranges() - ips = self.validate_ipv_specific() - - return result and source and dest and ranges and ips - - def validate_source_address(self): - """ - validate source address, set error message if validation fails - :return: boolean validation result - """ - if self.source.data and not address_with_mask(self.source.data, self.source_mask.data): - self.source.errors.append( - "This is not valid combination of address {} and mask {}.".format( - self.source.data, self.source_mask.data - ) - ) - return False - - return True - - def validate_dest_address(self): - """ - validate dest address, set error message if validation fails - :return: boolean validation result - """ - if self.dest.data and not address_with_mask(self.dest.data, self.dest_mask.data): - self.dest.errors.append( - "This is not valid combination of address {} and mask {}.".format(self.dest.data, self.dest_mask.data) - ) - return False - - return True - - def validate_address_ranges(self): - """ - validates if the address of source is in the user range - if the source and dest address are empty, check if the user - is member of whole world organization - :return: boolean validation result - """ - if not (self.source.data or self.dest.data): - whole_world_member = whole_world_range(self.net_ranges, self.zero_address) - if not whole_world_member: - self.source.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) - self.dest.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) - return False - else: - source_in_range = network_in_range(self.source.data, self.source_mask.data, self.net_ranges) - dest_in_range = network_in_range(self.dest.data, self.dest_mask.data, self.net_ranges) - if not (source_in_range or dest_in_range): - self.source.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) - self.dest.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) - return False - - return True - - def validate_ipv_specific(self): - """ - abstract method must be implemented in the subclass - """ - pass - - -class IPv4Form(IPForm): - """ - IPv4 form object - """ - - def __init__(self, *args, **kwargs): - super(IPv4Form, self).__init__(*args, **kwargs) - self.net_ranges = None - - zero_address = "0.0.0.0" - source = StringField( - "Source address", - validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], - ) - - source_mask = IntegerField( - "Source mask (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=32, message="invalid mask value (0-32)"), - ], - ) - - dest = StringField( - "Destination address", - validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], - ) - - dest_mask = IntegerField( - "Destination mask (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=32, message="invalid mask value (0-32)"), - ], - ) - - protocol = SelectField( - "Protocol", - choices=[(pr, pr.upper()) for pr in IPV4_PROTOCOL.keys()], - validators=[DataRequired()], - ) - - fragment = SelectMultipleField( - "Fragment", - choices=[(frv, frk.upper()) for frk, frv in IPV4_FRAGMENT.items()], - validators=[Optional()], - ) - - def validate_ipv_specific(self): - """ - validate protocol and flags, set error message if validation fails - :return: boolean validation result - """ - - if self.flags.data and self.protocol.data and len(self.flags.data) > 0 and self.protocol.data != "tcp": - self.flags.errors.append("Can not set TCP flags for protocol {} !".format(self.protocol.data.upper())) - return False - return True - - -class IPv6Form(IPForm): - """ - IPv6 form object - """ - - def __init__(self, *args, **kwargs): - super(IPv6Form, self).__init__(*args, **kwargs) - self.net_ranges = None - - zero_address = "::" - source = StringField( - "Source address", - validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], - ) - - source_mask = IntegerField( - "Source prefix length (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=128, message="invalid prefix value (0-128)"), - ], - ) - - dest = StringField( - "Destination address", - validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], - ) - - dest_mask = IntegerField( - "Destination prefix length (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=128, message="invalid prefix value (0-128)"), - ], - ) - - next_header = SelectField( - "Next Header", - choices=[(pr, pr.upper()) for pr in IPV6_NEXT_HEADER.keys()], - validators=[DataRequired()], - ) - - def validate_ipv_specific(self): - """ - validate next header and flags, set error message if validation fails - :return: boolean validation result - """ - if len(self.flags.data) > 0 and self.next_header.data != "tcp": - self.flags.errors.append("Can not set TCP flags for next-header {} !".format(self.next_header.data.upper())) - return False - - return True diff --git a/flowapp/forms/__init__.py b/flowapp/forms/__init__.py new file mode 100644 index 00000000..a9d7d887 --- /dev/null +++ b/flowapp/forms/__init__.py @@ -0,0 +1,40 @@ +""" +Forms module for the flowapp application. +This file imports and re-exports all form classes from the module. +""" + +# Base field +from .base import MultiFormatDateTimeLocalField + +# User forms +from .user import UserForm, BulkUserForm + +# API key forms +from .api import ApiKeyForm, MachineApiKeyForm + +# Organization forms +from .organization import OrganizationForm + +# Action, ASPath, and Community forms +from .choices import ActionForm, ASPathForm, CommunityForm + +# Rule forms +from .rules import IPForm, IPv4Form, IPv6Form, RTBHForm, WhitelistForm + + +__all__ = [ + "MultiFormatDateTimeLocalField", + "UserForm", + "BulkUserForm", + "ApiKeyForm", + "MachineApiKeyForm", + "OrganizationForm", + "ActionForm", + "ASPathForm", + "CommunityForm", + "RTBHForm", + "IPForm", + "IPv4Form", + "IPv6Form", + "WhitelistForm", +] diff --git a/flowapp/forms/api.py b/flowapp/forms/api.py new file mode 100644 index 00000000..c9af19a0 --- /dev/null +++ b/flowapp/forms/api.py @@ -0,0 +1,68 @@ +""" +API key forms for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import SelectField, StringField, TextAreaField, BooleanField, HiddenField +from wtforms.validators import DataRequired, IPAddress, Optional, Length + +from .base import MultiFormatDateTimeLocalField +from ..constants import FORM_TIME_PATTERN + + +class ApiKeyForm(FlaskForm): + """ + ApiKey for User + Each key / machine pair is unique + """ + + machine = StringField( + "Machine address", + validators=[DataRequired(), IPAddress(message="provide valid IP address")], + ) + + comment = TextAreaField("Your comment for this key", validators=[Optional(), Length(max=255)]) + + expires = MultiFormatDateTimeLocalField( + "Key expiration. Leave blank for non expring key (not-recomended).", + format=FORM_TIME_PATTERN, + validators=[Optional()], + unlimited=True, + ) + + readonly = BooleanField("Read only key", default=False) + + key = HiddenField("GeneratedKey") + + +class MachineApiKeyForm(FlaskForm): + """ + ApiKey for Services / No login users. + Each key / machine pair is unique + This form is used by Admin to create api key for services or users with no Shibboleth login + User must be created first and must have an organization + """ + + machine = StringField( + "Machine address", + validators=[DataRequired(), IPAddress(message="provide valid IP address")], + ) + + comment = TextAreaField("Your comment for this key", validators=[Optional(), Length(max=255)]) + + expires = MultiFormatDateTimeLocalField( + "Key expiration. Leave blank for non expring key (not-recomended).", + format=FORM_TIME_PATTERN, + validators=[Optional()], + unlimited=True, + ) + + readonly = BooleanField("Read only key", default=True) + + user = SelectField( + "User", + coerce=int, + validators=[DataRequired("Select user")], + ) + + key = HiddenField("GeneratedKey") diff --git a/flowapp/forms/base.py b/flowapp/forms/base.py new file mode 100644 index 00000000..f99db23d --- /dev/null +++ b/flowapp/forms/base.py @@ -0,0 +1,50 @@ +""" +Base form fields for the flowapp application. +""" + +from wtforms import widgets +from wtforms.fields import DateTimeField +from ..utils import parse_api_time + + +class MultiFormatDateTimeLocalField(DateTimeField): + """ + Same as :class:`~wtforms.fields.DateTimeField`, but represents an + ````. + + Custom implementation uses default HTML5 format for parsing the field. + It's possible to use multiple formats - used in API. + + """ + + widget = widgets.DateTimeLocalInput() + + def __init__(self, *args, **kwargs): + kwargs.setdefault("format", "%Y-%m-%dT%H:%M") + self.unlimited = kwargs.pop("unlimited", False) + self.pref_format = None + super().__init__(*args, **kwargs) + + def process_formdata(self, valuelist): + if not valuelist or (len(valuelist) == 1 and not valuelist[0]): + return None + + # with unlimited field we do not need to parse the empty value + if self.unlimited and len(valuelist) == 1 and len(valuelist[0]) == 0: + self.data = None + return None + + date_str = " ".join((str(val) for val in valuelist)) + + try: + result, pref_format = parse_api_time(date_str) + except TypeError: + raise ValueError(self.gettext("Not a valid datetime value.")) + + if result: + self.data = result + self.pref_format = pref_format + else: + self.data = None + self.pref_format = None + raise ValueError(self.gettext("Not a valid datetime value.")) diff --git a/flowapp/forms/choices.py b/flowapp/forms/choices.py new file mode 100644 index 00000000..424b6b80 --- /dev/null +++ b/flowapp/forms/choices.py @@ -0,0 +1,81 @@ +""" +Action, ASPath, and Community forms for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import StringField, SelectField, BooleanField +from wtforms.validators import Length, DataRequired + + +class ActionForm(FlaskForm): + """ + Action form object + used in Admin + """ + + name = StringField("Action short name", validators=[Length(max=150)]) + + command = StringField("ExaBGP command", validators=[Length(max=150)]) + + description = StringField("Action description") + + role_id = SelectField( + "Minimal required role", + choices=[("2", "user"), ("3", "admin")], + validators=[DataRequired()], + ) + + +class ASPathForm(FlaskForm): + """ + AS Path form object + used in Admin + """ + + prefix = StringField("Prefix", validators=[Length(max=120), DataRequired()]) + + as_path = StringField("as-path value", validators=[Length(max=250), DataRequired()]) + + +class CommunityForm(FlaskForm): + """ + Community form object + used in Admin + """ + + name = StringField("Community short name", validators=[Length(max=120), DataRequired()]) + + comm = StringField("Community value", validators=[Length(max=2046)]) + + larcomm = StringField("Large community value", validators=[Length(max=2046)]) + + extcomm = StringField("Extended community value", validators=[Length(max=2046)]) + + description = StringField("Community description", validators=[Length(max=255)]) + + role_id = SelectField( + "Minimal required role", + choices=[("2", "user"), ("3", "admin")], + validators=[DataRequired()], + ) + + as_path = BooleanField("add AS-path (checked = true)") + + def validate(self): + """ + custom validation method + :return: boolean + """ + result = True + + if not FlaskForm.validate(self): + result = False + + if not self.comm.data and not self.extcomm.data and not self.larcomm.data: + err_message = "At last one of those values could not be empty" + self.comm.errors.append(err_message) + self.larcomm.errors.append(err_message) + self.extcomm.errors.append(err_message) + result = False + + return result diff --git a/flowapp/forms/organization.py b/flowapp/forms/organization.py new file mode 100644 index 00000000..37594ab0 --- /dev/null +++ b/flowapp/forms/organization.py @@ -0,0 +1,47 @@ +""" +Organization form for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import StringField, IntegerField, TextAreaField +from wtforms.validators import Optional, Length, NumberRange + +from ..validators import NetRangeString + + +class OrganizationForm(FlaskForm): + """ + Organization form object + used in Admin + """ + + name = StringField("Organization name", validators=[Optional(), Length(max=150)]) + + limit_flowspec4 = IntegerField( + "Maximum number of IPv4 rules, 0 for unlimited", + validators=[ + Optional(), + NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), + ], + ) + + limit_flowspec6 = IntegerField( + "Maximum number of IPv6 rules, 0 for unlimited", + validators=[ + Optional(), + NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), + ], + ) + + limit_rtbh = IntegerField( + "Maximum number of RTBH rules, 0 for unlimited", + validators=[ + Optional(), + NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), + ], + ) + + arange = TextAreaField( + "Organization Adress Range - one range per row", + validators=[Optional(), NetRangeString()], + ) diff --git a/flowapp/forms/rules/__init__.py b/flowapp/forms/rules/__init__.py new file mode 100644 index 00000000..9796e8ce --- /dev/null +++ b/flowapp/forms/rules/__init__.py @@ -0,0 +1,17 @@ +""" +Rule forms for the flowapp application. +""" + +from .base import IPForm +from .ipv4 import IPv4Form +from .ipv6 import IPv6Form +from .rtbh import RTBHForm +from .whitelist import WhitelistForm + +__all__ = [ + "IPForm", + "IPv4Form", + "IPv6Form", + "RTBHForm", + "WhitelistForm", +] diff --git a/flowapp/forms/rules/base.py b/flowapp/forms/rules/base.py new file mode 100644 index 00000000..0e67776e --- /dev/null +++ b/flowapp/forms/rules/base.py @@ -0,0 +1,127 @@ +""" +Base rule form for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import SelectMultipleField, StringField, SelectField, TextAreaField +from wtforms.validators import Optional, Length, DataRequired, InputRequired + +from ..base import MultiFormatDateTimeLocalField +from ...constants import TCP_FLAGS +from ...validators import PortString, address_with_mask, network_in_range, whole_world_range + + +class IPForm(FlaskForm): + """ + Base class for IPv4 and IPv6 rules + """ + + def __init__(self, *args, **kwargs): + super(IPForm, self).__init__(*args, **kwargs) + self.net_ranges = None + + zero_address = None + source = None + source_mask = None + dest = None + dest_mask = None + flags = SelectMultipleField("TCP flag(s)", choices=TCP_FLAGS, validators=[Optional()]) + + source_port = StringField( + "Source port(s) - ; separated ", + validators=[Optional(), Length(max=255), PortString()], + ) + + dest_port = StringField( + "Destination port(s) - ; separated", + validators=[Optional(), Length(max=255), PortString()], + ) + + packet_len = StringField( + "Packet length - ; separated ", + validators=[Optional(), Length(max=255), PortString()], + ) + + action = SelectField( + "Action", + coerce=int, + validators=[DataRequired(message="Please select an action for the rule.")], + ) + + expires = MultiFormatDateTimeLocalField("Expires", format="%Y-%m-%dT%H:%M", validators=[InputRequired()]) + + comment = arange = TextAreaField("Comments") + + def validate(self): + """ + custom validation method + :return: boolean + """ + + result = True + if not FlaskForm.validate(self): + result = False + + source = self.validate_source_address() + dest = self.validate_dest_address() + ranges = self.validate_address_ranges() + ips = self.validate_ipv_specific() + + return result and source and dest and ranges and ips + + def validate_source_address(self): + """ + validate source address, set error message if validation fails + :return: boolean validation result + """ + if self.source.data and not address_with_mask(self.source.data, self.source_mask.data): + self.source.errors.append( + "This is not valid combination of address {} and mask {}.".format( + self.source.data, self.source_mask.data + ) + ) + return False + + return True + + def validate_dest_address(self): + """ + validate dest address, set error message if validation fails + :return: boolean validation result + """ + if self.dest.data and not address_with_mask(self.dest.data, self.dest_mask.data): + self.dest.errors.append( + "This is not valid combination of address {} and mask {}.".format(self.dest.data, self.dest_mask.data) + ) + return False + + return True + + def validate_address_ranges(self): + """ + validates if the address of source is in the user range + if the source and dest address are empty, check if the user + is member of whole world organization + :return: boolean validation result + """ + if not (self.source.data or self.dest.data): + whole_world_member = whole_world_range(self.net_ranges, self.zero_address) + if not whole_world_member: + self.source.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) + self.dest.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) + return False + else: + source_in_range = network_in_range(self.source.data, self.source_mask.data, self.net_ranges) + dest_in_range = network_in_range(self.dest.data, self.dest_mask.data, self.net_ranges) + if not (source_in_range or dest_in_range): + self.source.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) + self.dest.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) + return False + + return True + + def validate_ipv_specific(self): + """ + abstract method must be implemented in the subclass + """ + pass diff --git a/flowapp/forms/rules/ipv4.py b/flowapp/forms/rules/ipv4.py new file mode 100644 index 00000000..07f27a8c --- /dev/null +++ b/flowapp/forms/rules/ipv4.py @@ -0,0 +1,70 @@ +""" +IPv4 rule form for the flowapp application. +""" + +from wtforms import StringField, IntegerField, SelectField, SelectMultipleField +from wtforms.validators import Optional, NumberRange, DataRequired + +from ...constants import IPV4_PROTOCOL, IPV4_FRAGMENT +from ...validators import IPv4Address +from .base import IPForm + + +class IPv4Form(IPForm): + """ + IPv4 form object + """ + + def __init__(self, *args, **kwargs): + super(IPv4Form, self).__init__(*args, **kwargs) + self.net_ranges = None + + zero_address = "0.0.0.0" + source = StringField( + "Source address", + validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], + ) + + source_mask = IntegerField( + "Source mask (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=32, message="invalid mask value (0-32)"), + ], + ) + + dest = StringField( + "Destination address", + validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], + ) + + dest_mask = IntegerField( + "Destination mask (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=32, message="invalid mask value (0-32)"), + ], + ) + + protocol = SelectField( + "Protocol", + choices=[(pr, pr.upper()) for pr in IPV4_PROTOCOL.keys()], + validators=[DataRequired()], + ) + + fragment = SelectMultipleField( + "Fragment", + choices=[(frv, frk.upper()) for frk, frv in IPV4_FRAGMENT.items()], + validators=[Optional()], + ) + + def validate_ipv_specific(self): + """ + validate protocol and flags, set error message if validation fails + :return: boolean validation result + """ + + if self.flags.data and self.protocol.data and len(self.flags.data) > 0 and self.protocol.data != "tcp": + self.flags.errors.append("Can not set TCP flags for protocol {} !".format(self.protocol.data.upper())) + return False + return True diff --git a/flowapp/forms/rules/ipv6.py b/flowapp/forms/rules/ipv6.py new file mode 100644 index 00000000..aabd1b2f --- /dev/null +++ b/flowapp/forms/rules/ipv6.py @@ -0,0 +1,64 @@ +""" +IPv6 rule form for the flowapp application. +""" + +from wtforms import StringField, IntegerField, SelectField +from wtforms.validators import Optional, NumberRange, DataRequired + +from ...constants import IPV6_NEXT_HEADER +from ...validators import IPv6Address +from .base import IPForm + + +class IPv6Form(IPForm): + """ + IPv6 form object + """ + + def __init__(self, *args, **kwargs): + super(IPv6Form, self).__init__(*args, **kwargs) + self.net_ranges = None + + zero_address = "::" + source = StringField( + "Source address", + validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], + ) + + source_mask = IntegerField( + "Source prefix length (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=128, message="invalid prefix value (0-128)"), + ], + ) + + dest = StringField( + "Destination address", + validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], + ) + + dest_mask = IntegerField( + "Destination prefix length (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=128, message="invalid prefix value (0-128)"), + ], + ) + + next_header = SelectField( + "Next Header", + choices=[(pr, pr.upper()) for pr in IPV6_NEXT_HEADER.keys()], + validators=[DataRequired()], + ) + + def validate_ipv_specific(self): + """ + validate next header and flags, set error message if validation fails + :return: boolean validation result + """ + if len(self.flags.data) > 0 and self.next_header.data != "tcp": + self.flags.errors.append("Can not set TCP flags for next-header {} !".format(self.next_header.data.upper())) + return False + + return True diff --git a/flowapp/forms/rules/rtbh.py b/flowapp/forms/rules/rtbh.py new file mode 100644 index 00000000..cef79483 --- /dev/null +++ b/flowapp/forms/rules/rtbh.py @@ -0,0 +1,104 @@ +""" +RTBH rule form for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import StringField, IntegerField, SelectField, TextAreaField +from wtforms.validators import Optional, NumberRange, DataRequired, InputRequired + +from ...constants import FORM_TIME_PATTERN +from ...validators import IPv6Address, address_with_mask, address_in_range, IPv4Address +from ..base import MultiFormatDateTimeLocalField + + +class RTBHForm(FlaskForm): + """ + RoadToBlackHole rule form + """ + + def __init__(self, *args, **kwargs): + super(RTBHForm, self).__init__(*args, **kwargs) + self.net_ranges = None + + ipv4 = StringField( + "IPv4 address", + validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], + ) + + ipv4_mask = IntegerField( + "IPv4 mask (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=32, message="invalid IPv4 mask value (0-32)"), + ], + ) + + ipv6 = StringField( + "IPv6 address", + validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], + ) + + ipv6_mask = IntegerField( + "IPv6 mask (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=128, message="invalid IPv6 mask value (0-128)"), + ], + ) + + community = SelectField( + "Community", + coerce=int, + validators=[ + DataRequired(message="Please select a community for the rule."), + ], + ) + + expires = MultiFormatDateTimeLocalField( + "Expires", + format=FORM_TIME_PATTERN, + validators=[DataRequired(), InputRequired()], + ) + + comment = arange = TextAreaField("Comments") + + def validate(self): + """ + custom validation method + :return: boolean + """ + result = True + + if not FlaskForm.validate(self): + result = False + + # ipv4 and ipv6 are mutually exclusive + # if both are set, validation fails + # if none is set, validation fails + # if one is set, validation passes + if self.ipv4.data and self.ipv6.data: + self.ipv4.errors.append("IPv4 and IPv6 are mutually exclusive in RTBH rule.") + self.ipv6.errors.append("IPv4 and IPv6 are mutually exclusive in RTBH rule.") + result = False + + if self.ipv4.data and not address_with_mask(self.ipv4.data, self.ipv4_mask.data): + self.ipv4.errors.append( + "This is not valid combination of address {} and mask {}.".format(self.ipv4.data, self.ipv4_mask.data) + ) + result = False + + if self.ipv6.data and not address_with_mask(self.ipv6.data, self.ipv6_mask.data): + self.ipv6.errors.append( + "This is not valid combination of address {} and mask {}.".format(self.ipv6.data, self.ipv6_mask.data) + ) + result = False + + ipv6_in_range = address_in_range(self.ipv6.data, self.net_ranges) + ipv4_in_range = address_in_range(self.ipv4.data, self.net_ranges) + + if not (ipv6_in_range or ipv4_in_range): + self.ipv6.errors.append("IPv4 or IPv6 address must be in organization range : {}.".format(self.net_ranges)) + self.ipv4.errors.append("IPv4 or IPv6 address must be in organization range : {}.".format(self.net_ranges)) + result = False + + return result diff --git a/flowapp/forms/rules/whitelist.py b/flowapp/forms/rules/whitelist.py new file mode 100644 index 00000000..8aa88b41 --- /dev/null +++ b/flowapp/forms/rules/whitelist.py @@ -0,0 +1,66 @@ +""" +Whitelist form for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import StringField, IntegerField, TextAreaField +from wtforms.validators import Optional, DataRequired, InputRequired, Length + +from ...constants import FORM_TIME_PATTERN +from ...validators import IPAddressValidator, NetworkValidator, network_in_range +from ..base import MultiFormatDateTimeLocalField + + +class WhitelistForm(FlaskForm): + """ + Whitelist form object + Used for creating and editing whitelist entries + Supports both IPv4 and IPv6 addresses + """ + + def __init__(self, *args, **kwargs): + super(WhitelistForm, self).__init__(*args, **kwargs) + self.net_ranges = None + + ip = StringField( + "IP address", + validators=[ + DataRequired(message="Please provide an IP address"), + IPAddressValidator(message="Please provide a valid IP address: {}"), + NetworkValidator(mask_field_name="mask"), + ], + ) + + mask = IntegerField( + "Network mask (bits)", + validators=[ + DataRequired(message="Please provide a network mask"), + ], + ) + + comment = TextAreaField("Comments", validators=[Optional(), Length(max=255)]) + + expires = MultiFormatDateTimeLocalField( + "Expires", + format=FORM_TIME_PATTERN, + validators=[DataRequired(), InputRequired()], + ) + + def validate(self): + """ + Custom validation method + :return: boolean + """ + result = True + + if not FlaskForm.validate(self): + result = False + + # Validate IP is in organization range + if self.ip.data and self.mask.data and self.net_ranges: + ip_in_range = network_in_range(self.ip.data, self.mask.data, self.net_ranges) + if not ip_in_range: + self.ip.errors.append("IP address must be in organization range: {}.".format(self.net_ranges)) + result = False + + return result diff --git a/flowapp/forms/user.py b/flowapp/forms/user.py new file mode 100644 index 00000000..5944afd1 --- /dev/null +++ b/flowapp/forms/user.py @@ -0,0 +1,91 @@ +""" +User-related forms for the flowapp application. +""" + +import csv +from io import StringIO + +from flask_wtf import FlaskForm +from wtforms import StringField, TextAreaField, SelectMultipleField +from wtforms.validators import ValidationError, DataRequired, Email, InputRequired, Optional + + +class UserForm(FlaskForm): + """ + User Form object + used in Admin + """ + + uuid = StringField( + "Unique User ID", + validators=[ + InputRequired("Please provide UUID"), + Email("Please provide valid email"), + ], + ) + + email = StringField("Email", validators=[Optional(), Email("Please provide valid email")]) + + comment = StringField("Notice", validators=[Optional()]) + + name = StringField("Name", validators=[Optional()]) + + phone = StringField("Contact phone", validators=[Optional()]) + + role_ids = SelectMultipleField("Role", coerce=int, validators=[DataRequired("Select at last one role")]) + + org_ids = SelectMultipleField( + "Organization", + coerce=int, + validators=[DataRequired("We prefer one Organization per user, but it's possible select more")], + ) + + +class BulkUserForm(FlaskForm): + """ + Bulk User Form object + used in Admin + """ + + users = TextAreaField("Users in CSV - see example below", validators=[DataRequired()]) + + def __init__(self, *args, **kwargs): + super(BulkUserForm, self).__init__(*args, **kwargs) + self.roles = None + self.organizations = None + self.uuids = None + + # Custom validator for CSV data + def validate_users(self, field): + csv_data = field.data + + # Parse CSV data + csv_reader = csv.DictReader(StringIO(csv_data), delimiter=",") + + # List to keep track of failed validation rows + errors = 0 + for row_num, row in enumerate(csv_reader, start=1): + try: + # check if the user not already exists + if row["uuid-eppn"] in self.uuids: + field.errors.append(f"Row {row_num}: User with UUID {row['uuid-eppn']} already exists.") + errors += 1 + + # Check if role exists in the database + role_id = int(row["role"]) # Convert role field to integer + if role_id not in self.roles: + field.errors.append(f"Row {row_num}: Role ID {role_id} does not exist.") + errors += 1 + + # Check if organization exists in the database + org_id = int(row["organizace"]) # Convert organization field to integer + if org_id not in self.organizations: + field.errors.append(f"Row {row_num}: Organization ID {org_id} does not exist.") + errors += 1 + + except (KeyError, ValueError) as e: + field.errors.append(f"Row {row_num}: Invalid data / key - {str(e)}. Check CSV head row.") + + if errors > 0: + # Raise validation error if any invalid rows found + raise ValidationError("Invalid CSV Data - check the errors above.") diff --git a/flowapp/instance_config.py b/flowapp/instance_config.py index 1ffcb1f8..4136530e 100644 --- a/flowapp/instance_config.py +++ b/flowapp/instance_config.py @@ -2,12 +2,19 @@ # column names for tables RTBH_COLUMNS = ( - ("ipv4", "IP adress (v4 or v6)"), + ("ipv4", "IP address (v4 or v6)"), ("community_id", "Community"), ("expires", "Expires"), ("user_id", "User"), ) +WHITELIST_COLUMNS = ( + ("address", "IP address / network (v4 or v6)"), + ("expires", "Expires"), + ("user_id", "User"), +) + + RULES_COLUMNS_V4 = ( ("source", "Source addr."), ("source_port", "S port"), @@ -74,6 +81,7 @@ class InstanceConfig: {"name": "Add IPv4", "url": "rules.ipv4_rule"}, {"name": "Add IPv6", "url": "rules.ipv6_rule"}, {"name": "Add RTBH", "url": "rules.rtbh_rule"}, + {"name": "Add Whitelist", "url": "whitelist.add"}, {"name": "API Key", "url": "api_keys.all"}, ], "admin": [ @@ -123,6 +131,14 @@ class InstanceConfig: "table_colspan": 5, "table_columns": RTBH_COLUMNS, }, + "whitelist": { + "name": "Whitelist", + "macro_file": "macros.html", + "macro_tbody": "build_whitelist_tbody", + "macro_thead": "build_rules_thead", + "table_colspan": 4, + "table_columns": WHITELIST_COLUMNS, + }, } - COUNT_MATCH = {"ipv4": 0, "ipv6": 0, "rtbh": 0} + COUNT_MATCH = {"ipv4": 0, "ipv6": 0, "rtbh": 0, "whitelist": 0} diff --git a/flowapp/models.py b/flowapp/models.py deleted file mode 100644 index 4e1291f3..00000000 --- a/flowapp/models.py +++ /dev/null @@ -1,1064 +0,0 @@ -import json -from sqlalchemy import event -from datetime import datetime -from flowapp import db, utils -from flowapp.constants import RuleTypes -from flask import current_app - -# models and tables - -user_role = db.Table( - "user_role", - db.Column("user_id", db.Integer, db.ForeignKey("user.id"), nullable=False), - db.Column("role_id", db.Integer, db.ForeignKey("role.id"), nullable=False), - db.PrimaryKeyConstraint("user_id", "role_id"), -) - -user_organization = db.Table( - "user_organization", - db.Column("user_id", db.Integer, db.ForeignKey("user.id"), nullable=False), - db.Column("organization_id", db.Integer, db.ForeignKey("organization.id"), nullable=False), - db.PrimaryKeyConstraint("user_id", "organization_id"), -) - - -class User(db.Model): - """ - App User - """ - - id = db.Column(db.Integer, primary_key=True) - uuid = db.Column(db.String(180), unique=True) - comment = db.Column(db.String(500)) - email = db.Column(db.String(255)) - name = db.Column(db.String(255)) - phone = db.Column(db.String(255)) - apikeys = db.relationship("ApiKey", back_populates="user", lazy="dynamic") - machineapikeys = db.relationship("MachineApiKey", back_populates="user", lazy="dynamic") - role = db.relationship("Role", secondary=user_role, lazy="dynamic", backref="user") - - organization = db.relationship("Organization", secondary=user_organization, lazy="dynamic", backref="user") - - def __init__(self, uuid, name=None, phone=None, email=None, comment=None): - self.uuid = uuid - self.phone = phone - self.name = name - self.email = email - self.comment = comment - - def update(self, form): - """ - update the user with values from form object - :param form: flask form from request - :return: None - """ - self.uuid = form.uuid.data - self.name = form.name.data - self.email = form.email.data - self.phone = form.phone.data - self.comment = form.comment.data - - # first clear existing roles and orgs - for role in self.role: - self.role.remove(role) - for org in self.organization: - self.organization.remove(org) - - for role_id in form.role_ids.data: - my_role = db.session.query(Role).filter_by(id=role_id).first() - if my_role not in self.role: - self.role.append(my_role) - - for org_id in form.org_ids.data: - my_org = db.session.query(Organization).filter_by(id=org_id).first() - if my_org not in self.organization: - self.organization.append(my_org) - - db.session.commit() - - -class ApiKey(db.Model): - id = db.Column(db.Integer, primary_key=True) - machine = db.Column(db.String(255)) - key = db.Column(db.String(255)) - readonly = db.Column(db.Boolean, default=False) - expires = db.Column(db.DateTime, nullable=True) - comment = db.Column(db.String(255)) - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", back_populates="apikeys") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="apikey") - - def is_expired(self): - if self.expires is None: - return False # Non-expiring key - else: - return self.expires < datetime.now() - - -class MachineApiKey(db.Model): - id = db.Column(db.Integer, primary_key=True) - machine = db.Column(db.String(255)) - key = db.Column(db.String(255)) - readonly = db.Column(db.Boolean, default=True) - expires = db.Column(db.DateTime, nullable=True) - comment = db.Column(db.String(255)) - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", back_populates="machineapikeys") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="machineapikey") - - def is_expired(self): - if self.expires is None: - return False # Non-expiring key - else: - return self.expires < datetime.now() - - -class Role(db.Model): - id = db.Column(db.Integer, primary_key=True) - name = db.Column(db.String(20), unique=True) - description = db.Column(db.String(260)) - - def __init__(self, name, description): - self.name = name - self.description = description - - def __repr__(self): - return self.name - - -class Organization(db.Model): - id = db.Column(db.Integer, primary_key=True) - name = db.Column(db.String(150), unique=True) - arange = db.Column(db.Text) - limit_flowspec4 = db.Column(db.Integer, default=0) - limit_flowspec6 = db.Column(db.Integer, default=0) - limit_rtbh = db.Column(db.Integer, default=0) - - def __init__(self, name, arange, limit_flowspec4=0, limit_flowspec6=0, limit_rtbh=0): - self.name = name - self.arange = arange - self.limit_flowspec4 = limit_flowspec4 - self.limit_flowspec6 = limit_flowspec6 - self.limit_rtbh = limit_rtbh - - def __repr__(self): - return self.name - - def get_users(self): - """ - Returns all users associated with this organization. - """ - # self.user is the backref from the user_organization relationship - return self.user - - -class ASPath(db.Model): - id = db.Column(db.Integer, primary_key=True) - prefix = db.Column(db.String(120), unique=True) - as_path = db.Column(db.String(250)) - - def __init__(self, prefix, as_path): - self.prefix = prefix - self.as_path = as_path - - def __repr__(self): - return f"{self.prefix} : {self.as_path}" - - -class Action(db.Model): - """ - Action for rule - """ - - id = db.Column(db.Integer, primary_key=True) - name = db.Column(db.String(120), unique=True) - command = db.Column(db.String(120), unique=True) - description = db.Column(db.String(260)) - role_id = db.Column(db.Integer, db.ForeignKey("role.id"), nullable=False) - role = db.relationship("Role", backref="action") - - def __init__(self, name, command, description, role_id=2): - self.name = name - self.command = command - self.description = description - self.role_id = role_id - - -class Community(db.Model): - """ - Community for RTBH rule - """ - - id = db.Column(db.Integer, primary_key=True) - name = db.Column(db.String(120), unique=True) - comm = db.Column(db.String(2047)) - larcomm = db.Column(db.String(2047)) - extcomm = db.Column(db.String(2047)) - description = db.Column(db.String(255)) - as_path = db.Column(db.Boolean, default=False) - role_id = db.Column(db.Integer, db.ForeignKey("role.id"), nullable=False) - role = db.relationship("Role", backref="community") - - def __init__(self, name, comm, larcomm, extcomm, description, as_path=False, role_id=2): - self.name = name - self.comm = comm - self.larcomm = larcomm - self.extcomm = extcomm - self.description = description - self.as_path = as_path - self.role_id = role_id - - -class Rstate(db.Model): - """ - State for Rules - """ - - id = db.Column(db.Integer, primary_key=True) - description = db.Column(db.String(260)) - - def __init__(self, description): - self.description = description - - -class RTBH(db.Model): - __tablename__ = "RTBH" - - id = db.Column(db.Integer, primary_key=True) - ipv4 = db.Column(db.String(255)) - ipv4_mask = db.Column(db.Integer) - ipv6 = db.Column(db.String(255)) - ipv6_mask = db.Column(db.Integer) - community_id = db.Column(db.Integer, db.ForeignKey("community.id"), nullable=False) - community = db.relationship("Community", backref="rtbh") - comment = db.Column(db.Text) - expires = db.Column(db.DateTime) - created = db.Column(db.DateTime) - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", backref="rtbh") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="rtbh") - rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) - rstate = db.relationship("Rstate", backref="RTBH") - - def __init__( - self, - ipv4, - ipv4_mask, - ipv6, - ipv6_mask, - community_id, - expires, - user_id, - org_id, - comment=None, - created=None, - rstate_id=1, - ): - self.ipv4 = ipv4 - self.ipv4_mask = ipv4_mask - self.ipv6 = ipv6 - self.ipv6_mask = ipv6_mask - self.community_id = community_id - self.expires = expires - self.user_id = user_id - self.org_id = org_id - self.comment = comment - if created is None: - created = datetime.now() - self.created = created - self.rstate_id = rstate_id - - def __eq__(self, other): - """ - Two models are equal if all the network parameters equals. - User_id and time fields can differ. - :param other: other RTBH instance - :return: boolean - """ - return ( - self.ipv4 == other.ipv4 - and self.ipv4_mask == other.ipv4_mask - and self.ipv6 == other.ipv6 - and self.ipv6_mask == other.ipv6_mask - and self.community_id == other.community_id - and self.rstate_id == other.rstate_id - ) - - def __ne__(self, other): - """ - Two models are not equal if all the network parameters are not equal. - User_id and time fields can differ. - :param other: other RTBH instance - :return: boolean - """ - compars = ( - self.ipv4 == other.ipv4 - and self.ipv4_mask == other.ipv4_mask - and self.ipv6 == other.ipv6 - and self.ipv6_mask == other.ipv6_mask - and self.community_id == other.community_id - and self.rstate_id == other.rstate_id - ) - - return not compars - - def update_time(self, form): - self.expires = utils.webpicker_to_datetime(form.expire_date.data) - db.session.commit() - - def to_dict(self, prefered_format="yearfirst"): - """ - Serialize to dict used in API - :param prefered_format: string with prefered time format - :return: dictionary - """ - if prefered_format == "timestamp": - expires = int(datetime.timestamp(self.expires)) - created = int(datetime.timestamp(self.created)) - else: - expires = utils.datetime_to_webpicker(self.expires, prefered_format) - created = utils.datetime_to_webpicker(self.created, prefered_format) - - return { - "id": self.id, - "ipv4": self.ipv4, - "ipv4_mask": self.ipv4_mask, - "ipv6": self.ipv6, - "ipv6_mask": self.ipv6_mask, - "community": self.community.name, - "comment": self.comment, - "expires": expires, - "created": created, - "user": self.user.uuid, - "rstate": self.rstate.description, - } - - def dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :returns: dictionary - """ - return self.to_dict(prefered_format) - - def json(self, prefered_format="yearfirst"): - """ - Serialize to json - :param prefered_format: string with prefered time format - :returns: json - """ - return json.dumps(self.to_dict()) - - -class Flowspec4(db.Model): - id = db.Column(db.Integer, primary_key=True) - source = db.Column(db.String(255)) - source_mask = db.Column(db.Integer) - source_port = db.Column(db.String(255)) - dest = db.Column(db.String(255)) - dest_mask = db.Column(db.Integer) - dest_port = db.Column(db.String(255)) - protocol = db.Column(db.String(255)) - flags = db.Column(db.String(255)) - packet_len = db.Column(db.String(255)) - fragment = db.Column(db.String(255)) - comment = db.Column(db.Text) - expires = db.Column(db.DateTime) - created = db.Column(db.DateTime) - action_id = db.Column(db.Integer, db.ForeignKey("action.id"), nullable=False) - action = db.relationship("Action", backref="flowspec4") - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", backref="flowspec4") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="flowspec4") - rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) - rstate = db.relationship("Rstate", backref="flowspec4") - - def __init__( - self, - source, - source_mask, - source_port, - destination, - destination_mask, - destination_port, - protocol, - flags, - packet_len, - fragment, - expires, - user_id, - org_id, - action_id, - created=None, - comment=None, - rstate_id=1, - ): - self.source = source - self.source_mask = source_mask - self.dest = destination - self.dest_mask = destination_mask - self.source_port = source_port - self.dest_port = destination_port - self.protocol = protocol - self.flags = flags - self.packet_len = packet_len - self.fragment = fragment - self.comment = comment - self.expires = expires - self.user_id = user_id - self.org_id = org_id - self.action_id = action_id - if created is None: - created = datetime.now() - self.created = created - self.rstate_id = rstate_id - - def __eq__(self, other): - """ - Two models are equal if all the network parameters equals. User_id and time fields can differ. - :param other: other Flowspec4 instance - :return: boolean - """ - return ( - self.source == other.source - and self.source_mask == other.source_mask - and self.dest == other.dest - and self.dest_mask == other.dest_mask - and self.source_port == other.source_port - and self.dest_port == other.dest_port - and self.protocol == other.protocol - and self.flags == other.flags - and self.packet_len == other.packet_len - and self.fragment == other.fragment - and self.action_id == other.action_id - and self.rstate_id == other.rstate_id - ) - - def __ne__(self, other): - """ - Two models are not equal if all the network parameters are not equal. - User_id and time fields can differ. - :param other: other Flowspec4 instance - :return: boolean - """ - compars = ( - self.source == other.source - and self.source_mask == other.source_mask - and self.dest == other.dest - and self.dest_mask == other.dest_mask - and self.source_port == other.source_port - and self.dest_port == other.dest_port - and self.protocol == other.protocol - and self.flags == other.flags - and self.packet_len == other.packet_len - and self.fragment == other.fragment - and self.action_id == other.action_id - and self.rstate_id == other.rstate_id - ) - - return not compars - - def to_dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :return: dictionary - """ - if prefered_format == "timestamp": - expires = int(datetime.timestamp(self.expires)) - created = int(datetime.timestamp(self.created)) - else: - expires = utils.datetime_to_webpicker(self.expires, prefered_format) - created = utils.datetime_to_webpicker(self.created, prefered_format) - - return { - "id": self.id, - "source": self.source, - "source_mask": self.source_mask, - "source_port": self.source_port, - "dest": self.dest, - "dest_mask": self.dest_mask, - "dest_port": self.dest_port, - "protocol": self.protocol, - "flags": self.flags, - "packet_len": self.packet_len, - "fragment": self.fragment, - "comment": self.comment, - "expires": expires, - "created": created, - "action": self.action.name, - "user": self.user.uuid, - "rstate": self.rstate.description, - } - - def dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :returns: dictionary - """ - return self.to_dict(prefered_format) - - def json(self, prefered_format="yearfirst"): - """ - Serialize to json - :param prefered_format: string with prefered time format - :returns: json - """ - return json.dumps(self.to_dict()) - - -class Flowspec6(db.Model): - id = db.Column(db.Integer, primary_key=True) - source = db.Column(db.String(255)) - source_mask = db.Column(db.Integer) - source_port = db.Column(db.String(255)) - dest = db.Column(db.String(255)) - dest_mask = db.Column(db.Integer) - dest_port = db.Column(db.String(255)) - next_header = db.Column(db.String(255)) - flags = db.Column(db.String(255)) - packet_len = db.Column(db.String(255)) - comment = db.Column(db.Text) - expires = db.Column(db.DateTime) - created = db.Column(db.DateTime) - action_id = db.Column(db.Integer, db.ForeignKey("action.id"), nullable=False) - action = db.relationship("Action", backref="flowspec6") - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", backref="flowspec6") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="flowspec6") - rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) - rstate = db.relationship("Rstate", backref="flowspec6") - - def __init__( - self, - source, - source_mask, - source_port, - destination, - destination_mask, - destination_port, - next_header, - flags, - packet_len, - expires, - user_id, - org_id, - action_id, - created=None, - comment=None, - rstate_id=1, - ): - self.source = source - self.source_mask = source_mask - self.dest = destination - self.dest_mask = destination_mask - self.source_port = source_port - self.dest_port = destination_port - self.next_header = next_header - self.flags = flags - self.packet_len = packet_len - self.comment = comment - self.expires = expires - self.user_id = user_id - self.org_id = org_id - self.action_id = action_id - if created is None: - created = datetime.now() - self.created = created - self.rstate_id = rstate_id - - def __eq__(self, other): - """ - Two models are equal if all the network parameters equals. User_id and time fields can differ. - :param other: other Flowspec4 instance - :return: boolean - """ - return ( - self.source == other.source - and self.source_mask == other.source_mask - and self.dest == other.dest - and self.dest_mask == other.dest_mask - and self.source_port == other.source_port - and self.dest_port == other.dest_port - and self.next_header == other.next_header - and self.flags == other.flags - and self.packet_len == other.packet_len - and self.action_id == other.action_id - and self.rstate_id == other.rstate_id - ) - - def to_dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :returns: dictionary - """ - if prefered_format == "timestamp": - expires = int(datetime.timestamp(self.expires)) - created = int(datetime.timestamp(self.created)) - else: - expires = utils.datetime_to_webpicker(self.expires, prefered_format) - created = utils.datetime_to_webpicker(self.created, prefered_format) - - return { - "id": str(self.id), - "source": self.source, - "source_mask": self.source_mask, - "source_port": self.source_port, - "dest": self.dest, - "dest_mask": self.dest_mask, - "dest_port": self.dest_port, - "next_header": self.next_header, - "flags": self.flags, - "packet_len": self.packet_len, - "comment": self.comment, - "expires": expires, - "created": created, - "action": self.action.name, - "user": self.user.uuid, - "rstate": self.rstate.description, - } - - def dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :returns: dictionary - """ - return self.to_dict(prefered_format) - - def json(self, prefered_format="yearfirst"): - """ - Serialize to json - :param prefered_format: string with prefered time format - :returns: json - """ - return json.dumps(self.to_dict()) - - -class Log(db.Model): - id = db.Column(db.Integer, primary_key=True) - time = db.Column(db.DateTime) - task = db.Column(db.String(1000)) - author = db.Column(db.String(1000)) - rule_type = db.Column(db.Integer) - rule_id = db.Column(db.Integer) - user_id = db.Column(db.Integer) - org_id = db.Column(db.Integer, nullable=True) - - def __init__(self, time, task, user_id, rule_type, rule_id, author, org_id=None): - self.time = time - self.task = task - self.rule_type = rule_type - self.rule_id = rule_id - self.user_id = user_id - self.author = author - self.org_id = org_id - - -# DDL -# default values for tables inserted after create - - -@event.listens_for(Action.__table__, "after_create") -def insert_initial_actions(table, conn, *args, **kwargs): - conn.execute( - table.insert().values( - name="QoS 100 kbps", - command="rate-limit 12800", - description="QoS", - role_id=2, - ) - ) - conn.execute( - table.insert().values( - name="QoS 1Mbps", - command="rate-limit 13107200", - description="QoS", - role_id=2, - ) - ) - conn.execute( - table.insert().values( - name="QoS 10Mbps", - command="rate-limit 131072000", - description="QoS", - role_id=2, - ) - ) - conn.execute(table.insert().values(name="Discard", command="discard", description="Discard", role_id=2)) - - -@event.listens_for(Community.__table__, "after_create") -def insert_initial_communities(table, conn, *args, **kwargs): - conn.execute( - table.insert().values( - name="65535:65283", - comm="65535:65283", - larcomm="", - extcomm="", - description="local-as", - role_id=2, - ) - ) - conn.execute( - table.insert().values( - name="64496:64511", - comm="64496:64511", - larcomm="", - extcomm="", - description="", - role_id=2, - ) - ) - conn.execute( - table.insert().values( - name="64497:64510", - comm="64497:64510", - larcomm="", - extcomm="", - description="", - role_id=2, - ) - ) - - -@event.listens_for(Role.__table__, "after_create") -def insert_initial_roles(table, conn, *args, **kwargs): - conn.execute(table.insert().values(name="view", description="just view, no edit")) - conn.execute(table.insert().values(name="user", description="can edit")) - conn.execute(table.insert().values(name="admin", description="admin")) - - -@event.listens_for(Organization.__table__, "after_create") -def insert_initial_organizations(table, conn, *args, **kwargs): - conn.execute(table.insert().values(name="Cesnet", arange="147.230.0.0/16\n2001:718:1c01::/48")) - - -@event.listens_for(Rstate.__table__, "after_create") -def insert_initial_rulestates(table, conn, *args, **kwargs): - conn.execute(table.insert().values(description="active rule")) - conn.execute(table.insert().values(description="withdrawed rule")) - conn.execute(table.insert().values(description="deleted rule")) - - -# Misc functions -def check_rule_limit(org_id: int, rule_type: RuleTypes) -> bool: - """ - Check if the organization has reached the rule limit - :param org_id: integer organization id - :param rule_type: RuleType rule type - :return: boolean - """ - flowspec_limit = current_app.config.get("FLOWSPEC_MAX_RULES", 9000) - rtbh_limit = current_app.config.get("RTBH_MAX_RULES", 100000) - fs4 = db.session.query(Flowspec4).filter_by(rstate_id=1).count() - fs6 = db.session.query(Flowspec6).filter_by(rstate_id=1).count() - rtbh = db.session.query(RTBH).filter_by(rstate_id=1).count() - - # check the organization limits - org = Organization.query.filter_by(id=org_id).first() - if rule_type == RuleTypes.IPv4 and org.limit_flowspec4 > 0: - count = db.session.query(Flowspec4).filter_by(org_id=org_id, rstate_id=1).count() - return count >= org.limit_flowspec4 or fs4 >= flowspec_limit - if rule_type == RuleTypes.IPv6 and org.limit_flowspec6 > 0: - count = db.session.query(Flowspec6).filter_by(org_id=org_id, rstate_id=1).count() - return count >= org.limit_flowspec6 or fs6 >= flowspec_limit - if rule_type == RuleTypes.RTBH and org.limit_rtbh > 0: - count = db.session.query(RTBH).filter_by(org_id=org_id, rstate_id=1).count() - return count >= org.limit_rtbh or rtbh >= rtbh_limit - - -def check_global_rule_limit(rule_type: RuleTypes) -> bool: - flowspec4_limit = current_app.config.get("FLOWSPEC4_MAX_RULES", 9000) - flowspec6_limit = current_app.config.get("FLOWSPEC6_MAX_RULES", 9000) - rtbh_limit = current_app.config.get("RTBH_MAX_RULES", 100000) - fs4 = db.session.query(Flowspec4).filter_by(rstate_id=1).count() - fs6 = db.session.query(Flowspec6).filter_by(rstate_id=1).count() - rtbh = db.session.query(RTBH).filter_by(rstate_id=1).count() - - # check the global limits if the organization limits are not set - - if rule_type == RuleTypes.IPv4: - return fs4 >= flowspec4_limit - if rule_type == RuleTypes.IPv6: - return fs6 >= flowspec6_limit - if rule_type == RuleTypes.RTBH: - return rtbh >= rtbh_limit - - -def get_ipv4_model_if_exists(form_data, rstate_id=1): - """ - Check if the record in database exist - """ - record = ( - db.session.query(Flowspec4) - .filter( - Flowspec4.source == form_data["source"], - Flowspec4.source_mask == form_data["source_mask"], - Flowspec4.source_port == form_data["source_port"], - Flowspec4.dest == form_data["dest"], - Flowspec4.dest_mask == form_data["dest_mask"], - Flowspec4.dest_port == form_data["dest_port"], - Flowspec4.protocol == form_data["protocol"], - Flowspec4.flags == ";".join(form_data["flags"]), - Flowspec4.packet_len == form_data["packet_len"], - Flowspec4.action_id == form_data["action"], - Flowspec4.rstate_id == rstate_id, - ) - .first() - ) - - if record: - return record - - return False - - -def get_ipv6_model_if_exists(form_data, rstate_id=1): - """ - Check if the record in database exist - """ - record = ( - db.session.query(Flowspec6) - .filter( - Flowspec6.source == form_data["source"], - Flowspec6.source_mask == form_data["source_mask"], - Flowspec6.source_port == form_data["source_port"], - Flowspec6.dest == form_data["dest"], - Flowspec6.dest_mask == form_data["dest_mask"], - Flowspec6.dest_port == form_data["dest_port"], - Flowspec6.next_header == form_data["next_header"], - Flowspec6.flags == ";".join(form_data["flags"]), - Flowspec6.packet_len == form_data["packet_len"], - Flowspec6.action_id == form_data["action"], - Flowspec6.rstate_id == rstate_id, - ) - .first() - ) - - if record: - return record - - return False - - -def get_rtbh_model_if_exists(form_data, rstate_id=1): - """ - Check if the record in database exist - """ - - record = ( - db.session.query(RTBH) - .filter( - RTBH.ipv4 == form_data["ipv4"], - RTBH.ipv4_mask == form_data["ipv4_mask"], - RTBH.ipv6 == form_data["ipv6"], - RTBH.ipv6_mask == form_data["ipv6_mask"], - RTBH.community_id == form_data["community"], - RTBH.rstate_id == rstate_id, - ) - .first() - ) - - if record: - return record - - return False - - -def insert_users(users): - """ - inser list of users {name: string, role_id: integer} to db - """ - for user in users: - r = Role.query.filter_by(id=user["role_id"]).first() - o = Organization.query.filter_by(id=user["org_id"]).first() - u = User(uuid=user["name"]) - u.role.append(r) - u.organization.append(o) - db.session.add(u) - - db.session.commit() - - -def insert_user( - uuid: str, - role_ids: list, - org_ids: list, - name: str = None, - phone: str = None, - email: str = None, - comment: str = None, -): - """ - insert new user with multiple roles and organizations - :param uuid: string unique user id (eppn or similar) - :param phone: string phone number - :param name: string user name - :param email: string email - :param comment: string comment / notice - :param role_ids: list of roles - :param org_ids: list of orgs - :return: None - """ - u = User(uuid=uuid, name=name, phone=phone, comment=comment, email=email) - - for role_id in role_ids: - r = Role.query.filter_by(id=role_id).first() - u.role.append(r) - - for org_id in org_ids: - o = Organization.query.filter_by(id=org_id).first() - u.organization.append(o) - - db.session.add(u) - db.session.commit() - - -def get_user_nets(user_id): - """ - Return list of network ranges for all user organization - """ - user = db.session.query(User).filter_by(id=user_id).first() - orgs = user.organization - result = [] - for org in orgs: - result.extend(org.arange.split()) - - return result - - -def get_user_orgs_choices(user_id): - """ - Return list of orgs as choices for form - """ - user = db.session.query(User).filter_by(id=user_id).first() - orgs = user.organization - - return [(g.id, g.name) for g in orgs] - - -def get_user_actions(user_roles): - """ - Return list of actions based on current user role - """ - max_role = max(user_roles) - if max_role == 3: - actions = db.session.query(Action).order_by("id") - else: - actions = db.session.query(Action).filter_by(role_id=max_role).order_by("id") - - return [(g.id, g.name) for g in actions] - - -def get_user_communities(user_roles): - """ - Return list of communities based on current user role - """ - max_role = max(user_roles) - if max_role == 3: - communities = db.session.query(Community).order_by("id") - else: - communities = db.session.query(Community).filter_by(role_id=max_role).order_by("id") - - return [(g.id, g.name) for g in communities] - - -def get_existing_action(name=None, command=None): - """ - return Action with given name or command if the action exists - return None if action not exists - :param name: string action name - :param command: string action command - :return: action id - """ - action = Action.query.filter((Action.name == name) | (Action.command == command)).first() - return action.id if hasattr(action, "id") else None - - -def get_existing_community(name=None): - """ - return Community with given name or command if the action exists - return None if action not exists - :param name: string action name - :param command: string action command - :return: action id - """ - community = Community.query.filter(Community.name == name).first() - return community.id if hasattr(community, "id") else None - - -def get_ip_rules(rule_type, rule_state, sort="expires", order="desc"): - """ - Returns list of rules sorted by sort column ordered asc or desc - :param sort: sorting column - :param order: asc or desc - :return: list - """ - - today = datetime.now() - comp_func = utils.get_comp_func(rule_state) - - if rule_type == "ipv4": - sorter_ip4 = getattr(Flowspec4, sort, Flowspec4.id) - sorting_ip4 = getattr(sorter_ip4, order) - if comp_func: - rules4 = ( - db.session.query(Flowspec4).filter(comp_func(Flowspec4.expires, today)).order_by(sorting_ip4()).all() - ) - else: - rules4 = db.session.query(Flowspec4).order_by(sorting_ip4()).all() - - return rules4 - - if rule_type == "ipv6": - sorter_ip6 = getattr(Flowspec6, sort, Flowspec6.id) - sorting_ip6 = getattr(sorter_ip6, order) - if comp_func: - rules6 = ( - db.session.query(Flowspec6).filter(comp_func(Flowspec6.expires, today)).order_by(sorting_ip6()).all() - ) - else: - rules6 = db.session.query(Flowspec6).order_by(sorting_ip6()).all() - - return rules6 - - if rule_type == "rtbh": - sorter_rtbh = getattr(RTBH, sort, RTBH.id) - sorting_rtbh = getattr(sorter_rtbh, order) - - if comp_func: - rules_rtbh = db.session.query(RTBH).filter(comp_func(RTBH.expires, today)).order_by(sorting_rtbh()).all() - - else: - rules_rtbh = db.session.query(RTBH).order_by(sorting_rtbh()).all() - - return rules_rtbh - - -def get_user_rules_ids(user_id, rule_type): - """ - Returns list of rule ids belonging to user - :param user_id: user id - :param rule_type: ipv4, ipv6 or rtbh - :return: list - """ - - if rule_type == "ipv4": - rules4 = db.session.query(Flowspec4.id).filter_by(user_id=user_id).all() - return [int(x[0]) for x in rules4] - - if rule_type == "ipv6": - rules6 = db.session.query(Flowspec6.id).order_by(Flowspec6.expires.desc()).all() - return [int(x[0]) for x in rules6] - - if rule_type == "rtbh": - rules_rtbh = db.session.query(RTBH.id).filter_by(user_id=user_id).all() - return [int(x[0]) for x in rules_rtbh] diff --git a/flowapp/models/__init__.py b/flowapp/models/__init__.py new file mode 100644 index 00000000..64055059 --- /dev/null +++ b/flowapp/models/__init__.py @@ -0,0 +1,71 @@ +# Import and re-export all models and functions for backward compatibility +from .base import db, user_role, user_organization + +# User-related models +from .user import User, Role +from .api import ApiKey, MachineApiKey +from .organization import Organization + +# Rule-related models +from .rules import Flowspec4, Flowspec6, RTBH, Rstate, Action, Whitelist, RuleWhitelistCache + +# Other models +from .community import Community, ASPath, insert_initial_communities +from .log import Log + +# Helper functions +from .utils import ( + get_user_nets, + get_user_actions, + get_user_communities, + get_existing_action, + get_existing_community, + get_ipv4_model_if_exists, + get_ipv6_model_if_exists, + get_rtbh_model_if_exists, + get_whitelist_model_if_exists, + get_ip_rules, + get_user_rules_ids, + insert_users, + insert_user, + check_rule_limit, + check_global_rule_limit, +) + +# Ensure all models are registered properly +__all__ = [ + "db", + "User", + "Role", + "user_role", + "ApiKey", + "MachineApiKey", + "Organization", + "user_organization", + "Action", + "Flowspec4", + "Flowspec6", + "RTBH", + "Rstate", + "Community", + "ASPath", + "Log", + "Whitelist", + "RuleWhitelistCache", + "get_user_nets", + "get_user_actions", + "get_user_communities", + "get_existing_action", + "get_existing_community", + "get_ipv4_model_if_exists", + "get_ipv6_model_if_exists", + "get_rtbh_model_if_exists", + "get_whitelist_model_if_exists", + "get_ip_rules", + "get_user_rules_ids", + "insert_users", + "insert_user", + "check_rule_limit", + "check_global_rule_limit", + "insert_initial_communities", +] diff --git a/flowapp/models/api.py b/flowapp/models/api.py new file mode 100644 index 00000000..ed506850 --- /dev/null +++ b/flowapp/models/api.py @@ -0,0 +1,40 @@ +from datetime import datetime +from .base import db + + +class ApiKey(db.Model): + id = db.Column(db.Integer, primary_key=True) + machine = db.Column(db.String(255)) + key = db.Column(db.String(255)) + readonly = db.Column(db.Boolean, default=False) + expires = db.Column(db.DateTime, nullable=True) + comment = db.Column(db.String(255)) + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", back_populates="apikeys") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="apikey") + + def is_expired(self): + if self.expires is None: + return False # Non-expiring key + else: + return self.expires < datetime.now() + + +class MachineApiKey(db.Model): + id = db.Column(db.Integer, primary_key=True) + machine = db.Column(db.String(255)) + key = db.Column(db.String(255)) + readonly = db.Column(db.Boolean, default=True) + expires = db.Column(db.DateTime, nullable=True) + comment = db.Column(db.String(255)) + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", back_populates="machineapikeys") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="machineapikey") + + def is_expired(self): + if self.expires is None: + return False # Non-expiring key + else: + return self.expires < datetime.now() diff --git a/flowapp/models/base.py b/flowapp/models/base.py new file mode 100644 index 00000000..a934ce83 --- /dev/null +++ b/flowapp/models/base.py @@ -0,0 +1,16 @@ +from flowapp import db + +# Define shared tables +user_role = db.Table( + "user_role", + db.Column("user_id", db.Integer, db.ForeignKey("user.id"), nullable=False), + db.Column("role_id", db.Integer, db.ForeignKey("role.id"), nullable=False), + db.PrimaryKeyConstraint("user_id", "role_id"), +) + +user_organization = db.Table( + "user_organization", + db.Column("user_id", db.Integer, db.ForeignKey("user.id"), nullable=False), + db.Column("organization_id", db.Integer, db.ForeignKey("organization.id"), nullable=False), + db.PrimaryKeyConstraint("user_id", "organization_id"), +) diff --git a/flowapp/models/community.py b/flowapp/models/community.py new file mode 100644 index 00000000..880a837a --- /dev/null +++ b/flowapp/models/community.py @@ -0,0 +1,79 @@ +from sqlalchemy import event +from .base import db + + +class Community(db.Model): + """Community for RTBH rule""" + + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(120), unique=True) + comm = db.Column(db.String(2047)) + larcomm = db.Column(db.String(2047)) + extcomm = db.Column(db.String(2047)) + description = db.Column(db.String(255)) + as_path = db.Column(db.Boolean, default=False) + role_id = db.Column(db.Integer, db.ForeignKey("role.id"), nullable=False) + role = db.relationship("Role", backref="community") + + def __init__(self, name, comm, larcomm, extcomm, description, as_path, role_id): + self.name = name + self.comm = comm + self.larcomm = larcomm + self.extcomm = extcomm + self.description = description + self.as_path = as_path + self.role_id = role_id + + @classmethod + def get_whitelistable_communities(cls, id_list): + return cls.query.filter(cls.id.in_(id_list)).all() + + def __repr__(self): + return f"" + + def __str__(self): + return f"{self.name}" + + +class ASPath(db.Model): + """AS Path for RTBH rules""" + + id = db.Column(db.Integer, primary_key=True) + prefix = db.Column(db.String(120), unique=True) + as_path = db.Column(db.String(250)) + + # Methods and initializer + + +@event.listens_for(Community.__table__, "after_create") +def insert_initial_communities(table, conn, *args, **kwargs): + conn.execute( + table.insert().values( + name="65535:65283", + comm="65535:65283", + larcomm="", + extcomm="", + description="local-as", + role_id=2, + ) + ) + conn.execute( + table.insert().values( + name="64496:64511", + comm="64496:64511", + larcomm="", + extcomm="", + description="", + role_id=2, + ) + ) + conn.execute( + table.insert().values( + name="64497:64510", + comm="64497:64510", + larcomm="", + extcomm="", + description="", + role_id=2, + ) + ) diff --git a/flowapp/models/log.py b/flowapp/models/log.py new file mode 100644 index 00000000..16cf5b13 --- /dev/null +++ b/flowapp/models/log.py @@ -0,0 +1,39 @@ +from datetime import datetime, timedelta + +from flowapp.constants import RuleTypes +from .base import db + + +class Log(db.Model): + """Log model for system actions""" + + id = db.Column(db.Integer, primary_key=True) + time = db.Column(db.DateTime) + task = db.Column(db.String(1000)) + author = db.Column(db.String(1000)) + rule_type = db.Column(db.Integer) + rule_id = db.Column(db.Integer) + user_id = db.Column(db.Integer) + + def __init__(self, time, task, user_id, rule_type, rule_id, author): + self.time = time + self.task = task + self.rule_type = rule_type + self.rule_id = rule_id + self.user_id = user_id + self.author = author + + @classmethod + def delete_old(cls, days: int = 30): + """Delete logs older than :param days from the database""" + cls.query.filter(cls.time < datetime.now() - timedelta(days=days)).delete() + db.session.commit() + + def __repr__(self): + return f"" + + def __str__(self): + """ + {"author": "vrany@cesnet.cz / Cel\u00fd sv\u011bt", "source": "UI", "command": "cmd"} + """ + return f"{self.author} - {RuleTypes(self.rule_type).name}({self.rule_id}) - {self.task}" diff --git a/flowapp/models/organization.py b/flowapp/models/organization.py new file mode 100644 index 00000000..baf0ec1f --- /dev/null +++ b/flowapp/models/organization.py @@ -0,0 +1,35 @@ +from sqlalchemy import event +from .base import db + + +class Organization(db.Model): + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(150), unique=True) + arange = db.Column(db.Text) + limit_flowspec4 = db.Column(db.Integer, default=0) + limit_flowspec6 = db.Column(db.Integer, default=0) + limit_rtbh = db.Column(db.Integer, default=0) + + def __init__(self, name, arange, limit_flowspec4=0, limit_flowspec6=0, limit_rtbh=0): + self.name = name + self.arange = arange + self.limit_flowspec4 = limit_flowspec4 + self.limit_flowspec6 = limit_flowspec6 + self.limit_rtbh = limit_rtbh + + def __repr__(self): + return self.name + + def get_users(self): + """ + Returns all users associated with this organization. + """ + # self.user is the backref from the user_organization relationship + return self.user + + +# Event listeners for Organization +@event.listens_for(Organization.__table__, "after_create") +def insert_initial_organizations(table, conn, *args, **kwargs): + conn.execute(table.insert().values(name="TU Liberec", arange="147.230.0.0/16\n2001:718:1c01::/48")) + conn.execute(table.insert().values(name="Cesnet", arange="147.230.0.0/16\n2001:718:1c01::/48")) diff --git a/flowapp/models/rules/__init__.py b/flowapp/models/rules/__init__.py new file mode 100644 index 00000000..da12ba45 --- /dev/null +++ b/flowapp/models/rules/__init__.py @@ -0,0 +1,6 @@ +from .flowspec import Flowspec4, Flowspec6 +from .rtbh import RTBH +from .base import Rstate, Action +from .whitelist import Whitelist, RuleWhitelistCache + +__all__ = ["Flowspec4", "Flowspec6", "RTBH", "Rstate", "Action", "Whitelist", "RuleWhitelistCache"] diff --git a/flowapp/models/rules/base.py b/flowapp/models/rules/base.py new file mode 100644 index 00000000..22fbc089 --- /dev/null +++ b/flowapp/models/rules/base.py @@ -0,0 +1,69 @@ +from sqlalchemy import event +from ..base import db + + +class Rstate(db.Model): + """State for Rules""" + + id = db.Column(db.Integer, primary_key=True) + description = db.Column(db.String(260)) + + def __init__(self, description): + self.description = description + + +class Action(db.Model): + """ + Action for rule + """ + + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(120), unique=True) + command = db.Column(db.String(120), unique=True) + description = db.Column(db.String(260)) + role_id = db.Column(db.Integer, db.ForeignKey("role.id"), nullable=False) + role = db.relationship("Role", backref="action") + + def __init__(self, name, command, description, role_id=2): + self.name = name + self.command = command + self.description = description + self.role_id = role_id + + +# Event listeners for Rstate +@event.listens_for(Rstate.__table__, "after_create") +def insert_initial_rulestates(table, conn, *args, **kwargs): + conn.execute(table.insert().values(description="active rule")) + conn.execute(table.insert().values(description="withdrawed rule")) + conn.execute(table.insert().values(description="deleted rule")) + conn.execute(table.insert().values(description="whitelisted rule")) + + +@event.listens_for(Action.__table__, "after_create") +def insert_initial_actions(table, conn, *args, **kwargs): + conn.execute( + table.insert().values( + name="QoS 100 kbps", + command="rate-limit 12800", + description="QoS", + role_id=2, + ) + ) + conn.execute( + table.insert().values( + name="QoS 1Mbps", + command="rate-limit 13107200", + description="QoS", + role_id=2, + ) + ) + conn.execute( + table.insert().values( + name="QoS 10Mbps", + command="rate-limit 131072000", + description="QoS", + role_id=2, + ) + ) + conn.execute(table.insert().values(name="Discard", command="discard", description="Discard", role_id=2)) diff --git a/flowapp/models/rules/flowspec.py b/flowapp/models/rules/flowspec.py new file mode 100644 index 00000000..9823aba8 --- /dev/null +++ b/flowapp/models/rules/flowspec.py @@ -0,0 +1,294 @@ +import json +from datetime import datetime +from flowapp import utils +from ..base import db + + +class Flowspec4(db.Model): + id = db.Column(db.Integer, primary_key=True) + source = db.Column(db.String(255)) + source_mask = db.Column(db.Integer) + source_port = db.Column(db.String(255)) + dest = db.Column(db.String(255)) + dest_mask = db.Column(db.Integer) + dest_port = db.Column(db.String(255)) + protocol = db.Column(db.String(255)) + flags = db.Column(db.String(255)) + packet_len = db.Column(db.String(255)) + fragment = db.Column(db.String(255)) + comment = db.Column(db.Text) + expires = db.Column(db.DateTime) + created = db.Column(db.DateTime) + action_id = db.Column(db.Integer, db.ForeignKey("action.id"), nullable=False) + action = db.relationship("Action", backref="flowspec4") + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", backref="flowspec4") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="flowspec4") + rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) + rstate = db.relationship("Rstate", backref="flowspec4") + + def __init__( + self, + source, + source_mask, + source_port, + destination, + destination_mask, + destination_port, + protocol, + flags, + packet_len, + fragment, + expires, + user_id, + org_id, + action_id, + created=None, + comment=None, + rstate_id=1, + ): + self.source = source + self.source_mask = source_mask + self.dest = destination + self.dest_mask = destination_mask + self.source_port = source_port + self.dest_port = destination_port + self.protocol = protocol + self.flags = flags + self.packet_len = packet_len + self.fragment = fragment + self.comment = comment + self.expires = expires + self.user_id = user_id + self.org_id = org_id + self.action_id = action_id + if created is None: + created = datetime.now() + self.created = created + self.rstate_id = rstate_id + + def __eq__(self, other): + """ + Two models are equal if all the network parameters equals. + User_id and time fields can differ. + :param other: other Flowspec4 instance + :return: boolean + """ + return ( + self.source == other.source + and self.source_mask == other.source_mask + and self.dest == other.dest + and self.dest_mask == other.dest_mask + and self.source_port == other.source_port + and self.dest_port == other.dest_port + and self.protocol == other.protocol + and self.flags == other.flags + and self.packet_len == other.packet_len + and self.fragment == other.fragment + and self.action_id == other.action_id + and self.rstate_id == other.rstate_id + ) + + def __ne__(self, other): + """ + Two models are not equal if all the network parameters are not equal. + User_id and time fields can differ. + :param other: other Flowspec4 instance + :return: boolean + """ + compars = ( + self.source == other.source + and self.source_mask == other.source_mask + and self.dest == other.dest + and self.dest_mask == other.dest_mask + and self.source_port == other.source_port + and self.dest_port == other.dest_port + and self.protocol == other.protocol + and self.flags == other.flags + and self.packet_len == other.packet_len + and self.fragment == other.fragment + and self.action_id == other.action_id + and self.rstate_id == other.rstate_id + ) + + return not compars + + def to_dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :return: dictionary + """ + if prefered_format == "timestamp": + expires = int(datetime.timestamp(self.expires)) + created = int(datetime.timestamp(self.created)) + else: + expires = utils.datetime_to_webpicker(self.expires, prefered_format) + created = utils.datetime_to_webpicker(self.created, prefered_format) + + return { + "id": self.id, + "source": self.source, + "source_mask": self.source_mask, + "source_port": self.source_port, + "dest": self.dest, + "dest_mask": self.dest_mask, + "dest_port": self.dest_port, + "protocol": self.protocol, + "flags": self.flags, + "packet_len": self.packet_len, + "fragment": self.fragment, + "comment": self.comment, + "expires": expires, + "created": created, + "action": self.action.name, + "user": self.user.uuid, + "rstate": self.rstate.description, + } + + def dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + return self.to_dict(prefered_format) + + def json(self, prefered_format="yearfirst"): + """ + Serialize to json + :param prefered_format: string with prefered time format + :returns: json + """ + return json.dumps(self.to_dict()) + + +class Flowspec6(db.Model): + id = db.Column(db.Integer, primary_key=True) + source = db.Column(db.String(255)) + source_mask = db.Column(db.Integer) + source_port = db.Column(db.String(255)) + dest = db.Column(db.String(255)) + dest_mask = db.Column(db.Integer) + dest_port = db.Column(db.String(255)) + next_header = db.Column(db.String(255)) + flags = db.Column(db.String(255)) + packet_len = db.Column(db.String(255)) + comment = db.Column(db.Text) + expires = db.Column(db.DateTime) + created = db.Column(db.DateTime) + action_id = db.Column(db.Integer, db.ForeignKey("action.id"), nullable=False) + action = db.relationship("Action", backref="flowspec6") + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", backref="flowspec6") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="flowspec6") + rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) + rstate = db.relationship("Rstate", backref="flowspec6") + + def __init__( + self, + source, + source_mask, + source_port, + destination, + destination_mask, + destination_port, + next_header, + flags, + packet_len, + expires, + user_id, + org_id, + action_id, + created=None, + comment=None, + rstate_id=1, + ): + self.source = source + self.source_mask = source_mask + self.dest = destination + self.dest_mask = destination_mask + self.source_port = source_port + self.dest_port = destination_port + self.next_header = next_header + self.flags = flags + self.packet_len = packet_len + self.comment = comment + self.expires = expires + self.user_id = user_id + self.org_id = org_id + self.action_id = action_id + if created is None: + created = datetime.now() + self.created = created + self.rstate_id = rstate_id + + def __eq__(self, other): + """ + Two models are equal if all the network parameters equals. User_id and time fields can differ. + :param other: other Flowspec4 instance + :return: boolean + """ + return ( + self.source == other.source + and self.source_mask == other.source_mask + and self.dest == other.dest + and self.dest_mask == other.dest_mask + and self.source_port == other.source_port + and self.dest_port == other.dest_port + and self.next_header == other.next_header + and self.flags == other.flags + and self.packet_len == other.packet_len + and self.action_id == other.action_id + and self.rstate_id == other.rstate_id + ) + + def to_dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + if prefered_format == "timestamp": + expires = int(datetime.timestamp(self.expires)) + created = int(datetime.timestamp(self.created)) + else: + expires = utils.datetime_to_webpicker(self.expires, prefered_format) + created = utils.datetime_to_webpicker(self.created, prefered_format) + + return { + "id": str(self.id), + "source": self.source, + "source_mask": self.source_mask, + "source_port": self.source_port, + "dest": self.dest, + "dest_mask": self.dest_mask, + "dest_port": self.dest_port, + "next_header": self.next_header, + "flags": self.flags, + "packet_len": self.packet_len, + "comment": self.comment, + "expires": expires, + "created": created, + "action": self.action.name, + "user": self.user.uuid, + "rstate": self.rstate.description, + } + + def dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + return self.to_dict(prefered_format) + + def json(self, prefered_format="yearfirst"): + """ + Serialize to json + :param prefered_format: string with prefered time format + :returns: json + """ + return json.dumps(self.to_dict()) diff --git a/flowapp/models/rules/rtbh.py b/flowapp/models/rules/rtbh.py new file mode 100644 index 00000000..943fef0f --- /dev/null +++ b/flowapp/models/rules/rtbh.py @@ -0,0 +1,153 @@ +import json +from datetime import datetime +from flowapp import utils +from ..base import db + + +class RTBH(db.Model): + __tablename__ = "RTBH" + + id = db.Column(db.Integer, primary_key=True) + ipv4 = db.Column(db.String(255)) + ipv4_mask = db.Column(db.Integer) + ipv6 = db.Column(db.String(255)) + ipv6_mask = db.Column(db.Integer) + community_id = db.Column(db.Integer, db.ForeignKey("community.id"), nullable=False) + community = db.relationship("Community", backref="rtbh") + comment = db.Column(db.Text) + expires = db.Column(db.DateTime) + created = db.Column(db.DateTime) + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", backref="rtbh") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="rtbh") + rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) + rstate = db.relationship("Rstate", backref="RTBH") + + def __init__( + self, + ipv4, + ipv4_mask, + ipv6, + ipv6_mask, + community_id, + expires, + user_id, + org_id, + comment=None, + created=None, + rstate_id=1, + ): + self.ipv4 = ipv4 + self.ipv4_mask = ipv4_mask + self.ipv6 = ipv6 + self.ipv6_mask = ipv6_mask + self.community_id = community_id + self.expires = expires + self.user_id = user_id + self.org_id = org_id + self.comment = comment + if created is None: + created = datetime.now() + self.created = created + self.rstate_id = rstate_id + + def __eq__(self, other): + """ + Two models are equal if all the network parameters equals. + User_id and time fields can differ. + :param other: other RTBH instance + :return: boolean + """ + return ( + self.ipv4 == other.ipv4 + and self.ipv4_mask == other.ipv4_mask + and self.ipv6 == other.ipv6 + and self.ipv6_mask == other.ipv6_mask + and self.community_id == other.community_id + and self.rstate_id == other.rstate_id + ) + + def __ne__(self, other): + """ + Two models are not equal if all the network parameters are not equal. + User_id and time fields can differ. + :param other: other RTBH instance + :return: boolean + """ + compars = ( + self.ipv4 == other.ipv4 + and self.ipv4_mask == other.ipv4_mask + and self.ipv6 == other.ipv6 + and self.ipv6_mask == other.ipv6_mask + and self.community_id == other.community_id + and self.rstate_id == other.rstate_id + ) + + return not compars + + def update_time(self, form): + self.expires = utils.webpicker_to_datetime(form.expire_date.data) + db.session.commit() + + def to_dict(self, prefered_format="yearfirst"): + """ + Serialize to dict used in API + :param prefered_format: string with prefered time format + :return: dictionary + """ + if prefered_format == "timestamp": + expires = int(datetime.timestamp(self.expires)) + created = int(datetime.timestamp(self.created)) + else: + expires = utils.datetime_to_webpicker(self.expires, prefered_format) + created = utils.datetime_to_webpicker(self.created, prefered_format) + + return { + "id": self.id, + "ipv4": self.ipv4, + "ipv4_mask": self.ipv4_mask, + "ipv6": self.ipv6, + "ipv6_mask": self.ipv6_mask, + "community": self.community.name, + "comment": self.comment, + "expires": expires, + "created": created, + "user": self.user.uuid, + "rstate": self.rstate.description, + } + + def dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + return self.to_dict(prefered_format) + + def json(self, prefered_format="yearfirst"): + """ + Serialize to json + :param prefered_format: string with prefered time format + :returns: json + """ + return json.dumps(self.to_dict()) + + def __repr__(self): + if not self.ipv6 and not self.ipv6_mask: + return f"" + if not self.ipv4 and not self.ipv4_mask: + return f"" + + return f"" + + def __str__(self): + if not self.ipv6 and not self.ipv6_mask: + return f"{self.ipv4}/{self.ipv4_mask}" + if not self.ipv4 and not self.ipv4_mask: + return f"{self.ipv6}/{self.ipv6_mask}" + + return f"{self.ipv4}/{self.ipv4_mask} {self.ipv6}/{self.ipv6_mask}" + + def get_author(self): + return f"{self.user.email} / {self.org}" diff --git a/flowapp/models/rules/whitelist.py b/flowapp/models/rules/whitelist.py new file mode 100644 index 00000000..83c131f6 --- /dev/null +++ b/flowapp/models/rules/whitelist.py @@ -0,0 +1,172 @@ +from flowapp import utils +from ..base import db +from datetime import datetime +from flowapp.constants import RuleTypes, RuleOrigin + + +class Whitelist(db.Model): + id = db.Column(db.Integer, primary_key=True) + ip = db.Column(db.String(255)) + mask = db.Column(db.Integer) + comment = db.Column(db.Text) + expires = db.Column(db.DateTime) + created = db.Column(db.DateTime) + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", backref="whitelist") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="whitelist") + rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) + rstate = db.relationship("Rstate", backref="whitelist") + + def __init__( + self, + ip, + mask, + expires, + user_id, + org_id, + created=None, + comment=None, + rstate_id=1, + ): + self.ip = ip + self.mask = mask + self.expires = expires + self.user_id = user_id + self.org_id = org_id + self.comment = comment + if created is None: + created = datetime.now() + self.created = created + self.rstate_id = rstate_id + + def __eq__(self, other): + """ + Two whitelists are equal if all the network parameters equals. + User_id, org, comment and time fields can differ. + :param other: other Whitelist instance + :return: boolean + """ + return self.ip == other.ip and self.mask == other.mask and self.rstate_id == other.rstate_id + + def to_dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + if prefered_format == "timestamp": + expires = int(datetime.timestamp(self.expires)) + created = int(datetime.timestamp(self.created)) + else: + expires = utils.datetime_to_webpicker(self.expires, prefered_format) + created = utils.datetime_to_webpicker(self.created, prefered_format) + + return { + "id": self.id, + "ip": self.ip, + "mask": self.mask, + "comment": self.comment, + "expires": expires, + "created": created, + "user": self.user.uuid, + "rstate": self.rstate.description, + } + + def dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + return self.to_dict(prefered_format) + + def __repr__(self): + return f"" + + def __str__(self): + return f"{self.ip}/{self.mask}" + + +class RuleWhitelistCache(db.Model): + """ + Cache for whitelisted rules + For each rule we store id and type + Rule origin determines if the rule was created by user or by whitelist + """ + + id = db.Column(db.Integer, primary_key=True) + rid = db.Column(db.Integer) + rtype = db.Column(db.Integer) + rorigin = db.Column(db.Integer) + whitelist_id = db.Column(db.Integer, db.ForeignKey("whitelist.id")) # Add ForeignKey + whitelist = db.relationship("Whitelist", backref="rulewhitelistcache") + + def __init__(self, rid: int, rtype: RuleTypes, whitelist_id: int, rorigin: RuleOrigin = RuleOrigin.USER): + self.rid = rid + self.rtype = rtype.value + self.rorigin = rorigin.value + self.whitelist_id = whitelist_id + + @classmethod + def get_by_whitelist_id(cls, whitelist_id: int): + """ + Get all cache items with the given whitelist ID + + Args: + whitelist_id (int): The ID of the whitelist to filter by + + Returns: + list: All RuleWhitelistCache objects with the specified whitelist_id + """ + return cls.query.filter_by(whitelist_id=whitelist_id).all() + + @classmethod + def clean_by_whitelist_id(cls, whitelist_id: int): + """ + Delete all cache entries with the given whitelist ID from the database + + Args: + whitelist_id (int): The ID of the whitelist to clean + + Returns: + int: Number of rows deleted + """ + deleted = cls.query.filter_by(whitelist_id=whitelist_id).delete() + db.session.commit() + return deleted + + @classmethod + def delete_by_rule_id(cls, rule_id: int): + """ + Delete all cache entries with the given rule ID from the database + + Args: + rule_id (int): The ID of the rule to clean + + Returns: + int: Number of rows deleted + """ + deleted = cls.query.filter_by(rid=rule_id).delete() + db.session.commit() + return deleted + + @classmethod + def count_by_rule(cls, rule_id: int, rule_type: RuleTypes): + """ + Count the number of cache entries for the given rule + + Args: + rule_id (int): The ID of the rule to count + rule_type (RuleTypes): The type of the rule + + Returns: + int: Number of cache entries + """ + return cls.query.filter_by(rid=rule_id, rtype=rule_type.value).count() + + def __repr__(self): + return f"" + + def __str__(self): + return f"{self.rid} {self.rtype} {self.rorigin}" diff --git a/flowapp/models/user.py b/flowapp/models/user.py new file mode 100644 index 00000000..dcb2d7eb --- /dev/null +++ b/flowapp/models/user.py @@ -0,0 +1,79 @@ +from sqlalchemy import event +from .base import db, user_role, user_organization +from .organization import Organization + + +class User(db.Model): + """ + App User + """ + + id = db.Column(db.Integer, primary_key=True) + uuid = db.Column(db.String(180), unique=True) + comment = db.Column(db.String(500)) + email = db.Column(db.String(255)) + name = db.Column(db.String(255)) + phone = db.Column(db.String(255)) + apikeys = db.relationship("ApiKey", back_populates="user", lazy="dynamic") + machineapikeys = db.relationship("MachineApiKey", back_populates="user", lazy="dynamic") + role = db.relationship("Role", secondary=user_role, lazy="dynamic", backref="user") + + organization = db.relationship("Organization", secondary=user_organization, lazy="dynamic", backref="user") + + def __init__(self, uuid, name=None, phone=None, email=None, comment=None): + self.uuid = uuid + self.phone = phone + self.name = name + self.email = email + self.comment = comment + + def update(self, form): + """ + update the user with values from form object + :param form: flask form from request + :return: None + """ + self.uuid = form.uuid.data + self.name = form.name.data + self.email = form.email.data + self.phone = form.phone.data + self.comment = form.comment.data + + # first clear existing roles and orgs + for role in self.role: + self.role.remove(role) + for org in self.organization: + self.organization.remove(org) + + for role_id in form.role_ids.data: + my_role = db.session.query(Role).filter_by(id=role_id).first() + if my_role not in self.role: + self.role.append(my_role) + + for org_id in form.org_ids.data: + my_org = db.session.query(Organization).filter_by(id=org_id).first() + if my_org not in self.organization: + self.organization.append(my_org) + + db.session.commit() + + +class Role(db.Model): + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(20), unique=True) + description = db.Column(db.String(260)) + + def __init__(self, name, description): + self.name = name + self.description = description + + def __repr__(self): + return self.name + + +# Event listeners for Role +@event.listens_for(Role.__table__, "after_create") +def insert_initial_roles(table, conn, *args, **kwargs): + conn.execute(table.insert().values(name="view", description="just view, no edit")) + conn.execute(table.insert().values(name="user", description="can edit")) + conn.execute(table.insert().values(name="admin", description="admin")) diff --git a/flowapp/models/utils.py b/flowapp/models/utils.py new file mode 100644 index 00000000..d0bf1208 --- /dev/null +++ b/flowapp/models/utils.py @@ -0,0 +1,377 @@ +"""Utility functions for models""" + +from datetime import datetime +from flowapp import utils +from flowapp.constants import RuleTypes +from flask import current_app + +from flowapp.models.rules.whitelist import Whitelist +from .base import db +from .user import User, Role +from .organization import Organization +from .community import Community +from .rules.flowspec import Flowspec4, Flowspec6 +from .rules.rtbh import RTBH +from .rules.base import Action + + +def check_rule_limit(org_id: int, rule_type: RuleTypes) -> bool: + """ + Check if the organization has reached the rule limit + :param org_id: integer organization id + :param rule_type: RuleType rule type + :return: boolean + """ + flowspec_limit = current_app.config.get("FLOWSPEC_MAX_RULES", 9000) + rtbh_limit = current_app.config.get("RTBH_MAX_RULES", 100000) + fs4 = db.session.query(Flowspec4).filter_by(rstate_id=1).count() + fs6 = db.session.query(Flowspec6).filter_by(rstate_id=1).count() + rtbh = db.session.query(RTBH).filter_by(rstate_id=1).count() + + # check the organization limits + org = Organization.query.filter_by(id=org_id).first() + if rule_type == RuleTypes.IPv4 and org.limit_flowspec4 > 0: + count = db.session.query(Flowspec4).filter_by(org_id=org_id, rstate_id=1).count() + return count >= org.limit_flowspec4 or fs4 >= flowspec_limit + if rule_type == RuleTypes.IPv6 and org.limit_flowspec6 > 0: + count = db.session.query(Flowspec6).filter_by(org_id=org_id, rstate_id=1).count() + return count >= org.limit_flowspec6 or fs6 >= flowspec_limit + if rule_type == RuleTypes.RTBH and org.limit_rtbh > 0: + count = db.session.query(RTBH).filter_by(org_id=org_id, rstate_id=1).count() + return count >= org.limit_rtbh or rtbh >= rtbh_limit + + return False + + +def check_global_rule_limit(rule_type: RuleTypes) -> bool: + flowspec4_limit = current_app.config.get("FLOWSPEC4_MAX_RULES", 9000) + flowspec6_limit = current_app.config.get("FLOWSPEC6_MAX_RULES", 9000) + rtbh_limit = current_app.config.get("RTBH_MAX_RULES", 100000) + fs4 = db.session.query(Flowspec4).filter_by(rstate_id=1).count() + fs6 = db.session.query(Flowspec6).filter_by(rstate_id=1).count() + rtbh = db.session.query(RTBH).filter_by(rstate_id=1).count() + + # check the global limits if the organization limits are not set + + if rule_type == RuleTypes.IPv4: + return fs4 >= flowspec4_limit + if rule_type == RuleTypes.IPv6: + return fs6 >= flowspec6_limit + if rule_type == RuleTypes.RTBH: + return rtbh >= rtbh_limit + + return False + + +def get_whitelist_model_if_exists(form_data): + """ + Check if the record in database exist + ip, mask should match + expires, rstate_id, user_id, org_id, created, comment can be different + """ + record = ( + db.session.query(Whitelist) + .filter( + Whitelist.ip == form_data["ip"], + Whitelist.mask == form_data["mask"], + ) + .first() + ) + + if record: + return record + + return False + + +def get_ipv4_model_if_exists(form_data, rstate_id=1): + """ + Check if the record in database exist + Source and destination addresses, protocol, flags, action and packet_len should match + Other fields can be different + """ + record = ( + db.session.query(Flowspec4) + .filter( + Flowspec4.source == form_data["source"], + Flowspec4.source_mask == form_data["source_mask"], + Flowspec4.source_port == form_data["source_port"], + Flowspec4.dest == form_data["dest"], + Flowspec4.dest_mask == form_data["dest_mask"], + Flowspec4.dest_port == form_data["dest_port"], + Flowspec4.protocol == form_data["protocol"], + Flowspec4.flags == ";".join(form_data["flags"]), + Flowspec4.packet_len == form_data["packet_len"], + Flowspec4.action_id == form_data["action"], + Flowspec4.rstate_id == rstate_id, + ) + .first() + ) + + if record: + return record + + return False + + +def get_ipv6_model_if_exists(form_data, rstate_id=1): + """ + Check if the record in database exist + """ + record = ( + db.session.query(Flowspec6) + .filter( + Flowspec6.source == form_data["source"], + Flowspec6.source_mask == form_data["source_mask"], + Flowspec6.source_port == form_data["source_port"], + Flowspec6.dest == form_data["dest"], + Flowspec6.dest_mask == form_data["dest_mask"], + Flowspec6.dest_port == form_data["dest_port"], + Flowspec6.next_header == form_data["next_header"], + Flowspec6.flags == ";".join(form_data["flags"]), + Flowspec6.packet_len == form_data["packet_len"], + Flowspec6.action_id == form_data["action"], + Flowspec6.rstate_id == rstate_id, + ) + .first() + ) + + if record: + return record + + return False + + +def get_rtbh_model_if_exists(form_data): + """ + Check RTBH rule exist in database + IPv4, IPv6 and community_id should match + Rule can be in any state and have different expires, user_id, org_id, created, comment + """ + + record = ( + db.session.query(RTBH) + .filter( + RTBH.ipv4 == form_data["ipv4"], + RTBH.ipv4_mask == form_data["ipv4_mask"], + RTBH.ipv6 == form_data["ipv6"], + RTBH.ipv6_mask == form_data["ipv6_mask"], + RTBH.community_id == form_data["community"], + ) + .first() + ) + + if record: + return record + + return False + + +def insert_users(users): + """ + inser list of users {name: string, role_id: integer} to db + """ + for user in users: + r = Role.query.filter_by(id=user["role_id"]).first() + o = Organization.query.filter_by(id=user["org_id"]).first() + u = User(uuid=user["name"]) + u.role.append(r) + u.organization.append(o) + db.session.add(u) + + db.session.commit() + + +def insert_user( + uuid: str, + role_ids: list, + org_ids: list, + name: str = None, + phone: str = None, + email: str = None, + comment: str = None, +): + """ + insert new user with multiple roles and organizations + :param uuid: string unique user id (eppn or similar) + :param phone: string phone number + :param name: string user name + :param email: string email + :param comment: string comment / notice + :param role_ids: list of roles + :param org_ids: list of orgs + :return: None + """ + u = User(uuid=uuid, name=name, phone=phone, comment=comment, email=email) + + for role_id in role_ids: + r = Role.query.filter_by(id=role_id).first() + u.role.append(r) + + for org_id in org_ids: + o = Organization.query.filter_by(id=org_id).first() + u.organization.append(o) + + db.session.add(u) + db.session.commit() + + +def get_user_nets(user_id): + """ + Return list of network ranges for all user organization + """ + user = db.session.query(User).filter_by(id=user_id).first() + orgs = user.organization + result = [] + for org in orgs: + result.extend(org.arange.split()) + + return result + + +def get_user_orgs_choices(user_id): + """ + Return list of orgs as choices for form + """ + user = db.session.query(User).filter_by(id=user_id).first() + orgs = user.organization + + return [(g.id, g.name) for g in orgs] + + +def get_user_actions(user_roles): + """ + Return list of actions based on current user role + """ + max_role = max(user_roles) + print(max_role) + if max_role == 3: + actions = db.session.query(Action).order_by("id").all() + else: + actions = db.session.query(Action).filter_by(role_id=max_role).order_by("id").all() + result = [(g.id, g.name) for g in actions] + print(actions, result) + return result + + +def get_user_communities(user_roles): + """ + Return list of communities based on current user role + """ + max_role = max(user_roles) + if max_role == 3: + communities = db.session.query(Community).order_by("id") + else: + communities = db.session.query(Community).filter_by(role_id=max_role).order_by("id") + + return [(g.id, g.name) for g in communities] + + +def get_existing_action(name=None, command=None): + """ + return Action with given name or command if the action exists + return None if action not exists + :param name: string action name + :param command: string action command + :return: action id + """ + action = Action.query.filter((Action.name == name) | (Action.command == command)).first() + return action.id if hasattr(action, "id") else None + + +def get_existing_community(name=None): + """ + return Community with given name or command if the action exists + return None if action not exists + :param name: string action name + :param command: string action command + :return: action id + """ + community = Community.query.filter(Community.name == name).first() + return community.id if hasattr(community, "id") else None + + +def get_ip_rules(rule_type, rule_state, sort="expires", order="desc"): + """ + Returns list of rules sorted by sort column ordered asc or desc + :param sort: sorting column + :param order: asc or desc + :return: list + """ + + today = datetime.now() + comp_func = utils.get_comp_func(rule_state) + + if rule_type == "ipv4": + sorter_ip4 = getattr(Flowspec4, sort, Flowspec4.id) + sorting_ip4 = getattr(sorter_ip4, order) + if comp_func: + rules4 = ( + db.session.query(Flowspec4).filter(comp_func(Flowspec4.expires, today)).order_by(sorting_ip4()).all() + ) + else: + rules4 = db.session.query(Flowspec4).order_by(sorting_ip4()).all() + + return rules4 + + if rule_type == "ipv6": + sorter_ip6 = getattr(Flowspec6, sort, Flowspec6.id) + sorting_ip6 = getattr(sorter_ip6, order) + if comp_func: + rules6 = ( + db.session.query(Flowspec6).filter(comp_func(Flowspec6.expires, today)).order_by(sorting_ip6()).all() + ) + else: + rules6 = db.session.query(Flowspec6).order_by(sorting_ip6()).all() + + return rules6 + + if rule_type == "rtbh": + sorter_rtbh = getattr(RTBH, sort, RTBH.id) + sorting_rtbh = getattr(sorter_rtbh, order) + + if comp_func: + rules_rtbh = db.session.query(RTBH).filter(comp_func(RTBH.expires, today)).order_by(sorting_rtbh()).all() + + else: + rules_rtbh = db.session.query(RTBH).order_by(sorting_rtbh()).all() + + return rules_rtbh + + if rule_type == "whitelist": + sorter_whitelist = getattr(Whitelist, sort, Whitelist.id) + sorting_whitelist = getattr(sorter_whitelist, order) + + if comp_func: + rules_whitelist = ( + db.session.query(Whitelist) + .filter(comp_func(Whitelist.expires, today)) + .order_by(sorting_whitelist()) + .all() + ) + + else: + rules_whitelist = db.session.query(Whitelist).order_by(sorting_whitelist()).all() + + return rules_whitelist + + +def get_user_rules_ids(user_id, rule_type): + """ + Returns list of rule ids belonging to user + :param user_id: user id + :param rule_type: ipv4, ipv6 or rtbh + :return: list + """ + + if rule_type == "ipv4": + rules4 = db.session.query(Flowspec4.id).filter_by(user_id=user_id).all() + return [int(x[0]) for x in rules4] + + if rule_type == "ipv6": + rules6 = db.session.query(Flowspec6.id).order_by(Flowspec6.expires.desc()).all() + return [int(x[0]) for x in rules6] + + if rule_type == "rtbh": + rules_rtbh = db.session.query(RTBH.id).filter_by(user_id=user_id).all() + return [int(x[0]) for x in rules_rtbh] diff --git a/flowapp/output.py b/flowapp/output.py index 3dde8221..0d023671 100644 --- a/flowapp/output.py +++ b/flowapp/output.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, asdict from datetime import datetime +from typing import Dict, Union, Callable, Any import requests import pika @@ -11,9 +12,12 @@ from flask import current_app from flowapp import db, messages +from flowapp.constants import RuleTypes from flowapp.models import Log +from flowapp.models import Flowspec4, Flowspec6, RTBH -ROUTE_MODELS = { + +ROUTE_MODELS: Dict[int, Callable[[Any], str]] = { 1: messages.create_rtbh, 4: messages.create_ipv4, 6: messages.create_ipv6, @@ -21,21 +25,21 @@ class RouteSources: - UI = "UI" - API = "API" + UI: str = "UI" + API: str = "API" @dataclass class Route: author: str - source: RouteSources + source: str # Using str instead of RouteSources for flexibility command: str - def __dict__(self): + def __dict__(self) -> Dict[str, str]: return asdict(self) -def announce_route(route: Route): +def announce_route(route: Route) -> None: """ Dispatch route as dict to ExaBGP API API must be set in app config.py @@ -48,7 +52,7 @@ def announce_route(route: Route): announce_to_http(asdict(route)) -def announce_to_http(route): +def announce_to_http(route: Dict[str, str]) -> None: """ Announce route to ExaBGP HTTP API process """ @@ -64,7 +68,7 @@ def announce_to_http(route): current_app.logger.debug(f"Testing: {route}") -def announce_to_rabbitmq(route): +def announce_to_rabbitmq(route: Dict[str, str]) -> None: """ Announce rout to ExaBGP RabbitMQ API process """ @@ -80,48 +84,57 @@ def announce_to_rabbitmq(route): credentials, ) - connection = pika.BlockingConnection(parameters) - channel = connection.channel() - channel.queue_declare(queue=queue) - channel.basic_publish(exchange="", routing_key=queue, body=json.dumps(route)) + with pika.BlockingConnection(parameters) as connection: + channel = connection.channel() + channel.queue_declare(queue=queue) + channel.basic_publish(exchange="", routing_key=queue, body=json.dumps(route)) else: - current_app.logger.debug("Testing: {route}") + current_app.logger.debug(f"Testing: {route}") -def log_route(user_id, route_model, rule_type, author): +def log_route(user_id: int, route_model: Union[RTBH, Flowspec4, Flowspec6], rule_type: RuleTypes, author: str) -> None: """ Convert route to EXAFS message and log it to database :param user_id : int curent user - :param route_model: model with route object - :param rule_type: string + :param route_model: model with route object (RTBH, Flowspec4, or Flowspec6) + :param rule_type: RuleTypes enum + :param author: str name of the author :return: None """ - converter = ROUTE_MODELS[rule_type] + rule_type_value = rule_type.value + converter = ROUTE_MODELS[rule_type_value] task = converter(route_model) log = Log( time=datetime.now(), task=task, - rule_type=rule_type, + rule_type=rule_type_value, rule_id=route_model.id, user_id=user_id, author=author, ) db.session.add(log) + current_app.logger.info(log) db.session.commit() -def log_withdraw(user_id, task, rule_type, deleted_id, author): +def log_withdraw(user_id: int, task: str, rule_type: RuleTypes, deleted_id: int, author: str) -> None: """ Log the withdraw command to database - :param task: command message + :param user_id: int user ID + :param task: str command message + :param rule_type: RuleTypes enum + :param deleted_id: int ID of deleted rule + :param author: str name of the author + :return: None """ log = Log( time=datetime.now(), task=task, - rule_type=rule_type, + rule_type=rule_type.value, rule_id=deleted_id, user_id=user_id, author=author, ) db.session.add(log) + current_app.logger.info(log) db.session.commit() diff --git a/flowapp/services/__init__.py b/flowapp/services/__init__.py new file mode 100644 index 00000000..be4b838b --- /dev/null +++ b/flowapp/services/__init__.py @@ -0,0 +1,21 @@ +from .rule_service import ( + create_or_update_ipv4_rule, + create_or_update_ipv6_rule, + create_or_update_rtbh_rule, + reactivate_rule, +) + +from .whitelist_service import create_or_update_whitelist, delete_whitelist, delete_expired_whitelists + +from .base import announce_all_routes + +__all__ = [ + create_or_update_ipv4_rule, + create_or_update_ipv6_rule, + create_or_update_rtbh_rule, + create_or_update_whitelist, + delete_whitelist, + delete_expired_whitelists, + announce_all_routes, + reactivate_rule, +] diff --git a/flowapp/services/base.py b/flowapp/services/base.py new file mode 100644 index 00000000..77adeb07 --- /dev/null +++ b/flowapp/services/base.py @@ -0,0 +1,93 @@ +from datetime import datetime +from operator import ge, lt +from flowapp import constants, db, messages +from flowapp.constants import ANNOUNCE, WITHDRAW +from flowapp.models import RTBH, Flowspec4, Flowspec6 +from flowapp.output import Route, RouteSources, announce_route + + +def announce_rtbh_route(model: RTBH, author: str) -> None: + """ + Announce RTBH route if rule is in active state + """ + if model.rstate_id == 1: + command = messages.create_rtbh(model, ANNOUNCE) + route = Route( + author=author, + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + +def withdraw_rtbh_route(model: RTBH) -> None: + """ + Withdraw RTBH route if rule is in whitelist state + """ + if model.rstate_id == 4: + command = messages.create_rtbh(model, WITHDRAW) + route = Route( + author=model.get_author(), + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + +def announce_all_routes(action=constants.ANNOUNCE): + """ + get routes from db and send it to ExaBGB api + + @TODO take the request away, use some kind of messaging (maybe celery?) + :param action: action with routes - announce valid routes or withdraw expired routes + """ + today = datetime.now() + comp_func = ge if action == constants.ANNOUNCE else lt + + rules4 = ( + db.session.query(Flowspec4) + .filter(Flowspec4.rstate_id == 1) + .filter(comp_func(Flowspec4.expires, today)) + .order_by(Flowspec4.expires.desc()) + .all() + ) + rules6 = ( + db.session.query(Flowspec6) + .filter(Flowspec6.rstate_id == 1) + .filter(comp_func(Flowspec6.expires, today)) + .order_by(Flowspec6.expires.desc()) + .all() + ) + rules_rtbh = ( + db.session.query(RTBH) + .filter(RTBH.rstate_id == 1) + .filter(comp_func(RTBH.expires, today)) + .order_by(RTBH.expires.desc()) + .all() + ) + + messages_v4 = [messages.create_ipv4(rule, action) for rule in rules4] + messages_v6 = [messages.create_ipv6(rule, action) for rule in rules6] + messages_rtbh = [messages.create_rtbh(rule, action) for rule in rules_rtbh] + + messages_all = [] + messages_all.extend(messages_v4) + messages_all.extend(messages_v6) + messages_all.extend(messages_rtbh) + + author_action = "announce all" if action == constants.ANNOUNCE else "withdraw all expired" + + for command in messages_all: + route = Route( + author=f"System call / {author_action} rules", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + if action == constants.WITHDRAW: + for ruleset in [rules4, rules6, rules_rtbh]: + for rule in ruleset: + rule.rstate_id = 2 + + db.session.commit() diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py new file mode 100644 index 00000000..997bd19c --- /dev/null +++ b/flowapp/services/rule_service.py @@ -0,0 +1,538 @@ +# flowapp/services/rule_service.py +""" +Service module for rule operations. + +This module provides business logic functions for creating, updating, +and managing flow rules, separating these concerns from HTTP handling. +""" + +from datetime import datetime, timedelta +from typing import Dict, List, Tuple, Union + +from flask import current_app + +from flowapp import db, messages +from flowapp.constants import WITHDRAW, RuleOrigin, RuleTypes, ANNOUNCE +from flowapp.models import ( + get_ipv4_model_if_exists, + get_ipv6_model_if_exists, + get_rtbh_model_if_exists, + Flowspec4, + Flowspec6, + RTBH, + Whitelist, +) +from flowapp.models.rules.whitelist import RuleWhitelistCache +from flowapp.models.utils import check_global_rule_limit, check_rule_limit +from flowapp.output import ROUTE_MODELS, Route, announce_route, log_route, RouteSources, log_withdraw +from flowapp.services.base import announce_rtbh_route +from flowapp.services.whitelist_common import ( + Relation, + add_rtbh_rule_to_cache, + create_rtbh_from_whitelist_parts, + subtract_network, + whitelist_rtbh_rule, + check_rule_against_whitelists, +) +from flowapp.services.whitelist_service import create_or_update_whitelist +from flowapp.utils import round_to_ten_minutes, get_state_by_time, quote_to_ent + + +def reactivate_rule( + rule_type: RuleTypes, + rule_id: int, + expires: datetime, + comment: str, + user_id: int, + org_id: int, + user_email: str, + org_name: str, +) -> Tuple[Union[RTBH, Flowspec4, Flowspec6], List[str]]: + """ + Reactivate a rule by setting a new expiration time. + + Args: + rule_type: Type of rule (RTBH, IPv4, IPv6) + rule_id: ID of the rule to reactivate + expires: New expiration datetime + comment: Updated comment + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + + Returns: + Tuple containing (rule_model, messages) + """ + model_name = {RuleTypes.RTBH: RTBH, RuleTypes.IPv4: Flowspec4, RuleTypes.IPv6: Flowspec6}[rule_type] + + model = db.session.get(model_name, rule_id) + if not model: + return None, ["Rule not found"] + + flashes = [] + + # Check if rule will be reactivated + state = get_state_by_time(expires) + + # Check global limit + if state == 1 and check_global_rule_limit(rule_type.value): + return model, ["global_limit_reached"] + + # Check org limit + if state == 1 and check_rule_limit(org_id, rule_type=rule_type.value): + return model, ["limit_reached"] + + # Set new expiration date + model.expires = expires + # Set again the active state, if the rule is not whitelisted RTBH + if rule_type == RuleTypes.RTBH: + model = check_rtbh_whitelisted(model, user_id, flashes, f"{user_email} / {org_name}") + else: + model.rstate_id = state + + model.comment = comment + db.session.commit() + + route_model = ROUTE_MODELS[rule_type.value] + + if model.rstate_id == 1: + flashes.append("Rule successfully updated, state set to active.") + # Announce route + command = route_model(model, ANNOUNCE) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + # Log changes + log_route( + user_id, + model, + rule_type, + f"{user_email} / {org_name}", + ) + else: + # Withdraw route + command = route_model(model, WITHDRAW) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + # Log changes + log_withdraw( + user_id, + route.command, + rule_type, + model.id, + f"{user_email} / {org_name}", + ) + if model.rstate_id == 4: + flashes.append("Rule successfully updated, state set to whitelisted.") + else: + flashes.append("Rule successfully updated, state set to inactive.") + + return model, flashes + + +def create_or_update_ipv4_rule( + form_data: Dict, user_id: int, org_id: int, user_email: str, org_name: str +) -> Tuple[Flowspec4, str]: + """ + Create a new IPv4 rule or update an existing one. + + Args: + form_data: Validated form data + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + + Returns: + Tuple containing (rule_model, message) + """ + # Check for existing model + model = get_ipv4_model_if_exists(form_data, 1) + + if model: + model.expires = round_to_ten_minutes(form_data["expires"]) + flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value." + else: + # Create new model + model = Flowspec4( + source=form_data["source"], + source_mask=form_data["source_mask"], + source_port=form_data["source_port"], + destination=form_data["dest"], + destination_mask=form_data["dest_mask"], + destination_port=form_data["dest_port"], + protocol=form_data["protocol"], + flags=";".join(form_data["flags"]), + packet_len=form_data["packet_len"], + fragment=";".join(form_data["fragment"]), + expires=round_to_ten_minutes(form_data["expires"]), + comment=quote_to_ent(form_data["comment"]), + action_id=form_data["action"], + user_id=user_id, + org_id=org_id, + rstate_id=get_state_by_time(form_data["expires"]), + ) + db.session.add(model) + flash_message = "IPv4 Rule saved" + + db.session.commit() + + # Announce route if model is in active state + if model.rstate_id == 1: + command = messages.create_ipv4(model, ANNOUNCE) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + # Log changes + log_route( + user_id, + model, + RuleTypes.IPv4, + f"{user_email} / {org_name}", + ) + + return model, flash_message + + +def create_or_update_ipv6_rule( + form_data: Dict, user_id: int, org_id: int, user_email: str, org_name: str +) -> Tuple[Flowspec6, str]: + """ + Create a new IPv6 rule or update an existing one. + + Args: + form_data: Validated form data + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + + Returns: + Tuple containing (rule_model, message) + """ + # Check for existing model + model = get_ipv6_model_if_exists(form_data, 1) + + if model: + model.expires = round_to_ten_minutes(form_data["expires"]) + flash_message = "Existing IPv6 Rule found. Expiration time was updated to new value." + else: + # Create new model + model = Flowspec6( + source=form_data["source"], + source_mask=form_data["source_mask"], + source_port=form_data["source_port"], + destination=form_data["dest"], + destination_mask=form_data["dest_mask"], + destination_port=form_data["dest_port"], + next_header=form_data["next_header"], + flags=";".join(form_data["flags"]), + packet_len=form_data["packet_len"], + expires=round_to_ten_minutes(form_data["expires"]), + comment=quote_to_ent(form_data["comment"]), + action_id=form_data["action"], + user_id=user_id, + org_id=org_id, + rstate_id=get_state_by_time(form_data["expires"]), + ) + db.session.add(model) + flash_message = "IPv6 Rule saved" + + db.session.commit() + + # Announce routes + if model.rstate_id == 1: + command = messages.create_ipv6(model, ANNOUNCE) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + # Log changes + log_route( + user_id, + model, + RuleTypes.IPv6, + f"{user_email} / {org_name}", + ) + + return model, flash_message + + +def create_or_update_rtbh_rule( + form_data: Dict, user_id: int, org_id: int, user_email: str, org_name: str +) -> Tuple[RTBH, List]: + """ + Create a new RTBH rule or update an existing one. + + Args: + form_data: Validated form data + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + + Returns: + Tuple containing (rule_model, message) + """ + # Check for existing model + model = get_rtbh_model_if_exists(form_data) + flashes = [] + if model: + model.expires = round_to_ten_minutes(form_data["expires"]) + flashes.append("Existing RTBH Rule found. Expiration time was updated to new value.") + else: + # Create new model + model = RTBH( + ipv4=form_data["ipv4"], + ipv4_mask=form_data["ipv4_mask"], + ipv6=form_data["ipv6"], + ipv6_mask=form_data["ipv6_mask"], + community_id=form_data["community"], + expires=round_to_ten_minutes(form_data["expires"]), + comment=quote_to_ent(form_data["comment"]), + user_id=user_id, + org_id=org_id, + rstate_id=get_state_by_time(form_data["expires"]), + ) + db.session.add(model) + flashes.append("RTBH Rule saved") + + db.session.commit() + + # rule author for logging and announcing + author = f"{user_email} / {org_name}" + # Check if rule is whitelisted + model = check_rtbh_whitelisted(model, user_id, flashes, author) + + announce_rtbh_route(model, author=author) + # Log changes + log_route(user_id, model, RuleTypes.RTBH, author) + + return model, flashes + + +def check_rtbh_whitelisted(model: RTBH, user_id: int, flashes: List[str], author: str) -> None: + # Check if rule is whitelisted + allowed_communities = current_app.config["ALLOWED_COMMUNITIES"] + if model.community_id in allowed_communities: + # get all not expired whitelists + whitelists = db.session.query(Whitelist).filter(Whitelist.expires > datetime.now()).all() + wl_cache = map_whitelists_to_strings(whitelists) + results = check_rule_against_whitelists(str(model), wl_cache.keys()) + # check rule against whitelists + model = evaluate_rtbh_against_whitelists_check_results(user_id, model, flashes, author, wl_cache, results) + return model + + +def evaluate_rtbh_against_whitelists_check_results( + user_id: int, + model: RTBH, + flashes: List[str], + author: str, + wl_cache: Dict[str, Whitelist], + results: List[Tuple[str, str, Relation]], +) -> RTBH: + """ + Evaluate RTBH rule against whitelist check results. + Process all results for cases where rule is whitelisted by several whitelists. + """ + for rule, whitelist_key, relation in results: + match relation: + case Relation.EQUAL: + model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) + msg = f"RTBH Rule {model.id} {model} is equal to active whitelist {whitelist_key}. Rule is whitelisted." + flashes.append(msg) + current_app.logger.info(msg) + case Relation.SUBNET: + parts = subtract_network(target=str(model), whitelist=whitelist_key) + wl_id = wl_cache[whitelist_key].id + msg = f"RTBH Rule {model.id} {model} is supernet of active whitelist {whitelist_key}. Rule is whitelisted, {len(parts)} subnet rules created." + flashes.append(msg) + current_app.logger.info(msg) + for network in parts: + new_rule = create_rtbh_from_whitelist_parts(model, wl_id, whitelist_key, network, author, user_id) + msg = ( + f"Created RTBH rule {new_rule.id} {new_rule} for {network} parted by whitelist {whitelist_key}" + ) + flashes.append(msg) + current_app.logger.info(msg) + model.rstate_id = 4 + add_rtbh_rule_to_cache(model, wl_id, RuleOrigin.USER) + db.session.commit() + case Relation.SUPERNET: + model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) + msg = ( + f"RTBH Rule {model.id} {model} is subnet of active whitelist {whitelist_key}. Rule is whitelisted." + ) + current_app.logger.info(msg) + flashes.append(msg) + return model + + +def map_whitelists_to_strings(whitelists: List[Whitelist]) -> Dict[str, Whitelist]: + return {str(w): w for w in whitelists} + + +def delete_rule( + rule_type: RuleTypes, rule_id: int, user_id: int, user_email: str, org_name: str, allowed_rule_ids: List[int] = None +) -> Tuple[bool, str]: + """ + Delete a rule with the given id and type. + + Args: + rule_type: Type of rule (RTBH, IPv4, IPv6) + rule_id: ID of the rule to delete + user_id: Current user ID + user_email: User email for logging + org_name: Organization name for logging + allowed_rule_ids: List of rule IDs the user is allowed to delete, None means no restriction + + Returns: + Tuple containing (success, message) + """ + model_class = {RuleTypes.RTBH: RTBH, RuleTypes.IPv4: Flowspec4, RuleTypes.IPv6: Flowspec6}[rule_type] + + route_model = ROUTE_MODELS[rule_type.value] + + model = db.session.get(model_class, rule_id) + if not model: + return False, "Rule not found" + + # Check permission if allowed_rule_ids is provided + if allowed_rule_ids is not None and model.id not in allowed_rule_ids: + return False, "You cannot delete this rule" + + # Withdraw route + command = route_model(model, WITHDRAW) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + # Log withdrawal + log_withdraw( + user_id, + route.command, + rule_type, + model.id, + f"{user_email} / {org_name}", + ) + + # Special handling for RTBH rules + if rule_type == RuleTypes.RTBH: + current_app.logger.debug(f"Deleting RTBH rule {rule_id} from cache") + RuleWhitelistCache.delete_by_rule_id(rule_id) + + # Delete from database + db.session.delete(model) + db.session.commit() + + return True, "Rule deleted successfully" + + +def delete_rtbh_and_create_whitelist( + rule_id: int, + user_id: int, + org_id: int, + user_email: str, + org_name: str, + allowed_rule_ids: List[int] = None, + whitelist_expires: datetime = None, +) -> Tuple[bool, List[str], Union[Whitelist, None]]: + """ + Delete an RTBH rule and create a whitelist entry from it. + + Args: + rule_id: ID of the RTBH rule to delete + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + allowed_rule_ids: List of rule IDs the user is allowed to delete + whitelist_expires: Expiration time for the whitelist entry (default: 7 days from now) + + Returns: + Tuple containing (success, messages, whitelist_model) + """ + messages = [] + + # First get the RTBH rule to extract its data + model = db.session.get(RTBH, rule_id) + if not model: + return False, ["RTBH rule not found"], None + + # Check permission if allowed_rule_ids is provided + if allowed_rule_ids is not None and model.id not in allowed_rule_ids: + return False, ["You cannot delete this rule"], None + + # Extract data for whitelist + if model.ipv4: + ip = model.ipv4 + mask = model.ipv4_mask + elif model.ipv6: + ip = model.ipv6 + mask = model.ipv6_mask + else: + return False, ["RTBH rule has no IP address"], None + + # Set default whitelist expiration time if not provided + if whitelist_expires is None: + whitelist_expires = datetime.now() + timedelta(days=7) + + # Prepare whitelist data + # Create base comment + comment_text = f"Created from RTBH rule {model} {rule_id}" + # Append the rule's comment only if it exists + if model.comment: + comment_text += f": {model.comment}" + + whitelist_data = { + "ip": ip, + "mask": mask, + "expires": whitelist_expires, + "comment": comment_text, + } + + # Delete the RTBH rule + success, delete_message = delete_rule( + rule_type=RuleTypes.RTBH, + rule_id=rule_id, + user_id=user_id, + user_email=user_email, + org_name=org_name, + allowed_rule_ids=allowed_rule_ids, + ) + + if not success: + return False, [delete_message], None + + messages.append(delete_message) + + # Create the whitelist entry + try: + whitelist_model, whitelist_messages = create_or_update_whitelist( + form_data=whitelist_data, user_id=user_id, org_id=org_id, user_email=user_email, org_name=org_name + ) + messages.extend(whitelist_messages) + return True, messages, whitelist_model + except Exception as e: + current_app.logger.exception(f"Error creating whitelist entry: {e}") + messages.append(f"Rule deleted but failed to create whitelist: {str(e)}") + return False, messages, None diff --git a/flowapp/services/whitelist_common.py b/flowapp/services/whitelist_common.py new file mode 100644 index 00000000..3807dca6 --- /dev/null +++ b/flowapp/services/whitelist_common.py @@ -0,0 +1,244 @@ +from enum import Enum, auto +from functools import lru_cache +import ipaddress +from typing import List, Tuple +from flowapp import db +from flowapp.constants import RuleOrigin, RuleTypes +from flowapp.models import RTBH, RuleWhitelistCache, Whitelist +from flowapp.output import log_route +from flowapp.services.base import announce_rtbh_route + + +def add_rtbh_rule_to_cache(model: RTBH, whitelist_id: int, rule_origin: RuleOrigin = RuleOrigin.USER) -> None: + """ + Add RTBH rule to whitelist cache + """ + cache = RuleWhitelistCache(rid=model.id, rtype=RuleTypes.RTBH, whitelist_id=whitelist_id, rorigin=rule_origin) + db.session.add(cache) + db.session.commit() + + +def whitelist_rtbh_rule(model: RTBH, whitelist: Whitelist) -> RTBH: + """ + Whitelist RTBH rule. + Set rule state to 4 - whitelisted rule, do not announce later + Add to whitelist cache + """ + model.rstate_id = 4 + db.session.commit() + add_rtbh_rule_to_cache(model, whitelist.id, RuleOrigin.USER) + return model + + +class Relation(Enum): + """ + Enum to represent the relation between Whitelist to Rule Relation + Subnet: Whitelist is a subnet of the rule + Supernet: Whitelist is a supernet of the rule + Equal: Whitelist is equal to the rule + Different: Whitelist is different from the rule + """ + + SUBNET = auto() + SUPERNET = auto() + EQUAL = auto() + DIFFERENT = auto() + + +def _is_same_ip_version(addr1: str, addr2: str) -> bool: + """ + Check if two IP addresses/networks are of the same IP version. + + :param addr1: First IP address or network string + :param addr2: Second IP address or network string + :return: True if both addresses are of the same IP version (IPv4 or IPv6), False otherwise + """ + is_ipv4_1 = "." in addr1 + is_ipv4_2 = "." in addr2 + return is_ipv4_1 == is_ipv4_2 + + +@lru_cache(maxsize=1024) +def get_network(address: str) -> ipaddress.IPv4Network | ipaddress.IPv6Network: + """ + Create and cache an IP network object. + + :param address: IP address or network in string format + :return: Cached IP network object + """ + return ipaddress.ip_network(address, strict=False) + + +def check_whitelist_to_rule_relation(rule: str, whitelist_entry: str) -> Relation: + """ + Checks if the whitelist network is a subnet or supernet or exactly the same as the rule network. + Uses cached network objects for better performance. + + If the rule and whitelist are different IP versions (IPv4 vs IPv6), + they are treated as different networks. + + :param rule: The IP address or network to check (e.g., "192.168.1.1" or "192.168.1.0/24") + :param whitelist_entry: The allowed network to compare against (e.g., "192.168.1.0/24") + :return: Relation between the two networks + """ + # First check if IP versions are the same + if not _is_same_ip_version(rule, whitelist_entry): + return Relation.DIFFERENT + + try: + rule_net = get_network(rule) + whitelist_net = get_network(whitelist_entry) + + if whitelist_net == rule_net: + return Relation.EQUAL + if whitelist_net.supernet_of(rule_net): + return Relation.SUPERNET + if whitelist_net.subnet_of(rule_net): + return Relation.SUBNET + else: + return Relation.DIFFERENT + except (ValueError, TypeError): + # Handle any other errors that might occur during comparison + return Relation.DIFFERENT + + +def subtract_network(target: str, whitelist: str) -> List[str]: + """ + Computes the remaining parts of a network after removing the whitelist subnet. + Uses cached network objects for better performance. + + If the target and whitelist are different IP versions (IPv4 vs IPv6), + the original target is returned unchanged. + + :param target: The main network (e.g., "192.168.1.0/24") + :param whitelist: The subnet to remove (e.g., "192.168.1.128/25") + :return: A list of remaining subnets as strings + """ + # First check if IP versions are the same + if not _is_same_ip_version(target, whitelist): + return [target] + + try: + target_net = get_network(target) + whitelist_net = get_network(whitelist) + + # Check if the whitelist is actually a subnet + if check_whitelist_to_rule_relation(target, whitelist) != Relation.SUBNET: + return [target] # Return the full network if whitelist isn't a valid subnet + + remaining = [] + + # Compute ranges before and after the whitelist + if whitelist_net.network_address > target_net.network_address: + # Before the whitelist + start = target_net.network_address + end = whitelist_net.network_address - 1 + remaining.extend(ipaddress.summarize_address_range(start, end)) + + if whitelist_net.broadcast_address < target_net.broadcast_address: + # After the whitelist + start = whitelist_net.broadcast_address + 1 + end = target_net.broadcast_address + remaining.extend(ipaddress.summarize_address_range(start, end)) + + # Convert to string format + return [str(net) for net in remaining] + except (ValueError, TypeError): + # Return the original target in case of any error + return [target] + + +def check_rule_against_whitelists(rule: str, whitelists: List[str]) -> List[Tuple]: + """ + Helper function to check a single rule against multiple whitelist entries. + Creates a cached rule network object for better performance. + Reduces list of whitelists, where the Relation is not DIFFERENT + + :param rule: The IP address or network to check + :param whitelists: List of whitelist networks to check against + :return: tuple of rule, whitelist and relation for each whitelists that is not DIFFERENT + """ + # Pre-cache the rule network since it will be used multiple times + try: + get_network(rule) + except (ValueError, TypeError): + # Return empty list if rule is not a valid network + return [] + + items = [] + for whitelist in whitelists: + try: + rel = check_whitelist_to_rule_relation(rule, whitelist) + if rel != Relation.DIFFERENT: + items.append((rule, whitelist, rel)) + except (ValueError, TypeError): + # Skip this whitelist if there's an error comparing + continue + return items + + +def check_whitelist_against_rules(rules: List[str], whitelist: str) -> List[Tuple]: + """ + Helper function to check if any whitelist entry is a subnet of the rule. + Creates a cached rule network object for better performance. + Reduces list of rules, where the Relation is not DIFFERENT + + :param rules: List of rule networks to check against + :param whitelist: The whitelist network to check + :return: tuple of rule, whitelist and relation for each rules that is not DIFFERENT + """ + # Pre-cache the whitelist network since it will be used multiple times + try: + get_network(whitelist) + except (ValueError, TypeError): + # Return empty list if whitelist is not a valid network + return [] + + items = [] + for rule in rules: + try: + rel = check_whitelist_to_rule_relation(rule, whitelist) + if rel != Relation.DIFFERENT: + items.append((rule, whitelist, rel)) + except (ValueError, TypeError): + # Skip this rule if there's an error comparing + continue + return items + + +def clear_network_cache() -> None: + """ + Clear the network object cache. + Useful when processing a large number of networks to prevent memory growth. + """ + get_network.cache_clear() + + +def create_rtbh_from_whitelist_parts( + model: RTBH, wl_id: int, whitelist_key: str, network: str, rule_owner: str = "", user_id: int = 0 +) -> RTBH: + # default values from model + rule_owner = rule_owner or model.get_author() + user_id = user_id or model.user_id + + net_ip, net_mask = network.split("/") + new_model = RTBH( + ipv4=net_ip, + ipv4_mask=net_mask, + ipv6=model.ipv6, + ipv6_mask=model.ipv6_mask, + community_id=model.community_id, + expires=model.expires, + comment=model.comment, + user_id=model.user_id, + org_id=model.org_id, + rstate_id=1, + ) + db.session.add(new_model) + db.session.commit() + + add_rtbh_rule_to_cache(new_model, wl_id, RuleOrigin.WHITELIST) + announce_rtbh_route(new_model, rule_owner) + log_route(user_id, model, RuleTypes.RTBH, rule_owner) + + return new_model diff --git a/flowapp/services/whitelist_service.py b/flowapp/services/whitelist_service.py new file mode 100644 index 00000000..a934cc8c --- /dev/null +++ b/flowapp/services/whitelist_service.py @@ -0,0 +1,202 @@ +# flowapp/services/rule_service.py +""" +Service module for rule operations. + +This module provides business logic functions for creating, updating, +and managing flow rules, separating these concerns from HTTP handling. +""" +from flask import current_app +from typing import Dict, Tuple, List + +import sqlalchemy + +from flowapp import db +from flowapp.constants import RuleOrigin, RuleTypes +from flowapp.models import Whitelist, RuleWhitelistCache, get_whitelist_model_if_exists +from flowapp.models.rules.flowspec import Flowspec4, Flowspec6 +from flowapp.models.rules.rtbh import RTBH +from flowapp.services.base import announce_rtbh_route, withdraw_rtbh_route +from flowapp.services.whitelist_common import add_rtbh_rule_to_cache, create_rtbh_from_whitelist_parts +from flowapp.services.whitelist_common import ( + Relation, + check_whitelist_against_rules, + subtract_network, + whitelist_rtbh_rule, +) +from flowapp.utils import round_to_ten_minutes, quote_to_ent + + +def create_or_update_whitelist( + form_data: Dict, user_id: int, org_id: int, user_email: str, org_name: str +) -> Tuple[Whitelist, str]: + """ + Create a new Whitelist rule or update expiration time of an existing one. + + Args: + form_data: Validated form data + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + + Returns: + Tuple containing (whitelist_model, message) + """ + # Check for existing model + model = get_whitelist_model_if_exists(form_data) + flashes = [] + if model: + model.expires = round_to_ten_minutes(form_data["expires"]) + flashes.append("Existing Whitelist found. Expiration time was updated to new value.") + else: + # Create new model + model = Whitelist( + ip=form_data["ip"], + mask=form_data["mask"], + expires=round_to_ten_minutes(form_data["expires"]), + user_id=user_id, + org_id=org_id, + comment=quote_to_ent(form_data["comment"]), + ) + db.session.add(model) + flashes.append("Whitelist saved") + + db.session.commit() + + # check RTBH rules against whitelist + allowed_communities = current_app.config["ALLOWED_COMMUNITIES"] + # filter out RTBH rules that are not active or whitelisted and not in allowed communities + all_rtbh_rules = RTBH.query.filter(RTBH.rstate_id.in_([1, 4]), RTBH.community_id.in_(allowed_communities)).all() + rtbh_rules_map = map_rtbh_rules_to_strings(all_rtbh_rules) + result = check_whitelist_against_rules(rtbh_rules_map, str(model)) + current_app.logger.info(f"Found {len(result)} matching RTBH rules for whitelist {model}") + model = evaluate_whitelist_against_rtbh_check_results(model, flashes, rtbh_rules_map, result) + + return model, flashes + + +def evaluate_whitelist_against_rtbh_check_results( + whitelist_model: Whitelist, + flashes: List[str], + rtbh_rule_cache: Dict[str, Whitelist], + results: List[Tuple[str, str, Relation]], +) -> Whitelist: + + for rule_key, whitelist_key, relation in results: + current_app.logger.info(f"whitelist {whitelist_key} is {relation} to Rule {rule_key}") + match relation: + case Relation.EQUAL: + whitelist_rtbh_rule(rtbh_rule_cache[rule_key], whitelist_model) + withdraw_rtbh_route(rtbh_rule_cache[rule_key]) + msg = "Existing active rule {rule_key} is equal to whitelist {whitelist_key}. Rule is now whitelisted." + flashes.append(msg) + current_app.logger.info(msg) + case Relation.SUBNET: + parts = subtract_network(target=rule_key, whitelist=whitelist_key) + wl_id = whitelist_model.id + msg = f"Rule {rule_key} is supernet of whitelist {whitelist_key}. Rule is whitelisted, {len(parts)} subnet rules will be created." + flashes.append(msg) + current_app.logger.info(msg) + for network in parts: + rule_model = rtbh_rule_cache[rule_key] + create_rtbh_from_whitelist_parts(rule_model, wl_id, whitelist_key, network) + msg = f"Created RTBH rule from {rule_model.id} {network} parted by whitelist {whitelist_key}." + flashes.append(msg) + current_app.logger.info(msg) + rule_model.rstate_id = 4 + add_rtbh_rule_to_cache(rule_model, wl_id, RuleOrigin.USER) + db.session.commit() + case Relation.SUPERNET: + + whitelist_rtbh_rule(rtbh_rule_cache[rule_key], whitelist_model) + withdraw_rtbh_route(rtbh_rule_cache[rule_key]) + msg = ( + f"Existing active rule {rule_key} is subnet of whitelist {whitelist_key}. Rule is now whitelisted." + ) + current_app.logger.info(msg) + flashes.append(msg) + + return whitelist_model + + +def map_rtbh_rules_to_strings(all_rtbh_rules: List[RTBH]) -> Dict[str, RTBH]: + return {str(rule): rule for rule in all_rtbh_rules} + + +def delete_expired_whitelists() -> List[str]: + """ + Delete all expired whitelist entries from the database. + + Returns: + List of messages for the user + """ + expired_whitelists = Whitelist.query.filter(Whitelist.expires < db.func.now()).all() + flashes = [] + for model in expired_whitelists: + flashes.extend(delete_whitelist(model.id)) + return flashes + + +def delete_whitelist(whitelist_id: int) -> List[str]: + """ + Delete a whitelist entry from the database. + + Args: + whitelist_id: The ID of the whitelist to delete + """ + model = db.session.get(Whitelist, whitelist_id) + flashes = [] + if model: + cached_rules = RuleWhitelistCache.get_by_whitelist_id(whitelist_id) + current_app.logger.info( + f"Deleting whitelist {whitelist_id} {model}. Found {len(cached_rules)} cached rules to process." + ) + for cached_rule in cached_rules: + rule_model_type = RuleTypes(cached_rule.rtype) + cache_entries_count = RuleWhitelistCache.count_by_rule(cached_rule.rid, rule_model_type) + match rule_model_type: + case RuleTypes.IPv4: + rule_model = db.session.get(Flowspec4, cached_rule.rid) + case RuleTypes.IPv6: + rule_model = db.session.get(Flowspec6, cached_rule.rid) + case RuleTypes.RTBH: + rule_model = db.session.get(RTBH, cached_rule.rid) + rorigin_type = RuleOrigin(cached_rule.rorigin) + current_app.logger.debug(f"Rule {rule_model} has origin {rorigin_type}") + if rorigin_type == RuleOrigin.WHITELIST: + msg = f"Deleted {rule_model_type} rule {rule_model} created by whitelist {model}" + current_app.logger.info(msg) + flashes.append(msg) + try: + db.session.delete(rule_model) + except sqlalchemy.orm.exc.UnmappedInstanceError: + current_app.logger.warning( + f"RuleWhitelistCache Anomaly! Rule {rule_model} does not exist. Type {rule_model_type} RID {cached_rule.rid} ID {cached_rule.id} Skipping." + ) + + elif rorigin_type == RuleOrigin.USER and cache_entries_count == 1: + msg = f"Set rule {rule_model} back to state 'Active'" + current_app.logger.info(msg) + flashes.append(msg) + try: + rule_model.rstate_id = 1 # Set rule state to "Active" again + except AttributeError: + current_app.logger.warning( + f"RuleWhitelistCache Anomaly! Rule {rule_model} does not exist. Type {rule_model_type} RID {cached_rule.rid} ID {cached_rule.id} Skipping." + ) + else: + author = f"{model.user.email} ({model.user.organization})" + announce_rtbh_route(rule_model, author) + + msg = f"Deleted cache entries for whitelist {whitelist_id} {model}" + current_app.logger.info(msg) + flashes.append(msg) + RuleWhitelistCache.clean_by_whitelist_id(whitelist_id) + + db.session.delete(model) + db.session.commit() + msg = f"Deleted whitelist {whitelist_id} {model}" + flashes.append(msg) + current_app.logger.info(msg) + + return flashes diff --git a/flowapp/templates/forms/bulk_user_form.html b/flowapp/templates/forms/bulk_user_form.html index c41389ea..b15b3ff8 100644 --- a/flowapp/templates/forms/bulk_user_form.html +++ b/flowapp/templates/forms/bulk_user_form.html @@ -42,9 +42,9 @@

Example CSV data

             
 uuid-eppn,name,telefon,email,role,organizace,poznamka
-view@example.com,Test View,123,view@example.com,1,1,View
-user@example.com,Test User,123456,user@example.com,2,1,User
-admin@example.com,Test Admin,+420 111 111 111,admin@example.com,3,1,Admin
+view@example.com,Test View,123,view@example.com,1,1,user with view role (1)
+user@example.com,Test User,123456,user@example.com,2,1,regular user (role 2)
+admin@example.com,Test Admin,+420 111 111 111,admin@example.com,3,1,user with admin role (3)
             
         

Role

diff --git a/flowapp/templates/forms/machine_api_key.html b/flowapp/templates/forms/machine_api_key.html index be6fcaad..1b646ad3 100644 --- a/flowapp/templates/forms/machine_api_key.html +++ b/flowapp/templates/forms/machine_api_key.html @@ -18,25 +18,29 @@
Machine Api Key: {{ generated_key }}
{{ form.hidden_tag() if form.hidden_tag }}
-
+
{{ render_field(form.machine) }}
{{ render_checkbox_field(form.readonly, checked="checked") }}
-
+
+
+
+ {{ render_field(form.user) }} +
+
+
{{ render_field(form.expires) }}
-
- -
+
-
+
{{ render_field(form.comment) }}
-
+
diff --git a/flowapp/templates/forms/rtbh_rule.html b/flowapp/templates/forms/rtbh_rule.html index 986c081b..8c91ef22 100644 --- a/flowapp/templates/forms/rtbh_rule.html +++ b/flowapp/templates/forms/rtbh_rule.html @@ -28,6 +28,15 @@

{{ title or 'New'}} RTBH rule

{{ render_field(form.community) }}
+
+
Following communities can be whitelisted:
+
    + {% for com in whitelistable %} +
  • + {{ com }} +
  • + {% endfor %} +
@@ -36,9 +45,6 @@

{{ title or 'New'}} RTBH rule

- {{ form.expires(class_='form-control') }} diff --git a/flowapp/templates/forms/whitelist.html b/flowapp/templates/forms/whitelist.html new file mode 100644 index 00000000..35559d79 --- /dev/null +++ b/flowapp/templates/forms/whitelist.html @@ -0,0 +1,62 @@ +{% extends 'layouts/default.html' %} +{% from 'forms/macros.html' import render_field %} +{% block title %}Add Whitelist{% endblock %} +{% block content %} +

{{ title or 'New'}} Whitelist

+ + {{ form.hidden_tag() if form.hidden_tag }} +
+
+
+
+ {{ render_field(form.ip) }} +
+
+ {{ render_field(form.mask) }} +
+
+
+
+
+ +
+ +
+
+
+ +
+ + {{ form.expires(class_='form-control') }} + + + +
+ + {% if form.expires.errors %} + {% for e in form.expires.errors %} +

{{ e }}

+ {% endfor %} + {% endif %} +
+
+
+
+
+ +
+
+ {{ render_field(form.comment) }} +
+
+ +
+
+ +
+
+ +{% endblock %} \ No newline at end of file diff --git a/flowapp/templates/macros.html b/flowapp/templates/macros.html index 9e0b4b79..661be2f7 100644 --- a/flowapp/templates/macros.html +++ b/flowapp/templates/macros.html @@ -1,5 +1,5 @@ -{% macro build_ip_tbody(rules, today, editable=True, group_op=True) %} +{% macro build_ip_tbody(rules, today, editable=True, group_op=True, whitelist_rule_ids=None, allowed_communities=None) %} {% for rule in rules %} {% if rule.next_header is defined %} @@ -8,7 +8,11 @@ {% set rtype_int = 4 %} {% endif %} - + {{ rule.source }}{% if rule.source_mask != none %}{{ '/' if rule.source_mask >= 0 else '' }}{{ rule.source_mask if rule.source_mask >= 0 else '' }}{% endif %} @@ -45,10 +49,10 @@ {% if editable %} - + - + {% endif %} @@ -73,23 +77,72 @@ {% endmacro %} -{% macro build_rtbh_tbody(rules, today, editable=True, group_op=True) %} +{% macro build_rtbh_tbody(rules, today, editable=True, group_op=True, whitelist_rule_ids=None, allowed_communities=None) %} + +{% for rule in rules %} + + + {% if rule.ipv4 %} + {{ rule.ipv4 }}{{ '/' if rule.ipv4_mask else '' }}{{rule.ipv4_mask|default("", True)}} + {% endif %} + {% if rule.ipv6 %} + {{ rule.ipv6 }}{{ '/' if rule.ipv6_mask else '' }} {{rule.ipv6_mask|default("", True)}} + {% endif %} + + + {{ rule.community.name }} + + + + {{ rule.expires|strftime }} + + + {{ rule.user.name }} + + + {% if editable %} + + + + + + + {% if rule.community.id in allowed_communities %} + + + + {% endif %} + {% endif %} + {% if rule.comment %} + + {% endif %} + + {% if editable and group_op %} + + + + {% endif %} + + +{% endfor %} + +{% endmacro %} + + +{% macro build_whitelist_tbody(rules, today, editable=True, group_op=True, whitelist_rule_ids=None, allowed_communities=None) %} {% for rule in rules %} - {% if rule.ipv4 %} - {{ rule.ipv4 }}{{ '/' if rule.ipv4_mask else '' }}{{rule.ipv4_mask|default("", True)}} - {% endif %} - {% if rule.ipv6 %} - {{ rule.ipv6 }}{{ '/' if rule.ipv6_mask else '' }} {{rule.ipv6_mask|default("", True)}} - {% endif %} - - - {{ rule.community.name }} + {{ rule.ip }}{{ '/' if rule.mask else '' }}{{rule.mask|default("", True)}} - - + {{ rule.expires|strftime }} @@ -97,10 +150,10 @@ {% if editable %} - + - + {% endif %} @@ -122,6 +175,7 @@ {% endmacro %} + {% macro build_rules_thead(rules_columns, rtype, rstate, sort_key, sort_order, search_query='', group_op=True) %} diff --git a/flowapp/templates/pages/machine_api_key.html b/flowapp/templates/pages/machine_api_key.html index 52eb478a..c2ced223 100644 --- a/flowapp/templates/pages/machine_api_key.html +++ b/flowapp/templates/pages/machine_api_key.html @@ -11,8 +11,8 @@

Machines and ApiKeys

Machine address ApiKey - Created by Created for + Created by Expires Read/Write ? Action diff --git a/flowapp/tests/conftest.py b/flowapp/tests/conftest.py index 65919af1..3a988ef4 100644 --- a/flowapp/tests/conftest.py +++ b/flowapp/tests/conftest.py @@ -12,6 +12,7 @@ from flowapp import db as _db from datetime import datetime import flowapp.models +from flowapp.models.organization import Organization TESTDB = "test_project.db" @@ -67,6 +68,8 @@ def app(request): SECRET_KEY="testkeysession", LOCAL_USER_UUID="jiri.vrany@cesnet.cz", LOCAL_AUTH=True, + ALLOWED_COMMUNITIES=[1, 2, 3], + WTF_CSRF_ENABLED=False, ) print("\n----- CREATE FLASK APPLICATION\n") @@ -114,6 +117,12 @@ def db(app, request): print("#: inserting users") flowapp.models.insert_users(users) + org = _db.session.query(Organization).filter_by(id=1).first() + # Update the organization address range to include our test networks + org.arange = "147.230.0.0/16\n2001:718:1c01::/48\n192.168.0.0/16\n10.0.0.0/8" + _db.session.commit() + print("\n----- UPDATED TEST ORG 1 \n", org) + def teardown(): _db.session.commit() _db.drop_all() @@ -143,6 +152,26 @@ def jwt_token(client, app, db, request): return data["token"] +@pytest.fixture(scope="session") +def machine_api_token(client, app, db, request): + """ + Get the test_client from the app, for the whole test session. + """ + mkey = "machinetestkey" + + with app.app_context(): + model = flowapp.models.MachineApiKey(machine="127.0.0.1", key=mkey, user_id=1, org_id=1) + db.session.add(model) + db.session.commit() + + print("\n----- GET MACHINE API KEY TEST TOKEN\n") + url = "/api/v3/auth" + headers = {"x-api-key": mkey} + token = client.get(url, headers=headers) + data = json.loads(token.data) + return data["token"] + + @pytest.fixture(scope="session") def expired_auth_token(client, app, db, request): """ @@ -185,3 +214,19 @@ def auth_client(client): print("\n----- CREATE AUTHENTICATED FLASK TEST CLIENT\n") client.get("/local-login") return client + + +@pytest.fixture(autouse=True) +def reset_org_limits(db, app): + """ + Automatically reset organization limits after each test that modifies them. + """ + yield # Allow test execution + + with app.app_context(): + org = db.session.query(Organization).filter_by(id=1).first() + if org: + org.limit_flowspec4 = 0 + org.limit_flowspec6 = 0 + org.limit_rtbh = 0 + db.session.commit() diff --git a/flowapp/tests/rule_service_integration.py b/flowapp/tests/rule_service_integration.py new file mode 100644 index 00000000..106000ab --- /dev/null +++ b/flowapp/tests/rule_service_integration.py @@ -0,0 +1,332 @@ +"""Integration tests for rule_service module.""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import patch + +from flowapp import db, create_app +from flowapp.constants import RuleTypes +from flowapp.models import ( + RTBH, + Flowspec4, + Flowspec6, + Whitelist, + User, + Organization, + Role, + Community, + Action, + Rstate, +) +from flowapp.services import rule_service + + +@pytest.fixture(scope="module") +def app(): + """Create and configure a Flask app for testing.""" + app = create_app("testing") + + # Create a test context + with app.app_context(): + # Create database tables + db.create_all() + + # Create test data + _create_test_data() + + yield app + + # Clean up after tests + db.session.remove() + db.drop_all() + + +@pytest.fixture(scope="function") +def app_context(app): + """Provide an application context for each test.""" + with app.app_context(): + yield + + +@pytest.fixture(scope="function") +def mock_announce_route(): + """Mock the announce_route function.""" + with patch("flowapp.services.rule_service.announce_route") as mock: + yield mock + + +def _create_test_data(): + """Create test data in the database.""" + # Create roles + admin_role = Role(name="admin", description="Administrator") + user_role = Role(name="user", description="Normal user") + view_role = Role(name="view", description="View only") + db.session.add_all([admin_role, user_role, view_role]) + db.session.flush() + + # Create organizations + org1 = Organization(name="Test Org 1", arange="192.168.1.0/24\n2001:db8::/64") + org2 = Organization(name="Test Org 2", arange="10.0.0.0/8") + db.session.add_all([org1, org2]) + db.session.flush() + + # Create users + user1 = User(uuid="user1@example.com", name="User One") + user1.role.append(user_role) + user1.organization.append(org1) + + admin1 = User(uuid="admin@example.com", name="Admin User") + admin1.role.append(admin_role) + admin1.organization.append(org2) + + db.session.add_all([user1, admin1]) + db.session.flush() + + # Create rule states + state_active = Rstate(description="active rule") + state_withdrawn = Rstate(description="withdrawed rule") + state_deleted = Rstate(description="deleted rule") + state_whitelisted = Rstate(description="whitelisted rule") + db.session.add_all([state_active, state_withdrawn, state_deleted, state_whitelisted]) + db.session.flush() + + # Create actions + action1 = Action(name="Discard", command="discard", description="Drop traffic", role_id=user_role.id) + action2 = Action(name="Rate limit", command="rate-limit 1000000", description="Limit traffic", role_id=user_role.id) + db.session.add_all([action1, action2]) + db.session.flush() + + # Create communities + community1 = Community( + name="65000:1", + comm="65000:1", + larcomm="", + extcomm="", + description="Test community", + role_id=user_role.id, + as_path=False, + ) + db.session.add(community1) + db.session.flush() + + # Create some rules + # IPv4 rule + ipv4_rule = Flowspec4( + source="192.168.1.1", + source_mask=32, + source_port="", + dest="192.168.2.1", + dest_mask=32, + dest_port="", + protocol="tcp", + flags="", + packet_len="", + fragment="", + expires=datetime.now() + timedelta(hours=1), + comment="Test IPv4 rule", + action_id=action1.id, + user_id=user1.id, + org_id=org1.id, + rstate_id=state_active.id, + ) + db.session.add(ipv4_rule) + + # IPv6 rule + ipv6_rule = Flowspec6( + source="2001:db8::1", + source_mask=128, + source_port="", + dest="2001:db8:1::1", + dest_mask=128, + dest_port="", + next_header="tcp", + flags="", + packet_len="", + expires=datetime.now() + timedelta(hours=1), + comment="Test IPv6 rule", + action_id=action1.id, + user_id=user1.id, + org_id=org1.id, + rstate_id=state_active.id, + ) + db.session.add(ipv6_rule) + + # RTBH rule + rtbh_rule = RTBH( + ipv4="192.168.1.100", + ipv4_mask=32, + ipv6=None, + ipv6_mask=None, + community_id=community1.id, + expires=datetime.now() + timedelta(hours=1), + comment="Test RTBH rule", + user_id=user1.id, + org_id=org1.id, + rstate_id=state_active.id, + ) + db.session.add(rtbh_rule) + + # Save IDs as class variables for tests to use + _create_test_data.user_id = user1.id + _create_test_data.admin_id = admin1.id + _create_test_data.org1_id = org1.id + _create_test_data.org2_id = org2.id + _create_test_data.ipv4_rule_id = ipv4_rule.id + _create_test_data.ipv6_rule_id = ipv6_rule.id + _create_test_data.rtbh_rule_id = rtbh_rule.id + _create_test_data.action_id = action1.id + _create_test_data.community_id = community1.id + + db.session.commit() + + +class TestRuleServiceIntegration: + """Integration tests for rule_service functions.""" + + def test_reactivate_rule_ipv4(self, app_context, mock_announce_route): + """Test reactivating an IPv4 rule.""" + # Test data + rule_id = _create_test_data.ipv4_rule_id + user_id = _create_test_data.user_id + org_id = _create_test_data.org1_id + + # Set new expiration time (2 hours in the future) + new_expires = datetime.now() + timedelta(hours=2) + new_comment = "Updated test comment" + + # Call the function + model, messages = rule_service.reactivate_rule( + rule_type=RuleTypes.IPv4, + rule_id=rule_id, + expires=new_expires, + comment=new_comment, + user_id=user_id, + org_id=org_id, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the rule was updated + assert model is not None + assert model.id == rule_id + assert model.comment == new_comment + assert model.expires == new_expires + assert model.rstate_id == 1 # active state + + # Verify route announcement was attempted + assert mock_announce_route.called + + # Verify message + assert messages == ["Rule successfully updated"] + + def test_reactivate_rule_inactive(self, app_context, mock_announce_route): + """Test reactivating a rule to inactive state.""" + # Test data + rule_id = _create_test_data.ipv4_rule_id + user_id = _create_test_data.user_id + org_id = _create_test_data.org1_id + + # Set past expiration time + past_expires = datetime.now() - timedelta(hours=1) + + # Call the function + model, messages = rule_service.reactivate_rule( + rule_type=RuleTypes.IPv4, + rule_id=rule_id, + expires=past_expires, + comment="Expired comment", + user_id=user_id, + org_id=org_id, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the rule was updated + assert model is not None + assert model.id == rule_id + assert model.expires == past_expires + assert model.rstate_id == 2 # inactive/withdrawn state + + # Verify route announcement was attempted + assert mock_announce_route.called + + def test_delete_rule(self, app_context, mock_announce_route): + """Test deleting a rule.""" + # Test data + rule_id = _create_test_data.ipv6_rule_id + user_id = _create_test_data.user_id + + # Call the function + success, message = rule_service.delete_rule( + rule_type=RuleTypes.IPv6, + rule_id=rule_id, + user_id=user_id, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the rule was deleted + assert success is True + assert message == "Rule deleted successfully" + + # Verify the rule no longer exists in the database + rule = db.session.get(Flowspec6, rule_id) + assert rule is None + + # Verify route withdrawal was attempted + assert mock_announce_route.called + + def test_delete_rule_not_found(self, app_context): + """Test deleting a non-existent rule.""" + # Use a non-existent rule ID + non_existent_id = 9999 + + # Call the function + success, message = rule_service.delete_rule( + rule_type=RuleTypes.IPv4, + rule_id=non_existent_id, + user_id=_create_test_data.user_id, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the operation failed + assert success is False + assert message == "Rule not found" + + @patch("flowapp.services.rule_service.create_or_update_whitelist") + def test_delete_rtbh_and_create_whitelist(self, mock_create_whitelist, app_context, mock_announce_route): + """Test deleting an RTBH rule and creating a whitelist entry.""" + # Test data + rule_id = _create_test_data.rtbh_rule_id + user_id = _create_test_data.user_id + org_id = _create_test_data.org1_id + + # Mock whitelist creation + mock_whitelist = Whitelist( + ip="192.168.1.100", mask=32, expires=datetime.now() + timedelta(days=7), user_id=user_id, org_id=org_id + ) + mock_create_whitelist.return_value = (mock_whitelist, ["Whitelist created"]) + + # Call the function + success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( + rule_id=rule_id, user_id=user_id, org_id=org_id, user_email="test@example.com", org_name="Test Org" + ) + + # Verify success + assert success is True + assert len(messages) == 2 + assert "Rule deleted successfully" in messages[0] + assert "Whitelist created" in messages[1] + + # Verify the rule was deleted + rule = db.session.get(RTBH, rule_id) + assert rule is None + + # Verify create_or_update_whitelist was called with correct data + mock_create_whitelist.assert_called_once() + args, kwargs = mock_create_whitelist.call_args + form_data = kwargs.get("form_data", args[0] if args else None) + assert form_data["ip"] == "192.168.1.100" + assert form_data["mask"] == 32 + assert "Created from RTBH rule" in form_data["comment"] diff --git a/flowapp/tests/test_api_auth.py b/flowapp/tests/test_api_auth.py index 5733346b..8d08a7ae 100644 --- a/flowapp/tests/test_api_auth.py +++ b/flowapp/tests/test_api_auth.py @@ -11,6 +11,15 @@ def test_token(client, jwt_token): assert req.status_code == 200 +def test_machine_token(client, machine_api_token): + """ + test that token authorization works + """ + req = client.get("/api/v3/test_token", headers={"x-access-token": machine_api_token}) + + assert req.status_code == 200 + + def test_expired_token(client, expired_auth_token): """ test that expired token authorization return 401 @@ -37,7 +46,7 @@ def test_readonly_token(client, readonly_jwt_token): assert req.status_code == 200 data = json.loads(req.data) - assert data['readonly'] + assert data["readonly"] def test_readonly_token_ipv4_create(client, db, readonly_jwt_token): diff --git a/flowapp/tests/test_api_v3.py b/flowapp/tests/test_api_v3.py index 96b24786..87f70949 100644 --- a/flowapp/tests/test_api_v3.py +++ b/flowapp/tests/test_api_v3.py @@ -1,10 +1,37 @@ import json + from flowapp.models import Flowspec4, Organization V_PREFIX = "/api/v3" +def test_create_rtbh_only(client, app, db, jwt_token): + """ + Test creating an RTBH rule through API that exactly matches an existing whitelist. + + The rule should be created but marked as whitelisted (rstate_id=4), and a cache entry + should be created linking the rule to the whitelist. + """ + + # Now create the RTBH rule via API + res = client.post( + f"{V_PREFIX}/rules/rtbh", + headers={"x-access-token": jwt_token}, + json={ + "community": 1, + "ipv4": "147.230.17.17", + "ipv4_mask": 32, + "expires": "10/25/2050 14:46", + }, + ) + + # Verify response is successful + assert res.status_code == 201 + data = json.loads(res.data) + assert data["rule"] is not None + + def test_token(client, jwt_token): """ test that token authorization works diff --git a/flowapp/tests/test_api_whitelist_integration.py b/flowapp/tests/test_api_whitelist_integration.py new file mode 100644 index 00000000..0391a762 --- /dev/null +++ b/flowapp/tests/test_api_whitelist_integration.py @@ -0,0 +1,273 @@ +""" +Integration test for the API view when interacting with whitelists. + +This test suite verifies the behavior when creating RTBH rules through the API +when there are existing whitelists that could affect the rules. +""" + +import json +import pytest +from datetime import datetime, timedelta + +from flowapp.constants import RuleTypes, RuleOrigin +from flowapp.models import RTBH, RuleWhitelistCache, Organization +from flowapp.services import whitelist_service + + +@pytest.fixture +def whitelist_data(): + """Create whitelist data for testing""" + return { + "ip": "192.168.1.0", + "mask": 24, + "comment": "Test whitelist for API integration test", + "expires": datetime.now() + timedelta(days=1), + } + + +@pytest.fixture +def rtbh_api_payload(): + """Create RTBH rule data for API testing""" + return { + "community": 1, + "ipv4": "192.168.1.0", + "ipv4_mask": 24, + "expires": (datetime.now() + timedelta(days=1)).strftime("%m/%d/%Y %H:%M"), + "comment": "Test RTBH rule via API", + } + + +def test_create_rtbh_equal_to_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): + """ + Test creating an RTBH rule through API that exactly matches an existing whitelist. + + The rule should be created but marked as whitelisted (rstate_id=4), and a cache entry + should be created linking the rule to the whitelist. + """ + # First, configure app to include community ID 1 in allowed communities + app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) + + # Create the whitelist directly using the service + with app.app_context(): + # Create user and organization if needed for the whitelist + org = db.session.query(Organization).first() + + # Create the whitelist + whitelist_model, _ = whitelist_service.create_or_update_whitelist( + form_data=whitelist_data, + user_id=1, # Using user ID 1 from test fixtures + org_id=org.id, + user_email="test@example.com", + org_name=org.name, + ) + + # Verify whitelist was created + assert whitelist_model.id is not None + assert whitelist_model.ip == whitelist_data["ip"] + assert whitelist_model.mask == whitelist_data["mask"] + + # Now create the RTBH rule via API + response = client.post( + "/api/v3/rules/rtbh", + headers={"x-access-token": jwt_token}, + json=rtbh_api_payload, + ) + + # Verify response is successful + assert response.status_code == 201 + data = json.loads(response.data) + assert data["rule"] is not None + rule_id = data["rule"]["id"] + + # Now verify the rule was created but marked as whitelisted + with app.app_context(): + rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + assert rtbh_rule is not None + assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state + + # Verify a cache entry was created + cache_entry = RuleWhitelistCache.query.filter_by( + rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id + ).first() + + assert cache_entry is not None + assert cache_entry.rorigin == RuleOrigin.USER.value + + +def test_create_rtbh_supernet_of_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): + """ + Test creating an RTBH rule through API that is a supernet of an existing whitelist. + + The rule should be created with whitelisted state (rstate_id=4) and smaller subnet rules + should be created for the non-whitelisted parts. + """ + # Configure app with allowed communities + app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) + + # First create a whitelist for a subnet + whitelist_data["ip"] = "192.168.1.128" + whitelist_data["mask"] = 25 # Subnet of the RTBH rule which will be /24 + + # Create the whitelist directly using the service + with app.app_context(): + org = db.session.query(Organization).first() + whitelist_model, _ = whitelist_service.create_or_update_whitelist( + form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name + ) + + # Verify whitelist was created + assert whitelist_model.id is not None + + # Now create the RTBH rule via API that covers both the whitelist and additional space + headers = {"x-access-token": jwt_token} + response = client.post( + "/api/v3/rules/rtbh", + headers=headers, + json=rtbh_api_payload, # This is a /24, which contains the /25 whitelist + ) + + # Verify response is successful + assert response.status_code == 201 + data = json.loads(response.data) + assert data["rule"] is not None + rule_id = data["rule"]["id"] + + # Now verify the rule was created and marked as whitelisted + with app.app_context(): + # Check the original rule + rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + assert rtbh_rule is not None + assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state + + # Check if a new subnet rule was created for the non-whitelisted part + subnet_rule = ( + db.session.query(RTBH) + .filter( + RTBH.ipv4 == "192.168.1.0", + RTBH.ipv4_mask == 25, # This would be the other half not covered by the whitelist + ) + .first() + ) + + assert subnet_rule is not None + assert subnet_rule.rstate_id == 1 # Active status + + # Verify cache entries + # Main rule should be cached as a USER rule + user_cache = RuleWhitelistCache.query.filter_by( + rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id, rorigin=RuleOrigin.USER.value + ).first() + assert user_cache is not None + + # Subnet rule should be cached as a WHITELIST rule + whitelist_cache = RuleWhitelistCache.query.filter_by( + rid=subnet_rule.id, + rtype=RuleTypes.RTBH.value, + whitelist_id=whitelist_model.id, + rorigin=RuleOrigin.WHITELIST.value, + ).first() + assert whitelist_cache is not None + + +def test_create_rtbh_subnet_of_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): + """ + Test creating an RTBH rule through API that is contained within an existing whitelist. + + The rule should be created but immediately marked as whitelisted (rstate_id=4). + """ + # Configure app with allowed communities + app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) + + # First create a whitelist for a supernet + whitelist_data["ip"] = "192.168.0.0" + whitelist_data["mask"] = 16 # Supernet that contains the RTBH rule + + # Create the whitelist directly using the service + with app.app_context(): + all_rtbh_rules_before = db.session.query(RTBH).count() + + org = db.session.query(Organization).first() + whitelist_model, _ = whitelist_service.create_or_update_whitelist( + form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name + ) + + # Verify whitelist was created + assert whitelist_model.id is not None + + # Now create the RTBH rule via API that is inside the whitelist + headers = {"x-access-token": jwt_token} + response = client.post( + "/api/v3/rules/rtbh", headers=headers, json=rtbh_api_payload # This is a /24 inside the /16 whitelist + ) + + # Verify response is successful + assert response.status_code == 201 + data = json.loads(response.data) + assert data["rule"] is not None + rule_id = data["rule"]["id"] + + # Now verify the rule was created but marked as whitelisted + with app.app_context(): + rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + assert rtbh_rule is not None + assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state + + # Verify a cache entry was created + cache_entry = RuleWhitelistCache.query.filter_by( + rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id + ).first() + + assert cache_entry is not None + assert cache_entry.rorigin == RuleOrigin.USER.value + + # Verify no additional rules were created + all_rtbh_rules = db.session.query(RTBH).count() + assert all_rtbh_rules - all_rtbh_rules_before == 1 + + +def test_create_rtbh_no_relation_to_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): + """ + Test creating an RTBH rule through API that has no relation to any existing whitelist. + + The rule should be created normally with active state (rstate_id=1). + """ + # Configure app with allowed communities + app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) + + # First create a whitelist for a completely different network + whitelist_data["ip"] = "10.0.0.0" + whitelist_data["mask"] = 8 + + # Create the whitelist directly using the service + with app.app_context(): + org = db.session.query(Organization).first() + whitelist_model, _ = whitelist_service.create_or_update_whitelist( + form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name + ) + + # Verify whitelist was created + assert whitelist_model.id is not None + + # Now create the RTBH rule via API for a network not covered by any of the whitelists + rtbh_api_payload["ipv4"] = "147.230.17.185" + rtbh_api_payload["ipv4_mask"] = 32 + + headers = {"x-access-token": jwt_token} + response = client.post("/api/v3/rules/rtbh", headers=headers, json=rtbh_api_payload) # This is 192.168.1.0/24 + + # Verify response is successful + assert response.status_code == 201 + data = json.loads(response.data) + assert data["rule"] is not None + rule_id = data["rule"]["id"] + + # Now verify the rule was created with active state + with app.app_context(): + rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + assert rtbh_rule is not None + assert rtbh_rule.rstate_id == 1 # 1 = active state + + # Verify no cache entry was created + cache_entry = RuleWhitelistCache.query.filter_by(rid=rule_id, rtype=RuleTypes.RTBH.value).first() + + assert cache_entry is None diff --git a/flowapp/tests/test_flowspec.py b/flowapp/tests/test_flowspec.py index e7a9c3e9..70000b81 100644 --- a/flowapp/tests/test_flowspec.py +++ b/flowapp/tests/test_flowspec.py @@ -1,12 +1,12 @@ import pytest -import flowapp.flowspec +from flowapp.flowspec import translate_sequence, filter_rules_action, check_limit def test_translate_number(): """ tests for x (integer) to =x """ - assert "[=10]" == flowapp.flowspec.translate_sequence("10") + assert "[=10]" == translate_sequence("10") def test_raises(): @@ -14,7 +14,7 @@ def test_raises(): tests for translator """ with pytest.raises(ValueError): - flowapp.flowspec.translate_sequence("ahoj") + translate_sequence("ahoj") def test_raises_bad_number(): @@ -22,46 +22,146 @@ def test_raises_bad_number(): tests for translator """ with pytest.raises(ValueError): - flowapp.flowspec.translate_sequence("75555") + translate_sequence("75555") def test_translate_range(): """ tests for x-y to >=x&<=y """ - assert "[>=10&<=20]" == flowapp.flowspec.translate_sequence("10-20") + assert "[>=10&<=20]" == translate_sequence("10-20") def test_exact_rule(): """ test for >=x&<=y to >=x&<=y """ - assert "[>=10&<=20]" == flowapp.flowspec.translate_sequence(">=10&<=20") + assert "[>=10&<=20]" == translate_sequence(">=10&<=20") def test_greater_than(): """ test for >x to >=x&<=65535 """ - assert "[>=10&<=65535]" == flowapp.flowspec.translate_sequence(">10") + assert "[>=10&<=65535]" == translate_sequence(">10") def test_greater_equal_than(): """ test for >=x to >=x&<=65535 """ - assert "[>=10&<=65535]" == flowapp.flowspec.translate_sequence(">=10") + assert "[>=10&<=65535]" == translate_sequence(">=10") def test_lower_than(): """ test for =0&<=0 """ - assert "[>=0&<=10]" == flowapp.flowspec.translate_sequence("<10") + assert "[>=0&<=10]" == translate_sequence("<10") def test_lower_equal_than(): """ test for =0&<=0 """ - assert "[>=0&<=10]" == flowapp.flowspec.translate_sequence("<=10") + assert "[>=0&<=10]" == translate_sequence("<=10") + + +# new tests + + +def test_multiple_sequences(): + """Test multiple sequences separated by semicolons""" + assert "[=10 >=20&<=30]" == translate_sequence("10;20-30") + assert "[=10 >=0&<=20 >=30&<=65535]" == translate_sequence("10;<=20;>30") + + +def test_empty_sequence(): + """Test empty sequence and sequences with empty parts""" + assert "[]" == translate_sequence("") + assert "[=10]" == translate_sequence("10;;") + assert "[=10 =20]" == translate_sequence("10;;20") + + +def test_range_edge_cases(): + """Test edge cases for ranges""" + # Same numbers in range + assert "[>=10&<=10]" == translate_sequence("10-10") + + # Invalid range (start > end) + with pytest.raises(ValueError, match="Invalid range: start value cannot be greater than end value"): + translate_sequence("20-10") + + # Invalid range in exact rule + with pytest.raises(ValueError, match="Invalid range: start value cannot be greater than end value"): + translate_sequence(">=20&<=10") + + +def test_check_limit_validation(): + """Test the check_limit function""" + # Test valid cases + assert check_limit(10, 100) == 10 + assert check_limit(0, 100, 0) == 0 + assert check_limit(100, 100, 0) == 100 + + # Test invalid cases + with pytest.raises(ValueError, match="Invalid value number: .* is too big"): + check_limit(101, 100) + + with pytest.raises(ValueError, match="Invalid value number: .* is too small"): + check_limit(-1, 100, 0) + + +def test_invalid_inputs(): + """Test various invalid inputs""" + invalid_inputs = [ + "abc", # Non-numeric + "10.5", # Decimal + "10,20", # Wrong separator + "10&20", # Wrong format + ">", # Incomplete + ">=", # Incomplete + "<", # Incomplete + "<=", # Incomplete + ">>10", # Invalid operator + "<<10", # Invalid operator + ">=10&", # Incomplete range + "&<=10", # Incomplete range + ] + + for invalid_input in invalid_inputs: + with pytest.raises(ValueError): + translate_sequence(invalid_input) + + +class MockRule: + def __init__(self, action_id): + self.action_id = action_id + + +def test_filter_rules_action(): + """Test the filter_rules_action function""" + # Create test rules + rules = [MockRule(1), MockRule(2), MockRule(3), MockRule(4)] + + # Test with empty allowed actions + editable, viewonly = filter_rules_action([], rules) + assert len(editable) == 0 + assert len(viewonly) == 4 + + # Test with some allowed actions + editable, viewonly = filter_rules_action([1, 3], rules) + assert len(editable) == 2 + assert len(viewonly) == 2 + assert all(rule.action_id in [1, 3] for rule in editable) + assert all(rule.action_id in [2, 4] for rule in viewonly) + + # Test with all actions allowed + editable, viewonly = filter_rules_action([1, 2, 3, 4], rules) + assert len(editable) == 4 + assert len(viewonly) == 0 + + # Test with empty rules list + editable, viewonly = filter_rules_action([1, 2], []) + assert len(editable) == 0 + assert len(viewonly) == 0 diff --git a/flowapp/tests/test_forms.py b/flowapp/tests/test_forms.py index 76c98960..e9620d8d 100644 --- a/flowapp/tests/test_forms.py +++ b/flowapp/tests/test_forms.py @@ -1,15 +1,7 @@ import pytest -from flask import Flask import flowapp.forms -@pytest.fixture() -def app(): - app = Flask(__name__) - app.secret_key = "test" - return app - - @pytest.fixture() def ip_form(app, field_class): with app.test_request_context(): # Push the request context diff --git a/flowapp/tests/test_forms_cl.py b/flowapp/tests/test_forms_cl.py new file mode 100644 index 00000000..db0a569f --- /dev/null +++ b/flowapp/tests/test_forms_cl.py @@ -0,0 +1,608 @@ +import pytest +from datetime import datetime, timedelta +from werkzeug.datastructures import MultiDict +from flowapp.forms import ( + UserForm, + BulkUserForm, + ApiKeyForm, + MachineApiKeyForm, + OrganizationForm, + ActionForm, + ASPathForm, + CommunityForm, + RTBHForm, + IPv4Form, + IPv6Form, + WhitelistForm, +) + + +def create_form_data(data): + """Helper function to create proper form data format""" + processed_data = {} + for key, value in data.items(): + if isinstance(value, list): + processed_data[key] = [str(v) for v in value] + else: + processed_data[key] = value + return MultiDict(processed_data) + + +@pytest.fixture +def valid_datetime(): + return (datetime.now() + timedelta(days=1)).strftime("%Y-%m-%dT%H:%M") + + +@pytest.fixture +def sample_network_ranges(): + return ["192.168.0.0/16", "2001:db8::/32"] + + +class TestUserForm: + @pytest.fixture + def mock_choices(self): + return { + "role_ids": [(1, "Admin"), (2, "User"), (3, "Guest")], + "org_ids": [(1, "Org1"), (2, "Org2"), (3, "Org3")], + } + + @pytest.fixture + def valid_user_data(self): + return { + "uuid": "test@example.com", + "email": "user@example.com", + "name": "Test User", + "phone": "123456789", + "role_ids": ["2"], + "org_ids": ["1"], + } + + def test_valid_user_form(self, app, valid_user_data, mock_choices): + with app.test_request_context(): + form_data = create_form_data(valid_user_data) + form = UserForm(formdata=form_data) + form.role_ids.choices = mock_choices["role_ids"] + form.org_ids.choices = mock_choices["org_ids"] + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_invalid_email(self, app, mock_choices): + with app.test_request_context(): + form_data = create_form_data({"uuid": "invalid-email", "role_ids": ["2"], "org_ids": ["1"]}) + form = UserForm(formdata=form_data) + form.role_ids.choices = mock_choices["role_ids"] + form.org_ids.choices = mock_choices["org_ids"] + + assert not form.validate() + assert "Please provide valid email" in form.uuid.errors + + +class TestRTBHForm: + @pytest.fixture + def mock_community_choices(self): + return [(1, "Community1"), (2, "Community2")] + + @pytest.fixture + def valid_rtbh_data(self, valid_datetime): + return { + "ipv4": "192.168.1.0", + "ipv4_mask": 24, + "community": "1", + "expires": valid_datetime, + "comment": "Test RTBH rule", + } + + def test_valid_ipv4_rtbh(self, app, valid_rtbh_data, sample_network_ranges, mock_community_choices): + with app.test_request_context(): + form_data = create_form_data(valid_rtbh_data) + form = RTBHForm(formdata=form_data) + form.net_ranges = sample_network_ranges + form.community.choices = mock_community_choices + assert form.validate() + + +class TestIPv4Form: + @pytest.fixture + def mock_action_choices(self): + return [(1, "Accept"), (2, "Drop"), (3, "Reject")] + + @pytest.fixture + def valid_ipv4_data(self, valid_datetime): + return { + "source": "192.168.1.0", + "source_mask": "24", + "protocol": "tcp", + "action": "1", + "expires": valid_datetime, + } + + def test_valid_ipv4_rule(self, app, valid_ipv4_data, sample_network_ranges, mock_action_choices): + with app.test_request_context(): + form_data = create_form_data(valid_ipv4_data) + form = IPv4Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_invalid_protocol_flags(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + with app.test_request_context(): + form_data = create_form_data( + { + "source": "192.168.1.0", + "source_mask": "24", + "protocol": "udp", + "flags": ["syn"], + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv4Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + assert not form.validate() + + +class TestIPv6Form: + @pytest.fixture + def mock_action_choices(self): + return [(1, "Accept"), (2, "Drop"), (3, "Reject")] + + @pytest.fixture + def valid_ipv6_data(self, valid_datetime): + return { + "source": "2001:db8::", # Network aligned address within allowed range + "source_mask": "32", # Matching the organization's prefix length + "next_header": "tcp", + "action": "1", + "expires": valid_datetime, + } + + def test_valid_ipv6_rule(self, app, valid_ipv6_data, sample_network_ranges, mock_action_choices): + with app.test_request_context(): + form_data = create_form_data(valid_ipv6_data) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_invalid_next_header_flags(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db8::", + "source_mask": "32", + "next_header": "udp", + "flags": ["syn"], + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + assert not form.validate() + + def test_address_outside_range(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + """Test validation fails when address is outside allowed ranges""" + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db9::", # Different prefix + "source_mask": "32", + "next_header": "tcp", + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + assert not form.validate() + assert any("must be in organization range" in error for error in form.source.errors) + + def test_destination_address(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + """Test validation with destination address instead of source""" + with app.test_request_context(): + form_data = create_form_data( + { + "dest": "2001:db8::", + "dest_mask": "32", + "next_header": "tcp", + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_both_source_and_dest(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + """Test validation with both source and destination addresses""" + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db8::", + "source_mask": "32", + "dest": "2001:db8:1::", + "dest_mask": "48", + "next_header": "tcp", + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_tcp_flags(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + """Test validation with TCP flags (should be valid with TCP)""" + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db8::", + "source_mask": "32", + "next_header": "tcp", + "flags": ["SYN", "ACK"], + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + @pytest.mark.parametrize("port_data", ["80", "80;443", "1024-2048"]) + def test_valid_ports(self, app, valid_datetime, sample_network_ranges, mock_action_choices, port_data): + """Test validation with various valid port formats""" + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db8::", + "source_mask": "32", + "next_header": "tcp", + "source_port": port_data, + "dest_port": port_data, + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + +class TestBulkUserForm: + @pytest.fixture + def valid_csv_data(self): + return "uuid-eppn,role,organizace\nuser1@example.com,2,1\nuser2@example.com,2,1" + + def test_valid_bulk_import(self, app, valid_csv_data): + with app.test_request_context(): + form_data = create_form_data({"users": valid_csv_data}) + form = BulkUserForm(formdata=form_data) + form.roles = {2} # Mock available roles + form.organizations = {1} # Mock available organizations + form.uuids = set() # Mock existing UUIDs (empty set) + assert form.validate() + + def test_duplicate_uuid(self, app, valid_csv_data): + with app.test_request_context(): + form_data = create_form_data({"users": valid_csv_data}) + form = BulkUserForm(formdata=form_data) + form.roles = {2} + form.organizations = {1} + form.uuids = {"user1@example.com"} # UUID already exists + assert not form.validate() + assert any("already exists" in error for error in form.users.errors) + + def test_invalid_role(self, app): + csv_data = "uuid-eppn,role,organizace\nuser1@example.com,999,1" # Invalid role ID + with app.test_request_context(): + form_data = create_form_data({"users": csv_data}) + form = BulkUserForm(formdata=form_data) + form.roles = {2} + form.organizations = {1} + form.uuids = set() + assert not form.validate() + assert any("does not exist" in error for error in form.users.errors) + + +class TestApiKeyForm: + @pytest.fixture + def valid_api_key_data(self, valid_datetime): + return {"machine": "192.168.1.1", "comment": "Test API key", "expires": valid_datetime, "readonly": "true"} + + def test_valid_api_key(self, app, valid_api_key_data): + with app.test_request_context(): + form_data = create_form_data(valid_api_key_data) + form = ApiKeyForm(formdata=form_data) + assert form.validate() + + def test_invalid_ip(self, app, valid_datetime): + with app.test_request_context(): + form_data = create_form_data({"machine": "invalid_ip", "expires": valid_datetime}) + form = ApiKeyForm(formdata=form_data) + assert not form.validate() + assert "provide valid IP address" in form.machine.errors + + def test_unlimited_expiration(self, app): + with app.test_request_context(): + form_data = create_form_data({"machine": "192.168.1.1", "expires": ""}) # Empty expiration for unlimited + form = ApiKeyForm(formdata=form_data) + assert form.validate() + + +class TestMachineApiKeyForm: + # Similar to ApiKeyForm, but might have different validation rules + @pytest.fixture + def valid_machine_key_data(self, valid_datetime): + return { + "machine": "192.168.1.1", + "comment": "Test machine API key", + "expires": valid_datetime, + "readonly": "true", + "user": 1, + } + + def test_valid_machine_key(self, app, valid_machine_key_data): + with app.test_request_context(): + form_data = create_form_data(valid_machine_key_data) + form = MachineApiKeyForm(formdata=form_data) + form.user.choices = [(1, "g.name"), (2, "test")] + assert form.validate() + + +class TestOrganizationForm: + @pytest.fixture + def valid_org_data(self): + return { + "name": "Test Organization", + "limit_flowspec4": "100", + "limit_flowspec6": "100", + "limit_rtbh": "50", + "arange": "192.168.0.0/16\n2001:db8::/32", + } + + def test_valid_organization(self, app, valid_org_data): + with app.test_request_context(): + form_data = create_form_data(valid_org_data) + form = OrganizationForm(formdata=form_data) + assert form.validate() + + def test_invalid_ranges(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "Test Org", "arange": "invalid_range\n192.168.0.0/16"}) + form = OrganizationForm(formdata=form_data) + assert not form.validate() + + def test_invalid_limits(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "Test Org", "limit_flowspec4": "1001"}) # Exceeds max value + form = OrganizationForm(formdata=form_data) + assert not form.validate() + + +class TestActionForm: + @pytest.fixture + def valid_action_data(self): + return { + "name": "test_action", + "command": "announce route", + "description": "Test action description", + "role_id": "2", + } + + def test_valid_action(self, app, valid_action_data): + with app.test_request_context(): + form_data = create_form_data(valid_action_data) + form = ActionForm(formdata=form_data) + assert form.validate() + + def test_invalid_role(self, app, valid_action_data): + with app.test_request_context(): + invalid_data = dict(valid_action_data) + invalid_data["role_id"] = "4" # Invalid role + form_data = create_form_data(invalid_data) + form = ActionForm(formdata=form_data) + assert not form.validate() + + +class TestASPathForm: + @pytest.fixture + def valid_aspath_data(self): + return {"prefix": "AS64512", "as_path": "64512 64513 64514"} + + def test_valid_aspath(self, app, valid_aspath_data): + with app.test_request_context(): + form_data = create_form_data(valid_aspath_data) + form = ASPathForm(formdata=form_data) + assert form.validate() + + def test_empty_fields(self, app): + with app.test_request_context(): + form_data = create_form_data({}) + form = ASPathForm(formdata=form_data) + assert not form.validate() + + +class TestCommunityForm: + @pytest.fixture + def valid_community_data(self): + return { + "name": "test_community", + "comm": "64512:100", + "description": "Test community", + "role_id": "2", + "as_path": "true", + } + + def test_valid_community(self, app, valid_community_data): + with app.test_request_context(): + form_data = create_form_data(valid_community_data) + form = CommunityForm(formdata=form_data) + assert form.validate() + + def test_missing_all_community_types(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "test_community", "role_id": "2"}) + form = CommunityForm(formdata=form_data) + assert not form.validate() + assert any("could not be empty" in error for error in form.comm.errors) + + def test_valid_with_extended_community(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "test_community", "extcomm": "rt:64512:100", "role_id": "2"}) + form = CommunityForm(formdata=form_data) + assert form.validate() + + def test_valid_with_large_community(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "test_community", "larcomm": "64512:100:200", "role_id": "2"}) + form = CommunityForm(formdata=form_data) + assert form.validate() + + +class TestWhitelistForm: + @pytest.fixture + def valid_ipv4_data(self, valid_datetime): + return {"ip": "192.168.0.0", "mask": "24", "comment": "Test whitelist entry", "expires": valid_datetime} + + @pytest.fixture + def valid_ipv6_data(self, valid_datetime): + return {"ip": "2001:db8::", "mask": "32", "comment": "IPv6 whitelist entry", "expires": valid_datetime} + + def test_valid_ipv4_entry(self, app, valid_ipv4_data, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data(valid_ipv4_data) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_valid_ipv6_entry(self, app, valid_ipv6_data, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data(valid_ipv6_data) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_invalid_ip_format(self, app, valid_datetime, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data({"ip": "invalid_ip", "mask": "24", "expires": valid_datetime}) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + assert any("valid IP address" in error for error in form.ip.errors) + + def test_ip_outside_range(self, app, valid_datetime): + with app.test_request_context(): + form_data = create_form_data( + {"ip": "10.0.0.0", "mask": "24", "expires": valid_datetime} # IP outside allowed range + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = ["192.168.0.0/16"] # Only allow 192.168.0.0/16 + assert not form.validate() + assert any("must be in organization range" in error for error in form.ip.errors) + + def test_invalid_mask_ipv4(self, app, valid_datetime, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data( + {"ip": "192.168.0.0", "mask": "33", "expires": valid_datetime} # Invalid mask for IPv4 + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + + def test_invalid_mask_ipv6(self, app, valid_datetime, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data( + {"ip": "2001:db8::", "mask": "129", "expires": valid_datetime} # Invalid mask for IPv6 + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + + def test_missing_required_fields(self, app, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data({}) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + assert form.ip.errors # Should have error for missing IP + assert form.mask.errors # Should have error for missing mask + assert form.expires.errors # Should have error for missing expiration + + def test_comment_length(self, app, valid_datetime, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data( + { + "ip": "192.168.0.0", + "mask": "24", + "expires": valid_datetime, + "comment": "x" * 256, # Comment longer than 255 chars + } + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + + @pytest.mark.parametrize( + "expires", ["", "invalid_date", "2020-33-42T00:00"] # Empty expiration # Invalid date format # Past date + ) + def test_invalid_expiration(self, app, expires, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data({"ip": "192.168.0.0", "mask": "24", "expires": expires}) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + if not form.validate(): + print("Validation errors:", form.errors) + + assert not form.validate() + + def test_network_alignment(self, app, valid_datetime, sample_network_ranges): + """Test that IP addresses must be properly network-aligned""" + with app.test_request_context(): + form_data = create_form_data( + {"ip": "192.168.1.1", "mask": "24", "expires": valid_datetime} # Not aligned to /24 boundary + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() diff --git a/flowapp/tests/test_models.py b/flowapp/tests/test_models.py index 89599453..b352a5f2 100644 --- a/flowapp/tests/test_models.py +++ b/flowapp/tests/test_models.py @@ -1,4 +1,16 @@ -from datetime import datetime +from datetime import datetime, timedelta +from flowapp.models import ( + User, + Organization, + Role, + ApiKey, + MachineApiKey, + Rstate, + Community, + Action, + Flowspec6, + Whitelist, +) import flowapp.models as models @@ -232,3 +244,255 @@ def test_rtbj_eq(db): ) assert model_A == model_B + + +def test_user_creation(db): + """Test basic user creation and relationships""" + # Create test role and org first + role = Role(name="test_role", description="Test Role") + org = Organization(name="test_org", arange="10.0.0.0/8") + db.session.add_all([role, org]) + db.session.commit() + + # Create user with relationships + user = User( + uuid="test-user-123", name="Test User", phone="1234567890", email="test@example.com", comment="Test comment" + ) + user.role.append(role) + user.organization.append(org) + db.session.add(user) + db.session.commit() + + # Verify user and relationships + assert user.uuid == "test-user-123" + assert user.name == "Test User" + assert len(user.role.all()) == 1 + assert len(user.organization.all()) == 1 + assert user.role.first().name == "test_role" + assert user.organization.first().name == "test_org" + + +def test_api_key_expiration(db): + """Test ApiKey expiration logic""" + user = User(uuid="test-user") + org = Organization(name="test-org", arange="10.0.0.0/8") + db.session.add_all([user, org]) + db.session.commit() + + # Create non-expiring key + non_expiring_key = ApiKey( + machine="test-machine-1", + key="key1", + readonly=True, + expires=None, + comment="Non-expiring key", + user_id=user.id, + org_id=org.id, + ) + + # Create expired key + expired_key = ApiKey( + machine="test-machine-2", + key="key2", + readonly=True, + expires=datetime.now() - timedelta(days=1), + comment="Expired key", + user_id=user.id, + org_id=org.id, + ) + + # Create future key + future_key = ApiKey( + machine="test-machine-3", + key="key3", + readonly=True, + expires=datetime.now() + timedelta(days=1), + comment="Future key", + user_id=user.id, + org_id=org.id, + ) + + db.session.add_all([non_expiring_key, expired_key, future_key]) + db.session.commit() + + assert not non_expiring_key.is_expired() + assert expired_key.is_expired() + assert not future_key.is_expired() + + +def test_machine_api_key_expiration(db): + """Test MachineApiKey expiration logic""" + user = User(uuid="test-user-machine") + org = Organization(name="test-org-machine", arange="10.0.0.0/8") + db.session.add_all([user, org]) + db.session.commit() + + # Create non-expiring key + non_expiring_key = MachineApiKey( + machine="test-machine-1", + key="key1", + readonly=True, + expires=None, + comment="Non-expiring key", + user_id=user.id, + org_id=org.id, + ) + + # Create expired key + expired_key = MachineApiKey( + machine="test-machine-2", + key="key2", + readonly=True, + expires=datetime.now() - timedelta(days=1), + comment="Expired key", + user_id=user.id, + org_id=org.id, + ) + + db.session.add_all([non_expiring_key, expired_key]) + db.session.commit() + + assert not non_expiring_key.is_expired() + assert expired_key.is_expired() + + +def test_organization_get_users(db): + """Test Organization's get_users method""" + org = Organization( + name="test-org-get-user", arange="10.0.0.0/8", limit_flowspec4=100, limit_flowspec6=100, limit_rtbh=100 + ) + uuid1 = "test-org-get-user" + uuid2 = "test-org-get-user2" + user1 = User(uuid=uuid1) + user2 = User(uuid=uuid2) + + db.session.add(org) + db.session.add_all([user1, user2]) + db.session.commit() + + org.user.append(user1) + org.user.append(user2) + db.session.commit() + + users = org.get_users() + assert len(users) == 2 + assert all(isinstance(user, User) for user in users) + assert {user.uuid for user in users} == {uuid1, uuid2} + + +def test_flowspec6_equality(db): + """Test Flowspec6 equality comparison""" + model_a = Flowspec6( + source="2001:db8::1", + source_mask=128, + source_port="80", + destination="2001:db8::2", + destination_mask=128, + destination_port="443", + next_header="tcp", + flags="", + packet_len="", + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + ) + + # Same network parameters but different timestamps + model_b = Flowspec6( + source="2001:db8::1", + source_mask=128, + source_port="80", + destination="2001:db8::2", + destination_mask=128, + destination_port="443", + next_header="tcp", + flags="", + packet_len="", + expires=datetime.now() + timedelta(days=1), + user_id=1, + org_id=1, + action_id=1, + ) + + # Different network parameters + model_c = Flowspec6( + source="2001:db8::3", + source_mask=128, + source_port="80", + destination="2001:db8::4", + destination_mask=128, + destination_port="443", + next_header="tcp", + flags="", + packet_len="", + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + ) + + assert model_a == model_b # Should be equal despite different timestamps + assert model_a != model_c # Should be different due to different network parameters + + +def test_whitelist_equality(db): + """Test Whitelist equality comparison""" + model_a = Whitelist( + ip="192.168.1.1", mask=32, expires=datetime.now(), user_id=1, org_id=1, comment="Test whitelist" + ) + + # Same IP/mask but different timestamps + model_b = Whitelist( + ip="192.168.1.1", + mask=32, + expires=datetime.now() + timedelta(days=1), + user_id=1, + org_id=1, + comment="Different comment", + ) + + # Different IP + model_c = Whitelist( + ip="192.168.1.2", mask=32, expires=datetime.now(), user_id=1, org_id=1, comment="Test whitelist" + ) + + assert model_a == model_b # Should be equal despite different timestamps + assert model_a != model_c # Should be different due to different IP + + +def test_whitelist_to_dict(db): + """Test Whitelist to_dict serialization""" + whitelist = Whitelist( + ip="192.168.1.1", mask=32, expires=datetime.now(), user_id=1, org_id=1, comment="Test whitelist" + ) + + # Create required related objects + user = User(uuid="test-user-whitelist") + rstate = Rstate(description="active") + db.session.add_all([user, rstate]) + db.session.commit() + + db.session.add(whitelist) + db.session.commit() + + whitelist.user = user + whitelist.rstate_id = rstate.id + db.session.add(whitelist) + db.session.commit() + + # Test timestamp format + dict_timestamp = whitelist.to_dict(prefered_format="timestamp") + assert isinstance(dict_timestamp["expires"], int) + assert isinstance(dict_timestamp["created"], int) + + # Test yearfirst format + dict_yearfirst = whitelist.to_dict(prefered_format="yearfirst") + assert isinstance(dict_yearfirst["expires"], str) + assert isinstance(dict_yearfirst["created"], str) + + # Check basic fields + assert dict_timestamp["ip"] == "192.168.1.1" + assert dict_timestamp["mask"] == 32 + assert dict_timestamp["comment"] == "Test whitelist" + assert dict_timestamp["user"] == "test-user-whitelist" diff --git a/flowapp/tests/test_rule_service.py b/flowapp/tests/test_rule_service.py new file mode 100644 index 00000000..490d2c2f --- /dev/null +++ b/flowapp/tests/test_rule_service.py @@ -0,0 +1,628 @@ +""" +Tests for rule_service.py module. + +This test suite verifies the functionality of the rule service after refactoring, +which manages creation, updating, and processing of flow rules. +""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import patch, MagicMock + +from flowapp.constants import RuleOrigin, ANNOUNCE +from flowapp.models import Flowspec4, Flowspec6, RTBH, Whitelist +from flowapp.services import rule_service +from flowapp.services.whitelist_common import Relation + + +@pytest.fixture +def ipv4_form_data(): + """Sample valid IPv4 form data""" + return { + "source": "192.168.1.0", + "source_mask": 24, + "source_port": "80", + "dest": "", + "dest_mask": None, + "dest_port": "", + "protocol": "tcp", + "flags": ["SYN"], + "packet_len": "", + "fragment": ["dont-fragment"], + "comment": "Test IPv4 rule", + "expires": datetime.now() + timedelta(hours=1), + "action": 1, + } + + +@pytest.fixture +def ipv6_form_data(): + """Sample valid IPv6 form data""" + return { + "source": "2001:db8::", + "source_mask": 32, + "source_port": "80", + "dest": "", + "dest_mask": None, + "dest_port": "", + "next_header": "tcp", + "flags": ["SYN"], + "packet_len": "", + "comment": "Test IPv6 rule", + "expires": datetime.now() + timedelta(hours=1), + "action": 1, + } + + +@pytest.fixture +def rtbh_form_data(): + """Sample valid RTBH form data""" + return { + "ipv4": "192.168.1.0", + "ipv4_mask": 24, + "ipv6": "", + "ipv6_mask": None, + "community": 1, + "comment": "Test RTBH rule", + "expires": datetime.now() + timedelta(hours=1), + } + + +@pytest.fixture +def whitelist_fixture(): + """Create a whitelist fixture""" + whitelist = MagicMock(spec=Whitelist) + whitelist.id = 1 + return whitelist + + +class TestCreateOrUpdateIPv4Rule: + @patch("flowapp.services.rule_service.get_ipv4_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_create_new_ipv4_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv4_form_data + ): + """Test creating a new IPv4 rule""" + # Mock the get_ipv4_model_if_exists to return False (not found) + mock_get_model.return_value = False + + # Mock the announce route behavior + mock_messages.create_ipv4.return_value = "mock command" + + # Call the service function + with app.app_context(): + model, message = rule_service.create_or_update_ipv4_rule( + form_data=ipv4_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was created with correct attributes + assert model is not None + assert model.source == ipv4_form_data["source"] + assert model.source_mask == ipv4_form_data["source_mask"] + assert model.protocol == ipv4_form_data["protocol"] + assert model.flags == "SYN" # form data flags are joined with ";" + assert model.action_id == ipv4_form_data["action"] + assert model.rstate_id == 1 # Active state + assert model.user_id == 1 + assert model.org_id == 1 + + # Verify message is still a string for IPv4 rules + assert message == "IPv4 Rule saved" + + # Verify route was announced + mock_messages.create_ipv4.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() + + @patch("flowapp.services.rule_service.get_ipv4_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_update_existing_ipv4_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv4_form_data + ): + """Test updating an existing IPv4 rule""" + # Create an existing model to return + existing_model = Flowspec4( + source=ipv4_form_data["source"], + source_mask=ipv4_form_data["source_mask"], + source_port=ipv4_form_data["source_port"], + destination=ipv4_form_data["dest"] or "", + destination_mask=ipv4_form_data["dest_mask"], + destination_port=ipv4_form_data["dest_port"] or "", + protocol=ipv4_form_data["protocol"], + flags=";".join(ipv4_form_data["flags"]), + packet_len=ipv4_form_data["packet_len"] or "", + fragment=";".join(ipv4_form_data["fragment"]), + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + rstate_id=1, + ) + mock_get_model.return_value = existing_model + mock_messages.create_ipv4.return_value = "mock command" + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + ipv4_form_data["expires"] = new_expires + + # Call the service function + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, message = rule_service.create_or_update_ipv4_rule( + form_data=ipv4_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() + + # Verify message is still a string for IPv4 rules + assert message == "Existing IPv4 Rule found. Expiration time was updated to new value." + + # Verify route was announced + mock_messages.create_ipv4.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() + + +class TestCreateOrUpdateIPv6Rule: + @patch("flowapp.services.rule_service.get_ipv6_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_create_new_ipv6_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv6_form_data + ): + """Test creating a new IPv6 rule""" + # Mock get_ipv6_model_if_exists to return False + mock_get_model.return_value = False + + # Mock the announce route behavior + mock_messages.create_ipv6.return_value = "mock command" + + # Call the service function + with app.app_context(): + model, message = rule_service.create_or_update_ipv6_rule( + form_data=ipv6_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was created with correct attributes + assert model is not None + assert model.source == ipv6_form_data["source"] + assert model.source_mask == ipv6_form_data["source_mask"] + assert model.next_header == ipv6_form_data["next_header"] + assert model.flags == "SYN" + assert model.action_id == ipv6_form_data["action"] + assert model.rstate_id == 1 # Active state + assert model.user_id == 1 + assert model.org_id == 1 + + # Verify message is still a string for IPv6 rules + assert message == "IPv6 Rule saved" + + # Verify route was announced + mock_messages.create_ipv6.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() + + +class TestCreateOrUpdateRTBHRule: + @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") + @patch("flowapp.services.rule_service.db.session.query") + @patch("flowapp.services.rule_service.check_rule_against_whitelists") + @patch("flowapp.services.rule_service.evaluate_rtbh_against_whitelists_check_results") + @patch("flowapp.services.rule_service.announce_rtbh_route") + @patch("flowapp.services.rule_service.log_route") + def test_create_new_rtbh_rule( + self, mock_log, mock_announce, mock_evaluate, mock_check, mock_query, mock_get_model, app, db, rtbh_form_data + ): + """Test creating a new RTBH rule""" + # Mock get_rtbh_model_if_exists to return False + mock_get_model.return_value = False + + # Mock the whitelist query + mock_whitelists = [] + mock_query.return_value.filter.return_value.all.return_value = mock_whitelists + + # Mock check_rule_against_whitelists to return empty list (no matches) + mock_check.return_value = [] + + # Mock evaluate function to return the model unchanged + mock_evaluate.side_effect = lambda user_id, model, flashes, author, wl_cache, results: model + + # Call the service function + with app.app_context(): + model, flashes = rule_service.create_or_update_rtbh_rule( + form_data=rtbh_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was created with correct attributes + assert model is not None + assert model.ipv4 == rtbh_form_data["ipv4"] + assert model.ipv4_mask == rtbh_form_data["ipv4_mask"] + assert model.ipv6 == rtbh_form_data["ipv6"] + assert model.ipv6_mask == rtbh_form_data["ipv6_mask"] + assert model.community_id == rtbh_form_data["community"] + assert model.rstate_id == 1 # Active state + assert model.user_id == 1 + assert model.org_id == 1 + + # Verify flash messages - now a list instead of a string + assert isinstance(flashes, list) + assert "RTBH Rule saved" in flashes[0] + + # Verify rule was announced + mock_announce.assert_called_once() + mock_log.assert_called_once() + + # Verify evaluate function was called + mock_evaluate.assert_called_once() + + @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") + @patch("flowapp.services.rule_service.db.session.query") + @patch("flowapp.services.rule_service.check_rule_against_whitelists") + @patch("flowapp.services.rule_service.evaluate_rtbh_against_whitelists_check_results") + @patch("flowapp.services.rule_service.announce_rtbh_route") + @patch("flowapp.services.rule_service.log_route") + def test_update_existing_rtbh_rule( + self, mock_log, mock_announce, mock_evaluate, mock_check, mock_query, mock_get_model, app, db, rtbh_form_data + ): + """Test updating an existing RTBH rule""" + # Create an existing model to return + existing_model = RTBH( + ipv4=rtbh_form_data["ipv4"], + ipv4_mask=rtbh_form_data["ipv4_mask"], + ipv6=rtbh_form_data["ipv6"] or "", + ipv6_mask=rtbh_form_data["ipv6_mask"], + community_id=rtbh_form_data["community"], + expires=datetime.now(), + user_id=1, + org_id=1, + rstate_id=1, + ) + mock_get_model.return_value = existing_model + + # Mock the whitelist query + mock_whitelists = [] + mock_query.return_value.filter.return_value.all.return_value = mock_whitelists + + # Mock check_rule_against_whitelists to return empty list + mock_check.return_value = [] + + # Mock evaluate function to return the model unchanged + mock_evaluate.side_effect = lambda user_id, model, flashes, author, wl_cache, results: model + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + rtbh_form_data["expires"] = new_expires + + # Call the service function + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, flashes = rule_service.create_or_update_rtbh_rule( + form_data=rtbh_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() + + # Verify flash messages + assert isinstance(flashes, list) + assert "Existing RTBH Rule found" in flashes[0] + + # Verify route was announced + mock_announce.assert_called_once() + mock_log.assert_called_once() + + @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") + @patch("flowapp.services.rule_service.db.session.query") + @patch("flowapp.services.rule_service.map_whitelists_to_strings") + @patch("flowapp.services.rule_service.check_rule_against_whitelists") + @patch("flowapp.services.rule_service.evaluate_rtbh_against_whitelists_check_results") + @patch("flowapp.services.rule_service.announce_rtbh_route") + @patch("flowapp.services.rule_service.log_route") + def test_rtbh_rule_with_whitelists( + self, + mock_log, + mock_announce, + mock_evaluate, + mock_check, + mock_map, + mock_query, + mock_get_model, + app, + db, + rtbh_form_data, + whitelist_fixture, + ): + """Test creating a RTBH rule that interacts with whitelists""" + # Mock get_rtbh_model_if_exists to return False + mock_get_model.return_value = False + + # Create a mock whitelist + mock_whitelist = whitelist_fixture + mock_whitelists = [mock_whitelist] + + # Setup mock query to return our whitelist + mock_query_result = MagicMock() + mock_query_result.filter.return_value.all.return_value = mock_whitelists + mock_query.return_value = mock_query_result + + # Setup map function to return our whitelist in a dict + whitelist_key = "192.168.1.0/24" + mock_map.return_value = {whitelist_key: mock_whitelist} + + # Setup check function to return a relation + mock_rtbh = MagicMock(spec=RTBH) + mock_rtbh.__str__.return_value = "192.168.1.0/24" + + # Setup check result + rule_relation = [(str(mock_rtbh), whitelist_key, Relation.EQUAL)] + mock_check.return_value = rule_relation + + # Setup evaluate function to add a flash message + def evaluate_side_effect(user_id, model, flashes, author, wl_cache, results): + flashes.append("Rule is equal to whitelist") + return model + + mock_evaluate.side_effect = evaluate_side_effect + + # Call the service function + with app.app_context(): + model, flashes = rule_service.create_or_update_rtbh_rule( + form_data=rtbh_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify flash messages show both rule creation and whitelist check + assert isinstance(flashes, list) + assert "RTBH Rule saved" in flashes[0] + assert "Rule is equal to whitelist" in flashes[1] + + # Verify interactions + mock_map.assert_called_once() + mock_check.assert_called_once() + mock_evaluate.assert_called_once() + + +class TestEvaluateRtbhAgainstWhitelistsCheckResults: + def test_equal_relation(self, app, whitelist_fixture): + """Test evaluating a rule with an EQUAL relation to a whitelist""" + # Create a model + model = MagicMock(spec=RTBH) + + # Create test data + flashes = [] + user_id = 1 + author = "test@example.com / Test Org" + whitelist_key = "192.168.1.0/24" + wl_cache = {whitelist_key: whitelist_fixture} + results = [(str(model), whitelist_key, Relation.EQUAL)] + + # Call the function with mocked whitelist_rtbh_rule + with patch("flowapp.services.rule_service.whitelist_rtbh_rule") as mock_whitelist_rule: + mock_whitelist_rule.return_value = model + + with app.app_context(): + result = rule_service.evaluate_rtbh_against_whitelists_check_results( + user_id, model, flashes, author, wl_cache, results + ) + + # Verify the rule was whitelisted + mock_whitelist_rule.assert_called_once_with(model, whitelist_fixture) + + # Verify the flash message + assert flashes + + # Verify the correct model was returned + assert result == model + + def test_subnet_relation(self, app, whitelist_fixture): + """Test evaluating a rule with a SUBNET relation to a whitelist""" + # Create a model + model = MagicMock(spec=RTBH) + + # Create test data + flashes = [] + user_id = 1 + author = "test@example.com / Test Org" + whitelist_key = "192.168.1.128/25" + wl_cache = {whitelist_key: whitelist_fixture} + results = [(str(model), whitelist_key, Relation.SUBNET)] + + # Call the function with mocked dependencies + with patch("flowapp.services.rule_service.subtract_network") as mock_subtract, patch( + "flowapp.services.rule_service.create_rtbh_from_whitelist_parts" + ) as mock_create, patch("flowapp.services.rule_service.add_rtbh_rule_to_cache") as mock_add_cache, patch( + "flowapp.services.rule_service.db.session.commit" + ) as mock_commit: + + # Mock subtract_network to return some subnets + mock_subtract.return_value = ["192.168.1.0/25"] + + with app.app_context(): + _result = rule_service.evaluate_rtbh_against_whitelists_check_results( + user_id, model, flashes, author, wl_cache, results + ) + + # Verify subnet calculation was performed + mock_subtract.assert_called_once() + + # Verify new rules were created for the subnets + mock_create.assert_called_once() + + # Verify the original rule was cached + mock_add_cache.assert_called_once_with(model, whitelist_fixture.id, RuleOrigin.USER) + + # Verify transaction was committed + mock_commit.assert_called_once() + + # Verify the flash messages + assert flashes + + # Verify model was updated to whitelisted state + assert model.rstate_id == 4 + + def test_supernet_relation(self, app, whitelist_fixture): + """Test evaluating a rule with a SUPERNET relation to a whitelist""" + # Create a model + model = MagicMock(spec=RTBH) + + # Create test data + flashes = [] + user_id = 1 + author = "test@example.com / Test Org" + whitelist_key = "192.168.0.0/16" + wl_cache = {whitelist_key: whitelist_fixture} + results = [(str(model), whitelist_key, Relation.SUPERNET)] + + # Call the function with mocked whitelist_rtbh_rule + with patch("flowapp.services.rule_service.whitelist_rtbh_rule") as mock_whitelist_rule: + mock_whitelist_rule.return_value = model + + with app.app_context(): + result = rule_service.evaluate_rtbh_against_whitelists_check_results( + user_id, model, flashes, author, wl_cache, results + ) + + # Verify the rule was whitelisted + mock_whitelist_rule.assert_called_once_with(model, whitelist_fixture) + + # Verify the flash message + assert flashes + + # Verify the correct model was returned + assert result == model + + def test_no_relation(self, app): + """Test evaluating a rule with no relation to any whitelist""" + # Create a model + model = MagicMock(spec=RTBH) + + # Create test data + flashes = [] + user_id = 1 + author = "test@example.com / Test Org" + wl_cache = {} + results = [] + + with app.app_context(): + result = rule_service.evaluate_rtbh_against_whitelists_check_results( + user_id, model, flashes, author, wl_cache, results + ) + + # Verify no changes to the model and no messages + assert result == model + assert not flashes + + +class TestMapWhitelistsToStrings: + def test_map_whitelists_to_strings(self): + """Test mapping whitelist objects to strings""" + # Create mock whitelists + whitelist1 = MagicMock(spec=Whitelist) + whitelist1.__str__.return_value = "192.168.1.0/24" + + whitelist2 = MagicMock(spec=Whitelist) + whitelist2.__str__.return_value = "10.0.0.0/8" + + whitelists = [whitelist1, whitelist2] + + # Call the function + result = rule_service.map_whitelists_to_strings(whitelists) + + # Verify the result + assert len(result) == 2 + assert "192.168.1.0/24" in result + assert "10.0.0.0/8" in result + assert result["192.168.1.0/24"] == whitelist1 + assert result["10.0.0.0/8"] == whitelist2 + + @patch("flowapp.services.rule_service.get_ipv6_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_update_existing_ipv6_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv6_form_data + ): + """Test updating an existing IPv6 rule""" + # Create an existing model to return + existing_model = Flowspec6( + source=ipv6_form_data["source"], + source_mask=ipv6_form_data["source_mask"], + source_port=ipv6_form_data["source_port"] or "", + destination=ipv6_form_data["dest"] or "", + destination_mask=ipv6_form_data["dest_mask"], + destination_port=ipv6_form_data["dest_port"] or "", + next_header=ipv6_form_data["next_header"], + flags=";".join(ipv6_form_data["flags"]), + packet_len=ipv6_form_data["packet_len"] or "", + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + rstate_id=1, + ) + mock_get_model.return_value = existing_model + mock_messages.create_ipv6.return_value = "mock command" + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + ipv6_form_data["expires"] = new_expires + + # Call the service function + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, message = rule_service.create_or_update_ipv6_rule( + form_data=ipv6_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() + + # Verify message is still a string for IPv6 rules + assert message == "Existing IPv6 Rule found. Expiration time was updated to new value." + + # Verify route was announced + mock_messages.create_ipv6.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() diff --git a/flowapp/tests/test_rule_service_reactivate_delete.py b/flowapp/tests/test_rule_service_reactivate_delete.py new file mode 100644 index 00000000..22a3f9f4 --- /dev/null +++ b/flowapp/tests/test_rule_service_reactivate_delete.py @@ -0,0 +1,527 @@ +"""Tests for rule_service module.""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import patch + +from flowapp.constants import RuleTypes +from flowapp.models import ( + RTBH, + Flowspec4, + Flowspec6, + Whitelist, +) +from flowapp.output import RouteSources +from flowapp.services import rule_service + + +@pytest.fixture +def test_data(app, db): + """Fixture providing test data for rule service tests.""" + current_time = datetime.now() + + # Create a test Flowspec4 rule + ipv4_rule = Flowspec4( + source="192.168.1.1", + source_mask=32, + source_port="", + destination="192.168.2.1", + destination_mask=32, + destination_port="", + protocol="tcp", + flags="", + packet_len="", + fragment="", + expires=current_time + timedelta(hours=1), + comment="Test IPv4 rule", + action_id=1, # Using action ID 1 from test database + user_id=1, # Using user ID 1 from test database + org_id=1, # Using org ID 1 from test database + rstate_id=1, # Active state + ) + db.session.add(ipv4_rule) + + # Create a test Flowspec6 rule + ipv6_rule = Flowspec6( + source="2001:db8::1", + source_mask=128, + source_port="", + destination="2001:db8:1::1", + destination_mask=128, + destination_port="", + next_header="tcp", + flags="", + packet_len="", + expires=current_time + timedelta(hours=1), + comment="Test IPv6 rule", + action_id=1, # Using action ID 1 from test database + user_id=1, # Using user ID 1 from test database + org_id=1, # Using org ID 1 from test database + rstate_id=1, # Active state + ) + db.session.add(ipv6_rule) + + # Create a test RTBH rule + rtbh_rule = RTBH( + ipv4="192.168.1.100", + ipv4_mask=32, + ipv6=None, + ipv6_mask=None, + community_id=1, # Using community ID 1 from test database + expires=current_time + timedelta(hours=1), + comment="Test RTBH rule", + user_id=1, # Using user ID 1 from test database + org_id=1, # Using org ID 1 from test database + rstate_id=1, # Active state + ) + db.session.add(rtbh_rule) + + # Create a test Whitelist + whitelist = Whitelist( + ip="192.168.2.1", + mask=24, + expires=current_time + timedelta(days=7), + comment="Test whitelist", + user_id=1, + org_id=1, + rstate_id=1, + ) + db.session.add(whitelist) + + db.session.commit() + + # Return data that will be useful for tests + return { + "user_id": 1, + "org_id": 1, + "user_email": "test@example.com", + "org_name": "Test Org", + "ipv4_rule_id": ipv4_rule.id, + "ipv6_rule_id": ipv6_rule.id, + "rtbh_rule_id": rtbh_rule.id, + "whitelist_id": whitelist.id, + "current_time": current_time, + "future_time": current_time + timedelta(hours=1), + "past_time": current_time - timedelta(hours=1), + "comment": "Test comment", + } + + +class TestReactivateRule: + """Tests for the reactivate_rule function.""" + + @patch("flowapp.services.rule_service.check_global_rule_limit") + @patch("flowapp.services.rule_service.check_rule_limit") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_reactivate_rule_active( + self, mock_log_route, mock_announce_route, mock_check_rule_limit, mock_check_global_limit, test_data, db + ): + """Test reactivating a rule with future expiration (active state).""" + # Setup mocks + mock_check_global_limit.return_value = False + mock_check_rule_limit.return_value = False + + # The rule will be active (state=1) because expiration is in the future + expires = test_data["future_time"] + + # Call the function + model, messages = rule_service.reactivate_rule( + rule_type=RuleTypes.IPv4, + rule_id=test_data["ipv4_rule_id"], + expires=expires, + comment=test_data["comment"], + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + ) + + # Assertions + mock_check_global_limit.assert_called_once_with(RuleTypes.IPv4.value) + mock_check_rule_limit.assert_called_once_with(test_data["org_id"], rule_type=RuleTypes.IPv4.value) + + # Verify model was updated + assert model.expires == expires + assert model.comment == test_data["comment"] + assert model.rstate_id == 1 # Active state + + # Verify route announcement was made + mock_announce_route.assert_called_once() + args, _ = mock_announce_route.call_args + assert args[0].source == RouteSources.UI + + # Verify logging + mock_log_route.assert_called_once() + + # Check returned values + assert messages + assert messages[0].startswith("Rule ") + + @patch("flowapp.services.rule_service.check_global_rule_limit") + @patch("flowapp.services.rule_service.check_rule_limit") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_withdraw") + def test_reactivate_rule_inactive( + self, mock_log_withdraw, mock_announce_route, mock_check_rule_limit, mock_check_global_limit, test_data, db + ): + """Test reactivating a rule with past expiration (inactive state).""" + # Setup mocks + mock_check_global_limit.return_value = False + mock_check_rule_limit.return_value = False + + # The rule will be inactive (state=2) because expiration is in the past + expires = test_data["past_time"] + + # Call the function + model, messages = rule_service.reactivate_rule( + rule_type=RuleTypes.IPv4, + rule_id=test_data["ipv4_rule_id"], + expires=expires, + comment=test_data["comment"], + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + ) + + # Verify model was updated + assert model.expires == expires + assert model.comment == test_data["comment"] + assert model.rstate_id == 2 # Inactive state + + # Verify route withdrawal was made + mock_announce_route.assert_called_once() + args, _ = mock_announce_route.call_args + assert args[0].source == RouteSources.UI + + # Verify logging + mock_log_withdraw.assert_called_once() + + # Check returned values + assert messages + assert messages[0].startswith("Rule ") + + @patch("flowapp.services.rule_service.check_global_rule_limit") + @patch("flowapp.services.rule_service.check_rule_limit") + def test_reactivate_rule_global_limit_reached(self, mock_check_rule_limit, mock_check_global_limit, test_data, db): + """Test reactivating a rule when global limit is reached.""" + # Setup mocks + mock_check_global_limit.return_value = True # Global limit reached + mock_check_rule_limit.return_value = False + + # The rule would be active, but global limit is reached + expires = test_data["future_time"] + + # Call the function + model, messages = rule_service.reactivate_rule( + rule_type=RuleTypes.IPv4, + rule_id=test_data["ipv4_rule_id"], + expires=expires, + comment=test_data["comment"], + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + ) + + # Assertions + mock_check_global_limit.assert_called_once_with(RuleTypes.IPv4.value) + + # Check returned values + assert messages == ["global_limit_reached"] + + @patch("flowapp.services.rule_service.check_global_rule_limit") + @patch("flowapp.services.rule_service.check_rule_limit") + def test_reactivate_rule_org_limit_reached(self, mock_check_rule_limit, mock_check_global_limit, test_data, db): + """Test reactivating a rule when organization limit is reached.""" + # Setup mocks + mock_check_global_limit.return_value = False + mock_check_rule_limit.return_value = True # Org limit reached + + # The rule would be active, but org limit is reached + expires = test_data["future_time"] + + # Call the function + model, messages = rule_service.reactivate_rule( + rule_type=RuleTypes.IPv4, + rule_id=test_data["ipv4_rule_id"], + expires=expires, + comment=test_data["comment"], + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + ) + + # Assertions + mock_check_global_limit.assert_called_once_with(RuleTypes.IPv4.value) + mock_check_rule_limit.assert_called_once_with(test_data["org_id"], rule_type=RuleTypes.IPv4.value) + + # Check returned values + assert messages == ["limit_reached"] + + +class TestDeleteRule: + """Tests for the delete_rule function.""" + + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_withdraw") + def test_delete_rule_success(self, mock_log_withdraw, mock_announce_route, test_data, db): + """Test successful rule deletion.""" + # Call the function + success, message = rule_service.delete_rule( + rule_type=RuleTypes.IPv4, + rule_id=test_data["ipv4_rule_id"], + user_id=test_data["user_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[test_data["ipv4_rule_id"]], # Rule is in allowed list + ) + + # Verify route withdrawal was made + mock_announce_route.assert_called_once() + args, _ = mock_announce_route.call_args + assert args[0].source == RouteSources.UI + + # Verify logging + mock_log_withdraw.assert_called_once() + + # Verify database operations - rule should be deleted + rule = db.session.get(Flowspec4, test_data["ipv4_rule_id"]) + assert rule is None + + # Check returned values + assert success is True + assert message == "Rule deleted successfully" + + def test_delete_rule_not_allowed(self, test_data, db): + """Test rule deletion when rule is not in allowed list.""" + # Call the function + success, message = rule_service.delete_rule( + rule_type=RuleTypes.IPv4, + rule_id=test_data["ipv4_rule_id"], + user_id=test_data["user_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[999], # Rule is not in allowed list + ) + + # Verify rule still exists + rule = db.session.get(Flowspec4, test_data["ipv4_rule_id"]) + assert rule is not None + + # Check returned values + assert success is False + assert message == "You cannot delete this rule" + + def test_delete_rule_not_found(self, test_data, db): + """Test rule deletion when rule is not found.""" + # Call the function with non-existent ID + success, message = rule_service.delete_rule( + rule_type=RuleTypes.IPv4, + rule_id=9999, # Non-existent ID + user_id=test_data["user_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + ) + + # Check returned values + assert success is False + assert message == "Rule not found" + + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_withdraw") + @patch("flowapp.services.rule_service.RuleWhitelistCache") + def test_delete_rtbh_rule(self, mock_whitelist_cache, mock_log_withdraw, mock_announce_route, test_data, db): + """Test deleting an RTBH rule.""" + # Call the function + success, message = rule_service.delete_rule( + rule_type=RuleTypes.RTBH, + rule_id=test_data["rtbh_rule_id"], + user_id=test_data["user_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[test_data["rtbh_rule_id"]], + ) + + # Verify whitelist cache was cleaned + mock_whitelist_cache.delete_by_rule_id.assert_called_once_with(test_data["rtbh_rule_id"]) + + # Verify route withdrawal and logging + mock_announce_route.assert_called_once() + mock_log_withdraw.assert_called_once() + + # Verify rule was deleted + rule = db.session.get(RTBH, test_data["rtbh_rule_id"]) + assert rule is None + + # Check returned values + assert success is True + assert message == "Rule deleted successfully" + + +class TestDeleteRtbhAndCreateWhitelist: + """Tests for the delete_rtbh_and_create_whitelist function.""" + + @patch("flowapp.services.rule_service.delete_rule") + @patch("flowapp.services.rule_service.create_or_update_whitelist") + def test_delete_rtbh_and_create_whitelist_success(self, mock_create_whitelist, mock_delete_rule, test_data, db): + """Test successful RTBH deletion and whitelist creation.""" + # Setup mock for delete_rule service + mock_delete_rule.return_value = (True, "Rule deleted successfully") + + # Setup mock for create_or_update_whitelist service + mock_whitelist = Whitelist( + ip="192.168.1.100", + mask=32, + expires=datetime.now() + timedelta(days=7), + user_id=test_data["user_id"], + org_id=test_data["org_id"], + ) + mock_create_whitelist.return_value = (mock_whitelist, ["Whitelist created"]) + + # Call the function + success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( + rule_id=test_data["rtbh_rule_id"], + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[test_data["rtbh_rule_id"]], + ) + + # Verify delete_rule was called + mock_delete_rule.assert_called_once_with( + rule_type=RuleTypes.RTBH, + rule_id=test_data["rtbh_rule_id"], + user_id=test_data["user_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[test_data["rtbh_rule_id"]], + ) + + # Verify create_or_update_whitelist was called with correct data + mock_create_whitelist.assert_called_once() + args, kwargs = mock_create_whitelist.call_args + form_data = kwargs.get("form_data", args[0] if args else None) + assert form_data["ip"] == "192.168.1.100" + assert form_data["mask"] == 32 + assert "Created from RTBH rule" in form_data["comment"] + + # Check returned values + assert success is True + assert len(messages) == 2 + assert messages[0] == "Rule deleted successfully" + assert messages[1] == "Whitelist created" + assert whitelist == mock_whitelist + + def test_delete_rtbh_and_create_whitelist_rule_not_found(self, test_data, db): + """Test when the RTBH rule to convert is not found.""" + # Call the function with non-existent ID + success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( + rule_id=9999, # Non-existent ID + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + ) + + # Check returned values + assert success is False + assert messages == ["RTBH rule not found"] + assert whitelist is None + + def test_delete_rtbh_and_create_whitelist_not_allowed(self, test_data, db): + """Test when the user is not allowed to delete the rule.""" + # Call the function + success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( + rule_id=test_data["rtbh_rule_id"], + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[999], # Rule not in allowed list + ) + + # Check returned values + assert success is False + assert messages == ["You cannot delete this rule"] + assert whitelist is None + + @patch("flowapp.services.rule_service.delete_rule") + def test_delete_rtbh_and_create_whitelist_delete_fails(self, mock_delete_rule, test_data, db): + """Test when the RTBH deletion fails.""" + # Setup mock for delete_rule service to fail + mock_delete_rule.return_value = (False, "Error deleting rule") + + # Call the function + success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( + rule_id=test_data["rtbh_rule_id"], + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[test_data["rtbh_rule_id"]], + ) + + # Verify delete_rule was called + mock_delete_rule.assert_called_once() + + # Check returned values + assert success is False + assert messages == ["Error deleting rule"] + assert whitelist is None + + @patch("flowapp.services.rule_service.delete_rule") + @patch("flowapp.services.rule_service.create_or_update_whitelist") + def test_rtbh_with_ipv6(self, mock_create_whitelist, mock_delete_rule, test_data, db): + """Test with an IPv6 RTBH rule.""" + # Create an IPv6 RTBH rule + ipv6_rtbh = RTBH( + ipv4=None, + ipv4_mask=None, + ipv6="2001:db8::1", + ipv6_mask=64, + community_id=1, + expires=datetime.now() + timedelta(hours=1), + comment="IPv6 RTBH rule", + user_id=test_data["user_id"], + org_id=test_data["org_id"], + rstate_id=1, + ) + db.session.add(ipv6_rtbh) + db.session.commit() + + # Setup mocks + mock_delete_rule.return_value = (True, "Rule deleted successfully") + mock_whitelist = Whitelist( + ip="2001:db8::1", + mask=64, + expires=datetime.now() + timedelta(days=7), + user_id=test_data["user_id"], + org_id=test_data["org_id"], + ) + mock_create_whitelist.return_value = (mock_whitelist, ["Whitelist created"]) + + # Call the function + success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( + rule_id=ipv6_rtbh.id, + user_id=test_data["user_id"], + org_id=test_data["org_id"], + user_email=test_data["user_email"], + org_name=test_data["org_name"], + allowed_rule_ids=[ipv6_rtbh.id], + ) + + # Verify create_or_update_whitelist was called with correct IPv6 data + mock_create_whitelist.assert_called_once() + args, kwargs = mock_create_whitelist.call_args + form_data = kwargs.get("form_data", args[0] if args else None) + assert form_data["ip"] == "2001:db8::1" + assert form_data["mask"] == 64 + + # Check returned values + assert success is True + assert len(messages) == 2 + assert whitelist == mock_whitelist diff --git a/flowapp/tests/test_validators.py b/flowapp/tests/test_validators.py index 33e757cf..cab089f0 100644 --- a/flowapp/tests/test_validators.py +++ b/flowapp/tests/test_validators.py @@ -1,11 +1,28 @@ import pytest -import flowapp.validators +from flowapp.validators import ( + PortString, + PacketString, + NetRangeString, + NetInRange, + IPAddress, + IPAddressValidator, + NetworkValidator, + ValidationError, + address_in_range, + address_with_mask, + IPv4Address, + IPv6Address, + DateNotExpired, + editable_range, + network_in_range, + range_in_network, +) def test_port_string_len_raises(field): - port = flowapp.validators.PortString() + port = PortString() field.data = "1;2;3;4;5;6;7;8" - with pytest.raises(flowapp.validators.ValidationError): + with pytest.raises(ValidationError): port(None, field) @@ -20,12 +37,12 @@ def test_port_string_len_raises(field): ], ) def test_is_valid_address_with_mask(address, mask, expected): - assert flowapp.validators.address_with_mask(address, mask) == expected + assert address_with_mask(address, mask) == expected @pytest.mark.parametrize("address", ["147.230.23.25", "147.230.23.0"]) def test_ip4address_passes(field, address): - adr = flowapp.validators.IPv4Address() + adr = IPv4Address() field.data = address adr(None, field) @@ -38,7 +55,7 @@ def test_ip4address_passes(field, address): ], ) def test_ip6address_passes(field, address): - adr = flowapp.validators.IPv6Address() + adr = IPv6Address() field.data = address adr(None, field) @@ -51,19 +68,17 @@ def test_ip6address_passes(field, address): ], ) def test_bad_ip6address_raises(field, address): - adr = flowapp.validators.IPv4Address() + adr = IPv4Address() field.data = address - with pytest.raises(flowapp.validators.ValidationError): + with pytest.raises(ValidationError): adr(None, field) -@pytest.mark.parametrize( - "expired", ["2018/10/25 14:46", "2018/12/20 9:46", "2019/05/22 12:33"] -) +@pytest.mark.parametrize("expired", ["2018/10/25 14:46", "2018/12/20 9:46", "2019/05/22 12:33"]) def test_expired_date_raises(field, expired): - adr = flowapp.validators.DateNotExpired() + adr = DateNotExpired() field.data = expired - with pytest.raises(flowapp.validators.ValidationError): + with pytest.raises(ValidationError): adr(None, field) @@ -75,9 +90,9 @@ def test_expired_date_raises(field, expired): ], ) def test_ipaddress_raises(field, address): - adr = flowapp.validators.IPv6Address() + adr = IPv6Address() field.data = address - with pytest.raises(flowapp.validators.ValidationError): + with pytest.raises(ValidationError): adr(None, field) @@ -91,7 +106,7 @@ def test_ipaddress_raises(field, address): def test_editable_rule(rule, address, mask, ranges, expected): rule.source = address rule.source_mask = mask - assert flowapp.validators.editable_range(rule, ranges) == expected + assert editable_range(rule, ranges) == expected @pytest.mark.parametrize( @@ -104,7 +119,7 @@ def test_editable_rule(rule, address, mask, ranges, expected): ], ) def test_address_in_range(address, mask, ranges, expected): - assert flowapp.validators.address_in_range(address, ranges) == expected + assert address_in_range(address, ranges) == expected @pytest.mark.parametrize( @@ -123,7 +138,7 @@ def test_address_in_range(address, mask, ranges, expected): ], ) def test_network_in_range(address, mask, ranges, expected): - assert flowapp.validators.network_in_range(address, mask, ranges) == expected + assert network_in_range(address, mask, ranges) == expected @pytest.mark.parametrize( @@ -133,4 +148,209 @@ def test_network_in_range(address, mask, ranges, expected): ], ) def test_range_in_network(address, mask, ranges, expected): - assert flowapp.validators.range_in_network(address, mask, ranges) == expected + assert range_in_network(address, mask, ranges) == expected + + +### new tests + + +# New tests for previously uncovered validators + + +# Tests for PortString validator syntax +@pytest.mark.parametrize( + "port_data", + [ + "80", # Simple number + "80-443", # Range with hyphen + ">=80&<=443", # Explicit range + ">80", # Greater than + ">=80", # Greater than or equal + "<443", # Less than + "<=443", # Less than or equal + "80;443;8080", # Multiple port expressions + ], +) +def test_port_string_valid_syntax(field, port_data): + validator = PortString() + field.data = port_data + validator(None, field) # Should not raise + + +@pytest.mark.parametrize( + "invalid_port_data", + [ + "80,443", # Comma not supported, should use semicolon + "80..443", # Invalid range syntax (should be 80-443) + ">65536", # Port number too high + "-1", # Negative number + "<-1", # Port number too low + ">=80<=443", # Invalid range syntax (missing &) + "!=80", # Not equals not supported + "abc", # Non-numeric + "80-", # Incomplete range + "-80", # Invalid range + "80&&443", # Invalid operator + "443-80", # End less than start + ">=443&<=80", # End less than start in explicit range + "0-65537", # Range exceeding max + ], +) +def test_port_string_invalid_syntax(field, invalid_port_data): + validator = PortString() + field.data = invalid_port_data + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for PacketString validator +@pytest.mark.parametrize( + "packet_data", + ["1500", ">1000", "<9000", "1000-1500", ">=1000&<=1500", "1500;9000"], # Multiple packet size expressions +) +def test_packet_string_valid(field, packet_data): + validator = PacketString() + field.data = packet_data + validator(None, field) # Should not raise + + +@pytest.mark.parametrize( + "invalid_packet_data", + [ + "65536", # Too large + "-1", # Negative + "1500..", # Invalid range syntax + "abc", # Non-numeric + "!!1500", # Invalid operator + "1500-", # Incomplete range + "9000-1500", # End less than start + ], +) +def test_packet_string_invalid(field, invalid_packet_data): + validator = PacketString() + field.data = invalid_packet_data + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for NetRangeString validator +@pytest.mark.parametrize( + "net_range", + [ + "192.168.0.0/24", + "10.0.0.0/8\n172.16.0.0/12", + "2001:db8::/32", + "192.168.0.0/24 10.0.0.0/8", + "2001:db8::/32 2001:db8:1::/48", + ], +) +def test_net_range_string_valid(field, net_range): + validator = NetRangeString() + field.data = net_range + validator(None, field) # Should not raise + + +@pytest.mark.parametrize( + "invalid_net_range", + [ + "192.168.0.0/33", # Invalid mask + "256.256.256.0/24", # Invalid IP + "2001:xyz::/32", # Invalid IPv6 + "192.168.0.0/24/", # Malformed + "not-an-ip/24", # Invalid format + ], +) +def test_net_range_string_invalid(field, invalid_net_range): + validator = NetRangeString() + field.data = invalid_net_range + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for NetInRange validator +def test_net_in_range_valid(field): + net_ranges = ["192.168.0.0/16", "10.0.0.0/8"] + validator = NetInRange(net_ranges) + field.data = "192.168.1.1/24" + validator(None, field) # Should not raise + + +def test_net_in_range_invalid(field): + net_ranges = ["192.168.0.0/16", "10.0.0.0/8"] + validator = NetInRange(net_ranges) + field.data = "172.16.1.1/24" + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for base IPAddress validator +@pytest.mark.parametrize("ip_addr", ["192.168.1.1", "10.0.0.1", "2001:db8::1", "fe80::1"]) +def test_ip_address_valid(field, ip_addr): + validator = IPAddress() + field.data = ip_addr + validator(None, field) # Should not raise + + +@pytest.mark.parametrize("invalid_ip", ["256.256.256.256", "2001:xyz::1", "not-an-ip", "192.168.1"]) +def test_ip_address_invalid(field, invalid_ip): + validator = IPAddress() + field.data = invalid_ip + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for universal IPAddressValidator +@pytest.mark.parametrize("ip_addr", ["192.168.1.1", "2001:db8::1", "", None]) # Empty is allowed # None is allowed +def test_ip_address_validator_valid(field, ip_addr): + validator = IPAddressValidator() + field.data = ip_addr + validator(None, field) # Should not raise + + +@pytest.mark.parametrize("invalid_ip", ["256.256.256.256", "2001:xyz::1", "not-an-ip"]) +def test_ip_address_validator_invalid(field, invalid_ip): + validator = IPAddressValidator() + field.data = invalid_ip + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for NetworkValidator +class MockForm: + def __init__(self, address, mask): + self._fields = {"mask": type("MockField", (), {"data": mask})()} + + +@pytest.mark.parametrize( + "address,mask", + [ + ("192.168.0.0", "24"), + ("10.0.0.0", "8"), + ("2001:db8::", "32"), + ("", None), # Empty values should pass + (None, None), # None values should pass + ], +) +def test_network_validator_valid(field, address, mask): + validator = NetworkValidator("mask") + field.data = address + form = MockForm(address, mask) + validator(form, field) # Should not raise + + +@pytest.mark.parametrize( + "address,mask", + [ + ("192.168.1.1", "24"), # Not a network address + ("192.168.0.0", "33"), # Invalid IPv4 mask + ("2001:db8::", "129"), # Invalid IPv6 mask + ("256.256.256.0", "24"), # Invalid IP + ("2001:xyz::", "32"), # Invalid IPv6 + ], +) +def test_network_validator_invalid(field, address, mask): + validator = NetworkValidator("mask") + field.data = address + form = MockForm(address, mask) + with pytest.raises(ValidationError): + validator(form, field) diff --git a/flowapp/tests/test_whitelist_common.py b/flowapp/tests/test_whitelist_common.py new file mode 100644 index 00000000..a843aff9 --- /dev/null +++ b/flowapp/tests/test_whitelist_common.py @@ -0,0 +1,250 @@ +import pytest +from flowapp.services.whitelist_common import ( + Relation, + check_rule_against_whitelists, + check_whitelist_against_rules, + check_whitelist_to_rule_relation, + subtract_network, + clear_network_cache, + _is_same_ip_version, # New helper function +) + + +# Tests for core function that checks network relations +@pytest.mark.parametrize( + "rule,whitelist,expected", + [ + # IPv4 test cases + ("192.168.1.0/24", "192.168.0.0/16", Relation.SUPERNET), # Whitelist is supernet + ("192.168.1.0/24", "192.168.1.0/24", Relation.EQUAL), # Equal networks + ("192.168.1.0/24", "192.168.1.128/25", Relation.SUBNET), # Whitelist is subnet + ("192.168.1.128/25", "192.168.1.0/24", Relation.SUPERNET), # Whitelist is supernet + ("10.0.0.0/8", "192.168.1.0/24", Relation.DIFFERENT), # Different networks + # IPv6 test cases + ("2001:db8::/32", "2001:db8::/32", Relation.EQUAL), # Equal networks + ("2001:db8:1::/48", "2001:db8::/32", Relation.SUPERNET), # Whitelist is supernet + ("2001:db8::/32", "2001:db8:1::/48", Relation.SUBNET), # Whitelist is subnet + ("2001:db8::/32", "2002:db8::/32", Relation.DIFFERENT), # Different networks + ("2001:db8:1:2::/64", "2001:db8::/32", Relation.SUPERNET), # Whitelist is supernet + ], +) +def test_check_whitelist_to_rule_relation(rule, whitelist, expected): + clear_network_cache() + assert check_whitelist_to_rule_relation(rule, whitelist) == expected + + +@pytest.mark.parametrize( + "target,whitelist,expected", + [ + # Basic IPv4 cases + ( + "192.168.1.0/24", + "192.168.1.128/25", + ["192.168.1.0/25"], + ), # One remaining subnet + ( + "192.168.1.0/24", + "192.168.1.64/26", + ["192.168.1.0/26", "192.168.1.128/25"], + ), # Two remaining subnets + ( + "192.168.1.0/24", + "192.168.1.0/24", + ["192.168.1.0/24"], + ), # Equal networks - return original network + ( + "192.168.1.0/24", + "192.168.2.0/24", + ["192.168.1.0/24"], + ), # No overlap + ], +) +def test_subtract_network(target, whitelist, expected): + clear_network_cache() + result = subtract_network(target, whitelist) + assert sorted(result) == sorted(expected) + + +def test_check_rule_against_whitelists(): + clear_network_cache() + + rule = "192.168.1.0/24" + whitelists = [ + "192.168.0.0/16", # SUPERNET + "172.16.0.0/12", # DIFFERENT + "192.168.1.0/24", # EQUAL + "192.168.1.128/25", # SUBNET + ] + + results = check_rule_against_whitelists(rule, whitelists) + + # Should return list of tuples (rule, whitelist, relation) for non-DIFFERENT relations + expected = [ + (rule, "192.168.0.0/16", Relation.SUPERNET), + (rule, "192.168.1.0/24", Relation.EQUAL), + (rule, "192.168.1.128/25", Relation.SUBNET), + ] + + assert len(results) == 3 + for result, exp in zip(sorted(results), sorted(expected)): + assert result == exp + + +def test_check_whitelist_against_rules(): + clear_network_cache() + + whitelist = "192.168.1.128/25" + rules = [ + "192.168.0.0/16", # Whitelist is SUBNET of this larger network + "172.16.0.0/12", # DIFFERENT + "192.168.1.128/25", # EQUAL + "192.168.1.0/24", # Whitelist is SUBNET of this network + ] + + results = check_whitelist_against_rules(rules, whitelist) + + # Should return list of tuples (rule, whitelist, relation) for non-DIFFERENT relations + expected = [ + ("192.168.0.0/16", whitelist, Relation.SUBNET), + ("192.168.1.128/25", whitelist, Relation.EQUAL), + ("192.168.1.0/24", whitelist, Relation.SUBNET), + ] + + assert len(results) == 3 + for result, exp in zip(sorted(results), sorted(expected)): + assert result == exp + + +# Tests for edge cases and error handling +def test_host_addresses(): + clear_network_cache() + # Test with host addresses (no explicit subnet mask) + assert check_whitelist_to_rule_relation("192.168.1.1", "192.168.1.0/24") == Relation.SUPERNET + assert check_whitelist_to_rule_relation("192.168.1.1", "192.168.2.0/24") == Relation.DIFFERENT + + +def test_single_ip_as_network(): + clear_network_cache() + # Test with /32 (single IPv4 address as network) + assert check_whitelist_to_rule_relation("192.168.1.1/32", "192.168.1.0/24") == Relation.SUPERNET + # Test with /128 (single IPv6 address as network) + assert check_whitelist_to_rule_relation("2001:db8::1/128", "2001:db8::/32") == Relation.SUPERNET + + +def test_is_same_ip_version(): + """Test the new helper function to check if two addresses are of the same IP version""" + # IPv4 cases + assert _is_same_ip_version("192.168.1.0/24", "10.0.0.0/8") is True + assert _is_same_ip_version("192.168.1.1", "10.0.0.1") is True + + # IPv6 cases + assert _is_same_ip_version("2001:db8::/32", "fe80::/64") is True + assert _is_same_ip_version("2001:db8::1", "fe80::1") is True + + # Mixed cases + assert _is_same_ip_version("192.168.1.0/24", "2001:db8::/32") is False + assert _is_same_ip_version("192.168.1.1", "fe80::1") is False + + +def test_check_whitelist_to_rule_relation_mixed_versions(): + """Test the check_whitelist_to_rule_relation function with mixed IP versions""" + clear_network_cache() + + # IPv4 rule with IPv6 whitelist + assert check_whitelist_to_rule_relation("192.168.1.0/24", "2001:db8::/32") == Relation.DIFFERENT + + # IPv6 rule with IPv4 whitelist + assert check_whitelist_to_rule_relation("2001:db8::/32", "192.168.1.0/24") == Relation.DIFFERENT + + # Invalid IP addresses should return DIFFERENT + assert check_whitelist_to_rule_relation("invalid", "192.168.1.0/24") == Relation.DIFFERENT + assert check_whitelist_to_rule_relation("192.168.1.0/24", "invalid") == Relation.DIFFERENT + + +def test_subtract_network_mixed_versions(): + """Test the subtract_network function with mixed IP versions""" + clear_network_cache() + + # IPv4 target with IPv6 whitelist - should return original target + result = subtract_network("192.168.1.0/24", "2001:db8::/32") + assert result == ["192.168.1.0/24"] + + # IPv6 target with IPv4 whitelist - should return original target + result = subtract_network("2001:db8::/32", "192.168.1.0/24") + assert result == ["2001:db8::/32"] + + # Invalid addresses - should return original target + result = subtract_network("invalid", "192.168.1.0/24") + assert result == ["invalid"] + result = subtract_network("192.168.1.0/24", "invalid") + assert result == ["192.168.1.0/24"] + + +def test_check_rule_against_whitelists_mixed_versions(): + """Test the check_rule_against_whitelists function with mixed IP versions""" + clear_network_cache() + + # IPv4 rule against mixed whitelist + rule = "192.168.1.0/24" + whitelists = [ + "192.168.0.0/16", # IPv4 - should match + "2001:db8::/32", # IPv6 - should not match + "10.0.0.0/8", # IPv4 - should not match (different network) + ] + + results = check_rule_against_whitelists(rule, whitelists) + assert len(results) == 1 + assert results[0][0] == rule + assert results[0][1] == "192.168.0.0/16" + assert results[0][2] == Relation.SUPERNET + + # IPv6 rule against mixed whitelist + rule = "2001:db8:1::/48" + whitelists = [ + "192.168.0.0/16", # IPv4 - should not match + "2001:db8::/32", # IPv6 - should match + "2002:db8::/32", # IPv6 - should not match (different network) + ] + + results = check_rule_against_whitelists(rule, whitelists) + assert len(results) == 1 + assert results[0][0] == rule + assert results[0][1] == "2001:db8::/32" + assert results[0][2] == Relation.SUPERNET + + +def test_check_whitelist_against_rules_mixed_versions(): + """Test the check_whitelist_against_rules function with mixed IP versions""" + clear_network_cache() + + # IPv4 whitelist against mixed rules + whitelist = "192.168.1.0/24" + rules = [ + "192.168.0.0/16", # IPv4 - should match + "2001:db8::/32", # IPv6 - should not match + "10.0.0.0/8", # IPv4 - should not match (different network) + ] + + results = check_whitelist_against_rules(rules, whitelist) + assert len(results) == 1 + assert results[0][0] == "192.168.0.0/16" + assert results[0][1] == whitelist + assert results[0][2] == Relation.SUBNET + + # IPv6 whitelist against mixed rules + whitelist = "2001:db8:1::/48" + rules = [ + "192.168.0.0/16", # IPv4 - should not match + "2001:db8::/32", # IPv6 - should match + "2002:db8::/32", # IPv6 - should not match (different network) + ] + + results = check_whitelist_against_rules(rules, whitelist) + assert len(results) == 1 + assert results[0][0] == "2001:db8::/32" + assert results[0][1] == whitelist + assert results[0][2] == Relation.SUBNET + + +if __name__ == "__main__": + pytest.main(["-v"]) diff --git a/flowapp/tests/test_whitelist_service.py b/flowapp/tests/test_whitelist_service.py new file mode 100644 index 00000000..460a172b --- /dev/null +++ b/flowapp/tests/test_whitelist_service.py @@ -0,0 +1,461 @@ +""" +Tests for whitelist_service.py module. + +This test suite verifies the functionality of the whitelist service, +which manages creation, updating, and handling of whitelist rules. +""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import patch, MagicMock + +from flowapp.constants import RuleTypes, RuleOrigin +from flowapp.models import Whitelist, RuleWhitelistCache, RTBH +from flowapp.services import whitelist_service +from flowapp.services.whitelist_common import Relation + + +@pytest.fixture +def whitelist_form_data(): + """Sample valid whitelist form data""" + return { + "ip": "192.168.1.0", + "mask": 24, + "comment": "Test whitelist entry", + "expires": datetime.now() + timedelta(hours=1), + } + + +class TestCreateOrUpdateWhitelist: + @patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists") + @patch("flowapp.services.whitelist_service.check_whitelist_against_rules") + @patch("flowapp.services.whitelist_service.evaluate_whitelist_against_rtbh_check_results") + def test_create_new_whitelist( + self, mock_evaluate, mock_check_whitelist, mock_get_model, app, db, whitelist_form_data + ): + """Test creating a new whitelist entry""" + # Mock the get_whitelist_model_if_exists to return False (not found) + mock_get_model.return_value = False + + # Mock RTBH rules and check results + mock_rtbh_rules = [] + mock_query = MagicMock() + mock_query.filter.return_value.all.return_value = mock_rtbh_rules + + # Mock check_whitelist_against_rules to return empty list (no matches) + mock_check_whitelist.return_value = [] + + # Mock evaluate to return the model unchanged + mock_evaluate.side_effect = lambda model, flashes, rtbh_rule_cache, results: model + + # Call the service function + with patch("flowapp.services.whitelist_service.db.session.query", return_value=mock_query): + with app.app_context(): + model, flashes = whitelist_service.create_or_update_whitelist( + form_data=whitelist_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was created with correct attributes + assert model is not None + assert model.ip == whitelist_form_data["ip"] + assert model.mask == whitelist_form_data["mask"] + assert model.comment == whitelist_form_data["comment"] + assert model.user_id == 1 + assert model.org_id == 1 + assert model.rstate_id == 1 # Active state + + # Verify flash messages - now a list instead of a string + assert isinstance(flashes, list) + assert "Whitelist saved" in flashes[0] + + @patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists") + @patch("flowapp.services.whitelist_service.check_whitelist_against_rules") + @patch("flowapp.services.whitelist_service.evaluate_whitelist_against_rtbh_check_results") + def test_update_existing_whitelist( + self, mock_evaluate, mock_check_whitelist, mock_get_model, app, db, whitelist_form_data + ): + """Test updating an existing whitelist entry""" + # Create an existing whitelist + existing_model = Whitelist( + ip=whitelist_form_data["ip"], + mask=whitelist_form_data["mask"], + expires=datetime.now(), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Mock to return the existing model + mock_get_model.return_value = existing_model + + # Mock RTBH rules and check results + mock_rtbh_rules = [] + mock_query = MagicMock() + mock_query.filter.return_value.all.return_value = mock_rtbh_rules + + # Mock check_whitelist_against_rules to return empty list (no matches) + mock_check_whitelist.return_value = [] + + # Mock evaluate to return the model unchanged + mock_evaluate.side_effect = lambda model, flashes, rtbh_rule_cache, results: model + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + whitelist_form_data["expires"] = new_expires + + # Call the service function + with patch("flowapp.services.whitelist_service.db.session.query", return_value=mock_query): + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, flashes = whitelist_service.create_or_update_whitelist( + form_data=whitelist_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + # Check that expiration date was updated (rounded to 10 minutes) + # We can't compare exact timestamps, so check date parts + assert model.expires.date() == whitelist_service.round_to_ten_minutes(new_expires).date() + + # Verify flash messages - now a list instead of a string + assert isinstance(flashes, list) + assert "Existing Whitelist found" in flashes[0] + + @patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists") + @patch("flowapp.services.whitelist_service.db.session.query") + @patch("flowapp.services.whitelist_service.map_rtbh_rules_to_strings") + @patch("flowapp.services.whitelist_service.check_whitelist_against_rules") + @patch("flowapp.services.whitelist_service.evaluate_whitelist_against_rtbh_check_results") + def test_create_whitelist_with_matching_rules( + self, mock_evaluate, mock_check, mock_map, mock_query, mock_get_model, app, db, whitelist_form_data + ): + """Test creating a whitelist that affects existing rules""" + # Mock get_whitelist_model_if_exists to return False (new whitelist) + mock_get_model.return_value = False + + # Create a mock RTBH rule + mock_rtbh_rule = MagicMock(spec=RTBH) + mock_rtbh_rule.__str__.return_value = "192.168.1.0/24" + mock_rtbh_rules = [mock_rtbh_rule] + + # Setup mock query to return our rule + mock_query_result = MagicMock() + mock_query_result.filter.return_value.all.return_value = mock_rtbh_rules + mock_query.return_value = mock_query_result + + # Setup map function to return our rule in a dict + mock_map.return_value = {"192.168.1.0/24": mock_rtbh_rule} + + # Setup check function to return a relation + rule_relation = [ + ( + str(mock_rtbh_rule), + str(whitelist_form_data["ip"]) + "/" + str(whitelist_form_data["mask"]), + Relation.EQUAL, + ) + ] + mock_check.return_value = rule_relation + + # Setup evaluate function to modify flashes + def evaluate_side_effect(model, flashes, rtbh_rule_cache, results): + flashes.append("Rule was whitelisted") + return model + + mock_evaluate.side_effect = evaluate_side_effect + + # Call the service function + with app.app_context(): + model, flashes = whitelist_service.create_or_update_whitelist( + form_data=whitelist_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify flash messages includes both whitelist creation and rule whitelisting + assert "Whitelist saved" in flashes[0] + assert "Rule was whitelisted" in flashes[1] + + # Verify interactions + mock_map.assert_called_once() + mock_check.assert_called_once() + mock_evaluate.assert_called_once() + + +class TestDeleteWhitelist: + @patch("flowapp.services.whitelist_service.announce_rtbh_route") + def test_delete_whitelist_with_user_rules(self, mock_announce, app, db): + """Test deleting a whitelist that has user-created rules attached to it""" + # Create a whitelist + whitelist = Whitelist( + ip="192.168.1.0", + mask=24, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Create a RTBH rule (created by user, whitelisted) + rtbh_rule = RTBH( + ipv4="192.168.1.0", + ipv4_mask=24, + ipv6="", + ipv6_mask=None, + community_id=1, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=4, # Whitelisted state + ) + + # Create a cache entry linking the rule to the whitelist + cache_entry = RuleWhitelistCache( + rid=1, + rtype=RuleTypes.RTBH, + whitelist_id=1, + rorigin=RuleOrigin.USER, + ) + + with app.app_context(): + db.session.add(whitelist) + db.session.add(rtbh_rule) + db.session.commit() + + # Set whitelist ID now that it has been saved + cache_entry.whitelist_id = whitelist.id + cache_entry.rid = rtbh_rule.id + db.session.add(cache_entry) + db.session.commit() + + # Mock the get_by_whitelist_id to return our cache entry + with patch.object(RuleWhitelistCache, "get_by_whitelist_id", return_value=[cache_entry]): + # Call the service function + flashes = whitelist_service.delete_whitelist(whitelist.id) + + # Verify the rule state was changed back to active + rtbh_rule = db.session.get(RTBH, rtbh_rule.id) + assert rtbh_rule.rstate_id == 1 # Active state + + # Verify announcement was made + mock_announce.assert_called_once() + + # Verify flash messages + assert isinstance(flashes, list) + assert flashes + + # Verify the whitelist was deleted + assert db.session.get(Whitelist, whitelist.id) is None + + @patch("flowapp.services.whitelist_service.RuleWhitelistCache.clean_by_whitelist_id") + def test_delete_whitelist_with_whitelist_created_rules(self, mock_clean, app, db): + """Test deleting a whitelist that has rules created by the whitelist""" + # Create a whitelist + whitelist = Whitelist( + ip="192.168.1.0", + mask=24, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Create a RTBH rule (created by whitelist) + rtbh_rule = RTBH( + ipv4="192.168.1.0", + ipv4_mask=24, + ipv6="", + ipv6_mask=None, + community_id=1, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Create a cache entry linking the rule to the whitelist + cache_entry = RuleWhitelistCache( + rid=1, + rtype=RuleTypes.RTBH, + whitelist_id=1, + rorigin=RuleOrigin.WHITELIST, # Important: created BY whitelist + ) + + with app.app_context(): + db.session.add(whitelist) + db.session.add(rtbh_rule) + db.session.commit() + + # Set IDs now that they have been saved + cache_entry.whitelist_id = whitelist.id + cache_entry.rid = rtbh_rule.id + db.session.add(cache_entry) + db.session.commit() + + # Create a mock session that can get our rule + with patch.object(RuleWhitelistCache, "get_by_whitelist_id", return_value=[cache_entry]): + # Call the service function + flashes = whitelist_service.delete_whitelist(whitelist.id) + + # Verify flash messages + assert isinstance(flashes, list) + assert flashes + + # Verify the rule was deleted + assert db.session.get(RTBH, rtbh_rule.id) is None + + # Verify the whitelist was deleted + assert db.session.get(Whitelist, whitelist.id) is None + + # Verify cache cleanup was called + mock_clean.assert_called_once_with(whitelist.id) + + def test_delete_nonexistent_whitelist(self, app, db): + """Test deleting a whitelist that doesn't exist""" + with app.app_context(): + # Call the service function with a non-existent ID + flashes = whitelist_service.delete_whitelist(999) + + # Should return empty list of flash messages, as no whitelist was found + assert isinstance(flashes, list) + assert len(flashes) == 0 + + +class TestEvaluateWhitelistAgainstRtbhResults: + def test_equal_relation(self, app): + """Test evaluating a whitelist with an EQUAL relation to a rule""" + # Create test data + whitelist_model = MagicMock(spec=Whitelist) + whitelist_model.id = 1 + + flashes = [] + + rtbh_rule = MagicMock(spec=RTBH) + rtbh_rule.rstate_id = 1 # Active state + + rule_key = "192.168.1.0/24" + whitelist_key = "192.168.1.0/24" + rtbh_rule_cache = {rule_key: rtbh_rule} + + results = [(rule_key, whitelist_key, Relation.EQUAL)] + + # Mock required functions + with patch("flowapp.services.whitelist_service.whitelist_rtbh_rule") as mock_whitelist_rule, patch( + "flowapp.services.whitelist_service.withdraw_rtbh_route" + ) as mock_withdraw: + + # Call the function + with app.app_context(): + result = whitelist_service.evaluate_whitelist_against_rtbh_check_results( + whitelist_model, flashes, rtbh_rule_cache, results + ) + + # Verify the rule was whitelisted and route withdrawn + mock_whitelist_rule.assert_called_once_with(rtbh_rule, whitelist_model) + mock_withdraw.assert_called_once_with(rtbh_rule) + + # Verify the flash message + assert flashes + + # Verify the correct model was returned + assert result == whitelist_model + + def test_subnet_relation(self, app): + """Test evaluating a whitelist with a SUBNET relation to a rule""" + # Create test data + whitelist_model = MagicMock(spec=Whitelist) + whitelist_model.id = 1 + + flashes = [] + + rtbh_rule = MagicMock(spec=RTBH) + rtbh_rule.rstate_id = 1 # Active state + + rule_key = "192.168.1.0/24" + whitelist_key = "192.168.1.128/25" + rtbh_rule_cache = {rule_key: rtbh_rule} + + results = [(rule_key, whitelist_key, Relation.SUBNET)] + + # Mock required functions + with patch("flowapp.services.whitelist_service.subtract_network") as mock_subtract, patch( + "flowapp.services.whitelist_service.create_rtbh_from_whitelist_parts" + ) as mock_create, patch("flowapp.services.whitelist_service.add_rtbh_rule_to_cache") as mock_add_cache, patch( + "flowapp.services.whitelist_service.db.session.commit" + ) as mock_commit: + + # Mock subtract_network to return some subnets + mock_subtract.return_value = ["192.168.1.0/25"] + + # Call the function + with app.app_context(): + _result = whitelist_service.evaluate_whitelist_against_rtbh_check_results( + whitelist_model, flashes, rtbh_rule_cache, results + ) + + # Verify subnet calculation was performed + mock_subtract.assert_called_once() + + # Verify new rules were created for the subnets + mock_create.assert_called_once() + + # Verify the original rule was cached + mock_add_cache.assert_called_once_with(rtbh_rule, whitelist_model.id, RuleOrigin.USER) + + # Verify transaction was committed + mock_commit.assert_called_once() + + # Verify the flash messages + assert any("supernet of whitelist" in msg for msg in flashes) + + # Verify model was updated to whitelisted state + assert rtbh_rule.rstate_id == 4 + + def test_supernet_relation(self, app): + """Test evaluating a whitelist with a SUPERNET relation to a rule""" + # Create test data + whitelist_model = MagicMock(spec=Whitelist) + whitelist_model.id = 1 + + flashes = [] + + rtbh_rule = MagicMock(spec=RTBH) + rtbh_rule.rstate_id = 1 # Active state + + rule_key = "192.168.1.0/24" + whitelist_key = "192.168.0.0/16" + rtbh_rule_cache = {rule_key: rtbh_rule} + + results = [(rule_key, whitelist_key, Relation.SUPERNET)] + + # Mock required functions + with patch("flowapp.services.whitelist_service.whitelist_rtbh_rule") as mock_whitelist_rule, patch( + "flowapp.services.whitelist_service.withdraw_rtbh_route" + ) as mock_withdraw: + + # Call the function + with app.app_context(): + result = whitelist_service.evaluate_whitelist_against_rtbh_check_results( + whitelist_model, flashes, rtbh_rule_cache, results + ) + + # Verify the rule was whitelisted and route withdrawn + mock_whitelist_rule.assert_called_once_with(rtbh_rule, whitelist_model) + mock_withdraw.assert_called_once_with(rtbh_rule) + + # Verify the flash message + assert any("subnet of whitelist" in msg for msg in flashes) + + # Verify the correct model was returned + assert result == whitelist_model diff --git a/flowapp/utils/__init__.py b/flowapp/utils/__init__.py new file mode 100644 index 00000000..99ca325a --- /dev/null +++ b/flowapp/utils/__init__.py @@ -0,0 +1,46 @@ +from .base import ( + other_rtypes, + output_date_format, + parse_api_time, + quote_to_ent, + webpicker_to_datetime, + datetime_to_webpicker, + get_state_by_time, + round_to_ten_minutes, + flash_errors, + active_css_rstate, + get_comp_func, +) + +from .app_factory import ( + # configure_app, + configure_logging, + register_blueprints, + register_error_handlers, + register_context_processors, + register_template_filters, + register_auth_handlers, + # register_org_routes, +) + +__all__ = [ + "other_rtypes", + "output_date_format", + "parse_api_time", + "quote_to_ent", + "webpicker_to_datetime", + "datetime_to_webpicker", + "get_state_by_time", + "round_to_ten_minutes", + "flash_errors", + "active_css_rstate", + "get_comp_func", + # "configure_app", + "configure_logging", + "register_blueprints", + "register_error_handlers", + "register_context_processors", + "register_template_filters", + "register_auth_handlers", + # "register_org_routes", +] diff --git a/flowapp/utils/app_factory.py b/flowapp/utils/app_factory.py new file mode 100644 index 00000000..234522fd --- /dev/null +++ b/flowapp/utils/app_factory.py @@ -0,0 +1,221 @@ +import logging +import babel +from flask import redirect, render_template, request, session, url_for + + +def register_blueprints(app, csrf=None): + """Register Flask blueprints.""" + from flowapp.views.admin import admin + from flowapp.views.rules import rules + from flowapp.views.api_v1 import api as api_v1 + from flowapp.views.api_v2 import api as api_v2 + from flowapp.views.api_v3 import api as api_v3 + from flowapp.views.api_keys import api_keys + from flowapp.views.dashboard import dashboard + from flowapp.views.whitelist import whitelist + + # Configure CSRF exemption for API routes + if csrf: + csrf.exempt(api_v1) + csrf.exempt(api_v2) + csrf.exempt(api_v3) + + # Register blueprints with URL prefixes + app.register_blueprint(admin, url_prefix="/admin") + app.register_blueprint(rules, url_prefix="/rules") + app.register_blueprint(api_keys, url_prefix="/api_keys") + app.register_blueprint(api_v1, url_prefix="/api/v1") + app.register_blueprint(api_v2, url_prefix="/api/v2") + app.register_blueprint(api_v3, url_prefix="/api/v3") + app.register_blueprint(dashboard, url_prefix="/dashboard") + app.register_blueprint(whitelist, url_prefix="/whitelist") + + return app + + +def configure_logging(app): + """Configure logging for the Flask application.""" + + # Remove all default handlers + for handler in app.logger.handlers[:]: + app.logger.removeHandler(handler) + + # Retrieve log level and file name from config + log_level = app.config.get("LOG_LEVEL", "DEBUG").upper() + log_file = app.config.get("LOG_FILE", "app.log") + + # Define log format + log_format = "%(asctime)s | %(levelname)s | %(message)s" + log_datefmt = "%Y-%m-%d %H:%M:%S" + formatter = logging.Formatter(log_format, datefmt=log_datefmt) + + # Console handler + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + + # File handler + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + + # Set logger level + app.logger.setLevel(getattr(logging, log_level, logging.DEBUG)) + + # Attach handlers + app.logger.addHandler(console_handler) + app.logger.addHandler(file_handler) + + return app + + +def register_error_handlers(app): + """Register error handlers.""" + + @app.errorhandler(404) + def not_found(error): + return render_template("errors/404.html"), 404 + + @app.errorhandler(500) + def internal_error(exception): + app.logger.exception(exception) + return render_template("errors/500.html"), 500 + + return app + + +def register_context_processors(app): + """Register template context processors.""" + + @app.context_processor + def utility_processor(): + def editable_rule(rule): + if rule: + from flowapp.validators import editable_range + from flowapp.models import get_user_nets + + editable_range(rule, get_user_nets(session["user_id"])) + return True + return False + + return dict(editable_rule=editable_rule) + + @app.context_processor + def inject_main_menu(): + """Inject main menu config to templates.""" + return {"main_menu": app.config.get("MAIN_MENU")} + + @app.context_processor + def inject_dashboard(): + """Inject dashboard config to templates.""" + return {"dashboard": app.config.get("DASHBOARD")} + + return app + + +def register_template_filters(app): + """Register custom template filters.""" + + @app.template_filter("strftime") + def format_datetime(value): + if value is None: + return app.config.get("MISSING_DATETIME_MESSAGE", "Never") + + format = "y/MM/dd HH:mm" + return babel.dates.format_datetime(value, format) + + @app.template_filter("unlimited") + def unlimited_filter(value): + return "unlimited" if value == 0 else value + + return app + + +def register_auth_handlers(app, ext): + """Register authentication handlers.""" + + @ext.login_handler + def login(user_info): + try: + uuid = user_info.get("eppn") + except KeyError: + uuid = False + return render_template("errors/401.html") + + return _handle_login(uuid, app) + + @app.route("/logout") + def logout(): + session["user_uuid"] = False + session["user_id"] = False + session.clear() + return redirect(app.config.get("LOGOUT_URL")) + + @app.route("/ext-login") + def ext_login(): + header_name = app.config.get("AUTH_HEADER_NAME", "X-Authenticated-User") + if header_name not in request.headers: + return render_template("errors/401.html") + + uuid = request.headers.get(header_name) + if not uuid: + return render_template("errors/401.html") + + return _handle_login(uuid, app) + + @app.route("/local-login") + def local_login(): + print("Local login started") + if not app.config.get("LOCAL_AUTH", False): + print("Local auth not enabled") + return render_template("errors/401.html") + + uuid = app.config.get("LOCAL_USER_UUID", False) + if not uuid: + print("Local user not set") + return render_template("errors/401.html") + + print(f"Local login with {uuid}") + return _handle_login(uuid, app) + + return app + + +def _handle_login(uuid, app): + """Handle login process for all authentication methods.""" + from flowapp import db + + multiple_orgs = False + try: + user, multiple_orgs = _register_user_to_session(uuid, db) + except AttributeError as e: + app.logger.exception(e) + return render_template("errors/401.html") + + if multiple_orgs: + return redirect(url_for("select_org", org_id=None)) + + # set user org to session + user_org = user.organization.first() + session["user_org"] = user_org.name + session["user_org_id"] = user_org.id + + return redirect("/") + + +def _register_user_to_session(uuid, db): + """Register user information to session.""" + from flowapp.models import User + + print(f"Registering user {uuid} to session") + user = db.session.query(User).filter_by(uuid=uuid).first() + print(f"Got user {user} from DB") + session["user_uuid"] = user.uuid + session["user_email"] = user.uuid + session["user_name"] = user.name + session["user_id"] = user.id + session["user_roles"] = [role.name for role in user.role.all()] + session["user_role_ids"] = [role.id for role in user.role.all()] + roles = [i > 1 for i in session["user_role_ids"]] + session["can_edit"] = True if all(roles) and roles else [] + # check if user has multiple organizations and return True if so + print(f"DEBUG SESSION {session}") + return user, len(user.organization.all()) > 1 diff --git a/flowapp/utils.py b/flowapp/utils/base.py similarity index 98% rename from flowapp/utils.py rename to flowapp/utils/base.py index ad15b615..512f7ccd 100644 --- a/flowapp/utils.py +++ b/flowapp/utils/base.py @@ -1,4 +1,3 @@ -from operator import ge, lt from datetime import datetime, timedelta from flask import flash from flowapp.constants import ( @@ -73,7 +72,7 @@ def parse_api_time(apitime): except ValueError: mytime = False - return False + return mytime def quote_to_ent(comment): diff --git a/flowapp/validators.py b/flowapp/validators.py index 63061062..890bdacb 100644 --- a/flowapp/validators.py +++ b/flowapp/validators.py @@ -253,7 +253,10 @@ def __call__(self, form, field): result = False for address in field.data.split("/"): for adr_range in self.net_ranges: - result = result or ipaddress.ip_address(address) in ipaddress.ip_network(adr_range) + try: + result = result or ipaddress.ip_address(address) in ipaddress.ip_network(adr_range) + except ValueError as e: + raise ValidationError(self.message + str(e.args[0])) if not result: raise ValidationError(self.message) @@ -352,3 +355,66 @@ def subnet_of(net_a, net_b): def supernet_of(net_a, net_b): """Return True if this network is a supernet of other.""" return _is_subnet_of(net_b, net_a) + + +class IPAddressValidator: + """ + Universal validator that accepts both IPv4 and IPv6 addresses. + """ + + def __init__(self, message=None): + self.message = message or "Invalid IP address: {}" + + def __call__(self, form, field): + if not field.data: + return + + try: + ipaddress.ip_address(field.data) + except ValueError: + raise ValidationError(self.message.format(field.data)) + + +class NetworkValidator: + """ + Validates that an IP address and mask form a valid network. + Works with both IPv4 and IPv6 addresses. + """ + + def __init__(self, mask_field_name, message=None): + self.mask_field_name = mask_field_name + self.message = message or "Invalid network: address {}, mask {}" + + def __call__(self, form, field): + if not field.data: + return + + mask_field = form._fields.get(self.mask_field_name) + if not mask_field or not mask_field.data: + return + + try: + # Determine IP version + ip = ipaddress.ip_address(field.data) + mask = int(mask_field.data) + + # Validate mask range based on IP version + if isinstance(ip, ipaddress.IPv4Address): + if not 0 <= mask <= 32: + raise ValidationError(f"Invalid IPv4 mask: {mask}") + else: # IPv6 + if not 0 <= mask <= 128: + raise ValidationError(f"Invalid IPv6 mask: {mask}") + + # Try to create network to validate the combination + network = ipaddress.ip_network(f"{field.data}/{mask}", strict=False) + + # Check if the original IP is the correct network address + if str(network.network_address) != str(ip): + raise ValidationError( + f"Invalid network address: {field.data}/{mask}. " + f"Network address should be: {network.network_address}" + ) + + except ValueError: + raise ValidationError(self.message.format(field.data, mask_field.data)) diff --git a/flowapp/views/admin.py b/flowapp/views/admin.py index d014c87c..7912e86a 100644 --- a/flowapp/views/admin.py +++ b/flowapp/views/admin.py @@ -73,15 +73,22 @@ def add_machine_key(): """ generated = secrets.token_hex(24) form = MachineApiKeyForm(request.form, key=generated) + form.user.choices = [(g.id, g.name) for g in db.session.query(User).order_by("name")] if request.method == "POST" and form.validate(): + target_user = db.session.get(User, form.user.data) + target_org = target_user.organization.first() if target_user else None + current_user = session.get("user_name") + curent_email = session.get("user_uuid") + comment = f"created by: {current_user}/{curent_email}, comment: {form.comment.data}" model = MachineApiKey( machine=form.machine.data, key=form.key.data, expires=form.expires.data, readonly=form.readonly.data, - comment=form.comment.data, - user_id=session["user_id"], + comment=comment, + user_id=target_user.id, + org_id=target_org.id, ) db.session.add(model) diff --git a/flowapp/views/api_common.py b/flowapp/views/api_common.py index 6ce24bf6..3951b093 100644 --- a/flowapp/views/api_common.py +++ b/flowapp/views/api_common.py @@ -5,7 +5,7 @@ from functools import wraps from datetime import datetime, timedelta -from flowapp.constants import RULE_NAMES_DICT, WITHDRAW, ANNOUNCE, TIME_FORMAT_ARG, RuleTypes +from flowapp.constants import RULE_NAMES_DICT, WITHDRAW, TIME_FORMAT_ARG, RuleTypes from flowapp.models import ( RTBH, Flowspec4, @@ -18,20 +18,16 @@ check_rule_limit, get_user_nets, get_user_actions, - get_ipv4_model_if_exists, - get_ipv6_model_if_exists, insert_initial_communities, get_user_communities, - get_rtbh_model_if_exists, ) from flowapp.forms import IPv4Form, IPv6Form, RTBHForm +from flowapp.services import rule_service from flowapp.utils import ( - quote_to_ent, - get_state_by_time, output_date_format, ) from flowapp.auth import check_access_rights -from flowapp.output import announce_route, log_route, log_withdraw, Route, RouteSources +from flowapp.output import announce_route, log_withdraw, Route, RouteSources from flowapp import db, validators, flowspec, messages @@ -67,7 +63,6 @@ def authorize(user_key): :return: page with token """ jwt_key = current_app.config.get("JWT_SECRET") - # try normal user key first model = db.session.query(ApiKey).filter_by(key=user_key).first() # if not found try machine key @@ -100,7 +95,7 @@ def authorize(user_key): return jsonify({"token": encoded}) else: - return jsonify({"message": "auth token is invalid"}), 403 + return jsonify({"message": f"auth token is not valid from machine {request.remote_addr}"}), 403 def check_readonly(func): @@ -191,7 +186,7 @@ def all_communities(current_user): def limit_reached(count, rule_type, org_id): - rule_name = RULE_NAMES_DICT[int(rule_type)] + rule_name = RULE_NAMES_DICT[rule_type.value] org = db.session.get(Organization, org_id) if rule_type == RuleTypes.IPv4: limit = org.limit_flowspec4 @@ -207,7 +202,7 @@ def limit_reached(count, rule_type, org_id): def global_limit_reached(count, rule_type): - rule_name = RULE_NAMES_DICT[int(rule_type)] + rule_name = RULE_NAMES_DICT[rule_type.value] if rule_type == RuleTypes.IPv4 or rule_type == RuleTypes.IPv6: limit = current_app.config.get("FLOWSPEC_MAX_RULES") elif rule_type == RuleTypes.RTBH: @@ -249,52 +244,15 @@ def create_ipv4(current_user): if form_errors: return jsonify(form_errors), 400 - model = get_ipv4_model_if_exists(form.data, 1) - - if model: - model.expires = form.expires.data - flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value." - else: - model = Flowspec4( - source=form.source.data, - source_mask=form.source_mask.data, - source_port=form.source_port.data, - destination=form.dest.data, - destination_mask=form.dest_mask.data, - destination_port=form.dest_port.data, - protocol=form.protocol.data, - flags=";".join(form.flags.data), - packet_len=form.packet_len.data, - fragment=";".join(form.fragment.data), - expires=form.expires.data, - comment=quote_to_ent(form.comment.data), - action_id=form.action.data, - user_id=current_user["id"], - org_id=current_user["org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - flash_message = "IPv4 Rule saved" - db.session.add(model) - - db.session.commit() - - # announce route if model is in active state - if model.rstate_id == 1: - command = messages.create_ipv4(model, ANNOUNCE) - route = Route( - author=f"{current_user['uuid']} / {current_user['org']}", - source=RouteSources.API, - command=command, - ) - announce_route(route) - - # log changes - log_route( - current_user["id"], - model, - RuleTypes.IPv4, - f"{current_user['uuid']} / {current_user['org']}", + # Use the service to create/update the rule + model, flash_message = rule_service.create_or_update_ipv4_rule( + form_data=form.data, + user_id=current_user["id"], + org_id=current_user["org_id"], + user_email=current_user["uuid"], + org_name=current_user["org"], ) + pref_format = output_date_format(json_request_data, form.expires.pref_format) response = {"message": flash_message, "rule": model.to_dict(pref_format)} return jsonify(response), 201 @@ -326,50 +284,12 @@ def create_ipv6(current_user): if form_errors: return jsonify(form_errors), 400 - model = get_ipv6_model_if_exists(form.data, 1) - - if model: - model.expires = form.expires.data - flash_message = "Existing IPv6 Rule found. Expiration time was updated to new value." - else: - model = Flowspec6( - source=form.source.data, - source_mask=form.source_mask.data, - source_port=form.source_port.data, - destination=form.dest.data, - destination_mask=form.dest_mask.data, - destination_port=form.dest_port.data, - next_header=form.next_header.data, - flags=";".join(form.flags.data), - packet_len=form.packet_len.data, - expires=form.expires.data, - comment=quote_to_ent(form.comment.data), - action_id=form.action.data, - user_id=current_user["id"], - org_id=current_user["org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - flash_message = "IPv6 Rule saved" - db.session.add(model) - - db.session.commit() - - # announce routes - if model.rstate_id == 1: - command = messages.create_ipv6(model, ANNOUNCE) - route = Route( - author=f"{current_user['uuid']} / {current_user['org']}", - source=RouteSources.API, - command=command, - ) - announce_route(route) - - # log changes - log_route( - current_user["id"], - model, - RuleTypes.IPv6, - f"{current_user['uuid']} / {current_user['org']}", + model, flash_message = rule_service.create_or_update_ipv6_rule( + form_data=form.data, + user_id=current_user["id"], + org_id=current_user["org_id"], + user_email=current_user["uuid"], + org_name=current_user["org"], ) pref_format = output_date_format(json_request_data, form.expires.pref_format) @@ -384,7 +304,6 @@ def create_rtbh(current_user): count = db.session.query(RTBH).filter_by(rstate_id=1).count() return global_limit_reached(count=count, rule_type=RuleTypes.RTBH) - # check limit if check_rule_limit(current_user["org_id"], RuleTypes.RTBH): count = db.session.query(RTBH).filter_by(rstate_id=1, org_id=current_user["org_id"]).count() return limit_reached(count=count, rule_type=RuleTypes.RTBH, org_id=current_user["org_id"]) @@ -406,43 +325,12 @@ def create_rtbh(current_user): if form_errors: return jsonify(form_errors), 400 - model = get_rtbh_model_if_exists(form.data, 1) - - if model: - model.expires = form.expires.data - flash_message = "Existing RTBH Rule found. Expiration time was updated to new value." - else: - model = RTBH( - ipv4=form.ipv4.data, - ipv4_mask=form.ipv4_mask.data, - ipv6=form.ipv6.data, - ipv6_mask=form.ipv6_mask.data, - community_id=form.community.data, - expires=form.expires.data, - comment=quote_to_ent(form.comment.data), - user_id=current_user["id"], - org_id=current_user["org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - db.session.add(model) - db.session.commit() - flash_message = "RTBH Rule saved" - - # announce routes - if model.rstate_id == 1: - command = messages.create_rtbh(model, ANNOUNCE) - route = Route( - author=f"{current_user['uuid']} / {current_user['org']}", - source=RouteSources.API, - command=command, - ) - announce_route(route) - # log changes - log_route( - current_user["id"], - model, - RuleTypes.RTBH, - f"{current_user['uuid']} / {current_user['org']}", + model, flash_message = rule_service.create_or_update_rtbh_rule( + form_data=form.data, + user_id=current_user["id"], + org_id=current_user["org_id"], + user_email=current_user["uuid"], + org_name=current_user["org"], ) pref_format = output_date_format(json_request_data, form.expires.pref_format) @@ -506,7 +394,7 @@ def delete_v4_rule(current_user, rule_id): """ model_name = Flowspec4 route_model = messages.create_ipv4 - return delete_rule(current_user, rule_id, model_name, route_model, 4) + return delete_rule(current_user, rule_id, model_name, route_model, RuleTypes.IPv4) def delete_v6_rule(current_user, rule_id): @@ -516,7 +404,7 @@ def delete_v6_rule(current_user, rule_id): """ model_name = Flowspec6 route_model = messages.create_ipv6 - return delete_rule(current_user, rule_id, model_name, route_model, 6) + return delete_rule(current_user, rule_id, model_name, route_model, RuleTypes.IPv6) def delete_rtbh_rule(current_user, rule_id): @@ -526,7 +414,7 @@ def delete_rtbh_rule(current_user, rule_id): """ model_name = RTBH route_model = messages.create_rtbh - return delete_rule(current_user, rule_id, model_name, route_model, 1) + return delete_rule(current_user, rule_id, model_name, route_model, RuleTypes.RTBH) def delete_rule(current_user, rule_id, model_name, route_model, rule_type): diff --git a/flowapp/views/dashboard.py b/flowapp/views/dashboard.py index 8ed262f0..c709f11f 100644 --- a/flowapp/views/dashboard.py +++ b/flowapp/views/dashboard.py @@ -48,7 +48,7 @@ def whois(ip_address): def index(rtype=None, rstate="active"): """ dispatcher object for the dashboard - :param rtype: ipv4, ipv6, rtbh + :param rtype: ipv4, ipv6, rtbh, whitelist :param rstate: :return: view from view factory """ @@ -84,7 +84,6 @@ def index(rtype=None, rstate="active"): data_handler_module = current_app.config["DASHBOARD"].get(rtype).get("data_handler", models) data_handler_method = current_app.config["DASHBOARD"].get(rtype).get("data_handler_method", "get_ip_rules") - # get search query, sort order and sort key from request or session get_search_query = request.args.get(SEARCH_ARG, session.get(SEARCH_ARG, "")) get_sort_key = request.args.get(SORT_ARG, session.get(SORT_ARG, DEFAULT_SORT)) @@ -109,6 +108,10 @@ def index(rtype=None, rstate="active"): # get the handler and the data handler = getattr(data_handler_module, data_handler_method) rules = handler(rtype, rstate, get_sort_key, get_sort_order) + + # Enrich rules with whitelist information + rules, whitelist_rule_ids = enrich_rules_with_whitelist_info(rules, rtype) + session[RULES_KEY] = [rule.id for rule in rules] # search rules if get_search_query: @@ -123,6 +126,8 @@ def index(rtype=None, rstate="active"): else: count_match = "" + allowed_communities = current_app.config["ALLOWED_COMMUNITIES"] + return view_factory( rtype=rtype, rstate=rstate, @@ -138,6 +143,8 @@ def index(rtype=None, rstate="active"): macro_tbody=macro_tbody, macro_thead=macro_thead, macro_tfoot=macro_tfoot, + whitelist_rule_ids=whitelist_rule_ids, + allowed_communities=allowed_communities, ) @@ -148,18 +155,25 @@ def create_dashboard_table_body( group_op=True, macro_file="macros.html", macro_name="build_ip_tbody", + whitelist_rule_ids=None, + allowed_communities=None, ): """ create the table body for the dashboard using a jinja2 macro :param rules: list of rules :param rtype: ipv4, ipv6, rtbh + :param editable: whether rules can be edited + :param group_op: whether group operations are allowed :param macro_file: the file where the macro is defined :param macro_name: the name of the macro + :param whitelist_rule_ids: set of rule IDs that were created by a whitelist """ tstring = "{% " tstring = tstring + f"from '{macro_file}' import {macro_name}" tstring = tstring + " %} {{" - tstring = tstring + f" {macro_name}(rules, today, editable, group_op) " + "}}" + tstring = ( + tstring + f" {macro_name}(rules, today, editable, group_op, whitelist_rule_ids, allowed_communities) " + "}}" + ) dashboard_table_body = render_template_string( tstring, @@ -167,6 +181,8 @@ def create_dashboard_table_body( today=datetime.now(), editable=editable, group_op=group_op, + whitelist_rule_ids=whitelist_rule_ids or set(), + allowed_communities=allowed_communities or [], ) return dashboard_table_body @@ -246,6 +262,8 @@ def create_admin_response( macro_tbody="build_ip_tbody", macro_thead="build_rules_thead", macro_tfoot="build_group_buttons_tfoot", + whitelist_rule_ids=None, + allowed_communities=None, ): """ Admin can see and edit any rules @@ -256,8 +274,17 @@ def create_admin_response( :param sort_order: :return: """ + group_op = True if rtype != "whitelist" else False - dashboard_table_body = create_dashboard_table_body(rules, rtype, macro_file=macro_file, macro_name=macro_tbody) + dashboard_table_body = create_dashboard_table_body( + rules, + rtype, + macro_file=macro_file, + macro_name=macro_tbody, + group_op=group_op, + whitelist_rule_ids=whitelist_rule_ids, + allowed_communities=allowed_communities, + ) dashboard_table_head = create_dashboard_table_head( rules_columns=table_columns, @@ -266,15 +293,18 @@ def create_admin_response( sort_key=sort_key, sort_order=sort_order, search_query=search_query, + group_op=group_op, macro_file=macro_file, macro_name=macro_thead, ) - - dashboard_table_foot = create_dashboard_table_foot( - table_colspan, - macro_file=macro_file, - macro_name=macro_tfoot, - ) + if group_op: + dashboard_table_foot = create_dashboard_table_foot( + table_colspan, + macro_file=macro_file, + macro_name=macro_tfoot, + ) + else: + dashboard_table_foot = "" res = make_response( render_template( @@ -313,6 +343,8 @@ def create_user_response( macro_tbody="build_ip_tbody", macro_thead="build_rules_thead", macro_tfoot="build_rules_tfoot", + whitelist_rule_ids=None, + allowed_communities=None, ): """ Filter out the rules for normal users @@ -343,9 +375,20 @@ def create_user_response( group_op=False, macro_file=macro_file, macro_name=macro_tbody, + whitelist_rule_ids=whitelist_rule_ids, + allowed_communities=allowed_communities, ) + + group_op = True if rtype != "whitelist" else False + dashboard_table_editable = create_dashboard_table_body( - rules_editable, rtype, macro_file=macro_file, macro_name=macro_tbody + rules_editable, + rtype, + macro_file=macro_file, + macro_name=macro_tbody, + group_op=group_op, + whitelist_rule_ids=whitelist_rule_ids, + allowed_communities=allowed_communities, ) dashboard_table_editable_head = create_dashboard_table_head( rules_columns=table_columns, @@ -354,7 +397,7 @@ def create_user_response( sort_key=sort_key, sort_order=sort_order, search_query=search_query, - group_op=True, + group_op=group_op, macro_file=macro_file, macro_name=macro_thead, ) @@ -370,11 +413,14 @@ def create_user_response( macro_name=macro_thead, ) - dashboard_table_foot = create_dashboard_table_foot( - table_colspan, - macro_file=macro_file, - macro_name=macro_tfoot, - ) + if group_op: + dashboard_table_foot = create_dashboard_table_foot( + table_colspan, + macro_file=macro_file, + macro_name=macro_tfoot, + ) + else: + dashboard_table_foot = "" display_editable = len(rules_editable) display_readonly = len(read_only_rules) @@ -419,6 +465,8 @@ def create_view_response( macro_tbody="build_ip_tbody", macro_thead="build_rules_thead", macro_tfoot="build_rules_tfoot", + whitelist_rule_ids=None, + allowed_communities=None, ): """ Filter out the rules for normal users @@ -433,6 +481,8 @@ def create_view_response( group_op=False, macro_file=macro_file, macro_name=macro_tbody, + whitelist_rule_ids=whitelist_rule_ids, + allowed_communities=allowed_communities, ) dashboard_table_head = create_dashboard_table_head( @@ -482,3 +532,39 @@ def filter_rules(rules, get_search_query): result.append(rules[idx]) return result + + +def enrich_rules_with_whitelist_info(rules, rule_type): + """ + Enrich rules with whitelist information from RuleWhitelistCache. + + Args: + rules: List of rule objects (Flowspec4, Flowspec6, RTBH) + rule_type: String identifier of rule type ("ipv4", "ipv6", "rtbh") + + Returns: + Tuple of (rules, whitelist_rule_ids) where whitelist_rule_ids is a set of + rule IDs that were created by a whitelist. + """ + from flowapp.models.rules.whitelist import RuleWhitelistCache + from flowapp.constants import RuleTypes, RuleOrigin + + # Map rule type string to enum value + rule_type_map = {"ipv4": RuleTypes.IPv4.value, "ipv6": RuleTypes.IPv6.value, "rtbh": RuleTypes.RTBH.value} + + # Get all rule IDs + rule_ids = [rule.id for rule in rules] + + # No rules to process + if not rule_ids: + return rules, set() + + # Query the cache for these rule IDs + cache_entries = RuleWhitelistCache.query.filter( + RuleWhitelistCache.rid.in_(rule_ids), RuleWhitelistCache.rtype == rule_type_map.get(rule_type) + ).all() + + # Create a set of rule IDs that were created by a whitelist + whitelist_rule_ids = {entry.rid for entry in cache_entries if entry.rorigin == RuleOrigin.WHITELIST.value} + + return rules, whitelist_rule_ids diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py index 5618789f..e3b82a1a 100644 --- a/flowapp/views/rules.py +++ b/flowapp/views/rules.py @@ -1,11 +1,10 @@ # flowapp/views/admin.py from datetime import datetime, timedelta -from operator import ge, lt from collections import namedtuple from flask import Blueprint, current_app, flash, redirect, render_template, request, session, url_for -from flowapp import constants, db, messages +from flowapp import constants, db from flowapp.auth import ( admin_required, auth_required, @@ -24,19 +23,17 @@ Organization, check_global_rule_limit, check_rule_limit, - get_ipv4_model_if_exists, - get_ipv6_model_if_exists, - get_rtbh_model_if_exists, get_user_actions, get_user_communities, get_user_nets, insert_initial_communities, ) +from flowapp.models.log import Log from flowapp.output import ROUTE_MODELS, announce_route, log_route, log_withdraw, RouteSources, Route +from flowapp.services import rule_service, announce_all_routes, delete_expired_whitelists from flowapp.utils import ( flash_errors, get_state_by_time, - quote_to_ent, round_to_ten_minutes, ) @@ -61,13 +58,21 @@ def reactivate_rule(rule_type, rule_id): """ Set new time for the rule of given type identified by id - :param rule_type: string - type of rule + :param rule_type: integer - type of rule, corresponds to RuleTypes enum value :param rule_id: integer - id of the rule """ + # Convert the integer rule_type to RuleTypes enum + enum_rule_type = RuleTypes(rule_type) + + # Now use the enum value where needed but the integer for dictionary lookups model_name = DATA_MODELS[rule_type] form_name = DATA_FORMS[rule_type] model = db.session.get(model_name, rule_id) + if not model: + flash("Rule not found", "alert-danger") + return redirect(url_for("index")) + form = form_name(request.form, obj=model) form.net_ranges = get_user_nets(session["user_id"]) @@ -75,72 +80,41 @@ def reactivate_rule(rule_type, rule_id): form.action.choices = [(g.id, g.name) for g in db.session.query(Action).order_by("name")] form.action.data = model.action_id - if rule_type == 1: + if rule_type == RuleTypes.RTBH.value: form.community.choices = get_user_communities(session["user_role_ids"]) form.community.data = model.community_id - if rule_type == 4: + if rule_type == RuleTypes.IPv4.value: form.protocol.data = model.protocol - if rule_type == 6: + if rule_type == RuleTypes.IPv6.value: form.next_header.data = model.next_header - # do not need to validate - all is readonly + # Process form submission if request.method == "POST": - # check if rule will be reactivated - state = get_state_by_time(form.expires.data) + # Round expiration time to 10 minutes + expires = round_to_ten_minutes(form.expires.data) + + # Use the service to reactivate the rule + _, messages = rule_service.reactivate_rule( + rule_type=enum_rule_type, + rule_id=rule_id, + expires=expires, + comment=form.comment.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], + ) - # check global limit - check_gl = check_global_rule_limit(rule_type) - if state == 1 and check_gl: + # Handle special messages (redirects) + if "global_limit_reached" in messages: return redirect(url_for("rules.global_limit_reached", rule_type=rule_type)) - # check org limit - if state == 1 and check_rule_limit(session["user_org_id"], rule_type=rule_type): + if "limit_reached" in messages: return redirect(url_for("rules.limit_reached", rule_type=rule_type)) - # set new expiration date - model.expires = round_to_ten_minutes(form.expires.data) - # set again the active state - model.rstate_id = get_state_by_time(form.expires.data) - model.comment = form.comment.data - db.session.commit() - flash("Rule successfully updated", "alert-success") - - route_model = ROUTE_MODELS[rule_type] - - if model.rstate_id == 1: - # announce route - command = route_model(model, constants.ANNOUNCE) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - # log changes - log_route( - session["user_id"], - model, - rule_type, - f"{session['user_email']} / {session['user_org']}", - ) - else: - # withdraw route - command = route_model(model, constants.WITHDRAW) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - # log changes - log_withdraw( - session["user_id"], - route.command, - rule_type, - model.id, - f"{session['user_email']} / {session['user_org']}", - ) + for message in messages: + flash(message, "alert-success") return redirect( url_for( @@ -155,6 +129,7 @@ def reactivate_rule(rule_type, rule_id): else: flash_errors(form) + # For GET requests, prepare the form for display form.expires.data = model.expires for field in form: if field.name not in ["expires", "csrf_token", "comment"]: @@ -177,42 +152,74 @@ def reactivate_rule(rule_type, rule_id): def delete_rule(rule_type, rule_id): """ Delete rule with given id and type - :param sort_key: - :param filter_text: - :param rstate: - :param rule_type: string - type of rule to be deleted + :param rule_type: integer - type of rule to be deleted :param rule_id: integer - rule id """ - model_name = DATA_MODELS[rule_type] - route_model = ROUTE_MODELS[rule_type] + # Convert the integer rule_type to RuleTypes enum + enum_rule_type = RuleTypes(rule_type) + + # Use the service to delete the rule + success, message = rule_service.delete_rule( + rule_type=enum_rule_type, + rule_id=rule_id, + user_id=session["user_id"], + user_email=session["user_email"], + org_name=session["user_org"], + allowed_rule_ids=session.get(constants.RULES_KEY, []), + ) - model = db.session.get(model_name, rule_id) - if model.id in session[constants.RULES_KEY]: - # withdraw route - command = route_model(model, constants.WITHDRAW) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - - log_withdraw( - session["user_id"], - route.command, - rule_type, - model.id, - f"{session['user_email']} / {session['user_org']}", + # Flash appropriate message based on result + flash(message, "alert-success" if success else "alert-warning") + + # Redirect back to dashboard + return redirect( + url_for( + "dashboard.index", + rtype=session[constants.TYPE_ARG], + rstate=session[constants.RULE_ARG], + sort=session[constants.SORT_ARG], + squery=session[constants.SEARCH_ARG], + order=session[constants.ORDER_ARG], ) + ) - # delete from db - db.session.delete(model) - db.session.commit() - flash("Rule deleted", "alert-success") - else: - flash("You can not delete this rule", "alert-warning") +@rules.route("/delete_and_whitelist//", methods=["GET"]) +@auth_required +@user_or_admin_required +def delete_and_whitelist(rule_type, rule_id): + """ + Delete an RTBH rule and create a whitelist entry from it. + :param rule_id: integer - id of the RTBH rule + """ + if rule_type != RuleTypes.RTBH.value: + flash("Only RTBH rules can be converted to whitelists", "alert-warning") + return redirect(url_for("index")) + + # Set whitelist expiration to 7 days from now by default + whitelist_expires = datetime.now() + timedelta(days=7) + + # Use the service to delete RTBH and create whitelist + success, messages, whitelist = rule_service.delete_rtbh_and_create_whitelist( + rule_id=rule_id, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], + allowed_rule_ids=session.get(constants.RULES_KEY, []), + whitelist_expires=whitelist_expires, + ) + + # Flash all messages + for message in messages: + flash(message, "alert-success" if success else "alert-warning") + + # If successful, flash additional message about whitelist + if success and whitelist: + flash(f"Created whitelist entry ID {whitelist.id} from RTBH rule", "alert-info") + + # Redirect back to dashboard return redirect( url_for( "dashboard.index", @@ -260,6 +267,7 @@ def group_delete(): rule_type = session[constants.TYPE_ARG] model_name = DATA_MODELS_NAMED[rule_type] rule_type_int = constants.RULE_TYPES_DICT[rule_type] + enum_rule_type = RuleTypes(rule_type_int) route_model = ROUTE_MODELS[rule_type_int] rules = [str(x) for x in session[constants.RULES_KEY]] to_delete = request.form.getlist("delete-id") @@ -279,7 +287,7 @@ def group_delete(): log_withdraw( session["user_id"], route.command, - rule_type_int, + enum_rule_type, model.id, f"{session['user_email']} / {session['user_org']}", ) @@ -376,6 +384,7 @@ def group_update_save(rule_type): model_name = DATA_MODELS[rule_type] form_name = DATA_FORMS[rule_type] + enum_rule_type = RuleTypes(rule_type) form = form_name(request.form) @@ -417,7 +426,7 @@ def group_update_save(rule_type): log_route( session["user_id"], model, - rule_type, + enum_rule_type, f"{session['user_email']} / {session['user_org']}", ) else: @@ -433,7 +442,7 @@ def group_update_save(rule_type): log_withdraw( session["user_id"], route.command, - rule_type, + enum_rule_type, model.id, f"{session['user_email']} / {session['user_org']}", ) @@ -476,54 +485,17 @@ def ipv4_rule(): form.net_ranges = net_ranges if request.method == "POST" and form.validate(): - model = get_ipv4_model_if_exists(form.data, 1) - - if model: - model.expires = round_to_ten_minutes(form.expires.data) - flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value." - else: - model = Flowspec4( - source=form.source.data, - source_mask=form.source_mask.data, - source_port=form.source_port.data, - destination=form.dest.data, - destination_mask=form.dest_mask.data, - destination_port=form.dest_port.data, - protocol=form.protocol.data, - flags=";".join(form.flags.data), - packet_len=form.packet_len.data, - fragment=";".join(form.fragment.data), - expires=round_to_ten_minutes(form.expires.data), - comment=quote_to_ent(form.comment.data), - action_id=form.action.data, - user_id=session["user_id"], - org_id=session["user_org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - flash_message = "IPv4 Rule saved" - db.session.add(model) - - db.session.commit() - flash(flash_message, "alert-success") - - # announce route if model is in active state - if model.rstate_id == 1: - command = messages.create_ipv4(model, constants.ANNOUNCE) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - - # log changes - log_route( - session["user_id"], - model, - RuleTypes.IPv4, - f"{session['user_email']} / {session['user_org']}", + # Use the service to create/update the rule + _model, message = rule_service.create_or_update_ipv4_rule( + form_data=form.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], ) + flash(message, "alert-success") + return redirect(url_for("index")) else: for field, errors in form.errors.items(): @@ -560,52 +532,14 @@ def ipv6_rule(): form.net_ranges = net_ranges if request.method == "POST" and form.validate(): - model = get_ipv6_model_if_exists(form.data, 1) - - if model: - model.expires = round_to_ten_minutes(form.expires.data) - flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value." - else: - model = Flowspec6( - source=form.source.data, - source_mask=form.source_mask.data, - source_port=form.source_port.data, - destination=form.dest.data, - destination_mask=form.dest_mask.data, - destination_port=form.dest_port.data, - next_header=form.next_header.data, - flags=";".join(form.flags.data), - packet_len=form.packet_len.data, - expires=round_to_ten_minutes(form.expires.data), - comment=quote_to_ent(form.comment.data), - action_id=form.action.data, - user_id=session["user_id"], - org_id=session["user_org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - flash_message = "IPv6 Rule saved" - db.session.add(model) - - db.session.commit() - flash(flash_message, "alert-success") - - # announce routes - if model.rstate_id == 1: - command = messages.create_ipv6(model, constants.ANNOUNCE) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - - # log changes - log_route( - session["user_id"], - model, - RuleTypes.IPv6, - f"{session['user_email']} / {session['user_org']}", + _model, message = rule_service.create_or_update_ipv6_rule( + form_data=form.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], ) + flash(message, "alert-success") return redirect(url_for("index")) else: @@ -642,47 +576,18 @@ def rtbh_rule(): ] + user_communities form.community.choices = user_communities form.net_ranges = net_ranges + whitelistable = Community.get_whitelistable_communities(current_app.config["ALLOWED_COMMUNITIES"]) if request.method == "POST" and form.validate(): - model = get_rtbh_model_if_exists(form.data, 1) - - if model: - model.expires = round_to_ten_minutes(form.expires.data) - flash_message = "Existing RTBH Rule found. Expiration time was updated to new value." - else: - model = RTBH( - ipv4=form.ipv4.data, - ipv4_mask=form.ipv4_mask.data, - ipv6=form.ipv6.data, - ipv6_mask=form.ipv6_mask.data, - community_id=form.community.data, - expires=round_to_ten_minutes(form.expires.data), - comment=quote_to_ent(form.comment.data), - user_id=session["user_id"], - org_id=session["user_org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - db.session.add(model) - db.session.commit() - flash_message = "RTBH Rule saved" - - flash(flash_message, "alert-success") - # announce routes - if model.rstate_id == 1: - command = messages.create_rtbh(model, constants.ANNOUNCE) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - # log changes - log_route( - session["user_id"], - model, - RuleTypes.RTBH, - f"{session['user_email']} / {session['user_org']}", + _model, messages = rule_service.create_or_update_rtbh_rule( + form_data=form.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], ) + for message in messages: + flash(message, "alert-success") return redirect(url_for("index")) else: @@ -693,7 +598,11 @@ def rtbh_rule(): default_expires = datetime.now() + timedelta(days=7) form.expires.data = default_expires - return render_template("forms/rtbh_rule.html", form=form, action_url=url_for("rules.rtbh_rule")) + print(whitelistable) + + return render_template( + "forms/rtbh_rule.html", form=form, action_url=url_for("rules.rtbh_rule"), whitelistable=whitelistable + ) @rules.route("/limit_reached/") @@ -775,64 +684,13 @@ def announce_all(): @rules.route("/withdraw_expired", methods=["GET"]) @localhost_only def withdraw_expired(): - announce_all_routes(constants.WITHDRAW) - return " " - - -def announce_all_routes(action=constants.ANNOUNCE): """ - get routes from db and send it to ExaBGB api - - @TODO take the request away, use some kind of messaging (maybe celery?) - :param action: action with routes - announce valid routes or withdraw expired routes + cleaning endpoint + deletes expired whitelists + withdraws all expired routes from ExaBGP + deletes logs older than 30 days """ - today = datetime.now() - comp_func = ge if action == constants.ANNOUNCE else lt - - rules4 = ( - db.session.query(Flowspec4) - .filter(Flowspec4.rstate_id == 1) - .filter(comp_func(Flowspec4.expires, today)) - .order_by(Flowspec4.expires.desc()) - .all() - ) - rules6 = ( - db.session.query(Flowspec6) - .filter(Flowspec6.rstate_id == 1) - .filter(comp_func(Flowspec6.expires, today)) - .order_by(Flowspec6.expires.desc()) - .all() - ) - rules_rtbh = ( - db.session.query(RTBH) - .filter(RTBH.rstate_id == 1) - .filter(comp_func(RTBH.expires, today)) - .order_by(RTBH.expires.desc()) - .all() - ) - - messages_v4 = [messages.create_ipv4(rule, action) for rule in rules4] - messages_v6 = [messages.create_ipv6(rule, action) for rule in rules6] - messages_rtbh = [messages.create_rtbh(rule, action) for rule in rules_rtbh] - - messages_all = [] - messages_all.extend(messages_v4) - messages_all.extend(messages_v6) - messages_all.extend(messages_rtbh) - - author_action = "announce all" if action == constants.ANNOUNCE else "withdraw all expired" - - for command in messages_all: - route = Route( - author=f"System call / {author_action} rules", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - - if action == constants.WITHDRAW: - for ruleset in [rules4, rules6, rules_rtbh]: - for rule in ruleset: - rule.rstate_id = 2 - - db.session.commit() + delete_expired_whitelists() + announce_all_routes(constants.WITHDRAW) + Log.delete_old() + return " " diff --git a/flowapp/views/whitelist.py b/flowapp/views/whitelist.py new file mode 100644 index 00000000..39a9352a --- /dev/null +++ b/flowapp/views/whitelist.py @@ -0,0 +1,127 @@ +from datetime import datetime, timedelta +from flask import Blueprint, current_app, flash, redirect, render_template, request, session, url_for + +from flowapp.auth import ( + auth_required, + user_or_admin_required, +) +from flowapp import constants, db +from flowapp.forms import WhitelistForm +from flowapp.models import get_user_nets, Whitelist +from flowapp.services import create_or_update_whitelist, delete_whitelist +from flowapp.utils.base import flash_errors + +whitelist = Blueprint("whitelist", __name__, template_folder="templates") + + +@whitelist.route("/add", methods=["GET", "POST"]) +@auth_required +@user_or_admin_required +def add(): + net_ranges = get_user_nets(session["user_id"]) + form = WhitelistForm(request.form) + + form.net_ranges = net_ranges + + if request.method == "POST" and form.validate(): + model, messages = create_or_update_whitelist( + form.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], + ) + for message in messages: + flash(message, "alert-success") + + return redirect(url_for("index")) + else: + for field, errors in form.errors.items(): + for error in errors: + current_app.logger.debug("Error in the %s field - %s" % (getattr(form, field).label.text, error)) + + print("NOW", datetime.now()) + default_expires = datetime.now() + timedelta(hours=1) + form.expires.data = default_expires + + return render_template("forms/whitelist.html", form=form, action_url=url_for("whitelist.add")) + + +@whitelist.route("/reactivate/", methods=["GET", "POST"]) +@auth_required +@user_or_admin_required +def reactivate(wl_id): + """ + Set new time for whitelist + :param wl_id: int - id of the whitelist + """ + + model = db.session.get(Whitelist, wl_id) + form = WhitelistForm(request.form, obj=model) + form.net_ranges = get_user_nets(session["user_id"]) + + # do not need to validate - all is readonly + if request.method == "POST": + model = create_or_update_whitelist( + form.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], + ) + flash("Whitelist updated", "alert-success") + return redirect( + url_for( + "dashboard.index", + rtype=session[constants.TYPE_ARG], + rstate=session[constants.RULE_ARG], + sort=session[constants.SORT_ARG], + squery=session[constants.SEARCH_ARG], + order=session[constants.ORDER_ARG], + ) + ) + else: + flash_errors(form) + + form.expires.data = model.expires + for field in form: + if field.name not in ["expires", "csrf_token", "comment"]: + field.render_kw = {"disabled": "disabled"} + + action_url = url_for("whitelist.reactivate", wl_id=wl_id) + + return render_template( + "forms/whitelist.html", + form=form, + action_url=action_url, + editing=True, + title="Update", + ) + + +@whitelist.route("/delete/", methods=["GET"]) +@auth_required +@user_or_admin_required +def delete(wl_id): + """ + Delete whitelist + :param wl_id: integer - id of the whitelist + """ + if wl_id in session[constants.RULES_KEY]: + messages = delete_whitelist(wl_id) + flash(f"Whitelist {wl_id} deleted", "alert-success") + for message in messages: + flash(message, "alert-info") + else: + flash("You can not delete this Whitelist", "alert-warning") + + return redirect( + url_for( + "dashboard.index", + rtype=session[constants.TYPE_ARG], + rstate=session[constants.RULE_ARG], + sort=session[constants.SORT_ARG], + squery=session[constants.SEARCH_ARG], + order=session[constants.ORDER_ARG], + ) + ) diff --git a/run.example.py b/run.example.py index b3dd0c4e..ad6f97eb 100644 --- a/run.example.py +++ b/run.example.py @@ -1,30 +1,19 @@ """ -This is an example of how to run the application. -First copy the file as run.py (or whatever you want) -Then edit the file to match your needs. - -From version 0.8.1 the application is using Flask-Session -stored in DB using SQL Alchemy driver. This can be configured for other -drivers, however server side session is required for the application. - -In general you should not need to edit this example file. -Only if you want to configure the application main menu and -dashboard. - -Or in case that you want to add extensions etc. +This is run py for the application. Copied to the container on build. """ from os import environ -from flowapp import create_app, db, sess +from flowapp import create_app, db import config # Configurations -env = environ.get("EXAFS_ENV", "Production") +exafs_env = environ.get("EXAFS_ENV", "Production") +exafs_env = exafs_env.lower() # Call app factory -if env == "devel": +if exafs_env == "devel" or exafs_env == "development": app = create_app(config.DevelopmentConfig) else: app = create_app(config.ProductionConfig) @@ -32,12 +21,6 @@ # init database object db.init_app(app) -# init session -app.config.update(SESSION_TYPE="sqlalchemy") -app.config.update(SESSION_SQLALCHEMY=db) -sess.init_app(app) - - # run app if __name__ == "__main__": app.run(host="::", port=8080, debug=True) diff --git a/withdraw_expired b/withdraw_expired new file mode 100644 index 00000000..0519ecba --- /dev/null +++ b/withdraw_expired @@ -0,0 +1 @@ + \ No newline at end of file