Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
d32b7ab
avoiding shape copy, torch dynamo and torch autograd overheads
vthumbe1503 Dec 12, 2025
e724815
minor additional change
vthumbe1503 Dec 12, 2025
b725f5b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 12, 2025
7b031d0
changes done to remove the additional nvte_make_shape calls
vthumbe1503 Dec 23, 2025
1141d72
Merge branch 'main' into cpu_optimizations_v2
vthumbe1503 Dec 24, 2025
d6ac3f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 24, 2025
51dd309
some additional changes
vthumbe1503 Dec 28, 2025
425182f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 28, 2025
f8748c0
some optimizations
vthumbe1503 Dec 28, 2025
7288628
got rid of vector in makeTransformerEngineTensor
vthumbe1503 Dec 30, 2025
c713f0d
Merge branch 'cpu_optimizations_v2' of github.com:vthumbe1503/Transfo…
vthumbe1503 Dec 30, 2025
a66f46b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 30, 2025
f1e0781
Merge branch 'main' into cpu_optimizations_v2
vthumbe1503 Dec 30, 2025
b334a74
minor miss
vthumbe1503 Dec 30, 2025
b966f66
Merge branch 'cpu_optimizations_v2' of github.com:vthumbe1503/Transfo…
vthumbe1503 Dec 30, 2025
58bf0f0
all shape copies removed
vthumbe1503 Dec 31, 2025
8a9bb77
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 31, 2025
116761f
clean up
vthumbe1503 Jan 1, 2026
801f89f
fix merge conflixt
vthumbe1503 Jan 1, 2026
a6eb2b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 1, 2026
43b693e
minor other change
vthumbe1503 Jan 1, 2026
bc7ba8b
Merge branch 'cpu_optimizations_v2' of github.com:vthumbe1503/Transfo…
vthumbe1503 Jan 1, 2026
9026f1f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 1, 2026
b833d15
minor opt
vthumbe1503 Jan 1, 2026
9003c7d
Merge branch 'cpu_optimizations_v2' of github.com:vthumbe1503/Transfo…
vthumbe1503 Jan 1, 2026
5d77eda
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 1, 2026
29f8426
minor cleanup
vthumbe1503 Jan 1, 2026
4cf81e1
Merge branch 'cpu_optimizations_v2' of github.com:vthumbe1503/Transfo…
vthumbe1503 Jan 1, 2026
7d75815
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 1, 2026
da650a5
remove uncessary code
vthumbe1503 Jan 1, 2026
1189e51
Merge branch 'cpu_optimizations_v2' of github.com:vthumbe1503/Transfo…
vthumbe1503 Jan 1, 2026
9283238
other minor cleanups
vthumbe1503 Jan 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Atype = A.data.dtype;
ret.A_scale_inv = A.scale_inv.dptr;
ret.lda = is_A_transposed ? k : m;
if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
if (!is_nvte_non_tn_fp8_gemm_supported && !is_A_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
ret.A = A.columnwise_data.dptr;
Expand All @@ -140,7 +141,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
} else {
NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
}
} else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) {
} else if (is_nvte_non_tn_fp8_gemm_supported && !A.has_data()) {
// Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
// data with the mirrored transpose-flag if we don't have row-wise data.
NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype),
Expand Down Expand Up @@ -220,7 +221,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Btype = B.data.dtype;
ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
if (!is_nvte_non_tn_fp8_gemm_supported && is_B_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
ret.B = B.columnwise_data.dptr;
Expand All @@ -231,7 +233,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
} else {
NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
}
} else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) {
} else if (is_nvte_non_tn_fp8_gemm_supported && !B.has_data()) {
// Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
// data with the mirrored transpose-flag if we don't have row-wise data.
NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,143 @@ NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor);
*/
namespace transformer_engine {

/*! \class NVTEShapeWrapper
* \brief C++ wrapper for NVTEShape with container-like interface.
*/
class NVTEShapeWrapper {
private:
NVTEShape data;

public:
// Default constructor
NVTEShapeWrapper() { data.ndim = 0; }
NVTEShapeWrapper(int ndim) { data.ndim = ndim; }
// Constructor from NVTEShape (direct assignment by reference)
NVTEShapeWrapper(const NVTEShape &shape) { data = shape; }

// Constructor from vector (creates a copy)
template <typename T>
NVTEShapeWrapper(const std::vector<T> &shape_vec) {
data.ndim = shape_vec.size();
for (size_t i = 0; i < data.ndim; ++i) {
data.data[i] = static_cast<size_t>(shape_vec[i]);
}
}
// Constructor from initializer list
NVTEShapeWrapper(const std::initializer_list<size_t> &shape_list) {
data.ndim = shape_list.size();
size_t i = 0;
for (const auto &val : shape_list) {
data.data[i++] = val;
}
}

// Copy constructor
NVTEShapeWrapper(const NVTEShapeWrapper &other) : data(other.data) {}

// Move constructor from another NVTEShapeWrapper
NVTEShapeWrapper(NVTEShapeWrapper &&other) noexcept : data(other.data) { other.data.ndim = 0; }

// Move constructor from NVTEShape rvalue reference
NVTEShapeWrapper(NVTEShape &&shape) noexcept : data(shape) {}

// Copy assignment operator
NVTEShapeWrapper &operator=(const NVTEShapeWrapper &other) {
if (this != &other) {
data = other.data;
}
return *this;
}

// In the NVTEShapeWrapper class definition:
template <typename T>
NVTEShapeWrapper &operator=(const std::vector<T> &shape_vec) {
data.ndim = shape_vec.size();
for (size_t i = 0; i < data.ndim; ++i) {
data.data[i] = static_cast<size_t>(shape_vec[i]);
}
return *this;
}

// Assignment operator from initializer list
NVTEShapeWrapper &operator=(const std::initializer_list<size_t> &shape_list) {
data.ndim = shape_list.size();
size_t i = 0;
for (const auto &val : shape_list) {
data.data[i++] = val;
}
return *this;
}

// Move assignment operator from another NVTEShapeWrapper
NVTEShapeWrapper &operator=(NVTEShapeWrapper &&other) noexcept {
if (this != &other) {
data = other.data;
other.data.ndim = 0;
}
return *this;
}

// Move assignment operator from NVTEShape rvalue reference
NVTEShapeWrapper &operator=(NVTEShape &&shape) noexcept {
data = shape;
return *this;
}

operator NVTEShape &() { return data; }
operator const NVTEShape &() const { return data; }

// Iterator support
size_t *begin() { return data.data; }
const size_t *begin() const { return data.data; }
size_t *end() { return data.data + data.ndim; }
const size_t *end() const { return data.data + data.ndim; }

// Index access
size_t &operator[](size_t idx) { return data.data[idx]; }
const size_t &operator[](size_t idx) const { return data.data[idx]; }

// Back access
size_t &back() { return data.data[data.ndim - 1]; }
const size_t &back() const { return data.data[data.ndim - 1]; }

// Front access
size_t &front() { return data.data[0]; }
const size_t &front() const { return data.data[0]; }

// Size access
size_t size() const { return data.ndim; }
bool empty() const { return data.ndim == 0; }

// Container operations
void push_back(size_t value) {
if (data.ndim < 15) {
data.data[data.ndim++] = value;
}
}

void clear() { data.ndim = 0; }

void resize(size_t new_size) {
if (new_size <= 15) {
data.ndim = new_size;
}
}

// Equality comparison with another NVTEShapeWrapper
bool operator==(const NVTEShapeWrapper &other) const {
if (data.ndim != other.data.ndim) {
return false;
}
for (size_t i = 0; i < data.ndim; ++i) {
if (data.data[i] != other.data.data[i]) {
return false;
}
}
return true;
}
};

/*! \enum DType
* \brief TE datatype.
*/
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
}

