From 94d03cc4f0cebfc35c01a14436ecdfbdacacc429 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 20 Feb 2025 14:22:29 +0100 Subject: [PATCH 01/36] first prototype of whitelist model and form --- flowapp/constants.py | 8 ++- flowapp/forms.py | 57 +++++++++++++++++++++ flowapp/models.py | 99 ++++++++++++++++++++++++++++++++++++- flowapp/output.py | 1 + flowapp/validators.py | 63 +++++++++++++++++++++++ flowapp/views/api_common.py | 6 +-- run.example.py | 27 ++-------- 7 files changed, 233 insertions(+), 28 deletions(-) diff --git a/flowapp/constants.py b/flowapp/constants.py index 5aa5283..975f99b 100644 --- a/flowapp/constants.py +++ b/flowapp/constants.py @@ -2,6 +2,7 @@ This module contains constant values used in application """ +from enum import Enum from operator import ge, lt DEFAULT_SORT = "expires" @@ -59,7 +60,12 @@ FORM_TIME_PATTERN = "%Y-%m-%dT%H:%M" -class RuleTypes: +class RuleTypes(Enum): RTBH = 1 IPv4 = 4 IPv6 = 6 + + +class RuleOrigin(Enum): + USER = 1 + WHITELIST = 2 diff --git a/flowapp/forms.py b/flowapp/forms.py index 4b914dc..66cf43f 100644 --- a/flowapp/forms.py +++ b/flowapp/forms.py @@ -33,9 +33,11 @@ FORM_TIME_PATTERN, ) from flowapp.validators import ( + IPAddressValidator, IPv4Address, IPv6Address, NetRangeString, + NetworkValidator, PortString, address_in_range, address_with_mask, @@ -648,3 +650,58 @@ def validate_ipv_specific(self): return False return True + + +class WhitelistForm(FlaskForm): + """ + Whitelist form object + Used for creating and editing whitelist entries + Supports both IPv4 and IPv6 addresses + """ + + def __init__(self, *args, **kwargs): + super(WhitelistForm, self).__init__(*args, **kwargs) + self.net_ranges = None + + ip = StringField( + "IP address", + validators=[ + DataRequired(message="Please provide an IP address"), + IPAddressValidator(message="Please provide a valid IP address: {}"), + NetworkValidator(mask_field_name="mask"), + ], + ) + + mask = IntegerField( + "Network mask (bits)", + validators=[ + DataRequired(message="Please provide a network mask"), + ], + ) + + comment = TextAreaField("Comments", validators=[Optional(), Length(max=255)]) + + expires = MultiFormatDateTimeLocalField( + "Expires", + format=FORM_TIME_PATTERN, + validators=[DataRequired(), InputRequired()], + ) + + def validate(self): + """ + Custom validation method + :return: boolean + """ + result = True + + if not FlaskForm.validate(self): + result = False + + # Validate IP is in organization range + if self.ip.data and self.mask.data and self.net_ranges: + ip_in_range = network_in_range(self.ip.data, self.mask.data, self.net_ranges) + if not ip_in_range: + self.ip.errors.append("IP address must be in organization range: {}.".format(self.net_ranges)) + result = False + + return result diff --git a/flowapp/models.py b/flowapp/models.py index 4e1291f..54c50a4 100644 --- a/flowapp/models.py +++ b/flowapp/models.py @@ -662,10 +662,105 @@ def __init__(self, time, task, user_id, rule_type, rule_id, author, org_id=None) self.org_id = org_id -# DDL -# default values for tables inserted after create +class Whitelist(db.Model): + id = db.Column(db.Integer, primary_key=True) + ip = db.Column(db.String(255)) + mask = db.Column(db.Integer) + comment = db.Column(db.Text) + expires = db.Column(db.DateTime) + created = db.Column(db.DateTime) + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", backref="whitelist") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="whitelist") + rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) + rstate = db.relationship("Rstate", backref="whitelist") + + def __init__( + self, + ip, + mask, + expires, + user_id, + org_id, + created=None, + comment=None, + rstate_id=1, + ): + self.ip = ip + self.mask = mask + self.expires = expires + self.user_id = user_id + self.org_id = org_id + self.comment = comment + if created is None: + created = datetime.now() + self.created = created + self.rstate_id = rstate_id + + def __eq__(self, other): + """ + Two whitelists are equal if all the network parameters equals. User_id and time fields can differ. + :param other: other Whitelist instance + :return: boolean + """ + return ( + self.ip == other.ip + and self.mask == other.mask + and self.expires == other.expires + and self.user_id == other.user_id + and self.org_id == other.org_id + and self.rstate_id == other.rstate_id + ) + def to_dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + if prefered_format == "timestamp": + expires = int(datetime.timestamp(self.expires)) + created = int(datetime.timestamp(self.created)) + else: + expires = utils.datetime_to_webpicker(self.expires, prefered_format) + created = utils.datetime_to_webpicker(self.created, prefered_format) + return { + "id": self.id, + "ip": self.ip, + "mask": self.mask, + "comment": self.comment, + "expires": expires, + "created": created, + "user": self.user.uuid, + "rstate": self.rstate.description, + } + + +class RuleWhitelistCache(db.Model): + """ + Cache for whitelisted rules + For each rule we store id and type + Rule origin determines if the rule was created by user or by whitelist + """ + + id = db.Column(db.Integer, primary_key=True) + rid = db.Column(db.Integer) + rtype = db.Column(db.Integer) + rorigin = db.Column(db.Integer) + whitelist_id = db.Column(db.Integer, db.ForeignKey("whitelist.id")) # Add ForeignKey + whitelist = db.relationship("Whitelist", backref="rulewhitelistcache") + + def __init__(self, rid, rtype, rorigin, whitelist_id): + self.rid = rid + self.rtype = rtype + self.rorigin = rorigin + self.whitelist_id = whitelist_id + + +# DDL +# default values for tables inserted after create @event.listens_for(Action.__table__, "after_create") def insert_initial_actions(table, conn, *args, **kwargs): conn.execute( diff --git a/flowapp/output.py b/flowapp/output.py index 3dde822..b0cc7b4 100644 --- a/flowapp/output.py +++ b/flowapp/output.py @@ -96,6 +96,7 @@ def log_route(user_id, route_model, rule_type, author): :param rule_type: string :return: None """ + print(rule_type) converter = ROUTE_MODELS[rule_type] task = converter(route_model) log = Log( diff --git a/flowapp/validators.py b/flowapp/validators.py index 6306106..3b4f706 100644 --- a/flowapp/validators.py +++ b/flowapp/validators.py @@ -352,3 +352,66 @@ def subnet_of(net_a, net_b): def supernet_of(net_a, net_b): """Return True if this network is a supernet of other.""" return _is_subnet_of(net_b, net_a) + + +class IPAddressValidator: + """ + Universal validator that accepts both IPv4 and IPv6 addresses. + """ + + def __init__(self, message=None): + self.message = message or "Invalid IP address: {}" + + def __call__(self, form, field): + if not field.data: + return + + try: + ipaddress.ip_address(field.data) + except ValueError: + raise ValidationError(self.message.format(field.data)) + + +class NetworkValidator: + """ + Validates that an IP address and mask form a valid network. + Works with both IPv4 and IPv6 addresses. + """ + + def __init__(self, mask_field_name, message=None): + self.mask_field_name = mask_field_name + self.message = message or "Invalid network: address {}, mask {}" + + def __call__(self, form, field): + if not field.data: + return + + mask_field = form._fields.get(self.mask_field_name) + if not mask_field or not mask_field.data: + return + + try: + # Determine IP version + ip = ipaddress.ip_address(field.data) + mask = int(mask_field.data) + + # Validate mask range based on IP version + if isinstance(ip, ipaddress.IPv4Address): + if not 0 <= mask <= 32: + raise ValidationError(f"Invalid IPv4 mask: {mask}") + else: # IPv6 + if not 0 <= mask <= 128: + raise ValidationError(f"Invalid IPv6 mask: {mask}") + + # Try to create network to validate the combination + network = ipaddress.ip_network(f"{field.data}/{mask}", strict=False) + + # Check if the original IP is the correct network address + if str(network.network_address) != str(ip): + raise ValidationError( + f"Invalid network address: {field.data}/{mask}. " + f"Network address should be: {network.network_address}" + ) + + except ValueError: + raise ValidationError(self.message.format(field.data, mask_field.data)) diff --git a/flowapp/views/api_common.py b/flowapp/views/api_common.py index 6ce24bf..eeaade6 100644 --- a/flowapp/views/api_common.py +++ b/flowapp/views/api_common.py @@ -292,7 +292,7 @@ def create_ipv4(current_user): log_route( current_user["id"], model, - RuleTypes.IPv4, + RuleTypes.IPv4.value, f"{current_user['uuid']} / {current_user['org']}", ) pref_format = output_date_format(json_request_data, form.expires.pref_format) @@ -368,7 +368,7 @@ def create_ipv6(current_user): log_route( current_user["id"], model, - RuleTypes.IPv6, + RuleTypes.IPv6.value, f"{current_user['uuid']} / {current_user['org']}", ) @@ -441,7 +441,7 @@ def create_rtbh(current_user): log_route( current_user["id"], model, - RuleTypes.RTBH, + RuleTypes.RTBH.value, f"{current_user['uuid']} / {current_user['org']}", ) diff --git a/run.example.py b/run.example.py index b3dd0c4..ad6f97e 100644 --- a/run.example.py +++ b/run.example.py @@ -1,30 +1,19 @@ """ -This is an example of how to run the application. -First copy the file as run.py (or whatever you want) -Then edit the file to match your needs. - -From version 0.8.1 the application is using Flask-Session -stored in DB using SQL Alchemy driver. This can be configured for other -drivers, however server side session is required for the application. - -In general you should not need to edit this example file. -Only if you want to configure the application main menu and -dashboard. - -Or in case that you want to add extensions etc. +This is run py for the application. Copied to the container on build. """ from os import environ -from flowapp import create_app, db, sess +from flowapp import create_app, db import config # Configurations -env = environ.get("EXAFS_ENV", "Production") +exafs_env = environ.get("EXAFS_ENV", "Production") +exafs_env = exafs_env.lower() # Call app factory -if env == "devel": +if exafs_env == "devel" or exafs_env == "development": app = create_app(config.DevelopmentConfig) else: app = create_app(config.ProductionConfig) @@ -32,12 +21,6 @@ # init database object db.init_app(app) -# init session -app.config.update(SESSION_TYPE="sqlalchemy") -app.config.update(SESSION_SQLALCHEMY=db) -sess.init_app(app) - - # run app if __name__ == "__main__": app.run(host="::", port=8080, debug=True) From 073c2b4f8c3056d4ad8c88c423d4d2527e4fad3d Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 20 Feb 2025 14:33:52 +0100 Subject: [PATCH 02/36] fixed api after convetting RuleTypes to Enum class --- flowapp/views/api_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flowapp/views/api_common.py b/flowapp/views/api_common.py index eeaade6..d04e40d 100644 --- a/flowapp/views/api_common.py +++ b/flowapp/views/api_common.py @@ -191,7 +191,7 @@ def all_communities(current_user): def limit_reached(count, rule_type, org_id): - rule_name = RULE_NAMES_DICT[int(rule_type)] + rule_name = RULE_NAMES_DICT[rule_type.value] org = db.session.get(Organization, org_id) if rule_type == RuleTypes.IPv4: limit = org.limit_flowspec4 @@ -207,7 +207,7 @@ def limit_reached(count, rule_type, org_id): def global_limit_reached(count, rule_type): - rule_name = RULE_NAMES_DICT[int(rule_type)] + rule_name = RULE_NAMES_DICT[rule_type.value] if rule_type == RuleTypes.IPv4 or rule_type == RuleTypes.IPv6: limit = current_app.config.get("FLOWSPEC_MAX_RULES") elif rule_type == RuleTypes.RTBH: From f6672dc4adbd30d5943378a5c30b856bd574ec0b Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 20 Feb 2025 15:12:35 +0100 Subject: [PATCH 03/36] add more test for validators, fix edge cases validation in flowspec and validators modules --- flowapp/flowspec.py | 46 +++--- flowapp/tests/test_validators.py | 258 ++++++++++++++++++++++++++++--- flowapp/validators.py | 5 +- 3 files changed, 270 insertions(+), 39 deletions(-) diff --git a/flowapp/flowspec.py b/flowapp/flowspec.py index e0ce35a..ae357ca 100644 --- a/flowapp/flowspec.py +++ b/flowapp/flowspec.py @@ -21,6 +21,19 @@ def translate_sequence(sequence, max_val=MAX_PORT): return "[{}]".format(result) +def check_limit(value, max_value, min_value=0): + """ + test if the value is within valid range (min_value to max_value inclusive) + raise exception otherwise + """ + value = int(value) + if value > max_value: + raise ValueError("Invalid value number: {} is too big. Max is {}.".format(value, max_value)) + if value < min_value: + raise ValueError("Invalid value number: {} is too small. Min is {}.".format(value, min_value)) + return value + + def to_exabgp_string(value_string, max_val): """ Translate form string to flowspec value or packet size rule @@ -42,35 +55,30 @@ def to_exabgp_string(value_string, max_val): return "={}".format(check_limit(value_string, max_val)) elif RANGE.match(value_string): m = RANGE.match(value_string) - return ">={}&<={}".format( - check_limit(m.group(1), max_val), check_limit(m.group(2), max_val) - ) + start = check_limit(m.group(1), max_val) + end = check_limit(m.group(2), max_val) + if start > end: + raise ValueError("Invalid range: start value cannot be greater than end value") + return ">={}&<={}".format(start, end) elif NOTRAN.match(value_string): - return value_string + m = NOTRAN.match(value_string) + start = check_limit(m.group(1), max_val) + end = check_limit(m.group(2), max_val) + if start > end: + raise ValueError("Invalid range: start value cannot be greater than end value") + return ">={}&<={}".format(start, end) elif GREATER.match(value_string): m = GREATER.match(value_string) return ">={}&<={}".format(check_limit(m.group(1), max_val), max_val) elif LOWER.match(value_string): m = LOWER.match(value_string) - return ">=0&<={}".format(check_limit(m.group(1), max_val)) + # Even for lower bound expressions, validate that the value itself is within range + end = check_limit(m.group(1), max_val) + return ">=0&<={}".format(end) else: raise ValueError("string {} can not be converted".format(value_string)) -def check_limit(value, max_value): - """ - test if the value is lower than max_value - raise exception otherwise - """ - value = int(value) - if value > max_value: - raise ValueError( - "Invalid value number: {} is too big. Max is {}.".format(value, max_value) - ) - else: - return value - - def filter_rules_action(user_actions, rules): """ Divide the list of rules by user_actions to editable and viewonly subsets diff --git a/flowapp/tests/test_validators.py b/flowapp/tests/test_validators.py index 33e757c..cab089f 100644 --- a/flowapp/tests/test_validators.py +++ b/flowapp/tests/test_validators.py @@ -1,11 +1,28 @@ import pytest -import flowapp.validators +from flowapp.validators import ( + PortString, + PacketString, + NetRangeString, + NetInRange, + IPAddress, + IPAddressValidator, + NetworkValidator, + ValidationError, + address_in_range, + address_with_mask, + IPv4Address, + IPv6Address, + DateNotExpired, + editable_range, + network_in_range, + range_in_network, +) def test_port_string_len_raises(field): - port = flowapp.validators.PortString() + port = PortString() field.data = "1;2;3;4;5;6;7;8" - with pytest.raises(flowapp.validators.ValidationError): + with pytest.raises(ValidationError): port(None, field) @@ -20,12 +37,12 @@ def test_port_string_len_raises(field): ], ) def test_is_valid_address_with_mask(address, mask, expected): - assert flowapp.validators.address_with_mask(address, mask) == expected + assert address_with_mask(address, mask) == expected @pytest.mark.parametrize("address", ["147.230.23.25", "147.230.23.0"]) def test_ip4address_passes(field, address): - adr = flowapp.validators.IPv4Address() + adr = IPv4Address() field.data = address adr(None, field) @@ -38,7 +55,7 @@ def test_ip4address_passes(field, address): ], ) def test_ip6address_passes(field, address): - adr = flowapp.validators.IPv6Address() + adr = IPv6Address() field.data = address adr(None, field) @@ -51,19 +68,17 @@ def test_ip6address_passes(field, address): ], ) def test_bad_ip6address_raises(field, address): - adr = flowapp.validators.IPv4Address() + adr = IPv4Address() field.data = address - with pytest.raises(flowapp.validators.ValidationError): + with pytest.raises(ValidationError): adr(None, field) -@pytest.mark.parametrize( - "expired", ["2018/10/25 14:46", "2018/12/20 9:46", "2019/05/22 12:33"] -) +@pytest.mark.parametrize("expired", ["2018/10/25 14:46", "2018/12/20 9:46", "2019/05/22 12:33"]) def test_expired_date_raises(field, expired): - adr = flowapp.validators.DateNotExpired() + adr = DateNotExpired() field.data = expired - with pytest.raises(flowapp.validators.ValidationError): + with pytest.raises(ValidationError): adr(None, field) @@ -75,9 +90,9 @@ def test_expired_date_raises(field, expired): ], ) def test_ipaddress_raises(field, address): - adr = flowapp.validators.IPv6Address() + adr = IPv6Address() field.data = address - with pytest.raises(flowapp.validators.ValidationError): + with pytest.raises(ValidationError): adr(None, field) @@ -91,7 +106,7 @@ def test_ipaddress_raises(field, address): def test_editable_rule(rule, address, mask, ranges, expected): rule.source = address rule.source_mask = mask - assert flowapp.validators.editable_range(rule, ranges) == expected + assert editable_range(rule, ranges) == expected @pytest.mark.parametrize( @@ -104,7 +119,7 @@ def test_editable_rule(rule, address, mask, ranges, expected): ], ) def test_address_in_range(address, mask, ranges, expected): - assert flowapp.validators.address_in_range(address, ranges) == expected + assert address_in_range(address, ranges) == expected @pytest.mark.parametrize( @@ -123,7 +138,7 @@ def test_address_in_range(address, mask, ranges, expected): ], ) def test_network_in_range(address, mask, ranges, expected): - assert flowapp.validators.network_in_range(address, mask, ranges) == expected + assert network_in_range(address, mask, ranges) == expected @pytest.mark.parametrize( @@ -133,4 +148,209 @@ def test_network_in_range(address, mask, ranges, expected): ], ) def test_range_in_network(address, mask, ranges, expected): - assert flowapp.validators.range_in_network(address, mask, ranges) == expected + assert range_in_network(address, mask, ranges) == expected + + +### new tests + + +# New tests for previously uncovered validators + + +# Tests for PortString validator syntax +@pytest.mark.parametrize( + "port_data", + [ + "80", # Simple number + "80-443", # Range with hyphen + ">=80&<=443", # Explicit range + ">80", # Greater than + ">=80", # Greater than or equal + "<443", # Less than + "<=443", # Less than or equal + "80;443;8080", # Multiple port expressions + ], +) +def test_port_string_valid_syntax(field, port_data): + validator = PortString() + field.data = port_data + validator(None, field) # Should not raise + + +@pytest.mark.parametrize( + "invalid_port_data", + [ + "80,443", # Comma not supported, should use semicolon + "80..443", # Invalid range syntax (should be 80-443) + ">65536", # Port number too high + "-1", # Negative number + "<-1", # Port number too low + ">=80<=443", # Invalid range syntax (missing &) + "!=80", # Not equals not supported + "abc", # Non-numeric + "80-", # Incomplete range + "-80", # Invalid range + "80&&443", # Invalid operator + "443-80", # End less than start + ">=443&<=80", # End less than start in explicit range + "0-65537", # Range exceeding max + ], +) +def test_port_string_invalid_syntax(field, invalid_port_data): + validator = PortString() + field.data = invalid_port_data + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for PacketString validator +@pytest.mark.parametrize( + "packet_data", + ["1500", ">1000", "<9000", "1000-1500", ">=1000&<=1500", "1500;9000"], # Multiple packet size expressions +) +def test_packet_string_valid(field, packet_data): + validator = PacketString() + field.data = packet_data + validator(None, field) # Should not raise + + +@pytest.mark.parametrize( + "invalid_packet_data", + [ + "65536", # Too large + "-1", # Negative + "1500..", # Invalid range syntax + "abc", # Non-numeric + "!!1500", # Invalid operator + "1500-", # Incomplete range + "9000-1500", # End less than start + ], +) +def test_packet_string_invalid(field, invalid_packet_data): + validator = PacketString() + field.data = invalid_packet_data + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for NetRangeString validator +@pytest.mark.parametrize( + "net_range", + [ + "192.168.0.0/24", + "10.0.0.0/8\n172.16.0.0/12", + "2001:db8::/32", + "192.168.0.0/24 10.0.0.0/8", + "2001:db8::/32 2001:db8:1::/48", + ], +) +def test_net_range_string_valid(field, net_range): + validator = NetRangeString() + field.data = net_range + validator(None, field) # Should not raise + + +@pytest.mark.parametrize( + "invalid_net_range", + [ + "192.168.0.0/33", # Invalid mask + "256.256.256.0/24", # Invalid IP + "2001:xyz::/32", # Invalid IPv6 + "192.168.0.0/24/", # Malformed + "not-an-ip/24", # Invalid format + ], +) +def test_net_range_string_invalid(field, invalid_net_range): + validator = NetRangeString() + field.data = invalid_net_range + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for NetInRange validator +def test_net_in_range_valid(field): + net_ranges = ["192.168.0.0/16", "10.0.0.0/8"] + validator = NetInRange(net_ranges) + field.data = "192.168.1.1/24" + validator(None, field) # Should not raise + + +def test_net_in_range_invalid(field): + net_ranges = ["192.168.0.0/16", "10.0.0.0/8"] + validator = NetInRange(net_ranges) + field.data = "172.16.1.1/24" + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for base IPAddress validator +@pytest.mark.parametrize("ip_addr", ["192.168.1.1", "10.0.0.1", "2001:db8::1", "fe80::1"]) +def test_ip_address_valid(field, ip_addr): + validator = IPAddress() + field.data = ip_addr + validator(None, field) # Should not raise + + +@pytest.mark.parametrize("invalid_ip", ["256.256.256.256", "2001:xyz::1", "not-an-ip", "192.168.1"]) +def test_ip_address_invalid(field, invalid_ip): + validator = IPAddress() + field.data = invalid_ip + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for universal IPAddressValidator +@pytest.mark.parametrize("ip_addr", ["192.168.1.1", "2001:db8::1", "", None]) # Empty is allowed # None is allowed +def test_ip_address_validator_valid(field, ip_addr): + validator = IPAddressValidator() + field.data = ip_addr + validator(None, field) # Should not raise + + +@pytest.mark.parametrize("invalid_ip", ["256.256.256.256", "2001:xyz::1", "not-an-ip"]) +def test_ip_address_validator_invalid(field, invalid_ip): + validator = IPAddressValidator() + field.data = invalid_ip + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for NetworkValidator +class MockForm: + def __init__(self, address, mask): + self._fields = {"mask": type("MockField", (), {"data": mask})()} + + +@pytest.mark.parametrize( + "address,mask", + [ + ("192.168.0.0", "24"), + ("10.0.0.0", "8"), + ("2001:db8::", "32"), + ("", None), # Empty values should pass + (None, None), # None values should pass + ], +) +def test_network_validator_valid(field, address, mask): + validator = NetworkValidator("mask") + field.data = address + form = MockForm(address, mask) + validator(form, field) # Should not raise + + +@pytest.mark.parametrize( + "address,mask", + [ + ("192.168.1.1", "24"), # Not a network address + ("192.168.0.0", "33"), # Invalid IPv4 mask + ("2001:db8::", "129"), # Invalid IPv6 mask + ("256.256.256.0", "24"), # Invalid IP + ("2001:xyz::", "32"), # Invalid IPv6 + ], +) +def test_network_validator_invalid(field, address, mask): + validator = NetworkValidator("mask") + field.data = address + form = MockForm(address, mask) + with pytest.raises(ValidationError): + validator(form, field) diff --git a/flowapp/validators.py b/flowapp/validators.py index 3b4f706..890bdac 100644 --- a/flowapp/validators.py +++ b/flowapp/validators.py @@ -253,7 +253,10 @@ def __call__(self, form, field): result = False for address in field.data.split("/"): for adr_range in self.net_ranges: - result = result or ipaddress.ip_address(address) in ipaddress.ip_network(adr_range) + try: + result = result or ipaddress.ip_address(address) in ipaddress.ip_network(adr_range) + except ValueError as e: + raise ValidationError(self.message + str(e.args[0])) if not result: raise ValidationError(self.message) From 4e2a6a80e4c547e44a7cf98b4725854feef1fa56 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 21 Feb 2025 09:35:30 +0100 Subject: [PATCH 04/36] increased test coverage for forms, fixed minor edge cases validation issues in forms --- flowapp/forms.py | 10 +- flowapp/tests/test_flowspec.py | 120 ++++++- flowapp/tests/test_forms_cl.py | 615 +++++++++++++++++++++++++++++++++ 3 files changed, 733 insertions(+), 12 deletions(-) create mode 100644 flowapp/tests/test_forms_cl.py diff --git a/flowapp/forms.py b/flowapp/forms.py index 66cf43f..5baca78 100644 --- a/flowapp/forms.py +++ b/flowapp/forms.py @@ -67,15 +67,21 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def process_formdata(self, valuelist): - if not valuelist: + if not valuelist or (len(valuelist) == 1 and not valuelist[0]): return None + # with unlimited field we do not need to parse the empty value if self.unlimited and len(valuelist) == 1 and len(valuelist[0]) == 0: self.data = None return None date_str = " ".join((str(val) for val in valuelist)) - result, pref_format = parse_api_time(date_str) + + try: + result, pref_format = parse_api_time(date_str) + except TypeError: + raise ValueError(self.gettext("Not a valid datetime value.")) + if result: self.data = result self.pref_format = pref_format diff --git a/flowapp/tests/test_flowspec.py b/flowapp/tests/test_flowspec.py index e7a9c3e..70000b8 100644 --- a/flowapp/tests/test_flowspec.py +++ b/flowapp/tests/test_flowspec.py @@ -1,12 +1,12 @@ import pytest -import flowapp.flowspec +from flowapp.flowspec import translate_sequence, filter_rules_action, check_limit def test_translate_number(): """ tests for x (integer) to =x """ - assert "[=10]" == flowapp.flowspec.translate_sequence("10") + assert "[=10]" == translate_sequence("10") def test_raises(): @@ -14,7 +14,7 @@ def test_raises(): tests for translator """ with pytest.raises(ValueError): - flowapp.flowspec.translate_sequence("ahoj") + translate_sequence("ahoj") def test_raises_bad_number(): @@ -22,46 +22,146 @@ def test_raises_bad_number(): tests for translator """ with pytest.raises(ValueError): - flowapp.flowspec.translate_sequence("75555") + translate_sequence("75555") def test_translate_range(): """ tests for x-y to >=x&<=y """ - assert "[>=10&<=20]" == flowapp.flowspec.translate_sequence("10-20") + assert "[>=10&<=20]" == translate_sequence("10-20") def test_exact_rule(): """ test for >=x&<=y to >=x&<=y """ - assert "[>=10&<=20]" == flowapp.flowspec.translate_sequence(">=10&<=20") + assert "[>=10&<=20]" == translate_sequence(">=10&<=20") def test_greater_than(): """ test for >x to >=x&<=65535 """ - assert "[>=10&<=65535]" == flowapp.flowspec.translate_sequence(">10") + assert "[>=10&<=65535]" == translate_sequence(">10") def test_greater_equal_than(): """ test for >=x to >=x&<=65535 """ - assert "[>=10&<=65535]" == flowapp.flowspec.translate_sequence(">=10") + assert "[>=10&<=65535]" == translate_sequence(">=10") def test_lower_than(): """ test for =0&<=0 """ - assert "[>=0&<=10]" == flowapp.flowspec.translate_sequence("<10") + assert "[>=0&<=10]" == translate_sequence("<10") def test_lower_equal_than(): """ test for =0&<=0 """ - assert "[>=0&<=10]" == flowapp.flowspec.translate_sequence("<=10") + assert "[>=0&<=10]" == translate_sequence("<=10") + + +# new tests + + +def test_multiple_sequences(): + """Test multiple sequences separated by semicolons""" + assert "[=10 >=20&<=30]" == translate_sequence("10;20-30") + assert "[=10 >=0&<=20 >=30&<=65535]" == translate_sequence("10;<=20;>30") + + +def test_empty_sequence(): + """Test empty sequence and sequences with empty parts""" + assert "[]" == translate_sequence("") + assert "[=10]" == translate_sequence("10;;") + assert "[=10 =20]" == translate_sequence("10;;20") + + +def test_range_edge_cases(): + """Test edge cases for ranges""" + # Same numbers in range + assert "[>=10&<=10]" == translate_sequence("10-10") + + # Invalid range (start > end) + with pytest.raises(ValueError, match="Invalid range: start value cannot be greater than end value"): + translate_sequence("20-10") + + # Invalid range in exact rule + with pytest.raises(ValueError, match="Invalid range: start value cannot be greater than end value"): + translate_sequence(">=20&<=10") + + +def test_check_limit_validation(): + """Test the check_limit function""" + # Test valid cases + assert check_limit(10, 100) == 10 + assert check_limit(0, 100, 0) == 0 + assert check_limit(100, 100, 0) == 100 + + # Test invalid cases + with pytest.raises(ValueError, match="Invalid value number: .* is too big"): + check_limit(101, 100) + + with pytest.raises(ValueError, match="Invalid value number: .* is too small"): + check_limit(-1, 100, 0) + + +def test_invalid_inputs(): + """Test various invalid inputs""" + invalid_inputs = [ + "abc", # Non-numeric + "10.5", # Decimal + "10,20", # Wrong separator + "10&20", # Wrong format + ">", # Incomplete + ">=", # Incomplete + "<", # Incomplete + "<=", # Incomplete + ">>10", # Invalid operator + "<<10", # Invalid operator + ">=10&", # Incomplete range + "&<=10", # Incomplete range + ] + + for invalid_input in invalid_inputs: + with pytest.raises(ValueError): + translate_sequence(invalid_input) + + +class MockRule: + def __init__(self, action_id): + self.action_id = action_id + + +def test_filter_rules_action(): + """Test the filter_rules_action function""" + # Create test rules + rules = [MockRule(1), MockRule(2), MockRule(3), MockRule(4)] + + # Test with empty allowed actions + editable, viewonly = filter_rules_action([], rules) + assert len(editable) == 0 + assert len(viewonly) == 4 + + # Test with some allowed actions + editable, viewonly = filter_rules_action([1, 3], rules) + assert len(editable) == 2 + assert len(viewonly) == 2 + assert all(rule.action_id in [1, 3] for rule in editable) + assert all(rule.action_id in [2, 4] for rule in viewonly) + + # Test with all actions allowed + editable, viewonly = filter_rules_action([1, 2, 3, 4], rules) + assert len(editable) == 4 + assert len(viewonly) == 0 + + # Test with empty rules list + editable, viewonly = filter_rules_action([1, 2], []) + assert len(editable) == 0 + assert len(viewonly) == 0 diff --git a/flowapp/tests/test_forms_cl.py b/flowapp/tests/test_forms_cl.py new file mode 100644 index 0000000..d5373ac --- /dev/null +++ b/flowapp/tests/test_forms_cl.py @@ -0,0 +1,615 @@ +import pytest +from datetime import datetime, timedelta +from werkzeug.datastructures import MultiDict +from flask import Flask +from flowapp.forms import ( + UserForm, + BulkUserForm, + ApiKeyForm, + MachineApiKeyForm, + OrganizationForm, + ActionForm, + ASPathForm, + CommunityForm, + RTBHForm, + IPv4Form, + IPv6Form, + WhitelistForm, +) + + +def create_form_data(data): + """Helper function to create proper form data format""" + processed_data = {} + for key, value in data.items(): + if isinstance(value, list): + processed_data[key] = [str(v) for v in value] + else: + processed_data[key] = value + return MultiDict(processed_data) + + +@pytest.fixture() +def app(): + """Create Flask app with CSRF disabled for testing""" + app = Flask(__name__) + app.config.update(SECRET_KEY="test_secret", WTF_CSRF_ENABLED=False, TESTING=True) + return app + + +@pytest.fixture +def valid_datetime(): + return (datetime.now() + timedelta(days=1)).strftime("%Y-%m-%dT%H:%M") + + +@pytest.fixture +def sample_network_ranges(): + return ["192.168.0.0/16", "2001:db8::/32"] + + +class TestUserForm: + @pytest.fixture + def mock_choices(self): + return { + "role_ids": [(1, "Admin"), (2, "User"), (3, "Guest")], + "org_ids": [(1, "Org1"), (2, "Org2"), (3, "Org3")], + } + + @pytest.fixture + def valid_user_data(self): + return { + "uuid": "test@example.com", + "email": "user@example.com", + "name": "Test User", + "phone": "123456789", + "role_ids": ["2"], + "org_ids": ["1"], + } + + def test_valid_user_form(self, app, valid_user_data, mock_choices): + with app.test_request_context(): + form_data = create_form_data(valid_user_data) + form = UserForm(formdata=form_data) + form.role_ids.choices = mock_choices["role_ids"] + form.org_ids.choices = mock_choices["org_ids"] + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_invalid_email(self, app, mock_choices): + with app.test_request_context(): + form_data = create_form_data({"uuid": "invalid-email", "role_ids": ["2"], "org_ids": ["1"]}) + form = UserForm(formdata=form_data) + form.role_ids.choices = mock_choices["role_ids"] + form.org_ids.choices = mock_choices["org_ids"] + + assert not form.validate() + assert "Please provide valid email" in form.uuid.errors + + +class TestRTBHForm: + @pytest.fixture + def mock_community_choices(self): + return [(1, "Community1"), (2, "Community2")] + + @pytest.fixture + def valid_rtbh_data(self, valid_datetime): + return { + "ipv4": "192.168.1.0", + "ipv4_mask": 24, + "community": "1", + "expires": valid_datetime, + "comment": "Test RTBH rule", + } + + def test_valid_ipv4_rtbh(self, app, valid_rtbh_data, sample_network_ranges, mock_community_choices): + with app.test_request_context(): + form_data = create_form_data(valid_rtbh_data) + form = RTBHForm(formdata=form_data) + form.net_ranges = sample_network_ranges + form.community.choices = mock_community_choices + assert form.validate() + + +class TestIPv4Form: + @pytest.fixture + def mock_action_choices(self): + return [(1, "Accept"), (2, "Drop"), (3, "Reject")] + + @pytest.fixture + def valid_ipv4_data(self, valid_datetime): + return { + "source": "192.168.1.0", + "source_mask": "24", + "protocol": "tcp", + "action": "1", + "expires": valid_datetime, + } + + def test_valid_ipv4_rule(self, app, valid_ipv4_data, sample_network_ranges, mock_action_choices): + with app.test_request_context(): + form_data = create_form_data(valid_ipv4_data) + form = IPv4Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_invalid_protocol_flags(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + with app.test_request_context(): + form_data = create_form_data( + { + "source": "192.168.1.0", + "source_mask": "24", + "protocol": "udp", + "flags": ["syn"], + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv4Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + assert not form.validate() + + +class TestIPv6Form: + @pytest.fixture + def mock_action_choices(self): + return [(1, "Accept"), (2, "Drop"), (3, "Reject")] + + @pytest.fixture + def valid_ipv6_data(self, valid_datetime): + return { + "source": "2001:db8::", # Network aligned address within allowed range + "source_mask": "32", # Matching the organization's prefix length + "next_header": "tcp", + "action": "1", + "expires": valid_datetime, + } + + def test_valid_ipv6_rule(self, app, valid_ipv6_data, sample_network_ranges, mock_action_choices): + with app.test_request_context(): + form_data = create_form_data(valid_ipv6_data) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_invalid_next_header_flags(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db8::", + "source_mask": "32", + "next_header": "udp", + "flags": ["syn"], + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + assert not form.validate() + + def test_address_outside_range(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + """Test validation fails when address is outside allowed ranges""" + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db9::", # Different prefix + "source_mask": "32", + "next_header": "tcp", + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + assert not form.validate() + assert any("must be in organization range" in error for error in form.source.errors) + + def test_destination_address(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + """Test validation with destination address instead of source""" + with app.test_request_context(): + form_data = create_form_data( + { + "dest": "2001:db8::", + "dest_mask": "32", + "next_header": "tcp", + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_both_source_and_dest(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + """Test validation with both source and destination addresses""" + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db8::", + "source_mask": "32", + "dest": "2001:db8:1::", + "dest_mask": "48", + "next_header": "tcp", + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_tcp_flags(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + """Test validation with TCP flags (should be valid with TCP)""" + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db8::", + "source_mask": "32", + "next_header": "tcp", + "flags": ["SYN", "ACK"], + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + @pytest.mark.parametrize("port_data", ["80", "80;443", "1024-2048"]) + def test_valid_ports(self, app, valid_datetime, sample_network_ranges, mock_action_choices, port_data): + """Test validation with various valid port formats""" + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db8::", + "source_mask": "32", + "next_header": "tcp", + "source_port": port_data, + "dest_port": port_data, + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + +class TestBulkUserForm: + @pytest.fixture + def valid_csv_data(self): + return "uuid-eppn,role,organizace\nuser1@example.com,2,1\nuser2@example.com,2,1" + + def test_valid_bulk_import(self, app, valid_csv_data): + with app.test_request_context(): + form_data = create_form_data({"users": valid_csv_data}) + form = BulkUserForm(formdata=form_data) + form.roles = {2} # Mock available roles + form.organizations = {1} # Mock available organizations + form.uuids = set() # Mock existing UUIDs (empty set) + assert form.validate() + + def test_duplicate_uuid(self, app, valid_csv_data): + with app.test_request_context(): + form_data = create_form_data({"users": valid_csv_data}) + form = BulkUserForm(formdata=form_data) + form.roles = {2} + form.organizations = {1} + form.uuids = {"user1@example.com"} # UUID already exists + assert not form.validate() + assert any("already exists" in error for error in form.users.errors) + + def test_invalid_role(self, app): + csv_data = "uuid-eppn,role,organizace\nuser1@example.com,999,1" # Invalid role ID + with app.test_request_context(): + form_data = create_form_data({"users": csv_data}) + form = BulkUserForm(formdata=form_data) + form.roles = {2} + form.organizations = {1} + form.uuids = set() + assert not form.validate() + assert any("does not exist" in error for error in form.users.errors) + + +class TestApiKeyForm: + @pytest.fixture + def valid_api_key_data(self, valid_datetime): + return {"machine": "192.168.1.1", "comment": "Test API key", "expires": valid_datetime, "readonly": "true"} + + def test_valid_api_key(self, app, valid_api_key_data): + with app.test_request_context(): + form_data = create_form_data(valid_api_key_data) + form = ApiKeyForm(formdata=form_data) + assert form.validate() + + def test_invalid_ip(self, app, valid_datetime): + with app.test_request_context(): + form_data = create_form_data({"machine": "invalid_ip", "expires": valid_datetime}) + form = ApiKeyForm(formdata=form_data) + assert not form.validate() + assert "provide valid IP address" in form.machine.errors + + def test_unlimited_expiration(self, app): + with app.test_request_context(): + form_data = create_form_data({"machine": "192.168.1.1", "expires": ""}) # Empty expiration for unlimited + form = ApiKeyForm(formdata=form_data) + assert form.validate() + + +class TestMachineApiKeyForm: + # Similar to ApiKeyForm, but might have different validation rules + @pytest.fixture + def valid_machine_key_data(self, valid_datetime): + return { + "machine": "192.168.1.1", + "comment": "Test machine API key", + "expires": valid_datetime, + "readonly": "true", + } + + def test_valid_machine_key(self, app, valid_machine_key_data): + with app.test_request_context(): + form_data = create_form_data(valid_machine_key_data) + form = MachineApiKeyForm(formdata=form_data) + assert form.validate() + + +class TestOrganizationForm: + @pytest.fixture + def valid_org_data(self): + return { + "name": "Test Organization", + "limit_flowspec4": "100", + "limit_flowspec6": "100", + "limit_rtbh": "50", + "arange": "192.168.0.0/16\n2001:db8::/32", + } + + def test_valid_organization(self, app, valid_org_data): + with app.test_request_context(): + form_data = create_form_data(valid_org_data) + form = OrganizationForm(formdata=form_data) + assert form.validate() + + def test_invalid_ranges(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "Test Org", "arange": "invalid_range\n192.168.0.0/16"}) + form = OrganizationForm(formdata=form_data) + assert not form.validate() + + def test_invalid_limits(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "Test Org", "limit_flowspec4": "1001"}) # Exceeds max value + form = OrganizationForm(formdata=form_data) + assert not form.validate() + + +class TestActionForm: + @pytest.fixture + def valid_action_data(self): + return { + "name": "test_action", + "command": "announce route", + "description": "Test action description", + "role_id": "2", + } + + def test_valid_action(self, app, valid_action_data): + with app.test_request_context(): + form_data = create_form_data(valid_action_data) + form = ActionForm(formdata=form_data) + assert form.validate() + + def test_invalid_role(self, app, valid_action_data): + with app.test_request_context(): + invalid_data = dict(valid_action_data) + invalid_data["role_id"] = "4" # Invalid role + form_data = create_form_data(invalid_data) + form = ActionForm(formdata=form_data) + assert not form.validate() + + +class TestASPathForm: + @pytest.fixture + def valid_aspath_data(self): + return {"prefix": "AS64512", "as_path": "64512 64513 64514"} + + def test_valid_aspath(self, app, valid_aspath_data): + with app.test_request_context(): + form_data = create_form_data(valid_aspath_data) + form = ASPathForm(formdata=form_data) + assert form.validate() + + def test_empty_fields(self, app): + with app.test_request_context(): + form_data = create_form_data({}) + form = ASPathForm(formdata=form_data) + assert not form.validate() + + +class TestCommunityForm: + @pytest.fixture + def valid_community_data(self): + return { + "name": "test_community", + "comm": "64512:100", + "description": "Test community", + "role_id": "2", + "as_path": "true", + } + + def test_valid_community(self, app, valid_community_data): + with app.test_request_context(): + form_data = create_form_data(valid_community_data) + form = CommunityForm(formdata=form_data) + assert form.validate() + + def test_missing_all_community_types(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "test_community", "role_id": "2"}) + form = CommunityForm(formdata=form_data) + assert not form.validate() + assert any("could not be empty" in error for error in form.comm.errors) + + def test_valid_with_extended_community(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "test_community", "extcomm": "rt:64512:100", "role_id": "2"}) + form = CommunityForm(formdata=form_data) + assert form.validate() + + def test_valid_with_large_community(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "test_community", "larcomm": "64512:100:200", "role_id": "2"}) + form = CommunityForm(formdata=form_data) + assert form.validate() + + +class TestWhitelistForm: + @pytest.fixture + def valid_ipv4_data(self, valid_datetime): + return {"ip": "192.168.0.0", "mask": "24", "comment": "Test whitelist entry", "expires": valid_datetime} + + @pytest.fixture + def valid_ipv6_data(self, valid_datetime): + return {"ip": "2001:db8::", "mask": "32", "comment": "IPv6 whitelist entry", "expires": valid_datetime} + + def test_valid_ipv4_entry(self, app, valid_ipv4_data, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data(valid_ipv4_data) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_valid_ipv6_entry(self, app, valid_ipv6_data, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data(valid_ipv6_data) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_invalid_ip_format(self, app, valid_datetime, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data({"ip": "invalid_ip", "mask": "24", "expires": valid_datetime}) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + assert any("valid IP address" in error for error in form.ip.errors) + + def test_ip_outside_range(self, app, valid_datetime): + with app.test_request_context(): + form_data = create_form_data( + {"ip": "10.0.0.0", "mask": "24", "expires": valid_datetime} # IP outside allowed range + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = ["192.168.0.0/16"] # Only allow 192.168.0.0/16 + assert not form.validate() + assert any("must be in organization range" in error for error in form.ip.errors) + + def test_invalid_mask_ipv4(self, app, valid_datetime, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data( + {"ip": "192.168.0.0", "mask": "33", "expires": valid_datetime} # Invalid mask for IPv4 + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + + def test_invalid_mask_ipv6(self, app, valid_datetime, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data( + {"ip": "2001:db8::", "mask": "129", "expires": valid_datetime} # Invalid mask for IPv6 + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + + def test_missing_required_fields(self, app, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data({}) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + assert form.ip.errors # Should have error for missing IP + assert form.mask.errors # Should have error for missing mask + assert form.expires.errors # Should have error for missing expiration + + def test_comment_length(self, app, valid_datetime, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data( + { + "ip": "192.168.0.0", + "mask": "24", + "expires": valid_datetime, + "comment": "x" * 256, # Comment longer than 255 chars + } + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + + @pytest.mark.parametrize( + "expires", ["", "invalid_date", "2020-33-42T00:00"] # Empty expiration # Invalid date format # Past date + ) + def test_invalid_expiration(self, app, expires, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data({"ip": "192.168.0.0", "mask": "24", "expires": expires}) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + if not form.validate(): + print("Validation errors:", form.errors) + + assert not form.validate() + + def test_network_alignment(self, app, valid_datetime, sample_network_ranges): + """Test that IP addresses must be properly network-aligned""" + with app.test_request_context(): + form_data = create_form_data( + {"ip": "192.168.1.1", "mask": "24", "expires": valid_datetime} # Not aligned to /24 boundary + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() From c3b2bbb73906d94372774df3c9d359988d57bf87 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 21 Feb 2025 10:10:36 +0100 Subject: [PATCH 05/36] Refactor equality checks and add model unit tests - Updated equality checks in Flowspec4 and Whitelist models: - Adjusted conditions to compare only relevant fields. - Excluded user_id, org_id, comment, and time fields from Whitelist equality checks. - Improved formatting in docstrings for clarity. - Added unit tests for: - User creation and relationships. - ApiKey and MachineApiKey expiration logic. - Organization user retrieval. - Flowspec6 and Whitelist equality comparisons. - Whitelist serialization with to_dict. These changes enhance test coverage and refine model behavior. --- flowapp/models.py | 15 +- flowapp/tests/test_models.py | 263 ++++++++++++++++++++++++++++++++++- 2 files changed, 267 insertions(+), 11 deletions(-) diff --git a/flowapp/models.py b/flowapp/models.py index 54c50a4..90a3a9b 100644 --- a/flowapp/models.py +++ b/flowapp/models.py @@ -419,7 +419,8 @@ def __init__( def __eq__(self, other): """ - Two models are equal if all the network parameters equals. User_id and time fields can differ. + Two models are equal if all the network parameters equals. + User_id and time fields can differ. :param other: other Flowspec4 instance :return: boolean """ @@ -700,18 +701,12 @@ def __init__( def __eq__(self, other): """ - Two whitelists are equal if all the network parameters equals. User_id and time fields can differ. + Two whitelists are equal if all the network parameters equals. + User_id, org, comment and time fields can differ. :param other: other Whitelist instance :return: boolean """ - return ( - self.ip == other.ip - and self.mask == other.mask - and self.expires == other.expires - and self.user_id == other.user_id - and self.org_id == other.org_id - and self.rstate_id == other.rstate_id - ) + return self.ip == other.ip and self.mask == other.mask and self.rstate_id == other.rstate_id def to_dict(self, prefered_format="yearfirst"): """ diff --git a/flowapp/tests/test_models.py b/flowapp/tests/test_models.py index 8959945..82dff8e 100644 --- a/flowapp/tests/test_models.py +++ b/flowapp/tests/test_models.py @@ -1,4 +1,16 @@ -from datetime import datetime +from datetime import datetime, timedelta +from flowapp.models import ( + User, + Organization, + Role, + ApiKey, + MachineApiKey, + Rstate, + Community, + Action, + Flowspec6, + Whitelist, +) import flowapp.models as models @@ -232,3 +244,252 @@ def test_rtbj_eq(db): ) assert model_A == model_B + + +def test_user_creation(db): + """Test basic user creation and relationships""" + # Create test role and org first + role = Role(name="test_role", description="Test Role") + org = Organization(name="test_org", arange="10.0.0.0/8") + db.session.add_all([role, org]) + db.session.commit() + + # Create user with relationships + user = User( + uuid="test-user-123", name="Test User", phone="1234567890", email="test@example.com", comment="Test comment" + ) + user.role.append(role) + user.organization.append(org) + db.session.add(user) + db.session.commit() + + # Verify user and relationships + assert user.uuid == "test-user-123" + assert user.name == "Test User" + assert len(user.role.all()) == 1 + assert len(user.organization.all()) == 1 + assert user.role.first().name == "test_role" + assert user.organization.first().name == "test_org" + + +def test_api_key_expiration(db): + """Test ApiKey expiration logic""" + user = User(uuid="test-user") + org = Organization(name="test-org", arange="10.0.0.0/8") + db.session.add_all([user, org]) + db.session.commit() + + # Create non-expiring key + non_expiring_key = ApiKey( + machine="test-machine-1", + key="key1", + readonly=True, + expires=None, + comment="Non-expiring key", + user_id=user.id, + org_id=org.id, + ) + + # Create expired key + expired_key = ApiKey( + machine="test-machine-2", + key="key2", + readonly=True, + expires=datetime.now() - timedelta(days=1), + comment="Expired key", + user_id=user.id, + org_id=org.id, + ) + + # Create future key + future_key = ApiKey( + machine="test-machine-3", + key="key3", + readonly=True, + expires=datetime.now() + timedelta(days=1), + comment="Future key", + user_id=user.id, + org_id=org.id, + ) + + db.session.add_all([non_expiring_key, expired_key, future_key]) + db.session.commit() + + assert not non_expiring_key.is_expired() + assert expired_key.is_expired() + assert not future_key.is_expired() + + +def test_machine_api_key_expiration(db): + """Test MachineApiKey expiration logic""" + user = User(uuid="test-user-machine") + org = Organization(name="test-org-machine", arange="10.0.0.0/8") + db.session.add_all([user, org]) + db.session.commit() + + # Create non-expiring key + non_expiring_key = MachineApiKey( + machine="test-machine-1", + key="key1", + readonly=True, + expires=None, + comment="Non-expiring key", + user_id=user.id, + org_id=org.id, + ) + + # Create expired key + expired_key = MachineApiKey( + machine="test-machine-2", + key="key2", + readonly=True, + expires=datetime.now() - timedelta(days=1), + comment="Expired key", + user_id=user.id, + org_id=org.id, + ) + + db.session.add_all([non_expiring_key, expired_key]) + db.session.commit() + + assert not non_expiring_key.is_expired() + assert expired_key.is_expired() + + +def test_organization_get_users(db): + """Test Organization's get_users method""" + org = Organization( + name="test-org-get-user", arange="10.0.0.0/8", limit_flowspec4=100, limit_flowspec6=100, limit_rtbh=100 + ) + uuid1 = "test-org-get-user" + uuid2 = "test-org-get-user2" + user1 = User(uuid=uuid1) + user2 = User(uuid=uuid2) + + db.session.add(org) + db.session.add_all([user1, user2]) + db.session.commit() + + org.user.append(user1) + org.user.append(user2) + db.session.commit() + + users = org.get_users() + assert len(users) == 2 + assert all(isinstance(user, User) for user in users) + assert {user.uuid for user in users} == {uuid1, uuid2} + + +def test_flowspec6_equality(db): + """Test Flowspec6 equality comparison""" + model_a = Flowspec6( + source="2001:db8::1", + source_mask=128, + source_port="80", + destination="2001:db8::2", + destination_mask=128, + destination_port="443", + next_header="tcp", + flags="", + packet_len="", + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + ) + + # Same network parameters but different timestamps + model_b = Flowspec6( + source="2001:db8::1", + source_mask=128, + source_port="80", + destination="2001:db8::2", + destination_mask=128, + destination_port="443", + next_header="tcp", + flags="", + packet_len="", + expires=datetime.now() + timedelta(days=1), + user_id=1, + org_id=1, + action_id=1, + ) + + # Different network parameters + model_c = Flowspec6( + source="2001:db8::3", + source_mask=128, + source_port="80", + destination="2001:db8::4", + destination_mask=128, + destination_port="443", + next_header="tcp", + flags="", + packet_len="", + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + ) + + assert model_a == model_b # Should be equal despite different timestamps + assert model_a != model_c # Should be different due to different network parameters + + +def test_whitelist_equality(db): + """Test Whitelist equality comparison""" + model_a = Whitelist( + ip="192.168.1.1", mask=32, expires=datetime.now(), user_id=1, org_id=1, comment="Test whitelist" + ) + + # Same IP/mask but different timestamps + model_b = Whitelist( + ip="192.168.1.1", + mask=32, + expires=datetime.now() + timedelta(days=1), + user_id=1, + org_id=1, + comment="Different comment", + ) + + # Different IP + model_c = Whitelist( + ip="192.168.1.2", mask=32, expires=datetime.now(), user_id=1, org_id=1, comment="Test whitelist" + ) + + assert model_a == model_b # Should be equal despite different timestamps + assert model_a != model_c # Should be different due to different IP + + +def test_whitelist_to_dict(db): + """Test Whitelist to_dict serialization""" + whitelist = Whitelist( + ip="192.168.1.1", mask=32, expires=datetime.now(), user_id=1, org_id=1, comment="Test whitelist" + ) + + # Create required related objects + user = User(uuid="test-user-whitelist") + rstate = Rstate(description="active") + db.session.add_all([user, rstate]) + db.session.commit() + + whitelist.user = user + whitelist.rstate_id = rstate.id + db.session.add(whitelist) + db.session.commit() + + # Test timestamp format + dict_timestamp = whitelist.to_dict(prefered_format="timestamp") + assert isinstance(dict_timestamp["expires"], int) + assert isinstance(dict_timestamp["created"], int) + + # Test yearfirst format + dict_yearfirst = whitelist.to_dict(prefered_format="yearfirst") + assert isinstance(dict_yearfirst["expires"], str) + assert isinstance(dict_yearfirst["created"], str) + + # Check basic fields + assert dict_timestamp["ip"] == "192.168.1.1" + assert dict_timestamp["mask"] == 32 + assert dict_timestamp["comment"] == "Test whitelist" + assert dict_timestamp["user"] == "test-user-whitelist" From d86634e3eedad7f5467291cf70880365ea53da91 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 27 Feb 2025 13:02:42 +0100 Subject: [PATCH 06/36] whitelist view stub - WIP --- flowapp/views/whitelist.py | 94 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 flowapp/views/whitelist.py diff --git a/flowapp/views/whitelist.py b/flowapp/views/whitelist.py new file mode 100644 index 0000000..adf3129 --- /dev/null +++ b/flowapp/views/whitelist.py @@ -0,0 +1,94 @@ +from datetime import datetime, timedelta +from flask import Blueprint, current_app, flash, redirect, render_template, request, session, url_for + +from flowapp import constants, db, messages +from flowapp.auth import ( + auth_required, + user_or_admin_required, +) +from flowapp.constants import RuleTypes +from flowapp.forms import WhitelistForm +from flowapp.models import ( + Whitelist, + get_user_nets, +) +from flowapp.output import ROUTE_MODELS, announce_route, log_route, log_withdraw, RouteSources, Route +from flowapp.utils import ( + flash_errors, + get_state_by_time, + quote_to_ent, + round_to_ten_minutes, +) + +whitelist = Blueprint("whitelist", __name__, template_folder="templates") + + +@whitelist.route("/add", methods=["GET", "POST"]) +@auth_required +@user_or_admin_required +def add(): + net_ranges = get_user_nets(session["user_id"]) + form = WhitelistForm(request.form) + + form.net_ranges = net_ranges + + if request.method == "POST" and form.validate(): + model = get_ipv4_model_if_exists(form.data, 1) + + if model: + model.expires = round_to_ten_minutes(form.expires.data) + flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value." + else: + model = Flowspec4( + source=form.source.data, + source_mask=form.source_mask.data, + source_port=form.source_port.data, + destination=form.dest.data, + destination_mask=form.dest_mask.data, + destination_port=form.dest_port.data, + protocol=form.protocol.data, + flags=";".join(form.flags.data), + packet_len=form.packet_len.data, + fragment=";".join(form.fragment.data), + expires=round_to_ten_minutes(form.expires.data), + comment=quote_to_ent(form.comment.data), + action_id=form.action.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + rstate_id=get_state_by_time(form.expires.data), + ) + flash_message = "IPv4 Rule saved" + db.session.add(model) + + db.session.commit() + flash(flash_message, "alert-success") + + # announce route if model is in active state + if model.rstate_id == 1: + command = messages.create_ipv4(model, constants.ANNOUNCE) + route = Route( + author=f"{session['user_email']} / {session['user_org']}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + # log changes + log_route( + session["user_id"], + model, + RuleTypes.IPv4, + f"{session['user_email']} / {session['user_org']}", + ) + + return redirect(url_for("index")) + else: + for field, errors in form.errors.items(): + for error in errors: + current_app.logger.debug("Error in the %s field - %s" % (getattr(form, field).label.text, error)) + + print("NOW", datetime.now()) + default_expires = datetime.now() + timedelta(hours=1) + form.expires.data = default_expires + + return render_template("forms/ipv4_rule.html", form=form, action_url=url_for("rules.ipv4_rule")) From af4befc0e78ee6b4b3afffcc0f6a6ba41aae58cc Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 27 Feb 2025 14:48:22 +0100 Subject: [PATCH 07/36] Refactor models into a separate directory and update references - Moved models from flowapp/models.py into a dedicated directory - Updated flowapp/output.py to handle rule_type as an enum value - Fixed a return value issue in flowapp/utils.py's parse_api_time function - Added a debug print statement in flowapp/views/rules.py for user actions UI --- flowapp/models.py | 1154 ----------------------------- flowapp/models/__init__.py | 69 ++ flowapp/models/api.py | 40 + flowapp/models/base.py | 16 + flowapp/models/community.py | 62 ++ flowapp/models/log.py | 21 + flowapp/models/organization.py | 35 + flowapp/models/rules/__init__.py | 6 + flowapp/models/rules/base.py | 68 ++ flowapp/models/rules/flowspec.py | 294 ++++++++ flowapp/models/rules/rtbh.py | 134 ++++ flowapp/models/rules/whitelist.py | 94 +++ flowapp/models/user.py | 79 ++ flowapp/models/utils.py | 330 +++++++++ flowapp/output.py | 4 +- flowapp/utils.py | 3 +- 16 files changed, 1251 insertions(+), 1158 deletions(-) delete mode 100644 flowapp/models.py create mode 100644 flowapp/models/__init__.py create mode 100644 flowapp/models/api.py create mode 100644 flowapp/models/base.py create mode 100644 flowapp/models/community.py create mode 100644 flowapp/models/log.py create mode 100644 flowapp/models/organization.py create mode 100644 flowapp/models/rules/__init__.py create mode 100644 flowapp/models/rules/base.py create mode 100644 flowapp/models/rules/flowspec.py create mode 100644 flowapp/models/rules/rtbh.py create mode 100644 flowapp/models/rules/whitelist.py create mode 100644 flowapp/models/user.py create mode 100644 flowapp/models/utils.py diff --git a/flowapp/models.py b/flowapp/models.py deleted file mode 100644 index 90a3a9b..0000000 --- a/flowapp/models.py +++ /dev/null @@ -1,1154 +0,0 @@ -import json -from sqlalchemy import event -from datetime import datetime -from flowapp import db, utils -from flowapp.constants import RuleTypes -from flask import current_app - -# models and tables - -user_role = db.Table( - "user_role", - db.Column("user_id", db.Integer, db.ForeignKey("user.id"), nullable=False), - db.Column("role_id", db.Integer, db.ForeignKey("role.id"), nullable=False), - db.PrimaryKeyConstraint("user_id", "role_id"), -) - -user_organization = db.Table( - "user_organization", - db.Column("user_id", db.Integer, db.ForeignKey("user.id"), nullable=False), - db.Column("organization_id", db.Integer, db.ForeignKey("organization.id"), nullable=False), - db.PrimaryKeyConstraint("user_id", "organization_id"), -) - - -class User(db.Model): - """ - App User - """ - - id = db.Column(db.Integer, primary_key=True) - uuid = db.Column(db.String(180), unique=True) - comment = db.Column(db.String(500)) - email = db.Column(db.String(255)) - name = db.Column(db.String(255)) - phone = db.Column(db.String(255)) - apikeys = db.relationship("ApiKey", back_populates="user", lazy="dynamic") - machineapikeys = db.relationship("MachineApiKey", back_populates="user", lazy="dynamic") - role = db.relationship("Role", secondary=user_role, lazy="dynamic", backref="user") - - organization = db.relationship("Organization", secondary=user_organization, lazy="dynamic", backref="user") - - def __init__(self, uuid, name=None, phone=None, email=None, comment=None): - self.uuid = uuid - self.phone = phone - self.name = name - self.email = email - self.comment = comment - - def update(self, form): - """ - update the user with values from form object - :param form: flask form from request - :return: None - """ - self.uuid = form.uuid.data - self.name = form.name.data - self.email = form.email.data - self.phone = form.phone.data - self.comment = form.comment.data - - # first clear existing roles and orgs - for role in self.role: - self.role.remove(role) - for org in self.organization: - self.organization.remove(org) - - for role_id in form.role_ids.data: - my_role = db.session.query(Role).filter_by(id=role_id).first() - if my_role not in self.role: - self.role.append(my_role) - - for org_id in form.org_ids.data: - my_org = db.session.query(Organization).filter_by(id=org_id).first() - if my_org not in self.organization: - self.organization.append(my_org) - - db.session.commit() - - -class ApiKey(db.Model): - id = db.Column(db.Integer, primary_key=True) - machine = db.Column(db.String(255)) - key = db.Column(db.String(255)) - readonly = db.Column(db.Boolean, default=False) - expires = db.Column(db.DateTime, nullable=True) - comment = db.Column(db.String(255)) - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", back_populates="apikeys") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="apikey") - - def is_expired(self): - if self.expires is None: - return False # Non-expiring key - else: - return self.expires < datetime.now() - - -class MachineApiKey(db.Model): - id = db.Column(db.Integer, primary_key=True) - machine = db.Column(db.String(255)) - key = db.Column(db.String(255)) - readonly = db.Column(db.Boolean, default=True) - expires = db.Column(db.DateTime, nullable=True) - comment = db.Column(db.String(255)) - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", back_populates="machineapikeys") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="machineapikey") - - def is_expired(self): - if self.expires is None: - return False # Non-expiring key - else: - return self.expires < datetime.now() - - -class Role(db.Model): - id = db.Column(db.Integer, primary_key=True) - name = db.Column(db.String(20), unique=True) - description = db.Column(db.String(260)) - - def __init__(self, name, description): - self.name = name - self.description = description - - def __repr__(self): - return self.name - - -class Organization(db.Model): - id = db.Column(db.Integer, primary_key=True) - name = db.Column(db.String(150), unique=True) - arange = db.Column(db.Text) - limit_flowspec4 = db.Column(db.Integer, default=0) - limit_flowspec6 = db.Column(db.Integer, default=0) - limit_rtbh = db.Column(db.Integer, default=0) - - def __init__(self, name, arange, limit_flowspec4=0, limit_flowspec6=0, limit_rtbh=0): - self.name = name - self.arange = arange - self.limit_flowspec4 = limit_flowspec4 - self.limit_flowspec6 = limit_flowspec6 - self.limit_rtbh = limit_rtbh - - def __repr__(self): - return self.name - - def get_users(self): - """ - Returns all users associated with this organization. - """ - # self.user is the backref from the user_organization relationship - return self.user - - -class ASPath(db.Model): - id = db.Column(db.Integer, primary_key=True) - prefix = db.Column(db.String(120), unique=True) - as_path = db.Column(db.String(250)) - - def __init__(self, prefix, as_path): - self.prefix = prefix - self.as_path = as_path - - def __repr__(self): - return f"{self.prefix} : {self.as_path}" - - -class Action(db.Model): - """ - Action for rule - """ - - id = db.Column(db.Integer, primary_key=True) - name = db.Column(db.String(120), unique=True) - command = db.Column(db.String(120), unique=True) - description = db.Column(db.String(260)) - role_id = db.Column(db.Integer, db.ForeignKey("role.id"), nullable=False) - role = db.relationship("Role", backref="action") - - def __init__(self, name, command, description, role_id=2): - self.name = name - self.command = command - self.description = description - self.role_id = role_id - - -class Community(db.Model): - """ - Community for RTBH rule - """ - - id = db.Column(db.Integer, primary_key=True) - name = db.Column(db.String(120), unique=True) - comm = db.Column(db.String(2047)) - larcomm = db.Column(db.String(2047)) - extcomm = db.Column(db.String(2047)) - description = db.Column(db.String(255)) - as_path = db.Column(db.Boolean, default=False) - role_id = db.Column(db.Integer, db.ForeignKey("role.id"), nullable=False) - role = db.relationship("Role", backref="community") - - def __init__(self, name, comm, larcomm, extcomm, description, as_path=False, role_id=2): - self.name = name - self.comm = comm - self.larcomm = larcomm - self.extcomm = extcomm - self.description = description - self.as_path = as_path - self.role_id = role_id - - -class Rstate(db.Model): - """ - State for Rules - """ - - id = db.Column(db.Integer, primary_key=True) - description = db.Column(db.String(260)) - - def __init__(self, description): - self.description = description - - -class RTBH(db.Model): - __tablename__ = "RTBH" - - id = db.Column(db.Integer, primary_key=True) - ipv4 = db.Column(db.String(255)) - ipv4_mask = db.Column(db.Integer) - ipv6 = db.Column(db.String(255)) - ipv6_mask = db.Column(db.Integer) - community_id = db.Column(db.Integer, db.ForeignKey("community.id"), nullable=False) - community = db.relationship("Community", backref="rtbh") - comment = db.Column(db.Text) - expires = db.Column(db.DateTime) - created = db.Column(db.DateTime) - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", backref="rtbh") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="rtbh") - rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) - rstate = db.relationship("Rstate", backref="RTBH") - - def __init__( - self, - ipv4, - ipv4_mask, - ipv6, - ipv6_mask, - community_id, - expires, - user_id, - org_id, - comment=None, - created=None, - rstate_id=1, - ): - self.ipv4 = ipv4 - self.ipv4_mask = ipv4_mask - self.ipv6 = ipv6 - self.ipv6_mask = ipv6_mask - self.community_id = community_id - self.expires = expires - self.user_id = user_id - self.org_id = org_id - self.comment = comment - if created is None: - created = datetime.now() - self.created = created - self.rstate_id = rstate_id - - def __eq__(self, other): - """ - Two models are equal if all the network parameters equals. - User_id and time fields can differ. - :param other: other RTBH instance - :return: boolean - """ - return ( - self.ipv4 == other.ipv4 - and self.ipv4_mask == other.ipv4_mask - and self.ipv6 == other.ipv6 - and self.ipv6_mask == other.ipv6_mask - and self.community_id == other.community_id - and self.rstate_id == other.rstate_id - ) - - def __ne__(self, other): - """ - Two models are not equal if all the network parameters are not equal. - User_id and time fields can differ. - :param other: other RTBH instance - :return: boolean - """ - compars = ( - self.ipv4 == other.ipv4 - and self.ipv4_mask == other.ipv4_mask - and self.ipv6 == other.ipv6 - and self.ipv6_mask == other.ipv6_mask - and self.community_id == other.community_id - and self.rstate_id == other.rstate_id - ) - - return not compars - - def update_time(self, form): - self.expires = utils.webpicker_to_datetime(form.expire_date.data) - db.session.commit() - - def to_dict(self, prefered_format="yearfirst"): - """ - Serialize to dict used in API - :param prefered_format: string with prefered time format - :return: dictionary - """ - if prefered_format == "timestamp": - expires = int(datetime.timestamp(self.expires)) - created = int(datetime.timestamp(self.created)) - else: - expires = utils.datetime_to_webpicker(self.expires, prefered_format) - created = utils.datetime_to_webpicker(self.created, prefered_format) - - return { - "id": self.id, - "ipv4": self.ipv4, - "ipv4_mask": self.ipv4_mask, - "ipv6": self.ipv6, - "ipv6_mask": self.ipv6_mask, - "community": self.community.name, - "comment": self.comment, - "expires": expires, - "created": created, - "user": self.user.uuid, - "rstate": self.rstate.description, - } - - def dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :returns: dictionary - """ - return self.to_dict(prefered_format) - - def json(self, prefered_format="yearfirst"): - """ - Serialize to json - :param prefered_format: string with prefered time format - :returns: json - """ - return json.dumps(self.to_dict()) - - -class Flowspec4(db.Model): - id = db.Column(db.Integer, primary_key=True) - source = db.Column(db.String(255)) - source_mask = db.Column(db.Integer) - source_port = db.Column(db.String(255)) - dest = db.Column(db.String(255)) - dest_mask = db.Column(db.Integer) - dest_port = db.Column(db.String(255)) - protocol = db.Column(db.String(255)) - flags = db.Column(db.String(255)) - packet_len = db.Column(db.String(255)) - fragment = db.Column(db.String(255)) - comment = db.Column(db.Text) - expires = db.Column(db.DateTime) - created = db.Column(db.DateTime) - action_id = db.Column(db.Integer, db.ForeignKey("action.id"), nullable=False) - action = db.relationship("Action", backref="flowspec4") - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", backref="flowspec4") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="flowspec4") - rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) - rstate = db.relationship("Rstate", backref="flowspec4") - - def __init__( - self, - source, - source_mask, - source_port, - destination, - destination_mask, - destination_port, - protocol, - flags, - packet_len, - fragment, - expires, - user_id, - org_id, - action_id, - created=None, - comment=None, - rstate_id=1, - ): - self.source = source - self.source_mask = source_mask - self.dest = destination - self.dest_mask = destination_mask - self.source_port = source_port - self.dest_port = destination_port - self.protocol = protocol - self.flags = flags - self.packet_len = packet_len - self.fragment = fragment - self.comment = comment - self.expires = expires - self.user_id = user_id - self.org_id = org_id - self.action_id = action_id - if created is None: - created = datetime.now() - self.created = created - self.rstate_id = rstate_id - - def __eq__(self, other): - """ - Two models are equal if all the network parameters equals. - User_id and time fields can differ. - :param other: other Flowspec4 instance - :return: boolean - """ - return ( - self.source == other.source - and self.source_mask == other.source_mask - and self.dest == other.dest - and self.dest_mask == other.dest_mask - and self.source_port == other.source_port - and self.dest_port == other.dest_port - and self.protocol == other.protocol - and self.flags == other.flags - and self.packet_len == other.packet_len - and self.fragment == other.fragment - and self.action_id == other.action_id - and self.rstate_id == other.rstate_id - ) - - def __ne__(self, other): - """ - Two models are not equal if all the network parameters are not equal. - User_id and time fields can differ. - :param other: other Flowspec4 instance - :return: boolean - """ - compars = ( - self.source == other.source - and self.source_mask == other.source_mask - and self.dest == other.dest - and self.dest_mask == other.dest_mask - and self.source_port == other.source_port - and self.dest_port == other.dest_port - and self.protocol == other.protocol - and self.flags == other.flags - and self.packet_len == other.packet_len - and self.fragment == other.fragment - and self.action_id == other.action_id - and self.rstate_id == other.rstate_id - ) - - return not compars - - def to_dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :return: dictionary - """ - if prefered_format == "timestamp": - expires = int(datetime.timestamp(self.expires)) - created = int(datetime.timestamp(self.created)) - else: - expires = utils.datetime_to_webpicker(self.expires, prefered_format) - created = utils.datetime_to_webpicker(self.created, prefered_format) - - return { - "id": self.id, - "source": self.source, - "source_mask": self.source_mask, - "source_port": self.source_port, - "dest": self.dest, - "dest_mask": self.dest_mask, - "dest_port": self.dest_port, - "protocol": self.protocol, - "flags": self.flags, - "packet_len": self.packet_len, - "fragment": self.fragment, - "comment": self.comment, - "expires": expires, - "created": created, - "action": self.action.name, - "user": self.user.uuid, - "rstate": self.rstate.description, - } - - def dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :returns: dictionary - """ - return self.to_dict(prefered_format) - - def json(self, prefered_format="yearfirst"): - """ - Serialize to json - :param prefered_format: string with prefered time format - :returns: json - """ - return json.dumps(self.to_dict()) - - -class Flowspec6(db.Model): - id = db.Column(db.Integer, primary_key=True) - source = db.Column(db.String(255)) - source_mask = db.Column(db.Integer) - source_port = db.Column(db.String(255)) - dest = db.Column(db.String(255)) - dest_mask = db.Column(db.Integer) - dest_port = db.Column(db.String(255)) - next_header = db.Column(db.String(255)) - flags = db.Column(db.String(255)) - packet_len = db.Column(db.String(255)) - comment = db.Column(db.Text) - expires = db.Column(db.DateTime) - created = db.Column(db.DateTime) - action_id = db.Column(db.Integer, db.ForeignKey("action.id"), nullable=False) - action = db.relationship("Action", backref="flowspec6") - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", backref="flowspec6") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="flowspec6") - rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) - rstate = db.relationship("Rstate", backref="flowspec6") - - def __init__( - self, - source, - source_mask, - source_port, - destination, - destination_mask, - destination_port, - next_header, - flags, - packet_len, - expires, - user_id, - org_id, - action_id, - created=None, - comment=None, - rstate_id=1, - ): - self.source = source - self.source_mask = source_mask - self.dest = destination - self.dest_mask = destination_mask - self.source_port = source_port - self.dest_port = destination_port - self.next_header = next_header - self.flags = flags - self.packet_len = packet_len - self.comment = comment - self.expires = expires - self.user_id = user_id - self.org_id = org_id - self.action_id = action_id - if created is None: - created = datetime.now() - self.created = created - self.rstate_id = rstate_id - - def __eq__(self, other): - """ - Two models are equal if all the network parameters equals. User_id and time fields can differ. - :param other: other Flowspec4 instance - :return: boolean - """ - return ( - self.source == other.source - and self.source_mask == other.source_mask - and self.dest == other.dest - and self.dest_mask == other.dest_mask - and self.source_port == other.source_port - and self.dest_port == other.dest_port - and self.next_header == other.next_header - and self.flags == other.flags - and self.packet_len == other.packet_len - and self.action_id == other.action_id - and self.rstate_id == other.rstate_id - ) - - def to_dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :returns: dictionary - """ - if prefered_format == "timestamp": - expires = int(datetime.timestamp(self.expires)) - created = int(datetime.timestamp(self.created)) - else: - expires = utils.datetime_to_webpicker(self.expires, prefered_format) - created = utils.datetime_to_webpicker(self.created, prefered_format) - - return { - "id": str(self.id), - "source": self.source, - "source_mask": self.source_mask, - "source_port": self.source_port, - "dest": self.dest, - "dest_mask": self.dest_mask, - "dest_port": self.dest_port, - "next_header": self.next_header, - "flags": self.flags, - "packet_len": self.packet_len, - "comment": self.comment, - "expires": expires, - "created": created, - "action": self.action.name, - "user": self.user.uuid, - "rstate": self.rstate.description, - } - - def dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :returns: dictionary - """ - return self.to_dict(prefered_format) - - def json(self, prefered_format="yearfirst"): - """ - Serialize to json - :param prefered_format: string with prefered time format - :returns: json - """ - return json.dumps(self.to_dict()) - - -class Log(db.Model): - id = db.Column(db.Integer, primary_key=True) - time = db.Column(db.DateTime) - task = db.Column(db.String(1000)) - author = db.Column(db.String(1000)) - rule_type = db.Column(db.Integer) - rule_id = db.Column(db.Integer) - user_id = db.Column(db.Integer) - org_id = db.Column(db.Integer, nullable=True) - - def __init__(self, time, task, user_id, rule_type, rule_id, author, org_id=None): - self.time = time - self.task = task - self.rule_type = rule_type - self.rule_id = rule_id - self.user_id = user_id - self.author = author - self.org_id = org_id - - -class Whitelist(db.Model): - id = db.Column(db.Integer, primary_key=True) - ip = db.Column(db.String(255)) - mask = db.Column(db.Integer) - comment = db.Column(db.Text) - expires = db.Column(db.DateTime) - created = db.Column(db.DateTime) - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", backref="whitelist") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="whitelist") - rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) - rstate = db.relationship("Rstate", backref="whitelist") - - def __init__( - self, - ip, - mask, - expires, - user_id, - org_id, - created=None, - comment=None, - rstate_id=1, - ): - self.ip = ip - self.mask = mask - self.expires = expires - self.user_id = user_id - self.org_id = org_id - self.comment = comment - if created is None: - created = datetime.now() - self.created = created - self.rstate_id = rstate_id - - def __eq__(self, other): - """ - Two whitelists are equal if all the network parameters equals. - User_id, org, comment and time fields can differ. - :param other: other Whitelist instance - :return: boolean - """ - return self.ip == other.ip and self.mask == other.mask and self.rstate_id == other.rstate_id - - def to_dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :returns: dictionary - """ - if prefered_format == "timestamp": - expires = int(datetime.timestamp(self.expires)) - created = int(datetime.timestamp(self.created)) - else: - expires = utils.datetime_to_webpicker(self.expires, prefered_format) - created = utils.datetime_to_webpicker(self.created, prefered_format) - - return { - "id": self.id, - "ip": self.ip, - "mask": self.mask, - "comment": self.comment, - "expires": expires, - "created": created, - "user": self.user.uuid, - "rstate": self.rstate.description, - } - - -class RuleWhitelistCache(db.Model): - """ - Cache for whitelisted rules - For each rule we store id and type - Rule origin determines if the rule was created by user or by whitelist - """ - - id = db.Column(db.Integer, primary_key=True) - rid = db.Column(db.Integer) - rtype = db.Column(db.Integer) - rorigin = db.Column(db.Integer) - whitelist_id = db.Column(db.Integer, db.ForeignKey("whitelist.id")) # Add ForeignKey - whitelist = db.relationship("Whitelist", backref="rulewhitelistcache") - - def __init__(self, rid, rtype, rorigin, whitelist_id): - self.rid = rid - self.rtype = rtype - self.rorigin = rorigin - self.whitelist_id = whitelist_id - - -# DDL -# default values for tables inserted after create -@event.listens_for(Action.__table__, "after_create") -def insert_initial_actions(table, conn, *args, **kwargs): - conn.execute( - table.insert().values( - name="QoS 100 kbps", - command="rate-limit 12800", - description="QoS", - role_id=2, - ) - ) - conn.execute( - table.insert().values( - name="QoS 1Mbps", - command="rate-limit 13107200", - description="QoS", - role_id=2, - ) - ) - conn.execute( - table.insert().values( - name="QoS 10Mbps", - command="rate-limit 131072000", - description="QoS", - role_id=2, - ) - ) - conn.execute(table.insert().values(name="Discard", command="discard", description="Discard", role_id=2)) - - -@event.listens_for(Community.__table__, "after_create") -def insert_initial_communities(table, conn, *args, **kwargs): - conn.execute( - table.insert().values( - name="65535:65283", - comm="65535:65283", - larcomm="", - extcomm="", - description="local-as", - role_id=2, - ) - ) - conn.execute( - table.insert().values( - name="64496:64511", - comm="64496:64511", - larcomm="", - extcomm="", - description="", - role_id=2, - ) - ) - conn.execute( - table.insert().values( - name="64497:64510", - comm="64497:64510", - larcomm="", - extcomm="", - description="", - role_id=2, - ) - ) - - -@event.listens_for(Role.__table__, "after_create") -def insert_initial_roles(table, conn, *args, **kwargs): - conn.execute(table.insert().values(name="view", description="just view, no edit")) - conn.execute(table.insert().values(name="user", description="can edit")) - conn.execute(table.insert().values(name="admin", description="admin")) - - -@event.listens_for(Organization.__table__, "after_create") -def insert_initial_organizations(table, conn, *args, **kwargs): - conn.execute(table.insert().values(name="Cesnet", arange="147.230.0.0/16\n2001:718:1c01::/48")) - - -@event.listens_for(Rstate.__table__, "after_create") -def insert_initial_rulestates(table, conn, *args, **kwargs): - conn.execute(table.insert().values(description="active rule")) - conn.execute(table.insert().values(description="withdrawed rule")) - conn.execute(table.insert().values(description="deleted rule")) - - -# Misc functions -def check_rule_limit(org_id: int, rule_type: RuleTypes) -> bool: - """ - Check if the organization has reached the rule limit - :param org_id: integer organization id - :param rule_type: RuleType rule type - :return: boolean - """ - flowspec_limit = current_app.config.get("FLOWSPEC_MAX_RULES", 9000) - rtbh_limit = current_app.config.get("RTBH_MAX_RULES", 100000) - fs4 = db.session.query(Flowspec4).filter_by(rstate_id=1).count() - fs6 = db.session.query(Flowspec6).filter_by(rstate_id=1).count() - rtbh = db.session.query(RTBH).filter_by(rstate_id=1).count() - - # check the organization limits - org = Organization.query.filter_by(id=org_id).first() - if rule_type == RuleTypes.IPv4 and org.limit_flowspec4 > 0: - count = db.session.query(Flowspec4).filter_by(org_id=org_id, rstate_id=1).count() - return count >= org.limit_flowspec4 or fs4 >= flowspec_limit - if rule_type == RuleTypes.IPv6 and org.limit_flowspec6 > 0: - count = db.session.query(Flowspec6).filter_by(org_id=org_id, rstate_id=1).count() - return count >= org.limit_flowspec6 or fs6 >= flowspec_limit - if rule_type == RuleTypes.RTBH and org.limit_rtbh > 0: - count = db.session.query(RTBH).filter_by(org_id=org_id, rstate_id=1).count() - return count >= org.limit_rtbh or rtbh >= rtbh_limit - - -def check_global_rule_limit(rule_type: RuleTypes) -> bool: - flowspec4_limit = current_app.config.get("FLOWSPEC4_MAX_RULES", 9000) - flowspec6_limit = current_app.config.get("FLOWSPEC6_MAX_RULES", 9000) - rtbh_limit = current_app.config.get("RTBH_MAX_RULES", 100000) - fs4 = db.session.query(Flowspec4).filter_by(rstate_id=1).count() - fs6 = db.session.query(Flowspec6).filter_by(rstate_id=1).count() - rtbh = db.session.query(RTBH).filter_by(rstate_id=1).count() - - # check the global limits if the organization limits are not set - - if rule_type == RuleTypes.IPv4: - return fs4 >= flowspec4_limit - if rule_type == RuleTypes.IPv6: - return fs6 >= flowspec6_limit - if rule_type == RuleTypes.RTBH: - return rtbh >= rtbh_limit - - -def get_ipv4_model_if_exists(form_data, rstate_id=1): - """ - Check if the record in database exist - """ - record = ( - db.session.query(Flowspec4) - .filter( - Flowspec4.source == form_data["source"], - Flowspec4.source_mask == form_data["source_mask"], - Flowspec4.source_port == form_data["source_port"], - Flowspec4.dest == form_data["dest"], - Flowspec4.dest_mask == form_data["dest_mask"], - Flowspec4.dest_port == form_data["dest_port"], - Flowspec4.protocol == form_data["protocol"], - Flowspec4.flags == ";".join(form_data["flags"]), - Flowspec4.packet_len == form_data["packet_len"], - Flowspec4.action_id == form_data["action"], - Flowspec4.rstate_id == rstate_id, - ) - .first() - ) - - if record: - return record - - return False - - -def get_ipv6_model_if_exists(form_data, rstate_id=1): - """ - Check if the record in database exist - """ - record = ( - db.session.query(Flowspec6) - .filter( - Flowspec6.source == form_data["source"], - Flowspec6.source_mask == form_data["source_mask"], - Flowspec6.source_port == form_data["source_port"], - Flowspec6.dest == form_data["dest"], - Flowspec6.dest_mask == form_data["dest_mask"], - Flowspec6.dest_port == form_data["dest_port"], - Flowspec6.next_header == form_data["next_header"], - Flowspec6.flags == ";".join(form_data["flags"]), - Flowspec6.packet_len == form_data["packet_len"], - Flowspec6.action_id == form_data["action"], - Flowspec6.rstate_id == rstate_id, - ) - .first() - ) - - if record: - return record - - return False - - -def get_rtbh_model_if_exists(form_data, rstate_id=1): - """ - Check if the record in database exist - """ - - record = ( - db.session.query(RTBH) - .filter( - RTBH.ipv4 == form_data["ipv4"], - RTBH.ipv4_mask == form_data["ipv4_mask"], - RTBH.ipv6 == form_data["ipv6"], - RTBH.ipv6_mask == form_data["ipv6_mask"], - RTBH.community_id == form_data["community"], - RTBH.rstate_id == rstate_id, - ) - .first() - ) - - if record: - return record - - return False - - -def insert_users(users): - """ - inser list of users {name: string, role_id: integer} to db - """ - for user in users: - r = Role.query.filter_by(id=user["role_id"]).first() - o = Organization.query.filter_by(id=user["org_id"]).first() - u = User(uuid=user["name"]) - u.role.append(r) - u.organization.append(o) - db.session.add(u) - - db.session.commit() - - -def insert_user( - uuid: str, - role_ids: list, - org_ids: list, - name: str = None, - phone: str = None, - email: str = None, - comment: str = None, -): - """ - insert new user with multiple roles and organizations - :param uuid: string unique user id (eppn or similar) - :param phone: string phone number - :param name: string user name - :param email: string email - :param comment: string comment / notice - :param role_ids: list of roles - :param org_ids: list of orgs - :return: None - """ - u = User(uuid=uuid, name=name, phone=phone, comment=comment, email=email) - - for role_id in role_ids: - r = Role.query.filter_by(id=role_id).first() - u.role.append(r) - - for org_id in org_ids: - o = Organization.query.filter_by(id=org_id).first() - u.organization.append(o) - - db.session.add(u) - db.session.commit() - - -def get_user_nets(user_id): - """ - Return list of network ranges for all user organization - """ - user = db.session.query(User).filter_by(id=user_id).first() - orgs = user.organization - result = [] - for org in orgs: - result.extend(org.arange.split()) - - return result - - -def get_user_orgs_choices(user_id): - """ - Return list of orgs as choices for form - """ - user = db.session.query(User).filter_by(id=user_id).first() - orgs = user.organization - - return [(g.id, g.name) for g in orgs] - - -def get_user_actions(user_roles): - """ - Return list of actions based on current user role - """ - max_role = max(user_roles) - if max_role == 3: - actions = db.session.query(Action).order_by("id") - else: - actions = db.session.query(Action).filter_by(role_id=max_role).order_by("id") - - return [(g.id, g.name) for g in actions] - - -def get_user_communities(user_roles): - """ - Return list of communities based on current user role - """ - max_role = max(user_roles) - if max_role == 3: - communities = db.session.query(Community).order_by("id") - else: - communities = db.session.query(Community).filter_by(role_id=max_role).order_by("id") - - return [(g.id, g.name) for g in communities] - - -def get_existing_action(name=None, command=None): - """ - return Action with given name or command if the action exists - return None if action not exists - :param name: string action name - :param command: string action command - :return: action id - """ - action = Action.query.filter((Action.name == name) | (Action.command == command)).first() - return action.id if hasattr(action, "id") else None - - -def get_existing_community(name=None): - """ - return Community with given name or command if the action exists - return None if action not exists - :param name: string action name - :param command: string action command - :return: action id - """ - community = Community.query.filter(Community.name == name).first() - return community.id if hasattr(community, "id") else None - - -def get_ip_rules(rule_type, rule_state, sort="expires", order="desc"): - """ - Returns list of rules sorted by sort column ordered asc or desc - :param sort: sorting column - :param order: asc or desc - :return: list - """ - - today = datetime.now() - comp_func = utils.get_comp_func(rule_state) - - if rule_type == "ipv4": - sorter_ip4 = getattr(Flowspec4, sort, Flowspec4.id) - sorting_ip4 = getattr(sorter_ip4, order) - if comp_func: - rules4 = ( - db.session.query(Flowspec4).filter(comp_func(Flowspec4.expires, today)).order_by(sorting_ip4()).all() - ) - else: - rules4 = db.session.query(Flowspec4).order_by(sorting_ip4()).all() - - return rules4 - - if rule_type == "ipv6": - sorter_ip6 = getattr(Flowspec6, sort, Flowspec6.id) - sorting_ip6 = getattr(sorter_ip6, order) - if comp_func: - rules6 = ( - db.session.query(Flowspec6).filter(comp_func(Flowspec6.expires, today)).order_by(sorting_ip6()).all() - ) - else: - rules6 = db.session.query(Flowspec6).order_by(sorting_ip6()).all() - - return rules6 - - if rule_type == "rtbh": - sorter_rtbh = getattr(RTBH, sort, RTBH.id) - sorting_rtbh = getattr(sorter_rtbh, order) - - if comp_func: - rules_rtbh = db.session.query(RTBH).filter(comp_func(RTBH.expires, today)).order_by(sorting_rtbh()).all() - - else: - rules_rtbh = db.session.query(RTBH).order_by(sorting_rtbh()).all() - - return rules_rtbh - - -def get_user_rules_ids(user_id, rule_type): - """ - Returns list of rule ids belonging to user - :param user_id: user id - :param rule_type: ipv4, ipv6 or rtbh - :return: list - """ - - if rule_type == "ipv4": - rules4 = db.session.query(Flowspec4.id).filter_by(user_id=user_id).all() - return [int(x[0]) for x in rules4] - - if rule_type == "ipv6": - rules6 = db.session.query(Flowspec6.id).order_by(Flowspec6.expires.desc()).all() - return [int(x[0]) for x in rules6] - - if rule_type == "rtbh": - rules_rtbh = db.session.query(RTBH.id).filter_by(user_id=user_id).all() - return [int(x[0]) for x in rules_rtbh] diff --git a/flowapp/models/__init__.py b/flowapp/models/__init__.py new file mode 100644 index 0000000..6d74ffe --- /dev/null +++ b/flowapp/models/__init__.py @@ -0,0 +1,69 @@ +# Import and re-export all models and functions for backward compatibility +from .base import db, user_role, user_organization + +# User-related models +from .user import User, Role +from .api import ApiKey, MachineApiKey +from .organization import Organization + +# Rule-related models +from .rules import Flowspec4, Flowspec6, RTBH, Rstate, Action, Whitelist, RuleWhitelistCache + +# Other models +from .community import Community, ASPath, insert_initial_communities +from .log import Log + +# Helper functions +from .utils import ( + get_user_nets, + get_user_actions, + get_user_communities, + get_existing_action, + get_existing_community, + get_ipv4_model_if_exists, + get_ipv6_model_if_exists, + get_rtbh_model_if_exists, + get_ip_rules, + get_user_rules_ids, + insert_users, + insert_user, + check_rule_limit, + check_global_rule_limit, +) + +# Ensure all models are registered properly +__all__ = [ + "db", + "User", + "Role", + "user_role", + "ApiKey", + "MachineApiKey", + "Organization", + "user_organization", + "Action", + "Flowspec4", + "Flowspec6", + "RTBH", + "Rstate", + "Community", + "ASPath", + "Log", + "Whitelist", + "RuleWhitelistCache", + "get_user_nets", + "get_user_actions", + "get_user_communities", + "get_existing_action", + "get_existing_community", + "get_ipv4_model_if_exists", + "get_ipv6_model_if_exists", + "get_rtbh_model_if_exists", + "get_ip_rules", + "get_user_rules_ids", + "insert_users", + "insert_user", + "check_rule_limit", + "check_global_rule_limit", + "insert_initial_communities", +] diff --git a/flowapp/models/api.py b/flowapp/models/api.py new file mode 100644 index 0000000..ed50685 --- /dev/null +++ b/flowapp/models/api.py @@ -0,0 +1,40 @@ +from datetime import datetime +from .base import db + + +class ApiKey(db.Model): + id = db.Column(db.Integer, primary_key=True) + machine = db.Column(db.String(255)) + key = db.Column(db.String(255)) + readonly = db.Column(db.Boolean, default=False) + expires = db.Column(db.DateTime, nullable=True) + comment = db.Column(db.String(255)) + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", back_populates="apikeys") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="apikey") + + def is_expired(self): + if self.expires is None: + return False # Non-expiring key + else: + return self.expires < datetime.now() + + +class MachineApiKey(db.Model): + id = db.Column(db.Integer, primary_key=True) + machine = db.Column(db.String(255)) + key = db.Column(db.String(255)) + readonly = db.Column(db.Boolean, default=True) + expires = db.Column(db.DateTime, nullable=True) + comment = db.Column(db.String(255)) + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", back_populates="machineapikeys") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="machineapikey") + + def is_expired(self): + if self.expires is None: + return False # Non-expiring key + else: + return self.expires < datetime.now() diff --git a/flowapp/models/base.py b/flowapp/models/base.py new file mode 100644 index 0000000..a934ce8 --- /dev/null +++ b/flowapp/models/base.py @@ -0,0 +1,16 @@ +from flowapp import db + +# Define shared tables +user_role = db.Table( + "user_role", + db.Column("user_id", db.Integer, db.ForeignKey("user.id"), nullable=False), + db.Column("role_id", db.Integer, db.ForeignKey("role.id"), nullable=False), + db.PrimaryKeyConstraint("user_id", "role_id"), +) + +user_organization = db.Table( + "user_organization", + db.Column("user_id", db.Integer, db.ForeignKey("user.id"), nullable=False), + db.Column("organization_id", db.Integer, db.ForeignKey("organization.id"), nullable=False), + db.PrimaryKeyConstraint("user_id", "organization_id"), +) diff --git a/flowapp/models/community.py b/flowapp/models/community.py new file mode 100644 index 0000000..afb2c36 --- /dev/null +++ b/flowapp/models/community.py @@ -0,0 +1,62 @@ +from sqlalchemy import event +from .base import db + + +class Community(db.Model): + """Community for RTBH rule""" + + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(120), unique=True) + comm = db.Column(db.String(2047)) + larcomm = db.Column(db.String(2047)) + extcomm = db.Column(db.String(2047)) + description = db.Column(db.String(255)) + as_path = db.Column(db.Boolean, default=False) + role_id = db.Column(db.Integer, db.ForeignKey("role.id"), nullable=False) + role = db.relationship("Role", backref="community") + + # Methods and initializer + + +class ASPath(db.Model): + """AS Path for RTBH rules""" + + id = db.Column(db.Integer, primary_key=True) + prefix = db.Column(db.String(120), unique=True) + as_path = db.Column(db.String(250)) + + # Methods and initializer + + +@event.listens_for(Community.__table__, "after_create") +def insert_initial_communities(table, conn, *args, **kwargs): + conn.execute( + table.insert().values( + name="65535:65283", + comm="65535:65283", + larcomm="", + extcomm="", + description="local-as", + role_id=2, + ) + ) + conn.execute( + table.insert().values( + name="64496:64511", + comm="64496:64511", + larcomm="", + extcomm="", + description="", + role_id=2, + ) + ) + conn.execute( + table.insert().values( + name="64497:64510", + comm="64497:64510", + larcomm="", + extcomm="", + description="", + role_id=2, + ) + ) diff --git a/flowapp/models/log.py b/flowapp/models/log.py new file mode 100644 index 0000000..e32bb30 --- /dev/null +++ b/flowapp/models/log.py @@ -0,0 +1,21 @@ +from .base import db + + +class Log(db.Model): + """Log model for system actions""" + + id = db.Column(db.Integer, primary_key=True) + time = db.Column(db.DateTime) + task = db.Column(db.String(1000)) + author = db.Column(db.String(1000)) + rule_type = db.Column(db.Integer) + rule_id = db.Column(db.Integer) + user_id = db.Column(db.Integer) + + def __init__(self, time, task, user_id, rule_type, rule_id, author): + self.time = time + self.task = task + self.rule_type = rule_type + self.rule_id = rule_id + self.user_id = user_id + self.author = author diff --git a/flowapp/models/organization.py b/flowapp/models/organization.py new file mode 100644 index 0000000..baf0ec1 --- /dev/null +++ b/flowapp/models/organization.py @@ -0,0 +1,35 @@ +from sqlalchemy import event +from .base import db + + +class Organization(db.Model): + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(150), unique=True) + arange = db.Column(db.Text) + limit_flowspec4 = db.Column(db.Integer, default=0) + limit_flowspec6 = db.Column(db.Integer, default=0) + limit_rtbh = db.Column(db.Integer, default=0) + + def __init__(self, name, arange, limit_flowspec4=0, limit_flowspec6=0, limit_rtbh=0): + self.name = name + self.arange = arange + self.limit_flowspec4 = limit_flowspec4 + self.limit_flowspec6 = limit_flowspec6 + self.limit_rtbh = limit_rtbh + + def __repr__(self): + return self.name + + def get_users(self): + """ + Returns all users associated with this organization. + """ + # self.user is the backref from the user_organization relationship + return self.user + + +# Event listeners for Organization +@event.listens_for(Organization.__table__, "after_create") +def insert_initial_organizations(table, conn, *args, **kwargs): + conn.execute(table.insert().values(name="TU Liberec", arange="147.230.0.0/16\n2001:718:1c01::/48")) + conn.execute(table.insert().values(name="Cesnet", arange="147.230.0.0/16\n2001:718:1c01::/48")) diff --git a/flowapp/models/rules/__init__.py b/flowapp/models/rules/__init__.py new file mode 100644 index 0000000..da12ba4 --- /dev/null +++ b/flowapp/models/rules/__init__.py @@ -0,0 +1,6 @@ +from .flowspec import Flowspec4, Flowspec6 +from .rtbh import RTBH +from .base import Rstate, Action +from .whitelist import Whitelist, RuleWhitelistCache + +__all__ = ["Flowspec4", "Flowspec6", "RTBH", "Rstate", "Action", "Whitelist", "RuleWhitelistCache"] diff --git a/flowapp/models/rules/base.py b/flowapp/models/rules/base.py new file mode 100644 index 0000000..ab61f5e --- /dev/null +++ b/flowapp/models/rules/base.py @@ -0,0 +1,68 @@ +from sqlalchemy import event +from ..base import db + + +class Rstate(db.Model): + """State for Rules""" + + id = db.Column(db.Integer, primary_key=True) + description = db.Column(db.String(260)) + + def __init__(self, description): + self.description = description + + +class Action(db.Model): + """ + Action for rule + """ + + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(120), unique=True) + command = db.Column(db.String(120), unique=True) + description = db.Column(db.String(260)) + role_id = db.Column(db.Integer, db.ForeignKey("role.id"), nullable=False) + role = db.relationship("Role", backref="action") + + def __init__(self, name, command, description, role_id=2): + self.name = name + self.command = command + self.description = description + self.role_id = role_id + + +# Event listeners for Rstate +@event.listens_for(Rstate.__table__, "after_create") +def insert_initial_rulestates(table, conn, *args, **kwargs): + conn.execute(table.insert().values(description="active rule")) + conn.execute(table.insert().values(description="withdrawed rule")) + conn.execute(table.insert().values(description="deleted rule")) + + +@event.listens_for(Action.__table__, "after_create") +def insert_initial_actions(table, conn, *args, **kwargs): + conn.execute( + table.insert().values( + name="QoS 100 kbps", + command="rate-limit 12800", + description="QoS", + role_id=2, + ) + ) + conn.execute( + table.insert().values( + name="QoS 1Mbps", + command="rate-limit 13107200", + description="QoS", + role_id=2, + ) + ) + conn.execute( + table.insert().values( + name="QoS 10Mbps", + command="rate-limit 131072000", + description="QoS", + role_id=2, + ) + ) + conn.execute(table.insert().values(name="Discard", command="discard", description="Discard", role_id=2)) diff --git a/flowapp/models/rules/flowspec.py b/flowapp/models/rules/flowspec.py new file mode 100644 index 0000000..9823aba --- /dev/null +++ b/flowapp/models/rules/flowspec.py @@ -0,0 +1,294 @@ +import json +from datetime import datetime +from flowapp import utils +from ..base import db + + +class Flowspec4(db.Model): + id = db.Column(db.Integer, primary_key=True) + source = db.Column(db.String(255)) + source_mask = db.Column(db.Integer) + source_port = db.Column(db.String(255)) + dest = db.Column(db.String(255)) + dest_mask = db.Column(db.Integer) + dest_port = db.Column(db.String(255)) + protocol = db.Column(db.String(255)) + flags = db.Column(db.String(255)) + packet_len = db.Column(db.String(255)) + fragment = db.Column(db.String(255)) + comment = db.Column(db.Text) + expires = db.Column(db.DateTime) + created = db.Column(db.DateTime) + action_id = db.Column(db.Integer, db.ForeignKey("action.id"), nullable=False) + action = db.relationship("Action", backref="flowspec4") + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", backref="flowspec4") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="flowspec4") + rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) + rstate = db.relationship("Rstate", backref="flowspec4") + + def __init__( + self, + source, + source_mask, + source_port, + destination, + destination_mask, + destination_port, + protocol, + flags, + packet_len, + fragment, + expires, + user_id, + org_id, + action_id, + created=None, + comment=None, + rstate_id=1, + ): + self.source = source + self.source_mask = source_mask + self.dest = destination + self.dest_mask = destination_mask + self.source_port = source_port + self.dest_port = destination_port + self.protocol = protocol + self.flags = flags + self.packet_len = packet_len + self.fragment = fragment + self.comment = comment + self.expires = expires + self.user_id = user_id + self.org_id = org_id + self.action_id = action_id + if created is None: + created = datetime.now() + self.created = created + self.rstate_id = rstate_id + + def __eq__(self, other): + """ + Two models are equal if all the network parameters equals. + User_id and time fields can differ. + :param other: other Flowspec4 instance + :return: boolean + """ + return ( + self.source == other.source + and self.source_mask == other.source_mask + and self.dest == other.dest + and self.dest_mask == other.dest_mask + and self.source_port == other.source_port + and self.dest_port == other.dest_port + and self.protocol == other.protocol + and self.flags == other.flags + and self.packet_len == other.packet_len + and self.fragment == other.fragment + and self.action_id == other.action_id + and self.rstate_id == other.rstate_id + ) + + def __ne__(self, other): + """ + Two models are not equal if all the network parameters are not equal. + User_id and time fields can differ. + :param other: other Flowspec4 instance + :return: boolean + """ + compars = ( + self.source == other.source + and self.source_mask == other.source_mask + and self.dest == other.dest + and self.dest_mask == other.dest_mask + and self.source_port == other.source_port + and self.dest_port == other.dest_port + and self.protocol == other.protocol + and self.flags == other.flags + and self.packet_len == other.packet_len + and self.fragment == other.fragment + and self.action_id == other.action_id + and self.rstate_id == other.rstate_id + ) + + return not compars + + def to_dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :return: dictionary + """ + if prefered_format == "timestamp": + expires = int(datetime.timestamp(self.expires)) + created = int(datetime.timestamp(self.created)) + else: + expires = utils.datetime_to_webpicker(self.expires, prefered_format) + created = utils.datetime_to_webpicker(self.created, prefered_format) + + return { + "id": self.id, + "source": self.source, + "source_mask": self.source_mask, + "source_port": self.source_port, + "dest": self.dest, + "dest_mask": self.dest_mask, + "dest_port": self.dest_port, + "protocol": self.protocol, + "flags": self.flags, + "packet_len": self.packet_len, + "fragment": self.fragment, + "comment": self.comment, + "expires": expires, + "created": created, + "action": self.action.name, + "user": self.user.uuid, + "rstate": self.rstate.description, + } + + def dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + return self.to_dict(prefered_format) + + def json(self, prefered_format="yearfirst"): + """ + Serialize to json + :param prefered_format: string with prefered time format + :returns: json + """ + return json.dumps(self.to_dict()) + + +class Flowspec6(db.Model): + id = db.Column(db.Integer, primary_key=True) + source = db.Column(db.String(255)) + source_mask = db.Column(db.Integer) + source_port = db.Column(db.String(255)) + dest = db.Column(db.String(255)) + dest_mask = db.Column(db.Integer) + dest_port = db.Column(db.String(255)) + next_header = db.Column(db.String(255)) + flags = db.Column(db.String(255)) + packet_len = db.Column(db.String(255)) + comment = db.Column(db.Text) + expires = db.Column(db.DateTime) + created = db.Column(db.DateTime) + action_id = db.Column(db.Integer, db.ForeignKey("action.id"), nullable=False) + action = db.relationship("Action", backref="flowspec6") + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", backref="flowspec6") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="flowspec6") + rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) + rstate = db.relationship("Rstate", backref="flowspec6") + + def __init__( + self, + source, + source_mask, + source_port, + destination, + destination_mask, + destination_port, + next_header, + flags, + packet_len, + expires, + user_id, + org_id, + action_id, + created=None, + comment=None, + rstate_id=1, + ): + self.source = source + self.source_mask = source_mask + self.dest = destination + self.dest_mask = destination_mask + self.source_port = source_port + self.dest_port = destination_port + self.next_header = next_header + self.flags = flags + self.packet_len = packet_len + self.comment = comment + self.expires = expires + self.user_id = user_id + self.org_id = org_id + self.action_id = action_id + if created is None: + created = datetime.now() + self.created = created + self.rstate_id = rstate_id + + def __eq__(self, other): + """ + Two models are equal if all the network parameters equals. User_id and time fields can differ. + :param other: other Flowspec4 instance + :return: boolean + """ + return ( + self.source == other.source + and self.source_mask == other.source_mask + and self.dest == other.dest + and self.dest_mask == other.dest_mask + and self.source_port == other.source_port + and self.dest_port == other.dest_port + and self.next_header == other.next_header + and self.flags == other.flags + and self.packet_len == other.packet_len + and self.action_id == other.action_id + and self.rstate_id == other.rstate_id + ) + + def to_dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + if prefered_format == "timestamp": + expires = int(datetime.timestamp(self.expires)) + created = int(datetime.timestamp(self.created)) + else: + expires = utils.datetime_to_webpicker(self.expires, prefered_format) + created = utils.datetime_to_webpicker(self.created, prefered_format) + + return { + "id": str(self.id), + "source": self.source, + "source_mask": self.source_mask, + "source_port": self.source_port, + "dest": self.dest, + "dest_mask": self.dest_mask, + "dest_port": self.dest_port, + "next_header": self.next_header, + "flags": self.flags, + "packet_len": self.packet_len, + "comment": self.comment, + "expires": expires, + "created": created, + "action": self.action.name, + "user": self.user.uuid, + "rstate": self.rstate.description, + } + + def dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + return self.to_dict(prefered_format) + + def json(self, prefered_format="yearfirst"): + """ + Serialize to json + :param prefered_format: string with prefered time format + :returns: json + """ + return json.dumps(self.to_dict()) diff --git a/flowapp/models/rules/rtbh.py b/flowapp/models/rules/rtbh.py new file mode 100644 index 0000000..d75220c --- /dev/null +++ b/flowapp/models/rules/rtbh.py @@ -0,0 +1,134 @@ +import json +from datetime import datetime +from flowapp import utils +from ..base import db + + +class RTBH(db.Model): + __tablename__ = "RTBH" + + id = db.Column(db.Integer, primary_key=True) + ipv4 = db.Column(db.String(255)) + ipv4_mask = db.Column(db.Integer) + ipv6 = db.Column(db.String(255)) + ipv6_mask = db.Column(db.Integer) + community_id = db.Column(db.Integer, db.ForeignKey("community.id"), nullable=False) + community = db.relationship("Community", backref="rtbh") + comment = db.Column(db.Text) + expires = db.Column(db.DateTime) + created = db.Column(db.DateTime) + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", backref="rtbh") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="rtbh") + rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) + rstate = db.relationship("Rstate", backref="RTBH") + + def __init__( + self, + ipv4, + ipv4_mask, + ipv6, + ipv6_mask, + community_id, + expires, + user_id, + org_id, + comment=None, + created=None, + rstate_id=1, + ): + self.ipv4 = ipv4 + self.ipv4_mask = ipv4_mask + self.ipv6 = ipv6 + self.ipv6_mask = ipv6_mask + self.community_id = community_id + self.expires = expires + self.user_id = user_id + self.org_id = org_id + self.comment = comment + if created is None: + created = datetime.now() + self.created = created + self.rstate_id = rstate_id + + def __eq__(self, other): + """ + Two models are equal if all the network parameters equals. + User_id and time fields can differ. + :param other: other RTBH instance + :return: boolean + """ + return ( + self.ipv4 == other.ipv4 + and self.ipv4_mask == other.ipv4_mask + and self.ipv6 == other.ipv6 + and self.ipv6_mask == other.ipv6_mask + and self.community_id == other.community_id + and self.rstate_id == other.rstate_id + ) + + def __ne__(self, other): + """ + Two models are not equal if all the network parameters are not equal. + User_id and time fields can differ. + :param other: other RTBH instance + :return: boolean + """ + compars = ( + self.ipv4 == other.ipv4 + and self.ipv4_mask == other.ipv4_mask + and self.ipv6 == other.ipv6 + and self.ipv6_mask == other.ipv6_mask + and self.community_id == other.community_id + and self.rstate_id == other.rstate_id + ) + + return not compars + + def update_time(self, form): + self.expires = utils.webpicker_to_datetime(form.expire_date.data) + db.session.commit() + + def to_dict(self, prefered_format="yearfirst"): + """ + Serialize to dict used in API + :param prefered_format: string with prefered time format + :return: dictionary + """ + if prefered_format == "timestamp": + expires = int(datetime.timestamp(self.expires)) + created = int(datetime.timestamp(self.created)) + else: + expires = utils.datetime_to_webpicker(self.expires, prefered_format) + created = utils.datetime_to_webpicker(self.created, prefered_format) + + return { + "id": self.id, + "ipv4": self.ipv4, + "ipv4_mask": self.ipv4_mask, + "ipv6": self.ipv6, + "ipv6_mask": self.ipv6_mask, + "community": self.community.name, + "comment": self.comment, + "expires": expires, + "created": created, + "user": self.user.uuid, + "rstate": self.rstate.description, + } + + def dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + return self.to_dict(prefered_format) + + def json(self, prefered_format="yearfirst"): + """ + Serialize to json + :param prefered_format: string with prefered time format + :returns: json + """ + return json.dumps(self.to_dict()) diff --git a/flowapp/models/rules/whitelist.py b/flowapp/models/rules/whitelist.py new file mode 100644 index 0000000..5e93714 --- /dev/null +++ b/flowapp/models/rules/whitelist.py @@ -0,0 +1,94 @@ +from flowapp import utils +from ..base import db +from datetime import datetime + + +class Whitelist(db.Model): + id = db.Column(db.Integer, primary_key=True) + ip = db.Column(db.String(255)) + mask = db.Column(db.Integer) + comment = db.Column(db.Text) + expires = db.Column(db.DateTime) + created = db.Column(db.DateTime) + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", backref="whitelist") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="whitelist") + rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) + rstate = db.relationship("Rstate", backref="whitelist") + + def __init__( + self, + ip, + mask, + expires, + user_id, + org_id, + created=None, + comment=None, + rstate_id=1, + ): + self.ip = ip + self.mask = mask + self.expires = expires + self.user_id = user_id + self.org_id = org_id + self.comment = comment + if created is None: + created = datetime.now() + self.created = created + self.rstate_id = rstate_id + + def __eq__(self, other): + """ + Two whitelists are equal if all the network parameters equals. + User_id, org, comment and time fields can differ. + :param other: other Whitelist instance + :return: boolean + """ + return self.ip == other.ip and self.mask == other.mask and self.rstate_id == other.rstate_id + + def to_dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + if prefered_format == "timestamp": + expires = int(datetime.timestamp(self.expires)) + created = int(datetime.timestamp(self.created)) + else: + expires = utils.datetime_to_webpicker(self.expires, prefered_format) + created = utils.datetime_to_webpicker(self.created, prefered_format) + + return { + "id": self.id, + "ip": self.ip, + "mask": self.mask, + "comment": self.comment, + "expires": expires, + "created": created, + "user": self.user.uuid, + "rstate": self.rstate.description, + } + + +class RuleWhitelistCache(db.Model): + """ + Cache for whitelisted rules + For each rule we store id and type + Rule origin determines if the rule was created by user or by whitelist + """ + + id = db.Column(db.Integer, primary_key=True) + rid = db.Column(db.Integer) + rtype = db.Column(db.Integer) + rorigin = db.Column(db.Integer) + whitelist_id = db.Column(db.Integer, db.ForeignKey("whitelist.id")) # Add ForeignKey + whitelist = db.relationship("Whitelist", backref="rulewhitelistcache") + + def __init__(self, rid, rtype, rorigin, whitelist_id): + self.rid = rid + self.rtype = rtype + self.rorigin = rorigin + self.whitelist_id = whitelist_id diff --git a/flowapp/models/user.py b/flowapp/models/user.py new file mode 100644 index 0000000..dcb2d7e --- /dev/null +++ b/flowapp/models/user.py @@ -0,0 +1,79 @@ +from sqlalchemy import event +from .base import db, user_role, user_organization +from .organization import Organization + + +class User(db.Model): + """ + App User + """ + + id = db.Column(db.Integer, primary_key=True) + uuid = db.Column(db.String(180), unique=True) + comment = db.Column(db.String(500)) + email = db.Column(db.String(255)) + name = db.Column(db.String(255)) + phone = db.Column(db.String(255)) + apikeys = db.relationship("ApiKey", back_populates="user", lazy="dynamic") + machineapikeys = db.relationship("MachineApiKey", back_populates="user", lazy="dynamic") + role = db.relationship("Role", secondary=user_role, lazy="dynamic", backref="user") + + organization = db.relationship("Organization", secondary=user_organization, lazy="dynamic", backref="user") + + def __init__(self, uuid, name=None, phone=None, email=None, comment=None): + self.uuid = uuid + self.phone = phone + self.name = name + self.email = email + self.comment = comment + + def update(self, form): + """ + update the user with values from form object + :param form: flask form from request + :return: None + """ + self.uuid = form.uuid.data + self.name = form.name.data + self.email = form.email.data + self.phone = form.phone.data + self.comment = form.comment.data + + # first clear existing roles and orgs + for role in self.role: + self.role.remove(role) + for org in self.organization: + self.organization.remove(org) + + for role_id in form.role_ids.data: + my_role = db.session.query(Role).filter_by(id=role_id).first() + if my_role not in self.role: + self.role.append(my_role) + + for org_id in form.org_ids.data: + my_org = db.session.query(Organization).filter_by(id=org_id).first() + if my_org not in self.organization: + self.organization.append(my_org) + + db.session.commit() + + +class Role(db.Model): + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(20), unique=True) + description = db.Column(db.String(260)) + + def __init__(self, name, description): + self.name = name + self.description = description + + def __repr__(self): + return self.name + + +# Event listeners for Role +@event.listens_for(Role.__table__, "after_create") +def insert_initial_roles(table, conn, *args, **kwargs): + conn.execute(table.insert().values(name="view", description="just view, no edit")) + conn.execute(table.insert().values(name="user", description="can edit")) + conn.execute(table.insert().values(name="admin", description="admin")) diff --git a/flowapp/models/utils.py b/flowapp/models/utils.py new file mode 100644 index 0000000..e2129b2 --- /dev/null +++ b/flowapp/models/utils.py @@ -0,0 +1,330 @@ +"""Utility functions for models""" + +from datetime import datetime +from flowapp import utils +from flowapp.constants import RuleTypes +from flask import current_app +from .base import db +from .user import User, Role +from .organization import Organization +from .community import Community +from .rules.flowspec import Flowspec4, Flowspec6 +from .rules.rtbh import RTBH +from .rules.base import Action + + +def check_rule_limit(org_id: int, rule_type: RuleTypes) -> bool: + """ + Check if the organization has reached the rule limit + :param org_id: integer organization id + :param rule_type: RuleType rule type + :return: boolean + """ + flowspec_limit = current_app.config.get("FLOWSPEC_MAX_RULES", 9000) + rtbh_limit = current_app.config.get("RTBH_MAX_RULES", 100000) + fs4 = db.session.query(Flowspec4).filter_by(rstate_id=1).count() + fs6 = db.session.query(Flowspec6).filter_by(rstate_id=1).count() + rtbh = db.session.query(RTBH).filter_by(rstate_id=1).count() + + # check the organization limits + org = Organization.query.filter_by(id=org_id).first() + if rule_type == RuleTypes.IPv4 and org.limit_flowspec4 > 0: + count = db.session.query(Flowspec4).filter_by(org_id=org_id, rstate_id=1).count() + return count >= org.limit_flowspec4 or fs4 >= flowspec_limit + if rule_type == RuleTypes.IPv6 and org.limit_flowspec6 > 0: + count = db.session.query(Flowspec6).filter_by(org_id=org_id, rstate_id=1).count() + return count >= org.limit_flowspec6 or fs6 >= flowspec_limit + if rule_type == RuleTypes.RTBH and org.limit_rtbh > 0: + count = db.session.query(RTBH).filter_by(org_id=org_id, rstate_id=1).count() + return count >= org.limit_rtbh or rtbh >= rtbh_limit + + +def check_global_rule_limit(rule_type: RuleTypes) -> bool: + flowspec4_limit = current_app.config.get("FLOWSPEC4_MAX_RULES", 9000) + flowspec6_limit = current_app.config.get("FLOWSPEC6_MAX_RULES", 9000) + rtbh_limit = current_app.config.get("RTBH_MAX_RULES", 100000) + fs4 = db.session.query(Flowspec4).filter_by(rstate_id=1).count() + fs6 = db.session.query(Flowspec6).filter_by(rstate_id=1).count() + rtbh = db.session.query(RTBH).filter_by(rstate_id=1).count() + + # check the global limits if the organization limits are not set + + if rule_type == RuleTypes.IPv4: + return fs4 >= flowspec4_limit + if rule_type == RuleTypes.IPv6: + return fs6 >= flowspec6_limit + if rule_type == RuleTypes.RTBH: + return rtbh >= rtbh_limit + + +def get_ipv4_model_if_exists(form_data, rstate_id=1): + """ + Check if the record in database exist + """ + record = ( + db.session.query(Flowspec4) + .filter( + Flowspec4.source == form_data["source"], + Flowspec4.source_mask == form_data["source_mask"], + Flowspec4.source_port == form_data["source_port"], + Flowspec4.dest == form_data["dest"], + Flowspec4.dest_mask == form_data["dest_mask"], + Flowspec4.dest_port == form_data["dest_port"], + Flowspec4.protocol == form_data["protocol"], + Flowspec4.flags == ";".join(form_data["flags"]), + Flowspec4.packet_len == form_data["packet_len"], + Flowspec4.action_id == form_data["action"], + Flowspec4.rstate_id == rstate_id, + ) + .first() + ) + + if record: + return record + + return False + + +def get_ipv6_model_if_exists(form_data, rstate_id=1): + """ + Check if the record in database exist + """ + record = ( + db.session.query(Flowspec6) + .filter( + Flowspec6.source == form_data["source"], + Flowspec6.source_mask == form_data["source_mask"], + Flowspec6.source_port == form_data["source_port"], + Flowspec6.dest == form_data["dest"], + Flowspec6.dest_mask == form_data["dest_mask"], + Flowspec6.dest_port == form_data["dest_port"], + Flowspec6.next_header == form_data["next_header"], + Flowspec6.flags == ";".join(form_data["flags"]), + Flowspec6.packet_len == form_data["packet_len"], + Flowspec6.action_id == form_data["action"], + Flowspec6.rstate_id == rstate_id, + ) + .first() + ) + + if record: + return record + + return False + + +def get_rtbh_model_if_exists(form_data, rstate_id=1): + """ + Check if the record in database exist + """ + + record = ( + db.session.query(RTBH) + .filter( + RTBH.ipv4 == form_data["ipv4"], + RTBH.ipv4_mask == form_data["ipv4_mask"], + RTBH.ipv6 == form_data["ipv6"], + RTBH.ipv6_mask == form_data["ipv6_mask"], + RTBH.community_id == form_data["community"], + RTBH.rstate_id == rstate_id, + ) + .first() + ) + + if record: + return record + + return False + + +def insert_users(users): + """ + inser list of users {name: string, role_id: integer} to db + """ + for user in users: + r = Role.query.filter_by(id=user["role_id"]).first() + o = Organization.query.filter_by(id=user["org_id"]).first() + u = User(uuid=user["name"]) + u.role.append(r) + u.organization.append(o) + db.session.add(u) + + db.session.commit() + + +def insert_user( + uuid: str, + role_ids: list, + org_ids: list, + name: str = None, + phone: str = None, + email: str = None, + comment: str = None, +): + """ + insert new user with multiple roles and organizations + :param uuid: string unique user id (eppn or similar) + :param phone: string phone number + :param name: string user name + :param email: string email + :param comment: string comment / notice + :param role_ids: list of roles + :param org_ids: list of orgs + :return: None + """ + u = User(uuid=uuid, name=name, phone=phone, comment=comment, email=email) + + for role_id in role_ids: + r = Role.query.filter_by(id=role_id).first() + u.role.append(r) + + for org_id in org_ids: + o = Organization.query.filter_by(id=org_id).first() + u.organization.append(o) + + db.session.add(u) + db.session.commit() + + +def get_user_nets(user_id): + """ + Return list of network ranges for all user organization + """ + user = db.session.query(User).filter_by(id=user_id).first() + orgs = user.organization + result = [] + for org in orgs: + result.extend(org.arange.split()) + + return result + + +def get_user_orgs_choices(user_id): + """ + Return list of orgs as choices for form + """ + user = db.session.query(User).filter_by(id=user_id).first() + orgs = user.organization + + return [(g.id, g.name) for g in orgs] + + +def get_user_actions(user_roles): + """ + Return list of actions based on current user role + """ + max_role = max(user_roles) + print(max_role) + if max_role == 3: + actions = db.session.query(Action).order_by("id").all() + else: + actions = db.session.query(Action).filter_by(role_id=max_role).order_by("id").all() + result = [(g.id, g.name) for g in actions] + print(actions, result) + return result + + +def get_user_communities(user_roles): + """ + Return list of communities based on current user role + """ + max_role = max(user_roles) + if max_role == 3: + communities = db.session.query(Community).order_by("id") + else: + communities = db.session.query(Community).filter_by(role_id=max_role).order_by("id") + + return [(g.id, g.name) for g in communities] + + +def get_existing_action(name=None, command=None): + """ + return Action with given name or command if the action exists + return None if action not exists + :param name: string action name + :param command: string action command + :return: action id + """ + action = Action.query.filter((Action.name == name) | (Action.command == command)).first() + return action.id if hasattr(action, "id") else None + + +def get_existing_community(name=None): + """ + return Community with given name or command if the action exists + return None if action not exists + :param name: string action name + :param command: string action command + :return: action id + """ + community = Community.query.filter(Community.name == name).first() + return community.id if hasattr(community, "id") else None + + +def get_ip_rules(rule_type, rule_state, sort="expires", order="desc"): + """ + Returns list of rules sorted by sort column ordered asc or desc + :param sort: sorting column + :param order: asc or desc + :return: list + """ + + today = datetime.now() + comp_func = utils.get_comp_func(rule_state) + + if rule_type == "ipv4": + sorter_ip4 = getattr(Flowspec4, sort, Flowspec4.id) + sorting_ip4 = getattr(sorter_ip4, order) + if comp_func: + rules4 = ( + db.session.query(Flowspec4).filter(comp_func(Flowspec4.expires, today)).order_by(sorting_ip4()).all() + ) + else: + rules4 = db.session.query(Flowspec4).order_by(sorting_ip4()).all() + + return rules4 + + if rule_type == "ipv6": + sorter_ip6 = getattr(Flowspec6, sort, Flowspec6.id) + sorting_ip6 = getattr(sorter_ip6, order) + if comp_func: + rules6 = ( + db.session.query(Flowspec6).filter(comp_func(Flowspec6.expires, today)).order_by(sorting_ip6()).all() + ) + else: + rules6 = db.session.query(Flowspec6).order_by(sorting_ip6()).all() + + return rules6 + + if rule_type == "rtbh": + sorter_rtbh = getattr(RTBH, sort, RTBH.id) + sorting_rtbh = getattr(sorter_rtbh, order) + + if comp_func: + rules_rtbh = db.session.query(RTBH).filter(comp_func(RTBH.expires, today)).order_by(sorting_rtbh()).all() + + else: + rules_rtbh = db.session.query(RTBH).order_by(sorting_rtbh()).all() + + return rules_rtbh + + +def get_user_rules_ids(user_id, rule_type): + """ + Returns list of rule ids belonging to user + :param user_id: user id + :param rule_type: ipv4, ipv6 or rtbh + :return: list + """ + + if rule_type == "ipv4": + rules4 = db.session.query(Flowspec4.id).filter_by(user_id=user_id).all() + return [int(x[0]) for x in rules4] + + if rule_type == "ipv6": + rules6 = db.session.query(Flowspec6.id).order_by(Flowspec6.expires.desc()).all() + return [int(x[0]) for x in rules6] + + if rule_type == "rtbh": + rules_rtbh = db.session.query(RTBH.id).filter_by(user_id=user_id).all() + return [int(x[0]) for x in rules_rtbh] diff --git a/flowapp/output.py b/flowapp/output.py index b0cc7b4..8322544 100644 --- a/flowapp/output.py +++ b/flowapp/output.py @@ -96,7 +96,7 @@ def log_route(user_id, route_model, rule_type, author): :param rule_type: string :return: None """ - print(rule_type) + rule_type = rule_type.value converter = ROUTE_MODELS[rule_type] task = converter(route_model) log = Log( @@ -119,7 +119,7 @@ def log_withdraw(user_id, task, rule_type, deleted_id, author): log = Log( time=datetime.now(), task=task, - rule_type=rule_type, + rule_type=rule_type.value, rule_id=deleted_id, user_id=user_id, author=author, diff --git a/flowapp/utils.py b/flowapp/utils.py index ad15b61..512f7cc 100644 --- a/flowapp/utils.py +++ b/flowapp/utils.py @@ -1,4 +1,3 @@ -from operator import ge, lt from datetime import datetime, timedelta from flask import flash from flowapp.constants import ( @@ -73,7 +72,7 @@ def parse_api_time(apitime): except ValueError: mytime = False - return False + return mytime def quote_to_ent(comment): From 6fa4bb49ab488804c422d9c02e06bb5aebbea385 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 28 Feb 2025 08:53:40 +0100 Subject: [PATCH 08/36] bugfix - RuleTypes use in api_common and output --- flowapp/output.py | 44 +++++++++++++++++++++++-------------- flowapp/views/api_common.py | 12 +++++----- 2 files changed, 33 insertions(+), 23 deletions(-) diff --git a/flowapp/output.py b/flowapp/output.py index 8322544..b8774bf 100644 --- a/flowapp/output.py +++ b/flowapp/output.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, asdict from datetime import datetime +from typing import Dict, Union, Callable, Any import requests import pika @@ -11,9 +12,12 @@ from flask import current_app from flowapp import db, messages +from flowapp.constants import RuleTypes from flowapp.models import Log +from flowapp.models import Flowspec4, Flowspec6, RTBH -ROUTE_MODELS = { + +ROUTE_MODELS: Dict[int, Callable[[Any], str]] = { 1: messages.create_rtbh, 4: messages.create_ipv4, 6: messages.create_ipv6, @@ -21,21 +25,21 @@ class RouteSources: - UI = "UI" - API = "API" + UI: str = "UI" + API: str = "API" @dataclass class Route: author: str - source: RouteSources + source: str # Using str instead of RouteSources for flexibility command: str - def __dict__(self): + def __dict__(self) -> Dict[str, str]: return asdict(self) -def announce_route(route: Route): +def announce_route(route: Route) -> None: """ Dispatch route as dict to ExaBGP API API must be set in app config.py @@ -48,7 +52,7 @@ def announce_route(route: Route): announce_to_http(asdict(route)) -def announce_to_http(route): +def announce_to_http(route: Dict[str, str]) -> None: """ Announce route to ExaBGP HTTP API process """ @@ -64,7 +68,7 @@ def announce_to_http(route): current_app.logger.debug(f"Testing: {route}") -def announce_to_rabbitmq(route): +def announce_to_rabbitmq(route: Dict[str, str]) -> None: """ Announce rout to ExaBGP RabbitMQ API process """ @@ -85,24 +89,25 @@ def announce_to_rabbitmq(route): channel.queue_declare(queue=queue) channel.basic_publish(exchange="", routing_key=queue, body=json.dumps(route)) else: - current_app.logger.debug("Testing: {route}") + current_app.logger.debug(f"Testing: {route}") -def log_route(user_id, route_model, rule_type, author): +def log_route(user_id: int, route_model: Union[RTBH, Flowspec4, Flowspec6], rule_type: RuleTypes, author: str) -> None: """ Convert route to EXAFS message and log it to database :param user_id : int curent user - :param route_model: model with route object - :param rule_type: string + :param route_model: model with route object (RTBH, Flowspec4, or Flowspec6) + :param rule_type: RuleTypes enum + :param author: str name of the author :return: None """ - rule_type = rule_type.value - converter = ROUTE_MODELS[rule_type] + rule_type_value = rule_type.value + converter = ROUTE_MODELS[rule_type_value] task = converter(route_model) log = Log( time=datetime.now(), task=task, - rule_type=rule_type, + rule_type=rule_type_value, rule_id=route_model.id, user_id=user_id, author=author, @@ -111,10 +116,15 @@ def log_route(user_id, route_model, rule_type, author): db.session.commit() -def log_withdraw(user_id, task, rule_type, deleted_id, author): +def log_withdraw(user_id: int, task: str, rule_type: RuleTypes, deleted_id: int, author: str) -> None: """ Log the withdraw command to database - :param task: command message + :param user_id: int user ID + :param task: str command message + :param rule_type: RuleTypes enum + :param deleted_id: int ID of deleted rule + :param author: str name of the author + :return: None """ log = Log( time=datetime.now(), diff --git a/flowapp/views/api_common.py b/flowapp/views/api_common.py index d04e40d..b421eea 100644 --- a/flowapp/views/api_common.py +++ b/flowapp/views/api_common.py @@ -292,7 +292,7 @@ def create_ipv4(current_user): log_route( current_user["id"], model, - RuleTypes.IPv4.value, + RuleTypes.IPv4, f"{current_user['uuid']} / {current_user['org']}", ) pref_format = output_date_format(json_request_data, form.expires.pref_format) @@ -368,7 +368,7 @@ def create_ipv6(current_user): log_route( current_user["id"], model, - RuleTypes.IPv6.value, + RuleTypes.IPv6, f"{current_user['uuid']} / {current_user['org']}", ) @@ -441,7 +441,7 @@ def create_rtbh(current_user): log_route( current_user["id"], model, - RuleTypes.RTBH.value, + RuleTypes.RTBH, f"{current_user['uuid']} / {current_user['org']}", ) @@ -506,7 +506,7 @@ def delete_v4_rule(current_user, rule_id): """ model_name = Flowspec4 route_model = messages.create_ipv4 - return delete_rule(current_user, rule_id, model_name, route_model, 4) + return delete_rule(current_user, rule_id, model_name, route_model, RuleTypes.IPv4) def delete_v6_rule(current_user, rule_id): @@ -516,7 +516,7 @@ def delete_v6_rule(current_user, rule_id): """ model_name = Flowspec6 route_model = messages.create_ipv6 - return delete_rule(current_user, rule_id, model_name, route_model, 6) + return delete_rule(current_user, rule_id, model_name, route_model, RuleTypes.IPv6) def delete_rtbh_rule(current_user, rule_id): @@ -526,7 +526,7 @@ def delete_rtbh_rule(current_user, rule_id): """ model_name = RTBH route_model = messages.create_rtbh - return delete_rule(current_user, rule_id, model_name, route_model, 1) + return delete_rule(current_user, rule_id, model_name, route_model, RuleTypes.RTBH) def delete_rule(current_user, rule_id, model_name, route_model, rule_type): From 24253830d94e817a17c403b31dab198f533e301f Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 28 Feb 2025 11:29:02 +0100 Subject: [PATCH 09/36] refactoring forms.py into directory forms for better readibility and clean code --- flowapp/forms.py | 713 ------------------------------- flowapp/forms/__init__.py | 40 ++ flowapp/forms/api.py | 61 +++ flowapp/forms/base.py | 50 +++ flowapp/forms/choices.py | 81 ++++ flowapp/forms/organization.py | 47 ++ flowapp/forms/rules/__init__.py | 17 + flowapp/forms/rules/base.py | 127 ++++++ flowapp/forms/rules/ipv4.py | 70 +++ flowapp/forms/rules/ipv6.py | 64 +++ flowapp/forms/rules/rtbh.py | 104 +++++ flowapp/forms/rules/whitelist.py | 66 +++ flowapp/forms/user.py | 91 ++++ 13 files changed, 818 insertions(+), 713 deletions(-) delete mode 100644 flowapp/forms.py create mode 100644 flowapp/forms/__init__.py create mode 100644 flowapp/forms/api.py create mode 100644 flowapp/forms/base.py create mode 100644 flowapp/forms/choices.py create mode 100644 flowapp/forms/organization.py create mode 100644 flowapp/forms/rules/__init__.py create mode 100644 flowapp/forms/rules/base.py create mode 100644 flowapp/forms/rules/ipv4.py create mode 100644 flowapp/forms/rules/ipv6.py create mode 100644 flowapp/forms/rules/rtbh.py create mode 100644 flowapp/forms/rules/whitelist.py create mode 100644 flowapp/forms/user.py diff --git a/flowapp/forms.py b/flowapp/forms.py deleted file mode 100644 index 5baca78..0000000 --- a/flowapp/forms.py +++ /dev/null @@ -1,713 +0,0 @@ -import csv -from io import StringIO - - -from flask_wtf import FlaskForm -from wtforms import widgets -from wtforms import ( - BooleanField, - DateTimeField, - HiddenField, - IntegerField, - SelectField, - SelectMultipleField, - StringField, - TextAreaField, -) -from wtforms.validators import ( - ValidationError, - DataRequired, - Email, - InputRequired, - IPAddress, - Length, - NumberRange, - Optional, -) - -from flowapp.constants import ( - IPV4_FRAGMENT, - IPV4_PROTOCOL, - IPV6_NEXT_HEADER, - TCP_FLAGS, - FORM_TIME_PATTERN, -) -from flowapp.validators import ( - IPAddressValidator, - IPv4Address, - IPv6Address, - NetRangeString, - NetworkValidator, - PortString, - address_in_range, - address_with_mask, - network_in_range, - whole_world_range, -) - -from flowapp.utils import parse_api_time - - -class MultiFormatDateTimeLocalField(DateTimeField): - """ - Same as :class:`~wtforms.fields.DateTimeField`, but represents an - ````. - - Custom implementation uses default HTML5 format for parsing the field. - It's possible to use multiple formats - used in API. - - """ - - widget = widgets.DateTimeLocalInput() - - def __init__(self, *args, **kwargs): - kwargs.setdefault("format", "%Y-%m-%dT%H:%M") - self.unlimited = kwargs.pop("unlimited", False) - self.pref_format = None - super().__init__(*args, **kwargs) - - def process_formdata(self, valuelist): - if not valuelist or (len(valuelist) == 1 and not valuelist[0]): - return None - - # with unlimited field we do not need to parse the empty value - if self.unlimited and len(valuelist) == 1 and len(valuelist[0]) == 0: - self.data = None - return None - - date_str = " ".join((str(val) for val in valuelist)) - - try: - result, pref_format = parse_api_time(date_str) - except TypeError: - raise ValueError(self.gettext("Not a valid datetime value.")) - - if result: - self.data = result - self.pref_format = pref_format - else: - self.data = None - self.pref_format = None - raise ValueError(self.gettext("Not a valid datetime value.")) - - -class UserForm(FlaskForm): - """ - User Form object - used in Admin - """ - - uuid = StringField( - "Unique User ID", - validators=[ - InputRequired("Please provide UUID"), - Email("Please provide valid email"), - ], - ) - - email = StringField("Email", validators=[Optional(), Email("Please provide valid email")]) - - comment = StringField("Notice", validators=[Optional()]) - - name = StringField("Name", validators=[Optional()]) - - phone = StringField("Contact phone", validators=[Optional()]) - - role_ids = SelectMultipleField("Role", coerce=int, validators=[DataRequired("Select at last one role")]) - - org_ids = SelectMultipleField( - "Organization", - coerce=int, - validators=[DataRequired("We prefer one Organization per user, but it's possible select more")], - ) - - -class BulkUserForm(FlaskForm): - """ - Bulk User Form object - used in Admin - """ - - users = TextAreaField("Users in CSV - see example below", validators=[DataRequired()]) - - def __init__(self, *args, **kwargs): - super(BulkUserForm, self).__init__(*args, **kwargs) - self.roles = None - self.organizations = None - self.uuids = None - - # Custom validator for CSV data - def validate_users(self, field): - csv_data = field.data - - # Parse CSV data - csv_reader = csv.DictReader(StringIO(csv_data), delimiter=",") - - # List to keep track of failed validation rows - errors = 0 - for row_num, row in enumerate(csv_reader, start=1): - try: - # check if the user not already exists - if row["uuid-eppn"] in self.uuids: - field.errors.append(f"Row {row_num}: User with UUID {row['uuid-eppn']} already exists.") - errors += 1 - - # Check if role exists in the database - role_id = int(row["role"]) # Convert role field to integer - if role_id not in self.roles: - field.errors.append(f"Row {row_num}: Role ID {role_id} does not exist.") - errors += 1 - - # Check if organization exists in the database - org_id = int(row["organizace"]) # Convert organization field to integer - if org_id not in self.organizations: - field.errors.append(f"Row {row_num}: Organization ID {org_id} does not exist.") - errors += 1 - - except (KeyError, ValueError) as e: - field.errors.append(f"Row {row_num}: Invalid data / key - {str(e)}. Check CSV head row.") - - if errors > 0: - # Raise validation error if any invalid rows found - raise ValidationError("Invalid CSV Data - check the errors above.") - - -class ApiKeyForm(FlaskForm): - """ - ApiKey for User - Each key / machine pair is unique - """ - - machine = StringField( - "Machine address", - validators=[DataRequired(), IPAddress(message="provide valid IP address")], - ) - - comment = TextAreaField("Your comment for this key", validators=[Optional(), Length(max=255)]) - - expires = MultiFormatDateTimeLocalField( - "Key expiration. Leave blank for non expring key (not-recomended).", - format=FORM_TIME_PATTERN, - validators=[Optional()], - unlimited=True, - ) - - readonly = BooleanField("Read only key", default=False) - - key = HiddenField("GeneratedKey") - - -class MachineApiKeyForm(FlaskForm): - """ - ApiKey for Machines - Each key / machine pair is unique - Only Admin can create new these keys - """ - - machine = StringField( - "Machine address", - validators=[DataRequired(), IPAddress(message="provide valid IP address")], - ) - - comment = TextAreaField("Your comment for this key", validators=[Optional(), Length(max=255)]) - - expires = MultiFormatDateTimeLocalField( - "Key expiration. Leave blank for non expring key (not-recomended).", - format=FORM_TIME_PATTERN, - validators=[Optional()], - unlimited=True, - ) - - readonly = BooleanField("Read only key", default=False) - - key = HiddenField("GeneratedKey") - - -class OrganizationForm(FlaskForm): - """ - Organization form object - used in Admin - """ - - name = StringField("Organization name", validators=[Optional(), Length(max=150)]) - - limit_flowspec4 = IntegerField( - "Maximum number of IPv4 rules, 0 for unlimited", - validators=[ - Optional(), - NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), - ], - ) - - limit_flowspec6 = IntegerField( - "Maximum number of IPv6 rules, 0 for unlimited", - validators=[ - Optional(), - NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), - ], - ) - - limit_rtbh = IntegerField( - "Maximum number of RTBH rules, 0 for unlimited", - validators=[ - Optional(), - NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), - ], - ) - - arange = TextAreaField( - "Organization Adress Range - one range per row", - validators=[Optional(), NetRangeString()], - ) - - -class ActionForm(FlaskForm): - """ - Action form object - used in Admin - """ - - name = StringField("Action short name", validators=[Length(max=150)]) - - command = StringField("ExaBGP command", validators=[Length(max=150)]) - - description = StringField("Action description") - - role_id = SelectField( - "Minimal required role", - choices=[("2", "user"), ("3", "admin")], - validators=[DataRequired()], - ) - - -class ASPathForm(FlaskForm): - """ - AS Path form object - used in Admin - """ - - prefix = StringField("Prefix", validators=[Length(max=120), DataRequired()]) - - as_path = StringField("as-path value", validators=[Length(max=250), DataRequired()]) - - -class CommunityForm(FlaskForm): - """ - Community form object - used in Admin - """ - - name = StringField("Community short name", validators=[Length(max=120), DataRequired()]) - - comm = StringField("Community value", validators=[Length(max=2046)]) - - larcomm = StringField("Large community value", validators=[Length(max=2046)]) - - extcomm = StringField("Extended community value", validators=[Length(max=2046)]) - - description = StringField("Community description", validators=[Length(max=255)]) - - role_id = SelectField( - "Minimal required role", - choices=[("2", "user"), ("3", "admin")], - validators=[DataRequired()], - ) - - as_path = BooleanField("add AS-path (checked = true)") - - def validate(self): - """ - custom validation method - :return: boolean - """ - result = True - - if not FlaskForm.validate(self): - result = False - - if not self.comm.data and not self.extcomm.data and not self.larcomm.data: - err_message = "At last one of those values could not be empty" - self.comm.errors.append(err_message) - self.larcomm.errors.append(err_message) - self.extcomm.errors.append(err_message) - result = False - - return result - - -class RTBHForm(FlaskForm): - """ - RoadToBlackHole rule form - """ - - def __init__(self, *args, **kwargs): - super(RTBHForm, self).__init__(*args, **kwargs) - self.net_ranges = None - - ipv4 = StringField( - "IPv4 address", - validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], - ) - - ipv4_mask = IntegerField( - "IPv4 mask (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=32, message="invalid IPv4 mask value (0-32)"), - ], - ) - - ipv6 = StringField( - "IPv6 address", - validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], - ) - - ipv6_mask = IntegerField( - "IPv6 mask (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=128, message="invalid IPv6 mask value (0-128)"), - ], - ) - - community = SelectField( - "Community", - coerce=int, - validators=[ - DataRequired(message="Please select a community for the rule."), - ], - ) - - expires = MultiFormatDateTimeLocalField( - "Expires", - format=FORM_TIME_PATTERN, - validators=[DataRequired(), InputRequired()], - ) - - comment = arange = TextAreaField("Comments") - - def validate(self): - """ - custom validation method - :return: boolean - """ - result = True - - if not FlaskForm.validate(self): - result = False - - # ipv4 and ipv6 are mutually exclusive - # if both are set, validation fails - # if none is set, validation fails - # if one is set, validation passes - if self.ipv4.data and self.ipv6.data: - self.ipv4.errors.append("IPv4 and IPv6 are mutually exclusive in RTBH rule.") - self.ipv6.errors.append("IPv4 and IPv6 are mutually exclusive in RTBH rule.") - result = False - - if self.ipv4.data and not address_with_mask(self.ipv4.data, self.ipv4_mask.data): - self.ipv4.errors.append( - "This is not valid combination of address {} and mask {}.".format(self.ipv4.data, self.ipv4_mask.data) - ) - result = False - - if self.ipv6.data and not address_with_mask(self.ipv6.data, self.ipv6_mask.data): - self.ipv6.errors.append( - "This is not valid combination of address {} and mask {}.".format(self.ipv6.data, self.ipv6_mask.data) - ) - result = False - - ipv6_in_range = address_in_range(self.ipv6.data, self.net_ranges) - ipv4_in_range = address_in_range(self.ipv4.data, self.net_ranges) - - if not (ipv6_in_range or ipv4_in_range): - self.ipv6.errors.append("IPv4 or IPv6 address must be in organization range : {}.".format(self.net_ranges)) - self.ipv4.errors.append("IPv4 or IPv6 address must be in organization range : {}.".format(self.net_ranges)) - result = False - - return result - - -class IPForm(FlaskForm): - """ - Base class for IPv4 and IPv6 rules - """ - - def __init__(self, *args, **kwargs): - super(IPForm, self).__init__(*args, **kwargs) - self.net_ranges = None - - zero_address = None - source = None - source_mask = None - dest = None - dest_mask = None - flags = SelectMultipleField("TCP flag(s)", choices=TCP_FLAGS, validators=[Optional()]) - - source_port = StringField( - "Source port(s) - ; separated ", - validators=[Optional(), Length(max=255), PortString()], - ) - - dest_port = StringField( - "Destination port(s) - ; separated", - validators=[Optional(), Length(max=255), PortString()], - ) - - packet_len = StringField( - "Packet length - ; separated ", - validators=[Optional(), Length(max=255), PortString()], - ) - - action = SelectField( - "Action", - coerce=int, - validators=[DataRequired(message="Please select an action for the rule.")], - ) - - expires = MultiFormatDateTimeLocalField("Expires", format="%Y-%m-%dT%H:%M", validators=[InputRequired()]) - - comment = arange = TextAreaField("Comments") - - def validate(self): - """ - custom validation method - :return: boolean - """ - - result = True - if not FlaskForm.validate(self): - result = False - - source = self.validate_source_address() - dest = self.validate_dest_address() - ranges = self.validate_address_ranges() - ips = self.validate_ipv_specific() - - return result and source and dest and ranges and ips - - def validate_source_address(self): - """ - validate source address, set error message if validation fails - :return: boolean validation result - """ - if self.source.data and not address_with_mask(self.source.data, self.source_mask.data): - self.source.errors.append( - "This is not valid combination of address {} and mask {}.".format( - self.source.data, self.source_mask.data - ) - ) - return False - - return True - - def validate_dest_address(self): - """ - validate dest address, set error message if validation fails - :return: boolean validation result - """ - if self.dest.data and not address_with_mask(self.dest.data, self.dest_mask.data): - self.dest.errors.append( - "This is not valid combination of address {} and mask {}.".format(self.dest.data, self.dest_mask.data) - ) - return False - - return True - - def validate_address_ranges(self): - """ - validates if the address of source is in the user range - if the source and dest address are empty, check if the user - is member of whole world organization - :return: boolean validation result - """ - if not (self.source.data or self.dest.data): - whole_world_member = whole_world_range(self.net_ranges, self.zero_address) - if not whole_world_member: - self.source.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) - self.dest.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) - return False - else: - source_in_range = network_in_range(self.source.data, self.source_mask.data, self.net_ranges) - dest_in_range = network_in_range(self.dest.data, self.dest_mask.data, self.net_ranges) - if not (source_in_range or dest_in_range): - self.source.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) - self.dest.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) - return False - - return True - - def validate_ipv_specific(self): - """ - abstract method must be implemented in the subclass - """ - pass - - -class IPv4Form(IPForm): - """ - IPv4 form object - """ - - def __init__(self, *args, **kwargs): - super(IPv4Form, self).__init__(*args, **kwargs) - self.net_ranges = None - - zero_address = "0.0.0.0" - source = StringField( - "Source address", - validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], - ) - - source_mask = IntegerField( - "Source mask (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=32, message="invalid mask value (0-32)"), - ], - ) - - dest = StringField( - "Destination address", - validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], - ) - - dest_mask = IntegerField( - "Destination mask (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=32, message="invalid mask value (0-32)"), - ], - ) - - protocol = SelectField( - "Protocol", - choices=[(pr, pr.upper()) for pr in IPV4_PROTOCOL.keys()], - validators=[DataRequired()], - ) - - fragment = SelectMultipleField( - "Fragment", - choices=[(frv, frk.upper()) for frk, frv in IPV4_FRAGMENT.items()], - validators=[Optional()], - ) - - def validate_ipv_specific(self): - """ - validate protocol and flags, set error message if validation fails - :return: boolean validation result - """ - - if self.flags.data and self.protocol.data and len(self.flags.data) > 0 and self.protocol.data != "tcp": - self.flags.errors.append("Can not set TCP flags for protocol {} !".format(self.protocol.data.upper())) - return False - return True - - -class IPv6Form(IPForm): - """ - IPv6 form object - """ - - def __init__(self, *args, **kwargs): - super(IPv6Form, self).__init__(*args, **kwargs) - self.net_ranges = None - - zero_address = "::" - source = StringField( - "Source address", - validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], - ) - - source_mask = IntegerField( - "Source prefix length (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=128, message="invalid prefix value (0-128)"), - ], - ) - - dest = StringField( - "Destination address", - validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], - ) - - dest_mask = IntegerField( - "Destination prefix length (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=128, message="invalid prefix value (0-128)"), - ], - ) - - next_header = SelectField( - "Next Header", - choices=[(pr, pr.upper()) for pr in IPV6_NEXT_HEADER.keys()], - validators=[DataRequired()], - ) - - def validate_ipv_specific(self): - """ - validate next header and flags, set error message if validation fails - :return: boolean validation result - """ - if len(self.flags.data) > 0 and self.next_header.data != "tcp": - self.flags.errors.append("Can not set TCP flags for next-header {} !".format(self.next_header.data.upper())) - return False - - return True - - -class WhitelistForm(FlaskForm): - """ - Whitelist form object - Used for creating and editing whitelist entries - Supports both IPv4 and IPv6 addresses - """ - - def __init__(self, *args, **kwargs): - super(WhitelistForm, self).__init__(*args, **kwargs) - self.net_ranges = None - - ip = StringField( - "IP address", - validators=[ - DataRequired(message="Please provide an IP address"), - IPAddressValidator(message="Please provide a valid IP address: {}"), - NetworkValidator(mask_field_name="mask"), - ], - ) - - mask = IntegerField( - "Network mask (bits)", - validators=[ - DataRequired(message="Please provide a network mask"), - ], - ) - - comment = TextAreaField("Comments", validators=[Optional(), Length(max=255)]) - - expires = MultiFormatDateTimeLocalField( - "Expires", - format=FORM_TIME_PATTERN, - validators=[DataRequired(), InputRequired()], - ) - - def validate(self): - """ - Custom validation method - :return: boolean - """ - result = True - - if not FlaskForm.validate(self): - result = False - - # Validate IP is in organization range - if self.ip.data and self.mask.data and self.net_ranges: - ip_in_range = network_in_range(self.ip.data, self.mask.data, self.net_ranges) - if not ip_in_range: - self.ip.errors.append("IP address must be in organization range: {}.".format(self.net_ranges)) - result = False - - return result diff --git a/flowapp/forms/__init__.py b/flowapp/forms/__init__.py new file mode 100644 index 0000000..a9d7d88 --- /dev/null +++ b/flowapp/forms/__init__.py @@ -0,0 +1,40 @@ +""" +Forms module for the flowapp application. +This file imports and re-exports all form classes from the module. +""" + +# Base field +from .base import MultiFormatDateTimeLocalField + +# User forms +from .user import UserForm, BulkUserForm + +# API key forms +from .api import ApiKeyForm, MachineApiKeyForm + +# Organization forms +from .organization import OrganizationForm + +# Action, ASPath, and Community forms +from .choices import ActionForm, ASPathForm, CommunityForm + +# Rule forms +from .rules import IPForm, IPv4Form, IPv6Form, RTBHForm, WhitelistForm + + +__all__ = [ + "MultiFormatDateTimeLocalField", + "UserForm", + "BulkUserForm", + "ApiKeyForm", + "MachineApiKeyForm", + "OrganizationForm", + "ActionForm", + "ASPathForm", + "CommunityForm", + "RTBHForm", + "IPForm", + "IPv4Form", + "IPv6Form", + "WhitelistForm", +] diff --git a/flowapp/forms/api.py b/flowapp/forms/api.py new file mode 100644 index 0000000..a1d9d20 --- /dev/null +++ b/flowapp/forms/api.py @@ -0,0 +1,61 @@ +""" +API key forms for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import StringField, TextAreaField, BooleanField, HiddenField +from wtforms.validators import DataRequired, IPAddress, Optional, Length + +from .base import MultiFormatDateTimeLocalField +from ..constants import FORM_TIME_PATTERN + + +class ApiKeyForm(FlaskForm): + """ + ApiKey for User + Each key / machine pair is unique + """ + + machine = StringField( + "Machine address", + validators=[DataRequired(), IPAddress(message="provide valid IP address")], + ) + + comment = TextAreaField("Your comment for this key", validators=[Optional(), Length(max=255)]) + + expires = MultiFormatDateTimeLocalField( + "Key expiration. Leave blank for non expring key (not-recomended).", + format=FORM_TIME_PATTERN, + validators=[Optional()], + unlimited=True, + ) + + readonly = BooleanField("Read only key", default=False) + + key = HiddenField("GeneratedKey") + + +class MachineApiKeyForm(FlaskForm): + """ + ApiKey for Machines + Each key / machine pair is unique + Only Admin can create new these keys + """ + + machine = StringField( + "Machine address", + validators=[DataRequired(), IPAddress(message="provide valid IP address")], + ) + + comment = TextAreaField("Your comment for this key", validators=[Optional(), Length(max=255)]) + + expires = MultiFormatDateTimeLocalField( + "Key expiration. Leave blank for non expring key (not-recomended).", + format=FORM_TIME_PATTERN, + validators=[Optional()], + unlimited=True, + ) + + readonly = BooleanField("Read only key", default=False) + + key = HiddenField("GeneratedKey") diff --git a/flowapp/forms/base.py b/flowapp/forms/base.py new file mode 100644 index 0000000..f99db23 --- /dev/null +++ b/flowapp/forms/base.py @@ -0,0 +1,50 @@ +""" +Base form fields for the flowapp application. +""" + +from wtforms import widgets +from wtforms.fields import DateTimeField +from ..utils import parse_api_time + + +class MultiFormatDateTimeLocalField(DateTimeField): + """ + Same as :class:`~wtforms.fields.DateTimeField`, but represents an + ````. + + Custom implementation uses default HTML5 format for parsing the field. + It's possible to use multiple formats - used in API. + + """ + + widget = widgets.DateTimeLocalInput() + + def __init__(self, *args, **kwargs): + kwargs.setdefault("format", "%Y-%m-%dT%H:%M") + self.unlimited = kwargs.pop("unlimited", False) + self.pref_format = None + super().__init__(*args, **kwargs) + + def process_formdata(self, valuelist): + if not valuelist or (len(valuelist) == 1 and not valuelist[0]): + return None + + # with unlimited field we do not need to parse the empty value + if self.unlimited and len(valuelist) == 1 and len(valuelist[0]) == 0: + self.data = None + return None + + date_str = " ".join((str(val) for val in valuelist)) + + try: + result, pref_format = parse_api_time(date_str) + except TypeError: + raise ValueError(self.gettext("Not a valid datetime value.")) + + if result: + self.data = result + self.pref_format = pref_format + else: + self.data = None + self.pref_format = None + raise ValueError(self.gettext("Not a valid datetime value.")) diff --git a/flowapp/forms/choices.py b/flowapp/forms/choices.py new file mode 100644 index 0000000..424b6b8 --- /dev/null +++ b/flowapp/forms/choices.py @@ -0,0 +1,81 @@ +""" +Action, ASPath, and Community forms for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import StringField, SelectField, BooleanField +from wtforms.validators import Length, DataRequired + + +class ActionForm(FlaskForm): + """ + Action form object + used in Admin + """ + + name = StringField("Action short name", validators=[Length(max=150)]) + + command = StringField("ExaBGP command", validators=[Length(max=150)]) + + description = StringField("Action description") + + role_id = SelectField( + "Minimal required role", + choices=[("2", "user"), ("3", "admin")], + validators=[DataRequired()], + ) + + +class ASPathForm(FlaskForm): + """ + AS Path form object + used in Admin + """ + + prefix = StringField("Prefix", validators=[Length(max=120), DataRequired()]) + + as_path = StringField("as-path value", validators=[Length(max=250), DataRequired()]) + + +class CommunityForm(FlaskForm): + """ + Community form object + used in Admin + """ + + name = StringField("Community short name", validators=[Length(max=120), DataRequired()]) + + comm = StringField("Community value", validators=[Length(max=2046)]) + + larcomm = StringField("Large community value", validators=[Length(max=2046)]) + + extcomm = StringField("Extended community value", validators=[Length(max=2046)]) + + description = StringField("Community description", validators=[Length(max=255)]) + + role_id = SelectField( + "Minimal required role", + choices=[("2", "user"), ("3", "admin")], + validators=[DataRequired()], + ) + + as_path = BooleanField("add AS-path (checked = true)") + + def validate(self): + """ + custom validation method + :return: boolean + """ + result = True + + if not FlaskForm.validate(self): + result = False + + if not self.comm.data and not self.extcomm.data and not self.larcomm.data: + err_message = "At last one of those values could not be empty" + self.comm.errors.append(err_message) + self.larcomm.errors.append(err_message) + self.extcomm.errors.append(err_message) + result = False + + return result diff --git a/flowapp/forms/organization.py b/flowapp/forms/organization.py new file mode 100644 index 0000000..37594ab --- /dev/null +++ b/flowapp/forms/organization.py @@ -0,0 +1,47 @@ +""" +Organization form for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import StringField, IntegerField, TextAreaField +from wtforms.validators import Optional, Length, NumberRange + +from ..validators import NetRangeString + + +class OrganizationForm(FlaskForm): + """ + Organization form object + used in Admin + """ + + name = StringField("Organization name", validators=[Optional(), Length(max=150)]) + + limit_flowspec4 = IntegerField( + "Maximum number of IPv4 rules, 0 for unlimited", + validators=[ + Optional(), + NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), + ], + ) + + limit_flowspec6 = IntegerField( + "Maximum number of IPv6 rules, 0 for unlimited", + validators=[ + Optional(), + NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), + ], + ) + + limit_rtbh = IntegerField( + "Maximum number of RTBH rules, 0 for unlimited", + validators=[ + Optional(), + NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), + ], + ) + + arange = TextAreaField( + "Organization Adress Range - one range per row", + validators=[Optional(), NetRangeString()], + ) diff --git a/flowapp/forms/rules/__init__.py b/flowapp/forms/rules/__init__.py new file mode 100644 index 0000000..9796e8c --- /dev/null +++ b/flowapp/forms/rules/__init__.py @@ -0,0 +1,17 @@ +""" +Rule forms for the flowapp application. +""" + +from .base import IPForm +from .ipv4 import IPv4Form +from .ipv6 import IPv6Form +from .rtbh import RTBHForm +from .whitelist import WhitelistForm + +__all__ = [ + "IPForm", + "IPv4Form", + "IPv6Form", + "RTBHForm", + "WhitelistForm", +] diff --git a/flowapp/forms/rules/base.py b/flowapp/forms/rules/base.py new file mode 100644 index 0000000..0e67776 --- /dev/null +++ b/flowapp/forms/rules/base.py @@ -0,0 +1,127 @@ +""" +Base rule form for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import SelectMultipleField, StringField, SelectField, TextAreaField +from wtforms.validators import Optional, Length, DataRequired, InputRequired + +from ..base import MultiFormatDateTimeLocalField +from ...constants import TCP_FLAGS +from ...validators import PortString, address_with_mask, network_in_range, whole_world_range + + +class IPForm(FlaskForm): + """ + Base class for IPv4 and IPv6 rules + """ + + def __init__(self, *args, **kwargs): + super(IPForm, self).__init__(*args, **kwargs) + self.net_ranges = None + + zero_address = None + source = None + source_mask = None + dest = None + dest_mask = None + flags = SelectMultipleField("TCP flag(s)", choices=TCP_FLAGS, validators=[Optional()]) + + source_port = StringField( + "Source port(s) - ; separated ", + validators=[Optional(), Length(max=255), PortString()], + ) + + dest_port = StringField( + "Destination port(s) - ; separated", + validators=[Optional(), Length(max=255), PortString()], + ) + + packet_len = StringField( + "Packet length - ; separated ", + validators=[Optional(), Length(max=255), PortString()], + ) + + action = SelectField( + "Action", + coerce=int, + validators=[DataRequired(message="Please select an action for the rule.")], + ) + + expires = MultiFormatDateTimeLocalField("Expires", format="%Y-%m-%dT%H:%M", validators=[InputRequired()]) + + comment = arange = TextAreaField("Comments") + + def validate(self): + """ + custom validation method + :return: boolean + """ + + result = True + if not FlaskForm.validate(self): + result = False + + source = self.validate_source_address() + dest = self.validate_dest_address() + ranges = self.validate_address_ranges() + ips = self.validate_ipv_specific() + + return result and source and dest and ranges and ips + + def validate_source_address(self): + """ + validate source address, set error message if validation fails + :return: boolean validation result + """ + if self.source.data and not address_with_mask(self.source.data, self.source_mask.data): + self.source.errors.append( + "This is not valid combination of address {} and mask {}.".format( + self.source.data, self.source_mask.data + ) + ) + return False + + return True + + def validate_dest_address(self): + """ + validate dest address, set error message if validation fails + :return: boolean validation result + """ + if self.dest.data and not address_with_mask(self.dest.data, self.dest_mask.data): + self.dest.errors.append( + "This is not valid combination of address {} and mask {}.".format(self.dest.data, self.dest_mask.data) + ) + return False + + return True + + def validate_address_ranges(self): + """ + validates if the address of source is in the user range + if the source and dest address are empty, check if the user + is member of whole world organization + :return: boolean validation result + """ + if not (self.source.data or self.dest.data): + whole_world_member = whole_world_range(self.net_ranges, self.zero_address) + if not whole_world_member: + self.source.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) + self.dest.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) + return False + else: + source_in_range = network_in_range(self.source.data, self.source_mask.data, self.net_ranges) + dest_in_range = network_in_range(self.dest.data, self.dest_mask.data, self.net_ranges) + if not (source_in_range or dest_in_range): + self.source.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) + self.dest.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) + return False + + return True + + def validate_ipv_specific(self): + """ + abstract method must be implemented in the subclass + """ + pass diff --git a/flowapp/forms/rules/ipv4.py b/flowapp/forms/rules/ipv4.py new file mode 100644 index 0000000..07f27a8 --- /dev/null +++ b/flowapp/forms/rules/ipv4.py @@ -0,0 +1,70 @@ +""" +IPv4 rule form for the flowapp application. +""" + +from wtforms import StringField, IntegerField, SelectField, SelectMultipleField +from wtforms.validators import Optional, NumberRange, DataRequired + +from ...constants import IPV4_PROTOCOL, IPV4_FRAGMENT +from ...validators import IPv4Address +from .base import IPForm + + +class IPv4Form(IPForm): + """ + IPv4 form object + """ + + def __init__(self, *args, **kwargs): + super(IPv4Form, self).__init__(*args, **kwargs) + self.net_ranges = None + + zero_address = "0.0.0.0" + source = StringField( + "Source address", + validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], + ) + + source_mask = IntegerField( + "Source mask (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=32, message="invalid mask value (0-32)"), + ], + ) + + dest = StringField( + "Destination address", + validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], + ) + + dest_mask = IntegerField( + "Destination mask (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=32, message="invalid mask value (0-32)"), + ], + ) + + protocol = SelectField( + "Protocol", + choices=[(pr, pr.upper()) for pr in IPV4_PROTOCOL.keys()], + validators=[DataRequired()], + ) + + fragment = SelectMultipleField( + "Fragment", + choices=[(frv, frk.upper()) for frk, frv in IPV4_FRAGMENT.items()], + validators=[Optional()], + ) + + def validate_ipv_specific(self): + """ + validate protocol and flags, set error message if validation fails + :return: boolean validation result + """ + + if self.flags.data and self.protocol.data and len(self.flags.data) > 0 and self.protocol.data != "tcp": + self.flags.errors.append("Can not set TCP flags for protocol {} !".format(self.protocol.data.upper())) + return False + return True diff --git a/flowapp/forms/rules/ipv6.py b/flowapp/forms/rules/ipv6.py new file mode 100644 index 0000000..aabd1b2 --- /dev/null +++ b/flowapp/forms/rules/ipv6.py @@ -0,0 +1,64 @@ +""" +IPv6 rule form for the flowapp application. +""" + +from wtforms import StringField, IntegerField, SelectField +from wtforms.validators import Optional, NumberRange, DataRequired + +from ...constants import IPV6_NEXT_HEADER +from ...validators import IPv6Address +from .base import IPForm + + +class IPv6Form(IPForm): + """ + IPv6 form object + """ + + def __init__(self, *args, **kwargs): + super(IPv6Form, self).__init__(*args, **kwargs) + self.net_ranges = None + + zero_address = "::" + source = StringField( + "Source address", + validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], + ) + + source_mask = IntegerField( + "Source prefix length (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=128, message="invalid prefix value (0-128)"), + ], + ) + + dest = StringField( + "Destination address", + validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], + ) + + dest_mask = IntegerField( + "Destination prefix length (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=128, message="invalid prefix value (0-128)"), + ], + ) + + next_header = SelectField( + "Next Header", + choices=[(pr, pr.upper()) for pr in IPV6_NEXT_HEADER.keys()], + validators=[DataRequired()], + ) + + def validate_ipv_specific(self): + """ + validate next header and flags, set error message if validation fails + :return: boolean validation result + """ + if len(self.flags.data) > 0 and self.next_header.data != "tcp": + self.flags.errors.append("Can not set TCP flags for next-header {} !".format(self.next_header.data.upper())) + return False + + return True diff --git a/flowapp/forms/rules/rtbh.py b/flowapp/forms/rules/rtbh.py new file mode 100644 index 0000000..cef7948 --- /dev/null +++ b/flowapp/forms/rules/rtbh.py @@ -0,0 +1,104 @@ +""" +RTBH rule form for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import StringField, IntegerField, SelectField, TextAreaField +from wtforms.validators import Optional, NumberRange, DataRequired, InputRequired + +from ...constants import FORM_TIME_PATTERN +from ...validators import IPv6Address, address_with_mask, address_in_range, IPv4Address +from ..base import MultiFormatDateTimeLocalField + + +class RTBHForm(FlaskForm): + """ + RoadToBlackHole rule form + """ + + def __init__(self, *args, **kwargs): + super(RTBHForm, self).__init__(*args, **kwargs) + self.net_ranges = None + + ipv4 = StringField( + "IPv4 address", + validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], + ) + + ipv4_mask = IntegerField( + "IPv4 mask (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=32, message="invalid IPv4 mask value (0-32)"), + ], + ) + + ipv6 = StringField( + "IPv6 address", + validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], + ) + + ipv6_mask = IntegerField( + "IPv6 mask (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=128, message="invalid IPv6 mask value (0-128)"), + ], + ) + + community = SelectField( + "Community", + coerce=int, + validators=[ + DataRequired(message="Please select a community for the rule."), + ], + ) + + expires = MultiFormatDateTimeLocalField( + "Expires", + format=FORM_TIME_PATTERN, + validators=[DataRequired(), InputRequired()], + ) + + comment = arange = TextAreaField("Comments") + + def validate(self): + """ + custom validation method + :return: boolean + """ + result = True + + if not FlaskForm.validate(self): + result = False + + # ipv4 and ipv6 are mutually exclusive + # if both are set, validation fails + # if none is set, validation fails + # if one is set, validation passes + if self.ipv4.data and self.ipv6.data: + self.ipv4.errors.append("IPv4 and IPv6 are mutually exclusive in RTBH rule.") + self.ipv6.errors.append("IPv4 and IPv6 are mutually exclusive in RTBH rule.") + result = False + + if self.ipv4.data and not address_with_mask(self.ipv4.data, self.ipv4_mask.data): + self.ipv4.errors.append( + "This is not valid combination of address {} and mask {}.".format(self.ipv4.data, self.ipv4_mask.data) + ) + result = False + + if self.ipv6.data and not address_with_mask(self.ipv6.data, self.ipv6_mask.data): + self.ipv6.errors.append( + "This is not valid combination of address {} and mask {}.".format(self.ipv6.data, self.ipv6_mask.data) + ) + result = False + + ipv6_in_range = address_in_range(self.ipv6.data, self.net_ranges) + ipv4_in_range = address_in_range(self.ipv4.data, self.net_ranges) + + if not (ipv6_in_range or ipv4_in_range): + self.ipv6.errors.append("IPv4 or IPv6 address must be in organization range : {}.".format(self.net_ranges)) + self.ipv4.errors.append("IPv4 or IPv6 address must be in organization range : {}.".format(self.net_ranges)) + result = False + + return result diff --git a/flowapp/forms/rules/whitelist.py b/flowapp/forms/rules/whitelist.py new file mode 100644 index 0000000..8aa88b4 --- /dev/null +++ b/flowapp/forms/rules/whitelist.py @@ -0,0 +1,66 @@ +""" +Whitelist form for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import StringField, IntegerField, TextAreaField +from wtforms.validators import Optional, DataRequired, InputRequired, Length + +from ...constants import FORM_TIME_PATTERN +from ...validators import IPAddressValidator, NetworkValidator, network_in_range +from ..base import MultiFormatDateTimeLocalField + + +class WhitelistForm(FlaskForm): + """ + Whitelist form object + Used for creating and editing whitelist entries + Supports both IPv4 and IPv6 addresses + """ + + def __init__(self, *args, **kwargs): + super(WhitelistForm, self).__init__(*args, **kwargs) + self.net_ranges = None + + ip = StringField( + "IP address", + validators=[ + DataRequired(message="Please provide an IP address"), + IPAddressValidator(message="Please provide a valid IP address: {}"), + NetworkValidator(mask_field_name="mask"), + ], + ) + + mask = IntegerField( + "Network mask (bits)", + validators=[ + DataRequired(message="Please provide a network mask"), + ], + ) + + comment = TextAreaField("Comments", validators=[Optional(), Length(max=255)]) + + expires = MultiFormatDateTimeLocalField( + "Expires", + format=FORM_TIME_PATTERN, + validators=[DataRequired(), InputRequired()], + ) + + def validate(self): + """ + Custom validation method + :return: boolean + """ + result = True + + if not FlaskForm.validate(self): + result = False + + # Validate IP is in organization range + if self.ip.data and self.mask.data and self.net_ranges: + ip_in_range = network_in_range(self.ip.data, self.mask.data, self.net_ranges) + if not ip_in_range: + self.ip.errors.append("IP address must be in organization range: {}.".format(self.net_ranges)) + result = False + + return result diff --git a/flowapp/forms/user.py b/flowapp/forms/user.py new file mode 100644 index 0000000..5944afd --- /dev/null +++ b/flowapp/forms/user.py @@ -0,0 +1,91 @@ +""" +User-related forms for the flowapp application. +""" + +import csv +from io import StringIO + +from flask_wtf import FlaskForm +from wtforms import StringField, TextAreaField, SelectMultipleField +from wtforms.validators import ValidationError, DataRequired, Email, InputRequired, Optional + + +class UserForm(FlaskForm): + """ + User Form object + used in Admin + """ + + uuid = StringField( + "Unique User ID", + validators=[ + InputRequired("Please provide UUID"), + Email("Please provide valid email"), + ], + ) + + email = StringField("Email", validators=[Optional(), Email("Please provide valid email")]) + + comment = StringField("Notice", validators=[Optional()]) + + name = StringField("Name", validators=[Optional()]) + + phone = StringField("Contact phone", validators=[Optional()]) + + role_ids = SelectMultipleField("Role", coerce=int, validators=[DataRequired("Select at last one role")]) + + org_ids = SelectMultipleField( + "Organization", + coerce=int, + validators=[DataRequired("We prefer one Organization per user, but it's possible select more")], + ) + + +class BulkUserForm(FlaskForm): + """ + Bulk User Form object + used in Admin + """ + + users = TextAreaField("Users in CSV - see example below", validators=[DataRequired()]) + + def __init__(self, *args, **kwargs): + super(BulkUserForm, self).__init__(*args, **kwargs) + self.roles = None + self.organizations = None + self.uuids = None + + # Custom validator for CSV data + def validate_users(self, field): + csv_data = field.data + + # Parse CSV data + csv_reader = csv.DictReader(StringIO(csv_data), delimiter=",") + + # List to keep track of failed validation rows + errors = 0 + for row_num, row in enumerate(csv_reader, start=1): + try: + # check if the user not already exists + if row["uuid-eppn"] in self.uuids: + field.errors.append(f"Row {row_num}: User with UUID {row['uuid-eppn']} already exists.") + errors += 1 + + # Check if role exists in the database + role_id = int(row["role"]) # Convert role field to integer + if role_id not in self.roles: + field.errors.append(f"Row {row_num}: Role ID {role_id} does not exist.") + errors += 1 + + # Check if organization exists in the database + org_id = int(row["organizace"]) # Convert organization field to integer + if org_id not in self.organizations: + field.errors.append(f"Row {row_num}: Organization ID {org_id} does not exist.") + errors += 1 + + except (KeyError, ValueError) as e: + field.errors.append(f"Row {row_num}: Invalid data / key - {str(e)}. Check CSV head row.") + + if errors > 0: + # Raise validation error if any invalid rows found + raise ValidationError("Invalid CSV Data - check the errors above.") From 500d60a914179aba2f57737d4f1d6c30b249445e Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 28 Feb 2025 13:16:49 +0100 Subject: [PATCH 10/36] refactoring of app factory in flowap/__init__.py module. --- flowapp/__init__.py | 174 +++-------------------- flowapp/utils/__init__.py | 46 +++++++ flowapp/utils/app_factory.py | 205 ++++++++++++++++++++++++++++ flowapp/{utils.py => utils/base.py} | 0 4 files changed, 269 insertions(+), 156 deletions(-) create mode 100644 flowapp/utils/__init__.py create mode 100644 flowapp/utils/app_factory.py rename flowapp/{utils.py => utils/base.py} (100%) diff --git a/flowapp/__init__.py b/flowapp/__init__.py index f029c37..b201f25 100644 --- a/flowapp/__init__.py +++ b/flowapp/__init__.py @@ -1,10 +1,6 @@ # -*- coding: utf-8 -*- -import babel -import logging -from loguru import logger +from flask import Flask, redirect, render_template, session, url_for -from flask import Flask, redirect, render_template, session, url_for, request -from flask.logging import default_handler from flask_sso import SSO from flask_sqlalchemy import SQLAlchemy from flask_wtf.csrf import CSRFProtect @@ -26,13 +22,6 @@ swagger = Swagger(template_file="static/swagger.yml") -class InterceptHandler(logging.Handler): - - def emit(self, record): - logger_opt = logger.opt(depth=6, exception=record.exc_info, colors=True) - logger_opt.log(record.levelname, record.getMessage()) - - def create_app(config_object=None): app = Flask(__name__) @@ -64,76 +53,28 @@ def create_app(config_object=None): if app.config.get("BEHIND_PROXY", False): app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1) - from flowapp import models, constants, validators - from .views.admin import admin - from .views.rules import rules - from .views.api_v1 import api as api_v1 - from .views.api_v2 import api as api_v2 - from .views.api_v3 import api as api_v3 - from .views.api_keys import api_keys + from flowapp import models, constants from .auth import auth_required - from .views.dashboard import dashboard - - # no need for csrf on api because we use JWT - csrf.exempt(api_v1) - csrf.exempt(api_v2) - csrf.exempt(api_v3) - - app.register_blueprint(admin, url_prefix="/admin") - app.register_blueprint(rules, url_prefix="/rules") - app.register_blueprint(api_keys, url_prefix="/api_keys") - app.register_blueprint(api_v1, url_prefix="/api/v1") - app.register_blueprint(api_v2, url_prefix="/api/v2") - app.register_blueprint(api_v3, url_prefix="/api/v3") - app.register_blueprint(dashboard, url_prefix="/dashboard") - - # register loguru as handler - app.logger.removeHandler(default_handler) - app.logger.addHandler(InterceptHandler()) - - @ext.login_handler - def login(user_info): - try: - uuid = user_info.get("eppn") - except KeyError: - uuid = False - return render_template("errors/401.html") - return _handle_login(uuid) + # Register blueprints + from .utils import register_blueprints - @app.route("/logout") - def logout(): - session["user_uuid"] = False - session["user_id"] = False - session.clear() - return redirect(app.config.get("LOGOUT_URL")) + register_blueprints(app, csrf) - @app.route("/ext-login") - def ext_login(): - header_name = app.config.get("AUTH_HEADER_NAME", "X-Authenticated-User") - if header_name not in request.headers: - return render_template("errors/401.html") + # configure logging + from .utils import configure_logging - uuid = request.headers.get(header_name) - if not uuid: - return render_template("errors/401.html") + configure_logging(app) - return _handle_login(uuid) + # register error handlers + from .utils import register_error_handlers - @app.route("/local-login") - def local_login(): - print("Local login started") - if not app.config.get("LOCAL_AUTH", False): - print("Local auth not enabled") - return render_template("errors/401.html") + register_error_handlers(app) - uuid = app.config.get("LOCAL_USER_UUID", False) - if not uuid: - print("Local user not set") - return render_template("errors/401.html") + # register auth handlers + from .utils import register_auth_handlers - print(f"Local login with {uuid}") - return _handle_login(uuid) + register_auth_handlers(app, ext) @app.route("/") @auth_required @@ -192,89 +133,10 @@ def select_org(org_id=None): def shutdown_session(exception=None): db.session.remove() - # HTTP error handling - @app.errorhandler(404) - def not_found(error): - return render_template("errors/404.html"), 404 - - @app.errorhandler(500) - def internal_error(exception): - app.logger.exception(exception) - return render_template("errors/500.html"), 500 - - @app.context_processor - def utility_processor(): - def editable_rule(rule): - if rule: - validators.editable_range(rule, models.get_user_nets(session["user_id"])) - return True - return False - - return dict(editable_rule=editable_rule) - - @app.context_processor - def inject_main_menu(): - """ - inject main menu config to templates - used in default template to create main menu - """ - return {"main_menu": app.config.get("MAIN_MENU")} - - @app.context_processor - def inject_dashboard(): - """ - inject dashboard config to templates - used in submenu dashboard to create dashboard tables - """ - return {"dashboard": app.config.get("DASHBOARD")} - - @app.template_filter("strftime") - def format_datetime(value): - if value is None: - return app.config.get("MISSING_DATETIME_MESSAGE", "Never") - - format = "y/MM/dd HH:mm" - return babel.dates.format_datetime(value, format) - - @app.template_filter("unlimited") - def unlimited_filter(value): - return "unlimited" if value == 0 else value - - def _handle_login(uuid: str): - """ - handles rest of login process - """ - multiple_orgs = False - try: - user, multiple_orgs = _register_user_to_session(uuid) - except AttributeError as e: - app.logger.exception(e) - return render_template("errors/401.html") - - if multiple_orgs: - return redirect(url_for("select_org", org_id=None)) - - # set user org to session - user_org = user.organization.first() - session["user_org"] = user_org.name - session["user_org_id"] = user_org.id + # register context processors and template filters + from .utils import register_context_processors, register_template_filters - return redirect("/") - - def _register_user_to_session(uuid: str): - print(f"Registering user {uuid} to session") - user = db.session.query(models.User).filter_by(uuid=uuid).first() - print(f"Got user {user} from DB") - session["user_uuid"] = user.uuid - session["user_email"] = user.uuid - session["user_name"] = user.name - session["user_id"] = user.id - session["user_roles"] = [role.name for role in user.role.all()] - session["user_role_ids"] = [role.id for role in user.role.all()] - roles = [i > 1 for i in session["user_role_ids"]] - session["can_edit"] = True if all(roles) and roles else [] - # check if user has multiple organizations and return True if so - print(f"DEBUG SESSION {session}") - return user, len(user.organization.all()) > 1 + register_context_processors(app) + register_template_filters(app) return app diff --git a/flowapp/utils/__init__.py b/flowapp/utils/__init__.py new file mode 100644 index 0000000..99ca325 --- /dev/null +++ b/flowapp/utils/__init__.py @@ -0,0 +1,46 @@ +from .base import ( + other_rtypes, + output_date_format, + parse_api_time, + quote_to_ent, + webpicker_to_datetime, + datetime_to_webpicker, + get_state_by_time, + round_to_ten_minutes, + flash_errors, + active_css_rstate, + get_comp_func, +) + +from .app_factory import ( + # configure_app, + configure_logging, + register_blueprints, + register_error_handlers, + register_context_processors, + register_template_filters, + register_auth_handlers, + # register_org_routes, +) + +__all__ = [ + "other_rtypes", + "output_date_format", + "parse_api_time", + "quote_to_ent", + "webpicker_to_datetime", + "datetime_to_webpicker", + "get_state_by_time", + "round_to_ten_minutes", + "flash_errors", + "active_css_rstate", + "get_comp_func", + # "configure_app", + "configure_logging", + "register_blueprints", + "register_error_handlers", + "register_context_processors", + "register_template_filters", + "register_auth_handlers", + # "register_org_routes", +] diff --git a/flowapp/utils/app_factory.py b/flowapp/utils/app_factory.py new file mode 100644 index 0000000..b53b23b --- /dev/null +++ b/flowapp/utils/app_factory.py @@ -0,0 +1,205 @@ +import logging +import babel +from flask import redirect, render_template, request, session, url_for +from flask.logging import default_handler +from loguru import logger + + +def register_blueprints(app, csrf=None): + """Register Flask blueprints.""" + from flowapp.views.admin import admin + from flowapp.views.rules import rules + from flowapp.views.api_v1 import api as api_v1 + from flowapp.views.api_v2 import api as api_v2 + from flowapp.views.api_v3 import api as api_v3 + from flowapp.views.api_keys import api_keys + from flowapp.views.dashboard import dashboard + from flowapp.views.whitelist import whitelist + + # Configure CSRF exemption for API routes + if csrf: + csrf.exempt(api_v1) + csrf.exempt(api_v2) + csrf.exempt(api_v3) + + # Register blueprints with URL prefixes + app.register_blueprint(admin, url_prefix="/admin") + app.register_blueprint(rules, url_prefix="/rules") + app.register_blueprint(api_keys, url_prefix="/api_keys") + app.register_blueprint(api_v1, url_prefix="/api/v1") + app.register_blueprint(api_v2, url_prefix="/api/v2") + app.register_blueprint(api_v3, url_prefix="/api/v3") + app.register_blueprint(dashboard, url_prefix="/dashboard") + app.register_blueprint(whitelist, url_prefix="/whitelist") + + return app + + +class InterceptHandler(logging.Handler): + + def emit(self, record): + logger_opt = logger.opt(depth=6, exception=record.exc_info, colors=True) + logger_opt.log(record.levelname, record.getMessage()) + + +def configure_logging(app): + """Configure logging for the application.""" + # register loguru as handler + app.logger.removeHandler(default_handler) + app.logger.addHandler(InterceptHandler()) + + return app + + +def register_error_handlers(app): + """Register error handlers.""" + + @app.errorhandler(404) + def not_found(error): + return render_template("errors/404.html"), 404 + + @app.errorhandler(500) + def internal_error(exception): + app.logger.exception(exception) + return render_template("errors/500.html"), 500 + + return app + + +def register_context_processors(app): + """Register template context processors.""" + + @app.context_processor + def utility_processor(): + def editable_rule(rule): + if rule: + from flowapp.validators import editable_range + from flowapp.models import get_user_nets + + editable_range(rule, get_user_nets(session["user_id"])) + return True + return False + + return dict(editable_rule=editable_rule) + + @app.context_processor + def inject_main_menu(): + """Inject main menu config to templates.""" + return {"main_menu": app.config.get("MAIN_MENU")} + + @app.context_processor + def inject_dashboard(): + """Inject dashboard config to templates.""" + return {"dashboard": app.config.get("DASHBOARD")} + + return app + + +def register_template_filters(app): + """Register custom template filters.""" + + @app.template_filter("strftime") + def format_datetime(value): + if value is None: + return app.config.get("MISSING_DATETIME_MESSAGE", "Never") + + format = "y/MM/dd HH:mm" + return babel.dates.format_datetime(value, format) + + @app.template_filter("unlimited") + def unlimited_filter(value): + return "unlimited" if value == 0 else value + + return app + + +def register_auth_handlers(app, ext): + """Register authentication handlers.""" + + @ext.login_handler + def login(user_info): + try: + uuid = user_info.get("eppn") + except KeyError: + uuid = False + return render_template("errors/401.html") + + return _handle_login(uuid, app) + + @app.route("/logout") + def logout(): + session["user_uuid"] = False + session["user_id"] = False + session.clear() + return redirect(app.config.get("LOGOUT_URL")) + + @app.route("/ext-login") + def ext_login(): + header_name = app.config.get("AUTH_HEADER_NAME", "X-Authenticated-User") + if header_name not in request.headers: + return render_template("errors/401.html") + + uuid = request.headers.get(header_name) + if not uuid: + return render_template("errors/401.html") + + return _handle_login(uuid, app) + + @app.route("/local-login") + def local_login(): + print("Local login started") + if not app.config.get("LOCAL_AUTH", False): + print("Local auth not enabled") + return render_template("errors/401.html") + + uuid = app.config.get("LOCAL_USER_UUID", False) + if not uuid: + print("Local user not set") + return render_template("errors/401.html") + + print(f"Local login with {uuid}") + return _handle_login(uuid, app) + + return app + + +def _handle_login(uuid, app): + """Handle login process for all authentication methods.""" + from flowapp import db + + multiple_orgs = False + try: + user, multiple_orgs = _register_user_to_session(uuid, db) + except AttributeError as e: + app.logger.exception(e) + return render_template("errors/401.html") + + if multiple_orgs: + return redirect(url_for("select_org", org_id=None)) + + # set user org to session + user_org = user.organization.first() + session["user_org"] = user_org.name + session["user_org_id"] = user_org.id + + return redirect("/") + + +def _register_user_to_session(uuid, db): + """Register user information to session.""" + from flowapp.models import User + + print(f"Registering user {uuid} to session") + user = db.session.query(User).filter_by(uuid=uuid).first() + print(f"Got user {user} from DB") + session["user_uuid"] = user.uuid + session["user_email"] = user.uuid + session["user_name"] = user.name + session["user_id"] = user.id + session["user_roles"] = [role.name for role in user.role.all()] + session["user_role_ids"] = [role.id for role in user.role.all()] + roles = [i > 1 for i in session["user_role_ids"]] + session["can_edit"] = True if all(roles) and roles else [] + # check if user has multiple organizations and return True if so + print(f"DEBUG SESSION {session}") + return user, len(user.organization.all()) > 1 diff --git a/flowapp/utils.py b/flowapp/utils/base.py similarity index 100% rename from flowapp/utils.py rename to flowapp/utils/base.py From 9b379135e68d5622a46e91a3bea82fa5c51388be Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Mon, 10 Mar 2025 09:48:17 +0100 Subject: [PATCH 11/36] Refactor rule creation logic to use rule_service Rule service is now used both in rules and api views Replaced direct database operations with calls to rule_service.create_or_update_*_rule for IPv4, IPv6, and RTBH rules Streamlined route announcements and logging by delegating responsibility to the service layer Improved maintainability and modularity by centralizing rule handling logic in rule_service --- flowapp/services/__init__.py | 11 ++ flowapp/services/rule_service.py | 219 +++++++++++++++++++++++++++++++ flowapp/views/api_common.py | 156 ++++------------------ flowapp/views/rules.py | 157 ++++------------------ 4 files changed, 277 insertions(+), 266 deletions(-) create mode 100644 flowapp/services/__init__.py create mode 100644 flowapp/services/rule_service.py diff --git a/flowapp/services/__init__.py b/flowapp/services/__init__.py new file mode 100644 index 0000000..b8fc6e4 --- /dev/null +++ b/flowapp/services/__init__.py @@ -0,0 +1,11 @@ +from .rule_service import ( + create_or_update_ipv4_rule, + create_or_update_ipv6_rule, + create_or_update_rtbh_rule, +) + +__all__ = [ + create_or_update_ipv4_rule, + create_or_update_ipv6_rule, + create_or_update_rtbh_rule, +] diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py new file mode 100644 index 0000000..0505ab1 --- /dev/null +++ b/flowapp/services/rule_service.py @@ -0,0 +1,219 @@ +# flowapp/services/rule_service.py +""" +Service module for rule operations. + +This module provides business logic functions for creating, updating, +and managing flow rules, separating these concerns from HTTP handling. +""" + +from typing import Dict, Tuple + +from flowapp import db, messages +from flowapp.constants import RuleTypes, ANNOUNCE +from flowapp.models import ( + get_ipv4_model_if_exists, + get_ipv6_model_if_exists, + get_rtbh_model_if_exists, + Flowspec4, + Flowspec6, + RTBH, +) +from flowapp.output import Route, announce_route, log_route, RouteSources +from flowapp.utils import round_to_ten_minutes, get_state_by_time, quote_to_ent + + +def create_or_update_ipv4_rule( + form_data: Dict, user_id: int, org_id: int, user_email: str, org_name: str +) -> Tuple[Flowspec4, str]: + """ + Create a new IPv4 rule or update an existing one. + + Args: + form_data: Validated form data + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + + Returns: + Tuple containing (rule_model, message) + """ + # Check for existing model + model = get_ipv4_model_if_exists(form_data, 1) + + if model: + model.expires = round_to_ten_minutes(form_data["expires"]) + flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value." + else: + # Create new model + model = Flowspec4( + source=form_data["source"], + source_mask=form_data["source_mask"], + source_port=form_data["source_port"], + destination=form_data["dest"], + destination_mask=form_data["dest_mask"], + destination_port=form_data["dest_port"], + protocol=form_data["protocol"], + flags=";".join(form_data["flags"]), + packet_len=form_data["packet_len"], + fragment=";".join(form_data["fragment"]), + expires=round_to_ten_minutes(form_data["expires"]), + comment=quote_to_ent(form_data["comment"]), + action_id=form_data["action"], + user_id=user_id, + org_id=org_id, + rstate_id=get_state_by_time(form_data["expires"]), + ) + db.session.add(model) + flash_message = "IPv4 Rule saved" + + db.session.commit() + + # Announce route if model is in active state + if model.rstate_id == 1: + command = messages.create_ipv4(model, ANNOUNCE) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + # Log changes + log_route( + user_id, + model, + RuleTypes.IPv4, + f"{user_email} / {org_name}", + ) + + return model, flash_message + + +def create_or_update_ipv6_rule( + form_data: Dict, user_id: int, org_id: int, user_email: str, org_name: str +) -> Tuple[Flowspec6, str]: + """ + Create a new IPv6 rule or update an existing one. + + Args: + form_data: Validated form data + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + + Returns: + Tuple containing (rule_model, message) + """ + # Check for existing model + model = get_ipv6_model_if_exists(form_data, 1) + + if model: + model.expires = round_to_ten_minutes(form_data["expires"]) + flash_message = "Existing IPv6 Rule found. Expiration time was updated to new value." + else: + # Create new model + model = Flowspec6( + source=form_data["source"], + source_mask=form_data["source_mask"], + source_port=form_data["source_port"], + destination=form_data["dest"], + destination_mask=form_data["dest_mask"], + destination_port=form_data["dest_port"], + next_header=form_data["next_header"], + flags=";".join(form_data["flags"]), + packet_len=form_data["packet_len"], + expires=round_to_ten_minutes(form_data["expires"]), + comment=quote_to_ent(form_data["comment"]), + action_id=form_data["action"], + user_id=user_id, + org_id=org_id, + rstate_id=get_state_by_time(form_data["expires"]), + ) + db.session.add(model) + flash_message = "IPv6 Rule saved" + + db.session.commit() + + # Announce routes + if model.rstate_id == 1: + command = messages.create_ipv6(model, ANNOUNCE) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + # Log changes + log_route( + user_id, + model, + RuleTypes.IPv6, + f"{user_email} / {org_name}", + ) + + return model, flash_message + + +def create_or_update_rtbh_rule( + form_data: Dict, user_id: int, org_id: int, user_email: str, org_name: str +) -> Tuple[RTBH, str]: + """ + Create a new RTBH rule or update an existing one. + + Args: + form_data: Validated form data + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + + Returns: + Tuple containing (rule_model, message) + """ + # Check for existing model + model = get_rtbh_model_if_exists(form_data, 1) + + if model: + model.expires = round_to_ten_minutes(form_data["expires"]) + flash_message = "Existing RTBH Rule found. Expiration time was updated to new value." + else: + # Create new model + model = RTBH( + ipv4=form_data["ipv4"], + ipv4_mask=form_data["ipv4_mask"], + ipv6=form_data["ipv6"], + ipv6_mask=form_data["ipv6_mask"], + community_id=form_data["community"], + expires=round_to_ten_minutes(form_data["expires"]), + comment=quote_to_ent(form_data["comment"]), + user_id=user_id, + org_id=org_id, + rstate_id=get_state_by_time(form_data["expires"]), + ) + db.session.add(model) + flash_message = "RTBH Rule saved" + + db.session.commit() + + # Announce routes + if model.rstate_id == 1: + command = messages.create_rtbh(model, ANNOUNCE) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + # Log changes + log_route( + user_id, + model, + RuleTypes.RTBH, + f"{user_email} / {org_name}", + ) + + return model, flash_message diff --git a/flowapp/views/api_common.py b/flowapp/views/api_common.py index b421eea..39b7077 100644 --- a/flowapp/views/api_common.py +++ b/flowapp/views/api_common.py @@ -5,7 +5,7 @@ from functools import wraps from datetime import datetime, timedelta -from flowapp.constants import RULE_NAMES_DICT, WITHDRAW, ANNOUNCE, TIME_FORMAT_ARG, RuleTypes +from flowapp.constants import RULE_NAMES_DICT, WITHDRAW, TIME_FORMAT_ARG, RuleTypes from flowapp.models import ( RTBH, Flowspec4, @@ -18,20 +18,16 @@ check_rule_limit, get_user_nets, get_user_actions, - get_ipv4_model_if_exists, - get_ipv6_model_if_exists, insert_initial_communities, get_user_communities, - get_rtbh_model_if_exists, ) from flowapp.forms import IPv4Form, IPv6Form, RTBHForm +from flowapp.services import rule_service from flowapp.utils import ( - quote_to_ent, - get_state_by_time, output_date_format, ) from flowapp.auth import check_access_rights -from flowapp.output import announce_route, log_route, log_withdraw, Route, RouteSources +from flowapp.output import announce_route, log_withdraw, Route, RouteSources from flowapp import db, validators, flowspec, messages @@ -249,52 +245,15 @@ def create_ipv4(current_user): if form_errors: return jsonify(form_errors), 400 - model = get_ipv4_model_if_exists(form.data, 1) - - if model: - model.expires = form.expires.data - flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value." - else: - model = Flowspec4( - source=form.source.data, - source_mask=form.source_mask.data, - source_port=form.source_port.data, - destination=form.dest.data, - destination_mask=form.dest_mask.data, - destination_port=form.dest_port.data, - protocol=form.protocol.data, - flags=";".join(form.flags.data), - packet_len=form.packet_len.data, - fragment=";".join(form.fragment.data), - expires=form.expires.data, - comment=quote_to_ent(form.comment.data), - action_id=form.action.data, - user_id=current_user["id"], - org_id=current_user["org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - flash_message = "IPv4 Rule saved" - db.session.add(model) - - db.session.commit() - - # announce route if model is in active state - if model.rstate_id == 1: - command = messages.create_ipv4(model, ANNOUNCE) - route = Route( - author=f"{current_user['uuid']} / {current_user['org']}", - source=RouteSources.API, - command=command, - ) - announce_route(route) - - # log changes - log_route( - current_user["id"], - model, - RuleTypes.IPv4, - f"{current_user['uuid']} / {current_user['org']}", + # Use the service to create/update the rule + model, flash_message = rule_service.create_or_update_ipv4_rule( + form_data=form.data, + user_id=current_user["id"], + org_id=current_user["org_id"], + user_email=current_user["uuid"], + org_name=current_user["org"], ) + pref_format = output_date_format(json_request_data, form.expires.pref_format) response = {"message": flash_message, "rule": model.to_dict(pref_format)} return jsonify(response), 201 @@ -326,50 +285,12 @@ def create_ipv6(current_user): if form_errors: return jsonify(form_errors), 400 - model = get_ipv6_model_if_exists(form.data, 1) - - if model: - model.expires = form.expires.data - flash_message = "Existing IPv6 Rule found. Expiration time was updated to new value." - else: - model = Flowspec6( - source=form.source.data, - source_mask=form.source_mask.data, - source_port=form.source_port.data, - destination=form.dest.data, - destination_mask=form.dest_mask.data, - destination_port=form.dest_port.data, - next_header=form.next_header.data, - flags=";".join(form.flags.data), - packet_len=form.packet_len.data, - expires=form.expires.data, - comment=quote_to_ent(form.comment.data), - action_id=form.action.data, - user_id=current_user["id"], - org_id=current_user["org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - flash_message = "IPv6 Rule saved" - db.session.add(model) - - db.session.commit() - - # announce routes - if model.rstate_id == 1: - command = messages.create_ipv6(model, ANNOUNCE) - route = Route( - author=f"{current_user['uuid']} / {current_user['org']}", - source=RouteSources.API, - command=command, - ) - announce_route(route) - - # log changes - log_route( - current_user["id"], - model, - RuleTypes.IPv6, - f"{current_user['uuid']} / {current_user['org']}", + model, flash_message = rule_service.create_or_update_ipv6_rule( + form_data=form.data, + user_id=current_user["id"], + org_id=current_user["org_id"], + user_email=current_user["uuid"], + org_name=current_user["org"], ) pref_format = output_date_format(json_request_data, form.expires.pref_format) @@ -406,43 +327,12 @@ def create_rtbh(current_user): if form_errors: return jsonify(form_errors), 400 - model = get_rtbh_model_if_exists(form.data, 1) - - if model: - model.expires = form.expires.data - flash_message = "Existing RTBH Rule found. Expiration time was updated to new value." - else: - model = RTBH( - ipv4=form.ipv4.data, - ipv4_mask=form.ipv4_mask.data, - ipv6=form.ipv6.data, - ipv6_mask=form.ipv6_mask.data, - community_id=form.community.data, - expires=form.expires.data, - comment=quote_to_ent(form.comment.data), - user_id=current_user["id"], - org_id=current_user["org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - db.session.add(model) - db.session.commit() - flash_message = "RTBH Rule saved" - - # announce routes - if model.rstate_id == 1: - command = messages.create_rtbh(model, ANNOUNCE) - route = Route( - author=f"{current_user['uuid']} / {current_user['org']}", - source=RouteSources.API, - command=command, - ) - announce_route(route) - # log changes - log_route( - current_user["id"], - model, - RuleTypes.RTBH, - f"{current_user['uuid']} / {current_user['org']}", + model, flash_message = rule_service.create_or_update_rtbh_rule( + form_data=form.data, + user_id=current_user["id"], + org_id=current_user["org_id"], + user_email=current_user["uuid"], + org_name=current_user["org"], ) pref_format = output_date_format(json_request_data, form.expires.pref_format) diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py index 5618789..f0acfb4 100644 --- a/flowapp/views/rules.py +++ b/flowapp/views/rules.py @@ -24,19 +24,16 @@ Organization, check_global_rule_limit, check_rule_limit, - get_ipv4_model_if_exists, - get_ipv6_model_if_exists, - get_rtbh_model_if_exists, get_user_actions, get_user_communities, get_user_nets, insert_initial_communities, ) from flowapp.output import ROUTE_MODELS, announce_route, log_route, log_withdraw, RouteSources, Route +from flowapp.services import rule_service from flowapp.utils import ( flash_errors, get_state_by_time, - quote_to_ent, round_to_ten_minutes, ) @@ -476,54 +473,17 @@ def ipv4_rule(): form.net_ranges = net_ranges if request.method == "POST" and form.validate(): - model = get_ipv4_model_if_exists(form.data, 1) - - if model: - model.expires = round_to_ten_minutes(form.expires.data) - flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value." - else: - model = Flowspec4( - source=form.source.data, - source_mask=form.source_mask.data, - source_port=form.source_port.data, - destination=form.dest.data, - destination_mask=form.dest_mask.data, - destination_port=form.dest_port.data, - protocol=form.protocol.data, - flags=";".join(form.flags.data), - packet_len=form.packet_len.data, - fragment=";".join(form.fragment.data), - expires=round_to_ten_minutes(form.expires.data), - comment=quote_to_ent(form.comment.data), - action_id=form.action.data, - user_id=session["user_id"], - org_id=session["user_org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - flash_message = "IPv4 Rule saved" - db.session.add(model) - - db.session.commit() - flash(flash_message, "alert-success") - - # announce route if model is in active state - if model.rstate_id == 1: - command = messages.create_ipv4(model, constants.ANNOUNCE) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - - # log changes - log_route( - session["user_id"], - model, - RuleTypes.IPv4, - f"{session['user_email']} / {session['user_org']}", + # Use the service to create/update the rule + _model, message = rule_service.create_or_update_ipv4_rule( + form_data=form.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], ) + flash(message, "alert-success") + return redirect(url_for("index")) else: for field, errors in form.errors.items(): @@ -560,52 +520,14 @@ def ipv6_rule(): form.net_ranges = net_ranges if request.method == "POST" and form.validate(): - model = get_ipv6_model_if_exists(form.data, 1) - - if model: - model.expires = round_to_ten_minutes(form.expires.data) - flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value." - else: - model = Flowspec6( - source=form.source.data, - source_mask=form.source_mask.data, - source_port=form.source_port.data, - destination=form.dest.data, - destination_mask=form.dest_mask.data, - destination_port=form.dest_port.data, - next_header=form.next_header.data, - flags=";".join(form.flags.data), - packet_len=form.packet_len.data, - expires=round_to_ten_minutes(form.expires.data), - comment=quote_to_ent(form.comment.data), - action_id=form.action.data, - user_id=session["user_id"], - org_id=session["user_org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - flash_message = "IPv6 Rule saved" - db.session.add(model) - - db.session.commit() - flash(flash_message, "alert-success") - - # announce routes - if model.rstate_id == 1: - command = messages.create_ipv6(model, constants.ANNOUNCE) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - - # log changes - log_route( - session["user_id"], - model, - RuleTypes.IPv6, - f"{session['user_email']} / {session['user_org']}", + _model, message = rule_service.create_or_update_ipv6_rule( + form_data=form.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], ) + flash(message, "alert-success") return redirect(url_for("index")) else: @@ -644,45 +566,14 @@ def rtbh_rule(): form.net_ranges = net_ranges if request.method == "POST" and form.validate(): - model = get_rtbh_model_if_exists(form.data, 1) - - if model: - model.expires = round_to_ten_minutes(form.expires.data) - flash_message = "Existing RTBH Rule found. Expiration time was updated to new value." - else: - model = RTBH( - ipv4=form.ipv4.data, - ipv4_mask=form.ipv4_mask.data, - ipv6=form.ipv6.data, - ipv6_mask=form.ipv6_mask.data, - community_id=form.community.data, - expires=round_to_ten_minutes(form.expires.data), - comment=quote_to_ent(form.comment.data), - user_id=session["user_id"], - org_id=session["user_org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - db.session.add(model) - db.session.commit() - flash_message = "RTBH Rule saved" - - flash(flash_message, "alert-success") - # announce routes - if model.rstate_id == 1: - command = messages.create_rtbh(model, constants.ANNOUNCE) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - # log changes - log_route( - session["user_id"], - model, - RuleTypes.RTBH, - f"{session['user_email']} / {session['user_org']}", + _model, message = rule_service.create_or_update_rtbh_rule( + form_data=form.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], ) + flash(message, "alert-success") return redirect(url_for("index")) else: From 74bfefe6ae4508776338d4191d7e7c72013a0fa9 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Tue, 11 Mar 2025 12:55:30 +0100 Subject: [PATCH 12/36] Done - add whitelist, display whitelist in dashboard. WIP - basic whitelist handling when RTBH rule is created. --- flowapp/instance_config.py | 20 ++- flowapp/models/__init__.py | 2 + flowapp/models/rules/base.py | 1 + flowapp/models/rules/rtbh.py | 16 +++ flowapp/models/rules/whitelist.py | 13 +- flowapp/models/utils.py | 50 ++++++- flowapp/services/__init__.py | 3 + flowapp/services/rule_service.py | 44 +++++- flowapp/services/whitelist_service.py | 178 ++++++++++++++++++++++++ flowapp/templates/forms/whitelist.html | 62 +++++++++ flowapp/templates/macros.html | 48 ++++++- flowapp/tests/test_models.py | 3 + flowapp/tests/test_whitelist_service.py | 146 +++++++++++++++++++ flowapp/views/rules.py | 33 +++-- flowapp/views/whitelist.py | 66 ++------- 15 files changed, 603 insertions(+), 82 deletions(-) create mode 100644 flowapp/services/whitelist_service.py create mode 100644 flowapp/templates/forms/whitelist.html create mode 100644 flowapp/tests/test_whitelist_service.py diff --git a/flowapp/instance_config.py b/flowapp/instance_config.py index 1ffcb1f..86bc5fb 100644 --- a/flowapp/instance_config.py +++ b/flowapp/instance_config.py @@ -2,12 +2,19 @@ # column names for tables RTBH_COLUMNS = ( - ("ipv4", "IP adress (v4 or v6)"), + ("ipv4", "IP address (v4 or v6)"), ("community_id", "Community"), ("expires", "Expires"), ("user_id", "User"), ) +WHITELIST_COLUMNS = ( + ("address", "IP address / network (v4 or v6)"), + ("expires", "Expires"), + ("user_id", "User"), +) + + RULES_COLUMNS_V4 = ( ("source", "Source addr."), ("source_port", "S port"), @@ -75,6 +82,7 @@ class InstanceConfig: {"name": "Add IPv6", "url": "rules.ipv6_rule"}, {"name": "Add RTBH", "url": "rules.rtbh_rule"}, {"name": "API Key", "url": "api_keys.all"}, + {"name": "Add Whitelist", "url": "whitelist.add"}, ], "admin": [ {"name": "Commands Log", "url": "admin.log"}, @@ -123,6 +131,14 @@ class InstanceConfig: "table_colspan": 5, "table_columns": RTBH_COLUMNS, }, + "whitelist": { + "name": "Whitelist", + "macro_file": "macros.html", + "macro_tbody": "build_whitelist_tbody", + "macro_thead": "build_rules_thead", + "table_colspan": 4, + "table_columns": WHITELIST_COLUMNS, + }, } - COUNT_MATCH = {"ipv4": 0, "ipv6": 0, "rtbh": 0} + COUNT_MATCH = {"ipv4": 0, "ipv6": 0, "rtbh": 0, "whitelist": 0} diff --git a/flowapp/models/__init__.py b/flowapp/models/__init__.py index 6d74ffe..6405505 100644 --- a/flowapp/models/__init__.py +++ b/flowapp/models/__init__.py @@ -23,6 +23,7 @@ get_ipv4_model_if_exists, get_ipv6_model_if_exists, get_rtbh_model_if_exists, + get_whitelist_model_if_exists, get_ip_rules, get_user_rules_ids, insert_users, @@ -59,6 +60,7 @@ "get_ipv4_model_if_exists", "get_ipv6_model_if_exists", "get_rtbh_model_if_exists", + "get_whitelist_model_if_exists", "get_ip_rules", "get_user_rules_ids", "insert_users", diff --git a/flowapp/models/rules/base.py b/flowapp/models/rules/base.py index ab61f5e..22fbc08 100644 --- a/flowapp/models/rules/base.py +++ b/flowapp/models/rules/base.py @@ -37,6 +37,7 @@ def insert_initial_rulestates(table, conn, *args, **kwargs): conn.execute(table.insert().values(description="active rule")) conn.execute(table.insert().values(description="withdrawed rule")) conn.execute(table.insert().values(description="deleted rule")) + conn.execute(table.insert().values(description="whitelisted rule")) @event.listens_for(Action.__table__, "after_create") diff --git a/flowapp/models/rules/rtbh.py b/flowapp/models/rules/rtbh.py index d75220c..f183b81 100644 --- a/flowapp/models/rules/rtbh.py +++ b/flowapp/models/rules/rtbh.py @@ -132,3 +132,19 @@ def json(self, prefered_format="yearfirst"): :returns: json """ return json.dumps(self.to_dict()) + + def __repr__(self): + if not self.ipv6 and not self.ipv6_mask: + return f"" + if not self.ipv4 and not self.ipv4_mask: + return f"" + + return f"" + + def __str__(self): + if not self.ipv6 and not self.ipv6_mask: + return f"{self.ipv4}/{self.ipv4_mask}" + if not self.ipv4 and not self.ipv4_mask: + return f"{self.ipv6}/{self.ipv6_mask}" + + return f"{self.ipv4}/{self.ipv4_mask} {self.ipv6}/{self.ipv6_mask}" diff --git a/flowapp/models/rules/whitelist.py b/flowapp/models/rules/whitelist.py index 5e93714..f4c87ac 100644 --- a/flowapp/models/rules/whitelist.py +++ b/flowapp/models/rules/whitelist.py @@ -1,6 +1,7 @@ from flowapp import utils from ..base import db from datetime import datetime +from flowapp.constants import RuleTypes, RuleOrigin class Whitelist(db.Model): @@ -72,6 +73,12 @@ def to_dict(self, prefered_format="yearfirst"): "rstate": self.rstate.description, } + def __repr__(self): + return f"" + + def __str__(self): + return f"{self.ip}/{self.mask}" + class RuleWhitelistCache(db.Model): """ @@ -87,8 +94,8 @@ class RuleWhitelistCache(db.Model): whitelist_id = db.Column(db.Integer, db.ForeignKey("whitelist.id")) # Add ForeignKey whitelist = db.relationship("Whitelist", backref="rulewhitelistcache") - def __init__(self, rid, rtype, rorigin, whitelist_id): + def __init__(self, rid: int, rtype: RuleTypes, whitelist_id: int, rorigin: RuleOrigin = RuleOrigin.USER): self.rid = rid - self.rtype = rtype - self.rorigin = rorigin + self.rtype = rtype.value + self.rorigin = rorigin.value self.whitelist_id = whitelist_id diff --git a/flowapp/models/utils.py b/flowapp/models/utils.py index e2129b2..80a1af6 100644 --- a/flowapp/models/utils.py +++ b/flowapp/models/utils.py @@ -4,6 +4,8 @@ from flowapp import utils from flowapp.constants import RuleTypes from flask import current_app + +from flowapp.models.rules.whitelist import Whitelist from .base import db from .user import User, Role from .organization import Organization @@ -57,9 +59,33 @@ def check_global_rule_limit(rule_type: RuleTypes) -> bool: return rtbh >= rtbh_limit +def get_whitelist_model_if_exists(form_data, rstate_id=1): + """ + Check if the record in database exist + ip, mask, rstate_id should match + expires, user_id, org_id, created, comment can be different + """ + record = ( + db.session.query(Whitelist) + .filter( + Whitelist.ip == form_data["ip"], + Whitelist.mask == form_data["mask"], + Whitelist.rstate_id == rstate_id, + ) + .first() + ) + + if record: + return record + + return False + + def get_ipv4_model_if_exists(form_data, rstate_id=1): """ Check if the record in database exist + Source and destination addresses, protocol, flags, action and packet_len should match + Other fields can be different """ record = ( db.session.query(Flowspec4) @@ -113,9 +139,11 @@ def get_ipv6_model_if_exists(form_data, rstate_id=1): return False -def get_rtbh_model_if_exists(form_data, rstate_id=1): +def get_rtbh_model_if_exists(form_data): """ - Check if the record in database exist + Check RTBH rule exist in database + IPv4, IPv6 and community_id should match + Rule can be in any state and have different expires, user_id, org_id, created, comment """ record = ( @@ -126,7 +154,6 @@ def get_rtbh_model_if_exists(form_data, rstate_id=1): RTBH.ipv6 == form_data["ipv6"], RTBH.ipv6_mask == form_data["ipv6_mask"], RTBH.community_id == form_data["community"], - RTBH.rstate_id == rstate_id, ) .first() ) @@ -308,6 +335,23 @@ def get_ip_rules(rule_type, rule_state, sort="expires", order="desc"): return rules_rtbh + if rule_type == "whitelist": + sorter_whitelist = getattr(Whitelist, sort, Whitelist.id) + sorting_whitelist = getattr(sorter_whitelist, order) + + if comp_func: + rules_whitelist = ( + db.session.query(Whitelist) + .filter(comp_func(Whitelist.expires, today)) + .order_by(sorting_whitelist()) + .all() + ) + + else: + rules_whitelist = db.session.query(Whitelist).order_by(sorting_whitelist()).all() + + return rules_whitelist + def get_user_rules_ids(user_id, rule_type): """ diff --git a/flowapp/services/__init__.py b/flowapp/services/__init__.py index b8fc6e4..11766b3 100644 --- a/flowapp/services/__init__.py +++ b/flowapp/services/__init__.py @@ -4,8 +4,11 @@ create_or_update_rtbh_rule, ) +from .whitelist_service import create_or_update_whitelist + __all__ = [ create_or_update_ipv4_rule, create_or_update_ipv6_rule, create_or_update_rtbh_rule, + create_or_update_whitelist, ] diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index 0505ab1..5636c53 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -2,14 +2,15 @@ """ Service module for rule operations. -This module provides business logic functions for creating, updating, +This module provides business logic functions for creating, updating, and managing flow rules, separating these concerns from HTTP handling. """ +from datetime import datetime from typing import Dict, Tuple from flowapp import db, messages -from flowapp.constants import RuleTypes, ANNOUNCE +from flowapp.constants import RuleOrigin, RuleTypes, ANNOUNCE from flowapp.models import ( get_ipv4_model_if_exists, get_ipv6_model_if_exists, @@ -17,9 +18,12 @@ Flowspec4, Flowspec6, RTBH, + Whitelist, + RuleWhitelistCache, ) from flowapp.output import Route, announce_route, log_route, RouteSources from flowapp.utils import round_to_ten_minutes, get_state_by_time, quote_to_ent +from .whitelist_service import check_rule_against_whitelists, Relation def create_or_update_ipv4_rule( @@ -174,7 +178,7 @@ def create_or_update_rtbh_rule( Tuple containing (rule_model, message) """ # Check for existing model - model = get_rtbh_model_if_exists(form_data, 1) + model = get_rtbh_model_if_exists(form_data) if model: model.expires = round_to_ten_minutes(form_data["expires"]) @@ -198,6 +202,24 @@ def create_or_update_rtbh_rule( db.session.commit() + # Check if rule is whitelisted + # get all not expired whitelists + whitelists = db.session.query(Whitelist).filter(Whitelist.expires > datetime.now()).all() + whitelists = {str(w): w for w in whitelists} + results = check_rule_against_whitelists(str(model), whitelists.keys()) + # check rule against whitelists, stop search when rule is whitelisted first time + for rule, whitelist_key, relation in results: + match relation: + case Relation.EQUAL: + model = whitelist_rtbh_rule(model, whitelists[whitelist_key]) + break + case Relation.SUBNET: + print("WL is subnet of rule") + break + case Relation.SUPERNET: + model = whitelist_rtbh_rule(model, whitelists[whitelist_key]) + break + # Announce routes if model.rstate_id == 1: command = messages.create_rtbh(model, ANNOUNCE) @@ -217,3 +239,19 @@ def create_or_update_rtbh_rule( ) return model, flash_message + + +def whitelist_rtbh_rule(model: RTBH, whitelist: Whitelist) -> RTBH: + """ + Whitelist RTBH rule. + set state to 4 - whitelisted rule, do not announce + Add to whitelist cache + """ + model.rstate_id = 4 + db.session.commit() + # add to cache + cache = RuleWhitelistCache(rid=model.id, rtype=RuleTypes.RTBH, whitelist_id=whitelist.id, rorigin=RuleOrigin.USER) + db.session.add(cache) + db.session.commit() + + return model diff --git a/flowapp/services/whitelist_service.py b/flowapp/services/whitelist_service.py new file mode 100644 index 0000000..9652922 --- /dev/null +++ b/flowapp/services/whitelist_service.py @@ -0,0 +1,178 @@ +# flowapp/services/rule_service.py +""" +Service module for rule operations. + +This module provides business logic functions for creating, updating, +and managing flow rules, separating these concerns from HTTP handling. +""" + +from typing import Dict, Tuple, List +from enum import Enum, auto +import ipaddress +from functools import lru_cache + +from flowapp import db +from flowapp.models import Whitelist, get_whitelist_model_if_exists +from flowapp.utils import round_to_ten_minutes, quote_to_ent + + +def create_or_update_whitelist( + form_data: Dict, user_id: int, org_id: int, user_email: str, org_name: str +) -> Tuple[Whitelist, str]: + """ + Create a new Whitelist rule or update expiration time of an existing one. + + Args: + form_data: Validated form data + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + + Returns: + Tuple containing (whitelist_model, message) + """ + # Check for existing model + model = get_whitelist_model_if_exists(form_data, 1) + + if model: + model.expires = round_to_ten_minutes(form_data["expires"]) + flash_message = "Existing Whitelist found. Expiration time was updated to new value." + else: + # Create new model + model = Whitelist( + ip=form_data["ip"], + mask=form_data["mask"], + expires=round_to_ten_minutes(form_data["expires"]), + user_id=user_id, + org_id=org_id, + comment=quote_to_ent(form_data["comment"]), + ) + db.session.add(model) + flash_message = "Whitelist saved" + + db.session.commit() + + return model, flash_message + + +class Relation(Enum): + SUBNET = auto() + SUPERNET = auto() + EQUAL = auto() + DIFFERENT = auto() + + +@lru_cache(maxsize=1024) +def get_network(address: str) -> ipaddress.IPv4Network | ipaddress.IPv6Network: + """ + Create and cache an IP network object. + + :param address: IP address or network in string format + :return: Cached IP network object + """ + return ipaddress.ip_network(address, strict=False) + + +def check_whitelist_to_rule_relation(rule: str, whitelist_entry: str) -> Relation: + """ + Checks if the whitelist network is a subnet or supernet or exactly the same as the rule network. + Uses cached network objects for better performance. + + :param rule: The IP address or network to check (e.g., "192.168.1.1" or "192.168.1.0/24") + :param whitelist_entry: The allowed network to compare against (e.g., "192.168.1.0/24") + :return: Relation between the two networks + """ + rule_net = get_network(rule) + whitelist_net = get_network(whitelist_entry) + if whitelist_net == rule_net: + return Relation.EQUAL + if whitelist_net.supernet_of(rule_net): + return Relation.SUPERNET + if whitelist_net.subnet_of(rule_net): + return Relation.SUBNET + else: + return Relation.DIFFERENT + + +def subtract_network(target: str, whitelist: str) -> List[str]: + """ + Computes the remaining parts of a network after removing the whitelist subnet. + Uses cached network objects for better performance. + + :param target: The main network (e.g., "192.168.1.0/24") + :param whitelist: The subnet to remove (e.g., "192.168.1.128/25") + :return: A list of remaining subnets as strings + """ + target_net = get_network(target) + whitelist_net = get_network(whitelist) + + # Check if the whitelist is actually a subnet + if check_whitelist_to_rule_relation(target, whitelist) != Relation.SUBNET: + return [target] # Return the full network if whitelist isn't a valid subnet + + remaining = [] + + # Compute ranges before and after the whitelist + if whitelist_net.network_address > target_net.network_address: + # Before the whitelist + start = target_net.network_address + end = whitelist_net.network_address - 1 + remaining.extend(ipaddress.summarize_address_range(start, end)) + + if whitelist_net.broadcast_address < target_net.broadcast_address: + # After the whitelist + start = whitelist_net.broadcast_address + 1 + end = target_net.broadcast_address + remaining.extend(ipaddress.summarize_address_range(start, end)) + + # Convert to string format + return [str(net) for net in remaining] + + +def check_rule_against_whitelists(rule: str, whitelists: List[str]) -> List[Tuple]: + """ + Helper function to check a single rule against multiple whitelist entries. + Creates a cached rule network object for better performance. + Reduces list of whitelists, where the Relation is not DIFFERENT + + :param rule: The IP address or network to check + :param whitelists: List of whitelist networks to check against + :return: tuple of rule, whitelist and relation for each whitelists that is not DIFFERENT + """ + # Pre-cache the rule network since it will be used multiple times + get_network(rule) + items = [] + for whitelist in whitelists: + rel = check_whitelist_to_rule_relation(rule, whitelist) + if rel != Relation.DIFFERENT: + items.append((rule, whitelist, rel)) + return items + + +def check_whitelist_against_rules(rules: List[str], whitelist: str) -> List[Tuple]: + """ + Helper function to check if any whitelist entry is a subnet of the rule. + Creates a cached rule network object for better performance. + Reduces list of rules, where the Relation is not DIFFERENT + + :param rule: The IP address or network to check against + :param whitelists: List of whitelist networks to check + :return: tuple of rule, whitelist and relation for each whitelists that is not DIFFERENT + """ + # Pre-cache the rule network since it will be used multiple times + get_network(whitelist) + items = [] + for rule in rules: + rel = check_whitelist_to_rule_relation(rule, whitelist) + if rel != Relation.DIFFERENT: + items.append((rule, whitelist, rel)) + return items + + +def clear_network_cache() -> None: + """ + Clear the network object cache. + Useful when processing a large number of networks to prevent memory growth. + """ + get_network.cache_clear() diff --git a/flowapp/templates/forms/whitelist.html b/flowapp/templates/forms/whitelist.html new file mode 100644 index 0000000..35559d7 --- /dev/null +++ b/flowapp/templates/forms/whitelist.html @@ -0,0 +1,62 @@ +{% extends 'layouts/default.html' %} +{% from 'forms/macros.html' import render_field %} +{% block title %}Add Whitelist{% endblock %} +{% block content %} +

{{ title or 'New'}} Whitelist

+
+ {{ form.hidden_tag() if form.hidden_tag }} +
+
+
+
+ {{ render_field(form.ip) }} +
+
+ {{ render_field(form.mask) }} +
+
+
+
+
+ +
+ +
+
+
+ +
+ + {{ form.expires(class_='form-control') }} + + + +
+ + {% if form.expires.errors %} + {% for e in form.expires.errors %} +

{{ e }}

+ {% endfor %} + {% endif %} +
+
+
+
+
+ +
+
+ {{ render_field(form.comment) }} +
+
+ +
+
+ +
+
+
+{% endblock %} \ No newline at end of file diff --git a/flowapp/templates/macros.html b/flowapp/templates/macros.html index 9e0b4b7..9d51d3d 100644 --- a/flowapp/templates/macros.html +++ b/flowapp/templates/macros.html @@ -8,7 +8,7 @@ {% set rtype_int = 4 %} {% endif %} - + {{ rule.source }}{% if rule.source_mask != none %}{{ '/' if rule.source_mask >= 0 else '' }}{{ rule.source_mask if rule.source_mask >= 0 else '' }}{% endif %} @@ -76,7 +76,10 @@ {% macro build_rtbh_tbody(rules, today, editable=True, group_op=True) %} {% for rule in rules %} - + {% if rule.ipv4 %} {{ rule.ipv4 }}{{ '/' if rule.ipv4_mask else '' }}{{rule.ipv4_mask|default("", True)}} @@ -122,6 +125,47 @@ {% endmacro %} +{% macro build_whitelist_tbody(rules, today, editable=True, group_op=True) %} + +{% for rule in rules %} + + + {{ rule.ip }}{{ '/' if rule.mask else '' }}{{rule.mask|default("", True)}} + + + {{ rule.expires|strftime }} + + + {{ rule.user.name }} + + + {% if editable %} + + + + + + + {% endif %} + {% if rule.comment %} + + {% endif %} + + {% if editable and group_op %} + + + + {% endif %} + + +{% endfor %} + +{% endmacro %} + + + {% macro build_rules_thead(rules_columns, rtype, rstate, sort_key, sort_order, search_query='', group_op=True) %} diff --git a/flowapp/tests/test_models.py b/flowapp/tests/test_models.py index 82dff8e..b352a5f 100644 --- a/flowapp/tests/test_models.py +++ b/flowapp/tests/test_models.py @@ -473,6 +473,9 @@ def test_whitelist_to_dict(db): db.session.add_all([user, rstate]) db.session.commit() + db.session.add(whitelist) + db.session.commit() + whitelist.user = user whitelist.rstate_id = rstate.id db.session.add(whitelist) diff --git a/flowapp/tests/test_whitelist_service.py b/flowapp/tests/test_whitelist_service.py new file mode 100644 index 0000000..7ae4990 --- /dev/null +++ b/flowapp/tests/test_whitelist_service.py @@ -0,0 +1,146 @@ +import pytest +from flowapp.services.whitelist_service import ( + Relation, + check_whitelist_to_rule_relation, + subtract_network, + check_rule_against_whitelists, + check_whitelist_against_rules, + clear_network_cache, +) + + +# Tests for core function that checks network relations +@pytest.mark.parametrize( + "rule,whitelist,expected", + [ + # IPv4 test cases + ("192.168.1.0/24", "192.168.0.0/16", Relation.SUPERNET), # Whitelist is supernet + ("192.168.1.0/24", "192.168.1.0/24", Relation.EQUAL), # Equal networks + ("192.168.1.0/24", "192.168.1.128/25", Relation.SUBNET), # Whitelist is subnet + ("192.168.1.128/25", "192.168.1.0/24", Relation.SUPERNET), # Whitelist is supernet + ("10.0.0.0/8", "192.168.1.0/24", Relation.DIFFERENT), # Different networks + # IPv6 test cases + ("2001:db8::/32", "2001:db8::/32", Relation.EQUAL), # Equal networks + ("2001:db8:1::/48", "2001:db8::/32", Relation.SUPERNET), # Whitelist is supernet + ("2001:db8::/32", "2001:db8:1::/48", Relation.SUBNET), # Whitelist is subnet + ("2001:db8::/32", "2002:db8::/32", Relation.DIFFERENT), # Different networks + ("2001:db8:1:2::/64", "2001:db8::/32", Relation.SUPERNET), # Whitelist is supernet + ], +) +def test_check_whitelist_to_rule_relation(rule, whitelist, expected): + clear_network_cache() + assert check_whitelist_to_rule_relation(rule, whitelist) == expected + + +@pytest.mark.parametrize( + "target,whitelist,expected", + [ + # Basic IPv4 cases + ( + "192.168.1.0/24", + "192.168.1.128/25", + ["192.168.1.0/25"], + ), # One remaining subnet + ( + "192.168.1.0/24", + "192.168.1.64/26", + ["192.168.1.0/26", "192.168.1.128/25"], + ), # Two remaining subnets + ( + "192.168.1.0/24", + "192.168.1.0/24", + ["192.168.1.0/24"], + ), # Equal networks - return original network + ( + "192.168.1.0/24", + "192.168.2.0/24", + ["192.168.1.0/24"], + ), # No overlap + ], +) +def test_subtract_network(target, whitelist, expected): + clear_network_cache() + result = subtract_network(target, whitelist) + assert sorted(result) == sorted(expected) + + +def test_check_rule_against_whitelists(): + clear_network_cache() + + rule = "192.168.1.0/24" + whitelists = [ + "192.168.0.0/16", # SUPERNET + "172.16.0.0/12", # DIFFERENT + "192.168.1.0/24", # EQUAL + "192.168.1.128/25", # SUBNET + ] + + results = check_rule_against_whitelists(rule, whitelists) + + # Should return list of tuples (rule, whitelist, relation) for non-DIFFERENT relations + expected = [ + (rule, "192.168.0.0/16", Relation.SUPERNET), + (rule, "192.168.1.0/24", Relation.EQUAL), + (rule, "192.168.1.128/25", Relation.SUBNET), + ] + + assert len(results) == 3 + for result, exp in zip(sorted(results), sorted(expected)): + assert result == exp + + +def test_check_whitelist_against_rules(): + clear_network_cache() + + whitelist = "192.168.1.128/25" + rules = [ + "192.168.0.0/16", # Whitelist is SUBNET of this larger network + "172.16.0.0/12", # DIFFERENT + "192.168.1.128/25", # EQUAL + "192.168.1.0/24", # Whitelist is SUBNET of this network + ] + + results = check_whitelist_against_rules(rules, whitelist) + + # Should return list of tuples (rule, whitelist, relation) for non-DIFFERENT relations + expected = [ + ("192.168.0.0/16", whitelist, Relation.SUBNET), + ("192.168.1.128/25", whitelist, Relation.EQUAL), + ("192.168.1.0/24", whitelist, Relation.SUBNET), + ] + + assert len(results) == 3 + for result, exp in zip(sorted(results), sorted(expected)): + assert result == exp + + +# Tests for edge cases and error handling +def test_host_addresses(): + clear_network_cache() + # Test with host addresses (no explicit subnet mask) + assert check_whitelist_to_rule_relation("192.168.1.1", "192.168.1.0/24") == Relation.SUPERNET + assert check_whitelist_to_rule_relation("192.168.1.1", "192.168.2.0/24") == Relation.DIFFERENT + + +def test_single_ip_as_network(): + clear_network_cache() + # Test with /32 (single IPv4 address as network) + assert check_whitelist_to_rule_relation("192.168.1.1/32", "192.168.1.0/24") == Relation.SUPERNET + # Test with /128 (single IPv6 address as network) + assert check_whitelist_to_rule_relation("2001:db8::1/128", "2001:db8::/32") == Relation.SUPERNET + + +def test_invalid_input(): + clear_network_cache() + with pytest.raises(ValueError): + check_whitelist_to_rule_relation("invalid", "192.168.1.0/24") + + with pytest.raises(ValueError): + check_whitelist_to_rule_relation("192.168.1.0/24", "invalid") + + with pytest.raises(ValueError): + subtract_network("invalid", "192.168.1.0/24") + + +if __name__ == "__main__": + pytest.main(["-v"]) diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py index f0acfb4..4805e44 100644 --- a/flowapp/views/rules.py +++ b/flowapp/views/rules.py @@ -58,9 +58,13 @@ def reactivate_rule(rule_type, rule_id): """ Set new time for the rule of given type identified by id - :param rule_type: string - type of rule + :param rule_type: integer - type of rule, corresponds to RuleTypes enum value :param rule_id: integer - id of the rule """ + # Convert the integer rule_type to RuleTypes enum + enum_rule_type = RuleTypes(rule_type) + + # Now use the enum value where needed but the integer for dictionary lookups model_name = DATA_MODELS[rule_type] form_name = DATA_FORMS[rule_type] @@ -72,14 +76,14 @@ def reactivate_rule(rule_type, rule_id): form.action.choices = [(g.id, g.name) for g in db.session.query(Action).order_by("name")] form.action.data = model.action_id - if rule_type == 1: + if rule_type == RuleTypes.RTBH.value: form.community.choices = get_user_communities(session["user_role_ids"]) form.community.data = model.community_id - if rule_type == 4: + if rule_type == RuleTypes.IPv4.value: form.protocol.data = model.protocol - if rule_type == 6: + if rule_type == RuleTypes.IPv6.value: form.next_header.data = model.next_header # do not need to validate - all is readonly @@ -114,11 +118,11 @@ def reactivate_rule(rule_type, rule_id): command=command, ) announce_route(route) - # log changes + # log changes - Use the enum value here log_route( session["user_id"], model, - rule_type, + enum_rule_type, # Pass the enum instead of integer f"{session['user_email']} / {session['user_org']}", ) else: @@ -130,11 +134,11 @@ def reactivate_rule(rule_type, rule_id): command=command, ) announce_route(route) - # log changes + # log changes - Use the enum value here log_withdraw( session["user_id"], route.command, - rule_type, + enum_rule_type, # Pass the enum instead of integer model.id, f"{session['user_email']} / {session['user_org']}", ) @@ -183,6 +187,9 @@ def delete_rule(rule_type, rule_id): model_name = DATA_MODELS[rule_type] route_model = ROUTE_MODELS[rule_type] + # Convert the integer rule_type to RuleTypes enum + enum_rule_type = RuleTypes(rule_type) + model = db.session.get(model_name, rule_id) if model.id in session[constants.RULES_KEY]: # withdraw route @@ -197,7 +204,7 @@ def delete_rule(rule_type, rule_id): log_withdraw( session["user_id"], route.command, - rule_type, + enum_rule_type, model.id, f"{session['user_email']} / {session['user_org']}", ) @@ -257,6 +264,7 @@ def group_delete(): rule_type = session[constants.TYPE_ARG] model_name = DATA_MODELS_NAMED[rule_type] rule_type_int = constants.RULE_TYPES_DICT[rule_type] + enum_rule_type = RuleTypes(rule_type_int) route_model = ROUTE_MODELS[rule_type_int] rules = [str(x) for x in session[constants.RULES_KEY]] to_delete = request.form.getlist("delete-id") @@ -276,7 +284,7 @@ def group_delete(): log_withdraw( session["user_id"], route.command, - rule_type_int, + enum_rule_type, model.id, f"{session['user_email']} / {session['user_org']}", ) @@ -373,6 +381,7 @@ def group_update_save(rule_type): model_name = DATA_MODELS[rule_type] form_name = DATA_FORMS[rule_type] + enum_rule_type = RuleTypes(rule_type) form = form_name(request.form) @@ -414,7 +423,7 @@ def group_update_save(rule_type): log_route( session["user_id"], model, - rule_type, + enum_rule_type, f"{session['user_email']} / {session['user_org']}", ) else: @@ -430,7 +439,7 @@ def group_update_save(rule_type): log_withdraw( session["user_id"], route.command, - rule_type, + enum_rule_type, model.id, f"{session['user_email']} / {session['user_org']}", ) diff --git a/flowapp/views/whitelist.py b/flowapp/views/whitelist.py index adf3129..723416e 100644 --- a/flowapp/views/whitelist.py +++ b/flowapp/views/whitelist.py @@ -1,24 +1,15 @@ from datetime import datetime, timedelta from flask import Blueprint, current_app, flash, redirect, render_template, request, session, url_for -from flowapp import constants, db, messages from flowapp.auth import ( auth_required, user_or_admin_required, ) -from flowapp.constants import RuleTypes from flowapp.forms import WhitelistForm from flowapp.models import ( - Whitelist, get_user_nets, ) -from flowapp.output import ROUTE_MODELS, announce_route, log_route, log_withdraw, RouteSources, Route -from flowapp.utils import ( - flash_errors, - get_state_by_time, - quote_to_ent, - round_to_ten_minutes, -) +from flowapp.services import create_or_update_whitelist whitelist = Blueprint("whitelist", __name__, template_folder="templates") @@ -33,53 +24,14 @@ def add(): form.net_ranges = net_ranges if request.method == "POST" and form.validate(): - model = get_ipv4_model_if_exists(form.data, 1) - - if model: - model.expires = round_to_ten_minutes(form.expires.data) - flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value." - else: - model = Flowspec4( - source=form.source.data, - source_mask=form.source_mask.data, - source_port=form.source_port.data, - destination=form.dest.data, - destination_mask=form.dest_mask.data, - destination_port=form.dest_port.data, - protocol=form.protocol.data, - flags=";".join(form.flags.data), - packet_len=form.packet_len.data, - fragment=";".join(form.fragment.data), - expires=round_to_ten_minutes(form.expires.data), - comment=quote_to_ent(form.comment.data), - action_id=form.action.data, - user_id=session["user_id"], - org_id=session["user_org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - flash_message = "IPv4 Rule saved" - db.session.add(model) - - db.session.commit() - flash(flash_message, "alert-success") - - # announce route if model is in active state - if model.rstate_id == 1: - command = messages.create_ipv4(model, constants.ANNOUNCE) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - - # log changes - log_route( - session["user_id"], - model, - RuleTypes.IPv4, - f"{session['user_email']} / {session['user_org']}", + model, message = create_or_update_whitelist( + form.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], ) + flash(message, "alert-success") return redirect(url_for("index")) else: @@ -91,4 +43,4 @@ def add(): default_expires = datetime.now() + timedelta(hours=1) form.expires.data = default_expires - return render_template("forms/ipv4_rule.html", form=form, action_url=url_for("rules.ipv4_rule")) + return render_template("forms/whitelist.html", form=form, action_url=url_for("whitelist.add")) From 5b22513fab7c0238ef94bccc0ea430cf4a37bca8 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Wed, 12 Mar 2025 17:56:17 +0100 Subject: [PATCH 13/36] whitelist time update and delete completed, whitelist cache clean implemented --- flowapp/models/rules/whitelist.py | 34 ++++++++++ flowapp/models/utils.py | 7 +- flowapp/services/__init__.py | 3 +- flowapp/services/base.py | 18 +++++ flowapp/services/rule_service.py | 96 ++++++++++++++++++--------- flowapp/services/whitelist_service.py | 47 ++++++++++++- flowapp/templates/macros.html | 4 +- flowapp/views/rules.py | 5 +- flowapp/views/whitelist.py | 88 ++++++++++++++++++++++-- 9 files changed, 254 insertions(+), 48 deletions(-) create mode 100644 flowapp/services/base.py diff --git a/flowapp/models/rules/whitelist.py b/flowapp/models/rules/whitelist.py index f4c87ac..e1857cc 100644 --- a/flowapp/models/rules/whitelist.py +++ b/flowapp/models/rules/whitelist.py @@ -99,3 +99,37 @@ def __init__(self, rid: int, rtype: RuleTypes, whitelist_id: int, rorigin: RuleO self.rtype = rtype.value self.rorigin = rorigin.value self.whitelist_id = whitelist_id + + @classmethod + def get_by_whitelist_id(cls, whitelist_id: int): + """ + Get all cache items with the given whitelist ID + + Args: + whitelist_id (int): The ID of the whitelist to filter by + + Returns: + list: All RuleWhitelistCache objects with the specified whitelist_id + """ + return cls.query.filter_by(whitelist_id=whitelist_id).all() + + @classmethod + def clean_by_whitelist_id(cls, whitelist_id: int): + """ + Delete all cache entries with the given whitelist ID from the database + + Args: + whitelist_id (int): The ID of the whitelist to clean + + Returns: + int: Number of rows deleted + """ + deleted = cls.query.filter_by(whitelist_id=whitelist_id).delete() + db.session.commit() + return deleted + + def __repr__(self): + return f"" + + def __str__(self): + return f"{self.rid} {self.rtype} {self.rorigin}" diff --git a/flowapp/models/utils.py b/flowapp/models/utils.py index 80a1af6..7472d72 100644 --- a/flowapp/models/utils.py +++ b/flowapp/models/utils.py @@ -59,18 +59,17 @@ def check_global_rule_limit(rule_type: RuleTypes) -> bool: return rtbh >= rtbh_limit -def get_whitelist_model_if_exists(form_data, rstate_id=1): +def get_whitelist_model_if_exists(form_data): """ Check if the record in database exist - ip, mask, rstate_id should match - expires, user_id, org_id, created, comment can be different + ip, mask should match + expires, rstate_id, user_id, org_id, created, comment can be different """ record = ( db.session.query(Whitelist) .filter( Whitelist.ip == form_data["ip"], Whitelist.mask == form_data["mask"], - Whitelist.rstate_id == rstate_id, ) .first() ) diff --git a/flowapp/services/__init__.py b/flowapp/services/__init__.py index 11766b3..05b9695 100644 --- a/flowapp/services/__init__.py +++ b/flowapp/services/__init__.py @@ -4,11 +4,12 @@ create_or_update_rtbh_rule, ) -from .whitelist_service import create_or_update_whitelist +from .whitelist_service import create_or_update_whitelist, delete_whitelist __all__ = [ create_or_update_ipv4_rule, create_or_update_ipv6_rule, create_or_update_rtbh_rule, create_or_update_whitelist, + delete_whitelist, ] diff --git a/flowapp/services/base.py b/flowapp/services/base.py new file mode 100644 index 0000000..2f4bc26 --- /dev/null +++ b/flowapp/services/base.py @@ -0,0 +1,18 @@ +from flowapp import messages +from flowapp.constants import ANNOUNCE +from flowapp.models import RTBH +from flowapp.output import Route, RouteSources, announce_route + + +def announce_rtbh_route(model: RTBH, author: str) -> None: + """ + Announce RTBH route if rule is in active state + """ + if model.rstate_id == 1: + command = messages.create_rtbh(model, ANNOUNCE) + route = Route( + author=author, + source=RouteSources.UI, + command=command, + ) + announce_route(route) diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index 5636c53..b8db5f7 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -7,7 +7,7 @@ """ from datetime import datetime -from typing import Dict, Tuple +from typing import Dict, List, Tuple from flowapp import db, messages from flowapp.constants import RuleOrigin, RuleTypes, ANNOUNCE @@ -22,8 +22,9 @@ RuleWhitelistCache, ) from flowapp.output import Route, announce_route, log_route, RouteSources +from flowapp.services.base import announce_rtbh_route from flowapp.utils import round_to_ten_minutes, get_state_by_time, quote_to_ent -from .whitelist_service import check_rule_against_whitelists, Relation +from .whitelist_service import check_rule_against_whitelists, Relation, subtract_network def create_or_update_ipv4_rule( @@ -163,7 +164,7 @@ def create_or_update_ipv6_rule( def create_or_update_rtbh_rule( form_data: Dict, user_id: int, org_id: int, user_email: str, org_name: str -) -> Tuple[RTBH, str]: +) -> Tuple[RTBH, List]: """ Create a new RTBH rule or update an existing one. @@ -179,10 +180,10 @@ def create_or_update_rtbh_rule( """ # Check for existing model model = get_rtbh_model_if_exists(form_data) - + flashes = [] if model: model.expires = round_to_ten_minutes(form_data["expires"]) - flash_message = "Existing RTBH Rule found. Expiration time was updated to new value." + flashes.append("Existing RTBH Rule found. Expiration time was updated to new value.") else: # Create new model model = RTBH( @@ -198,60 +199,89 @@ def create_or_update_rtbh_rule( rstate_id=get_state_by_time(form_data["expires"]), ) db.session.add(model) - flash_message = "RTBH Rule saved" + flashes.append("RTBH Rule saved") db.session.commit() + # rule author for logging and announcing + author = f"{user_email} / {org_name}" + # Check if rule is whitelisted # get all not expired whitelists whitelists = db.session.query(Whitelist).filter(Whitelist.expires > datetime.now()).all() - whitelists = {str(w): w for w in whitelists} - results = check_rule_against_whitelists(str(model), whitelists.keys()) + wl_cache = {str(w): w for w in whitelists} + results = check_rule_against_whitelists(str(model), wl_cache.keys()) # check rule against whitelists, stop search when rule is whitelisted first time for rule, whitelist_key, relation in results: match relation: case Relation.EQUAL: - model = whitelist_rtbh_rule(model, whitelists[whitelist_key]) + model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) + flashes.append(f" Rule is equal to active whitelist {whitelist_key}. Rule is whitelisted.") break case Relation.SUBNET: - print("WL is subnet of rule") + # split subnet into parts + parts = subtract_network(target=str(model), whitelist=whitelist_key) + wl_id = wl_cache[whitelist_key].id + flashes.append( + f" Rule is supernet of active whitelist {whitelist_key}. Rule is whitelisted, {len(parts)} subnet rules created." + ) + for network in parts: + create_rtbh_from_whitelist_parts(model, wl_id, whitelist_key, network, author, user_id) + flashes.append(f"DEBUG: Created RTBH rule for {network}, from whitelist {whitelist_key}") + + model.rstate_id = 4 + add_rtbh_rule_to_cache(model, wl_id, RuleOrigin.USER) + db.session.commit() break case Relation.SUPERNET: - model = whitelist_rtbh_rule(model, whitelists[whitelist_key]) + model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) + flashes.append(f" Rule is subnet of active whitelist {whitelist_key}. Rule is whitelisted.") break - # Announce routes - if model.rstate_id == 1: - command = messages.create_rtbh(model, ANNOUNCE) - route = Route( - author=f"{user_email} / {org_name}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - + announce_rtbh_route(model, author=author) # Log changes - log_route( - user_id, - model, - RuleTypes.RTBH, - f"{user_email} / {org_name}", + log_route(user_id, model, RuleTypes.RTBH, author) + + return model, flashes + + +def create_rtbh_from_whitelist_parts( + model: RTBH, wl_id: int, whitelist_key: str, network: str, rule_owner: str, user_id: int +) -> None: + net_ip, net_mask = network.split("/") + new_model = RTBH( + ipv4=net_ip, + ipv4_mask=net_mask, + ipv6=model.ipv6, + ipv6_mask=model.ipv6_mask, + community_id=model.community_id, + expires=model.expires, + comment=model.comment, + user_id=model.user_id, + org_id=model.org_id, + rstate_id=1, ) + db.session.add(new_model) + db.session.commit() + add_rtbh_rule_to_cache(new_model, wl_id, RuleOrigin.WHITELIST) + announce_rtbh_route(new_model, rule_owner) + log_route(user_id, model, RuleTypes.RTBH, rule_owner) - return model, flash_message + +def add_rtbh_rule_to_cache(model: RTBH, whitelist_id: int, rule_origin: RuleOrigin = RuleOrigin.USER) -> None: + # add to cache + cache = RuleWhitelistCache(rid=model.id, rtype=RuleTypes.RTBH, whitelist_id=whitelist_id, rorigin=rule_origin) + db.session.add(cache) + db.session.commit() def whitelist_rtbh_rule(model: RTBH, whitelist: Whitelist) -> RTBH: """ Whitelist RTBH rule. - set state to 4 - whitelisted rule, do not announce + Set rule state to 4 - whitelisted rule, do not announce later Add to whitelist cache """ model.rstate_id = 4 db.session.commit() - # add to cache - cache = RuleWhitelistCache(rid=model.id, rtype=RuleTypes.RTBH, whitelist_id=whitelist.id, rorigin=RuleOrigin.USER) - db.session.add(cache) - db.session.commit() - + add_rtbh_rule_to_cache(model, whitelist.id, RuleOrigin.USER) return model diff --git a/flowapp/services/whitelist_service.py b/flowapp/services/whitelist_service.py index 9652922..0f8096c 100644 --- a/flowapp/services/whitelist_service.py +++ b/flowapp/services/whitelist_service.py @@ -12,7 +12,11 @@ from functools import lru_cache from flowapp import db -from flowapp.models import Whitelist, get_whitelist_model_if_exists +from flowapp.constants import RuleOrigin, RuleTypes +from flowapp.models import Whitelist, RuleWhitelistCache, get_whitelist_model_if_exists +from flowapp.models.rules.flowspec import Flowspec4, Flowspec6 +from flowapp.models.rules.rtbh import RTBH +from flowapp.services.base import announce_rtbh_route from flowapp.utils import round_to_ten_minutes, quote_to_ent @@ -33,7 +37,7 @@ def create_or_update_whitelist( Tuple containing (whitelist_model, message) """ # Check for existing model - model = get_whitelist_model_if_exists(form_data, 1) + model = get_whitelist_model_if_exists(form_data) if model: model.expires = round_to_ten_minutes(form_data["expires"]) @@ -56,6 +60,45 @@ def create_or_update_whitelist( return model, flash_message +def delete_whitelist(whitelist_id: int) -> List[str]: + """ + Delete a whitelist entry from the database. + + Args: + whitelist_id: The ID of the whitelist to delete + """ + model = db.session.get(Whitelist, whitelist_id) + flashes = [] + if model: + cached_rules = RuleWhitelistCache.get_by_whitelist_id(whitelist_id) + for cached_rule in cached_rules: + rule_model_type = RuleTypes(cached_rule.rtype) + match rule_model_type: + case RuleTypes.IPv4: + rule_model = db.session.get(Flowspec4, cached_rule.rid) + case RuleTypes.IPv6: + rule_model = db.session.get(Flowspec6, cached_rule.rid) + case RuleTypes.RTBH: + rule_model = db.session.get(RTBH, cached_rule.rid) + rorigin_type = RuleOrigin(cached_rule.rorigin) + if rorigin_type == RuleOrigin.WHITELIST: + flashes.append(f"Deleted rule {rule_model} created by this whitelist") + db.session.delete(rule_model) + elif rorigin_type == RuleOrigin.USER: + flashes.append(f"Set rule {rule_model} back to state 'Active'") + rule_model.rstate_id = 1 # Set rule state to "Active" again + author = f"{model.user.email} ({model.user.organization})" + announce_rtbh_route(rule_model, author) + + flashes.append(f"Deleted cache entries for whitelist {whitelist_id}") + RuleWhitelistCache.clean_by_whitelist_id(whitelist_id) + + db.session.delete(model) + db.session.commit() + + return flashes + + class Relation(Enum): SUBNET = auto() SUPERNET = auto() diff --git a/flowapp/templates/macros.html b/flowapp/templates/macros.html index 9d51d3d..653cc81 100644 --- a/flowapp/templates/macros.html +++ b/flowapp/templates/macros.html @@ -140,10 +140,10 @@ {% if editable %} - + - + {% endif %} diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py index 4805e44..fc787cf 100644 --- a/flowapp/views/rules.py +++ b/flowapp/views/rules.py @@ -575,14 +575,15 @@ def rtbh_rule(): form.net_ranges = net_ranges if request.method == "POST" and form.validate(): - _model, message = rule_service.create_or_update_rtbh_rule( + _model, messages = rule_service.create_or_update_rtbh_rule( form_data=form.data, user_id=session["user_id"], org_id=session["user_org_id"], user_email=session["user_email"], org_name=session["user_org"], ) - flash(message, "alert-success") + for message in messages: + flash(message, "alert-success") return redirect(url_for("index")) else: diff --git a/flowapp/views/whitelist.py b/flowapp/views/whitelist.py index 723416e..9f2bcda 100644 --- a/flowapp/views/whitelist.py +++ b/flowapp/views/whitelist.py @@ -5,11 +5,11 @@ auth_required, user_or_admin_required, ) +from flowapp import constants, db from flowapp.forms import WhitelistForm -from flowapp.models import ( - get_user_nets, -) -from flowapp.services import create_or_update_whitelist +from flowapp.models import get_user_nets, Whitelist +from flowapp.services import create_or_update_whitelist, delete_whitelist +from flowapp.utils.base import flash_errors whitelist = Blueprint("whitelist", __name__, template_folder="templates") @@ -44,3 +44,83 @@ def add(): form.expires.data = default_expires return render_template("forms/whitelist.html", form=form, action_url=url_for("whitelist.add")) + + +@whitelist.route("/reactivate/", methods=["GET", "POST"]) +@auth_required +@user_or_admin_required +def reactivate(wl_id): + """ + Set new time for whitelist + :param wl_id: int - id of the whitelist + """ + + model = db.session.get(Whitelist, wl_id) + form = WhitelistForm(request.form, obj=model) + form.net_ranges = get_user_nets(session["user_id"]) + + # do not need to validate - all is readonly + if request.method == "POST": + model = create_or_update_whitelist( + form.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], + ) + flash("Whitelist updated", "alert-success") + return redirect( + url_for( + "dashboard.index", + rtype=session[constants.TYPE_ARG], + rstate=session[constants.RULE_ARG], + sort=session[constants.SORT_ARG], + squery=session[constants.SEARCH_ARG], + order=session[constants.ORDER_ARG], + ) + ) + else: + flash_errors(form) + + form.expires.data = model.expires + for field in form: + if field.name not in ["expires", "csrf_token", "comment"]: + field.render_kw = {"disabled": "disabled"} + + action_url = url_for("whitelist.reactivate", wl_id=wl_id) + + return render_template( + "forms/whitelist.html", + form=form, + action_url=action_url, + editing=True, + title="Update", + ) + + +@whitelist.route("/delete/", methods=["GET"]) +@auth_required +@user_or_admin_required +def delete(wl_id): + """ + Delete whitelist + :param wl_id: integer - id of the whitelist + """ + if wl_id in session[constants.RULES_KEY]: + messages = delete_whitelist(wl_id) + flash(f"Whitelist {wl_id} deleted", "alert-success") + for message in messages: + flash(message, "alert-info") + else: + flash("You can not delete this Whitelist", "alert-warning") + + return redirect( + url_for( + "dashboard.index", + rtype=session[constants.TYPE_ARG], + rstate=session[constants.RULE_ARG], + sort=session[constants.SORT_ARG], + squery=session[constants.SEARCH_ARG], + order=session[constants.ORDER_ARG], + ) + ) From 356e1f332d24cf685012cd2d224b57ef61aeac37 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Wed, 12 Mar 2025 18:38:15 +0100 Subject: [PATCH 14/36] refactoring of whitelist services / split into common lib for rules and whitelist handling services to avoid ciruclar imports --- flowapp/instance_config.py | 2 +- flowapp/services/base.py | 15 +- flowapp/services/rule_service.py | 23 +-- flowapp/services/whitelist_common.py | 150 ++++++++++++++++++ flowapp/services/whitelist_service.py | 125 --------------- ...st_service.py => test_whitelist_common.py} | 6 +- 6 files changed, 168 insertions(+), 153 deletions(-) create mode 100644 flowapp/services/whitelist_common.py rename flowapp/tests/{test_whitelist_service.py => test_whitelist_common.py} (99%) diff --git a/flowapp/instance_config.py b/flowapp/instance_config.py index 86bc5fb..4136530 100644 --- a/flowapp/instance_config.py +++ b/flowapp/instance_config.py @@ -81,8 +81,8 @@ class InstanceConfig: {"name": "Add IPv4", "url": "rules.ipv4_rule"}, {"name": "Add IPv6", "url": "rules.ipv6_rule"}, {"name": "Add RTBH", "url": "rules.rtbh_rule"}, - {"name": "API Key", "url": "api_keys.all"}, {"name": "Add Whitelist", "url": "whitelist.add"}, + {"name": "API Key", "url": "api_keys.all"}, ], "admin": [ {"name": "Commands Log", "url": "admin.log"}, diff --git a/flowapp/services/base.py b/flowapp/services/base.py index 2f4bc26..d829eee 100644 --- a/flowapp/services/base.py +++ b/flowapp/services/base.py @@ -1,6 +1,6 @@ -from flowapp import messages -from flowapp.constants import ANNOUNCE -from flowapp.models import RTBH +from flowapp import db, messages +from flowapp.constants import ANNOUNCE, RuleOrigin, RuleTypes +from flowapp.models import RTBH, RuleWhitelistCache from flowapp.output import Route, RouteSources, announce_route @@ -16,3 +16,12 @@ def announce_rtbh_route(model: RTBH, author: str) -> None: command=command, ) announce_route(route) + + +def add_rtbh_rule_to_cache(model: RTBH, whitelist_id: int, rule_origin: RuleOrigin = RuleOrigin.USER) -> None: + """ + Add RTBH rule to whitelist cache + """ + cache = RuleWhitelistCache(rid=model.id, rtype=RuleTypes.RTBH, whitelist_id=whitelist_id, rorigin=rule_origin) + db.session.add(cache) + db.session.commit() diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index b8db5f7..0d0602c 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -19,12 +19,12 @@ Flowspec6, RTBH, Whitelist, - RuleWhitelistCache, ) from flowapp.output import Route, announce_route, log_route, RouteSources from flowapp.services.base import announce_rtbh_route +from flowapp.services.whitelist_common import Relation, add_rtbh_rule_to_cache, subtract_network, whitelist_rtbh_rule from flowapp.utils import round_to_ten_minutes, get_state_by_time, quote_to_ent -from .whitelist_service import check_rule_against_whitelists, Relation, subtract_network +from .whitelist_common import check_rule_against_whitelists def create_or_update_ipv4_rule( @@ -266,22 +266,3 @@ def create_rtbh_from_whitelist_parts( add_rtbh_rule_to_cache(new_model, wl_id, RuleOrigin.WHITELIST) announce_rtbh_route(new_model, rule_owner) log_route(user_id, model, RuleTypes.RTBH, rule_owner) - - -def add_rtbh_rule_to_cache(model: RTBH, whitelist_id: int, rule_origin: RuleOrigin = RuleOrigin.USER) -> None: - # add to cache - cache = RuleWhitelistCache(rid=model.id, rtype=RuleTypes.RTBH, whitelist_id=whitelist_id, rorigin=rule_origin) - db.session.add(cache) - db.session.commit() - - -def whitelist_rtbh_rule(model: RTBH, whitelist: Whitelist) -> RTBH: - """ - Whitelist RTBH rule. - Set rule state to 4 - whitelisted rule, do not announce later - Add to whitelist cache - """ - model.rstate_id = 4 - db.session.commit() - add_rtbh_rule_to_cache(model, whitelist.id, RuleOrigin.USER) - return model diff --git a/flowapp/services/whitelist_common.py b/flowapp/services/whitelist_common.py new file mode 100644 index 0000000..9c4b49a --- /dev/null +++ b/flowapp/services/whitelist_common.py @@ -0,0 +1,150 @@ +from enum import Enum, auto +from functools import lru_cache +import ipaddress +from typing import List, Tuple +from flowapp import db +from flowapp.constants import RuleOrigin, RuleTypes +from flowapp.models import RTBH, RuleWhitelistCache, Whitelist + + +def add_rtbh_rule_to_cache(model: RTBH, whitelist_id: int, rule_origin: RuleOrigin = RuleOrigin.USER) -> None: + """ + Add RTBH rule to whitelist cache + """ + cache = RuleWhitelistCache(rid=model.id, rtype=RuleTypes.RTBH, whitelist_id=whitelist_id, rorigin=rule_origin) + db.session.add(cache) + db.session.commit() + + +def whitelist_rtbh_rule(model: RTBH, whitelist: Whitelist) -> RTBH: + """ + Whitelist RTBH rule. + Set rule state to 4 - whitelisted rule, do not announce later + Add to whitelist cache + """ + model.rstate_id = 4 + db.session.commit() + add_rtbh_rule_to_cache(model, whitelist.id, RuleOrigin.USER) + return model + + +class Relation(Enum): + SUBNET = auto() + SUPERNET = auto() + EQUAL = auto() + DIFFERENT = auto() + + +@lru_cache(maxsize=1024) +def get_network(address: str) -> ipaddress.IPv4Network | ipaddress.IPv6Network: + """ + Create and cache an IP network object. + + :param address: IP address or network in string format + :return: Cached IP network object + """ + return ipaddress.ip_network(address, strict=False) + + +def check_whitelist_to_rule_relation(rule: str, whitelist_entry: str) -> Relation: + """ + Checks if the whitelist network is a subnet or supernet or exactly the same as the rule network. + Uses cached network objects for better performance. + + :param rule: The IP address or network to check (e.g., "192.168.1.1" or "192.168.1.0/24") + :param whitelist_entry: The allowed network to compare against (e.g., "192.168.1.0/24") + :return: Relation between the two networks + """ + rule_net = get_network(rule) + whitelist_net = get_network(whitelist_entry) + if whitelist_net == rule_net: + return Relation.EQUAL + if whitelist_net.supernet_of(rule_net): + return Relation.SUPERNET + if whitelist_net.subnet_of(rule_net): + return Relation.SUBNET + else: + return Relation.DIFFERENT + + +def subtract_network(target: str, whitelist: str) -> List[str]: + """ + Computes the remaining parts of a network after removing the whitelist subnet. + Uses cached network objects for better performance. + + :param target: The main network (e.g., "192.168.1.0/24") + :param whitelist: The subnet to remove (e.g., "192.168.1.128/25") + :return: A list of remaining subnets as strings + """ + target_net = get_network(target) + whitelist_net = get_network(whitelist) + + # Check if the whitelist is actually a subnet + if check_whitelist_to_rule_relation(target, whitelist) != Relation.SUBNET: + return [target] # Return the full network if whitelist isn't a valid subnet + + remaining = [] + + # Compute ranges before and after the whitelist + if whitelist_net.network_address > target_net.network_address: + # Before the whitelist + start = target_net.network_address + end = whitelist_net.network_address - 1 + remaining.extend(ipaddress.summarize_address_range(start, end)) + + if whitelist_net.broadcast_address < target_net.broadcast_address: + # After the whitelist + start = whitelist_net.broadcast_address + 1 + end = target_net.broadcast_address + remaining.extend(ipaddress.summarize_address_range(start, end)) + + # Convert to string format + return [str(net) for net in remaining] + + +def check_rule_against_whitelists(rule: str, whitelists: List[str]) -> List[Tuple]: + """ + Helper function to check a single rule against multiple whitelist entries. + Creates a cached rule network object for better performance. + Reduces list of whitelists, where the Relation is not DIFFERENT + + :param rule: The IP address or network to check + :param whitelists: List of whitelist networks to check against + :return: tuple of rule, whitelist and relation for each whitelists that is not DIFFERENT + """ + # Pre-cache the rule network since it will be used multiple times + get_network(rule) + items = [] + for whitelist in whitelists: + rel = check_whitelist_to_rule_relation(rule, whitelist) + if rel != Relation.DIFFERENT: + items.append((rule, whitelist, rel)) + return items + + +def check_whitelist_against_rules(rules: List[str], whitelist: str) -> List[Tuple]: + """ + Helper function to check if any whitelist entry is a subnet of the rule. + Creates a cached rule network object for better performance. + Reduces list of rules, where the Relation is not DIFFERENT + + :param rule: The IP address or network to check against + :param whitelists: List of whitelist networks to check + :return: tuple of rule, whitelist and relation for each whitelists that is not DIFFERENT + """ + # Pre-cache the rule network since it will be used multiple times + get_network(whitelist) + items = [] + for rule in rules: + rel = check_whitelist_to_rule_relation(rule, whitelist) + if rel != Relation.DIFFERENT: + items.append((rule, whitelist, rel)) + return items + + +def clear_network_cache() -> None: + """ + Clear the network object cache. + Useful when processing a large number of networks to prevent memory growth. + """ + get_network.cache_clear() diff --git a/flowapp/services/whitelist_service.py b/flowapp/services/whitelist_service.py index 0f8096c..b94772f 100644 --- a/flowapp/services/whitelist_service.py +++ b/flowapp/services/whitelist_service.py @@ -7,9 +7,6 @@ """ from typing import Dict, Tuple, List -from enum import Enum, auto -import ipaddress -from functools import lru_cache from flowapp import db from flowapp.constants import RuleOrigin, RuleTypes @@ -97,125 +94,3 @@ def delete_whitelist(whitelist_id: int) -> List[str]: db.session.commit() return flashes - - -class Relation(Enum): - SUBNET = auto() - SUPERNET = auto() - EQUAL = auto() - DIFFERENT = auto() - - -@lru_cache(maxsize=1024) -def get_network(address: str) -> ipaddress.IPv4Network | ipaddress.IPv6Network: - """ - Create and cache an IP network object. - - :param address: IP address or network in string format - :return: Cached IP network object - """ - return ipaddress.ip_network(address, strict=False) - - -def check_whitelist_to_rule_relation(rule: str, whitelist_entry: str) -> Relation: - """ - Checks if the whitelist network is a subnet or supernet or exactly the same as the rule network. - Uses cached network objects for better performance. - - :param rule: The IP address or network to check (e.g., "192.168.1.1" or "192.168.1.0/24") - :param whitelist_entry: The allowed network to compare against (e.g., "192.168.1.0/24") - :return: Relation between the two networks - """ - rule_net = get_network(rule) - whitelist_net = get_network(whitelist_entry) - if whitelist_net == rule_net: - return Relation.EQUAL - if whitelist_net.supernet_of(rule_net): - return Relation.SUPERNET - if whitelist_net.subnet_of(rule_net): - return Relation.SUBNET - else: - return Relation.DIFFERENT - - -def subtract_network(target: str, whitelist: str) -> List[str]: - """ - Computes the remaining parts of a network after removing the whitelist subnet. - Uses cached network objects for better performance. - - :param target: The main network (e.g., "192.168.1.0/24") - :param whitelist: The subnet to remove (e.g., "192.168.1.128/25") - :return: A list of remaining subnets as strings - """ - target_net = get_network(target) - whitelist_net = get_network(whitelist) - - # Check if the whitelist is actually a subnet - if check_whitelist_to_rule_relation(target, whitelist) != Relation.SUBNET: - return [target] # Return the full network if whitelist isn't a valid subnet - - remaining = [] - - # Compute ranges before and after the whitelist - if whitelist_net.network_address > target_net.network_address: - # Before the whitelist - start = target_net.network_address - end = whitelist_net.network_address - 1 - remaining.extend(ipaddress.summarize_address_range(start, end)) - - if whitelist_net.broadcast_address < target_net.broadcast_address: - # After the whitelist - start = whitelist_net.broadcast_address + 1 - end = target_net.broadcast_address - remaining.extend(ipaddress.summarize_address_range(start, end)) - - # Convert to string format - return [str(net) for net in remaining] - - -def check_rule_against_whitelists(rule: str, whitelists: List[str]) -> List[Tuple]: - """ - Helper function to check a single rule against multiple whitelist entries. - Creates a cached rule network object for better performance. - Reduces list of whitelists, where the Relation is not DIFFERENT - - :param rule: The IP address or network to check - :param whitelists: List of whitelist networks to check against - :return: tuple of rule, whitelist and relation for each whitelists that is not DIFFERENT - """ - # Pre-cache the rule network since it will be used multiple times - get_network(rule) - items = [] - for whitelist in whitelists: - rel = check_whitelist_to_rule_relation(rule, whitelist) - if rel != Relation.DIFFERENT: - items.append((rule, whitelist, rel)) - return items - - -def check_whitelist_against_rules(rules: List[str], whitelist: str) -> List[Tuple]: - """ - Helper function to check if any whitelist entry is a subnet of the rule. - Creates a cached rule network object for better performance. - Reduces list of rules, where the Relation is not DIFFERENT - - :param rule: The IP address or network to check against - :param whitelists: List of whitelist networks to check - :return: tuple of rule, whitelist and relation for each whitelists that is not DIFFERENT - """ - # Pre-cache the rule network since it will be used multiple times - get_network(whitelist) - items = [] - for rule in rules: - rel = check_whitelist_to_rule_relation(rule, whitelist) - if rel != Relation.DIFFERENT: - items.append((rule, whitelist, rel)) - return items - - -def clear_network_cache() -> None: - """ - Clear the network object cache. - Useful when processing a large number of networks to prevent memory growth. - """ - get_network.cache_clear() diff --git a/flowapp/tests/test_whitelist_service.py b/flowapp/tests/test_whitelist_common.py similarity index 99% rename from flowapp/tests/test_whitelist_service.py rename to flowapp/tests/test_whitelist_common.py index 7ae4990..9b98e9a 100644 --- a/flowapp/tests/test_whitelist_service.py +++ b/flowapp/tests/test_whitelist_common.py @@ -1,10 +1,10 @@ import pytest -from flowapp.services.whitelist_service import ( +from flowapp.services.whitelist_common import ( Relation, - check_whitelist_to_rule_relation, - subtract_network, check_rule_against_whitelists, check_whitelist_against_rules, + check_whitelist_to_rule_relation, + subtract_network, clear_network_cache, ) From b765603f12cfaa17002c1b2163b734c6b10a2960 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 13 Mar 2025 08:25:21 +0100 Subject: [PATCH 15/36] refactoring of services - broken tests for rule and whitelist services --- flowapp/models/rules/rtbh.py | 3 + flowapp/services/base.py | 21 +- flowapp/services/rule_service.py | 59 ++- flowapp/services/whitelist_common.py | 38 ++ flowapp/services/whitelist_service.py | 75 ++- flowapp/tests/test_rule_service.py | 605 ++++++++++++++++++++++++ flowapp/tests/test_whitelist_service.py | 234 +++++++++ flowapp/views/whitelist.py | 5 +- 8 files changed, 991 insertions(+), 49 deletions(-) create mode 100644 flowapp/tests/test_rule_service.py create mode 100644 flowapp/tests/test_whitelist_service.py diff --git a/flowapp/models/rules/rtbh.py b/flowapp/models/rules/rtbh.py index f183b81..943fef0 100644 --- a/flowapp/models/rules/rtbh.py +++ b/flowapp/models/rules/rtbh.py @@ -148,3 +148,6 @@ def __str__(self): return f"{self.ipv6}/{self.ipv6_mask}" return f"{self.ipv4}/{self.ipv4_mask} {self.ipv6}/{self.ipv6_mask}" + + def get_author(self): + return f"{self.user.email} / {self.org}" diff --git a/flowapp/services/base.py b/flowapp/services/base.py index d829eee..f06f16f 100644 --- a/flowapp/services/base.py +++ b/flowapp/services/base.py @@ -1,6 +1,6 @@ -from flowapp import db, messages -from flowapp.constants import ANNOUNCE, RuleOrigin, RuleTypes -from flowapp.models import RTBH, RuleWhitelistCache +from flowapp import messages +from flowapp.constants import ANNOUNCE, WITHDRAW +from flowapp.models import RTBH from flowapp.output import Route, RouteSources, announce_route @@ -18,10 +18,15 @@ def announce_rtbh_route(model: RTBH, author: str) -> None: announce_route(route) -def add_rtbh_rule_to_cache(model: RTBH, whitelist_id: int, rule_origin: RuleOrigin = RuleOrigin.USER) -> None: +def withdraw_rtbh_route(model: RTBH) -> None: """ - Add RTBH rule to whitelist cache + Withdraw RTBH route if rule is in whitelist state """ - cache = RuleWhitelistCache(rid=model.id, rtype=RuleTypes.RTBH, whitelist_id=whitelist_id, rorigin=rule_origin) - db.session.add(cache) - db.session.commit() + if model.rstate_id == 4: + command = messages.create_rtbh(model, WITHDRAW) + route = Route( + author=model.get_author(), + source=RouteSources.UI, + command=command, + ) + announce_route(route) diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index 0d0602c..85c177a 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -22,7 +22,13 @@ ) from flowapp.output import Route, announce_route, log_route, RouteSources from flowapp.services.base import announce_rtbh_route -from flowapp.services.whitelist_common import Relation, add_rtbh_rule_to_cache, subtract_network, whitelist_rtbh_rule +from flowapp.services.whitelist_common import ( + Relation, + add_rtbh_rule_to_cache, + create_rtbh_from_whitelist_parts, + subtract_network, + whitelist_rtbh_rule, +) from flowapp.utils import round_to_ten_minutes, get_state_by_time, quote_to_ent from .whitelist_common import check_rule_against_whitelists @@ -209,9 +215,26 @@ def create_or_update_rtbh_rule( # Check if rule is whitelisted # get all not expired whitelists whitelists = db.session.query(Whitelist).filter(Whitelist.expires > datetime.now()).all() - wl_cache = {str(w): w for w in whitelists} + wl_cache = map_whitelists_to_strings(whitelists) results = check_rule_against_whitelists(str(model), wl_cache.keys()) # check rule against whitelists, stop search when rule is whitelisted first time + model = evaluate_rtbh_against_whitelists_check_results(user_id, model, flashes, author, wl_cache, results) + + announce_rtbh_route(model, author=author) + # Log changes + log_route(user_id, model, RuleTypes.RTBH, author) + + return model, flashes + + +def evaluate_rtbh_against_whitelists_check_results( + user_id: int, + model: RTBH, + flashes: List[str], + author: str, + wl_cache: Dict[str, Whitelist], + results: List[Tuple[str, str, Relation]], +) -> RTBH: for rule, whitelist_key, relation in results: match relation: case Relation.EQUAL: @@ -219,7 +242,6 @@ def create_or_update_rtbh_rule( flashes.append(f" Rule is equal to active whitelist {whitelist_key}. Rule is whitelisted.") break case Relation.SUBNET: - # split subnet into parts parts = subtract_network(target=str(model), whitelist=whitelist_key) wl_id = wl_cache[whitelist_key].id flashes.append( @@ -228,7 +250,6 @@ def create_or_update_rtbh_rule( for network in parts: create_rtbh_from_whitelist_parts(model, wl_id, whitelist_key, network, author, user_id) flashes.append(f"DEBUG: Created RTBH rule for {network}, from whitelist {whitelist_key}") - model.rstate_id = 4 add_rtbh_rule_to_cache(model, wl_id, RuleOrigin.USER) db.session.commit() @@ -237,32 +258,8 @@ def create_or_update_rtbh_rule( model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) flashes.append(f" Rule is subnet of active whitelist {whitelist_key}. Rule is whitelisted.") break + return model - announce_rtbh_route(model, author=author) - # Log changes - log_route(user_id, model, RuleTypes.RTBH, author) - return model, flashes - - -def create_rtbh_from_whitelist_parts( - model: RTBH, wl_id: int, whitelist_key: str, network: str, rule_owner: str, user_id: int -) -> None: - net_ip, net_mask = network.split("/") - new_model = RTBH( - ipv4=net_ip, - ipv4_mask=net_mask, - ipv6=model.ipv6, - ipv6_mask=model.ipv6_mask, - community_id=model.community_id, - expires=model.expires, - comment=model.comment, - user_id=model.user_id, - org_id=model.org_id, - rstate_id=1, - ) - db.session.add(new_model) - db.session.commit() - add_rtbh_rule_to_cache(new_model, wl_id, RuleOrigin.WHITELIST) - announce_rtbh_route(new_model, rule_owner) - log_route(user_id, model, RuleTypes.RTBH, rule_owner) +def map_whitelists_to_strings(whitelists: List[Whitelist]) -> Dict[str, Whitelist]: + return {str(w): w for w in whitelists} diff --git a/flowapp/services/whitelist_common.py b/flowapp/services/whitelist_common.py index 9c4b49a..1730402 100644 --- a/flowapp/services/whitelist_common.py +++ b/flowapp/services/whitelist_common.py @@ -5,6 +5,8 @@ from flowapp import db from flowapp.constants import RuleOrigin, RuleTypes from flowapp.models import RTBH, RuleWhitelistCache, Whitelist +from flowapp.output import log_route +from flowapp.services.base import announce_rtbh_route def add_rtbh_rule_to_cache(model: RTBH, whitelist_id: int, rule_origin: RuleOrigin = RuleOrigin.USER) -> None: @@ -29,6 +31,14 @@ def whitelist_rtbh_rule(model: RTBH, whitelist: Whitelist) -> RTBH: class Relation(Enum): + """ + Enum to represent the relation between Whitelist to Rule Relation + Subnet: Whitelist is a subnet of the rule + Supernet: Whitelist is a supernet of the rule + Equal: Whitelist is equal to the rule + Different: Whitelist is different from the rule + """ + SUBNET = auto() SUPERNET = auto() EQUAL = auto() @@ -148,3 +158,31 @@ def clear_network_cache() -> None: Useful when processing a large number of networks to prevent memory growth. """ get_network.cache_clear() + + +def create_rtbh_from_whitelist_parts( + model: RTBH, wl_id: int, whitelist_key: str, network: str, rule_owner: str = "", user_id: int = 0 +) -> None: + # default values from model + rule_owner = rule_owner or model.get_author() + user_id = user_id or model.user_id + + net_ip, net_mask = network.split("/") + new_model = RTBH( + ipv4=net_ip, + ipv4_mask=net_mask, + ipv6=model.ipv6, + ipv6_mask=model.ipv6_mask, + community_id=model.community_id, + expires=model.expires, + comment=model.comment, + user_id=model.user_id, + org_id=model.org_id, + rstate_id=1, + ) + db.session.add(new_model) + db.session.commit() + + add_rtbh_rule_to_cache(new_model, wl_id, RuleOrigin.WHITELIST) + announce_rtbh_route(new_model, rule_owner) + log_route(user_id, model, RuleTypes.RTBH, rule_owner) diff --git a/flowapp/services/whitelist_service.py b/flowapp/services/whitelist_service.py index b94772f..3f57f7e 100644 --- a/flowapp/services/whitelist_service.py +++ b/flowapp/services/whitelist_service.py @@ -13,7 +13,14 @@ from flowapp.models import Whitelist, RuleWhitelistCache, get_whitelist_model_if_exists from flowapp.models.rules.flowspec import Flowspec4, Flowspec6 from flowapp.models.rules.rtbh import RTBH -from flowapp.services.base import announce_rtbh_route +from flowapp.services.base import announce_rtbh_route, withdraw_rtbh_route +from flowapp.services.whitelist_common import add_rtbh_rule_to_cache, create_rtbh_from_whitelist_parts +from flowapp.services.whitelist_common import ( + Relation, + check_whitelist_against_rules, + subtract_network, + whitelist_rtbh_rule, +) from flowapp.utils import round_to_ten_minutes, quote_to_ent @@ -35,10 +42,10 @@ def create_or_update_whitelist( """ # Check for existing model model = get_whitelist_model_if_exists(form_data) - + flashes = [] if model: model.expires = round_to_ten_minutes(form_data["expires"]) - flash_message = "Existing Whitelist found. Expiration time was updated to new value." + flashes.append("Existing Whitelist found. Expiration time was updated to new value.") else: # Create new model model = Whitelist( @@ -50,11 +57,59 @@ def create_or_update_whitelist( comment=quote_to_ent(form_data["comment"]), ) db.session.add(model) - flash_message = "Whitelist saved" + flashes.append("Whitelist saved") db.session.commit() - return model, flash_message + # check RTBH rules against whitelist + all_rtbh_rules = RTBH.query.filter(RTBH.rstate_id == 1).all() + print(f"Found {len(all_rtbh_rules)} active RTBH rules") + rtbh_rules_map = map_rtbh_rules_to_strings(all_rtbh_rules) + result = check_whitelist_against_rules(rtbh_rules_map, str(model)) + print(f"Found {len(result)} matching RTBH rules") + model = evaluate_whitelist_against_rtbh_check_results(model, flashes, rtbh_rules_map, result) + + return model, flashes + + +def evaluate_whitelist_against_rtbh_check_results( + whitelist_model: Whitelist, + flashes: List[str], + rtbh_rule_cache: Dict[str, Whitelist], + results: List[Tuple[str, str, Relation]], +) -> Whitelist: + + for rule_key, whitelist_key, relation in results: + print(f"whitelist {whitelist_key} is {relation} to Rule {rule_key}") + match relation: + case Relation.EQUAL: + whitelist_rtbh_rule(rtbh_rule_cache[rule_key], whitelist_model) + withdraw_rtbh_route(rtbh_rule_cache[rule_key]) + flashes.append(f"Active rule {rule_key} is equal to whitelist {whitelist_key}. Rule is whitelisted.") + case Relation.SUBNET: + parts = subtract_network(target=rule_key, whitelist=whitelist_key) + wl_id = whitelist_model.id + flashes.append( + f" Rule {rule_key} is supernet of whitelist {whitelist_key}. Rule is whitelisted, {len(parts)} subnet rules created." + ) + for network in parts: + rule_model = rtbh_rule_cache[rule_key] + create_rtbh_from_whitelist_parts(rule_model, wl_id, whitelist_key, network) + flashes.append(f"DEBUG: Created RTBH rule for {network}, from whitelist {whitelist_key}") + rule_model.rstate_id = 4 + add_rtbh_rule_to_cache(rule_model, wl_id, RuleOrigin.USER) + db.session.commit() + case Relation.SUPERNET: + + whitelist_rtbh_rule(rtbh_rule_cache[rule_key], whitelist_model) + withdraw_rtbh_route(rtbh_rule_cache[rule_key]) + flashes.append(f"Active rule {rule_key} is subnet of whitelist {whitelist_key}. Rule is whitelisted.") + + return whitelist_model + + +def map_rtbh_rules_to_strings(all_rtbh_rules: List[RTBH]) -> Dict[str, RTBH]: + return {str(rule): rule for rule in all_rtbh_rules} def delete_whitelist(whitelist_id: int) -> List[str]: @@ -83,9 +138,13 @@ def delete_whitelist(whitelist_id: int) -> List[str]: db.session.delete(rule_model) elif rorigin_type == RuleOrigin.USER: flashes.append(f"Set rule {rule_model} back to state 'Active'") - rule_model.rstate_id = 1 # Set rule state to "Active" again - author = f"{model.user.email} ({model.user.organization})" - announce_rtbh_route(rule_model, author) + try: + rule_model.rstate_id = 1 # Set rule state to "Active" again + except AttributeError: + print(f"Rule {rule_model} does not exist, cache anomaly?. Skipping.") + else: + author = f"{model.user.email} ({model.user.organization})" + announce_rtbh_route(rule_model, author) flashes.append(f"Deleted cache entries for whitelist {whitelist_id}") RuleWhitelistCache.clean_by_whitelist_id(whitelist_id) diff --git a/flowapp/tests/test_rule_service.py b/flowapp/tests/test_rule_service.py new file mode 100644 index 0000000..94ab8b6 --- /dev/null +++ b/flowapp/tests/test_rule_service.py @@ -0,0 +1,605 @@ +""" +Tests for rule_service.py module. + +This test suite verifies the functionality of the rule service after refactoring, +which manages creation, updating, and processing of flow rules. +""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import patch, MagicMock + +from flowapp.constants import RuleOrigin, ANNOUNCE +from flowapp.models import Flowspec4, Flowspec6, RTBH, Whitelist +from flowapp.services import rule_service +from flowapp.services.whitelist_common import Relation +import flowapp.services.whitelist_common + + +@pytest.fixture +def ipv4_form_data(): + """Sample valid IPv4 form data""" + return { + "source": "192.168.1.0", + "source_mask": 24, + "source_port": "80", + "dest": "", + "dest_mask": None, + "dest_port": "", + "protocol": "tcp", + "flags": ["SYN"], + "packet_len": "", + "fragment": ["dont-fragment"], + "comment": "Test IPv4 rule", + "expires": datetime.now() + timedelta(hours=1), + "action": 1, + } + + +@pytest.fixture +def ipv6_form_data(): + """Sample valid IPv6 form data""" + return { + "source": "2001:db8::", + "source_mask": 32, + "source_port": "80", + "dest": "", + "dest_mask": None, + "dest_port": "", + "next_header": "tcp", + "flags": ["SYN"], + "packet_len": "", + "comment": "Test IPv6 rule", + "expires": datetime.now() + timedelta(hours=1), + "action": 1, + } + + +@pytest.fixture +def rtbh_form_data(): + """Sample valid RTBH form data""" + return { + "ipv4": "192.168.1.0", + "ipv4_mask": 24, + "ipv6": "", + "ipv6_mask": None, + "community": 1, + "comment": "Test RTBH rule", + "expires": datetime.now() + timedelta(hours=1), + } + + +@pytest.fixture +def whitelist_fixture(): + """Create a whitelist fixture""" + whitelist = MagicMock(spec=Whitelist) + whitelist.id = 1 + return whitelist + + +class TestCreateOrUpdateIPv4Rule: + @patch("flowapp.services.rule_service.get_ipv4_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_create_new_ipv4_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv4_form_data + ): + """Test creating a new IPv4 rule""" + # Mock the get_ipv4_model_if_exists to return False (not found) + mock_get_model.return_value = False + + # Mock the announce route behavior + mock_messages.create_ipv4.return_value = "mock command" + + # Call the service function + with app.app_context(): + model, message = rule_service.create_or_update_ipv4_rule( + form_data=ipv4_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was created with correct attributes + assert model is not None + assert model.source == ipv4_form_data["source"] + assert model.source_mask == ipv4_form_data["source_mask"] + assert model.protocol == ipv4_form_data["protocol"] + assert model.flags == "SYN" # form data flags are joined with ";" + assert model.action_id == ipv4_form_data["action"] + assert model.rstate_id == 1 # Active state + assert model.user_id == 1 + assert model.org_id == 1 + + # Verify message + assert message == "IPv4 Rule saved" + + # Verify route was announced + mock_messages.create_ipv4.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() + + @patch("flowapp.services.rule_service.get_ipv4_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_update_existing_ipv4_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv4_form_data + ): + """Test updating an existing IPv4 rule""" + # Create an existing model to return + existing_model = Flowspec4( + source=ipv4_form_data["source"], + source_mask=ipv4_form_data["source_mask"], + source_port=ipv4_form_data["source_port"], + destination=ipv4_form_data["dest"] or "", + destination_mask=ipv4_form_data["dest_mask"], + destination_port=ipv4_form_data["dest_port"] or "", + protocol=ipv4_form_data["protocol"], + flags=";".join(ipv4_form_data["flags"]), + packet_len=ipv4_form_data["packet_len"] or "", + fragment=";".join(ipv4_form_data["fragment"]), + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + rstate_id=1, + ) + mock_get_model.return_value = existing_model + mock_messages.create_ipv4.return_value = "mock command" + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + ipv4_form_data["expires"] = new_expires + + # Call the service function + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, message = rule_service.create_or_update_ipv4_rule( + form_data=ipv4_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() + + # Verify message + assert message == "Existing IPv4 Rule found. Expiration time was updated to new value." + + # Verify route was announced + mock_messages.create_ipv4.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() + + +class TestCreateOrUpdateIPv6Rule: + @patch("flowapp.services.rule_service.get_ipv6_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_create_new_ipv6_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv6_form_data + ): + """Test creating a new IPv6 rule""" + # Mock get_ipv6_model_if_exists to return False + mock_get_model.return_value = False + + # Mock the announce route behavior + mock_messages.create_ipv6.return_value = "mock command" + + # Call the service function + with app.app_context(): + model, message = rule_service.create_or_update_ipv6_rule( + form_data=ipv6_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was created with correct attributes + assert model is not None + assert model.source == ipv6_form_data["source"] + assert model.source_mask == ipv6_form_data["source_mask"] + assert model.next_header == ipv6_form_data["next_header"] + assert model.flags == "SYN" + assert model.action_id == ipv6_form_data["action"] + assert model.rstate_id == 1 # Active state + assert model.user_id == 1 + assert model.org_id == 1 + + # Verify message + assert message == "IPv6 Rule saved" + + # Verify route was announced + mock_messages.create_ipv6.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() + + @patch("flowapp.services.rule_service.get_ipv6_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_update_existing_ipv6_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv6_form_data + ): + """Test updating an existing IPv6 rule""" + # Create an existing model to return + existing_model = Flowspec6( + source=ipv6_form_data["source"], + source_mask=ipv6_form_data["source_mask"], + source_port=ipv6_form_data["source_port"] or "", + destination=ipv6_form_data["dest"] or "", + destination_mask=ipv6_form_data["dest_mask"], + destination_port=ipv6_form_data["dest_port"] or "", + next_header=ipv6_form_data["next_header"], + flags=";".join(ipv6_form_data["flags"]), + packet_len=ipv6_form_data["packet_len"] or "", + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + rstate_id=1, + ) + mock_get_model.return_value = existing_model + mock_messages.create_ipv6.return_value = "mock command" + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + ipv6_form_data["expires"] = new_expires + + # Call the service function + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, message = rule_service.create_or_update_ipv6_rule( + form_data=ipv6_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() + + # Verify message + assert message == "Existing IPv6 Rule found. Expiration time was updated to new value." + + # Verify route was announced + mock_messages.create_ipv6.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() + + +class TestCreateOrUpdateRTBHRule: + @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") + @patch("flowapp.services.rule_service.db.session.query") + @patch("flowapp.services.rule_service.check_rule_against_whitelists") + @patch("flowapp.services.rule_service.evaluate_rtbh_rule_against_whitelists") + @patch("flowapp.services.rule_service.announce_rtbh_route") + @patch("flowapp.services.rule_service.log_route") + def test_create_new_rtbh_rule( + self, mock_log, mock_announce, mock_evaluate, mock_check, mock_query, mock_get_model, app, db, rtbh_form_data + ): + """Test creating a new RTBH rule""" + # Mock get_rtbh_model_if_exists to return False + mock_get_model.return_value = False + + # Mock the whitelist query + mock_whitelists = [] + mock_query.return_value.filter.return_value.all.return_value = mock_whitelists + + # Mock check_rule_against_whitelists to return empty list (no matches) + mock_check.return_value = [] + + # Mock evaluate function to return the model unchanged + mock_evaluate.side_effect = lambda user_id, model, flashes, author, wl_cache, results: model + + # Call the service function + with app.app_context(): + model, messages = rule_service.create_or_update_rtbh_rule( + form_data=rtbh_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was created with correct attributes + assert model is not None + assert model.ipv4 == rtbh_form_data["ipv4"] + assert model.ipv4_mask == rtbh_form_data["ipv4_mask"] + assert model.ipv6 == rtbh_form_data["ipv6"] + assert model.ipv6_mask == rtbh_form_data["ipv6_mask"] + assert model.community_id == rtbh_form_data["community"] + assert model.rstate_id == 1 # Active state + assert model.user_id == 1 + assert model.org_id == 1 + + # Verify messages + assert "RTBH Rule saved" in messages[0] + + # Verify route was announced + mock_announce.assert_called_once() + mock_log.assert_called_once() + + # Verify evaluate function was called + mock_evaluate.assert_called_once() + + @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") + @patch("flowapp.services.rule_service.db.session.query") + @patch("flowapp.services.rule_service.check_rule_against_whitelists") + @patch("flowapp.services.rule_service.evaluate_rtbh_rule_against_whitelists") + @patch("flowapp.services.rule_service.announce_rtbh_route") + @patch("flowapp.services.rule_service.log_route") + def test_update_existing_rtbh_rule( + self, mock_log, mock_announce, mock_evaluate, mock_check, mock_query, mock_get_model, app, db, rtbh_form_data + ): + """Test updating an existing RTBH rule""" + # Create an existing model to return + existing_model = RTBH( + ipv4=rtbh_form_data["ipv4"], + ipv4_mask=rtbh_form_data["ipv4_mask"], + ipv6=rtbh_form_data["ipv6"] or "", + ipv6_mask=rtbh_form_data["ipv6_mask"], + community_id=rtbh_form_data["community"], + expires=datetime.now(), + user_id=1, + org_id=1, + rstate_id=1, + ) + mock_get_model.return_value = existing_model + + # Mock the whitelist query + mock_whitelists = [] + mock_query.return_value.filter.return_value.all.return_value = mock_whitelists + + # Mock check_rule_against_whitelists to return empty list + mock_check.return_value = [] + + # Mock evaluate function to return the model unchanged + mock_evaluate.side_effect = lambda user_id, model, flashes, author, wl_cache, results: model + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + rtbh_form_data["expires"] = new_expires + + # Call the service function + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, messages = rule_service.create_or_update_rtbh_rule( + form_data=rtbh_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() + + # Verify messages + assert "Existing RTBH Rule found" in messages[0] + + # Verify route was announced + mock_announce.assert_called_once() + mock_log.assert_called_once() + + +class TestEvaluateRtbhRuleAgainstWhitelists: + def test_equal_relation(self, app, whitelist_fixture): + """Test evaluating a rule with an EQUAL relation to a whitelist""" + # Create a model + model = MagicMock(spec=RTBH) + + # Create test data + flashes = [] + user_id = 1 + author = "test@example.com / Test Org" + whitelist_key = "192.168.1.0/24" + wl_cache = {whitelist_key: whitelist_fixture} + results = [(str(model), whitelist_key, Relation.EQUAL)] + + # Call the function with mocked whitelist_rtbh_rule + with patch("flowapp.services.rule_service.whitelist_rtbh_rule") as mock_whitelist_rule: + mock_whitelist_rule.return_value = model + + with app.app_context(): + result = rule_service.evaluate_rtbh_against_whitelists_check_results( + user_id, model, flashes, author, wl_cache, results + ) + + # Verify the rule was whitelisted + mock_whitelist_rule.assert_called_once_with(model, whitelist_fixture) + + # Verify the flash message + assert any("Rule is equal to active whitelist" in msg for msg in flashes) + + # Verify the correct model was returned + assert result == model + + def test_subnet_relation(self, app, whitelist_fixture): + """Test evaluating a rule with a SUBNET relation to a whitelist""" + # Create a model + model = MagicMock(spec=RTBH) + + # Create test data + flashes = [] + user_id = 1 + author = "test@example.com / Test Org" + whitelist_key = "192.168.1.128/25" + wl_cache = {whitelist_key: whitelist_fixture} + results = [(str(model), whitelist_key, Relation.SUBNET)] + + # Call the function with mocked dependencies + with patch("flowapp.services.rule_service.subtract_network") as mock_subtract, patch( + "flowapp.services.rule_service.create_rtbh_from_whitelist_parts" + ) as mock_create, patch("flowapp.services.rule_service.add_rtbh_rule_to_cache") as mock_add_cache, patch( + "flowapp.services.rule_service.db.session.commit" + ) as mock_commit: + + # Mock subtract_network to return some subnets + mock_subtract.return_value = ["192.168.1.0/25"] + + with app.app_context(): + _result = rule_service.evaluate_rtbh_against_whitelists_check_results( + user_id, model, flashes, author, wl_cache, results + ) + + # Verify subnet calculation was performed + mock_subtract.assert_called_once() + + # Verify new rules were created for the subnets + mock_create.assert_called_once() + + # Verify the original rule was cached + mock_add_cache.assert_called_once_with(model, whitelist_fixture.id, RuleOrigin.USER) + + # Verify transaction was committed + mock_commit.assert_called_once() + + # Verify the flash messages + assert any("Rule is supernet of active whitelist" in msg for msg in flashes) + assert any("Created RTBH rule for" in msg for msg in flashes) + + # Verify model was updated to whitelisted state + assert model.rstate_id == 4 + + def test_supernet_relation(self, app, whitelist_fixture): + """Test evaluating a rule with a SUPERNET relation to a whitelist""" + # Create a model + model = MagicMock(spec=RTBH) + + # Create test data + flashes = [] + user_id = 1 + author = "test@example.com / Test Org" + whitelist_key = "192.168.0.0/16" + wl_cache = {whitelist_key: whitelist_fixture} + results = [(str(model), whitelist_key, Relation.SUPERNET)] + + # Call the function with mocked whitelist_rtbh_rule + with patch("flowapp.services.rule_service.whitelist_rtbh_rule") as mock_whitelist_rule: + mock_whitelist_rule.return_value = model + + with app.app_context(): + result = rule_service.evaluate_rtbh_against_whitelists_check_results( + user_id, model, flashes, author, wl_cache, results + ) + + # Verify the rule was whitelisted + mock_whitelist_rule.assert_called_once_with(model, whitelist_fixture) + + # Verify the flash message + assert any("Rule is subnet of active whitelist" in msg for msg in flashes) + + # Verify the correct model was returned + assert result == model + + def test_no_relation(self, app): + """Test evaluating a rule with no relation to any whitelist""" + # Create a model + model = MagicMock(spec=RTBH) + + # Create test data + flashes = [] + user_id = 1 + author = "test@example.com / Test Org" + wl_cache = {} + results = [] + + with app.app_context(): + result = rule_service.evaluate_rtbh_against_whitelists_check_results( + user_id, model, flashes, author, wl_cache, results + ) + + # Verify no changes to the model and no messages + assert result == model + assert not flashes + + +class TestCreateRtbhFromWhitelistParts: + @patch("flowapp.services.rule_service.add_rtbh_rule_to_cache") + @patch("flowapp.services.rule_service.announce_rtbh_route") + @patch("flowapp.services.rule_service.log_route") + def test_create_rtbh_from_whitelist_parts(self, mock_log, mock_announce, mock_add_cache, app, db): + """Test creating RTBH rules from whitelist parts""" + # Create an initial RTBH rule model + model = RTBH( + ipv4="192.168.1.0", + ipv4_mask=24, + ipv6="", + ipv6_mask=None, + community_id=1, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Set up test data + network = "192.168.1.128/25" + whitelist_id = 1 + whitelist_key = "192.168.1.128/25" + rule_owner = "test@example.com / Test Org" + user_id = 1 + + # Call the function + with app.app_context(): + db.session.add(model) + db.session.commit() + + flowapp.services.whitelist_common.create_rtbh_from_whitelist_parts( + model, whitelist_id, whitelist_key, network, rule_owner, user_id + ) + + # Verify a new model was created and saved + net_ip, net_mask = network.split("/") + subnet_rules = db.session.query(RTBH).filter(RTBH.ipv4 == net_ip, RTBH.ipv4_mask == net_mask).all() + + assert len(subnet_rules) == 1 + assert subnet_rules[0].rstate_id == 1 # Active state + assert subnet_rules[0].ipv4 == "192.168.1.128" + assert subnet_rules[0].ipv4_mask == 25 + + # Verify caching and announcement + mock_add_cache.assert_called_once() + mock_announce.assert_called_once() + mock_log.assert_called_once() + + +class TestMapWhitelistsToStrings: + def test_map_whitelists_to_strings(self): + """Test mapping whitelist objects to strings""" + # Create mock whitelists + whitelist1 = MagicMock(spec=Whitelist) + whitelist1.__str__.return_value = "192.168.1.0/24" + + whitelist2 = MagicMock(spec=Whitelist) + whitelist2.__str__.return_value = "10.0.0.0/8" + + whitelists = [whitelist1, whitelist2] + + # Call the function + result = rule_service.map_whitelists_to_strings(whitelists) + + # Verify the result + assert len(result) == 2 + assert "192.168.1.0/24" in result + assert "10.0.0.0/8" in result + assert result["192.168.1.0/24"] == whitelist1 + assert result["10.0.0.0/8"] == whitelist2 diff --git a/flowapp/tests/test_whitelist_service.py b/flowapp/tests/test_whitelist_service.py new file mode 100644 index 0000000..1f5ffd0 --- /dev/null +++ b/flowapp/tests/test_whitelist_service.py @@ -0,0 +1,234 @@ +""" +Tests for whitelist_service.py module. + +This test suite verifies the functionality of the whitelist service, +which manages creation, updating, and handling of whitelist rules. +""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import patch + +from flowapp.constants import RuleTypes, RuleOrigin +from flowapp.models import Whitelist, RuleWhitelistCache, RTBH +from flowapp.services import whitelist_service + + +@pytest.fixture +def whitelist_form_data(): + """Sample valid whitelist form data""" + return { + "ip": "192.168.1.0", + "mask": 24, + "comment": "Test whitelist entry", + "expires": datetime.now() + timedelta(hours=1), + } + + +class TestCreateOrUpdateWhitelist: + def test_create_new_whitelist(self, app, db, whitelist_form_data): + """Test creating a new whitelist entry""" + # Mock the get_whitelist_model_if_exists to return False (not found) + with patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists", return_value=False): + # Call the service function + with app.app_context(): + model, message = whitelist_service.create_or_update_whitelist( + form_data=whitelist_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was created with correct attributes + assert model is not None + assert model.ip == whitelist_form_data["ip"] + assert model.mask == whitelist_form_data["mask"] + assert model.comment == whitelist_form_data["comment"] + assert model.user_id == 1 + assert model.org_id == 1 + assert model.rstate_id == 1 # Active state + + # Verify message + assert message == "Whitelist saved" + + # Verify the whitelist was saved to the database + saved_whitelist = db.session.query(Whitelist).filter_by(ip=whitelist_form_data["ip"]).first() + assert saved_whitelist is not None + assert saved_whitelist.id == model.id + + def test_update_existing_whitelist(self, app, db, whitelist_form_data): + """Test updating an existing whitelist entry""" + # Create an existing whitelist + existing_model = Whitelist( + ip=whitelist_form_data["ip"], + mask=whitelist_form_data["mask"], + expires=datetime.now(), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Mock to return the existing model + with patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists", return_value=existing_model): + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + whitelist_form_data["expires"] = new_expires + + # Call the service function + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, message = whitelist_service.create_or_update_whitelist( + form_data=whitelist_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + # Check that expiration date was updated (rounded to 10 minutes) + # We can't compare exact timestamps, so check date parts + assert model.expires.date() == whitelist_service.round_to_ten_minutes(new_expires).date() + + # Verify message + assert message == "Existing Whitelist found. Expiration time was updated to new value." + + +class TestDeleteWhitelist: + @patch("flowapp.services.whitelist_service.announce_rtbh_route") + def test_delete_whitelist_with_user_rules(self, mock_announce, app, db): + """Test deleting a whitelist that has user-created rules attached to it""" + # Create a whitelist + whitelist = Whitelist( + ip="192.168.1.0", + mask=24, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Create a RTBH rule (created by user, whitelisted) + rtbh_rule = RTBH( + ipv4="192.168.1.0", + ipv4_mask=24, + ipv6="", + ipv6_mask=None, + community_id=1, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=4, # Whitelisted state + ) + + # Create a cache entry linking the rule to the whitelist + cache_entry = RuleWhitelistCache( + rid=1, + rtype=RuleTypes.RTBH, + whitelist_id=1, + rorigin=RuleOrigin.USER, + ) + + with app.app_context(): + db.session.add(whitelist) + db.session.add(rtbh_rule) + db.session.commit() + + # Set whitelist ID now that it has been saved + cache_entry.whitelist_id = whitelist.id + cache_entry.rid = rtbh_rule.id + db.session.add(cache_entry) + db.session.commit() + + # Mock the get_by_whitelist_id to return our cache entry + with patch.object(RuleWhitelistCache, "get_by_whitelist_id", return_value=[cache_entry]): + # Call the service function + messages = whitelist_service.delete_whitelist(whitelist.id) + + # Verify the rule state was changed back to active + rtbh_rule = db.session.get(RTBH, rtbh_rule.id) + assert rtbh_rule.rstate_id == 1 # Active state + + # Verify announcement was made + mock_announce.assert_called_once() + + # Verify messages + assert any("Set rule" in msg for msg in messages) + + # Verify the whitelist was deleted + assert db.session.get(Whitelist, whitelist.id) is None + + @patch("flowapp.services.whitelist_service.RuleWhitelistCache.clean_by_whitelist_id") + def test_delete_whitelist_with_whitelist_created_rules(self, mock_clean, app, db): + """Test deleting a whitelist that has rules created by the whitelist""" + # Create a whitelist + whitelist = Whitelist( + ip="192.168.1.0", + mask=24, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Create a RTBH rule (created by whitelist) + rtbh_rule = RTBH( + ipv4="192.168.1.0", + ipv4_mask=24, + ipv6="", + ipv6_mask=None, + community_id=1, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Create a cache entry linking the rule to the whitelist + cache_entry = RuleWhitelistCache( + rid=1, + rtype=RuleTypes.RTBH, + whitelist_id=1, + rorigin=RuleOrigin.WHITELIST, # Important: created BY whitelist + ) + + with app.app_context(): + db.session.add(whitelist) + db.session.add(rtbh_rule) + db.session.commit() + + # Set IDs now that they have been saved + cache_entry.whitelist_id = whitelist.id + cache_entry.rid = rtbh_rule.id + db.session.add(cache_entry) + db.session.commit() + + # Create a mock session that can get our rule + with patch.object(RuleWhitelistCache, "get_by_whitelist_id", return_value=[cache_entry]): + # Call the service function + messages = whitelist_service.delete_whitelist(whitelist.id) + + # Verify messages + assert any("Deleted rule" in msg for msg in messages) + + # Verify the rule was deleted + assert db.session.get(RTBH, rtbh_rule.id) is None + + # Verify the whitelist was deleted + assert db.session.get(Whitelist, whitelist.id) is None + + # Verify cache cleanup was called + mock_clean.assert_called_once_with(whitelist.id) + + def test_delete_nonexistent_whitelist(self, app, db): + """Test deleting a whitelist that doesn't exist""" + with app.app_context(): + # Call the service function with a non-existent ID + messages = whitelist_service.delete_whitelist(999) + + # Should return empty list of messages, as no whitelist was found + assert len(messages) == 0 diff --git a/flowapp/views/whitelist.py b/flowapp/views/whitelist.py index 9f2bcda..39a9352 100644 --- a/flowapp/views/whitelist.py +++ b/flowapp/views/whitelist.py @@ -24,14 +24,15 @@ def add(): form.net_ranges = net_ranges if request.method == "POST" and form.validate(): - model, message = create_or_update_whitelist( + model, messages = create_or_update_whitelist( form.data, user_id=session["user_id"], org_id=session["user_org_id"], user_email=session["user_email"], org_name=session["user_org"], ) - flash(message, "alert-success") + for message in messages: + flash(message, "alert-success") return redirect(url_for("index")) else: From 8b499bc851bc881e4e35625058f13b3a3c6fc045 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 13 Mar 2025 08:43:33 +0100 Subject: [PATCH 16/36] refactoring of services - fixed tests for rule and whitelist services, all tests passing --- flowapp/tests/test_rule_service.py | 266 ++++++++++++---------- flowapp/tests/test_whitelist_service.py | 287 +++++++++++++++++++++--- 2 files changed, 402 insertions(+), 151 deletions(-) diff --git a/flowapp/tests/test_rule_service.py b/flowapp/tests/test_rule_service.py index 94ab8b6..16daf10 100644 --- a/flowapp/tests/test_rule_service.py +++ b/flowapp/tests/test_rule_service.py @@ -13,7 +13,6 @@ from flowapp.models import Flowspec4, Flowspec6, RTBH, Whitelist from flowapp.services import rule_service from flowapp.services.whitelist_common import Relation -import flowapp.services.whitelist_common @pytest.fixture @@ -113,7 +112,7 @@ def test_create_new_ipv4_rule( assert model.user_id == 1 assert model.org_id == 1 - # Verify message + # Verify message is still a string for IPv4 rules assert message == "IPv4 Rule saved" # Verify route was announced @@ -171,7 +170,7 @@ def test_update_existing_ipv4_rule( assert model == existing_model assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() - # Verify message + # Verify message is still a string for IPv4 rules assert message == "Existing IPv4 Rule found. Expiration time was updated to new value." # Verify route was announced @@ -216,7 +215,7 @@ def test_create_new_ipv6_rule( assert model.user_id == 1 assert model.org_id == 1 - # Verify message + # Verify message is still a string for IPv6 rules assert message == "IPv6 Rule saved" # Verify route was announced @@ -224,69 +223,12 @@ def test_create_new_ipv6_rule( mock_announce.assert_called_once() mock_log.assert_called_once() - @patch("flowapp.services.rule_service.get_ipv6_model_if_exists") - @patch("flowapp.services.rule_service.messages") - @patch("flowapp.services.rule_service.announce_route") - @patch("flowapp.services.rule_service.log_route") - def test_update_existing_ipv6_rule( - self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv6_form_data - ): - """Test updating an existing IPv6 rule""" - # Create an existing model to return - existing_model = Flowspec6( - source=ipv6_form_data["source"], - source_mask=ipv6_form_data["source_mask"], - source_port=ipv6_form_data["source_port"] or "", - destination=ipv6_form_data["dest"] or "", - destination_mask=ipv6_form_data["dest_mask"], - destination_port=ipv6_form_data["dest_port"] or "", - next_header=ipv6_form_data["next_header"], - flags=";".join(ipv6_form_data["flags"]), - packet_len=ipv6_form_data["packet_len"] or "", - expires=datetime.now(), - user_id=1, - org_id=1, - action_id=1, - rstate_id=1, - ) - mock_get_model.return_value = existing_model - mock_messages.create_ipv6.return_value = "mock command" - - # Set a new expiration time - new_expires = datetime.now() + timedelta(days=1) - ipv6_form_data["expires"] = new_expires - - # Call the service function - with app.app_context(): - db.session.add(existing_model) - db.session.commit() - - model, message = rule_service.create_or_update_ipv6_rule( - form_data=ipv6_form_data, - user_id=1, - org_id=1, - user_email="test@example.com", - org_name="Test Org", - ) - - # Verify the model was updated - assert model == existing_model - assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() - - # Verify message - assert message == "Existing IPv6 Rule found. Expiration time was updated to new value." - - # Verify route was announced - mock_messages.create_ipv6.assert_called_once_with(model, ANNOUNCE) - mock_announce.assert_called_once() - mock_log.assert_called_once() - class TestCreateOrUpdateRTBHRule: @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") @patch("flowapp.services.rule_service.db.session.query") @patch("flowapp.services.rule_service.check_rule_against_whitelists") - @patch("flowapp.services.rule_service.evaluate_rtbh_rule_against_whitelists") + @patch("flowapp.services.rule_service.evaluate_rtbh_against_whitelists_check_results") @patch("flowapp.services.rule_service.announce_rtbh_route") @patch("flowapp.services.rule_service.log_route") def test_create_new_rtbh_rule( @@ -308,7 +250,7 @@ def test_create_new_rtbh_rule( # Call the service function with app.app_context(): - model, messages = rule_service.create_or_update_rtbh_rule( + model, flashes = rule_service.create_or_update_rtbh_rule( form_data=rtbh_form_data, user_id=1, org_id=1, @@ -327,10 +269,11 @@ def test_create_new_rtbh_rule( assert model.user_id == 1 assert model.org_id == 1 - # Verify messages - assert "RTBH Rule saved" in messages[0] + # Verify flash messages - now a list instead of a string + assert isinstance(flashes, list) + assert "RTBH Rule saved" in flashes[0] - # Verify route was announced + # Verify rule was announced mock_announce.assert_called_once() mock_log.assert_called_once() @@ -340,7 +283,7 @@ def test_create_new_rtbh_rule( @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") @patch("flowapp.services.rule_service.db.session.query") @patch("flowapp.services.rule_service.check_rule_against_whitelists") - @patch("flowapp.services.rule_service.evaluate_rtbh_rule_against_whitelists") + @patch("flowapp.services.rule_service.evaluate_rtbh_against_whitelists_check_results") @patch("flowapp.services.rule_service.announce_rtbh_route") @patch("flowapp.services.rule_service.log_route") def test_update_existing_rtbh_rule( @@ -380,7 +323,7 @@ def test_update_existing_rtbh_rule( db.session.add(existing_model) db.session.commit() - model, messages = rule_service.create_or_update_rtbh_rule( + model, flashes = rule_service.create_or_update_rtbh_rule( form_data=rtbh_form_data, user_id=1, org_id=1, @@ -392,15 +335,89 @@ def test_update_existing_rtbh_rule( assert model == existing_model assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() - # Verify messages - assert "Existing RTBH Rule found" in messages[0] + # Verify flash messages + assert isinstance(flashes, list) + assert "Existing RTBH Rule found" in flashes[0] # Verify route was announced mock_announce.assert_called_once() mock_log.assert_called_once() + @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") + @patch("flowapp.services.rule_service.db.session.query") + @patch("flowapp.services.rule_service.map_whitelists_to_strings") + @patch("flowapp.services.rule_service.check_rule_against_whitelists") + @patch("flowapp.services.rule_service.evaluate_rtbh_against_whitelists_check_results") + @patch("flowapp.services.rule_service.announce_rtbh_route") + @patch("flowapp.services.rule_service.log_route") + def test_rtbh_rule_with_whitelists( + self, + mock_log, + mock_announce, + mock_evaluate, + mock_check, + mock_map, + mock_query, + mock_get_model, + app, + db, + rtbh_form_data, + whitelist_fixture, + ): + """Test creating a RTBH rule that interacts with whitelists""" + # Mock get_rtbh_model_if_exists to return False + mock_get_model.return_value = False + + # Create a mock whitelist + mock_whitelist = whitelist_fixture + mock_whitelists = [mock_whitelist] + + # Setup mock query to return our whitelist + mock_query_result = MagicMock() + mock_query_result.filter.return_value.all.return_value = mock_whitelists + mock_query.return_value = mock_query_result + + # Setup map function to return our whitelist in a dict + whitelist_key = "192.168.1.0/24" + mock_map.return_value = {whitelist_key: mock_whitelist} + + # Setup check function to return a relation + mock_rtbh = MagicMock(spec=RTBH) + mock_rtbh.__str__.return_value = "192.168.1.0/24" + + # Setup check result + rule_relation = [(str(mock_rtbh), whitelist_key, Relation.EQUAL)] + mock_check.return_value = rule_relation + + # Setup evaluate function to add a flash message + def evaluate_side_effect(user_id, model, flashes, author, wl_cache, results): + flashes.append("Rule is equal to whitelist") + return model + + mock_evaluate.side_effect = evaluate_side_effect + + # Call the service function + with app.app_context(): + model, flashes = rule_service.create_or_update_rtbh_rule( + form_data=rtbh_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify flash messages show both rule creation and whitelist check + assert isinstance(flashes, list) + assert "RTBH Rule saved" in flashes[0] + assert "Rule is equal to whitelist" in flashes[1] + + # Verify interactions + mock_map.assert_called_once() + mock_check.assert_called_once() + mock_evaluate.assert_called_once() + -class TestEvaluateRtbhRuleAgainstWhitelists: +class TestEvaluateRtbhAgainstWhitelistsCheckResults: def test_equal_relation(self, app, whitelist_fixture): """Test evaluating a rule with an EQUAL relation to a whitelist""" # Create a model @@ -532,56 +549,6 @@ def test_no_relation(self, app): assert not flashes -class TestCreateRtbhFromWhitelistParts: - @patch("flowapp.services.rule_service.add_rtbh_rule_to_cache") - @patch("flowapp.services.rule_service.announce_rtbh_route") - @patch("flowapp.services.rule_service.log_route") - def test_create_rtbh_from_whitelist_parts(self, mock_log, mock_announce, mock_add_cache, app, db): - """Test creating RTBH rules from whitelist parts""" - # Create an initial RTBH rule model - model = RTBH( - ipv4="192.168.1.0", - ipv4_mask=24, - ipv6="", - ipv6_mask=None, - community_id=1, - expires=datetime.now() + timedelta(hours=1), - user_id=1, - org_id=1, - rstate_id=1, - ) - - # Set up test data - network = "192.168.1.128/25" - whitelist_id = 1 - whitelist_key = "192.168.1.128/25" - rule_owner = "test@example.com / Test Org" - user_id = 1 - - # Call the function - with app.app_context(): - db.session.add(model) - db.session.commit() - - flowapp.services.whitelist_common.create_rtbh_from_whitelist_parts( - model, whitelist_id, whitelist_key, network, rule_owner, user_id - ) - - # Verify a new model was created and saved - net_ip, net_mask = network.split("/") - subnet_rules = db.session.query(RTBH).filter(RTBH.ipv4 == net_ip, RTBH.ipv4_mask == net_mask).all() - - assert len(subnet_rules) == 1 - assert subnet_rules[0].rstate_id == 1 # Active state - assert subnet_rules[0].ipv4 == "192.168.1.128" - assert subnet_rules[0].ipv4_mask == 25 - - # Verify caching and announcement - mock_add_cache.assert_called_once() - mock_announce.assert_called_once() - mock_log.assert_called_once() - - class TestMapWhitelistsToStrings: def test_map_whitelists_to_strings(self): """Test mapping whitelist objects to strings""" @@ -603,3 +570,60 @@ def test_map_whitelists_to_strings(self): assert "10.0.0.0/8" in result assert result["192.168.1.0/24"] == whitelist1 assert result["10.0.0.0/8"] == whitelist2 + + @patch("flowapp.services.rule_service.get_ipv6_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_update_existing_ipv6_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv6_form_data + ): + """Test updating an existing IPv6 rule""" + # Create an existing model to return + existing_model = Flowspec6( + source=ipv6_form_data["source"], + source_mask=ipv6_form_data["source_mask"], + source_port=ipv6_form_data["source_port"] or "", + destination=ipv6_form_data["dest"] or "", + destination_mask=ipv6_form_data["dest_mask"], + destination_port=ipv6_form_data["dest_port"] or "", + next_header=ipv6_form_data["next_header"], + flags=";".join(ipv6_form_data["flags"]), + packet_len=ipv6_form_data["packet_len"] or "", + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + rstate_id=1, + ) + mock_get_model.return_value = existing_model + mock_messages.create_ipv6.return_value = "mock command" + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + ipv6_form_data["expires"] = new_expires + + # Call the service function + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, message = rule_service.create_or_update_ipv6_rule( + form_data=ipv6_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() + + # Verify message is still a string for IPv6 rules + assert message == "Existing IPv6 Rule found. Expiration time was updated to new value." + + # Verify route was announced + mock_messages.create_ipv6.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() diff --git a/flowapp/tests/test_whitelist_service.py b/flowapp/tests/test_whitelist_service.py index 1f5ffd0..519c2aa 100644 --- a/flowapp/tests/test_whitelist_service.py +++ b/flowapp/tests/test_whitelist_service.py @@ -7,11 +7,12 @@ import pytest from datetime import datetime, timedelta -from unittest.mock import patch +from unittest.mock import patch, MagicMock from flowapp.constants import RuleTypes, RuleOrigin from flowapp.models import Whitelist, RuleWhitelistCache, RTBH from flowapp.services import whitelist_service +from flowapp.services.whitelist_common import Relation @pytest.fixture @@ -26,13 +27,31 @@ def whitelist_form_data(): class TestCreateOrUpdateWhitelist: - def test_create_new_whitelist(self, app, db, whitelist_form_data): + @patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists") + @patch("flowapp.services.whitelist_service.check_whitelist_against_rules") + @patch("flowapp.services.whitelist_service.evaluate_whitelist_against_rtbh_check_results") + def test_create_new_whitelist( + self, mock_evaluate, mock_check_whitelist, mock_get_model, app, db, whitelist_form_data + ): """Test creating a new whitelist entry""" # Mock the get_whitelist_model_if_exists to return False (not found) - with patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists", return_value=False): - # Call the service function + mock_get_model.return_value = False + + # Mock RTBH rules and check results + mock_rtbh_rules = [] + mock_query = MagicMock() + mock_query.filter.return_value.all.return_value = mock_rtbh_rules + + # Mock check_whitelist_against_rules to return empty list (no matches) + mock_check_whitelist.return_value = [] + + # Mock evaluate to return the model unchanged + mock_evaluate.side_effect = lambda model, flashes, rtbh_rule_cache, results: model + + # Call the service function + with patch("flowapp.services.whitelist_service.db.session.query", return_value=mock_query): with app.app_context(): - model, message = whitelist_service.create_or_update_whitelist( + model, flashes = whitelist_service.create_or_update_whitelist( form_data=whitelist_form_data, user_id=1, org_id=1, @@ -49,15 +68,16 @@ def test_create_new_whitelist(self, app, db, whitelist_form_data): assert model.org_id == 1 assert model.rstate_id == 1 # Active state - # Verify message - assert message == "Whitelist saved" + # Verify flash messages - now a list instead of a string + assert isinstance(flashes, list) + assert "Whitelist saved" in flashes[0] - # Verify the whitelist was saved to the database - saved_whitelist = db.session.query(Whitelist).filter_by(ip=whitelist_form_data["ip"]).first() - assert saved_whitelist is not None - assert saved_whitelist.id == model.id - - def test_update_existing_whitelist(self, app, db, whitelist_form_data): + @patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists") + @patch("flowapp.services.whitelist_service.check_whitelist_against_rules") + @patch("flowapp.services.whitelist_service.evaluate_whitelist_against_rtbh_check_results") + def test_update_existing_whitelist( + self, mock_evaluate, mock_check_whitelist, mock_get_model, app, db, whitelist_form_data + ): """Test updating an existing whitelist entry""" # Create an existing whitelist existing_model = Whitelist( @@ -70,17 +90,30 @@ def test_update_existing_whitelist(self, app, db, whitelist_form_data): ) # Mock to return the existing model - with patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists", return_value=existing_model): - # Set a new expiration time - new_expires = datetime.now() + timedelta(days=1) - whitelist_form_data["expires"] = new_expires + mock_get_model.return_value = existing_model + + # Mock RTBH rules and check results + mock_rtbh_rules = [] + mock_query = MagicMock() + mock_query.filter.return_value.all.return_value = mock_rtbh_rules + + # Mock check_whitelist_against_rules to return empty list (no matches) + mock_check_whitelist.return_value = [] - # Call the service function + # Mock evaluate to return the model unchanged + mock_evaluate.side_effect = lambda model, flashes, rtbh_rule_cache, results: model + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + whitelist_form_data["expires"] = new_expires + + # Call the service function + with patch("flowapp.services.whitelist_service.db.session.query", return_value=mock_query): with app.app_context(): db.session.add(existing_model) db.session.commit() - model, message = whitelist_service.create_or_update_whitelist( + model, flashes = whitelist_service.create_or_update_whitelist( form_data=whitelist_form_data, user_id=1, org_id=1, @@ -94,8 +127,70 @@ def test_update_existing_whitelist(self, app, db, whitelist_form_data): # We can't compare exact timestamps, so check date parts assert model.expires.date() == whitelist_service.round_to_ten_minutes(new_expires).date() - # Verify message - assert message == "Existing Whitelist found. Expiration time was updated to new value." + # Verify flash messages - now a list instead of a string + assert isinstance(flashes, list) + assert "Existing Whitelist found" in flashes[0] + + @patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists") + @patch("flowapp.services.whitelist_service.db.session.query") + @patch("flowapp.services.whitelist_service.map_rtbh_rules_to_strings") + @patch("flowapp.services.whitelist_service.check_whitelist_against_rules") + @patch("flowapp.services.whitelist_service.evaluate_whitelist_against_rtbh_check_results") + def test_create_whitelist_with_matching_rules( + self, mock_evaluate, mock_check, mock_map, mock_query, mock_get_model, app, db, whitelist_form_data + ): + """Test creating a whitelist that affects existing rules""" + # Mock get_whitelist_model_if_exists to return False (new whitelist) + mock_get_model.return_value = False + + # Create a mock RTBH rule + mock_rtbh_rule = MagicMock(spec=RTBH) + mock_rtbh_rule.__str__.return_value = "192.168.1.0/24" + mock_rtbh_rules = [mock_rtbh_rule] + + # Setup mock query to return our rule + mock_query_result = MagicMock() + mock_query_result.filter.return_value.all.return_value = mock_rtbh_rules + mock_query.return_value = mock_query_result + + # Setup map function to return our rule in a dict + mock_map.return_value = {"192.168.1.0/24": mock_rtbh_rule} + + # Setup check function to return a relation + rule_relation = [ + ( + str(mock_rtbh_rule), + str(whitelist_form_data["ip"]) + "/" + str(whitelist_form_data["mask"]), + Relation.EQUAL, + ) + ] + mock_check.return_value = rule_relation + + # Setup evaluate function to modify flashes + def evaluate_side_effect(model, flashes, rtbh_rule_cache, results): + flashes.append("Rule was whitelisted") + return model + + mock_evaluate.side_effect = evaluate_side_effect + + # Call the service function + with app.app_context(): + model, flashes = whitelist_service.create_or_update_whitelist( + form_data=whitelist_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify flash messages includes both whitelist creation and rule whitelisting + assert "Whitelist saved" in flashes[0] + assert "Rule was whitelisted" in flashes[1] + + # Verify interactions + mock_map.assert_called_once() + mock_check.assert_called_once() + mock_evaluate.assert_called_once() class TestDeleteWhitelist: @@ -147,7 +242,7 @@ def test_delete_whitelist_with_user_rules(self, mock_announce, app, db): # Mock the get_by_whitelist_id to return our cache entry with patch.object(RuleWhitelistCache, "get_by_whitelist_id", return_value=[cache_entry]): # Call the service function - messages = whitelist_service.delete_whitelist(whitelist.id) + flashes = whitelist_service.delete_whitelist(whitelist.id) # Verify the rule state was changed back to active rtbh_rule = db.session.get(RTBH, rtbh_rule.id) @@ -156,8 +251,9 @@ def test_delete_whitelist_with_user_rules(self, mock_announce, app, db): # Verify announcement was made mock_announce.assert_called_once() - # Verify messages - assert any("Set rule" in msg for msg in messages) + # Verify flash messages + assert isinstance(flashes, list) + assert any("Set rule" in msg for msg in flashes) # Verify the whitelist was deleted assert db.session.get(Whitelist, whitelist.id) is None @@ -210,10 +306,11 @@ def test_delete_whitelist_with_whitelist_created_rules(self, mock_clean, app, db # Create a mock session that can get our rule with patch.object(RuleWhitelistCache, "get_by_whitelist_id", return_value=[cache_entry]): # Call the service function - messages = whitelist_service.delete_whitelist(whitelist.id) + flashes = whitelist_service.delete_whitelist(whitelist.id) - # Verify messages - assert any("Deleted rule" in msg for msg in messages) + # Verify flash messages + assert isinstance(flashes, list) + assert any("Deleted rule" in msg for msg in flashes) # Verify the rule was deleted assert db.session.get(RTBH, rtbh_rule.id) is None @@ -228,7 +325,137 @@ def test_delete_nonexistent_whitelist(self, app, db): """Test deleting a whitelist that doesn't exist""" with app.app_context(): # Call the service function with a non-existent ID - messages = whitelist_service.delete_whitelist(999) + flashes = whitelist_service.delete_whitelist(999) + + # Should return empty list of flash messages, as no whitelist was found + assert isinstance(flashes, list) + assert len(flashes) == 0 + + +class TestEvaluateWhitelistAgainstRtbhResults: + def test_equal_relation(self, app): + """Test evaluating a whitelist with an EQUAL relation to a rule""" + # Create test data + whitelist_model = MagicMock(spec=Whitelist) + whitelist_model.id = 1 + + flashes = [] + + rtbh_rule = MagicMock(spec=RTBH) + rtbh_rule.rstate_id = 1 # Active state + + rule_key = "192.168.1.0/24" + whitelist_key = "192.168.1.0/24" + rtbh_rule_cache = {rule_key: rtbh_rule} + + results = [(rule_key, whitelist_key, Relation.EQUAL)] + + # Mock required functions + with patch("flowapp.services.whitelist_service.whitelist_rtbh_rule") as mock_whitelist_rule, patch( + "flowapp.services.whitelist_service.withdraw_rtbh_route" + ) as mock_withdraw: + + # Call the function + with app.app_context(): + result = whitelist_service.evaluate_whitelist_against_rtbh_check_results( + whitelist_model, flashes, rtbh_rule_cache, results + ) + + # Verify the rule was whitelisted and route withdrawn + mock_whitelist_rule.assert_called_once_with(rtbh_rule, whitelist_model) + mock_withdraw.assert_called_once_with(rtbh_rule) + + # Verify the flash message + assert any("equal to whitelist" in msg for msg in flashes) + + # Verify the correct model was returned + assert result == whitelist_model + + def test_subnet_relation(self, app): + """Test evaluating a whitelist with a SUBNET relation to a rule""" + # Create test data + whitelist_model = MagicMock(spec=Whitelist) + whitelist_model.id = 1 + + flashes = [] + + rtbh_rule = MagicMock(spec=RTBH) + rtbh_rule.rstate_id = 1 # Active state + + rule_key = "192.168.1.0/24" + whitelist_key = "192.168.1.128/25" + rtbh_rule_cache = {rule_key: rtbh_rule} + + results = [(rule_key, whitelist_key, Relation.SUBNET)] + + # Mock required functions + with patch("flowapp.services.whitelist_service.subtract_network") as mock_subtract, patch( + "flowapp.services.whitelist_service.create_rtbh_from_whitelist_parts" + ) as mock_create, patch("flowapp.services.whitelist_service.add_rtbh_rule_to_cache") as mock_add_cache, patch( + "flowapp.services.whitelist_service.db.session.commit" + ) as mock_commit: + + # Mock subtract_network to return some subnets + mock_subtract.return_value = ["192.168.1.0/25"] + + # Call the function + with app.app_context(): + _result = whitelist_service.evaluate_whitelist_against_rtbh_check_results( + whitelist_model, flashes, rtbh_rule_cache, results + ) + + # Verify subnet calculation was performed + mock_subtract.assert_called_once() + + # Verify new rules were created for the subnets + mock_create.assert_called_once() + + # Verify the original rule was cached + mock_add_cache.assert_called_once_with(rtbh_rule, whitelist_model.id, RuleOrigin.USER) + + # Verify transaction was committed + mock_commit.assert_called_once() + + # Verify the flash messages + assert any("supernet of whitelist" in msg for msg in flashes) + + # Verify model was updated to whitelisted state + assert rtbh_rule.rstate_id == 4 + + def test_supernet_relation(self, app): + """Test evaluating a whitelist with a SUPERNET relation to a rule""" + # Create test data + whitelist_model = MagicMock(spec=Whitelist) + whitelist_model.id = 1 + + flashes = [] + + rtbh_rule = MagicMock(spec=RTBH) + rtbh_rule.rstate_id = 1 # Active state + + rule_key = "192.168.1.0/24" + whitelist_key = "192.168.0.0/16" + rtbh_rule_cache = {rule_key: rtbh_rule} + + results = [(rule_key, whitelist_key, Relation.SUPERNET)] + + # Mock required functions + with patch("flowapp.services.whitelist_service.whitelist_rtbh_rule") as mock_whitelist_rule, patch( + "flowapp.services.whitelist_service.withdraw_rtbh_route" + ) as mock_withdraw: + + # Call the function + with app.app_context(): + result = whitelist_service.evaluate_whitelist_against_rtbh_check_results( + whitelist_model, flashes, rtbh_rule_cache, results + ) + + # Verify the rule was whitelisted and route withdrawn + mock_whitelist_rule.assert_called_once_with(rtbh_rule, whitelist_model) + mock_withdraw.assert_called_once_with(rtbh_rule) + + # Verify the flash message + assert any("subnet of whitelist" in msg for msg in flashes) - # Should return empty list of messages, as no whitelist was found - assert len(messages) == 0 + # Verify the correct model was returned + assert result == whitelist_model From 2fcdb65bd18e7a6514abd347e9d127c934bf8522 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 14 Mar 2025 08:03:50 +0100 Subject: [PATCH 17/36] Improve rule management, logging, and RabbitMQ connection handling Added delete_by_rule_id method to RuleWhitelistCache for efficient cache entry removal. Used with statement for RabbitMQ connection handling in announce_to_rabbitmq to ensure proper cleanup. Introduced ALLOWED_COMMUNITIES check in rule and whitelist services for better filtering. Replaced print statements with structured logging using current_app.logger for better debugging. Improved logging in whitelist deletion to handle anomalies and prevent errors. Refactored logging configuration to use a standardized handler instead of loguru. Updated bulk_user_form.html example data to clarify user roles. --- flowapp/models/rules/whitelist.py | 15 ++++++++++ flowapp/output.py | 8 +++--- flowapp/services/rule_service.py | 16 +++++++---- flowapp/services/whitelist_service.py | 31 ++++++++++++++++----- flowapp/templates/forms/bulk_user_form.html | 6 ++-- flowapp/utils/app_factory.py | 26 +++++++++-------- flowapp/views/rules.py | 4 +++ 7 files changed, 75 insertions(+), 31 deletions(-) diff --git a/flowapp/models/rules/whitelist.py b/flowapp/models/rules/whitelist.py index e1857cc..6384bdd 100644 --- a/flowapp/models/rules/whitelist.py +++ b/flowapp/models/rules/whitelist.py @@ -128,6 +128,21 @@ def clean_by_whitelist_id(cls, whitelist_id: int): db.session.commit() return deleted + @classmethod + def delete_by_rule_id(cls, rule_id: int): + """ + Delete all cache entries with the given rule ID from the database + + Args: + rule_id (int): The ID of the rule to clean + + Returns: + int: Number of rows deleted + """ + deleted = cls.query.filter_by(rid=rule_id).delete() + db.session.commit() + return deleted + def __repr__(self): return f"" diff --git a/flowapp/output.py b/flowapp/output.py index b8774bf..fcb385f 100644 --- a/flowapp/output.py +++ b/flowapp/output.py @@ -84,10 +84,10 @@ def announce_to_rabbitmq(route: Dict[str, str]) -> None: credentials, ) - connection = pika.BlockingConnection(parameters) - channel = connection.channel() - channel.queue_declare(queue=queue) - channel.basic_publish(exchange="", routing_key=queue, body=json.dumps(route)) + with pika.BlockingConnection(parameters) as connection: + channel = connection.channel() + channel.queue_declare(queue=queue) + channel.basic_publish(exchange="", routing_key=queue, body=json.dumps(route)) else: current_app.logger.debug(f"Testing: {route}") diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index 85c177a..1611df0 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -9,6 +9,8 @@ from datetime import datetime from typing import Dict, List, Tuple +from flask import current_app + from flowapp import db, messages from flowapp.constants import RuleOrigin, RuleTypes, ANNOUNCE from flowapp.models import ( @@ -213,12 +215,14 @@ def create_or_update_rtbh_rule( author = f"{user_email} / {org_name}" # Check if rule is whitelisted - # get all not expired whitelists - whitelists = db.session.query(Whitelist).filter(Whitelist.expires > datetime.now()).all() - wl_cache = map_whitelists_to_strings(whitelists) - results = check_rule_against_whitelists(str(model), wl_cache.keys()) - # check rule against whitelists, stop search when rule is whitelisted first time - model = evaluate_rtbh_against_whitelists_check_results(user_id, model, flashes, author, wl_cache, results) + allowed_communities = current_app.config["ALLOWED_COMMUNITIES"] + if model.community_id in allowed_communities: + # get all not expired whitelists + whitelists = db.session.query(Whitelist).filter(Whitelist.expires > datetime.now()).all() + wl_cache = map_whitelists_to_strings(whitelists) + results = check_rule_against_whitelists(str(model), wl_cache.keys()) + # check rule against whitelists, stop search when rule is whitelisted first time + model = evaluate_rtbh_against_whitelists_check_results(user_id, model, flashes, author, wl_cache, results) announce_rtbh_route(model, author=author) # Log changes diff --git a/flowapp/services/whitelist_service.py b/flowapp/services/whitelist_service.py index 3f57f7e..d82fd29 100644 --- a/flowapp/services/whitelist_service.py +++ b/flowapp/services/whitelist_service.py @@ -5,9 +5,11 @@ This module provides business logic functions for creating, updating, and managing flow rules, separating these concerns from HTTP handling. """ - +from flask import current_app from typing import Dict, Tuple, List +import sqlalchemy + from flowapp import db from flowapp.constants import RuleOrigin, RuleTypes from flowapp.models import Whitelist, RuleWhitelistCache, get_whitelist_model_if_exists @@ -62,11 +64,13 @@ def create_or_update_whitelist( db.session.commit() # check RTBH rules against whitelist - all_rtbh_rules = RTBH.query.filter(RTBH.rstate_id == 1).all() - print(f"Found {len(all_rtbh_rules)} active RTBH rules") + allowed_communities = current_app.config["ALLOWED_COMMUNITIES"] + current_app.logger.info(f"allowed communities: {allowed_communities}") + # filter out RTBH rules that are not active and not in allowed communities + all_rtbh_rules = RTBH.query.filter(RTBH.rstate_id == 1, RTBH.community_id.in_(allowed_communities)).all() rtbh_rules_map = map_rtbh_rules_to_strings(all_rtbh_rules) result = check_whitelist_against_rules(rtbh_rules_map, str(model)) - print(f"Found {len(result)} matching RTBH rules") + current_app.logger.info(f"Found {len(result)} matching RTBH rules for whitelist {model}") model = evaluate_whitelist_against_rtbh_check_results(model, flashes, rtbh_rules_map, result) return model, flashes @@ -80,7 +84,7 @@ def evaluate_whitelist_against_rtbh_check_results( ) -> Whitelist: for rule_key, whitelist_key, relation in results: - print(f"whitelist {whitelist_key} is {relation} to Rule {rule_key}") + current_app.logger.info(f"whitelist {whitelist_key} is {relation} to Rule {rule_key}") match relation: case Relation.EQUAL: whitelist_rtbh_rule(rtbh_rule_cache[rule_key], whitelist_model) @@ -123,7 +127,11 @@ def delete_whitelist(whitelist_id: int) -> List[str]: flashes = [] if model: cached_rules = RuleWhitelistCache.get_by_whitelist_id(whitelist_id) + current_app.logger.info( + f"Deleting whitelist {whitelist_id}. Found {len(cached_rules)} cached rules to process." + ) for cached_rule in cached_rules: + current_app.logger.debug(f"Processing cached rule {cached_rule}") rule_model_type = RuleTypes(cached_rule.rtype) match rule_model_type: case RuleTypes.IPv4: @@ -133,15 +141,24 @@ def delete_whitelist(whitelist_id: int) -> List[str]: case RuleTypes.RTBH: rule_model = db.session.get(RTBH, cached_rule.rid) rorigin_type = RuleOrigin(cached_rule.rorigin) + current_app.logger.debug(f"Rule {rule_model} has origin {rorigin_type}") if rorigin_type == RuleOrigin.WHITELIST: flashes.append(f"Deleted rule {rule_model} created by this whitelist") - db.session.delete(rule_model) + try: + db.session.delete(rule_model) + except sqlalchemy.orm.exc.UnmappedInstanceError: + current_app.logger.warning( + f"RuleWhitelistCache Anomaly! Rule {rule_model} does not exist. Type {rule_model_type} RID {cached_rule.rid} ID {cached_rule.id} Skipping." + ) + elif rorigin_type == RuleOrigin.USER: flashes.append(f"Set rule {rule_model} back to state 'Active'") try: rule_model.rstate_id = 1 # Set rule state to "Active" again except AttributeError: - print(f"Rule {rule_model} does not exist, cache anomaly?. Skipping.") + current_app.logger.warning( + f"RuleWhitelistCache Anomaly! Rule {rule_model} does not exist. Type {rule_model_type} RID {cached_rule.rid} ID {cached_rule.id} Skipping." + ) else: author = f"{model.user.email} ({model.user.organization})" announce_rtbh_route(rule_model, author) diff --git a/flowapp/templates/forms/bulk_user_form.html b/flowapp/templates/forms/bulk_user_form.html index c41389e..b15b3ff 100644 --- a/flowapp/templates/forms/bulk_user_form.html +++ b/flowapp/templates/forms/bulk_user_form.html @@ -42,9 +42,9 @@

Example CSV data

             
 uuid-eppn,name,telefon,email,role,organizace,poznamka
-view@example.com,Test View,123,view@example.com,1,1,View
-user@example.com,Test User,123456,user@example.com,2,1,User
-admin@example.com,Test Admin,+420 111 111 111,admin@example.com,3,1,Admin
+view@example.com,Test View,123,view@example.com,1,1,user with view role (1)
+user@example.com,Test User,123456,user@example.com,2,1,regular user (role 2)
+admin@example.com,Test Admin,+420 111 111 111,admin@example.com,3,1,user with admin role (3)
             
         

Role

diff --git a/flowapp/utils/app_factory.py b/flowapp/utils/app_factory.py index b53b23b..bc6559b 100644 --- a/flowapp/utils/app_factory.py +++ b/flowapp/utils/app_factory.py @@ -1,8 +1,6 @@ import logging import babel from flask import redirect, render_template, request, session, url_for -from flask.logging import default_handler -from loguru import logger def register_blueprints(app, csrf=None): @@ -35,18 +33,24 @@ def register_blueprints(app, csrf=None): return app -class InterceptHandler(logging.Handler): +def configure_logging(app): + """Configure logging for the Flask application.""" - def emit(self, record): - logger_opt = logger.opt(depth=6, exception=record.exc_info, colors=True) - logger_opt.log(record.levelname, record.getMessage()) + # Remove all default handlers + for handler in app.logger.handlers[:]: + app.logger.removeHandler(handler) + # Define log format + log_format = "%(asctime)s | %(levelname)s | %(message)s" + log_datefmt = "%Y-%m-%d %H:%M:%S" -def configure_logging(app): - """Configure logging for the application.""" - # register loguru as handler - app.logger.removeHandler(default_handler) - app.logger.addHandler(InterceptHandler()) + # Create a new handler with the desired format + console_handler = logging.StreamHandler() + console_handler.setFormatter(logging.Formatter(log_format, datefmt=log_datefmt)) + + # Set logger level and attach the handler + app.logger.setLevel(logging.DEBUG) + app.logger.addHandler(console_handler) return app diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py index fc787cf..4cad6d2 100644 --- a/flowapp/views/rules.py +++ b/flowapp/views/rules.py @@ -29,6 +29,7 @@ get_user_nets, insert_initial_communities, ) +from flowapp.models.rules.whitelist import RuleWhitelistCache from flowapp.output import ROUTE_MODELS, announce_route, log_route, log_withdraw, RouteSources, Route from flowapp.services import rule_service from flowapp.utils import ( @@ -208,6 +209,9 @@ def delete_rule(rule_type, rule_id): model.id, f"{session['user_email']} / {session['user_org']}", ) + if enum_rule_type == RuleTypes.RTBH: + current_app.logger.debug(f"Deleting RTBH rule {rule_id} from cache") + RuleWhitelistCache.delete_by_rule_id(rule_id) # delete from db db.session.delete(model) From 04b5de126522c960ddf6a266bb16c2b568c69098 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 14 Mar 2025 09:10:02 +0100 Subject: [PATCH 18/36] Enhance rule visibility and styling in dashboard tables Added whitelist_rule_ids parameter to build_ip_tbody and build_rtbh_tbody macros to visually distinguish whitelist-created rules. Expired rules highlighted in table-warning. Whitelisted rules marked with table-success. Rules created by Whitelist marked with in table-secondary. Refactored index view to enrich rule data with whitelist info via enrich_rules_with_whitelist_info. Passed whitelist_rule_ids to dashboard rendering functions to ensure consistent rule styling. --- flowapp/templates/macros.html | 93 ++++++++++++++++++----------------- flowapp/views/dashboard.py | 62 +++++++++++++++++++++-- 2 files changed, 106 insertions(+), 49 deletions(-) diff --git a/flowapp/templates/macros.html b/flowapp/templates/macros.html index 653cc81..01a2fca 100644 --- a/flowapp/templates/macros.html +++ b/flowapp/templates/macros.html @@ -1,5 +1,5 @@ -{% macro build_ip_tbody(rules, today, editable=True, group_op=True) %} +{% macro build_ip_tbody(rules, today, editable=True, group_op=True, whitelist_rule_ids=None) %} {% for rule in rules %} {% if rule.next_header is defined %} @@ -8,7 +8,11 @@ {% set rtype_int = 4 %} {% endif %} - + {{ rule.source }}{% if rule.source_mask != none %}{{ '/' if rule.source_mask >= 0 else '' }}{{ rule.source_mask if rule.source_mask >= 0 else '' }}{% endif %} @@ -73,51 +77,52 @@ {% endmacro %} -{% macro build_rtbh_tbody(rules, today, editable=True, group_op=True) %} +{% macro build_rtbh_tbody(rules, today, editable=True, group_op=True, whitelist_rule_ids=None) %} {% for rule in rules %} - - - {% if rule.ipv4 %} - {{ rule.ipv4 }}{{ '/' if rule.ipv4_mask else '' }}{{rule.ipv4_mask|default("", True)}} - {% endif %} - {% if rule.ipv6 %} - {{ rule.ipv6 }}{{ '/' if rule.ipv6_mask else '' }} {{rule.ipv6_mask|default("", True)}} - {% endif %} - - - {{ rule.community.name }} - - - - {{ rule.expires|strftime }} - - - {{ rule.user.name }} - - - {% if editable %} - - - - - - + + + {% if rule.ipv4 %} + {{ rule.ipv4 }}{{ '/' if rule.ipv4_mask else '' }}{{rule.ipv4_mask|default("", True)}} + {% endif %} + {% if rule.ipv6 %} + {{ rule.ipv6 }}{{ '/' if rule.ipv6_mask else '' }} {{rule.ipv6_mask|default("", True)}} + {% endif %} + + + {{ rule.community.name }} + + + + {{ rule.expires|strftime }} + + + {{ rule.user.name }} + + + {% if editable %} + + + + + + + {% endif %} + {% if rule.comment %} + + {% endif %} + + {% if editable and group_op %} + + + {% endif %} - {% if rule.comment %} - - {% endif %} - - {% if editable and group_op %} - - - - {% endif %} {% endfor %} diff --git a/flowapp/views/dashboard.py b/flowapp/views/dashboard.py index 8ed262f..3980b88 100644 --- a/flowapp/views/dashboard.py +++ b/flowapp/views/dashboard.py @@ -48,7 +48,7 @@ def whois(ip_address): def index(rtype=None, rstate="active"): """ dispatcher object for the dashboard - :param rtype: ipv4, ipv6, rtbh + :param rtype: ipv4, ipv6, rtbh, whitelist :param rstate: :return: view from view factory """ @@ -84,7 +84,6 @@ def index(rtype=None, rstate="active"): data_handler_module = current_app.config["DASHBOARD"].get(rtype).get("data_handler", models) data_handler_method = current_app.config["DASHBOARD"].get(rtype).get("data_handler_method", "get_ip_rules") - # get search query, sort order and sort key from request or session get_search_query = request.args.get(SEARCH_ARG, session.get(SEARCH_ARG, "")) get_sort_key = request.args.get(SORT_ARG, session.get(SORT_ARG, DEFAULT_SORT)) @@ -109,6 +108,10 @@ def index(rtype=None, rstate="active"): # get the handler and the data handler = getattr(data_handler_module, data_handler_method) rules = handler(rtype, rstate, get_sort_key, get_sort_order) + + # Enrich rules with whitelist information + rules, whitelist_rule_ids = enrich_rules_with_whitelist_info(rules, rtype) + session[RULES_KEY] = [rule.id for rule in rules] # search rules if get_search_query: @@ -138,6 +141,7 @@ def index(rtype=None, rstate="active"): macro_tbody=macro_tbody, macro_thead=macro_thead, macro_tfoot=macro_tfoot, + whitelist_rule_ids=whitelist_rule_ids, ) @@ -148,18 +152,22 @@ def create_dashboard_table_body( group_op=True, macro_file="macros.html", macro_name="build_ip_tbody", + whitelist_rule_ids=None, ): """ create the table body for the dashboard using a jinja2 macro :param rules: list of rules :param rtype: ipv4, ipv6, rtbh + :param editable: whether rules can be edited + :param group_op: whether group operations are allowed :param macro_file: the file where the macro is defined :param macro_name: the name of the macro + :param whitelist_rule_ids: set of rule IDs that were created by a whitelist """ tstring = "{% " tstring = tstring + f"from '{macro_file}' import {macro_name}" tstring = tstring + " %} {{" - tstring = tstring + f" {macro_name}(rules, today, editable, group_op) " + "}}" + tstring = tstring + f" {macro_name}(rules, today, editable, group_op, whitelist_rule_ids) " + "}}" dashboard_table_body = render_template_string( tstring, @@ -167,6 +175,7 @@ def create_dashboard_table_body( today=datetime.now(), editable=editable, group_op=group_op, + whitelist_rule_ids=whitelist_rule_ids or set(), ) return dashboard_table_body @@ -246,6 +255,7 @@ def create_admin_response( macro_tbody="build_ip_tbody", macro_thead="build_rules_thead", macro_tfoot="build_group_buttons_tfoot", + whitelist_rule_ids=None, ): """ Admin can see and edit any rules @@ -257,7 +267,9 @@ def create_admin_response( :return: """ - dashboard_table_body = create_dashboard_table_body(rules, rtype, macro_file=macro_file, macro_name=macro_tbody) + dashboard_table_body = create_dashboard_table_body( + rules, rtype, macro_file=macro_file, macro_name=macro_tbody, whitelist_rule_ids=whitelist_rule_ids + ) dashboard_table_head = create_dashboard_table_head( rules_columns=table_columns, @@ -313,6 +325,7 @@ def create_user_response( macro_tbody="build_ip_tbody", macro_thead="build_rules_thead", macro_tfoot="build_rules_tfoot", + whitelist_rule_ids=None, ): """ Filter out the rules for normal users @@ -343,9 +356,10 @@ def create_user_response( group_op=False, macro_file=macro_file, macro_name=macro_tbody, + whitelist_rule_ids=whitelist_rule_ids, ) dashboard_table_editable = create_dashboard_table_body( - rules_editable, rtype, macro_file=macro_file, macro_name=macro_tbody + rules_editable, rtype, macro_file=macro_file, macro_name=macro_tbody, whitelist_rule_ids=whitelist_rule_ids ) dashboard_table_editable_head = create_dashboard_table_head( rules_columns=table_columns, @@ -419,6 +433,7 @@ def create_view_response( macro_tbody="build_ip_tbody", macro_thead="build_rules_thead", macro_tfoot="build_rules_tfoot", + whitelist_rule_ids=None, ): """ Filter out the rules for normal users @@ -433,6 +448,7 @@ def create_view_response( group_op=False, macro_file=macro_file, macro_name=macro_tbody, + whitelist_rule_ids=whitelist_rule_ids, ) dashboard_table_head = create_dashboard_table_head( @@ -482,3 +498,39 @@ def filter_rules(rules, get_search_query): result.append(rules[idx]) return result + + +def enrich_rules_with_whitelist_info(rules, rule_type): + """ + Enrich rules with whitelist information from RuleWhitelistCache. + + Args: + rules: List of rule objects (Flowspec4, Flowspec6, RTBH) + rule_type: String identifier of rule type ("ipv4", "ipv6", "rtbh") + + Returns: + Tuple of (rules, whitelist_rule_ids) where whitelist_rule_ids is a set of + rule IDs that were created by a whitelist. + """ + from flowapp.models.rules.whitelist import RuleWhitelistCache + from flowapp.constants import RuleTypes, RuleOrigin + + # Map rule type string to enum value + rule_type_map = {"ipv4": RuleTypes.IPv4.value, "ipv6": RuleTypes.IPv6.value, "rtbh": RuleTypes.RTBH.value} + + # Get all rule IDs + rule_ids = [rule.id for rule in rules] + + # No rules to process + if not rule_ids: + return rules, set() + + # Query the cache for these rule IDs + cache_entries = RuleWhitelistCache.query.filter( + RuleWhitelistCache.rid.in_(rule_ids), RuleWhitelistCache.rtype == rule_type_map.get(rule_type) + ).all() + + # Create a set of rule IDs that were created by a whitelist + whitelist_rule_ids = {entry.rid for entry in cache_entries if entry.rorigin == RuleOrigin.WHITELIST.value} + + return rules, whitelist_rule_ids From ed594cf0a13e6e06512e83a3490c06b2329a10d6 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 14 Mar 2025 12:05:30 +0100 Subject: [PATCH 19/36] Improve rule-whitelist interactions and UI updates Added count_by_rule method in RuleWhitelistCache to count cache entries per rule. Modified evaluate_rtbh_against_whitelists_check_results to process all whitelist matches instead of stopping at the first. Adjusted RTBH rule filtering to include both active (rstate_id=1) and whitelisted (rstate_id=4) rules when checking against whitelists. Improved whitelist deletion logic to only revert user-created rules to Active if they have no other whitelist references. Updated build_whitelist_tbody macro to accept whitelist_rule_ids for consistent rule highlighting. --- flowapp/models/rules/whitelist.py | 14 ++++++++++++++ flowapp/services/rule_service.py | 7 ++++--- flowapp/services/whitelist_service.py | 8 ++++---- flowapp/templates/macros.html | 2 +- 4 files changed, 23 insertions(+), 8 deletions(-) diff --git a/flowapp/models/rules/whitelist.py b/flowapp/models/rules/whitelist.py index 6384bdd..29f2b08 100644 --- a/flowapp/models/rules/whitelist.py +++ b/flowapp/models/rules/whitelist.py @@ -143,6 +143,20 @@ def delete_by_rule_id(cls, rule_id: int): db.session.commit() return deleted + @classmethod + def count_by_rule(cls, rule_id: int, rule_type: RuleTypes): + """ + Count the number of cache entries for the given rule + + Args: + rule_id (int): The ID of the rule to count + rule_type (RuleTypes): The type of the rule + + Returns: + int: Number of cache entries + """ + return cls.query.filter_by(rid=rule_id, rtype=rule_type.value).count() + def __repr__(self): return f"" diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index 1611df0..e0f1149 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -239,12 +239,15 @@ def evaluate_rtbh_against_whitelists_check_results( wl_cache: Dict[str, Whitelist], results: List[Tuple[str, str, Relation]], ) -> RTBH: + """ + Evaluate RTBH rule against whitelist check results. + Process all results for cases where rule is whitelisted by several whitelists. + """ for rule, whitelist_key, relation in results: match relation: case Relation.EQUAL: model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) flashes.append(f" Rule is equal to active whitelist {whitelist_key}. Rule is whitelisted.") - break case Relation.SUBNET: parts = subtract_network(target=str(model), whitelist=whitelist_key) wl_id = wl_cache[whitelist_key].id @@ -257,11 +260,9 @@ def evaluate_rtbh_against_whitelists_check_results( model.rstate_id = 4 add_rtbh_rule_to_cache(model, wl_id, RuleOrigin.USER) db.session.commit() - break case Relation.SUPERNET: model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) flashes.append(f" Rule is subnet of active whitelist {whitelist_key}. Rule is whitelisted.") - break return model diff --git a/flowapp/services/whitelist_service.py b/flowapp/services/whitelist_service.py index d82fd29..09df687 100644 --- a/flowapp/services/whitelist_service.py +++ b/flowapp/services/whitelist_service.py @@ -65,9 +65,8 @@ def create_or_update_whitelist( # check RTBH rules against whitelist allowed_communities = current_app.config["ALLOWED_COMMUNITIES"] - current_app.logger.info(f"allowed communities: {allowed_communities}") - # filter out RTBH rules that are not active and not in allowed communities - all_rtbh_rules = RTBH.query.filter(RTBH.rstate_id == 1, RTBH.community_id.in_(allowed_communities)).all() + # filter out RTBH rules that are not active or whitelisted and not in allowed communities + all_rtbh_rules = RTBH.query.filter(RTBH.rstate_id.in_([1, 4]), RTBH.community_id.in_(allowed_communities)).all() rtbh_rules_map = map_rtbh_rules_to_strings(all_rtbh_rules) result = check_whitelist_against_rules(rtbh_rules_map, str(model)) current_app.logger.info(f"Found {len(result)} matching RTBH rules for whitelist {model}") @@ -133,6 +132,7 @@ def delete_whitelist(whitelist_id: int) -> List[str]: for cached_rule in cached_rules: current_app.logger.debug(f"Processing cached rule {cached_rule}") rule_model_type = RuleTypes(cached_rule.rtype) + cache_entries_count = RuleWhitelistCache.count_by_rule(cached_rule.rid, rule_model_type) match rule_model_type: case RuleTypes.IPv4: rule_model = db.session.get(Flowspec4, cached_rule.rid) @@ -151,7 +151,7 @@ def delete_whitelist(whitelist_id: int) -> List[str]: f"RuleWhitelistCache Anomaly! Rule {rule_model} does not exist. Type {rule_model_type} RID {cached_rule.rid} ID {cached_rule.id} Skipping." ) - elif rorigin_type == RuleOrigin.USER: + elif rorigin_type == RuleOrigin.USER and cache_entries_count == 1: flashes.append(f"Set rule {rule_model} back to state 'Active'") try: rule_model.rstate_id = 1 # Set rule state to "Active" again diff --git a/flowapp/templates/macros.html b/flowapp/templates/macros.html index 01a2fca..2f5dd9f 100644 --- a/flowapp/templates/macros.html +++ b/flowapp/templates/macros.html @@ -130,7 +130,7 @@ {% endmacro %} -{% macro build_whitelist_tbody(rules, today, editable=True, group_op=True) %} +{% macro build_whitelist_tbody(rules, today, editable=True, group_op=True, whitelist_rule_ids=None) %} {% for rule in rules %} From 7ca0b03a874fccf435bd87054f29cdef7f92f2fa Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 14 Mar 2025 12:10:24 +0100 Subject: [PATCH 20/36] fixed app config for tests --- flowapp/tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flowapp/tests/conftest.py b/flowapp/tests/conftest.py index 65919af..5960434 100644 --- a/flowapp/tests/conftest.py +++ b/flowapp/tests/conftest.py @@ -67,6 +67,7 @@ def app(request): SECRET_KEY="testkeysession", LOCAL_USER_UUID="jiri.vrany@cesnet.cz", LOCAL_AUTH=True, + ALLOWED_COMMUNITIES=[1, 2, 3], ) print("\n----- CREATE FLASK APPLICATION\n") From 72fd0c95382052867a40b91c3ee1dd4af8ed2ae6 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 14 Mar 2025 13:19:15 +0100 Subject: [PATCH 21/36] Enhance IPv4/IPv6 checks by introducing _is_same_ip_version, update relevant functions to handle mixed IP versions gracefully, and add new test cases in test_whitelist_common.py to ensure proper coverage --- flowapp/services/whitelist_common.py | 136 +++++++++++++++++-------- flowapp/tests/test_whitelist_common.py | 118 +++++++++++++++++++-- 2 files changed, 206 insertions(+), 48 deletions(-) diff --git a/flowapp/services/whitelist_common.py b/flowapp/services/whitelist_common.py index 1730402..065b79a 100644 --- a/flowapp/services/whitelist_common.py +++ b/flowapp/services/whitelist_common.py @@ -45,6 +45,19 @@ class Relation(Enum): DIFFERENT = auto() +def _is_same_ip_version(addr1: str, addr2: str) -> bool: + """ + Check if two IP addresses/networks are of the same IP version. + + :param addr1: First IP address or network string + :param addr2: Second IP address or network string + :return: True if both addresses are of the same IP version (IPv4 or IPv6), False otherwise + """ + is_ipv4_1 = "." in addr1 + is_ipv4_2 = "." in addr2 + return is_ipv4_1 == is_ipv4_2 + + @lru_cache(maxsize=1024) def get_network(address: str) -> ipaddress.IPv4Network | ipaddress.IPv6Network: """ @@ -58,22 +71,34 @@ def get_network(address: str) -> ipaddress.IPv4Network | ipaddress.IPv6Network: def check_whitelist_to_rule_relation(rule: str, whitelist_entry: str) -> Relation: """ - Checks if the whitelist network is a subnet or supernet or exactly the same as the rule network. + Checks if the whitelist network is a subnet or supernet or exactly the same as the rule network. Uses cached network objects for better performance. + If the rule and whitelist are different IP versions (IPv4 vs IPv6), + they are treated as different networks. + :param rule: The IP address or network to check (e.g., "192.168.1.1" or "192.168.1.0/24") :param whitelist_entry: The allowed network to compare against (e.g., "192.168.1.0/24") :return: Relation between the two networks """ - rule_net = get_network(rule) - whitelist_net = get_network(whitelist_entry) - if whitelist_net == rule_net: - return Relation.EQUAL - if whitelist_net.supernet_of(rule_net): - return Relation.SUPERNET - if whitelist_net.subnet_of(rule_net): - return Relation.SUBNET - else: + # First check if IP versions are the same + if not _is_same_ip_version(rule, whitelist_entry): + return Relation.DIFFERENT + + try: + rule_net = get_network(rule) + whitelist_net = get_network(whitelist_entry) + + if whitelist_net == rule_net: + return Relation.EQUAL + if whitelist_net.supernet_of(rule_net): + return Relation.SUPERNET + if whitelist_net.subnet_of(rule_net): + return Relation.SUBNET + else: + return Relation.DIFFERENT + except (ValueError, TypeError): + # Handle any other errors that might occur during comparison return Relation.DIFFERENT @@ -82,34 +107,45 @@ def subtract_network(target: str, whitelist: str) -> List[str]: Computes the remaining parts of a network after removing the whitelist subnet. Uses cached network objects for better performance. + If the target and whitelist are different IP versions (IPv4 vs IPv6), + the original target is returned unchanged. + :param target: The main network (e.g., "192.168.1.0/24") :param whitelist: The subnet to remove (e.g., "192.168.1.128/25") :return: A list of remaining subnets as strings """ - target_net = get_network(target) - whitelist_net = get_network(whitelist) + # First check if IP versions are the same + if not _is_same_ip_version(target, whitelist): + return [target] - # Check if the whitelist is actually a subnet - if check_whitelist_to_rule_relation(target, whitelist) != Relation.SUBNET: - return [target] # Return the full network if whitelist isn't a valid subnet + try: + target_net = get_network(target) + whitelist_net = get_network(whitelist) - remaining = [] + # Check if the whitelist is actually a subnet + if check_whitelist_to_rule_relation(target, whitelist) != Relation.SUBNET: + return [target] # Return the full network if whitelist isn't a valid subnet - # Compute ranges before and after the whitelist - if whitelist_net.network_address > target_net.network_address: - # Before the whitelist - start = target_net.network_address - end = whitelist_net.network_address - 1 - remaining.extend(ipaddress.summarize_address_range(start, end)) + remaining = [] - if whitelist_net.broadcast_address < target_net.broadcast_address: - # After the whitelist - start = whitelist_net.broadcast_address + 1 - end = target_net.broadcast_address - remaining.extend(ipaddress.summarize_address_range(start, end)) + # Compute ranges before and after the whitelist + if whitelist_net.network_address > target_net.network_address: + # Before the whitelist + start = target_net.network_address + end = whitelist_net.network_address - 1 + remaining.extend(ipaddress.summarize_address_range(start, end)) - # Convert to string format - return [str(net) for net in remaining] + if whitelist_net.broadcast_address < target_net.broadcast_address: + # After the whitelist + start = whitelist_net.broadcast_address + 1 + end = target_net.broadcast_address + remaining.extend(ipaddress.summarize_address_range(start, end)) + + # Convert to string format + return [str(net) for net in remaining] + except (ValueError, TypeError): + # Return the original target in case of any error + return [target] def check_rule_against_whitelists(rule: str, whitelists: List[str]) -> List[Tuple]: @@ -123,12 +159,21 @@ def check_rule_against_whitelists(rule: str, whitelists: List[str]) -> List[Tupl :return: tuple of rule, whitelist and relation for each whitelists that is not DIFFERENT """ # Pre-cache the rule network since it will be used multiple times - get_network(rule) + try: + get_network(rule) + except (ValueError, TypeError): + # Return empty list if rule is not a valid network + return [] + items = [] for whitelist in whitelists: - rel = check_whitelist_to_rule_relation(rule, whitelist) - if rel != Relation.DIFFERENT: - items.append((rule, whitelist, rel)) + try: + rel = check_whitelist_to_rule_relation(rule, whitelist) + if rel != Relation.DIFFERENT: + items.append((rule, whitelist, rel)) + except (ValueError, TypeError): + # Skip this whitelist if there's an error comparing + continue return items @@ -138,17 +183,26 @@ def check_whitelist_against_rules(rules: List[str], whitelist: str) -> List[Tupl Creates a cached rule network object for better performance. Reduces list of rules, where the Relation is not DIFFERENT - :param rule: The IP address or network to check against - :param whitelists: List of whitelist networks to check - :return: tuple of rule, whitelist and relation for each whitelists that is not DIFFERENT + :param rules: List of rule networks to check against + :param whitelist: The whitelist network to check + :return: tuple of rule, whitelist and relation for each rules that is not DIFFERENT """ - # Pre-cache the rule network since it will be used multiple times - get_network(whitelist) + # Pre-cache the whitelist network since it will be used multiple times + try: + get_network(whitelist) + except (ValueError, TypeError): + # Return empty list if whitelist is not a valid network + return [] + items = [] for rule in rules: - rel = check_whitelist_to_rule_relation(rule, whitelist) - if rel != Relation.DIFFERENT: - items.append((rule, whitelist, rel)) + try: + rel = check_whitelist_to_rule_relation(rule, whitelist) + if rel != Relation.DIFFERENT: + items.append((rule, whitelist, rel)) + except (ValueError, TypeError): + # Skip this rule if there's an error comparing + continue return items diff --git a/flowapp/tests/test_whitelist_common.py b/flowapp/tests/test_whitelist_common.py index 9b98e9a..a843aff 100644 --- a/flowapp/tests/test_whitelist_common.py +++ b/flowapp/tests/test_whitelist_common.py @@ -6,6 +6,7 @@ check_whitelist_to_rule_relation, subtract_network, clear_network_cache, + _is_same_ip_version, # New helper function ) @@ -130,16 +131,119 @@ def test_single_ip_as_network(): assert check_whitelist_to_rule_relation("2001:db8::1/128", "2001:db8::/32") == Relation.SUPERNET -def test_invalid_input(): +def test_is_same_ip_version(): + """Test the new helper function to check if two addresses are of the same IP version""" + # IPv4 cases + assert _is_same_ip_version("192.168.1.0/24", "10.0.0.0/8") is True + assert _is_same_ip_version("192.168.1.1", "10.0.0.1") is True + + # IPv6 cases + assert _is_same_ip_version("2001:db8::/32", "fe80::/64") is True + assert _is_same_ip_version("2001:db8::1", "fe80::1") is True + + # Mixed cases + assert _is_same_ip_version("192.168.1.0/24", "2001:db8::/32") is False + assert _is_same_ip_version("192.168.1.1", "fe80::1") is False + + +def test_check_whitelist_to_rule_relation_mixed_versions(): + """Test the check_whitelist_to_rule_relation function with mixed IP versions""" + clear_network_cache() + + # IPv4 rule with IPv6 whitelist + assert check_whitelist_to_rule_relation("192.168.1.0/24", "2001:db8::/32") == Relation.DIFFERENT + + # IPv6 rule with IPv4 whitelist + assert check_whitelist_to_rule_relation("2001:db8::/32", "192.168.1.0/24") == Relation.DIFFERENT + + # Invalid IP addresses should return DIFFERENT + assert check_whitelist_to_rule_relation("invalid", "192.168.1.0/24") == Relation.DIFFERENT + assert check_whitelist_to_rule_relation("192.168.1.0/24", "invalid") == Relation.DIFFERENT + + +def test_subtract_network_mixed_versions(): + """Test the subtract_network function with mixed IP versions""" + clear_network_cache() + + # IPv4 target with IPv6 whitelist - should return original target + result = subtract_network("192.168.1.0/24", "2001:db8::/32") + assert result == ["192.168.1.0/24"] + + # IPv6 target with IPv4 whitelist - should return original target + result = subtract_network("2001:db8::/32", "192.168.1.0/24") + assert result == ["2001:db8::/32"] + + # Invalid addresses - should return original target + result = subtract_network("invalid", "192.168.1.0/24") + assert result == ["invalid"] + result = subtract_network("192.168.1.0/24", "invalid") + assert result == ["192.168.1.0/24"] + + +def test_check_rule_against_whitelists_mixed_versions(): + """Test the check_rule_against_whitelists function with mixed IP versions""" + clear_network_cache() + + # IPv4 rule against mixed whitelist + rule = "192.168.1.0/24" + whitelists = [ + "192.168.0.0/16", # IPv4 - should match + "2001:db8::/32", # IPv6 - should not match + "10.0.0.0/8", # IPv4 - should not match (different network) + ] + + results = check_rule_against_whitelists(rule, whitelists) + assert len(results) == 1 + assert results[0][0] == rule + assert results[0][1] == "192.168.0.0/16" + assert results[0][2] == Relation.SUPERNET + + # IPv6 rule against mixed whitelist + rule = "2001:db8:1::/48" + whitelists = [ + "192.168.0.0/16", # IPv4 - should not match + "2001:db8::/32", # IPv6 - should match + "2002:db8::/32", # IPv6 - should not match (different network) + ] + + results = check_rule_against_whitelists(rule, whitelists) + assert len(results) == 1 + assert results[0][0] == rule + assert results[0][1] == "2001:db8::/32" + assert results[0][2] == Relation.SUPERNET + + +def test_check_whitelist_against_rules_mixed_versions(): + """Test the check_whitelist_against_rules function with mixed IP versions""" clear_network_cache() - with pytest.raises(ValueError): - check_whitelist_to_rule_relation("invalid", "192.168.1.0/24") - with pytest.raises(ValueError): - check_whitelist_to_rule_relation("192.168.1.0/24", "invalid") + # IPv4 whitelist against mixed rules + whitelist = "192.168.1.0/24" + rules = [ + "192.168.0.0/16", # IPv4 - should match + "2001:db8::/32", # IPv6 - should not match + "10.0.0.0/8", # IPv4 - should not match (different network) + ] + + results = check_whitelist_against_rules(rules, whitelist) + assert len(results) == 1 + assert results[0][0] == "192.168.0.0/16" + assert results[0][1] == whitelist + assert results[0][2] == Relation.SUBNET - with pytest.raises(ValueError): - subtract_network("invalid", "192.168.1.0/24") + # IPv6 whitelist against mixed rules + whitelist = "2001:db8:1::/48" + rules = [ + "192.168.0.0/16", # IPv4 - should not match + "2001:db8::/32", # IPv6 - should match + "2002:db8::/32", # IPv6 - should not match (different network) + ] + + results = check_whitelist_against_rules(rules, whitelist) + assert len(results) == 1 + assert results[0][0] == "2001:db8::/32" + assert results[0][1] == whitelist + assert results[0][2] == Relation.SUBNET if __name__ == "__main__": From c9969c603116ba7f095fa9043cc0641894a609e8 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Mon, 17 Mar 2025 12:42:43 +0100 Subject: [PATCH 22/36] improved logging of whitelist service --- flowapp/services/whitelist_service.py | 38 +++++++++++++++++++-------- 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/flowapp/services/whitelist_service.py b/flowapp/services/whitelist_service.py index 09df687..b749b4b 100644 --- a/flowapp/services/whitelist_service.py +++ b/flowapp/services/whitelist_service.py @@ -88,17 +88,21 @@ def evaluate_whitelist_against_rtbh_check_results( case Relation.EQUAL: whitelist_rtbh_rule(rtbh_rule_cache[rule_key], whitelist_model) withdraw_rtbh_route(rtbh_rule_cache[rule_key]) - flashes.append(f"Active rule {rule_key} is equal to whitelist {whitelist_key}. Rule is whitelisted.") + msg = "Existing active rule {rule_key} is equal to whitelist {whitelist_key}. Rule is now whitelisted." + flashes.append(msg) + current_app.logger.info(msg) case Relation.SUBNET: parts = subtract_network(target=rule_key, whitelist=whitelist_key) wl_id = whitelist_model.id - flashes.append( - f" Rule {rule_key} is supernet of whitelist {whitelist_key}. Rule is whitelisted, {len(parts)} subnet rules created." - ) + msg = f"Rule {rule_key} is supernet of whitelist {whitelist_key}. Rule is whitelisted, {len(parts)} subnet rules will be created." + flashes.append(msg) + current_app.logger.info(msg) for network in parts: rule_model = rtbh_rule_cache[rule_key] create_rtbh_from_whitelist_parts(rule_model, wl_id, whitelist_key, network) - flashes.append(f"DEBUG: Created RTBH rule for {network}, from whitelist {whitelist_key}") + msg = f"Created RTBH rule from {rule_model.id} {network} parted by whitelist {whitelist_key}." + flashes.append(msg) + current_app.logger.info(msg) rule_model.rstate_id = 4 add_rtbh_rule_to_cache(rule_model, wl_id, RuleOrigin.USER) db.session.commit() @@ -106,7 +110,11 @@ def evaluate_whitelist_against_rtbh_check_results( whitelist_rtbh_rule(rtbh_rule_cache[rule_key], whitelist_model) withdraw_rtbh_route(rtbh_rule_cache[rule_key]) - flashes.append(f"Active rule {rule_key} is subnet of whitelist {whitelist_key}. Rule is whitelisted.") + msg = ( + f"Existing active rule {rule_key} is subnet of whitelist {whitelist_key}. Rule is now whitelisted." + ) + current_app.logger.info(msg) + flashes.append(msg) return whitelist_model @@ -127,10 +135,9 @@ def delete_whitelist(whitelist_id: int) -> List[str]: if model: cached_rules = RuleWhitelistCache.get_by_whitelist_id(whitelist_id) current_app.logger.info( - f"Deleting whitelist {whitelist_id}. Found {len(cached_rules)} cached rules to process." + f"Deleting whitelist {whitelist_id} {model}. Found {len(cached_rules)} cached rules to process." ) for cached_rule in cached_rules: - current_app.logger.debug(f"Processing cached rule {cached_rule}") rule_model_type = RuleTypes(cached_rule.rtype) cache_entries_count = RuleWhitelistCache.count_by_rule(cached_rule.rid, rule_model_type) match rule_model_type: @@ -143,7 +150,9 @@ def delete_whitelist(whitelist_id: int) -> List[str]: rorigin_type = RuleOrigin(cached_rule.rorigin) current_app.logger.debug(f"Rule {rule_model} has origin {rorigin_type}") if rorigin_type == RuleOrigin.WHITELIST: - flashes.append(f"Deleted rule {rule_model} created by this whitelist") + msg = f"Deleted {rule_model_type} rule {rule_model} created by whitelist {model}" + current_app.logger.info(msg) + flashes.append(msg) try: db.session.delete(rule_model) except sqlalchemy.orm.exc.UnmappedInstanceError: @@ -152,7 +161,9 @@ def delete_whitelist(whitelist_id: int) -> List[str]: ) elif rorigin_type == RuleOrigin.USER and cache_entries_count == 1: - flashes.append(f"Set rule {rule_model} back to state 'Active'") + msg = f"Set rule {rule_model} back to state 'Active'" + current_app.logger.info(msg) + flashes.append(msg) try: rule_model.rstate_id = 1 # Set rule state to "Active" again except AttributeError: @@ -163,10 +174,15 @@ def delete_whitelist(whitelist_id: int) -> List[str]: author = f"{model.user.email} ({model.user.organization})" announce_rtbh_route(rule_model, author) - flashes.append(f"Deleted cache entries for whitelist {whitelist_id}") + msg = f"Deleted cache entries for whitelist {whitelist_id} {model}" + current_app.logger.info(msg) + flashes.append(msg) RuleWhitelistCache.clean_by_whitelist_id(whitelist_id) db.session.delete(model) db.session.commit() + msg = f"Deleted whitelist {whitelist_id} {model}" + flashes.append(msg) + current_app.logger.info(msg) return flashes From 412dd845fcbd4d831fd6bed5a28b74c28f3845d5 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Mon, 17 Mar 2025 19:49:11 +0100 Subject: [PATCH 23/36] Refactor create_app config loading, enhance RTBH logging and flash messages, add file-based logging in app_factory, and update tests to assert flashes instead of specific messages. --- flowapp/__init__.py | 14 +++++++------- flowapp/services/rule_service.py | 24 +++++++++++++++++------- flowapp/services/whitelist_common.py | 4 +++- flowapp/tests/test_rule_service.py | 7 +++---- flowapp/tests/test_whitelist_service.py | 6 +++--- flowapp/utils/app_factory.py | 20 ++++++++++++++++---- withdraw_expired | 1 + 7 files changed, 50 insertions(+), 26 deletions(-) create mode 100644 withdraw_expired diff --git a/flowapp/__init__.py b/flowapp/__init__.py index b201f25..61543e7 100644 --- a/flowapp/__init__.py +++ b/flowapp/__init__.py @@ -25,6 +25,13 @@ def create_app(config_object=None): app = Flask(__name__) + # Load the default configuration for dashboard and main menu + app.config.from_object(InstanceConfig) + if config_object: + app.config.from_object(config_object) + + app.config.setdefault("VERSION", __version__) + # SSO configuration SSO_ATTRIBUTE_MAP = { "eppn": (True, "eppn"), @@ -36,13 +43,6 @@ def create_app(config_object=None): migrate.init_app(app, db) csrf.init_app(app) - # Load the default configuration for dashboard and main menu - app.config.from_object(InstanceConfig) - if config_object: - app.config.from_object(config_object) - - app.config.setdefault("VERSION", __version__) - # Init SSO ext.init_app(app) diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index e0f1149..6b8896c 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -247,22 +247,32 @@ def evaluate_rtbh_against_whitelists_check_results( match relation: case Relation.EQUAL: model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) - flashes.append(f" Rule is equal to active whitelist {whitelist_key}. Rule is whitelisted.") + msg = f"RTBH Rule {model.id} {model} is equal to active whitelist {whitelist_key}. Rule is whitelisted." + flashes.append(msg) + current_app.logger.info(msg) case Relation.SUBNET: parts = subtract_network(target=str(model), whitelist=whitelist_key) wl_id = wl_cache[whitelist_key].id - flashes.append( - f" Rule is supernet of active whitelist {whitelist_key}. Rule is whitelisted, {len(parts)} subnet rules created." - ) + msg = f"RTBH Rule {model.id} {model} is supernet of active whitelist {whitelist_key}. Rule is whitelisted, {len(parts)} subnet rules created." + flashes.append(msg) + current_app.logger.info(msg) for network in parts: - create_rtbh_from_whitelist_parts(model, wl_id, whitelist_key, network, author, user_id) - flashes.append(f"DEBUG: Created RTBH rule for {network}, from whitelist {whitelist_key}") + new_rule = create_rtbh_from_whitelist_parts(model, wl_id, whitelist_key, network, author, user_id) + msg = ( + f"Created RTBH rule {new_rule.id} {new_rule} for {network} parted by whitelist {whitelist_key}" + ) + flashes.append(msg) + current_app.logger.info(msg) model.rstate_id = 4 add_rtbh_rule_to_cache(model, wl_id, RuleOrigin.USER) db.session.commit() case Relation.SUPERNET: model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) - flashes.append(f" Rule is subnet of active whitelist {whitelist_key}. Rule is whitelisted.") + msg = ( + f"RTBH Rule {model.id} {model} is subnet of active whitelist {whitelist_key}. Rule is whitelisted." + ) + current_app.logger.info(msg) + flashes.append(msg) return model diff --git a/flowapp/services/whitelist_common.py b/flowapp/services/whitelist_common.py index 065b79a..3807dca 100644 --- a/flowapp/services/whitelist_common.py +++ b/flowapp/services/whitelist_common.py @@ -216,7 +216,7 @@ def clear_network_cache() -> None: def create_rtbh_from_whitelist_parts( model: RTBH, wl_id: int, whitelist_key: str, network: str, rule_owner: str = "", user_id: int = 0 -) -> None: +) -> RTBH: # default values from model rule_owner = rule_owner or model.get_author() user_id = user_id or model.user_id @@ -240,3 +240,5 @@ def create_rtbh_from_whitelist_parts( add_rtbh_rule_to_cache(new_model, wl_id, RuleOrigin.WHITELIST) announce_rtbh_route(new_model, rule_owner) log_route(user_id, model, RuleTypes.RTBH, rule_owner) + + return new_model diff --git a/flowapp/tests/test_rule_service.py b/flowapp/tests/test_rule_service.py index 16daf10..490d2c2 100644 --- a/flowapp/tests/test_rule_service.py +++ b/flowapp/tests/test_rule_service.py @@ -444,7 +444,7 @@ def test_equal_relation(self, app, whitelist_fixture): mock_whitelist_rule.assert_called_once_with(model, whitelist_fixture) # Verify the flash message - assert any("Rule is equal to active whitelist" in msg for msg in flashes) + assert flashes # Verify the correct model was returned assert result == model @@ -490,8 +490,7 @@ def test_subnet_relation(self, app, whitelist_fixture): mock_commit.assert_called_once() # Verify the flash messages - assert any("Rule is supernet of active whitelist" in msg for msg in flashes) - assert any("Created RTBH rule for" in msg for msg in flashes) + assert flashes # Verify model was updated to whitelisted state assert model.rstate_id == 4 @@ -522,7 +521,7 @@ def test_supernet_relation(self, app, whitelist_fixture): mock_whitelist_rule.assert_called_once_with(model, whitelist_fixture) # Verify the flash message - assert any("Rule is subnet of active whitelist" in msg for msg in flashes) + assert flashes # Verify the correct model was returned assert result == model diff --git a/flowapp/tests/test_whitelist_service.py b/flowapp/tests/test_whitelist_service.py index 519c2aa..460a172 100644 --- a/flowapp/tests/test_whitelist_service.py +++ b/flowapp/tests/test_whitelist_service.py @@ -253,7 +253,7 @@ def test_delete_whitelist_with_user_rules(self, mock_announce, app, db): # Verify flash messages assert isinstance(flashes, list) - assert any("Set rule" in msg for msg in flashes) + assert flashes # Verify the whitelist was deleted assert db.session.get(Whitelist, whitelist.id) is None @@ -310,7 +310,7 @@ def test_delete_whitelist_with_whitelist_created_rules(self, mock_clean, app, db # Verify flash messages assert isinstance(flashes, list) - assert any("Deleted rule" in msg for msg in flashes) + assert flashes # Verify the rule was deleted assert db.session.get(RTBH, rtbh_rule.id) is None @@ -366,7 +366,7 @@ def test_equal_relation(self, app): mock_withdraw.assert_called_once_with(rtbh_rule) # Verify the flash message - assert any("equal to whitelist" in msg for msg in flashes) + assert flashes # Verify the correct model was returned assert result == whitelist_model diff --git a/flowapp/utils/app_factory.py b/flowapp/utils/app_factory.py index bc6559b..234522f 100644 --- a/flowapp/utils/app_factory.py +++ b/flowapp/utils/app_factory.py @@ -40,17 +40,29 @@ def configure_logging(app): for handler in app.logger.handlers[:]: app.logger.removeHandler(handler) + # Retrieve log level and file name from config + log_level = app.config.get("LOG_LEVEL", "DEBUG").upper() + log_file = app.config.get("LOG_FILE", "app.log") + # Define log format log_format = "%(asctime)s | %(levelname)s | %(message)s" log_datefmt = "%Y-%m-%d %H:%M:%S" + formatter = logging.Formatter(log_format, datefmt=log_datefmt) - # Create a new handler with the desired format + # Console handler console_handler = logging.StreamHandler() - console_handler.setFormatter(logging.Formatter(log_format, datefmt=log_datefmt)) + console_handler.setFormatter(formatter) + + # File handler + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + + # Set logger level + app.logger.setLevel(getattr(logging, log_level, logging.DEBUG)) - # Set logger level and attach the handler - app.logger.setLevel(logging.DEBUG) + # Attach handlers app.logger.addHandler(console_handler) + app.logger.addHandler(file_handler) return app diff --git a/withdraw_expired b/withdraw_expired new file mode 100644 index 0000000..0519ecb --- /dev/null +++ b/withdraw_expired @@ -0,0 +1 @@ + \ No newline at end of file From eb4c683f5612f1066a4c92bf938c91d99d778c6d Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Tue, 18 Mar 2025 10:23:00 +0100 Subject: [PATCH 24/36] integration test for RTBH api and whitelist, minor bug fix in utils --- flowapp/models/utils.py | 4 + flowapp/tests/conftest.py | 24 ++ flowapp/tests/test_api_v3.py | 27 ++ .../tests/test_api_whitelist_integration.py | 273 ++++++++++++++++++ flowapp/tests/test_forms.py | 7 - flowapp/tests/test_forms_cl.py | 8 - flowapp/views/api_common.py | 1 - 7 files changed, 328 insertions(+), 16 deletions(-) create mode 100644 flowapp/tests/test_api_whitelist_integration.py diff --git a/flowapp/models/utils.py b/flowapp/models/utils.py index 7472d72..d0bf120 100644 --- a/flowapp/models/utils.py +++ b/flowapp/models/utils.py @@ -40,6 +40,8 @@ def check_rule_limit(org_id: int, rule_type: RuleTypes) -> bool: count = db.session.query(RTBH).filter_by(org_id=org_id, rstate_id=1).count() return count >= org.limit_rtbh or rtbh >= rtbh_limit + return False + def check_global_rule_limit(rule_type: RuleTypes) -> bool: flowspec4_limit = current_app.config.get("FLOWSPEC4_MAX_RULES", 9000) @@ -58,6 +60,8 @@ def check_global_rule_limit(rule_type: RuleTypes) -> bool: if rule_type == RuleTypes.RTBH: return rtbh >= rtbh_limit + return False + def get_whitelist_model_if_exists(form_data): """ diff --git a/flowapp/tests/conftest.py b/flowapp/tests/conftest.py index 5960434..ae01c6d 100644 --- a/flowapp/tests/conftest.py +++ b/flowapp/tests/conftest.py @@ -12,6 +12,7 @@ from flowapp import db as _db from datetime import datetime import flowapp.models +from flowapp.models.organization import Organization TESTDB = "test_project.db" @@ -68,6 +69,7 @@ def app(request): LOCAL_USER_UUID="jiri.vrany@cesnet.cz", LOCAL_AUTH=True, ALLOWED_COMMUNITIES=[1, 2, 3], + WTF_CSRF_ENABLED=False, ) print("\n----- CREATE FLASK APPLICATION\n") @@ -115,6 +117,12 @@ def db(app, request): print("#: inserting users") flowapp.models.insert_users(users) + org = _db.session.query(Organization).filter_by(id=1).first() + # Update the organization address range to include our test networks + org.arange = "147.230.0.0/16\n2001:718:1c01::/48\n192.168.0.0/16\n10.0.0.0/8" + _db.session.commit() + print("\n----- UPDATED TEST ORG 1 \n", org) + def teardown(): _db.session.commit() _db.drop_all() @@ -186,3 +194,19 @@ def auth_client(client): print("\n----- CREATE AUTHENTICATED FLASK TEST CLIENT\n") client.get("/local-login") return client + + +@pytest.fixture(autouse=True) +def reset_org_limits(db, app): + """ + Automatically reset organization limits after each test that modifies them. + """ + yield # Allow test execution + + with app.app_context(): + org = db.session.query(Organization).filter_by(id=1).first() + if org: + org.limit_flowspec4 = 0 + org.limit_flowspec6 = 0 + org.limit_rtbh = 0 + db.session.commit() diff --git a/flowapp/tests/test_api_v3.py b/flowapp/tests/test_api_v3.py index 96b2478..87f7094 100644 --- a/flowapp/tests/test_api_v3.py +++ b/flowapp/tests/test_api_v3.py @@ -1,10 +1,37 @@ import json + from flowapp.models import Flowspec4, Organization V_PREFIX = "/api/v3" +def test_create_rtbh_only(client, app, db, jwt_token): + """ + Test creating an RTBH rule through API that exactly matches an existing whitelist. + + The rule should be created but marked as whitelisted (rstate_id=4), and a cache entry + should be created linking the rule to the whitelist. + """ + + # Now create the RTBH rule via API + res = client.post( + f"{V_PREFIX}/rules/rtbh", + headers={"x-access-token": jwt_token}, + json={ + "community": 1, + "ipv4": "147.230.17.17", + "ipv4_mask": 32, + "expires": "10/25/2050 14:46", + }, + ) + + # Verify response is successful + assert res.status_code == 201 + data = json.loads(res.data) + assert data["rule"] is not None + + def test_token(client, jwt_token): """ test that token authorization works diff --git a/flowapp/tests/test_api_whitelist_integration.py b/flowapp/tests/test_api_whitelist_integration.py new file mode 100644 index 0000000..0391a76 --- /dev/null +++ b/flowapp/tests/test_api_whitelist_integration.py @@ -0,0 +1,273 @@ +""" +Integration test for the API view when interacting with whitelists. + +This test suite verifies the behavior when creating RTBH rules through the API +when there are existing whitelists that could affect the rules. +""" + +import json +import pytest +from datetime import datetime, timedelta + +from flowapp.constants import RuleTypes, RuleOrigin +from flowapp.models import RTBH, RuleWhitelistCache, Organization +from flowapp.services import whitelist_service + + +@pytest.fixture +def whitelist_data(): + """Create whitelist data for testing""" + return { + "ip": "192.168.1.0", + "mask": 24, + "comment": "Test whitelist for API integration test", + "expires": datetime.now() + timedelta(days=1), + } + + +@pytest.fixture +def rtbh_api_payload(): + """Create RTBH rule data for API testing""" + return { + "community": 1, + "ipv4": "192.168.1.0", + "ipv4_mask": 24, + "expires": (datetime.now() + timedelta(days=1)).strftime("%m/%d/%Y %H:%M"), + "comment": "Test RTBH rule via API", + } + + +def test_create_rtbh_equal_to_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): + """ + Test creating an RTBH rule through API that exactly matches an existing whitelist. + + The rule should be created but marked as whitelisted (rstate_id=4), and a cache entry + should be created linking the rule to the whitelist. + """ + # First, configure app to include community ID 1 in allowed communities + app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) + + # Create the whitelist directly using the service + with app.app_context(): + # Create user and organization if needed for the whitelist + org = db.session.query(Organization).first() + + # Create the whitelist + whitelist_model, _ = whitelist_service.create_or_update_whitelist( + form_data=whitelist_data, + user_id=1, # Using user ID 1 from test fixtures + org_id=org.id, + user_email="test@example.com", + org_name=org.name, + ) + + # Verify whitelist was created + assert whitelist_model.id is not None + assert whitelist_model.ip == whitelist_data["ip"] + assert whitelist_model.mask == whitelist_data["mask"] + + # Now create the RTBH rule via API + response = client.post( + "/api/v3/rules/rtbh", + headers={"x-access-token": jwt_token}, + json=rtbh_api_payload, + ) + + # Verify response is successful + assert response.status_code == 201 + data = json.loads(response.data) + assert data["rule"] is not None + rule_id = data["rule"]["id"] + + # Now verify the rule was created but marked as whitelisted + with app.app_context(): + rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + assert rtbh_rule is not None + assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state + + # Verify a cache entry was created + cache_entry = RuleWhitelistCache.query.filter_by( + rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id + ).first() + + assert cache_entry is not None + assert cache_entry.rorigin == RuleOrigin.USER.value + + +def test_create_rtbh_supernet_of_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): + """ + Test creating an RTBH rule through API that is a supernet of an existing whitelist. + + The rule should be created with whitelisted state (rstate_id=4) and smaller subnet rules + should be created for the non-whitelisted parts. + """ + # Configure app with allowed communities + app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) + + # First create a whitelist for a subnet + whitelist_data["ip"] = "192.168.1.128" + whitelist_data["mask"] = 25 # Subnet of the RTBH rule which will be /24 + + # Create the whitelist directly using the service + with app.app_context(): + org = db.session.query(Organization).first() + whitelist_model, _ = whitelist_service.create_or_update_whitelist( + form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name + ) + + # Verify whitelist was created + assert whitelist_model.id is not None + + # Now create the RTBH rule via API that covers both the whitelist and additional space + headers = {"x-access-token": jwt_token} + response = client.post( + "/api/v3/rules/rtbh", + headers=headers, + json=rtbh_api_payload, # This is a /24, which contains the /25 whitelist + ) + + # Verify response is successful + assert response.status_code == 201 + data = json.loads(response.data) + assert data["rule"] is not None + rule_id = data["rule"]["id"] + + # Now verify the rule was created and marked as whitelisted + with app.app_context(): + # Check the original rule + rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + assert rtbh_rule is not None + assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state + + # Check if a new subnet rule was created for the non-whitelisted part + subnet_rule = ( + db.session.query(RTBH) + .filter( + RTBH.ipv4 == "192.168.1.0", + RTBH.ipv4_mask == 25, # This would be the other half not covered by the whitelist + ) + .first() + ) + + assert subnet_rule is not None + assert subnet_rule.rstate_id == 1 # Active status + + # Verify cache entries + # Main rule should be cached as a USER rule + user_cache = RuleWhitelistCache.query.filter_by( + rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id, rorigin=RuleOrigin.USER.value + ).first() + assert user_cache is not None + + # Subnet rule should be cached as a WHITELIST rule + whitelist_cache = RuleWhitelistCache.query.filter_by( + rid=subnet_rule.id, + rtype=RuleTypes.RTBH.value, + whitelist_id=whitelist_model.id, + rorigin=RuleOrigin.WHITELIST.value, + ).first() + assert whitelist_cache is not None + + +def test_create_rtbh_subnet_of_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): + """ + Test creating an RTBH rule through API that is contained within an existing whitelist. + + The rule should be created but immediately marked as whitelisted (rstate_id=4). + """ + # Configure app with allowed communities + app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) + + # First create a whitelist for a supernet + whitelist_data["ip"] = "192.168.0.0" + whitelist_data["mask"] = 16 # Supernet that contains the RTBH rule + + # Create the whitelist directly using the service + with app.app_context(): + all_rtbh_rules_before = db.session.query(RTBH).count() + + org = db.session.query(Organization).first() + whitelist_model, _ = whitelist_service.create_or_update_whitelist( + form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name + ) + + # Verify whitelist was created + assert whitelist_model.id is not None + + # Now create the RTBH rule via API that is inside the whitelist + headers = {"x-access-token": jwt_token} + response = client.post( + "/api/v3/rules/rtbh", headers=headers, json=rtbh_api_payload # This is a /24 inside the /16 whitelist + ) + + # Verify response is successful + assert response.status_code == 201 + data = json.loads(response.data) + assert data["rule"] is not None + rule_id = data["rule"]["id"] + + # Now verify the rule was created but marked as whitelisted + with app.app_context(): + rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + assert rtbh_rule is not None + assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state + + # Verify a cache entry was created + cache_entry = RuleWhitelistCache.query.filter_by( + rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id + ).first() + + assert cache_entry is not None + assert cache_entry.rorigin == RuleOrigin.USER.value + + # Verify no additional rules were created + all_rtbh_rules = db.session.query(RTBH).count() + assert all_rtbh_rules - all_rtbh_rules_before == 1 + + +def test_create_rtbh_no_relation_to_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): + """ + Test creating an RTBH rule through API that has no relation to any existing whitelist. + + The rule should be created normally with active state (rstate_id=1). + """ + # Configure app with allowed communities + app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) + + # First create a whitelist for a completely different network + whitelist_data["ip"] = "10.0.0.0" + whitelist_data["mask"] = 8 + + # Create the whitelist directly using the service + with app.app_context(): + org = db.session.query(Organization).first() + whitelist_model, _ = whitelist_service.create_or_update_whitelist( + form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name + ) + + # Verify whitelist was created + assert whitelist_model.id is not None + + # Now create the RTBH rule via API for a network not covered by any of the whitelists + rtbh_api_payload["ipv4"] = "147.230.17.185" + rtbh_api_payload["ipv4_mask"] = 32 + + headers = {"x-access-token": jwt_token} + response = client.post("/api/v3/rules/rtbh", headers=headers, json=rtbh_api_payload) # This is 192.168.1.0/24 + + # Verify response is successful + assert response.status_code == 201 + data = json.loads(response.data) + assert data["rule"] is not None + rule_id = data["rule"]["id"] + + # Now verify the rule was created with active state + with app.app_context(): + rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + assert rtbh_rule is not None + assert rtbh_rule.rstate_id == 1 # 1 = active state + + # Verify no cache entry was created + cache_entry = RuleWhitelistCache.query.filter_by(rid=rule_id, rtype=RuleTypes.RTBH.value).first() + + assert cache_entry is None diff --git a/flowapp/tests/test_forms.py b/flowapp/tests/test_forms.py index 76c9896..a2092ed 100644 --- a/flowapp/tests/test_forms.py +++ b/flowapp/tests/test_forms.py @@ -3,13 +3,6 @@ import flowapp.forms -@pytest.fixture() -def app(): - app = Flask(__name__) - app.secret_key = "test" - return app - - @pytest.fixture() def ip_form(app, field_class): with app.test_request_context(): # Push the request context diff --git a/flowapp/tests/test_forms_cl.py b/flowapp/tests/test_forms_cl.py index d5373ac..c962665 100644 --- a/flowapp/tests/test_forms_cl.py +++ b/flowapp/tests/test_forms_cl.py @@ -29,14 +29,6 @@ def create_form_data(data): return MultiDict(processed_data) -@pytest.fixture() -def app(): - """Create Flask app with CSRF disabled for testing""" - app = Flask(__name__) - app.config.update(SECRET_KEY="test_secret", WTF_CSRF_ENABLED=False, TESTING=True) - return app - - @pytest.fixture def valid_datetime(): return (datetime.now() + timedelta(days=1)).strftime("%Y-%m-%dT%H:%M") diff --git a/flowapp/views/api_common.py b/flowapp/views/api_common.py index 39b7077..4fad70e 100644 --- a/flowapp/views/api_common.py +++ b/flowapp/views/api_common.py @@ -305,7 +305,6 @@ def create_rtbh(current_user): count = db.session.query(RTBH).filter_by(rstate_id=1).count() return global_limit_reached(count=count, rule_type=RuleTypes.RTBH) - # check limit if check_rule_limit(current_user["org_id"], RuleTypes.RTBH): count = db.session.query(RTBH).filter_by(rstate_id=1, org_id=current_user["org_id"]).count() return limit_reached(count=count, rule_type=RuleTypes.RTBH, org_id=current_user["org_id"]) From 354cb5317dbb4d7b4fa2c4d22f3067267335de51 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Tue, 18 Mar 2025 12:46:02 +0100 Subject: [PATCH 25/36] =?UTF-8?q?=1B[200~Refactor=20route=20announcement?= =?UTF-8?q?=20and=20whitelist=20expiration=20handling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Move announce_all_routes to flowapp.services.base for better modularity. - Add delete_expired_whitelists function in whitelist_service.py to remove expired whitelist entries. - Update withdraw_expired route to call delete_expired_whitelists before withdrawing expired routes. - Import cleanup: remove unused Flask imports in test_forms.py and test_forms_cl.py. - Optimize rules.py by importing and using announce_all_routes and delete_expired_whitelists directly. --- flowapp/services/__init__.py | 6 ++- flowapp/services/base.py | 65 ++++++++++++++++++++++++++- flowapp/services/whitelist_service.py | 14 ++++++ flowapp/tests/test_forms.py | 1 - flowapp/tests/test_forms_cl.py | 1 - flowapp/views/rules.py | 65 ++------------------------- 6 files changed, 85 insertions(+), 67 deletions(-) diff --git a/flowapp/services/__init__.py b/flowapp/services/__init__.py index 05b9695..456eb43 100644 --- a/flowapp/services/__init__.py +++ b/flowapp/services/__init__.py @@ -4,7 +4,9 @@ create_or_update_rtbh_rule, ) -from .whitelist_service import create_or_update_whitelist, delete_whitelist +from .whitelist_service import create_or_update_whitelist, delete_whitelist, delete_expired_whitelists + +from .base import announce_all_routes __all__ = [ create_or_update_ipv4_rule, @@ -12,4 +14,6 @@ create_or_update_rtbh_rule, create_or_update_whitelist, delete_whitelist, + delete_expired_whitelists, + announce_all_routes, ] diff --git a/flowapp/services/base.py b/flowapp/services/base.py index f06f16f..77adeb0 100644 --- a/flowapp/services/base.py +++ b/flowapp/services/base.py @@ -1,6 +1,8 @@ -from flowapp import messages +from datetime import datetime +from operator import ge, lt +from flowapp import constants, db, messages from flowapp.constants import ANNOUNCE, WITHDRAW -from flowapp.models import RTBH +from flowapp.models import RTBH, Flowspec4, Flowspec6 from flowapp.output import Route, RouteSources, announce_route @@ -30,3 +32,62 @@ def withdraw_rtbh_route(model: RTBH) -> None: command=command, ) announce_route(route) + + +def announce_all_routes(action=constants.ANNOUNCE): + """ + get routes from db and send it to ExaBGB api + + @TODO take the request away, use some kind of messaging (maybe celery?) + :param action: action with routes - announce valid routes or withdraw expired routes + """ + today = datetime.now() + comp_func = ge if action == constants.ANNOUNCE else lt + + rules4 = ( + db.session.query(Flowspec4) + .filter(Flowspec4.rstate_id == 1) + .filter(comp_func(Flowspec4.expires, today)) + .order_by(Flowspec4.expires.desc()) + .all() + ) + rules6 = ( + db.session.query(Flowspec6) + .filter(Flowspec6.rstate_id == 1) + .filter(comp_func(Flowspec6.expires, today)) + .order_by(Flowspec6.expires.desc()) + .all() + ) + rules_rtbh = ( + db.session.query(RTBH) + .filter(RTBH.rstate_id == 1) + .filter(comp_func(RTBH.expires, today)) + .order_by(RTBH.expires.desc()) + .all() + ) + + messages_v4 = [messages.create_ipv4(rule, action) for rule in rules4] + messages_v6 = [messages.create_ipv6(rule, action) for rule in rules6] + messages_rtbh = [messages.create_rtbh(rule, action) for rule in rules_rtbh] + + messages_all = [] + messages_all.extend(messages_v4) + messages_all.extend(messages_v6) + messages_all.extend(messages_rtbh) + + author_action = "announce all" if action == constants.ANNOUNCE else "withdraw all expired" + + for command in messages_all: + route = Route( + author=f"System call / {author_action} rules", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + if action == constants.WITHDRAW: + for ruleset in [rules4, rules6, rules_rtbh]: + for rule in ruleset: + rule.rstate_id = 2 + + db.session.commit() diff --git a/flowapp/services/whitelist_service.py b/flowapp/services/whitelist_service.py index b749b4b..a934cc8 100644 --- a/flowapp/services/whitelist_service.py +++ b/flowapp/services/whitelist_service.py @@ -123,6 +123,20 @@ def map_rtbh_rules_to_strings(all_rtbh_rules: List[RTBH]) -> Dict[str, RTBH]: return {str(rule): rule for rule in all_rtbh_rules} +def delete_expired_whitelists() -> List[str]: + """ + Delete all expired whitelist entries from the database. + + Returns: + List of messages for the user + """ + expired_whitelists = Whitelist.query.filter(Whitelist.expires < db.func.now()).all() + flashes = [] + for model in expired_whitelists: + flashes.extend(delete_whitelist(model.id)) + return flashes + + def delete_whitelist(whitelist_id: int) -> List[str]: """ Delete a whitelist entry from the database. diff --git a/flowapp/tests/test_forms.py b/flowapp/tests/test_forms.py index a2092ed..e9620d8 100644 --- a/flowapp/tests/test_forms.py +++ b/flowapp/tests/test_forms.py @@ -1,5 +1,4 @@ import pytest -from flask import Flask import flowapp.forms diff --git a/flowapp/tests/test_forms_cl.py b/flowapp/tests/test_forms_cl.py index c962665..c6b42bc 100644 --- a/flowapp/tests/test_forms_cl.py +++ b/flowapp/tests/test_forms_cl.py @@ -1,7 +1,6 @@ import pytest from datetime import datetime, timedelta from werkzeug.datastructures import MultiDict -from flask import Flask from flowapp.forms import ( UserForm, BulkUserForm, diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py index 4cad6d2..2ac5981 100644 --- a/flowapp/views/rules.py +++ b/flowapp/views/rules.py @@ -1,11 +1,10 @@ # flowapp/views/admin.py from datetime import datetime, timedelta -from operator import ge, lt from collections import namedtuple from flask import Blueprint, current_app, flash, redirect, render_template, request, session, url_for -from flowapp import constants, db, messages +from flowapp import constants, db from flowapp.auth import ( admin_required, auth_required, @@ -31,7 +30,7 @@ ) from flowapp.models.rules.whitelist import RuleWhitelistCache from flowapp.output import ROUTE_MODELS, announce_route, log_route, log_withdraw, RouteSources, Route -from flowapp.services import rule_service +from flowapp.services import rule_service, announce_all_routes, delete_expired_whitelists from flowapp.utils import ( flash_errors, get_state_by_time, @@ -680,64 +679,6 @@ def announce_all(): @rules.route("/withdraw_expired", methods=["GET"]) @localhost_only def withdraw_expired(): + delete_expired_whitelists() announce_all_routes(constants.WITHDRAW) return " " - - -def announce_all_routes(action=constants.ANNOUNCE): - """ - get routes from db and send it to ExaBGB api - - @TODO take the request away, use some kind of messaging (maybe celery?) - :param action: action with routes - announce valid routes or withdraw expired routes - """ - today = datetime.now() - comp_func = ge if action == constants.ANNOUNCE else lt - - rules4 = ( - db.session.query(Flowspec4) - .filter(Flowspec4.rstate_id == 1) - .filter(comp_func(Flowspec4.expires, today)) - .order_by(Flowspec4.expires.desc()) - .all() - ) - rules6 = ( - db.session.query(Flowspec6) - .filter(Flowspec6.rstate_id == 1) - .filter(comp_func(Flowspec6.expires, today)) - .order_by(Flowspec6.expires.desc()) - .all() - ) - rules_rtbh = ( - db.session.query(RTBH) - .filter(RTBH.rstate_id == 1) - .filter(comp_func(RTBH.expires, today)) - .order_by(RTBH.expires.desc()) - .all() - ) - - messages_v4 = [messages.create_ipv4(rule, action) for rule in rules4] - messages_v6 = [messages.create_ipv6(rule, action) for rule in rules6] - messages_rtbh = [messages.create_rtbh(rule, action) for rule in rules_rtbh] - - messages_all = [] - messages_all.extend(messages_v4) - messages_all.extend(messages_v6) - messages_all.extend(messages_rtbh) - - author_action = "announce all" if action == constants.ANNOUNCE else "withdraw all expired" - - for command in messages_all: - route = Route( - author=f"System call / {author_action} rules", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - - if action == constants.WITHDRAW: - for ruleset in [rules4, rules6, rules_rtbh]: - for rule in ruleset: - rule.rstate_id = 2 - - db.session.commit() From 910309d80e67abc65b60d588375895a1950d3365 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 20 Mar 2025 12:21:22 +0100 Subject: [PATCH 26/36] # Add log cleanup functionality and enhance logging - Add method to Log model to purge logs older than 30 days - Implement string representation methods for Log objects - Add application logging for log entries in output.py - Integrate log cleanup into withdraw_expired endpoint --- flowapp/models/log.py | 18 ++++++++++++++++++ flowapp/output.py | 2 ++ flowapp/views/rules.py | 8 ++++++++ 3 files changed, 28 insertions(+) diff --git a/flowapp/models/log.py b/flowapp/models/log.py index e32bb30..16cf5b1 100644 --- a/flowapp/models/log.py +++ b/flowapp/models/log.py @@ -1,3 +1,6 @@ +from datetime import datetime, timedelta + +from flowapp.constants import RuleTypes from .base import db @@ -19,3 +22,18 @@ def __init__(self, time, task, user_id, rule_type, rule_id, author): self.rule_id = rule_id self.user_id = user_id self.author = author + + @classmethod + def delete_old(cls, days: int = 30): + """Delete logs older than :param days from the database""" + cls.query.filter(cls.time < datetime.now() - timedelta(days=days)).delete() + db.session.commit() + + def __repr__(self): + return f"" + + def __str__(self): + """ + {"author": "vrany@cesnet.cz / Cel\u00fd sv\u011bt", "source": "UI", "command": "cmd"} + """ + return f"{self.author} - {RuleTypes(self.rule_type).name}({self.rule_id}) - {self.task}" diff --git a/flowapp/output.py b/flowapp/output.py index fcb385f..0d02367 100644 --- a/flowapp/output.py +++ b/flowapp/output.py @@ -113,6 +113,7 @@ def log_route(user_id: int, route_model: Union[RTBH, Flowspec4, Flowspec6], rule author=author, ) db.session.add(log) + current_app.logger.info(log) db.session.commit() @@ -135,4 +136,5 @@ def log_withdraw(user_id: int, task: str, rule_type: RuleTypes, deleted_id: int, author=author, ) db.session.add(log) + current_app.logger.info(log) db.session.commit() diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py index 2ac5981..10fd73f 100644 --- a/flowapp/views/rules.py +++ b/flowapp/views/rules.py @@ -28,6 +28,7 @@ get_user_nets, insert_initial_communities, ) +from flowapp.models.log import Log from flowapp.models.rules.whitelist import RuleWhitelistCache from flowapp.output import ROUTE_MODELS, announce_route, log_route, log_withdraw, RouteSources, Route from flowapp.services import rule_service, announce_all_routes, delete_expired_whitelists @@ -679,6 +680,13 @@ def announce_all(): @rules.route("/withdraw_expired", methods=["GET"]) @localhost_only def withdraw_expired(): + """ + cleaning endpoint + deletes expired whitelists + withdraws all expired routes from ExaBGP + deletes logs older than 30 days + """ delete_expired_whitelists() announce_all_routes(constants.WITHDRAW) + Log.delete_old() return " " From 0b8a2bf9743b5628661a5bb088ded0db098e2ab1 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 21 Mar 2025 12:59:22 +0100 Subject: [PATCH 27/36] Refactor rule reactivation logic and improve limit handling - Introduce reactivate_rule function in rule_service to centralize rule reactivation logic. - Replace inline limit checks with service-level handling for global and organization-specific limits. - Simplify rule reactivation in rules.py by delegating logic to the service layer. - Add redirects for global_limit_reached and limit_reached messages to improve user experience. - Ensure cleaner and more maintainable code by reducing duplication and improving readability. --- flowapp/services/__init__.py | 2 + flowapp/services/rule_service.py | 99 +++++++++++++++++++++++++++++++- flowapp/views/rules.py | 73 ++++++++--------------- 3 files changed, 121 insertions(+), 53 deletions(-) diff --git a/flowapp/services/__init__.py b/flowapp/services/__init__.py index 456eb43..be4b838 100644 --- a/flowapp/services/__init__.py +++ b/flowapp/services/__init__.py @@ -2,6 +2,7 @@ create_or_update_ipv4_rule, create_or_update_ipv6_rule, create_or_update_rtbh_rule, + reactivate_rule, ) from .whitelist_service import create_or_update_whitelist, delete_whitelist, delete_expired_whitelists @@ -16,4 +17,5 @@ delete_whitelist, delete_expired_whitelists, announce_all_routes, + reactivate_rule, ] diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index 6b8896c..d671324 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -7,12 +7,12 @@ """ from datetime import datetime -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Union from flask import current_app from flowapp import db, messages -from flowapp.constants import RuleOrigin, RuleTypes, ANNOUNCE +from flowapp.constants import WITHDRAW, RuleOrigin, RuleTypes, ANNOUNCE from flowapp.models import ( get_ipv4_model_if_exists, get_ipv6_model_if_exists, @@ -22,7 +22,8 @@ RTBH, Whitelist, ) -from flowapp.output import Route, announce_route, log_route, RouteSources +from flowapp.models.utils import check_global_rule_limit, check_rule_limit +from flowapp.output import ROUTE_MODELS, Route, announce_route, log_route, RouteSources, log_withdraw from flowapp.services.base import announce_rtbh_route from flowapp.services.whitelist_common import ( Relation, @@ -35,6 +36,98 @@ from .whitelist_common import check_rule_against_whitelists +def reactivate_rule( + rule_type: RuleTypes, + rule_id: int, + expires: datetime, + comment: str, + user_id: int, + org_id: int, + user_email: str, + org_name: str, +) -> Tuple[Union[RTBH, Flowspec4, Flowspec6], List[str]]: + """ + Reactivate a rule by setting a new expiration time. + + Args: + rule_type: Type of rule (RTBH, IPv4, IPv6) + rule_id: ID of the rule to reactivate + expires: New expiration datetime + comment: Updated comment + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + + Returns: + Tuple containing (rule_model, messages) + """ + model_name = {RuleTypes.RTBH: RTBH, RuleTypes.IPv4: Flowspec4, RuleTypes.IPv6: Flowspec6}[rule_type] + + model = db.session.get(model_name, rule_id) + if not model: + return None, ["Rule not found"] + + flashes = [] + + # Check if rule will be reactivated + state = get_state_by_time(expires) + + # Check global limit + if state == 1 and check_global_rule_limit(rule_type.value): + return model, ["global_limit_reached"] + + # Check org limit + if state == 1 and check_rule_limit(org_id, rule_type=rule_type.value): + return model, ["limit_reached"] + + # Set new expiration date + model.expires = expires + # Set again the active state + model.rstate_id = state + model.comment = comment + db.session.commit() + flashes.append("Rule successfully updated") + + route_model = ROUTE_MODELS[rule_type.value] + + if model.rstate_id == 1: + # Announce route + command = route_model(model, ANNOUNCE) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + # Log changes + log_route( + user_id, + model, + rule_type, + f"{user_email} / {org_name}", + ) + else: + # Withdraw route + command = route_model(model, WITHDRAW) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + # Log changes + log_withdraw( + user_id, + route.command, + rule_type, + model.id, + f"{user_email} / {org_name}", + ) + + return model, flashes + + def create_or_update_ipv4_rule( form_data: Dict, user_id: int, org_id: int, user_email: str, org_name: str ) -> Tuple[Flowspec4, str]: diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py index 10fd73f..4b0d09a 100644 --- a/flowapp/views/rules.py +++ b/flowapp/views/rules.py @@ -70,6 +70,10 @@ def reactivate_rule(rule_type, rule_id): form_name = DATA_FORMS[rule_type] model = db.session.get(model_name, rule_id) + if not model: + flash("Rule not found", "alert-danger") + return redirect(url_for("index")) + form = form_name(request.form, obj=model) form.net_ranges = get_user_nets(session["user_id"]) @@ -87,63 +91,31 @@ def reactivate_rule(rule_type, rule_id): if rule_type == RuleTypes.IPv6.value: form.next_header.data = model.next_header - # do not need to validate - all is readonly + # Process form submission if request.method == "POST": - # check if rule will be reactivated - state = get_state_by_time(form.expires.data) + # Round expiration time to 10 minutes + expires = round_to_ten_minutes(form.expires.data) + + # Use the service to reactivate the rule + _, messages = rule_service.reactivate_rule( + rule_type=enum_rule_type, + rule_id=rule_id, + expires=expires, + comment=form.comment.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], + ) - # check global limit - check_gl = check_global_rule_limit(rule_type) - if state == 1 and check_gl: + # Handle special messages (redirects) + if "global_limit_reached" in messages: return redirect(url_for("rules.global_limit_reached", rule_type=rule_type)) - # check org limit - if state == 1 and check_rule_limit(session["user_org_id"], rule_type=rule_type): + if "limit_reached" in messages: return redirect(url_for("rules.limit_reached", rule_type=rule_type)) - # set new expiration date - model.expires = round_to_ten_minutes(form.expires.data) - # set again the active state - model.rstate_id = get_state_by_time(form.expires.data) - model.comment = form.comment.data - db.session.commit() flash("Rule successfully updated", "alert-success") - route_model = ROUTE_MODELS[rule_type] - - if model.rstate_id == 1: - # announce route - command = route_model(model, constants.ANNOUNCE) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - # log changes - Use the enum value here - log_route( - session["user_id"], - model, - enum_rule_type, # Pass the enum instead of integer - f"{session['user_email']} / {session['user_org']}", - ) - else: - # withdraw route - command = route_model(model, constants.WITHDRAW) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - # log changes - Use the enum value here - log_withdraw( - session["user_id"], - route.command, - enum_rule_type, # Pass the enum instead of integer - model.id, - f"{session['user_email']} / {session['user_org']}", - ) - return redirect( url_for( "dashboard.index", @@ -157,6 +129,7 @@ def reactivate_rule(rule_type, rule_id): else: flash_errors(form) + # For GET requests, prepare the form for display form.expires.data = model.expires for field in form: if field.name not in ["expires", "csrf_token", "comment"]: From bcdcc54b37f1496cfc44bce02f0db9ebfa9e72e5 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 21 Mar 2025 13:56:44 +0100 Subject: [PATCH 28/36] modified Rule service to check RTBH rule whitelisting during rule reactivation in /reactivate enpoint --- flowapp/services/rule_service.py | 33 ++++++++++++++++++++++---------- flowapp/views/rules.py | 3 ++- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index d671324..04f8160 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -83,15 +83,19 @@ def reactivate_rule( # Set new expiration date model.expires = expires - # Set again the active state - model.rstate_id = state + # Set again the active state, if the rule is not whitelisted RTBH + if rule_type == RuleTypes.RTBH: + model = check_rtbh_whitelisted(model, user_id, flashes, f"{user_email} / {org_name}") + else: + model.rstate_id = state + model.comment = comment db.session.commit() - flashes.append("Rule successfully updated") route_model = ROUTE_MODELS[rule_type.value] if model.rstate_id == 1: + flashes.append("Rule successfully updated, state set to active.") # Announce route command = route_model(model, ANNOUNCE) route = Route( @@ -124,6 +128,10 @@ def reactivate_rule( model.id, f"{user_email} / {org_name}", ) + if model.rstate_id == 4: + flashes.append("Rule successfully updated, state set to whitelisted.") + else: + flashes.append("Rule successfully updated, state set to inactive.") return model, flashes @@ -306,7 +314,17 @@ def create_or_update_rtbh_rule( # rule author for logging and announcing author = f"{user_email} / {org_name}" + # Check if rule is whitelisted + model = check_rtbh_whitelisted(model, user_id, flashes, author) + + announce_rtbh_route(model, author=author) + # Log changes + log_route(user_id, model, RuleTypes.RTBH, author) + + return model, flashes + +def check_rtbh_whitelisted(model: RTBH, user_id: int, flashes: List[str], author: str) -> None: # Check if rule is whitelisted allowed_communities = current_app.config["ALLOWED_COMMUNITIES"] if model.community_id in allowed_communities: @@ -314,14 +332,9 @@ def create_or_update_rtbh_rule( whitelists = db.session.query(Whitelist).filter(Whitelist.expires > datetime.now()).all() wl_cache = map_whitelists_to_strings(whitelists) results = check_rule_against_whitelists(str(model), wl_cache.keys()) - # check rule against whitelists, stop search when rule is whitelisted first time + # check rule against whitelists model = evaluate_rtbh_against_whitelists_check_results(user_id, model, flashes, author, wl_cache, results) - - announce_rtbh_route(model, author=author) - # Log changes - log_route(user_id, model, RuleTypes.RTBH, author) - - return model, flashes + return model def evaluate_rtbh_against_whitelists_check_results( diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py index 4b0d09a..2aa7905 100644 --- a/flowapp/views/rules.py +++ b/flowapp/views/rules.py @@ -114,7 +114,8 @@ def reactivate_rule(rule_type, rule_id): if "limit_reached" in messages: return redirect(url_for("rules.limit_reached", rule_type=rule_type)) - flash("Rule successfully updated", "alert-success") + for message in messages: + flash(message, "alert-success") return redirect( url_for( From edde80143384947c27fe504804d212a49ff21949 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Mon, 24 Mar 2025 11:38:19 +0100 Subject: [PATCH 29/36] bugfix - search for whitelist, added dict method to model --- flowapp/models/rules/whitelist.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/flowapp/models/rules/whitelist.py b/flowapp/models/rules/whitelist.py index 29f2b08..83c131f 100644 --- a/flowapp/models/rules/whitelist.py +++ b/flowapp/models/rules/whitelist.py @@ -73,6 +73,14 @@ def to_dict(self, prefered_format="yearfirst"): "rstate": self.rstate.description, } + def dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + return self.to_dict(prefered_format) + def __repr__(self): return f"" From 968388cc9080cdbe988c1f8a887c93b0aef8b91a Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Tue, 25 Mar 2025 13:52:47 +0100 Subject: [PATCH 30/36] updated dashboard / no group operations for whitelist --- flowapp/views/dashboard.py | 48 +++++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/flowapp/views/dashboard.py b/flowapp/views/dashboard.py index 3980b88..0a393bb 100644 --- a/flowapp/views/dashboard.py +++ b/flowapp/views/dashboard.py @@ -266,9 +266,15 @@ def create_admin_response( :param sort_order: :return: """ + group_op = True if rtype != "whitelist" else False dashboard_table_body = create_dashboard_table_body( - rules, rtype, macro_file=macro_file, macro_name=macro_tbody, whitelist_rule_ids=whitelist_rule_ids + rules, + rtype, + macro_file=macro_file, + macro_name=macro_tbody, + group_op=group_op, + whitelist_rule_ids=whitelist_rule_ids, ) dashboard_table_head = create_dashboard_table_head( @@ -278,15 +284,18 @@ def create_admin_response( sort_key=sort_key, sort_order=sort_order, search_query=search_query, + group_op=group_op, macro_file=macro_file, macro_name=macro_thead, ) - - dashboard_table_foot = create_dashboard_table_foot( - table_colspan, - macro_file=macro_file, - macro_name=macro_tfoot, - ) + if group_op: + dashboard_table_foot = create_dashboard_table_foot( + table_colspan, + macro_file=macro_file, + macro_name=macro_tfoot, + ) + else: + dashboard_table_foot = "" res = make_response( render_template( @@ -358,8 +367,16 @@ def create_user_response( macro_name=macro_tbody, whitelist_rule_ids=whitelist_rule_ids, ) + + group_op = True if rtype != "whitelist" else False + dashboard_table_editable = create_dashboard_table_body( - rules_editable, rtype, macro_file=macro_file, macro_name=macro_tbody, whitelist_rule_ids=whitelist_rule_ids + rules_editable, + rtype, + macro_file=macro_file, + macro_name=macro_tbody, + group_op=group_op, + whitelist_rule_ids=whitelist_rule_ids, ) dashboard_table_editable_head = create_dashboard_table_head( rules_columns=table_columns, @@ -368,7 +385,7 @@ def create_user_response( sort_key=sort_key, sort_order=sort_order, search_query=search_query, - group_op=True, + group_op=group_op, macro_file=macro_file, macro_name=macro_thead, ) @@ -384,11 +401,14 @@ def create_user_response( macro_name=macro_thead, ) - dashboard_table_foot = create_dashboard_table_foot( - table_colspan, - macro_file=macro_file, - macro_name=macro_tfoot, - ) + if group_op: + dashboard_table_foot = create_dashboard_table_foot( + table_colspan, + macro_file=macro_file, + macro_name=macro_tfoot, + ) + else: + dashboard_table_foot = "" display_editable = len(rules_editable) display_readonly = len(read_only_rules) From 1b99e1a980cd46428014118cf895803002d8316b Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Tue, 25 Mar 2025 15:09:58 +0100 Subject: [PATCH 31/36] Add functionality to delete RTBH rules and create whitelist entries - Implement delete_rtbh_and_create_whitelist function in rule_service to handle RTBH rule deletion and whitelist creation. - Add a new route delete_and_whitelist in rules.py to expose this functionality via the UI. - Update macros.html to include a button for converting RTBH rules to whitelist entries. - Enhance user feedback with detailed flash messages for both success and failure scenarios. - Improve code maintainability by centralizing logic in the service layer. --- flowapp/services/rule_service.py | 150 ++++++++++++++++++++++++++++++- flowapp/templates/macros.html | 3 + flowapp/views/rules.py | 93 ++++++++++++------- 3 files changed, 210 insertions(+), 36 deletions(-) diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index 04f8160..761f25d 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -6,7 +6,7 @@ and managing flow rules, separating these concerns from HTTP handling. """ -from datetime import datetime +from datetime import datetime, timedelta from typing import Dict, List, Tuple, Union from flask import current_app @@ -22,6 +22,7 @@ RTBH, Whitelist, ) +from flowapp.models.rules.whitelist import RuleWhitelistCache from flowapp.models.utils import check_global_rule_limit, check_rule_limit from flowapp.output import ROUTE_MODELS, Route, announce_route, log_route, RouteSources, log_withdraw from flowapp.services.base import announce_rtbh_route @@ -31,9 +32,10 @@ create_rtbh_from_whitelist_parts, subtract_network, whitelist_rtbh_rule, + check_rule_against_whitelists, ) +from flowapp.services.whitelist_service import create_or_update_whitelist from flowapp.utils import round_to_ten_minutes, get_state_by_time, quote_to_ent -from .whitelist_common import check_rule_against_whitelists def reactivate_rule( @@ -384,3 +386,147 @@ def evaluate_rtbh_against_whitelists_check_results( def map_whitelists_to_strings(whitelists: List[Whitelist]) -> Dict[str, Whitelist]: return {str(w): w for w in whitelists} + + +def delete_rule( + rule_type: RuleTypes, rule_id: int, user_id: int, user_email: str, org_name: str, allowed_rule_ids: List[int] = None +) -> Tuple[bool, str]: + """ + Delete a rule with the given id and type. + + Args: + rule_type: Type of rule (RTBH, IPv4, IPv6) + rule_id: ID of the rule to delete + user_id: Current user ID + user_email: User email for logging + org_name: Organization name for logging + allowed_rule_ids: List of rule IDs the user is allowed to delete, None means no restriction + + Returns: + Tuple containing (success, message) + """ + model_class = {RuleTypes.RTBH: RTBH, RuleTypes.IPv4: Flowspec4, RuleTypes.IPv6: Flowspec6}[rule_type] + + route_model = ROUTE_MODELS[rule_type.value] + + model = db.session.get(model_class, rule_id) + if not model: + return False, "Rule not found" + + # Check permission if allowed_rule_ids is provided + if allowed_rule_ids is not None and model.id not in allowed_rule_ids: + return False, "You cannot delete this rule" + + # Withdraw route + command = route_model(model, WITHDRAW) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + # Log withdrawal + log_withdraw( + user_id, + route.command, + rule_type, + model.id, + f"{user_email} / {org_name}", + ) + + # Special handling for RTBH rules + if rule_type == RuleTypes.RTBH: + current_app.logger.debug(f"Deleting RTBH rule {rule_id} from cache") + RuleWhitelistCache.delete_by_rule_id(rule_id) + + # Delete from database + db.session.delete(model) + db.session.commit() + + return True, "Rule deleted successfully" + + +def delete_rtbh_and_create_whitelist( + rule_id: int, + user_id: int, + org_id: int, + user_email: str, + org_name: str, + allowed_rule_ids: List[int] = None, + whitelist_expires: datetime = None, +) -> Tuple[bool, List[str], Union[Whitelist, None]]: + """ + Delete an RTBH rule and create a whitelist entry from it. + + Args: + rule_id: ID of the RTBH rule to delete + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + allowed_rule_ids: List of rule IDs the user is allowed to delete + whitelist_expires: Expiration time for the whitelist entry (default: 7 days from now) + + Returns: + Tuple containing (success, messages, whitelist_model) + """ + messages = [] + + # First get the RTBH rule to extract its data + model = db.session.get(RTBH, rule_id) + if not model: + return False, ["RTBH rule not found"], None + + # Check permission if allowed_rule_ids is provided + if allowed_rule_ids is not None and model.id not in allowed_rule_ids: + return False, ["You cannot delete this rule"], None + + # Extract data for whitelist + if model.ipv4: + ip = model.ipv4 + mask = model.ipv4_mask + elif model.ipv6: + ip = model.ipv6 + mask = model.ipv6_mask + else: + return False, ["RTBH rule has no IP address"], None + + # Set default whitelist expiration time if not provided + if whitelist_expires is None: + whitelist_expires = datetime.now() + timedelta(days=7) + + # Prepare whitelist data + whitelist_data = { + "ip": ip, + "mask": mask, + "expires": whitelist_expires, + "comment": f"Created from RTBH rule {rule_id}: {model.comment}", + } + + # Delete the RTBH rule + success, delete_message = delete_rule( + rule_type=RuleTypes.RTBH, + rule_id=rule_id, + user_id=user_id, + user_email=user_email, + org_name=org_name, + allowed_rule_ids=allowed_rule_ids, + ) + + if not success: + return False, [delete_message], None + + messages.append(delete_message) + + # Create the whitelist entry + try: + whitelist_model, whitelist_messages = create_or_update_whitelist( + form_data=whitelist_data, user_id=user_id, org_id=org_id, user_email=user_email, org_name=org_name + ) + messages.extend(whitelist_messages) + return True, messages, whitelist_model + except Exception as e: + current_app.logger.exception(f"Error creating whitelist entry: {e}") + messages.append(f"Rule deleted but failed to create whitelist: {str(e)}") + return False, messages, None diff --git a/flowapp/templates/macros.html b/flowapp/templates/macros.html index 2f5dd9f..6f95a13 100644 --- a/flowapp/templates/macros.html +++ b/flowapp/templates/macros.html @@ -111,6 +111,9 @@ + + + {% endif %} {% if rule.comment %}