Skip to content

Training Configuration

This is the API documentation for the training configuration classes.

SimpleTrainConfig dataclass

Attributes
resources instance-attribute
resources: ResourceConfig
train_batch_size instance-attribute
train_batch_size: int | IntSchedule

The batch size for training. If an IntSchedule is provided, the batch size will be varied according to the schedule.

num_train_steps instance-attribute
num_train_steps: int
learning_rate instance-attribute
learning_rate: float
train_seq_len class-attribute instance-attribute
train_seq_len: int | None = None
data_seed class-attribute instance-attribute
data_seed: int | None = None
weight_decay class-attribute instance-attribute
weight_decay: float | None = None
beta1 class-attribute instance-attribute
beta1: float | None = None
beta2 class-attribute instance-attribute
beta2: float | None = None
epsilon class-attribute instance-attribute
epsilon: float | None = None
max_grad_norm class-attribute instance-attribute
max_grad_norm: float | None = None
warmup class-attribute instance-attribute
warmup: float | None = None
decay class-attribute instance-attribute
decay: float | None = None
rewarmup class-attribute instance-attribute
rewarmup: float | None = None

The rewarmup parameter is used to re-warmup the learning rate after a decay cycles

lr_schedule class-attribute instance-attribute
lr_schedule: str | None = None
min_lr_ratio class-attribute instance-attribute
min_lr_ratio: float | None = None
cycle_length class-attribute instance-attribute
cycle_length: int | list[int] | None = None
z_loss_weight class-attribute instance-attribute
z_loss_weight: float | None = None
ema_beta class-attribute instance-attribute
ema_beta: float | None = None

exponential moving average beta

skip_bad_steps class-attribute instance-attribute
skip_bad_steps: bool = False

If True, skips steps where the loss or grad is significantly higher than the historical mean.

steps_per_eval class-attribute instance-attribute
steps_per_eval: int | None = None

how often to run validation losses

steps_per_export class-attribute instance-attribute
steps_per_export: int | None = None

How often to keep a permanent checkpoint. None (default) keeps only the final checkpoint; rolling temporary checkpoints are still written for resumption.

steps_per_task_eval class-attribute instance-attribute
steps_per_task_eval: int | None = None

how often to run task evaluations

steps_per_hf_export class-attribute instance-attribute
steps_per_hf_export: int | None = None

None means match steps_per_export, -1 disables

hf_generation_eos_token_ids class-attribute instance-attribute
hf_generation_eos_token_ids: list[int] | None = None

EOS token IDs to write to generation_config.json. None means no generation config.

per_device_parallelism class-attribute instance-attribute
per_device_parallelism: int = -1

How many examples to process in parallel on each device. -1 (default) means train_batch_size/num_devices (no gradient accumulation). Set to a positive value to enable gradient accumulation.

per_device_eval_parallelism class-attribute instance-attribute
per_device_eval_parallelism: int | None = None

Number of examples to evaluate in parallel on each device

max_eval_batches class-attribute instance-attribute
max_eval_batches: int | None = None

Maximum number of batches to evaluate on. None means all batches

initialize_from_checkpoint_path class-attribute instance-attribute
initialize_from_checkpoint_path: str | None = None

If set, the training will resume from the checkpoint at this path. Otherwise, training will start from scratch.

initialize_from_hf class-attribute instance-attribute
initialize_from_hf: str | None = None

If set, the training will start from the hf model at this path. Otherwise, training will start from scratch.

reset_data_loader_on_init class-attribute instance-attribute
reset_data_loader_on_init: bool = True

Pairs with initialize_from_checkpoint_path. If True, initialize_from_checkpoint_path will reset the data loader so that it starts from step 0. Otherwise, it will resume from the step in the checkpoint.

allow_partial_checkpoint class-attribute instance-attribute
allow_partial_checkpoint: bool = False

Allow loading partial checkpoints. This is useful for converting training to EMA, e.g.

int8 class-attribute instance-attribute
int8: bool = False

Int8 (quantized) training in Levanter.

pad_tokenizer_to_match_model class-attribute instance-attribute
pad_tokenizer_to_match_model: bool = False

If True, pad the tokenizer's vocab to match the model's vocab size by adding dummy tokens. Useful when the model checkpoint has a larger vocab than the tokenizer (e.g., Qwen models pad their vocab to be divisible by 4 for TPU efficiency).

optimizer_config class-attribute instance-attribute
optimizer_config: OptimizerConfig | None = None

Optimizer configuration to use. If not set, Adam will be used.

watch class-attribute instance-attribute
watch: WatchConfig = field(default_factory=WatchConfig)

Config for watching gradients, parameters, etc. Default is to log norms of gradients and parameters.

profiler class-attribute instance-attribute
profiler: ProfilerConfig = field(
    default_factory=ProfilerConfig
)

JAX profiler settings for training.

explicit_mesh_axes class-attribute instance-attribute
explicit_mesh_axes: bool = False

