From 94d03cc4f0cebfc35c01a14436ecdfbdacacc429 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 20 Feb 2025 14:22:29 +0100 Subject: [PATCH 01/46] first prototype of whitelist model and form --- flowapp/constants.py | 8 ++- flowapp/forms.py | 57 +++++++++++++++++++++ flowapp/models.py | 99 ++++++++++++++++++++++++++++++++++++- flowapp/output.py | 1 + flowapp/validators.py | 63 +++++++++++++++++++++++ flowapp/views/api_common.py | 6 +-- run.example.py | 27 ++-------- 7 files changed, 233 insertions(+), 28 deletions(-) diff --git a/flowapp/constants.py b/flowapp/constants.py index 5aa5283..975f99b 100644 --- a/flowapp/constants.py +++ b/flowapp/constants.py @@ -2,6 +2,7 @@ This module contains constant values used in application """ +from enum import Enum from operator import ge, lt DEFAULT_SORT = "expires" @@ -59,7 +60,12 @@ FORM_TIME_PATTERN = "%Y-%m-%dT%H:%M" -class RuleTypes: +class RuleTypes(Enum): RTBH = 1 IPv4 = 4 IPv6 = 6 + + +class RuleOrigin(Enum): + USER = 1 + WHITELIST = 2 diff --git a/flowapp/forms.py b/flowapp/forms.py index 4b914dc..66cf43f 100644 --- a/flowapp/forms.py +++ b/flowapp/forms.py @@ -33,9 +33,11 @@ FORM_TIME_PATTERN, ) from flowapp.validators import ( + IPAddressValidator, IPv4Address, IPv6Address, NetRangeString, + NetworkValidator, PortString, address_in_range, address_with_mask, @@ -648,3 +650,58 @@ def validate_ipv_specific(self): return False return True + + +class WhitelistForm(FlaskForm): + """ + Whitelist form object + Used for creating and editing whitelist entries + Supports both IPv4 and IPv6 addresses + """ + + def __init__(self, *args, **kwargs): + super(WhitelistForm, self).__init__(*args, **kwargs) + self.net_ranges = None + + ip = StringField( + "IP address", + validators=[ + DataRequired(message="Please provide an IP address"), + IPAddressValidator(message="Please provide a valid IP address: {}"), + NetworkValidator(mask_field_name="mask"), + ], + ) + + mask = IntegerField( + "Network mask (bits)", + validators=[ + DataRequired(message="Please provide a network mask"), + ], + ) + + comment = TextAreaField("Comments", validators=[Optional(), Length(max=255)]) + + expires = MultiFormatDateTimeLocalField( + "Expires", + format=FORM_TIME_PATTERN, + validators=[DataRequired(), InputRequired()], + ) + + def validate(self): + """ + Custom validation method + :return: boolean + """ + result = True + + if not FlaskForm.validate(self): + result = False + + # Validate IP is in organization range + if self.ip.data and self.mask.data and self.net_ranges: + ip_in_range = network_in_range(self.ip.data, self.mask.data, self.net_ranges) + if not ip_in_range: + self.ip.errors.append("IP address must be in organization range: {}.".format(self.net_ranges)) + result = False + + return result diff --git a/flowapp/models.py b/flowapp/models.py index 4e1291f..54c50a4 100644 --- a/flowapp/models.py +++ b/flowapp/models.py @@ -662,10 +662,105 @@ def __init__(self, time, task, user_id, rule_type, rule_id, author, org_id=None) self.org_id = org_id -# DDL -# default values for tables inserted after create +class Whitelist(db.Model): + id = db.Column(db.Integer, primary_key=True) + ip = db.Column(db.String(255)) + mask = db.Column(db.Integer) + comment = db.Column(db.Text) + expires = db.Column(db.DateTime) + created = db.Column(db.DateTime) + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", backref="whitelist") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="whitelist") + rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) + rstate = db.relationship("Rstate", backref="whitelist") + + def __init__( + self, + ip, + mask, + expires, + user_id, + org_id, + created=None, + comment=None, + rstate_id=1, + ): + self.ip = ip + self.mask = mask + self.expires = expires + self.user_id = user_id + self.org_id = org_id + self.comment = comment + if created is None: + created = datetime.now() + self.created = created + self.rstate_id = rstate_id + + def __eq__(self, other): + """ + Two whitelists are equal if all the network parameters equals. User_id and time fields can differ. + :param other: other Whitelist instance + :return: boolean + """ + return ( + self.ip == other.ip + and self.mask == other.mask + and self.expires == other.expires + and self.user_id == other.user_id + and self.org_id == other.org_id + and self.rstate_id == other.rstate_id + ) + def to_dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + if prefered_format == "timestamp": + expires = int(datetime.timestamp(self.expires)) + created = int(datetime.timestamp(self.created)) + else: + expires = utils.datetime_to_webpicker(self.expires, prefered_format) + created = utils.datetime_to_webpicker(self.created, prefered_format) + return { + "id": self.id, + "ip": self.ip, + "mask": self.mask, + "comment": self.comment, + "expires": expires, + "created": created, + "user": self.user.uuid, + "rstate": self.rstate.description, + } + + +class RuleWhitelistCache(db.Model): + """ + Cache for whitelisted rules + For each rule we store id and type + Rule origin determines if the rule was created by user or by whitelist + """ + + id = db.Column(db.Integer, primary_key=True) + rid = db.Column(db.Integer) + rtype = db.Column(db.Integer) + rorigin = db.Column(db.Integer) + whitelist_id = db.Column(db.Integer, db.ForeignKey("whitelist.id")) # Add ForeignKey + whitelist = db.relationship("Whitelist", backref="rulewhitelistcache") + + def __init__(self, rid, rtype, rorigin, whitelist_id): + self.rid = rid + self.rtype = rtype + self.rorigin = rorigin + self.whitelist_id = whitelist_id + + +# DDL +# default values for tables inserted after create @event.listens_for(Action.__table__, "after_create") def insert_initial_actions(table, conn, *args, **kwargs): conn.execute( diff --git a/flowapp/output.py b/flowapp/output.py index 3dde822..b0cc7b4 100644 --- a/flowapp/output.py +++ b/flowapp/output.py @@ -96,6 +96,7 @@ def log_route(user_id, route_model, rule_type, author): :param rule_type: string :return: None """ + print(rule_type) converter = ROUTE_MODELS[rule_type] task = converter(route_model) log = Log( diff --git a/flowapp/validators.py b/flowapp/validators.py index 6306106..3b4f706 100644 --- a/flowapp/validators.py +++ b/flowapp/validators.py @@ -352,3 +352,66 @@ def subnet_of(net_a, net_b): def supernet_of(net_a, net_b): """Return True if this network is a supernet of other.""" return _is_subnet_of(net_b, net_a) + + +class IPAddressValidator: + """ + Universal validator that accepts both IPv4 and IPv6 addresses. + """ + + def __init__(self, message=None): + self.message = message or "Invalid IP address: {}" + + def __call__(self, form, field): + if not field.data: + return + + try: + ipaddress.ip_address(field.data) + except ValueError: + raise ValidationError(self.message.format(field.data)) + + +class NetworkValidator: + """ + Validates that an IP address and mask form a valid network. + Works with both IPv4 and IPv6 addresses. + """ + + def __init__(self, mask_field_name, message=None): + self.mask_field_name = mask_field_name + self.message = message or "Invalid network: address {}, mask {}" + + def __call__(self, form, field): + if not field.data: + return + + mask_field = form._fields.get(self.mask_field_name) + if not mask_field or not mask_field.data: + return + + try: + # Determine IP version + ip = ipaddress.ip_address(field.data) + mask = int(mask_field.data) + + # Validate mask range based on IP version + if isinstance(ip, ipaddress.IPv4Address): + if not 0 <= mask <= 32: + raise ValidationError(f"Invalid IPv4 mask: {mask}") + else: # IPv6 + if not 0 <= mask <= 128: + raise ValidationError(f"Invalid IPv6 mask: {mask}") + + # Try to create network to validate the combination + network = ipaddress.ip_network(f"{field.data}/{mask}", strict=False) + + # Check if the original IP is the correct network address + if str(network.network_address) != str(ip): + raise ValidationError( + f"Invalid network address: {field.data}/{mask}. " + f"Network address should be: {network.network_address}" + ) + + except ValueError: + raise ValidationError(self.message.format(field.data, mask_field.data)) diff --git a/flowapp/views/api_common.py b/flowapp/views/api_common.py index 6ce24bf..eeaade6 100644 --- a/flowapp/views/api_common.py +++ b/flowapp/views/api_common.py @@ -292,7 +292,7 @@ def create_ipv4(current_user): log_route( current_user["id"], model, - RuleTypes.IPv4, + RuleTypes.IPv4.value, f"{current_user['uuid']} / {current_user['org']}", ) pref_format = output_date_format(json_request_data, form.expires.pref_format) @@ -368,7 +368,7 @@ def create_ipv6(current_user): log_route( current_user["id"], model, - RuleTypes.IPv6, + RuleTypes.IPv6.value, f"{current_user['uuid']} / {current_user['org']}", ) @@ -441,7 +441,7 @@ def create_rtbh(current_user): log_route( current_user["id"], model, - RuleTypes.RTBH, + RuleTypes.RTBH.value, f"{current_user['uuid']} / {current_user['org']}", ) diff --git a/run.example.py b/run.example.py index b3dd0c4..ad6f97e 100644 --- a/run.example.py +++ b/run.example.py @@ -1,30 +1,19 @@ """ -This is an example of how to run the application. -First copy the file as run.py (or whatever you want) -Then edit the file to match your needs. - -From version 0.8.1 the application is using Flask-Session -stored in DB using SQL Alchemy driver. This can be configured for other -drivers, however server side session is required for the application. - -In general you should not need to edit this example file. -Only if you want to configure the application main menu and -dashboard. - -Or in case that you want to add extensions etc. +This is run py for the application. Copied to the container on build. """ from os import environ -from flowapp import create_app, db, sess +from flowapp import create_app, db import config # Configurations -env = environ.get("EXAFS_ENV", "Production") +exafs_env = environ.get("EXAFS_ENV", "Production") +exafs_env = exafs_env.lower() # Call app factory -if env == "devel": +if exafs_env == "devel" or exafs_env == "development": app = create_app(config.DevelopmentConfig) else: app = create_app(config.ProductionConfig) @@ -32,12 +21,6 @@ # init database object db.init_app(app) -# init session -app.config.update(SESSION_TYPE="sqlalchemy") -app.config.update(SESSION_SQLALCHEMY=db) -sess.init_app(app) - - # run app if __name__ == "__main__": app.run(host="::", port=8080, debug=True) From 073c2b4f8c3056d4ad8c88c423d4d2527e4fad3d Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 20 Feb 2025 14:33:52 +0100 Subject: [PATCH 02/46] fixed api after convetting RuleTypes to Enum class --- flowapp/views/api_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flowapp/views/api_common.py b/flowapp/views/api_common.py index eeaade6..d04e40d 100644 --- a/flowapp/views/api_common.py +++ b/flowapp/views/api_common.py @@ -191,7 +191,7 @@ def all_communities(current_user): def limit_reached(count, rule_type, org_id): - rule_name = RULE_NAMES_DICT[int(rule_type)] + rule_name = RULE_NAMES_DICT[rule_type.value] org = db.session.get(Organization, org_id) if rule_type == RuleTypes.IPv4: limit = org.limit_flowspec4 @@ -207,7 +207,7 @@ def limit_reached(count, rule_type, org_id): def global_limit_reached(count, rule_type): - rule_name = RULE_NAMES_DICT[int(rule_type)] + rule_name = RULE_NAMES_DICT[rule_type.value] if rule_type == RuleTypes.IPv4 or rule_type == RuleTypes.IPv6: limit = current_app.config.get("FLOWSPEC_MAX_RULES") elif rule_type == RuleTypes.RTBH: From f6672dc4adbd30d5943378a5c30b856bd574ec0b Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 20 Feb 2025 15:12:35 +0100 Subject: [PATCH 03/46] add more test for validators, fix edge cases validation in flowspec and validators modules --- flowapp/flowspec.py | 46 +++--- flowapp/tests/test_validators.py | 258 ++++++++++++++++++++++++++++--- flowapp/validators.py | 5 +- 3 files changed, 270 insertions(+), 39 deletions(-) diff --git a/flowapp/flowspec.py b/flowapp/flowspec.py index e0ce35a..ae357ca 100644 --- a/flowapp/flowspec.py +++ b/flowapp/flowspec.py @@ -21,6 +21,19 @@ def translate_sequence(sequence, max_val=MAX_PORT): return "[{}]".format(result) +def check_limit(value, max_value, min_value=0): + """ + test if the value is within valid range (min_value to max_value inclusive) + raise exception otherwise + """ + value = int(value) + if value > max_value: + raise ValueError("Invalid value number: {} is too big. Max is {}.".format(value, max_value)) + if value < min_value: + raise ValueError("Invalid value number: {} is too small. Min is {}.".format(value, min_value)) + return value + + def to_exabgp_string(value_string, max_val): """ Translate form string to flowspec value or packet size rule @@ -42,35 +55,30 @@ def to_exabgp_string(value_string, max_val): return "={}".format(check_limit(value_string, max_val)) elif RANGE.match(value_string): m = RANGE.match(value_string) - return ">={}&<={}".format( - check_limit(m.group(1), max_val), check_limit(m.group(2), max_val) - ) + start = check_limit(m.group(1), max_val) + end = check_limit(m.group(2), max_val) + if start > end: + raise ValueError("Invalid range: start value cannot be greater than end value") + return ">={}&<={}".format(start, end) elif NOTRAN.match(value_string): - return value_string + m = NOTRAN.match(value_string) + start = check_limit(m.group(1), max_val) + end = check_limit(m.group(2), max_val) + if start > end: + raise ValueError("Invalid range: start value cannot be greater than end value") + return ">={}&<={}".format(start, end) elif GREATER.match(value_string): m = GREATER.match(value_string) return ">={}&<={}".format(check_limit(m.group(1), max_val), max_val) elif LOWER.match(value_string): m = LOWER.match(value_string) - return ">=0&<={}".format(check_limit(m.group(1), max_val)) + # Even for lower bound expressions, validate that the value itself is within range + end = check_limit(m.group(1), max_val) + return ">=0&<={}".format(end) else: raise ValueError("string {} can not be converted".format(value_string)) -def check_limit(value, max_value): - """ - test if the value is lower than max_value - raise exception otherwise - """ - value = int(value) - if value > max_value: - raise ValueError( - "Invalid value number: {} is too big. Max is {}.".format(value, max_value) - ) - else: - return value - - def filter_rules_action(user_actions, rules): """ Divide the list of rules by user_actions to editable and viewonly subsets diff --git a/flowapp/tests/test_validators.py b/flowapp/tests/test_validators.py index 33e757c..cab089f 100644 --- a/flowapp/tests/test_validators.py +++ b/flowapp/tests/test_validators.py @@ -1,11 +1,28 @@ import pytest -import flowapp.validators +from flowapp.validators import ( + PortString, + PacketString, + NetRangeString, + NetInRange, + IPAddress, + IPAddressValidator, + NetworkValidator, + ValidationError, + address_in_range, + address_with_mask, + IPv4Address, + IPv6Address, + DateNotExpired, + editable_range, + network_in_range, + range_in_network, +) def test_port_string_len_raises(field): - port = flowapp.validators.PortString() + port = PortString() field.data = "1;2;3;4;5;6;7;8" - with pytest.raises(flowapp.validators.ValidationError): + with pytest.raises(ValidationError): port(None, field) @@ -20,12 +37,12 @@ def test_port_string_len_raises(field): ], ) def test_is_valid_address_with_mask(address, mask, expected): - assert flowapp.validators.address_with_mask(address, mask) == expected + assert address_with_mask(address, mask) == expected @pytest.mark.parametrize("address", ["147.230.23.25", "147.230.23.0"]) def test_ip4address_passes(field, address): - adr = flowapp.validators.IPv4Address() + adr = IPv4Address() field.data = address adr(None, field) @@ -38,7 +55,7 @@ def test_ip4address_passes(field, address): ], ) def test_ip6address_passes(field, address): - adr = flowapp.validators.IPv6Address() + adr = IPv6Address() field.data = address adr(None, field) @@ -51,19 +68,17 @@ def test_ip6address_passes(field, address): ], ) def test_bad_ip6address_raises(field, address): - adr = flowapp.validators.IPv4Address() + adr = IPv4Address() field.data = address - with pytest.raises(flowapp.validators.ValidationError): + with pytest.raises(ValidationError): adr(None, field) -@pytest.mark.parametrize( - "expired", ["2018/10/25 14:46", "2018/12/20 9:46", "2019/05/22 12:33"] -) +@pytest.mark.parametrize("expired", ["2018/10/25 14:46", "2018/12/20 9:46", "2019/05/22 12:33"]) def test_expired_date_raises(field, expired): - adr = flowapp.validators.DateNotExpired() + adr = DateNotExpired() field.data = expired - with pytest.raises(flowapp.validators.ValidationError): + with pytest.raises(ValidationError): adr(None, field) @@ -75,9 +90,9 @@ def test_expired_date_raises(field, expired): ], ) def test_ipaddress_raises(field, address): - adr = flowapp.validators.IPv6Address() + adr = IPv6Address() field.data = address - with pytest.raises(flowapp.validators.ValidationError): + with pytest.raises(ValidationError): adr(None, field) @@ -91,7 +106,7 @@ def test_ipaddress_raises(field, address): def test_editable_rule(rule, address, mask, ranges, expected): rule.source = address rule.source_mask = mask - assert flowapp.validators.editable_range(rule, ranges) == expected + assert editable_range(rule, ranges) == expected @pytest.mark.parametrize( @@ -104,7 +119,7 @@ def test_editable_rule(rule, address, mask, ranges, expected): ], ) def test_address_in_range(address, mask, ranges, expected): - assert flowapp.validators.address_in_range(address, ranges) == expected + assert address_in_range(address, ranges) == expected @pytest.mark.parametrize( @@ -123,7 +138,7 @@ def test_address_in_range(address, mask, ranges, expected): ], ) def test_network_in_range(address, mask, ranges, expected): - assert flowapp.validators.network_in_range(address, mask, ranges) == expected + assert network_in_range(address, mask, ranges) == expected @pytest.mark.parametrize( @@ -133,4 +148,209 @@ def test_network_in_range(address, mask, ranges, expected): ], ) def test_range_in_network(address, mask, ranges, expected): - assert flowapp.validators.range_in_network(address, mask, ranges) == expected + assert range_in_network(address, mask, ranges) == expected + + +### new tests + + +# New tests for previously uncovered validators + + +# Tests for PortString validator syntax +@pytest.mark.parametrize( + "port_data", + [ + "80", # Simple number + "80-443", # Range with hyphen + ">=80&<=443", # Explicit range + ">80", # Greater than + ">=80", # Greater than or equal + "<443", # Less than + "<=443", # Less than or equal + "80;443;8080", # Multiple port expressions + ], +) +def test_port_string_valid_syntax(field, port_data): + validator = PortString() + field.data = port_data + validator(None, field) # Should not raise + + +@pytest.mark.parametrize( + "invalid_port_data", + [ + "80,443", # Comma not supported, should use semicolon + "80..443", # Invalid range syntax (should be 80-443) + ">65536", # Port number too high + "-1", # Negative number + "<-1", # Port number too low + ">=80<=443", # Invalid range syntax (missing &) + "!=80", # Not equals not supported + "abc", # Non-numeric + "80-", # Incomplete range + "-80", # Invalid range + "80&&443", # Invalid operator + "443-80", # End less than start + ">=443&<=80", # End less than start in explicit range + "0-65537", # Range exceeding max + ], +) +def test_port_string_invalid_syntax(field, invalid_port_data): + validator = PortString() + field.data = invalid_port_data + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for PacketString validator +@pytest.mark.parametrize( + "packet_data", + ["1500", ">1000", "<9000", "1000-1500", ">=1000&<=1500", "1500;9000"], # Multiple packet size expressions +) +def test_packet_string_valid(field, packet_data): + validator = PacketString() + field.data = packet_data + validator(None, field) # Should not raise + + +@pytest.mark.parametrize( + "invalid_packet_data", + [ + "65536", # Too large + "-1", # Negative + "1500..", # Invalid range syntax + "abc", # Non-numeric + "!!1500", # Invalid operator + "1500-", # Incomplete range + "9000-1500", # End less than start + ], +) +def test_packet_string_invalid(field, invalid_packet_data): + validator = PacketString() + field.data = invalid_packet_data + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for NetRangeString validator +@pytest.mark.parametrize( + "net_range", + [ + "192.168.0.0/24", + "10.0.0.0/8\n172.16.0.0/12", + "2001:db8::/32", + "192.168.0.0/24 10.0.0.0/8", + "2001:db8::/32 2001:db8:1::/48", + ], +) +def test_net_range_string_valid(field, net_range): + validator = NetRangeString() + field.data = net_range + validator(None, field) # Should not raise + + +@pytest.mark.parametrize( + "invalid_net_range", + [ + "192.168.0.0/33", # Invalid mask + "256.256.256.0/24", # Invalid IP + "2001:xyz::/32", # Invalid IPv6 + "192.168.0.0/24/", # Malformed + "not-an-ip/24", # Invalid format + ], +) +def test_net_range_string_invalid(field, invalid_net_range): + validator = NetRangeString() + field.data = invalid_net_range + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for NetInRange validator +def test_net_in_range_valid(field): + net_ranges = ["192.168.0.0/16", "10.0.0.0/8"] + validator = NetInRange(net_ranges) + field.data = "192.168.1.1/24" + validator(None, field) # Should not raise + + +def test_net_in_range_invalid(field): + net_ranges = ["192.168.0.0/16", "10.0.0.0/8"] + validator = NetInRange(net_ranges) + field.data = "172.16.1.1/24" + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for base IPAddress validator +@pytest.mark.parametrize("ip_addr", ["192.168.1.1", "10.0.0.1", "2001:db8::1", "fe80::1"]) +def test_ip_address_valid(field, ip_addr): + validator = IPAddress() + field.data = ip_addr + validator(None, field) # Should not raise + + +@pytest.mark.parametrize("invalid_ip", ["256.256.256.256", "2001:xyz::1", "not-an-ip", "192.168.1"]) +def test_ip_address_invalid(field, invalid_ip): + validator = IPAddress() + field.data = invalid_ip + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for universal IPAddressValidator +@pytest.mark.parametrize("ip_addr", ["192.168.1.1", "2001:db8::1", "", None]) # Empty is allowed # None is allowed +def test_ip_address_validator_valid(field, ip_addr): + validator = IPAddressValidator() + field.data = ip_addr + validator(None, field) # Should not raise + + +@pytest.mark.parametrize("invalid_ip", ["256.256.256.256", "2001:xyz::1", "not-an-ip"]) +def test_ip_address_validator_invalid(field, invalid_ip): + validator = IPAddressValidator() + field.data = invalid_ip + with pytest.raises(ValidationError): + validator(None, field) + + +# Tests for NetworkValidator +class MockForm: + def __init__(self, address, mask): + self._fields = {"mask": type("MockField", (), {"data": mask})()} + + +@pytest.mark.parametrize( + "address,mask", + [ + ("192.168.0.0", "24"), + ("10.0.0.0", "8"), + ("2001:db8::", "32"), + ("", None), # Empty values should pass + (None, None), # None values should pass + ], +) +def test_network_validator_valid(field, address, mask): + validator = NetworkValidator("mask") + field.data = address + form = MockForm(address, mask) + validator(form, field) # Should not raise + + +@pytest.mark.parametrize( + "address,mask", + [ + ("192.168.1.1", "24"), # Not a network address + ("192.168.0.0", "33"), # Invalid IPv4 mask + ("2001:db8::", "129"), # Invalid IPv6 mask + ("256.256.256.0", "24"), # Invalid IP + ("2001:xyz::", "32"), # Invalid IPv6 + ], +) +def test_network_validator_invalid(field, address, mask): + validator = NetworkValidator("mask") + field.data = address + form = MockForm(address, mask) + with pytest.raises(ValidationError): + validator(form, field) diff --git a/flowapp/validators.py b/flowapp/validators.py index 3b4f706..890bdac 100644 --- a/flowapp/validators.py +++ b/flowapp/validators.py @@ -253,7 +253,10 @@ def __call__(self, form, field): result = False for address in field.data.split("/"): for adr_range in self.net_ranges: - result = result or ipaddress.ip_address(address) in ipaddress.ip_network(adr_range) + try: + result = result or ipaddress.ip_address(address) in ipaddress.ip_network(adr_range) + except ValueError as e: + raise ValidationError(self.message + str(e.args[0])) if not result: raise ValidationError(self.message) From 4e2a6a80e4c547e44a7cf98b4725854feef1fa56 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 21 Feb 2025 09:35:30 +0100 Subject: [PATCH 04/46] increased test coverage for forms, fixed minor edge cases validation issues in forms --- flowapp/forms.py | 10 +- flowapp/tests/test_flowspec.py | 120 ++++++- flowapp/tests/test_forms_cl.py | 615 +++++++++++++++++++++++++++++++++ 3 files changed, 733 insertions(+), 12 deletions(-) create mode 100644 flowapp/tests/test_forms_cl.py diff --git a/flowapp/forms.py b/flowapp/forms.py index 66cf43f..5baca78 100644 --- a/flowapp/forms.py +++ b/flowapp/forms.py @@ -67,15 +67,21 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def process_formdata(self, valuelist): - if not valuelist: + if not valuelist or (len(valuelist) == 1 and not valuelist[0]): return None + # with unlimited field we do not need to parse the empty value if self.unlimited and len(valuelist) == 1 and len(valuelist[0]) == 0: self.data = None return None date_str = " ".join((str(val) for val in valuelist)) - result, pref_format = parse_api_time(date_str) + + try: + result, pref_format = parse_api_time(date_str) + except TypeError: + raise ValueError(self.gettext("Not a valid datetime value.")) + if result: self.data = result self.pref_format = pref_format diff --git a/flowapp/tests/test_flowspec.py b/flowapp/tests/test_flowspec.py index e7a9c3e..70000b8 100644 --- a/flowapp/tests/test_flowspec.py +++ b/flowapp/tests/test_flowspec.py @@ -1,12 +1,12 @@ import pytest -import flowapp.flowspec +from flowapp.flowspec import translate_sequence, filter_rules_action, check_limit def test_translate_number(): """ tests for x (integer) to =x """ - assert "[=10]" == flowapp.flowspec.translate_sequence("10") + assert "[=10]" == translate_sequence("10") def test_raises(): @@ -14,7 +14,7 @@ def test_raises(): tests for translator """ with pytest.raises(ValueError): - flowapp.flowspec.translate_sequence("ahoj") + translate_sequence("ahoj") def test_raises_bad_number(): @@ -22,46 +22,146 @@ def test_raises_bad_number(): tests for translator """ with pytest.raises(ValueError): - flowapp.flowspec.translate_sequence("75555") + translate_sequence("75555") def test_translate_range(): """ tests for x-y to >=x&<=y """ - assert "[>=10&<=20]" == flowapp.flowspec.translate_sequence("10-20") + assert "[>=10&<=20]" == translate_sequence("10-20") def test_exact_rule(): """ test for >=x&<=y to >=x&<=y """ - assert "[>=10&<=20]" == flowapp.flowspec.translate_sequence(">=10&<=20") + assert "[>=10&<=20]" == translate_sequence(">=10&<=20") def test_greater_than(): """ test for >x to >=x&<=65535 """ - assert "[>=10&<=65535]" == flowapp.flowspec.translate_sequence(">10") + assert "[>=10&<=65535]" == translate_sequence(">10") def test_greater_equal_than(): """ test for >=x to >=x&<=65535 """ - assert "[>=10&<=65535]" == flowapp.flowspec.translate_sequence(">=10") + assert "[>=10&<=65535]" == translate_sequence(">=10") def test_lower_than(): """ test for =0&<=0 """ - assert "[>=0&<=10]" == flowapp.flowspec.translate_sequence("<10") + assert "[>=0&<=10]" == translate_sequence("<10") def test_lower_equal_than(): """ test for =0&<=0 """ - assert "[>=0&<=10]" == flowapp.flowspec.translate_sequence("<=10") + assert "[>=0&<=10]" == translate_sequence("<=10") + + +# new tests + + +def test_multiple_sequences(): + """Test multiple sequences separated by semicolons""" + assert "[=10 >=20&<=30]" == translate_sequence("10;20-30") + assert "[=10 >=0&<=20 >=30&<=65535]" == translate_sequence("10;<=20;>30") + + +def test_empty_sequence(): + """Test empty sequence and sequences with empty parts""" + assert "[]" == translate_sequence("") + assert "[=10]" == translate_sequence("10;;") + assert "[=10 =20]" == translate_sequence("10;;20") + + +def test_range_edge_cases(): + """Test edge cases for ranges""" + # Same numbers in range + assert "[>=10&<=10]" == translate_sequence("10-10") + + # Invalid range (start > end) + with pytest.raises(ValueError, match="Invalid range: start value cannot be greater than end value"): + translate_sequence("20-10") + + # Invalid range in exact rule + with pytest.raises(ValueError, match="Invalid range: start value cannot be greater than end value"): + translate_sequence(">=20&<=10") + + +def test_check_limit_validation(): + """Test the check_limit function""" + # Test valid cases + assert check_limit(10, 100) == 10 + assert check_limit(0, 100, 0) == 0 + assert check_limit(100, 100, 0) == 100 + + # Test invalid cases + with pytest.raises(ValueError, match="Invalid value number: .* is too big"): + check_limit(101, 100) + + with pytest.raises(ValueError, match="Invalid value number: .* is too small"): + check_limit(-1, 100, 0) + + +def test_invalid_inputs(): + """Test various invalid inputs""" + invalid_inputs = [ + "abc", # Non-numeric + "10.5", # Decimal + "10,20", # Wrong separator + "10&20", # Wrong format + ">", # Incomplete + ">=", # Incomplete + "<", # Incomplete + "<=", # Incomplete + ">>10", # Invalid operator + "<<10", # Invalid operator + ">=10&", # Incomplete range + "&<=10", # Incomplete range + ] + + for invalid_input in invalid_inputs: + with pytest.raises(ValueError): + translate_sequence(invalid_input) + + +class MockRule: + def __init__(self, action_id): + self.action_id = action_id + + +def test_filter_rules_action(): + """Test the filter_rules_action function""" + # Create test rules + rules = [MockRule(1), MockRule(2), MockRule(3), MockRule(4)] + + # Test with empty allowed actions + editable, viewonly = filter_rules_action([], rules) + assert len(editable) == 0 + assert len(viewonly) == 4 + + # Test with some allowed actions + editable, viewonly = filter_rules_action([1, 3], rules) + assert len(editable) == 2 + assert len(viewonly) == 2 + assert all(rule.action_id in [1, 3] for rule in editable) + assert all(rule.action_id in [2, 4] for rule in viewonly) + + # Test with all actions allowed + editable, viewonly = filter_rules_action([1, 2, 3, 4], rules) + assert len(editable) == 4 + assert len(viewonly) == 0 + + # Test with empty rules list + editable, viewonly = filter_rules_action([1, 2], []) + assert len(editable) == 0 + assert len(viewonly) == 0 diff --git a/flowapp/tests/test_forms_cl.py b/flowapp/tests/test_forms_cl.py new file mode 100644 index 0000000..d5373ac --- /dev/null +++ b/flowapp/tests/test_forms_cl.py @@ -0,0 +1,615 @@ +import pytest +from datetime import datetime, timedelta +from werkzeug.datastructures import MultiDict +from flask import Flask +from flowapp.forms import ( + UserForm, + BulkUserForm, + ApiKeyForm, + MachineApiKeyForm, + OrganizationForm, + ActionForm, + ASPathForm, + CommunityForm, + RTBHForm, + IPv4Form, + IPv6Form, + WhitelistForm, +) + + +def create_form_data(data): + """Helper function to create proper form data format""" + processed_data = {} + for key, value in data.items(): + if isinstance(value, list): + processed_data[key] = [str(v) for v in value] + else: + processed_data[key] = value + return MultiDict(processed_data) + + +@pytest.fixture() +def app(): + """Create Flask app with CSRF disabled for testing""" + app = Flask(__name__) + app.config.update(SECRET_KEY="test_secret", WTF_CSRF_ENABLED=False, TESTING=True) + return app + + +@pytest.fixture +def valid_datetime(): + return (datetime.now() + timedelta(days=1)).strftime("%Y-%m-%dT%H:%M") + + +@pytest.fixture +def sample_network_ranges(): + return ["192.168.0.0/16", "2001:db8::/32"] + + +class TestUserForm: + @pytest.fixture + def mock_choices(self): + return { + "role_ids": [(1, "Admin"), (2, "User"), (3, "Guest")], + "org_ids": [(1, "Org1"), (2, "Org2"), (3, "Org3")], + } + + @pytest.fixture + def valid_user_data(self): + return { + "uuid": "test@example.com", + "email": "user@example.com", + "name": "Test User", + "phone": "123456789", + "role_ids": ["2"], + "org_ids": ["1"], + } + + def test_valid_user_form(self, app, valid_user_data, mock_choices): + with app.test_request_context(): + form_data = create_form_data(valid_user_data) + form = UserForm(formdata=form_data) + form.role_ids.choices = mock_choices["role_ids"] + form.org_ids.choices = mock_choices["org_ids"] + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_invalid_email(self, app, mock_choices): + with app.test_request_context(): + form_data = create_form_data({"uuid": "invalid-email", "role_ids": ["2"], "org_ids": ["1"]}) + form = UserForm(formdata=form_data) + form.role_ids.choices = mock_choices["role_ids"] + form.org_ids.choices = mock_choices["org_ids"] + + assert not form.validate() + assert "Please provide valid email" in form.uuid.errors + + +class TestRTBHForm: + @pytest.fixture + def mock_community_choices(self): + return [(1, "Community1"), (2, "Community2")] + + @pytest.fixture + def valid_rtbh_data(self, valid_datetime): + return { + "ipv4": "192.168.1.0", + "ipv4_mask": 24, + "community": "1", + "expires": valid_datetime, + "comment": "Test RTBH rule", + } + + def test_valid_ipv4_rtbh(self, app, valid_rtbh_data, sample_network_ranges, mock_community_choices): + with app.test_request_context(): + form_data = create_form_data(valid_rtbh_data) + form = RTBHForm(formdata=form_data) + form.net_ranges = sample_network_ranges + form.community.choices = mock_community_choices + assert form.validate() + + +class TestIPv4Form: + @pytest.fixture + def mock_action_choices(self): + return [(1, "Accept"), (2, "Drop"), (3, "Reject")] + + @pytest.fixture + def valid_ipv4_data(self, valid_datetime): + return { + "source": "192.168.1.0", + "source_mask": "24", + "protocol": "tcp", + "action": "1", + "expires": valid_datetime, + } + + def test_valid_ipv4_rule(self, app, valid_ipv4_data, sample_network_ranges, mock_action_choices): + with app.test_request_context(): + form_data = create_form_data(valid_ipv4_data) + form = IPv4Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_invalid_protocol_flags(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + with app.test_request_context(): + form_data = create_form_data( + { + "source": "192.168.1.0", + "source_mask": "24", + "protocol": "udp", + "flags": ["syn"], + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv4Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + assert not form.validate() + + +class TestIPv6Form: + @pytest.fixture + def mock_action_choices(self): + return [(1, "Accept"), (2, "Drop"), (3, "Reject")] + + @pytest.fixture + def valid_ipv6_data(self, valid_datetime): + return { + "source": "2001:db8::", # Network aligned address within allowed range + "source_mask": "32", # Matching the organization's prefix length + "next_header": "tcp", + "action": "1", + "expires": valid_datetime, + } + + def test_valid_ipv6_rule(self, app, valid_ipv6_data, sample_network_ranges, mock_action_choices): + with app.test_request_context(): + form_data = create_form_data(valid_ipv6_data) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_invalid_next_header_flags(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db8::", + "source_mask": "32", + "next_header": "udp", + "flags": ["syn"], + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + assert not form.validate() + + def test_address_outside_range(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + """Test validation fails when address is outside allowed ranges""" + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db9::", # Different prefix + "source_mask": "32", + "next_header": "tcp", + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + assert not form.validate() + assert any("must be in organization range" in error for error in form.source.errors) + + def test_destination_address(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + """Test validation with destination address instead of source""" + with app.test_request_context(): + form_data = create_form_data( + { + "dest": "2001:db8::", + "dest_mask": "32", + "next_header": "tcp", + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_both_source_and_dest(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + """Test validation with both source and destination addresses""" + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db8::", + "source_mask": "32", + "dest": "2001:db8:1::", + "dest_mask": "48", + "next_header": "tcp", + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_tcp_flags(self, app, valid_datetime, sample_network_ranges, mock_action_choices): + """Test validation with TCP flags (should be valid with TCP)""" + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db8::", + "source_mask": "32", + "next_header": "tcp", + "flags": ["SYN", "ACK"], + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + @pytest.mark.parametrize("port_data", ["80", "80;443", "1024-2048"]) + def test_valid_ports(self, app, valid_datetime, sample_network_ranges, mock_action_choices, port_data): + """Test validation with various valid port formats""" + with app.test_request_context(): + form_data = create_form_data( + { + "source": "2001:db8::", + "source_mask": "32", + "next_header": "tcp", + "source_port": port_data, + "dest_port": port_data, + "action": "1", + "expires": valid_datetime, + } + ) + form = IPv6Form(formdata=form_data) + form.net_ranges = sample_network_ranges + form.action.choices = mock_action_choices + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + +class TestBulkUserForm: + @pytest.fixture + def valid_csv_data(self): + return "uuid-eppn,role,organizace\nuser1@example.com,2,1\nuser2@example.com,2,1" + + def test_valid_bulk_import(self, app, valid_csv_data): + with app.test_request_context(): + form_data = create_form_data({"users": valid_csv_data}) + form = BulkUserForm(formdata=form_data) + form.roles = {2} # Mock available roles + form.organizations = {1} # Mock available organizations + form.uuids = set() # Mock existing UUIDs (empty set) + assert form.validate() + + def test_duplicate_uuid(self, app, valid_csv_data): + with app.test_request_context(): + form_data = create_form_data({"users": valid_csv_data}) + form = BulkUserForm(formdata=form_data) + form.roles = {2} + form.organizations = {1} + form.uuids = {"user1@example.com"} # UUID already exists + assert not form.validate() + assert any("already exists" in error for error in form.users.errors) + + def test_invalid_role(self, app): + csv_data = "uuid-eppn,role,organizace\nuser1@example.com,999,1" # Invalid role ID + with app.test_request_context(): + form_data = create_form_data({"users": csv_data}) + form = BulkUserForm(formdata=form_data) + form.roles = {2} + form.organizations = {1} + form.uuids = set() + assert not form.validate() + assert any("does not exist" in error for error in form.users.errors) + + +class TestApiKeyForm: + @pytest.fixture + def valid_api_key_data(self, valid_datetime): + return {"machine": "192.168.1.1", "comment": "Test API key", "expires": valid_datetime, "readonly": "true"} + + def test_valid_api_key(self, app, valid_api_key_data): + with app.test_request_context(): + form_data = create_form_data(valid_api_key_data) + form = ApiKeyForm(formdata=form_data) + assert form.validate() + + def test_invalid_ip(self, app, valid_datetime): + with app.test_request_context(): + form_data = create_form_data({"machine": "invalid_ip", "expires": valid_datetime}) + form = ApiKeyForm(formdata=form_data) + assert not form.validate() + assert "provide valid IP address" in form.machine.errors + + def test_unlimited_expiration(self, app): + with app.test_request_context(): + form_data = create_form_data({"machine": "192.168.1.1", "expires": ""}) # Empty expiration for unlimited + form = ApiKeyForm(formdata=form_data) + assert form.validate() + + +class TestMachineApiKeyForm: + # Similar to ApiKeyForm, but might have different validation rules + @pytest.fixture + def valid_machine_key_data(self, valid_datetime): + return { + "machine": "192.168.1.1", + "comment": "Test machine API key", + "expires": valid_datetime, + "readonly": "true", + } + + def test_valid_machine_key(self, app, valid_machine_key_data): + with app.test_request_context(): + form_data = create_form_data(valid_machine_key_data) + form = MachineApiKeyForm(formdata=form_data) + assert form.validate() + + +class TestOrganizationForm: + @pytest.fixture + def valid_org_data(self): + return { + "name": "Test Organization", + "limit_flowspec4": "100", + "limit_flowspec6": "100", + "limit_rtbh": "50", + "arange": "192.168.0.0/16\n2001:db8::/32", + } + + def test_valid_organization(self, app, valid_org_data): + with app.test_request_context(): + form_data = create_form_data(valid_org_data) + form = OrganizationForm(formdata=form_data) + assert form.validate() + + def test_invalid_ranges(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "Test Org", "arange": "invalid_range\n192.168.0.0/16"}) + form = OrganizationForm(formdata=form_data) + assert not form.validate() + + def test_invalid_limits(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "Test Org", "limit_flowspec4": "1001"}) # Exceeds max value + form = OrganizationForm(formdata=form_data) + assert not form.validate() + + +class TestActionForm: + @pytest.fixture + def valid_action_data(self): + return { + "name": "test_action", + "command": "announce route", + "description": "Test action description", + "role_id": "2", + } + + def test_valid_action(self, app, valid_action_data): + with app.test_request_context(): + form_data = create_form_data(valid_action_data) + form = ActionForm(formdata=form_data) + assert form.validate() + + def test_invalid_role(self, app, valid_action_data): + with app.test_request_context(): + invalid_data = dict(valid_action_data) + invalid_data["role_id"] = "4" # Invalid role + form_data = create_form_data(invalid_data) + form = ActionForm(formdata=form_data) + assert not form.validate() + + +class TestASPathForm: + @pytest.fixture + def valid_aspath_data(self): + return {"prefix": "AS64512", "as_path": "64512 64513 64514"} + + def test_valid_aspath(self, app, valid_aspath_data): + with app.test_request_context(): + form_data = create_form_data(valid_aspath_data) + form = ASPathForm(formdata=form_data) + assert form.validate() + + def test_empty_fields(self, app): + with app.test_request_context(): + form_data = create_form_data({}) + form = ASPathForm(formdata=form_data) + assert not form.validate() + + +class TestCommunityForm: + @pytest.fixture + def valid_community_data(self): + return { + "name": "test_community", + "comm": "64512:100", + "description": "Test community", + "role_id": "2", + "as_path": "true", + } + + def test_valid_community(self, app, valid_community_data): + with app.test_request_context(): + form_data = create_form_data(valid_community_data) + form = CommunityForm(formdata=form_data) + assert form.validate() + + def test_missing_all_community_types(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "test_community", "role_id": "2"}) + form = CommunityForm(formdata=form_data) + assert not form.validate() + assert any("could not be empty" in error for error in form.comm.errors) + + def test_valid_with_extended_community(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "test_community", "extcomm": "rt:64512:100", "role_id": "2"}) + form = CommunityForm(formdata=form_data) + assert form.validate() + + def test_valid_with_large_community(self, app): + with app.test_request_context(): + form_data = create_form_data({"name": "test_community", "larcomm": "64512:100:200", "role_id": "2"}) + form = CommunityForm(formdata=form_data) + assert form.validate() + + +class TestWhitelistForm: + @pytest.fixture + def valid_ipv4_data(self, valid_datetime): + return {"ip": "192.168.0.0", "mask": "24", "comment": "Test whitelist entry", "expires": valid_datetime} + + @pytest.fixture + def valid_ipv6_data(self, valid_datetime): + return {"ip": "2001:db8::", "mask": "32", "comment": "IPv6 whitelist entry", "expires": valid_datetime} + + def test_valid_ipv4_entry(self, app, valid_ipv4_data, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data(valid_ipv4_data) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_valid_ipv6_entry(self, app, valid_ipv6_data, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data(valid_ipv6_data) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + + if not form.validate(): + print("Validation errors:", form.errors) + + assert form.validate() + + def test_invalid_ip_format(self, app, valid_datetime, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data({"ip": "invalid_ip", "mask": "24", "expires": valid_datetime}) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + assert any("valid IP address" in error for error in form.ip.errors) + + def test_ip_outside_range(self, app, valid_datetime): + with app.test_request_context(): + form_data = create_form_data( + {"ip": "10.0.0.0", "mask": "24", "expires": valid_datetime} # IP outside allowed range + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = ["192.168.0.0/16"] # Only allow 192.168.0.0/16 + assert not form.validate() + assert any("must be in organization range" in error for error in form.ip.errors) + + def test_invalid_mask_ipv4(self, app, valid_datetime, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data( + {"ip": "192.168.0.0", "mask": "33", "expires": valid_datetime} # Invalid mask for IPv4 + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + + def test_invalid_mask_ipv6(self, app, valid_datetime, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data( + {"ip": "2001:db8::", "mask": "129", "expires": valid_datetime} # Invalid mask for IPv6 + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + + def test_missing_required_fields(self, app, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data({}) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + assert form.ip.errors # Should have error for missing IP + assert form.mask.errors # Should have error for missing mask + assert form.expires.errors # Should have error for missing expiration + + def test_comment_length(self, app, valid_datetime, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data( + { + "ip": "192.168.0.0", + "mask": "24", + "expires": valid_datetime, + "comment": "x" * 256, # Comment longer than 255 chars + } + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() + + @pytest.mark.parametrize( + "expires", ["", "invalid_date", "2020-33-42T00:00"] # Empty expiration # Invalid date format # Past date + ) + def test_invalid_expiration(self, app, expires, sample_network_ranges): + with app.test_request_context(): + form_data = create_form_data({"ip": "192.168.0.0", "mask": "24", "expires": expires}) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + if not form.validate(): + print("Validation errors:", form.errors) + + assert not form.validate() + + def test_network_alignment(self, app, valid_datetime, sample_network_ranges): + """Test that IP addresses must be properly network-aligned""" + with app.test_request_context(): + form_data = create_form_data( + {"ip": "192.168.1.1", "mask": "24", "expires": valid_datetime} # Not aligned to /24 boundary + ) + form = WhitelistForm(formdata=form_data) + form.net_ranges = sample_network_ranges + assert not form.validate() From c3b2bbb73906d94372774df3c9d359988d57bf87 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 21 Feb 2025 10:10:36 +0100 Subject: [PATCH 05/46] Refactor equality checks and add model unit tests - Updated equality checks in Flowspec4 and Whitelist models: - Adjusted conditions to compare only relevant fields. - Excluded user_id, org_id, comment, and time fields from Whitelist equality checks. - Improved formatting in docstrings for clarity. - Added unit tests for: - User creation and relationships. - ApiKey and MachineApiKey expiration logic. - Organization user retrieval. - Flowspec6 and Whitelist equality comparisons. - Whitelist serialization with to_dict. These changes enhance test coverage and refine model behavior. --- flowapp/models.py | 15 +- flowapp/tests/test_models.py | 263 ++++++++++++++++++++++++++++++++++- 2 files changed, 267 insertions(+), 11 deletions(-) diff --git a/flowapp/models.py b/flowapp/models.py index 54c50a4..90a3a9b 100644 --- a/flowapp/models.py +++ b/flowapp/models.py @@ -419,7 +419,8 @@ def __init__( def __eq__(self, other): """ - Two models are equal if all the network parameters equals. User_id and time fields can differ. + Two models are equal if all the network parameters equals. + User_id and time fields can differ. :param other: other Flowspec4 instance :return: boolean """ @@ -700,18 +701,12 @@ def __init__( def __eq__(self, other): """ - Two whitelists are equal if all the network parameters equals. User_id and time fields can differ. + Two whitelists are equal if all the network parameters equals. + User_id, org, comment and time fields can differ. :param other: other Whitelist instance :return: boolean """ - return ( - self.ip == other.ip - and self.mask == other.mask - and self.expires == other.expires - and self.user_id == other.user_id - and self.org_id == other.org_id - and self.rstate_id == other.rstate_id - ) + return self.ip == other.ip and self.mask == other.mask and self.rstate_id == other.rstate_id def to_dict(self, prefered_format="yearfirst"): """ diff --git a/flowapp/tests/test_models.py b/flowapp/tests/test_models.py index 8959945..82dff8e 100644 --- a/flowapp/tests/test_models.py +++ b/flowapp/tests/test_models.py @@ -1,4 +1,16 @@ -from datetime import datetime +from datetime import datetime, timedelta +from flowapp.models import ( + User, + Organization, + Role, + ApiKey, + MachineApiKey, + Rstate, + Community, + Action, + Flowspec6, + Whitelist, +) import flowapp.models as models @@ -232,3 +244,252 @@ def test_rtbj_eq(db): ) assert model_A == model_B + + +def test_user_creation(db): + """Test basic user creation and relationships""" + # Create test role and org first + role = Role(name="test_role", description="Test Role") + org = Organization(name="test_org", arange="10.0.0.0/8") + db.session.add_all([role, org]) + db.session.commit() + + # Create user with relationships + user = User( + uuid="test-user-123", name="Test User", phone="1234567890", email="test@example.com", comment="Test comment" + ) + user.role.append(role) + user.organization.append(org) + db.session.add(user) + db.session.commit() + + # Verify user and relationships + assert user.uuid == "test-user-123" + assert user.name == "Test User" + assert len(user.role.all()) == 1 + assert len(user.organization.all()) == 1 + assert user.role.first().name == "test_role" + assert user.organization.first().name == "test_org" + + +def test_api_key_expiration(db): + """Test ApiKey expiration logic""" + user = User(uuid="test-user") + org = Organization(name="test-org", arange="10.0.0.0/8") + db.session.add_all([user, org]) + db.session.commit() + + # Create non-expiring key + non_expiring_key = ApiKey( + machine="test-machine-1", + key="key1", + readonly=True, + expires=None, + comment="Non-expiring key", + user_id=user.id, + org_id=org.id, + ) + + # Create expired key + expired_key = ApiKey( + machine="test-machine-2", + key="key2", + readonly=True, + expires=datetime.now() - timedelta(days=1), + comment="Expired key", + user_id=user.id, + org_id=org.id, + ) + + # Create future key + future_key = ApiKey( + machine="test-machine-3", + key="key3", + readonly=True, + expires=datetime.now() + timedelta(days=1), + comment="Future key", + user_id=user.id, + org_id=org.id, + ) + + db.session.add_all([non_expiring_key, expired_key, future_key]) + db.session.commit() + + assert not non_expiring_key.is_expired() + assert expired_key.is_expired() + assert not future_key.is_expired() + + +def test_machine_api_key_expiration(db): + """Test MachineApiKey expiration logic""" + user = User(uuid="test-user-machine") + org = Organization(name="test-org-machine", arange="10.0.0.0/8") + db.session.add_all([user, org]) + db.session.commit() + + # Create non-expiring key + non_expiring_key = MachineApiKey( + machine="test-machine-1", + key="key1", + readonly=True, + expires=None, + comment="Non-expiring key", + user_id=user.id, + org_id=org.id, + ) + + # Create expired key + expired_key = MachineApiKey( + machine="test-machine-2", + key="key2", + readonly=True, + expires=datetime.now() - timedelta(days=1), + comment="Expired key", + user_id=user.id, + org_id=org.id, + ) + + db.session.add_all([non_expiring_key, expired_key]) + db.session.commit() + + assert not non_expiring_key.is_expired() + assert expired_key.is_expired() + + +def test_organization_get_users(db): + """Test Organization's get_users method""" + org = Organization( + name="test-org-get-user", arange="10.0.0.0/8", limit_flowspec4=100, limit_flowspec6=100, limit_rtbh=100 + ) + uuid1 = "test-org-get-user" + uuid2 = "test-org-get-user2" + user1 = User(uuid=uuid1) + user2 = User(uuid=uuid2) + + db.session.add(org) + db.session.add_all([user1, user2]) + db.session.commit() + + org.user.append(user1) + org.user.append(user2) + db.session.commit() + + users = org.get_users() + assert len(users) == 2 + assert all(isinstance(user, User) for user in users) + assert {user.uuid for user in users} == {uuid1, uuid2} + + +def test_flowspec6_equality(db): + """Test Flowspec6 equality comparison""" + model_a = Flowspec6( + source="2001:db8::1", + source_mask=128, + source_port="80", + destination="2001:db8::2", + destination_mask=128, + destination_port="443", + next_header="tcp", + flags="", + packet_len="", + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + ) + + # Same network parameters but different timestamps + model_b = Flowspec6( + source="2001:db8::1", + source_mask=128, + source_port="80", + destination="2001:db8::2", + destination_mask=128, + destination_port="443", + next_header="tcp", + flags="", + packet_len="", + expires=datetime.now() + timedelta(days=1), + user_id=1, + org_id=1, + action_id=1, + ) + + # Different network parameters + model_c = Flowspec6( + source="2001:db8::3", + source_mask=128, + source_port="80", + destination="2001:db8::4", + destination_mask=128, + destination_port="443", + next_header="tcp", + flags="", + packet_len="", + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + ) + + assert model_a == model_b # Should be equal despite different timestamps + assert model_a != model_c # Should be different due to different network parameters + + +def test_whitelist_equality(db): + """Test Whitelist equality comparison""" + model_a = Whitelist( + ip="192.168.1.1", mask=32, expires=datetime.now(), user_id=1, org_id=1, comment="Test whitelist" + ) + + # Same IP/mask but different timestamps + model_b = Whitelist( + ip="192.168.1.1", + mask=32, + expires=datetime.now() + timedelta(days=1), + user_id=1, + org_id=1, + comment="Different comment", + ) + + # Different IP + model_c = Whitelist( + ip="192.168.1.2", mask=32, expires=datetime.now(), user_id=1, org_id=1, comment="Test whitelist" + ) + + assert model_a == model_b # Should be equal despite different timestamps + assert model_a != model_c # Should be different due to different IP + + +def test_whitelist_to_dict(db): + """Test Whitelist to_dict serialization""" + whitelist = Whitelist( + ip="192.168.1.1", mask=32, expires=datetime.now(), user_id=1, org_id=1, comment="Test whitelist" + ) + + # Create required related objects + user = User(uuid="test-user-whitelist") + rstate = Rstate(description="active") + db.session.add_all([user, rstate]) + db.session.commit() + + whitelist.user = user + whitelist.rstate_id = rstate.id + db.session.add(whitelist) + db.session.commit() + + # Test timestamp format + dict_timestamp = whitelist.to_dict(prefered_format="timestamp") + assert isinstance(dict_timestamp["expires"], int) + assert isinstance(dict_timestamp["created"], int) + + # Test yearfirst format + dict_yearfirst = whitelist.to_dict(prefered_format="yearfirst") + assert isinstance(dict_yearfirst["expires"], str) + assert isinstance(dict_yearfirst["created"], str) + + # Check basic fields + assert dict_timestamp["ip"] == "192.168.1.1" + assert dict_timestamp["mask"] == 32 + assert dict_timestamp["comment"] == "Test whitelist" + assert dict_timestamp["user"] == "test-user-whitelist" From d86634e3eedad7f5467291cf70880365ea53da91 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 27 Feb 2025 13:02:42 +0100 Subject: [PATCH 06/46] whitelist view stub - WIP --- flowapp/views/whitelist.py | 94 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 flowapp/views/whitelist.py diff --git a/flowapp/views/whitelist.py b/flowapp/views/whitelist.py new file mode 100644 index 0000000..adf3129 --- /dev/null +++ b/flowapp/views/whitelist.py @@ -0,0 +1,94 @@ +from datetime import datetime, timedelta +from flask import Blueprint, current_app, flash, redirect, render_template, request, session, url_for + +from flowapp import constants, db, messages +from flowapp.auth import ( + auth_required, + user_or_admin_required, +) +from flowapp.constants import RuleTypes +from flowapp.forms import WhitelistForm +from flowapp.models import ( + Whitelist, + get_user_nets, +) +from flowapp.output import ROUTE_MODELS, announce_route, log_route, log_withdraw, RouteSources, Route +from flowapp.utils import ( + flash_errors, + get_state_by_time, + quote_to_ent, + round_to_ten_minutes, +) + +whitelist = Blueprint("whitelist", __name__, template_folder="templates") + + +@whitelist.route("/add", methods=["GET", "POST"]) +@auth_required +@user_or_admin_required +def add(): + net_ranges = get_user_nets(session["user_id"]) + form = WhitelistForm(request.form) + + form.net_ranges = net_ranges + + if request.method == "POST" and form.validate(): + model = get_ipv4_model_if_exists(form.data, 1) + + if model: + model.expires = round_to_ten_minutes(form.expires.data) + flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value." + else: + model = Flowspec4( + source=form.source.data, + source_mask=form.source_mask.data, + source_port=form.source_port.data, + destination=form.dest.data, + destination_mask=form.dest_mask.data, + destination_port=form.dest_port.data, + protocol=form.protocol.data, + flags=";".join(form.flags.data), + packet_len=form.packet_len.data, + fragment=";".join(form.fragment.data), + expires=round_to_ten_minutes(form.expires.data), + comment=quote_to_ent(form.comment.data), + action_id=form.action.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + rstate_id=get_state_by_time(form.expires.data), + ) + flash_message = "IPv4 Rule saved" + db.session.add(model) + + db.session.commit() + flash(flash_message, "alert-success") + + # announce route if model is in active state + if model.rstate_id == 1: + command = messages.create_ipv4(model, constants.ANNOUNCE) + route = Route( + author=f"{session['user_email']} / {session['user_org']}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + # log changes + log_route( + session["user_id"], + model, + RuleTypes.IPv4, + f"{session['user_email']} / {session['user_org']}", + ) + + return redirect(url_for("index")) + else: + for field, errors in form.errors.items(): + for error in errors: + current_app.logger.debug("Error in the %s field - %s" % (getattr(form, field).label.text, error)) + + print("NOW", datetime.now()) + default_expires = datetime.now() + timedelta(hours=1) + form.expires.data = default_expires + + return render_template("forms/ipv4_rule.html", form=form, action_url=url_for("rules.ipv4_rule")) From af4befc0e78ee6b4b3afffcc0f6a6ba41aae58cc Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 27 Feb 2025 14:48:22 +0100 Subject: [PATCH 07/46] Refactor models into a separate directory and update references - Moved models from flowapp/models.py into a dedicated directory - Updated flowapp/output.py to handle rule_type as an enum value - Fixed a return value issue in flowapp/utils.py's parse_api_time function - Added a debug print statement in flowapp/views/rules.py for user actions UI --- flowapp/models.py | 1154 ----------------------------- flowapp/models/__init__.py | 69 ++ flowapp/models/api.py | 40 + flowapp/models/base.py | 16 + flowapp/models/community.py | 62 ++ flowapp/models/log.py | 21 + flowapp/models/organization.py | 35 + flowapp/models/rules/__init__.py | 6 + flowapp/models/rules/base.py | 68 ++ flowapp/models/rules/flowspec.py | 294 ++++++++ flowapp/models/rules/rtbh.py | 134 ++++ flowapp/models/rules/whitelist.py | 94 +++ flowapp/models/user.py | 79 ++ flowapp/models/utils.py | 330 +++++++++ flowapp/output.py | 4 +- flowapp/utils.py | 3 +- 16 files changed, 1251 insertions(+), 1158 deletions(-) delete mode 100644 flowapp/models.py create mode 100644 flowapp/models/__init__.py create mode 100644 flowapp/models/api.py create mode 100644 flowapp/models/base.py create mode 100644 flowapp/models/community.py create mode 100644 flowapp/models/log.py create mode 100644 flowapp/models/organization.py create mode 100644 flowapp/models/rules/__init__.py create mode 100644 flowapp/models/rules/base.py create mode 100644 flowapp/models/rules/flowspec.py create mode 100644 flowapp/models/rules/rtbh.py create mode 100644 flowapp/models/rules/whitelist.py create mode 100644 flowapp/models/user.py create mode 100644 flowapp/models/utils.py diff --git a/flowapp/models.py b/flowapp/models.py deleted file mode 100644 index 90a3a9b..0000000 --- a/flowapp/models.py +++ /dev/null @@ -1,1154 +0,0 @@ -import json -from sqlalchemy import event -from datetime import datetime -from flowapp import db, utils -from flowapp.constants import RuleTypes -from flask import current_app - -# models and tables - -user_role = db.Table( - "user_role", - db.Column("user_id", db.Integer, db.ForeignKey("user.id"), nullable=False), - db.Column("role_id", db.Integer, db.ForeignKey("role.id"), nullable=False), - db.PrimaryKeyConstraint("user_id", "role_id"), -) - -user_organization = db.Table( - "user_organization", - db.Column("user_id", db.Integer, db.ForeignKey("user.id"), nullable=False), - db.Column("organization_id", db.Integer, db.ForeignKey("organization.id"), nullable=False), - db.PrimaryKeyConstraint("user_id", "organization_id"), -) - - -class User(db.Model): - """ - App User - """ - - id = db.Column(db.Integer, primary_key=True) - uuid = db.Column(db.String(180), unique=True) - comment = db.Column(db.String(500)) - email = db.Column(db.String(255)) - name = db.Column(db.String(255)) - phone = db.Column(db.String(255)) - apikeys = db.relationship("ApiKey", back_populates="user", lazy="dynamic") - machineapikeys = db.relationship("MachineApiKey", back_populates="user", lazy="dynamic") - role = db.relationship("Role", secondary=user_role, lazy="dynamic", backref="user") - - organization = db.relationship("Organization", secondary=user_organization, lazy="dynamic", backref="user") - - def __init__(self, uuid, name=None, phone=None, email=None, comment=None): - self.uuid = uuid - self.phone = phone - self.name = name - self.email = email - self.comment = comment - - def update(self, form): - """ - update the user with values from form object - :param form: flask form from request - :return: None - """ - self.uuid = form.uuid.data - self.name = form.name.data - self.email = form.email.data - self.phone = form.phone.data - self.comment = form.comment.data - - # first clear existing roles and orgs - for role in self.role: - self.role.remove(role) - for org in self.organization: - self.organization.remove(org) - - for role_id in form.role_ids.data: - my_role = db.session.query(Role).filter_by(id=role_id).first() - if my_role not in self.role: - self.role.append(my_role) - - for org_id in form.org_ids.data: - my_org = db.session.query(Organization).filter_by(id=org_id).first() - if my_org not in self.organization: - self.organization.append(my_org) - - db.session.commit() - - -class ApiKey(db.Model): - id = db.Column(db.Integer, primary_key=True) - machine = db.Column(db.String(255)) - key = db.Column(db.String(255)) - readonly = db.Column(db.Boolean, default=False) - expires = db.Column(db.DateTime, nullable=True) - comment = db.Column(db.String(255)) - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", back_populates="apikeys") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="apikey") - - def is_expired(self): - if self.expires is None: - return False # Non-expiring key - else: - return self.expires < datetime.now() - - -class MachineApiKey(db.Model): - id = db.Column(db.Integer, primary_key=True) - machine = db.Column(db.String(255)) - key = db.Column(db.String(255)) - readonly = db.Column(db.Boolean, default=True) - expires = db.Column(db.DateTime, nullable=True) - comment = db.Column(db.String(255)) - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", back_populates="machineapikeys") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="machineapikey") - - def is_expired(self): - if self.expires is None: - return False # Non-expiring key - else: - return self.expires < datetime.now() - - -class Role(db.Model): - id = db.Column(db.Integer, primary_key=True) - name = db.Column(db.String(20), unique=True) - description = db.Column(db.String(260)) - - def __init__(self, name, description): - self.name = name - self.description = description - - def __repr__(self): - return self.name - - -class Organization(db.Model): - id = db.Column(db.Integer, primary_key=True) - name = db.Column(db.String(150), unique=True) - arange = db.Column(db.Text) - limit_flowspec4 = db.Column(db.Integer, default=0) - limit_flowspec6 = db.Column(db.Integer, default=0) - limit_rtbh = db.Column(db.Integer, default=0) - - def __init__(self, name, arange, limit_flowspec4=0, limit_flowspec6=0, limit_rtbh=0): - self.name = name - self.arange = arange - self.limit_flowspec4 = limit_flowspec4 - self.limit_flowspec6 = limit_flowspec6 - self.limit_rtbh = limit_rtbh - - def __repr__(self): - return self.name - - def get_users(self): - """ - Returns all users associated with this organization. - """ - # self.user is the backref from the user_organization relationship - return self.user - - -class ASPath(db.Model): - id = db.Column(db.Integer, primary_key=True) - prefix = db.Column(db.String(120), unique=True) - as_path = db.Column(db.String(250)) - - def __init__(self, prefix, as_path): - self.prefix = prefix - self.as_path = as_path - - def __repr__(self): - return f"{self.prefix} : {self.as_path}" - - -class Action(db.Model): - """ - Action for rule - """ - - id = db.Column(db.Integer, primary_key=True) - name = db.Column(db.String(120), unique=True) - command = db.Column(db.String(120), unique=True) - description = db.Column(db.String(260)) - role_id = db.Column(db.Integer, db.ForeignKey("role.id"), nullable=False) - role = db.relationship("Role", backref="action") - - def __init__(self, name, command, description, role_id=2): - self.name = name - self.command = command - self.description = description - self.role_id = role_id - - -class Community(db.Model): - """ - Community for RTBH rule - """ - - id = db.Column(db.Integer, primary_key=True) - name = db.Column(db.String(120), unique=True) - comm = db.Column(db.String(2047)) - larcomm = db.Column(db.String(2047)) - extcomm = db.Column(db.String(2047)) - description = db.Column(db.String(255)) - as_path = db.Column(db.Boolean, default=False) - role_id = db.Column(db.Integer, db.ForeignKey("role.id"), nullable=False) - role = db.relationship("Role", backref="community") - - def __init__(self, name, comm, larcomm, extcomm, description, as_path=False, role_id=2): - self.name = name - self.comm = comm - self.larcomm = larcomm - self.extcomm = extcomm - self.description = description - self.as_path = as_path - self.role_id = role_id - - -class Rstate(db.Model): - """ - State for Rules - """ - - id = db.Column(db.Integer, primary_key=True) - description = db.Column(db.String(260)) - - def __init__(self, description): - self.description = description - - -class RTBH(db.Model): - __tablename__ = "RTBH" - - id = db.Column(db.Integer, primary_key=True) - ipv4 = db.Column(db.String(255)) - ipv4_mask = db.Column(db.Integer) - ipv6 = db.Column(db.String(255)) - ipv6_mask = db.Column(db.Integer) - community_id = db.Column(db.Integer, db.ForeignKey("community.id"), nullable=False) - community = db.relationship("Community", backref="rtbh") - comment = db.Column(db.Text) - expires = db.Column(db.DateTime) - created = db.Column(db.DateTime) - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", backref="rtbh") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="rtbh") - rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) - rstate = db.relationship("Rstate", backref="RTBH") - - def __init__( - self, - ipv4, - ipv4_mask, - ipv6, - ipv6_mask, - community_id, - expires, - user_id, - org_id, - comment=None, - created=None, - rstate_id=1, - ): - self.ipv4 = ipv4 - self.ipv4_mask = ipv4_mask - self.ipv6 = ipv6 - self.ipv6_mask = ipv6_mask - self.community_id = community_id - self.expires = expires - self.user_id = user_id - self.org_id = org_id - self.comment = comment - if created is None: - created = datetime.now() - self.created = created - self.rstate_id = rstate_id - - def __eq__(self, other): - """ - Two models are equal if all the network parameters equals. - User_id and time fields can differ. - :param other: other RTBH instance - :return: boolean - """ - return ( - self.ipv4 == other.ipv4 - and self.ipv4_mask == other.ipv4_mask - and self.ipv6 == other.ipv6 - and self.ipv6_mask == other.ipv6_mask - and self.community_id == other.community_id - and self.rstate_id == other.rstate_id - ) - - def __ne__(self, other): - """ - Two models are not equal if all the network parameters are not equal. - User_id and time fields can differ. - :param other: other RTBH instance - :return: boolean - """ - compars = ( - self.ipv4 == other.ipv4 - and self.ipv4_mask == other.ipv4_mask - and self.ipv6 == other.ipv6 - and self.ipv6_mask == other.ipv6_mask - and self.community_id == other.community_id - and self.rstate_id == other.rstate_id - ) - - return not compars - - def update_time(self, form): - self.expires = utils.webpicker_to_datetime(form.expire_date.data) - db.session.commit() - - def to_dict(self, prefered_format="yearfirst"): - """ - Serialize to dict used in API - :param prefered_format: string with prefered time format - :return: dictionary - """ - if prefered_format == "timestamp": - expires = int(datetime.timestamp(self.expires)) - created = int(datetime.timestamp(self.created)) - else: - expires = utils.datetime_to_webpicker(self.expires, prefered_format) - created = utils.datetime_to_webpicker(self.created, prefered_format) - - return { - "id": self.id, - "ipv4": self.ipv4, - "ipv4_mask": self.ipv4_mask, - "ipv6": self.ipv6, - "ipv6_mask": self.ipv6_mask, - "community": self.community.name, - "comment": self.comment, - "expires": expires, - "created": created, - "user": self.user.uuid, - "rstate": self.rstate.description, - } - - def dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :returns: dictionary - """ - return self.to_dict(prefered_format) - - def json(self, prefered_format="yearfirst"): - """ - Serialize to json - :param prefered_format: string with prefered time format - :returns: json - """ - return json.dumps(self.to_dict()) - - -class Flowspec4(db.Model): - id = db.Column(db.Integer, primary_key=True) - source = db.Column(db.String(255)) - source_mask = db.Column(db.Integer) - source_port = db.Column(db.String(255)) - dest = db.Column(db.String(255)) - dest_mask = db.Column(db.Integer) - dest_port = db.Column(db.String(255)) - protocol = db.Column(db.String(255)) - flags = db.Column(db.String(255)) - packet_len = db.Column(db.String(255)) - fragment = db.Column(db.String(255)) - comment = db.Column(db.Text) - expires = db.Column(db.DateTime) - created = db.Column(db.DateTime) - action_id = db.Column(db.Integer, db.ForeignKey("action.id"), nullable=False) - action = db.relationship("Action", backref="flowspec4") - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", backref="flowspec4") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="flowspec4") - rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) - rstate = db.relationship("Rstate", backref="flowspec4") - - def __init__( - self, - source, - source_mask, - source_port, - destination, - destination_mask, - destination_port, - protocol, - flags, - packet_len, - fragment, - expires, - user_id, - org_id, - action_id, - created=None, - comment=None, - rstate_id=1, - ): - self.source = source - self.source_mask = source_mask - self.dest = destination - self.dest_mask = destination_mask - self.source_port = source_port - self.dest_port = destination_port - self.protocol = protocol - self.flags = flags - self.packet_len = packet_len - self.fragment = fragment - self.comment = comment - self.expires = expires - self.user_id = user_id - self.org_id = org_id - self.action_id = action_id - if created is None: - created = datetime.now() - self.created = created - self.rstate_id = rstate_id - - def __eq__(self, other): - """ - Two models are equal if all the network parameters equals. - User_id and time fields can differ. - :param other: other Flowspec4 instance - :return: boolean - """ - return ( - self.source == other.source - and self.source_mask == other.source_mask - and self.dest == other.dest - and self.dest_mask == other.dest_mask - and self.source_port == other.source_port - and self.dest_port == other.dest_port - and self.protocol == other.protocol - and self.flags == other.flags - and self.packet_len == other.packet_len - and self.fragment == other.fragment - and self.action_id == other.action_id - and self.rstate_id == other.rstate_id - ) - - def __ne__(self, other): - """ - Two models are not equal if all the network parameters are not equal. - User_id and time fields can differ. - :param other: other Flowspec4 instance - :return: boolean - """ - compars = ( - self.source == other.source - and self.source_mask == other.source_mask - and self.dest == other.dest - and self.dest_mask == other.dest_mask - and self.source_port == other.source_port - and self.dest_port == other.dest_port - and self.protocol == other.protocol - and self.flags == other.flags - and self.packet_len == other.packet_len - and self.fragment == other.fragment - and self.action_id == other.action_id - and self.rstate_id == other.rstate_id - ) - - return not compars - - def to_dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :return: dictionary - """ - if prefered_format == "timestamp": - expires = int(datetime.timestamp(self.expires)) - created = int(datetime.timestamp(self.created)) - else: - expires = utils.datetime_to_webpicker(self.expires, prefered_format) - created = utils.datetime_to_webpicker(self.created, prefered_format) - - return { - "id": self.id, - "source": self.source, - "source_mask": self.source_mask, - "source_port": self.source_port, - "dest": self.dest, - "dest_mask": self.dest_mask, - "dest_port": self.dest_port, - "protocol": self.protocol, - "flags": self.flags, - "packet_len": self.packet_len, - "fragment": self.fragment, - "comment": self.comment, - "expires": expires, - "created": created, - "action": self.action.name, - "user": self.user.uuid, - "rstate": self.rstate.description, - } - - def dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :returns: dictionary - """ - return self.to_dict(prefered_format) - - def json(self, prefered_format="yearfirst"): - """ - Serialize to json - :param prefered_format: string with prefered time format - :returns: json - """ - return json.dumps(self.to_dict()) - - -class Flowspec6(db.Model): - id = db.Column(db.Integer, primary_key=True) - source = db.Column(db.String(255)) - source_mask = db.Column(db.Integer) - source_port = db.Column(db.String(255)) - dest = db.Column(db.String(255)) - dest_mask = db.Column(db.Integer) - dest_port = db.Column(db.String(255)) - next_header = db.Column(db.String(255)) - flags = db.Column(db.String(255)) - packet_len = db.Column(db.String(255)) - comment = db.Column(db.Text) - expires = db.Column(db.DateTime) - created = db.Column(db.DateTime) - action_id = db.Column(db.Integer, db.ForeignKey("action.id"), nullable=False) - action = db.relationship("Action", backref="flowspec6") - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", backref="flowspec6") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="flowspec6") - rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) - rstate = db.relationship("Rstate", backref="flowspec6") - - def __init__( - self, - source, - source_mask, - source_port, - destination, - destination_mask, - destination_port, - next_header, - flags, - packet_len, - expires, - user_id, - org_id, - action_id, - created=None, - comment=None, - rstate_id=1, - ): - self.source = source - self.source_mask = source_mask - self.dest = destination - self.dest_mask = destination_mask - self.source_port = source_port - self.dest_port = destination_port - self.next_header = next_header - self.flags = flags - self.packet_len = packet_len - self.comment = comment - self.expires = expires - self.user_id = user_id - self.org_id = org_id - self.action_id = action_id - if created is None: - created = datetime.now() - self.created = created - self.rstate_id = rstate_id - - def __eq__(self, other): - """ - Two models are equal if all the network parameters equals. User_id and time fields can differ. - :param other: other Flowspec4 instance - :return: boolean - """ - return ( - self.source == other.source - and self.source_mask == other.source_mask - and self.dest == other.dest - and self.dest_mask == other.dest_mask - and self.source_port == other.source_port - and self.dest_port == other.dest_port - and self.next_header == other.next_header - and self.flags == other.flags - and self.packet_len == other.packet_len - and self.action_id == other.action_id - and self.rstate_id == other.rstate_id - ) - - def to_dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :returns: dictionary - """ - if prefered_format == "timestamp": - expires = int(datetime.timestamp(self.expires)) - created = int(datetime.timestamp(self.created)) - else: - expires = utils.datetime_to_webpicker(self.expires, prefered_format) - created = utils.datetime_to_webpicker(self.created, prefered_format) - - return { - "id": str(self.id), - "source": self.source, - "source_mask": self.source_mask, - "source_port": self.source_port, - "dest": self.dest, - "dest_mask": self.dest_mask, - "dest_port": self.dest_port, - "next_header": self.next_header, - "flags": self.flags, - "packet_len": self.packet_len, - "comment": self.comment, - "expires": expires, - "created": created, - "action": self.action.name, - "user": self.user.uuid, - "rstate": self.rstate.description, - } - - def dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :returns: dictionary - """ - return self.to_dict(prefered_format) - - def json(self, prefered_format="yearfirst"): - """ - Serialize to json - :param prefered_format: string with prefered time format - :returns: json - """ - return json.dumps(self.to_dict()) - - -class Log(db.Model): - id = db.Column(db.Integer, primary_key=True) - time = db.Column(db.DateTime) - task = db.Column(db.String(1000)) - author = db.Column(db.String(1000)) - rule_type = db.Column(db.Integer) - rule_id = db.Column(db.Integer) - user_id = db.Column(db.Integer) - org_id = db.Column(db.Integer, nullable=True) - - def __init__(self, time, task, user_id, rule_type, rule_id, author, org_id=None): - self.time = time - self.task = task - self.rule_type = rule_type - self.rule_id = rule_id - self.user_id = user_id - self.author = author - self.org_id = org_id - - -class Whitelist(db.Model): - id = db.Column(db.Integer, primary_key=True) - ip = db.Column(db.String(255)) - mask = db.Column(db.Integer) - comment = db.Column(db.Text) - expires = db.Column(db.DateTime) - created = db.Column(db.DateTime) - user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) - user = db.relationship("User", backref="whitelist") - org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) - org = db.relationship("Organization", backref="whitelist") - rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) - rstate = db.relationship("Rstate", backref="whitelist") - - def __init__( - self, - ip, - mask, - expires, - user_id, - org_id, - created=None, - comment=None, - rstate_id=1, - ): - self.ip = ip - self.mask = mask - self.expires = expires - self.user_id = user_id - self.org_id = org_id - self.comment = comment - if created is None: - created = datetime.now() - self.created = created - self.rstate_id = rstate_id - - def __eq__(self, other): - """ - Two whitelists are equal if all the network parameters equals. - User_id, org, comment and time fields can differ. - :param other: other Whitelist instance - :return: boolean - """ - return self.ip == other.ip and self.mask == other.mask and self.rstate_id == other.rstate_id - - def to_dict(self, prefered_format="yearfirst"): - """ - Serialize to dict - :param prefered_format: string with prefered time format - :returns: dictionary - """ - if prefered_format == "timestamp": - expires = int(datetime.timestamp(self.expires)) - created = int(datetime.timestamp(self.created)) - else: - expires = utils.datetime_to_webpicker(self.expires, prefered_format) - created = utils.datetime_to_webpicker(self.created, prefered_format) - - return { - "id": self.id, - "ip": self.ip, - "mask": self.mask, - "comment": self.comment, - "expires": expires, - "created": created, - "user": self.user.uuid, - "rstate": self.rstate.description, - } - - -class RuleWhitelistCache(db.Model): - """ - Cache for whitelisted rules - For each rule we store id and type - Rule origin determines if the rule was created by user or by whitelist - """ - - id = db.Column(db.Integer, primary_key=True) - rid = db.Column(db.Integer) - rtype = db.Column(db.Integer) - rorigin = db.Column(db.Integer) - whitelist_id = db.Column(db.Integer, db.ForeignKey("whitelist.id")) # Add ForeignKey - whitelist = db.relationship("Whitelist", backref="rulewhitelistcache") - - def __init__(self, rid, rtype, rorigin, whitelist_id): - self.rid = rid - self.rtype = rtype - self.rorigin = rorigin - self.whitelist_id = whitelist_id - - -# DDL -# default values for tables inserted after create -@event.listens_for(Action.__table__, "after_create") -def insert_initial_actions(table, conn, *args, **kwargs): - conn.execute( - table.insert().values( - name="QoS 100 kbps", - command="rate-limit 12800", - description="QoS", - role_id=2, - ) - ) - conn.execute( - table.insert().values( - name="QoS 1Mbps", - command="rate-limit 13107200", - description="QoS", - role_id=2, - ) - ) - conn.execute( - table.insert().values( - name="QoS 10Mbps", - command="rate-limit 131072000", - description="QoS", - role_id=2, - ) - ) - conn.execute(table.insert().values(name="Discard", command="discard", description="Discard", role_id=2)) - - -@event.listens_for(Community.__table__, "after_create") -def insert_initial_communities(table, conn, *args, **kwargs): - conn.execute( - table.insert().values( - name="65535:65283", - comm="65535:65283", - larcomm="", - extcomm="", - description="local-as", - role_id=2, - ) - ) - conn.execute( - table.insert().values( - name="64496:64511", - comm="64496:64511", - larcomm="", - extcomm="", - description="", - role_id=2, - ) - ) - conn.execute( - table.insert().values( - name="64497:64510", - comm="64497:64510", - larcomm="", - extcomm="", - description="", - role_id=2, - ) - ) - - -@event.listens_for(Role.__table__, "after_create") -def insert_initial_roles(table, conn, *args, **kwargs): - conn.execute(table.insert().values(name="view", description="just view, no edit")) - conn.execute(table.insert().values(name="user", description="can edit")) - conn.execute(table.insert().values(name="admin", description="admin")) - - -@event.listens_for(Organization.__table__, "after_create") -def insert_initial_organizations(table, conn, *args, **kwargs): - conn.execute(table.insert().values(name="Cesnet", arange="147.230.0.0/16\n2001:718:1c01::/48")) - - -@event.listens_for(Rstate.__table__, "after_create") -def insert_initial_rulestates(table, conn, *args, **kwargs): - conn.execute(table.insert().values(description="active rule")) - conn.execute(table.insert().values(description="withdrawed rule")) - conn.execute(table.insert().values(description="deleted rule")) - - -# Misc functions -def check_rule_limit(org_id: int, rule_type: RuleTypes) -> bool: - """ - Check if the organization has reached the rule limit - :param org_id: integer organization id - :param rule_type: RuleType rule type - :return: boolean - """ - flowspec_limit = current_app.config.get("FLOWSPEC_MAX_RULES", 9000) - rtbh_limit = current_app.config.get("RTBH_MAX_RULES", 100000) - fs4 = db.session.query(Flowspec4).filter_by(rstate_id=1).count() - fs6 = db.session.query(Flowspec6).filter_by(rstate_id=1).count() - rtbh = db.session.query(RTBH).filter_by(rstate_id=1).count() - - # check the organization limits - org = Organization.query.filter_by(id=org_id).first() - if rule_type == RuleTypes.IPv4 and org.limit_flowspec4 > 0: - count = db.session.query(Flowspec4).filter_by(org_id=org_id, rstate_id=1).count() - return count >= org.limit_flowspec4 or fs4 >= flowspec_limit - if rule_type == RuleTypes.IPv6 and org.limit_flowspec6 > 0: - count = db.session.query(Flowspec6).filter_by(org_id=org_id, rstate_id=1).count() - return count >= org.limit_flowspec6 or fs6 >= flowspec_limit - if rule_type == RuleTypes.RTBH and org.limit_rtbh > 0: - count = db.session.query(RTBH).filter_by(org_id=org_id, rstate_id=1).count() - return count >= org.limit_rtbh or rtbh >= rtbh_limit - - -def check_global_rule_limit(rule_type: RuleTypes) -> bool: - flowspec4_limit = current_app.config.get("FLOWSPEC4_MAX_RULES", 9000) - flowspec6_limit = current_app.config.get("FLOWSPEC6_MAX_RULES", 9000) - rtbh_limit = current_app.config.get("RTBH_MAX_RULES", 100000) - fs4 = db.session.query(Flowspec4).filter_by(rstate_id=1).count() - fs6 = db.session.query(Flowspec6).filter_by(rstate_id=1).count() - rtbh = db.session.query(RTBH).filter_by(rstate_id=1).count() - - # check the global limits if the organization limits are not set - - if rule_type == RuleTypes.IPv4: - return fs4 >= flowspec4_limit - if rule_type == RuleTypes.IPv6: - return fs6 >= flowspec6_limit - if rule_type == RuleTypes.RTBH: - return rtbh >= rtbh_limit - - -def get_ipv4_model_if_exists(form_data, rstate_id=1): - """ - Check if the record in database exist - """ - record = ( - db.session.query(Flowspec4) - .filter( - Flowspec4.source == form_data["source"], - Flowspec4.source_mask == form_data["source_mask"], - Flowspec4.source_port == form_data["source_port"], - Flowspec4.dest == form_data["dest"], - Flowspec4.dest_mask == form_data["dest_mask"], - Flowspec4.dest_port == form_data["dest_port"], - Flowspec4.protocol == form_data["protocol"], - Flowspec4.flags == ";".join(form_data["flags"]), - Flowspec4.packet_len == form_data["packet_len"], - Flowspec4.action_id == form_data["action"], - Flowspec4.rstate_id == rstate_id, - ) - .first() - ) - - if record: - return record - - return False - - -def get_ipv6_model_if_exists(form_data, rstate_id=1): - """ - Check if the record in database exist - """ - record = ( - db.session.query(Flowspec6) - .filter( - Flowspec6.source == form_data["source"], - Flowspec6.source_mask == form_data["source_mask"], - Flowspec6.source_port == form_data["source_port"], - Flowspec6.dest == form_data["dest"], - Flowspec6.dest_mask == form_data["dest_mask"], - Flowspec6.dest_port == form_data["dest_port"], - Flowspec6.next_header == form_data["next_header"], - Flowspec6.flags == ";".join(form_data["flags"]), - Flowspec6.packet_len == form_data["packet_len"], - Flowspec6.action_id == form_data["action"], - Flowspec6.rstate_id == rstate_id, - ) - .first() - ) - - if record: - return record - - return False - - -def get_rtbh_model_if_exists(form_data, rstate_id=1): - """ - Check if the record in database exist - """ - - record = ( - db.session.query(RTBH) - .filter( - RTBH.ipv4 == form_data["ipv4"], - RTBH.ipv4_mask == form_data["ipv4_mask"], - RTBH.ipv6 == form_data["ipv6"], - RTBH.ipv6_mask == form_data["ipv6_mask"], - RTBH.community_id == form_data["community"], - RTBH.rstate_id == rstate_id, - ) - .first() - ) - - if record: - return record - - return False - - -def insert_users(users): - """ - inser list of users {name: string, role_id: integer} to db - """ - for user in users: - r = Role.query.filter_by(id=user["role_id"]).first() - o = Organization.query.filter_by(id=user["org_id"]).first() - u = User(uuid=user["name"]) - u.role.append(r) - u.organization.append(o) - db.session.add(u) - - db.session.commit() - - -def insert_user( - uuid: str, - role_ids: list, - org_ids: list, - name: str = None, - phone: str = None, - email: str = None, - comment: str = None, -): - """ - insert new user with multiple roles and organizations - :param uuid: string unique user id (eppn or similar) - :param phone: string phone number - :param name: string user name - :param email: string email - :param comment: string comment / notice - :param role_ids: list of roles - :param org_ids: list of orgs - :return: None - """ - u = User(uuid=uuid, name=name, phone=phone, comment=comment, email=email) - - for role_id in role_ids: - r = Role.query.filter_by(id=role_id).first() - u.role.append(r) - - for org_id in org_ids: - o = Organization.query.filter_by(id=org_id).first() - u.organization.append(o) - - db.session.add(u) - db.session.commit() - - -def get_user_nets(user_id): - """ - Return list of network ranges for all user organization - """ - user = db.session.query(User).filter_by(id=user_id).first() - orgs = user.organization - result = [] - for org in orgs: - result.extend(org.arange.split()) - - return result - - -def get_user_orgs_choices(user_id): - """ - Return list of orgs as choices for form - """ - user = db.session.query(User).filter_by(id=user_id).first() - orgs = user.organization - - return [(g.id, g.name) for g in orgs] - - -def get_user_actions(user_roles): - """ - Return list of actions based on current user role - """ - max_role = max(user_roles) - if max_role == 3: - actions = db.session.query(Action).order_by("id") - else: - actions = db.session.query(Action).filter_by(role_id=max_role).order_by("id") - - return [(g.id, g.name) for g in actions] - - -def get_user_communities(user_roles): - """ - Return list of communities based on current user role - """ - max_role = max(user_roles) - if max_role == 3: - communities = db.session.query(Community).order_by("id") - else: - communities = db.session.query(Community).filter_by(role_id=max_role).order_by("id") - - return [(g.id, g.name) for g in communities] - - -def get_existing_action(name=None, command=None): - """ - return Action with given name or command if the action exists - return None if action not exists - :param name: string action name - :param command: string action command - :return: action id - """ - action = Action.query.filter((Action.name == name) | (Action.command == command)).first() - return action.id if hasattr(action, "id") else None - - -def get_existing_community(name=None): - """ - return Community with given name or command if the action exists - return None if action not exists - :param name: string action name - :param command: string action command - :return: action id - """ - community = Community.query.filter(Community.name == name).first() - return community.id if hasattr(community, "id") else None - - -def get_ip_rules(rule_type, rule_state, sort="expires", order="desc"): - """ - Returns list of rules sorted by sort column ordered asc or desc - :param sort: sorting column - :param order: asc or desc - :return: list - """ - - today = datetime.now() - comp_func = utils.get_comp_func(rule_state) - - if rule_type == "ipv4": - sorter_ip4 = getattr(Flowspec4, sort, Flowspec4.id) - sorting_ip4 = getattr(sorter_ip4, order) - if comp_func: - rules4 = ( - db.session.query(Flowspec4).filter(comp_func(Flowspec4.expires, today)).order_by(sorting_ip4()).all() - ) - else: - rules4 = db.session.query(Flowspec4).order_by(sorting_ip4()).all() - - return rules4 - - if rule_type == "ipv6": - sorter_ip6 = getattr(Flowspec6, sort, Flowspec6.id) - sorting_ip6 = getattr(sorter_ip6, order) - if comp_func: - rules6 = ( - db.session.query(Flowspec6).filter(comp_func(Flowspec6.expires, today)).order_by(sorting_ip6()).all() - ) - else: - rules6 = db.session.query(Flowspec6).order_by(sorting_ip6()).all() - - return rules6 - - if rule_type == "rtbh": - sorter_rtbh = getattr(RTBH, sort, RTBH.id) - sorting_rtbh = getattr(sorter_rtbh, order) - - if comp_func: - rules_rtbh = db.session.query(RTBH).filter(comp_func(RTBH.expires, today)).order_by(sorting_rtbh()).all() - - else: - rules_rtbh = db.session.query(RTBH).order_by(sorting_rtbh()).all() - - return rules_rtbh - - -def get_user_rules_ids(user_id, rule_type): - """ - Returns list of rule ids belonging to user - :param user_id: user id - :param rule_type: ipv4, ipv6 or rtbh - :return: list - """ - - if rule_type == "ipv4": - rules4 = db.session.query(Flowspec4.id).filter_by(user_id=user_id).all() - return [int(x[0]) for x in rules4] - - if rule_type == "ipv6": - rules6 = db.session.query(Flowspec6.id).order_by(Flowspec6.expires.desc()).all() - return [int(x[0]) for x in rules6] - - if rule_type == "rtbh": - rules_rtbh = db.session.query(RTBH.id).filter_by(user_id=user_id).all() - return [int(x[0]) for x in rules_rtbh] diff --git a/flowapp/models/__init__.py b/flowapp/models/__init__.py new file mode 100644 index 0000000..6d74ffe --- /dev/null +++ b/flowapp/models/__init__.py @@ -0,0 +1,69 @@ +# Import and re-export all models and functions for backward compatibility +from .base import db, user_role, user_organization + +# User-related models +from .user import User, Role +from .api import ApiKey, MachineApiKey +from .organization import Organization + +# Rule-related models +from .rules import Flowspec4, Flowspec6, RTBH, Rstate, Action, Whitelist, RuleWhitelistCache + +# Other models +from .community import Community, ASPath, insert_initial_communities +from .log import Log + +# Helper functions +from .utils import ( + get_user_nets, + get_user_actions, + get_user_communities, + get_existing_action, + get_existing_community, + get_ipv4_model_if_exists, + get_ipv6_model_if_exists, + get_rtbh_model_if_exists, + get_ip_rules, + get_user_rules_ids, + insert_users, + insert_user, + check_rule_limit, + check_global_rule_limit, +) + +# Ensure all models are registered properly +__all__ = [ + "db", + "User", + "Role", + "user_role", + "ApiKey", + "MachineApiKey", + "Organization", + "user_organization", + "Action", + "Flowspec4", + "Flowspec6", + "RTBH", + "Rstate", + "Community", + "ASPath", + "Log", + "Whitelist", + "RuleWhitelistCache", + "get_user_nets", + "get_user_actions", + "get_user_communities", + "get_existing_action", + "get_existing_community", + "get_ipv4_model_if_exists", + "get_ipv6_model_if_exists", + "get_rtbh_model_if_exists", + "get_ip_rules", + "get_user_rules_ids", + "insert_users", + "insert_user", + "check_rule_limit", + "check_global_rule_limit", + "insert_initial_communities", +] diff --git a/flowapp/models/api.py b/flowapp/models/api.py new file mode 100644 index 0000000..ed50685 --- /dev/null +++ b/flowapp/models/api.py @@ -0,0 +1,40 @@ +from datetime import datetime +from .base import db + + +class ApiKey(db.Model): + id = db.Column(db.Integer, primary_key=True) + machine = db.Column(db.String(255)) + key = db.Column(db.String(255)) + readonly = db.Column(db.Boolean, default=False) + expires = db.Column(db.DateTime, nullable=True) + comment = db.Column(db.String(255)) + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", back_populates="apikeys") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="apikey") + + def is_expired(self): + if self.expires is None: + return False # Non-expiring key + else: + return self.expires < datetime.now() + + +class MachineApiKey(db.Model): + id = db.Column(db.Integer, primary_key=True) + machine = db.Column(db.String(255)) + key = db.Column(db.String(255)) + readonly = db.Column(db.Boolean, default=True) + expires = db.Column(db.DateTime, nullable=True) + comment = db.Column(db.String(255)) + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", back_populates="machineapikeys") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="machineapikey") + + def is_expired(self): + if self.expires is None: + return False # Non-expiring key + else: + return self.expires < datetime.now() diff --git a/flowapp/models/base.py b/flowapp/models/base.py new file mode 100644 index 0000000..a934ce8 --- /dev/null +++ b/flowapp/models/base.py @@ -0,0 +1,16 @@ +from flowapp import db + +# Define shared tables +user_role = db.Table( + "user_role", + db.Column("user_id", db.Integer, db.ForeignKey("user.id"), nullable=False), + db.Column("role_id", db.Integer, db.ForeignKey("role.id"), nullable=False), + db.PrimaryKeyConstraint("user_id", "role_id"), +) + +user_organization = db.Table( + "user_organization", + db.Column("user_id", db.Integer, db.ForeignKey("user.id"), nullable=False), + db.Column("organization_id", db.Integer, db.ForeignKey("organization.id"), nullable=False), + db.PrimaryKeyConstraint("user_id", "organization_id"), +) diff --git a/flowapp/models/community.py b/flowapp/models/community.py new file mode 100644 index 0000000..afb2c36 --- /dev/null +++ b/flowapp/models/community.py @@ -0,0 +1,62 @@ +from sqlalchemy import event +from .base import db + + +class Community(db.Model): + """Community for RTBH rule""" + + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(120), unique=True) + comm = db.Column(db.String(2047)) + larcomm = db.Column(db.String(2047)) + extcomm = db.Column(db.String(2047)) + description = db.Column(db.String(255)) + as_path = db.Column(db.Boolean, default=False) + role_id = db.Column(db.Integer, db.ForeignKey("role.id"), nullable=False) + role = db.relationship("Role", backref="community") + + # Methods and initializer + + +class ASPath(db.Model): + """AS Path for RTBH rules""" + + id = db.Column(db.Integer, primary_key=True) + prefix = db.Column(db.String(120), unique=True) + as_path = db.Column(db.String(250)) + + # Methods and initializer + + +@event.listens_for(Community.__table__, "after_create") +def insert_initial_communities(table, conn, *args, **kwargs): + conn.execute( + table.insert().values( + name="65535:65283", + comm="65535:65283", + larcomm="", + extcomm="", + description="local-as", + role_id=2, + ) + ) + conn.execute( + table.insert().values( + name="64496:64511", + comm="64496:64511", + larcomm="", + extcomm="", + description="", + role_id=2, + ) + ) + conn.execute( + table.insert().values( + name="64497:64510", + comm="64497:64510", + larcomm="", + extcomm="", + description="", + role_id=2, + ) + ) diff --git a/flowapp/models/log.py b/flowapp/models/log.py new file mode 100644 index 0000000..e32bb30 --- /dev/null +++ b/flowapp/models/log.py @@ -0,0 +1,21 @@ +from .base import db + + +class Log(db.Model): + """Log model for system actions""" + + id = db.Column(db.Integer, primary_key=True) + time = db.Column(db.DateTime) + task = db.Column(db.String(1000)) + author = db.Column(db.String(1000)) + rule_type = db.Column(db.Integer) + rule_id = db.Column(db.Integer) + user_id = db.Column(db.Integer) + + def __init__(self, time, task, user_id, rule_type, rule_id, author): + self.time = time + self.task = task + self.rule_type = rule_type + self.rule_id = rule_id + self.user_id = user_id + self.author = author diff --git a/flowapp/models/organization.py b/flowapp/models/organization.py new file mode 100644 index 0000000..baf0ec1 --- /dev/null +++ b/flowapp/models/organization.py @@ -0,0 +1,35 @@ +from sqlalchemy import event +from .base import db + + +class Organization(db.Model): + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(150), unique=True) + arange = db.Column(db.Text) + limit_flowspec4 = db.Column(db.Integer, default=0) + limit_flowspec6 = db.Column(db.Integer, default=0) + limit_rtbh = db.Column(db.Integer, default=0) + + def __init__(self, name, arange, limit_flowspec4=0, limit_flowspec6=0, limit_rtbh=0): + self.name = name + self.arange = arange + self.limit_flowspec4 = limit_flowspec4 + self.limit_flowspec6 = limit_flowspec6 + self.limit_rtbh = limit_rtbh + + def __repr__(self): + return self.name + + def get_users(self): + """ + Returns all users associated with this organization. + """ + # self.user is the backref from the user_organization relationship + return self.user + + +# Event listeners for Organization +@event.listens_for(Organization.__table__, "after_create") +def insert_initial_organizations(table, conn, *args, **kwargs): + conn.execute(table.insert().values(name="TU Liberec", arange="147.230.0.0/16\n2001:718:1c01::/48")) + conn.execute(table.insert().values(name="Cesnet", arange="147.230.0.0/16\n2001:718:1c01::/48")) diff --git a/flowapp/models/rules/__init__.py b/flowapp/models/rules/__init__.py new file mode 100644 index 0000000..da12ba4 --- /dev/null +++ b/flowapp/models/rules/__init__.py @@ -0,0 +1,6 @@ +from .flowspec import Flowspec4, Flowspec6 +from .rtbh import RTBH +from .base import Rstate, Action +from .whitelist import Whitelist, RuleWhitelistCache + +__all__ = ["Flowspec4", "Flowspec6", "RTBH", "Rstate", "Action", "Whitelist", "RuleWhitelistCache"] diff --git a/flowapp/models/rules/base.py b/flowapp/models/rules/base.py new file mode 100644 index 0000000..ab61f5e --- /dev/null +++ b/flowapp/models/rules/base.py @@ -0,0 +1,68 @@ +from sqlalchemy import event +from ..base import db + + +class Rstate(db.Model): + """State for Rules""" + + id = db.Column(db.Integer, primary_key=True) + description = db.Column(db.String(260)) + + def __init__(self, description): + self.description = description + + +class Action(db.Model): + """ + Action for rule + """ + + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(120), unique=True) + command = db.Column(db.String(120), unique=True) + description = db.Column(db.String(260)) + role_id = db.Column(db.Integer, db.ForeignKey("role.id"), nullable=False) + role = db.relationship("Role", backref="action") + + def __init__(self, name, command, description, role_id=2): + self.name = name + self.command = command + self.description = description + self.role_id = role_id + + +# Event listeners for Rstate +@event.listens_for(Rstate.__table__, "after_create") +def insert_initial_rulestates(table, conn, *args, **kwargs): + conn.execute(table.insert().values(description="active rule")) + conn.execute(table.insert().values(description="withdrawed rule")) + conn.execute(table.insert().values(description="deleted rule")) + + +@event.listens_for(Action.__table__, "after_create") +def insert_initial_actions(table, conn, *args, **kwargs): + conn.execute( + table.insert().values( + name="QoS 100 kbps", + command="rate-limit 12800", + description="QoS", + role_id=2, + ) + ) + conn.execute( + table.insert().values( + name="QoS 1Mbps", + command="rate-limit 13107200", + description="QoS", + role_id=2, + ) + ) + conn.execute( + table.insert().values( + name="QoS 10Mbps", + command="rate-limit 131072000", + description="QoS", + role_id=2, + ) + ) + conn.execute(table.insert().values(name="Discard", command="discard", description="Discard", role_id=2)) diff --git a/flowapp/models/rules/flowspec.py b/flowapp/models/rules/flowspec.py new file mode 100644 index 0000000..9823aba --- /dev/null +++ b/flowapp/models/rules/flowspec.py @@ -0,0 +1,294 @@ +import json +from datetime import datetime +from flowapp import utils +from ..base import db + + +class Flowspec4(db.Model): + id = db.Column(db.Integer, primary_key=True) + source = db.Column(db.String(255)) + source_mask = db.Column(db.Integer) + source_port = db.Column(db.String(255)) + dest = db.Column(db.String(255)) + dest_mask = db.Column(db.Integer) + dest_port = db.Column(db.String(255)) + protocol = db.Column(db.String(255)) + flags = db.Column(db.String(255)) + packet_len = db.Column(db.String(255)) + fragment = db.Column(db.String(255)) + comment = db.Column(db.Text) + expires = db.Column(db.DateTime) + created = db.Column(db.DateTime) + action_id = db.Column(db.Integer, db.ForeignKey("action.id"), nullable=False) + action = db.relationship("Action", backref="flowspec4") + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", backref="flowspec4") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="flowspec4") + rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) + rstate = db.relationship("Rstate", backref="flowspec4") + + def __init__( + self, + source, + source_mask, + source_port, + destination, + destination_mask, + destination_port, + protocol, + flags, + packet_len, + fragment, + expires, + user_id, + org_id, + action_id, + created=None, + comment=None, + rstate_id=1, + ): + self.source = source + self.source_mask = source_mask + self.dest = destination + self.dest_mask = destination_mask + self.source_port = source_port + self.dest_port = destination_port + self.protocol = protocol + self.flags = flags + self.packet_len = packet_len + self.fragment = fragment + self.comment = comment + self.expires = expires + self.user_id = user_id + self.org_id = org_id + self.action_id = action_id + if created is None: + created = datetime.now() + self.created = created + self.rstate_id = rstate_id + + def __eq__(self, other): + """ + Two models are equal if all the network parameters equals. + User_id and time fields can differ. + :param other: other Flowspec4 instance + :return: boolean + """ + return ( + self.source == other.source + and self.source_mask == other.source_mask + and self.dest == other.dest + and self.dest_mask == other.dest_mask + and self.source_port == other.source_port + and self.dest_port == other.dest_port + and self.protocol == other.protocol + and self.flags == other.flags + and self.packet_len == other.packet_len + and self.fragment == other.fragment + and self.action_id == other.action_id + and self.rstate_id == other.rstate_id + ) + + def __ne__(self, other): + """ + Two models are not equal if all the network parameters are not equal. + User_id and time fields can differ. + :param other: other Flowspec4 instance + :return: boolean + """ + compars = ( + self.source == other.source + and self.source_mask == other.source_mask + and self.dest == other.dest + and self.dest_mask == other.dest_mask + and self.source_port == other.source_port + and self.dest_port == other.dest_port + and self.protocol == other.protocol + and self.flags == other.flags + and self.packet_len == other.packet_len + and self.fragment == other.fragment + and self.action_id == other.action_id + and self.rstate_id == other.rstate_id + ) + + return not compars + + def to_dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :return: dictionary + """ + if prefered_format == "timestamp": + expires = int(datetime.timestamp(self.expires)) + created = int(datetime.timestamp(self.created)) + else: + expires = utils.datetime_to_webpicker(self.expires, prefered_format) + created = utils.datetime_to_webpicker(self.created, prefered_format) + + return { + "id": self.id, + "source": self.source, + "source_mask": self.source_mask, + "source_port": self.source_port, + "dest": self.dest, + "dest_mask": self.dest_mask, + "dest_port": self.dest_port, + "protocol": self.protocol, + "flags": self.flags, + "packet_len": self.packet_len, + "fragment": self.fragment, + "comment": self.comment, + "expires": expires, + "created": created, + "action": self.action.name, + "user": self.user.uuid, + "rstate": self.rstate.description, + } + + def dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + return self.to_dict(prefered_format) + + def json(self, prefered_format="yearfirst"): + """ + Serialize to json + :param prefered_format: string with prefered time format + :returns: json + """ + return json.dumps(self.to_dict()) + + +class Flowspec6(db.Model): + id = db.Column(db.Integer, primary_key=True) + source = db.Column(db.String(255)) + source_mask = db.Column(db.Integer) + source_port = db.Column(db.String(255)) + dest = db.Column(db.String(255)) + dest_mask = db.Column(db.Integer) + dest_port = db.Column(db.String(255)) + next_header = db.Column(db.String(255)) + flags = db.Column(db.String(255)) + packet_len = db.Column(db.String(255)) + comment = db.Column(db.Text) + expires = db.Column(db.DateTime) + created = db.Column(db.DateTime) + action_id = db.Column(db.Integer, db.ForeignKey("action.id"), nullable=False) + action = db.relationship("Action", backref="flowspec6") + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", backref="flowspec6") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="flowspec6") + rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) + rstate = db.relationship("Rstate", backref="flowspec6") + + def __init__( + self, + source, + source_mask, + source_port, + destination, + destination_mask, + destination_port, + next_header, + flags, + packet_len, + expires, + user_id, + org_id, + action_id, + created=None, + comment=None, + rstate_id=1, + ): + self.source = source + self.source_mask = source_mask + self.dest = destination + self.dest_mask = destination_mask + self.source_port = source_port + self.dest_port = destination_port + self.next_header = next_header + self.flags = flags + self.packet_len = packet_len + self.comment = comment + self.expires = expires + self.user_id = user_id + self.org_id = org_id + self.action_id = action_id + if created is None: + created = datetime.now() + self.created = created + self.rstate_id = rstate_id + + def __eq__(self, other): + """ + Two models are equal if all the network parameters equals. User_id and time fields can differ. + :param other: other Flowspec4 instance + :return: boolean + """ + return ( + self.source == other.source + and self.source_mask == other.source_mask + and self.dest == other.dest + and self.dest_mask == other.dest_mask + and self.source_port == other.source_port + and self.dest_port == other.dest_port + and self.next_header == other.next_header + and self.flags == other.flags + and self.packet_len == other.packet_len + and self.action_id == other.action_id + and self.rstate_id == other.rstate_id + ) + + def to_dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + if prefered_format == "timestamp": + expires = int(datetime.timestamp(self.expires)) + created = int(datetime.timestamp(self.created)) + else: + expires = utils.datetime_to_webpicker(self.expires, prefered_format) + created = utils.datetime_to_webpicker(self.created, prefered_format) + + return { + "id": str(self.id), + "source": self.source, + "source_mask": self.source_mask, + "source_port": self.source_port, + "dest": self.dest, + "dest_mask": self.dest_mask, + "dest_port": self.dest_port, + "next_header": self.next_header, + "flags": self.flags, + "packet_len": self.packet_len, + "comment": self.comment, + "expires": expires, + "created": created, + "action": self.action.name, + "user": self.user.uuid, + "rstate": self.rstate.description, + } + + def dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + return self.to_dict(prefered_format) + + def json(self, prefered_format="yearfirst"): + """ + Serialize to json + :param prefered_format: string with prefered time format + :returns: json + """ + return json.dumps(self.to_dict()) diff --git a/flowapp/models/rules/rtbh.py b/flowapp/models/rules/rtbh.py new file mode 100644 index 0000000..d75220c --- /dev/null +++ b/flowapp/models/rules/rtbh.py @@ -0,0 +1,134 @@ +import json +from datetime import datetime +from flowapp import utils +from ..base import db + + +class RTBH(db.Model): + __tablename__ = "RTBH" + + id = db.Column(db.Integer, primary_key=True) + ipv4 = db.Column(db.String(255)) + ipv4_mask = db.Column(db.Integer) + ipv6 = db.Column(db.String(255)) + ipv6_mask = db.Column(db.Integer) + community_id = db.Column(db.Integer, db.ForeignKey("community.id"), nullable=False) + community = db.relationship("Community", backref="rtbh") + comment = db.Column(db.Text) + expires = db.Column(db.DateTime) + created = db.Column(db.DateTime) + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", backref="rtbh") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="rtbh") + rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) + rstate = db.relationship("Rstate", backref="RTBH") + + def __init__( + self, + ipv4, + ipv4_mask, + ipv6, + ipv6_mask, + community_id, + expires, + user_id, + org_id, + comment=None, + created=None, + rstate_id=1, + ): + self.ipv4 = ipv4 + self.ipv4_mask = ipv4_mask + self.ipv6 = ipv6 + self.ipv6_mask = ipv6_mask + self.community_id = community_id + self.expires = expires + self.user_id = user_id + self.org_id = org_id + self.comment = comment + if created is None: + created = datetime.now() + self.created = created + self.rstate_id = rstate_id + + def __eq__(self, other): + """ + Two models are equal if all the network parameters equals. + User_id and time fields can differ. + :param other: other RTBH instance + :return: boolean + """ + return ( + self.ipv4 == other.ipv4 + and self.ipv4_mask == other.ipv4_mask + and self.ipv6 == other.ipv6 + and self.ipv6_mask == other.ipv6_mask + and self.community_id == other.community_id + and self.rstate_id == other.rstate_id + ) + + def __ne__(self, other): + """ + Two models are not equal if all the network parameters are not equal. + User_id and time fields can differ. + :param other: other RTBH instance + :return: boolean + """ + compars = ( + self.ipv4 == other.ipv4 + and self.ipv4_mask == other.ipv4_mask + and self.ipv6 == other.ipv6 + and self.ipv6_mask == other.ipv6_mask + and self.community_id == other.community_id + and self.rstate_id == other.rstate_id + ) + + return not compars + + def update_time(self, form): + self.expires = utils.webpicker_to_datetime(form.expire_date.data) + db.session.commit() + + def to_dict(self, prefered_format="yearfirst"): + """ + Serialize to dict used in API + :param prefered_format: string with prefered time format + :return: dictionary + """ + if prefered_format == "timestamp": + expires = int(datetime.timestamp(self.expires)) + created = int(datetime.timestamp(self.created)) + else: + expires = utils.datetime_to_webpicker(self.expires, prefered_format) + created = utils.datetime_to_webpicker(self.created, prefered_format) + + return { + "id": self.id, + "ipv4": self.ipv4, + "ipv4_mask": self.ipv4_mask, + "ipv6": self.ipv6, + "ipv6_mask": self.ipv6_mask, + "community": self.community.name, + "comment": self.comment, + "expires": expires, + "created": created, + "user": self.user.uuid, + "rstate": self.rstate.description, + } + + def dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + return self.to_dict(prefered_format) + + def json(self, prefered_format="yearfirst"): + """ + Serialize to json + :param prefered_format: string with prefered time format + :returns: json + """ + return json.dumps(self.to_dict()) diff --git a/flowapp/models/rules/whitelist.py b/flowapp/models/rules/whitelist.py new file mode 100644 index 0000000..5e93714 --- /dev/null +++ b/flowapp/models/rules/whitelist.py @@ -0,0 +1,94 @@ +from flowapp import utils +from ..base import db +from datetime import datetime + + +class Whitelist(db.Model): + id = db.Column(db.Integer, primary_key=True) + ip = db.Column(db.String(255)) + mask = db.Column(db.Integer) + comment = db.Column(db.Text) + expires = db.Column(db.DateTime) + created = db.Column(db.DateTime) + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) + user = db.relationship("User", backref="whitelist") + org_id = db.Column(db.Integer, db.ForeignKey("organization.id"), nullable=False) + org = db.relationship("Organization", backref="whitelist") + rstate_id = db.Column(db.Integer, db.ForeignKey("rstate.id"), nullable=False) + rstate = db.relationship("Rstate", backref="whitelist") + + def __init__( + self, + ip, + mask, + expires, + user_id, + org_id, + created=None, + comment=None, + rstate_id=1, + ): + self.ip = ip + self.mask = mask + self.expires = expires + self.user_id = user_id + self.org_id = org_id + self.comment = comment + if created is None: + created = datetime.now() + self.created = created + self.rstate_id = rstate_id + + def __eq__(self, other): + """ + Two whitelists are equal if all the network parameters equals. + User_id, org, comment and time fields can differ. + :param other: other Whitelist instance + :return: boolean + """ + return self.ip == other.ip and self.mask == other.mask and self.rstate_id == other.rstate_id + + def to_dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + if prefered_format == "timestamp": + expires = int(datetime.timestamp(self.expires)) + created = int(datetime.timestamp(self.created)) + else: + expires = utils.datetime_to_webpicker(self.expires, prefered_format) + created = utils.datetime_to_webpicker(self.created, prefered_format) + + return { + "id": self.id, + "ip": self.ip, + "mask": self.mask, + "comment": self.comment, + "expires": expires, + "created": created, + "user": self.user.uuid, + "rstate": self.rstate.description, + } + + +class RuleWhitelistCache(db.Model): + """ + Cache for whitelisted rules + For each rule we store id and type + Rule origin determines if the rule was created by user or by whitelist + """ + + id = db.Column(db.Integer, primary_key=True) + rid = db.Column(db.Integer) + rtype = db.Column(db.Integer) + rorigin = db.Column(db.Integer) + whitelist_id = db.Column(db.Integer, db.ForeignKey("whitelist.id")) # Add ForeignKey + whitelist = db.relationship("Whitelist", backref="rulewhitelistcache") + + def __init__(self, rid, rtype, rorigin, whitelist_id): + self.rid = rid + self.rtype = rtype + self.rorigin = rorigin + self.whitelist_id = whitelist_id diff --git a/flowapp/models/user.py b/flowapp/models/user.py new file mode 100644 index 0000000..dcb2d7e --- /dev/null +++ b/flowapp/models/user.py @@ -0,0 +1,79 @@ +from sqlalchemy import event +from .base import db, user_role, user_organization +from .organization import Organization + + +class User(db.Model): + """ + App User + """ + + id = db.Column(db.Integer, primary_key=True) + uuid = db.Column(db.String(180), unique=True) + comment = db.Column(db.String(500)) + email = db.Column(db.String(255)) + name = db.Column(db.String(255)) + phone = db.Column(db.String(255)) + apikeys = db.relationship("ApiKey", back_populates="user", lazy="dynamic") + machineapikeys = db.relationship("MachineApiKey", back_populates="user", lazy="dynamic") + role = db.relationship("Role", secondary=user_role, lazy="dynamic", backref="user") + + organization = db.relationship("Organization", secondary=user_organization, lazy="dynamic", backref="user") + + def __init__(self, uuid, name=None, phone=None, email=None, comment=None): + self.uuid = uuid + self.phone = phone + self.name = name + self.email = email + self.comment = comment + + def update(self, form): + """ + update the user with values from form object + :param form: flask form from request + :return: None + """ + self.uuid = form.uuid.data + self.name = form.name.data + self.email = form.email.data + self.phone = form.phone.data + self.comment = form.comment.data + + # first clear existing roles and orgs + for role in self.role: + self.role.remove(role) + for org in self.organization: + self.organization.remove(org) + + for role_id in form.role_ids.data: + my_role = db.session.query(Role).filter_by(id=role_id).first() + if my_role not in self.role: + self.role.append(my_role) + + for org_id in form.org_ids.data: + my_org = db.session.query(Organization).filter_by(id=org_id).first() + if my_org not in self.organization: + self.organization.append(my_org) + + db.session.commit() + + +class Role(db.Model): + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(20), unique=True) + description = db.Column(db.String(260)) + + def __init__(self, name, description): + self.name = name + self.description = description + + def __repr__(self): + return self.name + + +# Event listeners for Role +@event.listens_for(Role.__table__, "after_create") +def insert_initial_roles(table, conn, *args, **kwargs): + conn.execute(table.insert().values(name="view", description="just view, no edit")) + conn.execute(table.insert().values(name="user", description="can edit")) + conn.execute(table.insert().values(name="admin", description="admin")) diff --git a/flowapp/models/utils.py b/flowapp/models/utils.py new file mode 100644 index 0000000..e2129b2 --- /dev/null +++ b/flowapp/models/utils.py @@ -0,0 +1,330 @@ +"""Utility functions for models""" + +from datetime import datetime +from flowapp import utils +from flowapp.constants import RuleTypes +from flask import current_app +from .base import db +from .user import User, Role +from .organization import Organization +from .community import Community +from .rules.flowspec import Flowspec4, Flowspec6 +from .rules.rtbh import RTBH +from .rules.base import Action + + +def check_rule_limit(org_id: int, rule_type: RuleTypes) -> bool: + """ + Check if the organization has reached the rule limit + :param org_id: integer organization id + :param rule_type: RuleType rule type + :return: boolean + """ + flowspec_limit = current_app.config.get("FLOWSPEC_MAX_RULES", 9000) + rtbh_limit = current_app.config.get("RTBH_MAX_RULES", 100000) + fs4 = db.session.query(Flowspec4).filter_by(rstate_id=1).count() + fs6 = db.session.query(Flowspec6).filter_by(rstate_id=1).count() + rtbh = db.session.query(RTBH).filter_by(rstate_id=1).count() + + # check the organization limits + org = Organization.query.filter_by(id=org_id).first() + if rule_type == RuleTypes.IPv4 and org.limit_flowspec4 > 0: + count = db.session.query(Flowspec4).filter_by(org_id=org_id, rstate_id=1).count() + return count >= org.limit_flowspec4 or fs4 >= flowspec_limit + if rule_type == RuleTypes.IPv6 and org.limit_flowspec6 > 0: + count = db.session.query(Flowspec6).filter_by(org_id=org_id, rstate_id=1).count() + return count >= org.limit_flowspec6 or fs6 >= flowspec_limit + if rule_type == RuleTypes.RTBH and org.limit_rtbh > 0: + count = db.session.query(RTBH).filter_by(org_id=org_id, rstate_id=1).count() + return count >= org.limit_rtbh or rtbh >= rtbh_limit + + +def check_global_rule_limit(rule_type: RuleTypes) -> bool: + flowspec4_limit = current_app.config.get("FLOWSPEC4_MAX_RULES", 9000) + flowspec6_limit = current_app.config.get("FLOWSPEC6_MAX_RULES", 9000) + rtbh_limit = current_app.config.get("RTBH_MAX_RULES", 100000) + fs4 = db.session.query(Flowspec4).filter_by(rstate_id=1).count() + fs6 = db.session.query(Flowspec6).filter_by(rstate_id=1).count() + rtbh = db.session.query(RTBH).filter_by(rstate_id=1).count() + + # check the global limits if the organization limits are not set + + if rule_type == RuleTypes.IPv4: + return fs4 >= flowspec4_limit + if rule_type == RuleTypes.IPv6: + return fs6 >= flowspec6_limit + if rule_type == RuleTypes.RTBH: + return rtbh >= rtbh_limit + + +def get_ipv4_model_if_exists(form_data, rstate_id=1): + """ + Check if the record in database exist + """ + record = ( + db.session.query(Flowspec4) + .filter( + Flowspec4.source == form_data["source"], + Flowspec4.source_mask == form_data["source_mask"], + Flowspec4.source_port == form_data["source_port"], + Flowspec4.dest == form_data["dest"], + Flowspec4.dest_mask == form_data["dest_mask"], + Flowspec4.dest_port == form_data["dest_port"], + Flowspec4.protocol == form_data["protocol"], + Flowspec4.flags == ";".join(form_data["flags"]), + Flowspec4.packet_len == form_data["packet_len"], + Flowspec4.action_id == form_data["action"], + Flowspec4.rstate_id == rstate_id, + ) + .first() + ) + + if record: + return record + + return False + + +def get_ipv6_model_if_exists(form_data, rstate_id=1): + """ + Check if the record in database exist + """ + record = ( + db.session.query(Flowspec6) + .filter( + Flowspec6.source == form_data["source"], + Flowspec6.source_mask == form_data["source_mask"], + Flowspec6.source_port == form_data["source_port"], + Flowspec6.dest == form_data["dest"], + Flowspec6.dest_mask == form_data["dest_mask"], + Flowspec6.dest_port == form_data["dest_port"], + Flowspec6.next_header == form_data["next_header"], + Flowspec6.flags == ";".join(form_data["flags"]), + Flowspec6.packet_len == form_data["packet_len"], + Flowspec6.action_id == form_data["action"], + Flowspec6.rstate_id == rstate_id, + ) + .first() + ) + + if record: + return record + + return False + + +def get_rtbh_model_if_exists(form_data, rstate_id=1): + """ + Check if the record in database exist + """ + + record = ( + db.session.query(RTBH) + .filter( + RTBH.ipv4 == form_data["ipv4"], + RTBH.ipv4_mask == form_data["ipv4_mask"], + RTBH.ipv6 == form_data["ipv6"], + RTBH.ipv6_mask == form_data["ipv6_mask"], + RTBH.community_id == form_data["community"], + RTBH.rstate_id == rstate_id, + ) + .first() + ) + + if record: + return record + + return False + + +def insert_users(users): + """ + inser list of users {name: string, role_id: integer} to db + """ + for user in users: + r = Role.query.filter_by(id=user["role_id"]).first() + o = Organization.query.filter_by(id=user["org_id"]).first() + u = User(uuid=user["name"]) + u.role.append(r) + u.organization.append(o) + db.session.add(u) + + db.session.commit() + + +def insert_user( + uuid: str, + role_ids: list, + org_ids: list, + name: str = None, + phone: str = None, + email: str = None, + comment: str = None, +): + """ + insert new user with multiple roles and organizations + :param uuid: string unique user id (eppn or similar) + :param phone: string phone number + :param name: string user name + :param email: string email + :param comment: string comment / notice + :param role_ids: list of roles + :param org_ids: list of orgs + :return: None + """ + u = User(uuid=uuid, name=name, phone=phone, comment=comment, email=email) + + for role_id in role_ids: + r = Role.query.filter_by(id=role_id).first() + u.role.append(r) + + for org_id in org_ids: + o = Organization.query.filter_by(id=org_id).first() + u.organization.append(o) + + db.session.add(u) + db.session.commit() + + +def get_user_nets(user_id): + """ + Return list of network ranges for all user organization + """ + user = db.session.query(User).filter_by(id=user_id).first() + orgs = user.organization + result = [] + for org in orgs: + result.extend(org.arange.split()) + + return result + + +def get_user_orgs_choices(user_id): + """ + Return list of orgs as choices for form + """ + user = db.session.query(User).filter_by(id=user_id).first() + orgs = user.organization + + return [(g.id, g.name) for g in orgs] + + +def get_user_actions(user_roles): + """ + Return list of actions based on current user role + """ + max_role = max(user_roles) + print(max_role) + if max_role == 3: + actions = db.session.query(Action).order_by("id").all() + else: + actions = db.session.query(Action).filter_by(role_id=max_role).order_by("id").all() + result = [(g.id, g.name) for g in actions] + print(actions, result) + return result + + +def get_user_communities(user_roles): + """ + Return list of communities based on current user role + """ + max_role = max(user_roles) + if max_role == 3: + communities = db.session.query(Community).order_by("id") + else: + communities = db.session.query(Community).filter_by(role_id=max_role).order_by("id") + + return [(g.id, g.name) for g in communities] + + +def get_existing_action(name=None, command=None): + """ + return Action with given name or command if the action exists + return None if action not exists + :param name: string action name + :param command: string action command + :return: action id + """ + action = Action.query.filter((Action.name == name) | (Action.command == command)).first() + return action.id if hasattr(action, "id") else None + + +def get_existing_community(name=None): + """ + return Community with given name or command if the action exists + return None if action not exists + :param name: string action name + :param command: string action command + :return: action id + """ + community = Community.query.filter(Community.name == name).first() + return community.id if hasattr(community, "id") else None + + +def get_ip_rules(rule_type, rule_state, sort="expires", order="desc"): + """ + Returns list of rules sorted by sort column ordered asc or desc + :param sort: sorting column + :param order: asc or desc + :return: list + """ + + today = datetime.now() + comp_func = utils.get_comp_func(rule_state) + + if rule_type == "ipv4": + sorter_ip4 = getattr(Flowspec4, sort, Flowspec4.id) + sorting_ip4 = getattr(sorter_ip4, order) + if comp_func: + rules4 = ( + db.session.query(Flowspec4).filter(comp_func(Flowspec4.expires, today)).order_by(sorting_ip4()).all() + ) + else: + rules4 = db.session.query(Flowspec4).order_by(sorting_ip4()).all() + + return rules4 + + if rule_type == "ipv6": + sorter_ip6 = getattr(Flowspec6, sort, Flowspec6.id) + sorting_ip6 = getattr(sorter_ip6, order) + if comp_func: + rules6 = ( + db.session.query(Flowspec6).filter(comp_func(Flowspec6.expires, today)).order_by(sorting_ip6()).all() + ) + else: + rules6 = db.session.query(Flowspec6).order_by(sorting_ip6()).all() + + return rules6 + + if rule_type == "rtbh": + sorter_rtbh = getattr(RTBH, sort, RTBH.id) + sorting_rtbh = getattr(sorter_rtbh, order) + + if comp_func: + rules_rtbh = db.session.query(RTBH).filter(comp_func(RTBH.expires, today)).order_by(sorting_rtbh()).all() + + else: + rules_rtbh = db.session.query(RTBH).order_by(sorting_rtbh()).all() + + return rules_rtbh + + +def get_user_rules_ids(user_id, rule_type): + """ + Returns list of rule ids belonging to user + :param user_id: user id + :param rule_type: ipv4, ipv6 or rtbh + :return: list + """ + + if rule_type == "ipv4": + rules4 = db.session.query(Flowspec4.id).filter_by(user_id=user_id).all() + return [int(x[0]) for x in rules4] + + if rule_type == "ipv6": + rules6 = db.session.query(Flowspec6.id).order_by(Flowspec6.expires.desc()).all() + return [int(x[0]) for x in rules6] + + if rule_type == "rtbh": + rules_rtbh = db.session.query(RTBH.id).filter_by(user_id=user_id).all() + return [int(x[0]) for x in rules_rtbh] diff --git a/flowapp/output.py b/flowapp/output.py index b0cc7b4..8322544 100644 --- a/flowapp/output.py +++ b/flowapp/output.py @@ -96,7 +96,7 @@ def log_route(user_id, route_model, rule_type, author): :param rule_type: string :return: None """ - print(rule_type) + rule_type = rule_type.value converter = ROUTE_MODELS[rule_type] task = converter(route_model) log = Log( @@ -119,7 +119,7 @@ def log_withdraw(user_id, task, rule_type, deleted_id, author): log = Log( time=datetime.now(), task=task, - rule_type=rule_type, + rule_type=rule_type.value, rule_id=deleted_id, user_id=user_id, author=author, diff --git a/flowapp/utils.py b/flowapp/utils.py index ad15b61..512f7cc 100644 --- a/flowapp/utils.py +++ b/flowapp/utils.py @@ -1,4 +1,3 @@ -from operator import ge, lt from datetime import datetime, timedelta from flask import flash from flowapp.constants import ( @@ -73,7 +72,7 @@ def parse_api_time(apitime): except ValueError: mytime = False - return False + return mytime def quote_to_ent(comment): From 6fa4bb49ab488804c422d9c02e06bb5aebbea385 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 28 Feb 2025 08:53:40 +0100 Subject: [PATCH 08/46] bugfix - RuleTypes use in api_common and output --- flowapp/output.py | 44 +++++++++++++++++++++++-------------- flowapp/views/api_common.py | 12 +++++----- 2 files changed, 33 insertions(+), 23 deletions(-) diff --git a/flowapp/output.py b/flowapp/output.py index 8322544..b8774bf 100644 --- a/flowapp/output.py +++ b/flowapp/output.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, asdict from datetime import datetime +from typing import Dict, Union, Callable, Any import requests import pika @@ -11,9 +12,12 @@ from flask import current_app from flowapp import db, messages +from flowapp.constants import RuleTypes from flowapp.models import Log +from flowapp.models import Flowspec4, Flowspec6, RTBH -ROUTE_MODELS = { + +ROUTE_MODELS: Dict[int, Callable[[Any], str]] = { 1: messages.create_rtbh, 4: messages.create_ipv4, 6: messages.create_ipv6, @@ -21,21 +25,21 @@ class RouteSources: - UI = "UI" - API = "API" + UI: str = "UI" + API: str = "API" @dataclass class Route: author: str - source: RouteSources + source: str # Using str instead of RouteSources for flexibility command: str - def __dict__(self): + def __dict__(self) -> Dict[str, str]: return asdict(self) -def announce_route(route: Route): +def announce_route(route: Route) -> None: """ Dispatch route as dict to ExaBGP API API must be set in app config.py @@ -48,7 +52,7 @@ def announce_route(route: Route): announce_to_http(asdict(route)) -def announce_to_http(route): +def announce_to_http(route: Dict[str, str]) -> None: """ Announce route to ExaBGP HTTP API process """ @@ -64,7 +68,7 @@ def announce_to_http(route): current_app.logger.debug(f"Testing: {route}") -def announce_to_rabbitmq(route): +def announce_to_rabbitmq(route: Dict[str, str]) -> None: """ Announce rout to ExaBGP RabbitMQ API process """ @@ -85,24 +89,25 @@ def announce_to_rabbitmq(route): channel.queue_declare(queue=queue) channel.basic_publish(exchange="", routing_key=queue, body=json.dumps(route)) else: - current_app.logger.debug("Testing: {route}") + current_app.logger.debug(f"Testing: {route}") -def log_route(user_id, route_model, rule_type, author): +def log_route(user_id: int, route_model: Union[RTBH, Flowspec4, Flowspec6], rule_type: RuleTypes, author: str) -> None: """ Convert route to EXAFS message and log it to database :param user_id : int curent user - :param route_model: model with route object - :param rule_type: string + :param route_model: model with route object (RTBH, Flowspec4, or Flowspec6) + :param rule_type: RuleTypes enum + :param author: str name of the author :return: None """ - rule_type = rule_type.value - converter = ROUTE_MODELS[rule_type] + rule_type_value = rule_type.value + converter = ROUTE_MODELS[rule_type_value] task = converter(route_model) log = Log( time=datetime.now(), task=task, - rule_type=rule_type, + rule_type=rule_type_value, rule_id=route_model.id, user_id=user_id, author=author, @@ -111,10 +116,15 @@ def log_route(user_id, route_model, rule_type, author): db.session.commit() -def log_withdraw(user_id, task, rule_type, deleted_id, author): +def log_withdraw(user_id: int, task: str, rule_type: RuleTypes, deleted_id: int, author: str) -> None: """ Log the withdraw command to database - :param task: command message + :param user_id: int user ID + :param task: str command message + :param rule_type: RuleTypes enum + :param deleted_id: int ID of deleted rule + :param author: str name of the author + :return: None """ log = Log( time=datetime.now(), diff --git a/flowapp/views/api_common.py b/flowapp/views/api_common.py index d04e40d..b421eea 100644 --- a/flowapp/views/api_common.py +++ b/flowapp/views/api_common.py @@ -292,7 +292,7 @@ def create_ipv4(current_user): log_route( current_user["id"], model, - RuleTypes.IPv4.value, + RuleTypes.IPv4, f"{current_user['uuid']} / {current_user['org']}", ) pref_format = output_date_format(json_request_data, form.expires.pref_format) @@ -368,7 +368,7 @@ def create_ipv6(current_user): log_route( current_user["id"], model, - RuleTypes.IPv6.value, + RuleTypes.IPv6, f"{current_user['uuid']} / {current_user['org']}", ) @@ -441,7 +441,7 @@ def create_rtbh(current_user): log_route( current_user["id"], model, - RuleTypes.RTBH.value, + RuleTypes.RTBH, f"{current_user['uuid']} / {current_user['org']}", ) @@ -506,7 +506,7 @@ def delete_v4_rule(current_user, rule_id): """ model_name = Flowspec4 route_model = messages.create_ipv4 - return delete_rule(current_user, rule_id, model_name, route_model, 4) + return delete_rule(current_user, rule_id, model_name, route_model, RuleTypes.IPv4) def delete_v6_rule(current_user, rule_id): @@ -516,7 +516,7 @@ def delete_v6_rule(current_user, rule_id): """ model_name = Flowspec6 route_model = messages.create_ipv6 - return delete_rule(current_user, rule_id, model_name, route_model, 6) + return delete_rule(current_user, rule_id, model_name, route_model, RuleTypes.IPv6) def delete_rtbh_rule(current_user, rule_id): @@ -526,7 +526,7 @@ def delete_rtbh_rule(current_user, rule_id): """ model_name = RTBH route_model = messages.create_rtbh - return delete_rule(current_user, rule_id, model_name, route_model, 1) + return delete_rule(current_user, rule_id, model_name, route_model, RuleTypes.RTBH) def delete_rule(current_user, rule_id, model_name, route_model, rule_type): From 24253830d94e817a17c403b31dab198f533e301f Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 28 Feb 2025 11:29:02 +0100 Subject: [PATCH 09/46] refactoring forms.py into directory forms for better readibility and clean code --- flowapp/forms.py | 713 ------------------------------- flowapp/forms/__init__.py | 40 ++ flowapp/forms/api.py | 61 +++ flowapp/forms/base.py | 50 +++ flowapp/forms/choices.py | 81 ++++ flowapp/forms/organization.py | 47 ++ flowapp/forms/rules/__init__.py | 17 + flowapp/forms/rules/base.py | 127 ++++++ flowapp/forms/rules/ipv4.py | 70 +++ flowapp/forms/rules/ipv6.py | 64 +++ flowapp/forms/rules/rtbh.py | 104 +++++ flowapp/forms/rules/whitelist.py | 66 +++ flowapp/forms/user.py | 91 ++++ 13 files changed, 818 insertions(+), 713 deletions(-) delete mode 100644 flowapp/forms.py create mode 100644 flowapp/forms/__init__.py create mode 100644 flowapp/forms/api.py create mode 100644 flowapp/forms/base.py create mode 100644 flowapp/forms/choices.py create mode 100644 flowapp/forms/organization.py create mode 100644 flowapp/forms/rules/__init__.py create mode 100644 flowapp/forms/rules/base.py create mode 100644 flowapp/forms/rules/ipv4.py create mode 100644 flowapp/forms/rules/ipv6.py create mode 100644 flowapp/forms/rules/rtbh.py create mode 100644 flowapp/forms/rules/whitelist.py create mode 100644 flowapp/forms/user.py diff --git a/flowapp/forms.py b/flowapp/forms.py deleted file mode 100644 index 5baca78..0000000 --- a/flowapp/forms.py +++ /dev/null @@ -1,713 +0,0 @@ -import csv -from io import StringIO - - -from flask_wtf import FlaskForm -from wtforms import widgets -from wtforms import ( - BooleanField, - DateTimeField, - HiddenField, - IntegerField, - SelectField, - SelectMultipleField, - StringField, - TextAreaField, -) -from wtforms.validators import ( - ValidationError, - DataRequired, - Email, - InputRequired, - IPAddress, - Length, - NumberRange, - Optional, -) - -from flowapp.constants import ( - IPV4_FRAGMENT, - IPV4_PROTOCOL, - IPV6_NEXT_HEADER, - TCP_FLAGS, - FORM_TIME_PATTERN, -) -from flowapp.validators import ( - IPAddressValidator, - IPv4Address, - IPv6Address, - NetRangeString, - NetworkValidator, - PortString, - address_in_range, - address_with_mask, - network_in_range, - whole_world_range, -) - -from flowapp.utils import parse_api_time - - -class MultiFormatDateTimeLocalField(DateTimeField): - """ - Same as :class:`~wtforms.fields.DateTimeField`, but represents an - ````. - - Custom implementation uses default HTML5 format for parsing the field. - It's possible to use multiple formats - used in API. - - """ - - widget = widgets.DateTimeLocalInput() - - def __init__(self, *args, **kwargs): - kwargs.setdefault("format", "%Y-%m-%dT%H:%M") - self.unlimited = kwargs.pop("unlimited", False) - self.pref_format = None - super().__init__(*args, **kwargs) - - def process_formdata(self, valuelist): - if not valuelist or (len(valuelist) == 1 and not valuelist[0]): - return None - - # with unlimited field we do not need to parse the empty value - if self.unlimited and len(valuelist) == 1 and len(valuelist[0]) == 0: - self.data = None - return None - - date_str = " ".join((str(val) for val in valuelist)) - - try: - result, pref_format = parse_api_time(date_str) - except TypeError: - raise ValueError(self.gettext("Not a valid datetime value.")) - - if result: - self.data = result - self.pref_format = pref_format - else: - self.data = None - self.pref_format = None - raise ValueError(self.gettext("Not a valid datetime value.")) - - -class UserForm(FlaskForm): - """ - User Form object - used in Admin - """ - - uuid = StringField( - "Unique User ID", - validators=[ - InputRequired("Please provide UUID"), - Email("Please provide valid email"), - ], - ) - - email = StringField("Email", validators=[Optional(), Email("Please provide valid email")]) - - comment = StringField("Notice", validators=[Optional()]) - - name = StringField("Name", validators=[Optional()]) - - phone = StringField("Contact phone", validators=[Optional()]) - - role_ids = SelectMultipleField("Role", coerce=int, validators=[DataRequired("Select at last one role")]) - - org_ids = SelectMultipleField( - "Organization", - coerce=int, - validators=[DataRequired("We prefer one Organization per user, but it's possible select more")], - ) - - -class BulkUserForm(FlaskForm): - """ - Bulk User Form object - used in Admin - """ - - users = TextAreaField("Users in CSV - see example below", validators=[DataRequired()]) - - def __init__(self, *args, **kwargs): - super(BulkUserForm, self).__init__(*args, **kwargs) - self.roles = None - self.organizations = None - self.uuids = None - - # Custom validator for CSV data - def validate_users(self, field): - csv_data = field.data - - # Parse CSV data - csv_reader = csv.DictReader(StringIO(csv_data), delimiter=",") - - # List to keep track of failed validation rows - errors = 0 - for row_num, row in enumerate(csv_reader, start=1): - try: - # check if the user not already exists - if row["uuid-eppn"] in self.uuids: - field.errors.append(f"Row {row_num}: User with UUID {row['uuid-eppn']} already exists.") - errors += 1 - - # Check if role exists in the database - role_id = int(row["role"]) # Convert role field to integer - if role_id not in self.roles: - field.errors.append(f"Row {row_num}: Role ID {role_id} does not exist.") - errors += 1 - - # Check if organization exists in the database - org_id = int(row["organizace"]) # Convert organization field to integer - if org_id not in self.organizations: - field.errors.append(f"Row {row_num}: Organization ID {org_id} does not exist.") - errors += 1 - - except (KeyError, ValueError) as e: - field.errors.append(f"Row {row_num}: Invalid data / key - {str(e)}. Check CSV head row.") - - if errors > 0: - # Raise validation error if any invalid rows found - raise ValidationError("Invalid CSV Data - check the errors above.") - - -class ApiKeyForm(FlaskForm): - """ - ApiKey for User - Each key / machine pair is unique - """ - - machine = StringField( - "Machine address", - validators=[DataRequired(), IPAddress(message="provide valid IP address")], - ) - - comment = TextAreaField("Your comment for this key", validators=[Optional(), Length(max=255)]) - - expires = MultiFormatDateTimeLocalField( - "Key expiration. Leave blank for non expring key (not-recomended).", - format=FORM_TIME_PATTERN, - validators=[Optional()], - unlimited=True, - ) - - readonly = BooleanField("Read only key", default=False) - - key = HiddenField("GeneratedKey") - - -class MachineApiKeyForm(FlaskForm): - """ - ApiKey for Machines - Each key / machine pair is unique - Only Admin can create new these keys - """ - - machine = StringField( - "Machine address", - validators=[DataRequired(), IPAddress(message="provide valid IP address")], - ) - - comment = TextAreaField("Your comment for this key", validators=[Optional(), Length(max=255)]) - - expires = MultiFormatDateTimeLocalField( - "Key expiration. Leave blank for non expring key (not-recomended).", - format=FORM_TIME_PATTERN, - validators=[Optional()], - unlimited=True, - ) - - readonly = BooleanField("Read only key", default=False) - - key = HiddenField("GeneratedKey") - - -class OrganizationForm(FlaskForm): - """ - Organization form object - used in Admin - """ - - name = StringField("Organization name", validators=[Optional(), Length(max=150)]) - - limit_flowspec4 = IntegerField( - "Maximum number of IPv4 rules, 0 for unlimited", - validators=[ - Optional(), - NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), - ], - ) - - limit_flowspec6 = IntegerField( - "Maximum number of IPv6 rules, 0 for unlimited", - validators=[ - Optional(), - NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), - ], - ) - - limit_rtbh = IntegerField( - "Maximum number of RTBH rules, 0 for unlimited", - validators=[ - Optional(), - NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), - ], - ) - - arange = TextAreaField( - "Organization Adress Range - one range per row", - validators=[Optional(), NetRangeString()], - ) - - -class ActionForm(FlaskForm): - """ - Action form object - used in Admin - """ - - name = StringField("Action short name", validators=[Length(max=150)]) - - command = StringField("ExaBGP command", validators=[Length(max=150)]) - - description = StringField("Action description") - - role_id = SelectField( - "Minimal required role", - choices=[("2", "user"), ("3", "admin")], - validators=[DataRequired()], - ) - - -class ASPathForm(FlaskForm): - """ - AS Path form object - used in Admin - """ - - prefix = StringField("Prefix", validators=[Length(max=120), DataRequired()]) - - as_path = StringField("as-path value", validators=[Length(max=250), DataRequired()]) - - -class CommunityForm(FlaskForm): - """ - Community form object - used in Admin - """ - - name = StringField("Community short name", validators=[Length(max=120), DataRequired()]) - - comm = StringField("Community value", validators=[Length(max=2046)]) - - larcomm = StringField("Large community value", validators=[Length(max=2046)]) - - extcomm = StringField("Extended community value", validators=[Length(max=2046)]) - - description = StringField("Community description", validators=[Length(max=255)]) - - role_id = SelectField( - "Minimal required role", - choices=[("2", "user"), ("3", "admin")], - validators=[DataRequired()], - ) - - as_path = BooleanField("add AS-path (checked = true)") - - def validate(self): - """ - custom validation method - :return: boolean - """ - result = True - - if not FlaskForm.validate(self): - result = False - - if not self.comm.data and not self.extcomm.data and not self.larcomm.data: - err_message = "At last one of those values could not be empty" - self.comm.errors.append(err_message) - self.larcomm.errors.append(err_message) - self.extcomm.errors.append(err_message) - result = False - - return result - - -class RTBHForm(FlaskForm): - """ - RoadToBlackHole rule form - """ - - def __init__(self, *args, **kwargs): - super(RTBHForm, self).__init__(*args, **kwargs) - self.net_ranges = None - - ipv4 = StringField( - "IPv4 address", - validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], - ) - - ipv4_mask = IntegerField( - "IPv4 mask (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=32, message="invalid IPv4 mask value (0-32)"), - ], - ) - - ipv6 = StringField( - "IPv6 address", - validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], - ) - - ipv6_mask = IntegerField( - "IPv6 mask (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=128, message="invalid IPv6 mask value (0-128)"), - ], - ) - - community = SelectField( - "Community", - coerce=int, - validators=[ - DataRequired(message="Please select a community for the rule."), - ], - ) - - expires = MultiFormatDateTimeLocalField( - "Expires", - format=FORM_TIME_PATTERN, - validators=[DataRequired(), InputRequired()], - ) - - comment = arange = TextAreaField("Comments") - - def validate(self): - """ - custom validation method - :return: boolean - """ - result = True - - if not FlaskForm.validate(self): - result = False - - # ipv4 and ipv6 are mutually exclusive - # if both are set, validation fails - # if none is set, validation fails - # if one is set, validation passes - if self.ipv4.data and self.ipv6.data: - self.ipv4.errors.append("IPv4 and IPv6 are mutually exclusive in RTBH rule.") - self.ipv6.errors.append("IPv4 and IPv6 are mutually exclusive in RTBH rule.") - result = False - - if self.ipv4.data and not address_with_mask(self.ipv4.data, self.ipv4_mask.data): - self.ipv4.errors.append( - "This is not valid combination of address {} and mask {}.".format(self.ipv4.data, self.ipv4_mask.data) - ) - result = False - - if self.ipv6.data and not address_with_mask(self.ipv6.data, self.ipv6_mask.data): - self.ipv6.errors.append( - "This is not valid combination of address {} and mask {}.".format(self.ipv6.data, self.ipv6_mask.data) - ) - result = False - - ipv6_in_range = address_in_range(self.ipv6.data, self.net_ranges) - ipv4_in_range = address_in_range(self.ipv4.data, self.net_ranges) - - if not (ipv6_in_range or ipv4_in_range): - self.ipv6.errors.append("IPv4 or IPv6 address must be in organization range : {}.".format(self.net_ranges)) - self.ipv4.errors.append("IPv4 or IPv6 address must be in organization range : {}.".format(self.net_ranges)) - result = False - - return result - - -class IPForm(FlaskForm): - """ - Base class for IPv4 and IPv6 rules - """ - - def __init__(self, *args, **kwargs): - super(IPForm, self).__init__(*args, **kwargs) - self.net_ranges = None - - zero_address = None - source = None - source_mask = None - dest = None - dest_mask = None - flags = SelectMultipleField("TCP flag(s)", choices=TCP_FLAGS, validators=[Optional()]) - - source_port = StringField( - "Source port(s) - ; separated ", - validators=[Optional(), Length(max=255), PortString()], - ) - - dest_port = StringField( - "Destination port(s) - ; separated", - validators=[Optional(), Length(max=255), PortString()], - ) - - packet_len = StringField( - "Packet length - ; separated ", - validators=[Optional(), Length(max=255), PortString()], - ) - - action = SelectField( - "Action", - coerce=int, - validators=[DataRequired(message="Please select an action for the rule.")], - ) - - expires = MultiFormatDateTimeLocalField("Expires", format="%Y-%m-%dT%H:%M", validators=[InputRequired()]) - - comment = arange = TextAreaField("Comments") - - def validate(self): - """ - custom validation method - :return: boolean - """ - - result = True - if not FlaskForm.validate(self): - result = False - - source = self.validate_source_address() - dest = self.validate_dest_address() - ranges = self.validate_address_ranges() - ips = self.validate_ipv_specific() - - return result and source and dest and ranges and ips - - def validate_source_address(self): - """ - validate source address, set error message if validation fails - :return: boolean validation result - """ - if self.source.data and not address_with_mask(self.source.data, self.source_mask.data): - self.source.errors.append( - "This is not valid combination of address {} and mask {}.".format( - self.source.data, self.source_mask.data - ) - ) - return False - - return True - - def validate_dest_address(self): - """ - validate dest address, set error message if validation fails - :return: boolean validation result - """ - if self.dest.data and not address_with_mask(self.dest.data, self.dest_mask.data): - self.dest.errors.append( - "This is not valid combination of address {} and mask {}.".format(self.dest.data, self.dest_mask.data) - ) - return False - - return True - - def validate_address_ranges(self): - """ - validates if the address of source is in the user range - if the source and dest address are empty, check if the user - is member of whole world organization - :return: boolean validation result - """ - if not (self.source.data or self.dest.data): - whole_world_member = whole_world_range(self.net_ranges, self.zero_address) - if not whole_world_member: - self.source.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) - self.dest.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) - return False - else: - source_in_range = network_in_range(self.source.data, self.source_mask.data, self.net_ranges) - dest_in_range = network_in_range(self.dest.data, self.dest_mask.data, self.net_ranges) - if not (source_in_range or dest_in_range): - self.source.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) - self.dest.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) - return False - - return True - - def validate_ipv_specific(self): - """ - abstract method must be implemented in the subclass - """ - pass - - -class IPv4Form(IPForm): - """ - IPv4 form object - """ - - def __init__(self, *args, **kwargs): - super(IPv4Form, self).__init__(*args, **kwargs) - self.net_ranges = None - - zero_address = "0.0.0.0" - source = StringField( - "Source address", - validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], - ) - - source_mask = IntegerField( - "Source mask (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=32, message="invalid mask value (0-32)"), - ], - ) - - dest = StringField( - "Destination address", - validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], - ) - - dest_mask = IntegerField( - "Destination mask (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=32, message="invalid mask value (0-32)"), - ], - ) - - protocol = SelectField( - "Protocol", - choices=[(pr, pr.upper()) for pr in IPV4_PROTOCOL.keys()], - validators=[DataRequired()], - ) - - fragment = SelectMultipleField( - "Fragment", - choices=[(frv, frk.upper()) for frk, frv in IPV4_FRAGMENT.items()], - validators=[Optional()], - ) - - def validate_ipv_specific(self): - """ - validate protocol and flags, set error message if validation fails - :return: boolean validation result - """ - - if self.flags.data and self.protocol.data and len(self.flags.data) > 0 and self.protocol.data != "tcp": - self.flags.errors.append("Can not set TCP flags for protocol {} !".format(self.protocol.data.upper())) - return False - return True - - -class IPv6Form(IPForm): - """ - IPv6 form object - """ - - def __init__(self, *args, **kwargs): - super(IPv6Form, self).__init__(*args, **kwargs) - self.net_ranges = None - - zero_address = "::" - source = StringField( - "Source address", - validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], - ) - - source_mask = IntegerField( - "Source prefix length (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=128, message="invalid prefix value (0-128)"), - ], - ) - - dest = StringField( - "Destination address", - validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], - ) - - dest_mask = IntegerField( - "Destination prefix length (bits)", - validators=[ - Optional(), - NumberRange(min=0, max=128, message="invalid prefix value (0-128)"), - ], - ) - - next_header = SelectField( - "Next Header", - choices=[(pr, pr.upper()) for pr in IPV6_NEXT_HEADER.keys()], - validators=[DataRequired()], - ) - - def validate_ipv_specific(self): - """ - validate next header and flags, set error message if validation fails - :return: boolean validation result - """ - if len(self.flags.data) > 0 and self.next_header.data != "tcp": - self.flags.errors.append("Can not set TCP flags for next-header {} !".format(self.next_header.data.upper())) - return False - - return True - - -class WhitelistForm(FlaskForm): - """ - Whitelist form object - Used for creating and editing whitelist entries - Supports both IPv4 and IPv6 addresses - """ - - def __init__(self, *args, **kwargs): - super(WhitelistForm, self).__init__(*args, **kwargs) - self.net_ranges = None - - ip = StringField( - "IP address", - validators=[ - DataRequired(message="Please provide an IP address"), - IPAddressValidator(message="Please provide a valid IP address: {}"), - NetworkValidator(mask_field_name="mask"), - ], - ) - - mask = IntegerField( - "Network mask (bits)", - validators=[ - DataRequired(message="Please provide a network mask"), - ], - ) - - comment = TextAreaField("Comments", validators=[Optional(), Length(max=255)]) - - expires = MultiFormatDateTimeLocalField( - "Expires", - format=FORM_TIME_PATTERN, - validators=[DataRequired(), InputRequired()], - ) - - def validate(self): - """ - Custom validation method - :return: boolean - """ - result = True - - if not FlaskForm.validate(self): - result = False - - # Validate IP is in organization range - if self.ip.data and self.mask.data and self.net_ranges: - ip_in_range = network_in_range(self.ip.data, self.mask.data, self.net_ranges) - if not ip_in_range: - self.ip.errors.append("IP address must be in organization range: {}.".format(self.net_ranges)) - result = False - - return result diff --git a/flowapp/forms/__init__.py b/flowapp/forms/__init__.py new file mode 100644 index 0000000..a9d7d88 --- /dev/null +++ b/flowapp/forms/__init__.py @@ -0,0 +1,40 @@ +""" +Forms module for the flowapp application. +This file imports and re-exports all form classes from the module. +""" + +# Base field +from .base import MultiFormatDateTimeLocalField + +# User forms +from .user import UserForm, BulkUserForm + +# API key forms +from .api import ApiKeyForm, MachineApiKeyForm + +# Organization forms +from .organization import OrganizationForm + +# Action, ASPath, and Community forms +from .choices import ActionForm, ASPathForm, CommunityForm + +# Rule forms +from .rules import IPForm, IPv4Form, IPv6Form, RTBHForm, WhitelistForm + + +__all__ = [ + "MultiFormatDateTimeLocalField", + "UserForm", + "BulkUserForm", + "ApiKeyForm", + "MachineApiKeyForm", + "OrganizationForm", + "ActionForm", + "ASPathForm", + "CommunityForm", + "RTBHForm", + "IPForm", + "IPv4Form", + "IPv6Form", + "WhitelistForm", +] diff --git a/flowapp/forms/api.py b/flowapp/forms/api.py new file mode 100644 index 0000000..a1d9d20 --- /dev/null +++ b/flowapp/forms/api.py @@ -0,0 +1,61 @@ +""" +API key forms for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import StringField, TextAreaField, BooleanField, HiddenField +from wtforms.validators import DataRequired, IPAddress, Optional, Length + +from .base import MultiFormatDateTimeLocalField +from ..constants import FORM_TIME_PATTERN + + +class ApiKeyForm(FlaskForm): + """ + ApiKey for User + Each key / machine pair is unique + """ + + machine = StringField( + "Machine address", + validators=[DataRequired(), IPAddress(message="provide valid IP address")], + ) + + comment = TextAreaField("Your comment for this key", validators=[Optional(), Length(max=255)]) + + expires = MultiFormatDateTimeLocalField( + "Key expiration. Leave blank for non expring key (not-recomended).", + format=FORM_TIME_PATTERN, + validators=[Optional()], + unlimited=True, + ) + + readonly = BooleanField("Read only key", default=False) + + key = HiddenField("GeneratedKey") + + +class MachineApiKeyForm(FlaskForm): + """ + ApiKey for Machines + Each key / machine pair is unique + Only Admin can create new these keys + """ + + machine = StringField( + "Machine address", + validators=[DataRequired(), IPAddress(message="provide valid IP address")], + ) + + comment = TextAreaField("Your comment for this key", validators=[Optional(), Length(max=255)]) + + expires = MultiFormatDateTimeLocalField( + "Key expiration. Leave blank for non expring key (not-recomended).", + format=FORM_TIME_PATTERN, + validators=[Optional()], + unlimited=True, + ) + + readonly = BooleanField("Read only key", default=False) + + key = HiddenField("GeneratedKey") diff --git a/flowapp/forms/base.py b/flowapp/forms/base.py new file mode 100644 index 0000000..f99db23 --- /dev/null +++ b/flowapp/forms/base.py @@ -0,0 +1,50 @@ +""" +Base form fields for the flowapp application. +""" + +from wtforms import widgets +from wtforms.fields import DateTimeField +from ..utils import parse_api_time + + +class MultiFormatDateTimeLocalField(DateTimeField): + """ + Same as :class:`~wtforms.fields.DateTimeField`, but represents an + ````. + + Custom implementation uses default HTML5 format for parsing the field. + It's possible to use multiple formats - used in API. + + """ + + widget = widgets.DateTimeLocalInput() + + def __init__(self, *args, **kwargs): + kwargs.setdefault("format", "%Y-%m-%dT%H:%M") + self.unlimited = kwargs.pop("unlimited", False) + self.pref_format = None + super().__init__(*args, **kwargs) + + def process_formdata(self, valuelist): + if not valuelist or (len(valuelist) == 1 and not valuelist[0]): + return None + + # with unlimited field we do not need to parse the empty value + if self.unlimited and len(valuelist) == 1 and len(valuelist[0]) == 0: + self.data = None + return None + + date_str = " ".join((str(val) for val in valuelist)) + + try: + result, pref_format = parse_api_time(date_str) + except TypeError: + raise ValueError(self.gettext("Not a valid datetime value.")) + + if result: + self.data = result + self.pref_format = pref_format + else: + self.data = None + self.pref_format = None + raise ValueError(self.gettext("Not a valid datetime value.")) diff --git a/flowapp/forms/choices.py b/flowapp/forms/choices.py new file mode 100644 index 0000000..424b6b8 --- /dev/null +++ b/flowapp/forms/choices.py @@ -0,0 +1,81 @@ +""" +Action, ASPath, and Community forms for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import StringField, SelectField, BooleanField +from wtforms.validators import Length, DataRequired + + +class ActionForm(FlaskForm): + """ + Action form object + used in Admin + """ + + name = StringField("Action short name", validators=[Length(max=150)]) + + command = StringField("ExaBGP command", validators=[Length(max=150)]) + + description = StringField("Action description") + + role_id = SelectField( + "Minimal required role", + choices=[("2", "user"), ("3", "admin")], + validators=[DataRequired()], + ) + + +class ASPathForm(FlaskForm): + """ + AS Path form object + used in Admin + """ + + prefix = StringField("Prefix", validators=[Length(max=120), DataRequired()]) + + as_path = StringField("as-path value", validators=[Length(max=250), DataRequired()]) + + +class CommunityForm(FlaskForm): + """ + Community form object + used in Admin + """ + + name = StringField("Community short name", validators=[Length(max=120), DataRequired()]) + + comm = StringField("Community value", validators=[Length(max=2046)]) + + larcomm = StringField("Large community value", validators=[Length(max=2046)]) + + extcomm = StringField("Extended community value", validators=[Length(max=2046)]) + + description = StringField("Community description", validators=[Length(max=255)]) + + role_id = SelectField( + "Minimal required role", + choices=[("2", "user"), ("3", "admin")], + validators=[DataRequired()], + ) + + as_path = BooleanField("add AS-path (checked = true)") + + def validate(self): + """ + custom validation method + :return: boolean + """ + result = True + + if not FlaskForm.validate(self): + result = False + + if not self.comm.data and not self.extcomm.data and not self.larcomm.data: + err_message = "At last one of those values could not be empty" + self.comm.errors.append(err_message) + self.larcomm.errors.append(err_message) + self.extcomm.errors.append(err_message) + result = False + + return result diff --git a/flowapp/forms/organization.py b/flowapp/forms/organization.py new file mode 100644 index 0000000..37594ab --- /dev/null +++ b/flowapp/forms/organization.py @@ -0,0 +1,47 @@ +""" +Organization form for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import StringField, IntegerField, TextAreaField +from wtforms.validators import Optional, Length, NumberRange + +from ..validators import NetRangeString + + +class OrganizationForm(FlaskForm): + """ + Organization form object + used in Admin + """ + + name = StringField("Organization name", validators=[Optional(), Length(max=150)]) + + limit_flowspec4 = IntegerField( + "Maximum number of IPv4 rules, 0 for unlimited", + validators=[ + Optional(), + NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), + ], + ) + + limit_flowspec6 = IntegerField( + "Maximum number of IPv6 rules, 0 for unlimited", + validators=[ + Optional(), + NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), + ], + ) + + limit_rtbh = IntegerField( + "Maximum number of RTBH rules, 0 for unlimited", + validators=[ + Optional(), + NumberRange(min=0, max=1000, message="invalid mask value (0-1000)"), + ], + ) + + arange = TextAreaField( + "Organization Adress Range - one range per row", + validators=[Optional(), NetRangeString()], + ) diff --git a/flowapp/forms/rules/__init__.py b/flowapp/forms/rules/__init__.py new file mode 100644 index 0000000..9796e8c --- /dev/null +++ b/flowapp/forms/rules/__init__.py @@ -0,0 +1,17 @@ +""" +Rule forms for the flowapp application. +""" + +from .base import IPForm +from .ipv4 import IPv4Form +from .ipv6 import IPv6Form +from .rtbh import RTBHForm +from .whitelist import WhitelistForm + +__all__ = [ + "IPForm", + "IPv4Form", + "IPv6Form", + "RTBHForm", + "WhitelistForm", +] diff --git a/flowapp/forms/rules/base.py b/flowapp/forms/rules/base.py new file mode 100644 index 0000000..0e67776 --- /dev/null +++ b/flowapp/forms/rules/base.py @@ -0,0 +1,127 @@ +""" +Base rule form for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import SelectMultipleField, StringField, SelectField, TextAreaField +from wtforms.validators import Optional, Length, DataRequired, InputRequired + +from ..base import MultiFormatDateTimeLocalField +from ...constants import TCP_FLAGS +from ...validators import PortString, address_with_mask, network_in_range, whole_world_range + + +class IPForm(FlaskForm): + """ + Base class for IPv4 and IPv6 rules + """ + + def __init__(self, *args, **kwargs): + super(IPForm, self).__init__(*args, **kwargs) + self.net_ranges = None + + zero_address = None + source = None + source_mask = None + dest = None + dest_mask = None + flags = SelectMultipleField("TCP flag(s)", choices=TCP_FLAGS, validators=[Optional()]) + + source_port = StringField( + "Source port(s) - ; separated ", + validators=[Optional(), Length(max=255), PortString()], + ) + + dest_port = StringField( + "Destination port(s) - ; separated", + validators=[Optional(), Length(max=255), PortString()], + ) + + packet_len = StringField( + "Packet length - ; separated ", + validators=[Optional(), Length(max=255), PortString()], + ) + + action = SelectField( + "Action", + coerce=int, + validators=[DataRequired(message="Please select an action for the rule.")], + ) + + expires = MultiFormatDateTimeLocalField("Expires", format="%Y-%m-%dT%H:%M", validators=[InputRequired()]) + + comment = arange = TextAreaField("Comments") + + def validate(self): + """ + custom validation method + :return: boolean + """ + + result = True + if not FlaskForm.validate(self): + result = False + + source = self.validate_source_address() + dest = self.validate_dest_address() + ranges = self.validate_address_ranges() + ips = self.validate_ipv_specific() + + return result and source and dest and ranges and ips + + def validate_source_address(self): + """ + validate source address, set error message if validation fails + :return: boolean validation result + """ + if self.source.data and not address_with_mask(self.source.data, self.source_mask.data): + self.source.errors.append( + "This is not valid combination of address {} and mask {}.".format( + self.source.data, self.source_mask.data + ) + ) + return False + + return True + + def validate_dest_address(self): + """ + validate dest address, set error message if validation fails + :return: boolean validation result + """ + if self.dest.data and not address_with_mask(self.dest.data, self.dest_mask.data): + self.dest.errors.append( + "This is not valid combination of address {} and mask {}.".format(self.dest.data, self.dest_mask.data) + ) + return False + + return True + + def validate_address_ranges(self): + """ + validates if the address of source is in the user range + if the source and dest address are empty, check if the user + is member of whole world organization + :return: boolean validation result + """ + if not (self.source.data or self.dest.data): + whole_world_member = whole_world_range(self.net_ranges, self.zero_address) + if not whole_world_member: + self.source.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) + self.dest.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) + return False + else: + source_in_range = network_in_range(self.source.data, self.source_mask.data, self.net_ranges) + dest_in_range = network_in_range(self.dest.data, self.dest_mask.data, self.net_ranges) + if not (source_in_range or dest_in_range): + self.source.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) + self.dest.errors.append("Source or dest must be in organization range : {}.".format(self.net_ranges)) + return False + + return True + + def validate_ipv_specific(self): + """ + abstract method must be implemented in the subclass + """ + pass diff --git a/flowapp/forms/rules/ipv4.py b/flowapp/forms/rules/ipv4.py new file mode 100644 index 0000000..07f27a8 --- /dev/null +++ b/flowapp/forms/rules/ipv4.py @@ -0,0 +1,70 @@ +""" +IPv4 rule form for the flowapp application. +""" + +from wtforms import StringField, IntegerField, SelectField, SelectMultipleField +from wtforms.validators import Optional, NumberRange, DataRequired + +from ...constants import IPV4_PROTOCOL, IPV4_FRAGMENT +from ...validators import IPv4Address +from .base import IPForm + + +class IPv4Form(IPForm): + """ + IPv4 form object + """ + + def __init__(self, *args, **kwargs): + super(IPv4Form, self).__init__(*args, **kwargs) + self.net_ranges = None + + zero_address = "0.0.0.0" + source = StringField( + "Source address", + validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], + ) + + source_mask = IntegerField( + "Source mask (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=32, message="invalid mask value (0-32)"), + ], + ) + + dest = StringField( + "Destination address", + validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], + ) + + dest_mask = IntegerField( + "Destination mask (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=32, message="invalid mask value (0-32)"), + ], + ) + + protocol = SelectField( + "Protocol", + choices=[(pr, pr.upper()) for pr in IPV4_PROTOCOL.keys()], + validators=[DataRequired()], + ) + + fragment = SelectMultipleField( + "Fragment", + choices=[(frv, frk.upper()) for frk, frv in IPV4_FRAGMENT.items()], + validators=[Optional()], + ) + + def validate_ipv_specific(self): + """ + validate protocol and flags, set error message if validation fails + :return: boolean validation result + """ + + if self.flags.data and self.protocol.data and len(self.flags.data) > 0 and self.protocol.data != "tcp": + self.flags.errors.append("Can not set TCP flags for protocol {} !".format(self.protocol.data.upper())) + return False + return True diff --git a/flowapp/forms/rules/ipv6.py b/flowapp/forms/rules/ipv6.py new file mode 100644 index 0000000..aabd1b2 --- /dev/null +++ b/flowapp/forms/rules/ipv6.py @@ -0,0 +1,64 @@ +""" +IPv6 rule form for the flowapp application. +""" + +from wtforms import StringField, IntegerField, SelectField +from wtforms.validators import Optional, NumberRange, DataRequired + +from ...constants import IPV6_NEXT_HEADER +from ...validators import IPv6Address +from .base import IPForm + + +class IPv6Form(IPForm): + """ + IPv6 form object + """ + + def __init__(self, *args, **kwargs): + super(IPv6Form, self).__init__(*args, **kwargs) + self.net_ranges = None + + zero_address = "::" + source = StringField( + "Source address", + validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], + ) + + source_mask = IntegerField( + "Source prefix length (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=128, message="invalid prefix value (0-128)"), + ], + ) + + dest = StringField( + "Destination address", + validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], + ) + + dest_mask = IntegerField( + "Destination prefix length (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=128, message="invalid prefix value (0-128)"), + ], + ) + + next_header = SelectField( + "Next Header", + choices=[(pr, pr.upper()) for pr in IPV6_NEXT_HEADER.keys()], + validators=[DataRequired()], + ) + + def validate_ipv_specific(self): + """ + validate next header and flags, set error message if validation fails + :return: boolean validation result + """ + if len(self.flags.data) > 0 and self.next_header.data != "tcp": + self.flags.errors.append("Can not set TCP flags for next-header {} !".format(self.next_header.data.upper())) + return False + + return True diff --git a/flowapp/forms/rules/rtbh.py b/flowapp/forms/rules/rtbh.py new file mode 100644 index 0000000..cef7948 --- /dev/null +++ b/flowapp/forms/rules/rtbh.py @@ -0,0 +1,104 @@ +""" +RTBH rule form for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import StringField, IntegerField, SelectField, TextAreaField +from wtforms.validators import Optional, NumberRange, DataRequired, InputRequired + +from ...constants import FORM_TIME_PATTERN +from ...validators import IPv6Address, address_with_mask, address_in_range, IPv4Address +from ..base import MultiFormatDateTimeLocalField + + +class RTBHForm(FlaskForm): + """ + RoadToBlackHole rule form + """ + + def __init__(self, *args, **kwargs): + super(RTBHForm, self).__init__(*args, **kwargs) + self.net_ranges = None + + ipv4 = StringField( + "IPv4 address", + validators=[Optional(), IPv4Address(message="provide valid IPv4 adress")], + ) + + ipv4_mask = IntegerField( + "IPv4 mask (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=32, message="invalid IPv4 mask value (0-32)"), + ], + ) + + ipv6 = StringField( + "IPv6 address", + validators=[Optional(), IPv6Address(message="provide valid IPv6 adress")], + ) + + ipv6_mask = IntegerField( + "IPv6 mask (bits)", + validators=[ + Optional(), + NumberRange(min=0, max=128, message="invalid IPv6 mask value (0-128)"), + ], + ) + + community = SelectField( + "Community", + coerce=int, + validators=[ + DataRequired(message="Please select a community for the rule."), + ], + ) + + expires = MultiFormatDateTimeLocalField( + "Expires", + format=FORM_TIME_PATTERN, + validators=[DataRequired(), InputRequired()], + ) + + comment = arange = TextAreaField("Comments") + + def validate(self): + """ + custom validation method + :return: boolean + """ + result = True + + if not FlaskForm.validate(self): + result = False + + # ipv4 and ipv6 are mutually exclusive + # if both are set, validation fails + # if none is set, validation fails + # if one is set, validation passes + if self.ipv4.data and self.ipv6.data: + self.ipv4.errors.append("IPv4 and IPv6 are mutually exclusive in RTBH rule.") + self.ipv6.errors.append("IPv4 and IPv6 are mutually exclusive in RTBH rule.") + result = False + + if self.ipv4.data and not address_with_mask(self.ipv4.data, self.ipv4_mask.data): + self.ipv4.errors.append( + "This is not valid combination of address {} and mask {}.".format(self.ipv4.data, self.ipv4_mask.data) + ) + result = False + + if self.ipv6.data and not address_with_mask(self.ipv6.data, self.ipv6_mask.data): + self.ipv6.errors.append( + "This is not valid combination of address {} and mask {}.".format(self.ipv6.data, self.ipv6_mask.data) + ) + result = False + + ipv6_in_range = address_in_range(self.ipv6.data, self.net_ranges) + ipv4_in_range = address_in_range(self.ipv4.data, self.net_ranges) + + if not (ipv6_in_range or ipv4_in_range): + self.ipv6.errors.append("IPv4 or IPv6 address must be in organization range : {}.".format(self.net_ranges)) + self.ipv4.errors.append("IPv4 or IPv6 address must be in organization range : {}.".format(self.net_ranges)) + result = False + + return result diff --git a/flowapp/forms/rules/whitelist.py b/flowapp/forms/rules/whitelist.py new file mode 100644 index 0000000..8aa88b4 --- /dev/null +++ b/flowapp/forms/rules/whitelist.py @@ -0,0 +1,66 @@ +""" +Whitelist form for the flowapp application. +""" + +from flask_wtf import FlaskForm +from wtforms import StringField, IntegerField, TextAreaField +from wtforms.validators import Optional, DataRequired, InputRequired, Length + +from ...constants import FORM_TIME_PATTERN +from ...validators import IPAddressValidator, NetworkValidator, network_in_range +from ..base import MultiFormatDateTimeLocalField + + +class WhitelistForm(FlaskForm): + """ + Whitelist form object + Used for creating and editing whitelist entries + Supports both IPv4 and IPv6 addresses + """ + + def __init__(self, *args, **kwargs): + super(WhitelistForm, self).__init__(*args, **kwargs) + self.net_ranges = None + + ip = StringField( + "IP address", + validators=[ + DataRequired(message="Please provide an IP address"), + IPAddressValidator(message="Please provide a valid IP address: {}"), + NetworkValidator(mask_field_name="mask"), + ], + ) + + mask = IntegerField( + "Network mask (bits)", + validators=[ + DataRequired(message="Please provide a network mask"), + ], + ) + + comment = TextAreaField("Comments", validators=[Optional(), Length(max=255)]) + + expires = MultiFormatDateTimeLocalField( + "Expires", + format=FORM_TIME_PATTERN, + validators=[DataRequired(), InputRequired()], + ) + + def validate(self): + """ + Custom validation method + :return: boolean + """ + result = True + + if not FlaskForm.validate(self): + result = False + + # Validate IP is in organization range + if self.ip.data and self.mask.data and self.net_ranges: + ip_in_range = network_in_range(self.ip.data, self.mask.data, self.net_ranges) + if not ip_in_range: + self.ip.errors.append("IP address must be in organization range: {}.".format(self.net_ranges)) + result = False + + return result diff --git a/flowapp/forms/user.py b/flowapp/forms/user.py new file mode 100644 index 0000000..5944afd --- /dev/null +++ b/flowapp/forms/user.py @@ -0,0 +1,91 @@ +""" +User-related forms for the flowapp application. +""" + +import csv +from io import StringIO + +from flask_wtf import FlaskForm +from wtforms import StringField, TextAreaField, SelectMultipleField +from wtforms.validators import ValidationError, DataRequired, Email, InputRequired, Optional + + +class UserForm(FlaskForm): + """ + User Form object + used in Admin + """ + + uuid = StringField( + "Unique User ID", + validators=[ + InputRequired("Please provide UUID"), + Email("Please provide valid email"), + ], + ) + + email = StringField("Email", validators=[Optional(), Email("Please provide valid email")]) + + comment = StringField("Notice", validators=[Optional()]) + + name = StringField("Name", validators=[Optional()]) + + phone = StringField("Contact phone", validators=[Optional()]) + + role_ids = SelectMultipleField("Role", coerce=int, validators=[DataRequired("Select at last one role")]) + + org_ids = SelectMultipleField( + "Organization", + coerce=int, + validators=[DataRequired("We prefer one Organization per user, but it's possible select more")], + ) + + +class BulkUserForm(FlaskForm): + """ + Bulk User Form object + used in Admin + """ + + users = TextAreaField("Users in CSV - see example below", validators=[DataRequired()]) + + def __init__(self, *args, **kwargs): + super(BulkUserForm, self).__init__(*args, **kwargs) + self.roles = None + self.organizations = None + self.uuids = None + + # Custom validator for CSV data + def validate_users(self, field): + csv_data = field.data + + # Parse CSV data + csv_reader = csv.DictReader(StringIO(csv_data), delimiter=",") + + # List to keep track of failed validation rows + errors = 0 + for row_num, row in enumerate(csv_reader, start=1): + try: + # check if the user not already exists + if row["uuid-eppn"] in self.uuids: + field.errors.append(f"Row {row_num}: User with UUID {row['uuid-eppn']} already exists.") + errors += 1 + + # Check if role exists in the database + role_id = int(row["role"]) # Convert role field to integer + if role_id not in self.roles: + field.errors.append(f"Row {row_num}: Role ID {role_id} does not exist.") + errors += 1 + + # Check if organization exists in the database + org_id = int(row["organizace"]) # Convert organization field to integer + if org_id not in self.organizations: + field.errors.append(f"Row {row_num}: Organization ID {org_id} does not exist.") + errors += 1 + + except (KeyError, ValueError) as e: + field.errors.append(f"Row {row_num}: Invalid data / key - {str(e)}. Check CSV head row.") + + if errors > 0: + # Raise validation error if any invalid rows found + raise ValidationError("Invalid CSV Data - check the errors above.") From 500d60a914179aba2f57737d4f1d6c30b249445e Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 28 Feb 2025 13:16:49 +0100 Subject: [PATCH 10/46] refactoring of app factory in flowap/__init__.py module. --- flowapp/__init__.py | 174 +++-------------------- flowapp/utils/__init__.py | 46 +++++++ flowapp/utils/app_factory.py | 205 ++++++++++++++++++++++++++++ flowapp/{utils.py => utils/base.py} | 0 4 files changed, 269 insertions(+), 156 deletions(-) create mode 100644 flowapp/utils/__init__.py create mode 100644 flowapp/utils/app_factory.py rename flowapp/{utils.py => utils/base.py} (100%) diff --git a/flowapp/__init__.py b/flowapp/__init__.py index f029c37..b201f25 100644 --- a/flowapp/__init__.py +++ b/flowapp/__init__.py @@ -1,10 +1,6 @@ # -*- coding: utf-8 -*- -import babel -import logging -from loguru import logger +from flask import Flask, redirect, render_template, session, url_for -from flask import Flask, redirect, render_template, session, url_for, request -from flask.logging import default_handler from flask_sso import SSO from flask_sqlalchemy import SQLAlchemy from flask_wtf.csrf import CSRFProtect @@ -26,13 +22,6 @@ swagger = Swagger(template_file="static/swagger.yml") -class InterceptHandler(logging.Handler): - - def emit(self, record): - logger_opt = logger.opt(depth=6, exception=record.exc_info, colors=True) - logger_opt.log(record.levelname, record.getMessage()) - - def create_app(config_object=None): app = Flask(__name__) @@ -64,76 +53,28 @@ def create_app(config_object=None): if app.config.get("BEHIND_PROXY", False): app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1) - from flowapp import models, constants, validators - from .views.admin import admin - from .views.rules import rules - from .views.api_v1 import api as api_v1 - from .views.api_v2 import api as api_v2 - from .views.api_v3 import api as api_v3 - from .views.api_keys import api_keys + from flowapp import models, constants from .auth import auth_required - from .views.dashboard import dashboard - - # no need for csrf on api because we use JWT - csrf.exempt(api_v1) - csrf.exempt(api_v2) - csrf.exempt(api_v3) - - app.register_blueprint(admin, url_prefix="/admin") - app.register_blueprint(rules, url_prefix="/rules") - app.register_blueprint(api_keys, url_prefix="/api_keys") - app.register_blueprint(api_v1, url_prefix="/api/v1") - app.register_blueprint(api_v2, url_prefix="/api/v2") - app.register_blueprint(api_v3, url_prefix="/api/v3") - app.register_blueprint(dashboard, url_prefix="/dashboard") - - # register loguru as handler - app.logger.removeHandler(default_handler) - app.logger.addHandler(InterceptHandler()) - - @ext.login_handler - def login(user_info): - try: - uuid = user_info.get("eppn") - except KeyError: - uuid = False - return render_template("errors/401.html") - return _handle_login(uuid) + # Register blueprints + from .utils import register_blueprints - @app.route("/logout") - def logout(): - session["user_uuid"] = False - session["user_id"] = False - session.clear() - return redirect(app.config.get("LOGOUT_URL")) + register_blueprints(app, csrf) - @app.route("/ext-login") - def ext_login(): - header_name = app.config.get("AUTH_HEADER_NAME", "X-Authenticated-User") - if header_name not in request.headers: - return render_template("errors/401.html") + # configure logging + from .utils import configure_logging - uuid = request.headers.get(header_name) - if not uuid: - return render_template("errors/401.html") + configure_logging(app) - return _handle_login(uuid) + # register error handlers + from .utils import register_error_handlers - @app.route("/local-login") - def local_login(): - print("Local login started") - if not app.config.get("LOCAL_AUTH", False): - print("Local auth not enabled") - return render_template("errors/401.html") + register_error_handlers(app) - uuid = app.config.get("LOCAL_USER_UUID", False) - if not uuid: - print("Local user not set") - return render_template("errors/401.html") + # register auth handlers + from .utils import register_auth_handlers - print(f"Local login with {uuid}") - return _handle_login(uuid) + register_auth_handlers(app, ext) @app.route("/") @auth_required @@ -192,89 +133,10 @@ def select_org(org_id=None): def shutdown_session(exception=None): db.session.remove() - # HTTP error handling - @app.errorhandler(404) - def not_found(error): - return render_template("errors/404.html"), 404 - - @app.errorhandler(500) - def internal_error(exception): - app.logger.exception(exception) - return render_template("errors/500.html"), 500 - - @app.context_processor - def utility_processor(): - def editable_rule(rule): - if rule: - validators.editable_range(rule, models.get_user_nets(session["user_id"])) - return True - return False - - return dict(editable_rule=editable_rule) - - @app.context_processor - def inject_main_menu(): - """ - inject main menu config to templates - used in default template to create main menu - """ - return {"main_menu": app.config.get("MAIN_MENU")} - - @app.context_processor - def inject_dashboard(): - """ - inject dashboard config to templates - used in submenu dashboard to create dashboard tables - """ - return {"dashboard": app.config.get("DASHBOARD")} - - @app.template_filter("strftime") - def format_datetime(value): - if value is None: - return app.config.get("MISSING_DATETIME_MESSAGE", "Never") - - format = "y/MM/dd HH:mm" - return babel.dates.format_datetime(value, format) - - @app.template_filter("unlimited") - def unlimited_filter(value): - return "unlimited" if value == 0 else value - - def _handle_login(uuid: str): - """ - handles rest of login process - """ - multiple_orgs = False - try: - user, multiple_orgs = _register_user_to_session(uuid) - except AttributeError as e: - app.logger.exception(e) - return render_template("errors/401.html") - - if multiple_orgs: - return redirect(url_for("select_org", org_id=None)) - - # set user org to session - user_org = user.organization.first() - session["user_org"] = user_org.name - session["user_org_id"] = user_org.id + # register context processors and template filters + from .utils import register_context_processors, register_template_filters - return redirect("/") - - def _register_user_to_session(uuid: str): - print(f"Registering user {uuid} to session") - user = db.session.query(models.User).filter_by(uuid=uuid).first() - print(f"Got user {user} from DB") - session["user_uuid"] = user.uuid - session["user_email"] = user.uuid - session["user_name"] = user.name - session["user_id"] = user.id - session["user_roles"] = [role.name for role in user.role.all()] - session["user_role_ids"] = [role.id for role in user.role.all()] - roles = [i > 1 for i in session["user_role_ids"]] - session["can_edit"] = True if all(roles) and roles else [] - # check if user has multiple organizations and return True if so - print(f"DEBUG SESSION {session}") - return user, len(user.organization.all()) > 1 + register_context_processors(app) + register_template_filters(app) return app diff --git a/flowapp/utils/__init__.py b/flowapp/utils/__init__.py new file mode 100644 index 0000000..99ca325 --- /dev/null +++ b/flowapp/utils/__init__.py @@ -0,0 +1,46 @@ +from .base import ( + other_rtypes, + output_date_format, + parse_api_time, + quote_to_ent, + webpicker_to_datetime, + datetime_to_webpicker, + get_state_by_time, + round_to_ten_minutes, + flash_errors, + active_css_rstate, + get_comp_func, +) + +from .app_factory import ( + # configure_app, + configure_logging, + register_blueprints, + register_error_handlers, + register_context_processors, + register_template_filters, + register_auth_handlers, + # register_org_routes, +) + +__all__ = [ + "other_rtypes", + "output_date_format", + "parse_api_time", + "quote_to_ent", + "webpicker_to_datetime", + "datetime_to_webpicker", + "get_state_by_time", + "round_to_ten_minutes", + "flash_errors", + "active_css_rstate", + "get_comp_func", + # "configure_app", + "configure_logging", + "register_blueprints", + "register_error_handlers", + "register_context_processors", + "register_template_filters", + "register_auth_handlers", + # "register_org_routes", +] diff --git a/flowapp/utils/app_factory.py b/flowapp/utils/app_factory.py new file mode 100644 index 0000000..b53b23b --- /dev/null +++ b/flowapp/utils/app_factory.py @@ -0,0 +1,205 @@ +import logging +import babel +from flask import redirect, render_template, request, session, url_for +from flask.logging import default_handler +from loguru import logger + + +def register_blueprints(app, csrf=None): + """Register Flask blueprints.""" + from flowapp.views.admin import admin + from flowapp.views.rules import rules + from flowapp.views.api_v1 import api as api_v1 + from flowapp.views.api_v2 import api as api_v2 + from flowapp.views.api_v3 import api as api_v3 + from flowapp.views.api_keys import api_keys + from flowapp.views.dashboard import dashboard + from flowapp.views.whitelist import whitelist + + # Configure CSRF exemption for API routes + if csrf: + csrf.exempt(api_v1) + csrf.exempt(api_v2) + csrf.exempt(api_v3) + + # Register blueprints with URL prefixes + app.register_blueprint(admin, url_prefix="/admin") + app.register_blueprint(rules, url_prefix="/rules") + app.register_blueprint(api_keys, url_prefix="/api_keys") + app.register_blueprint(api_v1, url_prefix="/api/v1") + app.register_blueprint(api_v2, url_prefix="/api/v2") + app.register_blueprint(api_v3, url_prefix="/api/v3") + app.register_blueprint(dashboard, url_prefix="/dashboard") + app.register_blueprint(whitelist, url_prefix="/whitelist") + + return app + + +class InterceptHandler(logging.Handler): + + def emit(self, record): + logger_opt = logger.opt(depth=6, exception=record.exc_info, colors=True) + logger_opt.log(record.levelname, record.getMessage()) + + +def configure_logging(app): + """Configure logging for the application.""" + # register loguru as handler + app.logger.removeHandler(default_handler) + app.logger.addHandler(InterceptHandler()) + + return app + + +def register_error_handlers(app): + """Register error handlers.""" + + @app.errorhandler(404) + def not_found(error): + return render_template("errors/404.html"), 404 + + @app.errorhandler(500) + def internal_error(exception): + app.logger.exception(exception) + return render_template("errors/500.html"), 500 + + return app + + +def register_context_processors(app): + """Register template context processors.""" + + @app.context_processor + def utility_processor(): + def editable_rule(rule): + if rule: + from flowapp.validators import editable_range + from flowapp.models import get_user_nets + + editable_range(rule, get_user_nets(session["user_id"])) + return True + return False + + return dict(editable_rule=editable_rule) + + @app.context_processor + def inject_main_menu(): + """Inject main menu config to templates.""" + return {"main_menu": app.config.get("MAIN_MENU")} + + @app.context_processor + def inject_dashboard(): + """Inject dashboard config to templates.""" + return {"dashboard": app.config.get("DASHBOARD")} + + return app + + +def register_template_filters(app): + """Register custom template filters.""" + + @app.template_filter("strftime") + def format_datetime(value): + if value is None: + return app.config.get("MISSING_DATETIME_MESSAGE", "Never") + + format = "y/MM/dd HH:mm" + return babel.dates.format_datetime(value, format) + + @app.template_filter("unlimited") + def unlimited_filter(value): + return "unlimited" if value == 0 else value + + return app + + +def register_auth_handlers(app, ext): + """Register authentication handlers.""" + + @ext.login_handler + def login(user_info): + try: + uuid = user_info.get("eppn") + except KeyError: + uuid = False + return render_template("errors/401.html") + + return _handle_login(uuid, app) + + @app.route("/logout") + def logout(): + session["user_uuid"] = False + session["user_id"] = False + session.clear() + return redirect(app.config.get("LOGOUT_URL")) + + @app.route("/ext-login") + def ext_login(): + header_name = app.config.get("AUTH_HEADER_NAME", "X-Authenticated-User") + if header_name not in request.headers: + return render_template("errors/401.html") + + uuid = request.headers.get(header_name) + if not uuid: + return render_template("errors/401.html") + + return _handle_login(uuid, app) + + @app.route("/local-login") + def local_login(): + print("Local login started") + if not app.config.get("LOCAL_AUTH", False): + print("Local auth not enabled") + return render_template("errors/401.html") + + uuid = app.config.get("LOCAL_USER_UUID", False) + if not uuid: + print("Local user not set") + return render_template("errors/401.html") + + print(f"Local login with {uuid}") + return _handle_login(uuid, app) + + return app + + +def _handle_login(uuid, app): + """Handle login process for all authentication methods.""" + from flowapp import db + + multiple_orgs = False + try: + user, multiple_orgs = _register_user_to_session(uuid, db) + except AttributeError as e: + app.logger.exception(e) + return render_template("errors/401.html") + + if multiple_orgs: + return redirect(url_for("select_org", org_id=None)) + + # set user org to session + user_org = user.organization.first() + session["user_org"] = user_org.name + session["user_org_id"] = user_org.id + + return redirect("/") + + +def _register_user_to_session(uuid, db): + """Register user information to session.""" + from flowapp.models import User + + print(f"Registering user {uuid} to session") + user = db.session.query(User).filter_by(uuid=uuid).first() + print(f"Got user {user} from DB") + session["user_uuid"] = user.uuid + session["user_email"] = user.uuid + session["user_name"] = user.name + session["user_id"] = user.id + session["user_roles"] = [role.name for role in user.role.all()] + session["user_role_ids"] = [role.id for role in user.role.all()] + roles = [i > 1 for i in session["user_role_ids"]] + session["can_edit"] = True if all(roles) and roles else [] + # check if user has multiple organizations and return True if so + print(f"DEBUG SESSION {session}") + return user, len(user.organization.all()) > 1 diff --git a/flowapp/utils.py b/flowapp/utils/base.py similarity index 100% rename from flowapp/utils.py rename to flowapp/utils/base.py From 9b379135e68d5622a46e91a3bea82fa5c51388be Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Mon, 10 Mar 2025 09:48:17 +0100 Subject: [PATCH 11/46] Refactor rule creation logic to use rule_service Rule service is now used both in rules and api views Replaced direct database operations with calls to rule_service.create_or_update_*_rule for IPv4, IPv6, and RTBH rules Streamlined route announcements and logging by delegating responsibility to the service layer Improved maintainability and modularity by centralizing rule handling logic in rule_service --- flowapp/services/__init__.py | 11 ++ flowapp/services/rule_service.py | 219 +++++++++++++++++++++++++++++++ flowapp/views/api_common.py | 156 ++++------------------ flowapp/views/rules.py | 157 ++++------------------ 4 files changed, 277 insertions(+), 266 deletions(-) create mode 100644 flowapp/services/__init__.py create mode 100644 flowapp/services/rule_service.py diff --git a/flowapp/services/__init__.py b/flowapp/services/__init__.py new file mode 100644 index 0000000..b8fc6e4 --- /dev/null +++ b/flowapp/services/__init__.py @@ -0,0 +1,11 @@ +from .rule_service import ( + create_or_update_ipv4_rule, + create_or_update_ipv6_rule, + create_or_update_rtbh_rule, +) + +__all__ = [ + create_or_update_ipv4_rule, + create_or_update_ipv6_rule, + create_or_update_rtbh_rule, +] diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py new file mode 100644 index 0000000..0505ab1 --- /dev/null +++ b/flowapp/services/rule_service.py @@ -0,0 +1,219 @@ +# flowapp/services/rule_service.py +""" +Service module for rule operations. + +This module provides business logic functions for creating, updating, +and managing flow rules, separating these concerns from HTTP handling. +""" + +from typing import Dict, Tuple + +from flowapp import db, messages +from flowapp.constants import RuleTypes, ANNOUNCE +from flowapp.models import ( + get_ipv4_model_if_exists, + get_ipv6_model_if_exists, + get_rtbh_model_if_exists, + Flowspec4, + Flowspec6, + RTBH, +) +from flowapp.output import Route, announce_route, log_route, RouteSources +from flowapp.utils import round_to_ten_minutes, get_state_by_time, quote_to_ent + + +def create_or_update_ipv4_rule( + form_data: Dict, user_id: int, org_id: int, user_email: str, org_name: str +) -> Tuple[Flowspec4, str]: + """ + Create a new IPv4 rule or update an existing one. + + Args: + form_data: Validated form data + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + + Returns: + Tuple containing (rule_model, message) + """ + # Check for existing model + model = get_ipv4_model_if_exists(form_data, 1) + + if model: + model.expires = round_to_ten_minutes(form_data["expires"]) + flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value." + else: + # Create new model + model = Flowspec4( + source=form_data["source"], + source_mask=form_data["source_mask"], + source_port=form_data["source_port"], + destination=form_data["dest"], + destination_mask=form_data["dest_mask"], + destination_port=form_data["dest_port"], + protocol=form_data["protocol"], + flags=";".join(form_data["flags"]), + packet_len=form_data["packet_len"], + fragment=";".join(form_data["fragment"]), + expires=round_to_ten_minutes(form_data["expires"]), + comment=quote_to_ent(form_data["comment"]), + action_id=form_data["action"], + user_id=user_id, + org_id=org_id, + rstate_id=get_state_by_time(form_data["expires"]), + ) + db.session.add(model) + flash_message = "IPv4 Rule saved" + + db.session.commit() + + # Announce route if model is in active state + if model.rstate_id == 1: + command = messages.create_ipv4(model, ANNOUNCE) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + # Log changes + log_route( + user_id, + model, + RuleTypes.IPv4, + f"{user_email} / {org_name}", + ) + + return model, flash_message + + +def create_or_update_ipv6_rule( + form_data: Dict, user_id: int, org_id: int, user_email: str, org_name: str +) -> Tuple[Flowspec6, str]: + """ + Create a new IPv6 rule or update an existing one. + + Args: + form_data: Validated form data + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + + Returns: + Tuple containing (rule_model, message) + """ + # Check for existing model + model = get_ipv6_model_if_exists(form_data, 1) + + if model: + model.expires = round_to_ten_minutes(form_data["expires"]) + flash_message = "Existing IPv6 Rule found. Expiration time was updated to new value." + else: + # Create new model + model = Flowspec6( + source=form_data["source"], + source_mask=form_data["source_mask"], + source_port=form_data["source_port"], + destination=form_data["dest"], + destination_mask=form_data["dest_mask"], + destination_port=form_data["dest_port"], + next_header=form_data["next_header"], + flags=";".join(form_data["flags"]), + packet_len=form_data["packet_len"], + expires=round_to_ten_minutes(form_data["expires"]), + comment=quote_to_ent(form_data["comment"]), + action_id=form_data["action"], + user_id=user_id, + org_id=org_id, + rstate_id=get_state_by_time(form_data["expires"]), + ) + db.session.add(model) + flash_message = "IPv6 Rule saved" + + db.session.commit() + + # Announce routes + if model.rstate_id == 1: + command = messages.create_ipv6(model, ANNOUNCE) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + # Log changes + log_route( + user_id, + model, + RuleTypes.IPv6, + f"{user_email} / {org_name}", + ) + + return model, flash_message + + +def create_or_update_rtbh_rule( + form_data: Dict, user_id: int, org_id: int, user_email: str, org_name: str +) -> Tuple[RTBH, str]: + """ + Create a new RTBH rule or update an existing one. + + Args: + form_data: Validated form data + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + + Returns: + Tuple containing (rule_model, message) + """ + # Check for existing model + model = get_rtbh_model_if_exists(form_data, 1) + + if model: + model.expires = round_to_ten_minutes(form_data["expires"]) + flash_message = "Existing RTBH Rule found. Expiration time was updated to new value." + else: + # Create new model + model = RTBH( + ipv4=form_data["ipv4"], + ipv4_mask=form_data["ipv4_mask"], + ipv6=form_data["ipv6"], + ipv6_mask=form_data["ipv6_mask"], + community_id=form_data["community"], + expires=round_to_ten_minutes(form_data["expires"]), + comment=quote_to_ent(form_data["comment"]), + user_id=user_id, + org_id=org_id, + rstate_id=get_state_by_time(form_data["expires"]), + ) + db.session.add(model) + flash_message = "RTBH Rule saved" + + db.session.commit() + + # Announce routes + if model.rstate_id == 1: + command = messages.create_rtbh(model, ANNOUNCE) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + # Log changes + log_route( + user_id, + model, + RuleTypes.RTBH, + f"{user_email} / {org_name}", + ) + + return model, flash_message diff --git a/flowapp/views/api_common.py b/flowapp/views/api_common.py index b421eea..39b7077 100644 --- a/flowapp/views/api_common.py +++ b/flowapp/views/api_common.py @@ -5,7 +5,7 @@ from functools import wraps from datetime import datetime, timedelta -from flowapp.constants import RULE_NAMES_DICT, WITHDRAW, ANNOUNCE, TIME_FORMAT_ARG, RuleTypes +from flowapp.constants import RULE_NAMES_DICT, WITHDRAW, TIME_FORMAT_ARG, RuleTypes from flowapp.models import ( RTBH, Flowspec4, @@ -18,20 +18,16 @@ check_rule_limit, get_user_nets, get_user_actions, - get_ipv4_model_if_exists, - get_ipv6_model_if_exists, insert_initial_communities, get_user_communities, - get_rtbh_model_if_exists, ) from flowapp.forms import IPv4Form, IPv6Form, RTBHForm +from flowapp.services import rule_service from flowapp.utils import ( - quote_to_ent, - get_state_by_time, output_date_format, ) from flowapp.auth import check_access_rights -from flowapp.output import announce_route, log_route, log_withdraw, Route, RouteSources +from flowapp.output import announce_route, log_withdraw, Route, RouteSources from flowapp import db, validators, flowspec, messages @@ -249,52 +245,15 @@ def create_ipv4(current_user): if form_errors: return jsonify(form_errors), 400 - model = get_ipv4_model_if_exists(form.data, 1) - - if model: - model.expires = form.expires.data - flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value." - else: - model = Flowspec4( - source=form.source.data, - source_mask=form.source_mask.data, - source_port=form.source_port.data, - destination=form.dest.data, - destination_mask=form.dest_mask.data, - destination_port=form.dest_port.data, - protocol=form.protocol.data, - flags=";".join(form.flags.data), - packet_len=form.packet_len.data, - fragment=";".join(form.fragment.data), - expires=form.expires.data, - comment=quote_to_ent(form.comment.data), - action_id=form.action.data, - user_id=current_user["id"], - org_id=current_user["org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - flash_message = "IPv4 Rule saved" - db.session.add(model) - - db.session.commit() - - # announce route if model is in active state - if model.rstate_id == 1: - command = messages.create_ipv4(model, ANNOUNCE) - route = Route( - author=f"{current_user['uuid']} / {current_user['org']}", - source=RouteSources.API, - command=command, - ) - announce_route(route) - - # log changes - log_route( - current_user["id"], - model, - RuleTypes.IPv4, - f"{current_user['uuid']} / {current_user['org']}", + # Use the service to create/update the rule + model, flash_message = rule_service.create_or_update_ipv4_rule( + form_data=form.data, + user_id=current_user["id"], + org_id=current_user["org_id"], + user_email=current_user["uuid"], + org_name=current_user["org"], ) + pref_format = output_date_format(json_request_data, form.expires.pref_format) response = {"message": flash_message, "rule": model.to_dict(pref_format)} return jsonify(response), 201 @@ -326,50 +285,12 @@ def create_ipv6(current_user): if form_errors: return jsonify(form_errors), 400 - model = get_ipv6_model_if_exists(form.data, 1) - - if model: - model.expires = form.expires.data - flash_message = "Existing IPv6 Rule found. Expiration time was updated to new value." - else: - model = Flowspec6( - source=form.source.data, - source_mask=form.source_mask.data, - source_port=form.source_port.data, - destination=form.dest.data, - destination_mask=form.dest_mask.data, - destination_port=form.dest_port.data, - next_header=form.next_header.data, - flags=";".join(form.flags.data), - packet_len=form.packet_len.data, - expires=form.expires.data, - comment=quote_to_ent(form.comment.data), - action_id=form.action.data, - user_id=current_user["id"], - org_id=current_user["org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - flash_message = "IPv6 Rule saved" - db.session.add(model) - - db.session.commit() - - # announce routes - if model.rstate_id == 1: - command = messages.create_ipv6(model, ANNOUNCE) - route = Route( - author=f"{current_user['uuid']} / {current_user['org']}", - source=RouteSources.API, - command=command, - ) - announce_route(route) - - # log changes - log_route( - current_user["id"], - model, - RuleTypes.IPv6, - f"{current_user['uuid']} / {current_user['org']}", + model, flash_message = rule_service.create_or_update_ipv6_rule( + form_data=form.data, + user_id=current_user["id"], + org_id=current_user["org_id"], + user_email=current_user["uuid"], + org_name=current_user["org"], ) pref_format = output_date_format(json_request_data, form.expires.pref_format) @@ -406,43 +327,12 @@ def create_rtbh(current_user): if form_errors: return jsonify(form_errors), 400 - model = get_rtbh_model_if_exists(form.data, 1) - - if model: - model.expires = form.expires.data - flash_message = "Existing RTBH Rule found. Expiration time was updated to new value." - else: - model = RTBH( - ipv4=form.ipv4.data, - ipv4_mask=form.ipv4_mask.data, - ipv6=form.ipv6.data, - ipv6_mask=form.ipv6_mask.data, - community_id=form.community.data, - expires=form.expires.data, - comment=quote_to_ent(form.comment.data), - user_id=current_user["id"], - org_id=current_user["org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - db.session.add(model) - db.session.commit() - flash_message = "RTBH Rule saved" - - # announce routes - if model.rstate_id == 1: - command = messages.create_rtbh(model, ANNOUNCE) - route = Route( - author=f"{current_user['uuid']} / {current_user['org']}", - source=RouteSources.API, - command=command, - ) - announce_route(route) - # log changes - log_route( - current_user["id"], - model, - RuleTypes.RTBH, - f"{current_user['uuid']} / {current_user['org']}", + model, flash_message = rule_service.create_or_update_rtbh_rule( + form_data=form.data, + user_id=current_user["id"], + org_id=current_user["org_id"], + user_email=current_user["uuid"], + org_name=current_user["org"], ) pref_format = output_date_format(json_request_data, form.expires.pref_format) diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py index 5618789..f0acfb4 100644 --- a/flowapp/views/rules.py +++ b/flowapp/views/rules.py @@ -24,19 +24,16 @@ Organization, check_global_rule_limit, check_rule_limit, - get_ipv4_model_if_exists, - get_ipv6_model_if_exists, - get_rtbh_model_if_exists, get_user_actions, get_user_communities, get_user_nets, insert_initial_communities, ) from flowapp.output import ROUTE_MODELS, announce_route, log_route, log_withdraw, RouteSources, Route +from flowapp.services import rule_service from flowapp.utils import ( flash_errors, get_state_by_time, - quote_to_ent, round_to_ten_minutes, ) @@ -476,54 +473,17 @@ def ipv4_rule(): form.net_ranges = net_ranges if request.method == "POST" and form.validate(): - model = get_ipv4_model_if_exists(form.data, 1) - - if model: - model.expires = round_to_ten_minutes(form.expires.data) - flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value." - else: - model = Flowspec4( - source=form.source.data, - source_mask=form.source_mask.data, - source_port=form.source_port.data, - destination=form.dest.data, - destination_mask=form.dest_mask.data, - destination_port=form.dest_port.data, - protocol=form.protocol.data, - flags=";".join(form.flags.data), - packet_len=form.packet_len.data, - fragment=";".join(form.fragment.data), - expires=round_to_ten_minutes(form.expires.data), - comment=quote_to_ent(form.comment.data), - action_id=form.action.data, - user_id=session["user_id"], - org_id=session["user_org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - flash_message = "IPv4 Rule saved" - db.session.add(model) - - db.session.commit() - flash(flash_message, "alert-success") - - # announce route if model is in active state - if model.rstate_id == 1: - command = messages.create_ipv4(model, constants.ANNOUNCE) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - - # log changes - log_route( - session["user_id"], - model, - RuleTypes.IPv4, - f"{session['user_email']} / {session['user_org']}", + # Use the service to create/update the rule + _model, message = rule_service.create_or_update_ipv4_rule( + form_data=form.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], ) + flash(message, "alert-success") + return redirect(url_for("index")) else: for field, errors in form.errors.items(): @@ -560,52 +520,14 @@ def ipv6_rule(): form.net_ranges = net_ranges if request.method == "POST" and form.validate(): - model = get_ipv6_model_if_exists(form.data, 1) - - if model: - model.expires = round_to_ten_minutes(form.expires.data) - flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value." - else: - model = Flowspec6( - source=form.source.data, - source_mask=form.source_mask.data, - source_port=form.source_port.data, - destination=form.dest.data, - destination_mask=form.dest_mask.data, - destination_port=form.dest_port.data, - next_header=form.next_header.data, - flags=";".join(form.flags.data), - packet_len=form.packet_len.data, - expires=round_to_ten_minutes(form.expires.data), - comment=quote_to_ent(form.comment.data), - action_id=form.action.data, - user_id=session["user_id"], - org_id=session["user_org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - flash_message = "IPv6 Rule saved" - db.session.add(model) - - db.session.commit() - flash(flash_message, "alert-success") - - # announce routes - if model.rstate_id == 1: - command = messages.create_ipv6(model, constants.ANNOUNCE) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - - # log changes - log_route( - session["user_id"], - model, - RuleTypes.IPv6, - f"{session['user_email']} / {session['user_org']}", + _model, message = rule_service.create_or_update_ipv6_rule( + form_data=form.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], ) + flash(message, "alert-success") return redirect(url_for("index")) else: @@ -644,45 +566,14 @@ def rtbh_rule(): form.net_ranges = net_ranges if request.method == "POST" and form.validate(): - model = get_rtbh_model_if_exists(form.data, 1) - - if model: - model.expires = round_to_ten_minutes(form.expires.data) - flash_message = "Existing RTBH Rule found. Expiration time was updated to new value." - else: - model = RTBH( - ipv4=form.ipv4.data, - ipv4_mask=form.ipv4_mask.data, - ipv6=form.ipv6.data, - ipv6_mask=form.ipv6_mask.data, - community_id=form.community.data, - expires=round_to_ten_minutes(form.expires.data), - comment=quote_to_ent(form.comment.data), - user_id=session["user_id"], - org_id=session["user_org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - db.session.add(model) - db.session.commit() - flash_message = "RTBH Rule saved" - - flash(flash_message, "alert-success") - # announce routes - if model.rstate_id == 1: - command = messages.create_rtbh(model, constants.ANNOUNCE) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - # log changes - log_route( - session["user_id"], - model, - RuleTypes.RTBH, - f"{session['user_email']} / {session['user_org']}", + _model, message = rule_service.create_or_update_rtbh_rule( + form_data=form.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], ) + flash(message, "alert-success") return redirect(url_for("index")) else: From 74bfefe6ae4508776338d4191d7e7c72013a0fa9 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Tue, 11 Mar 2025 12:55:30 +0100 Subject: [PATCH 12/46] Done - add whitelist, display whitelist in dashboard. WIP - basic whitelist handling when RTBH rule is created. --- flowapp/instance_config.py | 20 ++- flowapp/models/__init__.py | 2 + flowapp/models/rules/base.py | 1 + flowapp/models/rules/rtbh.py | 16 +++ flowapp/models/rules/whitelist.py | 13 +- flowapp/models/utils.py | 50 ++++++- flowapp/services/__init__.py | 3 + flowapp/services/rule_service.py | 44 +++++- flowapp/services/whitelist_service.py | 178 ++++++++++++++++++++++++ flowapp/templates/forms/whitelist.html | 62 +++++++++ flowapp/templates/macros.html | 48 ++++++- flowapp/tests/test_models.py | 3 + flowapp/tests/test_whitelist_service.py | 146 +++++++++++++++++++ flowapp/views/rules.py | 33 +++-- flowapp/views/whitelist.py | 66 ++------- 15 files changed, 603 insertions(+), 82 deletions(-) create mode 100644 flowapp/services/whitelist_service.py create mode 100644 flowapp/templates/forms/whitelist.html create mode 100644 flowapp/tests/test_whitelist_service.py diff --git a/flowapp/instance_config.py b/flowapp/instance_config.py index 1ffcb1f..86bc5fb 100644 --- a/flowapp/instance_config.py +++ b/flowapp/instance_config.py @@ -2,12 +2,19 @@ # column names for tables RTBH_COLUMNS = ( - ("ipv4", "IP adress (v4 or v6)"), + ("ipv4", "IP address (v4 or v6)"), ("community_id", "Community"), ("expires", "Expires"), ("user_id", "User"), ) +WHITELIST_COLUMNS = ( + ("address", "IP address / network (v4 or v6)"), + ("expires", "Expires"), + ("user_id", "User"), +) + + RULES_COLUMNS_V4 = ( ("source", "Source addr."), ("source_port", "S port"), @@ -75,6 +82,7 @@ class InstanceConfig: {"name": "Add IPv6", "url": "rules.ipv6_rule"}, {"name": "Add RTBH", "url": "rules.rtbh_rule"}, {"name": "API Key", "url": "api_keys.all"}, + {"name": "Add Whitelist", "url": "whitelist.add"}, ], "admin": [ {"name": "Commands Log", "url": "admin.log"}, @@ -123,6 +131,14 @@ class InstanceConfig: "table_colspan": 5, "table_columns": RTBH_COLUMNS, }, + "whitelist": { + "name": "Whitelist", + "macro_file": "macros.html", + "macro_tbody": "build_whitelist_tbody", + "macro_thead": "build_rules_thead", + "table_colspan": 4, + "table_columns": WHITELIST_COLUMNS, + }, } - COUNT_MATCH = {"ipv4": 0, "ipv6": 0, "rtbh": 0} + COUNT_MATCH = {"ipv4": 0, "ipv6": 0, "rtbh": 0, "whitelist": 0} diff --git a/flowapp/models/__init__.py b/flowapp/models/__init__.py index 6d74ffe..6405505 100644 --- a/flowapp/models/__init__.py +++ b/flowapp/models/__init__.py @@ -23,6 +23,7 @@ get_ipv4_model_if_exists, get_ipv6_model_if_exists, get_rtbh_model_if_exists, + get_whitelist_model_if_exists, get_ip_rules, get_user_rules_ids, insert_users, @@ -59,6 +60,7 @@ "get_ipv4_model_if_exists", "get_ipv6_model_if_exists", "get_rtbh_model_if_exists", + "get_whitelist_model_if_exists", "get_ip_rules", "get_user_rules_ids", "insert_users", diff --git a/flowapp/models/rules/base.py b/flowapp/models/rules/base.py index ab61f5e..22fbc08 100644 --- a/flowapp/models/rules/base.py +++ b/flowapp/models/rules/base.py @@ -37,6 +37,7 @@ def insert_initial_rulestates(table, conn, *args, **kwargs): conn.execute(table.insert().values(description="active rule")) conn.execute(table.insert().values(description="withdrawed rule")) conn.execute(table.insert().values(description="deleted rule")) + conn.execute(table.insert().values(description="whitelisted rule")) @event.listens_for(Action.__table__, "after_create") diff --git a/flowapp/models/rules/rtbh.py b/flowapp/models/rules/rtbh.py index d75220c..f183b81 100644 --- a/flowapp/models/rules/rtbh.py +++ b/flowapp/models/rules/rtbh.py @@ -132,3 +132,19 @@ def json(self, prefered_format="yearfirst"): :returns: json """ return json.dumps(self.to_dict()) + + def __repr__(self): + if not self.ipv6 and not self.ipv6_mask: + return f"" + if not self.ipv4 and not self.ipv4_mask: + return f"" + + return f"" + + def __str__(self): + if not self.ipv6 and not self.ipv6_mask: + return f"{self.ipv4}/{self.ipv4_mask}" + if not self.ipv4 and not self.ipv4_mask: + return f"{self.ipv6}/{self.ipv6_mask}" + + return f"{self.ipv4}/{self.ipv4_mask} {self.ipv6}/{self.ipv6_mask}" diff --git a/flowapp/models/rules/whitelist.py b/flowapp/models/rules/whitelist.py index 5e93714..f4c87ac 100644 --- a/flowapp/models/rules/whitelist.py +++ b/flowapp/models/rules/whitelist.py @@ -1,6 +1,7 @@ from flowapp import utils from ..base import db from datetime import datetime +from flowapp.constants import RuleTypes, RuleOrigin class Whitelist(db.Model): @@ -72,6 +73,12 @@ def to_dict(self, prefered_format="yearfirst"): "rstate": self.rstate.description, } + def __repr__(self): + return f"" + + def __str__(self): + return f"{self.ip}/{self.mask}" + class RuleWhitelistCache(db.Model): """ @@ -87,8 +94,8 @@ class RuleWhitelistCache(db.Model): whitelist_id = db.Column(db.Integer, db.ForeignKey("whitelist.id")) # Add ForeignKey whitelist = db.relationship("Whitelist", backref="rulewhitelistcache") - def __init__(self, rid, rtype, rorigin, whitelist_id): + def __init__(self, rid: int, rtype: RuleTypes, whitelist_id: int, rorigin: RuleOrigin = RuleOrigin.USER): self.rid = rid - self.rtype = rtype - self.rorigin = rorigin + self.rtype = rtype.value + self.rorigin = rorigin.value self.whitelist_id = whitelist_id diff --git a/flowapp/models/utils.py b/flowapp/models/utils.py index e2129b2..80a1af6 100644 --- a/flowapp/models/utils.py +++ b/flowapp/models/utils.py @@ -4,6 +4,8 @@ from flowapp import utils from flowapp.constants import RuleTypes from flask import current_app + +from flowapp.models.rules.whitelist import Whitelist from .base import db from .user import User, Role from .organization import Organization @@ -57,9 +59,33 @@ def check_global_rule_limit(rule_type: RuleTypes) -> bool: return rtbh >= rtbh_limit +def get_whitelist_model_if_exists(form_data, rstate_id=1): + """ + Check if the record in database exist + ip, mask, rstate_id should match + expires, user_id, org_id, created, comment can be different + """ + record = ( + db.session.query(Whitelist) + .filter( + Whitelist.ip == form_data["ip"], + Whitelist.mask == form_data["mask"], + Whitelist.rstate_id == rstate_id, + ) + .first() + ) + + if record: + return record + + return False + + def get_ipv4_model_if_exists(form_data, rstate_id=1): """ Check if the record in database exist + Source and destination addresses, protocol, flags, action and packet_len should match + Other fields can be different """ record = ( db.session.query(Flowspec4) @@ -113,9 +139,11 @@ def get_ipv6_model_if_exists(form_data, rstate_id=1): return False -def get_rtbh_model_if_exists(form_data, rstate_id=1): +def get_rtbh_model_if_exists(form_data): """ - Check if the record in database exist + Check RTBH rule exist in database + IPv4, IPv6 and community_id should match + Rule can be in any state and have different expires, user_id, org_id, created, comment """ record = ( @@ -126,7 +154,6 @@ def get_rtbh_model_if_exists(form_data, rstate_id=1): RTBH.ipv6 == form_data["ipv6"], RTBH.ipv6_mask == form_data["ipv6_mask"], RTBH.community_id == form_data["community"], - RTBH.rstate_id == rstate_id, ) .first() ) @@ -308,6 +335,23 @@ def get_ip_rules(rule_type, rule_state, sort="expires", order="desc"): return rules_rtbh + if rule_type == "whitelist": + sorter_whitelist = getattr(Whitelist, sort, Whitelist.id) + sorting_whitelist = getattr(sorter_whitelist, order) + + if comp_func: + rules_whitelist = ( + db.session.query(Whitelist) + .filter(comp_func(Whitelist.expires, today)) + .order_by(sorting_whitelist()) + .all() + ) + + else: + rules_whitelist = db.session.query(Whitelist).order_by(sorting_whitelist()).all() + + return rules_whitelist + def get_user_rules_ids(user_id, rule_type): """ diff --git a/flowapp/services/__init__.py b/flowapp/services/__init__.py index b8fc6e4..11766b3 100644 --- a/flowapp/services/__init__.py +++ b/flowapp/services/__init__.py @@ -4,8 +4,11 @@ create_or_update_rtbh_rule, ) +from .whitelist_service import create_or_update_whitelist + __all__ = [ create_or_update_ipv4_rule, create_or_update_ipv6_rule, create_or_update_rtbh_rule, + create_or_update_whitelist, ] diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index 0505ab1..5636c53 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -2,14 +2,15 @@ """ Service module for rule operations. -This module provides business logic functions for creating, updating, +This module provides business logic functions for creating, updating, and managing flow rules, separating these concerns from HTTP handling. """ +from datetime import datetime from typing import Dict, Tuple from flowapp import db, messages -from flowapp.constants import RuleTypes, ANNOUNCE +from flowapp.constants import RuleOrigin, RuleTypes, ANNOUNCE from flowapp.models import ( get_ipv4_model_if_exists, get_ipv6_model_if_exists, @@ -17,9 +18,12 @@ Flowspec4, Flowspec6, RTBH, + Whitelist, + RuleWhitelistCache, ) from flowapp.output import Route, announce_route, log_route, RouteSources from flowapp.utils import round_to_ten_minutes, get_state_by_time, quote_to_ent +from .whitelist_service import check_rule_against_whitelists, Relation def create_or_update_ipv4_rule( @@ -174,7 +178,7 @@ def create_or_update_rtbh_rule( Tuple containing (rule_model, message) """ # Check for existing model - model = get_rtbh_model_if_exists(form_data, 1) + model = get_rtbh_model_if_exists(form_data) if model: model.expires = round_to_ten_minutes(form_data["expires"]) @@ -198,6 +202,24 @@ def create_or_update_rtbh_rule( db.session.commit() + # Check if rule is whitelisted + # get all not expired whitelists + whitelists = db.session.query(Whitelist).filter(Whitelist.expires > datetime.now()).all() + whitelists = {str(w): w for w in whitelists} + results = check_rule_against_whitelists(str(model), whitelists.keys()) + # check rule against whitelists, stop search when rule is whitelisted first time + for rule, whitelist_key, relation in results: + match relation: + case Relation.EQUAL: + model = whitelist_rtbh_rule(model, whitelists[whitelist_key]) + break + case Relation.SUBNET: + print("WL is subnet of rule") + break + case Relation.SUPERNET: + model = whitelist_rtbh_rule(model, whitelists[whitelist_key]) + break + # Announce routes if model.rstate_id == 1: command = messages.create_rtbh(model, ANNOUNCE) @@ -217,3 +239,19 @@ def create_or_update_rtbh_rule( ) return model, flash_message + + +def whitelist_rtbh_rule(model: RTBH, whitelist: Whitelist) -> RTBH: + """ + Whitelist RTBH rule. + set state to 4 - whitelisted rule, do not announce + Add to whitelist cache + """ + model.rstate_id = 4 + db.session.commit() + # add to cache + cache = RuleWhitelistCache(rid=model.id, rtype=RuleTypes.RTBH, whitelist_id=whitelist.id, rorigin=RuleOrigin.USER) + db.session.add(cache) + db.session.commit() + + return model diff --git a/flowapp/services/whitelist_service.py b/flowapp/services/whitelist_service.py new file mode 100644 index 0000000..9652922 --- /dev/null +++ b/flowapp/services/whitelist_service.py @@ -0,0 +1,178 @@ +# flowapp/services/rule_service.py +""" +Service module for rule operations. + +This module provides business logic functions for creating, updating, +and managing flow rules, separating these concerns from HTTP handling. +""" + +from typing import Dict, Tuple, List +from enum import Enum, auto +import ipaddress +from functools import lru_cache + +from flowapp import db +from flowapp.models import Whitelist, get_whitelist_model_if_exists +from flowapp.utils import round_to_ten_minutes, quote_to_ent + + +def create_or_update_whitelist( + form_data: Dict, user_id: int, org_id: int, user_email: str, org_name: str +) -> Tuple[Whitelist, str]: + """ + Create a new Whitelist rule or update expiration time of an existing one. + + Args: + form_data: Validated form data + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + + Returns: + Tuple containing (whitelist_model, message) + """ + # Check for existing model + model = get_whitelist_model_if_exists(form_data, 1) + + if model: + model.expires = round_to_ten_minutes(form_data["expires"]) + flash_message = "Existing Whitelist found. Expiration time was updated to new value." + else: + # Create new model + model = Whitelist( + ip=form_data["ip"], + mask=form_data["mask"], + expires=round_to_ten_minutes(form_data["expires"]), + user_id=user_id, + org_id=org_id, + comment=quote_to_ent(form_data["comment"]), + ) + db.session.add(model) + flash_message = "Whitelist saved" + + db.session.commit() + + return model, flash_message + + +class Relation(Enum): + SUBNET = auto() + SUPERNET = auto() + EQUAL = auto() + DIFFERENT = auto() + + +@lru_cache(maxsize=1024) +def get_network(address: str) -> ipaddress.IPv4Network | ipaddress.IPv6Network: + """ + Create and cache an IP network object. + + :param address: IP address or network in string format + :return: Cached IP network object + """ + return ipaddress.ip_network(address, strict=False) + + +def check_whitelist_to_rule_relation(rule: str, whitelist_entry: str) -> Relation: + """ + Checks if the whitelist network is a subnet or supernet or exactly the same as the rule network. + Uses cached network objects for better performance. + + :param rule: The IP address or network to check (e.g., "192.168.1.1" or "192.168.1.0/24") + :param whitelist_entry: The allowed network to compare against (e.g., "192.168.1.0/24") + :return: Relation between the two networks + """ + rule_net = get_network(rule) + whitelist_net = get_network(whitelist_entry) + if whitelist_net == rule_net: + return Relation.EQUAL + if whitelist_net.supernet_of(rule_net): + return Relation.SUPERNET + if whitelist_net.subnet_of(rule_net): + return Relation.SUBNET + else: + return Relation.DIFFERENT + + +def subtract_network(target: str, whitelist: str) -> List[str]: + """ + Computes the remaining parts of a network after removing the whitelist subnet. + Uses cached network objects for better performance. + + :param target: The main network (e.g., "192.168.1.0/24") + :param whitelist: The subnet to remove (e.g., "192.168.1.128/25") + :return: A list of remaining subnets as strings + """ + target_net = get_network(target) + whitelist_net = get_network(whitelist) + + # Check if the whitelist is actually a subnet + if check_whitelist_to_rule_relation(target, whitelist) != Relation.SUBNET: + return [target] # Return the full network if whitelist isn't a valid subnet + + remaining = [] + + # Compute ranges before and after the whitelist + if whitelist_net.network_address > target_net.network_address: + # Before the whitelist + start = target_net.network_address + end = whitelist_net.network_address - 1 + remaining.extend(ipaddress.summarize_address_range(start, end)) + + if whitelist_net.broadcast_address < target_net.broadcast_address: + # After the whitelist + start = whitelist_net.broadcast_address + 1 + end = target_net.broadcast_address + remaining.extend(ipaddress.summarize_address_range(start, end)) + + # Convert to string format + return [str(net) for net in remaining] + + +def check_rule_against_whitelists(rule: str, whitelists: List[str]) -> List[Tuple]: + """ + Helper function to check a single rule against multiple whitelist entries. + Creates a cached rule network object for better performance. + Reduces list of whitelists, where the Relation is not DIFFERENT + + :param rule: The IP address or network to check + :param whitelists: List of whitelist networks to check against + :return: tuple of rule, whitelist and relation for each whitelists that is not DIFFERENT + """ + # Pre-cache the rule network since it will be used multiple times + get_network(rule) + items = [] + for whitelist in whitelists: + rel = check_whitelist_to_rule_relation(rule, whitelist) + if rel != Relation.DIFFERENT: + items.append((rule, whitelist, rel)) + return items + + +def check_whitelist_against_rules(rules: List[str], whitelist: str) -> List[Tuple]: + """ + Helper function to check if any whitelist entry is a subnet of the rule. + Creates a cached rule network object for better performance. + Reduces list of rules, where the Relation is not DIFFERENT + + :param rule: The IP address or network to check against + :param whitelists: List of whitelist networks to check + :return: tuple of rule, whitelist and relation for each whitelists that is not DIFFERENT + """ + # Pre-cache the rule network since it will be used multiple times + get_network(whitelist) + items = [] + for rule in rules: + rel = check_whitelist_to_rule_relation(rule, whitelist) + if rel != Relation.DIFFERENT: + items.append((rule, whitelist, rel)) + return items + + +def clear_network_cache() -> None: + """ + Clear the network object cache. + Useful when processing a large number of networks to prevent memory growth. + """ + get_network.cache_clear() diff --git a/flowapp/templates/forms/whitelist.html b/flowapp/templates/forms/whitelist.html new file mode 100644 index 0000000..35559d7 --- /dev/null +++ b/flowapp/templates/forms/whitelist.html @@ -0,0 +1,62 @@ +{% extends 'layouts/default.html' %} +{% from 'forms/macros.html' import render_field %} +{% block title %}Add Whitelist{% endblock %} +{% block content %} +

{{ title or 'New'}} Whitelist

+
+ {{ form.hidden_tag() if form.hidden_tag }} +
+
+
+
+ {{ render_field(form.ip) }} +
+
+ {{ render_field(form.mask) }} +
+
+
+
+
+ +
+ +
+
+
+ +
+ + {{ form.expires(class_='form-control') }} + + + +
+ + {% if form.expires.errors %} + {% for e in form.expires.errors %} +

{{ e }}

+ {% endfor %} + {% endif %} +
+
+
+
+
+ +
+
+ {{ render_field(form.comment) }} +
+
+ +
+
+ +
+
+
+{% endblock %} \ No newline at end of file diff --git a/flowapp/templates/macros.html b/flowapp/templates/macros.html index 9e0b4b7..9d51d3d 100644 --- a/flowapp/templates/macros.html +++ b/flowapp/templates/macros.html @@ -8,7 +8,7 @@ {% set rtype_int = 4 %} {% endif %} - + {{ rule.source }}{% if rule.source_mask != none %}{{ '/' if rule.source_mask >= 0 else '' }}{{ rule.source_mask if rule.source_mask >= 0 else '' }}{% endif %} @@ -76,7 +76,10 @@ {% macro build_rtbh_tbody(rules, today, editable=True, group_op=True) %} {% for rule in rules %} - + {% if rule.ipv4 %} {{ rule.ipv4 }}{{ '/' if rule.ipv4_mask else '' }}{{rule.ipv4_mask|default("", True)}} @@ -122,6 +125,47 @@ {% endmacro %} +{% macro build_whitelist_tbody(rules, today, editable=True, group_op=True) %} + +{% for rule in rules %} + + + {{ rule.ip }}{{ '/' if rule.mask else '' }}{{rule.mask|default("", True)}} + + + {{ rule.expires|strftime }} + + + {{ rule.user.name }} + + + {% if editable %} + + + + + + + {% endif %} + {% if rule.comment %} + + {% endif %} + + {% if editable and group_op %} + + + + {% endif %} + + +{% endfor %} + +{% endmacro %} + + + {% macro build_rules_thead(rules_columns, rtype, rstate, sort_key, sort_order, search_query='', group_op=True) %} diff --git a/flowapp/tests/test_models.py b/flowapp/tests/test_models.py index 82dff8e..b352a5f 100644 --- a/flowapp/tests/test_models.py +++ b/flowapp/tests/test_models.py @@ -473,6 +473,9 @@ def test_whitelist_to_dict(db): db.session.add_all([user, rstate]) db.session.commit() + db.session.add(whitelist) + db.session.commit() + whitelist.user = user whitelist.rstate_id = rstate.id db.session.add(whitelist) diff --git a/flowapp/tests/test_whitelist_service.py b/flowapp/tests/test_whitelist_service.py new file mode 100644 index 0000000..7ae4990 --- /dev/null +++ b/flowapp/tests/test_whitelist_service.py @@ -0,0 +1,146 @@ +import pytest +from flowapp.services.whitelist_service import ( + Relation, + check_whitelist_to_rule_relation, + subtract_network, + check_rule_against_whitelists, + check_whitelist_against_rules, + clear_network_cache, +) + + +# Tests for core function that checks network relations +@pytest.mark.parametrize( + "rule,whitelist,expected", + [ + # IPv4 test cases + ("192.168.1.0/24", "192.168.0.0/16", Relation.SUPERNET), # Whitelist is supernet + ("192.168.1.0/24", "192.168.1.0/24", Relation.EQUAL), # Equal networks + ("192.168.1.0/24", "192.168.1.128/25", Relation.SUBNET), # Whitelist is subnet + ("192.168.1.128/25", "192.168.1.0/24", Relation.SUPERNET), # Whitelist is supernet + ("10.0.0.0/8", "192.168.1.0/24", Relation.DIFFERENT), # Different networks + # IPv6 test cases + ("2001:db8::/32", "2001:db8::/32", Relation.EQUAL), # Equal networks + ("2001:db8:1::/48", "2001:db8::/32", Relation.SUPERNET), # Whitelist is supernet + ("2001:db8::/32", "2001:db8:1::/48", Relation.SUBNET), # Whitelist is subnet + ("2001:db8::/32", "2002:db8::/32", Relation.DIFFERENT), # Different networks + ("2001:db8:1:2::/64", "2001:db8::/32", Relation.SUPERNET), # Whitelist is supernet + ], +) +def test_check_whitelist_to_rule_relation(rule, whitelist, expected): + clear_network_cache() + assert check_whitelist_to_rule_relation(rule, whitelist) == expected + + +@pytest.mark.parametrize( + "target,whitelist,expected", + [ + # Basic IPv4 cases + ( + "192.168.1.0/24", + "192.168.1.128/25", + ["192.168.1.0/25"], + ), # One remaining subnet + ( + "192.168.1.0/24", + "192.168.1.64/26", + ["192.168.1.0/26", "192.168.1.128/25"], + ), # Two remaining subnets + ( + "192.168.1.0/24", + "192.168.1.0/24", + ["192.168.1.0/24"], + ), # Equal networks - return original network + ( + "192.168.1.0/24", + "192.168.2.0/24", + ["192.168.1.0/24"], + ), # No overlap + ], +) +def test_subtract_network(target, whitelist, expected): + clear_network_cache() + result = subtract_network(target, whitelist) + assert sorted(result) == sorted(expected) + + +def test_check_rule_against_whitelists(): + clear_network_cache() + + rule = "192.168.1.0/24" + whitelists = [ + "192.168.0.0/16", # SUPERNET + "172.16.0.0/12", # DIFFERENT + "192.168.1.0/24", # EQUAL + "192.168.1.128/25", # SUBNET + ] + + results = check_rule_against_whitelists(rule, whitelists) + + # Should return list of tuples (rule, whitelist, relation) for non-DIFFERENT relations + expected = [ + (rule, "192.168.0.0/16", Relation.SUPERNET), + (rule, "192.168.1.0/24", Relation.EQUAL), + (rule, "192.168.1.128/25", Relation.SUBNET), + ] + + assert len(results) == 3 + for result, exp in zip(sorted(results), sorted(expected)): + assert result == exp + + +def test_check_whitelist_against_rules(): + clear_network_cache() + + whitelist = "192.168.1.128/25" + rules = [ + "192.168.0.0/16", # Whitelist is SUBNET of this larger network + "172.16.0.0/12", # DIFFERENT + "192.168.1.128/25", # EQUAL + "192.168.1.0/24", # Whitelist is SUBNET of this network + ] + + results = check_whitelist_against_rules(rules, whitelist) + + # Should return list of tuples (rule, whitelist, relation) for non-DIFFERENT relations + expected = [ + ("192.168.0.0/16", whitelist, Relation.SUBNET), + ("192.168.1.128/25", whitelist, Relation.EQUAL), + ("192.168.1.0/24", whitelist, Relation.SUBNET), + ] + + assert len(results) == 3 + for result, exp in zip(sorted(results), sorted(expected)): + assert result == exp + + +# Tests for edge cases and error handling +def test_host_addresses(): + clear_network_cache() + # Test with host addresses (no explicit subnet mask) + assert check_whitelist_to_rule_relation("192.168.1.1", "192.168.1.0/24") == Relation.SUPERNET + assert check_whitelist_to_rule_relation("192.168.1.1", "192.168.2.0/24") == Relation.DIFFERENT + + +def test_single_ip_as_network(): + clear_network_cache() + # Test with /32 (single IPv4 address as network) + assert check_whitelist_to_rule_relation("192.168.1.1/32", "192.168.1.0/24") == Relation.SUPERNET + # Test with /128 (single IPv6 address as network) + assert check_whitelist_to_rule_relation("2001:db8::1/128", "2001:db8::/32") == Relation.SUPERNET + + +def test_invalid_input(): + clear_network_cache() + with pytest.raises(ValueError): + check_whitelist_to_rule_relation("invalid", "192.168.1.0/24") + + with pytest.raises(ValueError): + check_whitelist_to_rule_relation("192.168.1.0/24", "invalid") + + with pytest.raises(ValueError): + subtract_network("invalid", "192.168.1.0/24") + + +if __name__ == "__main__": + pytest.main(["-v"]) diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py index f0acfb4..4805e44 100644 --- a/flowapp/views/rules.py +++ b/flowapp/views/rules.py @@ -58,9 +58,13 @@ def reactivate_rule(rule_type, rule_id): """ Set new time for the rule of given type identified by id - :param rule_type: string - type of rule + :param rule_type: integer - type of rule, corresponds to RuleTypes enum value :param rule_id: integer - id of the rule """ + # Convert the integer rule_type to RuleTypes enum + enum_rule_type = RuleTypes(rule_type) + + # Now use the enum value where needed but the integer for dictionary lookups model_name = DATA_MODELS[rule_type] form_name = DATA_FORMS[rule_type] @@ -72,14 +76,14 @@ def reactivate_rule(rule_type, rule_id): form.action.choices = [(g.id, g.name) for g in db.session.query(Action).order_by("name")] form.action.data = model.action_id - if rule_type == 1: + if rule_type == RuleTypes.RTBH.value: form.community.choices = get_user_communities(session["user_role_ids"]) form.community.data = model.community_id - if rule_type == 4: + if rule_type == RuleTypes.IPv4.value: form.protocol.data = model.protocol - if rule_type == 6: + if rule_type == RuleTypes.IPv6.value: form.next_header.data = model.next_header # do not need to validate - all is readonly @@ -114,11 +118,11 @@ def reactivate_rule(rule_type, rule_id): command=command, ) announce_route(route) - # log changes + # log changes - Use the enum value here log_route( session["user_id"], model, - rule_type, + enum_rule_type, # Pass the enum instead of integer f"{session['user_email']} / {session['user_org']}", ) else: @@ -130,11 +134,11 @@ def reactivate_rule(rule_type, rule_id): command=command, ) announce_route(route) - # log changes + # log changes - Use the enum value here log_withdraw( session["user_id"], route.command, - rule_type, + enum_rule_type, # Pass the enum instead of integer model.id, f"{session['user_email']} / {session['user_org']}", ) @@ -183,6 +187,9 @@ def delete_rule(rule_type, rule_id): model_name = DATA_MODELS[rule_type] route_model = ROUTE_MODELS[rule_type] + # Convert the integer rule_type to RuleTypes enum + enum_rule_type = RuleTypes(rule_type) + model = db.session.get(model_name, rule_id) if model.id in session[constants.RULES_KEY]: # withdraw route @@ -197,7 +204,7 @@ def delete_rule(rule_type, rule_id): log_withdraw( session["user_id"], route.command, - rule_type, + enum_rule_type, model.id, f"{session['user_email']} / {session['user_org']}", ) @@ -257,6 +264,7 @@ def group_delete(): rule_type = session[constants.TYPE_ARG] model_name = DATA_MODELS_NAMED[rule_type] rule_type_int = constants.RULE_TYPES_DICT[rule_type] + enum_rule_type = RuleTypes(rule_type_int) route_model = ROUTE_MODELS[rule_type_int] rules = [str(x) for x in session[constants.RULES_KEY]] to_delete = request.form.getlist("delete-id") @@ -276,7 +284,7 @@ def group_delete(): log_withdraw( session["user_id"], route.command, - rule_type_int, + enum_rule_type, model.id, f"{session['user_email']} / {session['user_org']}", ) @@ -373,6 +381,7 @@ def group_update_save(rule_type): model_name = DATA_MODELS[rule_type] form_name = DATA_FORMS[rule_type] + enum_rule_type = RuleTypes(rule_type) form = form_name(request.form) @@ -414,7 +423,7 @@ def group_update_save(rule_type): log_route( session["user_id"], model, - rule_type, + enum_rule_type, f"{session['user_email']} / {session['user_org']}", ) else: @@ -430,7 +439,7 @@ def group_update_save(rule_type): log_withdraw( session["user_id"], route.command, - rule_type, + enum_rule_type, model.id, f"{session['user_email']} / {session['user_org']}", ) diff --git a/flowapp/views/whitelist.py b/flowapp/views/whitelist.py index adf3129..723416e 100644 --- a/flowapp/views/whitelist.py +++ b/flowapp/views/whitelist.py @@ -1,24 +1,15 @@ from datetime import datetime, timedelta from flask import Blueprint, current_app, flash, redirect, render_template, request, session, url_for -from flowapp import constants, db, messages from flowapp.auth import ( auth_required, user_or_admin_required, ) -from flowapp.constants import RuleTypes from flowapp.forms import WhitelistForm from flowapp.models import ( - Whitelist, get_user_nets, ) -from flowapp.output import ROUTE_MODELS, announce_route, log_route, log_withdraw, RouteSources, Route -from flowapp.utils import ( - flash_errors, - get_state_by_time, - quote_to_ent, - round_to_ten_minutes, -) +from flowapp.services import create_or_update_whitelist whitelist = Blueprint("whitelist", __name__, template_folder="templates") @@ -33,53 +24,14 @@ def add(): form.net_ranges = net_ranges if request.method == "POST" and form.validate(): - model = get_ipv4_model_if_exists(form.data, 1) - - if model: - model.expires = round_to_ten_minutes(form.expires.data) - flash_message = "Existing IPv4 Rule found. Expiration time was updated to new value." - else: - model = Flowspec4( - source=form.source.data, - source_mask=form.source_mask.data, - source_port=form.source_port.data, - destination=form.dest.data, - destination_mask=form.dest_mask.data, - destination_port=form.dest_port.data, - protocol=form.protocol.data, - flags=";".join(form.flags.data), - packet_len=form.packet_len.data, - fragment=";".join(form.fragment.data), - expires=round_to_ten_minutes(form.expires.data), - comment=quote_to_ent(form.comment.data), - action_id=form.action.data, - user_id=session["user_id"], - org_id=session["user_org_id"], - rstate_id=get_state_by_time(form.expires.data), - ) - flash_message = "IPv4 Rule saved" - db.session.add(model) - - db.session.commit() - flash(flash_message, "alert-success") - - # announce route if model is in active state - if model.rstate_id == 1: - command = messages.create_ipv4(model, constants.ANNOUNCE) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - - # log changes - log_route( - session["user_id"], - model, - RuleTypes.IPv4, - f"{session['user_email']} / {session['user_org']}", + model, message = create_or_update_whitelist( + form.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], ) + flash(message, "alert-success") return redirect(url_for("index")) else: @@ -91,4 +43,4 @@ def add(): default_expires = datetime.now() + timedelta(hours=1) form.expires.data = default_expires - return render_template("forms/ipv4_rule.html", form=form, action_url=url_for("rules.ipv4_rule")) + return render_template("forms/whitelist.html", form=form, action_url=url_for("whitelist.add")) From 5b22513fab7c0238ef94bccc0ea430cf4a37bca8 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Wed, 12 Mar 2025 17:56:17 +0100 Subject: [PATCH 13/46] whitelist time update and delete completed, whitelist cache clean implemented --- flowapp/models/rules/whitelist.py | 34 ++++++++++ flowapp/models/utils.py | 7 +- flowapp/services/__init__.py | 3 +- flowapp/services/base.py | 18 +++++ flowapp/services/rule_service.py | 96 ++++++++++++++++++--------- flowapp/services/whitelist_service.py | 47 ++++++++++++- flowapp/templates/macros.html | 4 +- flowapp/views/rules.py | 5 +- flowapp/views/whitelist.py | 88 ++++++++++++++++++++++-- 9 files changed, 254 insertions(+), 48 deletions(-) create mode 100644 flowapp/services/base.py diff --git a/flowapp/models/rules/whitelist.py b/flowapp/models/rules/whitelist.py index f4c87ac..e1857cc 100644 --- a/flowapp/models/rules/whitelist.py +++ b/flowapp/models/rules/whitelist.py @@ -99,3 +99,37 @@ def __init__(self, rid: int, rtype: RuleTypes, whitelist_id: int, rorigin: RuleO self.rtype = rtype.value self.rorigin = rorigin.value self.whitelist_id = whitelist_id + + @classmethod + def get_by_whitelist_id(cls, whitelist_id: int): + """ + Get all cache items with the given whitelist ID + + Args: + whitelist_id (int): The ID of the whitelist to filter by + + Returns: + list: All RuleWhitelistCache objects with the specified whitelist_id + """ + return cls.query.filter_by(whitelist_id=whitelist_id).all() + + @classmethod + def clean_by_whitelist_id(cls, whitelist_id: int): + """ + Delete all cache entries with the given whitelist ID from the database + + Args: + whitelist_id (int): The ID of the whitelist to clean + + Returns: + int: Number of rows deleted + """ + deleted = cls.query.filter_by(whitelist_id=whitelist_id).delete() + db.session.commit() + return deleted + + def __repr__(self): + return f"" + + def __str__(self): + return f"{self.rid} {self.rtype} {self.rorigin}" diff --git a/flowapp/models/utils.py b/flowapp/models/utils.py index 80a1af6..7472d72 100644 --- a/flowapp/models/utils.py +++ b/flowapp/models/utils.py @@ -59,18 +59,17 @@ def check_global_rule_limit(rule_type: RuleTypes) -> bool: return rtbh >= rtbh_limit -def get_whitelist_model_if_exists(form_data, rstate_id=1): +def get_whitelist_model_if_exists(form_data): """ Check if the record in database exist - ip, mask, rstate_id should match - expires, user_id, org_id, created, comment can be different + ip, mask should match + expires, rstate_id, user_id, org_id, created, comment can be different """ record = ( db.session.query(Whitelist) .filter( Whitelist.ip == form_data["ip"], Whitelist.mask == form_data["mask"], - Whitelist.rstate_id == rstate_id, ) .first() ) diff --git a/flowapp/services/__init__.py b/flowapp/services/__init__.py index 11766b3..05b9695 100644 --- a/flowapp/services/__init__.py +++ b/flowapp/services/__init__.py @@ -4,11 +4,12 @@ create_or_update_rtbh_rule, ) -from .whitelist_service import create_or_update_whitelist +from .whitelist_service import create_or_update_whitelist, delete_whitelist __all__ = [ create_or_update_ipv4_rule, create_or_update_ipv6_rule, create_or_update_rtbh_rule, create_or_update_whitelist, + delete_whitelist, ] diff --git a/flowapp/services/base.py b/flowapp/services/base.py new file mode 100644 index 0000000..2f4bc26 --- /dev/null +++ b/flowapp/services/base.py @@ -0,0 +1,18 @@ +from flowapp import messages +from flowapp.constants import ANNOUNCE +from flowapp.models import RTBH +from flowapp.output import Route, RouteSources, announce_route + + +def announce_rtbh_route(model: RTBH, author: str) -> None: + """ + Announce RTBH route if rule is in active state + """ + if model.rstate_id == 1: + command = messages.create_rtbh(model, ANNOUNCE) + route = Route( + author=author, + source=RouteSources.UI, + command=command, + ) + announce_route(route) diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index 5636c53..b8db5f7 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -7,7 +7,7 @@ """ from datetime import datetime -from typing import Dict, Tuple +from typing import Dict, List, Tuple from flowapp import db, messages from flowapp.constants import RuleOrigin, RuleTypes, ANNOUNCE @@ -22,8 +22,9 @@ RuleWhitelistCache, ) from flowapp.output import Route, announce_route, log_route, RouteSources +from flowapp.services.base import announce_rtbh_route from flowapp.utils import round_to_ten_minutes, get_state_by_time, quote_to_ent -from .whitelist_service import check_rule_against_whitelists, Relation +from .whitelist_service import check_rule_against_whitelists, Relation, subtract_network def create_or_update_ipv4_rule( @@ -163,7 +164,7 @@ def create_or_update_ipv6_rule( def create_or_update_rtbh_rule( form_data: Dict, user_id: int, org_id: int, user_email: str, org_name: str -) -> Tuple[RTBH, str]: +) -> Tuple[RTBH, List]: """ Create a new RTBH rule or update an existing one. @@ -179,10 +180,10 @@ def create_or_update_rtbh_rule( """ # Check for existing model model = get_rtbh_model_if_exists(form_data) - + flashes = [] if model: model.expires = round_to_ten_minutes(form_data["expires"]) - flash_message = "Existing RTBH Rule found. Expiration time was updated to new value." + flashes.append("Existing RTBH Rule found. Expiration time was updated to new value.") else: # Create new model model = RTBH( @@ -198,60 +199,89 @@ def create_or_update_rtbh_rule( rstate_id=get_state_by_time(form_data["expires"]), ) db.session.add(model) - flash_message = "RTBH Rule saved" + flashes.append("RTBH Rule saved") db.session.commit() + # rule author for logging and announcing + author = f"{user_email} / {org_name}" + # Check if rule is whitelisted # get all not expired whitelists whitelists = db.session.query(Whitelist).filter(Whitelist.expires > datetime.now()).all() - whitelists = {str(w): w for w in whitelists} - results = check_rule_against_whitelists(str(model), whitelists.keys()) + wl_cache = {str(w): w for w in whitelists} + results = check_rule_against_whitelists(str(model), wl_cache.keys()) # check rule against whitelists, stop search when rule is whitelisted first time for rule, whitelist_key, relation in results: match relation: case Relation.EQUAL: - model = whitelist_rtbh_rule(model, whitelists[whitelist_key]) + model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) + flashes.append(f" Rule is equal to active whitelist {whitelist_key}. Rule is whitelisted.") break case Relation.SUBNET: - print("WL is subnet of rule") + # split subnet into parts + parts = subtract_network(target=str(model), whitelist=whitelist_key) + wl_id = wl_cache[whitelist_key].id + flashes.append( + f" Rule is supernet of active whitelist {whitelist_key}. Rule is whitelisted, {len(parts)} subnet rules created." + ) + for network in parts: + create_rtbh_from_whitelist_parts(model, wl_id, whitelist_key, network, author, user_id) + flashes.append(f"DEBUG: Created RTBH rule for {network}, from whitelist {whitelist_key}") + + model.rstate_id = 4 + add_rtbh_rule_to_cache(model, wl_id, RuleOrigin.USER) + db.session.commit() break case Relation.SUPERNET: - model = whitelist_rtbh_rule(model, whitelists[whitelist_key]) + model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) + flashes.append(f" Rule is subnet of active whitelist {whitelist_key}. Rule is whitelisted.") break - # Announce routes - if model.rstate_id == 1: - command = messages.create_rtbh(model, ANNOUNCE) - route = Route( - author=f"{user_email} / {org_name}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - + announce_rtbh_route(model, author=author) # Log changes - log_route( - user_id, - model, - RuleTypes.RTBH, - f"{user_email} / {org_name}", + log_route(user_id, model, RuleTypes.RTBH, author) + + return model, flashes + + +def create_rtbh_from_whitelist_parts( + model: RTBH, wl_id: int, whitelist_key: str, network: str, rule_owner: str, user_id: int +) -> None: + net_ip, net_mask = network.split("/") + new_model = RTBH( + ipv4=net_ip, + ipv4_mask=net_mask, + ipv6=model.ipv6, + ipv6_mask=model.ipv6_mask, + community_id=model.community_id, + expires=model.expires, + comment=model.comment, + user_id=model.user_id, + org_id=model.org_id, + rstate_id=1, ) + db.session.add(new_model) + db.session.commit() + add_rtbh_rule_to_cache(new_model, wl_id, RuleOrigin.WHITELIST) + announce_rtbh_route(new_model, rule_owner) + log_route(user_id, model, RuleTypes.RTBH, rule_owner) - return model, flash_message + +def add_rtbh_rule_to_cache(model: RTBH, whitelist_id: int, rule_origin: RuleOrigin = RuleOrigin.USER) -> None: + # add to cache + cache = RuleWhitelistCache(rid=model.id, rtype=RuleTypes.RTBH, whitelist_id=whitelist_id, rorigin=rule_origin) + db.session.add(cache) + db.session.commit() def whitelist_rtbh_rule(model: RTBH, whitelist: Whitelist) -> RTBH: """ Whitelist RTBH rule. - set state to 4 - whitelisted rule, do not announce + Set rule state to 4 - whitelisted rule, do not announce later Add to whitelist cache """ model.rstate_id = 4 db.session.commit() - # add to cache - cache = RuleWhitelistCache(rid=model.id, rtype=RuleTypes.RTBH, whitelist_id=whitelist.id, rorigin=RuleOrigin.USER) - db.session.add(cache) - db.session.commit() - + add_rtbh_rule_to_cache(model, whitelist.id, RuleOrigin.USER) return model diff --git a/flowapp/services/whitelist_service.py b/flowapp/services/whitelist_service.py index 9652922..0f8096c 100644 --- a/flowapp/services/whitelist_service.py +++ b/flowapp/services/whitelist_service.py @@ -12,7 +12,11 @@ from functools import lru_cache from flowapp import db -from flowapp.models import Whitelist, get_whitelist_model_if_exists +from flowapp.constants import RuleOrigin, RuleTypes +from flowapp.models import Whitelist, RuleWhitelistCache, get_whitelist_model_if_exists +from flowapp.models.rules.flowspec import Flowspec4, Flowspec6 +from flowapp.models.rules.rtbh import RTBH +from flowapp.services.base import announce_rtbh_route from flowapp.utils import round_to_ten_minutes, quote_to_ent @@ -33,7 +37,7 @@ def create_or_update_whitelist( Tuple containing (whitelist_model, message) """ # Check for existing model - model = get_whitelist_model_if_exists(form_data, 1) + model = get_whitelist_model_if_exists(form_data) if model: model.expires = round_to_ten_minutes(form_data["expires"]) @@ -56,6 +60,45 @@ def create_or_update_whitelist( return model, flash_message +def delete_whitelist(whitelist_id: int) -> List[str]: + """ + Delete a whitelist entry from the database. + + Args: + whitelist_id: The ID of the whitelist to delete + """ + model = db.session.get(Whitelist, whitelist_id) + flashes = [] + if model: + cached_rules = RuleWhitelistCache.get_by_whitelist_id(whitelist_id) + for cached_rule in cached_rules: + rule_model_type = RuleTypes(cached_rule.rtype) + match rule_model_type: + case RuleTypes.IPv4: + rule_model = db.session.get(Flowspec4, cached_rule.rid) + case RuleTypes.IPv6: + rule_model = db.session.get(Flowspec6, cached_rule.rid) + case RuleTypes.RTBH: + rule_model = db.session.get(RTBH, cached_rule.rid) + rorigin_type = RuleOrigin(cached_rule.rorigin) + if rorigin_type == RuleOrigin.WHITELIST: + flashes.append(f"Deleted rule {rule_model} created by this whitelist") + db.session.delete(rule_model) + elif rorigin_type == RuleOrigin.USER: + flashes.append(f"Set rule {rule_model} back to state 'Active'") + rule_model.rstate_id = 1 # Set rule state to "Active" again + author = f"{model.user.email} ({model.user.organization})" + announce_rtbh_route(rule_model, author) + + flashes.append(f"Deleted cache entries for whitelist {whitelist_id}") + RuleWhitelistCache.clean_by_whitelist_id(whitelist_id) + + db.session.delete(model) + db.session.commit() + + return flashes + + class Relation(Enum): SUBNET = auto() SUPERNET = auto() diff --git a/flowapp/templates/macros.html b/flowapp/templates/macros.html index 9d51d3d..653cc81 100644 --- a/flowapp/templates/macros.html +++ b/flowapp/templates/macros.html @@ -140,10 +140,10 @@ {% if editable %} - + - + {% endif %} diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py index 4805e44..fc787cf 100644 --- a/flowapp/views/rules.py +++ b/flowapp/views/rules.py @@ -575,14 +575,15 @@ def rtbh_rule(): form.net_ranges = net_ranges if request.method == "POST" and form.validate(): - _model, message = rule_service.create_or_update_rtbh_rule( + _model, messages = rule_service.create_or_update_rtbh_rule( form_data=form.data, user_id=session["user_id"], org_id=session["user_org_id"], user_email=session["user_email"], org_name=session["user_org"], ) - flash(message, "alert-success") + for message in messages: + flash(message, "alert-success") return redirect(url_for("index")) else: diff --git a/flowapp/views/whitelist.py b/flowapp/views/whitelist.py index 723416e..9f2bcda 100644 --- a/flowapp/views/whitelist.py +++ b/flowapp/views/whitelist.py @@ -5,11 +5,11 @@ auth_required, user_or_admin_required, ) +from flowapp import constants, db from flowapp.forms import WhitelistForm -from flowapp.models import ( - get_user_nets, -) -from flowapp.services import create_or_update_whitelist +from flowapp.models import get_user_nets, Whitelist +from flowapp.services import create_or_update_whitelist, delete_whitelist +from flowapp.utils.base import flash_errors whitelist = Blueprint("whitelist", __name__, template_folder="templates") @@ -44,3 +44,83 @@ def add(): form.expires.data = default_expires return render_template("forms/whitelist.html", form=form, action_url=url_for("whitelist.add")) + + +@whitelist.route("/reactivate/", methods=["GET", "POST"]) +@auth_required +@user_or_admin_required +def reactivate(wl_id): + """ + Set new time for whitelist + :param wl_id: int - id of the whitelist + """ + + model = db.session.get(Whitelist, wl_id) + form = WhitelistForm(request.form, obj=model) + form.net_ranges = get_user_nets(session["user_id"]) + + # do not need to validate - all is readonly + if request.method == "POST": + model = create_or_update_whitelist( + form.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], + ) + flash("Whitelist updated", "alert-success") + return redirect( + url_for( + "dashboard.index", + rtype=session[constants.TYPE_ARG], + rstate=session[constants.RULE_ARG], + sort=session[constants.SORT_ARG], + squery=session[constants.SEARCH_ARG], + order=session[constants.ORDER_ARG], + ) + ) + else: + flash_errors(form) + + form.expires.data = model.expires + for field in form: + if field.name not in ["expires", "csrf_token", "comment"]: + field.render_kw = {"disabled": "disabled"} + + action_url = url_for("whitelist.reactivate", wl_id=wl_id) + + return render_template( + "forms/whitelist.html", + form=form, + action_url=action_url, + editing=True, + title="Update", + ) + + +@whitelist.route("/delete/", methods=["GET"]) +@auth_required +@user_or_admin_required +def delete(wl_id): + """ + Delete whitelist + :param wl_id: integer - id of the whitelist + """ + if wl_id in session[constants.RULES_KEY]: + messages = delete_whitelist(wl_id) + flash(f"Whitelist {wl_id} deleted", "alert-success") + for message in messages: + flash(message, "alert-info") + else: + flash("You can not delete this Whitelist", "alert-warning") + + return redirect( + url_for( + "dashboard.index", + rtype=session[constants.TYPE_ARG], + rstate=session[constants.RULE_ARG], + sort=session[constants.SORT_ARG], + squery=session[constants.SEARCH_ARG], + order=session[constants.ORDER_ARG], + ) + ) From 356e1f332d24cf685012cd2d224b57ef61aeac37 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Wed, 12 Mar 2025 18:38:15 +0100 Subject: [PATCH 14/46] refactoring of whitelist services / split into common lib for rules and whitelist handling services to avoid ciruclar imports --- flowapp/instance_config.py | 2 +- flowapp/services/base.py | 15 +- flowapp/services/rule_service.py | 23 +-- flowapp/services/whitelist_common.py | 150 ++++++++++++++++++ flowapp/services/whitelist_service.py | 125 --------------- ...st_service.py => test_whitelist_common.py} | 6 +- 6 files changed, 168 insertions(+), 153 deletions(-) create mode 100644 flowapp/services/whitelist_common.py rename flowapp/tests/{test_whitelist_service.py => test_whitelist_common.py} (99%) diff --git a/flowapp/instance_config.py b/flowapp/instance_config.py index 86bc5fb..4136530 100644 --- a/flowapp/instance_config.py +++ b/flowapp/instance_config.py @@ -81,8 +81,8 @@ class InstanceConfig: {"name": "Add IPv4", "url": "rules.ipv4_rule"}, {"name": "Add IPv6", "url": "rules.ipv6_rule"}, {"name": "Add RTBH", "url": "rules.rtbh_rule"}, - {"name": "API Key", "url": "api_keys.all"}, {"name": "Add Whitelist", "url": "whitelist.add"}, + {"name": "API Key", "url": "api_keys.all"}, ], "admin": [ {"name": "Commands Log", "url": "admin.log"}, diff --git a/flowapp/services/base.py b/flowapp/services/base.py index 2f4bc26..d829eee 100644 --- a/flowapp/services/base.py +++ b/flowapp/services/base.py @@ -1,6 +1,6 @@ -from flowapp import messages -from flowapp.constants import ANNOUNCE -from flowapp.models import RTBH +from flowapp import db, messages +from flowapp.constants import ANNOUNCE, RuleOrigin, RuleTypes +from flowapp.models import RTBH, RuleWhitelistCache from flowapp.output import Route, RouteSources, announce_route @@ -16,3 +16,12 @@ def announce_rtbh_route(model: RTBH, author: str) -> None: command=command, ) announce_route(route) + + +def add_rtbh_rule_to_cache(model: RTBH, whitelist_id: int, rule_origin: RuleOrigin = RuleOrigin.USER) -> None: + """ + Add RTBH rule to whitelist cache + """ + cache = RuleWhitelistCache(rid=model.id, rtype=RuleTypes.RTBH, whitelist_id=whitelist_id, rorigin=rule_origin) + db.session.add(cache) + db.session.commit() diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index b8db5f7..0d0602c 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -19,12 +19,12 @@ Flowspec6, RTBH, Whitelist, - RuleWhitelistCache, ) from flowapp.output import Route, announce_route, log_route, RouteSources from flowapp.services.base import announce_rtbh_route +from flowapp.services.whitelist_common import Relation, add_rtbh_rule_to_cache, subtract_network, whitelist_rtbh_rule from flowapp.utils import round_to_ten_minutes, get_state_by_time, quote_to_ent -from .whitelist_service import check_rule_against_whitelists, Relation, subtract_network +from .whitelist_common import check_rule_against_whitelists def create_or_update_ipv4_rule( @@ -266,22 +266,3 @@ def create_rtbh_from_whitelist_parts( add_rtbh_rule_to_cache(new_model, wl_id, RuleOrigin.WHITELIST) announce_rtbh_route(new_model, rule_owner) log_route(user_id, model, RuleTypes.RTBH, rule_owner) - - -def add_rtbh_rule_to_cache(model: RTBH, whitelist_id: int, rule_origin: RuleOrigin = RuleOrigin.USER) -> None: - # add to cache - cache = RuleWhitelistCache(rid=model.id, rtype=RuleTypes.RTBH, whitelist_id=whitelist_id, rorigin=rule_origin) - db.session.add(cache) - db.session.commit() - - -def whitelist_rtbh_rule(model: RTBH, whitelist: Whitelist) -> RTBH: - """ - Whitelist RTBH rule. - Set rule state to 4 - whitelisted rule, do not announce later - Add to whitelist cache - """ - model.rstate_id = 4 - db.session.commit() - add_rtbh_rule_to_cache(model, whitelist.id, RuleOrigin.USER) - return model diff --git a/flowapp/services/whitelist_common.py b/flowapp/services/whitelist_common.py new file mode 100644 index 0000000..9c4b49a --- /dev/null +++ b/flowapp/services/whitelist_common.py @@ -0,0 +1,150 @@ +from enum import Enum, auto +from functools import lru_cache +import ipaddress +from typing import List, Tuple +from flowapp import db +from flowapp.constants import RuleOrigin, RuleTypes +from flowapp.models import RTBH, RuleWhitelistCache, Whitelist + + +def add_rtbh_rule_to_cache(model: RTBH, whitelist_id: int, rule_origin: RuleOrigin = RuleOrigin.USER) -> None: + """ + Add RTBH rule to whitelist cache + """ + cache = RuleWhitelistCache(rid=model.id, rtype=RuleTypes.RTBH, whitelist_id=whitelist_id, rorigin=rule_origin) + db.session.add(cache) + db.session.commit() + + +def whitelist_rtbh_rule(model: RTBH, whitelist: Whitelist) -> RTBH: + """ + Whitelist RTBH rule. + Set rule state to 4 - whitelisted rule, do not announce later + Add to whitelist cache + """ + model.rstate_id = 4 + db.session.commit() + add_rtbh_rule_to_cache(model, whitelist.id, RuleOrigin.USER) + return model + + +class Relation(Enum): + SUBNET = auto() + SUPERNET = auto() + EQUAL = auto() + DIFFERENT = auto() + + +@lru_cache(maxsize=1024) +def get_network(address: str) -> ipaddress.IPv4Network | ipaddress.IPv6Network: + """ + Create and cache an IP network object. + + :param address: IP address or network in string format + :return: Cached IP network object + """ + return ipaddress.ip_network(address, strict=False) + + +def check_whitelist_to_rule_relation(rule: str, whitelist_entry: str) -> Relation: + """ + Checks if the whitelist network is a subnet or supernet or exactly the same as the rule network. + Uses cached network objects for better performance. + + :param rule: The IP address or network to check (e.g., "192.168.1.1" or "192.168.1.0/24") + :param whitelist_entry: The allowed network to compare against (e.g., "192.168.1.0/24") + :return: Relation between the two networks + """ + rule_net = get_network(rule) + whitelist_net = get_network(whitelist_entry) + if whitelist_net == rule_net: + return Relation.EQUAL + if whitelist_net.supernet_of(rule_net): + return Relation.SUPERNET + if whitelist_net.subnet_of(rule_net): + return Relation.SUBNET + else: + return Relation.DIFFERENT + + +def subtract_network(target: str, whitelist: str) -> List[str]: + """ + Computes the remaining parts of a network after removing the whitelist subnet. + Uses cached network objects for better performance. + + :param target: The main network (e.g., "192.168.1.0/24") + :param whitelist: The subnet to remove (e.g., "192.168.1.128/25") + :return: A list of remaining subnets as strings + """ + target_net = get_network(target) + whitelist_net = get_network(whitelist) + + # Check if the whitelist is actually a subnet + if check_whitelist_to_rule_relation(target, whitelist) != Relation.SUBNET: + return [target] # Return the full network if whitelist isn't a valid subnet + + remaining = [] + + # Compute ranges before and after the whitelist + if whitelist_net.network_address > target_net.network_address: + # Before the whitelist + start = target_net.network_address + end = whitelist_net.network_address - 1 + remaining.extend(ipaddress.summarize_address_range(start, end)) + + if whitelist_net.broadcast_address < target_net.broadcast_address: + # After the whitelist + start = whitelist_net.broadcast_address + 1 + end = target_net.broadcast_address + remaining.extend(ipaddress.summarize_address_range(start, end)) + + # Convert to string format + return [str(net) for net in remaining] + + +def check_rule_against_whitelists(rule: str, whitelists: List[str]) -> List[Tuple]: + """ + Helper function to check a single rule against multiple whitelist entries. + Creates a cached rule network object for better performance. + Reduces list of whitelists, where the Relation is not DIFFERENT + + :param rule: The IP address or network to check + :param whitelists: List of whitelist networks to check against + :return: tuple of rule, whitelist and relation for each whitelists that is not DIFFERENT + """ + # Pre-cache the rule network since it will be used multiple times + get_network(rule) + items = [] + for whitelist in whitelists: + rel = check_whitelist_to_rule_relation(rule, whitelist) + if rel != Relation.DIFFERENT: + items.append((rule, whitelist, rel)) + return items + + +def check_whitelist_against_rules(rules: List[str], whitelist: str) -> List[Tuple]: + """ + Helper function to check if any whitelist entry is a subnet of the rule. + Creates a cached rule network object for better performance. + Reduces list of rules, where the Relation is not DIFFERENT + + :param rule: The IP address or network to check against + :param whitelists: List of whitelist networks to check + :return: tuple of rule, whitelist and relation for each whitelists that is not DIFFERENT + """ + # Pre-cache the rule network since it will be used multiple times + get_network(whitelist) + items = [] + for rule in rules: + rel = check_whitelist_to_rule_relation(rule, whitelist) + if rel != Relation.DIFFERENT: + items.append((rule, whitelist, rel)) + return items + + +def clear_network_cache() -> None: + """ + Clear the network object cache. + Useful when processing a large number of networks to prevent memory growth. + """ + get_network.cache_clear() diff --git a/flowapp/services/whitelist_service.py b/flowapp/services/whitelist_service.py index 0f8096c..b94772f 100644 --- a/flowapp/services/whitelist_service.py +++ b/flowapp/services/whitelist_service.py @@ -7,9 +7,6 @@ """ from typing import Dict, Tuple, List -from enum import Enum, auto -import ipaddress -from functools import lru_cache from flowapp import db from flowapp.constants import RuleOrigin, RuleTypes @@ -97,125 +94,3 @@ def delete_whitelist(whitelist_id: int) -> List[str]: db.session.commit() return flashes - - -class Relation(Enum): - SUBNET = auto() - SUPERNET = auto() - EQUAL = auto() - DIFFERENT = auto() - - -@lru_cache(maxsize=1024) -def get_network(address: str) -> ipaddress.IPv4Network | ipaddress.IPv6Network: - """ - Create and cache an IP network object. - - :param address: IP address or network in string format - :return: Cached IP network object - """ - return ipaddress.ip_network(address, strict=False) - - -def check_whitelist_to_rule_relation(rule: str, whitelist_entry: str) -> Relation: - """ - Checks if the whitelist network is a subnet or supernet or exactly the same as the rule network. - Uses cached network objects for better performance. - - :param rule: The IP address or network to check (e.g., "192.168.1.1" or "192.168.1.0/24") - :param whitelist_entry: The allowed network to compare against (e.g., "192.168.1.0/24") - :return: Relation between the two networks - """ - rule_net = get_network(rule) - whitelist_net = get_network(whitelist_entry) - if whitelist_net == rule_net: - return Relation.EQUAL - if whitelist_net.supernet_of(rule_net): - return Relation.SUPERNET - if whitelist_net.subnet_of(rule_net): - return Relation.SUBNET - else: - return Relation.DIFFERENT - - -def subtract_network(target: str, whitelist: str) -> List[str]: - """ - Computes the remaining parts of a network after removing the whitelist subnet. - Uses cached network objects for better performance. - - :param target: The main network (e.g., "192.168.1.0/24") - :param whitelist: The subnet to remove (e.g., "192.168.1.128/25") - :return: A list of remaining subnets as strings - """ - target_net = get_network(target) - whitelist_net = get_network(whitelist) - - # Check if the whitelist is actually a subnet - if check_whitelist_to_rule_relation(target, whitelist) != Relation.SUBNET: - return [target] # Return the full network if whitelist isn't a valid subnet - - remaining = [] - - # Compute ranges before and after the whitelist - if whitelist_net.network_address > target_net.network_address: - # Before the whitelist - start = target_net.network_address - end = whitelist_net.network_address - 1 - remaining.extend(ipaddress.summarize_address_range(start, end)) - - if whitelist_net.broadcast_address < target_net.broadcast_address: - # After the whitelist - start = whitelist_net.broadcast_address + 1 - end = target_net.broadcast_address - remaining.extend(ipaddress.summarize_address_range(start, end)) - - # Convert to string format - return [str(net) for net in remaining] - - -def check_rule_against_whitelists(rule: str, whitelists: List[str]) -> List[Tuple]: - """ - Helper function to check a single rule against multiple whitelist entries. - Creates a cached rule network object for better performance. - Reduces list of whitelists, where the Relation is not DIFFERENT - - :param rule: The IP address or network to check - :param whitelists: List of whitelist networks to check against - :return: tuple of rule, whitelist and relation for each whitelists that is not DIFFERENT - """ - # Pre-cache the rule network since it will be used multiple times - get_network(rule) - items = [] - for whitelist in whitelists: - rel = check_whitelist_to_rule_relation(rule, whitelist) - if rel != Relation.DIFFERENT: - items.append((rule, whitelist, rel)) - return items - - -def check_whitelist_against_rules(rules: List[str], whitelist: str) -> List[Tuple]: - """ - Helper function to check if any whitelist entry is a subnet of the rule. - Creates a cached rule network object for better performance. - Reduces list of rules, where the Relation is not DIFFERENT - - :param rule: The IP address or network to check against - :param whitelists: List of whitelist networks to check - :return: tuple of rule, whitelist and relation for each whitelists that is not DIFFERENT - """ - # Pre-cache the rule network since it will be used multiple times - get_network(whitelist) - items = [] - for rule in rules: - rel = check_whitelist_to_rule_relation(rule, whitelist) - if rel != Relation.DIFFERENT: - items.append((rule, whitelist, rel)) - return items - - -def clear_network_cache() -> None: - """ - Clear the network object cache. - Useful when processing a large number of networks to prevent memory growth. - """ - get_network.cache_clear() diff --git a/flowapp/tests/test_whitelist_service.py b/flowapp/tests/test_whitelist_common.py similarity index 99% rename from flowapp/tests/test_whitelist_service.py rename to flowapp/tests/test_whitelist_common.py index 7ae4990..9b98e9a 100644 --- a/flowapp/tests/test_whitelist_service.py +++ b/flowapp/tests/test_whitelist_common.py @@ -1,10 +1,10 @@ import pytest -from flowapp.services.whitelist_service import ( +from flowapp.services.whitelist_common import ( Relation, - check_whitelist_to_rule_relation, - subtract_network, check_rule_against_whitelists, check_whitelist_against_rules, + check_whitelist_to_rule_relation, + subtract_network, clear_network_cache, ) From b765603f12cfaa17002c1b2163b734c6b10a2960 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 13 Mar 2025 08:25:21 +0100 Subject: [PATCH 15/46] refactoring of services - broken tests for rule and whitelist services --- flowapp/models/rules/rtbh.py | 3 + flowapp/services/base.py | 21 +- flowapp/services/rule_service.py | 59 ++- flowapp/services/whitelist_common.py | 38 ++ flowapp/services/whitelist_service.py | 75 ++- flowapp/tests/test_rule_service.py | 605 ++++++++++++++++++++++++ flowapp/tests/test_whitelist_service.py | 234 +++++++++ flowapp/views/whitelist.py | 5 +- 8 files changed, 991 insertions(+), 49 deletions(-) create mode 100644 flowapp/tests/test_rule_service.py create mode 100644 flowapp/tests/test_whitelist_service.py diff --git a/flowapp/models/rules/rtbh.py b/flowapp/models/rules/rtbh.py index f183b81..943fef0 100644 --- a/flowapp/models/rules/rtbh.py +++ b/flowapp/models/rules/rtbh.py @@ -148,3 +148,6 @@ def __str__(self): return f"{self.ipv6}/{self.ipv6_mask}" return f"{self.ipv4}/{self.ipv4_mask} {self.ipv6}/{self.ipv6_mask}" + + def get_author(self): + return f"{self.user.email} / {self.org}" diff --git a/flowapp/services/base.py b/flowapp/services/base.py index d829eee..f06f16f 100644 --- a/flowapp/services/base.py +++ b/flowapp/services/base.py @@ -1,6 +1,6 @@ -from flowapp import db, messages -from flowapp.constants import ANNOUNCE, RuleOrigin, RuleTypes -from flowapp.models import RTBH, RuleWhitelistCache +from flowapp import messages +from flowapp.constants import ANNOUNCE, WITHDRAW +from flowapp.models import RTBH from flowapp.output import Route, RouteSources, announce_route @@ -18,10 +18,15 @@ def announce_rtbh_route(model: RTBH, author: str) -> None: announce_route(route) -def add_rtbh_rule_to_cache(model: RTBH, whitelist_id: int, rule_origin: RuleOrigin = RuleOrigin.USER) -> None: +def withdraw_rtbh_route(model: RTBH) -> None: """ - Add RTBH rule to whitelist cache + Withdraw RTBH route if rule is in whitelist state """ - cache = RuleWhitelistCache(rid=model.id, rtype=RuleTypes.RTBH, whitelist_id=whitelist_id, rorigin=rule_origin) - db.session.add(cache) - db.session.commit() + if model.rstate_id == 4: + command = messages.create_rtbh(model, WITHDRAW) + route = Route( + author=model.get_author(), + source=RouteSources.UI, + command=command, + ) + announce_route(route) diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index 0d0602c..85c177a 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -22,7 +22,13 @@ ) from flowapp.output import Route, announce_route, log_route, RouteSources from flowapp.services.base import announce_rtbh_route -from flowapp.services.whitelist_common import Relation, add_rtbh_rule_to_cache, subtract_network, whitelist_rtbh_rule +from flowapp.services.whitelist_common import ( + Relation, + add_rtbh_rule_to_cache, + create_rtbh_from_whitelist_parts, + subtract_network, + whitelist_rtbh_rule, +) from flowapp.utils import round_to_ten_minutes, get_state_by_time, quote_to_ent from .whitelist_common import check_rule_against_whitelists @@ -209,9 +215,26 @@ def create_or_update_rtbh_rule( # Check if rule is whitelisted # get all not expired whitelists whitelists = db.session.query(Whitelist).filter(Whitelist.expires > datetime.now()).all() - wl_cache = {str(w): w for w in whitelists} + wl_cache = map_whitelists_to_strings(whitelists) results = check_rule_against_whitelists(str(model), wl_cache.keys()) # check rule against whitelists, stop search when rule is whitelisted first time + model = evaluate_rtbh_against_whitelists_check_results(user_id, model, flashes, author, wl_cache, results) + + announce_rtbh_route(model, author=author) + # Log changes + log_route(user_id, model, RuleTypes.RTBH, author) + + return model, flashes + + +def evaluate_rtbh_against_whitelists_check_results( + user_id: int, + model: RTBH, + flashes: List[str], + author: str, + wl_cache: Dict[str, Whitelist], + results: List[Tuple[str, str, Relation]], +) -> RTBH: for rule, whitelist_key, relation in results: match relation: case Relation.EQUAL: @@ -219,7 +242,6 @@ def create_or_update_rtbh_rule( flashes.append(f" Rule is equal to active whitelist {whitelist_key}. Rule is whitelisted.") break case Relation.SUBNET: - # split subnet into parts parts = subtract_network(target=str(model), whitelist=whitelist_key) wl_id = wl_cache[whitelist_key].id flashes.append( @@ -228,7 +250,6 @@ def create_or_update_rtbh_rule( for network in parts: create_rtbh_from_whitelist_parts(model, wl_id, whitelist_key, network, author, user_id) flashes.append(f"DEBUG: Created RTBH rule for {network}, from whitelist {whitelist_key}") - model.rstate_id = 4 add_rtbh_rule_to_cache(model, wl_id, RuleOrigin.USER) db.session.commit() @@ -237,32 +258,8 @@ def create_or_update_rtbh_rule( model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) flashes.append(f" Rule is subnet of active whitelist {whitelist_key}. Rule is whitelisted.") break + return model - announce_rtbh_route(model, author=author) - # Log changes - log_route(user_id, model, RuleTypes.RTBH, author) - return model, flashes - - -def create_rtbh_from_whitelist_parts( - model: RTBH, wl_id: int, whitelist_key: str, network: str, rule_owner: str, user_id: int -) -> None: - net_ip, net_mask = network.split("/") - new_model = RTBH( - ipv4=net_ip, - ipv4_mask=net_mask, - ipv6=model.ipv6, - ipv6_mask=model.ipv6_mask, - community_id=model.community_id, - expires=model.expires, - comment=model.comment, - user_id=model.user_id, - org_id=model.org_id, - rstate_id=1, - ) - db.session.add(new_model) - db.session.commit() - add_rtbh_rule_to_cache(new_model, wl_id, RuleOrigin.WHITELIST) - announce_rtbh_route(new_model, rule_owner) - log_route(user_id, model, RuleTypes.RTBH, rule_owner) +def map_whitelists_to_strings(whitelists: List[Whitelist]) -> Dict[str, Whitelist]: + return {str(w): w for w in whitelists} diff --git a/flowapp/services/whitelist_common.py b/flowapp/services/whitelist_common.py index 9c4b49a..1730402 100644 --- a/flowapp/services/whitelist_common.py +++ b/flowapp/services/whitelist_common.py @@ -5,6 +5,8 @@ from flowapp import db from flowapp.constants import RuleOrigin, RuleTypes from flowapp.models import RTBH, RuleWhitelistCache, Whitelist +from flowapp.output import log_route +from flowapp.services.base import announce_rtbh_route def add_rtbh_rule_to_cache(model: RTBH, whitelist_id: int, rule_origin: RuleOrigin = RuleOrigin.USER) -> None: @@ -29,6 +31,14 @@ def whitelist_rtbh_rule(model: RTBH, whitelist: Whitelist) -> RTBH: class Relation(Enum): + """ + Enum to represent the relation between Whitelist to Rule Relation + Subnet: Whitelist is a subnet of the rule + Supernet: Whitelist is a supernet of the rule + Equal: Whitelist is equal to the rule + Different: Whitelist is different from the rule + """ + SUBNET = auto() SUPERNET = auto() EQUAL = auto() @@ -148,3 +158,31 @@ def clear_network_cache() -> None: Useful when processing a large number of networks to prevent memory growth. """ get_network.cache_clear() + + +def create_rtbh_from_whitelist_parts( + model: RTBH, wl_id: int, whitelist_key: str, network: str, rule_owner: str = "", user_id: int = 0 +) -> None: + # default values from model + rule_owner = rule_owner or model.get_author() + user_id = user_id or model.user_id + + net_ip, net_mask = network.split("/") + new_model = RTBH( + ipv4=net_ip, + ipv4_mask=net_mask, + ipv6=model.ipv6, + ipv6_mask=model.ipv6_mask, + community_id=model.community_id, + expires=model.expires, + comment=model.comment, + user_id=model.user_id, + org_id=model.org_id, + rstate_id=1, + ) + db.session.add(new_model) + db.session.commit() + + add_rtbh_rule_to_cache(new_model, wl_id, RuleOrigin.WHITELIST) + announce_rtbh_route(new_model, rule_owner) + log_route(user_id, model, RuleTypes.RTBH, rule_owner) diff --git a/flowapp/services/whitelist_service.py b/flowapp/services/whitelist_service.py index b94772f..3f57f7e 100644 --- a/flowapp/services/whitelist_service.py +++ b/flowapp/services/whitelist_service.py @@ -13,7 +13,14 @@ from flowapp.models import Whitelist, RuleWhitelistCache, get_whitelist_model_if_exists from flowapp.models.rules.flowspec import Flowspec4, Flowspec6 from flowapp.models.rules.rtbh import RTBH -from flowapp.services.base import announce_rtbh_route +from flowapp.services.base import announce_rtbh_route, withdraw_rtbh_route +from flowapp.services.whitelist_common import add_rtbh_rule_to_cache, create_rtbh_from_whitelist_parts +from flowapp.services.whitelist_common import ( + Relation, + check_whitelist_against_rules, + subtract_network, + whitelist_rtbh_rule, +) from flowapp.utils import round_to_ten_minutes, quote_to_ent @@ -35,10 +42,10 @@ def create_or_update_whitelist( """ # Check for existing model model = get_whitelist_model_if_exists(form_data) - + flashes = [] if model: model.expires = round_to_ten_minutes(form_data["expires"]) - flash_message = "Existing Whitelist found. Expiration time was updated to new value." + flashes.append("Existing Whitelist found. Expiration time was updated to new value.") else: # Create new model model = Whitelist( @@ -50,11 +57,59 @@ def create_or_update_whitelist( comment=quote_to_ent(form_data["comment"]), ) db.session.add(model) - flash_message = "Whitelist saved" + flashes.append("Whitelist saved") db.session.commit() - return model, flash_message + # check RTBH rules against whitelist + all_rtbh_rules = RTBH.query.filter(RTBH.rstate_id == 1).all() + print(f"Found {len(all_rtbh_rules)} active RTBH rules") + rtbh_rules_map = map_rtbh_rules_to_strings(all_rtbh_rules) + result = check_whitelist_against_rules(rtbh_rules_map, str(model)) + print(f"Found {len(result)} matching RTBH rules") + model = evaluate_whitelist_against_rtbh_check_results(model, flashes, rtbh_rules_map, result) + + return model, flashes + + +def evaluate_whitelist_against_rtbh_check_results( + whitelist_model: Whitelist, + flashes: List[str], + rtbh_rule_cache: Dict[str, Whitelist], + results: List[Tuple[str, str, Relation]], +) -> Whitelist: + + for rule_key, whitelist_key, relation in results: + print(f"whitelist {whitelist_key} is {relation} to Rule {rule_key}") + match relation: + case Relation.EQUAL: + whitelist_rtbh_rule(rtbh_rule_cache[rule_key], whitelist_model) + withdraw_rtbh_route(rtbh_rule_cache[rule_key]) + flashes.append(f"Active rule {rule_key} is equal to whitelist {whitelist_key}. Rule is whitelisted.") + case Relation.SUBNET: + parts = subtract_network(target=rule_key, whitelist=whitelist_key) + wl_id = whitelist_model.id + flashes.append( + f" Rule {rule_key} is supernet of whitelist {whitelist_key}. Rule is whitelisted, {len(parts)} subnet rules created." + ) + for network in parts: + rule_model = rtbh_rule_cache[rule_key] + create_rtbh_from_whitelist_parts(rule_model, wl_id, whitelist_key, network) + flashes.append(f"DEBUG: Created RTBH rule for {network}, from whitelist {whitelist_key}") + rule_model.rstate_id = 4 + add_rtbh_rule_to_cache(rule_model, wl_id, RuleOrigin.USER) + db.session.commit() + case Relation.SUPERNET: + + whitelist_rtbh_rule(rtbh_rule_cache[rule_key], whitelist_model) + withdraw_rtbh_route(rtbh_rule_cache[rule_key]) + flashes.append(f"Active rule {rule_key} is subnet of whitelist {whitelist_key}. Rule is whitelisted.") + + return whitelist_model + + +def map_rtbh_rules_to_strings(all_rtbh_rules: List[RTBH]) -> Dict[str, RTBH]: + return {str(rule): rule for rule in all_rtbh_rules} def delete_whitelist(whitelist_id: int) -> List[str]: @@ -83,9 +138,13 @@ def delete_whitelist(whitelist_id: int) -> List[str]: db.session.delete(rule_model) elif rorigin_type == RuleOrigin.USER: flashes.append(f"Set rule {rule_model} back to state 'Active'") - rule_model.rstate_id = 1 # Set rule state to "Active" again - author = f"{model.user.email} ({model.user.organization})" - announce_rtbh_route(rule_model, author) + try: + rule_model.rstate_id = 1 # Set rule state to "Active" again + except AttributeError: + print(f"Rule {rule_model} does not exist, cache anomaly?. Skipping.") + else: + author = f"{model.user.email} ({model.user.organization})" + announce_rtbh_route(rule_model, author) flashes.append(f"Deleted cache entries for whitelist {whitelist_id}") RuleWhitelistCache.clean_by_whitelist_id(whitelist_id) diff --git a/flowapp/tests/test_rule_service.py b/flowapp/tests/test_rule_service.py new file mode 100644 index 0000000..94ab8b6 --- /dev/null +++ b/flowapp/tests/test_rule_service.py @@ -0,0 +1,605 @@ +""" +Tests for rule_service.py module. + +This test suite verifies the functionality of the rule service after refactoring, +which manages creation, updating, and processing of flow rules. +""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import patch, MagicMock + +from flowapp.constants import RuleOrigin, ANNOUNCE +from flowapp.models import Flowspec4, Flowspec6, RTBH, Whitelist +from flowapp.services import rule_service +from flowapp.services.whitelist_common import Relation +import flowapp.services.whitelist_common + + +@pytest.fixture +def ipv4_form_data(): + """Sample valid IPv4 form data""" + return { + "source": "192.168.1.0", + "source_mask": 24, + "source_port": "80", + "dest": "", + "dest_mask": None, + "dest_port": "", + "protocol": "tcp", + "flags": ["SYN"], + "packet_len": "", + "fragment": ["dont-fragment"], + "comment": "Test IPv4 rule", + "expires": datetime.now() + timedelta(hours=1), + "action": 1, + } + + +@pytest.fixture +def ipv6_form_data(): + """Sample valid IPv6 form data""" + return { + "source": "2001:db8::", + "source_mask": 32, + "source_port": "80", + "dest": "", + "dest_mask": None, + "dest_port": "", + "next_header": "tcp", + "flags": ["SYN"], + "packet_len": "", + "comment": "Test IPv6 rule", + "expires": datetime.now() + timedelta(hours=1), + "action": 1, + } + + +@pytest.fixture +def rtbh_form_data(): + """Sample valid RTBH form data""" + return { + "ipv4": "192.168.1.0", + "ipv4_mask": 24, + "ipv6": "", + "ipv6_mask": None, + "community": 1, + "comment": "Test RTBH rule", + "expires": datetime.now() + timedelta(hours=1), + } + + +@pytest.fixture +def whitelist_fixture(): + """Create a whitelist fixture""" + whitelist = MagicMock(spec=Whitelist) + whitelist.id = 1 + return whitelist + + +class TestCreateOrUpdateIPv4Rule: + @patch("flowapp.services.rule_service.get_ipv4_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_create_new_ipv4_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv4_form_data + ): + """Test creating a new IPv4 rule""" + # Mock the get_ipv4_model_if_exists to return False (not found) + mock_get_model.return_value = False + + # Mock the announce route behavior + mock_messages.create_ipv4.return_value = "mock command" + + # Call the service function + with app.app_context(): + model, message = rule_service.create_or_update_ipv4_rule( + form_data=ipv4_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was created with correct attributes + assert model is not None + assert model.source == ipv4_form_data["source"] + assert model.source_mask == ipv4_form_data["source_mask"] + assert model.protocol == ipv4_form_data["protocol"] + assert model.flags == "SYN" # form data flags are joined with ";" + assert model.action_id == ipv4_form_data["action"] + assert model.rstate_id == 1 # Active state + assert model.user_id == 1 + assert model.org_id == 1 + + # Verify message + assert message == "IPv4 Rule saved" + + # Verify route was announced + mock_messages.create_ipv4.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() + + @patch("flowapp.services.rule_service.get_ipv4_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_update_existing_ipv4_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv4_form_data + ): + """Test updating an existing IPv4 rule""" + # Create an existing model to return + existing_model = Flowspec4( + source=ipv4_form_data["source"], + source_mask=ipv4_form_data["source_mask"], + source_port=ipv4_form_data["source_port"], + destination=ipv4_form_data["dest"] or "", + destination_mask=ipv4_form_data["dest_mask"], + destination_port=ipv4_form_data["dest_port"] or "", + protocol=ipv4_form_data["protocol"], + flags=";".join(ipv4_form_data["flags"]), + packet_len=ipv4_form_data["packet_len"] or "", + fragment=";".join(ipv4_form_data["fragment"]), + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + rstate_id=1, + ) + mock_get_model.return_value = existing_model + mock_messages.create_ipv4.return_value = "mock command" + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + ipv4_form_data["expires"] = new_expires + + # Call the service function + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, message = rule_service.create_or_update_ipv4_rule( + form_data=ipv4_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() + + # Verify message + assert message == "Existing IPv4 Rule found. Expiration time was updated to new value." + + # Verify route was announced + mock_messages.create_ipv4.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() + + +class TestCreateOrUpdateIPv6Rule: + @patch("flowapp.services.rule_service.get_ipv6_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_create_new_ipv6_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv6_form_data + ): + """Test creating a new IPv6 rule""" + # Mock get_ipv6_model_if_exists to return False + mock_get_model.return_value = False + + # Mock the announce route behavior + mock_messages.create_ipv6.return_value = "mock command" + + # Call the service function + with app.app_context(): + model, message = rule_service.create_or_update_ipv6_rule( + form_data=ipv6_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was created with correct attributes + assert model is not None + assert model.source == ipv6_form_data["source"] + assert model.source_mask == ipv6_form_data["source_mask"] + assert model.next_header == ipv6_form_data["next_header"] + assert model.flags == "SYN" + assert model.action_id == ipv6_form_data["action"] + assert model.rstate_id == 1 # Active state + assert model.user_id == 1 + assert model.org_id == 1 + + # Verify message + assert message == "IPv6 Rule saved" + + # Verify route was announced + mock_messages.create_ipv6.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() + + @patch("flowapp.services.rule_service.get_ipv6_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_update_existing_ipv6_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv6_form_data + ): + """Test updating an existing IPv6 rule""" + # Create an existing model to return + existing_model = Flowspec6( + source=ipv6_form_data["source"], + source_mask=ipv6_form_data["source_mask"], + source_port=ipv6_form_data["source_port"] or "", + destination=ipv6_form_data["dest"] or "", + destination_mask=ipv6_form_data["dest_mask"], + destination_port=ipv6_form_data["dest_port"] or "", + next_header=ipv6_form_data["next_header"], + flags=";".join(ipv6_form_data["flags"]), + packet_len=ipv6_form_data["packet_len"] or "", + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + rstate_id=1, + ) + mock_get_model.return_value = existing_model + mock_messages.create_ipv6.return_value = "mock command" + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + ipv6_form_data["expires"] = new_expires + + # Call the service function + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, message = rule_service.create_or_update_ipv6_rule( + form_data=ipv6_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() + + # Verify message + assert message == "Existing IPv6 Rule found. Expiration time was updated to new value." + + # Verify route was announced + mock_messages.create_ipv6.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() + + +class TestCreateOrUpdateRTBHRule: + @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") + @patch("flowapp.services.rule_service.db.session.query") + @patch("flowapp.services.rule_service.check_rule_against_whitelists") + @patch("flowapp.services.rule_service.evaluate_rtbh_rule_against_whitelists") + @patch("flowapp.services.rule_service.announce_rtbh_route") + @patch("flowapp.services.rule_service.log_route") + def test_create_new_rtbh_rule( + self, mock_log, mock_announce, mock_evaluate, mock_check, mock_query, mock_get_model, app, db, rtbh_form_data + ): + """Test creating a new RTBH rule""" + # Mock get_rtbh_model_if_exists to return False + mock_get_model.return_value = False + + # Mock the whitelist query + mock_whitelists = [] + mock_query.return_value.filter.return_value.all.return_value = mock_whitelists + + # Mock check_rule_against_whitelists to return empty list (no matches) + mock_check.return_value = [] + + # Mock evaluate function to return the model unchanged + mock_evaluate.side_effect = lambda user_id, model, flashes, author, wl_cache, results: model + + # Call the service function + with app.app_context(): + model, messages = rule_service.create_or_update_rtbh_rule( + form_data=rtbh_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was created with correct attributes + assert model is not None + assert model.ipv4 == rtbh_form_data["ipv4"] + assert model.ipv4_mask == rtbh_form_data["ipv4_mask"] + assert model.ipv6 == rtbh_form_data["ipv6"] + assert model.ipv6_mask == rtbh_form_data["ipv6_mask"] + assert model.community_id == rtbh_form_data["community"] + assert model.rstate_id == 1 # Active state + assert model.user_id == 1 + assert model.org_id == 1 + + # Verify messages + assert "RTBH Rule saved" in messages[0] + + # Verify route was announced + mock_announce.assert_called_once() + mock_log.assert_called_once() + + # Verify evaluate function was called + mock_evaluate.assert_called_once() + + @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") + @patch("flowapp.services.rule_service.db.session.query") + @patch("flowapp.services.rule_service.check_rule_against_whitelists") + @patch("flowapp.services.rule_service.evaluate_rtbh_rule_against_whitelists") + @patch("flowapp.services.rule_service.announce_rtbh_route") + @patch("flowapp.services.rule_service.log_route") + def test_update_existing_rtbh_rule( + self, mock_log, mock_announce, mock_evaluate, mock_check, mock_query, mock_get_model, app, db, rtbh_form_data + ): + """Test updating an existing RTBH rule""" + # Create an existing model to return + existing_model = RTBH( + ipv4=rtbh_form_data["ipv4"], + ipv4_mask=rtbh_form_data["ipv4_mask"], + ipv6=rtbh_form_data["ipv6"] or "", + ipv6_mask=rtbh_form_data["ipv6_mask"], + community_id=rtbh_form_data["community"], + expires=datetime.now(), + user_id=1, + org_id=1, + rstate_id=1, + ) + mock_get_model.return_value = existing_model + + # Mock the whitelist query + mock_whitelists = [] + mock_query.return_value.filter.return_value.all.return_value = mock_whitelists + + # Mock check_rule_against_whitelists to return empty list + mock_check.return_value = [] + + # Mock evaluate function to return the model unchanged + mock_evaluate.side_effect = lambda user_id, model, flashes, author, wl_cache, results: model + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + rtbh_form_data["expires"] = new_expires + + # Call the service function + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, messages = rule_service.create_or_update_rtbh_rule( + form_data=rtbh_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() + + # Verify messages + assert "Existing RTBH Rule found" in messages[0] + + # Verify route was announced + mock_announce.assert_called_once() + mock_log.assert_called_once() + + +class TestEvaluateRtbhRuleAgainstWhitelists: + def test_equal_relation(self, app, whitelist_fixture): + """Test evaluating a rule with an EQUAL relation to a whitelist""" + # Create a model + model = MagicMock(spec=RTBH) + + # Create test data + flashes = [] + user_id = 1 + author = "test@example.com / Test Org" + whitelist_key = "192.168.1.0/24" + wl_cache = {whitelist_key: whitelist_fixture} + results = [(str(model), whitelist_key, Relation.EQUAL)] + + # Call the function with mocked whitelist_rtbh_rule + with patch("flowapp.services.rule_service.whitelist_rtbh_rule") as mock_whitelist_rule: + mock_whitelist_rule.return_value = model + + with app.app_context(): + result = rule_service.evaluate_rtbh_against_whitelists_check_results( + user_id, model, flashes, author, wl_cache, results + ) + + # Verify the rule was whitelisted + mock_whitelist_rule.assert_called_once_with(model, whitelist_fixture) + + # Verify the flash message + assert any("Rule is equal to active whitelist" in msg for msg in flashes) + + # Verify the correct model was returned + assert result == model + + def test_subnet_relation(self, app, whitelist_fixture): + """Test evaluating a rule with a SUBNET relation to a whitelist""" + # Create a model + model = MagicMock(spec=RTBH) + + # Create test data + flashes = [] + user_id = 1 + author = "test@example.com / Test Org" + whitelist_key = "192.168.1.128/25" + wl_cache = {whitelist_key: whitelist_fixture} + results = [(str(model), whitelist_key, Relation.SUBNET)] + + # Call the function with mocked dependencies + with patch("flowapp.services.rule_service.subtract_network") as mock_subtract, patch( + "flowapp.services.rule_service.create_rtbh_from_whitelist_parts" + ) as mock_create, patch("flowapp.services.rule_service.add_rtbh_rule_to_cache") as mock_add_cache, patch( + "flowapp.services.rule_service.db.session.commit" + ) as mock_commit: + + # Mock subtract_network to return some subnets + mock_subtract.return_value = ["192.168.1.0/25"] + + with app.app_context(): + _result = rule_service.evaluate_rtbh_against_whitelists_check_results( + user_id, model, flashes, author, wl_cache, results + ) + + # Verify subnet calculation was performed + mock_subtract.assert_called_once() + + # Verify new rules were created for the subnets + mock_create.assert_called_once() + + # Verify the original rule was cached + mock_add_cache.assert_called_once_with(model, whitelist_fixture.id, RuleOrigin.USER) + + # Verify transaction was committed + mock_commit.assert_called_once() + + # Verify the flash messages + assert any("Rule is supernet of active whitelist" in msg for msg in flashes) + assert any("Created RTBH rule for" in msg for msg in flashes) + + # Verify model was updated to whitelisted state + assert model.rstate_id == 4 + + def test_supernet_relation(self, app, whitelist_fixture): + """Test evaluating a rule with a SUPERNET relation to a whitelist""" + # Create a model + model = MagicMock(spec=RTBH) + + # Create test data + flashes = [] + user_id = 1 + author = "test@example.com / Test Org" + whitelist_key = "192.168.0.0/16" + wl_cache = {whitelist_key: whitelist_fixture} + results = [(str(model), whitelist_key, Relation.SUPERNET)] + + # Call the function with mocked whitelist_rtbh_rule + with patch("flowapp.services.rule_service.whitelist_rtbh_rule") as mock_whitelist_rule: + mock_whitelist_rule.return_value = model + + with app.app_context(): + result = rule_service.evaluate_rtbh_against_whitelists_check_results( + user_id, model, flashes, author, wl_cache, results + ) + + # Verify the rule was whitelisted + mock_whitelist_rule.assert_called_once_with(model, whitelist_fixture) + + # Verify the flash message + assert any("Rule is subnet of active whitelist" in msg for msg in flashes) + + # Verify the correct model was returned + assert result == model + + def test_no_relation(self, app): + """Test evaluating a rule with no relation to any whitelist""" + # Create a model + model = MagicMock(spec=RTBH) + + # Create test data + flashes = [] + user_id = 1 + author = "test@example.com / Test Org" + wl_cache = {} + results = [] + + with app.app_context(): + result = rule_service.evaluate_rtbh_against_whitelists_check_results( + user_id, model, flashes, author, wl_cache, results + ) + + # Verify no changes to the model and no messages + assert result == model + assert not flashes + + +class TestCreateRtbhFromWhitelistParts: + @patch("flowapp.services.rule_service.add_rtbh_rule_to_cache") + @patch("flowapp.services.rule_service.announce_rtbh_route") + @patch("flowapp.services.rule_service.log_route") + def test_create_rtbh_from_whitelist_parts(self, mock_log, mock_announce, mock_add_cache, app, db): + """Test creating RTBH rules from whitelist parts""" + # Create an initial RTBH rule model + model = RTBH( + ipv4="192.168.1.0", + ipv4_mask=24, + ipv6="", + ipv6_mask=None, + community_id=1, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Set up test data + network = "192.168.1.128/25" + whitelist_id = 1 + whitelist_key = "192.168.1.128/25" + rule_owner = "test@example.com / Test Org" + user_id = 1 + + # Call the function + with app.app_context(): + db.session.add(model) + db.session.commit() + + flowapp.services.whitelist_common.create_rtbh_from_whitelist_parts( + model, whitelist_id, whitelist_key, network, rule_owner, user_id + ) + + # Verify a new model was created and saved + net_ip, net_mask = network.split("/") + subnet_rules = db.session.query(RTBH).filter(RTBH.ipv4 == net_ip, RTBH.ipv4_mask == net_mask).all() + + assert len(subnet_rules) == 1 + assert subnet_rules[0].rstate_id == 1 # Active state + assert subnet_rules[0].ipv4 == "192.168.1.128" + assert subnet_rules[0].ipv4_mask == 25 + + # Verify caching and announcement + mock_add_cache.assert_called_once() + mock_announce.assert_called_once() + mock_log.assert_called_once() + + +class TestMapWhitelistsToStrings: + def test_map_whitelists_to_strings(self): + """Test mapping whitelist objects to strings""" + # Create mock whitelists + whitelist1 = MagicMock(spec=Whitelist) + whitelist1.__str__.return_value = "192.168.1.0/24" + + whitelist2 = MagicMock(spec=Whitelist) + whitelist2.__str__.return_value = "10.0.0.0/8" + + whitelists = [whitelist1, whitelist2] + + # Call the function + result = rule_service.map_whitelists_to_strings(whitelists) + + # Verify the result + assert len(result) == 2 + assert "192.168.1.0/24" in result + assert "10.0.0.0/8" in result + assert result["192.168.1.0/24"] == whitelist1 + assert result["10.0.0.0/8"] == whitelist2 diff --git a/flowapp/tests/test_whitelist_service.py b/flowapp/tests/test_whitelist_service.py new file mode 100644 index 0000000..1f5ffd0 --- /dev/null +++ b/flowapp/tests/test_whitelist_service.py @@ -0,0 +1,234 @@ +""" +Tests for whitelist_service.py module. + +This test suite verifies the functionality of the whitelist service, +which manages creation, updating, and handling of whitelist rules. +""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import patch + +from flowapp.constants import RuleTypes, RuleOrigin +from flowapp.models import Whitelist, RuleWhitelistCache, RTBH +from flowapp.services import whitelist_service + + +@pytest.fixture +def whitelist_form_data(): + """Sample valid whitelist form data""" + return { + "ip": "192.168.1.0", + "mask": 24, + "comment": "Test whitelist entry", + "expires": datetime.now() + timedelta(hours=1), + } + + +class TestCreateOrUpdateWhitelist: + def test_create_new_whitelist(self, app, db, whitelist_form_data): + """Test creating a new whitelist entry""" + # Mock the get_whitelist_model_if_exists to return False (not found) + with patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists", return_value=False): + # Call the service function + with app.app_context(): + model, message = whitelist_service.create_or_update_whitelist( + form_data=whitelist_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was created with correct attributes + assert model is not None + assert model.ip == whitelist_form_data["ip"] + assert model.mask == whitelist_form_data["mask"] + assert model.comment == whitelist_form_data["comment"] + assert model.user_id == 1 + assert model.org_id == 1 + assert model.rstate_id == 1 # Active state + + # Verify message + assert message == "Whitelist saved" + + # Verify the whitelist was saved to the database + saved_whitelist = db.session.query(Whitelist).filter_by(ip=whitelist_form_data["ip"]).first() + assert saved_whitelist is not None + assert saved_whitelist.id == model.id + + def test_update_existing_whitelist(self, app, db, whitelist_form_data): + """Test updating an existing whitelist entry""" + # Create an existing whitelist + existing_model = Whitelist( + ip=whitelist_form_data["ip"], + mask=whitelist_form_data["mask"], + expires=datetime.now(), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Mock to return the existing model + with patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists", return_value=existing_model): + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + whitelist_form_data["expires"] = new_expires + + # Call the service function + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, message = whitelist_service.create_or_update_whitelist( + form_data=whitelist_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + # Check that expiration date was updated (rounded to 10 minutes) + # We can't compare exact timestamps, so check date parts + assert model.expires.date() == whitelist_service.round_to_ten_minutes(new_expires).date() + + # Verify message + assert message == "Existing Whitelist found. Expiration time was updated to new value." + + +class TestDeleteWhitelist: + @patch("flowapp.services.whitelist_service.announce_rtbh_route") + def test_delete_whitelist_with_user_rules(self, mock_announce, app, db): + """Test deleting a whitelist that has user-created rules attached to it""" + # Create a whitelist + whitelist = Whitelist( + ip="192.168.1.0", + mask=24, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Create a RTBH rule (created by user, whitelisted) + rtbh_rule = RTBH( + ipv4="192.168.1.0", + ipv4_mask=24, + ipv6="", + ipv6_mask=None, + community_id=1, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=4, # Whitelisted state + ) + + # Create a cache entry linking the rule to the whitelist + cache_entry = RuleWhitelistCache( + rid=1, + rtype=RuleTypes.RTBH, + whitelist_id=1, + rorigin=RuleOrigin.USER, + ) + + with app.app_context(): + db.session.add(whitelist) + db.session.add(rtbh_rule) + db.session.commit() + + # Set whitelist ID now that it has been saved + cache_entry.whitelist_id = whitelist.id + cache_entry.rid = rtbh_rule.id + db.session.add(cache_entry) + db.session.commit() + + # Mock the get_by_whitelist_id to return our cache entry + with patch.object(RuleWhitelistCache, "get_by_whitelist_id", return_value=[cache_entry]): + # Call the service function + messages = whitelist_service.delete_whitelist(whitelist.id) + + # Verify the rule state was changed back to active + rtbh_rule = db.session.get(RTBH, rtbh_rule.id) + assert rtbh_rule.rstate_id == 1 # Active state + + # Verify announcement was made + mock_announce.assert_called_once() + + # Verify messages + assert any("Set rule" in msg for msg in messages) + + # Verify the whitelist was deleted + assert db.session.get(Whitelist, whitelist.id) is None + + @patch("flowapp.services.whitelist_service.RuleWhitelistCache.clean_by_whitelist_id") + def test_delete_whitelist_with_whitelist_created_rules(self, mock_clean, app, db): + """Test deleting a whitelist that has rules created by the whitelist""" + # Create a whitelist + whitelist = Whitelist( + ip="192.168.1.0", + mask=24, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Create a RTBH rule (created by whitelist) + rtbh_rule = RTBH( + ipv4="192.168.1.0", + ipv4_mask=24, + ipv6="", + ipv6_mask=None, + community_id=1, + expires=datetime.now() + timedelta(hours=1), + user_id=1, + org_id=1, + rstate_id=1, + ) + + # Create a cache entry linking the rule to the whitelist + cache_entry = RuleWhitelistCache( + rid=1, + rtype=RuleTypes.RTBH, + whitelist_id=1, + rorigin=RuleOrigin.WHITELIST, # Important: created BY whitelist + ) + + with app.app_context(): + db.session.add(whitelist) + db.session.add(rtbh_rule) + db.session.commit() + + # Set IDs now that they have been saved + cache_entry.whitelist_id = whitelist.id + cache_entry.rid = rtbh_rule.id + db.session.add(cache_entry) + db.session.commit() + + # Create a mock session that can get our rule + with patch.object(RuleWhitelistCache, "get_by_whitelist_id", return_value=[cache_entry]): + # Call the service function + messages = whitelist_service.delete_whitelist(whitelist.id) + + # Verify messages + assert any("Deleted rule" in msg for msg in messages) + + # Verify the rule was deleted + assert db.session.get(RTBH, rtbh_rule.id) is None + + # Verify the whitelist was deleted + assert db.session.get(Whitelist, whitelist.id) is None + + # Verify cache cleanup was called + mock_clean.assert_called_once_with(whitelist.id) + + def test_delete_nonexistent_whitelist(self, app, db): + """Test deleting a whitelist that doesn't exist""" + with app.app_context(): + # Call the service function with a non-existent ID + messages = whitelist_service.delete_whitelist(999) + + # Should return empty list of messages, as no whitelist was found + assert len(messages) == 0 diff --git a/flowapp/views/whitelist.py b/flowapp/views/whitelist.py index 9f2bcda..39a9352 100644 --- a/flowapp/views/whitelist.py +++ b/flowapp/views/whitelist.py @@ -24,14 +24,15 @@ def add(): form.net_ranges = net_ranges if request.method == "POST" and form.validate(): - model, message = create_or_update_whitelist( + model, messages = create_or_update_whitelist( form.data, user_id=session["user_id"], org_id=session["user_org_id"], user_email=session["user_email"], org_name=session["user_org"], ) - flash(message, "alert-success") + for message in messages: + flash(message, "alert-success") return redirect(url_for("index")) else: From 8b499bc851bc881e4e35625058f13b3a3c6fc045 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 13 Mar 2025 08:43:33 +0100 Subject: [PATCH 16/46] refactoring of services - fixed tests for rule and whitelist services, all tests passing --- flowapp/tests/test_rule_service.py | 266 ++++++++++++---------- flowapp/tests/test_whitelist_service.py | 287 +++++++++++++++++++++--- 2 files changed, 402 insertions(+), 151 deletions(-) diff --git a/flowapp/tests/test_rule_service.py b/flowapp/tests/test_rule_service.py index 94ab8b6..16daf10 100644 --- a/flowapp/tests/test_rule_service.py +++ b/flowapp/tests/test_rule_service.py @@ -13,7 +13,6 @@ from flowapp.models import Flowspec4, Flowspec6, RTBH, Whitelist from flowapp.services import rule_service from flowapp.services.whitelist_common import Relation -import flowapp.services.whitelist_common @pytest.fixture @@ -113,7 +112,7 @@ def test_create_new_ipv4_rule( assert model.user_id == 1 assert model.org_id == 1 - # Verify message + # Verify message is still a string for IPv4 rules assert message == "IPv4 Rule saved" # Verify route was announced @@ -171,7 +170,7 @@ def test_update_existing_ipv4_rule( assert model == existing_model assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() - # Verify message + # Verify message is still a string for IPv4 rules assert message == "Existing IPv4 Rule found. Expiration time was updated to new value." # Verify route was announced @@ -216,7 +215,7 @@ def test_create_new_ipv6_rule( assert model.user_id == 1 assert model.org_id == 1 - # Verify message + # Verify message is still a string for IPv6 rules assert message == "IPv6 Rule saved" # Verify route was announced @@ -224,69 +223,12 @@ def test_create_new_ipv6_rule( mock_announce.assert_called_once() mock_log.assert_called_once() - @patch("flowapp.services.rule_service.get_ipv6_model_if_exists") - @patch("flowapp.services.rule_service.messages") - @patch("flowapp.services.rule_service.announce_route") - @patch("flowapp.services.rule_service.log_route") - def test_update_existing_ipv6_rule( - self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv6_form_data - ): - """Test updating an existing IPv6 rule""" - # Create an existing model to return - existing_model = Flowspec6( - source=ipv6_form_data["source"], - source_mask=ipv6_form_data["source_mask"], - source_port=ipv6_form_data["source_port"] or "", - destination=ipv6_form_data["dest"] or "", - destination_mask=ipv6_form_data["dest_mask"], - destination_port=ipv6_form_data["dest_port"] or "", - next_header=ipv6_form_data["next_header"], - flags=";".join(ipv6_form_data["flags"]), - packet_len=ipv6_form_data["packet_len"] or "", - expires=datetime.now(), - user_id=1, - org_id=1, - action_id=1, - rstate_id=1, - ) - mock_get_model.return_value = existing_model - mock_messages.create_ipv6.return_value = "mock command" - - # Set a new expiration time - new_expires = datetime.now() + timedelta(days=1) - ipv6_form_data["expires"] = new_expires - - # Call the service function - with app.app_context(): - db.session.add(existing_model) - db.session.commit() - - model, message = rule_service.create_or_update_ipv6_rule( - form_data=ipv6_form_data, - user_id=1, - org_id=1, - user_email="test@example.com", - org_name="Test Org", - ) - - # Verify the model was updated - assert model == existing_model - assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() - - # Verify message - assert message == "Existing IPv6 Rule found. Expiration time was updated to new value." - - # Verify route was announced - mock_messages.create_ipv6.assert_called_once_with(model, ANNOUNCE) - mock_announce.assert_called_once() - mock_log.assert_called_once() - class TestCreateOrUpdateRTBHRule: @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") @patch("flowapp.services.rule_service.db.session.query") @patch("flowapp.services.rule_service.check_rule_against_whitelists") - @patch("flowapp.services.rule_service.evaluate_rtbh_rule_against_whitelists") + @patch("flowapp.services.rule_service.evaluate_rtbh_against_whitelists_check_results") @patch("flowapp.services.rule_service.announce_rtbh_route") @patch("flowapp.services.rule_service.log_route") def test_create_new_rtbh_rule( @@ -308,7 +250,7 @@ def test_create_new_rtbh_rule( # Call the service function with app.app_context(): - model, messages = rule_service.create_or_update_rtbh_rule( + model, flashes = rule_service.create_or_update_rtbh_rule( form_data=rtbh_form_data, user_id=1, org_id=1, @@ -327,10 +269,11 @@ def test_create_new_rtbh_rule( assert model.user_id == 1 assert model.org_id == 1 - # Verify messages - assert "RTBH Rule saved" in messages[0] + # Verify flash messages - now a list instead of a string + assert isinstance(flashes, list) + assert "RTBH Rule saved" in flashes[0] - # Verify route was announced + # Verify rule was announced mock_announce.assert_called_once() mock_log.assert_called_once() @@ -340,7 +283,7 @@ def test_create_new_rtbh_rule( @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") @patch("flowapp.services.rule_service.db.session.query") @patch("flowapp.services.rule_service.check_rule_against_whitelists") - @patch("flowapp.services.rule_service.evaluate_rtbh_rule_against_whitelists") + @patch("flowapp.services.rule_service.evaluate_rtbh_against_whitelists_check_results") @patch("flowapp.services.rule_service.announce_rtbh_route") @patch("flowapp.services.rule_service.log_route") def test_update_existing_rtbh_rule( @@ -380,7 +323,7 @@ def test_update_existing_rtbh_rule( db.session.add(existing_model) db.session.commit() - model, messages = rule_service.create_or_update_rtbh_rule( + model, flashes = rule_service.create_or_update_rtbh_rule( form_data=rtbh_form_data, user_id=1, org_id=1, @@ -392,15 +335,89 @@ def test_update_existing_rtbh_rule( assert model == existing_model assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() - # Verify messages - assert "Existing RTBH Rule found" in messages[0] + # Verify flash messages + assert isinstance(flashes, list) + assert "Existing RTBH Rule found" in flashes[0] # Verify route was announced mock_announce.assert_called_once() mock_log.assert_called_once() + @patch("flowapp.services.rule_service.get_rtbh_model_if_exists") + @patch("flowapp.services.rule_service.db.session.query") + @patch("flowapp.services.rule_service.map_whitelists_to_strings") + @patch("flowapp.services.rule_service.check_rule_against_whitelists") + @patch("flowapp.services.rule_service.evaluate_rtbh_against_whitelists_check_results") + @patch("flowapp.services.rule_service.announce_rtbh_route") + @patch("flowapp.services.rule_service.log_route") + def test_rtbh_rule_with_whitelists( + self, + mock_log, + mock_announce, + mock_evaluate, + mock_check, + mock_map, + mock_query, + mock_get_model, + app, + db, + rtbh_form_data, + whitelist_fixture, + ): + """Test creating a RTBH rule that interacts with whitelists""" + # Mock get_rtbh_model_if_exists to return False + mock_get_model.return_value = False + + # Create a mock whitelist + mock_whitelist = whitelist_fixture + mock_whitelists = [mock_whitelist] + + # Setup mock query to return our whitelist + mock_query_result = MagicMock() + mock_query_result.filter.return_value.all.return_value = mock_whitelists + mock_query.return_value = mock_query_result + + # Setup map function to return our whitelist in a dict + whitelist_key = "192.168.1.0/24" + mock_map.return_value = {whitelist_key: mock_whitelist} + + # Setup check function to return a relation + mock_rtbh = MagicMock(spec=RTBH) + mock_rtbh.__str__.return_value = "192.168.1.0/24" + + # Setup check result + rule_relation = [(str(mock_rtbh), whitelist_key, Relation.EQUAL)] + mock_check.return_value = rule_relation + + # Setup evaluate function to add a flash message + def evaluate_side_effect(user_id, model, flashes, author, wl_cache, results): + flashes.append("Rule is equal to whitelist") + return model + + mock_evaluate.side_effect = evaluate_side_effect + + # Call the service function + with app.app_context(): + model, flashes = rule_service.create_or_update_rtbh_rule( + form_data=rtbh_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify flash messages show both rule creation and whitelist check + assert isinstance(flashes, list) + assert "RTBH Rule saved" in flashes[0] + assert "Rule is equal to whitelist" in flashes[1] + + # Verify interactions + mock_map.assert_called_once() + mock_check.assert_called_once() + mock_evaluate.assert_called_once() + -class TestEvaluateRtbhRuleAgainstWhitelists: +class TestEvaluateRtbhAgainstWhitelistsCheckResults: def test_equal_relation(self, app, whitelist_fixture): """Test evaluating a rule with an EQUAL relation to a whitelist""" # Create a model @@ -532,56 +549,6 @@ def test_no_relation(self, app): assert not flashes -class TestCreateRtbhFromWhitelistParts: - @patch("flowapp.services.rule_service.add_rtbh_rule_to_cache") - @patch("flowapp.services.rule_service.announce_rtbh_route") - @patch("flowapp.services.rule_service.log_route") - def test_create_rtbh_from_whitelist_parts(self, mock_log, mock_announce, mock_add_cache, app, db): - """Test creating RTBH rules from whitelist parts""" - # Create an initial RTBH rule model - model = RTBH( - ipv4="192.168.1.0", - ipv4_mask=24, - ipv6="", - ipv6_mask=None, - community_id=1, - expires=datetime.now() + timedelta(hours=1), - user_id=1, - org_id=1, - rstate_id=1, - ) - - # Set up test data - network = "192.168.1.128/25" - whitelist_id = 1 - whitelist_key = "192.168.1.128/25" - rule_owner = "test@example.com / Test Org" - user_id = 1 - - # Call the function - with app.app_context(): - db.session.add(model) - db.session.commit() - - flowapp.services.whitelist_common.create_rtbh_from_whitelist_parts( - model, whitelist_id, whitelist_key, network, rule_owner, user_id - ) - - # Verify a new model was created and saved - net_ip, net_mask = network.split("/") - subnet_rules = db.session.query(RTBH).filter(RTBH.ipv4 == net_ip, RTBH.ipv4_mask == net_mask).all() - - assert len(subnet_rules) == 1 - assert subnet_rules[0].rstate_id == 1 # Active state - assert subnet_rules[0].ipv4 == "192.168.1.128" - assert subnet_rules[0].ipv4_mask == 25 - - # Verify caching and announcement - mock_add_cache.assert_called_once() - mock_announce.assert_called_once() - mock_log.assert_called_once() - - class TestMapWhitelistsToStrings: def test_map_whitelists_to_strings(self): """Test mapping whitelist objects to strings""" @@ -603,3 +570,60 @@ def test_map_whitelists_to_strings(self): assert "10.0.0.0/8" in result assert result["192.168.1.0/24"] == whitelist1 assert result["10.0.0.0/8"] == whitelist2 + + @patch("flowapp.services.rule_service.get_ipv6_model_if_exists") + @patch("flowapp.services.rule_service.messages") + @patch("flowapp.services.rule_service.announce_route") + @patch("flowapp.services.rule_service.log_route") + def test_update_existing_ipv6_rule( + self, mock_log, mock_announce, mock_messages, mock_get_model, app, db, ipv6_form_data + ): + """Test updating an existing IPv6 rule""" + # Create an existing model to return + existing_model = Flowspec6( + source=ipv6_form_data["source"], + source_mask=ipv6_form_data["source_mask"], + source_port=ipv6_form_data["source_port"] or "", + destination=ipv6_form_data["dest"] or "", + destination_mask=ipv6_form_data["dest_mask"], + destination_port=ipv6_form_data["dest_port"] or "", + next_header=ipv6_form_data["next_header"], + flags=";".join(ipv6_form_data["flags"]), + packet_len=ipv6_form_data["packet_len"] or "", + expires=datetime.now(), + user_id=1, + org_id=1, + action_id=1, + rstate_id=1, + ) + mock_get_model.return_value = existing_model + mock_messages.create_ipv6.return_value = "mock command" + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + ipv6_form_data["expires"] = new_expires + + # Call the service function + with app.app_context(): + db.session.add(existing_model) + db.session.commit() + + model, message = rule_service.create_or_update_ipv6_rule( + form_data=ipv6_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify the model was updated + assert model == existing_model + assert model.expires.date() == rule_service.round_to_ten_minutes(new_expires).date() + + # Verify message is still a string for IPv6 rules + assert message == "Existing IPv6 Rule found. Expiration time was updated to new value." + + # Verify route was announced + mock_messages.create_ipv6.assert_called_once_with(model, ANNOUNCE) + mock_announce.assert_called_once() + mock_log.assert_called_once() diff --git a/flowapp/tests/test_whitelist_service.py b/flowapp/tests/test_whitelist_service.py index 1f5ffd0..519c2aa 100644 --- a/flowapp/tests/test_whitelist_service.py +++ b/flowapp/tests/test_whitelist_service.py @@ -7,11 +7,12 @@ import pytest from datetime import datetime, timedelta -from unittest.mock import patch +from unittest.mock import patch, MagicMock from flowapp.constants import RuleTypes, RuleOrigin from flowapp.models import Whitelist, RuleWhitelistCache, RTBH from flowapp.services import whitelist_service +from flowapp.services.whitelist_common import Relation @pytest.fixture @@ -26,13 +27,31 @@ def whitelist_form_data(): class TestCreateOrUpdateWhitelist: - def test_create_new_whitelist(self, app, db, whitelist_form_data): + @patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists") + @patch("flowapp.services.whitelist_service.check_whitelist_against_rules") + @patch("flowapp.services.whitelist_service.evaluate_whitelist_against_rtbh_check_results") + def test_create_new_whitelist( + self, mock_evaluate, mock_check_whitelist, mock_get_model, app, db, whitelist_form_data + ): """Test creating a new whitelist entry""" # Mock the get_whitelist_model_if_exists to return False (not found) - with patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists", return_value=False): - # Call the service function + mock_get_model.return_value = False + + # Mock RTBH rules and check results + mock_rtbh_rules = [] + mock_query = MagicMock() + mock_query.filter.return_value.all.return_value = mock_rtbh_rules + + # Mock check_whitelist_against_rules to return empty list (no matches) + mock_check_whitelist.return_value = [] + + # Mock evaluate to return the model unchanged + mock_evaluate.side_effect = lambda model, flashes, rtbh_rule_cache, results: model + + # Call the service function + with patch("flowapp.services.whitelist_service.db.session.query", return_value=mock_query): with app.app_context(): - model, message = whitelist_service.create_or_update_whitelist( + model, flashes = whitelist_service.create_or_update_whitelist( form_data=whitelist_form_data, user_id=1, org_id=1, @@ -49,15 +68,16 @@ def test_create_new_whitelist(self, app, db, whitelist_form_data): assert model.org_id == 1 assert model.rstate_id == 1 # Active state - # Verify message - assert message == "Whitelist saved" + # Verify flash messages - now a list instead of a string + assert isinstance(flashes, list) + assert "Whitelist saved" in flashes[0] - # Verify the whitelist was saved to the database - saved_whitelist = db.session.query(Whitelist).filter_by(ip=whitelist_form_data["ip"]).first() - assert saved_whitelist is not None - assert saved_whitelist.id == model.id - - def test_update_existing_whitelist(self, app, db, whitelist_form_data): + @patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists") + @patch("flowapp.services.whitelist_service.check_whitelist_against_rules") + @patch("flowapp.services.whitelist_service.evaluate_whitelist_against_rtbh_check_results") + def test_update_existing_whitelist( + self, mock_evaluate, mock_check_whitelist, mock_get_model, app, db, whitelist_form_data + ): """Test updating an existing whitelist entry""" # Create an existing whitelist existing_model = Whitelist( @@ -70,17 +90,30 @@ def test_update_existing_whitelist(self, app, db, whitelist_form_data): ) # Mock to return the existing model - with patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists", return_value=existing_model): - # Set a new expiration time - new_expires = datetime.now() + timedelta(days=1) - whitelist_form_data["expires"] = new_expires + mock_get_model.return_value = existing_model + + # Mock RTBH rules and check results + mock_rtbh_rules = [] + mock_query = MagicMock() + mock_query.filter.return_value.all.return_value = mock_rtbh_rules + + # Mock check_whitelist_against_rules to return empty list (no matches) + mock_check_whitelist.return_value = [] - # Call the service function + # Mock evaluate to return the model unchanged + mock_evaluate.side_effect = lambda model, flashes, rtbh_rule_cache, results: model + + # Set a new expiration time + new_expires = datetime.now() + timedelta(days=1) + whitelist_form_data["expires"] = new_expires + + # Call the service function + with patch("flowapp.services.whitelist_service.db.session.query", return_value=mock_query): with app.app_context(): db.session.add(existing_model) db.session.commit() - model, message = whitelist_service.create_or_update_whitelist( + model, flashes = whitelist_service.create_or_update_whitelist( form_data=whitelist_form_data, user_id=1, org_id=1, @@ -94,8 +127,70 @@ def test_update_existing_whitelist(self, app, db, whitelist_form_data): # We can't compare exact timestamps, so check date parts assert model.expires.date() == whitelist_service.round_to_ten_minutes(new_expires).date() - # Verify message - assert message == "Existing Whitelist found. Expiration time was updated to new value." + # Verify flash messages - now a list instead of a string + assert isinstance(flashes, list) + assert "Existing Whitelist found" in flashes[0] + + @patch("flowapp.services.whitelist_service.get_whitelist_model_if_exists") + @patch("flowapp.services.whitelist_service.db.session.query") + @patch("flowapp.services.whitelist_service.map_rtbh_rules_to_strings") + @patch("flowapp.services.whitelist_service.check_whitelist_against_rules") + @patch("flowapp.services.whitelist_service.evaluate_whitelist_against_rtbh_check_results") + def test_create_whitelist_with_matching_rules( + self, mock_evaluate, mock_check, mock_map, mock_query, mock_get_model, app, db, whitelist_form_data + ): + """Test creating a whitelist that affects existing rules""" + # Mock get_whitelist_model_if_exists to return False (new whitelist) + mock_get_model.return_value = False + + # Create a mock RTBH rule + mock_rtbh_rule = MagicMock(spec=RTBH) + mock_rtbh_rule.__str__.return_value = "192.168.1.0/24" + mock_rtbh_rules = [mock_rtbh_rule] + + # Setup mock query to return our rule + mock_query_result = MagicMock() + mock_query_result.filter.return_value.all.return_value = mock_rtbh_rules + mock_query.return_value = mock_query_result + + # Setup map function to return our rule in a dict + mock_map.return_value = {"192.168.1.0/24": mock_rtbh_rule} + + # Setup check function to return a relation + rule_relation = [ + ( + str(mock_rtbh_rule), + str(whitelist_form_data["ip"]) + "/" + str(whitelist_form_data["mask"]), + Relation.EQUAL, + ) + ] + mock_check.return_value = rule_relation + + # Setup evaluate function to modify flashes + def evaluate_side_effect(model, flashes, rtbh_rule_cache, results): + flashes.append("Rule was whitelisted") + return model + + mock_evaluate.side_effect = evaluate_side_effect + + # Call the service function + with app.app_context(): + model, flashes = whitelist_service.create_or_update_whitelist( + form_data=whitelist_form_data, + user_id=1, + org_id=1, + user_email="test@example.com", + org_name="Test Org", + ) + + # Verify flash messages includes both whitelist creation and rule whitelisting + assert "Whitelist saved" in flashes[0] + assert "Rule was whitelisted" in flashes[1] + + # Verify interactions + mock_map.assert_called_once() + mock_check.assert_called_once() + mock_evaluate.assert_called_once() class TestDeleteWhitelist: @@ -147,7 +242,7 @@ def test_delete_whitelist_with_user_rules(self, mock_announce, app, db): # Mock the get_by_whitelist_id to return our cache entry with patch.object(RuleWhitelistCache, "get_by_whitelist_id", return_value=[cache_entry]): # Call the service function - messages = whitelist_service.delete_whitelist(whitelist.id) + flashes = whitelist_service.delete_whitelist(whitelist.id) # Verify the rule state was changed back to active rtbh_rule = db.session.get(RTBH, rtbh_rule.id) @@ -156,8 +251,9 @@ def test_delete_whitelist_with_user_rules(self, mock_announce, app, db): # Verify announcement was made mock_announce.assert_called_once() - # Verify messages - assert any("Set rule" in msg for msg in messages) + # Verify flash messages + assert isinstance(flashes, list) + assert any("Set rule" in msg for msg in flashes) # Verify the whitelist was deleted assert db.session.get(Whitelist, whitelist.id) is None @@ -210,10 +306,11 @@ def test_delete_whitelist_with_whitelist_created_rules(self, mock_clean, app, db # Create a mock session that can get our rule with patch.object(RuleWhitelistCache, "get_by_whitelist_id", return_value=[cache_entry]): # Call the service function - messages = whitelist_service.delete_whitelist(whitelist.id) + flashes = whitelist_service.delete_whitelist(whitelist.id) - # Verify messages - assert any("Deleted rule" in msg for msg in messages) + # Verify flash messages + assert isinstance(flashes, list) + assert any("Deleted rule" in msg for msg in flashes) # Verify the rule was deleted assert db.session.get(RTBH, rtbh_rule.id) is None @@ -228,7 +325,137 @@ def test_delete_nonexistent_whitelist(self, app, db): """Test deleting a whitelist that doesn't exist""" with app.app_context(): # Call the service function with a non-existent ID - messages = whitelist_service.delete_whitelist(999) + flashes = whitelist_service.delete_whitelist(999) + + # Should return empty list of flash messages, as no whitelist was found + assert isinstance(flashes, list) + assert len(flashes) == 0 + + +class TestEvaluateWhitelistAgainstRtbhResults: + def test_equal_relation(self, app): + """Test evaluating a whitelist with an EQUAL relation to a rule""" + # Create test data + whitelist_model = MagicMock(spec=Whitelist) + whitelist_model.id = 1 + + flashes = [] + + rtbh_rule = MagicMock(spec=RTBH) + rtbh_rule.rstate_id = 1 # Active state + + rule_key = "192.168.1.0/24" + whitelist_key = "192.168.1.0/24" + rtbh_rule_cache = {rule_key: rtbh_rule} + + results = [(rule_key, whitelist_key, Relation.EQUAL)] + + # Mock required functions + with patch("flowapp.services.whitelist_service.whitelist_rtbh_rule") as mock_whitelist_rule, patch( + "flowapp.services.whitelist_service.withdraw_rtbh_route" + ) as mock_withdraw: + + # Call the function + with app.app_context(): + result = whitelist_service.evaluate_whitelist_against_rtbh_check_results( + whitelist_model, flashes, rtbh_rule_cache, results + ) + + # Verify the rule was whitelisted and route withdrawn + mock_whitelist_rule.assert_called_once_with(rtbh_rule, whitelist_model) + mock_withdraw.assert_called_once_with(rtbh_rule) + + # Verify the flash message + assert any("equal to whitelist" in msg for msg in flashes) + + # Verify the correct model was returned + assert result == whitelist_model + + def test_subnet_relation(self, app): + """Test evaluating a whitelist with a SUBNET relation to a rule""" + # Create test data + whitelist_model = MagicMock(spec=Whitelist) + whitelist_model.id = 1 + + flashes = [] + + rtbh_rule = MagicMock(spec=RTBH) + rtbh_rule.rstate_id = 1 # Active state + + rule_key = "192.168.1.0/24" + whitelist_key = "192.168.1.128/25" + rtbh_rule_cache = {rule_key: rtbh_rule} + + results = [(rule_key, whitelist_key, Relation.SUBNET)] + + # Mock required functions + with patch("flowapp.services.whitelist_service.subtract_network") as mock_subtract, patch( + "flowapp.services.whitelist_service.create_rtbh_from_whitelist_parts" + ) as mock_create, patch("flowapp.services.whitelist_service.add_rtbh_rule_to_cache") as mock_add_cache, patch( + "flowapp.services.whitelist_service.db.session.commit" + ) as mock_commit: + + # Mock subtract_network to return some subnets + mock_subtract.return_value = ["192.168.1.0/25"] + + # Call the function + with app.app_context(): + _result = whitelist_service.evaluate_whitelist_against_rtbh_check_results( + whitelist_model, flashes, rtbh_rule_cache, results + ) + + # Verify subnet calculation was performed + mock_subtract.assert_called_once() + + # Verify new rules were created for the subnets + mock_create.assert_called_once() + + # Verify the original rule was cached + mock_add_cache.assert_called_once_with(rtbh_rule, whitelist_model.id, RuleOrigin.USER) + + # Verify transaction was committed + mock_commit.assert_called_once() + + # Verify the flash messages + assert any("supernet of whitelist" in msg for msg in flashes) + + # Verify model was updated to whitelisted state + assert rtbh_rule.rstate_id == 4 + + def test_supernet_relation(self, app): + """Test evaluating a whitelist with a SUPERNET relation to a rule""" + # Create test data + whitelist_model = MagicMock(spec=Whitelist) + whitelist_model.id = 1 + + flashes = [] + + rtbh_rule = MagicMock(spec=RTBH) + rtbh_rule.rstate_id = 1 # Active state + + rule_key = "192.168.1.0/24" + whitelist_key = "192.168.0.0/16" + rtbh_rule_cache = {rule_key: rtbh_rule} + + results = [(rule_key, whitelist_key, Relation.SUPERNET)] + + # Mock required functions + with patch("flowapp.services.whitelist_service.whitelist_rtbh_rule") as mock_whitelist_rule, patch( + "flowapp.services.whitelist_service.withdraw_rtbh_route" + ) as mock_withdraw: + + # Call the function + with app.app_context(): + result = whitelist_service.evaluate_whitelist_against_rtbh_check_results( + whitelist_model, flashes, rtbh_rule_cache, results + ) + + # Verify the rule was whitelisted and route withdrawn + mock_whitelist_rule.assert_called_once_with(rtbh_rule, whitelist_model) + mock_withdraw.assert_called_once_with(rtbh_rule) + + # Verify the flash message + assert any("subnet of whitelist" in msg for msg in flashes) - # Should return empty list of messages, as no whitelist was found - assert len(messages) == 0 + # Verify the correct model was returned + assert result == whitelist_model From 2fcdb65bd18e7a6514abd347e9d127c934bf8522 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 14 Mar 2025 08:03:50 +0100 Subject: [PATCH 17/46] Improve rule management, logging, and RabbitMQ connection handling Added delete_by_rule_id method to RuleWhitelistCache for efficient cache entry removal. Used with statement for RabbitMQ connection handling in announce_to_rabbitmq to ensure proper cleanup. Introduced ALLOWED_COMMUNITIES check in rule and whitelist services for better filtering. Replaced print statements with structured logging using current_app.logger for better debugging. Improved logging in whitelist deletion to handle anomalies and prevent errors. Refactored logging configuration to use a standardized handler instead of loguru. Updated bulk_user_form.html example data to clarify user roles. --- flowapp/models/rules/whitelist.py | 15 ++++++++++ flowapp/output.py | 8 +++--- flowapp/services/rule_service.py | 16 +++++++---- flowapp/services/whitelist_service.py | 31 ++++++++++++++++----- flowapp/templates/forms/bulk_user_form.html | 6 ++-- flowapp/utils/app_factory.py | 26 +++++++++-------- flowapp/views/rules.py | 4 +++ 7 files changed, 75 insertions(+), 31 deletions(-) diff --git a/flowapp/models/rules/whitelist.py b/flowapp/models/rules/whitelist.py index e1857cc..6384bdd 100644 --- a/flowapp/models/rules/whitelist.py +++ b/flowapp/models/rules/whitelist.py @@ -128,6 +128,21 @@ def clean_by_whitelist_id(cls, whitelist_id: int): db.session.commit() return deleted + @classmethod + def delete_by_rule_id(cls, rule_id: int): + """ + Delete all cache entries with the given rule ID from the database + + Args: + rule_id (int): The ID of the rule to clean + + Returns: + int: Number of rows deleted + """ + deleted = cls.query.filter_by(rid=rule_id).delete() + db.session.commit() + return deleted + def __repr__(self): return f"" diff --git a/flowapp/output.py b/flowapp/output.py index b8774bf..fcb385f 100644 --- a/flowapp/output.py +++ b/flowapp/output.py @@ -84,10 +84,10 @@ def announce_to_rabbitmq(route: Dict[str, str]) -> None: credentials, ) - connection = pika.BlockingConnection(parameters) - channel = connection.channel() - channel.queue_declare(queue=queue) - channel.basic_publish(exchange="", routing_key=queue, body=json.dumps(route)) + with pika.BlockingConnection(parameters) as connection: + channel = connection.channel() + channel.queue_declare(queue=queue) + channel.basic_publish(exchange="", routing_key=queue, body=json.dumps(route)) else: current_app.logger.debug(f"Testing: {route}") diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index 85c177a..1611df0 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -9,6 +9,8 @@ from datetime import datetime from typing import Dict, List, Tuple +from flask import current_app + from flowapp import db, messages from flowapp.constants import RuleOrigin, RuleTypes, ANNOUNCE from flowapp.models import ( @@ -213,12 +215,14 @@ def create_or_update_rtbh_rule( author = f"{user_email} / {org_name}" # Check if rule is whitelisted - # get all not expired whitelists - whitelists = db.session.query(Whitelist).filter(Whitelist.expires > datetime.now()).all() - wl_cache = map_whitelists_to_strings(whitelists) - results = check_rule_against_whitelists(str(model), wl_cache.keys()) - # check rule against whitelists, stop search when rule is whitelisted first time - model = evaluate_rtbh_against_whitelists_check_results(user_id, model, flashes, author, wl_cache, results) + allowed_communities = current_app.config["ALLOWED_COMMUNITIES"] + if model.community_id in allowed_communities: + # get all not expired whitelists + whitelists = db.session.query(Whitelist).filter(Whitelist.expires > datetime.now()).all() + wl_cache = map_whitelists_to_strings(whitelists) + results = check_rule_against_whitelists(str(model), wl_cache.keys()) + # check rule against whitelists, stop search when rule is whitelisted first time + model = evaluate_rtbh_against_whitelists_check_results(user_id, model, flashes, author, wl_cache, results) announce_rtbh_route(model, author=author) # Log changes diff --git a/flowapp/services/whitelist_service.py b/flowapp/services/whitelist_service.py index 3f57f7e..d82fd29 100644 --- a/flowapp/services/whitelist_service.py +++ b/flowapp/services/whitelist_service.py @@ -5,9 +5,11 @@ This module provides business logic functions for creating, updating, and managing flow rules, separating these concerns from HTTP handling. """ - +from flask import current_app from typing import Dict, Tuple, List +import sqlalchemy + from flowapp import db from flowapp.constants import RuleOrigin, RuleTypes from flowapp.models import Whitelist, RuleWhitelistCache, get_whitelist_model_if_exists @@ -62,11 +64,13 @@ def create_or_update_whitelist( db.session.commit() # check RTBH rules against whitelist - all_rtbh_rules = RTBH.query.filter(RTBH.rstate_id == 1).all() - print(f"Found {len(all_rtbh_rules)} active RTBH rules") + allowed_communities = current_app.config["ALLOWED_COMMUNITIES"] + current_app.logger.info(f"allowed communities: {allowed_communities}") + # filter out RTBH rules that are not active and not in allowed communities + all_rtbh_rules = RTBH.query.filter(RTBH.rstate_id == 1, RTBH.community_id.in_(allowed_communities)).all() rtbh_rules_map = map_rtbh_rules_to_strings(all_rtbh_rules) result = check_whitelist_against_rules(rtbh_rules_map, str(model)) - print(f"Found {len(result)} matching RTBH rules") + current_app.logger.info(f"Found {len(result)} matching RTBH rules for whitelist {model}") model = evaluate_whitelist_against_rtbh_check_results(model, flashes, rtbh_rules_map, result) return model, flashes @@ -80,7 +84,7 @@ def evaluate_whitelist_against_rtbh_check_results( ) -> Whitelist: for rule_key, whitelist_key, relation in results: - print(f"whitelist {whitelist_key} is {relation} to Rule {rule_key}") + current_app.logger.info(f"whitelist {whitelist_key} is {relation} to Rule {rule_key}") match relation: case Relation.EQUAL: whitelist_rtbh_rule(rtbh_rule_cache[rule_key], whitelist_model) @@ -123,7 +127,11 @@ def delete_whitelist(whitelist_id: int) -> List[str]: flashes = [] if model: cached_rules = RuleWhitelistCache.get_by_whitelist_id(whitelist_id) + current_app.logger.info( + f"Deleting whitelist {whitelist_id}. Found {len(cached_rules)} cached rules to process." + ) for cached_rule in cached_rules: + current_app.logger.debug(f"Processing cached rule {cached_rule}") rule_model_type = RuleTypes(cached_rule.rtype) match rule_model_type: case RuleTypes.IPv4: @@ -133,15 +141,24 @@ def delete_whitelist(whitelist_id: int) -> List[str]: case RuleTypes.RTBH: rule_model = db.session.get(RTBH, cached_rule.rid) rorigin_type = RuleOrigin(cached_rule.rorigin) + current_app.logger.debug(f"Rule {rule_model} has origin {rorigin_type}") if rorigin_type == RuleOrigin.WHITELIST: flashes.append(f"Deleted rule {rule_model} created by this whitelist") - db.session.delete(rule_model) + try: + db.session.delete(rule_model) + except sqlalchemy.orm.exc.UnmappedInstanceError: + current_app.logger.warning( + f"RuleWhitelistCache Anomaly! Rule {rule_model} does not exist. Type {rule_model_type} RID {cached_rule.rid} ID {cached_rule.id} Skipping." + ) + elif rorigin_type == RuleOrigin.USER: flashes.append(f"Set rule {rule_model} back to state 'Active'") try: rule_model.rstate_id = 1 # Set rule state to "Active" again except AttributeError: - print(f"Rule {rule_model} does not exist, cache anomaly?. Skipping.") + current_app.logger.warning( + f"RuleWhitelistCache Anomaly! Rule {rule_model} does not exist. Type {rule_model_type} RID {cached_rule.rid} ID {cached_rule.id} Skipping." + ) else: author = f"{model.user.email} ({model.user.organization})" announce_rtbh_route(rule_model, author) diff --git a/flowapp/templates/forms/bulk_user_form.html b/flowapp/templates/forms/bulk_user_form.html index c41389e..b15b3ff 100644 --- a/flowapp/templates/forms/bulk_user_form.html +++ b/flowapp/templates/forms/bulk_user_form.html @@ -42,9 +42,9 @@

Example CSV data

             
 uuid-eppn,name,telefon,email,role,organizace,poznamka
-view@example.com,Test View,123,view@example.com,1,1,View
-user@example.com,Test User,123456,user@example.com,2,1,User
-admin@example.com,Test Admin,+420 111 111 111,admin@example.com,3,1,Admin
+view@example.com,Test View,123,view@example.com,1,1,user with view role (1)
+user@example.com,Test User,123456,user@example.com,2,1,regular user (role 2)
+admin@example.com,Test Admin,+420 111 111 111,admin@example.com,3,1,user with admin role (3)
             
         

Role

diff --git a/flowapp/utils/app_factory.py b/flowapp/utils/app_factory.py index b53b23b..bc6559b 100644 --- a/flowapp/utils/app_factory.py +++ b/flowapp/utils/app_factory.py @@ -1,8 +1,6 @@ import logging import babel from flask import redirect, render_template, request, session, url_for -from flask.logging import default_handler -from loguru import logger def register_blueprints(app, csrf=None): @@ -35,18 +33,24 @@ def register_blueprints(app, csrf=None): return app -class InterceptHandler(logging.Handler): +def configure_logging(app): + """Configure logging for the Flask application.""" - def emit(self, record): - logger_opt = logger.opt(depth=6, exception=record.exc_info, colors=True) - logger_opt.log(record.levelname, record.getMessage()) + # Remove all default handlers + for handler in app.logger.handlers[:]: + app.logger.removeHandler(handler) + # Define log format + log_format = "%(asctime)s | %(levelname)s | %(message)s" + log_datefmt = "%Y-%m-%d %H:%M:%S" -def configure_logging(app): - """Configure logging for the application.""" - # register loguru as handler - app.logger.removeHandler(default_handler) - app.logger.addHandler(InterceptHandler()) + # Create a new handler with the desired format + console_handler = logging.StreamHandler() + console_handler.setFormatter(logging.Formatter(log_format, datefmt=log_datefmt)) + + # Set logger level and attach the handler + app.logger.setLevel(logging.DEBUG) + app.logger.addHandler(console_handler) return app diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py index fc787cf..4cad6d2 100644 --- a/flowapp/views/rules.py +++ b/flowapp/views/rules.py @@ -29,6 +29,7 @@ get_user_nets, insert_initial_communities, ) +from flowapp.models.rules.whitelist import RuleWhitelistCache from flowapp.output import ROUTE_MODELS, announce_route, log_route, log_withdraw, RouteSources, Route from flowapp.services import rule_service from flowapp.utils import ( @@ -208,6 +209,9 @@ def delete_rule(rule_type, rule_id): model.id, f"{session['user_email']} / {session['user_org']}", ) + if enum_rule_type == RuleTypes.RTBH: + current_app.logger.debug(f"Deleting RTBH rule {rule_id} from cache") + RuleWhitelistCache.delete_by_rule_id(rule_id) # delete from db db.session.delete(model) From 04b5de126522c960ddf6a266bb16c2b568c69098 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 14 Mar 2025 09:10:02 +0100 Subject: [PATCH 18/46] Enhance rule visibility and styling in dashboard tables Added whitelist_rule_ids parameter to build_ip_tbody and build_rtbh_tbody macros to visually distinguish whitelist-created rules. Expired rules highlighted in table-warning. Whitelisted rules marked with table-success. Rules created by Whitelist marked with in table-secondary. Refactored index view to enrich rule data with whitelist info via enrich_rules_with_whitelist_info. Passed whitelist_rule_ids to dashboard rendering functions to ensure consistent rule styling. --- flowapp/templates/macros.html | 93 ++++++++++++++++++----------------- flowapp/views/dashboard.py | 62 +++++++++++++++++++++-- 2 files changed, 106 insertions(+), 49 deletions(-) diff --git a/flowapp/templates/macros.html b/flowapp/templates/macros.html index 653cc81..01a2fca 100644 --- a/flowapp/templates/macros.html +++ b/flowapp/templates/macros.html @@ -1,5 +1,5 @@ -{% macro build_ip_tbody(rules, today, editable=True, group_op=True) %} +{% macro build_ip_tbody(rules, today, editable=True, group_op=True, whitelist_rule_ids=None) %} {% for rule in rules %} {% if rule.next_header is defined %} @@ -8,7 +8,11 @@ {% set rtype_int = 4 %} {% endif %} - + {{ rule.source }}{% if rule.source_mask != none %}{{ '/' if rule.source_mask >= 0 else '' }}{{ rule.source_mask if rule.source_mask >= 0 else '' }}{% endif %} @@ -73,51 +77,52 @@ {% endmacro %} -{% macro build_rtbh_tbody(rules, today, editable=True, group_op=True) %} +{% macro build_rtbh_tbody(rules, today, editable=True, group_op=True, whitelist_rule_ids=None) %} {% for rule in rules %} - - - {% if rule.ipv4 %} - {{ rule.ipv4 }}{{ '/' if rule.ipv4_mask else '' }}{{rule.ipv4_mask|default("", True)}} - {% endif %} - {% if rule.ipv6 %} - {{ rule.ipv6 }}{{ '/' if rule.ipv6_mask else '' }} {{rule.ipv6_mask|default("", True)}} - {% endif %} - - - {{ rule.community.name }} - - - - {{ rule.expires|strftime }} - - - {{ rule.user.name }} - - - {% if editable %} - - - - - - + + + {% if rule.ipv4 %} + {{ rule.ipv4 }}{{ '/' if rule.ipv4_mask else '' }}{{rule.ipv4_mask|default("", True)}} + {% endif %} + {% if rule.ipv6 %} + {{ rule.ipv6 }}{{ '/' if rule.ipv6_mask else '' }} {{rule.ipv6_mask|default("", True)}} + {% endif %} + + + {{ rule.community.name }} + + + + {{ rule.expires|strftime }} + + + {{ rule.user.name }} + + + {% if editable %} + + + + + + + {% endif %} + {% if rule.comment %} + + {% endif %} + + {% if editable and group_op %} + + + {% endif %} - {% if rule.comment %} - - {% endif %} - - {% if editable and group_op %} - - - - {% endif %} {% endfor %} diff --git a/flowapp/views/dashboard.py b/flowapp/views/dashboard.py index 8ed262f..3980b88 100644 --- a/flowapp/views/dashboard.py +++ b/flowapp/views/dashboard.py @@ -48,7 +48,7 @@ def whois(ip_address): def index(rtype=None, rstate="active"): """ dispatcher object for the dashboard - :param rtype: ipv4, ipv6, rtbh + :param rtype: ipv4, ipv6, rtbh, whitelist :param rstate: :return: view from view factory """ @@ -84,7 +84,6 @@ def index(rtype=None, rstate="active"): data_handler_module = current_app.config["DASHBOARD"].get(rtype).get("data_handler", models) data_handler_method = current_app.config["DASHBOARD"].get(rtype).get("data_handler_method", "get_ip_rules") - # get search query, sort order and sort key from request or session get_search_query = request.args.get(SEARCH_ARG, session.get(SEARCH_ARG, "")) get_sort_key = request.args.get(SORT_ARG, session.get(SORT_ARG, DEFAULT_SORT)) @@ -109,6 +108,10 @@ def index(rtype=None, rstate="active"): # get the handler and the data handler = getattr(data_handler_module, data_handler_method) rules = handler(rtype, rstate, get_sort_key, get_sort_order) + + # Enrich rules with whitelist information + rules, whitelist_rule_ids = enrich_rules_with_whitelist_info(rules, rtype) + session[RULES_KEY] = [rule.id for rule in rules] # search rules if get_search_query: @@ -138,6 +141,7 @@ def index(rtype=None, rstate="active"): macro_tbody=macro_tbody, macro_thead=macro_thead, macro_tfoot=macro_tfoot, + whitelist_rule_ids=whitelist_rule_ids, ) @@ -148,18 +152,22 @@ def create_dashboard_table_body( group_op=True, macro_file="macros.html", macro_name="build_ip_tbody", + whitelist_rule_ids=None, ): """ create the table body for the dashboard using a jinja2 macro :param rules: list of rules :param rtype: ipv4, ipv6, rtbh + :param editable: whether rules can be edited + :param group_op: whether group operations are allowed :param macro_file: the file where the macro is defined :param macro_name: the name of the macro + :param whitelist_rule_ids: set of rule IDs that were created by a whitelist """ tstring = "{% " tstring = tstring + f"from '{macro_file}' import {macro_name}" tstring = tstring + " %} {{" - tstring = tstring + f" {macro_name}(rules, today, editable, group_op) " + "}}" + tstring = tstring + f" {macro_name}(rules, today, editable, group_op, whitelist_rule_ids) " + "}}" dashboard_table_body = render_template_string( tstring, @@ -167,6 +175,7 @@ def create_dashboard_table_body( today=datetime.now(), editable=editable, group_op=group_op, + whitelist_rule_ids=whitelist_rule_ids or set(), ) return dashboard_table_body @@ -246,6 +255,7 @@ def create_admin_response( macro_tbody="build_ip_tbody", macro_thead="build_rules_thead", macro_tfoot="build_group_buttons_tfoot", + whitelist_rule_ids=None, ): """ Admin can see and edit any rules @@ -257,7 +267,9 @@ def create_admin_response( :return: """ - dashboard_table_body = create_dashboard_table_body(rules, rtype, macro_file=macro_file, macro_name=macro_tbody) + dashboard_table_body = create_dashboard_table_body( + rules, rtype, macro_file=macro_file, macro_name=macro_tbody, whitelist_rule_ids=whitelist_rule_ids + ) dashboard_table_head = create_dashboard_table_head( rules_columns=table_columns, @@ -313,6 +325,7 @@ def create_user_response( macro_tbody="build_ip_tbody", macro_thead="build_rules_thead", macro_tfoot="build_rules_tfoot", + whitelist_rule_ids=None, ): """ Filter out the rules for normal users @@ -343,9 +356,10 @@ def create_user_response( group_op=False, macro_file=macro_file, macro_name=macro_tbody, + whitelist_rule_ids=whitelist_rule_ids, ) dashboard_table_editable = create_dashboard_table_body( - rules_editable, rtype, macro_file=macro_file, macro_name=macro_tbody + rules_editable, rtype, macro_file=macro_file, macro_name=macro_tbody, whitelist_rule_ids=whitelist_rule_ids ) dashboard_table_editable_head = create_dashboard_table_head( rules_columns=table_columns, @@ -419,6 +433,7 @@ def create_view_response( macro_tbody="build_ip_tbody", macro_thead="build_rules_thead", macro_tfoot="build_rules_tfoot", + whitelist_rule_ids=None, ): """ Filter out the rules for normal users @@ -433,6 +448,7 @@ def create_view_response( group_op=False, macro_file=macro_file, macro_name=macro_tbody, + whitelist_rule_ids=whitelist_rule_ids, ) dashboard_table_head = create_dashboard_table_head( @@ -482,3 +498,39 @@ def filter_rules(rules, get_search_query): result.append(rules[idx]) return result + + +def enrich_rules_with_whitelist_info(rules, rule_type): + """ + Enrich rules with whitelist information from RuleWhitelistCache. + + Args: + rules: List of rule objects (Flowspec4, Flowspec6, RTBH) + rule_type: String identifier of rule type ("ipv4", "ipv6", "rtbh") + + Returns: + Tuple of (rules, whitelist_rule_ids) where whitelist_rule_ids is a set of + rule IDs that were created by a whitelist. + """ + from flowapp.models.rules.whitelist import RuleWhitelistCache + from flowapp.constants import RuleTypes, RuleOrigin + + # Map rule type string to enum value + rule_type_map = {"ipv4": RuleTypes.IPv4.value, "ipv6": RuleTypes.IPv6.value, "rtbh": RuleTypes.RTBH.value} + + # Get all rule IDs + rule_ids = [rule.id for rule in rules] + + # No rules to process + if not rule_ids: + return rules, set() + + # Query the cache for these rule IDs + cache_entries = RuleWhitelistCache.query.filter( + RuleWhitelistCache.rid.in_(rule_ids), RuleWhitelistCache.rtype == rule_type_map.get(rule_type) + ).all() + + # Create a set of rule IDs that were created by a whitelist + whitelist_rule_ids = {entry.rid for entry in cache_entries if entry.rorigin == RuleOrigin.WHITELIST.value} + + return rules, whitelist_rule_ids From ed594cf0a13e6e06512e83a3490c06b2329a10d6 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 14 Mar 2025 12:05:30 +0100 Subject: [PATCH 19/46] Improve rule-whitelist interactions and UI updates Added count_by_rule method in RuleWhitelistCache to count cache entries per rule. Modified evaluate_rtbh_against_whitelists_check_results to process all whitelist matches instead of stopping at the first. Adjusted RTBH rule filtering to include both active (rstate_id=1) and whitelisted (rstate_id=4) rules when checking against whitelists. Improved whitelist deletion logic to only revert user-created rules to Active if they have no other whitelist references. Updated build_whitelist_tbody macro to accept whitelist_rule_ids for consistent rule highlighting. --- flowapp/models/rules/whitelist.py | 14 ++++++++++++++ flowapp/services/rule_service.py | 7 ++++--- flowapp/services/whitelist_service.py | 8 ++++---- flowapp/templates/macros.html | 2 +- 4 files changed, 23 insertions(+), 8 deletions(-) diff --git a/flowapp/models/rules/whitelist.py b/flowapp/models/rules/whitelist.py index 6384bdd..29f2b08 100644 --- a/flowapp/models/rules/whitelist.py +++ b/flowapp/models/rules/whitelist.py @@ -143,6 +143,20 @@ def delete_by_rule_id(cls, rule_id: int): db.session.commit() return deleted + @classmethod + def count_by_rule(cls, rule_id: int, rule_type: RuleTypes): + """ + Count the number of cache entries for the given rule + + Args: + rule_id (int): The ID of the rule to count + rule_type (RuleTypes): The type of the rule + + Returns: + int: Number of cache entries + """ + return cls.query.filter_by(rid=rule_id, rtype=rule_type.value).count() + def __repr__(self): return f"" diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index 1611df0..e0f1149 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -239,12 +239,15 @@ def evaluate_rtbh_against_whitelists_check_results( wl_cache: Dict[str, Whitelist], results: List[Tuple[str, str, Relation]], ) -> RTBH: + """ + Evaluate RTBH rule against whitelist check results. + Process all results for cases where rule is whitelisted by several whitelists. + """ for rule, whitelist_key, relation in results: match relation: case Relation.EQUAL: model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) flashes.append(f" Rule is equal to active whitelist {whitelist_key}. Rule is whitelisted.") - break case Relation.SUBNET: parts = subtract_network(target=str(model), whitelist=whitelist_key) wl_id = wl_cache[whitelist_key].id @@ -257,11 +260,9 @@ def evaluate_rtbh_against_whitelists_check_results( model.rstate_id = 4 add_rtbh_rule_to_cache(model, wl_id, RuleOrigin.USER) db.session.commit() - break case Relation.SUPERNET: model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) flashes.append(f" Rule is subnet of active whitelist {whitelist_key}. Rule is whitelisted.") - break return model diff --git a/flowapp/services/whitelist_service.py b/flowapp/services/whitelist_service.py index d82fd29..09df687 100644 --- a/flowapp/services/whitelist_service.py +++ b/flowapp/services/whitelist_service.py @@ -65,9 +65,8 @@ def create_or_update_whitelist( # check RTBH rules against whitelist allowed_communities = current_app.config["ALLOWED_COMMUNITIES"] - current_app.logger.info(f"allowed communities: {allowed_communities}") - # filter out RTBH rules that are not active and not in allowed communities - all_rtbh_rules = RTBH.query.filter(RTBH.rstate_id == 1, RTBH.community_id.in_(allowed_communities)).all() + # filter out RTBH rules that are not active or whitelisted and not in allowed communities + all_rtbh_rules = RTBH.query.filter(RTBH.rstate_id.in_([1, 4]), RTBH.community_id.in_(allowed_communities)).all() rtbh_rules_map = map_rtbh_rules_to_strings(all_rtbh_rules) result = check_whitelist_against_rules(rtbh_rules_map, str(model)) current_app.logger.info(f"Found {len(result)} matching RTBH rules for whitelist {model}") @@ -133,6 +132,7 @@ def delete_whitelist(whitelist_id: int) -> List[str]: for cached_rule in cached_rules: current_app.logger.debug(f"Processing cached rule {cached_rule}") rule_model_type = RuleTypes(cached_rule.rtype) + cache_entries_count = RuleWhitelistCache.count_by_rule(cached_rule.rid, rule_model_type) match rule_model_type: case RuleTypes.IPv4: rule_model = db.session.get(Flowspec4, cached_rule.rid) @@ -151,7 +151,7 @@ def delete_whitelist(whitelist_id: int) -> List[str]: f"RuleWhitelistCache Anomaly! Rule {rule_model} does not exist. Type {rule_model_type} RID {cached_rule.rid} ID {cached_rule.id} Skipping." ) - elif rorigin_type == RuleOrigin.USER: + elif rorigin_type == RuleOrigin.USER and cache_entries_count == 1: flashes.append(f"Set rule {rule_model} back to state 'Active'") try: rule_model.rstate_id = 1 # Set rule state to "Active" again diff --git a/flowapp/templates/macros.html b/flowapp/templates/macros.html index 01a2fca..2f5dd9f 100644 --- a/flowapp/templates/macros.html +++ b/flowapp/templates/macros.html @@ -130,7 +130,7 @@ {% endmacro %} -{% macro build_whitelist_tbody(rules, today, editable=True, group_op=True) %} +{% macro build_whitelist_tbody(rules, today, editable=True, group_op=True, whitelist_rule_ids=None) %} {% for rule in rules %} From 7ca0b03a874fccf435bd87054f29cdef7f92f2fa Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 14 Mar 2025 12:10:24 +0100 Subject: [PATCH 20/46] fixed app config for tests --- flowapp/tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flowapp/tests/conftest.py b/flowapp/tests/conftest.py index 65919af..5960434 100644 --- a/flowapp/tests/conftest.py +++ b/flowapp/tests/conftest.py @@ -67,6 +67,7 @@ def app(request): SECRET_KEY="testkeysession", LOCAL_USER_UUID="jiri.vrany@cesnet.cz", LOCAL_AUTH=True, + ALLOWED_COMMUNITIES=[1, 2, 3], ) print("\n----- CREATE FLASK APPLICATION\n") From 72fd0c95382052867a40b91c3ee1dd4af8ed2ae6 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 14 Mar 2025 13:19:15 +0100 Subject: [PATCH 21/46] Enhance IPv4/IPv6 checks by introducing _is_same_ip_version, update relevant functions to handle mixed IP versions gracefully, and add new test cases in test_whitelist_common.py to ensure proper coverage --- flowapp/services/whitelist_common.py | 136 +++++++++++++++++-------- flowapp/tests/test_whitelist_common.py | 118 +++++++++++++++++++-- 2 files changed, 206 insertions(+), 48 deletions(-) diff --git a/flowapp/services/whitelist_common.py b/flowapp/services/whitelist_common.py index 1730402..065b79a 100644 --- a/flowapp/services/whitelist_common.py +++ b/flowapp/services/whitelist_common.py @@ -45,6 +45,19 @@ class Relation(Enum): DIFFERENT = auto() +def _is_same_ip_version(addr1: str, addr2: str) -> bool: + """ + Check if two IP addresses/networks are of the same IP version. + + :param addr1: First IP address or network string + :param addr2: Second IP address or network string + :return: True if both addresses are of the same IP version (IPv4 or IPv6), False otherwise + """ + is_ipv4_1 = "." in addr1 + is_ipv4_2 = "." in addr2 + return is_ipv4_1 == is_ipv4_2 + + @lru_cache(maxsize=1024) def get_network(address: str) -> ipaddress.IPv4Network | ipaddress.IPv6Network: """ @@ -58,22 +71,34 @@ def get_network(address: str) -> ipaddress.IPv4Network | ipaddress.IPv6Network: def check_whitelist_to_rule_relation(rule: str, whitelist_entry: str) -> Relation: """ - Checks if the whitelist network is a subnet or supernet or exactly the same as the rule network. + Checks if the whitelist network is a subnet or supernet or exactly the same as the rule network. Uses cached network objects for better performance. + If the rule and whitelist are different IP versions (IPv4 vs IPv6), + they are treated as different networks. + :param rule: The IP address or network to check (e.g., "192.168.1.1" or "192.168.1.0/24") :param whitelist_entry: The allowed network to compare against (e.g., "192.168.1.0/24") :return: Relation between the two networks """ - rule_net = get_network(rule) - whitelist_net = get_network(whitelist_entry) - if whitelist_net == rule_net: - return Relation.EQUAL - if whitelist_net.supernet_of(rule_net): - return Relation.SUPERNET - if whitelist_net.subnet_of(rule_net): - return Relation.SUBNET - else: + # First check if IP versions are the same + if not _is_same_ip_version(rule, whitelist_entry): + return Relation.DIFFERENT + + try: + rule_net = get_network(rule) + whitelist_net = get_network(whitelist_entry) + + if whitelist_net == rule_net: + return Relation.EQUAL + if whitelist_net.supernet_of(rule_net): + return Relation.SUPERNET + if whitelist_net.subnet_of(rule_net): + return Relation.SUBNET + else: + return Relation.DIFFERENT + except (ValueError, TypeError): + # Handle any other errors that might occur during comparison return Relation.DIFFERENT @@ -82,34 +107,45 @@ def subtract_network(target: str, whitelist: str) -> List[str]: Computes the remaining parts of a network after removing the whitelist subnet. Uses cached network objects for better performance. + If the target and whitelist are different IP versions (IPv4 vs IPv6), + the original target is returned unchanged. + :param target: The main network (e.g., "192.168.1.0/24") :param whitelist: The subnet to remove (e.g., "192.168.1.128/25") :return: A list of remaining subnets as strings """ - target_net = get_network(target) - whitelist_net = get_network(whitelist) + # First check if IP versions are the same + if not _is_same_ip_version(target, whitelist): + return [target] - # Check if the whitelist is actually a subnet - if check_whitelist_to_rule_relation(target, whitelist) != Relation.SUBNET: - return [target] # Return the full network if whitelist isn't a valid subnet + try: + target_net = get_network(target) + whitelist_net = get_network(whitelist) - remaining = [] + # Check if the whitelist is actually a subnet + if check_whitelist_to_rule_relation(target, whitelist) != Relation.SUBNET: + return [target] # Return the full network if whitelist isn't a valid subnet - # Compute ranges before and after the whitelist - if whitelist_net.network_address > target_net.network_address: - # Before the whitelist - start = target_net.network_address - end = whitelist_net.network_address - 1 - remaining.extend(ipaddress.summarize_address_range(start, end)) + remaining = [] - if whitelist_net.broadcast_address < target_net.broadcast_address: - # After the whitelist - start = whitelist_net.broadcast_address + 1 - end = target_net.broadcast_address - remaining.extend(ipaddress.summarize_address_range(start, end)) + # Compute ranges before and after the whitelist + if whitelist_net.network_address > target_net.network_address: + # Before the whitelist + start = target_net.network_address + end = whitelist_net.network_address - 1 + remaining.extend(ipaddress.summarize_address_range(start, end)) - # Convert to string format - return [str(net) for net in remaining] + if whitelist_net.broadcast_address < target_net.broadcast_address: + # After the whitelist + start = whitelist_net.broadcast_address + 1 + end = target_net.broadcast_address + remaining.extend(ipaddress.summarize_address_range(start, end)) + + # Convert to string format + return [str(net) for net in remaining] + except (ValueError, TypeError): + # Return the original target in case of any error + return [target] def check_rule_against_whitelists(rule: str, whitelists: List[str]) -> List[Tuple]: @@ -123,12 +159,21 @@ def check_rule_against_whitelists(rule: str, whitelists: List[str]) -> List[Tupl :return: tuple of rule, whitelist and relation for each whitelists that is not DIFFERENT """ # Pre-cache the rule network since it will be used multiple times - get_network(rule) + try: + get_network(rule) + except (ValueError, TypeError): + # Return empty list if rule is not a valid network + return [] + items = [] for whitelist in whitelists: - rel = check_whitelist_to_rule_relation(rule, whitelist) - if rel != Relation.DIFFERENT: - items.append((rule, whitelist, rel)) + try: + rel = check_whitelist_to_rule_relation(rule, whitelist) + if rel != Relation.DIFFERENT: + items.append((rule, whitelist, rel)) + except (ValueError, TypeError): + # Skip this whitelist if there's an error comparing + continue return items @@ -138,17 +183,26 @@ def check_whitelist_against_rules(rules: List[str], whitelist: str) -> List[Tupl Creates a cached rule network object for better performance. Reduces list of rules, where the Relation is not DIFFERENT - :param rule: The IP address or network to check against - :param whitelists: List of whitelist networks to check - :return: tuple of rule, whitelist and relation for each whitelists that is not DIFFERENT + :param rules: List of rule networks to check against + :param whitelist: The whitelist network to check + :return: tuple of rule, whitelist and relation for each rules that is not DIFFERENT """ - # Pre-cache the rule network since it will be used multiple times - get_network(whitelist) + # Pre-cache the whitelist network since it will be used multiple times + try: + get_network(whitelist) + except (ValueError, TypeError): + # Return empty list if whitelist is not a valid network + return [] + items = [] for rule in rules: - rel = check_whitelist_to_rule_relation(rule, whitelist) - if rel != Relation.DIFFERENT: - items.append((rule, whitelist, rel)) + try: + rel = check_whitelist_to_rule_relation(rule, whitelist) + if rel != Relation.DIFFERENT: + items.append((rule, whitelist, rel)) + except (ValueError, TypeError): + # Skip this rule if there's an error comparing + continue return items diff --git a/flowapp/tests/test_whitelist_common.py b/flowapp/tests/test_whitelist_common.py index 9b98e9a..a843aff 100644 --- a/flowapp/tests/test_whitelist_common.py +++ b/flowapp/tests/test_whitelist_common.py @@ -6,6 +6,7 @@ check_whitelist_to_rule_relation, subtract_network, clear_network_cache, + _is_same_ip_version, # New helper function ) @@ -130,16 +131,119 @@ def test_single_ip_as_network(): assert check_whitelist_to_rule_relation("2001:db8::1/128", "2001:db8::/32") == Relation.SUPERNET -def test_invalid_input(): +def test_is_same_ip_version(): + """Test the new helper function to check if two addresses are of the same IP version""" + # IPv4 cases + assert _is_same_ip_version("192.168.1.0/24", "10.0.0.0/8") is True + assert _is_same_ip_version("192.168.1.1", "10.0.0.1") is True + + # IPv6 cases + assert _is_same_ip_version("2001:db8::/32", "fe80::/64") is True + assert _is_same_ip_version("2001:db8::1", "fe80::1") is True + + # Mixed cases + assert _is_same_ip_version("192.168.1.0/24", "2001:db8::/32") is False + assert _is_same_ip_version("192.168.1.1", "fe80::1") is False + + +def test_check_whitelist_to_rule_relation_mixed_versions(): + """Test the check_whitelist_to_rule_relation function with mixed IP versions""" + clear_network_cache() + + # IPv4 rule with IPv6 whitelist + assert check_whitelist_to_rule_relation("192.168.1.0/24", "2001:db8::/32") == Relation.DIFFERENT + + # IPv6 rule with IPv4 whitelist + assert check_whitelist_to_rule_relation("2001:db8::/32", "192.168.1.0/24") == Relation.DIFFERENT + + # Invalid IP addresses should return DIFFERENT + assert check_whitelist_to_rule_relation("invalid", "192.168.1.0/24") == Relation.DIFFERENT + assert check_whitelist_to_rule_relation("192.168.1.0/24", "invalid") == Relation.DIFFERENT + + +def test_subtract_network_mixed_versions(): + """Test the subtract_network function with mixed IP versions""" + clear_network_cache() + + # IPv4 target with IPv6 whitelist - should return original target + result = subtract_network("192.168.1.0/24", "2001:db8::/32") + assert result == ["192.168.1.0/24"] + + # IPv6 target with IPv4 whitelist - should return original target + result = subtract_network("2001:db8::/32", "192.168.1.0/24") + assert result == ["2001:db8::/32"] + + # Invalid addresses - should return original target + result = subtract_network("invalid", "192.168.1.0/24") + assert result == ["invalid"] + result = subtract_network("192.168.1.0/24", "invalid") + assert result == ["192.168.1.0/24"] + + +def test_check_rule_against_whitelists_mixed_versions(): + """Test the check_rule_against_whitelists function with mixed IP versions""" + clear_network_cache() + + # IPv4 rule against mixed whitelist + rule = "192.168.1.0/24" + whitelists = [ + "192.168.0.0/16", # IPv4 - should match + "2001:db8::/32", # IPv6 - should not match + "10.0.0.0/8", # IPv4 - should not match (different network) + ] + + results = check_rule_against_whitelists(rule, whitelists) + assert len(results) == 1 + assert results[0][0] == rule + assert results[0][1] == "192.168.0.0/16" + assert results[0][2] == Relation.SUPERNET + + # IPv6 rule against mixed whitelist + rule = "2001:db8:1::/48" + whitelists = [ + "192.168.0.0/16", # IPv4 - should not match + "2001:db8::/32", # IPv6 - should match + "2002:db8::/32", # IPv6 - should not match (different network) + ] + + results = check_rule_against_whitelists(rule, whitelists) + assert len(results) == 1 + assert results[0][0] == rule + assert results[0][1] == "2001:db8::/32" + assert results[0][2] == Relation.SUPERNET + + +def test_check_whitelist_against_rules_mixed_versions(): + """Test the check_whitelist_against_rules function with mixed IP versions""" clear_network_cache() - with pytest.raises(ValueError): - check_whitelist_to_rule_relation("invalid", "192.168.1.0/24") - with pytest.raises(ValueError): - check_whitelist_to_rule_relation("192.168.1.0/24", "invalid") + # IPv4 whitelist against mixed rules + whitelist = "192.168.1.0/24" + rules = [ + "192.168.0.0/16", # IPv4 - should match + "2001:db8::/32", # IPv6 - should not match + "10.0.0.0/8", # IPv4 - should not match (different network) + ] + + results = check_whitelist_against_rules(rules, whitelist) + assert len(results) == 1 + assert results[0][0] == "192.168.0.0/16" + assert results[0][1] == whitelist + assert results[0][2] == Relation.SUBNET - with pytest.raises(ValueError): - subtract_network("invalid", "192.168.1.0/24") + # IPv6 whitelist against mixed rules + whitelist = "2001:db8:1::/48" + rules = [ + "192.168.0.0/16", # IPv4 - should not match + "2001:db8::/32", # IPv6 - should match + "2002:db8::/32", # IPv6 - should not match (different network) + ] + + results = check_whitelist_against_rules(rules, whitelist) + assert len(results) == 1 + assert results[0][0] == "2001:db8::/32" + assert results[0][1] == whitelist + assert results[0][2] == Relation.SUBNET if __name__ == "__main__": From c9969c603116ba7f095fa9043cc0641894a609e8 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Mon, 17 Mar 2025 12:42:43 +0100 Subject: [PATCH 22/46] improved logging of whitelist service --- flowapp/services/whitelist_service.py | 38 +++++++++++++++++++-------- 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/flowapp/services/whitelist_service.py b/flowapp/services/whitelist_service.py index 09df687..b749b4b 100644 --- a/flowapp/services/whitelist_service.py +++ b/flowapp/services/whitelist_service.py @@ -88,17 +88,21 @@ def evaluate_whitelist_against_rtbh_check_results( case Relation.EQUAL: whitelist_rtbh_rule(rtbh_rule_cache[rule_key], whitelist_model) withdraw_rtbh_route(rtbh_rule_cache[rule_key]) - flashes.append(f"Active rule {rule_key} is equal to whitelist {whitelist_key}. Rule is whitelisted.") + msg = "Existing active rule {rule_key} is equal to whitelist {whitelist_key}. Rule is now whitelisted." + flashes.append(msg) + current_app.logger.info(msg) case Relation.SUBNET: parts = subtract_network(target=rule_key, whitelist=whitelist_key) wl_id = whitelist_model.id - flashes.append( - f" Rule {rule_key} is supernet of whitelist {whitelist_key}. Rule is whitelisted, {len(parts)} subnet rules created." - ) + msg = f"Rule {rule_key} is supernet of whitelist {whitelist_key}. Rule is whitelisted, {len(parts)} subnet rules will be created." + flashes.append(msg) + current_app.logger.info(msg) for network in parts: rule_model = rtbh_rule_cache[rule_key] create_rtbh_from_whitelist_parts(rule_model, wl_id, whitelist_key, network) - flashes.append(f"DEBUG: Created RTBH rule for {network}, from whitelist {whitelist_key}") + msg = f"Created RTBH rule from {rule_model.id} {network} parted by whitelist {whitelist_key}." + flashes.append(msg) + current_app.logger.info(msg) rule_model.rstate_id = 4 add_rtbh_rule_to_cache(rule_model, wl_id, RuleOrigin.USER) db.session.commit() @@ -106,7 +110,11 @@ def evaluate_whitelist_against_rtbh_check_results( whitelist_rtbh_rule(rtbh_rule_cache[rule_key], whitelist_model) withdraw_rtbh_route(rtbh_rule_cache[rule_key]) - flashes.append(f"Active rule {rule_key} is subnet of whitelist {whitelist_key}. Rule is whitelisted.") + msg = ( + f"Existing active rule {rule_key} is subnet of whitelist {whitelist_key}. Rule is now whitelisted." + ) + current_app.logger.info(msg) + flashes.append(msg) return whitelist_model @@ -127,10 +135,9 @@ def delete_whitelist(whitelist_id: int) -> List[str]: if model: cached_rules = RuleWhitelistCache.get_by_whitelist_id(whitelist_id) current_app.logger.info( - f"Deleting whitelist {whitelist_id}. Found {len(cached_rules)} cached rules to process." + f"Deleting whitelist {whitelist_id} {model}. Found {len(cached_rules)} cached rules to process." ) for cached_rule in cached_rules: - current_app.logger.debug(f"Processing cached rule {cached_rule}") rule_model_type = RuleTypes(cached_rule.rtype) cache_entries_count = RuleWhitelistCache.count_by_rule(cached_rule.rid, rule_model_type) match rule_model_type: @@ -143,7 +150,9 @@ def delete_whitelist(whitelist_id: int) -> List[str]: rorigin_type = RuleOrigin(cached_rule.rorigin) current_app.logger.debug(f"Rule {rule_model} has origin {rorigin_type}") if rorigin_type == RuleOrigin.WHITELIST: - flashes.append(f"Deleted rule {rule_model} created by this whitelist") + msg = f"Deleted {rule_model_type} rule {rule_model} created by whitelist {model}" + current_app.logger.info(msg) + flashes.append(msg) try: db.session.delete(rule_model) except sqlalchemy.orm.exc.UnmappedInstanceError: @@ -152,7 +161,9 @@ def delete_whitelist(whitelist_id: int) -> List[str]: ) elif rorigin_type == RuleOrigin.USER and cache_entries_count == 1: - flashes.append(f"Set rule {rule_model} back to state 'Active'") + msg = f"Set rule {rule_model} back to state 'Active'" + current_app.logger.info(msg) + flashes.append(msg) try: rule_model.rstate_id = 1 # Set rule state to "Active" again except AttributeError: @@ -163,10 +174,15 @@ def delete_whitelist(whitelist_id: int) -> List[str]: author = f"{model.user.email} ({model.user.organization})" announce_rtbh_route(rule_model, author) - flashes.append(f"Deleted cache entries for whitelist {whitelist_id}") + msg = f"Deleted cache entries for whitelist {whitelist_id} {model}" + current_app.logger.info(msg) + flashes.append(msg) RuleWhitelistCache.clean_by_whitelist_id(whitelist_id) db.session.delete(model) db.session.commit() + msg = f"Deleted whitelist {whitelist_id} {model}" + flashes.append(msg) + current_app.logger.info(msg) return flashes From 412dd845fcbd4d831fd6bed5a28b74c28f3845d5 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Mon, 17 Mar 2025 19:49:11 +0100 Subject: [PATCH 23/46] Refactor create_app config loading, enhance RTBH logging and flash messages, add file-based logging in app_factory, and update tests to assert flashes instead of specific messages. --- flowapp/__init__.py | 14 +++++++------- flowapp/services/rule_service.py | 24 +++++++++++++++++------- flowapp/services/whitelist_common.py | 4 +++- flowapp/tests/test_rule_service.py | 7 +++---- flowapp/tests/test_whitelist_service.py | 6 +++--- flowapp/utils/app_factory.py | 20 ++++++++++++++++---- withdraw_expired | 1 + 7 files changed, 50 insertions(+), 26 deletions(-) create mode 100644 withdraw_expired diff --git a/flowapp/__init__.py b/flowapp/__init__.py index b201f25..61543e7 100644 --- a/flowapp/__init__.py +++ b/flowapp/__init__.py @@ -25,6 +25,13 @@ def create_app(config_object=None): app = Flask(__name__) + # Load the default configuration for dashboard and main menu + app.config.from_object(InstanceConfig) + if config_object: + app.config.from_object(config_object) + + app.config.setdefault("VERSION", __version__) + # SSO configuration SSO_ATTRIBUTE_MAP = { "eppn": (True, "eppn"), @@ -36,13 +43,6 @@ def create_app(config_object=None): migrate.init_app(app, db) csrf.init_app(app) - # Load the default configuration for dashboard and main menu - app.config.from_object(InstanceConfig) - if config_object: - app.config.from_object(config_object) - - app.config.setdefault("VERSION", __version__) - # Init SSO ext.init_app(app) diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index e0f1149..6b8896c 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -247,22 +247,32 @@ def evaluate_rtbh_against_whitelists_check_results( match relation: case Relation.EQUAL: model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) - flashes.append(f" Rule is equal to active whitelist {whitelist_key}. Rule is whitelisted.") + msg = f"RTBH Rule {model.id} {model} is equal to active whitelist {whitelist_key}. Rule is whitelisted." + flashes.append(msg) + current_app.logger.info(msg) case Relation.SUBNET: parts = subtract_network(target=str(model), whitelist=whitelist_key) wl_id = wl_cache[whitelist_key].id - flashes.append( - f" Rule is supernet of active whitelist {whitelist_key}. Rule is whitelisted, {len(parts)} subnet rules created." - ) + msg = f"RTBH Rule {model.id} {model} is supernet of active whitelist {whitelist_key}. Rule is whitelisted, {len(parts)} subnet rules created." + flashes.append(msg) + current_app.logger.info(msg) for network in parts: - create_rtbh_from_whitelist_parts(model, wl_id, whitelist_key, network, author, user_id) - flashes.append(f"DEBUG: Created RTBH rule for {network}, from whitelist {whitelist_key}") + new_rule = create_rtbh_from_whitelist_parts(model, wl_id, whitelist_key, network, author, user_id) + msg = ( + f"Created RTBH rule {new_rule.id} {new_rule} for {network} parted by whitelist {whitelist_key}" + ) + flashes.append(msg) + current_app.logger.info(msg) model.rstate_id = 4 add_rtbh_rule_to_cache(model, wl_id, RuleOrigin.USER) db.session.commit() case Relation.SUPERNET: model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) - flashes.append(f" Rule is subnet of active whitelist {whitelist_key}. Rule is whitelisted.") + msg = ( + f"RTBH Rule {model.id} {model} is subnet of active whitelist {whitelist_key}. Rule is whitelisted." + ) + current_app.logger.info(msg) + flashes.append(msg) return model diff --git a/flowapp/services/whitelist_common.py b/flowapp/services/whitelist_common.py index 065b79a..3807dca 100644 --- a/flowapp/services/whitelist_common.py +++ b/flowapp/services/whitelist_common.py @@ -216,7 +216,7 @@ def clear_network_cache() -> None: def create_rtbh_from_whitelist_parts( model: RTBH, wl_id: int, whitelist_key: str, network: str, rule_owner: str = "", user_id: int = 0 -) -> None: +) -> RTBH: # default values from model rule_owner = rule_owner or model.get_author() user_id = user_id or model.user_id @@ -240,3 +240,5 @@ def create_rtbh_from_whitelist_parts( add_rtbh_rule_to_cache(new_model, wl_id, RuleOrigin.WHITELIST) announce_rtbh_route(new_model, rule_owner) log_route(user_id, model, RuleTypes.RTBH, rule_owner) + + return new_model diff --git a/flowapp/tests/test_rule_service.py b/flowapp/tests/test_rule_service.py index 16daf10..490d2c2 100644 --- a/flowapp/tests/test_rule_service.py +++ b/flowapp/tests/test_rule_service.py @@ -444,7 +444,7 @@ def test_equal_relation(self, app, whitelist_fixture): mock_whitelist_rule.assert_called_once_with(model, whitelist_fixture) # Verify the flash message - assert any("Rule is equal to active whitelist" in msg for msg in flashes) + assert flashes # Verify the correct model was returned assert result == model @@ -490,8 +490,7 @@ def test_subnet_relation(self, app, whitelist_fixture): mock_commit.assert_called_once() # Verify the flash messages - assert any("Rule is supernet of active whitelist" in msg for msg in flashes) - assert any("Created RTBH rule for" in msg for msg in flashes) + assert flashes # Verify model was updated to whitelisted state assert model.rstate_id == 4 @@ -522,7 +521,7 @@ def test_supernet_relation(self, app, whitelist_fixture): mock_whitelist_rule.assert_called_once_with(model, whitelist_fixture) # Verify the flash message - assert any("Rule is subnet of active whitelist" in msg for msg in flashes) + assert flashes # Verify the correct model was returned assert result == model diff --git a/flowapp/tests/test_whitelist_service.py b/flowapp/tests/test_whitelist_service.py index 519c2aa..460a172 100644 --- a/flowapp/tests/test_whitelist_service.py +++ b/flowapp/tests/test_whitelist_service.py @@ -253,7 +253,7 @@ def test_delete_whitelist_with_user_rules(self, mock_announce, app, db): # Verify flash messages assert isinstance(flashes, list) - assert any("Set rule" in msg for msg in flashes) + assert flashes # Verify the whitelist was deleted assert db.session.get(Whitelist, whitelist.id) is None @@ -310,7 +310,7 @@ def test_delete_whitelist_with_whitelist_created_rules(self, mock_clean, app, db # Verify flash messages assert isinstance(flashes, list) - assert any("Deleted rule" in msg for msg in flashes) + assert flashes # Verify the rule was deleted assert db.session.get(RTBH, rtbh_rule.id) is None @@ -366,7 +366,7 @@ def test_equal_relation(self, app): mock_withdraw.assert_called_once_with(rtbh_rule) # Verify the flash message - assert any("equal to whitelist" in msg for msg in flashes) + assert flashes # Verify the correct model was returned assert result == whitelist_model diff --git a/flowapp/utils/app_factory.py b/flowapp/utils/app_factory.py index bc6559b..234522f 100644 --- a/flowapp/utils/app_factory.py +++ b/flowapp/utils/app_factory.py @@ -40,17 +40,29 @@ def configure_logging(app): for handler in app.logger.handlers[:]: app.logger.removeHandler(handler) + # Retrieve log level and file name from config + log_level = app.config.get("LOG_LEVEL", "DEBUG").upper() + log_file = app.config.get("LOG_FILE", "app.log") + # Define log format log_format = "%(asctime)s | %(levelname)s | %(message)s" log_datefmt = "%Y-%m-%d %H:%M:%S" + formatter = logging.Formatter(log_format, datefmt=log_datefmt) - # Create a new handler with the desired format + # Console handler console_handler = logging.StreamHandler() - console_handler.setFormatter(logging.Formatter(log_format, datefmt=log_datefmt)) + console_handler.setFormatter(formatter) + + # File handler + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + + # Set logger level + app.logger.setLevel(getattr(logging, log_level, logging.DEBUG)) - # Set logger level and attach the handler - app.logger.setLevel(logging.DEBUG) + # Attach handlers app.logger.addHandler(console_handler) + app.logger.addHandler(file_handler) return app diff --git a/withdraw_expired b/withdraw_expired new file mode 100644 index 0000000..0519ecb --- /dev/null +++ b/withdraw_expired @@ -0,0 +1 @@ + \ No newline at end of file From eb4c683f5612f1066a4c92bf938c91d99d778c6d Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Tue, 18 Mar 2025 10:23:00 +0100 Subject: [PATCH 24/46] integration test for RTBH api and whitelist, minor bug fix in utils --- flowapp/models/utils.py | 4 + flowapp/tests/conftest.py | 24 ++ flowapp/tests/test_api_v3.py | 27 ++ .../tests/test_api_whitelist_integration.py | 273 ++++++++++++++++++ flowapp/tests/test_forms.py | 7 - flowapp/tests/test_forms_cl.py | 8 - flowapp/views/api_common.py | 1 - 7 files changed, 328 insertions(+), 16 deletions(-) create mode 100644 flowapp/tests/test_api_whitelist_integration.py diff --git a/flowapp/models/utils.py b/flowapp/models/utils.py index 7472d72..d0bf120 100644 --- a/flowapp/models/utils.py +++ b/flowapp/models/utils.py @@ -40,6 +40,8 @@ def check_rule_limit(org_id: int, rule_type: RuleTypes) -> bool: count = db.session.query(RTBH).filter_by(org_id=org_id, rstate_id=1).count() return count >= org.limit_rtbh or rtbh >= rtbh_limit + return False + def check_global_rule_limit(rule_type: RuleTypes) -> bool: flowspec4_limit = current_app.config.get("FLOWSPEC4_MAX_RULES", 9000) @@ -58,6 +60,8 @@ def check_global_rule_limit(rule_type: RuleTypes) -> bool: if rule_type == RuleTypes.RTBH: return rtbh >= rtbh_limit + return False + def get_whitelist_model_if_exists(form_data): """ diff --git a/flowapp/tests/conftest.py b/flowapp/tests/conftest.py index 5960434..ae01c6d 100644 --- a/flowapp/tests/conftest.py +++ b/flowapp/tests/conftest.py @@ -12,6 +12,7 @@ from flowapp import db as _db from datetime import datetime import flowapp.models +from flowapp.models.organization import Organization TESTDB = "test_project.db" @@ -68,6 +69,7 @@ def app(request): LOCAL_USER_UUID="jiri.vrany@cesnet.cz", LOCAL_AUTH=True, ALLOWED_COMMUNITIES=[1, 2, 3], + WTF_CSRF_ENABLED=False, ) print("\n----- CREATE FLASK APPLICATION\n") @@ -115,6 +117,12 @@ def db(app, request): print("#: inserting users") flowapp.models.insert_users(users) + org = _db.session.query(Organization).filter_by(id=1).first() + # Update the organization address range to include our test networks + org.arange = "147.230.0.0/16\n2001:718:1c01::/48\n192.168.0.0/16\n10.0.0.0/8" + _db.session.commit() + print("\n----- UPDATED TEST ORG 1 \n", org) + def teardown(): _db.session.commit() _db.drop_all() @@ -186,3 +194,19 @@ def auth_client(client): print("\n----- CREATE AUTHENTICATED FLASK TEST CLIENT\n") client.get("/local-login") return client + + +@pytest.fixture(autouse=True) +def reset_org_limits(db, app): + """ + Automatically reset organization limits after each test that modifies them. + """ + yield # Allow test execution + + with app.app_context(): + org = db.session.query(Organization).filter_by(id=1).first() + if org: + org.limit_flowspec4 = 0 + org.limit_flowspec6 = 0 + org.limit_rtbh = 0 + db.session.commit() diff --git a/flowapp/tests/test_api_v3.py b/flowapp/tests/test_api_v3.py index 96b2478..87f7094 100644 --- a/flowapp/tests/test_api_v3.py +++ b/flowapp/tests/test_api_v3.py @@ -1,10 +1,37 @@ import json + from flowapp.models import Flowspec4, Organization V_PREFIX = "/api/v3" +def test_create_rtbh_only(client, app, db, jwt_token): + """ + Test creating an RTBH rule through API that exactly matches an existing whitelist. + + The rule should be created but marked as whitelisted (rstate_id=4), and a cache entry + should be created linking the rule to the whitelist. + """ + + # Now create the RTBH rule via API + res = client.post( + f"{V_PREFIX}/rules/rtbh", + headers={"x-access-token": jwt_token}, + json={ + "community": 1, + "ipv4": "147.230.17.17", + "ipv4_mask": 32, + "expires": "10/25/2050 14:46", + }, + ) + + # Verify response is successful + assert res.status_code == 201 + data = json.loads(res.data) + assert data["rule"] is not None + + def test_token(client, jwt_token): """ test that token authorization works diff --git a/flowapp/tests/test_api_whitelist_integration.py b/flowapp/tests/test_api_whitelist_integration.py new file mode 100644 index 0000000..0391a76 --- /dev/null +++ b/flowapp/tests/test_api_whitelist_integration.py @@ -0,0 +1,273 @@ +""" +Integration test for the API view when interacting with whitelists. + +This test suite verifies the behavior when creating RTBH rules through the API +when there are existing whitelists that could affect the rules. +""" + +import json +import pytest +from datetime import datetime, timedelta + +from flowapp.constants import RuleTypes, RuleOrigin +from flowapp.models import RTBH, RuleWhitelistCache, Organization +from flowapp.services import whitelist_service + + +@pytest.fixture +def whitelist_data(): + """Create whitelist data for testing""" + return { + "ip": "192.168.1.0", + "mask": 24, + "comment": "Test whitelist for API integration test", + "expires": datetime.now() + timedelta(days=1), + } + + +@pytest.fixture +def rtbh_api_payload(): + """Create RTBH rule data for API testing""" + return { + "community": 1, + "ipv4": "192.168.1.0", + "ipv4_mask": 24, + "expires": (datetime.now() + timedelta(days=1)).strftime("%m/%d/%Y %H:%M"), + "comment": "Test RTBH rule via API", + } + + +def test_create_rtbh_equal_to_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): + """ + Test creating an RTBH rule through API that exactly matches an existing whitelist. + + The rule should be created but marked as whitelisted (rstate_id=4), and a cache entry + should be created linking the rule to the whitelist. + """ + # First, configure app to include community ID 1 in allowed communities + app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) + + # Create the whitelist directly using the service + with app.app_context(): + # Create user and organization if needed for the whitelist + org = db.session.query(Organization).first() + + # Create the whitelist + whitelist_model, _ = whitelist_service.create_or_update_whitelist( + form_data=whitelist_data, + user_id=1, # Using user ID 1 from test fixtures + org_id=org.id, + user_email="test@example.com", + org_name=org.name, + ) + + # Verify whitelist was created + assert whitelist_model.id is not None + assert whitelist_model.ip == whitelist_data["ip"] + assert whitelist_model.mask == whitelist_data["mask"] + + # Now create the RTBH rule via API + response = client.post( + "/api/v3/rules/rtbh", + headers={"x-access-token": jwt_token}, + json=rtbh_api_payload, + ) + + # Verify response is successful + assert response.status_code == 201 + data = json.loads(response.data) + assert data["rule"] is not None + rule_id = data["rule"]["id"] + + # Now verify the rule was created but marked as whitelisted + with app.app_context(): + rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + assert rtbh_rule is not None + assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state + + # Verify a cache entry was created + cache_entry = RuleWhitelistCache.query.filter_by( + rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id + ).first() + + assert cache_entry is not None + assert cache_entry.rorigin == RuleOrigin.USER.value + + +def test_create_rtbh_supernet_of_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): + """ + Test creating an RTBH rule through API that is a supernet of an existing whitelist. + + The rule should be created with whitelisted state (rstate_id=4) and smaller subnet rules + should be created for the non-whitelisted parts. + """ + # Configure app with allowed communities + app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) + + # First create a whitelist for a subnet + whitelist_data["ip"] = "192.168.1.128" + whitelist_data["mask"] = 25 # Subnet of the RTBH rule which will be /24 + + # Create the whitelist directly using the service + with app.app_context(): + org = db.session.query(Organization).first() + whitelist_model, _ = whitelist_service.create_or_update_whitelist( + form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name + ) + + # Verify whitelist was created + assert whitelist_model.id is not None + + # Now create the RTBH rule via API that covers both the whitelist and additional space + headers = {"x-access-token": jwt_token} + response = client.post( + "/api/v3/rules/rtbh", + headers=headers, + json=rtbh_api_payload, # This is a /24, which contains the /25 whitelist + ) + + # Verify response is successful + assert response.status_code == 201 + data = json.loads(response.data) + assert data["rule"] is not None + rule_id = data["rule"]["id"] + + # Now verify the rule was created and marked as whitelisted + with app.app_context(): + # Check the original rule + rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + assert rtbh_rule is not None + assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state + + # Check if a new subnet rule was created for the non-whitelisted part + subnet_rule = ( + db.session.query(RTBH) + .filter( + RTBH.ipv4 == "192.168.1.0", + RTBH.ipv4_mask == 25, # This would be the other half not covered by the whitelist + ) + .first() + ) + + assert subnet_rule is not None + assert subnet_rule.rstate_id == 1 # Active status + + # Verify cache entries + # Main rule should be cached as a USER rule + user_cache = RuleWhitelistCache.query.filter_by( + rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id, rorigin=RuleOrigin.USER.value + ).first() + assert user_cache is not None + + # Subnet rule should be cached as a WHITELIST rule + whitelist_cache = RuleWhitelistCache.query.filter_by( + rid=subnet_rule.id, + rtype=RuleTypes.RTBH.value, + whitelist_id=whitelist_model.id, + rorigin=RuleOrigin.WHITELIST.value, + ).first() + assert whitelist_cache is not None + + +def test_create_rtbh_subnet_of_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): + """ + Test creating an RTBH rule through API that is contained within an existing whitelist. + + The rule should be created but immediately marked as whitelisted (rstate_id=4). + """ + # Configure app with allowed communities + app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) + + # First create a whitelist for a supernet + whitelist_data["ip"] = "192.168.0.0" + whitelist_data["mask"] = 16 # Supernet that contains the RTBH rule + + # Create the whitelist directly using the service + with app.app_context(): + all_rtbh_rules_before = db.session.query(RTBH).count() + + org = db.session.query(Organization).first() + whitelist_model, _ = whitelist_service.create_or_update_whitelist( + form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name + ) + + # Verify whitelist was created + assert whitelist_model.id is not None + + # Now create the RTBH rule via API that is inside the whitelist + headers = {"x-access-token": jwt_token} + response = client.post( + "/api/v3/rules/rtbh", headers=headers, json=rtbh_api_payload # This is a /24 inside the /16 whitelist + ) + + # Verify response is successful + assert response.status_code == 201 + data = json.loads(response.data) + assert data["rule"] is not None + rule_id = data["rule"]["id"] + + # Now verify the rule was created but marked as whitelisted + with app.app_context(): + rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + assert rtbh_rule is not None + assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state + + # Verify a cache entry was created + cache_entry = RuleWhitelistCache.query.filter_by( + rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id + ).first() + + assert cache_entry is not None + assert cache_entry.rorigin == RuleOrigin.USER.value + + # Verify no additional rules were created + all_rtbh_rules = db.session.query(RTBH).count() + assert all_rtbh_rules - all_rtbh_rules_before == 1 + + +def test_create_rtbh_no_relation_to_whitelist(client, app, db, jwt_token, whitelist_data, rtbh_api_payload): + """ + Test creating an RTBH rule through API that has no relation to any existing whitelist. + + The rule should be created normally with active state (rstate_id=1). + """ + # Configure app with allowed communities + app.config.update({"ALLOWED_COMMUNITIES": [1, 2, 3]}) + + # First create a whitelist for a completely different network + whitelist_data["ip"] = "10.0.0.0" + whitelist_data["mask"] = 8 + + # Create the whitelist directly using the service + with app.app_context(): + org = db.session.query(Organization).first() + whitelist_model, _ = whitelist_service.create_or_update_whitelist( + form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name + ) + + # Verify whitelist was created + assert whitelist_model.id is not None + + # Now create the RTBH rule via API for a network not covered by any of the whitelists + rtbh_api_payload["ipv4"] = "147.230.17.185" + rtbh_api_payload["ipv4_mask"] = 32 + + headers = {"x-access-token": jwt_token} + response = client.post("/api/v3/rules/rtbh", headers=headers, json=rtbh_api_payload) # This is 192.168.1.0/24 + + # Verify response is successful + assert response.status_code == 201 + data = json.loads(response.data) + assert data["rule"] is not None + rule_id = data["rule"]["id"] + + # Now verify the rule was created with active state + with app.app_context(): + rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + assert rtbh_rule is not None + assert rtbh_rule.rstate_id == 1 # 1 = active state + + # Verify no cache entry was created + cache_entry = RuleWhitelistCache.query.filter_by(rid=rule_id, rtype=RuleTypes.RTBH.value).first() + + assert cache_entry is None diff --git a/flowapp/tests/test_forms.py b/flowapp/tests/test_forms.py index 76c9896..a2092ed 100644 --- a/flowapp/tests/test_forms.py +++ b/flowapp/tests/test_forms.py @@ -3,13 +3,6 @@ import flowapp.forms -@pytest.fixture() -def app(): - app = Flask(__name__) - app.secret_key = "test" - return app - - @pytest.fixture() def ip_form(app, field_class): with app.test_request_context(): # Push the request context diff --git a/flowapp/tests/test_forms_cl.py b/flowapp/tests/test_forms_cl.py index d5373ac..c962665 100644 --- a/flowapp/tests/test_forms_cl.py +++ b/flowapp/tests/test_forms_cl.py @@ -29,14 +29,6 @@ def create_form_data(data): return MultiDict(processed_data) -@pytest.fixture() -def app(): - """Create Flask app with CSRF disabled for testing""" - app = Flask(__name__) - app.config.update(SECRET_KEY="test_secret", WTF_CSRF_ENABLED=False, TESTING=True) - return app - - @pytest.fixture def valid_datetime(): return (datetime.now() + timedelta(days=1)).strftime("%Y-%m-%dT%H:%M") diff --git a/flowapp/views/api_common.py b/flowapp/views/api_common.py index 39b7077..4fad70e 100644 --- a/flowapp/views/api_common.py +++ b/flowapp/views/api_common.py @@ -305,7 +305,6 @@ def create_rtbh(current_user): count = db.session.query(RTBH).filter_by(rstate_id=1).count() return global_limit_reached(count=count, rule_type=RuleTypes.RTBH) - # check limit if check_rule_limit(current_user["org_id"], RuleTypes.RTBH): count = db.session.query(RTBH).filter_by(rstate_id=1, org_id=current_user["org_id"]).count() return limit_reached(count=count, rule_type=RuleTypes.RTBH, org_id=current_user["org_id"]) From 354cb5317dbb4d7b4fa2c4d22f3067267335de51 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Tue, 18 Mar 2025 12:46:02 +0100 Subject: [PATCH 25/46] =?UTF-8?q?=1B[200~Refactor=20route=20announcement?= =?UTF-8?q?=20and=20whitelist=20expiration=20handling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Move announce_all_routes to flowapp.services.base for better modularity. - Add delete_expired_whitelists function in whitelist_service.py to remove expired whitelist entries. - Update withdraw_expired route to call delete_expired_whitelists before withdrawing expired routes. - Import cleanup: remove unused Flask imports in test_forms.py and test_forms_cl.py. - Optimize rules.py by importing and using announce_all_routes and delete_expired_whitelists directly. --- flowapp/services/__init__.py | 6 ++- flowapp/services/base.py | 65 ++++++++++++++++++++++++++- flowapp/services/whitelist_service.py | 14 ++++++ flowapp/tests/test_forms.py | 1 - flowapp/tests/test_forms_cl.py | 1 - flowapp/views/rules.py | 65 ++------------------------- 6 files changed, 85 insertions(+), 67 deletions(-) diff --git a/flowapp/services/__init__.py b/flowapp/services/__init__.py index 05b9695..456eb43 100644 --- a/flowapp/services/__init__.py +++ b/flowapp/services/__init__.py @@ -4,7 +4,9 @@ create_or_update_rtbh_rule, ) -from .whitelist_service import create_or_update_whitelist, delete_whitelist +from .whitelist_service import create_or_update_whitelist, delete_whitelist, delete_expired_whitelists + +from .base import announce_all_routes __all__ = [ create_or_update_ipv4_rule, @@ -12,4 +14,6 @@ create_or_update_rtbh_rule, create_or_update_whitelist, delete_whitelist, + delete_expired_whitelists, + announce_all_routes, ] diff --git a/flowapp/services/base.py b/flowapp/services/base.py index f06f16f..77adeb0 100644 --- a/flowapp/services/base.py +++ b/flowapp/services/base.py @@ -1,6 +1,8 @@ -from flowapp import messages +from datetime import datetime +from operator import ge, lt +from flowapp import constants, db, messages from flowapp.constants import ANNOUNCE, WITHDRAW -from flowapp.models import RTBH +from flowapp.models import RTBH, Flowspec4, Flowspec6 from flowapp.output import Route, RouteSources, announce_route @@ -30,3 +32,62 @@ def withdraw_rtbh_route(model: RTBH) -> None: command=command, ) announce_route(route) + + +def announce_all_routes(action=constants.ANNOUNCE): + """ + get routes from db and send it to ExaBGB api + + @TODO take the request away, use some kind of messaging (maybe celery?) + :param action: action with routes - announce valid routes or withdraw expired routes + """ + today = datetime.now() + comp_func = ge if action == constants.ANNOUNCE else lt + + rules4 = ( + db.session.query(Flowspec4) + .filter(Flowspec4.rstate_id == 1) + .filter(comp_func(Flowspec4.expires, today)) + .order_by(Flowspec4.expires.desc()) + .all() + ) + rules6 = ( + db.session.query(Flowspec6) + .filter(Flowspec6.rstate_id == 1) + .filter(comp_func(Flowspec6.expires, today)) + .order_by(Flowspec6.expires.desc()) + .all() + ) + rules_rtbh = ( + db.session.query(RTBH) + .filter(RTBH.rstate_id == 1) + .filter(comp_func(RTBH.expires, today)) + .order_by(RTBH.expires.desc()) + .all() + ) + + messages_v4 = [messages.create_ipv4(rule, action) for rule in rules4] + messages_v6 = [messages.create_ipv6(rule, action) for rule in rules6] + messages_rtbh = [messages.create_rtbh(rule, action) for rule in rules_rtbh] + + messages_all = [] + messages_all.extend(messages_v4) + messages_all.extend(messages_v6) + messages_all.extend(messages_rtbh) + + author_action = "announce all" if action == constants.ANNOUNCE else "withdraw all expired" + + for command in messages_all: + route = Route( + author=f"System call / {author_action} rules", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + if action == constants.WITHDRAW: + for ruleset in [rules4, rules6, rules_rtbh]: + for rule in ruleset: + rule.rstate_id = 2 + + db.session.commit() diff --git a/flowapp/services/whitelist_service.py b/flowapp/services/whitelist_service.py index b749b4b..a934cc8 100644 --- a/flowapp/services/whitelist_service.py +++ b/flowapp/services/whitelist_service.py @@ -123,6 +123,20 @@ def map_rtbh_rules_to_strings(all_rtbh_rules: List[RTBH]) -> Dict[str, RTBH]: return {str(rule): rule for rule in all_rtbh_rules} +def delete_expired_whitelists() -> List[str]: + """ + Delete all expired whitelist entries from the database. + + Returns: + List of messages for the user + """ + expired_whitelists = Whitelist.query.filter(Whitelist.expires < db.func.now()).all() + flashes = [] + for model in expired_whitelists: + flashes.extend(delete_whitelist(model.id)) + return flashes + + def delete_whitelist(whitelist_id: int) -> List[str]: """ Delete a whitelist entry from the database. diff --git a/flowapp/tests/test_forms.py b/flowapp/tests/test_forms.py index a2092ed..e9620d8 100644 --- a/flowapp/tests/test_forms.py +++ b/flowapp/tests/test_forms.py @@ -1,5 +1,4 @@ import pytest -from flask import Flask import flowapp.forms diff --git a/flowapp/tests/test_forms_cl.py b/flowapp/tests/test_forms_cl.py index c962665..c6b42bc 100644 --- a/flowapp/tests/test_forms_cl.py +++ b/flowapp/tests/test_forms_cl.py @@ -1,7 +1,6 @@ import pytest from datetime import datetime, timedelta from werkzeug.datastructures import MultiDict -from flask import Flask from flowapp.forms import ( UserForm, BulkUserForm, diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py index 4cad6d2..2ac5981 100644 --- a/flowapp/views/rules.py +++ b/flowapp/views/rules.py @@ -1,11 +1,10 @@ # flowapp/views/admin.py from datetime import datetime, timedelta -from operator import ge, lt from collections import namedtuple from flask import Blueprint, current_app, flash, redirect, render_template, request, session, url_for -from flowapp import constants, db, messages +from flowapp import constants, db from flowapp.auth import ( admin_required, auth_required, @@ -31,7 +30,7 @@ ) from flowapp.models.rules.whitelist import RuleWhitelistCache from flowapp.output import ROUTE_MODELS, announce_route, log_route, log_withdraw, RouteSources, Route -from flowapp.services import rule_service +from flowapp.services import rule_service, announce_all_routes, delete_expired_whitelists from flowapp.utils import ( flash_errors, get_state_by_time, @@ -680,64 +679,6 @@ def announce_all(): @rules.route("/withdraw_expired", methods=["GET"]) @localhost_only def withdraw_expired(): + delete_expired_whitelists() announce_all_routes(constants.WITHDRAW) return " " - - -def announce_all_routes(action=constants.ANNOUNCE): - """ - get routes from db and send it to ExaBGB api - - @TODO take the request away, use some kind of messaging (maybe celery?) - :param action: action with routes - announce valid routes or withdraw expired routes - """ - today = datetime.now() - comp_func = ge if action == constants.ANNOUNCE else lt - - rules4 = ( - db.session.query(Flowspec4) - .filter(Flowspec4.rstate_id == 1) - .filter(comp_func(Flowspec4.expires, today)) - .order_by(Flowspec4.expires.desc()) - .all() - ) - rules6 = ( - db.session.query(Flowspec6) - .filter(Flowspec6.rstate_id == 1) - .filter(comp_func(Flowspec6.expires, today)) - .order_by(Flowspec6.expires.desc()) - .all() - ) - rules_rtbh = ( - db.session.query(RTBH) - .filter(RTBH.rstate_id == 1) - .filter(comp_func(RTBH.expires, today)) - .order_by(RTBH.expires.desc()) - .all() - ) - - messages_v4 = [messages.create_ipv4(rule, action) for rule in rules4] - messages_v6 = [messages.create_ipv6(rule, action) for rule in rules6] - messages_rtbh = [messages.create_rtbh(rule, action) for rule in rules_rtbh] - - messages_all = [] - messages_all.extend(messages_v4) - messages_all.extend(messages_v6) - messages_all.extend(messages_rtbh) - - author_action = "announce all" if action == constants.ANNOUNCE else "withdraw all expired" - - for command in messages_all: - route = Route( - author=f"System call / {author_action} rules", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - - if action == constants.WITHDRAW: - for ruleset in [rules4, rules6, rules_rtbh]: - for rule in ruleset: - rule.rstate_id = 2 - - db.session.commit() From 910309d80e67abc65b60d588375895a1950d3365 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 20 Mar 2025 12:21:22 +0100 Subject: [PATCH 26/46] # Add log cleanup functionality and enhance logging - Add method to Log model to purge logs older than 30 days - Implement string representation methods for Log objects - Add application logging for log entries in output.py - Integrate log cleanup into withdraw_expired endpoint --- flowapp/models/log.py | 18 ++++++++++++++++++ flowapp/output.py | 2 ++ flowapp/views/rules.py | 8 ++++++++ 3 files changed, 28 insertions(+) diff --git a/flowapp/models/log.py b/flowapp/models/log.py index e32bb30..16cf5b1 100644 --- a/flowapp/models/log.py +++ b/flowapp/models/log.py @@ -1,3 +1,6 @@ +from datetime import datetime, timedelta + +from flowapp.constants import RuleTypes from .base import db @@ -19,3 +22,18 @@ def __init__(self, time, task, user_id, rule_type, rule_id, author): self.rule_id = rule_id self.user_id = user_id self.author = author + + @classmethod + def delete_old(cls, days: int = 30): + """Delete logs older than :param days from the database""" + cls.query.filter(cls.time < datetime.now() - timedelta(days=days)).delete() + db.session.commit() + + def __repr__(self): + return f"" + + def __str__(self): + """ + {"author": "vrany@cesnet.cz / Cel\u00fd sv\u011bt", "source": "UI", "command": "cmd"} + """ + return f"{self.author} - {RuleTypes(self.rule_type).name}({self.rule_id}) - {self.task}" diff --git a/flowapp/output.py b/flowapp/output.py index fcb385f..0d02367 100644 --- a/flowapp/output.py +++ b/flowapp/output.py @@ -113,6 +113,7 @@ def log_route(user_id: int, route_model: Union[RTBH, Flowspec4, Flowspec6], rule author=author, ) db.session.add(log) + current_app.logger.info(log) db.session.commit() @@ -135,4 +136,5 @@ def log_withdraw(user_id: int, task: str, rule_type: RuleTypes, deleted_id: int, author=author, ) db.session.add(log) + current_app.logger.info(log) db.session.commit() diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py index 2ac5981..10fd73f 100644 --- a/flowapp/views/rules.py +++ b/flowapp/views/rules.py @@ -28,6 +28,7 @@ get_user_nets, insert_initial_communities, ) +from flowapp.models.log import Log from flowapp.models.rules.whitelist import RuleWhitelistCache from flowapp.output import ROUTE_MODELS, announce_route, log_route, log_withdraw, RouteSources, Route from flowapp.services import rule_service, announce_all_routes, delete_expired_whitelists @@ -679,6 +680,13 @@ def announce_all(): @rules.route("/withdraw_expired", methods=["GET"]) @localhost_only def withdraw_expired(): + """ + cleaning endpoint + deletes expired whitelists + withdraws all expired routes from ExaBGP + deletes logs older than 30 days + """ delete_expired_whitelists() announce_all_routes(constants.WITHDRAW) + Log.delete_old() return " " From 0b8a2bf9743b5628661a5bb088ded0db098e2ab1 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 21 Mar 2025 12:59:22 +0100 Subject: [PATCH 27/46] Refactor rule reactivation logic and improve limit handling - Introduce reactivate_rule function in rule_service to centralize rule reactivation logic. - Replace inline limit checks with service-level handling for global and organization-specific limits. - Simplify rule reactivation in rules.py by delegating logic to the service layer. - Add redirects for global_limit_reached and limit_reached messages to improve user experience. - Ensure cleaner and more maintainable code by reducing duplication and improving readability. --- flowapp/services/__init__.py | 2 + flowapp/services/rule_service.py | 99 +++++++++++++++++++++++++++++++- flowapp/views/rules.py | 73 ++++++++--------------- 3 files changed, 121 insertions(+), 53 deletions(-) diff --git a/flowapp/services/__init__.py b/flowapp/services/__init__.py index 456eb43..be4b838 100644 --- a/flowapp/services/__init__.py +++ b/flowapp/services/__init__.py @@ -2,6 +2,7 @@ create_or_update_ipv4_rule, create_or_update_ipv6_rule, create_or_update_rtbh_rule, + reactivate_rule, ) from .whitelist_service import create_or_update_whitelist, delete_whitelist, delete_expired_whitelists @@ -16,4 +17,5 @@ delete_whitelist, delete_expired_whitelists, announce_all_routes, + reactivate_rule, ] diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index 6b8896c..d671324 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -7,12 +7,12 @@ """ from datetime import datetime -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Union from flask import current_app from flowapp import db, messages -from flowapp.constants import RuleOrigin, RuleTypes, ANNOUNCE +from flowapp.constants import WITHDRAW, RuleOrigin, RuleTypes, ANNOUNCE from flowapp.models import ( get_ipv4_model_if_exists, get_ipv6_model_if_exists, @@ -22,7 +22,8 @@ RTBH, Whitelist, ) -from flowapp.output import Route, announce_route, log_route, RouteSources +from flowapp.models.utils import check_global_rule_limit, check_rule_limit +from flowapp.output import ROUTE_MODELS, Route, announce_route, log_route, RouteSources, log_withdraw from flowapp.services.base import announce_rtbh_route from flowapp.services.whitelist_common import ( Relation, @@ -35,6 +36,98 @@ from .whitelist_common import check_rule_against_whitelists +def reactivate_rule( + rule_type: RuleTypes, + rule_id: int, + expires: datetime, + comment: str, + user_id: int, + org_id: int, + user_email: str, + org_name: str, +) -> Tuple[Union[RTBH, Flowspec4, Flowspec6], List[str]]: + """ + Reactivate a rule by setting a new expiration time. + + Args: + rule_type: Type of rule (RTBH, IPv4, IPv6) + rule_id: ID of the rule to reactivate + expires: New expiration datetime + comment: Updated comment + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + + Returns: + Tuple containing (rule_model, messages) + """ + model_name = {RuleTypes.RTBH: RTBH, RuleTypes.IPv4: Flowspec4, RuleTypes.IPv6: Flowspec6}[rule_type] + + model = db.session.get(model_name, rule_id) + if not model: + return None, ["Rule not found"] + + flashes = [] + + # Check if rule will be reactivated + state = get_state_by_time(expires) + + # Check global limit + if state == 1 and check_global_rule_limit(rule_type.value): + return model, ["global_limit_reached"] + + # Check org limit + if state == 1 and check_rule_limit(org_id, rule_type=rule_type.value): + return model, ["limit_reached"] + + # Set new expiration date + model.expires = expires + # Set again the active state + model.rstate_id = state + model.comment = comment + db.session.commit() + flashes.append("Rule successfully updated") + + route_model = ROUTE_MODELS[rule_type.value] + + if model.rstate_id == 1: + # Announce route + command = route_model(model, ANNOUNCE) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + # Log changes + log_route( + user_id, + model, + rule_type, + f"{user_email} / {org_name}", + ) + else: + # Withdraw route + command = route_model(model, WITHDRAW) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + # Log changes + log_withdraw( + user_id, + route.command, + rule_type, + model.id, + f"{user_email} / {org_name}", + ) + + return model, flashes + + def create_or_update_ipv4_rule( form_data: Dict, user_id: int, org_id: int, user_email: str, org_name: str ) -> Tuple[Flowspec4, str]: diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py index 10fd73f..4b0d09a 100644 --- a/flowapp/views/rules.py +++ b/flowapp/views/rules.py @@ -70,6 +70,10 @@ def reactivate_rule(rule_type, rule_id): form_name = DATA_FORMS[rule_type] model = db.session.get(model_name, rule_id) + if not model: + flash("Rule not found", "alert-danger") + return redirect(url_for("index")) + form = form_name(request.form, obj=model) form.net_ranges = get_user_nets(session["user_id"]) @@ -87,63 +91,31 @@ def reactivate_rule(rule_type, rule_id): if rule_type == RuleTypes.IPv6.value: form.next_header.data = model.next_header - # do not need to validate - all is readonly + # Process form submission if request.method == "POST": - # check if rule will be reactivated - state = get_state_by_time(form.expires.data) + # Round expiration time to 10 minutes + expires = round_to_ten_minutes(form.expires.data) + + # Use the service to reactivate the rule + _, messages = rule_service.reactivate_rule( + rule_type=enum_rule_type, + rule_id=rule_id, + expires=expires, + comment=form.comment.data, + user_id=session["user_id"], + org_id=session["user_org_id"], + user_email=session["user_email"], + org_name=session["user_org"], + ) - # check global limit - check_gl = check_global_rule_limit(rule_type) - if state == 1 and check_gl: + # Handle special messages (redirects) + if "global_limit_reached" in messages: return redirect(url_for("rules.global_limit_reached", rule_type=rule_type)) - # check org limit - if state == 1 and check_rule_limit(session["user_org_id"], rule_type=rule_type): + if "limit_reached" in messages: return redirect(url_for("rules.limit_reached", rule_type=rule_type)) - # set new expiration date - model.expires = round_to_ten_minutes(form.expires.data) - # set again the active state - model.rstate_id = get_state_by_time(form.expires.data) - model.comment = form.comment.data - db.session.commit() flash("Rule successfully updated", "alert-success") - route_model = ROUTE_MODELS[rule_type] - - if model.rstate_id == 1: - # announce route - command = route_model(model, constants.ANNOUNCE) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - # log changes - Use the enum value here - log_route( - session["user_id"], - model, - enum_rule_type, # Pass the enum instead of integer - f"{session['user_email']} / {session['user_org']}", - ) - else: - # withdraw route - command = route_model(model, constants.WITHDRAW) - route = Route( - author=f"{session['user_email']} / {session['user_org']}", - source=RouteSources.UI, - command=command, - ) - announce_route(route) - # log changes - Use the enum value here - log_withdraw( - session["user_id"], - route.command, - enum_rule_type, # Pass the enum instead of integer - model.id, - f"{session['user_email']} / {session['user_org']}", - ) - return redirect( url_for( "dashboard.index", @@ -157,6 +129,7 @@ def reactivate_rule(rule_type, rule_id): else: flash_errors(form) + # For GET requests, prepare the form for display form.expires.data = model.expires for field in form: if field.name not in ["expires", "csrf_token", "comment"]: From bcdcc54b37f1496cfc44bce02f0db9ebfa9e72e5 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 21 Mar 2025 13:56:44 +0100 Subject: [PATCH 28/46] modified Rule service to check RTBH rule whitelisting during rule reactivation in /reactivate enpoint --- flowapp/services/rule_service.py | 33 ++++++++++++++++++++++---------- flowapp/views/rules.py | 3 ++- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index d671324..04f8160 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -83,15 +83,19 @@ def reactivate_rule( # Set new expiration date model.expires = expires - # Set again the active state - model.rstate_id = state + # Set again the active state, if the rule is not whitelisted RTBH + if rule_type == RuleTypes.RTBH: + model = check_rtbh_whitelisted(model, user_id, flashes, f"{user_email} / {org_name}") + else: + model.rstate_id = state + model.comment = comment db.session.commit() - flashes.append("Rule successfully updated") route_model = ROUTE_MODELS[rule_type.value] if model.rstate_id == 1: + flashes.append("Rule successfully updated, state set to active.") # Announce route command = route_model(model, ANNOUNCE) route = Route( @@ -124,6 +128,10 @@ def reactivate_rule( model.id, f"{user_email} / {org_name}", ) + if model.rstate_id == 4: + flashes.append("Rule successfully updated, state set to whitelisted.") + else: + flashes.append("Rule successfully updated, state set to inactive.") return model, flashes @@ -306,7 +314,17 @@ def create_or_update_rtbh_rule( # rule author for logging and announcing author = f"{user_email} / {org_name}" + # Check if rule is whitelisted + model = check_rtbh_whitelisted(model, user_id, flashes, author) + + announce_rtbh_route(model, author=author) + # Log changes + log_route(user_id, model, RuleTypes.RTBH, author) + + return model, flashes + +def check_rtbh_whitelisted(model: RTBH, user_id: int, flashes: List[str], author: str) -> None: # Check if rule is whitelisted allowed_communities = current_app.config["ALLOWED_COMMUNITIES"] if model.community_id in allowed_communities: @@ -314,14 +332,9 @@ def create_or_update_rtbh_rule( whitelists = db.session.query(Whitelist).filter(Whitelist.expires > datetime.now()).all() wl_cache = map_whitelists_to_strings(whitelists) results = check_rule_against_whitelists(str(model), wl_cache.keys()) - # check rule against whitelists, stop search when rule is whitelisted first time + # check rule against whitelists model = evaluate_rtbh_against_whitelists_check_results(user_id, model, flashes, author, wl_cache, results) - - announce_rtbh_route(model, author=author) - # Log changes - log_route(user_id, model, RuleTypes.RTBH, author) - - return model, flashes + return model def evaluate_rtbh_against_whitelists_check_results( diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py index 4b0d09a..2aa7905 100644 --- a/flowapp/views/rules.py +++ b/flowapp/views/rules.py @@ -114,7 +114,8 @@ def reactivate_rule(rule_type, rule_id): if "limit_reached" in messages: return redirect(url_for("rules.limit_reached", rule_type=rule_type)) - flash("Rule successfully updated", "alert-success") + for message in messages: + flash(message, "alert-success") return redirect( url_for( From edde80143384947c27fe504804d212a49ff21949 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Mon, 24 Mar 2025 11:38:19 +0100 Subject: [PATCH 29/46] bugfix - search for whitelist, added dict method to model --- flowapp/models/rules/whitelist.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/flowapp/models/rules/whitelist.py b/flowapp/models/rules/whitelist.py index 29f2b08..83c131f 100644 --- a/flowapp/models/rules/whitelist.py +++ b/flowapp/models/rules/whitelist.py @@ -73,6 +73,14 @@ def to_dict(self, prefered_format="yearfirst"): "rstate": self.rstate.description, } + def dict(self, prefered_format="yearfirst"): + """ + Serialize to dict + :param prefered_format: string with prefered time format + :returns: dictionary + """ + return self.to_dict(prefered_format) + def __repr__(self): return f"" From 968388cc9080cdbe988c1f8a887c93b0aef8b91a Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Tue, 25 Mar 2025 13:52:47 +0100 Subject: [PATCH 30/46] updated dashboard / no group operations for whitelist --- flowapp/views/dashboard.py | 48 +++++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/flowapp/views/dashboard.py b/flowapp/views/dashboard.py index 3980b88..0a393bb 100644 --- a/flowapp/views/dashboard.py +++ b/flowapp/views/dashboard.py @@ -266,9 +266,15 @@ def create_admin_response( :param sort_order: :return: """ + group_op = True if rtype != "whitelist" else False dashboard_table_body = create_dashboard_table_body( - rules, rtype, macro_file=macro_file, macro_name=macro_tbody, whitelist_rule_ids=whitelist_rule_ids + rules, + rtype, + macro_file=macro_file, + macro_name=macro_tbody, + group_op=group_op, + whitelist_rule_ids=whitelist_rule_ids, ) dashboard_table_head = create_dashboard_table_head( @@ -278,15 +284,18 @@ def create_admin_response( sort_key=sort_key, sort_order=sort_order, search_query=search_query, + group_op=group_op, macro_file=macro_file, macro_name=macro_thead, ) - - dashboard_table_foot = create_dashboard_table_foot( - table_colspan, - macro_file=macro_file, - macro_name=macro_tfoot, - ) + if group_op: + dashboard_table_foot = create_dashboard_table_foot( + table_colspan, + macro_file=macro_file, + macro_name=macro_tfoot, + ) + else: + dashboard_table_foot = "" res = make_response( render_template( @@ -358,8 +367,16 @@ def create_user_response( macro_name=macro_tbody, whitelist_rule_ids=whitelist_rule_ids, ) + + group_op = True if rtype != "whitelist" else False + dashboard_table_editable = create_dashboard_table_body( - rules_editable, rtype, macro_file=macro_file, macro_name=macro_tbody, whitelist_rule_ids=whitelist_rule_ids + rules_editable, + rtype, + macro_file=macro_file, + macro_name=macro_tbody, + group_op=group_op, + whitelist_rule_ids=whitelist_rule_ids, ) dashboard_table_editable_head = create_dashboard_table_head( rules_columns=table_columns, @@ -368,7 +385,7 @@ def create_user_response( sort_key=sort_key, sort_order=sort_order, search_query=search_query, - group_op=True, + group_op=group_op, macro_file=macro_file, macro_name=macro_thead, ) @@ -384,11 +401,14 @@ def create_user_response( macro_name=macro_thead, ) - dashboard_table_foot = create_dashboard_table_foot( - table_colspan, - macro_file=macro_file, - macro_name=macro_tfoot, - ) + if group_op: + dashboard_table_foot = create_dashboard_table_foot( + table_colspan, + macro_file=macro_file, + macro_name=macro_tfoot, + ) + else: + dashboard_table_foot = "" display_editable = len(rules_editable) display_readonly = len(read_only_rules) From 1b99e1a980cd46428014118cf895803002d8316b Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Tue, 25 Mar 2025 15:09:58 +0100 Subject: [PATCH 31/46] Add functionality to delete RTBH rules and create whitelist entries - Implement delete_rtbh_and_create_whitelist function in rule_service to handle RTBH rule deletion and whitelist creation. - Add a new route delete_and_whitelist in rules.py to expose this functionality via the UI. - Update macros.html to include a button for converting RTBH rules to whitelist entries. - Enhance user feedback with detailed flash messages for both success and failure scenarios. - Improve code maintainability by centralizing logic in the service layer. --- flowapp/services/rule_service.py | 150 ++++++++++++++++++++++++++++++- flowapp/templates/macros.html | 3 + flowapp/views/rules.py | 93 ++++++++++++------- 3 files changed, 210 insertions(+), 36 deletions(-) diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index 04f8160..761f25d 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -6,7 +6,7 @@ and managing flow rules, separating these concerns from HTTP handling. """ -from datetime import datetime +from datetime import datetime, timedelta from typing import Dict, List, Tuple, Union from flask import current_app @@ -22,6 +22,7 @@ RTBH, Whitelist, ) +from flowapp.models.rules.whitelist import RuleWhitelistCache from flowapp.models.utils import check_global_rule_limit, check_rule_limit from flowapp.output import ROUTE_MODELS, Route, announce_route, log_route, RouteSources, log_withdraw from flowapp.services.base import announce_rtbh_route @@ -31,9 +32,10 @@ create_rtbh_from_whitelist_parts, subtract_network, whitelist_rtbh_rule, + check_rule_against_whitelists, ) +from flowapp.services.whitelist_service import create_or_update_whitelist from flowapp.utils import round_to_ten_minutes, get_state_by_time, quote_to_ent -from .whitelist_common import check_rule_against_whitelists def reactivate_rule( @@ -384,3 +386,147 @@ def evaluate_rtbh_against_whitelists_check_results( def map_whitelists_to_strings(whitelists: List[Whitelist]) -> Dict[str, Whitelist]: return {str(w): w for w in whitelists} + + +def delete_rule( + rule_type: RuleTypes, rule_id: int, user_id: int, user_email: str, org_name: str, allowed_rule_ids: List[int] = None +) -> Tuple[bool, str]: + """ + Delete a rule with the given id and type. + + Args: + rule_type: Type of rule (RTBH, IPv4, IPv6) + rule_id: ID of the rule to delete + user_id: Current user ID + user_email: User email for logging + org_name: Organization name for logging + allowed_rule_ids: List of rule IDs the user is allowed to delete, None means no restriction + + Returns: + Tuple containing (success, message) + """ + model_class = {RuleTypes.RTBH: RTBH, RuleTypes.IPv4: Flowspec4, RuleTypes.IPv6: Flowspec6}[rule_type] + + route_model = ROUTE_MODELS[rule_type.value] + + model = db.session.get(model_class, rule_id) + if not model: + return False, "Rule not found" + + # Check permission if allowed_rule_ids is provided + if allowed_rule_ids is not None and model.id not in allowed_rule_ids: + return False, "You cannot delete this rule" + + # Withdraw route + command = route_model(model, WITHDRAW) + route = Route( + author=f"{user_email} / {org_name}", + source=RouteSources.UI, + command=command, + ) + announce_route(route) + + # Log withdrawal + log_withdraw( + user_id, + route.command, + rule_type, + model.id, + f"{user_email} / {org_name}", + ) + + # Special handling for RTBH rules + if rule_type == RuleTypes.RTBH: + current_app.logger.debug(f"Deleting RTBH rule {rule_id} from cache") + RuleWhitelistCache.delete_by_rule_id(rule_id) + + # Delete from database + db.session.delete(model) + db.session.commit() + + return True, "Rule deleted successfully" + + +def delete_rtbh_and_create_whitelist( + rule_id: int, + user_id: int, + org_id: int, + user_email: str, + org_name: str, + allowed_rule_ids: List[int] = None, + whitelist_expires: datetime = None, +) -> Tuple[bool, List[str], Union[Whitelist, None]]: + """ + Delete an RTBH rule and create a whitelist entry from it. + + Args: + rule_id: ID of the RTBH rule to delete + user_id: Current user ID + org_id: Current organization ID + user_email: User email for logging + org_name: Organization name for logging + allowed_rule_ids: List of rule IDs the user is allowed to delete + whitelist_expires: Expiration time for the whitelist entry (default: 7 days from now) + + Returns: + Tuple containing (success, messages, whitelist_model) + """ + messages = [] + + # First get the RTBH rule to extract its data + model = db.session.get(RTBH, rule_id) + if not model: + return False, ["RTBH rule not found"], None + + # Check permission if allowed_rule_ids is provided + if allowed_rule_ids is not None and model.id not in allowed_rule_ids: + return False, ["You cannot delete this rule"], None + + # Extract data for whitelist + if model.ipv4: + ip = model.ipv4 + mask = model.ipv4_mask + elif model.ipv6: + ip = model.ipv6 + mask = model.ipv6_mask + else: + return False, ["RTBH rule has no IP address"], None + + # Set default whitelist expiration time if not provided + if whitelist_expires is None: + whitelist_expires = datetime.now() + timedelta(days=7) + + # Prepare whitelist data + whitelist_data = { + "ip": ip, + "mask": mask, + "expires": whitelist_expires, + "comment": f"Created from RTBH rule {rule_id}: {model.comment}", + } + + # Delete the RTBH rule + success, delete_message = delete_rule( + rule_type=RuleTypes.RTBH, + rule_id=rule_id, + user_id=user_id, + user_email=user_email, + org_name=org_name, + allowed_rule_ids=allowed_rule_ids, + ) + + if not success: + return False, [delete_message], None + + messages.append(delete_message) + + # Create the whitelist entry + try: + whitelist_model, whitelist_messages = create_or_update_whitelist( + form_data=whitelist_data, user_id=user_id, org_id=org_id, user_email=user_email, org_name=org_name + ) + messages.extend(whitelist_messages) + return True, messages, whitelist_model + except Exception as e: + current_app.logger.exception(f"Error creating whitelist entry: {e}") + messages.append(f"Rule deleted but failed to create whitelist: {str(e)}") + return False, messages, None diff --git a/flowapp/templates/macros.html b/flowapp/templates/macros.html index 2f5dd9f..6f95a13 100644 --- a/flowapp/templates/macros.html +++ b/flowapp/templates/macros.html @@ -111,6 +111,9 @@ + + + {% endif %} {% if rule.comment %} diff --git a/flowapp/templates/pages/machine_api_key.html b/flowapp/templates/pages/machine_api_key.html index 52eb478..c2ced22 100644 --- a/flowapp/templates/pages/machine_api_key.html +++ b/flowapp/templates/pages/machine_api_key.html @@ -11,8 +11,8 @@

Machines and ApiKeys

Machine address ApiKey - Created by Created for + Created by Expires Read/Write ? Action diff --git a/flowapp/tests/conftest.py b/flowapp/tests/conftest.py index ae01c6d..3a988ef 100644 --- a/flowapp/tests/conftest.py +++ b/flowapp/tests/conftest.py @@ -152,6 +152,26 @@ def jwt_token(client, app, db, request): return data["token"] +@pytest.fixture(scope="session") +def machine_api_token(client, app, db, request): + """ + Get the test_client from the app, for the whole test session. + """ + mkey = "machinetestkey" + + with app.app_context(): + model = flowapp.models.MachineApiKey(machine="127.0.0.1", key=mkey, user_id=1, org_id=1) + db.session.add(model) + db.session.commit() + + print("\n----- GET MACHINE API KEY TEST TOKEN\n") + url = "/api/v3/auth" + headers = {"x-api-key": mkey} + token = client.get(url, headers=headers) + data = json.loads(token.data) + return data["token"] + + @pytest.fixture(scope="session") def expired_auth_token(client, app, db, request): """ diff --git a/flowapp/tests/test_api_auth.py b/flowapp/tests/test_api_auth.py index 5733346..8d08a7a 100644 --- a/flowapp/tests/test_api_auth.py +++ b/flowapp/tests/test_api_auth.py @@ -11,6 +11,15 @@ def test_token(client, jwt_token): assert req.status_code == 200 +def test_machine_token(client, machine_api_token): + """ + test that token authorization works + """ + req = client.get("/api/v3/test_token", headers={"x-access-token": machine_api_token}) + + assert req.status_code == 200 + + def test_expired_token(client, expired_auth_token): """ test that expired token authorization return 401 @@ -37,7 +46,7 @@ def test_readonly_token(client, readonly_jwt_token): assert req.status_code == 200 data = json.loads(req.data) - assert data['readonly'] + assert data["readonly"] def test_readonly_token_ipv4_create(client, db, readonly_jwt_token): diff --git a/flowapp/tests/test_forms_cl.py b/flowapp/tests/test_forms_cl.py index c6b42bc..db0a569 100644 --- a/flowapp/tests/test_forms_cl.py +++ b/flowapp/tests/test_forms_cl.py @@ -372,12 +372,14 @@ def valid_machine_key_data(self, valid_datetime): "comment": "Test machine API key", "expires": valid_datetime, "readonly": "true", + "user": 1, } def test_valid_machine_key(self, app, valid_machine_key_data): with app.test_request_context(): form_data = create_form_data(valid_machine_key_data) form = MachineApiKeyForm(formdata=form_data) + form.user.choices = [(1, "g.name"), (2, "test")] assert form.validate() diff --git a/flowapp/views/admin.py b/flowapp/views/admin.py index d014c87..7912e86 100644 --- a/flowapp/views/admin.py +++ b/flowapp/views/admin.py @@ -73,15 +73,22 @@ def add_machine_key(): """ generated = secrets.token_hex(24) form = MachineApiKeyForm(request.form, key=generated) + form.user.choices = [(g.id, g.name) for g in db.session.query(User).order_by("name")] if request.method == "POST" and form.validate(): + target_user = db.session.get(User, form.user.data) + target_org = target_user.organization.first() if target_user else None + current_user = session.get("user_name") + curent_email = session.get("user_uuid") + comment = f"created by: {current_user}/{curent_email}, comment: {form.comment.data}" model = MachineApiKey( machine=form.machine.data, key=form.key.data, expires=form.expires.data, readonly=form.readonly.data, - comment=form.comment.data, - user_id=session["user_id"], + comment=comment, + user_id=target_user.id, + org_id=target_org.id, ) db.session.add(model) diff --git a/flowapp/views/api_common.py b/flowapp/views/api_common.py index 4fad70e..3951b09 100644 --- a/flowapp/views/api_common.py +++ b/flowapp/views/api_common.py @@ -63,7 +63,6 @@ def authorize(user_key): :return: page with token """ jwt_key = current_app.config.get("JWT_SECRET") - # try normal user key first model = db.session.query(ApiKey).filter_by(key=user_key).first() # if not found try machine key @@ -96,7 +95,7 @@ def authorize(user_key): return jsonify({"token": encoded}) else: - return jsonify({"message": "auth token is invalid"}), 403 + return jsonify({"message": f"auth token is not valid from machine {request.remote_addr}"}), 403 def check_readonly(func): From b836a1b3946a8b2d37d292104a10d64dc63a7eff Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Tue, 3 Jun 2025 12:09:16 +0200 Subject: [PATCH 38/46] Update python-app.yml updated to python 3.11 used in production env --- .github/workflows/python-app.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 9e659ab..ad47922 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -19,10 +19,10 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11 uses: actions/setup-python@v3 with: - python-version: "3.9" + python-version: "3.11" - name: Setup timezone uses: zcong1993/setup-timezone@master with: From cced42e23ea6ca20fc8355572944f7d5d44028d0 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Wed, 4 Jun 2025 10:38:27 +0200 Subject: [PATCH 39/46] updated readme changelod with versions 1.1.1 and 1.1.0 --- README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.md b/README.md index f0bc974..db67631 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,15 @@ You may also need to monitor the ExaBGP and renew the commands after restart / s * [Local database instalation notes](./docs/DB_LOCAL.md) ## Change Log +- 1.1.1 - Machine API Key rewrited. + - API keys for machines are now tied to one of the existing users. If there is a need to have API access for machine, first create service user, and set the access rights. Then create machine key as Admin and assign it to this user. +- 1.1.0 - Major Architecture Refactoring and Whitelist Integration + - Code Organization and Architecture Improvements. Significant architectural refactoring focused on better separation of concerns and improved maintainability. The most notable change is the introduction of a dedicated **services layer** that extracts business logic from view controllers. Key service modules include `rule_service.py` for rule management operations, `whitelist_service.py` for whitelist functionality, and `whitelist_common.py` for shared whitelist utilities. + - The **models structure** has been reorganized with better separation into logical modules. Rule models are now organized under `flowapp/models/rules/` with separate files for different rule types (`flowspec.py`, `rtbh.py`, `whitelist.py`), while maintaining backward compatibility through the main models `__init__.py`. Form handling has also been improved with better organization under `flowapp/forms/` and enhanced validation logic. + - **RTBH Whitelist Integration** This system automatically evaluates new RTBH rules against existing whitelists and can automatically modify or block rules that conflict with whitelisted networks. When an RTBH rule is created that intersects with a whitelist entry, the system can: + - **Automatically whitelist** rules that exactly match or are contained within whitelisted networks + - **Create subnet rules** when RTBH rules are supersets of whitelisted networks, automatically generating the non-whitelisted portions + - **Maintain rule cache** that tracks relationships between rules and whitelists for proper cleanup - 1.0.2 - fixed bug in IPv6 Flowspec messages - 1.0.1 . minor bug fixes - 1.0.0 . Major changes From eae1a980d1220bdd2fcf5987188b1c4d6b97896a Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Wed, 4 Jun 2025 17:55:45 +0200 Subject: [PATCH 40/46] Update python-app.yml Update to cover Python from 3.9 to 3.12 --- .github/workflows/python-app.yml | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index ad47922..0f8eab6 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -1,28 +1,25 @@ -# This workflow will install Python dependencies, run tests and lint with a single version of Python +# This workflow will install Python dependencies, run tests and lint with multiple versions of Python # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python - name: Python application - on: push: branches: [ "master", "develop" ] pull_request: branches: [ "master", "develop" ] - permissions: contents: read - jobs: build: - runs-on: ubuntu-latest - + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v3 - - name: Set up Python 3.11 + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v3 with: - python-version: "3.11" + python-version: ${{ matrix.python-version }} - name: Setup timezone uses: zcong1993/setup-timezone@master with: From e21a0af804a8e20014f9128a05833f92c7cee3ca Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Wed, 4 Jun 2025 18:34:26 +0200 Subject: [PATCH 41/46] avoid use of match case to keep compatibility with python 3.9 --- README.md | 2 +- flowapp/services/rule_service.py | 49 +++++++++++------------- flowapp/services/whitelist_service.py | 55 +++++++++++++-------------- 3 files changed, 49 insertions(+), 57 deletions(-) diff --git a/README.md b/README.md index db67631..04fb166 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ See how is ExaFS integrated into the network in the picture below. ![ExaFS schema](./docs/app_schema_en.png) -The central part of the ExaFS is a web application, written in Python3.6 with Flask framework. It provides a user interface for ExaBGP rule CRUD operations. The application also provides the REST API with CRUD operations for the configuration rules. The web app uses Shibboleth authorization; the REST API is using token-based authorization. +The central part of the ExaFS is a web application, written in Python with Flask framework. It provides a user interface for ExaBGP rule CRUD operations. The application also provides the REST API with CRUD operations for the configuration rules. The web app uses Shibboleth authorization; the REST API is using token-based authorization. The app creates the ExaBGP commands and forwards them to ExaBGP process. All rules are carefully validated, and only valid rules are stored in the database and sent to the ExaBGP connector. diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index 997bd19..fca7719 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -352,35 +352,30 @@ def evaluate_rtbh_against_whitelists_check_results( Process all results for cases where rule is whitelisted by several whitelists. """ for rule, whitelist_key, relation in results: - match relation: - case Relation.EQUAL: - model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) - msg = f"RTBH Rule {model.id} {model} is equal to active whitelist {whitelist_key}. Rule is whitelisted." + if relation == Relation.EQUAL: + model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) + msg = f"RTBH Rule {model.id} {model} is equal to active whitelist {whitelist_key}. Rule is whitelisted." + flashes.append(msg) + current_app.logger.info(msg) + elif relation == Relation.SUBNET: + parts = subtract_network(target=str(model), whitelist=whitelist_key) + wl_id = wl_cache[whitelist_key].id + msg = f"RTBH Rule {model.id} {model} is supernet of active whitelist {whitelist_key}. Rule is whitelisted, {len(parts)} subnet rules created." + flashes.append(msg) + current_app.logger.info(msg) + for network in parts: + new_rule = create_rtbh_from_whitelist_parts(model, wl_id, whitelist_key, network, author, user_id) + msg = f"Created RTBH rule {new_rule.id} {new_rule} for {network} parted by whitelist {whitelist_key}" flashes.append(msg) current_app.logger.info(msg) - case Relation.SUBNET: - parts = subtract_network(target=str(model), whitelist=whitelist_key) - wl_id = wl_cache[whitelist_key].id - msg = f"RTBH Rule {model.id} {model} is supernet of active whitelist {whitelist_key}. Rule is whitelisted, {len(parts)} subnet rules created." - flashes.append(msg) - current_app.logger.info(msg) - for network in parts: - new_rule = create_rtbh_from_whitelist_parts(model, wl_id, whitelist_key, network, author, user_id) - msg = ( - f"Created RTBH rule {new_rule.id} {new_rule} for {network} parted by whitelist {whitelist_key}" - ) - flashes.append(msg) - current_app.logger.info(msg) - model.rstate_id = 4 - add_rtbh_rule_to_cache(model, wl_id, RuleOrigin.USER) - db.session.commit() - case Relation.SUPERNET: - model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) - msg = ( - f"RTBH Rule {model.id} {model} is subnet of active whitelist {whitelist_key}. Rule is whitelisted." - ) - current_app.logger.info(msg) - flashes.append(msg) + model.rstate_id = 4 + add_rtbh_rule_to_cache(model, wl_id, RuleOrigin.USER) + db.session.commit() + elif relation == Relation.SUPERNET: + model = whitelist_rtbh_rule(model, wl_cache[whitelist_key]) + msg = f"RTBH Rule {model.id} {model} is subnet of active whitelist {whitelist_key}. Rule is whitelisted." + current_app.logger.info(msg) + flashes.append(msg) return model diff --git a/flowapp/services/whitelist_service.py b/flowapp/services/whitelist_service.py index a934cc8..7cbef38 100644 --- a/flowapp/services/whitelist_service.py +++ b/flowapp/services/whitelist_service.py @@ -84,37 +84,34 @@ def evaluate_whitelist_against_rtbh_check_results( for rule_key, whitelist_key, relation in results: current_app.logger.info(f"whitelist {whitelist_key} is {relation} to Rule {rule_key}") - match relation: - case Relation.EQUAL: - whitelist_rtbh_rule(rtbh_rule_cache[rule_key], whitelist_model) - withdraw_rtbh_route(rtbh_rule_cache[rule_key]) - msg = "Existing active rule {rule_key} is equal to whitelist {whitelist_key}. Rule is now whitelisted." + if relation == Relation.EQUAL: + whitelist_rtbh_rule(rtbh_rule_cache[rule_key], whitelist_model) + withdraw_rtbh_route(rtbh_rule_cache[rule_key]) + msg = "Existing active rule {rule_key} is equal to whitelist {whitelist_key}. Rule is now whitelisted." + flashes.append(msg) + current_app.logger.info(msg) + elif relation == Relation.SUBNET: + parts = subtract_network(target=rule_key, whitelist=whitelist_key) + wl_id = whitelist_model.id + msg = f"Rule {rule_key} is supernet of whitelist {whitelist_key}. Rule is whitelisted, {len(parts)} subnet rules will be created." + flashes.append(msg) + current_app.logger.info(msg) + for network in parts: + rule_model = rtbh_rule_cache[rule_key] + create_rtbh_from_whitelist_parts(rule_model, wl_id, whitelist_key, network) + msg = f"Created RTBH rule from {rule_model.id} {network} parted by whitelist {whitelist_key}." flashes.append(msg) current_app.logger.info(msg) - case Relation.SUBNET: - parts = subtract_network(target=rule_key, whitelist=whitelist_key) - wl_id = whitelist_model.id - msg = f"Rule {rule_key} is supernet of whitelist {whitelist_key}. Rule is whitelisted, {len(parts)} subnet rules will be created." - flashes.append(msg) - current_app.logger.info(msg) - for network in parts: - rule_model = rtbh_rule_cache[rule_key] - create_rtbh_from_whitelist_parts(rule_model, wl_id, whitelist_key, network) - msg = f"Created RTBH rule from {rule_model.id} {network} parted by whitelist {whitelist_key}." - flashes.append(msg) - current_app.logger.info(msg) - rule_model.rstate_id = 4 - add_rtbh_rule_to_cache(rule_model, wl_id, RuleOrigin.USER) - db.session.commit() - case Relation.SUPERNET: - - whitelist_rtbh_rule(rtbh_rule_cache[rule_key], whitelist_model) - withdraw_rtbh_route(rtbh_rule_cache[rule_key]) - msg = ( - f"Existing active rule {rule_key} is subnet of whitelist {whitelist_key}. Rule is now whitelisted." - ) - current_app.logger.info(msg) - flashes.append(msg) + rule_model.rstate_id = 4 + add_rtbh_rule_to_cache(rule_model, wl_id, RuleOrigin.USER) + db.session.commit() + elif relation == Relation.SUPERNET: + + whitelist_rtbh_rule(rtbh_rule_cache[rule_key], whitelist_model) + withdraw_rtbh_route(rtbh_rule_cache[rule_key]) + msg = f"Existing active rule {rule_key} is subnet of whitelist {whitelist_key}. Rule is now whitelisted." + current_app.logger.info(msg) + flashes.append(msg) return whitelist_model From ba305000aa6eb6f19c7cede43c7f909218951147 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Wed, 4 Jun 2025 18:40:00 +0200 Subject: [PATCH 42/46] avoid use of match case to keep compatibility with python 3.9 --- flowapp/services/whitelist_service.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/flowapp/services/whitelist_service.py b/flowapp/services/whitelist_service.py index 7cbef38..bf281af 100644 --- a/flowapp/services/whitelist_service.py +++ b/flowapp/services/whitelist_service.py @@ -151,13 +151,12 @@ def delete_whitelist(whitelist_id: int) -> List[str]: for cached_rule in cached_rules: rule_model_type = RuleTypes(cached_rule.rtype) cache_entries_count = RuleWhitelistCache.count_by_rule(cached_rule.rid, rule_model_type) - match rule_model_type: - case RuleTypes.IPv4: - rule_model = db.session.get(Flowspec4, cached_rule.rid) - case RuleTypes.IPv6: - rule_model = db.session.get(Flowspec6, cached_rule.rid) - case RuleTypes.RTBH: - rule_model = db.session.get(RTBH, cached_rule.rid) + if rule_model_type == RuleTypes.IPv4: + rule_model = db.session.get(Flowspec4, cached_rule.rid) + elif rule_model_type == RuleTypes.IPv6: + rule_model = db.session.get(Flowspec6, cached_rule.rid) + elif rule_model_type == RuleTypes.RTBH: + rule_model = db.session.get(RTBH, cached_rule.rid) rorigin_type = RuleOrigin(cached_rule.rorigin) current_app.logger.debug(f"Rule {rule_model} has origin {rorigin_type}") if rorigin_type == RuleOrigin.WHITELIST: From 42e93105d9ca52a153156ade8216876de762d24f Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Wed, 4 Jun 2025 18:59:40 +0200 Subject: [PATCH 43/46] Union instead | to be back compatible in type hints --- flowapp/services/whitelist_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flowapp/services/whitelist_common.py b/flowapp/services/whitelist_common.py index 3807dca..b9b9e8e 100644 --- a/flowapp/services/whitelist_common.py +++ b/flowapp/services/whitelist_common.py @@ -1,7 +1,7 @@ from enum import Enum, auto from functools import lru_cache import ipaddress -from typing import List, Tuple +from typing import List, Tuple, Union from flowapp import db from flowapp.constants import RuleOrigin, RuleTypes from flowapp.models import RTBH, RuleWhitelistCache, Whitelist @@ -59,7 +59,7 @@ def _is_same_ip_version(addr1: str, addr2: str) -> bool: @lru_cache(maxsize=1024) -def get_network(address: str) -> ipaddress.IPv4Network | ipaddress.IPv6Network: +def get_network(address: str) -> Union[ipaddress.IPv4Network, ipaddress.IPv6Network]: """ Create and cache an IP network object. From cf5de0631adb6985d6c3136159d4920a6cf22284 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 5 Jun 2025 09:08:40 +0200 Subject: [PATCH 44/46] updated readme --- README.md | 15 +++++++-------- docs/guarda-service/README.md | 1 - 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 04fb166..6402104 100644 --- a/README.md +++ b/README.md @@ -35,18 +35,17 @@ See how is ExaFS integrated into the network in the picture below. ## System overview ![ExaFS schema](./docs/app_schema_en.png) +The core component of ExaFS is a web application written in Python using the Flask framework. It provides a user interface for managing ExaBGP rules (CRUD operations) and also exposes a REST API with similar functionality. The web application uses Shibboleth for authentication, while the REST API relies on token-based authentication. -The central part of the ExaFS is a web application, written in Python with Flask framework. It provides a user interface for ExaBGP rule CRUD operations. The application also provides the REST API with CRUD operations for the configuration rules. The web app uses Shibboleth authorization; the REST API is using token-based authorization. +The application generates ExaBGP commands and forwards them to the ExaBGP process. All rules are thoroughly validated—only valid rules are stored in the database and sent to the ExaBGP connector. -The app creates the ExaBGP commands and forwards them to ExaBGP process. All rules are carefully validated, and only valid rules are stored in the database and sent to the ExaBGP connector. - -This second part of the system is another application that replicates the received command to the stdout. The connection between ExaBGP daemon and stdout of ExaAPI (ExaBGP process) is specified in the ExaBGP config. +The second component of the system is a separate application that replicates received commands to `stdout`. The connection between the ExaBGP daemon and the `stdout` of the ExaAPI (ExaBGP process) is defined in the ExaBGP configuration. -This API was a part of the project, but now has been moved to own repository. You can use [pip package exabgp-process](https://pypi.org/project/exabgp-process/) or clone the git repo. Or you can create your own version. - -Every time this process gets a command from ExaFS, it replicates this command to the ExaBGP service through the stdout. The registered service then updates the ExaBGP table – create, modify or remove the rule from command. +This API was originally part of the same project but has since been moved to its own repository. You can use the [exabgp-process pip package](https://pypi.org/project/exabgp-process/), clone the Git repository, or develop your own implementation. -You may also need to monitor the ExaBGP and renew the commands after restart / shutdown. In docs you can find and example of system service named Guarda. This systemctl service is running in the host system and gets a notification on each restart of ExaBGP service via systemctl WantedBy config option. For every restart of ExaBGP the Guarda service will put all the valid and active rules to the ExaBGP rules table again. +Each time this process receives a command from ExaFS, it outputs it to `stdout`, allowing the ExaBGP service to process the command and update its routing table—creating, modifying, or removing rules accordingly. + +It may also be necessary to monitor ExaBGP and re-announce rules after a restart or shutdown. This can be handled via the ExaBGP service configuration, or by using an example system service called **Guarda**, described in the documentation. In either case, the key mechanism is calling the application endpoint `/rules/announce_all`. This endpoint is only accessible from `localhost`; a local IP address must be configured in the application settings. ## DOCS * [Install notes](./docs/INSTALL.md) diff --git a/docs/guarda-service/README.md b/docs/guarda-service/README.md index 45f5344..e8086b4 100644 --- a/docs/guarda-service/README.md +++ b/docs/guarda-service/README.md @@ -1,5 +1,4 @@ # Guarda Service for ExaBGP - This is a systemd service designed to monitor ExaBGP and reapply commands after a restart or shutdown. The guarda.service runs on the host system and is triggered whenever the ExaBGP service restarts, thanks to the WantedBy configuration in systemd. After each restart, the Guarda service will reapply all valid and active rules to the ExaBGP rules table. ## Usage (as root) From f10e008959583efb6eef8063b138d7c9a529a993 Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Thu, 5 Jun 2025 09:19:23 +0200 Subject: [PATCH 45/46] link for api docs in admin menu, updated readme --- README.md | 2 ++ flowapp/templates/layouts/default.html | 1 + 2 files changed, 3 insertions(+) diff --git a/README.md b/README.md index 6402104..cdfc622 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,8 @@ It may also be necessary to monitor ExaBGP and re-announce rules after a restart * [Database backup configuration](./docs/DB_BACKUP.md) * [Local database instalation notes](./docs/DB_LOCAL.md) +The REST API is documented using Swagger (OpenAPI). After installing and running the application, the API documentation is available locally at the /apidocs/ endpoint. This interactive documentation provides details about all available endpoints, request and response formats, and supported operations, making it easier to integrate and test the API. + ## Change Log - 1.1.1 - Machine API Key rewrited. - API keys for machines are now tied to one of the existing users. If there is a need to have API access for machine, first create service user, and set the access rights. Then create machine key as Admin and assign it to this user. diff --git a/flowapp/templates/layouts/default.html b/flowapp/templates/layouts/default.html index 47b049c..c6f6e4f 100644 --- a/flowapp/templates/layouts/default.html +++ b/flowapp/templates/layouts/default.html @@ -59,6 +59,7 @@ {% endfor %}
  • ExaFS version {{ session['app_version'] }}
  • +
  • API docs
  • {% endif %} From 8566df990c9062a1e005749596c2e66c4787ffce Mon Sep 17 00:00:00 2001 From: Jiri Vrany Date: Fri, 6 Jun 2025 12:17:07 +0200 Subject: [PATCH 46/46] link to exafs deploy repo in docs --- README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index cdfc622..f6f2ca1 100644 --- a/README.md +++ b/README.md @@ -48,13 +48,16 @@ Each time this process receives a command from ExaFS, it outputs it to `stdout`, It may also be necessary to monitor ExaBGP and re-announce rules after a restart or shutdown. This can be handled via the ExaBGP service configuration, or by using an example system service called **Guarda**, described in the documentation. In either case, the key mechanism is calling the application endpoint `/rules/announce_all`. This endpoint is only accessible from `localhost`; a local IP address must be configured in the application settings. ## DOCS +### Instalation related +* [ExaFS Ansible deploy](https://github.com/CESNET/ExaFS-deploy) - repository with Ansbile playbook for deploying ExaFS with Docker Compose. * [Install notes](./docs/INSTALL.md) -* [API documentation ](https://exafs.docs.apiary.io/#) * [Database backup configuration](./docs/DB_BACKUP.md) * [Local database instalation notes](./docs/DB_LOCAL.md) - +### API The REST API is documented using Swagger (OpenAPI). After installing and running the application, the API documentation is available locally at the /apidocs/ endpoint. This interactive documentation provides details about all available endpoints, request and response formats, and supported operations, making it easier to integrate and test the API. + + ## Change Log - 1.1.1 - Machine API Key rewrited. - API keys for machines are now tied to one of the existing users. If there is a need to have API access for machine, first create service user, and set the access rights. Then create machine key as Admin and assign it to this user.