diff --git a/invokeai/frontend/web/src/common/util/promptAttention.test.ts b/invokeai/frontend/web/src/common/util/promptAttention.test.ts new file mode 100644 index 00000000000..33db5735dcd --- /dev/null +++ b/invokeai/frontend/web/src/common/util/promptAttention.test.ts @@ -0,0 +1,169 @@ +import { describe, expect, it } from 'vitest'; + +import { adjustPromptAttention } from './promptAttention'; + +describe('promptAttention', () => { + describe('adjustPromptAttention', () => { + describe('cross-boundary selection', () => { + it('should split group and apply attention when selection spans from inside group to outside (increment)', () => { + // "(a b)+ c" selecting "b through c": + // - a was in + group, not selected → gets + attention → a+ + // - b was in + group (effective +), selected, increment → ++ → b++ + // - c was neutral, selected, increment → + + const prompt = '(a b)+ c'; + // Selection from 'b' (position 3) to end of 'c' (position 8) + const result = adjustPromptAttention(prompt, 3, 8, 'increment'); + + expect(result.prompt).toBe('a+ b++ c+'); + }); + + it('should split group and apply attention when selection spans from inside group to outside (decrement)', () => { + // "(a b)+ c" selecting "b through c": + // - a was in + group, not selected → gets + attention → a+ + // - b was in + group (effective +), selected, decrement → neutralizes → b + // - c was neutral, selected, decrement → - + const prompt = '(a b)+ c'; + // Selection from 'b' (position 3) to end of 'c' (position 8) + const result = adjustPromptAttention(prompt, 3, 8, 'decrement'); + + expect(result.prompt).toBe('a+ b c-'); + }); + + it('should split group when selection starts before group and ends inside (increment)', () => { + // "a (b c)+" selecting "a and b": + // - a was neutral, selected, increment → + + // - b was in + group (effective +), selected, increment → ++ + // - c was in + group, not selected → keeps + attention + const prompt = 'a (b c)+'; + // Selection from 'a' (position 0) to 'b' (position 3) + const result = adjustPromptAttention(prompt, 0, 4, 'increment'); + + expect(result.prompt).toBe('a+ b++ c+'); + }); + + it('should handle nested groups with cross-boundary selection', () => { + const prompt = '((a b)+)+ c'; + // Selection from inside nested group to outside + const result = adjustPromptAttention(prompt, 2, 11, 'increment'); + + expect(result.prompt).toBe('(((a b)+)+ c)+'); + }); + + it('should handle selection spanning multiple groups', () => { + const prompt = '(a)+ (b)+'; + // Selection spanning both groups + const result = adjustPromptAttention(prompt, 0, 9, 'increment'); + + expect(result.prompt).toBe('((a)+ (b)+)+'); + }); + + it('should split negative group correctly (decrement on negative group)', () => { + // "(a b)- c" selecting "b through c": + // - a was in - group, not selected → keeps - attention → a- + // - b was in - group (effective -), selected, decrement → -- → b-- + // - c was neutral, selected, decrement → - + const prompt = '(a b)- c'; + const result = adjustPromptAttention(prompt, 3, 8, 'decrement'); + + expect(result.prompt).toBe('a- b-- c-'); + }); + + it('should split negative group correctly (increment on negative group)', () => { + // "(a b)- c" selecting "b through c": + // - a was in - group, not selected → keeps - attention → a- + // - b was in - group (effective -), selected, increment → neutralizes → b + // - c was neutral, selected, increment → + + const prompt = '(a b)- c'; + const result = adjustPromptAttention(prompt, 3, 8, 'increment'); + + expect(result.prompt).toBe('a- b c+'); + }); + + it('should handle multiple non-selected items in group', () => { + // "(a b c)+ d" selecting only "c d": + // - a, b not selected → keep + attention + // - c was in + group (effective +), selected, decrement → neutralizes → c + // - d selected, decrement → - + const prompt = '(a b c)+ d'; + const result = adjustPromptAttention(prompt, 5, 10, 'decrement'); + + expect(result.prompt).toBe('a+ b+ c d-'); + }); + + it('should handle word with existing attention in group when crossing boundary', () => { + // "x (d- e)+" with "x d" selected and incremented: + // - x at root, selected → increment → x+ + // - d HAS own attention (-), selected → adjust own only (- → neutral) → d + // - e not selected → gets group's + attention → e+ + const prompt = 'x (d- e)+'; + // Select from x to d (positions 0 to 5, covering x and d-) + const result = adjustPromptAttention(prompt, 0, 5, 'increment'); + + expect(result.prompt).toBe('x+ d e+'); + }); + + it('should handle complex multi-group case', () => { + // "(a+ b)+ c (d- e)+" with "c d" selected and incremented: + // - First group untouched since no children selected + // - c at root, selected → increment → c+ + // - Second group split: + // - d HAS own attention (-), selected → adjust own only (- → neutral) → d + // - e not selected → gets group's + attention → e+ + const prompt = '(a+ b)+ c (d- e)+'; + // Select from c to d + const result = adjustPromptAttention(prompt, 8, 14, 'increment'); + + expect(result.prompt).toBe('(a+ b)+ c+ d e+'); + }); + }); + + describe('single word', () => { + it('should add + when incrementing word without attention', () => { + const prompt = 'hello world'; + const result = adjustPromptAttention(prompt, 0, 5, 'increment'); + + expect(result.prompt).toBe('hello+ world'); + }); + + it('should add - when decrementing word without attention', () => { + const prompt = 'hello world'; + const result = adjustPromptAttention(prompt, 0, 5, 'decrement'); + + expect(result.prompt).toBe('hello- world'); + }); + }); + + describe('existing group', () => { + it('should adjust group attention when cursor is at group boundary', () => { + const prompt = '(hello world)+'; + // Cursor at the closing paren + const result = adjustPromptAttention(prompt, 13, 14, 'increment'); + + expect(result.prompt).toBe('(hello world)++'); + }); + + it('should remove group when attention becomes neutral', () => { + const prompt = '(hello world)+'; + const result = adjustPromptAttention(prompt, 0, 14, 'decrement'); + + expect(result.prompt).toBe('hello world'); + }); + }); + + describe('multiple words without group', () => { + it('should create new group with + when incrementing multiple words', () => { + const prompt = 'hello world'; + const result = adjustPromptAttention(prompt, 0, 11, 'increment'); + + expect(result.prompt).toBe('(hello world)+'); + }); + + it('should create new group with - when decrementing multiple words', () => { + const prompt = 'hello world'; + const result = adjustPromptAttention(prompt, 0, 11, 'decrement'); + + expect(result.prompt).toBe('(hello world)-'); + }); + }); + }); +}); diff --git a/invokeai/frontend/web/src/common/util/promptAttention.ts b/invokeai/frontend/web/src/common/util/promptAttention.ts index d303b7e4b29..35748044fe7 100644 --- a/invokeai/frontend/web/src/common/util/promptAttention.ts +++ b/invokeai/frontend/web/src/common/util/promptAttention.ts @@ -5,155 +5,102 @@ import { type ASTNode, type Attention, parseTokens, serialize, tokenize } from ' const log = logger('events'); -/** - * Behavior Rules: - * - * ATTENTION SYNTAX: - * - Words support +/- attention: `word+`, `word-`, `word++`, `word--`, etc. - * - Groups support both +/- and numeric: `(words)+`, `(words)1.2`, `(words)0.8` - * - `word++` is roughly equivalent to `(word)1.2` in effect - * - Mixed attention like `word+-` or numeric on words `word1.2` is invalid - * - * ADJUSTMENT RULES: - * - `word++` down → `word+` - * - `word+` down → `word` (attention removed) - * - `word` down → `word-` - * - `word-` down → `word--` - * - `(content)1.2` down → `(content)1.1` - * - `(content)1.1` down → `content` (group unwrapped) - * - `content content` down → `(content content)0.9` (new group created) - * - * SELECTION BEHAVIOR: - * - Cursor/selection within a word → expand to full word, adjust word attention - * - Cursor touching group boundary (parens or weight) → adjust group attention - * - Selection entirely within a group → adjust that group's attention - * - Selection spans multiple content nodes → create new group with initial attention - * - Whitespace and punctuation are ignored for content selection - */ - type AttentionDirection = 'increment' | 'decrement'; +type SelectionBounds = { start: number; end: number }; +type AdjustmentResult = { prompt: string; selectionStart: number; selectionEnd: number }; -type SelectionBounds = { - start: number; - end: number; -}; - -type AdjustmentResult = { - prompt: string; - selectionStart: number; - selectionEnd: number; -}; - -/** - * A node with its position in the serialized prompt string. - */ type PositionedNode = { node: ASTNode; start: number; end: number; - /** Position of content start (after opening paren for groups) */ contentStart: number; - /** Position of content end (before closing paren for groups) */ contentEnd: number; parent: PositionedNode | null; }; // ============================================================================ -// ATTENTION COMPUTATION +// ATTENTION HELPERS // ============================================================================ -/** - * Checks if attention is symbol-based (+, -, ++, etc.) - */ -function isSymbolAttention(attention: Attention | undefined): attention is string { - return typeof attention === 'string' && /^[+-]+$/.test(attention); +function isSymbolAttention(a: Attention | undefined): a is string { + return typeof a === 'string' && /^[+-]+$/.test(a); } -/** - * Checks if attention is numeric (1.2, 0.8, etc.) - */ -function isNumericAttention(attention: Attention | undefined): attention is number { - return typeof attention === 'number'; +function isNumericAttention(a: Attention | undefined): a is number { + return typeof a === 'number'; } -/** - * Computes adjusted attention for symbol-based attention (+/-). - * Used for both words and groups with symbol attention. - */ -function adjustSymbolAttention(direction: AttentionDirection, attention: string | undefined): string | undefined { +/** Convert attention to numeric level: + = 1, ++ = 2, - = -1, 1.1 = 1, etc. */ +function attentionToLevel(attention: Attention | undefined): number { + if (isNumericAttention(attention)) { + return Math.round((attention - 1.0) * 10); + } if (!attention) { - return direction === 'increment' ? '+' : '-'; + return 0; } - - if (direction === 'increment') { - // Going up: remove '-' if present, otherwise add '+' - if (attention.endsWith('-')) { - const result = attention.slice(0, -1); - return result || undefined; - } - return `${attention}+`; - } else { - // Going down: remove '+' if present, otherwise add '-' - if (attention.endsWith('+')) { - const result = attention.slice(0, -1); - return result || undefined; + let level = 0; + for (const c of String(attention)) { + if (c === '+') { + level++; + } else if (c === '-') { + level--; } - return `${attention}-`; } + return level; } -/** - * Computes adjusted attention for numeric attention. - * Only used for groups. - */ -function adjustNumericAttention(direction: AttentionDirection, attention: number): number | undefined { - const step = direction === 'increment' ? 0.1 : -0.1; - const result = parseFloat((attention + step).toFixed(1)); - - // 1.0 is default - return undefined to signal unwrapping - if (result === 1.0) { +/** Convert level back to symbol attention */ +function levelToSymbol(level: number): string | undefined { + if (level === 0) { return undefined; } + return level > 0 ? '+'.repeat(level) : '-'.repeat(-level); +} - return result; +/** Combine two attention values (for flattening group attention onto children) */ +function combineAttention(a: Attention | undefined, b: Attention | undefined): Attention | undefined { + return levelToSymbol(attentionToLevel(a) + attentionToLevel(b)); } -/** - * Computes the new attention value based on direction and current attention. - * Returns undefined if attention should be removed (normalize to default). - */ -function computeAttention( - direction: AttentionDirection, - attention: Attention | undefined, - isGroup: boolean -): Attention | undefined { - // No current attention - if (attention === undefined) { - if (isGroup) { - // Groups going down from neutral get 0.9 - return direction === 'increment' ? '+' : 0.9; - } - return direction === 'increment' ? '+' : '-'; +/** Adjust attention by one step in the given direction */ +function computeAttention(direction: AttentionDirection, attention: Attention | undefined): Attention | undefined { + if (isNumericAttention(attention)) { + const result = parseFloat((attention + (direction === 'increment' ? 0.1 : -0.1)).toFixed(1)); + return result === 1.0 ? undefined : result; } - - // Symbol attention (+, -, ++, etc.) if (isSymbolAttention(attention)) { - return adjustSymbolAttention(direction, attention); + const level = attentionToLevel(attention); + return levelToSymbol(direction === 'increment' ? level + 1 : level - 1); } + return direction === 'increment' ? '+' : '-'; +} - // Numeric attention (only valid for groups) - if (isNumericAttention(attention)) { - return adjustNumericAttention(direction, attention); +/** Apply attention to a node, flattening groups if combined attention is neutral */ +function applyAttentionToNode(node: ASTNode, attention: Attention | undefined): ASTNode[] { + if (node.type === 'word') { + const combined = combineAttention(attention, node.attention); + return combined === node.attention ? [node] : [{ ...node, attention: combined }]; } - - // Parse string numbers - const numValue = parseFloat(String(attention)); - if (!isNaN(numValue)) { - return adjustNumericAttention(direction, numValue); + if (node.type === 'group') { + const combined = combineAttention(attention, node.attention); + if (combined === undefined) { + return node.children.flatMap((c) => applyAttentionToNode(c, undefined)); + } + return [{ ...node, attention: combined }]; } + return [node]; +} - // Fallback: treat as no attention - return direction === 'increment' ? '+' : '-'; +/** Adjust a content node's attention, unwrapping groups if result is neutral */ +function adjustContentNode(node: ASTNode, direction: AttentionDirection): ASTNode[] { + if (node.type === 'word') { + return [{ ...node, attention: computeAttention(direction, node.attention) }]; + } + if (node.type === 'group') { + const newAttn = computeAttention(direction, node.attention); + return newAttn === undefined ? node.children : [{ ...node, attention: newAttn }]; + } + return [node]; } // ============================================================================ @@ -334,6 +281,51 @@ function findContentNodes(positions: PositionedNode[], selection: SelectionBound }); } +/** + * Elevates nodes to their parent groups when selection crosses group boundaries. + * This ensures we don't try to extract partial groups which would break parsing. + * + * For example, if selection spans from inside a group to outside it, + * the nodes inside the group are replaced with the group itself. + */ +function elevateToTopLevelNodes(nodes: PositionedNode[]): PositionedNode[] { + if (nodes.length === 0) { + return nodes; + } + + // Check if any nodes have different parent contexts + const hasRootLevel = nodes.some((n) => n.parent === null); + const hasNestedNodes = nodes.some((n) => n.parent !== null); + + // If all nodes are at the same level (all root or all same parent), no elevation needed + if (!hasRootLevel || !hasNestedNodes) { + return nodes; + } + + // We have nodes at different levels - elevate nested nodes to their parent groups + const result: PositionedNode[] = []; + const seenNodes = new Set(); + + for (const node of nodes) { + if (node.parent === null) { + // Root level node - keep as is (if not already added) + if (!seenNodes.has(node)) { + seenNodes.add(node); + result.push(node); + } + } else { + // Nested node - elevate to parent group (if not already added) + if (!seenNodes.has(node.parent)) { + seenNodes.add(node.parent); + result.push(node.parent); + } + } + } + + // Sort by start position to maintain correct order + return result.sort((a, b) => a.start - b.start); +} + /** * Finds the single word the cursor is within (not just touching). */ @@ -382,8 +374,23 @@ type AdjustmentStrategy = | { type: 'adjust-word'; node: PositionedNode } | { type: 'adjust-group'; node: PositionedNode } | { type: 'create-group'; nodes: PositionedNode[] } + | { type: 'split-group'; group: PositionedNode; selectedChildren: PositionedNode[]; allPositions: PositionedNode[] } | { type: 'no-op' }; +/** + * Gets direct children of a group from the position map. + */ +function getGroupChildren(positions: PositionedNode[], group: PositionedNode): PositionedNode[] { + return positions.filter((p) => p.parent === group); +} + +/** + * Checks if a positioned node is a content node (word, group, or embedding). + */ +function isContentNode(p: PositionedNode): boolean { + return p.node.type === 'word' || p.node.type === 'group' || p.node.type === 'embedding'; +} + function determineStrategy(positions: PositionedNode[], selection: SelectionBounds): AdjustmentStrategy { const contentNodes = findContentNodes(positions, selection); @@ -429,7 +436,36 @@ function determineStrategy(positions: PositionedNode[], selection: SelectionBoun return { type: 'create-group', nodes: contentNodes }; } - // Multiple content nodes - create a new group + // Multiple content nodes - check if selection crosses group boundaries + // This happens when some nodes are inside a group and some are outside + const hasRootLevel = contentNodes.some((n) => n.parent === null); + const nestedNodes = contentNodes.filter((n) => n.parent !== null); + + if (hasRootLevel && nestedNodes.length > 0) { + // Selection crosses group boundaries - need to split groups + // Find all unique parent groups that are partially selected + const parentGroups = new Set(); + for (const node of nestedNodes) { + if (node.parent) { + parentGroups.add(node.parent); + } + } + + // For each parent group, check if it's partially selected + for (const group of parentGroups) { + const groupChildren = getGroupChildren(positions, group).filter(isContentNode); + const selectedChildren = groupChildren.filter((child) => + contentNodes.some((cn) => cn === child || (cn.start >= child.start && cn.end <= child.end)) + ); + + // If not all children are selected, we need to split + if (selectedChildren.length > 0 && selectedChildren.length < groupChildren.length) { + return { type: 'split-group', group, selectedChildren, allPositions: positions }; + } + } + } + + // No partial group selections - create a new group return { type: 'create-group', nodes: contentNodes }; } @@ -443,23 +479,17 @@ export function adjustPromptAttention( direction: AttentionDirection ): AdjustmentResult { try { - // Handle empty prompt if (!prompt.trim()) { return { prompt, selectionStart, selectionEnd }; } - // Normalize selection const selection: SelectionBounds = { start: Math.min(selectionStart, selectionEnd), end: Math.max(selectionStart, selectionEnd), }; - // Parse and build position map - const tokens = tokenize(prompt); - const ast = parseTokens(tokens); + const ast = parseTokens(tokenize(prompt)); const { positions } = buildPositionMap(ast); - - // Determine what to do const strategy = determineStrategy(positions, selection); switch (strategy.type) { @@ -467,94 +497,117 @@ export function adjustPromptAttention( return { prompt, selectionStart, selectionEnd }; case 'adjust-word': { - const wordPos = strategy.node; - const word = wordPos.node as ASTNode & { type: 'word' }; - const newAttention = computeAttention(direction, word.attention, false); - - const updatedWord: ASTNode = { - type: 'word', - text: word.text, - attention: newAttention, - }; - - const newAST = replaceNodeInAST(ast, word, updatedWord); - const newPrompt = serialize(newAST); - const newWordText = serialize([updatedWord]); - + const { node: pos } = strategy; + const word = pos.node as ASTNode & { type: 'word' }; + const updated: ASTNode = { ...word, attention: computeAttention(direction, word.attention) }; + const newPrompt = serialize(replaceNodeInAST(ast, word, updated)); return { prompt: newPrompt, - selectionStart: wordPos.start, - selectionEnd: wordPos.start + newWordText.length, + selectionStart: pos.start, + selectionEnd: pos.start + serialize([updated]).length, }; } case 'adjust-group': { - const groupPos = strategy.node; - const group = groupPos.node as ASTNode & { type: 'group' }; - const newAttention = computeAttention(direction, group.attention, true); - - // If attention becomes undefined (1.0), unwrap the group - if (newAttention === undefined) { - const newAST = replaceNodeInAST(ast, group, group.children); - const newPrompt = serialize(newAST); - const childrenText = serialize(group.children); + const { node: pos } = strategy; + const group = pos.node as ASTNode & { type: 'group' }; + const newAttn = computeAttention(direction, group.attention); + if (newAttn === undefined) { + const newPrompt = serialize(replaceNodeInAST(ast, group, group.children)); return { prompt: newPrompt, - selectionStart: groupPos.start, - selectionEnd: groupPos.start + childrenText.length, + selectionStart: pos.start, + selectionEnd: pos.start + serialize(group.children).length, }; } - const updatedGroup: ASTNode = { - type: 'group', - children: group.children, - attention: newAttention, + const updated: ASTNode = { ...group, attention: newAttn }; + const newPrompt = serialize(replaceNodeInAST(ast, group, updated)); + return { + prompt: newPrompt, + selectionStart: pos.start, + selectionEnd: pos.start + serialize([updated]).length, }; + } + + case 'split-group': { + const { group: groupPos, selectedChildren, allPositions } = strategy; + const group = groupPos.node as ASTNode & { type: 'group' }; + const groupAttn = group.attention; + const selectedSet = new Set(selectedChildren.map((c) => c.node)); + + // Rebuild group children with proper attention handling + const rebuiltNodes: ASTNode[] = []; + for (const childPos of getGroupChildren(allPositions, groupPos)) { + if (!isContentNode(childPos)) { + rebuiltNodes.push(childPos.node); // Preserve whitespace/punct + continue; + } + + const isSelected = selectedSet.has(childPos.node); + if (!isSelected) { + // Not selected: flatten group attention onto child + rebuiltNodes.push(...applyAttentionToNode(childPos.node, groupAttn)); + } else { + // Selected: adjust based on whether child has own attention + const hasOwnAttn = + childPos.node.type === 'word' || childPos.node.type === 'group' + ? childPos.node.attention !== undefined + : false; + const effectiveAttn = hasOwnAttn + ? (childPos.node as { attention?: Attention }).attention + : combineAttention(groupAttn, undefined); + + rebuiltNodes.push( + ...adjustContentNode( + childPos.node.type === 'word' || childPos.node.type === 'group' + ? { ...childPos.node, attention: effectiveAttn } + : childPos.node, + direction + ) + ); + } + } + + // Replace group and adjust root-level selected nodes + let newAST = replaceNodeInAST(ast, group, rebuiltNodes); + for (const rootNode of findContentNodes(allPositions, selection).filter((n) => n.parent === null)) { + newAST = replaceNodeInAST(newAST, rootNode.node, adjustContentNode(rootNode.node, direction)); + } - const newAST = replaceNodeInAST(ast, group, updatedGroup); const newPrompt = serialize(newAST); - const newGroupText = serialize([updatedGroup]); + const allSelected = [ + ...selectedChildren, + ...findContentNodes(allPositions, selection).filter((n) => n.parent === null), + ]; + const sorted = allSelected.sort((a, b) => a.start - b.start); return { prompt: newPrompt, - selectionStart: groupPos.start, - selectionEnd: groupPos.start + newGroupText.length, + selectionStart: sorted[0]?.start ?? selectionStart, + selectionEnd: newPrompt.length - (prompt.length - (sorted[sorted.length - 1]?.end ?? selectionEnd)), }; } case 'create-group': { - const nodes = strategy.nodes; - const firstNode = nodes[0]!; - const lastNode = nodes[nodes.length - 1]!; - - // Get the text range to wrap - const wrapStart = firstNode.start; - const wrapEnd = lastNode.end; + const elevated = elevateToTopLevelNodes(strategy.nodes); + const sorted = elevated.sort((a, b) => a.start - b.start); + const wrapStart = sorted[0]!.start; + const wrapEnd = sorted[sorted.length - 1]!.end; - // Parse just the selected portion - const selectedText = prompt.substring(wrapStart, wrapEnd); - const selectedTokens = tokenize(selectedText); - const selectedAST = parseTokens(selectedTokens); - - // Create new group with appropriate attention - const newAttention = computeAttention(direction, undefined, true); + const selectedAST = parseTokens(tokenize(prompt.substring(wrapStart, wrapEnd))); const newGroup: ASTNode = { type: 'group', children: selectedAST, - attention: newAttention, + attention: computeAttention(direction, undefined), }; - // Reconstruct prompt - const before = prompt.substring(0, wrapStart); - const after = prompt.substring(wrapEnd); - const newGroupText = serialize([newGroup]); - const newPrompt = before + newGroupText + after; - + const newPrompt = prompt.substring(0, wrapStart) + serialize([newGroup]) + prompt.substring(wrapEnd); return { prompt: newPrompt, selectionStart: wrapStart, - selectionEnd: wrapStart + newGroupText.length, + selectionEnd: wrapStart + serialize([newGroup]).length, }; } } diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceNegativePrompt.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceNegativePrompt.tsx index 05348223753..07fe7ca770f 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceNegativePrompt.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceNegativePrompt.tsx @@ -9,6 +9,7 @@ import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/s import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton'; import { PromptPopover } from 'features/prompt/PromptPopover'; import { usePrompt } from 'features/prompt/usePrompt'; +import { usePromptAttentionHotkeys } from 'features/prompt/usePromptAttentionHotkeys'; import { memo, useCallback, useMemo, useRef } from 'react'; import { useTranslation } from 'react-i18next'; @@ -45,6 +46,8 @@ export const RegionalGuidanceNegativePrompt = memo(() => { onChange: _onChange, }); + usePromptAttentionHotkeys({ textareaRef, onPromptChange: _onChange }); + return ( @@ -59,8 +62,9 @@ export const RegionalGuidanceNegativePrompt = memo(() => { variant="outline" paddingInlineStart={2} paddingInlineEnd={8} - fontSize="sm" zIndex="0 !important" + fontFamily="mono" + fontSize="sm" _focusVisible={_focusVisible} /> { onChange: _onChange, }); + usePromptAttentionHotkeys({ textareaRef, onPromptChange: _onChange }); + return ( @@ -61,6 +64,8 @@ export const RegionalGuidancePositivePrompt = memo(() => { paddingInlineEnd={8} minH={28} zIndex="0 !important" + fontFamily="mono" + fontSize="sm" _focusVisible={_focusVisible} /> { onChange: _onChange, }); + usePromptAttentionHotkeys({ + textareaRef, + onPromptChange: (prompt) => dispatch(negativePromptChanged(prompt)), + }); + return ( diff --git a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx index d95b3298af0..b6c91ed2896 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx @@ -1,7 +1,6 @@ import { Box, Flex, Textarea } from '@invoke-ai/ui-library'; import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks'; import { usePersistedTextAreaSize } from 'common/hooks/usePersistedTextareaSize'; -import { adjustPromptAttention } from 'common/util/promptAttention'; import { positivePromptChanged, selectModelSupportsNegativePrompt, @@ -16,6 +15,7 @@ import { ViewModePrompt } from 'features/parameters/components/Prompts/ViewModeP import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton'; import { PromptPopover } from 'features/prompt/PromptPopover'; import { usePrompt } from 'features/prompt/usePrompt'; +import { usePromptAttentionHotkeys } from 'features/prompt/usePromptAttentionHotkeys'; import { selectStylePresetActivePresetId, selectStylePresetViewMode, @@ -193,66 +193,9 @@ export const ParamPositivePrompt = memo(() => { dependencies: [promptHistoryApi.next, isPromptFocused], }); - // Adjust prompt attention up - useRegisteredHotkeys({ - id: 'promptWeightUp', - category: 'app', - callback: (e) => { - if (isPromptFocused() && textareaRef.current) { - e.preventDefault(); - const textarea = textareaRef.current; - const result = adjustPromptAttention( - textarea.value, - textarea.selectionStart, - textarea.selectionEnd, - 'increment' - ); - - // Update the prompt - dispatch(positivePromptChanged(result.prompt)); - - // Update selection after React re-renders - setTimeout(() => { - if (textareaRef.current) { - textareaRef.current.setSelectionRange(result.selectionStart, result.selectionEnd); - textareaRef.current.focus(); - } - }, 0); - } - }, - options: { preventDefault: true, enableOnFormTags: ['TEXTAREA'] }, - dependencies: [dispatch, isPromptFocused], - }); - - // Adjust prompt attention down - useRegisteredHotkeys({ - id: 'promptWeightDown', - category: 'app', - callback: (e) => { - if (isPromptFocused() && textareaRef.current) { - e.preventDefault(); - const textarea = textareaRef.current; - const result = adjustPromptAttention( - textarea.value, - textarea.selectionStart, - textarea.selectionEnd, - 'decrement' - ); - - // Update the prompt - dispatch(positivePromptChanged(result.prompt)); - - // Update selection after React re-renders - setTimeout(() => { - if (textareaRef.current) { - textareaRef.current.setSelectionRange(result.selectionStart, result.selectionEnd); - textareaRef.current.focus(); - } - }, 0); - } - }, - options: { preventDefault: true, enableOnFormTags: ['TEXTAREA'] }, - dependencies: [dispatch, isPromptFocused], + usePromptAttentionHotkeys({ + textareaRef, + onPromptChange: (prompt) => dispatch(positivePromptChanged(prompt)), }); return ( diff --git a/invokeai/frontend/web/src/features/prompt/usePromptAttentionHotkeys.ts b/invokeai/frontend/web/src/features/prompt/usePromptAttentionHotkeys.ts new file mode 100644 index 00000000000..545857b9e2c --- /dev/null +++ b/invokeai/frontend/web/src/features/prompt/usePromptAttentionHotkeys.ts @@ -0,0 +1,68 @@ +import { adjustPromptAttention } from 'common/util/promptAttention'; +import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData'; +import type { RefObject } from 'react'; +import { useCallback } from 'react'; + +type UsePromptAttentionHotkeysArgs = { + textareaRef: RefObject; + onPromptChange: (prompt: string) => void; +}; + +/** + * Registers hotkeys for adjusting prompt attention weights (Ctrl+Up/Down). + * Handles both increment and decrement operations with proper selection restoration. + * Uses execCommand to integrate with browser's native undo stack. + */ +export const usePromptAttentionHotkeys = ({ + textareaRef, + onPromptChange: _onPromptChange, +}: UsePromptAttentionHotkeysArgs) => { + const isPromptFocused = useCallback(() => document.activeElement === textareaRef.current, [textareaRef]); + + const handleAttentionAdjustment = useCallback( + (direction: 'increment' | 'decrement') => { + const textarea = textareaRef.current; + if (!textarea) { + return; + } + + const result = adjustPromptAttention(textarea.value, textarea.selectionStart, textarea.selectionEnd, direction); + + // Use execCommand to make the change undo-able by the browser. + // This triggers the textarea's native onChange, which syncs React state. + textarea.focus(); + textarea.setSelectionRange(0, textarea.value.length); + document.execCommand('insertText', false, result.prompt); + + // Restore the selection to cover the adjusted portion + textarea.setSelectionRange(result.selectionStart, result.selectionEnd); + }, + [textareaRef] + ); + + useRegisteredHotkeys({ + id: 'promptWeightUp', + category: 'app', + callback: (e) => { + if (isPromptFocused() && textareaRef.current) { + e.preventDefault(); + handleAttentionAdjustment('increment'); + } + }, + options: { preventDefault: true, enableOnFormTags: ['TEXTAREA'] }, + dependencies: [isPromptFocused, handleAttentionAdjustment], + }); + + useRegisteredHotkeys({ + id: 'promptWeightDown', + category: 'app', + callback: (e) => { + if (isPromptFocused() && textareaRef.current) { + e.preventDefault(); + handleAttentionAdjustment('decrement'); + } + }, + options: { preventDefault: true, enableOnFormTags: ['TEXTAREA'] }, + dependencies: [isPromptFocused, handleAttentionAdjustment], + }); +};