-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Support torch.distributed as alias for pytorch distribution type #44968
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -73,6 +73,37 @@ def test_invalid_distribution_config(self): | |
| with pytest.raises(ValidationError): | ||
| CommandJob(**schema.load(cfg)) | ||
|
|
||
| def test_pytorch_torch_distributed_interchangeable(self): | ||
| """Test that type: pytorch and type: torch.distributed work interchangeably""" | ||
| # Test with torch.distributed type | ||
| path_torch_distributed = "./tests/test_configs/command_job/dist_job_pytorch_torch_distributed.yml" | ||
| # Test with pytorch type (existing file) | ||
| path_pytorch = "./tests/test_configs/command_job/dist_job_1.yml" | ||
|
|
||
| context = {BASE_PATH_CONTEXT_KEY: Path(path_torch_distributed).parent} | ||
| schema = CommandJobSchema(context=context) | ||
|
|
||
| # Load and verify torch.distributed type | ||
| with open(path_torch_distributed, "r") as f: | ||
| cfg_torch_distributed = yaml.safe_load(f) | ||
| job_torch_distributed = CommandJob(**schema.load(cfg_torch_distributed)) | ||
|
|
||
| # Load and verify pytorch type | ||
| with open(path_pytorch, "r") as f: | ||
| cfg_pytorch = yaml.safe_load(f) | ||
| job_pytorch = CommandJob(**schema.load(cfg_pytorch)) | ||
|
|
||
| # Both should create PyTorchDistribution objects | ||
| from azure.ai.ml.entities._job.distribution import PyTorchDistribution | ||
| assert isinstance(job_torch_distributed.distribution, PyTorchDistribution) | ||
| assert isinstance(job_pytorch.distribution, PyTorchDistribution) | ||
|
|
||
| # Verify roundtrip for torch.distributed | ||
| rest_obj = job_torch_distributed._to_rest_object() | ||
| reconstructed = CommandJob._load_from_rest(rest_obj) | ||
| assert isinstance(reconstructed.distribution, PyTorchDistribution) | ||
| assert reconstructed.distribution.process_count_per_instance == 4 | ||
|
Comment on lines
+101
to
+105
|
||
|
|
||
| def test_deserialize_inputs(self): | ||
| test_path = "./tests/test_configs/command_job/command_job_inputs_test.yml" | ||
| with open("./tests/test_configs/command_job/command_job_inputs_rest.yml", "r") as f: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| # --------------------------------------------------------- | ||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||
| # --------------------------------------------------------- | ||
|
|
||
| import pytest | ||
| from azure.ai.ml.entities._job.pipeline._pipeline_job_helpers import from_dict_to_rest_distribution | ||
|
|
||
|
|
||
| @pytest.mark.unittest | ||
| @pytest.mark.training_experiences_test | ||
| class TestDistributionHelpers: | ||
| """Test distribution helper functions to ensure torch.distributed is supported""" | ||
|
|
||
| def test_from_dict_to_rest_distribution_pytorch(self): | ||
| """Test that pytorch type is properly handled""" | ||
| distribution_dict = { | ||
| "distribution_type": "pytorch", | ||
| "process_count_per_instance": 4 | ||
| } | ||
| result = from_dict_to_rest_distribution(distribution_dict) | ||
| assert result is not None | ||
| assert hasattr(result, 'process_count_per_instance') | ||
| assert result.process_count_per_instance == 4 | ||
|
|
||
| def test_from_dict_to_rest_distribution_torch_distributed(self): | ||
| """Test that torch.distributed type is properly handled""" | ||
| distribution_dict = { | ||
| "distribution_type": "torch.distributed", | ||
| "process_count_per_instance": 4 | ||
| } | ||
| result = from_dict_to_rest_distribution(distribution_dict) | ||
| assert result is not None | ||
| assert hasattr(result, 'process_count_per_instance') | ||
| assert result.process_count_per_instance == 4 | ||
|
Comment on lines
+25
to
+34
|
||
|
|
||
| def test_from_dict_to_rest_distribution_pytorch_case_insensitive(self): | ||
| """Test that PyTorch (mixed case) type is properly handled""" | ||
| distribution_dict = { | ||
| "distribution_type": "PyTorch", | ||
| "process_count_per_instance": 2 | ||
| } | ||
| result = from_dict_to_rest_distribution(distribution_dict) | ||
| assert result is not None | ||
| assert result.process_count_per_instance == 2 | ||
|
|
||
| def test_from_dict_to_rest_distribution_torch_distributed_case_insensitive(self): | ||
| """Test that torch.distributed (mixed case) type is properly handled""" | ||
| distribution_dict = { | ||
| "distribution_type": "Torch.Distributed", | ||
| "process_count_per_instance": 2 | ||
| } | ||
| result = from_dict_to_rest_distribution(distribution_dict) | ||
| assert result is not None | ||
| assert result.process_count_per_instance == 2 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| command: pip freeze | ||
| environment: azureml:AzureML-sklearn-1.0-ubuntu20.04-py38-cpu:33 | ||
| name: "test_pytorch_torch_distributed" | ||
| compute: "azureml:testCompute" | ||
| distribution: | ||
| type: "torch.distributed" | ||
| process_count_per_instance: 4 | ||
| experiment_name: mfe-test1 | ||
| properties: | ||
| test_property: test_value | ||
| resources: | ||
| instance_count: 2 | ||
| limits: | ||
| timeout: 30 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message lists "torch.distributed" separately from "pytorch", which might confuse users since
torch.distributedis documented as an internal alias for backwards compatibility. Consider updating the message to either:For consistency with the PR's intention to support the alias transparently, option 1 would be more user-friendly.