diff --git a/wish-api/tests/test_placeholder.py b/wish-api/tests/test_placeholder.py new file mode 100644 index 00000000..7848cf07 --- /dev/null +++ b/wish-api/tests/test_placeholder.py @@ -0,0 +1,5 @@ +"""Placeholder test file for wish-api.""" + +def test_placeholder(): + """Placeholder test to satisfy pytest configuration.""" + assert True diff --git a/wish-command-execution/src/wish_command_execution/backend/bash.py b/wish-command-execution/src/wish_command_execution/backend/bash.py index cceafeec..84485770 100644 --- a/wish-command-execution/src/wish_command_execution/backend/bash.py +++ b/wish-command-execution/src/wish_command_execution/backend/bash.py @@ -132,8 +132,14 @@ async def execute_command(self, wish: Wish, command: str, cmd_num: int, log_file Commands are executed in the working directory /app/{run_id}/ to isolate command execution from the application source code. """ - # Create command result - result = CommandResult.create(cmd_num, command, log_files, timeout_sec) + # Create command result with Command object + from wish_models.command_result.command import Command, CommandType + command_obj = Command( + command=command, + tool_type=CommandType.BASH, + tool_parameters={"timeout": timeout_sec} + ) + result = CommandResult.create(cmd_num, command_obj, log_files, timeout_sec) wish.command_results.append(result) diff --git a/wish-command-execution/src/wish_command_execution/backend/sliver.py b/wish-command-execution/src/wish_command_execution/backend/sliver.py index 029c403b..17eb3407 100644 --- a/wish-command-execution/src/wish_command_execution/backend/sliver.py +++ b/wish-command-execution/src/wish_command_execution/backend/sliver.py @@ -67,8 +67,14 @@ async def execute_command(self, wish: Wish, command: str, cmd_num: int, log_file log_files: The log files to write output to. timeout_sec: The timeout in seconds for this command. """ - # Create command result - result = CommandResult.create(cmd_num, command, log_files, timeout_sec) + # Create command result with Command object + from wish_models.command_result.command import Command, CommandType + command_obj = Command( + command=command, + tool_type=CommandType.BASH, # Sliver executes shell commands + tool_parameters={"timeout": timeout_sec, "session_id": self.session_id} + ) + result = CommandResult.create(cmd_num, command_obj, log_files, timeout_sec) wish.command_results.append(result) try: diff --git a/wish-command-execution/tests/backend/test_bash_backend.py b/wish-command-execution/tests/backend/test_bash_backend.py index 2778b8c1..e13ed5a2 100644 --- a/wish-command-execution/tests/backend/test_bash_backend.py +++ b/wish-command-execution/tests/backend/test_bash_backend.py @@ -76,7 +76,7 @@ async def test_execute_command(self, mock_makedirs, mock_open, mock_popen, backe # Verify that the command result was added to the wish assert len(wish.command_results) == 1 - assert wish.command_results[0].command == cmd + assert wish.command_results[0].command.command == cmd assert wish.command_results[0].num == cmd_num # Verify that the command was added to running_commands @@ -113,7 +113,7 @@ async def test_execute_command_subprocess_error( # Verify that the command result was added to the wish assert len(wish.command_results) == 1 - assert wish.command_results[0].command == cmd + assert wish.command_results[0].command.command == cmd assert wish.command_results[0].num == cmd_num # Verify that the command result was updated with the error state diff --git a/wish-command-execution/tests/backend/test_sliver_backend.py b/wish-command-execution/tests/backend/test_sliver_backend.py index c64617ca..a5ca06c8 100644 --- a/wish-command-execution/tests/backend/test_sliver_backend.py +++ b/wish-command-execution/tests/backend/test_sliver_backend.py @@ -92,7 +92,7 @@ async def test_execute_command(sliver_backend, wish, log_files, mock_sliver_clie # Check that the command result was added to the wish assert len(wish.command_results) == 1 - assert wish.command_results[0].command == "whoami" + assert wish.command_results[0].command.command == "whoami" assert wish.command_results[0].num == 1 # Check that the log files were written to @@ -107,7 +107,9 @@ async def test_cancel_command(sliver_backend, wish, log_files): """Test cancelling a command.""" # Add a command result to the wish timeout_sec = 60 - result = CommandResult.create(1, "whoami", log_files, timeout_sec) + from wish_models.command_result.command import Command + command = Command.create_bash_command("whoami") + result = CommandResult.create(1, command, log_files, timeout_sec) wish.command_results.append(result) # Cancel the command diff --git a/wish-command-generation-api/src/wish_command_generation_api/nodes/network_error_handler.py b/wish-command-generation-api/src/wish_command_generation_api/nodes/network_error_handler.py index 41a16619..c8607275 100644 --- a/wish-command-generation-api/src/wish_command_generation_api/nodes/network_error_handler.py +++ b/wish-command-generation-api/src/wish_command_generation_api/nodes/network_error_handler.py @@ -99,7 +99,7 @@ def handle_network_error(state: Annotated[GraphState, "Current state"], settings # Format the feedback as JSON string feedback_str = ( - json.dumps([result.model_dump() for result in state.failed_command_results], ensure_ascii=False) + json.dumps([result.model_dump(mode="json") for result in state.failed_command_results], ensure_ascii=False) if state.failed_command_results else "[]" ) diff --git a/wish-command-generation-api/src/wish_command_generation_api/nodes/timeout_handler.py b/wish-command-generation-api/src/wish_command_generation_api/nodes/timeout_handler.py index ce531a1f..09885082 100644 --- a/wish-command-generation-api/src/wish_command_generation_api/nodes/timeout_handler.py +++ b/wish-command-generation-api/src/wish_command_generation_api/nodes/timeout_handler.py @@ -107,7 +107,7 @@ def handle_timeout(state: Annotated[GraphState, "Current state"], settings_obj: # Format the feedback as JSON string feedback_str = ( - json.dumps([result.model_dump() for result in state.failed_command_results], ensure_ascii=False) + json.dumps([result.model_dump(mode="json") for result in state.failed_command_results], ensure_ascii=False) if state.failed_command_results else "[]" ) diff --git a/wish-command-generation-api/tests/integrated/test_api_with_feedback.py b/wish-command-generation-api/tests/integrated/test_api_with_feedback.py index 526ca581..60b668fc 100644 --- a/wish-command-generation-api/tests/integrated/test_api_with_feedback.py +++ b/wish-command-generation-api/tests/integrated/test_api_with_feedback.py @@ -21,7 +21,11 @@ def test_lambda_handler_with_feedback(settings): act_result = [ { "num": 1, - "command": "nmap -p- 10.10.10.40", + "command": { + "command": "nmap -p- 10.10.10.40", + "tool_type": "bash", + "tool_parameters": {} + }, "state": "TIMEOUT", "exit_code": 1, "log_summary": "timeout", @@ -78,7 +82,11 @@ def test_lambda_handler_with_network_error_feedback(settings): act_result = [ { "num": 1, - "command": "nmap -p- 10.10.10.40", + "command": { + "command": "nmap -p- 10.10.10.40", + "tool_type": "bash", + "tool_parameters": {} + }, "state": "NETWORK_ERROR", "exit_code": 1, "log_summary": "Connection closed by peer", @@ -135,7 +143,11 @@ def test_lambda_handler_with_multiple_feedback(settings): act_result = [ { "num": 1, - "command": "nmap -p1-1000 10.10.10.40", + "command": { + "command": "nmap -p1-1000 10.10.10.40", + "tool_type": "bash", + "tool_parameters": {} + }, "state": "SUCCESS", "exit_code": 0, "log_summary": "Scan completed successfully", @@ -148,7 +160,11 @@ def test_lambda_handler_with_multiple_feedback(settings): }, { "num": 2, - "command": "nmap -p1001-65535 10.10.10.40", + "command": { + "command": "nmap -p1001-65535 10.10.10.40", + "tool_type": "bash", + "tool_parameters": {} + }, "state": "TIMEOUT", "exit_code": 1, "log_summary": "timeout", diff --git a/wish-command-generation-api/tests/integrated/test_docs_integration.py b/wish-command-generation-api/tests/integrated/test_docs_integration.py index 7011e4c7..bf285c64 100644 --- a/wish-command-generation-api/tests/integrated/test_docs_integration.py +++ b/wish-command-generation-api/tests/integrated/test_docs_integration.py @@ -4,6 +4,7 @@ import pytest from wish_models.command_result import CommandResult, CommandState, LogFiles +from wish_models.command_result.command import Command, CommandType from wish_models.settings import Settings from wish_models.utc_datetime import UtcDatetime @@ -63,7 +64,11 @@ def test_command_generation_with_network_error_feedback(settings): act_result = [ CommandResult( num=1, - command="smbclient -N //10.10.10.40/Users --option='client min protocol'=LANMAN1", + command=Command( + command="smbclient -N //10.10.10.40/Users --option='client min protocol'=LANMAN1", + tool_type=CommandType.BASH, + tool_parameters={} + ), state=CommandState.NETWORK_ERROR, exit_code=1, log_summary="Connection closed by peer", @@ -118,7 +123,11 @@ def test_command_generation_with_timeout_feedback(settings): act_result = [ CommandResult( num=1, - command="nmap -p- 10.10.10.40", + command=Command( + command="nmap -p- 10.10.10.40", + tool_type=CommandType.BASH, + tool_parameters={} + ), state=CommandState.TIMEOUT, exit_code=1, log_summary="timeout", diff --git a/wish-command-generation-api/tests/integrated/test_graph_with_feedback.py b/wish-command-generation-api/tests/integrated/test_graph_with_feedback.py index 0a789b90..d642be92 100644 --- a/wish-command-generation-api/tests/integrated/test_graph_with_feedback.py +++ b/wish-command-generation-api/tests/integrated/test_graph_with_feedback.py @@ -4,6 +4,7 @@ import pytest from wish_models.command_result import CommandResult, CommandState, LogFiles +from wish_models.command_result.command import Command, CommandType from wish_models.settings import Settings from wish_models.utc_datetime import UtcDatetime @@ -66,7 +67,11 @@ def test_graph_with_timeout_feedback(settings): act_result = [ CommandResult( num=1, - command="nmap -p- 10.10.10.40", + command=Command( + command="nmap -p- 10.10.10.40", + tool_type=CommandType.BASH, + tool_parameters={} + ), state=CommandState.TIMEOUT, exit_code=1, log_summary="timeout", @@ -126,7 +131,11 @@ def test_graph_with_network_error_feedback(settings): act_result = [ CommandResult( num=1, - command="nmap -p- 10.10.10.40", + command=Command( + command="nmap -p- 10.10.10.40", + tool_type=CommandType.BASH, + tool_parameters={} + ), state=CommandState.NETWORK_ERROR, exit_code=1, log_summary="Connection closed by peer", @@ -186,7 +195,11 @@ def test_graph_with_unknown_error_feedback(settings): act_result = [ CommandResult( num=1, - command="nmap -p- 10.10.10.40", + command=Command( + command="nmap -p- 10.10.10.40", + tool_type=CommandType.BASH, + tool_parameters={} + ), state=CommandState.OTHERS, exit_code=1, log_summary="Unknown error", diff --git a/wish-command-generation-api/tests/unit/test_command_modifier.py b/wish-command-generation-api/tests/unit/test_command_modifier.py index c13e8793..e5fff60e 100644 --- a/wish-command-generation-api/tests/unit/test_command_modifier.py +++ b/wish-command-generation-api/tests/unit/test_command_modifier.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pytest +from wish_models import CommandState from wish_models.test_factories.command_input_factory import CommandInputFactory from wish_models.test_factories.command_result_factory import CommandResultSuccessFactory from wish_models.test_factories.settings_factory import SettingsFactory @@ -199,8 +200,7 @@ def test_modify_command_preserve_state(mock_modify, settings): # Create a state with additional fields processed_query = "processed test query" failed_command_result = CommandResultSuccessFactory( - command="test command", - state="SUCCESS", + state=CommandState.SUCCESS, exit_code=0, log_summary="success", ) diff --git a/wish-command-generation-api/tests/unit/test_feedback_analyzer.py b/wish-command-generation-api/tests/unit/test_feedback_analyzer.py index 290a7e54..2db70976 100644 --- a/wish-command-generation-api/tests/unit/test_feedback_analyzer.py +++ b/wish-command-generation-api/tests/unit/test_feedback_analyzer.py @@ -36,7 +36,6 @@ def test_analyze_feedback_timeout(settings): """Test analyzing feedback with a TIMEOUT error.""" # Arrange timeout_result = CommandResultSuccessFactory( - command="test command", state=CommandState.TIMEOUT, exit_code=1, log_summary="timeout", @@ -56,7 +55,6 @@ def test_analyze_feedback_network_error(settings): """Test analyzing feedback with a NETWORK_ERROR.""" # Arrange network_error_result = CommandResultSuccessFactory( - command="test command", state=CommandState.NETWORK_ERROR, exit_code=1, log_summary="network error", @@ -77,21 +75,18 @@ def test_analyze_feedback_multiple_errors(settings): # Arrange success_result = CommandResultSuccessFactory( num=1, - command="command1", state=CommandState.SUCCESS, exit_code=0, log_summary="success", ) network_error_result = CommandResultSuccessFactory( num=2, - command="command2", state=CommandState.NETWORK_ERROR, exit_code=1, log_summary="network error", ) timeout_result = CommandResultSuccessFactory( num=3, - command="command3", state=CommandState.TIMEOUT, exit_code=1, log_summary="timeout", @@ -132,7 +127,6 @@ def test_analyze_feedback_preserve_state(settings): CommandInputFactory(command="find . -name '*.py'", timeout_sec=60) ] timeout_result = CommandResultSuccessFactory( - command="test command", state=CommandState.TIMEOUT, exit_code=1, log_summary="timeout", diff --git a/wish-command-generation-api/tests/unit/test_network_error_handler.py b/wish-command-generation-api/tests/unit/test_network_error_handler.py index 0bd6caea..222024c5 100644 --- a/wish-command-generation-api/tests/unit/test_network_error_handler.py +++ b/wish-command-generation-api/tests/unit/test_network_error_handler.py @@ -90,7 +90,7 @@ def test_handle_network_error_success(settings, mock_network_error_response): assert result.is_retry is True assert result.error_type == "NETWORK_ERROR" assert len(result.failed_command_results) == 1 - assert result.failed_command_results[0].command == "nmap -p- 10.10.10.40" + assert result.failed_command_results[0].command.command == "nmap -p- 10.10.10.40" assert result.failed_command_results[0].state == CommandState.NETWORK_ERROR @@ -273,5 +273,5 @@ def test_handle_network_error_preserve_state(mock_handler, settings): assert result.is_retry is True assert result.error_type == "NETWORK_ERROR" assert len(result.failed_command_results) == 1 - assert result.failed_command_results[0].command == "nmap -p- 10.10.10.40" + assert result.failed_command_results[0].command.command == "nmap -p- 10.10.10.40" assert result.failed_command_results[0].state == CommandState.NETWORK_ERROR diff --git a/wish-command-generation-api/tests/unit/test_timeout_handler.py b/wish-command-generation-api/tests/unit/test_timeout_handler.py index 9ab988e3..70532192 100644 --- a/wish-command-generation-api/tests/unit/test_timeout_handler.py +++ b/wish-command-generation-api/tests/unit/test_timeout_handler.py @@ -46,7 +46,6 @@ def test_handle_timeout_not_timeout(settings): """Test handling timeout when the error is not a timeout.""" # Arrange network_error_result = CommandResultSuccessFactory( - command="test command", state=CommandState.NETWORK_ERROR, exit_code=1, log_summary="network error", @@ -93,7 +92,7 @@ def test_handle_timeout_success(mock_handler, settings, mock_timeout_response): assert result.is_retry is True assert result.error_type == "TIMEOUT" assert len(result.failed_command_results) == 1 - assert result.failed_command_results[0].command == "nmap -p- 10.10.10.40" + assert result.failed_command_results[0].command.command == "nmap -p- 10.10.10.40" assert result.failed_command_results[0].state == CommandState.TIMEOUT @@ -239,5 +238,5 @@ def test_handle_timeout_preserve_state(mock_handler, settings): assert result.is_retry is True assert result.error_type == "TIMEOUT" assert len(result.failed_command_results) == 1 - assert result.failed_command_results[0].command == "nmap -p- 10.10.10.40" + assert result.failed_command_results[0].command.command == "nmap -p- 10.10.10.40" assert result.failed_command_results[0].state == CommandState.TIMEOUT diff --git a/wish-log-analysis-api/docs/graph.svg b/wish-log-analysis-api/docs/graph.svg index f173a317..fcd0d5ab 100644 --- a/wish-log-analysis-api/docs/graph.svg +++ b/wish-log-analysis-api/docs/graph.svg @@ -22,7 +22,7 @@ command_state_classifier - + log_summarization->command_state_classifier @@ -34,7 +34,7 @@ result_combiner - + command_state_classifier->result_combiner @@ -46,7 +46,7 @@ __end__ - + result_combiner->__end__ @@ -58,7 +58,7 @@ __start__ - + __start__->log_summarization diff --git a/wish-log-analysis-api/src/wish_log_analysis_api/app.py b/wish-log-analysis-api/src/wish_log_analysis_api/app.py index f206c048..70f4c024 100644 --- a/wish-log-analysis-api/src/wish_log_analysis_api/app.py +++ b/wish-log-analysis-api/src/wish_log_analysis_api/app.py @@ -56,7 +56,7 @@ def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: "headers": { "Content-Type": "application/json" }, - "body": json.dumps(response.model_dump()) + "body": json.dumps(response.model_dump(mode="json")) } except Exception as e: logger.exception("Error handling request") diff --git a/wish-log-analysis-api/tests/integration/test_library_usage.py b/wish-log-analysis-api/tests/integration/test_library_usage.py index 0dbc203d..65456f28 100644 --- a/wish-log-analysis-api/tests/integration/test_library_usage.py +++ b/wish-log-analysis-api/tests/integration/test_library_usage.py @@ -58,9 +58,10 @@ def test_end_to_end_analysis(mock_chat_openai, mock_str_output_parser): try: # Create command result log_files = LogFiles(stdout=Path(stdout_path), stderr=Path(stderr_path)) + from wish_models.command_result.command import Command command_result = CommandResult( num=1, - command="echo 'test'", + command=Command.create_bash_command("echo 'test'"), state=CommandState.DOING, exit_code=0, log_files=log_files, @@ -138,9 +139,10 @@ def test_custom_config_integration(mock_chat_openai, mock_str_output_parser): try: # Create command result log_files = LogFiles(stdout=Path(stdout_path), stderr=Path(stderr_path)) + from wish_models.command_result.command import Command command_result = CommandResult( num=1, - command="unknown_command", + command=Command.create_bash_command("unknown_command"), state=CommandState.DOING, exit_code=127, log_files=log_files, diff --git a/wish-log-analysis-api/tests/test_app.py b/wish-log-analysis-api/tests/test_app.py index c9a028fd..0913357b 100644 --- a/wish-log-analysis-api/tests/test_app.py +++ b/wish-log-analysis-api/tests/test_app.py @@ -46,7 +46,6 @@ def mock_openai_api(): def command_result(): """Create a test command result.""" return CommandResultSuccessFactory.build( - command="ls -la", stdout="file1.txt\nfile2.txt", stderr=None, exit_code=0, @@ -59,7 +58,6 @@ def command_result(): def analyzed_command_result(): """Create a test analyzed command result.""" return CommandResultSuccessFactory.build( - command="ls -la", stdout="file1.txt\nfile2.txt", stderr=None, exit_code=0, @@ -73,7 +71,7 @@ def lambda_event(command_result): """Create a test Lambda event.""" return { "body": json.dumps({ - "command_result": command_result.model_dump() + "command_result": command_result.model_dump(mode="json") }) } @@ -168,7 +166,10 @@ def test_handler_success(self, lambda_event, analyzed_command_result, command_re body = json.loads(response["body"]) assert "analyzed_command_result" in body - assert body["analyzed_command_result"]["command"] == "ls -la" + assert ( + body["analyzed_command_result"]["command"]["command"] + == analyzed_command_result.command.command + ) assert body["analyzed_command_result"]["state"] == "SUCCESS" # Just check that log_summary exists, not its exact content assert "log_summary" in body["analyzed_command_result"] diff --git a/wish-log-analysis-api/tests/unit/test_analyzer.py b/wish-log-analysis-api/tests/unit/test_analyzer.py index c3694dee..9c59bbf7 100644 --- a/wish-log-analysis-api/tests/unit/test_analyzer.py +++ b/wish-log-analysis-api/tests/unit/test_analyzer.py @@ -38,10 +38,11 @@ def sample_log_files(): @pytest.fixture def sample_command_result(sample_log_files): """Create a sample command result for testing""" + from wish_models.command_result.command import Command return CommandResult( num=1, - command="ls -la", - state="DOING", + command=Command.create_bash_command("ls -la"), + state=CommandState.DOING, exit_code=0, log_files=sample_log_files, created_at=UtcDatetime.now(), diff --git a/wish-log-analysis/src/wish_log_analysis/app.py b/wish-log-analysis/src/wish_log_analysis/app.py index fec8a83f..9fd0c0f7 100644 --- a/wish-log-analysis/src/wish_log_analysis/app.py +++ b/wish-log-analysis/src/wish_log_analysis/app.py @@ -37,7 +37,7 @@ def analyze(self, command_result: CommandResult) -> LogAnalysisOutput: # Send API request try: print(f"API request destination: {self.api_url}") - request_data = {"command_result": command_result.model_dump()} + request_data = {"command_result": command_result.model_dump(mode="json")} print(f"Request data: {json.dumps(request_data, indent=2)}") response = requests.post( diff --git a/wish-log-analysis/tests/test_app.py b/wish-log-analysis/tests/test_app.py index 9837f778..d1b43146 100644 --- a/wish-log-analysis/tests/test_app.py +++ b/wish-log-analysis/tests/test_app.py @@ -38,7 +38,6 @@ def test_analyze_success(self): """Test that the client successfully analyzes a command result.""" # Create a command result command_result = CommandResultDoingFactory.build( - command="ls -la", log_files={"stdout": "file1.txt\nfile2.txt", "stderr": ""}, exit_code=0, log_summary=None, @@ -46,7 +45,6 @@ def test_analyze_success(self): # Create the expected response analyzed_result = CommandResultSuccessFactory.build( - command="ls -la", log_files={"stdout": "file1.txt\nfile2.txt", "stderr": ""}, exit_code=0, log_summary="Listed files: file1.txt, file2.txt", @@ -57,7 +55,7 @@ def test_analyze_success(self): m.post( "http://localhost:3000/analyze", json={ - "analyzed_command_result": analyzed_result.model_dump(), + "analyzed_command_result": analyzed_result.model_dump(mode="json"), "error": None, }, ) @@ -73,14 +71,13 @@ def test_analyze_success(self): # Verify the request assert m.last_request.json() == { - "command_result": command_result.model_dump() + "command_result": command_result.model_dump(mode="json") } def test_analyze_api_error(self): """Test that the client handles API errors.""" # Create a command result command_result = CommandResultDoingFactory.build( - command="ls -la", log_files={"stdout": "file1.txt\nfile2.txt", "stderr": ""}, exit_code=0, log_summary=None, @@ -91,7 +88,7 @@ def test_analyze_api_error(self): m.post( "http://localhost:3000/analyze", json={ - "analyzed_command_result": command_result.model_dump(), + "analyzed_command_result": command_result.model_dump(mode="json"), "error": "API error", }, ) @@ -109,7 +106,6 @@ def test_analyze_request_exception(self): """Test that the client handles request exceptions.""" # Create a command result command_result = CommandResultDoingFactory.build( - command="ls -la", log_files={"stdout": "file1.txt\nfile2.txt", "stderr": ""}, exit_code=0, log_summary=None, @@ -136,7 +132,6 @@ def test_analyze_result(): """Test the analyze_result function.""" # Create a command result command_result = CommandResultDoingFactory.build( - command="ls -la", log_files={"stdout": "file1.txt\nfile2.txt", "stderr": ""}, exit_code=0, log_summary=None, @@ -144,20 +139,21 @@ def test_analyze_result(): ) # Create the expected response + from wish_models.command_result.command import Command analyzed_result = CommandResultSuccessFactory.build( - command="ls -la", log_files={"stdout": "file1.txt\nfile2.txt", "stderr": ""}, exit_code=0, log_summary="Listed files: file1.txt, file2.txt", timeout_sec=60, ) + analyzed_result.command = Command.create_bash_command("ls -la") # Mock the API response with requests_mock.Mocker() as m: m.post( "http://localhost:3000/analyze", json={ - "analyzed_command_result": analyzed_result.model_dump(), + "analyzed_command_result": analyzed_result.model_dump(mode="json"), "error": None, }, ) @@ -166,7 +162,7 @@ def test_analyze_result(): result = analyze_result(command_result) # Verify the result - assert result.command == "ls -la" + assert result.command == command_result.command # Command should be unchanged assert result.exit_code == 0 assert result.state == CommandState.SUCCESS assert result.log_summary == "Listed files: file1.txt, file2.txt" diff --git a/wish-models/src/wish_models/command_result/command.py b/wish-models/src/wish_models/command_result/command.py new file mode 100644 index 00000000..ebc6f086 --- /dev/null +++ b/wish-models/src/wish_models/command_result/command.py @@ -0,0 +1,144 @@ +"""Command model for representing commands with tool information.""" + +import json +from enum import Enum +from typing import Any, Dict + +from pydantic import BaseModel + + +class CommandType(Enum): + """Supported command execution tools.""" + + BASH = "bash" + MSFCONSOLE = "msfconsole" + # Future extensions + PYTHON = "python" + POWERSHELL = "powershell" + + +class Command(BaseModel): + """Represents a command with tool type and parameters.""" + + command: str + """The actual command string to execute.""" + + tool_type: CommandType + """The tool that should execute this command.""" + + tool_parameters: Dict[str, Any] = {} + """Tool-specific parameters for command execution.""" + + def to_json(self) -> str: + """Serialize Command to JSON string. + + Returns: + JSON string representation of the Command. + """ + return self.model_dump_json() + + @classmethod + def from_json(cls, json_str: str) -> "Command": + """Deserialize Command from JSON string. + + Args: + json_str: JSON string representation of a Command. + + Returns: + Command object. + + Raises: + ValueError: If JSON is invalid or doesn't represent a valid Command. + """ + return cls.model_validate_json(json_str) + + @classmethod + def create_bash_command( + cls, + command: str, + timeout: int = 300, + category: str = "", + **kwargs + ) -> "Command": + """Create a bash command with standard parameters. + + Args: + command: The bash command to execute. + timeout: Command timeout in seconds. + category: Command category hint (network, file, process, etc.). + **kwargs: Additional tool parameters. + + Returns: + Command configured for bash execution. + """ + tool_parameters = { + "timeout": timeout, + **kwargs + } + if category: + tool_parameters["category"] = category + + return cls( + command=command, + tool_type=CommandType.BASH, + tool_parameters=tool_parameters + ) + + @classmethod + def create_msfconsole_command( + cls, + command: str, + module: str = "", + rhosts: str = "", + lhost: str = "", + **kwargs + ) -> "Command": + """Create an msfconsole command with standard parameters. + + Args: + command: The msfconsole command sequence to execute. + module: The metasploit module path. + rhosts: Target host(s). + lhost: Local host for reverse connections. + **kwargs: Additional tool parameters. + + Returns: + Command configured for msfconsole execution. + """ + tool_parameters = {**kwargs} + if module: + tool_parameters["module"] = module + if rhosts: + tool_parameters["rhosts"] = rhosts + if lhost: + tool_parameters["lhost"] = lhost + + return cls( + command=command, + tool_type=CommandType.MSFCONSOLE, + tool_parameters=tool_parameters + ) + + +def parse_command_from_string(input_str: str) -> Command: + """Parse command input from string (JSON or plain command). + + Args: + input_str: Either a JSON string representing a Command or plain command string. + + Returns: + Command object. + + Raises: + ValueError: If input cannot be parsed as a valid command. + """ + # Try to parse as JSON first + try: + data = json.loads(input_str) + if isinstance(data, dict) and "command" in data and "tool_type" in data: + return Command.model_validate(data) + except (json.JSONDecodeError, ValueError): + pass + + # If not JSON, treat as plain bash command + return Command.create_bash_command(input_str) diff --git a/wish-models/src/wish_models/command_result/command_result.py b/wish-models/src/wish_models/command_result/command_result.py index ff9f9797..1f8e3ab1 100644 --- a/wish-models/src/wish_models/command_result/command_result.py +++ b/wish-models/src/wish_models/command_result/command_result.py @@ -2,6 +2,7 @@ from pydantic import BaseModel +from wish_models.command_result.command import Command from wish_models.command_result.command_state import CommandState from wish_models.command_result.log_files import LogFiles from wish_models.utc_datetime import UtcDatetime @@ -15,7 +16,7 @@ class CommandResult(BaseModel): It starts from 1.""" - command: str + command: Command """Command executed.""" state: CommandState @@ -52,7 +53,7 @@ class CommandResult(BaseModel): It's None before the command is finished.""" @classmethod - def create(cls, num: int, command: str, log_files: LogFiles, timeout_sec: int) -> "CommandResult": + def create(cls, num: int, command: Command, log_files: LogFiles, timeout_sec: int) -> "CommandResult": return cls( num=num, command=command, @@ -74,7 +75,15 @@ def to_json(self) -> str: return self.model_dump_json(indent=2) def to_dict(self) -> dict: - return self.model_dump() + return self.model_dump(mode="json") + + def get_command_string(self) -> str: + """Get the command string for execution. + + Returns: + The command string from the Command object. + """ + return self.command.command def finish(self, exit_code: int, state: CommandState | None = None, log_summary: str | None = None) -> None: """Mark the command as finished and update its state. diff --git a/wish-models/src/wish_models/test_factories/__init__.py b/wish-models/src/wish_models/test_factories/__init__.py index c7cffad9..178fd84c 100644 --- a/wish-models/src/wish_models/test_factories/__init__.py +++ b/wish-models/src/wish_models/test_factories/__init__.py @@ -1,3 +1,4 @@ +from .command_factory import BashCommandFactory, CommandFactory, MsfconsoleCommandFactory from .command_result_factory import CommandResultDoingFactory, CommandResultSuccessFactory from .executable_collection_factory import ExecutableCollectionFactory from .executable_info_factory import ExecutableInfoFactory @@ -8,6 +9,9 @@ from .wish_factory import WishDoingFactory, WishDoneFactory __all__ = [ + "CommandFactory", + "BashCommandFactory", + "MsfconsoleCommandFactory", "CommandResultDoingFactory", "CommandResultSuccessFactory", "LogFilesFactory", diff --git a/wish-models/src/wish_models/test_factories/command_factory.py b/wish-models/src/wish_models/test_factories/command_factory.py new file mode 100644 index 00000000..bef2a04d --- /dev/null +++ b/wish-models/src/wish_models/test_factories/command_factory.py @@ -0,0 +1,37 @@ +"""Factory class for Command model.""" + +import factory + +from wish_models.command_result.command import Command, CommandType + + +class CommandFactory(factory.Factory): + """Factory for creating Command objects.""" + + class Meta: + model = Command + + command = factory.Faker("sentence") + tool_type = CommandType.BASH + tool_parameters = factory.Dict({"timeout": 300}) + + +class BashCommandFactory(CommandFactory): + """Factory for creating bash commands.""" + + tool_type = CommandType.BASH + tool_parameters = factory.Dict({ + "timeout": 300, + "category": factory.Faker("random_element", elements=["network", "file", "process", "system"]) + }) + + +class MsfconsoleCommandFactory(CommandFactory): + """Factory for creating msfconsole commands.""" + + tool_type = CommandType.MSFCONSOLE + tool_parameters = factory.Dict({ + "module": factory.Faker("file_path"), + "rhosts": factory.Faker("ipv4"), + "lhost": factory.Faker("ipv4") + }) diff --git a/wish-models/src/wish_models/test_factories/command_result_factory.py b/wish-models/src/wish_models/test_factories/command_result_factory.py index d5cbfcbc..86e89727 100644 --- a/wish-models/src/wish_models/test_factories/command_result_factory.py +++ b/wish-models/src/wish_models/test_factories/command_result_factory.py @@ -1,6 +1,7 @@ import factory from wish_models import CommandResult, CommandState +from wish_models.test_factories.command_factory import BashCommandFactory from wish_models.test_factories.log_files_factory import LogFilesFactory from wish_models.test_factories.utc_datetime_factory import UtcDatetimeFactory @@ -10,7 +11,7 @@ class Meta: model = CommandResult num = factory.Faker("random_int", min=1) - command = factory.Faker("sentence") + command = factory.SubFactory(BashCommandFactory) state = CommandState.SUCCESS timeout_sec = 60 exit_code = 0 @@ -25,7 +26,7 @@ class Meta: model = CommandResult num = factory.Faker("random_int", min=1) - command = factory.Faker("sentence") + command = factory.SubFactory(BashCommandFactory) state = CommandState.DOING timeout_sec = 60 log_files = factory.SubFactory(LogFilesFactory) diff --git a/wish-models/src/wish_models/test_factories/command_result_network_error_factory.py b/wish-models/src/wish_models/test_factories/command_result_network_error_factory.py index b8a025de..859cecd5 100644 --- a/wish-models/src/wish_models/test_factories/command_result_network_error_factory.py +++ b/wish-models/src/wish_models/test_factories/command_result_network_error_factory.py @@ -5,6 +5,8 @@ import factory from wish_models import CommandResult, CommandState +from wish_models.command_result.command import Command +from wish_models.test_factories.command_factory import BashCommandFactory from wish_models.test_factories.log_files_factory import LogFilesFactory from wish_models.test_factories.utc_datetime_factory import UtcDatetimeFactory @@ -16,7 +18,7 @@ class Meta: model = CommandResult num = factory.Faker("random_int", min=1) - command = factory.Faker("sentence") + command = factory.SubFactory(BashCommandFactory) state = CommandState.NETWORK_ERROR timeout_sec = 60 exit_code = 1 @@ -38,7 +40,7 @@ def create_with_command(cls, command: str, timeout_sec: int = 60, log_summary: s CommandResult: ネットワークエラーのCommandResult """ return cls( - command=command, + command=Command.create_bash_command(command), timeout_sec=timeout_sec, log_summary=log_summary or "Connection closed by peer" ) diff --git a/wish-models/src/wish_models/test_factories/command_result_timeout_factory.py b/wish-models/src/wish_models/test_factories/command_result_timeout_factory.py index 460e31b5..928f18df 100644 --- a/wish-models/src/wish_models/test_factories/command_result_timeout_factory.py +++ b/wish-models/src/wish_models/test_factories/command_result_timeout_factory.py @@ -5,6 +5,8 @@ import factory from wish_models import CommandResult, CommandState +from wish_models.command_result.command import Command +from wish_models.test_factories.command_factory import BashCommandFactory from wish_models.test_factories.log_files_factory import LogFilesFactory from wish_models.test_factories.utc_datetime_factory import UtcDatetimeFactory @@ -16,7 +18,7 @@ class Meta: model = CommandResult num = factory.Faker("random_int", min=1) - command = factory.Faker("sentence") + command = factory.SubFactory(BashCommandFactory) state = CommandState.TIMEOUT timeout_sec = 60 exit_code = 1 @@ -37,6 +39,6 @@ def create_with_command(cls, command: str, timeout_sec: int = 60) -> CommandResu CommandResult: タイムアウトエラーのCommandResult """ return cls( - command=command, + command=Command.create_bash_command(command), timeout_sec=timeout_sec ) diff --git a/wish-models/src/wish_models/wish/wish.py b/wish-models/src/wish_models/wish/wish.py index c4ab5861..c248b197 100644 --- a/wish-models/src/wish_models/wish/wish.py +++ b/wish-models/src/wish_models/wish/wish.py @@ -52,7 +52,7 @@ def to_json(self) -> str: return self.model_dump_json(indent=2) def to_dict(self) -> dict: - return self.model_dump() + return self.model_dump(mode="json") @classmethod def _gen_id(cls) -> str: diff --git a/wish-models/tests/command_result/test_command_result.py b/wish-models/tests/command_result/test_command_result.py index dd214670..32b2d5d5 100644 --- a/wish-models/tests/command_result/test_command_result.py +++ b/wish-models/tests/command_result/test_command_result.py @@ -16,7 +16,11 @@ def test_command_result_creation(self): now = UtcDatetime.now() data = { "num": 1, - "command": "echo hello", + "command": { + "command": "echo hello", + "tool_type": "bash", + "tool_parameters": {} + }, "timeout_sec": 0, "exit_code": 0, "state": "SUCCESS", @@ -26,7 +30,7 @@ def test_command_result_creation(self): "finished_at": now, } cr = CommandResult.from_dict(data) - assert cr.command == "echo hello" + assert cr.command.command == "echo hello" assert cr.timeout_sec == 0 assert cr.exit_code == 0 assert cr.state == CommandState.SUCCESS @@ -51,7 +55,8 @@ def test_to_json_and_from_json(self, command_result): def test_create(self): # Arrange num = 1 - command = "echo hello" + from wish_models.command_result.command import Command + command = Command.create_bash_command("echo hello") log_files = LogFilesFactory.create() now = UtcDatetime.now() timeout_sec = 60 @@ -137,5 +142,5 @@ def test_parse_command_results_json(): json_str = json.dumps(data_list) results = parse_command_results_json(json_str) assert len(results) == 2 - assert results[0].command == data_list[0]["command"] - assert results[1].command == data_list[1]["command"] + assert results[0].command.command == data_list[0]["command"]["command"] + assert results[1].command.command == data_list[1]["command"]["command"]