2323
2424from ..utils .bench_case import get_bench_case_value
2525from ..utils .logger import logger
26-
26+ from mpi4py import MPI
2727
2828def convert_data (data , dformat : str , order : str , dtype : str , device : str = None ):
2929 if isinstance (data , csr_matrix ) and dformat != "csr_matrix" :
@@ -113,8 +113,36 @@ def split_and_transform_data(bench_case, data, data_description):
113113 "KNeighbors" in get_bench_case_value (bench_case , "algorithm:estimator" , "" )
114114 and int (get_bench_case_value (bench_case , "bench:mpi_params:n" , 1 )) > 1
115115 )
116- if distributed_split == "rank_based" or knn_split_train :
117- from mpi4py import MPI
116+
117+ if distributed_split == "sample_shift" :
118+ comm = MPI .COMM_WORLD
119+ rank = comm .Get_rank ()
120+ size = comm .Get_size ()
121+
122+ n_train = len (x_train )
123+ n_test = len (x_test )
124+
125+ train_start = 0
126+ train_end = n_train
127+ test_start = 0
128+ test_end = n_test
129+
130+ adjust_number = (math .sqrt (rank ) * 0.003 ) + 1
131+
132+ if "y" in data :
133+ x_train , y_train = (
134+ x_train [train_start :train_end ] * adjust_number ,
135+ y_train [train_start :train_end ],
136+ )
137+
138+ x_test , y_test = x_test [test_start :test_end ] * adjust_number , y_test [test_start :test_end ]
139+ else :
140+ x_train = x_train [train_start :train_end ]
141+
142+ x_test = x_test [test_start :test_end ] * adjust_number
143+
144+ elif distributed_split == "rank_based" or knn_split_train :
145+
118146
119147 comm = MPI .COMM_WORLD
120148 rank = comm .Get_rank ()
@@ -127,6 +155,7 @@ def split_and_transform_data(bench_case, data, data_description):
127155 train_end = (1 + rank ) * n_train // size
128156 test_start = rank * n_test // size
129157 test_end = (1 + rank ) * n_test // size
158+ x_train_rank = x_train [train_start :train_end ]
130159
131160 if "y" in data :
132161 x_train , y_train = (
@@ -138,7 +167,7 @@ def split_and_transform_data(bench_case, data, data_description):
138167 else :
139168 x_train = x_train [train_start :train_end ]
140169 if distributed_split == "rank_based" :
141- x_test = x_test [test_start :test_end ]
170+ x_test = x_test [test_start :test_end ] * adjust_number
142171
143172 device = get_bench_case_value (bench_case , "algorithm:device" , None )
144173 common_data_format = get_bench_case_value (bench_case , "data:format" , "pandas" )
0 commit comments