Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions archinstall/applications/firewall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import TYPE_CHECKING

from archinstall.lib.models.application import Firewall, FirewallConfiguration
from archinstall.lib.output import debug

if TYPE_CHECKING:
from archinstall.lib.installer import Installer


class FirewallApp:
@property
def ufw_packages(self) -> list[str]:
return [
'ufw',
]

@property
def ufw_services(self) -> list[str]:
return [
'ufw.service',
]

def install(
self,
install_session: 'Installer',
firewall_config: FirewallConfiguration,
) -> None:
debug(f'Installing firewall: {firewall_config.firewall.value}')

match firewall_config.firewall:
case Firewall.UFW:
install_session.add_additional_packages(self.ufw_packages)
install_session.enable_service(self.ufw_services)
7 changes: 7 additions & 0 deletions archinstall/lib/applications/application_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from archinstall.applications.audio import AudioApp
from archinstall.applications.bluetooth import BluetoothApp
from archinstall.applications.firewall import FirewallApp
from archinstall.applications.power_management import PowerManagementApp
from archinstall.applications.print_service import PrintServiceApp
from archinstall.lib.models import Audio
Expand Down Expand Up @@ -36,5 +37,11 @@ def install_applications(self, install_session: 'Installer', app_config: Applica
if app_config.print_service_config and app_config.print_service_config.enabled:
PrintServiceApp().install(install_session)

if app_config.firewall_config:
FirewallApp().install(
install_session,
app_config.firewall_config,
)


application_handler = ApplicationHandler()
37 changes: 37 additions & 0 deletions archinstall/lib/applications/application_menu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
Audio,
AudioConfiguration,
BluetoothConfiguration,
Firewall,
FirewallConfiguration,
PowerManagement,
PowerManagementConfiguration,
PrintServiceConfiguration,
Expand Down Expand Up @@ -70,6 +72,12 @@ def _define_menu_options(self) -> list[MenuItem]:
enabled=SysInfo.has_battery(),
key='power_management_config',
),
MenuItem(
text=tr('Firewall'),
action=select_firewall,
preview_action=self._prev_firewall,
key='firewall_config',
),
]

def _prev_power_management(self, item: MenuItem) -> str | None:
Expand Down Expand Up @@ -102,6 +110,12 @@ def _prev_print_service(self, item: MenuItem) -> str | None:
return output
return None

def _prev_firewall(self, item: MenuItem) -> str | None:
if item.value is not None:
config: FirewallConfiguration = item.value
return f'{tr("Firewall")}: {config.firewall.value}'
return None


def select_power_management(preset: PowerManagementConfiguration | None = None) -> PowerManagementConfiguration | None:
group = MenuItemGroup.from_enum(PowerManagement)
Expand Down Expand Up @@ -203,3 +217,26 @@ def select_audio(preset: AudioConfiguration | None = None) -> AudioConfiguration
return AudioConfiguration(audio=result.get_value())
case ResultType.Reset:
raise ValueError('Unhandled result type')


def select_firewall(preset: FirewallConfiguration | None = None) -> FirewallConfiguration | None:
group = MenuItemGroup.from_enum(Firewall)

if preset:
group.set_focus_by_value(preset.firewall)

result = SelectMenu[Firewall](
group,
allow_skip=True,
alignment=Alignment.CENTER,
allow_reset=True,
frame=FrameProperties.min(tr('Firewall')),
).run()

match result.type_:
case ResultType.Skip:
return preset
case ResultType.Selection:
return FirewallConfiguration(firewall=result.get_value())
case ResultType.Reset:
return None
32 changes: 32 additions & 0 deletions archinstall/lib/models/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ class PrintServiceConfigSerialization(TypedDict):
enabled: bool


class Firewall(StrEnum):
UFW = 'ufw'


class FirewallConfigSerialization(TypedDict):
firewall: str


class ZramAlgorithm(StrEnum):
ZSTD = 'zstd'
LZO_RLE = 'lzo-rle'
Expand All @@ -43,6 +51,7 @@ class ApplicationSerialization(TypedDict):
audio_config: NotRequired[AudioConfigSerialization]
power_management_config: NotRequired[PowerManagementConfigSerialization]
print_service_config: NotRequired[PrintServiceConfigSerialization]
firewall_config: NotRequired[FirewallConfigSerialization]


@dataclass
Expand Down Expand Up @@ -101,6 +110,22 @@ def parse_arg(arg: dict[str, Any]) -> 'PrintServiceConfiguration':
return PrintServiceConfiguration(arg['enabled'])


@dataclass
class FirewallConfiguration:
firewall: Firewall

def json(self) -> FirewallConfigSerialization:
return {
'firewall': self.firewall.value,
}

@staticmethod
def parse_arg(arg: dict[str, Any]) -> 'FirewallConfiguration':
return FirewallConfiguration(
Firewall(arg['firewall']),
)


@dataclass(frozen=True)
class ZramConfiguration:
enabled: bool
Expand All @@ -122,6 +147,7 @@ class ApplicationConfiguration:
audio_config: AudioConfiguration | None = None
power_management_config: PowerManagementConfiguration | None = None
print_service_config: PrintServiceConfiguration | None = None
firewall_config: FirewallConfiguration | None = None

@staticmethod
def parse_arg(
Expand All @@ -146,6 +172,9 @@ def parse_arg(
if args and (print_service_config := args.get('print_service_config')) is not None:
app_config.print_service_config = PrintServiceConfiguration.parse_arg(print_service_config)

if args and (firewall_config := args.get('firewall_config')) is not None:
app_config.firewall_config = FirewallConfiguration.parse_arg(firewall_config)

return app_config

def json(self) -> ApplicationSerialization:
Expand All @@ -163,4 +192,7 @@ def json(self) -> ApplicationSerialization:
if self.print_service_config:
config['print_service_config'] = self.print_service_config.json()

if self.firewall_config:
config['firewall_config'] = self.firewall_config.json()

return config