int nvte_is_non_tn_fp8_gemm_supported() {
int num_devices = transformer_engine::cuda::num_devices();
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
int device_id = transformer_engine::cuda::current_device();
Expand Down
91 changes: 53 additions & 38 deletions transformer_engine/pytorch/csrc/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,23 @@
namespace transformer_engine::pytorch {

/*! convert fp4 data shape back to original shape */
std::vector<size_t> convert_shape_back_from_fp4(const std::vector<size_t>& shape, bool transpose) {
std::vector<size_t> ret;
NVTEShape convert_shape_back_from_fp4(const NVTEShape& shape, bool transpose) {
NVTEShapeWrapper ret;
const NVTEShapeWrapper input_shape(shape);
size_t start_idx = (transpose) ? 1 : 0;
for (size_t i = start_idx; i < shape.size() - 1; ++i) {
ret.push_back(shape[i]);
for (size_t i = start_idx; i < input_shape.size() - 1; ++i) {
ret.push_back(input_shape[i]);
}
ret.push_back(shape.back() * 2);
ret.push_back(input_shape.back() * 2);
if (transpose) {
ret.push_back(shape.front());
ret.push_back(input_shape.front());
}
return ret;
}

std::vector<size_t> getTensorShape(const at::Tensor& t) {
std::vector<size_t> shape;
for (auto s : t.sizes()) {
shape.push_back(s);
}
return shape;
return static_cast<NVTEShape>(ret);
}

NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) {
NVTEShape getTensorShape(const at::Tensor& t) {
NVTEShape ret;
const c10::IntArrayRef& torch_shape = t.sizes();
ret.ndim = torch_shape.size();
constexpr int max_dimensions = sizeof(ret.data) / sizeof(size_t);
NVTE_CHECK(ret.ndim < max_dimensions,
Expand All @@ -48,6 +42,23 @@ NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) {
return ret;
}

template <typename T>
NVTEShape make_nvte_1d_shape(T dim0) {
NVTEShape shape;
shape.ndim = 1;
shape.data[0] = static_cast<size_t>(dim0);
return shape;
}

template <typename T, typename U>
NVTEShape make_nvte_2d_shape(T dim0, U dim1) {
NVTEShape shape;
shape.ndim = 2;
shape.data[0] = static_cast<size_t>(dim0);
shape.data[1] = static_cast<size_t>(dim1);
return shape;
}

std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer) {
init_extension();
if (quantizer.is_none()) {
Expand Down Expand Up @@ -112,17 +123,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(
return transformer_engine::TensorWrapper(data_ptr, shape, type);
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, const std::vector<size_t>& shape, const transformer_engine::DType type) {
return transformer_engine::TensorWrapper(data_ptr, shape, type);
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) {
transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type());
std::vector<size_t> shape;
for (auto s : tensor.sizes()) {
shape.push_back(s);
}
NVTEShape shape = getTensorShape(tensor);
return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype);
}

Expand Down Expand Up @@ -164,32 +167,30 @@ makeTransformerEngineTensorList(std::vector<std::vector<at::Tensor>> at_tensor_l
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, const std::vector<size_t>& shape, const transformer_engine::DType type,
void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector<size_t> scale_inv_shape,
void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr,
void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape,
NVTEScalingMode scaling_mode) {
TensorWrapper ret(scaling_mode);
ret.set_rowwise_data(data_ptr, type, shape);
const std::vector<size_t> meta_shape{1};
ret.set_amax(amax_ptr, DType::kFloat32, meta_shape);
ret.set_scale(scale_ptr, DType::kFloat32, meta_shape);
ret.set_amax(amax_ptr, DType::kFloat32, TensorWrapper::defaultShape);
ret.set_scale(scale_ptr, DType::kFloat32, TensorWrapper::defaultShape);
auto scale_inv_dtype =
(scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32;
ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape);
return ret;
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, void* columnwise_data_ptr, const std::vector<size_t>& shape,
const std::vector<size_t>& columnwise_shape, const transformer_engine::DType type,
void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr,
const std::vector<size_t>& scale_inv_shape,
const std::vector<size_t>& columnwise_scale_inv_shape, NVTEScalingMode scaling_mode) {
void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape,
const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr,
void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr,
const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape,
NVTEScalingMode scaling_mode) {
TensorWrapper ret(scaling_mode);
ret.set_rowwise_data(data_ptr, type, shape);
ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape);
const std::vector<size_t> meta_shape{1};
ret.set_amax(amax_ptr, DType::kFloat32, meta_shape);
ret.set_scale(scale_ptr, DType::kFloat32, meta_shape);
ret.set_amax(amax_ptr, DType::kFloat32, TensorWrapper::defaultShape);
ret.set_scale(scale_ptr, DType::kFloat32, TensorWrapper::defaultShape);
auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0
: (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat8E4M3
: DType::kFloat32;
Expand Down Expand Up @@ -230,6 +231,9 @@ template size_t product<size_t>(const std::vector<size_t>& shape);
template int64_t product<int64_t>(const std::vector<int64_t>& shape);

size_t product(const NVTEShape& shape, size_t begin, size_t end) {
if (end == -1) {
end = shape.ndim;
}
NVTE_CHECK(begin <= end && end <= shape.ndim, "Attempted to access entries ", begin, " to ", end,
" in a shape with ", shape.ndim, " entries");
size_t ret = 1;
Expand Down Expand Up @@ -322,4 +326,15 @@ at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_pe
return philox_args;
}

// Explicit template instantiations for make_nvte_1d_shape
template NVTEShape make_nvte_1d_shape<int>(int dim0);
template NVTEShape make_nvte_1d_shape<int64_t>(int64_t dim0);
template NVTEShape make_nvte_1d_shape<size_t>(size_t dim0);

// Explicit template instantiations for make_nvte_2d_shape
template NVTEShape make_nvte_2d_shape<int64_t, int64_t>(int64_t dim0, int64_t dim1);
template NVTEShape make_nvte_2d_shape<size_t, size_t>(size_t dim0, size_t dim1);
template NVTEShape make_nvte_2d_shape<int64_t, size_t>(int64_t dim0, size_t dim1);
template NVTEShape make_nvte_2d_shape<size_t, int64_t>(size_t dim0, int64_t dim1);

} // namespace transformer_engine::pytorch
Loading
Loading