diff --git a/archinstall/applications/firewall.py b/archinstall/applications/firewall.py new file mode 100644 index 0000000000..3100650b95 --- /dev/null +++ b/archinstall/applications/firewall.py @@ -0,0 +1,33 @@ +from typing import TYPE_CHECKING + +from archinstall.lib.models.application import Firewall, FirewallConfiguration +from archinstall.lib.output import debug + +if TYPE_CHECKING: + from archinstall.lib.installer import Installer + + +class FirewallApp: + @property + def ufw_packages(self) -> list[str]: + return [ + 'ufw', + ] + + @property + def ufw_services(self) -> list[str]: + return [ + 'ufw.service', + ] + + def install( + self, + install_session: 'Installer', + firewall_config: FirewallConfiguration, + ) -> None: + debug(f'Installing firewall: {firewall_config.firewall.value}') + + match firewall_config.firewall: + case Firewall.UFW: + install_session.add_additional_packages(self.ufw_packages) + install_session.enable_service(self.ufw_services) diff --git a/archinstall/lib/applications/application_handler.py b/archinstall/lib/applications/application_handler.py index 9bcf2ae015..e7d16058d5 100644 --- a/archinstall/lib/applications/application_handler.py +++ b/archinstall/lib/applications/application_handler.py @@ -2,6 +2,7 @@ from archinstall.applications.audio import AudioApp from archinstall.applications.bluetooth import BluetoothApp +from archinstall.applications.firewall import FirewallApp from archinstall.applications.power_management import PowerManagementApp from archinstall.applications.print_service import PrintServiceApp from archinstall.lib.models import Audio @@ -36,5 +37,11 @@ def install_applications(self, install_session: 'Installer', app_config: Applica if app_config.print_service_config and app_config.print_service_config.enabled: PrintServiceApp().install(install_session) + if app_config.firewall_config: + FirewallApp().install( + install_session, + app_config.firewall_config, + ) + application_handler = ApplicationHandler() diff --git a/archinstall/lib/applications/application_menu.py b/archinstall/lib/applications/application_menu.py index c6ef9d177c..7dd56e9871 100644 --- a/archinstall/lib/applications/application_menu.py +++ b/archinstall/lib/applications/application_menu.py @@ -7,6 +7,8 @@ Audio, AudioConfiguration, BluetoothConfiguration, + Firewall, + FirewallConfiguration, PowerManagement, PowerManagementConfiguration, PrintServiceConfiguration, @@ -70,6 +72,12 @@ def _define_menu_options(self) -> list[MenuItem]: enabled=SysInfo.has_battery(), key='power_management_config', ), + MenuItem( + text=tr('Firewall'), + action=select_firewall, + preview_action=self._prev_firewall, + key='firewall_config', + ), ] def _prev_power_management(self, item: MenuItem) -> str | None: @@ -102,6 +110,12 @@ def _prev_print_service(self, item: MenuItem) -> str | None: return output return None + def _prev_firewall(self, item: MenuItem) -> str | None: + if item.value is not None: + config: FirewallConfiguration = item.value + return f'{tr("Firewall")}: {config.firewall.value}' + return None + def select_power_management(preset: PowerManagementConfiguration | None = None) -> PowerManagementConfiguration | None: group = MenuItemGroup.from_enum(PowerManagement) @@ -203,3 +217,26 @@ def select_audio(preset: AudioConfiguration | None = None) -> AudioConfiguration return AudioConfiguration(audio=result.get_value()) case ResultType.Reset: raise ValueError('Unhandled result type') + + +def select_firewall(preset: FirewallConfiguration | None = None) -> FirewallConfiguration | None: + group = MenuItemGroup.from_enum(Firewall) + + if preset: + group.set_focus_by_value(preset.firewall) + + result = SelectMenu[Firewall]( + group, + allow_skip=True, + alignment=Alignment.CENTER, + allow_reset=True, + frame=FrameProperties.min(tr('Firewall')), + ).run() + + match result.type_: + case ResultType.Skip: + return preset + case ResultType.Selection: + return FirewallConfiguration(firewall=result.get_value()) + case ResultType.Reset: + return None diff --git a/archinstall/lib/models/application.py b/archinstall/lib/models/application.py index 95163e2b1e..8874c5af8d 100644 --- a/archinstall/lib/models/application.py +++ b/archinstall/lib/models/application.py @@ -30,6 +30,14 @@ class PrintServiceConfigSerialization(TypedDict): enabled: bool +class Firewall(StrEnum): + UFW = 'ufw' + + +class FirewallConfigSerialization(TypedDict): + firewall: str + + class ZramAlgorithm(StrEnum): ZSTD = 'zstd' LZO_RLE = 'lzo-rle' @@ -43,6 +51,7 @@ class ApplicationSerialization(TypedDict): audio_config: NotRequired[AudioConfigSerialization] power_management_config: NotRequired[PowerManagementConfigSerialization] print_service_config: NotRequired[PrintServiceConfigSerialization] + firewall_config: NotRequired[FirewallConfigSerialization] @dataclass @@ -101,6 +110,22 @@ def parse_arg(arg: dict[str, Any]) -> 'PrintServiceConfiguration': return PrintServiceConfiguration(arg['enabled']) +@dataclass +class FirewallConfiguration: + firewall: Firewall + + def json(self) -> FirewallConfigSerialization: + return { + 'firewall': self.firewall.value, + } + + @staticmethod + def parse_arg(arg: dict[str, Any]) -> 'FirewallConfiguration': + return FirewallConfiguration( + Firewall(arg['firewall']), + ) + + @dataclass(frozen=True) class ZramConfiguration: enabled: bool @@ -122,6 +147,7 @@ class ApplicationConfiguration: audio_config: AudioConfiguration | None = None power_management_config: PowerManagementConfiguration | None = None print_service_config: PrintServiceConfiguration | None = None + firewall_config: FirewallConfiguration | None = None @staticmethod def parse_arg( @@ -146,6 +172,9 @@ def parse_arg( if args and (print_service_config := args.get('print_service_config')) is not None: app_config.print_service_config = PrintServiceConfiguration.parse_arg(print_service_config) + if args and (firewall_config := args.get('firewall_config')) is not None: + app_config.firewall_config = FirewallConfiguration.parse_arg(firewall_config) + return app_config def json(self) -> ApplicationSerialization: @@ -163,4 +192,7 @@ def json(self) -> ApplicationSerialization: if self.print_service_config: config['print_service_config'] = self.print_service_config.json() + if self.firewall_config: + config['firewall_config'] = self.firewall_config.json() + return config