If True, build the device mesh with AxisType.Explicit axes.

Required for models that call jax.sharding.reshard(..., PartitionSpec(...)).

tensor_parallel_size class-attribute instance-attribute
tensor_parallel_size: int = 1

Size of the model (tensor parallel) axis. >1 shards model weights and activations across multiple devices. Useful when batch_size < num_chips.

env_vars class-attribute instance-attribute
env_vars: dict[str, str] | None = None

Environment variables to pass to the training task.

SimpleSFTConfig dataclass

A simplified configuration for Supervised Fine-Tuning (SFT) that works for both single dataset and mixture training approaches.

Attributes
resources instance-attribute
resources: ResourceConfig
train_batch_size class-attribute instance-attribute
train_batch_size: int | IntSchedule = 128

The batch size for training. If an IntSchedule is provided, the batch size will be varied according to the schedule.

num_train_steps class-attribute instance-attribute
num_train_steps: int = 10000

Number of training steps.

learning_rate class-attribute instance-attribute
learning_rate: float = 5e-06

Learning rate for the optimizer.

tokenizer class-attribute instance-attribute
tokenizer: str | None = None

Tokenizer to use for training.

initialize_from_hf class-attribute instance-attribute
initialize_from_hf: str | None = None

HF model name or path to initialize from (e.g., 'meta-llama/Llama-3.1-8B'). Mutually exclusive with initialize_from_checkpoint_path.

initialize_from_checkpoint_path class-attribute instance-attribute
initialize_from_checkpoint_path: str | None = None

Path to a levanter checkpoint to initialize from. Mutually exclusive with initialize_from_hf.

max_seq_len class-attribute instance-attribute
max_seq_len: int = 4096

Maximum sequence length for training.

weight_decay class-attribute instance-attribute
weight_decay: float = 0.0

Weight decay for the optimizer.

beta1 class-attribute instance-attribute
beta1: float | None = None

AdamW optimizer beta1.

beta2 class-attribute instance-attribute
beta2: float | None = None

AdamW optimizer beta2.

warmup class-attribute instance-attribute
warmup: float = 0.03

Fraction of training steps to use for learning rate warmup.

decay class-attribute instance-attribute
decay: float = 0.0

Fraction of training steps to use for learning rate decay.

lr_schedule class-attribute instance-attribute
lr_schedule: str = 'linear'

Learning rate schedule to use: 'linear', 'cosine', etc.

min_lr_ratio class-attribute instance-attribute
min_lr_ratio: float = 0.0

Minimum learning rate as a ratio of the base learning rate.

max_grad_norm class-attribute instance-attribute
max_grad_norm: float | None = None

Maximum gradient norm for gradient clipping.

steps_per_eval class-attribute instance-attribute
steps_per_eval: int = 1000

How often to run validation losses.

steps_per_checkpoint class-attribute instance-attribute
steps_per_checkpoint: int | None = None

How often to keep a permanent checkpoint. None (default) keeps only the final checkpoint; rolling temporary checkpoints are still written for resumption.

steps_per_hf_export class-attribute instance-attribute
steps_per_hf_export: int = 500

How often to save HuggingFace checkpoints.

hf_generation_eos_token_ids class-attribute instance-attribute
hf_generation_eos_token_ids: list[int] | None = None

EOS token IDs to write to generation_config.json. None means no generation config. For chat models, include the turn-boundary token (e.g. [128001, 128009]).

mixture_block_size class-attribute instance-attribute
mixture_block_size: int = 2048

Block size for dataset mixing (only used with mixture training).

stop_strategy class-attribute instance-attribute
stop_strategy: str = 'restart'

Strategy for handling dataset completion (only used with mixture training). Options: 'restart' or 'exit'.

seed class-attribute instance-attribute
seed: int = 0

Random seed for training.

node_count class-attribute instance-attribute
node_count: int = 1

Number of TPU slices for training.

int8 class-attribute instance-attribute
int8: bool = False

Int8 (quantized) training in Levanter.

pad_tokenizer_to_match_model class-attribute instance-attribute
pad_tokenizer_to_match_model: bool = False

If True, pad the tokenizer's vocab to match the model's vocab size by adding dummy tokens. Useful when the model checkpoint has a larger vocab than the tokenizer (e.g., Qwen models pad their vocab to be divisible by 4 for TPU efficiency).

z_loss_weight class-attribute instance-attribute
z_loss_weight: float = 0.0
per_device_parallelism class-attribute instance-attribute
per_device_parallelism: int = -1

How many examples to process in parallel on each device. -1 (default) means train_batch_size/num_devices (no gradient accumulation). Set to a positive value to enable gradient accumulation. For example, with 8 devices, batch_size=32, and per_device_parallelism=1, you get gradient accumulation of 4.

reinit_tokens class-attribute instance-attribute
reinit_tokens: list[str] | bool = False

if set, will reinitialize the embeddings for the given tokens. If True, will reinitialize the default tokens for llama3's tokenizer