How to Add a New Optimizer¶
Marin builds on Levanter for training code, meaning any training changes made in Levanter are automatically available in Marin. However, you can also add new optimizers directly in Marin- thanks to Levanter’s support for Optax-- without needing to merge a pull request upstream.
In this guide, we’ll walk through adding an AdaMax optimizer as an example.
Steps to Add an Optimizer¶
-
Import Optax and OptimizerConfig:
-
Define a new optimizer by subclassing
OptimizerConfigand add optimizer-specific parameters as class variables:@dataclass class AdamaxConfig(OptimizerConfig): beta1: float = 0.9 beta2: float = 0.95 epsilon: float = 1e-8 max_grad_norm: float | None = 1.0OptimizerConfighas a number of fields that are common to all optimizers; these includeweight_decay,learning_rate,lr_schedule,min_lr_ratio,warmup,decay,rewarmup,cycles, andcycle_length. You can find documentation for the OptimizerConfig class, along with further details about the fields here. -
Implement the
build()method to define the optimizer's update rule. This method should return an Optax optimizer. Optax allows you to define components that are gradient transformations, and then chain them together to obtain a final gradient update rule.def build(self, num_train_steps): print(f"Building optimizer: {self.__class__.__name__}") # Register the optimizer class if not already registered try: OptimizerConfig.register_subclass("adamax")(AdamaxConfig) except ValueError: pass def _optimizer(learning_rate): components = [] # Add gradient clipping if specified if self.max_grad_norm is not None: components.append(optax.clip_by_global_norm(self.max_grad_norm)) # Add the Adamax optimizer components.append( optax.adamax( b1=self.beta1, b2=self.beta2, eps=self.epsilon, ) ) # Add weight decay if specified if self.weight_decay > 0: components.append( optax.add_decayed_weights( self.weight_decay, self.build_weight_decay_mask() ) ) # Scale the learning rate components.append(optax.scale(-learning_rate)) # Chain all components together optimizer = optax.chain(*components) return optimizer # Inject hyperparameters using the learning rate scheduler return optax.inject_hyperparams(_optimizer)( learning_rate=self.lr_scheduler(num_train_steps) )Note that
optax.inject_hyperparamsis a wrapper in Optax that can be used to pass schedules (or stateful hyperparameters) into the optimizer. This also allows us to log the learning rate in the tracker.Note
You should also register your optimizer class with an identifier, as shown above.
-
Use the optimizer in your training script. You can instantiate and pass it directly into your training config:
optimizer = AdamaxConfig( beta1=0.9, beta2=0.95, epsilon=1e-8, max_grad_norm=1.0, weight_decay=0.1, lr=1e-4 )and use it in
TrainLmConfig:from levanter.main.train_lm import TrainLmConfig trainer_config = TrainLmConfig( ... optimizer=optimizer, ... )Or inside a
SimpleTrainConfig:train_config = SimpleTrainConfig( ... optimizer_config=AdamaxConfig( beta1=0.9, beta2=0.95, epsilon=1e-8, max_grad_norm=1.0, weight_decay=0.1, lr=1e-4 ), ... )Then pass it into
default_train, which will set the optimizer config correctly in the training step.
That’s it! You can now define new optimizers in this manner and train models using them, all within Marin. For optimizers that are widely useful or “standard,” consider submitting a pull request to Levanter.
Further reading: