Skip to content
20 changes: 13 additions & 7 deletions bittensor_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5100,13 +5100,19 @@ def stake_remove(
default=self.config.get("wallet_name") or defaults.wallet.name,
)
if include_hotkeys:
if len(include_hotkeys) > 1:
print_error("Cannot unstake_all from multiple hotkeys at once.")
return False
elif is_valid_ss58_address(include_hotkeys[0]):
hotkey_ss58_address = include_hotkeys[0]
else:
print_error("Invalid hotkey ss58 address.")
# Multiple hotkeys are now supported with batching
# Initialize wallet as it's needed to resolve hotkey names and get coldkey
wallet = self.wallet_ask(
wallet_name,
wallet_path,
wallet_hotkey,
ask_for=[WO.NAME, WO.PATH],
)
if len(include_hotkeys) == 1:
# Single hotkey - use hotkey_ss58_address for backward compatibility
if is_valid_ss58_address(include_hotkeys[0]):
hotkey_ss58_address = include_hotkeys[0]
# If it's a hotkey name, it will be handled by the unstake_all function
return False
elif all_hotkeys:
wallet = self.wallet_ask(
Expand Down
100 changes: 100 additions & 0 deletions bittensor_cli/src/bittensor/subtensor_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,106 @@ async def create_signed(call_to_sign, n):
)
return False, err_msg, None

async def sign_and_send_batch_extrinsic(
self,
calls: list[GenericCall],
wallet: Wallet,
wait_for_inclusion: bool = True,
wait_for_finalization: bool = False,
era: Optional[dict[str, int]] = None,
proxy: Optional[str] = None,
nonce: Optional[int] = None,
sign_with: Literal["coldkey", "hotkey", "coldkeypub"] = "coldkey",
batch_type: Literal["batch", "batch_all"] = "batch",
mev_protection: bool = False,
) -> tuple[bool, str, Optional[AsyncExtrinsicReceipt], list[dict]]:
"""
:param calls: List of prepared Call objects to batch together
:param wallet: the wallet whose key will be used to sign the extrinsic
:param wait_for_inclusion: whether to wait until the extrinsic call is included on the chain
:param wait_for_finalization: whether to wait until the extrinsic call is finalized on the chain
:param era: The length (in blocks) for which a transaction should be valid.
:param proxy: The real account used to create the proxy. None if not using a proxy for this call.
:param nonce: The nonce used to submit this extrinsic call.
:param sign_with: Determine which of the wallet's keypairs to use to sign the extrinsic call.
:param batch_type: "batch" (stops on first error) or "batch_all" (executes all, fails if any fail)
:param mev_protection: If set, uses Mev Protection on the extrinsic, thus encrypting it.

:return: (success, error message, extrinsic receipt | None, list of individual call results)
"""
if not calls:
return False, "No calls provided for batching operation", None, []

if len(calls) == 1:
success, err_msg, receipt = await self.sign_and_send_extrinsic(
call=calls[0],
wallet=wallet,
wait_for_inclusion=wait_for_inclusion,
wait_for_finalization=wait_for_finalization,
era=era,
proxy=proxy,
nonce=nonce,
sign_with=sign_with,
mev_protection=mev_protection,
)
return success, err_msg, receipt, [{"success": success, "error": err_msg}]

batch_call = await self.substrate.compose_call(
call_module="Utility",
call_function=batch_type,
call_params={"calls": calls},
)

success, err_msg, receipt = await self.sign_and_send_extrinsic(
call=batch_call,
wallet=wallet,
wait_for_inclusion=wait_for_inclusion,
wait_for_finalization=wait_for_finalization,
era=era,
proxy=proxy,
nonce=nonce,
sign_with=sign_with,
mev_protection=mev_protection,
)

# Parse batch results if successful
call_results = []
if success and receipt:
try:
# Extract batch execution results from receipt
# The receipt should contain information about which calls succeeded/failed
for i, call in enumerate(calls):
call_results.append(
{
"index": i,
"call": call,
"success": True, # Will be updated if we can parse receipt
}
)
except Exception:
# If we can't parse results, assume all succeeded if batch succeeded
for i, call in enumerate(calls):
call_results.append(
{
"index": i,
"call": call,
"success": success,
}
)
else:
# If batch failed, mark all as failed
for i, call in enumerate(calls):
call_results.append(
{
"index": i,
"call": call,
"success": False,
"error": err_msg,
}
)

