-
Notifications
You must be signed in to change notification settings - Fork 7k
[core][rdt] Rework passing transport through for bring your own transport #59216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
5dc370f
4fa48a2
0386f17
4abde04
99f1d45
e579076
fb8a99f
8ba388a
02d408c
ed3b576
dc323b5
42e1f85
6034d57
d0286f7
fb6e95c
21da299
0166383
88b1fc0
62369d9
3a6aed3
a6d1bcf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -801,7 +801,7 @@ def put_object( | |
| value: Any, | ||
| owner_address: Optional[str] = None, | ||
| _is_experimental_channel: bool = False, | ||
| _tensor_transport: str = "object_store", | ||
| _tensor_transport: str = "OBJECT_STORE", | ||
| ): | ||
| """Put value in the local object store. | ||
|
|
||
|
|
@@ -835,18 +835,16 @@ def put_object( | |
| "ray.ObjectRef in a list and call 'put' on it." | ||
| ) | ||
| tensors = None | ||
| tensor_transport: TensorTransportEnum = TensorTransportEnum.from_str( | ||
| _tensor_transport | ||
| ) | ||
| tensor_transport = _tensor_transport.upper() | ||
| if tensor_transport not in [ | ||
| TensorTransportEnum.OBJECT_STORE, | ||
| TensorTransportEnum.NIXL, | ||
| "OBJECT_STORE", | ||
| "NIXL", | ||
| ]: | ||
| raise ValueError( | ||
| "Currently, Ray Direct Transport only supports 'object_store' and 'nixl' for tensor transport in ray.put()." | ||
| ) | ||
|
||
| try: | ||
| if tensor_transport != TensorTransportEnum.OBJECT_STORE: | ||
| if tensor_transport != "OBJECT_STORE": | ||
| ( | ||
| serialized_value, | ||
| tensors, | ||
|
|
@@ -867,19 +865,24 @@ def put_object( | |
| # object. Instead, clients will keep the object pinned. | ||
| pin_object = not _is_experimental_channel | ||
|
|
||
| tensor_transport_enum = TensorTransportEnum.OBJECT_STORE | ||
| if tensor_transport != "OBJECT_STORE": | ||
| tensor_transport_enum = TensorTransportEnum.DIRECT_TRANSPORT | ||
|
|
||
| # This *must* be the first place that we construct this python | ||
| # ObjectRef because an entry with 0 local references is created when | ||
| # the object is Put() in the core worker, expecting that this python | ||
| # reference will be created. If another reference is created and | ||
| # removed before this one, it will corrupt the state in the | ||
| # reference counter. | ||
|
|
||
| ret = self.core_worker.put_object( | ||
| serialized_value, | ||
| pin_object=pin_object, | ||
| owner_address=owner_address, | ||
| inline_small_object=True, | ||
| _is_experimental_channel=_is_experimental_channel, | ||
| tensor_transport_val=tensor_transport.value, | ||
| tensor_transport_val=tensor_transport_enum.value, | ||
| ) | ||
| if tensors: | ||
| self.gpu_object_manager.put_object(ret, tensor_transport, tensors) | ||
|
|
@@ -896,43 +899,26 @@ def deserialize_objects( | |
| self, | ||
| serialized_objects, | ||
| object_refs, | ||
| tensor_transport_hint: Optional[TensorTransportEnum] = None, | ||
| tensor_transport_hint: Optional[str] = None, | ||
| ): | ||
| gpu_objects: Dict[str, List["torch.Tensor"]] = {} | ||
| for obj_ref, (_, _, tensor_transport) in zip(object_refs, serialized_objects): | ||
| # TODO: Here tensor_transport_hint is set by the user in ray.get(), tensor_transport is set | ||
| # in serialize_objects by ray.method(tensor_transport="xxx"), and obj_ref.tensor_transport() | ||
| # is set by ray.put(). We may clean up this logic in the future. | ||
| if ( | ||
| tensor_transport is None | ||
| or tensor_transport == TensorTransportEnum.OBJECT_STORE | ||
| ) and ( | ||
| obj_ref is None | ||
| or obj_ref.tensor_transport() == TensorTransportEnum.OBJECT_STORE.value | ||
| ): | ||
| # The object is not a gpu object, so we cannot use other external transport to | ||
| # fetch it. | ||
| continue | ||
|
|
||
| # If the object is a gpu object, we can choose to use the object store or other external | ||
| # transport to fetch it. The `tensor_transport_hint` has the highest priority, then the | ||
| # tensor_transport in obj_ref.tensor_transport(), then the tensor_transport in serialize_objects, | ||
| # then the default value `OBJECT_STORE`. | ||
| chosen_tensor_transport = ( | ||
| tensor_transport_hint | ||
| or ( | ||
| TensorTransportEnum(obj_ref.tensor_transport()) if obj_ref else None | ||
| ) | ||
| or tensor_transport | ||
| or TensorTransportEnum.OBJECT_STORE | ||
| ) | ||
|
|
||
| object_id = obj_ref.hex() | ||
| if object_id not in gpu_objects: | ||
| # If using a non-object store transport, then tensors will be sent | ||
| # out-of-band. Get them before deserializing the object store data. | ||
| # The user can choose OBJECT_STORE as the hint to fetch the RDT object | ||
| # through the object store. | ||
| gpu_objects[object_id] = self.gpu_object_manager.get_gpu_object( | ||
| object_id, tensor_transport=chosen_tensor_transport | ||
| object_id, tensor_transport=tensor_transport_hint | ||
| ) | ||
|
|
||
| # Function actor manager or the import thread may call pickle.loads | ||
|
|
@@ -983,16 +969,6 @@ def get_objects( | |
| f"Attempting to call `get` on the value {object_ref}, " | ||
| "which is not an ray.ObjectRef." | ||
| ) | ||
| tensor_transport: TensorTransportEnum = ( | ||
| TensorTransportEnum.from_str(_tensor_transport) | ||
| if _tensor_transport is not None | ||
| else None | ||
| ) | ||
| assert tensor_transport in [ | ||
| TensorTransportEnum.OBJECT_STORE, | ||
| TensorTransportEnum.NIXL, | ||
| None, | ||
| ], "Currently, RDT only supports 'object_store' and 'nixl' for tensor transport in ray.get()." | ||
| timeout_ms = ( | ||
| int(timeout * 1000) if timeout is not None and timeout != -1 else -1 | ||
| ) | ||
|
|
@@ -1004,7 +980,7 @@ def get_objects( | |
| ) | ||
|
|
||
| debugger_breakpoint = b"" | ||
| for data, metadata, _ in serialized_objects: | ||
| for _, metadata, _ in serialized_objects: | ||
| if metadata: | ||
| metadata_fields = metadata.split(b",") | ||
| if len(metadata_fields) >= 2 and metadata_fields[1].startswith( | ||
|
|
@@ -1016,6 +992,9 @@ def get_objects( | |
| if skip_deserialization: | ||
| return None, debugger_breakpoint | ||
|
|
||
| tensor_transport = ( | ||
| _tensor_transport.upper() if _tensor_transport is not None else None | ||
| ) | ||
| values = self.deserialize_objects( | ||
| serialized_objects, object_refs, tensor_transport_hint=tensor_transport | ||
| ) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.