Skip to content

Commit d79bf21

Browse files
committed
greptile fixes
Signed-off-by: Alp Dener <adener@nvidia.com>
1 parent dd8eaf3 commit d79bf21

File tree

5 files changed

+20
-15
lines changed

5 files changed

+20
-15
lines changed

tests/pytorch/distributed/run_gemm_with_overlap.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,6 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None
408408
if opts.comm_type == tex.CommOverlapType.AG:
409409
# (M/P, N) -> overlapped AG -> (M, N) x (K/P, N)^T = (M, K/P)
410410
local_kernel_t_shape = (ffn_hidden_size // tp_size, hidden_size)
411-
local_kernel2_t_shape = (0, )
412411
local_inp_shape = (outer_size // tp_size, hidden_size)
413412
if ub_obj2 is not None:
414413
local_kernel2_t_shape = (hidden_size, ffn_hidden_size // tp_size)
@@ -479,7 +478,6 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None
479478
ref_g = torch.stack(bulk_inp_list).sum(dim=0)
480479
else:
481480
ref_g = torch.matmul(inp_g, ker_g)
482-
ref2_g = (0, )
483481
if ub_obj2 is not None:
484482
inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable
485483
ref2_g = torch.matmul(inp2_g, ker2_g)

transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ void CommOverlapCore::cublasmp_gemm_rs(const TensorWrapper &A, bool transa, cons
332332
int64_t m = transa ? A.size(0) : A.size(1);
333333
int64_t n = transb ? B.size(1) : B.size(0);
334334
int64_t k_local = transa ? A.size(1) : A.size(0);
335-
int64_t k = k * _tp_size;
335+
int64_t k = k_local * _tp_size;
336336

337337
nvte_gemm_reduce_scatter(_cublasmp_ctx, m, n, k, A.data(), B.data(), D.data(), bias.data(),
338338
pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm,

transformer_engine/jax/csrc/extensions/cgemm_helper.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ void CommunicatorHandler::init(int num_total_devices, int num_devices_per_proces
132132
NVTE_CHECK_NCCL(ncclGroupEnd());
133133

134134
// Allocate device memory for barrier operations
135-
NVTE_CHECK_CUDA(cudaMalloc(&reinterpret_cast<int>(handler._device_barrier), sizeof(int)));
135+
NVTE_CHECK_CUDA(cudaMalloc(&handler._device_barrier, sizeof(int)));
136136

137137
handler._initialize = true;
138138

@@ -195,8 +195,9 @@ CommOverlapCore *CollectiveGemmPlanRegistry::get_executor(std::vector<size_t> bu
195195
std::unique_ptr<CommOverlapCore> executor;
196196
if (use_cublasmp) {
197197
executor = std::make_unique<CommOverlapP2PBase>(
198-
reinterpret_cast<int64_t>(comm_handler.get_comm_for_current_device()), comm_handler.tp_size,
199-
comm_handler.get_tp_domain_id(), cgemm_config.num_comm_sm, cgemm_config.aggregate_ag);
198+
reinterpret_cast<int64_t>(comm_handler.get_comm_for_current_device()),
199+
comm_handler.get_tp_domain_id(), comm_handler.tp_size, cgemm_config.num_comm_sm,
200+
cgemm_config.aggregate_ag);
200201
} else {
201202
executor = std::make_unique<CommOverlapP2PBase>(
202203
buffer_shape, dtype, comm_handler.get_global_rank(), comm_handler.num_total_devices,

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -525,13 +525,19 @@ class CommOverlapHelper : torch::CustomClassHolder {
525525
void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes,
526526
ExtComm comm);
527527

528-
void ub_barrier(ExtComm comm);a
528+
void ub_barrier(ExtComm comm);
529529

530530
int64_t get_nccl_comm_ptr(std::string comm_name) {
531+
#ifdef USE_C10_NCCL
531532
NVTE_CHECK(backend_is_nccl,
532533
"Comm+GEMM overlap with cuBLASMp backend requires a tensor-parallel process ",
533534
"group with NCCL backend.");
534-
return reinterpret_cast<c10d::ProcessGroupNCCL *>(pgs[comm_name])->getCommPtr();
535+
c10d::ProcessGroupNCCL *nccl_pg = reinterpret_cast<c10d::ProcessGroupNCCL *>(pgs[comm_name]);
536+
return nccl_pg->getCommPtr();
537+
#else
538+
NVTE_ERROR("Internal TE Error: CommOverlapHelper::get_nccl_comm_ptr() is an internal API that ",
539+
"should only be used when TE is built with the NVTE_WITH_CUBLASMP=1 flag.");
540+
#endif
535541
}
536542
};
537543

@@ -542,11 +548,11 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve
542548
int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
543549
int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16,
544550
bool set_sm_margin = true, bool atomic_gemm = false,
545-
bool rs_overlap_first_gemm = false);
551+
bool rs_overlap_first_gemm= false);
546552

547-
CommOverlap(CommOverlapHelper *helper, int tp_size, int tp_rank, int num_comm_sm = 16,
553+
CommOverlap(CommOverlapHelper *helper, int tp_rank, int tp_size, int num_comm_sm = 16,
548554
bool atomic_gemm = false)
549-
: CommOverlapBase(helper->get_nccl_comm_ptr("intra"), tp_size, tp_rank, num_comm_sm,
555+
: CommOverlapBase(helper->get_nccl_comm_ptr("intra"), tp_rank, tp_size, num_comm_sm,
550556
atomic_gemm) {}
551557

552558
~CommOverlap() {}
@@ -570,9 +576,9 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
570576
bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true,
571577
bool aggregate = false);
572578

573-
CommOverlapP2P(CommOverlapHelper *helper, int tp_size, int tp_rank, int num_comm_sm = 16,
579+
CommOverlapP2P(CommOverlapHelper *helper, int tp_rank, int tp_size, int num_comm_sm = 16,
574580
bool atomic_gemm = false)
575-
: CommOverlapP2PBase(helper->get_nccl_comm_ptr("intra"), tp_size, tp_rank, num_comm_sm,
581+
: CommOverlapP2PBase(helper->get_nccl_comm_ptr("intra"), tp_rank, tp_size, num_comm_sm,
576582
atomic_gemm) {}
577583

578584
~CommOverlapP2P() {}

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
491491
py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true,
492492
py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false)
493493
.def(py::init<CommOverlapHelper *, int, int, int, bool>(), py::arg("helper"),
494-
py::arg("tp_size"), py::arg("tp_rank"), py::arg("num_comm_sm") = 0,
494+
py::arg("tp_rank"), py::arg("tp_size"), py::arg("num_comm_sm") = 0,
495495
py::arg("atomic_gemm") = false, py::call_guard<py::gil_scoped_release>())
496496
.def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"),
497497
py::arg("local_chunk") = false)
@@ -512,7 +512,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
512512
py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false,
513513
py::arg("use_ce") = true, py::arg("aggregate") = false)
514514
.def(py::init<CommOverlapHelper *, int, int, int, bool>(), py::arg("helper"),
515-
py::arg("tp_size"), py::arg("tp_rank"), py::arg("num_comm_sm") = 0,
515+
py::arg("tp_rank"), py::arg("tp_size"), py::arg("num_comm_sm") = 0,
516516
py::arg("atomic_gemm") = false, py::call_guard<py::gil_scoped_release>())
517517
.def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"),
518518
py::arg("local_chunk") = false)

0 commit comments

Comments
 (0)