diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index a83a1e0a80..528a8c4828 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -7,6 +7,8 @@ #ifndef TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ #define TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ +break the code here + #include #include #include @@ -33,6 +35,9 @@ #include "transformer_engine/activation.h" #include "transformer_engine/multi_stream.h" + // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace + XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); + namespace transformer_engine { namespace jax {