Training Configuration¶
This is the API documentation for the training configuration classes.
SimpleTrainConfig
dataclass
¶
Attributes¶
train_batch_size
instance-attribute
¶
train_batch_size: int | IntSchedule
The batch size for training. If an IntSchedule is provided, the batch size will be varied according to the schedule.
rewarmup
class-attribute
instance-attribute
¶
rewarmup: float | None = None
The rewarmup parameter is used to re-warmup the learning rate after a decay cycles
ema_beta
class-attribute
instance-attribute
¶
ema_beta: float | None = None
exponential moving average beta
skip_bad_steps
class-attribute
instance-attribute
¶
skip_bad_steps: bool = False
If True, skips steps where the loss or grad is significantly higher than the historical mean.
steps_per_eval
class-attribute
instance-attribute
¶
steps_per_eval: int | None = None
how often to run validation losses
steps_per_export
class-attribute
instance-attribute
¶
steps_per_export: int | None = None
How often to keep a permanent checkpoint. None (default) keeps only the final checkpoint; rolling temporary checkpoints are still written for resumption.
steps_per_task_eval
class-attribute
instance-attribute
¶
steps_per_task_eval: int | None = None
how often to run task evaluations
steps_per_hf_export
class-attribute
instance-attribute
¶
steps_per_hf_export: int | None = None
None means match steps_per_export, -1 disables
hf_generation_eos_token_ids
class-attribute
instance-attribute
¶
EOS token IDs to write to generation_config.json. None means no generation config.
per_device_parallelism
class-attribute
instance-attribute
¶
per_device_parallelism: int = -1
How many examples to process in parallel on each device. -1 (default) means train_batch_size/num_devices (no gradient accumulation). Set to a positive value to enable gradient accumulation.
per_device_eval_parallelism
class-attribute
instance-attribute
¶
per_device_eval_parallelism: int | None = None
Number of examples to evaluate in parallel on each device
max_eval_batches
class-attribute
instance-attribute
¶
max_eval_batches: int | None = None
Maximum number of batches to evaluate on. None means all batches
initialize_from_checkpoint_path
class-attribute
instance-attribute
¶
initialize_from_checkpoint_path: str | None = None
If set, the training will resume from the checkpoint at this path. Otherwise, training will start from scratch.
initialize_from_hf
class-attribute
instance-attribute
¶
initialize_from_hf: str | None = None
If set, the training will start from the hf model at this path. Otherwise, training will start from scratch.
reset_data_loader_on_init
class-attribute
instance-attribute
¶
reset_data_loader_on_init: bool = True
Pairs with initialize_from_checkpoint_path. If True, initialize_from_checkpoint_path will reset the data loader so that it starts from step 0. Otherwise, it will resume from the step in the checkpoint.
allow_partial_checkpoint
class-attribute
instance-attribute
¶
allow_partial_checkpoint: bool = False
Allow loading partial checkpoints. This is useful for converting training to EMA, e.g.
pad_tokenizer_to_match_model
class-attribute
instance-attribute
¶
pad_tokenizer_to_match_model: bool = False
If True, pad the tokenizer's vocab to match the model's vocab size by adding dummy tokens. Useful when the model checkpoint has a larger vocab than the tokenizer (e.g., Qwen models pad their vocab to be divisible by 4 for TPU efficiency).
optimizer_config
class-attribute
instance-attribute
¶
optimizer_config: OptimizerConfig | None = None
Optimizer configuration to use. If not set, Adam will be used.
watch
class-attribute
instance-attribute
¶
watch: WatchConfig = field(default_factory=WatchConfig)
Config for watching gradients, parameters, etc. Default is to log norms of gradients and parameters.
profiler
class-attribute
instance-attribute
¶
profiler: ProfilerConfig = field(
default_factory=ProfilerConfig
)
JAX profiler settings for training.
explicit_mesh_axes
class-attribute
instance-attribute
¶
explicit_mesh_axes: bool = False
If True, build the device mesh with AxisType.Explicit axes.
Required for models that call jax.sharding.reshard(..., PartitionSpec(...)).
SimpleSFTConfig
dataclass
¶
A simplified configuration for Supervised Fine-Tuning (SFT) that works for both single dataset and mixture training approaches.
Attributes¶
train_batch_size
class-attribute
instance-attribute
¶
train_batch_size: int | IntSchedule = 128
The batch size for training. If an IntSchedule is provided, the batch size will be varied according to the schedule.
num_train_steps
class-attribute
instance-attribute
¶
num_train_steps: int = 10000
Number of training steps.
learning_rate
class-attribute
instance-attribute
¶
learning_rate: float = 5e-06
Learning rate for the optimizer.
tokenizer
class-attribute
instance-attribute
¶
tokenizer: str | None = None
Tokenizer to use for training.
initialize_from_hf
class-attribute
instance-attribute
¶
initialize_from_hf: str | None = None
HF model name or path to initialize from (e.g., 'meta-llama/Llama-3.1-8B'). Mutually exclusive with initialize_from_checkpoint_path.
initialize_from_checkpoint_path
class-attribute
instance-attribute
¶
initialize_from_checkpoint_path: str | None = None
Path to a levanter checkpoint to initialize from. Mutually exclusive with initialize_from_hf.
max_seq_len
class-attribute
instance-attribute
¶
max_seq_len: int = 4096
Maximum sequence length for training.
weight_decay
class-attribute
instance-attribute
¶
weight_decay: float = 0.0
Weight decay for the optimizer.
warmup
class-attribute
instance-attribute
¶
warmup: float = 0.03
Fraction of training steps to use for learning rate warmup.
decay
class-attribute
instance-attribute
¶
decay: float = 0.0
Fraction of training steps to use for learning rate decay.
lr_schedule
class-attribute
instance-attribute
¶
lr_schedule: str = 'linear'
Learning rate schedule to use: 'linear', 'cosine', etc.
min_lr_ratio
class-attribute
instance-attribute
¶
min_lr_ratio: float = 0.0
Minimum learning rate as a ratio of the base learning rate.
max_grad_norm
class-attribute
instance-attribute
¶
max_grad_norm: float | None = None
Maximum gradient norm for gradient clipping.
steps_per_eval
class-attribute
instance-attribute
¶
steps_per_eval: int = 1000
How often to run validation losses.
steps_per_checkpoint
class-attribute
instance-attribute
¶
steps_per_checkpoint: int | None = None
How often to keep a permanent checkpoint. None (default) keeps only the final checkpoint; rolling temporary checkpoints are still written for resumption.
steps_per_hf_export
class-attribute
instance-attribute
¶
steps_per_hf_export: int = 500
How often to save HuggingFace checkpoints.
hf_generation_eos_token_ids
class-attribute
instance-attribute
¶
EOS token IDs to write to generation_config.json. None means no generation config. For chat models, include the turn-boundary token (e.g. [128001, 128009]).
mixture_block_size
class-attribute
instance-attribute
¶
mixture_block_size: int = 2048
Block size for dataset mixing (only used with mixture training).
stop_strategy
class-attribute
instance-attribute
¶
stop_strategy: str = 'restart'
Strategy for handling dataset completion (only used with mixture training). Options: 'restart' or 'exit'.
node_count
class-attribute
instance-attribute
¶
node_count: int = 1
Number of TPU slices for training.
pad_tokenizer_to_match_model
class-attribute
instance-attribute
¶
pad_tokenizer_to_match_model: bool = False
If True, pad the tokenizer's vocab to match the model's vocab size by adding dummy tokens. Useful when the model checkpoint has a larger vocab than the tokenizer (e.g., Qwen models pad their vocab to be divisible by 4 for TPU efficiency).
per_device_parallelism
class-attribute
instance-attribute
¶
per_device_parallelism: int = -1
How many examples to process in parallel on each device. -1 (default) means train_batch_size/num_devices (no gradient accumulation). Set to a positive value to enable gradient accumulation. For example, with 8 devices, batch_size=32, and per_device_parallelism=1, you get gradient accumulation of 4.