Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ def predump(self, data, **kwargs):


class PyTorchDistributionSchema(metaclass=PatchedSchemaMeta):
type = StringTransformedEnum(required=True, allowed_values=DistributionType.PYTORCH)
type = StringTransformedEnum(
required=True,
allowed_values=[DistributionType.PYTORCH, DistributionType.TORCH_DISTRIBUTED]
)
process_count_per_instance = fields.Int()

@post_load
Expand Down
2 changes: 2 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/constants/_job/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ class DistributionType:
TENSORFLOW = "tensorflow"
PYTORCH = "pytorch"
RAY = "ray"
# Legacy alias for backwards compatibility with AML SDK v1.5
TORCH_DISTRIBUTED = "torch.distributed"


class JobType(object):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
DistributionType.TENSORFLOW: RestDistributionType.TENSOR_FLOW,
DistributionType.PYTORCH: RestDistributionType.PY_TORCH,
DistributionType.RAY: RestDistributionType.RAY,
# Support legacy alias - maps to the same REST type as PYTORCH
DistributionType.TORCH_DISTRIBUTED: RestDistributionType.PY_TORCH,
}


Expand Down Expand Up @@ -226,4 +228,6 @@ def _to_rest_object(self) -> RestRay:
DistributionType.TENSORFLOW: TensorFlowDistribution,
DistributionType.PYTORCH: PyTorchDistribution,
DistributionType.RAY: RayDistribution,
# Support legacy alias for backwards compatibility with AML SDK v1.5
DistributionType.TORCH_DISTRIBUTED: PyTorchDistribution,
}
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,15 @@ def from_dict_to_rest_io(

def from_dict_to_rest_distribution(distribution_dict: Dict) -> Union[PyTorch, Mpi, TensorFlow, Ray]:
target_type = distribution_dict["distribution_type"].lower()
if target_type == "pytorch":
if target_type == "pytorch" or target_type == "torch.distributed":
return PyTorch(**distribution_dict)
if target_type == "mpi":
return Mpi(**distribution_dict)
if target_type == "tensorflow":
return TensorFlow(**distribution_dict)
if target_type == "ray":
return Ray(**distribution_dict)
msg = "Distribution type must be pytorch, mpi, tensorflow or ray: {}".format(target_type)
msg = "Distribution type must be pytorch, torch.distributed, mpi, tensorflow or ray: {}".format(target_type)
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message lists "torch.distributed" separately from "pytorch", which might confuse users since torch.distributed is documented as an internal alias for backwards compatibility. Consider updating the message to either:

  1. Only mention "pytorch" as the valid type and omit "torch.distributed" from user-facing error messages (since it's a legacy alias)
  2. Clarify that "torch.distributed" is a legacy alias

For consistency with the PR's intention to support the alias transparently, option 1 would be more user-friendly.

Suggested change
msg = "Distribution type must be pytorch, torch.distributed, mpi, tensorflow or ray: {}".format(target_type)
msg = "Distribution type must be pytorch, mpi, tensorflow or ray: {}".format(target_type)

Copilot uses AI. Check for mistakes.
raise ValidationException(
message=msg,
no_personal_data_message=msg,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,37 @@ def test_invalid_distribution_config(self):
with pytest.raises(ValidationError):
CommandJob(**schema.load(cfg))

def test_pytorch_torch_distributed_interchangeable(self):
"""Test that type: pytorch and type: torch.distributed work interchangeably"""
# Test with torch.distributed type
path_torch_distributed = "./tests/test_configs/command_job/dist_job_pytorch_torch_distributed.yml"
# Test with pytorch type (existing file)
path_pytorch = "./tests/test_configs/command_job/dist_job_1.yml"

context = {BASE_PATH_CONTEXT_KEY: Path(path_torch_distributed).parent}
schema = CommandJobSchema(context=context)

# Load and verify torch.distributed type
with open(path_torch_distributed, "r") as f:
cfg_torch_distributed = yaml.safe_load(f)
job_torch_distributed = CommandJob(**schema.load(cfg_torch_distributed))

# Load and verify pytorch type
with open(path_pytorch, "r") as f:
cfg_pytorch = yaml.safe_load(f)
job_pytorch = CommandJob(**schema.load(cfg_pytorch))

# Both should create PyTorchDistribution objects
from azure.ai.ml.entities._job.distribution import PyTorchDistribution
assert isinstance(job_torch_distributed.distribution, PyTorchDistribution)
assert isinstance(job_pytorch.distribution, PyTorchDistribution)

# Verify roundtrip for torch.distributed
rest_obj = job_torch_distributed._to_rest_object()
reconstructed = CommandJob._load_from_rest(rest_obj)
assert isinstance(reconstructed.distribution, PyTorchDistribution)
assert reconstructed.distribution.process_count_per_instance == 4
Comment on lines +101 to +105
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test validates that torch.distributed deserializes correctly and creates a PyTorchDistribution object. However, it doesn't verify that the resulting REST object uses the correct distribution type (RestDistributionType.PY_TORCH). Consider adding an assertion to check that the REST representation is correct, for example:

rest_obj = job_torch_distributed._to_rest_object()
assert rest_obj.properties.distribution.distribution_type == "PyTorch"

This would ensure that the SDK_TO_REST mapping is working correctly and that the alias normalization is complete.

Copilot uses AI. Check for mistakes.
Comment on lines +101 to +105
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test validates round-trip serialization but doesn't verify what value the type field has after deserialization and re-serialization. According to the PR description, torch.distributed should be normalized to pytorch internally. Consider adding an assertion to verify the normalized type after round-trip:

rest_obj = job_torch_distributed._to_rest_object()
reconstructed = CommandJob._load_from_rest(rest_obj)
# Verify that the type is normalized to "pytorch" after round-trip
assert reconstructed.distribution.type == DistributionType.PYTORCH

This ensures that the internal normalization is working correctly and that torch.distributed is transparently converted to pytorch.

Copilot uses AI. Check for mistakes.

def test_deserialize_inputs(self):
test_path = "./tests/test_configs/command_job/command_job_inputs_test.yml"
with open("./tests/test_configs/command_job/command_job_inputs_rest.yml", "r") as f:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import pytest
from azure.ai.ml.entities._job.pipeline._pipeline_job_helpers import from_dict_to_rest_distribution


@pytest.mark.unittest
@pytest.mark.training_experiences_test
class TestDistributionHelpers:
"""Test distribution helper functions to ensure torch.distributed is supported"""

def test_from_dict_to_rest_distribution_pytorch(self):
"""Test that pytorch type is properly handled"""
distribution_dict = {
"distribution_type": "pytorch",
"process_count_per_instance": 4
}
result = from_dict_to_rest_distribution(distribution_dict)
assert result is not None
assert hasattr(result, 'process_count_per_instance')
assert result.process_count_per_instance == 4

def test_from_dict_to_rest_distribution_torch_distributed(self):
"""Test that torch.distributed type is properly handled"""
distribution_dict = {
"distribution_type": "torch.distributed",
"process_count_per_instance": 4
}
result = from_dict_to_rest_distribution(distribution_dict)
assert result is not None
assert hasattr(result, 'process_count_per_instance')
assert result.process_count_per_instance == 4
Comment on lines +25 to +34
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test validates that the from_dict_to_rest_distribution helper function accepts torch.distributed, but it doesn't verify that both pytorch and torch.distributed types result in the same REST object type. Consider adding an assertion to verify that both types produce equivalent REST objects:

# Compare results from both types
pytorch_dict = {"distribution_type": "pytorch", "process_count_per_instance": 4}
torch_dist_dict = {"distribution_type": "torch.distributed", "process_count_per_instance": 4}
pytorch_result = from_dict_to_rest_distribution(pytorch_dict)
torch_dist_result = from_dict_to_rest_distribution(torch_dist_dict)
assert type(pytorch_result) == type(torch_dist_result)

This would strengthen the test by verifying that the alias normalization produces equivalent results.

Copilot uses AI. Check for mistakes.

def test_from_dict_to_rest_distribution_pytorch_case_insensitive(self):
"""Test that PyTorch (mixed case) type is properly handled"""
distribution_dict = {
"distribution_type": "PyTorch",
"process_count_per_instance": 2
}
result = from_dict_to_rest_distribution(distribution_dict)
assert result is not None
assert result.process_count_per_instance == 2

def test_from_dict_to_rest_distribution_torch_distributed_case_insensitive(self):
"""Test that torch.distributed (mixed case) type is properly handled"""
distribution_dict = {
"distribution_type": "Torch.Distributed",
"process_count_per_instance": 2
}
result = from_dict_to_rest_distribution(distribution_dict)
assert result is not None
assert result.process_count_per_instance == 2
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
command: pip freeze
environment: azureml:AzureML-sklearn-1.0-ubuntu20.04-py38-cpu:33
name: "test_pytorch_torch_distributed"
compute: "azureml:testCompute"
distribution:
type: "torch.distributed"
process_count_per_instance: 4
experiment_name: mfe-test1
properties:
test_property: test_value
resources:
instance_count: 2
limits:
timeout: 30
Loading