Skip to content

Commit 0a271d8

Browse files
authored
model-conversion : add verbose flag in run-org-model.py (#18194)
This commit adds a --verbose flag to the run-org-model.py script to enable or disable detailed debug output, such as input and output tensors for each layer. Debug utilities (summarize, debug_hook, setup_rope_debug) have been moved to utils/common.py. The motivation for this is that the detailed debug output can be useful for diagnosing issues with model conversion or execution, but it can also produce a large amount of output that may not always be needed. The script will also be further cleaned/refactored in follow-up commits.
1 parent 52fc7fe commit 0a271d8

File tree

2 files changed

+147
-122
lines changed

2 files changed

+147
-122
lines changed

examples/model-conversion/scripts/causal/run-org-model.py

Lines changed: 17 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -2,135 +2,22 @@
22

33
import argparse
44
import os
5+
import sys
56
import importlib
67
from pathlib import Path
78

9+
# Add parent directory to path for imports
10+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
11+
812
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig
913
import torch
1014
import numpy as np
11-
12-
### If you want to dump RoPE activations, apply this monkey patch to the model
13-
### class from Transformers that you are running (replace apertus.modeling_apertus
14-
### with the proper package and class for your model
15-
### === START ROPE DEBUG ===
16-
# from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb
17-
18-
# orig_rope = apply_rotary_pos_emb
19-
# torch.set_printoptions(threshold=float('inf'))
20-
# torch.set_printoptions(precision=6, sci_mode=False)
21-
22-
# def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
23-
# # log inputs
24-
# summarize(q, "RoPE.q_in")
25-
# summarize(k, "RoPE.k_in")
26-
27-
# # call original
28-
# q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
29-
30-
# # log outputs
31-
# summarize(q_out, "RoPE.q_out")
32-
# summarize(k_out, "RoPE.k_out")
33-
34-
# return q_out, k_out
35-
36-
# # Patch it
37-
# import transformers.models.apertus.modeling_apertus as apertus_mod # noqa: E402
38-
# apertus_mod.apply_rotary_pos_emb = debug_rope
39-
### == END ROPE DEBUG ===
40-
41-
42-
def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
43-
"""
44-
Print a tensor in llama.cpp debug style.
45-
46-
Supports:
47-
- 2D tensors (seq, hidden)
48-
- 3D tensors (batch, seq, hidden)
49-
- 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
50-
51-
Shows first and last max_vals of each vector per sequence position.
52-
"""
53-
t = tensor.detach().to(torch.float32).cpu()
54-
55-
# Determine dimensions
56-
if t.ndim == 3:
57-
_, s, _ = t.shape
58-
elif t.ndim == 2:
59-
_, s = 1, t.shape[0]
60-
t = t.unsqueeze(0)
61-
elif t.ndim == 4:
62-
_, s, _, _ = t.shape
63-
else:
64-
print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
65-
return
66-
67-
ten_shape = t.shape
68-
69-
print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
70-
print(" [")
71-
print(" [")
72-
73-
# Determine indices for first and last sequences
74-
first_indices = list(range(min(s, max_seq)))
75-
last_indices = list(range(max(0, s - max_seq), s))
76-
77-
# Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
78-
has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
79-
80-
# Combine indices
81-
if has_overlap:
82-
# If there's overlap, just use the combined unique indices
83-
indices = sorted(list(set(first_indices + last_indices)))
84-
separator_index = None
85-
else:
86-
# If no overlap, we'll add a separator between first and last sequences
87-
indices = first_indices + last_indices
88-
separator_index = len(first_indices)
89-
90-
for i, si in enumerate(indices):
91-
# Add separator if needed
92-
if separator_index is not None and i == separator_index:
93-
print(" ...")
94-
95-
# Extract appropriate slice
96-
vec = t[0, si]
97-
if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
98-
flat = vec.flatten().tolist()
99-
else: # 2D or 3D case
100-
flat = vec.tolist()
101-
102-
# First and last slices
103-
first = flat[:max_vals]
104-
last = flat[-max_vals:] if len(flat) >= max_vals else flat
105-
first_str = ", ".join(f"{v:12.4f}" for v in first)
106-
last_str = ", ".join(f"{v:12.4f}" for v in last)
107-
108-
print(f" [{first_str}, ..., {last_str}]")
109-
110-
print(" ],")
111-
print(" ]")
112-
print(f" sum = {t.sum().item():.6f}\n")
113-
114-
115-
def debug_hook(name):
116-
def fn(_m, input, output):
117-
if isinstance(input, torch.Tensor):
118-
summarize(input, name + "_in")
119-
elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
120-
summarize(input[0], name + "_in")
121-
if isinstance(output, torch.Tensor):
122-
summarize(output, name + "_out")
123-
elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
124-
summarize(output[0], name + "_out")
125-
126-
return fn
127-
128-
129-
unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
15+
from utils.common import debug_hook
13016

13117
parser = argparse.ArgumentParser(description="Process model with specified path")
13218
parser.add_argument("--model-path", "-m", help="Path to the model")
13319
parser.add_argument("--prompt-file", "-f", help="Optional prompt file", required=False)
20+
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose debug output")
13421
args = parser.parse_args()
13522

13623
model_path = os.environ.get("MODEL_PATH", args.model_path)
@@ -139,6 +26,12 @@ def fn(_m, input, output):
13926
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
14027
)
14128

29+
### If you want to dump RoPE activations, uncomment the following lines:
30+
### === START ROPE DEBUG ===
31+
# from utils.common import setup_rope_debug
32+
# setup_rope_debug("transformers.models.apertus.modeling_apertus")
33+
### == END ROPE DEBUG ===
34+
14235

