@@ -970,3 +970,223 @@ def test_advanced_indexing_compute_follows_data():
970970 dpt .put (x , ind0 , val1 , axis = 0 )
971971 with pytest .raises (ExecutionPlacementError ):
972972 x [ind0 ] = val1
973+
974+
975+ #######
976+
977+
978+ def test_extract_all_1d ():
979+ x = dpt .arange (30 , dtype = "i4" )
980+ sel = dpt .ones (30 , dtype = "?" )
981+ sel [::2 ] = False
982+
983+ res = x [sel ]
984+ expected_res = dpt .asnumpy (x )[dpt .asnumpy (sel )]
985+ assert (dpt .asnumpy (res ) == expected_res ).all ()
986+
987+ res2 = dpt .extract (sel , x )
988+ assert (dpt .asnumpy (res2 ) == expected_res ).all ()
989+
990+
991+ def test_extract_all_2d ():
992+ x = dpt .reshape (dpt .arange (30 , dtype = "i4" ), (5 , 6 ))
993+ sel = dpt .ones (30 , dtype = "?" )
994+ sel [::2 ] = False
995+ sel = dpt .reshape (sel , x .shape )
996+
997+ res = x [sel ]
998+ expected_res = dpt .asnumpy (x )[dpt .asnumpy (sel )]
999+ assert (dpt .asnumpy (res ) == expected_res ).all ()
1000+
1001+ res2 = dpt .extract (sel , x )
1002+ assert (dpt .asnumpy (res2 ) == expected_res ).all ()
1003+
1004+
1005+ def test_extract_2D_axis0 ():
1006+ x = dpt .reshape (dpt .arange (30 , dtype = "i4" ), (5 , 6 ))
1007+ sel = dpt .ones (x .shape [0 ], dtype = "?" )
1008+ sel [::2 ] = False
1009+
1010+ res = x [sel ]
1011+ expected_res = dpt .asnumpy (x )[dpt .asnumpy (sel )]
1012+ assert (dpt .asnumpy (res ) == expected_res ).all ()
1013+
1014+
1015+ def test_extract_2D_axis1 ():
1016+ x = dpt .reshape (dpt .arange (30 , dtype = "i4" ), (5 , 6 ))
1017+ sel = dpt .ones (x .shape [1 ], dtype = "?" )
1018+ sel [::2 ] = False
1019+
1020+ res = x [:, sel ]
1021+ expected = dpt .asnumpy (x )[:, dpt .asnumpy (sel )]
1022+ assert (dpt .asnumpy (res ) == expected ).all ()
1023+
1024+
1025+ def test_extract_begin ():
1026+ x = dpt .reshape (dpt .arange (3 * 3 * 4 * 4 , dtype = "i2" ), (3 , 4 , 3 , 4 ))
1027+ y = dpt .permute_dims (x , (2 , 0 , 3 , 1 ))
1028+ sel = dpt .zeros ((3 , 3 ), dtype = "?" )
1029+ sel [0 , 0 ] = True
1030+ sel [1 , 1 ] = True
1031+ z = y [sel ]
1032+ expected = dpt .asnumpy (y )[[0 , 1 ], [0 , 1 ]]
1033+ assert (dpt .asnumpy (z ) == expected ).all ()
1034+
1035+
1036+ def test_extract_end ():
1037+ x = dpt .reshape (dpt .arange (3 * 3 * 4 * 4 , dtype = "i2" ), (3 , 4 , 3 , 4 ))
1038+ y = dpt .permute_dims (x , (2 , 0 , 3 , 1 ))
1039+ sel = dpt .zeros ((4 , 4 ), dtype = "?" )
1040+ sel [0 , 0 ] = True
1041+ z = y [..., sel ]
1042+ expected = dpt .asnumpy (y )[..., [0 ], [0 ]]
1043+ assert (dpt .asnumpy (z ) == expected ).all ()
1044+
1045+
1046+ def test_extract_middle ():
1047+ x = dpt .reshape (dpt .arange (3 * 3 * 4 * 4 , dtype = "i2" ), (3 , 4 , 3 , 4 ))
1048+ y = dpt .permute_dims (x , (2 , 0 , 3 , 1 ))
1049+ sel = dpt .zeros ((3 , 4 ), dtype = "?" )
1050+ sel [0 , 0 ] = True
1051+ z = y [:, sel ]
1052+ expected = dpt .asnumpy (y )[:, [0 ], [0 ], :]
1053+ assert (dpt .asnumpy (z ) == expected ).all ()
1054+
1055+
1056+ def test_extract_empty_result ():
1057+ x = dpt .reshape (dpt .arange (3 * 3 * 4 * 4 , dtype = "i2" ), (3 , 4 , 3 , 4 ))
1058+ y = dpt .permute_dims (x , (2 , 0 , 3 , 1 ))
1059+ sel = dpt .zeros ((3 , 4 ), dtype = "?" )
1060+ z = y [:, sel ]
1061+ assert z .shape == (
1062+ y .shape [0 ],
1063+ 0 ,
1064+ y .shape [3 ],
1065+ )
1066+
1067+
1068+ def test_place_all_1d ():
1069+ x = dpt .arange (10 , dtype = "i2" )
1070+ sel = dpt .zeros (10 , dtype = "?" )
1071+ sel [0 ::2 ] = True
1072+ val = dpt .zeros (5 , dtype = x .dtype )
1073+ x [sel ] = val
1074+ assert (dpt .asnumpy (x ) == np .array ([0 , 1 , 0 , 3 , 0 , 5 , 0 , 7 , 0 , 9 ])).all ()
1075+ dpt .place (x , sel , dpt .asarray (2 ))
1076+ assert (dpt .asnumpy (x ) == np .array ([2 , 1 , 2 , 3 , 2 , 5 , 2 , 7 , 2 , 9 ])).all ()
1077+
1078+
1079+ def test_place_2d_axis0 ():
1080+ x = dpt .reshape (dpt .arange (12 , dtype = "i2" ), (3 , 4 ))
1081+ sel = dpt .asarray ([True , False , True ])
1082+ val = dpt .zeros ((2 , 4 ), dtype = x .dtype )
1083+ x [sel ] = val
1084+ expected_x = np .stack (
1085+ (
1086+ np .zeros (4 , dtype = "i2" ),
1087+ np .arange (4 , 8 , dtype = "i2" ),
1088+ np .zeros (4 , dtype = "i2" ),
1089+ )
1090+ )
1091+ assert (dpt .asnumpy (x ) == expected_x ).all ()
1092+
1093+
1094+ def test_place_2d_axis1 ():
1095+ x = dpt .reshape (dpt .arange (12 , dtype = "i2" ), (3 , 4 ))
1096+ sel = dpt .asarray ([True , False , True , False ])
1097+ val = dpt .zeros ((3 , 2 ), dtype = x .dtype )
1098+ x [:, sel ] = val
1099+ expected_x = np .array (
1100+ [[0 , 1 , 0 , 3 ], [0 , 5 , 0 , 7 ], [0 , 9 , 0 , 11 ]], dtype = "i2"
1101+ )
1102+ assert (dpt .asnumpy (x ) == expected_x ).all ()
1103+
1104+
1105+ def test_place_2d_axis1_scalar ():
1106+ x = dpt .reshape (dpt .arange (12 , dtype = "i2" ), (3 , 4 ))
1107+ sel = dpt .asarray ([True , False , True , False ])
1108+ val = dpt .zeros (tuple (), dtype = x .dtype )
1109+ x [:, sel ] = val
1110+ expected_x = np .array (
1111+ [[0 , 1 , 0 , 3 ], [0 , 5 , 0 , 7 ], [0 , 9 , 0 , 11 ]], dtype = "i2"
1112+ )
1113+ assert (dpt .asnumpy (x ) == expected_x ).all ()
1114+
1115+
1116+ def test_place_all_slices ():
1117+ x = dpt .reshape (dpt .arange (12 , dtype = "i2" ), (3 , 4 ))
1118+ sel = dpt .asarray (
1119+ [
1120+ [False , True , True , False ],
1121+ [True , True , False , False ],
1122+ [False , False , True , True ],
1123+ ],
1124+ dtype = "?" ,
1125+ )
1126+ y = dpt .ones_like (x )
1127+ y [sel ] = x [sel ]
1128+
1129+
1130+ def test_place_some_slices_begin ():
1131+ x = dpt .reshape (dpt .arange (3 * 3 * 4 * 4 , dtype = "i2" ), (3 , 4 , 3 , 4 ))
1132+ y = dpt .permute_dims (x , (2 , 0 , 3 , 1 ))
1133+ sel = dpt .zeros ((3 , 3 ), dtype = "?" )
1134+ sel [0 , 0 ] = True
1135+ sel [1 , 1 ] = True
1136+ z = y [sel ]
1137+ w = dpt .zeros_like (y )
1138+ w [sel ] = z
1139+
1140+
1141+ def test_place_some_slices_mid ():
1142+ x = dpt .reshape (dpt .arange (3 * 3 * 4 * 4 , dtype = "i2" ), (3 , 4 , 3 , 4 ))
1143+ y = dpt .permute_dims (x , (2 , 0 , 3 , 1 ))
1144+ sel = dpt .zeros ((3 , 4 ), dtype = "?" )
1145+ sel [0 , 0 ] = True
1146+ sel [1 , 1 ] = True
1147+ z = y [:, sel ]
1148+ w = dpt .zeros_like (y )
1149+ w [:, sel ] = z
1150+
1151+
1152+ def test_place_some_slices_end ():
1153+ x = dpt .reshape (dpt .arange (3 * 3 * 4 * 4 , dtype = "i2" ), (3 , 4 , 3 , 4 ))
1154+ y = dpt .permute_dims (x , (2 , 0 , 3 , 1 ))
1155+ sel = dpt .zeros ((4 , 4 ), dtype = "?" )
1156+ sel [0 , 0 ] = True
1157+ sel [1 , 1 ] = True
1158+ z = y [:, :, sel ]
1159+ w = dpt .zeros_like (y )
1160+ w [:, :, sel ] = z
1161+
1162+
1163+ def test_place_cycling ():
1164+ x = dpt .zeros (10 , dtype = "f4" )
1165+ y = dpt .asarray ([2 , 3 ])
1166+ sel = dpt .ones (x .size , dtype = "?" )
1167+ dpt .place (x , sel , y )
1168+ expected = np .array (
1169+ [
1170+ 2 ,
1171+ 3 ,
1172+ ]
1173+ * 5 ,
1174+ dtype = x .dtype ,
1175+ )
1176+ assert (dpt .asnumpy (x ) == expected ).all ()
1177+
1178+
1179+ def test_place_subset ():
1180+ x = dpt .zeros (10 , dtype = "f4" )
1181+ y = dpt .ones_like (x )
1182+ sel = dpt .ones (x .size , dtype = "?" )
1183+ sel [::2 ] = False
1184+ dpt .place (x , sel , y )
1185+ expected = np .array ([1 , 3 , 5 , 7 , 9 ], dtype = x .dtype )
1186+ assert (dpt .asnumpy (x ) == expected ).all ()
1187+
1188+
1189+ def test_nonzero ():
1190+ x = dpt .concat ((dpt .zeros (3 ), dpt .ones (4 ), dpt .zeros (3 )))
1191+ (i ,) = dpt .nonzero (x )
1192+ assert dpt .asnumpy (i ) == np .array ([3 , 4 , 5 , 6 ]).all ()
0 commit comments