diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 9c20f9376..ee1781a8b 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -64,9 +64,9 @@ def override_config(self, parameters, key=None, value=None, key_value_dict=None) parameters (`torch.Tensor` or `list(torch.Tensors)`): The input parameters. key (`str`): - The hyperparamter to override. + The hyperparameter to override. value: - The hyperparameter values. + The hyperparameter value. key_value_dict (`dict`): A dictionary with multiple key-values to override. @@ -115,7 +115,7 @@ def __init__(self, params, defaults, optim_bits=32, is_paged=False): Base 8-bit optimizer class. Arguments: - params (`torch.tensor`): + params (`torch.Tensor`): The input parameters to optimize. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. @@ -291,7 +291,7 @@ def step(self, closure=None): self.update_step(group, p, gindex, pindex) torch.cuda.synchronize() if self.is_paged: - # all paged operation are asynchronous, we need + # all paged operations are asynchronous, we need # to sync to make sure all tensors are in the right state torch.cuda.synchronize() @@ -371,7 +371,7 @@ def __init__( Arguments: optimizer_name (`str`): The name of the optimizer. - params (`torch.tensor`): + params (`torch.Tensor`): The input parameters to optimize. lr (`float`, defaults to 1e-3): The learning rate. @@ -428,7 +428,6 @@ def __init__( if args is None: args = {} args["optim_bits"] = optim_bits - args["percentile_clipping"] = 100 args["min_8bit_size"] = min_8bit_size args["percentile_clipping"] = percentile_clipping args["block_wise"] = block_wise @@ -613,7 +612,7 @@ def __init__( Arguments: optimizer_name (`str`): The name of the optimizer. - params (`torch.tensor`): + params (`torch.Tensor`): The input parameters to optimize. lr (`float`, defaults to 1e-3): The learning rate. @@ -655,7 +654,6 @@ def __init__( if args is None: args = {} args["optim_bits"] = optim_bits - args["percentile_clipping"] = 100 args["min_8bit_size"] = min_8bit_size args["percentile_clipping"] = percentile_clipping args["block_wise"] = block_wise