14336
print("Loading model and tokenizer using AutoTokenizer:", model_path)
14437
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
@@ -156,6 +49,7 @@ def fn(_m, input, output):
15649
print("BOS token id: ", config.bos_token_id)
15750
print("EOS token id: ", config.eos_token_id)
15851

52+
unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
15953
if unreleased_model_name:
16054
model_name_lower = unreleased_model_name.lower()
16155
unreleased_module_path = (
@@ -184,9 +78,10 @@ def fn(_m, input, output):
18478
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
18579
)
18680

187-
for name, module in model.named_modules():
188-
if len(list(module.children())) == 0: # only leaf modules
189-
module.register_forward_hook(debug_hook(name))
81+
if args.verbose:
82+
for name, module in model.named_modules():
83+
if len(list(module.children())) == 0: # only leaf modules
84+
module.register_forward_hook(debug_hook(name))
19085

19186
model_name = os.path.basename(model_path)
19287
# Printing the Model class to allow for easier debugging. This can be useful

examples/model-conversion/scripts/utils/common.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import os
44
import sys
5+
import torch
6+
57

68
def get_model_name_from_env_path(env_path_name):
79
model_path = os.getenv(env_path_name)
@@ -18,3 +20,131 @@ def get_model_name_from_env_path(env_path_name):
1820
name = name[:-5]
1921

2022
return name
23+
24+
25+
def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
26+
"""
27+
Print a tensor in llama.cpp debug style.
28+
29+
Supports:
30+
- 2D tensors (seq, hidden)
31+
- 3D tensors (batch, seq, hidden)
32+
- 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
33+
34+
Shows first and last max_vals of each vector per sequence position.
35+
"""
36+
t = tensor.detach().to(torch.float32).cpu()
37+
38+
# Determine dimensions
39+
if t.ndim == 3:
40+
_, s, _ = t.shape
41+
elif t.ndim == 2:
42+
_, s = 1, t.shape[0]
43+
t = t.unsqueeze(0)
44+
elif t.ndim == 4:
45+
_, s, _, _ = t.shape
46+
else:
47+
print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
48+
return
49+
50+
ten_shape = t.shape
51+
52+
print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
53+
print(" [")
54+
print(" [")
55+
56+
# Determine indices for first and last sequences
57+
first_indices = list(range(min(s, max_seq)))
58+
last_indices = list(range(max(0, s - max_seq), s))
59+
60+
# Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
61+
has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
62+
63+
# Combine indices
64+
if has_overlap:
65+
# If there's overlap, just use the combined unique indices
66+
indices = sorted(list(set(first_indices + last_indices)))
67+
separator_index = None
68+
else:
69+
# If no overlap, we'll add a separator between first and last sequences
70+
indices = first_indices + last_indices
71+
separator_index = len(first_indices)
72+
73+
for i, si in enumerate(indices):
74+
# Add separator if needed
75+
if separator_index is not None and i == separator_index:
76+
print(" ...")
77+
78+
# Extract appropriate slice
79+
vec = t[0, si]
80+
if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
81+
flat = vec.flatten().tolist()
82+
else: # 2D or 3D case
83+
flat = vec.tolist()
84+
85+
# First and last slices
86+
first = flat[:max_vals]
87+
last = flat[-max_vals:] if len(flat) >= max_vals else flat
88+
first_str = ", ".join(f"{v:12.4f}" for v in first)
89+
last_str = ", ".join(f"{v:12.4f}" for v in last)
90+
91+
print(f" [{first_str}, ..., {last_str}]")
92+
93+
print(" ],")
94+
print(" ]")
95+
print(f" sum = {t.sum().item():.6f}\n")
96+
97+
98+
def debug_hook(name):
99+
def fn(_m, input, output):
100+
if isinstance(input, torch.Tensor):
101+
summarize(input, name + "_in")
102+
elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
103+
summarize(input[0], name + "_in")
104+
if isinstance(output, torch.Tensor):
105+
summarize(output, name + "_out")
106+
elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
107+
summarize(output[0], name + "_out")
108+
109+
return fn
110+
111+
112+
def setup_rope_debug(model_module_path: str, function_name: str = "apply_rotary_pos_emb"):
113+
"""
114+
Apply monkey patch to dump RoPE activations for debugging.
115+
116+
Args:
117+
model_module_path: Path to the model module (e.g., "transformers.models.apertus.modeling_apertus")
118+
function_name: Name of the RoPE function to patch (default: "apply_rotary_pos_emb")
119+
120+
Example:
121+
from utils.common import setup_rope_debug
122+
setup_rope_debug("transformers.models.apertus.modeling_apertus")
123+
"""
124+
import importlib
125+
126+
# Import the module and get the original function
127+
module = importlib.import_module(model_module_path)
128+
orig_rope = getattr(module, function_name)
129+
130+
# Set torch print options for better debugging
131+
torch.set_printoptions(threshold=float('inf'))
132+
torch.set_printoptions(precision=6, sci_mode=False)
133+
134+
def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
135+
# log inputs
136+
summarize(q, "RoPE.q_in")
137+
summarize(k, "RoPE.k_in")
138+
139+
# call original
140+
q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
141+
142+
# log outputs
143+
summarize(q_out, "RoPE.q_out")
144+
summarize(k_out, "RoPE.k_out")
145+
146+
return q_out, k_out
147+
148+
# Patch it
149+
setattr(module, function_name, debug_rope)
150+
print(f"RoPE debug patching applied to {model_module_path}.{function_name}")

0 commit comments

Comments
 (0)