Skip to content

Commit 5ee9ffe

Browse files
qihqiCyrilvallez
andauthored
Let transformers know when a model is being traced via jax.jit (torchax) (#42611)
* Let transformers know when a model is being traced via jax.jit Move check earler * Add docstring * add docstring --------- Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
1 parent ba1ad53 commit 5ee9ffe

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

src/transformers/utils/import_utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1308,6 +1308,34 @@ def is_torch_fx_proxy(x):
13081308
return False
13091309

13101310

1311+
def is_jax_jitting(x):
1312+
"""returns True if we are inside of `jax.jit` context, False otherwise.
1313+
1314+
When a torch model is being compiled with `jax.jit` using torchax,
1315+
the tensor that goes through the model would be an instance of
1316+
`torchax.tensor.Tensor`, which is a tensor subclass. This tensor has
1317+
a `jax` method to return the inner Jax array
1318+
(https://github.com/google/torchax/blob/13ce870a1d9adb2430333c27bb623469e3aea34e/torchax/tensor.py#L134).
1319+
Here we use ducktyping to detect if the inner jax array is a jax Tracer
1320+
then we are in tracing context. (See more at: https://github.com/jax-ml/jax/discussions/9241)
1321+
1322+
Args:
1323+
x: torch.Tensor
1324+
1325+
Returns:
1326+
bool: whether we are inside of jax jit tracing.
1327+
"""
1328+
1329+
if not hasattr(x, "jax"):
1330+
return False
1331+
try:
1332+
import jax
1333+
1334+
return isinstance(x.jax(), jax.core.Tracer)
1335+
except Exception:
1336+
return False
1337+
1338+
13111339
def is_jit_tracing() -> bool:
13121340
try:
13131341
import torch
@@ -1327,12 +1355,14 @@ def is_cuda_stream_capturing() -> bool:
13271355

13281356

13291357
def is_tracing(tensor=None) -> bool:
1330-
"""Checks whether we are tracing a graph with dynamo (compile or export), torch.jit, torch.fx or CUDA stream capturing"""
1358+
"""Checks whether we are tracing a graph with dynamo (compile or export), torch.jit, torch.fx, jax.jit (with torchax) or
1359+
CUDA stream capturing"""
13311360
# Note that `is_torchdynamo_compiling` checks both compiling and exporting (the export check is stricter and
13321361
# only checks export)
13331362
_is_tracing = is_torchdynamo_compiling() or is_jit_tracing() or is_cuda_stream_capturing()
13341363
if tensor is not None:
13351364
_is_tracing |= is_torch_fx_proxy(tensor)
1365+
_is_tracing |= is_jax_jitting(tensor)
13361366
return _is_tracing
13371367

13381368

0 commit comments

Comments
 (0)