Skip to content

Commit 78eb4a2

Browse files
authored
fix: Enable create_test to correctly parse and apply external providers defined in YAML pipeline specifications. (#37216)
1 parent f7a5bbd commit 78eb4a2

File tree

2 files changed

+55
-5
lines changed

2 files changed

+55
-5
lines changed

sdks/python/apache_beam/yaml/yaml_testing.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,13 @@ def create_test(
411411
**yaml_transform.SafeLineLoader.strip_metadata(
412412
pipeline_spec.get('options', {})))
413413

414+
providers = yaml_provider.merge_providers(
415+
yaml_provider.parse_providers('', pipeline_spec.get('providers', [])),
416+
{
417+
'AssertEqualAndRecord': yaml_provider.as_provider_list(
418+
'AssertEqualAndRecord', AssertEqualAndRecord)
419+
})
420+
414421
def get_name(transform):
415422
if 'name' in transform:
416423
return str(transform['name'])
@@ -428,7 +435,8 @@ def get_name(transform):
428435
mock_outputs = [{
429436
'name': get_name(t),
430437
'elements': [
431-
_try_row_as_dict(row) for row in _first_n(t, options, max_num_inputs)
438+
_try_row_as_dict(row)
439+
for row in _first_n(t, options, max_num_inputs, providers)
432440
],
433441
} for t in input_transforms]
434442

@@ -504,15 +512,18 @@ def record(element):
504512
return pcoll | beam.Map(record)
505513

506514

507-
def _first_n(transform_spec, options, n):
515+
def _first_n(transform_spec, options, n, providers=None):
508516
recorder = RecordElements(n)
517+
if providers is None:
518+
providers = {
519+
'AssertEqualAndRecord': yaml_provider.as_provider_list(
520+
'AssertEqualAndRecord', AssertEqualAndRecord)
521+
}
509522
try:
510523
with beam.Pipeline(options=options) as p:
511524
_ = (
512525
p
513-
| yaml_transform.YamlTransform(
514-
transform_spec,
515-
providers={'AssertEqualAndRecord': AssertEqualAndRecord})
526+
| yaml_transform.YamlTransform(transform_spec, providers=providers)
516527
| recorder)
517528
except _DoneException:
518529
pass

sdks/python/apache_beam/yaml/yaml_testing_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,45 @@ def test_toplevel_providers(self):
356356
}]
357357
})
358358

359+
def test_create_with_external_providers(self):
360+
"""Test that create_test works with external providers defined in the
361+
pipeline spec.
362+
363+
This test validates the fix for issue #37136 where external providers
364+
defined in YAML files were not recognized when running tests.
365+
"""
366+
pipeline = '''
367+
pipeline:
368+
type: chain
369+
transforms:
370+
- type: Create
371+
config:
372+
elements:
373+
- {a: 1, b: 2}
374+
- {a: 2, b: 3}
375+
- {a: 3, b: 4}
376+
- {a: 4, b: 5}
377+
- {a: 5, b: 6}
378+
- type: MyCustomTransform
379+
- type: LogForTesting
380+
providers:
381+
- type: yaml
382+
transforms:
383+
MyCustomTransform:
384+
body:
385+
type: MapToFields
386+
config:
387+
language: python
388+
fields:
389+
sum_ab: a + b
390+
'''
391+
test_spec = yaml_testing.create_test(
392+
pipeline, max_num_inputs=10, min_num_outputs=3)
393+
394+
self.assertEqual(len(test_spec['expected_inputs']), 1)
395+
self.assertGreaterEqual(len(test_spec['expected_inputs'][0]['elements']), 3)
396+
yaml_testing.run_test(pipeline, test_spec)
397+
359398

360399
if __name__ == '__main__':
361400
logging.getLogger().setLevel(logging.INFO)

0 commit comments

Comments
 (0)