return success, err_msg, receipt, call_results

async def get_children(self, hotkey, netuid) -> tuple[bool, list, str]:
"""
This method retrieves the children of a given hotkey and netuid. It queries the SubtensorModule's ChildKeys
Expand Down
240 changes: 200 additions & 40 deletions bittensor_cli/src/commands/stake/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,56 +474,216 @@ async def stake_extrinsic(
successes = defaultdict(dict)
error_messages = defaultdict(dict)
extrinsic_ids = defaultdict(dict)
with console.status(f"\n:satellite: Staking on netuid(s): {netuids} ...") as status:
if safe_staking:
stake_coroutines = {}
for i, (ni, am, curr, price_with_tolerance) in enumerate(
zip(
netuids,
amounts_to_stake,
current_stake_balances,
prices_with_tolerance,

# Collect all calls for batching
calls_to_batch = []
call_metadata = [] # Track (netuid, staking_address, amount, current_stake, price_with_tolerance) for each call

with console.status(
f"\n:satellite: Preparing batch staking on netuid(s): {netuids} ..."
) as status:
# Get next nonce for batch
next_nonce = await subtensor.substrate.get_account_next_index(coldkey_ss58)
# Get block_hash at the beginning to speed up compose_call operations
block_hash = await subtensor.substrate.get_chain_head()

# Collect all calls - iterate through the same order as when building the lists
# The lists are built in order: for each hotkey, for each netuid
list_idx = 0
price_idx = 0
for hotkey in hotkeys_to_stake_to:
for netuid in netuids:
# Safety check: if we've processed all items from the first loop, stop
if list_idx >= len(amounts_to_stake):
break

# Verify subnet exists (same check as first loop)
# If subnet doesn't exist, it was skipped in first loop, so list_idx won't advance
# We need to skip it here too to stay in sync
subnet_info = all_subnets.get(netuid)
if not subnet_info:
# This netuid was skipped in first loop (doesn't exist)
# Don't advance list_idx, just continue to next netuid
continue

am = amounts_to_stake[list_idx]
curr = current_stake_balances[list_idx]
staking_address = hotkey[1]
price_with_tol = (
prices_with_tolerance[price_idx]
if safe_staking and price_idx < len(prices_with_tolerance)
else None
)
):
for _, staking_address in hotkeys_to_stake_to:
# Regular extrinsic for root subnet
if ni == 0:
stake_coroutines[(ni, staking_address)] = stake_extrinsic(
netuid_i=ni,
amount_=am,
current=curr,
staking_address_ss58=staking_address,
status_=status,
)
else:
stake_coroutines[(ni, staking_address)] = safe_stake_extrinsic(
netuid_=ni,
amount_=am,
current_stake=curr,
hotkey_ss58_=staking_address,
price_limit=price_with_tolerance,
status_=status,
if safe_staking:
price_idx += 1

call_metadata.append(
(netuid, staking_address, am, curr, price_with_tol)
)

if safe_staking and netuid != 0 and price_with_tol:
# Safe staking for non-root subnets
call = await subtensor.substrate.compose_call(
call_module="SubtensorModule",
call_function="add_stake_limit",
call_params={
"hotkey": staking_address,
"netuid": netuid,
"amount_staked": am.rao,
"limit_price": price_with_tol.rao,
"allow_partial": allow_partial_stake,
},
block_hash=block_hash,
)
else:
# Regular staking for root subnet or non-safe staking
call = await subtensor.substrate.compose_call(
call_module="SubtensorModule",
call_function="add_stake",
call_params={
"hotkey": staking_address,
"netuid": netuid,
"amount_staked": am.rao,
},
block_hash=block_hash,
)
calls_to_batch.append(call)
list_idx += 1

# If we have multiple calls, batch them; otherwise send single call
if len(calls_to_batch) > 1:
status.update(
f"\n:satellite: Batching {len(calls_to_batch)} stake operations..."
)
(
batch_success,
batch_err_msg,
batch_receipt,
call_results,
) = await subtensor.sign_and_send_batch_extrinsic(
calls=calls_to_batch,
wallet=wallet,
era={"period": era},
proxy=proxy,
nonce=next_nonce,
mev_protection=mev_protection,
batch_type="batch_all", # Use batch_all to execute all even if some fail
)

if batch_success and batch_receipt:
if mev_protection:
inner_hash = batch_err_msg
mev_shield_id = await extract_mev_shield_id(batch_receipt)
(
mev_success,
mev_error,
batch_receipt,
) = await wait_for_extrinsic_by_hash(
subtensor=subtensor,
extrinsic_hash=inner_hash,
shield_id=mev_shield_id,
submit_block_hash=batch_receipt.block_hash,
status=status,
)
if not mev_success:
status.stop()
print_error(f"\n:cross_mark: [red]Failed[/red]: {mev_error}")
batch_success = False
batch_err_msg = mev_error

if batch_success:
if not json_output:
await print_extrinsic_id(batch_receipt)
batch_ext_id = await batch_receipt.get_extrinsic_identifier()

# Fetch updated balances for display
block_hash = await subtensor.substrate.get_chain_head()
current_balance = await subtensor.get_balance(
coldkey_ss58, block_hash
)

# Fetch all stake balances in parallel
if not json_output:
stake_fetch_tasks = [
subtensor.get_stake(
hotkey_ss58=staking_address,
coldkey_ss58=coldkey_ss58,
netuid=ni,
block_hash=block_hash,
)
for ni, staking_address, _, _, _ in call_metadata
]
new_stakes = await asyncio.gather(*stake_fetch_tasks)

# Process results for each call
for idx, (ni, staking_address, am, curr, _) in enumerate(
call_metadata
):
# For batch_all, we assume all succeeded if batch succeeded
# Individual call results would need to be parsed from receipt events
successes[ni][staking_address] = True
error_messages[ni][staking_address] = ""
extrinsic_ids[ni][staking_address] = batch_ext_id

if not json_output:
new_stake = new_stakes[idx]
console.print(
f":white_heavy_check_mark: [dark_sea_green3]Finalized. "
f"Stake added to netuid: {ni}, hotkey: {staking_address}[/dark_sea_green3]"
)
console.print(
f"Subnet: [{COLOR_PALETTE['GENERAL']['SUBHEADING']}]"
f"{ni}[/{COLOR_PALETTE['GENERAL']['SUBHEADING']}] "
f"Stake:\n"
f" [blue]{curr}[/blue] "
f":arrow_right: "
f"[{COLOR_PALETTE['STAKE']['STAKE_AMOUNT']}]{new_stake}\n"
)

# Show final coldkey balance
if not json_output:
console.print(
f"Coldkey Balance:\n "
f"[blue]{current_wallet_balance}[/blue] "
f":arrow_right: "
f"[{COLOR_PALETTE['STAKE']['STAKE_AMOUNT']}]{current_balance}"
)
else:
stake_coroutines = {
(ni, staking_address): stake_extrinsic(
else:
# Batch failed
for ni, staking_address, _, _, _ in call_metadata:
successes[ni][staking_address] = False
error_messages[ni][staking_address] = batch_err_msg
else:
# Batch submission failed
for ni, staking_address, _, _, _ in call_metadata:
successes[ni][staking_address] = False
error_messages[ni][staking_address] = (
batch_err_msg or "Batch submission failed"
)
elif len(calls_to_batch) == 1:
# Single call - use regular extrinsic
ni, staking_address, am, curr, price_with_tol = call_metadata[0]

if safe_staking and ni != 0 and price_with_tol:
success, er_msg, ext_receipt = await safe_stake_extrinsic(
netuid_=ni,
amount_=am,
current_stake=curr,
hotkey_ss58_=staking_address,
price_limit=price_with_tol,
status_=status,
)
else:
success, er_msg, ext_receipt = await stake_extrinsic(
netuid_i=ni,
amount_=am,
current=curr,
staking_address_ss58=staking_address,
status_=status,
)
for i, (ni, am, curr) in enumerate(
zip(netuids, amounts_to_stake, current_stake_balances)
)
for _, staking_address in hotkeys_to_stake_to
}
# We can gather them all at once but balance reporting will be in race-condition.
for (ni, staking_address), coroutine in stake_coroutines.items():
success, er_msg, ext_receipt = await coroutine
successes[ni][staking_address] = success
error_messages[ni][staking_address] = er_msg
if success:
if success and ext_receipt:
extrinsic_ids[ni][
staking_address
] = await ext_receipt.get_extrinsic_identifier()
Expand Down
Loading
Loading