1142 lines
56 KiB
Python
1142 lines
56 KiB
Python
"""
|
|
2025.6.1
|
|
2025.6.2
|
|
4.52.4
|
|
0.18.2
|
|
__UNSLOTH_VERSIONING__
|
|
"""
|
|
from torch import Tensor
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
from trl.trainer.sft_trainer import (Any, AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, DataCollatorWithFlattening, Dataset, EvalPrediction, FeatureExtractionMixin, IterableDataset, Optional, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, Trainer, TrainerCallback, TrainingArguments, Union, contextlib, dataclass, dataclasses, defaultdict, generate_model_card, get_act_offloading_ctx_manager, get_comet_experiment_url, get_peft_model, is_peft_available, is_wandb_available, nn, os, pad, peft, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, transformers, version, warnings, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, Dataset, IterableDataset, Optional, Union, os, pad, transformers, os)
|
|
|
|
|
|
import os
|
|
from typing import *
|
|
from dataclasses import dataclass, field
|
|
from packaging.version import Version
|
|
import torch
|
|
import numpy as np
|
|
from contextlib import nullcontext
|
|
from torch.nn import functional as F
|
|
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
|
|
|
torch_compile_options = {
|
|
"epilogue_fusion" : True,
|
|
"max_autotune" : False,
|
|
"shape_padding" : True,
|
|
"trace.enabled" : False,
|
|
"triton.cudagraphs" : False,
|
|
}
|
|
|
|
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
|
def selective_log_softmax(logits, index):
|
|
logits = logits.to(torch.float32)
|
|
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
|
# loop to reduce peak mem consumption
|
|
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
|
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
|
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
|
return per_token_logps
|
|
@dataclass
|
|
class UnslothSFTConfig(SFTConfig):
|
|
"""
|
|
|
|
Configuration class for the [`SFTTrainer`].
|
|
|
|
Only the parameters specific to SFT training are listed here. For details on other parameters, refer to the
|
|
[`~transformers.TrainingArguments`] documentation.
|
|
|
|
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
|
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
|
command line.
|
|
|
|
Parameters:
|
|
> Parameters that control the model
|
|
|
|
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
|
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
|
argument of the [`SFTTrainer`] is provided as a string.
|
|
|
|
> Parameters that control the data preprocessing
|
|
|
|
dataset_text_field (`str`, *optional*, defaults to `"text"`):
|
|
Name of the column that contains text data in the dataset.
|
|
dataset_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
|
Dictionary of optional keyword arguments for the dataset preparation. The only supported key is
|
|
`skip_prepare_dataset`.
|
|
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
|
Number of processes to use for processing the dataset.
|
|
eos_token (`str` or `None`, *optional*, defaults to `None`):
|
|
Token used to indicate the end of a turn or sequence. If `None`, it defaults to `processing_class.eos_token`.
|
|
pad_token (`int` or `None`, *optional*, defaults to `None`):
|
|
Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`,
|
|
it falls back to `processing_class.eos_token`.
|
|
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
|
Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right.
|
|
If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
|
|
packing (`bool`, *optional*, defaults to `False`):
|
|
Whether to pack multiple sequences into a fixed-length format. Uses `max_length` to define sequence length.
|
|
padding_free (`bool`, *optional*, defaults to `False`):
|
|
Whether to perform forward passes without padding by flattening all sequences in the batch into a single
|
|
continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only
|
|
supported with the `flash_attention_2` attention implementation, which can efficiently handle the flattened
|
|
batch structure.
|
|
pad_to_multiple_of (`int` or `None`, *optional*, defaults to `None`):
|
|
If set, the sequences will be padded to a multiple of this value.
|
|
eval_packing (`bool` or `None`, *optional*, defaults to `None`):
|
|
Whether to pack the eval dataset. If `None`, uses the same value as `packing`.
|
|
|
|
> Parameters that control the training
|
|
|
|
learning_rate (`float`, *optional*, defaults to `2e-5`):
|
|
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
|
[`~transformers.TrainingArguments`].
|
|
completion_only_loss (`bool` or `None`, *optional*, defaults to `None`):
|
|
Whether to compute loss only on the completion part of the sequence. If set to `True`, loss is computed
|
|
only on the completion, which is supported only for [prompt-completion](#prompt-completion) datasets. If
|
|
`False`, loss is computed on the entire sequence. If `None` (default), the behavior depends on the dataset:
|
|
loss is computed on the completion for [prompt-completion](#prompt-completion) datasets, and on
|
|
the full sequence for [language modeling](#language-modeling) datasets.
|
|
activation_offloading (`bool`, *optional*, defaults to `False`):
|
|
Whether to offload the activations to the CPU.
|
|
|
|
|
|
"""
|
|
vllm_sampling_params: Optional[Any] = field(
|
|
default = None,
|
|
metadata = {'help': 'vLLM SamplingParams'},
|
|
)
|
|
unsloth_num_chunks : Optional[int] = field(
|
|
default = -1,
|
|
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
|
)
|
|
def __init__(
|
|
self,
|
|
output_dir = None,
|
|
overwrite_output_dir = None,
|
|
do_train = False,
|
|
do_eval = False,
|
|
do_predict = False,
|
|
eval_strategy = 'no',
|
|
prediction_loss_only = False,
|
|
per_device_train_batch_size = 4,
|
|
per_device_eval_batch_size = 4,
|
|
per_gpu_train_batch_size = None,
|
|
per_gpu_eval_batch_size = None,
|
|
gradient_accumulation_steps = 2,
|
|
eval_accumulation_steps = 2,
|
|
eval_delay = 0,
|
|
torch_empty_cache_steps = 250,
|
|
learning_rate = 5e-05,
|
|
weight_decay = 0.01,
|
|
adam_beta1 = 0.9,
|
|
adam_beta2 = 0.999,
|
|
adam_epsilon = 1e-08,
|
|
max_grad_norm = 1.0,
|
|
num_train_epochs = 3.0,
|
|
max_steps = -1,
|
|
lr_scheduler_type = 'linear',
|
|
warmup_ratio = 0.1,
|
|
warmup_steps = 0,
|
|
log_level = 'passive',
|
|
log_level_replica = 'warning',
|
|
log_on_each_node = True,
|
|
logging_dir = None,
|
|
logging_strategy = 'steps',
|
|
logging_first_step = False,
|
|
logging_steps = 1,
|
|
logging_nan_inf_filter = False,
|
|
save_strategy = 'steps',
|
|
save_steps = 500,
|
|
save_total_limit = None,
|
|
save_safetensors = True,
|
|
save_on_each_node = False,
|
|
save_only_model = False,
|
|
restore_callback_states_from_checkpoint = False,
|
|
no_cuda = False,
|
|
use_cpu = False,
|
|
use_mps_device = False,
|
|
seed = 3407,
|
|
data_seed = 3407,
|
|
jit_mode_eval = False,
|
|
use_ipex = False,
|
|
bf16 = False,
|
|
fp16 = False,
|
|
fp16_opt_level = 'O1',
|
|
half_precision_backend = 'auto',
|
|
bf16_full_eval = False,
|
|
fp16_full_eval = False,
|
|
tf32 = None,
|
|
local_rank = -1,
|
|
ddp_backend = None,
|
|
tpu_num_cores = None,
|
|
tpu_metrics_debug = False,
|
|
debug = '',
|
|
dataloader_drop_last = False,
|
|
eval_steps = None,
|
|
dataloader_num_workers = 0,
|
|
dataloader_prefetch_factor = None,
|
|
past_index = -1,
|
|
run_name = None,
|
|
disable_tqdm = None,
|
|
remove_unused_columns = True,
|
|
label_names = None,
|
|
load_best_model_at_end = False,
|
|
metric_for_best_model = None,
|
|
greater_is_better = None,
|
|
ignore_data_skip = False,
|
|
fsdp = '',
|
|
fsdp_min_num_params = 0,
|
|
fsdp_config = None,
|
|
fsdp_transformer_layer_cls_to_wrap = None,
|
|
accelerator_config = None,
|
|
deepspeed = None,
|
|
label_smoothing_factor = 0.0,
|
|
optim = 'adamw_8bit',
|
|
optim_args = None,
|
|
adafactor = False,
|
|
group_by_length = False,
|
|
length_column_name = 'length',
|
|
report_to = None,
|
|
ddp_find_unused_parameters = None,
|
|
ddp_bucket_cap_mb = None,
|
|
ddp_broadcast_buffers = None,
|
|
dataloader_pin_memory = True,
|
|
dataloader_persistent_workers = False,
|
|
skip_memory_metrics = True,
|
|
use_legacy_prediction_loop = False,
|
|
push_to_hub = False,
|
|
resume_from_checkpoint = None,
|
|
hub_model_id = None,
|
|
hub_strategy = 'every_save',
|
|
hub_token = None,
|
|
hub_private_repo = None,
|
|
hub_always_push = False,
|
|
gradient_checkpointing = False,
|
|
gradient_checkpointing_kwargs = None,
|
|
include_inputs_for_metrics = False,
|
|
eval_do_concat_batches = True,
|
|
fp16_backend = 'auto',
|
|
push_to_hub_model_id = None,
|
|
push_to_hub_organization = None,
|
|
push_to_hub_token = None,
|
|
mp_parameters = '',
|
|
auto_find_batch_size = False,
|
|
full_determinism = False,
|
|
torchdynamo = None,
|
|
ray_scope = 'last',
|
|
ddp_timeout = 1800,
|
|
torch_compile = False,
|
|
torch_compile_backend = None,
|
|
torch_compile_mode = None,
|
|
include_tokens_per_second = False,
|
|
include_num_input_tokens_seen = False,
|
|
neftune_noise_alpha = None,
|
|
optim_target_modules = None,
|
|
batch_eval_metrics = False,
|
|
eval_on_start = False,
|
|
use_liger_kernel = False,
|
|
eval_use_gather_object = False,
|
|
average_tokens_across_devices = False,
|
|
model_init_kwargs = None,
|
|
dataset_text_field = 'text',
|
|
dataset_kwargs = None,
|
|
dataset_num_proc = None,
|
|
eos_token = None,
|
|
pad_token = None,
|
|
max_length = 1024,
|
|
packing = False,
|
|
padding_free = False,
|
|
pad_to_multiple_of = None,
|
|
eval_packing = None,
|
|
completion_only_loss = None,
|
|
activation_offloading = False,
|
|
max_seq_length = None,
|
|
vllm_sampling_params = None,
|
|
unsloth_num_chunks = -1,
|
|
**kwargs,
|
|
):
|
|
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
|
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
|
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
|
output_dir = 'unsloth_training_checkpoints'
|
|
save_strategy = 'no'
|
|
if dataset_num_proc is None:
|
|
from multiprocessing import cpu_count
|
|
dataset_num_proc = cpu_count()
|
|
|
|
super().__init__(
|
|
output_dir = output_dir,
|
|
overwrite_output_dir = overwrite_output_dir,
|
|
do_train = do_train,
|
|
do_eval = do_eval,
|
|
do_predict = do_predict,
|
|
eval_strategy = eval_strategy,
|
|
prediction_loss_only = prediction_loss_only,
|
|
per_device_train_batch_size = per_device_train_batch_size,
|
|
per_device_eval_batch_size = per_device_eval_batch_size,
|
|
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
|
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
|
gradient_accumulation_steps = gradient_accumulation_steps,
|
|
eval_accumulation_steps = eval_accumulation_steps,
|
|
eval_delay = eval_delay,
|
|
torch_empty_cache_steps = torch_empty_cache_steps,
|
|
learning_rate = learning_rate,
|
|
weight_decay = weight_decay,
|
|
adam_beta1 = adam_beta1,
|
|
adam_beta2 = adam_beta2,
|
|
adam_epsilon = adam_epsilon,
|
|
max_grad_norm = max_grad_norm,
|
|
num_train_epochs = num_train_epochs,
|
|
max_steps = max_steps,
|
|
lr_scheduler_type = lr_scheduler_type,
|
|
warmup_ratio = warmup_ratio,
|
|
warmup_steps = warmup_steps,
|
|
log_level = log_level,
|
|
log_level_replica = log_level_replica,
|
|
log_on_each_node = log_on_each_node,
|
|
logging_dir = logging_dir,
|
|
logging_strategy = logging_strategy,
|
|
logging_first_step = logging_first_step,
|
|
logging_steps = logging_steps,
|
|
logging_nan_inf_filter = logging_nan_inf_filter,
|
|
save_strategy = save_strategy,
|
|
save_steps = save_steps,
|
|
save_total_limit = save_total_limit,
|
|
save_safetensors = save_safetensors,
|
|
save_on_each_node = save_on_each_node,
|
|
save_only_model = save_only_model,
|
|
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
|
no_cuda = no_cuda,
|
|
use_cpu = use_cpu,
|
|
use_mps_device = use_mps_device,
|
|
seed = seed,
|
|
data_seed = data_seed,
|
|
jit_mode_eval = jit_mode_eval,
|
|
use_ipex = use_ipex,
|
|
bf16 = bf16,
|
|
fp16 = fp16,
|
|
fp16_opt_level = fp16_opt_level,
|
|
half_precision_backend = half_precision_backend,
|
|
bf16_full_eval = bf16_full_eval,
|
|
fp16_full_eval = fp16_full_eval,
|
|
tf32 = tf32,
|
|
local_rank = local_rank,
|
|
ddp_backend = ddp_backend,
|
|
tpu_num_cores = tpu_num_cores,
|
|
tpu_metrics_debug = tpu_metrics_debug,
|
|
debug = debug,
|
|
dataloader_drop_last = dataloader_drop_last,
|
|
eval_steps = eval_steps,
|
|
dataloader_num_workers = dataloader_num_workers,
|
|
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
|
past_index = past_index,
|
|
run_name = run_name,
|
|
disable_tqdm = disable_tqdm,
|
|
remove_unused_columns = remove_unused_columns,
|
|
label_names = label_names,
|
|
load_best_model_at_end = load_best_model_at_end,
|
|
metric_for_best_model = metric_for_best_model,
|
|
greater_is_better = greater_is_better,
|
|
ignore_data_skip = ignore_data_skip,
|
|
fsdp = fsdp,
|
|
fsdp_min_num_params = fsdp_min_num_params,
|
|
fsdp_config = fsdp_config,
|
|
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
|
accelerator_config = accelerator_config,
|
|
deepspeed = deepspeed,
|
|
label_smoothing_factor = label_smoothing_factor,
|
|
optim = optim,
|
|
optim_args = optim_args,
|
|
adafactor = adafactor,
|
|
group_by_length = group_by_length,
|
|
length_column_name = length_column_name,
|
|
report_to = report_to,
|
|
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
|
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
|
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
|
dataloader_pin_memory = dataloader_pin_memory,
|
|
dataloader_persistent_workers = dataloader_persistent_workers,
|
|
skip_memory_metrics = skip_memory_metrics,
|
|
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
|
push_to_hub = push_to_hub,
|
|
resume_from_checkpoint = resume_from_checkpoint,
|
|
hub_model_id = hub_model_id,
|
|
hub_strategy = hub_strategy,
|
|
hub_token = hub_token,
|
|
hub_private_repo = hub_private_repo,
|
|
hub_always_push = hub_always_push,
|
|
gradient_checkpointing = gradient_checkpointing,
|
|
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
|
include_inputs_for_metrics = include_inputs_for_metrics,
|
|
eval_do_concat_batches = eval_do_concat_batches,
|
|
fp16_backend = fp16_backend,
|
|
push_to_hub_model_id = push_to_hub_model_id,
|
|
push_to_hub_organization = push_to_hub_organization,
|
|
push_to_hub_token = push_to_hub_token,
|
|
mp_parameters = mp_parameters,
|
|
auto_find_batch_size = auto_find_batch_size,
|
|
full_determinism = full_determinism,
|
|
torchdynamo = torchdynamo,
|
|
ray_scope = ray_scope,
|
|
ddp_timeout = ddp_timeout,
|
|
torch_compile = torch_compile,
|
|
torch_compile_backend = torch_compile_backend,
|
|
torch_compile_mode = torch_compile_mode,
|
|
include_tokens_per_second = include_tokens_per_second,
|
|
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
|
neftune_noise_alpha = neftune_noise_alpha,
|
|
optim_target_modules = optim_target_modules,
|
|
batch_eval_metrics = batch_eval_metrics,
|
|
eval_on_start = eval_on_start,
|
|
use_liger_kernel = use_liger_kernel,
|
|
eval_use_gather_object = eval_use_gather_object,
|
|
average_tokens_across_devices = average_tokens_across_devices,
|
|
model_init_kwargs = model_init_kwargs,
|
|
dataset_text_field = dataset_text_field,
|
|
dataset_kwargs = dataset_kwargs,
|
|
dataset_num_proc = dataset_num_proc,
|
|
eos_token = eos_token,
|
|
pad_token = pad_token,
|
|
max_length = max_length,
|
|
packing = packing,
|
|
padding_free = padding_free,
|
|
pad_to_multiple_of = pad_to_multiple_of,
|
|
eval_packing = eval_packing,
|
|
completion_only_loss = completion_only_loss,
|
|
activation_offloading = activation_offloading,
|
|
max_seq_length = max_seq_length,**kwargs)
|
|
self.vllm_sampling_params = vllm_sampling_params
|
|
self.unsloth_num_chunks = unsloth_num_chunks
|
|
pass
|
|
|
|
class _UnslothSFTTrainer(Trainer):
|
|
""""""
|
|
|
|
_tag_names = ["trl", "sft"]
|
|
|
|
def __init__(
|
|
self,
|
|
model: Union[str, nn.Module, PreTrainedModel],
|
|
args: Optional[Union[SFTConfig, TrainingArguments]] = None,
|
|
data_collator: Optional[DataCollator] = None, # type: ignore
|
|
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
|
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
|
processing_class: Optional[
|
|
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
|
] = None,
|
|
compute_loss_func: Optional[Callable] = None,
|
|
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
|
callbacks: Optional[list[TrainerCallback]] = None,
|
|
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
|
optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
|
|
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
|
peft_config: Optional["PeftConfig"] = None,
|
|
formatting_func: Optional[Union[Callable[[dict], str], Callable[[dict], list[str]]]] = None,
|
|
):
|
|
# Args
|
|
model_id = model if isinstance(model, str) else model.config._name_or_path
|
|
if args is None:
|
|
model_name = model_id.split("/")[-1]
|
|
args = SFTConfig(f"{model_name}-SFT")
|
|
elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig):
|
|
dict_args = args.to_dict()
|
|
dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token
|
|
dict_args.pop("push_to_hub_token")
|
|
args = SFTConfig(**dict_args)
|
|
|
|
# Handle the tokenizer
|
|
if processing_class is None:
|
|
processing_class = AutoTokenizer.from_pretrained(model_id)
|
|
|
|
if args.eos_token is not None:
|
|
eos_token = args.eos_token
|
|
eos_token_id = processing_class.convert_tokens_to_ids(eos_token)
|
|
if eos_token_id is None:
|
|
raise ValueError(
|
|
f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given "
|
|
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists "
|
|
"in the vocabulary before using it as an EOS token."
|
|
)
|
|
processing_class.eos_token_id = eos_token_id
|
|
|
|
# Model
|
|
if args.model_init_kwargs is not None and not isinstance(model, str):
|
|
warnings.warn(
|
|
"You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. "
|
|
"The `model_init_kwargs` will be ignored."
|
|
)
|
|
if isinstance(model, str):
|
|
model = self._create_model_from_path(model, args)
|
|
|
|
# PEFT configuration and model wrapping
|
|
if False:
|
|
model = self._prepare_peft_model(model, peft_config, args)
|
|
|
|
# Data collator
|
|
if args.padding_free:
|
|
if data_collator is not None:
|
|
raise ValueError("Passing a custom data collator is not supported when using padding-free.")
|
|
if args.packing:
|
|
warnings.warn(
|
|
"You are passing `packing=True` and `padding_free=True` which is not recommended. Please refer "
|
|
"to the documentation to understand why this is not recommended."
|
|
)
|
|
if model.config._attn_implementation != "flash_attention_2":
|
|
warnings.warn(
|
|
"Padding-free training is enabled, but the attention implementation is not set to "
|
|
"'flash_attention_2'. Padding-free training flattens batches into a single sequence, and "
|
|
"'flash_attention_2' is the only known attention mechanism that reliably supports this. Using "
|
|
"other implementations may lead to unexpected behavior. To ensure compatibility, set "
|
|
"`attn_implementation='flash_attention_2'` in the model configuration, or verify that your "
|
|
"attention mechanism can handle flattened sequences."
|
|
)
|
|
if args.per_device_train_batch_size == 1:
|
|
warnings.warn(
|
|
"You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size "
|
|
"of 1 anihilate the benefits of padding-free training. Please consider increasing the batch size "
|
|
"to at least 2."
|
|
)
|
|
data_collator = DataCollatorWithFlattening()
|
|
|
|
if args.completion_only_loss is None:
|
|
first_example = next(iter(train_dataset))
|
|
self.completion_only_loss = "prompt" in first_example
|
|
else:
|
|
self.completion_only_loss = args.completion_only_loss
|
|
if data_collator is None:
|
|
# Get the pad token: if not provided, use the one from the processing class or the eos token
|
|
# if the processing class does not have a pad token.
|
|
pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token
|
|
pad_token_id = processing_class.convert_tokens_to_ids(pad_token)
|
|
if pad_token_id is None:
|
|
raise ValueError(
|
|
f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
|
|
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
|
|
"in the vocabulary before using it as a padding token."
|
|
)
|
|
data_collator = DataCollatorForLanguageModeling(
|
|
pad_token_id, self.completion_only_loss, args.pad_to_multiple_of
|
|
)
|
|
|
|
# Dataset
|
|
preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
|
|
if preprocess_dataset:
|
|
if self.completion_only_loss and formatting_func:
|
|
raise ValueError(
|
|
"A formatting function was provided while `completion_only_loss=True`, which is incompatible. "
|
|
"Using a formatter converts the dataset to a language modeling type, conflicting with "
|
|
"completion-only loss. To resolve this, apply your formatting function before passing the "
|
|
"dataset, or disable `completion_only_loss` in `SFTConfig`."
|
|
)
|
|
|
|
train_dataset = self._prepare_dataset(
|
|
train_dataset, processing_class, args, args.packing, formatting_func, "train"
|
|
)
|
|
if eval_dataset is not None:
|
|
packing = args.packing if args.eval_packing is None else args.eval_packing
|
|
if isinstance(eval_dataset, dict):
|
|
eval_dataset = {
|
|
key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
|
|
for key, dataset in eval_dataset.items()
|
|
}
|
|
else:
|
|
eval_dataset = self._prepare_dataset(
|
|
eval_dataset, processing_class, args, packing, formatting_func, "eval"
|
|
)
|
|
|
|
# Initialize the metrics
|
|
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
|
|
self._total_train_tokens = 0
|
|
|
|
# Initialize the Trainer. Parent class will handle:
|
|
# - DeepSpeed configuration (through create_accelerator_and_postprocess)
|
|
# - FSDP setup
|
|
# - Distributed training setup
|
|
# - Optimizer and scheduler creation
|
|
# Some arguments are only available for transformers>=4.47.0. Can be removed when the min version is bumped.
|
|
super_init_kwargs = {}
|
|
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
|
super_init_kwargs["optimizer_cls_and_kwargs"] = optimizer_cls_and_kwargs
|
|
else:
|
|
if optimizer_cls_and_kwargs is not None:
|
|
warnings.warn(
|
|
"The `optimizer_cls_and_kwargs` argument is only available for `transformers>=4.47.0`. "
|
|
"The default optimizer will be used. "
|
|
"Remove the `optimizer_cls_and_kwargs` or upgrade to `transformers>=4.47.0`."
|
|
)
|
|
super().__init__(
|
|
model=model,
|
|
args=args,
|
|
data_collator=data_collator,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
processing_class=processing_class,
|
|
compute_loss_func=compute_loss_func,
|
|
compute_metrics=compute_metrics,
|
|
callbacks=callbacks,
|
|
optimizers=optimizers,
|
|
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
|
**super_init_kwargs,
|
|
)
|
|
|
|
# Initialize activation offloading context
|
|
if self.args.activation_offloading:
|
|
self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model)
|
|
else:
|
|
self.maybe_activation_offload_context = contextlib.nullcontext()
|
|
|
|
# Add tags for models that have been loaded with the correct transformers version
|
|
if hasattr(self.model, "add_model_tags"):
|
|
self.model.add_model_tags(self._tag_names)
|
|
|
|
def _create_model_from_path(self, model_path: str, args: SFTConfig) -> PreTrainedModel:
|
|
"""Creates a model from a path or model identifier."""
|
|
model_init_kwargs = args.model_init_kwargs or {}
|
|
# Handle torch dtype
|
|
torch_dtype = model_init_kwargs.get("torch_dtype")
|
|
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
|
|
pass # torch_dtype is already a torch.dtype or "auto" or None
|
|
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
|
torch_dtype = getattr(torch, torch_dtype)
|
|
model_init_kwargs["torch_dtype"] = torch_dtype
|
|
else:
|
|
raise ValueError(
|
|
"Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing "
|
|
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
|
)
|
|
# Disable caching if gradient checkpointing is enabled (not supported)
|
|
# if args.gradient_checkpointing:
|
|
# model_init_kwargs["use_cache"] = False
|
|
|
|
# Create model
|
|
model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
|
|
return model
|
|
|
|
def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel:
|
|
"""Prepares a model for PEFT training."""
|
|
if not is_peft_available():
|
|
raise ImportError("To use PeftModel, you need to install the `peft` library.")
|
|
|
|
if not isinstance(peft_config, PeftConfig):
|
|
raise ValueError(
|
|
f"Expected PeftConfig object but got {type(peft_config)}. If you want to use the PeftModel, you need "
|
|
"to pass a PeftConfig object to the SFTTrainer."
|
|
)
|
|
|
|
if isinstance(model, PeftModel):
|
|
return model
|
|
|
|
# Handle quantized models (QLoRA)
|
|
is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)
|
|
|
|
is_sharded_qlora = False
|
|
if getattr(model, "is_loaded_in_4bit", False):
|
|
# Check if model is sharded (FSDP/DS-Zero3)
|
|
for _, param in model.named_parameters():
|
|
if param.__class__.__name__ == "Params4bit":
|
|
is_sharded_qlora = param.data.device.type in {"cpu", "meta"}
|
|
break
|
|
|
|
# Prepare model for kbit training if needed
|
|
if is_qlora and not is_sharded_qlora:
|
|
model = self._prepare_model_for_kbit_training(model, args)
|
|
# Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training
|
|
args = dataclasses.replace(args, gradient_checkpointing=False)
|
|
elif args.gradient_checkpointing:
|
|
model = self._enable_gradient_checkpointing(model, args)
|
|
|
|
# Create PEFT model
|
|
if (
|
|
version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12
|
|
and getattr(model, "is_loaded_in_4bit", False)
|
|
and is_sharded_qlora
|
|
):
|
|
model = get_peft_model(model, peft_config, autocast_adapter_dtype=False)
|
|
else:
|
|
model = get_peft_model(model, peft_config)
|
|
|
|
# Handle bf16 casting for 4-bit models
|
|
if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora:
|
|
peft_module_casting_to_bf16(model)
|
|
|
|
return model
|
|
|
|
def _prepare_model_for_kbit_training(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
|
|
"""Prepares a quantized model for kbit training."""
|
|
prepare_model_kwargs = {
|
|
"use_gradient_checkpointing": args.gradient_checkpointing,
|
|
"gradient_checkpointing_kwargs": args.gradient_checkpointing_kwargs or {},
|
|
}
|
|
|
|
return prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
|
|
|
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
|
|
"""Enables gradient checkpointing for the model."""
|
|
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
|
use_reentrant = (
|
|
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
|
|
)
|
|
|
|
if use_reentrant:
|
|
if hasattr(model, "enable_input_require_grads"):
|
|
model.enable_input_require_grads()
|
|
else:
|
|
|
|
def make_inputs_require_grad(module, input, output):
|
|
output.requires_grad_(True)
|
|
|
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
|
|
|
return model
|
|
|
|
def _prepare_dataset(
|
|
self,
|
|
dataset: Union[Dataset, IterableDataset],
|
|
processing_class,
|
|
args,
|
|
packing: bool,
|
|
formatting_func: Optional[Callable[[dict], str]],
|
|
dataset_name: str,
|
|
) -> Union[Dataset, IterableDataset]:
|
|
# All Unsloth Zoo code licensed under LGPLv3
|
|
if isinstance(dataset, ConstantLengthDataset): return dataset
|
|
|
|
map_kwargs = {}
|
|
use_desc = isinstance(dataset, Dataset)
|
|
is_vlm = hasattr(processing_class, "tokenizer")
|
|
tokenizer = processing_class
|
|
if is_vlm: tokenizer = processing_class.tokenizer
|
|
|
|
# Get max length
|
|
max_seq_length = getattr(args, "max_length", 0)
|
|
if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0)
|
|
if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0)
|
|
if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0)
|
|
if max_seq_length == 0: raise RuntimeError("Unsloth: max_seq_length is 0! Please specify one!")
|
|
dataset_text_field = getattr(args, "dataset_text_field", "text")
|
|
do_truncation = max_seq_length != 0
|
|
do_formatting_func = False
|
|
do_tokenize = True
|
|
|
|
# Get correct column names
|
|
column_names = set(next(iter(dataset)).keys())
|
|
used_column_names = ["input_ids"]
|
|
if "attention_mask" in column_names:
|
|
used_column_names.append("attention_mask")
|
|
|
|
# Check if already tokenized so skip
|
|
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
|
if "labels" in column_names:
|
|
# Most likely forgot data collator!
|
|
if is_vlm and not hasattr(tokenizer, "pad"):
|
|
# Check if processing_class has a .pad, if not, use tokenizer.tokenizer
|
|
raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
|
|
self.data_collator = DataCollatorForSeq2Seq(tokenizer)
|
|
used_column_names.append("labels")
|
|
do_tokenize = False
|
|
elif "input_ids" in column_names:
|
|
# Skip dataset prep, and set data collator
|
|
if is_vlm and not hasattr(tokenizer, "pad"):
|
|
# Check if processing_class has a .pad, if not, use tokenizer.tokenizer
|
|
raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
|
|
self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
|
|
do_tokenize = False
|
|
elif dataset_text_field not in column_names:
|
|
do_formatting_func = True
|
|
if formatting_func is None:
|
|
raise RuntimeError("Unsloth: You must specify a `formatting_func`")
|
|
pass
|
|
|
|
if do_tokenize:
|
|
# Check double BOS tokens
|
|
if do_formatting_func:
|
|
test_text = formatting_func(next(iter(dataset)))
|
|
if not isinstance(test_text, list):
|
|
raise ValueError(
|
|
"Unsloth: The `formatting_func` should return a list of processed strings."
|
|
)
|
|
test_text = test_text[0]
|
|
else:
|
|
test_text = next(iter(dataset))[dataset_text_field][0]
|
|
|
|
# Get chat template
|
|
chat_template = getattr(processing_class, 'chat_template', '')
|
|
if chat_template == '' and is_vlm:
|
|
chat_template = getattr(tokenizer, 'chat_template', '')
|
|
if chat_template is None:
|
|
chat_template = ''
|
|
|
|
# Get bos_token
|
|
add_special_tokens = True
|
|
bos_token_1 = getattr(processing_class, 'bos_token', None)
|
|
bos_token_2 = getattr(tokenizer, 'bos_token', None)
|
|
bos_token = bos_token_1 or bos_token_2
|
|
|
|
if bos_token is not None:
|
|
if test_text.startswith(bos_token) or bos_token in chat_template:
|
|
add_special_tokens = False
|
|
print("Unsloth: We found double BOS tokens - we shall remove one automatically.")
|
|
pass
|
|
|
|
# Create tokenize function
|
|
def _tokenize(example):
|
|
return tokenizer(
|
|
example[dataset_text_field] if not do_formatting_func else formatting_func(example),
|
|
truncation = do_truncation,
|
|
max_length = max_seq_length,
|
|
return_token_type_ids = False,
|
|
add_special_tokens = add_special_tokens,
|
|
)
|
|
pass
|
|
|
|
if not isinstance(dataset, IterableDataset):
|
|
map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2)
|
|
else:
|
|
map_kwargs["batch_size"] = dataset._ex_iterable.batch_size
|
|
|
|
if use_desc: map_kwargs["desc"] = f'Unsloth: Tokenizing ["{dataset_text_field}"]'
|
|
dataset = dataset.map(_tokenize, batched = True, **map_kwargs)
|
|
|
|
# If VLM, switch data collator since .pad is needed!
|
|
if is_vlm and not hasattr(processing_class, "pad"):
|
|
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
|
|
self.data_collator = data_collator
|
|
pass
|
|
pass
|
|
if packing:
|
|
print("Unsloth: Hugging Face's packing is currently buggy - we're disabling it for now!")
|
|
return dataset
|
|
|
|
if max_seq_length == 0:
|
|
raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
|
|
|
|
if use_desc: map_kwargs["desc"] = f"Unsloth: Packing {dataset_name} dataset"
|
|
dataset = dataset.select_columns(used_column_names).map(
|
|
pack_examples,
|
|
batched = True,
|
|
fn_kwargs = {"seq_length": max_seq_length,},
|
|
**map_kwargs,
|
|
)
|
|
pass
|
|
return dataset
|
|
|
|
def _set_signature_columns_if_needed(self):
|
|
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
|
# By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids"
|
|
# and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the
|
|
# dataset. So we need to override the default signature columns to include "completion_mask" as well.
|
|
if self._signature_columns is None:
|
|
self._signature_columns = ["input_ids", "attention_mask", "completion_mask","labels"]
|
|
|
|
def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
|
|
outputs = super().compute_loss(
|
|
model,
|
|
inputs,
|
|
return_outputs = return_outputs,
|
|
num_items_in_batch = num_items_in_batch,
|
|
)
|
|
return outputs
|
|
|
|
# Override training step to add activation offloading context.
|
|
def training_step(self, *args, **kwargs):
|
|
with self.maybe_activation_offload_context:
|
|
return super().training_step(*args, **kwargs)
|
|
|
|
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
|
mode = "train" if self.model.training else "eval"
|
|
metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
|
|
|
|
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
|
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
|
if mode == "eval":
|
|
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
|
|
|
logs = {**logs, **metrics}
|
|
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
|
super().log(logs, start_time)
|
|
else: # transformers<=4.46
|
|
super().log(logs)
|
|
self._metrics[mode].clear()
|
|
|
|
def create_model_card(
|
|
self,
|
|
model_name: Optional[str] = None,
|
|
dataset_name: Optional[str] = None,
|
|
tags: Union[str, list[str], None] = None,
|
|
):
|
|
"""
|
|
Creates a draft of a model card using the information available to the `Trainer`.
|
|
|
|
Args:
|
|
model_name (`str` or `None`, *optional*, defaults to `None`):
|
|
Name of the model.
|
|
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
|
Name of the dataset used for training.
|
|
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
|
Tags to be associated with the model card.
|
|
"""
|
|
if not self.is_world_process_zero():
|
|
return
|
|
|
|
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
|
base_model = self.model.config._name_or_path
|
|
else:
|
|
base_model = None
|
|
|
|
tags = tags or []
|
|
if isinstance(tags, str):
|
|
tags = [tags]
|
|
|
|
if hasattr(self.model.config, "unsloth_version"):
|
|
tags.append("unsloth")
|
|
|
|
model_card = generate_model_card(
|
|
base_model=base_model,
|
|
model_name=model_name,
|
|
hub_model_id=self.hub_model_id,
|
|
dataset_name=dataset_name,
|
|
tags=tags,
|
|
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
|
comet_url=get_comet_experiment_url(),
|
|
trainer_name="SFT",
|
|
)
|
|
|
|
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
|
class UnslothSFTTrainer(_UnslothSFTTrainer):
|
|
"""
|
|
|
|
Trainer for Supervised Fine-Tuning (SFT) method.
|
|
|
|
This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods.
|
|
|
|
Example:
|
|
|
|
```python
|
|
from datasets import load_dataset
|
|
from trl import SFTTrainer
|
|
|
|
dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")
|
|
|
|
trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset)
|
|
trainer.train()
|
|
```
|
|
|
|
Args:
|
|
model (`Union[str, PreTrainedModel]`):
|
|
Model to be trained. Can be either:
|
|
|
|
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
|
|
a path to a *directory* containing model weights saved using
|
|
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
|
|
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
|
|
in `args.model_init_kwargs`.
|
|
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
|
args ([`SFTConfig`], *optional*, defaults to `None`):
|
|
Configuration for this trainer. If `None`, a default configuration is used.
|
|
data_collator (`DataCollator`, *optional*):
|
|
Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
|
|
Will default to [`~transformers.default_data_collator`] if no `processing_class` is provided, an instance
|
|
of [`~transformers.DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or
|
|
tokenizer.
|
|
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
|
Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and
|
|
[prompt-completion](#prompt-completion) type. The format of the samples can be either:
|
|
|
|
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
|
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
|
and content).
|
|
|
|
The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
|
|
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
|
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
|
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
|
|
Processing class used to process the data. If `None`, the processing class is loaded from the model's name
|
|
with [`~transformers.AutoTokenizer.from_pretrained`].
|
|
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
|
|
List of callbacks to customize the training loop. Will add those to the list of default callbacks
|
|
detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
|
|
|
|
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
|
|
method.
|
|
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
|
|
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
|
|
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
|
|
optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*, defaults to `None`):
|
|
A tuple containing the optimizer class and keyword arguments to use.
|
|
Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument.
|
|
|
|
Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer.
|
|
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*, defaults to `None`):
|
|
A function that preprocess the logits right before caching them at each evaluation step. Must take two
|
|
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
|
|
by this function will be reflected in the predictions received by `compute_metrics`.
|
|
|
|
Note that the labels (second parameter) will be `None` if the dataset does not have them.
|
|
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
|
|
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
|
|
formatting_func (`Optional[Callable]`):
|
|
Formatting function applied to the dataset before tokenization. Applying the formatting function explicitly
|
|
converts the dataset into a [language modeling](#language-modeling) type.
|
|
|
|
"""
|
|
def __init__(
|
|
self,
|
|
model,
|
|
args = None,
|
|
data_collator = None,
|
|
train_dataset = None,
|
|
eval_dataset = None,
|
|
processing_class = None,
|
|
compute_loss_func = None,
|
|
compute_metrics = None,
|
|
callbacks = None,
|
|
optimizer_cls_and_kwargs = None,
|
|
preprocess_logits_for_metrics = None,
|
|
peft_config = None,
|
|
formatting_func = None,
|
|
**kwargs
|
|
):
|
|
if args is None: args = UnslothSFTConfig()
|
|
use_bf16 = getattr(args, 'bf16', False)
|
|
use_fp16 = getattr(args, 'fp16', False)
|
|
force_float32 = False
|
|
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
|
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
|
force_float32 = True
|
|
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
|
dtype = getattr(model.config, 'torch_dtype', None)
|
|
if dtype is None: dtype = model.get_input_embeddings().dtype
|
|
from unsloth_zoo.utils import _get_dtype
|
|
dtype = _get_dtype(dtype)
|
|
float16 = dtype == torch.float16
|
|
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
|
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
|
if force_float32:
|
|
args.fp16 = False
|
|
args.bf16 = False
|
|
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
|
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
|
args.fp16 = float16
|
|
args.bf16 = not float16
|
|
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
|
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
|
args.eval_strategy = 'steps'
|
|
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
|
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
|
if ga_steps is not None and ga_steps > 1:
|
|
from transformers import __version__ as transformers_version
|
|
if Version(transformers_version) <= Version('4.45.2'):
|
|
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
|
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
|
if getattr(args, 'eval_strategy', 'no') != 'no':
|
|
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
|
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
|
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
|
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
|
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
|
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
|
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
|
if force_float32:
|
|
args.bf16_full_eval = False
|
|
args.fp16_full_eval = False
|
|
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
|
args.bf16_full_eval = True
|
|
args.fp16_full_eval = False
|
|
elif not bf16_full_eval and not fp16_full_eval:
|
|
args.bf16_full_eval = args.bf16
|
|
args.fp16_full_eval = args.fp16
|
|
_output_logits = False
|
|
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
|
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
|
if _output_logits:
|
|
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
|
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
|
pass
|
|
else:
|
|
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
|
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
|
if args_max_seq_length is None and model_max_seq_length is not None:
|
|
max_seq_length = model.max_seq_length
|
|
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
|
if 'max_length' not in locals() and not hasattr(args, 'max_length'):
|
|
pass
|
|
else:
|
|
if hasattr(args, 'max_seq_length') and args.max_seq_length is not None and args.max_seq_length > 0:
|
|
if hasattr(args, 'max_length'):
|
|
args.max_length = args.max_seq_length
|
|
max_length = args.max_length
|
|
else:
|
|
model_max_length = getattr(model, 'max_seq_length', None)
|
|
# print(model_max_length, 'mml1')
|
|
if model_max_length is None: model_max_length = getattr(model, 'max_length', None)
|
|
# print(model_max_length, 'mml2')
|
|
if model_max_length is not None:
|
|
args.max_length = model_max_length
|
|
max_length = args.max_length
|
|
elif hasattr(args, 'max_length') and args.max_length is not None:
|
|
max_length = args.max_length
|
|
# if we are here, then we are in a weird case where max_length is set but max_seq_length is not set
|
|
setattr(model, 'max_seq_length', max_length)
|
|
else:
|
|
print('Unsloth: We did not find `max_seq_length` or `max_length` in the model or args. We will set it to 1024.')
|
|
args.max_length = 1024
|
|
if model is not None and hasattr(model, 'for_training'):
|
|
model.for_training()
|
|
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
|
if 'processing_class' in locals():
|
|
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
|
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
|
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
|
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
|
if not isinstance(data_collator, UnslothVisionDataCollator):
|
|
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
|
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
|
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
|
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
|
else:
|
|
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
|
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
|
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
|
if not isinstance(data_collator, UnslothVisionDataCollator):
|
|
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
|
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
|
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
|
else:
|
|
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
|
other_metrics = []
|
|
|
|
from unsloth_zoo.logging_utils import PatchRLStatistics
|
|
PatchRLStatistics('sft_trainer', other_metrics)
|
|
IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\n')
|
|
from unsloth_zoo.tokenizer_utils import fix_untrained_tokens
|
|
from unsloth_zoo.training_utils import fix_zero_training_loss
|
|
if 'tokenizer' not in locals(): tokenizer = processing_class
|
|
fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)
|
|
fix_zero_training_loss(model, tokenizer, train_dataset)
|
|
|
|
super().__init__(
|
|
model = model,
|
|
args = args,
|
|
data_collator = data_collator,
|
|
train_dataset = train_dataset,
|
|
eval_dataset = eval_dataset,
|
|
processing_class = processing_class,
|
|
compute_loss_func = compute_loss_func,
|
|
compute_metrics = compute_metrics,
|
|
callbacks = callbacks,
|
|
optimizer_cls_and_kwargs = optimizer_cls_and_kwargs,
|
|
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
|
peft_config = peft_config,
|
|
formatting_func = formatting_func,**kwargs)
|
|
if hasattr(self, 'neftune_hook_handle'):
|
|
self.neftune_hook_handle.remove()
|
|
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
|
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
|
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
|
pass
|
|
|
|
pass
|