@@ -473,9 +473,14 @@ def annotate_method(method: Callable[_P, _Ret]):
473473 if "enable_task_events" in kwargs and kwargs ["enable_task_events" ] is not None :
474474 method .__ray_enable_task_events__ = kwargs ["enable_task_events" ]
475475 if "tensor_transport" in kwargs :
476- method .__ray_tensor_transport__ = TensorTransportEnum .from_str (
477- kwargs ["tensor_transport" ]
476+ tensor_transport = kwargs ["tensor_transport" ]
477+ from ray .experimental .gpu_object_manager .util import (
478+ normalize_and_validate_tensor_transport ,
478479 )
480+
481+ tensor_transport = normalize_and_validate_tensor_transport (tensor_transport )
482+ method .__ray_tensor_transport__ = tensor_transport
483+
479484 return method
480485
481486 # Check if decorator is called without parentheses (args[0] would be the function)
@@ -521,7 +526,7 @@ def __init__(
521526 enable_task_events : bool ,
522527 decorator : Optional [Any ] = None ,
523528 signature : Optional [List [inspect .Parameter ]] = None ,
524- tensor_transport : Optional [TensorTransportEnum ] = None ,
529+ tensor_transport : Optional [str ] = None ,
525530 ):
526531 """Initialize an _ActorMethodMetadata.
527532
@@ -599,7 +604,7 @@ def __init__(
599604 enable_task_events : bool ,
600605 decorator = None ,
601606 signature : Optional [List [inspect .Parameter ]] = None ,
602- tensor_transport : Optional [TensorTransportEnum ] = None ,
607+ tensor_transport : Optional [str ] = None ,
603608 ):
604609 """Initialize an ActorMethod.
605610
@@ -649,10 +654,10 @@ def __init__(
649654 # and return the resulting ObjectRefs.
650655 self ._decorator = decorator
651656
652- # If the task call doesn't specify a tensor transport option, use `_tensor_transport `
657+ # If the task call doesn't specify a tensor transport option, use `OBJECT_STORE `
653658 # as the default transport for this actor method.
654659 if tensor_transport is None :
655- tensor_transport = TensorTransportEnum .OBJECT_STORE
660+ tensor_transport = TensorTransportEnum .OBJECT_STORE . name
656661 self ._tensor_transport = tensor_transport
657662
658663 def __call__ (self , * args , ** kwargs ):
@@ -695,7 +700,12 @@ def options(self, **options):
695700
696701 tensor_transport = options .get ("tensor_transport" , None )
697702 if tensor_transport is not None :
698- options ["tensor_transport" ] = TensorTransportEnum .from_str (tensor_transport )
703+ from ray .experimental .gpu_object_manager .util import (
704+ normalize_and_validate_tensor_transport ,
705+ )
706+
707+ tensor_transport = normalize_and_validate_tensor_transport (tensor_transport )
708+ options ["tensor_transport" ] = tensor_transport
699709
700710 class FuncWrapper :
701711 def remote (self , * args , ** kwargs ):
@@ -800,7 +810,7 @@ def _remote(
800810 concurrency_group = None ,
801811 _generator_backpressure_num_objects = None ,
802812 enable_task_events = None ,
803- tensor_transport : Optional [TensorTransportEnum ] = None ,
813+ tensor_transport : Optional [str ] = None ,
804814 ):
805815 if num_returns is None :
806816 num_returns = self ._num_returns
@@ -820,23 +830,23 @@ def _remote(
820830 if tensor_transport is None :
821831 tensor_transport = self ._tensor_transport
822832
823- if tensor_transport != TensorTransportEnum .OBJECT_STORE :
833+ if tensor_transport != TensorTransportEnum .OBJECT_STORE . name :
824834 if num_returns != 1 :
825835 raise ValueError (
826- f"Currently, methods with tensor_transport={ tensor_transport . name } only support 1 return value. "
836+ f"Currently, methods with tensor_transport={ tensor_transport } only support 1 return value. "
827837 "Please make sure the actor method is decorated with `@ray.method(num_returns=1)` (the default)."
828838 )
829839 if not self ._actor ._ray_enable_tensor_transport :
830840 raise ValueError (
831- f'Currently, methods with .options(tensor_transport="{ tensor_transport . name } ") are not supported when enable_tensor_transport=False. '
841+ f'Currently, methods with .options(tensor_transport="{ tensor_transport } ") are not supported when enable_tensor_transport=False. '
832842 "Please set @ray.remote(enable_tensor_transport=True) on the actor class definition."
833843 )
834844 gpu_object_manager = ray ._private .worker .global_worker .gpu_object_manager
835845 if not gpu_object_manager .actor_has_tensor_transport (
836846 self ._actor , tensor_transport
837847 ):
838848 raise ValueError (
839- f'{ self ._actor } does not have tensor transport { tensor_transport . name } available. If using a collective-based transport ("nccl" or "gloo"), please create a communicator with '
849+ f'{ self ._actor } does not have tensor transport { tensor_transport } available. If using a collective-based transport ("nccl" or "gloo"), please create a communicator with '
840850 "`ray.experimental.collective.create_collective_group` "
841851 "before calling actor tasks with non-default tensor_transport."
842852 )
@@ -876,7 +886,7 @@ def invocation(args, kwargs):
876886 invocation = self ._decorator (invocation )
877887
878888 object_refs = invocation (args , kwargs )
879- if tensor_transport != TensorTransportEnum .OBJECT_STORE :
889+ if tensor_transport != TensorTransportEnum .OBJECT_STORE . name :
880890 # Currently, we only support transfer tensor out-of-band when
881891 # num_returns is 1.
882892 assert isinstance (object_refs , ObjectRef )
@@ -979,14 +989,16 @@ def create(
979989 self .enable_task_events = {}
980990 self .generator_backpressure_num_objects = {}
981991 self .concurrency_group_for_methods = {}
982- self .method_name_to_tensor_transport : Dict [str , TensorTransportEnum ] = {}
992+ self .method_name_to_tensor_transport : Dict [str , str ] = {}
983993
984994 # Check whether any actor methods specify a non-default tensor transport.
985995 self .has_tensor_transport_methods = any (
986996 getattr (
987- method , "__ray_tensor_transport__" , TensorTransportEnum .OBJECT_STORE
997+ method ,
998+ "__ray_tensor_transport__" ,
999+ TensorTransportEnum .OBJECT_STORE .name ,
9881000 )
989- != TensorTransportEnum .OBJECT_STORE
1001+ != TensorTransportEnum .OBJECT_STORE . name
9901002 for _ , method in actor_methods
9911003 )
9921004
@@ -1941,7 +1953,7 @@ def __init__(
19411953 method_generator_backpressure_num_objects : Dict [str , int ],
19421954 method_enable_task_events : Dict [str , bool ],
19431955 enable_tensor_transport : bool ,
1944- method_name_to_tensor_transport : Dict [str , TensorTransportEnum ],
1956+ method_name_to_tensor_transport : Dict [str , str ],
19451957 actor_method_cpus : int ,
19461958 actor_creation_function_descriptor ,
19471959 cluster_and_job ,
@@ -1968,7 +1980,7 @@ def __init__(
19681980 this actor. If True, then methods can be called with
19691981 .options(tensor_transport=...) to specify a non-default tensor
19701982 transport.
1971- method_name_to_tensor_transport: Dictionary mapping method names to their tensor transport settings .
1983+ method_name_to_tensor_transport: Dictionary mapping method names to their tensor transport type .
19721984 actor_method_cpus: The number of CPUs required by actor methods.
19731985 actor_creation_function_descriptor: The function descriptor for actor creation.
19741986 cluster_and_job: The cluster and job information.
@@ -2079,7 +2091,7 @@ def _actor_method_call(
20792091 concurrency_group_name : Optional [str ] = None ,
20802092 generator_backpressure_num_objects : Optional [int ] = None ,
20812093 enable_task_events : Optional [bool ] = None ,
2082- tensor_transport : Optional [TensorTransportEnum ] = None ,
2094+ tensor_transport : Optional [str ] = None ,
20832095 ):
20842096 """Method execution stub for an actor handle.
20852097
@@ -2157,6 +2169,12 @@ def _actor_method_call(
21572169 if generator_backpressure_num_objects is None :
21582170 generator_backpressure_num_objects = - 1
21592171
2172+ tensor_transport_enum = TensorTransportEnum .OBJECT_STORE
2173+ if (
2174+ tensor_transport is not None
2175+ and tensor_transport != TensorTransportEnum .OBJECT_STORE .name
2176+ ):
2177+ tensor_transport_enum = TensorTransportEnum .DIRECT_TRANSPORT
21602178 object_refs = worker .core_worker .submit_actor_task (
21612179 self ._ray_actor_language ,
21622180 self ._ray_actor_id ,
@@ -2171,7 +2189,7 @@ def _actor_method_call(
21712189 concurrency_group_name if concurrency_group_name is not None else b"" ,
21722190 generator_backpressure_num_objects ,
21732191 enable_task_events ,
2174- tensor_transport .value ,
2192+ tensor_transport_enum .value ,
21752193 )
21762194
21772195 if num_returns == STREAMING_GENERATOR_RETURN :
0 commit comments