|
16 | 16 |
|
17 | 17 |
|
18 | 18 | import argparse |
| 19 | +import gc |
19 | 20 | import json |
| 21 | +import sys |
20 | 22 | from multiprocessing import Pool |
21 | 23 | from typing import Dict, List, Tuple, Union |
22 | 24 |
|
@@ -94,15 +96,22 @@ def run_benchmarks(args: argparse.Namespace) -> int: |
94 | 96 | bench_cases = early_filtering(bench_cases, param_filters) |
95 | 97 |
|
96 | 98 | # prefetch datasets |
97 | | - if args.prefetch_datasets: |
| 99 | + if args.prefetch_datasets or args.describe_datasets: |
98 | 100 | # trick: get unique dataset names only to avoid loading of same dataset |
99 | 101 | # by different cases/processes |
100 | 102 | dataset_cases = {get_data_name(case): case for case in bench_cases} |
101 | 103 | logger.debug(f"Unique dataset names to load:\n{list(dataset_cases.keys())}") |
102 | 104 | n_proc = min([16, cpu_count(), len(dataset_cases)]) |
103 | 105 | logger.info(f"Prefetching datasets with {n_proc} processes") |
104 | 106 | with Pool(n_proc) as pool: |
105 | | - pool.map(load_data, dataset_cases.values()) |
| 107 | + datasets = pool.map(load_data, dataset_cases.values()) |
| 108 | + if args.describe_datasets: |
| 109 | + for ((data, data_description), data_name) in zip(datasets, dataset_cases.keys()): |
| 110 | + print(f"{data_name}:\n\tshape: {data['x'].shape}\n\tparameters: {data_description}") |
| 111 | + sys.exit(0) |
| 112 | + # free memory used by prefetched datasets |
| 113 | + del datasets |
| 114 | + gc.collect() |
106 | 115 |
|
107 | 116 | # run bench_cases |
108 | 117 | return_code, result = call_benchmarks( |
|
0 commit comments