diff --git a/benchmarking/switchback/README.md b/benchmarking/switchback/README.md deleted file mode 100644 index b73569030..000000000 --- a/benchmarking/switchback/README.md +++ /dev/null @@ -1,4 +0,0 @@ -Steps: - -1. Run `python speed_benchmark/speed_benchmark.py` which times operations and writes their time to `speed_benchmark/info_a100_py2.jsonl` (change the name of the jsonl to a different name for your profiling). -2. Run `python speed_benchmark/make_plot_with_jsonl.py`, which produces the `speed_benchmark/plot_with_info.pdf`. Again make sure you change the jsonl which is being processed. diff --git a/benchmarking/switchback/info_a100_py2.jsonl b/benchmarking/switchback/info_a100_py2.jsonl deleted file mode 100644 index 53cda62cf..000000000 --- a/benchmarking/switchback/info_a100_py2.jsonl +++ /dev/null @@ -1,60 +0,0 @@ -{"repeat": 64, "batch_size": 8192, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.28139352798461914, "standard_gw": 0.2811811864376068, "standard_gx": 0.30258670449256897, "rowwise_fwd": 0.1994594931602478, "rowwise_bwd": 0.16159191727638245, "global_fwd": 0.19502267241477966, "global_bwd": 0.16080215573310852, "x_quantize_rowwise": 0.03306940197944641, "g_quantize_rowwise": 0.08210167288780212, "w_quantize_rowwise": 0.03385916352272034, "w_quantize_colwise_transpose": 0.08635595440864563, "w_quantize_global": 0.09237229824066162, "w_quantize_global_transpose": 0.10007619857788086, "time_standard": 0.8651614189147949, "time_rowwise": 0.8776187896728516, "time_global": 0.944625586271286} -{"repeat": 64, "batch_size": 8192, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.262625515460968, "standard_gw": 0.2806223928928375, "standard_gx": 0.31118839979171753, "rowwise_fwd": 0.1828707754611969, "rowwise_bwd": 0.21236762404441833, "global_fwd": 0.16665831208229065, "global_bwd": 0.19929558038711548, "x_quantize_rowwise": 0.08227676153182983, "g_quantize_rowwise": 0.03310292959213257, "w_quantize_rowwise": 0.032648444175720215, "w_quantize_colwise_transpose": 0.09015202522277832, "w_quantize_global": 0.0988692045211792, "w_quantize_global_transpose": 0.10057538747787476, "time_standard": 0.8544363081455231, "time_rowwise": 0.9140409529209137, "time_global": 0.96140056848526} -{"repeat": 64, "batch_size": 16384, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.5731917917728424, "standard_gw": 0.5709454417228699, "standard_gx": 0.5963630974292755, "rowwise_fwd": 0.37662312388420105, "rowwise_bwd": 0.281747430562973, "global_fwd": 0.36768242716789246, "global_bwd": 0.28043612837791443, "x_quantize_rowwise": 0.046547502279281616, "g_quantize_rowwise": 0.15532970428466797, "w_quantize_rowwise": 0.032436102628707886, "w_quantize_colwise_transpose": 0.08635222911834717, "w_quantize_global": 0.0947415828704834, "w_quantize_global_transpose": 0.10129809379577637, "time_standard": 1.7405003309249878, "time_rowwise": 1.5499815344810486, "time_global": 1.616980880498886} -{"repeat": 64, "batch_size": 16384, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.5341619253158569, "standard_gw": 0.5690865218639374, "standard_gx": 0.599835067987442, "rowwise_fwd": 0.3233291208744049, "rowwise_bwd": 0.41359663009643555, "global_fwd": 0.2831108868122101, "global_bwd": 0.37280842661857605, "x_quantize_rowwise": 0.15563145279884338, "g_quantize_rowwise": 0.046741217374801636, "w_quantize_rowwise": 0.03306940197944641, "w_quantize_colwise_transpose": 0.09020790457725525, "w_quantize_global": 0.0925213098526001, "w_quantize_global_transpose": 0.09945780038833618, "time_standard": 1.7030835151672363, "time_rowwise": 1.6316622495651245, "time_global": 1.6193576157093048} -{"repeat": 64, "batch_size": 32768, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 1.2199915945529938, "standard_gw": 1.1069811880588531, "standard_gx": 1.09761580824852, "rowwise_fwd": 0.738043338060379, "rowwise_bwd": 0.5549229681491852, "global_fwd": 0.7219798862934113, "global_bwd": 0.5512163043022156, "x_quantize_rowwise": 0.08748471736907959, "g_quantize_rowwise": 0.3023110330104828, "w_quantize_rowwise": 0.03182142972946167, "w_quantize_colwise_transpose": 0.08632615208625793, "w_quantize_global": 0.09445473551750183, "w_quantize_global_transpose": 0.10032951831817627, "time_standard": 3.424588590860367, "time_rowwise": 2.9078908264636993, "time_global": 2.9647573828697205} -{"repeat": 64, "batch_size": 32768, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 1.1040829122066498, "standard_gw": 1.1221766471862793, "standard_gx": 1.1548101902008057, "rowwise_fwd": 0.581938773393631, "rowwise_bwd": 0.7480122148990631, "global_fwd": 0.5537159740924835, "global_bwd": 0.7232688367366791, "x_quantize_rowwise": 0.30193477869033813, "g_quantize_rowwise": 0.08745118975639343, "w_quantize_rowwise": 0.03374740481376648, "w_quantize_colwise_transpose": 0.09068101644515991, "w_quantize_global": 0.09645149111747742, "w_quantize_global_transpose": 0.10189786553382874, "time_standard": 3.3810697495937347, "time_rowwise": 2.9659420251846313, "time_global": 2.9868967831134796} -{"repeat": 64, "batch_size": 65536, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 2.4533793330192566, "standard_gw": 2.1938569843769073, "standard_gx": 2.179361879825592, "rowwise_fwd": 1.4615543186664581, "rowwise_bwd": 1.0522231459617615, "global_fwd": 1.4288239181041718, "global_bwd": 1.0450035333633423, "x_quantize_rowwise": 0.1691766083240509, "g_quantize_rowwise": 0.5951300263404846, "w_quantize_rowwise": 0.03337860107421875, "w_quantize_colwise_transpose": 0.08653849363327026, "w_quantize_global": 0.0940859317779541, "w_quantize_global_transpose": 0.09976327419281006, "time_standard": 6.826598197221756, "time_rowwise": 5.5918581783771515, "time_global": 5.625840276479721} -{"repeat": 64, "batch_size": 65536, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 2.1698065102100372, "standard_gw": 2.1875128149986267, "standard_gx": 2.2887587547302246, "rowwise_fwd": 1.0762326419353485, "rowwise_bwd": 1.4638006687164307, "global_fwd": 1.0450668632984161, "global_bwd": 1.4308765530586243, "x_quantize_rowwise": 0.5953535437583923, "g_quantize_rowwise": 0.16899779438972473, "w_quantize_rowwise": 0.03240257501602173, "w_quantize_colwise_transpose": 0.09106099605560303, "w_quantize_global": 0.09546056389808655, "w_quantize_global_transpose": 0.09852275252342224, "time_standard": 6.6460780799388885, "time_rowwise": 5.615361034870148, "time_global": 5.621790885925293} -{"repeat": 64, "batch_size": 131072, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 4.858218133449554, "standard_gw": 4.3631307780742645, "standard_gx": 4.404045641422272, "rowwise_fwd": 2.9063820838928223, "rowwise_bwd": 2.094462513923645, "global_fwd": 2.8426870703697205, "global_bwd": 2.0792782306671143, "x_quantize_rowwise": 0.33241137862205505, "g_quantize_rowwise": 1.1817105114459991, "w_quantize_rowwise": 0.03374367952346802, "w_quantize_colwise_transpose": 0.08633732795715332, "w_quantize_global": 0.09231641888618469, "w_quantize_global_transpose": 0.100012868642807, "time_standard": 13.62539455294609, "time_rowwise": 10.998178273439407, "time_global": 10.991547256708145} -{"repeat": 64, "batch_size": 131072, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 4.246581345796585, "standard_gw": 4.42587211728096, "standard_gx": 4.581417888402939, "rowwise_fwd": 2.1114833652973175, "rowwise_bwd": 2.9050447046756744, "global_fwd": 2.0806826651096344, "global_bwd": 2.85966694355011, "x_quantize_rowwise": 1.1816024780273438, "g_quantize_rowwise": 0.33330172300338745, "w_quantize_rowwise": 0.033445656299591064, "w_quantize_colwise_transpose": 0.09065866470336914, "w_quantize_global": 0.09239837527275085, "w_quantize_global_transpose": 0.09984523057937622, "time_standard": 13.253871351480484, "time_rowwise": 11.081408709287643, "time_global": 11.073369532823563} -{"repeat": 64, "batch_size": 8192, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 0.4859529435634613, "standard_gw": 0.46338513493537903, "standard_gx": 0.42321905493736267, "rowwise_fwd": 0.2761557698249817, "rowwise_bwd": 0.20775198936462402, "global_fwd": 0.2713911235332489, "global_bwd": 0.20639970898628235, "x_quantize_rowwise": 0.033095479011535645, "g_quantize_rowwise": 0.11894106864929199, "w_quantize_rowwise": 0.03125518560409546, "w_quantize_colwise_transpose": 0.1424551010131836, "w_quantize_global": 0.07288157939910889, "w_quantize_global_transpose": 0.08071959018707275, "time_standard": 1.372557133436203, "time_rowwise": 1.2730397284030914, "time_global": 1.2468136847019196} -{"repeat": 64, "batch_size": 8192, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.3920421004295349, "standard_gw": 0.44424086809158325, "standard_gx": 0.4759356379508972, "rowwise_fwd": 0.23231282830238342, "rowwise_bwd": 0.28430670499801636, "global_fwd": 0.20883232355117798, "global_bwd": 0.2741999924182892, "x_quantize_rowwise": 0.12018159031867981, "g_quantize_rowwise": 0.03195926547050476, "w_quantize_rowwise": 0.026017427444458008, "w_quantize_colwise_transpose": 0.14733895659446716, "w_quantize_global": 0.07734447717666626, "w_quantize_global_transpose": 0.0788569450378418, "time_standard": 1.3122186064720154, "time_rowwise": 1.2863576412200928, "time_global": 1.235615462064743} -{"repeat": 64, "batch_size": 16384, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 1.0111741721630096, "standard_gw": 0.9267590939998627, "standard_gx": 0.8254274725914001, "rowwise_fwd": 0.5434826016426086, "rowwise_bwd": 0.4077926278114319, "global_fwd": 0.5318708717823029, "global_bwd": 0.40537863969802856, "x_quantize_rowwise": 0.059738755226135254, "g_quantize_rowwise": 0.2299174666404724, "w_quantize_rowwise": 0.02545863389968872, "w_quantize_colwise_transpose": 0.14269724488258362, "w_quantize_global": 0.07300823926925659, "w_quantize_global_transpose": 0.07878988981246948, "time_standard": 2.7633607387542725, "time_rowwise": 2.335846424102783, "time_global": 2.305462956428528} -{"repeat": 64, "batch_size": 16384, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.8095316588878632, "standard_gw": 0.8607134222984314, "standard_gx": 0.9204968810081482, "rowwise_fwd": 0.4275888204574585, "rowwise_bwd": 0.5485899746417999, "global_fwd": 0.41000545024871826, "global_bwd": 0.5317628383636475, "x_quantize_rowwise": 0.2301819622516632, "g_quantize_rowwise": 0.059254467487335205, "w_quantize_rowwise": 0.02466142177581787, "w_quantize_colwise_transpose": 0.14865398406982422, "w_quantize_global": 0.07582828402519226, "w_quantize_global_transpose": 0.08231401443481445, "time_standard": 2.5907419621944427, "time_rowwise": 2.2996440529823303, "time_global": 2.2500604391098022} -{"repeat": 64, "batch_size": 32768, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 2.0658522844314575, "standard_gw": 1.718364655971527, "standard_gx": 1.6660578548908234, "rowwise_fwd": 1.066897064447403, "rowwise_bwd": 0.8070804178714752, "global_fwd": 1.0473169386386871, "global_bwd": 0.8021742105484009, "x_quantize_rowwise": 0.11274218559265137, "g_quantize_rowwise": 0.4518181085586548, "w_quantize_rowwise": 0.026501715183258057, "w_quantize_colwise_transpose": 0.14259666204452515, "w_quantize_global": 0.07484853267669678, "w_quantize_global_transpose": 0.07976219058036804, "time_standard": 5.450274795293808, "time_rowwise": 4.326000809669495, "time_global": 4.287026822566986} -{"repeat": 64, "batch_size": 32768, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 2.7549192309379578, "standard_gw": 1.6954988241195679, "standard_gx": 1.8179528415203094, "rowwise_fwd": 0.8649080991744995, "rowwise_bwd": 1.0746456682682037, "global_fwd": 0.8023083209991455, "global_bwd": 1.0471977293491364, "x_quantize_rowwise": 0.45225024223327637, "g_quantize_rowwise": 0.11286512017250061, "w_quantize_rowwise": 0.0252649188041687, "w_quantize_colwise_transpose": 0.14732033014297485, "w_quantize_global": 0.07537379860877991, "w_quantize_global_transpose": 0.0807642936706543, "time_standard": 6.268370896577835, "time_rowwise": 4.372753202915192, "time_global": 4.266258329153061} -{"repeat": 64, "batch_size": 65536, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 4.098430275917053, "standard_gw": 3.3501461148262024, "standard_gx": 5.560480058193207, "rowwise_fwd": 2.112947404384613, "rowwise_bwd": 1.605246216058731, "global_fwd": 2.0697638392448425, "global_bwd": 1.5953518450260162, "x_quantize_rowwise": 0.21921470761299133, "g_quantize_rowwise": 0.8956789970397949, "w_quantize_rowwise": 0.02710893750190735, "w_quantize_colwise_transpose": 0.14268234372138977, "w_quantize_global": 0.07259473204612732, "w_quantize_global_transpose": 0.07899105548858643, "time_standard": 13.009056448936462, "time_rowwise": 8.35302472114563, "time_global": 8.281741291284561} -{"repeat": 64, "batch_size": 65536, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 5.586959421634674, "standard_gw": 3.358360379934311, "standard_gx": 3.6434978246688843, "rowwise_fwd": 1.6269534826278687, "rowwise_bwd": 2.128206193447113, "global_fwd": 1.5950687229633331, "global_bwd": 2.0831897854804993, "x_quantize_rowwise": 0.8954145014286041, "g_quantize_rowwise": 0.21914392709732056, "w_quantize_rowwise": 0.026203691959381104, "w_quantize_colwise_transpose": 0.14658644795417786, "w_quantize_global": 0.07478520274162292, "w_quantize_global_transpose": 0.07964670658111572, "time_standard": 12.58881762623787, "time_rowwise": 8.400868624448776, "time_global": 8.305609226226807} -{"repeat": 64, "batch_size": 131072, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 8.229725062847137, "standard_gw": 6.791356950998306, "standard_gx": 6.806455552577972, "rowwise_fwd": 4.252471029758453, "rowwise_bwd": 3.2062679529190063, "global_fwd": 4.175614565610886, "global_bwd": 3.1837262213230133, "x_quantize_rowwise": 0.4321373999118805, "g_quantize_rowwise": 1.787092536687851, "w_quantize_rowwise": 0.0270158052444458, "w_quantize_colwise_transpose": 0.1424252986907959, "w_quantize_global": 0.07348507642745972, "w_quantize_global_transpose": 0.07829815149307251, "time_standard": 21.827537566423416, "time_rowwise": 16.63876697421074, "time_global": 16.52171090245247} -{"repeat": 64, "batch_size": 131072, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 11.279478669166565, "standard_gw": 6.7345499992370605, "standard_gx": 7.206875830888748, "rowwise_fwd": 3.209315240383148, "rowwise_bwd": 4.256397485733032, "global_fwd": 3.180190920829773, "global_bwd": 4.177983850240707, "x_quantize_rowwise": 1.7836056649684906, "g_quantize_rowwise": 0.4321075975894928, "w_quantize_rowwise": 0.03205239772796631, "w_quantize_colwise_transpose": 0.14675036072731018, "w_quantize_global": 0.09316205978393555, "w_quantize_global_transpose": 0.10086596012115479, "time_standard": 25.220904499292374, "time_rowwise": 16.5947787463665, "time_global": 16.502466052770615} -{"repeat": 64, "batch_size": 8192, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 0.5776733160018921, "standard_gw": 0.5300231277942657, "standard_gx": 0.6005913019180298, "rowwise_fwd": 0.33330172300338745, "rowwise_bwd": 0.2957060933113098, "global_fwd": 0.32876431941986084, "global_bwd": 0.29108673334121704, "x_quantize_rowwise": 0.03466755151748657, "g_quantize_rowwise": 0.12264400720596313, "w_quantize_rowwise": 0.033874064683914185, "w_quantize_colwise_transpose": 0.1775398850440979, "w_quantize_global": 0.09503215551376343, "w_quantize_global_transpose": 0.10617449879646301, "time_standard": 1.7082877457141876, "time_rowwise": 1.5277564525604248, "time_global": 1.5083923935890198} -{"repeat": 64, "batch_size": 8192, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 0.5164109170436859, "standard_gw": 0.5367249250411987, "standard_gx": 0.5876161158084869, "rowwise_fwd": 0.3132447600364685, "rowwise_bwd": 0.3396235406398773, "global_fwd": 0.2943649888038635, "global_bwd": 0.33209100365638733, "x_quantize_rowwise": 0.12357160449028015, "g_quantize_rowwise": 0.035997480154037476, "w_quantize_rowwise": 0.03213062882423401, "w_quantize_colwise_transpose": 0.17676874995231628, "w_quantize_global": 0.09861215949058533, "w_quantize_global_transpose": 0.0998862087726593, "time_standard": 1.6407519578933716, "time_rowwise": 1.5580616891384125, "time_global": 1.5212483704090118} -{"repeat": 64, "batch_size": 16384, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 1.2096501886844635, "standard_gw": 1.0663382709026337, "standard_gx": 1.0961703956127167, "rowwise_fwd": 0.6396733224391937, "rowwise_bwd": 0.5173943936824799, "global_fwd": 0.6296299397945404, "global_bwd": 0.5130060017108917, "x_quantize_rowwise": 0.06211921572685242, "g_quantize_rowwise": 0.2361498773097992, "w_quantize_rowwise": 0.03260001540184021, "w_quantize_colwise_transpose": 0.17679482698440552, "w_quantize_global": 0.09361281991004944, "w_quantize_global_transpose": 0.09913742542266846, "time_standard": 3.372158855199814, "time_rowwise": 2.7310699224472046, "time_global": 2.6999935507774353} -{"repeat": 64, "batch_size": 16384, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 1.1065565049648285, "standard_gw": 1.0664314031600952, "standard_gx": 1.1266544461250305, "rowwise_fwd": 0.5352050065994263, "rowwise_bwd": 0.6464086472988129, "global_fwd": 0.513765960931778, "global_bwd": 0.6284862756729126, "x_quantize_rowwise": 0.23620948195457458, "g_quantize_rowwise": 0.062271952629089355, "w_quantize_rowwise": 0.031460076570510864, "w_quantize_colwise_transpose": 0.17675384879112244, "w_quantize_global": 0.09486451745033264, "w_quantize_global_transpose": 0.09898096323013306, "time_standard": 3.2996423542499542, "time_rowwise": 2.7547404170036316, "time_global": 2.7010105550289154} -{"repeat": 64, "batch_size": 32768, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 2.4367496371269226, "standard_gw": 2.0806193351745605, "standard_gx": 2.19624862074852, "rowwise_fwd": 1.2554042041301727, "rowwise_bwd": 1.0227933526039124, "global_fwd": 1.2322552502155304, "global_bwd": 1.0152235627174377, "x_quantize_rowwise": 0.11792033910751343, "g_quantize_rowwise": 0.4639364778995514, "w_quantize_rowwise": 0.03241002559661865, "w_quantize_colwise_transpose": 0.17657503485679626, "w_quantize_global": 0.09655207395553589, "w_quantize_global_transpose": 0.09958073496818542, "time_standard": 6.713617593050003, "time_rowwise": 5.149658769369125, "time_global": 5.106087774038315} -{"repeat": 64, "batch_size": 32768, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 2.1935217082500458, "standard_gw": 2.0055584609508514, "standard_gx": 2.1882541477680206, "rowwise_fwd": 1.0396353900432587, "rowwise_bwd": 1.2542344629764557, "global_fwd": 1.0161921381950378, "global_bwd": 1.233428716659546, "x_quantize_rowwise": 0.4642195999622345, "g_quantize_rowwise": 0.11782720685005188, "w_quantize_rowwise": 0.033117830753326416, "w_quantize_colwise_transpose": 0.17696991562843323, "w_quantize_global": 0.09416043758392334, "w_quantize_global_transpose": 0.10101497173309326, "time_standard": 6.387334316968918, "time_rowwise": 5.091562867164612, "time_global": 5.032401531934738} -{"repeat": 64, "batch_size": 65536, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 4.804681986570358, "standard_gw": 4.763372242450714, "standard_gx": 4.064023494720459, "rowwise_fwd": 2.484843134880066, "rowwise_bwd": 1.9691288471221924, "global_fwd": 2.441786229610443, "global_bwd": 1.9574686884880066, "x_quantize_rowwise": 0.2294592559337616, "g_quantize_rowwise": 0.9196549654006958, "w_quantize_rowwise": 0.0313781201839447, "w_quantize_colwise_transpose": 0.1768544316291809, "w_quantize_global": 0.09644776582717896, "w_quantize_global_transpose": 0.09847059845924377, "time_standard": 13.632077723741531, "time_rowwise": 10.574690997600555, "time_global": 10.506659746170044} -{"repeat": 64, "batch_size": 65536, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 4.0907710790634155, "standard_gw": 3.9793066680431366, "standard_gx": 4.302978515625, "rowwise_fwd": 1.992940902709961, "rowwise_bwd": 2.4996213614940643, "global_fwd": 1.9551962614059448, "global_bwd": 2.457551658153534, "x_quantize_rowwise": 0.9200014173984528, "g_quantize_rowwise": 0.2293996512889862, "w_quantize_rowwise": 0.0313781201839447, "w_quantize_colwise_transpose": 0.17882883548736572, "w_quantize_global": 0.09540095925331116, "w_quantize_global_transpose": 0.09880587458610535, "time_standard": 12.373056262731552, "time_rowwise": 9.831476956605911, "time_global": 9.73566249012947} -{"repeat": 64, "batch_size": 131072, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 9.655728936195374, "standard_gw": 8.261296898126602, "standard_gx": 8.064884692430496, "rowwise_fwd": 5.007706582546234, "rowwise_bwd": 3.8615092635154724, "global_fwd": 4.920527338981628, "global_bwd": 3.8330331444740295, "x_quantize_rowwise": 0.45276060700416565, "g_quantize_rowwise": 1.8306002020835876, "w_quantize_rowwise": 0.031366944313049316, "w_quantize_colwise_transpose": 0.1766495406627655, "w_quantize_global": 0.09412690997123718, "w_quantize_global_transpose": 0.09780004620552063, "time_standard": 25.981910526752472, "time_rowwise": 19.621890038251877, "time_global": 19.49014514684677} -{"repeat": 64, "batch_size": 131072, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 8.033104240894318, "standard_gw": 8.2889124751091, "standard_gx": 8.622754365205765, "rowwise_fwd": 3.8747042417526245, "rowwise_bwd": 5.003921687602997, "global_fwd": 3.8315393030643463, "global_bwd": 4.9162134528160095, "x_quantize_rowwise": 1.8304847180843353, "g_quantize_rowwise": 0.4522763192653656, "w_quantize_rowwise": 0.03413110971450806, "w_quantize_colwise_transpose": 0.1771189272403717, "w_quantize_global": 0.09519979357719421, "w_quantize_global_transpose": 0.09930506348609924, "time_standard": 24.944771081209183, "time_rowwise": 19.661549478769302, "time_global": 19.51393112540245} -{"repeat": 64, "batch_size": 8192, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 0.7954612374305725, "standard_gw": 0.7456131279468536, "standard_gx": 0.8799619972705841, "rowwise_fwd": 0.43267011642456055, "rowwise_bwd": 0.34622475504875183, "global_fwd": 0.42615458369255066, "global_bwd": 0.344250351190567, "x_quantize_rowwise": 0.03748014569282532, "g_quantize_rowwise": 0.13304129242897034, "w_quantize_rowwise": 0.03294646739959717, "w_quantize_colwise_transpose": 0.2407953143119812, "w_quantize_global": 0.094633549451828, "w_quantize_global_transpose": 0.10305643081665039, "time_standard": 2.4210363626480103, "time_rowwise": 1.96877121925354, "time_global": 1.8842294812202454} -{"repeat": 64, "batch_size": 8192, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 0.7120333611965179, "standard_gw": 0.7622130215167999, "standard_gx": 0.8262209594249725, "rowwise_fwd": 0.3702230751514435, "rowwise_bwd": 0.4419572651386261, "global_fwd": 0.3479123115539551, "global_bwd": 0.4306286573410034, "x_quantize_rowwise": 0.13308599591255188, "g_quantize_rowwise": 0.037495046854019165, "w_quantize_rowwise": 0.03398209810256958, "w_quantize_colwise_transpose": 0.23782625794410706, "w_quantize_global": 0.09853765368461609, "w_quantize_global_transpose": 0.10247156023979187, "time_standard": 2.3004673421382904, "time_rowwise": 2.016782760620117, "time_global": 1.9123442471027374} -{"repeat": 64, "batch_size": 16384, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 1.6292817890644073, "standard_gw": 1.5109702944755554, "standard_gx": 1.482747495174408, "rowwise_fwd": 0.8386112749576569, "rowwise_bwd": 0.6844550371170044, "global_fwd": 0.8220970630645752, "global_bwd": 0.6802082061767578, "x_quantize_rowwise": 0.06883963942527771, "g_quantize_rowwise": 0.25641173124313354, "w_quantize_rowwise": 0.033054500818252563, "w_quantize_colwise_transpose": 0.24027004837989807, "w_quantize_global": 0.0967271625995636, "w_quantize_global_transpose": 0.102948397397995, "time_standard": 4.622999578714371, "time_rowwise": 3.6326125264167786, "time_global": 3.5382024943828583} -{"repeat": 64, "batch_size": 16384, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 1.4877021312713623, "standard_gw": 1.5015341341495514, "standard_gx": 1.529306173324585, "rowwise_fwd": 0.715944916009903, "rowwise_bwd": 0.8529908955097198, "global_fwd": 0.680088996887207, "global_bwd": 0.8224695920944214, "x_quantize_rowwise": 0.2568177878856659, "g_quantize_rowwise": 0.06864592432975769, "w_quantize_rowwise": 0.03343448042869568, "w_quantize_colwise_transpose": 0.23645907640457153, "w_quantize_global": 0.09399279952049255, "w_quantize_global_transpose": 0.10286271572113037, "time_standard": 4.518542438745499, "time_rowwise": 3.665827214717865, "time_global": 3.5264119505882263} -{"repeat": 64, "batch_size": 32768, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 3.261040896177292, "standard_gw": 2.8816498816013336, "standard_gx": 2.8357282280921936, "rowwise_fwd": 1.6594752669334412, "rowwise_bwd": 1.359265297651291, "global_fwd": 1.6287527978420258, "global_bwd": 1.3503879308700562, "x_quantize_rowwise": 0.13146549463272095, "g_quantize_rowwise": 0.5035959184169769, "w_quantize_rowwise": 0.03438442945480347, "w_quantize_colwise_transpose": 0.24086236953735352, "w_quantize_global": 0.0945068895816803, "w_quantize_global_transpose": 0.10332837700843811, "time_standard": 8.978419005870819, "time_rowwise": 6.8106986582279205, "time_global": 6.693687289953232} -{"repeat": 64, "batch_size": 32768, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 2.848360687494278, "standard_gw": 2.8955675661563873, "standard_gx": 3.0499882996082306, "rowwise_fwd": 1.3900883495807648, "rowwise_bwd": 1.6595833003520966, "global_fwd": 1.3514049351215363, "global_bwd": 1.629263162612915, "x_quantize_rowwise": 0.5036592483520508, "g_quantize_rowwise": 0.13118237257003784, "w_quantize_rowwise": 0.03438442945480347, "w_quantize_colwise_transpose": 0.23709610104560852, "w_quantize_global": 0.0951625406742096, "w_quantize_global_transpose": 0.10216236114501953, "time_standard": 8.793916553258896, "time_rowwise": 6.851561367511749, "time_global": 6.708402186632156} -{"repeat": 64, "batch_size": 65536, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 6.4978525042533875, "standard_gw": 6.462603807449341, "standard_gx": 5.5987648665905, "rowwise_fwd": 3.2996535301208496, "rowwise_bwd": 2.6320070028305054, "global_fwd": 3.2426007091999054, "global_bwd": 2.612769603729248, "x_quantize_rowwise": 0.2561397850513458, "g_quantize_rowwise": 0.9984448552131653, "w_quantize_rowwise": 0.033076852560043335, "w_quantize_colwise_transpose": 0.24232640862464905, "w_quantize_global": 0.09618699550628662, "w_quantize_global_transpose": 0.10257214307785034, "time_standard": 18.559221178293228, "time_rowwise": 13.9242522418499, "time_global": 13.771317899227142} -{"repeat": 64, "batch_size": 65536, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 5.5702440440654755, "standard_gw": 5.717620253562927, "standard_gx": 6.08203187584877, "rowwise_fwd": 2.649586647748947, "rowwise_bwd": 3.315173089504242, "global_fwd": 2.6132799685001373, "global_bwd": 3.257807344198227, "x_quantize_rowwise": 0.9980201721191406, "g_quantize_rowwise": 0.256560742855072, "w_quantize_rowwise": 0.03356859087944031, "w_quantize_colwise_transpose": 0.23729726672172546, "w_quantize_global": 0.09495764970779419, "w_quantize_global_transpose": 0.103779137134552, "time_standard": 17.369896173477173, "time_rowwise": 13.207826763391495, "time_global": 13.04202526807785} -{"repeat": 64, "batch_size": 131072, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 13.058379292488098, "standard_gw": 11.480242013931274, "standard_gx": 11.092845350503922, "rowwise_fwd": 6.637874990701675, "rowwise_bwd": 5.24790957570076, "global_fwd": 6.521012634038925, "global_bwd": 5.214303731918335, "x_quantize_rowwise": 0.5057565867900848, "g_quantize_rowwise": 1.989319920539856, "w_quantize_rowwise": 0.03439188003540039, "w_quantize_colwise_transpose": 0.24280324578285217, "w_quantize_global": 0.09520724415779114, "w_quantize_global_transpose": 0.10240450501441956, "time_standard": 35.631466656923294, "time_rowwise": 26.138298213481903, "time_global": 25.908246636390686} -{"repeat": 64, "batch_size": 131072, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 11.13397628068924, "standard_gw": 11.371888220310211, "standard_gx": 12.12756335735321, "rowwise_fwd": 5.2495077252388, "rowwise_bwd": 6.638709455728531, "global_fwd": 5.215313285589218, "global_bwd": 6.5222084522247314, "x_quantize_rowwise": 1.9870512187480927, "g_quantize_rowwise": 0.5058236420154572, "w_quantize_rowwise": 0.034634023904800415, "w_quantize_colwise_transpose": 0.23674964904785156, "w_quantize_global": 0.09457767009735107, "w_quantize_global_transpose": 0.10183081030845642, "time_standard": 34.63342785835266, "time_rowwise": 26.024363934993744, "time_global": 25.798693299293518} -{"repeat": 64, "batch_size": 8192, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 1.2125298380851746, "standard_gw": 1.1111274361610413, "standard_gx": 1.0840706527233124, "rowwise_fwd": 0.6057210266590118, "rowwise_bwd": 0.51865354180336, "global_fwd": 0.5952082574367523, "global_bwd": 0.5167685449123383, "x_quantize_rowwise": 0.045686960220336914, "g_quantize_rowwise": 0.15827640891075134, "w_quantize_rowwise": 0.04361197352409363, "w_quantize_colwise_transpose": 0.34067779779434204, "w_quantize_global": 0.13644620776176453, "w_quantize_global_transpose": 0.14925003051757812, "time_standard": 3.407727926969528, "time_rowwise": 2.823755145072937, "time_global": 2.7127638459205627} -{"repeat": 64, "batch_size": 8192, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 1.0731369256973267, "standard_gw": 1.1365897953510284, "standard_gx": 1.1498592793941498, "rowwise_fwd": 0.5573518574237823, "rowwise_bwd": 0.615488737821579, "global_fwd": 0.5220361053943634, "global_bwd": 0.5939789116382599, "x_quantize_rowwise": 0.15765801072120667, "g_quantize_rowwise": 0.04369020462036133, "w_quantize_rowwise": 0.047359615564346313, "w_quantize_colwise_transpose": 0.5526281893253326, "w_quantize_global": 0.13606995344161987, "w_quantize_global_transpose": 0.15017390251159668, "time_standard": 3.359586000442505, "time_rowwise": 3.1107664108276367, "time_global": 2.7401968836784363} -{"repeat": 64, "batch_size": 16384, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 2.4274885654449463, "standard_gw": 2.1799951791763306, "standard_gx": 2.1426528692245483, "rowwise_fwd": 1.195710152387619, "rowwise_bwd": 1.027170568704605, "global_fwd": 1.1747106909751892, "global_bwd": 1.0251589119434357, "x_quantize_rowwise": 0.08098781108856201, "g_quantize_rowwise": 0.3052949905395508, "w_quantize_rowwise": 0.043764710426330566, "w_quantize_colwise_transpose": 0.33987686038017273, "w_quantize_global": 0.13646483421325684, "w_quantize_global_transpose": 0.14739856123924255, "time_standard": 6.750136613845825, "time_rowwise": 5.172800272703171, "time_global": 5.050010979175568} -{"repeat": 64, "batch_size": 16384, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 2.1661892533302307, "standard_gw": 2.0948275923728943, "standard_gx": 2.306375652551651, "rowwise_fwd": 1.0587647557258606, "rowwise_bwd": 1.1999905109405518, "global_fwd": 1.0296404361724854, "global_bwd": 1.1749230325222015, "x_quantize_rowwise": 0.3054030239582062, "g_quantize_rowwise": 0.08077546954154968, "w_quantize_rowwise": 0.047225505113601685, "w_quantize_colwise_transpose": 0.600133091211319, "w_quantize_global": 0.13613328337669373, "w_quantize_global_transpose": 0.1484006643295288, "time_standard": 6.567392498254776, "time_rowwise": 5.387119948863983, "time_global": 4.97010350227356} -{"repeat": 64, "batch_size": 32768, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 4.807606339454651, "standard_gw": 4.170913249254227, "standard_gx": 4.117622971534729, "rowwise_fwd": 2.370934933423996, "rowwise_bwd": 1.9481778144836426, "global_fwd": 2.3383721709251404, "global_bwd": 1.9443817436695099, "x_quantize_rowwise": 0.1547597348690033, "g_quantize_rowwise": 0.6000511348247528, "w_quantize_rowwise": 0.04361942410469055, "w_quantize_colwise_transpose": 0.3403201699256897, "w_quantize_global": 0.13600289821624756, "w_quantize_global_transpose": 0.1474134624004364, "time_standard": 13.096142560243607, "time_rowwise": 9.628776460886002, "time_global": 9.491894394159317} -{"repeat": 64, "batch_size": 32768, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 4.1619837284088135, "standard_gw": 4.181284457445145, "standard_gx": 4.635505378246307, "rowwise_fwd": 1.9684135913848877, "rowwise_bwd": 2.3750364780426025, "global_fwd": 1.9445866346359253, "global_bwd": 2.3551955819129944, "x_quantize_rowwise": 0.6004162132740021, "g_quantize_rowwise": 0.15468522906303406, "w_quantize_rowwise": 0.04730746150016785, "w_quantize_colwise_transpose": 0.5999617278575897, "w_quantize_global": 0.1364201307296753, "w_quantize_global_transpose": 0.14847144484519958, "time_standard": 12.978773564100266, "time_rowwise": 9.927105158567429, "time_global": 9.521059691905975} -{"repeat": 64, "batch_size": 65536, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 9.52371209859848, "standard_gw": 8.354485034942627, "standard_gx": 8.69860127568245, "rowwise_fwd": 4.717472940683365, "rowwise_bwd": 3.8843750953674316, "global_fwd": 4.645414650440216, "global_bwd": 3.8761012256145477, "x_quantize_rowwise": 0.3024861216545105, "g_quantize_rowwise": 1.1897757649421692, "w_quantize_rowwise": 0.04366785287857056, "w_quantize_colwise_transpose": 0.33988431096076965, "w_quantize_global": 0.1359507441520691, "w_quantize_global_transpose": 0.14724582433700562, "time_standard": 26.576798409223557, "time_rowwise": 18.832147121429443, "time_global": 18.651459366083145} -{"repeat": 64, "batch_size": 65536, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 8.307881653308868, "standard_gw": 8.214320987462997, "standard_gx": 9.21182706952095, "rowwise_fwd": 3.8919784128665924, "rowwise_bwd": 4.72346693277359, "global_fwd": 3.8761794567108154, "global_bwd": 4.673641175031662, "x_quantize_rowwise": 1.1893920600414276, "g_quantize_rowwise": 0.3024972975254059, "w_quantize_rowwise": 0.04708021879196167, "w_quantize_colwise_transpose": 0.6039328873157501, "w_quantize_global": 0.13624504208564758, "w_quantize_global_transpose": 0.14867261052131653, "time_standard": 25.734029710292816, "time_rowwise": 18.972668796777725, "time_global": 18.540948629379272} -{"repeat": 64, "batch_size": 131072, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 19.30372044444084, "standard_gw": 16.480475664138794, "standard_gx": 17.61433482170105, "rowwise_fwd": 9.49602946639061, "rowwise_bwd": 7.768530398607254, "global_fwd": 9.3533955514431, "global_bwd": 7.749464362859726, "x_quantize_rowwise": 0.5977451801300049, "g_quantize_rowwise": 2.3684948682785034, "w_quantize_rowwise": 0.04375725984573364, "w_quantize_colwise_transpose": 0.34042075276374817, "w_quantize_global": 0.13628974556922913, "w_quantize_global_transpose": 0.14671683311462402, "time_standard": 53.398530930280685, "time_rowwise": 37.09545359015465, "time_global": 36.83258220553398} -{"repeat": 64, "batch_size": 131072, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 18.041003495454788, "standard_gw": 17.770148813724518, "standard_gx": 17.70009845495224, "rowwise_fwd": 7.756810635328293, "rowwise_bwd": 9.502101689577103, "global_fwd": 7.7384114265441895, "global_bwd": 9.36170294880867, "x_quantize_rowwise": 2.3686252534389496, "g_quantize_rowwise": 0.5980581045150757, "w_quantize_rowwise": 0.04723668098449707, "w_quantize_colwise_transpose": 0.6035342812538147, "w_quantize_global": 0.13603642582893372, "w_quantize_global_transpose": 0.1485198736190796, "time_standard": 53.511250764131546, "time_rowwise": 38.64651545882225, "time_global": 38.121502846479416} -{"repeat": 64, "batch_size": 8192, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 4.598241299390793, "standard_gw": 4.294309765100479, "standard_gx": 4.261095076799393, "rowwise_fwd": 2.0976848900318146, "rowwise_bwd": 1.9718967378139496, "global_fwd": 2.0763762295246124, "global_bwd": 1.9703581929206848, "x_quantize_rowwise": 0.08216872811317444, "g_quantize_rowwise": 0.4405900835990906, "w_quantize_rowwise": 0.1553371548652649, "w_quantize_colwise_transpose": 1.6110725700855255, "w_quantize_global": 0.481240451335907, "w_quantize_global_transpose": 0.5061514675617218, "time_standard": 13.153646141290665, "time_rowwise": 10.653059929609299, "time_global": 9.85119491815567} -{"repeat": 64, "batch_size": 8192, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 4.35885414481163, "standard_gw": 4.29583340883255, "standard_gx": 4.5370906591415405, "rowwise_fwd": 2.0015686750411987, "rowwise_bwd": 2.097565680742264, "global_fwd": 1.969795674085617, "global_bwd": 2.075403928756714, "x_quantize_rowwise": 0.43984130024909973, "g_quantize_rowwise": 0.08216127753257751, "w_quantize_rowwise": 0.22544339299201965, "w_quantize_colwise_transpose": 2.4342015385627747, "w_quantize_global": 0.48087164759635925, "w_quantize_global_transpose": 0.5099289119243622, "time_standard": 13.19177821278572, "time_rowwise": 11.576615273952484, "time_global": 9.85383614897728} -{"repeat": 64, "batch_size": 16384, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 9.09888744354248, "standard_gw": 8.230950683355331, "standard_gx": 8.465446531772614, "rowwise_fwd": 4.182614386081696, "rowwise_bwd": 3.747660666704178, "global_fwd": 4.138719290494919, "global_bwd": 3.74777615070343, "x_quantize_rowwise": 0.15515834093093872, "g_quantize_rowwise": 0.8699297904968262, "w_quantize_rowwise": 0.15544891357421875, "w_quantize_colwise_transpose": 1.6132444143295288, "w_quantize_global": 0.48100948333740234, "w_quantize_global_transpose": 0.5051903426647186, "time_standard": 25.795284658670425, "time_rowwise": 18.955007195472717, "time_global": 18.128734081983566} -{"repeat": 64, "batch_size": 16384, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 8.378107100725174, "standard_gw": 8.923027664422989, "standard_gx": 9.049762040376663, "rowwise_fwd": 3.765825182199478, "rowwise_bwd": 4.183519631624222, "global_fwd": 3.744799643754959, "global_bwd": 4.1590481996536255, "x_quantize_rowwise": 0.8693933486938477, "g_quantize_rowwise": 0.1553073525428772, "w_quantize_rowwise": 0.2258792519569397, "w_quantize_colwise_transpose": 2.4386271834373474, "w_quantize_global": 0.4811100661754608, "w_quantize_global_transpose": 0.5102269351482391, "time_standard": 26.350896805524826, "time_rowwise": 20.5615796148777, "time_global": 18.842913210392} -{"repeat": 64, "batch_size": 32768, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 18.266115337610245, "standard_gw": 17.671160399913788, "standard_gx": 17.10302010178566, "rowwise_fwd": 8.347474038600922, "rowwise_bwd": 7.514089345932007, "global_fwd": 8.263226598501205, "global_bwd": 7.487393915653229, "x_quantize_rowwise": 0.3021806478500366, "g_quantize_rowwise": 1.7319358885288239, "w_quantize_rowwise": 0.15519559383392334, "w_quantize_colwise_transpose": 1.6133114695549011, "w_quantize_global": 0.48247724771499634, "w_quantize_global_transpose": 0.506427139043808, "time_standard": 53.04029583930969, "time_rowwise": 37.3353473842144, "time_global": 36.44480183720589} -{"repeat": 64, "batch_size": 32768, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 17.73649826645851, "standard_gw": 16.359902918338776, "standard_gx": 18.0993489921093, "rowwise_fwd": 7.493957877159119, "rowwise_bwd": 8.352488279342651, "global_fwd": 7.486194372177124, "global_bwd": 8.28903540968895, "x_quantize_rowwise": 1.7313472926616669, "g_quantize_rowwise": 0.30205026268959045, "w_quantize_rowwise": 0.2255477011203766, "w_quantize_colwise_transpose": 2.4363920092582703, "w_quantize_global": 0.4815347492694855, "w_quantize_global_transpose": 0.5103759467601776, "time_standard": 52.195750176906586, "time_rowwise": 36.90168634057045, "time_global": 35.16044095158577} -{"repeat": 64, "batch_size": 65536, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 36.309611052274704, "standard_gw": 32.85098075866699, "standard_gx": 34.34552624821663, "rowwise_fwd": 16.74525812268257, "rowwise_bwd": 15.026237815618515, "global_fwd": 16.574162989854813, "global_bwd": 14.977734535932541, "x_quantize_rowwise": 0.5954466760158539, "g_quantize_rowwise": 3.4569576382637024, "w_quantize_rowwise": 0.15521422028541565, "w_quantize_colwise_transpose": 1.6133897006511688, "w_quantize_global": 0.4822872579097748, "w_quantize_global_transpose": 0.5065612494945526, "time_standard": 103.50611805915833, "time_rowwise": 70.44348493218422, "time_global": 69.44413110613823} -{"repeat": 64, "batch_size": 65536, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 35.40017828345299, "standard_gw": 33.037226647138596, "standard_gx": 36.30436211824417, "rowwise_fwd": 15.043705701828003, "rowwise_bwd": 16.756191849708557, "global_fwd": 15.011314302682877, "global_bwd": 16.580048948526382, "x_quantize_rowwise": 3.4548528492450714, "g_quantize_rowwise": 0.5951337516307831, "w_quantize_rowwise": 0.22584572434425354, "w_quantize_colwise_transpose": 2.4329908192157745, "w_quantize_global": 0.4813261330127716, "w_quantize_global_transpose": 0.5101598799228668, "time_standard": 104.74176704883575, "time_rowwise": 71.54594734311104, "time_global": 69.67006251215935} -{"repeat": 64, "batch_size": 131072, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 73.40333238244057, "standard_gw": 73.76311346888542, "standard_gx": 70.41774317622185, "rowwise_fwd": 33.37597846984863, "rowwise_bwd": 30.345775187015533, "global_fwd": 33.00366923213005, "global_bwd": 30.218638479709625, "x_quantize_rowwise": 1.1825822293758392, "g_quantize_rowwise": 6.902601569890976, "w_quantize_rowwise": 0.15529245138168335, "w_quantize_colwise_transpose": 1.6109198331832886, "w_quantize_global": 0.48149004578590393, "w_quantize_global_transpose": 0.5066059529781342, "time_standard": 217.58418902754784, "time_rowwise": 147.33626320958138, "time_global": 146.05870097875595} -{"repeat": 64, "batch_size": 131072, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 71.5160183608532, "standard_gw": 73.76786693930626, "standard_gx": 72.98104092478752, "rowwise_fwd": 30.291248112916946, "rowwise_bwd": 33.36654230952263, "global_fwd": 30.181586742401123, "global_bwd": 33.082425594329834, "x_quantize_rowwise": 6.902430206537247, "g_quantize_rowwise": 1.1815279722213745, "w_quantize_rowwise": 0.2262219786643982, "w_quantize_colwise_transpose": 2.4421699345111847, "w_quantize_global": 0.4816502332687378, "w_quantize_global_transpose": 0.5105249583721161, "time_standard": 218.26492622494698, "time_rowwise": 148.17800745368004, "time_global": 146.1080126464367} diff --git a/benchmarking/switchback/make_plot_with_jsonl.py b/benchmarking/switchback/make_plot_with_jsonl.py deleted file mode 100644 index fd0dd7d58..000000000 --- a/benchmarking/switchback/make_plot_with_jsonl.py +++ /dev/null @@ -1,151 +0,0 @@ -import matplotlib.gridspec as gridspec -import matplotlib.pyplot as plt -import pandas as pd - -cmap = plt.get_cmap("cool") - -if __name__ == "__main__": - fig = plt.figure(tight_layout=True, figsize=(12, 3.5)) - gs = gridspec.GridSpec(1, 2) - - dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096] - batch_size_for_plot1 = 32768 - batch_sizes_for_plot2 = [2**14, 2**15, 2**16, 2**17] - dims_to_xtick = [1024, 2048, 4096] - logscale_plot1 = True - - ax = fig.add_subplot(gs[0, 0]) - - # TODO: change this to what you want. - rdf = pd.read_json("speed_benchmark/info_a100_py2.jsonl", lines=True) - df = rdf[rdf.batch_size == batch_size_for_plot1] - - # first plot the time occupied by different operations - for k, marker, ls, color, name in [ - ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (sum of parts)"), - ( - "x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd", - "o", - "-", - "C4", - "SwitchBack int8 (sum of parts)", - ), - ("standard_fwd", "^", "--", "C2", "Matmul XW (standard)"), - ("standard_gw", "^", "-.", "C2", "Matmul GW (standard)"), - ("standard_gx", "^", ":", "gray", "Matmul GX (both)"), - ("global_fwd", "^", "--", "C4", "Int8 Matmul XW (switchback)"), - ("global_bwd", "^", "-.", "C4", "Int8 Matmul GW (switchback)"), - ("x_quantize_rowwise", "P", "--", "C4", "Quantize rowwise X (switchback)"), - ("g_quantize_rowwise", "P", "-.", "C4", "Quantize rowwise G (switchback)"), - ("w_quantize_global", ".", "--", "C4", "Quantize global W (switchback)"), - ("w_quantize_global_transpose", ".", "-.", "C4", "Quantize global and\ntranspose W (switchback)"), - ]: - xs = [] - ys = [] - for embed_dim in dims_to_consider: - # average over dim -> 4*dim and 4*dim -> dim - df_ = df[df.dim_in == embed_dim] - df_ = df_[df_.dim_out == embed_dim * 4] - xs.append(embed_dim) - y_ = 0 - for k_ in k.split("+"): - y_ += df_[k_].values[0] - df_ = df[df.dim_in == embed_dim * 4] - df_ = df_[df_.dim_out == embed_dim] - for k_ in k.split("+"): - y_ += df_[k_].values[0] - ys.append(y_ * 0.5) - - ax.plot( - xs, - ys, - color=color, - label=name, - marker=marker, - markersize=5 if marker == "s" else 5, - linestyle=ls, - linewidth=2 if "+" in k else 1.0, - ) - - ax.set_xlabel("dim", fontsize=13) - ax.set_ylabel("time (ms)", fontsize=13) - - ax.grid() - - ax.set_xscale("log") - if logscale_plot1: - ax.set_yscale("log") - - ax.tick_params(axis="x", labelsize=11) - ax.tick_params(axis="y", labelsize=11) - - ax.set_xticks(dims_to_xtick) - ax.set_xticklabels(dims_to_xtick) - ax.set_xticks([], minor=True) - - leg = ax.legend(loc="upper center", bbox_to_anchor=(-0.64, 1.0), ncol=1, fontsize=10) - leg.get_texts()[0].set_fontweight("bold") - leg.get_texts()[1].set_fontweight("bold") - plt.subplots_adjust(left=0.1) - ax.set_title(" Linear layer, batch * sequence length = 32k", fontsize=10, loc="left", y=1.05, pad=-20) - - ax = fig.add_subplot(gs[0, 1]) - - # now plot the % speedup for different batch sizes - for j, batch_size in enumerate(batch_sizes_for_plot2): - all_xs, all_ys = [], [] - for k, marker, ls, color, name in [ - ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (total time)"), - ( - "x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd", - "o", - "-", - "C4", - "SwitchBack int8 (total time)", - ), - ]: - xs, ys = [], [] - df = rdf[rdf.batch_size == batch_size] - for embed_dim in dims_to_consider: - df_ = df[df.dim_in == embed_dim] - df_ = df_[df_.dim_out == embed_dim * 4] - xs.append(embed_dim) - y_ = 0 - for k_ in k.split("+"): - y_ += df_[k_].values[0] - df_ = df[df.dim_in == embed_dim * 4] - df_ = df_[df_.dim_out == embed_dim] - for k_ in k.split("+"): - y_ += df_[k_].values[0] - ys.append(y_ * 0.5) - all_xs.append(xs) - all_ys.append(ys) - - color = cmap(j * 0.25) - real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))] - markers = ["^", "v", "P", "o"] - ax.plot( - all_xs[0], - real_ys, - color=color, - label=f"batch * sequence length = {batch_size}", - marker=markers[j], - markersize=5 if marker == "s" else 5, - ) - - ax.legend() - ax.set_xlabel("dim", fontsize=13) - ax.set_xscale("log") - ax.grid() - ax.set_ylabel(r"% speedup", fontsize=13) - - ax.tick_params(axis="x", labelsize=11) - ax.tick_params(axis="y", labelsize=11) - - ax.set_xticks(dims_to_xtick) - ax.set_xticklabels(dims_to_xtick) - ax.set_xticks([], minor=True) - - ax.set_title(" Linear layer summary, varying dimensions", fontsize=10, loc="left", y=1.05, pad=-20) - - plt.savefig("speed_benchmark/plot_with_info.pdf", bbox_inches="tight") diff --git a/benchmarking/switchback/plot_with_info.pdf b/benchmarking/switchback/plot_with_info.pdf deleted file mode 100644 index d186e91b7..000000000 Binary files a/benchmarking/switchback/plot_with_info.pdf and /dev/null differ diff --git a/benchmarking/switchback/speed_benchmark.py b/benchmarking/switchback/speed_benchmark.py deleted file mode 100644 index eaba0e9cd..000000000 --- a/benchmarking/switchback/speed_benchmark.py +++ /dev/null @@ -1,160 +0,0 @@ -import json -import time - -import torch - -from bitsandbytes.triton.int8_matmul_mixed_dequantize import ( - int8_matmul_mixed_dequantize, -) -from bitsandbytes.triton.int8_matmul_rowwise_dequantize import ( - int8_matmul_rowwise_dequantize, -) -from bitsandbytes.triton.quantize_columnwise_and_transpose import ( - quantize_columnwise_and_transpose, -) -from bitsandbytes.triton.quantize_global import ( - quantize_global, - quantize_global_transpose, -) -from bitsandbytes.triton.quantize_rowwise import quantize_rowwise - -# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large. - - -def get_time(k, fn, info_dict): - for _ in range(repeat // 2): - fn() - - torch.cuda.synchronize() - start = time.time() - for _ in range(repeat): - fn() - - torch.cuda.synchronize() - end = time.time() - ms = (end - start) / repeat * 1000 - print(f"time {k}: {ms:.3f} ms") - info_dict[k] = ms - - -if __name__ == "__main__": - torch.manual_seed(0) - wm = 4 - for dim in [1024, 1280, 1408, 1664, 2048, 4096]: - # note "batch_size" is actually "batch_size * embed_dim", which is why it's large - for batch_size in [256 * 32, 256 * 64, 256 * 128, 256 * 256, 256 * 512]: - # switch switches dim_in and dim_out - for switch in [False, True]: - # hparams - repeat = 64 - batch_size = batch_size - dim_out = dim * wm - dim_in = dim - if switch: - dim_out = dim - dim_in = wm * dim - - dim_in = round(dim_in) - dim_out = round(dim_out) - - # simulate forward pass - x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda() - g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda() - w = torch.randn(dim_out, dim_in, dtype=torch.float16).cuda() - - x_int8 = x.clone().to(torch.int8) - g_int8 = g.clone().to(torch.int8) - w_int8 = w.clone().to(torch.int8) - wt_int8 = w.t().contiguous().clone().to(torch.int8) - state_x_rowwise = x.max(dim=1)[0] - state_g_rowwise = g.max(dim=1)[0] - state_w_columnwise = w.max(dim=0)[0] - state_w_rowwise = w.max(dim=1)[0] - state_w_global = w.max() - - info = { - "repeat": repeat, - "batch_size": batch_size, - "dim_out": dim_out, - "dim_in": dim_in, - "wm": wm, - "switch": switch, - } - - get_time("standard_fwd", lambda: x.matmul(w.t()), info) - get_time("standard_gw", lambda: g.t().matmul(x), info) - get_time("standard_gx", lambda: g.matmul(w), info) - get_time( - "rowwise_fwd", - lambda: int8_matmul_rowwise_dequantize( - x_int8, - w_int8.t(), - state_x_rowwise, - state_w_columnwise, - None, - ), - info, - ) - get_time( - "rowwise_bwd", - lambda: int8_matmul_rowwise_dequantize( - g_int8, - wt_int8.t(), - state_x_rowwise, - state_w_rowwise, - None, - ), - info, - ) - get_time( - "global_fwd", - lambda: int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), - info, - ) - get_time( - "global_bwd", - lambda: int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), - info, - ) - get_time("x_quantize_rowwise", lambda: quantize_rowwise(x), info) - get_time("g_quantize_rowwise", lambda: quantize_rowwise(g), info) - get_time("w_quantize_rowwise", lambda: quantize_rowwise(w), info) - get_time("w_quantize_colwise_transpose", lambda: quantize_columnwise_and_transpose(w), info) - get_time("w_quantize_global", lambda: quantize_global(w), info) - get_time("w_quantize_global_transpose", lambda: quantize_global_transpose(w), info) - - time_standard = info["standard_fwd"] + info["standard_gx"] + info["standard_gw"] - time_rowwise = ( - info["x_quantize_rowwise"] - + info["g_quantize_rowwise"] - + info["w_quantize_colwise_transpose"] - + info["w_quantize_rowwise"] - + info["standard_gw"] - + info["rowwise_fwd"] - + info["rowwise_bwd"] - ) - time_global = ( - info["x_quantize_rowwise"] - + info["g_quantize_rowwise"] - + info["w_quantize_global"] - + info["w_quantize_global_transpose"] - + info["standard_gw"] - + info["global_fwd"] - + info["global_bwd"] - ) - - print("TOTAL STANDARD", time_standard) - print("TOTAL ROWWISE", time_rowwise) - print("TOTAL GLOBAL", time_global) - - print("speedup", -100 * (time_global - time_standard) / time_standard) - - info["time_standard"] = time_standard - info["time_rowwise"] = time_rowwise - info["time_global"] = time_global - - info_json = json.dumps(info) - - # TODO: change this to what you want. - with open("speed_benchmark/info.jsonl", "a") as file: - file.write(info_json + "\n") diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index bdc150e5e..729342070 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -9,7 +9,7 @@ import torch -from . import _ops, research, utils +from . import _ops, utils from .autograd._functions import ( MatmulLtState, matmul, diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index da168e17b..614c19051 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -53,16 +53,12 @@ def get_current_outlier_idx(self): @dataclass class MatmulLtState: - _tile_indices: Optional[torch.Tensor] = None # TODO: remove - force_no_igemmlt: bool = False CB: Optional[torch.Tensor] = None - CxB: Optional[torch.Tensor] = None # TODO: Deprecate/remove SB: Optional[torch.Tensor] = None SCB: Optional[torch.Tensor] = None - CxBt: Optional[torch.Tensor] = None # TODO: Deprecate/remove SBt: Optional[torch.Tensor] = None CBt: Optional[torch.Tensor] = None @@ -75,22 +71,29 @@ class MatmulLtState: is_training = True has_fp16_weights = True use_pool = False - formatB = "row" # TODO: Deprecate/remove + + # Deprecated attributes kept for downstream compatibility (TGI, vLLM). + # These are always None and will be fully removed in the next release. + _deprecated_fields = frozenset({"CxB", "CxBt", "formatB", "_tile_indices"}) + + def __getattr__(self, name): + if name in MatmulLtState._deprecated_fields: + warnings.warn( + f"MatmulLtState.{name} is deprecated and will be removed in the next bitsandbytes release.", + FutureWarning, + stacklevel=2, + ) + return None + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") def reset_grads(self): self.CB = None - self.CxB = None self.SB = None self.SCB = None - self.CxBt = None self.SBt = None self.CBt = None - @property - def tile_indices(self): - raise ValueError("tile_indices is no longer supported.") - class MatMul8bitLt(torch.autograd.Function): @staticmethod @@ -293,7 +296,6 @@ def backward(ctx, grad_output): class MatMul4Bit(torch.autograd.Function): # forward is the same, but we added the fallback for pre-turing GPUs - # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") @staticmethod def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None): diff --git a/bitsandbytes/backends/utils.py b/bitsandbytes/backends/utils.py index ec96a440c..34e3d5faa 100644 --- a/bitsandbytes/backends/utils.py +++ b/bitsandbytes/backends/utils.py @@ -4,9 +4,8 @@ import torch try: - import triton.language as tl # noqa: F401 - import triton # noqa: F401 + import triton.language as tl # noqa: F401 triton_available = True except ImportError: diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index fe687e1e8..784eeafe5 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -11,7 +11,6 @@ import numpy as np import torch from torch import Tensor -from typing_extensions import deprecated from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict @@ -20,32 +19,6 @@ name2qmap = {} """C FUNCTIONS FOR OPTIMIZERS""" -str2optimizer8bit = { - "adam": ( - lib.cadam_static_8bit_grad_32, - lib.cadam_static_8bit_grad_16, - ), - "momentum": ( - lib.cmomentum_static_8bit_grad_32, - lib.cmomentum_static_8bit_grad_16, - ), - "rmsprop": ( - lib.crmsprop_static_8bit_grad_32, - lib.crmsprop_static_8bit_grad_16, - ), - "lion": ( - lib.clion_static_8bit_grad_32, - lib.clion_static_8bit_grad_16, - ), - "lamb": ( - lib.cadam_static_8bit_grad_32, - lib.cadam_static_8bit_grad_16, - ), - "lars": ( - lib.cmomentum_static_8bit_grad_32, - lib.cmomentum_static_8bit_grad_16, - ), -} class GlobalPageManager: @@ -1069,110 +1042,6 @@ def dequantize_4bit( return out -@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) -def quantize( - A: Tensor, - code: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, -) -> tuple[Tensor, tuple[Tensor, Tensor]]: - if code is None: - if "dynamic" not in name2qmap: - name2qmap["dynamic"] = create_dynamic_map().to(A.device) - code = name2qmap["dynamic"] - code = code.to(A.device) - - absmax = torch.abs(A).max() - if absmax.dtype != torch.float32: - absmax = absmax.float() - inp = A / absmax - out = quantize_no_absmax(inp, code, out) - return out, (absmax, code) - - -@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) -def dequantize( - A: Tensor, - state: Optional[tuple[Tensor, Tensor]] = None, - absmax: Optional[torch.Tensor] = None, - code: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, -) -> Tensor: - assert state is not None or absmax is not None - if code is None and state is None: - if "dynamic" not in name2qmap: - name2qmap["dynamic"] = create_dynamic_map().to(A.device) - code = name2qmap["dynamic"] - code = code.to(A.device) - - if state is None: - state = (absmax, code) - out = dequantize_no_absmax(A, state[1], out) - return out * state[0] - - -@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) -def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: - """ - Quantizes input tensor to 8-bit. - - Quantizes the 32-bit input tensor `A` to the 8-bit output tensor - `out` using the quantization map `code`. - - Parameters - ---------- - A : torch.Tensor - The input tensor. - code : torch.Tensor - The quantization map. - out : torch.Tensor, optional - The output tensor. Needs to be of type byte. - - Returns - ------- - torch.Tensor: - Quantized 8-bit tensor. - """ - with _cuda_device_of(A): - if out is None: - out = torch.zeros_like(A, dtype=torch.uint8) - is_on_gpu([A, out]) - lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) - - return out - - -@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) -def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: - """ - Dequantizes the 8-bit tensor to 32-bit. - - Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via - the quantization map `code`. - - Parameters - ---------- - A : torch.Tensor - The 8-bit input tensor. - code : torch.Tensor - The quantization map. - out : torch.Tensor - The 32-bit output tensor. - - Returns - ------- - torch.Tensor: - 32-bit output tensor. - """ - with _cuda_device_of(A): - if out is None: - out = torch.zeros_like(A, dtype=torch.float32) - is_on_gpu([code, A, out]) - stream = _get_tensor_stream(A) - lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()), stream) - - return out - - def optimizer_update_32bit( optimizer_name: str, g: Tensor, @@ -1262,143 +1131,6 @@ def optimizer_update_32bit( ) -@deprecated( - "This function is deprecated and will be removed in a future release. " - "Please use optimizer_update_8bit_blockwise instead. ", - category=FutureWarning, -) -def optimizer_update_8bit( - optimizer_name: str, - g: Tensor, - p: Tensor, - state1: Tensor, - state2: Optional[torch.Tensor], - beta1: float, - beta2: float, - eps: float, - step: int, - lr: float, - qmap1: Tensor, - qmap2: Optional[torch.Tensor], - max1: Tensor, - max2: Optional[torch.Tensor], - new_max1: Tensor, - new_max2: Optional[torch.Tensor], - weight_decay: float = 0.0, - gnorm_scale: float = 1.0, - unorm_vec: Optional[torch.Tensor] = None, - max_unorm: float = 0.0, -) -> None: - """ - Performs an inplace Adam update. - - Universal Adam update for 32/8-bit state and 32/16-bit gradients/weights. - Uses AdamW formulation if weight decay > 0.0. - - Parameters - ---------- - optimizer_name : str - The name of the optimizer. Choices {adam, momentum} - g : torch.Tensor - Gradient tensor. - p : torch.Tensor - Parameter tensor. - state1 : torch.Tensor - Adam state 1. - state2 : torch.Tensor - Adam state 2. - beta1 : float - Adam beta1. - beta2 : float - Adam beta2. - eps : float - Adam epsilon. - weight_decay : float - Weight decay. - step : int - Current optimizer step. - lr : float - The learning rate. - qmap1 : torch.Tensor - Quantization map for first Adam state. - qmap2 : torch.Tensor - Quantization map for second Adam state. - max1 : torch.Tensor - Max value for first Adam state update. - max2 : torch.Tensor - Max value for second Adam state update. - new_max1 : torch.Tensor - Max value for the next Adam update of the first state. - new_max2 : torch.Tensor - Max value for the next Adam update of the second state. - gnorm_scale : float - The factor to rescale the gradient to the max clip value. - unorm_vec : torch.Tensor - The tensor for the update norm. - max_unorm : float - The maximum update norm relative to the weight norm. - """ - - param_norm = 0.0 - if max_unorm > 0.0: - param_norm = torch.norm(p.data.float()) - - with _cuda_device_of(g): - is_on_gpu([g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2]) - if g.dtype == torch.float32 and state1.dtype == torch.uint8: - str2optimizer8bit[optimizer_name][0]( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - get_ptr(unorm_vec), - ct.c_float(max_unorm), - ct.c_float(param_norm), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(max1), - get_ptr(max2), - get_ptr(new_max1), - get_ptr(new_max2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_int32(g.numel()), - ) - elif g.dtype == torch.float16 and state1.dtype == torch.uint8: - str2optimizer8bit[optimizer_name][1]( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - get_ptr(unorm_vec), - ct.c_float(max_unorm), - ct.c_float(param_norm), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(max1), - get_ptr(max2), - get_ptr(new_max1), - get_ptr(new_max2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_int32(g.numel()), - ) - else: - raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", - ) - - def optimizer_update_8bit_blockwise( optimizer_name: str, g: Tensor, @@ -1445,48 +1177,6 @@ def optimizer_update_8bit_blockwise( ) -@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) -def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5): - """Applies percentile clipping - - grad: torch.Tensor - The gradient tensor. - gnorm_vec: torch.Tensor - Vector of gradient norms. 100 elements expected. - step: int - The current optimization steps (number of past gradient norms). - - """ - with _cuda_device_of(grad): - is_on_gpu([grad, gnorm_vec]) - if grad.dtype == torch.float32: - lib.cpercentile_clipping_g32( - get_ptr(grad), - get_ptr(gnorm_vec), - ct.c_int32(step), - ct.c_int32(grad.numel()), - ) - elif grad.dtype == torch.float16: - lib.cpercentile_clipping_g16( - get_ptr(grad), - get_ptr(gnorm_vec), - ct.c_int32(step), - ct.c_int32(grad.numel()), - ) - else: - raise ValueError(f"Gradient type {grad.dtype} not supported!") - - current_gnorm = torch.sqrt(gnorm_vec[step % 100]) - vals, _ = torch.sort(gnorm_vec) - clip_value = torch.sqrt(vals[percentile]) - gnorm_scale = 1.0 - - if current_gnorm > clip_value: - gnorm_scale = clip_value / current_gnorm - - return current_gnorm, clip_value, gnorm_scale - - def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): if not torch.cuda.is_initialized(): torch.cuda.init() diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 20aff67a3..54c2614bd 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -16,11 +16,4 @@ OutlierAwareLinear, Params4bit, StableEmbedding, - SwitchBackLinearBnb, -) -from .triton_based_modules import ( - StandardLinear, - SwitchBackLinear, - SwitchBackLinearGlobal, - SwitchBackLinearVectorwise, ) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 67847f40c..9f05ac6fb 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -1137,42 +1137,3 @@ def forward(self, x): w = self.quantize_weight(self.weight, self.outlier_dim) self.weight.data.copy_(w) self.is_quantized = True - - -class SwitchBackLinearBnb(nn.Linear): - def __init__( - self, - input_features, - output_features, - bias=True, - has_fp16_weights=True, - memory_efficient_backward=False, - threshold=0.0, - index=None, - device=None, - ): - super().__init__(input_features, output_features, bias, device) - self.state = bnb.MatmulLtState() - self.index = index - - self.state.threshold = threshold - self.state.has_fp16_weights = has_fp16_weights - self.state.memory_efficient_backward = memory_efficient_backward - if threshold > 0.0 and not has_fp16_weights: - self.state.use_pool = True - - self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) - - def init_8bit_state(self): - self.state.CB = self.weight.CB - self.state.SCB = self.weight.SCB - self.weight.CB = None - self.weight.SCB = None - - def forward(self, x): - self.state.is_training = self.training - - if self.weight.CB is not None: - self.init_8bit_state() - - return bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias diff --git a/bitsandbytes/nn/triton_based_modules.py b/bitsandbytes/nn/triton_based_modules.py deleted file mode 100644 index aa8494942..000000000 --- a/bitsandbytes/nn/triton_based_modules.py +++ /dev/null @@ -1,264 +0,0 @@ -from functools import partial - -import torch -import torch.nn as nn - -from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise -from bitsandbytes.triton.int8_matmul_mixed_dequantize import ( - int8_matmul_mixed_dequantize, -) -from bitsandbytes.triton.int8_matmul_rowwise_dequantize import ( - int8_matmul_rowwise_dequantize, -) -from bitsandbytes.triton.quantize_columnwise_and_transpose import ( - quantize_columnwise_and_transpose, -) -from bitsandbytes.triton.quantize_global import ( - quantize_global, - quantize_global_transpose, -) -from bitsandbytes.triton.quantize_rowwise import quantize_rowwise -from bitsandbytes.triton.triton_utils import is_triton_available - - -class _switchback_global(torch.autograd.Function): - @staticmethod - def forward(ctx, X_3D, W, bias): - # reshape input to [N * L, D] - X = X_3D.view(-1, X_3D.size(-1)) - - # rowwise quantize for X, global quantize for W - X_int8, state_X = quantize_rowwise(X) - W_int8, state_W = quantize_global(W) - - # save for backward. - ctx.save_for_backward = X, W - - # matmult, fused dequant and add bias - # call "mixed" because we are mixing rowwise quantized and global quantized - return int8_matmul_mixed_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D.size()[:-1], -1) - - @staticmethod - def backward(ctx, G_3D): - # reshape input to [N_out * L, D] - G = G_3D.reshape(-1, G_3D.size(-1)) - - grad_X = grad_W = grad_bias = None - - X, W = ctx.save_for_backward - if ctx.needs_input_grad[0]: - # rowwise quantize for G, global quantize for W - # for W, we also fuse the transpose operation because only A @ B^T is supported - # so we transpose once then call .t() in the matmul - G_int8, state_G = quantize_rowwise(G) - W_int8, state_W = quantize_global_transpose(W) - grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( - *G_3D.size()[:-1], - -1, - ) - if ctx.needs_input_grad[1]: - # backward pass uses standard weight grad - grad_W = torch.matmul(G.t(), X.to(G.dtype)) - if ctx.needs_input_grad[2]: - grad_bias = G.sum(dim=0) - - return grad_X, grad_W, grad_bias - - -class _switchback_vectorrize(torch.autograd.Function): - @staticmethod - def forward(ctx, X_3D, W, bias): - # reshape input to [N * L, D] - X = X_3D.view(-1, X_3D.size(-1)) - - ctx.save_for_backward = X, W - # rowwise quantize for X - # columnwise quantize for W (first rowwise, transpose later) - X_int8, state_X = quantize_rowwise(X) - W_int8, state_W = quantize_rowwise(W) - - # matmult, fused dequant and add bias - # call kernel which expects rowwise quantized X and W - return int8_matmul_rowwise_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D.size()[:-1], -1) - - @staticmethod - def backward(ctx, G_3D): - X, W = ctx.save_for_backward - - G = G_3D.reshape(-1, G_3D.size(-1)) - - grad_X = grad_W = grad_bias = None - - if ctx.needs_input_grad[0]: - # rowwise quantize for G, columnwise quantize for W and fused transpose - # we call .t() for weight later because only A @ B^T is supported - G_int8, state_G = quantize_rowwise(G) - W_int8, state_W = quantize_columnwise_and_transpose(W) - grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( - *G_3D.size()[:-1], - -1, - ) - if ctx.needs_input_grad[1]: - # backward pass uses standard weight grad - grad_W = torch.matmul(G.t(), X.to(G.dtype)) - if ctx.needs_input_grad[2]: - grad_bias = G.sum(dim=0) - - return grad_X, grad_W, grad_bias - - -class _switchback_global_mem_efficient(torch.autograd.Function): - @staticmethod - def forward(ctx, X_3D, W, bias): - # reshape input to [N * L, D] - X = X_3D.view(-1, X_3D.size(-1)) - X_3D_sz = X_3D.size() - - # rowwise quantize for X, global quantize for W - X_int8, state_X = quantize_rowwise(X) - del X - W_int8, state_W = quantize_global(W) - - # save for backward. - ctx.save_for_backward = X_int8, state_X, W_int8, state_W - - # matmult, fused dequant and add bias - # call "mixed" because we are mixing rowwise quantized and global quantized - return int8_matmul_mixed_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D_sz[:-1], -1) - - @staticmethod - def backward(ctx, G_3D): - # reshape input to [N_out * L, D] - G = G_3D.reshape(-1, G_3D.size(-1)) - G_3D_sz = G_3D.size() - - grad_X = grad_W = grad_bias = None - - X_int8, state_X, W_int8, state_W = ctx.save_for_backward - if ctx.needs_input_grad[1]: - real_X = dequantize_rowwise(X_int8, state_X) - del X_int8 - grad_W = torch.matmul(G.t(), real_X.to(G.dtype)) - del real_X - if ctx.needs_input_grad[2]: - grad_bias = G.sum(dim=0) - if ctx.needs_input_grad[0]: - G_int8, state_G = quantize_rowwise(G) - del G - W_int8 = W_int8.t().contiguous() - grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(*G_3D_sz[:-1], -1) - - return grad_X, grad_W, grad_bias - - -class SwitchBackLinear(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - vector_wise_quantization: bool = False, - mem_efficient: bool = False, - ): - super().__init__(in_features, out_features, bias, device, dtype) - - if not is_triton_available(): - raise ImportError("""Could not import triton. Please install triton to use SwitchBackLinear. - Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower""") - - # By default, we use the global quantization. - self.vector_wise_quantization = vector_wise_quantization - if self.vector_wise_quantization: - self._fn = _switchback_vectorrize - if mem_efficient: - print("mem efficient is not supported for vector-wise quantization.") - exit(1) - else: - if mem_efficient: - self._fn = _switchback_global_mem_efficient - else: - self._fn = _switchback_global - - def prepare_for_eval(self): - # If we just want to do eval, we can pre-quantize the weights instead of doing it on the forward pass. - # Note this is experimental and not tested thoroughly. - # Note this needs to be explicitly called with something like - # def cond_prepare(m): - # if hasattr(m, "prepare_for_eval"): - # m.prepare_for_eval() - # model.apply(cond_prepare) - print("=> preparing for eval.") - if self.vector_wise_quantization: - W_int8, state_W = quantize_rowwise(self.weight) - else: - W_int8, state_W = quantize_global(self.weight) - - self.register_buffer("W_int8", W_int8) - self.register_buffer("state_W", state_W) - - del self.weight - - def forward(self, x): - if self.training: - return self._fn.apply(x, self.weight, self.bias) - else: - # If it hasn't been "prepared for eval", run the standard forward pass. - if not hasattr(self, "W_int8"): - return self._fn.apply(x, self.weight, self.bias) - - # Otherwise, use pre-computed weights. - X = x.view(-1, x.size(-1)) - X_int8, state_X = quantize_rowwise(X) - - if self.vector_wise_quantization: - return int8_matmul_rowwise_dequantize(X_int8, self.W_int8.t(), state_X, self.state_W, self.bias).view( - *x.size()[:-1], - -1, - ) - else: - return int8_matmul_mixed_dequantize(X_int8, self.W_int8.t(), state_X, self.state_W, self.bias).view( - *x.size()[:-1], - -1, - ) - - -SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False) -SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True) -SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True) - - -# This is just the standard linear function. -class StandardLinearFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, input, weight, bias=None): - X = input.view(-1, input.size(-1)) - - ctx.save_for_backward(X, weight, bias) - output = input.matmul(weight.t()) - if bias is not None: - output += bias.unsqueeze(0).expand_as(output) - return output.view(*input.size()[:-1], -1) - - @staticmethod - def backward(ctx, grad_output_3D): - input, weight, bias = ctx.saved_tensors - - grad_output = grad_output_3D.reshape(-1, grad_output_3D.size(-1)) - - grad_input = grad_weight = grad_bias = None - - if ctx.needs_input_grad[0]: - grad_input = grad_output.matmul(weight.to(grad_output.dtype)).view(*grad_output_3D.size()[:-1], -1) - if ctx.needs_input_grad[1]: - grad_weight = grad_output.t().matmul(input.to(grad_output.dtype)) - if bias is not None and ctx.needs_input_grad[2]: - grad_bias = grad_output.sum(0) - - return grad_input, grad_weight, grad_bias - - -class StandardLinear(nn.Linear): - def forward(self, x): - return StandardLinearFunction.apply(x, self.weight, self.bias) diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py index 7459dece1..b871f2bf4 100644 --- a/bitsandbytes/optim/adagrad.py +++ b/bitsandbytes/optim/adagrad.py @@ -17,8 +17,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, ): """ Base Adagrad optimizer. @@ -42,10 +40,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") @@ -67,8 +61,6 @@ def __init__( optim_bits, args, min_8bit_size, - percentile_clipping, - block_wise, ) @@ -84,8 +76,6 @@ def __init__( optim_bits=8, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, ): """ 8-bit Adagrad optimizer. @@ -109,10 +99,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") @@ -124,7 +110,6 @@ def __init__( raise ValueError("Initial accumulator value != 0.0 not supported!") if lr_decay != 0.0: raise ValueError("Lr Decay != 0.0 not supported!") - assert block_wise super().__init__( "adagrad", params, @@ -135,8 +120,6 @@ def __init__( 8, args, min_8bit_size, - percentile_clipping, - block_wise, ) @@ -152,8 +135,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, ): """ 32-bit Adagrad optimizer. @@ -177,10 +158,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") @@ -202,6 +179,4 @@ def __init__( 32, args, min_8bit_size, - percentile_clipping, - block_wise, ) diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index 22a217c3b..63210bdc3 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -18,8 +18,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, is_paged=False, ): """ @@ -44,10 +42,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ @@ -61,8 +55,6 @@ def __init__( optim_bits, args, min_8bit_size, - percentile_clipping, - block_wise, is_paged=is_paged, ) @@ -79,8 +71,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, is_paged=False, ): """ @@ -107,10 +97,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ @@ -133,8 +119,6 @@ def __init__( 8, # Hardcoded to 8 bits args, min_8bit_size, - percentile_clipping, - block_wise, is_paged=is_paged, ) @@ -151,8 +135,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, is_paged=False, ): """ @@ -177,10 +159,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ @@ -194,8 +172,6 @@ def __init__( 32, args, min_8bit_size, - percentile_clipping, - block_wise, is_paged=is_paged, ) @@ -212,8 +188,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, is_paged=False, ): """ @@ -238,10 +212,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ @@ -255,8 +225,6 @@ def __init__( optim_bits, args, min_8bit_size, - percentile_clipping, - block_wise, is_paged=True, ) @@ -273,8 +241,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, is_paged=False, ): """ @@ -301,10 +267,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ @@ -327,8 +289,6 @@ def __init__( 8, # Hardcoded to 8 bits args, min_8bit_size, - percentile_clipping, - block_wise, is_paged=True, ) @@ -345,8 +305,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, is_paged=False, ): """ @@ -371,10 +329,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ @@ -388,7 +342,5 @@ def __init__( 32, args, min_8bit_size, - percentile_clipping, - block_wise, is_paged=True, ) diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index 5f225c9ad..36e151dfc 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -18,8 +18,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, is_paged=False, ): """ @@ -44,10 +42,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ @@ -61,8 +55,6 @@ def __init__( optim_bits, args, min_8bit_size, - percentile_clipping, - block_wise, is_paged=is_paged, ) @@ -79,8 +71,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, is_paged=False, ): """ @@ -107,10 +97,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ @@ -133,8 +119,6 @@ def __init__( 8, # Hardcoded to 8 bits args, min_8bit_size, - percentile_clipping, - block_wise, is_paged=is_paged, ) @@ -151,8 +135,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, is_paged=False, ): """ @@ -177,10 +159,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ @@ -194,8 +172,6 @@ def __init__( 32, args, min_8bit_size, - percentile_clipping, - block_wise, is_paged=is_paged, ) @@ -212,8 +188,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, ): """ Paged AdamW optimizer. @@ -237,10 +211,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ super().__init__( "adam", @@ -252,8 +222,6 @@ def __init__( optim_bits, args, min_8bit_size, - percentile_clipping, - block_wise, is_paged=True, ) @@ -270,8 +238,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, ): """ Paged 8-bit AdamW optimizer. @@ -297,10 +263,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ # Validate unsupported parameters if amsgrad: @@ -321,8 +283,6 @@ def __init__( 8, # Hardcoded to 8 bits args, min_8bit_size, - percentile_clipping, - block_wise, is_paged=True, ) @@ -339,8 +299,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, ): """ Paged 32-bit AdamW optimizer. @@ -364,10 +322,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ super().__init__( "adam", @@ -379,7 +333,5 @@ def __init__( 32, args, min_8bit_size, - percentile_clipping, - block_wise, is_paged=True, ) diff --git a/bitsandbytes/optim/ademamix.py b/bitsandbytes/optim/ademamix.py index 928289adb..48a62198e 100644 --- a/bitsandbytes/optim/ademamix.py +++ b/bitsandbytes/optim/ademamix.py @@ -129,8 +129,6 @@ def __init__( optim_bits=optim_bits, args=None, min_8bit_size=min_8bit_size, - percentile_clipping=100, - block_wise=True, is_paged=is_paged, alpha=alpha, t_alpha=t_alpha, @@ -142,8 +140,6 @@ def init_state(self, group, p, gindex, pindex): # In our AdEMAMix implementation, we use `state` to hold # both the fast and slow EMAs. Here we override the base # `Optimizer2State` to allocate a buffer twice as large. - # Additional consideration: we do not support block_wise=False, - # percentile clipping, or max_unorm. config = self.get_config(gindex, pindex, group) @@ -380,8 +376,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=min_8bit_size, - percentile_clipping=100, - block_wise=True, is_paged=is_paged, alpha=alpha, t_alpha=t_alpha, diff --git a/bitsandbytes/optim/lamb.py b/bitsandbytes/optim/lamb.py index 8d29cbbfe..6dcfd383f 100644 --- a/bitsandbytes/optim/lamb.py +++ b/bitsandbytes/optim/lamb.py @@ -19,8 +19,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=False, max_unorm=1.0, ): """ @@ -49,10 +47,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. max_unorm (`float`, defaults to 1.0): The maximum gradient norm. """ @@ -66,8 +60,6 @@ def __init__( optim_bits, args, min_8bit_size, - percentile_clipping, - block_wise, max_unorm=1.0, ) @@ -85,8 +77,6 @@ def __init__( adam_w_mode=True, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=False, max_unorm=1.0, ): """ @@ -113,10 +103,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. max_unorm (`float`, defaults to 1.0): The maximum gradient norm. """ @@ -130,8 +116,6 @@ def __init__( 8, args, min_8bit_size, - percentile_clipping, - block_wise, max_unorm=1.0, ) @@ -149,8 +133,6 @@ def __init__( adam_w_mode=True, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=False, max_unorm=1.0, ): """ @@ -177,10 +159,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. max_unorm (`float`, defaults to 1.0): The maximum gradient norm. """ @@ -194,7 +172,5 @@ def __init__( 32, args, min_8bit_size, - percentile_clipping, - block_wise, max_unorm=1.0, ) diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py index fa2af57bc..c2f5aa784 100644 --- a/bitsandbytes/optim/lars.py +++ b/bitsandbytes/optim/lars.py @@ -20,7 +20,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, max_unorm=0.02, ): """ @@ -45,8 +44,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. max_unorm (`float`, defaults to 0.02): The maximum gradient norm. """ @@ -62,9 +59,7 @@ def __init__( optim_bits, args, min_8bit_size, - percentile_clipping, max_unorm=max_unorm, - block_wise=False, ) @@ -79,7 +74,6 @@ def __init__( nesterov=False, args=None, min_8bit_size=4096, - percentile_clipping=100, max_unorm=0.02, ): """ @@ -102,8 +96,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. max_unorm (`float`, defaults to 0.02): The maximum gradient norm. """ @@ -119,9 +111,7 @@ def __init__( 8, args, min_8bit_size, - percentile_clipping, max_unorm=max_unorm, - block_wise=False, ) @@ -136,7 +126,6 @@ def __init__( nesterov=False, args=None, min_8bit_size=4096, - percentile_clipping=100, max_unorm=0.02, ): """ @@ -159,8 +148,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. max_unorm (`float`, defaults to 0.02): The maximum gradient norm. """ @@ -176,9 +163,7 @@ def __init__( 32, args, min_8bit_size, - percentile_clipping, max_unorm=max_unorm, - block_wise=False, ) diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py index 2e4163694..6100491f6 100644 --- a/bitsandbytes/optim/lion.py +++ b/bitsandbytes/optim/lion.py @@ -15,8 +15,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, is_paged=False, ): """ @@ -37,10 +35,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ @@ -54,8 +48,6 @@ def __init__( optim_bits, args, min_8bit_size, - percentile_clipping, - block_wise, is_paged=is_paged, ) @@ -69,8 +61,6 @@ def __init__( weight_decay=0, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, is_paged=False, ): """ @@ -89,10 +79,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ @@ -106,8 +92,6 @@ def __init__( 8, args, min_8bit_size, - percentile_clipping, - block_wise, is_paged=is_paged, ) @@ -121,8 +105,6 @@ def __init__( weight_decay=0, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, is_paged=False, ): """ @@ -141,10 +123,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ @@ -158,8 +136,6 @@ def __init__( 32, args, min_8bit_size, - percentile_clipping, - block_wise, is_paged=is_paged, ) @@ -174,8 +150,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, ): """ Paged Lion optimizer. @@ -195,10 +169,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ super().__init__( "lion", @@ -210,8 +180,6 @@ def __init__( optim_bits, args, min_8bit_size, - percentile_clipping, - block_wise, is_paged=True, ) @@ -225,8 +193,6 @@ def __init__( weight_decay=0, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, ): """ Paged 8-bit Lion optimizer. @@ -240,16 +206,10 @@ def __init__( The beta values are the decay rates of the first and second-order moment of the optimizer. weight_decay (`float`, defaults to 0): The weight decay value for the optimizer. - optim_bits (`int`, defaults to 32): - The number of bits of the optimizer state. args (`object`, defaults to `None`): An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ super().__init__( "lion", @@ -261,8 +221,6 @@ def __init__( 8, args, min_8bit_size, - percentile_clipping, - block_wise, is_paged=True, ) @@ -276,8 +234,6 @@ def __init__( weight_decay=0, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, ): """ Paged 32-bit Lion optimizer. @@ -291,16 +247,10 @@ def __init__( The beta values are the decay rates of the first and second-order moment of the optimizer. weight_decay (`float`, defaults to 0): The weight decay value for the optimizer. - optim_bits (`int`, defaults to 32): - The number of bits of the optimizer state. args (`object`, defaults to `None`): An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ super().__init__( "lion", @@ -312,7 +262,5 @@ def __init__( 32, args, min_8bit_size, - percentile_clipping, - block_wise, is_paged=True, ) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index db7a35231..5c00a42bf 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -59,7 +59,7 @@ def override_config(self, parameters, key=None, value=None, key_value_dict=None) The key-values of the optimizer config for the input parameters are overridden This can be both, optimizer parameters like `betas` or `lr`, or it can be - 8-bit specific parameters like `optim_bits` or `percentile_clipping`. + 8-bit specific parameters like `optim_bits`. Arguments: parameters (`torch.Tensor` or `list(torch.Tensors)`): @@ -345,8 +345,6 @@ def get_config(self, gindex, pindex, group): config["t_beta3"] = group.get("t_beta3", 0) config["optim_bits"] = self.args.optim_bits config["min_8bit_size"] = self.args.min_8bit_size - config["percentile_clipping"] = self.args.percentile_clipping - config["block_wise"] = self.args.block_wise config["max_unorm"] = self.args.max_unorm config["skip_zeros"] = self.args.skip_zeros @@ -393,8 +391,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, max_unorm=0.0, skip_zeros=False, is_paged=False, @@ -424,10 +420,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. max_unorm (`float`, defaults to 0.0): The maximum value to normalize each block with. skip_zeros (`bool`, defaults to `False`): @@ -466,8 +458,6 @@ def __init__( args = {} args["optim_bits"] = optim_bits args["min_8bit_size"] = min_8bit_size - args["percentile_clipping"] = percentile_clipping - args["block_wise"] = block_wise args["max_unorm"] = max_unorm args["skip_zeros"] = skip_zeros @@ -510,21 +500,12 @@ def init_state(self, group, p, gindex, pindex): state["state2"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap2"] = self.name2qmap["udynamic"] - if config["block_wise"]: - blocksize = 256 - n = p.numel() - blocks = (n // blocksize) + bool(n % blocksize) + blocksize = 256 + n = p.numel() + blocks = (n // blocksize) + bool(n % blocksize) - state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) - state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) - else: - state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) - state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) - state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device) - state["new_max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device) - - if config["percentile_clipping"] < 100: - state["gnorm_vec"] = torch.zeros((100,), device=p.device) + state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) + state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) if config["max_unorm"] > 0.0: state["unorm_vec"] = torch.zeros((1,), device=p.device) @@ -543,16 +524,6 @@ def update_step(self, group, p, gindex, pindex): state["step"] += 1 step = state["step"] - if config["percentile_clipping"] < 100: - _current_gnorm, _clip_value, gnorm_scale = F.percentile_clipping( - grad, - state["gnorm_vec"], - step, - config["percentile_clipping"], - ) - else: - gnorm_scale = 1.0 - if state["state1"].dtype == torch.float: F.optimizer_update_32bit( self.optimizer_name, @@ -568,40 +539,13 @@ def update_step(self, group, p, gindex, pindex): config["betas"][2] if len(config["betas"]) >= 3 else 0.0, config.get("alpha", 0.0), config["weight_decay"], - gnorm_scale, + 1.0, state["unorm_vec"] if config["max_unorm"] > 0.0 else None, max_unorm=config["max_unorm"], skip_zeros=config["skip_zeros"], ) - elif state["state1"].dtype == torch.uint8 and not config["block_wise"]: - F.optimizer_update_8bit( - self.optimizer_name, - grad, - p, - state["state1"], - state["state2"], - config["betas"][0], - config["betas"][1], - config["eps"], - step, - config["lr"], - state["qmap1"], - state["qmap2"], - state["max1"], - state["max2"], - state["new_max1"], - state["new_max2"], - config["weight_decay"], - gnorm_scale=gnorm_scale, - unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None, - max_unorm=config["max_unorm"], - ) - - # swap maxes - state["max1"], state["new_max1"] = state["new_max1"], state["max1"] - state["max2"], state["new_max2"] = state["new_max2"], state["max2"] - elif state["state1"].dtype == torch.uint8 and config["block_wise"]: + elif state["state1"].dtype == torch.uint8: F.optimizer_update_8bit_blockwise( self.optimizer_name, grad, @@ -620,7 +564,7 @@ def update_step(self, group, p, gindex, pindex): state["absmax1"], state["absmax2"], config["weight_decay"], - gnorm_scale=gnorm_scale, + gnorm_scale=1.0, skip_zeros=config["skip_zeros"], ) @@ -637,8 +581,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, max_unorm=0.0, skip_zeros=False, is_paged=False, @@ -665,10 +607,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. max_unorm (`float`, defaults to 0.0): The maximum value to normalize each block with. skip_zeros (`bool`, defaults to `False`): @@ -692,8 +630,6 @@ def __init__( args = {} args["optim_bits"] = optim_bits args["min_8bit_size"] = min_8bit_size - args["percentile_clipping"] = percentile_clipping - args["block_wise"] = block_wise args["max_unorm"] = max_unorm args["skip_zeros"] = skip_zeros @@ -731,18 +667,11 @@ def init_state(self, group, p, gindex, pindex): state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap1"] = self.name2qmap["dynamic"] - if config["block_wise"]: - blocksize = 256 - n = p.numel() - blocks = (n // blocksize) + bool(n % blocksize) - - state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) - else: - state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) - state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + blocksize = 256 + n = p.numel() + blocks = (n // blocksize) + bool(n % blocksize) - if config["percentile_clipping"] < 100: - state["gnorm_vec"] = torch.zeros((100,), device=p.device) + state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) if config["max_unorm"] > 0.0: state["unorm_vec"] = torch.zeros((1,), device=p.device) @@ -761,16 +690,6 @@ def update_step(self, group, p, gindex, pindex): state["step"] += 1 step = state["step"] - if config["percentile_clipping"] < 100: - _current_gnorm, _clip_value, gnorm_scale = F.percentile_clipping( - grad, - state["gnorm_vec"], - step, - config["percentile_clipping"], - ) - else: - gnorm_scale = 1.0 - if state["state1"].dtype == torch.float: F.optimizer_update_32bit( self.optimizer_name, @@ -786,38 +705,13 @@ def update_step(self, group, p, gindex, pindex): 0.0, 0.0, config["weight_decay"], - gnorm_scale, + 1.0, state["unorm_vec"] if config["max_unorm"] > 0.0 else None, max_unorm=config["max_unorm"], skip_zeros=config["skip_zeros"], ) - elif state["state1"].dtype == torch.uint8 and not config["block_wise"]: - F.optimizer_update_8bit( - self.optimizer_name, - grad, - p, - state["state1"], - None, - config["betas"][0], - config["betas"][1], - config["eps"], - step, - config["lr"], - state["qmap1"], - None, - state["max1"], - None, - state["new_max1"], - None, - config["weight_decay"], - gnorm_scale, - state["unorm_vec"] if config["max_unorm"] > 0.0 else None, - max_unorm=config["max_unorm"], - ) - - state["max1"], state["new_max1"] = state["new_max1"], state["max1"] - elif state["state1"].dtype == torch.uint8 and config["block_wise"]: + elif state["state1"].dtype == torch.uint8: F.optimizer_update_8bit_blockwise( self.optimizer_name, grad, @@ -836,6 +730,6 @@ def update_step(self, group, p, gindex, pindex): state["absmax1"], None, config["weight_decay"], - gnorm_scale=gnorm_scale, + gnorm_scale=1.0, skip_zeros=config["skip_zeros"], ) diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py index 25611309b..54c1fbda0 100644 --- a/bitsandbytes/optim/rmsprop.py +++ b/bitsandbytes/optim/rmsprop.py @@ -18,8 +18,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, ): """ Base RMSprop optimizer. @@ -45,10 +43,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ if alpha == 0: raise NotImplementedError("RMSprop with alpha==0.0 is not supported!") @@ -64,8 +58,6 @@ def __init__( optim_bits, args, min_8bit_size, - percentile_clipping, - block_wise, ) @@ -81,8 +73,6 @@ def __init__( centered=False, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, ): """ 8-bit RMSprop optimizer. @@ -102,16 +92,10 @@ def __init__( The momentum value speeds up the optimizer by taking bigger steps. centered (`bool`, defaults to `False`): Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute. - optim_bits (`int`, defaults to 32): - The number of bits of the optimizer state. args (`object`, defaults to `None`): An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ if alpha == 0: raise NotImplementedError("RMSprop with alpha==0.0 is not supported!") @@ -127,8 +111,6 @@ def __init__( 8, args, min_8bit_size, - percentile_clipping, - block_wise, ) @@ -144,8 +126,6 @@ def __init__( centered=False, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, ): """ 32-bit RMSprop optimizer. @@ -171,10 +151,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ if alpha == 0: @@ -191,6 +167,4 @@ def __init__( 32, args, min_8bit_size, - percentile_clipping, - block_wise, ) diff --git a/bitsandbytes/optim/sgd.py b/bitsandbytes/optim/sgd.py index ec18f036c..75fc71474 100644 --- a/bitsandbytes/optim/sgd.py +++ b/bitsandbytes/optim/sgd.py @@ -17,8 +17,6 @@ def __init__( optim_bits=32, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, ): """ Base SGD optimizer. @@ -42,10 +40,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ if momentum == 0: raise NotImplementedError("SGD without momentum is not supported!") @@ -59,8 +53,6 @@ def __init__( optim_bits, args, min_8bit_size, - percentile_clipping, - block_wise, ) @@ -75,8 +67,6 @@ def __init__( nesterov=False, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, ): """ 8-bit SGD optimizer. @@ -98,10 +88,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ if momentum == 0: raise NotImplementedError("SGD without momentum is not supported!") @@ -115,8 +101,6 @@ def __init__( 8, args, min_8bit_size, - percentile_clipping, - block_wise, ) @@ -131,8 +115,6 @@ def __init__( nesterov=False, args=None, min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, ): """ 32-bit SGD optimizer. @@ -154,10 +136,6 @@ def __init__( An object with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. - percentile_clipping (`int`, defaults to 100): - Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. - block_wise (`bool`, defaults to `True`): - Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ if momentum == 0: raise NotImplementedError("SGD without momentum is not supported!") @@ -171,6 +149,4 @@ def __init__( 32, args, min_8bit_size, - percentile_clipping, - block_wise, ) diff --git a/bitsandbytes/research/__init__.py b/bitsandbytes/research/__init__.py deleted file mode 100644 index 31db4f282..000000000 --- a/bitsandbytes/research/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from . import nn -from .autograd._functions import ( - matmul_fp8_global, - matmul_fp8_mixed, - switchback_bnb, -) diff --git a/bitsandbytes/research/autograd/__init__.py b/bitsandbytes/research/autograd/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py deleted file mode 100644 index 1ea147a90..000000000 --- a/bitsandbytes/research/autograd/_functions.py +++ /dev/null @@ -1,396 +0,0 @@ -from functools import reduce # Required in Python 3 -import operator -from typing import Optional -import warnings - -import torch - -from bitsandbytes.autograd._functions import GlobalOutlierPooler, MatmulLtState -import bitsandbytes.functional as F - - -# math.prod not compatible with python < 3.8 -def prod(iterable): - return reduce(operator.mul, iterable, 1) - - -class MatMulFP8Mixed(torch.autograd.Function): - # forward is the same, but we added the fallback for pre-turing GPUs - # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") - - @staticmethod - def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024): - # default of pytorch behavior if inputs are empty - ctx.is_empty = False - if prod(A.shape) == 0: - ctx.is_empty = True - ctx.A = A - ctx.B = B - - B_shape = B.shape - if A.shape[-1] == B_shape[0]: - return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) - else: - return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) - - # 1. Dequantize - # 2. MatmulnN - cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz) - fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype) - - cB, state = F.quantize(B.float(), code=fw_code) - fp8B = F.dequantize(cB, state).to(B.dtype) - - output = torch.matmul(fp8A, fp8B) - - # output is half - - # 3. Save state - ctx.fw_code = fw_code - ctx.bw_code = bw_code - ctx.bsz = bsz - ctx.bsz2 = bsz2 - ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype - - if any(ctx.needs_input_grad[:2]): - # NOTE: we send back A, and re-quant. - ctx.tensors = (A, fp8B) - else: - ctx.tensors = (None, None) - - return output - - @staticmethod - def backward(ctx, grad_output): - if ctx.is_empty: - return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None - - req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad - A, B = ctx.tensors - - grad_A, grad_B = None, None - - # TODO: Fix blocksize to be output_dim - cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz2) - fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz2).to(grad_output.dtype) - - # cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code) - # fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype) - - # grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() - # fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector') - # fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose - # fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2]) - - # not supported by PyTorch. TODO: create work-around - if req_gradA: - grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype) - - if req_gradB: - if len(A.shape) == 3: - At = A.transpose(2, 1).contiguous() - else: - At = A.transpose(1, 0).contiguous() - # cA, state = F.quantize(At.float(), code=ctx.fw_code) - # fp8At = F.dequantize(cA, state).to(A.dtype) - grad_B = torch.matmul(At.to(grad_output.dtype), grad_output).to(B.dtype) - - return grad_A, grad_B, None, None, None, None, None - - -class MatMulFP8Global(torch.autograd.Function): - # forward is the same, but we added the fallback for pre-turing GPUs - # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") - - @staticmethod - def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024): - # default of pytorch behavior if inputs are empty - ctx.is_empty = False - if prod(A.shape) == 0: - ctx.is_empty = True - ctx.A = A - ctx.B = B - - B_shape = B.shape - if A.shape[-1] == B_shape[0]: - return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) - else: - return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) - - # 1. Dequantize - # 2. MatmulnN - cA, state = F.quantize(A.float(), code=fw_code) - fp8A = F.dequantize(cA, state).to(A.dtype) - - cB, state = F.quantize(B.float(), code=fw_code) - fp8B = F.dequantize(cB, state).to(B.dtype) - - output = torch.matmul(fp8A, fp8B) - - # output is half - - # 3. Save state - ctx.fw_code = fw_code - ctx.bw_code = bw_code - ctx.bsz = bsz - ctx.bsz2 = bsz2 - ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype - - if any(ctx.needs_input_grad[:2]): - # NOTE: we send back A, and re-quant. - ctx.tensors = (A, fp8B) - else: - ctx.tensors = (None, None) - - return output - - @staticmethod - def backward(ctx, grad_output): - if ctx.is_empty: - return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None - - req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad - A, B = ctx.tensors - - grad_A, grad_B = None, None - - # TODO: Fix blocksize to be output_dim - cgrad_out, state = F.quantize(grad_output.float(), code=ctx.bw_code) - fp8out = F.dequantize(cgrad_out, state).to(grad_output.dtype) - - # cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code) - # fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype) - - # grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() - # fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector') - # fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose - # fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2]) - - # not supported by PyTorch. TODO: create work-around - if req_gradA: - grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype) - - if req_gradB: - if len(A.shape) == 3: - At = A.transpose(2, 1).contiguous() - else: - At = A.transpose(1, 0).contiguous() - cA, state = F.quantize(At.float(), code=ctx.fw_code) - fp8At = F.dequantize(cA, state).to(A.dtype) - grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype) - - return grad_A, grad_B, None, None, None, None, None - - -class SwitchBackBnb(torch.autograd.Function): - @staticmethod - def forward(ctx, A, B, out=None, bias=None, state: Optional[MatmulLtState] = None): - state = state or MatmulLtState() - - # default to pytorch behavior if inputs are empty - ctx.is_empty = False - if prod(A.shape) == 0: - ctx.is_empty = True - ctx.A = A - ctx.B = B - ctx.bias = bias - if A.shape[-1] == B.shape[0]: - return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device) - else: - return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device) - - # 1. Quantize A - # 2. Quantize B - # 3. Matmul - # 4. Mixed-precision decomposition matmul - # 5. Save state - input_shape = A.shape - if state.outlier_pool is None: - state.outlier_pool = GlobalOutlierPooler.get_instance() - - # Cast A to fp16 - if A.dtype != torch.float16: - warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") - - # 1. Quantize A - if len(A.shape) == 3: - A = A.view(-1, A.shape[-1]).contiguous() - CA, CAt, SCA, SCAt, outlier_cols = F.int8_double_quant(A.to(torch.float16), threshold=state.threshold) - - if state.threshold > 0.0 and outlier_cols is not None: - if state.has_fp16_weights: - idx = outlier_cols - CA[:, idx] = 0 - subA = A[:, idx] - state.subB = B[:, idx].t().contiguous() - state.idx = idx - else: - if state.SB is None: - state.SB = (state.CB.shape, "row") - else: - if not state.has_fp16_weights and state.SB is None: - state.SB = (state.CB.shape, "row") - subA = None - - # 2. Quantize B - if state.has_fp16_weights: - # print('B shape', B.shape) - has_grad = getattr(B, "grad", None) is not None - is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) - if is_transposed: - B = B.contiguous() - - if (state.is_training and not has_grad) or state.SB is None: - state.reset_grads() - ( - state.CB, - state.CBt, - state.SCB, - state.SCBt, - _, - ) = F.int8_double_quant(B.to(torch.float16)) - state.SB = (state.CB.shape, "row") - else: - has_grad = False - - if outlier_cols is not None and not state.has_fp16_weights: - # extract outliers - state.idx = outlier_cols - outliers = state.CB[:, state.idx.long()].clone() - state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) - CA[:, state.idx.long()] = 0 - - subA = A[:, state.idx.long()] - - shapeB = state.SB[0] - - if len(input_shape) == 3: - output_shape = (input_shape[0], input_shape[1], shapeB[0]) - else: - output_shape = (input_shape[0], shapeB[0]) - - # 3. Matmul - out32 = F.int8_linear_matmul(CA, state.CB) - # we apply the fused bias here - - if bias is None or bias.dtype == torch.float16: - output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=bias).to(A.dtype) - else: # apply bias separately - output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=None).to(A.dtype) - output.add_(bias) - - # 4. Mixed-precision decomposition matmul - if outlier_cols is not None and subA is not None: - output += torch.matmul(subA, state.subB) - - # 5. Save state - ctx.state = state - - ctx.grad_shape = input_shape - ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype - - if any(ctx.needs_input_grad[:2]): - ctx.tensors = (CAt, subA, A) - ctx.tensor_states = (SCAt, state.idx) - else: - ctx.tensors = [None, None, None] - ctx.tensor_states = (None, None) - ctx.save_for_backward(None, None) - - clone_func = torch.clone if len(output_shape) == 3 else lambda x: x - return clone_func(output.view(output_shape)) - - @staticmethod - def backward(ctx, grad_output): - if ctx.is_empty: - bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) - return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None - - req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad - _CAt, _subA, A = ctx.tensors - _SCAt, _idx = ctx.tensor_states - state = ctx.state - grad_A = grad_B = grad_bias = None - - if req_gradBias: - # compute grad_bias first before changing grad_output dtype - grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) - - # Cast grad_output to fp16 - if len(grad_output.shape) == 3: - grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() - - _Cgrad, _Cgradt, _SCgrad, _SCgradt, _outlier_cols = F.int8_double_quant(grad_output.to(torch.float16)) - - if req_gradB: - # print('back A shape', A.shape) - # print('grad output t shape', grad_output.t().shape) - grad_B = torch.matmul(grad_output.t(), A) - - if req_gradA: - if state.CB is not None: - CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) - grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) - else: - raise Exception("State must contain either CBt or CB matrix for backward") - - return grad_A, grad_B, None, grad_bias, None - - -def get_block_sizes(input_matrix, weight_matrix): - input_features = input_matrix.shape[-1] - output_features = weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1] - array = [4096, 2048, 1024, 512, 256, 128, 64, 0] - bsz, bsz2 = 1024, 1024 - for i, k in enumerate(array): - if input_features > array[i + 1]: - bsz = k - break - for i, k in enumerate(array): - if output_features > array[i + 1]: - bsz2 = k - break - - return bsz, bsz2 - - -def matmul_fp8_global( - A: torch.Tensor, - B: torch.Tensor, - fw_code: torch.Tensor, - bw_code: torch.Tensor, - out: Optional[torch.Tensor] = None, - bsz: int = -1, - bsz2: int = -1, -): - if bsz == -1 or bsz2 == -1: - bsz, bsz2 = get_block_sizes(A, B) - return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2) - - -def matmul_fp8_mixed( - A: torch.Tensor, - B: torch.Tensor, - fw_code: torch.Tensor, - bw_code: torch.Tensor, - out: Optional[torch.Tensor] = None, - bsz: int = -1, - bsz2: int = -1, -): - if bsz == -1 or bsz2 == -1: - bsz, bsz2 = get_block_sizes(A, B) - return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2) - - -def switchback_bnb( - A: torch.Tensor, - B: torch.Tensor, - out: Optional[torch.Tensor] = None, - state: Optional[MatmulLtState] = None, - threshold=0.0, - bias=None, -): - state = state or MatmulLtState() - if threshold > 0.0: - state.threshold = threshold - return SwitchBackBnb.apply(A, B, out, bias, state) diff --git a/bitsandbytes/research/nn/__init__.py b/bitsandbytes/research/nn/__init__.py deleted file mode 100644 index 417011218..000000000 --- a/bitsandbytes/research/nn/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .modules import LinearFP8Global, LinearFP8Mixed diff --git a/bitsandbytes/research/nn/modules.py b/bitsandbytes/research/nn/modules.py deleted file mode 100644 index 57c0f3358..000000000 --- a/bitsandbytes/research/nn/modules.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import TypeVar - -import torch -from torch import nn - -import bitsandbytes as bnb - -T = TypeVar("T", bound="torch.nn.Module") - - -class LinearFP8Mixed(nn.Linear): - def __init__(self, input_features, output_features, bias=True): - super().__init__(input_features, output_features, bias) - self.bw_code = None - self.fw_code = None - array = [4096, 2048, 1024, 512, 256, 128, 64, 0] - for i, k in enumerate(array): - if input_features > array[i + 1]: - self.bsz = k - break - for i, k in enumerate(array): - if output_features > array[i + 1]: - self.bsz2 = k - break - - def forward(self, x: torch.Tensor): - if self.fw_code is None: - self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) - self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) - - out = bnb.research.matmul_fp8_mixed( - x, - self.weight.t(), - fw_code=self.fw_code, - bw_code=self.bw_code, - bsz=self.bsz, - bsz2=self.bsz2, - ) - if self.bias is not None: - out += self.bias - - return out - - -class LinearFP8Global(nn.Linear): - def __init__(self, input_features, output_features, bias=True): - super().__init__(input_features, output_features, bias) - self.bw_code = None - self.fw_code = None - array = [4096, 2048, 1024, 512, 256, 128, 64, 0] - for i, k in enumerate(array): - if input_features > array[i + 1]: - self.bsz = k - break - for i, k in enumerate(array): - if output_features > array[i + 1]: - self.bsz2 = k - break - - def forward(self, x: torch.Tensor): - if self.fw_code is None: - self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) - self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) - - out = bnb.matmul_fp8_global( - x, - self.weight.t(), - fw_code=self.fw_code, - bw_code=self.bw_code, - bsz=self.bsz, - bsz2=self.bsz2, - ) - if self.bias is not None: - out += self.bias - - return out diff --git a/bitsandbytes/triton/__init__.py b/bitsandbytes/triton/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bitsandbytes/triton/dequantize_rowwise.py b/bitsandbytes/triton/dequantize_rowwise.py deleted file mode 100644 index 26eab84f2..000000000 --- a/bitsandbytes/triton/dequantize_rowwise.py +++ /dev/null @@ -1,64 +0,0 @@ -import math - -import torch - -from bitsandbytes.triton.triton_utils import is_triton_available - -if not is_triton_available(): - - def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): - return None -else: - import triton - import triton.language as tl - - # rowwise quantize - - # TODO: autotune this better. - @triton.autotune( - configs=[ - triton.Config({}, num_stages=1, num_warps=8), - triton.Config({}, num_stages=2, num_warps=8), - triton.Config({}, num_stages=4, num_warps=8), - triton.Config({}, num_stages=8, num_warps=8), - triton.Config({}, num_stages=1), - triton.Config({}, num_stages=2), - triton.Config({}, num_stages=4), - triton.Config({}, num_stages=8), - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - ], - key=["n_elements"], - ) - @triton.jit - def _dequantize_rowwise( - x_ptr, - state_x, - output_ptr, - inv_127, - n_elements, - BLOCK_SIZE: tl.constexpr, - P2: tl.constexpr, - ): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - arange = tl.arange(0, P2) - offsets = block_start + arange - row_mask = arange < BLOCK_SIZE - x = tl.load(x_ptr + offsets, mask=row_mask) - max_val = tl.load(state_x + pid) - output = max_val * x * inv_127 - tl.store(output_ptr + offsets, output, mask=row_mask) - - def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): - output = torch.empty(*x.shape, device=x.device, dtype=torch.float16) - - P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) - - assert x.is_cuda and output.is_cuda - n_elements = output.numel() - grid = lambda meta: (x.shape[0],) - _dequantize_rowwise[grid](x, state_x, output, 1.0 / 127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) - return output diff --git a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py deleted file mode 100644 index 5fcb927d4..000000000 --- a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py +++ /dev/null @@ -1,206 +0,0 @@ -import torch - -from bitsandbytes.triton.triton_utils import is_triton_available - -if not is_triton_available(): - - def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): - return None -else: - import triton - import triton.language as tl - - from .matmul_perf_model import early_config_prune, estimate_matmul_time - - # This is a matmul kernel based on triton.ops.matmul - # It is modified to support rowwise quantized input and global quantized weight - # It's purpose is fused matmul then dequantize - # It does support bias. - - def init_to_zero(name): - return lambda nargs: nargs[name].zero_() - - def get_configs_io_bound(): - configs = [] - for num_stages in [2, 3, 4, 5, 6]: - for block_m in [16, 32]: - for block_k in [32, 64]: - for block_n in [32, 64, 128, 256]: - num_warps = 2 if block_n <= 64 else 4 - configs.append( - triton.Config( - {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1}, - num_stages=num_stages, - num_warps=num_warps, - ), - ) - # split_k - for split_k in [2, 4, 8, 16]: - configs.append( - triton.Config( - {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k}, - num_stages=num_stages, - num_warps=num_warps, - pre_hook=init_to_zero("C"), - ), - ) - return configs - - @triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2), - # good for int8 - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2), - *get_configs_io_bound(), - ], - key=["M", "N", "K"], - prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10}, - ) - @triton.heuristics( - { - "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, - }, - ) - @triton.jit - def _int8_matmul_mixed_dequantize( - A, - B, - C, - bias, - state_x_ptr, - state_w_ptr, - M, - N, - K, - divfactor: tl.constexpr, - has_bias: tl.constexpr, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, - SPLIT_K: tl.constexpr, - EVEN_K: tl.constexpr, - ACC_TYPE: tl.constexpr, - ): - # matrix multiplication - pid = tl.program_id(0) - pid_z = tl.program_id(1) - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) - # pointers - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - w_factor = tl.load(state_w_ptr) - x_factor = tl.load(state_x_ptr + ram)[:, None] - - # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) - for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - k_remaining = K - k * (BLOCK_K * SPLIT_K) - a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0) - b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0) - acc += tl.dot(a, b) - A += BLOCK_K * SPLIT_K * stride_ak - B += BLOCK_K * SPLIT_K * stride_bk - - acc = w_factor * (x_factor * (acc * divfactor)) - acc = acc.to(C.dtype.element_ty) - - # conditionally add bias - if has_bias: - bias = tl.load(bias + rn).to(C.dtype.element_ty) - acc = acc + bias[None, :] - - C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) - mask = (rm < M)[:, None] & (rn < N)[None, :] - # handles write-back with reduction-splitting - if SPLIT_K == 1: - tl.store(C, acc, mask=mask) - else: - tl.atomic_add(C, acc, mask=mask) - - def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): - device = a.device - divfactor = 1.0 / (127.0 * 127.0) - has_bias = 0 if bias is None else 1 - # handle non-contiguous inputs if necessary - if a.stride(0) > 1 and a.stride(1) > 1: - a = a.contiguous() - if b.stride(0) > 1 and b.stride(1) > 1: - b = b.contiguous() - # checks constraints - assert a.shape[1] == b.shape[0], "incompatible dimensions" - M, K = a.shape - _, N = b.shape - # allocates output - c = torch.empty((M, N), device=device, dtype=torch.float16) - # accumulator types - ACC_TYPE = tl.float32 # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - # launch int8_matmul_mixed_dequantize kernel - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"]) - _int8_matmul_mixed_dequantize[grid]( - a, - b, - c, - bias, - state_x, - state_w, - M, - N, - K, - divfactor, - has_bias, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - GROUP_M=8, - ACC_TYPE=ACC_TYPE, - ) - return c diff --git a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py deleted file mode 100644 index 05e30a4c9..000000000 --- a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py +++ /dev/null @@ -1,207 +0,0 @@ -import torch - -from bitsandbytes.triton.triton_utils import is_triton_available - -if not is_triton_available(): - - def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): - return None -else: - import triton - import triton.language as tl - - from .matmul_perf_model import early_config_prune, estimate_matmul_time - - # This is a matmul kernel based on triton.ops.matmul - # It is modified to support rowwise quantized input and columnwise quantized weight - # It's purpose is fused matmul then dequantize - # It does support bias. - - def init_to_zero(name): - return lambda nargs: nargs[name].zero_() - - def get_configs_io_bound(): - configs = [] - for num_stages in [2, 3, 4, 5, 6]: - for block_m in [16, 32]: - for block_k in [32, 64]: - for block_n in [32, 64, 128, 256]: - num_warps = 2 if block_n <= 64 else 4 - configs.append( - triton.Config( - {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1}, - num_stages=num_stages, - num_warps=num_warps, - ), - ) - # split_k - for split_k in [2, 4, 8, 16]: - configs.append( - triton.Config( - {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k}, - num_stages=num_stages, - num_warps=num_warps, - pre_hook=init_to_zero("C"), - ), - ) - return configs - - @triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2), - # good for int8 - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2), - *get_configs_io_bound(), - ], - key=["M", "N", "K"], - prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10}, - ) - @triton.heuristics( - { - "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, - }, - ) - @triton.jit - def _int8_matmul_rowwise_dequantize( - A, - B, - C, - bias, - state_x_ptr, - state_w_ptr, - M, - N, - K, - divfactor, - has_bias: tl.constexpr, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, - SPLIT_K: tl.constexpr, - EVEN_K: tl.constexpr, - ACC_TYPE: tl.constexpr, - ): - # matrix multiplication - pid = tl.program_id(0) - pid_z = tl.program_id(1) - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) - # pointers - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - w_factor = tl.load(state_w_ptr + rbn)[None, :] - x_factor = tl.load(state_x_ptr + ram)[:, None] - - # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) - for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - k_remaining = K - k * (BLOCK_K * SPLIT_K) - a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0) - b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0) - acc += tl.dot(a, b) - A += BLOCK_K * SPLIT_K * stride_ak - B += BLOCK_K * SPLIT_K * stride_bk - - acc = w_factor * (x_factor * (acc * divfactor)) - acc = acc.to(C.dtype.element_ty) - - if has_bias: - bias = tl.load(bias + rn).to(C.dtype.element_ty) - acc = acc + bias[None, :] - - C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) - mask = (rm < M)[:, None] & (rn < N)[None, :] - # handles write-back with reduction-splitting - if SPLIT_K == 1: - tl.store(C, acc, mask=mask) - else: - tl.atomic_add(C, acc, mask=mask) - - def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): - divfactor = 1.0 / (127.0 * 127.0) - - has_bias = 0 if bias is None else 1 - - device = a.device - # handle non-contiguous inputs if necessary - if a.stride(0) > 1 and a.stride(1) > 1: - a = a.contiguous() - if b.stride(0) > 1 and b.stride(1) > 1: - b = b.contiguous() - # checks constraints - assert a.shape[1] == b.shape[0], "incompatible dimensions" - M, K = a.shape - _, N = b.shape - # allocates output - c = torch.empty((M, N), device=device, dtype=torch.float16) - # accumulator types - ACC_TYPE = tl.float32 # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - # launch int8_matmul_rowwise_dequantize kernel - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"]) - _int8_matmul_rowwise_dequantize[grid]( - a, - b, - c, - bias, - state_x, - state_w, - M, - N, - K, - divfactor, - has_bias, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - GROUP_M=8, - ACC_TYPE=ACC_TYPE, - ) - return c diff --git a/bitsandbytes/triton/matmul_perf_model.py b/bitsandbytes/triton/matmul_perf_model.py deleted file mode 100644 index e843a3a39..000000000 --- a/bitsandbytes/triton/matmul_perf_model.py +++ /dev/null @@ -1,211 +0,0 @@ -# Adapted from https://github.com/triton-lang/kernels/blob/eeeebdd8be7d13629de22d600621e6234057eed3/kernels/matmul_perf_model.py -# https://github.com/triton-lang/kernels is licensed under the MIT License. - -import functools -import heapq - -import torch - -from triton import cdiv -from triton.runtime import driver -from triton.testing import ( - get_dram_gbps, - get_max_simd_tflops, - get_max_tensorcore_tflops, - nvsmi, -) - - -@functools.lru_cache -def get_clock_rate_in_khz(): - try: - return nvsmi(["clocks.max.sm"])[0] * 1e3 - except FileNotFoundError: - import pynvml - - pynvml.nvmlInit() - handle = pynvml.nvmlDeviceGetHandleByIndex(0) - return pynvml.nvmlDeviceGetMaxClockInfo(handle, pynvml.NVML_CLOCK_SM) * 1e3 - - -def get_tensorcore_tflops(device, num_ctas, num_warps, dtype): - """return compute throughput in TOPS""" - total_warps = num_ctas * min(num_warps, 4) - num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs - tflops = ( - min(num_subcores, total_warps) - / num_subcores - * get_max_tensorcore_tflops(dtype, get_clock_rate_in_khz(), device) - ) - return tflops - - -def get_simd_tflops(device, num_ctas, num_warps, dtype): - """return compute throughput in TOPS""" - total_warps = num_ctas * min(num_warps, 4) - num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs - tflops = ( - min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, get_clock_rate_in_khz(), device) - ) - return tflops - - -def get_tflops(device, num_ctas, num_warps, dtype): - capability = torch.cuda.get_device_capability(device) - if capability[0] < 8 and dtype == torch.float32: - return get_simd_tflops(device, num_ctas, num_warps, dtype) - return get_tensorcore_tflops(device, num_ctas, num_warps, dtype) - - -def estimate_matmul_time( - # backend, device, - num_warps, - num_stages, # - A, - B, - C, # - M, - N, - K, # - BLOCK_M, - BLOCK_N, - BLOCK_K, - SPLIT_K, # - debug=False, - **kwargs, # -): - """return estimated running time in ms - = max(compute, loading) + store""" - device = torch.cuda.current_device() - dtype = A.dtype - dtsize = A.element_size() - - num_cta_m = cdiv(M, BLOCK_M) - num_cta_n = cdiv(N, BLOCK_N) - num_cta_k = SPLIT_K - num_ctas = num_cta_m * num_cta_n * num_cta_k - - # If the input is smaller than the block size - M, N = max(M, BLOCK_M), max(N, BLOCK_N) - - # time to compute - total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS - tput = get_tflops(device, num_ctas, num_warps, dtype) - compute_ms = total_ops / tput - - # time to load data - num_sm = driver.active.utils.get_device_properties(device)["multiprocessor_count"] - active_cta_ratio = min(1, num_ctas / num_sm) - active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate - active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5% - dram_bw = get_dram_gbps(device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s - l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) - # assume 80% of (following) loads are in L2 cache - load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1)) - load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1) - load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1)) - load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1) - # total - total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB - total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024) - # loading time in ms - load_ms = total_dram / dram_bw + total_l2 / l2_bw - - # estimate storing time - store_bw = dram_bw * 0.6 # :o - store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB - if SPLIT_K == 1: - store_ms = store_c_dram / store_bw - else: - reduce_bw = store_bw - store_ms = store_c_dram / reduce_bw - # c.zero_() - zero_ms = M * N * 2 / (1024 * 1024) / store_bw - store_ms += zero_ms - - total_time_ms = max(compute_ms, load_ms) + store_ms - if debug: - print( - f"Total time: {total_time_ms}ms, compute time: {compute_ms}ms, " - f"loading time: {load_ms}ms, store time: {store_ms}ms, " - f"Activate CTAs: {active_cta_ratio * 100}%" - ) - return total_time_ms - - -def early_config_prune(configs, named_args, **kwargs): - device = torch.cuda.current_device() - capability = torch.cuda.get_device_capability() - # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages - dtsize = named_args["A"].element_size() - dtype = named_args["A"].dtype - - # 1. make sure we have enough smem - pruned_configs = [] - for config in configs: - kw = config.kwargs - BLOCK_M, BLOCK_N, BLOCK_K, num_stages = ( - kw["BLOCK_M"], - kw["BLOCK_N"], - kw["BLOCK_K"], - config.num_stages, - ) - - max_shared_memory = driver.active.utils.get_device_properties(device)["max_shared_mem"] - required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize - if required_shared_memory <= max_shared_memory: - pruned_configs.append(config) - configs = pruned_configs - - # Some dtypes do not allow atomic_add - if dtype not in [torch.float16, torch.float32]: - configs = [config for config in configs if config.kwargs["SPLIT_K"] == 1] - - # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) - configs_map = {} - for config in configs: - kw = config.kwargs - BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = ( - kw["BLOCK_M"], - kw["BLOCK_N"], - kw["BLOCK_K"], - kw["SPLIT_K"], - config.num_warps, - config.num_stages, - ) - - key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps) - if key in configs_map: - configs_map[key].append((config, num_stages)) - else: - configs_map[key] = [(config, num_stages)] - - pruned_configs = [] - for k, v in configs_map.items(): - BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k - if capability[0] >= 8: - # compute cycles (only works for ampere GPUs) - mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16) - mma_cycles = mmas / min(4, num_warps) * 8 - - ldgsts_latency = 300 # Does this matter? - optimal_num_stages = ldgsts_latency / mma_cycles - - # nearest stages, prefer large #stages - nearest = heapq.nsmallest( - 2, - v, - key=lambda x: ( - 10 + abs(x[1] - optimal_num_stages) - if (x[1] - optimal_num_stages) < 0 - else x[1] - optimal_num_stages - ), - ) - - for n in nearest: - pruned_configs.append(n[0]) - else: # Volta & Turing only supports num_stages <= 2 - random_config = v[0][0] - random_config.num_stages = 2 - pruned_configs.append(random_config) - return pruned_configs diff --git a/bitsandbytes/triton/quantize_columnwise_and_transpose.py b/bitsandbytes/triton/quantize_columnwise_and_transpose.py deleted file mode 100644 index b8eeffd0c..000000000 --- a/bitsandbytes/triton/quantize_columnwise_and_transpose.py +++ /dev/null @@ -1,75 +0,0 @@ -import math - -import torch - -from bitsandbytes.triton.triton_utils import is_triton_available - -if not is_triton_available(): - - def quantize_columnwise_and_transpose(x: torch.Tensor): - return None -else: - import triton - import triton.language as tl - - # This kernel does fused columnwise quantization and transpose. - - # TODO: autotune this better. - @triton.autotune( - configs=[ - triton.Config({}, num_stages=1), - triton.Config({}, num_stages=2), - triton.Config({}, num_stages=4), - triton.Config({}, num_stages=8), - triton.Config({}, num_stages=16), - triton.Config({}, num_stages=1, num_warps=8), - triton.Config({}, num_stages=2, num_warps=8), - triton.Config({}, num_stages=4, num_warps=8), - triton.Config({}, num_stages=8, num_warps=8), - triton.Config({}, num_stages=16, num_warps=8), - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - ], - key=["n_elements"], - ) - @triton.jit - def _quantize_columnwise_and_transpose( - x_ptr, - output_ptr, - output_maxs, - n_elements, - M: tl.constexpr, - N: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - P2: tl.constexpr, - ): - pid = tl.program_id(axis=0) - block_start = pid - p2_arange = tl.arange(0, P2) - p2_arange_mask = p2_arange < M - arange = p2_arange * N - offsets = block_start + arange - x = tl.load(x_ptr + offsets, mask=p2_arange_mask) - abs_x = tl.abs(x) - max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0) - output = tl.libdevice.llrint(127.0 * (x / max_val)) - - new_start = pid * M - new_offsets = new_start + p2_arange - tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask) - tl.store(output_maxs + pid, max_val) - - def quantize_columnwise_and_transpose(x: torch.Tensor): - M, N = x.shape - output = torch.empty(N, M, device=x.device, dtype=torch.int8) - output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16) - - P2 = int(2 ** (math.ceil(math.log2(M)))) - - assert x.is_cuda and output.is_cuda - n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) - return output, output_maxs diff --git a/bitsandbytes/triton/quantize_global.py b/bitsandbytes/triton/quantize_global.py deleted file mode 100644 index f35bdd304..000000000 --- a/bitsandbytes/triton/quantize_global.py +++ /dev/null @@ -1,124 +0,0 @@ -import torch - -from bitsandbytes.triton.triton_utils import is_triton_available - -if not is_triton_available(): - - def quantize_global_transpose(input): - return None - - def quantize_global(x: torch.Tensor): - return None -else: - import triton - import triton.language as tl - - # global quantize - @triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), - triton.Config({"BLOCK_SIZE": 2048}, num_stages=1), - ], - key=["n_elements"], - ) - @triton.jit - def _quantize_global( - x_ptr, - absmax_inv_ptr, - output_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, - ): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - absmax_inv = tl.load(absmax_inv_ptr) - output = tl.libdevice.llrint(127.0 * (x * absmax_inv)) - tl.store(output_ptr + offsets, output, mask=mask) - - def quantize_global(x: torch.Tensor): - absmax = x.abs().max().unsqueeze(0) - absmax_inv = 1.0 / absmax - output = torch.empty(*x.shape, device="cuda", dtype=torch.int8) - assert x.is_cuda and output.is_cuda - n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - _quantize_global[grid](x, absmax_inv, output, n_elements) - return output, absmax - - # global quantize and transpose - @triton.autotune( - configs=[ - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4), - # ... - ], - key=["M", "N"], - ) - @triton.jit - def _quantize_global_transpose( - A, - absmax_inv_ptr, - B, - stride_am, - stride_an, - stride_bn, - stride_bm, - M, - N, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - GROUP_M: tl.constexpr, - ): - pid = tl.program_id(0) - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // group_size - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an) - mask = (rm < M)[:, None] & (rn < N)[None, :] - a = tl.load(A, mask=mask) - absmax_inv = tl.load(absmax_inv_ptr) - - # rematerialize to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn) - mask = (rm < M)[:, None] & (rn < N)[None, :] - - output = tl.libdevice.llrint(127.0 * (a * absmax_inv)) - - tl.store(B, output, mask=mask) - - def quantize_global_transpose(input): - absmax = input.abs().max().unsqueeze(0) - absmax_inv = 1.0 / absmax - M, N = input.shape - out = torch.empty(N, M, device="cuda", dtype=torch.int8) - - assert out.size(0) == N and out.size(1) == M - assert input.stride(0) == 1 or input.stride(1) == 1 - assert out.stride(0) == 1 or out.stride(1) == 1 - - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) - _quantize_global_transpose[grid]( - input, - absmax_inv, - out, - input.stride(0), - input.stride(1), - out.stride(0), - out.stride(1), - M, - N, - ) - return out, absmax diff --git a/bitsandbytes/triton/quantize_rowwise.py b/bitsandbytes/triton/quantize_rowwise.py deleted file mode 100644 index f92ace02c..000000000 --- a/bitsandbytes/triton/quantize_rowwise.py +++ /dev/null @@ -1,67 +0,0 @@ -import math - -import torch - -from bitsandbytes.triton.triton_utils import is_triton_available - -if not is_triton_available(): - - def quantize_rowwise(x: torch.Tensor): - return None -else: - import triton - import triton.language as tl - - # rowwise quantize - - # TODO: autotune this better. - @triton.autotune( - configs=[ - triton.Config({}, num_stages=1, num_warps=8), - triton.Config({}, num_stages=2, num_warps=8), - triton.Config({}, num_stages=4, num_warps=8), - triton.Config({}, num_stages=8, num_warps=8), - triton.Config({}, num_stages=1), - triton.Config({}, num_stages=2), - triton.Config({}, num_stages=4), - triton.Config({}, num_stages=8), - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - ], - key=["n_elements"], - ) - @triton.jit - def _quantize_rowwise( - x_ptr, - output_ptr, - output_maxs, - n_elements, - BLOCK_SIZE: tl.constexpr, - P2: tl.constexpr, - ): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - arange = tl.arange(0, P2) - offsets = block_start + arange - row_mask = arange < BLOCK_SIZE - x = tl.load(x_ptr + offsets, mask=row_mask) - - abs_x = tl.abs(x) - max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) - output = tl.libdevice.llrint(127.0 * (x / max_val)) - tl.store(output_ptr + offsets, output, mask=row_mask) - tl.store(output_maxs + pid, max_val) - - def quantize_rowwise(x: torch.Tensor): - output = torch.empty(*x.shape, device=x.device, dtype=torch.int8) - output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16) - - P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) - - assert x.is_cuda and output.is_cuda - n_elements = output.numel() - grid = lambda meta: (x.shape[0],) - _quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) - return output, output_maxs diff --git a/bitsandbytes/triton/triton_utils.py b/bitsandbytes/triton/triton_utils.py deleted file mode 100644 index f6bedd8cd..000000000 --- a/bitsandbytes/triton/triton_utils.py +++ /dev/null @@ -1,11 +0,0 @@ -import functools - - -@functools.lru_cache(None) -def is_triton_available(): - try: - from torch.utils._triton import has_triton, has_triton_package - - return has_triton_package() and has_triton() - except Exception: - return False diff --git a/docs/source/optimizers.mdx b/docs/source/optimizers.mdx index 7d04f82b1..3e5f6a2aa 100644 --- a/docs/source/optimizers.mdx +++ b/docs/source/optimizers.mdx @@ -30,12 +30,12 @@ import bitsandbytes as bnb adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384) ``` -Other parameters you can configure include the learning rate (`lr`), the decay rates (`betas`), the number of bits of the optimizer state (`optim_bits`), and percentile clipping (`percentile_clipping`) which can increase stability. For example, to initialize a 32-bit [`~bitsandbytes.optim.Adam`] optimizer with 5th percentile clipping: +Other parameters you can configure include the learning rate (`lr`), the decay rates (`betas`), and the number of bits of the optimizer state (`optim_bits`). For example, to initialize a 32-bit [`~bitsandbytes.optim.Adam`] optimizer: ```py import bitsandbytes as bnb -adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=32, percentile_clipping=5) +adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=32) ``` ## Optimize unstable parameters diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 7134925c1..b942e6ab7 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -22,8 +22,8 @@ @pytest.mark.parametrize("decomp", [0.0, 6.0], ids=id_formatter("decomp")) @pytest.mark.parametrize( "funcs", - [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)], - ids=["func=matmul", "func=switchback_bnb"], + [(torch.matmul, bnb.matmul)], + ids=["func=matmul"], ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) @@ -34,10 +34,6 @@ def test_matmullt( device, dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias ): if device != "cuda": - if funcs[1] == bnb.research.switchback_bnb: - # TODO: Deprecate/remove? - pytest.skip("switchback_bnb only works on CUDA.") - if req_grad[1]: # This will be deprecated for CUDA in the future. We don't expect # this to work on any other device. diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py deleted file mode 100644 index feb17c966..000000000 --- a/tests/test_deprecated.py +++ /dev/null @@ -1,175 +0,0 @@ -import pytest -import torch - -import bitsandbytes as bnb -from bitsandbytes import functional as F -from tests.helpers import BOOLEAN_TRIPLES, describe_dtype, get_test_dims, id_formatter -from tests.test_autograd import TRANSPOSE_VALS - - -@pytest.mark.deprecated -def test_dynamic_quantization(): - diffs = [] - reldiffs = [] - for i in range(100): - A1 = torch.randn(1024, 1024, device="cuda") - C, S = F.quantize(A1) - A2 = F.dequantize(C, S) - diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) - diffs.append(diff.mean().item()) - reldiffs.append(reldiff.mean().item()) - assert diff.mean().item() < 0.0135 - print(sum(diffs) / len(diffs)) - print(sum(reldiffs) / len(reldiffs)) - - for i in range(100): - A1 = torch.rand(1024, 1024, device="cuda") - C, S = F.quantize(A1) - A2 = F.dequantize(C, S) - diff = torch.abs(A1 - A2).mean().item() - torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) - assert diff < 0.004 - - -@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"]) -@pytest.mark.deprecated -def test_percentile_clipping(gtype): - gnorm_vec1 = torch.zeros(100, device="cuda") - gnorm_vec2 = torch.zeros(100, device="cuda") - n = 4 - step = 0 - percentile = 5 - for i in range(20): - step += 1 - g = torch.randn(n, n, dtype=gtype, device="cuda") - gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile) - assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1 - - gnorm2 = torch.norm(g.float()) - if step == 1: - gnorm_vec1[:] = gnorm2 - else: - gnorm_vec1[step % 100] = gnorm2 - - vals, _ = torch.sort(gnorm_vec1) - clip1 = vals[percentile] - - torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2)) - torch.testing.assert_close(clip1, clip2) - torch.testing.assert_close(gnorm1, gnorm2) - - -@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) -@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) -@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) -@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype) -@pytest.mark.parametrize( - "funcs", - [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)], - ids=["matmul_fp8_mixed", "matmul_fp8_global"], -) -@pytest.mark.deprecated -@pytest.mark.skip("Deprecated functionality, to be removed.") -def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): - dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) - dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) - req_grad = list(req_grad) - req_grad[2] = False - - for i in range(3): - # normal multiply - if funcs[0] in [torch.mm, torch.matmul]: - A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) - B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) - target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype) - - torch.nn.init.xavier_uniform_(B) - - fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(A.device) - bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(A.device) - - if not transpose[0] and transpose[1]: - out_torch = funcs[0](A, B.t()) - out_bnb = funcs[1](A, B.t(), fw_code, bw_code) - elif not transpose[0] and not transpose[1]: - out_torch = funcs[0](A, B) - out_bnb = funcs[1](A, B, fw_code, bw_code) - - assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}" - - n = out_bnb.numel() - err = torch.abs(out_bnb - out_torch).float().mean().item() - if n > 0: - assert err < 0.115 - # assert err < 0.20 - if any(req_grad): - out_bnb.data.copy_(out_torch) - torch.cuda.synchronize() - loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() - loss_bnb.backward() - gradA1 = A.grad - gradB1 = B.grad - A.grad = None - B.grad = None - - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() - loss_torch.backward() - gradA2 = A.grad - gradB2 = B.grad - A.grad = None - B.grad = None - - if req_grad[0]: - torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) - - if req_grad[1]: - n = gradB1.numel() - if dim2 > 0: - assert torch.abs(gradB1).sum() > 0.0 - assert torch.abs(gradB2).sum() > 0.0 - else: - assert torch.abs(gradB1).sum() == 0.0 - assert torch.abs(gradB2).sum() == 0.0 - idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) - - assert (idx == 0).sum().item() <= n * 0.1 - idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) - assert (idx == 0).sum().item() <= n * 0.02 - grad_err = (gradB1 - gradB2).abs().mean() - assert grad_err.item() < 0.003 - torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3) - - -@pytest.mark.deprecated -def test_fp8linear(): - b = 10 - h = 1024 - inp = torch.randn(b, h).cuda() - fp32 = torch.nn.Linear(h, h * 2).cuda() - fp8 = bnb.research.nn.LinearFP8Mixed(h, h * 2).cuda() - fp32b = torch.nn.Linear(h * 2, h).cuda() - fp8b = bnb.research.nn.LinearFP8Mixed(h * 2, h).cuda() - - fp8.weight.data.copy_(fp32.weight.data) - fp8.bias.data.copy_(fp32.bias.data) - fp8b.weight.data.copy_(fp32b.weight.data) - fp8b.bias.data.copy_(fp32b.bias.data) - - a = fp32b(torch.nn.functional.gelu(fp32(inp))) - b = fp8b(torch.nn.functional.gelu(fp8(inp))) - - err = (a - b).abs().mean() - - a.mean().backward() - b.mean().backward() - - graderr = (fp8.weight.grad - fp32.weight.grad).abs().mean() - bgraderr = (fp8.bias.grad - fp32.bias.grad).abs().mean() - - assert err < 0.05 - assert graderr < 0.00002 - assert bgraderr < 0.00002 diff --git a/tests/test_functional.py b/tests/test_functional.py index d2e3f0847..4670847ff 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -706,7 +706,6 @@ def test_dequant_mm(self, device, dim1, dim4, dims, has_bias): @pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2")) - @pytest.mark.deprecated def test_int8_double_quant(self, dim1, dim2): for i in range(k): A = torch.randn(dim1, dim2, device="cuda").half() diff --git a/tests/test_optim.py b/tests/test_optim.py index 190d9a206..6216639b0 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -49,16 +49,16 @@ def rm_path(path): ) str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) -str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) +str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx)) str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam) str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW) str2optimizers["paged_adam8bit_blockwise"] = ( torch.optim.Adam, - lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True), + lambda pxx: bnb.optim.PagedAdam8bit(pxx), ) str2optimizers["paged_adamw8bit_blockwise"] = ( torch.optim.AdamW, - lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True), + lambda pxx: bnb.optim.PagedAdamW8bit(pxx), ) str2optimizers["ademamix"] = (bnb.optim.ademamix._ReferenceAdEMAMix, bnb.optim.AdEMAMix) @@ -90,25 +90,25 @@ def rm_path(path): str2optimizers["lion"] = (Lion, bnb.optim.Lion) str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion) -str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True)) -str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True)) +str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx)) +str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx)) str2optimizers["momentum"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), - lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), + lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9), ) str2optimizers["momentum8bit_blockwise"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), - lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), + lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9), ) str2optimizers["rmsprop"] = ( lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), - lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), + lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9), ) str2optimizers["rmsprop8bit_blockwise"] = ( lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), - lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True), + lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9), ) str2statenames = {} @@ -462,94 +462,6 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name, device): torch_optimizer.state[p1][name1].copy_(s.data) -@pytest.mark.parametrize("optim_bits", [32, 8], ids=id_formatter("optim_bits")) -@pytest.mark.parametrize("gtype", [torch.float32], ids=describe_dtype) -@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) -@pytest.mark.deprecated -def test_adam_percentile_clipping(requires_cuda, dim1, dim2, gtype, optim_bits): - if dim1 == 1 and dim2 == 1: - return - p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1 - beta1 = 0.9 - beta2 = 0.999 - lr = 0.001 - eps = 1e-8 - p1 = p1.cuda() - p2 = p1.clone() - adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits) - adam2 = bnb.optim.Adam( - [p2], - lr, - (beta1, beta2), - eps, - optim_bits=optim_bits, - percentile_clipping=5, - ) - - gnorm_vec = torch.zeros(100).cuda() - step = 0 - - for i in range(50): - step += 1 - g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i) - g2 = g1.clone() - p2.grad = g2 - - _current_gnorm, _clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5) - g1 = (g1.float() * gnorm_scale).to(gtype) - p1.grad = g1 - - adam1.step() - adam2.step() - - # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state - if optim_bits == 32: - torch.testing.assert_close(p1, p2) - torch.testing.assert_close( - adam1.state[p1]["state1"], - adam2.state[p2]["state1"], - atol=5e-5, - rtol=1e-4, - ) - torch.testing.assert_close( - adam1.state[p1]["state2"], - adam2.state[p2]["state2"], - atol=5e-5, - rtol=1e-4, - ) - elif optim_bits == 8: - torch.testing.assert_close(p1, p2, atol=1e-4, rtol=1e-3) - torch.testing.assert_close( - adam1.state[p1]["state1"], - adam2.state[p2]["state1"], - atol=2, - rtol=1e-3, - ) - torch.testing.assert_close( - adam1.state[p1]["state2"], - adam2.state[p2]["state2"], - atol=2, - rtol=1e-3, - ) - adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"]) - adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"]) - if i % 10 == 0 and i > 0: - path = get_temp_dir() - torch.save(adam2.state_dict(), join(path, "opt.pt")) - del adam2 - adam2 = None - adam2 = bnb.optim.Adam( - [p2], - lr, - (beta1, beta2), - eps, - optim_bits=optim_bits, - percentile_clipping=5, - ) - adam2.load_state_dict(torch.load(join(path, "opt.pt"))) - - optimizer_names_benchmark = [ "adam8bit_blockwise", "paged_adam8bit_blockwise", diff --git a/tests/test_triton.py b/tests/test_triton.py deleted file mode 100644 index b245e534a..000000000 --- a/tests/test_triton.py +++ /dev/null @@ -1,64 +0,0 @@ -import pytest -import torch - -from bitsandbytes.nn import Linear8bitLt -from bitsandbytes.nn.triton_based_modules import SwitchBackLinear -from bitsandbytes.triton.triton_utils import is_triton_available -from tests.helpers import TRUE_FALSE - - -@pytest.mark.skipif( - not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, - reason="This test requires triton and a GPU with compute capability 8.0 or higher.", -) -@pytest.mark.deprecated -@pytest.mark.parametrize("vector_wise_quantization", TRUE_FALSE) -def test_switchback(vector_wise_quantization): - for dim in [83]: - for batch in [13]: - standard = torch.nn.Linear(dim, 4 * dim).cuda().half() - switchback = ( - SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half() - ) - baseline = Linear8bitLt(dim, 4 * dim).cuda().half() - switchback.weight.data.copy_(standard.weight) - switchback.bias.data.copy_(standard.bias) - baseline.weight.data.copy_(standard.weight) - baseline.bias.data.copy_(standard.bias) - - x1 = torch.randn(batch, dim).cuda().half().requires_grad_(True) - x2 = x1.clone().detach().requires_grad_(True) - x3 = x1.clone().detach().requires_grad_(True) - - out_standard = standard(x1) - (2**10 * out_standard.abs().mean()).backward() - - print(x2.dtype) - out_sb = switchback(x2) - (2**10 * out_sb.abs().mean()).backward() - - out_baseline = baseline(x3) - (2**10 * out_baseline.abs().mean()).backward() - - err_sb = (out_standard - out_sb).abs().mean() - err_baseline = (out_standard - out_baseline).abs().mean() - print("OUT", err_sb, err_baseline) - assert err_sb < 2 * err_baseline - - err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean() - err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean() - - print("GW2", err_sb, err_baseline) - assert err_sb < 2 * err_baseline - - err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean() - err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean() - - print("GW1", err_sb, err_baseline) - assert err_sb < 2 * err_baseline - - err_sb = (x1.grad - x2.grad).abs().mean() - err_baseline = (x1.grad - x3.grad).abs().mean() - - print("GX1", err_sb, err_baseline) - assert err_sb < 2 * err_baseline