|
13 | 13 | import time |
14 | 14 | import uuid |
15 | 15 | from pathlib import Path |
16 | | -from typing import Callable, Optional |
| 16 | +from typing import TYPE_CHECKING, Callable, Optional |
17 | 17 |
|
18 | 18 | import paramiko |
19 | 19 | import pytest |
|
22 | 22 | from helpers.ssh_client import REFSSHClient |
23 | 23 | from helpers.web_client import REFWebClient |
24 | 24 |
|
| 25 | +if TYPE_CHECKING: |
| 26 | + from helpers.ref_instance import REFInstance |
| 27 | + |
25 | 28 | SSHClientFactory = Callable[[str, str], REFSSHClient] |
26 | 29 |
|
27 | 30 |
|
| 31 | +def _enable_tcp_forwarding(ref_instance: "REFInstance") -> bool: |
| 32 | + """Enable TCP port forwarding in system settings.""" |
| 33 | + |
| 34 | + def _enable() -> bool: |
| 35 | + from flask import current_app |
| 36 | + |
| 37 | + from ref.model.settings import SystemSettingsManager |
| 38 | + |
| 39 | + SystemSettingsManager.ALLOW_TCP_PORT_FORWARDING.value = True |
| 40 | + current_app.db.session.commit() |
| 41 | + return True |
| 42 | + |
| 43 | + return ref_instance.remote_exec(_enable) |
| 44 | + |
| 45 | + |
| 46 | +def _disable_tcp_forwarding(ref_instance: "REFInstance") -> bool: |
| 47 | + """Disable TCP port forwarding in system settings.""" |
| 48 | + |
| 49 | + def _disable() -> bool: |
| 50 | + from flask import current_app |
| 51 | + |
| 52 | + from ref.model.settings import SystemSettingsManager |
| 53 | + |
| 54 | + SystemSettingsManager.ALLOW_TCP_PORT_FORWARDING.value = False |
| 55 | + current_app.db.session.commit() |
| 56 | + return True |
| 57 | + |
| 58 | + return ref_instance.remote_exec(_disable) |
| 59 | + |
| 60 | + |
| 61 | +def _get_tcp_forwarding_setting(ref_instance: "REFInstance") -> bool: |
| 62 | + """Get the current TCP port forwarding setting value.""" |
| 63 | + |
| 64 | + def _get() -> bool: |
| 65 | + from ref.model.settings import SystemSettingsManager |
| 66 | + |
| 67 | + return SystemSettingsManager.ALLOW_TCP_PORT_FORWARDING.value # type: ignore[return-value] |
| 68 | + |
| 69 | + return ref_instance.remote_exec(_get) |
| 70 | + |
| 71 | + |
28 | 72 | class PortForwardingTestState: |
29 | 73 | """Shared state for port forwarding tests.""" |
30 | 74 |
|
@@ -72,6 +116,19 @@ def test_01_admin_login( |
72 | 116 | success = web_client.login("0", admin_password) |
73 | 117 | assert success, "Admin login failed" |
74 | 118 |
|
| 119 | + @pytest.mark.e2e |
| 120 | + def test_01b_enable_tcp_forwarding( |
| 121 | + self, |
| 122 | + ref_instance: "REFInstance", |
| 123 | + ): |
| 124 | + """Enable TCP port forwarding in system settings.""" |
| 125 | + result = _enable_tcp_forwarding(ref_instance) |
| 126 | + assert result is True, "Failed to enable TCP port forwarding" |
| 127 | + |
| 128 | + # Verify the setting was actually changed |
| 129 | + value = _get_tcp_forwarding_setting(ref_instance) |
| 130 | + assert value is True, "TCP port forwarding setting not enabled" |
| 131 | + |
75 | 132 | @pytest.mark.e2e |
76 | 133 | def test_02_create_exercise( |
77 | 134 | self, |
@@ -452,160 +509,6 @@ def test_http_server_request_response( |
452 | 509 | pass |
453 | 510 | client.close() |
454 | 511 |
|
455 | | - @pytest.mark.e2e |
456 | | - def test_multiple_concurrent_channels( |
457 | | - self, |
458 | | - ssh_host: str, |
459 | | - ssh_port: int, |
460 | | - port_forwarding_state: PortForwardingTestState, |
461 | | - ): |
462 | | - """ |
463 | | - Test multiple concurrent port forwarding channels. |
464 | | -
|
465 | | - This test verifies that multiple forwarding channels can be |
466 | | - opened and used simultaneously over the same SSH connection. |
467 | | - """ |
468 | | - assert port_forwarding_state.student_private_key is not None |
469 | | - assert port_forwarding_state.exercise_name is not None |
470 | | - |
471 | | - pkey = _parse_private_key(port_forwarding_state.student_private_key) |
472 | | - client = _create_ssh_client( |
473 | | - ssh_host, ssh_port, port_forwarding_state.exercise_name, pkey |
474 | | - ) |
475 | | - |
476 | | - test_ports = [19881, 19882, 19883] |
477 | | - |
478 | | - try: |
479 | | - # Write and start echo servers on multiple ports |
480 | | - sftp = client.open_sftp() |
481 | | - sftp.file("/tmp/echo_server.py", "w").write(ECHO_SERVER_SCRIPT) |
482 | | - sftp.close() |
483 | | - |
484 | | - for port in test_ports: |
485 | | - _, stdout, _ = client.exec_command( |
486 | | - f"python3 /tmp/echo_server.py {port} &" |
487 | | - ) |
488 | | - stdout.channel.recv_exit_status() |
489 | | - |
490 | | - time.sleep(0.5) |
491 | | - |
492 | | - transport = client.get_transport() |
493 | | - assert transport is not None |
494 | | - |
495 | | - # Open channels to all servers |
496 | | - channels = [] |
497 | | - for port in test_ports: |
498 | | - channel = transport.open_channel( |
499 | | - "direct-tcpip", |
500 | | - ("127.0.0.1", port), |
501 | | - ("127.0.0.1", 0), |
502 | | - ) |
503 | | - channel.settimeout(10.0) |
504 | | - channels.append((port, channel)) |
505 | | - |
506 | | - # Send data through all channels and verify responses |
507 | | - for port, channel in channels: |
508 | | - test_msg = f"Message to port {port}".encode() |
509 | | - channel.sendall(test_msg) |
510 | | - response = channel.recv(1024) |
511 | | - expected = b"ECHO:" + test_msg |
512 | | - assert response == expected, ( |
513 | | - f"Port {port}: Expected {expected!r}, got {response!r}" |
514 | | - ) |
515 | | - |
516 | | - # Close all channels |
517 | | - for _, channel in channels: |
518 | | - channel.close() |
519 | | - |
520 | | - finally: |
521 | | - # Cleanup |
522 | | - try: |
523 | | - for port in test_ports: |
524 | | - client.exec_command(f"pkill -f 'echo_server.py {port}'") |
525 | | - client.exec_command("rm -f /tmp/echo_server.py") |
526 | | - except Exception: |
527 | | - pass |
528 | | - client.close() |
529 | | - |
530 | | - @pytest.mark.e2e |
531 | | - def test_large_data_transfer( |
532 | | - self, |
533 | | - ssh_host: str, |
534 | | - ssh_port: int, |
535 | | - port_forwarding_state: PortForwardingTestState, |
536 | | - ): |
537 | | - """ |
538 | | - Test transferring larger amounts of data through port forwarding. |
539 | | -
|
540 | | - This verifies that the forwarding handles data beyond single packets. |
541 | | - """ |
542 | | - assert port_forwarding_state.student_private_key is not None |
543 | | - assert port_forwarding_state.exercise_name is not None |
544 | | - |
545 | | - pkey = _parse_private_key(port_forwarding_state.student_private_key) |
546 | | - client = _create_ssh_client( |
547 | | - ssh_host, ssh_port, port_forwarding_state.exercise_name, pkey |
548 | | - ) |
549 | | - |
550 | | - test_port = 19884 |
551 | | - |
552 | | - try: |
553 | | - # Write the echo server script |
554 | | - sftp = client.open_sftp() |
555 | | - sftp.file("/tmp/echo_server.py", "w").write(ECHO_SERVER_SCRIPT) |
556 | | - sftp.close() |
557 | | - |
558 | | - # Start the echo server |
559 | | - _, stdout, _ = client.exec_command( |
560 | | - f"python3 /tmp/echo_server.py {test_port} &" |
561 | | - ) |
562 | | - stdout.channel.recv_exit_status() |
563 | | - time.sleep(0.5) |
564 | | - |
565 | | - transport = client.get_transport() |
566 | | - assert transport is not None |
567 | | - |
568 | | - # Open channel |
569 | | - channel = transport.open_channel( |
570 | | - "direct-tcpip", |
571 | | - ("127.0.0.1", test_port), |
572 | | - ("127.0.0.1", 0), |
573 | | - ) |
574 | | - channel.settimeout(10.0) |
575 | | - |
576 | | - # Send larger data (64KB) |
577 | | - large_data = b"X" * (64 * 1024) |
578 | | - channel.sendall(large_data) |
579 | | - |
580 | | - # Receive response |
581 | | - response = b"" |
582 | | - expected_len = len(b"ECHO:") + len(large_data) |
583 | | - while len(response) < expected_len: |
584 | | - try: |
585 | | - chunk = channel.recv(8192) |
586 | | - if not chunk: |
587 | | - break |
588 | | - response += chunk |
589 | | - except socket.timeout: |
590 | | - break |
591 | | - |
592 | | - channel.close() |
593 | | - |
594 | | - # Verify response |
595 | | - assert response.startswith(b"ECHO:"), "Response should start with ECHO:" |
596 | | - assert len(response) == expected_len, ( |
597 | | - f"Expected {expected_len} bytes, got {len(response)}" |
598 | | - ) |
599 | | - |
600 | | - finally: |
601 | | - # Cleanup |
602 | | - try: |
603 | | - client.exec_command(f"pkill -f 'echo_server.py {test_port}'") |
604 | | - client.exec_command("rm -f /tmp/echo_server.py") |
605 | | - except Exception: |
606 | | - pass |
607 | | - client.close() |
608 | | - |
609 | 512 | @pytest.mark.e2e |
610 | 513 | def test_direct_tcpip_channel_can_be_opened( |
611 | 514 | self, |
@@ -912,8 +815,81 @@ def test_remote_port_forwarding_request( |
912 | 815 | # Remote port forwarding might be restricted |
913 | 816 | # This is acceptable - we're just testing the capability |
914 | 817 | if "rejected" in str(e).lower() or "denied" in str(e).lower(): |
| 818 | + # |
915 | 819 | pytest.skip(f"Remote port forwarding not available: {e}") |
916 | 820 | raise |
917 | 821 |
|
918 | 822 | finally: |
919 | 823 | client.close() |
| 824 | + |
| 825 | + |
| 826 | +class TestTCPForwardingSettingEnforcement: |
| 827 | + """ |
| 828 | + Test that TCP port forwarding can be enabled/disabled via system settings. |
| 829 | +
|
| 830 | + These tests verify that the ALLOW_TCP_PORT_FORWARDING setting is properly |
| 831 | + enforced by the SSH server. |
| 832 | + """ |
| 833 | + |
| 834 | + @pytest.mark.e2e |
| 835 | + def test_forwarding_blocked_when_disabled( |
| 836 | + self, |
| 837 | + ssh_host: str, |
| 838 | + ssh_port: int, |
| 839 | + ref_instance: "REFInstance", |
| 840 | + port_forwarding_state: PortForwardingTestState, |
| 841 | + ): |
| 842 | + """ |
| 843 | + Verify TCP forwarding fails when the setting is disabled. |
| 844 | +
|
| 845 | + This test disables TCP forwarding and verifies that opening a |
| 846 | + direct-tcpip channel fails with the expected error. |
| 847 | + """ |
| 848 | + assert port_forwarding_state.student_private_key is not None |
| 849 | + assert port_forwarding_state.exercise_name is not None |
| 850 | + |
| 851 | + # Disable TCP forwarding |
| 852 | + _disable_tcp_forwarding(ref_instance) |
| 853 | + |
| 854 | + # Verify the setting is disabled |
| 855 | + assert _get_tcp_forwarding_setting(ref_instance) is False |
| 856 | + |
| 857 | + pkey = _parse_private_key(port_forwarding_state.student_private_key) |
| 858 | + |
| 859 | + # Need a fresh SSH connection to pick up the new setting |
| 860 | + client = paramiko.SSHClient() |
| 861 | + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) |
| 862 | + |
| 863 | + try: |
| 864 | + client.connect( |
| 865 | + hostname=ssh_host, |
| 866 | + port=ssh_port, |
| 867 | + username=port_forwarding_state.exercise_name, |
| 868 | + pkey=pkey, |
| 869 | + timeout=5.0, |
| 870 | + allow_agent=False, |
| 871 | + look_for_keys=False, |
| 872 | + ) |
| 873 | + |
| 874 | + transport = client.get_transport() |
| 875 | + assert transport is not None |
| 876 | + |
| 877 | + # Try to open a direct-tcpip channel - this should fail |
| 878 | + with pytest.raises(paramiko.ChannelException) as exc_info: |
| 879 | + transport.open_channel( |
| 880 | + "direct-tcpip", |
| 881 | + ("127.0.0.1", 12345), |
| 882 | + ("127.0.0.1", 0), |
| 883 | + timeout=3.0, |
| 884 | + ) |
| 885 | + |
| 886 | + # Error code 1 = "Administratively prohibited" |
| 887 | + # Error code 2 = "Connect failed" (also acceptable) |
| 888 | + assert exc_info.value.code in (1, 2), ( |
| 889 | + f"Expected channel error code 1 or 2, got {exc_info.value.code}" |
| 890 | + ) |
| 891 | + |
| 892 | + finally: |
| 893 | + client.close() |
| 894 | + # Re-enable TCP forwarding for subsequent tests |
| 895 | + _enable_tcp_forwarding(ref_instance) |
0 commit comments