Skip to content

Commit 70b0ff0

Browse files
authored
[core][rdt] Rework passing transport through for bring your own transport (#59216)
Signed-off-by: dayshah <dhyey2019@gmail.com>
1 parent 1180868 commit 70b0ff0

File tree

18 files changed

+179
-201
lines changed

18 files changed

+179
-201
lines changed

doc/source/ray-core/doc_code/direct_transport_gloo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def sum(self, tensor: torch.Tensor):
140140
ray.get(tensor)
141141

142142
assert (
143-
"Currently ray.get() only supports OBJECT_STORE and NIXL tensor transport, got TensorTransportEnum.GLOO, please specify the correct tensor transport in ray.get()"
143+
"Trying to use two-sided tensor transport: GLOO for ray.get. This is only supported for one-sided transports such as NIXL or the OBJECT_STORE."
144144
in str(e.value)
145145
)
146146

python/ray/_private/custom_types.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,7 @@
124124
# See `common.proto` for more details.
125125
class TensorTransportEnum(Enum):
126126
OBJECT_STORE = TensorTransport.Value("OBJECT_STORE")
127-
NCCL = TensorTransport.Value("NCCL")
128-
GLOO = TensorTransport.Value("GLOO")
129-
NIXL = TensorTransport.Value("NIXL")
127+
DIRECT_TRANSPORT = TensorTransport.Value("DIRECT_TRANSPORT")
130128

131129
@classmethod
132130
def from_str(cls, name: str) -> "TensorTransportEnum":

python/ray/_private/worker.py

Lines changed: 29 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ def put_object(
801801
value: Any,
802802
owner_address: Optional[str] = None,
803803
_is_experimental_channel: bool = False,
804-
_tensor_transport: str = "object_store",
804+
_tensor_transport: str = TensorTransportEnum.OBJECT_STORE.name,
805805
):
806806
"""Put value in the local object store.
807807
@@ -835,18 +835,15 @@ def put_object(
835835
"ray.ObjectRef in a list and call 'put' on it."
836836
)
837837
tensors = None
838-
tensor_transport: TensorTransportEnum = TensorTransportEnum.from_str(
839-
_tensor_transport
838+
from ray.experimental.gpu_object_manager.util import (
839+
normalize_and_validate_tensor_transport,
840+
validate_one_sided,
840841
)
841-
if tensor_transport not in [
842-
TensorTransportEnum.OBJECT_STORE,
843-
TensorTransportEnum.NIXL,
844-
]:
845-
raise ValueError(
846-
"Currently, Ray Direct Transport only supports 'object_store' and 'nixl' for tensor transport in ray.put()."
847-
)
842+
843+
tensor_transport = normalize_and_validate_tensor_transport(_tensor_transport)
844+
validate_one_sided(tensor_transport, "ray.put")
848845
try:
849-
if tensor_transport != TensorTransportEnum.OBJECT_STORE:
846+
if tensor_transport != TensorTransportEnum.OBJECT_STORE.name:
850847
(
851848
serialized_value,
852849
tensors,
@@ -867,19 +864,24 @@ def put_object(
867864
# object. Instead, clients will keep the object pinned.
868865
pin_object = not _is_experimental_channel
869866

867+
tensor_transport_enum = TensorTransportEnum.OBJECT_STORE
868+
if tensor_transport != TensorTransportEnum.OBJECT_STORE.name:
869+
tensor_transport_enum = TensorTransportEnum.DIRECT_TRANSPORT
870+
870871
# This *must* be the first place that we construct this python
871872
# ObjectRef because an entry with 0 local references is created when
872873
# the object is Put() in the core worker, expecting that this python
873874
# reference will be created. If another reference is created and
874875
# removed before this one, it will corrupt the state in the
875876
# reference counter.
877+
876878
ret = self.core_worker.put_object(
877879
serialized_value,
878880
pin_object=pin_object,
879881
owner_address=owner_address,
880882
inline_small_object=True,
881883
_is_experimental_channel=_is_experimental_channel,
882-
tensor_transport_val=tensor_transport.value,
884+
tensor_transport_val=tensor_transport_enum.value,
883885
)
884886
if tensors:
885887
self.gpu_object_manager.put_object(ret, tensor_transport, tensors)
@@ -896,43 +898,26 @@ def deserialize_objects(
896898
self,
897899
serialized_objects,
898900
object_refs,
899-
tensor_transport_hint: Optional[TensorTransportEnum] = None,
901+
tensor_transport_hint: Optional[str] = None,
900902
):
901903
gpu_objects: Dict[str, List["torch.Tensor"]] = {}
902904
for obj_ref, (_, _, tensor_transport) in zip(object_refs, serialized_objects):
903-
# TODO: Here tensor_transport_hint is set by the user in ray.get(), tensor_transport is set
904-
# in serialize_objects by ray.method(tensor_transport="xxx"), and obj_ref.tensor_transport()
905-
# is set by ray.put(). We may clean up this logic in the future.
906905
if (
907906
tensor_transport is None
908907
or tensor_transport == TensorTransportEnum.OBJECT_STORE
909-
) and (
910-
obj_ref is None
911-
or obj_ref.tensor_transport() == TensorTransportEnum.OBJECT_STORE.value
912908
):
913909
# The object is not a gpu object, so we cannot use other external transport to
914910
# fetch it.
915911
continue
916912

917-
# If the object is a gpu object, we can choose to use the object store or other external
918-
# transport to fetch it. The `tensor_transport_hint` has the highest priority, then the
919-
# tensor_transport in obj_ref.tensor_transport(), then the tensor_transport in serialize_objects,
920-
# then the default value `OBJECT_STORE`.
921-
chosen_tensor_transport = (
922-
tensor_transport_hint
923-
or (
924-
TensorTransportEnum(obj_ref.tensor_transport()) if obj_ref else None
925-
)
926-
or tensor_transport
927-
or TensorTransportEnum.OBJECT_STORE
928-
)
929-
930913
object_id = obj_ref.hex()
931914
if object_id not in gpu_objects:
932915
# If using a non-object store transport, then tensors will be sent
933916
# out-of-band. Get them before deserializing the object store data.
917+
# The user can choose OBJECT_STORE as the hint to fetch the RDT object
918+
# through the object store.
934919
gpu_objects[object_id] = self.gpu_object_manager.get_gpu_object(
935-
object_id, tensor_transport=chosen_tensor_transport
920+
object_id, tensor_transport=tensor_transport_hint
936921
)
937922

938923
# Function actor manager or the import thread may call pickle.loads
@@ -983,16 +968,6 @@ def get_objects(
983968
f"Attempting to call `get` on the value {object_ref}, "
984969
"which is not an ray.ObjectRef."
985970
)
986-
tensor_transport: TensorTransportEnum = (
987-
TensorTransportEnum.from_str(_tensor_transport)
988-
if _tensor_transport is not None
989-
else None
990-
)
991-
assert tensor_transport in [
992-
TensorTransportEnum.OBJECT_STORE,
993-
TensorTransportEnum.NIXL,
994-
None,
995-
], "Currently, RDT only supports 'object_store' and 'nixl' for tensor transport in ray.get()."
996971
timeout_ms = (
997972
int(timeout * 1000) if timeout is not None and timeout != -1 else -1
998973
)
@@ -1004,7 +979,7 @@ def get_objects(
1004979
)
1005980

1006981
debugger_breakpoint = b""
1007-
for data, metadata, _ in serialized_objects:
982+
for _, metadata, _ in serialized_objects:
1008983
if metadata:
1009984
metadata_fields = metadata.split(b",")
1010985
if len(metadata_fields) >= 2 and metadata_fields[1].startswith(
@@ -1016,8 +991,17 @@ def get_objects(
1016991
if skip_deserialization:
1017992
return None, debugger_breakpoint
1018993

994+
if _tensor_transport is not None:
995+
from ray.experimental.gpu_object_manager.util import (
996+
normalize_and_validate_tensor_transport,
997+
)
998+
999+
_tensor_transport = normalize_and_validate_tensor_transport(
1000+
_tensor_transport
1001+
)
1002+
10191003
values = self.deserialize_objects(
1020-
serialized_objects, object_refs, tensor_transport_hint=tensor_transport
1004+
serialized_objects, object_refs, tensor_transport_hint=_tensor_transport
10211005
)
10221006
if not return_exceptions:
10231007
# Raise exceptions instead of returning them to the user.

python/ray/_raylet.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2874,7 +2874,7 @@ cdef class CoreWorker:
28742874
c_object_ids, timeout_ms, results)
28752875
check_status(op_status)
28762876

2877-
return RayObjectsToSerializedRayObjects(results)
2877+
return RayObjectsToSerializedRayObjects(results, object_refs)
28782878

28792879
def get_if_local(self, object_refs):
28802880
"""Get objects from local plasma store directly
@@ -2886,7 +2886,7 @@ cdef class CoreWorker:
28862886
check_status(
28872887
CCoreWorkerProcess.GetCoreWorker().GetIfLocal(
28882888
c_object_ids, &results))
2889-
return RayObjectsToSerializedRayObjects(results)
2889+
return RayObjectsToSerializedRayObjects(results, object_refs)
28902890

28912891
def object_exists(self, ObjectRef object_ref, memory_store_only=False):
28922892
cdef:

python/ray/actor.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -473,9 +473,14 @@ def annotate_method(method: Callable[_P, _Ret]):
473473
if "enable_task_events" in kwargs and kwargs["enable_task_events"] is not None:
474474
method.__ray_enable_task_events__ = kwargs["enable_task_events"]
475475
if "tensor_transport" in kwargs:
476-
method.__ray_tensor_transport__ = TensorTransportEnum.from_str(
477-
kwargs["tensor_transport"]
476+
tensor_transport = kwargs["tensor_transport"]
477+
from ray.experimental.gpu_object_manager.util import (
478+
normalize_and_validate_tensor_transport,
478479
)
480+
481+
tensor_transport = normalize_and_validate_tensor_transport(tensor_transport)
482+
method.__ray_tensor_transport__ = tensor_transport
483+
479484
return method
480485

481486
# Check if decorator is called without parentheses (args[0] would be the function)
@@ -521,7 +526,7 @@ def __init__(
521526
enable_task_events: bool,
522527
decorator: Optional[Any] = None,
523528
signature: Optional[List[inspect.Parameter]] = None,
524-
tensor_transport: Optional[TensorTransportEnum] = None,
529+
tensor_transport: Optional[str] = None,
525530
):
526531
"""Initialize an _ActorMethodMetadata.
527532
@@ -599,7 +604,7 @@ def __init__(
599604
enable_task_events: bool,
600605
decorator=None,
601606
signature: Optional[List[inspect.Parameter]] = None,
602-
tensor_transport: Optional[TensorTransportEnum] = None,
607+
tensor_transport: Optional[str] = None,
603608
):
604609
"""Initialize an ActorMethod.
605610
@@ -649,10 +654,10 @@ def __init__(
649654
# and return the resulting ObjectRefs.
650655
self._decorator = decorator
651656

652-
# If the task call doesn't specify a tensor transport option, use `_tensor_transport`
657+
# If the task call doesn't specify a tensor transport option, use `OBJECT_STORE`
653658
# as the default transport for this actor method.
654659
if tensor_transport is None:
655-
tensor_transport = TensorTransportEnum.OBJECT_STORE
660+
tensor_transport = TensorTransportEnum.OBJECT_STORE.name
656661
self._tensor_transport = tensor_transport
657662

658663
def __call__(self, *args, **kwargs):
@@ -695,7 +700,12 @@ def options(self, **options):
695700

696701
tensor_transport = options.get("tensor_transport", None)
697702
if tensor_transport is not None:
698-
options["tensor_transport"] = TensorTransportEnum.from_str(tensor_transport)
703+
from ray.experimental.gpu_object_manager.util import (
704+
normalize_and_validate_tensor_transport,
705+
)
706+
707+
tensor_transport = normalize_and_validate_tensor_transport(tensor_transport)
708+
options["tensor_transport"] = tensor_transport
699709

700710
class FuncWrapper:
701711
def remote(self, *args, **kwargs):
@@ -800,7 +810,7 @@ def _remote(
800810
concurrency_group=None,
801811
_generator_backpressure_num_objects=None,
802812
enable_task_events=None,
803-
tensor_transport: Optional[TensorTransportEnum] = None,
813+
tensor_transport: Optional[str] = None,
804814
):
805815
if num_returns is None:
806816
num_returns = self._num_returns
@@ -820,23 +830,23 @@ def _remote(
820830
if tensor_transport is None:
821831
tensor_transport = self._tensor_transport
822832

823-
if tensor_transport != TensorTransportEnum.OBJECT_STORE:
833+
if tensor_transport != TensorTransportEnum.OBJECT_STORE.name:
824834
if num_returns != 1:
825835
raise ValueError(
826-
f"Currently, methods with tensor_transport={tensor_transport.name} only support 1 return value. "
836+
f"Currently, methods with tensor_transport={tensor_transport} only support 1 return value. "
827837
"Please make sure the actor method is decorated with `@ray.method(num_returns=1)` (the default)."
828838
)
829839
if not self._actor._ray_enable_tensor_transport:
830840
raise ValueError(
831-
f'Currently, methods with .options(tensor_transport="{tensor_transport.name}") are not supported when enable_tensor_transport=False. '
841+
f'Currently, methods with .options(tensor_transport="{tensor_transport}") are not supported when enable_tensor_transport=False. '
832842
"Please set @ray.remote(enable_tensor_transport=True) on the actor class definition."
833843
)
834844
gpu_object_manager = ray._private.worker.global_worker.gpu_object_manager
835845
if not gpu_object_manager.actor_has_tensor_transport(
836846
self._actor, tensor_transport
837847
):
838848
raise ValueError(
839-
f'{self._actor} does not have tensor transport {tensor_transport.name} available. If using a collective-based transport ("nccl" or "gloo"), please create a communicator with '
849+
f'{self._actor} does not have tensor transport {tensor_transport} available. If using a collective-based transport ("nccl" or "gloo"), please create a communicator with '
840850
"`ray.experimental.collective.create_collective_group` "
841851
"before calling actor tasks with non-default tensor_transport."
842852
)
@@ -876,7 +886,7 @@ def invocation(args, kwargs):
876886
invocation = self._decorator(invocation)
877887

878888
object_refs = invocation(args, kwargs)
879-
if tensor_transport != TensorTransportEnum.OBJECT_STORE:
889+
if tensor_transport != TensorTransportEnum.OBJECT_STORE.name:
880890
# Currently, we only support transfer tensor out-of-band when
881891
# num_returns is 1.
882892
assert isinstance(object_refs, ObjectRef)
@@ -979,14 +989,16 @@ def create(
979989
self.enable_task_events = {}
980990
self.generator_backpressure_num_objects = {}
981991
self.concurrency_group_for_methods = {}
982-
self.method_name_to_tensor_transport: Dict[str, TensorTransportEnum] = {}
992+
self.method_name_to_tensor_transport: Dict[str, str] = {}
983993

984994
# Check whether any actor methods specify a non-default tensor transport.
985995
self.has_tensor_transport_methods = any(
986996
getattr(
987-
method, "__ray_tensor_transport__", TensorTransportEnum.OBJECT_STORE
997+
method,
998+
"__ray_tensor_transport__",
999+
TensorTransportEnum.OBJECT_STORE.name,
9881000
)
989-
!= TensorTransportEnum.OBJECT_STORE
1001+
!= TensorTransportEnum.OBJECT_STORE.name
9901002
for _, method in actor_methods
9911003
)
9921004

@@ -1941,7 +1953,7 @@ def __init__(
19411953
method_generator_backpressure_num_objects: Dict[str, int],
19421954
method_enable_task_events: Dict[str, bool],
19431955
enable_tensor_transport: bool,
1944-
method_name_to_tensor_transport: Dict[str, TensorTransportEnum],
1956+
method_name_to_tensor_transport: Dict[str, str],
19451957
actor_method_cpus: int,
19461958
actor_creation_function_descriptor,
19471959
cluster_and_job,
@@ -1968,7 +1980,7 @@ def __init__(
19681980
this actor. If True, then methods can be called with
19691981
.options(tensor_transport=...) to specify a non-default tensor
19701982
transport.
1971-
method_name_to_tensor_transport: Dictionary mapping method names to their tensor transport settings.
1983+
method_name_to_tensor_transport: Dictionary mapping method names to their tensor transport type.
19721984
actor_method_cpus: The number of CPUs required by actor methods.
19731985
actor_creation_function_descriptor: The function descriptor for actor creation.
19741986
cluster_and_job: The cluster and job information.
@@ -2079,7 +2091,7 @@ def _actor_method_call(
20792091
concurrency_group_name: Optional[str] = None,
20802092
generator_backpressure_num_objects: Optional[int] = None,
20812093
enable_task_events: Optional[bool] = None,
2082-
tensor_transport: Optional[TensorTransportEnum] = None,
2094+
tensor_transport: Optional[str] = None,
20832095
):
20842096
"""Method execution stub for an actor handle.
20852097
@@ -2157,6 +2169,12 @@ def _actor_method_call(
21572169
if generator_backpressure_num_objects is None:
21582170
generator_backpressure_num_objects = -1
21592171

2172+
tensor_transport_enum = TensorTransportEnum.OBJECT_STORE
2173+
if (
2174+
tensor_transport is not None
2175+
and tensor_transport != TensorTransportEnum.OBJECT_STORE.name
2176+
):
2177+
tensor_transport_enum = TensorTransportEnum.DIRECT_TRANSPORT
21602178
object_refs = worker.core_worker.submit_actor_task(
21612179
self._ray_actor_language,
21622180
self._ray_actor_id,
@@ -2171,7 +2189,7 @@ def _actor_method_call(
21712189
concurrency_group_name if concurrency_group_name is not None else b"",
21722190
generator_backpressure_num_objects,
21732191
enable_task_events,
2174-
tensor_transport.value,
2192+
tensor_transport_enum.value,
21752193
)
21762194

21772195
if num_returns == STREAMING_GENERATOR_RETURN:

0 commit comments

Comments
 (0)